diff --git a/src/grung/db.py b/src/grung/db.py index 18f9381..00c7629 100644 --- a/src/grung/db.py +++ b/src/grung/db.py @@ -34,12 +34,12 @@ class RecordTable(table.Table): doc.after_insert(self._db) return doc.deserialize(self._db) - def get(self, doc_id: int, recurse: bool = True): + def get(self, doc_id: int, recurse: bool = False): document = super().get(doc_id=doc_id) if document: return document.deserialize(self._db, recurse=recurse) - def search(self, *args, recurse: bool = True, **kwargs) -> List[Record]: + def search(self, *args, recurse: bool = False, **kwargs) -> List[Record]: results = super().search(*args, **kwargs) return [r.deserialize(self._db, recurse=recurse) for r in results] diff --git a/src/grung/types.py b/src/grung/types.py index b860f52..d929524 100644 --- a/src/grung/types.py +++ b/src/grung/types.py @@ -33,7 +33,7 @@ class Field: if value is not None: return str(value) - def deserialize(self, value: value_type, db: TinyDB, recurse: bool = True) -> value_type: + def deserialize(self, value: value_type, db: TinyDB, recurse: bool = False) -> value_type: return value @@ -42,7 +42,7 @@ class Integer(Field): value_type = int default: int = 0 - def deserialize(self, value: str, db: TinyDB, recurse: bool = True) -> value_type: + def deserialize(self, value: str, db: TinyDB, recurse: bool = False) -> value_type: return int(value) @@ -96,6 +96,10 @@ class Record(Dict[(str, Field)]): for name, _field in self._metadata.fields.items(): _field.after_insert(db, self) + @property + def reference(self): + return Pointer.reference(self) + def __setattr__(self, key, value): if key in self: self[key] = value @@ -147,7 +151,10 @@ class Pointer(Field): elif type(value) == str: pt, puid = value.split("::") if puid: - return db.table(pt).search(where("uid") == puid, recurse=recurse)[0] + try: + return db.table(pt).search(where("uid") == puid, recurse=recurse)[0] + except IndexError: + raise PointerReferenceError(f"Expected a {pt} with uid=={puid} but did not find one!") return value @@ -165,13 +172,10 @@ class Collection(Field): value_type: type = Record default: List[value_type] = field(default_factory=lambda: []) - def _pointer(self, rec): - return Pointer(value_type=type(rec)) - def serialize(self, values: List[value_type]) -> List[str]: - return [self._pointer(val).serialize(val) for val in values] + return [Pointer.reference(val) for val in values] - def deserialize(self, values: List[str], db: TinyDB, recurse: bool = True) -> List[value_type]: + def deserialize(self, values: List[str], db: TinyDB, recurse: bool = False) -> List[value_type]: """ Recursively deserialize the objects in this collection """ @@ -179,7 +183,7 @@ class Collection(Field): if not recurse: return values for val in values: - recs.append(self._pointer(val).deserialize(val, db=db, recurse=recurse)) + recs.append(Pointer.dereference(val, db=db, recurse=False)) return recs def after_insert(self, db: TinyDB, record: Record) -> None: @@ -190,7 +194,9 @@ class Collection(Field): return for member in record[self.name]: - target = Pointer.dereference(member, db=db) + target = Pointer.dereference(member, db=db, recurse=False) for backref in target._metadata.backrefs(type(record)): target[backref.name] = Pointer.reference(record) - db.table(target._metadata.table).update({backref.name: record}, where("uid") == target.uid) + db.table(target._metadata.table).update( + {backref.name: target[backref.name]}, where("uid") == target.uid + ) diff --git a/test/test_db.py b/test/test_db.py index 2ec5397..2675092 100644 --- a/test/test_db.py +++ b/test/test_db.py @@ -80,7 +80,7 @@ def test_subgroups(db): unique_users = set([user for group in trek.groups for user in group.members]) kirk = db.table("User").get(doc_id=kirk.doc_id) - assert kirk in unique_users + assert kirk.reference in unique_users def test_unique(db): @@ -117,4 +117,4 @@ def test_search(db): Group = Query() crew = db.Group.search(Group.name == "Crew", recurse=False) - assert f"User::{kirk.uid}" in crew[0].members + assert kirk.reference in crew[0].members