Skip to content
Snippets Groups Projects
uf2updater.rs 16.15 KiB
use core::mem::size_of;

use bitvec::{BitArr, bitarr};
use bitvec::order::Lsb0;
use defmt::*;
use embassy_boot::{Partition, AlignedBuffer, FirmwareUpdater};
use embassy_rp::{flash::Flash, peripherals};
use zerocopy::FromBytes;

use crate::{modbus_server::ModbusErrorCode, uf2::*};

#[derive(PartialEq, Eq, Debug, Format)]
enum PositionInSector {
    Start, StartPartial, Middle, End
}

struct BootLoaderPartitions {
    state: Partition,
    active: Partition,
    dfu: Partition,
    flash_start_addr: u32,
}
// copied from embassy/embassy-boot/rp/src/lib.rs because fields are private for BootLoader
impl Default for BootLoaderPartitions {
    /// Create a new bootloader instance using parameters from linker script
    fn default() -> Self {
        extern "C" {
            static __bootloader_state_start: u32;
            static __bootloader_state_end: u32;
            static __bootloader_active_start: u32;
            static __bootloader_active_end: u32;
            static __bootloader_dfu_start: u32;
            static __bootloader_dfu_end: u32;
            static __bootloader_flash_start: u32;
        }

        let active = unsafe {
            Partition::new(
                &__bootloader_active_start as *const u32 as u32,
                &__bootloader_active_end as *const u32 as u32,
            )
        };
        let dfu = unsafe {
            Partition::new(
                &__bootloader_dfu_start as *const u32 as u32,
                &__bootloader_dfu_end as *const u32 as u32,
            )
        };
        let state = unsafe {
            Partition::new(
                &__bootloader_state_start as *const u32 as u32,
                &__bootloader_state_end as *const u32 as u32,
            )
        };
        let flash_start_addr = unsafe {
            &__bootloader_flash_start as *const u32 as u32
        };

        BootLoaderPartitions{ active, dfu, state, flash_start_addr }
    }
}

const FLASH_SIZE_GLOBAL: usize = 2 * 1024 * 1024;
const MAX_UF2_SECTORS_MIN: usize = FLASH_SIZE_GLOBAL / 2 / 256;
const MAX_UF2_SECTORS: usize = (MAX_UF2_SECTORS_MIN + 31) / 32 * 32;

pub struct UF2UpdateHandler<const FLASH_SIZE: usize> {
    buf: AlignedBuffer<4096>,
    write_pos: u32,
    pub flash: Flash<'static, peripherals::FLASH, FLASH_SIZE>,
    uf2_seen_bitmask: BitArr!(for MAX_UF2_SECTORS, in u32),
    uf2_num_blocks: u32,
    flash_erased_address_and_first_block: Option<(u32, usize)>,
    bootloader_state_erased: bool,
}

impl<const FLASH_SIZE: usize> UF2UpdateHandler<FLASH_SIZE> {
    pub fn new(flash: peripherals::FLASH) -> Self {
        defmt::assert!(FLASH_SIZE == FLASH_SIZE_GLOBAL);
        UF2UpdateHandler {
            buf: AlignedBuffer([0; 4096]),
            write_pos: 0,
            flash: Flash::new(flash),
            uf2_seen_bitmask: bitarr!(u32, Lsb0; 0; MAX_UF2_SECTORS),
            uf2_num_blocks: 0,
            flash_erased_address_and_first_block: None,
            bootloader_state_erased: false,
        }
    }

