Skip to content

Commit e1bec38

Browse files
committed
fix tool calling argument
Signed-off-by: Akihiko Kuroda <akihikokuroda2020@gmail.com>
1 parent cba83f5 commit e1bec38

2 files changed

Lines changed: 128 additions & 37 deletions

File tree

docs/examples/agents/react/react_using_mellea.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,15 @@ async def main():
4141
print(out)
4242

4343
# Version that looks up info and formats the final response as an Email object.
44-
# out, _ = await react(
45-
# goal="Write an email about the Mellea python library to Jake with the subject 'cool library'.",
46-
# context=ChatContext(),
47-
# backend=m.backend,
48-
# tools=[search_tool],
49-
# format=Email
50-
# )
51-
# print(out)
44+
out, _ = await react(
45+
goal="Write an email about the Mellea python library to Jake with the subject 'cool library'.",
46+
context=ChatContext(),
47+
backend=m.backend,
48+
tools=[search_tool],
49+
format=Email,
50+
loop_budget=20,
51+
)
52+
print(out)
5253

5354

5455
asyncio.run(main())

mellea/backends/tools.py

Lines changed: 119 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -499,19 +499,41 @@ def validate_tool_arguments(
499499
"object": dict,
500500
}
501501

502-
# Build Pydantic model from JSON schema
503-
field_definitions: dict[str, Any] = {}
502+
# Helper function to build Pydantic model from nested JSON schema
503+
def _build_pydantic_type_from_schema(schema: dict[str, Any]) -> Any:
504+
"""Recursively build Pydantic type from JSON schema."""
505+
json_type = schema.get("type", "string")
506+
507+
# Handle nested objects with properties
508+
if json_type == "object" and "properties" in schema:
509+
nested_properties = schema.get("properties", {})
510+
nested_required = schema.get("required", [])
511+
nested_fields: dict[str, Any] = {}
512+
513+
for nested_name, nested_schema in nested_properties.items():
514+
nested_type = _build_pydantic_type_from_schema(nested_schema)
515+
if nested_name in nested_required:
516+
nested_fields[nested_name] = (nested_type, ...)
517+
else:
518+
nested_fields[nested_name] = (nested_type, None)
519+
520+
# Create a nested Pydantic model
521+
return create_model(
522+
f"Nested_{hash(str(schema))}",
523+
__config__=ConfigDict(extra="allow"),
524+
**nested_fields,
525+
)
504526

505-
for param_name, param_schema in properties.items():
506-
# Get type from JSON schema
507-
json_type = param_schema.get("type", "string")
527+
# Handle arrays
528+
if json_type == "array":
529+
item_schema = schema.get("items", {})
530+
item_type = _build_pydantic_type_from_schema(item_schema)
531+
return list[item_type] # type: ignore
508532

