diff --git a/docs/docs/concepts/plugins.mdx b/docs/docs/concepts/plugins.mdx index 44d4530de..effca689c 100644 --- a/docs/docs/concepts/plugins.mdx +++ b/docs/docs/concepts/plugins.mdx @@ -658,7 +658,7 @@ async def cap_tokens(payload, ctx): **Fires:** Before invoking a tool from LLM output. -**Payload fields:** `model_tool_call` (contains `name`, `args`, `callable`) +**Payload fields:** `model_tool_call` (contains `name`, `args`, `callable`), `is_control_flow` **Writable fields:** `model_tool_call` @@ -674,11 +674,15 @@ async def enforce_tool_allowlist(payload, ctx): return block(f"Tool '{payload.model_tool_call.name}' not permitted", code="TOOL_NOT_ALLOWED") ``` + +The payload includes an `is_control_flow` field that is `True` for framework control-flow tools (e.g. the ReAct loop's `final_answer`). Allowlist plugins should check this field to avoid blocking internal tools. See [Control-flow tools](#control-flow-tools) for the recommended pattern. + + #### `tool_post_invoke` **Fires:** After tool execution completes. -**Payload fields:** `model_tool_call`, `tool_output`, `tool_message`, `execution_time_ms`, `success`, `error` +**Payload fields:** `model_tool_call`, `tool_output`, `tool_message`, `execution_time_ms`, `success`, `error`, `is_control_flow` **Writable fields:** `tool_output` @@ -810,13 +814,15 @@ The `tool_pre_invoke` and `tool_post_invoke` hooks give you fine-grained control ### Tool allow-listing -Block any tool not on an explicit approved list: +Block any tool not on an explicit approved list. The `is_control_flow` guard ensures framework tools like `final_answer` are not blocked: ```python ALLOWED_TOOLS = frozenset({"get_weather", "calculator"}) @hook(HookType.TOOL_PRE_INVOKE, mode=PluginMode.CONCURRENT, priority=5) async def enforce_tool_allowlist(payload, ctx): + if payload.is_control_flow: + return # framework control-flow tools are exempt tool_name = payload.model_tool_call.name if tool_name not in ALLOWED_TOOLS: return block(f"Tool '{tool_name}' is not permitted", code="TOOL_NOT_ALLOWED") @@ -878,6 +884,43 @@ with start_session(plugins=[tool_security]) as m: See the [full tool hooks example](https://github.com/generative-computing/mellea/blob/main/docs/examples/plugins/tool_hooks.py). +### Control-flow tools + +Mellea's frameworks use internal tools for control flow. For example, the [ReAct loop](../reference/glossary#react) uses a `final_answer` tool to signal that the agent has finished reasoning. These tools flow through the same invocation path as user-defined tools — hooks always fire for them — but the payload carries an `is_control_flow` flag so each plugin can decide its own policy. + +The recommended pattern for allowlist plugins is to skip control-flow tools explicitly: + +```python +ALLOWED_TOOLS = frozenset({"get_weather", "calculator"}) + +@hook(HookType.TOOL_PRE_INVOKE, mode=PluginMode.CONCURRENT, priority=5) +async def enforce_tool_allowlist(payload, ctx): + if payload.is_control_flow: + return # framework control-flow tools are exempt + if payload.model_tool_call.name not in ALLOWED_TOOLS: + return block(f"Tool '{payload.model_tool_call.name}' not permitted") +``` + +Logging and telemetry plugins typically do **not** check this flag — they observe all tool calls including control-flow tools: + +```python +@hook(HookType.TOOL_POST_INVOKE, mode=PluginMode.FIRE_AND_FORGET) +async def log_all_tools(payload, ctx): + logger.info("tool=%s control_flow=%s ms=%d", payload.model_tool_call.name, + payload.is_control_flow, payload.execution_time_ms) +``` + +#### Querying the registry + +Use `is_internal_tool()` to check whether a tool name is a known control-flow tool: + +```python +from mellea.plugins import is_internal_tool + +is_internal_tool("final_answer") # True +is_internal_tool("get_weather") # False +``` + --- ## Patterns and best practices @@ -971,18 +1014,19 @@ All public symbols are available from a single import: ```python from mellea.plugins import ( - HookType, # Enum of all hook types (e.g., GENERATION_PRE_CALL) - Plugin, # Base class for class-based plugins - PluginMode, # Execution mode enum (SEQUENTIAL, TRANSFORM, AUDIT, CONCURRENT, FIRE_AND_FORGET) - PluginResult, # Return type for hooks that modify or block - PluginSet, # Named group of hooks/plugins for composition - PluginViolationError,# Exception raised when a hook blocks execution - block, # Helper to create a blocking PluginResult - hook, # Decorator to register an async function as a hook handler - modify, # Helper to create a modifying PluginResult - plugin_scope, # Context manager for with-block scoped activation - register, # Register hooks/plugins globally or per-session - unregister, # Remove globally-registered hooks/plugins + HookType, # Enum of all hook types (e.g., GENERATION_PRE_CALL) + Plugin, # Base class for class-based plugins + PluginMode, # Execution mode enum (SEQUENTIAL, TRANSFORM, ...) + PluginResult, # Return type for hooks that modify or block + PluginSet, # Named group of hooks/plugins for composition + PluginViolationError, # Exception raised when a hook blocks execution + block, # Helper to create a blocking PluginResult + hook, # Decorator to register an async function as a hook handler + is_internal_tool, # Check if a tool is a framework control-flow tool + modify, # Helper to create a modifying PluginResult + plugin_scope, # Context manager for with-block scoped activation + register, # Register hooks/plugins globally or per-session + unregister, # Remove globally-registered hooks/plugins ) ``` @@ -996,6 +1040,7 @@ from mellea.plugins import ( | `plugin_scope(*items)` | Context manager that registers on enter, deregisters on exit | | `block(reason, *, code, details)` | Create a blocking `PluginResult` | | `modify(payload, **field_updates)` | Create a modifying `PluginResult` via `model_copy` | +| `is_internal_tool(tool_name)` | Returns `True` if the tool is a framework control-flow tool (e.g. `final_answer`) | | `HookType` | Enum with all 18 hook types | | `PluginMode` | Enum: `SEQUENTIAL`, `TRANSFORM`, `AUDIT`, `CONCURRENT`, `FIRE_AND_FORGET` | | `PluginResult` | Typed result with `continue_processing`, `modified_payload`, and `violation` | diff --git a/docs/examples/plugins/tool_hooks.py b/docs/examples/plugins/tool_hooks.py index 3b574ba75..96ec518ba 100644 --- a/docs/examples/plugins/tool_hooks.py +++ b/docs/examples/plugins/tool_hooks.py @@ -150,6 +150,8 @@ def parse_factor(): @hook(HookType.TOOL_PRE_INVOKE, mode=PluginMode.CONCURRENT, priority=5) async def enforce_tool_allowlist(payload, _): """Block any tool not on the explicit allow list.""" + if payload.is_control_flow: + return # framework control-flow tools (e.g. final_answer) are exempt tool_name = payload.model_tool_call.name if tool_name not in ALLOWED_TOOLS: log.warning( diff --git a/mellea/plugins/__init__.py b/mellea/plugins/__init__.py index f5433fdd8..c4a7aa6aa 100644 --- a/mellea/plugins/__init__.py +++ b/mellea/plugins/__init__.py @@ -9,6 +9,7 @@ from .base import Plugin, PluginResult, PluginViolationError from .decorators import hook +from .manager import is_internal_tool from .pluginset import PluginSet from .registry import block, modify, plugin_scope, register, unregister from .types import HookType, PluginMode @@ -22,6 +23,7 @@ "PluginViolationError", "block", "hook", + "is_internal_tool", "modify", "plugin_scope", "register", diff --git a/mellea/plugins/hooks/tool.py b/mellea/plugins/hooks/tool.py index bc92d430e..8bba52ad3 100644 --- a/mellea/plugins/hooks/tool.py +++ b/mellea/plugins/hooks/tool.py @@ -13,9 +13,13 @@ class ToolPreInvokePayload(MelleaBasePayload): Attributes: model_tool_call: The ``ModelToolCall`` about to be executed (writable — plugins may modify arguments or swap the tool entirely). + is_control_flow: ``True`` when this tool is used for framework control + flow (e.g. ``final_answer`` in ReAct) rather than data processing. + Plugins should check this field to decide whether to act. """ model_tool_call: Any = None + is_control_flow: bool = False class ToolPostInvokePayload(MelleaBasePayload): @@ -29,6 +33,9 @@ class ToolPostInvokePayload(MelleaBasePayload): execution_time_ms: Wall-clock time of the tool execution in milliseconds. success: ``True`` if the tool executed without raising an exception. error: The ``Exception`` raised during execution, or ``None`` on success. + is_control_flow: ``True`` when this tool is used for framework control + flow (e.g. ``final_answer`` in ReAct) rather than data processing. + Plugins should check this field to decide whether to act. """ model_tool_call: Any = None @@ -37,3 +44,4 @@ class ToolPostInvokePayload(MelleaBasePayload): execution_time_ms: int = 0 success: bool = True error: Any = None + is_control_flow: bool = False diff --git a/mellea/plugins/manager.py b/mellea/plugins/manager.py index e196a0a25..ce2859d38 100644 --- a/mellea/plugins/manager.py +++ b/mellea/plugins/manager.py @@ -29,6 +29,10 @@ _pending_background_results: list[Any] = [] _collect_background_results: bool = False # opt-in; only tests enable this +# Framework control-flow tool names (e.g. loop terminators). +# These are flagged on the payload so plugins can decide per-tool policy. +_INTERNAL_TOOL_NAMES: frozenset[str] = frozenset({"final_answer"}) + DEFAULT_PLUGIN_TIMEOUT: int = 5 # seconds DEFAULT_HOOK_POLICY: Literal["allow"] | Literal["deny"] = "deny" @@ -88,6 +92,18 @@ def has_plugins(hook_type: HookType | None = None) -> bool: return True +def is_internal_tool(tool_name: str) -> bool: + """Return whether the given tool name is a framework-internal tool. + + Args: + tool_name: Name of the tool to check. + + Returns: + ``True`` if the tool is in the internal tools registry. + """ + return tool_name in _INTERNAL_TOOL_NAMES + + def get_plugin_manager() -> Any | None: """Return the initialized PluginManager, or ``None`` if plugins are not configured. diff --git a/mellea/stdlib/functional.py b/mellea/stdlib/functional.py index 849a72cf8..a281e7c3f 100644 --- a/mellea/stdlib/functional.py +++ b/mellea/stdlib/functional.py @@ -30,7 +30,7 @@ ) from ..helpers import _run_async_in_thread from ..plugins.hooks.tool import ToolPostInvokePayload, ToolPreInvokePayload -from ..plugins.manager import has_plugins, invoke_hook +from ..plugins.manager import has_plugins, invoke_hook, is_internal_tool from ..plugins.types import HookType from ..telemetry import set_span_attribute, trace_application from .components import ( @@ -1287,9 +1287,13 @@ async def _acall_tools(result: ModelOutputThunk, backend: Backend) -> list[ToolM return outputs for name, tool in tool_calls.items(): + control_flow = is_internal_tool(name) + # --- tool_pre_invoke --- if has_plugins(HookType.TOOL_PRE_INVOKE): - pre_payload = ToolPreInvokePayload(model_tool_call=tool) + pre_payload = ToolPreInvokePayload( + model_tool_call=tool, is_control_flow=control_flow + ) _, pre_payload = await invoke_hook( HookType.TOOL_PRE_INVOKE, pre_payload, backend=backend ) @@ -1335,6 +1339,7 @@ async def _acall_tools(result: ModelOutputThunk, backend: Backend) -> list[ToolM execution_time_ms=latency_ms, success=success, error=error, + is_control_flow=control_flow, ) _, post_payload = await invoke_hook( HookType.TOOL_POST_INVOKE, post_payload, backend=backend diff --git a/pyproject.toml b/pyproject.toml index b13604a1b..e1b95cd8d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -109,7 +109,7 @@ switch = [ backends = ["mellea[watsonx,hf,litellm]"] hooks = [ - "cpex>=0.1.0.dev12; python_version >= '3.11'", + "cpex>=0.1.0rc1", "grpcio>=1.78.0", ] diff --git a/test/plugins/test_internal_tool_hook_skip.py b/test/plugins/test_internal_tool_hook_skip.py new file mode 100644 index 000000000..6f7836d4e --- /dev/null +++ b/test/plugins/test_internal_tool_hook_skip.py @@ -0,0 +1,204 @@ +"""Tests for control-flow tool signalling on tool hook payloads. + +Verifies that TOOL_PRE_INVOKE and TOOL_POST_INVOKE hooks always fire for all +tools (including framework-internal ones like ``final_answer``), and that the +``is_control_flow`` field is correctly populated so plugins can decide their +own policy. +""" + +from __future__ import annotations + +from typing import Any +from unittest.mock import MagicMock + +import pytest + +pytest.importorskip("cpex.framework") + +from mellea.core.base import AbstractMelleaTool, ModelOutputThunk, ModelToolCall +from mellea.plugins import block, hook, is_internal_tool, register +from mellea.plugins.manager import shutdown_plugins +from mellea.plugins.types import HookType, PluginMode +from mellea.stdlib.functional import _acall_tools + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +class _RecordingTool(AbstractMelleaTool): + """A tool that records invocations.""" + + def __init__(self, name: str = "test_tool") -> None: + self.name = name + self.calls: list[dict[str, Any]] = [] + + def run(self, **kwargs: Any) -> str: + self.calls.append(dict(kwargs)) + return f"result from {self.name}" + + @property + def as_json_tool(self) -> dict[str, Any]: + return {"name": self.name, "description": "recording tool", "parameters": {}} + + +def _make_result(*tool_calls: ModelToolCall) -> ModelOutputThunk: + """Wrap one or more ModelToolCalls in a minimal ModelOutputThunk.""" + mot = MagicMock(spec=ModelOutputThunk) + mot.tool_calls = {tc.name: tc for tc in tool_calls} + return mot + + +# --------------------------------------------------------------------------- +# Tests — is_internal_tool +# --------------------------------------------------------------------------- + + +class TestIsInternalTool: + def test_recognizes_final_answer(self) -> None: + assert is_internal_tool("final_answer") is True + + def test_rejects_user_tool(self) -> None: + assert is_internal_tool("search") is False + assert is_internal_tool("get_weather") is False + + +# --------------------------------------------------------------------------- +# Tests — hooks always fire, payload carries is_control_flow +# --------------------------------------------------------------------------- + + +class TestControlFlowPayloadField: + async def test_pre_hook_fires_for_internal_tool(self) -> None: + """TOOL_PRE_INVOKE fires for final_answer with is_control_flow=True.""" + tool = _RecordingTool("final_answer") + tc = ModelToolCall(name="final_answer", func=tool, args={"answer": "42"}) + result = _make_result(tc) + + captured: list[Any] = [] + + @hook(HookType.TOOL_PRE_INVOKE) + async def spy(payload, *_): + captured.append(payload) + + register(spy) + + await _acall_tools(result, MagicMock()) + + assert len(captured) == 1 + assert captured[0].model_tool_call.name == "final_answer" + assert captured[0].is_control_flow is True + + async def test_post_hook_fires_for_internal_tool(self) -> None: + """TOOL_POST_INVOKE fires for final_answer with is_control_flow=True.""" + tool = _RecordingTool("final_answer") + tc = ModelToolCall(name="final_answer", func=tool, args={"answer": "42"}) + result = _make_result(tc) + + captured: list[Any] = [] + + @hook(HookType.TOOL_POST_INVOKE) + async def spy(payload, *_): + captured.append(payload) + + register(spy) + + await _acall_tools(result, MagicMock()) + + assert len(captured) == 1 + assert captured[0].is_control_flow is True + + async def test_user_tool_has_control_flow_false(self) -> None: + """User tools get is_control_flow=False.""" + tool = _RecordingTool("search") + tc = ModelToolCall(name="search", func=tool, args={}) + result = _make_result(tc) + + captured: list[Any] = [] + + @hook(HookType.TOOL_PRE_INVOKE) + async def spy(payload, *_): + captured.append(payload) + + register(spy) + + await _acall_tools(result, MagicMock()) + + assert len(captured) == 1 + assert captured[0].is_control_flow is False + + async def test_mixed_batch_sets_flag_correctly(self) -> None: + """In a batch, each tool gets the correct is_control_flow value.""" + internal_tool = _RecordingTool("final_answer") + user_tool = _RecordingTool("search") + tc_internal = ModelToolCall( + name="final_answer", func=internal_tool, args={"answer": "done"} + ) + tc_user = ModelToolCall(name="search", func=user_tool, args={}) + result = _make_result(tc_internal, tc_user) + + captured: list[Any] = [] + + @hook(HookType.TOOL_PRE_INVOKE) + async def spy(payload, *_): + captured.append(payload) + + register(spy) + + await _acall_tools(result, MagicMock()) + + assert len(captured) == 2 + by_name = {p.model_tool_call.name: p for p in captured} + assert by_name["final_answer"].is_control_flow is True + assert by_name["search"].is_control_flow is False + + +# --------------------------------------------------------------------------- +# Tests — plugin pattern: allowlist that skips control-flow tools +# --------------------------------------------------------------------------- + + +class TestAllowlistPluginPattern: + async def test_allowlist_does_not_block_control_flow_tool(self) -> None: + """An allowlist plugin using is_control_flow guard does not block final_answer.""" + allowed_tools = frozenset({"search"}) + + @hook(HookType.TOOL_PRE_INVOKE, mode=PluginMode.CONCURRENT, priority=5) + async def enforce_allowlist(payload, _): + if payload.is_control_flow: + return + if payload.model_tool_call.name not in allowed_tools: + return block(f"Tool '{payload.model_tool_call.name}' not permitted") + + register(enforce_allowlist) + + internal_tool = _RecordingTool("final_answer") + tc = ModelToolCall( + name="final_answer", func=internal_tool, args={"answer": "ok"} + ) + result = _make_result(tc) + + msgs = await _acall_tools(result, MagicMock()) + assert len(msgs) == 1 + + async def test_allowlist_still_blocks_unknown_user_tools(self) -> None: + """The allowlist pattern still blocks non-allowed user tools.""" + from mellea.plugins.base import PluginViolationError + + allowed_tools = frozenset({"search"}) + + @hook(HookType.TOOL_PRE_INVOKE, mode=PluginMode.CONCURRENT, priority=5) + async def enforce_allowlist(payload, _): + if payload.is_control_flow: + return + if payload.model_tool_call.name not in allowed_tools: + return block(f"Tool '{payload.model_tool_call.name}' not permitted") + + register(enforce_allowlist) + + unknown_tool = _RecordingTool("hack_system") + tc = ModelToolCall(name="hack_system", func=unknown_tool, args={}) + result = _make_result(tc) + + with pytest.raises(PluginViolationError, match="not permitted"): + await _acall_tools(result, MagicMock()) diff --git a/uv.lock b/uv.lock index 5908de1cb..1c1cb3558 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.11" resolution-markers = [ "python_full_version >= '3.15'", @@ -883,7 +883,7 @@ toml = [ [[package]] name = "cpex" -version = "0.1.0.dev12" +version = "0.1.0rc1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "fastapi" }, @@ -899,9 +899,9 @@ dependencies = [ { name = "pydantic-settings" }, { name = "pyyaml" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/0d/f6/d5a194b338b3d55b1b9b8619baafa504ae8146168cf4b91fcefa95811a16/cpex-0.1.0.dev12.tar.gz", hash = "sha256:9fb08e0fa27236747c26c841260951a83252029c0e55a7550c65a060473f200c", size = 3475629, upload-time = "2026-04-23T17:34:14.434Z" } +sdist = { url = "https://files.pythonhosted.org/packages/f0/f1/608f5295bcd77a62ce520c91c76ba9fe07c50378e2600aeb7edbf80298c2/cpex-0.1.0rc1.tar.gz", hash = "sha256:36c8c85395073f5a8e828ab972b7a3eedc9f68066e6665473090947481319915", size = 1211891, upload-time = "2026-05-01T03:04:54.602Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/39/ed/f70537bd8adbf1f847a703e02c3abd2cdc53dfa87e44977aa25dd163774b/cpex-0.1.0.dev12-py3-none-any.whl", hash = "sha256:5c10688b6f7ca8c3673fce9dfd94d0b3a348e0e63566546ced5068574a38403e", size = 236654, upload-time = "2026-04-23T17:34:12.592Z" }, + { url = "https://files.pythonhosted.org/packages/6a/e5/a9f07008443e42eceedb433dfd683e480f648c2db4072190aa7ac4209c9d/cpex-0.1.0rc1-py3-none-any.whl", hash = "sha256:6aa9395af7792653dfa18815f11f9c34ebc5ac91284a47f40834ba3f507a7f21", size = 241048, upload-time = "2026-05-01T03:04:52.996Z" }, ] [[package]] @@ -3568,7 +3568,7 @@ typecheck = [ requires-dist = [ { name = "accelerate", marker = "extra == 'hf'", specifier = ">=1.9.0" }, { name = "boto3", marker = "extra == 'litellm'" }, - { name = "cpex", marker = "python_full_version >= '3.11' and extra == 'hooks'", specifier = ">=0.1.0.dev12" }, + { name = "cpex", marker = "extra == 'hooks'", specifier = ">=0.1.0rc1" }, { name = "datasets", marker = "extra == 'hf'", specifier = ">=4.0.0" }, { name = "docling", marker = "extra == 'docling'", specifier = ">=2.45.0" }, { name = "elasticsearch", marker = "extra == 'granite-retriever'", specifier = ">=8.0.0,<9.0.0" },