#include <array>
#include <cstdint>

struct gdt_entry_bits {
    std::uint32_t limit_low              : 16;
    std::uint32_t base_low               : 24;
    std::uint32_t accessed               :  1;
    std::uint32_t read_write             :  1; // readable for code, writable for data
    std::uint32_t conforming_expand_down :  1; // conforming for code, expand down for data
    std::uint32_t code                   :  1; // 1 for code, 0 for data
    std::uint32_t code_data_segment      :  1; // should be 1 for everything but TSS and LDT
    std::uint32_t DPL                    :  2; // privilege level
    std::uint32_t present                :  1;
    std::uint32_t limit_high             :  4;
    std::uint32_t available              :  1; // only used in software; has no effect on hardware
    std::uint32_t long_mode              :  1;
    std::uint32_t big                    :  1; // 32-bit opcodes for code, uint32_t stack for data
    std::uint32_t gran                   :  1; // 1 to use 4k page addressing, 0 for byte addressing
    std::uint32_t base_high              :  8;
} __attribute__((packed));

struct TSSEntry
{
    std::uint32_t prevTSS;
    std::uint32_t esp0;
    std::uint32_t ss0;
    std::uint32_t unused[23];
} __attribute__((packed));

static TSSEntry tss = {
    .prevTSS = 0,
    .esp0 = 0,
    .ss0 = 0x10
};

static const std::array<gdt_entry_bits, 6> gdt {{
    {},
    /* kernel_code = */ {
        .limit_low = 0xFFFF,
        .base_low = 0x0000,
        .accessed = 0,
        .read_write = 1,
        .conforming_expand_down = 0,
        .code = 1,
        .code_data_segment = 1,
        .DPL = 0,
        .present = 1,
        .limit_high = 0xF,
        .available = 0,
        .long_mode = 0,
        .big = 1,
        .gran = 1,
        .base_high = 0x00
    },
    /* kernel_data = */ {
        .limit_low = 0xFFFF,
        .base_low = 0x0000,
        .accessed = 0,
        .read_write = 1,
        .conforming_expand_down = 0,
        .code = 0,
        .code_data_segment = 1,
        .DPL = 0,
        .present = 1,
        .limit_high = 0xF,
        .available = 0,
        .long_mode = 0,
        .big = 1,
        .gran = 1,
        .base_high = 0x00
    },
    /* user_code = */ {
        .limit_low = 0xFFFF,
        .base_low = 0x0000,
        .accessed = 0,
        .read_write = 1,
        .conforming_expand_down = 0,
        .code = 1,
        .code_data_segment = 1,
        .DPL = 3,
        .present = 1,
        .limit_high = 0xF,
        .available = 0,
        .long_mode = 0,
        .big = 1,
        .gran = 1,
        .base_high = 0x00
    },
    /* user_data = */ {
        .limit_low = 0xFFFF,
        .base_low = 0x0000,
        .accessed = 0,
        .read_write = 1,
        .conforming_expand_down = 0,
        .code = 0,
        .code_data_segment = 1,
        .DPL = 3,
        .present = 1,
        .limit_high = 0xF,
        .available = 0,
        .long_mode = 0,
        .big = 1,
        .gran = 1,
        .base_high = 0x00
    },
    /* tss = */ {
        .limit_low = sizeof(TSSEntry),
        .base_low = (std::uint32_t)&tss & 0xFFFFFF,
        .accessed = 1,
        .read_write = 0,
        .conforming_expand_down = 0,
        .code = 1,
        .code_data_segment = 0,
        .DPL = 0,
        .present = 1,
        .limit_high = 0,
        .available = 0,
        .long_mode = 0,
        .big = 0,
        .gran = 0,
        .base_high = (std::uint32_t)&tss >> 24
    }
}};

void gdt_initialize()
{
    auto gdtr = reinterpret_cast<std::uint64_t>(gdt.data());
    gdtr <<= 16;
    gdtr |= gdt.size() * sizeof(gdt[0]);

    asm volatile(R"(
        lgdt %0
        pushl $0x8
        push $.setcs
        ljmp *(%%esp)
    .setcs:
        add $8, %%esp
        mov $0x10, %%eax
        mov %%eax, %%ds
        mov %%eax, %%es
        mov %%eax, %%fs
        mov %%eax, %%gs
        mov %%eax, %%ss

        mov $0x28, %%ax
        ltr %%ax
    )" :: "m"(gdtr));
}

void enter_user_mode(void (*func)())
{
    asm volatile("mov %%esp, %0" : "=r" (tss.esp0));

    asm volatile(R"(
        mov $0x23, %%ax
        mov %%ax, %%ds
        mov %%ax, %%es
        mov %%ax, %%fs
        mov %%ax, %%gs
        mov %%esp, %%eax
        push $0x23
        push %%esp
        pushf
        push $0x1b
        push %0
        iret
    )" :: "b"(func));
}