Skip to content

Commit

Permalink
Refactored service connection handling. #291
Browse files Browse the repository at this point in the history
  • Loading branch information
gerardog committed Aug 27, 2023
1 parent 449f048 commit ad18727
Show file tree
Hide file tree
Showing 9 changed files with 225 additions and 191 deletions.
26 changes: 18 additions & 8 deletions src/gsudo/Commands/RunCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -98,18 +98,28 @@ private async Task<int> RunUsingService(ElevationRequest elevationRequest)
try
{
var callingPid = ProcessHelper.GetCallerPid();

Logger.Instance.Log($"Caller PID: {callingPid}", LogLevel.Debug);

connection = await ServiceHelper.Connect().ConfigureAwait(false);
var serviceLocation = await ServiceHelper.FindAnyServiceFast().ConfigureAwait(false);
if (serviceLocation == null)
{
var serviceHandle = ServiceHelper.StartService(callingPid, singleUse: InputArguments.KillCache);
serviceLocation = await ServiceHelper.WaitForNewService(callingPid).ConfigureAwait(false);
}

if (connection == null) // service is not running or listening.
if (!InputArguments.IntegrityLevel.HasValue)
{
var service = ServiceHelper.StartService(callingPid, singleUse: InputArguments.KillCache);
connection = await ServiceHelper.Connect(callingPid).ConfigureAwait(false);
// This is the edge case where user does `gsudo -u SomeOne` and we dont know if SomeOne can elevate or not.
elevationRequest.IntegrityLevel = serviceLocation.IsHighIntegrity ? IntegrityLevel.High : IntegrityLevel.Medium;
}

if (serviceLocation==null)
throw new ApplicationException("Unable to connect to the elevated service.");

if (connection == null) // still not listening.
throw new ApplicationException("Unable to connect to the elevated service.");
connection = await ServiceHelper.Connect(serviceLocation).ConfigureAwait(false);
if (connection == null) // service is not running or listening.
{
throw new ApplicationException("Unable to connect to the elevated service.");
}

var renderer = GetRenderer(connection, elevationRequest);
Expand All @@ -133,7 +143,7 @@ private static int RunWithoutService(ElevationRequest elevationRequest)
// No need to escalate. Run in-process
Native.ConsoleApi.SetConsoleCtrlHandler(ConsoleHelper.IgnoreConsoleCancelKeyPress, true);

ConsoleHelper.SetPrompt(elevationRequest, InputArguments.GetIntegrityLevel() >= IntegrityLevel.High);
ConsoleHelper.SetPrompt(elevationRequest);

if (sameIntegrity)
{
Expand Down
2 changes: 1 addition & 1 deletion src/gsudo/Commands/ServiceCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ private async Task AcceptConnection(Connection connection)
ServiceHelper.StartService(AllowedPid, CacheDuration, AllowedSid, SingleUse);
}

ConsoleHelper.SetPrompt(request, connection.IsHighIntegrity);
ConsoleHelper.SetPrompt(request);
await applicationHost.Start(connection, request).ConfigureAwait(false);

//if (replaceService)
Expand Down
4 changes: 2 additions & 2 deletions src/gsudo/Helpers/ConsoleHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -122,11 +122,11 @@ internal static SecureString ReadConsolePassword(string userName)
return pass;
}

