diff --git a/common/chat-parser.cpp b/common/chat-parser.cpp index 18a30e49aa578..96ba8f533ef1b 100644 --- a/common/chat-parser.cpp +++ b/common/chat-parser.cpp @@ -55,7 +55,15 @@ bool common_chat_msg_parser::add_tool_call(const std::string & name, const std:: bool common_chat_msg_parser::add_tool_call(const json & tool_call) { std::string name = tool_call.contains("name") ? tool_call.at("name") : ""; std::string id = tool_call.contains("id") ? tool_call.at("id") : ""; - std::string arguments = tool_call.contains("arguments") ? tool_call.at("arguments") : ""; + std::string arguments = ""; + if (tool_call.contains("arguments")) { + if (tool_call.at("arguments").is_object()) { + arguments = tool_call.at("arguments").dump(); + } else { + arguments = tool_call.at("arguments"); + } + } + return add_tool_call(name, id, arguments); } diff --git a/common/chat.cpp b/common/chat.cpp index 60805ab3b53f5..316bd24170c9e 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -606,6 +606,7 @@ const char * common_chat_format_name(common_chat_format format) { case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: return "Functionary v3.1 Llama 3.1"; case COMMON_CHAT_FORMAT_HERMES_2_PRO: return "Hermes 2 Pro"; case COMMON_CHAT_FORMAT_COMMAND_R7B: return "Command R7B"; + case COMMON_CHAT_FORMAT_GRANITE: return "Granite"; case COMMON_CHAT_FORMAT_GPT_OSS: return "GPT-OSS"; default: throw std::runtime_error("Unknown chat format"); @@ -618,6 +619,7 @@ const char * common_reasoning_format_name(common_reasoning_format format) { case COMMON_REASONING_FORMAT_AUTO: return "auto"; case COMMON_REASONING_FORMAT_DEEPSEEK: return "deepseek"; case COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY: return "deepseek-legacy"; + case COMMON_REASONING_FORMAT_GRANITE: return "granite"; default: throw std::runtime_error("Unknown reasoning format"); } @@ -1734,6 +1736,124 @@ static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) { builder.add_content(builder.consume_rest()); } +static common_chat_params common_chat_params_init_granite(const common_chat_template & tmpl, const struct templates_params & inputs) { + common_chat_params data; + + // Pass thinking context for Granite template + json additional_context = { + {"thinking", inputs.enable_thinking}, + }; + + data.prompt = apply(tmpl, inputs, /* messages_override= */ std::nullopt, /* tools_override= */ std::nullopt, additional_context); + data.format = COMMON_CHAT_FORMAT_GRANITE; + + if (string_ends_with(data.prompt, "\n") || string_ends_with(data.prompt, "")) { + if (!inputs.enable_thinking) { + data.prompt += ""; + } else { + data.thinking_forced_open = true; + } + } + + if (!inputs.tools.is_null()) { + // Granite uses <|tool_call|> followed by JSON list + data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; + data.grammar = build_grammar([&](const common_grammar_builder & builder) { + std::vector tool_rules; + foreach_function(inputs.tools, [&](const json & tool) { + const auto & function = tool.at("function"); + std::string name = function.at("name"); + auto parameters = function.at("parameters"); + builder.resolve_refs(parameters); + tool_rules.push_back(builder.add_rule(name + "-call", builder.add_schema(name + +"-args", { + {"type", "object"}, + {"properties", { + {"name", {{"const", name}}}, + {"arguments", parameters}, + }}, + {"required", json::array({"name", "arguments"})}, + }))); + }); + + auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | ")); + auto tool_list = builder.add_rule("tool_list", "\"[\" space " + tool_call + " (\",\" space " + tool_call + ")* space \"]\""); + + if (data.thinking_forced_open) { + builder.add_rule("root", "\"\" space \"\" space [^<]* \"\" space \"<|tool_call|>\" space " + tool_list); + } else { + builder.add_rule("root", "\"<|tool_call|>\" space " + tool_list); + } + + data.grammar_triggers.push_back({ + COMMON_GRAMMAR_TRIGGER_TYPE_WORD, + "<|tool_call|>" + }); + + data.preserved_tokens = { + "", + "", + "", + "", + "<|tool_call|>", + }; + }); + } else { + // Handle thinking tags for non-tool responses + if (data.thinking_forced_open && inputs.enable_thinking) { + data.grammar_lazy = false; + data.grammar = build_grammar([&](const common_grammar_builder & builder) { + builder.add_rule("root", "\"\" space \"\" space .* \"\" space"); + }); + data.preserved_tokens = { + "", + "", + "", + "", + }; + } + } + + return data; +} + +static void common_chat_parse_granite(common_chat_msg_parser & builder) { + // Parse thinking tags + builder.try_parse_reasoning("", ""); + + // Parse response tags using regex + static const common_regex response_regex("([\\s\\S]*?)"); + if (auto res = builder.try_find_regex(response_regex)) { + // Extract the content between the tags (capture group 1) + auto content = builder.str(res->groups[1]); + builder.add_content(content); + builder.move_to(res->groups[0].end); + } + + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + + // Look for tool calls + static const common_regex tool_call_regex(regex_escape("<|tool_call|>")); + if (auto res = builder.try_find_regex(tool_call_regex)) { + builder.move_to(res->groups[0].end); + + // Expect JSON array of tool calls + auto tool_calls_data = builder.consume_json(); + if (tool_calls_data.json.is_array()) { + if (!builder.add_tool_calls(tool_calls_data.json)) { + builder.add_content("<|tool_call|>" + tool_calls_data.json.dump()); + } + } else { + builder.add_content("<|tool_call|>" + tool_calls_data.json.dump()); + } + } else { + builder.add_content(builder.consume_rest()); + } +} + static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; data.prompt = apply(tmpl, inputs); @@ -1805,6 +1925,11 @@ static common_chat_params common_chat_templates_apply_jinja( return common_chat_params_init_command_r7b(tmpl, params); } + // Granite (IBM) - detects thinking / tools support + if (src.find("elif thinking") != std::string::npos && src.find("<|tool_call|>") != std::string::npos) { + return common_chat_params_init_granite(tmpl, params); + } + // Hermes 2/3 Pro, Qwen 2.5 Instruct (w/ tools) if (src.find("") != std::string::npos && params.json_schema.is_null()) { return common_chat_params_init_hermes_2_pro(tmpl, params); @@ -1865,6 +1990,7 @@ static common_chat_params common_chat_templates_apply_legacy( int alloc_size = 0; std::vector chat; std::vector contents; + for (const auto & msg : inputs.messages) { auto content = msg.content; for (const auto & part : msg.content_parts) { @@ -1966,6 +2092,9 @@ static void common_chat_parse(common_chat_msg_parser & builder) { case COMMON_CHAT_FORMAT_COMMAND_R7B: common_chat_parse_command_r7b(builder); break; + case COMMON_CHAT_FORMAT_GRANITE: + common_chat_parse_granite(builder); + break; case COMMON_CHAT_FORMAT_GPT_OSS: common_chat_parse_gpt_oss(builder); break; diff --git a/common/chat.h b/common/chat.h index b014f9f0aaeb4..eb628d8bc275d 100644 --- a/common/chat.h +++ b/common/chat.h @@ -109,6 +109,7 @@ enum common_chat_format { COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1, COMMON_CHAT_FORMAT_HERMES_2_PRO, COMMON_CHAT_FORMAT_COMMAND_R7B, + COMMON_CHAT_FORMAT_GRANITE, COMMON_CHAT_FORMAT_GPT_OSS, COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats diff --git a/common/common.h b/common/common.h index 6c1c7ee237a3a..5eab199af559e 100644 --- a/common/common.h +++ b/common/common.h @@ -239,6 +239,7 @@ enum common_reasoning_format { COMMON_REASONING_FORMAT_AUTO, COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY, // Extract thinking tag contents and return as `message.reasoning_content`, or leave inline in tags in stream mode COMMON_REASONING_FORMAT_DEEPSEEK, // Extract thinking tag contents and return as `message.reasoning_content`, including in streaming deltas. + COMMON_REASONING_FORMAT_GRANITE, // Extract thinking tag contents and return as `message.reasoning_content`, including in streaming deltas. }; struct common_params { diff --git a/models/templates/ibm-granite-granite-3.3-2B-Instruct.jinja b/models/templates/ibm-granite-granite-3.3-2B-Instruct.jinja new file mode 100644 index 0000000000000..f5065360960f0 --- /dev/null +++ b/models/templates/ibm-granite-granite-3.3-2B-Instruct.jinja @@ -0,0 +1,59 @@ +{# Alias tools -> available_tools #} +{%- if tools and not available_tools -%} + {%- set available_tools = tools -%} +{%- endif -%} +{%- if messages[0]['role'] == 'system' %} + {%- set system_message = messages[0]['content'] %} + {%- set loop_messages = messages[1:] %} + {%- else %} + {%- set system_message = "Knowledge Cutoff Date: April 2024. Today's Date: " + strftime_now('%B %d, %Y') + ". You are Granite, developed by IBM." %} + {%- if available_tools and documents %} + {%- set system_message = system_message + " You are a helpful assistant with access to the following tools. When a tool is required to answer the user's query, respond only with <|tool_call|> followed by a JSON list of tools used. If a tool does not exist in the provided list of tools, notify the user that you do not have the ability to fulfill the request. Write the response to the user's input by strictly aligning with the facts in the provided documents. If the information needed to answer the question is not available in the documents, inform the user that the question cannot be answered based on the available data." %} + {%- elif available_tools %} + {%- set system_message = system_message + " You are a helpful assistant with access to the following tools. When a tool is required to answer the user's query, respond only with <|tool_call|> followed by a JSON list of tools used. If a tool does not exist in the provided list of tools, notify the user that you do not have the ability to fulfill the request." %} + {%- elif documents %} + {%- set system_message = system_message + " Write the response to the user's input by strictly aligning with the facts in the provided documents. If the information needed to answer the question is not available in the documents, inform the user that the question cannot be answered based on the available data." %} + {%- elif thinking %} + {%- set system_message = system_message + " You are a helpful AI assistant. +Respond to every user query in a comprehensive and detailed way. You can write down your thoughts and reasoning process before responding. In the thought process, engage in a comprehensive cycle of analysis, summarization, exploration, reassessment, reflection, backtracing, and iteration to develop well-considered thinking process. In the response section, based on various attempts, explorations, and reflections from the thoughts section, systematically present the final solution that you deem correct. The response should summarize the thought process. Write your thoughts between and write your response between for each user query." %} + {%- else %} + {%- set system_message = system_message + " You are a helpful AI assistant." %} + {%- endif %} + {%- if 'citations' in controls and documents %} + {%- set system_message = system_message + ' +Use the symbols <|start_of_cite|> and <|end_of_cite|> to indicate when a fact comes from a document in the search result, e.g <|start_of_cite|> {document_id: 1}my fact <|end_of_cite|> for a fact from document 1. Afterwards, list all the citations with their corresponding documents in an ordered list.' %} + {%- endif %} + {%- if 'hallucinations' in controls and documents %} + {%- set system_message = system_message + ' +Finally, after the response is written, include a numbered list of sentences from the response with a corresponding risk value that are hallucinated and not based in the documents.' %} + {%- endif %} + {%- set loop_messages = messages %} + {%- endif %} + {{- '<|start_of_role|>system<|end_of_role|>' + system_message + '<|end_of_text|> +' }} + {%- if available_tools %} + {{- '<|start_of_role|>available_tools<|end_of_role|>' }} + {{- available_tools | tojson(indent=4) }} + {{- '<|end_of_text|> +' }} + {%- endif %} + {%- if documents %} + {%- for document in documents %} + {{- '<|start_of_role|>document {"document_id": "' + document['doc_id'] | string + '"}<|end_of_role|> +' }} + {{- document['text'] }} + {{- '<|end_of_text|> +' }} + {%- endfor %} + {%- endif %} + {%- for message in loop_messages %} + {{- '<|start_of_role|>' + message['role'] + '<|end_of_role|>' + message['content'] + '<|end_of_text|> +' }} + {%- if loop.last and add_generation_prompt %} + {{- '<|start_of_role|>assistant' }} + {%- if controls %} + {{- ' ' + controls | tojson()}} + {%- endif %} + {{- '<|end_of_role|>' }} + {%- endif %} + {%- endfor %} diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp index 73c98bfa207fc..99b4b4d5bac7b 100644 --- a/tests/test-chat.cpp +++ b/tests/test-chat.cpp @@ -1386,6 +1386,59 @@ static void test_template_output_parsers() { "{\"arg1\": 1}\n" "```<|tool▁call▁end|><|tool▁calls▁end|>"); } + { + auto tmpls = read_templates("models/templates/ibm-granite-granite-3.3-2B-Instruct.jinja"); + std::vector end_tokens{ "<|end_of_text|>" }; + + assert_equals(COMMON_CHAT_FORMAT_GRANITE, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format); + + assert_equals(COMMON_CHAT_FORMAT_GRANITE, common_chat_templates_apply(tmpls.get(), inputs_tools).format); + + // Test parsing regular content + assert_msg_equals(message_assist, + common_chat_parse( + "Hello, world!\nWhat's up?", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_GRANITE})); + + // Test parsing content with thinking + assert_msg_equals(message_assist_thoughts, + common_chat_parse( + "I'm\nthinkingHello, world!\nWhat's up?", + /* is_partial= */ false, + { + /* .format = */ COMMON_CHAT_FORMAT_GRANITE, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_GRANITE, + })); + + // Test parsing tool calls + assert_msg_equals(message_assist_call, + common_chat_parse( + "<|tool_call|>[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_GRANITE})); + + // Test template generation for regular content + test_templates(tmpls.get(), end_tokens, message_assist, tools, + "Hello, world!\nWhat's up?", + /* expect_grammar_triggered= */ false); + + // Test template generation for tool calls + test_templates(tmpls.get(), end_tokens, message_assist_call_id, tools, + "{\n" + " \"tool_calls\": [\n" + " {\n" + " \"name\": \"special_function\",\n" + " \"arguments\": {\n" + " \"arg1\": 1\n" + " },\n" + " \"id\": \"123456789\"\n" + " }\n" + " ]\n" + "}", + /* expect_grammar_triggered= */ false + ); + } } static void test_msg_diffs_compute() {