]> code.octet-stream.net Git - m17rt/blobdiff - m17app/src/app.rs
Simplify adapter lifecycle and introduce a lot of error propagation
[m17rt] / m17app / src / app.rs
index 0abfab640b51937b9076b45cea922ea43c4e954a..dc2138da0605879ef6b34878af5e477b35075143 100644 (file)
@@ -1,5 +1,5 @@
 use crate::adapter::{PacketAdapter, StreamAdapter};
-use crate::error::M17Error;
+use crate::error::{M17Error, M17Errors};
 use crate::link_setup::LinkSetup;
 use crate::tnc::Tnc;
 use crate::{LsfFrame, PacketType, StreamFrame};
@@ -11,9 +11,17 @@ use std::collections::HashMap;
 use std::sync::mpsc;
 use std::sync::{Arc, RwLock};
 
+#[derive(Debug, Clone, PartialEq, Eq, Copy)]
+enum Lifecycle {
+    Setup,
+    Started,
+    Closed,
+}
+
 pub struct M17App {
     adapters: Arc<RwLock<Adapters>>,
     event_tx: mpsc::SyncSender<TncControlEvent>,
+    lifecycle: RwLock<Lifecycle>,
 }
 
 impl M17App {
@@ -26,41 +34,62 @@ impl M17App {
         Self {
             adapters: listeners,
             event_tx,
+            lifecycle: RwLock::new(Lifecycle::Setup),
         }
     }
 
-    pub fn add_packet_adapter<P: PacketAdapter + 'static>(&self, adapter: P) -> usize {
+    pub fn add_packet_adapter<P: PacketAdapter + 'static>(
+        &self,
+        adapter: P,
+    ) -> Result<usize, M17Error> {
         let adapter = Arc::new(adapter);
         let mut adapters = self.adapters.write().unwrap();
         let id = adapters.next;
         adapters.next += 1;
         adapters.packet.insert(id, adapter.clone());
         drop(adapters);
-        adapter.adapter_registered(id, self.tx());
-        id
+        if self.lifecycle() == Lifecycle::Started {
+            adapter
+                .start(self.tx())
+                .map_err(|e| M17Error::Adapter(id, e))?;
+        }
+        Ok(id)
     }
 
-    pub fn add_stream_adapter<S: StreamAdapter + 'static>(&self, adapter: S) -> usize {
+    pub fn add_stream_adapter<S: StreamAdapter + 'static>(
+        &self,
+        adapter: S,
+    ) -> Result<usize, M17Error> {
         let adapter = Arc::new(adapter);
         let mut adapters = self.adapters.write().unwrap();
         let id = adapters.next;
         adapters.next += 1;
         adapters.stream.insert(id, adapter.clone());
         drop(adapters);
-        adapter.adapter_registered(id, self.tx());
-        id
+        if self.lifecycle() == Lifecycle::Started {
+            adapter
+                .start(self.tx())
+                .map_err(|e| M17Error::Adapter(id, e))?;
+        }
+        Ok(id)
     }
 
