Skip to content

Commit

Permalink
Add auth static provider (user:password) (#42)
Browse files Browse the repository at this point in the history
* Add static authorization

* Add automatic user creation in tests

* Update Ydb.Proto dependency,

* Add more testcases, add throw of InvalidCredentialsException

* Add more verbosity in tests
  • Loading branch information
XmasApple authored Oct 3, 2023
1 parent 8a6d072 commit d1dc15b
Show file tree
Hide file tree
Showing 14 changed files with 474 additions and 17 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -67,4 +67,4 @@ jobs:
- name: Integration test
run: |
cd src
dotnet test --filter "Category=Integration" -f ${{ matrix.dotnet-target-framework }}
dotnet test --filter "Category=Integration" -f ${{ matrix.dotnet-target-framework }} -l "console;verbosity=normal"
File renamed without changes.
6 changes: 6 additions & 0 deletions src/Ydb.Sdk/src/Auth/IUseDriverConfig.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
namespace Ydb.Sdk.Auth;

public interface IUseDriverConfig
{
public Task ProvideConfig(DriverConfig driverConfig);
}
210 changes: 210 additions & 0 deletions src/Ydb.Sdk/src/Auth/StaticCredentialsProvider.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
using System.IdentityModel.Tokens.Jwt;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
using Ydb.Sdk.Services.Auth;

namespace Ydb.Sdk.Auth;

public class StaticCredentialsProvider : ICredentialsProvider, IUseDriverConfig
{
private readonly ILogger _logger;

private readonly string _user;
private readonly string? _password;

private Driver? _driver;

public int MaxRetries = 5;

private readonly object _lock = new();

private volatile TokenData? _token;
private volatile Task? _refreshTask;

public float RefreshRatio = .1f;

/// <summary>
///
/// </summary>
/// <param name="user">User of the database</param>
/// <param name="password">Password of the user. If user has no password use null </param>
/// <param name="loggerFactory"></param>
public StaticCredentialsProvider(string user, string? password, ILoggerFactory? loggerFactory = null)
{
_user = user;
_password = password;
loggerFactory ??= NullLoggerFactory.Instance;
_logger = loggerFactory.CreateLogger<StaticCredentialsProvider>();
}

private async Task Initialize()
{
_token = await ReceiveToken();
}

public string GetAuthInfo()
{
var token = _token;

if (token is null)
{
lock (_lock)
{
if (_token is not null) return _token.Token;
_logger.LogWarning(
"Blocking for initial token acquirement, please use explicit Initialize async method.");

Initialize().Wait();

return _token!.Token;
}
}

if (token.IsExpired())
{
lock (_lock)
{
if (!_token!.IsExpired()) return _token.Token;
_logger.LogWarning("Blocking on expired token.");

_token = ReceiveToken().Result;

return _token.Token;
}
}

if (!token.IsRefreshNeeded() || _refreshTask is not null) return _token!.Token;
lock (_lock)
{
if (!_token!.IsRefreshNeeded() || _refreshTask is not null) return _token!.Token;
_logger.LogInformation("Refreshing token.");

_refreshTask = Task.Run(RefreshToken);
}

return _token!.Token;
}

private async Task RefreshToken()
{
var token = await ReceiveToken();

lock (_lock)
{
_token = token;
_refreshTask = null;
}
}

private async Task<TokenData> ReceiveToken()
{
var retryAttempt = 0;
while (true)
{
try
{
_logger.LogTrace($"Attempting to receive token, attempt: {retryAttempt}");

var token = await FetchToken();

_logger.LogInformation($"Received token, expires at: {token.ExpiresAt}");

return token;
}
catch (InvalidCredentialsException e)
{
_logger.LogWarning($"Invalid credentials, {e}");
throw;
}
catch (Exception e)
{
_logger.LogDebug($"Failed to fetch token, {e}");

if (retryAttempt >= MaxRetries)
{
_logger.LogWarning($"Can't fetch token, {e}");
throw;
}

await Task.Delay(TimeSpan.FromSeconds(Math.Pow(2, retryAttempt)));
_logger.LogInformation($"Failed to fetch token, attempt {retryAttempt}");
++retryAttempt;
}
}
}

private async Task<TokenData> FetchToken()
{
if (_driver is null)
{
_logger.LogError("Driver in for static auth not provided");
throw new NullReferenceException();
}

var client = new AuthClient(_driver);
var loginResponse = await client.Login(_user, _password);
if (loginResponse.Status.StatusCode == StatusCode.Unauthorized)
{
throw new InvalidCredentialsException(Issue.IssuesToString(loginResponse.Status.Issues));
}

loginResponse.Status.EnsureSuccess();
var token = loginResponse.Result.Token;
var jwt = new JwtSecurityToken(token);
return new TokenData(token, jwt.ValidTo, RefreshRatio);
}

public async Task ProvideConfig(DriverConfig driverConfig)
{
_driver = await Driver.CreateInitialized(
new DriverConfig(
driverConfig.Endpoint,
driverConfig.Database,
new AnonymousProvider(),
driverConfig.DefaultTransportTimeout,
driverConfig.DefaultStreamingTransportTimeout,
driverConfig.CustomServerCertificate));

await Initialize();
}

private class TokenData
{
public TokenData(string token, DateTime expiresAt, float refreshInterval)
{
var now = DateTime.UtcNow;

Token = token;
ExpiresAt = expiresAt;

if (expiresAt <= now)
{
RefreshAt = expiresAt;
}
else
{
RefreshAt = now + (expiresAt - now) * refreshInterval;

if (RefreshAt < now)
{
RefreshAt = expiresAt;
}
}
}

public string Token { get; }
public DateTime ExpiresAt { get; }

private DateTime RefreshAt { get; }

public bool IsExpired()
{
return DateTime.UtcNow >= ExpiresAt;
}

public bool IsRefreshNeeded()
{
return DateTime.UtcNow >= RefreshAt;
}
}
}
7 changes: 7 additions & 0 deletions src/Ydb.Sdk/src/Driver.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using Microsoft.Extensions.Logging.Abstractions;
using Ydb.Discovery;
using Ydb.Discovery.V1;
using Ydb.Sdk.Auth;

namespace Ydb.Sdk;

Expand Down Expand Up @@ -72,6 +73,12 @@ public ValueTask DisposeAsync()

public async Task Initialize()
{
if (_config.Credentials is IUseDriverConfig useDriverConfig)
{
await useDriverConfig.ProvideConfig(_config);
_logger.LogInformation("DriverConfig provided to IUseDriverConfig interface");
}

_logger.LogInformation("Started initial endpoint discovery");

try
Expand Down
10 changes: 10 additions & 0 deletions src/Ydb.Sdk/src/Services/Auth/AuthClient.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
using Ydb.Sdk.Client;

namespace Ydb.Sdk.Services.Auth;

public partial class AuthClient : ClientBase
{
public AuthClient(Driver driver) : base(driver)
{
}
}
75 changes: 75 additions & 0 deletions src/Ydb.Sdk/src/Services/Auth/Login.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
using Ydb.Auth;
using Ydb.Auth.V1;
using Ydb.Sdk.Client;

namespace Ydb.Sdk.Services.Auth;

public class LoginSettings : OperationRequestSettings
{
}

public class LoginResponse : ResponseWithResultBase<LoginResponse.ResultData>
{
internal LoginResponse(Status status, ResultData? result = null)
: base(status, result)
{
}

public class ResultData
{
public string Token { get; }

internal ResultData(string token)
{
Token = token;
}


internal static ResultData FromProto(LoginResult resultProto)
{
var token = resultProto.Token;
return new ResultData(token);
}
}
}

public partial class AuthClient
{
public async Task<LoginResponse> Login(string user, string? password, LoginSettings? settings = null)
{
settings ??= new LoginSettings();
var request = new LoginRequest
{
OperationParams = MakeOperationParams(settings),
User = user
};
if (password is not null)
{
request.Password = password;
}

try
{
var response = await Driver.UnaryCall(
method: AuthService.LoginMethod,
request: request,
settings: settings
);

var status = UnpackOperation(response.Data.Operation, out LoginResult? resultProto);

LoginResponse.ResultData? result = null;

if (status.IsSuccess && resultProto is not null)
{
result = LoginResponse.ResultData.FromProto(resultProto);
}

return new LoginResponse(status, result);
}
catch (Driver.TransportException e)
{
return new LoginResponse(e.Status);
}
}
}
3 changes: 2 additions & 1 deletion src/Ydb.Sdk/src/Ydb.Sdk.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@
</PropertyGroup>

<ItemGroup>
<PackageReference Include="Ydb.Protos" Version="1.0.3" />
<PackageReference Include="Ydb.Protos" Version="1.0.4" />
<PackageReference Include="Portable.BouncyCastle" Version="1.9.0" />
<PackageReference Include="System.IdentityModel.Tokens.Jwt" Version="7.0.0" />
</ItemGroup>

<ItemGroup Condition="$(TargetFramework.Equals('net6.0'))">
Expand Down
Loading

0 comments on commit d1dc15b

Please sign in to comment.