/**
 * @file device.cpp
 * @brief Contains code for device-related UI elements and logic.
 *
 * Copyright (C) 2021 Clyne Sullivan
 *
 * Distributed under the GNU GPL v3 or later. You should have received a copy of
 * the GNU General Public License along with this program.
 * If not, see <https://www.gnu.org/licenses/>.
 */

#include "stmdsp.hpp"

#include "imgui.h"
#include "imgui_internal.h"
#include "ImGuiFileDialog.h"
#include "wav.hpp"

#include <array>
#include <charconv>
#include <cmath>
#include <deque>
#include <fstream>
#include <iostream>
#include <memory>
#include <mutex>
#include <thread>

extern std::string tempFileName;
extern void log(const std::string& str);

extern std::vector<stmdsp::dacsample_t> deviceGenLoadFormulaEval(const std::string_view);

std::shared_ptr<stmdsp::device> m_device;

static const std::array<const char *, 6> sampleRateList {{
    "8 kHz",
    "16 kHz",
    "20 kHz",
    "32 kHz",
    "48 kHz",
    "96 kHz"
}};
static const char *sampleRatePreview = sampleRateList[0];
static const std::array<unsigned int, 6> sampleRateInts {{
    8'000,
    16'000,
    20'000,
    32'000,
    48'000,
    96'000
}};

static bool measureCodeTime = false;
static bool drawSamples = false;
static bool logResults = false;
static bool genRunning = false;
static bool drawSamplesInput = false;

static bool popupRequestBuffer = false;
static bool popupRequestSiggen = false;
static bool popupRequestDraw = false;
static bool popupRequestLog = false;

static std::timed_mutex mutexDrawSamples;
static std::timed_mutex mutexDeviceLoad;

static std::ofstream logSamplesFile;
static wav::clip wavOutput;

static std::deque<stmdsp::dacsample_t> drawSamplesQueue;
static std::deque<stmdsp::dacsample_t> drawSamplesInputQueue;
static double drawSamplesTimeframe = 1.0; // seconds
static unsigned int drawSamplesBufferSize = 1;

static void measureCodeTask(std::shared_ptr<stmdsp::device> device)
{
    std::this_thread::sleep_for(std::chrono::seconds(1));

    if (device) {
        const auto cycles = device->continuous_start_get_measurement();
        log(std::string("Execution time: ") + std::to_string(cycles) + " cycles.");
    }
}

static std::vector<stmdsp::dacsample_t> tryReceiveChunk(
    std::shared_ptr<stmdsp::device> device,
    auto readFunc)
{
    int tries = -1;
    do {
        const auto chunk = readFunc(device.get());
        if (!chunk.empty())
            return chunk;
        else
            std::this_thread::sleep_for(std::chrono::microseconds(20));
    } while (++tries < 100 && device->is_running());

    return {};
}

static std::chrono::duration<double> getBufferPeriod(
    std::shared_ptr<stmdsp::device> device,
    const double factor = 0.975)
{
    if (device) {
        const double bufferSize = device->get_buffer_size();
        const double sampleRate = sampleRateInts[device->get_sample_rate()];
        return std::chrono::duration<double>(bufferSize / sampleRate * factor);
    } else {
        return {};
    }
}

