ky/des.lua
2024-06-11 17:46:13 +03:00

82 lines
2.6 KiB
Lua

local unique_name = require"bon".unique_name
local function pass_destandardify(chu)
local function is_builtin_method(c)
if c.type ~= 'dot' then
return false
end
if c.a.et and c.a.et.type == 'list' and c.b == 'pop' then
return {change_stmt = function(stmt)
assert(stmt.type == 'call')
local temp = {et = stmt.what.a.et, name = unique_name()}
return {
{
type = 'let',
var = temp,
expr = stmt.what.a
},
{
type = 'assign',
dest = {
type = 'index',
what = {type = 'var', which = temp, et = temp.et},
idx = {type = 'lengthof', sub = {type = 'var', which = temp, et = temp.et}, et = {type = 'integer'}}
},
src = {type = 'null', et = {type = 'null'}}
},
}
end}
end
return false
end
if chu.type == "chunk" then
for k, t in pairs(chu.types) do
for _, m in pairs(t.members) do
if m.type == "field" then
-- Do nothing.
elseif m.type == "static" then
pass_destandardify(m.value)
else
error("Invalid AST")
end
end
end
local stmtIdx = 1
while stmtIdx <= #chu.stmts do
local stmt = chu.stmts[stmtIdx]
local builtinmethod = stmt.type == 'call' and is_builtin_method(stmt.what)
if builtinmethod then
table.remove(chu.stmts, stmtIdx)
for p, o in ipairs(builtinmethod.change_stmt(stmt)) do
table.insert(chu.stmts, stmtIdx + p - 1, o)
end
else
pass_destandardify(stmt)
stmtIdx = stmtIdx + 1
end
end
elseif chu.type == 'function' then
pass_destandardify(chu.chunk)
elseif chu.type == 'if' then
for i = 1, #chu do
pass_destandardify(chu[i].chu)
end
if chu.elsa then
pass_destandardify(chu.elsa)
end
elseif chu.type == 'fori' then
pass_destandardify(chu.chu)
elseif chu.type == 'loop' then
pass_destandardify(chu.chu)
end
end
return pass_destandardify