From: Thomas Karpiniec <tom.karpiniec@outlook.com>
Date: Fri, 14 Jun 2024 21:59:11 +0000 (+0100)
Subject: Support for cancelling linux watch
X-Git-Tag: v0.1.0~15
X-Git-Url: https://code.octet-stream.net/netwatcher/commitdiff_plain/d2c9a9f306a336aa9dd8ddfb6579746e4f10b7a4?ds=sidebyside;hp=4abfa61b20e567fdd69ac3ca47a9c218971a30ff

Support for cancelling linux watch
---

diff --git a/src/lib.rs b/src/lib.rs
index 69de586..5c8d73d 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -43,7 +43,7 @@ impl Interface {
 }
 
 /// Information delivered via callback when a network interface change is detected.
-/// 
+///
 /// This contains up-to-date information about all interfaces, plus a diff which
 /// details which interfaces and IP addresses have changed since the last callback.
 #[derive(Debug, Clone, PartialEq, Eq)]
@@ -122,29 +122,31 @@ impl List {
 }
 
 /// A handle to keep alive as long as you wish to receive callbacks.
-/// 
+///
 /// If the callback is executing at the time the handle is dropped, drop will block until
 /// the callback is finished and it's guaranteed that it will not be called again.
-/// 
+///
 /// Do not drop the handle from within the callback itself. It will probably deadlock.
 pub struct WatchHandle {
     _inner: watch::WatchHandle,
 }
 
 /// Retrieve information about all enabled network interfaces and their IP addresses.
-/// 
+///
 /// This is a once-off operation. If you want to detect changes over time, see `watch_interfaces`.
 pub fn list_interfaces() -> Result<HashMap<IfIndex, Interface>, Error> {
     list::list_interfaces().map(|list| list.0)
 }
 
 /// Retrieve interface information and watch for changes, which will be delivered via callback.
-/// 
+///
 /// If setting up the watch is successful, this returns a `WatchHandle` which must be kept for
 /// as long as the provided callback should operate.
-/// 
+///
 /// The callback will fire once immediately with an initial interface list, and a diff as if
 /// there were originally no interfaces present.
-pub fn watch_interfaces<F: FnMut(Update) + Send + 'static>(callback: F) -> Result<WatchHandle, Error> {
+pub fn watch_interfaces<F: FnMut(Update) + Send + 'static>(
+    callback: F,
+) -> Result<WatchHandle, Error> {
     watch::watch_interfaces(callback).map(|handle| WatchHandle { _inner: handle })
 }