static void drawSamplesTask(std::shared_ptr<stmdsp::device> device)
{
    if (!device)
        return;

    const bool doLogger = logResults && logSamplesFile.good();
    const auto bufferTime = getBufferPeriod(device);

    std::unique_lock<std::timed_mutex> lockDraw (mutexDrawSamples, std::defer_lock);
    std::unique_lock<std::timed_mutex> lockDevice (mutexDeviceLoad, std::defer_lock);

    auto addToQueue = [&lockDraw](auto& queue, const auto& chunk) {
        lockDraw.lock();
        std::copy(chunk.cbegin(), chunk.cend(), std::back_inserter(queue));
        lockDraw.unlock();
    };

    while (device && device->is_running()) {
        const auto next = std::chrono::high_resolution_clock::now() + bufferTime;

        if (lockDevice.try_lock_until(next)) {
            const auto chunk = tryReceiveChunk(device,
                std::mem_fn(&stmdsp::device::continuous_read));
            lockDevice.unlock();

            addToQueue(drawSamplesQueue, chunk);
            if (doLogger) {
                for (const auto& s : chunk)
                    logSamplesFile << s << '\n';
            }
        } else {
            // Device must be busy, cooldown.
            std::this_thread::sleep_for(std::chrono::milliseconds(500));
        }

        if (drawSamplesInput && popupRequestDraw) {
            if (lockDevice.try_lock_for(std::chrono::milliseconds(1))) {
                const auto chunk2 = tryReceiveChunk(device,
                    std::mem_fn(&stmdsp::device::continuous_read_input));
                lockDevice.unlock();

                addToQueue(drawSamplesInputQueue, chunk2);
            }
        }

        std::this_thread::sleep_until(next);
    }
}

static void feedSigGenTask(std::shared_ptr<stmdsp::device> device)
{
    if (!device)
        return;

    const auto delay = getBufferPeriod(device);
    const auto uploadDelay = getBufferPeriod(device, 0.001);

    std::vector<stmdsp::dacsample_t> wavBuf (device->get_buffer_size() * 2, 2048);

    std::unique_lock<std::timed_mutex> lockDevice (mutexDeviceLoad, std::defer_lock);

    lockDevice.lock();
    device->siggen_upload(wavBuf.data(), wavBuf.size());
    wavBuf.resize(wavBuf.size() / 2);
    device->siggen_start();
    std::this_thread::sleep_for(std::chrono::milliseconds(1));
    lockDevice.unlock();

    std::vector<int16_t> wavIntBuf (wavBuf.size());

    while (genRunning) {
        const auto next = std::chrono::high_resolution_clock::now() + delay;

        wavOutput.next(wavIntBuf.data(), wavIntBuf.size());
        auto src = wavIntBuf.cbegin();
        std::generate(wavBuf.begin(), wavBuf.end(),
            [&src] { return static_cast<stmdsp::dacsample_t>(*src++ / 16 + 2048); });

        lockDevice.lock();
        while (!device->siggen_upload(wavBuf.data(), wavBuf.size()))
            std::this_thread::sleep_for(uploadDelay);
        lockDevice.unlock();

        std::this_thread::sleep_until(next);
    }
}

static void statusTask(std::shared_ptr<stmdsp::device> device)
{
    if (!device)
        return;

    while (device->connected()) {
        std::unique_lock<std::timed_mutex> lockDevice (mutexDeviceLoad, std::defer_lock);
        lockDevice.lock();
        const auto [status, error] = device->get_status();
        lockDevice.unlock();

        if (error != stmdsp::Error::None) {
            switch (error) {
            case stmdsp::Error::NotIdle:
                log("Error: Device already running...");
                break;
            case stmdsp::Error::ConversionAborted:
                log("Error: Algorithm unloaded, a fault occurred!");
                break;
            default:
                log("Error: Device had an issue...");
                break;
            }
        }

        std::this_thread::sleep_for(std::chrono::seconds(1));
    }
}

static void deviceConnect();
static void deviceStart();
static void deviceAlgorithmUpload();
static void deviceAlgorithmUnload();
static void deviceGenLoadList(std::string_view list);
static void deviceGenLoadFormula(std::string_view list);

