#include "mqtt.h"
#include "mqtt_common.h"
#include "mqtt_ota.h"

#include "environment.h"
#include "event.h"
#include "device/device_config.h"
#include "io.h"

#include <new>
#include <stdio.h>
#include <stdint.h>
#include <stddef.h>
#include <map>
#include "esp_system.h"
#include "esp_event_loop.h"
#include "esp_log.h"

#include "freertos/FreeRTOS.h"
#include "freertos/task.h"
#include "freertos/semphr.h"
#include "freertos/queue.h"
#include "freertos/event_groups.h"

#include "lwip/sockets.h"
#include "lwip/dns.h"
#include "lwip/netdb.h"

extern "C" {
    #include "mqtt_client.h"
    #include "mqtt_c.h"
}

#define TAG "MQTT"
bool mqtt_started = false;
bool mqtt_connected = false;

esp_mqtt_client_handle_t client;

std::map<std::string, message_callback_t> callback_map;
std::map<std::string, binary_message_callback_t> binary_callback_map;

mqtt_connected_callback_t connected_callback = [](){};

std::string multipart_topic;

void on_message(const std::string& topic, const std::string& message) {
    auto iterator = callback_map.find(topic);
    if (iterator != callback_map.end()) {
        auto callback = iterator->second;
        callback(message);
    }
}

void on_binary_message(const multipart_message_t& message) {
    auto iterator = binary_callback_map.find(message.topic);
    if (iterator != binary_callback_map.end()) {
        auto callback = iterator->second;
        callback(message);
    }
}

esp_err_t mqtt_event_handler(esp_mqtt_event_handle_t event)
{
    esp_mqtt_client_handle_t client = event->client;
    int msg_id;
    int result;
    switch (event->event_id) {
        case MQTT_EVENT_CONNECTED:
            ESP_LOGI(TAG, "MQTT_EVENT_CONNECTED");

            if (!callback_map.empty()) {
                for(std::map<std::string, message_callback_t>::iterator it = callback_map.begin(); it != callback_map.end(); ++it) {
                    result = esp_mqtt_client_subscribe(client, it->first.c_str(), 0);
                    ESP_LOGI(TAG, "Subscribing to topic %s (result %i)", it->first.c_str(), result);
                }

                for(std::map<std::string, binary_message_callback_t>::iterator it = binary_callback_map.begin(); it != binary_callback_map.end(); ++it) {
                    if (callback_map.count(it->first) == 0) {  // only subscribe once for every topic
                        result = esp_mqtt_client_subscribe(client, it->first.c_str(), 0);
                        ESP_LOGI(TAG, "Subscribing to topic %s (result %i)", it->first.c_str(), result);
                    }
                }
            }
            result = esp_mqtt_client_publish(client, STATUS_TOPIC, STATUS_MESSAGE_ONLINE, strlen(STATUS_MESSAGE_ONLINE), 0, 1);
            ESP_LOGI(TAG, "Online status published (result %i)", result);
            update_mqtt_connection_status(connected);
            mqtt_connected = true;
            connected_callback();
            break;
        case MQTT_EVENT_DISCONNECTED:
            ESP_LOGI(TAG, "MQTT_EVENT_DISCONNECTED");
            mqtt_connected = false;
            update_mqtt_connection_status(disconnected);
            break;

        case MQTT_EVENT_SUBSCRIBED:
            ESP_LOGI(TAG, "MQTT_EVENT_SUBSCRIBED, msg_id=%d", event->msg_id);
            break;
        case MQTT_EVENT_UNSUBSCRIBED:
            ESP_LOGI(TAG, "MQTT_EVENT_UNSUBSCRIBED, msg_id=%d", event->msg_id);
            break;
        case MQTT_EVENT_PUBLISHED:
            ESP_LOGI(TAG, "MQTT_EVENT_PUBLISHED, msg_id=%d", event->msg_id);
            break;
        case MQTT_EVENT_DATA: {
            if(event->current_data_offset == 0) {
                std::string topic(event->topic, event->topic_len);
                ESP_LOGI(TAG, "MQTT_EVENT_DATA: %s (length %d)", topic.c_str(), event->data_len);
                show_activity();

                multipart_message_t multipart_message = {
                    .topic = topic,
                    .payload = event->data,
                    .length = (uint32_t)event->data_len,
                    .offset = (uint32_t)event->current_data_offset,
                    .total_length = (uint32_t)event->total_data_len
                };
                on_binary_message(multipart_message);

                if (event->data_len < event->total_data_len) {  // first part of multipart message
                    multipart_topic = topic;
                }
                else {
                    std::string message(event->data, event->data_len);
                    ESP_LOGI(TAG, "%s %s", topic.c_str(), message.c_str());
                    on_message(topic, message);
                }
            }
            else {
                ESP_LOGI(TAG, "MQTT_EVENT_DATA: %s (length %d, offset %d)", multipart_topic.c_str(), event->data_len, event->current_data_offset);

                multipart_message_t multipart_message = {
                    .topic = multipart_topic,
                    .payload = event->data,
                    .length = (uint32_t)event->data_len,
                    .offset = (uint32_t)event->current_data_offset,
                    .total_length = (uint32_t)event->total_data_len
                };
                on_binary_message(multipart_message);

                if (event->current_data_offset + event->data_len == event->total_data_len) {  // last part of multipart message
                    multipart_topic.clear();
                }
            }
            break;
        }
        case MQTT_EVENT_ERROR:
            ESP_LOGI(TAG, "MQTT_EVENT_ERROR");
            break;
    }
    return ESP_OK;
}

