mirror of
https://github.com/trailbaseio/trailbase
synced 2026-04-21 13:37:44 +00:00
Further reduce leaky trailbase_sqlite::Connection API surface.
This commit is contained in:
parent
23b659b8c6
commit
59ec41a260
26 changed files with 195 additions and 195 deletions
|
|
@ -60,10 +60,10 @@ async fn add_room(
|
|||
name: &str,
|
||||
) -> Result<[u8; 16], anyhow::Error> {
|
||||
let room: [u8; 16] = conn
|
||||
.query_row_f(
|
||||
.query_row_get(
|
||||
"INSERT INTO room (name) VALUES ($1) RETURNING id",
|
||||
params!(name.to_string()),
|
||||
|row| row.get(0),
|
||||
0,
|
||||
)
|
||||
.await?
|
||||
.unwrap();
|
||||
|
|
|
|||
|
|
@ -59,13 +59,13 @@ pub async fn list_logs_handler(
|
|||
};
|
||||
|
||||
let total_row_count: i64 = conn
|
||||
.read_query_row_f(
|
||||
.read_query_row_get(
|
||||
format!(
|
||||
"SELECT COUNT(*) FROM {LOGS_TABLE} AS {TABLE_ALIAS} WHERE {where_clause}",
|
||||
where_clause = filter_where_clause.clause
|
||||
),
|
||||
filter_where_clause.params.clone(),
|
||||
|row| row.get(0),
|
||||
0,
|
||||
)
|
||||
.await?
|
||||
.unwrap_or(-1);
|
||||
|
|
|
|||
|
|
@ -216,11 +216,9 @@ mod tests {
|
|||
Uuid::from_slice(&row.myid).unwrap()
|
||||
};
|
||||
|
||||
let count = || async {
|
||||
let count = async || -> i64 {
|
||||
conn
|
||||
.read_query_row_f(format!("SELECT COUNT(*) FROM '{table_name}'"), (), |row| {
|
||||
row.get::<_, i64>(0)
|
||||
})
|
||||
.read_query_row_get(format!("SELECT COUNT(*) FROM '{table_name}'"), (), 0)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap()
|
||||
|
|
|
|||
|
|
@ -81,9 +81,7 @@ pub async fn list_rows_handler(
|
|||
table = qualified_name.escaped_string()
|
||||
);
|
||||
conn
|
||||
.read_query_row_f(count_query, filter_where_clause.params.clone(), |row| {
|
||||
row.get(0)
|
||||
})
|
||||
.read_query_row_get(count_query, filter_where_clause.params.clone(), 0)
|
||||
.await?
|
||||
.unwrap_or(-1)
|
||||
};
|
||||
|
|
@ -295,7 +293,7 @@ mod tests {
|
|||
.unwrap();
|
||||
|
||||
let cnt: i64 = conn
|
||||
.read_query_row_f("SELECT COUNT(*) FROM test_table", (), |row| row.get(0))
|
||||
.read_query_row_get("SELECT COUNT(*) FROM test_table", (), 0)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
|
|
|||
|
|
@ -81,13 +81,13 @@ pub async fn list_users_handler(
|
|||
build_filter_where_clause("_ROW_", &table_metadata.column_metadata, filter_params)?;
|
||||
|
||||
let total_row_count: i64 = conn
|
||||
.read_query_row_f(
|
||||
.read_query_row_get(
|
||||
format!(
|
||||
"SELECT COUNT(*) FROM {USER_TABLE} AS _ROW_ WHERE {where_clause}",
|
||||
where_clause = filter_where_clause.clause
|
||||
),
|
||||
filter_where_clause.params.clone(),
|
||||
|row| row.get(0),
|
||||
0,
|
||||
)
|
||||
.await?
|
||||
.unwrap_or(-1);
|
||||
|
|
|
|||
|
|
@ -222,10 +222,10 @@ pub async fn login_otp_handler(
|
|||
|
||||
let Some(user_id) = state
|
||||
.session_conn()
|
||||
.query_row_f(
|
||||
.read_query_row_get(
|
||||
LOOKUP_OTP_QUERY,
|
||||
params!(normalized_email, otp_code.to_string()),
|
||||
|row| row.get::<_, [u8; 16]>(0),
|
||||
0,
|
||||
)
|
||||
.await?
|
||||
else {
|
||||
|
|
|
|||
|
|
@ -67,10 +67,10 @@ pub(crate) async fn auth_code_to_token_handler(
|
|||
|
||||
let Some(user_id) = state
|
||||
.session_conn()
|
||||
.query_row_f(
|
||||
.read_query_row_get(
|
||||
AUTH_CODE_QUERY,
|
||||
params!(authorization_code, pkce_code_challenge),
|
||||
|row| row.get::<_, [u8; 16]>(0),
|
||||
0,
|
||||
)
|
||||
.await?
|
||||
else {
|
||||
|
|
|
|||
|
|
@ -372,10 +372,10 @@ async fn test_auth_password_login_flow_without_pkce() {
|
|||
|
||||
let refresh_token: String = state
|
||||
.session_conn()
|
||||
.read_query_row_f(
|
||||
.read_query_row_get(
|
||||
format!("SELECT refresh_token FROM {SESSION_TABLE} WHERE user = $1;"),
|
||||
(user.uuid.into_bytes().to_vec(),),
|
||||
|row| row.get(0),
|
||||
0,
|
||||
)
|
||||
.await
|
||||
.unwrap()
|
||||
|
|
@ -613,10 +613,10 @@ async fn test_auth_change_email_flow() {
|
|||
|
||||
let db_email: String = state
|
||||
.user_conn()
|
||||
.read_query_row_f(
|
||||
.read_query_row_get(
|
||||
format!(r#"SELECT email FROM "{USER_TABLE}" WHERE id = $1"#),
|
||||
params!(user.uuid.into_bytes()),
|
||||
|row| row.get(0),
|
||||
0,
|
||||
)
|
||||
.await
|
||||
.unwrap()
|
||||
|
|
@ -692,10 +692,10 @@ async fn test_auth_delete_user_flow() {
|
|||
|
||||
let user_exists: bool = state
|
||||
.user_conn()
|
||||
.read_query_row_f(
|
||||
.read_query_row_get(
|
||||
format!(r#"SELECT EXISTS(SELECT * FROM "{USER_TABLE}" WHERE id = $1)"#),
|
||||
params!(user.uuid.into_bytes()),
|
||||
|row| row.get(0),
|
||||
0,
|
||||
)
|
||||
.await
|
||||
.unwrap()
|
||||
|
|
@ -801,10 +801,10 @@ async fn test_auth_otp_flow() {
|
|||
async fn session_exists(state: &AppState, user_id: Uuid) -> bool {
|
||||
return state
|
||||
.session_conn()
|
||||
.read_query_row_f(
|
||||
.read_query_row_get(
|
||||
format!("SELECT EXISTS(SELECT 1 FROM {SESSION_TABLE} WHERE user = $1)"),
|
||||
params!(user_id.into_bytes().to_vec()),
|
||||
|row| row.get(0),
|
||||
0,
|
||||
)
|
||||
.await
|
||||
.unwrap()
|
||||
|
|
|
|||
|
|
@ -386,10 +386,10 @@ fn get_redirect_location<T: IntoResponse>(response: T) -> Option<String> {
|
|||
async fn session_exists(state: &AppState, user_id: Uuid) -> bool {
|
||||
return state
|
||||
.session_conn()
|
||||
.read_query_row_f(
|
||||
.read_query_row_get(
|
||||
format!("SELECT EXISTS(SELECT 1 FROM {SESSION_TABLE} WHERE user = $1)"),
|
||||
(user_id.into_bytes().to_vec(),),
|
||||
|row| row.get(0),
|
||||
0,
|
||||
)
|
||||
.await
|
||||
.unwrap()
|
||||
|
|
|
|||
|
|
@ -219,9 +219,7 @@ pub(crate) async fn reauth_with_refresh_token(
|
|||
|
||||
let Some(user_id) = state
|
||||
.session_conn()
|
||||
.query_row_f(SESSION_QUERY, params!(refresh_token), |row| {
|
||||
row.get::<_, [u8; 16]>(0)
|
||||
})
|
||||
.read_query_row_get::<[u8; 16]>(SESSION_QUERY, params!(refresh_token), 0)
|
||||
.await?
|
||||
else {
|
||||
// Row not found case, typically expected in one of 4 cases:
|
||||
|
|
|
|||
|
|
@ -304,9 +304,7 @@ pub async fn user_exists(state: &AppState, email: &str) -> bool {
|
|||
|
||||
return match state
|
||||
.user_conn()
|
||||
.read_query_row_f(QUERY, params!(email.to_string()), |row| {
|
||||
row.get::<_, bool>(0)
|
||||
})
|
||||
.read_query_row_get(QUERY, params!(email.to_string()), 0)
|
||||
.await
|
||||
{
|
||||
Ok(Some(row)) => row,
|
||||
|
|
@ -324,9 +322,7 @@ pub(crate) async fn is_admin(state: &AppState, user_id: &uuid::Uuid) -> bool {
|
|||
|
||||
return match state
|
||||
.user_conn()
|
||||
.read_query_row_f(QUERY, params!(user_id.as_bytes().to_vec()), |row| {
|
||||
row.get::<_, i64>(0)
|
||||
})
|
||||
.read_query_row_get::<i64>(QUERY, params!(user_id.as_bytes().to_vec()), 0)
|
||||
.await
|
||||
{
|
||||
Ok(Some(row)) => row > 0,
|
||||
|
|
|
|||
|
|
@ -434,7 +434,7 @@ mod test {
|
|||
// Bulk inserts are rolled back in a transaction is second insert fails.
|
||||
let count_before: i64 = state
|
||||
.conn()
|
||||
.read_query_row_f("SELECT COUNT(*) FROM message", (), |row| row.get(0))
|
||||
.read_query_row_get("SELECT COUNT(*) FROM message", (), 0)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
|
@ -460,7 +460,7 @@ mod test {
|
|||
|
||||
let count_after: i64 = state
|
||||
.conn()
|
||||
.read_query_row_f("SELECT COUNT(*) FROM message", (), |row| row.get(0))
|
||||
.read_query_row_get("SELECT COUNT(*) FROM message", (), 0)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
|
|
|||
|
|
@ -152,10 +152,10 @@ mod test {
|
|||
|
||||
async fn message_exists(conn: &trailbase_sqlite::Connection, id: &[u8; 16]) -> bool {
|
||||
let count: i64 = conn
|
||||
.read_query_row_f(
|
||||
.read_query_row_get(
|
||||
"SELECT COUNT(*) FROM message WHERE mid = $1",
|
||||
params!(*id),
|
||||
|row| row.get(0),
|
||||
0,
|
||||
)
|
||||
.await
|
||||
.unwrap()
|
||||
|
|
|
|||
|
|
@ -297,14 +297,14 @@ mod test {
|
|||
.unwrap();
|
||||
|
||||
let count: i64 = conn
|
||||
.read_query_row_f(
|
||||
.read_query_row_get(
|
||||
format!(r#"SELECT COUNT(*) from "{USER_TABLE}" WHERE email = :email"#),
|
||||
trailbase_sqlite::named_params! {
|
||||
":email": EMAIL,
|
||||
":unused": "unused",
|
||||
":foo": 42,
|
||||
},
|
||||
|row| row.get(0),
|
||||
0,
|
||||
)
|
||||
.await
|
||||
.unwrap()
|
||||
|
|
@ -1010,9 +1010,7 @@ mod test {
|
|||
|
||||
let index: String = state
|
||||
.conn()
|
||||
.read_query_row_f(r#"SELECT "index" from "table" WHERE pid = 2"#, (), |row| {
|
||||
row.get(0)
|
||||
})
|
||||
.read_query_row_get(r#"SELECT "index" from "table" WHERE pid = 2"#, (), 0)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
|
|
|||
|
|
@ -564,6 +564,10 @@ impl RecordApi {
|
|||
}
|
||||
|
||||
/// Check if the given user (if any) can access a record given the request and the operation.
|
||||
///
|
||||
/// QUESTION: Could we structure this in a way that we yield less in case there's no read-access
|
||||
/// rule, e.g. sync and return (yikes):
|
||||
/// `Result<Option<Box<dyn Future<Output=Result<(), RecordError>>>>, RecordErrorr>`
|
||||
#[inline]
|
||||
pub(crate) async fn check_record_level_read_access_for_subscriptions(
|
||||
&self,
|
||||
|
|
|
|||
|
|
@ -252,10 +252,10 @@ impl PerConnectionState {
|
|||
|
||||
let Some(row_id): Option<i64> = api
|
||||
.conn()
|
||||
.read_query_row_f(
|
||||
.read_query_row_get(
|
||||
format!(r#"SELECT _rowid_ FROM {table_name} WHERE "{pk_column}" = $1"#),
|
||||
[record],
|
||||
|row| row.get(0),
|
||||
0,
|
||||
)
|
||||
.await?
|
||||
else {
|
||||
|
|
|
|||
|
|
@ -56,12 +56,11 @@ async fn subscribe_to_record_test() {
|
|||
let conn = state.conn().clone();
|
||||
|
||||
let record_id_raw = 0;
|
||||
let record_id = trailbase_sqlite::Value::Integer(record_id_raw);
|
||||
let rowid: i64 = conn
|
||||
.query_row_f(
|
||||
.query_row_get(
|
||||
"INSERT INTO test (id, text) VALUES ($1, 'foo') RETURNING _rowid_",
|
||||
[record_id],
|
||||
|row| row.get(0),
|
||||
[trailbase_sqlite::Value::Integer(record_id_raw)],
|
||||
0,
|
||||
)
|
||||
.await
|
||||
.unwrap()
|
||||
|
|
@ -137,10 +136,7 @@ async fn subscribe_to_record_test() {
|
|||
}
|
||||
|
||||
// Implicitly await for scheduled cleanups to go through.
|
||||
conn
|
||||
.read_query_row_f("SELECT 1", (), |row| row.get::<_, i64>(0))
|
||||
.await
|
||||
.unwrap();
|
||||
conn.read_query_row("SELECT 1", ()).await.unwrap();
|
||||
}
|
||||
|
||||
async fn subscribe_to_records(
|
||||
|
|
@ -281,10 +277,7 @@ async fn subscribe_to_table_test() {
|
|||
}
|
||||
|
||||
// Implicitly await for scheduled cleanups to go through.
|
||||
conn
|
||||
.read_query_row_f("SELECT 1", (), |row| row.get::<_, i64>(0))
|
||||
.await
|
||||
.unwrap();
|
||||
conn.read_query_row("SELECT 1", ()).await.unwrap();
|
||||
|
||||
assert_eq!(0, manager.num_table_subscriptions());
|
||||
}
|
||||
|
|
@ -297,10 +290,10 @@ async fn subscription_lifecycle_test() {
|
|||
let record_id_raw = 0;
|
||||
let record_id = trailbase_sqlite::Value::Integer(record_id_raw);
|
||||
let rowid: i64 = conn
|
||||
.query_row_f(
|
||||
.query_row_get(
|
||||
"INSERT INTO test (id, text) VALUES ($1, 'foo') RETURNING _rowid_",
|
||||
[record_id],
|
||||
|row| row.get(0),
|
||||
0,
|
||||
)
|
||||
.await
|
||||
.unwrap()
|
||||
|
|
@ -323,10 +316,7 @@ async fn subscription_lifecycle_test() {
|
|||
drop(sse);
|
||||
|
||||
// Implicitly await for the cleanup to be scheduled on the sqlite executor.
|
||||
conn
|
||||
.read_query_row_f("SELECT 1", (), |row| row.get::<_, i64>(0))
|
||||
.await
|
||||
.unwrap();
|
||||
conn.read_query_row("SELECT 1", ()).await.unwrap();
|
||||
|
||||
assert_eq!(0, manager.num_record_subscriptions());
|
||||
}
|
||||
|
|
@ -412,13 +402,13 @@ async fn subscription_acl_test() {
|
|||
let record_id_raw = 0;
|
||||
let record_id = trailbase_sqlite::Value::Integer(record_id_raw);
|
||||
let _rowid: i64 = conn
|
||||
.query_row_f(
|
||||
.query_row_get(
|
||||
"INSERT INTO test (id, user, text) VALUES ($1, $2, 'foo') RETURNING _rowid_",
|
||||
[
|
||||
record_id.clone(),
|
||||
trailbase_sqlite::Value::Blob(user_x.to_vec()),
|
||||
],
|
||||
|row| row.get(0),
|
||||
0,
|
||||
)
|
||||
.await
|
||||
.unwrap()
|
||||
|
|
@ -565,10 +555,7 @@ async fn test_acl_selective_table_subs() {
|
|||
}
|
||||
|
||||
// Implicitly await for scheduled cleanups to go through.
|
||||
conn
|
||||
.read_query_row_f("SELECT 1", (), |row| row.get::<_, i64>(0))
|
||||
.await
|
||||
.unwrap();
|
||||
conn.read_query_row("SELECT 1", ()).await.unwrap();
|
||||
|
||||
assert_eq!(0, manager.num_table_subscriptions());
|
||||
}
|
||||
|
|
@ -591,13 +578,13 @@ async fn subscription_acl_change_owner() {
|
|||
|
||||
let record_id = 0;
|
||||
let _rowid: i64 = conn
|
||||
.query_row_f(
|
||||
.query_row_get(
|
||||
"INSERT INTO test (id, user, text) VALUES ($1, $2, 'foo') RETURNING _rowid_",
|
||||
[
|
||||
trailbase_sqlite::Value::Integer(record_id),
|
||||
trailbase_sqlite::Value::Blob(user_x_id.into()),
|
||||
],
|
||||
|row| row.get(0),
|
||||
0,
|
||||
)
|
||||
.await
|
||||
.unwrap()
|
||||
|
|
@ -721,10 +708,7 @@ async fn subscription_filter_test() {
|
|||
}
|
||||
|
||||
// Implicitly await for scheduled cleanups to go through.
|
||||
conn
|
||||
.read_query_row_f("SELECT 1", (), |row| row.get::<_, i64>(0))
|
||||
.await
|
||||
.unwrap();
|
||||
conn.read_query_row("SELECT 1", ()).await.unwrap();
|
||||
|
||||
assert_eq!(0, manager.num_table_subscriptions());
|
||||
}
|
||||
|
|
|
|||
|
|
@ -159,10 +159,10 @@ mod tests {
|
|||
name: &str,
|
||||
) -> Result<[u8; 16], anyhow::Error> {
|
||||
let room: [u8; 16] = conn
|
||||
.query_row_f(
|
||||
.query_row_get(
|
||||
"INSERT INTO room (name) VALUES ($1) RETURNING rid",
|
||||
params!(name.to_string()),
|
||||
|row| row.get(0),
|
||||
0,
|
||||
)
|
||||
.await?
|
||||
.ok_or(rusqlite::Error::QueryReturnedNoRows)?;
|
||||
|
|
@ -191,10 +191,10 @@ mod tests {
|
|||
message: &str,
|
||||
) -> Result<[u8; 16], anyhow::Error> {
|
||||
let id: [u8; 16] = conn
|
||||
.query_row_f(
|
||||
.query_row_get(
|
||||
"INSERT INTO message (_owner, room, data) VALUES ($1, $2, $3) RETURNING mid",
|
||||
params!(user, room, message.to_string()),
|
||||
|row| row.get(0),
|
||||
0,
|
||||
)
|
||||
.await?
|
||||
.ok_or(rusqlite::Error::QueryReturnedNoRows)?;
|
||||
|
|
|
|||
|
|
@ -36,13 +36,13 @@ pub async fn lookup_and_parse_table_schema(
|
|||
) -> Result<Table, SchemaLookupError> {
|
||||
// Then get the actual table.
|
||||
let sql: String = conn
|
||||
.read_query_row_f(
|
||||
.read_query_row_get(
|
||||
format!(
|
||||
"SELECT sql FROM {db}.{SQLITE_SCHEMA_TABLE} WHERE type = 'table' AND name = $1",
|
||||
db = database.unwrap_or("main")
|
||||
),
|
||||
params!(table_name.to_string()),
|
||||
|row| row.get(0),
|
||||
0,
|
||||
)
|
||||
.await?
|
||||
.ok_or_else(|| trailbase_sqlite::Error::Rusqlite(rusqlite::Error::QueryReturnedNoRows))?;
|
||||
|
|
|
|||
|
|
@ -123,10 +123,10 @@ pub async fn init_app_state(args: InitArgs) -> Result<(bool, AppState), InitErro
|
|||
if new_db {
|
||||
let num_admins: i64 = app_state
|
||||
.user_conn()
|
||||
.read_query_row_f(
|
||||
.read_query_row_get(
|
||||
format!("SELECT COUNT(*) FROM {USER_TABLE} WHERE admin = TRUE"),
|
||||
(),
|
||||
|row| row.get(0),
|
||||
0,
|
||||
)
|
||||
.await?
|
||||
.unwrap_or(0);
|
||||
|
|
|
|||
|
|
@ -242,7 +242,7 @@ mod tests {
|
|||
|
||||
#[tokio::test]
|
||||
async fn test_transaction_log() {
|
||||
let mut conn = rusqlite::Connection::open_in_memory().unwrap();
|
||||
let conn = trailbase_sqlite::Connection::open_in_memory().unwrap();
|
||||
conn
|
||||
.execute_batch(
|
||||
r#"
|
||||
|
|
@ -255,35 +255,45 @@ mod tests {
|
|||
INSERT INTO 'table' (id, name, age) VALUES (0, 'Alice', 21), (1, 'Bob', 18);
|
||||
"#,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Just double checking that rusqlite's query and execute ignore everything but the first
|
||||
// statement.
|
||||
let result = conn.query_row(
|
||||
r#"
|
||||
let result = conn
|
||||
.query_row_get::<String>(
|
||||
r#"
|
||||
SELECT name FROM 'table' WHERE id = 0;
|
||||
SELECT name FROM 'table' WHERE id = 1;
|
||||
DROP TABLE 'table';
|
||||
"#,
|
||||
(),
|
||||
|row| row.get::<_, String>(0),
|
||||
);
|
||||
assert!(matches!(result, Err(rusqlite::Error::MultipleStatement)));
|
||||
(),
|
||||
0,
|
||||
)
|
||||
.await;
|
||||
assert!(matches!(
|
||||
result,
|
||||
Err(trailbase_sqlite::Error::Rusqlite(
|
||||
rusqlite::Error::MultipleStatement
|
||||
))
|
||||
));
|
||||
|
||||
let mut recorder = TransactionRecorder::new(&mut conn).unwrap();
|
||||
let log = {
|
||||
let mut lock = conn.write_lock();
|
||||
let mut recorder = TransactionRecorder::new(&mut lock).unwrap();
|
||||
|
||||
recorder
|
||||
.execute("DELETE FROM 'table' WHERE age < ?1", rusqlite::params!(20))
|
||||
.unwrap();
|
||||
let log = recorder.rollback().unwrap().unwrap();
|
||||
recorder
|
||||
.execute("DELETE FROM 'table' WHERE age < ?1", rusqlite::params!(20))
|
||||
.unwrap();
|
||||
recorder.rollback().unwrap().unwrap()
|
||||
};
|
||||
|
||||
assert_eq!(log.log.len(), 1);
|
||||
assert_eq!(log.log[0].0, QueryType::Execute);
|
||||
assert_eq!(log.log[0].1, "DELETE FROM 'table' WHERE age < 20");
|
||||
|
||||
let conn = trailbase_sqlite::Connection::from_connection_test_only(conn);
|
||||
let count: i64 = conn
|
||||
.query_row_f("SELECT COUNT(*) FROM 'table'", (), |row| row.get(0))
|
||||
.read_query_row_get("SELECT COUNT(*) FROM 'table'", (), 0)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
|
@ -292,7 +302,7 @@ mod tests {
|
|||
log.commit(&conn).await.unwrap();
|
||||
|
||||
let count: i64 = conn
|
||||
.query_row_f("SELECT COUNT(*) FROM 'table'", (), |row| row.get(0))
|
||||
.read_query_row_get("SELECT COUNT(*) FROM 'table'", (), 0)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ use axum::extract::{Json, State};
|
|||
use axum::http::StatusCode;
|
||||
use axum_test::TestServer;
|
||||
use axum_test::multipart::MultipartForm;
|
||||
use serde::Deserialize;
|
||||
use std::sync::Arc;
|
||||
use tower_cookies::Cookie;
|
||||
use trailbase_sqlite::params;
|
||||
|
|
@ -237,29 +238,33 @@ async fn test_record_apis() {
|
|||
}
|
||||
|
||||
let logs_count: i64 = logs_conn
|
||||
.read_query_row_f("SELECT COUNT(*) FROM _logs", (), |row| row.get(0))
|
||||
.read_query_row_get("SELECT COUNT(*) FROM _logs", (), 0)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
assert!(logs_count > 0);
|
||||
|
||||
let (fetched_ip, latency, status): (String, f64, i64) = logs_conn
|
||||
.read_query_row_f(
|
||||
#[derive(Deserialize)]
|
||||
struct Log {
|
||||
client_ip: String,
|
||||
latency: f64,
|
||||
status: i64,
|
||||
}
|
||||
|
||||
let got: Log = logs_conn
|
||||
.read_query_value(
|
||||
"SELECT client_ip, latency, status FROM _logs WHERE client_ip = $1",
|
||||
trailbase_sqlite::params!(client_ip),
|
||||
|row| -> Result<_, rusqlite::Error> {
|
||||
return Ok((row.get(0)?, row.get(1)?, row.get(2)?));
|
||||
},
|
||||
)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
// We're also testing stiching here, since client_ip is recorded on_request and latency/status
|
||||
// We're also testing stitching here, since client_ip is recorded on_request and latency/status
|
||||
// on_response.
|
||||
assert_eq!(fetched_ip, client_ip);
|
||||
assert!(latency > 0.0);
|
||||
assert_eq!(status, 200);
|
||||
assert_eq!(got.client_ip, client_ip);
|
||||
assert!(got.latency > 0.0);
|
||||
assert_eq!(got.status, 200);
|
||||
}
|
||||
|
||||
async fn create_chat_message_app_tables(
|
||||
|
|
@ -305,10 +310,10 @@ async fn add_room(
|
|||
name: &str,
|
||||
) -> Result<[u8; 16], anyhow::Error> {
|
||||
let room: [u8; 16] = conn
|
||||
.query_row_f(
|
||||
.query_row_get(
|
||||
"INSERT INTO room (name) VALUES ($1) RETURNING id",
|
||||
params!(name.to_string()),
|
||||
|row| row.get::<_, [u8; 16]>(0),
|
||||
0,
|
||||
)
|
||||
.await?
|
||||
.unwrap();
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ impl AsyncConnection for Connection {
|
|||
) -> Result<T, BenchmarkError> {
|
||||
return Ok(
|
||||
self
|
||||
.query_row_f(sql.into(), params.into(), |row| row.get::<_, T>(0))
|
||||
.query_row_get(sql.into(), params.into(), 0)
|
||||
.await?
|
||||
.unwrap(),
|
||||
);
|
||||
|
|
@ -46,7 +46,7 @@ impl AsyncConnection for Connection {
|
|||
) -> Result<T, BenchmarkError> {
|
||||
return Ok(
|
||||
self
|
||||
.read_query_row_f(sql.into(), params.into(), |row| row.get::<_, T>(0))
|
||||
.read_query_row_get(sql.into(), params.into(), 0)
|
||||
.await?
|
||||
.unwrap(),
|
||||
);
|
||||
|
|
|
|||
|
|
@ -151,22 +151,6 @@ impl Connection {
|
|||
});
|
||||
}
|
||||
|
||||
pub fn from_connection_test_only(conn: rusqlite::Connection) -> Self {
|
||||
use parking_lot::lock_api::RwLock;
|
||||
|
||||
let (shared_write_sender, shared_write_receiver) = kanal::unbounded::<Message>();
|
||||
let conns = Arc::new(RwLock::new(ConnectionVec(vec![conn])));
|
||||
let conns_clone = conns.clone();
|
||||
std::thread::spawn(move || event_loop(0, conns_clone, shared_write_receiver));
|
||||
|
||||
return Self {
|
||||
id: UNIQUE_CONN_ID.fetch_add(1, Ordering::SeqCst),
|
||||
reader: shared_write_sender.clone(),
|
||||
writer: shared_write_sender,
|
||||
conns,
|
||||
};
|
||||
}
|
||||
|
||||
/// Open a new connection to an in-memory SQLite database.
|
||||
///
|
||||
/// # Failure
|
||||
|
|
@ -237,13 +221,6 @@ impl Connection {
|
|||
receiver.await.map_err(|_| Error::ConnectionClosed)?
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn call_and_forget(&self, function: impl FnOnce(&rusqlite::Connection) + Send + 'static) {
|
||||
let _ = self
|
||||
.writer
|
||||
.send(Message::RunMut(Box::new(move |conn| function(conn))));
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub async fn call_reader<F, R>(&self, function: F) -> Result<R, Error>
|
||||
where
|
||||
|
|
@ -264,16 +241,6 @@ impl Connection {
|
|||
receiver.await.map_err(|_| Error::ConnectionClosed)?
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn call_reader_and_forget(
|
||||
&self,
|
||||
function: impl FnOnce(&rusqlite::Connection) + Send + 'static,
|
||||
) {
|
||||
let _ = self
|
||||
.writer
|
||||
.send(Message::RunConst(Box::new(move |conn| function(conn))));
|
||||
}
|
||||
|
||||
/// Query SQL statement.
|
||||
pub async fn read_query_rows(
|
||||
&self,
|
||||
|
|
@ -287,7 +254,7 @@ impl Connection {
|
|||
|
||||
params.bind(&mut stmt)?;
|
||||
let rows = stmt.raw_query();
|
||||
Ok(crate::rows::from_rows(rows)?)
|
||||
return crate::rows::from_rows(rows);
|
||||
})
|
||||
.await;
|
||||
}
|
||||
|
|
@ -303,7 +270,7 @@ impl Connection {
|
|||
|
||||
params.bind(&mut stmt)?;
|
||||
let rows = stmt.raw_query();
|
||||
Ok(crate::rows::from_rows(rows)?)
|
||||
return crate::rows::from_rows(rows);
|
||||
})
|
||||
.await;
|
||||
}
|
||||
|
|
@ -321,7 +288,7 @@ impl Connection {
|
|||
}
|
||||
|
||||
#[inline]
|
||||
pub async fn query_row_f<T, E: Into<Error>>(
|
||||
async fn query_row_f<T, E>(
|
||||
&self,
|
||||
sql: impl AsRef<str> + Send + 'static,
|
||||
params: impl Params + Send + 'static,
|
||||
|
|
@ -347,7 +314,28 @@ impl Connection {
|
|||
}
|
||||
|
||||
#[inline]
|
||||
pub async fn read_query_row_f<T, E>(
|
||||
pub async fn query_row_get<T>(
|
||||
&self,
|
||||
sql: impl AsRef<str> + Send + 'static,
|
||||
params: impl Params + Send + 'static,
|
||||
index: usize,
|
||||
) -> Result<Option<T>, Error>
|
||||
where
|
||||
T: rusqlite::types::FromSql + Send + 'static,
|
||||
{
|
||||
return self
|
||||
.query_row_f(
|
||||
sql,
|
||||
params,
|
||||
move |row: &rusqlite::Row<'_>| -> Result<T, Error> {
|
||||
return Ok(row.get(index)?);
|
||||
},
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
#[inline]
|
||||
async fn read_query_row_f<T, E>(
|
||||
&self,
|
||||
sql: impl AsRef<str> + Send + 'static,
|
||||
params: impl Params + Send + 'static,
|
||||
|
|
@ -374,6 +362,27 @@ impl Connection {
|
|||
.await;
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub async fn read_query_row_get<T>(
|
||||
&self,
|
||||
sql: impl AsRef<str> + Send + 'static,
|
||||
params: impl Params + Send + 'static,
|
||||
index: usize,
|
||||
) -> Result<Option<T>, Error>
|
||||
where
|
||||
T: rusqlite::types::FromSql + Send + 'static,
|
||||
{
|
||||
return self
|
||||
.read_query_row_f(
|
||||
sql,
|
||||
params,
|
||||
move |row: &rusqlite::Row<'_>| -> Result<T, Error> {
|
||||
return Ok(row.get(index)?);
|
||||
},
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
pub async fn read_query_value<T: serde::de::DeserializeOwned + Send + 'static>(
|
||||
&self,
|
||||
sql: impl AsRef<str> + Send + 'static,
|
||||
|
|
|
|||
|
|
@ -316,9 +316,7 @@ async fn test_execute_batch() {
|
|||
|
||||
let count = async |table: &str| -> i64 {
|
||||
return conn
|
||||
.query_row_f(format!("SELECT COUNT(*) FROM {table}"), (), |row| {
|
||||
row.get(0)
|
||||
})
|
||||
.query_row_get(format!("SELECT COUNT(*) FROM {table}"), (), 0)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
|
@ -443,7 +441,7 @@ async fn test_params() {
|
|||
.unwrap();
|
||||
|
||||
let count: i64 = conn
|
||||
.read_query_row_f("SELECT COUNT(*) FROM person", (), |row| row.get(0))
|
||||
.read_query_row_get("SELECT COUNT(*) FROM person", (), 0)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
|
@ -466,44 +464,47 @@ async fn test_hooks() {
|
|||
row_id: i64,
|
||||
}
|
||||
|
||||
let (sender, mut receiver) = tokio::sync::mpsc::unbounded_channel::<String>();
|
||||
let c = conn.clone();
|
||||
let (sender, receiver) = kanal::unbounded::<String>();
|
||||
|
||||
conn
|
||||
.write_lock()
|
||||
.preupdate_hook(Some(
|
||||
move |action: rusqlite::hooks::Action, _db: &str, table_name: &str, case: &PreUpdateCase| {
|
||||
let row_id = extract_row_id(case).unwrap();
|
||||
let state = State {
|
||||
action,
|
||||
table_name: table_name.to_string(),
|
||||
row_id,
|
||||
};
|
||||
|
||||
let sender = sender.clone();
|
||||
c.call_and_forget(move |conn| {
|
||||
match state.action {
|
||||
rusqlite::hooks::Action::SQLITE_INSERT => {
|
||||
let text = conn
|
||||
.query_row(
|
||||
&format!(
|
||||
r#"SELECT text FROM "{}" WHERE _rowid_ = $1"#,
|
||||
state.table_name
|
||||
),
|
||||
[state.row_id],
|
||||
|row| row.get::<_, String>(0),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
sender.send(text).unwrap();
|
||||
}
|
||||
_ => {
|
||||
panic!("unexpected action: {:?}", state.action);
|
||||
}
|
||||
.preupdate_hook({
|
||||
let conn = conn.clone();
|
||||
Some(
|
||||
move |action: rusqlite::hooks::Action,
|
||||
_db: &str,
|
||||
table_name: &str,
|
||||
case: &PreUpdateCase| {
|
||||
let row_id = extract_row_id(case).unwrap();
|
||||
let state = State {
|
||||
action,
|
||||
table_name: table_name.to_string(),
|
||||
row_id,
|
||||
};
|
||||
});
|
||||
},
|
||||
))
|
||||
|
||||
if state.action != rusqlite::hooks::Action::SQLITE_INSERT {
|
||||
panic!("unexpected action: {:?}", state.action);
|
||||
}
|
||||
|
||||
let sender = sender.clone();
|
||||
let conn = conn.clone();
|
||||
|
||||
// We can't lock here since the lock is held by the `execute` below triggering
|
||||
// the hook. Thus delay the query until after the `execute` completes.
|
||||
std::thread::spawn(move || {
|
||||
let query = format!("SELECT text FROM '{}' WHERE _rowid_ = $1", state.table_name);
|
||||
sender
|
||||
.send(
|
||||
conn
|
||||
.write_lock()
|
||||
.query_row(&query, (state.row_id,), |row| row.get(0))
|
||||
.unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
});
|
||||
},
|
||||
)
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
conn
|
||||
|
|
@ -511,7 +512,7 @@ async fn test_hooks() {
|
|||
.await
|
||||
.unwrap();
|
||||
|
||||
let text = receiver.recv().await.unwrap();
|
||||
let text = receiver.recv().unwrap();
|
||||
assert_eq!(text, "foo");
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -552,7 +552,7 @@ mod tests {
|
|||
assert_eq!(
|
||||
1,
|
||||
conn
|
||||
.query_row_f("SELECT COUNT(*) FROM tx;", (), |row| row.get::<_, i64>(0))
|
||||
.query_row_get::<i64>("SELECT COUNT(*) FROM tx;", (), 0)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap()
|
||||
|
|
@ -585,10 +585,9 @@ mod tests {
|
|||
|
||||
#[tokio::test]
|
||||
async fn test_custom_sqlite_function() {
|
||||
let conn = rusqlite::Connection::open_in_memory().unwrap();
|
||||
let _sqlite_function_runtime = init_sqlite_function_runtime(&conn).await;
|
||||
let conn = trailbase_sqlite::Connection::open_in_memory().unwrap();
|
||||
|
||||
let conn = trailbase_sqlite::Connection::from_connection_test_only(conn);
|
||||
let _sqlite_function_runtime = init_sqlite_function_runtime(&conn.write_lock()).await;
|
||||
let runtime = init_runtime(Some(conn.clone()));
|
||||
|
||||
{
|
||||
|
|
|
|||
Loading…
Reference in a new issue