diff --git a/CHANGELOG.md b/CHANGELOG.md index 73cb6fbb..38002f44 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,7 @@ ## unreleased - OIDC protected and public paths now respect the site prefix when it is defined. - add submit and reset form button icons: validate_icon, reset_icon, reset_color + - improve error messages when sqlpage functions are used incorrectly. Include precise file reference and line number ## 0.42.0 (2026-01-17) diff --git a/src/webserver/database/execute_queries.rs b/src/webserver/database/execute_queries.rs index 24de81ce..253bd17e 100644 --- a/src/webserver/database/execute_queries.rs +++ b/src/webserver/database/execute_queries.rs @@ -331,6 +331,16 @@ fn debug_row(r: &AnyRow) { } fn clone_anyhow_err(source_file: &Path, err: &anyhow::Error) -> anyhow::Error { + if let Some(func_err) = err.downcast_ref::() { + let line = func_err.line; + let loc = if line > 0 { + format!(":{line}") + } else { + String::new() + }; + return anyhow::anyhow!("{}{loc} {}", source_file.display(), func_err); + } + let mut e = anyhow!("{} contains a syntax error preventing SQLPage from parsing and preparing its SQL statements.", source_file.display()); for c in err.chain().rev() { e = e.context(c.to_string()); diff --git a/src/webserver/database/sql.rs b/src/webserver/database/sql.rs index 2efe5452..7acac1a0 100644 --- a/src/webserver/database/sql.rs +++ b/src/webserver/database/sql.rs @@ -1,6 +1,5 @@ use super::csv_import::{extract_csv_copy_statement, CsvImport}; use super::sqlpage_functions::functions::SqlPageFunctionName; -use super::sqlpage_functions::{are_params_extractable, func_call_to_param}; use super::syntax_tree::StmtParam; use super::SupportedDatabase; use crate::file_cache::AsyncFromStrWithState; @@ -10,10 +9,9 @@ use crate::{AppState, Database}; use async_trait::async_trait; use sqlparser::ast::helpers::attached_token::AttachedToken; use sqlparser::ast::{ - BinaryOperator, CastKind, CharacterLength, DataType, Expr, Function, FunctionArg, - FunctionArgExpr, FunctionArgumentList, FunctionArguments, Ident, ObjectName, ObjectNamePart, - SelectFlavor, SelectItem, Set, SetExpr, Spanned, Statement, Value, ValueWithSpan, Visit, - VisitMut, Visitor, VisitorMut, + CastKind, DataType, Expr, Function, FunctionArg, FunctionArgExpr, FunctionArgumentList, + FunctionArguments, Ident, ObjectName, ObjectNamePart, SelectFlavor, SelectItem, Set, SetExpr, + Spanned, Statement, Value, ValueWithSpan, }; use sqlparser::dialect::{ Dialect, DuckDbDialect, GenericDialect, MsSqlDialect, MySqlDialect, OracleDialect, @@ -24,10 +22,18 @@ use sqlparser::tokenizer::Token::{self, SemiColon, EOF}; use sqlparser::tokenizer::{Location, Span, TokenWithSpan, Tokenizer}; use sqlx::any::AnyKind; use std::fmt::Write; -use std::ops::ControlFlow; use std::path::{Path, PathBuf}; use std::str::FromStr; +mod parameter_extraction; +use self::parameter_extraction::{ + extract_ident_param, validate_function_calls, ParameterExtractor, TEMP_PLACEHOLDER_PREFIX, +}; +pub(super) use self::parameter_extraction::{ + function_args_to_stmt_params, DbPlaceHolder, ParamExtractContext, SqlPageFunctionError, + DB_PLACEHOLDERS, +}; + #[derive(Default)] pub struct ParsedSqlFile { pub(super) statements: Vec, @@ -197,7 +203,11 @@ fn parse_single_statement( while parser.consume_token(&SemiColon) { semicolon = true; } - let mut params = ParameterExtractor::extract_parameters(&mut stmt, db_info.clone()); + + let mut params = match ParameterExtractor::extract_parameters(&mut stmt, db_info.clone()) { + Ok(p) => p, + Err(err) => return Some(ParsedStatement::Error(err)), + }; let dbms = db_info.database_type; if let Some(parsed) = extract_set_variable(&mut stmt, &mut params, db_info) { return Some(parsed); @@ -209,11 +219,11 @@ fn parse_single_statement( log::debug!("Optimised a static simple select to avoid a trivial database query: {stmt} optimized to {static_statement:?}"); return Some(ParsedStatement::StaticSimpleSelect(static_statement)); } + let delayed_functions = extract_toplevel_functions(&mut stmt); + if let Err(err) = validate_function_calls(&stmt) { - return Some(ParsedStatement::Error(err.context(format!( - "Invalid SQLPage function call found in:\n{stmt}" - )))); + return Some(ParsedStatement::Error(err)); } let json_columns = extract_json_columns(&stmt, dbms); let query = format!( @@ -285,18 +295,6 @@ fn dialect_for_db(dbms: SupportedDatabase) -> Box { } } -fn map_param(mut name: String) -> StmtParam { - if name.is_empty() { - return StmtParam::PostOrGet(name); - } - let prefix = name.remove(0); - match prefix { - '$' => StmtParam::PostOrGet(name), - ':' => StmtParam::Post(name), - _ => StmtParam::Get(name), - } -} - #[derive(Debug, PartialEq)] pub struct DelayedFunctionCall { pub function: SqlPageFunctionName, @@ -548,430 +546,6 @@ fn extract_set_variable( None } -struct ParameterExtractor { - db_info: DbInfo, - parameters: Vec, -} - -#[derive(Debug)] -pub enum DbPlaceHolder { - PrefixedNumber { prefix: &'static str }, - Positional { placeholder: &'static str }, -} - -pub const DB_PLACEHOLDERS: [(AnyKind, DbPlaceHolder); 5] = [ - ( - AnyKind::Sqlite, - DbPlaceHolder::PrefixedNumber { prefix: "?" }, - ), - ( - AnyKind::Postgres, - DbPlaceHolder::PrefixedNumber { prefix: "$" }, - ), - ( - AnyKind::MySql, - DbPlaceHolder::Positional { placeholder: "?" }, - ), - ( - AnyKind::Mssql, - DbPlaceHolder::PrefixedNumber { prefix: "@p" }, - ), - ( - AnyKind::Odbc, - DbPlaceHolder::Positional { placeholder: "?" }, - ), -]; - -/// For positional parameters, we use a temporary placeholder during parameter extraction, -/// And then replace it with the actual placeholder during statement rewriting. -const TEMP_PLACEHOLDER_PREFIX: &str = "@SQLPAGE_TEMP"; - -fn get_placeholder_prefix(kind: AnyKind) -> &'static str { - if let Some((_, DbPlaceHolder::PrefixedNumber { prefix })) = DB_PLACEHOLDERS - .iter() - .find(|(placeholder_kind, _prefix)| *placeholder_kind == kind) - { - prefix - } else { - TEMP_PLACEHOLDER_PREFIX - } -} - -impl ParameterExtractor { - fn extract_parameters( - sql_ast: &mut sqlparser::ast::Statement, - db_info: DbInfo, - ) -> Vec { - let mut this = Self { - db_info, - parameters: vec![], - }; - let _ = sql_ast.visit(&mut this); - this.parameters - } - - fn replace_with_placeholder(&mut self, value: &mut Expr, param: StmtParam) { - let placeholder = - if let Some(existing_idx) = self.parameters.iter().position(|p| *p == param) { - // Parameter already exists, use its index - self.make_placeholder_for_index(existing_idx + 1) - } else { - // New parameter, add it to the list - let placeholder = self.make_placeholder(); - log::trace!("Replacing {param} with {placeholder}"); - self.parameters.push(param); - placeholder - }; - *value = placeholder; - } - - fn make_placeholder_for_index(&self, index: usize) -> Expr { - let name = make_tmp_placeholder(self.db_info.kind, index); - let data_type = match self.db_info.database_type { - SupportedDatabase::MySql => DataType::Char(None), - SupportedDatabase::Mssql => DataType::Varchar(Some(CharacterLength::Max)), - SupportedDatabase::Postgres | SupportedDatabase::Sqlite => DataType::Text, - SupportedDatabase::Oracle => DataType::Varchar(Some(CharacterLength::IntegerLength { - length: 4000, - unit: None, - })), - _ => DataType::Varchar(None), - }; - let value = Expr::value(Value::Placeholder(name)); - Expr::Cast { - expr: Box::new(value), - data_type, - format: None, - kind: CastKind::Cast, - } - } - - fn make_placeholder(&self) -> Expr { - self.make_placeholder_for_index(self.parameters.len() + 1) - } - - fn is_own_placeholder(&self, param: &str) -> bool { - let prefix = get_placeholder_prefix(self.db_info.kind); - if let Some(param) = param.strip_prefix(prefix) { - if let Ok(index) = param.parse::() { - return index <= self.parameters.len() + 1; - } - } - false - } -} - -struct InvalidFunctionFinder; -impl Visitor for InvalidFunctionFinder { - type Break = (String, Vec); - fn pre_visit_expr(&mut self, value: &Expr) -> ControlFlow { - match value { - Expr::Function(Function { - name: ObjectName(func_name_parts), - args: - FunctionArguments::List(FunctionArgumentList { - args, - duplicate_treatment: None, - .. - }), - .. - }) if is_sqlpage_func(func_name_parts) => { - let func_name = sqlpage_func_name(func_name_parts); - let arguments = args.clone(); - return ControlFlow::Break((func_name.to_string(), arguments)); - } - _ => (), - } - ControlFlow::Continue(()) - } -} - -fn validate_function_calls(stmt: &Statement) -> anyhow::Result<()> { - let mut finder = InvalidFunctionFinder; - if let ControlFlow::Break((func_name, args)) = stmt.visit(&mut finder) { - let args_str = FormatArguments(&args); - let error_msg = format!( - "Invalid SQLPage function call: sqlpage.{func_name}({args_str})\n\n\ - Arbitrary SQL expressions as function arguments are not supported.\n\n\ - SQLPage functions can either:\n\ - 1. Run BEFORE the query (to provide input values)\n\ - 2. Run AFTER the query (to process the results)\n\ - But they can't run DURING the query - the database doesn't know how to call them!\n\n\ - To fix this, you can either:\n\ - 1. Store the function argument in a variable first:\n\ - SET {func_name}_arg = ...;\n\ - SET {func_name}_result = sqlpage.{func_name}(${func_name}_arg);\n\ - SELECT * FROM example WHERE xxx = ${func_name}_result;\n\n\ - 2. Or move the function to the top level to process results:\n\ - SELECT sqlpage.{func_name}(...) FROM example;" - ); - Err(anyhow::anyhow!(error_msg)) - } else { - Ok(()) - } -} - -/** This is a helper struct to format a list of arguments for an error message. */ -pub(super) struct FormatArguments<'a>(pub &'a [FunctionArg]); -impl std::fmt::Display for FormatArguments<'_> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let mut args = self.0.iter(); - if let Some(arg) = args.next() { - write!(f, "{arg}")?; - } - for arg in args { - write!(f, ", {arg}")?; - } - Ok(()) - } -} - -pub(super) fn function_arg_to_stmt_param(arg: &mut FunctionArg) -> Option { - function_arg_expr(arg).and_then(expr_to_stmt_param) -} - -pub(super) fn function_args_to_stmt_params( - arguments: &mut [FunctionArg], -) -> anyhow::Result> { - arguments - .iter_mut() - .map(|arg| { - function_arg_to_stmt_param(arg) - .ok_or_else(|| anyhow::anyhow!("Passing \"{arg}\" as a function argument is not supported.\n\n\ - The only supported sqlpage function argument types are : \n\ - - variables (such as $my_variable), \n\ - - other sqlpage function calls (such as sqlpage.cookie('my_cookie')), \n\ - - literal strings (such as 'my_string'), \n\ - - concatenations of the above (such as CONCAT(x, y)).\n\n\ - Arbitrary SQL expressions as function arguments are not supported.\n\ - Try executing the SQL expression in a separate SET expression, then passing it to the function:\n\n\ - set my_parameter = {arg}; \n\ - SELECT sqlpage.my_function($my_parameter);\n\n\ - ")) - }) - .collect::>>() -} - -fn expr_to_stmt_param(arg: &mut Expr) -> Option { - match arg { - Expr::Value(ValueWithSpan { - value: Value::Placeholder(placeholder), - .. - }) => Some(map_param(std::mem::take(placeholder))), - Expr::Identifier(ident) => extract_ident_param(ident), - Expr::Function(Function { - name: ObjectName(func_name_parts), - args: - FunctionArguments::List(FunctionArgumentList { - args, - duplicate_treatment: None, - .. - }), - .. - }) if is_sqlpage_func(func_name_parts) => Some(func_call_to_param( - sqlpage_func_name(func_name_parts), - args.as_mut_slice(), - )), - Expr::Value(ValueWithSpan { - value: Value::SingleQuotedString(param_value), - .. - }) => Some(StmtParam::Literal(std::mem::take(param_value))), - Expr::Value(ValueWithSpan { - value: Value::Number(param_value, _is_long), - .. - }) => Some(StmtParam::Literal(param_value.clone())), - Expr::Value(ValueWithSpan { - value: Value::Null, .. - }) => Some(StmtParam::Null), - Expr::BinaryOp { - // 'str1' || 'str2' - left, - op: BinaryOperator::StringConcat, - right, - } => { - let left = expr_to_stmt_param(left)?; - let right = expr_to_stmt_param(right)?; - Some(StmtParam::Concat(vec![left, right])) - } - // SQLPage can evaluate some functions natively without sending them to the database: - // CONCAT('str1', 'str2', ...) - // json_object('key1', 'value1', 'key2', 'value2', ...) - // json_array('value1', 'value2', ...) - Expr::Function(Function { - name: ObjectName(func_name_parts), - args: - FunctionArguments::List(FunctionArgumentList { - args, - duplicate_treatment: None, - .. - }), - .. - }) if func_name_parts.len() == 1 => { - let func_name = func_name_parts[0] - .as_ident() - .map(|ident| ident.value.as_str()) - .unwrap_or_default(); - if func_name.eq_ignore_ascii_case("concat") { - let mut concat_args = Vec::with_capacity(args.len()); - for arg in args { - concat_args.push(function_arg_to_stmt_param(arg)?); - } - Some(StmtParam::Concat(concat_args)) - } else if func_name.eq_ignore_ascii_case("json_object") - || func_name.eq_ignore_ascii_case("jsonb_object") - || func_name.eq_ignore_ascii_case("json_build_object") - || func_name.eq_ignore_ascii_case("jsonb_build_object") - { - let mut json_obj_args = Vec::with_capacity(args.len()); - for arg in args { - json_obj_args.push(function_arg_to_stmt_param(arg)?); - } - Some(StmtParam::JsonObject(json_obj_args)) - } else if func_name.eq_ignore_ascii_case("json_array") - || func_name.eq_ignore_ascii_case("jsonb_array") - || func_name.eq_ignore_ascii_case("json_build_array") - || func_name.eq_ignore_ascii_case("jsonb_build_array") - { - let mut json_obj_args = Vec::with_capacity(args.len()); - for arg in args { - json_obj_args.push(function_arg_to_stmt_param(arg)?); - } - Some(StmtParam::JsonArray(json_obj_args)) - } else if func_name.eq_ignore_ascii_case("coalesce") { - let mut coalesce_args = Vec::with_capacity(args.len()); - for arg in args { - coalesce_args.push(function_arg_to_stmt_param(arg)?); - } - Some(StmtParam::Coalesce(coalesce_args)) - } else { - log::warn!("SQLPage cannot emulate the following function: {func_name}"); - None - } - } - _ => { - log::warn!("Unsupported function argument: {arg}"); - None - } - } -} - -fn function_arg_expr(arg: &mut FunctionArg) -> Option<&mut Expr> { - match arg { - FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)) => Some(expr), - other => { - log::warn!( - "Using named function arguments ({other}) is not supported by SQLPage functions." - ); - None - } - } -} - -#[inline] -#[must_use] -pub fn make_tmp_placeholder(kind: AnyKind, arg_number: usize) -> String { - let prefix = if let Some((_, DbPlaceHolder::PrefixedNumber { prefix })) = - DB_PLACEHOLDERS.iter().find(|(db_typ, _)| *db_typ == kind) - { - prefix - } else { - TEMP_PLACEHOLDER_PREFIX - }; - format!("{prefix}{arg_number}") -} - -fn extract_ident_param(Ident { value, .. }: &mut Ident) -> Option { - if value.starts_with('$') || value.starts_with(':') { - let name = std::mem::take(value); - Some(map_param(name)) - } else { - None - } -} - -impl VisitorMut for ParameterExtractor { - type Break = (); - fn pre_visit_expr(&mut self, value: &mut Expr) -> ControlFlow { - match value { - Expr::Identifier(ident) => { - if let Some(param) = extract_ident_param(ident) { - self.replace_with_placeholder(value, param); - } - } - Expr::Value(ValueWithSpan { - value: Value::Placeholder(param), - .. - }) if !self.is_own_placeholder(param) => - // this check is to avoid recursively replacing placeholders in the form of '?', or '$1', '$2', which we emit ourselves - { - let name = std::mem::take(param); - self.replace_with_placeholder(value, map_param(name)); - } - Expr::Function(Function { - name: ObjectName(func_name_parts), - args: - FunctionArguments::List(FunctionArgumentList { - args, - duplicate_treatment: None, - .. - }), - filter: None, - null_treatment: None, - over: None, - .. - }) if is_sqlpage_func(func_name_parts) && are_params_extractable(args) => { - let func_name = sqlpage_func_name(func_name_parts); - log::trace!("Handling builtin function: {func_name}"); - let mut arguments = std::mem::take(args); - let param = func_call_to_param(func_name, &mut arguments); - self.replace_with_placeholder(value, param); - } - // Replace 'str1' || 'str2' with CONCAT('str1', 'str2') for MSSQL - Expr::BinaryOp { - left, - op: BinaryOperator::StringConcat, - right, - } if self.db_info.database_type == SupportedDatabase::Mssql => { - let left = std::mem::replace(left.as_mut(), Expr::value(Value::Null)); - let right = std::mem::replace(right.as_mut(), Expr::value(Value::Null)); - *value = Expr::Function(Function { - name: ObjectName(vec![ObjectNamePart::Identifier(Ident::new("CONCAT"))]), - args: FunctionArguments::List(FunctionArgumentList { - args: vec![ - FunctionArg::Unnamed(FunctionArgExpr::Expr(left)), - FunctionArg::Unnamed(FunctionArgExpr::Expr(right)), - ], - duplicate_treatment: None, - clauses: Vec::new(), - }), - parameters: FunctionArguments::None, - over: None, - filter: None, - null_treatment: None, - within_group: Vec::new(), - uses_odbc_syntax: false, - }); - } - Expr::Cast { - kind: kind @ CastKind::DoubleColon, - .. - } if ![ - SupportedDatabase::Postgres, - SupportedDatabase::Snowflake, - SupportedDatabase::Generic, - ] - .contains(&self.db_info.database_type) => - { - log::warn!("Casting with '::' is not supported on your database. \ - For backwards compatibility with older SQLPage versions, we will transform it to CAST(... AS ...)."); - *kind = CastKind::Cast; - } - _ => (), - } - ControlFlow::<()>::Continue(()) - } -} - const SQLPAGE_FUNCTION_NAMESPACE: &str = "sqlpage"; fn is_sqlpage_func(func_name_parts: &[ObjectNamePart]) -> bool { @@ -1147,7 +721,7 @@ mod test { let mut ast = parse_postgres_stmt("select $a from t where $x > $a OR $x = sqlpage.cookie('cookoo')"); let db_info = create_test_db_info(SupportedDatabase::Postgres); - let parameters = ParameterExtractor::extract_parameters(&mut ast, db_info); + let parameters = ParameterExtractor::extract_parameters(&mut ast, db_info).unwrap(); // $a -> $1 // $x -> $2 // sqlpage.cookie(...) -> $3 @@ -1172,7 +746,7 @@ mod test { fn test_statement_rewrite_sqlite() { let mut ast = parse_stmt("select $x, :y from t", &SQLiteDialect {}); let db_info = create_test_db_info(SupportedDatabase::Sqlite); - let parameters = ParameterExtractor::extract_parameters(&mut ast, db_info); + let parameters = ParameterExtractor::extract_parameters(&mut ast, db_info).unwrap(); assert_eq!( ast.to_string(), "SELECT CAST(?1 AS TEXT), CAST(?2 AS TEXT) FROM t" @@ -1271,31 +845,11 @@ mod test { // The order of the function arguments should be preserved // Otherwise the statement parameters will be bound to the wrong arguments let sql = "select $a as a, sqlpage.exec('xxx', x = $b) as b, $c as c from t"; - let db_info = create_test_db_info(SupportedDatabase::Postgres); - let all = parse_sql(&db_info, &PostgreSqlDialect {}, sql) - .unwrap() - .collect::>(); - assert_eq!(all.len(), 1); - let ParsedStatement::StmtWithParams(StmtWithParams { - query, - params, - delayed_functions, - .. - }) = &all[0] - else { - panic!("Failed to parse statement: {all:?}"); - }; - assert_eq!( - query, - "SELECT CAST($1 AS TEXT) AS a, 'xxx' AS \"_sqlpage_f0_a0\", x = CAST($2 AS TEXT) AS \"_sqlpage_f0_a1\", CAST($3 AS TEXT) AS c FROM t" - ); + let mut ast = parse_postgres_stmt(sql); + let delayed_functions = extract_toplevel_functions(&mut ast); assert_eq!( - params, - &[ - StmtParam::PostOrGet("a".to_string()), - StmtParam::PostOrGet("b".to_string()), - StmtParam::PostOrGet("c".to_string()), - ] + ast.to_string(), + "SELECT $a AS a, 'xxx' AS \"_sqlpage_f0_a0\", x = $b AS \"_sqlpage_f0_a1\", $c AS c FROM t" ); assert_eq!( delayed_functions, @@ -1316,7 +870,7 @@ mod test { let sql = "select sqlpage.fetch($x)"; let mut ast = parse_stmt(sql, dialect); let db_info = create_test_db_info(SupportedDatabase::Postgres); - let parameters = ParameterExtractor::extract_parameters(&mut ast, db_info); + let parameters = ParameterExtractor::extract_parameters(&mut ast, db_info).unwrap(); assert_eq!( parameters, [StmtParam::FunctionCall(SqlPageFunctionCall { @@ -1328,6 +882,43 @@ mod test { } } + #[test] + fn test_parse_sql_unsupported_expr_in_sqlpage_arg() { + let sql = "SELECT sqlpage.link('x', json_build_object('k', c)) FROM (SELECT 1 AS c) t"; + let db_info = create_test_db_info(SupportedDatabase::Postgres); + let mut parsed = parse_sql(&db_info, &PostgreSqlDialect {}, sql).unwrap(); + let stmt = parsed.next().expect("one statement"); + let ParsedStatement::Error(err) = stmt else { + panic!("expected ParsedStatement::Error: {stmt:?}"); + }; + let err_msg = format!("{err:#}"); + assert!( + err_msg.contains("Unsupported sqlpage function argument:"), + "{err_msg}" + ); + assert!(err_msg.contains("\"c\" is an sql expression, which cannot be passed as a nested sqlpage function argument."), "{err_msg}"); + } + + #[test] + fn test_parse_sql_unemulated_function_in_sqlpage_arg() { + let sql = "SELECT sqlpage.link('x', upper('a')) FROM (SELECT 1) t"; + let db_info = create_test_db_info(SupportedDatabase::Postgres); + let mut parsed = parse_sql(&db_info, &PostgreSqlDialect {}, sql).unwrap(); + let stmt = parsed.next().expect("one statement"); + let ParsedStatement::Error(err) = stmt else { + panic!("expected ParsedStatement::Error: {stmt:?}"); + }; + let err_msg = format!("{err:#}"); + assert!( + err_msg.contains("Unsupported sqlpage function argument:"), + "{err_msg}" + ); + assert!( + err_msg.contains("\"upper\" is not a supported sqlpage function"), + "{err_msg}" + ); + } + #[test] fn test_set_variable_to_other_variable() { let sql = "set x = $y"; @@ -1355,31 +946,36 @@ mod test { fn is_own_placeholder() { assert!(ParameterExtractor { db_info: create_test_db_info(SupportedDatabase::Postgres), - parameters: vec![] + parameters: vec![], + extract_error: None, } .is_own_placeholder("$1")); assert!(ParameterExtractor { db_info: create_test_db_info(SupportedDatabase::Postgres), - parameters: vec![StmtParam::Get("x".to_string())] + parameters: vec![StmtParam::Get("x".to_string())], + extract_error: None, } .is_own_placeholder("$2")); assert!(!ParameterExtractor { db_info: create_test_db_info(SupportedDatabase::Postgres), - parameters: vec![] + parameters: vec![], + extract_error: None, } .is_own_placeholder("$2")); assert!(ParameterExtractor { db_info: create_test_db_info(SupportedDatabase::Sqlite), - parameters: vec![] + parameters: vec![], + extract_error: None, } .is_own_placeholder("?1")); assert!(!ParameterExtractor { db_info: create_test_db_info(SupportedDatabase::Sqlite), - parameters: vec![] + parameters: vec![], + extract_error: None, } .is_own_placeholder("$1")); } @@ -1391,7 +987,7 @@ mod test { &MsSqlDialect {}, ); let db_info = create_test_db_info(SupportedDatabase::Mssql); - let parameters = ParameterExtractor::extract_parameters(&mut ast, db_info); + let parameters = ParameterExtractor::extract_parameters(&mut ast, db_info).unwrap(); assert_eq!( ast.to_string(), "SELECT CONCAT('', CAST(@p1 AS VARCHAR(MAX))) FROM [a schema].[a table]" @@ -1740,7 +1336,8 @@ mod test { let stmt = parse_single_statement(&mut parser, &db_info, sql); if let Some(ParsedStatement::Error(err)) = stmt { assert!( - err.to_string().contains("Invalid SQLPage function call"), + err.to_string() + .contains("Unsupported sqlpage function argument:"), "Expected error for invalid function, got: {err}" ); } else { diff --git a/src/webserver/database/sql/parameter_extraction.rs b/src/webserver/database/sql/parameter_extraction.rs new file mode 100644 index 00000000..7386880f --- /dev/null +++ b/src/webserver/database/sql/parameter_extraction.rs @@ -0,0 +1,583 @@ +use super::super::{DbInfo, SupportedDatabase}; +use super::{is_sqlpage_func, sqlpage_func_name}; +use crate::webserver::database::sqlpage_functions::func_call_to_param; +use crate::webserver::database::syntax_tree::StmtParam; +use sqlparser::ast::{ + BinaryOperator, CastKind, CharacterLength, DataType, Expr, Function, FunctionArg, + FunctionArgExpr, FunctionArgumentList, FunctionArguments, Ident, ObjectName, ObjectNamePart, + Spanned, Statement, Value, ValueWithSpan, Visit, VisitMut, Visitor, VisitorMut, +}; +use sqlx::any::AnyKind; +use std::ops::ControlFlow; + +pub(super) struct ParameterExtractor { + pub(super) db_info: DbInfo, + pub(super) parameters: Vec, + pub(super) extract_error: Option, +} + +#[derive(Debug)] +pub(crate) enum DbPlaceHolder { + PrefixedNumber { prefix: &'static str }, + Positional { placeholder: &'static str }, +} + +pub(crate) const DB_PLACEHOLDERS: [(AnyKind, DbPlaceHolder); 5] = [ + ( + AnyKind::Sqlite, + DbPlaceHolder::PrefixedNumber { prefix: "?" }, + ), + ( + AnyKind::Postgres, + DbPlaceHolder::PrefixedNumber { prefix: "$" }, + ), + ( + AnyKind::MySql, + DbPlaceHolder::Positional { placeholder: "?" }, + ), + ( + AnyKind::Mssql, + DbPlaceHolder::PrefixedNumber { prefix: "@p" }, + ), + ( + AnyKind::Odbc, + DbPlaceHolder::Positional { placeholder: "?" }, + ), +]; + +/// For positional parameters, we use a temporary placeholder during parameter extraction, +/// And then replace it with the actual placeholder during statement rewriting. +pub(crate) const TEMP_PLACEHOLDER_PREFIX: &str = "@SQLPAGE_TEMP"; + +fn get_placeholder_prefix(kind: AnyKind) -> &'static str { + if let Some((_, DbPlaceHolder::PrefixedNumber { prefix })) = DB_PLACEHOLDERS + .iter() + .find(|(placeholder_kind, _prefix)| *placeholder_kind == kind) + { + prefix + } else { + TEMP_PLACEHOLDER_PREFIX + } +} + +impl ParameterExtractor { + pub(super) fn extract_parameters( + sql_ast: &mut Statement, + db_info: DbInfo, + ) -> anyhow::Result> { + let mut this = Self { + db_info, + parameters: vec![], + extract_error: None, + }; + let _ = sql_ast.visit(&mut this); + if let Some(e) = this.extract_error { + return Err(e); + } + Ok(this.parameters) + } + + fn replace_with_placeholder(&mut self, value: &mut Expr, param: StmtParam) { + let placeholder = + if let Some(existing_idx) = self.parameters.iter().position(|p| *p == param) { + // Parameter already exists, use its index + self.make_placeholder_for_index(existing_idx + 1) + } else { + // New parameter, add it to the list + let placeholder = self.make_placeholder(); + log::trace!("Replacing {param} with {placeholder}"); + self.parameters.push(param); + placeholder + }; + *value = placeholder; + } + + fn make_placeholder_for_index(&self, index: usize) -> Expr { + let name = make_tmp_placeholder(self.db_info.kind, index); + let data_type = match self.db_info.database_type { + SupportedDatabase::MySql => DataType::Char(None), + SupportedDatabase::Mssql => DataType::Varchar(Some(CharacterLength::Max)), + SupportedDatabase::Postgres | SupportedDatabase::Sqlite => DataType::Text, + SupportedDatabase::Oracle => DataType::Varchar(Some(CharacterLength::IntegerLength { + length: 4000, + unit: None, + })), + _ => DataType::Varchar(None), + }; + let value = Expr::value(Value::Placeholder(name)); + Expr::Cast { + expr: Box::new(value), + data_type, + format: None, + kind: CastKind::Cast, + } + } + + fn make_placeholder(&self) -> Expr { + self.make_placeholder_for_index(self.parameters.len() + 1) + } + + pub(super) fn is_own_placeholder(&self, param: &str) -> bool { + let prefix = get_placeholder_prefix(self.db_info.kind); + if let Some(param) = param.strip_prefix(prefix) { + if let Ok(index) = param.parse::() { + return index <= self.parameters.len() + 1; + } + } + false + } +} + +struct InvalidFunctionFinder; +impl Visitor for InvalidFunctionFinder { + type Break = (String, Vec); + fn pre_visit_expr(&mut self, value: &Expr) -> ControlFlow { + match value { + Expr::Function(Function { + name: ObjectName(func_name_parts), + args: + FunctionArguments::List(FunctionArgumentList { + args, + duplicate_treatment: None, + .. + }), + .. + }) if is_sqlpage_func(func_name_parts) => { + let func_name = sqlpage_func_name(func_name_parts); + let arguments = args.clone(); + return ControlFlow::Break((func_name.to_string(), arguments)); + } + _ => (), + } + ControlFlow::Continue(()) + } +} + +pub(super) fn validate_function_calls(stmt: &Statement) -> anyhow::Result<()> { + let mut finder = InvalidFunctionFinder; + if let ControlFlow::Break((func_name, mut args)) = stmt.visit(&mut finder) { + let ctx = ParamExtractContext { + parent_func: Some(func_name.clone()), + }; + function_args_to_stmt_params(&mut args, &ctx)?; + + let args_str = FormatArguments(&args); + let error_msg = format!( + "Invalid SQLPage function call: sqlpage.{func_name}({args_str})\n\n\ + Arbitrary SQL expressions as function arguments are not supported.\n\n\ + SQLPage functions can either:\n\ + 1. Run BEFORE the query (to provide input values)\n\ + 2. Run AFTER the query (to process the results)\n\ + But they can't run DURING the query - the database doesn't know how to call them!\n\n\ + To fix this, you can either:\n\ + 1. Store the function argument in a variable first:\n\ + SET {func_name}_arg = ...;\n\ + SET {func_name}_result = sqlpage.{func_name}(${func_name}_arg);\n\ + SELECT * FROM example WHERE xxx = ${func_name}_result;\n\n\ + 2. Or move the function to the top level to process results:\n\ + SELECT sqlpage.{func_name}(...) FROM example;" + ); + Err(anyhow::anyhow!(error_msg)) + } else { + Ok(()) + } +} + +/** This is a helper struct to format a list of arguments for an error message. */ +struct FormatArguments<'a>(&'a [FunctionArg]); +impl std::fmt::Display for FormatArguments<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut args = self.0.iter(); + if let Some(arg) = args.next() { + write!(f, "{arg}")?; + } + for arg in args { + write!(f, ", {arg}")?; + } + Ok(()) + } +} + +#[derive(Clone, Default)] +pub(crate) struct ParamExtractContext { + pub parent_func: Option, +} + +impl ParamExtractContext { + fn with_parent(parent: &str) -> Self { + Self { + parent_func: Some(parent.to_string()), + } + } + + fn build_error(&self, e: &ExprToParamError, arguments: &[FunctionArg]) -> SqlPageFunctionError { + let line = e.line.unwrap_or(0); + let func_name = self.parent_func.as_deref().unwrap_or("unknown").to_string(); + let arguments_str = FormatArguments(arguments).to_string(); + + let reason = match &e.kind { + ExprToParamErrorKind::UnsupportedExpr { summary } => { + format!("\"{summary}\" is an sql expression, which cannot be passed as a nested sqlpage function argument.") + } + ExprToParamErrorKind::UnemulatedFunction { name } => { + format!("\"{name}\" is not a supported sqlpage function. Only a few basic sql functions like concat or json_object can be used inside sqlpage functions.") + } + ExprToParamErrorKind::NamedArgs => "Named function arguments are not supported.\n\ + Please use positional arguments only." + .to_string(), + }; + + SqlPageFunctionError { + line, + func_name, + arguments_str, + reason, + } + } +} + +#[derive(Debug)] +pub(crate) struct SqlPageFunctionError { + pub line: u64, + pub func_name: String, + pub arguments_str: String, + pub reason: String, +} + +impl std::fmt::Display for SqlPageFunctionError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "Unsupported sqlpage function argument:\n\ +sqlpage.{func}({args_str})\n\n\ +{reason}\n\n\ +SQLPage functions can either:\n\ +1. Run BEFORE the query (to provide input values)\n\ +2. Run AFTER the query (to process the results)\n\ +But they can't run DURING the query - the database doesn't know how to call them!\n\n\ +To fix this, you can either:\n\ +1. Store the function argument in a variable first:\n\ +SET {func}_arg = ...;\n\ +SET {func}_result = sqlpage.{func}(${func}_arg);\n\ +SELECT * FROM example WHERE xxx = ${func}_result;\n\n\ +2. Or move the function to the top level to process results:\n\ +SELECT sqlpage.{func}(...) FROM example;", + func = self.func_name, + args_str = self.arguments_str, + reason = self.reason + ) + } +} +impl std::error::Error for SqlPageFunctionError {} + +#[derive(Debug)] +struct ExprToParamError { + line: Option, + kind: ExprToParamErrorKind, +} + +#[derive(Debug)] +enum ExprToParamErrorKind { + UnsupportedExpr { summary: String }, + UnemulatedFunction { name: String }, + NamedArgs, +} + +fn expr_summary(expr: &Expr) -> String { + match expr { + Expr::CompoundIdentifier(idents) => { + let s = idents + .iter() + .map(|i| i.value.as_str()) + .collect::>() + .join("."); + format!("column/table reference '{s}'") + } + _ => format!("{expr}"), + } +} + +fn function_arg_to_stmt_param( + arg: &mut FunctionArg, + ctx: &ParamExtractContext, +) -> Result { + let expr = function_arg_expr(arg).ok_or(ExprToParamError { + line: None, + kind: ExprToParamErrorKind::NamedArgs, + })?; + expr_to_stmt_param(expr, ctx) +} + +pub(crate) fn function_args_to_stmt_params( + arguments: &mut [FunctionArg], + ctx: &ParamExtractContext, +) -> anyhow::Result> { + let mut params = Vec::with_capacity(arguments.len()); + // We iterate manually so we can pass the entire `arguments` slice to into_error on failure + for arg in arguments.iter_mut() { + match function_arg_to_stmt_param(arg, ctx) { + Ok(p) => params.push(p), + Err(e) => { + let func_err = ctx.build_error(&e, arguments); + return Err(anyhow::Error::new(func_err)); + } + } + } + Ok(params) +} + +fn emulated_func_args_to_param( + func_name: &str, + args: &mut [FunctionArg], + line: u64, +) -> Result { + let inner = ParamExtractContext::with_parent(func_name); + if func_name.eq_ignore_ascii_case("concat") { + let mut concat_args = Vec::with_capacity(args.len()); + for a in args { + concat_args.push(function_arg_to_stmt_param(a, &inner)?); + } + Ok(StmtParam::Concat(concat_args)) + } else if func_name.eq_ignore_ascii_case("json_object") + || func_name.eq_ignore_ascii_case("jsonb_object") + || func_name.eq_ignore_ascii_case("json_build_object") + || func_name.eq_ignore_ascii_case("jsonb_build_object") + { + let mut json_obj_args = Vec::with_capacity(args.len()); + for a in args { + json_obj_args.push(function_arg_to_stmt_param(a, &inner)?); + } + Ok(StmtParam::JsonObject(json_obj_args)) + } else if func_name.eq_ignore_ascii_case("json_array") + || func_name.eq_ignore_ascii_case("jsonb_array") + || func_name.eq_ignore_ascii_case("json_build_array") + || func_name.eq_ignore_ascii_case("jsonb_build_array") + { + let mut json_obj_args = Vec::with_capacity(args.len()); + for a in args { + json_obj_args.push(function_arg_to_stmt_param(a, &inner)?); + } + Ok(StmtParam::JsonArray(json_obj_args)) + } else if func_name.eq_ignore_ascii_case("coalesce") { + let mut coalesce_args = Vec::with_capacity(args.len()); + for a in args { + coalesce_args.push(function_arg_to_stmt_param(a, &inner)?); + } + Ok(StmtParam::Coalesce(coalesce_args)) + } else { + Err(ExprToParamError { + line: Some(line), + kind: ExprToParamErrorKind::UnemulatedFunction { + name: func_name.to_string(), + }, + }) + } +} + +fn expr_to_stmt_param( + arg: &mut Expr, + ctx: &ParamExtractContext, +) -> Result { + let line = arg.span().start.line; + match arg { + Expr::Value(ValueWithSpan { + value: Value::Placeholder(placeholder), + .. + }) => Ok(map_param(std::mem::take(placeholder))), + Expr::Identifier(ident) => extract_ident_param(ident).ok_or_else(|| ExprToParamError { + line: Some(line), + kind: ExprToParamErrorKind::UnsupportedExpr { + summary: expr_summary(arg), + }, + }), + Expr::Function(Function { + name: ObjectName(func_name_parts), + args: + FunctionArguments::List(FunctionArgumentList { + args, + duplicate_treatment: None, + .. + }), + .. + }) if is_sqlpage_func(func_name_parts) => Ok(func_call_to_param( + sqlpage_func_name(func_name_parts), + args.as_mut_slice(), + ctx, + )), + Expr::Value(ValueWithSpan { + value: Value::SingleQuotedString(param_value), + .. + }) => Ok(StmtParam::Literal(std::mem::take(param_value))), + Expr::Value(ValueWithSpan { + value: Value::Number(param_value, _is_long), + .. + }) => Ok(StmtParam::Literal(param_value.clone())), + Expr::Value(ValueWithSpan { + value: Value::Null, .. + }) => Ok(StmtParam::Null), + Expr::BinaryOp { + left, + op: BinaryOperator::StringConcat, + right, + } => { + let left = expr_to_stmt_param(left, ctx)?; + let right = expr_to_stmt_param(right, ctx)?; + Ok(StmtParam::Concat(vec![left, right])) + } + Expr::Function(Function { + name: ObjectName(func_name_parts), + args: + FunctionArguments::List(FunctionArgumentList { + args, + duplicate_treatment: None, + .. + }), + .. + }) if func_name_parts.len() == 1 => { + let func_name = func_name_parts[0] + .as_ident() + .map(|ident| ident.value.as_str()) + .unwrap_or_default(); + emulated_func_args_to_param(func_name, args.as_mut_slice(), line) + } + _ => Err(ExprToParamError { + line: Some(line), + kind: ExprToParamErrorKind::UnsupportedExpr { + summary: expr_summary(arg), + }, + }), + } +} + +fn function_arg_expr(arg: &mut FunctionArg) -> Option<&mut Expr> { + match arg { + FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)) => Some(expr), + _ => None, + } +} + +#[inline] +#[must_use] +pub(super) fn make_tmp_placeholder(kind: AnyKind, arg_number: usize) -> String { + let prefix = if let Some((_, DbPlaceHolder::PrefixedNumber { prefix })) = + DB_PLACEHOLDERS.iter().find(|(db_typ, _)| *db_typ == kind) + { + prefix + } else { + TEMP_PLACEHOLDER_PREFIX + }; + format!("{prefix}{arg_number}") +} + +pub(super) fn extract_ident_param(Ident { value, .. }: &mut Ident) -> Option { + if value.starts_with('$') || value.starts_with(':') { + let name = std::mem::take(value); + Some(map_param(name)) + } else { + None + } +} + +fn map_param(mut name: String) -> StmtParam { + if name.is_empty() { + return StmtParam::PostOrGet(name); + } + let prefix = name.remove(0); + match prefix { + '$' => StmtParam::PostOrGet(name), + ':' => StmtParam::Post(name), + _ => StmtParam::Get(name), + } +} + +impl VisitorMut for ParameterExtractor { + type Break = (); + fn pre_visit_expr(&mut self, value: &mut Expr) -> ControlFlow { + match value { + Expr::Identifier(ident) => { + if let Some(param) = extract_ident_param(ident) { + self.replace_with_placeholder(value, param); + } + } + Expr::Value(ValueWithSpan { + value: Value::Placeholder(param), + .. + }) if !self.is_own_placeholder(param) => + // this check is to avoid recursively replacing placeholders in the form of '?', or '$1', '$2', which we emit ourselves + { + let name = std::mem::take(param); + self.replace_with_placeholder(value, map_param(name)); + } + Expr::Function(Function { + name: ObjectName(func_name_parts), + args: + FunctionArguments::List(FunctionArgumentList { + args, + duplicate_treatment: None, + .. + }), + filter: None, + null_treatment: None, + over: None, + .. + }) if is_sqlpage_func(func_name_parts) => { + let func_name = sqlpage_func_name(func_name_parts); + log::trace!("Handling builtin function: {func_name}"); + let arguments = std::mem::take(args); + let ctx = ParamExtractContext { + parent_func: Some(func_name.to_string()), + }; + let mut arguments_clone = arguments.clone(); + let param = func_call_to_param(func_name, &mut arguments_clone, &ctx); + if let StmtParam::Error(msg) = ¶m { + log::trace!("Skipping extraction of {func_name} due to: {msg}"); + *args = arguments; + return ControlFlow::Continue(()); + } + self.replace_with_placeholder(value, param); + } + // Replace 'str1' || 'str2' with CONCAT('str1', 'str2') for MSSQL + Expr::BinaryOp { + left, + op: BinaryOperator::StringConcat, + right, + } if self.db_info.database_type == SupportedDatabase::Mssql => { + let left = std::mem::replace(left.as_mut(), Expr::value(Value::Null)); + let right = std::mem::replace(right.as_mut(), Expr::value(Value::Null)); + *value = Expr::Function(Function { + name: ObjectName(vec![ObjectNamePart::Identifier(Ident::new("CONCAT"))]), + args: FunctionArguments::List(FunctionArgumentList { + args: vec![ + FunctionArg::Unnamed(FunctionArgExpr::Expr(left)), + FunctionArg::Unnamed(FunctionArgExpr::Expr(right)), + ], + duplicate_treatment: None, + clauses: Vec::new(), + }), + parameters: FunctionArguments::None, + over: None, + filter: None, + null_treatment: None, + within_group: Vec::new(), + uses_odbc_syntax: false, + }); + } + Expr::Cast { + kind: kind @ CastKind::DoubleColon, + .. + } if ![ + SupportedDatabase::Postgres, + SupportedDatabase::Snowflake, + SupportedDatabase::Generic, + ] + .contains(&self.db_info.database_type) => + { + log::warn!("Casting with '::' is not supported on your database. \ + For backwards compatibility with older SQLPage versions, we will transform it to CAST(... AS ...)."); + *kind = CastKind::Cast; + } + _ => (), + } + ControlFlow::<()>::Continue(()) + } +} diff --git a/src/webserver/database/sqlpage_functions/mod.rs b/src/webserver/database/sqlpage_functions/mod.rs index 27dc6c07..e5912e07 100644 --- a/src/webserver/database/sqlpage_functions/mod.rs +++ b/src/webserver/database/sqlpage_functions/mod.rs @@ -8,28 +8,17 @@ use sqlparser::ast::FunctionArg; use crate::webserver::http_request_info::{ExecutionContext, RequestInfo}; -use super::sql::function_args_to_stmt_params; +use super::sql::ParamExtractContext; use super::syntax_tree::SqlPageFunctionCall; use super::syntax_tree::StmtParam; -use super::sql::FormatArguments; -use anyhow::Context; - -pub(super) fn func_call_to_param(func_name: &str, arguments: &mut [FunctionArg]) -> StmtParam { - SqlPageFunctionCall::from_func_call(func_name, arguments) - .with_context(|| { - format!( - "Invalid function call: sqlpage.{func_name}({})", - FormatArguments(arguments) - ) - }) - .map_or_else( - |e| StmtParam::Error(format!("{e:#}")), - StmtParam::FunctionCall, - ) -} - -pub(super) fn are_params_extractable(arguments: &[FunctionArg]) -> bool { - let mut mutable_copy = arguments.to_vec(); - function_args_to_stmt_params(&mut mutable_copy).is_ok() +pub(super) fn func_call_to_param( + func_name: &str, + arguments: &mut [FunctionArg], + ctx: &ParamExtractContext, +) -> StmtParam { + SqlPageFunctionCall::from_func_call(func_name, arguments, ctx).map_or_else( + |e| StmtParam::Error(format!("{e:#}")), + StmtParam::FunctionCall, + ) } diff --git a/src/webserver/database/syntax_tree.rs b/src/webserver/database/syntax_tree.rs index c6311689..d01e8598 100644 --- a/src/webserver/database/syntax_tree.rs +++ b/src/webserver/database/syntax_tree.rs @@ -20,7 +20,7 @@ use crate::webserver::http_request_info::ExecutionContext; use crate::webserver::single_or_vec::SingleOrVec; use super::{ - execute_queries::DbConn, sql::function_args_to_stmt_params, + execute_queries::DbConn, sql::function_args_to_stmt_params, sql::ParamExtractContext, sqlpage_functions::functions::SqlPageFunctionName, }; use anyhow::Context as _; @@ -101,9 +101,13 @@ pub struct SqlPageFunctionCall { } impl SqlPageFunctionCall { - pub fn from_func_call(func_name: &str, arguments: &mut [FunctionArg]) -> anyhow::Result { + pub fn from_func_call( + func_name: &str, + arguments: &mut [FunctionArg], + ctx: &ParamExtractContext, + ) -> anyhow::Result { let function = SqlPageFunctionName::from_str(func_name)?; - let arguments = function_args_to_stmt_params(arguments)?; + let arguments = function_args_to_stmt_params(arguments, ctx)?; Ok(Self { function, arguments, diff --git a/tests/sql_test_files/component_rendering/error_arbitrary_SQL_expressions_as_function_arguments_are_not_supported.sql b/tests/sql_test_files/component_rendering/error_is_not_a_supported_sqlpage_function.sql similarity index 100% rename from tests/sql_test_files/component_rendering/error_arbitrary_SQL_expressions_as_function_arguments_are_not_supported.sql rename to tests/sql_test_files/component_rendering/error_is_not_a_supported_sqlpage_function.sql