]> code.bitgloo.com Git - clyne/forspll.git/commitdiff
cache sp, minor refactors
authorClyne Sullivan <clyne@bitgloo.com>
Fri, 28 Jun 2024 14:26:03 +0000 (10:26 -0400)
committerClyne Sullivan <clyne@bitgloo.com>
Fri, 28 Jun 2024 14:26:03 +0000 (10:26 -0400)
ast.cpp
ast.hpp
llvm.cpp
llvm.hpp
main.cpp

diff --git a/ast.cpp b/ast.cpp
index a1daaace2581c0935309cbdbedaad03516c2f612..917d88f754335e78eadbe514c00eedb68f3f1b93 100644 (file)
--- a/ast.cpp
+++ b/ast.cpp
@@ -24,7 +24,6 @@ extern llvm::Function *curFunc;
 extern llvm::Value *curEnv;
 
 int ThunkAST::tcount = 0;
-int ThunkAST::envidx = 0;
 
 static inline auto loadEnv(LLVMState& llvmState, llvm::Value *index)
 {
@@ -61,7 +60,7 @@ PushAST::PushAST(const std::string& n): BaseAST(n) {}
 llvm::Value *PushAST::codegen(LLVMState& llvmState) const
 {
     if (auto [var, native] = Var::lookup(name, 1); var) {
-        auto index = llvm::ConstantInt::get(llvmState.inttype, ThunkAST::envidx++);
+        auto index = llvm::ConstantInt::get(llvmState.inttype, llvmState.envidx++);
         Var::addLocal(name, index);
     }
 
@@ -85,7 +84,7 @@ llvm::Value *PopAST::codegen(LLVMState& llvmState) const
     if (name == "self") {
         v = {curFunc, true};
     } else {
-        auto index = llvm::ConstantInt::get(llvmState.inttype, ThunkAST::envidx++);
+        auto index = llvm::ConstantInt::get(llvmState.inttype, llvmState.envidx++);
         auto pop = llvmState.createPop();
         storeEnv(llvmState, index, pop);
 
@@ -100,7 +99,7 @@ CallAST::CallAST(const std::string& n): BaseAST(n) {}
 llvm::Value *CallAST::codegen(LLVMState& llvmState) const
 {
     if (auto [var, native] = Var::lookup(name, 1); var && !native) {
-        auto index = llvm::ConstantInt::get(llvmState.inttype, ThunkAST::envidx++);
+        auto index = llvm::ConstantInt::get(llvmState.inttype, llvmState.envidx++);
         Var::addLocal(name, index);
     }
 
@@ -118,6 +117,8 @@ llvm::Value *CallAST::codegen(LLVMState& llvmState) const
         Var::addGlobal(name, Var {fn, true});
     }
 
+    llvmState.commitSp();
+
     bool couldRecur = Var::lookup("self").value != nullptr;
     auto localCount = Var::vars.back().size();
     if (!couldRecur || localCount == 0) {
@@ -168,12 +169,14 @@ void ThunkAST::beginGen(LLVMState& llvmState)
     entry = llvmState.createEntry(func);
     body = llvm::BasicBlock::Create(llvmState.ctx, "body", func);
     env = func->getArg(0);
+    lastSp = nullptr;
 
     llvmState.builder.SetInsertPoint(body);
 }
 
 void ThunkAST::endGen(LLVMState& llvmState)
 {
+    llvmState.commitSp();
     llvmState.builder.CreateRetVoid();
     llvmState.builder.SetInsertPoint(entry);
 
diff --git a/ast.hpp b/ast.hpp
index 380a38f7b6a942ee913fe6b3bbba5e11c55cee5b..287fd03fdf33edfbfc7958660cb7d0b97c2cacc7 100644 (file)
--- a/ast.hpp
+++ b/ast.hpp
@@ -66,13 +66,13 @@ struct CallAST : public BaseAST
 struct ThunkAST : public BaseAST
 {
     static int tcount;
-    static int envidx;
 
     std::list<std::unique_ptr<BaseAST>> ast;
     llvm::IRBuilderBase::InsertPoint parent;
     llvm::Function *func;
     llvm::BasicBlock *entry, *body;
     llvm::Value *env;
+    llvm::Value *lastSp;
 
     explicit ThunkAST();
     llvm::Value *codegen(LLVMState& llvmState) const override;
index c15d35df0aeb77f2e1a1478b8b5d7d161960cb66..38ec343cdbc91af5187e4e0e00cc3184cb3c377b 100644 (file)
--- a/llvm.cpp
+++ b/llvm.cpp
@@ -22,7 +22,7 @@ LLVMState::LLVMState():
     modul("forsp", ctx),
     builder(ctx),
     inttype(llvm::Type::getInt64Ty(ctx)),
-    stacktype(llvm::VectorType::get(inttype, 16, false)),
+    stacktype(llvm::ArrayType::get(inttype, 16)),
     ftype(llvm::FunctionType::get(llvm::Type::getVoidTy(ctx), llvm::ArrayRef<llvm::Type *> {inttype->getPointerTo()}, false)),
     one(llvm::ConstantInt::get(inttype, 1)),
     zero(llvm::ConstantInt::get(inttype, 0))
@@ -36,24 +36,32 @@ LLVMState::LLVMState():
 
 llvm::Value *LLVMState::createPush(llvm::Value *var)
 {
-    auto dspval = builder.CreateLoad(inttype, llvmSp);
-    auto inc = builder.CreateAdd(dspval, one);
-    builder.CreateStore(inc, llvmSp);
+    if (!lastSp)
+        lastSp = builder.CreateLoad(inttype, llvmSp);
 
-    auto gep = builder.CreateGEP(stacktype, llvmStack, {zero, dspval});
+    auto gep = builder.CreateGEP(stacktype, llvmStack, {zero, lastSp});
+    lastSp = builder.CreateAdd(lastSp, one);
     return builder.CreateStore(var, gep);
 }
 
 llvm::Value *LLVMState::createPop()
 {
-    auto dspval = builder.CreateLoad(inttype, llvmSp);
-    auto dec = builder.CreateSub(dspval, one);
-    builder.CreateStore(dec, llvmSp);
+    if (!lastSp)
+        lastSp = builder.CreateLoad(inttype, llvmSp);
 
-    auto gep = builder.CreateGEP(stacktype, llvmStack, {zero, dec});
+    lastSp = builder.CreateSub(lastSp, one);
+    auto gep = builder.CreateGEP(stacktype, llvmStack, {zero, lastSp});
     return builder.CreateLoad(inttype, gep);
 }
 
+void LLVMState::commitSp()
+{
+    if (lastSp) {
+        builder.CreateStore(lastSp, llvmSp);
+        lastSp = nullptr;
+    }
+}
+
 llvm::Function *LLVMState::createFunction(const std::string& name)
 {
     auto func = llvm::Function::Create(ftype, llvm::Function::ExternalLinkage,
index 09e20122532151723985574391c5f84a8360d8d3..96969e832a8b84a1d55e381e52d76df2696b6b4f 100644 (file)
--- a/llvm.hpp
+++ b/llvm.hpp
@@ -42,10 +42,14 @@ struct LLVMState
     llvm::Constant *llvmSp;
     llvm::Constant *llvmStack;
 
+    int envidx = 0;
+    llvm::Value *lastSp = nullptr;
+
     LLVMState();
 
     llvm::Value *createPush(llvm::Value *var);
     llvm::Value *createPop();
+    void commitSp();
     llvm::Function *createFunction(const std::string& name);
     llvm::BasicBlock *createEntry(llvm::Function *func);
     llvm::Value *createVariable(const std::string& name);
index 9798d7e87dfca70eb7a250d721bb8a764cecba21..8e65ec84bdbc174b30f74fd2a53ac1c6755bea30 100644 (file)
--- a/main.cpp
+++ b/main.cpp
@@ -60,20 +60,22 @@ int main()
     }
 
     // 3. Create main() which initiate global vars and calls the top-level thunk.
-    auto func = llvmState.createFunction("main");
-    auto entry = llvmState.createEntry(func);
-    auto envtype = llvm::VectorType::get(llvmState.inttype, ThunkAST::envidx, false);
-    auto [t0, _] = Var::lookup("__t0");
-    llvmState.builder.SetInsertPoint(entry);
+    auto envtype = llvm::ArrayType::get(llvmState.inttype, llvmState.envidx);
     auto zerovec = llvm::ConstantVector::get(llvm::ArrayRef(llvmState.zero));
     auto env = new llvm::GlobalVariable(llvmState.modul, envtype, false,
         llvm::GlobalValue::InternalLinkage, zerovec, "env");
+
+    auto func = llvmState.createFunction("main");
+    auto entry = llvmState.createEntry(func);
+    llvmState.builder.SetInsertPoint(entry);
+
+    auto [t0, _] = Var::lookup("__t0");
     llvmState.builder.CreateCall(llvmState.ftype, t0, llvm::ArrayRef<llvm::Value *> {env});
     llvmState.builder.CreateRetVoid();
 
     // 4. Output the IR to stdout.
     llvmState.output();
-    std::cerr << "envidx: " << ThunkAST::envidx << std::endl;
+    std::cerr << "envidx: " << llvmState.envidx << std::endl;
 
     return 0;
 }
@@ -142,12 +144,15 @@ std::list<ThunkAST>::iterator buildThunk(std::list<ThunkAST>::iterator it)
 
     curFunc = it->func;
     curEnv = it->env;
+    llvmState.lastSp = it->lastSp;
 
     for (auto& a : it->ast) {
         if (a->name.starts_with("__t")) {
+            it->lastSp = llvmState.lastSp;
             next = buildThunk(next);
             curFunc = it->func;
             curEnv = it->env;
+            llvmState.lastSp = it->lastSp;
         }
 
         if (a->codegen(llvmState) == nullptr) {