PoC: split sqlite connection into a public API and an internal minimal rusqlite executor.

This commit is contained in:
Sebastian Jeltsch 2026-04-13 11:35:23 +02:00
parent 5defef582f
commit 01f0edb5ef
11 changed files with 597 additions and 474 deletions

1
Cargo.lock generated
View file

@ -6891,6 +6891,7 @@ dependencies = [
"rusqlite",
"serde",
"serde_rusqlite",
"smallvec",
"tempfile",
"thiserror 2.0.18",
"tokio",

View file

@ -77,7 +77,8 @@ pub async fn query_handler(
.as_ref(),
)?;
let batched_rows_result = conn.execute_batch(request.query).await;
let batched_rows_result =
trailbase_sqlite::sqlite::batch::execute_batch(&conn, request.query).await;
// In the fallback case we always need to invalidate the cache.
if must_invalidate_schema_cache {
@ -91,5 +92,6 @@ pub async fn query_handler(
rows: rows_to_sql_value_rows(&rows)?,
}));
}
return Ok(Json(QueryResponse::default()));
}

View file

@ -26,6 +26,7 @@ parking_lot = { workspace = true }
rusqlite = { workspace = true }
serde = { workspace = true }
serde_rusqlite = { workspace = true }
smallvec = { version = "1.15.1", features = ["const_generics"] }
thiserror = "2.0.12"
tokio = { workspace = true }

View file

