Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 8 additions & 14 deletions lib/lua.ex
Original file line number Diff line number Diff line change
Expand Up @@ -569,22 +569,20 @@ defmodule Lua do
@doc """
Calls a function in Lua's state

# TODO: Restore once string stdlib is implemented
#iex> {:ok, [ret], _lua} = Lua.call_function(Lua.new(), [:string, :lower], ["HELLO ROBERT"])
#iex> ret
#"hello robert"
iex> {:ok, [ret], _lua} = Lua.call_function(Lua.new(), [:string, :lower], ["HELLO ROBERT"])
iex> ret
"hello robert"

iex> lua = Lua.new()
iex> lua = Lua.set!(lua, [:double], fn [val] -> [val * 2] end)
iex> {:ok, [_ret], _lua} = Lua.call_function(lua, [:double], [5])

References to functions can also be passed

# TODO: Restore once string stdlib is implemented
#iex> {[ref], lua} = Lua.eval!("return string.lower", decode: false)
#iex> {:ok, [ret], _lua} = Lua.call_function(lua, ref, ["FUNCTION REF"])
#iex> ret
#"function ref"
iex> {[ref], lua} = Lua.eval!(Lua.new(), "return string.lower", decode: false)
iex> {:ok, [ret], _lua} = Lua.call_function(lua, ref, ["FUNCTION REF"])
iex> ret
"function ref"

iex> {[ref], lua} = Lua.eval!(Lua.new(), "return function(x) return x end", decode: false)
iex> {:ok, [ret], _lua} = Lua.call_function(lua, ref, [42])
Expand Down Expand Up @@ -660,12 +658,8 @@ defmodule Lua do
defmodule MyAPI do
use Lua.API, scope: "example"

# TODO: Restore once string stdlib is implemented
# deflua foo(value), state do
# Lua.call_function!(state, [:string, :lower], [value])
# end
deflua foo(value), state do
Lua.call_function!(state, [:my_func], [value])
Lua.call_function!(state, [:string, :lower], [value])
end
end
```
Expand Down
25 changes: 25 additions & 0 deletions lib/lua/compiler/codegen.ex
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,31 @@ defmodule Lua.Compiler.Codegen do
end
end

# LocalFunc: local function name(params) body end
defp gen_statement(%Statement.LocalFunc{name: name} = local_func, ctx) do
# Generate closure for the function
{closure_instructions, closure_reg, ctx} = gen_closure_from_node(local_func, ctx)

# Get the local variable's register from scope
dest_reg = ctx.scope.locals[name]

# Move closure to the local's register if needed
move_instructions =
if closure_reg == dest_reg do
[]
else
[Instruction.move(dest_reg, closure_reg)]
end

{closure_instructions ++ move_instructions, ctx}
end

# Do: do...end block
defp gen_statement(%Statement.Do{body: body}, ctx) do
# Simply generate code for the inner block
gen_block(body, ctx)
end

# Stub for other statements
defp gen_statement(_stmt, ctx), do: {[], ctx}

Expand Down
20 changes: 20 additions & 0 deletions lib/lua/compiler/scope.ex
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,26 @@ defmodule Lua.Compiler.Scope do
resolve_block(body, state)
end

defp resolve_statement(%Statement.LocalFunc{name: name, params: params, body: body} = local_func, state) do
# First, allocate a register for the local function name
reg = state.next_register
state = %{state | locals: Map.put(state.locals, name, reg)}
state = %{state | next_register: reg + 1}

# Update max_register in current function scope
func_scope = state.functions[state.current_function]
func_scope = %{func_scope | max_register: max(func_scope.max_register, state.next_register)}
state = %{state | functions: Map.put(state.functions, state.current_function, func_scope)}

# Then resolve the function body scope (like FuncDecl)
resolve_function_scope(local_func, params, body, state)
end

defp resolve_statement(%Statement.Do{body: body}, state) do
# Do blocks just resolve their inner body
resolve_block(body, state)
end

# For now, stub out other statement types - we'll implement them incrementally
defp resolve_statement(_stmt, state), do: state

Expand Down
161 changes: 161 additions & 0 deletions test/lua/compiler/integration_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -1749,4 +1749,165 @@ defmodule Lua.Compiler.IntegrationTest do
assert Enum.sort(decoded) == [{"a", 1}, {"b", 2}]
end
end

describe "local function declarations" do
test "basic local function" do
code = """
local function add(a, b)
return a + b
end
return add(3, 4)
"""

assert {:ok, ast} = Parser.parse(code)
assert {:ok, proto} = Compiler.compile(ast)
assert {:ok, results, _state} = VM.execute(proto)

assert results == [7]
end

test "local function with closure" do
code = """
local function make_counter()
local count = 0
local function increment()
count = count + 1
return count
end
return increment
end
local counter = make_counter()
return counter(), counter(), counter()
"""

assert {:ok, ast} = Parser.parse(code)
assert {:ok, proto} = Compiler.compile(ast)
assert {:ok, results, _state} = VM.execute(proto)

assert results == [1, 2, 3]
end

@tag :skip
test "local function recursive (requires self-reference upvalue support)" do
code = """
local function factorial(n)
if n <= 1 then
return 1
else
return n * factorial(n - 1)
end
end
return factorial(5)
"""

assert {:ok, ast} = Parser.parse(code)
assert {:ok, proto} = Compiler.compile(ast)
assert {:ok, results, _state} = VM.execute(proto)

assert results == [120]
end

test "local function can be reassigned" do
code = """
local function f()
return 1
end
local x = f()
f = function() return 2 end
local y = f()
return x, y
"""

assert {:ok, ast} = Parser.parse(code)
assert {:ok, proto} = Compiler.compile(ast)
assert {:ok, results, _state} = VM.execute(proto)

assert results == [1, 2]
end
end

describe "do...end blocks" do
test "basic do block" do
code = """
local x = 1
do
local y = 2
x = x + y
end
return x
"""

assert {:ok, ast} = Parser.parse(code)
assert {:ok, proto} = Compiler.compile(ast)
assert {:ok, results, _state} = VM.execute(proto)

assert results == [3]
end

@tag :skip
test "do block creates new scope (requires proper scope cleanup)" do
code = """
local x = 1
do
local x = 2
end
return x
"""

assert {:ok, ast} = Parser.parse(code)
assert {:ok, proto} = Compiler.compile(ast)
assert {:ok, results, _state} = VM.execute(proto)

assert results == [1]
end

test "nested do blocks" do
code = """
local x = 1
do
x = 2
do
x = 3
end
end
return x
"""

assert {:ok, ast} = Parser.parse(code)
assert {:ok, proto} = Compiler.compile(ast)
assert {:ok, results, _state} = VM.execute(proto)

assert results == [3]
end

test "empty do block" do
code = """
local x = 1
do
end
return x
"""

assert {:ok, ast} = Parser.parse(code)
assert {:ok, proto} = Compiler.compile(ast)
assert {:ok, results, _state} = VM.execute(proto)

assert results == [1]
end

test "do block with return" do
code = """
do
return 42
end
return 99
"""

assert {:ok, ast} = Parser.parse(code)
assert {:ok, proto} = Compiler.compile(ast)
assert {:ok, results, _state} = VM.execute(proto)

assert results == [42]
end
end
end