diff --git a/tests/unit/test_logging.py b/tests/unit/test_logging.py index 53ce3e1a6f3b..fc0c23ade964 100644 --- a/tests/unit/test_logging.py +++ b/tests/unit/test_logging.py @@ -119,8 +119,9 @@ def test_includeme(monkeypatch, settings, expected_level): structlog.stdlib.filter_by_level, structlog.stdlib.add_logger_name, structlog.stdlib.add_log_level, - mock.ANY, - mock.ANY, + mock.ANY, # PositionalArgumentsFormatter + mock.ANY, # TimeStamper + mock.ANY, # StackInfoRenderer structlog.processors.format_exc_info, wlogging.RENDERER, ], @@ -135,6 +136,10 @@ def test_includeme(monkeypatch, settings, expected_level): ) assert isinstance( configure.calls[0].kwargs["processors"][4], + structlog.processors.TimeStamper, + ) + assert isinstance( + configure.calls[0].kwargs["processors"][5], structlog.processors.StackInfoRenderer, ) assert isinstance( @@ -144,3 +149,29 @@ def test_includeme(monkeypatch, settings, expected_level): pretend.call(wlogging._create_id, name="id", reify=True), pretend.call(wlogging._create_logger, name="log", reify=True), ] + + +def test_configure_celery_logging(monkeypatch): + configure = pretend.call_recorder(lambda **kw: None) + monkeypatch.setattr(structlog, "configure", configure) + + mock_handler = pretend.stub(setFormatter=pretend.call_recorder(lambda f: None)) + mock_logger = pretend.stub( + handlers=pretend.stub(clear=pretend.call_recorder(lambda: None)), + setLevel=pretend.call_recorder(lambda level: None), + addHandler=pretend.call_recorder(lambda add_handler: None), + removeHandler=pretend.call_recorder(lambda remove_handler: None), + ) + monkeypatch.setattr(logging, "getLogger", lambda: mock_logger) + monkeypatch.setattr(logging, "StreamHandler", lambda: mock_handler) + + wlogging.configure_celery_logging() + + # Verify handlers cleared and new one added + assert mock_logger.handlers.clear.calls == [pretend.call()] + assert len(mock_logger.addHandler.calls) == 1 + assert mock_logger.setLevel.calls == [pretend.call(logging.INFO)] + + # Verify processors + processors = configure.calls[0].kwargs["processors"] + assert structlog.contextvars.merge_contextvars in processors diff --git a/tests/unit/test_tasks.py b/tests/unit/test_tasks.py index 27457d554bf0..c50f88ae5f4f 100644 --- a/tests/unit/test_tasks.py +++ b/tests/unit/test_tasks.py @@ -524,3 +524,26 @@ def test_includeme(env, ssl, broker_redis_url, expected_url, transport_options): assert config.add_request_method.calls == [ pretend.call(tasks._get_task_from_request, name="task", reify=True) ] + + +def test_on_after_setup_logger(monkeypatch): + configure_celery_logging = pretend.call_recorder(lambda logfile, loglevel: None) + monkeypatch.setattr( + "warehouse.logging.configure_celery_logging", configure_celery_logging + ) + + tasks.on_after_setup_logger("logger", "loglevel", "logfile") + + assert configure_celery_logging.calls == [pretend.call("logfile", "loglevel")] + + +def test_on_task_prerun(monkeypatch): + bind_contextvars = pretend.call_recorder(lambda **kw: None) + monkeypatch.setattr("structlog.contextvars.bind_contextvars", bind_contextvars) + + task = pretend.stub(name="test.task") + tasks.on_task_prerun(None, "task-123", task) + + assert bind_contextvars.calls == [ + pretend.call(task_id="task-123", task_name="test.task") + ] diff --git a/warehouse/logging.py b/warehouse/logging.py index 9c383716cf62..9c0a1ad50847 100644 --- a/warehouse/logging.py +++ b/warehouse/logging.py @@ -36,6 +36,39 @@ def _create_id(request): return str(uuid.uuid4()) +def configure_celery_logging(logfile: str | None = None, loglevel: int = logging.INFO): + """Configure unified structlog logging for Celery that handles all log types.""" + processors = [ + structlog.contextvars.merge_contextvars, + structlog.processors.TimeStamper(fmt="iso"), + structlog.stdlib.add_log_level, + structlog.processors.StackInfoRenderer(), + structlog.processors.format_exc_info, + ] + formatter = structlog.stdlib.ProcessorFormatter( + processor=RENDERER, + foreign_pre_chain=processors, # type: ignore[arg-type] + ) + + handler = logging.FileHandler(logfile) if logfile else logging.StreamHandler() + handler.setFormatter(formatter) + + root = logging.getLogger() + root.handlers.clear() + root.addHandler(handler) + root.setLevel(loglevel) + + structlog.configure( + processors=processors # type: ignore[arg-type] + + [ + structlog.stdlib.ProcessorFormatter.wrap_for_formatter, + ], + logger_factory=structlog.stdlib.LoggerFactory(), + wrapper_class=structlog.make_filtering_bound_logger(logging.INFO), + cache_logger_on_first_use=True, + ) + + def _create_logger(request): # This has to use **{} instead of just a kwarg because request.id is not # an allowed kwarg name. @@ -88,6 +121,7 @@ def includeme(config): structlog.stdlib.add_logger_name, structlog.stdlib.add_log_level, structlog.stdlib.PositionalArgumentsFormatter(), + structlog.processors.TimeStamper(fmt="iso"), structlog.processors.StackInfoRenderer(), structlog.processors.format_exc_info, RENDERER, diff --git a/warehouse/tasks.py b/warehouse/tasks.py index d44ca0730822..68b40ad74354 100644 --- a/warehouse/tasks.py +++ b/warehouse/tasks.py @@ -8,19 +8,23 @@ import time import typing import urllib.parse +import uuid import celery import celery.app.backends import celery.backends.redis import pyramid.scripting import pyramid_retry +import structlog import transaction import venusian +from celery import signals from kombu import Queue from pyramid.threadlocal import get_current_request from warehouse.config import Environment +from warehouse.logging import configure_celery_logging from warehouse.metrics import IMetricsService if typing.TYPE_CHECKING: @@ -37,6 +41,20 @@ logger = logging.getLogger(__name__) +# Celery signal handlers for unified structlog configuration +@signals.after_setup_logger.connect +def on_after_setup_logger(logger, loglevel, logfile, *args, **kwargs): + """Override Celery's default logging behavior + with unified structlog configuration.""" + configure_celery_logging(logfile, loglevel) + + +@signals.task_prerun.connect +def on_task_prerun(sender, task_id, task, **_): + """Bind task metadata to contextvars for all logs within the task.""" + structlog.contextvars.bind_contextvars(task_id=task_id, task_name=task.name) + + class TLSRedisBackend(celery.backends.redis.RedisBackend): def _params_from_url(self, url, defaults): params = super()._params_from_url(url, defaults) @@ -122,6 +140,10 @@ def get_request(self) -> Request: env["request"].remote_addr_hashed = hashlib.sha256( ("127.0.0.1" + registry.settings["warehouse.ip_salt"]).encode("utf8") ).hexdigest() + request_id = str(uuid.uuid4()) + env["request"].id = request_id + structlog.contextvars.bind_contextvars(**{"request.id": request_id}) + env["request"].log = structlog.get_logger("warehouse.request") self.request.update(pyramid_env=env) return self.request.pyramid_env["request"] # type: ignore[attr-defined] @@ -302,6 +324,10 @@ def includeme(config: Configurator) -> None: REDBEAT_REDIS_URL=s["celery.scheduler_url"], # Silence deprecation warning on startup broker_connection_retry_on_startup=False, + # Disable Celery's logger hijacking for unified structlog control + worker_hijack_root_logger=False, + worker_log_format="%(message)s", + worker_task_log_format="%(message)s", ) config.registry["celery.app"].Task = WarehouseTask config.registry["celery.app"].pyramid_config = config