Skip to content

Commit f3e8cfd

Browse files
committed
use ServerSentEvent from sse_starlette
Signed-off-by: Ma, Xiangxiang <xiangxiang.ma@intel.com>
1 parent 701aca5 commit f3e8cfd

File tree

4 files changed

+13
-10
lines changed

4 files changed

+13
-10
lines changed

application/backend/src/api/endpoints/job_endpoints.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from uuid import UUID
66

77
from fastapi import APIRouter, Body, Depends
8-
from fastapi.responses import StreamingResponse
8+
from sse_starlette import EventSourceResponse
99

1010
from api.dependencies import get_job_id, get_job_service
1111
from api.endpoints import API_PREFIX
@@ -39,6 +39,6 @@ async def submit_train_job(
3939
async def get_job_logs(
4040
job_id: Annotated[UUID, Depends(get_job_id)],
4141
job_service: Annotated[JobService, Depends(get_job_service)],
42-
) -> StreamingResponse:
42+
) -> EventSourceResponse:
4343
"""Endpoint to get the logs of a job by its ID"""
44-
return StreamingResponse(job_service.stream_logs(job_id=job_id), media_type="text/event-stream")
44+
return EventSourceResponse(job_service.stream_logs(job_id=job_id))

application/backend/src/services/job_service.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22
# SPDX-License-Identifier: Apache-2.0
33
import asyncio
44
import os
5+
from collections.abc import AsyncGenerator
56
from uuid import UUID
67

78
import anyio
89
from sqlalchemy.exc import IntegrityError
10+
from sse_starlette import ServerSentEvent
911

1012
from db import get_async_db_session_ctx
1113
from exceptions import DuplicateJobException, ResourceNotFoundException
@@ -70,7 +72,7 @@ async def update_job_status(
7072
await repo.update(job, updates)
7173

7274
@classmethod
73-
async def stream_logs(cls, job_id: UUID | str):
75+
async def stream_logs(cls, job_id: UUID | str) -> AsyncGenerator[ServerSentEvent, None]:
7476
from core.logging.utils import get_job_logs_path
7577

7678
log_file = get_job_logs_path(job_id=job_id)
@@ -106,4 +108,4 @@ async def is_job_still_running():
106108
# No more lines are expected
107109
else:
108110
break
109-
yield line
111+
yield ServerSentEvent(data=line.rstrip())

application/backend/tests/unit/endpoints/test_jobs.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import pytest
66
from fastapi import status
7+
from sse_starlette import ServerSentEvent
78

89
from api.dependencies import get_job_service
910
from main import app
@@ -41,8 +42,8 @@ def test_get_job_logs_success(fxt_client, fxt_job_service, fxt_job):
4142

4243
# Mock the stream_logs generator
4344
async def mock_stream():
44-
yield '{"level": "INFO", "message": "Line 1"}\n'
45-
yield '{"level": "INFO", "message": "Line 2"}\n'
45+
yield ServerSentEvent(data='{"level": "INFO", "message": "Line 1"}')
46+
yield ServerSentEvent(data='{"level": "INFO", "message": "Line 2"}')
4647

4748
fxt_job_service.stream_logs.return_value = mock_stream()
4849

@@ -52,7 +53,7 @@ async def mock_stream():
5253
# Verify the streamed content
5354
content = response.content.decode("utf-8")
5455
lines = [line for line in content.split("\n") if line]
55-
assert len(lines) == 2
56+
assert len(lines) == 4 # 2 events + 2 newlines
5657
assert '"level": "INFO"' in lines[0]
5758
assert '"message": "Line 1"' in lines[0]
5859

application/backend/tests/unit/services/test_job_service.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ async def consume_stream():
206206

207207
def test_stream_logs_success(self, fxt_job_repository, fxt_job):
208208
"""Test streaming logs successfully from a completed job."""
209-
log_lines = ['{"level": "INFO", "message": "Line 1"}\n', '{"level": "INFO", "message": "Line 2"}\n']
209+
log_lines = ['{"level": "INFO", "message": "Line 1"}', '{"level": "INFO", "message": "Line 2"}']
210210

211211
# Mock job as completed
212212
completed_job = fxt_job.model_copy(update={"status": JobStatus.COMPLETED})
@@ -240,7 +240,7 @@ async def mock_anyio_open_file(*args, **kwargs):
240240
async def consume_stream():
241241
result = []
242242
async for line in JobService.stream_logs(fxt_job.id):
243-
result.append(line)
243+
result.append(line.data)
244244
return result
245245

246246
result = asyncio.run(consume_stream())

0 commit comments

Comments
 (0)