use classmethod to build fields on records
This commit is contained in:
parent
f9ebb4a8d8
commit
a76afaa126
|
@ -2,8 +2,12 @@ from grung.types import Collection, Field, Integer, Record
|
||||||
|
|
||||||
|
|
||||||
class User(Record):
|
class User(Record):
|
||||||
_fields = [Field("name"), Integer("number", default=0), Field("email", unique=True)]
|
@classmethod
|
||||||
|
def fields(cls):
|
||||||
|
return [*super().fields(), Field("name"), Integer("number", default=0), Field("email", unique=True)]
|
||||||
|
|
||||||
|
|
||||||
class Group(Record):
|
class Group(Record):
|
||||||
_fields = [Field("name", unique=True), Collection("users", User)]
|
@classmethod
|
||||||
|
def fields(cls):
|
||||||
|
return [*super().fields(), Field("name", unique=True), Collection("members", User), Collection("groups", cls)]
|
||||||
|
|
|
@ -1,11 +1,11 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, field
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
|
||||||
import nanoid
|
import nanoid
|
||||||
from tinydb import where
|
from tinydb import TinyDB, where
|
||||||
|
|
||||||
from grung.exceptions import PointerReferenceError
|
from grung.exceptions import PointerReferenceError
|
||||||
|
|
||||||
|
@ -18,23 +18,21 @@ class Field:
|
||||||
Represents a single field in a Record.
|
Represents a single field in a Record.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
value_type = str
|
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
default: value_type | None = None
|
value_type: type = str
|
||||||
|
default: str = None
|
||||||
unique: bool = False
|
unique: bool = False
|
||||||
|
|
||||||
def serialize(self, rec: value_type, db: TinyDB) -> str:
|
def serialize(self, rec: value_type, db: TinyDB) -> str:
|
||||||
return str(rec)
|
return str(rec)
|
||||||
|
|
||||||
def deserialize(self, rec: str, db: TinyDB) -> value_type:
|
def deserialize(self, rec: value_type, db: TinyDB) -> value_type:
|
||||||
return rec
|
return rec
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
class Integer(Field):
|
class Integer(Field):
|
||||||
value_type = int
|
default: int = 0
|
||||||
|
|
||||||
default: value_type = 0
|
|
||||||
|
|
||||||
def deserialize(self, rec: str, db: TinyDB) -> value_type:
|
def deserialize(self, rec: str, db: TinyDB) -> value_type:
|
||||||
return int(rec)
|
return int(rec)
|
||||||
|
@ -46,14 +44,17 @@ class Record(Dict[(str, Field)]):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, raw_doc: dict = {}, doc_id: int = None, **params):
|
def __init__(self, raw_doc: dict = {}, doc_id: int = None, **params):
|
||||||
# populate the metadata
|
self.doc_id = doc_id
|
||||||
self._fields.append(
|
fields = self.__class__.fields()
|
||||||
|
self._metadata = Metadata(table=self.__class__.__name__, fields={f.name: f for f in fields})
|
||||||
|
super().__init__(dict({field.name: field.default for field in fields}, **raw_doc, **params))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def fields(self):
|
||||||
|
return [
|
||||||
# 1% collision rate at ~2M records
|
# 1% collision rate at ~2M records
|
||||||
Field("uid", default=nanoid.generate(size=8), unique=True)
|
Field("uid", default=nanoid.generate(size=8), unique=True)
|
||||||
)
|
]
|
||||||
self._metadata = Metadata(table=self.__class__.__name__, fields={f.name: f for f in self._fields})
|
|
||||||
self.doc_id = doc_id
|
|
||||||
super().__init__(dict({field.name: field.default for field in self._fields}, **raw_doc, **params))
|
|
||||||
|
|
||||||
def serialize(self, db):
|
def serialize(self, db):
|
||||||
"""
|
"""
|
||||||
|
@ -83,26 +84,32 @@ class Record(Dict[(str, Field)]):
|
||||||
return self.get(attr_name)
|
return self.get(attr_name)
|
||||||
return super().__getattr__(attr_name)
|
return super().__getattr__(attr_name)
|
||||||
|
|
||||||
|
def __hash__(self):
|
||||||
|
return hash(str(dict(self)))
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"{self.__class__.__name__}[{self.doc_id}]: {self.items()}"
|
return f"{self.__class__.__name__}[{self.doc_id}]: {self.items()}"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
class Collection(Field):
|
class Collection(Field):
|
||||||
"""
|
"""
|
||||||
A collection of fields that store pointers instead of dicts.
|
A collection of fields that store pointers instead of dicts.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
value_type = List[Record]
|
value_type: type = Record
|
||||||
|
default: List[value_type] = field(default_factory=lambda: [])
|
||||||
|
|
||||||
def serialize(self, recs: value_type, db: TinyDB) -> List[str]:
|
def serialize(self, value: value_type, db: TinyDB) -> List[str]:
|
||||||
vals = []
|
vals = self.default
|
||||||
for rec in recs:
|
if value:
|
||||||
if not rec.doc_id:
|
for rec in value:
|
||||||
raise PointerReferenceError(rec)
|
if not rec.doc_id:
|
||||||
vals.append(f"{rec._metadata.table}::{rec.uid}")
|
raise PointerReferenceError(rec)
|
||||||
|
vals.append(f"{rec._metadata.table}::{rec.uid}")
|
||||||
return vals
|
return vals
|
||||||
|
|
||||||
def deserialize(self, rec: List[str], db: TinyDB) -> Collection.value_type:
|
def deserialize(self, rec: List[value_type], db: TinyDB) -> value_type:
|
||||||
"""
|
"""
|
||||||
Recursively deserialize the objects in this collection
|
Recursively deserialize the objects in this collection
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -43,16 +43,29 @@ def test_crud(db):
|
||||||
|
|
||||||
def test_pointers(db):
|
def test_pointers(db):
|
||||||
user = examples.User(name="john", email="john@foo")
|
user = examples.User(name="john", email="john@foo")
|
||||||
players = examples.Group(name="players", users=[user])
|
players = examples.Group(name="players", members=[user], groups=[])
|
||||||
|
|
||||||
with pytest.raises(PointerReferenceError):
|
with pytest.raises(PointerReferenceError):
|
||||||
players = db.save(players)
|
players = db.save(players)
|
||||||
|
|
||||||
user = db.save(user)
|
user = db.save(user)
|
||||||
|
|
||||||
players.users[0] = user
|
players.members[0] = user
|
||||||
players = db.save(players)
|
players = db.save(players)
|
||||||
assert players.users[0] == user
|
assert players.members[0] == user
|
||||||
|
|
||||||
|
|
||||||
|
def test_subgroups(db):
|
||||||
|
kirk = db.save(examples.User(name="James T. Kirk", email="riskybiznez@starfleet"))
|
||||||
|
pike = db.save(examples.User(name="Christopher Pike", email="hitit@starfleet"))
|
||||||
|
|
||||||
|
tos = db.save(examples.Group(name="The Original Series", members=[kirk]))
|
||||||
|
snw = db.save(examples.Group(name="Strange New Worlds", members=[pike]))
|
||||||
|
|
||||||
|
starfleet = db.save(examples.Group(name="Starfleet", groups=[tos, snw]))
|
||||||
|
assert tos in starfleet.groups
|
||||||
|
assert snw in starfleet.groups
|
||||||
|
assert kirk in set([user for group in starfleet.groups for user in group.members])
|
||||||
|
|
||||||
|
|
||||||
def test_unique(db):
|
def test_unique(db):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user