ky/gen.lua
2024-06-11 17:46:13 +03:00

225 lines
7.4 KiB
Lua

return function(ast, out)
local function compile(chu)
if chu.type == "chunk" then
if chu.toplevel then
for k, import in pairs(chu.imports) do
out("local %s = require(%q)\n", import[#import], table.concat(import, "."))
end
for k, var in pairs(chu.vars) do
out("local %s\n", k)
end
for k, expr in pairs(chu.vars) do
if expr then
out("%s=", k)
compile(expr)
out(";")
end
end
for k, t in pairs(chu.types) do
out("local %s={}%s.__index=%s;setmetatable(%s, %s)", k, k, k, k, k)
for _, m in pairs(t.members) do
if m.type == "field" then
-- Do nothing.
elseif m.type == "static" then
out("%s.%s=", k, m.name)
compile(m.value)
out(";")
else
error("Invalid AST")
end
end
end
out("return{")
for k,v in pairs(chu.vars) do
out("%s=%s,", k, k)
end
for k,v in pairs(chu.types) do
out("%s=%s,", k, k)
end
out("}")
else
for k, stmt in ipairs(chu.stmts) do
compile(stmt)
if stmt.type == 'call' then
out(";") --Safety delimiter.
end
end
end
elseif chu.type == "function" then
out("function(%s%s)", chu.is_constructor and "_," or "", table.concat(chu.args, ","))
if SAFE_MODE then
for k, argname in pairs(chu.args) do
local et = chu.et.args[k]
if et then
local checks = {}
if et.type == 'integer' or et.type == 'number' or et.type == 'string' or et.type == 'boolean' then
if et.type == 'integer' then
table.insert(checks, 'type(' .. argname .. ')=="number"')
table.insert(checks, argname .. '%1==0')
else
table.insert(checks, 'type(' .. argname .. ')=="'..et.type..'"')
end
elseif et.type == 'list' then
table.insert(checks, 'type(' .. argname .. ')=="table"')
end
if #checks > 0 then
out("assert(%s, %q);", table.concat(checks, " and "), string.format("Invalid argument %q", argname))
end
end
end
end
if chu.is_constructor then
out("local self=setmetatable({}, %s);", chu.et.ret.name)
end
compile(chu.chunk)
if chu.is_constructor then
out("return self;")
end
out("end")
elseif chu.type == "let" then
out("local %s", chu.var.name)
if chu.expr then
out("=")
compile(chu.expr)
end
out(";")
elseif chu.type == "num" then
out("%g", chu.value)
elseif chu.type == "var" then
out("%s", chu.which.name)
elseif chu.type == "binop" then
local paren
paren = chu.a.type == "binop" and chu.a.level > chu.level
if paren then out("(") end
compile(chu.a)
if paren then out(")") end
local op = chu.op
if op == "**" then
op = "^"
elseif op == "/" and (not chu.a.et or chu.a.et.type == 'integer') and (not chu.b.et or chu.b.et.type == 'integer') then
op = "//"
end
out("%s", op)
paren = chu.b.type == "binop" and chu.b.level > chu.level
if paren then out("(") end
compile(chu.b)
if paren then out(")") end
elseif chu.type == "lengthof" then
out("#(")
compile(chu.sub)
out(")")
elseif chu.type == "dot" then
compile(chu.a)
out("%s%s", chu.colon and ":" or ".", chu.b)
elseif chu.type == "index" then
out("(")
compile(chu.what)
out(")[")
compile(chu.idx)
out("]")
elseif chu.type == "call" then
local guard = chu.type == 'binop'
if guard then
out("(")
end
compile(chu.what)
if guard then
out(")")
end
out("(")
for k, v in pairs(chu.args) do
compile(v)
if k < #chu.args then
out(",")
end
end
out(")")
elseif chu.type == "dict" then
out("{")
for k, v in pairs(chu.mappings) do
out("[")
compile(k)
out("]=")
compile(v)
out(",")
end
out("}")
elseif chu.type == "list" then
out("{")
for k, v in pairs(chu.values) do
out("[%i]=", k)
compile(v)
out(",")
end
out("}")
elseif chu.type == "string" then
out("%q", chu.value)
elseif chu.type == "null" then
out("nil")
elseif chu.type == "return" then
out("return ")
compile(chu.expr)
out(";")
elseif chu.type == "if" then
out("if ")
compile(chu[1].pred)
out(" then ")
compile(chu[1].chu)
for i = 2, #chu do
out("elseif ")
compile(chu[i].pred)
out("then ")
compile(chu[i].chu)
end
if chu.elsa then
out("else ")
compile(chu.elsa)
end
out("end;")
elseif chu.type == "fori" then
out("for " .. chu.varname .. "=")
compile(chu.from)
out(",")
compile(chu.to)
out("-1 do ")
compile(chu.chu)
out("end ")
elseif chu.type == "loop" then
out("while true do ")
compile(chu.chu)
out("end ")
elseif chu.type == "break" then
out("break ")
elseif chu.type == "noop" then
elseif chu.type == "exprstat" then
compile(chu.expr)
out(";")
elseif chu.type == "assign" then
compile(chu.dest)
out("=")
compile(chu.src)
out(";")
else
error(string.format("Invalid AST node (type %q)", chu.type))
end
end
return compile(ast)
end