diff --git a/docs/examples/agents/react/react_using_mellea.py b/docs/examples/agents/react/react_using_mellea.py index 422e8f3e5..aed1884e7 100644 --- a/docs/examples/agents/react/react_using_mellea.py +++ b/docs/examples/agents/react/react_using_mellea.py @@ -8,6 +8,7 @@ from langchain_community.tools import DuckDuckGoSearchResults from mellea.backends.tools import MelleaTool +from mellea.stdlib import functional as mfuncs from mellea.stdlib.context import ChatContext from mellea.stdlib.frameworks.react import react from mellea.stdlib.session import start_session @@ -28,15 +29,75 @@ class Email(pydantic.BaseModel): body: str +class TrueOrFalse(pydantic.BaseModel): + """Response indicating whether the ReACT agent has completed its task.""" + + answer: bool = pydantic.Field( + description="True if you have enough information to answer the user's question, False if you need more tool calls" + ) + + +async def last_loop_completion_check( + goal, context, backend, step, model_options, turn_num, loop_budget +): + """Completion check that asks the model if it has the answer on the last iteration. + + Only checks on the last iteration (when turn_num == loop_budget) to avoid + unnecessary LLM calls. Returns False for all other iterations. + + Note: step.value is guaranteed to exist when this is called. + """ + # Only check on last iteration (and not for unlimited budget) + if loop_budget == -1 or turn_num < loop_budget: + return False + + message, _ = await mfuncs.achat( + content=f"Do you know the answer to the user's original query ({goal})? If so, respond with True. If you need to take more actions, then respond False.", + context=context, + backend=backend, + format=TrueOrFalse, + ) + have_answer = TrueOrFalse.model_validate_json(message.content).answer + + return have_answer + + +async def custom_completion_check( + goal, context, backend, step, model_options, turn_num, loop_budget +): + """Custom completion check combining keyword detection and fallback to last-loop check. + + This runs every iteration: + 1. First checks if response contains "final answer" for early termination + 2. On the last iteration, falls back to asking the model if it has the answer + + Note: step.value is guaranteed to exist when this is called. + """ + # Check every iteration for "final answer" keyword (early termination) + if "final answer" in step.value.lower(): + return True + + # On last iteration, fall back to asking the model if it has the answer + if loop_budget != -1 and turn_num >= loop_budget: + return await last_loop_completion_check( + goal, context, backend, step, model_options, turn_num, loop_budget + ) + + return False + + async def main(): """Example.""" - # Simple version that just searches for an answer. + # Version with custom answer check that terminates early + # when the model says "final answer" and queries the LLM + # if it reaches the loop_budget. out, _ = await react( goal="What is the Mellea python library?", context=ChatContext(), backend=m.backend, tools=[search_tool], loop_budget=12, + answer_check=custom_completion_check, ) print(out) @@ -47,6 +108,7 @@ async def main(): # backend=m.backend, # tools=[search_tool], # format=Email, + # answer_check=custom_completion_check, # loop_budget=20, # ) # print(out) diff --git a/mellea/stdlib/frameworks/react.py b/mellea/stdlib/frameworks/react.py index c39338990..4fdebbc5c 100644 --- a/mellea/stdlib/frameworks/react.py +++ b/mellea/stdlib/frameworks/react.py @@ -7,6 +7,9 @@ history tracking. Raises ``RuntimeError`` if the loop ends without a final answer. """ +from collections.abc import Awaitable, Callable +from typing import Protocol + # from PIL import Image as PILImage from mellea.backends.model_options import ModelOption from mellea.core.backend import Backend, BaseModelSubclass @@ -24,6 +27,42 @@ from mellea.stdlib.context import ChatContext +class AnswerCheckCallable(Protocol): + """Protocol for answer validation callbacks in the ReACT loop. + + Called each iteration when no tool calls are made and the model output is non-empty. + Allows custom logic to determine if the current output should be accepted as the + final answer. + """ + + async def __call__( + self, + goal: str, + context: ChatContext, + backend: Backend, + step: ComputedModelOutputThunk[str], + model_options: dict | None, + turn_num: int, + loop_budget: int, + ) -> bool: + """Validate whether the current step should be accepted as the final answer. + + Args: + goal: The original goal or question to accomplish. + context: The current conversation context. + backend: The backend being used for generation. + step: The current model output thunk containing the potential answer. + model_options: Model options in effect for this generation. + turn_num: Current iteration number (0-indexed). + loop_budget: Maximum allowed iterations. + + Returns: + True to accept step.value as the final answer and exit the loop; + False to continue iterating. + """ + ... + + async def react( goal: str, context: ChatContext, @@ -36,6 +75,7 @@ async def react( model_options: dict | None = None, tools: list[AbstractMelleaTool] | None, loop_budget: int = 10, + answer_check: AnswerCheckCallable | None = None, ) -> tuple[ComputedModelOutputThunk[str], ChatContext]: """Asynchronous ReACT pattern (Think -> Act -> Observe -> Repeat Until Done); attempts to accomplish the provided goal given the provided tools. @@ -47,9 +87,16 @@ async def react( model_options: additional model options, which will upsert into the model/backend's defaults. tools: the list of tools to use loop_budget: the number of steps allowed; use -1 for unlimited + answer_check: Optional async callable invoked each iteration when no tool + calls are made and ``step.value`` is non-empty. Receives ``goal`` (str), + ``step`` (ComputedModelOutputThunk[str]), ``context`` (ChatContext), + ``backend`` (Backend), ``model_options`` (dict | None), ``turn_num`` (int), + and ``loop_budget`` (int). Return ``True`` to accept ``step.value`` as the + final answer, ``False`` to continue. If ``None``, the loop runs until + ``final_answer`` is called or ``loop_budget`` is exhausted. Returns: - A (ModelOutputThunk, Context) if `return_sampling_results` is `False`, else returns a `SamplingResult`. + Tuple of ``(ComputedModelOutputThunk[str], ChatContext)`` — the final answer thunk and updated context. Raises: RuntimeError: if the loop ends before a final answer is found @@ -107,9 +154,31 @@ async def react( if tool_res.name == MELLEA_FINALIZER_TOOL: is_final = True + # Check if the agent has completed its task (runs every iteration if answer_check is provided and there's a value) + # The answer_check function can decide when to actually check based on turn_num and loop_budget + elif not is_final and answer_check and step.value: + have_answer = await answer_check( + goal, context, backend, step, model_options, turn_num, loop_budget + ) + + if have_answer: + # Create a synthetic finalizer tool response to be consistent with normal loop + finalizer_response = ToolMessage( + role="tool", + content=step.value or "", + tool_output=step.value or "", + name=MELLEA_FINALIZER_TOOL, + args={}, + tool=None, # type: ignore + ) + tool_responses = [finalizer_response] + context = context.add(finalizer_response) + is_final = True + if is_final: assert len(tool_responses) == 1, "multiple tools were called with 'final'" + # Apply format if requested if format is not None: step, next_context = await mfuncs.aact( action=ReactThought(), diff --git a/test/stdlib/frameworks/test_react_direct_answer.py b/test/stdlib/frameworks/test_react_direct_answer.py new file mode 100644 index 000000000..0024afd25 --- /dev/null +++ b/test/stdlib/frameworks/test_react_direct_answer.py @@ -0,0 +1,107 @@ +"""Test ReACT framework handling of direct answers without tool calls.""" + +import pydantic +import pytest + +from mellea.backends.tools import tool +from mellea.stdlib import functional as mfuncs +from mellea.stdlib.context import ChatContext +from mellea.stdlib.frameworks.react import react +from mellea.stdlib.session import start_session + + +class TrueOrFalse(pydantic.BaseModel): + """Response indicating whether the ReACT agent has completed its task.""" + + answer: bool = pydantic.Field( + description="True if you have enough information to answer the user's question, False if you need more tool calls" + ) + + +async def last_loop_completion_check( + goal, context, backend, step, model_options, turn_num, loop_budget +): + """Completion check that asks the model if it has the answer on the last iteration. + + Note: step.value is guaranteed to exist when this is called. + """ + # Only check on last iteration (and not for unlimited budget) + if loop_budget == -1 or turn_num < loop_budget: + return False + + content = mfuncs.chat( + content=f"Do you know the answer to the user's original query ({goal})? If so, respond with True. If you need to take more actions, then respond False.", + context=context, + backend=backend, + format=TrueOrFalse, + )[0].content + have_answer = TrueOrFalse.model_validate_json(content).answer + return have_answer + + +@pytest.mark.ollama +@pytest.mark.e2e +@pytest.mark.qualitative +async def test_react_direct_answer_without_tools(): + """Test that ReACT handles direct answers when model doesn't call tools. + + This tests the case where the model provides a direct answer in step.value + without making any tool calls. The fix ensures the loop terminates properly + instead of continuing until loop_budget is exhausted. + """ + m = start_session() + + # Ask a simple question that doesn't require tools + # The model should provide a direct answer without calling any tools + out, _ = await react( + goal="What is 2 + 2?", + context=ChatContext(), + backend=m.backend, + tools=[], # No tools provided + loop_budget=3, # Should complete in 1 iteration, not exhaust budget + answer_check=last_loop_completion_check, + ) + + # Verify we got an answer + assert out.value is not None + assert len(out.value) > 0 + + # The answer should contain "4" or "four" + answer_lower = out.value.lower() + assert "4" in answer_lower or "four" in answer_lower + + +@pytest.mark.ollama +@pytest.mark.e2e +@pytest.mark.qualitative +async def test_react_direct_answer_with_unused_tools(): + """Test that ReACT handles direct answers even when tools are available. + + This tests the case where tools are provided but the model chooses to + answer directly without using them. + """ + m = start_session() + + # Create a dummy tool that won't be needed + @tool + def search_web(query: str) -> str: + """Search the web for information.""" + return "Search results" + + # Ask a question that doesn't need the tool + out, _ = await react( + goal="What is the capital of France?", + context=ChatContext(), + backend=m.backend, + tools=[search_web], + loop_budget=3, + answer_check=last_loop_completion_check, + ) + + # Verify we got an answer + assert out.value is not None + assert len(out.value) > 0 + + # The answer should mention Paris + answer_lower = out.value.lower() + assert "paris" in answer_lower diff --git a/test/stdlib/frameworks/test_react_framework.py b/test/stdlib/frameworks/test_react_framework.py index e121a91f5..ab3d47022 100644 --- a/test/stdlib/frameworks/test_react_framework.py +++ b/test/stdlib/frameworks/test_react_framework.py @@ -231,5 +231,64 @@ async def test_react_rejects_non_chat_context(): await react(goal="g", context=Mock(), backend=Mock(), tools=None) +@pytest.mark.asyncio +async def test_react_answer_check_terminates_on_direct_response(): + """answer_check returning True on a no-tool-call turn exits the loop.""" + backend = ScriptedBackend([_ScriptedTurn(value="42")]) + + async def always_done( + goal, context, backend, step, model_options, turn_num, loop_budget + ): + return True + + result, _ = await react( + goal="answer", + context=ChatContext(), + backend=backend, + tools=None, + loop_budget=5, + answer_check=always_done, + ) + assert result.value == "42" + + +@pytest.mark.asyncio +async def test_react_answer_check_continues_when_false(): + """answer_check returning False on a no-tool-call turn continues the loop.""" + backend = ScriptedBackend( + [ + _ScriptedTurn( + value="thinking..." + ), # First turn, answer_check returns False + _final_answer_call("42"), # Second turn, model calls final_answer + ] + ) + + call_count = 0 + + async def check_on_second_call( + goal, context, backend, step, model_options, turn_num, loop_budget + ): + nonlocal call_count + call_count += 1 + # Return False on first call to test the continue branch + # On second call, model will use final_answer tool instead + return False + + result, _ = await react( + goal="answer", + context=ChatContext(), + backend=backend, + tools=None, + loop_budget=5, + answer_check=check_on_second_call, + ) + + # Verify answer_check was called once (returned False, loop continued) + assert call_count == 1 + # Verify we got the final answer from the second turn + assert result.value == "42" + + if __name__ == "__main__": pytest.main([__file__])