use defmt::*;
use {defmt_rtt as _, panic_probe as _};
use futures::Future;
use heapless::Vec;
use crc::{Crc, CRC_16_MODBUS, Digest};
use embassy_rp::uart;

use crate::rs485::{RS485Handler};

struct Cursor<'a, E>(&'a [u8], usize, E);

impl<'a, E: Clone> Cursor<'a, E> {
    fn new(xs: &'a [u8], err: E) -> Self {
        Cursor(xs, 0, err)
    }

    #[allow(dead_code)]
    fn into_inner(self: Self) -> &'a [u8] {
        self.0
    }

    fn read_u8(self: &mut Self) -> Result<u8, E> {
        if self.1 < self.0.len() {
            let x = self.0[self.1];
            self.1 += 1;
            Ok(x)
        } else {
            Err(self.2.clone())
        }
    }

    fn read(self: &mut Self, buf: &mut [u8]) -> Result<(), E> {
        for i in 0..buf.len() {
            buf[i] = self.read_u8()?;
        }
        Ok(())
    }

    fn read_u16be(self: &mut Self) -> Result<u16, E> {
        let mut buf = [0; 2];
        self.read(&mut buf)?;
        Ok(u16::from_be_bytes(buf))
    }
}


#[repr(u8)]
#[derive(PartialEq, Eq, Format, Clone)]
pub enum ModbusErrorCode {
    IllegalFunction = 1,
    IllegalDataAddress = 2,
    IllegalDataValue = 3,
    ServerDeviceFailure = 4,
    Acknowledge = 5,
    ServerDeviceBusy = 6,
    MemoryParityError = 7,
    GatewayPathUnavailable = 0xa,
    GatewayTargetDeviceFailedToRespond = 0xb,
}

pub enum ModbusAdressMatch {
    NotOurAddress,
    OurAddress,
    BroadcastNoReply,
}

pub trait ModbusRegisters {
    fn is_address_match(self: &mut Self, device_addr: u8) -> ModbusAdressMatch;
    fn read_discrete_input(self: &mut Self, device_addr: u8, addr: u16) -> Result<bool, ModbusErrorCode>;
    fn read_holding_register(self: &mut Self, device_addr: u8, addr: u16) -> Result<u16, ModbusErrorCode>;
    fn read_input_register(self: &mut Self, device_addr: u8, addr: u16) -> Result<u16, ModbusErrorCode>;
    fn write_coil(self: &mut Self, device_addr: u8, addr: u16, value: bool) -> Result<(), ModbusErrorCode>;
    fn write_register(self: &mut Self, device_addr: u8, addr: u16, value: u16) -> Result<u16, ModbusErrorCode>;
}

#[derive(PartialEq, Eq, Format)]
enum ModbusFrameLength {
    NeedMoreData(u16),
    Length(u16),
    Unknown,
}

//FIXME This won't work if this is a response frame!
fn get_modbus_frame_length(rxbuf: &[u8]) -> ModbusFrameLength {
    use ModbusFrameLength::*;

    if rxbuf.len() < 3 {
        return NeedMoreData(3);
    }

    match rxbuf[1] {
        0x01..=0x06 => Length(8),
        0x0f | 0x10 => if rxbuf.len() == 7 { Length(9 + rxbuf[6] as u16) } else { NeedMoreData(7) },
        _ => Unknown,
    }
}

const CRC: Crc<u16> = Crc::<u16>::new(&CRC_16_MODBUS);
const BUF_LENGTH: usize = 256;

pub struct ModbusServer<REGS: ModbusRegisters> {
    rxbuf: Vec<u8, BUF_LENGTH>,
    rxcrc: Digest<'static, u16>,
    rx_expected_bytes: ModbusFrameLength,
    rx_received_bytes: u16,
    txbuf: Vec<u8, BUF_LENGTH>,
    regs: REGS,
}