void deviceRenderWidgets()
{
    static char *siggenBuffer = nullptr;
    static int siggenOption = 0;

    if (popupRequestSiggen) {
        siggenBuffer = new char[65536];
        *siggenBuffer = '\0';
        ImGui::OpenPopup("siggen");
        popupRequestSiggen = false;
    } else if (popupRequestBuffer) {
        ImGui::OpenPopup("buffer");
        popupRequestBuffer = false;
    } else if (popupRequestLog) {
        ImGuiFileDialog::Instance()->OpenModal(
            "ChooseFileLogGen", "Choose File", ".csv", ".");
        popupRequestLog = false;
    }

    if (ImGui::BeginPopup("siggen")) {
        if (ImGui::RadioButton("List", &siggenOption, 0))
            siggenBuffer[0] = '\0';
        ImGui::SameLine();
        if (ImGui::RadioButton("Formula", &siggenOption, 1))
            siggenBuffer[0] = '\0';
        ImGui::SameLine();
        if (ImGui::RadioButton("Audio File", &siggenOption, 2))
            siggenBuffer[0] = '\0';

        switch (siggenOption) {
        case 0:
            ImGui::Text("Enter a list of numbers:");
            ImGui::PushStyleColor(ImGuiCol_FrameBg, {.8, .8, .8, 1});
            ImGui::InputText("", siggenBuffer, 65536);
            ImGui::PopStyleColor();
            break;
        case 1:
            ImGui::Text("Enter a formula. f(x) = ");
            ImGui::PushStyleColor(ImGuiCol_FrameBg, {.8, .8, .8, 1});
            ImGui::InputText("", siggenBuffer, 65536);
            ImGui::PopStyleColor();
            break;
        case 2:
            if (ImGui::Button("Choose File")) {
                // This dialog will override the siggen popup, closing it.
                ImGuiFileDialog::Instance()->OpenModal(
                    "ChooseFileLogGen", "Choose File", ".wav", ".");
            }
            break;
        }

        if (ImGui::Button("Cancel")) {
            delete[] siggenBuffer;
            ImGui::CloseCurrentPopup();
        }

        if (ImGui::Button("Save")) {
            switch (siggenOption) {
            case 0:
                deviceGenLoadList(siggenBuffer);
                break;
            case 1:
                deviceGenLoadFormula(siggenBuffer);
                break;
            case 2:
                break;
            }

            delete[] siggenBuffer;
            ImGui::CloseCurrentPopup();
        }

        ImGui::EndPopup();
    }

    if (ImGui::BeginPopup("buffer")) {
        static char bufferSizeStr[5] = "4096";
        ImGui::Text("Please enter a new sample buffer size (100-4096):");
        ImGui::PushStyleColor(ImGuiCol_FrameBg, {.8, .8, .8, 1});
        ImGui::InputText("", bufferSizeStr, sizeof(bufferSizeStr), ImGuiInputTextFlags_CharsDecimal);
        ImGui::PopStyleColor();
        if (ImGui::Button("Save")) {
            if (m_device) {
                int n = std::clamp(std::stoi(bufferSizeStr), 100, 4096);
                m_device->continuous_set_buffer_size(n);
            }
            ImGui::CloseCurrentPopup();
        }
        ImGui::SameLine();
        if (ImGui::Button("Cancel"))
            ImGui::CloseCurrentPopup();
        ImGui::EndPopup();
    }

    if (ImGuiFileDialog::Instance()->Display("ChooseFileLogGen",
                                             ImGuiWindowFlags_NoCollapse,
                                             ImVec2(460, 540)))
    {
        if (ImGuiFileDialog::Instance()->IsOk()) {
            auto filePathName = ImGuiFileDialog::Instance()->GetFilePathName();
            auto ext = filePathName.substr(filePathName.size() - 4);

            if (ext.compare(".wav") == 0) {
                wavOutput = wav::clip(filePathName.c_str());
                if (wavOutput.valid())
                    log("Audio file loaded.");
                else
                    log("Error: Bad WAV audio file.");

                delete[] siggenBuffer;
            } else if (ext.compare(".csv") == 0) {
                logSamplesFile = std::ofstream(filePathName);
                if (logSamplesFile.good())
                    log("Log file ready.");
            }
        }

        ImGuiFileDialog::Instance()->Close();
    }
}