diff --git a/src/list_unix.rs b/src/list_unix.rs
index 45ffc63..fad4306 100644
--- a/src/list_unix.rs
+++ b/src/list_unix.rs
@@ -62,11 +62,11 @@ pub(crate) fn list_interfaces() -> Result<List, Error> {
 
 fn format_mac(bytes: &[u8]) -> Result<String, Error> {
     let mut mac = String::with_capacity(bytes.len() * 3);
-    for i in 0..bytes.len() {
+    for (i, b) in bytes.iter().enumerate() {
         if i != 0 {
             write!(mac, ":").map_err(|_| Error::Internal)?;
         }
-        write!(mac, "{:02X}", bytes[i]).map_err(|_| Error::Internal)?;
+        write!(mac, "{:02X}", b).map_err(|_| Error::Internal)?;
     }
     Ok(mac)
 }
diff --git a/src/watch_linux.rs b/src/watch_linux.rs
index efc3842..e584211 100644
--- a/src/watch_linux.rs
+++ b/src/watch_linux.rs
@@ -1,6 +1,9 @@
 use std::os::fd::AsRawFd;
 use std::os::fd::OwnedFd;
 
+use nix::libc::poll;
+use nix::libc::pollfd;
+use nix::libc::POLLIN;
 use nix::libc::RTMGRP_IPV4_IFADDR;
 use nix::libc::RTMGRP_IPV6_IFADDR;
 use nix::libc::RTMGRP_LINK;
@@ -13,34 +16,42 @@ use nix::sys::socket::NetlinkAddr;
 use nix::sys::socket::SockFlag;
 use nix::sys::socket::SockProtocol;
 use nix::sys::socket::SockType;
+use nix::unistd::pipe;
 
 use crate::Error;
 use crate::List;
 use crate::Update;
 
 pub(crate) struct WatchHandle {
-    // PROBLEM: close() doesn't cancel recv() for a netlink socket
-    // SOLUTION: open a pipe() and use poll() inside the thread to watch for cancellation too
-    sockfd: OwnedFd,
+    // Dropping will close the fd which will be detected by poll
+    _pipefd: OwnedFd,
 }
 
 pub(crate) fn watch_interfaces<F: FnMut(Update) + Send + 'static>(
     callback: F,
 ) -> Result<WatchHandle, Error> {
-    let sockfd = start_watcher_thread(callback)?;
-    Ok(WatchHandle { sockfd })
+    let pipefd = start_watcher_thread(callback)?;
+    Ok(WatchHandle { _pipefd: pipefd })
 }
 
-fn start_watcher_thread<F: FnMut(Update) + Send + 'static>(mut callback: F) -> Result<OwnedFd, Error> {
-    let sockfd = socket(AddressFamily::Netlink, SockType::Raw, SockFlag::empty(), Some(SockProtocol::NetlinkRoute))
-        .map_err(|_| Error::Internal)?; // TODO: proper errors
-    let sa_nl = NetlinkAddr::new(0, (RTMGRP_LINK | RTMGRP_IPV4_IFADDR | RTMGRP_IPV6_IFADDR) as u32);
+fn start_watcher_thread<F: FnMut(Update) + Send + 'static>(
+    mut callback: F,
+) -> Result<OwnedFd, Error> {
+    let sockfd = socket(
+        AddressFamily::Netlink,
+        SockType::Raw,
+        SockFlag::empty(),
+        Some(SockProtocol::NetlinkRoute),
+    )
+    .map_err(|_| Error::Internal)?; // TODO: proper errors
+    let sa_nl = NetlinkAddr::new(
+        0,
+        (RTMGRP_LINK | RTMGRP_IPV4_IFADDR | RTMGRP_IPV6_IFADDR) as u32,
+    );
     bind(sockfd.as_raw_fd(), &sa_nl).map_err(|_| Error::Internal)?; // TODO: proper errors
-    let fd = sockfd.as_raw_fd();
-    println!("netlink socket on fd {}", fd);
+    let (pipe_rd, pipe_wr) = pipe().map_err(|_| Error::Internal)?;
 
     std::thread::spawn(move || {
-        println!("watch thread running");
         let mut prev_list = List::default();
         let mut buf = [0u8; 4096];
         let mut handle_update = move |new_list: List| {
@@ -59,15 +70,37 @@ fn start_watcher_thread<F: FnMut(Update) + Send + 'static>(mut callback: F) -> R
             handle_update(initial);
         };
 
-        while let Ok(n) = recv(fd, &mut buf, MsgFlags::empty()) {
-            println!("something on the netlink socket: {} bytes", n);
-            let Ok(new_list) = crate::list::list_interfaces() else {
-                continue;
-            };
-            handle_update(new_list);
+        loop {
+            let mut fds = [
+                pollfd {
+                    fd: sockfd.as_raw_fd(),
+                    events: POLLIN,
+                    revents: 0,
+                },
+                pollfd {
+                    fd: pipe_rd.as_raw_fd(),
+                    events: POLLIN,
+                    revents: 0,
+                },
+            ];
+            unsafe {
+                poll(&mut fds as *mut _, 2, -1);
+            }
+            if fds[0].revents != 0 {
+                // netlink socket had something happen
+                if recv(sockfd.as_raw_fd(), &mut buf, MsgFlags::empty()).is_ok() {
+                    let Ok(new_list) = crate::list::list_interfaces() else {
+                        continue;
+                    };
+                    handle_update(new_list);
+                }
+            }
+            if fds[1].revents != 0 {
+                // pipe had something happen
+                break;
+            }
         }
-        println!("netlink recv thread terminating");
     });
-    
-    Ok(sockfd)
+
+    Ok(pipe_wr)
 }
diff --git a/src/watch_win.rs b/src/watch_win.rs
index edd2fe3..23b274e 100644
--- a/src/watch_win.rs
+++ b/src/watch_win.rs
@@ -81,7 +81,8 @@ unsafe extern "system" fn notif(
 ) {
     let state_ptr = ctx as *const Mutex<WatchState>;
     unsafe {
-        let state_guard = &mut *state_ptr.as_ref()
+        let state_guard = &mut *state_ptr
+            .as_ref()
             .expect("callback ctx should never be null")
             .lock()
             .unwrap();