2
2
import requests
3
3
from http import HTTPStatus
4
4
from enum import Enum
5
- from typing import Dict , List
5
+ from typing import Dict , List , Union
6
6
7
7
8
8
class AIProvider (Enum ):
@@ -44,7 +44,30 @@ def __init__(self, llm: LLM, access_token: str):
44
44
def _get_message (self , message : str , role : str = "user" ):
45
45
return {"role" : role , "content" : message }
46
46
47
+ def _handle_api_response (self , response ) -> Dict [str , Union [str , int ]]:
48
+ """
49
+ Handles the API response, returning a success or error structure that the frontend can use.
50
+ """
51
+ if response .status_code == HTTPStatus .OK :
52
+ return {
53
+ "status" : "success" ,
54
+ "data" : response .json ()["choices" ][0 ]["message" ]["content" ],
55
+ }
56
+ elif response .status_code == HTTPStatus .UNAUTHORIZED :
57
+ return {
58
+ "status" : "error" ,
59
+ "message" : "Unauthorized Access: Your access token is either missing, expired, or invalid. Please ensure that you are providing a valid token. " ,
60
+ }
61
+ else :
62
+ return {
63
+ "status" : "error" ,
64
+ "message" : f"Unexpected error: { response .text } " ,
65
+ }
66
+
47
67
def _open_ai_fetch_completion_open_ai (self , messages : List [Dict [str , str ]]):
68
+ """
69
+ Handles the request to OpenAI API for fetching completions.
70
+ """
48
71
payload = {
49
72
"model" : self .LLM_NAME_TO_MODEL_MAP [self ._llm ],
50
73
"temperature" : 0.6 ,
@@ -53,13 +76,12 @@ def _open_ai_fetch_completion_open_ai(self, messages: List[Dict[str, str]]):
53
76
api_url = "https://api.openai.com/v1/chat/completions"
54
77
response = requests .post (api_url , headers = self ._headers , json = payload )
55
78
56
- print (payload , api_url , response )
57
- if response .status_code != HTTPStatus .OK :
58
- raise Exception (response .json ())
59
-
60
- return response .json ()
79
+ return self ._handle_api_response (response )
61
80
62
81
def _fireworks_ai_fetch_completions (self , messages : List [Dict [str , str ]]):
82
+ """
83
+ Handles the request to Fireworks AI API for fetching completions.
84
+ """
63
85
payload = {
64
86
"model" : self .LLM_NAME_TO_MODEL_MAP [self ._llm ],
65
87
"temperature" : 0.6 ,
@@ -73,28 +95,28 @@ def _fireworks_ai_fetch_completions(self, messages: List[Dict[str, str]]):
73
95
api_url = "https://api.fireworks.ai/inference/v1/chat/completions"
74
96
response = requests .post (api_url , headers = self ._headers , json = payload )
75
97
76
- if response .status_code != HTTPStatus .OK :
77
- raise Exception (response .json ())
78
-
79
- return response .json ()
80
-
81
- def _fetch_completion (self , messages : List [Dict [str , str ]]):
98
+ return self ._handle_api_response (response )
82
99
100
+ def _fetch_completion (
101
+ self , messages : List [Dict [str , str ]]
102
+ ) -> Dict [str , Union [str , int ]]:
103
+ """
104
+ Fetches the completion using the appropriate AI provider based on the LLM.
105
+ """
83
106
if self ._ai_provider == AIProvider .FIREWORKS_AI :
84
- return self ._fireworks_ai_fetch_completions (messages )["choices" ][0 ][
85
- "message"
86
- ]["content" ]
107
+ return self ._fireworks_ai_fetch_completions (messages )
87
108
88
109
if self ._ai_provider == AIProvider .OPEN_AI :
89
- return self ._open_ai_fetch_completion_open_ai (messages )["choices" ][0 ][
90
- "message"
91
- ]["content" ]
110
+ return self ._open_ai_fetch_completion_open_ai (messages )
92
111
93
- raise Exception (f"Invalid AI provider { self ._ai_provider } " )
112
+ return {
113
+ "status" : "error" ,
114
+ "message" : f"Invalid AI provider { self ._ai_provider } " ,
115
+ }
94
116
95
117
def get_dora_metrics_score (
96
118
self , four_keys_data : Dict [str , float ]
97
- ) -> Dict [str , str ]:
119
+ ) -> Dict [str , Union [ str , int ] ]:
98
120
"""
99
121
Calculate the DORA metrics score using input data and an LLM (Language Learning Model).
100
122
0 commit comments