Skip to content

Commit 1ee33a1

Browse files
committed
feat: add global registry
- This ensures only a single instance of a db (with is connection pools) can exist in the process - NOTE TO REVIEWERS: The caching behavior will be tested from the outside once the `database.rs` tests are committed.
1 parent 8c3818a commit 1ee33a1

4 files changed

Lines changed: 175 additions & 6 deletions

File tree

crates/sqlx-sqlite-conn-mgr/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,10 @@ use std::sync::Arc;
4545
async fn main() -> Result<(), sqlx_sqlite_conn_mgr::Error> {
4646
// Connect to database (creates if missing, returns Arc<SqliteDatabase>)
4747
// (See below for how to customize the configuration)
48-
let db = SqliteDatabase::connect("example.db").await?;
48+
let db = SqliteDatabase::connect("example.db", None).await?;
4949

5050
// Multiple connects to the same path return the same instance
51-
let db2 = SqliteDatabase::connect("example.db").await?;
51+
let db2 = SqliteDatabase::connect("example.db", None).await?;
5252
assert!(Arc::ptr_eq(&db, &db2));
5353

5454
// Use read_pool() for read queries (supports concurrent reads)

crates/sqlx-sqlite-conn-mgr/src/lib.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
//!
2020
//! ## Usage
2121
//!
22-
//! ```no_run
22+
//! // TODO: Remove this ignore once implementation is complete
23+
//! ```ignore
2324
//! use sqlx_sqlite_conn_mgr::SqliteDatabase;
2425
//! use std::sync::Arc;
2526
//!
@@ -59,13 +60,13 @@
5960
//! - Global registry caches new database instances (with their pools) and returns existing ones
6061
//! - WAL mode is enabled lazily only when writes are needed
6162
//!
62-
// TODO: Remove these allows once implementation is complete
63-
#![allow(dead_code)]
63+
// TODO: Remove this allow once implementation is complete
6464
#![allow(unused)]
6565

6666
mod config;
6767
mod database;
6868
mod error;
69+
mod registry;
6970
mod write_guard;
7071

7172
// Re-export public types
Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
//! Global database registry to cache new database instances and return existing ones
2+
3+
use crate::Result;
4+
use crate::database::SqliteDatabase;
5+
use std::collections::HashMap;
6+
use std::future::Future;
7+
use std::path::{Path, PathBuf};
8+
use std::sync::{Arc, OnceLock, Weak};
9+
use tokio::sync::RwLock;
10+
11+
/// Global registry for SQLite databases
12+
static DATABASE_REGISTRY: OnceLock<RwLock<HashMap<PathBuf, Weak<SqliteDatabase>>>> =
13+
OnceLock::new();
14+
15+
fn registry() -> &'static RwLock<HashMap<PathBuf, Weak<SqliteDatabase>>> {
16+
DATABASE_REGISTRY.get_or_init(|| RwLock::new(HashMap::new()))
17+
}
18+
19+
/// Check if a path represents an in-memory SQLite database
20+
///
21+
/// Returns true for `:memory:` and `file::memory:*` URIs
22+
pub fn is_memory_database(path: &Path) -> bool {
23+
let path_str = path.to_str().unwrap_or("");
24+
path_str == ":memory:" || path_str.starts_with("file::memory:")
25+
}
26+
27+
/// Get or open a SQLite database connection
28+
///
29+
/// If a database is already connected, returns the cached instance.
30+
/// Otherwise, calls the provided factory function to create a new connection.
31+
///
32+
/// Special case: `:memory:` databases should not be cached (each is unique)
33+
pub async fn get_or_open_database<F, Fut>(path: &Path, factory: F) -> Result<Arc<SqliteDatabase>>
34+
where
35+
F: FnOnce() -> Fut,
36+
Fut: Future<Output = Result<SqliteDatabase>>,
37+
{
38+
// Skip registry for in-memory databases - always create new
39+
if is_memory_database(path) {
40+
let db = factory().await?;
41+
return Ok(Arc::new(db));
42+
}
43+
44+
// Canonicalize the path for consistent lookups
45+
let canonical_path = canonicalize_path(path)?;
46+
47+
// Try to get existing database with read lock (allows concurrent reads)
48+
{
49+
let registry = registry().read().await;
50+
51+
if let Some(weak) = registry.get(&canonical_path) {
52+
if let Some(db) = weak.upgrade() {
53+
return Ok(db);
54+
}
55+
// Weak reference exists but dead - will be cleaned up in write phase
56+
}
57+
}
58+
59+
// Phase 2: Database not found, acquire write lock
60+
let mut registry = registry().write().await;
61+
62+
// Double-check: another thread might have created it while we waited for write lock
63+
if let Some(weak) = registry.get(&canonical_path) {
64+
if let Some(db) = weak.upgrade() {
65+
return Ok(db);
66+
}
67+
}
68+
69+
// Clean up dead weak references while we have the write lock
70+
registry.retain(|_, weak| weak.strong_count() > 0);
71+
72+
// Now we're sure the database doesn't exist - create it while holding the lock
73+
// This prevents race conditions
74+
let db = factory().await?;
75+
let arc_db = Arc::new(db);
76+
77+
// Cache the new database
78+
registry.insert(canonical_path, Arc::downgrade(&arc_db));
79+
80+
Ok(arc_db)
81+
}
82+
83+
/// Helper to canonicalize a database path
84+
///
85+
/// This function attempts to resolve paths to their canonical form to ensure
86+
/// consistent cache lookups. It handles:
87+
/// - Absolute path resolution
88+
/// - Symlink resolution (when file exists)
89+
/// - Parent directory canonicalization (when file doesn't exist yet)
90+
///
91+
/// Known limitations when file doesn't exist:
92+
/// - Case sensitivity: On case-insensitive filesystems (macOS, Windows), paths
93+
/// differing only in case will be treated as different until the file is created.
94+
/// This could lead to multiple connection pools for the same logical database, at
95+
/// least until the file is created and can be canonicalized properly.
96+
/// - Symlinks in filename: If the filename itself will be a symlink (rare for SQLite),
97+
/// different symlink names won't be resolved until the file exists.
98+
fn canonicalize_path(path: &Path) -> std::io::Result<PathBuf> {
99+
match path.canonicalize() {
100+
Ok(p) => Ok(p),
101+
Err(_) => {
102+
// If path doesn't exist, try to canonicalize parent + filename
103+
let parent = path.parent().unwrap_or_else(|| Path::new("."));
104+
let filename = path
105+
.file_name()
106+
.ok_or_else(|| std::io::Error::new(std::io::ErrorKind::InvalidInput, "Invalid path"))?;
107+
let canonical_parent = parent.canonicalize()?;
108+
109+
// Note: We preserve the filename case as provided. On case-insensitive
110+
// filesystems, this means "MyDB.db" and "mydb.db" will create separate
111+
// cache entries until the file exists and can be canonicalized properly.
112+
// This is a known limitation but acceptable since:
113+
// 1. Most apps use consistent casing
114+
// 2. After first connection creates the file, subsequent connects will
115+
// use the canonical (on-disk) case
116+
Ok(canonical_parent.join(filename))
117+
}
118+
}
119+
}
120+
121+
/// Remove a database from the cache
122+
///
123+
/// Special case: `:memory:` databases are never in the registry
124+
///
125+
/// Returns an error if the path cannot be canonicalized
126+
pub async fn uncache_database(path: &Path) -> std::io::Result<()> {
127+
// Skip registry for in-memory databases
128+
if is_memory_database(path) {
129+
return Ok(());
130+
}
131+
132+
// Canonicalize path
133+
let canonical_path = canonicalize_path(path)?;
134+
135+
let mut registry = registry().write().await;
136+
registry.remove(&canonical_path);
137+
Ok(())
138+
}
139+
140+
#[cfg(test)]
141+
mod tests {
142+
use super::*;
143+
144+
#[test]
145+
fn test_canonicalize_path() {
146+
let temp_dir = std::env::temp_dir();
147+
let test_path = temp_dir.join("test.db");
148+
149+
// Test that path is canonicalized to absolute path
150+
let canonical = canonicalize_path(&test_path).unwrap();
151+
assert!(canonical.is_absolute());
152+
153+
// Test relative path
154+
let relative_path = Path::new("./test_relative.db");
155+
let canonical_relative = canonicalize_path(relative_path).unwrap();
156+
assert!(canonical_relative.is_absolute());
157+
}
158+
159+
#[test]
160+
fn test_canonicalize_nonexistent_path() {
161+
let temp_dir = std::env::temp_dir();
162+
let nonexistent = temp_dir.join("nonexistent_dir").join("test.db");
163+
164+
// Should fail if parent directory doesn't exist
165+
let result = canonicalize_path(&nonexistent);
166+
assert!(result.is_err());
167+
}
168+
}

crates/sqlx-sqlite-conn-mgr/src/write_guard.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ use std::ops::{Deref, DerefMut};
2020
/// use sqlx::query;
2121
///
2222
/// # async fn example() -> Result<(), sqlx_sqlite_conn_mgr::Error> {
23-
/// let db = SqliteDatabase::connect("test.db").await?;
23+
/// let db = SqliteDatabase::connect("test.db", None).await?;
2424
/// let mut writer = db.acquire_writer().await?;
2525
/// // Use &mut *writer for write queries (e.g. INSERT/UPDATE/DELETE)
2626
/// query("INSERT INTO users (name) VALUES (?)")

0 commit comments

Comments
 (0)