|
13 | 13 | from ads.aqua.common.enums import PredictEndpoints |
14 | 14 | from notebook.base.handlers import IPythonHandler |
15 | 15 | from parameterized import parameterized |
| 16 | +import openai |
16 | 17 |
|
17 | 18 | import ads.aqua |
18 | 19 | import ads.config |
@@ -247,6 +248,9 @@ def test_validate_deployment_params( |
247 | 248 |
|
248 | 249 |
|
249 | 250 | class TestAquaDeploymentStreamingInferenceHandler(unittest.TestCase): |
| 251 | + |
| 252 | + EXPECTED_OCID = "ocid1.compartment.oc1..aaaaaaaaser65kfcfht7iddoioa4s6xos3vi53d3i7bi3czjkqyluawp2itq" |
| 253 | + |
250 | 254 | @patch.object(IPythonHandler, "__init__") |
251 | 255 | def setUp(self, ipython_init_mock) -> None: |
252 | 256 | ipython_init_mock.return_value = None |
@@ -315,7 +319,9 @@ def test_extract_text_from_choice_object_delta_content(self): |
315 | 319 |
|
316 | 320 | def test_extract_text_from_choice_object_message_str(self): |
317 | 321 | """Test object choice with message as string.""" |
318 | | - choice = MagicMock(message="direct-string") |
| 322 | + choice = MagicMock() |
| 323 | + choice.delta = None # No delta, so message takes precedence |
| 324 | + choice.message = "direct-string" |
319 | 325 | result = self.handler._extract_text_from_choice(choice) |
320 | 326 | self.assertEqual(result, "direct-string") |
321 | 327 |
|
@@ -350,150 +356,7 @@ def test_extract_text_from_chunk_empty(self): |
350 | 356 | self.assertIsNone(result) |
351 | 357 | result = self.handler._extract_text_from_chunk(None) |
352 | 358 | self.assertIsNone(result) |
353 | | - |
354 | | - @patch('ads.aqua.modeldeployment.AquaDeploymentApp') |
355 | | - def test_missing_required_keys_raises_http_error(self, mock_aqua_app): |
356 | | - """Test missing required payload keys raises HTTPError.""" |
357 | | - payload = {"prompt": "test"} |
358 | | - with self.assertRaises(HTTPError) as cm: |
359 | | - list(self.handler._get_model_deployment_response("test-id", payload)) |
360 | | - self.assertEqual(cm.exception.status_code, 400) |
361 | | - self.assertIn("model", str(cm.exception)) |
362 | | - |
363 | | - @patch('ads.aqua.modeldeployment.AquaDeploymentApp') |
364 | | - @patch.object(AquaDeploymentStreamingInferenceHandler, '_extract_text_from_chunk') |
365 | | - def test_chat_completions_no_image_yields_chunks(self, mock_extract, mock_aqua_app): |
366 | | - """Test chat completions without image streams correctly.""" |
367 | | - mock_deployment = MagicMock() |
368 | | - mock_deployment.endpoint = "https://test-endpoint" |
369 | | - mock_aqua_app.return_value.get.return_value = mock_deployment |
370 | | - |
371 | | - mock_stream = iter([MagicMock(choices=[{"delta": {"content": "hello"}}])]) |
372 | | - mock_client = MagicMock() |
373 | | - mock_client.chat.completions.create.return_value = mock_stream |
374 | | - with patch.object(self.handler, 'OpenAI', return_value=mock_client): |
375 | | - payload = { |
376 | | - "endpoint_type": PredictEndpoints.CHAT_COMPLETIONS_ENDPOINT, |
377 | | - "prompt": "test prompt", |
378 | | - "model": "test-model" |
379 | | - } |
380 | | - result = list(self.handler._get_model_deployment_response("test-id", payload)) |
381 | | - |
382 | | - mock_extract.assert_called() |
383 | | - self.assertEqual(result, ["hello"]) |
384 | | - |
385 | | - @patch('ads.aqua.modeldeployment.AquaDeploymentApp') |
386 | | - @patch.object(AquaDeploymentStreamingInferenceHandler, '_extract_text_from_chunk') |
387 | | - def test_text_completions_endpoint(self, mock_extract, mock_aqua_app): |
388 | | - """Test text completions endpoint path.""" |
389 | | - mock_deployment = MagicMock() |
390 | | - mock_deployment.endpoint = "https://test-endpoint" |
391 | | - mock_aqua_app.return_value.get.return_value = mock_deployment |
392 | | - |
393 | | - mock_stream = iter([MagicMock(choices=[{"delta": {"content": "text"}}])]) |
394 | | - mock_client = MagicMock() |
395 | | - mock_client.completions.create.return_value = mock_stream |
396 | | - with patch.object(self.handler, 'OpenAI', return_value=mock_client): |
397 | | - payload = { |
398 | | - "endpoint_type": PredictEndpoints.TEXT_COMPLETIONS_ENDPOINT, |
399 | | - "prompt": "test", |
400 | | - "model": "test-model" |
401 | | - } |
402 | | - result = list(self.handler._get_model_deployment_response("test-id", payload)) |
403 | | - |
404 | | - self.assertEqual(result, ["text"]) |
405 | | - |
406 | | - @patch('ads.aqua.modeldeployment.AquaDeploymentApp') |
407 | | - @patch.object(AquaDeploymentStreamingInferenceHandler, '_extract_text_from_chunk') |
408 | | - def test_image_chat_completions(self, mock_extract, mock_aqua_app): |
409 | | - """Test chat completions with image input.""" |
410 | | - mock_deployment = MagicMock() |
411 | | - mock_deployment.endpoint = "https://test-endpoint" |
412 | | - mock_aqua_app.return_value.get.return_value = mock_deployment |
413 | | - |
414 | | - mock_stream = iter([MagicMock()]) |
415 | | - mock_client = MagicMock() |
416 | | - mock_client.chat.completions.create.return_value = mock_stream |
417 | | - with patch.object(self.handler, 'OpenAI', return_value=mock_client): |
418 | | - payload = { |
419 | | - "endpoint_type": PredictEndpoints.CHAT_COMPLETIONS_ENDPOINT, |
420 | | - "prompt": "describe image", |
421 | | - "model": "test-model", |
422 | | - "encoded_image": "data:image/jpeg;base64,...", |
423 | | - "file_type": "image/jpeg" |
424 | | - } |
425 | | - list(self.handler._get_model_deployment_response("test-id", payload)) |
426 | | - |
427 | | - expected_call = call( |
428 | | - model="test-model", |
429 | | - messages=[{ |
430 | | - "role": "user", |
431 | | - "content": [ |
432 | | - {"type": "text", "text": "describe image"}, |
433 | | - {"type": "image_url", "image_url": {"url": "data:image/jpeg;base64,..."}} # Note: f-string expands |
434 | | - ] |
435 | | - }], |
436 | | - stream=True |
437 | | - ) |
438 | | - mock_client.chat.completions.create.assert_has_calls([expected_call]) |
439 | | - |
440 | | - @patch('ads.aqua.modeldeployment.AquaDeploymentApp') |
441 | | - def test_unsupported_endpoint_type_raises_error(self, mock_aqua_app): |
442 | | - """Test unsupported endpoint_type raises HTTPError.""" |
443 | | - mock_aqua_app.return_value.get.return_value = MagicMock(endpoint="test") |
444 | | - payload = { |
445 | | - "endpoint_type": "invalid-type", |
446 | | - "prompt": "test", |
447 | | - "model": "test-model" |
448 | | - } |
449 | | - with self.assertRaises(HTTPError) as cm: |
450 | | - list(self.handler._get_model_deployment_response("test-id", payload)) |
451 | | - self.assertEqual(cm.exception.status_code, 400) |
452 | | - |
453 | | - @patch('ads.aqua.modeldeployment.AquaDeploymentApp') |
454 | | - @patch.object(AquaDeploymentStreamingInferenceHandler, '_extract_text_from_chunk') |
455 | | - def test_responses_endpoint_with_params(self, mock_extract, mock_aqua_app): |
456 | | - """Test responses endpoint with temperature/top_p filtering.""" |
457 | | - mock_deployment = MagicMock() |
458 | | - mock_deployment.endpoint = "https://test-endpoint" |
459 | | - mock_aqua_app.return_value.get.return_value = mock_deployment |
460 | | - |
461 | | - mock_stream = iter([MagicMock()]) |
462 | | - mock_client = MagicMock() |
463 | | - mock_client.responses.create.return_value = mock_stream |
464 | | - with patch.object(self.handler, 'OpenAI', return_value=mock_client): |
465 | | - payload = { |
466 | | - "endpoint_type": PredictEndpoints.RESPONSES, |
467 | | - "prompt": "test", |
468 | | - "model": "test-model", |
469 | | - "temperature": 0.7, |
470 | | - "top_p": 0.9 |
471 | | - } |
472 | | - list(self.handler._get_model_deployment_response("test-id", payload)) |
473 | | - |
474 | | - mock_client.responses.create.assert_called_once_with( |
475 | | - model="test-model", |
476 | | - input="test", |
477 | | - stream=True, |
478 | | - temperature=0.7, |
479 | | - top_p=0.9 |
480 | | - ) |
481 | 359 |
|
482 | | - @patch('ads.aqua.modeldeployment.AquaDeploymentApp') |
483 | | - def test_stop_param_normalization(self, mock_aqua_app): |
484 | | - """Test stop=[] gets normalized to None.""" |
485 | | - mock_aqua_app.return_value.get.return_value = MagicMock(endpoint="test") |
486 | | - payload = { |
487 | | - "endpoint_type": PredictEndpoints.CHAT_COMPLETIONS_ENDPOINT, |
488 | | - "prompt": "test", |
489 | | - "model": "test-model", |
490 | | - "stop": [] |
491 | | - } |
492 | | - # Just verify it doesn't crash - normalization happens before API calls |
493 | | - try: |
494 | | - next(self.handler._get_model_deployment_response("test-id", payload)) |
495 | | - except HTTPError: |
496 | | - pass # Expected due to missing client mocks, but normalization should work |
497 | 360 |
|
498 | 361 |
|
499 | 362 |
|
|
0 commit comments