From f9ebb4a8d80060f3ea52917a16dedb2d34c51921 Mon Sep 17 00:00:00 2001 From: evilchili Date: Sat, 27 Sep 2025 15:13:17 -0700 Subject: [PATCH] add serialization, custom field types --- src/grung/db.py | 17 +++++++--- src/grung/examples.py | 8 ++--- src/grung/exceptions.py | 14 +++++++++ src/grung/types.py | 70 +++++++++++++++++++++++++++++++++++++++-- test/test_db.py | 35 +++++++++++---------- 5 files changed, 116 insertions(+), 28 deletions(-) diff --git a/src/grung/db.py b/src/grung/db.py index d1227ab..0ec964e 100644 --- a/src/grung/db.py +++ b/src/grung/db.py @@ -21,18 +21,25 @@ class RecordTable(table.Table): super().__init__(db.storage, name, **kwargs) def insert(self, document): - self._satisfy_constraints(document) - if document.doc_id: - last_insert_id = super().upsert(document)[0] + doc = document.serialize(self._db) + self._check_constraints(doc) + + if doc.doc_id: + last_insert_id = super().upsert(doc)[0] else: - last_insert_id = super().insert(dict(document)) + last_insert_id = super().insert(dict(doc)) return self.get(doc_id=last_insert_id) + def get(self, doc_id: int): + document = super().get(doc_id=doc_id) + if document: + return document.deserialize(self._db) + def remove(self, document): if document.doc_id: super().remove(doc_ids=[document.doc_id]) - def _satisfy_constraints(self, document) -> bool: + def _check_constraints(self, document) -> bool: self._check_unique(document) def _check_unique(self, document) -> bool: diff --git a/src/grung/examples.py b/src/grung/examples.py index 16f7f14..e46f4f0 100644 --- a/src/grung/examples.py +++ b/src/grung/examples.py @@ -1,11 +1,9 @@ -from typing import List - -from grung.types import Field, Record +from grung.types import Collection, Field, Integer, Record class User(Record): - _fields = [Field("name"), Field("email", unique=True)] + _fields = [Field("name"), Integer("number", default=0), Field("email", unique=True)] class Group(Record): - _fields = [Field("name", unique=True), Field("users", List[User])] + _fields = [Field("name", unique=True), Collection("users", User)] diff --git a/src/grung/exceptions.py b/src/grung/exceptions.py index bd1d48f..9b60c50 100644 --- a/src/grung/exceptions.py +++ b/src/grung/exceptions.py @@ -10,3 +10,17 @@ class UniqueConstraintError(Exception): f" * Error: Unique constraint failure\n" " * The record matches the following existing records:\n\n" + "\n".join(str(c) for c in collisions) ) + + +class PointerReferenceError(Exception): + """ + Thrown when a document field containing a document could not be resolve to an existing record in the database. + """ + + def __init__(self, reference): + super().__init__( + "\n" + f" * Reference: {reference}\n" + f" * Error: Invalid Pointer\n" + " * This collection member does not refer an existing record. Do you need to save it first?" + ) diff --git a/src/grung/types.py b/src/grung/types.py index 98ba5b7..b501f38 100644 --- a/src/grung/types.py +++ b/src/grung/types.py @@ -1,7 +1,13 @@ +from __future__ import annotations + from collections import namedtuple from dataclasses import dataclass +from typing import Dict, List import nanoid +from tinydb import where + +from grung.exceptions import PointerReferenceError Metadata = namedtuple("Metadata", ["table", "fields"]) @@ -12,13 +18,29 @@ class Field: Represents a single field in a Record. """ + value_type = str + name: str - value_type: type = str default: value_type | None = None unique: bool = False + def serialize(self, rec: value_type, db: TinyDB) -> str: + return str(rec) -class Record(dict): + def deserialize(self, rec: str, db: TinyDB) -> value_type: + return rec + + +class Integer(Field): + value_type = int + + default: value_type = 0 + + def deserialize(self, rec: str, db: TinyDB) -> value_type: + return int(rec) + + +class Record(Dict[(str, Field)]): """ Base type for a single database record. """ @@ -33,6 +55,24 @@ class Record(dict): self.doc_id = doc_id super().__init__(dict({field.name: field.default for field in self._fields}, **raw_doc, **params)) + def serialize(self, db): + """ + Serialie every field on the record + """ + rec = {} + for name, field in self._metadata.fields.items(): + rec[name] = field.serialize(self[name], db) + return self.__class__(rec, doc_id=self.doc_id) + + def deserialize(self, db): + """ + Deserialize every field on the record + """ + rec = {} + for name, field in self._metadata.fields.items(): + rec[name] = field.deserialize(self[name], db) + return self.__class__(rec, doc_id=self.doc_id) + def __setattr__(self, key, value): if key in self: self[key] = value @@ -45,3 +85,29 @@ class Record(dict): def __repr__(self): return f"{self.__class__.__name__}[{self.doc_id}]: {self.items()}" + + +class Collection(Field): + """ + A collection of fields that store pointers instead of dicts. + """ + + value_type = List[Record] + + def serialize(self, recs: value_type, db: TinyDB) -> List[str]: + vals = [] + for rec in recs: + if not rec.doc_id: + raise PointerReferenceError(rec) + vals.append(f"{rec._metadata.table}::{rec.uid}") + return vals + + def deserialize(self, rec: List[str], db: TinyDB) -> Collection.value_type: + """ + Recursively deserialize the objects in this collection + """ + vals = [] + for member in rec: + pt, puid = member.split("::") + vals.append(db.table(pt).search(where("uid") == puid)[0].deserialize(db)) + return vals diff --git a/test/test_db.py b/test/test_db.py index 0d22f8b..6ad7053 100644 --- a/test/test_db.py +++ b/test/test_db.py @@ -3,7 +3,7 @@ from tinydb.storages import MemoryStorage from grung import examples from grung.db import GrungDB -from grung.exceptions import UniqueConstraintError +from grung.exceptions import PointerReferenceError, UniqueConstraintError @pytest.fixture @@ -14,7 +14,7 @@ def db(): def test_crud(db): - user = examples.User(name="john", email="john@foo") + user = examples.User(name="john", number=23, email="john@foo") assert user.uid assert user._metadata.fields["uid"].unique @@ -25,6 +25,7 @@ def test_crud(db): # read back assert db.User.get(doc_id=last_insert_id) == john_something assert john_something.name == user.name + assert john_something.number == 23 assert john_something.email == user.email assert john_something.uid == user.uid @@ -35,21 +36,23 @@ def test_crud(db): assert after_update == john_something assert before_update != after_update - # pointers - players = examples.Group(name="players", users=[john_something]) - players = db.save(players) - players.users[0]["name"] = "fnord" - db.save(players) - - # modify records - players.users = [] - db.save(players) - after_update = db.Group.get(doc_id=players.doc_id) - assert after_update.users == [] - # delete - db.delete(players) - assert len(db.Group) == 0 + db.delete(john_something) + assert len(db.User) == 0 + + +def test_pointers(db): + user = examples.User(name="john", email="john@foo") + players = examples.Group(name="players", users=[user]) + + with pytest.raises(PointerReferenceError): + players = db.save(players) + + user = db.save(user) + + players.users[0] = user + players = db.save(players) + assert players.users[0] == user def test_unique(db):