diff --git a/spacy_llm/models/rest/base.py b/spacy_llm/models/rest/base.py index df089961..900b52c8 100644 --- a/spacy_llm/models/rest/base.py +++ b/spacy_llm/models/rest/base.py @@ -1,7 +1,7 @@ import abc import time from enum import Enum -from typing import Any, Callable, Dict, Iterable, Optional +from typing import Any, Callable, Dict, List, Iterable, Optional import requests # type: ignore from requests import ConnectTimeout, ReadTimeout @@ -34,6 +34,7 @@ def __init__( interval: float, max_request_time: float, context_length: Optional[int], + conversational_history: Optional[List[Dict[str,str]]] = None ): """Initializes new instance of REST-based model. name (str): Model name. @@ -60,6 +61,7 @@ def __init__( self._max_request_time = max_request_time self._credentials = self.credentials self._context_length = context_length + self._conversational_history = conversational_history assert self._max_tries >= 1 assert self._interval > 0 diff --git a/spacy_llm/models/rest/openai/model.py b/spacy_llm/models/rest/openai/model.py index 7715f12c..f0213ee8 100644 --- a/spacy_llm/models/rest/openai/model.py +++ b/spacy_llm/models/rest/openai/model.py @@ -125,8 +125,9 @@ def _request(json_data: Dict[str, Any]) -> Dict[str, Any]: else: for prompt in prompts_for_doc: + user_message = [{"role": "user", "content": prompt}] responses = _request( - {"messages": [{"role": "user", "content": prompt}]} + {"messages": self._conversational_history + user_message if self._conversational_history else user_message} ) if "error" in responses: return responses["error"] diff --git a/spacy_llm/models/rest/openai/registry.py b/spacy_llm/models/rest/openai/registry.py index 3c3793ff..c7e68141 100644 --- a/spacy_llm/models/rest/openai/registry.py +++ b/spacy_llm/models/rest/openai/registry.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional from confection import SimpleFrozenDict @@ -21,6 +21,41 @@ endpoint (Optional[str]): Endpoint to set. Defaults to standard endpoint. """ +@registry.llm_models("spacy.GPT-4.v4") +def openai_gpt_4_v3( + config: Dict[Any, Any] = SimpleFrozenDict(temperature=_DEFAULT_TEMPERATURE), + name: str = "gpt-4", + strict: bool = OpenAI.DEFAULT_STRICT, + max_tries: int = OpenAI.DEFAULT_MAX_TRIES, + interval: float = OpenAI.DEFAULT_INTERVAL, + max_request_time: float = OpenAI.DEFAULT_MAX_REQUEST_TIME, + endpoint: Optional[str] = None, + context_length: Optional[int] = None, + conversational_history: Optional[List[Dict[str,str]]] = None +) -> OpenAI: + """Returns OpenAI instance for 'gpt-4' model using REST to prompt API. + + config (Dict[Any, Any]): LLM config passed on to the model's initialization. + name (str): Model name to use. Can be any model name supported by the OpenAI API - e. g. 'gpt-4', + "gpt-4-1106-preview", .... + conversational_history ( Optional[List[Dict[str,str]]]): Optional conversational history to be provided in the ChatML approach, + with the User/Assistant streams to be appended before main prompt request + RETURNS (OpenAI): OpenAI instance for 'gpt-4' model. + + DOCS: https://spacy.io/api/large-language-models#models + """ + return OpenAI( + name=name, + endpoint=endpoint or Endpoints.CHAT.value, + config=config, + strict=strict, + max_tries=max_tries, + interval=interval, + max_request_time=max_request_time, + context_length=context_length, + conversational_history=conversational_history + ) + @registry.llm_models("spacy.GPT-4.v3") def openai_gpt_4_v3(