Skip to content

Commit b0a8b40

Browse files
committed
feat: add callback to react() to allow checking for final answer
* optional callback * example shows checking for "Final Answer" each iteration * and also querying the LLM when loop budget is reached Signed-off-by: Mark Sturdevant <mark.sturdevant@ibm.com>
1 parent dc2dd15 commit b0a8b40

3 files changed

Lines changed: 125 additions & 22 deletions

File tree

docs/examples/agents/react/react_using_mellea.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from langchain_community.tools import DuckDuckGoSearchResults
99

1010
from mellea.backends.tools import MelleaTool
11+
from mellea.stdlib import functional as mfuncs
1112
from mellea.stdlib.context import ChatContext
1213
from mellea.stdlib.frameworks.react import react
1314
from mellea.stdlib.session import start_session
@@ -28,15 +29,75 @@ class Email(pydantic.BaseModel):
2829
body: str
2930

3031

32+
class TrueOrFalse(pydantic.BaseModel):
33+
"""Response indicating whether the ReACT agent has completed its task."""
34+
35+
answer: bool = pydantic.Field(
36+
description="True if you have enough information to answer the user's question, False if you need more tool calls"
37+
)
38+
39+
40+
async def last_loop_completion_check(
41+
goal, step, context, backend, model_options, turn_num, loop_budget
42+
):
43+
"""Completion check that asks the model if it has the answer on the last iteration.
44+
45+
Only checks on the last iteration (when turn_num == loop_budget) to avoid
46+
unnecessary LLM calls. Returns False for all other iterations.
47+
48+
Note: step.value is guaranteed to exist when this is called.
49+
"""
50+
# Only check on last iteration (and not for unlimited budget)
51+
if loop_budget == -1 or turn_num < loop_budget:
52+
return False
53+
54+
content = mfuncs.chat(
55+
content=f"Do you know the answer to the user's original query ({goal})? If so, respond with True. If you need to take more actions, then respond False.",
56+
context=context,
57+
backend=backend,
58+
format=TrueOrFalse,
59+
)[0].content
60+
have_answer = TrueOrFalse.model_validate_json(content).answer
61+
62+
return have_answer
63+
64+
65+
async def custom_completion_check(
66+
goal, step, context, backend, model_options, turn_num, loop_budget
67+
):
68+
"""Custom completion check combining keyword detection and fallback to last-loop check.
69+
70+
This runs every iteration:
71+
1. First checks if response contains "final answer" for early termination
72+
2. On the last iteration, falls back to asking the model if it has the answer
73+
74+
Note: step.value is guaranteed to exist when this is called.
75+
"""
76+
# Check every iteration for "final answer" keyword (early termination)
77+
if "final answer" in step.value.lower():
78+
return True
79+
80+
# On last iteration, fall back to asking the model if it has the answer
81+
if loop_budget != -1 and turn_num >= loop_budget:
82+
return await last_loop_completion_check(
83+
goal, step, context, backend, model_options, turn_num, loop_budget
84+
)
85+
86+
return False
87+
88+
3189
async def main():
3290
"""Example."""
33-
# Simple version that just searches for an answer.
91+
# Version with custom answer check that terminates early
92+
# when the model says "final answer" and queries the LLM
93+
# if it reaches the loop_budget.
3494
out, _ = await react(
3595
goal="What is the Mellea python library?",
3696
context=ChatContext(),
3797
backend=m.backend,
3898
tools=[search_tool],
3999
loop_budget=12,
100+
answer_check=custom_completion_check,
40101
)
41102
print(out)
42103

@@ -46,6 +107,7 @@ async def main():
46107
# context=ChatContext(),
47108
# backend=m.backend,
48109
# tools=[search_tool],
110+
# answer_check = custom_completion_check,
49111
# format=Email
50112
# )
51113
# print(out)

