diff --git a/lib/lua/compiler/codegen.ex b/lib/lua/compiler/codegen.ex index f31c24e..d2ca3fb 100644 --- a/lib/lua/compiler/codegen.ex +++ b/lib/lua/compiler/codegen.ex @@ -111,23 +111,64 @@ defmodule Lua.Compiler.Codegen do {value_instructions, result_reg, ctx} = gen_expr(value, ctx) {value_instructions ++ [Instruction.return_instr(result_reg, 1)], ctx} - multiple -> - base_reg = ctx.next_reg - - {all_instructions, ctx} = - multiple - |> Enum.with_index() - |> Enum.reduce({[], ctx}, fn {value, i}, {instructions, ctx} -> - target_reg = base_reg + i - {value_instructions, value_reg, ctx} = gen_expr(value, ctx) - - move = - if value_reg == target_reg, do: [], else: [Instruction.move(target_reg, value_reg)] - - {instructions ++ value_instructions ++ move, ctx} - end) + [_, _ | _] = multiple -> + # Check if last value is vararg - needs special handling + {init_values, last_value} = Enum.split(multiple, -1) + [last] = last_value + + case last do + %Expr.Vararg{} when init_values != [] -> + # return a, b, ... - load a,b then all varargs + base_reg = ctx.next_reg + + {init_instructions, ctx} = + init_values + |> Enum.with_index() + |> Enum.reduce({[], ctx}, fn {value, i}, {instructions, ctx} -> + target_reg = base_reg + i + {value_instructions, value_reg, ctx} = gen_expr(value, ctx) + + move = + if value_reg == target_reg do + [] + else + [Instruction.move(target_reg, value_reg)] + end + + {instructions ++ value_instructions ++ move, ctx} + end) + + # Load all varargs starting after the init values + vararg_base = base_reg + length(init_values) + vararg_instruction = Instruction.vararg(vararg_base, 0) + + # Return with -1 to indicate variable number of results + {init_instructions ++ [vararg_instruction, Instruction.return_instr(base_reg, -1)], + ctx} - {all_instructions ++ [Instruction.return_instr(base_reg, length(multiple))], ctx} + _ -> + # Normal multi-value return + base_reg = ctx.next_reg + + {all_instructions, ctx} = + multiple + |> Enum.with_index() + |> Enum.reduce({[], ctx}, fn {value, i}, {instructions, ctx} -> + target_reg = base_reg + i + {value_instructions, value_reg, ctx} = gen_expr(value, ctx) + + move = + if value_reg == target_reg do + [] + else + [Instruction.move(target_reg, value_reg)] + end + + {instructions ++ value_instructions ++ move, ctx} + end) + + {all_instructions ++ [Instruction.return_instr(base_reg, length(multiple))], ctx} + end end end @@ -676,40 +717,92 @@ defmodule Lua.Compiler.Codegen do ctx = %{ctx | next_reg: base_reg + 1} - # Generate code for arguments into temp registers above the arg window. - # We skip over the arg slots (base+1..base+arg_count) so temps don't clobber them. - arg_count = length(args) - ctx = %{ctx | next_reg: base_reg + 1 + arg_count} - - {arg_instructions, arg_regs, ctx} = - Enum.reduce(args, {[], [], ctx}, fn arg, {instructions, regs, ctx} -> - {arg_instructions, arg_reg, ctx} = gen_expr(arg, ctx) - {instructions ++ arg_instructions, regs ++ [arg_reg], ctx} - end) - - # Move each arg result to its expected position (base+1+i) - move_instructions = - arg_regs - |> Enum.with_index() - |> Enum.flat_map(fn {arg_reg, i} -> - expected_reg = base_reg + 1 + i + # Check if last arg is vararg - needs special handling + {has_vararg_last, init_args} = + if length(args) > 0 do + [last | _] = Enum.reverse(args) - if arg_reg == expected_reg do - [] - else - [Instruction.move(expected_reg, arg_reg)] + case last do + %Expr.Vararg{} -> {true, Enum.slice(args, 0..-2//1)} + _ -> {false, args} end - end) - - # Generate call instruction (single return value for now) - call_instruction = Instruction.call(base_reg, arg_count, 1) + else + {false, []} + end - # Result will be in base_reg - {function_instructions ++ - move_function ++ - arg_instructions ++ - move_instructions ++ - [call_instruction], base_reg, ctx} + if has_vararg_last do + # f(a, b, ...) - load a, b then all varargs + arg_count = length(init_args) + ctx = %{ctx | next_reg: base_reg + 1 + arg_count} + + {arg_instructions, arg_regs, ctx} = + Enum.reduce(init_args, {[], [], ctx}, fn arg, {instructions, regs, ctx} -> + {arg_instructions, arg_reg, ctx} = gen_expr(arg, ctx) + {instructions ++ arg_instructions, regs ++ [arg_reg], ctx} + end) + + # Move each arg result to its expected position (base+1+i) + move_instructions = + arg_regs + |> Enum.with_index() + |> Enum.flat_map(fn {arg_reg, i} -> + expected_reg = base_reg + 1 + i + + if arg_reg == expected_reg do + [] + else + [Instruction.move(expected_reg, arg_reg)] + end + end) + + # Load all varargs starting after init args + vararg_base = base_reg + 1 + arg_count + vararg_instruction = Instruction.vararg(vararg_base, 0) + + # Call with -(init_args+1) to encode both varargs and fixed arg count + # Negative values encode: -1 means 0 fixed + varargs, -2 means 1 fixed + varargs, etc. + call_instruction = Instruction.call(base_reg, -(arg_count + 1), 1) + + {function_instructions ++ + move_function ++ + arg_instructions ++ + move_instructions ++ + [vararg_instruction, call_instruction], base_reg, ctx} + else + # Normal function call without varargs + arg_count = length(args) + ctx = %{ctx | next_reg: base_reg + 1 + arg_count} + + {arg_instructions, arg_regs, ctx} = + Enum.reduce(args, {[], [], ctx}, fn arg, {instructions, regs, ctx} -> + {arg_instructions, arg_reg, ctx} = gen_expr(arg, ctx) + {instructions ++ arg_instructions, regs ++ [arg_reg], ctx} + end) + + # Move each arg result to its expected position (base+1+i) + move_instructions = + arg_regs + |> Enum.with_index() + |> Enum.flat_map(fn {arg_reg, i} -> + expected_reg = base_reg + 1 + i + + if arg_reg == expected_reg do + [] + else + [Instruction.move(expected_reg, arg_reg)] + end + end) + + # Generate call instruction (single return value for now) + call_instruction = Instruction.call(base_reg, arg_count, 1) + + # Result will be in base_reg + {function_instructions ++ + move_function ++ + arg_instructions ++ + move_instructions ++ + [call_instruction], base_reg, ctx} + end end defp gen_expr(%Expr.Table{fields: fields}, ctx) do @@ -730,25 +823,80 @@ defmodule Lua.Compiler.Codegen do if list_fields == [] do {[], ctx} else - # Reserve contiguous slots for the list values - start_reg = ctx.next_reg - ctx = %{ctx | next_reg: start_reg + array_hint} - - {value_instructions, ctx} = - list_fields - |> Enum.with_index() - |> Enum.reduce({[], ctx}, fn {val_expr, i}, {instructions, ctx} -> - target_reg = start_reg + i - {value_instructions, val_reg, ctx} = gen_expr(val_expr, ctx) - - move = - if val_reg == target_reg, do: [], else: [Instruction.move(target_reg, val_reg)] + # Check if last field is vararg + {init_fields, last_field} = + if length(list_fields) > 0 do + Enum.split(list_fields, -1) + else + {[], []} + end + + [last | _] = last_field + + case last do + %Expr.Vararg{} when init_fields != [] -> + # Table with {a, b, ...} + # Reserve contiguous slots for the init values + start_reg = ctx.next_reg + ctx = %{ctx | next_reg: start_reg + length(init_fields)} + + {init_instructions, ctx} = + init_fields + |> Enum.with_index() + |> Enum.reduce({[], ctx}, fn {val_expr, i}, {instructions, ctx} -> + target_reg = start_reg + i + {value_instructions, val_reg, ctx} = gen_expr(val_expr, ctx) + + move = + if val_reg == target_reg do + [] + else + [Instruction.move(target_reg, val_reg)] + end + + {instructions ++ value_instructions ++ move, ctx} + end) + + # Load all varargs starting after init values + vararg_base = start_reg + length(init_fields) + vararg_instruction = Instruction.vararg(vararg_base, 0) + + # set_list with count 0 means variable number of values + set_list_instruction = Instruction.set_list(dest, start_reg, 0, 0) + {init_instructions ++ [vararg_instruction, set_list_instruction], ctx} + + %Expr.Vararg{} -> + # Table with just {...} + start_reg = ctx.next_reg + vararg_instruction = Instruction.vararg(start_reg, 0) + set_list_instruction = Instruction.set_list(dest, start_reg, 0, 0) + {[vararg_instruction, set_list_instruction], ctx} - {instructions ++ value_instructions ++ move, ctx} - end) - - set_list_instruction = Instruction.set_list(dest, start_reg, array_hint, 0) - {value_instructions ++ [set_list_instruction], ctx} + _ -> + # Normal list fields (no vararg) + start_reg = ctx.next_reg + ctx = %{ctx | next_reg: start_reg + array_hint} + + {value_instructions, ctx} = + list_fields + |> Enum.with_index() + |> Enum.reduce({[], ctx}, fn {val_expr, i}, {instructions, ctx} -> + target_reg = start_reg + i + {value_instructions, val_reg, ctx} = gen_expr(val_expr, ctx) + + move = + if val_reg == target_reg do + [] + else + [Instruction.move(target_reg, val_reg)] + end + + {instructions ++ value_instructions ++ move, ctx} + end) + + set_list_instruction = Instruction.set_list(dest, start_reg, array_hint, 0) + {value_instructions ++ [set_list_instruction], ctx} + end end # Compile record fields diff --git a/lib/lua/vm/executor.ex b/lib/lua/vm/executor.ex index 59d0de1..c3fd98f 100644 --- a/lib/lua/vm/executor.ex +++ b/lib/lua/vm/executor.ex @@ -402,11 +402,28 @@ defmodule Lua.VM.Executor do func_value = elem(regs, base) # Collect arguments from registers base+1..base+arg_count + # arg_count < 0 encodes fixed args + varargs: + # -1 means 0 fixed + varargs, -2 means 1 fixed + varargs, etc. args = - if arg_count > 0 do - for i <- 1..arg_count, do: elem(regs, base + i) - else - [] + cond do + arg_count > 0 -> + for i <- 1..arg_count, do: elem(regs, base + i) + + arg_count < 0 -> + # Collect fixed args + all varargs + # Decode: -1 => 0 fixed, -2 => 1 fixed, -3 => 2 fixed, etc. + fixed_arg_count = -(arg_count + 1) + varargs = Map.get(proto, :varargs, []) + total_args = fixed_arg_count + length(varargs) + + if total_args > 0 do + for i <- 1..total_args, do: elem(regs, base + i) + else + [] + end + + true -> + [] end {results, state} = @@ -511,13 +528,22 @@ defmodule Lua.VM.Executor do end # vararg - load vararg values into registers + # count == 0 means load all varargs, count > 0 means load exactly count values defp do_execute([{:vararg, base, count} | rest], regs, upvalues, proto, state) do varargs = Map.get(proto, :varargs, []) regs = - Enum.reduce(0..(count - 1), regs, fn i, regs -> - put_elem(regs, base + i, Enum.at(varargs, i)) - end) + if count == 0 do + # Load all varargs + Enum.reduce(Enum.with_index(varargs), regs, fn {val, i}, regs -> + put_elem(regs, base + i, val) + end) + else + # Load exactly count values + Enum.reduce(0..(count - 1), regs, fn i, regs -> + put_elem(regs, base + i, Enum.at(varargs, i)) + end) + end do_execute(rest, regs, upvalues, proto, state) end @@ -529,12 +555,30 @@ defmodule Lua.VM.Executor do end # return - defp do_execute([{:return, base, count} | _rest], regs, _upvalues, _proto, state) do + # count == -1 means return from base including all varargs + # count == 0 means return nil + # count > 0 means return exactly count values + defp do_execute([{:return, base, count} | _rest], regs, _upvalues, proto, state) do results = - if count == 0 do - [nil] - else - for i <- 0..(count - 1), do: elem(regs, base + i) + cond do + count == 0 -> + [nil] + + count == -1 -> + # Return values from base including varargs + # We need to collect values until we've covered the vararg range + varargs = Map.get(proto, :varargs, []) + tuple_size = tuple_size(regs) + max_index = min(tuple_size - 1, base + length(varargs) + proto.param_count - 1) + + if max_index < base do + [] + else + for i <- base..max_index, do: elem(regs, i) + end + + count > 0 -> + for i <- 0..(count - 1), do: elem(regs, base + i) end {results, regs, state} @@ -966,6 +1010,7 @@ defmodule Lua.VM.Executor do end # set_list — bulk store: table[offset+i] = R[start+i-1] for i in 1..count + # count == 0 means store all values from start until nil or end of tuple defp do_execute( [{:set_list, table_reg, start, count, offset} | rest], regs, @@ -978,10 +1023,41 @@ defmodule Lua.VM.Executor do state = State.update_table(state, {:tref, id}, fn table -> new_data = - Enum.reduce(1..count, table.data, fn i, data -> - value = elem(regs, start + i - 1) - Map.put(data, offset + i, value) - end) + if count == 0 do + # Variable number of values - collect from start register onwards + # This happens with varargs in table constructors like {a, b, ...} + # The previous vararg instruction loaded all varargs into registers, + # so we need to collect values until we've collected all of them + + # Count how many values to collect by checking registers + tuple_size = tuple_size(regs) + + # Collect values from start until we reach a nil or end of data + # We know varargs were just loaded, so collect until we see + # consecutive nils or reach tuple end + values_to_collect = + start..(tuple_size - 1) + |> Enum.take_while(fn reg_idx -> + reg_idx < tuple_size && elem(regs, reg_idx) != nil + end) + |> length() + + # Now collect those values + if values_to_collect > 0 do + Enum.reduce(0..(values_to_collect - 1), table.data, fn i, data -> + value = elem(regs, start + i) + Map.put(data, offset + i + 1, value) + end) + else + table.data + end + else + # Fixed number of values + Enum.reduce(1..count, table.data, fn i, data -> + value = elem(regs, start + i - 1) + Map.put(data, offset + i, value) + end) + end %{table | data: new_data} end) diff --git a/test/lua_test.exs b/test/lua_test.exs index cfdd52d..3dd5225 100644 --- a/test/lua_test.exs +++ b/test/lua_test.exs @@ -1461,6 +1461,94 @@ defmodule LuaTest do end end + describe "varargs" do + setup do + %{lua: Lua.new(sandboxed: [])} + end + + test "simple varargs function", %{lua: lua} do + code = """ + function f(...) + return ... + end + return f(1, 2, 3) + """ + + assert {[1, 2, 3], _} = Lua.eval!(lua, code) + end + + test "varargs with regular parameters", %{lua: lua} do + code = """ + function f(a, b, ...) + return a, b, ... + end + return f(1, 2, 3, 4, 5) + """ + + assert {[1, 2, 3, 4, 5], _} = Lua.eval!(lua, code) + end + + test "varargs in table constructor", %{lua: lua} do + code = """ + function f(...) + return {...} + end + t = f(1, 2, 3) + return t[1], t[2], t[3] + """ + + assert {[1, 2, 3], _} = Lua.eval!(lua, code) + end + + test "mixed values and varargs in table", %{lua: lua} do + code = """ + function f(...) + local t = {10, 20, ...} + return t[1], t[2], t[3], t[4] + end + return f(30, 40) + """ + + assert {[10, 20, 30, 40], _} = Lua.eval!(lua, code) + end + + test "varargs with select", %{lua: lua} do + code = """ + function f(...) + return select('#', ...), select(2, ...) + end + return f(10, 20, 30) + """ + + assert {[3, 20], _} = Lua.eval!(lua, code) + end + + test "varargs in function call", %{lua: lua} do + code = """ + function g(a, b, c) + return a + b + c + end + function f(...) + return g(...) + end + return f(1, 2, 3) + """ + + assert {[6], _} = Lua.eval!(lua, code) + end + + test "empty varargs", %{lua: lua} do + code = """ + function f(...) + return select('#', ...) + end + return f() + """ + + assert {[0], _} = Lua.eval!(lua, code) + end + end + defp test_file(name) do Path.join(["test", "fixtures", name]) end