Skip to content

Commit 0831bcc

Browse files
committed
review comment
Signed-off-by: Akihiko Kuroda <akihikokuroda2020@gmail.com>
1 parent 1e78c81 commit 0831bcc

2 files changed

Lines changed: 78 additions & 13 deletions

File tree

mellea/backends/tools.py

Lines changed: 46 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,11 @@ def _build_pydantic_type_from_schema(schema: dict[str, Any]) -> Any:
508508
convert_function_to_ollama_tool before validation. If a schema
509509
still contains $ref entries, this will return Any as a fallback.
510510
"""
511+
# Early exit for unresolved $ref: return Any to disable validation
512+
# rather than silently mistype as str
513+
if "$ref" in schema:
514+
return Any
515+
511516
json_type = schema.get("type", "string")
512517

513518
# Handle nested objects with properties
@@ -525,9 +530,12 @@ def _build_pydantic_type_from_schema(schema: dict[str, Any]) -> Any:
525530

526531
# Create a nested Pydantic model with deterministic name
527532
# Use sorted field names for consistent naming across runs
533+
# Respect strict mode: forbid extra fields if caller requested strict=True
528534
model_name = f"Nested_{'_'.join(sorted(nested_fields.keys()))}"
529535
return create_model(
530-
model_name, __config__=ConfigDict(extra="allow"), **nested_fields
536+
model_name,
537+
__config__=ConfigDict(extra="forbid" if strict else "allow"),
538+
**nested_fields,
531539
)
532540

533541
# Handle arrays
@@ -557,18 +565,28 @@ def _build_pydantic_type_from_schema(schema: dict[str, Any]) -> Any:
557565

558566
# Handle anyOf (union types in JSON schema)
559567
if "anyOf" in schema:
568+
# Filter out null sub-schemas before recursion to prevent them from
569+
# contaminating the union if an unresolved $ref returns Any.
570+
# Null becomes an explicit Optional[] wrapper instead.
571+
has_null = any(s.get("type") == "null" for s in schema["anyOf"])
572+
non_null_schemas = [s for s in schema["anyOf"] if s.get("type") != "null"]
573+
560574
types_list = []
561-
for sub_schema in schema["anyOf"]:
575+
for sub_schema in non_null_schemas:
562576
sub_type = _build_pydantic_type_from_schema(sub_schema)
563577
types_list.append(sub_type)
564578

565-
if len(types_list) == 1:
566-
return types_list[0]
579+
if len(types_list) == 0:
580+
result = Any
581+
elif len(types_list) == 1:
582+
result = types_list[0]
567583
else:
568584
from functools import reduce
569585
from operator import or_
570586

571-
return reduce(or_, types_list)
587+
result = reduce(or_, types_list)
588+
589+
return (result | None) if has_null else result
572590

573591
# Simple type mapping
574592
return JSON_TYPE_TO_PYTHON.get(json_type, Any)
@@ -918,8 +936,8 @@ def _is_complex_anyof(v: dict) -> bool:
918936
# Skip null types - they just indicate optionality
919937
if sub_schema.get("type") == "null":
920938
continue
921-
# Check for references or nested properties
922-
if "$ref" in sub_schema or "properties" in sub_schema:
939+
# Check for references, nested properties, or allOf (inherited models)
940+
if "$ref" in sub_schema or "properties" in sub_schema or "allOf" in sub_schema:
923941
return True
924942
return False
925943

@@ -958,7 +976,7 @@ def convert_function_to_ollama_tool(
958976
defs = schema.get("$defs", schema.get("definitions", {}))
959977

960978
for k, v in schema.get("properties", {}).items():
961-
# Check if this property has a $ref (reference to a definition)
979+
# First pass: inline all $refs (at top level and within anyOf)
962980
if "$ref" in v:
963981
# Resolve the reference and inline it
964982
ref_schema = _resolve_ref(v["$ref"], defs)
@@ -969,21 +987,36 @@ def convert_function_to_ollama_tool(
969987
if parsed_docstring.get(k):
970988
inlined["description"] = parsed_docstring[k]
971989
schema["properties"][k] = inlined
990+
v = inlined # Update v to point to inlined schema
972991
else:
973992
# Fallback if we can't resolve
974993
schema["properties"][k] = {
975994
"description": parsed_docstring.get(k, ""),
976995
"type": "object",
977996
}
978-
# Check if this property is a nested object (has 'properties' or complex types)
979-
# Narrow anyOf check to only complex unions, not Optional[primitive]
980-
elif (
981-
"properties" in v or "allOf" in v or ("anyOf" in v and _is_complex_anyof(v))
982-
):
997+
v = schema["properties"][k]
998+
999+
# Inline $refs within anyOf
1000+
if "anyOf" in v:
1001+
for i, sub_schema in enumerate(v["anyOf"]):
1002+
if "$ref" in sub_schema:
1003+
ref_schema = _resolve_ref(sub_schema["$ref"], defs)
1004+
if ref_schema:
1005+
# Inline the referenced schema
1006+
v["anyOf"][i] = copy.deepcopy(ref_schema)
1007+
1008+
# Second pass: determine how to handle the property type
1009+
if "properties" in v or "allOf" in v or ("anyOf" in v and _is_complex_anyof(v)):
9831010
# This is a complex/nested type - preserve the full schema
9841011
# Only add description if we have one
9851012
if parsed_docstring.get(k):
9861013
v["description"] = parsed_docstring[k]
1014+
# If anyOf contains null (making it Optional), remove from required
1015+
if "anyOf" in v and any(
1016+
t.get("type") == "null" for t in v.get("anyOf", [])
1017+
):
1018+
if k in schema.get("required", []):
1019+
schema["required"].remove(k)
9871020
schema["properties"][k] = v
9881021
else:
9891022
# Simple type - use the original flattening logic

test/backends/test_pydantic_tool_parameters.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -474,6 +474,38 @@ def process_text(x: str, y: str | None = None) -> str:
474474
assert validated["x"] == "hello"
475475
assert validated["y"] == "world"
476476

477+
def test_optional_basemodel_not_required(self):
478+
"""Test def f(email: Email | None = None) - Optional BaseModel param.
479+
480+
Confirms optional BaseModel params are absent from required, and that
481+
validate_tool_arguments rejects a malformed nested dict in strict mode.
482+
"""
483+
484+
def send(email: Email | None = None) -> str:
485+
"""Send an email.
486+
487+
Args:
488+
email: The email to send
489+
"""
490+
return "ok"
491+
492+
tool = MelleaTool.from_callable(send)
493+
schema = tool.as_json_tool
494+
params = schema["function"]["parameters"]
495+
496+
# email should NOT be required
497+
assert "email" not in params.get("required", [])
498+
499+
# strict mode should reject a missing required nested field
500+
from pydantic import ValidationError
501+
502+
with pytest.raises(ValidationError):
503+
validate_tool_arguments(
504+
tool,
505+
{"email": {"to": "a@b.com"}},
506+
strict=True, # missing subject + body
507+
)
508+
477509
@pytest.mark.skip(
478510
reason="Nested model resolution not yet implemented. "
479511
"This test documents the expected behavior once recursive $ref resolution is added. "

0 commit comments

Comments
 (0)