Skip to content

Commit d5169fc

Browse files
committed
fix: add TOOL_CHOICE to ModelOptions like TEMPERATURE not a sentinel
The pass-thru behavior was not clear enough, so adding it to ModelOptions where important options are known. Most of these are sentinels which are removed (because @@@) but this will be like TEMPERATURE which is passed through to the backends. No behavior change, but give a handly constant and a place to look for these. This does not address all the other possible pass through args. Signed-off-by: Mark Sturdevant <mark.sturdevant@ibm.com>
1 parent d2bf9c9 commit d5169fc

5 files changed

Lines changed: 18 additions & 11 deletions

File tree

cli/serve/app.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,14 +115,14 @@ def _build_model_options(request: ChatCompletionRequest) -> dict:
115115
"response_format", # Response format (json_object) - not yet implemented
116116
"functions", # Legacy function calling - not yet implemented
117117
"function_call", # Legacy function calling - not yet implemented
118-
# Tool choice is passed through as-is (not a ModelOption sentinel)
119118
}
120119
openai_to_model_option = {
121120
"temperature": ModelOption.TEMPERATURE,
122121
"max_tokens": ModelOption.MAX_NEW_TOKENS,
123122
"seed": ModelOption.SEED,
124123
"stream": ModelOption.STREAM,
125124
"tools": ModelOption.TOOLS,
125+
"tool_choice": ModelOption.TOOL_CHOICE,
126126
}
127127

128128
# Get all non-None fields

docs/examples/m_serve/m_serve_example_tool_calling.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -126,12 +126,13 @@ def serve(
126126
"""Serve function that handles tool calling.
127127
128128
This function demonstrates how to use tools with m serve. The tools
129-
are passed via model_options and the model can request to call them.
129+
are passed via model_options using ModelOption.TOOLS, and tool_choice
130+
can be specified using ModelOption.TOOL_CHOICE.
130131
131132
Args:
132133
input: List of chat messages
133134
requirements: Optional list of requirement strings
134-
model_options: Model options including tools and tool_choice
135+
model_options: Model options including ModelOption.TOOLS and ModelOption.TOOL_CHOICE
135136
136137
Returns:
137138
ModelOutputThunk with potential tool calls
@@ -169,9 +170,10 @@ def serve(
169170
# Example usage (for testing purposes)
170171
test_messages = [ChatMessage(role="user", content="What's the weather in Paris?")]
171172

172-
# Simulate tool definitions being passed
173+
# Simulate tool definitions being passed with tool_choice
173174
test_model_options = {
174-
ModelOption.TOOLS: [weather_tool.as_json_tool, stock_price_tool.as_json_tool]
175+
ModelOption.TOOLS: [weather_tool.as_json_tool, stock_price_tool.as_json_tool],
176+
ModelOption.TOOL_CHOICE: "auto", # Can be "none", "auto", or specific tool
175177
}
176178

177179
response = serve(input=test_messages, model_options=test_model_options)

mellea/backends/model_options.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ class ModelOption:
2222
2323
Attributes:
2424
TOOLS (str): Sentinel key for a list or dict of tools to expose for tool calling.
25+
TOOL_CHOICE (str): Key for tool choice strategy (passed through to the backend).
2526
MAX_NEW_TOKENS (str): Sentinel key for the maximum number of new tokens to generate.
2627
SYSTEM_PROMPT (str): Sentinel key for the system prompt string.
2728
TEMPERATURE (str): Key for the sampling temperature (passed through to the backend).
@@ -34,6 +35,9 @@ class ModelOption:
3435
TOOLS = "@@@tools@@@"
3536
"""Must be a list[Callable] or a dict[str, Callable] where str is the name of the function."""
3637

38+
TOOL_CHOICE = "tool_choice"
39+
"""Controls which tool the model should use. Can be "none", "auto", or a specific tool name."""
40+
3741
MAX_NEW_TOKENS = "@@@max_new_tokens@@@"
3842
SYSTEM_PROMPT = "@@@system_prompt@@@"
3943
TEMPERATURE = "temperature"

test/cli/test_serve.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -505,9 +505,9 @@ async def test_tool_params_passed_to_model_options(self, mock_module):
505505

506506
# Tools should be passed with ModelOption.TOOLS key
507507
assert ModelOption.TOOLS in model_options
508-
# tool_choice should be passed through as-is
509-
assert "tool_choice" in model_options
510-
assert model_options["tool_choice"] == "auto"
508+
# tool_choice should be passed through using ModelOption.TOOL_CHOICE
509+
assert ModelOption.TOOL_CHOICE in model_options
510+
assert model_options[ModelOption.TOOL_CHOICE] == "auto"
511511
# Legacy function calling parameters should still be excluded
512512
assert "functions" not in model_options
513513
assert "function_call" not in model_options

test/cli/test_serve_tool_calling.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
FunctionParameters,
1616
ToolFunction,
1717
)
18+
from mellea.backends import ModelOption
1819
from mellea.core.base import AbstractMelleaTool, ModelOutputThunk, ModelToolCall
1920

2021

@@ -223,9 +224,9 @@ async def test_tool_choice_passed_to_model_options(
223224
assert call_args is not None
224225
model_options = call_args.kwargs["model_options"]
225226

226-
# tool_choice should be passed through as-is
227-
assert "tool_choice" in model_options
228-
assert model_options["tool_choice"] == "auto"
227+
# tool_choice should be passed through using ModelOption.TOOL_CHOICE
228+
assert ModelOption.TOOL_CHOICE in model_options
229+
assert model_options[ModelOption.TOOL_CHOICE] == "auto"
229230

230231
@pytest.mark.asyncio
231232
async def test_tool_calls_with_complex_arguments(

0 commit comments

Comments
 (0)