-
Notifications
You must be signed in to change notification settings - Fork 39
feat: add async 'for' loop support to LogScanner (#424) #438
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 11 commits
768266d
0e01b8b
3aa067b
1065665
195ec7c
4ad2fd8
08eef13
68426a0
d619b13
134e56b
efbcb8c
db23dd6
6ad8cab
3981fff
ffd5161
d521133
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,6 +23,7 @@ use arrow_schema::SchemaRef; | |
| use fluss::record::to_arrow_schema; | ||
| use fluss::rpc::message::OffsetSpec; | ||
| use indexmap::IndexMap; | ||
| use pyo3::IntoPyObjectExt; | ||
| use pyo3::exceptions::{PyIndexError, PyRuntimeError, PyTypeError}; | ||
| use pyo3::sync::PyOnceLock; | ||
| use pyo3::types::{ | ||
|
|
@@ -1887,7 +1888,7 @@ impl ScannerKind { | |
| /// Both `LogScanner` and `RecordBatchLogScanner` share the same subscribe interface. | ||
| macro_rules! with_scanner { | ||
| ($scanner:expr, $method:ident($($arg:expr),*)) => { | ||
| match $scanner { | ||
| match $scanner.as_ref() { | ||
| ScannerKind::Record(s) => s.$method($($arg),*).await, | ||
| ScannerKind::Batch(s) => s.$method($($arg),*).await, | ||
| } | ||
|
|
@@ -1901,7 +1902,7 @@ macro_rules! with_scanner { | |
| /// - Batch-based scanning via `poll_arrow()` / `poll_record_batch()` - returns Arrow batches | ||
| #[pyclass] | ||
| pub struct LogScanner { | ||
| scanner: ScannerKind, | ||
| kind: Arc<ScannerKind>, | ||
| admin: fcore::client::FlussAdmin, | ||
| table_info: fcore::metadata::TableInfo, | ||
| /// The projected Arrow schema to use for empty table creation | ||
|
|
@@ -1922,7 +1923,7 @@ impl LogScanner { | |
| fn subscribe(&self, py: Python, bucket_id: i32, start_offset: i64) -> PyResult<()> { | ||
| py.detach(|| { | ||
| TOKIO_RUNTIME.block_on(async { | ||
| with_scanner!(&self.scanner, subscribe(bucket_id, start_offset)) | ||
| with_scanner!(&self.kind, subscribe(bucket_id, start_offset)) | ||
| .map_err(|e| FlussError::from_core_error(&e)) | ||
| }) | ||
| }) | ||
|
|
@@ -1935,7 +1936,7 @@ impl LogScanner { | |
| fn subscribe_buckets(&self, py: Python, bucket_offsets: HashMap<i32, i64>) -> PyResult<()> { | ||
| py.detach(|| { | ||
| TOKIO_RUNTIME.block_on(async { | ||
| with_scanner!(&self.scanner, subscribe_buckets(&bucket_offsets)) | ||
| with_scanner!(&self.kind, subscribe_buckets(&bucket_offsets)) | ||
| .map_err(|e| FlussError::from_core_error(&e)) | ||
| }) | ||
| }) | ||
|
|
@@ -1957,7 +1958,7 @@ impl LogScanner { | |
| py.detach(|| { | ||
| TOKIO_RUNTIME.block_on(async { | ||
| with_scanner!( | ||
| &self.scanner, | ||
| &self.kind, | ||
| subscribe_partition(partition_id, bucket_id, start_offset) | ||
| ) | ||
| .map_err(|e| FlussError::from_core_error(&e)) | ||
|
|
@@ -1977,7 +1978,7 @@ impl LogScanner { | |
| py.detach(|| { | ||
| TOKIO_RUNTIME.block_on(async { | ||
| with_scanner!( | ||
| &self.scanner, | ||
| &self.kind, | ||
| subscribe_partition_buckets(&partition_bucket_offsets) | ||
| ) | ||
| .map_err(|e| FlussError::from_core_error(&e)) | ||
|
|
@@ -1992,7 +1993,7 @@ impl LogScanner { | |
| fn unsubscribe(&self, py: Python, bucket_id: i32) -> PyResult<()> { | ||
| py.detach(|| { | ||
| TOKIO_RUNTIME.block_on(async { | ||
| with_scanner!(&self.scanner, unsubscribe(bucket_id)) | ||
| with_scanner!(&self.kind, unsubscribe(bucket_id)) | ||
| .map_err(|e| FlussError::from_core_error(&e)) | ||
| }) | ||
| }) | ||
|
|
@@ -2006,11 +2007,8 @@ impl LogScanner { | |
| fn unsubscribe_partition(&self, py: Python, partition_id: i64, bucket_id: i32) -> PyResult<()> { | ||
| py.detach(|| { | ||
| TOKIO_RUNTIME.block_on(async { | ||
| with_scanner!( | ||
| &self.scanner, | ||
| unsubscribe_partition(partition_id, bucket_id) | ||
| ) | ||
| .map_err(|e| FlussError::from_core_error(&e)) | ||
| with_scanner!(&self.kind, unsubscribe_partition(partition_id, bucket_id)) | ||
| .map_err(|e| FlussError::from_core_error(&e)) | ||
| }) | ||
| }) | ||
| } | ||
|
|
@@ -2030,7 +2028,7 @@ impl LogScanner { | |
| /// - Returns an empty ScanRecords if no records are available | ||
| /// - When timeout expires, returns an empty ScanRecords (NOT an error) | ||
| fn poll(&self, py: Python, timeout_ms: i64) -> PyResult<ScanRecords> { | ||
| let scanner = self.scanner.as_record()?; | ||
| let scanner = self.kind.as_record()?; | ||
|
|
||
| if timeout_ms < 0 { | ||
| return Err(FlussError::new_err(format!( | ||
|
|
@@ -2079,7 +2077,7 @@ impl LogScanner { | |
| /// - Returns an empty list if no batches are available | ||
| /// - When timeout expires, returns an empty list (NOT an error) | ||
| fn poll_record_batch(&self, py: Python, timeout_ms: i64) -> PyResult<Vec<RecordBatch>> { | ||
| let scanner = self.scanner.as_batch()?; | ||
| let scanner = self.kind.as_batch()?; | ||
|
|
||
| if timeout_ms < 0 { | ||
| return Err(FlussError::new_err(format!( | ||
|
|
@@ -2114,7 +2112,7 @@ impl LogScanner { | |
| /// - Returns an empty table (with correct schema) if no records are available | ||
| /// - When timeout expires, returns an empty table (NOT an error) | ||
| fn poll_arrow(&self, py: Python, timeout_ms: i64) -> PyResult<Py<PyAny>> { | ||
| let scanner = self.scanner.as_batch()?; | ||
| let scanner = self.kind.as_batch()?; | ||
|
|
||
| if timeout_ms < 0 { | ||
| return Err(FlussError::new_err(format!( | ||
|
|
@@ -2167,13 +2165,16 @@ impl LogScanner { | |
| /// Returns: | ||
| /// PyArrow Table containing all data from subscribed buckets | ||
| fn to_arrow(&self, py: Python) -> PyResult<Py<PyAny>> { | ||
| let scanner = self.scanner.as_batch()?; | ||
| let subscribed = scanner.get_subscribed_buckets(); | ||
| if subscribed.is_empty() { | ||
| return Err(FlussError::new_err( | ||
| "No buckets subscribed. Call subscribe(), subscribe_buckets(), subscribe_partition(), or subscribe_partition_buckets() first.", | ||
| )); | ||
| } | ||
| let subscribed = { | ||
| let scanner = self.kind.as_batch()?; | ||
| let subs = scanner.get_subscribed_buckets(); | ||
| if subs.is_empty() { | ||
| return Err(FlussError::new_err( | ||
| "No buckets subscribed. Call subscribe(), subscribe_buckets(), subscribe_partition(), or subscribe_partition_buckets() first.", | ||
| )); | ||
| } | ||
| subs.clone() | ||
| }; | ||
|
|
||
| // 2. Query latest offsets for all subscribed buckets | ||
| let stopping_offsets = self.query_latest_offsets(py, &subscribed)?; | ||
|
|
@@ -2199,6 +2200,171 @@ impl LogScanner { | |
| Ok(df) | ||
| } | ||
|
|
||
| fn __aiter__<'py>(slf: PyRef<'py, Self>) -> PyResult<Bound<'py, PyAny>> { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we need to add this method to .pyi stubs
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi @fresh-borzoni, just added
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it's better to leave
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi @fresh-borzoni, thanks for the clarification, removed those two entries in the |
||
| let py = slf.py(); | ||
|
|
||
| match slf.kind.as_ref() { | ||
| ScannerKind::Record(_) => { | ||
| static RECORD_ASYNC_GEN_FN: PyOnceLock<Py<PyAny>> = PyOnceLock::new(); | ||
| let gen_fn = RECORD_ASYNC_GEN_FN.get_or_init(py, || { | ||
| let code = pyo3::ffi::c_str!( | ||
| r#" | ||
| async def _async_scan(scanner, timeout_ms=1000): | ||
| while True: | ||
| batch = await scanner._async_poll(timeout_ms) | ||
| if batch: | ||
| for record in batch: | ||
| yield record | ||
| "# | ||
| ); | ||
| let globals = pyo3::types::PyDict::new(py); | ||
| py.run(code, Some(&globals), None).unwrap(); | ||
| globals.get_item("_async_scan").unwrap().unwrap().unbind() | ||
| }); | ||
| gen_fn.bind(py).call1((slf.into_bound_py_any(py)?,)) | ||
| } | ||
| ScannerKind::Batch(_) => { | ||
| static BATCH_ASYNC_GEN_FN: PyOnceLock<Py<PyAny>> = PyOnceLock::new(); | ||
| let gen_fn = BATCH_ASYNC_GEN_FN.get_or_init(py, || { | ||
| let code = pyo3::ffi::c_str!( | ||
| r#" | ||
| async def _async_batch_scan(scanner, timeout_ms=1000): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The two
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi @fresh-borzoni, made the changes in 3981fff, tested locally and passing. |
||
| while True: | ||
| batches = await scanner._async_poll_batches(timeout_ms) | ||
| if batches: | ||
| for rb in batches: | ||
| yield rb | ||
| "# | ||
| ); | ||
| let globals = pyo3::types::PyDict::new(py); | ||
| py.run(code, Some(&globals), None).unwrap(); | ||
| globals | ||
| .get_item("_async_batch_scan") | ||
| .unwrap() | ||
| .unwrap() | ||
| .unbind() | ||
| }); | ||
| gen_fn.bind(py).call1((slf.into_bound_py_any(py)?,)) | ||
| } | ||
| } | ||
| } | ||
|
Comment on lines
+2201
to
+2239
|
||
|
|
||
| /// Perform a single bounded poll and return a list of ScanRecord objects. | ||
| /// | ||
| /// This is the async building block used by `__aiter__` to implement | ||
| /// `async for`. Each call does exactly one network poll (bounded by | ||
| /// `timeout_ms`), converts any results to Python objects, and returns | ||
| /// them as a list. An empty list signals a timeout (no data yet), not | ||
| /// end-of-stream. | ||
| /// | ||
| /// Args: | ||
| /// timeout_ms: Timeout in milliseconds for the network poll (default: 1000) | ||
| /// | ||
| /// Returns: | ||
| /// Awaitable that resolves to a list of ScanRecord objects | ||
| fn _async_poll<'py>( | ||
| &self, | ||
| py: Python<'py>, | ||
| timeout_ms: Option<i64>, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems like we hardcode this anyway in generated code, do we need to define an arg for this? Also nit: poll_interval_ms is more accurate
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi @leekeiabstraction, in d521133 I made this into a global variable, |
||
| ) -> PyResult<Bound<'py, PyAny>> { | ||
| let timeout_ms = timeout_ms.unwrap_or(1000); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: magic number
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi @leekeiabstraction, this is mentioned above. |
||
| if timeout_ms < 0 { | ||
| return Err(FlussError::new_err(format!( | ||
| "timeout_ms must be non-negative, got: {timeout_ms}" | ||
| ))); | ||
| } | ||
|
|
||
| let scanner = Arc::clone(&self.kind); | ||
| let projected_row_type = self.projected_row_type.clone(); | ||
| let timeout = Duration::from_millis(timeout_ms as u64); | ||
|
|
||
| future_into_py(py, async move { | ||
| let core_scanner = match scanner.as_ref() { | ||
| ScannerKind::Record(s) => s, | ||
| ScannerKind::Batch(_) => { | ||
| return Err(PyTypeError::new_err( | ||
| "This internal method only supports record-based scanners. \ | ||
| For batch-based scanners, use 'async for' or 'poll_record_batch' instead.", | ||
| )); | ||
| } | ||
| }; | ||
|
|
||
| let scan_records = core_scanner | ||
| .poll(timeout) | ||
| .await | ||
| .map_err(|e| FlussError::from_core_error(&e))?; | ||
|
|
||
| // Convert to Python list | ||
| Python::attach(|py| { | ||
| let mut result: Vec<Py<ScanRecord>> = Vec::new(); | ||
| for (_, records) in scan_records.into_records_by_buckets() { | ||
| for core_record in records { | ||
| let scan_record = | ||
| ScanRecord::from_core(py, &core_record, &projected_row_type)?; | ||
| result.push(Py::new(py, scan_record)?); | ||
| } | ||
| } | ||
| Ok(result) | ||
| }) | ||
| }) | ||
| } | ||
|
|
||
| /// Perform a single bounded poll and return a list of RecordBatch objects. | ||
| /// | ||
| /// This is the async building block used by `__aiter__` (batch mode) to | ||
| /// implement `async for`. Each call does exactly one network poll (bounded | ||
| /// by `timeout_ms`), converts any results to Python RecordBatch objects, | ||
| /// and returns them as a list. An empty list signals a timeout (no data | ||
| /// yet), not end-of-stream. | ||
| /// | ||
| /// Args: | ||
| /// timeout_ms: Timeout in milliseconds for the network poll (default: 1000) | ||
| /// | ||
| /// Returns: | ||
| /// Awaitable that resolves to a list of RecordBatch objects | ||
| fn _async_poll_batches<'py>( | ||
| &self, | ||
| py: Python<'py>, | ||
| timeout_ms: Option<i64>, | ||
| ) -> PyResult<Bound<'py, PyAny>> { | ||
| let timeout_ms = timeout_ms.unwrap_or(1000); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: magic number
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi @leekeiabstraction, this is mentioned above. |
||
| if timeout_ms < 0 { | ||
| return Err(FlussError::new_err(format!( | ||
| "timeout_ms must be non-negative, got: {timeout_ms}" | ||
| ))); | ||
| } | ||
|
|
||
| let scanner = Arc::clone(&self.kind); | ||
| let timeout = Duration::from_millis(timeout_ms as u64); | ||
|
|
||
| future_into_py(py, async move { | ||
| let core_scanner = match scanner.as_ref() { | ||
| ScannerKind::Batch(s) => s, | ||
| ScannerKind::Record(_) => { | ||
| return Err(PyTypeError::new_err( | ||
| "This internal method only supports batch-based scanners. \ | ||
| For record-based scanners, use 'async for' or 'poll' instead.", | ||
| )); | ||
| } | ||
| }; | ||
|
|
||
| let scan_batches = core_scanner | ||
| .poll(timeout) | ||
| .await | ||
| .map_err(|e| FlussError::from_core_error(&e))?; | ||
|
|
||
| // Convert to Python list of RecordBatch objects | ||
| Python::attach(|py| { | ||
| let mut result: Vec<Py<RecordBatch>> = Vec::new(); | ||
| for scan_batch in scan_batches { | ||
| let rb = RecordBatch::from_scan_batch(scan_batch); | ||
| result.push(Py::new(py, rb)?); | ||
| } | ||
| Ok(result) | ||
| }) | ||
| }) | ||
| } | ||
|
|
||
| fn __repr__(&self) -> String { | ||
| format!("LogScanner(table={})", self.table_info.table_path) | ||
| } | ||
|
|
@@ -2213,7 +2379,7 @@ impl LogScanner { | |
| projected_row_type: fcore::metadata::RowType, | ||
| ) -> Self { | ||
| Self { | ||
| scanner, | ||
| kind: Arc::new(scanner), | ||
| admin, | ||
| table_info, | ||
| projected_schema, | ||
|
|
@@ -2264,7 +2430,7 @@ impl LogScanner { | |
| py: Python, | ||
| subscribed: &[(fcore::metadata::TableBucket, i64)], | ||
| ) -> PyResult<HashMap<fcore::metadata::TableBucket, i64>> { | ||
| let scanner = self.scanner.as_batch()?; | ||
| let scanner = self.kind.as_batch()?; | ||
| let is_partitioned = scanner.is_partitioned(); | ||
| let table_path = &self.table_info.table_path; | ||
|
|
||
|
|
@@ -2367,7 +2533,7 @@ impl LogScanner { | |
| py: Python, | ||
| mut stopping_offsets: HashMap<fcore::metadata::TableBucket, i64>, | ||
| ) -> PyResult<Py<PyAny>> { | ||
| let scanner = self.scanner.as_batch()?; | ||
| let scanner = self.kind.as_batch()?; | ||
| let mut all_batches = Vec::new(); | ||
|
|
||
| while !stopping_offsets.is_empty() { | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: scoping block + subs.clone() was needed with the Mutex, not needed with Arc - all borrows are shared now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @fresh-borzoni, made the changes in 6ad8cab, tested locally and passing.