fix recursion
This commit is contained in:
parent
bafaf043e3
commit
600209d485
|
@ -34,12 +34,12 @@ class RecordTable(table.Table):
|
||||||
doc.after_insert(self._db)
|
doc.after_insert(self._db)
|
||||||
return doc.deserialize(self._db)
|
return doc.deserialize(self._db)
|
||||||
|
|
||||||
def get(self, doc_id: int, recurse: bool = True):
|
def get(self, doc_id: int, recurse: bool = False):
|
||||||
document = super().get(doc_id=doc_id)
|
document = super().get(doc_id=doc_id)
|
||||||
if document:
|
if document:
|
||||||
return document.deserialize(self._db, recurse=recurse)
|
return document.deserialize(self._db, recurse=recurse)
|
||||||
|
|
||||||
def search(self, *args, recurse: bool = True, **kwargs) -> List[Record]:
|
def search(self, *args, recurse: bool = False, **kwargs) -> List[Record]:
|
||||||
results = super().search(*args, **kwargs)
|
results = super().search(*args, **kwargs)
|
||||||
return [r.deserialize(self._db, recurse=recurse) for r in results]
|
return [r.deserialize(self._db, recurse=recurse) for r in results]
|
||||||
|
|
||||||
|
|
|
@ -33,7 +33,7 @@ class Field:
|
||||||
if value is not None:
|
if value is not None:
|
||||||
return str(value)
|
return str(value)
|
||||||
|
|
||||||
def deserialize(self, value: value_type, db: TinyDB, recurse: bool = True) -> value_type:
|
def deserialize(self, value: value_type, db: TinyDB, recurse: bool = False) -> value_type:
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
@ -42,7 +42,7 @@ class Integer(Field):
|
||||||
value_type = int
|
value_type = int
|
||||||
default: int = 0
|
default: int = 0
|
||||||
|
|
||||||
def deserialize(self, value: str, db: TinyDB, recurse: bool = True) -> value_type:
|
def deserialize(self, value: str, db: TinyDB, recurse: bool = False) -> value_type:
|
||||||
return int(value)
|
return int(value)
|
||||||
|
|
||||||
|
|
||||||
|
@ -96,6 +96,10 @@ class Record(Dict[(str, Field)]):
|
||||||
for name, _field in self._metadata.fields.items():
|
for name, _field in self._metadata.fields.items():
|
||||||
_field.after_insert(db, self)
|
_field.after_insert(db, self)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def reference(self):
|
||||||
|
return Pointer.reference(self)
|
||||||
|
|
||||||
def __setattr__(self, key, value):
|
def __setattr__(self, key, value):
|
||||||
if key in self:
|
if key in self:
|
||||||
self[key] = value
|
self[key] = value
|
||||||
|
@ -147,7 +151,10 @@ class Pointer(Field):
|
||||||
elif type(value) == str:
|
elif type(value) == str:
|
||||||
pt, puid = value.split("::")
|
pt, puid = value.split("::")
|
||||||
if puid:
|
if puid:
|
||||||
|
try:
|
||||||
return db.table(pt).search(where("uid") == puid, recurse=recurse)[0]
|
return db.table(pt).search(where("uid") == puid, recurse=recurse)[0]
|
||||||
|
except IndexError:
|
||||||
|
raise PointerReferenceError(f"Expected a {pt} with uid=={puid} but did not find one!")
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
@ -165,13 +172,10 @@ class Collection(Field):
|
||||||
value_type: type = Record
|
value_type: type = Record
|
||||||
default: List[value_type] = field(default_factory=lambda: [])
|
default: List[value_type] = field(default_factory=lambda: [])
|
||||||
|
|
||||||
def _pointer(self, rec):
|
|
||||||
return Pointer(value_type=type(rec))
|
|
||||||
|
|
||||||
def serialize(self, values: List[value_type]) -> List[str]:
|
def serialize(self, values: List[value_type]) -> List[str]:
|
||||||
return [self._pointer(val).serialize(val) for val in values]
|
return [Pointer.reference(val) for val in values]
|
||||||
|
|
||||||
def deserialize(self, values: List[str], db: TinyDB, recurse: bool = True) -> List[value_type]:
|
def deserialize(self, values: List[str], db: TinyDB, recurse: bool = False) -> List[value_type]:
|
||||||
"""
|
"""
|
||||||
Recursively deserialize the objects in this collection
|
Recursively deserialize the objects in this collection
|
||||||
"""
|
"""
|
||||||
|
@ -179,7 +183,7 @@ class Collection(Field):
|
||||||
if not recurse:
|
if not recurse:
|
||||||
return values
|
return values
|
||||||
for val in values:
|
for val in values:
|
||||||
recs.append(self._pointer(val).deserialize(val, db=db, recurse=recurse))
|
recs.append(Pointer.dereference(val, db=db, recurse=False))
|
||||||
return recs
|
return recs
|
||||||
|
|
||||||
def after_insert(self, db: TinyDB, record: Record) -> None:
|
def after_insert(self, db: TinyDB, record: Record) -> None:
|
||||||
|
@ -190,7 +194,9 @@ class Collection(Field):
|
||||||
return
|
return
|
||||||
|
|
||||||
for member in record[self.name]:
|
for member in record[self.name]:
|
||||||
target = Pointer.dereference(member, db=db)
|
target = Pointer.dereference(member, db=db, recurse=False)
|
||||||
for backref in target._metadata.backrefs(type(record)):
|
for backref in target._metadata.backrefs(type(record)):
|
||||||
target[backref.name] = Pointer.reference(record)
|
target[backref.name] = Pointer.reference(record)
|
||||||
db.table(target._metadata.table).update({backref.name: record}, where("uid") == target.uid)
|
db.table(target._metadata.table).update(
|
||||||
|
{backref.name: target[backref.name]}, where("uid") == target.uid
|
||||||
|
)
|
||||||
|
|
|
@ -80,7 +80,7 @@ def test_subgroups(db):
|
||||||
unique_users = set([user for group in trek.groups for user in group.members])
|
unique_users = set([user for group in trek.groups for user in group.members])
|
||||||
|
|
||||||
kirk = db.table("User").get(doc_id=kirk.doc_id)
|
kirk = db.table("User").get(doc_id=kirk.doc_id)
|
||||||
assert kirk in unique_users
|
assert kirk.reference in unique_users
|
||||||
|
|
||||||
|
|
||||||
def test_unique(db):
|
def test_unique(db):
|
||||||
|
@ -117,4 +117,4 @@ def test_search(db):
|
||||||
|
|
||||||
Group = Query()
|
Group = Query()
|
||||||
crew = db.Group.search(Group.name == "Crew", recurse=False)
|
crew = db.Group.search(Group.name == "Crew", recurse=False)
|
||||||
assert f"User::{kirk.uid}" in crew[0].members
|
assert kirk.reference in crew[0].members
|
||||||
|
|
Loading…
Reference in New Issue
Block a user