diff --git a/pyproject.toml b/pyproject.toml index 4bd8399..61b25b5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,8 @@ flask = "*" tinydb = "^4.8.2" pyyaml = "^6.0.2" nanoid = "^2.0.0" +# grung-db = {git = "https://git.evilchi.li/evilchili/grung-db.git"} +grung-db = {git = "file:///home/greg/dev/grung-db/"} [tool.poetry.group.dev.dependencies] pytest = "*" diff --git a/src/ttfrog/app.py b/src/ttfrog/app.py index 13e944e..b9cb9a7 100644 --- a/src/ttfrog/app.py +++ b/src/ttfrog/app.py @@ -2,21 +2,25 @@ import os import sys from flask import Flask +from grung.db import GrungDB from tinydb.storages import MemoryStorage -from ttfrog.db import Database +from ttfrog import schema class ApplicationContext: def __init__(self): self.web: Flask = Flask("ttfrog") - self.db: Database = Database(storage=MemoryStorage) + self.db: GrungDB = None self._initialized = False - def initialize(self, db: Database = None): + def initialize(self, db: GrungDB = None): if not self._initialized: self.web.config["SECRET_KEY"] = os.getenv("SECRET_KEY", "secret string") - self.db = db or Database("ttfrog.db.json") + if os.environ.get("TTFROG_IN_MEMORY_DB"): + self.db = GrungDB.with_schema(schema, storage=MemoryStorage) + else: + self.db = GrungDB.with_schema(schema, "ttfrog.db.json") self._initialized = True diff --git a/src/ttfrog/cli.py b/src/ttfrog/cli.py index 14d0313..368fbbc 100644 --- a/src/ttfrog/cli.py +++ b/src/ttfrog/cli.py @@ -19,9 +19,6 @@ LOG_LEVEL=INFO main_app = typer.Typer() -_context = ttfrog.app.initialize() -flask_app = _context.app - app_state = dict( config_file=Path("~/.config/ttfrog.conf").expanduser(), ) @@ -53,6 +50,8 @@ def callback( ) app_state["verbose"] = verbose + ttfrog.app.initialize() + if context.invoked_subcommand is None: logger.debug("No command specified; invoking default handler.") run(context) @@ -62,19 +61,19 @@ def run(context: typer.Context): """ The default CLI entrypoint is ttfrog.cli.run(). """ - flask_app.run() + ttfrog.app.web.run() -@flask_app.shell_context_processor +@ttfrog.app.web.shell_context_processor def make_shell_context(): - return flask_app + return ttfrog.app -@click.group(cls=FlaskGroup, create_app=lambda: flask_app) +@click.group(cls=FlaskGroup, create_app=lambda: ttfrog.app.web) @click.pass_context def app(ctx): """ - Application management functions + Web Application management functions """ diff --git a/src/ttfrog/db.py b/src/ttfrog/db.py deleted file mode 100644 index 0b7da82..0000000 --- a/src/ttfrog/db.py +++ /dev/null @@ -1,53 +0,0 @@ -from tinydb import TinyDB, table -from tinydb.table import Document - -from ttfrog import schema - - -class RecordTable(table.Table): - """ - Wrapper around tinydb Tables that handles Records instead of dicts. - """ - - def __init__(self, storage, name, **kwargs): - self.document_class = getattr(schema, name, Document) - super().__init__(storage, name, **kwargs) - - def insert(self, document): - self._satisfy_constraints(document) - if document.doc_id: - last_insert_id = super().upsert(document)[0] - else: - last_insert_id = super().insert(dict(document)) - return self.get(doc_id=last_insert_id) - - def _satisfy_constraints(self, document): - # check for uniqueness, etc. - pass - - -class Database(TinyDB): - """ - A TinyDB database instance that uses RecordTable instances for each table - and Record instances for each document in the table. - """ - - default_table_name = "Record" - - def table(self, name: str, **kwargs) -> RecordTable: - if name not in self._tables: - self._tables[name] = RecordTable(self.storage, name, **kwargs) - return self._tables[name] - - def save(self, record): - """ - Create or update a record in its table. - """ - return self.table(record._metadata.table).insert(record) - - def __getattr__(self, attr_name): - """ - Make tables attributes of the instance. - """ - if attr_name in self.tables(): - return self.table(attr_name) diff --git a/src/ttfrog/schema.py b/src/ttfrog/schema.py index 62597a1..16f7f14 100644 --- a/src/ttfrog/schema.py +++ b/src/ttfrog/schema.py @@ -1,58 +1,11 @@ -from collections import namedtuple -from dataclasses import dataclass +from typing import List -import nanoid - -Metadata = namedtuple("Metadata", ["table", "fields"]) - - -@dataclass -class Field: - """ - Represents a single field in a Record. - """ - - name: str - value_type: type = str - default: value_type | None = None - unique: bool = False - - -class Record(dict): - """ - Base type for a single database record. - """ - - _fields = [Field("uid", default="", unique=True)] - - def __init__(self, raw_doc: dict = {}, doc_id: int = None, **params): - # populate the metadata - fields = Record._fields - if self.__class__ != Record: - fields += self._fields - self._metadata = Metadata(table=self.__class__.__name__, fields={f.name: f for f in fields}) - - self.doc_id = doc_id - - vals = dict({field.name: field.default for field in fields}, **raw_doc, **params) - if not vals["uid"]: - vals["uid"] = nanoid.generate(size=8) # 1% collision rate at ~2M records - - super().__init__(vals) - - def __setattr__(self, key, value): - if key in self: - self[key] = value - super().__setattr__(key, value) - - def __getattr__(self, attr_name): - if attr_name in self: - return self.get(attr_name) - return super().__getattr__(attr_name) - - def __repr__(self): - return f"{self.__class__.__name__}[{self.doc_id}]: {self.items()}" +from grung.types import Field, Record class User(Record): _fields = [Field("name"), Field("email", unique=True)] + + +class Group(Record): + _fields = [Field("name", unique=True), Field("users", List[User])] diff --git a/test/test_db.py b/test/test_db.py index e254c95..fea1f96 100644 --- a/test/test_db.py +++ b/test/test_db.py @@ -1,19 +1,15 @@ -import os - import pytest import ttfrog.app from ttfrog import schema -from ttfrog.db import Database @pytest.fixture -def app(): - ttfrog.app.initialize(Database("ttfrog-tests")) - ttfrog.app.db.drop_tables() +def app(monkeypatch): + monkeypatch.setenv("TTFROG_IN_MEMORY_DB", "1") + ttfrog.app.initialize() yield ttfrog.app ttfrog.app.db.close() - os.unlink("ttfrog-tests") def test_create(app): @@ -37,3 +33,8 @@ def test_create(app): after_update = app.db.save(john_something) assert after_update == john_something assert before_update != after_update + + players = schema.Group(name="players", users=[john_something]) + players = app.db.save(players) + players.users[0]["name"] = "fnord" + app.db.save(players)