void deviceRenderDraw()
{
    if (popupRequestDraw) {
        static std::vector<stmdsp::dacsample_t> buffer;
        static decltype(buffer.begin()) bufferCursor;
        static std::vector<stmdsp::dacsample_t> bufferInput;
        static decltype(bufferInput.begin()) bufferInputCursor;
        static unsigned int yMinMax = 4095;

        ImGui::Begin("draw", &popupRequestDraw);
        ImGui::Text("Draw input ");
        ImGui::SameLine();
        ImGui::Checkbox("", &drawSamplesInput);
        ImGui::SameLine();
        ImGui::Text("Time: %0.3f sec", drawSamplesTimeframe);
        ImGui::SameLine();
        if (ImGui::Button("-", {30, 0})) {
            drawSamplesTimeframe = std::max(drawSamplesTimeframe / 2., 0.0078125);
            auto sr = sampleRateInts[m_device->get_sample_rate()];
            auto tf = drawSamplesTimeframe;
            drawSamplesBufferSize = std::round(sr * tf);
        }
        ImGui::SameLine();
        if (ImGui::Button("+", {30, 0})) {
            drawSamplesTimeframe = std::min(drawSamplesTimeframe * 2, 32.);
            auto sr = sampleRateInts[m_device->get_sample_rate()];
            auto tf = drawSamplesTimeframe;
            drawSamplesBufferSize = std::round(sr * tf);
        }
        ImGui::SameLine();
        ImGui::Text("Y: +/-%1.2fV", 3.3f * (static_cast<float>(yMinMax) / 4095.f));
        ImGui::SameLine();
        if (ImGui::Button(" - ", {30, 0})) {
            yMinMax = std::max(63u, yMinMax >> 1);
        }
        ImGui::SameLine();
        if (ImGui::Button(" + ", {30, 0})) {
            yMinMax = std::min(4095u, (yMinMax << 1) | 1);
        }

        static unsigned long csize = 0;
        if (buffer.size() != drawSamplesBufferSize) {
            buffer.resize(drawSamplesBufferSize);
            bufferInput.resize(drawSamplesBufferSize);
            bufferCursor = buffer.begin();
            bufferInputCursor = bufferInput.begin();
            csize = drawSamplesBufferSize / (60. * drawSamplesTimeframe) * 1.025;
        }

        {
            std::scoped_lock lock (mutexDrawSamples);
            auto count = std::min(drawSamplesQueue.size(), csize);
            for (auto i = count; i; --i) {
                *bufferCursor = drawSamplesQueue.front();
                drawSamplesQueue.pop_front();
                if (++bufferCursor == buffer.end())
                    bufferCursor = buffer.begin();
            }
            
            if (drawSamplesInput) {
                auto count = std::min(drawSamplesInputQueue.size(), csize);
                for (auto i = count; i; --i) {
                    *bufferInputCursor = drawSamplesInputQueue.front();
                    drawSamplesInputQueue.pop_front();
                    if (++bufferInputCursor == bufferInput.end())
                        bufferInputCursor = bufferInput.begin();
                }
            }
        }

        auto drawList = ImGui::GetWindowDrawList();
        ImVec2 p0 = ImGui::GetWindowPos();
        auto size = ImGui::GetWindowSize();
        p0.y += 65;
        size.y -= 70;
        drawList->AddRectFilled(p0, {p0.x + size.x, p0.y + size.y}, IM_COL32(0, 0, 0, 255));

        const float di = static_cast<float>(buffer.size()) / size.x;
        const float dx = std::ceil(size.x / static_cast<float>(buffer.size()));
        ImVec2 pp = p0;
        float i = 0;
        while (pp.x < p0.x + size.x) {
            unsigned int idx = i;
            float n = std::clamp((buffer[idx] - 2048.) / yMinMax, -0.5, 0.5);
            i += di;

            ImVec2 next (pp.x + dx, p0.y + size.y * (0.5 - n));
            drawList->AddLine(pp, next, ImGui::GetColorU32(IM_COL32(255, 0, 0, 255)));
            pp = next;
        }

        if (drawSamplesInput) {
            ImVec2 pp = p0;
            float i = 0;
            while (pp.x < p0.x + size.x) {
                unsigned int idx = i;
                float n = std::clamp((bufferInput[idx] - 2048.) / yMinMax, -0.5, 0.5);
                i += di;

                ImVec2 next (pp.x + dx, p0.y + size.y * (0.5 - n));
                drawList->AddLine(pp, next, ImGui::GetColorU32(IM_COL32(0, 0, 255, 255)));
                pp = next;
            }
        }

        ImGui::End();
    }
}

