commit 037d18e9fde8d4f93c1e675bdfdcf43d313b31a6 Author: mid <> Date: Mon May 4 21:04:23 2026 +0300 Initial commit diff --git a/ast.lua b/ast.lua new file mode 100644 index 0000000..cdfc4db --- /dev/null +++ b/ast.lua @@ -0,0 +1,237 @@ +local ETypes = require"etype" + +local function generic_visitor(node, pre_callback, post_callback) + if pre_callback then pre_callback(node) end + + for k, v in ipairs(node.children) do + generic_visitor(v, pre_callback, post_callback) + end + + if post_callback then post_callback(node) end +end + +local ASTNode = {} +function ASTNode:__index(k) + if k == "add" then + return function(self, child) + table.insert(self.children, child) + return #self.children + end + elseif k == "generic_visitor" then + return generic_visitor + end + + local under = rawget(self, "_" .. k) + if under then + return self.children[under] + end +end +function deep_copy(item) + if type(item) == "table" then + local ret = {} + for k, v in pairs(item) do + ret[deep_copy(k)] = deep_copy(v) + end + return setmetatable(ret, getmetatable(item)) + end + return item +end +function ASTNode:deep_copy() + return deep_copy(self) +end +local DEPTH = 0 +function ASTNode:__tostring() + local specials = {} + local special_keys = {} + for k, v in pairs(self) do + if k ~= "children" and k ~= "kind" then + table.insert(specials, {k, v}) + special_keys[v] = true + end + end + table.sort(specials, function(a, b) return a[1] < b[1] end) + + local children = {"(", self.kind} + + for _, sp in pairs(specials) do + table.insert(children, " ") + if sp[1]:sub(1, 1) == "_" then + table.insert(children, sp[1]:sub(2) .. "=" .. tostring(self.children[sp[2]])) + else + table.insert(children, sp[1] .. "=" .. tostring(sp[2])) + end + end + + DEPTH = DEPTH + 1 + for k, v in ipairs(self.children) do + if not special_keys[k] then + table.insert(children, "\n") + table.insert(children, string.rep(" ", DEPTH)) + table.insert(children, tostring(v)) + end + end + DEPTH = DEPTH - 1 + + children[#children + 1] = ")" + return table.concat(children) +end + +local NEXT_VREG_ID = 0 +local VReg = {} +VReg.__index = VReg +function VReg:__tostring() + return tostring(self.id) +end +local function new_vreg(name, etype) + assert(name) + assert(not etype or ETypes.is(etype)) + NEXT_VREG_ID = NEXT_VREG_ID + 1 + return setmetatable({name = name, etype = etype, id = NEXT_VREG_ID}, VReg) +end + +local function node(kind) + return setmetatable({kind = kind, children = {}}, ASTNode) +end + +local function is(n) + return type(n) == "table" and getmetatable(n) == ASTNode +end + +local function root() + return node("root") +end + +local function decl(name, expr, export) + local n = node("decl") + n.name = name + n._expr = n:add(expr) + n.export = export + return n +end + +local function binop(etype, a, op, b) + assert(not etype or ETypes.is(etype)) + assert(is(a) and is(b)) + + local n = node("expr-binop") + n.etype = etype + n.op = op + n._a = n:add(a) + n._b = n:add(b) + return n +end + +local function unop(etype, op, a) + assert(not etype or ETypes.is(etype)) + assert(is(a)) + + local n = node("expr-unop") + n.etype = etype + n.op = op + n._a = n:add(a) + return n +end + +local function int(etype, value) + assert(not etype or etype.kind == "scalar" or etype.kind == "pointer") + + local n = node("expr-int") + n.etype = etype + n.value = value + return n +end + +local function unknown(etype) + local n = node("expr-unknown") + n.etype = etype + return n +end + +local function func(etype) + assert(not etype or etype.kind == "func") + + local n = node("expr-func") + n.etype = etype + return n +end + +local function array(etype, length) + assert(etype) + + local n = node("expr-array") + n.etype = etype + n.length = length + return n +end + +local function var(etype, vreg) + assert(ETypes.is(etype) or (etype == nil and vreg.etype and ETypes.is(vreg.etype))) + + local n = node("expr-var") + n.etype = etype or vreg.etype + n.vreg = vreg + return n +end + +local function stringh(value) + local n = node("expr-string") + n.value = value + n.etype = ETypes.string() + return n +end + +local function assign(dest, src) + assert(dest and src) + + local n = node("stmt-assign") + n._dest = n:add(dest) + n._src = n:add(src) + return n +end + +local function label(label_name) + local n = node("stmt-label") + n.name = label_name + return n +end + +local function jump(condition, label_name) + local n = node("stmt-jump") + n._condition = n:add(condition) + n.target = label_name + return n +end + +local function ret(value) + local n = node("stmt-return") + n._value = n:add(value) + return n +end + +local function cast(what, etype) + if etype == nil then + return what + end + + if what.etype == etype then + return what + end + + if what.kind == "expr-string" and etype.kind == "scalar" then + assert(#what.value * 8 == etype.bits, "String length does not match integer size") + + local val = 0 + for j = #what.value, 1, -1 do + val = val * 256 + what.value:sub(j, j):byte() + end + + return int(etype, val) + end + + local n = node("expr-cast") + n._child = n:add(what) + n.etype = etype + return n +end + +return {ASTNode = ASTNode, VReg = new_vreg, node = node, root = root, decl = decl, int = int, binop = binop, unop = unop, func = func, array = array, var = var, assign = assign, label = label, jump = jump, ret = ret, unknown = unknown, string = stringh, cast = cast} diff --git a/cg.lua b/cg.lua new file mode 100644 index 0000000..5367f07 --- /dev/null +++ b/cg.lua @@ -0,0 +1,349 @@ +local CG = {} +CG.__index = CG + +local AST = require"ast" +local Target = require"target" +local ETypes = require"etype" + +local SIMPLE_BINOPS = {["+"] = "add", ["-"] = "sub", ["^"] = "xor", ["|"] = "or", ["&"] = "and"} +local SIMPLE_UNOPS = {["~"] = "not", ["-"] = "neg"} + +local COMP_SIGNED = {["=="] = "e", ["!="] = "ne", ["<"] = "l", [">"] = "g", [">="] = "ge", ["<="] = "le"} +local COMP_UNSIGNED = {["=="] = "e", ["!="] = "ne", ["<"] = "b", [">"] = "a", [">="] = "be", ["<="] = "ae"} + +function CG:xop(ex) + if ex.kind == "expr-var" then + return ex.vreg.cgi.register + elseif ex.kind == "expr-int" then + return tostring(ex.value) + elseif ex.kind == "expr-unop" and ex.op == "*" then + -- Memory ops + if ex.a.kind == "expr-binop" and ex.a.op == "+" and ex.a.a.kind == "expr-var" and ex.a.b.kind == "expr-int" then + -- *(var1 + int) + return "[" .. ex.a.a.vreg.cgi.register .. " + " .. ex.a.b.value .. "]" + elseif ex.a.kind == "expr-binop" and ex.a.op == "+" and ex.a.a.kind == "expr-var" and ex.a.b.kind == "expr-var" then + -- *(var1 + var2) + return "[" .. ex.a.a.vreg.cgi.register .. " + " .. ex.a.b.vreg.cgi.register .. "]" + end + end + + error("Unimplemented " .. tostring(ex)) +end + +function CG:emit(chunk) + if chunk.cgi.stack_reservation > 0 then + print("sub esp, " .. chunk.cgi.stack_reservation) + end + + for _, stmt in ipairs(chunk.children) do + + if stmt.kind == "stmt-assign" then + if stmt.src.kind == "expr-binop" then + assert(SIMPLE_BINOPS[stmt.src.op]) + assert(self:xop(stmt.dest) == self:xop(stmt.src.a)) + + print(SIMPLE_BINOPS[stmt.src.op] .. " " .. self:xop(stmt.dest) .. ", " .. self:xop(stmt.src.b)) + else + print("mov " .. self:xop(stmt.dest) .. ", " .. self:xop(stmt.src)) + end + elseif stmt.kind == "stmt-jump" then + assert(stmt.condition.kind == "expr-binop") + + local cond = stmt.condition + local op = cond.op + + assert(cond.a.etype == cond.b.etype) + assert(cond.a.etype.kind == "scalar") + assert(op == ">" or op == "<" or op == "==" or op == "!=" or op == ">=" or op == "<=") + + print("cmp " .. self:xop(cond.a) .. ", " .. self:xop(cond.b)) + + local unsigned = cond.a.etype.unsigned + + print("j" .. (unsigned and COMP_UNSIGNED or COMP_SIGNED)[op] .. " .L" .. stmt.target) + elseif stmt.kind == "stmt-label" then + print(".L" .. stmt.name .. ":") + elseif stmt.kind == "stmt-return" then + if chunk.cgi.stack_reservation > 0 then + print("add esp, " .. chunk.cgi.stack_reservation) + end + + print("ret") + else + error("Unimplemented " .. tostring(stmt)) + end + + end +end + +function CG:reg_alloc(chunk, vreguses) + local vreg_live_start = {} + + for stmt_idx, stmt in ipairs(chunk.children) do + if stmt.kind == "stmt-assign" and stmt.dest.kind == "expr-var" then + local vreg = stmt.dest.vreg + + if vreg_live_start[vreg] then + vreg_live_start[vreg] = math.min(vreg_live_start[vreg], stmt_idx) + else + vreg_live_start[vreg] = stmt_idx + end + end + end + + local vreg_live_fin = {} + + for vreg, uses in pairs(vreguses) do + vreg_live_fin[vreg] = require"set".max(uses) + end + + -- Determine register classes as defined in target.lua + for vreg in pairs(vreg_live_start) do + vreg.cgi = {} + if vreg.etype:byte_size() == 1 then + vreg.cgi.register_class = "reg8" + else + vreg.cgi.register_class = "regn8" + end + end + + local edges = {} + for vreg1, start1 in pairs(vreg_live_start) do + edges[vreg1] = {} + + for vreg2, start2 in pairs(vreg_live_start) do + if vreg1 ~= vreg2 then + local fin1 = vreg_live_fin[vreg1] + local fin2 = vreg_live_fin[vreg2] + + local live_range_intersection = ((start1 <= start2 and start2 <= fin1) or (start1 <= fin2 and fin2 <= fin1)) + local resource_intersection = (Target.REG_CLASSES[vreg1.cgi.register_class].mask & Target.REG_CLASSES[vreg2.cgi.register_class].mask) ~= 0 + + if live_range_intersection and resource_intersection then + edges[vreg1][vreg2] = true + end + end + end + end + + for vreg in pairs(vreg_live_start) do + local found_reg + for _, reg in ipairs(Target.REG_CLASSES[vreg.cgi.register_class].items) do + local available = true + for vreg2 in pairs(vreg_live_start) do + if vreg2.cgi.register == reg then + available = false + break + end + end + if available then + found_reg = reg + break + end + end + + if not found_reg then + error("Spilling!") + end + + vreg.cgi.register = found_reg + end +end + +function CG:compute_uses(chunk) + local stmt_idx = 0 + + local ret = {} + + chunk:generic_visitor(function(n) + if n.kind:sub(1, 5) == "stmt-" then + stmt_idx = stmt_idx + 1 + elseif n.kind == "expr-var" then + if not ret[n.vreg] then + ret[n.vreg] = {} + end + + ret[n.vreg][stmt_idx] = true + end + end) + + return ret +end + +function CG:compute_defs(chunk) + local totaldefs = {} + for stmt_idx, stmt in ipairs(chunk.children) do + if stmt.kind == "stmt-assign" and stmt.dest.kind == "expr-var" then + if not totaldefs[stmt.dest.vreg] then + totaldefs[stmt.dest.vreg] = {} + end + table.insert(totaldefs[stmt.dest.vreg], stmt_idx) + end + end + + local outdefs = {} + local indefs = {} + + local changed = {} + for stmt_idx, stmt in ipairs(chunk.children) do + outdefs[stmt_idx] = {} + changed[stmt_idx] = true + end + + while #changed > 0 do + local stmt_idx + for s in pairs(changed) do stmt_idx = s break end + changed[stmt_idx] = nil + local stmt = chunk.children[stmt_idx] + + local predecessors = {} + if stmt_idx > 1 then + predecessors[stmt_idx - 1] = true + end + if stmt.kind == "stmt-label" then + for stmt2_idx, stmt2 in pairs(chunk.children) do + if stmt2.kind == "stmt-jump" and stmt2.target == stmt.name then + predecessors[stmt2_idx] = true + end + end + end + + indefs[stmt_idx] = {} + for pred in pairs(predecessors) do + require"set".merge(indefs[stmt_idx], outdefs[pred]) + end + + local newout = {} + local kill = nil + if stmt.kind == "stmt-assign" and stmt.dest.kind == "expr-var" then + kill = stmt.dest.vreg + newout[stmt_idx] = true + end + for indef in pairs(indefs[stmt_idx]) do + if chunk.children[indef].kind ~= "stmt-assign" or chunk.children[indef].dest.kind ~= "expr-var" or chunk.children[indef].dest.vreg ~= kill then + newout[indef] = true + end + end + + if require"set".equal(outdefs[stmt_idx], newout) then + break + end + + local successors = {} + if stmt_idx < #chunk.children then + successors[stmt_idx + 1] = true + end + if stmt.kind == "jump" then + for stmt2_idx, stmt2 in pairs(chunk.children) do + if stmt2.kind == "label" and stmt2.name == stmt.target then + successors[stmt2_idx] = true + end + end + end + + for succ in pairs(successors) do + changed[succ] = true + end + + outdefs[stmt_idx] = newout + end + + return indefs, outdefs +end + +function CG:get_stack_vreg(chunk) + for vreg in pairs(self:compute_uses(chunk)) do + if vreg.cgi.register == "esp" then + return vreg + end + end + + local vreg = AST.VReg("@stack", ETypes.scalar(true, Target.GPR_SIZE)) + vreg.cgi = {} + vreg.cgi.register = "esp" + return vreg +end + +function CG:pass_save(chunk) + local saves = chunk.etype.modifiers["save"] + + if not saves then + return + end + + local used_saves = {} + for vreg in pairs(self:compute_uses(chunk)) do + if saves[vreg.cgi.register] then + table.insert(used_saves, vreg) + end + end + + if #used_saves == 0 then + return + end + + local stack_vreg = self:get_stack_vreg(chunk) + + for i = 1, #used_saves do + table.insert(chunk.children, i, + AST.assign( + AST.unop(stack_vreg.etype.to, "*", + AST.binop(stack_vreg.etype, AST.var(nil, stack_vreg), "+", AST.int(stack_vreg.etype, (i - 1) * Target.GPR_SIZE // 8))), + AST.var(nil, used_saves[i]))) + + chunk.cgi.stack_reservation = chunk.cgi.stack_reservation + 4 + end + + local handled_rets = {} + for stmt_idx, stmt in ipairs(chunk.children) do + if stmt.kind == "stmt-return" then + if not handled_rets[stmt] then + for i = 1, #used_saves do + table.insert(chunk.children, stmt_idx, + AST.assign( + AST.var(nil, used_saves[i]), + AST.unop(stack_vreg.etype.to, "*", + AST.binop(stack_vreg.etype, AST.var(nil, stack_vreg), "+", AST.int(stack_vreg.etype, (i - 1) * Target.GPR_SIZE // 8))))) + end + handled_rets[stmt] = true + end + end + end +end + +function CG:process(chunk) + assert(not chunk.cgi) + chunk.cgi = {} + chunk.cgi.stack_reservation = 0 + + --local indefs, outdefs = self:compute_defs(chunk) + local uses = self:compute_uses(chunk) + self:reg_alloc(chunk, uses) + + self:pass_save(chunk) +end + +local function apply(modules) + local cg = setmetatable({}, CG) + + for module, root in pairs(modules) do + -- Go through all declarations in module + for _, decl in pairs(root.children) do + if decl.export then + print("global " .. decl.name) + end + print(decl.name .. ":") + + local ex = decl.expr + + if ex.etype.kind == "func" then + cg:process(ex) + cg:emit(ex) + else + error("Unimplemented " .. decl.expr.etype) + end + end + end +end + +return {apply = apply} diff --git a/etype.lua b/etype.lua new file mode 100644 index 0000000..238f4fe --- /dev/null +++ b/etype.lua @@ -0,0 +1,80 @@ +local EType = {} +EType.__index = EType + +function EType:__eq(other) + for k, v in pairs(self) do + if rawget(other, k) ~= v then + return false + end + end + for k, v in pairs(other) do + if rawget(self, k) ~= v then + return false + end + end + return true +end + +function EType:__tostring() + if self.kind == "scalar" then + return (self.unsigned and "u" or "i") .. self.bits + elseif self.kind == "func" then + return (tostring(self.input) .. "->" .. tostring(self.output)) + elseif self.kind == "struct" then + return ("(" .. table.concat(self.fields, ",") .. ")") + elseif self.kind == "pointer" then + return tostring(self.to) .. "*" + elseif self.kind == "array" then + return tostring(self.element_etype) .. "[" .. (self.length or "?") .. "]" + elseif self.kind == "string" then + return "string" + end + error("Unimplemented " .. self.kind) +end + +function EType:byte_size() + if self.kind == "scalar" then + return (self.bits + 7) // 8 + end + error("Unimplemented " .. self.kind) +end + +local ETypes = {} + +function ETypes.scalar(unsigned, bits) + return setmetatable({kind = "scalar", bits = bits, unsigned = unsigned}, EType) +end + +function ETypes.func(input_etype, output_etype) + assert(input_etype ~= nil, "Nil function input") + assert(output_etype ~= nil, "Nil function output") + return setmetatable({kind = "func", input = input_etype, output = output_etype, modifiers = {}}, EType) +end + +function ETypes.array(element_etype, length) + return setmetatable({kind = "array", element_etype = element_etype, length = length}, EType) +end + +function ETypes.struct(fields) + return setmetatable({kind = "struct", fields = fields, modifiers = {}}, EType) +end + +function ETypes.string() + return setmetatable({kind = "string"}, EType) +end + +function ETypes.ref(child) + assert(child, "Nil pointee type") + return setmetatable({kind = "pointer", to = child}, EType) +end + +function ETypes.deref(ptr) + assert(ptr and ptr.kind == "pointer", "Non-pointer type") + return ptr.to +end + +function ETypes.is(etype) + return type(etype) == "table" and getmetatable(etype) == EType +end + +return ETypes diff --git a/lexer.lua b/lexer.lua new file mode 100644 index 0000000..8fae0c7 --- /dev/null +++ b/lexer.lua @@ -0,0 +1,117 @@ +local Lexer = {} +Lexer.__index = Lexer + +function Lexer:eat(byte_count) + for c in self.source:sub(self.i, self.i + byte_count - 1):gmatch"." do + if c == "\n" then + self.current_row = self.current_row + 1 + self.current_col = 1 + else + self.current_col = self.current_col + 1 + end + end + + self.i = self.i + byte_count +end + +function Lexer:match(pattern) + local m = self.source:sub(self.i):match("^" .. pattern) + if m then + self:eat(#m) + self.last_match = m + return m + end +end + +function Lexer:add(token_type, token_data) + table.insert(self.result, {type = token_type, data = token_data, x = self.current_col, y = self.current_row}) +end + +function Lexer:get() + if self:match"[0-9]+r[0-9a-zA-Z]+" or self:match"[0-9]+" then + self:add("num", self.last_match) + elseif self:match"if" then + self:add("if", nil) + elseif self:match"string" then + self:add("stringkw", nil) + elseif self:match"loop" then + self:add("loop", nil) + elseif self:match"return" then + self:add("return", nil) + elseif self:match"export" then + self:add("export", nil) + elseif self:match"[a-zA-Z_][a-zA-Z0-9_]*" then + self:add("ident", self.last_match) + elseif self:match"@[a-zA-Z_][a-zA-Z0-9_]*" then + self:add("mod", self.last_match) + elseif self:match"'" then + local j = self.i + local value = "" + while j <= #self.source do + local b = self.source:sub(j, j) + + if b == "'" then + break + elseif b == "\\" then + j = j + 1 + b = self.source:sub(j, j) + if b == "n" then + b = "\n" + elseif b == "b" then + b = "\b" + elseif b == "e" then + b = "\x1B" + elseif b == "r" then + b = "\r" + elseif b == "t" then + b = "\t" + elseif b == "x" then + b = string.char(tonumber(self.source:sub(j + 1, j + 2), 16)) + j = j + 2 + else + error("Unknown escape sequence") + end + end + + value = value .. b + j = j + 1 + end + self:eat(j - self.i + 1) + self:add("string", value) + elseif self:match"==" then + self:add("==", nil) + elseif self:match"!=" then + self:add("!=", nil) + elseif self:match">=" then + self:add(">=", nil) + elseif self:match"<=" then + self:add("<=", nil) + elseif self:match":" then + self:add(":", nil) + elseif self:match"%(" then + self:add("(", nil) + elseif self:match"%)" then + self:add(")", nil) + elseif self:match"->" then + self:add("->", nil) + elseif self:match"{" then + self:add("{", nil) + elseif self:match"}" then + self:add("}", nil) + elseif self:match"%s" then + --self:add("ws", nil) + else + self:add(self.source:sub(self.i, self.i), nil) + self:eat(1) + end +end + +function Lexer:go() + while self.i <= #self.source do + self:get() + end +end + +return function(source_code) + return setmetatable({i = 1, source = source_code, result = {}, current_row = 1, current_col = 1}, Lexer) +end diff --git a/logger.lua b/logger.lua new file mode 100644 index 0000000..665e797 --- /dev/null +++ b/logger.lua @@ -0,0 +1,17 @@ +local Logger = {} + +local print = require"cprint" + +local COLORS = { + err = "\x1b[31;49;1m", + warn = "\x1b[33;49;1m", + info = "\x1b[39;49;1m", +} + +COLOR_RESET = "\x1b[0m" + +function Logger.log(msg_type, fmt, ...) + print((COLORS[msg_type] or COLOR_RESET) .. msg_type .. ": " .. COLOR_RESET .. string.format(fmt, ...)) +end + +return Logger diff --git a/main.lua b/main.lua new file mode 100644 index 0000000..134e1c6 --- /dev/null +++ b/main.lua @@ -0,0 +1,21 @@ +local Lexer = require"lexer" +local Parser = require"parser" + +-- import write: (u32 fd, u8* buf, ugpr count) -> igpr @c; +local lexer = Lexer([[ + export main: () -> () @save(string[?] ['ebx']) { + a = u32 5 + b = u32 10 + if a > u32 0 { + a = a - u32 16rBABE + } + return + } +]]) +lexer:go() + +local parser = Parser(lexer.result) + +local modules = { main = parser:parse_root() } + +require"cg".apply(modules) diff --git a/parser.lua b/parser.lua new file mode 100644 index 0000000..029c8f7 --- /dev/null +++ b/parser.lua @@ -0,0 +1,555 @@ +local AST = require"ast" + +local ETypes = require"etype" +local Target = require"target" + +local Logger = require"logger" + +local NEXT_LABEL_ID = 1 + +local Scope = {} +Scope.__index = Scope +function Scope.new(parent) + return setmetatable({parent = parent, items = {}}, Scope) +end +function Scope:find(name) + if self.items[name] then + return self.items[name] + end + if self.parent then + return self.parent:find(name) + end + return nil +end +function Scope:add(name) + self.items[name] = AST.VReg(name) + return self.items[name] +end + +local Parser = {} +Parser.__index = Parser + +function Parser:parse_root() + local root = AST.root() + + while true do + if self:peek(0).type == "eof" then + break + end + + local declaration, err = self:parse_declaration() + + if not declaration then + self:log("err", err or "can't parse") + break + end + + table.insert(root.children, declaration) + end + + return root +end + +function Parser:parse_declaration() + local old_i = self.i + + local export = self:maybe"export" + + if not self:maybe"ident" then + self.i = old_i + return nil, "expected identifier" + end + + local name = self:last().data + + if not self:maybe":" then + return nil, "expected :" + end + + local expr, err = self:parse_expr(0) + + if not expr then + return nil, err + end + + return AST.decl(name, expr, export) +end + +function Parser:parse_expr(precedence) + local old_idx = self.i + + if precedence == 0 then + local ret, err = self:parse_expr(precedence + 1) + + if not ret then + self.i = old_idx + return nil, err + end + + while self:maybe">" or self:maybe"<" or self:maybe"==" or self:maybe"!=" or self:maybe">=" or self:maybe"<=" do + local op = self:last().type + + local a = ret + local b, err = self:parse_expr(precedence + 1) + + if not b then + self.i = old_idx + return nil, err + end + + assert(a.etype == b.etype) + + ret = AST.binop(a.etype, a, op, b) + end + + return ret + elseif precedence == 1 then + local ret, err = self:parse_expr(precedence + 1) + + if not ret then + self.i = old_idx + return nil, err + end + + while self:maybe"+" or self:maybe"-" do + local op = self:last().type + + local a = ret + local b, err = self:parse_expr(precedence + 1) + + if not b then + self.i = old_idx + return nil, err + end + + assert(a.etype == b.etype) + + ret = AST.binop(a.etype, a, op, b) + end + + return ret + elseif precedence == 2 then + local ret, err = self:parse_expr(precedence + 1) + + if not ret then + self.i = old_idx + return nil, err + end + + while self:maybe"*" or self:maybe"/" do + local op = self:last().type + + local a = ret + local b, err = self:parse_expr(precedence + 1) + + if not b then + self.i = old_idx + return nil, err + end + + assert(a.etype == b.etype) + + ret = AST.binop(a.etype, a, op, b) + end + + return ret + elseif precedence == 3 then + local ret, err + while self:maybe"&" or self:maybe"*" do + local op = self:last().type + + local a, err = self:parse_expr(precedence) + if not a then + self.i = old_idx + return nil, err + end + + ret = AST.unop(op == "&" and ETypes.ref(a.etype) or ETypes.deref(a.etype), op, a) + end + + if not ret then + ret, err = self:parse_expr(precedence + 1) + end + + if not ret then + self.i = old_idx + return nil, err + end + + return ret + elseif precedence == 4 then + local a, err = self:parse_expr(precedence + 1) + + if not a then + self.i = old_idx + return nil, err + end + + while self:maybe"[" do + if a.etype.kind == "pointer" then + a = AST.unop(a.etype.to, "*", a) + end + + assert(a.etype.kind == "array", "Indexing a non-array") + + local b, err = self:parse_expr(0) + + if not b then + self.i = old_idx + return nil, err + end + + a = AST.unop(a.etype.element_etype, "*", AST.binop(nil, a, "+", b)) + + assert(self:maybe"]") + end + + return a + elseif precedence == 5 then + local asdf = self.i + + -- Okay for etype to be nil + local etype = self:parse_etype(0) + + if self:maybe"?" then + return AST.unknown(etype) + elseif self:maybe'string' then + return AST.cast(AST.string(self:last().data), etype) + elseif self:maybe"num" then + local tok = self:last().data + local base = tonumber(tok:match"^([0-9]+)r") or 10 + local val = tok:match"r([0-9a-zA-Z]+)$" or tok + return AST.int(etype, tonumber(val, base)) + elseif self:maybe"ident" then + assert(not etype) + + local name = self:last().data + local vreg = self.scope:find(name) + if not vreg then + self.i = old_idx + return nil, "Undeclared variable " .. name + end + + return AST.var(vreg.etype, vreg) + elseif self:maybe"[" then + assert(etype) + + local children = {} + + if not self:maybe"]" then + while true do + local ex = self:parse_expr(0) + assert(ex) + + if etype then + ex = AST.cast(ex, etype.element_etype) + end + + table.insert(children, ex) + + if self:maybe"]" then + break + end + + if not self:maybe"," then + local last = self:last() + self.i = old_idx + return nil, "expected comma", last + end + end + end + + local n = AST.array(etype, 0) + for _, child in ipairs(children) do + table.insert(n.children, child) + end + return n + elseif self:maybe"{" then + local n = AST.func(etype) + while not self:maybe"}" do + local status, err = self:parse_stmt(n) + + if not status then + self.i = old_idx + return nil, err + end + end + return n + end + end + + self.i = old_idx + return nil +end + +function Parser:parse_stmt(chunk) + local old_idx = self.i + + if self:maybe";" then + return true + elseif self:maybe"return" then + local ex = self:parse_expr(0) + + -- ex can be null + + table.insert(chunk.children, AST.ret(ex)) + + return true + elseif self:maybe"if" then + local condition = self:parse_expr(0) + + local lbl_id = NEXT_LABEL_ID + NEXT_LABEL_ID = NEXT_LABEL_ID + 1 + + table.insert(chunk.children, AST.jump(condition, lbl_id)) + + if not self:maybe"{" then + self.i = old_idx + return false, "expected {" + end + + while not self:maybe"}" do + local status, err = self:parse_stmt(chunk) + if not status then + self.i = old_idx + return nil, err + end + end + + table.insert(chunk.children, AST.label(lbl_id)) + + return true + elseif self:peek(0).type == "ident" and self:peek(1).type == "=" then + local name = self:next().data + self:next() + + local expr = self:parse_expr(0) + + local vreg = self.scope:find(name) + if not vreg then + assert(expr.etype) + + vreg = self.scope:add(name) + vreg.etype = expr.etype + else + assert(vreg.etype == expr.etype, "Type mismatch in assignment") + end + + table.insert(chunk.children, AST.assign(AST.var(vreg.etype, vreg), expr)) + + return true + end + + -- Try parsing assignment + + local ex = self:parse_expr(0) + if not ex then + self.i = old_idx + return nil, "expected expression" + end + + if not self:maybe"=" then + self.i = old_idx + return nil, "expected =" + end + + local ex2 = self:parse_expr(0) + if not ex2 then + self.i = old_idx + return nil, "expected expression" + end + + table.insert(chunk.children, AST.assign(ex, AST.cast(ex2, ex.etype))) + + return true +end + +function Parser:parse_etype(precedence) + local old_idx = self.i + + if precedence == 0 then + local ret = self:parse_etype(precedence + 1) + + if not ret then + return nil + end + + while self:maybe"->" do + local a = ret + local b = self:parse_etype(precedence + 1) + + ret = ETypes.func(a, b) + + while self:maybe("mod") do + local mod_type = self:last().data + + if mod_type == "@save" then + if not self:maybe"(" then + local last = self:last() + self.i = old_idx + return nil, "expected (", last + end + + local ex = self:parse_expr(0) + if not ex or ex.kind ~= "expr-array" then + local last = self:last() + self.i = old_idx + return nil, "expected array", last + end + + if not self:maybe")" then + local last = self:last() + self.i = old_idx + return nil, "expected )", last + end + + local items = {} + for _, child in ipairs(ex.children) do + assert(child.kind == "expr-string") + + if not require"target".REGS[child.value] then + self:log("warn", "skipping unknown register " .. child.value) + end + + items[child.value] = true + end + + ret.modifiers["save"] = items + else + self:log("warn", "skipping unknown modifier " .. mod_type) + + if self:maybe"(" then + local depth = 1 + while true do + local t = self:next() + if t.type == "(" then + depth = depth + 1 + elseif t.type == ")" then + depth = depth - 1 + if depth == 0 then + break + end + elseif t.type == "eof" then + break + end + end + end + end + end + end + + return ret + elseif precedence == 1 then + local a = self:parse_etype(precedence + 1) + + if not a then + return nil + end + + while self:maybe"*" or self:maybe"[" do + local old_idx = self.i + + if self:last().type == "*" then + a = ETypes.ref(a) + else + local length_expr = self:parse_expr(0) + + if not length_expr or not self:maybe"]" or (length_expr.kind ~= "expr-unknown" and length_expr.kind ~= "expr-int") then + -- This cannot be an array type, most likely a case like (string[?] [...]) + self.i = old_idx - 1 + break + end + + local length + if length_expr.kind == "expr-int" then + length = length_expr.value + end + + a = ETypes.array(a, length) + end + end + return a + elseif precedence == 2 then + if self:maybe"(" then + if self:maybe")" then + return ETypes.struct({}) + else + self:go_back() + end + end + + if self:maybe"stringkw" then + return ETypes.string() + end + + if not self:maybe"ident" then + self.i = old_idx + return nil + end + + local str = self:last().data + + local ret + if str:match"[ui][0-9]+" then + ret = ETypes.scalar(str:sub(1, 1) == "u", tonumber(str:sub(2))) + elseif str:match"[ui]gpr" then + ret = ETypes.scalar(str:sub(1, 1) == "u", Target.GPR_SIZE) + else + self.i = old_idx + return nil + end + + return ret + end + + self.i = old_idx + return nil +end + +function Parser:go_back() + if self.i <= 1 then + error("Already at the beginning") + end + self.i = self.i - 1 +end + +function Parser:last() + return self.tokens[self.i - 1] +end + +function Parser:peek(idx) + if self.i + idx > #self.tokens then + return {type = "eof"} + end + return self.tokens[self.i + idx] +end + +function Parser:maybe(token_type) + if self.i > #self.tokens then + return false + end + + if self.tokens[self.i].type == token_type then + self.i = self.i + 1 + return true + end + + return false +end + +function Parser:next() + local ret = self:peek(0) + if ret.type ~= "eof" then + self.i = self.i + 1 + end + return ret +end + +function Parser:log(msg_type, fmt, ...) + local tok = self:last() + Logger.log(msg_type, tok.y .. ":" .. tok.x .. ", " .. fmt, ...) +end + +return function(tokens) + return setmetatable({i = 1, tokens = tokens, scope = Scope.new(nil)}, Parser) +end diff --git a/set.lua b/set.lua new file mode 100644 index 0000000..0b31579 --- /dev/null +++ b/set.lua @@ -0,0 +1,41 @@ +local function equal(a, b) + for k in pairs(a) do + if not b[k] then + return false + end + end + for k in pairs(b) do + if not a[k] then + return false + end + end + return true +end + +local function min(s) + local a + for b in pairs(s) do + if not a or b < a then + a = b + end + end + return a +end + +local function max(s) + local a + for b in pairs(s) do + if not a or a < b then + a = b + end + end + return a +end + +local function merge(to, from) + for k in pairs(from) do + to[k] = true + end +end + +return {equal = equal, min = min, max = max, merge = merge} diff --git a/target.lua b/target.lua new file mode 100644 index 0000000..6801af9 --- /dev/null +++ b/target.lua @@ -0,0 +1,67 @@ +local REGS = { + al = {bits = 8, mask = 0x1}, + ah = {bits = 8, mask = 0x2}, + ax = {bits = 16, mask = 0x3}, + eax = {bits = 32, mask = 0x3}, + + bl = {bits = 8, mask = 0x4}, + bh = {bits = 8, mask = 0x8}, + bx = {bits = 16, mask = 0xC}, + ebx = {bits = 32, mask = 0xC}, + + cl = {bits = 8, mask = 0x10}, + ch = {bits = 8, mask = 0x20}, + cx = {bits = 16, mask = 0x30}, + ecx = {bits = 32, mask = 0x30}, + + dl = {bits = 8, mask = 0x40}, + dh = {bits = 8, mask = 0x80}, + dx = {bits = 16, mask = 0xC0}, + edx = {bits = 32, mask = 0xC0}, + + di = {bits = 16, mask = 0x100}, + edi = {bits = 32, mask = 0x100}, + + si = {bits = 16, mask = 0x200}, + esi = {bits = 32, mask = 0x200}, + + bp = {bits = 16, mask = 0x400}, + ebp = {bits = 32, mask = 0x400}, + + sp = {bits = 16, mask = 0x800}, + esp = {bits = 32, mask = 0x800}, + + ds = {bits = 16, mask = 0x100}, + es = {bits = 16, mask = 0x2000}, + fs = {bits = 16, mask = 0x4000}, + gs = {bits = 16, mask = 0x8000}, +} + +local REG_CLASSES = { + reg8 = { + items = {"al", "ah", "bl", "bh", "cl", "ch", "dl", "dh"} + }, + regn8 = { + items = {"eax", "ebx", "ecx", "edx", "ax", "bx", "cx", "dx", "di", "si", "bp", "edi", "esi", "ebp"} + }, + seg = { + items = {"ds", "es", "fs", "gs"} + }, + rmptr = { + items = {"bx", "bp", "di", "si"} + } +} + +for _, reg_class in pairs(REG_CLASSES) do + local mask = 0 + for _, reg in ipairs(reg_class.items) do + mask = mask | REGS[reg].mask + end + reg_class.mask = mask +end + +return { + GPR_SIZE = 32, + REGS = REGS, + REG_CLASSES = REG_CLASSES, +}