Skip to content

Commit 0aacc4b

Browse files
committed
refactor Database init to take opts dict
1 parent 5d445dd commit 0aacc4b

File tree

2 files changed

+33
-32
lines changed

2 files changed

+33
-32
lines changed

stat_log_db/src/stat_log_db/cli.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,19 @@ def main():
99
"""Main CLI entry point."""
1010

1111
# TODO: Read info from pyproject.toml?
12-
parser = create_parser({
13-
"prog": "sldb",
14-
"description": "My CLI tool",
15-
}, "0.0.1")
12+
# parser = create_parser({
13+
# "prog": "sldb",
14+
# "description": "My CLI tool",
15+
# }, "0.0.1")
1616

17-
args = parser.parse_args()
17+
# args = parser.parse_args()
1818

1919
# print(f"{args=}")
2020

21-
sl_db = MemDB(":memory:", True, True)
21+
sl_db = MemDB({
22+
"is_mem": True,
23+
"fkey_constraint": True
24+
})
2225
con = sl_db.init_db(True)
2326
con.create_table("test", [('notes', 'TEXT')], False, True)
2427
con.execute("INSERT INTO test (notes) VALUES (?);", ("Hello world!",))

stat_log_db/src/stat_log_db/db.py

Lines changed: 24 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,32 @@
11
# from abc import ABC, abstractmethod
22
import sqlite3
33
import uuid
4+
from typing import Any
45

56

67
from .exceptions import raise_auto_arg_type_error
78

89

910
class Database():
10-
def __init__(self, db_name: str | None = None, is_mem: bool = False, fkey_constraint: bool = True):
11+
def __init__(self, options: dict[str, Any] = {}):
1112
# Validate arguments
12-
# database name
13-
if db_name is None:
14-
self._db_name = str(uuid.uuid4())
15-
elif not isinstance(db_name, str):
16-
raise_auto_arg_type_error("db_name")
17-
else:
18-
self._db_name = db_name
19-
# is memory or file database
20-
if not isinstance(is_mem, bool):
21-
raise_auto_arg_type_error("is_mem")
22-
self._in_memory = is_mem
23-
self._is_file = not is_mem
24-
# database file name
25-
if is_mem:
26-
self._db_file_name = ":memory:"
27-
else:
28-
self._db_file_name = self._db_name.replace(" ", "_")
29-
if not isinstance(fkey_constraint, bool):
30-
raise_auto_arg_type_error("fkey_constraint")
31-
self._fkey_constraint = fkey_constraint
13+
valid_options = {
14+
"db_name": str,
15+
"is_mem": bool,
16+
"fkey_constraint": bool
17+
}
18+
for opt, opt_type in options.items():
19+
if opt not in valid_options.keys():
20+
raise ValueError(f"Invalid option provided: '{opt}'. Must be one of {list(valid_options.keys())}.")
21+
expected_type = valid_options[opt]
22+
if not isinstance(opt_type, expected_type):
23+
raise TypeError(f"Option '{opt}' must be of type {expected_type.__name__}, got {type(opt_type).__name__}.")
24+
# Assign arguments to class attributes
25+
self._in_memory: bool = options.get("is_mem", False)
26+
self._is_file: bool = bool(not self._in_memory)
27+
self._db_name: str = options.get("db_name", str(uuid.uuid4()))
28+
self._db_file_name: str = ":memory:" if self._in_memory else self._db_name.replace(" ", "_")
29+
self._fkey_constraint: bool = options.get("fkey_constraint", True)
3230
# Keep track of active connections (to ensure that they are closed)
3331
self._connections: dict[str, BaseConnection] = dict()
3432

@@ -166,8 +164,8 @@ def close_db(self):
166164

167165

168166
class MemDB(Database):
169-
def __init__(self, db_name: str | None = None, is_mem: bool = False, fkey_constraint: bool = True):
170-
super().__init__(db_name, is_mem, fkey_constraint)
167+
def __init__(self, options: dict[str, Any] = {}):
168+
super().__init__(options=options)
171169
if not self.in_memory:
172170
raise ValueError("MemDB can only be used for in-memory databases.")
173171

@@ -191,8 +189,8 @@ def init_db_auto_close(self):
191189

192190

193191
class FileDB(Database):
194-
def __init__(self, db_name: str | None = None, fkey_constraint: bool = True):
195-
super().__init__(db_name, fkey_constraint)
192+
def __init__(self, options: dict[str, Any] = {}):
193+
super().__init__(options=options)
196194
if not self.is_file:
197195
raise ValueError("FileDB can only be used for file-based databases.")
198196

@@ -261,7 +259,7 @@ def enforce_foreign_key_constraints(self, commit: bool = True):
261259
self.connection.commit()
262260

263261
def _open(self):
264-
self._connection = sqlite3.connect(self.db_name)
262+
self._connection = sqlite3.connect(self.db_file_name)
265263
self._cursor = self._connection.cursor()
266264

267265
def open(self):

0 commit comments

Comments
 (0)