fix recursion

This commit is contained in:
evilchili 2025-09-30 00:52:37 -07:00
parent bafaf043e3
commit 600209d485
3 changed files with 21 additions and 15 deletions

View File

@ -34,12 +34,12 @@ class RecordTable(table.Table):
doc.after_insert(self._db)
return doc.deserialize(self._db)
def get(self, doc_id: int, recurse: bool = True):
def get(self, doc_id: int, recurse: bool = False):
document = super().get(doc_id=doc_id)
if document:
return document.deserialize(self._db, recurse=recurse)
def search(self, *args, recurse: bool = True, **kwargs) -> List[Record]:
def search(self, *args, recurse: bool = False, **kwargs) -> List[Record]:
results = super().search(*args, **kwargs)
return [r.deserialize(self._db, recurse=recurse) for r in results]

View File

@ -33,7 +33,7 @@ class Field:
if value is not None:
return str(value)
def deserialize(self, value: value_type, db: TinyDB, recurse: bool = True) -> value_type:
def deserialize(self, value: value_type, db: TinyDB, recurse: bool = False) -> value_type:
return value
@ -42,7 +42,7 @@ class Integer(Field):
value_type = int
default: int = 0
def deserialize(self, value: str, db: TinyDB, recurse: bool = True) -> value_type:
def deserialize(self, value: str, db: TinyDB, recurse: bool = False) -> value_type:
return int(value)
@ -96,6 +96,10 @@ class Record(Dict[(str, Field)]):
for name, _field in self._metadata.fields.items():
_field.after_insert(db, self)
@property
def reference(self):
return Pointer.reference(self)
def __setattr__(self, key, value):
if key in self:
self[key] = value
@ -147,7 +151,10 @@ class Pointer(Field):
elif type(value) == str:
pt, puid = value.split("::")
if puid:
return db.table(pt).search(where("uid") == puid, recurse=recurse)[0]
try:
return db.table(pt).search(where("uid") == puid, recurse=recurse)[0]
except IndexError:
raise PointerReferenceError(f"Expected a {pt} with uid=={puid} but did not find one!")
return value
@ -165,13 +172,10 @@ class Collection(Field):
value_type: type = Record
default: List[value_type] = field(default_factory=lambda: [])
def _pointer(self, rec):
return Pointer(value_type=type(rec))
def serialize(self, values: List[value_type]) -> List[str]:
return [self._pointer(val).serialize(val) for val in values]
return [Pointer.reference(val) for val in values]
def deserialize(self, values: List[str], db: TinyDB, recurse: bool = True) -> List[value_type]:
def deserialize(self, values: List[str], db: TinyDB, recurse: bool = False) -> List[value_type]:
"""
Recursively deserialize the objects in this collection
"""
@ -179,7 +183,7 @@ class Collection(Field):
if not recurse:
return values
for val in values:
recs.append(self._pointer(val).deserialize(val, db=db, recurse=recurse))
recs.append(Pointer.dereference(val, db=db, recurse=False))
return recs
def after_insert(self, db: TinyDB, record: Record) -> None:
@ -190,7 +194,9 @@ class Collection(Field):
return
for member in record[self.name]:
target = Pointer.dereference(member, db=db)
target = Pointer.dereference(member, db=db, recurse=False)
for backref in target._metadata.backrefs(type(record)):
target[backref.name] = Pointer.reference(record)
db.table(target._metadata.table).update({backref.name: record}, where("uid") == target.uid)
db.table(target._metadata.table).update(
{backref.name: target[backref.name]}, where("uid") == target.uid
)

View File

@ -80,7 +80,7 @@ def test_subgroups(db):
unique_users = set([user for group in trek.groups for user in group.members])
kirk = db.table("User").get(doc_id=kirk.doc_id)
assert kirk in unique_users
assert kirk.reference in unique_users
def test_unique(db):
@ -117,4 +117,4 @@ def test_search(db):
Group = Query()
crew = db.Group.search(Group.name == "Crew", recurse=False)
assert f"User::{kirk.uid}" in crew[0].members
assert kirk.reference in crew[0].members