]> code.octet-stream.net Git - netwatcher/blobdiff - src/lib.rs
Clean up race condition in initial watch on Windows
[netwatcher] / src / lib.rs
index 61e49402d67afee932a878462d309b79cd82fefb..47c5ae11957466f7e233e0f2690105d431fad154 100644 (file)
@@ -1,6 +1,7 @@
 use std::{
 use std::{
-    collections::HashMap,
+    collections::{HashMap, HashSet},
     net::{IpAddr, Ipv4Addr, Ipv6Addr},
     net::{IpAddr, Ipv4Addr, Ipv6Addr},
+    ops::Sub,
 };
 
 #[cfg_attr(windows, path = "list_win.rs")]
 };
 
 #[cfg_attr(windows, path = "list_win.rs")]
@@ -11,6 +12,7 @@ mod list;
 #[cfg_attr(target_vendor = "apple", path = "watch_mac.rs")]
 mod watch;
 
 #[cfg_attr(target_vendor = "apple", path = "watch_mac.rs")]
 mod watch;
 
+#[cfg(target_vendor = "apple")]
 mod util;
 
 type IfIndex = u32;
 mod util;
 
 type IfIndex = u32;
@@ -45,13 +47,6 @@ pub struct Update {
     pub diff: UpdateDiff,
 }
 
     pub diff: UpdateDiff,
 }
 
-impl Update {
-    pub fn diff_from_previous(_prev: &Update) -> UpdateDiff {
-        // TODO: real calculation
-        UpdateDiff::default()
-    }
-}
-
 #[derive(Debug, Clone, PartialEq, Eq, Default)]
 pub struct UpdateDiff {
     pub added: Vec<IfIndex>,
 #[derive(Debug, Clone, PartialEq, Eq, Default)]
 pub struct UpdateDiff {
     pub added: Vec<IfIndex>,
@@ -72,6 +67,60 @@ pub enum Error {
     Internal,
 }
 
     Internal,
 }
 
-pub use list::list_interfaces;
-pub use watch::{watch_interfaces, WatchHandle};
+#[derive(Default, PartialEq, Eq)]
+struct List(HashMap<IfIndex, Interface>);
+
+impl List {
+    fn diff_from(&self, prev: &List) -> UpdateDiff {
+        let prev_index_set: HashSet<IfIndex> = prev.0.keys().cloned().collect();
+        let curr_index_set: HashSet<IfIndex> = self.0.keys().cloned().collect();
+        let added = curr_index_set.sub(&prev_index_set).into_iter().collect();
+        let removed = prev_index_set.sub(&curr_index_set).into_iter().collect();
+        let mut modified = HashMap::new();
+        for index in curr_index_set.intersection(&prev_index_set) {
+            if prev.0[index] == self.0[index] {
+                continue;
+            }
+            let prev_addr_set: HashSet<&IpAddr> = prev.0[index].ips.iter().collect();
+            let curr_addr_set: HashSet<&IpAddr> = self.0[index].ips.iter().collect();
+            let addrs_added: Vec<IpAddr> = curr_addr_set
+                .sub(&prev_addr_set)
+                .iter()
+                .cloned()
+                .cloned()
+                .collect();
+            let addrs_removed: Vec<IpAddr> = prev_addr_set
+                .sub(&curr_addr_set)
+                .iter()
+                .cloned()
+                .cloned()
+                .collect();
+            let hw_addr_changed = prev.0[index].hw_addr != self.0[index].hw_addr;
+            modified.insert(
+                *index,
+                InterfaceDiff {
+                    hw_addr_changed,
+                    addrs_added,
+                    addrs_removed,
+                },
+            );
+        }
+        UpdateDiff {
+            added,
+            removed,
+            modified,
+        }
+    }
+}
 
 
+pub struct WatchHandle {
+    _inner: watch::WatchHandle,
+}
+
+pub fn list_interfaces() -> Result<HashMap<IfIndex, Interface>, Error> {
+    list::list_interfaces().map(|list| list.0)
+}
+
+pub fn watch_interfaces<F: FnMut(Update) + 'static>(callback: F) -> Result<WatchHandle, Error> {
+    watch::watch_interfaces(callback).map(|handle| WatchHandle { _inner: handle })
+}