Skip to content

Commit b34a816

Browse files
committed
fix: handle direct model answers in ReACT loop
The ReACT framework now properly handles cases where the model provides a direct answer without calling tools. Previously, these answers were ignored and the loop would continue until exhausting the budget. Added test coverage for both scenarios (no tools, unused tools). Fixes: #762 Signed-off-by: Mark Sturdevant <mark.sturdevant@ibm.com>
1 parent 0fd342e commit b34a816

2 files changed

Lines changed: 88 additions & 5 deletions

File tree

mellea/stdlib/frameworks/react.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -106,9 +106,19 @@ async def react(
106106
if tool_res.name == MELLEA_FINALIZER_TOOL:
107107
is_final = True
108108

109-
if is_final:
110-
assert len(tool_responses) == 1, "multiple tools were called with 'final'"
109+
# Check if we should return: either finalizer was called or model gave direct answer
110+
should_return = is_final or (step.tool_calls is None and step.value is not None)
111111

112+
if should_return:
113+
if is_final:
114+
assert len(tool_responses) == 1, (
115+
"multiple tools were called with 'final'"
116+
)
117+
if format is None:
118+
# The tool has already been called above.
119+
step._underlying_value = str(tool_responses[0].content)
120+
121+
# Apply format if requested (works for both finalizer and direct answer cases)
112122
if format is not None:
113123
step, next_context = await mfuncs.aact(
114124
action=ReactThought(),
@@ -122,9 +132,7 @@ async def react(
122132
)
123133
assert isinstance(next_context, ChatContext)
124134
context = next_context
125-
else:
126-
# The tool has already been called above.
127-
step._underlying_value = str(tool_responses[0].content)
135+
128136
return step, context
129137

130138
raise RuntimeError(f"could not complete react loop in {loop_budget} iterations")
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
"""Test ReACT framework handling of direct answers without tool calls."""
2+
3+
import pytest
4+
5+
from mellea.backends.tools import tool
6+
from mellea.stdlib.context import ChatContext
7+
from mellea.stdlib.frameworks.react import react
8+
from mellea.stdlib.session import start_session
9+
10+
11+
@pytest.mark.ollama
12+
@pytest.mark.llm
13+
async def test_react_direct_answer_without_tools():
14+
"""Test that ReACT handles direct answers when model doesn't call tools.
15+
16+
This tests the case where the model provides a direct answer in step.value
17+
without making any tool calls. The fix ensures the loop terminates properly
18+
instead of continuing until loop_budget is exhausted.
19+
"""
20+
m = start_session()
21+
22+
# Ask a simple question that doesn't require tools
23+
# The model should provide a direct answer without calling any tools
24+
out, _ = await react(
25+
goal="What is 2 + 2?",
26+
context=ChatContext(),
27+
backend=m.backend,
28+
tools=[], # No tools provided
29+
loop_budget=3, # Should complete in 1 iteration, not exhaust budget
30+
)
31+
32+
# Verify we got an answer
33+
assert out.value is not None
34+
assert len(out.value) > 0
35+
36+
# The answer should contain "4" or "four"
37+
answer_lower = out.value.lower()
38+
assert "4" in answer_lower or "four" in answer_lower
39+
40+
41+
@pytest.mark.ollama
42+
@pytest.mark.llm
43+
async def test_react_direct_answer_with_unused_tools():
44+
"""Test that ReACT handles direct answers even when tools are available.
45+
46+
This tests the case where tools are provided but the model chooses to
47+
answer directly without using them.
48+
"""
49+
m = start_session()
50+
51+
# Create a dummy tool that won't be needed
52+
@tool
53+
def search_web(query: str) -> str:
54+
"""Search the web for information."""
55+
return "Search results"
56+
57+
# Ask a question that doesn't need the tool
58+
out, _ = await react(
59+
goal="What is the capital of France?",
60+
context=ChatContext(),
61+
backend=m.backend,
62+
tools=[search_web],
63+
loop_budget=3,
64+
)
65+
66+
# Verify we got an answer
67+
assert out.value is not None
68+
assert len(out.value) > 0
69+
70+
# The answer should mention Paris
71+
answer_lower = out.value.lower()
72+
assert "paris" in answer_lower
73+
74+
75+
# Made with Bob

0 commit comments

Comments
 (0)