cache sp, minor refactors

main
Clyne 6 months ago
parent 8a6a7a4311
commit 08f4a92102
Signed by: clyne
GPG Key ID: 3267C8EBF3F9AFC7

@ -24,7 +24,6 @@ extern llvm::Function *curFunc;
extern llvm::Value *curEnv; extern llvm::Value *curEnv;
int ThunkAST::tcount = 0; int ThunkAST::tcount = 0;
int ThunkAST::envidx = 0;
static inline auto loadEnv(LLVMState& llvmState, llvm::Value *index) static inline auto loadEnv(LLVMState& llvmState, llvm::Value *index)
{ {
@ -61,7 +60,7 @@ PushAST::PushAST(const std::string& n): BaseAST(n) {}
llvm::Value *PushAST::codegen(LLVMState& llvmState) const llvm::Value *PushAST::codegen(LLVMState& llvmState) const
{ {
if (auto [var, native] = Var::lookup(name, 1); var) { if (auto [var, native] = Var::lookup(name, 1); var) {
auto index = llvm::ConstantInt::get(llvmState.inttype, ThunkAST::envidx++); auto index = llvm::ConstantInt::get(llvmState.inttype, llvmState.envidx++);
Var::addLocal(name, index); Var::addLocal(name, index);
} }
@ -85,7 +84,7 @@ llvm::Value *PopAST::codegen(LLVMState& llvmState) const
if (name == "self") { if (name == "self") {
v = {curFunc, true}; v = {curFunc, true};
} else { } else {
auto index = llvm::ConstantInt::get(llvmState.inttype, ThunkAST::envidx++); auto index = llvm::ConstantInt::get(llvmState.inttype, llvmState.envidx++);
auto pop = llvmState.createPop(); auto pop = llvmState.createPop();
storeEnv(llvmState, index, pop); storeEnv(llvmState, index, pop);
@ -100,7 +99,7 @@ CallAST::CallAST(const std::string& n): BaseAST(n) {}
llvm::Value *CallAST::codegen(LLVMState& llvmState) const llvm::Value *CallAST::codegen(LLVMState& llvmState) const
{ {
if (auto [var, native] = Var::lookup(name, 1); var && !native) { if (auto [var, native] = Var::lookup(name, 1); var && !native) {
auto index = llvm::ConstantInt::get(llvmState.inttype, ThunkAST::envidx++); auto index = llvm::ConstantInt::get(llvmState.inttype, llvmState.envidx++);
Var::addLocal(name, index); Var::addLocal(name, index);
} }
@ -118,6 +117,8 @@ llvm::Value *CallAST::codegen(LLVMState& llvmState) const
Var::addGlobal(name, Var {fn, true}); Var::addGlobal(name, Var {fn, true});
} }
llvmState.commitSp();
bool couldRecur = Var::lookup("self").value != nullptr; bool couldRecur = Var::lookup("self").value != nullptr;
auto localCount = Var::vars.back().size(); auto localCount = Var::vars.back().size();
if (!couldRecur || localCount == 0) { if (!couldRecur || localCount == 0) {
@ -168,12 +169,14 @@ void ThunkAST::beginGen(LLVMState& llvmState)
entry = llvmState.createEntry(func); entry = llvmState.createEntry(func);
body = llvm::BasicBlock::Create(llvmState.ctx, "body", func); body = llvm::BasicBlock::Create(llvmState.ctx, "body", func);
env = func->getArg(0); env = func->getArg(0);
lastSp = nullptr;
llvmState.builder.SetInsertPoint(body); llvmState.builder.SetInsertPoint(body);
} }
void ThunkAST::endGen(LLVMState& llvmState) void ThunkAST::endGen(LLVMState& llvmState)
{ {
llvmState.commitSp();
llvmState.builder.CreateRetVoid(); llvmState.builder.CreateRetVoid();
llvmState.builder.SetInsertPoint(entry); llvmState.builder.SetInsertPoint(entry);

@ -66,13 +66,13 @@ struct CallAST : public BaseAST
struct ThunkAST : public BaseAST struct ThunkAST : public BaseAST
{ {
static int tcount; static int tcount;
static int envidx;
std::list<std::unique_ptr<BaseAST>> ast; std::list<std::unique_ptr<BaseAST>> ast;
llvm::IRBuilderBase::InsertPoint parent; llvm::IRBuilderBase::InsertPoint parent;
llvm::Function *func; llvm::Function *func;
llvm::BasicBlock *entry, *body; llvm::BasicBlock *entry, *body;
llvm::Value *env; llvm::Value *env;
llvm::Value *lastSp;
explicit ThunkAST(); explicit ThunkAST();
llvm::Value *codegen(LLVMState& llvmState) const override; llvm::Value *codegen(LLVMState& llvmState) const override;

@ -22,7 +22,7 @@ LLVMState::LLVMState():
modul("forsp", ctx), modul("forsp", ctx),
builder(ctx), builder(ctx),
inttype(llvm::Type::getInt64Ty(ctx)), inttype(llvm::Type::getInt64Ty(ctx)),
stacktype(llvm::VectorType::get(inttype, 16, false)), stacktype(llvm::ArrayType::get(inttype, 16)),
ftype(llvm::FunctionType::get(llvm::Type::getVoidTy(ctx), llvm::ArrayRef<llvm::Type *> {inttype->getPointerTo()}, false)), ftype(llvm::FunctionType::get(llvm::Type::getVoidTy(ctx), llvm::ArrayRef<llvm::Type *> {inttype->getPointerTo()}, false)),
one(llvm::ConstantInt::get(inttype, 1)), one(llvm::ConstantInt::get(inttype, 1)),
zero(llvm::ConstantInt::get(inttype, 0)) zero(llvm::ConstantInt::get(inttype, 0))
@ -36,24 +36,32 @@ LLVMState::LLVMState():
llvm::Value *LLVMState::createPush(llvm::Value *var) llvm::Value *LLVMState::createPush(llvm::Value *var)
{ {
auto dspval = builder.CreateLoad(inttype, llvmSp); if (!lastSp)
auto inc = builder.CreateAdd(dspval, one); lastSp = builder.CreateLoad(inttype, llvmSp);
builder.CreateStore(inc, llvmSp);
auto gep = builder.CreateGEP(stacktype, llvmStack, {zero, dspval}); auto gep = builder.CreateGEP(stacktype, llvmStack, {zero, lastSp});
lastSp = builder.CreateAdd(lastSp, one);
return builder.CreateStore(var, gep); return builder.CreateStore(var, gep);
} }
llvm::Value *LLVMState::createPop() llvm::Value *LLVMState::createPop()
{ {
auto dspval = builder.CreateLoad(inttype, llvmSp); if (!lastSp)
auto dec = builder.CreateSub(dspval, one); lastSp = builder.CreateLoad(inttype, llvmSp);
builder.CreateStore(dec, llvmSp);
auto gep = builder.CreateGEP(stacktype, llvmStack, {zero, dec}); lastSp = builder.CreateSub(lastSp, one);
auto gep = builder.CreateGEP(stacktype, llvmStack, {zero, lastSp});
return builder.CreateLoad(inttype, gep); return builder.CreateLoad(inttype, gep);
} }
void LLVMState::commitSp()
{
if (lastSp) {
builder.CreateStore(lastSp, llvmSp);
lastSp = nullptr;
}
}
llvm::Function *LLVMState::createFunction(const std::string& name) llvm::Function *LLVMState::createFunction(const std::string& name)
{ {
auto func = llvm::Function::Create(ftype, llvm::Function::ExternalLinkage, auto func = llvm::Function::Create(ftype, llvm::Function::ExternalLinkage,

@ -42,10 +42,14 @@ struct LLVMState
llvm::Constant *llvmSp; llvm::Constant *llvmSp;
llvm::Constant *llvmStack; llvm::Constant *llvmStack;
int envidx = 0;
llvm::Value *lastSp = nullptr;
LLVMState(); LLVMState();
llvm::Value *createPush(llvm::Value *var); llvm::Value *createPush(llvm::Value *var);
llvm::Value *createPop(); llvm::Value *createPop();
void commitSp();
llvm::Function *createFunction(const std::string& name); llvm::Function *createFunction(const std::string& name);
llvm::BasicBlock *createEntry(llvm::Function *func); llvm::BasicBlock *createEntry(llvm::Function *func);
llvm::Value *createVariable(const std::string& name); llvm::Value *createVariable(const std::string& name);

@ -60,20 +60,22 @@ int main()
} }
// 3. Create main() which initiate global vars and calls the top-level thunk. // 3. Create main() which initiate global vars and calls the top-level thunk.
auto func = llvmState.createFunction("main"); auto envtype = llvm::ArrayType::get(llvmState.inttype, llvmState.envidx);
auto entry = llvmState.createEntry(func);
auto envtype = llvm::VectorType::get(llvmState.inttype, ThunkAST::envidx, false);
auto [t0, _] = Var::lookup("__t0");
llvmState.builder.SetInsertPoint(entry);
auto zerovec = llvm::ConstantVector::get(llvm::ArrayRef(llvmState.zero)); auto zerovec = llvm::ConstantVector::get(llvm::ArrayRef(llvmState.zero));
auto env = new llvm::GlobalVariable(llvmState.modul, envtype, false, auto env = new llvm::GlobalVariable(llvmState.modul, envtype, false,
llvm::GlobalValue::InternalLinkage, zerovec, "env"); llvm::GlobalValue::InternalLinkage, zerovec, "env");
auto func = llvmState.createFunction("main");
auto entry = llvmState.createEntry(func);
llvmState.builder.SetInsertPoint(entry);
auto [t0, _] = Var::lookup("__t0");
llvmState.builder.CreateCall(llvmState.ftype, t0, llvm::ArrayRef<llvm::Value *> {env}); llvmState.builder.CreateCall(llvmState.ftype, t0, llvm::ArrayRef<llvm::Value *> {env});
llvmState.builder.CreateRetVoid(); llvmState.builder.CreateRetVoid();
// 4. Output the IR to stdout. // 4. Output the IR to stdout.
llvmState.output(); llvmState.output();
std::cerr << "envidx: " << ThunkAST::envidx << std::endl; std::cerr << "envidx: " << llvmState.envidx << std::endl;
return 0; return 0;
} }
@ -142,12 +144,15 @@ std::list<ThunkAST>::iterator buildThunk(std::list<ThunkAST>::iterator it)
curFunc = it->func; curFunc = it->func;
curEnv = it->env; curEnv = it->env;
llvmState.lastSp = it->lastSp;
for (auto& a : it->ast) { for (auto& a : it->ast) {
if (a->name.starts_with("__t")) { if (a->name.starts_with("__t")) {
it->lastSp = llvmState.lastSp;
next = buildThunk(next); next = buildThunk(next);
curFunc = it->func; curFunc = it->func;
curEnv = it->env; curEnv = it->env;
llvmState.lastSp = it->lastSp;
} }
if (a->codegen(llvmState) == nullptr) { if (a->codegen(llvmState) == nullptr) {

Loading…
Cancel
Save