fix pointers again

This commit is contained in:
evilchili 2025-09-29 23:10:33 -07:00
parent 5bddca974d
commit a9261321d2
4 changed files with 110 additions and 38 deletions

View File

@ -22,15 +22,17 @@ class RecordTable(table.Table):
super().__init__(db.storage, name, **kwargs)
def insert(self, document):
document.before_insert()
doc = document.serialize(self._db)
document.before_insert(self._db)
doc = document.serialize()
self._check_constraints(doc)
if doc.doc_id:
last_insert_id = super().upsert(doc)[0]
else:
last_insert_id = super().insert(dict(doc))
return self.get(doc_id=last_insert_id)
doc.doc_id = last_insert_id
doc.after_insert(self._db)
return doc.deserialize(self._db)
def get(self, doc_id: int, recurse: bool = True):
document = super().get(doc_id=doc_id)

View File

@ -1,13 +1,24 @@
from grung.types import Collection, Field, Integer, Record
from grung.types import BackReference, Collection, Field, Integer, Record
class User(Record):
@classmethod
def fields(cls):
return [*super().fields(), Field("name"), Integer("number", default=0), Field("email", unique=True)]
return [
*super().fields(),
Field("name"),
Integer("number", default=0),
Field("email", unique=True),
BackReference("groups", Group),
]
class Group(Record):
@classmethod
def fields(cls):
return [*super().fields(), Field("name", unique=True), Collection("members", User), Collection("groups", cls)]
return [
*super().fields(),
Field("name", unique=True),
Collection("members", User),
Collection("groups", Group),
]

View File

