/**
 * forspll - LLVM-based Forsp compiler
 * Copyright (C) 2024  Clyne Sullivan <clyne@bitgloo.com>
 *
 * This program is free software: you can redistribute it and/or modify it
 * under the terms of the GNU General Public License as published by the Free
 * Software Foundation, either version 3 of the License, or (at your option)
 * any later version.
 *
 * This program is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
 * FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
 * details.
 *
 * You should have received a copy of the GNU General Public License along with
 * this program.  If not, see <http://www.gnu.org/licenses/>.
 */
#include "ast.hpp"

#include <charconv>
#include <iostream>

extern llvm::Function *curFunc;
extern llvm::Value *curEnv;

int ThunkAST::tcount = 0;
int ThunkAST::envidx = 0;

static inline auto loadEnv(LLVMState& llvmState, llvm::Value *index)
{
    auto ptrty = llvmState.inttype->getPointerTo();
    auto gep = llvmState.builder.CreateGEP(ptrty, curEnv, {index});
    return llvmState.builder.CreateLoad(ptrty, gep);
}

static inline auto storeEnv(LLVMState& llvmState, llvm::Value *index, llvm::Value *val)
{
    auto ptrty = llvmState.inttype->getPointerTo();
    auto var = llvmState.builder.CreateGEP(ptrty, curEnv, {index});
    return llvmState.builder.CreateStore(val, var);
}

NumberAST::NumberAST(const std::string& n): BaseAST(n) {}

llvm::Value *NumberAST::codegen(LLVMState& llvmState) const
{
    int value;
    auto [ptr, _] = std::from_chars(&name.front(), &name.back() + 1, value);

    if (ptr <= &name.back()) {
        std::cerr << "error: not a number: " << name << std::endl;
        return nullptr;
    } else {
        auto val = llvmState.createInt(value);
        return llvmState.createPush(val);
    }
}

PushAST::PushAST(const std::string& n): BaseAST(n) {}

llvm::Value *PushAST::codegen(LLVMState& llvmState) const
{
    if (auto [var, native] = Var::lookup(name, 1); var) {
        auto index = llvm::ConstantInt::get(llvmState.inttype, ThunkAST::envidx++);
        Var::addLocal(name, index);
    }

    if (auto [var, native] = Var::lookupLocal(name); var) {
        if (!native)
            var = loadEnv(llvmState, var);

        return llvmState.createPush(var);
    } else {
        std::cerr << "error: not defined: " << name << std::endl;
        return nullptr;
    }
}

PopAST::PopAST(const std::string& n): BaseAST(n) {}

llvm::Value *PopAST::codegen(LLVMState& llvmState) const
{
    Var v;

    if (name == "self") {
        v = {curFunc, true};
    } else {
        auto index = llvm::ConstantInt::get(llvmState.inttype, ThunkAST::envidx++);
        auto pop = llvmState.createPop();
        storeEnv(llvmState, index, pop);

        v = index;
    }

    return Var::addLocal(name, v).value;
}

CallAST::CallAST(const std::string& n): BaseAST(n) {}

llvm::Value *CallAST::codegen(LLVMState& llvmState) const
{
    if (auto [var, native] = Var::lookup(name, 1); var && !native) {
        auto index = llvm::ConstantInt::get(llvmState.inttype, ThunkAST::envidx++);
        Var::addLocal(name, index);
    }

    llvm::Value *fn;
    if (auto [var, native] = Var::lookup(name); var) {
        if (!native)
            var = loadEnv(llvmState, var);

        fn = var;
    } else {
        std::cerr << "warning: anticipating external function: "
            << name << std::endl;

        fn = llvmState.createFunction(name);
        Var::addGlobal(name, Var {fn, true});
    }

    bool couldRecur = Var::lookup("self").value != nullptr;
    auto localCount = Var::vars.back().size();
    if (!couldRecur || localCount == 0) {
        return llvmState.builder.CreateCall(llvmState.ftype, fn, llvm::ArrayRef {curEnv});
    } else {
        int i;
        auto ptrty = llvmState.inttype->getPointerTo();
        auto type = llvm::VectorType::get(llvmState.inttype, localCount, false);
        auto mem = llvmState.builder.CreateAlloca(type, nullptr);

        i = 0;
        for (auto& [_, v] : Var::vars.back()) {
            if (!v.native) {
                auto index = llvm::ConstantInt::get(llvmState.inttype, i++);
                auto m = llvmState.builder.CreateGEP(ptrty, mem, {index});
                llvmState.builder.CreateStore(loadEnv(llvmState, v.value), m);
            }
        }

        auto call = llvmState.builder.CreateCall(llvmState.ftype, fn, llvm::ArrayRef {curEnv});

        i = 0;
        for (auto& [_, v] : Var::vars.back()) {
            if (!v.native) {
                auto index = llvm::ConstantInt::get(llvmState.inttype, i++);
                auto m = llvmState.builder.CreateGEP(ptrty, mem, {index});
                auto l = llvmState.builder.CreateLoad(ptrty, m);
                storeEnv(llvmState, v.value, l);
            }
        }

        return call;
    }
}

ThunkAST::ThunkAST():
    BaseAST(std::string("__t") + std::to_string(tcount++)) {}

llvm::Value *ThunkAST::codegen(LLVMState& llvmState) const
{
    return func;
}

void ThunkAST::beginGen(LLVMState& llvmState)
{
    parent = llvmState.builder.saveIP();
    func = llvmState.createFunction(name);
    entry = llvmState.createEntry(func);
    body = llvm::BasicBlock::Create(llvmState.ctx, "body", func);
    env = func->getArg(0);

    llvmState.builder.SetInsertPoint(body);
}

void ThunkAST::endGen(LLVMState& llvmState)
{
    llvmState.builder.CreateRetVoid();
    llvmState.builder.SetInsertPoint(entry);

    if (Var::vars.back().size() > 0) {
        for (auto& [n, v] : Var::vars.back()) {
            if (auto [c, _] = Var::lookup(n, 1); c) {
                auto src = loadEnv(llvmState, c);
                storeEnv(llvmState, v.value, src);
            }
        }
    }

    llvmState.builder.CreateBr(body);
    llvmState.builder.restoreIP(parent);
}