diff --git a/ast.cpp b/ast.cpp index f6e6074..a1daaac 100644 --- a/ast.cpp +++ b/ast.cpp @@ -20,7 +20,8 @@ #include #include -extern std::list scope; +extern llvm::Function *curFunc; +extern llvm::Value *curEnv; int ThunkAST::tcount = 0; int ThunkAST::envidx = 0; @@ -28,14 +29,14 @@ int ThunkAST::envidx = 0; static inline auto loadEnv(LLVMState& llvmState, llvm::Value *index) { auto ptrty = llvmState.inttype->getPointerTo(); - auto gep = llvmState.builder.CreateGEP(ptrty, scope.back().env, {index}); + auto gep = llvmState.builder.CreateGEP(ptrty, curEnv, {index}); return llvmState.builder.CreateLoad(ptrty, gep); } static inline auto storeEnv(LLVMState& llvmState, llvm::Value *index, llvm::Value *val) { auto ptrty = llvmState.inttype->getPointerTo(); - auto var = llvmState.builder.CreateGEP(ptrty, scope.back().env, {index}); + auto var = llvmState.builder.CreateGEP(ptrty, curEnv, {index}); return llvmState.builder.CreateStore(val, var); } @@ -82,7 +83,7 @@ llvm::Value *PopAST::codegen(LLVMState& llvmState) const Var v; if (name == "self") { - v = {scope.back().func, true}; + v = {curFunc, true}; } else { auto index = llvm::ConstantInt::get(llvmState.inttype, ThunkAST::envidx++); auto pop = llvmState.createPop(); @@ -120,7 +121,7 @@ llvm::Value *CallAST::codegen(LLVMState& llvmState) const bool couldRecur = Var::lookup("self").value != nullptr; auto localCount = Var::vars.back().size(); if (!couldRecur || localCount == 0) { - return llvmState.builder.CreateCall(llvmState.ftype, fn, llvm::ArrayRef {scope.back().env}); + return llvmState.builder.CreateCall(llvmState.ftype, fn, llvm::ArrayRef {curEnv}); } else { int i; auto ptrty = llvmState.inttype->getPointerTo(); @@ -136,7 +137,7 @@ llvm::Value *CallAST::codegen(LLVMState& llvmState) const } } - auto call = llvmState.builder.CreateCall(llvmState.ftype, fn, llvm::ArrayRef {scope.back().env}); + auto call = llvmState.builder.CreateCall(llvmState.ftype, fn, llvm::ArrayRef {curEnv}); i = 0; for (auto& [_, v] : Var::vars.back()) { @@ -152,10 +153,15 @@ llvm::Value *CallAST::codegen(LLVMState& llvmState) const } } -ThunkAST::ThunkAST(LLVMState& llvmState): - ThunkAST(llvmState, std::string("__t") + std::to_string(tcount++)) {} +ThunkAST::ThunkAST(): + BaseAST(std::string("__t") + std::to_string(tcount++)) {} -ThunkAST::ThunkAST(LLVMState& llvmState, std::string n): BaseAST(n) +llvm::Value *ThunkAST::codegen(LLVMState& llvmState) const +{ + return func; +} + +void ThunkAST::beginGen(LLVMState& llvmState) { parent = llvmState.builder.saveIP(); func = llvmState.createFunction(name); @@ -166,7 +172,7 @@ ThunkAST::ThunkAST(LLVMState& llvmState, std::string n): BaseAST(n) llvmState.builder.SetInsertPoint(body); } -llvm::Value *ThunkAST::codegen(LLVMState& llvmState) const +void ThunkAST::endGen(LLVMState& llvmState) { llvmState.builder.CreateRetVoid(); llvmState.builder.SetInsertPoint(entry); @@ -182,6 +188,5 @@ llvm::Value *ThunkAST::codegen(LLVMState& llvmState) const llvmState.builder.CreateBr(body); llvmState.builder.restoreIP(parent); - return func; } diff --git a/ast.hpp b/ast.hpp index b7802c1..380a38f 100644 --- a/ast.hpp +++ b/ast.hpp @@ -68,14 +68,17 @@ struct ThunkAST : public BaseAST static int tcount; static int envidx; + std::list> ast; llvm::IRBuilderBase::InsertPoint parent; llvm::Function *func; llvm::BasicBlock *entry, *body; llvm::Value *env; - explicit ThunkAST(LLVMState& llvmState); - explicit ThunkAST(LLVMState& llvmState, std::string n); + explicit ThunkAST(); llvm::Value *codegen(LLVMState& llvmState) const override; + + void beginGen(LLVMState& llvmState); + void endGen(LLVMState& llvmState); }; #endif // FORSPLL_AST_HPP diff --git a/main.cpp b/main.cpp index 2ac1ff5..9798d7e 100644 --- a/main.cpp +++ b/main.cpp @@ -21,6 +21,7 @@ #include #include #include +#include #include #include #include @@ -30,15 +31,19 @@ #include "parser.hpp" #include "var.hpp" +llvm::Function *curFunc; +llvm::Value *curEnv; + static LLVMState llvmState; -std::list scope; +static std::list scope; +static std::stack> callstack; static bool parseString(std::string_view sv); +static std::list::iterator buildThunk(std::list::iterator it); int main() { - Var::pushScope(); - + // 1. Parse code into ThunkASTs, i.e. functions with AST lists for bodies. std::string line; for (unsigned lineno = 1; std::cin.good(); ++lineno) { std::getline(std::cin, line); @@ -48,6 +53,13 @@ int main() } } + // 2. Traverse through thunks in the order they are given in the file and build IR. + Var::pushScope(); + for (auto it = scope.begin(); it != scope.end();) { + it = buildThunk(it); + } + + // 3. Create main() which initiate global vars and calls the top-level thunk. auto func = llvmState.createFunction("main"); auto entry = llvmState.createEntry(func); auto envtype = llvm::VectorType::get(llvmState.inttype, ThunkAST::envidx, false); @@ -59,9 +71,11 @@ int main() llvmState.builder.CreateCall(llvmState.ftype, t0, llvm::ArrayRef {env}); llvmState.builder.CreateRetVoid(); + // 4. Output the IR to stdout. llvmState.output(); std::cerr << "envidx: " << ThunkAST::envidx << std::endl; - std::cout << std::endl; + + return 0; } bool parseString(std::string_view sv) @@ -77,20 +91,11 @@ bool parseString(std::string_view sv) switch (tok) { case Token::ThunkOpen: - scope.emplace_back(llvmState); - Var::pushScope(); + callstack.push(scope.emplace_back()); break; case Token::ThunkClose: - { - auto& thunk = scope.back(); - auto gen = thunk.codegen(llvmState); - if (!gen) - return false; - Var::popScope(); - Var::addLocal(thunk.name, Var {gen, true}); - expr.reset(new PushAST {thunk.name}); - scope.pop_back(); - } + expr.reset(new PushAST {callstack.top().get().name}); + callstack.pop(); break; case Token::Quote: std::cerr << "error: quoting is not supported!" << std::endl; @@ -112,9 +117,8 @@ bool parseString(std::string_view sv) } if (expr) { - if (!scope.empty()) { - if (!expr->codegen(llvmState)) - return false; + if (!callstack.empty()) { + callstack.top().get().ast.emplace_back().swap(expr); } else if (tok != Token::ThunkClose) { std::cerr << "error: non-thunk at top level!" << std::endl; return false; @@ -128,3 +132,38 @@ bool parseString(std::string_view sv) return true; } +std::list::iterator buildThunk(std::list::iterator it) +{ + auto next = it; + ++next; + + it->beginGen(llvmState); + Var::pushScope(); + + curFunc = it->func; + curEnv = it->env; + + for (auto& a : it->ast) { + if (a->name.starts_with("__t")) { + next = buildThunk(next); + curFunc = it->func; + curEnv = it->env; + } + + if (a->codegen(llvmState) == nullptr) { + return scope.end(); + } + } + + it->endGen(llvmState); + + auto gen = it->codegen(llvmState); + if (!gen) + return scope.end(); + + Var::popScope(); + Var::addLocal(it->name, Var {gen, true}); + scope.erase(it); + return next; +} +