@ -9,7 +9,7 @@ from tinydb import TinyDB, where
from grung.exceptions import PointerReferenceError
Metadata = namedtuple("Metadata", ["table", "fields"])
Metadata = namedtuple("Metadata", ["table", "fields", "backrefs"])
@dataclass
@ -23,20 +23,27 @@ class Field:
default: str = None
unique: bool = False
def serialize(self, rec: value_type, db: TinyDB) -> str:
return str(rec)
def before_insert(self, value: value_type, db: TinyDB, record: Record) -> None:
pass
def deserialize(self, rec: value_type, db: TinyDB, recurse: bool = True) -> value_type:
return rec
def after_insert(self, db: TinyDB, record: Record) -> None:
pass
def serialize(self, value: value_type) -> str:
if value is not None:
return str(value)
def deserialize(self, value: value_type, db: TinyDB, recurse: bool = True) -> value_type:
return value
@dataclass
class Integer(Field):
value_type: type = int
value_type = int
default: int = 0
def deserialize(self, rec: str, db: TinyDB, recurse: bool = True) -> value_type:
return int(rec)
def deserialize(self, value: str, db: TinyDB, recurse: bool = True) -> value_type:
return int(value)
class Record(Dict[(str, Field)]):
@ -47,7 +54,13 @@ class Record(Dict[(str, Field)]):
def __init__(self, raw_doc: dict = {}, doc_id: int = None, **params):
self.doc_id = doc_id
fields = self.__class__.fields()
self._metadata = Metadata(table=self.__class__.__name__, fields={f.name: f for f in fields})
self._metadata = Metadata(
table=self.__class__.__name__,
fields={f.name: f for f in fields},
backrefs=lambda value_type: (
field for field in fields if type(field) == BackReference and field.value_type == value_type
),
)
super().__init__(dict({field.name: field.default for field in fields}, **raw_doc, **params))
@classmethod
@ -57,13 +70,13 @@ class Record(Dict[(str, Field)]):
Field("uid", default=nanoid.generate(size=8), unique=True)
]
def serialize(self, db):
def serialize(self):
"""
Serialie every field on the record
"""
rec = {}
for name, _field in self._metadata.fields.items():
rec[name] = _field.serialize(self[name], db)
rec[name] = _field.serialize(self[name])
return self.__class__(rec, doc_id=self.doc_id)
def deserialize(self, db, recurse: bool = True):
@ -75,8 +88,13 @@ class Record(Dict[(str, Field)]):
rec[name] = _field.deserialize(self[name], db, recurse=recurse)
return self.__class__(rec, doc_id=self.doc_id)
def before_insert(self):
pass
def before_insert(self, db: TinyDB) -> None:
for name, _field in self._metadata.fields.items():
_field.before_insert(self[name], db, self)
def after_insert(self, db: TinyDB) -> None:
for name, _field in self._metadata.fields.items():
_field.after_insert(db, self)
def __setattr__(self, key, value):
if key in self:
@ -92,7 +110,11 @@ class Record(Dict[(str, Field)]):
return hash(str(dict(self)))
def __repr__(self):
return f"{self.__class__.__name__}[{self.doc_id}]: {self.items()}"
return (
f"{self.__class__.__name__}[{self.doc_id}]("
+ ", ".join([f"{key}={val}" for (key, val) in self.items()])
+ ")"
)
@dataclass
@ -104,21 +126,32 @@ class Pointer(Field):
name: str = ""
value_type: type = Record
def serialize(self, value: value_type, db: TinyDB) -> str:
def serialize(self, value: value_type) -> str:
if value:
if not value.doc_id:
raise PointerReferenceError(value)
return f"{value._metadata.table}::{value.uid}"
return ""
return None
def deserialize(self, value: str, db: TinyDB, recurse=True) -> value_type:
if value:
def deserialize(self, value: str, db: TinyDB, recurse: bool = True) -> value_type:
return Pointer.dereference(value, db, recurse)
@classmethod
def dereference(cls, value: str, db: TinyDB, recurse: bool = True):
if not value:
return
elif type(value) == str:
pt, puid = value.split("::")
if puid:
return db.table(pt).search(where("uid") == puid, recurse=recurse)[0]
return value
@dataclass
class BackReference(Pointer):
pass
@dataclass
class Collection(Field):
"""
@ -131,16 +164,29 @@ class Collection(Field):
def _pointer(self, rec):
return Pointer(value_type=type(rec))
def serialize(self, values: List[Record], db: TinyDB) -> List[str]:
if values:
return [self._pointer(val).serialize(val, db=db) for val in values]
def serialize(self, values: List[value_type]) -> List[str]:
return [self._pointer(val).serialize(val) for val in values]
def deserialize(self, values: List[str], db: TinyDB, recurse=True) -> List[value_type]:
def deserialize(self, values: List[str], db: TinyDB, recurse: bool = True) -> List[value_type]:
"""
Recursively deserialize the objects in this collection
"""
recs = []
if not recurse:
return values
if values:
return [self._pointer(val).deserialize(val, db=db, recurse=recurse) for val in values]
return []
for val in values:
recs.append(self._pointer(val).deserialize(val, db=db, recurse=recurse))
return recs
def after_insert(self, db: TinyDB, record: Record) -> None:
"""
Populate any backreferences in the members of this collection with the parent record's uid.
"""
if not record[self.name]:
return
for member in record[self.name]:
reference = Pointer.dereference(member, db=db)
for backref in reference._metadata.backrefs(type(record)):
reference[backref.name] = record
db.table(reference._metadata.table).upsert(reference)

View File

@ -1,3 +1,5 @@
from pprint import pprint as print
import pytest
from tinydb import Query
from tinydb.storages import MemoryStorage
@ -44,16 +46,18 @@ def test_crud(db):
def test_pointers(db):
user = examples.User(name="john", email="john@foo")
players = examples.Group(name="players", members=[user], groups=[])
players = examples.Group(name="players", members=[user])
with pytest.raises(PointerReferenceError):
players = db.save(players)
user = db.save(user)
players = db.save(examples.Group(name="players", members=[user]))
players.members[0] = user
players = db.save(players)
assert players.members[0] == user
user = db.table('User').get(doc_id=user.doc_id)
assert user.groups.uid == players.uid
assert players.members[0].groups.uid == players.uid
def test_subgroups(db):
@ -66,7 +70,11 @@ def test_subgroups(db):
starfleet = db.save(examples.Group(name="Starfleet", groups=[tos, snw]))
assert tos in starfleet.groups
assert snw in starfleet.groups
assert kirk in set([user for group in starfleet.groups for user in group.members])
unique_users = set([user for group in starfleet.groups for user in group.members])
kirk = db.table('User').get(doc_id=kirk.doc_id)
assert kirk in unique_users
def test_unique(db):
@ -80,18 +88,23 @@ def test_unique(db):
def test_search(db):
# create crew members
kirk = db.save(examples.User(name="Captain James T. Kirk", email="riskybiznez@starfleet"))
bones = db.save(examples.User(name="Doctor McCoy", email="dammitjim@starfleet"))
ricky = db.save(examples.User(name="Ensign Ricky Redshirt", email="invincible@starfleet"))
# create the crew record
crew = db.save(examples.Group(name="Crew", members=[kirk, bones, ricky]))
User = Query()
captains = db.User.search(User.name.matches("Captain"))
assert len(captains) == 1
Group = Query()
crew = db.Group.search(Group.name == "Crew")[0]
# update the crew members so they have the backreference to crew
kirk = db.table('User').get(doc_id=kirk.doc_id)
bones = db.table('User').get(doc_id=bones.doc_id)
ricky = db.table('User').get(doc_id=ricky.doc_id)
assert kirk in crew.members
assert bones in crew.members
assert ricky in crew.members