Further reduce leaky trailbase_sqlite::Connection API surface.

This commit is contained in:
Sebastian Jeltsch 2026-04-12 12:52:34 +02:00
parent 23b659b8c6
commit 59ec41a260
26 changed files with 195 additions and 195 deletions

View file

@ -60,10 +60,10 @@ async fn add_room(
name: &str,
) -> Result<[u8; 16], anyhow::Error> {
let room: [u8; 16] = conn
.query_row_f(
.query_row_get(
"INSERT INTO room (name) VALUES ($1) RETURNING id",
params!(name.to_string()),
|row| row.get(0),
0,
)
.await?
.unwrap();

View file

@ -59,13 +59,13 @@ pub async fn list_logs_handler(
};
let total_row_count: i64 = conn
.read_query_row_f(
.read_query_row_get(
format!(
"SELECT COUNT(*) FROM {LOGS_TABLE} AS {TABLE_ALIAS} WHERE {where_clause}",
where_clause = filter_where_clause.clause
),
filter_where_clause.params.clone(),
|row| row.get(0),
0,
)
.await?
.unwrap_or(-1);

View file

@ -216,11 +216,9 @@ mod tests {
Uuid::from_slice(&row.myid).unwrap()
};
let count = || async {
let count = async || -> i64 {
conn
.read_query_row_f(format!("SELECT COUNT(*) FROM '{table_name}'"), (), |row| {
row.get::<_, i64>(0)
})
.read_query_row_get(format!("SELECT COUNT(*) FROM '{table_name}'"), (), 0)
.await
.unwrap()
.unwrap()

View file

@ -81,9 +81,7 @@ pub async fn list_rows_handler(
table = qualified_name.escaped_string()
);
conn
.read_query_row_f(count_query, filter_where_clause.params.clone(), |row| {
row.get(0)
})
.read_query_row_get(count_query, filter_where_clause.params.clone(), 0)
.await?
.unwrap_or(-1)
};
@ -295,7 +293,7 @@ mod tests {
.unwrap();
let cnt: i64 = conn
.read_query_row_f("SELECT COUNT(*) FROM test_table", (), |row| row.get(0))
.read_query_row_get("SELECT COUNT(*) FROM test_table", (), 0)
.await
.unwrap()
.unwrap();

View file

@ -81,13 +81,13 @@ pub async fn list_users_handler(
build_filter_where_clause("_ROW_", &table_metadata.column_metadata, filter_params)?;
let total_row_count: i64 = conn
.read_query_row_f(
.read_query_row_get(
format!(
"SELECT COUNT(*) FROM {USER_TABLE} AS _ROW_ WHERE {where_clause}",
where_clause = filter_where_clause.clause
),
filter_where_clause.params.clone(),
|row| row.get(0),
0,
)
.await?
.unwrap_or(-1);

View file

@ -222,10 +222,10 @@ pub async fn login_otp_handler(
let Some(user_id) = state
.session_conn()
.query_row_f(
.read_query_row_get(
LOOKUP_OTP_QUERY,
params!(normalized_email, otp_code.to_string()),
|row| row.get::<_, [u8; 16]>(0),
0,
)
.await?
else {

View file

@ -67,10 +67,10 @@ pub(crate) async fn auth_code_to_token_handler(
let Some(user_id) = state
.session_conn()
.query_row_f(
.read_query_row_get(
AUTH_CODE_QUERY,
params!(authorization_code, pkce_code_challenge),
|row| row.get::<_, [u8; 16]>(0),
0,
)
.await?
else {

View file

@ -372,10 +372,10 @@ async fn test_auth_password_login_flow_without_pkce() {
let refresh_token: String = state
.session_conn()
.read_query_row_f(
.read_query_row_get(
format!("SELECT refresh_token FROM {SESSION_TABLE} WHERE user = $1;"),
(user.uuid.into_bytes().to_vec(),),
|row| row.get(0),
0,
)
.await
.unwrap()
@ -613,10 +613,10 @@ async fn test_auth_change_email_flow() {
let db_email: String = state
.user_conn()
.read_query_row_f(
.read_query_row_get(
format!(r#"SELECT email FROM "{USER_TABLE}" WHERE id = $1"#),
params!(user.uuid.into_bytes()),
|row| row.get(0),
0,
)
.await
.unwrap()
@ -692,10 +692,10 @@ async fn test_auth_delete_user_flow() {
let user_exists: bool = state
.user_conn()
.read_query_row_f(
.read_query_row_get(
format!(r#"SELECT EXISTS(SELECT * FROM "{USER_TABLE}" WHERE id = $1)"#),
params!(user.uuid.into_bytes()),
|row| row.get(0),
0,
)
.await
.unwrap()
@ -801,10 +801,10 @@ async fn test_auth_otp_flow() {
async fn session_exists(state: &AppState, user_id: Uuid) -> bool {
return state
.session_conn()
.read_query_row_f(
.read_query_row_get(
format!("SELECT EXISTS(SELECT 1 FROM {SESSION_TABLE} WHERE user = $1)"),
params!(user_id.into_bytes().to_vec()),
|row| row.get(0),
0,
)
.await
.unwrap()

View file

@ -386,10 +386,10 @@ fn get_redirect_location<T: IntoResponse>(response: T) -> Option<String> {
async fn session_exists(state: &AppState, user_id: Uuid) -> bool {
return state
.session_conn()
.read_query_row_f(
.read_query_row_get(
format!("SELECT EXISTS(SELECT 1 FROM {SESSION_TABLE} WHERE user = $1)"),
(user_id.into_bytes().to_vec(),),
|row| row.get(0),
0,
)
.await
.unwrap()

View file

@ -219,9 +219,7 @@ pub(crate) async fn reauth_with_refresh_token(
let Some(user_id) = state
.session_conn()
.query_row_f(SESSION_QUERY, params!(refresh_token), |row| {
row.get::<_, [u8; 16]>(0)
})
.read_query_row_get::<[u8; 16]>(SESSION_QUERY, params!(refresh_token), 0)
.await?
else {
// Row not found case, typically expected in one of 4 cases:

View file

@ -304,9 +304,7 @@ pub async fn user_exists(state: &AppState, email: &str) -> bool {
return match state
.user_conn()
.read_query_row_f(QUERY, params!(email.to_string()), |row| {
row.get::<_, bool>(0)
})
.read_query_row_get(QUERY, params!(email.to_string()), 0)
.await
{
Ok(Some(row)) => row,
@ -324,9 +322,7 @@ pub(crate) async fn is_admin(state: &AppState, user_id: &uuid::Uuid) -> bool {
return match state
.user_conn()
.read_query_row_f(QUERY, params!(user_id.as_bytes().to_vec()), |row| {
row.get::<_, i64>(0)
})
.read_query_row_get::<i64>(QUERY, params!(user_id.as_bytes().to_vec()), 0)
.await
{
Ok(Some(row)) => row > 0,

View file

@ -434,7 +434,7 @@ mod test {
// Bulk inserts are rolled back in a transaction is second insert fails.
let count_before: i64 = state
.conn()
.read_query_row_f("SELECT COUNT(*) FROM message", (), |row| row.get(0))
.read_query_row_get("SELECT COUNT(*) FROM message", (), 0)
.await
.unwrap()
.unwrap();
@ -460,7 +460,7 @@ mod test {
let count_after: i64 = state
.conn()
.read_query_row_f("SELECT COUNT(*) FROM message", (), |row| row.get(0))
.read_query_row_get("SELECT COUNT(*) FROM message", (), 0)
.await
.unwrap()
.unwrap();

View file

@ -152,10 +152,10 @@ mod test {
async fn message_exists(conn: &trailbase_sqlite::Connection, id: &[u8; 16]) -> bool {
let count: i64 = conn
.read_query_row_f(
.read_query_row_get(
"SELECT COUNT(*) FROM message WHERE mid = $1",
params!(*id),
|row| row.get(0),
0,
)
.await
.unwrap()

View file

@ -297,14 +297,14 @@ mod test {
.unwrap();
let count: i64 = conn
.read_query_row_f(
.read_query_row_get(
format!(r#"SELECT COUNT(*) from "{USER_TABLE}" WHERE email = :email"#),
trailbase_sqlite::named_params! {
":email": EMAIL,
":unused": "unused",
":foo": 42,
},
|row| row.get(0),
0,
)
.await
.unwrap()
@ -1010,9 +1010,7 @@ mod test {
let index: String = state
.conn()
.read_query_row_f(r#"SELECT "index" from "table" WHERE pid = 2"#, (), |row| {
row.get(0)
})
.read_query_row_get(r#"SELECT "index" from "table" WHERE pid = 2"#, (), 0)
.await
.unwrap()
.unwrap();

View file

@ -564,6 +564,10 @@ impl RecordApi {
}
/// Check if the given user (if any) can access a record given the request and the operation.
///
/// QUESTION: Could we structure this in a way that we yield less in case there's no read-access
/// rule, e.g. sync and return (yikes):
/// `Result<Option<Box<dyn Future<Output=Result<(), RecordError>>>>, RecordErrorr>`
#[inline]
pub(crate) async fn check_record_level_read_access_for_subscriptions(
&self,

View file

@ -252,10 +252,10 @@ impl PerConnectionState {
let Some(row_id): Option<i64> = api
.conn()
.read_query_row_f(
.read_query_row_get(
format!(r#"SELECT _rowid_ FROM {table_name} WHERE "{pk_column}" = $1"#),
[record],
|row| row.get(0),
0,
)
.await?
else {

View file

@ -56,12 +56,11 @@ async fn subscribe_to_record_test() {
let conn = state.conn().clone();
let record_id_raw = 0;
let record_id = trailbase_sqlite::Value::Integer(record_id_raw);
let rowid: i64 = conn
.query_row_f(
.query_row_get(
"INSERT INTO test (id, text) VALUES ($1, 'foo') RETURNING _rowid_",
[record_id],
|row| row.get(0),
[trailbase_sqlite::Value::Integer(record_id_raw)],
0,
)
.await
.unwrap()
@ -137,10 +136,7 @@ async fn subscribe_to_record_test() {
}
// Implicitly await for scheduled cleanups to go through.
conn
.read_query_row_f("SELECT 1", (), |row| row.get::<_, i64>(0))
.await
.unwrap();
conn.read_query_row("SELECT 1", ()).await.unwrap();
}
async fn subscribe_to_records(
@ -281,10 +277,7 @@ async fn subscribe_to_table_test() {
}
// Implicitly await for scheduled cleanups to go through.
conn
.read_query_row_f("SELECT 1", (), |row| row.get::<_, i64>(0))
.await
.unwrap();
conn.read_query_row("SELECT 1", ()).await.unwrap();
assert_eq!(0, manager.num_table_subscriptions());
}
@ -297,10 +290,10 @@ async fn subscription_lifecycle_test() {
let record_id_raw = 0;
let record_id = trailbase_sqlite::Value::Integer(record_id_raw);
let rowid: i64 = conn
.query_row_f(
.query_row_get(
"INSERT INTO test (id, text) VALUES ($1, 'foo') RETURNING _rowid_",
[record_id],
|row| row.get(0),
0,
)
.await
.unwrap()
@ -323,10 +316,7 @@ async fn subscription_lifecycle_test() {
drop(sse);
// Implicitly await for the cleanup to be scheduled on the sqlite executor.
conn
.read_query_row_f("SELECT 1", (), |row| row.get::<_, i64>(0))
.await
.unwrap();
conn.read_query_row("SELECT 1", ()).await.unwrap();
assert_eq!(0, manager.num_record_subscriptions());
}
@ -412,13 +402,13 @@ async fn subscription_acl_test() {
let record_id_raw = 0;
let record_id = trailbase_sqlite::Value::Integer(record_id_raw);
let _rowid: i64 = conn
.query_row_f(
.query_row_get(
"INSERT INTO test (id, user, text) VALUES ($1, $2, 'foo') RETURNING _rowid_",
[
record_id.clone(),
trailbase_sqlite::Value::Blob(user_x.to_vec()),
],
|row| row.get(0),
0,
)
.await
.unwrap()
@ -565,10 +555,7 @@ async fn test_acl_selective_table_subs() {
}
// Implicitly await for scheduled cleanups to go through.
conn
.read_query_row_f("SELECT 1", (), |row| row.get::<_, i64>(0))
.await
.unwrap();
conn.read_query_row("SELECT 1", ()).await.unwrap();
assert_eq!(0, manager.num_table_subscriptions());
}
@ -591,13 +578,13 @@ async fn subscription_acl_change_owner() {
let record_id = 0;
let _rowid: i64 = conn
.query_row_f(
.query_row_get(
"INSERT INTO test (id, user, text) VALUES ($1, $2, 'foo') RETURNING _rowid_",
[
trailbase_sqlite::Value::Integer(record_id),
trailbase_sqlite::Value::Blob(user_x_id.into()),
],
|row| row.get(0),
0,
)
.await
.unwrap()
@ -721,10 +708,7 @@ async fn subscription_filter_test() {
}
// Implicitly await for scheduled cleanups to go through.
conn
.read_query_row_f("SELECT 1", (), |row| row.get::<_, i64>(0))
.await
.unwrap();
conn.read_query_row("SELECT 1", ()).await.unwrap();
assert_eq!(0, manager.num_table_subscriptions());
}

View file

@ -159,10 +159,10 @@ mod tests {
name: &str,
) -> Result<[u8; 16], anyhow::Error> {
let room: [u8; 16] = conn
.query_row_f(
.query_row_get(
"INSERT INTO room (name) VALUES ($1) RETURNING rid",
params!(name.to_string()),
|row| row.get(0),
0,
)
.await?
.ok_or(rusqlite::Error::QueryReturnedNoRows)?;
@ -191,10 +191,10 @@ mod tests {
message: &str,
) -> Result<[u8; 16], anyhow::Error> {
let id: [u8; 16] = conn
.query_row_f(
.query_row_get(
"INSERT INTO message (_owner, room, data) VALUES ($1, $2, $3) RETURNING mid",
params!(user, room, message.to_string()),
|row| row.get(0),
0,
)
.await?
.ok_or(rusqlite::Error::QueryReturnedNoRows)?;

View file

@ -36,13 +36,13 @@ pub async fn lookup_and_parse_table_schema(
) -> Result<Table, SchemaLookupError> {
// Then get the actual table.
let sql: String = conn
.read_query_row_f(
.read_query_row_get(
format!(
"SELECT sql FROM {db}.{SQLITE_SCHEMA_TABLE} WHERE type = 'table' AND name = $1",
db = database.unwrap_or("main")
),
params!(table_name.to_string()),
|row| row.get(0),
0,
)
.await?
.ok_or_else(|| trailbase_sqlite::Error::Rusqlite(rusqlite::Error::QueryReturnedNoRows))?;

View file

@ -123,10 +123,10 @@ pub async fn init_app_state(args: InitArgs) -> Result<(bool, AppState), InitErro
if new_db {
let num_admins: i64 = app_state
.user_conn()
.read_query_row_f(
.read_query_row_get(
format!("SELECT COUNT(*) FROM {USER_TABLE} WHERE admin = TRUE"),
(),
|row| row.get(0),
0,
)
.await?
.unwrap_or(0);

View file

@ -242,7 +242,7 @@ mod tests {
#[tokio::test]
async fn test_transaction_log() {
let mut conn = rusqlite::Connection::open_in_memory().unwrap();
let conn = trailbase_sqlite::Connection::open_in_memory().unwrap();
conn
.execute_batch(
r#"
@ -255,35 +255,45 @@ mod tests {
INSERT INTO 'table' (id, name, age) VALUES (0, 'Alice', 21), (1, 'Bob', 18);
"#,
)
.await
.unwrap();
// Just double checking that rusqlite's query and execute ignore everything but the first
// statement.
let result = conn.query_row(
r#"
let result = conn
.query_row_get::<String>(
r#"
SELECT name FROM 'table' WHERE id = 0;
SELECT name FROM 'table' WHERE id = 1;
DROP TABLE 'table';
"#,
(),
|row| row.get::<_, String>(0),
);
assert!(matches!(result, Err(rusqlite::Error::MultipleStatement)));
(),
0,
)
.await;
assert!(matches!(
result,
Err(trailbase_sqlite::Error::Rusqlite(
rusqlite::Error::MultipleStatement
))
));
let mut recorder = TransactionRecorder::new(&mut conn).unwrap();
let log = {
let mut lock = conn.write_lock();
let mut recorder = TransactionRecorder::new(&mut lock).unwrap();
recorder
.execute("DELETE FROM 'table' WHERE age < ?1", rusqlite::params!(20))
.unwrap();
let log = recorder.rollback().unwrap().unwrap();
recorder
.execute("DELETE FROM 'table' WHERE age < ?1", rusqlite::params!(20))
.unwrap();
recorder.rollback().unwrap().unwrap()
};
assert_eq!(log.log.len(), 1);
assert_eq!(log.log[0].0, QueryType::Execute);
assert_eq!(log.log[0].1, "DELETE FROM 'table' WHERE age < 20");
let conn = trailbase_sqlite::Connection::from_connection_test_only(conn);
let count: i64 = conn
.query_row_f("SELECT COUNT(*) FROM 'table'", (), |row| row.get(0))
.read_query_row_get("SELECT COUNT(*) FROM 'table'", (), 0)
.await
.unwrap()
.unwrap();
@ -292,7 +302,7 @@ mod tests {
log.commit(&conn).await.unwrap();
let count: i64 = conn
.query_row_f("SELECT COUNT(*) FROM 'table'", (), |row| row.get(0))
.read_query_row_get("SELECT COUNT(*) FROM 'table'", (), 0)
.await
.unwrap()
.unwrap();

View file

@ -2,6 +2,7 @@ use axum::extract::{Json, State};
use axum::http::StatusCode;
use axum_test::TestServer;
use axum_test::multipart::MultipartForm;
use serde::Deserialize;
use std::sync::Arc;
use tower_cookies::Cookie;
use trailbase_sqlite::params;
@ -237,29 +238,33 @@ async fn test_record_apis() {
}
let logs_count: i64 = logs_conn
.read_query_row_f("SELECT COUNT(*) FROM _logs", (), |row| row.get(0))
.read_query_row_get("SELECT COUNT(*) FROM _logs", (), 0)
.await
.unwrap()
.unwrap();
assert!(logs_count > 0);
let (fetched_ip, latency, status): (String, f64, i64) = logs_conn
.read_query_row_f(
#[derive(Deserialize)]
struct Log {
client_ip: String,
latency: f64,
status: i64,
}
let got: Log = logs_conn
.read_query_value(
"SELECT client_ip, latency, status FROM _logs WHERE client_ip = $1",
trailbase_sqlite::params!(client_ip),
|row| -> Result<_, rusqlite::Error> {
return Ok((row.get(0)?, row.get(1)?, row.get(2)?));
},
)
.await
.unwrap()
.unwrap();
// We're also testing stiching here, since client_ip is recorded on_request and latency/status
// We're also testing stitching here, since client_ip is recorded on_request and latency/status
// on_response.
assert_eq!(fetched_ip, client_ip);
assert!(latency > 0.0);
assert_eq!(status, 200);
assert_eq!(got.client_ip, client_ip);
assert!(got.latency > 0.0);
assert_eq!(got.status, 200);
}
async fn create_chat_message_app_tables(
@ -305,10 +310,10 @@ async fn add_room(
name: &str,
) -> Result<[u8; 16], anyhow::Error> {
let room: [u8; 16] = conn
.query_row_f(
.query_row_get(
"INSERT INTO room (name) VALUES ($1) RETURNING id",
params!(name.to_string()),
|row| row.get::<_, [u8; 16]>(0),
0,
)
.await?
.unwrap();

View file

@ -33,7 +33,7 @@ impl AsyncConnection for Connection {
) -> Result<T, BenchmarkError> {
return Ok(
self
.query_row_f(sql.into(), params.into(), |row| row.get::<_, T>(0))
.query_row_get(sql.into(), params.into(), 0)
.await?
.unwrap(),
);
@ -46,7 +46,7 @@ impl AsyncConnection for Connection {
) -> Result<T, BenchmarkError> {
return Ok(
self
.read_query_row_f(sql.into(), params.into(), |row| row.get::<_, T>(0))
.read_query_row_get(sql.into(), params.into(), 0)
.await?
.unwrap(),
);

View file

@ -151,22 +151,6 @@ impl Connection {
});
}
pub fn from_connection_test_only(conn: rusqlite::Connection) -> Self {
use parking_lot::lock_api::RwLock;
let (shared_write_sender, shared_write_receiver) = kanal::unbounded::<Message>();
let conns = Arc::new(RwLock::new(ConnectionVec(vec![conn])));
let conns_clone = conns.clone();
std::thread::spawn(move || event_loop(0, conns_clone, shared_write_receiver));
return Self {
id: UNIQUE_CONN_ID.fetch_add(1, Ordering::SeqCst),
reader: shared_write_sender.clone(),
writer: shared_write_sender,
conns,
};
}
/// Open a new connection to an in-memory SQLite database.
///
/// # Failure
@ -237,13 +221,6 @@ impl Connection {
receiver.await.map_err(|_| Error::ConnectionClosed)?
}
#[inline]
pub fn call_and_forget(&self, function: impl FnOnce(&rusqlite::Connection) + Send + 'static) {
let _ = self
.writer
.send(Message::RunMut(Box::new(move |conn| function(conn))));
}
#[inline]
pub async fn call_reader<F, R>(&self, function: F) -> Result<R, Error>
where
@ -264,16 +241,6 @@ impl Connection {
receiver.await.map_err(|_| Error::ConnectionClosed)?
}
#[inline]
pub fn call_reader_and_forget(
&self,
function: impl FnOnce(&rusqlite::Connection) + Send + 'static,
) {
let _ = self
.writer
.send(Message::RunConst(Box::new(move |conn| function(conn))));
}
/// Query SQL statement.
pub async fn read_query_rows(
&self,
@ -287,7 +254,7 @@ impl Connection {
params.bind(&mut stmt)?;
let rows = stmt.raw_query();
Ok(crate::rows::from_rows(rows)?)
return crate::rows::from_rows(rows);
})
.await;
}
@ -303,7 +270,7 @@ impl Connection {
params.bind(&mut stmt)?;
let rows = stmt.raw_query();
Ok(crate::rows::from_rows(rows)?)
return crate::rows::from_rows(rows);
})
.await;
}
@ -321,7 +288,7 @@ impl Connection {
}
#[inline]
pub async fn query_row_f<T, E: Into<Error>>(
async fn query_row_f<T, E>(
&self,
sql: impl AsRef<str> + Send + 'static,
params: impl Params + Send + 'static,
@ -347,7 +314,28 @@ impl Connection {
}
#[inline]
pub async fn read_query_row_f<T, E>(
pub async fn query_row_get<T>(
&self,
sql: impl AsRef<str> + Send + 'static,
params: impl Params + Send + 'static,
index: usize,
) -> Result<Option<T>, Error>
where
T: rusqlite::types::FromSql + Send + 'static,
{
return self
.query_row_f(
sql,
params,
move |row: &rusqlite::Row<'_>| -> Result<T, Error> {
return Ok(row.get(index)?);
},
)
.await;
}
#[inline]
async fn read_query_row_f<T, E>(
&self,
sql: impl AsRef<str> + Send + 'static,
params: impl Params + Send + 'static,
@ -374,6 +362,27 @@ impl Connection {
.await;
}
#[inline]
pub async fn read_query_row_get<T>(
&self,
sql: impl AsRef<str> + Send + 'static,
params: impl Params + Send + 'static,
index: usize,
) -> Result<Option<T>, Error>
where
T: rusqlite::types::FromSql + Send + 'static,
{
return self
.read_query_row_f(
sql,
params,
move |row: &rusqlite::Row<'_>| -> Result<T, Error> {
return Ok(row.get(index)?);
},
)
.await;
}
pub async fn read_query_value<T: serde::de::DeserializeOwned + Send + 'static>(
&self,
sql: impl AsRef<str> + Send + 'static,

View file

@ -316,9 +316,7 @@ async fn test_execute_batch() {
let count = async |table: &str| -> i64 {
return conn
.query_row_f(format!("SELECT COUNT(*) FROM {table}"), (), |row| {
row.get(0)
})
.query_row_get(format!("SELECT COUNT(*) FROM {table}"), (), 0)
.await
.unwrap()
.unwrap();
@ -443,7 +441,7 @@ async fn test_params() {
.unwrap();
let count: i64 = conn
.read_query_row_f("SELECT COUNT(*) FROM person", (), |row| row.get(0))
.read_query_row_get("SELECT COUNT(*) FROM person", (), 0)
.await
.unwrap()
.unwrap();
@ -466,44 +464,47 @@ async fn test_hooks() {
row_id: i64,
}
let (sender, mut receiver) = tokio::sync::mpsc::unbounded_channel::<String>();
let c = conn.clone();
let (sender, receiver) = kanal::unbounded::<String>();
conn
.write_lock()
.preupdate_hook(Some(
move |action: rusqlite::hooks::Action, _db: &str, table_name: &str, case: &PreUpdateCase| {
let row_id = extract_row_id(case).unwrap();
let state = State {
action,
table_name: table_name.to_string(),
row_id,
};
let sender = sender.clone();
c.call_and_forget(move |conn| {
match state.action {
rusqlite::hooks::Action::SQLITE_INSERT => {
let text = conn
.query_row(
&format!(
r#"SELECT text FROM "{}" WHERE _rowid_ = $1"#,
state.table_name
),
[state.row_id],
|row| row.get::<_, String>(0),
)
.unwrap();
sender.send(text).unwrap();
}
_ => {
panic!("unexpected action: {:?}", state.action);
}
.preupdate_hook({
let conn = conn.clone();
Some(
move |action: rusqlite::hooks::Action,
_db: &str,
table_name: &str,
case: &PreUpdateCase| {
let row_id = extract_row_id(case).unwrap();
let state = State {
action,
table_name: table_name.to_string(),
row_id,
};
});
},
))
if state.action != rusqlite::hooks::Action::SQLITE_INSERT {
panic!("unexpected action: {:?}", state.action);
}
let sender = sender.clone();
let conn = conn.clone();
// We can't lock here since the lock is held by the `execute` below triggering
// the hook. Thus delay the query until after the `execute` completes.
std::thread::spawn(move || {
let query = format!("SELECT text FROM '{}' WHERE _rowid_ = $1", state.table_name);
sender
.send(
conn
.write_lock()
.query_row(&query, (state.row_id,), |row| row.get(0))
.unwrap(),
)
.unwrap();
});
},
)
})
.unwrap();
conn
@ -511,7 +512,7 @@ async fn test_hooks() {
.await
.unwrap();
let text = receiver.recv().await.unwrap();
let text = receiver.recv().unwrap();
assert_eq!(text, "foo");
}

View file

@ -552,7 +552,7 @@ mod tests {
assert_eq!(
1,
conn
.query_row_f("SELECT COUNT(*) FROM tx;", (), |row| row.get::<_, i64>(0))
.query_row_get::<i64>("SELECT COUNT(*) FROM tx;", (), 0)
.await
.unwrap()
.unwrap()
@ -585,10 +585,9 @@ mod tests {
#[tokio::test]
async fn test_custom_sqlite_function() {
let conn = rusqlite::Connection::open_in_memory().unwrap();
let _sqlite_function_runtime = init_sqlite_function_runtime(&conn).await;
let conn = trailbase_sqlite::Connection::open_in_memory().unwrap();
let conn = trailbase_sqlite::Connection::from_connection_test_only(conn);
let _sqlite_function_runtime = init_sqlite_function_runtime(&conn.write_lock()).await;
let runtime = init_runtime(Some(conn.clone()));
{