void handle_firmware_command(const std::string& command) {
    if (command == "restart") {
        esp_restart();
    }
}

void publish_message(const std::string& topic, const std::string& message) {
    if (mqtt_started) {
        if (mqtt_connected) {
            show_activity();

            const char* topic_data = topic.c_str();
            const char* message_data = message.c_str();
            ESP_LOGI(TAG, "Publish started (%s: %s)", topic_data, message_data);
            int result = esp_mqtt_client_publish(client, topic_data, message_data, message.length(), 0, 0);
            if (result < 0) {
                ESP_LOGE(TAG, "Publish failed (%s: %s) (result %i)", topic.c_str(), message.c_str(), result);
            }
            else {
                ESP_LOGI(TAG, "Publish finished (%s: %s) (result %i)", topic.c_str(), message.c_str(), result);
            }
        }
        else {
            ESP_LOGW(TAG, "Cannot publish message: MQTT client is not connected (%s: %s)", topic.c_str(), message.c_str());
        }
    }
    else {
        ESP_LOGW(TAG, "Cannot publish message: MQTT client has not been started (%s: %s)", topic.c_str(), message.c_str());
    }
}


void add_message_callback(const std::string& topic, message_callback_t callback) {
    if (mqtt_started) {
        ESP_LOGW(TAG, "Not subscribint to topic %s: MQTT client has already been started", topic.c_str());
    }
    if (callback_map.count(topic) > 0) {
        message_callback_t old_callback = callback_map.at(topic);
        callback_map[topic] = [old_callback, callback](const std::string& message){
            old_callback(message);
            callback(message);
        };
    }
    else {
        callback_map[topic] = callback;
    }
}

void add_binary_message_callback(const std::string& topic, binary_message_callback_t callback) {
    if (mqtt_started) {
        ESP_LOGW(TAG, "Not subscribint to topic %s: MQTT client has already been started", topic.c_str());
    }
    if (binary_callback_map.count(topic) > 0) {
        binary_message_callback_t old_callback = binary_callback_map.at(topic);
        binary_callback_map[topic] = [old_callback, callback](const multipart_message_t& message){
            old_callback(message);
            callback(message);
        };
    }
    else {
        binary_callback_map[topic] = callback;
    }
}

void mqtt_client_start() {
    if (!mqtt_started) {
        add_message_callback(FIRMWARE_COMMAND_TOPIC, handle_firmware_command);
        add_binary_message_callback(OTA_FIRMWARE_TOPIC, handle_ota_message);
        mqtt_started = true;
        update_mqtt_connection_status(connecting);
        client = mqtt_client_start_c(mqtt_event_handler);
    }
}

void add_mqtt_connected_callback(mqtt_connected_callback_t callback) {
    mqtt_connected_callback_t old_callback = connected_callback;
    connected_callback = [old_callback, callback](){
        old_callback();
        callback();
    };
}