]> code.octet-stream.net Git - netwatcher/commitdiff
First go at a Windows watch implementation
authorThomas Karpiniec <tom.karpiniec@outlook.com>
Sat, 8 Jun 2024 06:26:06 +0000 (16:26 +1000)
committerThomas Karpiniec <tom.karpiniec@outlook.com>
Sat, 8 Jun 2024 06:26:06 +0000 (16:26 +1000)
examples/watch.rs [new file with mode: 0644]
src/lib.rs
src/list_win.rs
src/watch_win.rs

diff --git a/examples/watch.rs b/examples/watch.rs
new file mode 100644 (file)
index 0000000..4129dc6
--- /dev/null
@@ -0,0 +1,13 @@
+use std::time::Duration;
+
+fn main() {
+    let _handle = netwatcher::watch_interfaces(|update| {
+        println!("Interface update!");
+        println!("State: {:?}", update.interfaces);
+        println!("Diff: {:?}", update.diff);
+    });
+
+    loop {
+        std::thread::sleep(Duration::from_secs(60));
+    }
+}
index 61e49402d67afee932a878462d309b79cd82fefb..47c5ae11957466f7e233e0f2690105d431fad154 100644 (file)
@@ -1,6 +1,7 @@
 use std::{
-    collections::HashMap,
+    collections::{HashMap, HashSet},
     net::{IpAddr, Ipv4Addr, Ipv6Addr},
+    ops::Sub,
 };
 
 #[cfg_attr(windows, path = "list_win.rs")]
@@ -11,6 +12,7 @@ mod list;
 #[cfg_attr(target_vendor = "apple", path = "watch_mac.rs")]
 mod watch;
 
+#[cfg(target_vendor = "apple")]
 mod util;
 
 type IfIndex = u32;
@@ -45,13 +47,6 @@ pub struct Update {
     pub diff: UpdateDiff,
 }
 
-impl Update {
-    pub fn diff_from_previous(_prev: &Update) -> UpdateDiff {
-        // TODO: real calculation
-        UpdateDiff::default()
-    }
-}
-
 #[derive(Debug, Clone, PartialEq, Eq, Default)]
 pub struct UpdateDiff {
     pub added: Vec<IfIndex>,
@@ -72,6 +67,60 @@ pub enum Error {
     Internal,
 }
 
-pub use list::list_interfaces;
-pub use watch::{watch_interfaces, WatchHandle};
+#[derive(Default, PartialEq, Eq)]
+struct List(HashMap<IfIndex, Interface>);
+
+impl List {
+    fn diff_from(&self, prev: &List) -> UpdateDiff {
+        let prev_index_set: HashSet<IfIndex> = prev.0.keys().cloned().collect();
+        let curr_index_set: HashSet<IfIndex> = self.0.keys().cloned().collect();
+        let added = curr_index_set.sub(&prev_index_set).into_iter().collect();
+        let removed = prev_index_set.sub(&curr_index_set).into_iter().collect();
+        let mut modified = HashMap::new();
+        for index in curr_index_set.intersection(&prev_index_set) {
+            if prev.0[index] == self.0[index] {
+                continue;
+            }
+            let prev_addr_set: HashSet<&IpAddr> = prev.0[index].ips.iter().collect();
+            let curr_addr_set: HashSet<&IpAddr> = self.0[index].ips.iter().collect();
+            let addrs_added: Vec<IpAddr> = curr_addr_set
+                .sub(&prev_addr_set)
+                .iter()
+                .cloned()
+                .cloned()
+                .collect();
+            let addrs_removed: Vec<IpAddr> = prev_addr_set
+                .sub(&curr_addr_set)
+                .iter()
+                .cloned()
+                .cloned()
+                .collect();
+            let hw_addr_changed = prev.0[index].hw_addr != self.0[index].hw_addr;
+            modified.insert(
+                *index,
+                InterfaceDiff {
+                    hw_addr_changed,
+                    addrs_added,
+                    addrs_removed,
+                },
+            );
+        }
+        UpdateDiff {
+            added,
+            removed,
+            modified,
+        }
+    }
+}
 
+pub struct WatchHandle {
+    _inner: watch::WatchHandle,
+}
+
+pub fn list_interfaces() -> Result<HashMap<IfIndex, Interface>, Error> {
+    list::list_interfaces().map(|list| list.0)
+}
+
+pub fn watch_interfaces<F: FnMut(Update) + 'static>(callback: F) -> Result<WatchHandle, Error> {
+    watch::watch_interfaces(callback).map(|handle| WatchHandle { _inner: handle })
+}
index e842244ae47c9e8f30170ac7ffd143fa95b67c41..718f2bb752b640e0c1f7b2bcc1f68379ef2c02e3 100644 (file)
@@ -16,9 +16,9 @@ use windows::Win32::Networking::WinSock::{
     AF_INET, AF_INET6, AF_UNSPEC, SOCKADDR, SOCKADDR_IN, SOCKADDR_IN6,
 };
 
-use crate::{Error, IfIndex, Interface};
+use crate::{Error, Interface, List};
 
