Skip to content
Snippets Groups Projects
mqtt_ota.cpp 2.96 KiB
#include "mqtt_ota.h"
#include "mqtt_common.h"

#include "esp_log.h"
#include "esp_system.h"
#include "esp_ota_ops.h"

#include <qthing.h>

#define TAG "MQTT_OTA"


void default_ota_event(ota_event_t event);
ota_callback_t ota_handler = default_ota_event;

esp_ota_handle_t ota_handle = 0;
const esp_partition_t* partition = NULL;

void handle_ota_message(const multipart_message_t& message) {
    ESP_LOGD(TAG, "length=%d offset=%d", message.length, message.offset);

    esp_err_t err;

    if (message.offset == 0) {  // first ota message
        partition = esp_ota_get_next_update_partition(NULL);
        err = esp_ota_begin(partition, message.total_length, &ota_handle);
        if (err != ESP_OK) ESP_LOGW(TAG, "BEGIN OTA FAILED!");

        ota_event_t event = {
            .state = start,
            .error = err,
            .bytes_written = 0,
            .bytes_total = (uint32_t)message.total_length
        };
        ota_handler(event);
    }

    err = esp_ota_write(ota_handle, (const void*) message.payload, message.length);
    if (err != ESP_OK) ESP_LOGW(TAG, "WRITE OTA FAILED!");

    ota_event_t event = {
        .state = progress,
        .error = err,
        .bytes_written = (uint32_t)(message.length + message.offset),
        .bytes_total = (uint32_t)message.total_length
    };
    ota_handler(event);

    if (message.offset + message.length == message.total_length) {  // last ota message
        err = esp_ota_end(ota_handle);
        if (err != ESP_OK) ESP_LOGW(TAG, "END OTA FAILED!");

        ota_event_t event = {
            .state = success,
            .error = err,
            .bytes_written = (uint32_t)message.total_length,
            .bytes_total = (uint32_t)message.total_length
        };
        ota_handler(event);

        err = esp_ota_set_boot_partition(partition);
        if (err != ESP_OK) ESP_LOGW(TAG, "FINALIZE OTA FAILED!");
    }

}

int8_t last_pct = -1;
char pct_buffer[4];
void default_ota_event(ota_event_t event) {
    switch(event.state) {
        case start:
            ESP_LOGI(TAG, "OTA Start");
            last_pct = -1;
            break;
        case progress: {
            int8_t pct = (int8_t)(((float)event.bytes_written / (float)event.bytes_total) * 100);

            if (pct > last_pct) {
                ESP_LOGI(TAG, "OTA Progress %d%% (%d/%d KiB)", pct, event.bytes_written / 1024, event.bytes_total / 1024);
                snprintf(pct_buffer, sizeof(pct_buffer), "%d", pct);
                publish_message(OTA_PROGRESS_TOPIC, pct_buffer);
                last_pct = pct;
            }
            break;
        }
        case success:
            ESP_LOGI(TAG, "OTA Successful");
            break;
        case error:
            ESP_LOGW(TAG, "OTA Error 0x%x", event.error);
            break;
    }
}

void add_ota_callback(ota_callback_t handler) {
    ota_callback_t old_ota_handler = ota_handler;
    ota_handler = [old_ota_handler, handler](ota_event_t event){
        old_ota_handler(event);
        handler(event);
    };
}