Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions bindings/python/fluss/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,17 @@

from enum import IntEnum
from types import TracebackType
from typing import Dict, Iterator, List, Optional, Tuple, Union, overload
from typing import (
Any,
AsyncIterator,
Dict,
Iterator,
List,
Optional,
Tuple,
Union,
overload,
)

import pandas as pd
import pyarrow as pa
Expand Down Expand Up @@ -765,8 +775,12 @@ class LogScanner:

You must call subscribe(), subscribe_buckets(), or subscribe_partition() first.
"""
...
def __repr__(self) -> str: ...
def __aiter__(self) -> AsyncIterator[Union[ScanRecord, RecordBatch]]: ...
async def _async_poll(self, timeout_ms: Optional[int] = ...) -> List[ScanRecord]: ...
async def _async_poll_batches(
self, timeout_ms: Optional[int] = ...
) -> List[RecordBatch]: ...

class Schema:
def __init__(
Expand Down
216 changes: 191 additions & 25 deletions bindings/python/src/table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use arrow_schema::SchemaRef;
use fluss::record::to_arrow_schema;
use fluss::rpc::message::OffsetSpec;
use indexmap::IndexMap;
use pyo3::IntoPyObjectExt;
use pyo3::exceptions::{PyIndexError, PyRuntimeError, PyTypeError};
use pyo3::sync::PyOnceLock;
use pyo3::types::{
Expand Down Expand Up @@ -1887,7 +1888,7 @@ impl ScannerKind {
/// Both `LogScanner` and `RecordBatchLogScanner` share the same subscribe interface.
macro_rules! with_scanner {
($scanner:expr, $method:ident($($arg:expr),*)) => {
match $scanner {
match $scanner.as_ref() {
ScannerKind::Record(s) => s.$method($($arg),*).await,
ScannerKind::Batch(s) => s.$method($($arg),*).await,
}
Expand All @@ -1901,7 +1902,7 @@ macro_rules! with_scanner {
/// - Batch-based scanning via `poll_arrow()` / `poll_record_batch()` - returns Arrow batches
#[pyclass]
pub struct LogScanner {
scanner: ScannerKind,
kind: Arc<ScannerKind>,
admin: fcore::client::FlussAdmin,
table_info: fcore::metadata::TableInfo,
/// The projected Arrow schema to use for empty table creation
Expand All @@ -1922,7 +1923,7 @@ impl LogScanner {
fn subscribe(&self, py: Python, bucket_id: i32, start_offset: i64) -> PyResult<()> {
py.detach(|| {
TOKIO_RUNTIME.block_on(async {
with_scanner!(&self.scanner, subscribe(bucket_id, start_offset))
with_scanner!(&self.kind, subscribe(bucket_id, start_offset))
.map_err(|e| FlussError::from_core_error(&e))
})
})
Expand All @@ -1935,7 +1936,7 @@ impl LogScanner {
fn subscribe_buckets(&self, py: Python, bucket_offsets: HashMap<i32, i64>) -> PyResult<()> {
py.detach(|| {
TOKIO_RUNTIME.block_on(async {
with_scanner!(&self.scanner, subscribe_buckets(&bucket_offsets))
with_scanner!(&self.kind, subscribe_buckets(&bucket_offsets))
.map_err(|e| FlussError::from_core_error(&e))
})
})
Expand All @@ -1957,7 +1958,7 @@ impl LogScanner {
py.detach(|| {
TOKIO_RUNTIME.block_on(async {
with_scanner!(
&self.scanner,
&self.kind,
subscribe_partition(partition_id, bucket_id, start_offset)
)
.map_err(|e| FlussError::from_core_error(&e))
Expand All @@ -1977,7 +1978,7 @@ impl LogScanner {
py.detach(|| {
TOKIO_RUNTIME.block_on(async {
with_scanner!(
&self.scanner,
&self.kind,
subscribe_partition_buckets(&partition_bucket_offsets)
)
.map_err(|e| FlussError::from_core_error(&e))
Expand All @@ -1992,7 +1993,7 @@ impl LogScanner {
fn unsubscribe(&self, py: Python, bucket_id: i32) -> PyResult<()> {
py.detach(|| {
TOKIO_RUNTIME.block_on(async {
with_scanner!(&self.scanner, unsubscribe(bucket_id))
with_scanner!(&self.kind, unsubscribe(bucket_id))
.map_err(|e| FlussError::from_core_error(&e))
})
})
Expand All @@ -2006,11 +2007,8 @@ impl LogScanner {
fn unsubscribe_partition(&self, py: Python, partition_id: i64, bucket_id: i32) -> PyResult<()> {
py.detach(|| {
TOKIO_RUNTIME.block_on(async {
with_scanner!(
&self.scanner,
unsubscribe_partition(partition_id, bucket_id)
)
.map_err(|e| FlussError::from_core_error(&e))
with_scanner!(&self.kind, unsubscribe_partition(partition_id, bucket_id))
.map_err(|e| FlussError::from_core_error(&e))
})
})
}
Expand All @@ -2030,7 +2028,7 @@ impl LogScanner {
/// - Returns an empty ScanRecords if no records are available
/// - When timeout expires, returns an empty ScanRecords (NOT an error)
fn poll(&self, py: Python, timeout_ms: i64) -> PyResult<ScanRecords> {
let scanner = self.scanner.as_record()?;
let scanner = self.kind.as_record()?;

if timeout_ms < 0 {
return Err(FlussError::new_err(format!(
Expand Down Expand Up @@ -2079,7 +2077,7 @@ impl LogScanner {
/// - Returns an empty list if no batches are available
/// - When timeout expires, returns an empty list (NOT an error)
fn poll_record_batch(&self, py: Python, timeout_ms: i64) -> PyResult<Vec<RecordBatch>> {
let scanner = self.scanner.as_batch()?;
let scanner = self.kind.as_batch()?;

if timeout_ms < 0 {
return Err(FlussError::new_err(format!(
Expand Down Expand Up @@ -2114,7 +2112,7 @@ impl LogScanner {
/// - Returns an empty table (with correct schema) if no records are available
/// - When timeout expires, returns an empty table (NOT an error)
fn poll_arrow(&self, py: Python, timeout_ms: i64) -> PyResult<Py<PyAny>> {
let scanner = self.scanner.as_batch()?;
let scanner = self.kind.as_batch()?;

if timeout_ms < 0 {
return Err(FlussError::new_err(format!(
Expand Down Expand Up @@ -2167,13 +2165,16 @@ impl LogScanner {
/// Returns:
/// PyArrow Table containing all data from subscribed buckets
fn to_arrow(&self, py: Python) -> PyResult<Py<PyAny>> {
let scanner = self.scanner.as_batch()?;
let subscribed = scanner.get_subscribed_buckets();
if subscribed.is_empty() {
return Err(FlussError::new_err(
"No buckets subscribed. Call subscribe(), subscribe_buckets(), subscribe_partition(), or subscribe_partition_buckets() first.",
));
}
let subscribed = {
let scanner = self.kind.as_batch()?;
let subs = scanner.get_subscribed_buckets();
if subs.is_empty() {
return Err(FlussError::new_err(
"No buckets subscribed. Call subscribe(), subscribe_buckets(), subscribe_partition(), or subscribe_partition_buckets() first.",
));
}
subs.clone()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: scoping block + subs.clone() was needed with the Mutex, not needed with Arc - all borrows are shared now

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @fresh-borzoni, made the changes in 6ad8cab, tested locally and passing.

};

// 2. Query latest offsets for all subscribed buckets
let stopping_offsets = self.query_latest_offsets(py, &subscribed)?;
Expand All @@ -2199,6 +2200,171 @@ impl LogScanner {
Ok(df)
}

fn __aiter__<'py>(slf: PyRef<'py, Self>) -> PyResult<Bound<'py, PyAny>> {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need to add this method to .pyi stubs

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @fresh-borzoni, just added __aiter__ to __init__.pyi here 134e56b along with with _async_poll and _async_poll_batches.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's better to leave _async_poll and _async_poll_batches out of .pyi bc these methods ideally should be private implementation details.
So exposing __aiter__ makes sense to just signal IDE that we support async for, but the rest of underscore methods added - we don't want to encourage users to use them directly

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @fresh-borzoni, thanks for the clarification, removed those two entries in the .pyi file in db23dd6.

let py = slf.py();

match slf.kind.as_ref() {
ScannerKind::Record(_) => {
static RECORD_ASYNC_GEN_FN: PyOnceLock<Py<PyAny>> = PyOnceLock::new();
let gen_fn = RECORD_ASYNC_GEN_FN.get_or_init(py, || {
let code = pyo3::ffi::c_str!(
r#"
async def _async_scan(scanner, timeout_ms=1000):
while True:
batch = await scanner._async_poll(timeout_ms)
if batch:
for record in batch:
yield record
"#
);
let globals = pyo3::types::PyDict::new(py);
py.run(code, Some(&globals), None).unwrap();
globals.get_item("_async_scan").unwrap().unwrap().unbind()
});
gen_fn.bind(py).call1((slf.into_bound_py_any(py)?,))
}
ScannerKind::Batch(_) => {
static BATCH_ASYNC_GEN_FN: PyOnceLock<Py<PyAny>> = PyOnceLock::new();
let gen_fn = BATCH_ASYNC_GEN_FN.get_or_init(py, || {
let code = pyo3::ffi::c_str!(
r#"
async def _async_batch_scan(scanner, timeout_ms=1000):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The two __aiter__ branches are identical except for the poll method name. You can collapse to a single PyOnceLock + generator that takes a callable, and dispatch by passing _async_poll or _async_poll_batches

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @fresh-borzoni, made the changes in 3981fff, tested locally and passing.

while True:
batches = await scanner._async_poll_batches(timeout_ms)
if batches:
for rb in batches:
yield rb
"#
);
let globals = pyo3::types::PyDict::new(py);
py.run(code, Some(&globals), None).unwrap();
globals
.get_item("_async_batch_scan")
.unwrap()
.unwrap()
.unbind()
});
gen_fn.bind(py).call1((slf.into_bound_py_any(py)?,))
}
}
}
Comment on lines +2201 to +2239
Copy link

Copilot AI Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

__aiter__ recompiles and executes Python source via py.run() on every iteration start. Consider caching the adapter function (e.g., in a PyOnceLock) or returning self directly as the async iterator if possible; this avoids repeated code compilation and reduces overhead per async for loop.

Copilot uses AI. Check for mistakes.

/// Perform a single bounded poll and return a list of ScanRecord objects.
///
/// This is the async building block used by `__aiter__` to implement
/// `async for`. Each call does exactly one network poll (bounded by
/// `timeout_ms`), converts any results to Python objects, and returns
/// them as a list. An empty list signals a timeout (no data yet), not
/// end-of-stream.
///
/// Args:
/// timeout_ms: Timeout in milliseconds for the network poll (default: 1000)
///
/// Returns:
/// Awaitable that resolves to a list of ScanRecord objects
fn _async_poll<'py>(
&self,
py: Python<'py>,
timeout_ms: Option<i64>,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like we hardcode this anyway in generated code, do we need to define an arg for this?

Also nit: poll_interval_ms is more accurate

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @leekeiabstraction, in d521133 I made this into a global variable, DEFAULT_POLL_INTERVAL_MS that is initialized (to 1000) at the top of the script. It's then later referenced within the _async_poll and _async_poll_batches functions.

) -> PyResult<Bound<'py, PyAny>> {
let timeout_ms = timeout_ms.unwrap_or(1000);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: magic number

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @leekeiabstraction, this is mentioned above.

if timeout_ms < 0 {
return Err(FlussError::new_err(format!(
"timeout_ms must be non-negative, got: {timeout_ms}"
)));
}

let scanner = Arc::clone(&self.kind);
let projected_row_type = self.projected_row_type.clone();
let timeout = Duration::from_millis(timeout_ms as u64);

future_into_py(py, async move {
let core_scanner = match scanner.as_ref() {
ScannerKind::Record(s) => s,
ScannerKind::Batch(_) => {
return Err(PyTypeError::new_err(
"This internal method only supports record-based scanners. \
For batch-based scanners, use 'async for' or 'poll_record_batch' instead.",
));
}
};

let scan_records = core_scanner
.poll(timeout)
.await
.map_err(|e| FlussError::from_core_error(&e))?;

// Convert to Python list
Python::attach(|py| {
let mut result: Vec<Py<ScanRecord>> = Vec::new();
for (_, records) in scan_records.into_records_by_buckets() {
for core_record in records {
let scan_record =
ScanRecord::from_core(py, &core_record, &projected_row_type)?;
result.push(Py::new(py, scan_record)?);
}
}
Ok(result)
})
})
}

/// Perform a single bounded poll and return a list of RecordBatch objects.
///
/// This is the async building block used by `__aiter__` (batch mode) to
/// implement `async for`. Each call does exactly one network poll (bounded
/// by `timeout_ms`), converts any results to Python RecordBatch objects,
/// and returns them as a list. An empty list signals a timeout (no data
/// yet), not end-of-stream.
///
/// Args:
/// timeout_ms: Timeout in milliseconds for the network poll (default: 1000)
///
/// Returns:
/// Awaitable that resolves to a list of RecordBatch objects
fn _async_poll_batches<'py>(
&self,
py: Python<'py>,
timeout_ms: Option<i64>,
) -> PyResult<Bound<'py, PyAny>> {
let timeout_ms = timeout_ms.unwrap_or(1000);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: magic number

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @leekeiabstraction, this is mentioned above.

if timeout_ms < 0 {
return Err(FlussError::new_err(format!(
"timeout_ms must be non-negative, got: {timeout_ms}"
)));
}

let scanner = Arc::clone(&self.kind);
let timeout = Duration::from_millis(timeout_ms as u64);

future_into_py(py, async move {
let core_scanner = match scanner.as_ref() {
ScannerKind::Batch(s) => s,
ScannerKind::Record(_) => {
return Err(PyTypeError::new_err(
"This internal method only supports batch-based scanners. \
For record-based scanners, use 'async for' or 'poll' instead.",
));
}
};

let scan_batches = core_scanner
.poll(timeout)
.await
.map_err(|e| FlussError::from_core_error(&e))?;

// Convert to Python list of RecordBatch objects
Python::attach(|py| {
let mut result: Vec<Py<RecordBatch>> = Vec::new();
for scan_batch in scan_batches {
let rb = RecordBatch::from_scan_batch(scan_batch);
result.push(Py::new(py, rb)?);
}
Ok(result)
})
})
}

fn __repr__(&self) -> String {
format!("LogScanner(table={})", self.table_info.table_path)
}
Expand All @@ -2213,7 +2379,7 @@ impl LogScanner {
projected_row_type: fcore::metadata::RowType,
) -> Self {
Self {
scanner,
kind: Arc::new(scanner),
admin,
table_info,
projected_schema,
Expand Down Expand Up @@ -2264,7 +2430,7 @@ impl LogScanner {
py: Python,
subscribed: &[(fcore::metadata::TableBucket, i64)],
) -> PyResult<HashMap<fcore::metadata::TableBucket, i64>> {
let scanner = self.scanner.as_batch()?;
let scanner = self.kind.as_batch()?;
let is_partitioned = scanner.is_partitioned();
let table_path = &self.table_info.table_path;

Expand Down Expand Up @@ -2367,7 +2533,7 @@ impl LogScanner {
py: Python,
mut stopping_offsets: HashMap<fcore::metadata::TableBucket, i64>,
) -> PyResult<Py<PyAny>> {
let scanner = self.scanner.as_batch()?;
let scanner = self.kind.as_batch()?;
let mut all_batches = Vec::new();

while !stopping_offsets.is_empty() {
Expand Down
Loading
Loading