Skip to content

Commit 00dfe10

Browse files
authored
Merge pull request #47 from pmorris-dev/fix/interruptible-transaction-drop-leaks-writer
fix: rollback on drop so write conn returns to pool clean (#46)
2 parents 240eb77 + 2751290 commit 00dfe10

5 files changed

Lines changed: 322 additions & 60 deletions

File tree

README.md

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -381,8 +381,10 @@ await tx.commit();
381381
* Only one interruptible transaction can be active per database at a time
382382
* The write lock is held for the entire duration - keep transactions short
383383
* Uncommitted writes are visible only within the transaction's `read()` method
384-
* Always commit or rollback - abandoned transactions will rollback automatically
385-
on app exit
384+
* If the transaction handle is dropped without calling `commit()` or
385+
`rollback()`, the transaction is automatically rolled back and the write
386+
connection is released back to the pool. This also happens on app exit
387+
and on transaction timeout.
386388

387389
To rollback instead of committing:
388390

@@ -784,6 +786,10 @@ println!("Transaction completed: {} statements executed", results.len());
784786

785787
For transactions that need to read data mid-transaction:
786788

789+
If `tx` is dropped without calling `commit()` or `rollback()` — including via
790+
an early return from a `?` operator — the transaction is automatically rolled
791+
back and the write connection is released back to the pool.
792+
787793
```rust
788794
// Assuming user_id, product_id, item_total are defined in your application context
789795
let user_id = 123;

crates/sqlx-sqlite-conn-mgr/src/database.rs

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use sqlx::{ConnectOptions, Pool, Sqlite};
1010
use std::path::{Path, PathBuf};
1111
use std::sync::Arc;
1212
use std::sync::atomic::{AtomicBool, Ordering};
13-
use tracing::error;
13+
use tracing::{error, warn};
1414

1515
/// Analysis limit for PRAGMA optimize on close.
1616
/// SQLite recommends 100-1000 for older versions; 3.46.0+ handles automatically.
@@ -177,12 +177,40 @@ impl SqliteDatabase {
177177
.read_only(false)
178178
.optimize_on_close(true, OPTIMIZE_ANALYSIS_LIMIT);
179179

180+
// Defense-in-depth: when any writer is returned to the pool, issue
181+
// ROLLBACK to discard any transaction that a caller may have left open
182+
// (e.g., a writer dropped after BEGIN without COMMIT/ROLLBACK). SQLite
183+
// only auto-rollbacks on connection close, not on pool return, so
184+
// without this the next acquire_writer() sees "cannot start a
185+
// transaction within a transaction".
186+
//
187+
// Error handling: the expected benign case on a clean connection is
188+
// "cannot rollback - no transaction is active" — recycle normally.
189+
// Anything else means ROLLBACK itself failed or the connection is
190+
// wedged; tell the pool not to recycle so a broken connection isn't
191+
// handed to the next caller.
180192
let write_conn = SqlitePoolOptions::new()
181193
.max_connections(1)
182194
.min_connections(0)
183195
.idle_timeout(Some(std::time::Duration::from_secs(
184196
config.idle_timeout_secs,
185197
)))
198+
.after_release(|conn, _meta| {
199+
Box::pin(async move {
200+
match sqlx::query("ROLLBACK").execute(&mut *conn).await {
201+
Ok(_) => Ok(true),
202+
Err(sqlx::Error::Database(e))
203+
if e.message().contains("no transaction is active") =>
204+
{
205+
Ok(true)
206+
}
207+
Err(err) => {
208+
warn!("after_release ROLLBACK failed, discarding connection: {err}");
209+
Ok(false)
210+
}
211+
}
212+
})
213+
})
186214
.connect_with(write_options)
187215
.await?;
188216

crates/sqlx-sqlite-toolkit/src/transactions.rs

Lines changed: 109 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -99,15 +99,28 @@ pub struct ActiveInterruptibleTransaction {
9999
transaction_id: String,
100100
writer: Option<TransactionWriter>,
101101
created_at: Instant,
102+
// Captured at construction so Drop can always spawn the rollback task on a
103+
// valid runtime, even when the struct is dropped from a thread that has no
104+
// tokio thread-local (e.g., Tauri teardown on the main thread). Without a
105+
// stored handle, Drop's synchronous path through PoolConnection::Drop would
106+
// call sqlx's rt::spawn and panic with "this functionality requires a Tokio
107+
// context".
108+
runtime_handle: tokio::runtime::Handle,
102109
}
103110

104111
impl ActiveInterruptibleTransaction {
112+
/// # Panics
113+
///
114+
/// Panics if called outside a tokio runtime context. Both production call
115+
/// sites (the plugin command handler and the direct Rust API) run inside
116+
/// async functions, so this is a programming error, not a runtime risk.
105117
pub fn new(db_path: String, transaction_id: String, writer: TransactionWriter) -> Self {
106118
Self {
107119
db_path,
108120
transaction_id,
109121
writer: Some(writer),
110122
created_at: Instant::now(),
123+
runtime_handle: tokio::runtime::Handle::current(),
111124
}
112125
}
113126

@@ -230,17 +243,62 @@ impl From<(String, Vec<JsonValue>)> for Statement {
230243
}
231244
}
232245

246+
/// Upper bound on how long the auto-rollback task may hold the writer permit
247+
/// before it is considered hung and the connection is abandoned.
248+
const DROP_ROLLBACK_TIMEOUT: Duration = Duration::from_secs(5);
249+
233250
impl Drop for ActiveInterruptibleTransaction {
234251
fn drop(&mut self) {
235-
// If writer is still present, it means commit/rollback wasn't called.
236-
// SQLite will automatically ROLLBACK the transaction when the connection
237-
// is returned to the pool if no explicit COMMIT was issued.
238-
if self.writer.is_some() {
239-
debug!(
240-
"Dropping transaction for db: {}, tx_id: {} (will auto-rollback)",
241-
self.db_path, self.transaction_id
242-
);
243-
}
252+
// If writer is still present, commit/rollback was not called. The connection
253+
// is about to return to the pool — we must issue ROLLBACK explicitly because
254+
// sqlx pools reuse the connection (SQLite only auto-rollbacks on close, not
255+
// on pool return). Without this, the next acquire_writer() gets a connection
256+
// with an open transaction and "BEGIN IMMEDIATE" fails.
257+
let Some(mut writer) = self.writer.take() else {
258+
return;
259+
};
260+
let db_path = std::mem::take(&mut self.db_path);
261+
let tx_id = std::mem::take(&mut self.transaction_id);
262+
263+
debug!(
264+
"Dropping transaction for db: {}, tx_id: {} (auto-rollback scheduled)",
265+
db_path, tx_id
266+
);
267+
268+
// No race with the next acquire_writer(): `writer` owns the PoolConnection
269+
// (via WriteGuard / AttachedWriteGuard), which holds the single-writer
270+
// permit. The permit is not released until `writer` drops at the end of
271+
// this task — after ROLLBACK completes. The next acquire_writer() blocks
272+
// on that permit, so it cannot see a connection with a still-open tx.
273+
//
274+
// The timeout bounds how long a pathological ROLLBACK (stuck I/O, a
275+
// rogue busy lock) can keep the single-writer pool stalled. On timeout
276+
// we drop `writer` inside the runtime; after_release then cleans up.
277+
self.runtime_handle.spawn(async move {
278+
let result = tokio::time::timeout(DROP_ROLLBACK_TIMEOUT, async {
279+
if let Err(e) = writer.rollback().await {
280+
warn!(
281+
"auto-rollback on drop failed (db: {}, tx: {}): {}",
282+
db_path, tx_id, e
283+
);
284+
}
285+
if let Err(e) = writer.detach_if_attached().await {
286+
warn!(
287+
"detach_all after auto-rollback failed (db: {}, tx: {}): {}",
288+
db_path, tx_id, e
289+
);
290+
}
291+
// writer drops here — connection returns to pool clean
292+
})
293+
.await;
294+
295+
if result.is_err() {
296+
warn!(
297+
"auto-rollback on drop timed out after {:?} (db: {}, tx: {}) — pool's after_release hook will reconcile",
298+
DROP_ROLLBACK_TIMEOUT, db_path, tx_id
299+
);
300+
}
301+
});
244302
}
245303
}
246304

@@ -288,17 +346,21 @@ impl ActiveInterruptibleTransactions {
288346
Ok(())
289347
}
290348
Entry::Occupied(mut e) => {
291-
// If the existing transaction has expired, drop it (auto-rollback) and
292-
// replace with the new one.
349+
// If the existing transaction has expired, roll it back and replace
350+
// with the new one. We rollback explicitly (rather than relying on
351+
// Drop) so the writer is guaranteed to return to the pool clean
352+
// before the caller tries to start a new transaction on it.
293353
if e.get().created_at.elapsed() >= self.timeout {
294354
warn!(
295355
"Evicting expired transaction for db: {} (age: {:?}, timeout: {:?})",
296356
db_path,
297357
e.get().created_at.elapsed(),
298358
self.timeout,
299359
);
300-
// Drop the expired transaction (auto-rollback) before inserting the new one
301-
let _expired = e.insert(tx);
360+
let expired = e.insert(tx);
361+
if let Err(err) = expired.rollback().await {
362+
warn!("rollback of expired transaction failed (db: {db_path}): {err}");
363+
}
302364
Ok(())
303365
} else {
304366
Err(Error::TransactionAlreadyActive(db_path))
@@ -308,34 +370,37 @@ impl ActiveInterruptibleTransactions {
308370
}
309371

310372
pub async fn abort_all(&self) {
311-
let mut txs = self.inner.lock().await;
312-
debug!("Aborting {} active interruptible transaction(s)", txs.len());
313-
314-
for db_path in txs.keys() {
373+
// Drain under the lock, then release it before awaiting rollbacks so we
374+
// don't hold the mutex across a chain of awaits.
375+
let drained: Vec<(String, ActiveInterruptibleTransaction)> = {
376+
let mut txs = self.inner.lock().await;
377+
debug!("Aborting {} active interruptible transaction(s)", txs.len());
378+
txs.drain().collect()
379+
};
380+
381+
for (db_path, tx) in drained {
315382
debug!(
316-
"Dropping interruptible transaction for database: {}",
383+
"Rolling back interruptible transaction for database: {}",
317384
db_path
318385
);
386+
if let Err(err) = tx.rollback().await {
387+
warn!("rollback during abort_all failed (db: {db_path}): {err}");
388+
}
319389
}
320-
321-
// Clear all transactions to drop WriteGuards and release locks
322-
// Dropping triggers auto-rollback via Drop trait
323-
txs.clear();
324390
}
325391

326392
/// Remove and return transaction for commit/rollback.
327393
///
328394
/// Returns `Err(Error::TransactionTimedOut)` if the transaction has exceeded the
329-
/// configured timeout. The expired transaction is dropped (auto-rolled-back) in
330-
/// that case.
395+
/// configured timeout. The expired transaction is rolled back before the error
396+
/// is returned.
331397
pub async fn remove(
332398
&self,
333399
db_path: &str,
334400
token_id: &str,
335401
) -> Result<ActiveInterruptibleTransaction> {
336402
let mut txs = self.inner.lock().await;
337403

338-
// Validate token before removal
339404
let tx = txs
340405
.get(db_path)
341406
.ok_or_else(|| Error::NoActiveTransaction(db_path.to_string()))?;
@@ -344,21 +409,27 @@ impl ActiveInterruptibleTransactions {
344409
return Err(Error::InvalidTransactionToken);
345410
}
346411

347-
// Check if the transaction has expired
348-
if tx.created_at.elapsed() >= self.timeout {
349-
warn!(
350-
"Transaction timed out for db: {} (age: {:?}, timeout: {:?})",
351-
db_path,
352-
tx.created_at.elapsed(),
353-
self.timeout,
354-
);
355-
// Drop the expired transaction (auto-rollback via Drop)
356-
txs.remove(db_path);
357-
return Err(Error::TransactionTimedOut(db_path.to_string()));
412+
// Happy path: not expired, hand it back to the caller.
413+
if tx.created_at.elapsed() < self.timeout {
414+
// Safe unwrap: we just confirmed the key exists above.
415+
return Ok(txs.remove(db_path).unwrap());
358416
}
359417

360-
// Safe unwrap: we just confirmed the key exists above
361-
Ok(txs.remove(db_path).unwrap())
418+
// Expired: take it out, release the lock, then rollback without holding
419+
// it so other callers aren't blocked on an unrelated cleanup.
420+
warn!(
421+
"Transaction timed out for db: {} (age: {:?}, timeout: {:?})",
422+
db_path,
423+
tx.created_at.elapsed(),
424+
self.timeout,
425+
);
426+
let expired = txs.remove(db_path).unwrap();
427+
drop(txs);
428+
429+
if let Err(err) = expired.rollback().await {
430+
warn!("rollback of timed-out transaction failed (db: {db_path}): {err}");
431+
}
432+
Err(Error::TransactionTimedOut(db_path.to_string()))
362433
}
363434
}
364435

0 commit comments

Comments
 (0)