#include "textoutput.hpp"

#include <array>
#include <cstdint>

struct PageDirectory
{
    static constexpr std::uint32_t NotPresent = 0x2;

    PageDirectory(): value(NotPresent) {}
    PageDirectory(void *addr): value(reinterpret_cast<std::uint32_t>(addr) | 7) {}

    std::uint32_t value;
};
static_assert(sizeof(PageDirectory) == sizeof(std::uint32_t));

extern std::uint32_t lowerMem;
extern std::uint32_t upperMem;
extern TextOutput& term;

static std::uintptr_t lowerFree = 0x400;
static std::uintptr_t upperFree = 0x100000;

alignas(4096)
static std::array<PageDirectory, 1024> pageDirectory;

alignas(4096)
static std::array<std::uint32_t, 1024> pageTable;

void memory_initialize()
{
    lowerMem -= 1024;

    const auto totalKb = (lowerMem + upperMem) / 1024u;

    term.write("Claiming ");
    term.write(totalKb);
    term.write(" kB for allocations...\n");

    std::uint32_t addr = 0;
    for (auto& p : pageTable) {
        p = addr | 7; // supervisor, r/w, present
        addr += 0x1000;
    }

    pageDirectory[0] = PageDirectory(pageTable.data());

    asm volatile(R"(
        mov %%eax, %%cr3
        mov %%cr0, %%eax
        or $0x80000000, %%eax
        mov %%eax, %%cr0
    )" :: "a"(pageDirectory.data()));

    term.write("Paging enabled.\n");
}

static void *memory_alloc(std::size_t size)
{
    void *ret = nullptr;

    if (lowerMem > size) {
        ret = reinterpret_cast<void *>(lowerFree);
        lowerFree += size;
        lowerMem -= size;
    } else if (upperMem > size) {
        ret = reinterpret_cast<void *>(upperFree);
        upperFree += size;
        upperMem -= size;
    } else {
        // Uh oh!
        term.write("!!! Kernel allocation failed !!!");
    }

    return ret;
}

void *operator new(std::size_t size)
{
    return memory_alloc(size);
}

void *operator new[](std::size_t size)
{
    return memory_alloc(size);
}

void operator delete(void *)
{

}

void operator delete[](void *)
{

}

void operator delete(void *, std::size_t)
{

}

void operator delete[](void *, std::size_t)
{

}

extern "C"
void *memcpy(void *dst, const void *src, std::size_t sz)
{
    auto d = reinterpret_cast<char *>(dst);
    auto s = reinterpret_cast<const char *>(src);

    while (sz--)
        *d++ = *s++;

    return dst;
}

extern "C"
void *memset(void *dst, int val, std::size_t sz)
{
    auto d = reinterpret_cast<char *>(dst);

    while (sz--)
        *d++ = val;

    return dst;
}