]> code.bitgloo.com Git - clyne/forspll.git/commitdiff
move AST->IR outside of parsing stage
authorClyne Sullivan <clyne@bitgloo.com>
Thu, 27 Jun 2024 19:22:47 +0000 (15:22 -0400)
committerClyne Sullivan <clyne@bitgloo.com>
Thu, 27 Jun 2024 19:22:47 +0000 (15:22 -0400)
ast.cpp
ast.hpp
main.cpp

diff --git a/ast.cpp b/ast.cpp
index f6e607442e777de34c0ca9f27d58c2b7bf00f42c..a1daaace2581c0935309cbdbedaad03516c2f612 100644 (file)
--- a/ast.cpp
+++ b/ast.cpp
@@ -20,7 +20,8 @@
 #include <charconv>
 #include <iostream>
 
-extern std::list<ThunkAST> scope;
+extern llvm::Function *curFunc;
+extern llvm::Value *curEnv;
 
 int ThunkAST::tcount = 0;
 int ThunkAST::envidx = 0;
@@ -28,14 +29,14 @@ int ThunkAST::envidx = 0;
 static inline auto loadEnv(LLVMState& llvmState, llvm::Value *index)
 {
     auto ptrty = llvmState.inttype->getPointerTo();
-    auto gep = llvmState.builder.CreateGEP(ptrty, scope.back().env, {index});
+    auto gep = llvmState.builder.CreateGEP(ptrty, curEnv, {index});
     return llvmState.builder.CreateLoad(ptrty, gep);
 }
 
 static inline auto storeEnv(LLVMState& llvmState, llvm::Value *index, llvm::Value *val)
 {
     auto ptrty = llvmState.inttype->getPointerTo();
-    auto var = llvmState.builder.CreateGEP(ptrty, scope.back().env, {index});
+    auto var = llvmState.builder.CreateGEP(ptrty, curEnv, {index});
     return llvmState.builder.CreateStore(val, var);
 }
 
@@ -82,7 +83,7 @@ llvm::Value *PopAST::codegen(LLVMState& llvmState) const
     Var v;
 
     if (name == "self") {
-        v = {scope.back().func, true};
+        v = {curFunc, true};
     } else {
         auto index = llvm::ConstantInt::get(llvmState.inttype, ThunkAST::envidx++);
         auto pop = llvmState.createPop();
@@ -120,7 +121,7 @@ llvm::Value *CallAST::codegen(LLVMState& llvmState) const
     bool couldRecur = Var::lookup("self").value != nullptr;
     auto localCount = Var::vars.back().size();
     if (!couldRecur || localCount == 0) {
-        return llvmState.builder.CreateCall(llvmState.ftype, fn, llvm::ArrayRef {scope.back().env});
+        return llvmState.builder.CreateCall(llvmState.ftype, fn, llvm::ArrayRef {curEnv});
     } else {
         int i;
         auto ptrty = llvmState.inttype->getPointerTo();
@@ -136,7 +137,7 @@ llvm::Value *CallAST::codegen(LLVMState& llvmState) const
             }
         }
 
-        auto call = llvmState.builder.CreateCall(llvmState.ftype, fn, llvm::ArrayRef {scope.back().env});
+        auto call = llvmState.builder.CreateCall(llvmState.ftype, fn, llvm::ArrayRef {curEnv});
 
         i = 0;
         for (auto& [_, v] : Var::vars.back()) {
@@ -152,10 +153,15 @@ llvm::Value *CallAST::codegen(LLVMState& llvmState) const
     }
 }
 
-ThunkAST::ThunkAST(LLVMState& llvmState):
-    ThunkAST(llvmState, std::string("__t") + std::to_string(tcount++)) {}
+ThunkAST::ThunkAST():
+    BaseAST(std::string("__t") + std::to_string(tcount++)) {}
 
-ThunkAST::ThunkAST(LLVMState& llvmState, std::string n): BaseAST(n)
+llvm::Value *ThunkAST::codegen(LLVMState& llvmState) const
+{
+    return func;
+}
+
+void ThunkAST::beginGen(LLVMState& llvmState)
 {
     parent = llvmState.builder.saveIP();
     func = llvmState.createFunction(name);
@@ -166,7 +172,7 @@ ThunkAST::ThunkAST(LLVMState& llvmState, std::string n): BaseAST(n)
     llvmState.builder.SetInsertPoint(body);
 }
 
-llvm::Value *ThunkAST::codegen(LLVMState& llvmState) const
+void ThunkAST::endGen(LLVMState& llvmState)
 {
     llvmState.builder.CreateRetVoid();
     llvmState.builder.SetInsertPoint(entry);
@@ -182,6 +188,5 @@ llvm::Value *ThunkAST::codegen(LLVMState& llvmState) const
 
     llvmState.builder.CreateBr(body);
     llvmState.builder.restoreIP(parent);
-    return func;
 }
 
diff --git a/ast.hpp b/ast.hpp
index b7802c13dca2a9d11d93faacbb303fe62533b03f..380a38f7b6a942ee913fe6b3bbba5e11c55cee5b 100644 (file)
--- a/ast.hpp
+++ b/ast.hpp
@@ -68,14 +68,17 @@ struct ThunkAST : public BaseAST
     static int tcount;
     static int envidx;
 
