]> code.octet-stream.net Git - m17rt/blobdiff - m17app/src/app.rs
Error handler for soundmodem components
[m17rt] / m17app / src / app.rs
index 81304559ddf6db33174da3fd74570ca046106547..b938a2cddb27e6593f1f21c326b04642f68e78bd 100644 (file)
@@ -1,16 +1,27 @@
 use crate::adapter::{PacketAdapter, StreamAdapter};
+use crate::error::{M17Error, M17Errors};
+use crate::link_setup::LinkSetup;
 use crate::tnc::Tnc;
+use crate::{LsfFrame, PacketType, StreamFrame};
 use m17core::kiss::{KissBuffer, KissCommand, KissFrame};
-use m17core::protocol::{EncryptionType, LsfFrame, PacketType};
+use m17core::protocol::EncryptionType;
 
 use log::debug;
 use std::collections::HashMap;
 use std::sync::mpsc;
 use std::sync::{Arc, RwLock};
 
+#[derive(Debug, Clone, PartialEq, Eq, Copy)]
+enum Lifecycle {
+    Setup,
+    Started,
+    Closed,
+}
+
 pub struct M17App {
     adapters: Arc<RwLock<Adapters>>,
     event_tx: mpsc::SyncSender<TncControlEvent>,
+    lifecycle: RwLock<Lifecycle>,
 }
 
 impl M17App {
@@ -23,49 +34,62 @@ impl M17App {
         Self {
             adapters: listeners,
             event_tx,
+            lifecycle: RwLock::new(Lifecycle::Setup),
         }
     }
 
-    pub fn add_packet_adapter<P: PacketAdapter + 'static>(&self, adapter: P) -> usize {
+    pub fn add_packet_adapter<P: PacketAdapter + 'static>(
+        &self,
+        adapter: P,
+    ) -> Result<usize, M17Error> {
         let adapter = Arc::new(adapter);
         let mut adapters = self.adapters.write().unwrap();
         let id = adapters.next;
         adapters.next += 1;
         adapters.packet.insert(id, adapter.clone());
         drop(adapters);
-        adapter.adapter_registered(id, self.tx());
-        id
+        if self.lifecycle() == Lifecycle::Started {
+            adapter
+                .start(self.tx())
+                .map_err(|e| M17Error::Adapter(id, e))?;
+        }
+        Ok(id)
     }
 
-    pub fn add_stream_adapter<S: StreamAdapter + 'static>(&self, adapter: S) -> usize {
+    pub fn add_stream_adapter<S: StreamAdapter + 'static>(
+        &self,
+        adapter: S,
+    ) -> Result<usize, M17Error> {
         let adapter = Arc::new(adapter);
         let mut adapters = self.adapters.write().unwrap();
         let id = adapters.next;
         adapters.next += 1;
         adapters.stream.insert(id, adapter.clone());
         drop(adapters);
-        adapter.adapter_registered(id, self.tx());
-        id
+        if self.lifecycle() == Lifecycle::Started {
+            adapter
+                .start(self.tx())
+                .map_err(|e| M17Error::Adapter(id, e))?;
+        }
+        Ok(id)
     }
 
