nctref/src/ast.c

573 lines
15 KiB
C
Raw Normal View History

2023-08-27 19:48:06 +03:00
#include"ast.h"
#include<stdint.h>
#include<string.h>
#include<stdlib.h>
2024-11-20 16:36:17 +02:00
#include<assert.h>
#include<stdarg.h>
const char *AST_KIND_STR[] = {
AST_KINDS(GEN_STRI)
};
2023-08-27 19:48:06 +03:00
AST *ast_expression_optimize(AST *ast) {
return ast;
}
int ast_expression_equal(AST *a, AST *b) {
if(a->nodeKind != b->nodeKind) return 0;
if(a->nodeKind == AST_EXPR_PRIMITIVE) {
return a->exprPrim.val == b->exprPrim.val;
} else if(a->nodeKind == AST_EXPR_VAR) {
return a->exprVar.thing == b->exprVar.thing;
} else if(a->nodeKind == AST_EXPR_UNARY_OP) {
return a->exprUnOp.operator == b->exprUnOp.operator && ast_expression_equal(a->exprUnOp.operand, b->exprUnOp.operand);
} else if(a->nodeKind == AST_EXPR_BINARY_OP) {
return a->exprBinOp.operator == b->exprBinOp.operator && ast_expression_equal(a->exprBinOp.operands[0], b->exprBinOp.operands[0]) && ast_expression_equal(a->exprBinOp.operands[1], b->exprBinOp.operands[1]);
2024-11-20 16:36:17 +02:00
} else if(a->nodeKind == AST_EXPR_STACK_POINTER) {
return 1;
2023-08-27 19:48:06 +03:00
}
return 0;
}
// ... Ew
int ast_stmt_is_after(const AST *chunk, const AST *s1, const AST *s2) {
const AST *s = chunk->chunk.statementFirst;
2024-11-20 16:36:17 +02:00
while(1) {
if(s && s->nodeKind == AST_STMT_LOOP) {
2024-02-15 22:33:06 +02:00
int i = ast_stmt_is_after(s->stmtLoop.body, s1, s2);
if(i != -1) {
return i;
}
}
if(s == s1) {
return 0;
}
if(s == s2) {
return 1;
}
2024-11-20 16:36:17 +02:00
if(!s) break;
if(s->nodeKind == AST_STMT_IF) {
int i = ast_stmt_is_after(s->stmtIf.then, s1, s2);
if(i != -1) {
return i;
}
}
s = s->statement.next;
}
return -1;
2024-11-20 16:36:17 +02:00
}
/*
* This pass is necessary for the purposes of optimization and regalloc.
* Because an AST may hold outdated UD-chains, this pass MUST be called
* before using the UD-chains to make sure they are valid.
*
* Each local var (VTE of kind VARTABLEENTRY_VAR) holds its own UD-chain
* that specifies the exact nodes in the AST where:
* 1. It is used
* 2. The whole statement in which it is used
* 3. The definition that *might* be in use
*
* Because multiple definitions may be in use (reachable) at the point
* of the use, a unique UseDef for each possible definition is appended
* to the chain.
*
* Reachable definitions are kept track in a ReachingDefs, also held by
* each VTE. In the case of a single, simple block of code, we know
* exactly one definition (including undefined) can reach each variable,
* which would simplify the ReachingDefs structure to a single
* definition pointer.
*
* Unfortunately, conditional blocks and loops ruin this simplicity.
* If you have code like
* x = A
* if B {
* x = C
* }
* then afterward two definitions may apply to x.
*
* A solution here is to lay ReachingDefs as a graph, with each
* ReachingDefs having an optional parent. When we enter a new block of
* code, we create an empty ReachingDefs with the previous block as its
* parent. Any definitions replace the ones in the deepest
* ReachingDefs only.
*
* How we exit a block depends on its type. If it is conditional,
* the reaching definitions should join the parent (mergedefs).
* If the block is a loop, it is even worse. Given
* x = A
* loop {
* use x
* x = B
* }
* definitions can apply to uses that come before it!
*
* Also, a different case:
* x = A
* loop {
* use x
* y = B
* }
* Because, technically, the last use of x is before y = B, y and x may
* be assigned the same physical location, corrupting data as a result.
* To fix this, fake, "useless" statements are inserted during parsing
* that make the AST look as such:
* x = A
* loop {
* use x
* y = B
* }
* x;
* Until dead code removal is implemented, this will not be a problem.
*/
static void rawadduse(VarTableEntry *vte, UseDef *ud) {
assert(vte->kind == VARTABLEENTRY_VAR);
assert(ud->next == NULL);
assert(!!vte->data.var.usedefFirst == !!vte->data.var.usedefLast);
if(!vte->data.var.usedefFirst) {
vte->data.var.usedefFirst = vte->data.var.usedefLast = ud;
} else {
vte->data.var.usedefLast->next = ud;
vte->data.var.usedefLast = ud;
}
}
static void adduse(VarTableEntry *vte, AST *use, AST *whole) {
assert(vte->kind == VARTABLEENTRY_VAR);
assert(vte->data.var.reachingDefs != NULL);
for(size_t d = 0; d < vte->data.var.reachingDefs->defCount; d++) {
UseDef *ud = malloc(sizeof(*ud));
ud->def = vte->data.var.reachingDefs->defs[d];
ud->use = use;
ud->stmt = whole;
ud->next = NULL;
rawadduse(vte, ud);
}
}
static void overwritedefs(VarTableEntry *vte, AST *def) {
assert(vte->kind == VARTABLEENTRY_VAR);
if(!vte->data.var.reachingDefs) {
vte->data.var.reachingDefs = calloc(1, sizeof(*vte->data.var.reachingDefs));
}
vte->data.var.reachingDefs->defCount = 1;
if(!vte->data.var.reachingDefs->defs) {
vte->data.var.reachingDefs->defs = calloc(1, sizeof(*vte->data.var.reachingDefs->defs));
}
vte->data.var.reachingDefs->defs[0] = def;
}
static void mergedefs(VarTableEntry *vte) {
assert(vte->kind == VARTABLEENTRY_VAR);
ReachingDefs *rdefs = vte->data.var.reachingDefs;
assert(rdefs != NULL);
assert(rdefs->parent != NULL);
rdefs->parent->defs = realloc(rdefs->parent->defs, sizeof(*rdefs->parent->defs) * (rdefs->parent->defCount + rdefs->defCount));
memcpy(rdefs->parent->defs + rdefs->parent->defCount, rdefs->defs, rdefs->defCount * sizeof(*rdefs->defs));
vte->data.var.reachingDefs = rdefs->parent;
free(rdefs->defs);
free(rdefs);
}
static void pushdefs(VarTableEntry *vte) {
assert(vte->kind == VARTABLEENTRY_VAR);
ReachingDefs *rdefs = malloc(sizeof(*rdefs));
rdefs->defCount = 0;
rdefs->defs = NULL;
rdefs->excludeParent = 0;
rdefs->parent = vte->data.var.reachingDefs;
vte->data.var.reachingDefs = rdefs;
}
static void pushdefsall(AST *tlc) {
for(size_t i = 0; i < tlc->chunk.varCount; i++) {
pushdefs(tlc->chunk.vars[i]);
}
}
static void mergedefsall(AST *tlc) {
for(size_t i = 0; i < tlc->chunk.varCount; i++) {
mergedefs(tlc->chunk.vars[i]);
}
}
static void mergedefsloop(AST *tlc, VarTableEntry *vte, AST *daLoopStmt) {
assert(vte->kind == VARTABLEENTRY_VAR);
for(size_t d = 0; d < vte->data.var.reachingDefs->defCount; d++) {
UseDef *ud = vte->data.var.usedefFirst;
while(ud) {
if(ast_stmt_is_after(daLoopStmt->stmtLoop.body, NULL, ud->stmt) == 1 && ud->def != vte->data.var.reachingDefs->defs[d]) {
UseDef *udnew = calloc(1, sizeof(*udnew));
udnew->next = ud->next;
ud->next = udnew;
udnew->def = vte->data.var.reachingDefs->defs[d];
udnew->use = ud->use;
udnew->stmt = ud->stmt;
if(udnew->next == NULL) {
vte->data.var.usedefLast = udnew;
}
}
ud = ud->next;
}
}
mergedefs(vte);
}
static void mergedefsloopall(AST *tlc, AST *daLoopStmt) {
for(size_t i = 0; i < tlc->chunk.varCount; i++) {
mergedefsloop(tlc, tlc->chunk.vars[i], daLoopStmt);
}
}
static void ast_usedef_pass(AST *tlc, AST *a, AST *wholestmt) {
if(a->nodeKind == AST_CHUNK) {
for(AST *s = a->chunk.statementFirst; s; s = s->statement.next) {
ast_usedef_pass(tlc, s, s);
}
} else if(a->nodeKind == AST_STMT_IF) {
pushdefsall(tlc);
ast_usedef_pass(tlc, a->stmtIf.expression, wholestmt);
ast_usedef_pass(tlc, a->stmtIf.then, wholestmt);
mergedefsall(tlc);
} else if(a->nodeKind == AST_STMT_LOOP) {
pushdefsall(tlc);
ast_usedef_pass(tlc, a->stmtLoop.body, wholestmt);
mergedefsloopall(tlc, a);
} else if(a->nodeKind == AST_STMT_ASSIGN) {
if(a->stmtAssign.what->nodeKind == AST_EXPR_VAR && a->stmtAssign.what->exprVar.thing->kind == VARTABLEENTRY_VAR) {
overwritedefs(a->stmtAssign.what->exprVar.thing, a);
}
ast_usedef_pass(tlc, a->stmtAssign.what, wholestmt);
if(a->stmtAssign.to) {
ast_usedef_pass(tlc, a->stmtAssign.to, wholestmt);
}
} else if(a->nodeKind == AST_STMT_EXPR) {
ast_usedef_pass(tlc, a->stmtExpr.expr, wholestmt);
} else if(a->nodeKind == AST_EXPR_VAR) {
if(a->exprVar.thing->kind == VARTABLEENTRY_VAR) {
adduse(a->exprVar.thing, a, wholestmt);
}
} else if(a->nodeKind == AST_EXPR_BINARY_OP) {
ast_usedef_pass(tlc, a->exprBinOp.operands[0], wholestmt);
ast_usedef_pass(tlc, a->exprBinOp.operands[1], wholestmt);
} else if(a->nodeKind == AST_EXPR_UNARY_OP) {
ast_usedef_pass(tlc, a->exprUnOp.operand, wholestmt);
} else if(a->nodeKind == AST_EXPR_CALL) {
ast_usedef_pass(tlc, a->exprCall.what, wholestmt);
for(size_t p = 0; p < a->exprCall.what->expression.type->function.argCount; p++) {
ast_usedef_pass(tlc, a->exprCall.args[p], wholestmt);
}
} else if(a->nodeKind == AST_EXPR_PRIMITIVE) {
} else if(a->nodeKind == AST_EXPR_STRING_LITERAL) {
} else if(a->nodeKind == AST_EXPR_CAST) {
ast_usedef_pass(tlc, a->exprCast.what, wholestmt);
} else if(a->nodeKind == AST_EXPR_STACK_POINTER) {
} else if(a->nodeKind == AST_STMT_BREAK) {
} else if(a->nodeKind == AST_STMT_CONTINUE) {
} else if(a->nodeKind == AST_STMT_EXT_ALIGN) {
} else if(a->nodeKind == AST_STMT_EXT_ORG) {
} else if(a->nodeKind == AST_STMT_EXT_SECTION) {
} else if(a->nodeKind == AST_STMT_DECL) {
assert(a->stmtDecl.thing->kind != VARTABLEENTRY_VAR || a->stmtDecl.expression);
2024-11-28 21:40:03 +02:00
} else if(a->nodeKind == AST_STMT_RETURN) {
if(a->stmtReturn.val) {
ast_usedef_pass(tlc, a->stmtReturn.val, wholestmt);
}
2024-11-20 16:36:17 +02:00
} else {
abort();
}
}
void ast_usedef_reset(AST *chu) {
for(size_t i = 0; i < chu->chunk.varCount; i++) {
VarTableEntry *vte = chu->chunk.vars[i];
assert(vte->kind == VARTABLEENTRY_VAR);
vte->data.var.reachingDefs = NULL;
vte->data.var.usedefFirst = NULL;
vte->data.var.usedefLast = NULL;
}
pushdefsall(chu);
return ast_usedef_pass(chu, chu, NULL);
}
static char *cat(char *a, const char *b) {
if(!a) {
return strdup(b);
}
a = realloc(a, strlen(a) + strlen(b) + 1);
strcpy(a + strlen(a), b);
return a;
}
__attribute__((format(printf, 1, 2))) static char *malp(const char *fmt, ...) {
va_list v1, v2;
va_start(v1, fmt);
va_copy(v2, v1);
size_t len = vsnprintf(NULL, 0, fmt, v1);
va_end(v1);
va_start(v2, fmt);
char *str = malloc(len + 1);
vsnprintf(str, len + 1, fmt, v2);
str[len] = 0;
va_end(v2);
return str;
}
char *type_to_string(Type *t) {
if(t->type == TYPE_TYPE_PRIMITIVE) {
char ret[16] = {};
int i = 0;
ret[i++] = t->primitive.isFloat ? 'f' : (t->primitive.isUnsigned ? 'u' : 'i');
snprintf(ret + i, sizeof(ret) - i, "%i", t->primitive.width);
return strdup(ret);
} else if(t->type == TYPE_TYPE_POINTER) {
char *c = type_to_string(t->pointer.of);
char *r = malp("%s*", c);
free(c);
return r;
}
return strdup("@unimp");
}
static char *ast_dumpe(AST *e) {
if(e->nodeKind == AST_EXPR_PRIMITIVE) {
return malp("%i", e->exprPrim.val);
} else if(e->nodeKind == AST_EXPR_VAR) {
VarTableEntry *vte = e->exprVar.thing;
if(vte->kind == VARTABLEENTRY_VAR) {
return strdup(vte->data.var.name);
} else if(vte->kind == VARTABLEENTRY_SYMBOL) {
return strdup(vte->data.symbol.name);
} else abort();
} else if(e->nodeKind == AST_EXPR_UNARY_OP) {
2024-11-25 17:36:03 +02:00
const char *op = NULL;
2024-11-20 16:36:17 +02:00
switch(e->exprUnOp.operator) {
case UNOP_REF:
op = "&";
break;
case UNOP_DEREF:
op = "*";
break;
case UNOP_BITWISE_NOT:
op = "~";
break;
case UNOP_NEGATE:
op = "-";
break;
2024-11-25 17:36:03 +02:00
default:
abort();
2024-11-20 16:36:17 +02:00
}
char *c = ast_dumpe(e->exprUnOp.operand);
char *r = malp("%s%s", op, c);
free(c);
return r;
} else if(e->nodeKind == AST_EXPR_BINARY_OP) {
char *a = ast_dumpe(e->exprBinOp.operands[0]);
char *b = ast_dumpe(e->exprBinOp.operands[1]);
const char *op;
switch(e->exprBinOp.operator) {
case BINOP_ADD:
op = "+";
break;
case BINOP_SUB:
op = "-";
break;
case BINOP_MUL:
op = "*";
break;
case BINOP_DIV:
op = "/";
break;
case BINOP_BITWISE_AND:
op = "&";
break;
case BINOP_BITWISE_OR:
op = "|";
break;
case BINOP_BITWISE_XOR:
op = "^";
break;
case BINOP_EQUAL:
op = "==";
break;
2024-11-25 17:36:03 +02:00
case BINOP_NEQUAL:
op = "!=";
break;
2024-11-20 16:36:17 +02:00
default:
abort();
}
char *r = malp("(%s %s %s)", a, op, b);
free(a);
free(b);
return r;
2024-11-25 17:36:03 +02:00
} else if(e->nodeKind == AST_EXPR_STACK_POINTER) {
return malp("@stack");
2024-11-28 21:40:03 +02:00
} else if(e->nodeKind == AST_EXPR_FUNC) {
char *out = NULL;
{
char *rettype = type_to_string(e->expression.type->function.ret);
out = malp("%s(", rettype);
free(rettype);
}
for(int i = 0; i < e->expression.type->function.argCount; i++) {
char *argtype = type_to_string(e->expression.type->function.args[i]);
char *out2 = malp(i == e->expression.type->function.argCount - 1 ? "%s%s" : "%s%s, ", out, argtype);
free(out);
free(argtype);
out = out2;
}
{
char *choonk = ast_dump(e->exprFunc.chunk);
char *out2 = malp("%s) {\n%s}", out, choonk);
free(out);
free(choonk);
out = out2;
}
return out;
2024-11-20 16:36:17 +02:00
}
return malp("@unimp:%s", AST_KIND_STR[e->nodeKind]);
}
2024-11-25 17:36:03 +02:00
char *ast_dump(AST *tlc);
2024-11-20 16:36:17 +02:00
static char *ast_dumps(AST *s) {
if(s->nodeKind == AST_STMT_DECL) {
VarTableEntry *vte = s->stmtDecl.thing;
if(vte->kind == VARTABLEENTRY_SYMBOL) {
char *t = type_to_string(vte->type);
char *e = s->stmtDecl.expression ? ast_dumpe(s->stmtDecl.expression) : strdup("");
char *r = malp("%s%s %s: %s;\n", vte->data.symbol.isExternal ? "external " : "", t, vte->data.symbol.name, e);
free(t);
free(e);
return r;
}
} else if(s->nodeKind == AST_STMT_ASSIGN) {
char *a = ast_dumpe(s->stmtAssign.what);
char *b = ast_dumpe(s->stmtAssign.to);
char *r = malp("%s = %s;\n", a, b);
free(a);
free(b);
return r;
2024-11-25 17:36:03 +02:00
} else if(s->nodeKind == AST_STMT_LOOP) {
char *inner = ast_dump(s->stmtLoop.body);
char *c = malp("loop {\n%s}\n", inner);
free(inner);
return c;
} else if(s->nodeKind == AST_STMT_IF) {
char *cond = ast_dumpe(s->stmtIf.expression);
char *inner = ast_dump(s->stmtIf.then);
char *c = malp("if(%s) {\n%s}\n", cond, inner);
free(cond);
free(inner);
return c;
} else if(s->nodeKind == AST_STMT_EXPR && s->stmtExpr.expr->nodeKind == AST_EXPR_VAR) {
const char *name;
if(s->stmtExpr.expr->exprVar.thing->kind == VARTABLEENTRY_VAR) {
name = s->stmtExpr.expr->exprVar.thing->data.var.name;
} else {
name = s->stmtExpr.expr->exprVar.thing->data.symbol.name;
}
return malp("%s; /* loop guard */\n", name);
2024-11-28 21:40:03 +02:00
} else if(s->nodeKind == AST_STMT_RETURN) {
if(s->stmtReturn.val) {
char *e = ast_dumpe(s->stmtReturn.val);
char *c = malp("return %s;\n", e);
free(e);
return c;
} else {
return malp("return;\n");
}
2024-11-20 16:36:17 +02:00
}
return malp("@unimp:%s\n", AST_KIND_STR[s->nodeKind]);
}
char *ast_dump(AST *tlc) {
AST *stmt = tlc->chunk.statementFirst;
char *ret = NULL;
#define CAT(s) do { char *b = s; ret = cat(ret, (b)); free(b); } while(0)
while(stmt) {
CAT(ast_dumps(stmt));
stmt = stmt->statement.next;
}
return ret;
}
static void *memdup(void *a, size_t len) {
void *r = malloc(len);
memcpy(r, a, len);
return r;
}
AST *ast_deep_copy(AST *src) {
if(src->nodeKind == AST_EXPR_VAR) {
return memdup(src, sizeof(ASTExprVar));
} else if(src->nodeKind == AST_EXPR_PRIMITIVE) {
return memdup(src, sizeof(ASTExprPrimitive));
}
abort();
}