From cc3e849e9ff8568609f160d4bcdb83966ffdb6c5 Mon Sep 17 00:00:00 2001 From: Thomas Karpiniec Date: Sat, 8 Jun 2024 16:26:06 +1000 Subject: [PATCH] First go at a Windows watch implementation --- examples/watch.rs | 13 ++++++ src/lib.rs | 69 +++++++++++++++++++++++----- src/list_win.rs | 10 ++--- src/watch_win.rs | 111 ++++++++++++++++++++++++++++++++++++++++++---- 4 files changed, 180 insertions(+), 23 deletions(-) create mode 100644 examples/watch.rs diff --git a/examples/watch.rs b/examples/watch.rs new file mode 100644 index 0000000..4129dc6 --- /dev/null +++ b/examples/watch.rs @@ -0,0 +1,13 @@ +use std::time::Duration; + +fn main() { + let _handle = netwatcher::watch_interfaces(|update| { + println!("Interface update!"); + println!("State: {:?}", update.interfaces); + println!("Diff: {:?}", update.diff); + }); + + loop { + std::thread::sleep(Duration::from_secs(60)); + } +} diff --git a/src/lib.rs b/src/lib.rs index 61e4940..47c5ae1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,7 @@ use std::{ - collections::HashMap, + collections::{HashMap, HashSet}, net::{IpAddr, Ipv4Addr, Ipv6Addr}, + ops::Sub, }; #[cfg_attr(windows, path = "list_win.rs")] @@ -11,6 +12,7 @@ mod list; #[cfg_attr(target_vendor = "apple", path = "watch_mac.rs")] mod watch; +#[cfg(target_vendor = "apple")] mod util; type IfIndex = u32; @@ -45,13 +47,6 @@ pub struct Update { pub diff: UpdateDiff, } -impl Update { - pub fn diff_from_previous(_prev: &Update) -> UpdateDiff { - // TODO: real calculation - UpdateDiff::default() - } -} - #[derive(Debug, Clone, PartialEq, Eq, Default)] pub struct UpdateDiff { pub added: Vec, @@ -72,6 +67,60 @@ pub enum Error { Internal, } -pub use list::list_interfaces; -pub use watch::{watch_interfaces, WatchHandle}; +#[derive(Default, PartialEq, Eq)] +struct List(HashMap); + +impl List { + fn diff_from(&self, prev: &List) -> UpdateDiff { + let prev_index_set: HashSet = prev.0.keys().cloned().collect(); + let curr_index_set: HashSet = self.0.keys().cloned().collect(); + let added = curr_index_set.sub(&prev_index_set).into_iter().collect(); + let removed = prev_index_set.sub(&curr_index_set).into_iter().collect(); + let mut modified = HashMap::new(); + for index in curr_index_set.intersection(&prev_index_set) { + if prev.0[index] == self.0[index] { + continue; + } + let prev_addr_set: HashSet<&IpAddr> = prev.0[index].ips.iter().collect(); + let curr_addr_set: HashSet<&IpAddr> = self.0[index].ips.iter().collect(); + let addrs_added: Vec = curr_addr_set + .sub(&prev_addr_set) + .iter() + .cloned() + .cloned() + .collect(); + let addrs_removed: Vec = prev_addr_set + .sub(&curr_addr_set) + .iter() + .cloned() + .cloned() + .collect(); + let hw_addr_changed = prev.0[index].hw_addr != self.0[index].hw_addr; + modified.insert( + *index, + InterfaceDiff { + hw_addr_changed, + addrs_added, + addrs_removed, + }, + ); + } + UpdateDiff { + added, + removed, + modified, + } + } +} +pub struct WatchHandle { + _inner: watch::WatchHandle, +} + +pub fn list_interfaces() -> Result, Error> { + list::list_interfaces().map(|list| list.0) +} + +pub fn watch_interfaces(callback: F) -> Result { + watch::watch_interfaces(callback).map(|handle| WatchHandle { _inner: handle }) +} diff --git a/src/list_win.rs b/src/list_win.rs index e842244..718f2bb 100644 --- a/src/list_win.rs +++ b/src/list_win.rs @@ -16,9 +16,9 @@ use windows::Win32::Networking::WinSock::{ AF_INET, AF_INET6, AF_UNSPEC, SOCKADDR, SOCKADDR_IN, SOCKADDR_IN6, }; -use crate::{Error, IfIndex, Interface}; +use crate::{Error, Interface, List}; -pub fn list_interfaces() -> Result, Error> { +pub(crate) fn list_interfaces() -> Result { let mut ifs = HashMap::new(); // Microsoft recommends a 15 KB initial buffer let start_size = 15 * 1024; @@ -44,7 +44,7 @@ pub fn list_interfaces() -> Result, Error> { } ERROR_INVALID_PARAMETER => return Err(Error::Internal), ERROR_NOT_ENOUGH_MEMORY => return Err(Error::Internal), - ERROR_NO_DATA => return Ok(HashMap::new()), // there aren't any + ERROR_NO_DATA => return Ok(List(HashMap::new())), // there aren't any _ => return Err(Error::Internal), // TODO: Use FormatMessage to get a string } } @@ -98,7 +98,7 @@ pub fn list_interfaces() -> Result, Error> { } } - Ok(ifs) + Ok(List(ifs)) } #[cfg(test)] @@ -107,7 +107,7 @@ mod test { #[test] fn list() { - let ifaces = list_interfaces().unwrap(); + let ifaces = list_interfaces().unwrap().0; println!("{:?}", ifaces); } } diff --git a/src/watch_win.rs b/src/watch_win.rs index e495e9a..c88cd7c 100644 --- a/src/watch_win.rs +++ b/src/watch_win.rs @@ -1,8 +1,103 @@ -use crate::Update; - -pub struct WatchHandle; - -pub fn watch_interfaces(callback: F) -> WatchHandle { - drop(callback); - WatchHandle -} +use std::ffi::c_void; +use std::pin::Pin; +use std::sync::Mutex; + +use windows::Win32::Foundation::ERROR_INVALID_HANDLE; +use windows::Win32::Foundation::ERROR_INVALID_PARAMETER; +use windows::Win32::Foundation::ERROR_NOT_ENOUGH_MEMORY; +use windows::Win32::Foundation::NO_ERROR; +use windows::Win32::NetworkManagement::IpHelper::CancelMibChangeNotify2; +use windows::Win32::NetworkManagement::IpHelper::MIB_NOTIFICATION_TYPE; +use windows::Win32::NetworkManagement::IpHelper::MIB_UNICASTIPADDRESS_ROW; +use windows::Win32::{ + Foundation::{BOOLEAN, HANDLE}, + NetworkManagement::IpHelper::NotifyUnicastIpAddressChange, + Networking::WinSock::AF_UNSPEC, +}; + +use crate::Error; +use crate::List; +use crate::Update; + +pub struct WatchState { + /// The last result that we captured, for diffing + prev_list: List, + /// User's callback + cb: Box, +} + +pub struct WatchHandle { + hnd: HANDLE, + _state: Pin>>, +} + +impl Drop for WatchHandle { + fn drop(&mut self) { + unsafe { + let _ = CancelMibChangeNotify2(self.hnd); + } + } +} + +pub(crate) fn watch_interfaces( + mut callback: F, +) -> Result { + let null_list = List::default(); + let prev_list = crate::list::list_interfaces()?; + callback(Update { + interfaces: prev_list.0.clone(), + diff: prev_list.diff_from(&null_list), + }); + + // TODO: Can wo do something about the race condition? + let state = Box::pin(Mutex::new(WatchState { + prev_list, + cb: Box::new(callback), + })); + let state_ctx = &*state.as_ref() as *const _ as *const c_void; + + let mut hnd = HANDLE::default(); + let res = unsafe { + NotifyUnicastIpAddressChange( + AF_UNSPEC, + Some(notif), + Some(state_ctx), + BOOLEAN(0), + &mut hnd, + ) + }; + match res { + NO_ERROR => Ok(WatchHandle { hnd, _state: state }), + ERROR_INVALID_HANDLE => Err(Error::Internal), + ERROR_INVALID_PARAMETER => Err(Error::Internal), + ERROR_NOT_ENOUGH_MEMORY => Err(Error::Internal), + _ => Err(Error::Internal), // TODO: Use FormatMessage and get real error + } +} + +unsafe extern "system" fn notif( + ctx: *const c_void, + _row: *const MIB_UNICASTIPADDRESS_ROW, + _notification_type: MIB_NOTIFICATION_TYPE, +) { + println!("There was a change!"); + let Ok(new_list) = crate::list::list_interfaces() else { + println!("Failed to get list of interfaces on change"); + return; + }; + let state_ptr = ctx as *const Mutex; + unsafe { + let state_guard = &mut *state_ptr.as_ref().unwrap().lock().unwrap(); + if new_list == state_guard.prev_list { + // TODO: Hitting this a lot, is it true? + println!("Interfaces seem to be the same, ignoring"); + return; + } + let update = Update { + interfaces: new_list.0.clone(), + diff: new_list.diff_from(&state_guard.prev_list), + }; + (state_guard.cb)(update); + state_guard.prev_list = new_list; + } +} -- 2.39.5