-    pub fn remove_packet_adapter(&self, id: usize) {
+    pub fn remove_packet_adapter(&self, id: usize) -> Result<(), M17Error> {
         if let Some(a) = self.adapters.write().unwrap().packet.remove(&id) {
-            a.adapter_removed();
+            if self.lifecycle() == Lifecycle::Started {
+                a.close().map_err(|e| M17Error::Adapter(id, e))?;
+            }
         }
+        Ok(())
     }
 
-    pub fn remove_stream_adapter(&self, id: usize) {
+    pub fn remove_stream_adapter(&self, id: usize) -> Result<(), M17Error> {
         if let Some(a) = self.adapters.write().unwrap().stream.remove(&id) {
-            a.adapter_removed();
+            if self.lifecycle() == Lifecycle::Started {
+                a.close().map_err(|e| M17Error::Adapter(id, e))?;
+            }
         }
-    }
-
-    pub fn transmit_packet(&self, packet_type: PacketType, payload: &[u8]) {
-        // hang on where do we get the LSF details from? We need a destination obviously
-        // our source address needs to be configured here too
-        // also there is possible CAN, encryption, meta payload
-
-        // we will immediately convert this into a KISS payload before sending into channel so we only need borrow on data
+        Ok(())
     }
 
     /// Create a handle that can be used to transmit data on the TNC
@@ -75,12 +99,68 @@ impl M17App {
         }
     }
 
-    pub fn start(&self) {
+    pub fn start(&self) -> Result<(), M17Errors> {
+        if self.lifecycle() != Lifecycle::Setup {
+            return Err(M17Errors(vec![M17Error::InvalidStart]));
+        }
+        self.set_lifecycle(Lifecycle::Started);
+        let mut errs = vec![];
+        {
+            let adapters = self.adapters.read().unwrap();
+            for (i, p) in &adapters.packet {
+                if let Err(e) = p.start(self.tx()) {
+                    errs.push(M17Error::Adapter(*i, e));
+                }
+            }
+            for (i, s) in &adapters.stream {
+                if let Err(e) = s.start(self.tx()) {
+                    errs.push(M17Error::Adapter(*i, e));
+                }
+            }
+        }
         let _ = self.event_tx.send(TncControlEvent::Start);
+        if errs.is_empty() {
+            Ok(())
+        } else {
+            Err(M17Errors(errs))
+        }
     }
 
-    pub fn close(&self) {
+    pub fn close(&self) -> Result<(), M17Errors> {
+        if self.lifecycle() != Lifecycle::Started {
+            return Err(M17Errors(vec![M17Error::InvalidClose]));
+        }
+        self.set_lifecycle(Lifecycle::Closed);
+        let mut errs = vec![];
+        {
+            let adapters = self.adapters.read().unwrap();
+            for (i, p) in &adapters.packet {
+                if let Err(e) = p.close() {
+                    errs.push(M17Error::Adapter(*i, e));
+                }
+            }
+            for (i, s) in &adapters.stream {
+                if let Err(e) = s.close() {
+                    errs.push(M17Error::Adapter(*i, e));
+                }
+            }
+        }
+        // TODO: blocking function to indicate TNC has finished closing
+        // then we could call this in a signal handler to ensure PTT is dropped before quit
         let _ = self.event_tx.send(TncControlEvent::Close);
+        if errs.is_empty() {
+            Ok(())
+        } else {
+            Err(M17Errors(errs))
+        }
+    }
+
+    fn lifecycle(&self) -> Lifecycle {
+        *self.lifecycle.read().unwrap()
+    }
+
+    fn set_lifecycle(&self, lifecycle: Lifecycle) {
+        *self.lifecycle.write().unwrap() = lifecycle;
     }
 }
 
@@ -89,13 +169,40 @@ pub struct TxHandle {
 }
 
 impl TxHandle {
-    // add more methods here for stream outgoing
+    pub fn transmit_packet(
+        &self,
+        link_setup: &LinkSetup,
+        packet_type: PacketType,
+        payload: &[u8],
+    ) -> Result<(), M17Error> {
+        let (pack_type, pack_type_len) = packet_type.as_proto();
+        if pack_type_len + payload.len() > 823 {
+            return Err(M17Error::PacketTooLarge {
+                provided: payload.len(),
+                capacity: 823 - pack_type_len,
+            });
+        }
+        let mut full_payload = vec![];
+        full_payload.extend_from_slice(&pack_type[0..pack_type_len]);
+        full_payload.extend_from_slice(payload);
+        let crc = m17core::crc::m17_crc(&full_payload);
+        full_payload.extend_from_slice(&crc.to_be_bytes());
+        let kiss_frame = KissFrame::new_full_packet(&link_setup.raw.0, &full_payload).unwrap();
+        let _ = self.event_tx.send(TncControlEvent::Kiss(kiss_frame));
+        Ok(())
+    }
 
-    pub fn transmit_stream_start(&self /* lsf?, payload? what needs to be configured ?! */) {}
+    pub fn transmit_stream_start(&self, link_setup: &LinkSetup) {
+        let kiss_frame = KissFrame::new_stream_setup(&link_setup.raw.0).unwrap();
+        let _ = self.event_tx.send(TncControlEvent::Kiss(kiss_frame));
+    }
 
     // as long as there is only one TNC it is implied there is only ever one stream transmission in flight
 
-    pub fn transmit_stream_next(&self, /* next payload,  */ end_of_stream: bool) {}
+    pub fn transmit_stream_next(&self, stream: &StreamFrame) {
+        let kiss_frame = KissFrame::new_stream_data(stream).unwrap();
+        let _ = self.event_tx.send(TncControlEvent::Kiss(kiss_frame));
+    }
 }
 
 /// Synchronised structure for listeners subscribing to packets and streams.
@@ -119,19 +226,20 @@ impl Adapters {
 }
 
 /// Carries a request from a method on M17App to the TNC's writer thread, which will execute it.
+#[allow(clippy::large_enum_variant)]
 enum TncControlEvent {
     Kiss(KissFrame),
     Start,
     Close,
 }
 
-fn spawn_reader<T: Tnc + Send + 'static>(mut tnc: T, adapters: Arc<RwLock<Adapters>>) {
+fn spawn_reader<T: Tnc>(mut tnc: T, adapters: Arc<RwLock<Adapters>>) {
     std::thread::spawn(move || {
         let mut kiss_buffer = KissBuffer::new();
         let mut stream_running = false;
         loop {
-            let mut buf = kiss_buffer.buf_remaining();
-            let n = match tnc.read(&mut buf) {
+            let buf = kiss_buffer.buf_remaining();
+            let n = match tnc.read(buf) {
                 Ok(n) => n,
                 Err(_) => break,
             };
@@ -156,7 +264,7 @@ fn spawn_reader<T: Tnc + Send + 'static>(mut tnc: T, adapters: Arc<RwLock<Adapte
                             continue;
                         }
                         let lsf = LsfFrame(payload[0..30].try_into().unwrap());
-                        if lsf.crc() != 0 {
+                        if lsf.check_crc() != 0 {
                             debug!("LSF in full packet frame did not pass CRC");
                             continue;
                         }
@@ -185,8 +293,8 @@ fn spawn_reader<T: Tnc + Send + 'static>(mut tnc: T, adapters: Arc<RwLock<Adapte
                             adapters.read().unwrap().packet.values().cloned().collect();
                         for s in subs {
                             s.packet_received(
-                                lsf.clone(),
-                                packet_type.clone(),
+                                LinkSetup::new_raw(lsf.clone()),
+                                packet_type,
                                 packet_payload.clone(),
                             );
                         }
@@ -199,7 +307,7 @@ fn spawn_reader<T: Tnc + Send + 'static>(mut tnc: T, adapters: Arc<RwLock<Adapte
                         };
                         if n == 30 {
                             let lsf = LsfFrame(payload[0..30].try_into().unwrap());
-                            if lsf.crc() != 0 {
+                            if lsf.check_crc() != 0 {
                                 debug!("initial LSF in stream did not pass CRC");
                                 continue;
                             }
@@ -207,7 +315,7 @@ fn spawn_reader<T: Tnc + Send + 'static>(mut tnc: T, adapters: Arc<RwLock<Adapte
                             let subs: Vec<_> =
                                 adapters.read().unwrap().stream.values().cloned().collect();
                             for s in subs {
-                                s.stream_began(lsf.clone());
+                                s.stream_began(LinkSetup::new_raw(lsf.clone()));
                             }
                         } else if n == 26 {
                             if !stream_running {
@@ -241,29 +349,102 @@ fn spawn_reader<T: Tnc + Send + 'static>(mut tnc: T, adapters: Arc<RwLock<Adapte
     });
 }
 
-fn spawn_writer<T: Tnc + Send + 'static>(mut tnc: T, event_rx: mpsc::Receiver<TncControlEvent>) {
+fn spawn_writer<T: Tnc>(mut tnc: T, event_rx: mpsc::Receiver<TncControlEvent>) {
     std::thread::spawn(move || {
         while let Ok(ev) = event_rx.recv() {
             match ev {
                 TncControlEvent::Kiss(k) => {
-                    if let Err(e) = tnc.write_all(&k.as_bytes()) {
-                        debug!("kiss send err: {:?}", e);
+                    if tnc.write_all(k.as_bytes()).is_err() {
                         return;
                     }
                 }
                 TncControlEvent::Start => {
-                    if let Err(e) = tnc.start() {
-                        debug!("tnc start err: {:?}", e);
-                        return;
-                    }
+                    tnc.start();
                 }
                 TncControlEvent::Close => {
-                    if let Err(e) = tnc.close() {
-                        debug!("tnc close err: {:?}", e);
-                        return;
-                    }
+                    tnc.close();
                 }
             }
         }
     });
 }
+
+#[cfg(test)]
+mod tests {
+    use crate::error::AdapterError;
+    use crate::{link_setup::M17Address, test_util::NullTnc};
+
+    use super::*;
+
+    #[test]
+    fn packet_payload_len() {
+        let app = M17App::new(NullTnc);
+        let res = app.tx().transmit_packet(
+            &LinkSetup::new_packet(&M17Address::new_broadcast(), &M17Address::new_broadcast()),
+            PacketType::Raw,
+            &[0u8; 100],
+        );
+        assert!(matches!(res, Ok(())));
+        let res = app.tx().transmit_packet(
+            &LinkSetup::new_packet(&M17Address::new_broadcast(), &M17Address::new_broadcast()),
+            PacketType::Raw,
+            &[0u8; 900],
+        );
+        assert!(matches!(
+            res,
+            Err(M17Error::PacketTooLarge {
+                provided: 900,
+                capacity: 822
+            })
+        ));
+    }
+
+    #[test]
+    fn adapter_lifecycle() {
+        #[derive(Debug, PartialEq)]
+        enum Event {
+            Started,
+            Closed,
+        }
+        macro_rules! event_impl {
+            ($target:ty, $trait:ty) => {
+                impl $trait for $target {
+                    fn start(&self, _handle: TxHandle) -> Result<(), AdapterError> {
+                        self.0.send(Event::Started)?;
+                        Ok(())
+                    }
+
+                    fn close(&self) -> Result<(), AdapterError> {
+                        self.0.send(Event::Closed)?;
+                        Ok(())
+                    }
+                }
+            };
+        }
+        struct FakePacket(mpsc::SyncSender<Event>);
+        struct FakeStream(mpsc::SyncSender<Event>);
+        event_impl!(FakePacket, PacketAdapter);
+        event_impl!(FakeStream, StreamAdapter);
+
+        let app = M17App::new(NullTnc);
+        let (tx_p, rx_p) = mpsc::sync_channel(128);
+        let (tx_s, rx_s) = mpsc::sync_channel(128);
+        let packet = FakePacket(tx_p);
+        let stream = FakeStream(tx_s);
+
+        let id_p = app.add_packet_adapter(packet).unwrap();
+        let id_s = app.add_stream_adapter(stream).unwrap();
+        app.start().unwrap();
+        app.close().unwrap();
+        app.remove_packet_adapter(id_p).unwrap();
+        app.remove_stream_adapter(id_s).unwrap();
+
+        assert_eq!(rx_p.try_recv(), Ok(Event::Started));
+        assert_eq!(rx_p.try_recv(), Ok(Event::Closed));
+        assert_eq!(rx_p.try_recv(), Err(mpsc::TryRecvError::Disconnected));
+
+        assert_eq!(rx_s.try_recv(), Ok(Event::Started));
+        assert_eq!(rx_s.try_recv(), Ok(Event::Closed));
+        assert_eq!(rx_s.try_recv(), Err(mpsc::TryRecvError::Disconnected));
+    }
+}