virtual_desktop_manager/
change_elevation.rs

1use deelevate::{Command, PrivilegeLevel, Token};
2use std::{
3    any::Any,
4    borrow::Cow,
5    ffi::OsString,
6    io::{Read, Result as IoResult},
7    net::{Shutdown, TcpStream},
8    sync::{
9        atomic::{AtomicBool, Ordering},
10        mpsc::{self, TryRecvError},
11        OnceLock,
12    },
13    time::Duration,
14};
15
16/// Workaround for the `deelevate::process::Process` type that is private.
17#[allow(clippy::type_complexity)]
18struct Process<T = Box<dyn Any>> {
19    state: T,
20    /// Wait for the specified duration (in milliseconds!) to pass.
21    /// Use None to wait forever.
22    wait_for: Box<dyn Fn(&T, Option<u32>) -> IoResult<u32>>,
23    /// Retrieves the exit code from the process
24    exit_code: Box<dyn Fn(&T) -> IoResult<u32>>,
25}
26impl<T> Process<T> {
27    fn into_any(self) -> Process
28    where
29        T: Any,
30    {
31        Process {
32            state: Box::new(self.state),
33            wait_for: Box::new(move |s, duration| {
34                (self.wait_for)(s.downcast_ref().unwrap(), duration)
35            }),
36            exit_code: Box::new(move |s| (self.exit_code)(s.downcast_ref().unwrap())),
37        }
38    }
39
40    /// Wait for the specified duration (in milliseconds!) to pass.
41    /// Use None to wait forever.
42    pub fn wait_for(&self, duration: Option<u32>) -> IoResult<u32> {
43        (self.wait_for)(&self.state, duration)
44    }
45
46    /// Retrieves the exit code from the process
47    pub fn exit_code(&self) -> IoResult<u32> {
48        (self.exit_code)(&self.state)
49    }
50}
51macro_rules! into_process {
52    ($expr:expr) => {
53        Process {
54            state: $expr,
55            wait_for: Box::new(|s, dur| s.wait_for(dur)),
56            exit_code: Box::new(|s| s.exit_code()),
57        }
58        .into_any()
59    };
60}
61
62pub trait SetElevationHandler: Send {
63    fn get_args(&mut self, port: u16) -> Vec<OsString>;
64    fn exit(&mut self) -> !;
65    fn confirm_message(&mut self) -> Cow<'_, [u8]>;
66}
67
68pub fn set_elevation(
69    app: &mut dyn SetElevationHandler,
70    should_elevate: bool,
71) -> Result<(), String> {
72    let token = Token::with_current_process()
73        .map_err(|e| format!("failed to get token for current process: {e}"))?;
74    let level = token
75        .privilege_level()
76        .map_err(|e| format!("failed to get privilege level of token: {e}"))?;
77
78    let target_token = if should_elevate {
79        match level {
80            PrivilegeLevel::NotPrivileged => token
81                .as_medium_integrity_safer_token()
82                .map_err(|e| format!("failed to create token with elevated privilege: {e}"))?,
83            PrivilegeLevel::HighIntegrityAdmin | PrivilegeLevel::Elevated => return Ok(()),
84        }
85    } else {
86        match level {
87            PrivilegeLevel::NotPrivileged => return Ok(()),
88            PrivilegeLevel::Elevated => Token::with_shell_process()
89                .map_err(|e| format!("failed to find token for shell process: {e}"))?,
90            PrivilegeLevel::HighIntegrityAdmin => token
91                .as_medium_integrity_safer_token()
92                .map_err(|e| format!("failed to change privilege level of token: {e}"))?,
93        }
94    };
95
96    let mut command = Command::with_environment_for_token(&target_token)
97        .map_err(|e| format!("failed to create environment for child process: {e}"))?;
98
99    let current_exe = std::env::current_exe()
100        .map_err(|e| format!("failed to resolve path to current executable: {e}"))?;
101
102    let tcp = std::net::TcpListener::bind("127.0.0.1:0")
103        .map_err(|e| format!("failed to open a local TCP connection: {e}"))?;
104    let addr = tcp
105        .local_addr()
106        .map_err(|e| format!("failed to get info about local TCP connection: {e}"))?;
107
108    command.set_argv({
109        let mut args = app.get_args(addr.port());
110        args.insert(0, OsString::from(current_exe));
111        args
112    });
113    let proc = if should_elevate {
114        command
115            .shell_execute("runas")
116            .map_err(|e| format!("failed to spawn elevated process: {e}"))?
117    } else {
118        command
119            .spawn_with_token(&target_token)
120            .map_err(|e| format!("failed to spawn child process: {e}"))?
121    };
122    let proc = into_process!(proc);
123
124    let cancel = AtomicBool::new(false);
125    let shared_stream = OnceLock::<TcpStream>::new();
126    std::thread::scope(|s| -> Result<(), String> {
127        let (tx, rx) = mpsc::channel();
128        let handle = s.spawn(|| {
129            let tx = tx;
130            let (stream, _addr) = match tcp.accept() {
131                Ok(v) => v,
132                Err(e) => {
133                    let _ = tx.send(format!("failed to accept TCP connection: {e}"));
134                    return
135                }
136            };
137
138            let _ = shared_stream.set(stream);
139            let mut stream = shared_stream.get().unwrap();
140
141            if cancel.load(Ordering::Acquire) {
142                return;
143            }
144
145            let confirm_msg = app.confirm_message();
146            let mut data = vec![0; confirm_msg.len()];
147            if let Err(e) = stream.read_exact(&mut data) {
148                let _ = tx.send(format!("failed to read from TCP stream: {e}"));
149                return;
150            }
151
152            if data.as_slice() != &*confirm_msg {
153                let _ = tx.send(format!(
154                    "Invalid data sent over TCP stream while waiting for restart confirmation message: {}",
155                    String::from_utf8_lossy(&data)
156                ));
157                return;
158            }
159
160            app.exit();
161        });
162        let wait_result = loop {
163            match proc.wait_for(Some(1000)) {
164                // Child process started:
165                Ok(_code) => break Ok(()),
166                // Handle errors in the other thread:
167                Err(e) if e.kind() == std::io::ErrorKind::TimedOut => match rx.try_recv() {
168                    Ok(err) => {
169                        handle.join().unwrap(); // Wait for thread to exit...
170                        return Err(err); // Then return its error
171                    }
172                    // Other thread exited unexpectedly:
173                    Err(TryRecvError::Disconnected) => {
174                        return Err(
175                            "Failed to wait for restarted process to confirm it had started"
176                                .to_string(),
177                        )
178                    }
179                    // Other thread hasn't made progress:
180                    Err(TryRecvError::Empty) => {}
181                },
182                Err(e) => break Err(format!("failed to wait for child process to exit: {e}")),
183            }
184        };
185        // Child process exited or we timed out (failed to change elevation!)
186
187        // Notify other thread to exit:
188        cancel.store(true, Ordering::Release);
189        if let Some(stream) = shared_stream.get() {
190            // Cancel work on the other thread.
191            let _ = stream.shutdown(Shutdown::Both);
192        } else {
193            // Connect to the listener so that it unblocks:
194            let _ = TcpStream::connect_timeout(&addr, Duration::from_millis(3000));
195        }
196        // Attempt to wait for other thread to exit:
197        let _ = rx.recv_timeout(Duration::from_millis(10_000));
198
199        wait_result
200    })?;
201
202    let code = proc
203        .exit_code()
204        .map_err(|e| format!("failed to get exit code for child process: {e}"))?;
205
206    Err(format!("Failed to spawn child process (exit code: {code})"))
207}
208
209pub struct AdminRestart;
210impl AdminRestart {
211    const RESTARTED_ARG: &'static str = "restarted";
212    const RESTART_TCP_MSG: &'static str = "restarted-backup-manager";
213
214    pub fn handle_startup(&self) {
215        if std::env::args().nth(1).as_deref() == Some(Self::RESTARTED_ARG) {
216            use std::io::Write;
217
218            tracing::info!(
219                args = ?std::env::args().skip(2).collect::<Vec<_>>(),
220                "Program was restarted"
221            );
222
223            let port: u16 = std::env::args()
224                .nth(2)
225                .expect("2nd arg should be a port number")
226                .parse()
227                .expect("2nd arg should be a 16bit number");
228
229            tracing::debug!(
230                "Notifying parent process at port {port} that we have successfully started"
231            );
232
233            let mut stream = std::net::TcpStream::connect_timeout(
234                &([127, 0, 0, 1], port).into(),
235                std::time::Duration::from_millis(1500),
236            )
237            .expect("failed to connect to parent process");
238
239            tracing::trace!("Writing message to parent process to confirm that we have started");
240
241            stream
242                .write_all(Self::RESTART_TCP_MSG.as_bytes())
243                .expect("failed to write data to parent process");
244
245            drop(stream);
246            // Wait for parent process to exit (only one instance of the app
247            // should be running):
248            std::thread::sleep(std::time::Duration::from_millis(500));
249        }
250    }
251}
252impl SetElevationHandler for AdminRestart {
253    fn get_args(&mut self, port: u16) -> Vec<OsString> {
254        vec![
255            OsString::from(Self::RESTARTED_ARG),
256            OsString::from(port.to_string()),
257        ]
258    }
259
260    fn exit(&mut self) -> ! {
261        std::process::exit(0);
262    }
263
264    fn confirm_message(&mut self) -> Cow<'_, [u8]> {
265        Cow::Borrowed(Self::RESTART_TCP_MSG.as_bytes())
266    }
267}