From a9261321d26366b115a178f6638243d478f82505 Mon Sep 17 00:00:00 2001 From: evilchili Date: Mon, 29 Sep 2025 23:10:33 -0700 Subject: [PATCH] fix pointers again --- src/grung/db.py | 8 ++-- src/grung/examples.py | 17 ++++++-- src/grung/types.py | 96 ++++++++++++++++++++++++++++++++----------- test/test_db.py | 27 ++++++++---- 4 files changed, 110 insertions(+), 38 deletions(-) diff --git a/src/grung/db.py b/src/grung/db.py index 183c6b9..18f9381 100644 --- a/src/grung/db.py +++ b/src/grung/db.py @@ -22,15 +22,17 @@ class RecordTable(table.Table): super().__init__(db.storage, name, **kwargs) def insert(self, document): - document.before_insert() - doc = document.serialize(self._db) + document.before_insert(self._db) + doc = document.serialize() self._check_constraints(doc) if doc.doc_id: last_insert_id = super().upsert(doc)[0] else: last_insert_id = super().insert(dict(doc)) - return self.get(doc_id=last_insert_id) + doc.doc_id = last_insert_id + doc.after_insert(self._db) + return doc.deserialize(self._db) def get(self, doc_id: int, recurse: bool = True): document = super().get(doc_id=doc_id) diff --git a/src/grung/examples.py b/src/grung/examples.py index 5887e86..c6fa920 100644 --- a/src/grung/examples.py +++ b/src/grung/examples.py @@ -1,13 +1,24 @@ -from grung.types import Collection, Field, Integer, Record +from grung.types import BackReference, Collection, Field, Integer, Record class User(Record): @classmethod def fields(cls): - return [*super().fields(), Field("name"), Integer("number", default=0), Field("email", unique=True)] + return [ + *super().fields(), + Field("name"), + Integer("number", default=0), + Field("email", unique=True), + BackReference("groups", Group), + ] class Group(Record): @classmethod def fields(cls): - return [*super().fields(), Field("name", unique=True), Collection("members", User), Collection("groups", cls)] + return [ + *super().fields(), + Field("name", unique=True), + Collection("members", User), + Collection("groups", Group), + ] diff --git a/src/grung/types.py b/src/grung/types.py index 0c8a714..f5e3e3f 100644 --- a/src/grung/types.py +++ b/src/grung/types.py @@ -9,7 +9,7 @@ from tinydb import TinyDB, where from grung.exceptions import PointerReferenceError -Metadata = namedtuple("Metadata", ["table", "fields"]) +Metadata = namedtuple("Metadata", ["table", "fields", "backrefs"]) @dataclass @@ -23,20 +23,27 @@ class Field: default: str = None unique: bool = False - def serialize(self, rec: value_type, db: TinyDB) -> str: - return str(rec) + def before_insert(self, value: value_type, db: TinyDB, record: Record) -> None: + pass - def deserialize(self, rec: value_type, db: TinyDB, recurse: bool = True) -> value_type: - return rec + def after_insert(self, db: TinyDB, record: Record) -> None: + pass + + def serialize(self, value: value_type) -> str: + if value is not None: + return str(value) + + def deserialize(self, value: value_type, db: TinyDB, recurse: bool = True) -> value_type: + return value @dataclass class Integer(Field): - value_type: type = int + value_type = int default: int = 0 - def deserialize(self, rec: str, db: TinyDB, recurse: bool = True) -> value_type: - return int(rec) + def deserialize(self, value: str, db: TinyDB, recurse: bool = True) -> value_type: + return int(value) class Record(Dict[(str, Field)]): @@ -47,7 +54,13 @@ class Record(Dict[(str, Field)]): def __init__(self, raw_doc: dict = {}, doc_id: int = None, **params): self.doc_id = doc_id fields = self.__class__.fields() - self._metadata = Metadata(table=self.__class__.__name__, fields={f.name: f for f in fields}) + self._metadata = Metadata( + table=self.__class__.__name__, + fields={f.name: f for f in fields}, + backrefs=lambda value_type: ( + field for field in fields if type(field) == BackReference and field.value_type == value_type + ), + ) super().__init__(dict({field.name: field.default for field in fields}, **raw_doc, **params)) @classmethod @@ -57,13 +70,13 @@ class Record(Dict[(str, Field)]): Field("uid", default=nanoid.generate(size=8), unique=True) ] - def serialize(self, db): + def serialize(self): """ Serialie every field on the record """ rec = {} for name, _field in self._metadata.fields.items(): - rec[name] = _field.serialize(self[name], db) + rec[name] = _field.serialize(self[name]) return self.__class__(rec, doc_id=self.doc_id) def deserialize(self, db, recurse: bool = True): @@ -75,8 +88,13 @@ class Record(Dict[(str, Field)]): rec[name] = _field.deserialize(self[name], db, recurse=recurse) return self.__class__(rec, doc_id=self.doc_id) - def before_insert(self): - pass + def before_insert(self, db: TinyDB) -> None: + for name, _field in self._metadata.fields.items(): + _field.before_insert(self[name], db, self) + + def after_insert(self, db: TinyDB) -> None: + for name, _field in self._metadata.fields.items(): + _field.after_insert(db, self) def __setattr__(self, key, value): if key in self: @@ -92,7 +110,11 @@ class Record(Dict[(str, Field)]): return hash(str(dict(self))) def __repr__(self): - return f"{self.__class__.__name__}[{self.doc_id}]: {self.items()}" + return ( + f"{self.__class__.__name__}[{self.doc_id}](" + + ", ".join([f"{key}={val}" for (key, val) in self.items()]) + + ")" + ) @dataclass @@ -104,21 +126,32 @@ class Pointer(Field): name: str = "" value_type: type = Record - def serialize(self, value: value_type, db: TinyDB) -> str: + def serialize(self, value: value_type) -> str: if value: if not value.doc_id: raise PointerReferenceError(value) return f"{value._metadata.table}::{value.uid}" - return "" + return None - def deserialize(self, value: str, db: TinyDB, recurse=True) -> value_type: - if value: + def deserialize(self, value: str, db: TinyDB, recurse: bool = True) -> value_type: + return Pointer.dereference(value, db, recurse) + + @classmethod + def dereference(cls, value: str, db: TinyDB, recurse: bool = True): + if not value: + return + elif type(value) == str: pt, puid = value.split("::") if puid: return db.table(pt).search(where("uid") == puid, recurse=recurse)[0] return value +@dataclass +class BackReference(Pointer): + pass + + @dataclass class Collection(Field): """ @@ -131,16 +164,29 @@ class Collection(Field): def _pointer(self, rec): return Pointer(value_type=type(rec)) - def serialize(self, values: List[Record], db: TinyDB) -> List[str]: - if values: - return [self._pointer(val).serialize(val, db=db) for val in values] + def serialize(self, values: List[value_type]) -> List[str]: + return [self._pointer(val).serialize(val) for val in values] - def deserialize(self, values: List[str], db: TinyDB, recurse=True) -> List[value_type]: + def deserialize(self, values: List[str], db: TinyDB, recurse: bool = True) -> List[value_type]: """ Recursively deserialize the objects in this collection """ + recs = [] if not recurse: return values - if values: - return [self._pointer(val).deserialize(val, db=db, recurse=recurse) for val in values] - return [] + for val in values: + recs.append(self._pointer(val).deserialize(val, db=db, recurse=recurse)) + return recs + + def after_insert(self, db: TinyDB, record: Record) -> None: + """ + Populate any backreferences in the members of this collection with the parent record's uid. + """ + if not record[self.name]: + return + + for member in record[self.name]: + reference = Pointer.dereference(member, db=db) + for backref in reference._metadata.backrefs(type(record)): + reference[backref.name] = record + db.table(reference._metadata.table).upsert(reference) diff --git a/test/test_db.py b/test/test_db.py index 35121fe..ac3dc31 100644 --- a/test/test_db.py +++ b/test/test_db.py @@ -1,3 +1,5 @@ +from pprint import pprint as print + import pytest from tinydb import Query from tinydb.storages import MemoryStorage @@ -44,16 +46,18 @@ def test_crud(db): def test_pointers(db): user = examples.User(name="john", email="john@foo") - players = examples.Group(name="players", members=[user], groups=[]) + players = examples.Group(name="players", members=[user]) with pytest.raises(PointerReferenceError): players = db.save(players) user = db.save(user) + players = db.save(examples.Group(name="players", members=[user])) - players.members[0] = user - players = db.save(players) - assert players.members[0] == user + user = db.table('User').get(doc_id=user.doc_id) + assert user.groups.uid == players.uid + + assert players.members[0].groups.uid == players.uid def test_subgroups(db): @@ -66,7 +70,11 @@ def test_subgroups(db): starfleet = db.save(examples.Group(name="Starfleet", groups=[tos, snw])) assert tos in starfleet.groups assert snw in starfleet.groups - assert kirk in set([user for group in starfleet.groups for user in group.members]) + + unique_users = set([user for group in starfleet.groups for user in group.members]) + + kirk = db.table('User').get(doc_id=kirk.doc_id) + assert kirk in unique_users def test_unique(db): @@ -80,18 +88,23 @@ def test_unique(db): def test_search(db): + # create crew members kirk = db.save(examples.User(name="Captain James T. Kirk", email="riskybiznez@starfleet")) bones = db.save(examples.User(name="Doctor McCoy", email="dammitjim@starfleet")) ricky = db.save(examples.User(name="Ensign Ricky Redshirt", email="invincible@starfleet")) + # create the crew record crew = db.save(examples.Group(name="Crew", members=[kirk, bones, ricky])) User = Query() captains = db.User.search(User.name.matches("Captain")) assert len(captains) == 1 - Group = Query() - crew = db.Group.search(Group.name == "Crew")[0] + # update the crew members so they have the backreference to crew + kirk = db.table('User').get(doc_id=kirk.doc_id) + bones = db.table('User').get(doc_id=bones.doc_id) + ricky = db.table('User').get(doc_id=ricky.doc_id) + assert kirk in crew.members assert bones in crew.members assert ricky in crew.members