diff --git a/src/lib.rs b/src/lib.rs index 2087fc8..2c2986a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,15 +3,15 @@ //! You can use [`trace_command`] to execute and sample an [`std::process::Command`]. //! //! Or you can use [`trace_child`] to start tracing an [`std::process::Child`]. -// You can also trace an arbitrary process using [`trace_pid`]. +//! You can also trace an arbitrary process using [`trace_pid`]. #![allow(clippy::field_reassign_with_default)] use object::Object; use windows::core::{GUID, PCSTR, PSTR}; use windows::Win32::Foundation::{ - CloseHandle, DuplicateHandle, GetLastError, DUPLICATE_SAME_ACCESS, ERROR_SUCCESS, - ERROR_WMI_INSTANCE_NOT_FOUND, HANDLE, INVALID_HANDLE_VALUE, WIN32_ERROR, + CloseHandle, GetLastError, ERROR_SUCCESS, ERROR_WMI_INSTANCE_NOT_FOUND, HANDLE, + INVALID_HANDLE_VALUE, WIN32_ERROR, }; use windows::Win32::Security::{ AdjustTokenPrivileges, LookupPrivilegeValueW, SE_PRIVILEGE_ENABLED, TOKEN_ADJUST_PRIVILEGES, @@ -32,8 +32,8 @@ use windows::Win32::System::Diagnostics::Etw::{ use windows::Win32::System::SystemInformation::{GetVersionExA, OSVERSIONINFOA}; use windows::Win32::System::SystemServices::SE_SYSTEM_PROFILE_NAME; use windows::Win32::System::Threading::{ - GetCurrentProcess, GetCurrentThread, OpenProcessToken, SetThreadPriority, CREATE_SUSPENDED, - THREAD_PRIORITY_TIME_CRITICAL, + GetCurrentProcess, GetCurrentThread, OpenProcess, OpenProcessToken, SetThreadPriority, + WaitForSingleObject, CREATE_SUSPENDED, PROCESS_ALL_ACCESS, THREAD_PRIORITY_TIME_CRITICAL, }; use pdb_addr2line::{pdb::PDB, ContextPdbData}; @@ -41,7 +41,7 @@ use pdb_addr2line::{pdb::PDB, ContextPdbData}; use std::ffi::OsString; use std::io::{Read, Write}; use std::mem::size_of; -use std::os::windows::{ffi::OsStringExt, prelude::AsRawHandle}; +use std::os::windows::ffi::OsStringExt; use std::path::PathBuf; use std::ptr::{addr_of, addr_of_mut}; use std::sync::atomic::{AtomicBool, Ordering}; @@ -106,8 +106,10 @@ pub enum Error { Write(std::io::Error), /// Error spawning a suspended process SpawnErr(std::io::Error), - /// Error waiting for child - WaitOnChildErr(std::io::Error), + /// Error waiting for child, abandoned + WaitOnChildErrAbandoned, + /// Error waiting for child, timed out + WaitOnChildErrTimeout, /// A call to a windows API function returned an error and we didn't know how to handle it Other(WIN32_ERROR, String, &'static str), /// We require Windows 7 or greater @@ -146,24 +148,25 @@ fn get_last_error(extra: &'static str) -> Error { Error::Other(code, code_str.to_string(), extra) } -/// `h` must be a valid handle -unsafe fn clone_handle(h: HANDLE) -> Result { - let mut target_h = HANDLE::default(); - let ret = DuplicateHandle( - GetCurrentProcess(), - h, - GetCurrentProcess(), - &mut target_h, - 0, - false, - DUPLICATE_SAME_ACCESS, - ); - if ret.0 == 0 { - return Err(get_last_error("clone_handle")); +/// A wrapper around `OpenProcess` that returns a handle with all access rights +unsafe fn handle_from_process_id(process_id: u32) -> Result { + match OpenProcess(PROCESS_ALL_ACCESS, false, process_id) { + Ok(handle) => Ok(handle), + Err(_) => Err(get_last_error("handle_from_process_id")), } - Ok(target_h) } -fn acquire_priviledges() -> Result<()> { + +unsafe fn wait_for_process_by_handle(handle: HANDLE) -> Result<()> { + let ret = WaitForSingleObject(handle, 0xFFFFFFFF); + match ret.0 { + 0 => Ok(()), + 0x00000080 => Err(Error::WaitOnChildErrAbandoned), + 0x00000102 => Err(Error::WaitOnChildErrTimeout), + _ => Err(get_last_error("wait_for_process_by_handle")), + } +} + +fn acquire_privileges() -> Result<()> { let mut privs = TOKEN_PRIVILEGES::default(); privs.PrivilegeCount = 1; privs.Privileges[0].Attributes = SE_PRIVILEGE_ENABLED; @@ -195,8 +198,8 @@ fn acquire_priviledges() -> Result<()> { Ok(()) } /// SAFETY: is_suspended must only be true if `target_process` is suspended -unsafe fn trace_from_process( - target_process: &mut std::process::Child, +unsafe fn trace_from_process_id( + target_process_id: u32, is_suspended: bool, kernel_stacks: bool, ) -> Result { @@ -213,7 +216,7 @@ unsafe fn trace_from_process( { return Err(Error::UnsupportedOsVersion); } - acquire_priviledges()?; + acquire_privileges()?; // Set the sampling interval // Only for Win8 or more @@ -263,6 +266,7 @@ unsafe fn trace_from_process( const PROPS_SIZE: usize = size_of::() + KERNEL_LOGGER_NAMEA_LEN + 1; #[derive(Clone)] #[repr(C)] + #[allow(non_camel_case_types)] struct EVENT_TRACE_PROPERTIES_WITH_STRING { data: EVENT_TRACE_PROPERTIES, s: [u8; KERNEL_LOGGER_NAMEA_LEN + 1], @@ -341,10 +345,8 @@ unsafe fn trace_from_process( } } - let target_pid = target_process.id(); - // std Child closes the handle when it drops so we clone it - let target_proc_handle = clone_handle(HANDLE(target_process.as_raw_handle() as isize))?; - let mut context = TraceContext::new(target_proc_handle, target_pid, kernel_stacks)?; + let target_proc_handle = handle_from_process_id(target_process_id)?; + let mut context = TraceContext::new(target_proc_handle, target_process_id, kernel_stacks)?; //TODO: Do we need to Box the context? let mut log = EVENT_TRACE_LOGFILEA::default(); @@ -438,6 +440,7 @@ unsafe fn trace_from_process( #[repr(C)] #[derive(Debug)] #[allow(non_snake_case)] + #[allow(non_camel_case_types)] struct EVENT_HEADERR { Size: u16, HeaderType: u16, @@ -456,6 +459,7 @@ unsafe fn trace_from_process( #[repr(C)] #[derive(Debug)] #[allow(non_snake_case)] + #[allow(non_camel_case_types)] struct EVENT_RECORDD { EventHeader: EVENT_HEADERR, BufferContextAnonymousProcessorNumber: u8, @@ -520,8 +524,9 @@ unsafe fn trace_from_process( std::mem::transmute(NtResumeProcess); NtResumeProcess(context.target_process_handle.0); } + // Wait for it to end - target_process.wait().map_err(Error::WaitOnChildErr)?; + wait_for_process_by_handle(target_proc_handle)?; // This unblocks ProcessTrace let ret = ControlTraceA( ::default(), @@ -552,14 +557,18 @@ unsafe fn trace_from_process( /// The sampled results from a process execution pub struct CollectionResults(TraceContext); +/// Trace an existing child process based only on its process ID (pid). +/// It is recommended that you use `trace_command` instead, since it suspends the process on creation +/// and only resumes it after the trace has started, ensuring that all samples are captured. +pub fn trace_pid(process_id: u32, kernel_stacks: bool) -> Result { + let res = unsafe { trace_from_process_id(process_id, false, kernel_stacks) }; + res.map(CollectionResults) +} /// Trace an existing child process. /// It is recommended that you use `trace_command` instead, since it suspends the process on creation /// and only resumes it after the trace has started, ensuring that all samples are captured. -pub fn trace_child( - mut process: std::process::Child, - kernel_stacks: bool, -) -> Result { - let res = unsafe { trace_from_process(&mut process, false, kernel_stacks) }; +pub fn trace_child(process: std::process::Child, kernel_stacks: bool) -> Result { + let res = unsafe { trace_from_process_id(process.id(), false, kernel_stacks) }; res.map(CollectionResults) } /// Execute `command` and trace it, periodically collecting call stacks. @@ -578,7 +587,7 @@ pub fn trace_command( .creation_flags(CREATE_SUSPENDED.0) .spawn() .map_err(Error::SpawnErr)?; - let res = unsafe { trace_from_process(&mut proc, true, kernel_stacks) }; + let res = unsafe { trace_from_process_id(proc.id(), true, kernel_stacks) }; if res.is_err() { // Kill the suspended process if we had some kind of error let _ = proc.kill(); @@ -718,7 +727,7 @@ impl<'a> CallStack<'a> { /// Iterate addresses in this callstack /// /// This also performs symbol resolution if possible, and tries to find the image (DLL/EXE) it comes from - fn iter_resolved_addresses2< + fn iter_resolved_addresses< F: for<'b> FnMut(u64, u64, &'b [&'b str], Option<&'b str>) -> Result<()>, >( &'a self, @@ -742,7 +751,7 @@ impl<'a> CallStack<'a> { } let mut symbol_names = symbol_names_storage; - let module = pdb_db.range(..addr).rev().next(); + let module = pdb_db.range(..addr).next_back(); let module = match module { None => { f(addr, 0, &[], None)?; @@ -790,7 +799,7 @@ impl CollectionResults { let mut v = vec![]; for callstack in self.iter_callstacks() { - callstack.iter_resolved_addresses2( + callstack.iter_resolved_addresses( &pdb_db, &mut v, |address, displacement, symbol_names, image_name| { @@ -882,6 +891,7 @@ fn list_kernel_modules() -> Vec<(OsString, u64, u64)> { #[repr(C)] #[derive(Debug)] #[allow(non_snake_case)] + #[allow(non_camel_case_types)] struct _RTL_PROCESS_MODULE_INFORMATION { Section: *mut std::ffi::c_void, MappedBase: *mut std::ffi::c_void,