diff --git a/cli/serve/app.py b/cli/serve/app.py index 583b28c01..7ad354a95 100644 --- a/cli/serve/app.py +++ b/cli/serve/app.py @@ -7,6 +7,7 @@ import sys import time import uuid +from typing import Literal try: import typer @@ -21,11 +22,15 @@ ) from e from mellea.backends.model_options import ModelOption -from mellea.helpers.openai_compatible_helpers import build_completion_usage +from mellea.helpers.openai_compatible_helpers import ( + build_completion_usage, + build_tool_calls, +) from .models import ( ChatCompletion, ChatCompletionMessage, + ChatCompletionMessageToolCall, ChatCompletionRequest, Choice, OpenAIError, @@ -111,14 +116,14 @@ def _build_model_options(request: ChatCompletionRequest) -> dict: "response_format", # Response format (json_object) - not yet implemented "functions", # Legacy function calling - not yet implemented "function_call", # Legacy function calling - not yet implemented - "tools", # Tool calling - not yet implemented - "tool_choice", # Tool choice - not yet implemented } openai_to_model_option = { "temperature": ModelOption.TEMPERATURE, "max_tokens": ModelOption.MAX_NEW_TOKENS, "seed": ModelOption.SEED, "stream": ModelOption.STREAM, + "tools": ModelOption.TOOLS, + "tool_choice": ModelOption.TOOL_CHOICE, } # Get all non-None fields @@ -171,8 +176,6 @@ async def endpoint(request: ChatCompletionRequest): model_options=model_options, ) - # system_fingerprint represents backend config hash, not model name - # The model name is already in response.model (line 73) # Leave as None since we don't track backend config fingerprints yet system_fingerprint = None @@ -190,6 +193,24 @@ async def endpoint(request: ChatCompletionRequest): media_type="text/event-stream", ) + tool_calls_list = build_tool_calls(output) + tool_calls = ( + [ + ChatCompletionMessageToolCall.model_validate(tc) + for tc in tool_calls_list + ] + if tool_calls_list + else None + ) + + # Determine finish_reason based on tool calls + finish_reason: ( + Literal[ + "stop", "length", "content_filter", "tool_calls", "function_call" + ] + | None + ) = "tool_calls" if tool_calls else "stop" + return ChatCompletion( id=completion_id, model=request.model, @@ -198,9 +219,11 @@ async def endpoint(request: ChatCompletionRequest): Choice( index=0, message=ChatCompletionMessage( - content=output.value, role="assistant" + content=output.value, + role="assistant", + tool_calls=tool_calls, ), - finish_reason="stop", + finish_reason=finish_reason, ) ], object="chat.completion", # type: ignore diff --git a/cli/serve/models.py b/cli/serve/models.py index 7e738730e..1130e62a9 100644 --- a/cli/serve/models.py +++ b/cli/serve/models.py @@ -80,6 +80,67 @@ class ChatCompletionRequest(BaseModel): extra: dict[str, Any] = Field(default_factory=dict) +class ToolCallFunction(BaseModel): + """Function details for a tool call.""" + + name: str + """The name of the function to call.""" + + arguments: str + """The arguments to call the function with, as a JSON string.""" + + +class ChatCompletionMessageToolCall(BaseModel): + """A tool call generated by the model (non-streaming).""" + + id: str + """The ID of the tool call.""" + + type: Literal["function"] + """The type of the tool. Currently, only 'function' is supported.""" + + function: ToolCallFunction + """The function that the model called.""" + + +class ToolCallFunctionDelta(BaseModel): + """Function details for a streaming tool call delta. + + In streaming responses, function name and arguments may arrive across + multiple chunks, so both fields are optional. + """ + + name: str | None = None + """The name of the function to call (may be None in delta chunks).""" + + arguments: str | None = None + """The arguments fragment for this delta (may be None in delta chunks).""" + + +class ChatCompletionMessageToolCallDelta(BaseModel): + """A tool call delta in a streaming response. + + Per OpenAI streaming spec, each delta must include an index field that + clients use to reassemble tool calls across chunks. The id, type, and + function fields are optional since they may arrive incrementally. + """ + + index: int + """The index of this tool call in the tool_calls array. + + Required for delta reassembly in OpenAI SDK and compatible clients. + """ + + id: str | None = None + """The ID of the tool call (may be None in subsequent delta chunks).""" + + type: Literal["function"] | None = None + """The type of the tool (may be None in subsequent delta chunks).""" + + function: ToolCallFunctionDelta | None = None + """The function delta for this chunk (may be None in some chunks).""" + + # Taking this from OpenAI types https://github.com/openai/openai-python/blob/main/src/openai/types/chat/chat_completion.py, class ChatCompletionMessage(BaseModel): content: str | None = None @@ -91,6 +152,9 @@ class ChatCompletionMessage(BaseModel): role: Literal["assistant"] """The role of the author of this message.""" + tool_calls: list[ChatCompletionMessageToolCall] | None = None + """The tool calls generated by the model, such as function calls.""" + class Choice(BaseModel): index: int @@ -144,6 +208,14 @@ class ChatCompletionChunkDelta(BaseModel): refusal: str | None = None """The refusal message fragment, if any.""" + tool_calls: list[ChatCompletionMessageToolCallDelta] | None = None + """The tool call deltas in this chunk. + + Each delta includes a required index field for reassembly by OpenAI SDK + and compatible clients. The id, type, and function fields are optional + since they may arrive incrementally across multiple chunks. + """ + class ChatCompletionChunkChoice(BaseModel): """A choice in a streaming chunk.""" diff --git a/cli/serve/streaming.py b/cli/serve/streaming.py index 51ff33c3c..4d2f8f8ec 100644 --- a/cli/serve/streaming.py +++ b/cli/serve/streaming.py @@ -1,15 +1,20 @@ """Streaming utilities for OpenAI-compatible server responses.""" from collections.abc import AsyncGenerator +from typing import Literal from mellea.core.base import ModelOutputThunk from mellea.core.utils import MelleaLogger -from mellea.helpers.openai_compatible_helpers import build_completion_usage +from mellea.helpers.openai_compatible_helpers import ( + build_completion_usage, + build_tool_calls, +) from .models import ( ChatCompletionChunk, ChatCompletionChunkChoice, ChatCompletionChunkDelta, + ChatCompletionMessageToolCallDelta, OpenAIError, OpenAIErrorResponse, StreamOptions, @@ -98,6 +103,38 @@ async def stream_chat_completion_chunks( ) yield f"data: {chunk.model_dump_json()}\n\n" + tool_calls_list = build_tool_calls(output) + + if tool_calls_list: + # Convert to ChatCompletionMessageToolCallDelta objects with required index + tool_calls = [ + ChatCompletionMessageToolCallDelta.model_validate({**tc, "index": idx}) + for idx, tc in enumerate(tool_calls_list) + ] + + # Emit tool calls in a separate chunk before the final chunk + tool_call_chunk = ChatCompletionChunk( + id=completion_id, + model=model, + created=created, + choices=[ + ChatCompletionChunkChoice( + index=0, + delta=ChatCompletionChunkDelta(tool_calls=tool_calls), + finish_reason=None, + ) + ], + object="chat.completion.chunk", + system_fingerprint=system_fingerprint, + ) + yield f"data: {tool_call_chunk.model_dump_json()}\n\n" + + # Determine finish_reason based on tool calls + finish_reason: ( + Literal["stop", "length", "content_filter", "tool_calls", "function_call"] + | None + ) = "tool_calls" if tool_calls_list else "stop" + # Include usage in final chunk only if explicitly requested via stream_options # Per OpenAI spec: usage is only included when stream_options.include_usage=True include_usage = stream_options is not None and stream_options.include_usage @@ -112,7 +149,7 @@ async def stream_chat_completion_chunks( ChatCompletionChunkChoice( index=0, delta=ChatCompletionChunkDelta(content=None), - finish_reason="stop", + finish_reason=finish_reason, ) ], object="chat.completion.chunk", diff --git a/docs/examples/m_serve/client_streaming_tool_calling.py b/docs/examples/m_serve/client_streaming_tool_calling.py new file mode 100644 index 000000000..2a406564e --- /dev/null +++ b/docs/examples/m_serve/client_streaming_tool_calling.py @@ -0,0 +1,323 @@ +"""Client example for testing streaming with tool calling. + +This script demonstrates how to use streaming responses with tool calls +from an m serve server using the OpenAI-compatible API. + +Usage: + 1. Start the server: + uv run m serve docs/examples/m_serve/m_serve_example_tool_calling.py + + 2. Run this client: + uv run python docs/examples/m_serve/client_streaming_tool_calling.py +""" + +import json +from typing import Any + +import requests + +# Server configuration +BASE_URL = "http://localhost:8080" +ENDPOINT = f"{BASE_URL}/v1/chat/completions" + +# Define tools in OpenAI format +tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the current weather in a given location", + "parameters": { + "RootModel": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city name, e.g. San Francisco", + }, + "units": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "Temperature units", + }, + }, + "required": ["location"], + } + }, + }, + }, + { + "type": "function", + "function": { + "name": "get_stock_price", + "description": "Get the current stock price for a given ticker symbol", + "parameters": { + "RootModel": { + "type": "object", + "properties": { + "symbol": { + "type": "string", + "description": "The stock ticker symbol, e.g. AAPL, GOOGL", + } + }, + "required": ["symbol"], + } + }, + }, + }, +] + + +def make_streaming_request( + messages: list[dict[str, Any]], + tools: list[dict[str, Any]] | None = None, + tool_name: str | None = None, +) -> tuple[str, list[dict[str, Any]] | None, str]: + """Make a streaming request to the m serve API. + + Args: + messages: List of message dictionaries + tools: Optional list of tool definitions + tool_name: Optional tool name to request explicitly + + Returns: + Tuple of (content, tool_calls, finish_reason) + """ + payload: dict[str, Any] = { + "model": "gpt-3.5-turbo", # Model name (not used by m serve) + "messages": messages, + "temperature": 0.7, + "stream": True, + } + + if tools: + payload["tools"] = tools + if tool_name is not None: + payload["tool_choice"] = { + "type": "function", + "function": {"name": tool_name}, + } + else: + payload["tool_choice"] = "auto" + + response = requests.post(ENDPOINT, json=payload, stream=True, timeout=30) + + if response.status_code >= 400: + try: + error_payload = response.json() + except ValueError: + error_payload = {"error": {"message": response.text}} + + error_message = error_payload.get("error", {}).get("message", response.text) + raise requests.HTTPError( + f"{response.status_code} Server Error: {error_message}", response=response + ) + + content_chunks = [] + tool_calls: list[dict[str, Any]] | None = None + finish_reason = "stop" + + for line in response.iter_lines(): + if line: + line_str = line.decode("utf-8") + if line_str.startswith("data: "): + data_str = line_str[6:] + if data_str == "[DONE]": + break + + chunk = json.loads(data_str) + choice = chunk["choices"][0] + delta = choice.get("delta", {}) + + # Collect content + if delta.get("content"): + content_chunks.append(delta["content"]) + print(delta["content"], end="", flush=True) + + # Collect tool calls + if delta.get("tool_calls"): + tool_calls = delta["tool_calls"] + + # Get finish reason + if choice.get("finish_reason"): + finish_reason = choice["finish_reason"] + + content = "".join(content_chunks) + return content, tool_calls, finish_reason + + +def _run_local_tool(tool_name: str, args: dict) -> str: + """Simulate local execution of the example tools.""" + if tool_name == "get_weather": + units = args.get("units") or "celsius" + unit_suffix = "C" if units == "celsius" else "F" + return f"The weather in {args['location']} is sunny and 22°{unit_suffix}" + if tool_name == "get_stock_price": + mock_prices = { + "AAPL": "$175.43", + "GOOGL": "$142.87", + "MSFT": "$378.91", + "TSLA": "$242.15", + } + symbol = args["symbol"].upper() + return f"The current price of {symbol} is {mock_prices.get(symbol, '$100.00')}" + return "Tool result" + + +def main(): + """Run example streaming tool calling interactions.""" + print("=" * 60) + print("Streaming Tool Calling Example with m serve") + print("=" * 60) + + # Example 1: Request that should trigger weather tool + print("\n1. Weather Query (Streaming)") + print("-" * 60) + messages = [{"role": "user", "content": "What's the weather like in Tokyo?"}] + + print(f"User: {messages[0]['content']}") + print("\nAssistant: ", end="", flush=True) + content, tool_calls, finish_reason = make_streaming_request( + messages, tools=tools, tool_name="get_weather" + ) + + print(f"\n\nFinish Reason: {finish_reason}") + + if tool_calls: + print("\nTool Calls:") + for tool_call in tool_calls: + func = tool_call["function"] + args = json.loads(func["arguments"]) + print(f" - {func['name']}({json.dumps(args)})") + elif content: + print("(Content already displayed above)") + else: + print("Assistant returned no content and no tool calls.") + + # Example 2: Request that should trigger stock price tool + print("\n\n2. Stock Price Query (Streaming)") + print("-" * 60) + messages = [{"role": "user", "content": "What's the current stock price of AAPL?"}] + + print(f"User: {messages[0]['content']}") + print("\nAssistant: ", end="", flush=True) + content, tool_calls, finish_reason = make_streaming_request( + messages, tools=tools, tool_name="get_stock_price" + ) + + print(f"\n\nFinish Reason: {finish_reason}") + + if tool_calls: + print("\nTool Calls:") + for tool_call in tool_calls: + func = tool_call["function"] + args = json.loads(func["arguments"]) + print(f" - {func['name']}({json.dumps(args)})") + elif content: + print("(Content already displayed above)") + else: + print("Assistant returned no content and no tool calls.") + + # Example 3: Request without tools (normal chat) + print("\n\n3. Normal Chat (No Tools, Streaming)") + print("-" * 60) + messages = [{"role": "user", "content": "Hello! How are you?"}] + + print(f"User: {messages[0]['content']}") + print("\nAssistant: ", end="", flush=True) + content, tool_calls, finish_reason = make_streaming_request(messages, tools=None) + + print(f"\n\nFinish Reason: {finish_reason}") + + # Example 4: Multi-turn conversation with tool use + print("\n\n4. Multi-turn Conversation (Streaming)") + print("-" * 60) + messages = [{"role": "user", "content": "What's the weather in Paris?"}] + + print(f"User: {messages[0]['content']}") + print("\nAssistant: ", end="", flush=True) + content, tool_calls, finish_reason = make_streaming_request( + messages, tools=tools, tool_name="get_weather" + ) + print() # New line after streaming + + if tool_calls: + print("\nAssistant requested tool calls:") + + # Add assistant message once before processing tool calls + messages.append( + { + "role": "assistant", + "content": content if content else None, + "tool_calls": tool_calls, + } + ) + + tool_results: list[str] = [] + + # Process each tool call and add tool responses + for tool_call in tool_calls: + func = tool_call["function"] + args = json.loads(func["arguments"]) + print(f" - {func['name']}({json.dumps(args)})") + + tool_result = _run_local_tool(func["name"], args) + tool_results.append(tool_result) + print(f" Result: {tool_result}") + + # Add tool response to conversation + messages.append( + { + "role": "tool", + "tool_call_id": tool_call["id"], + "content": tool_result, + } + ) + + # Get final response after tool execution + messages.append( + { + "role": "user", + "content": ( + f"Original question: {messages[0]['content']}\n" + f"Tool result: {'; '.join(tool_results)}\n" + "Answer the original question directly using only that tool " + "result. Do not mention unrelated topics or other tools." + ), + } + ) + print("\nGetting final response after tool execution...") + print("Assistant: ", end="", flush=True) + content, tool_calls, finish_reason = make_streaming_request( + messages, tools=None + ) + print() # New line after streaming + if not content: + print("Assistant returned no content after tool execution.") + elif content: + print("(Content already displayed above)") + else: + print("Assistant returned no content and no tool calls.") + + print("\n" + "=" * 60) + print("Examples completed!") + print("=" * 60) + + +if __name__ == "__main__": + try: + main() + except requests.exceptions.ConnectionError: + print("Error: Could not connect to server.") + print("Make sure the server is running:") + print(" uv run m serve docs/examples/m_serve/m_serve_example_tool_calling.py") + except requests.exceptions.HTTPError as e: + print(f"Error: {e}") + if e.response is not None: + try: + print("Server response:", json.dumps(e.response.json(), indent=2)) + except ValueError: + print("Server response:", e.response.text) + except Exception as e: + print(f"Unexpected error: {e}") + raise diff --git a/docs/examples/m_serve/client_tool_calling.py b/docs/examples/m_serve/client_tool_calling.py new file mode 100644 index 000000000..d68e5d238 --- /dev/null +++ b/docs/examples/m_serve/client_tool_calling.py @@ -0,0 +1,291 @@ +"""Client example for testing tool calling with m serve. + +This script demonstrates how to interact with an m serve server +that supports tool calling using the OpenAI-compatible API. + +Usage: + 1. Start the server: + uv run m serve docs/examples/m_serve/m_serve_example_tool_calling.py + + 2. Run this client: + uv run python docs/examples/m_serve/client_tool_calling.py +""" + +import json + +import requests + +# Server configuration +BASE_URL = "http://localhost:8080" +ENDPOINT = f"{BASE_URL}/v1/chat/completions" + +# Define tools in OpenAI format +tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the current weather in a given location", + "parameters": { + "RootModel": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city name, e.g. San Francisco", + }, + "units": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "Temperature units", + }, + }, + "required": ["location"], + } + }, + }, + }, + { + "type": "function", + "function": { + "name": "get_stock_price", + "description": "Get the current stock price for a given ticker symbol", + "parameters": { + "RootModel": { + "type": "object", + "properties": { + "symbol": { + "type": "string", + "description": "The stock ticker symbol, e.g. AAPL, GOOGL", + } + }, + "required": ["symbol"], + } + }, + }, + }, +] + + +def make_request( + messages: list[dict], tools: list[dict] | None = None, tool_name: str | None = None +) -> dict: + """Make a request to the m serve API. + + Args: + messages: List of message dictionaries + tools: Optional list of tool definitions + tool_name: Optional tool name to request explicitly + + Returns: + Response dictionary from the API + """ + payload = { + "model": "gpt-3.5-turbo", # Model name (not used by m serve) + "messages": messages, + "temperature": 0.7, + } + + if tools: + payload["tools"] = tools + if tool_name is not None: + # m serve forwards tool_choice to compatible backends, but the + # downstream provider/model may ignore it or treat it as a weak + # preference rather than a guarantee. Use an explicit function + # selection in this client so the example demonstrates the API + # contract even when the model would otherwise decline to call tools. + payload["tool_choice"] = { + "type": "function", + "function": {"name": tool_name}, + } + else: + payload["tool_choice"] = "auto" + + response = requests.post(ENDPOINT, json=payload, timeout=30) + + if response.status_code >= 400: + try: + error_payload = response.json() + except ValueError: + error_payload = {"error": {"message": response.text}} + + error_message = error_payload.get("error", {}).get("message", response.text) + raise requests.HTTPError( + f"{response.status_code} Server Error: {error_message}", response=response + ) + + return response.json() + + +def _run_local_tool(tool_name: str, args: dict) -> str: + """Simulate local execution of the example tools.""" + if tool_name == "get_weather": + units = args.get("units") or "celsius" + unit_suffix = "C" if units == "celsius" else "F" + return f"The weather in {args['location']} is sunny and 22°{unit_suffix}" + if tool_name == "get_stock_price": + mock_prices = { + "AAPL": "$175.43", + "GOOGL": "$142.87", + "MSFT": "$378.91", + "TSLA": "$242.15", + } + symbol = args["symbol"].upper() + return f"The current price of {symbol} is {mock_prices.get(symbol, '$100.00')}" + return "Tool result" + + +def main(): + """Run example tool calling interactions.""" + print("=" * 60) + print("Tool Calling Example with m serve") + print("=" * 60) + + # Example 1: Request that should trigger weather tool + print("\n1. Weather Query") + print("-" * 60) + messages = [{"role": "user", "content": "What's the weather like in Tokyo?"}] + + print(f"User: {messages[0]['content']}") + response = make_request(messages, tools=tools, tool_name="get_weather") + + choice = response["choices"][0] + print(f"\nFinish Reason: {choice['finish_reason']}") + + if choice.get("message", {}).get("tool_calls"): + print("\nTool Calls:") + for tool_call in choice["message"]["tool_calls"]: + func = tool_call["function"] + args = json.loads(func["arguments"]) + print(f" - {func['name']}({json.dumps(args)})") + elif choice.get("message", {}).get("content"): + print(f"Assistant: {choice['message']['content']}") + else: + print("Assistant returned no content and no tool calls.") + + # Example 2: Request that should trigger stock price tool + print("\n\n2. Stock Price Query") + print("-" * 60) + messages = [{"role": "user", "content": "What's the current stock price of AAPL?"}] + + print(f"User: {messages[0]['content']}") + response = make_request(messages, tools=tools, tool_name="get_stock_price") + + choice = response["choices"][0] + print(f"\nFinish Reason: {choice['finish_reason']}") + + if choice.get("message", {}).get("tool_calls"): + print("\nTool Calls:") + for tool_call in choice["message"]["tool_calls"]: + func = tool_call["function"] + args = json.loads(func["arguments"]) + print(f" - {func['name']}({json.dumps(args)})") + elif choice.get("message", {}).get("content"): + print(f"Assistant: {choice['message']['content']}") + else: + print("Assistant returned no content and no tool calls.") + + # Example 3: Request without tools (normal chat) + print("\n\n3. Normal Chat (No Tools)") + print("-" * 60) + messages = [{"role": "user", "content": "Hello! How are you?"}] + + print(f"User: {messages[0]['content']}") + response = make_request(messages, tools=None) + + choice = response["choices"][0] + print(f"\nFinish Reason: {choice['finish_reason']}") + print(f"Assistant: {choice['message']['content']}") + + # Example 4: Multi-turn conversation with tool use + print("\n\n4. Multi-turn Conversation") + print("-" * 60) + messages = [{"role": "user", "content": "What's the weather in Paris?"}] + + print(f"User: {messages[0]['content']}") + response = make_request(messages, tools=tools, tool_name="get_weather") + + choice = response["choices"][0] + assistant_message = choice["message"] + + if assistant_message.get("tool_calls"): + print("\nAssistant requested tool calls:") + + # Add assistant message once before processing tool calls + messages.append( + { + "role": "assistant", + "content": assistant_message.get("content"), + "tool_calls": assistant_message["tool_calls"], + } + ) + + tool_results: list[str] = [] + + # Process each tool call and add tool responses + for tool_call in assistant_message["tool_calls"]: + func = tool_call["function"] + args = json.loads(func["arguments"]) + print(f" - {func['name']}({json.dumps(args)})") + + tool_result = _run_local_tool(func["name"], args) + tool_results.append(tool_result) + print(f" Result: {tool_result}") + + # Add tool response to conversation + messages.append( + { + "role": "tool", + "tool_call_id": tool_call["id"], + "content": tool_result, + } + ) + + # Get final response after tool execution. + # Ask for a concise answer that explicitly uses the tool result so the + # example output includes the actual weather/price instead of only a + # conversational acknowledgement. + messages.append( + { + "role": "user", + "content": ( + f"Original question: {messages[0]['content']}\n" + f"Tool result: {'; '.join(tool_results)}\n" + "Answer the original question directly using only that tool " + "result. Do not mention unrelated topics or other tools." + ), + } + ) + print("\nGetting final response after tool execution...") + response = make_request(messages, tools=None) + choice = response["choices"][0] + if choice.get("message", {}).get("content"): + print(f"Assistant: {choice['message']['content']}") + else: + print("Assistant returned no content after tool execution.") + elif assistant_message.get("content"): + print(f"Assistant: {assistant_message['content']}") + else: + print("Assistant returned no content and no tool calls.") + + print("\n" + "=" * 60) + print("Examples completed!") + print("=" * 60) + + +if __name__ == "__main__": + try: + main() + except requests.exceptions.ConnectionError: + print("Error: Could not connect to server.") + print("Make sure the server is running:") + print(" uv run m serve docs/examples/m_serve/m_serve_example_tool_calling.py") + except requests.exceptions.HTTPError as e: + print(f"Error: {e}") + if e.response is not None: + try: + print("Server response:", json.dumps(e.response.json(), indent=2)) + except ValueError: + print("Server response:", e.response.text) + except Exception as e: + print(f"Error: {e}") diff --git a/docs/examples/m_serve/m_serve_example_tool_calling.py b/docs/examples/m_serve/m_serve_example_tool_calling.py new file mode 100644 index 000000000..839c91b1b --- /dev/null +++ b/docs/examples/m_serve/m_serve_example_tool_calling.py @@ -0,0 +1,271 @@ +# pytest: ollama, e2e + +"""Example demonstrating tool calling with m serve. + +This file supports two distinct usage patterns: + +1. Running it directly with ``uv run python ...`` performs a local smoke test + using native Mellea tool calling. +2. Serving it with ``m serve`` exposes an OpenAI-compatible endpoint that + accepts OpenAI-style tool definitions in the request. + +The direct ``__main__`` smoke test is intentionally separate from the +OpenAI-compatible server flow because local ``session.instruct(...)`` calls +should use ``MelleaTool`` objects directly. +""" + +import os +from typing import Any + +from cli.serve.models import ChatMessage +from mellea.backends import ModelOption +from mellea.backends.model_ids import IBM_GRANITE_4_HYBRID_MICRO +from mellea.backends.openai import OpenAIBackend +from mellea.backends.tools import MelleaTool +from mellea.core import ModelOutputThunk, Requirement +from mellea.core.base import AbstractMelleaTool +from mellea.formatters import TemplateFormatter +from mellea.stdlib.context import ChatContext +from mellea.stdlib.session import MelleaSession + +_ollama_host = os.environ.get("OLLAMA_HOST", "localhost:11434") +if not _ollama_host.startswith(("http://", "https://")): + _ollama_host = f"http://{_ollama_host}" + +backend = OpenAIBackend( + model_id=IBM_GRANITE_4_HYBRID_MICRO.ollama_name, # type: ignore[arg-type] + formatter=TemplateFormatter(model_id=IBM_GRANITE_4_HYBRID_MICRO.hf_model_name), # type: ignore[arg-type] + base_url=f"{_ollama_host}/v1", + api_key="ollama", +) +session = MelleaSession(backend, ctx=ChatContext()) + + +class GetWeatherTool(AbstractMelleaTool): + """Tool for getting weather information.""" + + name = "get_weather" + + def run(self, location: str, units: str | None = "celsius") -> str: + """Get the current weather for a location. + + Args: + location: The city name + units: Temperature units (celsius or fahrenheit) + + Returns: + Weather information as a string + """ + # Models sometimes emit optional arguments explicitly as null/None. + resolved_units = units or "celsius" + # In a real implementation, this would call a weather API + return f"The weather in {location} is sunny and 22°{resolved_units[0].upper()}" + + @property + def as_json_tool(self) -> dict[str, Any]: + """Return JSON schema for this tool.""" + return { + "type": "function", + "function": { + "name": self.name, + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city name, e.g. San Francisco", + }, + "units": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "Temperature units", + }, + }, + "required": ["location"], + }, + }, + } + + +class GetStockPriceTool(AbstractMelleaTool): + """Tool for getting stock price information.""" + + name = "get_stock_price" + + def run(self, symbol: str) -> str: + """Get the current stock price for a symbol. + + Args: + symbol: The stock ticker symbol (e.g., AAPL, GOOGL) + + Returns: + Stock price information as a string + """ + # In a real implementation, this would call a stock market API + mock_prices = { + "AAPL": "$175.43", + "GOOGL": "$142.87", + "MSFT": "$378.91", + "TSLA": "$242.15", + } + price = mock_prices.get(symbol.upper(), "$100.00") + return f"The current price of {symbol.upper()} is {price}" + + @property + def as_json_tool(self) -> dict[str, Any]: + """Return JSON schema for this tool.""" + return { + "type": "function", + "function": { + "name": self.name, + "description": "Get the current stock price for a given ticker symbol", + "parameters": { + "type": "object", + "properties": { + "symbol": { + "type": "string", + "description": "The stock ticker symbol, e.g. AAPL, GOOGL", + } + }, + "required": ["symbol"], + }, + }, + } + + +# Create tool instances for server-side lookup +weather_tool_impl = GetWeatherTool() +stock_price_tool_impl = GetStockPriceTool() + +# Native MelleaTool wrappers are only needed for the direct ``__main__`` path. +# The backend helper used by local ``session.instruct(..., ModelOption.TOOLS=[...])`` +# expects ``MelleaTool`` instances in a list, while the server path below uses the +# class-based implementations via the ``TOOLS`` lookup. +weather_tool = MelleaTool( + name=weather_tool_impl.name, + tool_call=weather_tool_impl.run, + as_json_tool=weather_tool_impl.as_json_tool, +) +stock_price_tool = MelleaTool( + name=stock_price_tool_impl.name, + tool_call=stock_price_tool_impl.run, + as_json_tool=stock_price_tool_impl.as_json_tool, +) + +# Map tool names to server-side tool implementations for easy lookup +TOOLS = { + weather_tool_impl.name: weather_tool_impl, + stock_price_tool_impl.name: stock_price_tool_impl, +} + + +def _extract_mellea_tools_from_model_options( + model_options: dict | None, +) -> dict[str, AbstractMelleaTool]: + """Normalize example tool inputs to native tool instances. + + This example supports only two shapes: + - OpenAI-style JSON tool definitions from the server path + - native tool objects from the direct ``__main__`` path + """ + if model_options is None or ModelOption.TOOLS not in model_options: + return {} + + provided_tools = model_options[ModelOption.TOOLS] + tools: dict[str, AbstractMelleaTool] = {} + + for tool_def in provided_tools: + if isinstance(tool_def, AbstractMelleaTool): + tools[tool_def.name] = tool_def + else: + tool_name = tool_def["function"]["name"] + if tool_name in TOOLS: + tools[tool_name] = TOOLS[tool_name] + + return tools + + +def serve( + input: list[ChatMessage], + requirements: list[str] | None = None, + model_options: None | dict = None, +) -> ModelOutputThunk: + """Serve function that handles tool calling. + + This function demonstrates how to use tools with m serve. The tools + are passed via model_options using ModelOption.TOOLS, and tool_choice + can be specified using ModelOption.TOOL_CHOICE. Mellea forwards that + setting to compatible backends, but the downstream provider/model may + still ignore it or treat it as a weak preference. + + Args: + input: List of chat messages + requirements: Optional list of requirement strings + model_options: Model options including ModelOption.TOOLS and ModelOption.TOOL_CHOICE + + Returns: + ModelOutputThunk with potential tool calls + """ + requirements = requirements if requirements else [] + message = input[-1].content + + # Extract tools from model_options if provided + tools = _extract_mellea_tools_from_model_options(model_options) + + # Build model options with tools. + # If the caller explicitly selected a single function via tool_choice, + # narrow the advertised tool set to that one tool so the backend/model + # is not asked to choose among unrelated tools. + final_model_options = dict(model_options or {}) + selected_tool_name: str | None = None + if tools: + selected_tools = tools + if model_options is not None and ModelOption.TOOL_CHOICE in model_options: + tool_choice = model_options[ModelOption.TOOL_CHOICE] + if isinstance(tool_choice, dict): + selected_tool_name = tool_choice.get("function", {}).get("name") + if selected_tool_name in tools: + selected_tools = {selected_tool_name: tools[selected_tool_name]} + final_model_options[ModelOption.TOOLS] = selected_tools + + # Keep the serve path deterministic for the client example by retrying only + # at the request level. Enforcing uses_tool(...) inside session.instruct() + # caused noisy server-side failures when the model ignored the tool request + # on a particular sample. + result = session.instruct( + description=message, # type: ignore + requirements=[Requirement(req) for req in requirements], # type: ignore + model_options=final_model_options, + tool_calls=True, + strategy=None, + ) + + return result + + +if __name__ == "__main__": + response = session.instruct( + "What's the weather in Boston?", + model_options={ + ModelOption.TOOLS: [weather_tool], + # This direct path now uses the OpenAI backend against Ollama's + # OpenAI-compatible endpoint, so TOOL_CHOICE is forwarded by + # Mellea. Ollama and/or the selected model may still ignore it + # or not enforce it strictly in practice. + ModelOption.TOOL_CHOICE: "auto", + ModelOption.MAX_NEW_TOKENS: 1000, + }, + strategy=None, + tool_calls=True, + ) + + print(f"Response: {response.value}") + print( + "Tool calls requested:", + None if response.tool_calls is None else list(response.tool_calls.keys()), + ) + + if response.tool_calls and weather_tool.name in response.tool_calls: + tool_result = response.tool_calls[weather_tool.name].call_func() + print(f"Tool result: {tool_result}") diff --git a/mellea/backends/model_options.py b/mellea/backends/model_options.py index decc8c34b..a03e8625c 100644 --- a/mellea/backends/model_options.py +++ b/mellea/backends/model_options.py @@ -22,6 +22,7 @@ class ModelOption: Attributes: TOOLS (str): Sentinel key for a list or dict of ``MelleaTool`` instances to expose for tool calling. + TOOL_CHOICE (str): Key for tool choice strategy (passed through to the backend). MAX_NEW_TOKENS (str): Sentinel key for the maximum number of new tokens to generate. SYSTEM_PROMPT (str): Sentinel key for the system prompt string. TEMPERATURE (str): Key for the sampling temperature (passed through to the backend). @@ -34,6 +35,9 @@ class ModelOption: TOOLS = "@@@tools@@@" """Must be a list[MelleaTool] or a dict[str, MelleaTool]. Use ``MelleaTool.from_callable()`` or the ``@tool`` decorator to wrap plain callables.""" + TOOL_CHOICE = "tool_choice" + """Controls which tool the model should use. Can be "none", "auto", or a specific tool name.""" + MAX_NEW_TOKENS = "@@@max_new_tokens@@@" SYSTEM_PROMPT = "@@@system_prompt@@@" TEMPERATURE = "temperature" diff --git a/mellea/helpers/openai_compatible_helpers.py b/mellea/helpers/openai_compatible_helpers.py index dfa9dd122..bd1507910 100644 --- a/mellea/helpers/openai_compatible_helpers.py +++ b/mellea/helpers/openai_compatible_helpers.py @@ -1,7 +1,8 @@ """A file for helper functions that deal with OpenAI API compatible helpers.""" import json -from typing import Any +import uuid +from typing import Any, Literal, TypedDict from pydantic import BaseModel @@ -11,6 +12,21 @@ from ..stdlib.components import Document, Message +class ToolCallFunction(TypedDict): + """Function details in a tool call.""" + + name: str + arguments: str + + +class ToolCallDict(TypedDict): + """OpenAI-compatible tool call dictionary with ID and function.""" + + id: str + type: Literal["function"] + function: ToolCallFunction + + class CompletionUsage(BaseModel): """Token usage statistics for a completion request.""" @@ -251,3 +267,39 @@ def build_completion_usage(output: ModelOutputThunk) -> CompletionUsage | None: completion_tokens=completion_tokens, total_tokens=total_tokens, ) + + +def build_tool_calls(output: ModelOutputThunk) -> list[ToolCallDict] | None: + """Build OpenAI-compatible tool calls from a model output, if available. + + Args: + output: Model output thunk that may expose a ``tool_calls`` mapping. + + Returns: + List of ``ToolCallDict`` objects when tool calls are present, + otherwise ``None``. + """ + # Check for tool calls - ModelOutputThunk always has tool_calls attribute + if ( + output.tool_calls is None + or not isinstance(output.tool_calls, dict) + or not output.tool_calls + ): + return None + + tool_calls: list[ToolCallDict] = [] + for model_tool_call in output.tool_calls.values(): + # Generate a unique ID for this tool call + tool_call_id = f"call_{uuid.uuid4().hex[:24]}" + + # Serialize arguments to JSON with str fallback for non-serializable types + args_json = json.dumps(model_tool_call.args, default=str) + + tool_call: ToolCallDict = { + "id": tool_call_id, + "type": "function", + "function": {"name": model_tool_call.name, "arguments": args_json}, + } + tool_calls.append(tool_call) + + return tool_calls diff --git a/test/cli/test_serve.py b/test/cli/test_serve.py index 2e626e6d5..add2682e5 100644 --- a/test/cli/test_serve.py +++ b/test/cli/test_serve.py @@ -451,18 +451,19 @@ async def test_unsupported_params_excluded_from_model_options(self, mock_module) assert "logit_bias" not in model_options @pytest.mark.asyncio - async def test_tool_params_excluded_from_model_options(self, mock_module): - """Test that tool-related parameters are excluded from model_options.""" + async def test_tool_params_passed_to_model_options(self, mock_module): + """Test that tool-related parameters are passed to model_options.""" from cli.serve.models import ( FunctionDefinition, FunctionParameters, ToolFunction, ) + from mellea.backends.model_options import ModelOption request = ChatCompletionRequest( model="test-model", messages=[ChatMessage(role="user", content="Hello")], - # Tool-related parameters that should be excluded + # Tool-related parameters tools=[ ToolFunction( type="function", @@ -498,9 +499,12 @@ async def test_tool_params_excluded_from_model_options(self, mock_module): assert call_args is not None model_options = call_args.kwargs["model_options"] - # Tool-related parameters should NOT be in model_options - assert "tools" not in model_options - assert "tool_choice" not in model_options + # Tools should be passed with ModelOption.TOOLS key + assert ModelOption.TOOLS in model_options + # tool_choice should be passed through using ModelOption.TOOL_CHOICE + assert ModelOption.TOOL_CHOICE in model_options + assert model_options[ModelOption.TOOL_CHOICE] == "auto" + # Legacy function calling parameters should still be excluded assert "functions" not in model_options assert "function_call" not in model_options @@ -531,3 +535,93 @@ async def test_response_format_excluded_from_model_options(self, mock_module): # response_format should NOT be in model_options assert "response_format" not in model_options + + @pytest.mark.asyncio + async def test_streaming_with_tool_calls(self, mock_module): + """Test that tool calls are properly emitted in streaming responses.""" + import json + from unittest.mock import Mock + + from fastapi.responses import StreamingResponse + + from mellea.core.base import ModelToolCall + + # Create a mock tool + mock_tool = Mock() + mock_tool.name = "get_weather" + + # Create a mock output with tool calls + # Real backends may return content alongside tool calls (e.g., "I'll check that for you") + mock_output = ModelOutputThunk("I'll check the weather for you.") + mock_output.tool_calls = { + "get_weather": ModelToolCall( + name="get_weather", + func=mock_tool, + args={"location": "San Francisco", "units": "celsius"}, + ) + } + mock_module.serve.return_value = mock_output + + request = ChatCompletionRequest( + model="test-model", + messages=[ChatMessage(role="user", content="What's the weather?")], + stream=True, + ) + + endpoint = make_chat_endpoint(mock_module) + response = await endpoint(request) + + # Verify it's a streaming response + assert isinstance(response, StreamingResponse) + + # Collect all chunks + chunks = [] + async for chunk_data in response.body_iterator: + # Convert to string for parsing + if isinstance(chunk_data, (bytes, memoryview)): + chunk_str = ( + bytes(chunk_data).decode("utf-8") + if isinstance(chunk_data, memoryview) + else chunk_data.decode("utf-8") + ) + else: + chunk_str = chunk_data + + # Parse SSE format: "data: {json}\n\n" + if chunk_str.startswith("data: "): + json_str = chunk_str[6:].strip() + if json_str and json_str != "[DONE]": + chunks.append(json.loads(json_str)) + + # Verify we have the expected chunk sequence + # Expected: initial (role), content, tool_calls, final = 4 chunks + assert len(chunks) == 4, f"Should have exactly 4 chunks, got {len(chunks)}" + + # Chunk 0: Initial chunk with role + initial_chunk = chunks[0] + assert initial_chunk["choices"][0]["delta"].get("role") == "assistant" + assert initial_chunk["choices"][0]["finish_reason"] is None + + # Chunk 1: Content chunk + content_chunk = chunks[1] + assert ( + content_chunk["choices"][0]["delta"].get("content") + == "I'll check the weather for you." + ) + assert content_chunk["choices"][0]["finish_reason"] is None + + # Chunk 2: Tool call chunk + tool_call_chunk = chunks[2] + tool_calls = tool_call_chunk["choices"][0]["delta"]["tool_calls"] + assert len(tool_calls) == 1 + # Verify required index field is present (OpenAI streaming spec requirement) + assert "index" in tool_calls[0], "tool_calls delta must include index field" + assert tool_calls[0]["index"] == 0 + assert tool_calls[0]["function"]["name"] == "get_weather" + assert "location" in tool_calls[0]["function"]["arguments"] + assert tool_call_chunk["choices"][0]["finish_reason"] is None + + # Chunk 3: Final chunk has finish_reason="tool_calls" + final_chunk = chunks[3] + assert final_chunk["choices"][0]["delta"].get("content") is None + assert final_chunk["choices"][0]["finish_reason"] == "tool_calls" diff --git a/test/cli/test_serve_integration.py b/test/cli/test_serve_integration.py new file mode 100644 index 000000000..ff6792470 --- /dev/null +++ b/test/cli/test_serve_integration.py @@ -0,0 +1,572 @@ +"""Integration tests for m serve using FastAPI TestClient. + +Tests the full HTTP request/response cycle including: +- Streaming responses (SSE format, headers, chunking) +- Tool calling responses via HTTP +- Error handling at the HTTP layer +""" + +import json +from typing import Any +from unittest.mock import Mock + +import pytest +from fastapi import FastAPI +from fastapi.exceptions import RequestValidationError +from fastapi.testclient import TestClient + +from cli.serve.app import make_chat_endpoint, validation_exception_handler +from cli.serve.models import FunctionDefinition, FunctionParameters, ToolFunction +from mellea.core.base import AbstractMelleaTool, ModelOutputThunk, ModelToolCall + +# Mark all tests in this module as integration tests +pytestmark = pytest.mark.integration + + +class MockWeatherTool(AbstractMelleaTool): + """Mock weather tool for testing.""" + + name = "get_weather" + + def run(self, location: str, units: str = "celsius") -> str: + """Mock run method.""" + return f"Weather in {location} is 22°{units[0].upper()}" + + @property + def as_json_tool(self) -> dict[str, Any]: + """Return JSON schema for this tool.""" + return { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the current weather in a location", + "parameters": { + "type": "object", + "properties": { + "location": {"type": "string", "description": "City name"}, + "units": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "Temperature units", + }, + }, + "required": ["location"], + }, + }, + } + + @property + def as_tool_function(self) -> ToolFunction: + """Return ToolFunction model for HTTP requests.""" + return ToolFunction( + type="function", + function=FunctionDefinition( + name="get_weather", + description="Get the current weather in a location", + parameters=FunctionParameters( + RootModel={ + "type": "object", + "properties": { + "location": {"type": "string", "description": "City name"}, + "units": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "Temperature units", + }, + }, + "required": ["location"], + } + ), + ), + ) + + +@pytest.fixture +def mock_module(): + """Create a mock module with a serve function.""" + module = Mock() + module.__name__ = "test_integration_module" + return module + + +@pytest.fixture +def test_app(mock_module): + """Create a FastAPI test app with the chat endpoint.""" + app = FastAPI() + app.add_exception_handler(RequestValidationError, validation_exception_handler) + app.add_api_route( + "/v1/chat/completions", make_chat_endpoint(mock_module), methods=["POST"] + ) + return app + + +@pytest.fixture +def client(test_app): + """Create a TestClient for the app.""" + return TestClient(test_app) + + +class TestStreamingIntegration: + """Integration tests for streaming responses via HTTP.""" + + def test_streaming_response_headers(self, client, mock_module): + """Test that streaming responses have correct HTTP headers.""" + mock_output = ModelOutputThunk("Hello, streaming world!") + mock_module.serve.return_value = mock_output + + response = client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "Hello"}], + "stream": True, + }, + ) + + # Verify streaming headers + assert response.status_code == 200 + assert response.headers["content-type"] == "text/event-stream; charset=utf-8" + + def test_streaming_sse_format(self, client, mock_module): + """Test that streaming responses follow SSE format.""" + mock_output = ModelOutputThunk("Test response") + mock_module.serve.return_value = mock_output + + response = client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "Hello"}], + "stream": True, + }, + ) + + # Parse SSE chunks + chunks = [] + for line in response.text.split("\n\n"): + if line.startswith("data: "): + data = line[6:].strip() + if data != "[DONE]": + chunks.append(json.loads(data)) + + # Verify chunk structure + assert len(chunks) > 0 + for chunk in chunks: + assert chunk["object"] == "chat.completion.chunk" + assert "id" in chunk + assert "model" in chunk + assert "created" in chunk + assert "choices" in chunk + assert len(chunk["choices"]) == 1 + + # Verify final chunk has finish_reason + assert chunks[-1]["choices"][0]["finish_reason"] == "stop" + + def test_streaming_content_chunks(self, client, mock_module): + """Test that content is properly chunked in streaming response.""" + mock_output = ModelOutputThunk("Hello world!") + mock_module.serve.return_value = mock_output + + response = client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "Say hello"}], + "stream": True, + }, + ) + + # Parse chunks + chunks = [] + for line in response.text.split("\n\n"): + if line.startswith("data: "): + data = line[6:].strip() + if data != "[DONE]": + chunks.append(json.loads(data)) + + # First chunk should have role + assert chunks[0]["choices"][0]["delta"].get("role") == "assistant" + + # Second chunk should have content + assert chunks[1]["choices"][0]["delta"].get("content") == "Hello world!" + + # Final chunk should have finish_reason + assert chunks[-1]["choices"][0]["finish_reason"] == "stop" + + def test_streaming_with_usage_field(self, client, mock_module): + """Test streaming response includes usage when stream_options.include_usage=True.""" + mock_output = ModelOutputThunk("Response") + mock_output.generation.usage = { + "prompt_tokens": 10, + "completion_tokens": 5, + "total_tokens": 15, + } + mock_module.serve.return_value = mock_output + + response = client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "Hello"}], + "stream": True, + "stream_options": {"include_usage": True}, + }, + ) + + # Parse chunks + chunks = [] + for line in response.text.split("\n\n"): + if line.startswith("data: "): + data = line[6:].strip() + if data != "[DONE]": + chunks.append(json.loads(data)) + + # Final chunk should include usage + final_chunk = chunks[-1] + assert "usage" in final_chunk + assert final_chunk["usage"]["prompt_tokens"] == 10 + assert final_chunk["usage"]["completion_tokens"] == 5 + assert final_chunk["usage"]["total_tokens"] == 15 + + def test_streaming_done_marker(self, client, mock_module): + """Test that streaming response ends with [DONE] marker.""" + mock_output = ModelOutputThunk("Test") + mock_module.serve.return_value = mock_output + + response = client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "Hello"}], + "stream": True, + }, + ) + + # Verify [DONE] marker is present + assert "data: [DONE]" in response.text + assert response.text.strip().endswith("data: [DONE]") + + +class TestToolCallingIntegration: + """Integration tests for tool calling via HTTP.""" + + def test_tool_call_response_structure(self, client, mock_module): + """Test that tool calls are properly formatted in HTTP response.""" + mock_output = ModelOutputThunk("I'll check the weather.") + mock_tool = MockWeatherTool() + mock_output.tool_calls = { + "get_weather": ModelToolCall( + name="get_weather", + func=mock_tool, + args={"location": "Paris", "units": "celsius"}, + ) + } + mock_module.serve.return_value = mock_output + + response = client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [ + {"role": "user", "content": "What's the weather in Paris?"} + ], + "tools": [mock_tool.as_tool_function.model_dump(mode="json")], + }, + ) + + assert response.status_code == 200 + data = response.json() + + # Verify response structure + assert data["object"] == "chat.completion" + assert data["choices"][0]["finish_reason"] == "tool_calls" + assert data["choices"][0]["message"]["tool_calls"] is not None + assert len(data["choices"][0]["message"]["tool_calls"]) == 1 + + # Verify tool call details + tool_call = data["choices"][0]["message"]["tool_calls"][0] + assert tool_call["type"] == "function" + assert tool_call["function"]["name"] == "get_weather" + assert tool_call["id"].startswith("call_") + + # Verify arguments + args = json.loads(tool_call["function"]["arguments"]) + assert args["location"] == "Paris" + assert args["units"] == "celsius" + + def test_multiple_tool_calls_via_http(self, client, mock_module): + """Test multiple tool calls in a single HTTP response.""" + mock_output = ModelOutputThunk("Checking multiple locations.") + mock_tool = MockWeatherTool() + mock_output.tool_calls = { + "weather_paris": ModelToolCall( + name="get_weather", func=mock_tool, args={"location": "Paris"} + ), + "weather_london": ModelToolCall( + name="get_weather", func=mock_tool, args={"location": "London"} + ), + "weather_tokyo": ModelToolCall( + name="get_weather", func=mock_tool, args={"location": "Tokyo"} + ), + } + mock_module.serve.return_value = mock_output + + response = client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "Weather in multiple cities"}], + "tools": [mock_tool.as_tool_function.model_dump(mode="json")], + }, + ) + + assert response.status_code == 200 + data = response.json() + + # Verify multiple tool calls + tool_calls = data["choices"][0]["message"]["tool_calls"] + assert len(tool_calls) == 3 + + # Verify each has unique ID + ids = [tc["id"] for tc in tool_calls] + assert len(ids) == len(set(ids)), "Tool call IDs should be unique" + + # Verify locations + locations = [ + json.loads(tc["function"]["arguments"])["location"] for tc in tool_calls + ] + assert set(locations) == {"Paris", "London", "Tokyo"} + + def test_tool_calls_with_usage_info(self, client, mock_module): + """Test that usage info is included with tool calls.""" + mock_output = ModelOutputThunk("Calling tool.") + mock_tool = MockWeatherTool() + mock_output.tool_calls = { + "get_weather": ModelToolCall( + name="get_weather", func=mock_tool, args={"location": "Paris"} + ) + } + mock_output.generation.usage = { + "prompt_tokens": 50, + "completion_tokens": 20, + "total_tokens": 70, + } + mock_module.serve.return_value = mock_output + + response = client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "Weather?"}], + "tools": [mock_tool.as_tool_function.model_dump(mode="json")], + }, + ) + + assert response.status_code == 200 + data = response.json() + + # Verify both tool calls and usage + assert data["choices"][0]["finish_reason"] == "tool_calls" + assert data["usage"] is not None + assert data["usage"]["total_tokens"] == 70 + + +class TestStreamingWithToolCalls: + """Integration tests for streaming responses with tool calls.""" + + def test_streaming_tool_call_response(self, client, mock_module): + """Test streaming response with tool calls via HTTP.""" + mock_output = ModelOutputThunk("I'll check that for you.") + mock_tool = MockWeatherTool() + mock_output.tool_calls = { + "get_weather": ModelToolCall( + name="get_weather", func=mock_tool, args={"location": "Paris"} + ) + } + mock_module.serve.return_value = mock_output + + response = client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "Weather in Paris?"}], + "tools": [mock_tool.as_tool_function.model_dump(mode="json")], + "stream": True, + }, + ) + + assert response.status_code == 200 + assert response.headers["content-type"] == "text/event-stream; charset=utf-8" + + # Parse chunks + chunks = [] + for line in response.text.split("\n\n"): + if line.startswith("data: "): + data = line[6:].strip() + if data != "[DONE]": + chunks.append(json.loads(data)) + + # Should have: initial (role), content, tool_calls, final + assert len(chunks) == 4 + + # Verify chunk sequence + assert chunks[0]["choices"][0]["delta"].get("role") == "assistant" + assert ( + chunks[1]["choices"][0]["delta"].get("content") + == "I'll check that for you." + ) + assert "tool_calls" in chunks[2]["choices"][0]["delta"] + assert chunks[3]["choices"][0]["finish_reason"] == "tool_calls" + + # Verify tool call structure in streaming chunk + tool_calls = chunks[2]["choices"][0]["delta"]["tool_calls"] + assert len(tool_calls) == 1 + assert "index" in tool_calls[0], "Streaming tool calls must include index" + assert tool_calls[0]["index"] == 0 + assert tool_calls[0]["type"] == "function" + assert tool_calls[0]["function"]["name"] == "get_weather" + + def test_streaming_multiple_tool_calls(self, client, mock_module): + """Test streaming with multiple tool calls via HTTP.""" + mock_output = ModelOutputThunk("Checking multiple locations.") + mock_tool = MockWeatherTool() + mock_output.tool_calls = { + "weather_1": ModelToolCall( + name="get_weather", func=mock_tool, args={"location": "Paris"} + ), + "weather_2": ModelToolCall( + name="get_weather", func=mock_tool, args={"location": "London"} + ), + } + mock_module.serve.return_value = mock_output + + response = client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "Weather?"}], + "tools": [mock_tool.as_tool_function.model_dump(mode="json")], + "stream": True, + }, + ) + + # Parse chunks + chunks = [] + for line in response.text.split("\n\n"): + if line.startswith("data: "): + data = line[6:].strip() + if data != "[DONE]": + chunks.append(json.loads(data)) + + # Find tool call chunk (with non-None tool_calls) + tool_call_chunk = next( + c + for c in chunks + if "tool_calls" in c["choices"][0]["delta"] + and c["choices"][0]["delta"]["tool_calls"] is not None + ) + tool_calls = tool_call_chunk["choices"][0]["delta"]["tool_calls"] + + # Verify multiple tool calls with indices + assert len(tool_calls) == 2 + indices = [tc["index"] for tc in tool_calls] + assert indices == [0, 1] + + def test_streaming_tool_calls_with_usage(self, client, mock_module): + """Test streaming tool calls with usage info via HTTP.""" + mock_output = ModelOutputThunk("Calling tool.") + mock_tool = MockWeatherTool() + mock_output.tool_calls = { + "get_weather": ModelToolCall( + name="get_weather", func=mock_tool, args={"location": "Paris"} + ) + } + mock_output.generation.usage = { + "prompt_tokens": 30, + "completion_tokens": 15, + "total_tokens": 45, + } + mock_module.serve.return_value = mock_output + + response = client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "Weather?"}], + "tools": [mock_tool.as_tool_function.model_dump(mode="json")], + "stream": True, + "stream_options": {"include_usage": True}, + }, + ) + + # Parse chunks + chunks = [] + for line in response.text.split("\n\n"): + if line.startswith("data: "): + data = line[6:].strip() + if data != "[DONE]": + chunks.append(json.loads(data)) + + # Final chunk should have both finish_reason and usage + final_chunk = chunks[-1] + assert final_chunk["choices"][0]["finish_reason"] == "tool_calls" + assert "usage" in final_chunk + assert final_chunk["usage"]["total_tokens"] == 45 + + +class TestHTTPErrorHandling: + """Integration tests for error handling at HTTP layer.""" + + def test_invalid_request_returns_400(self, client, mock_module): + """Test that invalid requests return 400 with OpenAI error format.""" + response = client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "Hello"}], + "n": 0, # Invalid: must be >= 1 + }, + ) + + assert response.status_code == 400 + data = response.json() + assert "error" in data + assert data["error"]["type"] == "invalid_request_error" + assert data["error"]["param"] == "n" + + def test_unsupported_n_parameter(self, client, mock_module): + """Test that n > 1 is rejected with proper error.""" + response = client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "Hello"}], + "n": 2, + }, + ) + + assert response.status_code == 400 + data = response.json() + assert "error" in data + assert data["error"]["type"] == "invalid_request_error" + assert data["error"]["param"] == "n" + assert "not supported" in data["error"]["message"].lower() + + def test_server_error_returns_500(self, client, mock_module): + """Test that server errors return 500 with OpenAI error format.""" + # Make serve raise an exception + mock_module.serve.side_effect = RuntimeError("Internal error") + + response = client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "Hello"}], + }, + ) + + assert response.status_code == 500 + data = response.json() + assert "error" in data + assert data["error"]["type"] == "server_error" + assert "Internal error" in data["error"]["message"] diff --git a/test/cli/test_serve_streaming_tool_calls.py b/test/cli/test_serve_streaming_tool_calls.py new file mode 100644 index 000000000..0b5a3a63e --- /dev/null +++ b/test/cli/test_serve_streaming_tool_calls.py @@ -0,0 +1,515 @@ +"""Unit tests for streaming with tool calls, usage fields, and error handling. + +This file contains new tests added in the tool-calling PR. The main streaming +tests (from main branch) are in test_serve_streaming.py. +""" + +import json +from unittest.mock import AsyncMock, Mock + +import pytest + +from cli.serve.models import StreamOptions +from cli.serve.streaming import stream_chat_completion_chunks +from mellea.core.base import ModelOutputThunk, ModelToolCall + + +class TestStreamingToolCalls: + """Tests for streaming responses with tool calls.""" + + @pytest.mark.asyncio + async def test_streaming_tool_call_chunk_structure(self): + """Test that tool call chunks have correct structure with index field.""" + # Create a mock tool + mock_tool = Mock() + mock_tool.name = "get_weather" + + # Create output with tool calls + output = ModelOutputThunk("Checking weather...") + output.tool_calls = { + "get_weather": ModelToolCall( + name="get_weather", + func=mock_tool, + args={"location": "San Francisco", "units": "celsius"}, + ) + } + + # Stream chunks + chunks = [] + async for chunk_data in stream_chat_completion_chunks( + output=output, + completion_id="chatcmpl-test123", + model="test-model", + created=1234567890, + ): + if chunk_data.startswith("data: ") and chunk_data.strip() != "data: [DONE]": + json_str = chunk_data[6:].strip() + chunks.append(json.loads(json_str)) + + # Should have: initial (role), content, tool_calls, final = 4 chunks + assert len(chunks) == 4 + + # Verify tool call chunk structure + tool_call_chunk = chunks[2] + tool_calls = tool_call_chunk["choices"][0]["delta"]["tool_calls"] + assert len(tool_calls) == 1 + + # Critical: index field must be present (OpenAI streaming spec) + assert "index" in tool_calls[0], "tool_calls delta must include index field" + assert tool_calls[0]["index"] == 0 + assert tool_calls[0]["id"] is not None + assert tool_calls[0]["type"] == "function" + assert tool_calls[0]["function"]["name"] == "get_weather" + assert "location" in tool_calls[0]["function"]["arguments"] + + @pytest.mark.asyncio + async def test_finish_reason_tool_calls(self): + """Test that finish_reason is 'tool_calls' when tool calls are present.""" + mock_tool = Mock() + mock_tool.name = "test_func" + + output = ModelOutputThunk("Response") + output.tool_calls = { + "test_func": ModelToolCall( + name="test_func", func=mock_tool, args={"arg": "value"} + ) + } + + chunks = [] + async for chunk_data in stream_chat_completion_chunks( + output=output, + completion_id="chatcmpl-test", + model="test-model", + created=1234567890, + ): + if chunk_data.startswith("data: ") and chunk_data.strip() != "data: [DONE]": + json_str = chunk_data[6:].strip() + chunks.append(json.loads(json_str)) + + # Final chunk should have finish_reason="tool_calls" + final_chunk = chunks[-1] + assert final_chunk["choices"][0]["finish_reason"] == "tool_calls" + + @pytest.mark.asyncio + async def test_finish_reason_stop_without_tool_calls(self): + """Test that finish_reason is 'stop' when no tool calls are present.""" + output = ModelOutputThunk("Simple response") + + chunks = [] + async for chunk_data in stream_chat_completion_chunks( + output=output, + completion_id="chatcmpl-test", + model="test-model", + created=1234567890, + ): + if chunk_data.startswith("data: ") and chunk_data.strip() != "data: [DONE]": + json_str = chunk_data[6:].strip() + chunks.append(json.loads(json_str)) + + # Final chunk should have finish_reason="stop" + final_chunk = chunks[-1] + assert final_chunk["choices"][0]["finish_reason"] == "stop" + + @pytest.mark.asyncio + async def test_multiple_tool_calls_with_indices(self): + """Test that multiple tool calls each get correct index values.""" + mock_tool1 = Mock() + mock_tool1.name = "func1" + mock_tool2 = Mock() + mock_tool2.name = "func2" + mock_tool3 = Mock() + mock_tool3.name = "func3" + + output = ModelOutputThunk("Calling multiple functions") + output.tool_calls = { + "func1": ModelToolCall(name="func1", func=mock_tool1, args={"a": 1}), + "func2": ModelToolCall(name="func2", func=mock_tool2, args={"b": 2}), + "func3": ModelToolCall(name="func3", func=mock_tool3, args={"c": 3}), + } + + chunks = [] + async for chunk_data in stream_chat_completion_chunks( + output=output, + completion_id="chatcmpl-test", + model="test-model", + created=1234567890, + ): + if chunk_data.startswith("data: ") and chunk_data.strip() != "data: [DONE]": + json_str = chunk_data[6:].strip() + chunks.append(json.loads(json_str)) + + # Find tool call chunk + tool_call_chunk = chunks[2] + tool_calls = tool_call_chunk["choices"][0]["delta"]["tool_calls"] + + # Should have 3 tool calls with indices 0, 1, 2 + assert len(tool_calls) == 3 + indices = [tc["index"] for tc in tool_calls] + assert indices == [0, 1, 2] + + # Verify each has required fields + for tc in tool_calls: + assert "index" in tc + assert "id" in tc + assert "type" in tc + assert tc["type"] == "function" + assert "function" in tc + assert "name" in tc["function"] + assert "arguments" in tc["function"] + + @pytest.mark.asyncio + async def test_tool_call_chunk_before_final_chunk(self): + """Test that tool call chunk is emitted before final chunk.""" + mock_tool = Mock() + mock_tool.name = "test_func" + + output = ModelOutputThunk("Response") + output.tool_calls = { + "test_func": ModelToolCall(name="test_func", func=mock_tool, args={}) + } + + chunks = [] + async for chunk_data in stream_chat_completion_chunks( + output=output, + completion_id="chatcmpl-test", + model="test-model", + created=1234567890, + ): + if chunk_data.startswith("data: ") and chunk_data.strip() != "data: [DONE]": + json_str = chunk_data[6:].strip() + chunks.append(json.loads(json_str)) + + # Verify chunk sequence + assert len(chunks) == 4 + + # Chunk 0: initial with role + assert chunks[0]["choices"][0]["delta"].get("role") == "assistant" + assert chunks[0]["choices"][0]["finish_reason"] is None + + # Chunk 1: content + assert chunks[1]["choices"][0]["delta"].get("content") == "Response" + assert chunks[1]["choices"][0]["finish_reason"] is None + + # Chunk 2: tool calls (before final) + assert "tool_calls" in chunks[2]["choices"][0]["delta"] + assert chunks[2]["choices"][0]["finish_reason"] is None + + # Chunk 3: final with finish_reason + assert chunks[3]["choices"][0]["finish_reason"] == "tool_calls" + + +class TestStreamingIncrementalContent: + """Tests for streaming with incremental content (not pre-computed).""" + + @pytest.mark.asyncio + async def test_streaming_incremental_chunks(self): + """Test streaming with incremental content via astream().""" + from unittest.mock import patch + + # Create output that streams incrementally + output = ModelOutputThunk("") + + # Mock astream to return incremental chunks + chunks_to_stream = ["Hello", " ", "world", "!"] + stream_index = 0 + + async def mock_astream(): + nonlocal stream_index + if stream_index < len(chunks_to_stream): + chunk = chunks_to_stream[stream_index] + stream_index += 1 + return chunk + else: + # Mark as computed by setting the value property + output.value = "Hello world!" + # Use object.__setattr__ to bypass property setter for _computed + object.__setattr__(output, "_computed", True) + return "" + + def mock_is_computed(): + return stream_index >= len(chunks_to_stream) + + # Patch the astream and is_computed methods + with ( + patch.object(output, "astream", side_effect=mock_astream), + patch.object(output, "is_computed", side_effect=mock_is_computed), + ): + # Collect streamed chunks + collected_chunks = [] + async for chunk_data in stream_chat_completion_chunks( + output=output, + completion_id="chatcmpl-test", + model="test-model", + created=1234567890, + ): + if ( + chunk_data.startswith("data: ") + and chunk_data.strip() != "data: [DONE]" + ): + json_str = chunk_data[6:].strip() + parsed = json.loads(json_str) + delta_content = parsed["choices"][0]["delta"].get("content") + if delta_content: + collected_chunks.append(delta_content) + + # Should have initial role chunk + 4 content chunks + # (role chunk has content=None, so not collected) + assert collected_chunks == ["Hello", " ", "world", "!"] + + @pytest.mark.asyncio + async def test_streaming_with_tool_calls_after_incremental_content(self): + """Test that tool calls are emitted after incremental content streaming.""" + from unittest.mock import patch + + mock_tool = Mock() + mock_tool.name = "test_func" + + # Create output that streams incrementally + output = ModelOutputThunk("") + output.tool_calls = { + "test_func": ModelToolCall( + name="test_func", func=mock_tool, args={"key": "value"} + ) + } + + # Mock astream + chunks_to_stream = ["Part1", "Part2"] + stream_index = 0 + + async def mock_astream(): + nonlocal stream_index + if stream_index < len(chunks_to_stream): + chunk = chunks_to_stream[stream_index] + stream_index += 1 + return chunk + else: + output.value = "Part1Part2" + object.__setattr__(output, "_computed", True) + return "" + + def mock_is_computed(): + return stream_index >= len(chunks_to_stream) + + # Patch the astream and is_computed methods + with ( + patch.object(output, "astream", side_effect=mock_astream), + patch.object(output, "is_computed", side_effect=mock_is_computed), + ): + # Collect all chunks + chunks = [] + async for chunk_data in stream_chat_completion_chunks( + output=output, + completion_id="chatcmpl-test", + model="test-model", + created=1234567890, + ): + if ( + chunk_data.startswith("data: ") + and chunk_data.strip() != "data: [DONE]" + ): + json_str = chunk_data[6:].strip() + chunks.append(json.loads(json_str)) + + # Should have: initial, Part1, Part2, tool_calls, final = 5 chunks + assert len(chunks) == 5 + + # Verify sequence + assert chunks[0]["choices"][0]["delta"].get("role") == "assistant" + assert chunks[1]["choices"][0]["delta"].get("content") == "Part1" + assert chunks[2]["choices"][0]["delta"].get("content") == "Part2" + assert "tool_calls" in chunks[3]["choices"][0]["delta"] + assert chunks[4]["choices"][0]["finish_reason"] == "tool_calls" + + +class TestStreamingUsageField: + """Tests for usage field in streaming responses.""" + + @pytest.mark.asyncio + async def test_usage_included_when_stream_options_set(self): + """Test that usage is included in final chunk when stream_options.include_usage=True.""" + output = ModelOutputThunk("Response") + output.generation.usage = { + "prompt_tokens": 10, + "completion_tokens": 5, + "total_tokens": 15, + } + + chunks = [] + async for chunk_data in stream_chat_completion_chunks( + output=output, + completion_id="chatcmpl-test", + model="test-model", + created=1234567890, + stream_options=StreamOptions(include_usage=True), + ): + if chunk_data.startswith("data: ") and chunk_data.strip() != "data: [DONE]": + json_str = chunk_data[6:].strip() + chunks.append(json.loads(json_str)) + + # Final chunk should include usage + final_chunk = chunks[-1] + assert "usage" in final_chunk + assert final_chunk["usage"]["prompt_tokens"] == 10 + assert final_chunk["usage"]["completion_tokens"] == 5 + assert final_chunk["usage"]["total_tokens"] == 15 + + @pytest.mark.asyncio + async def test_usage_excluded_when_stream_options_not_set(self): + """Test that usage is excluded when stream_options is None.""" + output = ModelOutputThunk("Response") + output.generation.usage = { + "prompt_tokens": 10, + "completion_tokens": 5, + "total_tokens": 15, + } + + chunks = [] + async for chunk_data in stream_chat_completion_chunks( + output=output, + completion_id="chatcmpl-test", + model="test-model", + created=1234567890, + stream_options=None, + ): + if chunk_data.startswith("data: ") and chunk_data.strip() != "data: [DONE]": + json_str = chunk_data[6:].strip() + chunks.append(json.loads(json_str)) + + # Final chunk should NOT include usage + final_chunk = chunks[-1] + assert "usage" not in final_chunk or final_chunk["usage"] is None + + @pytest.mark.asyncio + async def test_usage_excluded_when_include_usage_false(self): + """Test that usage is excluded when stream_options.include_usage=False.""" + output = ModelOutputThunk("Response") + output.generation.usage = { + "prompt_tokens": 10, + "completion_tokens": 5, + "total_tokens": 15, + } + + chunks = [] + async for chunk_data in stream_chat_completion_chunks( + output=output, + completion_id="chatcmpl-test", + model="test-model", + created=1234567890, + stream_options=StreamOptions(include_usage=False), + ): + if chunk_data.startswith("data: ") and chunk_data.strip() != "data: [DONE]": + json_str = chunk_data[6:].strip() + chunks.append(json.loads(json_str)) + + # Final chunk should NOT include usage + final_chunk = chunks[-1] + assert "usage" not in final_chunk or final_chunk["usage"] is None + + +class TestStreamingErrorHandling: + """Tests for error handling in streaming.""" + + @pytest.mark.asyncio + async def test_streaming_error_emits_error_response(self): + """Test that streaming errors emit OpenAI-compatible error responses.""" + # Create output that will raise an error during streaming + output = ModelOutputThunk("") + output._computed = False + + # Use AsyncMock with side_effect to raise error + output.astream = AsyncMock( + side_effect=RuntimeError("Simulated streaming error") + ) + + # Collect chunks + chunks = [] + async for chunk_data in stream_chat_completion_chunks( + output=output, + completion_id="chatcmpl-test", + model="test-model", + created=1234567890, + ): + chunks.append(chunk_data) + + # Should have: initial chunk, error response, [DONE] + assert len(chunks) >= 3 + + # Find error response (second-to-last before [DONE]) + error_chunk_data = chunks[-2] + assert error_chunk_data.startswith("data: ") + json_str = error_chunk_data[6:].strip() + error_response = json.loads(json_str) + + # Verify error structure + assert "error" in error_response + assert error_response["error"]["type"] == "server_error" + assert "Streaming error" in error_response["error"]["message"] + assert "Simulated streaming error" in error_response["error"]["message"] + + # Should still end with [DONE] + assert chunks[-1] == "data: [DONE]\n\n" + + +class TestStreamingChunkMetadata: + """Tests for chunk metadata fields.""" + + @pytest.mark.asyncio + async def test_all_chunks_have_required_fields(self): + """Test that all chunks have required OpenAI fields.""" + output = ModelOutputThunk("Test response") + + chunks = [] + async for chunk_data in stream_chat_completion_chunks( + output=output, + completion_id="chatcmpl-test123", + model="test-model-name", + created=1234567890, + system_fingerprint="test-fingerprint", + ): + if chunk_data.startswith("data: ") and chunk_data.strip() != "data: [DONE]": + json_str = chunk_data[6:].strip() + chunks.append(json.loads(json_str)) + + # Verify all chunks have required fields + for chunk in chunks: + assert chunk["id"] == "chatcmpl-test123" + assert chunk["model"] == "test-model-name" + assert chunk["created"] == 1234567890 + assert chunk["object"] == "chat.completion.chunk" + assert chunk["system_fingerprint"] == "test-fingerprint" + assert "choices" in chunk + assert len(chunk["choices"]) == 1 + assert chunk["choices"][0]["index"] == 0 + + @pytest.mark.asyncio + async def test_done_marker_emitted(self): + """Test that [DONE] marker is always emitted at the end.""" + output = ModelOutputThunk("Response") + + chunks = [] + async for chunk_data in stream_chat_completion_chunks( + output=output, + completion_id="chatcmpl-test", + model="test-model", + created=1234567890, + ): + chunks.append(chunk_data) + + # Last chunk should be [DONE] + assert chunks[-1] == "data: [DONE]\n\n" + + @pytest.mark.asyncio + async def test_sse_format_correct(self): + """Test that chunks follow SSE format: 'data: {json}\\n\\n'.""" + output = ModelOutputThunk("Response") + + async for chunk_data in stream_chat_completion_chunks( + output=output, + completion_id="chatcmpl-test", + model="test-model", + created=1234567890, + ): + # All chunks should start with "data: " + assert chunk_data.startswith("data: ") + # All chunks should end with double newline + assert chunk_data.endswith("\n\n") diff --git a/test/cli/test_serve_tool_calling.py b/test/cli/test_serve_tool_calling.py new file mode 100644 index 000000000..29c5bbf1b --- /dev/null +++ b/test/cli/test_serve_tool_calling.py @@ -0,0 +1,313 @@ +"""Tests for tool calling support in m serve OpenAI-compatible API server.""" + +import json +from typing import Any +from unittest.mock import Mock + +import pytest + +from cli.serve.app import make_chat_endpoint +from cli.serve.models import ( + ChatCompletion, + ChatCompletionRequest, + ChatMessage, + FunctionDefinition, + FunctionParameters, + ToolFunction, +) +from mellea.backends import ModelOption +from mellea.core.base import AbstractMelleaTool, ModelOutputThunk, ModelToolCall + + +class MockTool(AbstractMelleaTool): + """Mock tool for testing.""" + + name = "get_weather" + + def run(self, location: str) -> str: + """Mock run method.""" + return f"Weather in {location} is sunny" + + @property + def as_json_tool(self) -> dict[str, Any]: + """Return JSON schema for this tool.""" + return { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the current weather", + "parameters": { + "type": "object", + "properties": {"location": {"type": "string"}}, + "required": ["location"], + }, + }, + } + + +@pytest.fixture +def mock_module(): + """Create a mock module with a serve function.""" + module = Mock() + module.__name__ = "test_module" + return module + + +@pytest.fixture +def sample_tool_request(): + """Create a sample ChatCompletionRequest with tools.""" + return ChatCompletionRequest( + model="test-model", + messages=[ChatMessage(role="user", content="What's the weather in Paris?")], + tools=[ + ToolFunction( + type="function", + function=FunctionDefinition( + name="get_weather", + description="Get the current weather in a location", + parameters=FunctionParameters( + RootModel={ + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city name", + } + }, + "required": ["location"], + } + ), + ), + ) + ], + tool_choice="auto", + ) + + +class TestToolCalling: + """Tests for tool calling functionality.""" + + @pytest.mark.asyncio + async def test_tool_calls_in_response(self, mock_module, sample_tool_request): + """Test that tool calls are properly formatted in the response.""" + # Setup mock output with tool calls + mock_output = ModelOutputThunk("I'll check the weather for you.") + mock_tool = MockTool() + mock_output.tool_calls = { + "get_weather": ModelToolCall( + name="get_weather", func=mock_tool, args={"location": "Paris"} + ) + } + mock_module.serve.return_value = mock_output + + # Create endpoint and call it + endpoint = make_chat_endpoint(mock_module) + response = await endpoint(sample_tool_request) + + # Verify response structure + assert isinstance(response, ChatCompletion) + assert response.choices[0].finish_reason == "tool_calls" + assert response.choices[0].message.tool_calls is not None + assert len(response.choices[0].message.tool_calls) == 1 + + # Verify tool call details + tool_call = response.choices[0].message.tool_calls[0] + assert tool_call.type == "function" + assert tool_call.function.name == "get_weather" + + # Parse and verify arguments + args = json.loads(tool_call.function.arguments) + assert args == {"location": "Paris"} + + # Verify tool call ID format + assert tool_call.id.startswith("call_") + assert len(tool_call.id) > len("call_") + + @pytest.mark.asyncio + async def test_multiple_tool_calls(self, mock_module, sample_tool_request): + """Test handling multiple tool calls in a single response.""" + mock_output = ModelOutputThunk("I'll check multiple locations.") + mock_tool = MockTool() + mock_output.tool_calls = { + "get_weather_paris": ModelToolCall( + name="get_weather", func=mock_tool, args={"location": "Paris"} + ), + "get_weather_london": ModelToolCall( + name="get_weather", func=mock_tool, args={"location": "London"} + ), + } + mock_module.serve.return_value = mock_output + + endpoint = make_chat_endpoint(mock_module) + response = await endpoint(sample_tool_request) + + # Verify multiple tool calls + assert response.choices[0].finish_reason == "tool_calls" + assert len(response.choices[0].message.tool_calls) == 2 + + # Verify each tool call has unique ID + ids = [tc.id for tc in response.choices[0].message.tool_calls] + assert len(ids) == len(set(ids)), "Tool call IDs should be unique" + + @pytest.mark.asyncio + async def test_no_tool_calls_finish_reason_stop( + self, mock_module, sample_tool_request + ): + """Test that finish_reason is 'stop' when no tool calls are made.""" + mock_output = ModelOutputThunk("The weather is sunny.") + # No tool_calls set + mock_module.serve.return_value = mock_output + + endpoint = make_chat_endpoint(mock_module) + response = await endpoint(sample_tool_request) + + assert response.choices[0].finish_reason == "stop" + assert response.choices[0].message.tool_calls is None + + @pytest.mark.asyncio + async def test_empty_tool_calls_dict_finish_reason_stop( + self, mock_module, sample_tool_request + ): + """Test that finish_reason is 'stop' when tool_calls is an empty dict. + + Regression test for bug where empty tool_calls dict {} produces + finish_reason='tool_calls' with an empty array instead of + finish_reason='stop' with tool_calls=None. + """ + mock_output = ModelOutputThunk("Hello! How can I help?") + # Set tool_calls to empty dict (the bug case) + mock_output.tool_calls = {} + mock_module.serve.return_value = mock_output + + endpoint = make_chat_endpoint(mock_module) + response = await endpoint(sample_tool_request) + + # Should behave like no tool calls at all + assert response.choices[0].finish_reason == "stop" + assert response.choices[0].message.tool_calls is None + + @pytest.mark.asyncio + async def test_tools_passed_to_model_options( + self, mock_module, sample_tool_request + ): + """Test that tools are passed to serve function in model_options.""" + from mellea.backends.model_options import ModelOption + + mock_output = ModelOutputThunk("Test response") + mock_module.serve.return_value = mock_output + + endpoint = make_chat_endpoint(mock_module) + await endpoint(sample_tool_request) + + # Verify serve was called with tools in model_options + call_args = mock_module.serve.call_args + assert call_args is not None + model_options = call_args.kwargs["model_options"] + + # Tools should be in model_options with the ModelOption.TOOLS key + assert ModelOption.TOOLS in model_options + assert model_options[ModelOption.TOOLS] is not None + + @pytest.mark.asyncio + async def test_tool_choice_passed_to_model_options( + self, mock_module, sample_tool_request + ): + """Test that tool_choice is passed to serve function in model_options.""" + mock_output = ModelOutputThunk("Test response") + mock_module.serve.return_value = mock_output + + endpoint = make_chat_endpoint(mock_module) + await endpoint(sample_tool_request) + + # Verify serve was called with tool_choice in model_options + call_args = mock_module.serve.call_args + assert call_args is not None + model_options = call_args.kwargs["model_options"] + + # tool_choice should be passed through using ModelOption.TOOL_CHOICE + assert ModelOption.TOOL_CHOICE in model_options + assert model_options[ModelOption.TOOL_CHOICE] == "auto" + + @pytest.mark.asyncio + async def test_tool_calls_with_complex_arguments( + self, mock_module, sample_tool_request + ): + """Test tool calls with complex nested arguments.""" + mock_output = ModelOutputThunk("Processing complex request.") + mock_tool = MockTool() + mock_output.tool_calls = { + "complex_tool": ModelToolCall( + name="complex_function", + func=mock_tool, + args={ + "location": "Paris", + "options": { + "units": "celsius", + "include_forecast": True, + "days": 5, + }, + "tags": ["weather", "forecast"], + }, + ) + } + mock_module.serve.return_value = mock_output + + endpoint = make_chat_endpoint(mock_module) + response = await endpoint(sample_tool_request) + + # Verify complex arguments are properly serialized + tool_call = response.choices[0].message.tool_calls[0] + args = json.loads(tool_call.function.arguments) + + assert args["location"] == "Paris" + assert args["options"]["units"] == "celsius" + assert args["options"]["include_forecast"] is True + assert args["options"]["days"] == 5 + assert args["tags"] == ["weather", "forecast"] + + @pytest.mark.asyncio + async def test_tool_calls_with_usage_info(self, mock_module, sample_tool_request): + """Test that usage info is included alongside tool calls.""" + mock_output = ModelOutputThunk("Calling tool.") + mock_tool = MockTool() + mock_output.tool_calls = { + "get_weather": ModelToolCall( + name="get_weather", func=mock_tool, args={"location": "Paris"} + ) + } + mock_output.generation.usage = { + "prompt_tokens": 50, + "completion_tokens": 20, + "total_tokens": 70, + } + mock_module.serve.return_value = mock_output + + endpoint = make_chat_endpoint(mock_module) + response = await endpoint(sample_tool_request) + + # Verify both tool calls and usage are present + assert response.choices[0].finish_reason == "tool_calls" + assert response.choices[0].message.tool_calls is not None + assert response.usage is not None + assert response.usage.total_tokens == 70 + + @pytest.mark.asyncio + async def test_request_without_tools(self, mock_module): + """Test that requests without tools still work normally.""" + request = ChatCompletionRequest( + model="test-model", + messages=[ChatMessage(role="user", content="Hello")], + # No tools specified + ) + + mock_output = ModelOutputThunk("Hello! How can I help?") + mock_module.serve.return_value = mock_output + + endpoint = make_chat_endpoint(mock_module) + response = await endpoint(request) + + # Should work normally without tool-related fields + assert isinstance(response, ChatCompletion) + assert response.choices[0].finish_reason == "stop" + assert response.choices[0].message.tool_calls is None + assert response.choices[0].message.content == "Hello! How can I help?" diff --git a/test/cli/test_tool_call_index_verification.py b/test/cli/test_tool_call_index_verification.py new file mode 100644 index 000000000..ad1232609 --- /dev/null +++ b/test/cli/test_tool_call_index_verification.py @@ -0,0 +1,151 @@ +"""Verification that streaming tool call deltas include required index field. + +This test demonstrates that our streaming implementation is compatible with +OpenAI SDK delta reassembly logic, which requires the index field. +""" + +import json +from unittest.mock import Mock + +import pytest + +from cli.serve.app import make_chat_endpoint +from cli.serve.models import ChatCompletionRequest, ChatMessage +from mellea.core.base import ModelOutputThunk, ModelToolCall + + +@pytest.mark.asyncio +async def test_tool_call_delta_has_required_index_field(): + """Verify that streaming tool call deltas include the required index field. + + The OpenAI streaming spec requires each item in delta.tool_calls to carry + an index field. Clients including the openai Python SDK, LangChain, and + LiteLLM key their delta-reassembly state machine on this field. + + Without it, they silently drop tool calls, coalesce them incorrectly, or + raise a TypeError depending on version. + """ + # Create a mock module with a serve function + mock_module = Mock() + mock_module.__name__ = "test_module" + + # Create a mock tool + mock_tool = Mock() + mock_tool.name = "get_weather" + + # Create a mock output with multiple tool calls to test indexing + mock_output = ModelOutputThunk("I'll check the weather for you.") + mock_output.tool_calls = { + "get_weather": ModelToolCall( + name="get_weather", + func=mock_tool, + args={"location": "San Francisco", "units": "celsius"}, + ), + "get_forecast": ModelToolCall( + name="get_forecast", + func=mock_tool, + args={"location": "San Francisco", "days": 3}, + ), + } + mock_module.serve.return_value = mock_output + + request = ChatCompletionRequest( + model="test-model", + messages=[ChatMessage(role="user", content="What's the weather?")], + stream=True, + ) + + endpoint = make_chat_endpoint(mock_module) + response = await endpoint(request) + + # Collect all chunks + chunks = [] + async for chunk_data in response.body_iterator: + chunk_str = ( + bytes(chunk_data).decode("utf-8") + if isinstance(chunk_data, (bytes, memoryview)) + else chunk_data + ) + + if chunk_str.startswith("data: "): + json_str = chunk_str[6:].strip() + if json_str and json_str != "[DONE]": + chunks.append(json.loads(json_str)) + + # Find the tool call chunk + tool_call_chunk = None + for chunk in chunks: + if chunk["choices"][0]["delta"].get("tool_calls"): + tool_call_chunk = chunk + break + + assert tool_call_chunk is not None, "Should have a tool call chunk" + + tool_calls = tool_call_chunk["choices"][0]["delta"]["tool_calls"] + assert len(tool_calls) == 2, "Should have 2 tool calls" + + # Verify REQUIRED index field is present on each tool call delta + for i, tc in enumerate(tool_calls): + assert "index" in tc, f"tool_calls[{i}] must include index field" + assert isinstance(tc["index"], int), "index must be an integer" + assert tc["index"] == i, f"tool_calls[{i}] should have index={i}" + + # Verify other fields are present (id, type, function) + assert "id" in tc, f"tool_calls[{i}] should have id" + assert "type" in tc, f"tool_calls[{i}] should have type" + assert tc["type"] == "function", f"tool_calls[{i}] type should be 'function'" + assert "function" in tc, f"tool_calls[{i}] should have function" + assert "name" in tc["function"], f"tool_calls[{i}].function should have name" + assert "arguments" in tc["function"], ( + f"tool_calls[{i}].function should have arguments" + ) + + +@pytest.mark.asyncio +async def test_single_tool_call_has_index_zero(): + """Verify that a single tool call has index=0.""" + mock_module = Mock() + mock_module.__name__ = "test_module" + + mock_tool = Mock() + mock_tool.name = "search" + + mock_output = ModelOutputThunk("Searching...") + mock_output.tool_calls = { + "search": ModelToolCall(name="search", func=mock_tool, args={"query": "test"}) + } + mock_module.serve.return_value = mock_output + + request = ChatCompletionRequest( + model="test-model", + messages=[ChatMessage(role="user", content="Search for test")], + stream=True, + ) + + endpoint = make_chat_endpoint(mock_module) + response = await endpoint(request) + + chunks = [] + async for chunk_data in response.body_iterator: + chunk_str = ( + bytes(chunk_data).decode("utf-8") + if isinstance(chunk_data, (bytes, memoryview)) + else chunk_data + ) + + if chunk_str.startswith("data: "): + json_str = chunk_str[6:].strip() + if json_str and json_str != "[DONE]": + chunks.append(json.loads(json_str)) + + # Find the tool call chunk + tool_call_chunk = None + for chunk in chunks: + if chunk["choices"][0]["delta"].get("tool_calls"): + tool_call_chunk = chunk + break + + assert tool_call_chunk is not None + tool_calls = tool_call_chunk["choices"][0]["delta"]["tool_calls"] + assert len(tool_calls) == 1 + assert tool_calls[0]["index"] == 0, "Single tool call should have index=0" diff --git a/test/core/test_component_typing.py b/test/core/test_component_typing.py index f8d3d411d..829fb9eaa 100644 --- a/test/core/test_component_typing.py +++ b/test/core/test_component_typing.py @@ -78,16 +78,16 @@ def session(backend) -> MelleaSession: def test_mot_init_typing(): mot = ModelOutputThunk[float](value="1") - assert hasattr(mot, "__orig_class__"), ( - "mots are generics and should have this field" + assert "__orig_class__" in mot.__dict__, ( + "mots are generics and should have this field in instance dict" ) assert get_args(mot.__orig_class__)[0] is float, ( # type: ignore f"expected float, got {get_args(mot.__orig_class__)[0]} as mot type" # type: ignore ) # type: ignore unknown_mot = ModelOutputThunk(value="2") - assert not hasattr(unknown_mot, "__orig_class__"), ( - "unknown mots / mots with no type defined at instantiate don't have this attribute" + assert "__orig_class__" not in unknown_mot.__dict__, ( + "unknown mots / mots with no type defined at instantiate don't have this attribute in instance dict" ) diff --git a/test/helpers/test_openai_compatible_helpers.py b/test/helpers/test_openai_compatible_helpers.py index 3963c3d8f..2f29a554f 100644 --- a/test/helpers/test_openai_compatible_helpers.py +++ b/test/helpers/test_openai_compatible_helpers.py @@ -2,12 +2,15 @@ import base64 import json +from datetime import datetime +from decimal import Decimal import pytest from mellea.backends.tools import MelleaTool -from mellea.core.base import ImageBlock +from mellea.core.base import ImageBlock, ModelOutputThunk, ModelToolCall from mellea.helpers.openai_compatible_helpers import ( + build_tool_calls, chat_completion_delta_merge, extract_model_tool_requests, message_to_openai_message, @@ -359,5 +362,30 @@ def test_docs_across_messages(self): assert result[1]["text"] == "b" +# --- build_tool_calls --- + + +class TestBuildToolCalls: + def test_with_non_json_serializable_args(self): + """Non-JSON-serializable values (datetime, Decimal) are converted to strings.""" + tool = _make_tool("test_tool") + tool_call = ModelToolCall( + name="test_tool", + func=tool, + args={"timestamp": datetime(2024, 1, 15), "amount": Decimal("123.45")}, + ) + output = ModelOutputThunk(value="test", tool_calls={"test_tool": tool_call}) + + result = build_tool_calls(output) + + assert result is not None + assert len(result) == 1 + # Verify arguments are valid JSON and values were converted to strings + args = json.loads(result[0]["function"]["arguments"]) + assert isinstance(args["timestamp"], str) + assert "2024-01-15" in args["timestamp"] + assert args["amount"] == "123.45" + + if __name__ == "__main__": pytest.main([__file__])