Skip to content

Commit

Permalink
Use Runspace scoped registrations (#3)
Browse files Browse the repository at this point in the history
Move away from thread local storage for registrations to a new method
that scopes the storage to a Runspace. This is because a Runspace where
a module is loaded is not guaranteed to always run on the same thread
breaking the previously registered forges.
  • Loading branch information
jborean93 authored Mar 21, 2024
1 parent e4fd2f0 commit f528ca3
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 12 deletions.
2 changes: 1 addition & 1 deletion src/RemoteForge/Commands/RemoteForgeCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ protected override void ProcessRecord()

protected override void EndProcessing()
{
foreach (RemoteForgeRegistration forge in RemoteForgeRegistration.Registrations.ToArray())
foreach (RemoteForgeRegistration forge in RegistrationStorage.GetFromTLS().Registrations.ToArray())
{
WriteVerbose($"Checking for forge '{forge.Name}' matches requested Name");

Expand Down
2 changes: 1 addition & 1 deletion src/RemoteForge/OnImportAndRemove.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ public void OnImport()

public void OnRemove(PSModuleInfo module)
{
RemoteForgeRegistration.Registrations.Clear();
RegistrationStorage.GetFromTLS().Registrations.Clear();
}

private static SSHConnectionInfo CreateSshConnectionInfo(string info)
Expand Down
50 changes: 40 additions & 10 deletions src/RemoteForge/RemoteForgeRegistration.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,50 @@
using System.Management.Automation;
using System.Management.Automation.Runspaces;
using System.Reflection;
using System.Runtime.CompilerServices;
using System.Threading;

namespace RemoteForge;

public delegate RunspaceConnectionInfo RemoteForgeFactory(string info);

public sealed class RemoteForgeRegistration
internal class RunspaceSpecificStorage<T>
{
private readonly ConditionalWeakTable<Runspace, Lazy<T>> _map = new();

private readonly Func<T> _factory;

private readonly LazyThreadSafetyMode _mode = LazyThreadSafetyMode.ExecutionAndPublication;

public RunspaceSpecificStorage(Func<T> factory)
{
_factory = factory;
}

public T GetFromTLS()
=> GetForRunspace(Runspace.DefaultRunspace);

public T GetForRunspace(Runspace runspace)
{
return _map.GetValue(
runspace,
_ => new Lazy<T>(() => _factory(), _mode))
.Value;
}
}

internal sealed class RegistrationStorage
{
// We use a thread local storage value so the registrations are scoped to
// a Runspace rather than be process wide.
[ThreadStatic]
private static List<RemoteForgeRegistration>? _registrations;
private static RunspaceSpecificStorage<RegistrationStorage> _registrations = new(() => new());

internal static List<RemoteForgeRegistration> Registrations => _registrations ??= new();
public List<RemoteForgeRegistration> Registrations = new();

public static RegistrationStorage GetFromTLS() => _registrations.GetFromTLS();
}


public sealed class RemoteForgeRegistration
{
public string Name { get; }
public string? Description { get; }
internal RemoteForgeFactory CreateFactory { get; }
Expand Down Expand Up @@ -85,7 +115,7 @@ internal static RemoteForgeRegistration Register(
}
else if (force)
{
Registrations.Remove(forge);
RegistrationStorage.GetFromTLS().Registrations.Remove(forge);
}
else
{
Expand All @@ -94,7 +124,7 @@ internal static RemoteForgeRegistration Register(
}

RemoteForgeRegistration registration = new(name, description, factory);
Registrations.Add(registration);
RegistrationStorage.GetFromTLS().Registrations.Add(registration);

return registration;
}
Expand All @@ -103,7 +133,7 @@ public static void Unregister(string name)
{
if (TryGetForgeRegistration(name, out RemoteForgeRegistration? forge))
{
Registrations.Remove(forge);
RegistrationStorage.GetFromTLS().Registrations.Remove(forge);
}
else
{
Expand Down Expand Up @@ -151,7 +181,7 @@ private static bool TryGetForgeRegistration(
[NotNullWhen(true)] out RemoteForgeRegistration? registration)
{
string lowerId = name.ToLowerInvariant();
foreach (RemoteForgeRegistration forge in Registrations)
foreach (RemoteForgeRegistration forge in RegistrationStorage.GetFromTLS().Registrations)
{
if (forge.Name.ToLowerInvariant() == lowerId)
{
Expand Down
3 changes: 3 additions & 0 deletions tests/units/RemoteForgeRegistrationTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ public class RemoteForgeRegistrationTests : IDisposable

public RemoteForgeRegistrationTests() : base()
{
Runspace.DefaultRunspace = RunspaceFactory.CreateRunspace(InitialSessionState.CreateDefault2());
Runspace.DefaultRunspace.Open();
_remoteForgeModule.OnImport();
}

Expand Down Expand Up @@ -94,5 +96,6 @@ public void GetWSManConnectionInfoInvalidHostname(string info)
public void Dispose()
{
_remoteForgeModule.OnRemove(null!);
Runspace.DefaultRunspace?.Dispose();
}
}
3 changes: 3 additions & 0 deletions tests/units/StringForgeConnectionInfoPSSessionTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ public class StringForgeConnectionInfoPSSessionTests : IDisposable

public StringForgeConnectionInfoPSSessionTests() : base()
{
Runspace.DefaultRunspace = RunspaceFactory.CreateRunspace(InitialSessionState.CreateDefault2());
Runspace.DefaultRunspace.Open();
_remoteForgeModule.OnImport();
}

Expand Down Expand Up @@ -115,5 +117,6 @@ public void ToStringFromCustomConnectionInfo()
public void Dispose()
{
_remoteForgeModule.OnRemove(null!);
Runspace.DefaultRunspace?.Dispose();
}
}

0 comments on commit f528ca3

Please sign in to comment.