use classmethod to build fields on records

This commit is contained in:
evilchili 2025-09-28 10:33:44 -07:00
parent f9ebb4a8d8
commit a76afaa126
3 changed files with 52 additions and 28 deletions

View File

@ -2,8 +2,12 @@ from grung.types import Collection, Field, Integer, Record
class User(Record): class User(Record):
_fields = [Field("name"), Integer("number", default=0), Field("email", unique=True)] @classmethod
def fields(cls):
return [*super().fields(), Field("name"), Integer("number", default=0), Field("email", unique=True)]
class Group(Record): class Group(Record):
_fields = [Field("name", unique=True), Collection("users", User)] @classmethod
def fields(cls):
return [*super().fields(), Field("name", unique=True), Collection("members", User), Collection("groups", cls)]

View File

@ -1,11 +1,11 @@
from __future__ import annotations from __future__ import annotations
from collections import namedtuple from collections import namedtuple
from dataclasses import dataclass from dataclasses import dataclass, field
from typing import Dict, List from typing import Dict, List
import nanoid import nanoid
from tinydb import where from tinydb import TinyDB, where
from grung.exceptions import PointerReferenceError from grung.exceptions import PointerReferenceError
@ -18,23 +18,21 @@ class Field:
Represents a single field in a Record. Represents a single field in a Record.
""" """
value_type = str
name: str name: str
default: value_type | None = None value_type: type = str
default: str = None
unique: bool = False unique: bool = False
def serialize(self, rec: value_type, db: TinyDB) -> str: def serialize(self, rec: value_type, db: TinyDB) -> str:
return str(rec) return str(rec)
def deserialize(self, rec: str, db: TinyDB) -> value_type: def deserialize(self, rec: value_type, db: TinyDB) -> value_type:
return rec return rec
@dataclass
class Integer(Field): class Integer(Field):
value_type = int default: int = 0
default: value_type = 0
def deserialize(self, rec: str, db: TinyDB) -> value_type: def deserialize(self, rec: str, db: TinyDB) -> value_type:
return int(rec) return int(rec)
@ -46,14 +44,17 @@ class Record(Dict[(str, Field)]):
""" """
def __init__(self, raw_doc: dict = {}, doc_id: int = None, **params): def __init__(self, raw_doc: dict = {}, doc_id: int = None, **params):
# populate the metadata self.doc_id = doc_id
self._fields.append( fields = self.__class__.fields()
self._metadata = Metadata(table=self.__class__.__name__, fields={f.name: f for f in fields})
super().__init__(dict({field.name: field.default for field in fields}, **raw_doc, **params))
@classmethod
def fields(self):
return [
# 1% collision rate at ~2M records # 1% collision rate at ~2M records
Field("uid", default=nanoid.generate(size=8), unique=True) Field("uid", default=nanoid.generate(size=8), unique=True)
) ]
self._metadata = Metadata(table=self.__class__.__name__, fields={f.name: f for f in self._fields})
self.doc_id = doc_id
super().__init__(dict({field.name: field.default for field in self._fields}, **raw_doc, **params))
def serialize(self, db): def serialize(self, db):
""" """
@ -83,26 +84,32 @@ class Record(Dict[(str, Field)]):
return self.get(attr_name) return self.get(attr_name)
return super().__getattr__(attr_name) return super().__getattr__(attr_name)
def __hash__(self):
return hash(str(dict(self)))
def __repr__(self): def __repr__(self):
return f"{self.__class__.__name__}[{self.doc_id}]: {self.items()}" return f"{self.__class__.__name__}[{self.doc_id}]: {self.items()}"
@dataclass
class Collection(Field): class Collection(Field):
""" """
A collection of fields that store pointers instead of dicts. A collection of fields that store pointers instead of dicts.
""" """
value_type = List[Record] value_type: type = Record
default: List[value_type] = field(default_factory=lambda: [])
def serialize(self, recs: value_type, db: TinyDB) -> List[str]: def serialize(self, value: value_type, db: TinyDB) -> List[str]:
vals = [] vals = self.default
for rec in recs: if value:
for rec in value:
if not rec.doc_id: if not rec.doc_id:
raise PointerReferenceError(rec) raise PointerReferenceError(rec)
vals.append(f"{rec._metadata.table}::{rec.uid}") vals.append(f"{rec._metadata.table}::{rec.uid}")
return vals return vals
def deserialize(self, rec: List[str], db: TinyDB) -> Collection.value_type: def deserialize(self, rec: List[value_type], db: TinyDB) -> value_type:
""" """
Recursively deserialize the objects in this collection Recursively deserialize the objects in this collection
""" """

View File

@ -43,16 +43,29 @@ def test_crud(db):
def test_pointers(db): def test_pointers(db):
user = examples.User(name="john", email="john@foo") user = examples.User(name="john", email="john@foo")
players = examples.Group(name="players", users=[user]) players = examples.Group(name="players", members=[user], groups=[])
with pytest.raises(PointerReferenceError): with pytest.raises(PointerReferenceError):
players = db.save(players) players = db.save(players)
user = db.save(user) user = db.save(user)
players.users[0] = user players.members[0] = user
players = db.save(players) players = db.save(players)
assert players.users[0] == user assert players.members[0] == user
def test_subgroups(db):
kirk = db.save(examples.User(name="James T. Kirk", email="riskybiznez@starfleet"))
pike = db.save(examples.User(name="Christopher Pike", email="hitit@starfleet"))
tos = db.save(examples.Group(name="The Original Series", members=[kirk]))
snw = db.save(examples.Group(name="Strange New Worlds", members=[pike]))
starfleet = db.save(examples.Group(name="Starfleet", groups=[tos, snw]))
assert tos in starfleet.groups
assert snw in starfleet.groups
assert kirk in set([user for group in starfleet.groups for user in group.members])
def test_unique(db): def test_unique(db):