void deviceRenderMenu()
{
    if (ImGui::BeginMenu("Run")) {
        bool isConnected = m_device ? true : false;
        bool isRunning = isConnected && m_device->is_running();

        static const char *connectLabel = "Connect";
        if (ImGui::MenuItem(connectLabel, nullptr, false, !isConnected || (isConnected && !isRunning))) {
            deviceConnect();
            isConnected = m_device ? true : false;
            connectLabel = isConnected ? "Disconnect" : "Connect";
        }

        ImGui::Separator();
        static const char *startLabel = "Start";
        if (ImGui::MenuItem(startLabel, nullptr, false, isConnected)) {
            startLabel = isRunning ? "Start" : "Stop";
            deviceStart();
        }

        if (ImGui::MenuItem("Upload algorithm", nullptr, false, isConnected && !isRunning))
            deviceAlgorithmUpload();
        if (ImGui::MenuItem("Unload algorithm", nullptr, false, isConnected && !isRunning))
            deviceAlgorithmUnload();

        ImGui::Separator();
        if (!isConnected || isRunning)
            ImGui::PushDisabled();
        ImGui::Checkbox("Measure Code Time", &measureCodeTime);
        if (ImGui::Checkbox("Draw samples", &drawSamples)) {
            if (drawSamples)
                popupRequestDraw = true;
        }
        if (ImGui::Checkbox("Log results...", &logResults)) {
            if (logResults)
                popupRequestLog = true;
            else if (logSamplesFile.is_open())
                logSamplesFile.close();
        }
        if (!isConnected || isRunning)
            ImGui::PopDisabled();

        if (ImGui::MenuItem("Set buffer size...", nullptr, false, isConnected && !isRunning)) {
            popupRequestBuffer = true;
        }
        ImGui::Separator();
        if (ImGui::MenuItem("Load signal generator", nullptr, false, isConnected && !m_device->is_siggening())) {
            popupRequestSiggen = true;
        }
        static const char *startSiggenLabel = "Start signal generator";
        if (ImGui::MenuItem(startSiggenLabel, nullptr, false, isConnected)) {
            if (m_device) {
                if (!genRunning) {
                    genRunning = true;
                    if (wavOutput.valid())
                        std::thread(feedSigGenTask, m_device).detach();
                    else
                        m_device->siggen_start();
                    log("Generator started.");
                    startSiggenLabel = "Stop signal generator";
                } else {
                    genRunning = false;
                    m_device->siggen_stop();
                    log("Generator stopped.");
                    startSiggenLabel = "Start signal generator";
                }
            }
        }

        ImGui::EndMenu();
    }
}

void deviceRenderToolbar()
{
    ImGui::SameLine();
    if (ImGui::Button("Upload"))
        deviceAlgorithmUpload();
    ImGui::SameLine();
    ImGui::SetNextItemWidth(100);

    const bool enable = m_device && !m_device->is_running() && !m_device->is_siggening();
    if (!enable)
        ImGui::PushDisabled();
    if (ImGui::BeginCombo("", sampleRatePreview)) {
        for (unsigned int i = 0; i < sampleRateList.size(); ++i) {
            if (ImGui::Selectable(sampleRateList[i])) {
                sampleRatePreview = sampleRateList[i];
                do {
                    m_device->set_sample_rate(i);
                    std::this_thread::sleep_for(std::chrono::milliseconds(10));
                } while (m_device->get_sample_rate() != i);

                drawSamplesBufferSize = std::round(sampleRateInts[i] * drawSamplesTimeframe);
            }
        }
        ImGui::EndCombo();
    }
    if (!enable)
        ImGui::PopDisabled();
}