internal static void SetPrompt(ElevationRequest elevationRequest, bool isElevated)
internal static void SetPrompt(ElevationRequest elevationRequest)
{
if (!string.IsNullOrEmpty(elevationRequest.Prompt))
{
if (!isElevated)
if (elevationRequest.IntegrityLevel < IntegrityLevel.High)
Environment.SetEnvironmentVariable("PROMPT", Environment.GetEnvironmentVariable("PROMPT", EnvironmentVariableTarget.User) ?? Environment.GetEnvironmentVariable("PROMPT", EnvironmentVariableTarget.Machine) ?? "$P$G");
else
Environment.SetEnvironmentVariable("PROMPT", Environment.ExpandEnvironmentVariables(elevationRequest.Prompt));
Expand Down
266 changes: 174 additions & 92 deletions src/gsudo/Helpers/ServiceHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,66 +9,147 @@

namespace gsudo.Helpers
{
class ServiceLocation
{
public string PipeName { get; set; }
public bool IsHighIntegrity { get; set; }
}

internal static class ServiceHelper
{
internal static IRpcClient GetRpcClient()
{
// future Tcp implementations should be plugged here.
return new NamedPipeClient();
}

internal static async Task<Connection> Connect(int? callingPid = null)
{
IRpcClient rpcClient = GetRpcClient();

try
{
return await rpcClient.Connect(callingPid).ConfigureAwait(false);
}
catch (System.IO.IOException) { }
catch (TimeoutException) { }
catch (Exception ex)
{
if (callingPid.HasValue)
Logger.Instance.Log(ex.ToString(), LogLevel.Warning);
}

return null;
}

internal static SafeProcessHandle StartService(int? allowedPid, TimeSpan? cacheDuration = null, string allowedSid = null, bool singleUse = false)
}

/// <summary>
/// Establishes a connection to a named pipe server.
/// </summary>
/// <param name="clientPid">Optional client process ID.</param>
/// <returns>A <see cref="Connection"/> object representing the connected named pipe, or null if connection fails.</returns>
public static async Task<ServiceLocation> WaitForNewService(int clientPid)
{
var currentSid = WindowsIdentity.GetCurrent().User.Value;

allowedPid = allowedPid ?? Process.GetCurrentProcess().GetCacheableRootProcessId();
allowedSid = allowedSid ?? Process.GetProcessById(allowedPid.Value)?.GetProcessUser()?.User.Value ?? currentSid;

string verb;
SafeProcessHandle ret;

Logger.Instance.Log($"Caller SID: {allowedSid}", LogLevel.Debug);

int timeoutMilliseconds = 10000;
ServiceLocation service;

string user = WindowsIdentity.GetCurrent().User.Value;
do
{
service = FindServiceByIntegrity(clientPid, user);
if (service != null)
return service;

// Retry until service has started.
await Task.Delay(50).ConfigureAwait(false);
timeoutMilliseconds -= 50;
}
while (service == null && timeoutMilliseconds > 0);

return service;
}

public static async Task<ServiceLocation> FindAnyServiceFast()

Check warning on line 52 in src/gsudo/Helpers/ServiceHelper.cs

View workflow job for this annotation

GitHub Actions / Test

This async method lacks 'await' operators and will run synchronously. Consider using the 'await' operator to await non-blocking API calls, or 'await Task.Run(...)' to do CPU-bound work on a background thread.

Check warning on line 52 in src/gsudo/Helpers/ServiceHelper.cs

View workflow job for this annotation

GitHub Actions / build / Test

This async method lacks 'await' operators and will run synchronously. Consider using the 'await' operator to await non-blocking API calls, or 'await Task.Run(...)' to do CPU-bound work on a background thread.

Check warning on line 52 in src/gsudo/Helpers/ServiceHelper.cs

View workflow job for this annotation

GitHub Actions / Test

This async method lacks 'await' operators and will run synchronously. Consider using the 'await' operator to await non-blocking API calls, or 'await Task.Run(...)' to do CPU-bound work on a background thread.
{
string user = WindowsIdentity.GetCurrent().User.Value;
var callerProcessId = Process.GetCurrentProcess().Id;
// Loop to search for a cache for the current process or its ancestors
int maxIterations = 20; // To avoid potential PID tree loops where an ancestor process has the same PID. (gerardog/gsudo#155)
while (callerProcessId > 0 && maxIterations-- > 0)
{
callerProcessId = ProcessHelper.GetParentProcessId(callerProcessId);

var service = FindServiceByIntegrity(callerProcessId, user);
if (service != null)
return service;
}
return null;
}

private static ServiceLocation FindServiceByIntegrity(int? clientPid, string user)
{
var anyIntegrity = InputArguments.UserName != null;
var tryHighIntegrity = !InputArguments.IntegrityLevel.HasValue || InputArguments.IntegrityLevel.Value >= IntegrityLevel.High;
var tryLowIntegrity = !InputArguments.IntegrityLevel.HasValue || InputArguments.IntegrityLevel.Value < IntegrityLevel.High;
if (tryHighIntegrity)
{
var pipeName = NamedPipeClient.TryGetServicePipe(user, clientPid.Value, true);
if (pipeName != null)
{
return new ServiceLocation
{
PipeName = pipeName,
IsHighIntegrity = true
};
}
}

if (tryLowIntegrity)
{
var pipeName = NamedPipeClient.TryGetServicePipe(user, clientPid.Value, false);
if (pipeName != null)
{
return new ServiceLocation
{
PipeName = pipeName,
IsHighIntegrity = false
};
}
}
return null;
}

internal static async Task<Connection> Connect(ServiceLocation service)
{
IRpcClient rpcClient = GetRpcClient();

try
{
return await rpcClient.Connect(service).ConfigureAwait(false);
}
catch (System.IO.IOException) { }
catch (TimeoutException) { }
catch (Exception ex)
{
Logger.Instance.Log(ex.ToString(), LogLevel.Warning);
}

return null;
}

internal static SafeProcessHandle StartService(int? allowedPid, TimeSpan? cacheDuration = null, string allowedSid = null, bool singleUse = false)
{
var currentSid = WindowsIdentity.GetCurrent().User.Value;

allowedPid = allowedPid ?? Process.GetCurrentProcess().GetCacheableRootProcessId();
allowedSid = allowedSid ?? Process.GetProcessById(allowedPid.Value)?.GetProcessUser()?.User.Value ?? currentSid;

string verb;
SafeProcessHandle ret;

Logger.Instance.Log($"Caller SID: {allowedSid}", LogLevel.Debug);

var @params = InputArguments.Debug ? "--debug " : string.Empty;
if (!InputArguments.RunAsSystem && InputArguments.IntegrityLevel.HasValue) @params += $"-i {InputArguments.IntegrityLevel.Value} ";
if (InputArguments.RunAsSystem) @params += "-s ";
if (InputArguments.TrustedInstaller) @params += "--ti ";
if (InputArguments.UserName != null) @params += $"-u {InputArguments.UserName} ";

verb = "gsudoservice";

if (!cacheDuration.HasValue || singleUse)
if (!InputArguments.RunAsSystem && InputArguments.IntegrityLevel.HasValue) @params += $"-i {InputArguments.IntegrityLevel.Value} ";
if (InputArguments.RunAsSystem) @params += "-s ";
if (InputArguments.TrustedInstaller) @params += "--ti ";
if (InputArguments.UserName != null) @params += $"-u {InputArguments.UserName} ";

verb = "gsudoservice";

if (!cacheDuration.HasValue || singleUse)
{
if (!Settings.CacheMode.Value.In(CredentialsCache.CacheMode.Auto) || singleUse)
{
verb = "gsudoelevate";
cacheDuration = TimeSpan.Zero;
}
else
cacheDuration = Settings.CacheDuration;
}

bool isAdmin = SecurityHelper.IsHighIntegrity();

if (!Settings.CacheMode.Value.In(CredentialsCache.CacheMode.Auto) || singleUse)
{
verb = "gsudoelevate";
cacheDuration = TimeSpan.Zero;
}
else
cacheDuration = Settings.CacheDuration;
}

bool isAdmin = SecurityHelper.IsHighIntegrity();

string commandLine = $"{@params}{verb} {allowedPid} {allowedSid} {Settings.LogLevel} {Settings.TimeSpanWithInfiniteToString(cacheDuration.Value)}";

string ownExe = ProcessHelper.GetOwnExeName();
Expand Down Expand Up @@ -98,45 +179,46 @@ internal static SafeProcessHandle StartService(int? allowedPid, TimeSpan? cacheD
}
else
{
ret = ProcessFactory.StartElevatedDetached(ownExe, commandLine, !InputArguments.Debug).GetSafeProcessHandle();
}

Logger.Instance.Log("Service process started.", LogLevel.Debug);
return ret;
}

private static void StartTrustedInstallerService(string commandLine, int pid)
{
string name = $"gsudo TI Cache for PID {pid}";

string args = $"/Create /ru \"NT SERVICE\\TrustedInstaller\" /TN \"{name}\" /TR \"\\\"{ProcessHelper.GetOwnExeName()}\\\" {commandLine}\" /sc ONCE /st 00:00 /f\"";
Logger.Instance.Log($"Running: schtasks {args}", LogLevel.Debug);
Process p;

p = InputArguments.Debug
? ProcessFactory.StartAttached("schtasks", args)
: ProcessFactory.StartRedirected("schtasks", args, null);

p.WaitForExit();
if (p.ExitCode != 0) throw new ApplicationException($"Error creating a scheduled task for TrustedInstaller: {p.ExitCode}");

try
{
args = $"/run /I /TN \"{name}\"";
p = InputArguments.Debug
? ProcessFactory.StartAttached("schtasks", args)
: ProcessFactory.StartRedirected("schtasks", args, null);
p.WaitForExit();
if (p.ExitCode != 0) throw new ApplicationException($"Error starting scheduled task for TrustedInstaller: {p.ExitCode}");
}
finally
{
args = $"/delete /F /TN \"{name}\"";
p = InputArguments.Debug
? ProcessFactory.StartAttached("schtasks", args)
: ProcessFactory.StartRedirected("schtasks", args, null);
p.WaitForExit();
}
}
}
ret = ProcessFactory.StartElevatedDetached(ownExe, commandLine, !InputArguments.Debug).GetSafeProcessHandle();
}

