/// sforth, an implementation of forth
/// Copyright (C) 2024  Clyne Sullivan <clyne@bitgloo.com>
///
/// This program is free software: you can redistribute it and/or modify it
/// under the terms of the GNU General Public License as published by the Free
/// Software Foundation, either version 3 of the License, or (at your option)
/// any later version.
///
/// This program is distributed in the hope that it will be useful, but WITHOUT
/// ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
/// FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License for
/// more details.
///
/// You should have received a copy of the GNU General Public License along
/// with this program.  If not, see <http://www.gnu.org/licenses/>.

#ifndef SFORTH_HPP
#define SFORTH_HPP

#include <algorithm>
#include <array>
#include <bit>
#include <charconv>
#include <cstdint>
#include <cstddef>
#include <iterator>
#include <span>
#include <string_view>
#include <tuple>
#include <utility>

struct forth
{
    using cell = std::intptr_t;
    using addr = std::uintptr_t;
    using func = void (*)(const void *);

    static constexpr bool enable_exceptions = true;
    static constexpr int data_size = 16;
    static constexpr int return_size = 16;

    static constexpr auto npos = std::string_view::npos;

    enum class error {
        init_error,
        parse_error,
        execute_error,
        dictionary_overflow,
        word_not_found,
        stack_underflow,
        stack_overflow,
        return_stack_underflow,
        return_stack_overflow,
        compile_only_word
    };

    template<error Err>
    static inline void assert(bool condition) {
        if constexpr (enable_exceptions) {
            if (!condition)
                throw Err;
        }
    }

    struct word_base {
        static constexpr addr immediate = 1 << 8;

        const word_base *next;
        addr flags_len;

        auto name() const -> std::string_view {
            return {std::bit_cast<const char *>(this + 1)};
        }

        auto body() const -> const func * {
            const auto ptr = std::bit_cast<const std::uint8_t *>(this + 1);
            const auto fptr = ptr + (flags_len & 0xFF);
            return std::bit_cast<const func *>(fptr);
        }

        constexpr void make_immediate() {
            flags_len |= immediate;
        }
    };

    template<std::size_t L>
    struct word : public word_base {
        std::array<char, L> name;
        func body;

        template<std::size_t N>
        consteval word(const char (&nam)[N],
            func bod = nullptr,
            const word_base *prev = nullptr,
            addr flags = 0):
            word_base{prev, L | flags}, name{}, body{bod}
        {
            std::copy(nam, nam + N, name.begin());
        }
    };

    template<std::size_t N>
    word(const char (&nam)[N], func b = nullptr, const word_base *w = nullptr,
        addr flags = 0) -> word<(N + sizeof(cell)) & ~(sizeof(cell) - 1)>;

    void push(cell v) {
        assert<error::stack_overflow>(sp != dstack.begin());
        *--sp = v;
    }

    void push(cell v, auto... vs) {
        push(v); (push(vs), ...);
    }

    void rpush(func *v) {
        assert<error::return_stack_overflow>(rp != rstack.begin());
        *--rp = v;
    }

    cell& top() {
        assert<error::stack_underflow>(sp != dstack.end());
        return *sp;
    }

    cell pop() {
        assert<error::stack_underflow>(sp != dstack.end());
        return *sp++;
    }

    auto rpop() -> func * {
        assert<error::return_stack_underflow>(rp != rstack.end());
        return *rp++;
    }

    template<int N>
    auto pop() {
        static_assert(N > 0, "pop<N>() with N <= 0");

        auto t = std::tuple {pop()};
        if constexpr (N > 1)
            return std::tuple_cat(t, pop<N - 1>());
        else
            return t;
    }

    forth& add(std::string_view name, func entry = nullptr) {
        const auto namesz = (name.size() + 1 + sizeof(cell) - 1) & ~(sizeof(cell) - 1);
        const auto size = (sizeof(word_base) + namesz) / sizeof(cell);

        assert<error::parse_error>(!name.empty());
        //assert<error::dictionary_overflow>(state->here + size < &dictionary.back());

        const auto h = std::exchange(here, here + size);
        latest = new (h) word_base (latest, namesz);
        std::copy(name.begin(), name.end(),
            std::bit_cast<char *>(h) + sizeof(word_base));
        if (entry)
            *here++ = std::bit_cast<cell>(entry);
        return *this;
    }

    void parse_line(std::string_view sv) {
        source = sv.data();
        sourcei = sv.find_first_not_of(" \t\r\n");

        while (sourcei != npos) {
            const auto word = parse();

            if (auto ent = get(word); !ent) {
                cell n;
                const auto [p, e] = std::from_chars(word.cbegin(), word.cend(),
                    n, base);

                assert<error::word_not_found>(e == std::errc() && p == word.cend());

                push(n);

                if (compiling)
                    execute((*get("literal"))->body());
            } else {
                auto body = (*ent)->body();

                if (compiling && ((*ent)->flags_len & word_base::immediate) == 0) {
                    *here++ = std::bit_cast<cell>(body);
                } else {
                    execute(body);
                }
            }
        }
    }

