Skip to content

Commit b4c05c5

Browse files
committed
feat: cli support for OpenAI API tool calling with streaming
Signed-off-by: Mark Sturdevant <mark.sturdevant@ibm.com> Assisted-by: IBM Bob
1 parent 313a497 commit b4c05c5

6 files changed

Lines changed: 544 additions & 33 deletions

File tree

cli/serve/app.py

Lines changed: 27 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import asyncio
44
import importlib.util
55
import inspect
6-
import json
76
import os
87
import sys
98
import time
@@ -23,7 +22,10 @@
2322
) from e
2423

2524
from mellea.backends.model_options import ModelOption
26-
from mellea.helpers.openai_compatible_helpers import build_completion_usage
25+
from mellea.helpers.openai_compatible_helpers import (
26+
build_completion_usage,
27+
build_tool_calls,
28+
)
2729

2830
from .models import (
2931
ChatCompletion,
@@ -176,34 +178,30 @@ async def endpoint(request: ChatCompletionRequest):
176178
)
177179

178180
# Extract tool calls from the ModelOutputThunk if available
179-
tool_calls = None
180-
finish_reason: Literal[
181-
"stop", "length", "content_filter", "tool_calls", "function_call"
182-
] = "stop"
183-
if (
184-
hasattr(output, "tool_calls")
185-
and output.tool_calls is not None
186-
and isinstance(output.tool_calls, dict)
187-
and output.tool_calls # Check dict is not empty
188-
):
189-
tool_calls = []
190-
for model_tool_call in output.tool_calls.values():
191-
# Generate a unique ID for this tool call
192-
tool_call_id = f"call_{uuid.uuid4().hex[:24]}"
193-
194-
# Serialize the arguments to JSON string
195-
args_json = json.dumps(model_tool_call.args)
196-
197-
tool_calls.append(
198-
ChatCompletionMessageToolCall(
199-
id=tool_call_id,
200-
type="function",
201-
function=ToolCallFunction(
202-
name=model_tool_call.name, arguments=args_json
203-
),
204-
)
181+
tool_calls_list = build_tool_calls(output)
182+
tool_calls = (
183+
[
184+
ChatCompletionMessageToolCall(
185+
id=tc["id"],
186+
type=tc["type"],
187+
function=ToolCallFunction(
188+
name=tc["function"]["name"],
189+
arguments=tc["function"]["arguments"],
190+
),
205191
)
206-
finish_reason = "tool_calls"
192+
for tc in tool_calls_list
193+
]
194+
if tool_calls_list
195+
else None
196+
)
197+
198+
# Determine finish_reason based on tool calls
199+
finish_reason: (
200+
Literal[
201+
"stop", "length", "content_filter", "tool_calls", "function_call"
202+
]
203+
| None
204+
) = "tool_calls" if tool_calls else "stop"
207205

208206
# system_fingerprint represents backend config hash, not model name
209207
# The model name is already in response.model (line 73)

cli/serve/models.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,9 @@ class ChatCompletionChunkDelta(BaseModel):
170170
refusal: str | None = None
171171
"""The refusal message fragment, if any."""
172172

173+
tool_calls: list[ChatCompletionMessageToolCall] | None = None
174+
"""The tool calls generated by the model (only in tool call chunks)."""
175+
173176

174177
class ChatCompletionChunkChoice(BaseModel):
175178
"""A choice in a streaming chunk."""

cli/serve/streaming.py

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,24 @@
11
"""Streaming utilities for OpenAI-compatible server responses."""
22

33
from collections.abc import AsyncGenerator
4+
from typing import Literal
45

56
from mellea.core.base import ModelOutputThunk
67
from mellea.core.utils import MelleaLogger
7-
from mellea.helpers.openai_compatible_helpers import build_completion_usage
8+
from mellea.helpers.openai_compatible_helpers import (
9+
build_completion_usage,
10+
build_tool_calls,
11+
)
812

913
from .models import (
1014
ChatCompletionChunk,
1115
ChatCompletionChunkChoice,
1216
ChatCompletionChunkDelta,
17+
ChatCompletionMessageToolCall,
1318
OpenAIError,
1419
OpenAIErrorResponse,
1520
StreamOptions,
21+
ToolCallFunction,
1622
)
1723

1824

@@ -98,6 +104,46 @@ async def stream_chat_completion_chunks(
98104
)
99105
yield f"data: {chunk.model_dump_json()}\n\n"
100106

107+
# Extract tool calls from the ModelOutputThunk if available
108+
tool_calls_list = build_tool_calls(output)
109+
110+
if tool_calls_list:
111+
# Convert to ChatCompletionMessageToolCall objects
112+
tool_calls = [
113+
ChatCompletionMessageToolCall(
114+
id=tc["id"],
115+
type=tc["type"],
116+
function=ToolCallFunction(
117+
name=tc["function"]["name"],
118+
arguments=tc["function"]["arguments"],
119+
),
120+
)
121+
for tc in tool_calls_list
122+
]
123+
124+
# Emit tool calls in a separate chunk before the final chunk
125+
tool_call_chunk = ChatCompletionChunk(
126+
id=completion_id,
127+
model=model,
128+
created=created,
129+
choices=[
130+
ChatCompletionChunkChoice(
131+
index=0,
132+
delta=ChatCompletionChunkDelta(tool_calls=tool_calls),
133+
finish_reason=None,
134+
)
135+
],
136+
object="chat.completion.chunk",
137+
system_fingerprint=system_fingerprint,
138+
)
139+
yield f"data: {tool_call_chunk.model_dump_json()}\n\n"
140+
141+
# Determine finish_reason based on tool calls
142+
finish_reason: (
143+
Literal["stop", "length", "content_filter", "tool_calls", "function_call"]
144+
| None
145+
) = "tool_calls" if tool_calls_list else "stop"
146+
101147
# Include usage in final chunk only if explicitly requested via stream_options
102148
# Per OpenAI spec: usage is only included when stream_options.include_usage=True
103149
include_usage = stream_options is not None and stream_options.include_usage
@@ -112,7 +158,7 @@ async def stream_chat_completion_chunks(
112158
ChatCompletionChunkChoice(
113159
index=0,
114160
delta=ChatCompletionChunkDelta(content=None),
115-
finish_reason="stop",
161+
finish_reason=finish_reason,
116162
)
117163
],
118164
object="chat.completion.chunk",

0 commit comments

Comments
 (0)