From 08f4a92102c9739e0736965186814e856ff727a1 Mon Sep 17 00:00:00 2001 From: Clyne Sullivan Date: Fri, 28 Jun 2024 10:26:03 -0400 Subject: [PATCH] cache sp, minor refactors --- ast.cpp | 11 +++++++---- ast.hpp | 2 +- llvm.cpp | 26 +++++++++++++++++--------- llvm.hpp | 4 ++++ main.cpp | 17 +++++++++++------ 5 files changed, 40 insertions(+), 20 deletions(-) diff --git a/ast.cpp b/ast.cpp index a1daaac..917d88f 100644 --- a/ast.cpp +++ b/ast.cpp @@ -24,7 +24,6 @@ extern llvm::Function *curFunc; extern llvm::Value *curEnv; int ThunkAST::tcount = 0; -int ThunkAST::envidx = 0; static inline auto loadEnv(LLVMState& llvmState, llvm::Value *index) { @@ -61,7 +60,7 @@ PushAST::PushAST(const std::string& n): BaseAST(n) {} llvm::Value *PushAST::codegen(LLVMState& llvmState) const { if (auto [var, native] = Var::lookup(name, 1); var) { - auto index = llvm::ConstantInt::get(llvmState.inttype, ThunkAST::envidx++); + auto index = llvm::ConstantInt::get(llvmState.inttype, llvmState.envidx++); Var::addLocal(name, index); } @@ -85,7 +84,7 @@ llvm::Value *PopAST::codegen(LLVMState& llvmState) const if (name == "self") { v = {curFunc, true}; } else { - auto index = llvm::ConstantInt::get(llvmState.inttype, ThunkAST::envidx++); + auto index = llvm::ConstantInt::get(llvmState.inttype, llvmState.envidx++); auto pop = llvmState.createPop(); storeEnv(llvmState, index, pop); @@ -100,7 +99,7 @@ CallAST::CallAST(const std::string& n): BaseAST(n) {} llvm::Value *CallAST::codegen(LLVMState& llvmState) const { if (auto [var, native] = Var::lookup(name, 1); var && !native) { - auto index = llvm::ConstantInt::get(llvmState.inttype, ThunkAST::envidx++); + auto index = llvm::ConstantInt::get(llvmState.inttype, llvmState.envidx++); Var::addLocal(name, index); } @@ -118,6 +117,8 @@ llvm::Value *CallAST::codegen(LLVMState& llvmState) const Var::addGlobal(name, Var {fn, true}); } + llvmState.commitSp(); + bool couldRecur = Var::lookup("self").value != nullptr; auto localCount = Var::vars.back().size(); if (!couldRecur || localCount == 0) { @@ -168,12 +169,14 @@ void ThunkAST::beginGen(LLVMState& llvmState) entry = llvmState.createEntry(func); body = llvm::BasicBlock::Create(llvmState.ctx, "body", func); env = func->getArg(0); + lastSp = nullptr; llvmState.builder.SetInsertPoint(body); } void ThunkAST::endGen(LLVMState& llvmState) { + llvmState.commitSp(); llvmState.builder.CreateRetVoid(); llvmState.builder.SetInsertPoint(entry); diff --git a/ast.hpp b/ast.hpp index 380a38f..287fd03 100644 --- a/ast.hpp +++ b/ast.hpp @@ -66,13 +66,13 @@ struct CallAST : public BaseAST struct ThunkAST : public BaseAST { static int tcount; - static int envidx; std::list> ast; llvm::IRBuilderBase::InsertPoint parent; llvm::Function *func; llvm::BasicBlock *entry, *body; llvm::Value *env; + llvm::Value *lastSp; explicit ThunkAST(); llvm::Value *codegen(LLVMState& llvmState) const override; diff --git a/llvm.cpp b/llvm.cpp index c15d35d..38ec343 100644 --- a/llvm.cpp +++ b/llvm.cpp @@ -22,7 +22,7 @@ LLVMState::LLVMState(): modul("forsp", ctx), builder(ctx), inttype(llvm::Type::getInt64Ty(ctx)), - stacktype(llvm::VectorType::get(inttype, 16, false)), + stacktype(llvm::ArrayType::get(inttype, 16)), ftype(llvm::FunctionType::get(llvm::Type::getVoidTy(ctx), llvm::ArrayRef {inttype->getPointerTo()}, false)), one(llvm::ConstantInt::get(inttype, 1)), zero(llvm::ConstantInt::get(inttype, 0)) @@ -36,24 +36,32 @@ LLVMState::LLVMState(): llvm::Value *LLVMState::createPush(llvm::Value *var) { - auto dspval = builder.CreateLoad(inttype, llvmSp); - auto inc = builder.CreateAdd(dspval, one); - builder.CreateStore(inc, llvmSp); + if (!lastSp) + lastSp = builder.CreateLoad(inttype, llvmSp); - auto gep = builder.CreateGEP(stacktype, llvmStack, {zero, dspval}); + auto gep = builder.CreateGEP(stacktype, llvmStack, {zero, lastSp}); + lastSp = builder.CreateAdd(lastSp, one); return builder.CreateStore(var, gep); } llvm::Value *LLVMState::createPop() { - auto dspval = builder.CreateLoad(inttype, llvmSp); - auto dec = builder.CreateSub(dspval, one); - builder.CreateStore(dec, llvmSp); + if (!lastSp) + lastSp = builder.CreateLoad(inttype, llvmSp); - auto gep = builder.CreateGEP(stacktype, llvmStack, {zero, dec}); + lastSp = builder.CreateSub(lastSp, one); + auto gep = builder.CreateGEP(stacktype, llvmStack, {zero, lastSp}); return builder.CreateLoad(inttype, gep); } +void LLVMState::commitSp() +{ + if (lastSp) { + builder.CreateStore(lastSp, llvmSp); + lastSp = nullptr; + } +} + llvm::Function *LLVMState::createFunction(const std::string& name) { auto func = llvm::Function::Create(ftype, llvm::Function::ExternalLinkage, diff --git a/llvm.hpp b/llvm.hpp index 09e2012..96969e8 100644 --- a/llvm.hpp +++ b/llvm.hpp @@ -42,10 +42,14 @@ struct LLVMState llvm::Constant *llvmSp; llvm::Constant *llvmStack; + int envidx = 0; + llvm::Value *lastSp = nullptr; + LLVMState(); llvm::Value *createPush(llvm::Value *var); llvm::Value *createPop(); + void commitSp(); llvm::Function *createFunction(const std::string& name); llvm::BasicBlock *createEntry(llvm::Function *func); llvm::Value *createVariable(const std::string& name); diff --git a/main.cpp b/main.cpp index 9798d7e..8e65ec8 100644 --- a/main.cpp +++ b/main.cpp @@ -60,20 +60,22 @@ int main() } // 3. Create main() which initiate global vars and calls the top-level thunk. - auto func = llvmState.createFunction("main"); - auto entry = llvmState.createEntry(func); - auto envtype = llvm::VectorType::get(llvmState.inttype, ThunkAST::envidx, false); - auto [t0, _] = Var::lookup("__t0"); - llvmState.builder.SetInsertPoint(entry); + auto envtype = llvm::ArrayType::get(llvmState.inttype, llvmState.envidx); auto zerovec = llvm::ConstantVector::get(llvm::ArrayRef(llvmState.zero)); auto env = new llvm::GlobalVariable(llvmState.modul, envtype, false, llvm::GlobalValue::InternalLinkage, zerovec, "env"); + + auto func = llvmState.createFunction("main"); + auto entry = llvmState.createEntry(func); + llvmState.builder.SetInsertPoint(entry); + + auto [t0, _] = Var::lookup("__t0"); llvmState.builder.CreateCall(llvmState.ftype, t0, llvm::ArrayRef {env}); llvmState.builder.CreateRetVoid(); // 4. Output the IR to stdout. llvmState.output(); - std::cerr << "envidx: " << ThunkAST::envidx << std::endl; + std::cerr << "envidx: " << llvmState.envidx << std::endl; return 0; } @@ -142,12 +144,15 @@ std::list::iterator buildThunk(std::list::iterator it) curFunc = it->func; curEnv = it->env; + llvmState.lastSp = it->lastSp; for (auto& a : it->ast) { if (a->name.starts_with("__t")) { + it->lastSp = llvmState.lastSp; next = buildThunk(next); curFunc = it->func; curEnv = it->env; + llvmState.lastSp = it->lastSp; } if (a->codegen(llvmState) == nullptr) {