]> code.octet-stream.net Git - netwatcher/blobdiff - src/watch_linux.rs
Commit to reporting error if original interface listing fails
[netwatcher] / src / watch_linux.rs
index 4e4049b8856ba87dece34d25677f1116c49ae7a5..95ef10638e5c5551ee9743106508031ca0462c6f 100644 (file)
@@ -1,7 +1,9 @@
 use std::os::fd::AsRawFd;
 use std::os::fd::OwnedFd;
 
-use nix::libc::nlmsghdr;
+use nix::libc::poll;
+use nix::libc::pollfd;
+use nix::libc::POLLIN;
 use nix::libc::RTMGRP_IPV4_IFADDR;
 use nix::libc::RTMGRP_IPV6_IFADDR;
 use nix::libc::RTMGRP_LINK;
@@ -14,44 +16,93 @@ use nix::sys::socket::NetlinkAddr;
 use nix::sys::socket::SockFlag;
 use nix::sys::socket::SockProtocol;
 use nix::sys::socket::SockType;
+use nix::unistd::pipe;
 
 use crate::Error;
+use crate::List;
 use crate::Update;
 
 pub(crate) struct WatchHandle {
-    // PROBLEM: close() doesn't cancel recv() for a netlink socket
-    sockfd: OwnedFd,
+    // Dropping will close the fd which will be detected by poll
+    _pipefd: OwnedFd,
 }
 
-pub(crate) fn watch_interfaces<F: FnMut(Update) + 'static>(
+pub(crate) fn watch_interfaces<F: FnMut(Update) + Send + 'static>(
     callback: F,
 ) -> Result<WatchHandle, Error> {
-    let sockfd = start_watcher_thread(callback)?;
-    Ok(WatchHandle { sockfd })
+    let pipefd = start_watcher_thread(callback)?;
+    Ok(WatchHandle { _pipefd: pipefd })
 }
 
-fn start_watcher_thread<F: FnMut(Update) + 'static>(callback: F) -> Result<OwnedFd, Error> {
-    let sockfd = socket(AddressFamily::Netlink, SockType::Raw, SockFlag::empty(), Some(SockProtocol::NetlinkRoute))
-        .map_err(|_| Error::Internal)?; // TODO: proper errors
-    let sa_nl = NetlinkAddr::new(0, (RTMGRP_LINK | RTMGRP_IPV4_IFADDR | RTMGRP_IPV6_IFADDR) as u32);
+fn start_watcher_thread<F: FnMut(Update) + Send + 'static>(
+    mut callback: F,
+) -> Result<OwnedFd, Error> {
+    let sockfd = socket(
+        AddressFamily::Netlink,
+        SockType::Raw,
+        SockFlag::empty(),
+        Some(SockProtocol::NetlinkRoute),
+    )
+    .map_err(|_| Error::Internal)?; // TODO: proper errors
+    let sa_nl = NetlinkAddr::new(
+        0,
+        (RTMGRP_LINK | RTMGRP_IPV4_IFADDR | RTMGRP_IPV6_IFADDR) as u32,
+    );
     bind(sockfd.as_raw_fd(), &sa_nl).map_err(|_| Error::Internal)?; // TODO: proper errors
-    let fd = sockfd.as_raw_fd();
-    println!("netlink socket on fd {}", fd);
+    let (pipe_rd, pipe_wr) = pipe().map_err(|_| Error::Internal)?;
+
+    let mut prev_list = List::default();
+    let mut handle_update = move |new_list: List| {
+        if new_list == prev_list {
+            return;
+        }
+        let update = Update {
+            interfaces: new_list.0.clone(),
+            diff: new_list.diff_from(&prev_list),
+        };
+        (callback)(update);
+        prev_list = new_list;
+    };
+
+    // Now that netlink socket is open, provide an initial update.
+    // By having this outside the thread we can return an error synchronously if it
+    // looks like we're going to have trouble listing interfaces.
+    handle_update(crate::list::list_interfaces()?);
 
     std::thread::spawn(move || {
-        println!("watch thread running");
         let mut buf = [0u8; 4096];
-        // recvmsg?
-        while let Ok(n) = recv(fd, &mut buf, MsgFlags::empty()) {
-            println!("something on the netlink socket: {} bytes", n);
-            let nlmsg_ptr = &buf as *const _ as *const nlmsghdr;
-            let nlmsg = unsafe { &*nlmsg_ptr };
-            // Right conventionally there's some trick here involving macros NLMSG_OK
-            // I can presumably do this using NetlinkGeneric too
-            // It's unclear whether this is worse or not - need to know what those macros do
+
+        loop {
+            let mut fds = [
+                pollfd {
+                    fd: sockfd.as_raw_fd(),
+                    events: POLLIN,
+                    revents: 0,
+                },
+                pollfd {
+                    fd: pipe_rd.as_raw_fd(),
+                    events: POLLIN,
+                    revents: 0,
+                },
+            ];
+            unsafe {
+                poll(&mut fds as *mut _, 2, -1);
+            }
+            if fds[0].revents != 0 {
+                // netlink socket had something happen
+                if recv(sockfd.as_raw_fd(), &mut buf, MsgFlags::empty()).is_ok() {
+                    let Ok(new_list) = crate::list::list_interfaces() else {
+                        continue;
+                    };
+                    handle_update(new_list);
+                }
+            }
+            if fds[1].revents != 0 {
+                // pipe had something happen
+                break;
+            }
         }
-        println!("netlink recv thread terminating");
     });
-    
-    Ok(sockfd)
+
+    Ok(pipe_wr)
 }