local ETypes = require"etype" local function generic_visitor(node, pre_callback, post_callback) if pre_callback then pre_callback(node) end for k, v in ipairs(node.children) do generic_visitor(v, pre_callback, post_callback) end if post_callback then post_callback(node) end end local ASTNode = {} function ASTNode:__index(k) if k == "add" then return function(self, child) table.insert(self.children, child) return #self.children end elseif k == "generic_visitor" then return generic_visitor end local under = rawget(self, "_" .. k) if under then return self.children[under] end end function deep_copy(item) if type(item) == "table" then local ret = {} for k, v in pairs(item) do ret[deep_copy(k)] = deep_copy(v) end return setmetatable(ret, getmetatable(item)) end return item end function ASTNode:deep_copy() return deep_copy(self) end local DEPTH = 0 function ASTNode:__tostring() local specials = {} local special_keys = {} for k, v in pairs(self) do if k ~= "children" and k ~= "kind" then table.insert(specials, {k, v}) special_keys[v] = true end end table.sort(specials, function(a, b) return a[1] < b[1] end) local children = {"(", self.kind} for _, sp in pairs(specials) do table.insert(children, " ") if sp[1]:sub(1, 1) == "_" then table.insert(children, sp[1]:sub(2) .. "=" .. tostring(self.children[sp[2]])) else table.insert(children, sp[1] .. "=" .. tostring(sp[2])) end end DEPTH = DEPTH + 1 for k, v in ipairs(self.children) do if not special_keys[k] then table.insert(children, "\n") table.insert(children, string.rep(" ", DEPTH)) table.insert(children, tostring(v)) end end DEPTH = DEPTH - 1 children[#children + 1] = ")" return table.concat(children) end local NEXT_VREG_ID = 0 local VReg = {} VReg.__index = VReg function VReg:__tostring() return tostring(self.id) end local function new_vreg(name, etype) assert(name) assert(not etype or ETypes.is(etype)) NEXT_VREG_ID = NEXT_VREG_ID + 1 return setmetatable({name = name, etype = etype, id = NEXT_VREG_ID}, VReg) end local function node(kind) return setmetatable({kind = kind, children = {}}, ASTNode) end local function is(n) return type(n) == "table" and getmetatable(n) == ASTNode end local function root() return node("root") end local function decl(name, expr, export) local n = node("decl") n.name = name n._expr = n:add(expr) n.export = export return n end local function binop(etype, a, op, b) assert(not etype or ETypes.is(etype)) assert(is(a) and is(b)) local n = node("expr-binop") n.etype = etype n.op = op n._a = n:add(a) n._b = n:add(b) return n end local function unop(etype, op, a) assert(not etype or ETypes.is(etype)) assert(is(a)) local n = node("expr-unop") n.etype = etype n.op = op n._a = n:add(a) return n end local function int(etype, value) assert(not etype or etype.kind == "scalar" or etype.kind == "pointer") local n = node("expr-int") n.etype = etype n.value = value return n end local function unknown(etype) local n = node("expr-unknown") n.etype = etype return n end local function func(etype) assert(not etype or etype.kind == "func") local n = node("expr-func") n.etype = etype return n end local function array(etype, length) assert(etype) local n = node("expr-array") n.etype = etype n.length = length return n end local function var(etype, vreg) assert(ETypes.is(etype) or (etype == nil and vreg.etype and ETypes.is(vreg.etype))) local n = node("expr-var") n.etype = etype or vreg.etype n.vreg = vreg return n end local function stringh(value) local n = node("expr-string") n.value = value n.etype = ETypes.string() return n end local function assign(dest, src) assert(dest and src) local n = node("stmt-assign") n._dest = n:add(dest) n._src = n:add(src) return n end local function label(label_name) local n = node("stmt-label") n.name = label_name return n end local function jump(condition, label_name) local n = node("stmt-jump") n._condition = n:add(condition) n.target = label_name return n end local function ret(value) local n = node("stmt-return") n._value = n:add(value) return n end local function cast(what, etype) if etype == nil then return what end if what.etype == etype then return what end if what.kind == "expr-string" and etype.kind == "scalar" then assert(#what.value * 8 == etype.bits, "String length does not match integer size") local val = 0 for j = #what.value, 1, -1 do val = val * 256 + what.value:sub(j, j):byte() end return int(etype, val) end local n = node("expr-cast") n._child = n:add(what) n.etype = etype return n end return {ASTNode = ASTNode, VReg = new_vreg, node = node, root = root, decl = decl, int = int, binop = binop, unop = unop, func = func, array = array, var = var, assign = assign, label = label, jump = jump, ret = ret, unknown = unknown, string = stringh, cast = cast}