@ -1,61 +1,20 @@
use kanal::{Receiver, Sender};
use log::*;
use parking_lot::RwLock;
use rusqlite::fallible_iterator::FallibleIterator;
use std::fmt::Debug;
use std::hash::{Hash, Hasher};
use std::ops::{Deref, DerefMut};
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use tokio::sync::oneshot;
use crate::database::Database;
use crate::error::Error;
use crate::params::Params;
use crate::rows::{Column, columns};
use crate::rows::{Row, Rows};
use crate::rows::{Row, Rows, columns};
use crate::sqlite::connection::{ConnectionImpl, map_first};
#[derive(Clone, Debug, PartialEq, serde::Deserialize)]
pub struct Database {
pub seq: u8,
pub name: String,
}
#[derive(Default)]
struct ConnectionVec(Vec<rusqlite::Connection>);
// NOTE: We must never access the same connection concurrently even as immutable &Connection, due
// to intrinsic statement cache. We can ensure this by uniquely assigning one connection to each
// thread.
unsafe impl Sync for ConnectionVec {}
enum Message {
RunMut(Box<dyn FnOnce(&mut rusqlite::Connection) + Send>),
RunConst(Box<dyn FnOnce(&rusqlite::Connection) + Send>),
Terminate,
}
#[derive(Clone)]
pub struct Options {
pub busy_timeout: std::time::Duration,
pub n_read_threads: usize,
}
impl Default for Options {
fn default() -> Self {
return Self {
busy_timeout: std::time::Duration::from_secs(5),
n_read_threads: 0,
};
}
}
// NOTE: We should probably decouple from the impl.
pub use crate::sqlite::connection::{ArcLockGuard, LockGuard, Options};
/// A handle to call functions in background thread.
#[derive(Clone)]
pub struct Connection {
id: usize,
reader: Sender<Message>,
writer: Sender<Message>,
conns: Arc<RwLock<ConnectionVec>>,
c: ConnectionImpl,
}
impl Connection {
@ -63,91 +22,8 @@ impl Connection {
builder: impl Fn() -> Result<rusqlite::Connection, E>,
opt: Option<Options>,
) -> std::result::Result<Self, E> {
let new_conn = || -> Result<rusqlite::Connection, E> {
let conn = builder()?;
if let Some(timeout) = opt.as_ref().map(|o| o.busy_timeout) {
conn.busy_timeout(timeout).expect("busy timeout failed");
}
return Ok(conn);
};
let write_conn = new_conn()?;
let path = write_conn.path().map(|p| p.to_string());
// Returns empty string for in-memory databases.
let in_memory = path.as_ref().is_none_or(|s| !s.is_empty());
let n_read_threads: i64 = match (in_memory, opt.as_ref().map_or(0, |o| o.n_read_threads)) {
(true, _) => {
// We cannot share an in-memory database across threads, they're all independent.
0
}
(false, 1) => {
warn!("A single reader thread won't improve performance, falling back to 0.");
0
}
(false, n) => {
if let Ok(max) = std::thread::available_parallelism()
&& n > max.get()
{
warn!(
"Num read threads '{n}' exceeds hardware parallelism: {}",
max.get()
);
}
n as i64
}
};
let conns = Arc::new(RwLock::new(ConnectionVec({
let mut conns = vec![write_conn];
for _ in 0..(n_read_threads - 1).max(0) {
conns.push(new_conn()?);
}
conns
})));
assert_eq!(n_read_threads.max(1) as usize, conns.read().0.len());
// Spawn writer.
let (shared_write_sender, shared_write_receiver) = kanal::unbounded::<Message>();
{
let conns = conns.clone();
std::thread::Builder::new()
.name("tb-sqlite-writer".to_string())
.spawn(move || event_loop(0, conns, shared_write_receiver))
.expect("startup");
}
// Spawn readers.
let shared_read_sender = if n_read_threads > 0 {
let (shared_read_sender, shared_read_receiver) = kanal::unbounded::<Message>();
for i in 0..n_read_threads {
// NOTE: read and writer threads are sharing the first conn, given they're mutually
// exclusive.
let index = i as usize;
let shared_read_receiver = shared_read_receiver.clone();
let conns = conns.clone();
std::thread::Builder::new()
.name(format!("tb-sqlite-reader-{index}"))
.spawn(move || event_loop(index, conns, shared_read_receiver))
.expect("startup");
}
shared_read_sender
} else {
shared_write_sender.clone()
};
debug!(
"Opened SQLite DB '{}' with {n_read_threads} reader threads",
path.as_deref().unwrap_or("<in-memory>")
);
return Ok(Self {
id: UNIQUE_CONN_ID.fetch_add(1, Ordering::SeqCst),
reader: shared_read_sender,
writer: shared_write_sender,
conns,
c: ConnectionImpl::new(builder, opt)?,
});
}
@ -157,41 +33,29 @@ impl Connection {
///
/// Will return `Err` if the underlying SQLite open call fails.
pub fn open_in_memory() -> Result<Self, Error> {
return Self::new(|| Ok(rusqlite::Connection::open_in_memory()?), None);
let conn = Self::new(
rusqlite::Connection::open_in_memory,
Some(Options {
n_read_threads: 0,
..Default::default()
}),
)?;
assert_eq!(1, conn.c.len());
return Ok(conn);
}
pub fn id(&self) -> usize {
return self.id;
return self.c.id();
}
#[inline]
pub fn write_lock(&self) -> LockGuard<'_> {
return LockGuard {
guard: self.conns.write(),
};
return self.c.write_lock();
}
// #[inline]
// pub fn try_write_lock_for(&self, duration: tokio::time::Duration) -> Option<LockGuard<'_>> {
// return self
// .conns
// .try_write_for(duration)
// .map(|guard| LockGuard { guard });
// }
// #[inline]
// pub fn write_arc_lock(&self) -> ArcLockGuard {
// return ArcLockGuard {
// guard: self.conns.write_arc(),
// };
// }
#[inline]
pub fn try_write_arc_lock_for(&self, duration: tokio::time::Duration) -> Option<ArcLockGuard> {
return self
.conns
.try_write_arc_for(duration)
.map(|guard| ArcLockGuard { guard });
return self.c.try_write_arc_lock_for(duration);
}
/// Call a function in background thread and get the result
@ -200,45 +64,20 @@ impl Connection {
/// # Failure
///
/// Will return `Err` if the database connection has been closed.
#[inline]
pub async fn call<F, R>(&self, function: F) -> Result<R, Error>
where
F: FnOnce(&mut rusqlite::Connection) -> Result<R, Error> + Send + 'static,
R: Send + 'static,
{
// return call_impl(&self.writer, function).await;
let (sender, receiver) = oneshot::channel::<Result<R, Error>>();
self
.writer
.send(Message::RunMut(Box::new(move |conn| {
if !sender.is_closed() {
let _ = sender.send(function(conn));
}
})))
.map_err(|_| Error::ConnectionClosed)?;
receiver.await.map_err(|_| Error::ConnectionClosed)?
return self.c.call(function).await;
}
#[inline]
pub async fn call_reader<F, R>(&self, function: F) -> Result<R, Error>
where
F: FnOnce(&rusqlite::Connection) -> Result<R, Error> + Send + 'static,
R: Send + 'static,
{
let (sender, receiver) = oneshot::channel::<Result<R, Error>>();
self
.reader
.send(Message::RunConst(Box::new(move |conn| {
if !sender.is_closed() {
let _ = sender.send(function(conn));
}
})))
.map_err(|_| Error::ConnectionClosed)?;
receiver.await.map_err(|_| Error::ConnectionClosed)?
return self.c.call_reader(function).await;
}
/// Query SQL statement.
@ -248,30 +87,8 @@ impl Connection {
params: impl Params + Send + 'static,
) -> Result<Rows, Error> {
return self
.call_reader(move |conn: &rusqlite::Connection| {
let mut stmt = conn.prepare_cached(sql.as_ref())?;
assert!(stmt.readonly());
params.bind(&mut stmt)?;
let rows = stmt.raw_query();
return crate::rows::from_rows(rows);
})
.await;
}
pub async fn write_query_rows(
&self,
sql: impl AsRef<str> + Send + 'static,
params: impl Params + Send + 'static,
) -> Result<Rows, Error> {
return self
.call(move |conn: &mut rusqlite::Connection| {
let mut stmt = conn.prepare_cached(sql.as_ref())?;
params.bind(&mut stmt)?;
let rows = stmt.raw_query();
return crate::rows::from_rows(rows);
})
.c
.read_query_rows_f(sql, params, crate::rows::from_rows)
.await;
}
@ -281,88 +98,15 @@ impl Connection {
params: impl Params + Send + 'static,
) -> Result<Option<Row>, Error> {
return self
.read_query_row_f(sql, params, |row| {
return crate::rows::from_row(row, Arc::new(columns(row.as_ref())));
.c
.read_query_rows_f(sql, params, |rows| {
return map_first(rows, |row| {
return crate::rows::from_row(row, Arc::new(columns(row.as_ref())));
});
})
.await;
}
#[inline]
async fn query_row_f<T, E>(
&self,
sql: impl AsRef<str> + Send + 'static,
params: impl Params + Send + 'static,
f: impl (FnOnce(&rusqlite::Row<'_>) -> Result<T, E>) + Send + 'static,
) -> Result<Option<T>, Error>
where
T: Send + 'static,
Error: From<E>,
{
return self
.call(move |conn: &mut rusqlite::Connection| {
let mut stmt = conn.prepare_cached(sql.as_ref())?;
params.bind(&mut stmt)?;
let mut rows = stmt.raw_query();
if let Some(row) = rows.next()? {
return Ok(Some(f(row)?));
}
Ok(None)
})
.await;
}
#[inline]
pub async fn query_row_get<T>(
&self,
sql: impl AsRef<str> + Send + 'static,
params: impl Params + Send + 'static,
index: usize,
) -> Result<Option<T>, Error>
where
T: rusqlite::types::FromSql + Send + 'static,
{
return self
.query_row_f(
sql,
params,
move |row: &rusqlite::Row<'_>| -> Result<T, Error> {
return Ok(row.get(index)?);
},
)
.await;
}
#[inline]
async fn read_query_row_f<T, E>(
&self,
sql: impl AsRef<str> + Send + 'static,
params: impl Params + Send + 'static,
f: impl (FnOnce(&rusqlite::Row<'_>) -> std::result::Result<T, E>) + Send + 'static,
) -> Result<Option<T>, Error>
where
T: Send + 'static,
Error: From<E>,
{
return self
.call_reader(move |conn: &rusqlite::Connection| {
let mut stmt = conn.prepare_cached(sql.as_ref())?;
assert!(stmt.readonly());
params.bind(&mut stmt)?;
let mut rows = stmt.raw_query();
if let Some(row) = rows.next()? {
return Ok(Some(f(row)?));
}
Ok(None)
})
.await;
}
#[inline]
pub async fn read_query_row_get<T>(
&self,
sql: impl AsRef<str> + Send + 'static,
@ -373,13 +117,12 @@ impl Connection {
T: rusqlite::types::FromSql + Send + 'static,
{
return self
.read_query_row_f(
sql,
params,
move |row: &rusqlite::Row<'_>| -> Result<T, Error> {
.c
.read_query_rows_f(sql, params, move |rows| {
return map_first(rows, move |row| {
return Ok(row.get(index)?);
},
)
});
})
.await;
}
@ -389,20 +132,11 @@ impl Connection {
params: impl Params + Send + 'static,
) -> Result<Option<T>, Error> {
return self
.read_query_row_f(sql, params, |row| {
serde_rusqlite::from_row(row).map_err(Error::DeserializeValue)
})
.await;
}
pub async fn write_query_value<T: serde::de::DeserializeOwned + Send + 'static>(
&self,
sql: impl AsRef<str> + Send + 'static,
params: impl Params + Send + 'static,
) -> Result<Option<T>, Error> {
return self
.query_row_f(sql, params, |row| {
serde_rusqlite::from_row(row).map_err(Error::DeserializeValue)
.c
.read_query_rows_f(sql, params, |rows| {
return map_first(rows, move |row| {
serde_rusqlite::from_row(row).map_err(Error::DeserializeValue)
});
})
.await;
}
@ -413,18 +147,56 @@ impl Connection {
params: impl Params + Send + 'static,
) -> Result<Vec<T>, Error> {
return self
.call_reader(move |conn: &rusqlite::Connection| {
let mut stmt = conn.prepare_cached(sql.as_ref())?;
assert!(stmt.readonly());
.c
.read_query_rows_f(sql, params, |rows| {
return serde_rusqlite::from_rows(rows)
.collect::<Result<Vec<_>, _>>()
.map_err(Error::DeserializeValue);
})
.await;
}
params.bind(&mut stmt)?;
let mut rows = stmt.raw_query();
pub async fn write_query_rows(
&self,
sql: impl AsRef<str> + Send + 'static,
params: impl Params + Send + 'static,
) -> Result<Rows, Error> {
return self
.c
.write_query_rows_f(sql, params, crate::rows::from_rows)
.await;
}
let mut values = vec![];
while let Some(row) = rows.next()? {
values.push(serde_rusqlite::from_row(row).map_err(Error::DeserializeValue)?);
}
return Ok(values);
pub async fn query_row_get<T>(
&self,
sql: impl AsRef<str> + Send + 'static,
params: impl Params + Send + 'static,
index: usize,
) -> Result<Option<T>, Error>
where
T: rusqlite::types::FromSql + Send + 'static,
{
return self
.c
.write_query_rows_f(sql, params, move |rows| {
return map_first(rows, move |row| {
return Ok(row.get(index)?);
});
})
.await;
}
pub async fn write_query_value<T: serde::de::DeserializeOwned + Send + 'static>(
&self,
sql: impl AsRef<str> + Send + 'static,
params: impl Params + Send + 'static,
) -> Result<Option<T>, Error> {
return self
.c
.write_query_rows_f(sql, params, |rows| {
return map_first(rows, |row| {
serde_rusqlite::from_row(row).map_err(Error::DeserializeValue)
});
})
.await;
}
@ -435,17 +207,11 @@ impl Connection {
params: impl Params + Send + 'static,
) -> Result<Vec<T>, Error> {
return self
.call(move |conn: &mut rusqlite::Connection| {
let mut stmt = conn.prepare_cached(sql.as_ref())?;
params.bind(&mut stmt)?;
let mut rows = stmt.raw_query();
let mut values = vec![];
while let Some(row) = rows.next()? {
values.push(serde_rusqlite::from_row(row).map_err(Error::DeserializeValue)?);
}
return Ok(values);
.c
.write_query_rows_f(sql, params, |rows| {
return serde_rusqlite::from_rows(rows)
.collect::<Result<Vec<_>, _>>()
.map_err(Error::DeserializeValue);
})
.await;
}
@ -456,75 +222,35 @@ impl Connection {
sql: impl AsRef<str> + Send + 'static,
params: impl Params + Send + 'static,
) -> Result<usize, Error> {
return self
.call(move |conn: &mut rusqlite::Connection| {
let mut stmt = conn.prepare_cached(sql.as_ref())?;
params.bind(&mut stmt)?;
let n = stmt.raw_execute()?;
return Ok(n);
})
.await;
return self.c.execute(sql, params).await;
}
/// Batch execute SQL statements and return rows of last statement.
pub async fn execute_batch(
&self,
sql: impl AsRef<str> + Send + 'static,
) -> Result<Option<Rows>, Error> {
return self
.call(move |conn: &mut rusqlite::Connection| {
let batch = rusqlite::Batch::new(conn, sql.as_ref());
let mut p = batch.peekable();
while let Some(mut stmt) = p.next()? {
let mut rows = stmt.raw_query();
let row = rows.next()?;
match p.peek()? {
Some(_) => {}
None => {
if let Some(row) = row {
let cols: Arc<Vec<Column>> = Arc::new(columns(row.as_ref()));
let mut result = vec![crate::rows::from_row(row, cols.clone())?];
while let Some(row) = rows.next()? {
result.push(crate::rows::from_row(row, cols.clone())?);
}
return Ok(Some(Rows(result, cols)));
}
return Ok(None);
}
}
}
return Ok(None);
})
.await;
/// Batch execute provided SQL statementsi in batch.
pub async fn execute_batch(&self, sql: impl AsRef<str> + Send + 'static) -> Result<(), Error> {
return self.c.execute_batch(sql).await;
}
pub fn attach(&self, path: &str, name: &str) -> Result<(), Error> {
let query = format!("ATTACH DATABASE '{path}' AS {name} ");
let lock = self.conns.write();
for conn in &lock.0 {
return self.c.map(move |conn| {
conn.execute(&query, ())?;
}
return Ok(());
return Ok(());
});
}
pub fn detach(&self, name: &str) -> Result<(), Error> {
let query = format!("DETACH DATABASE {name}");
let lock = self.conns.write();
for conn in &lock.0 {
return self.c.map(move |conn| {
conn.execute(&query, ())?;
}
return Ok(());
return Ok(());
});
}
pub async fn list_databases(&self) -> Result<Vec<Database>, Error> {
return self.call_reader(crate::sqlite::list_databases).await;
return self
.c
.call_reader(crate::sqlite::util::list_databases)
.await;
}
/// Close the database connection.
@ -539,28 +265,7 @@ impl Connection {
///
/// Will return `Err` if the underlying SQLite close call fails.
pub async fn close(self) -> Result<(), Error> {
let _ = self.writer.send(Message::Terminate);
while self.reader.send(Message::Terminate).is_ok() {
// Continue to close readers while the channel is alive.
}
let mut errors = vec![];
let conns: ConnectionVec = std::mem::take(&mut self.conns.write());
for conn in conns.0 {
// NOTE: rusqlite's `Connection::close()` returns itself, to allow users to retry
// failed closes. We on the other, may be left in a partially closed state with multiple
// connections. Ignorance is bliss.
if let Err((_self, err)) = conn.close() {
errors.push(err);
};
}
if !errors.is_empty() {
warn!("Closing connection: {errors:?}");
return Err(errors.swap_remove(0).into());
}
return Ok(());
return self.c.close().await;
}
}
@ -572,72 +277,14 @@ impl Debug for Connection {
impl Hash for Connection {
fn hash<H: Hasher>(&self, state: &mut H) {
self.id.hash(state);
self.id().hash(state);
}
}
impl PartialEq for Connection {
fn eq(&self, other: &Self) -> bool {
return self.id == other.id;
return self.id() == other.id();
}
}
impl Eq for Connection {}
fn event_loop(id: usize, conns: Arc<RwLock<ConnectionVec>>, receiver: Receiver<Message>) {
while let Ok(message) = receiver.recv() {
match message {
Message::RunConst(f) => {
let lock = conns.read();
f(&lock.0[id])
}
Message::RunMut(f) => {
let mut lock = conns.write();
f(&mut lock.0[0])
}
Message::Terminate => {
return;
}
};
}
}
pub struct LockGuard<'a> {
guard: parking_lot::RwLockWriteGuard<'a, ConnectionVec>,
}
impl Deref for LockGuard<'_> {
type Target = rusqlite::Connection;
#[inline]
fn deref(&self) -> &rusqlite::Connection {
return &self.guard.deref().0[0];
}
}
impl DerefMut for LockGuard<'_> {
#[inline]
fn deref_mut(&mut self) -> &mut rusqlite::Connection {
return &mut self.guard.deref_mut().0[0];
}
}
pub struct ArcLockGuard {
guard: parking_lot::ArcRwLockWriteGuard<parking_lot::RawRwLock, ConnectionVec>,
}
impl Deref for ArcLockGuard {
type Target = rusqlite::Connection;
#[inline]
fn deref(&self) -> &rusqlite::Connection {
return &self.guard.deref().0[0];
}
}
impl DerefMut for ArcLockGuard {
#[inline]
fn deref_mut(&mut self) -> &mut rusqlite::Connection {
return &mut self.guard.deref_mut().0[0];
}
}
static UNIQUE_CONN_ID: AtomicUsize = AtomicUsize::new(0);

View file

@ -0,0 +1,5 @@
#[derive(Clone, Debug, PartialEq, serde::Deserialize)]
pub struct Database {
pub seq: u8,
pub name: String,
}

View file

@ -11,6 +11,7 @@
)]
pub mod connection;
pub mod database;
pub mod error;
pub mod params;
pub mod rows;
@ -19,6 +20,7 @@ pub mod to_sql;
pub mod value;
pub use connection::Connection;
pub use database::Database;
pub use error::Error;
pub use params::{NamedParamRef, NamedParams, NamedParamsRef, Params};
pub use rows::{Row, Rows, ValueType};

View file

@ -0,0 +1,47 @@
use rusqlite::fallible_iterator::FallibleIterator;
use std::sync::Arc;
use crate::connection::Connection;
use crate::error::Error;
use crate::rows::{Column, Rows, columns, from_row};
/// Batch execute SQL statements and return rows of last statement.
///
/// NOTE: This is a weird batch flavor that returns the last statement's rows. Most clients don't
/// support this. We thus keep it out of `Connection::execute_batch()`. We currently only rely on
/// it in the admin dashboard. Avoid further adoption.
pub async fn execute_batch(
conn: &Connection,
sql: impl AsRef<str> + Send + 'static,
) -> Result<Option<Rows>, Error> {
return conn
.call(move |conn: &mut rusqlite::Connection| {
let batch = rusqlite::Batch::new(conn, sql.as_ref());
let mut p = batch.peekable();
while let Some(mut stmt) = p.next()? {
let mut rows = stmt.raw_query();
let row = rows.next()?;
match p.peek()? {
Some(_) => {}
None => {
if let Some(row) = row {
let cols: Arc<Vec<Column>> = Arc::new(columns(row.as_ref()));
let mut result = vec![from_row(row, cols.clone())?];
while let Some(row) = rows.next()? {
result.push(crate::rows::from_row(row, cols.clone())?);
}
return Ok(Some(Rows(result, cols)));
}
return Ok(None);
}
}
}
return Ok(None);
})
.await;
}

View file

@ -0,0 +1,402 @@
use kanal::{Receiver, Sender};
use log::*;
use parking_lot::RwLock;
use rusqlite::fallible_iterator::FallibleIterator;
use std::ops::{Deref, DerefMut};
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use tokio::sync::oneshot;
use crate::error::Error;
use crate::params::Params;
#[derive(Default)]
struct ConnectionVec(smallvec::SmallVec<[rusqlite::Connection; 32]>);
// NOTE: We must never access the same connection concurrently even as immutable &Connection, due
// to intrinsic statement cache. We can ensure this by uniquely assigning one connection to each
// thread.
unsafe impl Sync for ConnectionVec {}
enum Message {
RunMut(Box<dyn FnOnce(&mut rusqlite::Connection) + Send>),
RunConst(Box<dyn FnOnce(&rusqlite::Connection) + Send>),
Terminate,
}
#[derive(Clone)]
pub struct Options {
pub busy_timeout: std::time::Duration,
pub n_read_threads: usize,
}
impl Default for Options {
fn default() -> Self {
return Self {
busy_timeout: std::time::Duration::from_secs(5),
n_read_threads: 0,
};
}
}
/// A handle to call functions in background thread.
#[derive(Clone)]
pub(crate) struct ConnectionImpl {
id: usize,
reader: Sender<Message>,
writer: Sender<Message>,
// NOTE: Is shared across reader and writer worker threads.
conns: Arc<RwLock<ConnectionVec>>,
}
impl ConnectionImpl {
pub fn new<E>(
builder: impl Fn() -> Result<rusqlite::Connection, E>,
opt: Option<Options>,
) -> Result<Self, E> {
let Options {
busy_timeout,
n_read_threads,
} = opt.unwrap_or_default();
let new_conn = || -> Result<rusqlite::Connection, E> {
let conn = builder()?;
if !busy_timeout.is_zero() {
conn
.busy_timeout(busy_timeout)
.expect("busy timeout failed");
}
return Ok(conn);
};
let write_conn = new_conn()?;
let path = write_conn.path().map(|p| p.to_string());
// Returns empty string for in-memory databases.
let in_memory = path
.as_ref()
.is_none_or(|s| s.is_empty() || s == ":memory:");
let n_read_threads: i64 = match (in_memory, n_read_threads) {
(true, _) => {
// We cannot share an in-memory database across threads, they're all independent.
0
}
(false, 1) => {
warn!("A single reader thread won't improve performance, falling back to 0.");
0
}
(false, n) => {
if let Ok(max) = std::thread::available_parallelism()
&& n > max.get()
{
warn!(
"Num read threads '{n}' exceeds hardware parallelism: {}",
max.get()
);
}
n as i64
}
};
let conns = Arc::new(RwLock::new(ConnectionVec({
let mut conns = vec![write_conn];
for _ in 0..(n_read_threads - 1).max(0) {
conns.push(new_conn()?);
}
conns.into()
})));
assert_eq!(n_read_threads.max(1) as usize, conns.read().0.len());
// Spawn writer.
let (shared_write_sender, shared_write_receiver) = kanal::unbounded::<Message>();
{
let conns = conns.clone();
std::thread::Builder::new()
.name("tb-sqlite-writer".to_string())
.spawn(move || event_loop(0, conns, shared_write_receiver))
.expect("startup");
}
// Spawn readers.
let shared_read_sender = if n_read_threads > 0 {
let (shared_read_sender, shared_read_receiver) = kanal::unbounded::<Message>();
for i in 0..n_read_threads {
// NOTE: read and writer threads are sharing the first conn, given they're mutually
// exclusive.
let index = i as usize;
let shared_read_receiver = shared_read_receiver.clone();
let conns = conns.clone();
std::thread::Builder::new()
.name(format!("tb-sqlite-reader-{index}"))
.spawn(move || event_loop(index, conns, shared_read_receiver))
.expect("startup");
}
shared_read_sender
} else {
shared_write_sender.clone()
};
debug!(
"Opened SQLite DB '{}' with {n_read_threads} reader threads",
path.as_deref().unwrap_or("<in-memory>")
);
return Ok(Self {
id: UNIQUE_CONN_ID.fetch_add(1, Ordering::SeqCst),
reader: shared_read_sender,
writer: shared_write_sender,
conns,
});
}
pub fn id(&self) -> usize {
return self.id;
}
pub(crate) fn len(&self) -> usize {
return self.conns.read().0.len();
}
#[inline]
pub fn write_lock(&self) -> LockGuard<'_> {
return LockGuard {
guard: self.conns.write(),
};
}
#[inline]
pub fn try_write_arc_lock_for(&self, duration: tokio::time::Duration) -> Option<ArcLockGuard> {
return self
.conns
.try_write_arc_for(duration)
.map(|guard| ArcLockGuard { guard });
}
#[inline]
pub(crate) fn map(
&self,
f: impl Fn(&rusqlite::Connection) -> Result<(), Error> + Send + 'static,
) -> Result<(), Error> {
let lock = self.conns.write();
for conn in &lock.0 {
f(conn)?;
}
return Ok(());
}
#[inline]
pub async fn call<F, R>(&self, function: F) -> Result<R, Error>
where
F: FnOnce(&mut rusqlite::Connection) -> Result<R, Error> + Send + 'static,
R: Send + 'static,
{
// return call_impl(&self.writer, function).await;
let (sender, receiver) = oneshot::channel::<Result<R, Error>>();
self
.writer
.send(Message::RunMut(Box::new(move |conn| {
if !sender.is_closed() {
let _ = sender.send(function(conn));
}
})))
.map_err(|_| Error::ConnectionClosed)?;
receiver.await.map_err(|_| Error::ConnectionClosed)?
}
#[inline]
pub async fn call_reader<F, R>(&self, function: F) -> Result<R, Error>
where
F: FnOnce(&rusqlite::Connection) -> Result<R, Error> + Send + 'static,
R: Send + 'static,
{
let (sender, receiver) = oneshot::channel::<Result<R, Error>>();
self
.reader
.send(Message::RunConst(Box::new(move |conn| {
if !sender.is_closed() {
let _ = sender.send(function(conn));
}
})))
.map_err(|_| Error::ConnectionClosed)?;
receiver.await.map_err(|_| Error::ConnectionClosed)?
}
#[inline]
pub async fn write_query_rows_f<T>(
&self,
sql: impl AsRef<str> + Send + 'static,
params: impl Params + Send + 'static,
f: impl (FnOnce(rusqlite::Rows<'_>) -> Result<T, Error>) + Send + 'static,
) -> Result<T, Error>
where
T: Send + 'static,
{
return self
.call(move |conn: &mut rusqlite::Connection| {
let mut stmt = conn.prepare_cached(sql.as_ref())?;
params.bind(&mut stmt)?;
return f(stmt.raw_query());
})
.await;
}
#[inline]
pub async fn read_query_rows_f<T>(
&self,
sql: impl AsRef<str> + Send + 'static,
params: impl Params + Send + 'static,
f: impl (FnOnce(rusqlite::Rows<'_>) -> Result<T, Error>) + Send + 'static,
) -> Result<T, Error>
where
T: Send + 'static,
{
return self
.call_reader(move |conn: &rusqlite::Connection| {
let mut stmt = conn.prepare_cached(sql.as_ref())?;
assert!(stmt.readonly());
params.bind(&mut stmt)?;
return f(stmt.raw_query());
})
.await;
}
pub async fn execute(
&self,
sql: impl AsRef<str> + Send + 'static,
params: impl Params + Send + 'static,
) -> Result<usize, Error> {
return self
.call(move |conn: &mut rusqlite::Connection| {
let mut stmt = conn.prepare_cached(sql.as_ref())?;
params.bind(&mut stmt)?;
return Ok(stmt.raw_execute()?);
})
.await;
}
pub async fn execute_batch(&self, sql: impl AsRef<str> + Send + 'static) -> Result<(), Error> {
self
.call(move |conn: &mut rusqlite::Connection| {
let mut batch = rusqlite::Batch::new(conn, sql.as_ref());
while let Some(mut stmt) = batch.next()? {
// NOTE: We must use `raw_query` instead of `raw_execute`, otherwise queries
// returning rows (e.g. SELECT) will return an error. Rusqlite's batch_execute
// behaves consistently.
let _row = stmt.raw_query().next()?;
}
return Ok(());
})
.await?;
return Ok(());
}
pub async fn close(self) -> Result<(), Error> {
let _ = self.writer.send(Message::Terminate);
while self.reader.send(Message::Terminate).is_ok() {
// Continue to close readers while the channel is alive.
}
let mut errors = vec![];
let conns: ConnectionVec = std::mem::take(&mut self.conns.write());
for conn in conns.0 {
// NOTE: rusqlite's `Connection::close()` returns itself, to allow users to retry
// failed closes. We on the other, may be left in a partially closed state with multiple
// connections. Ignorance is bliss.
if let Err((_self, err)) = conn.close() {
errors.push(err);
};
}
if !errors.is_empty() {
warn!("Closing connection: {errors:?}");
return Err(errors.swap_remove(0).into());
}
return Ok(());
}
}
fn event_loop(id: usize, conns: Arc<RwLock<ConnectionVec>>, receiver: Receiver<Message>) {
while let Ok(message) = receiver.recv() {
match message {
Message::RunConst(f) => {
let lock = conns.read();
f(&lock.0[id])
}
Message::RunMut(f) => {
let mut lock = conns.write();
f(&mut lock.0[0])
}
Message::Terminate => {
return;
}
};
}
}
pub struct LockGuard<'a> {
guard: parking_lot::RwLockWriteGuard<'a, ConnectionVec>,
}
impl Deref for LockGuard<'_> {
type Target = rusqlite::Connection;
#[inline]
fn deref(&self) -> &rusqlite::Connection {
return &self.guard.deref().0[0];
}
}
impl DerefMut for LockGuard<'_> {
#[inline]
fn deref_mut(&mut self) -> &mut rusqlite::Connection {
return &mut self.guard.deref_mut().0[0];
}
}
pub struct ArcLockGuard {
guard: parking_lot::ArcRwLockWriteGuard<parking_lot::RawRwLock, ConnectionVec>,
}
impl Deref for ArcLockGuard {
type Target = rusqlite::Connection;
#[inline]
fn deref(&self) -> &rusqlite::Connection {
return &self.guard.deref().0[0];
}
}
impl DerefMut for ArcLockGuard {
#[inline]
fn deref_mut(&mut self) -> &mut rusqlite::Connection {
return &mut self.guard.deref_mut().0[0];
}
}
#[inline]
pub(crate) fn map_first<T>(
mut rows: rusqlite::Rows<'_>,
f: impl (FnOnce(&rusqlite::Row<'_>) -> Result<T, Error>) + Send + 'static,
) -> Result<Option<T>, Error>
where
T: Send + 'static,
{
if let Some(row) = rows.next()? {
return Ok(Some(f(row)?));
}
return Ok(None);
}
static UNIQUE_CONN_ID: AtomicUsize = AtomicUsize::new(0);

View file

@ -0,0 +1,7 @@
pub mod batch;
pub mod connection;
pub mod util;
pub use util::extract_record_values;
pub use util::extract_row_id;
pub use util::list_databases;

View file

@ -1,6 +1,6 @@
use rusqlite::hooks::PreUpdateCase;
use crate::connection::Database;
use crate::database::Database;
use crate::error::Error;
use crate::value::Value;

View file

@ -4,9 +4,9 @@ use rusqlite::{ErrorCode, ffi};
use serde::Deserialize;
use std::borrow::Cow;
use crate::connection::{Connection, Database, Options};
use crate::connection::{Connection, Options};
use crate::sqlite::extract_row_id;
use crate::{Error, Value, ValueType};
use crate::{Database, Error, Value, ValueType};
#[tokio::test]
async fn open_in_memory_test() {
@ -275,19 +275,28 @@ async fn test_execute_and_query() {
assert_eq!(person.id, 1);
assert_eq!(person.name, "baz");
let rows = conn
.execute_batch(
r#"
let rows = crate::sqlite::batch::execute_batch(
&conn,
r#"
CREATE TABLE foo (id INTEGER) STRICT;
INSERT INTO foo (id) VALUES (17);
SELECT * FROM foo;
"#,
)
.await
.unwrap()
.unwrap();
)
.await
.unwrap()
.unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(rows.0.get(0).unwrap().get::<i64>(0), Ok(17));
// Lastly make sure rusqlite and out Connection consistently execute batches
// containing statements returning rows.
let query = "
SELECT * FROM foo;
SELECT * FROM foo;
";
conn.write_lock().execute_batch(query).unwrap();
conn.execute_batch(query).await.unwrap();
}
#[tokio::test]