add serialization, custom field types

This commit is contained in:
evilchili 2025-09-27 15:13:17 -07:00
parent 7e649ee6e0
commit f9ebb4a8d8
5 changed files with 116 additions and 28 deletions

View File

@ -21,18 +21,25 @@ class RecordTable(table.Table):
super().__init__(db.storage, name, **kwargs) super().__init__(db.storage, name, **kwargs)
def insert(self, document): def insert(self, document):
self._satisfy_constraints(document) doc = document.serialize(self._db)
if document.doc_id: self._check_constraints(doc)
last_insert_id = super().upsert(document)[0]
if doc.doc_id:
last_insert_id = super().upsert(doc)[0]
else: else:
last_insert_id = super().insert(dict(document)) last_insert_id = super().insert(dict(doc))
return self.get(doc_id=last_insert_id) return self.get(doc_id=last_insert_id)
def get(self, doc_id: int):
document = super().get(doc_id=doc_id)
if document:
return document.deserialize(self._db)
def remove(self, document): def remove(self, document):
if document.doc_id: if document.doc_id:
super().remove(doc_ids=[document.doc_id]) super().remove(doc_ids=[document.doc_id])
def _satisfy_constraints(self, document) -> bool: def _check_constraints(self, document) -> bool:
self._check_unique(document) self._check_unique(document)
def _check_unique(self, document) -> bool: def _check_unique(self, document) -> bool:

View File

@ -1,11 +1,9 @@
from typing import List from grung.types import Collection, Field, Integer, Record
from grung.types import Field, Record
class User(Record): class User(Record):
_fields = [Field("name"), Field("email", unique=True)] _fields = [Field("name"), Integer("number", default=0), Field("email", unique=True)]
class Group(Record): class Group(Record):
_fields = [Field("name", unique=True), Field("users", List[User])] _fields = [Field("name", unique=True), Collection("users", User)]

View File

@ -10,3 +10,17 @@ class UniqueConstraintError(Exception):
f" * Error: Unique constraint failure\n" f" * Error: Unique constraint failure\n"
" * The record matches the following existing records:\n\n" + "\n".join(str(c) for c in collisions) " * The record matches the following existing records:\n\n" + "\n".join(str(c) for c in collisions)
) )
class PointerReferenceError(Exception):
"""
Thrown when a document field containing a document could not be resolve to an existing record in the database.
"""
def __init__(self, reference):
super().__init__(
"\n"
f" * Reference: {reference}\n"
f" * Error: Invalid Pointer\n"
" * This collection member does not refer an existing record. Do you need to save it first?"
)

View File

@ -1,7 +1,13 @@
from __future__ import annotations
from collections import namedtuple from collections import namedtuple
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List
import nanoid import nanoid
from tinydb import where
from grung.exceptions import PointerReferenceError
Metadata = namedtuple("Metadata", ["table", "fields"]) Metadata = namedtuple("Metadata", ["table", "fields"])
@ -12,13 +18,29 @@ class Field:
Represents a single field in a Record. Represents a single field in a Record.
""" """
value_type = str
name: str name: str
value_type: type = str
default: value_type | None = None default: value_type | None = None
unique: bool = False unique: bool = False
def serialize(self, rec: value_type, db: TinyDB) -> str:
return str(rec)
class Record(dict): def deserialize(self, rec: str, db: TinyDB) -> value_type:
return rec
class Integer(Field):
value_type = int
default: value_type = 0
def deserialize(self, rec: str, db: TinyDB) -> value_type:
return int(rec)
class Record(Dict[(str, Field)]):
""" """
Base type for a single database record. Base type for a single database record.
""" """
@ -33,6 +55,24 @@ class Record(dict):
self.doc_id = doc_id self.doc_id = doc_id
super().__init__(dict({field.name: field.default for field in self._fields}, **raw_doc, **params)) super().__init__(dict({field.name: field.default for field in self._fields}, **raw_doc, **params))
def serialize(self, db):
"""
Serialie every field on the record
"""
rec = {}
for name, field in self._metadata.fields.items():
rec[name] = field.serialize(self[name], db)
return self.__class__(rec, doc_id=self.doc_id)
def deserialize(self, db):
"""
Deserialize every field on the record
"""
rec = {}
for name, field in self._metadata.fields.items():
rec[name] = field.deserialize(self[name], db)
return self.__class__(rec, doc_id=self.doc_id)
def __setattr__(self, key, value): def __setattr__(self, key, value):
if key in self: if key in self:
self[key] = value self[key] = value
@ -45,3 +85,29 @@ class Record(dict):
def __repr__(self): def __repr__(self):
return f"{self.__class__.__name__}[{self.doc_id}]: {self.items()}" return f"{self.__class__.__name__}[{self.doc_id}]: {self.items()}"
class Collection(Field):
"""
A collection of fields that store pointers instead of dicts.
"""
value_type = List[Record]
def serialize(self, recs: value_type, db: TinyDB) -> List[str]:
vals = []
for rec in recs:
if not rec.doc_id:
raise PointerReferenceError(rec)
vals.append(f"{rec._metadata.table}::{rec.uid}")
return vals
def deserialize(self, rec: List[str], db: TinyDB) -> Collection.value_type:
"""
Recursively deserialize the objects in this collection
"""
vals = []
for member in rec:
pt, puid = member.split("::")
vals.append(db.table(pt).search(where("uid") == puid)[0].deserialize(db))
return vals

View File

@ -3,7 +3,7 @@ from tinydb.storages import MemoryStorage
from grung import examples from grung import examples
from grung.db import GrungDB from grung.db import GrungDB
from grung.exceptions import UniqueConstraintError from grung.exceptions import PointerReferenceError, UniqueConstraintError
@pytest.fixture @pytest.fixture
@ -14,7 +14,7 @@ def db():
def test_crud(db): def test_crud(db):
user = examples.User(name="john", email="john@foo") user = examples.User(name="john", number=23, email="john@foo")
assert user.uid assert user.uid
assert user._metadata.fields["uid"].unique assert user._metadata.fields["uid"].unique
@ -25,6 +25,7 @@ def test_crud(db):
# read back # read back
assert db.User.get(doc_id=last_insert_id) == john_something assert db.User.get(doc_id=last_insert_id) == john_something
assert john_something.name == user.name assert john_something.name == user.name
assert john_something.number == 23
assert john_something.email == user.email assert john_something.email == user.email
assert john_something.uid == user.uid assert john_something.uid == user.uid
@ -35,21 +36,23 @@ def test_crud(db):
assert after_update == john_something assert after_update == john_something
assert before_update != after_update assert before_update != after_update
# pointers
players = examples.Group(name="players", users=[john_something])
players = db.save(players)
players.users[0]["name"] = "fnord"
db.save(players)
# modify records
players.users = []
db.save(players)
after_update = db.Group.get(doc_id=players.doc_id)
assert after_update.users == []
# delete # delete
db.delete(players) db.delete(john_something)
assert len(db.Group) == 0 assert len(db.User) == 0
def test_pointers(db):
user = examples.User(name="john", email="john@foo")
players = examples.Group(name="players", users=[user])
with pytest.raises(PointerReferenceError):
players = db.save(players)
user = db.save(user)
players.users[0] = user
players = db.save(players)
assert players.users[0] == user
def test_unique(db): def test_unique(db):