    pub fn write(self: &mut Self, pos: u32, data: &[u8]) -> Result<(), ModbusErrorCode> {
        //info!("UF2Updater::write: {}, len={}", pos, data.len());
        if data.len() > 256 {
            // We could handle these cases but we are limited by Modbus frames anyway.
            info!("Too much data in one call to UF2UpdateHandler::write()");
            self.write_pos = 0;
            return Err(ModbusErrorCode::IllegalDataValue);
        }

        if pos % 512 == self.write_pos {
            // ok
        } else if pos == 0 {
            // We are aborting a previous write but that can be ok.
            self.write_pos = 0;
        } else {
            info!("Unexpected write address in UF2UpdateHandler: {}", pos);
            self.write_pos = 0;

            // Re-sync if this write crosses the 512-byte boundary, i.e. keep data at start of the next sector.
            let write_size_to_end_of_sector = 512 - (pos as usize % 512);
            if write_size_to_end_of_sector < data.len() {
                let write_size2 = data.len() - write_size_to_end_of_sector;
                self.buf.0[0 .. write_size2].copy_from_slice(&data[write_size_to_end_of_sector .. data.len()]);
                self.write_pos = write_size2 as u32;
            }

            return Err(ModbusErrorCode::IllegalDataValue);
        }

        let write_size1 = core::cmp::min(512 - (self.write_pos as usize % 512), data.len());
        if write_size1 > 0 {
            self.buf.0[self.write_pos as usize .. self.write_pos as usize + write_size1].copy_from_slice(&data[0 .. write_size1]);
            self.write_pos += write_size1 as u32;

            if self.write_pos == 512 {
                let result = self.process_sector();

                self.write_pos = 0;
                if data.len() > write_size1 {
                    let write_size2 = data.len() - write_size1;
                    self.buf.0[0 .. write_size2].copy_from_slice(&data[write_size1 .. data.len()]);
                    self.write_pos = write_size2 as u32;
                }

                result?;
            }
        }

        Ok(())
    }

    fn process_sector(self: &mut Self) -> Result<(), ModbusErrorCode> {
        defmt::assert!(self.write_pos == 512);

        let uf2_block = &self.buf.0[0..512];

        //info!("process_sector: {}", uf2_block);
        //info!("process_sector: len={}", uf2_block.len());

        let uf2_header = Uf2BlockHeader::read_from_prefix(uf2_block).unwrap();
        let uf2_footer = Uf2BlockFooter::read_from_suffix(uf2_block).unwrap();
        if uf2_header.magic_start0 != UF2_MAGIC_START0 || uf2_header.magic_start1 != UF2_MAGIC_START1 || uf2_footer.magic_end != UF2_MAGIC_END {
            warn!("Invalid magic in UF2 block");
            return Err(ModbusErrorCode::IllegalDataValue)
        }
        if uf2_header.num_blocks as usize > self.uf2_seen_bitmask.len() {
            warn!("We cannot support that many blocks in one UF2 file.");
            return Err(ModbusErrorCode::IllegalDataValue);
        }
        if uf2_header.block_no >= uf2_header.num_blocks {
            warn!("Invalid block_no in UF2 header");
            return Err(ModbusErrorCode::IllegalDataValue);
        }
        if uf2_header.payload_size as usize > 512 - size_of::<Uf2BlockHeader>() - size_of::<Uf2BlockFooter>() {
            warn!("Invalid block_no in UF2 header");
            return Err(ModbusErrorCode::IllegalDataValue);
        }
        if (uf2_header.flags & UF2_FLAG_FAMILY_ID_PRESENT) == 0 || uf2_header.file_size != RP2040_FAMILY_ID {
            // not for us but that shouldn't be treated as an error
            warn!("Ignoring UF2 block for different family");
            self.uf2_seen_bitmask.set(uf2_header.block_no as usize, true);
            return Ok(())
        }
        if (uf2_header.flags & (UF2_FLAG_NOT_MAIN_FLASH | UF2_FLAG_FILE_CONTAINER)) != 0 {
            // not for DFU partition
            warn!("Ignoring UF2 block that is not for main flash");
            self.uf2_seen_bitmask.set(uf2_header.block_no as usize, true);
            return Ok(())
        }

        if uf2_header.block_no == 0 || uf2_header.num_blocks != self.uf2_num_blocks {
            self.uf2_seen_bitmask.fill_with(|_| false);
            self.uf2_num_blocks = uf2_header.num_blocks;
            self.flash_erased_address_and_first_block = None;
        }

        let data_start = size_of::<Uf2BlockHeader>();
        let data_end = data_start + uf2_header.payload_size as usize;
        let addr = uf2_header.target_addr as usize;
        let til_end_of_page = 4096 - (addr % 4096);
        if (addr % 4096) == 0 {
            self.process_sector_part(PositionInSector::Start, uf2_header.block_no as usize, addr as u32, data_start, data_end);
        } else if (uf2_header.payload_size as usize) < til_end_of_page {
            self.process_sector_part(PositionInSector::Middle, uf2_header.block_no as usize, addr as u32, data_start, data_end);
        } else {
            self.process_sector_part(PositionInSector::End, uf2_header.block_no as usize, addr as u32, data_start, data_start + til_end_of_page);
            let remaining = uf2_header.payload_size as usize - til_end_of_page;
            if remaining > 0 {
                self.process_sector_part(PositionInSector::StartPartial, uf2_header.block_no as usize,
                    (addr + til_end_of_page) as u32, data_start + til_end_of_page, data_end);
            }
        }

        Ok(())
    }

