diff --git a/front/parser/src/parser/ast.rs b/front/parser/src/parser/ast.rs index fc687a80..f98a8d5f 100644 --- a/front/parser/src/parser/ast.rs +++ b/front/parser/src/parser/ast.rs @@ -74,6 +74,14 @@ pub enum FormatPart { Placeholder, } +#[derive(Debug, Clone)] +pub enum IncDecKind { + PreInc, + PreDec, + PostInc, + PostDec, +} + #[derive(Debug, Clone)] pub enum Expression { StructLiteral { @@ -126,6 +134,10 @@ pub enum Expression { operator: Operator, expr: Box, }, + IncDec { + kind: IncDecKind, + target: Box, + }, } #[derive(Debug, Clone)] @@ -220,6 +232,7 @@ pub enum StatementNode { } #[derive(Debug, Clone, PartialEq)] +#[derive(Copy)] pub enum Mutability { Var, Let, diff --git a/front/parser/src/parser/format.rs b/front/parser/src/parser/format.rs index 36891217..238eb3e1 100644 --- a/front/parser/src/parser/format.rs +++ b/front/parser/src/parser/format.rs @@ -1,5 +1,5 @@ use crate::ast::Expression::Variable; -use crate::ast::{AssignOperator, Expression, FormatPart, Literal, Operator}; +use crate::ast::{AssignOperator, Expression, FormatPart, IncDecKind, Literal, Operator}; use lexer::{Token, TokenType}; use std::iter::Peekable; use std::slice::Iter; @@ -33,6 +33,36 @@ pub fn parse_format_string(s: &str) -> Vec { parts } +fn is_assignable(expr: &Expression) -> bool { + match expr { + Expression::Variable(_) => true, + Expression::Deref(_) => true, + Expression::FieldAccess { .. } => true, + Expression::IndexAccess { .. } => true, + + Expression::Grouped(inner) => is_assignable(inner), + + _ => false, + } +} + +fn desugar_incdec(line: usize, target: Expression, is_inc: bool) -> Option { + if !is_assignable(&target) { + println!("Error: ++/-- target must bee assignable (line {})", line); + return None; + } + + Some(Expression::AssignOperation { + target: Box::new(target), + operator: if is_inc { + AssignOperator::AddAssign + } else { + AssignOperator::SubAssign + }, + value: Box::new(Expression::Literal(Literal::Number(1))), + }) +} + pub fn parse_expression<'a, T>(tokens: &mut std::iter::Peekable) -> Option where T: Iterator, @@ -321,6 +351,30 @@ where let inner = parse_unary_expression(tokens)?; return Some(Expression::Deref(Box::new(inner))); } + TokenType::Increment => { + let tok = tokens.next()?; // '++' + let inner = parse_unary_expression(tokens)?; + if !is_assignable(&inner) { + println!("Error: ++ target must be assignable (line {})", tok.line); + return None; + } + return Some(Expression::IncDec { + kind: IncDecKind::PreInc, + target: Box::new(inner), + }); + } + TokenType::Decrement => { + let tok = tokens.next()?; // '--' + let inner = parse_unary_expression(tokens)?; + if !is_assignable(&inner) { + println!("Error: -- target must be assignable (line {})", tok.line); + return None; + } + return Some(Expression::IncDec { + kind: IncDecKind::PreDec, + target: Box::new(inner), + }); + } _ => {} } } @@ -762,6 +816,40 @@ where }); } + Some(TokenType::Increment) => { + let line = tokens.peek().unwrap().line; + tokens.next(); // consume '++' + + let base = expr.take()?; + if !is_assignable(&base) { + println!("Error: postfix ++ target must be assignable (line {})", line); + return None; + } + + expr = Some(Expression::IncDec { + kind: IncDecKind::PostInc, + target: Box::new(base), + }); + break; + } + + Some(TokenType::Decrement) => { + let line = tokens.peek().unwrap().line; + tokens.next(); // consume '--' + + let base = expr.take()?; + if !is_assignable(&base) { + println!("Error: postfix -- target must be assignable (line {})", line); + return None; + } + + expr = Some(Expression::IncDec { + kind: IncDecKind::PostDec, + target: Box::new(base), + }); + break; + } + _ => break, } } diff --git a/front/parser/src/parser/mod.rs b/front/parser/src/parser/mod.rs index 1cad9d2b..7cbcc9f9 100644 --- a/front/parser/src/parser/mod.rs +++ b/front/parser/src/parser/mod.rs @@ -23,5 +23,6 @@ pub mod import; pub mod parser; pub mod stdlib; pub mod type_system; +mod verification; pub use parser::*; diff --git a/front/parser/src/parser/parser.rs b/front/parser/src/parser/parser.rs index 327f5b5e..c7e0712f 100644 --- a/front/parser/src/parser/parser.rs +++ b/front/parser/src/parser/parser.rs @@ -7,6 +7,7 @@ use regex::Regex; use std::collections::HashSet; use std::iter::Peekable; use std::slice::Iter; +use crate::parser::verification::validate_program; pub fn parse(tokens: &Vec) -> Option> { let mut iter = tokens.iter().peekable(); @@ -68,6 +69,11 @@ pub fn parse(tokens: &Vec) -> Option> { } } + if let Err(e) = validate_program(&nodes) { + println!("❌ {}", e); + return None; + } + Some(nodes) } diff --git a/front/parser/src/parser/verification.rs b/front/parser/src/parser/verification.rs new file mode 100644 index 00000000..9cf11576 --- /dev/null +++ b/front/parser/src/parser/verification.rs @@ -0,0 +1,256 @@ +use std::collections::HashMap; +use crate::ast::{ASTNode, Expression, Mutability, StatementNode}; + +fn lookup_mutability( + name: &str, + scopes: &Vec>, + globals: &HashMap, +) -> Option { + for scope in scopes.iter().rev() { + if let Some(m) = scope.get(name) { + return Some(*m); + } + } + globals.get(name).copied() +} + +fn find_base_var(target: &Expression, saw_deref: bool) -> Option<(String, bool)> { + match target { + Expression::Variable(name) => Some((name.clone(), saw_deref)), + Expression::Grouped(inner) => find_base_var(inner, saw_deref), + + Expression::FieldAccess { object, .. } => find_base_var(object, saw_deref), + Expression::IndexAccess { target, .. } => find_base_var(target, saw_deref), + + Expression::Deref(inner) => find_base_var(inner, true), + + _ => None, + } +} + +fn ensure_mutable_write_target( + target: &Expression, + scopes: &Vec>, + globals: &HashMap, + why: &str, +) -> Result<(), String> { + let Some((base, saw_deref)) = find_base_var(target, false) else { + return Ok(()); + }; + + if saw_deref { + return Ok(()); + } + + if let Some(m) = lookup_mutability(&base, scopes, globals) { + match m { + Mutability::Let | Mutability::Const => { + return Err(format!("cannot {} immutable binding `{}` ({:?})", why, base, m)); + } + _ => {} + } + } + + Ok(()) +} + +fn validate_expr( + expr: &Expression, + scopes: &Vec>, + globals: &HashMap, +) -> Result<(), String> { + match expr { + Expression::IncDec { target, .. } => { + ensure_mutable_write_target(target, scopes, globals, "modify with ++/--")?; + validate_expr(target, scopes, globals)?; + } + + Expression::AssignOperation { target, value, .. } => { + ensure_mutable_write_target(target, scopes, globals, "assign")?; + validate_expr(target, scopes, globals)?; + validate_expr(value, scopes, globals)?; + } + + Expression::Assignment { target, value } => { + ensure_mutable_write_target(target, scopes, globals, "assign")?; + validate_expr(target, scopes, globals)?; + validate_expr(value, scopes, globals)?; + } + + Expression::BinaryExpression { left, right, .. } => { + validate_expr(left, scopes, globals)?; + validate_expr(right, scopes, globals)?; + } + + Expression::Unary { expr, .. } => validate_expr(expr, scopes, globals)?, + + Expression::FunctionCall { args, .. } => { + for a in args { + validate_expr(a, scopes, globals)?; + } + } + Expression::MethodCall { object, args, .. } => { + validate_expr(object, scopes, globals)?; + for a in args { + validate_expr(a, scopes, globals)?; + } + } + + Expression::IndexAccess { target, index } => { + validate_expr(target, scopes, globals)?; + validate_expr(index, scopes, globals)?; + } + + Expression::ArrayLiteral(items) => { + for it in items { + validate_expr(it, scopes, globals)?; + } + } + + Expression::FieldAccess { object, .. } => validate_expr(object, scopes, globals)?, + + Expression::StructLiteral { fields, .. } => { + for (_, v) in fields { + validate_expr(v, scopes, globals)?; + } + } + + Expression::AsmBlock { inputs, outputs, .. } => { + for (_, e) in inputs { + validate_expr(e, scopes, globals)?; + } + for (_, e) in outputs { + validate_expr(e, scopes, globals)?; + } + } + + Expression::Deref(inner) | Expression::AddressOf(inner) => { + validate_expr(inner, scopes, globals)?; + } + + Expression::Literal(_) | Expression::Variable(_) => {} + + _ => {} + } + + Ok(()) +} + +fn validate_node( + node: &ASTNode, + scopes: &mut Vec>, + globals: &HashMap, +) -> Result<(), String> { + match node { + ASTNode::Variable(v) => { + scopes + .last_mut() + .unwrap() + .insert(v.name.clone(), v.mutability.clone()); + } + + ASTNode::Statement(stmt) => match stmt { + StatementNode::Expression(e) => validate_expr(e, scopes, globals)?, + + StatementNode::Assign { variable, value } => { + let fake_target = Expression::Variable(variable.clone()); + ensure_mutable_write_target(&fake_target, scopes, globals, "assign")?; + validate_expr(value, scopes, globals)?; + } + + StatementNode::PrintlnFormat { args, .. } + | StatementNode::PrintFormat { args, .. } => { + for a in args { + validate_expr(a, scopes, globals)?; + } + } + + StatementNode::Return(Some(e)) => validate_expr(e, scopes, globals)?, + + StatementNode::If { + condition, + body, + else_if_blocks, + else_block, + } => { + validate_expr(condition, scopes, globals)?; + + scopes.push(HashMap::new()); + for n in body { + validate_node(n, scopes, globals)?; + } + scopes.pop(); + + if let Some(blocks) = else_if_blocks { + for (cond, b) in blocks.iter() { + validate_expr(cond, scopes, globals)?; + scopes.push(HashMap::new()); + for n in b { + validate_node(n, scopes, globals)?; + } + scopes.pop(); + } + } + + if let Some(b) = else_block { + scopes.push(HashMap::new()); + for n in b.iter() { + validate_node(n, scopes, globals)?; + } + scopes.pop(); + } + } + + StatementNode::While { condition, body } => { + validate_expr(condition, scopes, globals)?; + scopes.push(HashMap::new()); + for n in body { + validate_node(n, scopes, globals)?; + } + scopes.pop(); + } + + _ => {} + }, + + ASTNode::Function(func) => { + scopes.push(HashMap::new()); + + for p in &func.parameters { + scopes + .last_mut() + .unwrap() + .insert(p.name.clone(), Mutability::Var); + } + + for n in &func.body { + validate_node(n, scopes, globals)?; + } + + scopes.pop(); + } + + _ => {} + } + + Ok(()) +} + +pub fn validate_program(nodes: &Vec) -> Result<(), String> { + let mut globals: HashMap = HashMap::new(); + for n in nodes { + if let ASTNode::Variable(v) = n { + if v.mutability == Mutability::Const { + globals.insert(v.name.clone(), Mutability::Const); + } + } + } + + let mut scopes: Vec> = vec![HashMap::new()]; + + for n in nodes { + validate_node(n, &mut scopes, &globals)?; + } + + Ok(()) +} diff --git a/llvm_temporary/src/llvm_temporary/expression.rs b/llvm_temporary/src/llvm_temporary/expression.rs index 182145cf..dda25f70 100644 --- a/llvm_temporary/src/llvm_temporary/expression.rs +++ b/llvm_temporary/src/llvm_temporary/expression.rs @@ -3,7 +3,7 @@ use inkwell::context::Context; use inkwell::types::{AnyTypeEnum, BasicType, BasicTypeEnum, StructType}; use inkwell::values::{BasicMetadataValueEnum, BasicValue, BasicValueEnum, IntValue}; use inkwell::{FloatPredicate, IntPredicate}; -use parser::ast::{ASTNode, AssignOperator, Expression, Literal, Operator, WaveType}; +use parser::ast::{ASTNode, AssignOperator, Expression, IncDecKind, Literal, Operator, WaveType}; use std::collections::HashMap; use inkwell::builder::Builder; @@ -1103,6 +1103,57 @@ pub fn generate_expression_ir<'ctx>( } } + Expression::IncDec { kind, target } => { + let ptr = generate_address_ir(context, builder, target, variables, module); + let old_val = builder.build_load(ptr, "incdec_old").unwrap(); + + let new_val: BasicValueEnum<'ctx> = match old_val { + BasicValueEnum::IntValue(iv) => { + if iv.get_type().get_bit_width() == 1 { + panic!("++/-- not allowed on bool"); + } + + let one = iv.get_type().const_int(1, false); + let nv = match kind { + IncDecKind::PreInc | IncDecKind::PostInc => builder.build_int_add(iv, one, "inc").unwrap(), + IncDecKind::PreDec | IncDecKind::PostDec => builder.build_int_sub(iv, one, "dec").unwrap(), + }; + nv.as_basic_value_enum() + } + + BasicValueEnum::FloatValue(fv) => { + let one = fv.get_type().const_float(1.0); + let nv = match kind { + IncDecKind::PreInc | IncDecKind::PostInc => builder.build_float_add(fv, one, "finc").unwrap(), + IncDecKind::PreDec | IncDecKind::PostDec => builder.build_float_sub(fv, one, "fdec").unwrap(), + }; + nv.as_basic_value_enum() + } + + BasicValueEnum::PointerValue(pv) => { + let idx = match kind { + IncDecKind::PreInc | IncDecKind::PostInc => context.i64_type().const_int(1, true), + IncDecKind::PreDec | IncDecKind::PostDec => context.i64_type().const_int((-1i64) as u64, true), + }; + let gep = unsafe { + builder + .build_in_bounds_gep(pv, &[idx], "pincdec") + .unwrap() + }; + gep.as_basic_value_enum() + } + + _ => panic!("Unsupported type for ++/--: {:?}", old_val), + }; + + builder.build_store(ptr, new_val).unwrap(); + + match kind { + IncDecKind::PreInc | IncDecKind::PreDec => new_val, + IncDecKind::PostInc | IncDecKind::PostDec => old_val, + } + } + Expression::Grouped(inner) => { generate_expression_ir( context, diff --git a/llvm_temporary/src/llvm_temporary/llvm_codegen.rs b/llvm_temporary/src/llvm_temporary/llvm_codegen.rs index 39954255..db14fa59 100644 --- a/llvm_temporary/src/llvm_temporary/llvm_codegen.rs +++ b/llvm_temporary/src/llvm_temporary/llvm_codegen.rs @@ -293,6 +293,10 @@ pub fn generate_address_ir<'ctx>( module: &'ctx inkwell::module::Module<'ctx>, ) -> PointerValue<'ctx> { match expr { + Expression::Grouped(inner) => { + generate_address_ir(context, builder, inner, variables, module) + } + Expression::Variable(name) => { let var_info = variables .get(name) @@ -301,18 +305,25 @@ pub fn generate_address_ir<'ctx>( var_info.ptr } - Expression::Deref(inner_expr) => match &**inner_expr { - Expression::Variable(var_name) => { - let ptr_to_ptr = variables - .get(var_name) - .unwrap_or_else(|| panic!("Variable {} not found", var_name)) - .ptr; + Expression::Deref(inner_expr) => { + let mut inner: &Expression = inner_expr.as_ref(); + while let Expression::Grouped(g) = inner { + inner = g.as_ref(); + } + + match inner { + Expression::Variable(var_name) => { + let ptr_to_ptr = variables + .get(var_name) + .unwrap_or_else(|| panic!("Variable {} not found", var_name)) + .ptr; - let actual_ptr = builder.build_load(ptr_to_ptr, "deref_target").unwrap(); - actual_ptr.into_pointer_value() + let actual_ptr = builder.build_load(ptr_to_ptr, "deref_target").unwrap(); + actual_ptr.into_pointer_value() + } + _ => panic!("Cannot take address: deref target is not a variable"), } - _ => panic!("Nested deref not supported"), - }, + } _ => panic!("Cannot take address of this expression"), } diff --git a/test/test72.wave b/test/test72.wave new file mode 100644 index 00000000..deb32dbc --- /dev/null +++ b/test/test72.wave @@ -0,0 +1,8 @@ +fun main() { + var x: i32 = 10; + x++; + ++x; + x--; + --x; + println("x = {}", x); +} diff --git a/test/test73.wave b/test/test73.wave new file mode 100644 index 00000000..f9e65113 --- /dev/null +++ b/test/test73.wave @@ -0,0 +1,6 @@ +fun main() { + var x: i32 = 1; + let p: ptr = &x; + (deref p)++; + println("{}", p); +}