Skip to content

Commit 6cad942

Browse files
committed
use ServerSentEvent from sse_starlette
Signed-off-by: Ma, Xiangxiang <xiangxiang.ma@intel.com>
1 parent 6bf1e55 commit 6cad942

File tree

4 files changed

+13
-13
lines changed

4 files changed

+13
-13
lines changed

application/backend/src/api/endpoints/job_endpoints.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from uuid import UUID
66

77
from fastapi import APIRouter, Body, Depends
8-
from fastapi.responses import StreamingResponse
8+
from sse_starlette import EventSourceResponse
99

1010
from api.dependencies import get_job_id, get_job_service
1111
from api.endpoints import API_PREFIX
@@ -39,6 +39,6 @@ async def submit_train_job(
3939
async def get_job_logs(
4040
job_id: Annotated[UUID, Depends(get_job_id)],
4141
job_service: Annotated[JobService, Depends(get_job_service)],
42-
) -> StreamingResponse:
42+
) -> EventSourceResponse:
4343
"""Endpoint to get the logs of a job by its ID"""
44-
return StreamingResponse(job_service.stream_logs(job_id=job_id), media_type="text/event-stream")
44+
return EventSourceResponse(job_service.stream_logs(job_id=job_id))

application/backend/src/services/job_service.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@
33
import asyncio
44
import datetime
55
import os
6+
from collections.abc import AsyncGenerator
67
from uuid import UUID
78

89
import anyio
910
from sqlalchemy.exc import IntegrityError
11+
from sse_starlette import ServerSentEvent
1012

1113
from db import get_async_db_session_ctx
1214
from exceptions import DuplicateJobException, ResourceNotFoundException
@@ -75,7 +77,7 @@ async def update_job_status(
7577
await repo.update(job, updates)
7678

7779
@classmethod
78-
async def stream_logs(cls, job_id: UUID | str):
80+
async def stream_logs(cls, job_id: UUID | str) -> AsyncGenerator[ServerSentEvent, None]:
7981
from core.logging.utils import get_job_logs_path
8082

8183
log_file = get_job_logs_path(job_id=job_id)
@@ -110,8 +112,5 @@ async def is_job_still_running():
110112
continue
111113
# No more lines are expected
112114
else:
113-
yield "data: DONE\n\n"
114115
break
115-
116-
# Format as an SSE message
117-
yield f"data: {line.rstrip()}\n\n"
116+
yield ServerSentEvent(data=line.rstrip())

application/backend/tests/unit/endpoints/test_jobs.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import pytest
66
from fastapi import status
7+
from sse_starlette import ServerSentEvent
78

89
from api.dependencies import get_job_service
910
from main import app
@@ -41,8 +42,8 @@ def test_get_job_logs_success(fxt_client, fxt_job_service, fxt_job):
4142

4243
# Mock the stream_logs generator
4344
async def mock_stream():
44-
yield '{"level": "INFO", "message": "Line 1"}\n'
45-
yield '{"level": "INFO", "message": "Line 2"}\n'
45+
yield ServerSentEvent(data='{"level": "INFO", "message": "Line 1"}')
46+
yield ServerSentEvent(data='{"level": "INFO", "message": "Line 2"}')
4647

4748
fxt_job_service.stream_logs.return_value = mock_stream()
4849

@@ -52,7 +53,7 @@ async def mock_stream():
5253
# Verify the streamed content
5354
content = response.content.decode("utf-8")
5455
lines = [line for line in content.split("\n") if line]
55-
assert len(lines) == 2
56+
assert len(lines) == 4 # 2 events + 2 newlines
5657
assert '"level": "INFO"' in lines[0]
5758
assert '"message": "Line 1"' in lines[0]
5859

application/backend/tests/unit/services/test_job_service.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ async def consume_stream():
206206

207207
def test_stream_logs_success(self, fxt_job_repository, fxt_job):
208208
"""Test streaming logs successfully from a completed job."""
209-
log_lines = ['{"level": "INFO", "message": "Line 1"}\n', '{"level": "INFO", "message": "Line 2"}\n']
209+
log_lines = ['{"level": "INFO", "message": "Line 1"}', '{"level": "INFO", "message": "Line 2"}']
210210

211211
# Mock job as completed
212212
completed_job = fxt_job.model_copy(update={"status": JobStatus.COMPLETED})
@@ -240,7 +240,7 @@ async def mock_anyio_open_file(*args, **kwargs):
240240
async def consume_stream():
241241
result = []
242242
async for line in JobService.stream_logs(fxt_job.id):
243-
result.append(line)
243+
result.append(line.data)
244244
return result
245245

246246
result = asyncio.run(consume_stream())

0 commit comments

Comments
 (0)