-
Benjamin Koch authoredBenjamin Koch authored
uf2updater.rs 16.15 KiB
use core::mem::size_of;
use bitvec::{BitArr, bitarr};
use bitvec::order::Lsb0;
use defmt::*;
use embassy_boot::{Partition, AlignedBuffer, FirmwareUpdater};
use embassy_rp::{flash::Flash, peripherals};
use zerocopy::FromBytes;
use crate::{modbus_server::ModbusErrorCode, uf2::*};
#[derive(PartialEq, Eq, Debug, Format)]
enum PositionInSector {
Start, StartPartial, Middle, End
}
struct BootLoaderPartitions {
state: Partition,
active: Partition,
dfu: Partition,
flash_start_addr: u32,
}
// copied from embassy/embassy-boot/rp/src/lib.rs because fields are private for BootLoader
impl Default for BootLoaderPartitions {
/// Create a new bootloader instance using parameters from linker script
fn default() -> Self {
extern "C" {
static __bootloader_state_start: u32;
static __bootloader_state_end: u32;
static __bootloader_active_start: u32;
static __bootloader_active_end: u32;
static __bootloader_dfu_start: u32;
static __bootloader_dfu_end: u32;
static __bootloader_flash_start: u32;
}
let active = unsafe {
Partition::new(
&__bootloader_active_start as *const u32 as u32,
&__bootloader_active_end as *const u32 as u32,
)
};
let dfu = unsafe {
Partition::new(
&__bootloader_dfu_start as *const u32 as u32,
&__bootloader_dfu_end as *const u32 as u32,
)
};
let state = unsafe {
Partition::new(
&__bootloader_state_start as *const u32 as u32,
&__bootloader_state_end as *const u32 as u32,
)
};
let flash_start_addr = unsafe {
&__bootloader_flash_start as *const u32 as u32
};
BootLoaderPartitions{ active, dfu, state, flash_start_addr }
}
}
const FLASH_SIZE_GLOBAL: usize = 2 * 1024 * 1024;
const MAX_UF2_SECTORS_MIN: usize = FLASH_SIZE_GLOBAL / 2 / 256;
const MAX_UF2_SECTORS: usize = (MAX_UF2_SECTORS_MIN + 31) / 32 * 32;
pub struct UF2UpdateHandler<const FLASH_SIZE: usize> {
buf: AlignedBuffer<4096>,
write_pos: u32,
pub flash: Flash<'static, peripherals::FLASH, FLASH_SIZE>,
uf2_seen_bitmask: BitArr!(for MAX_UF2_SECTORS, in u32),
uf2_num_blocks: u32,
flash_erased_address_and_first_block: Option<(u32, usize)>,
bootloader_state_erased: bool,
}
impl<const FLASH_SIZE: usize> UF2UpdateHandler<FLASH_SIZE> {
pub fn new(flash: peripherals::FLASH) -> Self {
defmt::assert!(FLASH_SIZE == FLASH_SIZE_GLOBAL);
UF2UpdateHandler {
buf: AlignedBuffer([0; 4096]),
write_pos: 0,
flash: Flash::new(flash),
uf2_seen_bitmask: bitarr!(u32, Lsb0; 0; MAX_UF2_SECTORS),
uf2_num_blocks: 0,
flash_erased_address_and_first_block: None,
bootloader_state_erased: false,
}
}
pub fn write(self: &mut Self, pos: u32, data: &[u8]) -> Result<(), ModbusErrorCode> {
//info!("UF2Updater::write: {}, len={}", pos, data.len());
if data.len() > 256 {
// We could handle these cases but we are limited by Modbus frames anyway.
info!("Too much data in one call to UF2UpdateHandler::write()");
self.write_pos = 0;
return Err(ModbusErrorCode::IllegalDataValue);
}
if pos % 512 == self.write_pos {
// ok
} else if pos == 0 {
// We are aborting a previous write but that can be ok.
self.write_pos = 0;
} else {
info!("Unexpected write address in UF2UpdateHandler: {}", pos);
self.write_pos = 0;
// Re-sync if this write crosses the 512-byte boundary, i.e. keep data at start of the next sector.
let write_size_to_end_of_sector = 512 - (pos as usize % 512);
if write_size_to_end_of_sector < data.len() {
let write_size2 = data.len() - write_size_to_end_of_sector;
self.buf.0[0 .. write_size2].copy_from_slice(&data[write_size_to_end_of_sector .. data.len()]);
self.write_pos = write_size2 as u32;
}
return Err(ModbusErrorCode::IllegalDataValue);
}
let write_size1 = core::cmp::min(512 - (self.write_pos as usize % 512), data.len());
if write_size1 > 0 {
self.buf.0[self.write_pos as usize .. self.write_pos as usize + write_size1].copy_from_slice(&data[0 .. write_size1]);
self.write_pos += write_size1 as u32;
if self.write_pos == 512 {
let result = self.process_sector();
self.write_pos = 0;
if data.len() > write_size1 {
let write_size2 = data.len() - write_size1;
self.buf.0[0 .. write_size2].copy_from_slice(&data[write_size1 .. data.len()]);
self.write_pos = write_size2 as u32;
}
result?;
}
}
Ok(())
}
fn process_sector(self: &mut Self) -> Result<(), ModbusErrorCode> {
defmt::assert!(self.write_pos == 512);
let uf2_block = &self.buf.0[0..512];
//info!("process_sector: {}", uf2_block);
//info!("process_sector: len={}", uf2_block.len());
let uf2_header = Uf2BlockHeader::read_from_prefix(uf2_block).unwrap();
let uf2_footer = Uf2BlockFooter::read_from_suffix(uf2_block).unwrap();
if uf2_header.magic_start0 != UF2_MAGIC_START0 || uf2_header.magic_start1 != UF2_MAGIC_START1 || uf2_footer.magic_end != UF2_MAGIC_END {
warn!("Invalid magic in UF2 block");
return Err(ModbusErrorCode::IllegalDataValue)
}
if uf2_header.num_blocks as usize > self.uf2_seen_bitmask.len() {
warn!("We cannot support that many blocks in one UF2 file.");
return Err(ModbusErrorCode::IllegalDataValue);
}
if uf2_header.block_no >= uf2_header.num_blocks {
warn!("Invalid block_no in UF2 header");
return Err(ModbusErrorCode::IllegalDataValue);
}
if uf2_header.payload_size as usize > 512 - size_of::<Uf2BlockHeader>() - size_of::<Uf2BlockFooter>() {
warn!("Invalid block_no in UF2 header");
return Err(ModbusErrorCode::IllegalDataValue);
}
if (uf2_header.flags & UF2_FLAG_FAMILY_ID_PRESENT) == 0 || uf2_header.file_size != RP2040_FAMILY_ID {
// not for us but that shouldn't be treated as an error
warn!("Ignoring UF2 block for different family");
self.uf2_seen_bitmask.set(uf2_header.block_no as usize, true);
return Ok(())
}
if (uf2_header.flags & (UF2_FLAG_NOT_MAIN_FLASH | UF2_FLAG_FILE_CONTAINER)) != 0 {
// not for DFU partition
warn!("Ignoring UF2 block that is not for main flash");
self.uf2_seen_bitmask.set(uf2_header.block_no as usize, true);
return Ok(())
}
if uf2_header.block_no == 0 || uf2_header.num_blocks != self.uf2_num_blocks {
self.uf2_seen_bitmask.fill_with(|_| false);
self.uf2_num_blocks = uf2_header.num_blocks;
self.flash_erased_address_and_first_block = None;
}
let data_start = size_of::<Uf2BlockHeader>();
let data_end = data_start + uf2_header.payload_size as usize;
let addr = uf2_header.target_addr as usize;
let til_end_of_page = 4096 - (addr % 4096);
if (addr % 4096) == 0 {
self.process_sector_part(PositionInSector::Start, uf2_header.block_no as usize, addr as u32, data_start, data_end);
} else if (uf2_header.payload_size as usize) < til_end_of_page {
self.process_sector_part(PositionInSector::Middle, uf2_header.block_no as usize, addr as u32, data_start, data_end);
} else {
self.process_sector_part(PositionInSector::End, uf2_header.block_no as usize, addr as u32, data_start, data_start + til_end_of_page);
let remaining = uf2_header.payload_size as usize - til_end_of_page;
if remaining > 0 {
self.process_sector_part(PositionInSector::StartPartial, uf2_header.block_no as usize,
(addr + til_end_of_page) as u32, data_start + til_end_of_page, data_end);
}
}
Ok(())
}
fn process_sector_part(self: &mut Self, pos: PositionInSector, block_no: usize, addr: u32, data_start: usize, data_end: usize) {
use PositionInSector::*;
let data = &self.buf.0[data_start .. data_end];
info!("process_sector_part: block {}, {:?}, addr {:08x}, len {}", block_no, pos, addr, data.len());
let partitions = BootLoaderPartitions::default();
if addr < partitions.flash_start_addr + partitions.active.from || addr >= partitions.flash_start_addr + partitions.active.to {
// We don't want to write this.
if pos != StartPartial {
self.uf2_seen_bitmask.set(block_no, true);
}
//info!("not in active partition: not {:08x} <= {:08x} < {:08x}",
// partitions.flash_start_addr + partitions.active.from, addr, partitions.flash_start_addr + partitions.active.to);
return;
}
let addr_orig = addr;
let addr = addr - partitions.flash_start_addr - partitions.active.from + partitions.dfu.from;
let abort_previous: bool;
let process_current: bool;
let is_start = match pos {
Start | StartPartial => true,
Middle | End => false,
};
if let Some((erased_addr, first_block)) = self.flash_erased_address_and_first_block {
if is_start {
// previous block is not done but we are starting a new one -> abort previous one
abort_previous = true;
process_current = true;
} else if erased_addr == addr {
// address matches current partially written block -> continue writing to it
abort_previous = false;
process_current = true;
} else {
// address doesn't match -> abort previous and we can't do anything useful with the current one either
abort_previous = true;
process_current = false;
}
// If there is a partially written page, mark the blocks as not seen so we will later try again.
if abort_previous {
info!("aborting from block {} to {} because flash_erased_address_and_first_block={:?} and current is {:?} {}",
first_block, block_no, self.flash_erased_address_and_first_block, pos, block_no);
for i in first_block .. block_no {
self.uf2_seen_bitmask.set(i, false);
}
self.flash_erased_address_and_first_block = None;
}
} else {
if is_start {
// previous block was done and we are starting a new one -> continue with that
process_current = true;
} else {
// previous block was done but this is not the start of a new one -> ignore it
process_current = false;
}
}
if !process_current {
return;
}
let already_processed = match pos {
StartPartial => block_no+1 < self.uf2_seen_bitmask.len() && self.uf2_seen_bitmask[block_no+1],
_ => self.uf2_seen_bitmask[block_no],
};
if already_processed {
return;
}
if is_start {
// We have to erase the block. However, let's erase the bootloader state first
// because we don't want to swap into a partially cleared DFU partition.
if !self.bootloader_state_erased {
info!("erasing state partition at {:08x}", partitions.state.from);
match self.flash.erase(partitions.state.from, partitions.state.to) {
Ok(()) => (),
Err(e) => {
error!("erase: {:?} at {:08x}", e, partitions.state.from);
return;
}
}
self.bootloader_state_erased = true;
}
info!("erasing at {:08x} -> {:08x}", addr_orig, addr);
match self.flash.erase(addr, addr+4096) {
Ok(()) => (),
Err(e) => {
error!("erase: {:?} at {:08x} -> {:08x}", e, addr_orig, addr);
return;
}
}
// From which block should we start marking as unseen if we later encounter an error
// for this block? This is usually the current block but if this block also contains
// the end of another sector, we use the next one.
let first_block = match pos {
Start => block_no,
StartPartial => block_no+1,
_ => defmt::unreachable!(),
};
self.flash_erased_address_and_first_block = Some((addr, first_block));
}
let (mut prev_addr, first_block) = self.flash_erased_address_and_first_block.unwrap();
defmt::assert!(prev_addr == addr);
//FIXME We might have to handle alignment concerns for address and data.
info!("writing to {:08x} -> {:08x}, len={}", addr_orig, addr, data.len());
let r = self.flash.write(addr, data);
if let Err(e) = r {
error!("write: {:?} at {:08x} -> {:08x}", e, addr_orig, addr);
// abort current block
for i in first_block .. block_no {
self.uf2_seen_bitmask.set(i, false);
}
self.flash_erased_address_and_first_block = None;
return;
}
prev_addr += data.len() as u32;
if (prev_addr % 4096) == 0 {
self.flash_erased_address_and_first_block = None;
} else {
self.flash_erased_address_and_first_block = Some((prev_addr, first_block));
}
if pos != StartPartial {
self.uf2_seen_bitmask.set(block_no, true);
}
}
pub fn get_missing_block_info(self: &mut Self) -> (u32, u32) {
if self.uf2_num_blocks == 0 {
return (0, 0)
} else if (self.uf2_num_blocks as usize) > self.uf2_seen_bitmask.len() {
return (self.uf2_num_blocks, self.uf2_num_blocks)
} else {
match self.uf2_seen_bitmask.first_zero() {
None => (1, 1),
Some(pos) if pos >= self.uf2_num_blocks as usize => (1, 1),
Some(first_missing) => {
let (_, first_present) = self.uf2_seen_bitmask.split_at(first_missing);
let first_present = first_present.first_one()
.map(|x| (x + first_missing) as u32)
.unwrap_or(self.uf2_num_blocks);
let first_missing = first_missing as u32;
// subtract one because a partial sector might be in the previous block
let first_missing = if first_missing > 0 { first_missing - 1 } else { first_missing };
(first_missing as u32, first_present)
}
}
}
}
pub fn successfully_programmed(self: &mut Self) -> bool {
self.get_missing_block_info() == (1, 1)
}
pub fn mark_updated(self: &mut Self) -> Result<(), ModbusErrorCode> {
self.bootloader_state_erased = false;
let mut updater = FirmwareUpdater::default();
match updater.mark_updated_blocking(&mut self.flash, &mut self.buf.0[..1]) {
Ok(()) => Ok(()),
Err(e) => {
error!("mark_updated_blocking: {:?}", e);
Err(ModbusErrorCode::ServerDeviceFailure)
}
}
}
pub fn mark_booted(self: &mut Self) -> Result<(), ModbusErrorCode> {
self.bootloader_state_erased = false;
let mut updater = FirmwareUpdater::default();
match updater.mark_booted_blocking(&mut self.flash, &mut self.buf.0[..1]) {
Ok(()) => Ok(()),
Err(e) => {
error!("mark_booted_blocking: {:?}", e);
Err(ModbusErrorCode::ServerDeviceFailure)
}
}
}
}