+    std::list<std::unique_ptr<BaseAST>> ast;
     llvm::IRBuilderBase::InsertPoint parent;
     llvm::Function *func;
     llvm::BasicBlock *entry, *body;
     llvm::Value *env;
 
-    explicit ThunkAST(LLVMState& llvmState);
-    explicit ThunkAST(LLVMState& llvmState, std::string n);
+    explicit ThunkAST();
     llvm::Value *codegen(LLVMState& llvmState) const override;
+
+    void beginGen(LLVMState& llvmState);
+    void endGen(LLVMState& llvmState);
 };
 
 #endif // FORSPLL_AST_HPP
index 2ac1ff5be7799718195bb4a441ad790e172912fc..9798d7e87dfca70eb7a250d721bb8a764cecba21 100644 (file)
--- a/main.cpp
+++ b/main.cpp
@@ -21,6 +21,7 @@
 #include <iostream>
 #include <list>
 #include <map>
+#include <stack>
 #include <string>
 #include <string_view>
 #include <tuple>
 #include "parser.hpp"
 #include "var.hpp"
 
+llvm::Function *curFunc;
+llvm::Value *curEnv;
+
 static LLVMState llvmState;
-std::list<ThunkAST> scope;
+static std::list<ThunkAST> scope;
+static std::stack<std::reference_wrapper<ThunkAST>> callstack;
 
 static bool parseString(std::string_view sv);
+static std::list<ThunkAST>::iterator buildThunk(std::list<ThunkAST>::iterator it);
 
 int main()
 {
-    Var::pushScope();
-
+    // 1. Parse code into ThunkASTs, i.e. functions with AST lists for bodies.
     std::string line;
     for (unsigned lineno = 1; std::cin.good(); ++lineno) {
         std::getline(std::cin, line);
@@ -48,6 +53,13 @@ int main()
         }
     }
 
+    // 2. Traverse through thunks in the order they are given in the file and build IR.
+    Var::pushScope();
+    for (auto it = scope.begin(); it != scope.end();) {
+        it = buildThunk(it);
+    }
+
+    // 3. Create main() which initiate global vars and calls the top-level thunk.
     auto func = llvmState.createFunction("main");
     auto entry = llvmState.createEntry(func);
     auto envtype = llvm::VectorType::get(llvmState.inttype, ThunkAST::envidx, false);
@@ -59,9 +71,11 @@ int main()
     llvmState.builder.CreateCall(llvmState.ftype, t0, llvm::ArrayRef<llvm::Value *> {env});
     llvmState.builder.CreateRetVoid();
 
+    // 4. Output the IR to stdout.
     llvmState.output();
     std::cerr << "envidx: " << ThunkAST::envidx << std::endl;
-    std::cout << std::endl;
+
+    return 0;
 }
 
 bool parseString(std::string_view sv)
@@ -77,20 +91,11 @@ bool parseString(std::string_view sv)
 
             switch (tok) {
             case Token::ThunkOpen:
-                scope.emplace_back(llvmState);
-                Var::pushScope();
+                callstack.push(scope.emplace_back());
                 break;
             case Token::ThunkClose:
-                {
-                auto& thunk = scope.back();
-                auto gen = thunk.codegen(llvmState);
-                if (!gen)
-                    return false;
-                Var::popScope();
-                Var::addLocal(thunk.name, Var {gen, true});
-                expr.reset(new PushAST {thunk.name});
-                scope.pop_back();
-                }
+                expr.reset(new PushAST {callstack.top().get().name});
+                callstack.pop();
                 break;
             case Token::Quote:
                 std::cerr << "error: quoting is not supported!" << std::endl;
@@ -112,9 +117,8 @@ bool parseString(std::string_view sv)
             }
 
             if (expr) {
-                if (!scope.empty()) {
-                    if (!expr->codegen(llvmState))
-                        return false;
+                if (!callstack.empty()) {
+                    callstack.top().get().ast.emplace_back().swap(expr);
                 } else if (tok != Token::ThunkClose) {
                     std::cerr << "error: non-thunk at top level!" << std::endl;
                     return false;
@@ -128,3 +132,38 @@ bool parseString(std::string_view sv)
     return true;
 }
 
+std::list<ThunkAST>::iterator buildThunk(std::list<ThunkAST>::iterator it)
+{
+    auto next = it;
+    ++next;
+
+    it->beginGen(llvmState);
+    Var::pushScope();
+
+    curFunc = it->func;
+    curEnv = it->env;
+
+    for (auto& a : it->ast) {
+        if (a->name.starts_with("__t")) {
+            next = buildThunk(next);
+            curFunc = it->func;
+            curEnv = it->env;
+        }
+
+        if (a->codegen(llvmState) == nullptr) {
+            return scope.end();
+        }
+    }
+
+    it->endGen(llvmState);
+
+    auto gen = it->codegen(llvmState);
+    if (!gen)
+        return scope.end();
+
+    Var::popScope();
+    Var::addLocal(it->name, Var {gen, true});
+    scope.erase(it);
+    return next;
+}
+