Skip to content
64 changes: 63 additions & 1 deletion docs/examples/agents/react/react_using_mellea.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from langchain_community.tools import DuckDuckGoSearchResults

from mellea.backends.tools import MelleaTool
from mellea.stdlib import functional as mfuncs
from mellea.stdlib.context import ChatContext
from mellea.stdlib.frameworks.react import react
from mellea.stdlib.session import start_session
Expand All @@ -28,15 +29,75 @@ class Email(pydantic.BaseModel):
body: str


class TrueOrFalse(pydantic.BaseModel):
"""Response indicating whether the ReACT agent has completed its task."""

answer: bool = pydantic.Field(
description="True if you have enough information to answer the user's question, False if you need more tool calls"
)


async def last_loop_completion_check(
goal, context, backend, step, model_options, turn_num, loop_budget
):
"""Completion check that asks the model if it has the answer on the last iteration.

Only checks on the last iteration (when turn_num == loop_budget) to avoid
unnecessary LLM calls. Returns False for all other iterations.

Note: step.value is guaranteed to exist when this is called.
"""
# Only check on last iteration (and not for unlimited budget)
if loop_budget == -1 or turn_num < loop_budget:
return False

message, _ = await mfuncs.achat(
content=f"Do you know the answer to the user's original query ({goal})? If so, respond with True. If you need to take more actions, then respond False.",
context=context,
backend=backend,
format=TrueOrFalse,
)
have_answer = TrueOrFalse.model_validate_json(message.content).answer

return have_answer


async def custom_completion_check(
goal, context, backend, step, model_options, turn_num, loop_budget
):
"""Custom completion check combining keyword detection and fallback to last-loop check.

This runs every iteration:
1. First checks if response contains "final answer" for early termination
2. On the last iteration, falls back to asking the model if it has the answer

Note: step.value is guaranteed to exist when this is called.
"""
# Check every iteration for "final answer" keyword (early termination)
if "final answer" in step.value.lower():
return True

# On last iteration, fall back to asking the model if it has the answer
if loop_budget != -1 and turn_num >= loop_budget:
return await last_loop_completion_check(
goal, context, backend, step, model_options, turn_num, loop_budget
)

return False


async def main():
"""Example."""
# Simple version that just searches for an answer.
# Version with custom answer check that terminates early
# when the model says "final answer" and queries the LLM
# if it reaches the loop_budget.
out, _ = await react(
goal="What is the Mellea python library?",
context=ChatContext(),
backend=m.backend,
tools=[search_tool],
loop_budget=12,
answer_check=custom_completion_check,
)
print(out)

Expand All @@ -47,6 +108,7 @@ async def main():
# backend=m.backend,
# tools=[search_tool],
# format=Email,
# answer_check=custom_completion_check,
# loop_budget=20,
# )
# print(out)
Expand Down
71 changes: 70 additions & 1 deletion mellea/stdlib/frameworks/react.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
history tracking. Raises ``RuntimeError`` if the loop ends without a final answer.
"""

from collections.abc import Awaitable, Callable
from typing import Protocol

# from PIL import Image as PILImage
from mellea.backends.model_options import ModelOption
from mellea.core.backend import Backend, BaseModelSubclass
Expand All @@ -24,6 +27,42 @@
from mellea.stdlib.context import ChatContext


class AnswerCheckCallable(Protocol):
"""Protocol for answer validation callbacks in the ReACT loop.

Called each iteration when no tool calls are made and the model output is non-empty.
Allows custom logic to determine if the current output should be accepted as the
final answer.
"""

async def __call__(
self,
goal: str,
context: ChatContext,
backend: Backend,
step: ComputedModelOutputThunk[str],
model_options: dict | None,
turn_num: int,
loop_budget: int,
) -> bool:
"""Validate whether the current step should be accepted as the final answer.