impl<REGS: ModbusRegisters> ModbusServer<REGS> {
    pub fn new(regs: REGS) -> ModbusServer<REGS> {
        ModbusServer { 
            rxbuf: Vec::new(),
            rxcrc: CRC.digest(),
            rx_expected_bytes: ModbusFrameLength::NeedMoreData(3),
            rx_received_bytes: 0,
            txbuf: Vec::new(),
            regs,
        }
    }


    fn modbus_reply_error(self: &mut Self, code: ModbusErrorCode) {
        self.txbuf.clear();
        self.txbuf.push(self.rxbuf[0]).unwrap();
        self.txbuf.push(self.rxbuf[1] | 0x80).unwrap();
        self.txbuf.push(code as u8).unwrap();
    }
    
    fn handle_modbus_frame2(self: &mut Self, should_reply: bool) -> Result<(), ModbusErrorCode> {
        use ModbusErrorCode::*;

        let rxbuf = &self.rxbuf;
        let mut rx = Cursor::new(&rxbuf, ModbusErrorCode::IllegalDataValue);
        let txbuf = &mut self.txbuf;
        info!("Modbus frame: {:?}", rxbuf.as_slice());

        txbuf.clear();
        let capacity = txbuf.capacity();

        fn map_err<T>(x: Result<(), T>) -> Result<(), ModbusErrorCode> {
            x.or(Err(ServerDeviceFailure))
        }
        fn push(txbuf: &mut Vec<u8, BUF_LENGTH>, x: u8) -> Result<(), ModbusErrorCode> {
            map_err(txbuf.push(x))
        }
        fn push_many(txbuf: &mut Vec<u8, BUF_LENGTH>, xs: &[u8]) -> Result<(), ModbusErrorCode> {
            for x in xs {
                push(txbuf, *x)?;
            }
            Ok(())
        }
        fn push_u16be(txbuf: &mut Vec<u8, BUF_LENGTH>, x: u16) -> Result<(), ModbusErrorCode> {
            push_many(txbuf, &x.to_be_bytes())
        }

        let _device_addr = rx.read_u8()?;

        // see https://modbus.org/docs/Modbus_Application_Protocol_V1_1b3.pdf
        let function = rx.read_u8()?;
        match function {
            0x02 => {
                // read discrete inputs
                let start = rx.read_u16be()?;
                let quantity = rx.read_u16be()?;
                let byte_count = (quantity+7)/8;
                if quantity < 1 || quantity > 2000 || byte_count as usize > capacity - 5 {
                    return Err(IllegalDataValue);
                };
                push_many(txbuf, &rxbuf[0..2])?;
                push(txbuf, byte_count as u8)?;
                for i in 0..byte_count {
                    let mut x = 0;
                    for j in 0..7 {
                        if i*8+j < quantity {
                            x |= (if self.regs.read_discrete_input(rxbuf[0], start + i*8 + j)? {1} else {0}) << j;
                        }
                    }
                    push(txbuf, x)?;
                }
                Ok(())
            },
            0x03 => {
                // read holding registers
                if rxbuf.len() != 8 {
                    // we shouldn't get here
                    return Err(ServerDeviceFailure);
                }
                let start = rx.read_u16be()?;
                let quantity = rx.read_u16be()?;
                let byte_count = quantity * 2;
                if byte_count as usize > (capacity - 5) || quantity >= 0x7d {
                    return Err(IllegalDataValue);
                }
                push_many(txbuf, &rxbuf[0..2])?;
                push(txbuf, byte_count as u8)?;
                for i in 0..quantity {
                    let value = self.regs.read_holding_register(rxbuf[0], start + i)?;
                    push_u16be(txbuf, value)?;
                }
                Ok(())
            },
            0x04 => {
                // read input registers
                if rxbuf.len() != 8 {
                    // we shouldn't get here
                    return Err(ServerDeviceFailure);
                }
                let start = rx.read_u16be()?;
                let quantity = rx.read_u16be()?;
                if quantity as usize > (capacity - 5) / 2 || quantity >= 0x7d {
                    return Err(IllegalDataValue); // is that right?
                }
                push_many(txbuf, &rxbuf[0..2])?;
                push(txbuf, (quantity*2) as u8)?;
                for i in 0..quantity {
                    let value = self.regs.read_input_register(rxbuf[0], start + i)?;
                    push_u16be(txbuf, value)?;
                }
                Ok(())
            },
            0x05 => {
                // write single coil
                if rxbuf.len() != 8 {
                    // we shouldn't get here
                    return Err(ServerDeviceFailure);
                }
                let start = rx.read_u16be()?;
                let value = rx.read_u16be()?;
                let value = if value == 0x0000 {
                    false
                } else if value == 0xff00 {
                    true
                } else {
                    return Err(IllegalDataValue);
                };
                self.regs.write_coil(rxbuf[0], start, value)?;
                if should_reply {
                    push_many(txbuf, &rxbuf[0..6])?;
                }
                Ok(())
            },
            0x06 => {
                // write register
                if rxbuf.len() != 8 {
                    // we shouldn't get here
                    return Err(ServerDeviceFailure);
                }
                let start = rx.read_u16be()?;
                let value = rx.read_u16be()?;
                let actual_new_value = self.regs.write_register(rxbuf[0], start, value)?;
                if should_reply {
                    push_many(txbuf, &rxbuf[0..4])?;
                    push_u16be(txbuf, actual_new_value)?;
                }
                Ok(())
            },
            0x0F => {
                // write multiple coils
                if rxbuf.len() < 9 {
                    // we shouldn't get here
                    return Err(ServerDeviceFailure);
                }
                let start = rx.read_u16be()?;
                let quantity = rx.read_u16be()?;
                let byte_count = rx.read_u8()?;
                let expected_byte_count = (quantity + 7) / 8;
                if quantity < 1 || quantity > 0x07b0 || byte_count as u16 != expected_byte_count {
                    return Err(IllegalDataValue);
                }
                if rxbuf.len() != 9 + byte_count as usize {
                    // we shouldn't get here
                    return Err(ServerDeviceFailure);
                }
                for i in 0..byte_count as u16 {
                    let x = rxbuf[7 + i as usize];
                    for j in 0..7 {
                        if i*8+j < quantity {
                            self.regs.write_coil(rxbuf[0], start + i*8 + j, (x >> j) != 0)?;
                        }
                    }
                }
                if should_reply {
                    push_many(txbuf, &rxbuf[0..6])?;
                }
                Ok(())
            },
            0x10 => {
                // write multiple registers
                if rxbuf.len() < 9 {
                    // we shouldn't get here
                    return Err(ServerDeviceFailure);
                }
                let start = rx.read_u16be()?;
                let quantity = rx.read_u16be()?;
                let byte_count = rx.read_u8()?;
                let expected_byte_count = quantity * 2;
                if quantity < 1 || quantity > 0x07b || byte_count as u16 != expected_byte_count {
                    return Err(IllegalDataValue);
                }
                if rxbuf.len() != 9 + byte_count as usize {
                    // we shouldn't get here
                    return Err(ServerDeviceFailure);
                }
                for i in 0..quantity {
                    let value = rx.read_u16be()?;
                    self.regs.write_register(rxbuf[0], start + i, value)?;
                }
                if should_reply {
                    push_many(txbuf, &rxbuf[0..6])?;
                }
                Ok(())
            },
            _ => {
                Err(IllegalFunction)
            },
        }
    }
    
