From 450d8d490aa2deecd35080c9697705045355ad1b Mon Sep 17 00:00:00 2001 From: evilchili Date: Sun, 28 Sep 2025 14:08:50 -0700 Subject: [PATCH] Add Pointer field --- src/grung/types.py | 51 +++++++++++++++++++++++++++++++--------------- 1 file changed, 35 insertions(+), 16 deletions(-) diff --git a/src/grung/types.py b/src/grung/types.py index 3303c2f..7a9b826 100644 --- a/src/grung/types.py +++ b/src/grung/types.py @@ -95,32 +95,51 @@ class Record(Dict[(str, Field)]): return f"{self.__class__.__name__}[{self.doc_id}]: {self.items()}" +@dataclass +class Pointer(Field): + """ + Store a string reference to a record. + """ + + name: str = "" + value_type: type = Record + + def serialize(self, value: value_type, db: TinyDB) -> str: + if value: + if not value.doc_id: + raise PointerReferenceError(value) + return f"{value._metadata.table}::{value.uid}" + return "" + + def deserialize(self, value: str, db: TinyDB, recurse=True) -> value_type: + pt, puid = value.split("::") + if puid: + return db.table(pt).search(where("uid") == puid, recurse=recurse)[0] + return value + + @dataclass class Collection(Field): """ - A collection of fields that store pointers instead of dicts. + A collection of pointers. """ value_type: type = Record default: List[value_type] = field(default_factory=lambda: []) - def serialize(self, value: value_type, db: TinyDB) -> List[str]: - vals = self.default - if value: - for rec in value: - if not rec.doc_id: - raise PointerReferenceError(rec) - vals.append(f"{rec._metadata.table}::{rec.uid}") - return vals + def _pointer(self, rec): + return Pointer(value_type=type(rec)) - def deserialize(self, rec: List[value_type], db: TinyDB, recurse=True) -> value_type: + def serialize(self, values: List[Record], db: TinyDB) -> List[str]: + if values: + return [self._pointer(val).serialize(val, db=db) for val in values] + + def deserialize(self, values: List[str], db: TinyDB, recurse=True) -> List[value_type]: """ Recursively deserialize the objects in this collection """ - vals = [] if not recurse: - return rec - for member in rec: - pt, puid = member.split("::") - vals.append(db.table(pt).search(where("uid") == puid)[0]) - return vals + return values + if values: + return [self._pointer(val).deserialize(val, db=db, recurse=recurse) for val in values] + return []