diff --git a/src/grung/examples.py b/src/grung/examples.py index c6fa920..377e271 100644 --- a/src/grung/examples.py +++ b/src/grung/examples.py @@ -21,4 +21,5 @@ class Group(Record): Field("name", unique=True), Collection("members", User), Collection("groups", Group), + BackReference("parent", Group), ] diff --git a/src/grung/types.py b/src/grung/types.py index f5e3e3f..b860f52 100644 --- a/src/grung/types.py +++ b/src/grung/types.py @@ -127,15 +127,19 @@ class Pointer(Field): value_type: type = Record def serialize(self, value: value_type) -> str: + return Pointer.reference(value) + + def deserialize(self, value: str, db: TinyDB, recurse: bool = True) -> value_type: + return Pointer.dereference(value, db, recurse) + + @classmethod + def reference(cls, value: Record): if value: if not value.doc_id: raise PointerReferenceError(value) return f"{value._metadata.table}::{value.uid}" return None - def deserialize(self, value: str, db: TinyDB, recurse: bool = True) -> value_type: - return Pointer.dereference(value, db, recurse) - @classmethod def dereference(cls, value: str, db: TinyDB, recurse: bool = True): if not value: @@ -186,7 +190,7 @@ class Collection(Field): return for member in record[self.name]: - reference = Pointer.dereference(member, db=db) - for backref in reference._metadata.backrefs(type(record)): - reference[backref.name] = record - db.table(reference._metadata.table).upsert(reference) + target = Pointer.dereference(member, db=db) + for backref in target._metadata.backrefs(type(record)): + target[backref.name] = Pointer.reference(record) + db.table(target._metadata.table).update({backref.name: record}, where("uid") == target.uid) diff --git a/test/test_db.py b/test/test_db.py index 06e4780..b8d7533 100644 --- a/test/test_db.py +++ b/test/test_db.py @@ -68,9 +68,13 @@ def test_subgroups(db): snw = db.save(examples.Group(name="Strange New Worlds", members=[pike])) starfleet = db.save(examples.Group(name="Starfleet", groups=[tos, snw])) + tos = db.table("Group").get(doc_id=tos.doc_id) + snw = db.table("Group").get(doc_id=snw.doc_id) assert tos in starfleet.groups assert snw in starfleet.groups + assert tos.parent.uid == starfleet.uid + unique_users = set([user for group in starfleet.groups for user in group.members]) kirk = db.table("User").get(doc_id=kirk.doc_id)