From f6e0f6584978bf7ebade604892b54aa648a773fa Mon Sep 17 00:00:00 2001 From: NatashaTheRobot Date: Fri, 29 Nov 2024 14:36:49 +0530 Subject: [PATCH] WIP --- Package.resolved | 14 +- Package.swift | 28 ++ .../ModelIdentifier.Provider.swift | 11 + .../Service/_MIServiceTypeIdentifier.swift | 1 + .../Intramodular/ElevenLabs.Model.swift | 4 +- .../xAI.APISpecification.RequestBodies.swift | 144 ++++++++++ .../xAI.APISpecification.ResponseBodies.swift | 6 + .../API/xAI.APISpecification.swift | 120 ++++++++ .../Models/xAI.ChatCompletion.swift | 36 +++ .../Models/xAI.ChatFunctionDefinition.swift | 17 ++ .../Intramodular/Models/xAI.ChatMessage.swift | 210 ++++++++++++++ .../Models/xAI.ChatMessageBody.swift | 215 +++++++++++++++ .../Intramodular/Models/xAI.ChatRole.swift | 41 +++ .../xAI/Intramodular/Models/xAI.Tool.swift | 51 ++++ .../Intramodular/xAI+LLMRequestHandling.swift | 156 +++++++++++ .../xAI.ChatMessage+LargeLanguageModels.swift | 258 ++++++++++++++++++ Sources/xAI/Intramodular/xAI.Client.swift | 50 ++++ Sources/xAI/Intramodular/xAI.Model.swift | 52 ++++ Sources/xAI/Intramodular/xAI.swift | 7 + Sources/xAI/module.swift | 5 + .../Intramodular/EmbeddingsTests.swift | 1 - Tests/xAI/Intramodular/CompletionTests.swift | 43 +++ .../Intramodular/FunctionCallingTests.swift | 103 +++++++ Tests/xAI/module.swift | 12 + 24 files changed, 1575 insertions(+), 10 deletions(-) create mode 100644 Sources/xAI/Intramodular/API/xAI.APISpecification.RequestBodies.swift create mode 100644 Sources/xAI/Intramodular/API/xAI.APISpecification.ResponseBodies.swift create mode 100644 Sources/xAI/Intramodular/API/xAI.APISpecification.swift create mode 100644 Sources/xAI/Intramodular/Models/xAI.ChatCompletion.swift create mode 100644 Sources/xAI/Intramodular/Models/xAI.ChatFunctionDefinition.swift create mode 100644 Sources/xAI/Intramodular/Models/xAI.ChatMessage.swift create mode 100644 Sources/xAI/Intramodular/Models/xAI.ChatMessageBody.swift create mode 100644 Sources/xAI/Intramodular/Models/xAI.ChatRole.swift create mode 100644 Sources/xAI/Intramodular/Models/xAI.Tool.swift create mode 100644 Sources/xAI/Intramodular/xAI+LLMRequestHandling.swift create mode 100644 Sources/xAI/Intramodular/xAI.ChatMessage+LargeLanguageModels.swift create mode 100644 Sources/xAI/Intramodular/xAI.Client.swift create mode 100644 Sources/xAI/Intramodular/xAI.Model.swift create mode 100644 Sources/xAI/Intramodular/xAI.swift create mode 100644 Sources/xAI/module.swift create mode 100644 Tests/xAI/Intramodular/CompletionTests.swift create mode 100644 Tests/xAI/Intramodular/FunctionCallingTests.swift create mode 100644 Tests/xAI/module.swift diff --git a/Package.resolved b/Package.resolved index e257e27a..666d3ccf 100644 --- a/Package.resolved +++ b/Package.resolved @@ -1,5 +1,5 @@ { - "originHash" : "864ef9201dffd6ebf57da3ab4413cc1001edb70cd0c6d264b91e19b3967fb7ba", + "originHash" : "0db351194242c55d1647132cde67f92244bdfb0f30f99916a91c9464f714a06e", "pins" : [ { "identity" : "corepersistence", @@ -7,7 +7,7 @@ "location" : "https://github.com/vmanot/CorePersistence.git", "state" : { "branch" : "main", - "revision" : "38fd5271fa906a2d8395e4b42724142886a3c763" + "revision" : "e5026e82410d4140aa5468fe98069e4f5c4fa5cf" } }, { @@ -16,7 +16,7 @@ "location" : "https://github.com/vmanot/Merge.git", "state" : { "branch" : "master", - "revision" : "e8bc37c8dc203cab481efedd71237c151882c007" + "revision" : "2a47b62831164bea212a1616cfe3d4da32902f7f" } }, { @@ -34,7 +34,7 @@ "location" : "https://github.com/vmanot/Swallow.git", "state" : { "branch" : "master", - "revision" : "6227a1114e341daf54e90df61e173599b187a9b1" + "revision" : "f91811c49863ed506da4f3aed21f6c216a258ca0" } }, { @@ -51,8 +51,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/apple/swift-syntax.git", "state" : { - "revision" : "2bc86522d115234d1f588efe2bcb4ce4be8f8b82", - "version" : "510.0.3" + "revision" : "0687f71944021d616d34d922343dcef086855920", + "version" : "600.0.1" } }, { @@ -70,7 +70,7 @@ "location" : "https://github.com/SwiftUIX/SwiftUIX.git", "state" : { "branch" : "master", - "revision" : "836fc284a9bb07fc9ab6d2dce6ebd0e32aabde26" + "revision" : "275be4fd06b570b71fe1eb47b4246f4c4285b177" } } ], diff --git a/Package.swift b/Package.swift index e4abb826..244517eb 100644 --- a/Package.swift +++ b/Package.swift @@ -22,6 +22,7 @@ let package = Package( "ElevenLabs", "_Gemini", "Groq", + "xAI", "HuggingFace", "Jina", "Mistral", @@ -150,6 +151,21 @@ let package = Package( .enableExperimentalFeature("AccessLevelOnImport") ] ), + .target( + name: "xAI", + dependencies: [ + "CorePersistence", + "CoreMI", + "LargeLanguageModels", + "Merge", + "NetworkKit", + "Swallow" + ], + path: "Sources/xAI", + swiftSettings: [ + .enableExperimentalFeature("AccessLevelOnImport") + ] + ), .target( name: "Ollama", dependencies: [ @@ -275,6 +291,7 @@ let package = Package( "Mistral", "_Gemini", "Groq", + "xAI", "Ollama", "OpenAI", "Perplexity", @@ -345,6 +362,17 @@ let package = Package( .enableExperimentalFeature("AccessLevelOnImport") ] ), + .testTarget( + name: "xAITests", + dependencies: [ + "AI", + "Swallow" + ], + path: "Tests/xAI", + swiftSettings: [ + .enableExperimentalFeature("AccessLevelOnImport") + ] + ), .testTarget( name: "PerplexityTests", dependencies: [ diff --git a/Sources/CoreMI/Intramodular/Model Identifier/ModelIdentifier.Provider.swift b/Sources/CoreMI/Intramodular/Model Identifier/ModelIdentifier.Provider.swift index fbdc91cc..0ec80659 100644 --- a/Sources/CoreMI/Intramodular/Model Identifier/ModelIdentifier.Provider.swift +++ b/Sources/CoreMI/Intramodular/Model Identifier/ModelIdentifier.Provider.swift @@ -15,6 +15,7 @@ extension ModelIdentifier { case _Fal case _Mistral case _Groq + case _xAI case _Ollama case _OpenAI case _Gemini @@ -47,6 +48,10 @@ extension ModelIdentifier { Self._Groq } + public static var xAI: Self { + Self._xAI + } + public static var gemini: Self { Self._Gemini } @@ -92,6 +97,8 @@ extension ModelIdentifier.Provider: CustomStringConvertible { return "Mistral" case ._Groq: return "Groq" + case ._xAI: + return "xAI" case ._Ollama: return "Ollama" case ._OpenAI: @@ -129,6 +136,8 @@ extension ModelIdentifier.Provider: RawRepresentable { return "mistral" case ._Groq: return "groq" + case ._xAI: + return "xai" case ._Ollama: return "ollama" case ._OpenAI: @@ -164,6 +173,8 @@ extension ModelIdentifier.Provider: RawRepresentable { self = ._Mistral case Self._Groq.rawValue: self = ._Groq + case Self._xAI.rawValue: + self = ._xAI case Self._OpenAI.rawValue: self = ._OpenAI case Self._Gemini.rawValue: diff --git a/Sources/CoreMI/Intramodular/Service/_MIServiceTypeIdentifier.swift b/Sources/CoreMI/Intramodular/Service/_MIServiceTypeIdentifier.swift index 00fe0f99..413465c0 100644 --- a/Sources/CoreMI/Intramodular/Service/_MIServiceTypeIdentifier.swift +++ b/Sources/CoreMI/Intramodular/Service/_MIServiceTypeIdentifier.swift @@ -38,4 +38,5 @@ extension _MIServiceTypeIdentifier { public static let _VoyageAI = _MIServiceTypeIdentifier(rawValue: "hajat-fufoh-janaf-disam") public static let _Cohere = _MIServiceTypeIdentifier(rawValue: "guzob-fipin-navij-duvon") public static let _TogetherAI = _MIServiceTypeIdentifier(rawValue: "pafob-vopoj-lurig-zilur") + public static let _xAI = _MIServiceTypeIdentifier(rawValue: "niluj-futol-guhaj-pabas") } diff --git a/Sources/ElevenLabs/Intramodular/ElevenLabs.Model.swift b/Sources/ElevenLabs/Intramodular/ElevenLabs.Model.swift index 6d181027..35f1b09a 100644 --- a/Sources/ElevenLabs/Intramodular/ElevenLabs.Model.swift +++ b/Sources/ElevenLabs/Intramodular/ElevenLabs.Model.swift @@ -33,7 +33,7 @@ extension ElevenLabs.Model: CustomStringConvertible { extension ElevenLabs.Model: ModelIdentifierRepresentable { public init(from identifier: ModelIdentifier) throws { - guard identifier.provider == ._Groq, identifier.revision == nil else { + guard identifier.provider == ._ElevenLabs, identifier.revision == nil else { throw Never.Reason.illegal } @@ -46,7 +46,7 @@ extension ElevenLabs.Model: ModelIdentifierRepresentable { public func __conversion() -> ModelIdentifier { ModelIdentifier( - provider: ._Groq, + provider: ._ElevenLabs, name: rawValue, revision: nil ) diff --git a/Sources/xAI/Intramodular/API/xAI.APISpecification.RequestBodies.swift b/Sources/xAI/Intramodular/API/xAI.APISpecification.RequestBodies.swift new file mode 100644 index 00000000..1753bfaf --- /dev/null +++ b/Sources/xAI/Intramodular/API/xAI.APISpecification.RequestBodies.swift @@ -0,0 +1,144 @@ + + +import Foundation + +extension xAI.APISpecification.RequestBodies { + + /* https://docs.x.ai/api/endpoints#chat-completions */ + struct ChatCompletions: Codable, Hashable, Sendable { + private enum CodingKeys: String, CodingKey { + case model + case messages + case temperature + case topProbabilityMass = "top_p" + case choices = "n" + case stream + case stop + case maxTokens = "max_tokens" + case presencePenalty = "presence_penalty" + case frequencyPenalty = "frequency_penalty" + case logprobs = "logprobs" + case logitBias = "logit_bias" + case seed = "seed" + case topLogprobs = "top_logprobs" + case user + case tools + } + + /* Model name for the model to use. */ + var model: xAI.Model + + /* A list of messages that make up the the chat conversation. Different models support different message types, such as image and text.*/ + var messages: [xAI.ChatMessage] + + /* What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.*/ + var temperature: Double? + + /* An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. It is generally recommended to alter this or `temperature` but not both.*/ + var topProbabilityMass: Double? + + /*The maximum number of tokens that can be generated in the chat completion. This value can be used to control costs for text generated via API.*/ + var maxTokens: Int? + + /* If set, partial message deltas will be sent. Tokens will be sent as data-only server-sent events as they become available, with the stream terminated by a `data: [DONE]` message.*/ + var stream: Bool? + + /* If specified, our system will make a best effort to sample deterministically, such that repeated requests with the same `seed` and parameters should return the same result. Determinism is not guaranteed, and you should refer to the `system_fingerprint` response parameter to monitor changes in the backend.*/ + var seed: Int? + + /* Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. */ + var frequencyPenalty: Double? + + /* A JSON object that maps tokens (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically, the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should result in a ban or exclusive selection of the relevant token.*/ + var logitBias: [String: Int]? + + /* Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each output token returned in the content of message. */ + var logprobs: Bool? + + /* How many chat completion choices to generate for each input message. Note that you will be charged based on the number of generated tokens across all of the choices. Keep n as 1 to minimize costs. */ + var choices: Int? + + /* Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics. */ + var presencePenalty: Double? + + /* Up to 4 sequences where the API will stop generating further tokens. */ + var stop: [String]? + + /* An integer between 0 and 20 specifying the number of most likely tokens to return at each token position, each with an associated log probability. logprobs must be set to true if this parameter is used. */ + var topLogprobs: Int? + + /* A unique identifier representing your end-user, which can help xAI to monitor and detect abuse. */ + var user: String? + + /* A list of tools the model may call. Currently, only functions are supported as a tool. Use this to provide a list of functions the model may generate JSON inputs for. A max of 128 functions are supported. */ + var tools: [xAI.Tool]? + + init( + messages: [xAI.ChatMessage], + model: xAI.Model, + frequencyPenalty: Double? = nil, + logitBias: [String : Int]? = nil, + logprobs: Bool? = nil, + topLogprobs: Int? = nil, + maxTokens: Int? = nil, + choices: Int? = nil, + presencePenalty: Double? = nil, + seed: Int? = nil, + stop: [String]? = nil, + stream: Bool? = nil, + temperature: Double? = nil, + topProbabilityMass: Double? = nil, + user: String? = nil, + functions: [xAI.ChatFunctionDefinition]? = nil + ) { + self.messages = messages + self.model = model + self.frequencyPenalty = frequencyPenalty + self.logitBias = logitBias + self.logprobs = logprobs + self.topLogprobs = topLogprobs + self.maxTokens = maxTokens + self.choices = choices + self.presencePenalty = presencePenalty + self.seed = seed + self.stop = stop + self.stream = stream + self.temperature = temperature + self.topProbabilityMass = topProbabilityMass + self.user = user + self.tools = functions?.map { xAI.Tool.function($0) } + + } + + init( + user: String?, + messages: [xAI.ChatMessage], + model: xAI.Model, + temperature: Double?, + topProbabilityMass: Double?, + choices: Int?, + stream: Bool?, + stop: [String]?, + maxTokens: Int?, + presencePenalty: Double?, + frequencyPenalty: Double? + ) { + self.user = user + self.messages = messages + self.model = model + self.temperature = temperature + self.topProbabilityMass = topProbabilityMass + self.choices = choices + self.stream = stream + self.stop = stop + self.maxTokens = maxTokens + self.presencePenalty = presencePenalty + self.frequencyPenalty = frequencyPenalty + + self.logitBias = nil + self.logprobs = nil + self.topLogprobs = nil + self.seed = nil + } + } +} diff --git a/Sources/xAI/Intramodular/API/xAI.APISpecification.ResponseBodies.swift b/Sources/xAI/Intramodular/API/xAI.APISpecification.ResponseBodies.swift new file mode 100644 index 00000000..6ca529f0 --- /dev/null +++ b/Sources/xAI/Intramodular/API/xAI.APISpecification.ResponseBodies.swift @@ -0,0 +1,6 @@ + +import Foundation + +extension xAI.APISpecification.ResponseBodies { + +} diff --git a/Sources/xAI/Intramodular/API/xAI.APISpecification.swift b/Sources/xAI/Intramodular/API/xAI.APISpecification.swift new file mode 100644 index 00000000..89be500e --- /dev/null +++ b/Sources/xAI/Intramodular/API/xAI.APISpecification.swift @@ -0,0 +1,120 @@ + + +import NetworkKit +import FoundationX +import Swallow + +extension xAI { + public enum APIError: APIErrorProtocol { + public typealias API = xAI.APISpecification + + case apiKeyMissing + case incorrectAPIKeyProvided + case rateLimitExceeded + case badRequest(request: API.Request?, error: API.Request.Error) + case runtime(AnyError) + + public var traits: ErrorTraits { + [.domain(.networking)] + } + } + + public struct APISpecification: RESTAPISpecification { + public typealias Error = APIError + + public struct Configuration: Codable, Hashable { + public var apiKey: String? + } + + public let configuration: Configuration + + public var host: URL { + URL(string: "https://api.x.ai/v1")! + } + + public var id: some Hashable { + configuration + } + + @POST + @Path("chat/completions") + var chatCompletions = Endpoint() + } +} + +extension xAI.APISpecification { + public final class Endpoint: BaseHTTPEndpoint { + override public func buildRequestBase( + from input: Input, + context: BuildRequestContext + ) throws -> Request { + let configuration = context.root.configuration + + return try super + .buildRequestBase(from: input, context: context) + .jsonBody(input, keyEncodingStrategy: .convertToSnakeCase) + .header(.contentType(.json)) + .header(.accept(.json)) + .header(.authorization(.bearer, configuration.apiKey.unwrap())) + } + + struct _ErrorWrapper: Codable, Hashable, Sendable { + struct Error: Codable, Hashable, Sendable { + let type: String + let param: AnyCodable? + let message: String + } + + let error: Error + } + + override public func decodeOutputBase( + from response: Request.Response, + context: DecodeOutputContext + ) throws -> Output { + do { + try response.validate() + } catch { + let apiError: Error + + if let error = error as? Request.Error { + if let error = try? response.decode( + _ErrorWrapper.self, + keyDecodingStrategy: .convertFromSnakeCase + ).error { + if error.message.contains("You didn't provide an API key") { + throw Error.apiKeyMissing + } else if error.message.contains("Incorrect API key provided") { + throw Error.incorrectAPIKeyProvided + } + } + + if response.statusCode.rawValue == 429 { + apiError = .rateLimitExceeded + } else { + apiError = .badRequest(error) + } + } else { + apiError = .runtime(error) + } + + throw apiError + } + + return try response.decode( + Output.self, + keyDecodingStrategy: .convertFromSnakeCase + ) + } + } +} + +extension xAI.APISpecification { + public enum RequestBodies: _StaticSwift.Namespace { + + } + + public enum ResponseBodies: _StaticSwift.Namespace { + + } +} diff --git a/Sources/xAI/Intramodular/Models/xAI.ChatCompletion.swift b/Sources/xAI/Intramodular/Models/xAI.ChatCompletion.swift new file mode 100644 index 00000000..084aebdc --- /dev/null +++ b/Sources/xAI/Intramodular/Models/xAI.ChatCompletion.swift @@ -0,0 +1,36 @@ + + +import Foundation + +extension xAI { + public struct ChatCompletion: Codable, Hashable, Sendable { + + public struct Choice: Codable, Hashable, Sendable { + public enum FinishReason: String, Codable, Hashable, Sendable { + case stop = "stop" + case length = "length" + case modelLength = "model_length" + case toolCalls = "tool_calls" + } + + public let index: Int + public let message: ChatMessage + public let finishReason: FinishReason + } + + public struct Usage: Codable, Hashable, Sendable { + public let promptTokens: Int + public let completionTokens: Int + public let totalTokens: Int + } + + public var id: String + public var object: String + public var created: Date + public var model: Model + public var choices: [Choice] + public let usage: Usage + public let systemFingerprint: String + } +} + diff --git a/Sources/xAI/Intramodular/Models/xAI.ChatFunctionDefinition.swift b/Sources/xAI/Intramodular/Models/xAI.ChatFunctionDefinition.swift new file mode 100644 index 00000000..ca4b8169 --- /dev/null +++ b/Sources/xAI/Intramodular/Models/xAI.ChatFunctionDefinition.swift @@ -0,0 +1,17 @@ + +import CorePersistence + +extension xAI { + public struct ChatFunctionDefinition: Codable, Hashable, Sendable { + public let name: String + public let description: String + public let parameters: JSONSchema + + public init(name: String, description: String, parameters: JSONSchema) { + self.name = name + self.description = description + self.parameters = parameters + } + } +} + diff --git a/Sources/xAI/Intramodular/Models/xAI.ChatMessage.swift b/Sources/xAI/Intramodular/Models/xAI.ChatMessage.swift new file mode 100644 index 00000000..13b3f651 --- /dev/null +++ b/Sources/xAI/Intramodular/Models/xAI.ChatMessage.swift @@ -0,0 +1,210 @@ + +import CorePersistence +import Diagnostics +import LargeLanguageModels +import Swallow + +extension xAI { + public struct ChatMessage: Hashable, Sendable { + public typealias ID = String + + public let id: ID + public let role: ChatRole + public var body: ChatMessageBody + + public init( + id: ID? = nil, + role: ChatRole, + body: ChatMessageBody + ) { + switch body { + case .text: + assert(role != .function) + case .content: + assert(role != .function) + case .functionCall: + assert(role == .assistant) + case .toolCalls(_): + assert(role == .assistant) + case .functionInvocation: + assert(role == .function) + } + + self.id = id ?? UUID().stringValue // FIXME: !!! + self.role = role + self.body = body + } + } + + public enum FunctionCallingStrategy: Codable, Hashable, Sendable { + enum CodingKeys: String, CodingKey { + case none = "none" + case auto = "auto" + case function = "name" + } + + case none + case auto + case function(String) + + public init(from decoder: Decoder) throws { + switch try decoder._determineContainerKind() { + case .singleValue: + let rawValue = try decoder.singleValueContainer().decode(String.self) + + switch rawValue { + case CodingKeys.none.rawValue: + self = .none + case CodingKeys.auto.rawValue: + self = .auto + default: + throw DecodingError.dataCorrupted(.init(codingPath: [])) + } + case .keyed: + let container = try decoder.container(keyedBy: CodingKeys.self) + + self = try .function(container.decode(String.self, forKey: .function)) + default: + throw DecodingError.dataCorrupted(.init(codingPath: [])) + } + } + + public func encode(to encoder: Encoder) throws { + switch self { + case .none: + var container = encoder.singleValueContainer() + + try container.encode(CodingKeys.none.rawValue) + case .auto: + var container = encoder.singleValueContainer() + + try container.encode(CodingKeys.auto.rawValue) + case .function(let name): + var container = encoder.container(keyedBy: CodingKeys.self) + + try container.encode(name, forKey: .function) + } + } + } +} + +// MARK: - Conformances + +extension xAI.ChatMessage: AbstractLLM.ChatMessageConvertible { + public func __conversion() throws -> AbstractLLM.ChatMessage { + .init( + id: .init(rawValue: id), + role: try role.__conversion(), + content: try PromptLiteral(from: self) + ) + } +} + +extension xAI.ChatMessage: Codable { + public enum CodingKeys: CodingKey { + case id + case role + case content + case name + case functionCall + case toolCalls + } + + public init(from decoder: Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + print(try JSON(from: decoder).prettyPrintedDescription) + + + self.id = try container.decodeIfPresent(String.self, forKey: .id) ?? UUID().stringValue // FIXME + self.role = try container.decode(xAI.ChatRole.self, forKey: .role) + + switch role { + case .function: + self.body = .functionInvocation( + .init( + name: try container.decode(String.self, forKey: .name), + response: try container.decode(String.self, forKey: .name) + ) + ) + case .assistant: + if let toolCalls = try container.decodeIfPresent([xAI.ToolCall].self, forKey: .toolCalls) { + if let function = toolCalls.first?.function { + self.body = .functionCall(function) + } else { + self.body = .toolCalls(toolCalls) + } + + } else { + self.body = try .content(container.decode(String.self, forKey: .content)) + } + default: + self.body = try .content(container.decode(String.self, forKey: .content)) + } + } + + public func encode(to encoder: Encoder) throws { + var container = encoder.container(keyedBy: CodingKeys.self) + + try container.encode(role, forKey: .role) + + switch body { + case .text(let content): + try container.encode(content, forKey: .content) + case .content(let content): + try container.encode(content, forKey: .content) + case .functionCall(let call): + try _tryAssert(role == .assistant) + + try container.encode(call, forKey: .functionCall) + try container.encodeNil(forKey: .content) + case .toolCalls(let calls): + try _tryAssert(role == .assistant) + + try container.encode(calls, forKey: .toolCalls) + try container.encodeNil(forKey: .content) + case .functionInvocation(let invocation): + try _tryAssert(role == .function) + + try container.encode(invocation.name, forKey: .name) + try container.encode(invocation.response, forKey: .content) + } + } +} + +// MARK: - Initializers + +extension xAI.ChatMessage { + public init( + id: ID? = nil, + role: xAI.ChatRole, + body: String + ) { + self.init( + id: id, + role: role, + body: .content(body) + ) + } + + public init( + role: xAI.ChatRole, + content: String + ) { + self.init( + role: role, + body: content + ) + } + + public static func system( + _ content: String + ) -> Self { + Self(id: UUID().stringValue, role: .system, body: .content(content)) + } + + public static func user( + _ content: String + ) -> Self { + Self(id: UUID().stringValue, role: .user, body: .content(content)) + } +} diff --git a/Sources/xAI/Intramodular/Models/xAI.ChatMessageBody.swift b/Sources/xAI/Intramodular/Models/xAI.ChatMessageBody.swift new file mode 100644 index 00000000..1f9122e1 --- /dev/null +++ b/Sources/xAI/Intramodular/Models/xAI.ChatMessageBody.swift @@ -0,0 +1,215 @@ + + +import CorePersistence +import Diagnostics +import Swift + +extension xAI { + public enum ChatMessageBody: Hashable, Sendable { + + + public struct FunctionCall: Codable, Hashable, Sendable { + public let name: String + public let arguments: String + + public init(name: String, arguments: String) { + self.name = name + self.arguments = arguments + } + } + + public struct FunctionInvocation: Codable, Hashable, Sendable { + public let name: String + public let response: String + + public init(name: String, response: String) { + self.name = name + self.response = response + } + } + + case text(String) + case content([_Content]) + /// The call made to a function provided to the LLM. + case functionCall(FunctionCall) + case toolCalls([ToolCall]) + /// The result of a function call of a function that was provided to the LLM. + case functionInvocation(FunctionInvocation) + } +} + +// MARK: - Initializers + +extension xAI.ChatMessageBody { + public static func content(_ text: String) -> Self { + .text(text) + } +} + +// MARK: - Extensions + +extension xAI.ChatMessageBody { + public var isEmpty: Bool { + switch self { + case .text(let text): + return text.isEmpty + case .content(let content): + return content.isEmpty + case .functionCall: + return false + case .toolCalls(let toolCalls): + return false + case .functionInvocation: + return false + } + } + + var _textValue: String? { + guard case .text(let string) = self else { + return nil + } + + return string + } + + public mutating func append(_ newText: String) throws { + switch self { + case .text(let text): + self = .text(text.appending(contentsOf: newText)) + case .content(let content): + self = .content(content.appending(.text(newText))) + case .functionCall: + throw Never.Reason.illegal + case .toolCalls(let toolCalls): + throw Never.Reason.illegal + case .functionInvocation: + throw Never.Reason.illegal + } + } + + public static func += (lhs: inout Self, rhs: String) throws { + try lhs.append(rhs) + } +} + +// MARK: - Auxiliary + +extension xAI.ChatMessageBody { + enum _ContentType: String, Codable, Hashable, Sendable { + case text = "text" + case imageURL = "image_url" + } + + public enum _Content: Sendable { + public struct ImageURL: Codable, Hashable, Sendable { + /// https://platform.openai.com/docs/guides/vision/low-or-high-fidelity-image-understanding + public enum ImageDetail: String, Codable, Hashable, Sendable { + case low + case high + case auto + } + + public let url: URL + public let detail: ImageDetail + + public init(url: URL, detail: ImageDetail = .auto) { + self.url = url + self.detail = detail + } + } + + case text(String) + case imageURL(ImageURL) + + public static func imageURL(_ url: URL) -> Self { + Self.imageURL(ImageURL(url: url, detail: .auto)) + } + } +} + +// MARK: - Conformances + +extension xAI.ChatMessageBody._Content: Codable { + fileprivate enum CodingKeys: String, CodingKey { + case type + case text + case imageURL = "image_url" + } + + public init(from decoder: Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + print(try JSON(from: decoder).prettyPrintedDescription) + + let contentType = try container.decode(xAI.ChatMessageBody._ContentType.self, forKey: .type) + + switch contentType { + case .text: + self = .text(try container.decode(String.self, forKey: .text)) + case .imageURL: + self = .imageURL(try container.decode(ImageURL.self, forKey: .imageURL)) + } + } + + public func encode(to encoder: Encoder) throws { + var container = encoder.container(keyedBy: CodingKeys.self) + switch self { + case .text(let text): + try container.encode("text", forKey: .type) + try container.encode(text, forKey: .text) + case .imageURL(let imageURL): + try container.encode("image_url", forKey: .type) + try container.encode(imageURL, forKey: .imageURL) + } + } +} + +extension xAI.ChatMessageBody: CustomStringConvertible { + public var description: String { + switch self { + case .text(let text): + return text.description + case .content(let content): + return content.description + case .functionCall(let call): + return "\(call.name)(\(call.arguments))" + case .toolCalls(let calls): + return calls.map { "\($0.function.name)(\($0.function.arguments))" }.joined(separator: ", ") + case .functionInvocation(let invocation): + return "\(invocation.name)(...) = \(invocation.response)" + } + } +} + +extension xAI.ChatMessageBody._Content: CustomStringConvertible { + public var description: String { + switch self { + case .text(let text): + return text.description + case .imageURL(let imageURL): + return imageURL.url.description + } + } +} + +extension xAI.ChatMessageBody._Content: Hashable { + public func hash(into hasher: inout Hasher) { + switch self { + case .text(let string): + hasher.combine(string) + case .imageURL(let url): + hasher.combine(url) + } + } + + public static func == (lhs: Self, rhs: Self) -> Bool { + switch (lhs, rhs) { + case let (.text(a), .text(b)): + return a == b + case let (.imageURL(a), .imageURL(b)): + return a == b + default: + return false + } + } +} + diff --git a/Sources/xAI/Intramodular/Models/xAI.ChatRole.swift b/Sources/xAI/Intramodular/Models/xAI.ChatRole.swift new file mode 100644 index 00000000..b7e04857 --- /dev/null +++ b/Sources/xAI/Intramodular/Models/xAI.ChatRole.swift @@ -0,0 +1,41 @@ + +import CorePersistence +import Diagnostics +import LargeLanguageModels +import Swallow + +extension xAI { + public enum ChatRole: String, Codable, Hashable, Sendable { + case system + case user + case assistant + case function + + public init(from role: AbstractLLM.ChatRole) { + switch role { + case .system: + self = .system + case .user: + self = .user + case .assistant: + self = .assistant + case .other(.function): + self = .function + } + } + + public func __conversion() throws -> AbstractLLM.ChatRole { + switch self { + case .system: + return .system + case .user: + return .user + case .assistant: + return .assistant + case .function: + return .other(.function) + } + } + } +} + diff --git a/Sources/xAI/Intramodular/Models/xAI.Tool.swift b/Sources/xAI/Intramodular/Models/xAI.Tool.swift new file mode 100644 index 00000000..148be8e4 --- /dev/null +++ b/Sources/xAI/Intramodular/Models/xAI.Tool.swift @@ -0,0 +1,51 @@ + + +extension xAI { + public enum ToolType: String, CaseIterable, Codable, Hashable, Sendable { + /* Currently, only functions are supported as a tool. */ + case function + } + + public struct Tool: Codable, Hashable, Sendable { + public let type: ToolType + public let function: xAI.ChatFunctionDefinition? + + private init( + type: ToolType, + function: xAI.ChatFunctionDefinition? + ) { + self.type = type + self.function = function + + if function != nil { + assert(type == .function) + } + } + + public static func function( + _ function: xAI.ChatFunctionDefinition + ) -> Self { + Self(type: .function, function: function) + } + } + + public struct ToolCall: Codable, Hashable, Sendable { + public let index: Int? + public let id: String? + public let type: ToolType? + public let function: ChatMessageBody.FunctionCall + + public init( + index: Int? = nil, + id: String?, + type: ToolType = .function, + function: ChatMessageBody.FunctionCall + ) { + self.index = index + self.id = id + self.type = type + self.function = function + } + } +} + diff --git a/Sources/xAI/Intramodular/xAI+LLMRequestHandling.swift b/Sources/xAI/Intramodular/xAI+LLMRequestHandling.swift new file mode 100644 index 00000000..e978f81a --- /dev/null +++ b/Sources/xAI/Intramodular/xAI+LLMRequestHandling.swift @@ -0,0 +1,156 @@ + + +import CorePersistence +import LargeLanguageModels +import NetworkKit +import Swallow + +extension xAI.Client: _TaskDependenciesExporting { + public var _exportedTaskDependencies: TaskDependencies { + var result = TaskDependencies() + + result[\.llm] = self + + return result + } +} + +extension xAI.Client: LLMRequestHandling { + public var _availableModels: [ModelIdentifier]? { + xAI.Model.allCases.map({ $0.__conversion() }) + } + + public func complete( + prompt: Prompt, + parameters: Prompt.CompletionParameters + ) async throws -> Prompt.Completion { + let _completion: Any + + switch prompt { + case let prompt as AbstractLLM.TextPrompt: + _completion = try await _complete( + prompt: prompt, + parameters: try cast(parameters) + ) + + case let prompt as AbstractLLM.ChatPrompt: + _completion = try await _complete( + prompt: prompt, + parameters: try cast(parameters) + ) + default: + throw LLMRequestHandlingError.unsupportedPromptType(Prompt.self) + } + + return try cast(_completion) + } + + private func _complete( + prompt: AbstractLLM.TextPrompt, + parameters: AbstractLLM.TextCompletionParameters + ) async throws -> AbstractLLM.TextCompletion { + throw LLMRequestHandlingError.unsupportedPromptType(.init(Swift.type(of: prompt))) + } + + private func _complete( + prompt: AbstractLLM.ChatPrompt, + parameters: AbstractLLM.ChatCompletionParameters + ) async throws -> AbstractLLM.ChatCompletion { + + var messages: [xAI.ChatMessage] = [] + for message in prompt.messages { + let chatMessage = try await xAI.ChatMessage(from: message) + messages.append(chatMessage) + } + + let response: xAI.ChatCompletion = try await run( + \.chatCompletions, + with: .init( + messages: messages, + model: _model(for: prompt, parameters: parameters), + maxTokens: parameters.tokenLimit?.fixedValue, + seed: nil, + stream: false, + temperature: parameters.temperatureOrTopP?.temperature, + topProbabilityMass: parameters.temperatureOrTopP?.topProbabilityMass, + functions: parameters.functions?.map { xAI.ChatFunctionDefinition(from: $0) } + ) + ) + + assert(response.choices.count == 1) + + let message = try AbstractLLM.ChatMessage(from: response, choiceIndex: 0) + + return AbstractLLM.ChatCompletion( + prompt: prompt.messages, + message: message, + stopReason: .init() // FIXME: !!! + ) + } + + private func _model( + for prompt: AbstractLLM.ChatPrompt, + parameters: AbstractLLM.ChatCompletionParameters? + ) throws -> xAI.Model { + try prompt.context.get(\.modelIdentifier)?.as(xAI.Model.self) ?? .grok_beta + } +} + +// MARK: - Auxiliary + +extension AbstractLLM.ChatRole { + public init( + from role: xAI.ChatRole + ) throws { + switch role { + case .system: + self = .system + case .user: + self = .user + case .assistant: + self = .assistant + case .function: + self = .other(.function) + } + } +} + +extension AbstractLLM.ChatMessage { + public init( + from completion: xAI.ChatCompletion, + choiceIndex: Int + ) throws { + let choice = completion.choices[choiceIndex] + + self.init( + id: AnyPersistentIdentifier(erasing: "\(completion.id)_\(choiceIndex.description)"), + role: try AbstractLLM.ChatRole(from: choice.message.role), + content: PromptLiteral(choice.message.body.description) + ) + } +} + +extension xAI.ChatMessage { + public init( + from message: AbstractLLM.ChatMessage + ) throws { + self.init( + role: xAI.ChatRole( + from: message.role + ), + content: try message.content._stripToText() + ) + } +} + +extension xAI.ChatFunctionDefinition { + public init( + from function: AbstractLLM.ChatFunctionDefinition + ) { + self.init( + name: function.name.rawValue, + description: function.context, + parameters: function.parameters + ) + } +} diff --git a/Sources/xAI/Intramodular/xAI.ChatMessage+LargeLanguageModels.swift b/Sources/xAI/Intramodular/xAI.ChatMessage+LargeLanguageModels.swift new file mode 100644 index 00000000..a1937e11 --- /dev/null +++ b/Sources/xAI/Intramodular/xAI.ChatMessage+LargeLanguageModels.swift @@ -0,0 +1,258 @@ + + +import CorePersistence +import FoundationX +@_spi(Internal) import LargeLanguageModels + +extension xAI.ChatMessage: _PromptLiteralEncodingContainer { + public mutating func encode( + _ component: PromptLiteral._Degenerate.Component + ) async throws { + var content: [xAI.ChatMessageBody._Content] + + switch self.body { + case .text(let _content): + content = [.text(_content)] + case .content(let _content): + content = _content + case .functionCall(_): + throw Never.Reason.unsupported + case .toolCalls(_): + throw Never.Reason.unsupported + case .functionInvocation(_): + throw Never.Reason.unsupported + } + + switch component.payload { + case .string(let string): + content.append(.text(string)) + case .image(let image): + let imageURL: Base64DataURL = try await image.toBase64DataURL() + + content.append(.imageURL(xAI.ChatMessageBody._Content.ImageURL(url: imageURL.url, detail: .auto))) + case .functionCall: + throw Never.Reason.unsupported + case .resultOfFunctionCall: + throw Never.Reason.unsupported + } + + self = .init( + id: nil, // FIXME: !!! + role: role, + body: .content(content) + ) + } +} + +extension xAI.ChatMessage { + public init( + from message: AbstractLLM.ChatMessage + ) async throws { + let role: xAI.ChatRole + + switch message.role { + case .system: + role = .system + case .user: + role = .user + case .assistant: + role = .assistant + case .other(.function): + role = .function + } + + let _content = try message.content._degenerate() + + if _content.components.contains(where: { $0.payload.type == .functionCall || $0.payload.type == .functionInvocation }) { + switch try _content.components.toCollectionOfOne().value.payload { + case .functionCall(let call): + self.init( + id: nil, + // FIXME: !!! + role: role, + body: .functionCall( + xAI.ChatMessageBody.FunctionCall( + name: call.name.rawValue, + arguments: try call.arguments.__conversion() + ) + ) + ) + case .resultOfFunctionCall(let result): + self.init( + id: nil, // FIXME: !!! + role: role, + body: .functionInvocation( + .init( + name: result.name.rawValue, + response: try result.result.__conversion() as String + ) + ) + ) + default: + assertionFailure("Unsupported prompt literal.") + + throw Never.Reason.illegal + } + } else { + var _temp = Self( + id: nil, // FIXME: !!! + role: role, + body: .content([]) + ) + + try await message.content._encode(to: &_temp) + + self = _temp + } + } +} + +extension AbstractLLM.ChatMessage { + public init( + from message: xAI.ChatMessage + ) throws { + let id = message.id + let role: AbstractLLM.ChatRole + + switch message.role { + case .system: + role = .system + case .user: + role = .user + case .assistant: + role = .assistant + case .function: + role = .other(.function) + } + + switch message.body { + case .text(let content): + self.init( + id: AnyPersistentIdentifier(erasing: id), + role: role, + content: PromptLiteral( + content, + role: .chat(role) + ) + ) + case .content(let content): + self.init( + id: AnyPersistentIdentifier(erasing: id), + role: role, + content: PromptLiteral( + from: content, + role: .chat(role) + ) + ) + case .functionCall(let call): + self.init( + id: AnyPersistentIdentifier(erasing: id), + role: role, + content: try PromptLiteral( + functionCall: .init( + functionID: nil, + name: AbstractLLM.ChatFunction.Name(rawValue: call.name), + arguments: AbstractLLM.ChatFunctionCall.Arguments(unencoded: call.arguments), + context: .init() + ), + role: .chat(role) + ) + ) + case .toolCalls(let calls): + guard let firstCall = calls.first?.function else { + throw DecodingError.dataCorrupted(.init( + codingPath: [], + debugDescription: "Tool calls array is empty" + )) + } + + self.init( + id: AnyPersistentIdentifier(erasing: id), + role: role, + content: try PromptLiteral( + functionCall: .init( + functionID: nil, + name: AbstractLLM.ChatFunction.Name(rawValue: firstCall.name), + arguments: AbstractLLM.ChatFunctionCall.Arguments(unencoded: firstCall.arguments), + context: .init() + ), + role: .chat(role) + ) + ) + case .functionInvocation(let invocation): + self.init( + id: AnyPersistentIdentifier(erasing: id), + role: role, + content: try .init( + functionInvocation: .init( + functionID: nil, + name: AbstractLLM.ChatFunction.Name(rawValue: invocation.name), + result: .init(rawValue: invocation.response) + ), + role: .chat(role) + ) + ) + } + } +} + +extension PromptLiteral { + public init(from message: xAI.ChatMessage) throws { + let role: PromptMatterRole + + switch message.role { + case .system: + role = .chat(.system) + case .user: + role = .chat(.user) + case .assistant: + role = .chat(.assistant) + case .function: + role = .chat(.other(.function)) + } + + switch message.body { + case .text(let text): + self.init(from: [.text(text)], role: role) + case .content(let content): + self.init(from: content, role: role) + case .functionCall: + TODO.unimplemented + case .toolCalls(_): + TODO.unimplemented + case .functionInvocation: + TODO.unimplemented + } + } + + init( + from contents: [xAI.ChatMessageBody._Content], + role: PromptMatterRole + ) { + var components: [PromptLiteral.StringInterpolation.Component] = [] + + for content in contents { + switch content { + case .text(let content): + components.append( + PromptLiteral.StringInterpolation.Component( + payload: .stringLiteral(content), + role: role + ) + ) + case .imageURL(let image): + assert(image.detail == .auto) // FIXME + + components.append( + PromptLiteral.StringInterpolation.Component( + payload: .image(.url(image.url)), + role: role + ) + ) + } + } + + self.init(stringInterpolation: .init(components: components)) + } +} + diff --git a/Sources/xAI/Intramodular/xAI.Client.swift b/Sources/xAI/Intramodular/xAI.Client.swift new file mode 100644 index 00000000..eb33832f --- /dev/null +++ b/Sources/xAI/Intramodular/xAI.Client.swift @@ -0,0 +1,50 @@ + + +import CorePersistence +import LargeLanguageModels +import Merge +import NetworkKit +import Swallow + +extension xAI { + @RuntimeDiscoverable + public final class Client: HTTPClient, _StaticSwift.Namespace { + public static var persistentTypeRepresentation: some IdentityRepresentation { + _MIServiceTypeIdentifier._xAI + } + + public let interface: APISpecification + public let session: HTTPSession + + public init(interface: APISpecification, session: HTTPSession) { + self.interface = interface + self.session = session + } + + public convenience init(apiKey: String?) { + self.init( + interface: .init(configuration: .init(apiKey: apiKey)), + session: .shared + ) + } + } +} + +extension xAI.Client: _MIService { + public convenience init( + account: (any _MIServiceAccount)? + ) async throws { + let account: any _MIServiceAccount = try account.unwrap() + let serviceIdentifier: _MIServiceTypeIdentifier = account.serviceIdentifier + + guard serviceIdentifier == _MIServiceTypeIdentifier._xAI else { + throw _MIServiceError.serviceTypeIncompatible(serviceIdentifier) + } + + guard let credential = account.credential as? _MIServiceAPIKeyCredential else { + throw _MIServiceError.invalidCredentials(account.credential) + } + + self.init(apiKey: credential.apiKey) + } +} diff --git a/Sources/xAI/Intramodular/xAI.Model.swift b/Sources/xAI/Intramodular/xAI.Model.swift new file mode 100644 index 00000000..bc06398d --- /dev/null +++ b/Sources/xAI/Intramodular/xAI.Model.swift @@ -0,0 +1,52 @@ + + +import CoreMI +import CorePersistence +import LargeLanguageModels +import Swallow + +extension xAI { + public enum Model: String, CaseIterable, Codable, Hashable, Named, Sendable { + case grok_beta = "grok-beta" + case grok_vision_beta = "grok-vision-beta" + + public var name: String { + switch self { + case .grok_beta: + return "Grok Beta" + case .grok_vision_beta: + return "Grok Vision Beta" + } + } + } +} + +// MARK: - Conformances + +extension xAI.Model: CustomStringConvertible { + public var description: String { + rawValue + } +} + +extension xAI.Model: ModelIdentifierRepresentable { + public init(from identifier: ModelIdentifier) throws { + guard identifier.provider == ._xAI, identifier.revision == nil else { + throw Never.Reason.illegal + } + + guard let model = Self(rawValue: identifier.name) else { + throw Never.Reason.unexpected + } + + self = model + } + + public func __conversion() -> ModelIdentifier { + ModelIdentifier( + provider: ._xAI, + name: rawValue, + revision: nil + ) + } +} diff --git a/Sources/xAI/Intramodular/xAI.swift b/Sources/xAI/Intramodular/xAI.swift new file mode 100644 index 00000000..24ccd67a --- /dev/null +++ b/Sources/xAI/Intramodular/xAI.swift @@ -0,0 +1,7 @@ + +import Swift + +public enum xAI { + +} + diff --git a/Sources/xAI/module.swift b/Sources/xAI/module.swift new file mode 100644 index 00000000..86462ca6 --- /dev/null +++ b/Sources/xAI/module.swift @@ -0,0 +1,5 @@ + + +@_exported import LargeLanguageModels +@_exported import SwallowMacrosClient + diff --git a/Tests/Mistral/Intramodular/EmbeddingsTests.swift b/Tests/Mistral/Intramodular/EmbeddingsTests.swift index 4aee6269..3d8c2d75 100644 --- a/Tests/Mistral/Intramodular/EmbeddingsTests.swift +++ b/Tests/Mistral/Intramodular/EmbeddingsTests.swift @@ -3,7 +3,6 @@ // import LargeLanguageModels -import Groq import XCTest import Mistral diff --git a/Tests/xAI/Intramodular/CompletionTests.swift b/Tests/xAI/Intramodular/CompletionTests.swift new file mode 100644 index 00000000..ef9e78b9 --- /dev/null +++ b/Tests/xAI/Intramodular/CompletionTests.swift @@ -0,0 +1,43 @@ + + +import LargeLanguageModels +import xAI +import XCTest + +final class CompletionTests: XCTestCase { + + let llm: any LLMRequestHandling = client + + func testChatCompletionsGrokBeta() async throws { + let result = try await resultForModel(xAI.Model.grok_beta) + print(result) // "Hey! What's up with you?" + } + + func testChatCompletionsGrokVisionBeta() async throws { + let result = try await resultForModel(xAI.Model.grok_vision_beta) + print(result) // "Hey! How can I help you today?" + } + + private func resultForModel(_ model: xAI.Model) async throws -> String { + + let messages: [AbstractLLM.ChatMessage] = [ + AbstractLLM.ChatMessage( + role: .system, + body: "You are an extremely intelligent assistant." + ), + AbstractLLM.ChatMessage( + role: .user, + body: "Sup?" + ) + ] + + let result: String = try await llm.complete( + messages, + model: model, + as: String.self + ) + + return result + } +} + diff --git a/Tests/xAI/Intramodular/FunctionCallingTests.swift b/Tests/xAI/Intramodular/FunctionCallingTests.swift new file mode 100644 index 00000000..ba6af3cb --- /dev/null +++ b/Tests/xAI/Intramodular/FunctionCallingTests.swift @@ -0,0 +1,103 @@ + + +import CorePersistence +import xAI +import XCTest + +final class FunctionCallingTests: XCTestCase { + let llm: any LLMRequestHandling = client + + func testFunctionCalling() async throws { + let messages: [AbstractLLM.ChatMessage] = [ + .system { + "You are a Metereologist Expert accurately giving weather data in fahrenheit at any given city around the world" + }, + .user { + "What is the weather in San Francisco, CA?" + } + ] + + let functionCall1: AbstractLLM.ChatFunctionCall = try await llm.complete( + messages, + functions: [makeGetWeatherFunction1()], + as: .functionCall + ) + + let functionCall2: AbstractLLM.ChatFunctionCall = try await llm.complete( + messages, + functions: [makeGetWeatherFunction2()], + as: .functionCall + ) + + let result1 = try functionCall1.decode(GetWeatherParameters.self) + let result2 = try functionCall2.decode(GetWeatherParameters.self) + + print(result1, result2) + } + + private func makeGetWeatherFunction1() -> AbstractLLM.ChatFunctionDefinition { + let weatherObjectSchema = JSONSchema( + type: .object, + description: "Weather in a certain location", + properties: [ + "location": JSONSchema( + type: .string, + description: "The city and state, e.g. San Francisco, CA" + ), + "unit_fahrenheit" : JSONSchema( + type: .number, + description: "The unit of temperature in 'fahrenheit'" + ) + ], + required: true + ) + + let getWeatherFunction: AbstractLLM.ChatFunctionDefinition = AbstractLLM.ChatFunctionDefinition( + name: "get_weather", + context: "Get the current weather in a given location", + parameters: JSONSchema( + type: .object, + description: "Weather data for a given location in fahrenheit", + properties: [ + "weather": .array(weatherObjectSchema) + ] + ) + ) + + return getWeatherFunction + } + + struct GetWeatherParameters: Codable, Hashable, Sendable { + let weather: [WeatherObject] + } + + struct WeatherObject: Codable, Hashable, Sendable { + let location: String + let unit_fahrenheit: Double? + } + + private func makeGetWeatherFunction2() throws -> AbstractLLM.ChatFunctionDefinition { + let getWeatherFunction: AbstractLLM.ChatFunctionDefinition = AbstractLLM.ChatFunctionDefinition( + name: "get_weather", + context: "Get the current weather in a given location", + parameters: JSONSchema( + type: .object, + description: "Weather data for a given location in fahrenheit", + properties: [ + "weather": try .array { + try JSONSchema( + type: WeatherObject.self, + description: "Weather in a certain location", + propertyDescriptions: [ + "location": "The city and state, e.g. San Francisco, CA", + "unit_fahrenheit": "The unit of temperature in 'fahrenheit'" + ] + ) + } + ] + ) + ) + + return getWeatherFunction + } +} diff --git a/Tests/xAI/module.swift b/Tests/xAI/module.swift new file mode 100644 index 00000000..bd88733f --- /dev/null +++ b/Tests/xAI/module.swift @@ -0,0 +1,12 @@ + + +import xAI + +public var xAI_API_KEY: String { + "xai-iukwcbTFm3HCyuJVq7U5c0c9LKHJ0uhnGIsiOyn4Qu0zxSH3g1ULDSkaCHHoDQnX9tsV5cSCWom0HosP" +} + +public var client: xAI.Client { + xAI.Client(apiKey: xAI_API_KEY) +} +