Skip to content

Commit

Permalink
Refactor aws_lambda config settings (envoyproxy#33042)
Browse files Browse the repository at this point in the history
Signed-off-by: Juan Manuel Ollé <[email protected]>
  • Loading branch information
juanmolle authored Apr 2, 2024
1 parent c1bbeac commit d92decb
Show file tree
Hide file tree
Showing 5 changed files with 199 additions and 145 deletions.
61 changes: 23 additions & 38 deletions source/extensions/filters/http/aws_lambda/aws_lambda_filter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,32 +116,16 @@ bool isContentTypeTextual(const Http::RequestOrResponseHeaderMap& headers) {
} // namespace

// TODO(nbaws) Implement Sigv4a support
Filter::Filter(const FilterSettings& settings, const FilterStats& stats,
const std::shared_ptr<Extensions::Common::Aws::Signer>& sigv4_signer,
bool is_upstream)
: settings_(settings), stats_(stats), sigv4_signer_(sigv4_signer), is_upstream_(is_upstream) {}

absl::optional<FilterSettings> Filter::getRouteSpecificSettings() const {
const auto* settings =
Http::Utility::resolveMostSpecificPerFilterConfig<FilterSettings>(decoder_callbacks_);
if (!settings) {
return absl::nullopt;
}

return *settings;
}
Filter::Filter(const FilterSettingsSharedPtr& settings, const FilterStats& stats, bool is_upstream)
: settings_(settings), stats_(stats), is_upstream_(is_upstream) {}

void Filter::resolveSettings() {
if (auto route_settings = getRouteSpecificSettings()) {
payload_passthrough_ = route_settings->payloadPassthrough();
invocation_mode_ = route_settings->invocationMode();
arn_ = route_settings->arn();
host_rewrite_ = route_settings->hostRewrite();
} else {
payload_passthrough_ = settings_.payloadPassthrough();
invocation_mode_ = settings_.invocationMode();
host_rewrite_ = settings_.hostRewrite();
FilterSettings& Filter::getSettings() {
auto* settings = const_cast<FilterSettings*>(
Http::Utility::resolveMostSpecificPerFilterConfig<FilterSettings>(decoder_callbacks_));
if (settings) {
return *settings;
}
return *settings_;
}

Http::FilterHeadersStatus Filter::decodeHeaders(Http::RequestHeaderMap& headers, bool end_stream) {
Expand All @@ -154,33 +138,29 @@ Http::FilterHeadersStatus Filter::decodeHeaders(Http::RequestHeaderMap& headers,
}
}

resolveSettings();

if (!arn_) {
arn_ = settings_.arn();
}
auto& settings = getSettings();

if (!end_stream) {
request_headers_ = &headers;
return Http::FilterHeadersStatus::StopIteration;
}

if (payload_passthrough_) {
setLambdaHeaders(headers, arn_, invocation_mode_, host_rewrite_);
sigv4_signer_->signEmptyPayload(headers, arn_->region());
if (settings.payloadPassthrough()) {
setLambdaHeaders(headers, settings.arn(), settings.invocationMode(), settings.hostRewrite());
settings.signer().signEmptyPayload(headers, settings.arn().region());
return Http::FilterHeadersStatus::Continue;
}

Buffer::OwnedImpl json_buf;
jsonizeRequest(headers, nullptr, json_buf);
// We must call setLambdaHeaders *after* the JSON transformation of the request. That way we
// reflect the actual incoming request headers instead of the overwritten ones.
setLambdaHeaders(headers, arn_, invocation_mode_, host_rewrite_);
setLambdaHeaders(headers, settings.arn(), settings.invocationMode(), settings.hostRewrite());
headers.setContentLength(json_buf.length());
headers.setReferenceContentType(Http::Headers::get().ContentTypeValues.Json);
auto& hashing_util = Envoy::Common::Crypto::UtilitySingleton::get();
const auto hash = Hex::encode(hashing_util.getSha256Digest(json_buf));
sigv4_signer_->sign(headers, hash, arn_->region());
settings.signer().sign(headers, hash, settings.arn().region());
decoder_callbacks_->addDecodedData(json_buf, false);
return Http::FilterHeadersStatus::Continue;
}
Expand Down Expand Up @@ -223,7 +203,9 @@ Http::FilterDataStatus Filter::decodeData(Buffer::Instance& data, bool end_strea

const Buffer::Instance& decoding_buffer = *decoder_callbacks_->decodingBuffer();

if (!payload_passthrough_) {
auto& settings = getSettings();

if (!settings.payloadPassthrough()) {
decoder_callbacks_->modifyDecodingBuffer([this](Buffer::Instance& dec_buf) {
Buffer::OwnedImpl json_buf;
jsonizeRequest(*request_headers_, &dec_buf, json_buf);
Expand All @@ -235,15 +217,18 @@ Http::FilterDataStatus Filter::decodeData(Buffer::Instance& data, bool end_strea
request_headers_->setReferenceContentType(Http::Headers::get().ContentTypeValues.Json);
}

setLambdaHeaders(*request_headers_, arn_, invocation_mode_, host_rewrite_);
setLambdaHeaders(*request_headers_, settings.arn(), settings.invocationMode(),
settings.hostRewrite());
const auto hash = Hex::encode(hashing_util.getSha256Digest(decoding_buffer));
sigv4_signer_->sign(*request_headers_, hash, arn_->region());
settings.signer().sign(*request_headers_, hash, settings.arn().region());
stats().upstream_rq_payload_size_.recordValue(decoding_buffer.length());
return Http::FilterDataStatus::Continue;
}

Http::FilterDataStatus Filter::encodeData(Buffer::Instance& data, bool end_stream) {
if (skip_ || payload_passthrough_ || invocation_mode_ == InvocationMode::Asynchronous) {
auto& settings = getSettings();
if (skip_ || settings.payloadPassthrough() ||
settings.invocationMode() == InvocationMode::Asynchronous) {
return Http::FilterDataStatus::Continue;
}

Expand Down
55 changes: 31 additions & 24 deletions source/extensions/filters/http/aws_lambda/aws_lambda_filter.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,67 +80,74 @@ enum class InvocationMode { Synchronous, Asynchronous };

class FilterSettings : public Router::RouteSpecificFilterConfig {
public:
FilterSettings(const Arn& arn, InvocationMode mode, bool payload_passthrough,
const std::string& host_rewrite)
~FilterSettings() override = default;

virtual const Arn& arn() const PURE;
virtual bool payloadPassthrough() const PURE;
virtual InvocationMode invocationMode() const PURE;
virtual const std::string& hostRewrite() const PURE;
virtual Extensions::Common::Aws::Signer& signer() PURE;
};

class FilterSettingsImpl : public FilterSettings {
public:
FilterSettingsImpl(const Arn& arn, InvocationMode mode, bool payload_passthrough,
const std::string& host_rewrite, Extensions::Common::Aws::SignerPtr&& signer)
: arn_(arn), invocation_mode_(mode), payload_passthrough_(payload_passthrough),
host_rewrite_(host_rewrite) {}
host_rewrite_(host_rewrite), signer_(std::move(signer)) {}

const Arn& arn() const& { return arn_; }
bool payloadPassthrough() const { return payload_passthrough_; }
InvocationMode invocationMode() const { return invocation_mode_; }
const std::string& hostRewrite() const { return host_rewrite_; }
const Arn& arn() const override { return arn_; }
bool payloadPassthrough() const override { return payload_passthrough_; }
InvocationMode invocationMode() const override { return invocation_mode_; }
const std::string& hostRewrite() const override { return host_rewrite_; }
Extensions::Common::Aws::Signer& signer() override { return *signer_; }

private:
Arn arn_;
InvocationMode invocation_mode_;
bool payload_passthrough_;
const std::string host_rewrite_;
Extensions::Common::Aws::SignerPtr signer_;
};

using FilterSettingsSharedPtr = std::shared_ptr<FilterSettings>;

class Filter : public Http::PassThroughFilter, Logger::Loggable<Logger::Id::filter> {

public:
Filter(const FilterSettings& config, const FilterStats& stats,
const std::shared_ptr<Extensions::Common::Aws::Signer>& sigv4_signer, bool is_upstream);
Filter(const FilterSettingsSharedPtr& settings, const FilterStats& stats, bool is_upstream);

Http::FilterHeadersStatus decodeHeaders(Http::RequestHeaderMap&, bool end_stream) override;
Http::FilterDataStatus decodeData(Buffer::Instance& data, bool end_stream) override;

Http::FilterHeadersStatus encodeHeaders(Http::ResponseHeaderMap&, bool end_stream) override;
Http::FilterDataStatus encodeData(Buffer::Instance& data, bool end_stream) override;

/**
* Calculates the function ARN, value of pass-through, etc. by checking per-filter configurations
* and general filter configuration. Ultimately, the most specific configuration wins.
* @return error message if settings are invalid. Otherwise, empty string.
*/
void resolveSettings();
FilterStats& stats() { return stats_; }

/**
* Used for unit testing only
*/
const FilterSettings& settingsForTest() const { return settings_; }
const FilterSettings& settingsForTest() const { return *settings_; }

private:
absl::optional<FilterSettings> getRouteSpecificSettings() const;
/**
* Calculates the function ARN, value of pass-through, etc. by checking per-filter configurations
* and general filter configuration. Ultimately, the most specific configuration is returned.
*/
FilterSettings& getSettings();
// Convert the HTTP request to JSON request.
void jsonizeRequest(const Http::RequestHeaderMap& headers, const Buffer::Instance* body,
Buffer::Instance& out) const;
// Convert the JSON response to a standard HTTP response.
void dejsonizeResponse(Http::ResponseHeaderMap& headers, const Buffer::Instance& body,
Buffer::Instance& out);
const FilterSettings settings_;

FilterSettingsSharedPtr settings_;
FilterStats stats_;
Http::RequestHeaderMap* request_headers_ = nullptr;
Http::ResponseHeaderMap* response_headers_ = nullptr;
std::shared_ptr<Extensions::Common::Aws::Signer> sigv4_signer_;
absl::optional<Arn> arn_;
InvocationMode invocation_mode_ = InvocationMode::Synchronous;
bool payload_passthrough_ = false;
bool skip_ = false;
bool is_upstream_ = false;
std::string host_rewrite_;
};

} // namespace AwsLambdaFilter
Expand Down
43 changes: 30 additions & 13 deletions source/extensions/filters/http/aws_lambda/config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,36 +50,53 @@ absl::StatusOr<Http::FilterFactoryCb> AwsLambdaFilterFactory::createFilterFactor
server_context.api(), makeOptRef(server_context), region,
Extensions::Common::Aws::Utility::fetchMetadata);

auto signer = std::make_shared<Extensions::Common::Aws::SigV4SignerImpl>(
auto signer = std::make_unique<Extensions::Common::Aws::SigV4SignerImpl>(
service_name, region, std::move(credentials_provider), server_context,
// TODO: extend API to allow specifying header exclusion. ref:
// https://github.com/envoyproxy/envoy/pull/18998
Extensions::Common::Aws::AwsSigningHeaderExclusionVector{});

FilterSettings filter_settings{*arn, getInvocationMode(proto_config),
proto_config.payload_passthrough(), proto_config.host_rewrite()};
auto filter_settings = std::make_shared<FilterSettingsImpl>(
*arn, getInvocationMode(proto_config), proto_config.payload_passthrough(),
proto_config.host_rewrite(), std::move(signer));

FilterStats stats = generateStats(stats_prefix, dual_info.scope);
return [stats, signer, filter_settings, dual_info](Http::FilterChainFactoryCallbacks& cb) {
auto filter = std::make_shared<Filter>(filter_settings, stats, signer, dual_info.is_upstream);
return [stats, filter_settings, dual_info](Http::FilterChainFactoryCallbacks& cb) -> void {
auto filter = std::make_shared<Filter>(filter_settings, stats, dual_info.is_upstream);
cb.addStreamFilter(filter);
};
}

Router::RouteSpecificFilterConfigConstSharedPtr
AwsLambdaFilterFactory::createRouteSpecificFilterConfigTyped(
const envoy::extensions::filters::http::aws_lambda::v3::PerRouteConfig& proto_config,
Server::Configuration::ServerFactoryContext&, ProtobufMessage::ValidationVisitor&) {
const envoy::extensions::filters::http::aws_lambda::v3::PerRouteConfig& per_route_config,
Server::Configuration::ServerFactoryContext& server_context,
ProtobufMessage::ValidationVisitor&) {

const auto arn = parseArn(proto_config.invoke_config().arn());
const auto arn = parseArn(per_route_config.invoke_config().arn());
if (!arn) {
throw EnvoyException(
fmt::format("aws_lambda_filter: Invalid ARN: {}", proto_config.invoke_config().arn()));
fmt::format("aws_lambda_filter: Invalid ARN: {}", per_route_config.invoke_config().arn()));
}
return std::make_shared<const FilterSettings>(
FilterSettings{*arn, getInvocationMode(proto_config.invoke_config()),
proto_config.invoke_config().payload_passthrough(),
proto_config.invoke_config().host_rewrite()});
const std::string region = arn->region();

auto credentials_provider =
std::make_shared<Extensions::Common::Aws::DefaultCredentialsProviderChain>(
server_context.api(), makeOptRef(server_context), region,
Extensions::Common::Aws::Utility::fetchMetadata);

auto signer = std::make_unique<Extensions::Common::Aws::SigV4SignerImpl>(
service_name, region, std::move(credentials_provider), server_context,
// TODO: extend API to allow specifying header exclusion. ref:
// https://github.com/envoyproxy/envoy/pull/18998
Extensions::Common::Aws::AwsSigningHeaderExclusionVector{});

auto filter_settings = std::make_shared<FilterSettingsImpl>(
*arn, getInvocationMode(per_route_config.invoke_config()),
per_route_config.invoke_config().payload_passthrough(),
per_route_config.invoke_config().host_rewrite(), std::move(signer));

return filter_settings;
}

/*
Expand Down
Loading

0 comments on commit d92decb

Please sign in to comment.