Logger.Instance.Log("Service process started.", LogLevel.Debug);
return ret;
}

private static void StartTrustedInstallerService(string commandLine, int pid)
{
string name = $"gsudo TI Cache for PID {pid}";

string args = $"/Create /ru \"NT SERVICE\\TrustedInstaller\" /TN \"{name}\" /TR \"\\\"{ProcessHelper.GetOwnExeName()}\\\" {commandLine}\" /sc ONCE /st 00:00 /f\"";
Logger.Instance.Log($"Running: schtasks {args}", LogLevel.Debug);
Process p;

p = InputArguments.Debug
? ProcessFactory.StartAttached("schtasks", args)
: ProcessFactory.StartRedirected("schtasks", args, null);

p.WaitForExit();
if (p.ExitCode != 0) throw new ApplicationException($"Error creating a scheduled task for TrustedInstaller: {p.ExitCode}");

try
{
args = $"/run /I /TN \"{name}\"";
p = InputArguments.Debug
? ProcessFactory.StartAttached("schtasks", args)
: ProcessFactory.StartRedirected("schtasks", args, null);
p.WaitForExit();
if (p.ExitCode != 0) throw new ApplicationException($"Error starting scheduled task for TrustedInstaller: {p.ExitCode}");
}
finally
{
args = $"/delete /F /TN \"{name}\"";
p = InputArguments.Debug
? ProcessFactory.StartAttached("schtasks", args)
: ProcessFactory.StartRedirected("schtasks", args, null);
p.WaitForExit();
}
}
}
}

