Skip to content

Commit 2d473a7

Browse files
cleop-googlecopybara-github
authored andcommitted
feat: GenAI SDK client(multimodal) - Add to_batch_job_source and get_batch_job_destination to MultimodalDataset
PiperOrigin-RevId: 906352851
1 parent f5c4f8f commit 2d473a7

4 files changed

Lines changed: 142 additions & 73 deletions

File tree

tests/unit/vertexai/genai/replays/test_create_multimodal_datasets.py

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -74,11 +74,20 @@ def mock_import_bigframes(is_replay_mode):
7474

7575
@pytest.fixture
7676
def mock_generate_multimodal_dataset_display_name():
77-
with mock.patch.object(
77+
with mock.patch.object(
7878
_datasets_utils, "generate_multimodal_dataset_display_name"
7979
) as mock_generate:
80-
mock_generate.return_value = "test-generated-name"
81-
yield mock_generate
80+
mock_generate.return_value = "test-generated-name"
81+
yield mock_generate
82+
83+
84+
@pytest.fixture
85+
def mock_get_batch_job_unique_name():
86+
with mock.patch.object(
87+
_datasets_utils, "get_batch_job_unique_name"
88+
) as mock_unique_name:
89+
mock_unique_name.return_value = "12345678901234_abcde"
90+
yield mock_unique_name
8291

8392

8493
def test_create_dataset(client):
@@ -169,43 +178,43 @@ def test_create_dataset_from_pandas(client, is_replay_mode):
169178
)
170179
@pytest.mark.usefixtures("mock_bigquery_client", "mock_import_bigframes")
171180
def test_create_dataset_from_bigframes(client, is_replay_mode):
172-
import bigframes.pandas
181+
import bigframes.pandas
173182

