Skip to content

Commit 406f7c7

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Vertex AI Skill Registry - Create Skill method
PiperOrigin-RevId: 910161153
1 parent 9ea4aa6 commit 406f7c7

8 files changed

Lines changed: 1469 additions & 0 deletions

File tree

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
"""Tests the skills.create() method against the Vertex AI endpoint using replays."""
16+
17+
import os
18+
import tempfile
19+
20+
from vertexai._genai import types
21+
22+
23+
def test_create_skill(client):
24+
with tempfile.TemporaryDirectory() as tmpdir:
25+
# Create a dummy skill structure (SKILL.md is required by the spec)
26+
with open(os.path.join(tmpdir, "SKILL.md"), "w") as f:
27+
f.write("# My Replay Skill\nThis is a test skill for replay tests.")
28+
29+
skill = client.skills.create(
30+
display_name="My Replay Skill",
31+
description="My Replay Skill Description",
32+
local_path=tmpdir,
33+
config=types.CreateSkillConfig(wait_for_completion=True),
34+
)
35+
36+
assert skill.name is not None
37+
assert skill.display_name == "My Replay Skill"
38+
assert skill.description == "My Replay Skill Description"
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
"""Tests the skills.get() method against the autopush endpoint."""
2+
3+
from google.api_core import exceptions
4+
from tests.unit.vertexai.genai.replays import pytest_helper
5+
import pytest
6+
7+
PROJECT_ID = "srbai-testing"
8+
REGION = "us-central1"
9+
# SKILL_ID = "5578834038405201920"
10+
SKILL_ID = "7184367305562783744"
11+
ENDPOINT = f"{REGION}-autopush-aiplatform.sandbox.googleapis.com"
12+
13+
# # Configure HTTP options to target the autopush endpoint
14+
# my_http_options = genai_types.HttpOptions(
15+
# api_version="v1beta1",
16+
# base_url=f"https://{ENDPOINT}/v1beta1/" # <---APPENDED /v1beta1/ here
17+
# )
18+
19+
pytestmark = pytest_helper.setup(
20+
file=__file__,
21+
globals_for_file=globals(),
22+
# http_options=my_http_options,
23+
)
24+
25+
26+
def test_get_skill(client): # client fixture is injected by pytest_helper.setup
27+
"""Tests the skills.get() method against the autopush endpoint."""
28+
29+
client._api_client._http_options.base_url = (
30+
"https://us-central1-autopush-aiplatform.sandbox.googleapis.com"
31+
)
32+
skill_name = f"projects/{PROJECT_ID}/locations/{REGION}/skills/{SKILL_ID}"
33+
34+
try:
35+
skill = client.skills.get(name=skill_name)
36+
assert skill.name == skill_name
37+
38+
except exceptions.GoogleAPIError as e:
39+
pytest.fail(f"Error calling client.skills.get(): {e}")
Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
# //third_party/py/google/cloud/aiplatform/tests/unit/vertexai/genai/test_genai_skills.py
2+
import importlib
3+
import json
4+
from unittest import mock
5+
from vertexai import _genai as genai
6+
from vertexai._genai import client as vertexai_client
7+
from google.genai import types as genai_types
8+
import pytest
9+
10+
11+
@pytest.fixture
12+
def skills_client():
13+
creds = mock.MagicMock()
14+
creds.token = "test_token"
15+
client = vertexai_client.Client(
16+
project="test-project", location="test-location", credentials=creds
17+
)
18+
return client.skills
19+
20+
21+
@pytest.fixture
22+
def async_skills_client():
23+
creds = mock.MagicMock()
24+
creds.token = "test_token"
25+
client = vertexai_client.Client(
26+
project="test-project", location="test-location", credentials=creds
27+
)
28+
return client.aio.skills
29+
30+
31+
class TestGenaiSkills:
32+
mock_get_skill_response = {
33+
"name": "projects/test-project/locations/test-location/skills/test-skill",
34+
"displayName": "My Test Skill",
35+
}
36+
37+
def test_get_skill(self, skills_client):
38+
"""Tests the get_skill method."""
39+
with mock.patch.object(skills_client._api_client, "request") as request_mock:
40+
request_mock.return_value = genai_types.HttpResponse(
41+
body=json.dumps(self.mock_get_skill_response)
42+
)
43+
skill_name = (
44+
"projects/test-project/locations/test-location/skills/test-skill"
45+
)
46+
skill = skills_client.get(name=skill_name)
47+
request_mock.assert_called_with(
48+
"get",
49+
skill_name,
50+
{"_url": {"name": skill_name}},
51+
None,
52+
)
53+
assert isinstance(skill, genai.types.Skill)
54+
assert skill.name == skill_name
55+
assert skill.display_name == "My Test Skill"
56+
57+
def test_create_skill(self, skills_client):
58+
"""Tests the create_skill method with wait_for_completion=True."""
59+
import tempfile
60+
import os
61+
62+
with tempfile.TemporaryDirectory() as tmpdir:
63+
# Create a dummy file in tmpdir
64+
with open(os.path.join(tmpdir, "SKILL.md"), "w") as f:
65+
f.write("# Test Skill")
66+
67+
# Prepare mock responses
68+
pending_op = {
69+
"name": "projects/test-project/locations/test-location/skills/test-skill/operations/op-123",
70+
"done": False,
71+
}
72+
finished_op = {
73+
"name": "projects/test-project/locations/test-location/skills/test-skill/operations/op-123",
74+
"done": True,
75+
"response": {
76+
"name": "projects/test-project/locations/test-location/skills/test-skill",
77+
"displayName": "My Test Skill",
78+
"description": "My Test Skill Description",
79+
}
80+
}
81+
82+
with mock.patch.object(skills_client._api_client, "request") as request_mock:
83+
request_mock.side_effect = [
84+
genai_types.HttpResponse(body=json.dumps(pending_op)),
85+
genai_types.HttpResponse(body=json.dumps(finished_op)),
86+
]
87+
88+
# We mock time.sleep to speed up the test
89+
with mock.patch("time.sleep", return_value=None):
90+
skill = skills_client.create(
91+
display_name="My Test Skill",
92+
description="My Test Skill Description",
93+
local_path=tmpdir,
94+
config={"wait_for_completion": True}
95+
)
96+
97+
# Assertions
98+
assert request_mock.call_count == 2
99+
100+
# Verify POST request
101+
post_call = request_mock.call_args_list[0]
102+
assert post_call[0][0] == "post"
103+
assert post_call[0][1] == "skills"
104+
105+
post_body = post_call[0][2]
106+
assert post_body["skill"]["displayName"] == "My Test Skill"
107+
assert post_body["skill"]["description"] == "My Test Skill Description"
108+
assert isinstance(post_body["skill"]["zippedFilesystem"], str)
109+
110+
# Verify GET request (polling)
111+
get_call = request_mock.call_args_list[1]
112+
assert get_call[0][0] == "get"
113+
assert get_call[0][1] == "projects/test-project/locations/test-location/skills/test-skill/operations/op-123"
114+
115+
# Verify returned skill
116+
assert isinstance(skill, genai.types.Skill)
117+
assert skill.name == "projects/test-project/locations/test-location/skills/test-skill"
118+
assert skill.display_name == "My Test Skill"
119+
assert skill.description == "My Test Skill Description"
120+
121+
def test_create_skill_no_wait(self, skills_client):
122+
"""Tests the create_skill method with wait_for_completion=False."""
123+
import tempfile
124+
import os
125+
126+
with tempfile.TemporaryDirectory() as tmpdir:
127+
with open(os.path.join(tmpdir, "SKILL.md"), "w") as f:
128+
f.write("# Test Skill")
129+
130+
pending_op = {
131+
"name": "projects/test-project/locations/test-location/skills/test-skill/operations/op-123",
132+
"done": False,
133+
}
134+
135+
with mock.patch.object(skills_client._api_client, "request") as request_mock:
136+
request_mock.return_value = genai_types.HttpResponse(body=json.dumps(pending_op))
137+
138+
operation = skills_client.create(
139+
display_name="My Test Skill",
140+
description="My Test Skill Description",
141+
local_path=tmpdir,
142+
config={"wait_for_completion": False}
143+
)
144+
145+
# Assertions
146+
assert request_mock.call_count == 1
147+
assert isinstance(operation, genai.types.SkillOperation)
148+
assert operation.name == "projects/test-project/locations/test-location/skills/test-skill/operations/op-123"
149+
assert not operation.done
150+
151+
@pytest.mark.asyncio
152+
async def test_create_skill_async(self, async_skills_client):
153+
"""Tests the create_skill method asynchronously with wait_for_completion=True."""
154+
import tempfile
155+
import os
156+
157+
with tempfile.TemporaryDirectory() as tmpdir:
158+
with open(os.path.join(tmpdir, "SKILL.md"), "w") as f:
159+
f.write("# Test Skill")
160+
161+
pending_op = {
162+
"name": "projects/test-project/locations/test-location/skills/test-skill/operations/op-123",
163+
"done": False,
164+
}
165+
finished_op = {
166+
"name": "projects/test-project/locations/test-location/skills/test-skill/operations/op-123",
167+
"done": True,
168+
"response": {
169+
"name": "projects/test-project/locations/test-location/skills/test-skill",
170+
"displayName": "My Test Skill",
171+
"description": "My Test Skill Description",
172+
}
173+
}
174+
175+
with mock.patch.object(async_skills_client._api_client, "async_request") as request_mock:
176+
request_mock.side_effect = [
177+
genai_types.HttpResponse(body=json.dumps(pending_op)),
178+
genai_types.HttpResponse(body=json.dumps(finished_op)),
179+
]
180+
181+
with mock.patch("asyncio.sleep", new_callable=mock.AsyncMock):
182+
skill = await async_skills_client.create(
183+
display_name="My Test Skill",
184+
description="My Test Skill Description",
185+
local_path=tmpdir,
186+
config={"wait_for_completion": True}
187+
)
188+
189+
# Assertions
190+
assert request_mock.call_count == 2
191+
192+
# Verify POST request
193+
post_call = request_mock.call_args_list[0]
194+
assert post_call[0][0] == "post"
195+
assert post_call[0][1] == "skills"
196+
197+
# Verify returned skill
198+
assert isinstance(skill, genai.types.Skill)
199+
assert skill.name == "projects/test-project/locations/test-location/skills/test-skill"
200+

