diff --git a/api/server/handler.go b/api/server/handler.go index b4d3913..fe54be7 100644 --- a/api/server/handler.go +++ b/api/server/handler.go @@ -125,7 +125,7 @@ func (m *HttpServer) handleBlobGetByID(w http.ResponseWriter, r *http.Request, i w.Header().Set(httpHeaderContentTypeOptionsNoSniff()) w.Header().Set(httpHeaderContentLength(blobDesc.Size)) w.Header().Set("Content-Disposition", fmt.Sprintf(`attachement; filename="%s.bin"`, id.String())) - http.ServeContent(w, r, "", blobDesc.ModificationTime, blobReader) + serveContent(w, r, "", blobDesc.ModificationTime, blobReader) logger.Debug("Blob fetched successfully") } diff --git a/api/server/serve_content.go b/api/server/serve_content.go new file mode 100644 index 0000000..8dfacb2 --- /dev/null +++ b/api/server/serve_content.go @@ -0,0 +1,569 @@ +package server + +import ( + "errors" + "fmt" + "io" + "io/fs" + "mime" + "mime/multipart" + "net/http" + "net/textproto" + "path/filepath" + "strconv" + "strings" + "time" +) + +type WriterToN interface { + WriteToN(io.Writer, int64) (int64, error) +} + +// errSeeker is returned by ServeContent's sizeFunc when the content +// doesn't seek properly. The underlying Seeker's error text isn't +// included in the sizeFunc reply so it's not sent over HTTP to end +// users. +var errSeeker = errors.New("seeker can't seek") + +// errNoOverlap is returned by serveContent's parseRange if first-byte-pos of +// all of the byte-range-spec values is greater than the content size. +var errNoOverlap = errors.New("invalid range: failed to overlap") + +// serveContext is a copy of http.ServeContent with the minor modification that +// is can call content.WriteToN if available, and it does not sniff content to +// determine content type. +// +// See: https://pkg.go.dev/net/http#ServeContent +func serveContent(w http.ResponseWriter, r *http.Request, name string, modtime time.Time, content io.ReadSeeker) { + sizeFunc := func() (int64, error) { + size, err := content.Seek(0, io.SeekEnd) + if err != nil { + return 0, errSeeker + } + _, err = content.Seek(0, io.SeekStart) + if err != nil { + return 0, errSeeker + } + return size, nil + } + + setLastModified(w, modtime) + done, rangeReq := checkPreconditions(w, r, modtime) + if done { + return + } + + code := http.StatusOK + + // If Content-Type isn't set, use the file's extension to find it. Do not + // sniff content to find type. Default to "application/octet-stream" if no + // content type is set. + ctypes, haveType := w.Header()["Content-Type"] + var ctype string + if !haveType { + ctype = mime.TypeByExtension(filepath.Ext(name)) + if ctype == "" { + ctype = "application/octet-stream" + } + w.Header().Set("Content-Type", ctype) + } else if len(ctypes) > 0 { + ctype = ctypes[0] + } + + size, err := sizeFunc() + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + if size < 0 { + // Should never happen but just to be sure + http.Error(w, "negative content size computed", http.StatusInternalServerError) + return + } + + // handle Content-Range header. + sendSize := size + var sendContent io.Reader = content + ranges, err := parseRange(rangeReq, size) + switch err { + case nil: + case errNoOverlap: + if size == 0 { + // Some clients add a Range header to all requests to + // limit the size of the response. If the file is empty, + // ignore the range header and respond with a 200 rather + // than a 416. + ranges = nil + break + } + w.Header().Set("Content-Range", fmt.Sprintf("bytes */%d", size)) + fallthrough + default: + http.Error(w, err.Error(), http.StatusRequestedRangeNotSatisfiable) + return + } + + if sumRangesSize(ranges) > size { + // The total number of bytes in all the ranges + // is larger than the size of the file by + // itself, so this is probably an attack, or a + // dumb client. Ignore the range request. + ranges = nil + } + switch { + case len(ranges) == 1: + // RFC 7233, Section 4.1: + // "If a single part is being transferred, the server + // generating the 206 response MUST generate a + // Content-Range header field, describing what range + // of the selected representation is enclosed, and a + // payload consisting of the range. + // ... + // A server MUST NOT generate a multipart response to + // a request for a single range, since a client that + // does not request multiple parts might not support + // multipart responses." + ra := ranges[0] + if _, err := content.Seek(ra.start, io.SeekStart); err != nil { + http.Error(w, err.Error(), http.StatusRequestedRangeNotSatisfiable) + return + } + sendSize = ra.length + code = http.StatusPartialContent + w.Header().Set("Content-Range", ra.contentRange(size)) + case len(ranges) > 1: + sendSize = rangesMIMESize(ranges, ctype, size) + code = http.StatusPartialContent + + pr, pw := io.Pipe() + mw := multipart.NewWriter(pw) + w.Header().Set("Content-Type", "multipart/byteranges; boundary="+mw.Boundary()) + sendContent = pr + defer pr.Close() // cause writing goroutine to fail and exit if CopyN doesn't finish. + go func() { + for _, ra := range ranges { + part, err := mw.CreatePart(ra.mimeHeader(ctype, size)) + if err != nil { + pw.CloseWithError(err) + return + } + if _, err := content.Seek(ra.start, io.SeekStart); err != nil { + pw.CloseWithError(err) + return + } + wtn, ok := content.(WriterToN) + if ok { + if _, err := wtn.WriteToN(part, ra.length); err != nil { + pw.CloseWithError(err) + return + } + } else { + if _, err := io.CopyN(part, content, ra.length); err != nil { + pw.CloseWithError(err) + return + } + } + } + mw.Close() + pw.Close() + }() + } + + w.Header().Set("Accept-Ranges", "bytes") + if w.Header().Get("Content-Encoding") == "" { + w.Header().Set("Content-Length", strconv.FormatInt(sendSize, 10)) + } + + w.WriteHeader(code) + + if r.Method != "HEAD" { + wtn, ok := content.(WriterToN) + if ok { + wtn.WriteToN(w, sendSize) + } else { + io.CopyN(w, sendContent, sendSize) + } + } +} + +// scanETag determines if a syntactically valid ETag is present at s. If so, +// the ETag and remaining text after consuming ETag is returned. Otherwise, +// it returns "", "". +func scanETag(s string) (etag string, remain string) { + s = textproto.TrimString(s) + start := 0 + if strings.HasPrefix(s, "W/") { + start = 2 + } + if len(s[start:]) < 2 || s[start] != '"' { + return "", "" + } + // ETag is either W/"text" or "text". + // See RFC 7232 2.3. + for i := start + 1; i < len(s); i++ { + c := s[i] + switch { + // Character values allowed in ETags. + case c == 0x21 || c >= 0x23 && c <= 0x7E || c >= 0x80: + case c == '"': + return s[:i+1], s[i+1:] + default: + return "", "" + } + } + return "", "" +} + +// etagStrongMatch reports whether a and b match using strong ETag comparison. +// Assumes a and b are valid ETags. +func etagStrongMatch(a, b string) bool { + return a == b && a != "" && a[0] == '"' +} + +// etagWeakMatch reports whether a and b match using weak ETag comparison. +// Assumes a and b are valid ETags. +func etagWeakMatch(a, b string) bool { + return strings.TrimPrefix(a, "W/") == strings.TrimPrefix(b, "W/") +} + +// condResult is the result of an HTTP request precondition check. +// See https://tools.ietf.org/html/rfc7232 section 3. +type condResult int + +const ( + condNone condResult = iota + condTrue + condFalse +) + +func checkIfMatch(w http.ResponseWriter, r *http.Request) condResult { + im := r.Header.Get("If-Match") + if im == "" { + return condNone + } + for { + im = textproto.TrimString(im) + if len(im) == 0 { + break + } + if im[0] == ',' { + im = im[1:] + continue + } + if im[0] == '*' { + return condTrue + } + etag, remain := scanETag(im) + if etag == "" { + break + } + if etagStrongMatch(etag, w.Header().Get("Etag")) { + return condTrue + } + im = remain + } + + return condFalse +} + +func checkIfUnmodifiedSince(r *http.Request, modtime time.Time) condResult { + ius := r.Header.Get("If-Unmodified-Since") + if ius == "" || isZeroTime(modtime) { + return condNone + } + t, err := http.ParseTime(ius) + if err != nil { + return condNone + } + + // The Last-Modified header truncates sub-second precision so + // the modtime needs to be truncated too. + modtime = modtime.Truncate(time.Second) + if ret := modtime.Compare(t); ret <= 0 { + return condTrue + } + return condFalse +} + +func checkIfNoneMatch(w http.ResponseWriter, r *http.Request) condResult { + inm := r.Header.Get("If-None-Match") + if inm == "" { + return condNone + } + buf := inm + for { + buf = textproto.TrimString(buf) + if len(buf) == 0 { + break + } + if buf[0] == ',' { + buf = buf[1:] + continue + } + if buf[0] == '*' { + return condFalse + } + etag, remain := scanETag(buf) + if etag == "" { + break + } + if etagWeakMatch(etag, w.Header().Get("Etag")) { + return condFalse + } + buf = remain + } + return condTrue +} + +func checkIfModifiedSince(r *http.Request, modtime time.Time) condResult { + if r.Method != "GET" && r.Method != "HEAD" { + return condNone + } + ims := r.Header.Get("If-Modified-Since") + if ims == "" || isZeroTime(modtime) { + return condNone + } + t, err := http.ParseTime(ims) + if err != nil { + return condNone + } + // The Last-Modified header truncates sub-second precision so + // the modtime needs to be truncated too. + modtime = modtime.Truncate(time.Second) + if ret := modtime.Compare(t); ret <= 0 { + return condFalse + } + return condTrue +} + +func checkIfRange(w http.ResponseWriter, r *http.Request, modtime time.Time) condResult { + if r.Method != "GET" && r.Method != "HEAD" { + return condNone + } + ir := r.Header.Get("If-Range") + if ir == "" { + return condNone + } + etag, _ := scanETag(ir) + if etag != "" { + if etagStrongMatch(etag, w.Header().Get("Etag")) { + return condTrue + } else { + return condFalse + } + } + // The If-Range value is typically the ETag value, but it may also be + // the modtime date. See golang.org/issue/8367. + if modtime.IsZero() { + return condFalse + } + t, err := http.ParseTime(ir) + if err != nil { + return condFalse + } + if t.Unix() == modtime.Unix() { + return condTrue + } + return condFalse +} + +var unixEpochTime = time.Unix(0, 0) + +// isZeroTime reports whether t is obviously unspecified (either zero or Unix()=0). +func isZeroTime(t time.Time) bool { + return t.IsZero() || t.Equal(unixEpochTime) +} + +func setLastModified(w http.ResponseWriter, modtime time.Time) { + if !isZeroTime(modtime) { + w.Header().Set("Last-Modified", modtime.UTC().Format(http.TimeFormat)) + } +} + +func writeNotModified(w http.ResponseWriter) { + // RFC 7232 section 4.1: + // a sender SHOULD NOT generate representation metadata other than the + // above listed fields unless said metadata exists for the purpose of + // guiding cache updates (e.g., Last-Modified might be useful if the + // response does not have an ETag field). + h := w.Header() + delete(h, "Content-Type") + delete(h, "Content-Length") + delete(h, "Content-Encoding") + if h.Get("Etag") != "" { + delete(h, "Last-Modified") + } + w.WriteHeader(http.StatusNotModified) +} + +// checkPreconditions evaluates request preconditions and reports whether a precondition +// resulted in sending StatusNotModified or StatusPreconditionFailed. +func checkPreconditions(w http.ResponseWriter, r *http.Request, modtime time.Time) (done bool, rangeHeader string) { + // This function carefully follows RFC 7232 section 6. + ch := checkIfMatch(w, r) + if ch == condNone { + ch = checkIfUnmodifiedSince(r, modtime) + } + if ch == condFalse { + w.WriteHeader(http.StatusPreconditionFailed) + return true, "" + } + switch checkIfNoneMatch(w, r) { + case condFalse: + if r.Method == "GET" || r.Method == "HEAD" { + writeNotModified(w) + return true, "" + } else { + w.WriteHeader(http.StatusPreconditionFailed) + return true, "" + } + case condNone: + if checkIfModifiedSince(r, modtime) == condFalse { + writeNotModified(w) + return true, "" + } + } + + rangeHeader = r.Header.Get("Range") + if rangeHeader != "" && checkIfRange(w, r, modtime) == condFalse { + rangeHeader = "" + } + return false, rangeHeader +} + +// toHTTPError returns a non-specific HTTP error message and status code +// for a given non-nil error value. It's important that toHTTPError does not +// actually return err.Error(), since msg and httpStatus are returned to users, +// and historically Go's ServeContent always returned just "404 Not Found" for +// all errors. We don't want to start leaking information in error messages. +func toHTTPError(err error) (msg string, httpStatus int) { + if errors.Is(err, fs.ErrNotExist) { + return "404 page not found", http.StatusNotFound + } + if errors.Is(err, fs.ErrPermission) { + return "403 Forbidden", http.StatusForbidden + } + // Default: + return "500 Internal Server Error", http.StatusInternalServerError +} + +// httpRange specifies the byte range to be sent to the client. +type httpRange struct { + start, length int64 +} + +func (r httpRange) contentRange(size int64) string { + return fmt.Sprintf("bytes %d-%d/%d", r.start, r.start+r.length-1, size) +} + +func (r httpRange) mimeHeader(contentType string, size int64) textproto.MIMEHeader { + return textproto.MIMEHeader{ + "Content-Range": {r.contentRange(size)}, + "Content-Type": {contentType}, + } +} + +// parseRange parses a Range header string as per RFC 7233. +// errNoOverlap is returned if none of the ranges overlap. +func parseRange(s string, size int64) ([]httpRange, error) { + if s == "" { + return nil, nil // header not present + } + const b = "bytes=" + if !strings.HasPrefix(s, b) { + return nil, errors.New("invalid range") + } + var ranges []httpRange + noOverlap := false + for _, ra := range strings.Split(s[len(b):], ",") { + ra = textproto.TrimString(ra) + if ra == "" { + continue + } + start, end, ok := strings.Cut(ra, "-") + if !ok { + return nil, errors.New("invalid range") + } + start, end = textproto.TrimString(start), textproto.TrimString(end) + var r httpRange + if start == "" { + // If no start is specified, end specifies the + // range start relative to the end of the file, + // and we are dealing with + // which has to be a non-negative integer as per + // RFC 7233 Section 2.1 "Byte-Ranges". + if end == "" || end[0] == '-' { + return nil, errors.New("invalid range") + } + i, err := strconv.ParseInt(end, 10, 64) + if i < 0 || err != nil { + return nil, errors.New("invalid range") + } + if i > size { + i = size + } + r.start = size - i + r.length = size - r.start + } else { + i, err := strconv.ParseInt(start, 10, 64) + if err != nil || i < 0 { + return nil, errors.New("invalid range") + } + if i >= size { + // If the range begins after the size of the content, + // then it does not overlap. + noOverlap = true + continue + } + r.start = i + if end == "" { + // If no end is specified, range extends to end of the file. + r.length = size - r.start + } else { + i, err := strconv.ParseInt(end, 10, 64) + if err != nil || r.start > i { + return nil, errors.New("invalid range") + } + if i >= size { + i = size - 1 + } + r.length = i - r.start + 1 + } + } + ranges = append(ranges, r) + } + if noOverlap && len(ranges) == 0 { + // The specified ranges did not overlap with the content. + return nil, errNoOverlap + } + return ranges, nil +} + +// countingWriter counts how many bytes have been written to it. +type countingWriter int64 + +func (w *countingWriter) Write(p []byte) (n int, err error) { + *w += countingWriter(len(p)) + return len(p), nil +} + +// rangesMIMESize returns the number of bytes it takes to encode the +// provided ranges as a multipart response. +func rangesMIMESize(ranges []httpRange, contentType string, contentSize int64) (encSize int64) { + var w countingWriter + mw := multipart.NewWriter(&w) + for _, ra := range ranges { + mw.CreatePart(ra.mimeHeader(contentType, contentSize)) + encSize += ra.length + } + mw.Close() + encSize += int64(w) + return +} + +func sumRangesSize(ranges []httpRange) (size int64) { + for _, ra := range ranges { + size += ra.length + } + return +} diff --git a/integration/singularity/reader.go b/integration/singularity/reader.go index 3593c94..f746da9 100644 --- a/integration/singularity/reader.go +++ b/integration/singularity/reader.go @@ -38,7 +38,7 @@ func (r *SingularityReader) Read(p []byte) (int, error) { buf := bytes.NewBuffer(p) buf.Reset() - n, err := r.writeToN(buf, readLen) + n, err := r.WriteToN(buf, readLen) return int(n), err } @@ -49,10 +49,10 @@ func (r *SingularityReader) WriteTo(w io.Writer) (int64, error) { return 0, io.EOF } // Read all remaining bytes and write them to w. - return r.writeToN(w, r.size-r.offset) + return r.WriteToN(w, r.size-r.offset) } -func (r *SingularityReader) writeToN(w io.Writer, readLen int64) (int64, error) { +func (r *SingularityReader) WriteToN(w io.Writer, readLen int64) (int64, error) { var read int64 // If there is a rangeReader from the previous read that can be used to // continue reading more data, then use it instead of doing another @@ -68,24 +68,16 @@ func (r *SingularityReader) writeToN(w io.Writer, readLen int64) (int64, error) r.offset += n readLen -= n read += n - if r.rangeReader.remaining == 0 { - // No data left in range reader. - r.rangeReader.close() - r.rangeReader = nil - } if readLen == 0 { // Read all requested data from leftover in rangeReader. return read, nil } // No more leftover data to read, but readLen additional bytes // still needed. Will read more data from next range(s). - } else { - // Trying to read from outside of rangeReader's range. Must have - // seeked out of current range. Close rangeReader and read new - // range. - r.rangeReader.close() - r.rangeReader = nil } + // No more leftover data in rangeReader, or seek to done since last read. + r.rangeReader.close() + r.rangeReader = nil } rangeReadLen := readLen