Skip to content

Commit

Permalink
[FLINK-26808][rest] Only accept file upload at multipart routes
Browse files Browse the repository at this point in the history
  • Loading branch information
uce committed May 28, 2024
1 parent 52105a4 commit da45a78
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
package org.apache.flink.runtime.rest;

import org.apache.flink.runtime.rest.handler.FileUploads;
import org.apache.flink.runtime.rest.handler.router.MultipartRoutes;
import org.apache.flink.runtime.rest.handler.util.HandlerUtils;
import org.apache.flink.runtime.rest.messages.ErrorResponseBody;
import org.apache.flink.runtime.rest.util.RestConstants;
Expand Down Expand Up @@ -79,15 +80,15 @@ public class FileUploadHandler extends SimpleChannelInboundHandler<HttpObject> {

private final Path uploadDir;

private final MultipartRoutes multipartRoutes;

private HttpPostRequestDecoder currentHttpPostRequestDecoder;

private HttpRequest currentHttpRequest;
private byte[] currentJsonPayload;
private Path currentUploadDir;

private boolean addCRPrefix = false;

public FileUploadHandler(final Path uploadDir) {
public FileUploadHandler(final Path uploadDir, final MultipartRoutes multipartRoutes) {
super(true);

// the clean up of temp files when jvm exits is handled by
Expand All @@ -103,6 +104,7 @@ public FileUploadHandler(final Path uploadDir) {
DiskAttribute.baseDirectory = DiskFileUpload.baseDirectory;

this.uploadDir = requireNonNull(uploadDir);
this.multipartRoutes = requireNonNull(multipartRoutes);
}

@Override
Expand All @@ -125,6 +127,18 @@ protected void channelRead0(final ChannelHandlerContext ctx, final HttpObject ms
new HttpPostRequestDecoder(DATA_FACTORY, httpRequest);
currentHttpRequest = ReferenceCountUtil.retain(httpRequest);

// We check this after initializing the multipart file upload in order for
// handleError to work correctly.
if (!multipartRoutes.isPostRoute(httpRequest.uri())) {
LOG.trace("POST request not allowed for {}.", httpRequest.uri());
handleError(
ctx,
"POST request not allowed",
HttpResponseStatus.BAD_REQUEST,
null);
return;
}

// make sure that we still have a upload dir in case that it got deleted in
// the meanwhile
RestServerEndpoint.createUploadDir(uploadDir, LOG, false);
Expand All @@ -151,6 +165,17 @@ protected void channelRead0(final ChannelHandlerContext ctx, final HttpObject ms
&& hasNext(currentHttpPostRequestDecoder)) {
final InterfaceHttpData data = currentHttpPostRequestDecoder.next();
if (data.getHttpDataType() == InterfaceHttpData.HttpDataType.FileUpload) {
HttpRequest httpRequest = currentHttpRequest;
if (!multipartRoutes.isFileUploadRoute(httpRequest.uri())) {
LOG.trace("File upload not allowed for {}.", httpRequest.uri());
handleError(
ctx,
"File upload not allowed",
HttpResponseStatus.BAD_REQUEST,
null);
return;
}

final DiskFileUpload fileUpload = (DiskFileUpload) data;
checkState(fileUpload.isCompleted());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@
import org.apache.flink.runtime.net.RedirectingSslHandler;
import org.apache.flink.runtime.rest.handler.PipelineErrorHandler;
import org.apache.flink.runtime.rest.handler.RestHandlerSpecification;
import org.apache.flink.runtime.rest.handler.router.MultipartRoutes;
import org.apache.flink.runtime.rest.handler.router.Router;
import org.apache.flink.runtime.rest.handler.router.RouterHandler;
import org.apache.flink.runtime.rest.messages.UntypedResponseMessageHeaders;
import org.apache.flink.runtime.rest.versioning.RestAPIVersion;
import org.apache.flink.util.AutoCloseableAsync;
import org.apache.flink.util.ConfigurationException;
Expand Down Expand Up @@ -196,6 +198,9 @@ public final void start() throws Exception {
checkAllEndpointsAndHandlersAreUnique(handlers);
handlers.forEach(handler -> registerHandler(router, handler, log));

MultipartRoutes multipartRoutes = createMultipartRoutes(handlers);
log.debug("Using {} for FileUploadHandler", multipartRoutes);

ChannelInitializer<SocketChannel> initializer =
new ChannelInitializer<SocketChannel>() {

Expand All @@ -216,7 +221,7 @@ protected void initChannel(SocketChannel ch) throws ConfigurationException {

ch.pipeline()
.addLast(new HttpServerCodec())
.addLast(new FileUploadHandler(uploadDir))
.addLast(new FileUploadHandler(uploadDir, multipartRoutes))
.addLast(
new FlinkHttpObjectAggregator(
maxContentLength, responseHeaders));
Expand Down Expand Up @@ -635,6 +640,32 @@ private static void checkAllEndpointsAndHandlersAreUnique(
}
}

private MultipartRoutes createMultipartRoutes(
List<Tuple2<RestHandlerSpecification, ChannelInboundHandler>> handlers) {
MultipartRoutes.Builder builder = new MultipartRoutes.Builder();

for (Tuple2<RestHandlerSpecification, ChannelInboundHandler> handler : handlers) {
if (handler.f0.getHttpMethod() == HttpMethodWrapper.POST) {
for (String url : getHandlerRoutes(handler.f0)) {
builder.addPostRoute(url);
}
}

// The cast is necessary, because currently only UntypedResponseMessageHeaders exposes
// whether the handler accepts file uploads.
if (handler.f0 instanceof UntypedResponseMessageHeaders) {
UntypedResponseMessageHeaders<?, ?> header =
(UntypedResponseMessageHeaders<?, ?>) handler.f0;
if (header.acceptsFileUploads()) {
for (String url : getHandlerRoutes(header)) {
builder.addFileUploadRoute(url);
}
}
}
}
return builder.build();
}

private static Iterable<String> getHandlerRoutes(RestHandlerSpecification handlerSpec) {
final List<String> registeredRoutes = new ArrayList<>();
final String handlerUrl = handlerSpec.getTargetRestEndpointURL();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -374,32 +374,45 @@ void testShouldRespectMaxContentLengthLimitForResponses() throws Exception {
@TestTemplate
void testFileUpload() throws Exception {
final String boundary = generateMultiPartBoundary();
final String crlf = "\r\n";
final String uploadedContent = "hello";
final HttpURLConnection connection = openHttpConnectionForUpload(boundary);

try (OutputStream output = connection.getOutputStream();
PrintWriter writer =
new PrintWriter(
new OutputStreamWriter(output, StandardCharsets.UTF_8), true)) {
final HttpURLConnection connection =
openHttpConnectionForUpload(
boundary, TestUploadHeaders.INSTANCE.getTargetRestEndpointURL());

writer.append("--" + boundary).append(crlf);
writer.append("Content-Disposition: form-data; name=\"foo\"; filename=\"bar\"")
.append(crlf);
writer.append("Content-Type: plain/text; charset=utf8").append(crlf);
writer.append(crlf).flush();
output.write(uploadedContent.getBytes(StandardCharsets.UTF_8));
output.flush();
writer.append(crlf).flush();
writer.append("--" + boundary + "--").append(crlf).flush();
}
uploadFile(connection, uploadedContent, boundary);

assertThat(connection.getResponseCode()).isEqualTo(200);
final byte[] lastUploadedFileContents = testUploadHandler.getLastUploadedFileContents();
assertThat(uploadedContent)
.isEqualTo(new String(lastUploadedFileContents, StandardCharsets.UTF_8));
}

/**
* Tests that when a handler is marked as not accepting file uploads we (1) return an error and
* (2) don't upload the file to the upload directory.
*/
@TestTemplate
void testFileUploadLimitedToAllowedUris() throws Exception {
final String boundary = generateMultiPartBoundary();
final File uploadDir = new File(tempFolder.toString(), "flink-web-upload");
final File[] preUploadFiles = uploadDir.listFiles();

// We need a handler that does not accept file uploads for this test
assertThat(TestVersionHeaders.INSTANCE.acceptsFileUploads()).isFalse();
String uri = TestVersionHeaders.INSTANCE.getTargetRestEndpointURL();

final HttpURLConnection connection = openHttpConnectionForUpload(boundary, uri);

uploadFile(connection, "hello", boundary);

assertThat(connection.getResponseCode()).isEqualTo(400);

// This is the important check. We don't want additional files when the handler does
// not accept file uploads.
final File[] postUploadFiles = uploadDir.listFiles();
assertThat(postUploadFiles).isEqualTo(preUploadFiles);
}

/**
* Sending multipart/form-data without a file should result in a bad request if the handler
* expects a file upload.
Expand All @@ -408,7 +421,9 @@ void testFileUpload() throws Exception {
void testMultiPartFormDataWithoutFileUpload() throws Exception {
final String boundary = generateMultiPartBoundary();
final String crlf = "\r\n";
final HttpURLConnection connection = openHttpConnectionForUpload(boundary);
final HttpURLConnection connection =
openHttpConnectionForUpload(
boundary, TestUploadHeaders.INSTANCE.getTargetRestEndpointURL());

try (OutputStream output = connection.getOutputStream();
PrintWriter writer =
Expand Down Expand Up @@ -715,11 +730,11 @@ private static File getTestResource(final String fileName) {
return new File(resource.getFile());
}

private HttpURLConnection openHttpConnectionForUpload(final String boundary)
throws IOException {
private HttpURLConnection openHttpConnectionForUpload(
final String boundary, final String uploadUri) throws IOException {
final HttpURLConnection connection =
(HttpURLConnection)
new URL(serverEndpoint.getRestBaseUrl() + "/upload").openConnection();
new URL(serverEndpoint.getRestBaseUrl() + uploadUri).openConnection();
connection.setDoOutput(true);
connection.setRequestProperty("Content-Type", "multipart/form-data; boundary=" + boundary);
return connection;
Expand All @@ -737,6 +752,26 @@ private static String createStringOfSize(int size) {
return sb.toString();
}

private static void uploadFile(HttpURLConnection connection, String content, String boundary)
throws IOException {
final String crlf = "\r\n";
try (OutputStream output = connection.getOutputStream();
PrintWriter writer =
new PrintWriter(
new OutputStreamWriter(output, StandardCharsets.UTF_8), true)) {

writer.append("--" + boundary).append(crlf);
writer.append("Content-Disposition: form-data; name=\"foo\"; filename=\"bar\"")
.append(crlf);
writer.append("Content-Type: plain/text; charset=utf8").append(crlf);
writer.append(crlf).flush();
output.write(content.getBytes(StandardCharsets.UTF_8));
output.flush();
writer.append(crlf).flush();
writer.append("--" + boundary + "--").append(crlf).flush();
}
}

private static class TestHandler
extends AbstractRestHandler<RestfulGateway, TestRequest, TestResponse, TestParameters> {

Expand Down

0 comments on commit da45a78

Please sign in to comment.