-pub fn list_interfaces() -> Result<HashMap<IfIndex, Interface>, Error> {
+pub(crate) fn list_interfaces() -> Result<List, Error> {
     let mut ifs = HashMap::new();
     // Microsoft recommends a 15 KB initial buffer
     let start_size = 15 * 1024;
@@ -44,7 +44,7 @@ pub fn list_interfaces() -> Result<HashMap<IfIndex, Interface>, Error> {
                 }
                 ERROR_INVALID_PARAMETER => return Err(Error::Internal),
                 ERROR_NOT_ENOUGH_MEMORY => return Err(Error::Internal),
-                ERROR_NO_DATA => return Ok(HashMap::new()), // there aren't any
+                ERROR_NO_DATA => return Ok(List(HashMap::new())), // there aren't any
                 _ => return Err(Error::Internal), // TODO: Use FormatMessage to get a string
             }
         }
@@ -98,7 +98,7 @@ pub fn list_interfaces() -> Result<HashMap<IfIndex, Interface>, Error> {
         }
     }
 
-    Ok(ifs)
+    Ok(List(ifs))
 }
 
 #[cfg(test)]
@@ -107,7 +107,7 @@ mod test {
 
     #[test]
     fn list() {
-        let ifaces = list_interfaces().unwrap();
+        let ifaces = list_interfaces().unwrap().0;
         println!("{:?}", ifaces);
     }
 }
index e495e9a4467c809cc29ffe918fc6bbf97e9d0125..c88cd7c1263b32bbfa87f8b9b06b702b4f7ed5bc 100644 (file)
@@ -1,8 +1,103 @@
-use crate::Update;\r
-\r
-pub struct WatchHandle;\r
-\r
-pub fn watch_interfaces<F: FnMut(Update)>(callback: F) -> WatchHandle {\r
-    drop(callback);\r
-    WatchHandle\r
-}\r
+use std::ffi::c_void;
+use std::pin::Pin;
+use std::sync::Mutex;
+
+use windows::Win32::Foundation::ERROR_INVALID_HANDLE;
+use windows::Win32::Foundation::ERROR_INVALID_PARAMETER;
+use windows::Win32::Foundation::ERROR_NOT_ENOUGH_MEMORY;
+use windows::Win32::Foundation::NO_ERROR;
+use windows::Win32::NetworkManagement::IpHelper::CancelMibChangeNotify2;
+use windows::Win32::NetworkManagement::IpHelper::MIB_NOTIFICATION_TYPE;
+use windows::Win32::NetworkManagement::IpHelper::MIB_UNICASTIPADDRESS_ROW;
+use windows::Win32::{
+    Foundation::{BOOLEAN, HANDLE},
+    NetworkManagement::IpHelper::NotifyUnicastIpAddressChange,
+    Networking::WinSock::AF_UNSPEC,
+};
+
+use crate::Error;
+use crate::List;
+use crate::Update;
+
+pub struct WatchState {
+    /// The last result that we captured, for diffing
+    prev_list: List,
+    /// User's callback
+    cb: Box<dyn FnMut(Update) + 'static>,
+}
+
+pub struct WatchHandle {
+    hnd: HANDLE,
+    _state: Pin<Box<Mutex<WatchState>>>,
+}
+
+impl Drop for WatchHandle {
+    fn drop(&mut self) {
+        unsafe {
+            let _ = CancelMibChangeNotify2(self.hnd);
+        }
+    }
+}
+
+pub(crate) fn watch_interfaces<F: FnMut(Update) + 'static>(
+    mut callback: F,
+) -> Result<WatchHandle, Error> {
+    let null_list = List::default();
+    let prev_list = crate::list::list_interfaces()?;
+    callback(Update {
+        interfaces: prev_list.0.clone(),
+        diff: prev_list.diff_from(&null_list),
+    });
+
+    // TODO: Can wo do something about the race condition?
+    let state = Box::pin(Mutex::new(WatchState {
+        prev_list,
+        cb: Box::new(callback),
+    }));
+    let state_ctx = &*state.as_ref() as *const _ as *const c_void;
+
+    let mut hnd = HANDLE::default();
+    let res = unsafe {
+        NotifyUnicastIpAddressChange(
+            AF_UNSPEC,
+            Some(notif),
+            Some(state_ctx),
+            BOOLEAN(0),
+            &mut hnd,
+        )
+    };
+    match res {
+        NO_ERROR => Ok(WatchHandle { hnd, _state: state }),
+        ERROR_INVALID_HANDLE => Err(Error::Internal),
+        ERROR_INVALID_PARAMETER => Err(Error::Internal),
+        ERROR_NOT_ENOUGH_MEMORY => Err(Error::Internal),
+        _ => Err(Error::Internal), // TODO: Use FormatMessage and get real error
+    }
+}
+
+unsafe extern "system" fn notif(
+    ctx: *const c_void,
+    _row: *const MIB_UNICASTIPADDRESS_ROW,
+    _notification_type: MIB_NOTIFICATION_TYPE,
+) {
+    println!("There was a change!");
+    let Ok(new_list) = crate::list::list_interfaces() else {
+        println!("Failed to get list of interfaces on change");
+        return;
+    };
+    let state_ptr = ctx as *const Mutex<WatchState>;
+    unsafe {
+        let state_guard = &mut *state_ptr.as_ref().unwrap().lock().unwrap();
+        if new_list == state_guard.prev_list {
+            // TODO: Hitting this a lot, is it true?
+            println!("Interfaces seem to be the same, ignoring");
+            return;
+        }
+        let update = Update {
+            interfaces: new_list.0.clone(),
+            diff: new_list.diff_from(&state_guard.prev_list),
+        };
+        (state_guard.cb)(update);
+        state_guard.prev_list = new_list;
+    }
+}