use classmethod to build fields on records

This commit is contained in:
evilchili 2025-09-28 10:33:44 -07:00
parent f9ebb4a8d8
commit a76afaa126
3 changed files with 52 additions and 28 deletions

View File

@ -2,8 +2,12 @@ from grung.types import Collection, Field, Integer, Record
class User(Record):
_fields = [Field("name"), Integer("number", default=0), Field("email", unique=True)]
@classmethod
def fields(cls):
return [*super().fields(), Field("name"), Integer("number", default=0), Field("email", unique=True)]
class Group(Record):
_fields = [Field("name", unique=True), Collection("users", User)]
@classmethod
def fields(cls):
return [*super().fields(), Field("name", unique=True), Collection("members", User), Collection("groups", cls)]

View File

@ -1,11 +1,11 @@
from __future__ import annotations
from collections import namedtuple
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Dict, List
import nanoid
from tinydb import where
from tinydb import TinyDB, where
from grung.exceptions import PointerReferenceError
@ -18,23 +18,21 @@ class Field:
Represents a single field in a Record.
"""
value_type = str
name: str
default: value_type | None = None
value_type: type = str
default: str = None
unique: bool = False
def serialize(self, rec: value_type, db: TinyDB) -> str:
return str(rec)
def deserialize(self, rec: str, db: TinyDB) -> value_type:
def deserialize(self, rec: value_type, db: TinyDB) -> value_type:
return rec
@dataclass
class Integer(Field):
value_type = int
default: value_type = 0
default: int = 0
def deserialize(self, rec: str, db: TinyDB) -> value_type:
return int(rec)
@ -46,14 +44,17 @@ class Record(Dict[(str, Field)]):
"""
def __init__(self, raw_doc: dict = {}, doc_id: int = None, **params):
# populate the metadata
self._fields.append(
self.doc_id = doc_id
fields = self.__class__.fields()
self._metadata = Metadata(table=self.__class__.__name__, fields={f.name: f for f in fields})
super().__init__(dict({field.name: field.default for field in fields}, **raw_doc, **params))
@classmethod
def fields(self):
return [
# 1% collision rate at ~2M records
Field("uid", default=nanoid.generate(size=8), unique=True)
)
self._metadata = Metadata(table=self.__class__.__name__, fields={f.name: f for f in self._fields})
self.doc_id = doc_id
super().__init__(dict({field.name: field.default for field in self._fields}, **raw_doc, **params))
]
def serialize(self, db):
"""
@ -83,26 +84,32 @@ class Record(Dict[(str, Field)]):
return self.get(attr_name)
return super().__getattr__(attr_name)
def __hash__(self):
return hash(str(dict(self)))
def __repr__(self):
return f"{self.__class__.__name__}[{self.doc_id}]: {self.items()}"
@dataclass
class Collection(Field):
"""
A collection of fields that store pointers instead of dicts.
"""
value_type = List[Record]
value_type: type = Record
default: List[value_type] = field(default_factory=lambda: [])
def serialize(self, recs: value_type, db: TinyDB) -> List[str]:
vals = []
for rec in recs:
if not rec.doc_id:
raise PointerReferenceError(rec)
vals.append(f"{rec._metadata.table}::{rec.uid}")
def serialize(self, value: value_type, db: TinyDB) -> List[str]:
vals = self.default
if value:
for rec in value:
if not rec.doc_id:
raise PointerReferenceError(rec)
vals.append(f"{rec._metadata.table}::{rec.uid}")
return vals
def deserialize(self, rec: List[str], db: TinyDB) -> Collection.value_type:
def deserialize(self, rec: List[value_type], db: TinyDB) -> value_type:
"""
Recursively deserialize the objects in this collection
"""

View File

@ -43,16 +43,29 @@ def test_crud(db):
def test_pointers(db):
user = examples.User(name="john", email="john@foo")
players = examples.Group(name="players", users=[user])
players = examples.Group(name="players", members=[user], groups=[])
with pytest.raises(PointerReferenceError):
players = db.save(players)
user = db.save(user)
players.users[0] = user
players.members[0] = user
players = db.save(players)
assert players.users[0] == user
assert players.members[0] == user
def test_subgroups(db):
kirk = db.save(examples.User(name="James T. Kirk", email="riskybiznez@starfleet"))
pike = db.save(examples.User(name="Christopher Pike", email="hitit@starfleet"))
tos = db.save(examples.Group(name="The Original Series", members=[kirk]))
snw = db.save(examples.Group(name="Strange New Worlds", members=[pike]))
starfleet = db.save(examples.Group(name="Starfleet", groups=[tos, snw]))
assert tos in starfleet.groups
assert snw in starfleet.groups
assert kirk in set([user for group in starfleet.groups for user in group.members])
def test_unique(db):