/** * forspll - LLVM-based Forsp compiler * Copyright (C) 2024 Clyne Sullivan * * This program is free software: you can redistribute it and/or modify it * under the terms of the GNU General Public License as published by the Free * Software Foundation, either version 3 of the License, or (at your option) * any later version. * * This program is distributed in the hope that it will be useful, but WITHOUT * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS * FOR A PARTICULAR PURPOSE. See the GNU General Public License for more * details. * * You should have received a copy of the GNU General Public License along with * this program. If not, see . */ #include "ast.hpp" #include #include extern llvm::Function *curFunc; extern llvm::Value *curEnv; int ThunkAST::tcount = 0; static inline auto loadEnv(LLVMState& llvmState, llvm::Value *index) { auto ptrty = llvmState.inttype->getPointerTo(); auto gep = llvmState.builder.CreateGEP(ptrty, curEnv, {index}); return llvmState.builder.CreateLoad(ptrty, gep); } static inline auto storeEnv(LLVMState& llvmState, llvm::Value *index, llvm::Value *val) { auto ptrty = llvmState.inttype->getPointerTo(); auto var = llvmState.builder.CreateGEP(ptrty, curEnv, {index}); return llvmState.builder.CreateStore(val, var); } NumberAST::NumberAST(const std::string& n): BaseAST(n) {} llvm::Value *NumberAST::codegen(LLVMState& llvmState) const { int value; auto [ptr, _] = std::from_chars(&name.front(), &name.back() + 1, value); if (ptr <= &name.back()) { std::cerr << "error: not a number: " << name << std::endl; return nullptr; } else { auto val = llvmState.createInt(value); return llvmState.createPush(val); } } PushAST::PushAST(const std::string& n): BaseAST(n) {} llvm::Value *PushAST::codegen(LLVMState& llvmState) const { if (auto [var, native] = Var::lookup(name, 1); var) { if (!native) { auto index = llvm::ConstantInt::get(llvmState.inttype, llvmState.envidx++); Var::addLocal(name, index); } else { Var::addLocal(name, Var {var, native}); } } if (auto [var, native] = Var::lookupLocal(name); var) { if (!native) var = loadEnv(llvmState, var); return llvmState.createPush(var); } else { std::cerr << "error: not defined: " << name << std::endl; return nullptr; } } PopAST::PopAST(const std::string& n): BaseAST(n) {} llvm::Value *PopAST::codegen(LLVMState& llvmState) const { Var v; if (name == "self") { v = {curFunc, true}; } else { auto index = llvm::ConstantInt::get(llvmState.inttype, llvmState.envidx++); auto pop = llvmState.createPop(); storeEnv(llvmState, index, pop); v = index; } return Var::addLocal(name, v).value; } CallAST::CallAST(const std::string& n): BaseAST(n) {} llvm::Value *CallAST::codegen(LLVMState& llvmState) const { if (auto [var, native] = Var::lookup(name, 1); var && !native) { auto index = llvm::ConstantInt::get(llvmState.inttype, llvmState.envidx++); Var::addLocal(name, index); } llvm::Value *fn; if (auto [var, native] = Var::lookup(name); var) { if (!native) var = loadEnv(llvmState, var); fn = var; } else { std::cerr << "warning: anticipating external function: " << name << std::endl; fn = llvmState.createFunction(name); Var::addGlobal(name, Var {fn, true}); } llvmState.commitSp(); bool couldRecur = Var::lookup("self").value != nullptr; auto localCount = Var::vars.back().size(); if (!couldRecur || localCount == 0) { return llvmState.builder.CreateCall(llvmState.ftype, fn, llvm::ArrayRef {curEnv}); } else { int i; auto ptrty = llvmState.inttype->getPointerTo(); auto type = llvm::VectorType::get(llvmState.inttype, localCount, false); auto mem = llvmState.builder.CreateAlloca(type, nullptr); i = 0; for (auto& [_, v] : Var::vars.back()) { if (!v.native) { auto index = llvm::ConstantInt::get(llvmState.inttype, i++); auto m = llvmState.builder.CreateGEP(ptrty, mem, {index}); llvmState.builder.CreateStore(loadEnv(llvmState, v.value), m); } } auto call = llvmState.builder.CreateCall(llvmState.ftype, fn, llvm::ArrayRef {curEnv}); i = 0; for (auto& [_, v] : Var::vars.back()) { if (!v.native) { auto index = llvm::ConstantInt::get(llvmState.inttype, i++); auto m = llvmState.builder.CreateGEP(ptrty, mem, {index}); auto l = llvmState.builder.CreateLoad(ptrty, m); storeEnv(llvmState, v.value, l); } } return call; } } ThunkAST::ThunkAST(): BaseAST(std::string("__t") + std::to_string(tcount++)) {} llvm::Value *ThunkAST::codegen(LLVMState& llvmState) const { return func; } void ThunkAST::beginGen(LLVMState& llvmState) { parent = llvmState.builder.saveIP(); func = llvmState.createFunction(name); entry = llvmState.createEntry(func); body = llvm::BasicBlock::Create(llvmState.ctx, "body", func); env = func->getArg(0); lastSp = nullptr; llvmState.builder.SetInsertPoint(body); } void ThunkAST::endGen(LLVMState& llvmState) { llvmState.commitSp(); llvmState.builder.CreateRetVoid(); llvmState.builder.SetInsertPoint(entry); if (Var::vars.back().size() > 0) { for (auto& [n, v] : Var::vars.back()) { if (auto [c, nat] = Var::lookup(n, 1); c) { if (!nat) { auto src = loadEnv(llvmState, c); storeEnv(llvmState, v.value, src); } } } } llvmState.builder.CreateBr(body); llvmState.builder.restoreIP(parent); }