virtual_desktop_manager/
change_elevation.rs1use deelevate::{Command, PrivilegeLevel, Token};
2use std::{
3 any::Any,
4 borrow::Cow,
5 ffi::OsString,
6 io::{Read, Result as IoResult},
7 net::{Shutdown, TcpStream},
8 sync::{
9 atomic::{AtomicBool, Ordering},
10 mpsc::{self, TryRecvError},
11 OnceLock,
12 },
13 time::Duration,
14};
15
16#[allow(clippy::type_complexity)]
18struct Process<T = Box<dyn Any>> {
19 state: T,
20 wait_for: Box<dyn Fn(&T, Option<u32>) -> IoResult<u32>>,
23 exit_code: Box<dyn Fn(&T) -> IoResult<u32>>,
25}
26impl<T> Process<T> {
27 fn into_any(self) -> Process
28 where
29 T: Any,
30 {
31 Process {
32 state: Box::new(self.state),
33 wait_for: Box::new(move |s, duration| {
34 (self.wait_for)(s.downcast_ref().unwrap(), duration)
35 }),
36 exit_code: Box::new(move |s| (self.exit_code)(s.downcast_ref().unwrap())),
37 }
38 }
39
40 pub fn wait_for(&self, duration: Option<u32>) -> IoResult<u32> {
43 (self.wait_for)(&self.state, duration)
44 }
45
46 pub fn exit_code(&self) -> IoResult<u32> {
48 (self.exit_code)(&self.state)
49 }
50}
51macro_rules! into_process {
52 ($expr:expr) => {
53 Process {
54 state: $expr,
55 wait_for: Box::new(|s, dur| s.wait_for(dur)),
56 exit_code: Box::new(|s| s.exit_code()),
57 }
58 .into_any()
59 };
60}
61
62pub trait SetElevationHandler: Send {
63 fn get_args(&mut self, port: u16) -> Vec<OsString>;
64 fn exit(&mut self) -> !;
65 fn confirm_message(&mut self) -> Cow<'_, [u8]>;
66}
67
68pub fn set_elevation(
69 app: &mut dyn SetElevationHandler,
70 should_elevate: bool,
71) -> Result<(), String> {
72 let token = Token::with_current_process()
73 .map_err(|e| format!("failed to get token for current process: {e}"))?;
74 let level = token
75 .privilege_level()
76 .map_err(|e| format!("failed to get privilege level of token: {e}"))?;
77
78 let target_token = if should_elevate {
79 match level {
80 PrivilegeLevel::NotPrivileged => token
81 .as_medium_integrity_safer_token()
82 .map_err(|e| format!("failed to create token with elevated privilege: {e}"))?,
83 PrivilegeLevel::HighIntegrityAdmin | PrivilegeLevel::Elevated => return Ok(()),
84 }
85 } else {
86 match level {
87 PrivilegeLevel::NotPrivileged => return Ok(()),
88 PrivilegeLevel::Elevated => Token::with_shell_process()
89 .map_err(|e| format!("failed to find token for shell process: {e}"))?,
90 PrivilegeLevel::HighIntegrityAdmin => token
91 .as_medium_integrity_safer_token()
92 .map_err(|e| format!("failed to change privilege level of token: {e}"))?,
93 }
94 };
95
96 let mut command = Command::with_environment_for_token(&target_token)
97 .map_err(|e| format!("failed to create environment for child process: {e}"))?;
98
99 let current_exe = std::env::current_exe()
100 .map_err(|e| format!("failed to resolve path to current executable: {e}"))?;
101
102 let tcp = std::net::TcpListener::bind("127.0.0.1:0")
103 .map_err(|e| format!("failed to open a local TCP connection: {e}"))?;
104 let addr = tcp
105 .local_addr()
106 .map_err(|e| format!("failed to get info about local TCP connection: {e}"))?;
107
108 command.set_argv({
109 let mut args = app.get_args(addr.port());
110 args.insert(0, OsString::from(current_exe));
111 args
112 });
113 let proc = if should_elevate {
114 command
115 .shell_execute("runas")
116 .map_err(|e| format!("failed to spawn elevated process: {e}"))?
117 } else {
118 command
119 .spawn_with_token(&target_token)
120 .map_err(|e| format!("failed to spawn child process: {e}"))?
121 };
122 let proc = into_process!(proc);
123
124 let cancel = AtomicBool::new(false);
125 let shared_stream = OnceLock::<TcpStream>::new();
126 std::thread::scope(|s| -> Result<(), String> {
127 let (tx, rx) = mpsc::channel();
128 let handle = s.spawn(|| {
129 let tx = tx;
130 let (stream, _addr) = match tcp.accept() {
131 Ok(v) => v,
132 Err(e) => {
133 let _ = tx.send(format!("failed to accept TCP connection: {e}"));
134 return
135 }
136 };
137
138 let _ = shared_stream.set(stream);
139 let mut stream = shared_stream.get().unwrap();
140
141 if cancel.load(Ordering::Acquire) {
142 return;
143 }
144
145 let confirm_msg = app.confirm_message();
146 let mut data = vec![0; confirm_msg.len()];
147 if let Err(e) = stream.read_exact(&mut data) {
148 let _ = tx.send(format!("failed to read from TCP stream: {e}"));
149 return;
150 }
151
152 if data.as_slice() != &*confirm_msg {
153 let _ = tx.send(format!(
154 "Invalid data sent over TCP stream while waiting for restart confirmation message: {}",
155 String::from_utf8_lossy(&data)
156 ));
157 return;
158 }
159
160 app.exit();
161 });
162 let wait_result = loop {
163 match proc.wait_for(Some(1000)) {
164 Ok(_code) => break Ok(()),
166 Err(e) if e.kind() == std::io::ErrorKind::TimedOut => match rx.try_recv() {
168 Ok(err) => {
169 handle.join().unwrap(); return Err(err); }
172 Err(TryRecvError::Disconnected) => {
174 return Err(
175 "Failed to wait for restarted process to confirm it had started"
176 .to_string(),
177 )
178 }
179 Err(TryRecvError::Empty) => {}
181 },
182 Err(e) => break Err(format!("failed to wait for child process to exit: {e}")),
183 }
184 };
185 cancel.store(true, Ordering::Release);
189 if let Some(stream) = shared_stream.get() {
190 let _ = stream.shutdown(Shutdown::Both);
192 } else {
193 let _ = TcpStream::connect_timeout(&addr, Duration::from_millis(3000));
195 }
196 let _ = rx.recv_timeout(Duration::from_millis(10_000));
198
199 wait_result
200 })?;
201
202 let code = proc
203 .exit_code()
204 .map_err(|e| format!("failed to get exit code for child process: {e}"))?;
205
206 Err(format!("Failed to spawn child process (exit code: {code})"))
207}
208
209pub struct AdminRestart;
210impl AdminRestart {
211 const RESTARTED_ARG: &'static str = "restarted";
212 const RESTART_TCP_MSG: &'static str = "restarted-backup-manager";
213
214 pub fn handle_startup(&self) {
215 if std::env::args().nth(1).as_deref() == Some(Self::RESTARTED_ARG) {
216 use std::io::Write;
217
218 tracing::info!(
219 args = ?std::env::args().skip(2).collect::<Vec<_>>(),
220 "Program was restarted"
221 );
222
223 let port: u16 = std::env::args()
224 .nth(2)
225 .expect("2nd arg should be a port number")
226 .parse()
227 .expect("2nd arg should be a 16bit number");
228
229 tracing::debug!(
230 "Notifying parent process at port {port} that we have successfully started"
231 );
232
233 let mut stream = std::net::TcpStream::connect_timeout(
234 &([127, 0, 0, 1], port).into(),
235 std::time::Duration::from_millis(1500),
236 )
237 .expect("failed to connect to parent process");
238
239 tracing::trace!("Writing message to parent process to confirm that we have started");
240
241 stream
242 .write_all(Self::RESTART_TCP_MSG.as_bytes())
243 .expect("failed to write data to parent process");
244
245 drop(stream);
246 std::thread::sleep(std::time::Duration::from_millis(500));
249 }
250 }
251}
252impl SetElevationHandler for AdminRestart {
253 fn get_args(&mut self, port: u16) -> Vec<OsString> {
254 vec![
255 OsString::from(Self::RESTARTED_ARG),
256 OsString::from(port.to_string()),
257 ]
258 }
259
260 fn exit(&mut self) -> ! {
261 std::process::exit(0);
262 }
263
264 fn confirm_message(&mut self) -> Cow<'_, [u8]> {
265 Cow::Borrowed(Self::RESTART_TCP_MSG.as_bytes())
266 }
267}