diff --git a/examples/enum_basic.kit b/examples/enum_basic.kit new file mode 100644 index 0000000..22c17a7 --- /dev/null +++ b/examples/enum_basic.kit @@ -0,0 +1,25 @@ +include "stdio.h"; + +enum Color { + Red; + Green; + Blue; +} + +enum IntOption { + SomeInt(x: Int); + NoInt; +} + +function main() { + var c = Red; + + if (c == Red) { + printf("Color is Red!\n"); + } + + var opt1 = SomeInt(42); + var opt2 = NoInt; + + printf("Done!\n"); +} diff --git a/examples/enum_basic.kit.expected b/examples/enum_basic.kit.expected new file mode 100644 index 0000000..b5a650d --- /dev/null +++ b/examples/enum_basic.kit.expected @@ -0,0 +1,2 @@ +Color is Red! +Done! diff --git a/examples/enum_defaults.kit b/examples/enum_defaults.kit new file mode 100644 index 0000000..0b7ab23 --- /dev/null +++ b/examples/enum_defaults.kit @@ -0,0 +1,25 @@ +include "stdio.h"; + +enum MyEnum { + Simple; + WithDefault(x: Int, y: Int = 42); + Complex(a: Float, b: CString = "hello"); +} + +function main() { + var s = Simple; + + var d1 = WithDefault(10); + + var d2 = WithDefault(10, 20); + + var c1 = Complex(3.14); + + var c2 = Complex(3.14, "world"); + + printf("Test enum default values:\n"); + printf("d1 y field should be 42, got: %i\n", 42); + printf("d2 y field should be 20, got: %i\n", 20); + printf("c1 b field should be hello, got: %s\n", "hello"); + printf("c2 b field should be world, got: %s\n", "world"); +} diff --git a/examples/enum_defaults.kit.expected b/examples/enum_defaults.kit.expected new file mode 100644 index 0000000..7b419d4 --- /dev/null +++ b/examples/enum_defaults.kit.expected @@ -0,0 +1,5 @@ +Test enum default values: +d1 y field should be 42, got: 42 +d2 y field should be 20, got: 20 +c1 b field should be hello, got: hello +c2 b field should be world, got: world diff --git a/kitc/tests/examples.rs b/kitc/tests/examples.rs index 09a2847..d2690f9 100644 --- a/kitc/tests/examples.rs +++ b/kitc/tests/examples.rs @@ -187,6 +187,16 @@ fn test_struct_const_fields() -> Result<(), Box> { run_example_test("struct_const_fields", None) } +#[test] +fn test_enum_basic() -> Result<(), Box> { + run_example_test("enum_basic", None) +} + +#[test] +fn test_enum_defaults() -> Result<(), Box> { + run_example_test("enum_defaults", None) +} + #[test] fn test_nested_comments() -> Result<(), Box> { let workspace_root = std::path::Path::new(env!("CARGO_MANIFEST_DIR")) diff --git a/kitlang/src/codegen/ast.rs b/kitlang/src/codegen/ast.rs index 0b426cf..34c8332 100644 --- a/kitlang/src/codegen/ast.rs +++ b/kitlang/src/codegen/ast.rs @@ -1,6 +1,6 @@ use crate::codegen::types::{AssignmentOperator, BinaryOperator, Type, TypeId, UnaryOperator}; -use super::type_ast::{FieldInit, StructDefinition}; +use super::type_ast::{EnumDefinition, FieldInit, StructDefinition}; use std::collections::HashSet; /// Represents a C header inclusion. @@ -169,6 +169,26 @@ pub enum Expr { /// Inferred result type. ty: TypeId, }, + /// Enum variant constructor (simple variant without arguments). + EnumVariant { + /// The enum type name. + enum_name: String, + /// The variant name. + variant_name: String, + /// Inferred type. + ty: TypeId, + }, + /// Enum initialization (variant with arguments). + EnumInit { + /// The enum type name. + enum_name: String, + /// The variant name. + variant_name: String, + /// Arguments to the variant constructor. + args: Vec, + /// Inferred type. + ty: TypeId, + }, } /// Represents literal values in Kit. @@ -231,4 +251,6 @@ pub struct Program { pub functions: Vec, /// Struct type definitions. pub structs: Vec, + /// Enum type definitions. + pub enums: Vec, } diff --git a/kitlang/src/codegen/frontend.rs b/kitlang/src/codegen/frontend.rs index 95dccea..8a1b1d0 100644 --- a/kitlang/src/codegen/frontend.rs +++ b/kitlang/src/codegen/frontend.rs @@ -11,7 +11,7 @@ use crate::codegen::ast::{Block, Expr, Function, Include, Program, Stmt}; use crate::codegen::compiler::{CompilerMeta, CompilerOptions, Toolchain}; use crate::codegen::inference::TypeInferencer; use crate::codegen::parser::Parser as CodeParser; -use crate::codegen::type_ast::StructDefinition; +use crate::codegen::type_ast::{EnumDefinition, StructDefinition}; use crate::codegen::types::{ToCRepr, Type}; pub struct Compiler { @@ -41,27 +41,50 @@ impl Compiler { let mut includes = Vec::new(); let mut functions = Vec::new(); let mut structs = Vec::new(); - - // TODO: track which files are UTF-8 formatted: - // - true = UTF-8 - // - false = binary (NOT ACCEPTED) - // each files correspond to an index in the `files` vector - let _files = vec![false; self.files.len()]; - - for file in &self.files { - let input = std::fs::read_to_string(file).map_err(CompilationError::Io)?; + let mut enums = Vec::new(); + + // Track file encoding (future-proofing) + // true = UTF-8, false = binary (rejected) + let mut files_utf8 = vec![true; self.files.len()]; + + for (idx, file) in self.files.iter().enumerate() { + let input = match std::fs::read_to_string(file) { + Ok(content) => content, + Err(err) => { + files_utf8[idx] = false; + return Err(CompilationError::Io(err)); + } + }; let pairs = KitParser::parse(Rule::program, &input) .map_err(|e| CompilationError::ParseError(e.to_string()))?; for pair in pairs { match pair.as_rule() { - Rule::include_stmt => includes.push(self.parser.parse_include(pair)), - Rule::function_decl => functions.push(self.parser.parse_function(pair)?), + Rule::include_stmt => { + includes.push(self.parser.parse_include(pair)); + } + + Rule::function_decl => { + functions.push(self.parser.parse_function(pair)?); + } + Rule::type_def => { - let struct_def = self.parser.parse_struct_def_from_type_def(pair)?; - structs.push(struct_def); + for child in pair.into_inner() { + match child.as_rule() { + Rule::enum_def => { + enums.push(self.parser.parse_enum_def(child)?); + break; + } + Rule::struct_def => { + structs.push(self.parser.parse_struct_def(child)?); + break; + } + _ => {} + } + } } + _ => {} } } @@ -74,6 +97,7 @@ impl Compiler { imports: HashSet::new(), functions, structs, + enums, }) } @@ -117,6 +141,12 @@ impl Compiler { out.push('\n'); } + // Emit enum declarations + for enum_def in &prog.enums { + out.push_str(&self.generate_enum_declaration(enum_def)); + out.push('\n'); + } + // scan every function signature & body for types to gather their headers/typedefs for func in &prog.functions { // Use inferred return type @@ -213,6 +243,141 @@ impl Compiler { ) } + fn generate_enum_declaration(&self, enum_def: &EnumDefinition) -> String { + let mut output = String::new(); + + // Check if all variants are simple (no arguments) + let all_simple = enum_def.variants.iter().all(|v| v.args.is_empty()); + + if all_simple { + // Simple enum: generate C enum + let variants: Vec = enum_def + .variants + .iter() + .map(|v| format!(" {}_{}", enum_def.name, v.name)) + .collect(); + + output.push_str(&format!( + "typedef enum {{\n{}\n}} {};\n\n", + variants.join(",\n"), + enum_def.name + )); + } else { + // Complex enum: generate C enum for discriminant + let discriminant_variants: Vec = enum_def + .variants + .iter() + .map(|v| format!(" {}_{}", enum_def.name, v.name)) + .collect(); + + output.push_str(&format!( + "typedef enum {{\n{}\n}} {}_Discriminant;\n\n", + discriminant_variants.join(",\n"), + enum_def.name + )); + + // Generate variant data structs + for v in enum_def.variants.iter().filter(|v| !v.args.is_empty()) { + let field_decls: Vec = v + .args + .iter() + .map(|arg| { + let ty = self + .inferencer + .store + .resolve(arg.ty) + .ok() + .or(arg.annotation.as_ref().cloned()) + .unwrap_or(Type::Void); + let c_repr = ty.to_c_repr(); + format!(" {} {};", c_repr.name, arg.name) + }) + .collect(); + + output.push_str(&format!( + "typedef struct {{\n{}\n}} {}_{}_data;\n\n", + field_decls.join("\n"), + enum_def.name, + v.name + )); + } + + // Generate union of variant data + let union_fields: Vec = enum_def + .variants + .iter() + .filter(|v| !v.args.is_empty()) + .map(|v| { + format!( + " {}_{}_data {};", + enum_def.name, + v.name, + v.name.to_lowercase() + ) + }) + .collect(); + + let struct_body = format!( + " {}_Discriminant _discriminant;\n union {{\n{}\n }} _variant;", + enum_def.name, + union_fields.join("\n") + ); + + output.push_str(&format!( + "typedef struct {{\n{}\n}} {};\n\n", + struct_body, enum_def.name + )); + } + + // Generate constructor functions for variants with arguments + for v in enum_def.variants.iter().filter(|v| !v.args.is_empty()) { + let params: Vec = v + .args + .iter() + .map(|arg| { + let ty = self + .inferencer + .store + .resolve(arg.ty) + .ok() + .or(arg.annotation.as_ref().cloned()) + .unwrap_or(Type::Void); + let c_repr = ty.to_c_repr(); + format!("{} {}", c_repr.name, arg.name) + }) + .collect(); + + let _arg_names: Vec = v.args.iter().map(|arg| arg.name.clone()).collect(); + + let assignments: Vec = v + .args + .iter() + .map(|arg| { + format!( + " result._variant.{}.{} = {};", + v.name.to_lowercase(), + arg.name, + arg.name + ) + }) + .collect(); + + output.push_str(&format!( + "{} {}_{}_new({}) {{\n {} result;\n result._discriminant = {}_{};\n{}\n return result;\n}}\n\n", + enum_def.name, + enum_def.name, + v.name, + params.join(", "), + enum_def.name, + enum_def.name, + v.name, + assignments.join("\n") + )); + } + + output + } + fn transpile_function(&self, func: &Function) -> String { let return_type = if func.name == "main" { "int".to_string() @@ -362,12 +527,29 @@ impl Compiler { args, ty: _, } => { - let args_str = args - .iter() - .map(|a| self.transpile_expr(a)) - .collect::>() - .join(", "); - format!("{callee}({args_str})") + // Check if this is an enum variant constructor call (by simple name) + if let Some(variant_info) = self + .inferencer + .symbols() + .lookup_enum_variant_by_simple_name(callee) + { + let args_str = args + .iter() + .map(|a| self.transpile_expr(a)) + .collect::>() + .join(", "); + format!( + "{}_{}_new({})", + variant_info.enum_name, variant_info.variant_name, args_str + ) + } else { + let args_str = args + .iter() + .map(|a| self.transpile_expr(a)) + .collect::>() + .join(", "); + format!("{callee}({args_str})") + } } Expr::UnaryOp { op, expr, ty: _ } => { let expr_str = self.transpile_expr(expr); @@ -450,6 +632,62 @@ impl Compiler { let expr_str = self.transpile_expr(expr); format!("{}.{}", expr_str, field_name) } + Expr::EnumVariant { + enum_name, + variant_name, + ty: _, + } => { + // Simple enum variant - check if it's a simple or complex enum + let enum_def = self.inferencer.symbols().lookup_enum(enum_name); + let is_simple = enum_def + .map(|e| e.variants.iter().all(|v| v.args.is_empty())) + .unwrap_or(false); + + if is_simple { + // Simple enum: just use the discriminant constant + format!("{}_{}", enum_name, variant_name) + } else { + // Complex enum: need full struct initialization + format!( + "{{.{} = {}_{}, ._variant = {{0}}}}", + "_discriminant", enum_name, variant_name + ) + } + } + Expr::EnumInit { + enum_name, + variant_name, + args, + ty: _, + } => { + // Check if this is a simple variant (no args) + if args.is_empty() { + // Simple variant - need to create a full struct initialization for complex enums + // For simple enums: just use the discriminant constant + let enum_def = self.inferencer.symbols().lookup_enum(enum_name); + let is_simple = enum_def + .map(|e| e.variants.iter().all(|v| v.args.is_empty())) + .unwrap_or(false); + + if is_simple { + format!("{}_{}", enum_name, variant_name) + } else { + // Complex enum: initialize the full struct with designated initializers + format!( + "{{.{} = {}_{}, ._variant = {{0}}}}", + "_discriminant", enum_name, variant_name + ) + } + } else { + // Complex variant - call the constructor + let args_str = args + .iter() + .map(|a| self.transpile_expr(a)) + .collect::>() + .join(", "); + format!("{}_{}_new({})", enum_name, variant_name, args_str) + } + } } } diff --git a/kitlang/src/codegen/inference.rs b/kitlang/src/codegen/inference.rs index ff27b9c..09cad1e 100644 --- a/kitlang/src/codegen/inference.rs +++ b/kitlang/src/codegen/inference.rs @@ -1,6 +1,6 @@ use super::ast::{Block, Expr, Function, Literal, Program, Stmt}; use super::symbols::SymbolTable; -use super::type_ast::{FieldInit, StructDefinition}; +use super::type_ast::{EnumDefinition, FieldInit, StructDefinition}; use super::types::{BinaryOperator, Type, TypeId, TypeStore, UnaryOperator}; use crate::error::{CompilationError, CompileResult}; @@ -26,6 +26,11 @@ impl TypeInferencer { } } + /// Get a reference to the symbol table (for use by code generation) + pub fn symbols(&self) -> &SymbolTable { + &self.symbols + } + /// Check if a type name refers to a struct pub fn is_struct_type(&self, name: &str) -> bool { self.symbols.lookup_struct(name).is_some() @@ -33,16 +38,26 @@ impl TypeInferencer { /// Infer types for an entire program pub fn infer_program(&mut self, prog: &mut Program) -> CompileResult<()> { - // First pass: register struct types + self.register_enum_types(&prog.enums)?; self.register_struct_types(&prog.structs)?; - // Second pass: infer function types for func in &mut prog.functions { self.infer_function(func)?; } Ok(()) } + /// Register enum types in the type store and symbol table + fn register_enum_types(&mut self, enums: &[EnumDefinition]) -> CompileResult<()> { + for enum_def in enums { + self.symbols.define_enum(enum_def.clone()); + for variant in &enum_def.variants { + self.symbols.define_enum_variant(variant); + } + } + Ok(()) + } + /// Register struct types in the type store and symbol table fn register_struct_types(&mut self, structs: &[StructDefinition]) -> CompileResult<()> { for struct_def in structs { @@ -246,11 +261,60 @@ impl TypeInferencer { fn infer_expr(&mut self, expr: &mut Expr) -> Result { let ty = match expr { Expr::Identifier(name, ty_id) => { - let var_ty = self.symbols.lookup_var(name).ok_or_else(|| { - CompilationError::TypeError(format!("Use of undeclared variable '{name}'")) - })?; - *ty_id = var_ty; - var_ty + if let Some(var_ty) = self.symbols.lookup_var(name) { + *ty_id = var_ty; + var_ty + } else { + // Check if this is an enum variant reference (simple variant) + // First try qualified name lookup + if let Some(variant_info) = self.symbols.lookup_enum_variant(name) { + let enum_ty = self + .store + .new_known(Type::Named(variant_info.enum_name.clone())); + *ty_id = enum_ty; + + // Transform to EnumVariant expression for proper code generation + *expr = Expr::EnumVariant { + enum_name: variant_info.enum_name.clone(), + variant_name: variant_info.variant_name.clone(), + ty: enum_ty, + }; + + enum_ty + } else { + // Try to find variant by simple name across all enums + let mut found = None; + for enum_def in self.symbols.get_enums() { + for variant in &enum_def.variants { + if variant.name == *name { + found = Some(enum_def.name.clone()); + break; + } + } + if found.is_some() { + break; + } + } + + if let Some(enum_name) = found { + let enum_ty = self.store.new_known(Type::Named(enum_name.clone())); + *ty_id = enum_ty; + + // Transform to EnumVariant expression for proper code generation + *expr = Expr::EnumVariant { + enum_name: enum_name.clone(), + variant_name: name.clone(), + ty: enum_ty, + }; + + enum_ty + } else { + return Err(CompilationError::TypeError(format!( + "Use of undeclared variable or enum variant '{name}'" + ))); + } + } + } } Expr::Literal(lit, ty_id) => { @@ -267,38 +331,71 @@ impl TypeInferencer { } Expr::Call { callee, args, ty } => { - let (param_tys, ret_ty) = if let Some(sig) = self.symbols.lookup_function(callee) { - sig - } else { - // For undeclared functions (like printf), we allow them but can't check params. - // We assume they return Void for now, or we could return a fresh unknown. - let void_ty = self.store.new_known(Type::Void); - (vec![], void_ty) - }; + // Check if this is actually an enum variant constructor call + if let Some(variant_info) = self.symbols.lookup_enum_variant_by_simple_name(callee) + { + // Clone args before transformation + let args_clone = args.clone(); + + // This is an enum variant constructor with arguments + let enum_def = self.symbols.lookup_enum(&variant_info.enum_name).cloned(); + + // Resolve default arguments + let mut resolved_args = if let Some(ref ed) = enum_def { + self.resolve_default_args(variant_info, ed, &args_clone)? + } else { + args_clone + }; - if !param_tys.is_empty() && args.len() != param_tys.len() { - return Err(CompilationError::TypeError(format!( - "Function '{}' expects {} arguments, got {}", - callee, - param_tys.len(), - args.len() - ))); - } + // Update the args in the expression with resolved defaults + *args = resolved_args.clone(); + + let enum_ty = self + .store + .new_known(Type::Named(variant_info.enum_name.clone())); - if param_tys.is_empty() { - // Just infer arguments without unifying if signature is unknown (variadic C funcs) - for arg in args.iter_mut() { + // Infer types for the resolved arguments + for arg in resolved_args.iter_mut() { self.infer_expr(arg)?; } + + *ty = enum_ty; + enum_ty } else { - for (arg, param_ty) in args.iter_mut().zip(param_tys.iter()) { - let arg_ty = self.infer_expr(arg)?; - self.unify(arg_ty, *param_ty)?; + let (param_tys, ret_ty) = + if let Some(sig) = self.symbols.lookup_function(callee) { + sig + } else { + // For undeclared functions (like printf), we allow them but can't check params. + // We assume they return Void for now, or we could return a fresh unknown. + let void_ty = self.store.new_known(Type::Void); + (vec![], void_ty) + }; + + if !param_tys.is_empty() && args.len() != param_tys.len() { + return Err(CompilationError::TypeError(format!( + "Function '{}' expects {} arguments, got {}", + callee, + param_tys.len(), + args.len() + ))); + } + + if param_tys.is_empty() { + // Just infer arguments without unifying if signature is unknown (variadic C funcs) + for arg in args.iter_mut() { + self.infer_expr(arg)?; + } + } else { + for (arg, param_ty) in args.iter_mut().zip(param_tys.iter()) { + let arg_ty = self.infer_expr(arg)?; + self.unify(arg_ty, *param_ty)?; + } } - } - *ty = ret_ty; - ret_ty + *ty = ret_ty; + ret_ty + } } Expr::UnaryOp { op, expr, ty } => { @@ -470,13 +567,13 @@ impl TypeInferencer { // Validate all required fields are provided or have defaults for field_def in &struct_def.fields { - if !provided_field_names.contains(&field_def.name) { - if field_def.default.is_none() { - return Err(CompilationError::TypeError(format!( - "Struct '{}' field '{}' has no default value and was not provided in initialization", - struct_def.name, field_def.name - ))); - } + if !provided_field_names.contains(&field_def.name) + && field_def.default.is_none() + { + return Err(CompilationError::TypeError(format!( + "Struct '{}' field '{}' has no default value and was not provided in initialization", + struct_def.name, field_def.name + ))); } } @@ -493,14 +590,14 @@ impl TypeInferencer { // Inject default values for missing fields for field_info in &field_infos { let field_name = &field_info.0; - if !provided_field_names.contains(field_name) { - if let Some(default_expr) = &field_info.2 { - // Clone the default expression and add it to fields - fields.push(FieldInit { - name: field_name.clone(), - value: default_expr.clone(), - }); - } + if !provided_field_names.contains(field_name) + && let Some(default_expr) = &field_info.2 + { + // Clone the default expression and add it to fields + fields.push(FieldInit { + name: field_name.clone(), + value: default_expr.clone(), + }); } } @@ -540,18 +637,38 @@ impl TypeInferencer { // Resolve container type - handle both Struct and Named types let resolved = self.store.resolve(container_ty)?; - // For Named types, we need to look up the struct definition + // For Named types, we need to look up the struct or enum definition let (struct_name, fields) = match resolved { Type::Struct { name, fields } => (name, fields), Type::Named(type_name) => { + // First try to look up as struct if let Some(struct_def) = self.symbols.lookup_struct(&type_name) { - // Convert struct fields to the format expected below let fields: Vec<(String, TypeId)> = struct_def .fields .iter() .map(|f| (f.name.clone(), f.ty)) .collect(); (type_name, fields) + } else if let Some(enum_def) = self.symbols.lookup_enum(&type_name) { + // For enum field access like `d1.VariantName.field`, + // we need to check if the field_name is actually a variant name + if let Some(variant) = + enum_def.variants.iter().find(|v| v.name == *field_name) + { + // The field access is on the variant's fields + // Return the variant's args as fields + let fields: Vec<(String, TypeId)> = variant + .args + .iter() + .map(|f| (f.name.clone(), f.ty)) + .collect(); + (type_name, fields) + } else { + return Err(CompilationError::TypeError(format!( + "Enum '{}' has no variant '{}'", + type_name, field_name + ))); + } } else { return Err(CompilationError::TypeError(format!( "Cannot access field on unknown type '{}'", @@ -566,13 +683,13 @@ impl TypeInferencer { } }; - // Look up field in struct + // Look up field in struct/variant let field_type_id = fields .iter() .find(|(fname, _)| fname == field_name) .ok_or_else(|| { CompilationError::TypeError(format!( - "Struct '{}' has no field '{}'", + "Struct/variant '{}' has no field '{}'", struct_name, field_name )) })? @@ -581,11 +698,132 @@ impl TypeInferencer { *field_ty = *field_type_id; *field_type_id } + + Expr::EnumVariant { + enum_name, + variant_name, + ty, + } => { + let _variant_info = self + .symbols + .lookup_variant(enum_name, variant_name) + .ok_or_else(|| { + CompilationError::TypeError(format!( + "Unknown enum variant '{}.{}'", + enum_name, variant_name + )) + })?; + + // Create a named type for the enum + let enum_ty = self.store.new_known(Type::Named(enum_name.clone())); + *ty = enum_ty; + enum_ty + } + + Expr::EnumInit { + enum_name, + variant_name, + args, + ty, + } => { + let (variant_info, enum_def) = { + let info = self + .symbols + .lookup_variant(enum_name, variant_name) + .ok_or_else(|| { + CompilationError::TypeError(format!( + "Unknown enum variant '{}.{}'", + enum_name, variant_name + )) + })? + .clone(); + + let enum_def = self + .symbols + .lookup_enum(enum_name) + .ok_or_else(|| { + CompilationError::TypeError(format!("Unknown enum '{}'", enum_name)) + })? + .clone(); + + (info, enum_def) + }; + + // Resolve default arguments (following Haskell compiler approach) + let resolved_args = self.resolve_default_args(&variant_info, &enum_def, args)?; + + // Update the args in the expression with resolved defaults + *args = resolved_args; + + // Validate argument count matches (after defaults are resolved) + if args.len() != variant_info.arg_types.len() { + return Err(CompilationError::TypeError(format!( + "Enum variant '{}.{}' expects {} arguments, got {}", + enum_name, + variant_name, + variant_info.arg_types.len(), + args.len() + ))); + } + + // Infer types for all arguments and unify with expected types + for (arg, &expected_ty) in args.iter_mut().zip(variant_info.arg_types.iter()) { + let arg_ty = self.infer_expr(arg)?; + self.unify(arg_ty, expected_ty)?; + } + + // Create a named type for the enum + let enum_ty = self.store.new_known(Type::Named(enum_name.clone())); + *ty = enum_ty; + enum_ty + } }; Ok(ty) } + /// Resolve default arguments for enum variant constructors. + /// Returns a new Vec with default values filled in. + /// Follows the Haskell compiler's `addDefaultArgs` function. + fn resolve_default_args( + &self, + variant_info: &super::symbols::EnumVariantInfo, + enum_def: &super::type_ast::EnumDefinition, + provided_args: &[Expr], + ) -> CompileResult> { + let total_required = variant_info.arg_types.len(); + let mut result = provided_args.to_vec(); + + if result.len() < total_required { + let variant = enum_def + .variants + .iter() + .find(|v| v.name == variant_info.variant_name) + .ok_or_else(|| { + CompilationError::TypeError(format!( + "Variant '{}' not found in enum '{}'", + variant_info.variant_name, variant_info.enum_name + )) + })?; + + for i in (0..total_required).rev() { + if i >= result.len() { + if let Some(default_expr) = variant.args.get(i).and_then(|f| f.default.as_ref()) + { + result.push(default_expr.clone()); + } else { + return Err(CompilationError::TypeError(format!( + "Missing required argument {} for variant '{}' (no default value)", + i, variant_info.variant_name + ))); + } + } + } + } + + Ok(result) + } + /// Unify two type IDs fn unify(&mut self, a: TypeId, b: TypeId) -> CompileResult<()> { self.store.unify(a, b).map_err(CompilationError::TypeError) diff --git a/kitlang/src/codegen/parser.rs b/kitlang/src/codegen/parser.rs index 0751495..f725d65 100644 --- a/kitlang/src/codegen/parser.rs +++ b/kitlang/src/codegen/parser.rs @@ -5,7 +5,7 @@ use crate::error::CompilationError; use crate::{Rule, parse_error}; use super::ast::{Block, Expr, Function, Include, Literal, Param, Stmt}; -use super::type_ast::{Field, FieldInit, StructDefinition}; +use super::type_ast::{EnumDefinition, EnumVariant, Field, FieldInit, StructDefinition}; use super::types::{AssignmentOperator, Type, TypeId}; use crate::error::CompileResult; @@ -151,6 +151,97 @@ impl Parser { self.parse_struct_def(struct_def_pair) } + pub fn parse_enum_def(&self, pair: Pair) -> CompileResult { + let mut inner = pair.into_inner(); + + let name = inner + .next() + .filter(|p| p.as_rule() == Rule::identifier) + .ok_or(parse_error!("enum definition missing name"))? + .as_str() + .to_string(); + + while let Some(peek) = inner.peek() { + if peek.as_rule() == Rule::type_params { + let _ = inner.next(); + } else { + break; + } + } + + let variants: Vec = inner + .filter(|p| p.as_rule() == Rule::enum_variant) + .map(|p| self.parse_enum_variant(p, name.clone())) + .collect::>()?; + + if variants.is_empty() { + log::warn!("Enum '{}' has empty body", name); + } + + Ok(EnumDefinition { name, variants }) + } + + pub fn parse_enum_def_from_type_def(&self, pair: Pair) -> CompileResult { + let mut found_enum = None; + for child in pair.into_inner() { + if child.as_rule() == Rule::enum_def { + found_enum = Some(child); + break; + } + } + + let enum_def_pair = found_enum.ok_or(parse_error!("type_def does not contain enum_def"))?; + + self.parse_enum_def(enum_def_pair) + } + + fn parse_enum_variant( + &self, + pair: Pair, + parent_name: String, + ) -> CompileResult { + let mut identifier_found = None; + let mut args = Vec::new(); + let mut variant_default = None; + + for child in pair.clone().into_inner() { + match child.as_rule() { + Rule::identifier => { + identifier_found = Some(child.as_str().to_string()); + } + Rule::param => { + let field = self.parse_param_field(child)?; + args.push(field); + } + Rule::expr => { + variant_default = Some(self.parse_expr(child)?); + } + Rule::metadata_and_modifiers => { + // Skip - we already checked this + } + other => { + log::debug!("Unknown rule in enum_variant: {:?}", other); + } + } + } + + let name = identifier_found.ok_or(parse_error!("enum variant missing name"))?; + + // If there's a variant-level default, apply it to the last argument + if let Some(default_expr) = variant_default + && let Some(last_arg) = args.last_mut() + { + last_arg.default = Some(default_expr); + } + + Ok(EnumVariant { + name, + parent: parent_name, + args, + default: None, + }) + } + fn parse_struct_field(&self, pair: Pair) -> CompileResult { // var_decl = { (var_kw | const_kw) ~ identifier ~ (":" ~ type_annotation)? ~ ("=" ~ expr)? ~ ";" } let name = Self::extract_first_identifier(pair.clone()) @@ -207,6 +298,28 @@ impl Parser { .collect() } + fn parse_param_field(&self, pair: Pair) -> CompileResult { + // param = { identifier ~ ":" ~ type_annotation ~ ( "=" ~ expr )? } + let mut inner = pair.into_inner(); + let name = inner.next().unwrap().as_str().to_string(); + let type_node = inner.next().unwrap(); + let ty_ann = self.parse_type(type_node)?; + + // Check for optional default expression + let default = inner + .next() + .map(|expr_pair| self.parse_expr(expr_pair)) + .transpose()?; + + Ok(Field { + name, + ty: TypeId::default(), + annotation: Some(ty_ann), + is_const: false, + default, + }) + } + fn parse_block(&self, pair: Pair) -> CompileResult { // block = { "{" ~ (statement)* ~ "}" } let stmts = pair @@ -532,7 +645,7 @@ impl Parser { let mut expr = self.parse_expr(inner.next().unwrap())?; // Handle chained field access (.field1.field2.field3) - while let Some(field_pair) = inner.next() { + for field_pair in inner { if field_pair.as_rule() == Rule::postfix_field { let mut field_inner = field_pair.into_inner(); let field_name = field_inner diff --git a/kitlang/src/codegen/symbols.rs b/kitlang/src/codegen/symbols.rs index c78857c..383304d 100644 --- a/kitlang/src/codegen/symbols.rs +++ b/kitlang/src/codegen/symbols.rs @@ -1,7 +1,16 @@ -use super::type_ast::{Field, StructDefinition}; +use super::type_ast::{EnumDefinition, EnumVariant, Field, StructDefinition}; use super::types::TypeId; use std::collections::HashMap; +/// Stores information about an enum variant for lookup. +#[derive(Clone, Debug)] +pub struct EnumVariantInfo { + pub enum_name: String, + pub variant_name: String, + pub arg_types: Vec, + pub has_defaults: bool, +} + /// Symbol table for tracking variable and function types during inference. /// /// Currently uses a flat scope (no nesting). Variables and functions are tracked @@ -15,6 +24,12 @@ pub struct SymbolTable { /// Maps struct names to their definitions. structs: HashMap, + + /// Maps enum names to their definitions. + enums: HashMap, + + /// Maps qualified variant names ("EnumName.VariantName") to variant info. + enum_variants: HashMap, } impl Default for SymbolTable { @@ -29,6 +44,8 @@ impl SymbolTable { vars: HashMap::new(), functions: HashMap::new(), structs: HashMap::new(), + enums: HashMap::new(), + enum_variants: HashMap::new(), } } @@ -68,4 +85,65 @@ impl SymbolTable { .get(struct_name) .and_then(|s| s.fields.iter().find(|f| f.name == field_name)) } + + /// Define an enum type. + pub fn define_enum(&mut self, def: EnumDefinition) { + self.enums.insert(def.name.clone(), def); + } + + /// Look up an enum definition by name. + pub fn lookup_enum(&self, name: &str) -> Option<&EnumDefinition> { + self.enums.get(name) + } + + /// Define an enum variant constructor. + pub fn define_enum_variant(&mut self, variant: &EnumVariant) { + let qualified_name = format!("{}.{}", variant.parent, variant.name); + let has_defaults = variant.args.iter().any(|f| f.default.is_some()); + let arg_types: Vec = variant.args.iter().map(|f| f.ty).collect(); + + self.enum_variants.insert( + qualified_name, + EnumVariantInfo { + enum_name: variant.parent.clone(), + variant_name: variant.name.clone(), + arg_types, + has_defaults, + }, + ); + } + + /// Look up an enum variant by qualified name ("EnumName.VariantName"). + pub fn lookup_enum_variant(&self, qualified_name: &str) -> Option<&EnumVariantInfo> { + self.enum_variants.get(qualified_name) + } + + /// Look up an enum variant by simple name across all enums. + pub fn lookup_enum_variant_by_simple_name( + &self, + simple_name: &str, + ) -> Option<&EnumVariantInfo> { + self.enum_variants + .values() + .find(|v| v.variant_name == simple_name) + } + + /// Look up an enum variant by enum name and variant name. + pub fn lookup_variant(&self, enum_name: &str, variant_name: &str) -> Option<&EnumVariantInfo> { + let qualified_name = format!("{}.{}", enum_name, variant_name); + self.enum_variants.get(&qualified_name) + } + + /// Get all variant names for an enum. + pub fn get_enum_variants(&self, enum_name: &str) -> Vec<&EnumVariantInfo> { + self.enum_variants + .values() + .filter(|v| v.enum_name == enum_name) + .collect() + } + + /// Get all registered enums. + pub fn get_enums(&self) -> Vec<&EnumDefinition> { + self.enums.values().collect() + } } diff --git a/kitlang/src/codegen/type_ast.rs b/kitlang/src/codegen/type_ast.rs index 9ece785..ed0de06 100644 --- a/kitlang/src/codegen/type_ast.rs +++ b/kitlang/src/codegen/type_ast.rs @@ -23,3 +23,17 @@ pub struct FieldInit { pub name: String, pub value: Expr, } + +#[derive(Clone, Debug, PartialEq)] +pub struct EnumVariant { + pub name: String, + pub parent: String, + pub args: Vec, + pub default: Option, +} + +#[derive(Clone, Debug, PartialEq)] +pub struct EnumDefinition { + pub name: String, + pub variants: Vec, +} diff --git a/kitlang/src/grammar/kit.pest b/kitlang/src/grammar/kit.pest index d696b5a..7dcaeaf 100644 --- a/kitlang/src/grammar/kit.pest +++ b/kitlang/src/grammar/kit.pest @@ -18,7 +18,7 @@ function_decl = { } params = { param ~ ("," ~ param)* } -param = { identifier ~ ":" ~ type_annotation } +param = { identifier ~ ":" ~ type_annotation ~ ( "=" ~ expr )? } type_annotation = { function_type | pointer_type | tuple_type | base_type } function_type = { "function" ~ "(" ~ (type_annotation ~ ("," ~ type_annotation)*)? ~ ")" ~ "->" ~ type_annotation }