Skip to content

Adding chat template support for Granite model #14864

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion common/chat-parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,15 @@ bool common_chat_msg_parser::add_tool_call(const std::string & name, const std::
bool common_chat_msg_parser::add_tool_call(const json & tool_call) {
std::string name = tool_call.contains("name") ? tool_call.at("name") : "";
std::string id = tool_call.contains("id") ? tool_call.at("id") : "";
std::string arguments = tool_call.contains("arguments") ? tool_call.at("arguments") : "";
std::string arguments = "";
if (tool_call.contains("arguments")) {
if (tool_call.at("arguments").is_object()) {
arguments = tool_call.at("arguments").dump();
} else {
arguments = tool_call.at("arguments");
}
}

return add_tool_call(name, id, arguments);
}

Expand Down
130 changes: 130 additions & 0 deletions common/chat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,8 @@ const char * common_chat_format_name(common_chat_format format) {
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: return "Functionary v3.1 Llama 3.1";
case COMMON_CHAT_FORMAT_HERMES_2_PRO: return "Hermes 2 Pro";
case COMMON_CHAT_FORMAT_COMMAND_R7B: return "Command R7B";
case COMMON_CHAT_FORMAT_GRANITE: return "Granite";

default:
throw std::runtime_error("Unknown chat format");
}
Expand All @@ -602,6 +604,7 @@ const char * common_reasoning_format_name(common_reasoning_format format) {
case COMMON_REASONING_FORMAT_NONE: return "none";
case COMMON_REASONING_FORMAT_DEEPSEEK: return "deepseek";
case COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY: return "deepseek-legacy";
case COMMON_REASONING_FORMAT_GRANITE: return "granite";
default:
throw std::runtime_error("Unknown reasoning format");
}
Expand Down Expand Up @@ -1700,6 +1703,124 @@ static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) {
}
}

static common_chat_params common_chat_params_init_granite(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;

// Pass thinking context for Granite template
json additional_context = {
{"thinking", inputs.enable_thinking},
};

data.prompt = apply(tmpl, inputs, /* messages_override= */ std::nullopt, /* tools_override= */ std::nullopt, additional_context);
data.format = COMMON_CHAT_FORMAT_GRANITE;

if (string_ends_with(data.prompt, "<think>\n") || string_ends_with(data.prompt, "<think>")) {
if (!inputs.enable_thinking) {
data.prompt += "</think>";
} else {
data.thinking_forced_open = true;
}
}

if (!inputs.tools.is_null()) {
// Granite uses <|tool_call|> followed by JSON list
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules;
foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool.at("function");
std::string name = function.at("name");
auto parameters = function.at("parameters");
builder.resolve_refs(parameters);
tool_rules.push_back(builder.add_rule(name + "-call", builder.add_schema(name +
"-args", {
{"type", "object"},
{"properties", {
{"name", {{"const", name}}},
{"arguments", parameters},
}},
{"required", json::array({"name", "arguments"})},
})));
});

auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | "));
auto tool_list = builder.add_rule("tool_list", "\"[\" space " + tool_call + " (\",\" space " + tool_call + ")* space \"]\"");

if (data.thinking_forced_open) {
builder.add_rule("root", "\"</think>\" space \"<response>\" space [^<]* \"</response>\" space \"<|tool_call|>\" space " + tool_list);
} else {
builder.add_rule("root", "\"<|tool_call|>\" space " + tool_list);
}

data.grammar_triggers.push_back({
COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
"<|tool_call|>"
});

data.preserved_tokens = {
"<think>",
"</think>",
"<response>",
"</response>",
"<|tool_call|>",
};
});
} else {
// Handle thinking tags for non-tool responses
if (data.thinking_forced_open && inputs.enable_thinking) {
data.grammar_lazy = false;
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
builder.add_rule("root", "\"</think>\" space \"<response>\" space .* \"</response>\" space");
});
data.preserved_tokens = {
"<think>",
"</think>",
"<response>",
"</response>",
};
}
}

return data;
}

static void common_chat_parse_granite(common_chat_msg_parser & builder) {
// Parse thinking tags
builder.try_parse_reasoning("<think>", "</think>");

// Parse response tags using regex
static const common_regex response_regex("<response>([\\s\\S]*?)</response>");
if (auto res = builder.try_find_regex(response_regex)) {
// Extract the content between the tags (capture group 1)
auto content = builder.str(res->groups[1]);
builder.add_content(content);
builder.move_to(res->groups[0].end);
}

if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}

