fix pointers again
This commit is contained in:
parent
5bddca974d
commit
a9261321d2
|
@ -22,15 +22,17 @@ class RecordTable(table.Table):
|
||||||
super().__init__(db.storage, name, **kwargs)
|
super().__init__(db.storage, name, **kwargs)
|
||||||
|
|
||||||
def insert(self, document):
|
def insert(self, document):
|
||||||
document.before_insert()
|
document.before_insert(self._db)
|
||||||
doc = document.serialize(self._db)
|
doc = document.serialize()
|
||||||
self._check_constraints(doc)
|
self._check_constraints(doc)
|
||||||
|
|
||||||
if doc.doc_id:
|
if doc.doc_id:
|
||||||
last_insert_id = super().upsert(doc)[0]
|
last_insert_id = super().upsert(doc)[0]
|
||||||
else:
|
else:
|
||||||
last_insert_id = super().insert(dict(doc))
|
last_insert_id = super().insert(dict(doc))
|
||||||
return self.get(doc_id=last_insert_id)
|
doc.doc_id = last_insert_id
|
||||||
|
doc.after_insert(self._db)
|
||||||
|
return doc.deserialize(self._db)
|
||||||
|
|
||||||
def get(self, doc_id: int, recurse: bool = True):
|
def get(self, doc_id: int, recurse: bool = True):
|
||||||
document = super().get(doc_id=doc_id)
|
document = super().get(doc_id=doc_id)
|
||||||
|
|
|
@ -1,13 +1,24 @@
|
||||||
from grung.types import Collection, Field, Integer, Record
|
from grung.types import BackReference, Collection, Field, Integer, Record
|
||||||
|
|
||||||
|
|
||||||
class User(Record):
|
class User(Record):
|
||||||
@classmethod
|
@classmethod
|
||||||
def fields(cls):
|
def fields(cls):
|
||||||
return [*super().fields(), Field("name"), Integer("number", default=0), Field("email", unique=True)]
|
return [
|
||||||
|
*super().fields(),
|
||||||
|
Field("name"),
|
||||||
|
Integer("number", default=0),
|
||||||
|
Field("email", unique=True),
|
||||||
|
BackReference("groups", Group),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class Group(Record):
|
class Group(Record):
|
||||||
@classmethod
|
@classmethod
|
||||||
def fields(cls):
|
def fields(cls):
|
||||||
return [*super().fields(), Field("name", unique=True), Collection("members", User), Collection("groups", cls)]
|
return [
|
||||||
|
*super().fields(),
|
||||||
|
Field("name", unique=True),
|
||||||
|
Collection("members", User),
|
||||||
|
Collection("groups", Group),
|
||||||
|
]
|
||||||
|
|
|
@ -9,7 +9,7 @@ from tinydb import TinyDB, where
|
||||||
|
|
||||||
from grung.exceptions import PointerReferenceError
|
from grung.exceptions import PointerReferenceError
|
||||||
|
|
||||||
Metadata = namedtuple("Metadata", ["table", "fields"])
|
Metadata = namedtuple("Metadata", ["table", "fields", "backrefs"])
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -23,20 +23,27 @@ class Field:
|
||||||
default: str = None
|
default: str = None
|
||||||
unique: bool = False
|
unique: bool = False
|
||||||
|
|
||||||
def serialize(self, rec: value_type, db: TinyDB) -> str:
|
def before_insert(self, value: value_type, db: TinyDB, record: Record) -> None:
|
||||||
return str(rec)
|
pass
|
||||||
|
|
||||||
def deserialize(self, rec: value_type, db: TinyDB, recurse: bool = True) -> value_type:
|
def after_insert(self, db: TinyDB, record: Record) -> None:
|
||||||
return rec
|
pass
|
||||||
|
|
||||||
|
def serialize(self, value: value_type) -> str:
|
||||||
|
if value is not None:
|
||||||
|
return str(value)
|
||||||
|
|
||||||
|
def deserialize(self, value: value_type, db: TinyDB, recurse: bool = True) -> value_type:
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Integer(Field):
|
class Integer(Field):
|
||||||
value_type: type = int
|
value_type = int
|
||||||
default: int = 0
|
default: int = 0
|
||||||
|
|
||||||
def deserialize(self, rec: str, db: TinyDB, recurse: bool = True) -> value_type:
|
def deserialize(self, value: str, db: TinyDB, recurse: bool = True) -> value_type:
|
||||||
return int(rec)
|
return int(value)
|
||||||
|
|
||||||
|
|
||||||
class Record(Dict[(str, Field)]):
|
class Record(Dict[(str, Field)]):
|
||||||
|
@ -47,7 +54,13 @@ class Record(Dict[(str, Field)]):
|
||||||
def __init__(self, raw_doc: dict = {}, doc_id: int = None, **params):
|
def __init__(self, raw_doc: dict = {}, doc_id: int = None, **params):
|
||||||
self.doc_id = doc_id
|
self.doc_id = doc_id
|
||||||
fields = self.__class__.fields()
|
fields = self.__class__.fields()
|
||||||
self._metadata = Metadata(table=self.__class__.__name__, fields={f.name: f for f in fields})
|
self._metadata = Metadata(
|
||||||
|
table=self.__class__.__name__,
|
||||||
|
fields={f.name: f for f in fields},
|
||||||
|
backrefs=lambda value_type: (
|
||||||
|
field for field in fields if type(field) == BackReference and field.value_type == value_type
|
||||||
|
),
|
||||||
|
)
|
||||||
super().__init__(dict({field.name: field.default for field in fields}, **raw_doc, **params))
|
super().__init__(dict({field.name: field.default for field in fields}, **raw_doc, **params))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -57,13 +70,13 @@ class Record(Dict[(str, Field)]):
|
||||||
Field("uid", default=nanoid.generate(size=8), unique=True)
|
Field("uid", default=nanoid.generate(size=8), unique=True)
|
||||||
]
|
]
|
||||||
|
|
||||||
def serialize(self, db):
|
def serialize(self):
|
||||||
"""
|
"""
|
||||||
Serialie every field on the record
|
Serialie every field on the record
|
||||||
"""
|
"""
|
||||||
rec = {}
|
rec = {}
|
||||||
for name, _field in self._metadata.fields.items():
|
for name, _field in self._metadata.fields.items():
|
||||||
rec[name] = _field.serialize(self[name], db)
|
rec[name] = _field.serialize(self[name])
|
||||||
return self.__class__(rec, doc_id=self.doc_id)
|
return self.__class__(rec, doc_id=self.doc_id)
|
||||||
|
|
||||||
def deserialize(self, db, recurse: bool = True):
|
def deserialize(self, db, recurse: bool = True):
|
||||||
|
@ -75,8 +88,13 @@ class Record(Dict[(str, Field)]):
|
||||||
rec[name] = _field.deserialize(self[name], db, recurse=recurse)
|
rec[name] = _field.deserialize(self[name], db, recurse=recurse)
|
||||||
return self.__class__(rec, doc_id=self.doc_id)
|
return self.__class__(rec, doc_id=self.doc_id)
|
||||||
|
|
||||||
def before_insert(self):
|
def before_insert(self, db: TinyDB) -> None:
|
||||||
pass
|
for name, _field in self._metadata.fields.items():
|
||||||
|
_field.before_insert(self[name], db, self)
|
||||||
|
|
||||||
|
def after_insert(self, db: TinyDB) -> None:
|
||||||
|
for name, _field in self._metadata.fields.items():
|
||||||
|
_field.after_insert(db, self)
|
||||||
|
|
||||||
def __setattr__(self, key, value):
|
def __setattr__(self, key, value):
|
||||||
if key in self:
|
if key in self:
|
||||||
|
@ -92,7 +110,11 @@ class Record(Dict[(str, Field)]):
|
||||||
return hash(str(dict(self)))
|
return hash(str(dict(self)))
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"{self.__class__.__name__}[{self.doc_id}]: {self.items()}"
|
return (
|
||||||
|
f"{self.__class__.__name__}[{self.doc_id}]("
|
||||||
|
+ ", ".join([f"{key}={val}" for (key, val) in self.items()])
|
||||||
|
+ ")"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -104,21 +126,32 @@ class Pointer(Field):
|
||||||
name: str = ""
|
name: str = ""
|
||||||
value_type: type = Record
|
value_type: type = Record
|
||||||
|
|
||||||
def serialize(self, value: value_type, db: TinyDB) -> str:
|
def serialize(self, value: value_type) -> str:
|
||||||
if value:
|
if value:
|
||||||
if not value.doc_id:
|
if not value.doc_id:
|
||||||
raise PointerReferenceError(value)
|
raise PointerReferenceError(value)
|
||||||
return f"{value._metadata.table}::{value.uid}"
|
return f"{value._metadata.table}::{value.uid}"
|
||||||
return ""
|
return None
|
||||||
|
|
||||||
def deserialize(self, value: str, db: TinyDB, recurse=True) -> value_type:
|
def deserialize(self, value: str, db: TinyDB, recurse: bool = True) -> value_type:
|
||||||
if value:
|
return Pointer.dereference(value, db, recurse)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def dereference(cls, value: str, db: TinyDB, recurse: bool = True):
|
||||||
|
if not value:
|
||||||
|
return
|
||||||
|
elif type(value) == str:
|
||||||
pt, puid = value.split("::")
|
pt, puid = value.split("::")
|
||||||
if puid:
|
if puid:
|
||||||
return db.table(pt).search(where("uid") == puid, recurse=recurse)[0]
|
return db.table(pt).search(where("uid") == puid, recurse=recurse)[0]
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BackReference(Pointer):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Collection(Field):
|
class Collection(Field):
|
||||||
"""
|
"""
|
||||||
|
@ -131,16 +164,29 @@ class Collection(Field):
|
||||||
def _pointer(self, rec):
|
def _pointer(self, rec):
|
||||||
return Pointer(value_type=type(rec))
|
return Pointer(value_type=type(rec))
|
||||||
|
|
||||||
def serialize(self, values: List[Record], db: TinyDB) -> List[str]:
|
def serialize(self, values: List[value_type]) -> List[str]:
|
||||||
if values:
|
return [self._pointer(val).serialize(val) for val in values]
|
||||||
return [self._pointer(val).serialize(val, db=db) for val in values]
|
|
||||||
|
|
||||||
def deserialize(self, values: List[str], db: TinyDB, recurse=True) -> List[value_type]:
|
def deserialize(self, values: List[str], db: TinyDB, recurse: bool = True) -> List[value_type]:
|
||||||
"""
|
"""
|
||||||
Recursively deserialize the objects in this collection
|
Recursively deserialize the objects in this collection
|
||||||
"""
|
"""
|
||||||
|
recs = []
|
||||||
if not recurse:
|
if not recurse:
|
||||||
return values
|
return values
|
||||||
if values:
|
for val in values:
|
||||||
return [self._pointer(val).deserialize(val, db=db, recurse=recurse) for val in values]
|
recs.append(self._pointer(val).deserialize(val, db=db, recurse=recurse))
|
||||||
return []
|
return recs
|
||||||
|
|
||||||
|
def after_insert(self, db: TinyDB, record: Record) -> None:
|
||||||
|
"""
|
||||||
|
Populate any backreferences in the members of this collection with the parent record's uid.
|
||||||
|
"""
|
||||||
|
if not record[self.name]:
|
||||||
|
return
|
||||||
|
|
||||||
|
for member in record[self.name]:
|
||||||
|
reference = Pointer.dereference(member, db=db)
|
||||||
|
for backref in reference._metadata.backrefs(type(record)):
|
||||||
|
reference[backref.name] = record
|
||||||
|
db.table(reference._metadata.table).upsert(reference)
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
from pprint import pprint as print
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from tinydb import Query
|
from tinydb import Query
|
||||||
from tinydb.storages import MemoryStorage
|
from tinydb.storages import MemoryStorage
|
||||||
|
@ -44,16 +46,18 @@ def test_crud(db):
|
||||||
|
|
||||||
def test_pointers(db):
|
def test_pointers(db):
|
||||||
user = examples.User(name="john", email="john@foo")
|
user = examples.User(name="john", email="john@foo")
|
||||||
players = examples.Group(name="players", members=[user], groups=[])
|
players = examples.Group(name="players", members=[user])
|
||||||
|
|
||||||
with pytest.raises(PointerReferenceError):
|
with pytest.raises(PointerReferenceError):
|
||||||
players = db.save(players)
|
players = db.save(players)
|
||||||
|
|
||||||
user = db.save(user)
|
user = db.save(user)
|
||||||
|
players = db.save(examples.Group(name="players", members=[user]))
|
||||||
|
|
||||||
players.members[0] = user
|
user = db.table('User').get(doc_id=user.doc_id)
|
||||||
players = db.save(players)
|
assert user.groups.uid == players.uid
|
||||||
assert players.members[0] == user
|
|
||||||
|
assert players.members[0].groups.uid == players.uid
|
||||||
|
|
||||||
|
|
||||||
def test_subgroups(db):
|
def test_subgroups(db):
|
||||||
|
@ -66,7 +70,11 @@ def test_subgroups(db):
|
||||||
starfleet = db.save(examples.Group(name="Starfleet", groups=[tos, snw]))
|
starfleet = db.save(examples.Group(name="Starfleet", groups=[tos, snw]))
|
||||||
assert tos in starfleet.groups
|
assert tos in starfleet.groups
|
||||||
assert snw in starfleet.groups
|
assert snw in starfleet.groups
|
||||||
assert kirk in set([user for group in starfleet.groups for user in group.members])
|
|
||||||
|
unique_users = set([user for group in starfleet.groups for user in group.members])
|
||||||
|
|
||||||
|
kirk = db.table('User').get(doc_id=kirk.doc_id)
|
||||||
|
assert kirk in unique_users
|
||||||
|
|
||||||
|
|
||||||
def test_unique(db):
|
def test_unique(db):
|
||||||
|
@ -80,18 +88,23 @@ def test_unique(db):
|
||||||
|
|
||||||
|
|
||||||
def test_search(db):
|
def test_search(db):
|
||||||
|
# create crew members
|
||||||
kirk = db.save(examples.User(name="Captain James T. Kirk", email="riskybiznez@starfleet"))
|
kirk = db.save(examples.User(name="Captain James T. Kirk", email="riskybiznez@starfleet"))
|
||||||
bones = db.save(examples.User(name="Doctor McCoy", email="dammitjim@starfleet"))
|
bones = db.save(examples.User(name="Doctor McCoy", email="dammitjim@starfleet"))
|
||||||
ricky = db.save(examples.User(name="Ensign Ricky Redshirt", email="invincible@starfleet"))
|
ricky = db.save(examples.User(name="Ensign Ricky Redshirt", email="invincible@starfleet"))
|
||||||
|
|
||||||
|
# create the crew record
|
||||||
crew = db.save(examples.Group(name="Crew", members=[kirk, bones, ricky]))
|
crew = db.save(examples.Group(name="Crew", members=[kirk, bones, ricky]))
|
||||||
|
|
||||||
User = Query()
|
User = Query()
|
||||||
captains = db.User.search(User.name.matches("Captain"))
|
captains = db.User.search(User.name.matches("Captain"))
|
||||||
assert len(captains) == 1
|
assert len(captains) == 1
|
||||||
|
|
||||||
Group = Query()
|
# update the crew members so they have the backreference to crew
|
||||||
crew = db.Group.search(Group.name == "Crew")[0]
|
kirk = db.table('User').get(doc_id=kirk.doc_id)
|
||||||
|
bones = db.table('User').get(doc_id=bones.doc_id)
|
||||||
|
ricky = db.table('User').get(doc_id=ricky.doc_id)
|
||||||
|
|
||||||
assert kirk in crew.members
|
assert kirk in crew.members
|
||||||
assert bones in crew.members
|
assert bones in crew.members
|
||||||
assert ricky in crew.members
|
assert ricky in crew.members
|
||||||
|
|
Loading…
Reference in New Issue
Block a user