174-
dataframe = pd.DataFrame(
183+
dataframe = pd.DataFrame(
175184
{
176185
"col1": ["col1"],
177186
"col2": ["col2"],
178187
}
179188
)
180-
if is_replay_mode:
181-
bf_dataframe = mock.MagicMock()
182-
bf_dataframe.to_gbq.return_value = "temp_table_id"
183-
else:
184-
bf_dataframe = bigframes.pandas.DataFrame(dataframe)
189+
if is_replay_mode:
190+
bf_dataframe = mock.MagicMock()
191+
bf_dataframe.to_gbq.return_value = "temp_table_id"
192+
else:
193+
bf_dataframe = bigframes.pandas.DataFrame(dataframe)
185194

186-
dataset = client.datasets.create_from_bigframes(
195+
dataset = client.datasets.create_from_bigframes(
187196
dataframe=bf_dataframe,
188197
target_table_id=BIGQUERY_TABLE_NAME,
189198
multimodal_dataset={
190199
"display_name": "test-from-bigframes",
191200
},
192201
)
193202

194-
assert isinstance(dataset, types.MultimodalDataset)
195-
assert dataset.display_name == "test-from-bigframes"
196-
assert dataset.metadata.input_config.bigquery_source.uri == (
203+
assert isinstance(dataset, types.MultimodalDataset)
204+
assert dataset.display_name == "test-from-bigframes"
205+
assert dataset.metadata.input_config.bigquery_source.uri == (
197206
f"bq://{BIGQUERY_TABLE_NAME}"
198207
)
199-
if not is_replay_mode:
200-
bigquery_client = bigquery.Client(
208+
if not is_replay_mode:
209+
bigquery_client = bigquery.Client(
201210
project=client._api_client.project,
202211
location=client._api_client.location,
203212
credentials=client._api_client._credentials,
204213
)
205-
rows = bigquery_client.list_rows(
214+
rows = bigquery_client.list_rows(
206215
dataset.metadata.input_config.bigquery_source.uri[5:]
207216
)
208-
pd.testing.assert_frame_equal(
217+
pd.testing.assert_frame_equal(
209218
rows.to_dataframe(), dataframe, check_index_type=False
210219
)
211220

tests/unit/vertexai/genai/test_multimodal_datasets_genai.py

Lines changed: 84 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -23,140 +23,171 @@
2323

2424
@pytest.fixture
2525
def mock_import_bigframes():
26-
with mock.patch.object(
26+
with mock.patch.object(
2727
_datasets_utils, "_try_import_bigframes"
2828
) as mock_import_bigframes:
29-
mock_read_gbq_table_result = mock.MagicMock()
30-
mock_read_gbq_table_result.sql = "SELECT * FROM `project.dataset.table`"
29+
mock_read_gbq_table_result = mock.MagicMock()
30+
mock_read_gbq_table_result.sql = "SELECT * FROM `project.dataset.table`"
3131

32-
bigframes = mock.MagicMock()
33-
bigframes.pandas.read_gbq_table.return_value = mock_read_gbq_table_result
32+
bigframes = mock.MagicMock()
33+
bigframes.pandas.read_gbq_table.return_value = mock_read_gbq_table_result
3434

35-
mock_import_bigframes.return_value = bigframes
36-
yield mock_import_bigframes
35+
mock_import_bigframes.return_value = bigframes
36+
yield mock_import_bigframes
37+
38+
39+
@pytest.fixture
40+
def mock_get_batch_job_unique_name():
41+
with mock.patch.object(
42+
_datasets_utils, "get_batch_job_unique_name"
43+
) as mock_unique_name:
44+
mock_unique_name.return_value = "12345678901234_abcde"
45+
yield mock_unique_name
3746

3847

3948
class TestMultimodalDataset:
4049

41-
def test_read_config(self):
42-
dataset = types.MultimodalDataset(
50+
def test_read_config(self):
51+
dataset = types.MultimodalDataset(
4352
metadata={
4453
"gemini_request_read_config": {
4554
"assembled_request_column_name": "test_column",
4655
},
4756
},
4857
)
4958

50-
assert isinstance(dataset.read_config, types.GeminiRequestReadConfig)
51-
assert dataset.read_config.assembled_request_column_name == "test_column"
59+
assert isinstance(dataset.read_config, types.GeminiRequestReadConfig)
60+
assert dataset.read_config.assembled_request_column_name == "test_column"
5261

53-
def test_read_config_empty(self):
54-
dataset = types.MultimodalDataset()
55-
assert dataset.read_config is None
62+
def test_read_config_empty(self):
63+
dataset = types.MultimodalDataset()
64+
assert dataset.read_config is None
5665

57-
def test_set_read_config(self):
58-
dataset = types.MultimodalDataset()
66+
def test_set_read_config(self):
67+
dataset = types.MultimodalDataset()
5968

60-
dataset.set_read_config(
69+
dataset.set_read_config(
6170
read_config={
6271
"assembled_request_column_name": "test_column",
6372
},
6473
)
6574

66-
assert isinstance(dataset, types.MultimodalDataset)
67-
assert (
75+
assert isinstance(dataset, types.MultimodalDataset)
76+
assert (
6877
dataset.metadata.gemini_request_read_config.assembled_request_column_name
6978
== "test_column"
7079
)
7180

72-
def test_set_read_config_preserves_other_fields(self):
73-
dataset = types.MultimodalDataset(
81+
def test_set_read_config_preserves_other_fields(self):
82+
dataset = types.MultimodalDataset(
7483
metadata={
7584
"inputConfig": {
7685
"bigquerySource": {"uri": "bq://test_table"},
7786
},
7887
},
7988
)
8089

81-
dataset.set_read_config(
90+
dataset.set_read_config(
8291
read_config={
8392
"assembled_request_column_name": "test_column",
8493
},
8594
)
8695

87-
assert isinstance(dataset, types.MultimodalDataset)
88-
assert (
96+
assert isinstance(dataset, types.MultimodalDataset)
97+
assert (
8998
dataset.metadata.gemini_request_read_config.assembled_request_column_name
9099
== "test_column"
91100
)
92-
assert dataset.metadata.input_config.bigquery_source.uri == "bq://test_table"
101+
assert dataset.metadata.input_config.bigquery_source.uri == "bq://test_table"
93102

94-
def test_bigquery_uri(self):
95-
dataset = types.MultimodalDataset(
103+
def test_bigquery_uri(self):
104+
dataset = types.MultimodalDataset(
96105
metadata={
97106
"inputConfig": {
98107
"bigquerySource": {"uri": "bq://project.dataset.table"},
99108
},
100109
},
101110
)
102111

103-
assert dataset.bigquery_uri == "bq://project.dataset.table"
112+
assert dataset.bigquery_uri == "bq://project.dataset.table"
104113

105-
def test_bigquery_uri_empty(self):
106-
dataset = types.MultimodalDataset()
107-
assert dataset.bigquery_uri is None
114+
def test_bigquery_uri_empty(self):
115+
dataset = types.MultimodalDataset()
116+
assert dataset.bigquery_uri is None
108117

109-
def test_set_bigquery_uri(self):
110-
dataset = types.MultimodalDataset()
118+
def test_set_bigquery_uri(self):
119+
dataset = types.MultimodalDataset()
111120

112-
dataset.set_bigquery_uri("bq://project.dataset.table")
121+
dataset.set_bigquery_uri("bq://project.dataset.table")
113122

114-
assert isinstance(dataset, types.MultimodalDataset)
115-
assert (
123+
assert isinstance(dataset, types.MultimodalDataset)
124+
assert (
116125
dataset.metadata.input_config.bigquery_source.uri
117126
== "bq://project.dataset.table"
118127
)
119128

120-
def test_set_bigquery_uri_without_prefix(self):
121-
dataset = types.MultimodalDataset()
129+
def test_set_bigquery_uri_without_prefix(self):
130+
dataset = types.MultimodalDataset()
122131

123-
dataset.set_bigquery_uri("project.dataset.table")
132+
dataset.set_bigquery_uri("project.dataset.table")
124133

125-
assert isinstance(dataset, types.MultimodalDataset)
126-
assert (
134+
assert isinstance(dataset, types.MultimodalDataset)
135+
assert (
127136
dataset.metadata.input_config.bigquery_source.uri
128137
== "bq://project.dataset.table"
129138
)
130139

131-
def test_set_bigquery_uri_preserves_other_fields(self):
132-
dataset = types.MultimodalDataset(
140+
def test_set_bigquery_uri_preserves_other_fields(self):
141+
dataset = types.MultimodalDataset(
133142
metadata={
134143
"gemini_request_read_config": {
135144
"assembled_request_column_name": "test_column",
136145
},
137146
},
138147
)
139148

140-
dataset.set_bigquery_uri("bq://test_table")
149+
dataset.set_bigquery_uri("bq://test_table")
141150

142-
assert isinstance(dataset, types.MultimodalDataset)
143-
assert dataset.metadata.input_config.bigquery_source.uri == "bq://test_table"
144-
assert (
151+
assert isinstance(dataset, types.MultimodalDataset)
152+
assert dataset.metadata.input_config.bigquery_source.uri == "bq://test_table"
153+
assert (
145154
dataset.metadata.gemini_request_read_config.assembled_request_column_name
146155
== "test_column"
147156
)
148157

149-
def test_to_bigframes(self, mock_import_bigframes):
150-
dataset = types.MultimodalDataset()
151-
dataset.set_bigquery_uri("bq://project.dataset.table")
158+
def test_to_bigframes(self, mock_import_bigframes):
159+
dataset = types.MultimodalDataset()
160+
dataset.set_bigquery_uri("bq://project.dataset.table")
152161

153-
df = dataset.to_bigframes()
162+
df = dataset.to_bigframes()
154163

155-
assert "project.dataset.table" in df.sql
156-
mock_import_bigframes.return_value.pandas.read_gbq_table.assert_called_once_with(
164+
assert "project.dataset.table" in df.sql
165+
mock_import_bigframes.return_value.pandas.read_gbq_table.assert_called_once_with(
157166
"project.dataset.table"
158167
)
159168

169+
def test_get_batch_job_destination(self, mock_get_batch_job_unique_name):
170+
dataset = types.MultimodalDataset(
171+
name="projects/vertex-sdk-dev/locations/us-central1/datasets/12345",
172+
display_name="test_multimodal_dataset",
173+
metadata={
174+
"inputConfig": {
175+
"bigquerySource": {
176+
"uri": "bq://target_project.target_dataset.target_table"
177+
},
178+
},
179+
},
180+
)
181+
destination = dataset.get_batch_job_destination()
182+
assert (
183+
destination.vertex_dataset.display_name
184+
== "test_multimodal_dataset_batch_output_12345678901234_abcde"
185+
)
186+
assert (
187+
destination.vertex_dataset.bigquery_destination
188+
== "bq://target_project.target_dataset.target_table_batch_output_12345678901234_abcde"
189+
)
190+
160191

161192
class TestGeminiRequestReadConfig:
162193
def test_single_turn_template(self):

vertexai/_genai/_datasets_utils.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -229,8 +229,15 @@ def _generate_target_table_id(dataset_id: str) -> str:
229229

230230

231231
def generate_multimodal_dataset_display_name() -> str:
232-
"""Generates a display name with a timestamp."""
233-
return f"MultimodalDataset {datetime.datetime.now().isoformat(sep=' ')}"
232+
"""Generates a display name with a timestamp."""
233+
return f"MultimodalDataset {datetime.datetime.now().isoformat(sep=' ')}"
234+
235+
236+
def get_batch_job_unique_name() -> str:
237+
"""Generates a unique name suffix for a batch job destination."""
238+
timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
239+
unique_id = uuid.uuid4().hex[0:5]
240+
return f"{timestamp}_{unique_id}"
234241

235242

236243
def save_dataframe_to_bigquery(

vertexai/_genai/types/common.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14094,6 +14094,28 @@ def to_bigframes(
1409414094
raise ValueError("Multimodal dataset bigquery source uri is not set.")
1409514095
return bigframes.pandas.read_gbq_table(self.bigquery_uri.removeprefix("bq://"))
1409614096

14097+
def to_batch_job_source(self) -> "genai_types.BatchJobSource":
14098+
"""Converts the dataset to a BatchJobSource."""
14099+
return genai_types.BatchJobSource(
14100+
vertex_dataset_name=self.name,
14101+
)
14102+
14103+
def get_batch_job_destination(self) -> "genai_types.BatchJobDestination":
14104+
"""Converts the dataset to a BatchJobDestination."""
14105+
from .. import _datasets_utils
14106+
14107+
unique_name = _datasets_utils.get_batch_job_unique_name()
14108+
bigquery_uri = self.bigquery_uri
14109+
if bigquery_uri is None:
14110+
raise ValueError("Multimodal dataset bigquery source uri is not set.")
14111+
curr_display_name = self.display_name or "genai_batch_job"
14112+
return genai_types.BatchJobDestination(
14113+
vertex_dataset=genai_types.VertexMultimodalDatasetDestination(
14114+
display_name=f"{curr_display_name}_batch_output_{unique_name}",
14115+
bigquery_destination=f"{bigquery_uri}_batch_output_{unique_name}",
14116+
)
14117+
)
14118+
1409714119

1409814120
class MultimodalDatasetDict(TypedDict, total=False):
1409914121
"""Represents a multimodal dataset."""

0 commit comments

Comments
 (0)