    fn process_sector_part(self: &mut Self, pos: PositionInSector, block_no: usize, addr: u32, data_start: usize, data_end: usize) {
        use PositionInSector::*;
        let data = &self.buf.0[data_start .. data_end];

        info!("process_sector_part: block {}, {:?}, addr {:08x}, len {}", block_no, pos, addr, data.len());

        let partitions = BootLoaderPartitions::default();
        if addr < partitions.flash_start_addr + partitions.active.from || addr >= partitions.flash_start_addr + partitions.active.to {
            // We don't want to write this.
            if pos != StartPartial {
                self.uf2_seen_bitmask.set(block_no, true);
            }
            //info!("not in active partition: not  {:08x} <= {:08x} < {:08x}",
            //    partitions.flash_start_addr + partitions.active.from, addr, partitions.flash_start_addr + partitions.active.to);
            return;
        }
        let addr_orig = addr;
        let addr = addr - partitions.flash_start_addr - partitions.active.from + partitions.dfu.from;

        let abort_previous: bool;
        let process_current: bool;
        let is_start = match pos {
            Start | StartPartial => true,
            Middle | End => false,
        };

        if let Some((erased_addr, first_block)) = self.flash_erased_address_and_first_block {
            if is_start {
                // previous block is not done but we are starting a new one -> abort previous one
                abort_previous = true;
                process_current = true;
            } else if erased_addr == addr {
                // address matches current partially written block -> continue writing to it
                abort_previous = false;
                process_current = true;
            } else {
                // address doesn't match -> abort previous and we can't do anything useful with the current one either
                abort_previous = true;
                process_current = false;
            }

            // If there is a partially written page, mark the blocks as not seen so we will later try again.
            if abort_previous {
                info!("aborting from block {} to {} because flash_erased_address_and_first_block={:?} and current is {:?} {}",
                    first_block, block_no, self.flash_erased_address_and_first_block, pos, block_no);
                for i in first_block .. block_no {
                    self.uf2_seen_bitmask.set(i, false);
                }
                self.flash_erased_address_and_first_block = None;
            }
        } else {
            if is_start {
                // previous block was done and we are starting a new one -> continue with that
                process_current = true;
            } else {
                // previous block was done but this is not the start of a new one -> ignore it
                process_current = false;
            }
        }

        if !process_current {
            return;
        }

        let already_processed = match pos {
            StartPartial => block_no+1 < self.uf2_seen_bitmask.len() && self.uf2_seen_bitmask[block_no+1],
            _ => self.uf2_seen_bitmask[block_no],
        };
        if already_processed {
            return;
        }
        if is_start {
            // We have to erase the block. However, let's erase the bootloader state first
            // because we don't want to swap into a partially cleared DFU partition.
            if !self.bootloader_state_erased {
                info!("erasing state partition at {:08x}", partitions.state.from);
                match self.flash.erase(partitions.state.from, partitions.state.to) {
                    Ok(()) => (),
                    Err(e) => {
                        error!("erase: {:?} at {:08x}", e, partitions.state.from);
                        return;
                    }
                }

                self.bootloader_state_erased = true;
            }

            info!("erasing at {:08x} -> {:08x}", addr_orig, addr);
            match self.flash.erase(addr, addr+4096) {
                Ok(()) => (),
                Err(e) => {
                    error!("erase: {:?} at {:08x} -> {:08x}", e, addr_orig, addr);
                    return;
                }
            }

            // From which block should we start marking as unseen if we later encounter an error
            // for this block? This is usually the current block but if this block also contains
            // the end of another sector, we use the next one.
            let first_block = match pos {
                Start => block_no,
                StartPartial => block_no+1,
                _ => defmt::unreachable!(),
            };

            self.flash_erased_address_and_first_block = Some((addr, first_block));
        }

        let (mut prev_addr, first_block) = self.flash_erased_address_and_first_block.unwrap();
        defmt::assert!(prev_addr == addr);
        //FIXME We might have to handle alignment concerns for address and data.
        info!("writing to {:08x} -> {:08x}, len={}", addr_orig, addr, data.len());
        let r = self.flash.write(addr, data);
        if let Err(e) = r {
            error!("write: {:?} at {:08x} -> {:08x}", e, addr_orig, addr);

            // abort current block
            for i in first_block .. block_no {
                self.uf2_seen_bitmask.set(i, false);
            }
            self.flash_erased_address_and_first_block = None;
            return;
        }

        prev_addr += data.len() as u32;
        if (prev_addr % 4096) == 0 {
            self.flash_erased_address_and_first_block = None;
        } else {
            self.flash_erased_address_and_first_block = Some((prev_addr, first_block));
        }

        if pos != StartPartial {
            self.uf2_seen_bitmask.set(block_no, true);
        }
    }

