Skip to content

Commit

Permalink
basic auth: optimize code and improve the exception/response message (e…
Browse files Browse the repository at this point in the history
…nvoyproxy#30759)

* basic auth: optimize code and improve the exception/response message

Signed-off-by: wbpcode <[email protected]>

* address comments and minor update

Signed-off-by: wbpcode <[email protected]>

---------

Signed-off-by: wbpcode <[email protected]>
  • Loading branch information
code authored Nov 8, 2023
1 parent 1c67529 commit 1e8d60a
Show file tree
Hide file tree
Showing 6 changed files with 155 additions and 87 deletions.
82 changes: 42 additions & 40 deletions source/extensions/filters/http/basic_auth/basic_auth_filter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,12 @@ std::string computeSHA1(absl::string_view password) {

} // namespace

FilterConfig::FilterConfig(UserMapConstPtr users, const std::string& stats_prefix,
Stats::Scope& scope)
FilterConfig::FilterConfig(UserMap&& users, const std::string& stats_prefix, Stats::Scope& scope)
: users_(std::move(users)), stats_(generateStats(stats_prefix + "basic_auth.", scope)) {}

bool FilterConfig::validateUser(absl::string_view username, absl::string_view password) const {
auto user = users_->find(username);
if (user == users_->end()) {
auto user = users_.find(username);
if (user == users_.end()) {
return false;
}

Expand All @@ -43,46 +42,49 @@ BasicAuthFilter::BasicAuthFilter(FilterConfigConstSharedPtr config) : config_(st

Http::FilterHeadersStatus BasicAuthFilter::decodeHeaders(Http::RequestHeaderMap& headers, bool) {
auto auth_header = headers.get(Http::CustomHeaders::get().Authorization);
if (!auth_header.empty()) {
absl::string_view auth_value = auth_header[0]->value().getStringView();

if (absl::StartsWith(auth_value, "Basic ")) {
// Extract and decode the Base64 part of the header.
absl::string_view base64Token = auth_value.substr(6);
const std::string decoded = Base64::decodeWithoutPadding(base64Token);

// The decoded string is in the format "username:password".
const size_t colon_pos = decoded.find(':');

if (colon_pos != std::string::npos) {
absl::string_view decoded_view = decoded;
absl::string_view username = decoded_view.substr(0, colon_pos);
absl::string_view password = decoded_view.substr(colon_pos + 1);

if (config_->validateUser(username, password)) {
config_->stats().allowed_.inc();
return Http::FilterHeadersStatus::Continue;
} else {
config_->stats().denied_.inc();
decoder_callbacks_->sendLocalReply(
Http::Code::Unauthorized,
"User authentication failed. Invalid username/password combination", nullptr,
absl::nullopt, "invalid_credential_for_basic_auth");
return Http::FilterHeadersStatus::StopIteration;
}
}
}

if (auth_header.empty()) {
return onDenied("User authentication failed. Missing username and password.",
"no_credential_for_basic_auth");
}

config_->stats().denied_.inc();
decoder_callbacks_->sendLocalReply(Http::Code::Unauthorized,
"User authentication failed. Missing username and password",
nullptr, absl::nullopt, "no_credential_for_basic_auth");
return Http::FilterHeadersStatus::StopIteration;
absl::string_view auth_value = auth_header[0]->value().getStringView();

if (!absl::StartsWith(auth_value, "Basic ")) {
return onDenied("User authentication failed. Expected 'Basic' authentication scheme.",
"invalid_scheme_for_basic_auth");
}

// Extract and decode the Base64 part of the header.
absl::string_view base64_token = auth_value.substr(6);
const std::string decoded = Base64::decodeWithoutPadding(base64_token);

// The decoded string is in the format "username:password".
const size_t colon_pos = decoded.find(':');
if (colon_pos == std::string::npos) {
return onDenied("User authentication failed. Invalid basic credential format.",
"invalid_format_for_basic_auth");
}

absl::string_view decoded_view = decoded;
absl::string_view username = decoded_view.substr(0, colon_pos);
absl::string_view password = decoded_view.substr(colon_pos + 1);

if (!config_->validateUser(username, password)) {
return onDenied("User authentication failed. Invalid username/password combination.",
"invalid_credential_for_basic_auth");
}

config_->stats().allowed_.inc();
return Http::FilterHeadersStatus::Continue;
}

void BasicAuthFilter::setDecoderFilterCallbacks(Http::StreamDecoderFilterCallbacks& callbacks) {
decoder_callbacks_ = &callbacks;
Http::FilterHeadersStatus BasicAuthFilter::onDenied(absl::string_view body,
absl::string_view response_code_details) {
config_->stats().denied_.inc();
decoder_callbacks_->sendLocalReply(Http::Code::Unauthorized, body, nullptr, absl::nullopt,
response_code_details);
return Http::FilterHeadersStatus::StopIteration;
}

} // namespace BasicAuth
Expand Down
12 changes: 6 additions & 6 deletions source/extensions/filters/http/basic_auth/basic_auth_filter.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,14 @@ struct User {
std::string hash;
};

using UserMapConstPtr =
std::unique_ptr<const absl::flat_hash_map<std::string, User>>; // username, User
using UserMap = absl::flat_hash_map<std::string, User>;

/**
* Configuration for the Basic Auth filter.
*/
class FilterConfig {
public:
FilterConfig(UserMapConstPtr users, const std::string& stats_prefix, Stats::Scope& scope);
FilterConfig(UserMap&& users, const std::string& stats_prefix, Stats::Scope& scope);
const BasicAuthStats& stats() const { return stats_; }
bool validateUser(absl::string_view username, absl::string_view password) const;

Expand All @@ -53,7 +52,7 @@ class FilterConfig {
return BasicAuthStats{ALL_BASIC_AUTH_STATS(POOL_COUNTER_PREFIX(scope, prefix))};
}

UserMapConstPtr users_;
const UserMap users_;
BasicAuthStats stats_;
};
using FilterConfigConstSharedPtr = std::shared_ptr<const FilterConfig>;
Expand All @@ -66,11 +65,12 @@ class BasicAuthFilter : public Http::PassThroughDecoderFilter,

// Http::StreamDecoderFilter
Http::FilterHeadersStatus decodeHeaders(Http::RequestHeaderMap& headers, bool) override;
void setDecoderFilterCallbacks(Http::StreamDecoderFilterCallbacks& callbacks) override;

private:
Http::FilterHeadersStatus onDenied(absl::string_view body,
absl::string_view response_code_details);

// The callback function.
Http::StreamDecoderFilterCallbacks* decoder_callbacks_;
FilterConfigConstSharedPtr config_;
};

Expand Down
53 changes: 32 additions & 21 deletions source/extensions/filters/http/basic_auth/config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,36 +12,47 @@ using envoy::extensions::filters::http::basic_auth::v3::BasicAuth;

namespace {

UserMapConstPtr readHtpasswd(const std::string& htpasswd) {
std::unique_ptr<absl::flat_hash_map<std::string, User>> users =
std::make_unique<absl::flat_hash_map<std::string, User>>();
UserMap readHtpasswd(const std::string& htpasswd) {
UserMap users;

std::istringstream htpsswd_ss(htpasswd);
std::string line;

while (std::getline(htpsswd_ss, line)) {
// TODO(wbpcode): should we trim the spaces or empty chars?

// Skip empty lines and comments.
if (line.empty() || line[0] == '#') {
continue;
}

const size_t colon_pos = line.find(':');
if (colon_pos == std::string::npos) {
throw EnvoyException("basic auth: invalid htpasswd format, username:password is expected");
}

if (colon_pos != std::string::npos) {
std::string name = line.substr(0, colon_pos);
std::string hash = line.substr(colon_pos + 1);
std::string name = line.substr(0, colon_pos);
std::string hash = line.substr(colon_pos + 1);

if (name.empty()) {
throw EnvoyException("basic auth: invalid user name");
}
if (name.empty() || hash.empty()) {
throw EnvoyException("basic auth: empty user name or password");
}

if (users.contains(name)) {
throw EnvoyException("basic auth: duplicate users");
}

if (absl::StartsWith(hash, "{SHA}")) {
hash = hash.substr(5);
// The base64 encoded SHA1 hash is 28 bytes long
if (hash.length() != 28) {
throw EnvoyException("basic auth: invalid SHA hash length");
}
if (!absl::StartsWith(hash, "{SHA}")) {
throw EnvoyException("basic auth: unsupported htpasswd format: please use {SHA}");
}

users->insert({name, {name, hash}});
continue;
}
hash = hash.substr(5);
// The base64 encoded SHA1 hash is 28 bytes long
if (hash.length() != 28) {
throw EnvoyException("basic auth: invalid htpasswd format, invalid SHA hash length");
}

throw EnvoyException("basic auth: unsupported htpasswd format: please use {SHA}");
users.insert({name, {name, hash}});
}

return users;
Expand All @@ -52,8 +63,8 @@ UserMapConstPtr readHtpasswd(const std::string& htpasswd) {
Http::FilterFactoryCb BasicAuthFilterFactory::createFilterFactoryFromProtoTyped(
const BasicAuth& proto_config, const std::string& stats_prefix,
Server::Configuration::FactoryContext& context) {
const std::string htpasswd = Config::DataSource::read(proto_config.users(), false, context.api());
UserMapConstPtr users = readHtpasswd(htpasswd);
UserMap users =
readHtpasswd(Config::DataSource::read(proto_config.users(), false, context.api()));
FilterConfigConstSharedPtr config =
std::make_unique<FilterConfig>(std::move(users), stats_prefix, context.scope());
return [config](Http::FilterChainFactoryCallbacks& callbacks) -> void {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ TEST_P(BasicAuthIntegrationTestAllProtocols, NoCredential) {
ASSERT_TRUE(response->waitForEndStream());
ASSERT_TRUE(response->complete());
EXPECT_EQ("401", response->headers().getStatusValue());
EXPECT_EQ("User authentication failed. Missing username and password", response->body());
EXPECT_EQ("User authentication failed. Missing username and password.", response->body());
}

// Request without wrong password
Expand All @@ -91,7 +91,7 @@ TEST_P(BasicAuthIntegrationTestAllProtocols, WrongPasswrod) {
ASSERT_TRUE(response->waitForEndStream());
ASSERT_TRUE(response->complete());
EXPECT_EQ("401", response->headers().getStatusValue());
EXPECT_EQ("User authentication failed. Invalid username/password combination", response->body());
EXPECT_EQ("User authentication failed. Invalid username/password combination.", response->body());
}

// Request with none-existed user
Expand All @@ -110,7 +110,7 @@ TEST_P(BasicAuthIntegrationTestAllProtocols, NoneExistedUser) {
ASSERT_TRUE(response->waitForEndStream());
ASSERT_TRUE(response->complete());
EXPECT_EQ("401", response->headers().getStatusValue());
EXPECT_EQ("User authentication failed. Invalid username/password combination", response->body());
EXPECT_EQ("User authentication failed. Invalid username/password combination.", response->body());
}
} // namespace
} // namespace BasicAuth
Expand Down
42 changes: 32 additions & 10 deletions test/extensions/filters/http/basic_auth/config_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ TEST(Factory, ValidConfig) {
const std::string yaml = R"(
users:
inline_string: |-
# comment line
user1:{SHA}tESsBmE/yNY3lb6a0L6vVQEZNqw=
user2:{SHA}EJ9LPFDXsN9ynSmbxvjp75Bmlx8=
)";
Expand Down Expand Up @@ -45,8 +46,27 @@ TEST(Factory, InvalidConfigNoColon) {

NiceMock<Server::Configuration::MockFactoryContext> context;

EXPECT_THROW(factory.createFilterFactoryFromProto(*proto_config, "stats", context),
EnvoyException);
EXPECT_THROW_WITH_MESSAGE(factory.createFilterFactoryFromProto(*proto_config, "stats", context),
EnvoyException,
"basic auth: invalid htpasswd format, username:password is expected");
}

TEST(Factory, InvalidConfigDuplicateUsers) {
const std::string yaml = R"(
users:
inline_string: |-
user1:{SHA}tESsBmE/yNY3lb6a0L6vVQEZNqw=
user1:{SHA}EJ9LPFDXsN9ynSmbxvjp75Bmlx8=
)";

BasicAuthFilterFactory factory;
ProtobufTypes::MessagePtr proto_config = factory.createEmptyRouteConfigProto();
TestUtility::loadFromYaml(yaml, *proto_config);

NiceMock<Server::Configuration::MockFactoryContext> context;

EXPECT_THROW_WITH_MESSAGE(factory.createFilterFactoryFromProto(*proto_config, "stats", context),
EnvoyException, "basic auth: duplicate users");
}

TEST(Factory, InvalidConfigNoUser) {
Expand All @@ -63,8 +83,8 @@ TEST(Factory, InvalidConfigNoUser) {

NiceMock<Server::Configuration::MockFactoryContext> context;

EXPECT_THROW(factory.createFilterFactoryFromProto(*proto_config, "stats", context),
EnvoyException);
EXPECT_THROW_WITH_MESSAGE(factory.createFilterFactoryFromProto(*proto_config, "stats", context),
EnvoyException, "basic auth: empty user name or password");
}

TEST(Factory, InvalidConfigNoPassword) {
Expand All @@ -81,8 +101,8 @@ TEST(Factory, InvalidConfigNoPassword) {

NiceMock<Server::Configuration::MockFactoryContext> context;

EXPECT_THROW(factory.createFilterFactoryFromProto(*proto_config, "stats", context),
EnvoyException);
EXPECT_THROW_WITH_MESSAGE(factory.createFilterFactoryFromProto(*proto_config, "stats", context),
EnvoyException, "basic auth: empty user name or password");
}

TEST(Factory, InvalidConfigNoHash) {
Expand All @@ -99,8 +119,9 @@ TEST(Factory, InvalidConfigNoHash) {

NiceMock<Server::Configuration::MockFactoryContext> context;

EXPECT_THROW(factory.createFilterFactoryFromProto(*proto_config, "stats", context),
EnvoyException);
EXPECT_THROW_WITH_MESSAGE(factory.createFilterFactoryFromProto(*proto_config, "stats", context),
EnvoyException,
"basic auth: invalid htpasswd format, invalid SHA hash length");
}

TEST(Factory, InvalidConfigNotSHA) {
Expand All @@ -117,8 +138,9 @@ TEST(Factory, InvalidConfigNotSHA) {

NiceMock<Server::Configuration::MockFactoryContext> context;

EXPECT_THROW(factory.createFilterFactoryFromProto(*proto_config, "stats", context),
EnvoyException);
EXPECT_THROW_WITH_MESSAGE(factory.createFilterFactoryFromProto(*proto_config, "stats", context),
EnvoyException,
"basic auth: unsupported htpasswd format: please use {SHA}");
}

} // namespace BasicAuth
Expand Down
Loading

0 comments on commit 1e8d60a

Please sign in to comment.