fix pointers again

This commit is contained in:
evilchili 2025-09-29 23:10:33 -07:00
parent 5bddca974d
commit a9261321d2
4 changed files with 110 additions and 38 deletions

View File

@ -22,15 +22,17 @@ class RecordTable(table.Table):
super().__init__(db.storage, name, **kwargs) super().__init__(db.storage, name, **kwargs)
def insert(self, document): def insert(self, document):
document.before_insert() document.before_insert(self._db)
doc = document.serialize(self._db) doc = document.serialize()
self._check_constraints(doc) self._check_constraints(doc)
if doc.doc_id: if doc.doc_id:
last_insert_id = super().upsert(doc)[0] last_insert_id = super().upsert(doc)[0]
else: else:
last_insert_id = super().insert(dict(doc)) last_insert_id = super().insert(dict(doc))
return self.get(doc_id=last_insert_id) doc.doc_id = last_insert_id
doc.after_insert(self._db)
return doc.deserialize(self._db)
def get(self, doc_id: int, recurse: bool = True): def get(self, doc_id: int, recurse: bool = True):
document = super().get(doc_id=doc_id) document = super().get(doc_id=doc_id)

View File

@ -1,13 +1,24 @@
from grung.types import Collection, Field, Integer, Record from grung.types import BackReference, Collection, Field, Integer, Record
class User(Record): class User(Record):
@classmethod @classmethod
def fields(cls): def fields(cls):
return [*super().fields(), Field("name"), Integer("number", default=0), Field("email", unique=True)] return [
*super().fields(),
Field("name"),
Integer("number", default=0),
Field("email", unique=True),
BackReference("groups", Group),
]
class Group(Record): class Group(Record):
@classmethod @classmethod
def fields(cls): def fields(cls):
return [*super().fields(), Field("name", unique=True), Collection("members", User), Collection("groups", cls)] return [
*super().fields(),
Field("name", unique=True),
Collection("members", User),
Collection("groups", Group),
]

View File