mellea/stdlib/frameworks/react.py

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
history tracking. Raises ``RuntimeError`` if the loop ends without a final answer.
88
"""
99

10+
from collections.abc import Awaitable, Callable
11+
1012
import pydantic
1113

1214
# from PIL import Image as PILImage
@@ -46,6 +48,19 @@ async def react(
4648
model_options: dict | None = None,
4749
tools: list[AbstractMelleaTool] | None,
4850
loop_budget: int = 10,
51+
answer_check: Callable[
52+
[
53+
str,
54+
ComputedModelOutputThunk[str],
55+
ChatContext,
56+
Backend,
57+
dict | None,
58+
int,
59+
int,
60+
],
61+
Awaitable[bool],
62+
]
63+
| None = None,
4964
) -> tuple[ComputedModelOutputThunk[str], ChatContext]:
5065
"""Asynchronous ReACT pattern (Think -> Act -> Observe -> Repeat Until Done); attempts to accomplish the provided goal given the provided tools.
5166
@@ -57,6 +72,11 @@ async def react(
5772
model_options: additional model options, which will upsert into the model/backend's defaults.
5873
tools: the list of tools to use
5974
loop_budget: the number of steps allowed; use -1 for unlimited
75+
answer_check: optional callable to determine if the agent has completed its task.
76+
Called every iteration when no tool calls are made and step.value exists (if provided).
77+
Receives (goal, step, context, backend, model_options, turn_num, loop_budget).
78+
Returns bool indicating if the task is complete.
79+
If None, no answer check is performed (loop continues until finalizer or budget exhausted).
6080
6181
Returns:
6282
A (ModelOutputThunk, Context) if `return_sampling_results` is `False`, else returns a `SamplingResult`.
@@ -116,31 +136,19 @@ async def react(
116136
if tool_res.name == MELLEA_FINALIZER_TOOL:
117137
is_final = True
118138

119-
# Check for special case where model already has the answer, but it won't call the finalize tool.
120-
# Instead of letting this run out of iterations and fail, let's ask.
121-
# Only do this before we fail on iteration limit as a last resort because it's hard to justify doing it earlier for now.
122-
elif -1 < loop_budget <= turn_num and step.value:
123-
# If the turn number has reached the end of loop budget (and budget is not unlimited),
124-
# then it's time to check if the model is just loopy and already has the answer.
125-
print("### Done Check")
126-
print("STEP_TOOL_CALLS:", step.tool_calls)
127-
print("STEP:", step)
128-
print("CONTEXT:", context)
129-
content = mfuncs.chat(
130-
content=f"Do you know the answer to the user's original query ({goal})? If so, respond with True. If you need to take more actions, then respond False.",
131-
context=context,
132-
backend=backend,
133-
format=TrueOrFalse,
134-
)[0].content
135-
have_answer = TrueOrFalse.model_validate_json(content).answer
136-
137-
print("### Done Check ANSWER: ", have_answer)
139+
# Check if the agent has completed its task (runs every iteration if answer_check is provided and there's a value)
140+
# The answer_check function can decide when to actually check based on turn_num and loop_budget
141+
elif not is_final and answer_check and step.value:
142+
have_answer = await answer_check(
143+
goal, step, context, backend, model_options, turn_num, loop_budget
144+
)
145+
138146
if have_answer:
139147
# Create a synthetic finalizer tool response to be consistent with normal loop
140148
finalizer_response = ToolMessage(
141149
role="tool",
142-
content=step.value,
143-
tool_output=step.value,
150+
content=step.value or "",
151+
tool_output=step.value or "",
144152
name=MELLEA_FINALIZER_TOOL,
145153
args={},
146154
tool=None, # type: ignore

test/stdlib/test_react_direct_answer.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,44 @@
11
"""Test ReACT framework handling of direct answers without tool calls."""
22

3+
import pydantic
34
import pytest
45

56
from mellea.backends.tools import tool
7+
from mellea.stdlib import functional as mfuncs
68
from mellea.stdlib.context import ChatContext
79
from mellea.stdlib.frameworks.react import react
810
from mellea.stdlib.session import start_session
911

1012

13+
class TrueOrFalse(pydantic.BaseModel):
14+
"""Response indicating whether the ReACT agent has completed its task."""
15+
16+
answer: bool = pydantic.Field(
17+
description="True if you have enough information to answer the user's question, False if you need more tool calls"
18+
)
19+
20+
21+
async def last_loop_completion_check(
22+
goal, step, context, backend, model_options, turn_num, loop_budget
23+
):
24+
"""Completion check that asks the model if it has the answer on the last iteration.
25+
26+
Note: step.value is guaranteed to exist when this is called.
27+
"""
28+
# Only check on last iteration (and not for unlimited budget)
29+
if loop_budget == -1 or turn_num < loop_budget:
30+
return False
31+
32+
content = mfuncs.chat(
33+
content=f"Do you know the answer to the user's original query ({goal})? If so, respond with True. If you need to take more actions, then respond False.",
34+
context=context,
35+
backend=backend,
36+
format=TrueOrFalse,
37+
)[0].content
38+
have_answer = TrueOrFalse.model_validate_json(content).answer
39+
return have_answer
40+
41+
1142
@pytest.mark.ollama
1243
@pytest.mark.llm
1344
async def test_react_direct_answer_without_tools():
@@ -27,6 +58,7 @@ async def test_react_direct_answer_without_tools():
2758
backend=m.backend,
2859
tools=[], # No tools provided
2960
loop_budget=3, # Should complete in 1 iteration, not exhaust budget
61+
answer_check=last_loop_completion_check,
3062
)
3163

3264
# Verify we got an answer
@@ -61,6 +93,7 @@ def search_web(query: str) -> str:
6193
backend=m.backend,
6294
tools=[search_web],
6395
loop_budget=3,
96+
answer_check=last_loop_completion_check,
6497
)
6598

6699
# Verify we got an answer

0 commit comments

Comments
 (0)