Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/examples/agents/react/react_using_mellea.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ async def main():
# context=ChatContext(),
# backend=m.backend,
# tools=[search_tool],
# format=Email
# format=Email,
# loop_budget=20,
# )
# print(out)

Expand Down
202 changes: 173 additions & 29 deletions mellea/backends/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
validating/coercing tool arguments against the tool's JSON schema using Pydantic.
"""

import copy
import inspect
import json
import re
Expand Down Expand Up @@ -499,19 +500,54 @@ def validate_tool_arguments(
"object": dict,
}

# Build Pydantic model from JSON schema
field_definitions: dict[str, Any] = {}
# Helper function to build Pydantic model from nested JSON schema
def _build_pydantic_type_from_schema(schema: dict[str, Any]) -> Any:
Comment thread
ajbozarth marked this conversation as resolved.
"""Recursively build Pydantic type from JSON schema.

for param_name, param_schema in properties.items():
# Get type from JSON schema
json_type = param_schema.get("type", "string")
Note: This assumes all $ref references have been resolved by
convert_function_to_ollama_tool before validation. If a schema
still contains $ref entries, this will return Any as a fallback.
Comment thread
planetf1 marked this conversation as resolved.
"""
# Early exit for unresolved $ref: return Any to disable validation
# rather than silently mistype as str
if "$ref" in schema:
return Any

json_type = schema.get("type", "string")

# Handle nested objects with properties
if json_type == "object" and "properties" in schema:
nested_properties = schema.get("properties", {})
nested_required = schema.get("required", [])
nested_fields: dict[str, Any] = {}

for nested_name, nested_schema in nested_properties.items():
nested_type = _build_pydantic_type_from_schema(nested_schema)
if nested_name in nested_required:
nested_fields[nested_name] = (nested_type, ...)
else:
nested_fields[nested_name] = (nested_type, None)

# Create a nested Pydantic model with deterministic name
# Use sorted field names for consistent naming across runs
# Respect strict mode: forbid extra fields if caller requested strict=True
model_name = f"Nested_{'_'.join(sorted(nested_fields.keys()))}"
return create_model(
model_name,
__config__=ConfigDict(extra="forbid" if strict else "allow"),
**nested_fields,
)

# Handle arrays
if json_type == "array":
item_schema = schema.get("items", {})
item_type = _build_pydantic_type_from_schema(item_schema)
return list[item_type] # type: ignore

# Handle comma-separated types (e.g., "integer, string" for Union types)
# Handle comma-separated types (Union types)
if isinstance(json_type, str) and "," in json_type:
# Create Union type for multiple types
type_list = [t.strip() for t in json_type.split(",")]
python_types = [JSON_TYPE_TO_PYTHON.get(t, Any) for t in type_list]
# Remove duplicates while preserving order
seen = set()
unique_types = []
for t in python_types:
Expand All @@ -520,15 +556,46 @@ def validate_tool_arguments(
unique_types.append(t)

if len(unique_types) == 1:
param_type = unique_types[0]
return unique_types[0]
else:
from functools import reduce
from operator import or_

param_type = reduce(or_, unique_types)
else:
# Map to Python type
param_type = JSON_TYPE_TO_PYTHON.get(json_type, Any)
return reduce(or_, unique_types)

# Handle anyOf (union types in JSON schema)
if "anyOf" in schema:
Comment thread
planetf1 marked this conversation as resolved.
# Filter out null sub-schemas before recursion to prevent them from
# contaminating the union if an unresolved $ref returns Any.
# Null becomes an explicit Optional[] wrapper instead.
has_null = any(s.get("type") == "null" for s in schema["anyOf"])
non_null_schemas = [s for s in schema["anyOf"] if s.get("type") != "null"]

types_list = []
for sub_schema in non_null_schemas:
sub_type = _build_pydantic_type_from_schema(sub_schema)
types_list.append(sub_type)

if len(types_list) == 0:
result = Any
elif len(types_list) == 1:
result = types_list[0]
else:
from functools import reduce
from operator import or_

result = reduce(or_, types_list)

return (result | None) if has_null else result

# Simple type mapping
return JSON_TYPE_TO_PYTHON.get(json_type, Any)

# Build Pydantic model from JSON schema
field_definitions: dict[str, Any] = {}

for param_name, param_schema in properties.items():
param_type = _build_pydantic_type_from_schema(param_schema)

# Determine if parameter is required
if param_name in required_fields:
Expand Down Expand Up @@ -784,14 +851,20 @@ class Property(SubscriptableBaseModel):
items (Any | None): Schema for array element types, if applicable.
description (str | None): Human-readable description of this parameter.
enum (Sequence[Any] | None): Allowed values for this parameter, if constrained.
properties (Mapping[str, Any] | None): Nested properties for object types.
required (Sequence[str] | None): Required fields for nested objects.
title (str | None): Title for the property schema.
"""

