From d2c9a9f306a336aa9dd8ddfb6579746e4f10b7a4 Mon Sep 17 00:00:00 2001 From: Thomas Karpiniec Date: Fri, 14 Jun 2024 22:59:11 +0100 Subject: [PATCH 1/1] Support for cancelling linux watch --- src/lib.rs | 16 +++++----- src/list_unix.rs | 4 +-- src/watch_linux.rs | 75 +++++++++++++++++++++++++++++++++------------- src/watch_win.rs | 3 +- 4 files changed, 67 insertions(+), 31 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 69de586..5c8d73d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -43,7 +43,7 @@ impl Interface { } /// Information delivered via callback when a network interface change is detected. -/// +/// /// This contains up-to-date information about all interfaces, plus a diff which /// details which interfaces and IP addresses have changed since the last callback. #[derive(Debug, Clone, PartialEq, Eq)] @@ -122,29 +122,31 @@ impl List { } /// A handle to keep alive as long as you wish to receive callbacks. -/// +/// /// If the callback is executing at the time the handle is dropped, drop will block until /// the callback is finished and it's guaranteed that it will not be called again. -/// +/// /// Do not drop the handle from within the callback itself. It will probably deadlock. pub struct WatchHandle { _inner: watch::WatchHandle, } /// Retrieve information about all enabled network interfaces and their IP addresses. -/// +/// /// This is a once-off operation. If you want to detect changes over time, see `watch_interfaces`. pub fn list_interfaces() -> Result, Error> { list::list_interfaces().map(|list| list.0) } /// Retrieve interface information and watch for changes, which will be delivered via callback. -/// +/// /// If setting up the watch is successful, this returns a `WatchHandle` which must be kept for /// as long as the provided callback should operate. -/// +/// /// The callback will fire once immediately with an initial interface list, and a diff as if /// there were originally no interfaces present. -pub fn watch_interfaces(callback: F) -> Result { +pub fn watch_interfaces( + callback: F, +) -> Result { watch::watch_interfaces(callback).map(|handle| WatchHandle { _inner: handle }) } diff --git a/src/list_unix.rs b/src/list_unix.rs index 45ffc63..fad4306 100644 --- a/src/list_unix.rs +++ b/src/list_unix.rs @@ -62,11 +62,11 @@ pub(crate) fn list_interfaces() -> Result { fn format_mac(bytes: &[u8]) -> Result { let mut mac = String::with_capacity(bytes.len() * 3); - for i in 0..bytes.len() { + for (i, b) in bytes.iter().enumerate() { if i != 0 { write!(mac, ":").map_err(|_| Error::Internal)?; } - write!(mac, "{:02X}", bytes[i]).map_err(|_| Error::Internal)?; + write!(mac, "{:02X}", b).map_err(|_| Error::Internal)?; } Ok(mac) } diff --git a/src/watch_linux.rs b/src/watch_linux.rs index efc3842..e584211 100644 --- a/src/watch_linux.rs +++ b/src/watch_linux.rs @@ -1,6 +1,9 @@ use std::os::fd::AsRawFd; use std::os::fd::OwnedFd; +use nix::libc::poll; +use nix::libc::pollfd; +use nix::libc::POLLIN; use nix::libc::RTMGRP_IPV4_IFADDR; use nix::libc::RTMGRP_IPV6_IFADDR; use nix::libc::RTMGRP_LINK; @@ -13,34 +16,42 @@ use nix::sys::socket::NetlinkAddr; use nix::sys::socket::SockFlag; use nix::sys::socket::SockProtocol; use nix::sys::socket::SockType; +use nix::unistd::pipe; use crate::Error; use crate::List; use crate::Update; pub(crate) struct WatchHandle { - // PROBLEM: close() doesn't cancel recv() for a netlink socket - // SOLUTION: open a pipe() and use poll() inside the thread to watch for cancellation too - sockfd: OwnedFd, + // Dropping will close the fd which will be detected by poll + _pipefd: OwnedFd, } pub(crate) fn watch_interfaces( callback: F, ) -> Result { - let sockfd = start_watcher_thread(callback)?; - Ok(WatchHandle { sockfd }) + let pipefd = start_watcher_thread(callback)?; + Ok(WatchHandle { _pipefd: pipefd }) } -fn start_watcher_thread(mut callback: F) -> Result { - let sockfd = socket(AddressFamily::Netlink, SockType::Raw, SockFlag::empty(), Some(SockProtocol::NetlinkRoute)) - .map_err(|_| Error::Internal)?; // TODO: proper errors - let sa_nl = NetlinkAddr::new(0, (RTMGRP_LINK | RTMGRP_IPV4_IFADDR | RTMGRP_IPV6_IFADDR) as u32); +fn start_watcher_thread( + mut callback: F, +) -> Result { + let sockfd = socket( + AddressFamily::Netlink, + SockType::Raw, + SockFlag::empty(), + Some(SockProtocol::NetlinkRoute), + ) + .map_err(|_| Error::Internal)?; // TODO: proper errors + let sa_nl = NetlinkAddr::new( + 0, + (RTMGRP_LINK | RTMGRP_IPV4_IFADDR | RTMGRP_IPV6_IFADDR) as u32, + ); bind(sockfd.as_raw_fd(), &sa_nl).map_err(|_| Error::Internal)?; // TODO: proper errors - let fd = sockfd.as_raw_fd(); - println!("netlink socket on fd {}", fd); + let (pipe_rd, pipe_wr) = pipe().map_err(|_| Error::Internal)?; std::thread::spawn(move || { - println!("watch thread running"); let mut prev_list = List::default(); let mut buf = [0u8; 4096]; let mut handle_update = move |new_list: List| { @@ -59,15 +70,37 @@ fn start_watcher_thread(mut callback: F) -> R handle_update(initial); }; - while let Ok(n) = recv(fd, &mut buf, MsgFlags::empty()) { - println!("something on the netlink socket: {} bytes", n); - let Ok(new_list) = crate::list::list_interfaces() else { - continue; - }; - handle_update(new_list); + loop { + let mut fds = [ + pollfd { + fd: sockfd.as_raw_fd(), + events: POLLIN, + revents: 0, + }, + pollfd { + fd: pipe_rd.as_raw_fd(), + events: POLLIN, + revents: 0, + }, + ]; + unsafe { + poll(&mut fds as *mut _, 2, -1); + } + if fds[0].revents != 0 { + // netlink socket had something happen + if recv(sockfd.as_raw_fd(), &mut buf, MsgFlags::empty()).is_ok() { + let Ok(new_list) = crate::list::list_interfaces() else { + continue; + }; + handle_update(new_list); + } + } + if fds[1].revents != 0 { + // pipe had something happen + break; + } } - println!("netlink recv thread terminating"); }); - - Ok(sockfd) + + Ok(pipe_wr) } diff --git a/src/watch_win.rs b/src/watch_win.rs index edd2fe3..23b274e 100644 --- a/src/watch_win.rs +++ b/src/watch_win.rs @@ -81,7 +81,8 @@ unsafe extern "system" fn notif( ) { let state_ptr = ctx as *const Mutex; unsafe { - let state_guard = &mut *state_ptr.as_ref() + let state_guard = &mut *state_ptr + .as_ref() .expect("callback ctx should never be null") .lock() .unwrap(); -- 2.39.5