]> code.octet-stream.net Git - netwatcher/blobdiff - src/watch_linux.rs
Make example program test dropping the WatchHandle
[netwatcher] / src / watch_linux.rs
index c1f08a25b40d1a973bec553041a24bc77105a486..dc9957211a698fc4269911e4adf674bfd4f8cfc4 100644 (file)
@@ -1,5 +1,6 @@
 use std::os::fd::AsRawFd;
 use std::os::fd::OwnedFd;
 use std::os::fd::AsRawFd;
 use std::os::fd::OwnedFd;
+use std::sync::mpsc;
 
 use nix::libc::poll;
 use nix::libc::pollfd;
 
 use nix::libc::poll;
 use nix::libc::pollfd;
@@ -10,9 +11,9 @@ use nix::sys::socket::socket;
 use nix::sys::socket::AddressFamily;
 use nix::sys::socket::MsgFlags;
 use nix::sys::socket::NetlinkAddr;
 use nix::sys::socket::AddressFamily;
 use nix::sys::socket::MsgFlags;
 use nix::sys::socket::NetlinkAddr;
-use nix::sys::socket::SockFlag;
 use nix::sys::socket::SockProtocol;
 use nix::sys::socket::SockType;
 use nix::sys::socket::SockProtocol;
 use nix::sys::socket::SockType;
+use nix::sys::socket::SOCK_NONBLOCK;
 use nix::unistd::pipe;
 
 use crate::Error;
 use nix::unistd::pipe;
 
 use crate::Error;
@@ -24,33 +25,47 @@ const RTMGRP_IPV6_IFADDR: u32 = 0x20;
 const RTMGRP_LINK: u32 = 0x01;
 
 pub(crate) struct WatchHandle {
 const RTMGRP_LINK: u32 = 0x01;
 
 pub(crate) struct WatchHandle {
-    // Dropping will close the fd which will be detected by poll
-    _pipefd: OwnedFd,
+    // Close on drop, which will be detected by poll in background thread
+    pipefd: Option<OwnedFd>,
+
+    // Detect when thread has completed
+    complete: Option<mpsc::Receiver<()>>,
+}
+
+impl Drop for WatchHandle {
+    fn drop(&mut self) {
+        drop(self.pipefd.take());
+        let _ = self.complete.take().recv();
+    }
 }
 
 pub(crate) fn watch_interfaces<F: FnMut(Update) + Send + 'static>(
     callback: F,
 ) -> Result<WatchHandle, Error> {
 }
 
 pub(crate) fn watch_interfaces<F: FnMut(Update) + Send + 'static>(
     callback: F,
 ) -> Result<WatchHandle, Error> {
-    let pipefd = start_watcher_thread(callback)?;
-    Ok(WatchHandle { _pipefd: pipefd })
+    let (pipefd, complete) = start_watcher_thread(callback)?;
+    Ok(WatchHandle {
+        pipefd: Some(pipefd),
+        complete: Some(complete),
+    })
 }
 
 fn start_watcher_thread<F: FnMut(Update) + Send + 'static>(
     mut callback: F,
 }
 
 fn start_watcher_thread<F: FnMut(Update) + Send + 'static>(
     mut callback: F,
-) -> Result<OwnedFd, Error> {
+) -> Result<(OwnedFd, mpsc::Receiver<()>), Error> {
     let sockfd = socket(
         AddressFamily::Netlink,
         SockType::Raw,
     let sockfd = socket(
         AddressFamily::Netlink,
         SockType::Raw,
-        SockFlag::empty(),
+        SOCK_NONBLOCK,
         Some(SockProtocol::NetlinkRoute),
     )
         Some(SockProtocol::NetlinkRoute),
     )
-    .map_err(|_| Error::Internal)?; // TODO: proper errors
+    .map_err(|e| Error::CreateSocket(e.to_string()))?;
+    sockfd.set_nonblocking(true);
     let sa_nl = NetlinkAddr::new(
         0,
         (RTMGRP_LINK | RTMGRP_IPV4_IFADDR | RTMGRP_IPV6_IFADDR) as u32,
     );
     let sa_nl = NetlinkAddr::new(
         0,
         (RTMGRP_LINK | RTMGRP_IPV4_IFADDR | RTMGRP_IPV6_IFADDR) as u32,
     );
-    bind(sockfd.as_raw_fd(), &sa_nl).map_err(|_| Error::Internal)?; // TODO: proper errors
-    let (pipe_rd, pipe_wr) = pipe().map_err(|_| Error::Internal)?;
+    bind(sockfd.as_raw_fd(), &sa_nl).map_err(|e| Error::Bind(e.to_string()))?;
+    let (pipe_rd, pipe_wr) = pipe().map_err(|e| Error::CreatePipe(e.to_string()))?;
 
     let mut prev_list = List::default();
     let mut handle_update = move |new_list: List| {
 
     let mut prev_list = List::default();
     let mut handle_update = move |new_list: List| {
@@ -70,6 +85,8 @@ fn start_watcher_thread<F: FnMut(Update) + Send + 'static>(
     // looks like we're going to have trouble listing interfaces.
     handle_update(crate::list::list_interfaces()?);
 
     // looks like we're going to have trouble listing interfaces.
     handle_update(crate::list::list_interfaces()?);
 
+    let (complete_tx, complete_rx) = mpsc::channel();
+
     std::thread::spawn(move || {
         let mut buf = [0u8; 4096];
 
     std::thread::spawn(move || {
         let mut buf = [0u8; 4096];
 
@@ -103,6 +120,8 @@ fn start_watcher_thread<F: FnMut(Update) + Send + 'static>(
                 break;
             }
         }
                 break;
             }
         }
+
+        drop(complete_tx);
     });
 
     Ok(pipe_wr)
     });
 
     Ok(pipe_wr)