Skip to content

Commit 9dccee2

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: faster sandbox creation with templates and snapshots and improve dataplane routing and security.
Multitenancy Sandbox support PiperOrigin-RevId: 910405607
1 parent 9ea4aa6 commit 9dccee2

3 files changed

Lines changed: 124 additions & 26 deletions

File tree

tests/unit/vertexai/genai/test_sandbox.py

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,19 @@
1515

1616
import importlib
1717
import os
18-
1918
from unittest import mock
2019

2120
from google import auth
2221
from google.auth import credentials as auth_credentials
2322
from google.cloud import aiplatform
2423
import vertexai
2524
from google.cloud.aiplatform import initializer
25+
from vertexai._genai import sandboxes
2626
from google.genai import client
2727
from google.genai import types as genai_types
2828
import pytest
2929

30+
3031
_TEST_CREDENTIALS = mock.Mock(spec=auth_credentials.AnonymousCredentials())
3132
_TEST_LOCATION = "us-central1"
3233
_TEST_PROJECT = "test-project"
@@ -73,8 +74,11 @@ def teardown_method(self):
7374
@mock.patch.object(client.Client, "_get_api_client")
7475
def test_send_command(self, mock_get_api_client):
7576
mock_sandbox = mock.Mock()
76-
mock_sandbox.connection_info.load_balancer_ip = "127.0.0.1"
77-
mock_sandbox.connection_info.load_balancer_hostname = None
77+
mock_sandbox.connection_info.load_balancer_ip = None
78+
mock_sandbox.connection_info.load_balancer_hostname = (
79+
"test-us-central1.sandbox.vertexai.goog"
80+
)
81+
mock_sandbox.connection_info.routing_token = "test_routing_token"
7882
mock_http_client = mock_get_api_client.return_value
7983
mock_http_client.request.return_value = genai_types.HttpResponse(
8084
body=b"{}", headers={}
@@ -91,7 +95,39 @@ def test_send_command(self, mock_get_api_client):
9195
assert call_args is not None
9296
_, kwargs = call_args
9397
http_options = kwargs["http_options"]
94-
assert http_options.base_url == "http://127.0.0.1/test/path"
98+
assert http_options.base_url == (
99+
"https://test-us-central1.sandbox.vertexai.goog/test/path"
100+
)
95101
assert http_options.headers["Authorization"] == "Bearer test_token"
96102

97103
mock_http_client.request.assert_called_with("GET", "test/path", {})
104+
105+
@mock.patch.object(sandboxes.Sandboxes, "generate_access_token")
106+
@mock.patch.object(client.Client, "_get_api_client")
107+
def test_generate_browser_ws_headers(
108+
self, mock_get_api_client, mock_generate_access_token
109+
):
110+
mock_generate_access_token.return_value = "test_token"
111+
112+
mock_sandbox = mock.Mock()
113+
mock_sandbox.connection_info.load_balancer_ip = None
114+
mock_sandbox.connection_info.load_balancer_hostname = (
115+
"test-us-central1.sandbox.vertexai.goog"
116+
)
117+
mock_sandbox.connection_info.routing_token = "test_routing_token"
118+
mock_http_client = mock_get_api_client.return_value
119+
mock_http_client.request.return_value = genai_types.HttpResponse(
120+
body=b'{"endpoint": "test/endpoint"}', headers={}
121+
)
122+
ws_url, headers = (
123+
self.client.agent_engines.sandboxes.generate_browser_ws_headers(
124+
sandbox_environment=mock_sandbox,
125+
service_account_email=_TEST_SERVICE_ACCOUNT_EMAIL,
126+
timeout=3600,
127+
)
128+
)
129+
assert ws_url == "wss://test-us-central1.sandbox.vertexai.goog/test/endpoint"
130+
assert (
131+
headers["Sec-WebSocket-Protocol"]
132+
== "v1.stream, test_token, test_routing_token, 9222"
133+
)

vertexai/_genai/sandboxes.py

Lines changed: 59 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,23 @@ def _CreateAgentEngineSandboxConfig_to_vertex(
5656
if getv(from_object, ["ttl"]) is not None:
5757
setv(parent_object, ["ttl"], getv(from_object, ["ttl"]))
5858

59+
if getv(from_object, ["sandbox_environment_template"]) is not None:
60+
setv(
61+
parent_object,
62+
["sandboxEnvironmentTemplate"],
63+
getv(from_object, ["sandbox_environment_template"]),
64+
)
65+
66+
if getv(from_object, ["sandbox_environment_snapshot"]) is not None:
67+
setv(
68+
parent_object,
69+
["sandboxEnvironmentSnapshot"],
70+
getv(from_object, ["sandbox_environment_snapshot"]),
71+
)
72+
73+
if getv(from_object, ["owner"]) is not None:
74+
setv(parent_object, ["owner"], getv(from_object, ["owner"]))
75+
5976
return to_object
6077

6178

@@ -853,7 +870,7 @@ def delete(
853870
def generate_access_token(
854871
self,
855872
service_account_email: str,
856-
sandbox_id: str,
873+
sandbox_hostname: str,
857874
port: str = "8080",
858875
timeout: int = 3600,
859876
) -> str:
@@ -862,8 +879,8 @@ def generate_access_token(
862879
Args:
863880
service_account_email (str):
864881
Required. The email of the service account to use for signing.
865-
sandbox_id (str):
866-
Required. The resource name of the sandbox to generate a token for.
882+
sandbox_hostname (str):
883+
Required. The hostname of the sandbox to generate a token for.
867884
port (str):
868885
Optional. The port to use for the token. Defaults to "8080".
869886
timeout (int):
@@ -874,13 +891,14 @@ def generate_access_token(
874891
"""
875892
client = iam_credentials_v1.IAMCredentialsClient()
876893
name = f"projects/-/serviceAccounts/{service_account_email}"
877-
custom_claims = {"port": port, "sandbox_id": sandbox_id}
894+
custom_claims = {"hostname": sandbox_hostname, "port": port}
878895
payload = {
879896
"iat": int(time.time()),
880897
"exp": int(time.time()) + timeout,
881898
"iss": service_account_email,
899+
"sub": service_account_email,
882900
"nonce": secrets.randbelow(1000000000) + 1,
883-
"aud": "vmaas-proxy-api", # default audience for sandbox proxy
901+
"aud": "https://aiplatform.googleapis.com/", # default audience for sandbox proxy
884902
**custom_claims,
885903
}
886904
request = iam_credentials_v1.SignJwtRequest(
@@ -896,6 +914,7 @@ def send_command(
896914
http_method: str,
897915
access_token: str,
898916
sandbox_environment: types.SandboxEnvironment,
917+
port: str = "8080",
899918
path: Optional[str] = None,
900919
query_params: Optional[dict[str, object]] = None,
901920
headers: Optional[dict[str, str]] = None,
@@ -910,6 +929,8 @@ def send_command(
910929
Required. The access token to use for authorization.
911930
sandbox_environment (types.SandboxEnvironment):
912931
Required. The sandbox environment to send the command to.
932+
port (str):
933+
Optional. The port to use for the token. Defaults to "8080". This should be one of the ports specified during template creation.
913934
path (str):
914935
Optional. The path to send the command to.
915936
query_params (dict[str, object]):
@@ -934,10 +955,16 @@ def send_command(
934955
else:
935956
raise ValueError("Load balancer hostname or ip is not available.")
936957

958+
routing_token = connection_info.routing_token
959+
if not routing_token:
960+
raise ValueError("Routing token is not available.")
961+
937962
path = path or ""
938963
if query_params:
939964
path = f"{path}?{urlencode(query_params)}"
940965
headers["Authorization"] = f"Bearer {access_token}"
966+
headers["X-Sandbox-Routing-Token"] = routing_token
967+
headers["X-Sandbox-Port"] = port
941968
endpoint = endpoint + path if path.startswith("/") else endpoint + "/" + path
942969
http_options = genai_types.HttpOptions(headers=headers, base_url=endpoint)
943970
http_client = genai.Client(vertexai=True, http_options=http_options)
@@ -953,6 +980,7 @@ def generate_browser_ws_headers(
953980
self,
954981
sandbox_environment: types.SandboxEnvironment,
955982
service_account_email: str,
983+
port: str = "8080",
956984
timeout: int = 3600,
957985
) -> tuple[str, dict[str, str]]:
958986
"""Generates the websocket upgrade headers for the browser.
@@ -962,47 +990,56 @@ def generate_browser_ws_headers(
962990
Required. The sandbox environment to generate websocket headers for.
963991
service_account_email (str):
964992
Required. The email of the service account to use for signing.
993+
port (str):
994+
Optional. The port to use for the CDP websocket endpoint url fetching.
995+
Defaults to "8080". This should be one of the ports specified during template creation.
965996
timeout (int):
966997
Optional. The timeout in seconds for the token. Defaults to 3600.
967-
968998
Returns:
969999
tuple[str, dict[str, str]]: A tuple containing the websocket URL and
9701000
the headers for websocket upgrade.
9711001
"""
972-
sandbox_id = sandbox_environment.name
973-
# port 8080 is the default port for http endpoint.
1002+
if not sandbox_environment.connection_info:
1003+
raise ValueError("Connection info is not available.")
1004+
1005+
connection_info = sandbox_environment.connection_info
1006+
if connection_info.load_balancer_hostname:
1007+
ws_base_url = "wss://" + connection_info.load_balancer_hostname
1008+
elif connection_info.load_balancer_ip:
1009+
ws_base_url = "ws://" + connection_info.load_balancer_ip
1010+
else:
1011+
raise ValueError("Load balancer hostname or ip is not available.")
1012+
9741013
http_access_token = self.generate_access_token(
975-
service_account_email, sandbox_id, "8080", timeout
1014+
service_account_email, connection_info.load_balancer_hostname, port, timeout
9761015
)
9771016
response = self.send_command(
9781017
http_method="GET",
9791018
access_token=http_access_token,
9801019
sandbox_environment=sandbox_environment,
1020+
port=port,
9811021
path="/cdp_ws_endpoint",
9821022
)
9831023
if not response:
9841024
raise ValueError("Failed to get the websocket endpoint.")
9851025
body_dict = json.loads(response.body)
9861026
ws_path = body_dict["endpoint"]
987-
988-
ws_url = "wss://test-us-central1.autopush-sandbox.vertexai.goog"
989-
if sandbox_environment and sandbox_environment.connection_info:
990-
connection_info = sandbox_environment.connection_info
991-
if connection_info.load_balancer_hostname:
992-
ws_url = "wss://" + connection_info.load_balancer_hostname
993-
elif connection_info.load_balancer_ip:
994-
ws_url = "ws://" + connection_info.load_balancer_ip
995-
else:
996-
raise ValueError("Load balancer hostname or ip is not available.")
997-
ws_url = ws_url + "/" + ws_path
1027+
ws_url = ws_base_url + "/" + ws_path
9981028

9991029
# port 9222 is the default port for the browser websocket endpoint.
10001030
ws_access_token = self.generate_access_token(
1001-
service_account_email, sandbox_id, "9222", timeout
1031+
service_account_email,
1032+
connection_info.load_balancer_hostname,
1033+
"9222",
1034+
timeout,
10021035
)
10031036

1037+
routing_token = connection_info.routing_token
1038+
10041039
headers = {}
1005-
headers["Sec-WebSocket-Protocol"] = f"binary, {ws_access_token}"
1040+
headers["Sec-WebSocket-Protocol"] = (
1041+
f"v1.stream, {ws_access_token}, {routing_token}, 9222"
1042+
)
10061043
return ws_url, headers
10071044

10081045

vertexai/_genai/types/common.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11744,6 +11744,20 @@ class CreateAgentEngineSandboxConfig(_common.BaseModel):
1174411744
default=None,
1174511745
description="""The TTL for this resource. The expiration time is computed: now + TTL.""",
1174611746
)
11747+
sandbox_environment_template: Optional[str] = Field(
11748+
default=None,
11749+
description="""The name of the sandbox environment template to create the sandbox from. The sandbox environment template should be in the format:
11750+
projects/{project}/locations/{location}/agentEngines/{agent_engine}/sandboxEnvironmentTemplates/{sandbox_environment_template}""",
11751+
)
11752+
sandbox_environment_snapshot: Optional[str] = Field(
11753+
default=None,
11754+
description="""The name of the sandbox environment snapshot to restore the sandbox from. The sandbox environment snapshot should be in the format:
11755+
projects/{project}/locations/{location}/agentEngines/{agent_engine}/sandboxEnvironmentSnapshots/{sandbox_environment_snapshot}""",
11756+
)
11757+
owner: Optional[str] = Field(
11758+
default=None,
11759+
description="""Owner information for this sandbox environment. A sandbox can only be restored from a snapshot belonging to the same owner.""",
11760+
)
1174711761

1174811762

1174911763
class CreateAgentEngineSandboxConfigDict(TypedDict, total=False):
@@ -11764,6 +11778,17 @@ class CreateAgentEngineSandboxConfigDict(TypedDict, total=False):
1176411778
ttl: Optional[str]
1176511779
"""The TTL for this resource. The expiration time is computed: now + TTL."""
1176611780

11781+
sandbox_environment_template: Optional[str]
11782+
"""The name of the sandbox environment template to create the sandbox from. The sandbox environment template should be in the format:
11783+
projects/{project}/locations/{location}/agentEngines/{agent_engine}/sandboxEnvironmentTemplates/{sandbox_environment_template}"""
11784+
11785+
sandbox_environment_snapshot: Optional[str]
11786+
"""The name of the sandbox environment snapshot to restore the sandbox from. The sandbox environment snapshot should be in the format:
11787+
projects/{project}/locations/{location}/agentEngines/{agent_engine}/sandboxEnvironmentSnapshots/{sandbox_environment_snapshot}"""
11788+
11789+
owner: Optional[str]
11790+
"""Owner information for this sandbox environment. A sandbox can only be restored from a snapshot belonging to the same owner."""
11791+
1176711792

1176811793
CreateAgentEngineSandboxConfigOrDict = Union[
1176911794
CreateAgentEngineSandboxConfig, CreateAgentEngineSandboxConfigDict

0 commit comments

Comments
 (0)