Skip to content

Commit dc2dd15

Browse files
committed
fix: before failing react loop ask LLM if it has the answer
With some models (and Ollama for example) we get stuck where the model has the answer but won't call finalize. Before failing due to iteration limit, ask the model if it has the answer and if it responds True then use it. Note: This is only done at the end of iterations because it is questionable to penalize other models on each iteration. When failure is the only option, it seems to be worth a try. Fixes: #762 Signed-off-by: Mark Sturdevant <mark.sturdevant@ibm.com>
1 parent b34a816 commit dc2dd15

1 file changed

Lines changed: 48 additions & 12 deletions

File tree

mellea/stdlib/frameworks/react.py

Lines changed: 48 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
history tracking. Raises ``RuntimeError`` if the loop ends without a final answer.
88
"""
99

10+
import pydantic
11+
1012
# from PIL import Image as PILImage
1113
from mellea.backends.model_options import ModelOption
1214
from mellea.core.backend import Backend, BaseModelSubclass
@@ -24,6 +26,14 @@
2426
from mellea.stdlib.context import ChatContext
2527

2628

29+
class TrueOrFalse(pydantic.BaseModel):
30+
"""Response indicating whether the ReACT agent has completed its task."""
31+
32+
answer: bool = pydantic.Field(
33+
description="True if you have enough information to answer the user's question, False if you need more tool calls"
34+
)
35+
36+
2737
async def react(
2838
goal: str,
2939
context: ChatContext,
@@ -106,19 +116,43 @@ async def react(
106116
if tool_res.name == MELLEA_FINALIZER_TOOL:
107117
is_final = True
108118

109-
# Check if we should return: either finalizer was called or model gave direct answer
110-
should_return = is_final or (step.tool_calls is None and step.value is not None)
111-
112-
if should_return:
113-
if is_final:
114-
assert len(tool_responses) == 1, (
115-
"multiple tools were called with 'final'"
119+
# Check for special case where model already has the answer, but it won't call the finalize tool.
120+
# Instead of letting this run out of iterations and fail, let's ask.
121+
# Only do this before we fail on iteration limit as a last resort because it's hard to justify doing it earlier for now.
122+
elif -1 < loop_budget <= turn_num and step.value:
123+
# If the turn number has reached the end of loop budget (and budget is not unlimited),
124+
# then it's time to check if the model is just loopy and already has the answer.
125+
print("### Done Check")
126+
print("STEP_TOOL_CALLS:", step.tool_calls)
127+
print("STEP:", step)
128+
print("CONTEXT:", context)
129+
content = mfuncs.chat(
130+
content=f"Do you know the answer to the user's original query ({goal})? If so, respond with True. If you need to take more actions, then respond False.",
131+
context=context,
132+
backend=backend,
133+
format=TrueOrFalse,
134+
)[0].content
135+
have_answer = TrueOrFalse.model_validate_json(content).answer
136+
137+
print("### Done Check ANSWER: ", have_answer)
138+
if have_answer:
139+
# Create a synthetic finalizer tool response to be consistent with normal loop
140+
finalizer_response = ToolMessage(
141+
role="tool",
142+
content=step.value,
143+
tool_output=step.value,
144+
name=MELLEA_FINALIZER_TOOL,
145+
args={},
146+
tool=None, # type: ignore
116147
)
117-
if format is None:
118-
# The tool has already been called above.
119-
step._underlying_value = str(tool_responses[0].content)
148+
tool_responses = [finalizer_response]
149+
context = context.add(finalizer_response)
150+
is_final = True
120151

121-
# Apply format if requested (works for both finalizer and direct answer cases)
152+
if is_final:
153+
assert len(tool_responses) == 1, "multiple tools were called with 'final'"
154+
155+
# Apply format if requested
122156
if format is not None:
123157
step, next_context = await mfuncs.aact(
124158
action=ReactThought(),
@@ -132,7 +166,9 @@ async def react(
132166
)
133167
assert isinstance(next_context, ChatContext)
134168
context = next_context
135-
169+
else:
170+
# The tool has already been called above.
171+
step._underlying_value = str(tool_responses[0].content)
136172
return step, context
137173

138174
raise RuntimeError(f"could not complete react loop in {loop_budget} iterations")

0 commit comments

Comments
 (0)