    auto parse() -> std::string_view {
        const std::string_view sv {source};

        const auto e = sv.find_first_of(" \t\r\n", sourcei);
        const auto word = e != npos ? sv.substr(sourcei, e - sourcei)
                                    : sv.substr(sourcei);
    
        sourcei = sv.find_first_not_of(" \t\r\n", e);
        return word;
    }

    void execute(const func *body) {
        assert<error::execute_error>(body && *body);
        (*body)(body);
    }

    auto get(std::string_view sv) -> std::optional<const word_base *> {
        for (auto lt = latest; lt; lt = lt->next) {
            if (sv == lt->name())
                return lt;
        }

        return {};
    }

    template<forth **fthp>
    static void prologue(func *body) {
        static auto& fth = **fthp;

        fth.rpush(fth.ip);

        for (fth.ip = body + 1; *fth.ip; fth.ip++)
            fth.execute(std::bit_cast<func *>(*fth.ip));

        fth.ip = fth.rpop();
    }

    template<forth** fthp>
    static void initialize(cell *end_value)
    {
        assert<error::init_error>(*fthp);

        static auto& fth = **fthp;

        constexpr static func lit_impl = [](auto) {
            auto ptr = std::bit_cast<cell *>(++fth.ip);
            fth.push(*ptr);
        };
        auto f_dict   = [](auto) { fth.push(std::bit_cast<cell>(&fth)); };
        auto f_add    = [](auto) { fth.top() += fth.pop(); };
        auto f_minus  = [](auto) { fth.top() -= fth.pop(); };
        auto f_times  = [](auto) { fth.top() *= fth.pop(); };
        auto f_divide = [](auto) { fth.top() /= fth.pop(); };
        auto f_mod    = [](auto) { fth.top() %= fth.pop(); };
        auto f_bitand = [](auto) { fth.top() &= fth.pop(); };
        auto f_bitor  = [](auto) { fth.top() |= fth.pop(); };
        auto f_bitxor = [](auto) { fth.top() ^= fth.pop(); };
        auto f_lshift = [](auto) { fth.top() <<= fth.pop(); };
        auto f_rshift = [](auto) { fth.top() >>= fth.pop(); };
        auto f_lbrac  = [](auto) { fth.compiling = false; };
        auto f_rbrac  = [](auto) { fth.compiling = true; };
        auto f_imm    = [](auto) {
            const_cast<word_base *>(fth.latest)->make_immediate(); };
        auto f_lit    = [](auto) {
            //assert<error::compile_only_word>(fth.compiling);
            *fth.here++ = std::bit_cast<cell>(&lit_impl);
            *fth.here++ = fth.pop(); };
        auto f_peek = [](auto) { fth.push(*std::bit_cast<cell *>(fth.pop())); };
        auto f_poke = [](auto) {
            auto [p, v] = fth.pop<2>();
            *std::bit_cast<cell *>(p) = v; };
        auto f_cpeek = [](auto) { fth.push(*std::bit_cast<char *>(fth.pop())); };
        auto f_cpoke = [](auto) {
            auto [p, v] = fth.pop<2>();
            *std::bit_cast<char *>(p) = v; };
        auto f_swap = [](auto) { auto [a, b] = fth.pop<2>(); fth.push(a, b); };
        auto f_drop = [](auto) { fth.pop(); };
        auto f_dup  = [](auto) { fth.push(fth.top()); };
        auto f_rot  = [](auto) { auto [a, b, c] = fth.pop<3>(); fth.push(b, a, c); };
        auto f_eq   = [](auto) { auto v = fth.pop(); fth.top() = -(fth.top() == v); };
        auto f_lt   = [](auto) { auto v = fth.pop(); fth.top() = -(fth.top() < v); };
        auto f_tick = [](auto) {
            auto w = fth.parse();

            if (auto g = fth.get(w); g)
                fth.push(std::bit_cast<cell>((*g)->body()));
            else
                fth.push(0); };
        auto f_colon = [](auto) {
            const auto prologue = forth::prologue<fthp>;
            auto w = fth.parse();
            fth.add(w);
            *fth.here++ = std::bit_cast<cell>(prologue);
            fth.compiling = true; };
        auto f_semic = [](auto) { *fth.here++ = 0; fth.compiling = false; };
        auto f_comm  = [](auto) { fth.sourcei = npos; };
        auto f_cell  = [](auto) { fth.push(sizeof(cell)); };
        auto f_jmp   = [](auto) {
            auto ptr = ++fth.ip;
            fth.ip = *std::bit_cast<func **>(ptr) - 1;
        };
        auto f_jmp0  = [](auto) {
            auto ptr = ++fth.ip;

            if (fth.pop() == 0)
                fth.ip = *std::bit_cast<func **>(ptr) - 1;
        };
        auto f_postpone = [](auto) {
            assert<error::compile_only_word>(fth.compiling);

            auto w = fth.parse();
            auto g = fth.get(w);

            assert<error::word_not_found>(g.has_value());

            *fth.here++ = std::bit_cast<cell>((*g)->body());
        };

        constexpr static word w_dict {"_d", f_dict};
        constexpr static word w_liti {"_lit", lit_impl, &w_dict};
        constexpr static word w_add {"+", f_add, &w_liti};
        constexpr static word w_minus {"-", f_minus, &w_add};
        constexpr static word w_times {"*", f_times, &w_minus};
        constexpr static word w_divide {"/", f_divide, &w_times};
        constexpr static word w_mod {"mod", f_mod, &w_divide};
        constexpr static word w_bitand {"and", f_bitand, &w_mod};
        constexpr static word w_bitor {"or", f_bitor, &w_bitand};
        constexpr static word w_bitxor {"xor", f_bitxor, &w_bitor};
        constexpr static word w_lshift {"lshift", f_lshift, &w_bitxor};
        constexpr static word w_rshift {"rshift", f_rshift, &w_lshift};
        constexpr static word w_lbrac {"[", f_lbrac, &w_rshift, word_base::immediate};
        constexpr static word w_rbrac {"]", f_rbrac, &w_lbrac};
        constexpr static word w_imm {"immediate", f_imm, &w_rbrac};
        constexpr static word w_lit {"literal", f_lit, &w_imm, word_base::immediate};
        constexpr static word w_peek {"@", f_peek, &w_lit};
        constexpr static word w_poke {"!", f_poke, &w_peek};
        constexpr static word w_cpeek {"c@", f_cpeek, &w_poke};
        constexpr static word w_cpoke {"c!", f_cpoke, &w_cpeek};
        constexpr static word w_swap {"swap", f_swap, &w_cpoke};
        constexpr static word w_drop {"drop", f_drop, &w_swap};
        constexpr static word w_dup {"dup", f_dup, &w_drop};
        constexpr static word w_rot {"rot", f_rot, &w_dup};
        constexpr static word w_eq {"=", f_eq, &w_rot};
        constexpr static word w_lt {"<", f_lt, &w_eq};
        constexpr static word w_tick {"\'", f_tick, &w_lt};
        constexpr static word w_colon {":", f_colon, &w_tick};
        constexpr static word w_semic {";", f_semic, &w_colon, word_base::immediate};
        constexpr static word w_comm {"\\", f_comm, &w_semic, word_base::immediate};
        constexpr static word w_cell {"cell", f_cell, &w_comm};
        constexpr static word w_jmp {"_jmp", f_jmp, &w_cell};
        constexpr static word w_jmp0 {"_jmp0", f_jmp0, &w_jmp};
        constexpr static word w_postp {"postpone", f_postpone, &w_jmp0, word_base::immediate};

        fth.latest = &w_postp;
        fth.end = end_value;
    }