void deviceConnect()
{
    static std::thread statusThread;

    if (!m_device) {
        stmdsp::scanner scanner;
        if (auto devices = scanner.scan(); !devices.empty()) {
            try {
                m_device.reset(new stmdsp::device(devices.front()));
            } catch (...) {
                log("Failed to connect (check permissions?).");
                m_device.reset();
            }

            if (m_device) {
                if (m_device->connected()) {
                    auto sri = m_device->get_sample_rate();
                    sampleRatePreview = sampleRateList[sri];
                    drawSamplesBufferSize = std::round(sampleRateInts[sri] * drawSamplesTimeframe);
                    log("Connected!");
                    statusThread = std::thread(statusTask, m_device);
                    statusThread.detach();
                } else {
                    m_device.reset();
                    log("Failed to connect.");
                }
            }
        } else {
            log("No devices found.");
        }
    } else {
        m_device->disconnect();
        if (statusThread.joinable())
            statusThread.join();
        m_device.reset();
        log("Disconnected.");
    }
}

void deviceStart()
{
    if (!m_device) {
        log("No device connected.");
        return;
    }

    if (m_device->is_running()) {
        {
            std::scoped_lock lock (mutexDrawSamples, mutexDeviceLoad);
            std::this_thread::sleep_for(std::chrono::microseconds(150));
            m_device->continuous_stop();
        }
        if (logResults) {
            logSamplesFile.close();
            logResults = false;
            log("Log file saved and closed.");
        }
        log("Ready.");
    } else {
        if (measureCodeTime) {
            m_device->continuous_start_measure();
            std::thread(measureCodeTask, m_device).detach();
        } else {
            m_device->continuous_start();
            if (drawSamples || logResults || wavOutput.valid())
                std::thread(drawSamplesTask, m_device).detach();
        }
        log("Running.");
    }
}

void deviceAlgorithmUpload()
{
    if (!m_device) {
        log("No device connected.");
        return;
    } else if (m_device->is_running()) {
        return;
    }

    if (std::ifstream algo (tempFileName + ".o"); algo.good()) {
        std::ostringstream sstr;
        sstr << algo.rdbuf();
        auto str = sstr.str();

        m_device->upload_filter(reinterpret_cast<unsigned char *>(&str[0]), str.size());
        log("Algorithm uploaded.");
    } else {
        log("Algorithm must be compiled first.");
    }
}

void deviceAlgorithmUnload()
{
    if (!m_device) {
        log("No device connected.");
    } else if (!m_device->is_running()) {
        m_device->unload_filter();
        log("Algorithm unloaded.");
    }
}

void deviceGenLoadList(const std::string_view list)
{
    std::vector<stmdsp::dacsample_t> samples;

    auto it = list.cbegin();
    while (it != list.cend() && samples.size() < stmdsp::SAMPLES_MAX * 2) {
        const auto end = list.find_first_not_of("0123456789",
            std::distance(list.cbegin(), it));
        const auto itend = end != std::string_view::npos ? list.cbegin() + end
                                                         : list.cend();
        unsigned long n;
        const auto [ptr, ec] = std::from_chars(it, itend, n);
        if (ec != std::errc())
            break;

        samples.push_back(n & 4095);
        it = itend;
    }

    if (samples.size() <= stmdsp::SAMPLES_MAX * 2) {
        // DAC buffer must be of even size
        if (samples.size() % 2 != 0)
            samples.push_back(samples.back());

        if (m_device)
            m_device->siggen_upload(samples.data(), samples.size());
        log("Generator ready.");
    } else {
        log("Error: Too many samples for signal generator.");
    }
}

void deviceGenLoadFormula(std::string_view formula)
{
    auto samples = deviceGenLoadFormulaEval(formula);

    if (!samples.empty()) {
        if (m_device)
            m_device->siggen_upload(samples.data(), samples.size());

        log("Generator ready.");
    } else {
        log("Error: Bad formula.");
    }
}