Skip to content

Commit

Permalink
[FLINK-26808][rest] Limit file uploads to allowed handlers
Browse files Browse the repository at this point in the history
  • Loading branch information
uce committed May 25, 2024
1 parent 5b535e1 commit 1f80848
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
import java.nio.file.Path;
import java.util.Collections;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;

import static java.util.Objects.requireNonNull;
Expand All @@ -79,15 +80,16 @@ public class FileUploadHandler extends SimpleChannelInboundHandler<HttpObject> {

private final Path uploadDir;

// File uploads are only allowed to these URIs. Others return an error.
private final Set<String> allowedUploadUris;

private HttpPostRequestDecoder currentHttpPostRequestDecoder;

private HttpRequest currentHttpRequest;
private byte[] currentJsonPayload;
private Path currentUploadDir;

private boolean addCRPrefix = false;

public FileUploadHandler(final Path uploadDir) {
public FileUploadHandler(final Path uploadDir, final Set<String> allowedUploadUris) {
super(true);

// the clean up of temp files when jvm exits is handled by
Expand All @@ -103,6 +105,7 @@ public FileUploadHandler(final Path uploadDir) {
DiskAttribute.baseDirectory = DiskFileUpload.baseDirectory;

this.uploadDir = requireNonNull(uploadDir);
this.allowedUploadUris = requireNonNull(allowedUploadUris);
}

@Override
Expand All @@ -125,6 +128,19 @@ protected void channelRead0(final ChannelHandlerContext ctx, final HttpObject ms
new HttpPostRequestDecoder(DATA_FACTORY, httpRequest);
currentHttpRequest = ReferenceCountUtil.retain(httpRequest);

// We check this after initializing the multipart file upload in order for
// handleError to work correctly as it expects to be only called with
// certain fields set.
if (!allowedUploadUris.contains(httpRequest.uri())) {
LOG.trace("File upload not allowed for {}.", httpRequest.uri());
handleError(
ctx,
"File upload not allowed",
HttpResponseStatus.BAD_REQUEST,
null);
return;
}

// make sure that we still have a upload dir in case that it got deleted in
// the meanwhile
RestServerEndpoint.createUploadDir(uploadDir, LOG, false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.apache.flink.runtime.rest.handler.RestHandlerSpecification;
import org.apache.flink.runtime.rest.handler.router.Router;
import org.apache.flink.runtime.rest.handler.router.RouterHandler;
import org.apache.flink.runtime.rest.messages.UntypedResponseMessageHeaders;
import org.apache.flink.runtime.rest.versioning.RestAPIVersion;
import org.apache.flink.util.AutoCloseableAsync;
import org.apache.flink.util.ConfigurationException;
Expand Down Expand Up @@ -196,6 +197,10 @@ public final void start() throws Exception {
checkAllEndpointsAndHandlersAreUnique(handlers);
handlers.forEach(handler -> registerHandler(router, handler, log));

// Prepare allow list for file uploads
Set<String> fileUploadAllowedUris = extractAllowedFileUploadUris(handlers);
log.trace("Using {} as allow list for file uploads", fileUploadAllowedUris);

ChannelInitializer<SocketChannel> initializer =
new ChannelInitializer<SocketChannel>() {

Expand All @@ -216,7 +221,8 @@ protected void initChannel(SocketChannel ch) throws ConfigurationException {

ch.pipeline()
.addLast(new HttpServerCodec())
.addLast(new FileUploadHandler(uploadDir))
.addLast(
new FileUploadHandler(uploadDir, fileUploadAllowedUris))
.addLast(
new FlinkHttpObjectAggregator(
maxContentLength, responseHeaders));
Expand Down Expand Up @@ -513,35 +519,14 @@ private static void registerHandler(
Router router,
Tuple2<RestHandlerSpecification, ChannelInboundHandler> specificationHandler,
Logger log) {
final String handlerURL = specificationHandler.f0.getTargetRestEndpointURL();
// setup versioned urls
for (final RestAPIVersion supportedVersion :
specificationHandler.f0.getSupportedAPIVersions()) {
final String versionedHandlerURL =
'/' + supportedVersion.getURLVersionPrefix() + handlerURL;
for (String url : getHandlerUrls(specificationHandler.f0)) {
log.debug(
"Register handler {} under {}@{}.",
specificationHandler.f1,
specificationHandler.f0.getHttpMethod(),
versionedHandlerURL);
url);
registerHandler(
router,
versionedHandlerURL,
specificationHandler.f0.getHttpMethod(),
specificationHandler.f1);
if (supportedVersion.isDefaultVersion()) {
// setup unversioned url for convenience and backwards compatibility
log.debug(
"Register handler {} under {}@{}.",
specificationHandler.f1,
specificationHandler.f0.getHttpMethod(),
handlerURL);
registerHandler(
router,
handlerURL,
specificationHandler.f0.getHttpMethod(),
specificationHandler.f1);
}
router, url, specificationHandler.f0.getHttpMethod(), specificationHandler.f1);
}
}

Expand Down Expand Up @@ -653,6 +638,41 @@ private static void checkAllEndpointsAndHandlersAreUnique(
}
}

private static Set<String> extractAllowedFileUploadUris(
List<Tuple2<RestHandlerSpecification, ChannelInboundHandler>> handlers) {
Set<String> fileUploadAllowedUris = new HashSet<>();
for (Tuple2<RestHandlerSpecification, ChannelInboundHandler> handler : handlers) {
// The cast is not pretty, but in the current design only UntypedResponseMessageHeaders
// exposes whether the handler accepts file uploads. Note that handlers that accept
// a file upload but who are not of this type will not be allowed file uploads.
if (handler.f0 instanceof UntypedResponseMessageHeaders) {
UntypedResponseMessageHeaders<?, ?> header =
(UntypedResponseMessageHeaders<?, ?>) handler.f0;
if (header.acceptsFileUploads()) {
for (String url : getHandlerUrls(header)) {
fileUploadAllowedUris.add(url);
}
}
}
}
return fileUploadAllowedUris;
}

private static Iterable<String> getHandlerUrls(RestHandlerSpecification handlerSpec) {
final List<String> registeredUrls = new ArrayList<>();
final String handlerUrl = handlerSpec.getTargetRestEndpointURL();
for (RestAPIVersion<?> supportedVersion : handlerSpec.getSupportedAPIVersions()) {
String versionedUrl = '/' + supportedVersion.getURLVersionPrefix() + handlerUrl;
registeredUrls.add(versionedUrl);

if (supportedVersion.isDefaultVersion()) {
// Unversioned URL for convenience and backwards compatibility
registeredUrls.add(handlerUrl);
}
}
return registeredUrls;
}

/**
* Comparator for Rest URLs.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -374,32 +374,45 @@ void testShouldRespectMaxContentLengthLimitForResponses() throws Exception {
@TestTemplate
void testFileUpload() throws Exception {
final String boundary = generateMultiPartBoundary();
final String crlf = "\r\n";
final String uploadedContent = "hello";
final HttpURLConnection connection = openHttpConnectionForUpload(boundary);

try (OutputStream output = connection.getOutputStream();
PrintWriter writer =
new PrintWriter(
new OutputStreamWriter(output, StandardCharsets.UTF_8), true)) {
final HttpURLConnection connection =
openHttpConnectionForUpload(
boundary, TestUploadHeaders.INSTANCE.getTargetRestEndpointURL());

writer.append("--" + boundary).append(crlf);
writer.append("Content-Disposition: form-data; name=\"foo\"; filename=\"bar\"")
.append(crlf);
writer.append("Content-Type: plain/text; charset=utf8").append(crlf);
writer.append(crlf).flush();
output.write(uploadedContent.getBytes(StandardCharsets.UTF_8));
output.flush();
writer.append(crlf).flush();
writer.append("--" + boundary + "--").append(crlf).flush();
}
uploadFile(connection, uploadedContent, boundary);

assertThat(connection.getResponseCode()).isEqualTo(200);
final byte[] lastUploadedFileContents = testUploadHandler.getLastUploadedFileContents();
assertThat(uploadedContent)
.isEqualTo(new String(lastUploadedFileContents, StandardCharsets.UTF_8));
}

/**
* Tests that when a handler is marked as not accepting file uploads we (1) return an error and
* (2) don't upload the file to the upload directory.
*/
@TestTemplate
void testFileUploadLimitedToAllowedUris() throws Exception {
final String boundary = generateMultiPartBoundary();
final File uploadDir = new File(tempFolder.toString(), "flink-web-upload");
final File[] preUploadFiles = uploadDir.listFiles();

// We need a handler that does not accept file uploads for this test
assertThat(TestVersionHeaders.INSTANCE.acceptsFileUploads()).isFalse();
String uri = TestVersionHeaders.INSTANCE.getTargetRestEndpointURL();

final HttpURLConnection connection = openHttpConnectionForUpload(boundary, uri);

uploadFile(connection, "hello", boundary);

assertThat(connection.getResponseCode()).isEqualTo(400);

// This is the important check. We don't want additional files when the handler does
// not accept file uploads.
final File[] postUploadFiles = uploadDir.listFiles();
assertThat(postUploadFiles).isEqualTo(preUploadFiles);
}

/**
* Sending multipart/form-data without a file should result in a bad request if the handler
* expects a file upload.
Expand All @@ -408,7 +421,9 @@ void testFileUpload() throws Exception {
void testMultiPartFormDataWithoutFileUpload() throws Exception {
final String boundary = generateMultiPartBoundary();
final String crlf = "\r\n";
final HttpURLConnection connection = openHttpConnectionForUpload(boundary);
final HttpURLConnection connection =
openHttpConnectionForUpload(
boundary, TestUploadHeaders.INSTANCE.getTargetRestEndpointURL());

try (OutputStream output = connection.getOutputStream();
PrintWriter writer =
Expand Down Expand Up @@ -715,11 +730,11 @@ private static File getTestResource(final String fileName) {
return new File(resource.getFile());
}

private HttpURLConnection openHttpConnectionForUpload(final String boundary)
throws IOException {
private HttpURLConnection openHttpConnectionForUpload(
final String boundary, final String uploadUri) throws IOException {
final HttpURLConnection connection =
(HttpURLConnection)
new URL(serverEndpoint.getRestBaseUrl() + "/upload").openConnection();
new URL(serverEndpoint.getRestBaseUrl() + uploadUri).openConnection();
connection.setDoOutput(true);
connection.setRequestProperty("Content-Type", "multipart/form-data; boundary=" + boundary);
return connection;
Expand All @@ -737,6 +752,26 @@ private static String createStringOfSize(int size) {
return sb.toString();
}

private static void uploadFile(HttpURLConnection connection, String content, String boundary)
throws IOException {
final String crlf = "\r\n";
try (OutputStream output = connection.getOutputStream();
PrintWriter writer =
new PrintWriter(
new OutputStreamWriter(output, StandardCharsets.UTF_8), true)) {

writer.append("--" + boundary).append(crlf);
writer.append("Content-Disposition: form-data; name=\"foo\"; filename=\"bar\"")
.append(crlf);
writer.append("Content-Type: plain/text; charset=utf8").append(crlf);
writer.append(crlf).flush();
output.write(content.getBytes(StandardCharsets.UTF_8));
output.flush();
writer.append(crlf).flush();
writer.append("--" + boundary + "--").append(crlf).flush();
}
}

private static class TestHandler
extends AbstractRestHandler<RestfulGateway, TestRequest, TestResponse, TestParameters> {

Expand Down

0 comments on commit 1f80848

Please sign in to comment.