diff --git a/src/grung/examples.py b/src/grung/examples.py index e46f4f0..5887e86 100644 --- a/src/grung/examples.py +++ b/src/grung/examples.py @@ -2,8 +2,12 @@ from grung.types import Collection, Field, Integer, Record class User(Record): - _fields = [Field("name"), Integer("number", default=0), Field("email", unique=True)] + @classmethod + def fields(cls): + return [*super().fields(), Field("name"), Integer("number", default=0), Field("email", unique=True)] class Group(Record): - _fields = [Field("name", unique=True), Collection("users", User)] + @classmethod + def fields(cls): + return [*super().fields(), Field("name", unique=True), Collection("members", User), Collection("groups", cls)] diff --git a/src/grung/types.py b/src/grung/types.py index b501f38..246c8a3 100644 --- a/src/grung/types.py +++ b/src/grung/types.py @@ -1,11 +1,11 @@ from __future__ import annotations from collections import namedtuple -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Dict, List import nanoid -from tinydb import where +from tinydb import TinyDB, where from grung.exceptions import PointerReferenceError @@ -18,23 +18,21 @@ class Field: Represents a single field in a Record. """ - value_type = str - name: str - default: value_type | None = None + value_type: type = str + default: str = None unique: bool = False def serialize(self, rec: value_type, db: TinyDB) -> str: return str(rec) - def deserialize(self, rec: str, db: TinyDB) -> value_type: + def deserialize(self, rec: value_type, db: TinyDB) -> value_type: return rec +@dataclass class Integer(Field): - value_type = int - - default: value_type = 0 + default: int = 0 def deserialize(self, rec: str, db: TinyDB) -> value_type: return int(rec) @@ -46,14 +44,17 @@ class Record(Dict[(str, Field)]): """ def __init__(self, raw_doc: dict = {}, doc_id: int = None, **params): - # populate the metadata - self._fields.append( + self.doc_id = doc_id + fields = self.__class__.fields() + self._metadata = Metadata(table=self.__class__.__name__, fields={f.name: f for f in fields}) + super().__init__(dict({field.name: field.default for field in fields}, **raw_doc, **params)) + + @classmethod + def fields(self): + return [ # 1% collision rate at ~2M records Field("uid", default=nanoid.generate(size=8), unique=True) - ) - self._metadata = Metadata(table=self.__class__.__name__, fields={f.name: f for f in self._fields}) - self.doc_id = doc_id - super().__init__(dict({field.name: field.default for field in self._fields}, **raw_doc, **params)) + ] def serialize(self, db): """ @@ -83,26 +84,32 @@ class Record(Dict[(str, Field)]): return self.get(attr_name) return super().__getattr__(attr_name) + def __hash__(self): + return hash(str(dict(self))) + def __repr__(self): return f"{self.__class__.__name__}[{self.doc_id}]: {self.items()}" +@dataclass class Collection(Field): """ A collection of fields that store pointers instead of dicts. """ - value_type = List[Record] + value_type: type = Record + default: List[value_type] = field(default_factory=lambda: []) - def serialize(self, recs: value_type, db: TinyDB) -> List[str]: - vals = [] - for rec in recs: - if not rec.doc_id: - raise PointerReferenceError(rec) - vals.append(f"{rec._metadata.table}::{rec.uid}") + def serialize(self, value: value_type, db: TinyDB) -> List[str]: + vals = self.default + if value: + for rec in value: + if not rec.doc_id: + raise PointerReferenceError(rec) + vals.append(f"{rec._metadata.table}::{rec.uid}") return vals - def deserialize(self, rec: List[str], db: TinyDB) -> Collection.value_type: + def deserialize(self, rec: List[value_type], db: TinyDB) -> value_type: """ Recursively deserialize the objects in this collection """ diff --git a/test/test_db.py b/test/test_db.py index 6ad7053..4d8e35c 100644 --- a/test/test_db.py +++ b/test/test_db.py @@ -43,16 +43,29 @@ def test_crud(db): def test_pointers(db): user = examples.User(name="john", email="john@foo") - players = examples.Group(name="players", users=[user]) + players = examples.Group(name="players", members=[user], groups=[]) with pytest.raises(PointerReferenceError): players = db.save(players) user = db.save(user) - players.users[0] = user + players.members[0] = user players = db.save(players) - assert players.users[0] == user + assert players.members[0] == user + + +def test_subgroups(db): + kirk = db.save(examples.User(name="James T. Kirk", email="riskybiznez@starfleet")) + pike = db.save(examples.User(name="Christopher Pike", email="hitit@starfleet")) + + tos = db.save(examples.Group(name="The Original Series", members=[kirk])) + snw = db.save(examples.Group(name="Strange New Worlds", members=[pike])) + + starfleet = db.save(examples.Group(name="Starfleet", groups=[tos, snw])) + assert tos in starfleet.groups + assert snw in starfleet.groups + assert kirk in set([user for group in starfleet.groups for user in group.members]) def test_unique(db):