Skip to content

Commit 074a046

Browse files
committed
adding typing to melleatools
Signed-off-by: Akihiko Kuroda <akihikokuroda2020@gmail.com>
1 parent 208ca9b commit 074a046

2 files changed

Lines changed: 62 additions & 31 deletions

File tree

mellea/backends/tools.py

Lines changed: 41 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import re
1313
from collections import defaultdict
1414
from collections.abc import Callable, Generator, Iterable, Mapping, Sequence
15-
from typing import Any, Literal, overload
15+
from typing import Any, Generic, Literal, ParamSpec, TypeVar, overload
1616

1717
from pydantic import BaseModel, ConfigDict, Field
1818

@@ -22,16 +22,23 @@
2222
from ..core.base import AbstractMelleaTool
2323
from .model_options import ModelOption
2424

25+
P = ParamSpec("P")
26+
R = TypeVar("R")
2527

26-
class MelleaTool(AbstractMelleaTool):
28+
29+
class MelleaTool(AbstractMelleaTool[P, R]):
2730
"""Tool class to represent a callable tool with an OpenAI-compatible JSON schema.
2831
2932
Wraps a Python callable alongside its JSON schema representation so it can be
3033
registered with backends that support tool calling (OpenAI, Ollama, HuggingFace, etc.).
3134
35+
Type parameters:
36+
P: Parameter specification for the underlying callable
37+
R: Return type of the tool
38+
3239
Args:
3340
name (str): The tool name used for identification and lookup.
34-
tool_call (Callable): The underlying Python callable to invoke when the tool is run.
41+
tool_call (Callable[P, R]): The underlying Python callable to invoke when the tool is run.
3542
as_json_tool (dict[str, Any]): The OpenAI-compatible JSON schema dict describing
3643
the tool's parameters.
3744
@@ -42,25 +49,25 @@ class MelleaTool(AbstractMelleaTool):
4249

4350
name: str
4451
_as_json_tool: dict[str, Any]
45-
_call_tool: Callable[..., Any]
52+
_call_tool: Callable[P, R]
4653

4754
def __init__(
48-
self, name: str, tool_call: Callable, as_json_tool: dict[str, Any]
55+
self, name: str, tool_call: Callable[P, R], as_json_tool: dict[str, Any]
4956
) -> None:
5057
"""Initialize the tool with a name, tool call and as_json_tool dict."""
5158
self.name = name
5259
self._as_json_tool = as_json_tool
5360
self._call_tool = tool_call
5461

55-
def run(self, *args, **kwargs) -> Any:
62+
def run(self, *args: P.args, **kwargs: P.kwargs) -> R:
5663
"""Run the tool with the given arguments.
5764
5865
Args:
59-
args: Positional arguments forwarded to the underlying callable.
60-
kwargs: Keyword arguments forwarded to the underlying callable.
66+
*args: Positional arguments forwarded to the underlying callable.
67+
**kwargs: Keyword arguments forwarded to the underlying callable.
6168
6269
Returns:
63-
Any: The return value of the underlying callable.
70+
R: The return value of the underlying callable.
6471
"""
6572
return self._call_tool(*args, **kwargs)
6673

@@ -70,14 +77,14 @@ def as_json_tool(self) -> dict[str, Any]:
7077
return self._as_json_tool.copy()
7178

7279
@classmethod
73-
def from_langchain(cls, tool: Any) -> "MelleaTool":
80+
def from_langchain(cls, tool: Any) -> "MelleaTool[P, Any]":
7481
"""Create a MelleaTool from a LangChain tool object.
7582
7683
Args:
7784
tool (Any): A ``langchain_core.tools.BaseTool`` instance to wrap.
7885
7986
Returns:
80-
MelleaTool: A Mellea tool wrapping the LangChain tool.
87+
MelleaTool[P, Any]: A Mellea tool wrapping the LangChain tool.
8188
8289
Raises:
8390
ImportError: If ``langchain-core`` is not installed.
@@ -117,14 +124,14 @@ def parameter_remapper(*args, **kwargs):
117124
) from e
118125