// Look for tool calls
static const common_regex tool_call_regex(regex_escape("<|tool_call|>"));
if (auto res = builder.try_find_regex(tool_call_regex)) {
builder.move_to(res->groups[0].end);

// Expect JSON array of tool calls
auto tool_calls_data = builder.consume_json();
if (tool_calls_data.json.is_array()) {
if (!builder.add_tool_calls(tool_calls_data.json)) {
builder.add_content("<|tool_call|>" + tool_calls_data.json.dump());
}
} else {
builder.add_content("<|tool_call|>" + tool_calls_data.json.dump());
}
} else {
builder.add_content(builder.consume_rest());
}
}

static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;
data.prompt = apply(tmpl, inputs);
Expand Down Expand Up @@ -1769,6 +1890,11 @@ static common_chat_params common_chat_templates_apply_jinja(
return common_chat_params_init_command_r7b(tmpl, params);
}

// Granite (IBM) - detects thinking support
if (src.find("elif thinking") != std::string::npos && src.find("<think>") != std::string::npos) {
return common_chat_params_init_granite(tmpl, params);
}

// Hermes 2/3 Pro, Qwen 2.5 Instruct (w/ tools)
if (src.find("<tool_call>") != std::string::npos && params.json_schema.is_null()) {
return common_chat_params_init_hermes_2_pro(tmpl, params);
Expand Down Expand Up @@ -1824,6 +1950,7 @@ static common_chat_params common_chat_templates_apply_legacy(
int alloc_size = 0;
std::vector<llama_chat_message> chat;
std::vector<std::string> contents;

for (const auto & msg : inputs.messages) {
auto content = msg.content;
for (const auto & part : msg.content_parts) {
Expand Down Expand Up @@ -1925,6 +2052,9 @@ static void common_chat_parse(common_chat_msg_parser & builder) {
case COMMON_CHAT_FORMAT_COMMAND_R7B:
common_chat_parse_command_r7b(builder);
break;
case COMMON_CHAT_FORMAT_GRANITE:
common_chat_parse_granite(builder);
break;
default:
throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format));
}
Expand Down
1 change: 1 addition & 0 deletions common/chat.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ enum common_chat_format {
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
COMMON_CHAT_FORMAT_HERMES_2_PRO,
COMMON_CHAT_FORMAT_COMMAND_R7B,
COMMON_CHAT_FORMAT_GRANITE,

COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
};
Expand Down
1 change: 1 addition & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ enum common_reasoning_format {
COMMON_REASONING_FORMAT_NONE,
COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY, // Extract thinking tag contents and return as `message.reasoning_content`, or leave inline in <think> tags in stream mode
COMMON_REASONING_FORMAT_DEEPSEEK, // Extract thinking tag contents and return as `message.reasoning_content`, including in streaming deltas.
COMMON_REASONING_FORMAT_GRANITE, // Extract thinking tag contents and return as `message.reasoning_content`, including in streaming deltas.
};

struct common_params {
Expand Down
59 changes: 59 additions & 0 deletions models/templates/ibm-granite-granite-3.3-2B-Instruct.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
{# Alias tools -> available_tools #}
{%- if tools and not available_tools -%}
{%- set available_tools = tools -%}
{%- endif -%}
{%- if messages[0]['role'] == 'system' %}
{%- set system_message = messages[0]['content'] %}
{%- set loop_messages = messages[1:] %}
{%- else %}
{%- set system_message = "Knowledge Cutoff Date: April 2024. Today's Date: " + strftime_now('%B %d, %Y') + ". You are Granite, developed by IBM." %}
{%- if available_tools and documents %}
{%- set system_message = system_message + " You are a helpful assistant with access to the following tools. When a tool is required to answer the user's query, respond only with <|tool_call|> followed by a JSON list of tools used. If a tool does not exist in the provided list of tools, notify the user that you do not have the ability to fulfill the request. Write the response to the user's input by strictly aligning with the facts in the provided documents. If the information needed to answer the question is not available in the documents, inform the user that the question cannot be answered based on the available data." %}
{%- elif available_tools %}
{%- set system_message = system_message + " You are a helpful assistant with access to the following tools. When a tool is required to answer the user's query, respond only with <|tool_call|> followed by a JSON list of tools used. If a tool does not exist in the provided list of tools, notify the user that you do not have the ability to fulfill the request." %}
{%- elif documents %}
{%- set system_message = system_message + " Write the response to the user's input by strictly aligning with the facts in the provided documents. If the information needed to answer the question is not available in the documents, inform the user that the question cannot be answered based on the available data." %}
{%- elif thinking %}
{%- set system_message = system_message + " You are a helpful AI assistant.
Respond to every user query in a comprehensive and detailed way. You can write down your thoughts and reasoning process before responding. In the thought process, engage in a comprehensive cycle of analysis, summarization, exploration, reassessment, reflection, backtracing, and iteration to develop well-considered thinking process. In the response section, based on various attempts, explorations, and reflections from the thoughts section, systematically present the final solution that you deem correct. The response should summarize the thought process. Write your thoughts between <think></think> and write your response between <response></response> for each user query." %}
{%- else %}
{%- set system_message = system_message + " You are a helpful AI assistant." %}
{%- endif %}
{%- if 'citations' in controls and documents %}
{%- set system_message = system_message + '
Use the symbols <|start_of_cite|> and <|end_of_cite|> to indicate when a fact comes from a document in the search result, e.g <|start_of_cite|> {document_id: 1}my fact <|end_of_cite|> for a fact from document 1. Afterwards, list all the citations with their corresponding documents in an ordered list.' %}
{%- endif %}
{%- if 'hallucinations' in controls and documents %}
{%- set system_message = system_message + '
Finally, after the response is written, include a numbered list of sentences from the response with a corresponding risk value that are hallucinated and not based in the documents.' %}
{%- endif %}
{%- set loop_messages = messages %}
{%- endif %}
{{- '<|start_of_role|>system<|end_of_role|>' + system_message + '<|end_of_text|>
' }}
{%- if available_tools %}
{{- '<|start_of_role|>available_tools<|end_of_role|>' }}
{{- available_tools | tojson(indent=4) }}
{{- '<|end_of_text|>
' }}
{%- endif %}
{%- if documents %}
{%- for document in documents %}
{{- '<|start_of_role|>document {"document_id": "' + document['doc_id'] | string + '"}<|end_of_role|>
' }}
{{- document['text'] }}
{{- '<|end_of_text|>
' }}
{%- endfor %}
{%- endif %}
{%- for message in loop_messages %}
{{- '<|start_of_role|>' + message['role'] + '<|end_of_role|>' + message['content'] + '<|end_of_text|>
' }}
{%- if loop.last and add_generation_prompt %}
{{- '<|start_of_role|>assistant' }}
{%- if controls %}
{{- ' ' + controls | tojson()}}
{%- endif %}
{{- '<|end_of_role|>' }}
{%- endif %}
{%- endfor %}
53 changes: 53 additions & 0 deletions tests/test-chat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1343,6 +1343,59 @@ static void test_template_output_parsers() {
"{\"arg1\": 1}\n"
"```<|tool▁call▁end|><|tool▁calls▁end|>");
}
{
auto tmpls = read_templates("models/templates/ibm-granite-granite-3.3-2B-Instruct.jinja");
std::vector<std::string> end_tokens{ "<|end_of_text|>" };

assert_equals(COMMON_CHAT_FORMAT_GRANITE, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);

assert_equals(COMMON_CHAT_FORMAT_GRANITE, common_chat_templates_apply(tmpls.get(), inputs_tools).format);

// Test parsing regular content
assert_msg_equals(message_assist,
common_chat_parse(
"Hello, world!\nWhat's up?",
/* is_partial= */ false,
{COMMON_CHAT_FORMAT_GRANITE}));

// Test parsing content with thinking
assert_msg_equals(message_assist_thoughts,
common_chat_parse(
"<think>I'm\nthinking</think>Hello, world!\nWhat's up?",
/* is_partial= */ false,
{
/* .format = */ COMMON_CHAT_FORMAT_GRANITE,
/* .reasoning_format = */ COMMON_REASONING_FORMAT_GRANITE,
}));

// Test parsing tool calls
assert_msg_equals(message_assist_call,
common_chat_parse(
"<|tool_call|>[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]",
/* is_partial= */ false,
{COMMON_CHAT_FORMAT_GRANITE}));

// Test template generation for regular content
test_templates(tmpls.get(), end_tokens, message_assist, tools,
"Hello, world!\nWhat's up?",
/* expect_grammar_triggered= */ false);

// Test template generation for tool calls
test_templates(tmpls.get(), end_tokens, message_assist_call_id, tools,
"{\n"
" \"tool_calls\": [\n"
" {\n"
" \"name\": \"special_function\",\n"
" \"arguments\": {\n"
" \"arg1\": 1\n"
" },\n"
" \"id\": \"123456789\"\n"
" }\n"
" ]\n"
"}",
/* expect_grammar_triggered= */ false
);
}
}

static void test_msg_diffs_compute() {
Expand Down
Loading