From 037d18e9fde8d4f93c1e675bdfdcf43d313b31a6 Mon Sep 17 00:00:00 2001 From: mid <> Date: Mon, 4 May 2026 21:04:23 +0300 Subject: [PATCH] Initial commit --- ast.lua | 237 +++++++++++++++++++++++ cg.lua | 349 +++++++++++++++++++++++++++++++++ etype.lua | 80 ++++++++ lexer.lua | 117 +++++++++++ logger.lua | 17 ++ main.lua | 21 ++ parser.lua | 555 +++++++++++++++++++++++++++++++++++++++++++++++++++++ set.lua | 41 ++++ target.lua | 67 +++++++ 9 files changed, 1484 insertions(+) create mode 100644 ast.lua create mode 100644 cg.lua create mode 100644 etype.lua create mode 100644 lexer.lua create mode 100644 logger.lua create mode 100644 main.lua create mode 100644 parser.lua create mode 100644 set.lua create mode 100644 target.lua diff --git a/ast.lua b/ast.lua new file mode 100644 index 0000000..cdfc4db --- /dev/null +++ b/ast.lua @@ -0,0 +1,237 @@ +local ETypes = require"etype" + +local function generic_visitor(node, pre_callback, post_callback) + if pre_callback then pre_callback(node) end + + for k, v in ipairs(node.children) do + generic_visitor(v, pre_callback, post_callback) + end + + if post_callback then post_callback(node) end +end + +local ASTNode = {} +function ASTNode:__index(k) + if k == "add" then + return function(self, child) + table.insert(self.children, child) + return #self.children + end + elseif k == "generic_visitor" then + return generic_visitor + end + + local under = rawget(self, "_" .. k) + if under then + return self.children[under] + end +end +function deep_copy(item) + if type(item) == "table" then + local ret = {} + for k, v in pairs(item) do + ret[deep_copy(k)] = deep_copy(v) + end + return setmetatable(ret, getmetatable(item)) + end + return item +end +function ASTNode:deep_copy() + return deep_copy(self) +end +local DEPTH = 0 +function ASTNode:__tostring() + local specials = {} + local special_keys = {} + for k, v in pairs(self) do + if k ~= "children" and k ~= "kind" then + table.insert(specials, {k, v}) + special_keys[v] = true + end + end + table.sort(specials, function(a, b) return a[1] < b[1] end) + + local children = {"(", self.kind} + + for _, sp in pairs(specials) do + table.insert(children, " ") + if sp[1]:sub(1, 1) == "_" then + table.insert(children, sp[1]:sub(2) .. "=" .. tostring(self.children[sp[2]])) + else + table.insert(children, sp[1] .. "=" .. tostring(sp[2])) + end + end + + DEPTH = DEPTH + 1 + for k, v in ipairs(self.children) do + if not special_keys[k] then + table.insert(children, "\n") + table.insert(children, string.rep(" ", DEPTH)) + table.insert(children, tostring(v)) + end + end + DEPTH = DEPTH - 1 + + children[#children + 1] = ")" + return table.concat(children) +end + +local NEXT_VREG_ID = 0 +local VReg = {} +VReg.__index = VReg +function VReg:__tostring() + return tostring(self.id) +end +local function new_vreg(name, etype) + assert(name) + assert(not etype or ETypes.is(etype)) + NEXT_VREG_ID = NEXT_VREG_ID + 1 + return setmetatable({name = name, etype = etype, id = NEXT_VREG_ID}, VReg) +end + +local function node(kind) + return setmetatable({kind = kind, children = {}}, ASTNode) +end + +local function is(n) + return type(n) == "table" and getmetatable(n) == ASTNode +end + +local function root() + return node("root") +end + +local function decl(name, expr, export) + local n = node("decl") + n.name = name + n._expr = n:add(expr) + n.export = export + return n +end + +local function binop(etype, a, op, b) + assert(not etype or ETypes.is(etype)) + assert(is(a) and is(b)) + + local n = node("expr-binop") + n.etype = etype + n.op = op + n._a = n:add(a) + n._b = n:add(b) + return n +end + +local function unop(etype, op, a) + assert(not etype or ETypes.is(etype)) + assert(is(a)) + + local n = node("expr-unop") + n.etype = etype + n.op = op + n._a = n:add(a) + return n +end + +local function int(etype, value) + assert(not etype or etype.kind == "scalar" or etype.kind == "pointer") + + local n = node("expr-int") + n.etype = etype + n.value = value + return n +end + +local function unknown(etype) + local n = node("expr-unknown") + n.etype = etype + return n +end + +local function func(etype) + assert(not etype or etype.kind == "func") + + local n = node("expr-func") + n.etype = etype + return n +end + +local function array(etype, length) + assert(etype) + + local n = node("expr-array") + n.etype = etype + n.length = length + return n +end + +local function var(etype, vreg) + assert(ETypes.is(etype) or (etype == nil and vreg.etype and ETypes.is(vreg.etype))) + + local n = node("expr-var") + n.etype = etype or vreg.etype + n.vreg = vreg + return n +end + +local function stringh(value) + local n = node("expr-string") + n.value = value + n.etype = ETypes.string() + return n +end + +local function assign(dest, src) + assert(dest and src) + + local n = node("stmt-assign") + n._dest = n:add(dest) + n._src = n:add(src) + return n +end + +local function label(label_name) + local n = node("stmt-label") + n.name = label_name + return n +end + +local function jump(condition, label_name) + local n = node("stmt-jump") + n._condition = n:add(condition) + n.target = label_name + return n +end + +local function ret(value) + local n = node("stmt-return") + n._value = n:add(value) + return n +end + +local function cast(what, etype) + if etype == nil then + return what + end + + if what.etype == etype then + return what + end + + if what.kind == "expr-string" and etype.kind == "scalar" then + assert(#what.value * 8 == etype.bits, "String length does not match integer size") + + local val = 0 + for j = #what.value, 1, -1 do + val = val * 256 + what.value:sub(j, j):byte() + end + + return int(etype, val) + end + + local n = node("expr-cast") + n._child = n:add(what) + n.etype = etype + return n +end + +return {ASTNode = ASTNode, VReg = new_vreg, node = node, root = root, decl = decl, int = int, binop = binop, unop = unop, func = func, array = array, var = var, assign = assign, label = label, jump = jump, ret = ret, unknown = unknown, string = stringh, cast = cast} diff --git a/cg.lua b/cg.lua new file mode 100644 index 0000000..5367f07 --- /dev/null +++ b/cg.lua @@ -0,0 +1,349 @@ +local CG = {} +CG.__index = CG + +local AST = require"ast" +local Target = require"target" +local ETypes = require"etype" + +local SIMPLE_BINOPS = {["+"] = "add", ["-"] = "sub", ["^"] = "xor", ["|"] = "or", ["&"] = "and"} +local SIMPLE_UNOPS = {["~"] = "not", ["-"] = "neg"} + +local COMP_SIGNED = {["=="] = "e", ["!="] = "ne", ["<"] = "l", [">"] = "g", [">="] = "ge", ["<="] = "le"} +local COMP_UNSIGNED = {["=="] = "e", ["!="] = "ne", ["<"] = "b", [">"] = "a", [">="] = "be", ["<="] = "ae"} + +function CG:xop(ex) + if ex.kind == "expr-var" then + return ex.vreg.cgi.register + elseif ex.kind == "expr-int" then + return tostring(ex.value) + elseif ex.kind == "expr-unop" and ex.op == "*" then + -- Memory ops + if ex.a.kind == "expr-binop" and ex.a.op == "+" and ex.a.a.kind == "expr-var" and ex.a.b.kind == "expr-int" then + -- *(var1 + int) + return "[" .. ex.a.a.vreg.cgi.register .. " + " .. ex.a.b.value .. "]" + elseif ex.a.kind == "expr-binop" and ex.a.op == "+" and ex.a.a.kind == "expr-var" and ex.a.b.kind == "expr-var" then + -- *(var1 + var2) + return "[" .. ex.a.a.vreg.cgi.register .. " + " .. ex.a.b.vreg.cgi.register .. "]" + end + end + + error("Unimplemented " .. tostring(ex)) +end + +function CG:emit(chunk) + if chunk.cgi.stack_reservation > 0 then + print("sub esp, " .. chunk.cgi.stack_reservation) + end + + for _, stmt in ipairs(chunk.children) do + + if stmt.kind == "stmt-assign" then + if stmt.src.kind == "expr-binop" then + assert(SIMPLE_BINOPS[stmt.src.op]) + assert(self:xop(stmt.dest) == self:xop(stmt.src.a)) + + print(SIMPLE_BINOPS[stmt.src.op] .. " " .. self:xop(stmt.dest) .. ", " .. self:xop(stmt.src.b)) + else + print("mov " .. self:xop(stmt.dest) .. ", " .. self:xop(stmt.src)) + end + elseif stmt.kind == "stmt-jump" then + assert(stmt.condition.kind == "expr-binop") + + local cond = stmt.condition + local op = cond.op + + assert(cond.a.etype == cond.b.etype) + assert(cond.a.etype.kind == "scalar") + assert(op == ">" or op == "<" or op == "==" or op == "!=" or op == ">=" or op == "<=") + + print("cmp " .. self:xop(cond.a) .. ", " .. self:xop(cond.b)) + + local unsigned = cond.a.etype.unsigned + + print("j" .. (unsigned and COMP_UNSIGNED or COMP_SIGNED)[op] .. " .L" .. stmt.target) + elseif stmt.kind == "stmt-label" then + print(".L" .. stmt.name .. ":") + elseif stmt.kind == "stmt-return" then + if chunk.cgi.stack_reservation > 0 then + print("add esp, " .. chunk.cgi.stack_reservation) + end + + print("ret") + else + error("Unimplemented " .. tostring(stmt)) + end + + end +end + +function CG:reg_alloc(chunk, vreguses) + local vreg_live_start = {} + + for stmt_idx, stmt in ipairs(chunk.children) do + if stmt.kind == "stmt-assign" and stmt.dest.kind == "expr-var" then + local vreg = stmt.dest.vreg + + if vreg_live_start[vreg] then + vreg_live_start[vreg] = math.min(vreg_live_start[vreg], stmt_idx) + else + vreg_live_start[vreg] = stmt_idx + end + end + end + + local vreg_live_fin = {} + + for vreg, uses in pairs(vreguses) do + vreg_live_fin[vreg] = require"set".max(uses) + end + + -- Determine register classes as defined in target.lua + for vreg in pairs(vreg_live_start) do + vreg.cgi = {} + if vreg.etype:byte_size() == 1 then + vreg.cgi.register_class = "reg8" + else + vreg.cgi.register_class = "regn8" + end + end + + local edges = {} + for vreg1, start1 in pairs(vreg_live_start) do + edges[vreg1] = {} + + for vreg2, start2 in pairs(vreg_live_start) do + if vreg1 ~= vreg2 then + local fin1 = vreg_live_fin[vreg1] + local fin2 = vreg_live_fin[vreg2] + + local live_range_intersection = ((start1 <= start2 and start2 <= fin1) or (start1 <= fin2 and fin2 <= fin1)) + local resource_intersection = (Target.REG_CLASSES[vreg1.cgi.register_class].mask & Target.REG_CLASSES[vreg2.cgi.register_class].mask) ~= 0 + + if live_range_intersection and resource_intersection then + edges[vreg1][vreg2] = true + end + end + end + end + + for vreg in pairs(vreg_live_start) do + local found_reg + for _, reg in ipairs(Target.REG_CLASSES[vreg.cgi.register_class].items) do + local available = true + for vreg2 in pairs(vreg_live_start) do + if vreg2.cgi.register == reg then + available = false + break + end + end + if available then + found_reg = reg + break + end + end + + if not found_reg then + error("Spilling!") + end + + vreg.cgi.register = found_reg + end +end + +function CG:compute_uses(chunk) + local stmt_idx = 0 + + local ret = {} + + chunk:generic_visitor(function(n) + if n.kind:sub(1, 5) == "stmt-" then + stmt_idx = stmt_idx + 1 + elseif n.kind == "expr-var" then + if not ret[n.vreg] then + ret[n.vreg] = {} + end + + ret[n.vreg][stmt_idx] = true + end + end) + + return ret +end + +function CG:compute_defs(chunk) + local totaldefs = {} + for stmt_idx, stmt in ipairs(chunk.children) do + if stmt.kind == "stmt-assign" and stmt.dest.kind == "expr-var" then + if not totaldefs[stmt.dest.vreg] then + totaldefs[stmt.dest.vreg] = {} + end + table.insert(totaldefs[stmt.dest.vreg], stmt_idx) + end + end + + local outdefs = {} + local indefs = {} + + local changed = {} + for stmt_idx, stmt in ipairs(chunk.children) do + outdefs[stmt_idx] = {} + changed[stmt_idx] = true + end + + while #changed > 0 do + local stmt_idx + for s in pairs(changed) do stmt_idx = s break end + changed[stmt_idx] = nil + local stmt = chunk.children[stmt_idx] + + local predecessors = {} + if stmt_idx > 1 then + predecessors[stmt_idx - 1] = true + end + if stmt.kind == "stmt-label" then + for stmt2_idx, stmt2 in pairs(chunk.children) do + if stmt2.kind == "stmt-jump" and stmt2.target == stmt.name then + predecessors[stmt2_idx] = true + end + end + end + + indefs[stmt_idx] = {} + for pred in pairs(predecessors) do + require"set".merge(indefs[stmt_idx], outdefs[pred]) + end + + local newout = {} + local kill = nil + if stmt.kind == "stmt-assign" and stmt.dest.kind == "expr-var" then + kill = stmt.dest.vreg + newout[stmt_idx] = true + end + for indef in pairs(indefs[stmt_idx]) do + if chunk.children[indef].kind ~= "stmt-assign" or chunk.children[indef].dest.kind ~= "expr-var" or chunk.children[indef].dest.vreg ~= kill then + newout[indef] = true + end + end + + if require"set".equal(outdefs[stmt_idx], newout) then + break + end + + local successors = {} + if stmt_idx < #chunk.children then + successors[stmt_idx + 1] = true + end + if stmt.kind == "jump" then + for stmt2_idx, stmt2 in pairs(chunk.children) do + if stmt2.kind == "label" and stmt2.name == stmt.target then + successors[stmt2_idx] = true + end + end + end + + for succ in pairs(successors) do + changed[succ] = true + end + + outdefs[stmt_idx] = newout + end + + return indefs, outdefs +end + +function CG:get_stack_vreg(chunk) + for vreg in pairs(self:compute_uses(chunk)) do + if vreg.cgi.register == "esp" then + return vreg + end + end + + local vreg = AST.VReg("@stack", ETypes.scalar(true, Target.GPR_SIZE)) + vreg.cgi = {} + vreg.cgi.register = "esp" + return vreg +end + +function CG:pass_save(chunk) + local saves = chunk.etype.modifiers["save"] + + if not saves then + return + end + + local used_saves = {} + for vreg in pairs(self:compute_uses(chunk)) do + if saves[vreg.cgi.register] then + table.insert(used_saves, vreg) + end + end + + if #used_saves == 0 then + return + end + + local stack_vreg = self:get_stack_vreg(chunk) + + for i = 1, #used_saves do + table.insert(chunk.children, i, + AST.assign( + AST.unop(stack_vreg.etype.to, "*", + AST.binop(stack_vreg.etype, AST.var(nil, stack_vreg), "+", AST.int(stack_vreg.etype, (i - 1) * Target.GPR_SIZE // 8))), + AST.var(nil, used_saves[i]))) + + chunk.cgi.stack_reservation = chunk.cgi.stack_reservation + 4 + end + + local handled_rets = {} + for stmt_idx, stmt in ipairs(chunk.children) do + if stmt.kind == "stmt-return" then + if not handled_rets[stmt] then + for i = 1, #used_saves do + table.insert(chunk.children, stmt_idx, + AST.assign( + AST.var(nil, used_saves[i]), + AST.unop(stack_vreg.etype.to, "*", + AST.binop(stack_vreg.etype, AST.var(nil, stack_vreg), "+", AST.int(stack_vreg.etype, (i - 1) * Target.GPR_SIZE // 8))))) + end + handled_rets[stmt] = true + end + end + end +end + +function CG:process(chunk) + assert(not chunk.cgi) + chunk.cgi = {} + chunk.cgi.stack_reservation = 0 + + --local indefs, outdefs = self:compute_defs(chunk) + local uses = self:compute_uses(chunk) + self:reg_alloc(chunk, uses) + + self:pass_save(chunk) +end + +local function apply(modules) + local cg = setmetatable({}, CG) + + for module, root in pairs(modules) do + -- Go through all declarations in module + for _, decl in pairs(root.children) do + if decl.export then + print("global " .. decl.name) + end + print(decl.name .. ":") + + local ex = decl.expr + + if ex.etype.kind == "func" then + cg:process(ex) + cg:emit(ex) + else + error("Unimplemented " .. decl.expr.etype) + end + end + end +end + +return {apply = apply} diff --git a/etype.lua b/etype.lua new file mode 100644 index 0000000..238f4fe --- /dev/null +++ b/etype.lua @@ -0,0 +1,80 @@ +local EType = {} +EType.__index = EType + +function EType:__eq(other) + for k, v in pairs(self) do + if rawget(other, k) ~= v then + return false + end + end + for k, v in pairs(other) do + if rawget(self, k) ~= v then + return false + end + end + return true +end + +function EType:__tostring() + if self.kind == "scalar" then + return (self.unsigned and "u" or "i") .. self.bits + elseif self.kind == "func" then + return (tostring(self.input) .. "->" .. tostring(self.output)) + elseif self.kind == "struct" then + return ("(" .. table.concat(self.fields, ",") .. ")") + elseif self.kind == "pointer" then + return tostring(self.to) .. "*" + elseif self.kind == "array" then + return tostring(self.element_etype) .. "[" .. (self.length or "?") .. "]" + elseif self.kind == "string" then + return "string" + end + error("Unimplemented " .. self.kind) +end + +function EType:byte_size() + if self.kind == "scalar" then + return (self.bits + 7) // 8 + end + error("Unimplemented " .. self.kind) +end + +local ETypes = {} + +function ETypes.scalar(unsigned, bits) + return setmetatable({kind = "scalar", bits = bits, unsigned = unsigned}, EType) +end + +function ETypes.func(input_etype, output_etype) + assert(input_etype ~= nil, "Nil function input") + assert(output_etype ~= nil, "Nil function output") + return setmetatable({kind = "func", input = input_etype, output = output_etype, modifiers = {}}, EType) +end + +function ETypes.array(element_etype, length) + return setmetatable({kind = "array", element_etype = element_etype, length = length}, EType) +end + +function ETypes.struct(fields) + return setmetatable({kind = "struct", fields = fields, modifiers = {}}, EType) +end + +function ETypes.string() + return setmetatable({kind = "string"}, EType) +end + +function ETypes.ref(child) + assert(child, "Nil pointee type") + return setmetatable({kind = "pointer", to = child}, EType) +end + +function ETypes.deref(ptr) + assert(ptr and ptr.kind == "pointer", "Non-pointer type") + return ptr.to +end + +function ETypes.is(etype) + return type(etype) == "table" and getmetatable(etype) == EType +end + +return ETypes diff --git a/lexer.lua b/lexer.lua new file mode 100644 index 0000000..8fae0c7 --- /dev/null +++ b/lexer.lua @@ -0,0 +1,117 @@ +local Lexer = {} +Lexer.__index = Lexer + +function Lexer:eat(byte_count) + for c in self.source:sub(self.i, self.i + byte_count - 1):gmatch"." do + if c == "\n" then + self.current_row = self.current_row + 1 + self.current_col = 1 + else + self.current_col = self.current_col + 1 + end + end + + self.i = self.i + byte_count +end + +function Lexer:match(pattern) + local m = self.source:sub(self.i):match("^" .. pattern) + if m then + self:eat(#m) + self.last_match = m + return m + end +end + +function Lexer:add(token_type, token_data) + table.insert(self.result, {type = token_type, data = token_data, x = self.current_col, y = self.current_row}) +end + +function Lexer:get() + if self:match"[0-9]+r[0-9a-zA-Z]+" or self:match"[0-9]+" then + self:add("num", self.last_match) + elseif self:match"if" then + self:add("if", nil) + elseif self:match"string" then + self:add("stringkw", nil) + elseif self:match"loop" then + self:add("loop", nil) + elseif self:match"return" then + self:add("return", nil) + elseif self:match"export" then + self:add("export", nil) + elseif self:match"[a-zA-Z_][a-zA-Z0-9_]*" then + self:add("ident", self.last_match) + elseif self:match"@[a-zA-Z_][a-zA-Z0-9_]*" then + self:add("mod", self.last_match) + elseif self:match"'" then + local j = self.i + local value = "" + while j <= #self.source do + local b = self.source:sub(j, j) + + if b == "'" then + break + elseif b == "\\" then + j = j + 1 + b = self.source:sub(j, j) + if b == "n" then + b = "\n" + elseif b == "b" then + b = "\b" + elseif b == "e" then + b = "\x1B" + elseif b == "r" then + b = "\r" + elseif b == "t" then + b = "\t" + elseif b == "x" then + b = string.char(tonumber(self.source:sub(j + 1, j + 2), 16)) + j = j + 2 + else + error("Unknown escape sequence") + end + end + + value = value .. b + j = j + 1 + end + self:eat(j - self.i + 1) + self:add("string", value) + elseif self:match"==" then + self:add("==", nil) + elseif self:match"!=" then + self:add("!=", nil) + elseif self:match">=" then + self:add(">=", nil) + elseif self:match"<=" then + self:add("<=", nil) + elseif self:match":" then + self:add(":", nil) + elseif self:match"%(" then + self:add("(", nil) + elseif self:match"%)" then + self:add(")", nil) + elseif self:match"->" then + self:add("->", nil) + elseif self:match"{" then + self:add("{", nil) + elseif self:match"}" then + self:add("}", nil) + elseif self:match"%s" then + --self:add("ws", nil) + else + self:add(self.source:sub(self.i, self.i), nil) + self:eat(1) + end +end + +function Lexer:go() + while self.i <= #self.source do + self:get() + end +end + +return function(source_code) + return setmetatable({i = 1, source = source_code, result = {}, current_row = 1, current_col = 1}, Lexer) +end diff --git a/logger.lua b/logger.lua new file mode 100644 index 0000000..665e797 --- /dev/null +++ b/logger.lua @@ -0,0 +1,17 @@ +local Logger = {} + +local print = require"cprint" + +local COLORS = { + err = "\x1b[31;49;1m", + warn = "\x1b[33;49;1m", + info = "\x1b[39;49;1m", +} + +COLOR_RESET = "\x1b[0m" + +function Logger.log(msg_type, fmt, ...) + print((COLORS[msg_type] or COLOR_RESET) .. msg_type .. ": " .. COLOR_RESET .. string.format(fmt, ...)) +end + +return Logger diff --git a/main.lua b/main.lua new file mode 100644 index 0000000..134e1c6 --- /dev/null +++ b/main.lua @@ -0,0 +1,21 @@ +local Lexer = require"lexer" +local Parser = require"parser" + +-- import write: (u32 fd, u8* buf, ugpr count) -> igpr @c; +local lexer = Lexer([[ + export main: () -> () @save(string[?] ['ebx']) { + a = u32 5 + b = u32 10 + if a > u32 0 { + a = a - u32 16rBABE + } + return + } +]]) +lexer:go() + +local parser = Parser(lexer.result) + +local modules = { main = parser:parse_root() } + +require"cg".apply(modules) diff --git a/parser.lua b/parser.lua new file mode 100644 index 0000000..029c8f7 --- /dev/null +++ b/parser.lua @@ -0,0 +1,555 @@ +local AST = require"ast" + +local ETypes = require"etype" +local Target = require"target" + +local Logger = require"logger" + +local NEXT_LABEL_ID = 1 + +local Scope = {} +Scope.__index = Scope +function Scope.new(parent) + return setmetatable({parent = parent, items = {}}, Scope) +end +function Scope:find(name) + if self.items[name] then + return self.items[name] + end + if self.parent then + return self.parent:find(name) + end + return nil +end +function Scope:add(name) + self.items[name] = AST.VReg(name) + return self.items[name] +end + +local Parser = {} +Parser.__index = Parser + +function Parser:parse_root() + local root = AST.root() + + while true do + if self:peek(0).type == "eof" then + break + end + + local declaration, err = self:parse_declaration() + + if not declaration then + self:log("err", err or "can't parse") + break + end + + table.insert(root.children, declaration) + end + + return root +end + +function Parser:parse_declaration() + local old_i = self.i + + local export = self:maybe"export" + + if not self:maybe"ident" then + self.i = old_i + return nil, "expected identifier" + end + + local name = self:last().data + + if not self:maybe":" then + return nil, "expected :" + end + + local expr, err = self:parse_expr(0) + + if not expr then + return nil, err + end + + return AST.decl(name, expr, export) +end + +function Parser:parse_expr(precedence) + local old_idx = self.i + + if precedence == 0 then + local ret, err = self:parse_expr(precedence + 1) + + if not ret then + self.i = old_idx + return nil, err + end + + while self:maybe">" or self:maybe"<" or self:maybe"==" or self:maybe"!=" or self:maybe">=" or self:maybe"<=" do + local op = self:last().type + + local a = ret + local b, err = self:parse_expr(precedence + 1) + + if not b then + self.i = old_idx + return nil, err + end + + assert(a.etype == b.etype) + + ret = AST.binop(a.etype, a, op, b) + end + + return ret + elseif precedence == 1 then + local ret, err = self:parse_expr(precedence + 1) + + if not ret then + self.i = old_idx + return nil, err + end + + while self:maybe"+" or self:maybe"-" do + local op = self:last().type + + local a = ret + local b, err = self:parse_expr(precedence + 1) + + if not b then + self.i = old_idx + return nil, err + end + + assert(a.etype == b.etype) + + ret = AST.binop(a.etype, a, op, b) + end + + return ret + elseif precedence == 2 then + local ret, err = self:parse_expr(precedence + 1) + + if not ret then + self.i = old_idx + return nil, err + end + + while self:maybe"*" or self:maybe"/" do + local op = self:last().type + + local a = ret + local b, err = self:parse_expr(precedence + 1) + + if not b then + self.i = old_idx + return nil, err + end + + assert(a.etype == b.etype) + + ret = AST.binop(a.etype, a, op, b) + end + + return ret + elseif precedence == 3 then + local ret, err + while self:maybe"&" or self:maybe"*" do + local op = self:last().type + + local a, err = self:parse_expr(precedence) + if not a then + self.i = old_idx + return nil, err + end + + ret = AST.unop(op == "&" and ETypes.ref(a.etype) or ETypes.deref(a.etype), op, a) + end + + if not ret then + ret, err = self:parse_expr(precedence + 1) + end + + if not ret then + self.i = old_idx + return nil, err + end + + return ret + elseif precedence == 4 then + local a, err = self:parse_expr(precedence + 1) + + if not a then + self.i = old_idx + return nil, err + end + + while self:maybe"[" do + if a.etype.kind == "pointer" then + a = AST.unop(a.etype.to, "*", a) + end + + assert(a.etype.kind == "array", "Indexing a non-array") + + local b, err = self:parse_expr(0) + + if not b then + self.i = old_idx + return nil, err + end + + a = AST.unop(a.etype.element_etype, "*", AST.binop(nil, a, "+", b)) + + assert(self:maybe"]") + end + + return a + elseif precedence == 5 then + local asdf = self.i + + -- Okay for etype to be nil + local etype = self:parse_etype(0) + + if self:maybe"?" then + return AST.unknown(etype) + elseif self:maybe'string' then + return AST.cast(AST.string(self:last().data), etype) + elseif self:maybe"num" then + local tok = self:last().data + local base = tonumber(tok:match"^([0-9]+)r") or 10 + local val = tok:match"r([0-9a-zA-Z]+)$" or tok + return AST.int(etype, tonumber(val, base)) + elseif self:maybe"ident" then + assert(not etype) + + local name = self:last().data + local vreg = self.scope:find(name) + if not vreg then + self.i = old_idx + return nil, "Undeclared variable " .. name + end + + return AST.var(vreg.etype, vreg) + elseif self:maybe"[" then + assert(etype) + + local children = {} + + if not self:maybe"]" then + while true do + local ex = self:parse_expr(0) + assert(ex) + + if etype then + ex = AST.cast(ex, etype.element_etype) + end + + table.insert(children, ex) + + if self:maybe"]" then + break + end + + if not self:maybe"," then + local last = self:last() + self.i = old_idx + return nil, "expected comma", last + end + end + end + + local n = AST.array(etype, 0) + for _, child in ipairs(children) do + table.insert(n.children, child) + end + return n + elseif self:maybe"{" then + local n = AST.func(etype) + while not self:maybe"}" do + local status, err = self:parse_stmt(n) + + if not status then + self.i = old_idx + return nil, err + end + end + return n + end + end + + self.i = old_idx + return nil +end + +function Parser:parse_stmt(chunk) + local old_idx = self.i + + if self:maybe";" then + return true + elseif self:maybe"return" then + local ex = self:parse_expr(0) + + -- ex can be null + + table.insert(chunk.children, AST.ret(ex)) + + return true + elseif self:maybe"if" then + local condition = self:parse_expr(0) + + local lbl_id = NEXT_LABEL_ID + NEXT_LABEL_ID = NEXT_LABEL_ID + 1 + + table.insert(chunk.children, AST.jump(condition, lbl_id)) + + if not self:maybe"{" then + self.i = old_idx + return false, "expected {" + end + + while not self:maybe"}" do + local status, err = self:parse_stmt(chunk) + if not status then + self.i = old_idx + return nil, err + end + end + + table.insert(chunk.children, AST.label(lbl_id)) + + return true + elseif self:peek(0).type == "ident" and self:peek(1).type == "=" then + local name = self:next().data + self:next() + + local expr = self:parse_expr(0) + + local vreg = self.scope:find(name) + if not vreg then + assert(expr.etype) + + vreg = self.scope:add(name) + vreg.etype = expr.etype + else + assert(vreg.etype == expr.etype, "Type mismatch in assignment") + end + + table.insert(chunk.children, AST.assign(AST.var(vreg.etype, vreg), expr)) + + return true + end + + -- Try parsing assignment + + local ex = self:parse_expr(0) + if not ex then + self.i = old_idx + return nil, "expected expression" + end + + if not self:maybe"=" then + self.i = old_idx + return nil, "expected =" + end + + local ex2 = self:parse_expr(0) + if not ex2 then + self.i = old_idx + return nil, "expected expression" + end + + table.insert(chunk.children, AST.assign(ex, AST.cast(ex2, ex.etype))) + + return true +end + +function Parser:parse_etype(precedence) + local old_idx = self.i + + if precedence == 0 then + local ret = self:parse_etype(precedence + 1) + + if not ret then + return nil + end + + while self:maybe"->" do + local a = ret + local b = self:parse_etype(precedence + 1) + + ret = ETypes.func(a, b) + + while self:maybe("mod") do + local mod_type = self:last().data + + if mod_type == "@save" then + if not self:maybe"(" then + local last = self:last() + self.i = old_idx + return nil, "expected (", last + end + + local ex = self:parse_expr(0) + if not ex or ex.kind ~= "expr-array" then + local last = self:last() + self.i = old_idx + return nil, "expected array", last + end + + if not self:maybe")" then + local last = self:last() + self.i = old_idx + return nil, "expected )", last + end + + local items = {} + for _, child in ipairs(ex.children) do + assert(child.kind == "expr-string") + + if not require"target".REGS[child.value] then + self:log("warn", "skipping unknown register " .. child.value) + end + + items[child.value] = true + end + + ret.modifiers["save"] = items + else + self:log("warn", "skipping unknown modifier " .. mod_type) + + if self:maybe"(" then + local depth = 1 + while true do + local t = self:next() + if t.type == "(" then + depth = depth + 1 + elseif t.type == ")" then + depth = depth - 1 + if depth == 0 then + break + end + elseif t.type == "eof" then + break + end + end + end + end + end + end + + return ret + elseif precedence == 1 then + local a = self:parse_etype(precedence + 1) + + if not a then + return nil + end + + while self:maybe"*" or self:maybe"[" do + local old_idx = self.i + + if self:last().type == "*" then + a = ETypes.ref(a) + else + local length_expr = self:parse_expr(0) + + if not length_expr or not self:maybe"]" or (length_expr.kind ~= "expr-unknown" and length_expr.kind ~= "expr-int") then + -- This cannot be an array type, most likely a case like (string[?] [...]) + self.i = old_idx - 1 + break + end + + local length + if length_expr.kind == "expr-int" then + length = length_expr.value + end + + a = ETypes.array(a, length) + end + end + return a + elseif precedence == 2 then + if self:maybe"(" then + if self:maybe")" then + return ETypes.struct({}) + else + self:go_back() + end + end + + if self:maybe"stringkw" then + return ETypes.string() + end + + if not self:maybe"ident" then + self.i = old_idx + return nil + end + + local str = self:last().data + + local ret + if str:match"[ui][0-9]+" then + ret = ETypes.scalar(str:sub(1, 1) == "u", tonumber(str:sub(2))) + elseif str:match"[ui]gpr" then + ret = ETypes.scalar(str:sub(1, 1) == "u", Target.GPR_SIZE) + else + self.i = old_idx + return nil + end + + return ret + end + + self.i = old_idx + return nil +end + +function Parser:go_back() + if self.i <= 1 then + error("Already at the beginning") + end + self.i = self.i - 1 +end + +function Parser:last() + return self.tokens[self.i - 1] +end + +function Parser:peek(idx) + if self.i + idx > #self.tokens then + return {type = "eof"} + end + return self.tokens[self.i + idx] +end + +function Parser:maybe(token_type) + if self.i > #self.tokens then + return false + end + + if self.tokens[self.i].type == token_type then + self.i = self.i + 1 + return true + end + + return false +end + +function Parser:next() + local ret = self:peek(0) + if ret.type ~= "eof" then + self.i = self.i + 1 + end + return ret +end + +function Parser:log(msg_type, fmt, ...) + local tok = self:last() + Logger.log(msg_type, tok.y .. ":" .. tok.x .. ", " .. fmt, ...) +end + +return function(tokens) + return setmetatable({i = 1, tokens = tokens, scope = Scope.new(nil)}, Parser) +end diff --git a/set.lua b/set.lua new file mode 100644 index 0000000..0b31579 --- /dev/null +++ b/set.lua @@ -0,0 +1,41 @@ +local function equal(a, b) + for k in pairs(a) do + if not b[k] then + return false + end + end + for k in pairs(b) do + if not a[k] then + return false + end + end + return true +end + +local function min(s) + local a + for b in pairs(s) do + if not a or b < a then + a = b + end + end + return a +end + +local function max(s) + local a + for b in pairs(s) do + if not a or a < b then + a = b + end + end + return a +end + +local function merge(to, from) + for k in pairs(from) do + to[k] = true + end +end + +return {equal = equal, min = min, max = max, merge = merge} diff --git a/target.lua b/target.lua new file mode 100644 index 0000000..6801af9 --- /dev/null +++ b/target.lua @@ -0,0 +1,67 @@ +local REGS = { + al = {bits = 8, mask = 0x1}, + ah = {bits = 8, mask = 0x2}, + ax = {bits = 16, mask = 0x3}, + eax = {bits = 32, mask = 0x3}, + + bl = {bits = 8, mask = 0x4}, + bh = {bits = 8, mask = 0x8}, + bx = {bits = 16, mask = 0xC}, + ebx = {bits = 32, mask = 0xC}, + + cl = {bits = 8, mask = 0x10}, + ch = {bits = 8, mask = 0x20}, + cx = {bits = 16, mask = 0x30}, + ecx = {bits = 32, mask = 0x30}, + + dl = {bits = 8, mask = 0x40}, + dh = {bits = 8, mask = 0x80}, + dx = {bits = 16, mask = 0xC0}, + edx = {bits = 32, mask = 0xC0}, + + di = {bits = 16, mask = 0x100}, + edi = {bits = 32, mask = 0x100}, + + si = {bits = 16, mask = 0x200}, + esi = {bits = 32, mask = 0x200}, + + bp = {bits = 16, mask = 0x400}, + ebp = {bits = 32, mask = 0x400}, + + sp = {bits = 16, mask = 0x800}, + esp = {bits = 32, mask = 0x800}, + + ds = {bits = 16, mask = 0x100}, + es = {bits = 16, mask = 0x2000}, + fs = {bits = 16, mask = 0x4000}, + gs = {bits = 16, mask = 0x8000}, +} + +local REG_CLASSES = { + reg8 = { + items = {"al", "ah", "bl", "bh", "cl", "ch", "dl", "dh"} + }, + regn8 = { + items = {"eax", "ebx", "ecx", "edx", "ax", "bx", "cx", "dx", "di", "si", "bp", "edi", "esi", "ebp"} + }, + seg = { + items = {"ds", "es", "fs", "gs"} + }, + rmptr = { + items = {"bx", "bp", "di", "si"} + } +} + +for _, reg_class in pairs(REG_CLASSES) do + local mask = 0 + for _, reg in ipairs(reg_class.items) do + mask = mask | REGS[reg].mask + end + reg_class.mask = mask +end + +return { + GPR_SIZE = 32, + REGS = REGS, + REG_CLASSES = REG_CLASSES, +}