Skip to content

Commit

Permalink
Add static authorization
Browse files Browse the repository at this point in the history
  • Loading branch information
XmasApple committed Sep 25, 2023
1 parent 8a6d072 commit e239301
Show file tree
Hide file tree
Showing 13 changed files with 404 additions and 14 deletions.
6 changes: 6 additions & 0 deletions src/Ydb.Sdk/src/Auth/IUseDriverConfig.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
namespace Ydb.Sdk.Auth;

public interface IUseDriverConfig
{
public Task ProvideConfig(DriverConfig driverConfig);
}
160 changes: 160 additions & 0 deletions src/Ydb.Sdk/src/Auth/IamProviderBase.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
namespace Ydb.Sdk.Auth;

using System;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;

public abstract class IamProviderBase : ICredentialsProvider
{
private static readonly TimeSpan IamRefreshInterval = TimeSpan.FromMinutes(5);

private static readonly TimeSpan IamRefreshGap = TimeSpan.FromMinutes(1);

private const int IamMaxRetries = 5;

private readonly object _lock = new();

private readonly ILogger _logger;

private volatile IamTokenData? _iamToken;
private volatile Task? _refreshTask;

protected IamProviderBase(ILoggerFactory? loggerFactory)
{
loggerFactory ??= NullLoggerFactory.Instance;
_logger = loggerFactory.CreateLogger<IamProviderBase>();
}

public async Task Initialize()
{
_iamToken = await ReceiveIamToken();
}

public string? GetAuthInfo()
{
var iamToken = _iamToken;

if (iamToken is null)
{
lock (_lock)
{
if (_iamToken is not null) return _iamToken.Token;
_logger.LogWarning("Blocking for initial IAM token acquirement" +
", please use explicit Initialize async method.");

_iamToken = ReceiveIamToken().Result;

return _iamToken.Token;
}
}

if (iamToken.IsExpired())
{
lock (_lock)
{
if (!_iamToken!.IsExpired()) return _iamToken.Token;
_logger.LogWarning("Blocking on expired IAM token.");

_iamToken = ReceiveIamToken().Result;

return _iamToken.Token;
}
}

if (!iamToken.IsExpiring() || _refreshTask is not null) return _iamToken!.Token;
lock (_lock)
{
if (!_iamToken!.IsExpiring() || _refreshTask is not null) return _iamToken!.Token;
_logger.LogInformation("Refreshing IAM token.");

_refreshTask = Task.Run(RefreshIamToken);
}

return _iamToken!.Token;
}

private async Task RefreshIamToken()
{
var iamToken = await ReceiveIamToken();

lock (_lock)
{
_iamToken = iamToken;
_refreshTask = null;
}
}

protected async Task<IamTokenData> ReceiveIamToken()
{
var retryAttempt = 0;
while (true)
{
try
{
_logger.LogTrace($"Attempting to receive IAM token, attempt: {retryAttempt}");

var iamToken = await FetchToken();

_logger.LogInformation($"Received IAM token, expires at: {iamToken.ExpiresAt}");

return iamToken;
}
catch (Exception e)
{
_logger.LogDebug($"Failed to fetch IAM token, {e}");

if (retryAttempt >= IamMaxRetries)
{
throw;
}

await Task.Delay(TimeSpan.FromSeconds(Math.Pow(2, retryAttempt)));
++retryAttempt;
}
}
}

protected abstract Task<IamTokenData> FetchToken();

protected class IamTokenData
{
public IamTokenData(string token, DateTime expiresAt)
{
var now = DateTime.UtcNow;

Token = token;
ExpiresAt = expiresAt;

if (expiresAt <= now)
{
RefreshAt = expiresAt;
}
else
{
var refreshSeconds = new Random().Next((int)IamRefreshInterval.TotalSeconds);
RefreshAt = expiresAt - IamRefreshGap - TimeSpan.FromSeconds(refreshSeconds);

if (RefreshAt < now)
{
RefreshAt = expiresAt;
}
}
}

public string Token { get; }
public DateTime ExpiresAt { get; }

public DateTime RefreshAt { get; }

public bool IsExpired()
{
return DateTime.UtcNow >= ExpiresAt;
}

public bool IsExpiring()
{
return DateTime.UtcNow >= RefreshAt;
}
}
}
54 changes: 54 additions & 0 deletions src/Ydb.Sdk/src/Auth/StaticProvider.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
using System.IdentityModel.Tokens.Jwt;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
using Ydb.Sdk.Services.Auth;

namespace Ydb.Sdk.Auth;

