diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/rest/FileUploadHandler.java b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/FileUploadHandler.java index c9e1fd78d7432a..3a969a57308588 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/rest/FileUploadHandler.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/FileUploadHandler.java @@ -19,6 +19,7 @@ package org.apache.flink.runtime.rest; import org.apache.flink.runtime.rest.handler.FileUploads; +import org.apache.flink.runtime.rest.handler.router.MultipartRoutes; import org.apache.flink.runtime.rest.handler.util.HandlerUtils; import org.apache.flink.runtime.rest.messages.ErrorResponseBody; import org.apache.flink.runtime.rest.util.RestConstants; @@ -79,15 +80,15 @@ public class FileUploadHandler extends SimpleChannelInboundHandler { private final Path uploadDir; + private final MultipartRoutes multipartRoutes; + private HttpPostRequestDecoder currentHttpPostRequestDecoder; private HttpRequest currentHttpRequest; private byte[] currentJsonPayload; private Path currentUploadDir; - private boolean addCRPrefix = false; - - public FileUploadHandler(final Path uploadDir) { + public FileUploadHandler(final Path uploadDir, final MultipartRoutes multipartRoutes) { super(true); // the clean up of temp files when jvm exits is handled by @@ -103,6 +104,7 @@ public FileUploadHandler(final Path uploadDir) { DiskAttribute.baseDirectory = DiskFileUpload.baseDirectory; this.uploadDir = requireNonNull(uploadDir); + this.multipartRoutes = requireNonNull(multipartRoutes); } @Override @@ -125,6 +127,18 @@ protected void channelRead0(final ChannelHandlerContext ctx, final HttpObject ms new HttpPostRequestDecoder(DATA_FACTORY, httpRequest); currentHttpRequest = ReferenceCountUtil.retain(httpRequest); + // We check this after initializing the multipart file upload in order for + // handleError to work correctly. + if (!multipartRoutes.isPostRoute(httpRequest.uri())) { + LOG.trace("POST request not allowed for {}.", httpRequest.uri()); + handleError( + ctx, + "POST request not allowed", + HttpResponseStatus.BAD_REQUEST, + null); + return; + } + // make sure that we still have a upload dir in case that it got deleted in // the meanwhile RestServerEndpoint.createUploadDir(uploadDir, LOG, false); @@ -151,6 +165,17 @@ protected void channelRead0(final ChannelHandlerContext ctx, final HttpObject ms && hasNext(currentHttpPostRequestDecoder)) { final InterfaceHttpData data = currentHttpPostRequestDecoder.next(); if (data.getHttpDataType() == InterfaceHttpData.HttpDataType.FileUpload) { + HttpRequest httpRequest = currentHttpRequest; + if (!multipartRoutes.isFileUploadRoute(httpRequest.uri())) { + LOG.trace("File upload not allowed for {}.", httpRequest.uri()); + handleError( + ctx, + "File upload not allowed", + HttpResponseStatus.BAD_REQUEST, + null); + return; + } + final DiskFileUpload fileUpload = (DiskFileUpload) data; checkState(fileUpload.isCompleted()); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/rest/RestServerEndpoint.java b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/RestServerEndpoint.java index 817e25521d8845..4d021e019d0625 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/rest/RestServerEndpoint.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/RestServerEndpoint.java @@ -28,8 +28,10 @@ import org.apache.flink.runtime.net.RedirectingSslHandler; import org.apache.flink.runtime.rest.handler.PipelineErrorHandler; import org.apache.flink.runtime.rest.handler.RestHandlerSpecification; +import org.apache.flink.runtime.rest.handler.router.MultipartRoutes; import org.apache.flink.runtime.rest.handler.router.Router; import org.apache.flink.runtime.rest.handler.router.RouterHandler; +import org.apache.flink.runtime.rest.messages.UntypedResponseMessageHeaders; import org.apache.flink.runtime.rest.versioning.RestAPIVersion; import org.apache.flink.util.AutoCloseableAsync; import org.apache.flink.util.ConfigurationException; @@ -196,6 +198,9 @@ public final void start() throws Exception { checkAllEndpointsAndHandlersAreUnique(handlers); handlers.forEach(handler -> registerHandler(router, handler, log)); + MultipartRoutes multipartRoutes = createMultipartRoutes(handlers); + log.debug("Using {} for FileUploadHandler", multipartRoutes); + ChannelInitializer initializer = new ChannelInitializer() { @@ -216,7 +221,7 @@ protected void initChannel(SocketChannel ch) throws ConfigurationException { ch.pipeline() .addLast(new HttpServerCodec()) - .addLast(new FileUploadHandler(uploadDir)) + .addLast(new FileUploadHandler(uploadDir, multipartRoutes)) .addLast( new FlinkHttpObjectAggregator( maxContentLength, responseHeaders)); @@ -635,6 +640,32 @@ private static void checkAllEndpointsAndHandlersAreUnique( } } + private MultipartRoutes createMultipartRoutes( + List> handlers) { + MultipartRoutes.Builder builder = new MultipartRoutes.Builder(); + + for (Tuple2 handler : handlers) { + if (handler.f0.getHttpMethod() == HttpMethodWrapper.POST) { + for (String url : getHandlerRoutes(handler.f0)) { + builder.addPostRoute(url); + } + } + + // The cast is necessary, because currently only UntypedResponseMessageHeaders exposes + // whether the handler accepts file uploads. + if (handler.f0 instanceof UntypedResponseMessageHeaders) { + UntypedResponseMessageHeaders header = + (UntypedResponseMessageHeaders) handler.f0; + if (header.acceptsFileUploads()) { + for (String url : getHandlerRoutes(header)) { + builder.addFileUploadRoute(url); + } + } + } + } + return builder.build(); + } + private static Iterable getHandlerRoutes(RestHandlerSpecification handlerSpec) { final List registeredRoutes = new ArrayList<>(); final String handlerUrl = handlerSpec.getTargetRestEndpointURL(); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/rest/RestServerEndpointITCase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/rest/RestServerEndpointITCase.java index 2c6d3b0afef8f6..bcc1ee07f0e2ab 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/rest/RestServerEndpointITCase.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/rest/RestServerEndpointITCase.java @@ -374,25 +374,12 @@ void testShouldRespectMaxContentLengthLimitForResponses() throws Exception { @TestTemplate void testFileUpload() throws Exception { final String boundary = generateMultiPartBoundary(); - final String crlf = "\r\n"; final String uploadedContent = "hello"; - final HttpURLConnection connection = openHttpConnectionForUpload(boundary); - - try (OutputStream output = connection.getOutputStream(); - PrintWriter writer = - new PrintWriter( - new OutputStreamWriter(output, StandardCharsets.UTF_8), true)) { + final HttpURLConnection connection = + openHttpConnectionForUpload( + boundary, TestUploadHeaders.INSTANCE.getTargetRestEndpointURL()); - writer.append("--" + boundary).append(crlf); - writer.append("Content-Disposition: form-data; name=\"foo\"; filename=\"bar\"") - .append(crlf); - writer.append("Content-Type: plain/text; charset=utf8").append(crlf); - writer.append(crlf).flush(); - output.write(uploadedContent.getBytes(StandardCharsets.UTF_8)); - output.flush(); - writer.append(crlf).flush(); - writer.append("--" + boundary + "--").append(crlf).flush(); - } + uploadFile(connection, uploadedContent, boundary); assertThat(connection.getResponseCode()).isEqualTo(200); final byte[] lastUploadedFileContents = testUploadHandler.getLastUploadedFileContents(); @@ -400,6 +387,32 @@ void testFileUpload() throws Exception { .isEqualTo(new String(lastUploadedFileContents, StandardCharsets.UTF_8)); } + /** + * Tests that when a handler is marked as not accepting file uploads we (1) return an error and + * (2) don't upload the file to the upload directory. + */ + @TestTemplate + void testFileUploadLimitedToAllowedUris() throws Exception { + final String boundary = generateMultiPartBoundary(); + final File uploadDir = new File(tempFolder.toString(), "flink-web-upload"); + final File[] preUploadFiles = uploadDir.listFiles(); + + // We need a handler that does not accept file uploads for this test + assertThat(TestVersionHeaders.INSTANCE.acceptsFileUploads()).isFalse(); + String uri = TestVersionHeaders.INSTANCE.getTargetRestEndpointURL(); + + final HttpURLConnection connection = openHttpConnectionForUpload(boundary, uri); + + uploadFile(connection, "hello", boundary); + + assertThat(connection.getResponseCode()).isEqualTo(400); + + // This is the important check. We don't want additional files when the handler does + // not accept file uploads. + final File[] postUploadFiles = uploadDir.listFiles(); + assertThat(postUploadFiles).isEqualTo(preUploadFiles); + } + /** * Sending multipart/form-data without a file should result in a bad request if the handler * expects a file upload. @@ -408,7 +421,9 @@ void testFileUpload() throws Exception { void testMultiPartFormDataWithoutFileUpload() throws Exception { final String boundary = generateMultiPartBoundary(); final String crlf = "\r\n"; - final HttpURLConnection connection = openHttpConnectionForUpload(boundary); + final HttpURLConnection connection = + openHttpConnectionForUpload( + boundary, TestUploadHeaders.INSTANCE.getTargetRestEndpointURL()); try (OutputStream output = connection.getOutputStream(); PrintWriter writer = @@ -715,11 +730,11 @@ private static File getTestResource(final String fileName) { return new File(resource.getFile()); } - private HttpURLConnection openHttpConnectionForUpload(final String boundary) - throws IOException { + private HttpURLConnection openHttpConnectionForUpload( + final String boundary, final String uploadUri) throws IOException { final HttpURLConnection connection = (HttpURLConnection) - new URL(serverEndpoint.getRestBaseUrl() + "/upload").openConnection(); + new URL(serverEndpoint.getRestBaseUrl() + uploadUri).openConnection(); connection.setDoOutput(true); connection.setRequestProperty("Content-Type", "multipart/form-data; boundary=" + boundary); return connection; @@ -737,6 +752,26 @@ private static String createStringOfSize(int size) { return sb.toString(); } + private static void uploadFile(HttpURLConnection connection, String content, String boundary) + throws IOException { + final String crlf = "\r\n"; + try (OutputStream output = connection.getOutputStream(); + PrintWriter writer = + new PrintWriter( + new OutputStreamWriter(output, StandardCharsets.UTF_8), true)) { + + writer.append("--" + boundary).append(crlf); + writer.append("Content-Disposition: form-data; name=\"foo\"; filename=\"bar\"") + .append(crlf); + writer.append("Content-Type: plain/text; charset=utf8").append(crlf); + writer.append(crlf).flush(); + output.write(content.getBytes(StandardCharsets.UTF_8)); + output.flush(); + writer.append(crlf).flush(); + writer.append("--" + boundary + "--").append(crlf).flush(); + } + } + private static class TestHandler extends AbstractRestHandler {