diff --git a/docs/examples/agents/react/react_using_mellea.py b/docs/examples/agents/react/react_using_mellea.py index 4bd23328a..b4e89712e 100644 --- a/docs/examples/agents/react/react_using_mellea.py +++ b/docs/examples/agents/react/react_using_mellea.py @@ -46,7 +46,8 @@ async def main(): # context=ChatContext(), # backend=m.backend, # tools=[search_tool], - # format=Email + # format=Email, + # loop_budget=20, # ) # print(out) diff --git a/mellea/backends/tools.py b/mellea/backends/tools.py index 3dd056512..da675eab3 100644 --- a/mellea/backends/tools.py +++ b/mellea/backends/tools.py @@ -7,6 +7,7 @@ validating/coercing tool arguments against the tool's JSON schema using Pydantic. """ +import copy import inspect import json import re @@ -499,19 +500,54 @@ def validate_tool_arguments( "object": dict, } - # Build Pydantic model from JSON schema - field_definitions: dict[str, Any] = {} + # Helper function to build Pydantic model from nested JSON schema + def _build_pydantic_type_from_schema(schema: dict[str, Any]) -> Any: + """Recursively build Pydantic type from JSON schema. - for param_name, param_schema in properties.items(): - # Get type from JSON schema - json_type = param_schema.get("type", "string") + Note: This assumes all $ref references have been resolved by + convert_function_to_ollama_tool before validation. If a schema + still contains $ref entries, this will return Any as a fallback. + """ + # Early exit for unresolved $ref: return Any to disable validation + # rather than silently mistype as str + if "$ref" in schema: + return Any + + json_type = schema.get("type", "string") + + # Handle nested objects with properties + if json_type == "object" and "properties" in schema: + nested_properties = schema.get("properties", {}) + nested_required = schema.get("required", []) + nested_fields: dict[str, Any] = {} + + for nested_name, nested_schema in nested_properties.items(): + nested_type = _build_pydantic_type_from_schema(nested_schema) + if nested_name in nested_required: + nested_fields[nested_name] = (nested_type, ...) + else: + nested_fields[nested_name] = (nested_type, None) + + # Create a nested Pydantic model with deterministic name + # Use sorted field names for consistent naming across runs + # Respect strict mode: forbid extra fields if caller requested strict=True + model_name = f"Nested_{'_'.join(sorted(nested_fields.keys()))}" + return create_model( + model_name, + __config__=ConfigDict(extra="forbid" if strict else "allow"), + **nested_fields, + ) + + # Handle arrays + if json_type == "array": + item_schema = schema.get("items", {}) + item_type = _build_pydantic_type_from_schema(item_schema) + return list[item_type] # type: ignore - # Handle comma-separated types (e.g., "integer, string" for Union types) + # Handle comma-separated types (Union types) if isinstance(json_type, str) and "," in json_type: - # Create Union type for multiple types type_list = [t.strip() for t in json_type.split(",")] python_types = [JSON_TYPE_TO_PYTHON.get(t, Any) for t in type_list] - # Remove duplicates while preserving order seen = set() unique_types = [] for t in python_types: @@ -520,15 +556,46 @@ def validate_tool_arguments( unique_types.append(t) if len(unique_types) == 1: - param_type = unique_types[0] + return unique_types[0] else: from functools import reduce from operator import or_ - param_type = reduce(or_, unique_types) - else: - # Map to Python type - param_type = JSON_TYPE_TO_PYTHON.get(json_type, Any) + return reduce(or_, unique_types) + + # Handle anyOf (union types in JSON schema) + if "anyOf" in schema: + # Filter out null sub-schemas before recursion to prevent them from + # contaminating the union if an unresolved $ref returns Any. + # Null becomes an explicit Optional[] wrapper instead. + has_null = any(s.get("type") == "null" for s in schema["anyOf"]) + non_null_schemas = [s for s in schema["anyOf"] if s.get("type") != "null"] + + types_list = [] + for sub_schema in non_null_schemas: + sub_type = _build_pydantic_type_from_schema(sub_schema) + types_list.append(sub_type) + + if len(types_list) == 0: + result = Any + elif len(types_list) == 1: + result = types_list[0] + else: + from functools import reduce + from operator import or_ + + result = reduce(or_, types_list) + + return (result | None) if has_null else result + + # Simple type mapping + return JSON_TYPE_TO_PYTHON.get(json_type, Any) + + # Build Pydantic model from JSON schema + field_definitions: dict[str, Any] = {} + + for param_name, param_schema in properties.items(): + param_type = _build_pydantic_type_from_schema(param_schema) # Determine if parameter is required if param_name in required_fields: @@ -784,14 +851,20 @@ class Property(SubscriptableBaseModel): items (Any | None): Schema for array element types, if applicable. description (str | None): Human-readable description of this parameter. enum (Sequence[Any] | None): Allowed values for this parameter, if constrained. + properties (Mapping[str, Any] | None): Nested properties for object types. + required (Sequence[str] | None): Required fields for nested objects. + title (str | None): Title for the property schema. """ - model_config = ConfigDict(arbitrary_types_allowed=True) + model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow") type: str | Sequence[str] | None = None items: Any | None = None description: str | None = None enum: Sequence[Any] | None = None + properties: Mapping[str, Any] | None = None + required: Sequence[str] | None = None + title: str | None = None properties: Mapping[str, Property] | None = None @@ -845,13 +918,37 @@ def _parse_docstring(doc_string: str | None) -> dict[str, str]: return parsed_docstring +def _resolve_ref(ref_path: str, defs: dict) -> dict: + """Resolve a $ref path like '#/$defs/Email' to the actual schema.""" + if ref_path.startswith("#/$defs/"): + def_name = ref_path.split("/")[-1] + return defs.get(def_name, {}) + elif ref_path.startswith("#/definitions/"): + def_name = ref_path.split("/")[-1] + return defs.get(def_name, {}) + return {} + + +def _is_complex_anyof(v: dict) -> bool: + """Check if anyOf contains complex types (refs or nested objects).""" + any_of_schemas = v.get("anyOf", []) + for sub_schema in any_of_schemas: + # Skip null types - they just indicate optionality + if sub_schema.get("type") == "null": + continue + # Check for references or nested properties (don't recursively check allOf) + if "$ref" in sub_schema or "properties" in sub_schema: + return True + return False + + # https://github.com/ollama/ollama-python/blob/60e7b2f9ce710eeb57ef2986c46ea612ae7516af/ollama/_utils.py#L56-L90 def convert_function_to_ollama_tool( func: Callable, name: str | None = None ) -> OllamaTool: """Convert a Python callable to an Ollama-compatible tool schema. - Imported from Ollama. + Imported from Ollama, with enhancements to support Pydantic BaseModel parameters. Args: func: The Python callable to convert. @@ -876,21 +973,68 @@ def convert_function_to_ollama_tool( }, ).model_json_schema() # type: ignore + defs = schema.get("$defs", schema.get("definitions", {})) + for k, v in schema.get("properties", {}).items(): - # If type is missing, the default is string - types = ( - {t.get("type", "string") for t in v.get("anyOf")} - if "anyOf" in v - else {v.get("type", "string")} - ) - if "null" in types: - schema["required"].remove(k) - types.discard("null") - - schema["properties"][k] = { - "description": parsed_docstring[k], - "type": ", ".join(types), - } + # First pass: inline all $refs (at top level and within anyOf) + if "$ref" in v: + # Resolve the reference and inline it + ref_schema = _resolve_ref(v["$ref"], defs) + if ref_schema: + # Inline the referenced schema (deep copy to avoid mutations) + inlined = copy.deepcopy(ref_schema) + # Add description from docstring if available + if parsed_docstring.get(k): + inlined["description"] = parsed_docstring[k] + schema["properties"][k] = inlined + v = inlined # Update v to point to inlined schema + else: + # Fallback if we can't resolve + schema["properties"][k] = { + "description": parsed_docstring.get(k, ""), + "type": "object", + } + v = schema["properties"][k] + + # Inline $refs within anyOf + if "anyOf" in v: + for i, sub_schema in enumerate(v["anyOf"]): + if "$ref" in sub_schema: + ref_schema = _resolve_ref(sub_schema["$ref"], defs) + if ref_schema: + # Inline the referenced schema + v["anyOf"][i] = copy.deepcopy(ref_schema) + + # Second pass: determine how to handle the property type + if "properties" in v or "allOf" in v or ("anyOf" in v and _is_complex_anyof(v)): + # This is a complex/nested type - preserve the full schema + # Only add description if we have one + if parsed_docstring.get(k): + v["description"] = parsed_docstring[k] + # If anyOf contains null (making it Optional), remove from required + if "anyOf" in v and any( + t.get("type") == "null" for t in v.get("anyOf", []) + ): + if k in schema.get("required", []): + schema["required"].remove(k) + schema["properties"][k] = v + else: + # Simple type - use the original flattening logic + # This now handles Optional[primitive] types correctly + if "anyOf" in v: + types = {t.get("type", "string") for t in v.get("anyOf")} + else: + types = {v.get("type", "string")} + + if "null" in types: + if k in schema.get("required", []): + schema["required"].remove(k) + types.discard("null") + + schema["properties"][k] = { + "description": parsed_docstring.get(k, ""), + "type": ", ".join(types), + } tool = OllamaTool( type="function", diff --git a/test/backends/test_pydantic_tool_parameters.py b/test/backends/test_pydantic_tool_parameters.py new file mode 100644 index 000000000..c0a259126 --- /dev/null +++ b/test/backends/test_pydantic_tool_parameters.py @@ -0,0 +1,597 @@ +"""Tests for tools with Pydantic BaseModel parameters. + +This test file addresses the issue where tools defined with Pydantic BaseModel +parameters don't properly validate/coerce arguments from LLM responses. +""" + +import pytest +from pydantic import BaseModel + +from mellea.backends.tools import MelleaTool, validate_tool_arguments +from mellea.core import ModelToolCall + +# ============================================================================ +# Test Fixtures - Pydantic Models +# ============================================================================ + + +class Email(BaseModel): + """An email message.""" + + to: str + subject: str + body: str + + +class Address(BaseModel): + """A physical address.""" + + street: str + city: str + state: str + zip_code: str + + +class Person(BaseModel): + """A person with contact info.""" + + name: str + age: int + email: str + address: Address | None = None + + +# ============================================================================ +# Test Fixtures - Tool Functions (for schema generation testing) +# These use Pydantic BaseModel types to test schema generation +# ============================================================================ + + +def send_email_typed(email: Email) -> str: + """Send an email message. + + Args: + email: The email to send + """ + # In practice, this receives a dict from LLM tool calls + # but we type it as Email for schema generation + if isinstance(email, dict): + return f"Sent email to {email['to']} with subject '{email['subject']}'" + return f"Sent email to {email.to} with subject '{email.subject}'" + + +def create_contact_typed(person: Person) -> str: + """Create a new contact. + + Args: + person: The person's information + """ + if isinstance(person, dict): + return f"Created contact for {person['name']}" + return f"Created contact for {person.name}" + + +def simple_nested_typed(data: Email, priority: int = 1) -> str: + """Tool with both BaseModel and primitive parameters. + + Args: + data: Email data + priority: Priority level + """ + if isinstance(data, dict): + return f"Priority {priority}: {data['subject']}" + return f"Priority {priority}: {data.subject}" + + +# ============================================================================ +# Test Fixtures - Tool Functions (for actual tool calls) +# These accept dicts as LLM tool calls provide +# ============================================================================ + + +def send_email(email: dict) -> str: + """Send an email message. + + Args: + email: The email to send (dict with to, subject, body) + """ + return f"Sent email to {email['to']} with subject '{email['subject']}'" + + +def create_contact(person: dict) -> str: + """Create a new contact. + + Args: + person: The person's information (dict with name, age, email, optional address) + """ + return f"Created contact for {person['name']}" + + +def simple_nested(data: dict, priority: int = 1) -> str: + """Tool with both BaseModel and primitive parameters. + + Args: + data: Email data (dict with to, subject, body) + priority: Priority level + """ + return f"Priority {priority}: {data['subject']}" + + +# ============================================================================ +# Test Cases +# ============================================================================ + + +class TestPydanticParameterSchemaGeneration: + """Test that Pydantic BaseModel parameters generate correct schemas.""" + + def test_simple_basemodel_schema(self): + """Test schema generation for simple BaseModel parameter.""" + tool = MelleaTool.from_callable(send_email_typed) + schema = tool.as_json_tool + + # Check basic structure + assert "function" in schema + assert schema["function"]["name"] == "send_email_typed" + + # Check parameters + params = schema["function"]["parameters"] + assert "properties" in params + assert "email" in params["properties"] + + # The email parameter should have nested properties + email_schema = params["properties"]["email"] + assert "type" in email_schema + + # For nested objects, we expect either: + # 1. type: "object" with nested properties + # 2. A reference to the Email schema + if email_schema["type"] == "object": + assert "properties" in email_schema + assert "to" in email_schema["properties"] + assert "subject" in email_schema["properties"] + assert "body" in email_schema["properties"] + + def test_nested_basemodel_schema(self): + """Test schema generation for nested BaseModel (Person with Address).""" + tool = MelleaTool.from_callable(create_contact_typed) + schema = tool.as_json_tool + + params = schema["function"]["parameters"] + person_schema = params["properties"]["person"] + + # Person should be an object type + if person_schema["type"] == "object": + assert "properties" in person_schema + assert "name" in person_schema["properties"] + assert "age" in person_schema["properties"] + assert "email" in person_schema["properties"] + assert "address" in person_schema["properties"] + + # Address should also be properly nested + address_schema = person_schema["properties"]["address"] + if address_schema.get("type") == "object": + assert "properties" in address_schema + assert "street" in address_schema["properties"] + + def test_mixed_parameters_schema(self): + """Test schema with both BaseModel and primitive parameters.""" + tool = MelleaTool.from_callable(simple_nested_typed) + schema = tool.as_json_tool + + params = schema["function"]["parameters"] + assert "data" in params["properties"] + assert "priority" in params["properties"] + + # Priority should be a simple integer + priority_schema = params["properties"]["priority"] + assert "integer" in priority_schema.get("type", "") + + +class TestPydanticParameterValidation: + """Test validation of Pydantic BaseModel parameters.""" + + def test_valid_nested_object(self): + """Test validation with correctly structured nested object.""" + tool = MelleaTool.from_callable(send_email) + + # LLM returns a properly structured email object + args = { + "email": { + "to": "user@example.com", + "subject": "Test Subject", + "body": "Test body content", + } + } + + validated = validate_tool_arguments(tool, args, coerce_types=True) + + # Should validate successfully + assert "email" in validated + assert validated["email"]["to"] == "user@example.com" + assert validated["email"]["subject"] == "Test Subject" + assert validated["email"]["body"] == "Test body content" + + def test_nested_object_with_type_coercion(self): + """Test that nested object fields can be coerced.""" + tool = MelleaTool.from_callable(create_contact_typed) + + # LLM returns age as string + args = { + "person": { + "name": "John Doe", + "age": "30", # String instead of int + "email": "john@example.com", + } + } + + validated = validate_tool_arguments(tool, args, coerce_types=True) + + # Age should be coerced to int + assert validated["person"]["age"] == 30 + assert isinstance(validated["person"]["age"], int) + + def test_missing_required_nested_field(self): + """Test validation fails when required nested field is missing.""" + tool = MelleaTool.from_callable(send_email_typed) + + # Missing 'body' field + args = {"email": {"to": "user@example.com", "subject": "Test"}} + + # In lenient mode, should return original args + validated = validate_tool_arguments(tool, args, strict=False) + assert validated == args + + # In strict mode, should raise + from pydantic import ValidationError + + with pytest.raises(ValidationError): + validate_tool_arguments(tool, args, strict=True) + + def test_optional_nested_object(self): + """Test validation with optional nested object.""" + tool = MelleaTool.from_callable(create_contact_typed) + + # Address is optional, so this should be valid + args = {"person": {"name": "Jane Doe", "age": 25, "email": "jane@example.com"}} + + validated = validate_tool_arguments(tool, args, coerce_types=True) + assert validated["person"]["name"] == "Jane Doe" + assert ( + "address" not in validated["person"] + or validated["person"].get("address") is None + ) + + def test_nested_object_with_all_fields(self): + """Test validation with all fields including optional nested object.""" + tool = MelleaTool.from_callable(create_contact_typed) + + args = { + "person": { + "name": "Bob Smith", + "age": 35, + "email": "bob@example.com", + "address": { + "street": "123 Main St", + "city": "Boston", + "state": "MA", + "zip_code": "02101", + }, + } + } + + validated = validate_tool_arguments(tool, args, coerce_types=True) + assert validated["person"]["address"]["city"] == "Boston" + + +class TestPydanticParameterToolCall: + """Test actual tool calls with Pydantic BaseModel parameters.""" + + def test_tool_call_with_basemodel(self): + """Test that tool can be called with validated BaseModel args.""" + tool = MelleaTool.from_callable(send_email) + + args = { + "email": { + "to": "test@example.com", + "subject": "Hello", + "body": "Test message", + } + } + + validated = validate_tool_arguments(tool, args, coerce_types=True) + tool_call = ModelToolCall("send_email", tool, validated) + result = tool_call.call_func() + + assert "test@example.com" in result + assert "Hello" in result + + def test_tool_call_with_coerced_nested_types(self): + """Test tool call with type coercion in nested object.""" + tool = MelleaTool.from_callable(create_contact) + + # Age as string, should be coerced + args = {"person": {"name": "Alice", "age": "28", "email": "alice@example.com"}} + + validated = validate_tool_arguments(tool, args, coerce_types=True) + tool_call = ModelToolCall("create_contact", tool, validated) + result = tool_call.call_func() + + assert "Alice" in result + + def test_mixed_parameters_tool_call(self): + """Test tool with both BaseModel and primitive parameters.""" + tool = MelleaTool.from_callable(simple_nested_typed) + + args = { + "data": { + "to": "user@example.com", + "subject": "Important", + "body": "Message", + }, + "priority": "5", # String that should be coerced to int + } + + validated = validate_tool_arguments(tool, args, coerce_types=True) + assert validated["priority"] == 5 + assert isinstance(validated["priority"], int) + + tool_call = ModelToolCall("simple_nested", tool, validated) + result = tool_call.call_func() + + assert "Priority 5" in result + assert "Important" in result + + +class TestEdgeCases: + """Test edge cases with Pydantic parameters.""" + + def test_flat_dict_instead_of_nested(self): + """Test when LLM returns flat dict instead of nested structure.""" + tool = MelleaTool.from_callable(send_email_typed) + + # LLM might incorrectly flatten the structure + args = {"to": "user@example.com", "subject": "Test", "body": "Content"} + + # This should fail validation in strict mode + from pydantic import ValidationError + + with pytest.raises(ValidationError): + validate_tool_arguments(tool, args, strict=True) + + # In lenient mode, returns original + validated = validate_tool_arguments(tool, args, strict=False) + assert validated == args + + def test_extra_fields_in_nested_object(self): + """Test nested object with extra fields not in schema.""" + tool = MelleaTool.from_callable(send_email_typed) + + args = { + "email": { + "to": "user@example.com", + "subject": "Test", + "body": "Content", + "extra_field": "should be ignored or preserved", + } + } + + # In lenient mode, extra fields might be preserved + validated = validate_tool_arguments(tool, args, strict=False) + assert validated["email"]["to"] == "user@example.com" + + +class TestOptionalParameterRegression: + """Test cases to prevent regression of Optional parameter handling. + + These tests verify the fix for the anyOf narrowing issue where Optional + parameters were incorrectly treated as complex types and added to required. + """ + + def test_basemodel_param_inlined_no_ref(self): + """Test def f(email: Email) — required BaseModel param. + + Confirms the core fix works: email is inlined in the schema (no $ref in + the output), and validate_tool_arguments accepts the dict without error. + """ + + def send_email(email: Email) -> str: + """Send an email. + + Args: + email: The email to send + """ + return f"Sent to {email.to}" + + tool = MelleaTool.from_callable(send_email) + schema = tool.as_json_tool + + # Verify email is inlined (no $ref) + params = schema["function"]["parameters"] + email_prop = params["properties"]["email"] + assert "$ref" not in email_prop, "Email should be inlined, not a $ref" + assert email_prop["type"] == "object" + assert "properties" in email_prop + assert "to" in email_prop["properties"] + assert "subject" in email_prop["properties"] + assert "body" in email_prop["properties"] + + # Verify email is required + assert "email" in params["required"] + + # Verify validation works + args = {"email": {"to": "a@b.com", "subject": "hi", "body": "test"}} + validated = validate_tool_arguments(tool, args, coerce_types=True) + assert validated["email"]["to"] == "a@b.com" + assert validated["email"]["subject"] == "hi" + + def test_optional_scalar_not_required(self): + """Test def f(x: str, y: str | None = None) — Optional scalar. + + Confirms Optional params still work: y must be absent from required and + the schema type must be "string", not a raw anyOf structure. + """ + + def process_text(x: str, y: str | None = None) -> str: + """Process text with optional parameter. + + Args: + x: Required text + y: Optional additional text + """ + return f"{x} {y or ''}" + + tool = MelleaTool.from_callable(process_text) + schema = tool.as_json_tool + + params = schema["function"]["parameters"] + + # Verify x is required + assert "x" in params["required"] + + # Verify y is NOT required + assert "y" not in params["required"], ( + "Optional parameter y should not be in required" + ) + + # Verify y has simple string type, not raw anyOf + y_prop = params["properties"]["y"] + assert y_prop["type"] == "string", ( + f"Expected 'string', got {y_prop.get('type')}" + ) + assert "anyOf" not in y_prop, ( + "Optional scalar should be flattened, not preserve anyOf" + ) + + # Verify validation works with y absent + args = {"x": "hello"} + validated = validate_tool_arguments(tool, args, coerce_types=True) + assert validated["x"] == "hello" + + # Verify validation works with y present + args = {"x": "hello", "y": "world"} + validated = validate_tool_arguments(tool, args, coerce_types=True) + assert validated["x"] == "hello" + assert validated["y"] == "world" + + def test_optional_basemodel_not_required(self): + """Test def f(email: Email | None = None) - Optional BaseModel param. + + Confirms optional BaseModel params are absent from required, and that + validate_tool_arguments rejects a malformed nested dict in strict mode. + """ + + def send(email: Email | None = None) -> str: + """Send an email. + + Args: + email: The email to send + """ + return "ok" + + tool = MelleaTool.from_callable(send) + schema = tool.as_json_tool + params = schema["function"]["parameters"] + + # email should NOT be required + assert "email" not in params.get("required", []) + + # strict mode should reject a missing required nested field + from pydantic import ValidationError + + with pytest.raises(ValidationError): + validate_tool_arguments( + tool, + {"email": {"to": "a@b.com"}}, + strict=True, # missing subject + body + ) + + @pytest.mark.skip( + reason="Nested model resolution not yet implemented. " + "This test documents the expected behavior once recursive $ref resolution is added. " + "Currently fails because Address remains as a dangling $ref inside Person's schema. " + "NESTED_MODEL_RESOLUTION_ISSUE.md in " + "https://github.com/generative-computing/mellea/issues/911 for implementation details." + ) + def test_nested_models_fully_inlined(self): + """Test def f(person: Person) where Person has address: Address. + + Confirms nested models work end-to-end: both Person and Address fully + inlined in the schema, and validate_tool_arguments accepts a nested + dict without a ValidationError. + + DISABLED: This test is currently skipped because nested model resolution + is not yet implemented. The test will be enabled once the recursive + reference resolution feature is added to convert_function_to_ollama_tool. + """ + + def create_person(person: Person) -> str: + """Create a person record. + + Args: + person: The person's information + """ + return f"Created {person.name}" + + tool = MelleaTool.from_callable(create_person) + schema = tool.as_json_tool + + params = schema["function"]["parameters"] + person_prop = params["properties"]["person"] + + # Verify Person is inlined + assert "$ref" not in person_prop, "Person should be inlined, not a $ref" + assert person_prop["type"] == "object" + assert "properties" in person_prop + + # Verify Person has all expected fields + person_props = person_prop["properties"] + assert "name" in person_props + assert "age" in person_props + assert "email" in person_props + assert "address" in person_props + + # Verify Address is also inlined (not a dangling $ref) + address_prop = person_props["address"] + # Address is Optional[Address], so it might be in anyOf + if "anyOf" in address_prop: + # Find the non-null schema in anyOf + address_schemas = [ + s for s in address_prop["anyOf"] if s.get("type") != "null" + ] + assert len(address_schemas) > 0, "Should have at least one non-null schema" + address_schema = address_schemas[0] + else: + address_schema = address_prop + + # The address schema should be fully resolved (no $ref) + assert "$ref" not in address_schema, ( + "Address should be inlined, not a dangling $ref" + ) + + # Note: Current implementation may not fully inline nested models yet + # This test documents the expected behavior for the nested resolution issue + + # Verify validation works with nested dict + args = { + "person": { + "name": "John Doe", + "age": 30, + "email": "john@example.com", + "address": { + "street": "123 Main St", + "city": "Boston", + "state": "MA", + "zip_code": "02101", + }, + } + } + + # This should not raise ValidationError + validated = validate_tool_arguments(tool, args, coerce_types=True) + assert validated["person"]["name"] == "John Doe" + assert validated["person"]["address"]["city"] == "Boston" + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/test/backends/test_schema_helpers.py b/test/backends/test_schema_helpers.py new file mode 100644 index 000000000..73dced767 --- /dev/null +++ b/test/backends/test_schema_helpers.py @@ -0,0 +1,168 @@ +"""Tests for schema helper functions in tools.py. + +Tests for _resolve_ref and _is_complex_anyof functions that handle +JSON schema resolution and complex type detection. +""" + +import pytest + +from mellea.backends.tools import _is_complex_anyof, _resolve_ref + + +class TestResolveRef: + """Tests for the _resolve_ref helper function.""" + + def test_resolve_defs_style_ref(self): + """Test resolving #/$defs/ style references.""" + defs = { + "Email": { + "type": "object", + "properties": {"to": {"type": "string"}, "subject": {"type": "string"}}, + } + } + result = _resolve_ref("#/$defs/Email", defs) + assert result == defs["Email"] + + def test_resolve_definitions_style_ref(self): + """Test resolving #/definitions/ style references.""" + defs = { + "Address": { + "type": "object", + "properties": { + "street": {"type": "string"}, + "city": {"type": "string"}, + }, + } + } + result = _resolve_ref("#/definitions/Address", defs) + assert result == defs["Address"] + + def test_resolve_missing_ref_returns_empty_dict(self): + """Test that resolving a non-existent ref returns empty dict.""" + defs = {"Email": {"type": "object"}} + result = _resolve_ref("#/$defs/NotFound", defs) + assert result == {} + + def test_resolve_invalid_ref_format_returns_empty_dict(self): + """Test that invalid ref format returns empty dict.""" + defs = {"Email": {"type": "object"}} + result = _resolve_ref("#/invalid/Email", defs) + assert result == {} + + def test_resolve_with_empty_defs(self): + """Test resolving against empty defs dict.""" + result = _resolve_ref("#/$defs/Email", {}) + assert result == {} + + def test_resolve_nested_ref_name(self): + """Test resolving refs with nested-like names.""" + defs = { + "User_v2": {"type": "object", "properties": {"id": {"type": "integer"}}} + } + result = _resolve_ref("#/$defs/User_v2", defs) + assert result == defs["User_v2"] + + +class TestIsComplexAnyof: + """Tests for the _is_complex_anyof helper function.""" + + def test_simple_optional_primitive_not_complex(self): + """Test that Optional[str] (anyOf with null and string) is not complex.""" + schema = {"anyOf": [{"type": "string"}, {"type": "null"}]} + assert _is_complex_anyof(schema) is False + + def test_optional_int_not_complex(self): + """Test that Optional[int] is not complex.""" + schema = {"anyOf": [{"type": "integer"}, {"type": "null"}]} + assert _is_complex_anyof(schema) is False + + def test_anyof_with_ref_is_complex(self): + """Test that anyOf with a $ref is complex.""" + schema = {"anyOf": [{"$ref": "#/$defs/Email"}, {"type": "null"}]} + assert _is_complex_anyof(schema) is True + + def test_anyof_with_nested_object_is_complex(self): + """Test that anyOf with nested object (properties) is complex.""" + schema = { + "anyOf": [ + { + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + }, + }, + {"type": "null"}, + ] + } + assert _is_complex_anyof(schema) is True + + def test_anyof_with_multiple_types_no_ref_not_complex(self): + """Test that anyOf with multiple primitives (no ref/props) is not complex.""" + schema = {"anyOf": [{"type": "string"}, {"type": "integer"}, {"type": "null"}]} + assert _is_complex_anyof(schema) is False + + def test_anyof_with_ref_and_primitive_is_complex(self): + """Test that anyOf with both ref and primitive is complex.""" + schema = {"anyOf": [{"$ref": "#/$defs/Email"}, {"type": "string"}]} + assert _is_complex_anyof(schema) is True + + def test_anyof_with_only_null_not_complex(self): + """Test that anyOf with only null type is not complex.""" + schema = {"anyOf": [{"type": "null"}]} + assert _is_complex_anyof(schema) is False + + def test_empty_anyof_not_complex(self): + """Test that empty anyOf is not complex.""" + schema = {"anyOf": []} + assert _is_complex_anyof(schema) is False + + def test_anyof_missing_from_schema_not_complex(self): + """Test that schema without anyOf is not complex.""" + schema = {"type": "object"} + assert _is_complex_anyof(schema) is False + + def test_anyof_with_allof_is_complex(self): + """Test that anyOf containing allOf is complex (has properties-like structure).""" + schema = { + "anyOf": [ + { + "allOf": [ + {"$ref": "#/$defs/Base"}, + {"properties": {"extra": {"type": "string"}}}, + ] + }, + {"type": "null"}, + ] + } + # Note: Our implementation checks for $ref or properties in the sub_schema, + # not recursively in allOf, so this should be not complex + # (unless allOf itself has properties) + assert _is_complex_anyof(schema) is False + + def test_anyof_union_with_ref(self): + """Test anyOf representing a union of multiple types including ref.""" + schema = {"anyOf": [{"$ref": "#/$defs/User"}, {"$ref": "#/$defs/Admin"}]} + assert _is_complex_anyof(schema) is True + + def test_complex_nested_structure(self): + """Test complex nested object with all optional fields.""" + schema = { + "anyOf": [ + { + "type": "object", + "properties": { + "user": { + "type": "object", + "properties": {"name": {"type": "string"}}, + } + }, + }, + {"type": "null"}, + ] + } + assert _is_complex_anyof(schema) is True + + +if __name__ == "__main__": + pytest.main([__file__])