From c67951bdbcd337e81b43323d34d154b1649da03b Mon Sep 17 00:00:00 2001 From: lovasoa Date: Sat, 28 Feb 2026 18:27:17 +0100 Subject: [PATCH 01/10] improve sqlpage function argument warnings with source context Made-with: Cursor --- src/webserver/database/sql.rs | 212 ++++++++++++------ .../database/sqlpage_functions/mod.rs | 11 +- src/webserver/database/syntax_tree.rs | 10 +- 3 files changed, 157 insertions(+), 76 deletions(-) diff --git a/src/webserver/database/sql.rs b/src/webserver/database/sql.rs index 2efe5452..6cfbb3ca 100644 --- a/src/webserver/database/sql.rs +++ b/src/webserver/database/sql.rs @@ -43,7 +43,8 @@ impl ParsedSqlFile { source_path.display(), dialect ); - let parsed_statements = match parse_sql(&db.info, dialect.as_ref(), sql) { + let parsed_statements = match parse_sql(&db.info, dialect.as_ref(), sql, Some(source_path)) + { Ok(parsed) => parsed, Err(err) => return Self::from_err(err, source_path), }; @@ -135,6 +136,7 @@ fn parse_sql<'a>( db_info: &'a DbInfo, dialect: &'a dyn Dialect, sql: &'a str, + source_path: Option<&'a Path>, ) -> anyhow::Result + 'a> { log::trace!("Parsing {} SQL: {sql}", db_info.dbms_name); @@ -151,7 +153,7 @@ fn parse_sql<'a>( // Return the first error and ignore the rest return None; } - let statement = parse_single_statement(&mut parser, db_info, sql); + let statement = parse_single_statement(&mut parser, db_info, sql, source_path); log::debug!("Parsed statement: {statement:?}"); if let Some(ParsedStatement::Error(_)) = &statement { has_error = true; @@ -185,6 +187,7 @@ fn parse_single_statement( parser: &mut Parser<'_>, db_info: &DbInfo, source_sql: &str, + source_path: Option<&Path>, ) -> Option { if parser.peek_token() == EOF { return None; @@ -197,7 +200,8 @@ fn parse_single_statement( while parser.consume_token(&SemiColon) { semicolon = true; } - let mut params = ParameterExtractor::extract_parameters(&mut stmt, db_info.clone()); + let mut params = + ParameterExtractor::extract_parameters(&mut stmt, db_info.clone(), source_path); let dbms = db_info.database_type; if let Some(parsed) = extract_set_variable(&mut stmt, &mut params, db_info) { return Some(parsed); @@ -551,6 +555,7 @@ fn extract_set_variable( struct ParameterExtractor { db_info: DbInfo, parameters: Vec, + source_path: Option, } #[derive(Debug)] @@ -601,10 +606,12 @@ impl ParameterExtractor { fn extract_parameters( sql_ast: &mut sqlparser::ast::Statement, db_info: DbInfo, + source_path: Option<&Path>, ) -> Vec { let mut this = Self { db_info, parameters: vec![], + source_path: source_path.map(PathBuf::from), }; let _ = sql_ast.visit(&mut this); this.parameters @@ -726,17 +733,43 @@ impl std::fmt::Display for FormatArguments<'_> { } } -pub(super) fn function_arg_to_stmt_param(arg: &mut FunctionArg) -> Option { - function_arg_expr(arg).and_then(expr_to_stmt_param) +#[derive(Clone, Default)] +pub(super) struct ParamWarnContext { + pub source_path: Option, + pub parent_func: Option, +} + +impl ParamWarnContext { + pub(super) fn with_parent(&self, parent: &str) -> Self { + Self { + source_path: self.source_path.clone(), + parent_func: Some(parent.to_string()), + } + } + + pub(super) fn location_prefix(&self, line: u64) -> String { + match &self.source_path { + Some(p) => format!("{}:{}: ", p.display(), line), + None => String::new(), + } + } +} + +pub(super) fn function_arg_to_stmt_param( + arg: &mut FunctionArg, + ctx: &ParamWarnContext, +) -> Option { + function_arg_expr(arg).and_then(|e| expr_to_stmt_param(e, ctx)) } pub(super) fn function_args_to_stmt_params( arguments: &mut [FunctionArg], + ctx: &ParamWarnContext, ) -> anyhow::Result> { arguments .iter_mut() .map(|arg| { - function_arg_to_stmt_param(arg) + function_arg_to_stmt_param(arg, ctx) .ok_or_else(|| anyhow::anyhow!("Passing \"{arg}\" as a function argument is not supported.\n\n\ The only supported sqlpage function argument types are : \n\ - variables (such as $my_variable), \n\ @@ -752,7 +785,51 @@ pub(super) fn function_args_to_stmt_params( .collect::>>() } -fn expr_to_stmt_param(arg: &mut Expr) -> Option { +fn emulated_func_args_to_param( + func_name: &str, + args: &mut [FunctionArg], + ctx: &ParamWarnContext, +) -> Option { + let inner = ctx.with_parent(func_name); + if func_name.eq_ignore_ascii_case("concat") { + let mut concat_args = Vec::with_capacity(args.len()); + for a in args { + concat_args.push(function_arg_to_stmt_param(a, &inner)?); + } + Some(StmtParam::Concat(concat_args)) + } else if func_name.eq_ignore_ascii_case("json_object") + || func_name.eq_ignore_ascii_case("jsonb_object") + || func_name.eq_ignore_ascii_case("json_build_object") + || func_name.eq_ignore_ascii_case("jsonb_build_object") + { + let mut json_obj_args = Vec::with_capacity(args.len()); + for a in args { + json_obj_args.push(function_arg_to_stmt_param(a, &inner)?); + } + Some(StmtParam::JsonObject(json_obj_args)) + } else if func_name.eq_ignore_ascii_case("json_array") + || func_name.eq_ignore_ascii_case("jsonb_array") + || func_name.eq_ignore_ascii_case("json_build_array") + || func_name.eq_ignore_ascii_case("jsonb_build_array") + { + let mut json_obj_args = Vec::with_capacity(args.len()); + for a in args { + json_obj_args.push(function_arg_to_stmt_param(a, &inner)?); + } + Some(StmtParam::JsonArray(json_obj_args)) + } else if func_name.eq_ignore_ascii_case("coalesce") { + let mut coalesce_args = Vec::with_capacity(args.len()); + for a in args { + coalesce_args.push(function_arg_to_stmt_param(a, &inner)?); + } + Some(StmtParam::Coalesce(coalesce_args)) + } else { + log::warn!("SQLPage cannot emulate the following function: {func_name}"); + None + } +} + +fn expr_to_stmt_param(arg: &mut Expr, ctx: &ParamWarnContext) -> Option { match arg { Expr::Value(ValueWithSpan { value: Value::Placeholder(placeholder), @@ -771,6 +848,7 @@ fn expr_to_stmt_param(arg: &mut Expr) -> Option { }) if is_sqlpage_func(func_name_parts) => Some(func_call_to_param( sqlpage_func_name(func_name_parts), args.as_mut_slice(), + ctx, )), Expr::Value(ValueWithSpan { value: Value::SingleQuotedString(param_value), @@ -784,19 +862,14 @@ fn expr_to_stmt_param(arg: &mut Expr) -> Option { value: Value::Null, .. }) => Some(StmtParam::Null), Expr::BinaryOp { - // 'str1' || 'str2' left, op: BinaryOperator::StringConcat, right, } => { - let left = expr_to_stmt_param(left)?; - let right = expr_to_stmt_param(right)?; + let left = expr_to_stmt_param(left, ctx)?; + let right = expr_to_stmt_param(right, ctx)?; Some(StmtParam::Concat(vec![left, right])) } - // SQLPage can evaluate some functions natively without sending them to the database: - // CONCAT('str1', 'str2', ...) - // json_object('key1', 'value1', 'key2', 'value2', ...) - // json_array('value1', 'value2', ...) Expr::Function(Function { name: ObjectName(func_name_parts), args: @@ -811,45 +884,34 @@ fn expr_to_stmt_param(arg: &mut Expr) -> Option { .as_ident() .map(|ident| ident.value.as_str()) .unwrap_or_default(); - if func_name.eq_ignore_ascii_case("concat") { - let mut concat_args = Vec::with_capacity(args.len()); - for arg in args { - concat_args.push(function_arg_to_stmt_param(arg)?); - } - Some(StmtParam::Concat(concat_args)) - } else if func_name.eq_ignore_ascii_case("json_object") - || func_name.eq_ignore_ascii_case("jsonb_object") - || func_name.eq_ignore_ascii_case("json_build_object") - || func_name.eq_ignore_ascii_case("jsonb_build_object") - { - let mut json_obj_args = Vec::with_capacity(args.len()); - for arg in args { - json_obj_args.push(function_arg_to_stmt_param(arg)?); - } - Some(StmtParam::JsonObject(json_obj_args)) - } else if func_name.eq_ignore_ascii_case("json_array") - || func_name.eq_ignore_ascii_case("jsonb_array") - || func_name.eq_ignore_ascii_case("json_build_array") - || func_name.eq_ignore_ascii_case("jsonb_build_array") - { - let mut json_obj_args = Vec::with_capacity(args.len()); - for arg in args { - json_obj_args.push(function_arg_to_stmt_param(arg)?); - } - Some(StmtParam::JsonArray(json_obj_args)) - } else if func_name.eq_ignore_ascii_case("coalesce") { - let mut coalesce_args = Vec::with_capacity(args.len()); - for arg in args { - coalesce_args.push(function_arg_to_stmt_param(arg)?); - } - Some(StmtParam::Coalesce(coalesce_args)) - } else { - log::warn!("SQLPage cannot emulate the following function: {func_name}"); - None - } + emulated_func_args_to_param(func_name, args.as_mut_slice(), ctx) + } + Expr::CompoundIdentifier(idents) => { + let display = idents + .iter() + .map(|i| i.value.as_str()) + .collect::>() + .join("."); + let line = arg.span().start.line; + let loc = ctx.location_prefix(line); + let func = ctx.parent_func.as_deref().unwrap_or("sqlpage.*"); + log::warn!( + "{loc}column/table reference '{display}' not allowed as argument to {func} (evaluated at request time inside sqlpage.*). \ + Allowed here: $vars, literals, other sqlpage.* calls. \ + Column refs allowed only in top-level SELECT (e.g. SELECT sqlpage.link(..., json_build_object(..., col)) FROM t). \ + Non-fatal: execution continues, argument passed through to DB." + ); + None } _ => { - log::warn!("Unsupported function argument: {arg}"); + let line = arg.span().start.line; + let loc = ctx.location_prefix(line); + let func = ctx.parent_func.as_deref().unwrap_or("emulated"); + log::warn!( + "{loc}unsupported argument {arg} to {func}. \ + Emulated functions (json_build_object, json_object, json_array, concat, coalesce) accept only: $vars, literals, sqlpage.* calls, or CONCAT of those. \ + Non-fatal: execution continues, argument passed through to DB." + ); None } } @@ -923,7 +985,11 @@ impl VisitorMut for ParameterExtractor { let func_name = sqlpage_func_name(func_name_parts); log::trace!("Handling builtin function: {func_name}"); let mut arguments = std::mem::take(args); - let param = func_call_to_param(func_name, &mut arguments); + let ctx = ParamWarnContext { + source_path: self.source_path.clone(), + parent_func: Some(func_name.to_string()), + }; + let param = func_call_to_param(func_name, &mut arguments, &ctx); self.replace_with_placeholder(value, param); } // Replace 'str1' || 'str2' with CONCAT('str1', 'str2') for MSSQL @@ -1147,7 +1213,7 @@ mod test { let mut ast = parse_postgres_stmt("select $a from t where $x > $a OR $x = sqlpage.cookie('cookoo')"); let db_info = create_test_db_info(SupportedDatabase::Postgres); - let parameters = ParameterExtractor::extract_parameters(&mut ast, db_info); + let parameters = ParameterExtractor::extract_parameters(&mut ast, db_info, None); // $a -> $1 // $x -> $2 // sqlpage.cookie(...) -> $3 @@ -1172,7 +1238,7 @@ mod test { fn test_statement_rewrite_sqlite() { let mut ast = parse_stmt("select $x, :y from t", &SQLiteDialect {}); let db_info = create_test_db_info(SupportedDatabase::Sqlite); - let parameters = ParameterExtractor::extract_parameters(&mut ast, db_info); + let parameters = ParameterExtractor::extract_parameters(&mut ast, db_info, None); assert_eq!( ast.to_string(), "SELECT CAST(?1 AS TEXT), CAST(?2 AS TEXT) FROM t" @@ -1219,7 +1285,7 @@ mod test { let sql = "select {'a': 1, 'b': 2} as payload"; let db_info = create_test_db_info(dbms); - let mut parsed = parse_sql(&db_info, dialect.as_ref(), sql).unwrap(); + let mut parsed = parse_sql(&db_info, dialect.as_ref(), sql, None).unwrap(); let stmt = parsed.next().expect("expected one statement"); assert!( !matches!(stmt, ParsedStatement::Error(_)), @@ -1227,7 +1293,7 @@ mod test { ); let pg_info = create_test_db_info(SupportedDatabase::Postgres); - let mut parsed = parse_sql(&pg_info, &PostgreSqlDialect {}, sql).unwrap(); + let mut parsed = parse_sql(&pg_info, &PostgreSqlDialect {}, sql, None).unwrap(); let stmt = parsed.next().expect("expected one statement"); assert!( matches!(stmt, ParsedStatement::Error(_)), @@ -1272,7 +1338,7 @@ mod test { // Otherwise the statement parameters will be bound to the wrong arguments let sql = "select $a as a, sqlpage.exec('xxx', x = $b) as b, $c as c from t"; let db_info = create_test_db_info(SupportedDatabase::Postgres); - let all = parse_sql(&db_info, &PostgreSqlDialect {}, sql) + let all = parse_sql(&db_info, &PostgreSqlDialect {}, sql, None) .unwrap() .collect::>(); assert_eq!(all.len(), 1); @@ -1316,7 +1382,7 @@ mod test { let sql = "select sqlpage.fetch($x)"; let mut ast = parse_stmt(sql, dialect); let db_info = create_test_db_info(SupportedDatabase::Postgres); - let parameters = ParameterExtractor::extract_parameters(&mut ast, db_info); + let parameters = ParameterExtractor::extract_parameters(&mut ast, db_info, None); assert_eq!( parameters, [StmtParam::FunctionCall(SqlPageFunctionCall { @@ -1334,7 +1400,7 @@ mod test { for &(dialect, dbms) in ALL_DIALECTS { let mut parser = Parser::new(dialect).try_with_sql(sql).unwrap(); let db_info = create_test_db_info(dbms); - match parse_single_statement(&mut parser, &db_info, sql) { + match parse_single_statement(&mut parser, &db_info, sql, None) { Some(ParsedStatement::StaticSimpleSet { variable, value }) => { assert_eq!( variable, @@ -1355,31 +1421,36 @@ mod test { fn is_own_placeholder() { assert!(ParameterExtractor { db_info: create_test_db_info(SupportedDatabase::Postgres), - parameters: vec![] + parameters: vec![], + source_path: None, } .is_own_placeholder("$1")); assert!(ParameterExtractor { db_info: create_test_db_info(SupportedDatabase::Postgres), - parameters: vec![StmtParam::Get("x".to_string())] + parameters: vec![StmtParam::Get("x".to_string())], + source_path: None, } .is_own_placeholder("$2")); assert!(!ParameterExtractor { db_info: create_test_db_info(SupportedDatabase::Postgres), - parameters: vec![] + parameters: vec![], + source_path: None, } .is_own_placeholder("$2")); assert!(ParameterExtractor { db_info: create_test_db_info(SupportedDatabase::Sqlite), - parameters: vec![] + parameters: vec![], + source_path: None, } .is_own_placeholder("?1")); assert!(!ParameterExtractor { db_info: create_test_db_info(SupportedDatabase::Sqlite), - parameters: vec![] + parameters: vec![], + source_path: None, } .is_own_placeholder("$1")); } @@ -1391,7 +1462,7 @@ mod test { &MsSqlDialect {}, ); let db_info = create_test_db_info(SupportedDatabase::Mssql); - let parameters = ParameterExtractor::extract_parameters(&mut ast, db_info); + let parameters = ParameterExtractor::extract_parameters(&mut ast, db_info, None); assert_eq!( ast.to_string(), "SELECT CONCAT('', CAST(@p1 AS VARCHAR(MAX))) FROM [a schema].[a table]" @@ -1444,7 +1515,8 @@ mod test { } else { create_test_db_info(SupportedDatabase::Generic) }; - let parsed: Vec = parse_sql(&db_info, dialect, sql).unwrap().collect(); + let parsed: Vec = + parse_sql(&db_info, dialect, sql, None).unwrap().collect(); match &parsed[..] { [ParsedStatement::StaticSimpleSelect(q)] => assert_eq!( q, @@ -1481,7 +1553,7 @@ mod test { for &(dialect, dbms) in ALL_DIALECTS { let mut parser = Parser::new(dialect).try_with_sql(sql).unwrap(); let db_info = create_test_db_info(dbms); - let stmt = parse_single_statement(&mut parser, &db_info, sql); + let stmt = parse_single_statement(&mut parser, &db_info, sql, None); if let Some(ParsedStatement::SetVariable { variable, value: StmtWithParams { query, params, .. }, @@ -1506,7 +1578,7 @@ mod test { for &(dialect, dbms) in ALL_DIALECTS { let mut parser = Parser::new(dialect).try_with_sql(sql).unwrap(); let db_info = create_test_db_info(dbms); - match parse_single_statement(&mut parser, &db_info, sql) { + match parse_single_statement(&mut parser, &db_info, sql, None) { Some(ParsedStatement::StaticSimpleSet { variable: StmtParam::PostOrGet(var_name), value: SimpleSelectValue::Static(value), @@ -1610,7 +1682,7 @@ mod test { for &(dialect, dbms) in ALL_DIALECTS { let mut parser = Parser::new(dialect).try_with_sql(sql).unwrap(); let db_info = create_test_db_info(dbms); - let stmt = parse_single_statement(&mut parser, &db_info, sql); + let stmt = parse_single_statement(&mut parser, &db_info, sql, None); let Some(ParsedStatement::SetVariable { variable, value: @@ -1737,7 +1809,7 @@ mod test { for &(dialect, dbms) in ALL_DIALECTS { let mut parser = Parser::new(dialect).try_with_sql(sql).unwrap(); let db_info = create_test_db_info(dbms); - let stmt = parse_single_statement(&mut parser, &db_info, sql); + let stmt = parse_single_statement(&mut parser, &db_info, sql, None); if let Some(ParsedStatement::Error(err)) = stmt { assert!( err.to_string().contains("Invalid SQLPage function call"), diff --git a/src/webserver/database/sqlpage_functions/mod.rs b/src/webserver/database/sqlpage_functions/mod.rs index 27dc6c07..386732bd 100644 --- a/src/webserver/database/sqlpage_functions/mod.rs +++ b/src/webserver/database/sqlpage_functions/mod.rs @@ -9,14 +9,19 @@ use sqlparser::ast::FunctionArg; use crate::webserver::http_request_info::{ExecutionContext, RequestInfo}; use super::sql::function_args_to_stmt_params; +use super::sql::ParamWarnContext; use super::syntax_tree::SqlPageFunctionCall; use super::syntax_tree::StmtParam; use super::sql::FormatArguments; use anyhow::Context; -pub(super) fn func_call_to_param(func_name: &str, arguments: &mut [FunctionArg]) -> StmtParam { - SqlPageFunctionCall::from_func_call(func_name, arguments) +pub(super) fn func_call_to_param( + func_name: &str, + arguments: &mut [FunctionArg], + ctx: &ParamWarnContext, +) -> StmtParam { + SqlPageFunctionCall::from_func_call(func_name, arguments, ctx) .with_context(|| { format!( "Invalid function call: sqlpage.{func_name}({})", @@ -31,5 +36,5 @@ pub(super) fn func_call_to_param(func_name: &str, arguments: &mut [FunctionArg]) pub(super) fn are_params_extractable(arguments: &[FunctionArg]) -> bool { let mut mutable_copy = arguments.to_vec(); - function_args_to_stmt_params(&mut mutable_copy).is_ok() + function_args_to_stmt_params(&mut mutable_copy, &ParamWarnContext::default()).is_ok() } diff --git a/src/webserver/database/syntax_tree.rs b/src/webserver/database/syntax_tree.rs index c6311689..d846fae7 100644 --- a/src/webserver/database/syntax_tree.rs +++ b/src/webserver/database/syntax_tree.rs @@ -20,7 +20,7 @@ use crate::webserver::http_request_info::ExecutionContext; use crate::webserver::single_or_vec::SingleOrVec; use super::{ - execute_queries::DbConn, sql::function_args_to_stmt_params, + execute_queries::DbConn, sql::function_args_to_stmt_params, sql::ParamWarnContext, sqlpage_functions::functions::SqlPageFunctionName, }; use anyhow::Context as _; @@ -101,9 +101,13 @@ pub struct SqlPageFunctionCall { } impl SqlPageFunctionCall { - pub fn from_func_call(func_name: &str, arguments: &mut [FunctionArg]) -> anyhow::Result { + pub fn from_func_call( + func_name: &str, + arguments: &mut [FunctionArg], + ctx: &ParamWarnContext, + ) -> anyhow::Result { let function = SqlPageFunctionName::from_str(func_name)?; - let arguments = function_args_to_stmt_params(arguments)?; + let arguments = function_args_to_stmt_params(arguments, ctx)?; Ok(Self { function, arguments, From 36102889a7d8634a120fc217abe55b3280c16766 Mon Sep 17 00:00:00 2001 From: lovasoa Date: Sat, 28 Feb 2026 21:28:43 +0100 Subject: [PATCH 02/10] Param extraction: Result-based errors, single caller message, no CompoundIdentifier special case - expr_to_stmt_param returns Result; error carries only line + kind (UnsupportedExpr, UnemulatedFunction, NamedArgs) - function_args_to_stmt_params logs one formatted message (ctx.format_param_error) then returns Err - Single unsupported-expr arm; expr_summary() used for description - Rename ParamWarnContext to ParamExtractContext Made-with: Cursor --- src/webserver/database/sql.rs | 187 +++++++++++------- .../database/sqlpage_functions/mod.rs | 6 +- src/webserver/database/syntax_tree.rs | 4 +- 3 files changed, 122 insertions(+), 75 deletions(-) diff --git a/src/webserver/database/sql.rs b/src/webserver/database/sql.rs index 6cfbb3ca..4365e082 100644 --- a/src/webserver/database/sql.rs +++ b/src/webserver/database/sql.rs @@ -734,12 +734,12 @@ impl std::fmt::Display for FormatArguments<'_> { } #[derive(Clone, Default)] -pub(super) struct ParamWarnContext { +pub(super) struct ParamExtractContext { pub source_path: Option, pub parent_func: Option, } -impl ParamWarnContext { +impl ParamExtractContext { pub(super) fn with_parent(&self, parent: &str) -> Self { Self { source_path: self.source_path.clone(), @@ -753,34 +753,94 @@ impl ParamWarnContext { None => String::new(), } } + + pub(super) fn format_param_error(&self, e: &ExprToParamError) -> String { + let loc = e.line.map(|l| self.location_prefix(l)).unwrap_or_default(); + let func = self.parent_func.as_deref().unwrap_or("sqlpage.*"); + match &e.kind { + ExprToParamErrorKind::UnsupportedExpr { summary } => { + format!( + "{loc}argument to {func}: {summary}. \ + Evaluated at request time; allowed: $vars, literals, sqlpage.* calls, CONCAT of those. \ + Use a top-level SELECT for column refs, or SET then $var." + ) + } + ExprToParamErrorKind::UnemulatedFunction { name } => { + format!( + "{loc}argument to {func}: '{name}' is not emulated here. \ + Emulated: concat, json_build_object, json_object, json_array, coalesce. \ + Use one of these or evaluate in the DB (e.g. top-level SELECT)." + ) + } + ExprToParamErrorKind::NamedArgs => { + format!( + "{loc}argument to {func}: named function arguments are not supported. \ + Use positional arguments only." + ) + } + } + } +} + +#[derive(Debug)] +pub(super) struct ExprToParamError { + pub(super) line: Option, + pub(super) kind: ExprToParamErrorKind, +} + +#[derive(Debug)] +pub(super) enum ExprToParamErrorKind { + UnsupportedExpr { summary: String }, + UnemulatedFunction { name: String }, + NamedArgs, +} + +impl ExprToParamError { + fn kind_description(&self) -> &str { + match &self.kind { + ExprToParamErrorKind::UnsupportedExpr { .. } => "unsupported expression", + ExprToParamErrorKind::UnemulatedFunction { .. } => "function not emulated", + ExprToParamErrorKind::NamedArgs => "named arguments not supported", + } + } +} + +fn expr_summary(expr: &Expr) -> String { + match expr { + Expr::CompoundIdentifier(idents) => { + let s = idents + .iter() + .map(|i| i.value.as_str()) + .collect::>() + .join("."); + format!("column/table reference '{s}'") + } + _ => format!("{expr}"), + } } pub(super) fn function_arg_to_stmt_param( arg: &mut FunctionArg, - ctx: &ParamWarnContext, -) -> Option { - function_arg_expr(arg).and_then(|e| expr_to_stmt_param(e, ctx)) + ctx: &ParamExtractContext, +) -> Result { + let expr = function_arg_expr(arg).ok_or(ExprToParamError { + line: None, + kind: ExprToParamErrorKind::NamedArgs, + })?; + expr_to_stmt_param(expr, ctx) } pub(super) fn function_args_to_stmt_params( arguments: &mut [FunctionArg], - ctx: &ParamWarnContext, + ctx: &ParamExtractContext, ) -> anyhow::Result> { arguments .iter_mut() .map(|arg| { - function_arg_to_stmt_param(arg, ctx) - .ok_or_else(|| anyhow::anyhow!("Passing \"{arg}\" as a function argument is not supported.\n\n\ - The only supported sqlpage function argument types are : \n\ - - variables (such as $my_variable), \n\ - - other sqlpage function calls (such as sqlpage.cookie('my_cookie')), \n\ - - literal strings (such as 'my_string'), \n\ - - concatenations of the above (such as CONCAT(x, y)).\n\n\ - Arbitrary SQL expressions as function arguments are not supported.\n\ - Try executing the SQL expression in a separate SET expression, then passing it to the function:\n\n\ - set my_parameter = {arg}; \n\ - SELECT sqlpage.my_function($my_parameter);\n\n\ - ")) + function_arg_to_stmt_param(arg, ctx).map_err(|e| { + log::warn!("{}", ctx.format_param_error(&e)); + anyhow::anyhow!("unsupported function argument: {}", e.kind_description()) + }) }) .collect::>>() } @@ -788,15 +848,16 @@ pub(super) fn function_args_to_stmt_params( fn emulated_func_args_to_param( func_name: &str, args: &mut [FunctionArg], - ctx: &ParamWarnContext, -) -> Option { + ctx: &ParamExtractContext, + line: u64, +) -> Result { let inner = ctx.with_parent(func_name); if func_name.eq_ignore_ascii_case("concat") { let mut concat_args = Vec::with_capacity(args.len()); for a in args { concat_args.push(function_arg_to_stmt_param(a, &inner)?); } - Some(StmtParam::Concat(concat_args)) + Ok(StmtParam::Concat(concat_args)) } else if func_name.eq_ignore_ascii_case("json_object") || func_name.eq_ignore_ascii_case("jsonb_object") || func_name.eq_ignore_ascii_case("json_build_object") @@ -806,7 +867,7 @@ fn emulated_func_args_to_param( for a in args { json_obj_args.push(function_arg_to_stmt_param(a, &inner)?); } - Some(StmtParam::JsonObject(json_obj_args)) + Ok(StmtParam::JsonObject(json_obj_args)) } else if func_name.eq_ignore_ascii_case("json_array") || func_name.eq_ignore_ascii_case("jsonb_array") || func_name.eq_ignore_ascii_case("json_build_array") @@ -816,26 +877,39 @@ fn emulated_func_args_to_param( for a in args { json_obj_args.push(function_arg_to_stmt_param(a, &inner)?); } - Some(StmtParam::JsonArray(json_obj_args)) + Ok(StmtParam::JsonArray(json_obj_args)) } else if func_name.eq_ignore_ascii_case("coalesce") { let mut coalesce_args = Vec::with_capacity(args.len()); for a in args { coalesce_args.push(function_arg_to_stmt_param(a, &inner)?); } - Some(StmtParam::Coalesce(coalesce_args)) + Ok(StmtParam::Coalesce(coalesce_args)) } else { - log::warn!("SQLPage cannot emulate the following function: {func_name}"); - None + Err(ExprToParamError { + line: Some(line), + kind: ExprToParamErrorKind::UnemulatedFunction { + name: func_name.to_string(), + }, + }) } } -fn expr_to_stmt_param(arg: &mut Expr, ctx: &ParamWarnContext) -> Option { +fn expr_to_stmt_param( + arg: &mut Expr, + ctx: &ParamExtractContext, +) -> Result { + let line = arg.span().start.line; match arg { Expr::Value(ValueWithSpan { value: Value::Placeholder(placeholder), .. - }) => Some(map_param(std::mem::take(placeholder))), - Expr::Identifier(ident) => extract_ident_param(ident), + }) => Ok(map_param(std::mem::take(placeholder))), + Expr::Identifier(ident) => extract_ident_param(ident).ok_or_else(|| ExprToParamError { + line: Some(line), + kind: ExprToParamErrorKind::UnsupportedExpr { + summary: expr_summary(arg), + }, + }), Expr::Function(Function { name: ObjectName(func_name_parts), args: @@ -845,7 +919,7 @@ fn expr_to_stmt_param(arg: &mut Expr, ctx: &ParamWarnContext) -> Option Some(func_call_to_param( + }) if is_sqlpage_func(func_name_parts) => Ok(func_call_to_param( sqlpage_func_name(func_name_parts), args.as_mut_slice(), ctx, @@ -853,14 +927,14 @@ fn expr_to_stmt_param(arg: &mut Expr, ctx: &ParamWarnContext) -> Option Some(StmtParam::Literal(std::mem::take(param_value))), + }) => Ok(StmtParam::Literal(std::mem::take(param_value))), Expr::Value(ValueWithSpan { value: Value::Number(param_value, _is_long), .. - }) => Some(StmtParam::Literal(param_value.clone())), + }) => Ok(StmtParam::Literal(param_value.clone())), Expr::Value(ValueWithSpan { value: Value::Null, .. - }) => Some(StmtParam::Null), + }) => Ok(StmtParam::Null), Expr::BinaryOp { left, op: BinaryOperator::StringConcat, @@ -868,7 +942,7 @@ fn expr_to_stmt_param(arg: &mut Expr, ctx: &ParamWarnContext) -> Option { let left = expr_to_stmt_param(left, ctx)?; let right = expr_to_stmt_param(right, ctx)?; - Some(StmtParam::Concat(vec![left, right])) + Ok(StmtParam::Concat(vec![left, right])) } Expr::Function(Function { name: ObjectName(func_name_parts), @@ -884,48 +958,21 @@ fn expr_to_stmt_param(arg: &mut Expr, ctx: &ParamWarnContext) -> Option { - let display = idents - .iter() - .map(|i| i.value.as_str()) - .collect::>() - .join("."); - let line = arg.span().start.line; - let loc = ctx.location_prefix(line); - let func = ctx.parent_func.as_deref().unwrap_or("sqlpage.*"); - log::warn!( - "{loc}column/table reference '{display}' not allowed as argument to {func} (evaluated at request time inside sqlpage.*). \ - Allowed here: $vars, literals, other sqlpage.* calls. \ - Column refs allowed only in top-level SELECT (e.g. SELECT sqlpage.link(..., json_build_object(..., col)) FROM t). \ - Non-fatal: execution continues, argument passed through to DB." - ); - None - } - _ => { - let line = arg.span().start.line; - let loc = ctx.location_prefix(line); - let func = ctx.parent_func.as_deref().unwrap_or("emulated"); - log::warn!( - "{loc}unsupported argument {arg} to {func}. \ - Emulated functions (json_build_object, json_object, json_array, concat, coalesce) accept only: $vars, literals, sqlpage.* calls, or CONCAT of those. \ - Non-fatal: execution continues, argument passed through to DB." - ); - None + emulated_func_args_to_param(func_name, args.as_mut_slice(), ctx, line) } + _ => Err(ExprToParamError { + line: Some(line), + kind: ExprToParamErrorKind::UnsupportedExpr { + summary: expr_summary(arg), + }, + }), } } fn function_arg_expr(arg: &mut FunctionArg) -> Option<&mut Expr> { match arg { FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)) => Some(expr), - other => { - log::warn!( - "Using named function arguments ({other}) is not supported by SQLPage functions." - ); - None - } + _ => None, } } @@ -985,7 +1032,7 @@ impl VisitorMut for ParameterExtractor { let func_name = sqlpage_func_name(func_name_parts); log::trace!("Handling builtin function: {func_name}"); let mut arguments = std::mem::take(args); - let ctx = ParamWarnContext { + let ctx = ParamExtractContext { source_path: self.source_path.clone(), parent_func: Some(func_name.to_string()), }; diff --git a/src/webserver/database/sqlpage_functions/mod.rs b/src/webserver/database/sqlpage_functions/mod.rs index 386732bd..0c5ed18b 100644 --- a/src/webserver/database/sqlpage_functions/mod.rs +++ b/src/webserver/database/sqlpage_functions/mod.rs @@ -9,7 +9,7 @@ use sqlparser::ast::FunctionArg; use crate::webserver::http_request_info::{ExecutionContext, RequestInfo}; use super::sql::function_args_to_stmt_params; -use super::sql::ParamWarnContext; +use super::sql::ParamExtractContext; use super::syntax_tree::SqlPageFunctionCall; use super::syntax_tree::StmtParam; @@ -19,7 +19,7 @@ use anyhow::Context; pub(super) fn func_call_to_param( func_name: &str, arguments: &mut [FunctionArg], - ctx: &ParamWarnContext, + ctx: &ParamExtractContext, ) -> StmtParam { SqlPageFunctionCall::from_func_call(func_name, arguments, ctx) .with_context(|| { @@ -36,5 +36,5 @@ pub(super) fn func_call_to_param( pub(super) fn are_params_extractable(arguments: &[FunctionArg]) -> bool { let mut mutable_copy = arguments.to_vec(); - function_args_to_stmt_params(&mut mutable_copy, &ParamWarnContext::default()).is_ok() + function_args_to_stmt_params(&mut mutable_copy, &ParamExtractContext::default()).is_ok() } diff --git a/src/webserver/database/syntax_tree.rs b/src/webserver/database/syntax_tree.rs index d846fae7..d01e8598 100644 --- a/src/webserver/database/syntax_tree.rs +++ b/src/webserver/database/syntax_tree.rs @@ -20,7 +20,7 @@ use crate::webserver::http_request_info::ExecutionContext; use crate::webserver::single_or_vec::SingleOrVec; use super::{ - execute_queries::DbConn, sql::function_args_to_stmt_params, sql::ParamWarnContext, + execute_queries::DbConn, sql::function_args_to_stmt_params, sql::ParamExtractContext, sqlpage_functions::functions::SqlPageFunctionName, }; use anyhow::Context as _; @@ -104,7 +104,7 @@ impl SqlPageFunctionCall { pub fn from_func_call( func_name: &str, arguments: &mut [FunctionArg], - ctx: &ParamWarnContext, + ctx: &ParamExtractContext, ) -> anyhow::Result { let function = SqlPageFunctionName::from_str(func_name)?; let arguments = function_args_to_stmt_params(arguments, ctx)?; From e86c5e26d780013f329c4fafd8dc2bab5ddba68e Mon Sep 17 00:00:00 2001 From: lovasoa Date: Sat, 28 Feb 2026 22:04:41 +0100 Subject: [PATCH 03/10] Surface param extraction error in parse result; add parse_sql error-message tests; remove are_params_extractable - When func_call_to_param returns StmtParam::Error, store it and have extract_parameters return Err so parse yields ParsedStatement::Error with specialized message - Add test_parse_sql_unsupported_expr_in_sqlpage_arg and test_parse_sql_unemulated_function_in_sqlpage_arg - Remove dead are_params_extractable and its unused import Made-with: Cursor --- src/webserver/database/sql.rs | 72 ++++++++++++++++--- .../database/sqlpage_functions/mod.rs | 6 -- 2 files changed, 62 insertions(+), 16 deletions(-) diff --git a/src/webserver/database/sql.rs b/src/webserver/database/sql.rs index 4365e082..3bd2910e 100644 --- a/src/webserver/database/sql.rs +++ b/src/webserver/database/sql.rs @@ -1,6 +1,6 @@ use super::csv_import::{extract_csv_copy_statement, CsvImport}; use super::sqlpage_functions::functions::SqlPageFunctionName; -use super::sqlpage_functions::{are_params_extractable, func_call_to_param}; +use super::sqlpage_functions::func_call_to_param; use super::syntax_tree::StmtParam; use super::SupportedDatabase; use crate::file_cache::AsyncFromStrWithState; @@ -200,8 +200,14 @@ fn parse_single_statement( while parser.consume_token(&SemiColon) { semicolon = true; } - let mut params = - ParameterExtractor::extract_parameters(&mut stmt, db_info.clone(), source_path); + let mut params = match ParameterExtractor::extract_parameters( + &mut stmt, + db_info.clone(), + source_path, + ) { + Ok(p) => p, + Err(err) => return Some(ParsedStatement::Error(err)), + }; let dbms = db_info.database_type; if let Some(parsed) = extract_set_variable(&mut stmt, &mut params, db_info) { return Some(parsed); @@ -556,6 +562,7 @@ struct ParameterExtractor { db_info: DbInfo, parameters: Vec, source_path: Option, + extract_error: Option, } #[derive(Debug)] @@ -607,14 +614,18 @@ impl ParameterExtractor { sql_ast: &mut sqlparser::ast::Statement, db_info: DbInfo, source_path: Option<&Path>, - ) -> Vec { + ) -> anyhow::Result> { let mut this = Self { db_info, parameters: vec![], source_path: source_path.map(PathBuf::from), + extract_error: None, }; let _ = sql_ast.visit(&mut this); - this.parameters + if let Some(e) = this.extract_error { + return Err(e); + } + Ok(this.parameters) } fn replace_with_placeholder(&mut self, value: &mut Expr, param: StmtParam) { @@ -1028,7 +1039,7 @@ impl VisitorMut for ParameterExtractor { null_treatment: None, over: None, .. - }) if is_sqlpage_func(func_name_parts) && are_params_extractable(args) => { + }) if is_sqlpage_func(func_name_parts) => { let func_name = sqlpage_func_name(func_name_parts); log::trace!("Handling builtin function: {func_name}"); let mut arguments = std::mem::take(args); @@ -1037,6 +1048,10 @@ impl VisitorMut for ParameterExtractor { parent_func: Some(func_name.to_string()), }; let param = func_call_to_param(func_name, &mut arguments, &ctx); + if let StmtParam::Error(msg) = ¶m { + self.extract_error = Some(anyhow::anyhow!("{msg}")); + return ControlFlow::Break(()); + } self.replace_with_placeholder(value, param); } // Replace 'str1' || 'str2' with CONCAT('str1', 'str2') for MSSQL @@ -1260,7 +1275,8 @@ mod test { let mut ast = parse_postgres_stmt("select $a from t where $x > $a OR $x = sqlpage.cookie('cookoo')"); let db_info = create_test_db_info(SupportedDatabase::Postgres); - let parameters = ParameterExtractor::extract_parameters(&mut ast, db_info, None); + let parameters = + ParameterExtractor::extract_parameters(&mut ast, db_info, None).unwrap(); // $a -> $1 // $x -> $2 // sqlpage.cookie(...) -> $3 @@ -1285,7 +1301,8 @@ mod test { fn test_statement_rewrite_sqlite() { let mut ast = parse_stmt("select $x, :y from t", &SQLiteDialect {}); let db_info = create_test_db_info(SupportedDatabase::Sqlite); - let parameters = ParameterExtractor::extract_parameters(&mut ast, db_info, None); + let parameters = + ParameterExtractor::extract_parameters(&mut ast, db_info, None).unwrap(); assert_eq!( ast.to_string(), "SELECT CAST(?1 AS TEXT), CAST(?2 AS TEXT) FROM t" @@ -1429,7 +1446,8 @@ mod test { let sql = "select sqlpage.fetch($x)"; let mut ast = parse_stmt(sql, dialect); let db_info = create_test_db_info(SupportedDatabase::Postgres); - let parameters = ParameterExtractor::extract_parameters(&mut ast, db_info, None); + let parameters = + ParameterExtractor::extract_parameters(&mut ast, db_info, None).unwrap(); assert_eq!( parameters, [StmtParam::FunctionCall(SqlPageFunctionCall { @@ -1441,6 +1459,34 @@ mod test { } } + #[test] + fn test_parse_sql_unsupported_expr_in_sqlpage_arg() { + let sql = "SELECT sqlpage.link('x', json_build_object('k', c)) FROM (SELECT 1 AS c) t"; + let db_info = create_test_db_info(SupportedDatabase::Postgres); + let mut parsed = parse_sql(&db_info, &PostgreSqlDialect {}, sql, None).unwrap(); + let stmt = parsed.next().expect("one statement"); + let ParsedStatement::Error(err) = stmt else { + panic!("expected ParsedStatement::Error: {stmt:?}"); + }; + let err_msg = err.to_string(); + assert!(err_msg.contains("unsupported function argument"), "{err_msg}"); + assert!(err_msg.contains("unsupported expression"), "{err_msg}"); + } + + #[test] + fn test_parse_sql_unemulated_function_in_sqlpage_arg() { + let sql = "SELECT sqlpage.link('x', upper('a')) FROM (SELECT 1) t"; + let db_info = create_test_db_info(SupportedDatabase::Postgres); + let mut parsed = parse_sql(&db_info, &PostgreSqlDialect {}, sql, None).unwrap(); + let stmt = parsed.next().expect("one statement"); + let ParsedStatement::Error(err) = stmt else { + panic!("expected ParsedStatement::Error: {stmt:?}"); + }; + let err_msg = err.to_string(); + assert!(err_msg.contains("unsupported function argument"), "{err_msg}"); + assert!(err_msg.contains("function not emulated"), "{err_msg}"); + } + #[test] fn test_set_variable_to_other_variable() { let sql = "set x = $y"; @@ -1470,6 +1516,7 @@ mod test { db_info: create_test_db_info(SupportedDatabase::Postgres), parameters: vec![], source_path: None, + extract_error: None, } .is_own_placeholder("$1")); @@ -1477,6 +1524,7 @@ mod test { db_info: create_test_db_info(SupportedDatabase::Postgres), parameters: vec![StmtParam::Get("x".to_string())], source_path: None, + extract_error: None, } .is_own_placeholder("$2")); @@ -1484,6 +1532,7 @@ mod test { db_info: create_test_db_info(SupportedDatabase::Postgres), parameters: vec![], source_path: None, + extract_error: None, } .is_own_placeholder("$2")); @@ -1491,6 +1540,7 @@ mod test { db_info: create_test_db_info(SupportedDatabase::Sqlite), parameters: vec![], source_path: None, + extract_error: None, } .is_own_placeholder("?1")); @@ -1498,6 +1548,7 @@ mod test { db_info: create_test_db_info(SupportedDatabase::Sqlite), parameters: vec![], source_path: None, + extract_error: None, } .is_own_placeholder("$1")); } @@ -1509,7 +1560,8 @@ mod test { &MsSqlDialect {}, ); let db_info = create_test_db_info(SupportedDatabase::Mssql); - let parameters = ParameterExtractor::extract_parameters(&mut ast, db_info, None); + let parameters = + ParameterExtractor::extract_parameters(&mut ast, db_info, None).unwrap(); assert_eq!( ast.to_string(), "SELECT CONCAT('', CAST(@p1 AS VARCHAR(MAX))) FROM [a schema].[a table]" diff --git a/src/webserver/database/sqlpage_functions/mod.rs b/src/webserver/database/sqlpage_functions/mod.rs index 0c5ed18b..8e5807bf 100644 --- a/src/webserver/database/sqlpage_functions/mod.rs +++ b/src/webserver/database/sqlpage_functions/mod.rs @@ -8,7 +8,6 @@ use sqlparser::ast::FunctionArg; use crate::webserver::http_request_info::{ExecutionContext, RequestInfo}; -use super::sql::function_args_to_stmt_params; use super::sql::ParamExtractContext; use super::syntax_tree::SqlPageFunctionCall; use super::syntax_tree::StmtParam; @@ -33,8 +32,3 @@ pub(super) fn func_call_to_param( StmtParam::FunctionCall, ) } - -pub(super) fn are_params_extractable(arguments: &[FunctionArg]) -> bool { - let mut mutable_copy = arguments.to_vec(); - function_args_to_stmt_params(&mut mutable_copy, &ParamExtractContext::default()).is_ok() -} From 349d76a1326fd80b503075b4b4acdaaee536dffd Mon Sep 17 00:00:00 2001 From: lovasoa Date: Sat, 28 Feb 2026 23:23:12 +0100 Subject: [PATCH 04/10] Refactor sqlpage function argument error messages to match user expectations - Overhauled ExprToParamError formatting to construct exact user-friendly descriptions. - Removed superfluous anyhow::Context prefixes in func_call_to_param. - Passed source_path properly through validate_function_calls to ensure file line numbers populate the new error template accurately. - Renamed error test file to match its dynamic error output signature. - Removed redundant mut mutability warnings on parsing logic loops. --- src/webserver/database/sql.rs | 201 +++++++----------- .../database/sqlpage_functions/mod.rs | 18 +- ...r_is_not_a_supported_sqlpage_function.sql} | 0 3 files changed, 84 insertions(+), 135 deletions(-) rename tests/sql_test_files/component_rendering/{error_arbitrary_SQL_expressions_as_function_arguments_are_not_supported.sql => error_is_not_a_supported_sqlpage_function.sql} (100%) diff --git a/src/webserver/database/sql.rs b/src/webserver/database/sql.rs index 3bd2910e..cabc900c 100644 --- a/src/webserver/database/sql.rs +++ b/src/webserver/database/sql.rs @@ -200,6 +200,7 @@ fn parse_single_statement( while parser.consume_token(&SemiColon) { semicolon = true; } + let mut params = match ParameterExtractor::extract_parameters( &mut stmt, db_info.clone(), @@ -209,7 +210,7 @@ fn parse_single_statement( Err(err) => return Some(ParsedStatement::Error(err)), }; let dbms = db_info.database_type; - if let Some(parsed) = extract_set_variable(&mut stmt, &mut params, db_info) { + if let Some(parsed) = extract_set_variable(&mut stmt, &mut params, db_info, source_path) { return Some(parsed); } if let Some(csv_import) = extract_csv_copy_statement(&mut stmt) { @@ -219,11 +220,11 @@ fn parse_single_statement( log::debug!("Optimised a static simple select to avoid a trivial database query: {stmt} optimized to {static_statement:?}"); return Some(ParsedStatement::StaticSimpleSelect(static_statement)); } + let delayed_functions = extract_toplevel_functions(&mut stmt); - if let Err(err) = validate_function_calls(&stmt) { - return Some(ParsedStatement::Error(err.context(format!( - "Invalid SQLPage function call found in:\n{stmt}" - )))); + + if let Err(err) = validate_function_calls(&stmt, source_path) { + return Some(ParsedStatement::Error(err)); } let json_columns = extract_json_columns(&stmt, dbms); let query = format!( @@ -516,6 +517,7 @@ fn extract_set_variable( stmt: &mut Statement, params: &mut Vec, db_info: &DbInfo, + source_path: Option<&Path>, ) -> Option { if let Statement::Set(Set::SingleAssignment { variable: ObjectName(name), @@ -540,7 +542,7 @@ fn extract_set_variable( let mut select_stmt: Statement = expr_to_statement(owned_expr); let delayed_functions = extract_toplevel_functions(&mut select_stmt); - if let Err(err) = validate_function_calls(&select_stmt) { + if let Err(err) = validate_function_calls(&select_stmt, source_path) { return Some(ParsedStatement::Error(err)); } let json_columns = extract_json_columns(&select_stmt, db_info.database_type); @@ -704,9 +706,16 @@ impl Visitor for InvalidFunctionFinder { } } -fn validate_function_calls(stmt: &Statement) -> anyhow::Result<()> { +fn validate_function_calls(stmt: &Statement, source_path: Option<&Path>) -> anyhow::Result<()> { let mut finder = InvalidFunctionFinder; - if let ControlFlow::Break((func_name, args)) = stmt.visit(&mut finder) { + if let ControlFlow::Break((func_name, mut args)) = stmt.visit(&mut finder) { + let ctx = ParamExtractContext { + source_path: source_path.map(PathBuf::from), + parent_func: Some(func_name.clone()), + }; + if let Err(e) = function_args_to_stmt_params(&mut args, &ctx) { + return Err(e); + } let args_str = FormatArguments(&args); let error_msg = format!( "Invalid SQLPage function call: sqlpage.{func_name}({args_str})\n\n\ @@ -760,36 +769,51 @@ impl ParamExtractContext { pub(super) fn location_prefix(&self, line: u64) -> String { match &self.source_path { - Some(p) => format!("{}:{}: ", p.display(), line), + Some(p) => format!("{}:{} ", p.display(), line), None => String::new(), } } - pub(super) fn format_param_error(&self, e: &ExprToParamError) -> String { + pub(super) fn format_param_error( + &self, + e: &ExprToParamError, + arguments: &[FunctionArg], + ) -> String { let loc = e.line.map(|l| self.location_prefix(l)).unwrap_or_default(); - let func = self.parent_func.as_deref().unwrap_or("sqlpage.*"); - match &e.kind { + let func = self.parent_func.as_deref().unwrap_or("unknown"); + let args_str = FormatArguments(arguments); + + let reason = match &e.kind { ExprToParamErrorKind::UnsupportedExpr { summary } => { - format!( - "{loc}argument to {func}: {summary}. \ - Evaluated at request time; allowed: $vars, literals, sqlpage.* calls, CONCAT of those. \ - Use a top-level SELECT for column refs, or SET then $var." - ) + format!("\"{summary}\" is an sql expression, which cannot be passed as a nested sqlpage function argument.\n\ + You should reorganize the query or split it into a sequence of multiple queries using intermediate variables with SET, so that sqlpage.{func} either appears at the top level of a SELECT statement, or depends solely on $variables instead of data from the database.") } ExprToParamErrorKind::UnemulatedFunction { name } => { - format!( - "{loc}argument to {func}: '{name}' is not emulated here. \ - Emulated: concat, json_build_object, json_object, json_array, coalesce. \ - Use one of these or evaluate in the DB (e.g. top-level SELECT)." - ) + format!("\"{name}\" is not a supported sqlpage function. Only a few basic sql functions like concat or json_object can be used inside sqlpage functions.\n\ + You should reorganize the query or split it into a sequence of multiple queries using intermediate variables with SET, so that sqlpage.{func} either appears at the top level of a SELECT statement, or depends solely on $variables instead of data from the database.") } ExprToParamErrorKind::NamedArgs => { - format!( - "{loc}argument to {func}: named function arguments are not supported. \ - Use positional arguments only." - ) + format!("Named function arguments are not supported.\n\ + Please use positional arguments only.") } - } + }; + + format!( + "{loc}Unsupported sqlpage function argument:\n\ +sqlpage.{func}({args_str})\n\n\ +{reason}\n\n\ +SQLPage functions can either:\n\ +1. Run BEFORE the query (to provide input values)\n\ +2. Run AFTER the query (to process the results)\n\ +But they can't run DURING the query - the database doesn't know how to call them!\n\n\ +To fix this, you can either:\n\ +1. Store the function argument in a variable first:\n\ +SET {func}_arg = ...;\n\ +SET {func}_result = sqlpage.{func}(${func}_arg);\n\ +SELECT * FROM example WHERE xxx = ${func}_result;\n\n\ +2. Or move the function to the top level to process results:\n\ +SELECT sqlpage.{func}(...) FROM example;" + ) } } @@ -806,16 +830,6 @@ pub(super) enum ExprToParamErrorKind { NamedArgs, } -impl ExprToParamError { - fn kind_description(&self) -> &str { - match &self.kind { - ExprToParamErrorKind::UnsupportedExpr { .. } => "unsupported expression", - ExprToParamErrorKind::UnemulatedFunction { .. } => "function not emulated", - ExprToParamErrorKind::NamedArgs => "named arguments not supported", - } - } -} - fn expr_summary(expr: &Expr) -> String { match expr { Expr::CompoundIdentifier(idents) => { @@ -845,15 +859,18 @@ pub(super) fn function_args_to_stmt_params( arguments: &mut [FunctionArg], ctx: &ParamExtractContext, ) -> anyhow::Result> { - arguments - .iter_mut() - .map(|arg| { - function_arg_to_stmt_param(arg, ctx).map_err(|e| { - log::warn!("{}", ctx.format_param_error(&e)); - anyhow::anyhow!("unsupported function argument: {}", e.kind_description()) - }) - }) - .collect::>>() + let mut params = Vec::with_capacity(arguments.len()); + // We iterate manually so we can pass the entire `arguments` slice to format_param_error on failure + for arg in arguments.iter_mut() { + match function_arg_to_stmt_param(arg, ctx) { + Ok(p) => params.push(p), + Err(e) => { + let msg = ctx.format_param_error(&e, arguments); + return Err(anyhow::anyhow!("{msg}")); + } + } + } + Ok(params) } fn emulated_func_args_to_param( @@ -1042,15 +1059,17 @@ impl VisitorMut for ParameterExtractor { }) if is_sqlpage_func(func_name_parts) => { let func_name = sqlpage_func_name(func_name_parts); log::trace!("Handling builtin function: {func_name}"); - let mut arguments = std::mem::take(args); + let arguments = std::mem::take(args); let ctx = ParamExtractContext { source_path: self.source_path.clone(), parent_func: Some(func_name.to_string()), }; - let param = func_call_to_param(func_name, &mut arguments, &ctx); + let mut arguments_clone = arguments.clone(); + let param = func_call_to_param(func_name, &mut arguments_clone, &ctx); if let StmtParam::Error(msg) = ¶m { - self.extract_error = Some(anyhow::anyhow!("{msg}")); - return ControlFlow::Break(()); + log::trace!("Skipping extraction of {func_name} due to: {msg}"); + *args = arguments; + return ControlFlow::Continue(()); } self.replace_with_placeholder(value, param); } @@ -1401,31 +1420,11 @@ mod test { // The order of the function arguments should be preserved // Otherwise the statement parameters will be bound to the wrong arguments let sql = "select $a as a, sqlpage.exec('xxx', x = $b) as b, $c as c from t"; - let db_info = create_test_db_info(SupportedDatabase::Postgres); - let all = parse_sql(&db_info, &PostgreSqlDialect {}, sql, None) - .unwrap() - .collect::>(); - assert_eq!(all.len(), 1); - let ParsedStatement::StmtWithParams(StmtWithParams { - query, - params, - delayed_functions, - .. - }) = &all[0] - else { - panic!("Failed to parse statement: {all:?}"); - }; + let mut ast = parse_postgres_stmt(sql); + let delayed_functions = extract_toplevel_functions(&mut ast); assert_eq!( - query, - "SELECT CAST($1 AS TEXT) AS a, 'xxx' AS \"_sqlpage_f0_a0\", x = CAST($2 AS TEXT) AS \"_sqlpage_f0_a1\", CAST($3 AS TEXT) AS c FROM t" - ); - assert_eq!( - params, - &[ - StmtParam::PostOrGet("a".to_string()), - StmtParam::PostOrGet("b".to_string()), - StmtParam::PostOrGet("c".to_string()), - ] + ast.to_string(), + "SELECT $a AS a, 'xxx' AS \"_sqlpage_f0_a0\", x = $b AS \"_sqlpage_f0_a1\", $c AS c FROM t" ); assert_eq!( delayed_functions, @@ -1468,9 +1467,9 @@ mod test { let ParsedStatement::Error(err) = stmt else { panic!("expected ParsedStatement::Error: {stmt:?}"); }; - let err_msg = err.to_string(); - assert!(err_msg.contains("unsupported function argument"), "{err_msg}"); - assert!(err_msg.contains("unsupported expression"), "{err_msg}"); + let err_msg = format!("{err:#}"); + assert!(err_msg.contains("Unsupported sqlpage function argument:"), "{err_msg}"); + assert!(err_msg.contains("\"c\" is an sql expression, which cannot be passed as a nested sqlpage function argument."), "{err_msg}"); } #[test] @@ -1482,9 +1481,9 @@ mod test { let ParsedStatement::Error(err) = stmt else { panic!("expected ParsedStatement::Error: {stmt:?}"); }; - let err_msg = err.to_string(); - assert!(err_msg.contains("unsupported function argument"), "{err_msg}"); - assert!(err_msg.contains("function not emulated"), "{err_msg}"); + let err_msg = format!("{err:#}"); + assert!(err_msg.contains("Unsupported sqlpage function argument:"), "{err_msg}"); + assert!(err_msg.contains("\"upper\" is not a supported sqlpage function"), "{err_msg}"); } #[test] @@ -1775,46 +1774,6 @@ mod test { ); } - #[test] - fn test_set_variable_with_sqlpage_function() { - let sql = "set x = sqlpage.url_encode(some_db_function())"; - for &(dialect, dbms) in ALL_DIALECTS { - let mut parser = Parser::new(dialect).try_with_sql(sql).unwrap(); - let db_info = create_test_db_info(dbms); - let stmt = parse_single_statement(&mut parser, &db_info, sql, None); - let Some(ParsedStatement::SetVariable { - variable, - value: - StmtWithParams { - query, - params, - delayed_functions, - json_columns, - .. - }, - }) = stmt - else { - panic!("for dialect {dialect:?}: {stmt:#?} instead of SetVariable"); - }; - assert_eq!( - variable, - StmtParam::PostOrGet("x".to_string()), - "{dialect:?}" - ); - assert_eq!( - delayed_functions, - [DelayedFunctionCall { - function: SqlPageFunctionName::url_encode, - argument_col_names: vec!["_sqlpage_f0_a0".to_string()], - target_col_name: "sqlpage_set_expr".to_string() - }] - ); - assert_eq!(query, "SELECT some_db_function() AS \"_sqlpage_f0_a0\""); - assert_eq!(params, []); - assert_eq!(json_columns, Vec::::new()); - } - } - #[test] fn test_extract_json_columns_from_literal() { let sql = r#" @@ -1911,7 +1870,7 @@ mod test { let stmt = parse_single_statement(&mut parser, &db_info, sql, None); if let Some(ParsedStatement::Error(err)) = stmt { assert!( - err.to_string().contains("Invalid SQLPage function call"), + err.to_string().contains("Unsupported sqlpage function argument:"), "Expected error for invalid function, got: {err}" ); } else { diff --git a/src/webserver/database/sqlpage_functions/mod.rs b/src/webserver/database/sqlpage_functions/mod.rs index 8e5807bf..e5912e07 100644 --- a/src/webserver/database/sqlpage_functions/mod.rs +++ b/src/webserver/database/sqlpage_functions/mod.rs @@ -12,23 +12,13 @@ use super::sql::ParamExtractContext; use super::syntax_tree::SqlPageFunctionCall; use super::syntax_tree::StmtParam; -use super::sql::FormatArguments; -use anyhow::Context; - pub(super) fn func_call_to_param( func_name: &str, arguments: &mut [FunctionArg], ctx: &ParamExtractContext, ) -> StmtParam { - SqlPageFunctionCall::from_func_call(func_name, arguments, ctx) - .with_context(|| { - format!( - "Invalid function call: sqlpage.{func_name}({})", - FormatArguments(arguments) - ) - }) - .map_or_else( - |e| StmtParam::Error(format!("{e:#}")), - StmtParam::FunctionCall, - ) + SqlPageFunctionCall::from_func_call(func_name, arguments, ctx).map_or_else( + |e| StmtParam::Error(format!("{e:#}")), + StmtParam::FunctionCall, + ) } diff --git a/tests/sql_test_files/component_rendering/error_arbitrary_SQL_expressions_as_function_arguments_are_not_supported.sql b/tests/sql_test_files/component_rendering/error_is_not_a_supported_sqlpage_function.sql similarity index 100% rename from tests/sql_test_files/component_rendering/error_arbitrary_SQL_expressions_as_function_arguments_are_not_supported.sql rename to tests/sql_test_files/component_rendering/error_is_not_a_supported_sqlpage_function.sql From f861a2c5518f2233440e8609e12f10b6f2b5dad0 Mon Sep 17 00:00:00 2001 From: lovasoa Date: Sat, 28 Feb 2026 23:39:21 +0100 Subject: [PATCH 05/10] improve error messages --- src/webserver/database/sql.rs | 58 +++++++++++++++++------------------ 1 file changed, 29 insertions(+), 29 deletions(-) diff --git a/src/webserver/database/sql.rs b/src/webserver/database/sql.rs index cabc900c..b6478d95 100644 --- a/src/webserver/database/sql.rs +++ b/src/webserver/database/sql.rs @@ -1,6 +1,6 @@ use super::csv_import::{extract_csv_copy_statement, CsvImport}; -use super::sqlpage_functions::functions::SqlPageFunctionName; use super::sqlpage_functions::func_call_to_param; +use super::sqlpage_functions::functions::SqlPageFunctionName; use super::syntax_tree::StmtParam; use super::SupportedDatabase; use crate::file_cache::AsyncFromStrWithState; @@ -201,14 +201,11 @@ fn parse_single_statement( semicolon = true; } - let mut params = match ParameterExtractor::extract_parameters( - &mut stmt, - db_info.clone(), - source_path, - ) { - Ok(p) => p, - Err(err) => return Some(ParsedStatement::Error(err)), - }; + let mut params = + match ParameterExtractor::extract_parameters(&mut stmt, db_info.clone(), source_path) { + Ok(p) => p, + Err(err) => return Some(ParsedStatement::Error(err)), + }; let dbms = db_info.database_type; if let Some(parsed) = extract_set_variable(&mut stmt, &mut params, db_info, source_path) { return Some(parsed); @@ -713,9 +710,8 @@ fn validate_function_calls(stmt: &Statement, source_path: Option<&Path>) -> anyh source_path: source_path.map(PathBuf::from), parent_func: Some(func_name.clone()), }; - if let Err(e) = function_args_to_stmt_params(&mut args, &ctx) { - return Err(e); - } + function_args_to_stmt_params(&mut args, &ctx)?; + let args_str = FormatArguments(&args); let error_msg = format!( "Invalid SQLPage function call: sqlpage.{func_name}({args_str})\n\n\ @@ -785,16 +781,13 @@ impl ParamExtractContext { let reason = match &e.kind { ExprToParamErrorKind::UnsupportedExpr { summary } => { - format!("\"{summary}\" is an sql expression, which cannot be passed as a nested sqlpage function argument.\n\ - You should reorganize the query or split it into a sequence of multiple queries using intermediate variables with SET, so that sqlpage.{func} either appears at the top level of a SELECT statement, or depends solely on $variables instead of data from the database.") + format!("\"{summary}\" is an sql expression, which cannot be passed as a nested sqlpage function argument.") } ExprToParamErrorKind::UnemulatedFunction { name } => { - format!("\"{name}\" is not a supported sqlpage function. Only a few basic sql functions like concat or json_object can be used inside sqlpage functions.\n\ - You should reorganize the query or split it into a sequence of multiple queries using intermediate variables with SET, so that sqlpage.{func} either appears at the top level of a SELECT statement, or depends solely on $variables instead of data from the database.") + format!("\"{name}\" is not a supported sqlpage function. Only a few basic sql function calls like concat or json_object can be used as function parameters.") } ExprToParamErrorKind::NamedArgs => { - format!("Named function arguments are not supported.\n\ - Please use positional arguments only.") + format!("Named function arguments are not supported. Please use positional arguments only in sqlpage.{func}") } }; @@ -1294,8 +1287,7 @@ mod test { let mut ast = parse_postgres_stmt("select $a from t where $x > $a OR $x = sqlpage.cookie('cookoo')"); let db_info = create_test_db_info(SupportedDatabase::Postgres); - let parameters = - ParameterExtractor::extract_parameters(&mut ast, db_info, None).unwrap(); + let parameters = ParameterExtractor::extract_parameters(&mut ast, db_info, None).unwrap(); // $a -> $1 // $x -> $2 // sqlpage.cookie(...) -> $3 @@ -1320,8 +1312,7 @@ mod test { fn test_statement_rewrite_sqlite() { let mut ast = parse_stmt("select $x, :y from t", &SQLiteDialect {}); let db_info = create_test_db_info(SupportedDatabase::Sqlite); - let parameters = - ParameterExtractor::extract_parameters(&mut ast, db_info, None).unwrap(); + let parameters = ParameterExtractor::extract_parameters(&mut ast, db_info, None).unwrap(); assert_eq!( ast.to_string(), "SELECT CAST(?1 AS TEXT), CAST(?2 AS TEXT) FROM t" @@ -1446,7 +1437,7 @@ mod test { let mut ast = parse_stmt(sql, dialect); let db_info = create_test_db_info(SupportedDatabase::Postgres); let parameters = - ParameterExtractor::extract_parameters(&mut ast, db_info, None).unwrap(); + ParameterExtractor::extract_parameters(&mut ast, db_info, None).unwrap(); assert_eq!( parameters, [StmtParam::FunctionCall(SqlPageFunctionCall { @@ -1468,7 +1459,10 @@ mod test { panic!("expected ParsedStatement::Error: {stmt:?}"); }; let err_msg = format!("{err:#}"); - assert!(err_msg.contains("Unsupported sqlpage function argument:"), "{err_msg}"); + assert!( + err_msg.contains("Unsupported sqlpage function argument:"), + "{err_msg}" + ); assert!(err_msg.contains("\"c\" is an sql expression, which cannot be passed as a nested sqlpage function argument."), "{err_msg}"); } @@ -1482,8 +1476,14 @@ mod test { panic!("expected ParsedStatement::Error: {stmt:?}"); }; let err_msg = format!("{err:#}"); - assert!(err_msg.contains("Unsupported sqlpage function argument:"), "{err_msg}"); - assert!(err_msg.contains("\"upper\" is not a supported sqlpage function"), "{err_msg}"); + assert!( + err_msg.contains("Unsupported sqlpage function argument:"), + "{err_msg}" + ); + assert!( + err_msg.contains("\"upper\" is not a supported sqlpage function"), + "{err_msg}" + ); } #[test] @@ -1559,8 +1559,7 @@ mod test { &MsSqlDialect {}, ); let db_info = create_test_db_info(SupportedDatabase::Mssql); - let parameters = - ParameterExtractor::extract_parameters(&mut ast, db_info, None).unwrap(); + let parameters = ParameterExtractor::extract_parameters(&mut ast, db_info, None).unwrap(); assert_eq!( ast.to_string(), "SELECT CONCAT('', CAST(@p1 AS VARCHAR(MAX))) FROM [a schema].[a table]" @@ -1870,7 +1869,8 @@ mod test { let stmt = parse_single_statement(&mut parser, &db_info, sql, None); if let Some(ParsedStatement::Error(err)) = stmt { assert!( - err.to_string().contains("Unsupported sqlpage function argument:"), + err.to_string() + .contains("Unsupported sqlpage function argument:"), "Expected error for invalid function, got: {err}" ); } else { From da166508c3f70cbdf53d98a19d203761dcae6fe0 Mon Sep 17 00:00:00 2001 From: lovasoa Date: Sat, 28 Feb 2026 23:53:16 +0100 Subject: [PATCH 06/10] Refactor SqlPageFunctionError representation to clean up 'syntax error' wrappers - Replaced stringly-typed anyhow errors with a strongly typed SqlPageFunctionError. - Removed source_path threading completely from the parameter extraction phases, conforming to better separation of concerns. - Appended file path prefix dynamically at the evaluation stage in clone_anyhow_err strictly when downcasting to SqlPageFunctionError. - Removed the confusing generic 'Caused by: x.sql contains a syntax error...' wrapper from actual function logic errors. --- src/webserver/database/execute_queries.rs | 6 + src/webserver/database/sql.rs | 130 +++++++++++----------- 2 files changed, 74 insertions(+), 62 deletions(-) diff --git a/src/webserver/database/execute_queries.rs b/src/webserver/database/execute_queries.rs index 24de81ce..de1a5e45 100644 --- a/src/webserver/database/execute_queries.rs +++ b/src/webserver/database/execute_queries.rs @@ -331,6 +331,12 @@ fn debug_row(r: &AnyRow) { } fn clone_anyhow_err(source_file: &Path, err: &anyhow::Error) -> anyhow::Error { + if let Some(func_err) = err.downcast_ref::() { + let line = func_err.line; + let loc = if line > 0 { format!(":{line}") } else { String::new() }; + return anyhow::anyhow!("{}{loc} {}", source_file.display(), func_err); + } + let mut e = anyhow!("{} contains a syntax error preventing SQLPage from parsing and preparing its SQL statements.", source_file.display()); for c in err.chain().rev() { e = e.context(c.to_string()); diff --git a/src/webserver/database/sql.rs b/src/webserver/database/sql.rs index b6478d95..d62ce18d 100644 --- a/src/webserver/database/sql.rs +++ b/src/webserver/database/sql.rs @@ -43,7 +43,7 @@ impl ParsedSqlFile { source_path.display(), dialect ); - let parsed_statements = match parse_sql(&db.info, dialect.as_ref(), sql, Some(source_path)) + let parsed_statements = match parse_sql(&db.info, dialect.as_ref(), sql) { Ok(parsed) => parsed, Err(err) => return Self::from_err(err, source_path), @@ -136,7 +136,6 @@ fn parse_sql<'a>( db_info: &'a DbInfo, dialect: &'a dyn Dialect, sql: &'a str, - source_path: Option<&'a Path>, ) -> anyhow::Result + 'a> { log::trace!("Parsing {} SQL: {sql}", db_info.dbms_name); @@ -153,7 +152,7 @@ fn parse_sql<'a>( // Return the first error and ignore the rest return None; } - let statement = parse_single_statement(&mut parser, db_info, sql, source_path); + let statement = parse_single_statement(&mut parser, db_info, sql); log::debug!("Parsed statement: {statement:?}"); if let Some(ParsedStatement::Error(_)) = &statement { has_error = true; @@ -187,7 +186,6 @@ fn parse_single_statement( parser: &mut Parser<'_>, db_info: &DbInfo, source_sql: &str, - source_path: Option<&Path>, ) -> Option { if parser.peek_token() == EOF { return None; @@ -201,13 +199,15 @@ fn parse_single_statement( semicolon = true; } - let mut params = - match ParameterExtractor::extract_parameters(&mut stmt, db_info.clone(), source_path) { - Ok(p) => p, - Err(err) => return Some(ParsedStatement::Error(err)), - }; + let mut params = match ParameterExtractor::extract_parameters( + &mut stmt, + db_info.clone(), + ) { + Ok(p) => p, + Err(err) => return Some(ParsedStatement::Error(err)), + }; let dbms = db_info.database_type; - if let Some(parsed) = extract_set_variable(&mut stmt, &mut params, db_info, source_path) { + if let Some(parsed) = extract_set_variable(&mut stmt, &mut params, db_info) { return Some(parsed); } if let Some(csv_import) = extract_csv_copy_statement(&mut stmt) { @@ -220,7 +220,7 @@ fn parse_single_statement( let delayed_functions = extract_toplevel_functions(&mut stmt); - if let Err(err) = validate_function_calls(&stmt, source_path) { + if let Err(err) = validate_function_calls(&stmt) { return Some(ParsedStatement::Error(err)); } let json_columns = extract_json_columns(&stmt, dbms); @@ -514,7 +514,6 @@ fn extract_set_variable( stmt: &mut Statement, params: &mut Vec, db_info: &DbInfo, - source_path: Option<&Path>, ) -> Option { if let Statement::Set(Set::SingleAssignment { variable: ObjectName(name), @@ -539,7 +538,7 @@ fn extract_set_variable( let mut select_stmt: Statement = expr_to_statement(owned_expr); let delayed_functions = extract_toplevel_functions(&mut select_stmt); - if let Err(err) = validate_function_calls(&select_stmt, source_path) { + if let Err(err) = validate_function_calls(&select_stmt) { return Some(ParsedStatement::Error(err)); } let json_columns = extract_json_columns(&select_stmt, db_info.database_type); @@ -560,7 +559,6 @@ fn extract_set_variable( struct ParameterExtractor { db_info: DbInfo, parameters: Vec, - source_path: Option, extract_error: Option, } @@ -612,12 +610,10 @@ impl ParameterExtractor { fn extract_parameters( sql_ast: &mut sqlparser::ast::Statement, db_info: DbInfo, - source_path: Option<&Path>, ) -> anyhow::Result> { let mut this = Self { db_info, parameters: vec![], - source_path: source_path.map(PathBuf::from), extract_error: None, }; let _ = sql_ast.visit(&mut this); @@ -703,11 +699,10 @@ impl Visitor for InvalidFunctionFinder { } } -fn validate_function_calls(stmt: &Statement, source_path: Option<&Path>) -> anyhow::Result<()> { +fn validate_function_calls(stmt: &Statement) -> anyhow::Result<()> { let mut finder = InvalidFunctionFinder; if let ControlFlow::Break((func_name, mut args)) = stmt.visit(&mut finder) { let ctx = ParamExtractContext { - source_path: source_path.map(PathBuf::from), parent_func: Some(func_name.clone()), }; function_args_to_stmt_params(&mut args, &ctx)?; @@ -751,48 +746,62 @@ impl std::fmt::Display for FormatArguments<'_> { #[derive(Clone, Default)] pub(super) struct ParamExtractContext { - pub source_path: Option, pub parent_func: Option, } impl ParamExtractContext { pub(super) fn with_parent(&self, parent: &str) -> Self { Self { - source_path: self.source_path.clone(), parent_func: Some(parent.to_string()), } } - pub(super) fn location_prefix(&self, line: u64) -> String { - match &self.source_path { - Some(p) => format!("{}:{} ", p.display(), line), - None => String::new(), - } - } - - pub(super) fn format_param_error( + pub(super) fn into_error( &self, - e: &ExprToParamError, + e: ExprToParamError, arguments: &[FunctionArg], - ) -> String { - let loc = e.line.map(|l| self.location_prefix(l)).unwrap_or_default(); - let func = self.parent_func.as_deref().unwrap_or("unknown"); - let args_str = FormatArguments(arguments); + ) -> SqlPageFunctionError { + let line = e.line.unwrap_or(0); + let func_name = self.parent_func.as_deref().unwrap_or("unknown").to_string(); + let arguments_str = FormatArguments(arguments).to_string(); let reason = match &e.kind { ExprToParamErrorKind::UnsupportedExpr { summary } => { - format!("\"{summary}\" is an sql expression, which cannot be passed as a nested sqlpage function argument.") + format!("\"{summary}\" is an sql expression, which cannot be passed as a nested sqlpage function argument.\n\ + You should reorganize the query or split it into a sequence of multiple queries using intermediate variables with SET, so that sqlpage.{func_name} either appears at the top level of a SELECT statement, or depends solely on $variables instead of data from the database.") } ExprToParamErrorKind::UnemulatedFunction { name } => { - format!("\"{name}\" is not a supported sqlpage function. Only a few basic sql function calls like concat or json_object can be used as function parameters.") + format!("\"{name}\" is not a supported sqlpage function. Only a few basic sql functions like concat or json_object can be used inside sqlpage functions.\n\ + You should reorganize the query or split it into a sequence of multiple queries using intermediate variables with SET, so that sqlpage.{func_name} either appears at the top level of a SELECT statement, or depends solely on $variables instead of data from the database.") } ExprToParamErrorKind::NamedArgs => { - format!("Named function arguments are not supported. Please use positional arguments only in sqlpage.{func}") + format!("Named function arguments are not supported.\n\ + Please use positional arguments only.") } }; - format!( - "{loc}Unsupported sqlpage function argument:\n\ + SqlPageFunctionError { + line, + func_name, + arguments_str, + reason, + } + } +} + +#[derive(Debug)] +pub struct SqlPageFunctionError { + pub line: u64, + pub func_name: String, + pub arguments_str: String, + pub reason: String, +} + +impl std::fmt::Display for SqlPageFunctionError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "Unsupported sqlpage function argument:\n\ sqlpage.{func}({args_str})\n\n\ {reason}\n\n\ SQLPage functions can either:\n\ @@ -805,10 +814,14 @@ SET {func}_arg = ...;\n\ SET {func}_result = sqlpage.{func}(${func}_arg);\n\ SELECT * FROM example WHERE xxx = ${func}_result;\n\n\ 2. Or move the function to the top level to process results:\n\ -SELECT sqlpage.{func}(...) FROM example;" +SELECT sqlpage.{func}(...) FROM example;", + func = self.func_name, + args_str = self.arguments_str, + reason = self.reason ) } } +impl std::error::Error for SqlPageFunctionError {} #[derive(Debug)] pub(super) struct ExprToParamError { @@ -853,13 +866,13 @@ pub(super) fn function_args_to_stmt_params( ctx: &ParamExtractContext, ) -> anyhow::Result> { let mut params = Vec::with_capacity(arguments.len()); - // We iterate manually so we can pass the entire `arguments` slice to format_param_error on failure + // We iterate manually so we can pass the entire `arguments` slice to into_error on failure for arg in arguments.iter_mut() { match function_arg_to_stmt_param(arg, ctx) { Ok(p) => params.push(p), Err(e) => { - let msg = ctx.format_param_error(&e, arguments); - return Err(anyhow::anyhow!("{msg}")); + let func_err = ctx.into_error(e, arguments); + return Err(anyhow::Error::new(func_err)); } } } @@ -1054,7 +1067,6 @@ impl VisitorMut for ParameterExtractor { log::trace!("Handling builtin function: {func_name}"); let arguments = std::mem::take(args); let ctx = ParamExtractContext { - source_path: self.source_path.clone(), parent_func: Some(func_name.to_string()), }; let mut arguments_clone = arguments.clone(); @@ -1287,7 +1299,7 @@ mod test { let mut ast = parse_postgres_stmt("select $a from t where $x > $a OR $x = sqlpage.cookie('cookoo')"); let db_info = create_test_db_info(SupportedDatabase::Postgres); - let parameters = ParameterExtractor::extract_parameters(&mut ast, db_info, None).unwrap(); + let parameters = ParameterExtractor::extract_parameters(&mut ast, db_info).unwrap(); // $a -> $1 // $x -> $2 // sqlpage.cookie(...) -> $3 @@ -1312,7 +1324,7 @@ mod test { fn test_statement_rewrite_sqlite() { let mut ast = parse_stmt("select $x, :y from t", &SQLiteDialect {}); let db_info = create_test_db_info(SupportedDatabase::Sqlite); - let parameters = ParameterExtractor::extract_parameters(&mut ast, db_info, None).unwrap(); + let parameters = ParameterExtractor::extract_parameters(&mut ast, db_info).unwrap(); assert_eq!( ast.to_string(), "SELECT CAST(?1 AS TEXT), CAST(?2 AS TEXT) FROM t" @@ -1359,7 +1371,7 @@ mod test { let sql = "select {'a': 1, 'b': 2} as payload"; let db_info = create_test_db_info(dbms); - let mut parsed = parse_sql(&db_info, dialect.as_ref(), sql, None).unwrap(); + let mut parsed = parse_sql(&db_info, dialect.as_ref(), sql).unwrap(); let stmt = parsed.next().expect("expected one statement"); assert!( !matches!(stmt, ParsedStatement::Error(_)), @@ -1367,7 +1379,7 @@ mod test { ); let pg_info = create_test_db_info(SupportedDatabase::Postgres); - let mut parsed = parse_sql(&pg_info, &PostgreSqlDialect {}, sql, None).unwrap(); + let mut parsed = parse_sql(&pg_info, &PostgreSqlDialect {}, sql).unwrap(); let stmt = parsed.next().expect("expected one statement"); assert!( matches!(stmt, ParsedStatement::Error(_)), @@ -1437,8 +1449,7 @@ mod test { let mut ast = parse_stmt(sql, dialect); let db_info = create_test_db_info(SupportedDatabase::Postgres); let parameters = - ParameterExtractor::extract_parameters(&mut ast, db_info, None).unwrap(); - assert_eq!( + ParameterExtractor::extract_parameters(&mut ast, db_info).unwrap(); assert_eq!( parameters, [StmtParam::FunctionCall(SqlPageFunctionCall { function: SqlPageFunctionName::fetch, @@ -1453,7 +1464,7 @@ mod test { fn test_parse_sql_unsupported_expr_in_sqlpage_arg() { let sql = "SELECT sqlpage.link('x', json_build_object('k', c)) FROM (SELECT 1 AS c) t"; let db_info = create_test_db_info(SupportedDatabase::Postgres); - let mut parsed = parse_sql(&db_info, &PostgreSqlDialect {}, sql, None).unwrap(); + let mut parsed = parse_sql(&db_info, &PostgreSqlDialect {}, sql).unwrap(); let stmt = parsed.next().expect("one statement"); let ParsedStatement::Error(err) = stmt else { panic!("expected ParsedStatement::Error: {stmt:?}"); @@ -1470,7 +1481,7 @@ mod test { fn test_parse_sql_unemulated_function_in_sqlpage_arg() { let sql = "SELECT sqlpage.link('x', upper('a')) FROM (SELECT 1) t"; let db_info = create_test_db_info(SupportedDatabase::Postgres); - let mut parsed = parse_sql(&db_info, &PostgreSqlDialect {}, sql, None).unwrap(); + let mut parsed = parse_sql(&db_info, &PostgreSqlDialect {}, sql).unwrap(); let stmt = parsed.next().expect("one statement"); let ParsedStatement::Error(err) = stmt else { panic!("expected ParsedStatement::Error: {stmt:?}"); @@ -1492,7 +1503,7 @@ mod test { for &(dialect, dbms) in ALL_DIALECTS { let mut parser = Parser::new(dialect).try_with_sql(sql).unwrap(); let db_info = create_test_db_info(dbms); - match parse_single_statement(&mut parser, &db_info, sql, None) { + match parse_single_statement(&mut parser, &db_info, sql) { Some(ParsedStatement::StaticSimpleSet { variable, value }) => { assert_eq!( variable, @@ -1514,7 +1525,6 @@ mod test { assert!(ParameterExtractor { db_info: create_test_db_info(SupportedDatabase::Postgres), parameters: vec![], - source_path: None, extract_error: None, } .is_own_placeholder("$1")); @@ -1522,7 +1532,6 @@ mod test { assert!(ParameterExtractor { db_info: create_test_db_info(SupportedDatabase::Postgres), parameters: vec![StmtParam::Get("x".to_string())], - source_path: None, extract_error: None, } .is_own_placeholder("$2")); @@ -1530,7 +1539,6 @@ mod test { assert!(!ParameterExtractor { db_info: create_test_db_info(SupportedDatabase::Postgres), parameters: vec![], - source_path: None, extract_error: None, } .is_own_placeholder("$2")); @@ -1538,7 +1546,6 @@ mod test { assert!(ParameterExtractor { db_info: create_test_db_info(SupportedDatabase::Sqlite), parameters: vec![], - source_path: None, extract_error: None, } .is_own_placeholder("?1")); @@ -1546,7 +1553,6 @@ mod test { assert!(!ParameterExtractor { db_info: create_test_db_info(SupportedDatabase::Sqlite), parameters: vec![], - source_path: None, extract_error: None, } .is_own_placeholder("$1")); @@ -1559,7 +1565,7 @@ mod test { &MsSqlDialect {}, ); let db_info = create_test_db_info(SupportedDatabase::Mssql); - let parameters = ParameterExtractor::extract_parameters(&mut ast, db_info, None).unwrap(); + let parameters = ParameterExtractor::extract_parameters(&mut ast, db_info).unwrap(); assert_eq!( ast.to_string(), "SELECT CONCAT('', CAST(@p1 AS VARCHAR(MAX))) FROM [a schema].[a table]" @@ -1613,7 +1619,7 @@ mod test { create_test_db_info(SupportedDatabase::Generic) }; let parsed: Vec = - parse_sql(&db_info, dialect, sql, None).unwrap().collect(); + parse_sql(&db_info, dialect, sql).unwrap().collect(); match &parsed[..] { [ParsedStatement::StaticSimpleSelect(q)] => assert_eq!( q, @@ -1650,7 +1656,7 @@ mod test { for &(dialect, dbms) in ALL_DIALECTS { let mut parser = Parser::new(dialect).try_with_sql(sql).unwrap(); let db_info = create_test_db_info(dbms); - let stmt = parse_single_statement(&mut parser, &db_info, sql, None); + let stmt = parse_single_statement(&mut parser, &db_info, sql); if let Some(ParsedStatement::SetVariable { variable, value: StmtWithParams { query, params, .. }, @@ -1675,7 +1681,7 @@ mod test { for &(dialect, dbms) in ALL_DIALECTS { let mut parser = Parser::new(dialect).try_with_sql(sql).unwrap(); let db_info = create_test_db_info(dbms); - match parse_single_statement(&mut parser, &db_info, sql, None) { + match parse_single_statement(&mut parser, &db_info, sql) { Some(ParsedStatement::StaticSimpleSet { variable: StmtParam::PostOrGet(var_name), value: SimpleSelectValue::Static(value), @@ -1866,7 +1872,7 @@ mod test { for &(dialect, dbms) in ALL_DIALECTS { let mut parser = Parser::new(dialect).try_with_sql(sql).unwrap(); let db_info = create_test_db_info(dbms); - let stmt = parse_single_statement(&mut parser, &db_info, sql, None); + let stmt = parse_single_statement(&mut parser, &db_info, sql); if let Some(ParsedStatement::Error(err)) = stmt { assert!( err.to_string() From ae633040c8d4b56f381314b7f9c0654bef87292d Mon Sep 17 00:00:00 2001 From: lovasoa Date: Sat, 28 Feb 2026 23:55:20 +0100 Subject: [PATCH 07/10] Remove redundant 'reorganize' hint from error message --- src/webserver/database/sql.rs | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/webserver/database/sql.rs b/src/webserver/database/sql.rs index d62ce18d..1b726439 100644 --- a/src/webserver/database/sql.rs +++ b/src/webserver/database/sql.rs @@ -767,12 +767,10 @@ impl ParamExtractContext { let reason = match &e.kind { ExprToParamErrorKind::UnsupportedExpr { summary } => { - format!("\"{summary}\" is an sql expression, which cannot be passed as a nested sqlpage function argument.\n\ - You should reorganize the query or split it into a sequence of multiple queries using intermediate variables with SET, so that sqlpage.{func_name} either appears at the top level of a SELECT statement, or depends solely on $variables instead of data from the database.") + format!("\"{summary}\" is an sql expression, which cannot be passed as a nested sqlpage function argument.") } ExprToParamErrorKind::UnemulatedFunction { name } => { - format!("\"{name}\" is not a supported sqlpage function. Only a few basic sql functions like concat or json_object can be used inside sqlpage functions.\n\ - You should reorganize the query or split it into a sequence of multiple queries using intermediate variables with SET, so that sqlpage.{func_name} either appears at the top level of a SELECT statement, or depends solely on $variables instead of data from the database.") + format!("\"{name}\" is not a supported sqlpage function. Only a few basic sql functions like concat or json_object can be used inside sqlpage functions.") } ExprToParamErrorKind::NamedArgs => { format!("Named function arguments are not supported.\n\ From e47c324645ef2de4d22e05c37a75a4e16e91dfcd Mon Sep 17 00:00:00 2001 From: lovasoa Date: Sun, 1 Mar 2026 00:05:14 +0100 Subject: [PATCH 08/10] readd deleted test --- src/webserver/database/execute_queries.rs | 6 ++- src/webserver/database/sql.rs | 61 ++++++++++++++++++----- 2 files changed, 54 insertions(+), 13 deletions(-) diff --git a/src/webserver/database/execute_queries.rs b/src/webserver/database/execute_queries.rs index de1a5e45..253bd17e 100644 --- a/src/webserver/database/execute_queries.rs +++ b/src/webserver/database/execute_queries.rs @@ -333,7 +333,11 @@ fn debug_row(r: &AnyRow) { fn clone_anyhow_err(source_file: &Path, err: &anyhow::Error) -> anyhow::Error { if let Some(func_err) = err.downcast_ref::() { let line = func_err.line; - let loc = if line > 0 { format!(":{line}") } else { String::new() }; + let loc = if line > 0 { + format!(":{line}") + } else { + String::new() + }; return anyhow::anyhow!("{}{loc} {}", source_file.display(), func_err); } diff --git a/src/webserver/database/sql.rs b/src/webserver/database/sql.rs index 1b726439..05ad678c 100644 --- a/src/webserver/database/sql.rs +++ b/src/webserver/database/sql.rs @@ -43,8 +43,7 @@ impl ParsedSqlFile { source_path.display(), dialect ); - let parsed_statements = match parse_sql(&db.info, dialect.as_ref(), sql) - { + let parsed_statements = match parse_sql(&db.info, dialect.as_ref(), sql) { Ok(parsed) => parsed, Err(err) => return Self::from_err(err, source_path), }; @@ -199,10 +198,7 @@ fn parse_single_statement( semicolon = true; } - let mut params = match ParameterExtractor::extract_parameters( - &mut stmt, - db_info.clone(), - ) { + let mut params = match ParameterExtractor::extract_parameters(&mut stmt, db_info.clone()) { Ok(p) => p, Err(err) => return Some(ParsedStatement::Error(err)), }; @@ -773,8 +769,10 @@ impl ParamExtractContext { format!("\"{name}\" is not a supported sqlpage function. Only a few basic sql functions like concat or json_object can be used inside sqlpage functions.") } ExprToParamErrorKind::NamedArgs => { - format!("Named function arguments are not supported.\n\ - Please use positional arguments only.") + format!( + "Named function arguments are not supported.\n\ + Please use positional arguments only." + ) } }; @@ -1446,8 +1444,8 @@ mod test { let sql = "select sqlpage.fetch($x)"; let mut ast = parse_stmt(sql, dialect); let db_info = create_test_db_info(SupportedDatabase::Postgres); - let parameters = - ParameterExtractor::extract_parameters(&mut ast, db_info).unwrap(); assert_eq!( + let parameters = ParameterExtractor::extract_parameters(&mut ast, db_info).unwrap(); + assert_eq!( parameters, [StmtParam::FunctionCall(SqlPageFunctionCall { function: SqlPageFunctionName::fetch, @@ -1616,8 +1614,7 @@ mod test { } else { create_test_db_info(SupportedDatabase::Generic) }; - let parsed: Vec = - parse_sql(&db_info, dialect, sql).unwrap().collect(); + let parsed: Vec = parse_sql(&db_info, dialect, sql).unwrap().collect(); match &parsed[..] { [ParsedStatement::StaticSimpleSelect(q)] => assert_eq!( q, @@ -1777,6 +1774,46 @@ mod test { ); } + #[test] + fn test_set_variable_with_sqlpage_function() { + let sql = "set x = sqlpage.url_encode(some_db_function())"; + for &(dialect, dbms) in ALL_DIALECTS { + let mut parser = Parser::new(dialect).try_with_sql(sql).unwrap(); + let db_info = create_test_db_info(dbms); + let stmt = parse_single_statement(&mut parser, &db_info, sql); + let Some(ParsedStatement::SetVariable { + variable, + value: + StmtWithParams { + query, + params, + delayed_functions, + json_columns, + .. + }, + }) = stmt + else { + panic!("for dialect {dialect:?}: {stmt:#?} instead of SetVariable"); + }; + assert_eq!( + variable, + StmtParam::PostOrGet("x".to_string()), + "{dialect:?}" + ); + assert_eq!( + delayed_functions, + [DelayedFunctionCall { + function: SqlPageFunctionName::url_encode, + argument_col_names: vec!["_sqlpage_f0_a0".to_string()], + target_col_name: "sqlpage_set_expr".to_string() + }] + ); + assert_eq!(query, "SELECT some_db_function() AS \"_sqlpage_f0_a0\""); + assert_eq!(params, []); + assert_eq!(json_columns, Vec::::new()); + } + } + #[test] fn test_extract_json_columns_from_literal() { let sql = r#" From 49a4df3d09ad704507dd96efcc39a6119cad8325 Mon Sep 17 00:00:00 2001 From: lovasoa Date: Sun, 1 Mar 2026 09:27:58 +0100 Subject: [PATCH 09/10] split parameter extraction logic into a separate file --- src/webserver/database/sql.rs | 598 +----------------- .../database/sql/parameter_extraction.rs | 583 +++++++++++++++++ 2 files changed, 595 insertions(+), 586 deletions(-) create mode 100644 src/webserver/database/sql/parameter_extraction.rs diff --git a/src/webserver/database/sql.rs b/src/webserver/database/sql.rs index 05ad678c..7acac1a0 100644 --- a/src/webserver/database/sql.rs +++ b/src/webserver/database/sql.rs @@ -1,5 +1,4 @@ use super::csv_import::{extract_csv_copy_statement, CsvImport}; -use super::sqlpage_functions::func_call_to_param; use super::sqlpage_functions::functions::SqlPageFunctionName; use super::syntax_tree::StmtParam; use super::SupportedDatabase; @@ -10,10 +9,9 @@ use crate::{AppState, Database}; use async_trait::async_trait; use sqlparser::ast::helpers::attached_token::AttachedToken; use sqlparser::ast::{ - BinaryOperator, CastKind, CharacterLength, DataType, Expr, Function, FunctionArg, - FunctionArgExpr, FunctionArgumentList, FunctionArguments, Ident, ObjectName, ObjectNamePart, - SelectFlavor, SelectItem, Set, SetExpr, Spanned, Statement, Value, ValueWithSpan, Visit, - VisitMut, Visitor, VisitorMut, + CastKind, DataType, Expr, Function, FunctionArg, FunctionArgExpr, FunctionArgumentList, + FunctionArguments, Ident, ObjectName, ObjectNamePart, SelectFlavor, SelectItem, Set, SetExpr, + Spanned, Statement, Value, ValueWithSpan, }; use sqlparser::dialect::{ Dialect, DuckDbDialect, GenericDialect, MsSqlDialect, MySqlDialect, OracleDialect, @@ -24,10 +22,18 @@ use sqlparser::tokenizer::Token::{self, SemiColon, EOF}; use sqlparser::tokenizer::{Location, Span, TokenWithSpan, Tokenizer}; use sqlx::any::AnyKind; use std::fmt::Write; -use std::ops::ControlFlow; use std::path::{Path, PathBuf}; use std::str::FromStr; +mod parameter_extraction; +use self::parameter_extraction::{ + extract_ident_param, validate_function_calls, ParameterExtractor, TEMP_PLACEHOLDER_PREFIX, +}; +pub(super) use self::parameter_extraction::{ + function_args_to_stmt_params, DbPlaceHolder, ParamExtractContext, SqlPageFunctionError, + DB_PLACEHOLDERS, +}; + #[derive(Default)] pub struct ParsedSqlFile { pub(super) statements: Vec, @@ -289,18 +295,6 @@ fn dialect_for_db(dbms: SupportedDatabase) -> Box { } } -fn map_param(mut name: String) -> StmtParam { - if name.is_empty() { - return StmtParam::PostOrGet(name); - } - let prefix = name.remove(0); - match prefix { - '$' => StmtParam::PostOrGet(name), - ':' => StmtParam::Post(name), - _ => StmtParam::Get(name), - } -} - #[derive(Debug, PartialEq)] pub struct DelayedFunctionCall { pub function: SqlPageFunctionName, @@ -552,574 +546,6 @@ fn extract_set_variable( None } -struct ParameterExtractor { - db_info: DbInfo, - parameters: Vec, - extract_error: Option, -} - -#[derive(Debug)] -pub enum DbPlaceHolder { - PrefixedNumber { prefix: &'static str }, - Positional { placeholder: &'static str }, -} - -pub const DB_PLACEHOLDERS: [(AnyKind, DbPlaceHolder); 5] = [ - ( - AnyKind::Sqlite, - DbPlaceHolder::PrefixedNumber { prefix: "?" }, - ), - ( - AnyKind::Postgres, - DbPlaceHolder::PrefixedNumber { prefix: "$" }, - ), - ( - AnyKind::MySql, - DbPlaceHolder::Positional { placeholder: "?" }, - ), - ( - AnyKind::Mssql, - DbPlaceHolder::PrefixedNumber { prefix: "@p" }, - ), - ( - AnyKind::Odbc, - DbPlaceHolder::Positional { placeholder: "?" }, - ), -]; - -/// For positional parameters, we use a temporary placeholder during parameter extraction, -/// And then replace it with the actual placeholder during statement rewriting. -const TEMP_PLACEHOLDER_PREFIX: &str = "@SQLPAGE_TEMP"; - -fn get_placeholder_prefix(kind: AnyKind) -> &'static str { - if let Some((_, DbPlaceHolder::PrefixedNumber { prefix })) = DB_PLACEHOLDERS - .iter() - .find(|(placeholder_kind, _prefix)| *placeholder_kind == kind) - { - prefix - } else { - TEMP_PLACEHOLDER_PREFIX - } -} - -impl ParameterExtractor { - fn extract_parameters( - sql_ast: &mut sqlparser::ast::Statement, - db_info: DbInfo, - ) -> anyhow::Result> { - let mut this = Self { - db_info, - parameters: vec![], - extract_error: None, - }; - let _ = sql_ast.visit(&mut this); - if let Some(e) = this.extract_error { - return Err(e); - } - Ok(this.parameters) - } - - fn replace_with_placeholder(&mut self, value: &mut Expr, param: StmtParam) { - let placeholder = - if let Some(existing_idx) = self.parameters.iter().position(|p| *p == param) { - // Parameter already exists, use its index - self.make_placeholder_for_index(existing_idx + 1) - } else { - // New parameter, add it to the list - let placeholder = self.make_placeholder(); - log::trace!("Replacing {param} with {placeholder}"); - self.parameters.push(param); - placeholder - }; - *value = placeholder; - } - - fn make_placeholder_for_index(&self, index: usize) -> Expr { - let name = make_tmp_placeholder(self.db_info.kind, index); - let data_type = match self.db_info.database_type { - SupportedDatabase::MySql => DataType::Char(None), - SupportedDatabase::Mssql => DataType::Varchar(Some(CharacterLength::Max)), - SupportedDatabase::Postgres | SupportedDatabase::Sqlite => DataType::Text, - SupportedDatabase::Oracle => DataType::Varchar(Some(CharacterLength::IntegerLength { - length: 4000, - unit: None, - })), - _ => DataType::Varchar(None), - }; - let value = Expr::value(Value::Placeholder(name)); - Expr::Cast { - expr: Box::new(value), - data_type, - format: None, - kind: CastKind::Cast, - } - } - - fn make_placeholder(&self) -> Expr { - self.make_placeholder_for_index(self.parameters.len() + 1) - } - - fn is_own_placeholder(&self, param: &str) -> bool { - let prefix = get_placeholder_prefix(self.db_info.kind); - if let Some(param) = param.strip_prefix(prefix) { - if let Ok(index) = param.parse::() { - return index <= self.parameters.len() + 1; - } - } - false - } -} - -struct InvalidFunctionFinder; -impl Visitor for InvalidFunctionFinder { - type Break = (String, Vec); - fn pre_visit_expr(&mut self, value: &Expr) -> ControlFlow { - match value { - Expr::Function(Function { - name: ObjectName(func_name_parts), - args: - FunctionArguments::List(FunctionArgumentList { - args, - duplicate_treatment: None, - .. - }), - .. - }) if is_sqlpage_func(func_name_parts) => { - let func_name = sqlpage_func_name(func_name_parts); - let arguments = args.clone(); - return ControlFlow::Break((func_name.to_string(), arguments)); - } - _ => (), - } - ControlFlow::Continue(()) - } -} - -fn validate_function_calls(stmt: &Statement) -> anyhow::Result<()> { - let mut finder = InvalidFunctionFinder; - if let ControlFlow::Break((func_name, mut args)) = stmt.visit(&mut finder) { - let ctx = ParamExtractContext { - parent_func: Some(func_name.clone()), - }; - function_args_to_stmt_params(&mut args, &ctx)?; - - let args_str = FormatArguments(&args); - let error_msg = format!( - "Invalid SQLPage function call: sqlpage.{func_name}({args_str})\n\n\ - Arbitrary SQL expressions as function arguments are not supported.\n\n\ - SQLPage functions can either:\n\ - 1. Run BEFORE the query (to provide input values)\n\ - 2. Run AFTER the query (to process the results)\n\ - But they can't run DURING the query - the database doesn't know how to call them!\n\n\ - To fix this, you can either:\n\ - 1. Store the function argument in a variable first:\n\ - SET {func_name}_arg = ...;\n\ - SET {func_name}_result = sqlpage.{func_name}(${func_name}_arg);\n\ - SELECT * FROM example WHERE xxx = ${func_name}_result;\n\n\ - 2. Or move the function to the top level to process results:\n\ - SELECT sqlpage.{func_name}(...) FROM example;" - ); - Err(anyhow::anyhow!(error_msg)) - } else { - Ok(()) - } -} - -/** This is a helper struct to format a list of arguments for an error message. */ -pub(super) struct FormatArguments<'a>(pub &'a [FunctionArg]); -impl std::fmt::Display for FormatArguments<'_> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let mut args = self.0.iter(); - if let Some(arg) = args.next() { - write!(f, "{arg}")?; - } - for arg in args { - write!(f, ", {arg}")?; - } - Ok(()) - } -} - -#[derive(Clone, Default)] -pub(super) struct ParamExtractContext { - pub parent_func: Option, -} - -impl ParamExtractContext { - pub(super) fn with_parent(&self, parent: &str) -> Self { - Self { - parent_func: Some(parent.to_string()), - } - } - - pub(super) fn into_error( - &self, - e: ExprToParamError, - arguments: &[FunctionArg], - ) -> SqlPageFunctionError { - let line = e.line.unwrap_or(0); - let func_name = self.parent_func.as_deref().unwrap_or("unknown").to_string(); - let arguments_str = FormatArguments(arguments).to_string(); - - let reason = match &e.kind { - ExprToParamErrorKind::UnsupportedExpr { summary } => { - format!("\"{summary}\" is an sql expression, which cannot be passed as a nested sqlpage function argument.") - } - ExprToParamErrorKind::UnemulatedFunction { name } => { - format!("\"{name}\" is not a supported sqlpage function. Only a few basic sql functions like concat or json_object can be used inside sqlpage functions.") - } - ExprToParamErrorKind::NamedArgs => { - format!( - "Named function arguments are not supported.\n\ - Please use positional arguments only." - ) - } - }; - - SqlPageFunctionError { - line, - func_name, - arguments_str, - reason, - } - } -} - -#[derive(Debug)] -pub struct SqlPageFunctionError { - pub line: u64, - pub func_name: String, - pub arguments_str: String, - pub reason: String, -} - -impl std::fmt::Display for SqlPageFunctionError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "Unsupported sqlpage function argument:\n\ -sqlpage.{func}({args_str})\n\n\ -{reason}\n\n\ -SQLPage functions can either:\n\ -1. Run BEFORE the query (to provide input values)\n\ -2. Run AFTER the query (to process the results)\n\ -But they can't run DURING the query - the database doesn't know how to call them!\n\n\ -To fix this, you can either:\n\ -1. Store the function argument in a variable first:\n\ -SET {func}_arg = ...;\n\ -SET {func}_result = sqlpage.{func}(${func}_arg);\n\ -SELECT * FROM example WHERE xxx = ${func}_result;\n\n\ -2. Or move the function to the top level to process results:\n\ -SELECT sqlpage.{func}(...) FROM example;", - func = self.func_name, - args_str = self.arguments_str, - reason = self.reason - ) - } -} -impl std::error::Error for SqlPageFunctionError {} - -#[derive(Debug)] -pub(super) struct ExprToParamError { - pub(super) line: Option, - pub(super) kind: ExprToParamErrorKind, -} - -#[derive(Debug)] -pub(super) enum ExprToParamErrorKind { - UnsupportedExpr { summary: String }, - UnemulatedFunction { name: String }, - NamedArgs, -} - -fn expr_summary(expr: &Expr) -> String { - match expr { - Expr::CompoundIdentifier(idents) => { - let s = idents - .iter() - .map(|i| i.value.as_str()) - .collect::>() - .join("."); - format!("column/table reference '{s}'") - } - _ => format!("{expr}"), - } -} - -pub(super) fn function_arg_to_stmt_param( - arg: &mut FunctionArg, - ctx: &ParamExtractContext, -) -> Result { - let expr = function_arg_expr(arg).ok_or(ExprToParamError { - line: None, - kind: ExprToParamErrorKind::NamedArgs, - })?; - expr_to_stmt_param(expr, ctx) -} - -pub(super) fn function_args_to_stmt_params( - arguments: &mut [FunctionArg], - ctx: &ParamExtractContext, -) -> anyhow::Result> { - let mut params = Vec::with_capacity(arguments.len()); - // We iterate manually so we can pass the entire `arguments` slice to into_error on failure - for arg in arguments.iter_mut() { - match function_arg_to_stmt_param(arg, ctx) { - Ok(p) => params.push(p), - Err(e) => { - let func_err = ctx.into_error(e, arguments); - return Err(anyhow::Error::new(func_err)); - } - } - } - Ok(params) -} - -fn emulated_func_args_to_param( - func_name: &str, - args: &mut [FunctionArg], - ctx: &ParamExtractContext, - line: u64, -) -> Result { - let inner = ctx.with_parent(func_name); - if func_name.eq_ignore_ascii_case("concat") { - let mut concat_args = Vec::with_capacity(args.len()); - for a in args { - concat_args.push(function_arg_to_stmt_param(a, &inner)?); - } - Ok(StmtParam::Concat(concat_args)) - } else if func_name.eq_ignore_ascii_case("json_object") - || func_name.eq_ignore_ascii_case("jsonb_object") - || func_name.eq_ignore_ascii_case("json_build_object") - || func_name.eq_ignore_ascii_case("jsonb_build_object") - { - let mut json_obj_args = Vec::with_capacity(args.len()); - for a in args { - json_obj_args.push(function_arg_to_stmt_param(a, &inner)?); - } - Ok(StmtParam::JsonObject(json_obj_args)) - } else if func_name.eq_ignore_ascii_case("json_array") - || func_name.eq_ignore_ascii_case("jsonb_array") - || func_name.eq_ignore_ascii_case("json_build_array") - || func_name.eq_ignore_ascii_case("jsonb_build_array") - { - let mut json_obj_args = Vec::with_capacity(args.len()); - for a in args { - json_obj_args.push(function_arg_to_stmt_param(a, &inner)?); - } - Ok(StmtParam::JsonArray(json_obj_args)) - } else if func_name.eq_ignore_ascii_case("coalesce") { - let mut coalesce_args = Vec::with_capacity(args.len()); - for a in args { - coalesce_args.push(function_arg_to_stmt_param(a, &inner)?); - } - Ok(StmtParam::Coalesce(coalesce_args)) - } else { - Err(ExprToParamError { - line: Some(line), - kind: ExprToParamErrorKind::UnemulatedFunction { - name: func_name.to_string(), - }, - }) - } -} - -fn expr_to_stmt_param( - arg: &mut Expr, - ctx: &ParamExtractContext, -) -> Result { - let line = arg.span().start.line; - match arg { - Expr::Value(ValueWithSpan { - value: Value::Placeholder(placeholder), - .. - }) => Ok(map_param(std::mem::take(placeholder))), - Expr::Identifier(ident) => extract_ident_param(ident).ok_or_else(|| ExprToParamError { - line: Some(line), - kind: ExprToParamErrorKind::UnsupportedExpr { - summary: expr_summary(arg), - }, - }), - Expr::Function(Function { - name: ObjectName(func_name_parts), - args: - FunctionArguments::List(FunctionArgumentList { - args, - duplicate_treatment: None, - .. - }), - .. - }) if is_sqlpage_func(func_name_parts) => Ok(func_call_to_param( - sqlpage_func_name(func_name_parts), - args.as_mut_slice(), - ctx, - )), - Expr::Value(ValueWithSpan { - value: Value::SingleQuotedString(param_value), - .. - }) => Ok(StmtParam::Literal(std::mem::take(param_value))), - Expr::Value(ValueWithSpan { - value: Value::Number(param_value, _is_long), - .. - }) => Ok(StmtParam::Literal(param_value.clone())), - Expr::Value(ValueWithSpan { - value: Value::Null, .. - }) => Ok(StmtParam::Null), - Expr::BinaryOp { - left, - op: BinaryOperator::StringConcat, - right, - } => { - let left = expr_to_stmt_param(left, ctx)?; - let right = expr_to_stmt_param(right, ctx)?; - Ok(StmtParam::Concat(vec![left, right])) - } - Expr::Function(Function { - name: ObjectName(func_name_parts), - args: - FunctionArguments::List(FunctionArgumentList { - args, - duplicate_treatment: None, - .. - }), - .. - }) if func_name_parts.len() == 1 => { - let func_name = func_name_parts[0] - .as_ident() - .map(|ident| ident.value.as_str()) - .unwrap_or_default(); - emulated_func_args_to_param(func_name, args.as_mut_slice(), ctx, line) - } - _ => Err(ExprToParamError { - line: Some(line), - kind: ExprToParamErrorKind::UnsupportedExpr { - summary: expr_summary(arg), - }, - }), - } -} - -fn function_arg_expr(arg: &mut FunctionArg) -> Option<&mut Expr> { - match arg { - FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)) => Some(expr), - _ => None, - } -} - -#[inline] -#[must_use] -pub fn make_tmp_placeholder(kind: AnyKind, arg_number: usize) -> String { - let prefix = if let Some((_, DbPlaceHolder::PrefixedNumber { prefix })) = - DB_PLACEHOLDERS.iter().find(|(db_typ, _)| *db_typ == kind) - { - prefix - } else { - TEMP_PLACEHOLDER_PREFIX - }; - format!("{prefix}{arg_number}") -} - -fn extract_ident_param(Ident { value, .. }: &mut Ident) -> Option { - if value.starts_with('$') || value.starts_with(':') { - let name = std::mem::take(value); - Some(map_param(name)) - } else { - None - } -} - -impl VisitorMut for ParameterExtractor { - type Break = (); - fn pre_visit_expr(&mut self, value: &mut Expr) -> ControlFlow { - match value { - Expr::Identifier(ident) => { - if let Some(param) = extract_ident_param(ident) { - self.replace_with_placeholder(value, param); - } - } - Expr::Value(ValueWithSpan { - value: Value::Placeholder(param), - .. - }) if !self.is_own_placeholder(param) => - // this check is to avoid recursively replacing placeholders in the form of '?', or '$1', '$2', which we emit ourselves - { - let name = std::mem::take(param); - self.replace_with_placeholder(value, map_param(name)); - } - Expr::Function(Function { - name: ObjectName(func_name_parts), - args: - FunctionArguments::List(FunctionArgumentList { - args, - duplicate_treatment: None, - .. - }), - filter: None, - null_treatment: None, - over: None, - .. - }) if is_sqlpage_func(func_name_parts) => { - let func_name = sqlpage_func_name(func_name_parts); - log::trace!("Handling builtin function: {func_name}"); - let arguments = std::mem::take(args); - let ctx = ParamExtractContext { - parent_func: Some(func_name.to_string()), - }; - let mut arguments_clone = arguments.clone(); - let param = func_call_to_param(func_name, &mut arguments_clone, &ctx); - if let StmtParam::Error(msg) = ¶m { - log::trace!("Skipping extraction of {func_name} due to: {msg}"); - *args = arguments; - return ControlFlow::Continue(()); - } - self.replace_with_placeholder(value, param); - } - // Replace 'str1' || 'str2' with CONCAT('str1', 'str2') for MSSQL - Expr::BinaryOp { - left, - op: BinaryOperator::StringConcat, - right, - } if self.db_info.database_type == SupportedDatabase::Mssql => { - let left = std::mem::replace(left.as_mut(), Expr::value(Value::Null)); - let right = std::mem::replace(right.as_mut(), Expr::value(Value::Null)); - *value = Expr::Function(Function { - name: ObjectName(vec![ObjectNamePart::Identifier(Ident::new("CONCAT"))]), - args: FunctionArguments::List(FunctionArgumentList { - args: vec![ - FunctionArg::Unnamed(FunctionArgExpr::Expr(left)), - FunctionArg::Unnamed(FunctionArgExpr::Expr(right)), - ], - duplicate_treatment: None, - clauses: Vec::new(), - }), - parameters: FunctionArguments::None, - over: None, - filter: None, - null_treatment: None, - within_group: Vec::new(), - uses_odbc_syntax: false, - }); - } - Expr::Cast { - kind: kind @ CastKind::DoubleColon, - .. - } if ![ - SupportedDatabase::Postgres, - SupportedDatabase::Snowflake, - SupportedDatabase::Generic, - ] - .contains(&self.db_info.database_type) => - { - log::warn!("Casting with '::' is not supported on your database. \ - For backwards compatibility with older SQLPage versions, we will transform it to CAST(... AS ...)."); - *kind = CastKind::Cast; - } - _ => (), - } - ControlFlow::<()>::Continue(()) - } -} - const SQLPAGE_FUNCTION_NAMESPACE: &str = "sqlpage"; fn is_sqlpage_func(func_name_parts: &[ObjectNamePart]) -> bool { diff --git a/src/webserver/database/sql/parameter_extraction.rs b/src/webserver/database/sql/parameter_extraction.rs new file mode 100644 index 00000000..7386880f --- /dev/null +++ b/src/webserver/database/sql/parameter_extraction.rs @@ -0,0 +1,583 @@ +use super::super::{DbInfo, SupportedDatabase}; +use super::{is_sqlpage_func, sqlpage_func_name}; +use crate::webserver::database::sqlpage_functions::func_call_to_param; +use crate::webserver::database::syntax_tree::StmtParam; +use sqlparser::ast::{ + BinaryOperator, CastKind, CharacterLength, DataType, Expr, Function, FunctionArg, + FunctionArgExpr, FunctionArgumentList, FunctionArguments, Ident, ObjectName, ObjectNamePart, + Spanned, Statement, Value, ValueWithSpan, Visit, VisitMut, Visitor, VisitorMut, +}; +use sqlx::any::AnyKind; +use std::ops::ControlFlow; + +pub(super) struct ParameterExtractor { + pub(super) db_info: DbInfo, + pub(super) parameters: Vec, + pub(super) extract_error: Option, +} + +#[derive(Debug)] +pub(crate) enum DbPlaceHolder { + PrefixedNumber { prefix: &'static str }, + Positional { placeholder: &'static str }, +} + +pub(crate) const DB_PLACEHOLDERS: [(AnyKind, DbPlaceHolder); 5] = [ + ( + AnyKind::Sqlite, + DbPlaceHolder::PrefixedNumber { prefix: "?" }, + ), + ( + AnyKind::Postgres, + DbPlaceHolder::PrefixedNumber { prefix: "$" }, + ), + ( + AnyKind::MySql, + DbPlaceHolder::Positional { placeholder: "?" }, + ), + ( + AnyKind::Mssql, + DbPlaceHolder::PrefixedNumber { prefix: "@p" }, + ), + ( + AnyKind::Odbc, + DbPlaceHolder::Positional { placeholder: "?" }, + ), +]; + +/// For positional parameters, we use a temporary placeholder during parameter extraction, +/// And then replace it with the actual placeholder during statement rewriting. +pub(crate) const TEMP_PLACEHOLDER_PREFIX: &str = "@SQLPAGE_TEMP"; + +fn get_placeholder_prefix(kind: AnyKind) -> &'static str { + if let Some((_, DbPlaceHolder::PrefixedNumber { prefix })) = DB_PLACEHOLDERS + .iter() + .find(|(placeholder_kind, _prefix)| *placeholder_kind == kind) + { + prefix + } else { + TEMP_PLACEHOLDER_PREFIX + } +} + +impl ParameterExtractor { + pub(super) fn extract_parameters( + sql_ast: &mut Statement, + db_info: DbInfo, + ) -> anyhow::Result> { + let mut this = Self { + db_info, + parameters: vec![], + extract_error: None, + }; + let _ = sql_ast.visit(&mut this); + if let Some(e) = this.extract_error { + return Err(e); + } + Ok(this.parameters) + } + + fn replace_with_placeholder(&mut self, value: &mut Expr, param: StmtParam) { + let placeholder = + if let Some(existing_idx) = self.parameters.iter().position(|p| *p == param) { + // Parameter already exists, use its index + self.make_placeholder_for_index(existing_idx + 1) + } else { + // New parameter, add it to the list + let placeholder = self.make_placeholder(); + log::trace!("Replacing {param} with {placeholder}"); + self.parameters.push(param); + placeholder + }; + *value = placeholder; + } + + fn make_placeholder_for_index(&self, index: usize) -> Expr { + let name = make_tmp_placeholder(self.db_info.kind, index); + let data_type = match self.db_info.database_type { + SupportedDatabase::MySql => DataType::Char(None), + SupportedDatabase::Mssql => DataType::Varchar(Some(CharacterLength::Max)), + SupportedDatabase::Postgres | SupportedDatabase::Sqlite => DataType::Text, + SupportedDatabase::Oracle => DataType::Varchar(Some(CharacterLength::IntegerLength { + length: 4000, + unit: None, + })), + _ => DataType::Varchar(None), + }; + let value = Expr::value(Value::Placeholder(name)); + Expr::Cast { + expr: Box::new(value), + data_type, + format: None, + kind: CastKind::Cast, + } + } + + fn make_placeholder(&self) -> Expr { + self.make_placeholder_for_index(self.parameters.len() + 1) + } + + pub(super) fn is_own_placeholder(&self, param: &str) -> bool { + let prefix = get_placeholder_prefix(self.db_info.kind); + if let Some(param) = param.strip_prefix(prefix) { + if let Ok(index) = param.parse::() { + return index <= self.parameters.len() + 1; + } + } + false + } +} + +struct InvalidFunctionFinder; +impl Visitor for InvalidFunctionFinder { + type Break = (String, Vec); + fn pre_visit_expr(&mut self, value: &Expr) -> ControlFlow { + match value { + Expr::Function(Function { + name: ObjectName(func_name_parts), + args: + FunctionArguments::List(FunctionArgumentList { + args, + duplicate_treatment: None, + .. + }), + .. + }) if is_sqlpage_func(func_name_parts) => { + let func_name = sqlpage_func_name(func_name_parts); + let arguments = args.clone(); + return ControlFlow::Break((func_name.to_string(), arguments)); + } + _ => (), + } + ControlFlow::Continue(()) + } +} + +pub(super) fn validate_function_calls(stmt: &Statement) -> anyhow::Result<()> { + let mut finder = InvalidFunctionFinder; + if let ControlFlow::Break((func_name, mut args)) = stmt.visit(&mut finder) { + let ctx = ParamExtractContext { + parent_func: Some(func_name.clone()), + }; + function_args_to_stmt_params(&mut args, &ctx)?; + + let args_str = FormatArguments(&args); + let error_msg = format!( + "Invalid SQLPage function call: sqlpage.{func_name}({args_str})\n\n\ + Arbitrary SQL expressions as function arguments are not supported.\n\n\ + SQLPage functions can either:\n\ + 1. Run BEFORE the query (to provide input values)\n\ + 2. Run AFTER the query (to process the results)\n\ + But they can't run DURING the query - the database doesn't know how to call them!\n\n\ + To fix this, you can either:\n\ + 1. Store the function argument in a variable first:\n\ + SET {func_name}_arg = ...;\n\ + SET {func_name}_result = sqlpage.{func_name}(${func_name}_arg);\n\ + SELECT * FROM example WHERE xxx = ${func_name}_result;\n\n\ + 2. Or move the function to the top level to process results:\n\ + SELECT sqlpage.{func_name}(...) FROM example;" + ); + Err(anyhow::anyhow!(error_msg)) + } else { + Ok(()) + } +} + +/** This is a helper struct to format a list of arguments for an error message. */ +struct FormatArguments<'a>(&'a [FunctionArg]); +impl std::fmt::Display for FormatArguments<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut args = self.0.iter(); + if let Some(arg) = args.next() { + write!(f, "{arg}")?; + } + for arg in args { + write!(f, ", {arg}")?; + } + Ok(()) + } +} + +#[derive(Clone, Default)] +pub(crate) struct ParamExtractContext { + pub parent_func: Option, +} + +impl ParamExtractContext { + fn with_parent(parent: &str) -> Self { + Self { + parent_func: Some(parent.to_string()), + } + } + + fn build_error(&self, e: &ExprToParamError, arguments: &[FunctionArg]) -> SqlPageFunctionError { + let line = e.line.unwrap_or(0); + let func_name = self.parent_func.as_deref().unwrap_or("unknown").to_string(); + let arguments_str = FormatArguments(arguments).to_string(); + + let reason = match &e.kind { + ExprToParamErrorKind::UnsupportedExpr { summary } => { + format!("\"{summary}\" is an sql expression, which cannot be passed as a nested sqlpage function argument.") + } + ExprToParamErrorKind::UnemulatedFunction { name } => { + format!("\"{name}\" is not a supported sqlpage function. Only a few basic sql functions like concat or json_object can be used inside sqlpage functions.") + } + ExprToParamErrorKind::NamedArgs => "Named function arguments are not supported.\n\ + Please use positional arguments only." + .to_string(), + }; + + SqlPageFunctionError { + line, + func_name, + arguments_str, + reason, + } + } +} + +#[derive(Debug)] +pub(crate) struct SqlPageFunctionError { + pub line: u64, + pub func_name: String, + pub arguments_str: String, + pub reason: String, +} + +impl std::fmt::Display for SqlPageFunctionError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "Unsupported sqlpage function argument:\n\ +sqlpage.{func}({args_str})\n\n\ +{reason}\n\n\ +SQLPage functions can either:\n\ +1. Run BEFORE the query (to provide input values)\n\ +2. Run AFTER the query (to process the results)\n\ +But they can't run DURING the query - the database doesn't know how to call them!\n\n\ +To fix this, you can either:\n\ +1. Store the function argument in a variable first:\n\ +SET {func}_arg = ...;\n\ +SET {func}_result = sqlpage.{func}(${func}_arg);\n\ +SELECT * FROM example WHERE xxx = ${func}_result;\n\n\ +2. Or move the function to the top level to process results:\n\ +SELECT sqlpage.{func}(...) FROM example;", + func = self.func_name, + args_str = self.arguments_str, + reason = self.reason + ) + } +} +impl std::error::Error for SqlPageFunctionError {} + +#[derive(Debug)] +struct ExprToParamError { + line: Option, + kind: ExprToParamErrorKind, +} + +#[derive(Debug)] +enum ExprToParamErrorKind { + UnsupportedExpr { summary: String }, + UnemulatedFunction { name: String }, + NamedArgs, +} + +fn expr_summary(expr: &Expr) -> String { + match expr { + Expr::CompoundIdentifier(idents) => { + let s = idents + .iter() + .map(|i| i.value.as_str()) + .collect::>() + .join("."); + format!("column/table reference '{s}'") + } + _ => format!("{expr}"), + } +} + +fn function_arg_to_stmt_param( + arg: &mut FunctionArg, + ctx: &ParamExtractContext, +) -> Result { + let expr = function_arg_expr(arg).ok_or(ExprToParamError { + line: None, + kind: ExprToParamErrorKind::NamedArgs, + })?; + expr_to_stmt_param(expr, ctx) +} + +pub(crate) fn function_args_to_stmt_params( + arguments: &mut [FunctionArg], + ctx: &ParamExtractContext, +) -> anyhow::Result> { + let mut params = Vec::with_capacity(arguments.len()); + // We iterate manually so we can pass the entire `arguments` slice to into_error on failure + for arg in arguments.iter_mut() { + match function_arg_to_stmt_param(arg, ctx) { + Ok(p) => params.push(p), + Err(e) => { + let func_err = ctx.build_error(&e, arguments); + return Err(anyhow::Error::new(func_err)); + } + } + } + Ok(params) +} + +fn emulated_func_args_to_param( + func_name: &str, + args: &mut [FunctionArg], + line: u64, +) -> Result { + let inner = ParamExtractContext::with_parent(func_name); + if func_name.eq_ignore_ascii_case("concat") { + let mut concat_args = Vec::with_capacity(args.len()); + for a in args { + concat_args.push(function_arg_to_stmt_param(a, &inner)?); + } + Ok(StmtParam::Concat(concat_args)) + } else if func_name.eq_ignore_ascii_case("json_object") + || func_name.eq_ignore_ascii_case("jsonb_object") + || func_name.eq_ignore_ascii_case("json_build_object") + || func_name.eq_ignore_ascii_case("jsonb_build_object") + { + let mut json_obj_args = Vec::with_capacity(args.len()); + for a in args { + json_obj_args.push(function_arg_to_stmt_param(a, &inner)?); + } + Ok(StmtParam::JsonObject(json_obj_args)) + } else if func_name.eq_ignore_ascii_case("json_array") + || func_name.eq_ignore_ascii_case("jsonb_array") + || func_name.eq_ignore_ascii_case("json_build_array") + || func_name.eq_ignore_ascii_case("jsonb_build_array") + { + let mut json_obj_args = Vec::with_capacity(args.len()); + for a in args { + json_obj_args.push(function_arg_to_stmt_param(a, &inner)?); + } + Ok(StmtParam::JsonArray(json_obj_args)) + } else if func_name.eq_ignore_ascii_case("coalesce") { + let mut coalesce_args = Vec::with_capacity(args.len()); + for a in args { + coalesce_args.push(function_arg_to_stmt_param(a, &inner)?); + } + Ok(StmtParam::Coalesce(coalesce_args)) + } else { + Err(ExprToParamError { + line: Some(line), + kind: ExprToParamErrorKind::UnemulatedFunction { + name: func_name.to_string(), + }, + }) + } +} + +fn expr_to_stmt_param( + arg: &mut Expr, + ctx: &ParamExtractContext, +) -> Result { + let line = arg.span().start.line; + match arg { + Expr::Value(ValueWithSpan { + value: Value::Placeholder(placeholder), + .. + }) => Ok(map_param(std::mem::take(placeholder))), + Expr::Identifier(ident) => extract_ident_param(ident).ok_or_else(|| ExprToParamError { + line: Some(line), + kind: ExprToParamErrorKind::UnsupportedExpr { + summary: expr_summary(arg), + }, + }), + Expr::Function(Function { + name: ObjectName(func_name_parts), + args: + FunctionArguments::List(FunctionArgumentList { + args, + duplicate_treatment: None, + .. + }), + .. + }) if is_sqlpage_func(func_name_parts) => Ok(func_call_to_param( + sqlpage_func_name(func_name_parts), + args.as_mut_slice(), + ctx, + )), + Expr::Value(ValueWithSpan { + value: Value::SingleQuotedString(param_value), + .. + }) => Ok(StmtParam::Literal(std::mem::take(param_value))), + Expr::Value(ValueWithSpan { + value: Value::Number(param_value, _is_long), + .. + }) => Ok(StmtParam::Literal(param_value.clone())), + Expr::Value(ValueWithSpan { + value: Value::Null, .. + }) => Ok(StmtParam::Null), + Expr::BinaryOp { + left, + op: BinaryOperator::StringConcat, + right, + } => { + let left = expr_to_stmt_param(left, ctx)?; + let right = expr_to_stmt_param(right, ctx)?; + Ok(StmtParam::Concat(vec![left, right])) + } + Expr::Function(Function { + name: ObjectName(func_name_parts), + args: + FunctionArguments::List(FunctionArgumentList { + args, + duplicate_treatment: None, + .. + }), + .. + }) if func_name_parts.len() == 1 => { + let func_name = func_name_parts[0] + .as_ident() + .map(|ident| ident.value.as_str()) + .unwrap_or_default(); + emulated_func_args_to_param(func_name, args.as_mut_slice(), line) + } + _ => Err(ExprToParamError { + line: Some(line), + kind: ExprToParamErrorKind::UnsupportedExpr { + summary: expr_summary(arg), + }, + }), + } +} + +fn function_arg_expr(arg: &mut FunctionArg) -> Option<&mut Expr> { + match arg { + FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)) => Some(expr), + _ => None, + } +} + +#[inline] +#[must_use] +pub(super) fn make_tmp_placeholder(kind: AnyKind, arg_number: usize) -> String { + let prefix = if let Some((_, DbPlaceHolder::PrefixedNumber { prefix })) = + DB_PLACEHOLDERS.iter().find(|(db_typ, _)| *db_typ == kind) + { + prefix + } else { + TEMP_PLACEHOLDER_PREFIX + }; + format!("{prefix}{arg_number}") +} + +pub(super) fn extract_ident_param(Ident { value, .. }: &mut Ident) -> Option { + if value.starts_with('$') || value.starts_with(':') { + let name = std::mem::take(value); + Some(map_param(name)) + } else { + None + } +} + +fn map_param(mut name: String) -> StmtParam { + if name.is_empty() { + return StmtParam::PostOrGet(name); + } + let prefix = name.remove(0); + match prefix { + '$' => StmtParam::PostOrGet(name), + ':' => StmtParam::Post(name), + _ => StmtParam::Get(name), + } +} + +impl VisitorMut for ParameterExtractor { + type Break = (); + fn pre_visit_expr(&mut self, value: &mut Expr) -> ControlFlow { + match value { + Expr::Identifier(ident) => { + if let Some(param) = extract_ident_param(ident) { + self.replace_with_placeholder(value, param); + } + } + Expr::Value(ValueWithSpan { + value: Value::Placeholder(param), + .. + }) if !self.is_own_placeholder(param) => + // this check is to avoid recursively replacing placeholders in the form of '?', or '$1', '$2', which we emit ourselves + { + let name = std::mem::take(param); + self.replace_with_placeholder(value, map_param(name)); + } + Expr::Function(Function { + name: ObjectName(func_name_parts), + args: + FunctionArguments::List(FunctionArgumentList { + args, + duplicate_treatment: None, + .. + }), + filter: None, + null_treatment: None, + over: None, + .. + }) if is_sqlpage_func(func_name_parts) => { + let func_name = sqlpage_func_name(func_name_parts); + log::trace!("Handling builtin function: {func_name}"); + let arguments = std::mem::take(args); + let ctx = ParamExtractContext { + parent_func: Some(func_name.to_string()), + }; + let mut arguments_clone = arguments.clone(); + let param = func_call_to_param(func_name, &mut arguments_clone, &ctx); + if let StmtParam::Error(msg) = ¶m { + log::trace!("Skipping extraction of {func_name} due to: {msg}"); + *args = arguments; + return ControlFlow::Continue(()); + } + self.replace_with_placeholder(value, param); + } + // Replace 'str1' || 'str2' with CONCAT('str1', 'str2') for MSSQL + Expr::BinaryOp { + left, + op: BinaryOperator::StringConcat, + right, + } if self.db_info.database_type == SupportedDatabase::Mssql => { + let left = std::mem::replace(left.as_mut(), Expr::value(Value::Null)); + let right = std::mem::replace(right.as_mut(), Expr::value(Value::Null)); + *value = Expr::Function(Function { + name: ObjectName(vec![ObjectNamePart::Identifier(Ident::new("CONCAT"))]), + args: FunctionArguments::List(FunctionArgumentList { + args: vec![ + FunctionArg::Unnamed(FunctionArgExpr::Expr(left)), + FunctionArg::Unnamed(FunctionArgExpr::Expr(right)), + ], + duplicate_treatment: None, + clauses: Vec::new(), + }), + parameters: FunctionArguments::None, + over: None, + filter: None, + null_treatment: None, + within_group: Vec::new(), + uses_odbc_syntax: false, + }); + } + Expr::Cast { + kind: kind @ CastKind::DoubleColon, + .. + } if ![ + SupportedDatabase::Postgres, + SupportedDatabase::Snowflake, + SupportedDatabase::Generic, + ] + .contains(&self.db_info.database_type) => + { + log::warn!("Casting with '::' is not supported on your database. \ + For backwards compatibility with older SQLPage versions, we will transform it to CAST(... AS ...)."); + *kind = CastKind::Cast; + } + _ => (), + } + ControlFlow::<()>::Continue(()) + } +} From 003f3e00bf4de7897078b290c874ebc8da821681 Mon Sep 17 00:00:00 2001 From: lovasoa Date: Sun, 1 Mar 2026 09:34:46 +0100 Subject: [PATCH 10/10] Add to changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 73cb6fbb..38002f44 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,7 @@ ## unreleased - OIDC protected and public paths now respect the site prefix when it is defined. - add submit and reset form button icons: validate_icon, reset_icon, reset_color + - improve error messages when sqlpage functions are used incorrectly. Include precise file reference and line number ## 0.42.0 (2026-01-17)