    static auto error_string(error err) noexcept -> std::string_view {
        using enum error;
        switch (err) {
        case init_error:             return "init error";
        case parse_error:            return "parse error";
        case execute_error:          return "execute error";
        case dictionary_overflow:    return "dictionary overflow";
        case word_not_found:         return "word not found";
        case stack_underflow:        return "stack underflow";
        case stack_overflow:         return "stack overflow";
        case return_stack_underflow: return "return stack underflow";
        case return_stack_overflow:  return "return stack overflow";
        case compile_only_word:      return "compile only word";
        default:                     return "unknown error";
        }
    }

    constexpr forth() {
        sp = dstack.end();
        rp = rstack.end();
    }

    cell *sp;
    func **rp;
    func *ip = nullptr;
    cell *here = std::bit_cast<cell *>(this + 1);
    const word_base *latest = nullptr;
    const char *source = nullptr;
    std::size_t sourcei = npos;
    cell compiling = false;
    cell *end = nullptr;
    cell base = 10;
    std::array<cell, data_size> dstack;
    std::array<func *, return_size> rstack;
};

static_assert(offsetof(forth::word_base, flags_len) == 1 * sizeof(forth::cell));
static_assert(offsetof(forth, rp)        == 1 * sizeof(forth::cell));
static_assert(offsetof(forth, ip)        == 2 * sizeof(forth::cell));
static_assert(offsetof(forth, here)      == 3 * sizeof(forth::cell));
static_assert(offsetof(forth, latest)    == 4 * sizeof(forth::cell));
static_assert(offsetof(forth, source)    == 5 * sizeof(forth::cell));
static_assert(offsetof(forth, sourcei)   == 6 * sizeof(forth::cell));
static_assert(offsetof(forth, compiling) == 7 * sizeof(forth::cell));
static_assert(offsetof(forth, end)       == 8 * sizeof(forth::cell));
static_assert(offsetof(forth, base)      == 9 * sizeof(forth::cell));

#endif // SFORTH_HPP