From caf0952403f0b1d2ac44acae6288a67db09c8d00 Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 3 Oct 2022 10:06:19 -0600 Subject: [PATCH] WIP - DistributedRoutingTable supporting classes --- Net/DistributedRoutingTable.cs | 1716 +++++++++++++++++++++++--------- 1 file changed, 1245 insertions(+), 471 deletions(-) diff --git a/Net/DistributedRoutingTable.cs b/Net/DistributedRoutingTable.cs index cad403567..d42af91eb 100644 --- a/Net/DistributedRoutingTable.cs +++ b/Net/DistributedRoutingTable.cs @@ -1,15 +1,13 @@ -#if LFJSDFHJDSJFHSDJFH -#nullable enable -#pragma warning disable CS1591 // Missing XML comment for publicly visible type or member +#nullable enable + using System; -using System.Collections.Generic; using System.Linq; +using System.Net; using System.Runtime.InteropServices; -using System.Text; -using System.Threading.Tasks; using Vanara.Extensions; using Vanara.InteropServices; using Vanara.PInvoke; +using static Vanara.Net.DrtUtil; using static Vanara.PInvoke.Crypt32; using static Vanara.PInvoke.Drt; using static Vanara.PInvoke.Kernel32; @@ -17,55 +15,26 @@ namespace Vanara.Net; +/// Represents a distributed routing table from Win32. public class DistributedRoutingTable { - private SafeHDRT hDrt; - private SafeEventHandle evt; - private SafeRegisteredWaitHandle drtWaitEvent; - private readonly Ws2_32.SafeWSA ws = Ws2_32.SafeWSA.Initialize(); - - /// Provides a for that is disposed using . - public class SafeHDRT : SafeHANDLE - { - /// Initializes a new instance of the class and assigns an existing handle. - /// An object that represents the pre-existing handle to use. - /// to reliably release the handle during the finalization phase; otherwise, (not recommended). - public SafeHDRT(IntPtr preexistingHandle, bool ownsHandle = true) : base(preexistingHandle, ownsHandle) { } - - /// Initializes a new instance of the class. - private SafeHDRT() : base() { } - - /// Performs an implicit conversion from to . - /// The safe handle instance. - /// The result of the conversion. - public static implicit operator HDRT(SafeHDRT h) => h.handle; - - /// - protected override bool InternalReleaseHandle() { DrtClose(handle); return true; } - } - - public DistributedRoutingTable(PCCERT_CONTEXT pRootCert, PCCERT_CONTEXT pLocalCert, string hostName, ushort port) - { - IntPtr pBootstrapProvider, pSecurityProvider; - if (pRootCert.IsNull) - DrtCreateNullSecurityProvider(out pSecurityProvider).ThrowIfFailed(); - else - DrtCreateDerivedKeySecurityProvider(pRootCert, pLocalCert, out pSecurityProvider).ThrowIfFailed(); - - Init(pSecurityProvider, pBootstrapProvider); - } + private readonly SafeRegisteredWaitHandle? drtWaitEvent; + private readonly SafeEventHandle? evt; + private readonly SafeHDRT? hDrt; + private readonly SafeWSA ws = SafeWSA.Initialize(); + private DRT_SETTINGS pSettings; - public DistributedRoutingTable(ICustomDrtSecurityProvider securityProvider, ICustomBootstrapProvider bootstrapper) - { + /// Initializes a new instance of the class. + /// The security provider. + /// The bootstrapper. + public DistributedRoutingTable(DrtSecurityProvider? securityProvider, DrtBootstrapProvider bootstrapper) : + this((IntPtr)(securityProvider ?? DrtSecurityProvider.CreateNullSecurityProvider()), (IntPtr)bootstrapper) + { } - } - - private void Init(IntPtr pSecProv, IntPtr pBootProv) + private DistributedRoutingTable(IntPtr pSecProv, IntPtr pBootProv) { ushort port = 0; - DrtCreateIpv6UdpTransport(DRT_SCOPE.DRT_GLOBAL_SCOPE, 0, 300, ref port, out var hTransport).ThrowIfFailed(); - - DRT_SETTINGS pSettings = new() + pSettings = new() { dwSize = (uint)Marshal.SizeOf(typeof(DRT_SETTINGS)), cbKey = 32, @@ -73,410 +42,671 @@ private void Init(IntPtr pSecProv, IntPtr pBootProv) bProtocolMajorVersion = 0x6, bProtocolMinorVersion = 0x65, eSecurityMode = DRT_SECURITY_MODE.DRT_SECURE_CONFIDENTIALPAYLOAD, - pwzDrtInstancePrefix = "__VanaraDRT", + pwzDrtInstancePrefix = "__VanaraDRT" + Guid.NewGuid().ToString("N"), pSecurityProvider = pSecProv, - pBootstrapProvider = pBootProv + pBootstrapProvider = pBootProv, }; + DrtCreateIpv6UdpTransport(DRT_SCOPE.DRT_GLOBAL_SCOPE, 0, 300, ref port, out pSettings.hTransport).ThrowIfFailed(); + evt = CreateEvent(null, false, false); - DrtOpen(pSettings, evt, default, out var h).ThrowIfFailed(); + DrtOpen(pSettings, evt, default, out HDRT h).ThrowIfFailed(); hDrt = new((IntPtr)h); Win32Error.ThrowLastErrorIfFalse(RegisterWaitForSingleObject(out drtWaitEvent, evt, DrtEventCallback, default /*AddCtx(Drt)*/, INFINITE, WT.WT_EXECUTEDEFAULT)); } - void DrtEventCallback(IntPtr Param, bool TimedOut) + private void DrtEventCallback(IntPtr Param, bool TimedOut) { - HRESULT hr; - var Drt = GetCtx(Param); + // HRESULT hr; var Drt = GetCtx(Param); - hr = DrtGetEventDataSize(Drt.hDrt, out var ulDrtEventDataLen); - if (hr.Failed) + // hr = DrtGetEventDataSize(Drt.hDrt, out var ulDrtEventDataLen); if (hr.Failed) { if (hr != HRESULT.DRT_E_NO_MORE) Console.Write(" + // DrtGetEventDataSize failed: {0}\n", hr); goto Cleanup; } + + // using (var pEventData = new SafeCoTaskMemStruct(ulDrtEventDataLen)) { if (pEventData.IsInvalid) { Console.Write(" + // Out of memory\n"); goto Cleanup; } + + // hr = DrtGetEventData(Drt.hDrt, ulDrtEventDataLen, pEventData); if (hr.Failed) { if (hr != HRESULT.DRT_E_NO_MORE) Console.Write(" + // DrtGetEventData failed: {0}\n", hr); goto Cleanup; } + + // switch (pEventData.Value.type) { case DRT_EVENT_TYPE.DRT_EVENT_STATUS_CHANGED: switch (pEventData.Value.union.statusChange.status) + // { case DRT_STATUS.DRT_ACTIVE: SetConsoleTitle("DrtSdkSample Current Drt Status: Active"); if (g_DisplayEvents) Console.Write(" DRT + // Status Changed to Active\n"); break; case DRT_STATUS.DRT_ALONE: SetConsoleTitle("DrtSdkSample Current Drt Status: Alone"); if + // (g_DisplayEvents) Console.Write(" DRT Status Changed to Alone\n"); break; case DRT_STATUS.DRT_NO_NETWORK: + // SetConsoleTitle("DrtSdkSample Current Drt Status: No Network"); if (g_DisplayEvents) Console.Write(" DRT Status Changed to No + // Network\n"); break; case DRT_STATUS.DRT_FAULTED: SetConsoleTitle("DrtSdkSample Current Drt Status: Faulted"); if (g_DisplayEvents) + // Console.Write(" DRT Status Changed to Faulted\n"); break; } + + // break; case DRT_EVENT_TYPE.DRT_EVENT_LEAFSET_KEY_CHANGED: if (g_DisplayEvents) { switch + // (pEventData.Value.union.leafsetKeyChange.change) { case DRT_LEAFSET_KEY_CHANGE_TYPE.DRT_LEAFSET_KEY_ADDED: Console.Write(" Leafset + // Key Added Event: {0}\n", pEventData.Value.hr); break; case DRT_LEAFSET_KEY_CHANGE_TYPE.DRT_LEAFSET_KEY_DELETED: Console.Write(" + // Leafset Key Deleted Event: {0}\n", pEventData.Value.hr); break; } } + + // break; + // case DRT_EVENT_TYPE.DRT_EVENT_REGISTRATION_STATE_CHANGED: + // if (g_DisplayEvents) + // Console.Write(" Registration State Changed Event: [hr: 0x%x, registration state: %i]\n", pEventData.Value.hr, pEventData.Value.union.registrationStateChange.state); + // break; + // } + // } + //Cleanup: + // return; + } +} + +/// Abstract base class for a custom DRT bootstrap provider. +public class DrtBootstrapProvider : IDisposable +{ + /// The bootstrap provider structure. + protected DRT_BOOTSTRAP_PROVIDER prov; + + private readonly IntPtr pProv; + private readonly char pProvType; + + /// Initializes a new instance of the class. + /// The prov. + protected DrtBootstrapProvider(in DRT_BOOTSTRAP_PROVIDER prov) + { + this.prov = prov; + pProv = GCHandle.Alloc(this.prov, GCHandleType.Pinned).AddrOfPinnedObject(); + pProvType = 'h'; + } + + /// Initializes a new instance of the class. + private DrtBootstrapProvider() + { } + + /// Initializes a new instance of the class. + /// The PTR. + /// Type of the provider pointer. + private DrtBootstrapProvider(IntPtr ptr, char provType) + { + pProv = ptr; + pProvType = provType; + } + + /// Gets the local DNS host. + protected static string LocalDnsHost + { + get { - if (hr != HRESULT.DRT_E_NO_MORE) - Console.Write(" DrtGetEventDataSize failed: {0}\n", hr); - goto Cleanup; + Win32Error.ThrowLastErrorIfFalse(GetComputerNameEx(COMPUTER_NAME_FORMAT.ComputerNameDnsFullyQualified, out string? name)); + return name; } + } - using (var pEventData = new SafeCoTaskMemStruct(ulDrtEventDataLen)) + /// + /// Creates a bootstrap resolver that will use the GetAddrInfo system function to resolve the hostname of a will known node already + /// present in the DRT mesh. + /// + /// Specifies the hostname of the well known node. + /// Specifies the port to which the DRT protocol is bound on the well known node. + /// A DNS instance. + public static DrtBootstrapProvider CreateDnsBootstrapResolver(string? hostname, ushort port) + { + hostname ??= LocalDnsHost; + DrtCreateDnsBootstrapResolver(port, hostname, out IntPtr pbp).ThrowIfFailed(); + return new(pbp, 'd'); + } + + /// Creates a bootstrap resolver based on the Peer Name Resolution Protocol (PNRP). + /// The name of the peer to search for in the PNRP cloud. This string has a maximum limit of 137 unicode characters + /// + /// The name of the cloud to search for in for the DRT corresponding to the MeshName. + /// + /// This string has a maximum limit of 256 unicode characters. If left blank the PNRP Bootstrap Provider will use all PNRP clouds available. + /// + /// + /// + /// The PeerIdentity that is publishing into the PNRP cloud utilized for bootstrapping. This string has a maximum limit of 137 unicode + /// characters. The PublishingIdentity must be allowed to publish the PeerName specified. + /// + /// A PNRP instance. + /// + /// The default PNRP Bootstrap Resolver created by this function is specific to the DRT it is created for. As a result it cannot be + /// re-used across multiple DRTs. + /// + public static DrtBootstrapProvider CreatePnrpBootstrapResolver(string peerName, string? cloudName = null, string? publishingId = null) + { + DrtCreatePnrpBootstrapResolver(true, peerName, cloudName, publishingId, out IntPtr pbp).ThrowIfFailed(); + return new(pbp, 'p'); + } + + /// Performs an explicit conversion from to . + /// The prov. + /// The result of the conversion. + public static explicit operator IntPtr(DrtBootstrapProvider prov) => prov.pProv; + + /// Performs application-defined tasks associated with freeing, releasing, or resetting unmanaged resources. + public void Dispose() + { + if (pProv != default) { - if (pEventData.IsInvalid) + if (pProvType == 'd') { - Console.Write(" Out of memory\n"); - goto Cleanup; + DrtDeleteDnsBootstrapResolver(pProv); } - - hr = DrtGetEventData(Drt.hDrt, ulDrtEventDataLen, pEventData); - if (hr.Failed) + else if (pProvType == 'p') { - if (hr != HRESULT.DRT_E_NO_MORE) - Console.Write(" DrtGetEventData failed: {0}\n", hr); - goto Cleanup; + DrtDeletePnrpBootstrapResolver(pProv); } - - switch (pEventData.Value.type) + else if (pProvType == 'h') { - case DRT_EVENT_TYPE.DRT_EVENT_STATUS_CHANGED: - switch (pEventData.Value.union.statusChange.status) - { - case DRT_STATUS.DRT_ACTIVE: - SetConsoleTitle("DrtSdkSample Current Drt Status: Active"); - if (g_DisplayEvents) - Console.Write(" DRT Status Changed to Active\n"); - break; - case DRT_STATUS.DRT_ALONE: - SetConsoleTitle("DrtSdkSample Current Drt Status: Alone"); - if (g_DisplayEvents) - Console.Write(" DRT Status Changed to Alone\n"); - break; - case DRT_STATUS.DRT_NO_NETWORK: - SetConsoleTitle("DrtSdkSample Current Drt Status: No Network"); - if (g_DisplayEvents) - Console.Write(" DRT Status Changed to No Network\n"); - break; - case DRT_STATUS.DRT_FAULTED: - SetConsoleTitle("DrtSdkSample Current Drt Status: Faulted"); - if (g_DisplayEvents) - Console.Write(" DRT Status Changed to Faulted\n"); - break; - } - - break; - case DRT_EVENT_TYPE.DRT_EVENT_LEAFSET_KEY_CHANGED: - if (g_DisplayEvents) - { - switch (pEventData.Value.union.leafsetKeyChange.change) - { - case DRT_LEAFSET_KEY_CHANGE_TYPE.DRT_LEAFSET_KEY_ADDED: - Console.Write(" Leafset Key Added Event: {0}\n", pEventData.Value.hr); - break; - case DRT_LEAFSET_KEY_CHANGE_TYPE.DRT_LEAFSET_KEY_DELETED: - Console.Write(" Leafset Key Deleted Event: {0}\n", pEventData.Value.hr); - break; - } - } - - break; - case DRT_EVENT_TYPE.DRT_EVENT_REGISTRATION_STATE_CHANGED: - if (g_DisplayEvents) - Console.Write(" Registration State Changed Event: [hr: 0x%x, registration state: %i]\n", pEventData.Value.hr, pEventData.Value.union.registrationStateChange.state); - break; + GCHandle.FromIntPtr(pProv).Free(); } } - Cleanup: - return; } } -public interface ICustomBootstrapProvider +/// Abstract base class for a custom DRT bootstrap provider. +public abstract class DrtCustomBootstrapProvider : DrtBootstrapProvider { - /// Increments the count of references for the Bootstrap Provider with a set of DRTs. - /// Contains the pvContext value from DRT_BOOTSTRAP_PROVIDER. - /// - HRESULT Attach([In] IntPtr pvContext); + private int refCount; + + /// Initializes a new instance of the class. + protected DrtCustomBootstrapProvider(object? context = null) : base(default) + { + unsafe + { + prov.Attach = InternalAttach; + prov.Detach = InternalDetach; + prov.InitResolve = InternalInitResolve; + prov.IssueResolve = InternalIssueResolve; + prov.EndResolve = InternalEndResolve; + prov.Register = InternalRegister; + prov.Unregister = InternalUnregister; + if (context != null) + prov.pvContext = GCHandle.Alloc(context).AddrOfPinnedObject(); + } + AddRef(); + } - /// Decrements the count of references for the Bootstrap Provider with a set of DRTs. - /// Contains the pvContext value from DRT_BOOTSTRAP_PROVIDER. - void Detach([In] IntPtr pvContext); + /// Gets the context provided for all methods. + /// The context object. + protected virtual object? Context => prov.pvContext == IntPtr.Zero ? null : GCHandle.FromIntPtr(prov.pvContext).Target; + + /// Gets a value indicating whether this instance is attached. + /// + /// if this instance is attached; otherwise, . + /// + protected bool IsAttached => refCount > 0; /// Ends the resolution of an endpoint. - /// Contains the pvContext value from DRT_BOOTSTRAP_PROVIDER. /// /// The BOOTSTRAP_RESOLVE_CONTEXT received from the Resolve function of the specified bootstrap provider. /// - void EndResolve([In] IntPtr pvContext, [In] DRT_BOOTSTRAP_RESOLVE_CONTEXT ResolveContext); + protected abstract void EndResolve([In] DRT_BOOTSTRAP_RESOLVE_CONTEXT ResolveContext); /// Called by the DRT infrastructure to supply configuration information about upcoming name resolutions. - /// Contains the pvContext value from DRT_BOOTSTRAP_PROVIDER. /// Specifies if the resolve operation is being utilized for network split detection and recovery. /// Specifies the maximum time a resolve should take before timing out. This value is represented in milliseconds. /// Specifies the maximum number of results to return during the resolve operation. /// Pointer to resolver specific data. /// - /// If the bootstrap provider encounters an irrecoverable error, this parameter must be set to TRUE when the function - /// complete in order for the DRT to transition to the faulted state. The HRESULT that is made available to the higher layer - /// application for debugging will appear in the hr member of the DRT_EVENT_DATA structure associated with the event - /// signaling the transition to the faulted state. This bootstrap provider function should not return S_OK if setting the - /// fFatalError flag to TRUE. + /// If the bootstrap provider encounters an irrecoverable error, this parameter must be set to TRUE when the function complete in + /// order for the DRT to transition to the faulted state. The HRESULT that is made available to the higher layer application for + /// debugging will appear in the hr member of the DRT_EVENT_DATA structure associated with the event signaling the transition to + /// the faulted state. This bootstrap provider function should not return S_OK if setting the fFatalError flag to TRUE. /// - /// - HRESULT InitResolve([In] IntPtr pvContext, bool fSplitDetect, uint timeout, uint cMaxResults, + protected abstract HRESULT InitResolve(bool fSplitDetect, TimeSpan timeout, uint cMaxResults, out DRT_BOOTSTRAP_RESOLVE_CONTEXT ResolveContext, out bool fFatalError); /// /// Called by the DRT infrastructure to issue a resolution to determine the endpoints of nodes already active in the DRT cloud. /// - /// Contains the pvContext value from DRT_BOOTSTRAP_PROVIDER. /// Pointer to the context data that is passed back to the callback defined by the next parameter. /// A BOOTSTRAP_RESOLVE_CALLBACK that is called back for each result and DRT_E_NO_MORE. /// Pointer to resolver specific data. /// - /// If the bootstrap provider encounters an irrecoverable error, this parameter must be set to TRUE when the function - /// complete in order for the DRT to transition to the faulted state. The HRESULT that is made available to the higher layer - /// application for debugging will appear in the hr member of the DRT_EVENT_DATA structure associated with the event - /// signaling the transition to the faulted state. This bootstrap provider function should not return S_OK if setting the - /// fFatalError flag to TRUE. + /// If the bootstrap provider encounters an irrecoverable error, this parameter must be set to TRUE when the function complete in + /// order for the DRT to transition to the faulted state. The HRESULT that is made available to the higher layer application for + /// debugging will appear in the hr member of the DRT_EVENT_DATA structure associated with the event signaling the transition to + /// the faulted state. This bootstrap provider function should not return S_OK if setting the fFatalError flag to TRUE. /// - /// - HRESULT IssueResolve([In] IntPtr pvContext, [In] IntPtr pvCallbackContext, DRT_BOOTSTRAP_RESOLVE_CALLBACK callback, + protected abstract HRESULT IssueResolve(DRT_BOOTSTRAP_RESOLVE_CALLBACK callback, [In] IntPtr pvCallbackContext, [In] DRT_BOOTSTRAP_RESOLVE_CONTEXT ResolveContext, out bool fFatalError); /// /// Registers an endpoint with the bootstrapping mechanism. This process makes it possible for other nodes find the endpoint via the /// bootstrap resolver. /// - /// Contains the pvContext value from DRT_BOOTSTRAP_PROVIDER. - /// Pointer to containing the list of addresses to register with the bootstrapping mechanism. - /// - HRESULT Register([In] IntPtr pvContext, [In] IntPtr pAddressList); + /// + /// Pointer to containing the list of addresses to register with the bootstrapping mechanism. + /// + protected virtual HRESULT Register([In] IPEndPoint[]? pAddressList) => HRESULT.S_OK; /// - /// This function deregisters an endpoint with the bootstrapping mechanism. As a result, other nodes will be unable to find the - /// local node via the bootstrap resolver. + /// This function deregisters an endpoint with the bootstrapping mechanism. As a result, other nodes will be unable to find the local + /// node via the bootstrap resolver. /// - /// Contains the pvContext value from DRT_BOOTSTRAP_PROVIDER. - void Unregister([In] IntPtr pvContext); -} + protected virtual void Unregister() { } -public interface ICustomDrtSecurityProvider -{ - /// Increments the count of references for the Security Provider with a set of DRTs. - /// Pointer to the value held by the pvContext member of DRT_SECURITY_PROVIDER. - HRESULT Attach([In] IntPtr pvContext); + private HRESULT InternalAttach(IntPtr pvContext) + { + if (InterlockedCompareExchange(ref refCount, 1, 0) != 0) + return HRESULT.DRT_E_BOOTSTRAPPROVIDER_IN_USE; + AddRef(); + return HRESULT.S_OK; + } - /// Decrements the count of references for the Security Provider with a set of DRTs. - /// Pointer to the value held by the pvContext member of DRT_SECURITY_PROVIDER. - void Detach([In] IntPtr pvContext); + /// Adds the reference. + protected void AddRef() + { + if (refCount == 0) + GC.SuppressFinalize(this); + InterlockedIncrement(ref refCount); + } - /// Called to register a key with the Security Provider. - /// Pointer to the value held by the pvContext member of DRT_SECURITY_PROVIDER. - /// - /// Pointer to the DRT_REGISTRATION structure created by an application and passed to the DrtRegisterKey function. - /// - /// Pointer to the context data created by an application and passed to the DrtRegisterKey function. - /// - HRESULT RegisterKey([In] IntPtr pvContext, in DRT_REGISTRATION pRegistration, [In, Optional] IntPtr pvKeyContext); + private void InternalDetach(IntPtr pvContext) + { + InterlockedCompareExchange(ref refCount, 0, 1); + Release(); + } - /// Called to deregister a key with the Security Provider. - /// Pointer to the value held by the pvContext member of DRT_SECURITY_PROVIDER. - /// Pointer to the key to which the payload is registered. - /// Pointer to the context data created by the application and passed to DrtRegisterKey. - /// - HRESULT UnregisterKey([In] IntPtr pvContext, in DRT_DATA pKey, [In, Optional] IntPtr pvKeyContext); + /// Releases this instance. + protected void Release() + { + if (InterlockedDecrement(ref refCount) == 0) + { + if (prov.pvContext != default) + GCHandle.FromIntPtr(prov.pvContext).Free(); + GC.ReRegisterForFinalize(this); + } + } - /// - /// Called when an Authority message is received on the wire. It is responsible for validating the data received, and for unpacking - /// the service addresses, revoked flag, and nonce from the Secured Address Payload. - /// - /// Pointer to the value held by the pvContext member of DRT_SECURITY_PROVIDER. - /// - /// Pointer to the payload received on the wire that contains the service addresses, revoked flag, nonce, and any other data - /// required by the security provider. - /// - /// Pointer to the cert chain received in the authority message. - /// Pointer to the classifier received in the authority message. - /// - /// Pointer to the nonce that was sent in the original Inquire or Lookup message. This value must be compared to the - /// value embedded in the Secured Address Payload to ensure they are the same. This value is fixed at 16 bytes. - /// - /// - /// Pointer to the application data payload received in the Authority message. After validation, the original data (after - /// decryption, removal of signature, and so on.) is output as pPayload. - /// - /// - /// Pointer to the byte array that represents the protocol major version. This is packed in every DRT packet to identify the version - /// of the security provider in use when a single DRT instance is supporting multiple Security Providers. - /// - /// - /// Pointer to the byte array that represents the protocol minor version. This is packed in every DRT packet to identify the version - /// of the security provider in use when a single DRT instance is supporting multiple Security Providers. - /// - /// Pointer to the key to which the payload is registered. - /// - /// Pointer to the original payload specified by the remote application. pPayload.pb is allocated by the security provider. - /// - /// Pointer to a pointer to the number of service addresses embedded in the secured address payload. - /// - /// Pointer to a pointer to the service addresses that are embedded in the Secured Address Payload. pAddresses is allocated - /// by the security provider. - /// - /// - /// Any DRT flags currently defined only to be the revoked or deleted flag that need to be unpacked for the local DRT instance processing. - /// Note Currently the only allowed value is: DRT_PAYLOAD_REVOKED (1) - /// - /// - unsafe HRESULT ValidateAndUnpackPayload([In] IntPtr pvContext, in DRT_DATA pSecuredAddressPayload, [In, Optional] DRT_DATA* pCertChain, - [In, Optional] DRT_DATA* pClassifier, [In, Optional] DRT_DATA* pNonce, [In, Optional] DRT_DATA* pSecuredPayload, - [Out] byte* pbProtocolMajor, [Out] byte* pbProtocolMinor, out DRT_DATA pKey, [Out, Optional] DRT_DATA* pPayload, - [Out] CERT_PUBLIC_KEY_INFO** ppPublicKey, [Out, Optional] void** ppAddressList, out uint pdwFlags); + private void InternalEndResolve(IntPtr pvContext, DRT_BOOTSTRAP_RESOLVE_CONTEXT ResolveContext) => EndResolve(ResolveContext); - /// - /// Called when an Authority message is about to be sent on the wire. It is responsible for securing the data before it is sent, and - /// for packing the service addresses, revoked flag, nonce, and other application data into the Secured Address Payload. - /// - /// Pointer to the value held by the pvContext member of DRT_SECURITY_PROVIDER. - /// Contains the context passed into DrtRegisterKey when the key was registered. - /// Pointer to the byte array that represents the protocol major version. - /// Pointer to the byte array that represents the protocol minor version. - /// + private HRESULT InternalInitResolve(IntPtr pvContext, bool fSplitDetect, uint timeout, uint cMaxResults, + out DRT_BOOTSTRAP_RESOLVE_CONTEXT ResolveContext, out bool fFatalError) => + InitResolve(fSplitDetect, TimeSpan.FromMilliseconds(timeout), cMaxResults, out ResolveContext, out fFatalError); + + private HRESULT InternalIssueResolve(IntPtr pvContext, IntPtr pvCallbackContext, DRT_BOOTSTRAP_RESOLVE_CALLBACK callback, + DRT_BOOTSTRAP_RESOLVE_CONTEXT ResolveContext, out bool fFatalError) => + IssueResolve(callback, pvCallbackContext, ResolveContext, out fFatalError); + + private HRESULT InternalRegister(IntPtr pvContext, IntPtr pAddressList) => Register(ToEndPoints(pAddressList.ToNullableStructure())); + + private void InternalUnregister(IntPtr pvContext) => Unregister(); +} + +/// DNS Bootstrapper +/// +public class CustomDnsBootstapper : DrtCustomBootstrapProvider +{ + private readonly string hostname; + private readonly string port; + private readonly object m_lock = new(); + private bool m_fResolveInProgress; + private uint m_CallbackThreadId; + private bool m_fEndResolve; + private SafeEventHandle? m_hCallbackComplete; + private uint m_dwMaxResults; + + /// Initializes a new instance of the class. + /// + /// A string that contains a host (node) name or a numeric host address string. For the Internet protocol, the numeric host address + /// string is a dotted-decimal IPv4 address or an IPv6 hex address. + /// + /// + /// A string that contains either a service name or port number represented as a string. /// - /// Any DRT specific flags, currently defined only to be the revoked or deleted flag that need to be packed, secured and sent to - /// another instance for processing. + /// A service name is a string alias for a port number. For example, “http” is an alias for port 80 defined by the Internet Engineering + /// Task Force (IETF) as the default port used by web servers for the HTTP protocol. Possible values for the pServiceName parameter when + /// a port number is not specified are listed in the following file: /// - /// Note Currently the only allowed value is: DRT_PAYLOAD_REVOKED /// - /// Pointer to the key to which this payload is registered. - /// Pointer to the payload specified by the application when calling DrtRegisterKey. - /// Pointer to the service addresses that are placed in the Secured Address Payload. - /// - /// Pointer to the nonce that was sent in the original Inquire or Lookup message. This value is fixed at 16 bytes. - /// - /// - /// Pointer to the payload to send on the wire which contains the service addresses, revoked flag, nonce, and other data required by - /// the security provider. pSecuredAddressPayload.pb is allocated by the security provider. - /// - /// - /// Pointer to the classifier to send in the Authority message. pClassifier.pb is allocated by the security provider. - /// - /// - /// Pointer to the application data payload received in the Authority message. After validation, the original data (after - /// decryption, removal of signature, and so on.) is output as pPayload. pSecuredPayload.pb is allocated by the security provider. - /// - /// - /// Pointer to the cert chain to send in the Authority message. pCertChain.pb is allocated by the security provider. - /// - /// - unsafe HRESULT SecureAndPackPayload([In] IntPtr pvContext, [In, Optional] IntPtr pvKeyContext, byte bProtocolMajor, byte bProtocolMinor, - uint dwFlags, in DRT_DATA pKey, [In, Optional] DRT_DATA* pPayload, [In, Optional] IntPtr pAddressList, - in DRT_DATA pNonce, out DRT_DATA pSecuredAddressPayload, [Out, Optional] DRT_DATA* pClassifier, - [Out, Optional] DRT_DATA* pSecuredPayload, [Out, Optional] DRT_DATA* pCertChain); + public CustomDnsBootstapper(string pNodeName, string pServiceName) + { + hostname = pNodeName; + port = pServiceName; + } - /// Called to release resources previously allocated for a security provider function. - /// Pointer to the value held by the pvContext member of DRT_SECURITY_PROVIDER. - /// Specifies what data to free. - void FreeData([In] IntPtr pvContext, [In, Optional] IntPtr pv); + /// + protected override void EndResolve([In] DRT_BOOTSTRAP_RESOLVE_CONTEXT ResolveContext) + { + var fWaitForCallback = false; - /// - /// Called when the DRT sends a message containing data that must be encrypted. This function is only called when the DRT is - /// operating in the DRT_SECURE_CONFIDENTIALPAYLOAD security mode defined by DRT_SECURITY_MODE. - /// - /// Pointer to the value held by the pvContext member of DRT_SECURITY_PROVIDER. - /// Contains the credential of the peer that will receive the protected message. - /// Contains the length of the pDataBuffers and pEncryptedBuffers. - /// Contains the unencrypted buffer. - /// Contains the encrypted content upon completion of the function. - /// - /// Contains the encrypted session key that can be decrypted by the recipient of the message and used to decrypted the protected fields. - /// - /// - HRESULT EncryptData([In] IntPtr pvContext, in DRT_DATA pRemoteCredential, - [In] DRT_DATA[] pDataBuffers, [Out] DRT_DATA[] pEncryptedBuffers, out DRT_DATA pKeyToken); + var CallbackComplete = CreateEvent(default, true, false, default); - /// - /// Called when the DRT receives a message containing encrypted data. This function is only called when the DRT is operating in the - /// DRT_SECURE_CONFIDENTIALPAYLOAD security mode defined by DRT_SECURITY_MODE. - /// - /// Pointer to the value held by the pvContext member of DRT_SECURITY_PROVIDER. - /// - /// Contains the encrypted session key that can be decrypted by the recipient of the message and used to decrypt the protected fields. - /// - /// Contains the context passed into DrtRegisterKey when the key was registered. - /// Contains the size of pData buffer. - /// Contains the decrypted data upon completion of the function. - /// - HRESULT DecryptData([In] IntPtr pvContext, in DRT_DATA pKeyToken, [In] IntPtr pvKeyContext, - [In, Out] DRT_DATA[] pData); + lock (m_lock) + { + if (m_fResolveInProgress && (GetCurrentThreadId() != m_CallbackThreadId)) + { + if (!m_fEndResolve) + { + // This is the first thread to call EndResolve and we need to wait for a callback to complete so initialize the class + // member event + m_fEndResolve = true; + m_hCallbackComplete = CallbackComplete; + } + fWaitForCallback = true; + } + } - /// - /// Called when the DRT must provide a credential used to authorize the local node. This function is only called when the DRT is - /// operating in the DRT_SECURE_MEMBERSHIP and DRT_SECURE_CONFIDENTIALPAYLOAD security modes defined by DRT_SECURITY_MODE. - /// - /// Pointer to the value held by the pvContext member of DRT_SECURITY_PROVIDER. - /// Contains the serialized credential upon completion of the function. - /// - HRESULT GetSerializedCredential([In] IntPtr pvContext, out DRT_DATA pSelfCredential); + if (!CallbackComplete.IsInvalid && (m_hCallbackComplete != CallbackComplete)) + { + // This thread was not the first to call EndResolve, so its event is not in use, release it (m_hCallbackComplete is released + // in the destructor) + CallbackComplete.Dispose(); + } - /// Called when the DRT must validate a credential provided by a peer node. - /// Pointer to the value held by the pvContext member of DRT_SECURITY_PROVIDER. - /// Contains the serialized credential provided by the peer node. - /// - HRESULT ValidateRemoteCredential([In] IntPtr pvContext, in DRT_DATA pRemoteCredential); + if (fWaitForCallback && m_hCallbackComplete != null) + { + WaitForSingleObject(m_hCallbackComplete, INFINITE); + } - /// - /// Called when the DRT must sign a data blob for inclusion in a DRT protocol message. This function is only called when the DRT is - /// operating in the DRT_SECURE_MEMBERSHIP and DRT_SECURE_CONFIDENTIALPAYLOAD security modes defined by DRT_SECURITY_MODE. - /// - /// Pointer to the value held by the pvContext member of DRT_SECURITY_PROVIDER. - /// Contains the size of the pDataBuffers buffer. - /// Contains the data to be signed. - /// - /// Upon completion of this function, contains an index that can be used to select from multiple credentials for use in calculating - /// the signature. + Release(); + } + + /// + protected override HRESULT InitResolve(bool fSplitDetect, TimeSpan timeout, uint cMaxResults, out DRT_BOOTSTRAP_RESOLVE_CONTEXT pResolveContext, out bool fFatalError) + { + fFatalError = false; + pResolveContext = default; + + var hr = HRESULT.DRT_E_BOOTSTRAPPROVIDER_NOT_ATTACHED; + if (IsAttached) + { + // The cache is not scope aware so we ask for a larger number of addresses than the cache wants. In the expectation that one + // of them may be good for us + m_dwMaxResults = cMaxResults; + + AddRef(); + hr = HRESULT.S_OK; + } + + if (hr.Failed) + { + // CustomDNSResolver has no retry cases, so any failed HRESULT is fatal + fFatalError = true; + } + + return hr; + } + + /// + protected override HRESULT IssueResolve(DRT_BOOTSTRAP_RESOLVE_CALLBACK callback, [In] IntPtr pvCallbackContext, [In] DRT_BOOTSTRAP_RESOLVE_CONTEXT ResolveContext, out bool fFatalError) + { + fFatalError = false; + + if (callback is null) + { + return HRESULT.E_INVALIDARG; + } + + var hr = HRESULT.DRT_E_BOOTSTRAPPROVIDER_NOT_ATTACHED; + if (IsAttached) + { + lock (m_lock) + { + m_fResolveInProgress = true; + m_CallbackThreadId = GetCurrentThreadId(); + } + + if (m_dwMaxResults > 0) + { + var addresses = hostname.Split(new[] { ';', ' ' }, StringSplitOptions.RemoveEmptyEntries).Select(s => s.Trim()).ToArray(); + foreach (var CurrentAddress in addresses) + { + if (m_fEndResolve) + break; + + // Retrieve bootstrap possibilities + var addrInf = new ADDRINFOW + { + ai_flags = ADDRINFO_FLAGS.AI_CANONNAME, + ai_family = ADDRESS_FAMILY.AF_UNSPEC, + ai_socktype = SOCK.SOCK_STREAM + }; + + var nStat = GetAddrInfoW(CurrentAddress, port, addrInf, out var results); + if (nStat.Succeeded) + { + using (results) + { + var cbSA6 = Marshal.SizeOf(typeof(SOCKADDR_IN6)); + using var psockAddrs = new SafeNativeArray(results.Select(a => { using var ar = a.addr; return (SOCKADDR_IN6)ar; }).ToArray()); + var Addresses = new SOCKET_ADDRESS_LIST + { + iAddressCount = psockAddrs.Count, + Address = psockAddrs.Select((a, i) => new SOCKET_ADDRESS { iSockaddrLength = cbSA6, lpSockaddr = ((IntPtr)psockAddrs).Offset(cbSA6) }).ToArray() + }; + + // Call the callback to signal completion + using var pAddresses = Addresses.Pack(); + callback?.Invoke(hr, pvCallbackContext, pAddresses, false); + } + } + else + { + // GetAddrInfoW Failed but there may be more addresses in the string so keep going otherwise we return + // HRESULT.E_NO_MORE and retry next cycle + } + } + } + + // Tell the drt there will be no more results + if (!m_fEndResolve) + callback?.Invoke(HRESULT.DRT_E_NO_MORE, pvCallbackContext, default, false); + + lock (m_lock) + { + if (m_hCallbackComplete != null && !m_hCallbackComplete.IsInvalid) + { + // Notify EndResolve that callbacks have completed + m_hCallbackComplete.Set(); + } + m_fResolveInProgress = false; + } + hr = HRESULT.S_OK; + } + + if (hr.Failed) + { + // DNSResolver has no retry cases, so any failed HRESULT is fatal + fFatalError = true; + } + return hr; + } +} + +/// Base class for a DRT security provider. +public class DrtSecurityProvider : IDisposable +{ + /// The security provider structure. + protected DRT_SECURITY_PROVIDER prov; + + private readonly IntPtr pProv; + private readonly char pProvType; + + /// Initializes a new instance of the class. + /// The prov. + protected DrtSecurityProvider(in DRT_SECURITY_PROVIDER prov) + { + this.prov = prov; + pProv = GCHandle.Alloc(this.prov, GCHandleType.Pinned).AddrOfPinnedObject(); + pProvType = 'h'; + } + + private DrtSecurityProvider() + { } + + private DrtSecurityProvider(IntPtr ptr, char provType) + { + pProv = ptr; + pProvType = provType; + } + + /// Creates the derived key security provider for a Distributed Routing Table. + /// + /// Pointer to the certificate that is the "root" portion of the chain. This is used to ensure that keys derived from the same chain can + /// be verified. /// - /// Upon completion of this function, contains the signature data. - /// - HRESULT SignData([In] IntPtr pvContext, [In] DRT_DATA[] pDataBuffers, out DRT_DATA pKeyIdentifier, - out DRT_DATA pSignature); + /// Pointer to the DRT_SECURITY_PROVIDER module to be included in the DRT_SETTINGS structure. + /// A derived key instance. + /// + /// The security provider created by this function is specific to the DRT it was created for. It cannot be shared by multiple DRT instances. + /// + public static DrtSecurityProvider CreateDerivedKeySecurityProvider(PCCERT_CONTEXT pRootCert, PCCERT_CONTEXT pLocalCert) + { + DrtCreateDerivedKeySecurityProvider(pRootCert, pLocalCert, out IntPtr psp).ThrowIfFailed(); + return new(psp, 'd'); + } - /// - /// Called when the DRT must verify a signature calculated over a block of data included in a DRT message. This function is only - /// called when the DRT is operating in the DRT_SECURE_MEMBERSHIP and DRT_SECURE_CONFIDENTIALPAYLOAD security modes - /// defined by DRT_SECURITY_MODE. - /// - /// Pointer to the value held by the pvContext member of DRT_SECURITY_PROVIDER. - /// Contains the size of the pDataBuffers buffer. - /// Contains the data over which the signature was calculated. - /// Contains the credentials of the remote node used to calculate the signature. - /// Contains an index that may be used to select from multiple credentials provided in pRemoteCredentials. - /// Contains the signature to be verified. - /// - HRESULT VerifyData([In] IntPtr pvContext, [In] DRT_DATA[] pDataBuffers, in DRT_DATA pRemoteCredentials, - in DRT_DATA pKeyIdentifier, in DRT_DATA pSignature); + /// Creates a null security provider. This security provider does not require nodes to authenticate keys. + /// A null instance. + public static DrtSecurityProvider CreateNullSecurityProvider() + { + DrtCreateNullSecurityProvider(out IntPtr psp).ThrowIfFailed(); + return new DrtSecurityProvider(psp, 'n'); + } + + /// Performs an explicit conversion from to . + /// The prov. + /// The result of the conversion. + public static explicit operator IntPtr(DrtSecurityProvider prov) => prov.pProv; + + /// Performs application-defined tasks associated with freeing, releasing, or resetting unmanaged resources. + public void Dispose() + { + if (pProv != default) + { + if (pProvType == 'n') + { + DrtDeleteNullSecurityProvider(pProv); + } + else if (pProvType == 'd') + { + DrtDeleteDerivedKeySecurityProvider(pProv); + } + else if (pProvType == 'h') + { + GCHandle.FromIntPtr(pProv).Free(); + } + } + } } -[StructLayout(LayoutKind.Sequential)] -public abstract class DrtSecurityProvider +internal static class DrtUtil { - private readonly DRT_SECURITY_PROVIDER prov; + public static IntPtr Alloc(int size) => Marshal.AllocCoTaskMem(size); + + public static void Free(IntPtr ptr) => Marshal.FreeCoTaskMem(ptr); + + public static IntPtr ToAddrListPtr(IPEndPoint[]? pts) + { + if (pts is null || pts.Length == 0) + { + return default; + } + + SocketAddress[] sa = Array.ConvertAll(pts, p => p.Serialize()); + int ptsz = sa.Sum(a => a.Size); + int strsz = Marshal.SizeOf(typeof(SOCKET_ADDRESS_LIST)); + int sasz = Marshal.SizeOf(typeof(SOCKET_ADDRESS)); + SafeCoTaskMemHandle psal = new(strsz + sasz * (pts.Length - 1) + ptsz); + psal.Write(pts.Length); + Span salSpan = psal.AsSpan(psal.Size); + for (int i = 0, aoff = Marshal.OffsetOf(typeof(SOCKET_ADDRESS_LIST), "Address").ToInt32(), asoff = strsz + sasz * (pts.Length - 1); i < pts.Length; i++, aoff += sasz) + { + psal.Write(new SOCKET_ADDRESS { iSockaddrLength = sa[i].Size, lpSockaddr = ((IntPtr)psal).Offset(asoff) }, false, aoff); + for (int j = 0; j < sa[i].Size; j++) + { + salSpan[asoff + j] = sa[i][j]; + } + + asoff += sa[i].Size; + } + return psal.TakeOwnership(); + } + + public static DRT_DATA ToData(byte[]? data) + { + DRT_DATA ret = default; + if (data is not null) + { + ret.pb = data.MarshalToPtr(Alloc, out int cb); + ret.cb = (uint)cb; + } + return ret; + } + + public static IPEndPoint[]? ToEndPoints(SOCKET_ADDRESS_LIST? al) + { + return al.HasValue ? Array.ConvertAll(al.Value.Address, Cvt) : null; + + static IPEndPoint Cvt(SOCKET_ADDRESS a) + { + SOCKADDR sa = new(a.lpSockaddr, false, a.iSockaddrLength); + SocketAddress nsa = new((System.Net.Sockets.AddressFamily)sa.sa_family, sa.Size); + Span saspan = sa.AsBytes(); + for (int i = 2; i < sa.Size; i++) + { + nsa[i] = saspan[i]; + } + + IPEndPoint ep = new(0, 0); + return (IPEndPoint)ep.Create(nsa); + } + } +} + +/// Abstract base class for a custom DRT security provider. +public abstract class DrtCustomSecurityProvider : DrtSecurityProvider +{ + private int refCount; + /// Initializes a new instance of the class. + /// The context. + protected DrtCustomSecurityProvider(object? context) : base(default) + { + unsafe + { + prov.Attach = InternalAttach; + prov.Detach = InternalDetach; + prov.RegisterKey = InternalRegisterKey; + prov.UnregisterKey = InternalUnregisterKey; + prov.ValidateAndUnpackPayload = InternalValidateAndUnpackPayload; + prov.SecureAndPackPayload = InternalSecureAndPackPayload; + prov.FreeData = InternalFreeData; + prov.EncryptData = InternalEncryptData; + prov.DecryptData = InternalDecryptData; + prov.GetSerializedCredential = InternalGetSerializedCredential; + prov.ValidateRemoteCredential = InternalValidateRemoteCredential; + prov.SignData = InternalSignData; + prov.VerifyData = InternalVerifyData; + if (context != null) + prov.pvContext = GCHandle.Alloc(context).AddrOfPinnedObject(); + } + AddRef(); + } - protected DrtSecurityProvider() + /// Adds the reference. + protected void AddRef() { - unsafe - { - prov = new() - { - Attach = InternalAttach, - Detach = InternalDetach, - RegisterKey = InternalRegisterKey, - UnregisterKey = InternalUnregisterKey, - ValidateAndUnpackPayload = InternalValidateAndUnpackPayload, - SecureAndPackPayload = InternalSecureAndPackPayload, - FreeData = InternalFreeData, - EncryptData = InternalEncryptData, - DecryptData = InternalDecryptData, - GetSerializedCredential = InternalGetSerializedCredential, - ValidateRemoteCredential = InternalValidateRemoteCredential, - SignData = InternalSignData, - VerifyData = InternalVerifyData, - }; - } - } - - protected static IntPtr Alloc(int size) => Marshal.AllocCoTaskMem(size); - protected static void Free(IntPtr ptr) => Marshal.FreeCoTaskMem(ptr); - private static DRT_DATA ToData(byte[]? data) + if (refCount == 0) + GC.SuppressFinalize(this); + InterlockedIncrement(ref refCount); + } + + /// Releases this instance. + protected void Release() { - DRT_DATA ret = default; - if (data is not null) + if (InterlockedDecrement(ref refCount) == 0) { - ret.pb = data.MarshalToPtr(Alloc, out int cb); - ret.cb = (uint)cb; + if (prov.pvContext != default) + GCHandle.FromIntPtr(prov.pvContext).Free(); + GC.ReRegisterForFinalize(this); } - return ret; } - /// Increments the count of references for the Security Provider with a set of DRTs. - protected virtual void Attach() { } - /// Decrements the count of references for the Security Provider with a set of DRTs. - protected virtual void Detach() { } + /// Gets the context provided for all methods. + /// The context object. + protected virtual object? Context => prov.pvContext == IntPtr.Zero ? null : GCHandle.FromIntPtr(prov.pvContext).Target; + + /// Gets a value indicating whether this instance is attached. + /// + /// if this instance is attached; otherwise, . + /// + protected bool IsAttached => refCount > 0; + /// /// Called when the DRT receives a message containing encrypted data. This function is only called when the DRT is operating in the /// DRT_SECURE_CONFIDENTIALPAYLOAD security mode defined by DRT_SECURITY_MODE. @@ -486,10 +716,11 @@ protected virtual void Detach() { } /// /// Contains the context passed into DrtRegisterKey when the key was registered. /// Contains the decrypted data upon completion of the function. - protected virtual void DecryptData([In] byte[] pKeyToken, [Optional] IntPtr pvKeyContext, byte[][] pData) { } + protected virtual HRESULT DecryptData([In] byte[] pKeyToken, [Optional] IntPtr pvKeyContext, byte[][] pData) => HRESULT.S_OK; + /// - /// Called when the DRT sends a message containing data that must be encrypted. This function is only called when the DRT is - /// operating in the DRT_SECURE_CONFIDENTIALPAYLOAD security mode defined by DRT_SECURITY_MODE. + /// Called when the DRT sends a message containing data that must be encrypted. This function is only called when the DRT is operating in + /// the DRT_SECURE_CONFIDENTIALPAYLOAD security mode defined by DRT_SECURITY_MODE. /// /// Contains the credential of the peer that will receive the protected message. /// Contains the unencrypted buffer. @@ -498,58 +729,88 @@ protected virtual void DecryptData([In] byte[] pKeyToken, [Optional] IntPtr pvKe /// Contains the encrypted session key that can be decrypted by the recipient of the message and used to decrypted the protected fields. /// /// - protected abstract void EncryptData([In] byte[] pRemoteCredential, byte[][] pDataBuffers, byte[][] pEncryptedBuffers, out byte[] pKeyToken); - /// Called to register a key with the Security Provider. - /// - /// Pointer to the DRT_REGISTRATION structure created by an application and passed to the DrtRegisterKey function. - /// - /// Pointer to the context data created by an application and passed to the DrtRegisterKey function. - /// - protected virtual void RegisterKey(in DRT_REGISTRATION pRegistration, [In, Optional] IntPtr pvKeyContext) { } + protected abstract HRESULT EncryptData([In] byte[] pRemoteCredential, byte[][] pDataBuffers, byte[][] pEncryptedBuffers, out byte[]? pKeyToken); + /// Called to release resources previously allocated for a security provider function. /// Specifies what data to free. protected virtual void FreeData([In, Optional] IntPtr pv) => Free(pv); + /// /// Called when the DRT must provide a credential used to authorize the local node. This function is only called when the DRT is /// operating in the DRT_SECURE_MEMBERSHIP and DRT_SECURE_CONFIDENTIALPAYLOAD security modes defined by DRT_SECURITY_MODE. /// /// Contains the serialized credential upon completion of the function. protected virtual byte[]? GetSerializedCredential() => null; + + /// Called to register a key with the Security Provider. + /// + /// Pointer to the DRT_REGISTRATION structure created by an application and passed to the DrtRegisterKey function. + /// + /// Pointer to the context data created by an application and passed to the DrtRegisterKey function. + /// + protected virtual HRESULT RegisterKey(in DRT_REGISTRATION pRegistration, [In, Optional] IntPtr pvKeyContext) => HRESULT.S_OK; + + /// + /// Called when an Authority message is about to be sent on the wire. It is responsible for securing the data before it is sent, and for + /// packing the service addresses, revoked flag, nonce, and other application data into the Secured Address Payload. + /// + /// Contains the context passed into DrtRegisterKey when the key was registered. + /// Pointer to the byte array that represents the protocol major version. + /// Pointer to the byte array that represents the protocol minor version. + /// + /// + /// Any DRT specific flags, currently defined only to be the revoked or deleted flag that need to be packed, secured and sent to another + /// instance for processing. + /// + /// Note Currently the only allowed value is: DRT_PAYLOAD_REVOKED + /// + /// Pointer to the key to which this payload is registered. + /// Pointer to the payload specified by the application when calling DrtRegisterKey. + /// Pointer to the service addresses that are placed in the Secured Address Payload. + /// + /// Pointer to the nonce that was sent in the original Inquire or Lookup message. This value is fixed at 16 bytes. + /// + /// + /// Pointer to the payload to send on the wire which contains the service addresses, revoked flag, nonce, and other data required by the + /// security provider. pSecuredAddressPayload.pb is allocated by the security provider. + /// + /// + /// Pointer to the classifier to send in the Authority message. pClassifier.pb is allocated by the security provider. + /// + /// + /// Pointer to the application data payload received in the Authority message. After validation, the original data (after decryption, + /// removal of signature, and so on.) is output as pPayload. pSecuredPayload.pb is allocated by the security provider. + /// + /// + /// Pointer to the cert chain to send in the Authority message. pCertChain.pb is allocated by the security provider. + /// + /// + protected abstract HRESULT SecureAndPackPayload([In, Optional] IntPtr pvKeyContext, byte bProtocolMajor, byte bProtocolMinor, uint dwFlags, + [In] byte[] pKey, [In] byte[]? pPayload, [In] IPEndPoint[]? pAddressList, [In] byte[] pNonce, out byte[] pSecuredAddressPayload, + out byte[]? pClassifier, out byte[]? pSecuredPayload, out byte[]? pCertChain); + /// /// Called when the DRT must sign a data blob for inclusion in a DRT protocol message. This function is only called when the DRT is /// operating in the DRT_SECURE_MEMBERSHIP and DRT_SECURE_CONFIDENTIALPAYLOAD security modes defined by DRT_SECURITY_MODE. /// /// Contains the data to be signed. /// - /// Upon completion of this function, contains an index that can be used to select from multiple credentials for use in calculating - /// the signature. + /// Upon completion of this function, contains an index that can be used to select from multiple credentials for use in calculating the signature. /// /// Upon completion of this function, contains the signature data. /// - protected virtual void SignData(byte[][] dataBuffers, out byte[]? keyIdentifier, out byte[]? signature) => keyIdentifier = signature = null; + protected virtual HRESULT SignData(byte[][] dataBuffers, out byte[]? keyIdentifier, out byte[]? signature) + { + keyIdentifier = signature = null; + return HRESULT.S_OK; + } + /// Called to deregister a key with the Security Provider. /// Pointer to the key to which the payload is registered. /// Pointer to the context data created by an application and passed to the DrtRegisterKey function. /// Pointer to the context data created by the application and passed to DrtRegisterKey. - protected virtual void UnregisterKey(byte[] key, [In, Optional] IntPtr pvKeyContext) { } - /// Called when the DRT must validate a credential provided by a peer node. - /// Contains the serialized credential provided by the peer node. - /// - protected virtual void ValidateRemoteCredential(byte[] pRemoteCredential) { } - /// - /// Called when the DRT must verify a signature calculated over a block of data included in a DRT message. This function is only - /// called when the DRT is operating in the DRT_SECURE_MEMBERSHIP and DRT_SECURE_CONFIDENTIALPAYLOAD security modes - /// defined by DRT_SECURITY_MODE. - /// - /// Contains the data over which the signature was calculated. - /// Contains the credentials of the remote node used to calculate the signature. - /// Contains an index that may be used to select from multiple credentials provided in pRemoteCredentials. - /// Contains the signature to be verified. - /// - protected virtual void VerifyData(byte[][] pDataBuffers, byte[] remoteCredentials, byte[] keyIdentifier, byte[] signature) - { - if (signature is null || signature.Length == 0) throw HRESULT.DRT_E_INVALID_MESSAGE.GetException(); - } + protected virtual HRESULT UnregisterKey(byte[] key, [In, Optional] IntPtr pvKeyContext) => HRESULT.S_OK; + /// /// Called when an Authority message is received on the wire. It is responsible for validating the data received, and for unpacking the /// service addresses, revoked flag, and nonce from the Secured Address Payload. @@ -590,108 +851,621 @@ protected virtual void VerifyData(byte[][] pDataBuffers, byte[] remoteCredential /// Note Currently the only allowed value is: DRT_PAYLOAD_REVOKED (1) /// /// - protected abstract void ValidateAndUnpackPayload([In] byte[] pSecuredAddressPayload, [In] byte[]? pCertChain, + protected abstract HRESULT ValidateAndUnpackPayload([In] byte[] pSecuredAddressPayload, [In] byte[]? pCertChain, [In] byte[]? pClassifier, [In] byte[]? pNonce, [In] byte[]? pSecuredPayload, out byte pbProtocolMajor, out byte pbProtocolMinor, out byte[] pKey, out byte[]? pPayload, - out SafeCoTaskMemStruct ppPublicKey, out SafeCoTaskMemStruct? ppAddressList, + out SafeCoTaskMemStruct ppPublicKey, out IPEndPoint[]? ppAddressList, out uint pdwFlags); - static HRESULT Execute(Action action) + /// Called when the DRT must validate a credential provided by a peer node. + /// Contains the serialized credential provided by the peer node. + /// + protected virtual HRESULT ValidateRemoteCredential(byte[] pRemoteCredential) => HRESULT.S_OK; + + /// + /// Called when the DRT must verify a signature calculated over a block of data included in a DRT message. This function is only called + /// when the DRT is operating in the DRT_SECURE_MEMBERSHIP and DRT_SECURE_CONFIDENTIALPAYLOAD security modes defined by DRT_SECURITY_MODE. + /// + /// Contains the data over which the signature was calculated. + /// Contains the credentials of the remote node used to calculate the signature. + /// Contains an index that may be used to select from multiple credentials provided in pRemoteCredentials. + /// Contains the signature to be verified. + /// + protected virtual HRESULT VerifyData(byte[][] pDataBuffers, byte[] remoteCredentials, byte[] keyIdentifier, byte[] signature) => + signature is null || signature.Length == 0 ? HRESULT.DRT_E_INVALID_MESSAGE : (HRESULT)HRESULT.S_OK; + + private HRESULT InternalAttach(IntPtr pvContext) { - try - { - action(); - return HRESULT.S_OK; - } - catch (Exception ex) { return ex.HResult; } + if (InterlockedCompareExchange(ref refCount, 1, 0) != 0) + return HRESULT.DRT_E_SECURITYPROVIDER_IN_USE; + AddRef(); + return HRESULT.S_OK; } - HRESULT InternalAttach(IntPtr pvContext) => Execute(Attach); - HRESULT InternalDecryptData(IntPtr pvContext, in DRT_DATA pKeyToken, IntPtr pvKeyContext, uint dwBuffers, DRT_DATA[] pData) + private HRESULT InternalDecryptData(IntPtr pvContext, in DRT_DATA pKeyToken, IntPtr pvKeyContext, uint dwBuffers, DRT_DATA[] pData) => + DecryptData(pKeyToken.GetArray(), pvKeyContext, Array.ConvertAll(pData, p => p.GetArray())); + + private void InternalDetach(IntPtr pvContext) { - byte[] pkt = pKeyToken; - return Execute(() => DecryptData(pkt, pvKeyContext, Array.ConvertAll(pData, p => p.GetArray()))); + InterlockedCompareExchange(ref refCount, 0, 1); + Release(); } - void InternalDetach(IntPtr pvContext) => Execute(Detach); - HRESULT InternalEncryptData(IntPtr pvContext, in DRT_DATA pRemoteCredential, uint dwBuffers, DRT_DATA[] pDataBuffers, DRT_DATA[] pEncryptedBuffers, out DRT_DATA pKeyToken) + + private HRESULT InternalEncryptData(IntPtr pvContext, in DRT_DATA pRemoteCredential, uint dwBuffers, DRT_DATA[] pDataBuffers, DRT_DATA[] pEncryptedBuffers, out DRT_DATA pKeyToken) { - byte[] prc = pRemoteCredential; - DRT_DATA dkt = default; - var hr = Execute(() => - { - EncryptData(prc, Array.ConvertAll(pDataBuffers, p => p.GetArray()), Array.ConvertAll(pEncryptedBuffers, p => p.GetArray()), out var pkt); - dkt = ToData(pkt); - }); - pKeyToken = dkt; + var hr = EncryptData(pRemoteCredential.GetArray(), Array.ConvertAll(pDataBuffers, p => p.GetArray()), Array.ConvertAll(pEncryptedBuffers, p => p.GetArray()), out byte[]? pkt); + pKeyToken = ToData(pkt); return hr; } - void InternalFreeData(IntPtr pvContext, IntPtr pv) => Execute(() => FreeData(pv)); - HRESULT InternalGetSerializedCredential(IntPtr pvContext, out DRT_DATA pSelfCredential) + private void InternalFreeData(IntPtr pvContext, IntPtr pv) => FreeData(pv); + + private HRESULT InternalGetSerializedCredential(IntPtr pvContext, out DRT_DATA pSelfCredential) { - byte[]? output = null; - var hr = Execute(() => output = GetSerializedCredential()); - pSelfCredential = ToData(output); - return hr; + pSelfCredential = ToData(GetSerializedCredential()); + return HRESULT.S_OK; } - HRESULT InternalRegisterKey(IntPtr pvContext, in DRT_REGISTRATION pRegistration, IntPtr pvKeyContext) + private HRESULT InternalRegisterKey(IntPtr pvContext, in DRT_REGISTRATION pRegistration, IntPtr pvKeyContext) => + RegisterKey(pRegistration, pvKeyContext); + + private unsafe HRESULT InternalSecureAndPackPayload(IntPtr pvContext, IntPtr pvKeyContext, byte bProtocolMajor, byte bProtocolMinor, + uint dwFlags, in DRT_DATA pKey, DRT_DATA* pPayload, IntPtr pAddressList, in DRT_DATA pNonce, out DRT_DATA pSecuredAddressPayload, + DRT_DATA* pClassifier, DRT_DATA* pSecuredPayload, DRT_DATA* pCertChain) { - var pr = pRegistration; - return Execute(() => RegisterKey(pr, pvKeyContext)); + HRESULT hr = SecureAndPackPayload(pvContext, bProtocolMajor, bProtocolMinor, dwFlags, pKey.GetArray(), pPayload is null ? null : (*pPayload).GetArray(), + ToEndPoints(pAddressList.ToNullableStructure()), pNonce.GetArray(), out byte[]? sap, out byte[]? cl, out byte[]? sp, out byte[]? cc); + pSecuredAddressPayload = ToData(sap); + if (cl is not null) + *pClassifier = ToData(cl); + if (sp is not null) + *pSecuredPayload = ToData(sp); + if (cc is not null) + *pCertChain = ToData(cc); + return hr; } - HRESULT InternalSignData(IntPtr pvContext, uint dwBuffers, DRT_DATA[] pDataBuffers, out DRT_DATA pKeyIdentifier, out DRT_DATA pSignature) + private HRESULT InternalSignData(IntPtr pvContext, uint dwBuffers, DRT_DATA[] pDataBuffers, out DRT_DATA pKeyIdentifier, out DRT_DATA pSignature) { - byte[] id = null, sig = null; - var hr = Execute(() => SignData(Array.ConvertAll(pDataBuffers, b => b.GetArray()), out id, out sig)); - pKeyIdentifier = ToData(id); pSignature = ToData(sig); + HRESULT hr = SignData(Array.ConvertAll(pDataBuffers, b => b.GetArray()), out byte[]? id, out byte[]? sig); + pKeyIdentifier = ToData(id); + pSignature = ToData(sig); return hr; } - HRESULT InternalUnregisterKey(IntPtr pvContext, in DRT_DATA pKey, IntPtr pvKeyContext) + private HRESULT InternalUnregisterKey(IntPtr pvContext, in DRT_DATA pKey, IntPtr pvKeyContext) => + UnregisterKey(pKey.GetArray()); + + private unsafe HRESULT InternalValidateAndUnpackPayload(IntPtr pvContext, in DRT_DATA pSecuredAddressPayload, DRT_DATA* pCertChain, + DRT_DATA* pClassifier, DRT_DATA* pNonce, DRT_DATA* pSecuredPayload, byte* pbProtocolMajor, byte* pbProtocolMinor, + out DRT_DATA pKey, DRT_DATA* pPayload, CERT_PUBLIC_KEY_INFO** ppPublicKey, void** ppAddressList, out uint pdwFlags) { - byte[] pk = pKey.GetArray(); - return Execute(() => UnregisterKey(pk)); + HRESULT hr = ValidateAndUnpackPayload(pSecuredAddressPayload, pCertChain is null ? null : (*pCertChain).GetArray(), + pClassifier is null ? null : (*pClassifier).GetArray(), pNonce is null ? null : (*pNonce).GetArray(), + pSecuredPayload is null ? null : (*pSecuredPayload).GetArray(), out byte maj, out byte min, out byte[]? k, + out byte[]? pl, out SafeCoTaskMemStruct? pk, out IPEndPoint[]? al, out pdwFlags); + *pbProtocolMajor = maj; + *pbProtocolMinor = min; + pKey = ToData(k); + if (pl is not null) + *pPayload = ToData(pl); + *ppPublicKey = (CERT_PUBLIC_KEY_INFO*)(pk?.TakeOwnership() ?? IntPtr.Zero); + *ppAddressList = (void*)ToAddrListPtr(al); + return hr; } - HRESULT InternalValidateRemoteCredential(IntPtr pvContext, in DRT_DATA pRemoteCredential) + private HRESULT InternalValidateRemoteCredential(IntPtr pvContext, in DRT_DATA pRemoteCredential) => + ValidateRemoteCredential(pRemoteCredential.GetArray()); + + private HRESULT InternalVerifyData(IntPtr pvContext, uint dwBuffers, DRT_DATA[] pDataBuffers, in DRT_DATA pRemoteCredentials, in DRT_DATA pKeyIdentifier, in DRT_DATA pSignature) => + VerifyData(Array.ConvertAll(pDataBuffers, b => b.GetArray()), pRemoteCredentials.GetArray(), pKeyIdentifier.GetArray(), pSignature.GetArray()); +} + +/* +/// +/// +/// +/// +public class CustomNullSecurityProvider : DrtCustomSecurityProvider +{ + internal unsafe class CCustomNullSecuredAddressPayload : IDisposable + { + public const ALG_ID DRT_ALGORITHM = ALG_ID.CALG_SHA_256; + public const string DRT_ALGORITHM_OID = AlgOID.szOID_RSA_SHA1RSA; + public const uint DRT_DERIVED_KEY_SIZE = 32; + + // default security provider constants + public const byte DRT_SECURITY_VERSION_MAJOR = 1; + + public const byte DRT_SECURITY_VERSION_MINOR = 0; + public const uint DRT_SHA2_LENGTH = 32; + public const uint DRT_SIG_LENGTH = SHA2_SIG_LENGTH; + + // Original 0x8000 + space for extended payload (4k plus some overhead) + public const uint MAX_MESSAGE_SIZE = 0x8000 + 0x1200; + + public const uint SHA1_SIG_LENGTH = 0x80; + public const uint SHA2_SIG_LENGTH = 0x80; + + private readonly byte[] m_signature = new byte[DRT_SIG_LENGTH]; + private IPEndPoint[]? m_addressList; + private byte m_bProtocolVersionMajor; + private byte m_bProtocolVersionMinor; + private byte[] m_ddKey; + private byte[] m_ddNonce; + private bool m_fAllocated; // set if the data needs to be freed when destroyed (true when deserializing) + + //CERT_PUBLIC_KEY_INFO* + private SafeCoTaskMemStruct? m_pPublicKey; + + /// + /// Initializes a new instance of the class. + /// + /// The b major. + /// The b minor. + /// The key. + /// The nonce. + /// The p address list. + /// The flags. + public CCustomNullSecuredAddressPayload(byte bMajor, byte bMinor, byte[] key, byte[] nonce, IPEndPoint[]? pAddressList, uint flags) + { + m_bProtocolVersionMajor = bMajor; + m_bProtocolVersionMinor = bMinor; + m_ddKey = key; + m_ddNonce = nonce; + m_addressList = pAddressList; + Flags = flags; + } + + // Purpose: Retrieve or set the flags + // + // Args: dwFlags: + public uint Flags { get; set; } + + // Serialized SecureAddressPayload format: bytes name 1 protocol major version 1 protocol minor version 1 security major version 1 + // security minor version 2 key length (KL) KL key 1 signature length (SL) SL signature 1 nonce length (NL) NL nonce 4 flags + // ----- public key ----------- 1 algorithm length (AL) 2 key parameters length (PL) 2 public key length (KL) 1 unused bits AL + // algorithm (byte) PL key parameters KL public key + // ----- end public key ------- 1 address count + // ----- for each address ----- 2 address length (AL) AL address data + // ----- end each address ----- + // Function: CCustomNullSecuredAddressPayload::DeserializeAndValidate + // + // Purpose: Deserialize and validate the payload. + // + // Args: pData: data to deserialize + // pNonce: expected nonce + // pCertChain: opt. remote cert chain (if one was in the message) + // hCryptProv: crypt provider to use with remote public key + // + // Notes: The deserialized data is later retrieved via Get* methods. + public HRESULT DeserializeAndValidate(byte[] pData, byte[]? pNonce) + { + HRESULT hr = HRESULT.S_OK; + + using var deserializer = new MemoryStream(pData, true); + m_fAllocated = true; + + // protocol version + m_bProtocolVersionMajor = (byte)deserializer.ReadByte(); + m_bProtocolVersionMinor = (byte)deserializer.ReadByte(); + + // security version + var bVersionMajor = (byte)deserializer.ReadByte(); + var bVersionMinor = (byte)deserializer.ReadByte(); + + // ensure we are receiving a version we understand + if (bVersionMajor != DRT_SECURITY_VERSION_MAJOR || bVersionMinor != DRT_SECURITY_VERSION_MINOR) + { + hr = HRESULT.DRT_E_INVALID_MESSAGE; + goto cleanup; + } + + // extract key + var cb = deserializer.Read(); + m_ddKey = new byte[cb]; + deserializer.Read(m_ddKey, 0, cb); + + // extract signature + var b = (byte)deserializer.ReadByte(); + if (b != DRT_SIG_LENGTH) + { + hr = HRESULT.DRT_E_INVALID_MESSAGE; + goto cleanup; + } + + var pbSignature = deserializer.Position; + deserializer.Position += DRT_SIG_LENGTH; //deserializer.ReadArray(DRT_SIG_LENGTH, &ddSignature); + + // extract and validate nonce + cb = (byte)deserializer.ReadByte(); + m_ddNonce = new byte[cb]; + deserializer.Read(m_ddNonce, 0, cb); + + // if a nonce was supplied, ensure it matches the nonce in the message + if (pNonce != null && (hr = CompareNonce(pNonce)).Failed) + { + goto cleanup; + } + + // extract flags + Flags = deserializer.Read(); + + // extract public key + hr = ReadPublicKey(deserializer, out m_pPublicKey); + + // extract addresses + var addressList = new SOCKET_ADDRESS_LIST { iAddressCount = (byte)deserializer.ReadByte() }; + addressList.Address = new SOCKET_ADDRESS[addressList.iAddressCount]; + for (var i = 0; i < addressList.iAddressCount; i++) + { + addressList.Address[i] = new SOCKET_ADDRESS { iSockaddrLength = deserializer.Read() }; + // Store just the pointer and then pull that into the packed object + addressList.Address[i].lpSockaddr = deserializer.Pointer.Offset(deserializer.Position); + deserializer.Seek(addressList.Address[i].iSockaddrLength, System.IO.SeekOrigin.Current); + } + + m_addressList = addressList.Pack(); + + if (deserializer.Position != deserializer.Length) + { + hr = HRESULT.DRT_E_INVALID_MESSAGE; + goto cleanup; + } + + cleanup: + // the remaining allocated memory is Marshal.FreeCoTaskMem in the destructor, or ownership is passed via GetAddresses + return hr; + } + + public void Dispose() + { + m_addressList?.Dispose(); + m_pPublicKey?.Dispose(); + if (m_fAllocated) + { + Marshal.FreeCoTaskMem(m_ddKey.pb); + Marshal.FreeCoTaskMem(m_ddNonce.pb); + } + } + + // Purpose: Retrieve the addresses. This returns the memory allocated during de-serialization, so can only be called once. Since it + // will only be called once, there isn't benefit to making another copy of the data. + public void GetAddresses(out SafeCoTaskMemStruct pAddressList) + { + pAddressList = m_addressList; + // this object no longer owns the address list + m_addressList = default; + } + + // Purpose: Retrieve the key deserialized earlier. This returns memory allocated during deserialization, and passes ownership to the + // caller. This method may only be called once. + public void GetKey(out DRT_DATA pData) + { + pData = m_ddKey; + // this object no longer owns the public key + m_ddKey = default; + } + + // Purpose: Retrieve the flags + public void GetProtocolVersion(out byte pbMajor, out byte pbMinor) + { + pbMajor = m_bProtocolVersionMajor; + pbMinor = m_bProtocolVersionMinor; + } + + // Purpose: Retrieve the public key deserialized earlier. This returns memory allocated during deserialization, and passes ownership + // to the caller. This method may only be called once. + public void GetPublicKey(out SafeCoTaskMemStruct pKey) + { + pKey = m_pPublicKey; + m_pPublicKey = null; + } + + // Purpose: Serialize the SecuredAddressPayload according to the format specified above, and sign it using the specified credentials. + // + // Args: pCertChain: [out] pData: serialized/signed data. pData->pb is allocated. + // + // Notes: The data to be serialized has already been set using the Set* methods. + public HRESULT SerializeAndSign(out byte[] pData) + { + CERT_PUBLIC_KEY_INFO publicKey = default; + using var emptyAddress = new SafeCoTaskMemString("0.0.0.0", CharSet.Ansi); + publicKey.Algorithm.pszObjId = (IntPtr)emptyAddress; + publicKey.PublicKey.cbData = sizeof(uint); + var dwBaadFood = 0xbaadf00d; + publicKey.PublicKey.pbData = (IntPtr)(&dwBaadFood); + + pData = default; + + uint cbAlgorithmId = emptyAddress.Size; + + // validate that the lengths are all reasonable (fit in the space provided for their count) + var addressList = m_addressList.Value; + if (m_ddNonce.cb > byte.MaxValue || + addressList.iAddressCount > byte.MaxValue || + m_ddKey.cb > ushort.MaxValue || cbAlgorithmId > byte.MaxValue || + publicKey.Algorithm.Parameters.cbData > ushort.MaxValue || + publicKey.PublicKey.cbData > ushort.MaxValue || + publicKey.PublicKey.cUnusedBits > byte.MaxValue) + { + return HRESULT.E_INVALIDARG; + } + + // serialize away + using var mem = new SafeCoTaskMemHandle(1024); + var ddDataPtr = new NativeMemoryStream(mem); + + // protocol version + ddDataPtr.Write(m_bProtocolVersionMajor); + ddDataPtr.Write(m_bProtocolVersionMinor); + + // security version + ddDataPtr.Write(DRT_SECURITY_VERSION_MAJOR); + ddDataPtr.Write(DRT_SECURITY_VERSION_MINOR); + + // key + ddDataPtr.Write((ushort)m_ddKey.cb); + ddDataPtr.WriteFromPtr(m_ddKey.pb, m_ddKey.cb); + + // skip over the signature for now (leave it zero while we calculate the signature) + ddDataPtr.Write((byte)DRT_SIG_LENGTH); + var pbSignature = ddDataPtr.Position; // save the location of the signature for later + ddDataPtr.Position += DRT_SIG_LENGTH; + + // nonce + ddDataPtr.Write((byte)m_ddNonce.cb); + ddDataPtr.WriteFromPtr(m_ddNonce.pb, m_ddNonce.cb); + + // flags + ddDataPtr.Write(Flags); + + // public key sizes + ddDataPtr.Write((byte)cbAlgorithmId); + ddDataPtr.Write((ushort)publicKey.Algorithm.Parameters.cbData); + ddDataPtr.Write((ushort)publicKey.PublicKey.cbData); + ddDataPtr.Write((byte)publicKey.PublicKey.cUnusedBits); + + // public key data + ddDataPtr.Write(publicKey.Algorithm.pszObjId.ToString(), CharSet.Ansi); + if (publicKey.Algorithm.Parameters.cbData > 0) + ddDataPtr.WriteFromPtr(publicKey.Algorithm.Parameters.pbData, publicKey.Algorithm.Parameters.cbData); + ddDataPtr.WriteFromPtr(publicKey.PublicKey.pbData, publicKey.PublicKey.cbData); + + // addresses + ddDataPtr.Write((byte)addressList.iAddressCount); + for (var i = 0; i < addressList.iAddressCount; i++) + { + ddDataPtr.Write((ushort)addressList.Address[i].iSockaddrLength); + ddDataPtr.WriteFromPtr(addressList.Address[i].lpSockaddr, addressList.Address[i].iSockaddrLength); + } + + // pass the data back to the caller + pData = new DRT_DATA { cb = (uint)ddDataPtr.Length, pb = mem.TakeOwnership() }; + + return HRESULT.S_OK; + } + + // Purpose: Read a public key from the stream + // + // Args: [out] ppPublicKey: public key allocated as a single block of memory (with self-refertial embedded pointers) + private static HRESULT ReadPublicKey(NativeMemoryStream deserializer, out SafeCoTaskMemStruct ppPublicKey) + { + ppPublicKey = default; + + try + { + var cbAlgorithmId = (byte)deserializer.ReadByte(); + var cbParameters = deserializer.Read(); + var cbPublicKey = deserializer.Read(); + var cUnusedBits = (byte)deserializer.ReadByte(); + + var szAlgId = cbAlgorithmId == 0 ? null : deserializer.Read(CharSet.Ansi); + var pParamData = cbParameters == 0 ? new byte[0] : deserializer.ReadArray(cbParameters, false).ToArray(); + var pKeyData = cbPublicKey == 0 ? new byte[0] : deserializer.ReadArray(cbPublicKey, false).ToArray(); + + var cbTotal = sizeof(CERT_PUBLIC_KEY_INFO) + Macros.ALIGN_TO_MULTIPLE(cbAlgorithmId + 1, IntPtr.Size) + + Macros.ALIGN_TO_MULTIPLE(cbParameters, IntPtr.Size) + Macros.ALIGN_TO_MULTIPLE(cbPublicKey, IntPtr.Size); + + var pPublicKey = new SafeCoTaskMemStruct(cbTotal); + ref var rpk = ref pPublicKey.AsRef(); + var pbStructIter = ((IntPtr)pPublicKey).Offset(sizeof(CERT_PUBLIC_KEY_INFO)); // skip the structure + + // copy the algorithm id + rpk.Algorithm.pszObjId = pbStructIter; + StringHelper.Write(szAlgId, pbStructIter, out var written, true, CharSet.Ansi); + pbStructIter += (int)Macros.ALIGN_TO_MULTIPLE(written, IntPtr.Size); + + // copy the key parameters + if (cbParameters > 0) + { + rpk.Algorithm.Parameters.cbData = cbParameters; + rpk.Algorithm.Parameters.pbData = pbStructIter; + pbStructIter.Write(pParamData); + pbStructIter += (int)Macros.ALIGN_TO_MULTIPLE(pParamData.Length, IntPtr.Size); + } + + // copy the key + rpk.PublicKey.cbData = cbPublicKey; + rpk.PublicKey.cUnusedBits = cUnusedBits; + rpk.PublicKey.pbData = pbStructIter; + pbStructIter.Write(pKeyData); + + ppPublicKey = pPublicKey; + + return HRESULT.S_OK; + } + catch (Exception ex) + { + return HRESULT.FromException(ex); + } + } + + // Purpose: Compare the nonce provided by the DRT to the nonce received on the wire, returning HRESULT.DRT_E_INVALID_MESSAGE if they + // don't match. + // + // Args: pNonce: + private HRESULT CompareNonce(byte[] pNonce) => pNonce.SequenceEqual(m_ddNonce) ? (HRESULT)HRESULT.S_OK : HRESULT.DRT_E_INVALID_MESSAGE; + } + + /// Initializes a new instance of the class. + /// The context. + public CustomNullSecurityProvider(object? context) : base(context) { } + + /// + protected override HRESULT EncryptData([In] byte[] pRemoteCredential, byte[][] pDataBuffers, byte[][] pEncryptedBuffers, out byte[]? pKeyToken) { - byte[] rc = pRemoteCredential.GetArray(); - return Execute(() => ValidateRemoteCredential(rc)); + HRESULT hr = HRESULT.S_OK; + pKeyToken = default; + + //copy all input buffers into out buffers unmodified + for (uint dwIdx = 0; dwIdx < pDataBuffers.GetLength(0); dwIdx++) + { + pDataBuffers[dwIdx].CopyTo(pEncryptedBuffers[dwIdx], 0); + } + return hr; } - HRESULT InternalVerifyData(IntPtr pvContext, uint dwBuffers, DRT_DATA[] pDataBuffers, in DRT_DATA pRemoteCredentials, in DRT_DATA pKeyIdentifier, in DRT_DATA pSignature) + /// + protected override HRESULT SecureAndPackPayload([In, Optional] IntPtr pvKeyContext, byte bProtocolMajor, byte bProtocolMinor, uint dwFlags, + [In] byte[] pKey, [In] byte[]? pPayload, [In] IPEndPoint[]? pAddressList, [In] byte[] pNonce, out byte[] pSecuredAddressPayload, + out byte[]? pClassifier, out byte[]? pSecuredPayload, out byte[]? pCertChain) { - byte[] rc = pRemoteCredentials.GetArray(), id = pKeyIdentifier.GetArray(), sig = pSignature.GetArray(); - return Execute(() => VerifyData(Array.ConvertAll(pDataBuffers, b => b.GetArray()), rc, id, sig)); + // NULL out the out params + pClassifier = default; + pSecuredPayload = default; + pCertChain = default; + + // set the payload contents + var sap = new CCustomNullSecuredAddressPayload(bProtocolMajor, bProtocolMinor, pKey, pNonce, pAddressList, dwFlags); + var hr = sap.SerializeAndSign(out pSecuredAddressPayload); + if (hr.Failed) + { + goto cleanup; + } + + if (pPayload != null && pSecuredPayload != null) + { + pSecuredPayload->cb = pPayload->cb; + pSecuredPayload->pb = Marshal.AllocCoTaskMem((int)pSecuredPayload->cb); + if (pSecuredPayload->pb == default) + { + hr = HRESULT.E_OUTOFMEMORY; + goto cleanup; + } + pPayload->pb.CopyTo(pSecuredPayload->pb, pSecuredPayload->cb); + } + + // make a copy of the serialized local cert chain + if (pCertChain != null) + { + pCertChain->cb = sizeof(uint); + pCertChain->pb = Marshal.AllocCoTaskMem((int)pCertChain->cb); + if (pCertChain->pb == default) + { + hr = HRESULT.E_OUTOFMEMORY; + goto cleanup; + } + pCertChain->pb.Write(0xdeadbeefU, 0, sizeof(uint)); + } + + cleanup: + // if something failed, free all the out params and NULL them out + if (hr.Failed) + { + Marshal.FreeCoTaskMem(pSecuredAddressPayload.pb); + pSecuredAddressPayload = default; + if (pSecuredPayload != null) + { + Marshal.FreeCoTaskMem(pSecuredPayload->pb); + *pSecuredPayload = default; + } + if (pCertChain != null) + { + Marshal.FreeCoTaskMem(pCertChain->pb); + *pCertChain = default; + } + } + + return hr; } - unsafe HRESULT InternalValidateAndUnpackPayload(IntPtr pvContext, in DRT_DATA pSecuredAddressPayload, DRT_DATA* pCertChain, - DRT_DATA* pClassifier, DRT_DATA* pNonce, DRT_DATA* pSecuredPayload, byte* pbProtocolMajor, byte* pbProtocolMinor, - out DRT_DATA pKey, DRT_DATA* pPayload, CERT_PUBLIC_KEY_INFO** ppPublicKey, void** ppAddressList, out uint pdwFlags) + /// + protected override HRESULT ValidateAndUnpackPayload([In] byte[] pSecuredAddressPayload, [In] byte[]? pCertChain, [In] byte[]? pClassifier, + [In] byte[]? pNonce, [In] byte[]? pSecuredPayload, out byte pbProtocolMajor, out byte pbProtocolMinor, out byte[] pKey, + out byte[]? pPayload, out SafeCoTaskMemStruct ppPublicKey, out IPEndPoint[]? ppAddressList, out uint pdwFlags) { - byte[] sap = pSecuredAddressPayload.GetArray(); - byte[]? cc = pCertChain is null ? null : (*pCertChain).GetArray(); - byte[]? cl = pClassifier is null ? null : (*pClassifier).GetArray(); - byte[]? no = pNonce is null ? null : (*pNonce).GetArray(); - byte[]? sp = pSecuredPayload is null ? null : (*pSecuredPayload).GetArray(); - DRT_DATA lpKey = default; - uint lpdwFlags = 0; - var hr = Execute(() => { - ValidateAndUnpackPayload(sap, cc, cl, no, sp, out var maj, out var min, out var k, out var pl, out var pk, out var al, out var fl); - *pbProtocolMajor = maj; - *pbProtocolMinor = min; - lpKey = ToData(k); - if (pl is not null) *pPayload = ToData(pl); - *ppPublicKey = (CERT_PUBLIC_KEY_INFO*)(void*)pk.TakeOwnership(); - if (al is not null && !al.IsInvalid) *ppAddressList = (void*)al.TakeOwnership(); - lpdwFlags = fl; - }); - pKey = lpKey; - pdwFlags = lpdwFlags; + var sap = new CCustomNullSecuredAddressPayload(); + HRESULT hr = HRESULT.S_OK; + + // NULL out the out params + *pbProtocolMajor = 0; + *pbProtocolMinor = 0; + pKey = default; + if (pPayload != null) + *pPayload = default; + *ppPublicKey = null; + pdwFlags = 0; + + // deserialize Secured Address Payload + hr = sap.DeserializeAndValidate(pSecuredAddressPayload, pNonce); + if (hr.Failed) + { + goto cleanup; + } + + // When we asked for the payload validate signature of payload + if (pPayload != null && pSecuredPayload != null) + { + pPayload->cb = pSecuredPayload->cb; + pPayload->pb = Marshal.AllocCoTaskMem((int)pPayload->cb); + if (pPayload->pb == default) + { + hr = HRESULT.E_OUTOFMEMORY; + goto cleanup; + } + pSecuredPayload->pb.CopyTo(pPayload->pb, pPayload->cb); + } + + pdwFlags = sap.Flags; + + // everything is valid, time to extract the data + if (ppAddressList != null) + { + sap.GetAddresses(out var addr); + *ppAddressList = (void*)addr.TakeOwnership(); + } + sap.GetPublicKey(out var pk); + *ppPublicKey = (CERT_PUBLIC_KEY_INFO*)pk.TakeOwnership(); + sap.GetKey(out pKey); + sap.GetProtocolVersion(out *pbProtocolMajor, out *pbProtocolMinor); + + cleanup: + // if something failed, free all the out params and NULL them out + if (hr.Failed) + { + *pbProtocolMajor = 0; + *pbProtocolMinor = 0; + pdwFlags = 0; + Marshal.FreeCoTaskMem(pKey.pb); + pKey = default; + if (pPayload != null) + { + Marshal.FreeCoTaskMem(pPayload->pb); + *pPayload = default; + } + Marshal.FreeCoTaskMem((IntPtr)(*ppPublicKey)); + *ppPublicKey = null; + + // free all the addresses + if (ppAddressList != null) + { + Marshal.FreeCoTaskMem((IntPtr)(*ppAddressList)); + *ppAddressList = null; + } + } + return hr; } - unsafe HRESULT InternalSecureAndPackPayload(IntPtr pvContext, IntPtr pvKeyContext, byte bProtocolMajor, byte bProtocolMinor, uint dwFlags, in DRT_DATA pKey, DRT_DATA* pPayload, IntPtr pAddressList, in DRT_DATA pNonce, out DRT_DATA pSecuredAddressPayload, DRT_DATA* pClassifier, DRT_DATA* pSecuredPayload, DRT_DATA* pCertChain) => throw new NotImplementedException(); } -#endif \ No newline at end of file +*/ \ No newline at end of file