Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ Arithmetic calculation created
0
Testing Result<int> toString:
Result<int> Success: 50
Result<int> Error: <Result[t20, string]>
Result<int> Error: -1
35 changes: 29 additions & 6 deletions compiler/internal/codegen/core_functions.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,14 +133,35 @@ func (g *LLVMGenerator) convertValueToStringByType(
return g.generateIntToString(arg)
}

// Check if it's a Result type
if strings.HasPrefix(theType, "Result<") {
// Check if it's a Result type (with either angle or square brackets)
if strings.HasPrefix(theType, "Result<") || strings.HasPrefix(theType, "Result[") {
// For Result types, check if it's a struct pointer
if ptrType, ok := arg.Type().(*types.PointerType); ok {
if structType, ok := ptrType.ElemType.(*types.StructType); ok && len(structType.Fields) == ResultFieldCount {
return g.convertResultToString(arg, structType)
}
}
// Also handle struct value directly (not pointer)
if structType, ok := arg.Type().(*types.StructType); ok && len(structType.Fields) == ResultFieldCount {
return g.convertResultToString(arg, structType)
}

// FALLBACK: For Results that are actually raw values (arithmetic operations),
// convert the underlying value directly
switch arg.Type() {
case types.I64:
// It's actually just an integer pretending to be a Result
// This happens with arithmetic operations that don't actually create Result structs
return g.generateIntToString(arg)
case types.I8Ptr:
// It's actually just a string
return arg, nil
case types.I1:
return g.generateBoolToString(arg)
default:
// Unknown Result representation - use "Error" as safe fallback
return g.createGlobalString("Error"), nil
}
}

// For other complex types, return a generic representation
Expand Down Expand Up @@ -225,10 +246,12 @@ func (g *LLVMGenerator) convertResultToString(
}

// Format as "Success(value)" using sprintf
sprintf := g.ensureSprintfDeclaration()
malloc := g.ensureMallocDeclaration()
successFormatStr := g.createGlobalString("Success(%s)")
bufferSize := constant.NewInt(types.I64, BufferSize64Bytes)
successBuffer := g.builder.NewCall(g.functions["malloc"], bufferSize)
g.builder.NewCall(g.functions["sprintf"], successBuffer, successFormatStr, valueStr)
successBuffer := g.builder.NewCall(malloc, bufferSize)
g.builder.NewCall(sprintf, successBuffer, successFormatStr, valueStr)
successStr = successBuffer

successBlock.NewBr(endBlock)
Expand All @@ -253,8 +276,8 @@ func (g *LLVMGenerator) convertResultToString(
if structType.Fields[0] == types.I8Ptr {
// String error message - format as Error(message)
errorFormatStr := g.createGlobalString("Error(%s)")
errorBuffer := g.builder.NewCall(g.functions["malloc"], bufferSize)
g.builder.NewCall(g.functions["sprintf"], errorBuffer, errorFormatStr, errorMsg)
errorBuffer := g.builder.NewCall(malloc, bufferSize)
g.builder.NewCall(sprintf, errorBuffer, errorFormatStr, errorMsg)
errorStr = errorBuffer
} else {
// Non-string error - just use "Error" for now
Expand Down
62 changes: 59 additions & 3 deletions compiler/internal/codegen/llvm.go
Original file line number Diff line number Diff line change
Expand Up @@ -1898,6 +1898,13 @@ func (g *LLVMGenerator) generateSuccessBlock(
) (value.Value, error) {
g.builder = blocks.Success

// Save the current type environment and create a new scope for this match arm
oldEnv := g.typeInferer.env
g.typeInferer.env = g.typeInferer.env.Clone()
defer func() {
g.typeInferer.env = oldEnv
}()

// Find the success arm and bind pattern variables
successArm := g.findSuccessArm(matchExpr)
if successArm != nil && len(successArm.Pattern.Fields) > 0 {
Expand Down Expand Up @@ -1982,11 +1989,36 @@ func (g *LLVMGenerator) generateErrorBlock(
) (value.Value, error) {
g.builder = blocks.Error

// Save the current type environment and create a new scope for this match arm
oldEnv := g.typeInferer.env
g.typeInferer.env = g.typeInferer.env.Clone()
defer func() {
g.typeInferer.env = oldEnv
}()

// Find the Error arm and bind pattern variables
errorArm := g.findErrorArm(matchExpr)
if errorArm != nil && len(errorArm.Pattern.Fields) > 0 {
// Bind the Result error message to the pattern variable
fieldName := errorArm.Pattern.Fields[0] // First field is the message

// Bind the error type to the pattern variable in the type environment
matchedExprType, err := g.typeInferer.InferType(matchExpr.Expression)
if err == nil {
resolvedType := g.typeInferer.ResolveType(matchedExprType)
if genericType, ok := resolvedType.(*GenericType); ok {
if genericType.name == TypeResult && len(genericType.typeArgs) >= 2 {
// Extract the error type (second type argument of Result<T, E>)
errorType := genericType.typeArgs[1]
g.typeInferer.env.Set(fieldName, errorType)
}
} else {
// If the matched expression is not a Result type, it gets auto-wrapped in Success
// The Error arm should never be reached, but bind String type for safety
g.typeInferer.env.Set(fieldName, &ConcreteType{name: TypeString})
}
}

// Create a unique global string for the error message
// Include function context to ensure uniqueness across monomorphized instances
funcContext := ""
Expand Down Expand Up @@ -2054,8 +2086,32 @@ func (g *LLVMGenerator) bindPatternVariableType(fieldName string, matchedExpr as
// Extract the success type (first type argument of Result<T, E>)
successType := genericType.typeArgs[0]
g.typeInferer.env.Set(fieldName, successType)
return
}
}

// Check for ConcreteType that represents a Result (e.g., from built-in functions)
if concreteType, ok := resolvedType.(*ConcreteType); ok {
// Check if this is a Result<T, E> type represented as a concrete type string
if len(concreteType.name) > 7 && concreteType.name[:6] == "Result" && concreteType.name[6] == '<' {
// Parse "Result<int, string>" to extract "int"
// Simple extraction: find first type arg between < and ,
startIdx := 7 // After "Result<"
endIdx := startIdx
for endIdx < len(concreteType.name) && concreteType.name[endIdx] != ',' {
endIdx++
}
if endIdx > startIdx {
successTypeName := concreteType.name[startIdx:endIdx]
g.typeInferer.env.Set(fieldName, &ConcreteType{name: successTypeName})
return
}
}
}

// If the matched expression is not a Result type, it gets auto-wrapped
// In this case, the success value type is the matched expression's type itself
g.typeInferer.env.Set(fieldName, resolvedType)
}
}

Expand Down Expand Up @@ -2143,9 +2199,9 @@ func (g *LLVMGenerator) createResultMatchPhiWithActualBlocks(

// BUGFIX: Check if both values are void (Unit) - can't create PHI with void values
if successValue != nil && errorValue != nil {
// Check if both are void types (nil represents void/Unit)
successIsVoid := (successValue == nil) || isVoidType(successValue.Type())
errorIsVoid := (errorValue == nil) || isVoidType(errorValue.Type())
// Check if both are void types
successIsVoid := isVoidType(successValue.Type())
errorIsVoid := isVoidType(errorValue.Type())

if successIsVoid && errorIsVoid {
// Both arms return Unit - return nil to represent void, don't create PHI
Expand Down
69 changes: 58 additions & 11 deletions compiler/internal/codegen/match_validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package codegen

import (
"fmt"
"strings"
"strconv"

"github.com/christianfindlay/osprey/internal/ast"
)
Expand Down Expand Up @@ -167,19 +167,66 @@ func (g *LLVMGenerator) validateMatchArmWithTypeAndPosition(
func (g *LLVMGenerator) validateMatchPatternWithTypeAndPosition(
pattern ast.Pattern, discriminantType string, matchPos *ast.Position,
) error {
// Infer pattern type with position context
_, err := g.typeInferer.InferPattern(pattern)
if err != nil {
// Check if this is an unknown constructor error and enhance it with position info
if strings.Contains(err.Error(), "unknown constructor") {
// Extract the constructor name from the pattern
constructorName := pattern.Constructor
// Use the provided discriminant type instead of hardcoded "Color"
return WrapUnknownVariantWithPos(constructorName, discriminantType, matchPos)
// Wildcard patterns and variable patterns are always valid
if pattern.Constructor == "_" || pattern.Constructor == "" {
return nil
}

// Literal patterns (integers, strings, booleans) are always valid for their type
if isLiteralPattern(pattern.Constructor) {
return nil
}

// Special constructors that are always allowed (Result types, etc.)
if isSpecialConstructor(pattern.Constructor) {
return nil
}

// Check if this pattern matches a variant of the discriminant's union type
if typeDecl, exists := g.typeDeclarations[discriminantType]; exists {
// Check if the pattern constructor is a valid variant of this type
isValidVariant := false
for _, variant := range typeDecl.Variants {
if variant.Name == pattern.Constructor {
isValidVariant = true
break
}
}

return err
// If not a valid variant, return error
if !isValidVariant {
return WrapUnknownVariantWithPos(pattern.Constructor, discriminantType, matchPos)
}
}

// Pattern validation is complete - type inference and variable binding
// happen during the match expression type inference phase, not here
return nil
}

// isLiteralPattern checks if a pattern constructor is a literal value
func isLiteralPattern(constructor string) bool {
// Boolean literals
if constructor == "true" || constructor == "false" {
return true
}

// Integer literals (try parsing)
_, err := strconv.ParseInt(constructor, 10, 64)
if err == nil {
return true
}

// String literals (quoted)
if len(constructor) >= 2 && constructor[0] == '"' && constructor[len(constructor)-1] == '"' {
return true
}

return false
}

// isSpecialConstructor checks if a constructor is a special built-in constructor
func isSpecialConstructor(constructor string) bool {
// Result type constructors
return constructor == SuccessPattern || constructor == ErrorPattern
}
15 changes: 9 additions & 6 deletions compiler/tests/unit/codegen/union_types_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,22 +80,25 @@ func TestMatchExhaustivenessShouldPass(t *testing.T) {
}
}

func TestUnknownVariantInMatchShouldPass(t *testing.T) {
// This currently passes but ideally should fail with better validation
func TestUnknownVariantInMatchShouldFail(t *testing.T) {
// This should fail with unknown variant error
source := `
type Color = Red | Green | Blue
let color = Red
let description = match color {
Red => "red"
Green => "green"
Green => "green"
Blue => "blue"
Purple => "invalid"
}
`

_, err := codegen.CompileToLLVM(source)
// Note: This currently passes but should eventually fail with proper validation
if err != nil {
t.Logf("Got expected error for unknown variant: %v", err)
if err == nil {
t.Errorf("Expected unknown variant 'Purple' to cause compilation failure")
}

if !strings.Contains(err.Error(), "Purple") || !strings.Contains(err.Error(), "not defined in type") {
t.Errorf("Expected 'unknown variant' error for Purple, got: %v", err)
}
}
4 changes: 2 additions & 2 deletions website/src/status.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ Current version: **0.2.0-alpha** (released)
- Stream fusion optimization for zero-cost abstractions
- Pipe operator (`|>`) for elegant composition
- Function chaining with compile-time optimization
- **Union Types**: Algebraic data types with pattern matching and exhaustiveness checking
- **Any Type Handling**: Explicit `any` types with pattern matching requirement
- **Result Types**: Error handling without exceptions
- **Type Safety**: No implicit conversions, compile-time type checking
Expand Down Expand Up @@ -61,13 +62,12 @@ Current version: **0.2.0-alpha** (released)

### Type System Extensions
- **Record Types with Constraints**: `where` clause validation (partially implemented)
- **Union Types**: Complex algebraic data types with destructuring
- **Generic Types**: Type parameters and polymorphism
- **Module System**: Fiber-isolated modules with proper imports

### Advanced Language Features
- **Extern Declarations**: Full Rust/C interoperability (syntax ready)
- **Advanced Pattern Matching**: Constructor patterns, guards, exhaustiveness checking
- **Advanced Pattern Matching**: Constructor patterns with guards
- **Select Expressions**: Channel multiplexing for concurrent operations
- **Streaming Responses**: Large HTTP response streaming

Expand Down
Loading