add serialization, custom field types
This commit is contained in:
parent
7e649ee6e0
commit
f9ebb4a8d8
|
@ -21,18 +21,25 @@ class RecordTable(table.Table):
|
|||
super().__init__(db.storage, name, **kwargs)
|
||||
|
||||
def insert(self, document):
|
||||
self._satisfy_constraints(document)
|
||||
if document.doc_id:
|
||||
last_insert_id = super().upsert(document)[0]
|
||||
doc = document.serialize(self._db)
|
||||
self._check_constraints(doc)
|
||||
|
||||
if doc.doc_id:
|
||||
last_insert_id = super().upsert(doc)[0]
|
||||
else:
|
||||
last_insert_id = super().insert(dict(document))
|
||||
last_insert_id = super().insert(dict(doc))
|
||||
return self.get(doc_id=last_insert_id)
|
||||
|
||||
def get(self, doc_id: int):
|
||||
document = super().get(doc_id=doc_id)
|
||||
if document:
|
||||
return document.deserialize(self._db)
|
||||
|
||||
def remove(self, document):
|
||||
if document.doc_id:
|
||||
super().remove(doc_ids=[document.doc_id])
|
||||
|
||||
def _satisfy_constraints(self, document) -> bool:
|
||||
def _check_constraints(self, document) -> bool:
|
||||
self._check_unique(document)
|
||||
|
||||
def _check_unique(self, document) -> bool:
|
||||
|
|
|
@ -1,11 +1,9 @@
|
|||
from typing import List
|
||||
|
||||
from grung.types import Field, Record
|
||||
from grung.types import Collection, Field, Integer, Record
|
||||
|
||||
|
||||
class User(Record):
|
||||
_fields = [Field("name"), Field("email", unique=True)]
|
||||
_fields = [Field("name"), Integer("number", default=0), Field("email", unique=True)]
|
||||
|
||||
|
||||
class Group(Record):
|
||||
_fields = [Field("name", unique=True), Field("users", List[User])]
|
||||
_fields = [Field("name", unique=True), Collection("users", User)]
|
||||
|
|
|
@ -10,3 +10,17 @@ class UniqueConstraintError(Exception):
|
|||
f" * Error: Unique constraint failure\n"
|
||||
" * The record matches the following existing records:\n\n" + "\n".join(str(c) for c in collisions)
|
||||
)
|
||||
|
||||
|
||||
class PointerReferenceError(Exception):
|
||||
"""
|
||||
Thrown when a document field containing a document could not be resolve to an existing record in the database.
|
||||
"""
|
||||
|
||||
def __init__(self, reference):
|
||||
super().__init__(
|
||||
"\n"
|
||||
f" * Reference: {reference}\n"
|
||||
f" * Error: Invalid Pointer\n"
|
||||
" * This collection member does not refer an existing record. Do you need to save it first?"
|
||||
)
|
||||
|
|
|
@ -1,7 +1,13 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from collections import namedtuple
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List
|
||||
|
||||
import nanoid
|
||||
from tinydb import where
|
||||
|
||||
from grung.exceptions import PointerReferenceError
|
||||
|
||||
Metadata = namedtuple("Metadata", ["table", "fields"])
|
||||
|
||||
|
@ -12,13 +18,29 @@ class Field:
|
|||
Represents a single field in a Record.
|
||||
"""
|
||||
|
||||
value_type = str
|
||||
|
||||
name: str
|
||||
value_type: type = str
|
||||
default: value_type | None = None
|
||||
unique: bool = False
|
||||
|
||||
def serialize(self, rec: value_type, db: TinyDB) -> str:
|
||||
return str(rec)
|
||||
|
||||
class Record(dict):
|
||||
def deserialize(self, rec: str, db: TinyDB) -> value_type:
|
||||
return rec
|
||||
|
||||
|
||||
class Integer(Field):
|
||||
value_type = int
|
||||
|
||||
default: value_type = 0
|
||||
|
||||
def deserialize(self, rec: str, db: TinyDB) -> value_type:
|
||||
return int(rec)
|
||||
|
||||
|
||||
class Record(Dict[(str, Field)]):
|
||||
"""
|
||||
Base type for a single database record.
|
||||
"""
|
||||
|
@ -33,6 +55,24 @@ class Record(dict):
|
|||
self.doc_id = doc_id
|
||||
super().__init__(dict({field.name: field.default for field in self._fields}, **raw_doc, **params))
|
||||
|
||||
def serialize(self, db):
|
||||
"""
|
||||
Serialie every field on the record
|
||||
"""
|
||||
rec = {}
|
||||
for name, field in self._metadata.fields.items():
|
||||
rec[name] = field.serialize(self[name], db)
|
||||
return self.__class__(rec, doc_id=self.doc_id)
|
||||
|
||||
def deserialize(self, db):
|
||||
"""
|
||||
Deserialize every field on the record
|
||||
"""
|
||||
rec = {}
|
||||
for name, field in self._metadata.fields.items():
|
||||
rec[name] = field.deserialize(self[name], db)
|
||||
return self.__class__(rec, doc_id=self.doc_id)
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
if key in self:
|
||||
self[key] = value
|
||||
|
@ -45,3 +85,29 @@ class Record(dict):
|
|||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}[{self.doc_id}]: {self.items()}"
|
||||
|
||||
|
||||
class Collection(Field):
|
||||
"""
|
||||
A collection of fields that store pointers instead of dicts.
|
||||
"""
|
||||
|
||||
value_type = List[Record]
|
||||
|
||||
def serialize(self, recs: value_type, db: TinyDB) -> List[str]:
|
||||
vals = []
|
||||
for rec in recs:
|
||||
if not rec.doc_id:
|
||||
raise PointerReferenceError(rec)
|
||||
vals.append(f"{rec._metadata.table}::{rec.uid}")
|
||||
return vals
|
||||
|
||||
def deserialize(self, rec: List[str], db: TinyDB) -> Collection.value_type:
|
||||
"""
|
||||
Recursively deserialize the objects in this collection
|
||||
"""
|
||||
vals = []
|
||||
for member in rec:
|
||||
pt, puid = member.split("::")
|
||||
vals.append(db.table(pt).search(where("uid") == puid)[0].deserialize(db))
|
||||
return vals
|
||||
|
|
|
@ -3,7 +3,7 @@ from tinydb.storages import MemoryStorage
|
|||
|
||||
from grung import examples
|
||||
from grung.db import GrungDB
|
||||
from grung.exceptions import UniqueConstraintError
|
||||
from grung.exceptions import PointerReferenceError, UniqueConstraintError
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -14,7 +14,7 @@ def db():
|
|||
|
||||
|
||||
def test_crud(db):
|
||||
user = examples.User(name="john", email="john@foo")
|
||||
user = examples.User(name="john", number=23, email="john@foo")
|
||||
assert user.uid
|
||||
assert user._metadata.fields["uid"].unique
|
||||
|
||||
|
@ -25,6 +25,7 @@ def test_crud(db):
|
|||
# read back
|
||||
assert db.User.get(doc_id=last_insert_id) == john_something
|
||||
assert john_something.name == user.name
|
||||
assert john_something.number == 23
|
||||
assert john_something.email == user.email
|
||||
assert john_something.uid == user.uid
|
||||
|
||||
|
@ -35,21 +36,23 @@ def test_crud(db):
|
|||
assert after_update == john_something
|
||||
assert before_update != after_update
|
||||
|
||||
# pointers
|
||||
players = examples.Group(name="players", users=[john_something])
|
||||
players = db.save(players)
|
||||
players.users[0]["name"] = "fnord"
|
||||
db.save(players)
|
||||
|
||||
# modify records
|
||||
players.users = []
|
||||
db.save(players)
|
||||
after_update = db.Group.get(doc_id=players.doc_id)
|
||||
assert after_update.users == []
|
||||
|
||||
# delete
|
||||
db.delete(players)
|
||||
assert len(db.Group) == 0
|
||||
db.delete(john_something)
|
||||
assert len(db.User) == 0
|
||||
|
||||
|
||||
def test_pointers(db):
|
||||
user = examples.User(name="john", email="john@foo")
|
||||
players = examples.Group(name="players", users=[user])
|
||||
|
||||
with pytest.raises(PointerReferenceError):
|
||||
players = db.save(players)
|
||||
|
||||
user = db.save(user)
|
||||
|
||||
players.users[0] = user
|
||||
players = db.save(players)
|
||||
assert players.users[0] == user
|
||||
|
||||
|
||||
def test_unique(db):
|
||||
|
|
Loading…
Reference in New Issue
Block a user