509-
# Handle comma-separated types (e.g., "integer, string" for Union types)
533+
# Handle comma-separated types (Union types)
510534
if isinstance(json_type, str) and "," in json_type:
511-
# Create Union type for multiple types
512535
type_list = [t.strip() for t in json_type.split(",")]
513536
python_types = [JSON_TYPE_TO_PYTHON.get(t, Any) for t in type_list]
514-
# Remove duplicates while preserving order
515537
seen = set()
516538
unique_types = []
517539
for t in python_types:
@@ -520,15 +542,36 @@ def validate_tool_arguments(
520542
unique_types.append(t)
521543

522544
if len(unique_types) == 1:
523-
param_type = unique_types[0]
545+
return unique_types[0]
524546
else:
525547
from functools import reduce
526548
from operator import or_
527549

528-
param_type = reduce(or_, unique_types)
529-
else:
530-
# Map to Python type
531-
param_type = JSON_TYPE_TO_PYTHON.get(json_type, Any)
550+
return reduce(or_, unique_types)
551+
552+
# Handle anyOf (union types in JSON schema)
553+
if "anyOf" in schema:
554+
types_list = []
555+
for sub_schema in schema["anyOf"]:
556+
sub_type = _build_pydantic_type_from_schema(sub_schema)
557+
types_list.append(sub_type)
558+
559+
if len(types_list) == 1:
560+
return types_list[0]
561+
else:
562+
from functools import reduce
563+
from operator import or_
564+
565+
return reduce(or_, types_list)
566+
567+
# Simple type mapping
568+
return JSON_TYPE_TO_PYTHON.get(json_type, Any)
569+
570+
# Build Pydantic model from JSON schema
571+
field_definitions: dict[str, Any] = {}
572+
573+
for param_name, param_schema in properties.items():
574+
param_type = _build_pydantic_type_from_schema(param_schema)
532575

533576
# Determine if parameter is required
534577
if param_name in required_fields:
@@ -784,14 +827,20 @@ class Property(SubscriptableBaseModel):
784827
items (Any | None): Schema for array element types, if applicable.
785828
description (str | None): Human-readable description of this parameter.
786829
enum (Sequence[Any] | None): Allowed values for this parameter, if constrained.
830+
properties (Mapping[str, Any] | None): Nested properties for object types.
831+
required (Sequence[str] | None): Required fields for nested objects.
832+
title (str | None): Title for the property schema.
787833
"""
788834

789-
model_config = ConfigDict(arbitrary_types_allowed=True)
835+
model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow")
790836

791837
type: str | Sequence[str] | None = None
792838
items: Any | None = None
793839
description: str | None = None
794840
enum: Sequence[Any] | None = None
841+
properties: Mapping[str, Any] | None = None
842+
required: Sequence[str] | None = None
843+
title: str | None = None
795844

796845
properties: Mapping[str, Property] | None = None
797846

@@ -851,7 +900,7 @@ def convert_function_to_ollama_tool(
851900
) -> OllamaTool:
852901
"""Convert a Python callable to an Ollama-compatible tool schema.
853902
854-
Imported from Ollama.
903+
Imported from Ollama, with enhancements to support Pydantic BaseModel parameters.
855904
856905
Args:
857906
func: The Python callable to convert.
@@ -861,6 +910,8 @@ def convert_function_to_ollama_tool(
861910
An ``OllamaTool`` instance representing the function as an OpenAI-compatible
862911
tool schema.
863912
"""
913+
import copy
914+
864915
doc_string_hash = str(hash(inspect.getdoc(func)))
865916
parsed_docstring = _parse_docstring(inspect.getdoc(func))
866917
schema = type(
@@ -876,21 +927,60 @@ def convert_function_to_ollama_tool(
876927
},
877928
).model_json_schema() # type: ignore
878929

930+
# Helper to resolve $ref references
931+
def resolve_ref(ref_path: str, defs: dict) -> dict:
932+
"""Resolve a $ref path like '#/$defs/Email' to the actual schema."""
933+
if ref_path.startswith("#/$defs/"):
934+
def_name = ref_path.split("/")[-1]
935+
return defs.get(def_name, {})
936+
elif ref_path.startswith("#/definitions/"):
937+
def_name = ref_path.split("/")[-1]
938+
return defs.get(def_name, {})
939+
return {}
940+
941+
defs = schema.get("$defs", schema.get("definitions", {}))
942+
879943
for k, v in schema.get("properties", {}).items():
880-
# If type is missing, the default is string
881-
types = (
882-
{t.get("type", "string") for t in v.get("anyOf")}
883-
if "anyOf" in v
884-
else {v.get("type", "string")}
885-
)
886-
if "null" in types:
887-
schema["required"].remove(k)
888-
types.discard("null")
889-
890-
schema["properties"][k] = {
891-
"description": parsed_docstring[k],
892-
"type": ", ".join(types),
893-
}
944+
# Check if this property has a $ref (reference to a definition)
945+
if "$ref" in v:
946+
# Resolve the reference and inline it
947+
ref_schema = resolve_ref(v["$ref"], defs)
948+
if ref_schema:
949+
# Inline the referenced schema (deep copy to avoid mutations)
950+
inlined = copy.deepcopy(ref_schema)
951+
# Add description from docstring if available
952+
if parsed_docstring.get(k):
953+
inlined["description"] = parsed_docstring[k]
954+
schema["properties"][k] = inlined
955+
else:
956+
# Fallback if we can't resolve
957+
schema["properties"][k] = {
958+
"description": parsed_docstring[k],
959+
"type": "object",
960+
}
961+
# Check if this property is a nested object (has 'properties' or complex types)
962+
elif "properties" in v or "allOf" in v or "anyOf" in v:
963+
# This is a complex/nested type - preserve the full schema
964+
# Only add description if we have one
965+
if parsed_docstring.get(k):
966+
v["description"] = parsed_docstring[k]
967+
schema["properties"][k] = v
968+
else:
969+
# Simple type - use the original flattening logic
970+
types = (
971+
{t.get("type", "string") for t in v.get("anyOf")}
972+
if "anyOf" in v
973+
else {v.get("type", "string")}
974+
)
975+
if "null" in types:
976+
if k in schema.get("required", []):
977+
schema["required"].remove(k)
978+
types.discard("null")
979+
980+
schema["properties"][k] = {
981+
"description": parsed_docstring[k],
982+
"type": ", ".join(types),
983+
}
894984

895985
tool = OllamaTool(
896986
type="function",

0 commit comments

Comments
 (0)