119126
@classmethod
120-
def from_smolagents(cls, tool: Any) -> "MelleaTool":
127+
def from_smolagents(cls, tool: Any) -> "MelleaTool[P, Any]":
121128
"""Create a Tool from a HuggingFace smolagents tool object.
122129
123130
Args:
124131
tool: A smolagents.Tool instance
125132
126133
Returns:
127-
MelleaTool: A Mellea tool wrapping the smolagents tool
134+
MelleaTool[P, Any]: A Mellea tool wrapping the smolagents tool
128135
129136
Raises:
130137
ImportError: If smolagents is not installed
@@ -172,18 +179,20 @@ def tool_call(*args, **kwargs):
172179
) from e
173180

174181
@classmethod
175-
def from_callable(cls, func: Callable, name: str | None = None) -> "MelleaTool":
182+
def from_callable(
183+
cls, func: Callable[P, R], name: str | None = None
184+
) -> "MelleaTool[P, R]":
176185
"""Create a MelleaTool from a plain Python callable.
177186
178187
Introspects the callable's signature and docstring to build an
179188
OpenAI-compatible JSON schema automatically.
180189
181190
Args:
182-
func (Callable): The Python callable to wrap as a tool.
191+
func (Callable[P, R]): The Python callable to wrap as a tool.
183192
name (str | None): Optional name override; defaults to ``func.__name__``.
184193
185194
Returns:
186-
MelleaTool: A Mellea tool wrapping the callable.
195+
MelleaTool[P, R]: A Mellea tool wrapping the callable with preserved parameter and return types.
187196
"""
188197
# Use the function name if the name is '' or None.
189198
tool_name = name or func.__name__
@@ -195,28 +204,34 @@ def from_callable(cls, func: Callable, name: str | None = None) -> "MelleaTool":
195204

196205

197206
@overload
198-
def tool(func: Callable, *, name: str | None = None) -> MelleaTool: ...
207+
def tool(func: Callable[P, R], *, name: str | None = None) -> MelleaTool[P, R]: ...
199208

200209

201210
@overload
202-
def tool(*, name: str | None = None) -> Callable[[Callable], MelleaTool]: ...
211+
def tool(
212+
*, name: str | None = None
213+
) -> Callable[[Callable[P, R]], MelleaTool[P, R]]: ...
203214

204215

205216
def tool(
206-
func: Callable | None = None, name: str | None = None
207-
) -> MelleaTool | Callable[[Callable], MelleaTool]:
208-
"""Decorator to mark a function as a Mellea tool.
217+
func: Callable[P, R] | None = None, name: str | None = None
218+
) -> MelleaTool[P, R] | Callable[[Callable[P, R]], MelleaTool[P, R]]:
219+
"""Decorator to mark a function as a Mellea tool with type-safe parameter and return types.
209220
210221
This decorator wraps a function to make it usable as a tool without
211222
requiring explicit MelleaTool.from_callable() calls. The decorated
212223
function returns a MelleaTool instance that must be called via .run().
213224
225+
Type parameters:
226+
P: Parameter specification of the decorated function
227+
R: Return type of the decorated function
228+
214229
Args:
215230
func: The function to decorate (when used without arguments)
216231
name: Optional custom name for the tool (defaults to function name)
217232
218233
Returns:
219-
A MelleaTool instance. Use .run() to invoke the tool.
234+
A MelleaTool[P, R] instance with preserved parameter and return types. Use .run() to invoke.
220235
The returned object passes isinstance(result, MelleaTool) checks.
221236
222237
Examples:
@@ -237,8 +252,8 @@ def tool(
237252
>>> # Can be used directly in tools list (no extraction needed)
238253
>>> tools = [get_weather]
239254
>>>
240-
>>> # Must use .run() to invoke the tool
241-
>>> result = get_weather.run(location="Boston")
255+
>>> # Must use .run() to invoke the tool - now with type hints
256+
>>> result = get_weather.run(location="Boston") # IDE shows: location: str, days: int = 1
242257
243258
With custom name (as decorator):
244259
>>> @tool(name="weather_api")
@@ -252,8 +267,8 @@ def tool(
252267
>>> differently_named_tool = tool(new_tool, name="different_name")
253268
"""
254269

255-
def decorator(f: Callable) -> MelleaTool:
256-
# Simply return the base MelleaTool instance
270+
def decorator(f: Callable[P, R]) -> MelleaTool[P, R]:
271+
# Simply return the base MelleaTool instance with preserved types
257272
return MelleaTool.from_callable(f, name=name)
258273

259274
# Handle both @tool and @tool() syntax

mellea/core/base.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,15 @@
2121
from copy import copy, deepcopy
2222
from dataclasses import dataclass
2323
from io import BytesIO
24-
from typing import Any, Generic, Literal, Protocol, TypeVar, runtime_checkable
24+
from typing import (
25+
Any,
26+
Generic,
27+
Literal,
28+
ParamSpec,
29+
Protocol,
30+
TypeVar,
31+
runtime_checkable,
32+
)
2533

2634
import typing_extensions
2735
from PIL import Image as PILImage
@@ -947,8 +955,16 @@ def view_for_generation(self) -> list[Component | CBlock] | None:
947955
...
948956

949957

950-
class AbstractMelleaTool(abc.ABC):
951-
"""Abstract base class for Mellea Tool.
958+
P = ParamSpec("P")
959+
R = TypeVar("R")
960+
961+
962+
class AbstractMelleaTool(abc.ABC, Generic[P, R]):
963+
"""Abstract base class for Mellea Tool with parameter and return type support.
964+
965+
Type parameters:
966+
P: Parameter specification for the tool's callable (via ParamSpec)
967+
R: Return type of the tool
952968
953969
Attributes:
954970
name (str): The unique name used to identify the tool in JSON descriptions and tool-call dispatch.
@@ -960,15 +976,15 @@ class AbstractMelleaTool(abc.ABC):
960976
"""Name of the tool."""
961977

962978
@abc.abstractmethod
963-
def run(self, *args: Any, **kwargs: Any) -> Any:
979+
def run(self, *args: P.args, **kwargs: P.kwargs) -> R:
964980
"""Executes the tool with the provided arguments and returns the result.
965981
966982
Args:
967983
*args: Positional arguments forwarded to the tool implementation.
968984
**kwargs: Keyword arguments forwarded to the tool implementation.
969985
970986
Returns:
971-
Any: The result produced by the tool; the concrete type depends on the implementation.
987+
R: The result produced by the tool; the concrete type depends on the implementation.
972988
"""
973989

974990
@property

0 commit comments

Comments
 (0)