    pub fn get_missing_block_info(self: &mut Self) -> (u32, u32) {
        if self.uf2_num_blocks == 0 {
            return (0, 0)
        } else if (self.uf2_num_blocks as usize) > self.uf2_seen_bitmask.len() {
            return (self.uf2_num_blocks, self.uf2_num_blocks)
        } else {
            match self.uf2_seen_bitmask.first_zero() {
                None => (1, 1),
                Some(pos) if pos >= self.uf2_num_blocks as usize => (1, 1),
                Some(first_missing) => {
                    let (_, first_present) = self.uf2_seen_bitmask.split_at(first_missing);
                    let first_present = first_present.first_one()
                        .map(|x| (x + first_missing) as u32)
                        .unwrap_or(self.uf2_num_blocks);
                    let first_missing = first_missing as u32;
                    // subtract one because a partial sector might be in the previous block
                    let first_missing = if first_missing > 0 { first_missing - 1 } else { first_missing };
                    (first_missing as u32, first_present)
                }
            }
        }
    }

    pub fn successfully_programmed(self: &mut Self) -> bool {
        self.get_missing_block_info() == (1, 1)
    }

    pub fn mark_updated(self: &mut Self) -> Result<(), ModbusErrorCode> {
        self.bootloader_state_erased = false;

        let mut updater = FirmwareUpdater::default();
        match updater.mark_updated_blocking(&mut self.flash, &mut self.buf.0[..1]) {
            Ok(()) => Ok(()),
            Err(e) => {
                error!("mark_updated_blocking: {:?}", e);
                Err(ModbusErrorCode::ServerDeviceFailure)
            }
        }
    }

    pub fn mark_booted(self: &mut Self) -> Result<(), ModbusErrorCode> {
        self.bootloader_state_erased = false;

        let mut updater = FirmwareUpdater::default();
        match updater.mark_booted_blocking(&mut self.flash, &mut self.buf.0[..1]) {
            Ok(()) => Ok(()),
            Err(e) => {
                error!("mark_booted_blocking: {:?}", e);
                Err(ModbusErrorCode::ServerDeviceFailure)
            }
        }
    }
}