From d52cd119eabe7f44ca1b808aabfac1ee4353c265 Mon Sep 17 00:00:00 2001
From: Benjamin Koch <bbbsnowball@gmail.com>
Date: Sat, 27 May 2023 22:28:40 +0200
Subject: [PATCH] refactor Modbus RX code: use cursor class for reading from
 vector

---
 firmware/rust1/src/modbus_server.rs | 81 ++++++++++++++++++++++-------
 1 file changed, 61 insertions(+), 20 deletions(-)

diff --git a/firmware/rust1/src/modbus_server.rs b/firmware/rust1/src/modbus_server.rs
index be277c4..9d8500b 100644
--- a/firmware/rust1/src/modbus_server.rs
+++ b/firmware/rust1/src/modbus_server.rs
@@ -7,8 +7,45 @@ use embassy_rp::uart;
 
 use crate::rs485::{RS485Handler};
 
+struct Cursor<'a, E>(&'a [u8], usize, E);
+
+impl<'a, E: Clone> Cursor<'a, E> {
+    fn new(xs: &'a [u8], err: E) -> Self {
+        Cursor(xs, 0, err)
+    }
+
+    #[allow(dead_code)]
+    fn into_inner(self: Self) -> &'a [u8] {
+        self.0
+    }
+
+    fn read_u8(self: &mut Self) -> Result<u8, E> {
+        if self.1 < self.0.len() {
+            let x = self.0[self.1];
+            self.1 += 1;
+            Ok(x)
+        } else {
+            Err(self.2.clone())
+        }
+    }
+
+    fn read(self: &mut Self, buf: &mut [u8]) -> Result<(), E> {
+        for i in 0..buf.len() {
+            buf[i] = self.read_u8()?;
+        }
+        Ok(())
+    }
+
+    fn read_u16be(self: &mut Self) -> Result<u16, E> {
+        let mut buf = [0; 2];
+        self.read(&mut buf)?;
+        Ok(u16::from_be_bytes(buf))
+    }
+}
+
+
 #[repr(u8)]
