fix pointers again
This commit is contained in:
parent
5bddca974d
commit
a9261321d2
|
@ -22,15 +22,17 @@ class RecordTable(table.Table):
|
|||
super().__init__(db.storage, name, **kwargs)
|
||||
|
||||
def insert(self, document):
|
||||
document.before_insert()
|
||||
doc = document.serialize(self._db)
|
||||
document.before_insert(self._db)
|
||||
doc = document.serialize()
|
||||
self._check_constraints(doc)
|
||||
|
||||
if doc.doc_id:
|
||||
last_insert_id = super().upsert(doc)[0]
|
||||
else:
|
||||
last_insert_id = super().insert(dict(doc))
|
||||
return self.get(doc_id=last_insert_id)
|
||||
doc.doc_id = last_insert_id
|
||||
doc.after_insert(self._db)
|
||||
return doc.deserialize(self._db)
|
||||
|
||||
def get(self, doc_id: int, recurse: bool = True):
|
||||
document = super().get(doc_id=doc_id)
|
||||
|
|
|
@ -1,13 +1,24 @@
|
|||
from grung.types import Collection, Field, Integer, Record
|
||||
from grung.types import BackReference, Collection, Field, Integer, Record
|
||||
|
||||
|
||||
class User(Record):
|
||||
@classmethod
|
||||
def fields(cls):
|
||||
return [*super().fields(), Field("name"), Integer("number", default=0), Field("email", unique=True)]
|
||||
return [
|
||||
*super().fields(),
|
||||
Field("name"),
|
||||
Integer("number", default=0),
|
||||
Field("email", unique=True),
|
||||
BackReference("groups", Group),
|
||||
]
|
||||
|
||||
|
||||
class Group(Record):
|
||||
@classmethod
|
||||
def fields(cls):
|
||||
return [*super().fields(), Field("name", unique=True), Collection("members", User), Collection("groups", cls)]
|
||||
return [
|
||||
*super().fields(),
|
||||
Field("name", unique=True),
|
||||
Collection("members", User),
|
||||
Collection("groups", Group),
|
||||
]
|
||||
|
|
|
@ -9,7 +9,7 @@ from tinydb import TinyDB, where
|
|||
|
||||
from grung.exceptions import PointerReferenceError
|
||||
|
||||
Metadata = namedtuple("Metadata", ["table", "fields"])
|
||||
Metadata = namedtuple("Metadata", ["table", "fields", "backrefs"])
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -23,20 +23,27 @@ class Field:
|
|||
default: str = None
|
||||
unique: bool = False
|
||||
|
||||
def serialize(self, rec: value_type, db: TinyDB) -> str:
|
||||
return str(rec)
|
||||
def before_insert(self, value: value_type, db: TinyDB, record: Record) -> None:
|
||||
pass
|
||||
|
||||
def deserialize(self, rec: value_type, db: TinyDB, recurse: bool = True) -> value_type:
|
||||
return rec
|
||||
def after_insert(self, db: TinyDB, record: Record) -> None:
|
||||
pass
|
||||
|
||||
def serialize(self, value: value_type) -> str:
|
||||
if value is not None:
|
||||
return str(value)
|
||||
|
||||
def deserialize(self, value: value_type, db: TinyDB, recurse: bool = True) -> value_type:
|
||||
return value
|
||||
|
||||
|
||||
@dataclass
|
||||
class Integer(Field):
|
||||
value_type: type = int
|
||||
value_type = int
|
||||
default: int = 0
|
||||
|
||||
def deserialize(self, rec: str, db: TinyDB, recurse: bool = True) -> value_type:
|
||||
return int(rec)
|
||||
def deserialize(self, value: str, db: TinyDB, recurse: bool = True) -> value_type:
|
||||
return int(value)
|
||||
|
||||
|
||||
class Record(Dict[(str, Field)]):
|
||||
|
@ -47,7 +54,13 @@ class Record(Dict[(str, Field)]):
|
|||
def __init__(self, raw_doc: dict = {}, doc_id: int = None, **params):
|
||||
self.doc_id = doc_id
|
||||
fields = self.__class__.fields()
|
||||
self._metadata = Metadata(table=self.__class__.__name__, fields={f.name: f for f in fields})
|
||||
self._metadata = Metadata(
|
||||
table=self.__class__.__name__,
|
||||
fields={f.name: f for f in fields},
|
||||
backrefs=lambda value_type: (
|
||||
field for field in fields if type(field) == BackReference and field.value_type == value_type
|
||||
),
|
||||
)
|
||||
super().__init__(dict({field.name: field.default for field in fields}, **raw_doc, **params))
|
||||
|
||||
@classmethod
|
||||
|
@ -57,13 +70,13 @@ class Record(Dict[(str, Field)]):
|
|||
Field("uid", default=nanoid.generate(size=8), unique=True)
|
||||
]
|
||||
|
||||
def serialize(self, db):
|
||||
def serialize(self):
|
||||
"""
|
||||
Serialie every field on the record
|
||||
"""
|
||||
rec = {}
|
||||
for name, _field in self._metadata.fields.items():
|
||||
rec[name] = _field.serialize(self[name], db)
|
||||
rec[name] = _field.serialize(self[name])
|
||||
return self.__class__(rec, doc_id=self.doc_id)
|
||||
|
||||
def deserialize(self, db, recurse: bool = True):
|
||||
|
@ -75,8 +88,13 @@ class Record(Dict[(str, Field)]):
|
|||
rec[name] = _field.deserialize(self[name], db, recurse=recurse)
|
||||
return self.__class__(rec, doc_id=self.doc_id)
|
||||
|
||||
def before_insert(self):
|
||||
pass
|
||||
def before_insert(self, db: TinyDB) -> None:
|
||||
for name, _field in self._metadata.fields.items():
|
||||
_field.before_insert(self[name], db, self)
|
||||
|
||||
def after_insert(self, db: TinyDB) -> None:
|
||||
for name, _field in self._metadata.fields.items():
|
||||
_field.after_insert(db, self)
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
if key in self:
|
||||
|
@ -92,7 +110,11 @@ class Record(Dict[(str, Field)]):
|
|||
return hash(str(dict(self)))
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}[{self.doc_id}]: {self.items()}"
|
||||
return (
|
||||
f"{self.__class__.__name__}[{self.doc_id}]("
|
||||
+ ", ".join([f"{key}={val}" for (key, val) in self.items()])
|
||||
+ ")"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -104,21 +126,32 @@ class Pointer(Field):
|
|||
name: str = ""
|
||||
value_type: type = Record
|
||||
|
||||
def serialize(self, value: value_type, db: TinyDB) -> str:
|
||||
def serialize(self, value: value_type) -> str:
|
||||
if value:
|
||||
if not value.doc_id:
|
||||
raise PointerReferenceError(value)
|
||||
return f"{value._metadata.table}::{value.uid}"
|
||||
return ""
|
||||
return None
|
||||
|
||||
def deserialize(self, value: str, db: TinyDB, recurse=True) -> value_type:
|
||||
if value:
|
||||
def deserialize(self, value: str, db: TinyDB, recurse: bool = True) -> value_type:
|
||||
return Pointer.dereference(value, db, recurse)
|
||||
|
||||
@classmethod
|
||||
def dereference(cls, value: str, db: TinyDB, recurse: bool = True):
|
||||
if not value:
|
||||
return
|
||||
elif type(value) == str:
|
||||
pt, puid = value.split("::")
|
||||
if puid:
|
||||
return db.table(pt).search(where("uid") == puid, recurse=recurse)[0]
|
||||
return value
|
||||
|
||||
|
||||
@dataclass
|
||||
class BackReference(Pointer):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class Collection(Field):
|
||||
"""
|
||||
|
@ -131,16 +164,29 @@ class Collection(Field):
|
|||
def _pointer(self, rec):
|
||||
return Pointer(value_type=type(rec))
|
||||
|
||||
def serialize(self, values: List[Record], db: TinyDB) -> List[str]:
|
||||
if values:
|
||||
return [self._pointer(val).serialize(val, db=db) for val in values]
|
||||
def serialize(self, values: List[value_type]) -> List[str]:
|
||||
return [self._pointer(val).serialize(val) for val in values]
|
||||
|
||||
def deserialize(self, values: List[str], db: TinyDB, recurse=True) -> List[value_type]:
|
||||
def deserialize(self, values: List[str], db: TinyDB, recurse: bool = True) -> List[value_type]:
|
||||
"""
|
||||
Recursively deserialize the objects in this collection
|
||||
"""
|
||||
recs = []
|
||||
if not recurse:
|
||||
return values
|
||||
if values:
|
||||
return [self._pointer(val).deserialize(val, db=db, recurse=recurse) for val in values]
|
||||
return []
|
||||
for val in values:
|
||||
recs.append(self._pointer(val).deserialize(val, db=db, recurse=recurse))
|
||||
return recs
|
||||
|
||||
def after_insert(self, db: TinyDB, record: Record) -> None:
|
||||
"""
|
||||
Populate any backreferences in the members of this collection with the parent record's uid.
|
||||
"""
|
||||
if not record[self.name]:
|
||||
return
|
||||
|
||||
for member in record[self.name]:
|
||||
reference = Pointer.dereference(member, db=db)
|
||||
for backref in reference._metadata.backrefs(type(record)):
|
||||
reference[backref.name] = record
|
||||
db.table(reference._metadata.table).upsert(reference)
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
from pprint import pprint as print
|
||||
|
||||
import pytest
|
||||
from tinydb import Query
|
||||
from tinydb.storages import MemoryStorage
|
||||
|
@ -44,16 +46,18 @@ def test_crud(db):
|
|||
|
||||
def test_pointers(db):
|
||||
user = examples.User(name="john", email="john@foo")
|
||||
players = examples.Group(name="players", members=[user], groups=[])
|
||||
players = examples.Group(name="players", members=[user])
|
||||
|
||||
with pytest.raises(PointerReferenceError):
|
||||
players = db.save(players)
|
||||
|
||||
user = db.save(user)
|
||||
players = db.save(examples.Group(name="players", members=[user]))
|
||||
|
||||
players.members[0] = user
|
||||
players = db.save(players)
|
||||
assert players.members[0] == user
|
||||
user = db.table('User').get(doc_id=user.doc_id)
|
||||
assert user.groups.uid == players.uid
|
||||
|
||||
assert players.members[0].groups.uid == players.uid
|
||||
|
||||
|
||||
def test_subgroups(db):
|
||||
|
@ -66,7 +70,11 @@ def test_subgroups(db):
|
|||
starfleet = db.save(examples.Group(name="Starfleet", groups=[tos, snw]))
|
||||
assert tos in starfleet.groups
|
||||
assert snw in starfleet.groups
|
||||
assert kirk in set([user for group in starfleet.groups for user in group.members])
|
||||
|
||||
unique_users = set([user for group in starfleet.groups for user in group.members])
|
||||
|
||||
kirk = db.table('User').get(doc_id=kirk.doc_id)
|
||||
assert kirk in unique_users
|
||||
|
||||
|
||||
def test_unique(db):
|
||||
|
@ -80,18 +88,23 @@ def test_unique(db):
|
|||
|
||||
|
||||
def test_search(db):
|
||||
# create crew members
|
||||
kirk = db.save(examples.User(name="Captain James T. Kirk", email="riskybiznez@starfleet"))
|
||||
bones = db.save(examples.User(name="Doctor McCoy", email="dammitjim@starfleet"))
|
||||
ricky = db.save(examples.User(name="Ensign Ricky Redshirt", email="invincible@starfleet"))
|
||||
|
||||
# create the crew record
|
||||
crew = db.save(examples.Group(name="Crew", members=[kirk, bones, ricky]))
|
||||
|
||||
User = Query()
|
||||
captains = db.User.search(User.name.matches("Captain"))
|
||||
assert len(captains) == 1
|
||||
|
||||
Group = Query()
|
||||
crew = db.Group.search(Group.name == "Crew")[0]
|
||||
# update the crew members so they have the backreference to crew
|
||||
kirk = db.table('User').get(doc_id=kirk.doc_id)
|
||||
bones = db.table('User').get(doc_id=bones.doc_id)
|
||||
ricky = db.table('User').get(doc_id=ricky.doc_id)
|
||||
|
||||
assert kirk in crew.members
|
||||
assert bones in crew.members
|
||||
assert ricky in crew.members
|
||||
|
|
Loading…
Reference in New Issue
Block a user