From cf709a2215006849ee974b6f1d25ad3fb6831daa Mon Sep 17 00:00:00 2001 From: JWittmeyer Date: Tue, 6 May 2025 17:50:15 +0200 Subject: [PATCH 1/7] Adds cache decorator & cache or middleware calls --- business_objects/organization.py | 2 + business_objects/user.py | 3 + db_cache.py | 180 +++++++++++++++++++++++++++++++ 3 files changed, 185 insertions(+) create mode 100644 db_cache.py diff --git a/business_objects/organization.py b/business_objects/organization.py index e20abe7e..843a3170 100644 --- a/business_objects/organization.py +++ b/business_objects/organization.py @@ -8,8 +8,10 @@ from ..models import Organization, Project, User from ..business_objects import project, user, general from ..util import prevent_sql_injection +from ..db_cache import TTLCacheDecorator, CacheEnum +@TTLCacheDecorator(CacheEnum.ORGANIZATION, 5, "id") def get(id: str) -> Organization: return session.query(Organization).get(id) diff --git a/business_objects/user.py b/business_objects/user.py index a20210b6..dada9e82 100644 --- a/business_objects/user.py +++ b/business_objects/user.py @@ -5,10 +5,13 @@ from typing import List, Optional from sqlalchemy import sql +from ..db_cache import TTLCacheDecorator, CacheEnum + from ..util import prevent_sql_injection +@TTLCacheDecorator(CacheEnum.USER, 5, "user_id") def get(user_id: str) -> User: return session.query(User).get(user_id) diff --git a/db_cache.py b/db_cache.py new file mode 100644 index 00000000..00f7949d --- /dev/null +++ b/db_cache.py @@ -0,0 +1,180 @@ +import time +import functools +import inspect +import threading +from enum import Enum +from .daemon import run_without_db_token + + +# Enum for logical cache separation\ +class CacheEnum(Enum): + DEFAULT = "default" + USER = "user" + ORGANIZATION = "organization" + TEAM = "team" + # extend with more categories as needed + + +# Global cache map: each cache_type -> its own dict of key -> (value, expires_at) +_GLOBAL_CACHE_MAP = {} +# Lock to protect cache operations +_CACHE_LOCK = threading.Lock() + + +def _cleanup_expired(): + while True: + time.sleep(60 * 60) # run every hour + now = time.time() + with _CACHE_LOCK: + for cache in _GLOBAL_CACHE_MAP.values(): + # collect expired keys first + expired_keys = [key for key, (_, exp) in cache.items() if now >= exp] + for key in expired_keys: + del cache[key] + + +# # Start cleanup thread as daemon +# _cleanup_thread = threading.Thread(target=_cleanup_expired, daemon=True) +# _cleanup_thread.start() + + +def start_cleanup_thread(): + run_without_db_token(_cleanup_expired) + + +class TTLCacheDecorator: + def __init__(self, cache_type=CacheEnum.DEFAULT, ttl_minutes=None, *key_fields): + """ + cache_type: namespace for the cache + ttl_minutes: time-to-live for cache entries, in minutes + key_fields: argument names (str) to build cache key; positions are not supported + """ + if not isinstance(cache_type, CacheEnum): + raise TypeError("cache_type must be a CacheEnum member") + if ttl_minutes is None: + raise ValueError("ttl_minutes must be specified") + self.cache_type = cache_type + # convert minutes to seconds + self.ttl = ttl_minutes * 60 + # only named fields + for f in key_fields: + if not isinstance(f, str): + raise TypeError("key_fields must be argument names (strings)") + self.key_fields = key_fields + + def __call__(self, fn): + sig = inspect.signature(fn) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + bound = sig.bind_partial(*args, **kwargs) + # build cache key tuple from named fields + try: + key = tuple(bound.arguments[field] for field in self.key_fields) + except KeyError as e: + raise KeyError(f"Missing argument for cache key: {e}") + + now = time.time() + with _CACHE_LOCK: + cache = _GLOBAL_CACHE_MAP.setdefault(self.cache_type, {}) + entry = cache.get(key) + if entry: + value, expires_at = entry + if now < expires_at: + print(f"Cache hit for {key} in {self.cache_type}") + return value + # expired + del cache[key] + + # miss or expired + print(f"No cache hit for {key} in {self.cache_type}") + result = fn(*args, **kwargs) + with _CACHE_LOCK: + cache = _GLOBAL_CACHE_MAP[self.cache_type] + cache[key] = (result, now + self.ttl) + return result + + # management methods + def invalidate(**kws): + try: + key = tuple(kws[field] for field in self.key_fields) + except KeyError as e: + raise KeyError(f"Missing argument for invalidate key: {e}") + with _CACHE_LOCK: + _GLOBAL_CACHE_MAP.get(self.cache_type, {}).pop(key, None) + + def update(value, **kws): + try: + key = tuple(kws[field] for field in self.key_fields) + except KeyError as e: + raise KeyError(f"Missing argument for update key: {e}") + with _CACHE_LOCK: + cache = _GLOBAL_CACHE_MAP.setdefault(self.cache_type, {}) + cache[key] = (value, time.time() + self.ttl) + + def clear_all(): + with _CACHE_LOCK: + _GLOBAL_CACHE_MAP[self.cache_type] = {} + + wrapper.invalidate = invalidate + wrapper.update = update + wrapper.clear_all = clear_all + wrapper.ttl = self.ttl + wrapper.key_fields = self.key_fields + wrapper.cache_type = self.cache_type + + return wrapper + + +# ─── GLOBAL INVALIDATE / UPDATE ────────────────────────────────────────────── + + +def invalidate_cache(cache_type: CacheEnum, key: tuple): + """ + Remove a single entry from the given cache. + key must be the exact tuple used when caching. + """ + if not isinstance(cache_type, CacheEnum): + raise TypeError("cache_type must be a CacheEnum member") + with _CACHE_LOCK: + _GLOBAL_CACHE_MAP.get(cache_type, {}).pop(key, None) + + +def update_cache(cache_type: CacheEnum, key: tuple, value, ttl_minutes: float): + """ + Force-set a value in cache under `cache_type` and `key`, overriding any existing entry. + """ + if not isinstance(cache_type, CacheEnum): + raise TypeError("cache_type must be a CacheEnum member") + expires_at = time.time() + ttl_minutes * 60 + with _CACHE_LOCK: + cache = _GLOBAL_CACHE_MAP.setdefault(cache_type, {}) + cache[key] = (value, expires_at) + + +# Example usage: !Note the tuples syntax +# invalidate_cache(CacheEnum.USER, (user_id,)) +# update_cache(CacheEnum.USER, (user_id,), some_value, ttl_minutes=5) + + +# Example usage: +# @TTLCacheDecorator(CacheEnum.USER, 60, 'user_id') +# def get_user_by_id(user_id): +# print(f"Fetching user {user_id} from database") +# return {"id": user_id, "name": "John"} + +# @TTLCacheDecorator(CacheEnum.USER, 60, 'user_id') +# def get_admin_user(user_id, dummy=None): +# print(f"Fetching admin user {user_id} from database") +# return {"id": user_id, "role": "admin"} + +# @TTLCacheDecorator(CacheEnum.RECORD, 120, 'project_id', 'record_id') +# def get_record(project_id, record_id): +# print(f"Fetching record {project_id}, {record_id} from database") +# return {"project_id": project_id, "id": record_id, "value": "Some data"} + +# Management examples: +# get_user_by_id.invalidate(user_id=1) +# get_user_by_id.update({"id":1, "name":"Jane"}, user_id=1) +# get_user_by_id.clear_all() +# get_record.clear_all() From 064cb7db0b9650e41ac7e60d5630fc249b6f1b35 Mon Sep 17 00:00:00 2001 From: JWittmeyer Date: Wed, 7 May 2025 09:17:06 +0200 Subject: [PATCH 2/7] Remove prints --- db_cache.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/db_cache.py b/db_cache.py index 00f7949d..d6acb2cd 100644 --- a/db_cache.py +++ b/db_cache.py @@ -81,13 +81,13 @@ def wrapper(*args, **kwargs): if entry: value, expires_at = entry if now < expires_at: - print(f"Cache hit for {key} in {self.cache_type}") + # print(f"Cache hit for {key} in {self.cache_type}") return value # expired del cache[key] # miss or expired - print(f"No cache hit for {key} in {self.cache_type}") + # print(f"No cache hit for {key} in {self.cache_type}") result = fn(*args, **kwargs) with _CACHE_LOCK: cache = _GLOBAL_CACHE_MAP[self.cache_type] From 8b8321d63f735818cb08b7224c8c7cc2bc1de935 Mon Sep 17 00:00:00 2001 From: JWittmeyer Date: Wed, 7 May 2025 11:55:48 +0200 Subject: [PATCH 3/7] Expunge test --- business_objects/organization.py | 13 +++++++++++-- business_objects/user.py | 20 ++++++++++++++++++-- 2 files changed, 29 insertions(+), 4 deletions(-) diff --git a/business_objects/organization.py b/business_objects/organization.py index 843a3170..feb6d7bb 100644 --- a/business_objects/organization.py +++ b/business_objects/organization.py @@ -11,11 +11,20 @@ from ..db_cache import TTLCacheDecorator, CacheEnum -@TTLCacheDecorator(CacheEnum.ORGANIZATION, 5, "id") def get(id: str) -> Organization: return session.query(Organization).get(id) +@TTLCacheDecorator(CacheEnum.ORGANIZATION, 5, "id") +def get_org_cached(id: str) -> Organization: + o = get(id) + if not o: + return None + general.expunge(o) + general.make_transient(o) + return o + + def get_by_name(name: str) -> Organization: return session.query(Organization).filter(Organization.name == name).first() @@ -120,7 +129,7 @@ def log_admin_requests(org_id: str) -> str: # enum AdminLogLevel if not org_id: # e.g. not assigned to an organization = not logged return None - if o := get(org_id): + if o := get_org_cached(org_id): return o.log_admin_requests return None diff --git a/business_objects/user.py b/business_objects/user.py index dada9e82..faf52c96 100644 --- a/business_objects/user.py +++ b/business_objects/user.py @@ -1,21 +1,37 @@ from datetime import datetime from . import general, organization, team_member from .. import User, enums +from typing import Dict, Any from ..session import session from typing import List, Optional from sqlalchemy import sql from ..db_cache import TTLCacheDecorator, CacheEnum - from ..util import prevent_sql_injection -@TTLCacheDecorator(CacheEnum.USER, 5, "user_id") def get(user_id: str) -> User: return session.query(User).get(user_id) +@TTLCacheDecorator(CacheEnum.USER, 5, "user_id") +def get_user_cached(user_id: str) -> User: + """ + Get user by id and return as dict + """ + user = get(user_id) + if not user: + return None + + general.expunge(user) + general.make_transient(user) + return user + # if not user: + # return {} + # return sql_alchemy_to_dict(user) + + def get_by_id_list(user_ids: List[str]) -> List[User]: return session.query(User).filter(User.id.in_(user_ids)).all() From d7773b180cb969392365957987e1ec433f57fbbd Mon Sep 17 00:00:00 2001 From: JWittmeyer Date: Mon, 12 May 2025 15:49:27 +0200 Subject: [PATCH 4/7] Adds session decorator and middleware cached options --- business_objects/general.py | 8 ++- business_objects/organization.py | 5 +- business_objects/user.py | 8 ++- cognition_objects/project.py | 28 ++++++++- db_cache.py | 1 + session_wrapper.py | 98 ++++++++++++++++++++++++++++++++ 6 files changed, 143 insertions(+), 5 deletions(-) create mode 100644 session_wrapper.py diff --git a/business_objects/general.py b/business_objects/general.py index 76147ae5..e1215209 100644 --- a/business_objects/general.py +++ b/business_objects/general.py @@ -15,6 +15,8 @@ __THREAD_LOCK = Lock() +IS_DEV = True + session_lookup = {} @@ -23,7 +25,11 @@ def get_ctx_token() -> Any: session_uuid = str(uuid.uuid4()) session_id = request_id_ctx_var.set(session_uuid) - call_stack = "".join(traceback.format_stack()[-5:]) + if IS_DEV: + # traces are usually long running and only useful for debugging + call_stack = "".join(traceback.format_stack()[-5:]) + else: + call_stack = "Activate dev mode to see call stack" with __THREAD_LOCK: session_lookup[session_uuid] = { "session_id": session_uuid, diff --git a/business_objects/organization.py b/business_objects/organization.py index feb6d7bb..84fa420b 100644 --- a/business_objects/organization.py +++ b/business_objects/organization.py @@ -4,11 +4,12 @@ from submodules.model import enums -from ..session import session +from ..session import session, request_id_ctx_var from ..models import Organization, Project, User from ..business_objects import project, user, general from ..util import prevent_sql_injection from ..db_cache import TTLCacheDecorator, CacheEnum +from ..session_wrapper import with_session def get(id: str) -> Organization: @@ -16,7 +17,9 @@ def get(id: str) -> Organization: @TTLCacheDecorator(CacheEnum.ORGANIZATION, 5, "id") +@with_session() def get_org_cached(id: str) -> Organization: + print("get_org_cached with session:", request_id_ctx_var.get(), flush=True) o = get(id) if not o: return None diff --git a/business_objects/user.py b/business_objects/user.py index faf52c96..5d5cf5f7 100644 --- a/business_objects/user.py +++ b/business_objects/user.py @@ -1,12 +1,14 @@ from datetime import datetime from . import general, organization, team_member from .. import User, enums -from typing import Dict, Any -from ..session import session + +# from typing import Dict, Any +from ..session import session, request_id_ctx_var from typing import List, Optional from sqlalchemy import sql from ..db_cache import TTLCacheDecorator, CacheEnum +from ..session_wrapper import with_session from ..util import prevent_sql_injection @@ -16,10 +18,12 @@ def get(user_id: str) -> User: @TTLCacheDecorator(CacheEnum.USER, 5, "user_id") +@with_session() def get_user_cached(user_id: str) -> User: """ Get user by id and return as dict """ + print("get_user_cached with session:", request_id_ctx_var.get(), flush=True) user = get(user_id) if not user: return None diff --git a/cognition_objects/project.py b/cognition_objects/project.py index 9ebdd4d3..06f30ba6 100644 --- a/cognition_objects/project.py +++ b/cognition_objects/project.py @@ -1,13 +1,15 @@ from typing import List, Optional, Dict, Any, Iterable from ..business_objects import general, team_resource, user from ..cognition_objects import consumption_log, consumption_summary -from ..session import session +from ..session import session, request_id_ctx_var from ..models import CognitionProject, TeamMember, TeamResource from .. import enums from datetime import datetime from ..util import prevent_sql_injection from sqlalchemy.orm.attributes import flag_modified from copy import deepcopy +from ..db_cache import TTLCacheDecorator, CacheEnum +from ..session_wrapper import with_session def get(project_id: str) -> CognitionProject: @@ -18,6 +20,18 @@ def get(project_id: str) -> CognitionProject: ) +@TTLCacheDecorator(CacheEnum.PROJECT, 5, "project_id") +@with_session() +def get_cached(project_id: str) -> CognitionProject: + print("get_project_cached with session:", request_id_ctx_var.get(), flush=True) + p = get(project_id) + if not p: + return None + general.expunge(p) + general.make_transient(p) + return p + + def get_org_id(project_id: str) -> str: if p := get(project_id): return str(p.organization_id) @@ -42,6 +56,18 @@ def get_by_user(project_id: str, user_id: str) -> CognitionProject: ) +@TTLCacheDecorator(CacheEnum.PROJECT, 5, "project_id", "user_id") +@with_session() +def get_by_user_cached(project_id: str, user_id: str) -> CognitionProject: + print("get_by_user_cached with session:", request_id_ctx_var.get(), flush=True) + p = get_by_user(project_id, user_id) + if not p: + return None + general.expunge(p) + general.make_transient(p) + return p + + def get_all(org_id: str, order_by_name: bool = False) -> List[CognitionProject]: query = session.query(CognitionProject).filter( CognitionProject.organization_id == org_id diff --git a/db_cache.py b/db_cache.py index d6acb2cd..8dbff58e 100644 --- a/db_cache.py +++ b/db_cache.py @@ -11,6 +11,7 @@ class CacheEnum(Enum): DEFAULT = "default" USER = "user" ORGANIZATION = "organization" + PROJECT = "project" TEAM = "team" # extend with more categories as needed diff --git a/session_wrapper.py b/session_wrapper.py new file mode 100644 index 00000000..9eb68621 --- /dev/null +++ b/session_wrapper.py @@ -0,0 +1,98 @@ +# db_utils.py + +import asyncio + +# import uuid +from .business_objects import general +from contextvars import copy_context +from .session import session, request_id_ctx_var +import functools + + +def _run_with_session( + fn, *args, auto_remove: bool = True, new_session: bool = True, **kwargs +): + """ + Sync helper: ensures a request-id is set (or reset), runs fn(*args, **kwargs), + then optionally removes the session. + + Args: + fn: the DB function to run + auto_remove: if True, calls Session.remove() after execution + new_session: if True, always assign a fresh UUID as the request ID + """ + # decide on request ID behavior + if new_session or request_id_ctx_var.get() is None: + # generate a unique request id for this session + # request_id_ctx_var.set(str(uuid.uuid4())) + general.get_ctx_token() + + try: + # Scoped Session uses request_id_ctx_var under the hood + return fn(*args, **kwargs) + except Exception: + session.rollback() + raise + finally: + if auto_remove: + session.remove() + + +def with_session(auto_remove: bool = True, new_session: bool = True): + """ + Decorator for sync DB functions. + + Args: + auto_remove: session.remove() after fn returns (default True) + new_session: force a fresh session UUID for each call (default False) + + Usage: + @with_session() + def read_data(...): + session = Session() + return session.query(...) + + @with_session(auto_remove=False, new_session=True) + def batch_ops(...): + session = Session() + # do writes in an isolated session context + """ + + def decorator(fn): + @functools.wraps(fn) + def wrapper(*args, **kwargs): + return _run_with_session( + fn, + *args, + auto_remove=auto_remove, + new_session=new_session, + **kwargs, + ) + + return wrapper + + return decorator + + +async def run_db( + fn, *args, auto_remove: bool = True, new_session: bool = True, **kwargs +): + """ + Async helper: runs a sync @with_session function in a threadpool. + + Args: + fn: the @with_session-decorated function to call + auto_remove: pass-through to control session removal + new_session: pass-through to force fresh session UUID + """ + ctx = copy_context() + return await asyncio.to_thread( + lambda: ctx.run( + _run_with_session, + fn, + *args, + auto_remove, + new_session, + **kwargs, + ) + ) From 371f73b3850c75f5453f0021c7ba1af179770ced Mon Sep 17 00:00:00 2001 From: JWittmeyer Date: Mon, 12 May 2025 16:06:31 +0200 Subject: [PATCH 5/7] Change to kwargs instead of args --- session_wrapper.py | 88 ++++++++++++++++++++++++++++++++++++---------- 1 file changed, 70 insertions(+), 18 deletions(-) diff --git a/session_wrapper.py b/session_wrapper.py index 9eb68621..5ff5196a 100644 --- a/session_wrapper.py +++ b/session_wrapper.py @@ -38,24 +38,73 @@ def _run_with_session( session.remove() -def with_session(auto_remove: bool = True, new_session: bool = True): +# def with_session(auto_remove: bool = True, new_session: bool = True): +# """ +# Decorator for sync DB functions. + +# Args: +# auto_remove: session.remove() after fn returns (default True) +# new_session: force a fresh session UUID for each call (default False) + +# Usage: +# @with_session() +# def read_data(...): +# session = Session() +# return session.query(...) + +# @with_session(auto_remove=False, new_session=True) +# def batch_ops(...): +# session = Session() +# # do writes in an isolated session context +# """ + +# def decorator(fn): +# @functools.wraps(fn) +# def wrapper(*args, **kwargs): +# return _run_with_session( +# fn, +# *args, +# auto_remove=auto_remove, +# new_session=new_session, +# **kwargs, +# ) + +# return wrapper + +# return decorator + + +# async def run_db( +# fn, *args, auto_remove: bool = True, new_session: bool = True, **kwargs +# ): +# """ +# Async helper: runs a sync @with_session function in a threadpool. + +# Args: +# fn: the @with_session-decorated function to call +# auto_remove: pass-through to control session removal +# new_session: pass-through to force fresh session UUID +# """ +# ctx = copy_context() +# return await asyncio.to_thread( +# lambda: ctx.run( +# _run_with_session, +# fn, +# *args, +# auto_remove, +# new_session, +# **kwargs, +# ) +# ) + + +def with_session(auto_remove: bool = True, new_session: bool = False): """ Decorator for sync DB functions. Args: auto_remove: session.remove() after fn returns (default True) new_session: force a fresh session UUID for each call (default False) - - Usage: - @with_session() - def read_data(...): - session = Session() - return session.query(...) - - @with_session(auto_remove=False, new_session=True) - def batch_ops(...): - session = Session() - # do writes in an isolated session context """ def decorator(fn): @@ -75,7 +124,7 @@ def wrapper(*args, **kwargs): async def run_db( - fn, *args, auto_remove: bool = True, new_session: bool = True, **kwargs + fn, *args, auto_remove: bool = True, new_session: bool = False, **kwargs ): """ Async helper: runs a sync @with_session function in a threadpool. @@ -86,13 +135,16 @@ async def run_db( new_session: pass-through to force fresh session UUID """ ctx = copy_context() - return await asyncio.to_thread( - lambda: ctx.run( + + def call(): + # explicitly pass keyword-only args + return ctx.run( _run_with_session, fn, *args, - auto_remove, - new_session, + auto_remove=auto_remove, + new_session=new_session, **kwargs, ) - ) + + return await asyncio.to_thread(call) From 4c36e2d40d4e8c703434bcc0c34a3cd05e8fa50b Mon Sep 17 00:00:00 2001 From: JWittmeyer Date: Mon, 12 May 2025 16:18:36 +0200 Subject: [PATCH 6/7] Pool info --- session.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/session.py b/session.py index 9e588ea1..3f65f36c 100644 --- a/session.py +++ b/session.py @@ -88,3 +88,22 @@ def __start_session_cleanup(): except Exception: traceback.print_exc() time.sleep(10) + + +def pool_report(): + """ + Returns a dict with pool metrics for the engine bound to the given + SQLAlchemy Session (or global `engine` if sess is None). + """ + # eng = sess.get_bind() if sess else engine + pool = engine.pool + + return { + "pool_size": pool.size(), + "checked_in": pool.checkedin(), + "overflow": pool.overflow(), + "checked_out": pool.checkedout(), + "max_overflow": pool._max_overflow, + "total_capacity": pool.size() + pool._max_overflow, + "available": (pool.size() + pool._max_overflow) - pool.checkedout(), + } From eacdfc9fa7e22e7b17a3a70a162fc8601208e403 Mon Sep 17 00:00:00 2001 From: JWittmeyer Date: Mon, 12 May 2025 16:23:06 +0200 Subject: [PATCH 7/7] IS_DEV change --- business_objects/general.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/business_objects/general.py b/business_objects/general.py index e1215209..f0ce8c05 100644 --- a/business_objects/general.py +++ b/business_objects/general.py @@ -11,11 +11,12 @@ from threading import Lock from sqlalchemy.dialects import postgresql from sqlalchemy.sql import Select +import os __THREAD_LOCK = Lock() -IS_DEV = True +IS_DEV = os.getenv("IS_DEV", "false").lower() in {"true", "1", "yes", "y"} session_lookup = {}