From a02bde3efd609d9d44d20b1d94396ddbb070e0ba Mon Sep 17 00:00:00 2001 From: mq200 Date: Thu, 5 Sep 2024 18:22:28 -0700 Subject: [PATCH 1/4] watsonx-integration --- .../java/ee/carlrobert/codegpt/Icons.java | 1 + .../completions/CompletionClientProvider.java | 12 + .../CompletionRequestProvider.java | 25 ++ .../completions/CompletionRequestService.java | 5 +- .../conversations/ConversationService.java | 4 + .../codegpt/settings/GeneralSettings.java | 10 + .../codegpt/settings/service/ServiceType.java | 4 +- .../service/watsonx/WatsonxSettings.java | 36 +++ .../service/watsonx/WatsonxSettingsForm.java | 304 ++++++++++++++++++ .../service/watsonx/WatsonxSettingsState.java | 181 +++++++++++ .../chat/ui/textarea/ModelComboBoxAction.java | 23 +- .../carlrobert/codegpt/ui/ModelIconLabel.java | 3 + .../CodeCompletionFeatureToggleActions.kt | 10 +- .../CodeCompletionRequestFactory.kt | 28 ++ .../codecompletions/CodeCompletionService.kt | 6 + .../CodeGPTInlineCompletionProvider.kt | 0 .../DebouncedCodeCompletionProvider.kt | 2 + .../codegpt/credentials/CredentialsStore.kt | 1 + .../service/ServiceConfigurableComponent.kt | 4 +- .../service/WatsonxServiceConfigurable.kt | 45 +++ .../service/codegpt/CodeGPTAvailableModels.kt | 15 + src/main/resources/META-INF/plugin.xml | 5 +- src/main/resources/icons/watsonx.svg | 6 + src/main/resources/icons/watsonx_dark.svg | 6 + .../resources/messages/codegpt.properties | 36 ++- 25 files changed, 742 insertions(+), 30 deletions(-) create mode 100644 src/main/java/ee/carlrobert/codegpt/settings/service/watsonx/WatsonxSettings.java create mode 100644 src/main/java/ee/carlrobert/codegpt/settings/service/watsonx/WatsonxSettingsForm.java create mode 100644 src/main/java/ee/carlrobert/codegpt/settings/service/watsonx/WatsonxSettingsState.java create mode 100644 src/main/kotlin/ee/carlrobert/codegpt/codecompletions/CodeGPTInlineCompletionProvider.kt create mode 100644 src/main/kotlin/ee/carlrobert/codegpt/settings/service/WatsonxServiceConfigurable.kt create mode 100644 src/main/resources/icons/watsonx.svg create mode 100644 src/main/resources/icons/watsonx_dark.svg diff --git a/src/main/java/ee/carlrobert/codegpt/Icons.java b/src/main/java/ee/carlrobert/codegpt/Icons.java index fb680fac2..3fbe596eb 100644 --- a/src/main/java/ee/carlrobert/codegpt/Icons.java +++ b/src/main/java/ee/carlrobert/codegpt/Icons.java @@ -24,6 +24,7 @@ public final class Icons { public static final Icon You = IconLoader.getIcon("/icons/you.svg", Icons.class); public static final Icon YouSmall = IconLoader.getIcon("/icons/you_small.png", Icons.class); public static final Icon Ollama = IconLoader.getIcon("/icons/ollama.svg", Icons.class); + public static final Icon Watsonx = IconLoader.getIcon("/icons/watsonx.svg", Icons.class); public static final Icon User = IconLoader.getIcon("/icons/user.svg", Icons.class); public static final Icon Upload = IconLoader.getIcon("/icons/upload.svg", Icons.class); public static final Icon GreenCheckmark = diff --git a/src/main/java/ee/carlrobert/codegpt/completions/CompletionClientProvider.java b/src/main/java/ee/carlrobert/codegpt/completions/CompletionClientProvider.java index fe7c796d9..4ce48d514 100644 --- a/src/main/java/ee/carlrobert/codegpt/completions/CompletionClientProvider.java +++ b/src/main/java/ee/carlrobert/codegpt/completions/CompletionClientProvider.java @@ -10,6 +10,7 @@ import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings; import ee.carlrobert.codegpt.settings.service.ollama.OllamaSettings; import ee.carlrobert.codegpt.settings.service.openai.OpenAISettings; +import ee.carlrobert.codegpt.settings.service.watsonx.WatsonxSettings; import ee.carlrobert.llm.client.anthropic.ClaudeClient; import ee.carlrobert.llm.client.azure.AzureClient; import ee.carlrobert.llm.client.azure.AzureCompletionRequestParams; @@ -18,6 +19,7 @@ import ee.carlrobert.llm.client.llama.LlamaClient; import ee.carlrobert.llm.client.ollama.OllamaClient; import ee.carlrobert.llm.client.openai.OpenAIClient; +import ee.carlrobert.llm.client.watsonx.WatsonxClient; import java.net.InetSocketAddress; import java.net.Proxy; import java.util.concurrent.TimeUnit; @@ -32,6 +34,16 @@ public static CodeGPTClient getCodeGPTClient() { getDefaultClientBuilder()); } + public static WatsonxClient getWatsonxClient() { + return new WatsonxClient.Builder(getCredential(CredentialKey.WATSONX_API_KEY)) + .setApiVersion(WatsonxSettings.getCurrentState().getApiVersion()) + .setIsOnPrem(WatsonxSettings.getCurrentState().isOnPrem()) + .setHost(WatsonxSettings.getCurrentState().getOnPremHost()) + .setUsername(WatsonxSettings.getCurrentState().getUsername()) + .setIsZenApiKey(WatsonxSettings.getCurrentState().isZenApiKey()) + .build(getDefaultClientBuilder()); + } + public static OpenAIClient getOpenAIClient() { return new OpenAIClient.Builder(getCredential(CredentialKey.OPENAI_API_KEY)) .setOrganization(OpenAISettings.getCurrentState().getOrganization()) diff --git a/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestProvider.java b/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestProvider.java index 9886b420a..d6dd2b47a 100644 --- a/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestProvider.java +++ b/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestProvider.java @@ -28,6 +28,7 @@ import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings; import ee.carlrobert.codegpt.settings.service.ollama.OllamaSettings; import ee.carlrobert.codegpt.settings.service.openai.OpenAISettings; +import ee.carlrobert.codegpt.settings.service.watsonx.WatsonxSettings; import ee.carlrobert.codegpt.util.file.FileUtil; import ee.carlrobert.llm.client.anthropic.completion.ClaudeBase64Source; import ee.carlrobert.llm.client.anthropic.completion.ClaudeCompletionDetailedMessage; @@ -55,6 +56,7 @@ import ee.carlrobert.llm.client.openai.completion.request.OpenAIMessageImageURLContent; import ee.carlrobert.llm.client.openai.completion.request.OpenAIMessageTextContent; import ee.carlrobert.llm.client.openai.completion.request.RequestDocumentationDetails; +import ee.carlrobert.llm.client.watsonx.completion.WatsonxCompletionRequest; import java.io.IOException; import java.nio.charset.StandardCharsets; import java.nio.file.Files; @@ -307,6 +309,29 @@ private static Request buildCustomOpenAIChatCompletionRequest( } } + public WatsonxCompletionRequest buildWatsonxChatCompletionRequest( + CallParameters callParameters) { + var settings = WatsonxSettings.getCurrentState(); + String prompt = PersonaSettings.getSystemPrompt(); + prompt += "\n"+callParameters.getMessage().getPrompt(); + var builder = new WatsonxCompletionRequest.Builder(prompt); + builder.setDecodingMethod(settings.isGreedyDecoding() ? "greedy" : "sample"); + builder.setModelId(settings.getModel()); + builder.setProjectId(settings.getProjectId()); + builder.setSpaceId(settings.getSpaceId()); + builder.setMaxNewTokens(settings.getMaxNewTokens()); + builder.setMinNewTokens(settings.getMinNewTokens()); + builder.setTemperature(settings.getTemperature()); + builder.setStopSequences(settings.getStopSequences().isEmpty() ? null : settings.getStopSequences().split(",")); + builder.setTopP(settings.getTopP()); + builder.setTopK(settings.getTopK()); + builder.setIncludeStopSequence(settings.getIncludeStopSequence()); + builder.setRandomSeed(settings.getRandomSeed()); + builder.setRepetitionPenalty(settings.getRepetitionPenalty()); + builder.setStream(true); + return builder.build(); + } + public ClaudeCompletionRequest buildAnthropicChatCompletionRequest( CallParameters callParameters) { var configuration = ConfigurationSettings.getState(); diff --git a/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestService.java b/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestService.java index 56066bad0..89cdfcffb 100644 --- a/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestService.java +++ b/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestService.java @@ -126,6 +126,9 @@ public EventSource getChatCompletionAsync( settings.getModel(), eventListener); } + case WATSONX -> CompletionClientProvider.getWatsonxClient().getCompletionAsync( + requestProvider.buildWatsonxChatCompletionRequest(callParameters), + eventListener); }; } @@ -285,7 +288,7 @@ public static boolean isRequestAllowed(ServiceType serviceType) { AzureSettings.getCurrentState().isUseAzureApiKeyAuthentication() ? CredentialKey.AZURE_OPENAI_API_KEY : CredentialKey.AZURE_ACTIVE_DIRECTORY_TOKEN); - case CODEGPT, CUSTOM_OPENAI, ANTHROPIC, LLAMA_CPP, OLLAMA -> true; + case CODEGPT, CUSTOM_OPENAI, ANTHROPIC, LLAMA_CPP, OLLAMA, WATSONX -> true; case GOOGLE -> CredentialsStore.INSTANCE.isCredentialSet(CredentialKey.GOOGLE_API_KEY); }; } diff --git a/src/main/java/ee/carlrobert/codegpt/conversations/ConversationService.java b/src/main/java/ee/carlrobert/codegpt/conversations/ConversationService.java index 603c9776e..45dec10fc 100644 --- a/src/main/java/ee/carlrobert/codegpt/conversations/ConversationService.java +++ b/src/main/java/ee/carlrobert/codegpt/conversations/ConversationService.java @@ -13,6 +13,7 @@ import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings; import ee.carlrobert.codegpt.settings.service.ollama.OllamaSettings; import ee.carlrobert.codegpt.settings.service.openai.OpenAISettings; +import ee.carlrobert.codegpt.settings.service.watsonx.WatsonxSettings; import java.time.LocalDateTime; import java.util.ArrayList; import java.util.Comparator; @@ -210,6 +211,9 @@ private static String getModelForSelectedService(ServiceType serviceType) { case GOOGLE -> application.getService(GoogleSettings.class) .getState() .getModel(); + case WATSONX -> application.getService(WatsonxSettings.class) + .getState() + .getModel(); }; } } diff --git a/src/main/java/ee/carlrobert/codegpt/settings/GeneralSettings.java b/src/main/java/ee/carlrobert/codegpt/settings/GeneralSettings.java index 2a254230c..a29a97737 100644 --- a/src/main/java/ee/carlrobert/codegpt/settings/GeneralSettings.java +++ b/src/main/java/ee/carlrobert/codegpt/settings/GeneralSettings.java @@ -18,6 +18,7 @@ import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings; import ee.carlrobert.codegpt.settings.service.ollama.OllamaSettings; import ee.carlrobert.codegpt.settings.service.openai.OpenAISettings; +import ee.carlrobert.codegpt.settings.service.watsonx.WatsonxSettings; import ee.carlrobert.codegpt.util.ApplicationUtil; import org.jetbrains.annotations.NotNull; @@ -92,6 +93,10 @@ public void sync(Conversation conversation) { ApplicationManager.getApplication().getService(OllamaSettings.class).getState() .setModel(conversation.getModel()); break; + case WATSONX: + ApplicationManager.getApplication().getService(WatsonxSettings.class).getState() + .setModel(conversation.getModel()); + break; default: break; } @@ -144,6 +149,11 @@ public String getModel() { .getService(GoogleSettings.class) .getState() .getModel(); + case WATSONX: + return ApplicationManager.getApplication() + .getService(WatsonxSettings.class) + .getState() + .getModel(); default: return "Unknown"; } diff --git a/src/main/java/ee/carlrobert/codegpt/settings/service/ServiceType.java b/src/main/java/ee/carlrobert/codegpt/settings/service/ServiceType.java index e8c8216ff..06d65256e 100644 --- a/src/main/java/ee/carlrobert/codegpt/settings/service/ServiceType.java +++ b/src/main/java/ee/carlrobert/codegpt/settings/service/ServiceType.java @@ -12,7 +12,9 @@ public enum ServiceType { AZURE("AZURE", "service.azure.title", "azure.chat.completion"), GOOGLE("GOOGLE", "service.google.title", "google.chat.completion"), LLAMA_CPP("LLAMA_CPP", "service.llama.title", "llama.chat.completion"), - OLLAMA("OLLAMA", "service.ollama.title", "ollama.chat.completion"); + OLLAMA("OLLAMA", "service.ollama.title", "ollama.chat.completion"), + WATSONX("WATSONX", "service.watsonx.title", "watsonx.chat.completion"); + private final String code; private final String label; diff --git a/src/main/java/ee/carlrobert/codegpt/settings/service/watsonx/WatsonxSettings.java b/src/main/java/ee/carlrobert/codegpt/settings/service/watsonx/WatsonxSettings.java new file mode 100644 index 000000000..a375a572d --- /dev/null +++ b/src/main/java/ee/carlrobert/codegpt/settings/service/watsonx/WatsonxSettings.java @@ -0,0 +1,36 @@ +package ee.carlrobert.codegpt.settings.service.watsonx; + +import com.intellij.openapi.application.ApplicationManager; +import com.intellij.openapi.components.PersistentStateComponent; +import com.intellij.openapi.components.State; +import com.intellij.openapi.components.Storage; +import ee.carlrobert.codegpt.completions.llama.LlamaModel; +import org.jetbrains.annotations.NotNull; + +@State(name = "CodeGPT_WatsonxSettings", storages = @Storage("CodeGPT_WatsonxSettings.xml")) +public class WatsonxSettings implements PersistentStateComponent { + + private WatsonxSettingsState state = new WatsonxSettingsState(); + + @Override + @NotNull + public WatsonxSettingsState getState() { + return state; + } + + @Override + public void loadState(@NotNull WatsonxSettingsState state) { + this.state = state; + } + + public static WatsonxSettingsState getCurrentState() { + return getInstance().getState(); + } + public static boolean isCodeCompletionsPossible() { + return getInstance().getState().isCodeCompletionsEnabled(); + } + + public static ee.carlrobert.codegpt.settings.service.watsonx.WatsonxSettings getInstance() { + return ApplicationManager.getApplication().getService(WatsonxSettings.class); + } +} diff --git a/src/main/java/ee/carlrobert/codegpt/settings/service/watsonx/WatsonxSettingsForm.java b/src/main/java/ee/carlrobert/codegpt/settings/service/watsonx/WatsonxSettingsForm.java new file mode 100644 index 000000000..72e59c0b6 --- /dev/null +++ b/src/main/java/ee/carlrobert/codegpt/settings/service/watsonx/WatsonxSettingsForm.java @@ -0,0 +1,304 @@ +package ee.carlrobert.codegpt.settings.service.watsonx; + +import static com.intellij.ide.BrowserUtil.open; +import static ee.carlrobert.codegpt.credentials.CredentialsStore.CredentialKey.WATSONX_API_KEY; + +import com.intellij.openapi.application.ApplicationManager; +import com.intellij.openapi.ui.ComboBox; +import com.intellij.ui.EnumComboBoxModel; + +import com.intellij.ui.TitledSeparator; +import com.intellij.ui.components.JBLabel; +import com.intellij.ui.components.JBPasswordField; +import com.intellij.ui.components.JBTextField; +import com.intellij.ui.components.JBCheckBox; +import com.intellij.util.ui.FormBuilder; +import com.intellij.util.ui.UI; +import ee.carlrobert.codegpt.CodeGPTBundle; +import ee.carlrobert.codegpt.credentials.CredentialsStore; +import ee.carlrobert.codegpt.settings.service.watsonx.WatsonxSettings; +import ee.carlrobert.codegpt.settings.service.watsonx.WatsonxSettingsState; +import ee.carlrobert.llm.client.watsonx.completion.WatsonxCompletionModel; + +import ee.carlrobert.codegpt.ui.UIUtil; + +import javax.swing.*; +import java.awt.*; +import java.awt.event.ActionEvent; +import java.awt.event.ActionListener; +import org.jetbrains.annotations.Nullable; + +public class WatsonxSettingsForm { + + private final JBCheckBox onPremCheckbox; + private final JBPasswordField apiKeyField; + private final JBPasswordField onPremApiKeyField; + private final JBTextField onPremHostField; + private final JBTextField usernameField; + private final JBCheckBox zenApiKeyCheckbox; + + private final JPanel onPremAuthenticationFieldPanel; + private final JPanel onCloudAuthenticationFieldPanel; + + private final String getStartedText = "Click here to get started with IBM watsonx.ai as a Service"; + private final String getStartedUrl = "https://dataplatform.cloud.ibm.com/registration/stepone?context=wx"; + private final JButton getStartedLink; + private final JBTextField apiVersionField; + private final JBTextField projectIdField; + private final JBTextField spaceIdField; + + private final ComboBox modelComboBox; + private final JBCheckBox greedyDecodingCheckbox; + private final JBTextField temperatureField; + private final JBTextField topKField; + private final JBTextField topPField; + private final JBTextField repetitionPenaltyField; + private final JBTextField randomSeedField; + private final JBTextField maxNewTokensField; + private final JBTextField minNewTokensField; + private final JBTextField stopSequencesField; + private final JBCheckBox includeStopSequenceCheckbox; + private final JPanel sampleParametersFieldPanel; + + public WatsonxSettingsForm(WatsonxSettingsState settings) { + onPremCheckbox = new JBCheckBox("Watsonx.ai on-premises software", false); + onPremHostField = new JBTextField(settings.getOnPremHost(), 35); + + usernameField = new JBTextField(settings.getUsername(), 35); + + zenApiKeyCheckbox = new JBCheckBox("Is Platform (Zen) API key", false); + + + class OpenUrlAction implements ActionListener { + @Override public void actionPerformed(ActionEvent e) { + open(getStartedUrl); + } + } + getStartedLink = new JButton(); + getStartedLink.setText(""+getStartedText+""); + getStartedLink.setHorizontalAlignment(SwingConstants.LEFT); + getStartedLink.setBorderPainted(false); + getStartedLink.setCursor(new Cursor(Cursor.HAND_CURSOR)); + getStartedLink.setToolTipText(getStartedUrl); + getStartedLink.addActionListener(new OpenUrlAction()); + + apiKeyField = new JBPasswordField(); + apiKeyField.setColumns(35); + ApplicationManager.getApplication().executeOnPooledThread(() -> { + var apiKey = CredentialsStore.getCredential(WATSONX_API_KEY); + SwingUtilities.invokeLater(() -> apiKeyField.setText(apiKey)); + }); + + onPremApiKeyField = new JBPasswordField(); + onPremApiKeyField.setColumns(35); + ApplicationManager.getApplication().executeOnPooledThread(() -> { + var apiKey = CredentialsStore.getCredential(WATSONX_API_KEY); + SwingUtilities.invokeLater(() -> onPremApiKeyField.setText(apiKey)); + }); + + onPremAuthenticationFieldPanel = new UI.PanelFactory().grid() + .add(UI.PanelFactory.panel(onPremHostField) + .withLabel(CodeGPTBundle.get("settingsConfigurable.service.watsonx.onPremHost.label")) + .withComment(CodeGPTBundle.get( + "settingsConfigurable.service.watsonx.onPremHost.comment")) + .resizeX(false)) + .add(UI.PanelFactory.panel(usernameField) + .withLabel(CodeGPTBundle.get("settingsConfigurable.service.watsonx.username.label")) + .withComment(CodeGPTBundle.get( + "settingsConfigurable.service.watsonx.username.comment")) + .resizeX(false)) + .add(UI.PanelFactory.panel(onPremApiKeyField) + .withLabel(CodeGPTBundle.get("settingsConfigurable.service.watsonx.onPremApiKey.label")) + .withComment(CodeGPTBundle.get( + "settingsConfigurable.service.watsonx.onPremApiKey.comment")) + .resizeX(false)) + .add(UI.PanelFactory.panel(zenApiKeyCheckbox).resizeX(false)) + .createPanel(); + + onCloudAuthenticationFieldPanel = new UI.PanelFactory().grid() + .add(UI.PanelFactory.panel(getStartedLink)) + .add(UI.PanelFactory.panel(apiKeyField) + .withLabel(CodeGPTBundle.get("settingsConfigurable.service.watsonx.onCloudApiKey.label")) + .withComment(CodeGPTBundle.get( + "settingsConfigurable.service.watsonx.onCloudApiKey.comment")) + .resizeX(false)) + .createPanel(); + + apiVersionField = new JBTextField(settings.getApiVersion(), 35); + projectIdField = new JBTextField(settings.getProjectId(), 35); + spaceIdField = new JBTextField(settings.getSpaceId(), 35); + greedyDecodingCheckbox = new JBCheckBox("Greedy decoding", false); + + modelComboBox = new ComboBox<>(new EnumComboBoxModel(WatsonxCompletionModel.class)); + modelComboBox.setSelectedItem(WatsonxCompletionModel.findByCode(settings.getModel())); + temperatureField=new JBTextField(String.valueOf(settings.getTemperature()),35); + topKField=new JBTextField(String.valueOf(settings.getTopK()),35); + topPField=new JBTextField(String.valueOf(settings.getTopP()),35); + repetitionPenaltyField=new JBTextField(String.valueOf(settings.getRepetitionPenalty()),35); + randomSeedField=new JBTextField(settings.getRandomSeed() == null ? "" : String.valueOf(settings.getRandomSeed()),35); + maxNewTokensField=new JBTextField(String.valueOf(settings.getMaxNewTokens()),35); + minNewTokensField=new JBTextField(String.valueOf(settings.getMinNewTokens()),35); + stopSequencesField=new JBTextField(String.valueOf(settings.getStopSequences()),35); + includeStopSequenceCheckbox=new JBCheckBox(String.valueOf(settings.getIncludeStopSequence())); + includeStopSequenceCheckbox.setText("Include stop sequence"); + includeStopSequenceCheckbox.setSelected(false); + + sampleParametersFieldPanel = UI.PanelFactory.grid() + .add(UI.PanelFactory.panel(temperatureField) + .withLabel(CodeGPTBundle.get("configurationConfigurable.section.assistant.temperatureField.label")) + .withComment(CodeGPTBundle.get( + "settingsConfigurable.service.watsonx.temperature.comment")) + .resizeX(false)) + .add(UI.PanelFactory.panel(topKField) + .withLabel(CodeGPTBundle.get("settingsConfigurable.service.llama.topK.label")) + .withComment(CodeGPTBundle.get( + "settingsConfigurable.service.watsonx.topK.comment")) + .resizeX(false)) + .add(UI.PanelFactory.panel(topPField) + .withLabel(CodeGPTBundle.get("settingsConfigurable.service.llama.topP.label")) + .withComment(CodeGPTBundle.get( + "settingsConfigurable.service.llama.topP.comment")) + .resizeX(false)) + .add(UI.PanelFactory.panel(randomSeedField) + .withLabel(CodeGPTBundle.get("settingsConfigurable.service.watsonx.randomSeed.label")) + .withComment(CodeGPTBundle.get( + "settingsConfigurable.service.watsonx.randomSeed.comment")) + .resizeX(false)) + .createPanel(); + + registerPanelsVisibility(settings); + + onPremCheckbox.addActionListener(e -> { + onPremAuthenticationFieldPanel.setVisible(!onPremAuthenticationFieldPanel.isVisible()); + onCloudAuthenticationFieldPanel.setVisible(!onCloudAuthenticationFieldPanel.isVisible()); + settings.setOnPrem(!settings.isOnPrem()); + }); + + greedyDecodingCheckbox.addActionListener(e -> { + sampleParametersFieldPanel.setVisible(!sampleParametersFieldPanel.isVisible()); + settings.setGreedyDecoding(!settings.isGreedyDecoding()); + }); + + + } + + private void registerPanelsVisibility(WatsonxSettingsState settings) { + onPremAuthenticationFieldPanel.setVisible(settings.isOnPrem()); + onCloudAuthenticationFieldPanel.setVisible(!settings.isOnPrem()); + sampleParametersFieldPanel.setVisible(!settings.isGreedyDecoding()); + } + + public JPanel getForm() { + return FormBuilder.createFormBuilder() + .addComponent(new TitledSeparator("Connection Parameters")) + .addComponent(UI.PanelFactory.grid() + .add(UI.PanelFactory.panel(onPremCheckbox) + .resizeX(false)).createPanel()) + .addComponent(onCloudAuthenticationFieldPanel) + .addComponent(onPremAuthenticationFieldPanel) + .addComponent(UI.PanelFactory.grid() + .add(UI.PanelFactory.panel(apiVersionField) + .withLabel(CodeGPTBundle.get("shared.apiVersion")) + .withComment(CodeGPTBundle.get( + "settingsConfigurable.service.watsonx.apiVersion.comment")) + .resizeX(false)) + .add(UI.PanelFactory.panel(projectIdField) + .withLabel("Project ID:") + .withComment(CodeGPTBundle.get( + "settingsConfigurable.service.watsonx.projectId.comment")) + .resizeX(false)) + .add(UI.PanelFactory.panel(spaceIdField) + .withLabel("Space ID:") + .withComment(CodeGPTBundle.get( + "settingsConfigurable.service.watsonx.spaceId.comment")) + .resizeX(false)) + .createPanel()) + .addComponent(new TitledSeparator("Generation Parameters")) + .addComponent(UI.PanelFactory.grid() + .add(UI.PanelFactory.panel(modelComboBox) + .withLabel(CodeGPTBundle.get("settingsConfigurable.shared.model.label")) + .withComment(CodeGPTBundle.get( + "settingsConfigurable.service.watsonx.modelId.comment")) + .resizeX(false)) + .add(UI.PanelFactory.panel(maxNewTokensField) + .withLabel(CodeGPTBundle.get("configurationConfigurable.section.assistant.maxTokensField.label")) + .withComment(CodeGPTBundle.get( + "settingsConfigurable.service.watsonx.maxNewTokens.comment")) + .resizeX(false)) + .add(UI.PanelFactory.panel(minNewTokensField) + .withLabel(CodeGPTBundle.get("settingsConfigurable.service.watsonx.minNewTokens.label")) + .withComment(CodeGPTBundle.get( + "settingsConfigurable.service.watsonx.minNewTokens.comment")) + .resizeX(false)) + .add(UI.PanelFactory.panel(stopSequencesField) + .withLabel(CodeGPTBundle.get("settingsConfigurable.service.watsonx.stopSequences.label")) + .withComment(CodeGPTBundle.get( + "settingsConfigurable.service.watsonx.stopSequences.comment")) + .resizeX(false)) + .add(UI.PanelFactory.panel(includeStopSequenceCheckbox) + .resizeX(false)) + + .add(UI.PanelFactory.panel(repetitionPenaltyField) + .withLabel("Repetition Penalty:") + .withComment(CodeGPTBundle.get( + "settingsConfigurable.service.watsonx.repetitionPenalty.comment")) + .resizeX(false)) + .add(UI.PanelFactory.panel(greedyDecodingCheckbox) + .resizeX(false)) + .createPanel()) + .addComponent(sampleParametersFieldPanel) + .addComponentFillVertically(new JPanel(), 0) + .getPanel(); + } + + public WatsonxSettingsState getCurrentState() { + var state = new WatsonxSettingsState(); + state.setModel(((WatsonxCompletionModel) modelComboBox.getSelectedItem()).getCode()); + state.setOnPrem(onPremCheckbox.isSelected()); + state.setOnPremHost(onPremHostField.getText()); + state.setUsername(usernameField.getText()); + state.setZenApiKey(zenApiKeyCheckbox.isSelected()); + state.setApiVersion(apiVersionField.getText()); + state.setSpaceId(spaceIdField.getText()); + state.setProjectId(projectIdField.getText()); + state.setGreedyDecoding(greedyDecodingCheckbox.isSelected()); + state.setMaxNewTokens(Integer.valueOf(maxNewTokensField.getText())); + state.setMinNewTokens(Integer.valueOf(minNewTokensField.getText())); + state.setTemperature(Double.valueOf(temperatureField.getText())); + state.setRandomSeed(randomSeedField.getText().isEmpty() ? null : Integer.valueOf(randomSeedField.getText())); + state.setTopP(Double.valueOf(topPField.getText())); + state.setTopK(Integer.valueOf(topKField.getText())); + state.setIncludeStopSequence(includeStopSequenceCheckbox.isSelected()); + state.setStopSequences(stopSequencesField.getText()); + state.setRepetitionPenalty(Double.valueOf(repetitionPenaltyField.getText())); + return state; + } + + public void resetForm() { + var state = WatsonxSettings.getCurrentState(); + onPremCheckbox.setSelected(state.isOnPrem()); + onPremHostField.setText(state.getOnPremHost()); + usernameField.setText(state.getUsername()); + zenApiKeyCheckbox.setSelected(state.isZenApiKey()); + apiKeyField.setText(CredentialsStore.getCredential(WATSONX_API_KEY)); + apiVersionField.setText(state.getApiVersion()); + spaceIdField.setText(state.getSpaceId()); + projectIdField.setText(state.getProjectId()); + greedyDecodingCheckbox.setSelected(state.isGreedyDecoding()); + modelComboBox.setSelectedItem(WatsonxCompletionModel.findByCode(state.getModel())); + maxNewTokensField.setText(String.valueOf(state.getMaxNewTokens())); + minNewTokensField.setText(String.valueOf(state.getMinNewTokens())); + temperatureField.setText(String.valueOf(state.getTemperature())); + randomSeedField.setText(state.getRandomSeed() == null ? "" : String.valueOf(state.getRandomSeed())); + topPField.setText(String.valueOf(state.getTopP())); + topKField.setText(String.valueOf(state.getTopK())); + repetitionPenaltyField.setText(String.valueOf(state.getRepetitionPenalty())); + includeStopSequenceCheckbox.setSelected(state.getIncludeStopSequence()); + stopSequencesField.setText(String.valueOf(state.getStopSequences())); + } + + public @Nullable String getApiKey() { + var apiKey = new String(apiKeyField.getPassword()); + return apiKey.isEmpty() ? null : apiKey; + } +} diff --git a/src/main/java/ee/carlrobert/codegpt/settings/service/watsonx/WatsonxSettingsState.java b/src/main/java/ee/carlrobert/codegpt/settings/service/watsonx/WatsonxSettingsState.java new file mode 100644 index 000000000..b322753c9 --- /dev/null +++ b/src/main/java/ee/carlrobert/codegpt/settings/service/watsonx/WatsonxSettingsState.java @@ -0,0 +1,181 @@ +package ee.carlrobert.codegpt.settings.service.watsonx; + +import java.util.Objects; + +public class WatsonxSettingsState { + + private String onPremHost; + private String username; + private boolean isOnPrem = false; + private boolean isZenApiKey = false; + private String apiVersion = "2024-03-14"; + // use this model as default + private String model = "ibm/granite-3b-code-instruct"; + private String spaceId; + private String projectId; + private boolean isGreedyDecoding = false; + private boolean codeCompletionsEnabled = false; + private Double temperature = 0.9; + private Integer topK = 40; + private Double topP = 0.9; + private Integer maxNewTokens= 4000; + private Integer minNewTokens = 0; + private Boolean includeStopSequence = false; + private String stopSequences = ""; + private Double repetitionPenalty = 1.1; + private Integer randomSeed; + + public boolean isOnPrem() { + return isOnPrem; + } + + public void setOnPrem(boolean onPrem) { + isOnPrem = onPrem; + } + + public String getOnPremHost() { + return onPremHost; + } + + public void setOnPremHost(String onPremHost) { + this.onPremHost = onPremHost; + } + + public boolean isZenApiKey() { + return isZenApiKey; + } + + public void setZenApiKey(boolean zenApiKey) { + isZenApiKey = zenApiKey; + } + + public String getUsername() { + return username; + } + + public void setUsername(String username) { + this.username = username; + } + + public String getApiVersion() { + return apiVersion; + } + public void setApiVersion(String apiVersion) { + this.apiVersion = apiVersion; + } + + public String getSpaceId() {return spaceId;} + public void setSpaceId(String spaceId) { + this.spaceId = spaceId; + } + + public String getProjectId() {return projectId;} + public void setProjectId(String projectId) { + this.projectId = projectId; + } + + public Boolean isGreedyDecoding() { + return isGreedyDecoding; + } + public void setGreedyDecoding(boolean isGreedyDecoding) { + this.isGreedyDecoding = isGreedyDecoding; + } + public String getModel() { + return model; + } + public void setModel(String model) { + this.model = model; + } + + public Double getTemperature() { + return temperature; + } + + public void setTemperature(Double temperature) { + this.temperature = temperature; + } + + public Integer getMaxNewTokens() { + return maxNewTokens; + } + public void setMaxNewTokens(Integer maxNewTokens) { + this.maxNewTokens = maxNewTokens; + } + public Integer getMinNewTokens() { + return minNewTokens; + } + public void setMinNewTokens(Integer minNewTokens) { + this.minNewTokens = minNewTokens; + } + + public Double getRepetitionPenalty() { + return repetitionPenalty; + } + + public void setRepetitionPenalty(Double repetitionPenalty) { + this.repetitionPenalty = repetitionPenalty; + } + + public Double getTopP() { + return topP; + } + + public void setTopP(Double topP) { + this.topP = topP; + } + + public Integer getTopK() { + return topK; + } + public void setTopK(Integer topK) { + this.topK = topK; + } + + public Boolean getIncludeStopSequence() { + return includeStopSequence; + } + + public void setIncludeStopSequence(Boolean includeStopSequence) { + this.includeStopSequence = includeStopSequence; + } + + public String getStopSequences() { + return stopSequences; + } + public void setStopSequences(String stopSequences){ + this.stopSequences = stopSequences; + } + + public Integer getRandomSeed() { + return randomSeed; + } + public void setRandomSeed(Integer randomSeed){ + this.randomSeed=randomSeed; + } + + public boolean isCodeCompletionsEnabled() { + return codeCompletionsEnabled; + } + + public void setCodeCompletionsEnabled(boolean codeCompletionsEnabled) { + this.codeCompletionsEnabled = codeCompletionsEnabled; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + ee.carlrobert.codegpt.settings.service.watsonx.WatsonxSettingsState that = (ee.carlrobert.codegpt.settings.service.watsonx.WatsonxSettingsState) o; + return Objects.equals(apiVersion, that.apiVersion) && Objects.equals(spaceId, that.spaceId) && Objects.equals(projectId, that.projectId) && Objects.equals(model, that.model) && Objects.equals(temperature,that.temperature) && Objects.equals(topP,that.topP) && Objects.equals(topK,that.topK) && Objects.equals(randomSeed,that.randomSeed) && Objects.equals(repetitionPenalty,that.repetitionPenalty) && Objects.equals(maxNewTokens, that.maxNewTokens) && Objects.equals(minNewTokens,that.minNewTokens) && Objects.equals(isGreedyDecoding,that.isGreedyDecoding) && Objects.equals(isOnPrem,that.isOnPrem) && Objects.equals(isZenApiKey,that.isZenApiKey); + + } + + @Override + public int hashCode() { + return Objects.hash(apiVersion, model, apiVersion, projectId, spaceId,temperature,topP,topK,randomSeed,includeStopSequence,stopSequences,repetitionPenalty, maxNewTokens,minNewTokens,isGreedyDecoding,isOnPrem,isZenApiKey); + } +} diff --git a/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ui/textarea/ModelComboBoxAction.java b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ui/textarea/ModelComboBoxAction.java index da1a093cf..cab7e3967 100644 --- a/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ui/textarea/ModelComboBoxAction.java +++ b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ui/textarea/ModelComboBoxAction.java @@ -1,13 +1,6 @@ package ee.carlrobert.codegpt.toolwindow.chat.ui.textarea; -import static ee.carlrobert.codegpt.settings.service.ServiceType.ANTHROPIC; -import static ee.carlrobert.codegpt.settings.service.ServiceType.AZURE; -import static ee.carlrobert.codegpt.settings.service.ServiceType.CODEGPT; -import static ee.carlrobert.codegpt.settings.service.ServiceType.CUSTOM_OPENAI; -import static ee.carlrobert.codegpt.settings.service.ServiceType.GOOGLE; -import static ee.carlrobert.codegpt.settings.service.ServiceType.LLAMA_CPP; -import static ee.carlrobert.codegpt.settings.service.ServiceType.OLLAMA; -import static ee.carlrobert.codegpt.settings.service.ServiceType.OPENAI; +import static ee.carlrobert.codegpt.settings.service.ServiceType.*; import static java.lang.String.format; import com.intellij.openapi.actionSystem.ActionUpdateThread; @@ -32,9 +25,11 @@ import ee.carlrobert.codegpt.settings.service.ollama.OllamaSettings; import ee.carlrobert.codegpt.settings.service.openai.OpenAISettings; import ee.carlrobert.llm.client.openai.completion.OpenAIChatCompletionModel; + import java.util.Arrays; import java.util.List; import java.util.function.Consumer; +import java.util.stream.Stream; import javax.swing.Icon; import javax.swing.JComponent; import org.jetbrains.annotations.NotNull; @@ -135,6 +130,14 @@ private AnAction[] getCodeGPTModelActions(Project project, Presentation presenta Icons.Google, presentation)); } + if (availableProviders.contains(WATSONX)) { + actionGroup.addSeparator("IBM"); + actionGroup.add(createModelAction( + WATSONX, + "Watsonx", + Icons.Watsonx, + presentation)); + } if (availableProviders.contains(LLAMA_CPP)) { actionGroup.addSeparator("LLaMA C/C++"); actionGroup.add(createModelAction( @@ -212,6 +215,10 @@ private void updateTemplatePresentation(ServiceType selectedService) { templatePresentation.setText("Google (Gemini)"); templatePresentation.setIcon(Icons.Google); break; + case WATSONX: + templatePresentation.setIcon(Icons.Watsonx); + templatePresentation.setText("Watsonx"); + break; default: break; } diff --git a/src/main/java/ee/carlrobert/codegpt/ui/ModelIconLabel.java b/src/main/java/ee/carlrobert/codegpt/ui/ModelIconLabel.java index f125a5b18..38a86dedd 100644 --- a/src/main/java/ee/carlrobert/codegpt/ui/ModelIconLabel.java +++ b/src/main/java/ee/carlrobert/codegpt/ui/ModelIconLabel.java @@ -30,6 +30,9 @@ public ModelIconLabel(String clientCode, String modelCode) { if ("google.chat.completion".equals(clientCode)) { setIcon(Icons.Google); } + if ("watsonx.chat.completion".equals(clientCode)) { + setIcon(Icons.Watsonx); + } setText(formatModelName(modelCode)); setFont(JBFont.small()); setHorizontalAlignment(SwingConstants.LEADING); diff --git a/src/main/kotlin/ee/carlrobert/codegpt/actions/CodeCompletionFeatureToggleActions.kt b/src/main/kotlin/ee/carlrobert/codegpt/actions/CodeCompletionFeatureToggleActions.kt index 1f3b4aea2..a41a2dc18 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/actions/CodeCompletionFeatureToggleActions.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/actions/CodeCompletionFeatureToggleActions.kt @@ -12,6 +12,7 @@ import ee.carlrobert.codegpt.settings.service.custom.CustomServiceSettings import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings import ee.carlrobert.codegpt.settings.service.ollama.OllamaSettings import ee.carlrobert.codegpt.settings.service.openai.OpenAISettings +import ee.carlrobert.codegpt.settings.service.watsonx.WatsonxSettings abstract class CodeCompletionFeatureToggleActions( private val enableFeatureAction: Boolean @@ -19,14 +20,11 @@ abstract class CodeCompletionFeatureToggleActions( override fun actionPerformed(e: AnActionEvent) = when (GeneralSettings.getSelectedService()) { CODEGPT -> service().state.codeCompletionSettings::codeCompletionsEnabled::set - OPENAI -> OpenAISettings.getCurrentState()::setCodeCompletionsEnabled - LLAMA_CPP -> LlamaSettings.getCurrentState()::setCodeCompletionsEnabled - OLLAMA -> service().state::codeCompletionsEnabled::set - CUSTOM_OPENAI -> service().state.codeCompletionSettings::codeCompletionsEnabled::set + WATSONX -> WatsonxSettings.getCurrentState()::setCodeCompletionsEnabled ANTHROPIC, AZURE, @@ -44,7 +42,9 @@ abstract class CodeCompletionFeatureToggleActions( OPENAI, CUSTOM_OPENAI, LLAMA_CPP, - OLLAMA -> true + OLLAMA, + WATSONX + -> true ANTHROPIC, AZURE, diff --git a/src/main/kotlin/ee/carlrobert/codegpt/codecompletions/CodeCompletionRequestFactory.kt b/src/main/kotlin/ee/carlrobert/codegpt/codecompletions/CodeCompletionRequestFactory.kt index 722444b2a..d47a7170e 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/codecompletions/CodeCompletionRequestFactory.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/codecompletions/CodeCompletionRequestFactory.kt @@ -8,16 +8,19 @@ import ee.carlrobert.codegpt.completions.llama.LlamaModel import ee.carlrobert.codegpt.credentials.CredentialsStore.CredentialKey import ee.carlrobert.codegpt.credentials.CredentialsStore.getCredential import ee.carlrobert.codegpt.settings.configuration.Placeholder.* +import ee.carlrobert.codegpt.settings.persona.PersonaSettings.Companion.getSystemPrompt import ee.carlrobert.codegpt.settings.service.codegpt.CodeGPTServiceSettings import ee.carlrobert.codegpt.settings.service.custom.CustomServiceSettings import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings import ee.carlrobert.codegpt.settings.service.llama.LlamaSettingsState import ee.carlrobert.codegpt.settings.service.ollama.OllamaSettings import ee.carlrobert.llm.client.codegpt.request.CodeCompletionRequest +import ee.carlrobert.codegpt.settings.service.watsonx.WatsonxSettings import ee.carlrobert.llm.client.llama.completion.LlamaCompletionRequest import ee.carlrobert.llm.client.ollama.completion.request.OllamaCompletionRequest import ee.carlrobert.llm.client.ollama.completion.request.OllamaParameters import ee.carlrobert.llm.client.openai.completion.request.OpenAITextCompletionRequest +import ee.carlrobert.llm.client.watsonx.completion.WatsonxCompletionRequest import okhttp3.MediaType.Companion.toMediaType import okhttp3.Request import okhttp3.RequestBody.Companion.toRequestBody @@ -25,6 +28,30 @@ import java.nio.charset.StandardCharsets object CodeCompletionRequestFactory { + @JvmStatic + fun buildWatsonxRequest(details: InfillRequest): WatsonxCompletionRequest { + val settings = WatsonxSettings.getCurrentState(); + val builder = WatsonxCompletionRequest.Builder(details.prefix) + builder.setDecodingMethod(if (settings.isGreedyDecoding) "greedy" else "sample") + builder.setModelId(settings.model) + builder.setProjectId(settings.projectId) + builder.setSpaceId(settings.spaceId) + builder.setMaxNewTokens(settings.maxNewTokens) + builder.setMinNewTokens(settings.minNewTokens) + builder.setTemperature(settings.temperature) + builder.setStopSequences( + if (settings.stopSequences.isEmpty()) null else settings.stopSequences.split(",".toRegex()) + .dropLastWhile { it.isEmpty() } + .toTypedArray()) + builder.setTopP(settings.topP) + builder.setTopK(settings.topK) + builder.setIncludeStopSequence(settings.includeStopSequence) + builder.setRandomSeed(settings.randomSeed) + builder.setRepetitionPenalty(settings.repetitionPenalty) + builder.setStream(true) + return builder.build() + } + private const val MAX_TOKENS = 128 @JvmStatic @@ -158,6 +185,7 @@ object CodeCompletionRequestFactory { ?.replace(SUFFIX.code, suffix) ?: value } } + return 36 } private fun getCompletionContext(request: InfillRequest): Pair { diff --git a/src/main/kotlin/ee/carlrobert/codegpt/codecompletions/CodeCompletionService.kt b/src/main/kotlin/ee/carlrobert/codegpt/codecompletions/CodeCompletionService.kt index ede794201..3ca7e15be 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/codecompletions/CodeCompletionService.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/codecompletions/CodeCompletionService.kt @@ -7,6 +7,7 @@ import ee.carlrobert.codegpt.codecompletions.CodeCompletionRequestFactory.buildC import ee.carlrobert.codegpt.codecompletions.CodeCompletionRequestFactory.buildLlamaRequest import ee.carlrobert.codegpt.codecompletions.CodeCompletionRequestFactory.buildOllamaRequest import ee.carlrobert.codegpt.codecompletions.CodeCompletionRequestFactory.buildOpenAIRequest +import ee.carlrobert.codegpt.codecompletions.CodeCompletionRequestFactory.buildWatsonxRequest import ee.carlrobert.codegpt.completions.CompletionClientProvider import ee.carlrobert.codegpt.settings.GeneralSettings import ee.carlrobert.codegpt.settings.service.ServiceType @@ -16,6 +17,7 @@ import ee.carlrobert.codegpt.settings.service.custom.CustomServiceSettings import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings import ee.carlrobert.codegpt.settings.service.ollama.OllamaSettings import ee.carlrobert.codegpt.settings.service.openai.OpenAISettings +import ee.carlrobert.codegpt.settings.service.watsonx.WatsonxSettings import ee.carlrobert.llm.client.openai.completion.OpenAITextCompletionEventSourceListener import ee.carlrobert.llm.completion.CompletionEventListener import okhttp3.sse.EventSource @@ -46,6 +48,7 @@ class CodeCompletionService { OPENAI -> OpenAISettings.getCurrentState().isCodeCompletionsEnabled CUSTOM_OPENAI -> service().state.codeCompletionSettings.codeCompletionsEnabled LLAMA_CPP -> LlamaSettings.isCodeCompletionsPossible() + WATSONX -> WatsonxSettings.isCodeCompletionsPossible() OLLAMA -> service().state.codeCompletionsEnabled else -> false } @@ -74,6 +77,9 @@ class CodeCompletionService { LLAMA_CPP -> CompletionClientProvider.getLlamaClient() .getChatCompletionAsync(buildLlamaRequest(requestDetails), eventListener) + WATSONX -> CompletionClientProvider.getWatsonxClient() + .getCompletionAsync(buildWatsonxRequest(requestDetails), eventListener) + else -> throw IllegalArgumentException("Code completion not supported for ${selectedService.name}") } } diff --git a/src/main/kotlin/ee/carlrobert/codegpt/codecompletions/CodeGPTInlineCompletionProvider.kt b/src/main/kotlin/ee/carlrobert/codegpt/codecompletions/CodeGPTInlineCompletionProvider.kt new file mode 100644 index 000000000..e69de29bb diff --git a/src/main/kotlin/ee/carlrobert/codegpt/codecompletions/DebouncedCodeCompletionProvider.kt b/src/main/kotlin/ee/carlrobert/codegpt/codecompletions/DebouncedCodeCompletionProvider.kt index 5e499a93c..421bd2f06 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/codecompletions/DebouncedCodeCompletionProvider.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/codecompletions/DebouncedCodeCompletionProvider.kt @@ -18,6 +18,7 @@ import ee.carlrobert.codegpt.settings.service.custom.CustomServiceSettings import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings import ee.carlrobert.codegpt.settings.service.ollama.OllamaSettings import ee.carlrobert.codegpt.settings.service.openai.OpenAISettings +import ee.carlrobert.codegpt.settings.service.watsonx.WatsonxSettings import kotlinx.coroutines.channels.ProducerScope import kotlinx.coroutines.channels.awaitClose import kotlinx.coroutines.flow.channelFlow @@ -95,6 +96,7 @@ class DebouncedCodeCompletionProvider : DebouncedInlineCompletionProvider() { ServiceType.CUSTOM_OPENAI -> service().state.codeCompletionSettings.codeCompletionsEnabled ServiceType.LLAMA_CPP -> LlamaSettings.isCodeCompletionsPossible() ServiceType.OLLAMA -> service().state.codeCompletionsEnabled + ServiceType.WATSONX -> WatsonxSettings.isCodeCompletionsPossible() ServiceType.ANTHROPIC, ServiceType.AZURE, ServiceType.GOOGLE, diff --git a/src/main/kotlin/ee/carlrobert/codegpt/credentials/CredentialsStore.kt b/src/main/kotlin/ee/carlrobert/codegpt/credentials/CredentialsStore.kt index 229808762..8c29972f5 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/credentials/CredentialsStore.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/credentials/CredentialsStore.kt @@ -43,5 +43,6 @@ object CredentialsStore { LLAMA_API_KEY, GOOGLE_API_KEY, OLLAMA_API_KEY, + WATSONX_API_KEY } } \ No newline at end of file diff --git a/src/main/kotlin/ee/carlrobert/codegpt/settings/service/ServiceConfigurableComponent.kt b/src/main/kotlin/ee/carlrobert/codegpt/settings/service/ServiceConfigurableComponent.kt index 8c706f572..e2e3629a1 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/settings/service/ServiceConfigurableComponent.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/settings/service/ServiceConfigurableComponent.kt @@ -15,6 +15,7 @@ import ee.carlrobert.codegpt.settings.service.codegpt.CodeGPTServiceForm import ee.carlrobert.codegpt.settings.service.custom.CustomServiceConfigurable import ee.carlrobert.codegpt.settings.service.google.GoogleSettingsConfigurable import ee.carlrobert.codegpt.settings.service.ollama.OllamaSettingsConfigurable +import ee.carlrobert.codegpt.settings.service.WatsonxServiceConfigurable import javax.swing.JPanel class ServiceConfigurableComponent { @@ -60,7 +61,8 @@ class ServiceConfigurableComponent { "Google" to GoogleSettingsConfigurable::class.java, "LLaMA C/C++ (Local)" to LlamaServiceConfigurable::class.java, "Ollama (Local)" to OllamaSettingsConfigurable::class.java, - ).entries.forEach { (name, configurableClass) -> + "Watsonx" to WatsonxServiceConfigurable::class.java, + ).entries.forEach { (name, configurableClass) -> formBuilder.addComponent(ActionLink(name) { val context = service().getDataContext(it.source as ActionLink) val settings = Settings.KEY.getData(context) diff --git a/src/main/kotlin/ee/carlrobert/codegpt/settings/service/WatsonxServiceConfigurable.kt b/src/main/kotlin/ee/carlrobert/codegpt/settings/service/WatsonxServiceConfigurable.kt new file mode 100644 index 000000000..6740ba4f1 --- /dev/null +++ b/src/main/kotlin/ee/carlrobert/codegpt/settings/service/WatsonxServiceConfigurable.kt @@ -0,0 +1,45 @@ +package ee.carlrobert.codegpt.settings.service; +import com.intellij.openapi.components.service +import com.intellij.openapi.options.Configurable +import ee.carlrobert.codegpt.credentials.CredentialsStore.CredentialKey.OPENAI_API_KEY +import ee.carlrobert.codegpt.credentials.CredentialsStore.CredentialKey.WATSONX_API_KEY +import ee.carlrobert.codegpt.credentials.CredentialsStore.getCredential +import ee.carlrobert.codegpt.credentials.CredentialsStore.setCredential +import ee.carlrobert.codegpt.settings.GeneralSettings +import ee.carlrobert.codegpt.settings.service.openai.OpenAISettings +import ee.carlrobert.codegpt.settings.service.openai.OpenAISettingsForm +import ee.carlrobert.codegpt.settings.service.watsonx.WatsonxSettings +import ee.carlrobert.codegpt.settings.service.watsonx.WatsonxSettingsForm +import javax.swing.JComponent + +public class WatsonxServiceConfigurable: Configurable { + + private lateinit var component: WatsonxSettingsForm + + override fun getDisplayName(): String { + return "CodeGPT: Watsonx Service" + } + + override fun createComponent(): JComponent { + component = WatsonxSettingsForm(service().state) + return component.form + } + + override fun isModified(): Boolean { + return component.getCurrentState() != service().state + || component.getApiKey() != getCredential(WATSONX_API_KEY) + } + + override fun apply() { + service().state.selectedService = ServiceType.WATSONX + setCredential(WATSONX_API_KEY, component.getApiKey()) + service().loadState(component.getCurrentState()) + } + + override fun reset() { + component.resetForm() + } +} + + + diff --git a/src/main/kotlin/ee/carlrobert/codegpt/settings/service/codegpt/CodeGPTAvailableModels.kt b/src/main/kotlin/ee/carlrobert/codegpt/settings/service/codegpt/CodeGPTAvailableModels.kt index afab7bbe5..ccd08e4df 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/settings/service/codegpt/CodeGPTAvailableModels.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/settings/service/codegpt/CodeGPTAvailableModels.kt @@ -38,6 +38,15 @@ object CodeGPTAvailableModels { @JvmStatic val BASE_CHAT_MODELS: List = listOf( + CodeGPTModel("Mixtral (8x7B)", "mistralai/mixtral-8x7b-instruct-v01", Icons.Watsonx, INDIVIDUAL), + CodeGPTModel("Mistral Large", "mistralai/mistral-large", Icons.Watsonx, INDIVIDUAL), + CodeGPTModel("Llama 3.1 Instruct (70B)", "meta-llama/llama-3-1-70b-instruct", Icons.Watsonx, INDIVIDUAL), + CodeGPTModel("Llama 3.1 Instruct (8B)", "meta-llama/llama-3-1-8b-instruct", Icons.Watsonx, INDIVIDUAL), + CodeGPTModel("Llama 2 Chat (70B)", "meta-llama/llama-2-70b-chat", Icons.Watsonx, INDIVIDUAL), + CodeGPTModel("Llama 2 Chat (13B)", "meta-llama/llama-2-13b-chat", Icons.Watsonx, INDIVIDUAL), + CodeGPTModel("IBM Granite 13B Instruct V2", "ibm/granite-13b-instruct-v2", Icons.Watsonx, INDIVIDUAL), + CodeGPTModel("IBM Granite 13B Chat V2", "ibm/granite-13b-chat-v2", Icons.Watsonx, INDIVIDUAL), + CodeGPTModel("IBM Granite 20B Multilingual", "ibm/granite-20b-multilingual", Icons.Watsonx, INDIVIDUAL), CodeGPTModel("GPT-4o", "gpt-4o", Icons.OpenAI, INDIVIDUAL), CodeGPTModel("GPT-4o mini", "gpt-4o-mini", Icons.OpenAI, ANONYMOUS), CodeGPTModel("Claude 3 Opus", "claude-3-opus", Icons.Anthropic, INDIVIDUAL), @@ -59,6 +68,12 @@ object CodeGPTAvailableModels { @JvmStatic val CODE_MODELS: List = listOf( + CodeGPTModel("Code Llama 34B Instruct", "codellama/codellama-34b-instruct", Icons.Watsonx, INDIVIDUAL), + CodeGPTModel("IBM Granite 3B Code Instruct", "ibm/granite-3b-code-instruct", Icons.Watsonx, INDIVIDUAL), + CodeGPTModel("IBM Granite 8B Code Instruct", "ibm/granite-8b-code-instruct", Icons.Watsonx, INDIVIDUAL), + CodeGPTModel("IBM Granite 20B Code Instruct", "ibm/granite-20b-code-instruct", Icons.Watsonx, INDIVIDUAL), + CodeGPTModel("IBM Granite 34B Code Instruct", "ibm/granite-34b-code-instruct", Icons.Watsonx, INDIVIDUAL), + CodeGPTModel("Codestral", "codestral", Icons.OpenAI, INDIVIDUAL), CodeGPTModel("GPT-3.5 Turbo Instruct", "gpt-3.5-turbo-instruct", Icons.OpenAI, INDIVIDUAL), CodeGPTModel("StarCoder (16B)", "starcoder-16b", Icons.CodeGPTModel, FREE), CodeGPTModel("StarCoder (7B) - FREE", "starcoder-7b", Icons.CodeGPTModel, FREE), diff --git a/src/main/resources/META-INF/plugin.xml b/src/main/resources/META-INF/plugin.xml index 6e6d2e8df..cf38c106f 100644 --- a/src/main/resources/META-INF/plugin.xml +++ b/src/main/resources/META-INF/plugin.xml @@ -44,9 +44,11 @@ instance="ee.carlrobert.codegpt.settings.service.LlamaServiceConfigurable"/> + - @@ -66,6 +68,7 @@ + diff --git a/src/main/resources/icons/watsonx.svg b/src/main/resources/icons/watsonx.svg new file mode 100644 index 000000000..b9641822a --- /dev/null +++ b/src/main/resources/icons/watsonx.svg @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/src/main/resources/icons/watsonx_dark.svg b/src/main/resources/icons/watsonx_dark.svg new file mode 100644 index 000000000..907960975 --- /dev/null +++ b/src/main/resources/icons/watsonx_dark.svg @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/src/main/resources/messages/codegpt.properties b/src/main/resources/messages/codegpt.properties index c1892cdb6..f21316f04 100644 --- a/src/main/resources/messages/codegpt.properties +++ b/src/main/resources/messages/codegpt.properties @@ -142,6 +142,28 @@ settingsConfigurable.service.custom.openai.linkToDocs=Link to API docs settingsConfigurable.service.custom.openai.connectionSuccess=Connection successful. settingsConfigurable.service.custom.openai.connectionFailed=Connection failed. settingsConfigurable.service.ollama.models.refresh=Refresh Models +settingsConfigurable.service.watsonx.onPremHost.label=URL: +settingsConfigurable.service.watsonx.onPremHost.comment= +settingsConfigurable.service.watsonx.username.label=Username: +settingsConfigurable.service.watsonx.username.comment= +settingsConfigurable.service.watsonx.onPremApiKey.label=API Key: +settingsConfigurable.service.watsonx.onPremApiKey.comment=Provide an API key or Platform (Zen) API key +settingsConfigurable.service.watsonx.onCloudApiKey.label=API Key: +settingsConfigurable.service.watsonx.onCloudApiKey.comment=Provide an IBM Cloud API Key +settingsConfigurable.service.watsonx.temperature.comment=The value of randomness. Must be between 0 and 2. +settingsConfigurable.service.watsonx.topK.comment=Limit the next token selection to the top K most probable tokens. +settingsConfigurable.service.watsonx.randomSeed.label=Random seed: +settingsConfigurable.service.watsonx.randomSeed.comment=Specify an integer value for reproducibility of results. +settingsConfigurable.service.watsonx.apiVersion.comment= +settingsConfigurable.service.watsonx.projectId.comment=Provide a project ID +settingsConfigurable.service.watsonx.spaceId.comment=Provide a deployment space ID +settingsConfigurable.service.watsonx.modelId.comment=Select a model from the list +settingsConfigurable.service.watsonx.maxNewTokens.comment= +settingsConfigurable.service.watsonx.minNewTokens.comment= +settingsConfigurable.service.watsonx.minNewTokens.label=Min completion tokens: +settingsConfigurable.service.watsonx.stopSequences.label=Stop sequences +settingsConfigurable.service.watsonx.stopSequences.comment=Comma-separated list of stop sequences +settingsConfigurable.service.watsonx.repetitionPenalty.comment= configurationConfigurable.section.commitMessage.title=Commit Message Template configurationConfigurable.section.commitMessage.systemPromptField.label=Prompt template: configurationConfigurable.section.inlineCompletion.title=Inline Completion @@ -202,6 +224,7 @@ service.azure.title=Azure service.google.title=Google service.llama.title=LLaMA C/C++ (Local) service.ollama.title=Ollama (Local) +service.watsonx.title=Watsonx validation.error.model.notExists='%s' is not available, please select another model validation.error.fieldRequired=This field is required. validation.error.invalidEmail=The email you entered is invalid. @@ -240,19 +263,6 @@ shared.image=Image shared.chatCompletions=Chat Completions shared.codeCompletions=Code Completions codeCompletionsForm.enableFeatureText=Enable code completions -codeCompletionsForm.maxTokensLabel=Max tokens: -codeCompletionsForm.maxTokensComment=The maximum number of tokens that will be generated in the code completion. -editCodePopover.title=Edit Code -editCodePopover.textField.emptyText=Editing instructions... -editCodePopover.textField.followUp.emptyText=Ask a follow-up question -editCodePopover.textField.comment=Provide instructions for the code modification. -editCodePopover.submitButton.title=Submit Edit -editCodePopover.acceptButton.title=Accept Suggestion -editCodePopover.followUpButton.title=Submit Follow-up -smartTextPane.submitButton.title=Send Message -smartTextPane.submitButton.description=Send message -smartTextPane.stopButton.title=Stop -smartTextPane.stopButton.description=Stop completion chatMessageResponseBody.webPages.title=WEB PAGES chatMessageResponseBody.webDocs.startProgress.label=Analyzing web content... addDocumentation.popup.title=Add Documentation From 81f6178af7d9d0b3c648e347d573e7e457ac602d Mon Sep 17 00:00:00 2001 From: mq200 Date: Thu, 5 Sep 2024 20:21:39 -0700 Subject: [PATCH 2/4] cloud region customization --- .../completions/CompletionClientProvider.java | 10 +++++++++- .../service/watsonx/WatsonxSettingsForm.java | 9 +++++++++ .../service/watsonx/WatsonxSettingsState.java | 13 +++++++++---- src/main/resources/messages/codegpt.properties | 2 ++ 4 files changed, 29 insertions(+), 5 deletions(-) diff --git a/src/main/java/ee/carlrobert/codegpt/completions/CompletionClientProvider.java b/src/main/java/ee/carlrobert/codegpt/completions/CompletionClientProvider.java index 4ce48d514..aa402d2f0 100644 --- a/src/main/java/ee/carlrobert/codegpt/completions/CompletionClientProvider.java +++ b/src/main/java/ee/carlrobert/codegpt/completions/CompletionClientProvider.java @@ -35,10 +35,18 @@ public static CodeGPTClient getCodeGPTClient() { } public static WatsonxClient getWatsonxClient() { + String regionCode = switch(WatsonxSettings.getCurrentState().getRegion()) { + case "Dallas" -> "us-south"; + case "Frankfurt" -> "eu-de"; + case "London" -> "eu-gb"; + case "Tokyo" -> "jp-tok"; + default -> "us-south"; + }; + String host = WatsonxSettings.getCurrentState().isOnPrem() ? WatsonxSettings.getCurrentState().getOnPremHost() : "https://" + regionCode + ".ml.cloud.ibm.com"; return new WatsonxClient.Builder(getCredential(CredentialKey.WATSONX_API_KEY)) .setApiVersion(WatsonxSettings.getCurrentState().getApiVersion()) .setIsOnPrem(WatsonxSettings.getCurrentState().isOnPrem()) - .setHost(WatsonxSettings.getCurrentState().getOnPremHost()) + .setHost(host) .setUsername(WatsonxSettings.getCurrentState().getUsername()) .setIsZenApiKey(WatsonxSettings.getCurrentState().isZenApiKey()) .build(getDefaultClientBuilder()); diff --git a/src/main/java/ee/carlrobert/codegpt/settings/service/watsonx/WatsonxSettingsForm.java b/src/main/java/ee/carlrobert/codegpt/settings/service/watsonx/WatsonxSettingsForm.java index 72e59c0b6..a754e9453 100644 --- a/src/main/java/ee/carlrobert/codegpt/settings/service/watsonx/WatsonxSettingsForm.java +++ b/src/main/java/ee/carlrobert/codegpt/settings/service/watsonx/WatsonxSettingsForm.java @@ -33,6 +33,7 @@ public class WatsonxSettingsForm { private final JBCheckBox onPremCheckbox; private final JBPasswordField apiKeyField; private final JBPasswordField onPremApiKeyField; + private final ComboBox regionComboBox; private final JBTextField onPremHostField; private final JBTextField usernameField; private final JBCheckBox zenApiKeyCheckbox; @@ -82,6 +83,9 @@ class OpenUrlAction implements ActionListener { getStartedLink.setToolTipText(getStartedUrl); getStartedLink.addActionListener(new OpenUrlAction()); + regionComboBox = new ComboBox(new String[] {"Dallas", "Frankfurt", "London", "Tokyo"}); + regionComboBox.setSelectedItem(settings.getRegion()); + apiKeyField = new JBPasswordField(); apiKeyField.setColumns(35); ApplicationManager.getApplication().executeOnPooledThread(() -> { @@ -117,6 +121,10 @@ class OpenUrlAction implements ActionListener { onCloudAuthenticationFieldPanel = new UI.PanelFactory().grid() .add(UI.PanelFactory.panel(getStartedLink)) + .add(UI.PanelFactory.panel(regionComboBox) + .withLabel(CodeGPTBundle.get("settingsConfigurable.service.watsonx.cloudRegion.label")) + .withComment(CodeGPTBundle.get("settingsConfigurable.service.watsonx.cloudRegion.comment")) + .resizeX(false)) .add(UI.PanelFactory.panel(apiKeyField) .withLabel(CodeGPTBundle.get("settingsConfigurable.service.watsonx.onCloudApiKey.label")) .withComment(CodeGPTBundle.get( @@ -258,6 +266,7 @@ public WatsonxSettingsState getCurrentState() { state.setOnPremHost(onPremHostField.getText()); state.setUsername(usernameField.getText()); state.setZenApiKey(zenApiKeyCheckbox.isSelected()); + state.setRegion((String)regionComboBox.getSelectedItem()); state.setApiVersion(apiVersionField.getText()); state.setSpaceId(spaceIdField.getText()); state.setProjectId(projectIdField.getText()); diff --git a/src/main/java/ee/carlrobert/codegpt/settings/service/watsonx/WatsonxSettingsState.java b/src/main/java/ee/carlrobert/codegpt/settings/service/watsonx/WatsonxSettingsState.java index b322753c9..567ec5b4d 100644 --- a/src/main/java/ee/carlrobert/codegpt/settings/service/watsonx/WatsonxSettingsState.java +++ b/src/main/java/ee/carlrobert/codegpt/settings/service/watsonx/WatsonxSettingsState.java @@ -8,6 +8,7 @@ public class WatsonxSettingsState { private String username; private boolean isOnPrem = false; private boolean isZenApiKey = false; + private String region = "us-south"; private String apiVersion = "2024-03-14"; // use this model as default private String model = "ibm/granite-3b-code-instruct"; @@ -57,6 +58,12 @@ public void setUsername(String username) { this.username = username; } + public String getRegion() {return region;} + + public void setRegion(String region) { + this.region = region; + } + public String getApiVersion() { return apiVersion; } @@ -170,12 +177,10 @@ public boolean equals(Object o) { return false; } ee.carlrobert.codegpt.settings.service.watsonx.WatsonxSettingsState that = (ee.carlrobert.codegpt.settings.service.watsonx.WatsonxSettingsState) o; - return Objects.equals(apiVersion, that.apiVersion) && Objects.equals(spaceId, that.spaceId) && Objects.equals(projectId, that.projectId) && Objects.equals(model, that.model) && Objects.equals(temperature,that.temperature) && Objects.equals(topP,that.topP) && Objects.equals(topK,that.topK) && Objects.equals(randomSeed,that.randomSeed) && Objects.equals(repetitionPenalty,that.repetitionPenalty) && Objects.equals(maxNewTokens, that.maxNewTokens) && Objects.equals(minNewTokens,that.minNewTokens) && Objects.equals(isGreedyDecoding,that.isGreedyDecoding) && Objects.equals(isOnPrem,that.isOnPrem) && Objects.equals(isZenApiKey,that.isZenApiKey); - + return Objects.equals(apiVersion, that.apiVersion) && Objects.equals(region, that.region) && Objects.equals(spaceId, that.spaceId) && Objects.equals(projectId, that.projectId) && Objects.equals(model, that.model) && Objects.equals(temperature,that.temperature) && Objects.equals(topP,that.topP) && Objects.equals(topK,that.topK) && Objects.equals(randomSeed,that.randomSeed) && Objects.equals(repetitionPenalty,that.repetitionPenalty) && Objects.equals(maxNewTokens, that.maxNewTokens) && Objects.equals(minNewTokens,that.minNewTokens) && Objects.equals(isGreedyDecoding,that.isGreedyDecoding) && Objects.equals(isOnPrem,that.isOnPrem) && Objects.equals(isZenApiKey,that.isZenApiKey); } @Override public int hashCode() { - return Objects.hash(apiVersion, model, apiVersion, projectId, spaceId,temperature,topP,topK,randomSeed,includeStopSequence,stopSequences,repetitionPenalty, maxNewTokens,minNewTokens,isGreedyDecoding,isOnPrem,isZenApiKey); - } + return Objects.hash(apiVersion, region, model, apiVersion, projectId, spaceId,temperature,topP,topK,randomSeed,includeStopSequence,stopSequences,repetitionPenalty, maxNewTokens,minNewTokens,isGreedyDecoding,isOnPrem,isZenApiKey); } } diff --git a/src/main/resources/messages/codegpt.properties b/src/main/resources/messages/codegpt.properties index f21316f04..11ce1eeed 100644 --- a/src/main/resources/messages/codegpt.properties +++ b/src/main/resources/messages/codegpt.properties @@ -164,6 +164,8 @@ settingsConfigurable.service.watsonx.minNewTokens.label=Min completion tokens: settingsConfigurable.service.watsonx.stopSequences.label=Stop sequences settingsConfigurable.service.watsonx.stopSequences.comment=Comma-separated list of stop sequences settingsConfigurable.service.watsonx.repetitionPenalty.comment= +settingsConfigurable.service.watsonx.cloudRegion.label=IBM Cloud region: +settingsConfigurable.service.watsonx.cloudRegion.comment= configurationConfigurable.section.commitMessage.title=Commit Message Template configurationConfigurable.section.commitMessage.systemPromptField.label=Prompt template: configurationConfigurable.section.inlineCompletion.title=Inline Completion From 9bd709144b91549ceb420086233b58086c352b5a Mon Sep 17 00:00:00 2001 From: mq200 Date: Fri, 6 Sep 2024 18:43:58 -0700 Subject: [PATCH 3/4] replace settings checkboxes with radio buttons + update link style --- .../service/watsonx/WatsonxSettingsForm.java | 78 ++++++++++++------- 1 file changed, 50 insertions(+), 28 deletions(-) diff --git a/src/main/java/ee/carlrobert/codegpt/settings/service/watsonx/WatsonxSettingsForm.java b/src/main/java/ee/carlrobert/codegpt/settings/service/watsonx/WatsonxSettingsForm.java index a754e9453..a18be9d62 100644 --- a/src/main/java/ee/carlrobert/codegpt/settings/service/watsonx/WatsonxSettingsForm.java +++ b/src/main/java/ee/carlrobert/codegpt/settings/service/watsonx/WatsonxSettingsForm.java @@ -8,19 +8,13 @@ import com.intellij.ui.EnumComboBoxModel; import com.intellij.ui.TitledSeparator; -import com.intellij.ui.components.JBLabel; -import com.intellij.ui.components.JBPasswordField; -import com.intellij.ui.components.JBTextField; -import com.intellij.ui.components.JBCheckBox; +import com.intellij.ui.components.*; import com.intellij.util.ui.FormBuilder; import com.intellij.util.ui.UI; import ee.carlrobert.codegpt.CodeGPTBundle; import ee.carlrobert.codegpt.credentials.CredentialsStore; -import ee.carlrobert.codegpt.settings.service.watsonx.WatsonxSettings; -import ee.carlrobert.codegpt.settings.service.watsonx.WatsonxSettingsState; import ee.carlrobert.llm.client.watsonx.completion.WatsonxCompletionModel; -import ee.carlrobert.codegpt.ui.UIUtil; import javax.swing.*; import java.awt.*; @@ -30,7 +24,8 @@ public class WatsonxSettingsForm { - private final JBCheckBox onPremCheckbox; + private final JBRadioButton onPremRadio; + private final JBRadioButton onCloudRadio; private final JBPasswordField apiKeyField; private final JBPasswordField onPremApiKeyField; private final ComboBox regionComboBox; @@ -41,7 +36,7 @@ public class WatsonxSettingsForm { private final JPanel onPremAuthenticationFieldPanel; private final JPanel onCloudAuthenticationFieldPanel; - private final String getStartedText = "Click here to get started with IBM watsonx.ai as a Service"; + private final String getStartedText = "Get started with IBM watsonx.ai as a Service"; private final String getStartedUrl = "https://dataplatform.cloud.ibm.com/registration/stepone?context=wx"; private final JButton getStartedLink; private final JBTextField apiVersionField; @@ -49,7 +44,8 @@ public class WatsonxSettingsForm { private final JBTextField spaceIdField; private final ComboBox modelComboBox; - private final JBCheckBox greedyDecodingCheckbox; + private final JBRadioButton greedyDecodingRadio; + private final JBRadioButton sampleDecodingRadio; private final JBTextField temperatureField; private final JBTextField topKField; private final JBTextField topPField; @@ -62,7 +58,9 @@ public class WatsonxSettingsForm { private final JPanel sampleParametersFieldPanel; public WatsonxSettingsForm(WatsonxSettingsState settings) { - onPremCheckbox = new JBCheckBox("Watsonx.ai on-premises software", false); + onPremRadio = new JBRadioButton("Watsonx.ai on-premises software", false); + onCloudRadio = new JBRadioButton("Watsonx.ai as a Service", true); + onPremHostField = new JBTextField(settings.getOnPremHost(), 35); usernameField = new JBTextField(settings.getUsername(), 35); @@ -76,9 +74,12 @@ class OpenUrlAction implements ActionListener { } } getStartedLink = new JButton(); - getStartedLink.setText(""+getStartedText+""); + getStartedLink.setText(""+getStartedText+""); getStartedLink.setHorizontalAlignment(SwingConstants.LEFT); getStartedLink.setBorderPainted(false); + getStartedLink.setContentAreaFilled(false); + getStartedLink.setFocusPainted(false); + getStartedLink.setOpaque(false); getStartedLink.setCursor(new Cursor(Cursor.HAND_CURSOR)); getStartedLink.setToolTipText(getStartedUrl); getStartedLink.addActionListener(new OpenUrlAction()); @@ -135,7 +136,8 @@ class OpenUrlAction implements ActionListener { apiVersionField = new JBTextField(settings.getApiVersion(), 35); projectIdField = new JBTextField(settings.getProjectId(), 35); spaceIdField = new JBTextField(settings.getSpaceId(), 35); - greedyDecodingCheckbox = new JBCheckBox("Greedy decoding", false); + greedyDecodingRadio = new JBRadioButton("Greedy decoding", false); + sampleDecodingRadio = new JBRadioButton("Sample decoding", true); modelComboBox = new ComboBox<>(new EnumComboBoxModel(WatsonxCompletionModel.class)); modelComboBox.setSelectedItem(WatsonxCompletionModel.findByCode(settings.getModel())); @@ -176,18 +178,33 @@ class OpenUrlAction implements ActionListener { registerPanelsVisibility(settings); - onPremCheckbox.addActionListener(e -> { - onPremAuthenticationFieldPanel.setVisible(!onPremAuthenticationFieldPanel.isVisible()); - onCloudAuthenticationFieldPanel.setVisible(!onCloudAuthenticationFieldPanel.isVisible()); - settings.setOnPrem(!settings.isOnPrem()); + onPremRadio.addActionListener(e -> { + settings.setOnPrem(true); + onCloudRadio.setSelected(false); + onPremAuthenticationFieldPanel.setVisible(true); + onCloudAuthenticationFieldPanel.setVisible(false); + }); - greedyDecodingCheckbox.addActionListener(e -> { - sampleParametersFieldPanel.setVisible(!sampleParametersFieldPanel.isVisible()); - settings.setGreedyDecoding(!settings.isGreedyDecoding()); + onCloudRadio.addActionListener(e -> { + settings.setOnPrem(false); + onPremRadio.setSelected(false); + onPremAuthenticationFieldPanel.setVisible(false); + onCloudAuthenticationFieldPanel.setVisible(true); + }); + greedyDecodingRadio.addActionListener(e -> { + settings.setGreedyDecoding(true); + sampleDecodingRadio.setSelected(false); + sampleParametersFieldPanel.setVisible(false); + }); + sampleDecodingRadio.addActionListener(e -> { + settings.setGreedyDecoding(false); + greedyDecodingRadio.setSelected(false); + sampleParametersFieldPanel.setVisible(true); + }); } private void registerPanelsVisibility(WatsonxSettingsState settings) { @@ -200,8 +217,10 @@ public JPanel getForm() { return FormBuilder.createFormBuilder() .addComponent(new TitledSeparator("Connection Parameters")) .addComponent(UI.PanelFactory.grid() - .add(UI.PanelFactory.panel(onPremCheckbox) - .resizeX(false)).createPanel()) + .add(UI.PanelFactory.panel(onCloudRadio) + .resizeX(false)) + .add(UI.PanelFactory.panel(onPremRadio) + .resizeX(false)).createPanel()) .addComponent(onCloudAuthenticationFieldPanel) .addComponent(onPremAuthenticationFieldPanel) .addComponent(UI.PanelFactory.grid() @@ -245,13 +264,14 @@ public JPanel getForm() { .resizeX(false)) .add(UI.PanelFactory.panel(includeStopSequenceCheckbox) .resizeX(false)) - .add(UI.PanelFactory.panel(repetitionPenaltyField) .withLabel("Repetition Penalty:") .withComment(CodeGPTBundle.get( "settingsConfigurable.service.watsonx.repetitionPenalty.comment")) .resizeX(false)) - .add(UI.PanelFactory.panel(greedyDecodingCheckbox) + .add(UI.PanelFactory.panel(greedyDecodingRadio) + .resizeX(false)) + .add(UI.PanelFactory.panel(sampleDecodingRadio) .resizeX(false)) .createPanel()) .addComponent(sampleParametersFieldPanel) @@ -262,7 +282,7 @@ public JPanel getForm() { public WatsonxSettingsState getCurrentState() { var state = new WatsonxSettingsState(); state.setModel(((WatsonxCompletionModel) modelComboBox.getSelectedItem()).getCode()); - state.setOnPrem(onPremCheckbox.isSelected()); + state.setOnPrem(onPremRadio.isSelected()); state.setOnPremHost(onPremHostField.getText()); state.setUsername(usernameField.getText()); state.setZenApiKey(zenApiKeyCheckbox.isSelected()); @@ -270,7 +290,7 @@ public WatsonxSettingsState getCurrentState() { state.setApiVersion(apiVersionField.getText()); state.setSpaceId(spaceIdField.getText()); state.setProjectId(projectIdField.getText()); - state.setGreedyDecoding(greedyDecodingCheckbox.isSelected()); + state.setGreedyDecoding(greedyDecodingRadio.isSelected()); state.setMaxNewTokens(Integer.valueOf(maxNewTokensField.getText())); state.setMinNewTokens(Integer.valueOf(minNewTokensField.getText())); state.setTemperature(Double.valueOf(temperatureField.getText())); @@ -285,7 +305,8 @@ public WatsonxSettingsState getCurrentState() { public void resetForm() { var state = WatsonxSettings.getCurrentState(); - onPremCheckbox.setSelected(state.isOnPrem()); + onPremRadio.setSelected(state.isOnPrem()); + onCloudRadio.setSelected(!state.isOnPrem()); onPremHostField.setText(state.getOnPremHost()); usernameField.setText(state.getUsername()); zenApiKeyCheckbox.setSelected(state.isZenApiKey()); @@ -293,7 +314,8 @@ public void resetForm() { apiVersionField.setText(state.getApiVersion()); spaceIdField.setText(state.getSpaceId()); projectIdField.setText(state.getProjectId()); - greedyDecodingCheckbox.setSelected(state.isGreedyDecoding()); + greedyDecodingRadio.setSelected(state.isGreedyDecoding()); + sampleDecodingRadio.setSelected(!state.isGreedyDecoding()); modelComboBox.setSelectedItem(WatsonxCompletionModel.findByCode(state.getModel())); maxNewTokensField.setText(String.valueOf(state.getMaxNewTokens())); minNewTokensField.setText(String.valueOf(state.getMinNewTokens())); From 84e54e3e8b1f10cd1e54e75b18dbe8e2ad96a64d Mon Sep 17 00:00:00 2001 From: mq200 Date: Sat, 14 Sep 2024 14:49:11 -0700 Subject: [PATCH 4/4] add support for watsonx deployments --- .../CompletionRequestProvider.java | 10 +- .../service/watsonx/WatsonxSettings.java | 48 +- .../service/watsonx/WatsonxSettingsForm.java | 648 +++++++++--------- .../service/watsonx/WatsonxSettingsState.java | 395 ++++++----- .../CodeCompletionRequestFactory.kt | 10 +- .../resources/messages/codegpt.properties | 5 +- 6 files changed, 592 insertions(+), 524 deletions(-) diff --git a/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestProvider.java b/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestProvider.java index d6dd2b47a..1ebf0a503 100644 --- a/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestProvider.java +++ b/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestProvider.java @@ -316,9 +316,13 @@ public WatsonxCompletionRequest buildWatsonxChatCompletionRequest( prompt += "\n"+callParameters.getMessage().getPrompt(); var builder = new WatsonxCompletionRequest.Builder(prompt); builder.setDecodingMethod(settings.isGreedyDecoding() ? "greedy" : "sample"); - builder.setModelId(settings.getModel()); - builder.setProjectId(settings.getProjectId()); - builder.setSpaceId(settings.getSpaceId()); + if (settings.getDeploymentId() != null && !settings.getDeploymentId().isEmpty()) { + builder.setDeploymentId(settings.getDeploymentId()); + } else { + builder.setModelId(settings.getModel()); + builder.setProjectId(settings.getProjectId()); + builder.setSpaceId(settings.getSpaceId()); + } builder.setMaxNewTokens(settings.getMaxNewTokens()); builder.setMinNewTokens(settings.getMinNewTokens()); builder.setTemperature(settings.getTemperature()); diff --git a/src/main/java/ee/carlrobert/codegpt/settings/service/watsonx/WatsonxSettings.java b/src/main/java/ee/carlrobert/codegpt/settings/service/watsonx/WatsonxSettings.java index a375a572d..4d0646c18 100644 --- a/src/main/java/ee/carlrobert/codegpt/settings/service/watsonx/WatsonxSettings.java +++ b/src/main/java/ee/carlrobert/codegpt/settings/service/watsonx/WatsonxSettings.java @@ -4,33 +4,33 @@ import com.intellij.openapi.components.PersistentStateComponent; import com.intellij.openapi.components.State; import com.intellij.openapi.components.Storage; -import ee.carlrobert.codegpt.completions.llama.LlamaModel; import org.jetbrains.annotations.NotNull; @State(name = "CodeGPT_WatsonxSettings", storages = @Storage("CodeGPT_WatsonxSettings.xml")) public class WatsonxSettings implements PersistentStateComponent { - private WatsonxSettingsState state = new WatsonxSettingsState(); - - @Override - @NotNull - public WatsonxSettingsState getState() { - return state; - } - - @Override - public void loadState(@NotNull WatsonxSettingsState state) { - this.state = state; - } - - public static WatsonxSettingsState getCurrentState() { - return getInstance().getState(); - } - public static boolean isCodeCompletionsPossible() { - return getInstance().getState().isCodeCompletionsEnabled(); - } - - public static ee.carlrobert.codegpt.settings.service.watsonx.WatsonxSettings getInstance() { - return ApplicationManager.getApplication().getService(WatsonxSettings.class); - } + private WatsonxSettingsState state = new WatsonxSettingsState(); + + public static WatsonxSettingsState getCurrentState() { + return getInstance().getState(); + } + + public static boolean isCodeCompletionsPossible() { + return getInstance().getState().isCodeCompletionsEnabled(); + } + + public static ee.carlrobert.codegpt.settings.service.watsonx.WatsonxSettings getInstance() { + return ApplicationManager.getApplication().getService(WatsonxSettings.class); + } + + @Override + @NotNull + public WatsonxSettingsState getState() { + return state; + } + + @Override + public void loadState(@NotNull WatsonxSettingsState state) { + this.state = state; + } } diff --git a/src/main/java/ee/carlrobert/codegpt/settings/service/watsonx/WatsonxSettingsForm.java b/src/main/java/ee/carlrobert/codegpt/settings/service/watsonx/WatsonxSettingsForm.java index a18be9d62..3c7b1aa88 100644 --- a/src/main/java/ee/carlrobert/codegpt/settings/service/watsonx/WatsonxSettingsForm.java +++ b/src/main/java/ee/carlrobert/codegpt/settings/service/watsonx/WatsonxSettingsForm.java @@ -6,330 +6,352 @@ import com.intellij.openapi.application.ApplicationManager; import com.intellij.openapi.ui.ComboBox; import com.intellij.ui.EnumComboBoxModel; - import com.intellij.ui.TitledSeparator; -import com.intellij.ui.components.*; +import com.intellij.ui.components.JBCheckBox; +import com.intellij.ui.components.JBPasswordField; +import com.intellij.ui.components.JBRadioButton; +import com.intellij.ui.components.JBTextField; import com.intellij.util.ui.FormBuilder; import com.intellij.util.ui.UI; import ee.carlrobert.codegpt.CodeGPTBundle; import ee.carlrobert.codegpt.credentials.CredentialsStore; import ee.carlrobert.llm.client.watsonx.completion.WatsonxCompletionModel; - - -import javax.swing.*; -import java.awt.*; +import java.awt.Cursor; import java.awt.event.ActionEvent; import java.awt.event.ActionListener; +import javax.swing.JButton; +import javax.swing.JPanel; +import javax.swing.SwingConstants; +import javax.swing.SwingUtilities; import org.jetbrains.annotations.Nullable; public class WatsonxSettingsForm { - private final JBRadioButton onPremRadio; - private final JBRadioButton onCloudRadio; - private final JBPasswordField apiKeyField; - private final JBPasswordField onPremApiKeyField; - private final ComboBox regionComboBox; - private final JBTextField onPremHostField; - private final JBTextField usernameField; - private final JBCheckBox zenApiKeyCheckbox; - - private final JPanel onPremAuthenticationFieldPanel; - private final JPanel onCloudAuthenticationFieldPanel; - - private final String getStartedText = "Get started with IBM watsonx.ai as a Service"; - private final String getStartedUrl = "https://dataplatform.cloud.ibm.com/registration/stepone?context=wx"; - private final JButton getStartedLink; - private final JBTextField apiVersionField; - private final JBTextField projectIdField; - private final JBTextField spaceIdField; - - private final ComboBox modelComboBox; - private final JBRadioButton greedyDecodingRadio; - private final JBRadioButton sampleDecodingRadio; - private final JBTextField temperatureField; - private final JBTextField topKField; - private final JBTextField topPField; - private final JBTextField repetitionPenaltyField; - private final JBTextField randomSeedField; - private final JBTextField maxNewTokensField; - private final JBTextField minNewTokensField; - private final JBTextField stopSequencesField; - private final JBCheckBox includeStopSequenceCheckbox; - private final JPanel sampleParametersFieldPanel; - - public WatsonxSettingsForm(WatsonxSettingsState settings) { - onPremRadio = new JBRadioButton("Watsonx.ai on-premises software", false); - onCloudRadio = new JBRadioButton("Watsonx.ai as a Service", true); - - onPremHostField = new JBTextField(settings.getOnPremHost(), 35); - - usernameField = new JBTextField(settings.getUsername(), 35); - - zenApiKeyCheckbox = new JBCheckBox("Is Platform (Zen) API key", false); - - - class OpenUrlAction implements ActionListener { - @Override public void actionPerformed(ActionEvent e) { - open(getStartedUrl); - } - } - getStartedLink = new JButton(); - getStartedLink.setText(""+getStartedText+""); - getStartedLink.setHorizontalAlignment(SwingConstants.LEFT); - getStartedLink.setBorderPainted(false); - getStartedLink.setContentAreaFilled(false); - getStartedLink.setFocusPainted(false); - getStartedLink.setOpaque(false); - getStartedLink.setCursor(new Cursor(Cursor.HAND_CURSOR)); - getStartedLink.setToolTipText(getStartedUrl); - getStartedLink.addActionListener(new OpenUrlAction()); - - regionComboBox = new ComboBox(new String[] {"Dallas", "Frankfurt", "London", "Tokyo"}); - regionComboBox.setSelectedItem(settings.getRegion()); - - apiKeyField = new JBPasswordField(); - apiKeyField.setColumns(35); - ApplicationManager.getApplication().executeOnPooledThread(() -> { - var apiKey = CredentialsStore.getCredential(WATSONX_API_KEY); - SwingUtilities.invokeLater(() -> apiKeyField.setText(apiKey)); - }); - - onPremApiKeyField = new JBPasswordField(); - onPremApiKeyField.setColumns(35); - ApplicationManager.getApplication().executeOnPooledThread(() -> { - var apiKey = CredentialsStore.getCredential(WATSONX_API_KEY); - SwingUtilities.invokeLater(() -> onPremApiKeyField.setText(apiKey)); - }); - - onPremAuthenticationFieldPanel = new UI.PanelFactory().grid() - .add(UI.PanelFactory.panel(onPremHostField) - .withLabel(CodeGPTBundle.get("settingsConfigurable.service.watsonx.onPremHost.label")) - .withComment(CodeGPTBundle.get( - "settingsConfigurable.service.watsonx.onPremHost.comment")) - .resizeX(false)) - .add(UI.PanelFactory.panel(usernameField) - .withLabel(CodeGPTBundle.get("settingsConfigurable.service.watsonx.username.label")) - .withComment(CodeGPTBundle.get( - "settingsConfigurable.service.watsonx.username.comment")) - .resizeX(false)) - .add(UI.PanelFactory.panel(onPremApiKeyField) - .withLabel(CodeGPTBundle.get("settingsConfigurable.service.watsonx.onPremApiKey.label")) - .withComment(CodeGPTBundle.get( - "settingsConfigurable.service.watsonx.onPremApiKey.comment")) - .resizeX(false)) - .add(UI.PanelFactory.panel(zenApiKeyCheckbox).resizeX(false)) - .createPanel(); - - onCloudAuthenticationFieldPanel = new UI.PanelFactory().grid() - .add(UI.PanelFactory.panel(getStartedLink)) - .add(UI.PanelFactory.panel(regionComboBox) - .withLabel(CodeGPTBundle.get("settingsConfigurable.service.watsonx.cloudRegion.label")) - .withComment(CodeGPTBundle.get("settingsConfigurable.service.watsonx.cloudRegion.comment")) - .resizeX(false)) - .add(UI.PanelFactory.panel(apiKeyField) - .withLabel(CodeGPTBundle.get("settingsConfigurable.service.watsonx.onCloudApiKey.label")) - .withComment(CodeGPTBundle.get( - "settingsConfigurable.service.watsonx.onCloudApiKey.comment")) - .resizeX(false)) - .createPanel(); - - apiVersionField = new JBTextField(settings.getApiVersion(), 35); - projectIdField = new JBTextField(settings.getProjectId(), 35); - spaceIdField = new JBTextField(settings.getSpaceId(), 35); - greedyDecodingRadio = new JBRadioButton("Greedy decoding", false); - sampleDecodingRadio = new JBRadioButton("Sample decoding", true); - - modelComboBox = new ComboBox<>(new EnumComboBoxModel(WatsonxCompletionModel.class)); - modelComboBox.setSelectedItem(WatsonxCompletionModel.findByCode(settings.getModel())); - temperatureField=new JBTextField(String.valueOf(settings.getTemperature()),35); - topKField=new JBTextField(String.valueOf(settings.getTopK()),35); - topPField=new JBTextField(String.valueOf(settings.getTopP()),35); - repetitionPenaltyField=new JBTextField(String.valueOf(settings.getRepetitionPenalty()),35); - randomSeedField=new JBTextField(settings.getRandomSeed() == null ? "" : String.valueOf(settings.getRandomSeed()),35); - maxNewTokensField=new JBTextField(String.valueOf(settings.getMaxNewTokens()),35); - minNewTokensField=new JBTextField(String.valueOf(settings.getMinNewTokens()),35); - stopSequencesField=new JBTextField(String.valueOf(settings.getStopSequences()),35); - includeStopSequenceCheckbox=new JBCheckBox(String.valueOf(settings.getIncludeStopSequence())); - includeStopSequenceCheckbox.setText("Include stop sequence"); - includeStopSequenceCheckbox.setSelected(false); - - sampleParametersFieldPanel = UI.PanelFactory.grid() - .add(UI.PanelFactory.panel(temperatureField) - .withLabel(CodeGPTBundle.get("configurationConfigurable.section.assistant.temperatureField.label")) - .withComment(CodeGPTBundle.get( - "settingsConfigurable.service.watsonx.temperature.comment")) - .resizeX(false)) - .add(UI.PanelFactory.panel(topKField) - .withLabel(CodeGPTBundle.get("settingsConfigurable.service.llama.topK.label")) - .withComment(CodeGPTBundle.get( - "settingsConfigurable.service.watsonx.topK.comment")) - .resizeX(false)) - .add(UI.PanelFactory.panel(topPField) - .withLabel(CodeGPTBundle.get("settingsConfigurable.service.llama.topP.label")) - .withComment(CodeGPTBundle.get( - "settingsConfigurable.service.llama.topP.comment")) - .resizeX(false)) - .add(UI.PanelFactory.panel(randomSeedField) - .withLabel(CodeGPTBundle.get("settingsConfigurable.service.watsonx.randomSeed.label")) - .withComment(CodeGPTBundle.get( - "settingsConfigurable.service.watsonx.randomSeed.comment")) - .resizeX(false)) - .createPanel(); - - registerPanelsVisibility(settings); - - onPremRadio.addActionListener(e -> { - settings.setOnPrem(true); - onCloudRadio.setSelected(false); - onPremAuthenticationFieldPanel.setVisible(true); - onCloudAuthenticationFieldPanel.setVisible(false); - - }); - - onCloudRadio.addActionListener(e -> { - settings.setOnPrem(false); - onPremRadio.setSelected(false); - onPremAuthenticationFieldPanel.setVisible(false); - onCloudAuthenticationFieldPanel.setVisible(true); - - }); - - greedyDecodingRadio.addActionListener(e -> { - settings.setGreedyDecoding(true); - sampleDecodingRadio.setSelected(false); - sampleParametersFieldPanel.setVisible(false); - }); - - sampleDecodingRadio.addActionListener(e -> { - settings.setGreedyDecoding(false); - greedyDecodingRadio.setSelected(false); - sampleParametersFieldPanel.setVisible(true); - }); - } - - private void registerPanelsVisibility(WatsonxSettingsState settings) { - onPremAuthenticationFieldPanel.setVisible(settings.isOnPrem()); - onCloudAuthenticationFieldPanel.setVisible(!settings.isOnPrem()); - sampleParametersFieldPanel.setVisible(!settings.isGreedyDecoding()); - } - - public JPanel getForm() { - return FormBuilder.createFormBuilder() - .addComponent(new TitledSeparator("Connection Parameters")) - .addComponent(UI.PanelFactory.grid() - .add(UI.PanelFactory.panel(onCloudRadio) - .resizeX(false)) - .add(UI.PanelFactory.panel(onPremRadio) - .resizeX(false)).createPanel()) - .addComponent(onCloudAuthenticationFieldPanel) - .addComponent(onPremAuthenticationFieldPanel) - .addComponent(UI.PanelFactory.grid() - .add(UI.PanelFactory.panel(apiVersionField) - .withLabel(CodeGPTBundle.get("shared.apiVersion")) - .withComment(CodeGPTBundle.get( - "settingsConfigurable.service.watsonx.apiVersion.comment")) - .resizeX(false)) - .add(UI.PanelFactory.panel(projectIdField) - .withLabel("Project ID:") - .withComment(CodeGPTBundle.get( - "settingsConfigurable.service.watsonx.projectId.comment")) - .resizeX(false)) - .add(UI.PanelFactory.panel(spaceIdField) - .withLabel("Space ID:") - .withComment(CodeGPTBundle.get( - "settingsConfigurable.service.watsonx.spaceId.comment")) - .resizeX(false)) - .createPanel()) - .addComponent(new TitledSeparator("Generation Parameters")) - .addComponent(UI.PanelFactory.grid() - .add(UI.PanelFactory.panel(modelComboBox) - .withLabel(CodeGPTBundle.get("settingsConfigurable.shared.model.label")) - .withComment(CodeGPTBundle.get( - "settingsConfigurable.service.watsonx.modelId.comment")) - .resizeX(false)) - .add(UI.PanelFactory.panel(maxNewTokensField) - .withLabel(CodeGPTBundle.get("configurationConfigurable.section.assistant.maxTokensField.label")) - .withComment(CodeGPTBundle.get( - "settingsConfigurable.service.watsonx.maxNewTokens.comment")) - .resizeX(false)) - .add(UI.PanelFactory.panel(minNewTokensField) - .withLabel(CodeGPTBundle.get("settingsConfigurable.service.watsonx.minNewTokens.label")) - .withComment(CodeGPTBundle.get( - "settingsConfigurable.service.watsonx.minNewTokens.comment")) - .resizeX(false)) - .add(UI.PanelFactory.panel(stopSequencesField) - .withLabel(CodeGPTBundle.get("settingsConfigurable.service.watsonx.stopSequences.label")) - .withComment(CodeGPTBundle.get( - "settingsConfigurable.service.watsonx.stopSequences.comment")) - .resizeX(false)) - .add(UI.PanelFactory.panel(includeStopSequenceCheckbox) - .resizeX(false)) - .add(UI.PanelFactory.panel(repetitionPenaltyField) - .withLabel("Repetition Penalty:") - .withComment(CodeGPTBundle.get( - "settingsConfigurable.service.watsonx.repetitionPenalty.comment")) - .resizeX(false)) - .add(UI.PanelFactory.panel(greedyDecodingRadio) - .resizeX(false)) - .add(UI.PanelFactory.panel(sampleDecodingRadio) - .resizeX(false)) - .createPanel()) - .addComponent(sampleParametersFieldPanel) - .addComponentFillVertically(new JPanel(), 0) - .getPanel(); - } - - public WatsonxSettingsState getCurrentState() { - var state = new WatsonxSettingsState(); - state.setModel(((WatsonxCompletionModel) modelComboBox.getSelectedItem()).getCode()); - state.setOnPrem(onPremRadio.isSelected()); - state.setOnPremHost(onPremHostField.getText()); - state.setUsername(usernameField.getText()); - state.setZenApiKey(zenApiKeyCheckbox.isSelected()); - state.setRegion((String)regionComboBox.getSelectedItem()); - state.setApiVersion(apiVersionField.getText()); - state.setSpaceId(spaceIdField.getText()); - state.setProjectId(projectIdField.getText()); - state.setGreedyDecoding(greedyDecodingRadio.isSelected()); - state.setMaxNewTokens(Integer.valueOf(maxNewTokensField.getText())); - state.setMinNewTokens(Integer.valueOf(minNewTokensField.getText())); - state.setTemperature(Double.valueOf(temperatureField.getText())); - state.setRandomSeed(randomSeedField.getText().isEmpty() ? null : Integer.valueOf(randomSeedField.getText())); - state.setTopP(Double.valueOf(topPField.getText())); - state.setTopK(Integer.valueOf(topKField.getText())); - state.setIncludeStopSequence(includeStopSequenceCheckbox.isSelected()); - state.setStopSequences(stopSequencesField.getText()); - state.setRepetitionPenalty(Double.valueOf(repetitionPenaltyField.getText())); - return state; - } - - public void resetForm() { - var state = WatsonxSettings.getCurrentState(); - onPremRadio.setSelected(state.isOnPrem()); - onCloudRadio.setSelected(!state.isOnPrem()); - onPremHostField.setText(state.getOnPremHost()); - usernameField.setText(state.getUsername()); - zenApiKeyCheckbox.setSelected(state.isZenApiKey()); - apiKeyField.setText(CredentialsStore.getCredential(WATSONX_API_KEY)); - apiVersionField.setText(state.getApiVersion()); - spaceIdField.setText(state.getSpaceId()); - projectIdField.setText(state.getProjectId()); - greedyDecodingRadio.setSelected(state.isGreedyDecoding()); - sampleDecodingRadio.setSelected(!state.isGreedyDecoding()); - modelComboBox.setSelectedItem(WatsonxCompletionModel.findByCode(state.getModel())); - maxNewTokensField.setText(String.valueOf(state.getMaxNewTokens())); - minNewTokensField.setText(String.valueOf(state.getMinNewTokens())); - temperatureField.setText(String.valueOf(state.getTemperature())); - randomSeedField.setText(state.getRandomSeed() == null ? "" : String.valueOf(state.getRandomSeed())); - topPField.setText(String.valueOf(state.getTopP())); - topKField.setText(String.valueOf(state.getTopK())); - repetitionPenaltyField.setText(String.valueOf(state.getRepetitionPenalty())); - includeStopSequenceCheckbox.setSelected(state.getIncludeStopSequence()); - stopSequencesField.setText(String.valueOf(state.getStopSequences())); - } - - public @Nullable String getApiKey() { - var apiKey = new String(apiKeyField.getPassword()); - return apiKey.isEmpty() ? null : apiKey; + private final JBRadioButton onPremRadio; + private final JBRadioButton onCloudRadio; + private final JBPasswordField apiKeyField; + private final JBPasswordField onPremApiKeyField; + private final ComboBox regionComboBox; + private final JBTextField onPremHostField; + private final JBTextField usernameField; + private final JBCheckBox zenApiKeyCheckbox; + + private final JPanel onPremAuthenticationFieldPanel; + private final JPanel onCloudAuthenticationFieldPanel; + + private final String getStartedText = "Get started with IBM watsonx.ai as a Service"; + private final String getStartedUrl = "https://dataplatform.cloud.ibm.com/registration/stepone?context=wx"; + private final JButton getStartedLink; + private final JBTextField apiVersionField; + private final JBTextField projectIdField; + private final JBTextField spaceIdField; + + private final ComboBox modelComboBox; + private final JBTextField deploymentIdField; + private final JBRadioButton greedyDecodingRadio; + private final JBRadioButton sampleDecodingRadio; + private final JBTextField temperatureField; + private final JBTextField topKField; + private final JBTextField topPField; + private final JBTextField repetitionPenaltyField; + private final JBTextField randomSeedField; + private final JBTextField maxNewTokensField; + private final JBTextField minNewTokensField; + private final JBTextField stopSequencesField; + private final JBCheckBox includeStopSequenceCheckbox; + private final JPanel sampleParametersFieldPanel; + + public WatsonxSettingsForm(WatsonxSettingsState settings) { + onPremRadio = new JBRadioButton("Watsonx.ai on-premises software", false); + onCloudRadio = new JBRadioButton("Watsonx.ai as a Service", true); + + onPremHostField = new JBTextField(settings.getOnPremHost(), 35); + + usernameField = new JBTextField(settings.getUsername(), 35); + + zenApiKeyCheckbox = new JBCheckBox("Is Platform (Zen) API key", false); + + class OpenUrlAction implements ActionListener { + + @Override + public void actionPerformed(ActionEvent e) { + open(getStartedUrl); + } } + getStartedLink = new JButton(); + getStartedLink.setText("" + getStartedText + ""); + getStartedLink.setHorizontalAlignment(SwingConstants.LEFT); + getStartedLink.setBorderPainted(false); + getStartedLink.setContentAreaFilled(false); + getStartedLink.setFocusPainted(false); + getStartedLink.setOpaque(false); + getStartedLink.setCursor(new Cursor(Cursor.HAND_CURSOR)); + getStartedLink.setToolTipText(getStartedUrl); + getStartedLink.addActionListener(new OpenUrlAction()); + + regionComboBox = new ComboBox(new String[]{"Dallas", "Frankfurt", "London", "Tokyo"}); + regionComboBox.setSelectedItem(settings.getRegion()); + + apiKeyField = new JBPasswordField(); + apiKeyField.setColumns(35); + ApplicationManager.getApplication().executeOnPooledThread(() -> { + var apiKey = CredentialsStore.getCredential(WATSONX_API_KEY); + SwingUtilities.invokeLater(() -> apiKeyField.setText(apiKey)); + }); + + onPremApiKeyField = new JBPasswordField(); + onPremApiKeyField.setColumns(35); + ApplicationManager.getApplication().executeOnPooledThread(() -> { + var apiKey = CredentialsStore.getCredential(WATSONX_API_KEY); + SwingUtilities.invokeLater(() -> onPremApiKeyField.setText(apiKey)); + }); + + onPremAuthenticationFieldPanel = new UI.PanelFactory().grid() + .add(UI.PanelFactory.panel(onPremHostField) + .withLabel(CodeGPTBundle.get("settingsConfigurable.service.watsonx.onPremHost.label")) + .withComment(CodeGPTBundle.get( + "settingsConfigurable.service.watsonx.onPremHost.comment")) + .resizeX(false)) + .add(UI.PanelFactory.panel(usernameField) + .withLabel(CodeGPTBundle.get("settingsConfigurable.service.watsonx.username.label")) + .withComment(CodeGPTBundle.get( + "settingsConfigurable.service.watsonx.username.comment")) + .resizeX(false)) + .add(UI.PanelFactory.panel(onPremApiKeyField) + .withLabel(CodeGPTBundle.get("settingsConfigurable.service.watsonx.onPremApiKey.label")) + .withComment(CodeGPTBundle.get( + "settingsConfigurable.service.watsonx.onPremApiKey.comment")) + .resizeX(false)) + .add(UI.PanelFactory.panel(zenApiKeyCheckbox).resizeX(false)) + .createPanel(); + + onCloudAuthenticationFieldPanel = new UI.PanelFactory().grid() + .add(UI.PanelFactory.panel(getStartedLink)) + .add(UI.PanelFactory.panel(regionComboBox) + .withLabel(CodeGPTBundle.get("settingsConfigurable.service.watsonx.cloudRegion.label")) + .withComment( + CodeGPTBundle.get("settingsConfigurable.service.watsonx.cloudRegion.comment")) + .resizeX(false)) + .add(UI.PanelFactory.panel(apiKeyField) + .withLabel( + CodeGPTBundle.get("settingsConfigurable.service.watsonx.onCloudApiKey.label")) + .withComment(CodeGPTBundle.get( + "settingsConfigurable.service.watsonx.onCloudApiKey.comment")) + .resizeX(false)) + .createPanel(); + + apiVersionField = new JBTextField(settings.getApiVersion(), 35); + projectIdField = new JBTextField(settings.getProjectId(), 35); + spaceIdField = new JBTextField(settings.getSpaceId(), 35); + greedyDecodingRadio = new JBRadioButton("Greedy decoding", false); + sampleDecodingRadio = new JBRadioButton("Sample decoding", true); + + modelComboBox = new ComboBox<>(new EnumComboBoxModel(WatsonxCompletionModel.class)); + modelComboBox.setSelectedItem(WatsonxCompletionModel.findByCode(settings.getModel())); + deploymentIdField = new JBTextField(settings.getSpaceId(), 35); + temperatureField = new JBTextField(String.valueOf(settings.getTemperature()), 17); + topKField = new JBTextField(String.valueOf(settings.getTopK()), 17); + topPField = new JBTextField(String.valueOf(settings.getTopP()), 17); + repetitionPenaltyField = new JBTextField(String.valueOf(settings.getRepetitionPenalty()), 17); + randomSeedField = new JBTextField( + settings.getRandomSeed() == null ? "" : String.valueOf(settings.getRandomSeed()), 17); + maxNewTokensField = new JBTextField(String.valueOf(settings.getMaxNewTokens()), 17); + minNewTokensField = new JBTextField(String.valueOf(settings.getMinNewTokens()), 17); + stopSequencesField = new JBTextField(String.valueOf(settings.getStopSequences()), 35); + includeStopSequenceCheckbox = new JBCheckBox(String.valueOf(settings.getIncludeStopSequence())); + includeStopSequenceCheckbox.setText("Include stop sequence"); + includeStopSequenceCheckbox.setSelected(false); + + sampleParametersFieldPanel = UI.PanelFactory.grid() + .add(UI.PanelFactory.panel(temperatureField) + .withLabel(CodeGPTBundle.get( + "configurationConfigurable.section.assistant.temperatureField.label")) + .withComment(CodeGPTBundle.get( + "settingsConfigurable.service.watsonx.temperature.comment")) + .resizeX(false)) + .add(UI.PanelFactory.panel(topKField) + .withLabel(CodeGPTBundle.get("settingsConfigurable.service.llama.topK.label")) + .withComment(CodeGPTBundle.get( + "settingsConfigurable.service.watsonx.topK.comment")) + .resizeX(false)) + .add(UI.PanelFactory.panel(topPField) + .withLabel(CodeGPTBundle.get("settingsConfigurable.service.llama.topP.label")) + .withComment(CodeGPTBundle.get( + "settingsConfigurable.service.llama.topP.comment")) + .resizeX(false)) + .add(UI.PanelFactory.panel(randomSeedField) + .withLabel(CodeGPTBundle.get("settingsConfigurable.service.watsonx.randomSeed.label")) + .withComment(CodeGPTBundle.get( + "settingsConfigurable.service.watsonx.randomSeed.comment")) + .resizeX(false)) + .createPanel(); + + registerPanelsVisibility(settings); + + onPremRadio.addActionListener(e -> { + settings.setOnPrem(true); + onCloudRadio.setSelected(false); + onPremAuthenticationFieldPanel.setVisible(true); + onCloudAuthenticationFieldPanel.setVisible(false); + + }); + + onCloudRadio.addActionListener(e -> { + settings.setOnPrem(false); + onPremRadio.setSelected(false); + onPremAuthenticationFieldPanel.setVisible(false); + onCloudAuthenticationFieldPanel.setVisible(true); + + }); + + greedyDecodingRadio.addActionListener(e -> { + settings.setGreedyDecoding(true); + sampleDecodingRadio.setSelected(false); + sampleParametersFieldPanel.setVisible(false); + }); + + sampleDecodingRadio.addActionListener(e -> { + settings.setGreedyDecoding(false); + greedyDecodingRadio.setSelected(false); + sampleParametersFieldPanel.setVisible(true); + }); + } + + private void registerPanelsVisibility(WatsonxSettingsState settings) { + onPremAuthenticationFieldPanel.setVisible(settings.isOnPrem()); + onCloudAuthenticationFieldPanel.setVisible(!settings.isOnPrem()); + sampleParametersFieldPanel.setVisible(!settings.isGreedyDecoding()); + } + + public JPanel getForm() { + return FormBuilder.createFormBuilder() + .addComponent(new TitledSeparator("Connection Parameters")) + .addComponent(UI.PanelFactory.grid() + .add(UI.PanelFactory.panel(onCloudRadio) + .resizeX(false)) + .add(UI.PanelFactory.panel(onPremRadio) + .resizeX(false)).createPanel()) + .addComponent(onCloudAuthenticationFieldPanel) + .addComponent(onPremAuthenticationFieldPanel) + .addComponent(UI.PanelFactory.grid() + .add(UI.PanelFactory.panel(apiVersionField) + .withLabel(CodeGPTBundle.get("shared.apiVersion")) + .withComment(CodeGPTBundle.get( + "settingsConfigurable.service.watsonx.apiVersion.comment")) + .resizeX(false)) + .add(UI.PanelFactory.panel(projectIdField) + .withLabel("Project ID:") + .withComment(CodeGPTBundle.get( + "settingsConfigurable.service.watsonx.projectId.comment")) + .resizeX(false)) + .add(UI.PanelFactory.panel(spaceIdField) + .withLabel("Space ID:") + .withComment(CodeGPTBundle.get( + "settingsConfigurable.service.watsonx.spaceId.comment")) + .resizeX(false)) + .add(UI.PanelFactory.panel(modelComboBox) + .withLabel(CodeGPTBundle.get("settingsConfigurable.shared.model.label")) + .withComment(CodeGPTBundle.get( + "settingsConfigurable.service.watsonx.modelId.comment")) + .resizeX(false)) + .add(UI.PanelFactory.panel(deploymentIdField) + .withLabel("Deployment ID:") + .withComment(CodeGPTBundle.get( + "settingsConfigurable.service.watsonx.deploymentId.comment")) + .resizeX(false)) + .createPanel()) + .addComponent(new TitledSeparator("Generation Parameters")) + .addComponent(UI.PanelFactory.grid() + .add(UI.PanelFactory.panel(maxNewTokensField) + .withLabel(CodeGPTBundle.get( + "configurationConfigurable.section.assistant.maxTokensField.label")) + .withComment(CodeGPTBundle.get( + "settingsConfigurable.service.watsonx.maxNewTokens.comment")) + .resizeX(false)) + .add(UI.PanelFactory.panel(minNewTokensField) + .withLabel( + CodeGPTBundle.get("settingsConfigurable.service.watsonx.minNewTokens.label")) + .withComment(CodeGPTBundle.get( + "settingsConfigurable.service.watsonx.minNewTokens.comment")) + .resizeX(false)) + .add(UI.PanelFactory.panel(stopSequencesField) + .withLabel( + CodeGPTBundle.get("settingsConfigurable.service.watsonx.stopSequences.label")) + .withComment(CodeGPTBundle.get( + "settingsConfigurable.service.watsonx.stopSequences.comment")) + .resizeX(false)) + .add(UI.PanelFactory.panel(includeStopSequenceCheckbox) + .resizeX(false)) + .add(UI.PanelFactory.panel(repetitionPenaltyField) + .withLabel("Repetition Penalty:") + .withComment(CodeGPTBundle.get( + "settingsConfigurable.service.watsonx.repetitionPenalty.comment")) + .resizeX(false)) + .add(UI.PanelFactory.panel(greedyDecodingRadio) + .resizeX(false)) + .add(UI.PanelFactory.panel(sampleDecodingRadio) + .resizeX(false)) + .createPanel()) + .addComponent(sampleParametersFieldPanel) + .addComponentFillVertically(new JPanel(), 0) + .getPanel(); + } + + public WatsonxSettingsState getCurrentState() { + var state = new WatsonxSettingsState(); + state.setModel(((WatsonxCompletionModel) modelComboBox.getSelectedItem()).getCode()); + state.setDeploymentId(deploymentIdField.getText()); + state.setOnPrem(onPremRadio.isSelected()); + state.setOnPremHost(onPremHostField.getText()); + state.setUsername(usernameField.getText()); + state.setZenApiKey(zenApiKeyCheckbox.isSelected()); + state.setRegion((String) regionComboBox.getSelectedItem()); + state.setApiVersion(apiVersionField.getText()); + state.setSpaceId(spaceIdField.getText()); + state.setProjectId(projectIdField.getText()); + state.setGreedyDecoding(greedyDecodingRadio.isSelected()); + state.setMaxNewTokens(Integer.valueOf(maxNewTokensField.getText())); + state.setMinNewTokens(Integer.valueOf(minNewTokensField.getText())); + state.setTemperature(Double.valueOf(temperatureField.getText())); + state.setRandomSeed( + randomSeedField.getText().isEmpty() ? null : Integer.valueOf(randomSeedField.getText())); + state.setTopP(Double.valueOf(topPField.getText())); + state.setTopK(Integer.valueOf(topKField.getText())); + state.setIncludeStopSequence(includeStopSequenceCheckbox.isSelected()); + state.setStopSequences(stopSequencesField.getText()); + state.setRepetitionPenalty(Double.valueOf(repetitionPenaltyField.getText())); + return state; + } + + public void resetForm() { + var state = WatsonxSettings.getCurrentState(); + onPremRadio.setSelected(state.isOnPrem()); + onCloudRadio.setSelected(!state.isOnPrem()); + onPremHostField.setText(state.getOnPremHost()); + usernameField.setText(state.getUsername()); + zenApiKeyCheckbox.setSelected(state.isZenApiKey()); + apiKeyField.setText(CredentialsStore.getCredential(WATSONX_API_KEY)); + apiVersionField.setText(state.getApiVersion()); + spaceIdField.setText(state.getSpaceId()); + projectIdField.setText(state.getProjectId()); + greedyDecodingRadio.setSelected(state.isGreedyDecoding()); + sampleDecodingRadio.setSelected(!state.isGreedyDecoding()); + modelComboBox.setSelectedItem(WatsonxCompletionModel.findByCode(state.getModel())); + deploymentIdField.setText(state.getDeploymentId()); + maxNewTokensField.setText(String.valueOf(state.getMaxNewTokens())); + minNewTokensField.setText(String.valueOf(state.getMinNewTokens())); + temperatureField.setText(String.valueOf(state.getTemperature())); + randomSeedField.setText( + state.getRandomSeed() == null ? "" : String.valueOf(state.getRandomSeed())); + topPField.setText(String.valueOf(state.getTopP())); + topKField.setText(String.valueOf(state.getTopK())); + repetitionPenaltyField.setText(String.valueOf(state.getRepetitionPenalty())); + includeStopSequenceCheckbox.setSelected(state.getIncludeStopSequence()); + stopSequencesField.setText(String.valueOf(state.getStopSequences())); + } + + public @Nullable String getApiKey() { + var apiKey = new String(apiKeyField.getPassword()); + return apiKey.isEmpty() ? null : apiKey; + } } diff --git a/src/main/java/ee/carlrobert/codegpt/settings/service/watsonx/WatsonxSettingsState.java b/src/main/java/ee/carlrobert/codegpt/settings/service/watsonx/WatsonxSettingsState.java index 567ec5b4d..0c5ef00c3 100644 --- a/src/main/java/ee/carlrobert/codegpt/settings/service/watsonx/WatsonxSettingsState.java +++ b/src/main/java/ee/carlrobert/codegpt/settings/service/watsonx/WatsonxSettingsState.java @@ -4,183 +4,220 @@ public class WatsonxSettingsState { - private String onPremHost; - private String username; - private boolean isOnPrem = false; - private boolean isZenApiKey = false; - private String region = "us-south"; - private String apiVersion = "2024-03-14"; - // use this model as default - private String model = "ibm/granite-3b-code-instruct"; - private String spaceId; - private String projectId; - private boolean isGreedyDecoding = false; - private boolean codeCompletionsEnabled = false; - private Double temperature = 0.9; - private Integer topK = 40; - private Double topP = 0.9; - private Integer maxNewTokens= 4000; - private Integer minNewTokens = 0; - private Boolean includeStopSequence = false; - private String stopSequences = ""; - private Double repetitionPenalty = 1.1; - private Integer randomSeed; - - public boolean isOnPrem() { - return isOnPrem; - } - - public void setOnPrem(boolean onPrem) { - isOnPrem = onPrem; - } - - public String getOnPremHost() { - return onPremHost; - } - - public void setOnPremHost(String onPremHost) { - this.onPremHost = onPremHost; - } - - public boolean isZenApiKey() { - return isZenApiKey; - } - - public void setZenApiKey(boolean zenApiKey) { - isZenApiKey = zenApiKey; - } - - public String getUsername() { - return username; - } - - public void setUsername(String username) { - this.username = username; - } - - public String getRegion() {return region;} - - public void setRegion(String region) { - this.region = region; - } - - public String getApiVersion() { - return apiVersion; - } - public void setApiVersion(String apiVersion) { - this.apiVersion = apiVersion; - } - - public String getSpaceId() {return spaceId;} - public void setSpaceId(String spaceId) { - this.spaceId = spaceId; - } - - public String getProjectId() {return projectId;} - public void setProjectId(String projectId) { - this.projectId = projectId; - } - - public Boolean isGreedyDecoding() { - return isGreedyDecoding; - } - public void setGreedyDecoding(boolean isGreedyDecoding) { - this.isGreedyDecoding = isGreedyDecoding; - } - public String getModel() { - return model; - } - public void setModel(String model) { - this.model = model; - } - - public Double getTemperature() { - return temperature; - } - - public void setTemperature(Double temperature) { - this.temperature = temperature; - } - - public Integer getMaxNewTokens() { - return maxNewTokens; - } - public void setMaxNewTokens(Integer maxNewTokens) { - this.maxNewTokens = maxNewTokens; - } - public Integer getMinNewTokens() { - return minNewTokens; - } - public void setMinNewTokens(Integer minNewTokens) { - this.minNewTokens = minNewTokens; - } - - public Double getRepetitionPenalty() { - return repetitionPenalty; - } - - public void setRepetitionPenalty(Double repetitionPenalty) { - this.repetitionPenalty = repetitionPenalty; - } - - public Double getTopP() { - return topP; - } - - public void setTopP(Double topP) { - this.topP = topP; - } - - public Integer getTopK() { - return topK; - } - public void setTopK(Integer topK) { - this.topK = topK; - } - - public Boolean getIncludeStopSequence() { - return includeStopSequence; - } - - public void setIncludeStopSequence(Boolean includeStopSequence) { - this.includeStopSequence = includeStopSequence; - } - - public String getStopSequences() { - return stopSequences; - } - public void setStopSequences(String stopSequences){ - this.stopSequences = stopSequences; - } - - public Integer getRandomSeed() { - return randomSeed; - } - public void setRandomSeed(Integer randomSeed){ - this.randomSeed=randomSeed; - } - - public boolean isCodeCompletionsEnabled() { - return codeCompletionsEnabled; - } - - public void setCodeCompletionsEnabled(boolean codeCompletionsEnabled) { - this.codeCompletionsEnabled = codeCompletionsEnabled; - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - ee.carlrobert.codegpt.settings.service.watsonx.WatsonxSettingsState that = (ee.carlrobert.codegpt.settings.service.watsonx.WatsonxSettingsState) o; - return Objects.equals(apiVersion, that.apiVersion) && Objects.equals(region, that.region) && Objects.equals(spaceId, that.spaceId) && Objects.equals(projectId, that.projectId) && Objects.equals(model, that.model) && Objects.equals(temperature,that.temperature) && Objects.equals(topP,that.topP) && Objects.equals(topK,that.topK) && Objects.equals(randomSeed,that.randomSeed) && Objects.equals(repetitionPenalty,that.repetitionPenalty) && Objects.equals(maxNewTokens, that.maxNewTokens) && Objects.equals(minNewTokens,that.minNewTokens) && Objects.equals(isGreedyDecoding,that.isGreedyDecoding) && Objects.equals(isOnPrem,that.isOnPrem) && Objects.equals(isZenApiKey,that.isZenApiKey); - } - - @Override - public int hashCode() { - return Objects.hash(apiVersion, region, model, apiVersion, projectId, spaceId,temperature,topP,topK,randomSeed,includeStopSequence,stopSequences,repetitionPenalty, maxNewTokens,minNewTokens,isGreedyDecoding,isOnPrem,isZenApiKey); } + private String onPremHost; + private String username; + private boolean isOnPrem = false; + private boolean isZenApiKey = false; + private String region = "us-south"; + private String apiVersion = "2024-03-14"; + private String model = "ibm/granite-3b-code-instruct"; + private String deploymentId; + private String spaceId; + private String projectId; + private boolean isGreedyDecoding = false; + private boolean codeCompletionsEnabled = false; + private Double temperature = 0.9; + private Integer topK = 40; + private Double topP = 0.9; + private Integer maxNewTokens = 4000; + private Integer minNewTokens = 0; + private Boolean includeStopSequence = false; + private String stopSequences = ""; + private Double repetitionPenalty = 1.1; + private Integer randomSeed; + + public boolean isOnPrem() { + return isOnPrem; + } + + public void setOnPrem(boolean onPrem) { + isOnPrem = onPrem; + } + + public String getOnPremHost() { + return onPremHost; + } + + public void setOnPremHost(String onPremHost) { + this.onPremHost = onPremHost; + } + + public boolean isZenApiKey() { + return isZenApiKey; + } + + public void setZenApiKey(boolean zenApiKey) { + isZenApiKey = zenApiKey; + } + + public String getUsername() { + return username; + } + + public void setUsername(String username) { + this.username = username; + } + + public String getRegion() { + return region; + } + + public void setRegion(String region) { + this.region = region; + } + + public String getDeploymentId() { + return deploymentId; + } + + public void setDeploymentId(String deploymentId) { + this.deploymentId = deploymentId; + } + + public String getApiVersion() { + return apiVersion; + } + + public void setApiVersion(String apiVersion) { + this.apiVersion = apiVersion; + } + + public String getSpaceId() { + return spaceId; + } + + public void setSpaceId(String spaceId) { + this.spaceId = spaceId; + } + + public String getProjectId() { + return projectId; + } + + public void setProjectId(String projectId) { + this.projectId = projectId; + } + + public Boolean isGreedyDecoding() { + return isGreedyDecoding; + } + + public void setGreedyDecoding(boolean isGreedyDecoding) { + this.isGreedyDecoding = isGreedyDecoding; + } + + public String getModel() { + return model; + } + + public void setModel(String model) { + this.model = model; + } + + public Double getTemperature() { + return temperature; + } + + public void setTemperature(Double temperature) { + this.temperature = temperature; + } + + public Integer getMaxNewTokens() { + return maxNewTokens; + } + + public void setMaxNewTokens(Integer maxNewTokens) { + this.maxNewTokens = maxNewTokens; + } + + public Integer getMinNewTokens() { + return minNewTokens; + } + + public void setMinNewTokens(Integer minNewTokens) { + this.minNewTokens = minNewTokens; + } + + public Double getRepetitionPenalty() { + return repetitionPenalty; + } + + public void setRepetitionPenalty(Double repetitionPenalty) { + this.repetitionPenalty = repetitionPenalty; + } + + public Double getTopP() { + return topP; + } + + public void setTopP(Double topP) { + this.topP = topP; + } + + public Integer getTopK() { + return topK; + } + + public void setTopK(Integer topK) { + this.topK = topK; + } + + public Boolean getIncludeStopSequence() { + return includeStopSequence; + } + + public void setIncludeStopSequence(Boolean includeStopSequence) { + this.includeStopSequence = includeStopSequence; + } + + public String getStopSequences() { + return stopSequences; + } + + public void setStopSequences(String stopSequences) { + this.stopSequences = stopSequences; + } + + public Integer getRandomSeed() { + return randomSeed; + } + + public void setRandomSeed(Integer randomSeed) { + this.randomSeed = randomSeed; + } + + public boolean isCodeCompletionsEnabled() { + return codeCompletionsEnabled; + } + + public void setCodeCompletionsEnabled(boolean codeCompletionsEnabled) { + this.codeCompletionsEnabled = codeCompletionsEnabled; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + ee.carlrobert.codegpt.settings.service.watsonx.WatsonxSettingsState that = (ee.carlrobert.codegpt.settings.service.watsonx.WatsonxSettingsState) o; + return Objects.equals(apiVersion, that.apiVersion) && Objects.equals(region, that.region) + && Objects.equals(spaceId, that.spaceId) && Objects.equals(projectId, that.projectId) + && Objects.equals(model, that.model) && Objects.equals(deploymentId, that.deploymentId) + && Objects.equals(temperature, that.temperature) && Objects.equals(topP, that.topP) + && Objects.equals(topK, that.topK) && Objects.equals(randomSeed, that.randomSeed) + && Objects.equals(repetitionPenalty, that.repetitionPenalty) && Objects.equals(maxNewTokens, + that.maxNewTokens) && Objects.equals(minNewTokens, that.minNewTokens) && Objects.equals( + isGreedyDecoding, that.isGreedyDecoding) && Objects.equals(isOnPrem, that.isOnPrem) + && Objects.equals(isZenApiKey, that.isZenApiKey); + } + + @Override + public int hashCode() { + return Objects.hash(apiVersion, region, model, deploymentId, projectId, spaceId, temperature, + topP, topK, randomSeed, includeStopSequence, stopSequences, repetitionPenalty, maxNewTokens, + minNewTokens, isGreedyDecoding, isOnPrem, isZenApiKey); + } } diff --git a/src/main/kotlin/ee/carlrobert/codegpt/codecompletions/CodeCompletionRequestFactory.kt b/src/main/kotlin/ee/carlrobert/codegpt/codecompletions/CodeCompletionRequestFactory.kt index d47a7170e..075f3593d 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/codecompletions/CodeCompletionRequestFactory.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/codecompletions/CodeCompletionRequestFactory.kt @@ -33,9 +33,13 @@ object CodeCompletionRequestFactory { val settings = WatsonxSettings.getCurrentState(); val builder = WatsonxCompletionRequest.Builder(details.prefix) builder.setDecodingMethod(if (settings.isGreedyDecoding) "greedy" else "sample") - builder.setModelId(settings.model) - builder.setProjectId(settings.projectId) - builder.setSpaceId(settings.spaceId) + if (settings.deploymentId != null && !settings.deploymentId.isEmpty()) { + builder.setDeploymentId(settings.deploymentId); + } else { + builder.setModelId(settings.model) + builder.setProjectId(settings.projectId) + builder.setSpaceId(settings.spaceId) + } builder.setMaxNewTokens(settings.maxNewTokens) builder.setMinNewTokens(settings.minNewTokens) builder.setTemperature(settings.temperature) diff --git a/src/main/resources/messages/codegpt.properties b/src/main/resources/messages/codegpt.properties index 11ce1eeed..88916d8ea 100644 --- a/src/main/resources/messages/codegpt.properties +++ b/src/main/resources/messages/codegpt.properties @@ -157,11 +157,12 @@ settingsConfigurable.service.watsonx.randomSeed.comment=Specify an integer value settingsConfigurable.service.watsonx.apiVersion.comment= settingsConfigurable.service.watsonx.projectId.comment=Provide a project ID settingsConfigurable.service.watsonx.spaceId.comment=Provide a deployment space ID -settingsConfigurable.service.watsonx.modelId.comment=Select a model from the list +settingsConfigurable.service.watsonx.deploymentId.comment=Provide a deployment ID +settingsConfigurable.service.watsonx.modelId.comment= settingsConfigurable.service.watsonx.maxNewTokens.comment= settingsConfigurable.service.watsonx.minNewTokens.comment= settingsConfigurable.service.watsonx.minNewTokens.label=Min completion tokens: -settingsConfigurable.service.watsonx.stopSequences.label=Stop sequences +settingsConfigurable.service.watsonx.stopSequences.label=Stop sequences: settingsConfigurable.service.watsonx.stopSequences.comment=Comma-separated list of stop sequences settingsConfigurable.service.watsonx.repetitionPenalty.comment= settingsConfigurable.service.watsonx.cloudRegion.label=IBM Cloud region: