diff --git a/Cargo.lock b/Cargo.lock index 41de0463..bf72075e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2066,6 +2066,9 @@ name = "fastrand" version = "2.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9f1f227452a390804cdb637b74a86990f2a7d7ba4b7d5693aac9b4dd6defd8d6" +dependencies = [ + "getrandom 0.3.4", +] [[package]] name = "fd-lock" @@ -2143,6 +2146,18 @@ dependencies = [ "serde", ] +[[package]] +name = "flume" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e139bc46ca777eb5efaf62df0ab8cc5fd400866427e56c68b22e414e53bd3be" +dependencies = [ + "fastrand", + "futures-core", + "futures-sink", + "spin", +] + [[package]] name = "fnv" version = "1.0.7" @@ -6894,10 +6909,10 @@ dependencies = [ "criterion", "csv", "env_logger", + "flume", "futures-util", "glob", "itertools 0.14.0", - "kanal", "log", "parking_lot", "rand 0.10.1", diff --git a/crates/core/benches/benchmark.rs b/crates/core/benches/benchmark.rs index 646fed68..80288511 100644 --- a/crates/core/benches/benchmark.rs +++ b/crates/core/benches/benchmark.rs @@ -60,7 +60,7 @@ async fn add_room( name: &str, ) -> Result<[u8; 16], anyhow::Error> { let room: [u8; 16] = conn - .query_row_get( + .write_query_row_get( "INSERT INTO room (name) VALUES ($1) RETURNING id", params!(name.to_string()), 0, diff --git a/crates/core/src/admin/table/alter_index.rs b/crates/core/src/admin/table/alter_index.rs index 2b8e2d2c..176e547a 100644 --- a/crates/core/src/admin/table/alter_index.rs +++ b/crates/core/src/admin/table/alter_index.rs @@ -54,7 +54,7 @@ pub async fn alter_index_handler( let unqualified_source_index_name = request.source_schema.name.name.clone(); let tx_log = conn - .call(move |conn| { + .call_writer(move |conn| { let mut tx = TransactionRecorder::new(conn)?; // Drop old index diff --git a/crates/core/src/admin/table/alter_table.rs b/crates/core/src/admin/table/alter_table.rs index 2d526d88..9bfa6a32 100644 --- a/crates/core/src/admin/table/alter_table.rs +++ b/crates/core/src/admin/table/alter_table.rs @@ -92,7 +92,7 @@ pub async fn alter_table_handler( ephemeral_table_schema.name.database_schema = None; conn - .call( + .call_writer( move |conn| -> Result, trailbase_sqlite::Error> { let mut tx = TransactionRecorder::new(conn) .map_err(|err| trailbase_sqlite::Error::Other(err.into()))?; diff --git a/crates/core/src/admin/table/create_index.rs b/crates/core/src/admin/table/create_index.rs index a37ebb7f..3df52086 100644 --- a/crates/core/src/admin/table/create_index.rs +++ b/crates/core/src/admin/table/create_index.rs @@ -37,7 +37,7 @@ pub async fn create_index_handler( let create_index_query = index_schema.create_index_statement(); let tx_log = conn - .call(move |conn| { + .call_writer(move |conn| { let mut tx = TransactionRecorder::new(conn)?; tx.execute(&create_index_query, ())?; diff --git a/crates/core/src/admin/table/create_table.rs b/crates/core/src/admin/table/create_table.rs index 818bacd4..3ab63d5a 100644 --- a/crates/core/src/admin/table/create_table.rs +++ b/crates/core/src/admin/table/create_table.rs @@ -41,7 +41,7 @@ pub async fn create_table_handler( let create_table_query = table_schema.create_table_statement(); let tx_log = conn - .call(move |conn| { + .call_writer(move |conn| { let mut tx = TransactionRecorder::new(conn)?; tx.execute(&create_table_query, ())?; diff --git a/crates/core/src/admin/table/drop_index.rs b/crates/core/src/admin/table/drop_index.rs index f7ed9036..6f25d3e4 100644 --- a/crates/core/src/admin/table/drop_index.rs +++ b/crates/core/src/admin/table/drop_index.rs @@ -40,7 +40,7 @@ pub async fn drop_index_handler( let tx_log = { let unqualified_index_name = unqualified_index_name.clone(); conn - .call(move |conn| { + .call_writer(move |conn| { let mut tx = TransactionRecorder::new(conn)?; let query = format!("DROP INDEX IF EXISTS \"{unqualified_index_name}\""); diff --git a/crates/core/src/admin/table/drop_table.rs b/crates/core/src/admin/table/drop_table.rs index e8b9a1e6..f0d9420d 100644 --- a/crates/core/src/admin/table/drop_table.rs +++ b/crates/core/src/admin/table/drop_table.rs @@ -69,7 +69,7 @@ pub async fn drop_table_handler( let unqualified_table_name = unqualified_table_name.clone(); let entity_type = entity_type.clone(); conn - .call(move |conn| { + .call_writer(move |conn| { let mut tx = TransactionRecorder::new(conn)?; let query = format!( diff --git a/crates/core/src/connection.rs b/crates/core/src/connection.rs index 4c0ace66..bf0a454b 100644 --- a/crates/core/src/connection.rs +++ b/crates/core/src/connection.rs @@ -353,10 +353,10 @@ fn init_main_db_impl( return Ok(conn); }, trailbase_sqlite::Options { - n_read_threads: match (data_dir, std::thread::available_parallelism()) { - (None, _) => Some(0), + num_threads: match (data_dir, std::thread::available_parallelism()) { + (None, _) => Some(1), (Some(_), Ok(n)) => Some(n.get().clamp(2, 4)), - (Some(_), Err(_)) => Some(4), + (Some(_), Err(_)) => Some(2), }, ..Default::default() }, diff --git a/crates/core/src/logging.rs b/crates/core/src/logging.rs index 2af7d1ee..e2e7d204 100644 --- a/crates/core/src/logging.rs +++ b/crates/core/src/logging.rs @@ -249,7 +249,7 @@ impl SqliteLogLayer { // NOTE: awaiting the `conn.call()` is the secret to batching, since we won't read from the // channel until the database write is complete. let result = conn - .call(move |conn| { + .call_writer(move |conn| { Self::insert_logs(conn, buffer)?; Ok(()) }) diff --git a/crates/core/src/records/record_api.rs b/crates/core/src/records/record_api.rs index cecf6194..895a9956 100644 --- a/crates/core/src/records/record_api.rs +++ b/crates/core/src/records/record_api.rs @@ -496,7 +496,7 @@ impl RecordApi { self .state .conn - .call(move |conn| { + .call_writer(move |conn| { Ok(Self::check_record_level_access_impl( conn, &access_query, diff --git a/crates/core/src/records/subscribe/tests.rs b/crates/core/src/records/subscribe/tests.rs index b1fe6a68..62ab4f7c 100644 --- a/crates/core/src/records/subscribe/tests.rs +++ b/crates/core/src/records/subscribe/tests.rs @@ -57,7 +57,7 @@ async fn subscribe_to_record_test() { let record_id_raw = 0; let rowid: i64 = conn - .query_row_get( + .write_query_row_get( "INSERT INTO test (id, text) VALUES ($1, 'foo') RETURNING _rowid_", [trailbase_sqlite::Value::Integer(record_id_raw)], 0, @@ -290,7 +290,7 @@ async fn subscription_lifecycle_test() { let record_id_raw = 0; let record_id = trailbase_sqlite::Value::Integer(record_id_raw); let rowid: i64 = conn - .query_row_get( + .write_query_row_get( "INSERT INTO test (id, text) VALUES ($1, 'foo') RETURNING _rowid_", [record_id], 0, @@ -402,7 +402,7 @@ async fn subscription_acl_test() { let record_id_raw = 0; let record_id = trailbase_sqlite::Value::Integer(record_id_raw); let _rowid: i64 = conn - .query_row_get( + .write_query_row_get( "INSERT INTO test (id, user, text) VALUES ($1, $2, 'foo') RETURNING _rowid_", [ record_id.clone(), @@ -578,7 +578,7 @@ async fn subscription_acl_change_owner() { let record_id = 0; let _rowid: i64 = conn - .query_row_get( + .write_query_row_get( "INSERT INTO test (id, user, text) VALUES ($1, $2, 'foo') RETURNING _rowid_", [ trailbase_sqlite::Value::Integer(record_id), diff --git a/crates/core/src/records/test_utils.rs b/crates/core/src/records/test_utils.rs index e7c33a0f..00a8413b 100644 --- a/crates/core/src/records/test_utils.rs +++ b/crates/core/src/records/test_utils.rs @@ -159,7 +159,7 @@ mod tests { name: &str, ) -> Result<[u8; 16], anyhow::Error> { let room: [u8; 16] = conn - .query_row_get( + .write_query_row_get( "INSERT INTO room (name) VALUES ($1) RETURNING rid", params!(name.to_string()), 0, @@ -191,7 +191,7 @@ mod tests { message: &str, ) -> Result<[u8; 16], anyhow::Error> { let id: [u8; 16] = conn - .query_row_get( + .write_query_row_get( "INSERT INTO message (_owner, room, data) VALUES ($1, $2, $3) RETURNING mid", params!(user, room, message.to_string()), 0, diff --git a/crates/core/src/records/transaction.rs b/crates/core/src/records/transaction.rs index f968e8f9..73add926 100644 --- a/crates/core/src/records/transaction.rs +++ b/crates/core/src/records/transaction.rs @@ -229,7 +229,7 @@ pub async fn record_transactions_handler( let ids = if request.transaction.unwrap_or(false) { conn - .call( + .call_writer( move |conn: &mut rusqlite::Connection| -> Result, trailbase_sqlite::Error> { let tx = conn.transaction()?; @@ -248,7 +248,7 @@ pub async fn record_transactions_handler( .await? } else { conn - .call( + .call_writer( move |conn: &mut rusqlite::Connection| -> Result, trailbase_sqlite::Error> { let mut ids: Vec = vec![]; for op in operations { diff --git a/crates/core/src/records/write_queries.rs b/crates/core/src/records/write_queries.rs index 745f7fba..c1c50109 100644 --- a/crates/core/src/records/write_queries.rs +++ b/crates/core/src/records/write_queries.rs @@ -200,7 +200,7 @@ pub(crate) async fn run_queries( }; let result: Vec = conn - .call(move |conn| { + .call_writer(move |conn| { let tx = conn.transaction()?; let rows: Vec = queries @@ -251,7 +251,7 @@ pub(crate) async fn run_insert_query( }; let (rowid, return_value): (i64, trailbase_sqlite::Value) = conn - .call(move |conn| { + .call_writer(move |conn| { let result = query.apply(conn)?; return Ok((result.rowid, result.pk_value.expect("insert"))); }) @@ -288,7 +288,7 @@ pub(crate) async fn run_update_query( }; let rowid: i64 = conn - .call(move |conn| { + .call_writer(move |conn| { return Ok(query.apply(conn)?.rowid); }) .await?; @@ -315,7 +315,7 @@ pub(crate) async fn run_delete_query( let query = WriteQuery::new_delete(table_name, pk_column, pk_value)?; let rowid: i64 = conn - .call(move |conn| { + .call_writer(move |conn| { return Ok(query.apply(conn)?.rowid); }) .await?; diff --git a/crates/core/src/scheduler.rs b/crates/core/src/scheduler.rs index 4d3b9ec3..f57018c2 100644 --- a/crates/core/src/scheduler.rs +++ b/crates/core/src/scheduler.rs @@ -274,7 +274,7 @@ fn build_job( return async move { conn - .call(|conn| { + .call_writer(|conn| { return Ok(conn.backup("main", backup_file, /* progress= */ None)?); }) .await diff --git a/crates/core/src/server/mod.rs b/crates/core/src/server/mod.rs index 2be7208d..2a79fec3 100644 --- a/crates/core/src/server/mod.rs +++ b/crates/core/src/server/mod.rs @@ -318,7 +318,7 @@ impl Server { .connection_manager() .main_entry() .connection - .call(|conn: &mut rusqlite::Connection| { + .call_writer(|conn: &mut rusqlite::Connection| { return crate::migrations::apply_main_migrations(conn, Some(user_migrations_path)) .map_err(|err| trailbase_sqlite::Error::Other(err.into())); }) diff --git a/crates/core/src/transaction.rs b/crates/core/src/transaction.rs index 3ffcd71a..1dadcb8f 100644 --- a/crates/core/src/transaction.rs +++ b/crates/core/src/transaction.rs @@ -77,7 +77,7 @@ impl TransactionLog { let runner = migrations::new_migration_runner(&migrations).set_abort_missing(false); let report = conn - .call(move |conn| { + .call_writer(move |conn| { let report = runner .run(conn) .map_err(|err| trailbase_sqlite::Error::Other(err.into()))?; @@ -129,7 +129,7 @@ impl TransactionLog { conn: &trailbase_sqlite::Connection, ) -> Result<(), trailbase_sqlite::Error> { conn - .call(|conn: &mut rusqlite::Connection| { + .call_writer(|conn: &mut rusqlite::Connection| { let tx = conn.transaction()?; for (query_type, stmt) in self.log { match query_type { @@ -261,7 +261,7 @@ mod tests { // Just double checking that rusqlite's query and execute ignore everything but the first // statement. let result = conn - .query_row_get::( + .write_query_row_get::( r#" SELECT name FROM 'table' WHERE id = 0; SELECT name FROM 'table' WHERE id = 1; diff --git a/crates/core/tests/integration_test.rs b/crates/core/tests/integration_test.rs index e10d630d..f2a607c0 100644 --- a/crates/core/tests/integration_test.rs +++ b/crates/core/tests/integration_test.rs @@ -310,7 +310,7 @@ async fn add_room( name: &str, ) -> Result<[u8; 16], anyhow::Error> { let room: [u8; 16] = conn - .query_row_get( + .write_query_row_get( "INSERT INTO room (name) VALUES ($1) RETURNING id", params!(name.to_string()), 0, diff --git a/crates/sqlite/Cargo.toml b/crates/sqlite/Cargo.toml index 3b1ae03e..d9c6004d 100644 --- a/crates/sqlite/Cargo.toml +++ b/crates/sqlite/Cargo.toml @@ -20,7 +20,7 @@ path = "benches/join-order/main.rs" harness = false [dependencies] -kanal = "0.1.1" +flume = { version = "0.12.0", default-feature = false, features = ["select"] } log = { version = "^0.4.21", default-features = false } parking_lot = { workspace = true } rusqlite = { workspace = true } diff --git a/crates/sqlite/benches/synthetic/benchmark.rs b/crates/sqlite/benches/synthetic/benchmark.rs index 949ff527..6c7ad504 100644 --- a/crates/sqlite/benches/synthetic/benchmark.rs +++ b/crates/sqlite/benches/synthetic/benchmark.rs @@ -120,7 +120,7 @@ fn insert_benchmark_group(c: &mut Criterion) { return Ok(Connection::with_opts( || rusqlite::Connection::open(&fname), Options { - n_read_threads: Some(0), + num_threads: Some(1), ..Default::default() }, )?); @@ -132,7 +132,7 @@ fn insert_benchmark_group(c: &mut Criterion) { return Ok(Connection::with_opts( || rusqlite::Connection::open(&fname), Options { - n_read_threads: Some(2), + num_threads: Some(2), ..Default::default() }, )?); @@ -144,7 +144,7 @@ fn insert_benchmark_group(c: &mut Criterion) { return Ok(Connection::with_opts( || rusqlite::Connection::open(&fname), Options { - n_read_threads: Some(4), + num_threads: Some(4), ..Default::default() }, )?); @@ -156,7 +156,7 @@ fn insert_benchmark_group(c: &mut Criterion) { return Ok(Connection::with_opts( || rusqlite::Connection::open(&fname), Options { - n_read_threads: Some(8), + num_threads: Some(8), ..Default::default() }, )?); @@ -258,7 +258,7 @@ fn read_benchmark_group(c: &mut Criterion) { return Ok(Connection::with_opts( || rusqlite::Connection::open(&fname), Options { - n_read_threads: Some(2), + num_threads: Some(1), ..Default::default() }, )?); @@ -270,7 +270,7 @@ fn read_benchmark_group(c: &mut Criterion) { return Ok(Connection::with_opts( || rusqlite::Connection::open(&fname), Options { - n_read_threads: Some(2), + num_threads: Some(2), ..Default::default() }, )?); @@ -282,7 +282,7 @@ fn read_benchmark_group(c: &mut Criterion) { return Ok(Connection::with_opts( || rusqlite::Connection::open(&fname), Options { - n_read_threads: Some(4), + num_threads: Some(4), ..Default::default() }, )?); @@ -294,7 +294,7 @@ fn read_benchmark_group(c: &mut Criterion) { return Ok(Connection::with_opts( || rusqlite::Connection::open(&fname), Options { - n_read_threads: Some(8), + num_threads: Some(8), ..Default::default() }, )?); @@ -434,7 +434,7 @@ fn mixed_benchmark_group(c: &mut Criterion) { return Ok(Connection::with_opts( || rusqlite::Connection::open(&fname), Options { - n_read_threads: Some(0), + num_threads: Some(1), ..Default::default() }, )?); @@ -446,7 +446,7 @@ fn mixed_benchmark_group(c: &mut Criterion) { return Ok(Connection::with_opts( || rusqlite::Connection::open(&fname), Options { - n_read_threads: Some(2), + num_threads: Some(2), ..Default::default() }, )?); @@ -458,7 +458,7 @@ fn mixed_benchmark_group(c: &mut Criterion) { return Ok(Connection::with_opts( || rusqlite::Connection::open(&fname), Options { - n_read_threads: Some(4), + num_threads: Some(4), ..Default::default() }, )?); @@ -470,7 +470,7 @@ fn mixed_benchmark_group(c: &mut Criterion) { return Ok(Connection::with_opts( || rusqlite::Connection::open(&fname), Options { - n_read_threads: Some(8), + num_threads: Some(8), ..Default::default() }, )?); diff --git a/crates/sqlite/benches/synthetic/connection.rs b/crates/sqlite/benches/synthetic/connection.rs index 7bdb6591..57ce92cd 100644 --- a/crates/sqlite/benches/synthetic/connection.rs +++ b/crates/sqlite/benches/synthetic/connection.rs @@ -33,7 +33,7 @@ impl AsyncConnection for Connection { ) -> Result { return Ok( self - .query_row_get::>(sql.into(), params.into(), 0) + .write_query_row_get::>(sql.into(), params.into(), 0) .await? .unwrap() .0, diff --git a/crates/sqlite/src/sqlite/batch.rs b/crates/sqlite/src/sqlite/batch.rs index 8f43620c..c3b7204b 100644 --- a/crates/sqlite/src/sqlite/batch.rs +++ b/crates/sqlite/src/sqlite/batch.rs @@ -16,7 +16,7 @@ pub async fn execute_batch( sql: impl AsRef + Send + 'static, ) -> Result, Error> { return conn - .call(move |conn: &mut rusqlite::Connection| { + .call_writer(move |conn: &mut rusqlite::Connection| { let batch = rusqlite::Batch::new(conn, sql.as_ref()); let mut p = batch.peekable(); diff --git a/crates/sqlite/src/sqlite/connection.rs b/crates/sqlite/src/sqlite/connection.rs index 0f4a89e7..2064d321 100644 --- a/crates/sqlite/src/sqlite/connection.rs +++ b/crates/sqlite/src/sqlite/connection.rs @@ -1,6 +1,7 @@ use std::fmt::Debug; use std::hash::{Hash, Hasher}; use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; use crate::database::Database; use crate::error::Error; @@ -16,6 +17,7 @@ pub use crate::sqlite::executor::{ArcLockGuard, LockGuard, Options}; /// A handle to call functions in background thread. #[derive(Clone)] pub struct Connection { + id: usize, exec: Executor, } @@ -29,6 +31,7 @@ impl Connection { opt: Options, ) -> std::result::Result { return Ok(Self { + id: UNIQUE_CONN_ID.fetch_add(1, Ordering::SeqCst), exec: Executor::new(builder, opt)?, }); } @@ -42,7 +45,7 @@ impl Connection { let conn = Self::with_opts( rusqlite::Connection::open_in_memory, Options { - n_read_threads: Some(0), + num_threads: Some(1), ..Default::default() }, )?; @@ -53,7 +56,7 @@ impl Connection { } pub fn id(&self) -> usize { - return self.exec.id(); + return self.id; } pub fn threads(&self) -> usize { @@ -84,12 +87,12 @@ impl Connection { /// during startup/SIGHUP). /// * Batch log inserts to minimize thread slushing. /// * Backups from scheduler (API could be easily hoisted) - pub async fn call(&self, function: F) -> Result + pub async fn call_writer(&self, function: F) -> Result where F: FnOnce(&mut rusqlite::Connection) -> Result + Send + 'static, R: Send + 'static, { - return self.exec.call(function).await; + return self.exec.call_writer(function).await; } pub async fn call_reader(&self, function: F) -> Result @@ -181,7 +184,7 @@ impl Connection { return self.exec.write_query_rows_f(sql, params, from_rows).await; } - pub async fn query_row_get( + pub async fn write_query_row_get( &self, sql: impl AsRef + Send + 'static, params: impl Params + Send + 'static, @@ -302,3 +305,5 @@ impl PartialEq for Connection { } impl Eq for Connection {} + +static UNIQUE_CONN_ID: AtomicUsize = AtomicUsize::new(0); diff --git a/crates/sqlite/src/sqlite/executor.rs b/crates/sqlite/src/sqlite/executor.rs index 0c50cabf..01e10a7f 100644 --- a/crates/sqlite/src/sqlite/executor.rs +++ b/crates/sqlite/src/sqlite/executor.rs @@ -1,10 +1,9 @@ -use kanal::{Receiver, Sender}; +use flume::{Receiver, Sender}; use log::*; use parking_lot::RwLock; use rusqlite::fallible_iterator::FallibleIterator; use std::ops::{Deref, DerefMut}; use std::sync::Arc; -use std::sync::atomic::{AtomicUsize, Ordering}; use tokio::sync::oneshot; use crate::error::Error; @@ -18,24 +17,26 @@ struct ConnectionVec(smallvec::SmallVec<[rusqlite::Connection; 32]>); // thread. unsafe impl Sync for ConnectionVec {} -enum Message { - RunMut(Box), +enum ReaderMessage { RunConst(Box), Terminate, } +enum WriterMessage { + RunMut(Box), +} + #[derive(Clone, Default)] pub struct Options { pub busy_timeout: Option, - pub n_read_threads: Option, + pub num_threads: Option, } /// A handle to call functions in background thread. #[derive(Clone)] pub(crate) struct Executor { - id: usize, - reader: Sender, - writer: Sender, + reader: Sender, + writer: Sender, // NOTE: Is shared across reader and writer worker threads. conns: Arc>, } @@ -47,7 +48,7 @@ impl Executor { ) -> Result { let Options { busy_timeout, - n_read_threads, + num_threads, } = opt; let new_conn = || -> Result { @@ -67,83 +68,84 @@ impl Executor { return s.is_empty(); }); - let n_read_threads: i64 = match (in_memory, n_read_threads.unwrap_or(0)) { + let num_threads: usize = match (in_memory, num_threads.unwrap_or(1)) { (true, _) => { // We cannot share an in-memory database across threads, they're all independent. - 0 + 1 } - (false, 1) => { - warn!("A single reader thread won't improve performance, falling back to 0."); - 0 + (false, 0) => { + warn!("Executor needs at least one thread, falling back to 1."); + 1 } (false, n) => { if let Ok(max) = std::thread::available_parallelism() && n > max.get() { warn!( - "Num read threads '{n}' exceeds hardware parallelism: {}", + "Num threads '{n}' exceeds hardware parallelism: {}", max.get() ); } - n as i64 + + n } }; + assert!(num_threads > 0); + + let num_read_threads = num_threads - 1; let conns = Arc::new(RwLock::new(ConnectionVec({ - let mut conns = vec![write_conn]; - for _ in 0..(n_read_threads - 1).max(0) { + let mut conns = Vec::with_capacity(num_threads); + conns.push(write_conn); + for _ in 0..num_read_threads { conns.push(new_conn()?); } conns.into() }))); - assert_eq!(n_read_threads.max(1) as usize, conns.read().0.len()); + assert_eq!(num_threads, conns.read().0.len()); - // Spawn writer. - let (shared_write_sender, shared_write_receiver) = kanal::unbounded::(); - { - let conns = conns.clone(); - std::thread::Builder::new() - .name("tb-sqlite-writer".to_string()) - .spawn(move || event_loop(0, conns, shared_write_receiver)) - .expect("startup"); - } + let (shared_write_sender, shared_write_receiver) = flume::unbounded::(); + let (shared_read_sender, shared_read_receiver) = flume::unbounded::(); - // Spawn readers. - let shared_read_sender = if n_read_threads > 0 { - let (shared_read_sender, shared_read_receiver) = kanal::unbounded::(); - for i in 0..n_read_threads { - // NOTE: read and writer threads are sharing the first conn, given they're mutually - // exclusive. - let index = i as usize; + // Spawn writer thread. + std::thread::Builder::new() + .name("tb-sqlite-0 (rw)".to_string()) + .spawn({ let shared_read_receiver = shared_read_receiver.clone(); let conns = conns.clone(); - std::thread::Builder::new() - .name(format!("tb-sqlite-reader-{index}")) - .spawn(move || event_loop(index, conns, shared_read_receiver)) - .expect("startup"); - } - shared_read_sender - } else { - shared_write_sender.clone() - }; + move || writer_event_loop(conns, shared_read_receiver, shared_write_receiver) + }) + .expect("startup"); + + // Spawn readers threads. + for index in 0..num_read_threads { + std::thread::Builder::new() + .name(format!("tb-sqlite-{index} (ro)")) + .spawn({ + let shared_read_receiver = shared_read_receiver.clone(); + let conns = conns.clone(); + + move || reader_event_loop(index + 1, conns, shared_read_receiver) + }) + .expect("startup"); + } debug!( - "Opened SQLite DB '{}' with {n_read_threads} reader threads", + "Opened SQLite DB '{}' ({num_threads} threads, in-memory: {in_memory})", path.as_deref().unwrap_or("") ); - return Ok(Self { - id: UNIQUE_CONN_ID.fetch_add(1, Ordering::SeqCst), + let conn = Self { reader: shared_read_sender, writer: shared_write_sender, conns, - }); - } + }; - pub fn id(&self) -> usize { - return self.id; + assert_eq!(num_threads, conn.threads()); + + return Ok(conn); } pub fn threads(&self) -> usize { @@ -178,7 +180,7 @@ impl Executor { } #[inline] - pub async fn call(&self, function: F) -> Result + pub async fn call_writer(&self, function: F) -> Result where F: FnOnce(&mut rusqlite::Connection) -> Result + Send + 'static, R: Send + 'static, @@ -188,7 +190,7 @@ impl Executor { self .writer - .send(Message::RunMut(Box::new(move |conn| { + .send(WriterMessage::RunMut(Box::new(move |conn| { if !sender.is_closed() { let _ = sender.send(function(conn)); } @@ -208,7 +210,7 @@ impl Executor { self .reader - .send(Message::RunConst(Box::new(move |conn| { + .send(ReaderMessage::RunConst(Box::new(move |conn| { if !sender.is_closed() { let _ = sender.send(function(conn)); } @@ -229,7 +231,7 @@ impl Executor { T: Send + 'static, { return self - .call(move |conn: &mut rusqlite::Connection| { + .call_writer(move |conn: &mut rusqlite::Connection| { let mut stmt = conn.prepare_cached(sql.as_ref())?; params.bind(&mut stmt)?; @@ -267,7 +269,7 @@ impl Executor { params: impl Params + Send + 'static, ) -> Result { return self - .call(move |conn: &mut rusqlite::Connection| { + .call_writer(move |conn: &mut rusqlite::Connection| { let mut stmt = conn.prepare_cached(sql.as_ref())?; params.bind(&mut stmt)?; @@ -279,7 +281,7 @@ impl Executor { pub async fn execute_batch(&self, sql: impl AsRef + Send + 'static) -> Result<(), Error> { self - .call(move |conn: &mut rusqlite::Connection| { + .call_writer(move |conn: &mut rusqlite::Connection| { let mut batch = rusqlite::Batch::new(conn, sql.as_ref()); while let Some(mut stmt) = batch.next()? { // NOTE: We must use `raw_query` instead of `raw_execute`, otherwise queries @@ -295,9 +297,8 @@ impl Executor { } pub async fn close(self) -> Result<(), Error> { - let _ = self.writer.send(Message::Terminate); - while self.reader.send(Message::Terminate).is_ok() { - // Continue to close readers while the channel is alive. + while self.reader.send(ReaderMessage::Terminate).is_ok() { + // Continue to close readers (as well as the reader/writer) while the channel is alive. } let mut errors = vec![]; @@ -320,22 +321,67 @@ impl Executor { } } -fn event_loop(id: usize, conns: Arc>, receiver: Receiver) { +fn reader_event_loop( + idx: usize, + conns: Arc>, + receiver: Receiver, +) { while let Ok(message) = receiver.recv() { match message { - Message::RunConst(f) => { + ReaderMessage::RunConst(f) => { let lock = conns.read(); - f(&lock.0[id]) + f(&lock.0[idx]) } - Message::RunMut(f) => { - let mut lock = conns.write(); - f(&mut lock.0[0]) - } - Message::Terminate => { + ReaderMessage::Terminate => { return; } }; } + + debug!("reader thread shut down"); +} + +fn writer_event_loop( + conns: Arc>, + reader_receiver: Receiver, + writer_receiver: Receiver, +) { + while flume::Selector::new() + .recv(&writer_receiver, |m| { + let Ok(m) = m else { + return false; + }; + + return match m { + WriterMessage::RunMut(f) => { + let mut lock = conns.write(); + f(&mut lock.0[0]); + + // Continue + true + } + }; + }) + .recv(&reader_receiver, |m| { + let Ok(m) = m else { + return false; + }; + + return match m { + ReaderMessage::Terminate => false, + ReaderMessage::RunConst(f) => { + let lock = conns.read(); + f(&lock.0[0]); + + // Continue + true + } + }; + }) + .wait() + {} + + debug!("writer thread shut down"); } pub struct LockGuard<'a> { @@ -375,5 +421,3 @@ impl DerefMut for ArcLockGuard { return &mut self.guard.deref_mut().0[0]; } } - -static UNIQUE_CONN_ID: AtomicUsize = AtomicUsize::new(0); diff --git a/crates/sqlite/src/tests.rs b/crates/sqlite/src/tests.rs index d12965f3..1a516003 100644 --- a/crates/sqlite/src/tests.rs +++ b/crates/sqlite/src/tests.rs @@ -20,7 +20,7 @@ async fn call_success_test() { let conn = Connection::open_in_memory().unwrap(); let result = conn - .call(|conn| { + .call_writer(|conn| { conn .execute( "CREATE TABLE person(id INTEGER PRIMARY KEY AUTOINCREMENT, name TEXT NOT NULL);", @@ -38,7 +38,7 @@ async fn call_failure_test() { let conn = Connection::open_in_memory().unwrap(); let result = conn - .call(|conn| conn.execute("Invalid sql", []).map_err(|e| e.into())) + .call_writer(|conn| conn.execute("Invalid sql", []).map_err(|e| e.into())) .await; assert!(match result.unwrap_err() { @@ -65,13 +65,13 @@ async fn close_success_test() { let conn = Connection::with_opts( move || rusqlite::Connection::open(&db_path), Options { - n_read_threads: Some(2), + num_threads: Some(3), ..Default::default() }, ) .unwrap(); - assert_eq!(2, conn.threads()); + assert_eq!(3, conn.threads()); conn .execute("CREATE TABLE 'test' (id INTEGER PRIMARY KEY)", ()) @@ -116,7 +116,7 @@ async fn close_call_test() { assert!(conn.close().await.is_ok()); let result = conn2 - .call(|conn| conn.execute("SELECT 1;", []).map_err(|e| e.into())) + .call_writer(|conn| conn.execute("SELECT 1;", []).map_err(|e| e.into())) .await; assert!(matches!( @@ -130,7 +130,7 @@ async fn close_failure_test() { let conn = Connection::open_in_memory().unwrap(); conn - .call(|conn| { + .call_writer(|conn| { conn .execute( "CREATE TABLE person(id INTEGER PRIMARY KEY AUTOINCREMENT, name TEXT NOT NULL);", @@ -142,7 +142,7 @@ async fn close_failure_test() { .unwrap(); conn - .call(|conn| { + .call_writer(|conn| { // Leak a prepared statement to make the database uncloseable // See https://www.sqlite.org/c3ref/close.html for details regarding this behaviour let stmt = Box::new(conn.prepare("INSERT INTO person VALUES (1, ?1);").unwrap()); @@ -197,7 +197,7 @@ async fn test_ergonomic_errors() { let conn = Connection::open_in_memory().unwrap(); let res = conn - .call(|conn| failable_func(conn).map_err(|e| Error::Other(Box::new(e)))) + .call_writer(|conn| failable_func(conn).map_err(|e| Error::Other(Box::new(e)))) .await .unwrap_err(); @@ -213,7 +213,7 @@ async fn test_execute_and_query() { let conn = Connection::open_in_memory().unwrap(); let result = conn - .call(|conn| { + .call_writer(|conn| { conn .execute( "CREATE TABLE person(id INTEGER PRIMARY KEY, name TEXT NOT NULL);", @@ -328,7 +328,7 @@ async fn test_execute_batch() { let count = async |table: &str| -> i64 { return conn - .query_row_get(format!("SELECT COUNT(*) FROM {table}"), (), 0) + .write_query_row_get(format!("SELECT COUNT(*) FROM {table}"), (), 0) .await .unwrap() .unwrap(); @@ -402,7 +402,7 @@ async fn test_params() { let conn = Connection::open_in_memory().unwrap(); conn - .call(|conn| { + .call_writer(|conn| { conn .execute( "CREATE TABLE person(id INTEGER PRIMARY KEY, name TEXT NOT NULL);", @@ -476,7 +476,7 @@ async fn test_hooks() { row_id: i64, } - let (sender, receiver) = kanal::unbounded::(); + let (sender, receiver) = flume::unbounded::(); conn .write_lock() @@ -607,7 +607,7 @@ fn test_busy() { return Ok(conn); }, Options { - n_read_threads: Some(2), + num_threads: Some(3), ..Default::default() }, ) diff --git a/crates/wasm-runtime-host/src/lib.rs b/crates/wasm-runtime-host/src/lib.rs index dce00ee2..0269979b 100644 --- a/crates/wasm-runtime-host/src/lib.rs +++ b/crates/wasm-runtime-host/src/lib.rs @@ -551,7 +551,7 @@ mod tests { assert_eq!( 1, conn - .query_row_get::("SELECT COUNT(*) FROM tx;", (), 0) + .read_query_row_get::("SELECT COUNT(*) FROM tx;", (), 0) .await .unwrap() .unwrap()