    fn handle_modbus_frame(self: &mut Self) {
        let should_reply = match self.regs.is_address_match(self.rxbuf[0]) {
            ModbusAdressMatch::NotOurAddress => return,
            ModbusAdressMatch::OurAddress => true,
            ModbusAdressMatch::BroadcastNoReply => false,
        };

        match self.handle_modbus_frame2(should_reply) {
            Ok(()) => {
                if should_reply && self.txbuf.capacity() - self.txbuf.len() < 2 {
                    // We don't have enough space for the CRC so reply with error instead.
                    self.modbus_reply_error(ModbusErrorCode::ServerDeviceFailure);
                }
            },
            Err(code) => {
                if should_reply {
                    self.modbus_reply_error(code);
                } else {
                    warn!("Error result when processing Modbus frame but we won't reply because it is a broadcast: {:?}", code);
                }
            }
        }
    
        if should_reply {
            const CRC: Crc<u16> = Crc::<u16>::new(&CRC_16_MODBUS);
            let x = CRC.checksum(self.txbuf.as_slice());
            self.txbuf.push((x & 0xff) as u8).unwrap();
            self.txbuf.push((x >> 8) as u8).unwrap();
        }
    }    
}

impl<REGS: ModbusRegisters> RS485Handler for ModbusServer<REGS> {
    //type CommandFuture = !;
    const TX_BUF_LENGTH: usize = BUF_LENGTH;