2 changes: 1 addition & 1 deletion src/gsudo/ProcessRenderers/TokenSwitchRenderer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ internal TokenSwitchRenderer(Connection connection, ElevationRequest elevationRe

_connection = connection;
_elevationRequest = elevationRequest;
ConsoleHelper.SetPrompt(elevationRequest, connection.IsHighIntegrity);
ConsoleHelper.SetPrompt(elevationRequest);

ProcessApi.CreateProcessFlags dwCreationFlags = ProcessApi.CreateProcessFlags.CREATE_SUSPENDED;

Expand Down
4 changes: 1 addition & 3 deletions src/gsudo/Rpc/Connection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,10 @@ class Connection : IDisposable
{
private PipeStream _dataStream;
private PipeStream _controlStream;
public bool IsHighIntegrity { get; }
public Connection(PipeStream ControlStream, PipeStream DataStream, bool isHighIntegrity)
public Connection(PipeStream ControlStream, PipeStream DataStream)
{
_dataStream = DataStream;
_controlStream = ControlStream;
IsHighIntegrity = isHighIntegrity;
}

public Stream DataStream => _dataStream;
Expand Down
5 changes: 3 additions & 2 deletions src/gsudo/Rpc/IRpcClient.cs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
using Microsoft.Win32.SafeHandles;
using gsudo.Helpers;
using Microsoft.Win32.SafeHandles;
using System.Threading.Tasks;

namespace gsudo.Rpc
{
internal interface IRpcClient
{
Task<Connection> Connect(int? clientPid = null);
Task<Connection> Connect(ServiceLocation service);
}
}
Loading

0 comments on commit ad18727

Please sign in to comment.