model_config = ConfigDict(arbitrary_types_allowed=True)
model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow")

type: str | Sequence[str] | None = None
items: Any | None = None
description: str | None = None
enum: Sequence[Any] | None = None
properties: Mapping[str, Any] | None = None
required: Sequence[str] | None = None
title: str | None = None

properties: Mapping[str, Property] | None = None

Expand Down Expand Up @@ -845,13 +918,37 @@ def _parse_docstring(doc_string: str | None) -> dict[str, str]:
return parsed_docstring


def _resolve_ref(ref_path: str, defs: dict) -> dict:
"""Resolve a $ref path like '#/$defs/Email' to the actual schema."""
if ref_path.startswith("#/$defs/"):
def_name = ref_path.split("/")[-1]
return defs.get(def_name, {})
elif ref_path.startswith("#/definitions/"):
def_name = ref_path.split("/")[-1]
return defs.get(def_name, {})
return {}


def _is_complex_anyof(v: dict) -> bool:
"""Check if anyOf contains complex types (refs or nested objects)."""
any_of_schemas = v.get("anyOf", [])
for sub_schema in any_of_schemas:
# Skip null types - they just indicate optionality
if sub_schema.get("type") == "null":
continue
# Check for references or nested properties (don't recursively check allOf)
if "$ref" in sub_schema or "properties" in sub_schema:
return True
return False


# https://github.com/ollama/ollama-python/blob/60e7b2f9ce710eeb57ef2986c46ea612ae7516af/ollama/_utils.py#L56-L90
def convert_function_to_ollama_tool(
func: Callable, name: str | None = None
) -> OllamaTool:
"""Convert a Python callable to an Ollama-compatible tool schema.

Imported from Ollama.
Imported from Ollama, with enhancements to support Pydantic BaseModel parameters.

Args:
func: The Python callable to convert.
Expand All @@ -876,21 +973,68 @@ def convert_function_to_ollama_tool(
},
).model_json_schema() # type: ignore

defs = schema.get("$defs", schema.get("definitions", {}))

for k, v in schema.get("properties", {}).items():
# If type is missing, the default is string
types = (
{t.get("type", "string") for t in v.get("anyOf")}
if "anyOf" in v
else {v.get("type", "string")}
)
if "null" in types:
schema["required"].remove(k)
types.discard("null")

schema["properties"][k] = {
"description": parsed_docstring[k],
"type": ", ".join(types),
}
# First pass: inline all $refs (at top level and within anyOf)
if "$ref" in v:
# Resolve the reference and inline it
ref_schema = _resolve_ref(v["$ref"], defs)
if ref_schema:
# Inline the referenced schema (deep copy to avoid mutations)
inlined = copy.deepcopy(ref_schema)
# Add description from docstring if available
if parsed_docstring.get(k):
inlined["description"] = parsed_docstring[k]
schema["properties"][k] = inlined
v = inlined # Update v to point to inlined schema
else:
# Fallback if we can't resolve
schema["properties"][k] = {
"description": parsed_docstring.get(k, ""),
"type": "object",
}
v = schema["properties"][k]

# Inline $refs within anyOf
if "anyOf" in v:
for i, sub_schema in enumerate(v["anyOf"]):
if "$ref" in sub_schema:
ref_schema = _resolve_ref(sub_schema["$ref"], defs)
if ref_schema:
# Inline the referenced schema
v["anyOf"][i] = copy.deepcopy(ref_schema)

# Second pass: determine how to handle the property type
if "properties" in v or "allOf" in v or ("anyOf" in v and _is_complex_anyof(v)):
# This is a complex/nested type - preserve the full schema
# Only add description if we have one
if parsed_docstring.get(k):
v["description"] = parsed_docstring[k]
# If anyOf contains null (making it Optional), remove from required
if "anyOf" in v and any(
t.get("type") == "null" for t in v.get("anyOf", [])
):
if k in schema.get("required", []):
schema["required"].remove(k)
schema["properties"][k] = v
else:
# Simple type - use the original flattening logic
# This now handles Optional[primitive] types correctly
if "anyOf" in v:
types = {t.get("type", "string") for t in v.get("anyOf")}
else:
types = {v.get("type", "string")}

if "null" in types:
if k in schema.get("required", []):
schema["required"].remove(k)
types.discard("null")

schema["properties"][k] = {
"description": parsed_docstring.get(k, ""),
"type": ", ".join(types),
}

tool = OllamaTool(
type="function",
Expand Down
Loading
Loading