@@ -56,6 +56,23 @@ def _CreateAgentEngineSandboxConfig_to_vertex(
5656 if getv (from_object , ["ttl" ]) is not None :
5757 setv (parent_object , ["ttl" ], getv (from_object , ["ttl" ]))
5858
59+ if getv (from_object , ["sandbox_environment_template" ]) is not None :
60+ setv (
61+ parent_object ,
62+ ["sandboxEnvironmentTemplate" ],
63+ getv (from_object , ["sandbox_environment_template" ]),
64+ )
65+
66+ if getv (from_object , ["sandbox_environment_snapshot" ]) is not None :
67+ setv (
68+ parent_object ,
69+ ["sandboxEnvironmentSnapshot" ],
70+ getv (from_object , ["sandbox_environment_snapshot" ]),
71+ )
72+
73+ if getv (from_object , ["owner" ]) is not None :
74+ setv (parent_object , ["owner" ], getv (from_object , ["owner" ]))
75+
5976 return to_object
6077
6178
@@ -853,7 +870,7 @@ def delete(
853870 def generate_access_token (
854871 self ,
855872 service_account_email : str ,
856- sandbox_id : str ,
873+ sandbox_hostname : str ,
857874 port : str = "8080" ,
858875 timeout : int = 3600 ,
859876 ) -> str :
@@ -862,8 +879,8 @@ def generate_access_token(
862879 Args:
863880 service_account_email (str):
864881 Required. The email of the service account to use for signing.
865- sandbox_id (str):
866- Required. The resource name of the sandbox to generate a token for.
882+ sandbox_hostname (str):
883+ Required. The hostname of the sandbox to generate a token for.
867884 port (str):
868885 Optional. The port to use for the token. Defaults to "8080".
869886 timeout (int):
@@ -874,13 +891,14 @@ def generate_access_token(
874891 """
875892 client = iam_credentials_v1 .IAMCredentialsClient ()
876893 name = f"projects/-/serviceAccounts/{ service_account_email } "
877- custom_claims = {"port " : port , "sandbox_id " : sandbox_id }
894+ custom_claims = {"hostname " : sandbox_hostname , "port " : port }
878895 payload = {
879896 "iat" : int (time .time ()),
880897 "exp" : int (time .time ()) + timeout ,
881898 "iss" : service_account_email ,
899+ "sub" : service_account_email ,
882900 "nonce" : secrets .randbelow (1000000000 ) + 1 ,
883- "aud" : "vmaas-proxy-api " , # default audience for sandbox proxy
901+ "aud" : "https://aiplatform.googleapis.com/ " , # default audience for sandbox proxy
884902 ** custom_claims ,
885903 }
886904 request = iam_credentials_v1 .SignJwtRequest (
@@ -896,6 +914,7 @@ def send_command(
896914 http_method : str ,
897915 access_token : str ,
898916 sandbox_environment : types .SandboxEnvironment ,
917+ port : str = "8080" ,
899918 path : Optional [str ] = None ,
900919 query_params : Optional [dict [str , object ]] = None ,
901920 headers : Optional [dict [str , str ]] = None ,
@@ -910,6 +929,8 @@ def send_command(
910929 Required. The access token to use for authorization.
911930 sandbox_environment (types.SandboxEnvironment):
912931 Required. The sandbox environment to send the command to.
932+ port (str):
933+ Optional. The port to use for the token. Defaults to "8080". This should be one of the ports specified during template creation.
913934 path (str):
914935 Optional. The path to send the command to.
915936 query_params (dict[str, object]):
@@ -934,10 +955,16 @@ def send_command(
934955 else :
935956 raise ValueError ("Load balancer hostname or ip is not available." )
936957
958+ routing_token = connection_info .routing_token
959+ if not routing_token :
960+ raise ValueError ("Routing token is not available." )
961+
937962 path = path or ""
938963 if query_params :
939964 path = f"{ path } ?{ urlencode (query_params )} "
940965 headers ["Authorization" ] = f"Bearer { access_token } "
966+ headers ["X-Sandbox-Routing-Token" ] = routing_token
967+ headers ["X-Sandbox-Port" ] = port
941968 endpoint = endpoint + path if path .startswith ("/" ) else endpoint + "/" + path
942969 http_options = genai_types .HttpOptions (headers = headers , base_url = endpoint )
943970 http_client = genai .Client (vertexai = True , http_options = http_options )
@@ -953,6 +980,7 @@ def generate_browser_ws_headers(
953980 self ,
954981 sandbox_environment : types .SandboxEnvironment ,
955982 service_account_email : str ,
983+ port : str = "8080" ,
956984 timeout : int = 3600 ,
957985 ) -> tuple [str , dict [str , str ]]:
958986 """Generates the websocket upgrade headers for the browser.
@@ -962,47 +990,56 @@ def generate_browser_ws_headers(
962990 Required. The sandbox environment to generate websocket headers for.
963991 service_account_email (str):
964992 Required. The email of the service account to use for signing.
993+ port (str):
994+ Optional. The port to use for the CDP websocket endpoint url fetching.
995+ Defaults to "8080". This should be one of the ports specified during template creation.
965996 timeout (int):
966997 Optional. The timeout in seconds for the token. Defaults to 3600.
967-
968998 Returns:
969999 tuple[str, dict[str, str]]: A tuple containing the websocket URL and
9701000 the headers for websocket upgrade.
9711001 """
972- sandbox_id = sandbox_environment .name
973- # port 8080 is the default port for http endpoint.
1002+ if not sandbox_environment .connection_info :
1003+ raise ValueError ("Connection info is not available." )
1004+
1005+ connection_info = sandbox_environment .connection_info
1006+ if connection_info .load_balancer_hostname :
1007+ ws_base_url = "wss://" + connection_info .load_balancer_hostname
1008+ elif connection_info .load_balancer_ip :
1009+ ws_base_url = "ws://" + connection_info .load_balancer_ip
1010+ else :
1011+ raise ValueError ("Load balancer hostname or ip is not available." )
1012+
9741013 http_access_token = self .generate_access_token (
975- service_account_email , sandbox_id , "8080" , timeout
1014+ service_account_email , connection_info . load_balancer_hostname , port , timeout
9761015 )
9771016 response = self .send_command (
9781017 http_method = "GET" ,
9791018 access_token = http_access_token ,
9801019 sandbox_environment = sandbox_environment ,
1020+ port = port ,
9811021 path = "/cdp_ws_endpoint" ,
9821022 )
9831023 if not response :
9841024 raise ValueError ("Failed to get the websocket endpoint." )
9851025 body_dict = json .loads (response .body )
9861026 ws_path = body_dict ["endpoint" ]
987-
988- ws_url = "wss://test-us-central1.autopush-sandbox.vertexai.goog"
989- if sandbox_environment and sandbox_environment .connection_info :
990- connection_info = sandbox_environment .connection_info
991- if connection_info .load_balancer_hostname :
992- ws_url = "wss://" + connection_info .load_balancer_hostname
993- elif connection_info .load_balancer_ip :
994- ws_url = "ws://" + connection_info .load_balancer_ip
995- else :
996- raise ValueError ("Load balancer hostname or ip is not available." )
997- ws_url = ws_url + "/" + ws_path
1027+ ws_url = ws_base_url + "/" + ws_path
9981028
9991029 # port 9222 is the default port for the browser websocket endpoint.
10001030 ws_access_token = self .generate_access_token (
1001- service_account_email , sandbox_id , "9222" , timeout
1031+ service_account_email ,
1032+ connection_info .load_balancer_hostname ,
1033+ "9222" ,
1034+ timeout ,
10021035 )
10031036
1037+ routing_token = connection_info .routing_token
1038+
10041039 headers = {}
1005- headers ["Sec-WebSocket-Protocol" ] = f"binary, { ws_access_token } "
1040+ headers ["Sec-WebSocket-Protocol" ] = (
1041+ f"v1.stream, { ws_access_token } , { routing_token } , 9222"
1042+ )
10061043 return ws_url , headers
10071044
10081045
0 commit comments