From b6097b60cc876a13a09c028a9c3d8bb3eb1f150c Mon Sep 17 00:00:00 2001 From: evilchili Date: Sun, 28 Sep 2025 11:11:34 -0700 Subject: [PATCH] deserialize search results --- src/grung/db.py | 5 +++++ src/grung/types.py | 11 ++++++----- test/test_db.py | 19 +++++++++++++++++++ 3 files changed, 30 insertions(+), 5 deletions(-) diff --git a/src/grung/db.py b/src/grung/db.py index 0ec964e..035fa65 100644 --- a/src/grung/db.py +++ b/src/grung/db.py @@ -2,6 +2,7 @@ import inspect import re from functools import reduce from operator import ior +from typing import List from tinydb import Query, TinyDB, table from tinydb.table import Document @@ -35,6 +36,10 @@ class RecordTable(table.Table): if document: return document.deserialize(self._db) + def search(self, *args, **kwargs) -> List[Record]: + results = super().search(*args, **kwargs) + return [r.deserialize(self._db) for r in results] + def remove(self, document): if document.doc_id: super().remove(doc_ids=[document.doc_id]) diff --git a/src/grung/types.py b/src/grung/types.py index 246c8a3..3f49237 100644 --- a/src/grung/types.py +++ b/src/grung/types.py @@ -32,6 +32,7 @@ class Field: @dataclass class Integer(Field): + value_type: type = int default: int = 0 def deserialize(self, rec: str, db: TinyDB) -> value_type: @@ -61,8 +62,8 @@ class Record(Dict[(str, Field)]): Serialie every field on the record """ rec = {} - for name, field in self._metadata.fields.items(): - rec[name] = field.serialize(self[name], db) + for name, _field in self._metadata.fields.items(): + rec[name] = _field.serialize(self[name], db) return self.__class__(rec, doc_id=self.doc_id) def deserialize(self, db): @@ -70,8 +71,8 @@ class Record(Dict[(str, Field)]): Deserialize every field on the record """ rec = {} - for name, field in self._metadata.fields.items(): - rec[name] = field.deserialize(self[name], db) + for name, _field in self._metadata.fields.items(): + rec[name] = _field.deserialize(self[name], db) return self.__class__(rec, doc_id=self.doc_id) def __setattr__(self, key, value): @@ -116,5 +117,5 @@ class Collection(Field): vals = [] for member in rec: pt, puid = member.split("::") - vals.append(db.table(pt).search(where("uid") == puid)[0].deserialize(db)) + vals.append(db.table(pt).search(where("uid") == puid)[0]) return vals diff --git a/test/test_db.py b/test/test_db.py index 4d8e35c..b63ad04 100644 --- a/test/test_db.py +++ b/test/test_db.py @@ -1,4 +1,5 @@ import pytest +from tinydb import Query from tinydb.storages import MemoryStorage from grung import examples @@ -76,3 +77,21 @@ def test_unique(db): with pytest.raises(UniqueConstraintError): user2 = db.save(user2) db.save(user1) + + +def test_search(db): + kirk = db.save(examples.User(name="Captain James T. Kirk", email="riskybiznez@starfleet")) + bones = db.save(examples.User(name="Doctor McCoy", email="dammitjim@starfleet")) + ricky = db.save(examples.User(name="Ensign Ricky Redshirt", email="invincible@starfleet")) + + crew = db.save(examples.Group(name="Crew", members=[kirk, bones, ricky])) + + User = Query() + captains = db.User.search(User.name.matches("Captain")) + assert len(captains) == 1 + + Group = Query() + crew = db.Group.search(Group.name == "Crew")[0] + assert kirk in crew.members + assert bones in crew.members + assert ricky in crew.members