#include"cg.h" #include #include #include #include #define REGS 4 static const char *regs[REGS][3] = {{"al", "ax", "eax"}, {"bl", "bx", "ebx"}, {"cl", "cx", "ecx"}, {"dl", "dx", "edx"}}; static const char *BINOP_SIMPLE_INSTRS[] = {[BINOP_ADD] = "add", [BINOP_SUB] = "sub", [BINOP_BITWISE_AND] = "and", [BINOP_BITWISE_OR] = "or", [BINOP_BITWISE_XOR] = "xor"}; static size_t nextLocalLabel = 0; #define LOOPSTACKSIZE 96 static size_t loopStackStart[LOOPSTACKSIZE]; static size_t loopStackEnd[LOOPSTACKSIZE]; static size_t loopStackIdx; static const char *direct(int size) { switch(size) { case 1: return "db"; case 2: return "dw"; case 4: return "dd"; case 8: return "dq"; } abort(); } static const char *spec(int size) { switch(size) { case 1: return "byte"; case 2: return "word"; case 4: return "dword"; case 8: return "qword"; } abort(); } static int log_size(int size) { switch(size) { case 1: return 0; case 2: return 1; case 4: return 2; case 8: return 3; } abort(); } static const char *specexpr(AST *e) { return spec(type_size(e->expression.type)); } static const char *xv(VarTableEntry *v) { assert(v->kind == VARTABLEENTRY_VAR); #define XVBUFS 8 #define XVBUFSZ 8 static char bufs[XVBUFS][XVBUFSZ]; static int bufidx = 0; char *ret = bufs[bufidx]; #ifdef DEBUG snprintf(ret, XVBUFSZ, "@%i", v->data.var.color); #else snprintf(ret, XVBUFSZ, "%s", regs[v->data.var.color][log_size(type_size(v->type))]); #endif bufidx = (bufidx + 1) % XVBUFS; return ret; } static const char *xj(BinaryOp op) { switch(op) { case BINOP_EQUAL: return "e"; case BINOP_NEQUAL: return "ne"; default: return "wtf"; } } static const char *xop(AST *e) { #define XOPBUFS 16 #define XOPBUFSZ 24 static char bufs[XOPBUFS][XOPBUFSZ]; static int bufidx = 0; char *ret = bufs[bufidx]; if(e->nodeKind == AST_EXPR_VAR) { VarTableEntry *v = e->exprVar.thing; if(v->kind == VARTABLEENTRY_VAR) { return xv(v); } else if(v->kind == VARTABLEENTRY_SYMBOL) { snprintf(ret, XOPBUFSZ, "[%s]", v->data.symbol.name); } else abort(); } else if(e->nodeKind == AST_EXPR_PRIMITIVE) { snprintf(ret, XOPBUFSZ, "%s %i", specexpr(e), e->exprPrim.val); } else if(e->nodeKind == AST_EXPR_UNARY_OP && e->exprUnOp.operator == UNOP_DEREF && e->exprUnOp.operand->nodeKind == AST_EXPR_BINARY_OP && e->exprUnOp.operand->exprBinOp.operator == BINOP_ADD && e->exprUnOp.operand->exprBinOp.operands[0]->nodeKind == AST_EXPR_UNARY_OP && e->exprUnOp.operand->exprBinOp.operands[1]->nodeKind == AST_EXPR_VAR && e->exprUnOp.operand->exprBinOp.operands[0]->exprUnOp.operator == UNOP_REF && e->exprUnOp.operand->exprBinOp.operands[0]->exprUnOp.operand->nodeKind == AST_EXPR_VAR && e->exprUnOp.operand->exprBinOp.operands[0]->exprUnOp.operand->exprVar.thing->kind == VARTABLEENTRY_SYMBOL && e->exprUnOp.operand->exprBinOp.operands[1]->exprVar.thing->kind == VARTABLEENTRY_VAR) { snprintf(ret, XOPBUFSZ, "[%s + %s]", e->exprUnOp.operand->exprBinOp.operands[0]->exprUnOp.operand->exprVar.thing->data.symbol.name, xv(e->exprUnOp.operand->exprBinOp.operands[1]->exprVar.thing)); } else if(e->nodeKind == AST_EXPR_UNARY_OP && e->exprUnOp.operator == UNOP_DEREF && e->exprUnOp.operand->nodeKind == AST_EXPR_BINARY_OP && e->exprUnOp.operand->exprBinOp.operator == BINOP_ADD && e->exprUnOp.operand->exprBinOp.operands[0]->nodeKind == AST_EXPR_UNARY_OP && e->exprUnOp.operand->exprBinOp.operands[1]->nodeKind == AST_EXPR_BINARY_OP && e->exprUnOp.operand->exprBinOp.operands[0]->exprUnOp.operator == UNOP_REF && e->exprUnOp.operand->exprBinOp.operands[0]->exprUnOp.operand->nodeKind == AST_EXPR_VAR && e->exprUnOp.operand->exprBinOp.operands[0]->exprUnOp.operand->exprVar.thing->kind == VARTABLEENTRY_SYMBOL && e->exprUnOp.operand->exprBinOp.operands[1]->exprBinOp.operator == BINOP_MUL && e->exprUnOp.operand->exprBinOp.operands[1]->exprBinOp.operands[1]->nodeKind == AST_EXPR_VAR && e->exprUnOp.operand->exprBinOp.operands[1]->exprBinOp.operands[0]->nodeKind == AST_EXPR_PRIMITIVE && e->exprUnOp.operand->exprBinOp.operands[1]->exprBinOp.operands[1]->exprVar.thing->kind == VARTABLEENTRY_VAR) { snprintf(ret, XOPBUFSZ, "[%s + %i * %s]", e->exprUnOp.operand->exprBinOp.operands[0]->exprUnOp.operand->exprVar.thing->data.symbol.name, e->exprUnOp.operand->exprBinOp.operands[1]->exprBinOp.operands[0]->exprPrim.val, xv(e->exprUnOp.operand->exprBinOp.operands[1]->exprBinOp.operands[1]->exprVar.thing)); } else if(e->nodeKind == AST_EXPR_UNARY_OP && e->exprUnOp.operator == UNOP_REF && e->exprUnOp.operand->nodeKind == AST_EXPR_VAR && e->exprUnOp.operand->exprVar.thing->kind == VARTABLEENTRY_SYMBOL) { snprintf(ret, XOPBUFSZ, "%s", e->exprUnOp.operand->exprVar.thing->data.symbol.name); } else if(e->nodeKind == AST_EXPR_UNARY_OP && e->exprUnOp.operator == UNOP_DEREF && e->exprUnOp.operand->nodeKind == AST_EXPR_VAR && e->exprUnOp.operand->exprVar.thing->kind == VARTABLEENTRY_VAR) { snprintf(ret, XOPBUFSZ, "[%s]", xv(e->exprUnOp.operand->exprVar.thing)); } else { return NULL; } bufidx = (bufidx + 1) % XOPBUFS; return ret; } void cg_chunk(AST *a) { AST *s = a->chunk.statementFirst; // Potentially complex pattern matching while(s) { if(s->nodeKind == AST_STMT_EXT_SECTION) { Token t = s->stmtExtSection.name; printf("section %.*s\n", (int) t.length, t.content); } else if(s->nodeKind == AST_STMT_EXT_ORG) { printf("org %lu\n", s->stmtExtOrg.val); } else if(s->nodeKind == AST_STMT_EXT_ALIGN) { uint32_t val = s->stmtExtAlign.val; if((val & (val - 1))) { // nasm does not support non-PoT alignments, so pad manually printf("times ($ - $$ + %u) / %u * %u - ($ - $$) db 0\n", val - 1, val, val); } else { printf("align %u\n", val); } } else if(s->nodeKind == AST_STMT_DECL && s->stmtDecl.thing->kind == VARTABLEENTRY_SYMBOL) { VarTableEntry *v = s->stmtDecl.thing; if(v->data.symbol.isExternal) { printf("extern %s\n", v->data.symbol.name); } else { if(!v->data.symbol.isLocal) { printf("global %s\n", v->data.symbol.name); } if(s->stmtDecl.expression) { printf("%s:", v->data.symbol.name); if(v->type->type == TYPE_TYPE_PRIMITIVE) { assert(s->stmtDecl.expression->nodeKind == AST_EXPR_PRIMITIVE); printf("%s %i", direct(type_size(v->type)), s->stmtDecl.expression->exprPrim.val); } else if(v->type->type == TYPE_TYPE_ARRAY && v->type->array.of->type == TYPE_TYPE_PRIMITIVE) { printf("%s ", direct(type_size(v->type->array.of))); for(size_t i = 0; i < v->type->array.length; i++) { printf("%i,", s->stmtDecl.expression->exprArray.items[i]->exprPrim.val); } } else printf("A"); putchar('\n'); } else { printf("%s resb %lu\n", v->data.symbol.name, type_size(s->stmtDecl.thing->type)); } } } else if(s->nodeKind == AST_STMT_ASSIGN) { if(s->stmtAssign.to->nodeKind == AST_EXPR_BINARY_OP && ast_expression_equal(s->stmtAssign.what, s->stmtAssign.to->exprBinOp.operands[0]) && (s->stmtAssign.to->exprBinOp.operator == BINOP_ADD || s->stmtAssign.to->exprBinOp.operator == BINOP_SUB) && s->stmtAssign.to->exprBinOp.operands[1]->nodeKind == AST_EXPR_PRIMITIVE && s->stmtAssign.to->exprBinOp.operands[1]->exprPrim.val == 1) { // inc or dec static const char *instrs[] = {"inc", "dec"}; printf("%s %s %s\n", instrs[s->stmtAssign.to->exprBinOp.operator == BINOP_SUB], specexpr(s->stmtAssign.what), xop(s->stmtAssign.what)); } else if(s->stmtAssign.what->nodeKind == AST_EXPR_VAR && s->stmtAssign.to->nodeKind == AST_EXPR_BINARY_OP && s->stmtAssign.to->exprBinOp.operator == BINOP_ADD && s->stmtAssign.to->exprBinOp.operands[0]->nodeKind == AST_EXPR_VAR && s->stmtAssign.to->exprBinOp.operands[1]->nodeKind == AST_EXPR_VAR && s->stmtAssign.to->exprBinOp.operands[0]->exprVar.thing->kind == VARTABLEENTRY_VAR && s->stmtAssign.to->exprBinOp.operands[1]->exprVar.thing->kind == VARTABLEENTRY_VAR) { printf("lea %s, [%s + %s]\n", xv(s->stmtAssign.what->exprVar.thing), xv(s->stmtAssign.to->exprBinOp.operands[0]->exprVar.thing), xv(s->stmtAssign.to->exprBinOp.operands[1]->exprVar.thing)); } else if(s->stmtAssign.what->nodeKind == AST_EXPR_VAR && s->stmtAssign.to->nodeKind == AST_EXPR_BINARY_OP && s->stmtAssign.to->exprBinOp.operator == BINOP_ADD && s->stmtAssign.to->exprBinOp.operands[0]->nodeKind == AST_EXPR_UNARY_OP && s->stmtAssign.to->exprBinOp.operands[0]->exprUnOp.operator == UNOP_REF && s->stmtAssign.to->exprBinOp.operands[0]->exprUnOp.operand->nodeKind == AST_EXPR_VAR && s->stmtAssign.to->exprBinOp.operands[1]->nodeKind == AST_EXPR_VAR && s->stmtAssign.to->exprBinOp.operands[0]->exprUnOp.operand->exprVar.thing->kind == VARTABLEENTRY_SYMBOL && s->stmtAssign.to->exprBinOp.operands[1]->exprVar.thing->kind == VARTABLEENTRY_VAR) { printf("lea %s, [%s + %s]\n", xv(s->stmtAssign.what->exprVar.thing), s->stmtAssign.to->exprBinOp.operands[0]->exprUnOp.operand->exprVar.thing->data.symbol.name, xv(s->stmtAssign.to->exprBinOp.operands[1]->exprVar.thing)); } else { printf("mov %s, %s\n", xop(s->stmtAssign.what), xop(s->stmtAssign.to)); } } else if(s->nodeKind == AST_STMT_LOOP) { size_t lbl0 = nextLocalLabel++; size_t lbl1 = nextLocalLabel++; loopStackStart[loopStackIdx] = lbl0; loopStackEnd[loopStackIdx] = lbl1; loopStackIdx++; printf(".L%lu:\n", lbl0); cg_chunk(s->stmtLoop.body); printf("jmp .L%lu\n", lbl0); printf(".L%lu:\n", lbl1); loopStackIdx--; } else if(s->nodeKind == AST_STMT_BREAK) { printf("jmp .L%lu\n", loopStackEnd[loopStackIdx - 1]); } else if(s->nodeKind == AST_STMT_CONTINUE) { printf("jmp .L%lu\n", loopStackStart[loopStackIdx - 1]); } else if(s->nodeKind == AST_STMT_IF) { assert(s->stmtIf.expression->nodeKind == AST_EXPR_BINARY_OP && binop_is_comparison(s->stmtIf.expression->exprBinOp.operator)); size_t lbl = nextLocalLabel++; printf("cmp %s %s, %s\n", specexpr(s->stmtIf.expression->exprBinOp.operands[0]), xop(s->stmtIf.expression->exprBinOp.operands[0]), xop(s->stmtIf.expression->exprBinOp.operands[1])); printf("j%s .L%lu\n", xj(binop_comp_opposite(s->stmtIf.expression->exprBinOp.operator)), lbl); cg_chunk(s->stmtIf.then); printf(".L%lu:\n", lbl); } else if(s->nodeKind == AST_STMT_EXPR) { AST *e = s->stmtExpr.expr; if(e->nodeKind == AST_EXPR_CALL) { puts("push eax"); puts("push ecx"); puts("push edx"); int argCount = e->exprCall.what->expression.type->function.argCount; size_t argSize = 0; for(int i = argCount - 1; i >= 0; i--) { printf("push %s\n", xop(e->exprCall.args[i])); argSize += (type_size(e->exprCall.args[i]->expression.type) + 3) & ~3; } assert(e->exprCall.what->nodeKind == AST_EXPR_VAR && e->exprCall.what->exprVar.thing->kind == VARTABLEENTRY_SYMBOL); printf("call %s\n", e->exprCall.what->exprVar.thing->data.symbol.name); printf("add esp, %lu\n", argSize); puts("pop edx"); puts("pop ecx"); puts("pop eax"); } } s = s->statement.next; } } /* Welsh-Powell graph coloring */ static int comparator(const void *A, const void *B) { VarTableEntry *const *a = A; VarTableEntry *const *b = B; return ((*a)->data.var.degree * (*a)->data.var.priority) - ((*b)->data.var.degree * (*b)->data.var.priority); } void cg_go(AST *a) { typedef VarTableEntry *Adjacency[2]; size_t adjCount = 0; Adjacency *adjs = malloc(sizeof(*adjs) * adjCount); VarTableEntry **vars = a->chunk.vars; for(size_t v1i = 0; v1i < a->chunk.varCount; v1i++) { for(size_t v2i = 0; v2i < a->chunk.varCount; v2i++) { if(v1i == v2i) continue; VarTableEntry *v1 = vars[v1i]; VarTableEntry *v2 = vars[v2i]; /* 1D intersection test */ // if((v1->data.var.start >= v2->data.var.start && v1->data.var.start <= v2->data.var.end) // || (v1->data.var.end >= v2->data.var.start && v1->data.var.end <= v2->data.var.end)) { if( (ast_stmt_is_after(a, v1->data.var.usedefFirst->stmt, v2->data.var.usedefFirst->stmt) == 1 && ast_stmt_is_after(a, v2->data.var.usedefLast->stmt, v1->data.var.usedefFirst->stmt) == 1) || (ast_stmt_is_after(a, v1->data.var.usedefLast->stmt, v2->data.var.usedefFirst->stmt) == 1 && ast_stmt_is_after(a, v2->data.var.usedefLast->stmt, v1->data.var.usedefLast->stmt) == 1) ) { VarTableEntry *min = v1 < v2 ? v1 : v2; VarTableEntry *max = v1 < v2 ? v2 : v1; for(size_t a = 0; a < adjCount; a++) { if(adjs[a][0] == min && adjs[a][1] == max) { goto cont; } } adjs = realloc(adjs, sizeof(*adjs) * ++adjCount); adjs[adjCount - 1][0] = min; adjs[adjCount - 1][1] = max; cont:; } } } for(size_t a = 0; a < adjCount; a++) { adjs[a][0]->data.var.degree++; adjs[a][1]->data.var.degree++; } qsort(vars, a->chunk.varCount, sizeof(*vars), comparator); /* Welsh plow my ass */ for(int v = 0; v < a->chunk.varCount; v++) { for(int c = 0;; c++) { for(int a = 0; a < adjCount; a++) { if(adjs[a][0] == vars[v] && adjs[a][1]->data.var.color == c) { goto nextColor; } else if(adjs[a][1] == vars[v] && adjs[a][0]->data.var.color == c) { goto nextColor; } } vars[v]->data.var.color = c; break; nextColor:; } } free(adjs); cg_chunk(a); free(vars); }