Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions tests/unit/vertexai/genai/replays/test_create_evaluation_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,31 @@ def test_create_eval_run_with_inference_configs(client):
assert evaluation_run.error is None


def test_create_eval_run_with_allow_cross_region_model(client):
"""Tests that create_evaluation_run() works with allow_cross_region_model in config."""
client._api_client._http_options.api_version = "v1beta1"
inference_config = types.EvaluationRunInferenceConfig(
model=MODEL_NAME,
prompt_template=types.EvaluationRunPromptTemplate(
prompt_template="test prompt template"
),
)
evaluation_run = client.evals.create_evaluation_run(
name="test_inference_config",
display_name="test_inference_config",
dataset=types.EvaluationRunDataSource(evaluation_set=EVAL_SET_NAME),
dest=GCS_DEST,
metrics=[GENERAL_QUALITY_METRIC],
inference_configs={"model_1": inference_config},
labels={"label1": "value1"},
config={"allow_cross_region_model": True},
)
assert isinstance(evaluation_run, types.EvaluationRun)
assert evaluation_run.display_name == "test_inference_config"
assert evaluation_run.state == types.EvaluationRunState.PENDING
assert evaluation_run.error is None


@mock.patch("uuid.uuid4")
def test_create_eval_run_with_metric_resource_name(mock_uuid4, client):
"""Tests create_evaluation_run with metric_resource_name."""
Expand Down
101 changes: 101 additions & 0 deletions tests/unit/vertexai/genai/test_evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -9136,3 +9136,104 @@ def test_computation_metric_retry_on_resource_exhausted(
summary_metric = result.summary_metrics[0]
assert summary_metric.metric_name == "bleu"
assert summary_metric.mean_score == 0.85


class TestAllowCrossRegionModel:
"""Tests for allow_cross_region_model flag for create_evaluation_run."""

def setup_method(self, method):
self.mock_api_client = mock.MagicMock()
self.mock_api_client.vertexai = True

self.mock_response = mock.MagicMock()
self.mock_response.body = json.dumps(
{
"name": "projects/123/locations/us-central1/evaluationRuns/456",
"displayName": "test_run",
"state": "PENDING",
}
)
self.mock_api_client.request.return_value = self.mock_response

def test_create_evaluation_run_config_has_allow_cross_region_model(self):
"""Verifies allow_cross_region_model field exists on CreateEvaluationRunConfig."""
config = vertexai_genai_types.CreateEvaluationRunConfig(
allow_cross_region_model=True,
)
assert config.allow_cross_region_model is True

def test_create_evaluation_run_config_from_dict(self):
"""Verifies allow_cross_region_model can be set via dict on CreateEvaluationRunConfig."""
config = vertexai_genai_types.CreateEvaluationRunConfig.model_validate(
{"allow_cross_region_model": True}
)
assert config.allow_cross_region_model is True

def test_create_evaluation_run_config_default_is_none(self):
"""Verifies the default value of allow_cross_region_model is None."""
config = vertexai_genai_types.CreateEvaluationRunConfig()
assert config.allow_cross_region_model is None

def test_create_evaluation_run_passes_allow_cross_region_model(self):
"""Verifies allow_cross_region_model is sent inside evaluationConfig in the API request."""
evals_module = evals.Evals(api_client_=self.mock_api_client)

evals_module.create_evaluation_run(
dataset=vertexai_genai_types.EvaluationRunDataSource(
evaluation_set="projects/123/locations/us-central1/evaluationSets/789"
),
metrics=[
vertexai_genai_types.EvaluationRunMetric(
metric="general_quality_v1",
metric_config=vertexai_genai_types.UnifiedMetric(
predefined_metric_spec=genai_types.PredefinedMetricSpec(
metric_spec_name="general_quality_v1",
)
),
)
],
dest="gs://test-bucket/output",
config={"allow_cross_region_model": True},
)

self.mock_api_client.request.assert_called_once()
call_args = self.mock_api_client.request.call_args
request_body = call_args[0][2] # Third positional arg is the request dict
assert (
request_body.get("evaluationConfig", {}).get("allowCrossRegionModel")
is True
)

@pytest.mark.asyncio
async def test_create_evaluation_run_async_passes_allow_cross_region_model(self):
"""Verifies allow_cross_region_model is sent inside evaluationConfig in the async API request."""
self.mock_api_client.async_request = mock.AsyncMock(
return_value=self.mock_response
)
async_evals_module = evals.AsyncEvals(api_client_=self.mock_api_client)

await async_evals_module.create_evaluation_run(
dataset=vertexai_genai_types.EvaluationRunDataSource(
evaluation_set="projects/123/locations/us-central1/evaluationSets/789"
),
metrics=[
vertexai_genai_types.EvaluationRunMetric(
metric="general_quality_v1",
metric_config=vertexai_genai_types.UnifiedMetric(
predefined_metric_spec=genai_types.PredefinedMetricSpec(
metric_spec_name="general_quality_v1",
)
),
)
],
dest="gs://test-bucket/output",
config={"allow_cross_region_model": True},
)

self.mock_api_client.async_request.assert_called_once()
call_args = self.mock_api_client.async_request.call_args
request_body = call_args[0][2] # Third positional arg is the request dict
assert (
request_body.get("evaluationConfig", {}).get("allowCrossRegionModel")
is True
)
35 changes: 35 additions & 0 deletions vertexai/_genai/evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,13 @@ def _EvaluationRunConfig_from_vertex(
[item for item in getv(from_object, ["lossAnalysisConfig"])],
)

if getv(from_object, ["allowCrossRegionModel"]) is not None:
setv(
to_object,
["allow_cross_region_model"],
getv(from_object, ["allowCrossRegionModel"]),
)

return to_object


Expand Down Expand Up @@ -425,6 +432,13 @@ def _EvaluationRunConfig_to_vertex(
[item for item in getv(from_object, ["loss_analysis_config"])],
)

if getv(from_object, ["allow_cross_region_model"]) is not None:
setv(
to_object,
["allowCrossRegionModel"],
getv(from_object, ["allow_cross_region_model"]),
)

return to_object


Expand Down Expand Up @@ -2653,6 +2667,13 @@ def create_evaluation_run(
``max_top_cluster_count``. Mutually exclusive with
``loss_analysis_metrics``.
config: The configuration for the evaluation run.
- allow_cross_region_model: Allows the evaluation run to use cross
region models. When this flag is set, the service may route traffic to
other regions if a model is unavailable in the current region (e.g.,
to a `global`endpoint). If a fully-qualified model endpoint resource
name with a different region than the run location is provided
elsewhere in the runconfig, this flag must be set to true or the
request will fail.

Returns:
The created evaluation run.
Expand All @@ -2672,6 +2693,11 @@ def create_evaluation_run(
else (agent_info or evals_types.AgentInfo())
)

if not config:
config = types.CreateEvaluationRunConfig()
if isinstance(config, dict):
config = types.CreateEvaluationRunConfig.model_validate(config)

if agent_info and not inference_configs:
parsed_user_simulator_config = (
evals_types.UserSimulatorConfig.model_validate(user_simulator_config)
Expand Down Expand Up @@ -2712,6 +2738,7 @@ def create_evaluation_run(
output_config=output_config,
metrics=resolved_metrics,
loss_analysis_config=resolved_loss_configs,
allow_cross_region_model=getattr(config, "allow_cross_region_model", None),
)
resolved_inference_configs = _evals_common._resolve_inference_configs(
self._api_client, resolved_dataset, inference_configs, parsed_agent_info
Expand Down Expand Up @@ -4422,6 +4449,8 @@ async def create_evaluation_run(
``max_top_cluster_count``. Mutually exclusive with
``loss_analysis_metrics``.
config: The configuration for the evaluation run.
- allow_cross_region_model: Opt-in flag to authorize cross-region
routing for model inference. Applies to both scraping and evaluation.

Returns:
The created evaluation run.
Expand All @@ -4441,6 +4470,11 @@ async def create_evaluation_run(
else (agent_info or evals_types.AgentInfo())
)

if not config:
config = types.CreateEvaluationRunConfig()
if isinstance(config, dict):
config = types.CreateEvaluationRunConfig.model_validate(config)

if agent_info and not inference_configs:
parsed_user_simulator_config = (
evals_types.UserSimulatorConfig.model_validate(user_simulator_config)
Expand Down Expand Up @@ -4481,6 +4515,7 @@ async def create_evaluation_run(
output_config=output_config,
metrics=resolved_metrics,
loss_analysis_config=resolved_loss_configs,
allow_cross_region_model=getattr(config, "allow_cross_region_model", None),
)
resolved_inference_configs = _evals_common._resolve_inference_configs(
self._api_client, resolved_dataset, inference_configs, parsed_agent_info
Expand Down
26 changes: 26 additions & 0 deletions vertexai/_genai/types/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2401,6 +2401,11 @@ class EvaluationRunConfig(_common.BaseModel):
default=None,
description="""Specifications for loss analysis. Each config specifies a metric and candidate to analyze for loss patterns.""",
)
allow_cross_region_model: Optional[bool] = Field(
default=None,
description="""Opt-in flag to authorize cross-region routing for model inference.
Applies to both scraping and evaluation.""",
)


class EvaluationRunConfigDict(TypedDict, total=False):
Expand All @@ -2421,6 +2426,10 @@ class EvaluationRunConfigDict(TypedDict, total=False):
loss_analysis_config: Optional[list[LossAnalysisConfigDict]]
"""Specifications for loss analysis. Each config specifies a metric and candidate to analyze for loss patterns."""

allow_cross_region_model: Optional[bool]
"""Opt-in flag to authorize cross-region routing for model inference.
Applies to both scraping and evaluation."""


EvaluationRunConfigOrDict = Union[EvaluationRunConfig, EvaluationRunConfigDict]

Expand Down Expand Up @@ -2551,6 +2560,15 @@ class CreateEvaluationRunConfig(_common.BaseModel):
http_options: Optional[genai_types.HttpOptions] = Field(
default=None, description="""Used to override HTTP request options."""
)
allow_cross_region_model: Optional[bool] = Field(
default=None,
description="""Allows the evaluation run to use cross region models. When this
flag is set, the service may route traffic to other regions if a model is
unavailable in the current region (e.g., to a `global`endpoint). If a
fully-qualified model endpoint resource name with a different region than
the run location is provided elsewhere in the runconfig, this flag must
be set to true or the request will fail.""",
)


class CreateEvaluationRunConfigDict(TypedDict, total=False):
Expand All @@ -2559,6 +2577,14 @@ class CreateEvaluationRunConfigDict(TypedDict, total=False):
http_options: Optional[genai_types.HttpOptionsDict]
"""Used to override HTTP request options."""

allow_cross_region_model: Optional[bool]
"""Allows the evaluation run to use cross region models. When this
flag is set, the service may route traffic to other regions if a model is
unavailable in the current region (e.g., to a `global`endpoint). If a
fully-qualified model endpoint resource name with a different region than
the run location is provided elsewhere in the runconfig, this flag must
be set to true or the request will fail."""


CreateEvaluationRunConfigOrDict = Union[
CreateEvaluationRunConfig, CreateEvaluationRunConfigDict
Expand Down
Loading