Skip to content

Commit d1f61ae

Browse files
cleop-googlecopybara-github
authored andcommitted
feat: GenAI SDK client(multimodal) - Implement create_from_gemini_request_jsonl in Datasets and AsyncDatasets classes.
PiperOrigin-RevId: 908222654
1 parent 762d20c commit d1f61ae

3 files changed

Lines changed: 303 additions & 0 deletions

File tree

tests/unit/vertexai/genai/replays/test_create_multimodal_datasets.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,26 @@ def mock_generate_multimodal_dataset_display_name():
8181
yield mock_generate
8282

8383

84+
@pytest.fixture
85+
def mock_try_import_storage():
86+
with mock.patch.object(
87+
_datasets_utils, "_try_import_storage"
88+
) as mock_import_storage:
89+
blob = mock.MagicMock()
90+
blob.download_as_text.return_value = (
91+
'{"contents": ["test1"]}\n{"contents": ["test2"]}'
92+
)
93+
94+
bucket = mock.MagicMock()
95+
bucket.blob.return_value = blob
96+
97+
client = mock.MagicMock()
98+
client.bucket.return_value = bucket
99+
mock_import_storage.return_value.Client.return_value = client
100+
101+
yield mock_import_storage
102+
103+
84104
def test_create_dataset(client):
85105
create_dataset_operation = client.datasets._create_multimodal_dataset(
86106
name="projects/vertex-sdk-dev/locations/us-central1",
@@ -295,6 +315,43 @@ def test_create_dataset_from_bigframes_preserves_other_metadata(client, is_repla
295315
)
296316

297317

318+
@pytest.mark.usefixtures(
319+
"mock_bigquery_client", "mock_import_bigframes", "mock_try_import_storage"
320+
)
321+
def test_create_from_gemini_request_jsonl(client, is_replay_mode):
322+
if is_replay_mode:
323+
with mock.patch.object(client.datasets, "create_from_bigframes") as mock_create:
324+
mock_ds = mock.MagicMock()
325+
mock_ds.display_name = "test-from-gemini-jsonl"
326+
mock_create.return_value = mock_ds
327+
328+
dataset = client.datasets.create_from_gemini_request_jsonl(
329+
gcs_uri="gs://test-bucket/test-blob.jsonl",
330+
target_table_id=BIGQUERY_TABLE_NAME,
331+
multimodal_dataset={
332+
"display_name": "test-from-gemini-jsonl",
333+
},
334+
)
335+
assert dataset.display_name == "test-from-gemini-jsonl"
336+
assert (
337+
dataset.metadata.gemini_request_read_config.assembled_request_column_name
338+
== "requests"
339+
)
340+
else:
341+
dataset = client.datasets.create_from_gemini_request_jsonl(
342+
gcs_uri="gs://test-bucket/test-blob.jsonl",
343+
target_table_id=BIGQUERY_TABLE_NAME,
344+
multimodal_dataset={
345+
"display_name": "test-from-gemini-jsonl",
346+
},
347+
)
348+
assert dataset.display_name == "test-from-gemini-jsonl"
349+
assert (
350+
dataset.metadata.gemini_request_read_config.assembled_request_column_name
351+
== "requests"
352+
)
353+
354+
298355
pytestmark = pytest_helper.setup(
299356
file=__file__,
300357
globals_for_file=globals(),
@@ -549,3 +606,43 @@ async def test_create_dataset_from_bigframes_preserves_other_metadata_async(
549606
assert dataset.metadata.input_config.bigquery_source.uri == (
550607
f"bq://{BIGQUERY_TABLE_NAME}"
551608
)
609+
610+
611+
@pytest.mark.asyncio
612+
@pytest.mark.usefixtures(
613+
"mock_bigquery_client", "mock_import_bigframes", "mock_try_import_storage"
614+
)
615+
async def test_create_from_gemini_request_jsonl_async(client, is_replay_mode):
616+
if is_replay_mode:
617+
with mock.patch.object(
618+
client.aio.datasets, "create_from_bigframes"
619+
) as mock_create:
620+
mock_ds = mock.MagicMock()
621+
mock_ds.display_name = "test-from-gemini-jsonl-async"
622+
mock_create.return_value = mock_ds
623+
624+
dataset = await client.aio.datasets.create_from_gemini_request_jsonl(
625+
gcs_uri="gs://test-bucket/test-blob-async.jsonl",
626+
target_table_id=BIGQUERY_TABLE_NAME,
627+
multimodal_dataset={
628+
"display_name": "test-from-gemini-jsonl-async",
629+
},
630+
)
631+
assert dataset.display_name == "test-from-gemini-jsonl-async"
632+
assert (
633+
dataset.metadata.gemini_request_read_config.assembled_request_column_name
634+
== "requests"
635+
)
636+
else:
637+
dataset = await client.aio.datasets.create_from_gemini_request_jsonl(
638+
gcs_uri="gs://test-bucket/test-blob-async.jsonl",
639+
target_table_id=BIGQUERY_TABLE_NAME,
640+
multimodal_dataset={
641+
"display_name": "test-from-gemini-jsonl-async",
642+
},
643+
)
644+
assert dataset.display_name == "test-from-gemini-jsonl-async"
645+
assert (
646+
dataset.metadata.gemini_request_read_config.assembled_request_column_name
647+
== "requests"
648+
)

vertexai/_genai/_datasets_utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,19 @@ def _try_import_bigquery() -> Any:
115115
) from exc
116116

117117

118+
def _try_import_storage() -> Any:
119+
"""Tries to import `storage`."""
120+
try:
121+
from google.cloud import storage # type: ignore[attr-defined]
122+
123+
return storage
124+
except ImportError as exc:
125+
raise ImportError(
126+
"`storage` is not installed. Please call 'pip install"
127+
" google-cloud-storage'."
128+
) from exc
129+
130+
118131
def _bq_dataset_location_allowed(
119132
vertex_location: str, bq_dataset_location: str
120133
) -> bool:

vertexai/_genai/datasets.py

Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
# Code generated by the Google Gen AI SDK generator DO NOT EDIT.
1717

1818
import asyncio
19+
import io
1920
import json
2021
import logging
2122
import time
@@ -1112,6 +1113,102 @@ def create_from_bigframes(
11121113
multimodal_dataset=multimodal_dataset, config=config
11131114
)
11141115

1116+
def create_from_gemini_request_jsonl(
1117+
self,
1118+
*,
1119+
gcs_uri: str,
1120+
multimodal_dataset: Optional[types.MultimodalDatasetOrDict] = None,
1121+
target_table_id: Optional[str] = None,
1122+
config: Optional[types.CreateMultimodalDatasetConfigOrDict] = None,
1123+
) -> types.MultimodalDataset:
1124+
"""Creates a multimodal dataset from a JSONL file stored on GCS.
1125+
1126+
The JSONL file should contain instances of Gemini
1127+
`GenerateContentRequest` on each line. The data will be stored in a
1128+
BigQuery table with a single column called "requests". The
1129+
request_column_name in the dataset metadata will be set to "requests".
1130+
1131+
Args:
1132+
gcs_uri (str):
1133+
The Google Cloud Storage URI of the JSONL file to import.
1134+
For example, 'gs://my-bucket/path/to/data.jsonl'
1135+
multimodal_dataset:
1136+
Optional. A representation of a multimodal dataset.
1137+
target_table_id (str):
1138+
Optional. The BigQuery table id where the dataframe will be
1139+
uploaded. The table id can be in the format of "dataset.table"
1140+
or "project.dataset.table". Note that the BigQuery
1141+
dataset must already exist and be in the same location as the
1142+
multimodal dataset. If not provided, a generated table id will
1143+
be created in the `vertex_datasets` dataset (e.g.
1144+
`project.vertex_datasets_us_central1.multimodal_dataset_4cbf7ffd`).
1145+
config:
1146+
Optional. A configuration for creating the multimodal dataset. If not
1147+
provided, the default configuration will be used.
1148+
1149+
Returns:
1150+
The created multimodal dataset.
1151+
"""
1152+
storage = _datasets_utils._try_import_storage()
1153+
1154+
if isinstance(multimodal_dataset, dict):
1155+
multimodal_dataset = types.MultimodalDataset(**multimodal_dataset)
1156+
elif not multimodal_dataset:
1157+
multimodal_dataset = types.MultimodalDataset()
1158+
1159+
gcs_uri_prefix = "gs://"
1160+
if gcs_uri.startswith(gcs_uri_prefix):
1161+
gcs_uri = gcs_uri[len(gcs_uri_prefix) :]
1162+
parts = gcs_uri.split("/", 1)
1163+
if len(parts) != 2:
1164+
raise ValueError(
1165+
"Invalid GCS URI format. Expected: gs://bucket-name/object-path"
1166+
)
1167+
bucket_name = parts[0]
1168+
blob_name = parts[1]
1169+
1170+
project = self._api_client.project
1171+
location = self._api_client.location
1172+
credentials = self._api_client._credentials
1173+
1174+
storage_client = storage.Client(project=project)
1175+
bucket = storage_client.bucket(bucket_name)
1176+
blob = bucket.blob(blob_name)
1177+
request_column_name = "requests"
1178+
1179+
jsonl_string = blob.download_as_text()
1180+
lines = [line.strip() for line in jsonl_string.splitlines() if line.strip()]
1181+
json_string = json.dumps({request_column_name: lines})
1182+
1183+
multimodal_dataset = multimodal_dataset.model_copy(deep=True)
1184+
metadata = multimodal_dataset.metadata or types.SchemaTablesDatasetMetadata()
1185+
1186+
read_config = (
1187+
metadata.gemini_request_read_config or types.GeminiRequestReadConfig()
1188+
)
1189+
read_config.assembled_request_column_name = request_column_name
1190+
metadata.gemini_request_read_config = read_config
1191+
1192+
multimodal_dataset.metadata = metadata
1193+
1194+
bigframes = _datasets_utils._try_import_bigframes()
1195+
session_options = bigframes.BigQueryOptions(
1196+
credentials=credentials,
1197+
project=project,
1198+
location=location,
1199+
)
1200+
with bigframes.connect(session_options) as session:
1201+
temp_bigframes_df = session.read_json(io.StringIO(json_string))
1202+
temp_bigframes_df[request_column_name] = bigframes.bigquery.parse_json(
1203+
temp_bigframes_df[request_column_name]
1204+
)
1205+
return self.create_from_bigframes(
1206+
dataframe=temp_bigframes_df,
1207+
multimodal_dataset=multimodal_dataset,
1208+
target_table_id=target_table_id,
1209+
config=config,
1210+
)
1211+
11151212
def update_multimodal_dataset(
11161213
self,
11171214
*,
@@ -2400,6 +2497,102 @@ async def create_from_bigframes(
24002497
multimodal_dataset=multimodal_dataset, config=config
24012498
)
24022499

2500+
async def create_from_gemini_request_jsonl(
2501+
self,
2502+
*,
2503+
gcs_uri: str,
2504+
multimodal_dataset: Optional[types.MultimodalDatasetOrDict] = None,
2505+
target_table_id: Optional[str] = None,
2506+
config: Optional[types.CreateMultimodalDatasetConfigOrDict] = None,
2507+
) -> types.MultimodalDataset:
2508+
"""Creates a multimodal dataset from a JSONL file stored on GCS.
2509+
2510+
The JSONL file should contain instances of Gemini
2511+
`GenerateContentRequest` on each line. The data will be stored in a
2512+
BigQuery table with a single column called "requests". The
2513+
request_column_name in the dataset metadata will be set to "requests".
2514+
2515+
Args:
2516+
gcs_uri (str):
2517+
The Google Cloud Storage URI of the JSONL file to import.
2518+
For example, 'gs://my-bucket/path/to/data.jsonl'
2519+
multimodal_dataset:
2520+
Optional. A representation of a multimodal dataset.
2521+
target_table_id (str):
2522+
Optional. The BigQuery table id where the dataframe will be
2523+
uploaded. The table id can be in the format of "dataset.table"
2524+
or "project.dataset.table". Note that the BigQuery
2525+
dataset must already exist and be in the same location as the
2526+
multimodal dataset. If not provided, a generated table id will
2527+
be created in the `vertex_datasets` dataset (e.g.
2528+
`project.vertex_datasets_us_central1.multimodal_dataset_4cbf7ffd`).
2529+
config:
2530+
Optional. A configuration for creating the multimodal dataset. If not
2531+
provided, the default configuration will be used.
2532+
2533+
Returns:
2534+
The created multimodal dataset.
2535+
"""
2536+
storage = _datasets_utils._try_import_storage()
2537+
2538+
if isinstance(multimodal_dataset, dict):
2539+
multimodal_dataset = types.MultimodalDataset(**multimodal_dataset)
2540+
elif not multimodal_dataset:
2541+
multimodal_dataset = types.MultimodalDataset()
2542+
2543+
gcs_uri_prefix = "gs://"
2544+
if gcs_uri.startswith(gcs_uri_prefix):
2545+
gcs_uri = gcs_uri[len(gcs_uri_prefix) :]
2546+
parts = gcs_uri.split("/", 1)
2547+
if len(parts) != 2:
2548+
raise ValueError(
2549+
"Invalid GCS URI format. Expected: gs://bucket-name/object-path"
2550+
)
2551+
bucket_name = parts[0]
2552+
blob_name = parts[1]
2553+
2554+
project = self._api_client.project
2555+
location = self._api_client.location
2556+
credentials = self._api_client._credentials
2557+
2558+
storage_client = storage.Client(project=project)
2559+
bucket = storage_client.bucket(bucket_name)
2560+
blob = bucket.blob(blob_name)
2561+
request_column_name = "requests"
2562+
2563+
jsonl_string = await asyncio.to_thread(blob.download_as_text)
2564+
lines = [line.strip() for line in jsonl_string.splitlines() if line.strip()]
2565+
json_string = json.dumps({request_column_name: lines})
2566+
2567+
multimodal_dataset = multimodal_dataset.model_copy(deep=True)
2568+
metadata = multimodal_dataset.metadata or types.SchemaTablesDatasetMetadata()
2569+
2570+
read_config = (
2571+
metadata.gemini_request_read_config or types.GeminiRequestReadConfig()
2572+
)
2573+
read_config.assembled_request_column_name = request_column_name
2574+
metadata.gemini_request_read_config = read_config
2575+
2576+
multimodal_dataset.metadata = metadata
2577+
2578+
bigframes = _datasets_utils._try_import_bigframes()
2579+
session_options = bigframes.BigQueryOptions(
2580+
credentials=credentials,
2581+
project=project,
2582+
location=location,
2583+
)
2584+
with bigframes.connect(session_options) as session:
2585+
temp_bigframes_df = session.read_json(io.StringIO(json_string))
2586+
temp_bigframes_df[request_column_name] = bigframes.bigquery.parse_json(
2587+
temp_bigframes_df[request_column_name]
2588+
)
2589+
return await self.create_from_bigframes(
2590+
dataframe=temp_bigframes_df,
2591+
multimodal_dataset=multimodal_dataset,
2592+
target_table_id=target_table_id,
2593+
config=config,
2594+
)
2595+
24032596
async def update_multimodal_dataset(
24042597
self,
24052598
*,

0 commit comments

Comments
 (0)