public class StaticProvider : IamProviderBase, IUseDriverConfig
{
private readonly ILogger _logger;

private readonly string _user;
private readonly string? _password;

private Driver? _driver;


public StaticProvider(string user, string? password = null, ILoggerFactory? loggerFactory = null) : base(
loggerFactory)
{
_user = user;
_password = password;
loggerFactory ??= NullLoggerFactory.Instance;
_logger = loggerFactory.CreateLogger<StaticProvider>();
}

protected override async Task<IamTokenData> FetchToken()
{
if (_driver is null)
{
_logger.LogError("Driver in for static auth not provided");
throw new NullReferenceException();
}

var client = new AuthClient(_driver);
var loginResponse = await client.Login(_user, _password);
loginResponse.Status.EnsureSuccess();
var token = loginResponse.Result.Token;
var jwt = new JwtSecurityToken(token);
return new IamTokenData(token, jwt.ValidTo);
}

public async Task ProvideConfig(DriverConfig driverConfig)
{
_driver = await Driver.CreateInitialized(
new DriverConfig(
driverConfig.Endpoint,
driverConfig.Database,
new AnonymousProvider(),
driverConfig.DefaultTransportTimeout,
driverConfig.DefaultStreamingTransportTimeout,
driverConfig.CustomServerCertificate));
}
}
7 changes: 7 additions & 0 deletions src/Ydb.Sdk/src/Driver.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using Microsoft.Extensions.Logging.Abstractions;
using Ydb.Discovery;
using Ydb.Discovery.V1;
using Ydb.Sdk.Auth;

namespace Ydb.Sdk;

Expand Down Expand Up @@ -72,6 +73,12 @@ public ValueTask DisposeAsync()

public async Task Initialize()
{
if (_config.Credentials is IUseDriverConfig useDriverConfig)
{
await useDriverConfig.ProvideConfig(_config);
_logger.LogInformation("DriverConfig provided to IUseDriverConfig interface");
}

_logger.LogInformation("Started initial endpoint discovery");

try
Expand Down
10 changes: 10 additions & 0 deletions src/Ydb.Sdk/src/Services/Auth/AuthClient.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
using Ydb.Sdk.Client;

namespace Ydb.Sdk.Services.Auth;

public partial class AuthClient : ClientBase
{
public AuthClient(Driver driver) : base(driver)
{
}
}
72 changes: 72 additions & 0 deletions src/Ydb.Sdk/src/Services/Auth/Login.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
using Ydb.Auth;
using Ydb.Auth.V1;
using Ydb.Sdk.Client;

namespace Ydb.Sdk.Services.Auth;

public class LoginSettings : OperationRequestSettings
{
}

public class LoginResponse : ResponseWithResultBase<LoginResponse.ResultData>
{
internal LoginResponse(Status status, ResultData? result = null)
: base(status, result)
{
}

public class ResultData
{
public string Token { get; }

internal ResultData(string token)
{
Token = token;
}


internal static ResultData FromProto(LoginResult resultProto)
{
var token = resultProto.Token;
return new ResultData(token);
}
}
}

public partial class AuthClient
{
public async Task<LoginResponse> Login(string user, string? password, LoginSettings? settings = null)
{
settings ??= new LoginSettings();
var request = new LoginRequest
{
OperationParams = MakeOperationParams(settings),
Password = password,
User = user
};

try
{
var response = await Driver.UnaryCall(
method: AuthService.LoginMethod,
request: request,
settings: settings
);

var status = UnpackOperation(response.Data.Operation, out LoginResult? resultProto);

LoginResponse.ResultData? result = null;

if (status.IsSuccess && resultProto is not null)
{
result = LoginResponse.ResultData.FromProto(resultProto);
}

return new LoginResponse(status, result);
}
catch (Driver.TransportException e)
{
return new LoginResponse(e.Status);
}
}
}
7 changes: 6 additions & 1 deletion src/Ydb.Sdk/src/Ydb.Sdk.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@
</PropertyGroup>

<ItemGroup>
<PackageReference Include="Ydb.Protos" Version="1.0.3" />
<!-- <PackageReference Include="Ydb.Protos" Version="1.0.3" />-->
<PackageReference Include="Portable.BouncyCastle" Version="1.9.0" />
<PackageReference Include="System.IdentityModel.Tokens.Jwt" Version="7.0.0" />
</ItemGroup>

<ItemGroup Condition="$(TargetFramework.Equals('net6.0'))">
Expand All @@ -36,4 +37,8 @@
<ItemGroup Condition="$(TargetFramework.Equals('net7.0'))">
</ItemGroup>

<ItemGroup>
<ProjectReference Include="..\..\..\..\ydb-dotnet-genproto\src\Ydb.Protos\Ydb.Protos.csproj" />
</ItemGroup>

</Project>
Loading

0 comments on commit e239301

Please sign in to comment.