Files
n26/ast.lua
2026-05-04 21:04:23 +03:00

238 lines
4.9 KiB
Lua

local ETypes = require"etype"
local function generic_visitor(node, pre_callback, post_callback)
if pre_callback then pre_callback(node) end
for k, v in ipairs(node.children) do
generic_visitor(v, pre_callback, post_callback)
end
if post_callback then post_callback(node) end
end
local ASTNode = {}
function ASTNode:__index(k)
if k == "add" then
return function(self, child)
table.insert(self.children, child)
return #self.children
end
elseif k == "generic_visitor" then
return generic_visitor
end
local under = rawget(self, "_" .. k)
if under then
return self.children[under]
end
end
function deep_copy(item)
if type(item) == "table" then
local ret = {}
for k, v in pairs(item) do
ret[deep_copy(k)] = deep_copy(v)
end
return setmetatable(ret, getmetatable(item))
end
return item
end
function ASTNode:deep_copy()
return deep_copy(self)
end
local DEPTH = 0
function ASTNode:__tostring()
local specials = {}
local special_keys = {}
for k, v in pairs(self) do
if k ~= "children" and k ~= "kind" then
table.insert(specials, {k, v})
special_keys[v] = true
end
end
table.sort(specials, function(a, b) return a[1] < b[1] end)
local children = {"(", self.kind}
for _, sp in pairs(specials) do
table.insert(children, " ")
if sp[1]:sub(1, 1) == "_" then
table.insert(children, sp[1]:sub(2) .. "=" .. tostring(self.children[sp[2]]))
else
table.insert(children, sp[1] .. "=" .. tostring(sp[2]))
end
end
DEPTH = DEPTH + 1
for k, v in ipairs(self.children) do
if not special_keys[k] then
table.insert(children, "\n")
table.insert(children, string.rep(" ", DEPTH))
table.insert(children, tostring(v))
end
end
DEPTH = DEPTH - 1
children[#children + 1] = ")"
return table.concat(children)
end
local NEXT_VREG_ID = 0
local VReg = {}
VReg.__index = VReg
function VReg:__tostring()
return tostring(self.id)
end
local function new_vreg(name, etype)
assert(name)
assert(not etype or ETypes.is(etype))
NEXT_VREG_ID = NEXT_VREG_ID + 1
return setmetatable({name = name, etype = etype, id = NEXT_VREG_ID}, VReg)
end
local function node(kind)
return setmetatable({kind = kind, children = {}}, ASTNode)
end
local function is(n)
return type(n) == "table" and getmetatable(n) == ASTNode
end
local function root()
return node("root")
end
local function decl(name, expr, export)
local n = node("decl")
n.name = name
n._expr = n:add(expr)
n.export = export
return n
end
local function binop(etype, a, op, b)
assert(not etype or ETypes.is(etype))
assert(is(a) and is(b))
local n = node("expr-binop")
n.etype = etype
n.op = op
n._a = n:add(a)
n._b = n:add(b)
return n
end
local function unop(etype, op, a)
assert(not etype or ETypes.is(etype))
assert(is(a))
local n = node("expr-unop")
n.etype = etype
n.op = op
n._a = n:add(a)
return n
end
local function int(etype, value)
assert(not etype or etype.kind == "scalar" or etype.kind == "pointer")
local n = node("expr-int")
n.etype = etype
n.value = value
return n
end
local function unknown(etype)
local n = node("expr-unknown")
n.etype = etype
return n
end
local function func(etype)
assert(not etype or etype.kind == "func")
local n = node("expr-func")
n.etype = etype
return n
end
local function array(etype, length)
assert(etype)
local n = node("expr-array")
n.etype = etype
n.length = length
return n
end
local function var(etype, vreg)
assert(ETypes.is(etype) or (etype == nil and vreg.etype and ETypes.is(vreg.etype)))
local n = node("expr-var")
n.etype = etype or vreg.etype
n.vreg = vreg
return n
end
local function stringh(value)
local n = node("expr-string")
n.value = value
n.etype = ETypes.string()
return n
end
local function assign(dest, src)
assert(dest and src)
local n = node("stmt-assign")
n._dest = n:add(dest)
n._src = n:add(src)
return n
end
local function label(label_name)
local n = node("stmt-label")
n.name = label_name
return n
end
local function jump(condition, label_name)
local n = node("stmt-jump")
n._condition = n:add(condition)
n.target = label_name
return n
end
local function ret(value)
local n = node("stmt-return")
n._value = n:add(value)
return n
end
local function cast(what, etype)
if etype == nil then
return what
end
if what.etype == etype then
return what
end
if what.kind == "expr-string" and etype.kind == "scalar" then
assert(#what.value * 8 == etype.bits, "String length does not match integer size")
local val = 0
for j = #what.value, 1, -1 do
val = val * 256 + what.value:sub(j, j):byte()
end
return int(etype, val)
end
local n = node("expr-cast")
n._child = n:add(what)
n.etype = etype
return n
end
return {ASTNode = ASTNode, VReg = new_vreg, node = node, root = root, decl = decl, int = int, binop = binop, unop = unop, func = func, array = array, var = var, assign = assign, label = label, jump = jump, ret = ret, unknown = unknown, string = stringh, cast = cast}