mirror of
https://github.com/trailbaseio/trailbase
synced 2026-04-21 13:37:44 +00:00
Add more validation to access rules, e.g. ensure that magics like _REQ_, _ROW_, ... are only included in appropriate queries.
This commit is contained in:
parent
d6d0534df4
commit
fa59a49c9c
2 changed files with 192 additions and 115 deletions
|
|
@ -7,7 +7,6 @@ use trailbase_schema::metadata::{
|
|||
JsonColumnMetadata, TableMetadata, TableOrViewMetadata, ViewMetadata, find_file_column_indexes,
|
||||
find_user_id_foreign_key_columns,
|
||||
};
|
||||
use trailbase_schema::parse::parse_into_statement;
|
||||
use trailbase_schema::sqlite::Column;
|
||||
use trailbase_schema::{QualifiedName, QualifiedNameEscaped};
|
||||
use trailbase_sqlite::{NamedParams, Params as _, Value};
|
||||
|
|
@ -690,63 +689,6 @@ impl<'a> trailbase_sqlite::Params for SubscriptionAclParams<'a> {
|
|||
}
|
||||
}
|
||||
|
||||
pub(crate) fn validate_rule(rule: &str) -> Result<(), String> {
|
||||
let stmt = parse_into_statement(&format!("SELECT {rule}"))
|
||||
.map_err(|err| format!("'{rule}' not a valid SQL expression: {err}"))?;
|
||||
|
||||
let Some(sqlite3_parser::ast::Stmt::Select(select)) = stmt else {
|
||||
panic!("Expected SELECT");
|
||||
};
|
||||
|
||||
let sqlite3_parser::ast::OneSelect::Select { mut columns, .. } = select.body.select else {
|
||||
panic!("Expected SELECT");
|
||||
};
|
||||
|
||||
if columns.len() != 1 {
|
||||
return Err("Expected single column".to_string());
|
||||
}
|
||||
|
||||
let sqlite3_parser::ast::ResultColumn::Expr(expr, _) = columns.swap_remove(0) else {
|
||||
return Err("Expected expr".to_string());
|
||||
};
|
||||
|
||||
validate_expr_recursively(&expr)?;
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
fn validate_expr_recursively(expr: &sqlite3_parser::ast::Expr) -> Result<(), String> {
|
||||
use sqlite3_parser::ast;
|
||||
|
||||
match &expr {
|
||||
ast::Expr::Binary(lhs, _op, rhs) => {
|
||||
validate_expr_recursively(lhs)?;
|
||||
validate_expr_recursively(rhs)?;
|
||||
}
|
||||
ast::Expr::IsNull(inner) => {
|
||||
validate_expr_recursively(inner)?;
|
||||
}
|
||||
ast::Expr::InTable { lhs, rhs, .. } => {
|
||||
match rhs {
|
||||
ast::QualifiedName {
|
||||
name: ast::Name(name),
|
||||
..
|
||||
} if **name == *"_REQ_FIELDS_" => {
|
||||
if !matches!(**lhs, ast::Expr::Literal(ast::Literal::String(_))) {
|
||||
return Err(format!("Expected literal string: {lhs:?}"));
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
};
|
||||
|
||||
validate_expr_recursively(lhs)?;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
#[derive(Template)]
|
||||
#[template(
|
||||
escape = "none",
|
||||
|
|
@ -1014,18 +956,4 @@ mod tests {
|
|||
assert!(has_access(acl, Permission::Update), "ACL: {acl}");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_rule() {
|
||||
assert!(validate_rule("").is_err());
|
||||
assert!(validate_rule("1, 1").is_err());
|
||||
assert!(validate_rule("1").is_ok());
|
||||
|
||||
validate_rule("_USER_.id IS NOT NULL").unwrap();
|
||||
validate_rule("_USER_.id IS NOT NULL AND _ROW_.userid = _USER_.id").unwrap();
|
||||
validate_rule("_USER_.id IS NOT NULL AND _REQ_.field IS NOT NULL").unwrap();
|
||||
|
||||
assert!(validate_rule("'field' IN _REQ_FIELDS_").is_ok());
|
||||
assert!(validate_rule("field IN _REQ_FIELDS_").is_err());
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,20 +1,18 @@
|
|||
use itertools::Itertools;
|
||||
use trailbase_schema::QualifiedName;
|
||||
use trailbase_schema::parse::parse_into_statement;
|
||||
use trailbase_schema::sqlite::ColumnOption;
|
||||
|
||||
use crate::config::{ConfigError, proto};
|
||||
use crate::records::record_api::validate_rule;
|
||||
use crate::schema_metadata::{SchemaMetadataCache, TableOrViewMetadata};
|
||||
|
||||
fn validate_record_api_name(name: &str) -> Result<(), ConfigError> {
|
||||
if name.is_empty() {
|
||||
return Err(ConfigError::Invalid(
|
||||
"Invalid api name: cannot be empty".to_string(),
|
||||
));
|
||||
return Err(invalid("Invalid api name: cannot be empty"));
|
||||
}
|
||||
|
||||
if !name.chars().all(|x| x.is_ascii_alphanumeric() || x == '_') {
|
||||
return Err(ConfigError::Invalid(format!(
|
||||
return Err(invalid(format!(
|
||||
"Invalid api name: {name}. Must only contain alphanumeric characters or '_'."
|
||||
)));
|
||||
}
|
||||
|
|
@ -26,46 +24,46 @@ pub(crate) fn validate_record_api_config(
|
|||
schemas: &SchemaMetadataCache,
|
||||
api_config: &proto::RecordApiConfig,
|
||||
) -> Result<String, ConfigError> {
|
||||
let ierr = |msg: &str| Err(ConfigError::Invalid(msg.to_string()));
|
||||
|
||||
let Some(ref api_name) = api_config.name else {
|
||||
return ierr("RecordApi config misses name.");
|
||||
return Err(invalid("RecordApi config misses name."));
|
||||
};
|
||||
validate_record_api_name(api_name)?;
|
||||
|
||||
let Some(ref table_name) = api_config.table_name else {
|
||||
return ierr("RecordApi config misses table name.");
|
||||
return Err(invalid("RecordApi config misses table name."));
|
||||
};
|
||||
|
||||
let metadata: std::sync::Arc<dyn TableOrViewMetadata> = {
|
||||
let table_name = QualifiedName::parse(table_name)?;
|
||||
if let Some(metadata) = schemas.get_table(&table_name) {
|
||||
if metadata.schema.temporary {
|
||||
return ierr("Record APIs must not reference TEMPORARY tables");
|
||||
return Err(invalid("Record APIs must not reference TEMPORARY tables"));
|
||||
}
|
||||
|
||||
metadata
|
||||
} else if let Some(metadata) = schemas.get_view(&table_name) {
|
||||
if metadata.schema.temporary {
|
||||
return ierr("Record APIs must not reference TEMPORARY views");
|
||||
return Err(invalid("Record APIs must not reference TEMPORARY views"));
|
||||
}
|
||||
|
||||
metadata
|
||||
} else {
|
||||
return ierr(&format!("Missing table or view for API: {api_name}"));
|
||||
return Err(invalid(format!(
|
||||
"Missing table or view for API: {api_name}"
|
||||
)));
|
||||
}
|
||||
};
|
||||
|
||||
let Some((pk_index, _)) = metadata.record_pk_column() else {
|
||||
return ierr(&format!(
|
||||
return Err(invalid(format!(
|
||||
"Table for api '{api_name}' is missing valid integer/UUID primary key column."
|
||||
));
|
||||
)));
|
||||
};
|
||||
|
||||
let Some(columns) = metadata.columns() else {
|
||||
return ierr(&format!(
|
||||
return Err(invalid(format!(
|
||||
"View for api '{api_name}' is not a \"simple\" view, i.e unable to infer types for strong type-safety"
|
||||
));
|
||||
)));
|
||||
};
|
||||
|
||||
for excluded_column_name in &api_config.excluded_columns {
|
||||
|
|
@ -73,32 +71,36 @@ pub(crate) fn validate_record_api_config(
|
|||
.iter()
|
||||
.position(|col| col.name == *excluded_column_name)
|
||||
else {
|
||||
return ierr(&format!(
|
||||
return Err(invalid(format!(
|
||||
"Excluded column '{excluded_column_name}' in API '{api_name}' not found.",
|
||||
));
|
||||
)));
|
||||
};
|
||||
|
||||
if excluded_index == pk_index {
|
||||
return ierr(&format!(
|
||||
return Err(invalid(format!(
|
||||
"PK column '{excluded_column_name}' cannot be excluded from API '{api_name}'.",
|
||||
));
|
||||
)));
|
||||
}
|
||||
|
||||
let excluded_column = &columns[excluded_index];
|
||||
if excluded_column.is_not_null() && !excluded_column.has_default() {
|
||||
return ierr(&format!(
|
||||
return Err(invalid(format!(
|
||||
"Cannot exclude column '{excluded_column_name}' from API '{api_name}', which is NOT NULL and w/o DEFAULT",
|
||||
));
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
for expand in &api_config.expand {
|
||||
if expand.starts_with("_") {
|
||||
return ierr(&format!("{api_name} expands hidden column: {expand}"));
|
||||
return Err(invalid(format!(
|
||||
"{api_name} expands hidden column: {expand}"
|
||||
)));
|
||||
}
|
||||
|
||||
let Some(column) = columns.iter().find(|c| c.name == *expand) else {
|
||||
return ierr(&format!("{api_name} expands missing column: {expand}"));
|
||||
return Err(invalid(format!(
|
||||
"{api_name} expands missing column: {expand}"
|
||||
)));
|
||||
};
|
||||
|
||||
let Some(ColumnOption::ForeignKey {
|
||||
|
|
@ -110,56 +112,203 @@ pub(crate) fn validate_record_api_config(
|
|||
.iter()
|
||||
.find_or_first(|o| matches!(o, ColumnOption::ForeignKey { .. }))
|
||||
else {
|
||||
return ierr(&format!(
|
||||
return Err(invalid(format!(
|
||||
"{api_name} expands non-foreign-key column: {expand}"
|
||||
));
|
||||
)));
|
||||
};
|
||||
|
||||
if foreign_table_name.starts_with("_") {
|
||||
return ierr(&format!(
|
||||
return Err(invalid(format!(
|
||||
"{api_name} expands reference '{expand}' to hidden table: {foreign_table_name}"
|
||||
));
|
||||
)));
|
||||
}
|
||||
|
||||
let Some(foreign_table) = schemas.get_table(&QualifiedName::parse(foreign_table_name)?) else {
|
||||
return ierr(&format!(
|
||||
return Err(invalid(format!(
|
||||
"{api_name} reference missing table: {foreign_table_name}"
|
||||
));
|
||||
)));
|
||||
};
|
||||
|
||||
let Some((_idx, foreign_pk_column)) = foreign_table.record_pk_column() else {
|
||||
return ierr(&format!(
|
||||
return Err(invalid(format!(
|
||||
"{api_name} references pk-less table: {foreign_table_name}"
|
||||
));
|
||||
)));
|
||||
};
|
||||
|
||||
match referred_columns.len() {
|
||||
0 => {}
|
||||
1 => {
|
||||
if referred_columns[0] != foreign_pk_column.name {
|
||||
return ierr(&format!(
|
||||
return Err(invalid(format!(
|
||||
"{api_name}.{expand} expands non-primary-key reference"
|
||||
));
|
||||
)));
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
return ierr(&format!(
|
||||
return Err(invalid(format!(
|
||||
"Composite keys cannot be expanded for {api_name}.{expand}"
|
||||
));
|
||||
)));
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
let rules = [
|
||||
&api_config.create_access_rule,
|
||||
&api_config.read_access_rule,
|
||||
&api_config.update_access_rule,
|
||||
&api_config.delete_access_rule,
|
||||
&api_config.schema_access_rule,
|
||||
(AccessKind::Create, api_config.create_access_rule.as_ref()),
|
||||
(AccessKind::Read, api_config.read_access_rule.as_ref()),
|
||||
(AccessKind::Update, api_config.update_access_rule.as_ref()),
|
||||
(AccessKind::Delete, api_config.delete_access_rule.as_ref()),
|
||||
(AccessKind::Schema, api_config.schema_access_rule.as_ref()),
|
||||
];
|
||||
for rule in rules.into_iter().flatten() {
|
||||
validate_rule(rule).map_err(ConfigError::Invalid)?;
|
||||
for (kind, rule) in rules {
|
||||
if let Some(rule) = rule {
|
||||
validate_rule(kind, rule).map_err(invalid)?;
|
||||
}
|
||||
}
|
||||
|
||||
return Ok(api_name.to_owned());
|
||||
}
|
||||
|
||||
enum AccessKind {
|
||||
Create,
|
||||
Read,
|
||||
Update,
|
||||
Delete,
|
||||
Schema,
|
||||
}
|
||||
|
||||
fn validate_rule(kind: AccessKind, rule: &str) -> Result<(), ConfigError> {
|
||||
for magic in ["_USER_", "_REQ_", "_REQ_FIELDS_", "_ROW_"] {
|
||||
if rule.contains(&magic.to_lowercase()) {
|
||||
return Err(invalid(
|
||||
"Access rule '{rule}', contained lower-case {magic}, upper-case expected",
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
// NOTE: We could probably do this as part of the recursive AST traversal below rather than
|
||||
// string match.
|
||||
// We may also want to scan more actively for typos... , e.g. _ROW_ vs _row_.
|
||||
match kind {
|
||||
AccessKind::Create => {
|
||||
if rule.contains("_ROW_") {
|
||||
return Err(invalid("Create rule cannot reference _ROW_"));
|
||||
}
|
||||
}
|
||||
AccessKind::Read => {
|
||||
if rule.contains("_REQ_") || rule.contains("_REQ_FIELDS_") {
|
||||
return Err(invalid("Read rule cannot reference _REQ_"));
|
||||
}
|
||||
}
|
||||
AccessKind::Update => {}
|
||||
AccessKind::Delete => {
|
||||
if rule.contains("_REQ_") || rule.contains("_REQ_FIELDS_") {
|
||||
return Err(invalid("Delete rule cannot reference _REQ_"));
|
||||
}
|
||||
}
|
||||
AccessKind::Schema => {
|
||||
if rule.contains("_ROW_") {
|
||||
return Err(invalid("Schema rule cannot reference _ROW_"));
|
||||
}
|
||||
if rule.contains("_REQ_") || rule.contains("_REQ_FIELDS_") {
|
||||
return Err(invalid("Schema rule cannot reference _REQ_"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let stmt = parse_into_statement(&format!("SELECT {rule}"))
|
||||
.map_err(|err| invalid(format!("'{rule}' not a valid SQL expression: {err}")))?;
|
||||
|
||||
let Some(sqlite3_parser::ast::Stmt::Select(select)) = stmt else {
|
||||
return Err(invalid(format!(
|
||||
"Access rule '{rule}' not a select statement"
|
||||
)));
|
||||
};
|
||||
|
||||
let sqlite3_parser::ast::OneSelect::Select { mut columns, .. } = select.body.select else {
|
||||
return Err(invalid(format!(
|
||||
"Access rule '{rule}' not a select statement"
|
||||
)));
|
||||
};
|
||||
|
||||
if columns.len() != 1 {
|
||||
return Err(invalid("Expected single column"));
|
||||
}
|
||||
|
||||
let sqlite3_parser::ast::ResultColumn::Expr(expr, _) = columns.swap_remove(0) else {
|
||||
return Err(invalid("Expected expr"));
|
||||
};
|
||||
|
||||
validate_expr_recursively(&expr)?;
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
fn validate_expr_recursively(expr: &sqlite3_parser::ast::Expr) -> Result<(), ConfigError> {
|
||||
use sqlite3_parser::ast;
|
||||
|
||||
match &expr {
|
||||
ast::Expr::Binary(lhs, _op, rhs) => {
|
||||
validate_expr_recursively(lhs)?;
|
||||
validate_expr_recursively(rhs)?;
|
||||
}
|
||||
ast::Expr::IsNull(inner) => {
|
||||
validate_expr_recursively(inner)?;
|
||||
}
|
||||
// Ensure `IN _REQ_FIELDS_` expression are preceded by literals, e.g.:
|
||||
// `'field' IN _REQ_FIELDS_`.
|
||||
ast::Expr::InTable { lhs, rhs, .. } => {
|
||||
match rhs {
|
||||
ast::QualifiedName {
|
||||
name: ast::Name(name),
|
||||
..
|
||||
} if name.as_ref() == "_REQ_FIELDS_" => {
|
||||
if !matches!(**lhs, ast::Expr::Literal(ast::Literal::String(_))) {
|
||||
return Err(invalid(format!(
|
||||
"Expected literal string on LHS of `IN _REQ_FIELDS_`, got: {lhs:?}"
|
||||
)));
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
};
|
||||
|
||||
validate_expr_recursively(lhs)?;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
fn invalid(err: impl std::string::ToString) -> ConfigError {
|
||||
return ConfigError::Invalid(err.to_string());
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_validate_rule() {
|
||||
assert!(validate_rule(AccessKind::Read, "").is_err());
|
||||
assert!(validate_rule(AccessKind::Read, "1, 1").is_err());
|
||||
assert!(validate_rule(AccessKind::Read, "1").is_ok());
|
||||
|
||||
validate_rule(AccessKind::Read, "_USER_.id IS NOT NULL").unwrap();
|
||||
validate_rule(
|
||||
AccessKind::Read,
|
||||
"_USER_.id IS NOT NULL AND _ROW_.userid = _USER_.id",
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert!(validate_rule(AccessKind::Read, "_REQ_.field = 'magic'").is_err());
|
||||
|
||||
validate_rule(
|
||||
AccessKind::Create,
|
||||
"_USER_.id IS NOT NULL AND _REQ_.field IS NOT NULL",
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert!(validate_rule(AccessKind::Update, "'field' IN _REQ_FIELDS_").is_ok());
|
||||
assert!(validate_rule(AccessKind::Update, "field IN _REQ_FIELDS_").is_err());
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue