#include"ast.h" #include #include #include #include #include const char *AST_KIND_STR[] = { AST_KINDS(GEN_STRI) }; void generic_visitor(AST **nptr, AST *stmt, AST *stmtPrev, AST *chu, AST *tlc, void *ud, GenericVisitorHandler preHandler, GenericVisitorHandler postHandler) { if(preHandler) preHandler(nptr, stmt, stmtPrev, chu, tlc, ud); AST *n = *nptr; if(n->nodeKind == AST_CHUNK) { AST *sPrev = NULL; AST **s = &n->chunk.statementFirst; while(*s) { generic_visitor(s, *s, sPrev, n, tlc, ud, preHandler, postHandler); sPrev = *s; s = &sPrev->statement.next; } } else if(n->nodeKind == AST_STMT_ASSIGN) { generic_visitor(&n->stmtAssign.what, stmt, stmtPrev, chu, tlc, ud, preHandler, postHandler); if(n->stmtAssign.to) { generic_visitor(&n->stmtAssign.to, stmt, stmtPrev, chu, tlc, ud, preHandler, postHandler); } } else if(n->nodeKind == AST_STMT_IF) { generic_visitor(&n->stmtIf.expression, stmt, stmtPrev, chu, tlc, ud, preHandler, postHandler); generic_visitor(&n->stmtIf.then, stmt, stmtPrev, chu, tlc, ud, preHandler, postHandler); } else if(n->nodeKind == AST_STMT_LOOP) { generic_visitor(&n->stmtLoop.body, stmt, stmtPrev, chu, tlc, ud, preHandler, postHandler); } else if(n->nodeKind == AST_STMT_BREAK) { } else if(n->nodeKind == AST_STMT_CONTINUE) { } else if(n->nodeKind == AST_STMT_EXT_ALIGN) { } else if(n->nodeKind == AST_STMT_DECL) { if(n->stmtDecl.expression) { generic_visitor(&n->stmtDecl.expression, stmt, stmtPrev, chu, tlc, ud, preHandler, postHandler); } } else if(n->nodeKind == AST_STMT_EXPR) { generic_visitor(&n->stmtExpr.expr, stmt, stmtPrev, chu, tlc, ud, preHandler, postHandler); } else if(n->nodeKind == AST_STMT_EXT_ORG) { } else if(n->nodeKind == AST_STMT_EXT_SECTION) { } else if(n->nodeKind == AST_STMT_RETURN) { if(n->stmtReturn.val) { generic_visitor(&n->stmtReturn.val, stmt, stmtPrev, chu, tlc, ud, preHandler, postHandler); } } else if(n->nodeKind == AST_EXPR_BINARY_OP) { generic_visitor(&n->exprBinOp.operands[0], stmt, stmtPrev, chu, tlc, ud, preHandler, postHandler); generic_visitor(&n->exprBinOp.operands[1], stmt, stmtPrev, chu, tlc, ud, preHandler, postHandler); } else if(n->nodeKind == AST_EXPR_CALL) { generic_visitor(&n->exprCall.what, stmt, stmtPrev, chu, tlc, ud, preHandler, postHandler); for(size_t i = 0; i < n->exprCall.what->expression.type->function.argCount; i++) { generic_visitor(&n->exprCall.args[i], stmt, stmtPrev, chu, tlc, ud, preHandler, postHandler); } } else if(n->nodeKind == AST_EXPR_CAST) { generic_visitor(&n->exprCast.what, stmt, stmtPrev, chu, tlc, ud, preHandler, postHandler); } else if(n->nodeKind == AST_EXPR_FUNC) { generic_visitor(&n->exprFunc.chunk, NULL, NULL, n->exprFunc.chunk, n->exprFunc.chunk, ud, preHandler, postHandler); } else if(n->nodeKind == AST_EXPR_UNARY_OP) { generic_visitor(&n->exprUnOp.operand, stmt, stmtPrev, chu, tlc, ud, preHandler, postHandler); } else if(n->nodeKind == AST_EXPR_VAR) { } else if(n->nodeKind == AST_EXPR_STACK_POINTER) { } else if(n->nodeKind == AST_EXPR_PRIMITIVE) { } else if(n->nodeKind == AST_EXPR_STRING_LITERAL) { } else if(n->nodeKind == AST_EXPR_ARRAY) { assert(n->expression.type->type == TYPE_TYPE_ARRAY); assert(n->expression.type->array.length != 0); for(size_t i = 0; i < n->expression.type->array.length; i++) { generic_visitor(&n->exprArray.items[i], stmt, stmtPrev, chu, tlc, ud, preHandler, postHandler); } } else if(n->nodeKind == AST_EXPR_EXT_SALLOC) { } else if(n->nodeKind == AST_EXPR_DOT) { generic_visitor(&n->exprDot.a, stmt, stmtPrev, chu, tlc, ud, preHandler, postHandler); } else { abort(); } if(postHandler) postHandler(nptr, stmt, stmtPrev, chu, tlc, ud); } int ast_expression_equal(AST *a, AST *b) { if(a->nodeKind != b->nodeKind) return 0; if(a->nodeKind == AST_EXPR_PRIMITIVE) { return a->exprPrim.val == b->exprPrim.val; } else if(a->nodeKind == AST_EXPR_VAR) { return a->exprVar.thing == b->exprVar.thing; } else if(a->nodeKind == AST_EXPR_UNARY_OP) { return a->exprUnOp.operator == b->exprUnOp.operator && ast_expression_equal(a->exprUnOp.operand, b->exprUnOp.operand); } else if(a->nodeKind == AST_EXPR_BINARY_OP) { return a->exprBinOp.operator == b->exprBinOp.operator && ast_expression_equal(a->exprBinOp.operands[0], b->exprBinOp.operands[0]) && ast_expression_equal(a->exprBinOp.operands[1], b->exprBinOp.operands[1]); } else if(a->nodeKind == AST_EXPR_STACK_POINTER) { return 1; } else if(a->nodeKind == AST_EXPR_CAST) { return ast_expression_equal(a->exprCast.what, b->exprCast.what) && type_equal(a->exprCast.to, b->exprCast.to) && a->exprCast.reinterpretation == b->exprCast.reinterpretation; } else { stahp_node(a, "ast_expression_equal: unhandled %s", AST_KIND_STR[a->nodeKind]); } } // This function may return three values: YES (1), NO (0) or UNKNOWN (-1). // ... Ew int ast_stmt_is_after(const AST *chunk, const AST *s1, const AST *s2) { const AST *s = chunk->chunk.statementFirst; while(1) { if(s && s->nodeKind == AST_STMT_LOOP) { int i = ast_stmt_is_after(s->stmtLoop.body, s1, s2); if(i != -1) { return i; } } if(s == s1) { return 0; } if(s == s2) { return 1; } if(!s) break; if(s->nodeKind == AST_STMT_IF) { int i = ast_stmt_is_after(s->stmtIf.then, s1, s2); if(i != -1) { return i; } } s = s->statement.next; } return -1; } /* * This pass is necessary for the purposes of optimization and regalloc. * Because an AST may hold outdated UD-chains, this pass MUST be called * before using the UD-chains to make sure they are valid. * * Each local var (VTE of kind VARTABLEENTRY_VAR) holds its own UD-chain * that specifies the exact nodes in the AST where: * 1. It is used * 2. The whole statement in which it is used * 3. The definition that *might* be in use * * Because multiple definitions may be in use (reachable) at the point * of the use, a unique UseDef for each possible definition is appended * to the chain. * * Reachable definitions are kept track in a ReachingDefs, also held by * each VTE. In the case of a single, simple block of code, we know * exactly one definition (including undefined) can reach each variable, * which would simplify the ReachingDefs structure to a single * definition pointer. * * Unfortunately, conditional blocks and loops ruin this simplicity. * If you have code like * x = A * if B { * x = C * } * then afterward two definitions may apply to x. * * A solution here is to lay ReachingDefs as a graph, with each * ReachingDefs having an optional parent. When we enter a new block of * code, we create an empty ReachingDefs with the previous block as its * parent. Any definitions replace the ones in the deepest * ReachingDefs only. * * How we exit a block depends on its type. If it is conditional, * the reaching definitions should join the parent (mergedefs). * If the block is a loop, it is even worse. Given * x = A * loop { * use x * x = B * } * definitions can apply to uses that come before it! * * Also, a different case: * x = A * loop { * use x * y = B * } * Because, technically, the last use of x is before y = B, y and x may * be assigned the same physical location, corrupting data as a result. * To fix this, fake, "useless" statements are inserted during parsing * that make the AST look as such: * x = A * loop { * use x * y = B * } * x; * Until dead code removal is implemented, this will not be a problem. */ static void rawadduse(VarTableEntry *vte, UseDef *ud) { assert(vte->kind == VARTABLEENTRY_VAR); assert(ud->next == NULL); assert(!!vte->data.var.usedefFirst == !!vte->data.var.usedefLast); if(!vte->data.var.usedefFirst) { vte->data.var.usedefFirst = vte->data.var.usedefLast = ud; } else { vte->data.var.usedefLast->next = ud; vte->data.var.usedefLast = ud; } } static void adduse(VarTableEntry *vte, AST *use, AST *whole) { assert(vte->kind == VARTABLEENTRY_VAR); assert(vte->data.var.reachingDefs != NULL); ReachingDefs *rd = vte->data.var.reachingDefs; while(rd && rd->defCount == 0) rd = rd->parent; if(!rd) return; for(size_t d = 0; d < rd->defCount; d++) { UseDef *ud = calloc(1, sizeof(*ud)); ud->def = rd->defs[d]; ud->use = use; ud->stmt = whole; ud->next = NULL; rawadduse(vte, ud); } } static void overwritedefs(VarTableEntry *vte, AST *def) { assert(vte->kind == VARTABLEENTRY_VAR); if(!vte->data.var.reachingDefs) { vte->data.var.reachingDefs = calloc(1, sizeof(*vte->data.var.reachingDefs)); } vte->data.var.reachingDefs->defCount = 1; if(!vte->data.var.reachingDefs->defs) { vte->data.var.reachingDefs->defs = calloc(1, sizeof(*vte->data.var.reachingDefs->defs)); } vte->data.var.reachingDefs->defs[0] = def; } static void mergedefs(VarTableEntry *vte) { assert(vte->kind == VARTABLEENTRY_VAR); ReachingDefs *rdefs = vte->data.var.reachingDefs; assert(rdefs != NULL); assert(rdefs->parent != NULL); rdefs->parent->defs = realloc(rdefs->parent->defs, sizeof(*rdefs->parent->defs) * (rdefs->parent->defCount + rdefs->defCount)); memcpy(rdefs->parent->defs + rdefs->parent->defCount, rdefs->defs, rdefs->defCount * sizeof(*rdefs->defs)); vte->data.var.reachingDefs = rdefs->parent; free(rdefs->defs); free(rdefs); } static void pushdefs(VarTableEntry *vte) { assert(vte->kind == VARTABLEENTRY_VAR); ReachingDefs *rdefs = calloc(1, sizeof(*rdefs)); rdefs->defCount = 0; rdefs->defs = NULL; rdefs->excludeParent = 0; rdefs->parent = vte->data.var.reachingDefs; vte->data.var.reachingDefs = rdefs; } static void pushdefsall(AST *tlc) { for(size_t i = 0; i < tlc->chunk.varCount; i++) { pushdefs(tlc->chunk.vars[i]); } } static void mergedefsall(AST *tlc) { for(size_t i = 0; i < tlc->chunk.varCount; i++) { mergedefs(tlc->chunk.vars[i]); } } static void mergedefsloop(AST *tlc, VarTableEntry *vte, AST *daLoopStmt) { assert(vte->kind == VARTABLEENTRY_VAR); for(size_t d = 0; d < vte->data.var.reachingDefs->defCount; d++) { UseDef *ud = vte->data.var.usedefFirst; while(ud) { if(ast_stmt_is_after(daLoopStmt->stmtLoop.body, NULL, ud->stmt) == 1 && ud->def != vte->data.var.reachingDefs->defs[d]) { UseDef *udnew = calloc(1, sizeof(*udnew)); udnew->next = ud->next; ud->next = udnew; udnew->def = vte->data.var.reachingDefs->defs[d]; udnew->use = ud->use; udnew->stmt = ud->stmt; if(udnew->next == NULL) { vte->data.var.usedefLast = udnew; } } ud = ud->next; } } mergedefs(vte); } static void mergedefsloopall(AST *tlc, AST *daLoopStmt) { for(size_t i = 0; i < tlc->chunk.varCount; i++) { mergedefsloop(tlc, tlc->chunk.vars[i], daLoopStmt); } } static void ast_usedef_pass(AST *tlc, AST *a, AST *wholestmt) { if(a->nodeKind == AST_CHUNK) { for(AST *s = a->chunk.statementFirst; s; s = s->statement.next) { ast_usedef_pass(tlc, s, s); } } else if(a->nodeKind == AST_STMT_IF) { pushdefsall(tlc); ast_usedef_pass(tlc, a->stmtIf.expression, wholestmt); ast_usedef_pass(tlc, a->stmtIf.then, wholestmt); mergedefsall(tlc); } else if(a->nodeKind == AST_STMT_LOOP) { pushdefsall(tlc); ast_usedef_pass(tlc, a->stmtLoop.body, wholestmt); mergedefsloopall(tlc, a); } else if(a->nodeKind == AST_STMT_ASSIGN) { if(a->stmtAssign.what->nodeKind == AST_EXPR_VAR && a->stmtAssign.what->exprVar.thing->kind == VARTABLEENTRY_VAR) { overwritedefs(a->stmtAssign.what->exprVar.thing, a); } ast_usedef_pass(tlc, a->stmtAssign.what, wholestmt); if(a->stmtAssign.to) { ast_usedef_pass(tlc, a->stmtAssign.to, wholestmt); } } else if(a->nodeKind == AST_STMT_EXPR) { ast_usedef_pass(tlc, a->stmtExpr.expr, wholestmt); } else if(a->nodeKind == AST_EXPR_VAR) { if(a->exprVar.thing->kind == VARTABLEENTRY_VAR) { adduse(a->exprVar.thing, a, wholestmt); } } else if(a->nodeKind == AST_EXPR_BINARY_OP) { ast_usedef_pass(tlc, a->exprBinOp.operands[0], wholestmt); ast_usedef_pass(tlc, a->exprBinOp.operands[1], wholestmt); } else if(a->nodeKind == AST_EXPR_UNARY_OP) { ast_usedef_pass(tlc, a->exprUnOp.operand, wholestmt); } else if(a->nodeKind == AST_EXPR_CALL) { ast_usedef_pass(tlc, a->exprCall.what, wholestmt); for(size_t p = 0; p < a->exprCall.what->expression.type->function.argCount; p++) { ast_usedef_pass(tlc, a->exprCall.args[p], wholestmt); } } else if(a->nodeKind == AST_EXPR_PRIMITIVE) { } else if(a->nodeKind == AST_EXPR_STRING_LITERAL) { } else if(a->nodeKind == AST_EXPR_CAST) { ast_usedef_pass(tlc, a->exprCast.what, wholestmt); } else if(a->nodeKind == AST_EXPR_STACK_POINTER) { } else if(a->nodeKind == AST_EXPR_EXT_SALLOC) { } else if(a->nodeKind == AST_STMT_BREAK) { } else if(a->nodeKind == AST_STMT_CONTINUE) { } else if(a->nodeKind == AST_STMT_EXT_ALIGN) { } else if(a->nodeKind == AST_STMT_EXT_ORG) { } else if(a->nodeKind == AST_STMT_EXT_SECTION) { } else if(a->nodeKind == AST_STMT_DECL) { assert(a->stmtDecl.thing->kind != VARTABLEENTRY_VAR || a->stmtDecl.expression); } else if(a->nodeKind == AST_STMT_RETURN) { if(a->stmtReturn.val) { ast_usedef_pass(tlc, a->stmtReturn.val, wholestmt); } } else { abort(); } } void ast_usedef_reset(AST *chu) { for(size_t i = 0; i < chu->chunk.varCount; i++) { VarTableEntry *vte = chu->chunk.vars[i]; assert(vte->kind == VARTABLEENTRY_VAR); vte->data.var.reachingDefs = NULL; vte->data.var.usedefFirst = NULL; vte->data.var.usedefLast = NULL; } pushdefsall(chu); return ast_usedef_pass(chu, chu, NULL); } static char *cat(char *a, const char *b) { if(!a) { return strdup(b); } a = realloc(a, strlen(a) + strlen(b) + 1); strcpy(a + strlen(a), b); return a; } __attribute__((format(printf, 1, 2))) char *malp(const char *fmt, ...) { va_list v1, v2; va_start(v1, fmt); va_copy(v2, v1); size_t len = vsnprintf(NULL, 0, fmt, v1); va_end(v1); va_start(v2, fmt); char *str = malloc(len + 1); vsnprintf(str, len + 1, fmt, v2); str[len] = 0; va_end(v2); return str; } char *type_to_string(Type *t) { if(t->type == TYPE_TYPE_PRIMITIVE) { char ret[16] = {}; int i = 0; ret[i++] = t->primitive.isFloat ? 'f' : (t->primitive.isUnsigned ? 'u' : 'i'); snprintf(ret + i, sizeof(ret) - i, "%i", t->primitive.width); return strdup(ret); } else if(t->type == TYPE_TYPE_POINTER) { char *c = type_to_string(t->pointer.of); char *r = malp("%s*", c); free(c); return r; } else if(t->type == TYPE_TYPE_RECORD) { return malp("%s", t->record.name); } else if(t->type == TYPE_TYPE_GENERIC) { return malp("%s", t->generic.paramName); } else if(t->type == TYPE_TYPE_ARRAY) { char *of = type_to_string(t->array.of); char *len = NULL; if(t->array.lengthIsGeneric) { len = malp(""); } else { len = malp("%i", t->array.length); } char *r = malp("%s[%s]", of, len); free(of); free(len); return r; } return strdup("@unimp"); } static char *ast_dumpe(AST *e) { if(e->nodeKind == AST_EXPR_PRIMITIVE) { return malp("%i", e->exprPrim.val); } else if(e->nodeKind == AST_EXPR_VAR) { VarTableEntry *vte = e->exprVar.thing; if(vte->kind == VARTABLEENTRY_VAR) { return strdup(vte->data.var.name); } else if(vte->kind == VARTABLEENTRY_SYMBOL) { return strdup(vte->data.symbol.name); } else abort(); } else if(e->nodeKind == AST_EXPR_UNARY_OP) { const char *op = NULL; switch(e->exprUnOp.operator) { case UNOP_REF: op = "&"; break; case UNOP_DEREF: op = "*"; break; case UNOP_BITWISE_NOT: op = "~"; break; case UNOP_NEGATE: op = "-"; break; default: abort(); } char *c = ast_dumpe(e->exprUnOp.operand); char *r = malp("%s%s", op, c); free(c); return r; } else if(e->nodeKind == AST_EXPR_BINARY_OP) { char *a = ast_dumpe(e->exprBinOp.operands[0]); char *b = ast_dumpe(e->exprBinOp.operands[1]); const char *op; switch(e->exprBinOp.operator) { case BINOP_ADD: op = "+"; break; case BINOP_SUB: op = "-"; break; case BINOP_MUL: op = "*"; break; case BINOP_DIV: op = "/"; break; case BINOP_BITWISE_AND: op = "&"; break; case BINOP_BITWISE_OR: op = "|"; break; case BINOP_BITWISE_XOR: op = "^"; break; case BINOP_EQUAL: op = "=="; break; case BINOP_NEQUAL: op = "!="; break; case BINOP_LESS: op = "<"; break; case BINOP_GREATER: op = ">"; break; case BINOP_LEQUAL: op = "<="; break; case BINOP_GEQUAL: op = ">="; break; case BINOP_MULHI: op = "*^"; break; default: abort(); } char *r = malp("(%s %s %s)", a, op, b); free(a); free(b); return r; } else if(e->nodeKind == AST_EXPR_STACK_POINTER) { return malp("@stack"); } else if(e->nodeKind == AST_EXPR_FUNC) { char *out = NULL; if(type_is_generic(e->expression.type)) { out = malp("(generic)"); return out; } { char *rettype = type_to_string(e->expression.type->function.ret); out = malp("%s(", rettype); free(rettype); } for(int i = 0; i < e->expression.type->function.argCount; i++) { char *argtype = type_to_string(e->expression.type->function.args[i]); char *out2 = malp(i == e->expression.type->function.argCount - 1 ? "%s%s" : "%s%s, ", out, argtype); free(out); free(argtype); out = out2; } { char *choonk = ast_dump(e->exprFunc.chunk); char *out2 = malp("%s) {\n%s}", out, choonk); free(out); free(choonk); out = out2; } return out; } else if(e->nodeKind == AST_EXPR_CALL) { char *w = ast_dumpe(e->exprCall.what); char *out = malp("%s(", w); free(w); size_t argCount = e->exprCall.what->expression.type->function.argCount; for(size_t i = 0; i < argCount; i++) { char *a = ast_dumpe(e->exprCall.args[i]); char *out2 = malp(i == argCount - 1 ? "%s%s)" : "%s%s, ", out, a); free(a); free(out); out = out2; } return out; } else if(e->nodeKind == AST_EXPR_EXT_SALLOC) { char *w = type_to_string(e->exprExtSalloc.size); char *out = malp("@salloc(%s)", w); free(w); return out; } else if(e->nodeKind == AST_EXPR_CAST) { char *a = ast_dumpe(e->exprCast.what); char *b = type_to_string(e->exprCast.to); char *out = malp("(%s as %s)", a, b); free(a); free(b); return out; } else if(e->nodeKind == AST_EXPR_DOT) { char *a = ast_dumpe(e->exprDot.a); char *out = malp(e->nodeKind == AST_EXPR_BINARY_OP ? "(%s).%s" : "%s.%s", a, e->exprDot.b); free(a); return out; } return malp("@unimp:%s", AST_KIND_STR[e->nodeKind]); } char *ast_dump(AST *tlc); static char *ast_dumps(AST *s) { if(s->nodeKind == AST_STMT_DECL) { VarTableEntry *vte = s->stmtDecl.thing; if(vte->kind == VARTABLEENTRY_SYMBOL) { char *t = type_to_string(vte->type); char *e = s->stmtDecl.expression ? ast_dumpe(s->stmtDecl.expression) : strdup(""); char *r = malp("%s%s %s: %s;\n", vte->data.symbol.isExternal ? "external " : "", t, vte->data.symbol.name, e); free(t); free(e); return r; } } else if(s->nodeKind == AST_STMT_ASSIGN) { if(s->stmtAssign.to) { char *a = ast_dumpe(s->stmtAssign.what); char *b = ast_dumpe(s->stmtAssign.to); char *r = malp("%s = %s;\n", a, b); free(a); free(b); return r; } else { char *a = ast_dumpe(s->stmtAssign.what); char *r = malp("%s = ; /* fake def */\n", a); free(a); return r; } } else if(s->nodeKind == AST_STMT_LOOP) { char *inner = ast_dump(s->stmtLoop.body); char *c = malp("loop {\n%s}\n", inner); free(inner); return c; } else if(s->nodeKind == AST_STMT_IF) { char *cond = ast_dumpe(s->stmtIf.expression); char *inner = ast_dump(s->stmtIf.then); char *c = malp("if(%s) {\n%s}\n", cond, inner); free(cond); free(inner); return c; } else if(s->nodeKind == AST_STMT_EXPR && s->stmtExpr.expr->nodeKind == AST_EXPR_VAR) { const char *name; if(s->stmtExpr.expr->exprVar.thing->kind == VARTABLEENTRY_VAR) { name = s->stmtExpr.expr->exprVar.thing->data.var.name; } else { name = s->stmtExpr.expr->exprVar.thing->data.symbol.name; } return malp("%s; /* loop guard */\n", name); } else if(s->nodeKind == AST_STMT_EXPR) { return ast_dumpe(s->stmtExpr.expr); } else if(s->nodeKind == AST_STMT_RETURN) { if(s->stmtReturn.val) { char *e = ast_dumpe(s->stmtReturn.val); char *c = malp("return %s;\n", e); free(e); return c; } else { return malp("return;\n"); } } return malp("@unimp:%s\n", AST_KIND_STR[s->nodeKind]); } char *ast_dump(AST *tlc) { AST *stmt = tlc->chunk.statementFirst; char *ret = strdup(""); #define CAT(s) do { char *b = s; ret = cat(ret, (b)); free(b); } while(0) while(stmt) { CAT(ast_dumps(stmt)); stmt = stmt->statement.next; } return ret; } static void *memdup(void *a, size_t len) { void *r = malloc(len); memcpy(r, a, len); return r; } AST *ast_deep_copy(AST *src) { if(src->nodeKind == AST_EXPR_VAR) { return memdup(src, sizeof(ASTExprVar)); } else if(src->nodeKind == AST_EXPR_PRIMITIVE) { return memdup(src, sizeof(ASTExprPrimitive)); } abort(); } AST *ast_cast_expr(AST *what, Type *to) { if(what == NULL) goto fail; /* Only exists at parse-time, hence not part of type system and is handled separately */ if(what->nodeKind == AST_EXPR_STRING_LITERAL) { if(to->type == TYPE_TYPE_ARRAY && type_equal(primitive_parse("u8"), to->array.of) && to->array.length == what->exprStrLit.length) { ASTExprArray *ret = calloc(1, sizeof(*ret)); ret->nodeKind = AST_EXPR_ARRAY; ret->items = calloc(to->array.length, sizeof(*ret->items)); ret->type = to; for(int i = 0; i < to->array.length; i++) { uint8_t bajt = what->exprStrLit.data[i]; ASTExprPrimitive *item = calloc(1, sizeof(*item)); item->nodeKind = AST_EXPR_PRIMITIVE; item->type = to->array.of; item->val = bajt; ret->items[i] = (AST*) item; } return (AST*) ret; } else if(to->type == TYPE_TYPE_PRIMITIVE) { if(to->primitive.width != what->exprStrLit.length * 8) { stahp_node(what, "Size mismatch between string literal and target type"); } ASTExprPrimitive *ret = calloc(1, sizeof(*ret)); ret->nodeKind = AST_EXPR_PRIMITIVE; ret->type = to; memcpy(&ret->val, what->exprStrLit.data, sizeof(ret->val)); return (AST*) ret; } else abort(); } // Make sure an unparametrized generic int parameter hasn't sneaked its way in while(what->nodeKind == AST_EXPR_VAR && what->exprVar.thing->kind == VARTABLEENTRY_CEXPR && what->exprVar.thing->data.cexpr.concrete) { what = what->exprVar.thing->data.cexpr.concrete; } assert(!(what->nodeKind == AST_EXPR_VAR && what->exprVar.thing->kind == VARTABLEENTRY_CEXPR)); if(type_equal(what->expression.type, to)) return what; if(!type_is_castable(what->expression.type, to)) { goto fail; } if(what->nodeKind == AST_EXPR_PRIMITIVE && (to->type == TYPE_TYPE_PRIMITIVE || to->type == TYPE_TYPE_POINTER)) { ASTExprPrimitive *ret = calloc(1, sizeof(*ret)); ret->nodeKind = AST_EXPR_PRIMITIVE; ret->type = to; if(to->type == TYPE_TYPE_PRIMITIVE) { ret->val = what->exprPrim.val & ((1UL << to->primitive.width) - 1); } else { ret->val = what->exprPrim.val & ((1UL << (8 * type_size(to))) - 1); } return (AST*) ret; } else { ASTExprCast *ret = calloc(1, sizeof(*ret)); ret->nodeKind = AST_EXPR_CAST; ret->type = to; ret->what = what; ret->to = to; return (AST*) ret; } fail: stahp_node(what, "Cannot cast type %s into %s", type_to_string(what->expression.type), type_to_string(to)); } struct Spill2StackState { AST *targetTLC; VarTableEntry *target; size_t stackGrowth; }; static void spill2stack_visitor(AST **aptr, AST *stmt, AST *stmtPrev, AST *chunk, AST *tlc, void *ud) { struct Spill2StackState *this = ud; if(tlc != this->targetTLC) { // Don't do anything. return; } AST *a = *aptr; if(a == tlc) { a->chunk.stackReservation += this->stackGrowth; } else if(a->nodeKind == AST_EXPR_VAR) { if(a->exprVar.thing == this->target) { // DO THE SPILL ASTExprStackPointer *rsp = calloc(1, sizeof(*rsp)); rsp->nodeKind = AST_EXPR_STACK_POINTER; rsp->type = primitive_parse("u32"); ASTExprPrimitive *offset = calloc(1, sizeof(*offset)); offset->nodeKind = AST_EXPR_PRIMITIVE; offset->type = rsp->type; offset->val = -this->stackGrowth; // This will be affected by the other part of this pass, so we must reverse ASTExprBinaryOp *bop = calloc(1, sizeof(*bop)); bop->nodeKind = AST_EXPR_BINARY_OP; bop->type = rsp->type; bop->operator = BINOP_ADD; bop->operands[0] = (AST*) rsp; bop->operands[1] = (AST*) offset; ASTExprUnaryOp *deref = calloc(1, sizeof(*deref)); deref->nodeKind = AST_EXPR_UNARY_OP; deref->type = a->expression.type; deref->operator = UNOP_DEREF; deref->operand = (AST*) bop; *aptr = (AST*) deref; } } else if(a->nodeKind == AST_EXPR_BINARY_OP && a->exprBinOp.operands[0]->nodeKind == AST_EXPR_STACK_POINTER && a->exprBinOp.operands[1]->nodeKind == AST_EXPR_PRIMITIVE) { // Guaranteed to not require more dumbification a->exprBinOp.operands[1]->exprPrim.val += this->stackGrowth; } } void ast_spill_to_stack(AST *tlc, VarTableEntry *vte) { assert(tlc->nodeKind == AST_CHUNK); assert(vte != NULL); assert(vte->kind == VARTABLEENTRY_VAR); fprintf(stderr, "Spilling %s to stack...\n", vte->data.var.name); struct Spill2StackState state; memset(&state, 0, sizeof(state)); state.target = vte; state.targetTLC = tlc; state.stackGrowth = (type_size(vte->type) + 7) & ~7; generic_visitor(&tlc, NULL, NULL, tlc, tlc, &state, spill2stack_visitor, NULL); } static void typecheck_visitor(AST **aptr, AST *stmt, AST *stmtPrev, AST *chunk, AST *tlc, void *ud) { AST *a = *aptr; if(a->nodeKind == AST_EXPR_CALL) { if(a->exprCall.what->expression.type != TYPE_TYPE_FUNCTION) { stahp_node(a, "Only function types may be called."); } } else if(a->nodeKind == AST_EXPR_BINARY_OP) { if(!type_is_number(a->exprBinOp.operands[0]) || !type_is_number(a->exprBinOp.operands[1])) { stahp_node(a, "Operands must be numbers."); } if(type_size(a->exprBinOp.operands[0]->expression.type) < type_size(a->exprBinOp.operands[1]->expression.type)) { a->exprBinOp.operands[0] = ast_cast_expr(a->exprBinOp.operands[0], a->exprBinOp.operands[1]->expression.type); } if(type_size(a->exprBinOp.operands[1]->expression.type) < type_size(a->exprBinOp.operands[0]->expression.type)) { a->exprBinOp.operands[1] = ast_cast_expr(a->exprBinOp.operands[1], a->exprBinOp.operands[0]->expression.type); } if(!a->exprBinOp.type) { a->exprBinOp.type = a->exprBinOp.operands[0]->expression.type; } } else if(a->nodeKind == AST_EXPR_UNARY_OP) { } } void ast_type_check(AST *tlc, VarTableEntry *vte) { generic_visitor(&tlc, NULL, NULL, tlc, tlc, NULL, NULL, typecheck_visitor); }