Args:
goal: The original goal or question to accomplish.
context: The current conversation context.
backend: The backend being used for generation.
step: The current model output thunk containing the potential answer.
model_options: Model options in effect for this generation.
turn_num: Current iteration number (0-indexed).
loop_budget: Maximum allowed iterations.

Returns:
True to accept step.value as the final answer and exit the loop;
False to continue iterating.
"""
...


async def react(
goal: str,
context: ChatContext,
Expand All @@ -36,6 +75,7 @@ async def react(
model_options: dict | None = None,
tools: list[AbstractMelleaTool] | None,
loop_budget: int = 10,
answer_check: AnswerCheckCallable | None = None,
) -> tuple[ComputedModelOutputThunk[str], ChatContext]:
"""Asynchronous ReACT pattern (Think -> Act -> Observe -> Repeat Until Done); attempts to accomplish the provided goal given the provided tools.

Expand All @@ -47,9 +87,16 @@ async def react(
model_options: additional model options, which will upsert into the model/backend's defaults.
tools: the list of tools to use
loop_budget: the number of steps allowed; use -1 for unlimited
answer_check: Optional async callable invoked each iteration when no tool
calls are made and ``step.value`` is non-empty. Receives ``goal`` (str),
``step`` (ComputedModelOutputThunk[str]), ``context`` (ChatContext),
``backend`` (Backend), ``model_options`` (dict | None), ``turn_num`` (int),
and ``loop_budget`` (int). Return ``True`` to accept ``step.value`` as the
final answer, ``False`` to continue. If ``None``, the loop runs until
``final_answer`` is called or ``loop_budget`` is exhausted.

Returns:
A (ModelOutputThunk, Context) if `return_sampling_results` is `False`, else returns a `SamplingResult`.
Tuple of ``(ComputedModelOutputThunk[str], ChatContext)`` — the final answer thunk and updated context.

Raises:
RuntimeError: if the loop ends before a final answer is found
Expand Down Expand Up @@ -107,9 +154,31 @@ async def react(
if tool_res.name == MELLEA_FINALIZER_TOOL:
is_final = True

# Check if the agent has completed its task (runs every iteration if answer_check is provided and there's a value)
# The answer_check function can decide when to actually check based on turn_num and loop_budget
elif not is_final and answer_check and step.value:
have_answer = await answer_check(
goal, context, backend, step, model_options, turn_num, loop_budget
)

if have_answer:
# Create a synthetic finalizer tool response to be consistent with normal loop
finalizer_response = ToolMessage(
role="tool",
content=step.value or "",
tool_output=step.value or "",
name=MELLEA_FINALIZER_TOOL,
args={},
tool=None, # type: ignore
)
tool_responses = [finalizer_response]
context = context.add(finalizer_response)
is_final = True

if is_final:
assert len(tool_responses) == 1, "multiple tools were called with 'final'"

# Apply format if requested
if format is not None:
step, next_context = await mfuncs.aact(
action=ReactThought(),
Expand Down
107 changes: 107 additions & 0 deletions test/stdlib/frameworks/test_react_direct_answer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
"""Test ReACT framework handling of direct answers without tool calls."""
Comment thread
markstur marked this conversation as resolved.

import pydantic
import pytest

from mellea.backends.tools import tool
from mellea.stdlib import functional as mfuncs
from mellea.stdlib.context import ChatContext
from mellea.stdlib.frameworks.react import react
from mellea.stdlib.session import start_session


class TrueOrFalse(pydantic.BaseModel):
"""Response indicating whether the ReACT agent has completed its task."""

answer: bool = pydantic.Field(
description="True if you have enough information to answer the user's question, False if you need more tool calls"
)


async def last_loop_completion_check(
goal, context, backend, step, model_options, turn_num, loop_budget
):
"""Completion check that asks the model if it has the answer on the last iteration.

Note: step.value is guaranteed to exist when this is called.
"""
# Only check on last iteration (and not for unlimited budget)
if loop_budget == -1 or turn_num < loop_budget:
return False

content = mfuncs.chat(
content=f"Do you know the answer to the user's original query ({goal})? If so, respond with True. If you need to take more actions, then respond False.",
context=context,
backend=backend,
format=TrueOrFalse,
)[0].content
have_answer = TrueOrFalse.model_validate_json(content).answer
return have_answer


@pytest.mark.ollama
@pytest.mark.e2e
@pytest.mark.qualitative
async def test_react_direct_answer_without_tools():
Comment thread
markstur marked this conversation as resolved.
"""Test that ReACT handles direct answers when model doesn't call tools.

This tests the case where the model provides a direct answer in step.value
without making any tool calls. The fix ensures the loop terminates properly
instead of continuing until loop_budget is exhausted.
"""
m = start_session()

# Ask a simple question that doesn't require tools
# The model should provide a direct answer without calling any tools
out, _ = await react(
goal="What is 2 + 2?",
context=ChatContext(),
backend=m.backend,
tools=[], # No tools provided
loop_budget=3, # Should complete in 1 iteration, not exhaust budget
answer_check=last_loop_completion_check,
)

# Verify we got an answer
assert out.value is not None
assert len(out.value) > 0

# The answer should contain "4" or "four"
answer_lower = out.value.lower()
assert "4" in answer_lower or "four" in answer_lower


@pytest.mark.ollama
@pytest.mark.e2e
@pytest.mark.qualitative
async def test_react_direct_answer_with_unused_tools():
"""Test that ReACT handles direct answers even when tools are available.

This tests the case where tools are provided but the model chooses to
answer directly without using them.
"""
m = start_session()

# Create a dummy tool that won't be needed
@tool
def search_web(query: str) -> str:
"""Search the web for information."""
return "Search results"

# Ask a question that doesn't need the tool
out, _ = await react(
goal="What is the capital of France?",
context=ChatContext(),
backend=m.backend,
tools=[search_web],
loop_budget=3,
answer_check=last_loop_completion_check,
)

# Verify we got an answer
assert out.value is not None
assert len(out.value) > 0

# The answer should mention Paris
answer_lower = out.value.lower()
assert "paris" in answer_lower
59 changes: 59 additions & 0 deletions test/stdlib/frameworks/test_react_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,5 +231,64 @@ async def test_react_rejects_non_chat_context():
await react(goal="g", context=Mock(), backend=Mock(), tools=None)


@pytest.mark.asyncio
async def test_react_answer_check_terminates_on_direct_response():
"""answer_check returning True on a no-tool-call turn exits the loop."""
backend = ScriptedBackend([_ScriptedTurn(value="42")])

async def always_done(
goal, context, backend, step, model_options, turn_num, loop_budget
):
return True

result, _ = await react(
goal="answer",
context=ChatContext(),
backend=backend,
tools=None,
loop_budget=5,
answer_check=always_done,
)
assert result.value == "42"


@pytest.mark.asyncio
async def test_react_answer_check_continues_when_false():
"""answer_check returning False on a no-tool-call turn continues the loop."""
backend = ScriptedBackend(
[
_ScriptedTurn(
value="thinking..."
), # First turn, answer_check returns False
_final_answer_call("42"), # Second turn, model calls final_answer
]
)

call_count = 0

async def check_on_second_call(
goal, context, backend, step, model_options, turn_num, loop_budget
):
nonlocal call_count
call_count += 1
# Return False on first call to test the continue branch
# On second call, model will use final_answer tool instead
return False

result, _ = await react(
goal="answer",
context=ChatContext(),
backend=backend,
tools=None,
loop_budget=5,
answer_check=check_on_second_call,
)

# Verify answer_check was called once (returned False, loop continued)
assert call_count == 1
# Verify we got the final answer from the second turn
assert result.value == "42"


if __name__ == "__main__":
pytest.main([__file__])
Loading