    fn on_rx<F>(self: &mut Self, rx: Result<u8, uart::Error>, reply: Option<F>)
        where F: FnOnce(&[u8]) {
        match rx {
            Ok(rx_char) => {
                //info!("RX {:?}", rx_char);

                self.rxcrc.update(&[rx_char]);
                if !self.rxbuf.is_full() {
                    self.rxbuf.push(rx_char).unwrap_or_default();
                }
                self.rx_received_bytes += 1;

                if let ModbusFrameLength::NeedMoreData(x) = self.rx_expected_bytes {
                    if x == self.rx_received_bytes {
                        self.rx_expected_bytes = get_modbus_frame_length(self.rxbuf.as_slice());
                        match self.rx_expected_bytes {
                            ModbusFrameLength::Unknown => {
                                //FIXME Wait for pause.
                            },
                            _ => {}
                        }
                    }
                }
                if let ModbusFrameLength::Length(x) = self.rx_expected_bytes {
                    if x == self.rx_received_bytes {
                        let calculated_crc = self.rxcrc.clone().finalize();
                        const CORRECT_CRC: u16 = 0;  // because we include the CRC bytes in our calculation
                        if calculated_crc != CORRECT_CRC {
                            info!("CRC: {:04x} (should be zero)", calculated_crc);
                        }

                        //FIXME In case of CRC mismatch, wait for gap/idle of >=1.5 chars.
                        const OUR_ADDRESS: u8 = 1;
                        if self.rxbuf[0] == OUR_ADDRESS && calculated_crc == CORRECT_CRC {
                            self.txbuf.clear();
                            self.handle_modbus_frame();

                            if !self.txbuf.is_empty() {
                                info!("Modbus reply: {:?}", self.txbuf);
                                match reply {
                                    Option::Some(reply) => reply(self.txbuf.as_slice()),
                                    Option::None => warn!("Cannot send reply because a reply is already in progress!"),
                                }
                            }
                        }

                        self.rxbuf.clear();
                        self.rxcrc = CRC.digest();
                        self.rx_expected_bytes = ModbusFrameLength::NeedMoreData(3);
                        self.rx_received_bytes = 0;
                    }
                }
            },
            Err(e) => {
                info!("RX error {:?}", e);
        
                //FIXME wait for gap/idle of >=1.5 chars.
                self.rxbuf.clear();
                self.rxcrc = CRC.digest();
                self.rx_expected_bytes = ModbusFrameLength::NeedMoreData(3);
                self.rx_received_bytes = 0;
            }
        }
    }

    fn on_idle(self: &mut Self) {
        if !self.rxbuf.is_empty() {
            warn!("Partial frame in rx buffer, cut short by inter-byte gap: {:?}, {:?}", self.rx_expected_bytes, self.rxbuf);
        }

        self.rxbuf.clear();
        self.rxcrc = CRC.digest();
        self.rx_expected_bytes = ModbusFrameLength::NeedMoreData(3);
        self.rx_received_bytes = 0;
    }

    fn on_tx_done(self: &mut Self) {
        //TODO
    }

    fn on_autobaud_success(self: &mut Self, baudrate: f32) {
        //TODO
        info!("Guessed baud rate: {}", baudrate);
    }
}