-
Notifications
You must be signed in to change notification settings - Fork 18
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add auth static provider (user:password) (#42)
* Add static authorization * Add automatic user creation in tests * Update Ydb.Proto dependency, * Add more testcases, add throw of InvalidCredentialsException * Add more verbosity in tests
- Loading branch information
Showing
14 changed files
with
474 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
namespace Ydb.Sdk.Auth; | ||
|
||
public interface IUseDriverConfig | ||
{ | ||
public Task ProvideConfig(DriverConfig driverConfig); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,210 @@ | ||
using System.IdentityModel.Tokens.Jwt; | ||
using Microsoft.Extensions.Logging; | ||
using Microsoft.Extensions.Logging.Abstractions; | ||
using Ydb.Sdk.Services.Auth; | ||
|
||
namespace Ydb.Sdk.Auth; | ||
|
||
public class StaticCredentialsProvider : ICredentialsProvider, IUseDriverConfig | ||
{ | ||
private readonly ILogger _logger; | ||
|
||
private readonly string _user; | ||
private readonly string? _password; | ||
|
||
private Driver? _driver; | ||
|
||
public int MaxRetries = 5; | ||
|
||
private readonly object _lock = new(); | ||
|
||
private volatile TokenData? _token; | ||
private volatile Task? _refreshTask; | ||
|
||
public float RefreshRatio = .1f; | ||
|
||
/// <summary> | ||
/// | ||
/// </summary> | ||
/// <param name="user">User of the database</param> | ||
/// <param name="password">Password of the user. If user has no password use null </param> | ||
/// <param name="loggerFactory"></param> | ||
public StaticCredentialsProvider(string user, string? password, ILoggerFactory? loggerFactory = null) | ||
{ | ||
_user = user; | ||
_password = password; | ||
loggerFactory ??= NullLoggerFactory.Instance; | ||
_logger = loggerFactory.CreateLogger<StaticCredentialsProvider>(); | ||
} | ||
|
||
private async Task Initialize() | ||
{ | ||
_token = await ReceiveToken(); | ||
} | ||
|
||
public string GetAuthInfo() | ||
{ | ||
var token = _token; | ||
|
||
if (token is null) | ||
{ | ||
lock (_lock) | ||
{ | ||
if (_token is not null) return _token.Token; | ||
_logger.LogWarning( | ||
"Blocking for initial token acquirement, please use explicit Initialize async method."); | ||
|
||
Initialize().Wait(); | ||
|
||
return _token!.Token; | ||
} | ||
} | ||
|
||
if (token.IsExpired()) | ||
{ | ||
lock (_lock) | ||
{ | ||
if (!_token!.IsExpired()) return _token.Token; | ||
_logger.LogWarning("Blocking on expired token."); | ||
|
||
_token = ReceiveToken().Result; | ||
|
||
return _token.Token; | ||
} | ||
} | ||
|
||
if (!token.IsRefreshNeeded() || _refreshTask is not null) return _token!.Token; | ||
lock (_lock) | ||
{ | ||
if (!_token!.IsRefreshNeeded() || _refreshTask is not null) return _token!.Token; | ||
_logger.LogInformation("Refreshing token."); | ||
|
||
_refreshTask = Task.Run(RefreshToken); | ||
} | ||
|
||
return _token!.Token; | ||
} | ||
|
||
private async Task RefreshToken() | ||
{ | ||
var token = await ReceiveToken(); | ||
|
||
lock (_lock) | ||
{ | ||
_token = token; | ||
_refreshTask = null; | ||
} | ||
} | ||
|
||
private async Task<TokenData> ReceiveToken() | ||
{ | ||
var retryAttempt = 0; | ||
while (true) | ||
{ | ||
try | ||
{ | ||
_logger.LogTrace($"Attempting to receive token, attempt: {retryAttempt}"); | ||
|
||
var token = await FetchToken(); | ||
|
||
_logger.LogInformation($"Received token, expires at: {token.ExpiresAt}"); | ||
|
||
return token; | ||
} | ||
catch (InvalidCredentialsException e) | ||
{ | ||
_logger.LogWarning($"Invalid credentials, {e}"); | ||
throw; | ||
} | ||
catch (Exception e) | ||
{ | ||
_logger.LogDebug($"Failed to fetch token, {e}"); | ||
|
||
if (retryAttempt >= MaxRetries) | ||
{ | ||
_logger.LogWarning($"Can't fetch token, {e}"); | ||
throw; | ||
} | ||
|
||
await Task.Delay(TimeSpan.FromSeconds(Math.Pow(2, retryAttempt))); | ||
_logger.LogInformation($"Failed to fetch token, attempt {retryAttempt}"); | ||
++retryAttempt; | ||
} | ||
} | ||
} | ||
|
||
private async Task<TokenData> FetchToken() | ||
{ | ||
if (_driver is null) | ||
{ | ||
_logger.LogError("Driver in for static auth not provided"); | ||
throw new NullReferenceException(); | ||
} | ||
|
||
var client = new AuthClient(_driver); | ||
var loginResponse = await client.Login(_user, _password); | ||
if (loginResponse.Status.StatusCode == StatusCode.Unauthorized) | ||
{ | ||
throw new InvalidCredentialsException(Issue.IssuesToString(loginResponse.Status.Issues)); | ||
} | ||
|
||
loginResponse.Status.EnsureSuccess(); | ||
var token = loginResponse.Result.Token; | ||
var jwt = new JwtSecurityToken(token); | ||
return new TokenData(token, jwt.ValidTo, RefreshRatio); | ||
} | ||
|
||
public async Task ProvideConfig(DriverConfig driverConfig) | ||
{ | ||
_driver = await Driver.CreateInitialized( | ||
new DriverConfig( | ||
driverConfig.Endpoint, | ||
driverConfig.Database, | ||
new AnonymousProvider(), | ||
driverConfig.DefaultTransportTimeout, | ||
driverConfig.DefaultStreamingTransportTimeout, | ||
driverConfig.CustomServerCertificate)); | ||
|
||
await Initialize(); | ||
} | ||
|
||
private class TokenData | ||
{ | ||
public TokenData(string token, DateTime expiresAt, float refreshInterval) | ||
{ | ||
var now = DateTime.UtcNow; | ||
|
||
Token = token; | ||
ExpiresAt = expiresAt; | ||
|
||
if (expiresAt <= now) | ||
{ | ||
RefreshAt = expiresAt; | ||
} | ||
else | ||
{ | ||
RefreshAt = now + (expiresAt - now) * refreshInterval; | ||
|
||
if (RefreshAt < now) | ||
{ | ||
RefreshAt = expiresAt; | ||
} | ||
} | ||
} | ||
|
||
public string Token { get; } | ||
public DateTime ExpiresAt { get; } | ||
|
||
private DateTime RefreshAt { get; } | ||
|
||
public bool IsExpired() | ||
{ | ||
return DateTime.UtcNow >= ExpiresAt; | ||
} | ||
|
||
public bool IsRefreshNeeded() | ||
{ | ||
return DateTime.UtcNow >= RefreshAt; | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
using Ydb.Sdk.Client; | ||
|
||
namespace Ydb.Sdk.Services.Auth; | ||
|
||
public partial class AuthClient : ClientBase | ||
{ | ||
public AuthClient(Driver driver) : base(driver) | ||
{ | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
using Ydb.Auth; | ||
using Ydb.Auth.V1; | ||
using Ydb.Sdk.Client; | ||
|
||
namespace Ydb.Sdk.Services.Auth; | ||
|
||
public class LoginSettings : OperationRequestSettings | ||
{ | ||
} | ||
|
||
public class LoginResponse : ResponseWithResultBase<LoginResponse.ResultData> | ||
{ | ||
internal LoginResponse(Status status, ResultData? result = null) | ||
: base(status, result) | ||
{ | ||
} | ||
|
||
public class ResultData | ||
{ | ||
public string Token { get; } | ||
|
||
internal ResultData(string token) | ||
{ | ||
Token = token; | ||
} | ||
|
||
|
||
internal static ResultData FromProto(LoginResult resultProto) | ||
{ | ||
var token = resultProto.Token; | ||
return new ResultData(token); | ||
} | ||
} | ||
} | ||
|
||
public partial class AuthClient | ||
{ | ||
public async Task<LoginResponse> Login(string user, string? password, LoginSettings? settings = null) | ||
{ | ||
settings ??= new LoginSettings(); | ||
var request = new LoginRequest | ||
{ | ||
OperationParams = MakeOperationParams(settings), | ||
User = user | ||
}; | ||
if (password is not null) | ||
{ | ||
request.Password = password; | ||
} | ||
|
||
try | ||
{ | ||
var response = await Driver.UnaryCall( | ||
method: AuthService.LoginMethod, | ||
request: request, | ||
settings: settings | ||
); | ||
|
||
var status = UnpackOperation(response.Data.Operation, out LoginResult? resultProto); | ||
|
||
LoginResponse.ResultData? result = null; | ||
|
||
if (status.IsSuccess && resultProto is not null) | ||
{ | ||
result = LoginResponse.ResultData.FromProto(resultProto); | ||
} | ||
|
||
return new LoginResponse(status, result); | ||
} | ||
catch (Driver.TransportException e) | ||
{ | ||
return new LoginResponse(e.Status); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.