-    pub fn remove_packet_adapter(&self, id: usize) {
+    pub fn remove_packet_adapter(&self, id: usize) -> Result<(), M17Error> {
         if let Some(a) = self.adapters.write().unwrap().packet.remove(&id) {
-            a.adapter_removed();
+            if self.lifecycle() == Lifecycle::Started {
+                a.close().map_err(|e| M17Error::Adapter(id, e))?;
+            }
         }
+        Ok(())
     }
 
-    pub fn remove_stream_adapter(&self, id: usize) {
+    pub fn remove_stream_adapter(&self, id: usize) -> Result<(), M17Error> {
         if let Some(a) = self.adapters.write().unwrap().stream.remove(&id) {
-            a.adapter_removed();
+            if self.lifecycle() == Lifecycle::Started {
+                a.close().map_err(|e| M17Error::Adapter(id, e))?;
+            }
         }
+        Ok(())
     }
 
     /// Create a handle that can be used to transmit data on the TNC
@@ -70,14 +99,68 @@ impl M17App {
         }
     }
 
-    pub fn start(&self) {
+    pub fn start(&self) -> Result<(), M17Errors> {
+        if self.lifecycle() != Lifecycle::Setup {
+            return Err(M17Errors(vec![M17Error::InvalidStart]));
+        }
+        self.set_lifecycle(Lifecycle::Started);
+        let mut errs = vec![];
+        {
+            let adapters = self.adapters.read().unwrap();
+            for (i, p) in &adapters.packet {
+                if let Err(e) = p.start(self.tx()) {
+                    errs.push(M17Error::Adapter(*i, e));
+                }
+            }
+            for (i, s) in &adapters.stream {
+                if let Err(e) = s.start(self.tx()) {
+                    errs.push(M17Error::Adapter(*i, e));
+                }
+            }
+        }
         let _ = self.event_tx.send(TncControlEvent::Start);
+        if errs.is_empty() {
+            Ok(())
+        } else {
+            Err(M17Errors(errs))
+        }
     }
 
-    pub fn close(&self) {
+    pub fn close(&self) -> Result<(), M17Errors> {
+        if self.lifecycle() != Lifecycle::Started {
+            return Err(M17Errors(vec![M17Error::InvalidClose]));
+        }
+        self.set_lifecycle(Lifecycle::Closed);
+        let mut errs = vec![];
+        {
+            let adapters = self.adapters.read().unwrap();
+            for (i, p) in &adapters.packet {
+                if let Err(e) = p.close() {
+                    errs.push(M17Error::Adapter(*i, e));
+                }
+            }
+            for (i, s) in &adapters.stream {
+                if let Err(e) = s.close() {
+                    errs.push(M17Error::Adapter(*i, e));
+                }
+            }
+        }
         // TODO: blocking function to indicate TNC has finished closing
         // then we could call this in a signal handler to ensure PTT is dropped before quit
         let _ = self.event_tx.send(TncControlEvent::Close);
+        if errs.is_empty() {
+            Ok(())
+        } else {
+            Err(M17Errors(errs))
+        }
+    }
+
+    fn lifecycle(&self) -> Lifecycle {
+        *self.lifecycle.read().unwrap()
+    }
+
+    fn set_lifecycle(&self, lifecycle: Lifecycle) {
+        *self.lifecycle.write().unwrap() = lifecycle;
     }
 }
 
@@ -295,6 +378,7 @@ fn spawn_writer<T: Tnc>(mut tnc: T, event_rx: mpsc::Receiver<TncControlEvent>) {
 
 #[cfg(test)]
 mod tests {
+    use crate::error::AdapterError;
     use crate::{link_setup::M17Address, test_util::NullTnc};
 
     use super::*;
@@ -307,18 +391,67 @@ mod tests {
             PacketType::Raw,
             &[0u8; 100],
         );
-        assert_eq!(res, Ok(()));
+        assert!(matches!(res, Ok(())));
         let res = app.tx().transmit_packet(
             &LinkSetup::new_packet(&M17Address::new_broadcast(), &M17Address::new_broadcast()),
             PacketType::Raw,
             &[0u8; 900],
         );
-        assert_eq!(
+        assert!(matches!(
             res,
             Err(M17Error::PacketTooLarge {
                 provided: 900,
                 capacity: 822
             })
-        );
+        ));
+    }
+
+    #[test]
+    fn adapter_lifecycle() {
+        #[derive(Debug, PartialEq)]
+        enum Event {
+            Started,
+            Closed,
+        }
+        macro_rules! event_impl {
+            ($target:ty, $trait:ty) => {
+                impl $trait for $target {
+                    fn start(&self, _handle: TxHandle) -> Result<(), AdapterError> {
+                        self.0.send(Event::Started)?;
+                        Ok(())
+                    }
+
+                    fn close(&self) -> Result<(), AdapterError> {
+                        self.0.send(Event::Closed)?;
+                        Ok(())
+                    }
+                }
+            };
+        }
+        struct FakePacket(mpsc::SyncSender<Event>);
+        struct FakeStream(mpsc::SyncSender<Event>);
+        event_impl!(FakePacket, PacketAdapter);
+        event_impl!(FakeStream, StreamAdapter);
+
+        let app = M17App::new(NullTnc);
+        let (tx_p, rx_p) = mpsc::sync_channel(128);
+        let (tx_s, rx_s) = mpsc::sync_channel(128);
+        let packet = FakePacket(tx_p);
+        let stream = FakeStream(tx_s);
+
+        let id_p = app.add_packet_adapter(packet).unwrap();
+        let id_s = app.add_stream_adapter(stream).unwrap();
+        app.start().unwrap();
+        app.close().unwrap();
+        app.remove_packet_adapter(id_p).unwrap();
+        app.remove_stream_adapter(id_s).unwrap();
+
+        assert_eq!(rx_p.try_recv(), Ok(Event::Started));
+        assert_eq!(rx_p.try_recv(), Ok(Event::Closed));
+        assert_eq!(rx_p.try_recv(), Err(mpsc::TryRecvError::Disconnected));
+
+        assert_eq!(rx_s.try_recv(), Ok(Event::Started));
+        assert_eq!(rx_s.try_recv(), Ok(Event::Closed));
+        assert_eq!(rx_s.try_recv(), Err(mpsc::TryRecvError::Disconnected));
     }
 }