-#[derive(PartialEq, Eq, Format)]
+#[derive(PartialEq, Eq, Format, Clone)]
 pub enum ModbusErrorCode {
     IllegalFunction = 1,
     IllegalDataAddress = 2,
@@ -94,6 +131,7 @@ impl<REGS: ModbusRegisters> ModbusServer<REGS> {
         use ModbusErrorCode::*;
 
         let rxbuf = &self.rxbuf;
+        let mut rx = Cursor::new(&rxbuf, ModbusErrorCode::IllegalDataValue);
         let txbuf = &mut self.txbuf;
         info!("Modbus frame: {:?}", rxbuf.as_slice());
 
@@ -115,13 +153,16 @@ impl<REGS: ModbusRegisters> ModbusServer<REGS> {
         fn push_u16be(txbuf: &mut Vec<u8, BUF_LENGTH>, x: u16) -> Result<(), ModbusErrorCode> {
             push_many(txbuf, &x.to_be_bytes())
         }
-    
+
+        let _device_addr = rx.read_u8()?;
+
         // see https://modbus.org/docs/Modbus_Application_Protocol_V1_1b3.pdf
-        match rxbuf[1] {
+        let function = rx.read_u8()?;
+        match function {
             0x02 => {
                 // read discrete inputs
-                let start = ((rxbuf[2] as u16) << 8) | rxbuf[3] as u16;
-                let quantity = ((rxbuf[4] as u16) << 8) | rxbuf[5] as u16;
+                let start = rx.read_u16be()?;
+                let quantity = rx.read_u16be()?;
                 let byte_count = (quantity+7)/8;
                 if quantity < 1 || quantity > 2000 || byte_count as usize > capacity - 5 {
                     return Err(IllegalDataValue);
@@ -145,8 +186,8 @@ impl<REGS: ModbusRegisters> ModbusServer<REGS> {
                     // we shouldn't get here
                     return Err(ServerDeviceFailure);
                 }
-                let start = ((rxbuf[2] as u16) << 8) | rxbuf[3] as u16;
-                let quantity = ((rxbuf[4] as u16) << 8) | rxbuf[5] as u16;
+                let start = rx.read_u16be()?;
+                let quantity = rx.read_u16be()?;
                 let byte_count = quantity * 2;
                 if byte_count as usize > (capacity - 5) || quantity >= 0x7d {
                     return Err(IllegalDataValue);
@@ -165,8 +206,8 @@ impl<REGS: ModbusRegisters> ModbusServer<REGS> {
                     // we shouldn't get here
                     return Err(ServerDeviceFailure);
                 }
-                let start = ((rxbuf[2] as u16) << 8) | rxbuf[3] as u16;
-                let quantity = ((rxbuf[4] as u16) << 8) | rxbuf[5] as u16;
+                let start = rx.read_u16be()?;
+                let quantity = rx.read_u16be()?;
                 if quantity as usize > (capacity - 5) / 2 || quantity >= 0x7d {
                     return Err(IllegalDataValue); // is that right?
                 }
@@ -184,8 +225,8 @@ impl<REGS: ModbusRegisters> ModbusServer<REGS> {
                     // we shouldn't get here
                     return Err(ServerDeviceFailure);
                 }
-                let start = ((rxbuf[2] as u16) << 8) | rxbuf[3] as u16;
-                let value = ((rxbuf[4] as u16) << 8) | rxbuf[5] as u16;
+                let start = rx.read_u16be()?;
+                let value = rx.read_u16be()?;
                 let value = if value == 0x0000 {
                     false
                 } else if value == 0xff00 {
@@ -205,8 +246,8 @@ impl<REGS: ModbusRegisters> ModbusServer<REGS> {
                     // we shouldn't get here
                     return Err(ServerDeviceFailure);
                 }
-                let start = ((rxbuf[2] as u16) << 8) | rxbuf[3] as u16;
-                let value = ((rxbuf[4] as u16) << 8) | rxbuf[5] as u16;
+                let start = rx.read_u16be()?;
+                let value = rx.read_u16be()?;
                 let actual_new_value = self.regs.write_register(rxbuf[0], start, value)?;
                 if should_reply {
                     push_many(txbuf, &rxbuf[0..4])?;
@@ -220,9 +261,9 @@ impl<REGS: ModbusRegisters> ModbusServer<REGS> {
                     // we shouldn't get here
                     return Err(ServerDeviceFailure);
                 }
-                let start = ((rxbuf[2] as u16) << 8) | rxbuf[3] as u16;
-                let quantity = ((rxbuf[4] as u16) << 8) | rxbuf[5] as u16;
-                let byte_count = rxbuf[6];
+                let start = rx.read_u16be()?;
+                let quantity = rx.read_u16be()?;
+                let byte_count = rx.read_u8()?;
                 let expected_byte_count = (quantity + 7) / 8;
                 if quantity < 1 || quantity > 0x07b0 || byte_count as u16 != expected_byte_count {
                     return Err(IllegalDataValue);
@@ -250,9 +291,9 @@ impl<REGS: ModbusRegisters> ModbusServer<REGS> {
                     // we shouldn't get here
                     return Err(ServerDeviceFailure);
                 }
-                let start = ((rxbuf[2] as u16) << 8) | rxbuf[3] as u16;
-                let quantity = ((rxbuf[4] as u16) << 8) | rxbuf[5] as u16;
-                let byte_count = rxbuf[6];
+                let start = rx.read_u16be()?;
+                let quantity = rx.read_u16be()?;
+                let byte_count = rx.read_u8()?;
                 let expected_byte_count = quantity * 2;
                 if quantity < 1 || quantity > 0x07b || byte_count as u16 != expected_byte_count {
                     return Err(IllegalDataValue);
@@ -262,7 +303,7 @@ impl<REGS: ModbusRegisters> ModbusServer<REGS> {
                     return Err(ServerDeviceFailure);
                 }
                 for i in 0..quantity {
-                    let value = ((rxbuf[7 + 2 * i as usize] as u16) << 8) | rxbuf[7 + 2 * i as usize + 1] as u16;
+                    let value = rx.read_u16be()?;
                     self.regs.write_register(rxbuf[0], start + i, value)?;
                 }
                 if should_reply {
-- 
GitLab