@ -9,7 +9,7 @@ from tinydb import TinyDB, where
from grung.exceptions import PointerReferenceError from grung.exceptions import PointerReferenceError
Metadata = namedtuple("Metadata", ["table", "fields"]) Metadata = namedtuple("Metadata", ["table", "fields", "backrefs"])
@dataclass @dataclass
@ -23,20 +23,27 @@ class Field:
default: str = None default: str = None
unique: bool = False unique: bool = False
def serialize(self, rec: value_type, db: TinyDB) -> str: def before_insert(self, value: value_type, db: TinyDB, record: Record) -> None:
return str(rec) pass
def deserialize(self, rec: value_type, db: TinyDB, recurse: bool = True) -> value_type: def after_insert(self, db: TinyDB, record: Record) -> None:
return rec pass
def serialize(self, value: value_type) -> str:
if value is not None:
return str(value)
def deserialize(self, value: value_type, db: TinyDB, recurse: bool = True) -> value_type:
return value
@dataclass @dataclass
class Integer(Field): class Integer(Field):
value_type: type = int value_type = int
default: int = 0 default: int = 0
def deserialize(self, rec: str, db: TinyDB, recurse: bool = True) -> value_type: def deserialize(self, value: str, db: TinyDB, recurse: bool = True) -> value_type:
return int(rec) return int(value)
class Record(Dict[(str, Field)]): class Record(Dict[(str, Field)]):
@ -47,7 +54,13 @@ class Record(Dict[(str, Field)]):
def __init__(self, raw_doc: dict = {}, doc_id: int = None, **params): def __init__(self, raw_doc: dict = {}, doc_id: int = None, **params):
self.doc_id = doc_id self.doc_id = doc_id
fields = self.__class__.fields() fields = self.__class__.fields()
self._metadata = Metadata(table=self.__class__.__name__, fields={f.name: f for f in fields}) self._metadata = Metadata(
table=self.__class__.__name__,
fields={f.name: f for f in fields},
backrefs=lambda value_type: (
field for field in fields if type(field) == BackReference and field.value_type == value_type
),
)
super().__init__(dict({field.name: field.default for field in fields}, **raw_doc, **params)) super().__init__(dict({field.name: field.default for field in fields}, **raw_doc, **params))
@classmethod @classmethod
@ -57,13 +70,13 @@ class Record(Dict[(str, Field)]):
Field("uid", default=nanoid.generate(size=8), unique=True) Field("uid", default=nanoid.generate(size=8), unique=True)
] ]
def serialize(self, db): def serialize(self):
""" """
Serialie every field on the record Serialie every field on the record
""" """
rec = {} rec = {}
for name, _field in self._metadata.fields.items(): for name, _field in self._metadata.fields.items():
rec[name] = _field.serialize(self[name], db) rec[name] = _field.serialize(self[name])
return self.__class__(rec, doc_id=self.doc_id) return self.__class__(rec, doc_id=self.doc_id)
def deserialize(self, db, recurse: bool = True): def deserialize(self, db, recurse: bool = True):
@ -75,8 +88,13 @@ class Record(Dict[(str, Field)]):
rec[name] = _field.deserialize(self[name], db, recurse=recurse) rec[name] = _field.deserialize(self[name], db, recurse=recurse)
return self.__class__(rec, doc_id=self.doc_id) return self.__class__(rec, doc_id=self.doc_id)
def before_insert(self): def before_insert(self, db: TinyDB) -> None:
pass for name, _field in self._metadata.fields.items():
_field.before_insert(self[name], db, self)
def after_insert(self, db: TinyDB) -> None:
for name, _field in self._metadata.fields.items():
_field.after_insert(db, self)
def __setattr__(self, key, value): def __setattr__(self, key, value):
if key in self: if key in self:
@ -92,7 +110,11 @@ class Record(Dict[(str, Field)]):
return hash(str(dict(self))) return hash(str(dict(self)))
def __repr__(self): def __repr__(self):
return f"{self.__class__.__name__}[{self.doc_id}]: {self.items()}" return (
f"{self.__class__.__name__}[{self.doc_id}]("
+ ", ".join([f"{key}={val}" for (key, val) in self.items()])
+ ")"
)
@dataclass @dataclass
@ -104,21 +126,32 @@ class Pointer(Field):
name: str = "" name: str = ""
value_type: type = Record value_type: type = Record
def serialize(self, value: value_type, db: TinyDB) -> str: def serialize(self, value: value_type) -> str:
if value: if value:
if not value.doc_id: if not value.doc_id:
raise PointerReferenceError(value) raise PointerReferenceError(value)
return f"{value._metadata.table}::{value.uid}" return f"{value._metadata.table}::{value.uid}"
return "" return None
def deserialize(self, value: str, db: TinyDB, recurse=True) -> value_type: def deserialize(self, value: str, db: TinyDB, recurse: bool = True) -> value_type:
if value: return Pointer.dereference(value, db, recurse)
@classmethod
def dereference(cls, value: str, db: TinyDB, recurse: bool = True):
if not value:
return
elif type(value) == str:
pt, puid = value.split("::") pt, puid = value.split("::")
if puid: if puid:
return db.table(pt).search(where("uid") == puid, recurse=recurse)[0] return db.table(pt).search(where("uid") == puid, recurse=recurse)[0]
return value return value
@dataclass
class BackReference(Pointer):
pass
@dataclass @dataclass
class Collection(Field): class Collection(Field):
""" """
@ -131,16 +164,29 @@ class Collection(Field):
def _pointer(self, rec): def _pointer(self, rec):
return Pointer(value_type=type(rec)) return Pointer(value_type=type(rec))
def serialize(self, values: List[Record], db: TinyDB) -> List[str]: def serialize(self, values: List[value_type]) -> List[str]:
if values: return [self._pointer(val).serialize(val) for val in values]
return [self._pointer(val).serialize(val, db=db) for val in values]
def deserialize(self, values: List[str], db: TinyDB, recurse=True) -> List[value_type]: def deserialize(self, values: List[str], db: TinyDB, recurse: bool = True) -> List[value_type]:
""" """
Recursively deserialize the objects in this collection Recursively deserialize the objects in this collection
""" """
recs = []
if not recurse: if not recurse:
return values return values
if values: for val in values:
return [self._pointer(val).deserialize(val, db=db, recurse=recurse) for val in values] recs.append(self._pointer(val).deserialize(val, db=db, recurse=recurse))
return [] return recs
def after_insert(self, db: TinyDB, record: Record) -> None:
"""
Populate any backreferences in the members of this collection with the parent record's uid.
"""
if not record[self.name]:
return
for member in record[self.name]:
reference = Pointer.dereference(member, db=db)
for backref in reference._metadata.backrefs(type(record)):
reference[backref.name] = record
db.table(reference._metadata.table).upsert(reference)

