diff --git a/tests/unit/vertexai/genai/replays/test_skills_delete.py b/tests/unit/vertexai/genai/replays/test_skills_delete.py new file mode 100644 index 0000000000..92bb1a8228 --- /dev/null +++ b/tests/unit/vertexai/genai/replays/test_skills_delete.py @@ -0,0 +1,44 @@ +"""Tests the skills.delete() method against the autopush endpoint.""" + +from tests.unit.vertexai.genai.replays import pytest_helper +from vertexai._genai import types +from google.genai import errors +import pytest + +pytestmark = pytest_helper.setup( + file=__file__, + globals_for_file=globals(), +) + + +def test_delete_skill(client, tmp_path): + # Target the autopush sandbox endpoint for the Skill Registry API + client._api_client._http_options.base_url = ( + "https://us-central1-autopush-aiplatform.sandbox.googleapis.com" + ) + + # 1. Create a fresh unique skill first + with open(tmp_path / "SKILL.md", "w") as f: + f.write("# Test Skill\nTo be deleted.") + + created_skill = client.skills.create( + display_name="To Be Deleted Skill", + description="Skill to be deleted", + config=types.CreateSkillConfig( + local_path=str(tmp_path), wait_for_completion=True + ), + ) + + assert created_skill.name is not None + + # 2. Delete the skill and wait for LRO completion + client.skills.delete( + name=created_skill.name, + config=types.DeleteSkillConfig(wait_for_completion=True), + ) + + # 3. Verify the skill is successfully deleted (Getting it should raise NotFound) + with pytest.raises(errors.ClientError) as exc_info: + client.skills.get(name=created_skill.name) + + assert exc_info.value.code == 404 diff --git a/tests/unit/vertexai/genai/replays/test_skills_list.py b/tests/unit/vertexai/genai/replays/test_skills_list.py new file mode 100644 index 0000000000..173095cc75 --- /dev/null +++ b/tests/unit/vertexai/genai/replays/test_skills_list.py @@ -0,0 +1,21 @@ +"""Tests the skills.list() method against the autopush endpoint.""" + +from tests.unit.vertexai.genai.replays import pytest_helper +from vertexai._genai import types + +pytestmark = pytest_helper.setup( + file=__file__, + globals_for_file=globals(), +) + + +def test_list_skills(client): + # Target the autopush sandbox endpoint for the Skill Registry API + client._api_client._http_options.base_url = ( + "https://us-central1-autopush-aiplatform.sandbox.googleapis.com" + ) + + skills = client.skills.list() + for skill in skills: + assert isinstance(skill, types.Skill) + assert skill.name is not None diff --git a/tests/unit/vertexai/genai/replays/test_skills_update.py b/tests/unit/vertexai/genai/replays/test_skills_update.py new file mode 100644 index 0000000000..f001be7d08 --- /dev/null +++ b/tests/unit/vertexai/genai/replays/test_skills_update.py @@ -0,0 +1,103 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Tests the skills.update() method against the Vertex AI endpoint using replays.""" + +import io +import os +import zipfile + +from tests.unit.vertexai.genai.replays import pytest_helper +from vertexai._genai import types + +# MANDATORY: Initialize the replay test framework for this module +pytestmark = pytest_helper.setup( + file=__file__, + globals_for_file=globals(), +) + +PROJECT_ID = os.environ.get("GOOGLE_CLOUD_PROJECT", "srbai-testing") +REGION = "us-central1" + + +def test_update_skill(client, tmp_path): + # Target the autopush sandbox endpoint for the Skill Registry API + client._api_client._http_options.base_url = ( + "https://us-central1-autopush-aiplatform.sandbox.googleapis.com" + ) + + # 1. Create a fresh unique skill first + with open(tmp_path / "SKILL.md", "w") as f: + f.write("# Test Skill\nInitial content.") + + created_skill = client.skills.create( + display_name="Original Skill", + description="Original Description", + config=types.CreateSkillConfig( + local_path=str(tmp_path), wait_for_completion=True + ), + ) + + # 2. Perform the metadata-only update on the new skill + updated_skill = client.skills.update( + name=created_skill.name, + config=types.UpdateSkillConfig( + display_name="My Updated Replay Skill", + description="My Updated Replay Skill Description", + wait_for_completion=True, + ), + ) + + assert updated_skill.name == created_skill.name + assert updated_skill.display_name == "My Updated Replay Skill" + assert updated_skill.description == "My Updated Replay Skill Description" + + +def test_update_skill_with_zipped_bytes(client, tmp_path): + # Target the autopush sandbox endpoint for the Skill Registry API + client._api_client._http_options.base_url = ( + "https://us-central1-autopush-aiplatform.sandbox.googleapis.com" + ) + + # 1. Create a fresh unique skill first + with open(tmp_path / "SKILL.md", "w") as f: + f.write("# Test Skill\nInitial content.") + + created_skill = client.skills.create( + display_name="Original Skill", + description="Original Description", + config=types.CreateSkillConfig( + local_path=str(tmp_path), wait_for_completion=True + ), + ) + + # 2. Prepare zipped bytes for update + zip_buffer = io.BytesIO() + zinfo = zipfile.ZipInfo("SKILL.md", date_time=(1980, 1, 1, 0, 0, 0)) + with zipfile.ZipFile(zip_buffer, "w") as zip_file: + zip_file.writestr(zinfo, "# My Updated Zipped Replay Skill\nThis is updated.") + zipped_bytes = zip_buffer.getvalue() + + # 3. Update the skill with new zipped bytes + updated_skill = client.skills.update( + name=created_skill.name, + config=types.UpdateSkillConfig( + zipped_filesystem=zipped_bytes, wait_for_completion=True + ), + ) + + assert updated_skill.name == created_skill.name + assert ( + updated_skill.display_name == "Original Skill" + ) # Display name remains unchanged diff --git a/tests/unit/vertexai/genai/test_genai_skills.py b/tests/unit/vertexai/genai/test_genai_skills.py index cc6a54ecaf..bdfd40d5b7 100644 --- a/tests/unit/vertexai/genai/test_genai_skills.py +++ b/tests/unit/vertexai/genai/test_genai_skills.py @@ -1,5 +1,21 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# # //third_party/py/google/cloud/aiplatform/tests/unit/vertexai/genai/test_genai_skills.py import json +import os +import tempfile from unittest import mock import google.auth.credentials from vertexai import _genai as genai @@ -29,6 +45,8 @@ def async_skills_client(): class TestGenaiSkills: + """Tests the Genai Skills client.""" + mock_get_skill_response = { "name": "projects/test-project/locations/test-location/skills/test-skill", "displayName": "My Test Skill", @@ -144,3 +162,662 @@ async def test_retrieve_skills_async(self, async_skills_client): {"_query": {"query": "test query", "topK": 1}}, None, ) + + def test_create_skill(self, skills_client): + """Tests the create_skill method with wait_for_completion=True.""" + with tempfile.TemporaryDirectory() as tmpdir: + # Create a dummy file in tmpdir + with open(os.path.join(tmpdir, "SKILL.md"), "w") as f: + f.write("# Test Skill") + + # Prepare mock responses + pending_op = { + "name": ( + "projects/test-project/locations/test-location/skills/test-skill/operations/op-123" + ), + "done": False, + } + finished_op = { + "name": ( + "projects/test-project/locations/test-location/skills/test-skill/operations/op-123" + ), + "done": True, + "response": { + "name": ( + "projects/test-project/locations/test-location/skills/test-skill" + ), + "displayName": "My Test Skill", + "description": "My Test Skill Description", + }, + } + + # Final Skill response returned by get call + skill_response = { + "name": ( + "projects/test-project/locations/test-location/skills/test-skill" + ), + "displayName": "My Test Skill", + "description": "My Test Skill Description", + } + + with mock.patch.object( + skills_client._api_client, "request", autospec=True + ) as request_mock: + request_mock.side_effect = [ + genai_types.HttpResponse(body=json.dumps(pending_op)), + genai_types.HttpResponse(body=json.dumps(finished_op)), + genai_types.HttpResponse(body=json.dumps(skill_response)), + ] + + # We mock time.sleep to speed up the test + with mock.patch("time.sleep", return_value=None): + skill = skills_client.create( + display_name="My Test Skill", + description="My Test Skill Description", + config={"local_path": tmpdir, "wait_for_completion": True}, + ) + + # Assertions + assert request_mock.call_count == 3 + + # Verify POST request + post_call = request_mock.call_args_list[0] + assert post_call[0][0] == "post" + assert post_call[0][1] == "skills" + + post_body = post_call[0][2] + assert post_body["displayName"] == "My Test Skill" + assert post_body["description"] == "My Test Skill Description" + assert isinstance(post_body["zippedFilesystem"], str) + + # Verify GET request (polling) + get_call = request_mock.call_args_list[1] + assert get_call[0][0] == "get" + assert ( + get_call[0][1] + == "projects/test-project/locations/test-location/skills/test-skill/operations/op-123" + ) + + # Verify final GET request to fetch the skill + get_skill_call = request_mock.call_args_list[2] + assert get_skill_call[0][0] == "get" + assert ( + get_skill_call[0][1] + == "projects/test-project/locations/test-location/skills/test-skill" + ) + + # Verify returned skill + assert isinstance(skill, genai.types.Skill) + assert ( + skill.name + == "projects/test-project/locations/test-location/skills/test-skill" + ) + assert skill.display_name == "My Test Skill" + assert skill.description == "My Test Skill Description" + + def test_create_skill_no_wait(self, skills_client): + """Tests the create_skill method with wait_for_completion=False.""" + with tempfile.TemporaryDirectory() as tmpdir: + with open(os.path.join(tmpdir, "SKILL.md"), "w") as f: + f.write("# Test Skill") + + pending_op = { + "name": ( + "projects/test-project/locations/test-location/skills/test-skill/operations/op-123" + ), + "done": False, + } + + with mock.patch.object( + skills_client._api_client, "request", autospec=True + ) as request_mock: + request_mock.return_value = genai_types.HttpResponse( + body=json.dumps(pending_op) + ) + + operation = skills_client.create( + display_name="My Test Skill", + description="My Test Skill Description", + config={"local_path": tmpdir, "wait_for_completion": False}, + ) + + # Assertions + assert request_mock.call_count == 1 + assert isinstance(operation, genai.types.SkillOperation) + assert ( + operation.name + == "projects/test-project/locations/test-location/skills/test-skill/operations/op-123" + ) + assert not operation.done + + @pytest.mark.asyncio + async def test_create_skill_async(self, async_skills_client): + """Tests the create_skill method asynchronously with wait_for_completion=True.""" + with tempfile.TemporaryDirectory() as tmpdir: + with open(os.path.join(tmpdir, "SKILL.md"), "w") as f: + f.write("# Test Skill") + + pending_op = { + "name": ( + "projects/test-project/locations/test-location/skills/test-skill/operations/op-123" + ), + "done": False, + } + finished_op = { + "name": ( + "projects/test-project/locations/test-location/skills/test-skill/operations/op-123" + ), + "done": True, + "response": { + "name": ( + "projects/test-project/locations/test-location/skills/test-skill" + ), + "displayName": "My Test Skill", + "description": "My Test Skill Description", + }, + } + + # Final Skill response returned by async get call + skill_response = { + "name": ( + "projects/test-project/locations/test-location/skills/test-skill" + ), + "displayName": "My Test Skill", + "description": "My Test Skill Description", + } + + with mock.patch.object( + async_skills_client._api_client, "async_request", autospec=True + ) as request_mock: + request_mock.side_effect = [ + genai_types.HttpResponse(body=json.dumps(pending_op)), + genai_types.HttpResponse(body=json.dumps(finished_op)), + genai_types.HttpResponse(body=json.dumps(skill_response)), + ] + + with mock.patch("asyncio.sleep", new_callable=mock.AsyncMock): + skill = await async_skills_client.create( + display_name="My Test Skill", + description="My Test Skill Description", + config={"local_path": tmpdir, "wait_for_completion": True}, + ) + + # Assertions + assert request_mock.call_count == 3 + + # Verify POST request + post_call = request_mock.call_args_list[0] + assert post_call[0][0] == "post" + assert post_call[0][1] == "skills" + + # Verify final GET request to fetch the skill + get_skill_call = request_mock.call_args_list[2] + assert get_skill_call[0][0] == "get" + assert ( + get_skill_call[0][1] + == "projects/test-project/locations/test-location/skills/test-skill" + ) + + # Verify returned skill + assert isinstance(skill, genai.types.Skill) + assert ( + skill.name + == "projects/test-project/locations/test-location/skills/test-skill" + ) + assert skill.display_name == "My Test Skill" + assert skill.description == "My Test Skill Description" + + def test_update_skill(self, skills_client): + """Tests the update method with wait_for_completion=True (default).""" + skill_name = "projects/test-project/locations/test-location/skills/test-skill" + # Prepare mock responses + pending_op = { + "name": ( + "projects/test-project/locations/test-location/skills/test-skill/operations/op-456" + ), + "done": False, + } + finished_op = { + "name": ( + "projects/test-project/locations/test-location/skills/test-skill/operations/op-456" + ), + "done": True, + "response": { + "name": skill_name, + "displayName": "Updated Skill", + "description": "Updated Description", + }, + } + skill_response = { + "name": skill_name, + "displayName": "Updated Skill", + "description": "Updated Description", + } + + with mock.patch.object( + skills_client._api_client, "request", autospec=True + ) as request_mock: + request_mock.side_effect = [ + genai_types.HttpResponse(body=json.dumps(pending_op)), + genai_types.HttpResponse(body=json.dumps(finished_op)), + genai_types.HttpResponse(body=json.dumps(skill_response)), + ] + + with mock.patch("time.sleep", return_value=None): + skill = skills_client.update( + name=skill_name, + config={ + "display_name": "Updated Skill", + "description": "Updated Description", + }, + ) + + # Assertions + assert request_mock.call_count == 3 + + # Verify PATCH request + patch_call = request_mock.call_args_list[0] + assert patch_call[0][0] == "patch" + assert ( + patch_call[0][1] == f"{skill_name}?updateMask=displayName%2Cdescription" + ) + + patch_body = patch_call[0][2] + assert patch_body["displayName"] == "Updated Skill" + assert patch_body["description"] == "Updated Description" + + # Verify GET request (polling) + get_call = request_mock.call_args_list[1] + assert get_call[0][0] == "get" + assert ( + get_call[0][1] + == "projects/test-project/locations/test-location/skills/test-skill/operations/op-456" + ) + + # Verify final GET request to fetch the skill + get_skill_call = request_mock.call_args_list[2] + assert get_skill_call[0][0] == "get" + assert get_skill_call[0][1] == skill_name + + # Verify returned skill + assert isinstance(skill, genai.types.Skill) + assert skill.name == skill_name + assert skill.display_name == "Updated Skill" + assert skill.description == "Updated Description" + + def test_update_skill_no_wait(self, skills_client): + """Tests the update method with wait_for_completion=False.""" + skill_name = "projects/test-project/locations/test-location/skills/test-skill" + pending_op = { + "name": ( + "projects/test-project/locations/test-location/skills/test-skill/operations/op-456" + ), + "done": False, + } + + with mock.patch.object( + skills_client._api_client, "request", autospec=True + ) as request_mock: + request_mock.return_value = genai_types.HttpResponse( + body=json.dumps(pending_op) + ) + + operation = skills_client.update( + name=skill_name, + config={ + "display_name": "Updated Skill", + "wait_for_completion": False, + }, + ) + + # Assertions + assert isinstance(operation, genai.types.SkillOperation) + assert ( + operation.name + == "projects/test-project/locations/test-location/skills/test-skill/operations/op-456" + ) + assert not operation.done + + request_mock.assert_called_once_with( + "patch", + f"{skill_name}?updateMask=displayName", + { + "_url": { + "name": ( + "projects/test-project/locations/test-location/skills/test-skill" + ) + }, + "displayName": "Updated Skill", + "_query": { + "updateMask": "displayName", + }, + }, + None, + ) + + @pytest.mark.asyncio + async def test_update_skill_async(self, async_skills_client): + """Tests the async update method with wait_for_completion=True.""" + skill_name = "projects/test-project/locations/test-location/skills/test-skill" + pending_op = { + "name": ( + "projects/test-project/locations/test-location/skills/test-skill/operations/op-456" + ), + "done": False, + } + finished_op = { + "name": ( + "projects/test-project/locations/test-location/skills/test-skill/operations/op-456" + ), + "done": True, + "response": { + "name": skill_name, + "displayName": "Updated Skill", + }, + } + skill_response = { + "name": skill_name, + "displayName": "Updated Skill", + } + + with mock.patch.object( + async_skills_client._api_client, "async_request", autospec=True + ) as request_mock: + request_mock.side_effect = [ + genai_types.HttpResponse(body=json.dumps(pending_op)), + genai_types.HttpResponse(body=json.dumps(finished_op)), + genai_types.HttpResponse(body=json.dumps(skill_response)), + ] + + with mock.patch("asyncio.sleep", new_callable=mock.AsyncMock): + skill = await async_skills_client.update( + name=skill_name, + config={ + "display_name": "Updated Skill", + "wait_for_completion": True, + }, + ) + + # Assertions + assert request_mock.call_count == 3 + + # Verify PATCH request + patch_call = request_mock.call_args_list[0] + assert patch_call[0][0] == "patch" + assert patch_call[0][1] == f"{skill_name}?updateMask=displayName" + + # Verify returned skill + assert isinstance(skill, genai.types.Skill) + assert skill.name == skill_name + assert skill.display_name == "Updated Skill" + + def test_update_skill_invalid_args(self, skills_client): + """Verifies ValueError is raised when no update fields are provided.""" + skill_name = "projects/test-project/locations/test-location/skills/test-skill" + with pytest.raises( + ValueError, + match=( + "At least one of `display_name`, `description`, `local_path`, or" + " `zipped_filesystem` must be provided for update in config." + ), + ): + skills_client.update(name=skill_name) + + def test_update_skill_mutually_exclusive_args(self, skills_client): + """Verifies ValueError is raised when both local_path and zipped_filesystem are provided.""" + skill_name = "projects/test-project/locations/test-location/skills/test-skill" + with pytest.raises( + ValueError, + match="Only one of `local_path` or `zipped_filesystem` can be provided", + ): + skills_client.update( + name=skill_name, + config={ + "local_path": "/some/path", + "zipped_filesystem": b"zipped_bytes", + }, + ) + + def test_list_skills(self, skills_client): + """Tests the list method using the standard Pager.""" + mock_list_response = { + "skills": [ + { + "name": ( + "projects/test-project/locations/test-location/skills/skill-1" + ), + "displayName": "Skill 1", + }, + { + "name": ( + "projects/test-project/locations/test-location/skills/skill-2" + ), + "displayName": "Skill 2", + }, + ], + "nextPageToken": "token-123", + } + mock_list_response_page_2 = { + "skills": [ + { + "name": ( + "projects/test-project/locations/test-location/skills/skill-3" + ), + "displayName": "Skill 3", + }, + ], + } + + with mock.patch.object( + skills_client._api_client, "request", autospec=True + ) as request_mock: + request_mock.side_effect = [ + genai_types.HttpResponse(body=json.dumps(mock_list_response)), + genai_types.HttpResponse(body=json.dumps(mock_list_response_page_2)), + ] + + skills = list(skills_client.list()) + + # Verify Pager correct retrieval across pages + assert len(skills) == 3 + assert skills[0].display_name == "Skill 1" + assert skills[1].display_name == "Skill 2" + assert skills[2].display_name == "Skill 3" + + # Verify requests using robust assert_has_calls + request_mock.assert_has_calls( + [ + mock.call( + "get", + "skills", + {}, + None, + ), + mock.call( + "get", + "skills?pageToken=token-123", + {"_query": {"pageToken": "token-123"}}, + None, + ), + ] + ) + + @pytest.mark.asyncio + async def test_list_skills_async(self, async_skills_client): + """Tests the async list method returning AsyncPager.""" + mock_list_response = { + "skills": [ + { + "name": ( + "projects/test-project/locations/test-location/skills/skill-1" + ), + "displayName": "Skill 1", + }, + ], + } + + with mock.patch.object( + async_skills_client._api_client, "async_request", autospec=True + ) as request_mock: + request_mock.return_value = genai_types.HttpResponse( + body=json.dumps(mock_list_response) + ) + + skills = [] + pager = await async_skills_client.list() + async for skill in pager: + skills.append(skill) + + assert len(skills) == 1 + assert skills[0].display_name == "Skill 1" + request_mock.assert_called_once_with( + "get", + "skills", + {}, + None, + ) + + def test_delete_skill(self, skills_client): + """Tests the delete method with wait_for_completion=True (default).""" + skill_name = "projects/test-project/locations/test-location/skills/test-skill" + pending_op = { + "name": ( + "projects/test-project/locations/test-location/skills/test-skill/operations/op-789" + ), + "done": False, + } + finished_op = { + "name": ( + "projects/test-project/locations/test-location/skills/test-skill/operations/op-789" + ), + "done": True, + "response": {}, + } + + with mock.patch.object( + skills_client._api_client, "request", autospec=True + ) as request_mock: + request_mock.side_effect = [ + genai_types.HttpResponse(body=json.dumps(pending_op)), + genai_types.HttpResponse(body=json.dumps(finished_op)), + ] + + with mock.patch("time.sleep", return_value=None): + result = skills_client.delete(name=skill_name) + + assert result is None + + # Verify both DELETE and LRO GET requests using robust assert_has_calls + request_mock.assert_has_calls( + [ + mock.call( + "delete", + skill_name, + {"_url": {"name": skill_name}}, + None, + ), + mock.call( + "get", + "projects/test-project/locations/test-location/skills/test-skill/operations/op-789", + { + "_url": { + "operationName": ( + "projects/test-project/locations/test-location/skills/test-skill/operations/op-789" + ) + } + }, + None, + ), + ] + ) + + def test_delete_skill_no_wait(self, skills_client): + """Tests the delete method with wait_for_completion=False.""" + skill_name = "projects/test-project/locations/test-location/skills/test-skill" + pending_op = { + "name": ( + "projects/test-project/locations/test-location/skills/test-skill/operations/op-789" + ), + "done": False, + } + + with mock.patch.object( + skills_client._api_client, "request", autospec=True + ) as request_mock: + request_mock.return_value = genai_types.HttpResponse( + body=json.dumps(pending_op) + ) + + operation = skills_client.delete( + name=skill_name, config={"wait_for_completion": False} + ) + + assert isinstance(operation, genai.types.DeleteSkillOperation) + assert ( + operation.name + == "projects/test-project/locations/test-location/skills/test-skill/operations/op-789" + ) + assert not operation.done + request_mock.assert_called_once_with( + "delete", + skill_name, + {"_url": {"name": skill_name}}, + None, + ) + + @pytest.mark.asyncio + async def test_delete_skill_async(self, async_skills_client): + """Tests the async delete method with wait_for_completion=True.""" + skill_name = "projects/test-project/locations/test-location/skills/test-skill" + pending_op = { + "name": ( + "projects/test-project/locations/test-location/skills/test-skill/operations/op-789" + ), + "done": False, + } + finished_op = { + "name": ( + "projects/test-project/locations/test-location/skills/test-skill/operations/op-789" + ), + "done": True, + "response": {}, + } + + with mock.patch.object( + async_skills_client._api_client, "async_request", autospec=True + ) as request_mock: + request_mock.side_effect = [ + genai_types.HttpResponse(body=json.dumps(pending_op)), + genai_types.HttpResponse(body=json.dumps(finished_op)), + ] + + with mock.patch("asyncio.sleep", new_callable=mock.AsyncMock): + result = await async_skills_client.delete( + name=skill_name, config={"wait_for_completion": True} + ) + + assert result is None + + # Verify both DELETE and LRO GET requests asynchronously using robust assert_has_calls + request_mock.assert_has_calls( + [ + mock.call( + "delete", + skill_name, + {"_url": {"name": skill_name}}, + None, + ), + mock.call( + "get", + "projects/test-project/locations/test-location/skills/test-skill/operations/op-789", + { + "_url": { + "operationName": ( + "projects/test-project/locations/test-location/skills/test-skill/operations/op-789" + ) + } + }, + None, + ), + ] + ) diff --git a/vertexai/_genai/skills.py b/vertexai/_genai/skills.py index 665a07ad26..9d9035ea31 100644 --- a/vertexai/_genai/skills.py +++ b/vertexai/_genai/skills.py @@ -19,13 +19,14 @@ import base64 import json import logging -from typing import Any, Optional, Union +from typing import Any, Iterator, Optional, Union from urllib.parse import urlencode from google.genai import _api_module from google.genai import _common from google.genai._common import get_value_by_path as getv from google.genai._common import set_value_by_path as setv +from google.genai.pagers import AsyncPager, Pager from . import _skills_utils from . import types @@ -69,6 +70,17 @@ def _CreateSkillRequestParameters_to_vertex( return to_object +def _DeleteSkillRequestParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["name"]) is not None: + setv(to_object, ["_url", "name"], getv(from_object, ["name"])) + + return to_object + + def _GetSkillOperationParameters_to_vertex( from_object: Union[dict[str, Any], object], parent_object: Optional[dict[str, Any]] = None, @@ -96,6 +108,36 @@ def _GetSkillRequestParameters_to_vertex( return to_object +def _ListSkillsConfig_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + + if getv(from_object, ["page_size"]) is not None: + setv(parent_object, ["_query", "pageSize"], getv(from_object, ["page_size"])) + + if getv(from_object, ["page_token"]) is not None: + setv(parent_object, ["_query", "pageToken"], getv(from_object, ["page_token"])) + + return to_object + + +def _ListSkillsRequestParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["config"]) is not None: + setv( + to_object, + ["config"], + _ListSkillsConfig_to_vertex(getv(from_object, ["config"]), to_object), + ) + + return to_object + + def _RetrieveSkillsConfig_to_vertex( from_object: Union[dict[str, Any], object], parent_object: Optional[dict[str, Any]] = None, @@ -126,6 +168,47 @@ def _RetrieveSkillsRequestParameters_to_vertex( return to_object +def _UpdateSkillConfig_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + + if getv(from_object, ["display_name"]) is not None: + setv(parent_object, ["displayName"], getv(from_object, ["display_name"])) + + if getv(from_object, ["description"]) is not None: + setv(parent_object, ["description"], getv(from_object, ["description"])) + + if getv(from_object, ["zipped_filesystem"]) is not None: + setv( + parent_object, + ["zippedFilesystem"], + getv(from_object, ["zipped_filesystem"]), + ) + + if getv(from_object, ["update_mask"]) is not None: + setv( + parent_object, ["_query", "updateMask"], getv(from_object, ["update_mask"]) + ) + + return to_object + + +def _UpdateSkillRequestParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["name"]) is not None: + setv(to_object, ["_url", "name"], getv(from_object, ["name"])) + + if getv(from_object, ["config"]) is not None: + _UpdateSkillConfig_to_vertex(getv(from_object, ["config"]), to_object) + + return to_object + + class Skills(_api_module.BaseModule): """Class for managing Skills in the Skill Registry.""" @@ -341,14 +424,15 @@ def _create( self._api_client._verify_response(return_value) return return_value - def _get_skill_operation( - self, - *, - operation_name: str, - config: Optional[types.GetSkillOperationConfigOrDict] = None, + def _update( + self, *, name: str, config: Optional[types.UpdateSkillConfigOrDict] = None ) -> types.SkillOperation: - parameter_model = types._GetSkillOperationParameters( - operation_name=operation_name, + """ + Updates a Skill. + """ + + parameter_model = types._UpdateSkillRequestParameters( + name=name, config=config, ) @@ -358,12 +442,12 @@ def _get_skill_operation( "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." ) else: - request_dict = _GetSkillOperationParameters_to_vertex(parameter_model) + request_dict = _UpdateSkillRequestParameters_to_vertex(parameter_model) request_url_dict = request_dict.get("_url") if request_url_dict: - path = "{operationName}".format_map(request_url_dict) + path = "{name}".format_map(request_url_dict) else: - path = "{operationName}" + path = "{name}" query_params = request_dict.get("_query") if query_params: @@ -381,7 +465,7 @@ def _get_skill_operation( request_dict = _common.convert_to_dict(request_dict) request_dict = _common.encode_unserializable_types(request_dict) - response = self._api_client.request("get", path, request_dict, http_options) + response = self._api_client.request("patch", path, request_dict, http_options) response_dict = {} if not response.body else json.loads(response.body) @@ -409,102 +493,14 @@ def _get_skill_operation( self._api_client._verify_response(return_value) return return_value - def create( - self, - *, - display_name: str, - description: str, - config: Optional[types.CreateSkillConfigOrDict] = None, - ) -> Union[types.Skill, types.SkillOperation]: - """Creates a new Skill. - - Args: - display_name (str): - Required. The display name of the Skill. - description (str): - Required. The description of the Skill. - config (CreateSkillConfigOrDict): - Optional. The configuration for creating the Skill. - - Returns: - Skill: The created Skill if wait_for_completion is True. - SkillOperation: The operation for creating the Skill if - wait_for_completion is False. - """ - if config is None: - config = types.CreateSkillConfig() - elif isinstance(config, dict): - config = types.CreateSkillConfig.model_validate(config) - elif not isinstance(config, types.CreateSkillConfig): - raise TypeError( - f"config must be a dict or CreateSkillConfig, but got {type(config)}." - ) - - config = config.model_copy() - - local_path = config.local_path - zipped_filesystem = config.zipped_filesystem - - if local_path and zipped_filesystem: - raise ValueError( - "Only one of `local_path` or `zipped_filesystem` can be provided in config." - ) - if not local_path and not zipped_filesystem: - raise ValueError( - "Either `local_path` or `zipped_filesystem` must be provided in config." - ) - - if local_path: - zipped_filesystem_payload = _skills_utils.get_zipped_filesystem_payload( - local_path - ) - else: - # Narrow type for mypy - if zipped_filesystem is None: - raise ValueError( - "zipped_filesystem is required if local_path is not provided." - ) - if isinstance(zipped_filesystem, bytes): - zipped_filesystem_payload = base64.b64encode(zipped_filesystem).decode( - "utf-8" - ) - else: - zipped_filesystem_payload = zipped_filesystem - - # Mutate the config object to populate the zipped_filesystem payload - config.zipped_filesystem = zipped_filesystem_payload - - operation = self._create( - display_name=display_name, - description=description, - config=config, - ) - - if config.wait_for_completion: - operation = _skills_utils.await_operation( - operation_name=operation.name, - get_operation_fn=self._get_skill_operation, - ) - if operation.error: - raise RuntimeError(f"Failed to create Skill: {operation.error}") - # Fetch the fully populated Skill resource from the server - return self.get(name=operation.response.name) - - return operation - - -class AsyncSkills(_api_module.BaseModule): - """Class for managing Skills in the Skill Registry.""" - - async def get( - self, *, name: str, config: Optional[types.GetSkillConfigOrDict] = None - ) -> types.Skill: + def _list( + self, *, config: Optional[types.ListSkillsConfigOrDict] = None + ) -> types.ListSkillsResponse: """ - Gets a Skill. + Lists Skills in the Skill Registry. """ - parameter_model = types._GetSkillRequestParameters( - name=name, + parameter_model = types._ListSkillsRequestParameters( config=config, ) @@ -514,12 +510,12 @@ async def get( "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." ) else: - request_dict = _GetSkillRequestParameters_to_vertex(parameter_model) + request_dict = _ListSkillsRequestParameters_to_vertex(parameter_model) request_url_dict = request_dict.get("_url") if request_url_dict: - path = "{name}".format_map(request_url_dict) + path = "skills".format_map(request_url_dict) else: - path = "{name}" + path = "skills" query_params = request_dict.get("_query") if query_params: @@ -537,13 +533,11 @@ async def get( request_dict = _common.convert_to_dict(request_dict) request_dict = _common.encode_unserializable_types(request_dict) - response = await self._api_client.async_request( - "get", path, request_dict, http_options - ) + response = self._api_client.request("get", path, request_dict, http_options) response_dict = {} if not response.body else json.loads(response.body) - return_value = types.Skill._from_response( + return_value = types.ListSkillsResponse._from_response( response=response_dict, kwargs=( { @@ -567,15 +561,15 @@ async def get( self._api_client._verify_response(return_value) return return_value - async def retrieve( - self, *, query: str, config: Optional[types.RetrieveSkillsConfigOrDict] = None - ) -> types.RetrieveSkillsResponse: + def _delete( + self, *, name: str, config: Optional[types.DeleteSkillConfigOrDict] = None + ) -> types.DeleteSkillOperation: """ - Retrieves skills semantically matched to a query. + Deletes a Skill. """ - parameter_model = types._RetrieveSkillsRequestParameters( - query=query, + parameter_model = types._DeleteSkillRequestParameters( + name=name, config=config, ) @@ -585,12 +579,12 @@ async def retrieve( "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." ) else: - request_dict = _RetrieveSkillsRequestParameters_to_vertex(parameter_model) + request_dict = _DeleteSkillRequestParameters_to_vertex(parameter_model) request_url_dict = request_dict.get("_url") if request_url_dict: - path = "skills:retrieve".format_map(request_url_dict) + path = "{name}".format_map(request_url_dict) else: - path = "skills:retrieve" + path = "{name}" query_params = request_dict.get("_query") if query_params: @@ -608,13 +602,11 @@ async def retrieve( request_dict = _common.convert_to_dict(request_dict) request_dict = _common.encode_unserializable_types(request_dict) - response = await self._api_client.async_request( - "get", path, request_dict, http_options - ) + response = self._api_client.request("delete", path, request_dict, http_options) response_dict = {} if not response.body else json.loads(response.body) - return_value = types.RetrieveSkillsResponse._from_response( + return_value = types.DeleteSkillOperation._from_response( response=response_dict, kwargs=( { @@ -638,20 +630,14 @@ async def retrieve( self._api_client._verify_response(return_value) return return_value - async def _create( + def _get_skill_operation( self, *, - display_name: str, - description: str, - config: Optional[types.CreateSkillConfigOrDict] = None, + operation_name: str, + config: Optional[types.GetSkillOperationConfigOrDict] = None, ) -> types.SkillOperation: - """ - Creates a new Skill. - """ - - parameter_model = types._CreateSkillRequestParameters( - display_name=display_name, - description=description, + parameter_model = types._GetSkillOperationParameters( + operation_name=operation_name, config=config, ) @@ -661,12 +647,12 @@ async def _create( "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." ) else: - request_dict = _CreateSkillRequestParameters_to_vertex(parameter_model) + request_dict = _GetSkillOperationParameters_to_vertex(parameter_model) request_url_dict = request_dict.get("_url") if request_url_dict: - path = "skills".format_map(request_url_dict) + path = "{operationName}".format_map(request_url_dict) else: - path = "skills" + path = "{operationName}" query_params = request_dict.get("_query") if query_params: @@ -684,9 +670,7 @@ async def _create( request_dict = _common.convert_to_dict(request_dict) request_dict = _common.encode_unserializable_types(request_dict) - response = await self._api_client.async_request( - "post", path, request_dict, http_options - ) + response = self._api_client.request("get", path, request_dict, http_options) response_dict = {} if not response.body else json.loads(response.body) @@ -714,29 +698,702 @@ async def _create( self._api_client._verify_response(return_value) return return_value - async def _get_skill_operation( + def create( self, *, - operation_name: str, - config: Optional[types.GetSkillOperationConfigOrDict] = None, - ) -> types.SkillOperation: - parameter_model = types._GetSkillOperationParameters( - operation_name=operation_name, - config=config, - ) - - request_url_dict: Optional[dict[str, str]] - if not self._api_client.vertexai: - raise ValueError( - "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." - ) - else: - request_dict = _GetSkillOperationParameters_to_vertex(parameter_model) - request_url_dict = request_dict.get("_url") - if request_url_dict: - path = "{operationName}".format_map(request_url_dict) - else: - path = "{operationName}" + display_name: str, + description: str, + config: Optional[types.CreateSkillConfigOrDict] = None, + ) -> Union[types.Skill, types.SkillOperation]: + """Creates a new Skill. + + Args: + display_name (str): + Required. The display name of the Skill. + description (str): + Required. The description of the Skill. + config (CreateSkillConfigOrDict): + Optional. The configuration for creating the Skill. + + Returns: + Skill: The created Skill if wait_for_completion is True. + SkillOperation: The operation for creating the Skill if + wait_for_completion is False. + """ + if config is None: + config = types.CreateSkillConfig() + elif isinstance(config, dict): + config = types.CreateSkillConfig.model_validate(config) + elif not isinstance(config, types.CreateSkillConfig): + raise TypeError( + f"config must be a dict or CreateSkillConfig, but got {type(config)}." + ) + + config = config.model_copy() + + local_path = config.local_path + zipped_filesystem = config.zipped_filesystem + + if local_path and zipped_filesystem: + raise ValueError( + "Only one of `local_path` or `zipped_filesystem` can be provided in config." + ) + if not local_path and not zipped_filesystem: + raise ValueError( + "Either `local_path` or `zipped_filesystem` must be provided in config." + ) + + if local_path: + zipped_filesystem_payload = _skills_utils.get_zipped_filesystem_payload( + local_path + ) + else: + # Narrow type for mypy + if zipped_filesystem is None: + raise ValueError( + "zipped_filesystem is required if local_path is not provided." + ) + if isinstance(zipped_filesystem, bytes): + zipped_filesystem_payload = base64.b64encode(zipped_filesystem).decode( + "utf-8" + ) + else: + zipped_filesystem_payload = zipped_filesystem + + # Mutate the config object to populate the zipped_filesystem payload + config.zipped_filesystem = zipped_filesystem_payload + + operation = self._create( + display_name=display_name, + description=description, + config=config, + ) + + if config.wait_for_completion: + operation = _skills_utils.await_operation( + operation_name=operation.name, + get_operation_fn=self._get_skill_operation, + ) + if operation.error: + raise RuntimeError(f"Failed to create Skill: {operation.error}") + # Fetch the fully populated Skill resource from the server + return self.get(name=operation.response.name) + + return operation + + def update( + self, + *, + name: str, + config: Optional[types.UpdateSkillConfigOrDict] = None, + ) -> Union[types.Skill, types.SkillOperation]: + """Updates an existing Skill. + + Args: + name (str): + Required. The resource name of the Skill to update. + Format: projects/{project}/locations/{location}/skills/{skill} + config (UpdateSkillConfigOrDict): + Optional. The configuration for updating the Skill. + + Returns: + Skill: The updated Skill if wait_for_completion is True. + SkillOperation: The operation for updating the Skill if + wait_for_completion is False. + """ + if config is None: + config = types.UpdateSkillConfig() + elif isinstance(config, dict): + config = types.UpdateSkillConfig.model_validate(config) + elif not isinstance(config, types.UpdateSkillConfig): + raise TypeError( + f"config must be a dict or UpdateSkillConfig, but got {type(config)}." + ) + + config = config.model_copy() + + display_name = config.display_name + description = config.description + local_path = config.local_path + zipped_filesystem = config.zipped_filesystem + + if local_path and zipped_filesystem: + raise ValueError( + "Only one of `local_path` or `zipped_filesystem` can be provided in config." + ) + + # Construct update_mask and prepare payload + update_mask_paths = [] + zipped_filesystem_payload = None + + if display_name is not None: + update_mask_paths.append("displayName") + + if description is not None: + update_mask_paths.append("description") + + if local_path: + zipped_filesystem_payload = _skills_utils.get_zipped_filesystem_payload( + local_path + ) + update_mask_paths.append("zippedFilesystem") + elif zipped_filesystem is not None: + if isinstance(zipped_filesystem, bytes): + zipped_filesystem_payload = base64.b64encode(zipped_filesystem).decode( + "utf-8" + ) + else: + zipped_filesystem_payload = zipped_filesystem + update_mask_paths.append("zippedFilesystem") + + if not update_mask_paths: + raise ValueError( + "At least one of `display_name`, `description`, `local_path`, or " + "`zipped_filesystem` must be provided for update in config." + ) + + update_mask = ",".join(update_mask_paths) + + # Mutate config in place to populate the generated update_mask and zipped_filesystem + config.update_mask = update_mask + config.zipped_filesystem = zipped_filesystem_payload + + operation = self._update( + name=name, + config=config, + ) + + if config.wait_for_completion: + operation = _skills_utils.await_operation( + operation_name=operation.name, + get_operation_fn=self._get_skill_operation, + ) + if operation.error: + raise RuntimeError(f"Failed to update Skill: {operation.error}") + # Fetch the fully populated Skill resource from the server + return self.get(name=name) + + return operation + + def list( + self, + *, + config: Optional[types.ListSkillsConfigOrDict] = None, + ) -> Iterator[types.Skill]: + """Lists Skills in the Skill Registry. + + Args: + config (ListSkillsConfigOrDict): + Optional. Additional configuration for listing Skills. + + Returns: + Iterator[Skill]: An iterator (Pager) of Skills. + """ + return Pager( + "skills", + self._list, + self._list(config=config), + config, + ) + + def delete( + self, + *, + name: str, + config: Optional[types.DeleteSkillConfigOrDict] = None, + ) -> Optional[types.DeleteSkillOperation]: + """Deletes a Skill. + + Args: + name (str): + Required. The resource name of the Skill to delete. + Format: projects/{project}/locations/{location}/skills/{skill} + config (DeleteSkillConfigOrDict): + Optional. Additional configuration for the delete operation. + + Returns: + DeleteSkillOperation: The pending LRO if wait_for_completion is False, + otherwise None (blocks until done). + """ + if config is None: + config = types.DeleteSkillConfig() + elif isinstance(config, dict): + config = types.DeleteSkillConfig.model_validate(config) + elif not isinstance(config, types.DeleteSkillConfig): + raise TypeError( + f"config must be a dict or DeleteSkillConfig, but got {type(config)}." + ) + + operation = self._delete(name=name, config=config) + + if config.wait_for_completion: + operation = _skills_utils.await_operation( + operation_name=operation.name, + get_operation_fn=self._get_skill_operation, + ) + if operation.error: + raise RuntimeError(f"Failed to delete Skill: {operation.error}") + return None + + return operation + + +class AsyncSkills(_api_module.BaseModule): + """Class for managing Skills in the Skill Registry.""" + + async def get( + self, *, name: str, config: Optional[types.GetSkillConfigOrDict] = None + ) -> types.Skill: + """ + Gets a Skill. + """ + + parameter_model = types._GetSkillRequestParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _GetSkillRequestParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}".format_map(request_url_dict) + else: + path = "{name}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "get", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.Skill._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def retrieve( + self, *, query: str, config: Optional[types.RetrieveSkillsConfigOrDict] = None + ) -> types.RetrieveSkillsResponse: + """ + Retrieves skills semantically matched to a query. + """ + + parameter_model = types._RetrieveSkillsRequestParameters( + query=query, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _RetrieveSkillsRequestParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "skills:retrieve".format_map(request_url_dict) + else: + path = "skills:retrieve" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "get", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.RetrieveSkillsResponse._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _create( + self, + *, + display_name: str, + description: str, + config: Optional[types.CreateSkillConfigOrDict] = None, + ) -> types.SkillOperation: + """ + Creates a new Skill. + """ + + parameter_model = types._CreateSkillRequestParameters( + display_name=display_name, + description=description, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _CreateSkillRequestParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "skills".format_map(request_url_dict) + else: + path = "skills" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "post", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.SkillOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _update( + self, *, name: str, config: Optional[types.UpdateSkillConfigOrDict] = None + ) -> types.SkillOperation: + """ + Updates a Skill. + """ + + parameter_model = types._UpdateSkillRequestParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _UpdateSkillRequestParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}".format_map(request_url_dict) + else: + path = "{name}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "patch", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.SkillOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _list( + self, *, config: Optional[types.ListSkillsConfigOrDict] = None + ) -> types.ListSkillsResponse: + """ + Lists Skills in the Skill Registry. + """ + + parameter_model = types._ListSkillsRequestParameters( + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _ListSkillsRequestParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "skills".format_map(request_url_dict) + else: + path = "skills" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "get", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.ListSkillsResponse._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _delete( + self, *, name: str, config: Optional[types.DeleteSkillConfigOrDict] = None + ) -> types.DeleteSkillOperation: + """ + Deletes a Skill. + """ + + parameter_model = types._DeleteSkillRequestParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _DeleteSkillRequestParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}".format_map(request_url_dict) + else: + path = "{name}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "delete", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.DeleteSkillOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _get_skill_operation( + self, + *, + operation_name: str, + config: Optional[types.GetSkillOperationConfigOrDict] = None, + ) -> types.SkillOperation: + parameter_model = types._GetSkillOperationParameters( + operation_name=operation_name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _GetSkillOperationParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{operationName}".format_map(request_url_dict) + else: + path = "{operationName}" query_params = request_dict.get("_query") if query_params: @@ -867,3 +1524,160 @@ async def create( return await self.get(name=operation.response.name) return operation + + async def update( + self, + *, + name: str, + config: Optional[types.UpdateSkillConfigOrDict] = None, + ) -> Union[types.Skill, types.SkillOperation]: + """Updates an existing Skill asynchronously. + + Args: + name (str): + Required. The resource name of the Skill to update. + Format: projects/{project}/locations/{location}/skills/{skill} + config (UpdateSkillConfigOrDict): + Optional. The configuration for updating the Skill. + + Returns: + Skill: The updated Skill if wait_for_completion is True. + SkillOperation: The operation for updating the Skill if + wait_for_completion is False. + """ + if config is None: + config = types.UpdateSkillConfig() + elif isinstance(config, dict): + config = types.UpdateSkillConfig.model_validate(config) + elif not isinstance(config, types.UpdateSkillConfig): + raise TypeError( + f"config must be a dict or UpdateSkillConfig, but got {type(config)}." + ) + + config = config.model_copy() + + display_name = config.display_name + description = config.description + local_path = config.local_path + zipped_filesystem = config.zipped_filesystem + + if local_path and zipped_filesystem: + raise ValueError( + "Only one of `local_path` or `zipped_filesystem` can be provided in config." + ) + + # Construct update_mask and prepare payload + update_mask_paths = [] + zipped_filesystem_payload = None + + if display_name is not None: + update_mask_paths.append("displayName") + + if description is not None: + update_mask_paths.append("description") + + if local_path: + loop = asyncio.get_running_loop() + zipped_filesystem_payload = await loop.run_in_executor( + None, _skills_utils.get_zipped_filesystem_payload, local_path + ) + update_mask_paths.append("zippedFilesystem") + elif zipped_filesystem is not None: + if isinstance(zipped_filesystem, bytes): + zipped_filesystem_payload = base64.b64encode(zipped_filesystem).decode( + "utf-8" + ) + else: + zipped_filesystem_payload = zipped_filesystem + update_mask_paths.append("zippedFilesystem") + + if not update_mask_paths: + raise ValueError( + "At least one of `display_name`, `description`, `local_path`, or " + "`zipped_filesystem` must be provided for update in config." + ) + + update_mask = ",".join(update_mask_paths) + + # Mutate config in place to populate the generated update_mask and zipped_filesystem + config.update_mask = update_mask + config.zipped_filesystem = zipped_filesystem_payload + + operation = await self._update( + name=name, + config=config, + ) + + if config.wait_for_completion: + operation = await _skills_utils.await_operation_async( + operation_name=operation.name, + get_operation_fn=self._get_skill_operation, + ) + if operation.error: + raise RuntimeError(f"Failed to update Skill: {operation.error}") + # Fetch the fully populated Skill resource asynchronously + return await self.get(name=name) + + return operation + + async def list( + self, + *, + config: Optional[types.ListSkillsConfigOrDict] = None, + ) -> AsyncPager[types.Skill]: + """Lists Skills in the Skill Registry asynchronously. + + Args: + config (ListSkillsConfigOrDict): + Optional. Additional configuration for listing Skills. + + Returns: + AsyncPager[Skill]: An async pager of Skills. + """ + return AsyncPager( + "skills", + self._list, + await self._list(config=config), + config, + ) + + async def delete( + self, + *, + name: str, + config: Optional[types.DeleteSkillConfigOrDict] = None, + ) -> Optional[types.DeleteSkillOperation]: + """Deletes a Skill asynchronously. + + Args: + name (str): + Required. The resource name of the Skill to delete. + Format: projects/{project}/locations/{location}/skills/{skill} + config (DeleteSkillConfigOrDict): + Optional. Additional configuration for the delete operation. + + Returns: + DeleteSkillOperation: The pending LRO if wait_for_completion is False, + otherwise None (blocks until done). + """ + if config is None: + config = types.DeleteSkillConfig() + elif isinstance(config, dict): + config = types.DeleteSkillConfig.model_validate(config) + elif not isinstance(config, types.DeleteSkillConfig): + raise TypeError( + f"config must be a dict or DeleteSkillConfig, but got {type(config)}." + ) + + operation = await self._delete(name=name, config=config) + + if config.wait_for_completion: + operation = await _skills_utils.await_operation_async( + operation_name=operation.name, + get_operation_fn=self._get_skill_operation, + ) + if operation.error: + raise RuntimeError(f"Failed to delete Skill: {operation.error}") + return None + + return operation diff --git a/vertexai/_genai/types/__init__.py b/vertexai/_genai/types/__init__.py index 2e4acdba09..cb1a224775 100644 --- a/vertexai/_genai/types/__init__.py +++ b/vertexai/_genai/types/__init__.py @@ -57,6 +57,7 @@ from .common import _DeletePromptVersionRequestParameters from .common import _DeleteSandboxEnvironmentSnapshotRequestParameters from .common import _DeleteSandboxEnvironmentTemplateRequestParameters +from .common import _DeleteSkillRequestParameters from .common import _EvaluateInstancesRequestParameters from .common import _ExecuteCodeAgentEngineSandboxRequestParameters from .common import _GenerateAgentEngineMemoriesRequestParameters @@ -109,6 +110,7 @@ from .common import _ListMultimodalDatasetsRequestParameters from .common import _ListSandboxEnvironmentSnapshotsRequestParameters from .common import _ListSandboxEnvironmentTemplatesRequestParameters +from .common import _ListSkillsRequestParameters from .common import _OptimizeRequestParameters from .common import _OptimizeRequestParameters from .common import _PurgeAgentEngineMemoriesRequestParameters @@ -128,6 +130,7 @@ from .common import _UpdateAgentEngineSessionRequestParameters from .common import _UpdateDatasetParameters from .common import _UpdateMultimodalDatasetParameters +from .common import _UpdateSkillRequestParameters from .common import A2aTask from .common import A2aTaskDict from .common import A2aTaskOrDict @@ -384,6 +387,12 @@ from .common import DeleteSandboxEnvironmentTemplateOperation from .common import DeleteSandboxEnvironmentTemplateOperationDict from .common import DeleteSandboxEnvironmentTemplateOperationOrDict +from .common import DeleteSkillConfig +from .common import DeleteSkillConfigDict +from .common import DeleteSkillConfigOrDict +from .common import DeleteSkillOperation +from .common import DeleteSkillOperationDict +from .common import DeleteSkillOperationOrDict from .common import DiskSpec from .common import DiskSpecDict from .common import DiskSpecOrDict @@ -730,6 +739,12 @@ from .common import ListSandboxEnvironmentTemplatesResponse from .common import ListSandboxEnvironmentTemplatesResponseDict from .common import ListSandboxEnvironmentTemplatesResponseOrDict +from .common import ListSkillsConfig +from .common import ListSkillsConfigDict +from .common import ListSkillsConfigOrDict +from .common import ListSkillsResponse +from .common import ListSkillsResponseDict +from .common import ListSkillsResponseOrDict from .common import LLMMetric from .common import LossAnalysisConfig from .common import LossAnalysisConfigDict @@ -1404,6 +1419,9 @@ from .common import UpdatePromptConfig from .common import UpdatePromptConfigDict from .common import UpdatePromptConfigOrDict +from .common import UpdateSkillConfig +from .common import UpdateSkillConfigDict +from .common import UpdateSkillConfigOrDict from .common import VertexBaseConfig from .common import VertexBaseConfigDict from .common import VertexBaseConfigOrDict @@ -2528,6 +2546,21 @@ "SkillOperation", "SkillOperationDict", "SkillOperationOrDict", + "UpdateSkillConfig", + "UpdateSkillConfigDict", + "UpdateSkillConfigOrDict", + "ListSkillsConfig", + "ListSkillsConfigDict", + "ListSkillsConfigOrDict", + "ListSkillsResponse", + "ListSkillsResponseDict", + "ListSkillsResponseOrDict", + "DeleteSkillConfig", + "DeleteSkillConfigDict", + "DeleteSkillConfigOrDict", + "DeleteSkillOperation", + "DeleteSkillOperationDict", + "DeleteSkillOperationOrDict", "GetSkillOperationConfig", "GetSkillOperationConfigDict", "GetSkillOperationConfigOrDict", @@ -2771,6 +2804,9 @@ "_GetSkillRequestParameters", "_RetrieveSkillsRequestParameters", "_CreateSkillRequestParameters", + "_UpdateSkillRequestParameters", + "_ListSkillsRequestParameters", + "_DeleteSkillRequestParameters", "_GetSkillOperationParameters", "evals", "agent_engines", diff --git a/vertexai/_genai/types/common.py b/vertexai/_genai/types/common.py index 831ad80a02..3cbe64356b 100644 --- a/vertexai/_genai/types/common.py +++ b/vertexai/_genai/types/common.py @@ -18159,6 +18159,253 @@ class SkillOperationDict(TypedDict, total=False): SkillOperationOrDict = Union[SkillOperation, SkillOperationDict] +class UpdateSkillConfig(_common.BaseModel): + """Config for updating a skill.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + wait_for_completion: Optional[bool] = Field( + default=True, + description="""Whether to wait for the long running operation to complete.""", + ) + local_path: Optional[str] = Field( + default=None, + description="""Optional. The local path to the directory containing the Skill to + be zipped and uploaded. + """, + ) + display_name: Optional[str] = Field( + default=None, description="""Optional. The display name of the Skill.""" + ) + description: Optional[str] = Field( + default=None, description="""Optional. The description of the Skill.""" + ) + zipped_filesystem: Optional[Any] = Field( + default=None, description="""Optional. The zipped filesystem of the Skill.""" + ) + update_mask: Optional[str] = Field( + default=None, description="""Optional. The update mask to apply.""" + ) + + +class UpdateSkillConfigDict(TypedDict, total=False): + """Config for updating a skill.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + wait_for_completion: Optional[bool] + """Whether to wait for the long running operation to complete.""" + + local_path: Optional[str] + """Optional. The local path to the directory containing the Skill to + be zipped and uploaded. + """ + + display_name: Optional[str] + """Optional. The display name of the Skill.""" + + description: Optional[str] + """Optional. The description of the Skill.""" + + zipped_filesystem: Optional[Any] + """Optional. The zipped filesystem of the Skill.""" + + update_mask: Optional[str] + """Optional. The update mask to apply.""" + + +UpdateSkillConfigOrDict = Union[UpdateSkillConfig, UpdateSkillConfigDict] + + +class _UpdateSkillRequestParameters(_common.BaseModel): + """Parameters for updating a skill.""" + + name: Optional[str] = Field( + default=None, + description="""Required. The resource name of the Skill to update.""", + ) + config: Optional[UpdateSkillConfig] = Field(default=None, description="""""") + + +class _UpdateSkillRequestParametersDict(TypedDict, total=False): + """Parameters for updating a skill.""" + + name: Optional[str] + """Required. The resource name of the Skill to update.""" + + config: Optional[UpdateSkillConfigDict] + """""" + + +_UpdateSkillRequestParametersOrDict = Union[ + _UpdateSkillRequestParameters, _UpdateSkillRequestParametersDict +] + + +class ListSkillsConfig(_common.BaseModel): + """Config for listing skills.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + page_size: Optional[int] = Field(default=None, description="""""") + page_token: Optional[str] = Field(default=None, description="""""") + + +class ListSkillsConfigDict(TypedDict, total=False): + """Config for listing skills.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + page_size: Optional[int] + """""" + + page_token: Optional[str] + """""" + + +ListSkillsConfigOrDict = Union[ListSkillsConfig, ListSkillsConfigDict] + + +class _ListSkillsRequestParameters(_common.BaseModel): + """Parameters for listing skills.""" + + config: Optional[ListSkillsConfig] = Field(default=None, description="""""") + + +class _ListSkillsRequestParametersDict(TypedDict, total=False): + """Parameters for listing skills.""" + + config: Optional[ListSkillsConfigDict] + """""" + + +_ListSkillsRequestParametersOrDict = Union[ + _ListSkillsRequestParameters, _ListSkillsRequestParametersDict +] + + +class ListSkillsResponse(_common.BaseModel): + """Response for listing skills.""" + + sdk_http_response: Optional[genai_types.HttpResponse] = Field( + default=None, description="""Used to retain the full HTTP response.""" + ) + next_page_token: Optional[str] = Field(default=None, description="""""") + skills: Optional[list[Skill]] = Field( + default=None, description="""List of Skills.""" + ) + + +class ListSkillsResponseDict(TypedDict, total=False): + """Response for listing skills.""" + + sdk_http_response: Optional[genai_types.HttpResponseDict] + """Used to retain the full HTTP response.""" + + next_page_token: Optional[str] + """""" + + skills: Optional[list[SkillDict]] + """List of Skills.""" + + +ListSkillsResponseOrDict = Union[ListSkillsResponse, ListSkillsResponseDict] + + +class DeleteSkillConfig(_common.BaseModel): + """Config for deleting a skill.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + wait_for_completion: Optional[bool] = Field( + default=True, + description="""Whether to wait for the long running operation to complete.""", + ) + + +class DeleteSkillConfigDict(TypedDict, total=False): + """Config for deleting a skill.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + wait_for_completion: Optional[bool] + """Whether to wait for the long running operation to complete.""" + + +DeleteSkillConfigOrDict = Union[DeleteSkillConfig, DeleteSkillConfigDict] + + +class _DeleteSkillRequestParameters(_common.BaseModel): + """Parameters for deleting a skill.""" + + name: Optional[str] = Field( + default=None, + description="""Required. The resource name of the Skill to delete.""", + ) + config: Optional[DeleteSkillConfig] = Field(default=None, description="""""") + + +class _DeleteSkillRequestParametersDict(TypedDict, total=False): + """Parameters for deleting a skill.""" + + name: Optional[str] + """Required. The resource name of the Skill to delete.""" + + config: Optional[DeleteSkillConfigDict] + """""" + + +_DeleteSkillRequestParametersOrDict = Union[ + _DeleteSkillRequestParameters, _DeleteSkillRequestParametersDict +] + + +class DeleteSkillOperation(_common.BaseModel): + """Operation for deleting a skill.""" + + name: Optional[str] = Field( + default=None, + description="""The server-assigned name, which is only unique within the same service that originally returns it. If you use the default HTTP mapping, the `name` should be a resource name ending with `operations/{unique_id}`.""", + ) + metadata: Optional[dict[str, Any]] = Field( + default=None, + description="""Service-specific metadata associated with the operation. It typically contains progress information and common metadata such as create time. Some services might not provide such metadata. Any method that returns a long-running operation should document the metadata type, if any.""", + ) + done: Optional[bool] = Field( + default=None, + description="""If the value is `false`, it means the operation is still in progress. If `true`, the operation is completed, and either `error` or `response` is available.""", + ) + error: Optional[dict[str, Any]] = Field( + default=None, + description="""The error result of the operation in case of failure or cancellation.""", + ) + + +class DeleteSkillOperationDict(TypedDict, total=False): + """Operation for deleting a skill.""" + + name: Optional[str] + """The server-assigned name, which is only unique within the same service that originally returns it. If you use the default HTTP mapping, the `name` should be a resource name ending with `operations/{unique_id}`.""" + + metadata: Optional[dict[str, Any]] + """Service-specific metadata associated with the operation. It typically contains progress information and common metadata such as create time. Some services might not provide such metadata. Any method that returns a long-running operation should document the metadata type, if any.""" + + done: Optional[bool] + """If the value is `false`, it means the operation is still in progress. If `true`, the operation is completed, and either `error` or `response` is available.""" + + error: Optional[dict[str, Any]] + """The error result of the operation in case of failure or cancellation.""" + + +DeleteSkillOperationOrDict = Union[DeleteSkillOperation, DeleteSkillOperationDict] + + class GetSkillOperationConfig(_common.BaseModel): http_options: Optional[genai_types.HttpOptions] = Field(