Skip to content

Commit

Permalink
remove Iam
Browse files Browse the repository at this point in the history
  • Loading branch information
XmasApple committed Sep 27, 2023
1 parent dd9939d commit cec765c
Show file tree
Hide file tree
Showing 6 changed files with 168 additions and 195 deletions.
File renamed without changes.
160 changes: 0 additions & 160 deletions src/Ydb.Sdk/src/Auth/IamProviderBase.cs

This file was deleted.

152 changes: 146 additions & 6 deletions src/Ydb.Sdk/src/Auth/StaticCredentialsProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

namespace Ydb.Sdk.Auth;

public class StaticCredentialsProvider : IamProviderBase, IUseDriverConfig
public class StaticCredentialsProvider : ICredentialsProvider, IUseDriverConfig
{
private readonly ILogger _logger;

Expand All @@ -14,23 +14,121 @@ public class StaticCredentialsProvider : IamProviderBase, IUseDriverConfig

private Driver? _driver;

private static readonly TimeSpan RefreshInterval = TimeSpan.FromMinutes(5);

private static readonly TimeSpan RefreshGap = TimeSpan.FromMinutes(1);

private const int MaxRetries = 5;

private readonly object _lock = new();

private volatile TokenData? _token;
private volatile Task? _refreshTask;

/// <summary>
///
/// </summary>
/// <param name="user">User of the database</param>
/// <param name="password">Password of the user. If user has no password use null or "" </param>
/// <param name="loggerFactory"></param>
public StaticCredentialsProvider(string user, string? password, ILoggerFactory? loggerFactory = null) : base(
loggerFactory)
public StaticCredentialsProvider(string user, string? password, ILoggerFactory? loggerFactory = null)
{
_user = user;
_password = password ?? "";
loggerFactory ??= NullLoggerFactory.Instance;
_logger = loggerFactory.CreateLogger<StaticCredentialsProvider>();
}

protected override async Task<IamTokenData> FetchToken()
public async Task Initialize()
{
_token = await ReceiveToken();
}

public string GetAuthInfo()
{
var token = _token;

if (token is null)
{
lock (_lock)
{
if (_token is not null) return _token.Token;
_logger.LogWarning(
"Blocking for initial token acquirement, please use explicit Initialize async method.");

_token = ReceiveToken().Result;

return _token.Token;
}
}

if (token.IsExpired())
{
lock (_lock)
{
if (!_token!.IsExpired()) return _token.Token;
_logger.LogWarning("Blocking on expired token.");

_token = ReceiveToken().Result;

return _token.Token;
}
}

if (!token.IsExpiring() || _refreshTask is not null) return _token!.Token;
lock (_lock)
{
if (!_token!.IsExpiring() || _refreshTask is not null) return _token!.Token;
_logger.LogInformation("Refreshing token.");

_refreshTask = Task.Run(RefreshToken);
}

return _token!.Token;
}

private async Task RefreshToken()
{
var token = await ReceiveToken();

lock (_lock)
{
_token = token;
_refreshTask = null;
}
}

private async Task<TokenData> ReceiveToken()
{
var retryAttempt = 0;
while (true)
{
try
{
_logger.LogTrace($"Attempting to receive token, attempt: {retryAttempt}");

var iamToken = await FetchToken();

_logger.LogInformation($"Received token, expires at: {iamToken.ExpiresAt}");

return iamToken;
}
catch (Exception e)
{
_logger.LogDebug($"Failed to fetch token, {e}");

if (retryAttempt >= MaxRetries)
{
throw;
}

await Task.Delay(TimeSpan.FromSeconds(Math.Pow(2, retryAttempt)));
++retryAttempt;
}
}
}

private async Task<TokenData> FetchToken()
{
if (_driver is null)
{
Expand All @@ -42,12 +140,13 @@ protected override async Task<IamTokenData> FetchToken()
var loginResponse = await client.Login(_user, _password);
if (loginResponse.Status.StatusCode == StatusCode.Unauthorized)
{
throw new InvalidCredentialsException(loginResponse.Status.Issues.ToString() ?? "Unknown");
throw new InvalidCredentialsException(Issue.IssuesToString(loginResponse.Status.Issues));
}

loginResponse.Status.EnsureSuccess();
var token = loginResponse.Result.Token;
var jwt = new JwtSecurityToken(token);
return new IamTokenData(token, jwt.ValidTo);
return new TokenData(token, jwt.ValidTo);
}

public async Task ProvideConfig(DriverConfig driverConfig)
Expand All @@ -61,4 +160,45 @@ public async Task ProvideConfig(DriverConfig driverConfig)
driverConfig.DefaultStreamingTransportTimeout,
driverConfig.CustomServerCertificate));
}

private class TokenData
{
public TokenData(string token, DateTime expiresAt)
{
var now = DateTime.UtcNow;

Token = token;
ExpiresAt = expiresAt;

if (expiresAt <= now)
{
RefreshAt = expiresAt;
}
else
{
var refreshSeconds = new Random().Next((int)RefreshInterval.TotalSeconds);
RefreshAt = expiresAt - RefreshGap - TimeSpan.FromSeconds(refreshSeconds);

if (RefreshAt < now)
{
RefreshAt = expiresAt;
}
}
}

public string Token { get; }
public DateTime ExpiresAt { get; }

private DateTime RefreshAt { get; }

public bool IsExpired()
{
return DateTime.UtcNow >= ExpiresAt;
}

public bool IsExpiring()
{
return DateTime.UtcNow >= RefreshAt;
}
}
}
7 changes: 6 additions & 1 deletion src/Ydb.Sdk/src/Driver.cs
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,14 @@ public async Task Initialize()
if (_config.Credentials is IUseDriverConfig useDriverConfig)
{
await useDriverConfig.ProvideConfig(_config);
if (_config.Credentials is StaticCredentialsProvider staticCredentialsProvider)
{
await staticCredentialsProvider.Initialize();
}

_logger.LogInformation("DriverConfig provided to IUseDriverConfig interface");
}

_logger.LogInformation("Started initial endpoint discovery");

try
Expand Down
Loading

0 comments on commit cec765c

Please sign in to comment.