diff --git a/examples/block_comments.kit b/examples/block_comments.kit index e03c757..c667223 100644 --- a/examples/block_comments.kit +++ b/examples/block_comments.kit @@ -3,7 +3,7 @@ include "stdio.h"; /* Block comment at start of file */ function main() { /* Block comment with indentation */ - var x: int = 42; /* End of line block comment */ + var x: Int = 42; /* End of line block comment */ /* * Multi-line block comment diff --git a/examples/fail_return_type.kit b/examples/fail_return_type.kit new file mode 100644 index 0000000..146aba1 --- /dev/null +++ b/examples/fail_return_type.kit @@ -0,0 +1,5 @@ +include "stdio.h"; + +function main(): Int { + return "hello"; +} diff --git a/examples/inference_test.kit b/examples/inference_test.kit new file mode 100644 index 0000000..aac4ed8 --- /dev/null +++ b/examples/inference_test.kit @@ -0,0 +1,23 @@ +include "stdio.h"; + +function add(a: Int, b: Int) { + return a + b; +} + +function sub(a: Int, b: Int) { + return a - b; +} + +function main() { + var x = 50; + var y = 20; + + var z = add(x, y); + z = sub(z, 5); + + var w = if z > 60 then 10 else 0; + + printf("Value of w: %d\n", w); + + return 0; +} diff --git a/examples/inference_test.kit.expected b/examples/inference_test.kit.expected new file mode 100644 index 0000000..d773a93 --- /dev/null +++ b/examples/inference_test.kit.expected @@ -0,0 +1 @@ +Value of w: 10 diff --git a/examples/line_comments.kit b/examples/line_comments.kit index a643ee4..84cfbbb 100644 --- a/examples/line_comments.kit +++ b/examples/line_comments.kit @@ -3,7 +3,7 @@ include "stdio.h"; // Line comment at start of file function main() { // Line comment with indentation - var x: int = 42; // End of line comment + var x: Int = 42; // End of line comment // Another line comment printf("%d", x); diff --git a/examples/mixed_comments.kit b/examples/mixed_comments.kit index e377519..c9b6403 100644 --- a/examples/mixed_comments.kit +++ b/examples/mixed_comments.kit @@ -3,16 +3,16 @@ include "stdio.h"; // Mix of line and block comments function main() { // Start with line comment - var x: int = 1; + var x: Int = 1; /* Block comment in middle */ - var y: int = 2; // End of line comment + var y: Int = 2; // End of line comment /* * Multi-line block comment * followed by line comment */ // After multi-line - var z: int = x + y; + var z: Int = x + y; printf("%d", z); /* Block comment at end */ } diff --git a/kitc/src/main.rs b/kitc/src/main.rs index 37aa993..d473e01 100644 --- a/kitc/src/main.rs +++ b/kitc/src/main.rs @@ -69,7 +69,7 @@ fn main() -> Result<(), Error> { fn compile(source: &PathBuf, libs: &[String], measure: bool) -> Result { let init = time::Instant::now(); - fs::read_to_string(source).map_err(|_| format!("couldn't read {:?}", source))?; + fs::read_to_string(source).map_err(|_| format!("couldn't read {}", source.display()))?; let ext = if cfg!(windows) { "exe" } else { "" }; let exe_path = source.with_extension(ext); @@ -92,7 +92,7 @@ fn compile(source: &PathBuf, libs: &[String], measure: bool) -> Result Result<(), String> { let status = Command::new(exe_path) .status() - .map_err(|e| format!("failed to launch executable: {}", e))?; + .map_err(|e| format!("failed to launch executable: {e}"))?; if !status.success() { std::process::exit(status.code().unwrap_or(1)); diff --git a/kitc/tests/examples.rs b/kitc/tests/examples.rs index 907b958..cc54eb7 100644 --- a/kitc/tests/examples.rs +++ b/kitc/tests/examples.rs @@ -23,8 +23,8 @@ fn run_example_test( .ok_or("couldn't get workspace root")?; let examples_dir = workspace_root.join("examples"); - let example_file = examples_dir.join(format!("{}.kit", example_name)); - let expected_file = examples_dir.join(format!("{}.kit.expected", example_name)); + let example_file = examples_dir.join(format!("{example_name}.kit")); + let expected_file = examples_dir.join(format!("{example_name}.kit.expected")); assert!( example_file.exists(), @@ -39,8 +39,7 @@ fn run_example_test( ); log::info!( - "Running example {} in {} (path: {}). Expected file is at {}", - example_name, + "Running example {example_name} in {} (path: {}). Expected file is at {}", workspace_root.display(), example_file.display(), expected_file.display() @@ -77,9 +76,7 @@ fn run_example_test( .success(); // TODO: executable files are actually generated in the CWD, not in the examples folder. - // This explains why the executable is not actually removed. But I don't get why these tests - // passed on Linux and Mac if std::fs::remove_file is supposed to also fail when the file - // doesn't exist. + // This explains why the executable is not actually generated in the examples folder. if let Err(err) = std::fs::remove_file(&executable_path) { log::error!("Failed to remove executable: {err}"); } @@ -175,6 +172,11 @@ fn test_mixed_comments() -> Result<(), Box> { run_example_test("mixed_comments", None) } +#[test] +fn test_inference() -> Result<(), Box> { + run_example_test("inference_test", None) +} + #[test] fn test_nested_comments() -> Result<(), Box> { let workspace_root = std::path::Path::new(env!("CARGO_MANIFEST_DIR")) diff --git a/kitlang/src/codegen/ast.rs b/kitlang/src/codegen/ast.rs index 94332cd..536576b 100644 --- a/kitlang/src/codegen/ast.rs +++ b/kitlang/src/codegen/ast.rs @@ -1,11 +1,11 @@ -use crate::codegen::types::*; +use crate::codegen::types::{AssignmentOperator, BinaryOperator, Type, TypeId, UnaryOperator}; use std::collections::HashSet; /// Represents a C header inclusion. #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub struct Include { - /// Path to the header file (e.g., "stdio.h"). + /// Path to header file (e.g., "stdio.h"). pub path: String, } @@ -16,8 +16,10 @@ pub struct Function { pub name: String, /// List of function parameters. pub params: Vec, - /// Return type (`None` for void functions). + /// Return type annotation (`None` for void inference). pub return_type: Option, + /// Inferred return type ID. + pub inferred_return: Option, /// Function body as a block of statements. pub body: Block, } @@ -27,8 +29,10 @@ pub struct Function { pub struct Param { /// Parameter name. pub name: String, - /// Parameter type. - pub ty: Type, + /// Parameter type annotation (if specified). + pub annotation: Option, + /// Inferred parameter type ID. + pub ty: TypeId, } /// Represents a block of statements (e.g., function body or scope block). @@ -46,7 +50,9 @@ pub enum Stmt { /// Variable name. name: String, /// Type annotation (`None` for type inference). - ty: Option, + annotation: Option, + /// Inferred variable type ID. + inferred: TypeId, /// Initializer expression (`None` for uninitialized). init: Option, }, @@ -89,15 +95,17 @@ pub enum Stmt { #[derive(Clone, Debug, PartialEq)] pub enum Expr { /// Variable or function identifier. - Identifier(String), + Identifier(String, TypeId), /// Literal value. - Literal(Literal), + Literal(Literal, TypeId), /// Function call. Call { /// Name of the callee function. callee: String, /// Arguments passed to the function. args: Vec, + /// Inferred return type. + ty: TypeId, }, /// Unary operation. UnaryOp { @@ -105,18 +113,24 @@ pub enum Expr { op: UnaryOperator, /// The operand expression. expr: Box, + /// Inferred result type. + ty: TypeId, }, /// Binary operation. BinaryOp { op: BinaryOperator, left: Box, right: Box, + /// Inferred result type. + ty: TypeId, }, /// Assignment operation. Assign { op: AssignmentOperator, left: Box, right: Box, + /// Inferred result type. + ty: TypeId, }, /// If-then-else expression. If { @@ -126,6 +140,8 @@ pub enum Expr { then_branch: Box, /// The expression to evaluate if the condition is false. else_branch: Box, + /// Inferred result type. + ty: TypeId, }, /// Range literal expression (e.g., `1...10`). RangeLiteral { @@ -153,15 +169,16 @@ pub enum Literal { impl Literal { /// Converts the literal to its C representation string. + #[must_use] pub fn to_c(&self) -> String { match self { Literal::Int(i) => i.to_string(), Literal::Float(f) => { // Ensure float literals have 'f' suffix in C if f.fract() == 0.0 { - format!("{}.0f", f) + format!("{f}.0f") } else { - format!("{}f", f) + format!("{f}f") } } Literal::String(s) => { @@ -176,7 +193,7 @@ impl Literal { _ => c.to_string(), }) .collect(); - format!("\"{}\"", escaped) + format!("\"{escaped}\"") } Literal::Bool(b) => b.to_string(), Literal::Null => "NULL".to_string(), diff --git a/kitlang/src/codegen/compiler.rs b/kitlang/src/codegen/compiler.rs index bddc012..b0db0be 100644 --- a/kitlang/src/codegen/compiler.rs +++ b/kitlang/src/codegen/compiler.rs @@ -146,7 +146,7 @@ impl Toolchain { match self { Toolchain::Gcc | Toolchain::Clang => { let flags = ["-std=c99", "-Wall", "-Wextra", "-pedantic"]; - flags.iter().map(|s| s.to_string()).collect() + flags.iter().map(ToString::to_string).collect() } #[cfg(windows)] Toolchain::Msvc => { @@ -181,7 +181,7 @@ pub struct CompilerOptions { pub link_opts: Vec, } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Copy)] pub struct CompilerMeta(pub Toolchain); impl CompilerOptions { @@ -256,9 +256,10 @@ impl CompilerOptions { /// Build the compiler invocation used to spawn `Command`. /// - /// Returns (path_to_compiler_executable, args_vec). + /// Returns (`path_to_compiler_executable`, `args_vec`). + /// + /// # Errors /// - /// Errors: /// - if `sources` is empty /// - if `output` is not set /// - if no system compiler can be found and no `enforced_toolchain` is usable diff --git a/kitlang/src/codegen/frontend.rs b/kitlang/src/codegen/frontend.rs index 9dd2733..e23537c 100644 --- a/kitlang/src/codegen/frontend.rs +++ b/kitlang/src/codegen/frontend.rs @@ -1,17 +1,18 @@ +use crate::error::CompileResult; use crate::{KitParser, Rule, error::CompilationError}; use pest::Parser; use std::collections::HashSet; +use std::fmt::Write; use std::path::{Path, PathBuf}; use std::process::Command; -use crate::codegen::ast::*; +use crate::codegen::ast::{Block, Expr, Function, Include, Program, Stmt}; use crate::codegen::compiler::{CompilerMeta, CompilerOptions, Toolchain}; +use crate::codegen::inference::TypeInferencer; use crate::codegen::parser::Parser as CodeParser; use crate::codegen::types::{ToCRepr, Type}; -pub type CompileResult = Result; - pub struct Compiler { files: Vec, output: PathBuf, @@ -19,6 +20,7 @@ pub struct Compiler { includes: Vec, libs: Vec, parser: CodeParser, + inferencer: TypeInferencer, } impl Compiler { @@ -30,6 +32,7 @@ impl Compiler { includes: Vec::new(), libs, parser: CodeParser::new(), + inferencer: TypeInferencer::new(), } } @@ -58,7 +61,7 @@ impl Compiler { } } - self.includes = includes.clone(); + self.includes.clone_from(&includes); Ok(Program { includes, @@ -67,14 +70,15 @@ impl Compiler { }) } - fn transpile_with_program(&mut self, prog: Program) { + /// Generate C code from the AST and write it to the output path + fn transpile_with_program(&mut self, prog: &Program) { let c_code = self.generate_c_code(prog); if let Err(e) = std::fs::write(&self.c_output, c_code) { - panic!("Failed to write output: {}", e); + panic!("Failed to write output: {e}"); } } - fn generate_c_code(&self, prog: Program) -> String { + fn generate_c_code(&self, prog: &Program) -> String { let mut out = String::new(); // emit regular includes from the source `include` statements @@ -86,7 +90,7 @@ impl Compiler { let mut seen_headers = HashSet::new(); // Vec preserves order - let mut seen_declarations = Vec::new(); + let mut seen_declarations: Vec = Vec::new(); let mut collect_from_type = |t: &Type| { let ctype = t.to_c_repr(); @@ -102,22 +106,36 @@ impl Compiler { // scan every function signature & body for types to gather their headers/typedefs for func in &prog.functions { - if let Some(r) = &func.return_type { + // Use inferred return type + if let Some(ret_id) = func.inferred_return { + if let Ok(ty) = self.inferencer.store.resolve(ret_id) { + collect_from_type(&ty); + } + } else if let Some(r) = &func.return_type { collect_from_type(r); } + for p in &func.params { - collect_from_type(&p.ty); + // Use inferred param type + if let Ok(ty) = self.inferencer.store.resolve(p.ty) { + collect_from_type(&ty); + } else if let Some(ann) = &p.annotation { + collect_from_type(ann); + } } + for stmt in &func.body.stmts { - if let Stmt::VarDecl { ty: Some(t), .. } = stmt { - collect_from_type(t); + if let Stmt::VarDecl { inferred, .. } = stmt + && let Ok(ty) = self.inferencer.store.resolve(*inferred) + { + collect_from_type(&ty); } } } // emit unique headers for hdr in seen_headers { - out.push_str(&format!("#include {}\n", hdr)); + writeln!(out, "#include {hdr}").unwrap(); } out.push('\n'); @@ -126,11 +144,10 @@ impl Compiler { out.push_str(&decl); out.push('\n'); } - out.push('\n'); // emit functions as before... - for func in prog.functions { - out.push_str(&self.transpile_function(&func)); + for func in &prog.functions { + out.push_str(&self.transpile_function(func)); out.push_str("\n\n"); } out @@ -140,19 +157,31 @@ impl Compiler { let return_type = if func.name == "main" { "int".to_string() } else { - func.return_type - .as_ref() - .map_or("void".to_string(), |t| self.transpile_type(t)) + // Try inferred return type first + func.inferred_return + .and_then(|id| self.inferencer.store.resolve(id).ok()) + .map(|t| t.to_c_repr().name) + .or_else(|| func.return_type.as_ref().map(|t| t.to_c_repr().name)) + .unwrap_or_else(|| "void".to_string()) }; let params = func .params .iter() - .map(|p| format!("{} {}", self.transpile_type(&p.ty), p.name)) + .map(|p| { + let ty_name = self + .inferencer + .store + .resolve(p.ty) + .map(|t| t.to_c_repr().name) + .or_else(|_| p.annotation.as_ref().map(|t| t.to_c_repr().name).ok_or(())) + .unwrap_or("void*".to_string()); // Fallback + format!("{} {}", ty_name, p.name) + }) .collect::>() .join(", "); - let mut body = self.transpile_block(&func.body); + let mut body_code = self.transpile_block(&func.body); if func.name == "main" { let has_return = func @@ -161,155 +190,165 @@ impl Compiler { .iter() .any(|stmt| matches!(stmt, Stmt::Return(_))); if !has_return { - body.push_str(" return 0;\n"); + // Insert return 0 before the closing brace + if let Some(pos) = body_code.rfind('}') { + body_code.insert_str(pos, "return 0;\n"); + } } } - format!("{} {}({}) {{\n{}}}", return_type, func.name, params, body) + format!("{} {}({}) {}", return_type, func.name, params, body_code) } fn transpile_block(&self, block: &Block) -> String { - let mut code = String::new(); + let mut code = String::from("{\n"); for stmt in &block.stmts { - match stmt { - Stmt::VarDecl { name, ty, init } => { - let ty_str = if let Some(t) = ty { - self.transpile_type(t) - } else if let Some(Expr::Literal(Literal::Int(_))) = init { - "int".to_string() - } else { - "auto".to_string() - }; + let stmt_code = match stmt { + Stmt::VarDecl { + name, + annotation: _, + inferred, + init, + } => { + let ty_str = self + .inferencer + .store + .resolve(*inferred) + .map_or_else(|_| "auto".to_string(), |t| t.to_c_repr().name); + match init { Some(expr) => { let init_str = self.transpile_expr(expr); - code.push_str(&format!("{} {} = {};\n", ty_str, name, init_str)); + format!("{ty_str} {name} = {init_str};\n") } None => { - code.push_str(&format!("{} {};\n", ty_str, name)); + format!("{ty_str} {name};\n") } } } Stmt::Expr(expr) => { - code.push_str(&self.transpile_expr(expr)); - code.push_str(";\n"); + format!("{};\n", self.transpile_expr(expr)) } - // TODO: should add a return to the main function anyway Stmt::Return(expr) => { - let expr_str = expr - .as_ref() - .map_or(String::new(), |e| format!(" {}", self.transpile_expr(e))); - code.push_str(&format!("return{};\n", expr_str)); + if let Some(e) = expr { + format!("return {};\n", self.transpile_expr(e)) + } else { + "return;\n".to_string() + } } Stmt::If { cond, then_branch, else_branch, } => { - let cond_str = self.transpile_expr(cond); - let then_code = self.transpile_block(then_branch); - let mut if_str = format!("if ({}) {{\n{}}}", cond_str, then_code); + let mut s = format!("if ({}) ", self.transpile_expr(cond)); + s.push_str(&self.transpile_block(then_branch)); if let Some(else_b) = else_branch { - let else_code = self.transpile_block(else_b); - if_str.push_str(&format!(" else {{\n{}}}", else_code)); + s.push_str(" else "); + s.push_str(&self.transpile_block(else_b)); } - code.push_str(&if_str); + s.push('\n'); + s } Stmt::While { cond, body } => { - let cond_str = self.transpile_expr(cond); - let body_code = self.transpile_block(body); - code.push_str(&format!("while ({}) {{\n{}}}", cond_str, body_code)); + let mut s = format!("while ({}) ", self.transpile_expr(cond)); + s.push_str(&self.transpile_block(body)); + s.push('\n'); + s } - // Translate `for i in 10` to `for (int i = 0; i < 10; ++i)` - // Of course, this assumes `iter` (i) is an integer literal or expression that evaluates to an integer. Stmt::For { var, iter, body } => { - let body_code = self.transpile_block(body); - match iter { - Expr::RangeLiteral { start, end } => { - // Handle range literals: `for i in 1...10` - let start_str = self.transpile_expr(start); - let end_str = self.transpile_expr(end); - code.push_str(&format!( - "for (int {} = {}; {} < {}; ++{}) {{\n{}}}", - var, start_str, var, end_str, var, body_code - )); - } - _ => { - // Handle single integer expressions: `for i in 3` - let iter_str = self.transpile_expr(iter); - code.push_str(&format!( - "for (int {} = 0; {} < {}; ++{}) {{\n{}}}", - var, var, iter_str, var, body_code - )); - } - } - } - Stmt::Break => { - code.push_str("break;\n"); - } - Stmt::Continue => { - code.push_str("continue;\n"); + let mut s = if let Expr::RangeLiteral { start, end } = iter { + let start_str = self.transpile_expr(start); + let end_str = self.transpile_expr(end); + format!("for (int {var} = {start_str}; {var} < {end_str}; ++{var}) ") + } else { + let iter_str = self.transpile_expr(iter); + format!("for (int {var} = 0; {var} < {iter_str}; ++{var}) ") + }; + s.push_str(&self.transpile_block(body)); + s } + Stmt::Break => "break;\n".to_string(), + Stmt::Continue => "continue;\n".to_string(), + }; + + for line in stmt_code.lines() { + code.push_str(" "); + code.push_str(line); + code.push('\n'); } } + code.push('}'); code } fn transpile_expr(&self, expr: &Expr) -> String { match expr { - Expr::Identifier(id) => id.clone(), - Expr::Literal(lit) => lit.to_c(), - Expr::Call { callee, args } => { + Expr::Identifier(name, _) => name.clone(), + Expr::Literal(lit, _) => lit.to_c(), + Expr::Call { + callee, + args, + ty: _, + } => { let args_str = args .iter() .map(|a| self.transpile_expr(a)) .collect::>() .join(", "); - format!("{}({})", callee, args_str) + format!("{callee}({args_str})") } - Expr::UnaryOp { op, expr } => { + Expr::UnaryOp { op, expr, ty: _ } => { let expr_str = self.transpile_expr(expr); - op.to_string_with_expr(expr_str) + format!("{}({})", op.to_c_str(), expr_str) + } + Expr::BinaryOp { + op, + left, + right, + ty: _, + } => { + let left_str = self.transpile_expr(left); + let right_str = self.transpile_expr(right); + format!("({left_str} {} {right_str})", op.to_c_str()) } - Expr::BinaryOp { op, left, right } => { + Expr::Assign { + op, + left, + right, + ty: _, + } => { let left_str = self.transpile_expr(left); let right_str = self.transpile_expr(right); - format!("({} {} {})", left_str, op.to_c_str(), right_str) + format!("{left_str} {} {right_str}", op.to_c_str()) } Expr::If { cond, then_branch, else_branch, + ty: _, } => { let cond_str = self.transpile_expr(cond); let then_str = self.transpile_expr(then_branch); let else_str = self.transpile_expr(else_branch); - format!("({}) ? ({}) : ({})", cond_str, then_str, else_str) - } - Expr::Assign { op, left, right } => { - let left_str = self.transpile_expr(left); - let right_str = self.transpile_expr(right); - format!("({} {} {})", left_str, op.to_c_str(), right_str) + format!("({cond_str} ? {then_str} : {else_str})") } - Expr::RangeLiteral { start: _, end: _ } => { - // Range literals are not directly transpiled to C - // They are only used in for loop context - panic!("Range literals should only be used in for loop expressions") + Expr::RangeLiteral { .. } => { + // Should technically not be used alone, but return something safe to avoid panic + "/* range literal */ 0".to_string() } } } - fn transpile_type(&self, ty: &Type) -> String { - ty.to_c_repr().name - } - pub fn compile(&mut self) -> CompileResult<()> { - let prog = self.parse()?; - self.transpile_with_program(prog); + let mut prog = self.parse()?; + + self.inferencer.infer_program(&mut prog)?; + self.transpile_with_program(&prog); let detected = Toolchain::executable_path().ok_or(CompilationError::ToolchainNotFound)?; - // FIX: Handle non-UTF-8 paths + // TODO: Handle non-UTF-8 paths let target_path = self .output .clone() diff --git a/kitlang/src/codegen/inference.rs b/kitlang/src/codegen/inference.rs new file mode 100644 index 0000000..712196b --- /dev/null +++ b/kitlang/src/codegen/inference.rs @@ -0,0 +1,373 @@ +use super::ast::{Block, Expr, Function, Literal, Program, Stmt}; +use super::symbols::SymbolTable; +use super::types::{BinaryOperator, Type, TypeId, TypeStore, UnaryOperator}; +use crate::error::{CompilationError, CompileResult}; + +/// Type inference engine using Hindley-Milner algorithm. +pub struct TypeInferencer { + pub store: TypeStore, + symbols: SymbolTable, + current_return_type: Option, +} + +impl Default for TypeInferencer { + fn default() -> Self { + Self::new() + } +} + +impl TypeInferencer { + pub fn new() -> Self { + Self { + store: TypeStore::new(), + symbols: SymbolTable::new(), + current_return_type: None, + } + } + + /// Infer types for an entire program + pub fn infer_program(&mut self, prog: &mut Program) -> CompileResult<()> { + for func in &mut prog.functions { + self.infer_function(func)?; + } + Ok(()) + } + + /// Infer types for a function definition + fn infer_function(&mut self, func: &mut Function) -> CompileResult<()> { + // Infer parameter types (fresh unknowns if unannotated) + for param in &mut func.params { + param.ty = if let Some(ann) = ¶m.annotation { + self.store.new_known(ann.clone()) + } else { + self.store.new_unknown() + }; + self.symbols.define_var(¶m.name, param.ty); + } + + // Infer return type + func.inferred_return = if let Some(ann) = &func.return_type { + Some(self.store.new_known(ann.clone())) + } else { + Some(self.store.new_unknown()) + }; + + self.current_return_type = func.inferred_return; + + // Infer function body + self.infer_block(&mut func.body)?; + + self.current_return_type = None; + + // Register function signature in symbol table + if let Some(ret_ty) = func.inferred_return { + let param_tys: Vec = func.params.iter().map(|p| p.ty).collect(); + self.symbols.define_function(&func.name, param_tys, ret_ty); + } + + Ok(()) + } + + /// Infer types for a block of statements + fn infer_block(&mut self, block: &mut Block) -> CompileResult<()> { + for stmt in &mut block.stmts { + self.infer_stmt(stmt)?; + } + Ok(()) + } + + /// Infer types for a single statement + fn infer_stmt(&mut self, stmt: &mut Stmt) -> CompileResult<()> { + match stmt { + Stmt::VarDecl { + name, + annotation, + inferred, + init, + } => { + if let Some(init_expr) = init { + let init_ty = self.infer_expr(init_expr)?; + + *inferred = if let Some(ann) = annotation { + let ann_ty = self.store.new_known(ann.clone()); + self.unify(ann_ty, init_ty)?; + ann_ty + } else { + init_ty + }; + + self.symbols.define_var(name, *inferred); + } else if let Some(ann) = annotation { + // Declaration without initializer -> just use annotation + *inferred = self.store.new_known(ann.clone()); + self.symbols.define_var(name, *inferred); + } else { + return Err(CompilationError::TypeError(format!( + "Variable '{name}' declared without type annotation or initializer", + ))); + } + } + + Stmt::Expr(expr) => { + self.infer_expr(expr)?; + } + + Stmt::Return(Some(expr)) => { + let expr_ty = self.infer_expr(expr)?; + if let Some(ret_ty) = self.current_return_type { + self.unify(ret_ty, expr_ty)?; + } else { + return Err(CompilationError::TypeError( + "Return statement outside of function".into(), + )); + } + } + + // Void return - check if function expects void + Stmt::Return(None) => { + if let Some(ret_ty) = self.current_return_type { + let void_ty = self.store.new_known(Type::Void); + self.unify(ret_ty, void_ty)?; + } else { + return Err(CompilationError::TypeError( + "Return statement outside of function".into(), + )); + } + } + + Stmt::If { + cond, + then_branch, + else_branch, + } => { + let cond_ty = self.infer_expr(cond)?; + let bool_ty = self.store.new_known(Type::Bool); + self.unify(cond_ty, bool_ty)?; + + self.infer_block(then_branch)?; + if let Some(else_b) = else_branch { + self.infer_block(else_b)?; + } + } + + Stmt::While { cond, body } => { + let cond_ty = self.infer_expr(cond)?; + let bool_ty = self.store.new_known(Type::Bool); + self.unify(cond_ty, bool_ty)?; + + self.infer_block(body)?; + } + + Stmt::For { var, iter, body } => { + let iter_ty = self.infer_expr(iter)?; + + // For i in N: iter should be Int-like OR a range (which we currently typed as Void) + // TODO: Better range typing + let iter_resolved = self + .store + .resolve(iter_ty) + .map_err(CompilationError::TypeError)?; + if iter_resolved != Type::Int && iter_resolved != Type::Void { + return Err(CompilationError::TypeError(format!( + "For loop iterator must be Int or Range, found {iter_resolved:?}" + ))); + } + + let var_ty = self.store.new_known(Type::Int); + self.symbols.define_var(var, var_ty); + + self.infer_block(body)?; + } + + Stmt::Break | Stmt::Continue => { + // No type inference needed + } + } + Ok(()) + } + + /// Infer types for an expression + fn infer_expr(&mut self, expr: &mut Expr) -> Result { + let ty = match expr { + Expr::Identifier(name, ty_id) => { + let var_ty = self.symbols.lookup_var(name).ok_or_else(|| { + CompilationError::TypeError(format!("Use of undeclared variable '{name}'")) + })?; + *ty_id = var_ty; + var_ty + } + + Expr::Literal(lit, ty_id) => { + let ty = match lit { + Literal::Int(_) => Type::Int, + Literal::Float(_) => Type::Float, + Literal::Bool(_) => Type::Bool, + Literal::String(_) => Type::CString, + Literal::Null => Type::Ptr(Box::new(Type::Void)), + }; + let type_id = self.store.new_known(ty); + *ty_id = type_id; + type_id + } + + Expr::Call { callee, args, ty } => { + let (param_tys, ret_ty) = if let Some(sig) = self.symbols.lookup_function(callee) { + sig + } else { + // For undeclared functions (like printf), we allow them but can't check params. + // We assume they return Void for now, or we could return a fresh unknown. + let void_ty = self.store.new_known(Type::Void); + (vec![], void_ty) + }; + + if !param_tys.is_empty() && args.len() != param_tys.len() { + return Err(CompilationError::TypeError(format!( + "Function '{}' expects {} arguments, got {}", + callee, + param_tys.len(), + args.len() + ))); + } + + if param_tys.is_empty() { + // Just infer arguments without unifying if signature is unknown (variadic C funcs) + for arg in args.iter_mut() { + self.infer_expr(arg)?; + } + } else { + for (arg, param_ty) in args.iter_mut().zip(param_tys.iter()) { + let arg_ty = self.infer_expr(arg)?; + self.unify(arg_ty, *param_ty)?; + } + } + + *ty = ret_ty; + ret_ty + } + + Expr::UnaryOp { op, expr, ty } => { + let expr_ty = self.infer_expr(expr)?; + + // Unary operators typically preserve type (except address-of) + let result_ty = match op { + UnaryOperator::AddressOf => { + let resolved = self + .store + .resolve(expr_ty) + .map_err(CompilationError::TypeError)?; + let ptr_ty = Type::Ptr(Box::new(resolved)); + self.store.new_known(ptr_ty) + } + UnaryOperator::Dereference => { + let resolved = self + .store + .resolve(expr_ty) + .map_err(CompilationError::TypeError)?; + if let Type::Ptr(inner_ty) = resolved { + self.store.new_known(*inner_ty) + } else { + return Err(CompilationError::TypeError(format!( + "Cannot dereference non-pointer type: {resolved:?}" + ))); + } + } + _ => expr_ty, + }; + + *ty = result_ty; + result_ty + } + + Expr::BinaryOp { + op, + left, + right, + ty, + } => { + let left_ty = self.infer_expr(left)?; + let right_ty = self.infer_expr(right)?; + + // Result type depends on operator + let result_ty = match op { + BinaryOperator::And | BinaryOperator::Or => { + let bool_ty = self.store.new_known(Type::Bool); + self.unify(left_ty, bool_ty)?; + self.unify(right_ty, bool_ty)?; + bool_ty + } + BinaryOperator::Eq + | BinaryOperator::Ne + | BinaryOperator::Lt + | BinaryOperator::Gt + | BinaryOperator::Le + | BinaryOperator::Ge => { + self.unify(left_ty, right_ty)?; + self.store.new_known(Type::Bool) + } + _ => { + self.unify(left_ty, right_ty)?; + left_ty + } + }; + + *ty = result_ty; + result_ty + } + + Expr::Assign { + op: _, + left, + right, + ty, + } => { + let right_ty = self.infer_expr(right)?; + let left_ty = self.infer_expr(left)?; + + // Assignment: right must unify with left + self.unify(left_ty, right_ty)?; + + *ty = left_ty; + left_ty + } + + Expr::If { + cond, + then_branch, + else_branch, + ty, + } => { + let cond_ty = self.infer_expr(cond)?; + let bool_ty = self.store.new_known(Type::Bool); + self.unify(cond_ty, bool_ty)?; + + let then_ty = self.infer_expr(then_branch)?; + let else_ty = self.infer_expr(else_branch)?; + + self.unify(then_ty, else_ty)?; + + *ty = then_ty; + then_ty + } + + Expr::RangeLiteral { start, end } => { + // Range literals are only used in for loop context + let start_ty = self.infer_expr(start)?; + let end_ty = self.infer_expr(end)?; + + let int_ty = self.store.new_known(Type::Int); + self.unify(start_ty, int_ty)?; + self.unify(end_ty, int_ty)?; + + // Return a dummy type -> ranges don't have their own type + self.store.new_known(Type::Void) + } + }; + + Ok(ty) + } + + /// Unify two type IDs + fn unify(&mut self, a: TypeId, b: TypeId) -> CompileResult<()> { + self.store.unify(a, b).map_err(CompilationError::TypeError) + } +} diff --git a/kitlang/src/codegen/mod.rs b/kitlang/src/codegen/mod.rs index c245746..cd71322 100644 --- a/kitlang/src/codegen/mod.rs +++ b/kitlang/src/codegen/mod.rs @@ -1,7 +1,7 @@ //! The `codegen` module is responsible for generating executable code //! from the parsed Kit Abstract Syntax Tree (AST). //! -//! It orchestrates the compilation process, translating the AST into an +//! It orchestrates compilation process, translating AST into an //! intermediate representation and then into target-specific machine code. pub mod ast; @@ -10,8 +10,8 @@ pub mod parser; pub use ast::{Block, Expr, Function, Include, Literal, Param, Program, Stmt}; pub use compiler::Toolchain; -/// Handles the initial parsing of Kitlang source files, constructs the -/// Abstract Syntax Tree (AST), and orchestrates the generation of C code +/// Handles the initial parsing of Kitlang source files, constructs +/// Abstract Syntax Tree (AST), and orchestrates generation of C code /// from this AST. pub mod frontend; @@ -20,3 +20,9 @@ pub mod frontend; /// expressions, and literals. It also handles the conversion logic between Kitlang /// and C type systems. pub mod types; + +/// Symbol table for tracking variables and functions during type inference. +pub mod symbols; + +/// Type inference engine using Hindley-Milner algorithm. +pub mod inference; diff --git a/kitlang/src/codegen/parser.rs b/kitlang/src/codegen/parser.rs index 3d584b4..b79c411 100644 --- a/kitlang/src/codegen/parser.rs +++ b/kitlang/src/codegen/parser.rs @@ -5,16 +5,16 @@ use crate::error::CompilationError; use crate::{Rule, parse_error}; use super::ast::{Block, Expr, Function, Include, Literal, Param, Stmt}; -use super::frontend::CompileResult; -use super::types::{AssignmentOperator, Type}; +use super::types::{AssignmentOperator, Type, TypeId}; +use crate::error::CompileResult; use std::path::PathBuf; use std::str::FromStr; #[derive(Default, Debug)] pub struct Parser { - current_file: Option, - source_content: String, + _current_file: Option, + _source_content: String, } impl Parser { @@ -26,20 +26,11 @@ impl Parser { let source_content = std::fs::read_to_string(&file).unwrap_or_else(|_| String::new()); Self { - current_file: Some(file), - source_content, + _current_file: Some(file), + _source_content: source_content, } } - /// Gets the line content from a line number. Returns an empty string if the line does not exist - fn get_line_content(&self, line_num: usize) -> String { - self.source_content - .lines() - .nth(line_num.saturating_sub(1)) - .unwrap_or("") - .to_string() - } - pub fn parse_include(&self, pair: Pair) -> Include { // include_stmt = { "include" ~ string ~ ("=>" ~ string)? ~ ";" } let mut inner = pair.into_inner(); @@ -84,6 +75,7 @@ impl Parser { name, params, return_type, + inferred_return: None, body, }) } @@ -97,8 +89,12 @@ impl Parser { // SAFETY: Grammar guarantees param has identifier and type let name = inner.next().unwrap().as_str().to_string(); let type_node = inner.next().unwrap(); - let ty = self.parse_type(type_node)?; - Ok(Param { name, ty }) + let ty_ann = self.parse_type(type_node)?; + Ok(Param { + name, + annotation: Some(ty_ann), + ty: TypeId::default(), + }) }) .collect() } @@ -122,8 +118,7 @@ impl Parser { Rule::break_stmt => Ok(Stmt::Break), Rule::continue_stmt => Ok(Stmt::Continue), other => Err(CompilationError::ParseError(format!( - "unexpected statement: {:?}", - other + "unexpected statement: {other:?}", ))), } }) @@ -155,7 +150,12 @@ impl Parser { } let name = name.ok_or(parse_error!("var_decl missing identifier"))?; - Ok(Stmt::VarDecl { name, ty, init }) + Ok(Stmt::VarDecl { + name, + annotation: ty, + inferred: TypeId::default(), + init, + }) } fn parse_type(&self, pair: Pair) -> CompileResult { @@ -271,6 +271,7 @@ impl Parser { op, left: Box::new(left), right: Box::new(right), + ty: TypeId::default(), }; } Ok(left) @@ -284,13 +285,14 @@ impl Parser { Rule::unary_op => { let op_str = first_pair.as_str(); let op = UnaryOperator::from_str(op_str) - .map_err(|_| parse_error!("invalid unary operation: {}", op_str))?; + .map_err(|()| parse_error!("invalid unary operation: {op_str}"))?; // SAFETY: Grammar guarantees expression after unary op let expr = self.parse_expr(inner_pairs.next().unwrap())?; Ok(Expr::UnaryOp { op, expr: Box::new(expr), + ty: TypeId::default(), }) } Rule::ADDRESS_OF_OP => { @@ -300,13 +302,17 @@ impl Parser { Ok(Expr::UnaryOp { op, expr: Box::new(expr), + ty: TypeId::default(), }) } Rule::primary => self.parse_expr(first_pair), - _other => Err(parse_error!("Unexpected rule in unary: {:?}", _other)), + other => Err(parse_error!("Unexpected rule in unary: {other:?}")), } } - Rule::identifier => Ok(Expr::Identifier(pair.as_str().to_string())), + Rule::identifier => Ok(Expr::Identifier( + pair.as_str().to_string(), + TypeId::default(), + )), Rule::literal => { // SAFETY: Grammar guarantees exactly one child in literal let inner = pair.into_inner().next().unwrap(); @@ -320,23 +326,19 @@ impl Parser { let i = s.parse::().map_err(|e| { parse_error!("invalid integer literal '{s}': {:?}", e) })?; - Ok(Expr::Literal(Literal::Int(i))) + Ok(Expr::Literal(Literal::Int(i), TypeId::default())) } Rule::float => { let s = num_pair.as_str(); let f = s.parse::().map_err(|e| { parse_error!("invalid float literal '{s}': {:?}", e) })?; - Ok(Expr::Literal(Literal::Float(f))) + Ok(Expr::Literal(Literal::Float(f), TypeId::default())) } _ => Err(parse_error!("Unexpected number type")), } } - Rule::boolean => match inner.as_str() { - "true" => Ok(Expr::Literal(Literal::Bool(true))), - "false" => Ok(Expr::Literal(Literal::Bool(false))), - _s => Err(parse_error!("invalid boolean literal: {}", _s)), - }, + Rule::boolean => Self::parse_bool_literal(inner.as_str()), Rule::char_literal => todo!("char literal parsing not implemented"), _ => Err(parse_error!( "Unexpected literal type: {:?}", @@ -348,7 +350,7 @@ impl Parser { let full = pair.as_str(); let inner = &full[1..full.len() - 1]; let unescaped = Self::unescape(inner).unwrap_or_else(|| inner.to_string()); - Ok(Expr::Literal(Literal::String(unescaped))) + Ok(Expr::Literal(Literal::String(unescaped), TypeId::default())) } Rule::function_call_expr => { let mut inner = pair.into_inner(); @@ -358,7 +360,11 @@ impl Parser { .filter(|p: &Pair| p.as_rule() == Rule::expr) .map(|p: Pair| self.parse_expr(p)) .collect::, _>>()?; // Collect and propagate errors - Ok(Expr::Call { callee, args }) + Ok(Expr::Call { + callee, + args, + ty: TypeId::default(), + }) } Rule::if_expr => { let mut inner = pair.into_inner(); @@ -369,14 +375,50 @@ impl Parser { cond: Box::new(cond), then_branch: Box::new(then_branch), else_branch: Box::new(else_branch), + ty: TypeId::default(), }) } Rule::primary => { - // SAFETY: Primary rule always has exactly one child - let inner = pair.into_inner().next().unwrap(); - self.parse_expr(inner) + let text = pair.as_str(); + let mut inner = pair.into_inner(); + + // Tokens like "null", "this", "Self", "true", "false" have no inner pairs + if inner.peek().is_none() { + match text { + "null" => Ok(Expr::Literal(Literal::Null, TypeId::default())), + "true" | "false" => Self::parse_bool_literal(text), + // "this" => Ok(Expr::This(TypeId::default())), + // "Self" => Ok(Expr::SelfType), + other => Err(parse_error!("Unknown primary keyword: {}", other)), + } + } else { + // Otherwise, unwrap and parse the inner rule + let inner_pair = inner.next().unwrap(); + match inner_pair.as_rule() { + Rule::identifier => Ok(Expr::Identifier( + inner_pair.as_str().to_string(), + TypeId::default(), + )), + Rule::literal + | Rule::function_call_expr + | Rule::array_literal + | Rule::struct_init + | Rule::union_init + | Rule::tuple_literal + | Rule::if_expr + | Rule::range_expr + | Rule::string + | Rule::expr + | Rule::unary => self.parse_expr(inner_pair), + _ => Err(parse_error!( + "Unexpected primary inner rule: {:?}", + inner_pair.as_rule() + )), + } + } } + Rule::range_expr => { let mut inner = pair.into_inner(); let start = self.parse_expr(inner.next().unwrap())?; @@ -387,8 +429,7 @@ impl Parser { }) } other => Err(CompilationError::ParseError(format!( - "Unexpected expr rule: {:?}", - other + "Unexpected expr rule: {other:?}" ))), } } @@ -413,6 +454,7 @@ impl Parser { op, left: Box::new(left), right: Box::new(right), + ty: TypeId::default(), }) } else { // No assignment operator, so it's just the expression itself (the logical_or that formed the LHS) @@ -420,6 +462,14 @@ impl Parser { } } + fn parse_bool_literal(s: &str) -> CompileResult { + match s { + "true" => Ok(Expr::Literal(Literal::Bool(true), TypeId::default())), + "false" => Ok(Expr::Literal(Literal::Bool(false), TypeId::default())), + _ => Err(parse_error!("invalid boolean literal: {}", s)), + } + } + fn unescape(s: impl AsRef) -> Option { // TODO: search if there is an escape sequence beforehand to // avoid the allocation and search logic diff --git a/kitlang/src/codegen/symbols.rs b/kitlang/src/codegen/symbols.rs new file mode 100644 index 0000000..a40c282 --- /dev/null +++ b/kitlang/src/codegen/symbols.rs @@ -0,0 +1,49 @@ +use super::types::TypeId; +use std::collections::HashMap; + +/// Symbol table for tracking variable and function types during inference. +/// +/// Currently uses a flat scope (no nesting). Variables and functions are tracked +/// by their names and their `TypeId`s. +pub struct SymbolTable { + /// Maps variable names to their inferred `TypeId`s. + vars: HashMap, + + /// Maps function names to their signatures (parameter types, return type). + functions: HashMap, TypeId)>, +} + +impl Default for SymbolTable { + fn default() -> Self { + Self::new() + } +} + +impl SymbolTable { + pub fn new() -> Self { + Self { + vars: HashMap::new(), + functions: HashMap::new(), + } + } + + /// Define a variable in the current scope. + pub fn define_var(&mut self, name: &str, ty: TypeId) { + self.vars.insert(name.to_string(), ty); + } + + /// Look up a variable's type. + pub fn lookup_var(&self, name: &str) -> Option { + self.vars.get(name).copied() + } + + /// Define a function signature. + pub fn define_function(&mut self, name: &str, params: Vec, ret: TypeId) { + self.functions.insert(name.to_string(), (params, ret)); + } + + /// Look up a function's signature. + pub fn lookup_function(&self, name: &str) -> Option<(Vec, TypeId)> { + self.functions.get(name).cloned() + } +} diff --git a/kitlang/src/codegen/types.rs b/kitlang/src/codegen/types.rs index 41ad021..b0d56d9 100644 --- a/kitlang/src/codegen/types.rs +++ b/kitlang/src/codegen/types.rs @@ -6,29 +6,246 @@ use pest::iterators::Pair; use std::collections::HashSet; use std::str::FromStr; -/// Trait for converting types to their C representation. +/// Identity handle for a type in `TypeStore`. /// -/// This trait should be implemented for any type that needs to be represented in generated C code. -/// The conversion should include necessary header dependencies and any required type declarations. -pub trait ToCRepr { - /// Converts `self` to its C representation. - fn to_c_repr(&self) -> CType; +/// Types need stable identity for inference - we can't use the enum alone. +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub struct TypeId(u32); + +impl Default for TypeId { + fn default() -> Self { + Self(u32::MAX) + } +} + +/// Identity handle for a type variable (unknown type during inference). +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct TypeVarId(u32); + +/// Represents a type variable used during inference. +/// +/// Type variables start unbound and may later be bound to a `TypeId`. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct TypeVar { + binding: Option, +} + +/// Central type storage for type inference. +/// +/// All type mutations go through here, making inference predictable. +pub struct TypeStore { + nodes: Vec, + type_vars: Vec, + next_id: u32, +} + +#[derive(Debug, Clone)] +enum TypeNode { + /// Fully known Kit type + Known(Type), + /// Inference-only placeholder + Unknown(TypeVarId), } -// Blanket implementation for references to types that implement ToCRepr. -// This allows calling `to_c_repr()` on references without explicit dereferencing. -impl ToCRepr for &T { - fn to_c_repr(&self) -> CType { - (*self).to_c_repr() +impl Default for TypeStore { + fn default() -> Self { + Self::new() + } +} + +impl TypeStore { + pub const fn new() -> Self { + Self { + nodes: Vec::new(), + type_vars: Vec::new(), + next_id: 0, + } + } + + /// Create a new known type from a Type enum. + pub fn new_known(&mut self, ty: Type) -> TypeId { + let id = TypeId(self.next_id); + self.next_id += 1; + self.nodes.push(TypeNode::Known(ty)); + id + } + + /// Create a new unknown type (type variable) for inference. + pub fn new_unknown(&mut self) -> TypeId { + let var_id = TypeVarId(self.type_vars.len() as u32); + self.type_vars.push(TypeVar { binding: None }); + let id = TypeId(self.next_id); + self.next_id += 1; + self.nodes.push(TypeNode::Unknown(var_id)); + id + } + + /// Bind a type variable to a specific type ID. + pub fn bind_type_var(&mut self, var_id: TypeVarId, ty: TypeId) -> Result<(), String> { + if let Some(existing) = self.type_vars.get_mut(var_id.0 as usize) { + if let Some(binding) = existing.binding { + return Err(format!( + "Type variable {var_id:?} already bound to {binding:?}" + )); + } + existing.binding = Some(ty); + Ok(()) + } else { + Err(format!("Type variable {var_id:?} does not exist")) + } + } + + /// Resolve a `TypeId` to its concrete Type. + /// + /// Follows type variable bindings. Returns error if any type variables remain unbound. + pub fn resolve(&self, mut id: TypeId) -> Result { + loop { + let Some(node) = self.nodes.get(id.0 as usize) else { + return Err(format!("Type ID {id:?} does not exist")); + }; + + id = match node { + TypeNode::Known(ty) => return Ok(ty.clone()), + TypeNode::Unknown(var_id) => self.resolve_var(id, *var_id)?, + }; + } + } + + fn resolve_var(&self, id: TypeId, var_id: TypeVarId) -> Result { + let Some(var) = self.type_vars.get(var_id.0 as usize) else { + return Err(format!( + "Type variable {var_id:?} does not exist in TypeStore", + )); + }; + + var.binding.ok_or_else(|| { + format!("Cannot resolve type ID {id:?}: type variable {var_id:?} is unbound") + }) + } + + /// Check if a `TypeId` is an unknown type variable. + pub fn is_unknown(&self, id: TypeId) -> bool { + matches!(self.nodes.get(id.0 as usize), Some(TypeNode::Unknown(_))) + } + + fn get_node(&self, id: TypeId) -> &TypeNode { + // We assume valid IDs here as they are managed by TypeStore + &self.nodes[id.0 as usize] + } + + /// Follow bindings to find the representative `TypeId`. + pub fn find_rep(&self, mut id: TypeId) -> TypeId { + loop { + match self.nodes.get(id.0 as usize) { + Some(TypeNode::Unknown(var_id)) => { + match self.type_vars.get(var_id.0 as usize) { + Some(TypeVar { + binding: Some(next_id), + }) => id = *next_id, + _ => return id, // Unbound + } + } + _ => return id, // Known + } + } + } + + /// Unify two type IDs (the core inference algorithm). + /// + /// Makes two types agree by either binding unknowns or comparing known types structurally. + pub fn unify(&mut self, a: TypeId, b: TypeId) -> Result<(), String> { + let rep_a = self.find_rep(a); + let rep_b = self.find_rep(b); + + if rep_a == rep_b { + return Ok(()); + } + + match (self.get_node(rep_a).clone(), self.get_node(rep_b).clone()) { + // Unknown + Anything + (TypeNode::Unknown(var_id), _) => self.bind_type_var(var_id, rep_b), + (_, TypeNode::Unknown(var_id)) => self.bind_type_var(var_id, rep_a), + + // Both Known -> structural comparison + (TypeNode::Known(ty_a), TypeNode::Known(ty_b)) => self.unify_types(&ty_a, &ty_b), + } + } + + /// Unify two known Type enum values structurally. + fn unify_types(&mut self, a: &Type, b: &Type) -> Result<(), String> { + match (a, b) { + // Simple type equality + (Type::Int8, Type::Int8) => Ok(()), + (Type::Int16, Type::Int16) => Ok(()), + (Type::Int32, Type::Int32) => Ok(()), + (Type::Int64, Type::Int64) => Ok(()), + (Type::Uint8, Type::Uint8) => Ok(()), + (Type::Uint16, Type::Uint16) => Ok(()), + (Type::Uint32, Type::Uint32) => Ok(()), + (Type::Uint64, Type::Uint64) => Ok(()), + (Type::Float32, Type::Float32) => Ok(()), + (Type::Float64, Type::Float64) => Ok(()), + (Type::Int | Type::Bool, Type::Int) | (Type::Int, Type::Bool) => Ok(()), + (Type::Float, Type::Float) => Ok(()), + (Type::Size, Type::Size) => Ok(()), + (Type::Char, Type::Char) => Ok(()), + (Type::Bool, Type::Bool) => Ok(()), + (Type::CString, Type::CString) => Ok(()), + (Type::Void, Type::Void) => Ok(()), + + // Pointer types: unify inner types + (Type::Ptr(t1), Type::Ptr(t2)) => self.unify_type_ids((**t1).clone(), (**t2).clone()), + + // Tuple types: unify element-wise + (Type::Tuple(v1), Type::Tuple(v2)) => { + if v1.len() != v2.len() { + return Err(format!( + "Cannot unify tuples of different sizes: {} vs {}", + v1.len(), + v2.len() + )); + } + for (elem1, elem2) in v1.iter().zip(v2.iter()) { + self.unify_type_ids(elem1.clone(), elem2.clone())?; + } + Ok(()) + } + + // Array types: unify element type and length + (Type::CArray(elem1, len1), Type::CArray(elem2, len2)) => { + if len1 != len2 { + return Err(format!( + "Cannot unify arrays of different sizes: {len1:?} vs {len2:?}" + )); + } + self.unify_type_ids((**elem1).clone(), (**elem2).clone()) + } + + // Named types: check string equality + (Type::Named(n1), Type::Named(n2)) => { + if n1 == n2 { + Ok(()) + } else { + Err(format!("Cannot unify different named types: {n1} vs {n2}")) + } + } + + // Everything else is a type mismatch + _ => Err(format!("Type mismatch: {a:?} vs {b:?}")), + } + } + + /// Helper to unify boxed Type values. + fn unify_type_ids(&mut self, a: Type, b: Type) -> Result<(), String> { + let a_id = self.new_known(a); + let b_id = self.new_known(b); + self.unify(a_id, b_id) } } /// Represents a type in the Kit language. /// -/// This enum covers both primitive C types and composite types. Note that floating-point variants -/// don't implement `Eq` or `Hash` by default, but we manually derive `PartialEq` and `Hash` for -/// practical usage in the compiler. The `Hash` implementation treats floating-point types as -/// having fixed bit patterns (which is valid since we only hash known constant types). +/// TODO: further description #[derive(Clone, Debug, PartialEq, Hash)] pub enum Type { /// User-defined named type (fallback for types not covered by other variants). @@ -69,268 +286,251 @@ pub enum Type { CString, /// Tuple type (represented as a struct in C). Tuple(Vec), - /// C array type (fixed or variable length). + /// C array type (TODO: is this variable length or fixed length?). /// - /// The second field is `Some(n)` for fixed-size arrays or `None` for variable-length arrays. - CArray(Box, Option), + /// ... + CArray(Box, usize), /// Represents a void type (e.g., for functions with no return value). Void, } impl Type { - /// Converts a Kit type name to its internal representation. - /// - /// This handles built-in types directly and falls back to `Named` for user-defined types. - pub fn from_kit(s: &str) -> Self { - match s { + pub fn from_kit(name: &str) -> Self { + match name { + "Int8" => Type::Int8, + "Int16" => Type::Int16, + "Int32" => Type::Int32, + "Int64" => Type::Int64, + "Uint8" => Type::Uint8, + "Uint16" => Type::Uint16, + "Uint32" => Type::Uint32, + "Uint64" => Type::Uint64, + "Float32" => Type::Float32, + "Float64" => Type::Float64, "Int" => Type::Int, "Float" => Type::Float, - "Char" => Type::Char, "Size" => Type::Size, - "CString" => Type::CString, + "Char" => Type::Char, "Bool" => Type::Bool, + "CString" => Type::CString, "Void" => Type::Void, - other => Type::Named(other.to_string()), + _ => Type::Named(name.to_string()), } } } -/// Unary operators supported in Kit expressions. -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] -pub enum UnaryOperator { - /// Logical NOT (`!`). - Not, - /// Arithmetic negation (`-`). - Negate, - /// Address-of operator (`&`). - AddressOf, - /// Pointer dereference (`*`). - Dereference, - /// Prefix increment (`++`). - Increment, - /// Prefix decrement (`--`). - Decrement, - /// Bitwise NOT (`~`). - BitwiseNot, + +#[derive(Clone, Debug, PartialEq, Eq)] +/// Represents ..., with ... +pub struct CRepr { + pub name: String, + pub declaration: Option, + pub headers: HashSet, } -impl UnaryOperator { - /// Formats the operator with its operand as a C expression string. - /// e.g. `to_string_with_expr("++", "x") -> "++x"` - pub fn to_string_with_expr(&self, expr: impl Into) -> String { - let expr = expr.into(); +pub trait ToCRepr { + fn to_c_repr(&self) -> CRepr; +} + +impl ToCRepr for Type { + fn to_c_repr(&self) -> CRepr { match self { - UnaryOperator::Not => format!("!{}", expr), - UnaryOperator::Negate => format!("-{}", expr), - UnaryOperator::AddressOf => format!("&{}", expr), - UnaryOperator::Dereference => format!("*{}", expr), - UnaryOperator::Increment => format!("++{}", expr), - UnaryOperator::Decrement => format!("--{}", expr), - UnaryOperator::BitwiseNot => format!("~{}", expr), + Type::Int8 => simple_c_type("int8_t", &["stdint.h"]), + Type::Int16 => simple_c_type("int16_t", &["stdint.h"]), + Type::Int32 => simple_c_type("int32_t", &["stdint.h"]), + Type::Int64 => simple_c_type("int64_t", &["stdint.h"]), + Type::Uint8 => simple_c_type("uint8_t", &["stdint.h"]), + Type::Uint16 => simple_c_type("uint16_t", &["stdint.h"]), + Type::Uint32 => simple_c_type("uint32_t", &["stdint.h"]), + Type::Uint64 => simple_c_type("uint64_t", &["stdint.h"]), + Type::Float32 | Type::Float => simple_c_type("float", &[]), + Type::Float64 => simple_c_type("double", &[]), + Type::Int => simple_c_type("int", &[]), + Type::Size => simple_c_type("size_t", &["stddef.h"]), + Type::Char => simple_c_type("char", &[]), + Type::Bool => simple_c_type("bool", &["stdbool.h"]), + Type::CString => simple_c_type("char*", &[]), + Type::Void => simple_c_type("void", &[]), + Type::Ptr(inner) => { + let inner_repr = inner.to_c_repr(); + let headers = inner_repr.headers; + CRepr { + name: format!("{}*", inner_repr.name), + declaration: inner_repr.declaration, + headers, + } + } + Type::Named(name) => simple_c_type(name, &[]), + _ => simple_c_type("void*", &[]), // Fallback } } } -impl FromStr for UnaryOperator { - type Err = (); - - /// Parses a unary operator from its string representation. - /// - /// Returns `Err(())` for invalid operator strings. - fn from_str(s: &str) -> Result { - match s { - "!" => Ok(UnaryOperator::Not), - "-" => Ok(UnaryOperator::Negate), - "&" => Ok(UnaryOperator::AddressOf), - "*" => Ok(UnaryOperator::Dereference), - "++" => Ok(UnaryOperator::Increment), - "--" => Ok(UnaryOperator::Decrement), - "~" => Ok(UnaryOperator::BitwiseNot), - _ => Err(()), - } +fn simple_c_type(name: &str, headers: &[&str]) -> CRepr { + let mut h = HashSet::new(); + for header in headers { + h.insert(format!("<{header}>")); + } + CRepr { + name: name.to_string(), + declaration: None, + headers: h, } } -/// Binary operators supported in Kit expressions. -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +#[derive(Clone, Debug, PartialEq)] pub enum BinaryOperator { - /// Additive Add, - /// Subtractive: `-` - Subtract, - // Multiplicative - Multiply, - /// Divide - Divide, - /// Modulo - Modulo, - /// Equality + Sub, + Mul, + Div, + Mod, Eq, - /// Negative equality - Neq, - /// Greater than - Gt, - /// Greater than or equal - Gte, + /// Not equal + Ne, /// Less than Lt, + /// Greater than + Gt, /// Less than or equal - Lte, - // Logical + Le, + /// Greater than or equal + Ge, And, - /// Logical Or, - /// Bitwise AND - BitwiseAnd, - /// Bitwise OR - BitwiseOr, - /// Bitwise XOR - BitwiseXor, - /// Bitwise left shift - BitwiseLeftShift, - /// Bitwise right shift - BitwiseRightShift, + BitAnd, + BitOr, + BitXor, + /// Shift Left + Shl, + /// Shift Right + Shr, } impl BinaryOperator { + pub fn to_c_str(&self) -> &'static str { + match self { + BinaryOperator::Add => "+", + BinaryOperator::Sub => "-", + BinaryOperator::Mul => "*", + BinaryOperator::Div => "/", + BinaryOperator::Mod => "%", + BinaryOperator::Eq => "==", + BinaryOperator::Ne => "!=", + BinaryOperator::Lt => "<", + BinaryOperator::Gt => ">", + BinaryOperator::Le => "<=", + BinaryOperator::Ge => ">=", + BinaryOperator::And => "&&", + BinaryOperator::Or => "||", + BinaryOperator::BitAnd => "&", + BinaryOperator::BitOr => "|", + BinaryOperator::BitXor => "^", + BinaryOperator::Shl => "<<", + BinaryOperator::Shr => ">>", + } + } + pub fn from_rule_pair(pair: &Pair) -> Result { match pair.as_rule() { Rule::additive_op => match pair.as_str() { "+" => Ok(BinaryOperator::Add), - "-" => Ok(BinaryOperator::Subtract), - _ => Err(CompilationError::InvalidOperator(format!( - "Unknown additive operator: {}", - pair.as_str() - ))), + "-" => Ok(BinaryOperator::Sub), + _ => Err(CompilationError::InvalidOperator(pair.as_str().to_string())), }, Rule::multiplicative_op => match pair.as_str() { - "*" => Ok(BinaryOperator::Multiply), - "/" => Ok(BinaryOperator::Divide), - "%" => Ok(BinaryOperator::Modulo), - _ => Err(CompilationError::InvalidOperator(format!( - "Unknown multiplicative operator: {}", - pair.as_str() - ))), + "*" => Ok(BinaryOperator::Mul), + "/" => Ok(BinaryOperator::Div), + "%" => Ok(BinaryOperator::Mod), + _ => Err(CompilationError::InvalidOperator(pair.as_str().to_string())), }, Rule::eq_op => match pair.as_str() { "==" => Ok(BinaryOperator::Eq), - "!=" => Ok(BinaryOperator::Neq), - _ => Err(CompilationError::InvalidOperator(format!( - "Unknown equality operator: {}", - pair.as_str() - ))), + "!=" => Ok(BinaryOperator::Ne), + _ => Err(CompilationError::InvalidOperator(pair.as_str().to_string())), }, Rule::comp_op => match pair.as_str() { - ">" => Ok(BinaryOperator::Gt), - ">=" => Ok(BinaryOperator::Gte), "<" => Ok(BinaryOperator::Lt), - "<=" => Ok(BinaryOperator::Lte), - _ => Err(CompilationError::InvalidOperator(format!( - "Unknown comparison operator: {}", - pair.as_str() - ))), - }, - Rule::and_ops => match pair.as_str() { - "&&" => Ok(BinaryOperator::And), - "&" => Ok(BinaryOperator::BitwiseAnd), - _ => unreachable!(), // Should not happen with atomic rules + ">" => Ok(BinaryOperator::Gt), + "<=" => Ok(BinaryOperator::Le), + ">=" => Ok(BinaryOperator::Ge), + _ => Err(CompilationError::InvalidOperator(pair.as_str().to_string())), }, - Rule::logical_or_op => Ok(BinaryOperator::Or), // Use new logical_or_op - Rule::bitwise_or_op => Ok(BinaryOperator::BitwiseOr), // Use new bitwise_or_op - Rule::bitwise_xor_op => Ok(BinaryOperator::BitwiseXor), + Rule::and_ops => Ok(BinaryOperator::And), // && + Rule::logical_or_op => Ok(BinaryOperator::Or), // || + Rule::bitwise_or_op => Ok(BinaryOperator::BitOr), + Rule::bitwise_xor_op => Ok(BinaryOperator::BitXor), Rule::shift_op => match pair.as_str() { - "<<" => Ok(BinaryOperator::BitwiseLeftShift), - ">>" => Ok(BinaryOperator::BitwiseRightShift), - _ => Err(CompilationError::InvalidOperator(format!( - "Unknown shift operator: {}", - pair.as_str() - ))), + "<<" => Ok(BinaryOperator::Shl), + ">>" => Ok(BinaryOperator::Shr), + _ => Err(CompilationError::InvalidOperator(pair.as_str().to_string())), }, + // Need to check specific logic for & vs && in grammar _ => Err(CompilationError::InvalidOperator(format!( - "Unexpected rule for binary operator: {:?}", + "{:?}", pair.as_rule() ))), } } } -impl BinaryOperator { +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub enum UnaryOperator { + Neg, + Not, + BitNot, + AddressOf, + Dereference, +} + +impl UnaryOperator { pub fn to_c_str(&self) -> &'static str { match self { - BinaryOperator::Add => "+", - BinaryOperator::Subtract => "-", - BinaryOperator::Multiply => "*", - BinaryOperator::Divide => "/", - BinaryOperator::Modulo => "%", - BinaryOperator::Eq => "==", - BinaryOperator::Neq => "!=", - BinaryOperator::Gt => ">", - BinaryOperator::Gte => ">=", - BinaryOperator::Lt => "<", - BinaryOperator::Lte => "<=", - BinaryOperator::And => "&&", - BinaryOperator::Or => "||", - BinaryOperator::BitwiseAnd => "&", - BinaryOperator::BitwiseOr => "|", - BinaryOperator::BitwiseXor => "^", - BinaryOperator::BitwiseLeftShift => "<<", - BinaryOperator::BitwiseRightShift => ">>", + UnaryOperator::Neg => "-", + UnaryOperator::Not => "!", + UnaryOperator::BitNot => "~", + UnaryOperator::AddressOf => "&", + UnaryOperator::Dereference => "*", + } + } +} + +impl FromStr for UnaryOperator { + type Err = (); + fn from_str(s: &str) -> Result { + match s { + "-" => Ok(UnaryOperator::Neg), + "!" => Ok(UnaryOperator::Not), + "~" => Ok(UnaryOperator::BitNot), + // AddressOf is typically handled separately in parser due to grammar structure + _ => Err(()), } } } -/// Assignment operators supported in Kit expressions. #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] pub enum AssignmentOperator { - /// Simple assignment (`=`). + /// Simple assignment Assign, - /// Add and assign (`+=`). + /// Add assignment (+=) AddAssign, - /// Subtract and assign (`-=`). - SubtractAssign, - /// Multiply and assign (`*=`). - MultiplyAssign, - /// Divide and assign (`/=`). - DivideAssign, - /// Modulo and assign (`%=`). - ModuloAssign, - /// Bitwise AND and assign (`&=`). - BitwiseAndAssign, - /// Bitwise OR and assign (`|=`). - BitwiseOrAssign, - /// Bitwise XOR and assign (`^=`). - BitwiseXorAssign, - /// Bitwise left shift and assign (`<<=`). - BitwiseLeftShiftAssign, - /// Bitwise right shift and assign (`>>=`). - BitwiseRightShiftAssign, -} - -impl AssignmentOperator { - pub fn from_rule_pair(pair: &Pair) -> Result { - match pair.as_rule() { - Rule::ASSIGN_OP => match pair.as_str() { - "=" => Ok(AssignmentOperator::Assign), - "+=" => Ok(AssignmentOperator::AddAssign), - "-=" => Ok(AssignmentOperator::SubtractAssign), - "*=" => Ok(AssignmentOperator::MultiplyAssign), - "/=" => Ok(AssignmentOperator::DivideAssign), - "%=" => Ok(AssignmentOperator::ModuloAssign), - "&=" => Ok(AssignmentOperator::BitwiseAndAssign), - "|=" => Ok(AssignmentOperator::BitwiseOrAssign), - "^=" => Ok(AssignmentOperator::BitwiseXorAssign), - "<<=" => Ok(AssignmentOperator::BitwiseLeftShiftAssign), - ">>=" => Ok(AssignmentOperator::BitwiseRightShiftAssign), - _ => Err(CompilationError::InvalidOperator(format!( - "Unknown assignment operator: {}", - pair.as_str() - ))), - }, - _ => Err(CompilationError::InvalidOperator(format!( - "Unexpected rule for assignment operator: {:?}", - pair.as_rule() - ))), - } - } + /// Subtract assignment (-=) + SubAssign, + /// Multiply assignment (*=) + MulAssign, + /// Divide assignment (/=) + DivAssign, + /// Modulo assignment (%=) + ModAssign, + /// Bitwise and assignment (&=) + AndAssign, + /// Bitwise or assignment (|=) + OrAssign, + /// Bitwise xor assignment (^=) + XorAssign, + /// Shift left assignment (<<=) + ShlAssign, + /// Shift right assignment (>>=) + ShrAssign, } impl AssignmentOperator { @@ -338,217 +538,32 @@ impl AssignmentOperator { match self { AssignmentOperator::Assign => "=", AssignmentOperator::AddAssign => "+=", - AssignmentOperator::SubtractAssign => "-=", - AssignmentOperator::MultiplyAssign => "*=", - AssignmentOperator::DivideAssign => "/=", - AssignmentOperator::ModuloAssign => "%=", - AssignmentOperator::BitwiseAndAssign => "&=", - AssignmentOperator::BitwiseOrAssign => "|=", - AssignmentOperator::BitwiseXorAssign => "^=", - AssignmentOperator::BitwiseLeftShiftAssign => "<<=", - AssignmentOperator::BitwiseRightShiftAssign => ">>=", + AssignmentOperator::SubAssign => "-=", + AssignmentOperator::MulAssign => "*=", + AssignmentOperator::DivAssign => "/=", + AssignmentOperator::ModAssign => "%=", + AssignmentOperator::AndAssign => "&=", + AssignmentOperator::OrAssign => "|=", + AssignmentOperator::XorAssign => "^=", + AssignmentOperator::ShlAssign => "<<=", + AssignmentOperator::ShrAssign => ">>=", } } -} -/// C type representation for code generation. -/// -/// This struct encapsulates all information needed to generate a C type: -/// - The type name as it appears in C code -/// - Required header dependencies -/// - Optional type declaration (for structs or typedefs) -#[derive(Clone, Debug, PartialEq, Eq, Hash)] -pub struct CType { - /// The C type name (e.g., "int", "MyStruct", "uint32_t"). - pub name: String, - /// Headers required for this type (e.g., [""]). - pub headers: Vec, - /// Custom declaration needed for this type (e.g., struct definitions). - /// `None` for primitive/built-in types. - pub declaration: Option, -} -impl CType { - /// Creates a new C type representation with no headers or declaration. - fn new(name: impl Into) -> Self { - Self { - name: name.into(), - headers: Vec::new(), - declaration: None, - } - } - - /// Creates a C type that requires a specific header. - fn with_header(name: impl Into, header: impl Into) -> Self { - Self { - name: name.into(), - headers: vec![header.into()], - declaration: None, - } - } -} - -/// Creates a sanitized C identifier from a Kit type. -/// -/// This is used for generating struct names for tuples and arrays. The identifier: -/// - Uses standard C fixed-width integer names (e.g., "int32_t" instead of "i32") -/// - Uses "float" and "double" for floating-point types -/// - Prefixes composite types with "__Kit" to avoid naming conflicts -/// - Escapes type structures into valid C identifiers -fn type_to_c_ident_string(t: &Type) -> String { - match t { - Type::Named(s) => s.clone(), - Type::Ptr(inner) => format!("{}_ptr", type_to_c_ident_string(inner)), - Type::Int8 => "int8_t".to_string(), - Type::Int16 => "int16_t".to_string(), - Type::Int32 => "int32_t".to_string(), - Type::Int64 => "int64_t".to_string(), - Type::Uint8 => "uint8_t".to_string(), - Type::Uint16 => "uint16_t".to_string(), - Type::Uint32 => "uint32_t".to_string(), - Type::Uint64 => "uint64_t".to_string(), - Type::Float32 => "float".to_string(), - Type::Float64 => "double".to_string(), - Type::Int => "int".to_string(), - Type::Float => "float".to_string(), - Type::Size => "size_t".to_string(), - Type::Char => "char".to_string(), - Type::Bool => "bool".to_string(), - Type::CString => "cstring".to_string(), - Type::Tuple(types) => { - let member_types = types - .iter() - .map(type_to_c_ident_string) - .collect::>() - .join("_"); - format!("__KitTuple_{}", member_types) - } - Type::CArray(inner, _) => format!("{}__KitArray", type_to_c_ident_string(inner)), - Type::Void => "void".to_string(), - } -} - -impl ToCRepr for Type { - fn to_c_repr(&self) -> CType { - // small helper to reduce repetition for header-backed types - fn hdr(name: &str, header: &str) -> CType { - CType::with_header(name, header) - } - - match self { - // fixed-width integer types -> - Type::Int8 => hdr("int8_t", ""), - Type::Int16 => hdr("int16_t", ""), - Type::Int32 => hdr("int32_t", ""), - Type::Int64 => hdr("int64_t", ""), - Type::Uint8 => hdr("uint8_t", ""), - Type::Uint16 => hdr("uint16_t", ""), - Type::Uint32 => hdr("uint32_t", ""), - Type::Uint64 => hdr("uint64_t", ""), - - Type::Float32 => CType::new("float"), - Type::Float64 => CType::new("double"), - Type::Int => CType::new("int"), - Type::Float => CType::new("float"), - Type::Size => CType::with_header("size_t", ""), - Type::Char => CType::new("char"), - Type::Bool => CType::with_header("bool", ""), - Type::CString => CType::new("char*"), - - Type::Ptr(inner) => { - let mut c = inner.to_c_repr(); - c.name = format!("{}*", c.name); // clearer than push - c - } - - // Transform a tuple type into a C struct definition - Type::Tuple(fields) => { - // Mangle each field's C identifier and join with '_', e.g., "i32_f64_..." - // to avoid name conflicts when tuples with same size have different types. - // In practice, this makes (Int, Int) and (Float, Float) two entirely - // different types when converted to C. - let type_names_mangled = fields - .iter() - .map(type_to_c_ident_string) - .collect::>() - .join("_"); - - // Build a unique struct name using the mangled field list - let struct_name = format!("__KitTuple_{}", type_names_mangled); - - // Collect all required headers and declarations from the fields - let mut all_headers = HashSet::new(); - let mut all_declarations = Vec::new(); - - // Generate the struct members (like "int _0;") and gather headers/decls - let members = fields - .iter() - .enumerate() - .map(|(i, f)| { - let c = f.to_c_repr(); - - // collect needed headers and declarations - all_headers.extend(c.headers); - if let Some(decl) = c.declaration { - all_declarations.push(decl); - } - - format!(" {} _{};\n", c.name, i) // struct member line - }) - .collect::(); - - // Append the full typedef for the tuple struct - all_declarations.push(format!( - "typedef struct {{\n{}}} {};\n", - members, struct_name - )); - - // Return the C type description for the tuple - CType { - name: struct_name, - headers: all_headers.into_iter().collect(), - declaration: Some(all_declarations.join("\n")), - } - } - - // Transform a C-array type - Type::CArray(elem, len) => { - let base = elem.to_c_repr(); - - // Fixed-size array, like int[10] - if let Some(n) = len { - let mut ctype = base; - ctype.name = format!("{}[{}]", ctype.name, n); - ctype - // Variable-length array: convert to wrapper struct with length + pointer - } else { - let type_name_mangled = type_to_c_ident_string(elem); - let struct_name = format!("__KitArray_{}", type_name_mangled); - let decl = format!( - "typedef struct {{\n size_t len;\n {} *data;\n}} {};\n", - base.name, struct_name - ); - - // Gather headers (need for size_t) and any nested decls - let mut all_headers: HashSet = base.headers.into_iter().collect(); - all_headers.insert("".to_string()); - - let mut all_declarations = Vec::new(); - if let Some(d) = base.declaration { - all_declarations.push(d); - } - all_declarations.push(decl); - - CType { - name: struct_name, - headers: all_headers.into_iter().collect(), - declaration: Some(all_declarations.join("\n")), - } - } - } - Type::Void => CType::new("void"), - - // User-defined types are assumed to be already declared elsewhere - Type::Named(name) => CType::new(name.to_string()), + pub fn from_rule_pair(pair: &Pair) -> Result { + match pair.as_str() { + "=" => Ok(AssignmentOperator::Assign), + "+=" => Ok(AssignmentOperator::AddAssign), + "-=" => Ok(AssignmentOperator::SubAssign), + "*=" => Ok(AssignmentOperator::MulAssign), + "/=" => Ok(AssignmentOperator::DivAssign), + "%=" => Ok(AssignmentOperator::ModAssign), + "&=" => Ok(AssignmentOperator::AndAssign), + "|=" => Ok(AssignmentOperator::OrAssign), + "^=" => Ok(AssignmentOperator::XorAssign), + "<<=" => Ok(AssignmentOperator::ShlAssign), + ">>=" => Ok(AssignmentOperator::ShrAssign), + _ => Err(CompilationError::InvalidOperator(pair.as_str().to_string())), } } } diff --git a/kitlang/src/error/mod.rs b/kitlang/src/error/mod.rs index 47e2828..fe31725 100644 --- a/kitlang/src/error/mod.rs +++ b/kitlang/src/error/mod.rs @@ -1,5 +1,7 @@ use thiserror::Error; +pub type CompileResult = Result; + #[derive(Error, Debug)] pub enum CompilationError { #[error("Failed to compile: {0}")] @@ -11,6 +13,9 @@ pub enum CompilationError { #[error("Invalid operator: {0}")] InvalidOperator(String), + #[error("Type error: {0}")] + TypeError(String), + #[error("Failed to compile C code:\n{}", String::from_utf8_lossy(.0))] CCompileError(Vec), @@ -30,12 +35,7 @@ pub enum CompilationError { /// Helper macro to create a `CompilationError::ParseError` #[macro_export] macro_rules! parse_error { - // No arguments: just a literal string - ( $msg:literal ) => { - $crate::error::CompilationError::ParseError($msg.to_string()) - }; - // Literal with one or more format arguments - ( $fmt:literal, $($arg:tt)+ ) => { - $crate::error::CompilationError::ParseError(format!($fmt, $($arg)+)) + ( $($arg:tt)* ) => { + $crate::error::CompilationError::ParseError(format!($($arg)*)) }; } diff --git a/kitlang/src/grammar/kit.pest b/kitlang/src/grammar/kit.pest index 3a1f206..c15b299 100644 --- a/kitlang/src/grammar/kit.pest +++ b/kitlang/src/grammar/kit.pest @@ -163,4 +163,8 @@ char_literal = @{ "'" ~ ( "\\'" | !"'" ~ ANY ) ~ "'" } string = @{ "\"" ~ ( "\\" ~ ANY | !"\"" ~ ANY )* ~ "\"" } -identifier = @{ (ASCII_ALPHA | "_") ~ (ASCII_ALPHANUMERIC | "_")* } +identifier = @{ + !("true" | "false" | "null" | "this" | "Self") + ~ (ASCII_ALPHA | "_") + ~ (ASCII_ALPHANUMERIC | "_")* +}