diff --git a/agentkit/toolkit/cli/cli_config.py b/agentkit/toolkit/cli/cli_config.py index 6bc4494..8998e06 100644 --- a/agentkit/toolkit/cli/cli_config.py +++ b/agentkit/toolkit/cli/cli_config.py @@ -372,7 +372,7 @@ def _interactive_config(config_file: Optional[str] = None): ) common_config = create_common_config_interactively( - config.get_common_config().to_dict() + (config.get_raw_data() or {}).get("common", {}) ) config.update_common_config(common_config) diff --git a/agentkit/toolkit/cli/interactive_config.py b/agentkit/toolkit/cli/interactive_config.py index 181f8a5..30bb25a 100644 --- a/agentkit/toolkit/cli/interactive_config.py +++ b/agentkit/toolkit/cli/interactive_config.py @@ -178,6 +178,7 @@ def generate_config( dataclass_type: type, existing_config: Optional[Dict[str, Any]] = None, context: Optional[Dict[str, Any]] = None, + carry_over_config: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: self.current_dataclass_type = dataclass_type if not is_dataclass(dataclass_type): @@ -185,6 +186,12 @@ def generate_config( config = {} existing_config = existing_config or {} + if carry_over_config is not None and not isinstance(carry_over_config, dict): + raise TypeError( + "carry_over_config must be a dict when provided; " + f"got {type(carry_over_config).__name__}" + ) + carry_over = existing_config if carry_over_config is None else carry_over_config # Get dataclass metadata # Try to get from class attributes; if not found, create instance to get field values @@ -267,8 +274,8 @@ def generate_config( if field.metadata.get("hidden", False) or field.metadata.get( "system", False ): - if field_name in existing_config: - config[field_name] = existing_config[field_name] + if isinstance(carry_over, dict) and field_name in carry_over: + config[field_name] = carry_over[field_name] # Filter out MISSING values filtered_config = {} @@ -1222,8 +1229,14 @@ def generate_config_from_dataclass( dataclass_type: type, existing_config: Optional[Dict[str, Any]] = None, context: Optional[Dict[str, Any]] = None, + carry_over_config: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: - return auto_prompt.generate_config(dataclass_type, existing_config, context=context) + return auto_prompt.generate_config( + dataclass_type, + existing_config, + context=context, + carry_over_config=carry_over_config, + ) def create_common_config_interactively( @@ -1245,6 +1258,11 @@ def create_common_config_interactively( """ from agentkit.toolkit.config import CommonConfig - existing = CommonConfig.from_dict(existing_config or {}) - config_dict = auto_prompt.generate_config(CommonConfig, existing.to_dict()) + raw_existing_config = existing_config or {} + existing = CommonConfig.from_dict(raw_existing_config) + config_dict = auto_prompt.generate_config( + CommonConfig, + existing.to_dict(), + carry_over_config=raw_existing_config, + ) return CommonConfig.from_dict(config_dict) diff --git a/agentkit/version.py b/agentkit/version.py index 36512d8..ddf32e6 100644 --- a/agentkit/version.py +++ b/agentkit/version.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -VERSION = "0.5.4" +VERSION = "0.5.5" diff --git a/pyproject.toml b/pyproject.toml index 4a1642b..9ef310d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "agentkit-sdk-python" -version = "0.5.4" +version = "0.5.5" description = "Python SDK for transforming any AI agent into a production-ready application. Framework-agnostic primitives for runtime, memory, authentication, and tools with volcengine-managed infrastructure." readme = "README.md" requires-python = ">=3.10" diff --git a/tests/toolkit/cli/test_cli_config_interactive_provider_resolution.py b/tests/toolkit/cli/test_cli_config_interactive_provider_resolution.py index 6bc2e3d..9a29794 100644 --- a/tests/toolkit/cli/test_cli_config_interactive_provider_resolution.py +++ b/tests/toolkit/cli/test_cli_config_interactive_provider_resolution.py @@ -73,3 +73,64 @@ def fake_generate_config_from_dataclass( clear_config_cache() cli_config._interactive_config(config_file=str(config_path)) + + +def test_interactive_config_common_input_uses_raw_yaml_common( + tmp_path: Path, monkeypatch +) -> None: + from agentkit.toolkit.cli import cli_config + from agentkit.toolkit.config import CommonConfig + from agentkit.toolkit.config.config import clear_config_cache + import agentkit.toolkit.config.global_config as global_cfg_mod + from agentkit.toolkit.config.global_config import GlobalConfig + import agentkit.toolkit.cli.interactive_config as interactive_config + + config_path = tmp_path / "agentkit.yaml" + raw_common = { + "agent_name": "demo", + "entry_point": "agent.py", + "launch_type": "cloud", + } + config_path.write_text( + yaml.safe_dump( + { + "common": raw_common, + "launch_types": {"cloud": {}}, + "docker_build": {}, + }, + sort_keys=False, + allow_unicode=True, + ), + encoding="utf-8", + ) + + global_cfg = GlobalConfig() + global_cfg.defaults.cloud_provider = "byteplus" + monkeypatch.setattr(global_cfg_mod, "get_global_config", lambda: global_cfg) + + def fake_create_common_config_interactively(existing_config): + assert isinstance(existing_config, dict) + assert "cloud_provider" not in existing_config + for k, v in raw_common.items(): + assert existing_config.get(k) == v + return CommonConfig.from_dict(existing_config or {}) + + def fake_generate_config_from_dataclass( + _dataclass_type, existing_config=None, context=None + ): + assert context == {"cloud_provider": "byteplus"} + return {"region": "ap-southeast-1"} + + monkeypatch.setattr( + interactive_config, + "create_common_config_interactively", + fake_create_common_config_interactively, + ) + monkeypatch.setattr( + interactive_config, + "generate_config_from_dataclass", + fake_generate_config_from_dataclass, + ) + + clear_config_cache() + cli_config._interactive_config(config_file=str(config_path)) diff --git a/tests/toolkit/cli/test_interactive_hidden_fields_carry_over.py b/tests/toolkit/cli/test_interactive_hidden_fields_carry_over.py new file mode 100644 index 0000000..b751d5f --- /dev/null +++ b/tests/toolkit/cli/test_interactive_hidden_fields_carry_over.py @@ -0,0 +1,164 @@ +from __future__ import annotations + +import pytest + + +def test_generate_config_rejects_non_dict_carry_over_config(monkeypatch) -> None: + from agentkit.toolkit.cli.interactive_config import AutoPromptGenerator + from agentkit.toolkit.config import CommonConfig + + monkeypatch.setattr( + AutoPromptGenerator, "_show_welcome_panel", lambda *a, **k: None + ) + monkeypatch.setattr( + AutoPromptGenerator, "_show_completion_panel", lambda *a, **k: None + ) + monkeypatch.setattr( + AutoPromptGenerator, + "_prompt_for_field", + lambda self, + name, + field_type, + description, + default, + metadata=None, + current=1, + total=1, + current_config=None, + resolver_context=None: default, + ) + + generator = AutoPromptGenerator() + + with pytest.raises(TypeError, match=r"carry_over_config must be a dict"): + generator.generate_config(CommonConfig, {}, carry_over_config=[]) # type: ignore[arg-type] + + +def test_generate_config_does_not_carry_hidden_fields_from_prefill_when_carry_over_missing( + monkeypatch, +) -> None: + from agentkit.toolkit.cli.interactive_config import AutoPromptGenerator + from agentkit.toolkit.config import CommonConfig + + monkeypatch.setattr( + AutoPromptGenerator, "_show_welcome_panel", lambda *a, **k: None + ) + monkeypatch.setattr( + AutoPromptGenerator, "_show_completion_panel", lambda *a, **k: None + ) + monkeypatch.setattr( + AutoPromptGenerator, + "_prompt_for_field", + lambda self, + name, + field_type, + description, + default, + metadata=None, + current=1, + total=1, + current_config=None, + resolver_context=None: default, + ) + + generator = AutoPromptGenerator() + + prefill = CommonConfig.from_dict({}).to_dict() + carry_over = { + "agent_name": "demo", + "entry_point": "agent.py", + "launch_type": "cloud", + } + + result = generator.generate_config( + CommonConfig, + prefill, + carry_over_config=carry_over, + ) + + assert "cloud_provider" not in result + + +def test_generate_config_carries_hidden_fields_from_carry_over(monkeypatch) -> None: + from agentkit.toolkit.cli.interactive_config import AutoPromptGenerator + from agentkit.toolkit.config import CommonConfig + + monkeypatch.setattr( + AutoPromptGenerator, "_show_welcome_panel", lambda *a, **k: None + ) + monkeypatch.setattr( + AutoPromptGenerator, "_show_completion_panel", lambda *a, **k: None + ) + monkeypatch.setattr( + AutoPromptGenerator, + "_prompt_for_field", + lambda self, + name, + field_type, + description, + default, + metadata=None, + current=1, + total=1, + current_config=None, + resolver_context=None: default, + ) + + generator = AutoPromptGenerator() + + prefill = CommonConfig.from_dict({}).to_dict() + carry_over = { + "agent_name": "demo", + "entry_point": "agent.py", + "launch_type": "cloud", + "cloud_provider": "byteplus", + } + + result = generator.generate_config( + CommonConfig, + prefill, + carry_over_config=carry_over, + ) + + assert result["cloud_provider"] == "byteplus" + + +def test_generate_config_default_behavior_carries_hidden_fields_from_existing_config( + monkeypatch, +) -> None: + from agentkit.toolkit.cli.interactive_config import AutoPromptGenerator + from agentkit.toolkit.config import CommonConfig + + monkeypatch.setattr( + AutoPromptGenerator, "_show_welcome_panel", lambda *a, **k: None + ) + monkeypatch.setattr( + AutoPromptGenerator, "_show_completion_panel", lambda *a, **k: None + ) + monkeypatch.setattr( + AutoPromptGenerator, + "_prompt_for_field", + lambda self, + name, + field_type, + description, + default, + metadata=None, + current=1, + total=1, + current_config=None, + resolver_context=None: default, + ) + + generator = AutoPromptGenerator() + + existing_config = { + "agent_name": "demo", + "entry_point": "agent.py", + "launch_type": "cloud", + "cloud_provider": "byteplus", + } + + result = generator.generate_config(CommonConfig, existing_config) + + assert result["cloud_provider"] == "byteplus"