View File

@ -1,3 +1,5 @@
from pprint import pprint as print
import pytest import pytest
from tinydb import Query from tinydb import Query
from tinydb.storages import MemoryStorage from tinydb.storages import MemoryStorage
@ -44,16 +46,18 @@ def test_crud(db):
def test_pointers(db): def test_pointers(db):
user = examples.User(name="john", email="john@foo") user = examples.User(name="john", email="john@foo")
players = examples.Group(name="players", members=[user], groups=[]) players = examples.Group(name="players", members=[user])
with pytest.raises(PointerReferenceError): with pytest.raises(PointerReferenceError):
players = db.save(players) players = db.save(players)
user = db.save(user) user = db.save(user)
players = db.save(examples.Group(name="players", members=[user]))
players.members[0] = user user = db.table('User').get(doc_id=user.doc_id)
players = db.save(players) assert user.groups.uid == players.uid
assert players.members[0] == user
assert players.members[0].groups.uid == players.uid
def test_subgroups(db): def test_subgroups(db):
@ -66,7 +70,11 @@ def test_subgroups(db):
starfleet = db.save(examples.Group(name="Starfleet", groups=[tos, snw])) starfleet = db.save(examples.Group(name="Starfleet", groups=[tos, snw]))
assert tos in starfleet.groups assert tos in starfleet.groups
assert snw in starfleet.groups assert snw in starfleet.groups
assert kirk in set([user for group in starfleet.groups for user in group.members])
unique_users = set([user for group in starfleet.groups for user in group.members])
kirk = db.table('User').get(doc_id=kirk.doc_id)
assert kirk in unique_users
def test_unique(db): def test_unique(db):
@ -80,18 +88,23 @@ def test_unique(db):
def test_search(db): def test_search(db):
# create crew members
kirk = db.save(examples.User(name="Captain James T. Kirk", email="riskybiznez@starfleet")) kirk = db.save(examples.User(name="Captain James T. Kirk", email="riskybiznez@starfleet"))
bones = db.save(examples.User(name="Doctor McCoy", email="dammitjim@starfleet")) bones = db.save(examples.User(name="Doctor McCoy", email="dammitjim@starfleet"))
ricky = db.save(examples.User(name="Ensign Ricky Redshirt", email="invincible@starfleet")) ricky = db.save(examples.User(name="Ensign Ricky Redshirt", email="invincible@starfleet"))
# create the crew record
crew = db.save(examples.Group(name="Crew", members=[kirk, bones, ricky])) crew = db.save(examples.Group(name="Crew", members=[kirk, bones, ricky]))
User = Query() User = Query()
captains = db.User.search(User.name.matches("Captain")) captains = db.User.search(User.name.matches("Captain"))
assert len(captains) == 1 assert len(captains) == 1
Group = Query() # update the crew members so they have the backreference to crew
crew = db.Group.search(Group.name == "Crew")[0] kirk = db.table('User').get(doc_id=kirk.doc_id)
bones = db.table('User').get(doc_id=bones.doc_id)
ricky = db.table('User').get(doc_id=ricky.doc_id)
assert kirk in crew.members assert kirk in crew.members
assert bones in crew.members assert bones in crew.members
assert ricky in crew.members assert ricky in crew.members