vertexai/_genai/_skills_utils.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
"""Utility functions for Skills."""
16+
17+
import asyncio
18+
import base64
19+
import io
20+
import os
21+
import time
22+
from typing import Any, Callable, Awaitable
23+
import zipfile
24+
25+
26+
def zip_directory(directory_path: str) -> bytes:
27+
"""Zips a directory into memory and returns the bytes.
28+
29+
Args:
30+
directory_path (str): Required. The local path to the directory.
31+
32+
Returns:
33+
bytes: The zipped directory content.
34+
"""
35+
if not os.path.isdir(directory_path):
36+
raise ValueError(f"Path is not a directory: {directory_path}")
37+
38+
zip_buffer = io.BytesIO()
39+
with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zip_file:
40+
for root, _, files in os.walk(directory_path):
41+
for file in files:
42+
file_path = os.path.join(root, file)
43+
arcname = os.path.relpath(file_path, directory_path)
44+
zip_file.write(file_path, arcname)
45+
return zip_buffer.getvalue()
46+
47+
48+
def get_zipped_filesystem_payload(directory_path: str) -> str:
49+
"""Zips a directory and base64-encodes the result to a UTF-8 string.
50+
51+
Args:
52+
directory_path (str): Required. The local path to the directory.
53+
54+
Returns:
55+
str: The base64-encoded zipped directory.
56+
"""
57+
zip_bytes = zip_directory(directory_path)
58+
return base64.b64encode(zip_bytes).decode("utf-8")
59+
60+
61+
def await_operation(
62+
*,
63+
operation_name: str,
64+
get_operation_fn: Callable[..., Any],
65+
poll_interval_seconds: float = 10.0,
66+
) -> Any:
67+
"""Waits for a long running operation to complete.
68+
69+
Args:
70+
operation_name (str): Required. The name of the operation.
71+
get_operation_fn (Callable): Required. Function to get the operation
72+
status.
73+
poll_interval_seconds (float): The interval between polls in seconds.
74+
75+
Returns:
76+
Any: The completed operation.
77+
"""
78+
operation = get_operation_fn(operation_name=operation_name)
79+
while not operation.done:
80+
time.sleep(poll_interval_seconds)
81+
operation = get_operation_fn(operation_name=operation.name)
82+
return operation
83+
84+
85+
async def await_operation_async(
86+
*,
87+
operation_name: str,
88+
get_operation_fn: Callable[..., Awaitable[Any]],
89+
poll_interval_seconds: float = 10.0,
90+
) -> Any:
91+
"""Waits for a long running operation to complete asynchronously.
92+
93+
Args:
94+
operation_name (str): Required. The name of the operation.
95+
get_operation_fn (Callable): Required. Async function to get the
96+
operation status.
97+
poll_interval_seconds (float): The interval between polls in seconds.
98+
99+
Returns:
100+
Any: The completed operation.
101+
"""
102+
operation = await get_operation_fn(operation_name=operation_name)
103+
while not operation.done:
104+
await asyncio.sleep(poll_interval_seconds)
105+
operation = await get_operation_fn(operation_name=operation.name)
106+
return operation

0 commit comments

Comments
 (0)