From fe388bb9cc4b159224dc2f7e33a9a7faa013fa35 Mon Sep 17 00:00:00 2001 From: Rudi Grinberg Date: Thu, 27 Jun 2024 20:15:35 +0100 Subject: [PATCH] fix: new lwt server bugs (#1048) * cannot call direct access recursively * we need to catch conn_reset Signed-off-by: Rudi Grinberg --- .../src/cohttp_server_lwt_unix.ml | 152 ++++++++---------- 1 file changed, 71 insertions(+), 81 deletions(-) diff --git a/cohttp-server-lwt-unix/src/cohttp_server_lwt_unix.ml b/cohttp-server-lwt-unix/src/cohttp_server_lwt_unix.ml index dd44e2d020..c372abf873 100644 --- a/cohttp-server-lwt-unix/src/cohttp_server_lwt_unix.ml +++ b/cohttp-server-lwt-unix/src/cohttp_server_lwt_unix.ml @@ -16,24 +16,6 @@ open Lwt.Syntax -module Direct_access = struct - let rec write_sub (da : Lwt_io.direct_access) s ~pos ~len = - if len = 0 then Lwt.return_unit - else - let remaining = da.da_max - da.da_ptr in - if remaining = 0 then - let* (_ : int) = da.da_perform () in - write_sub da s ~pos ~len - else - let write_len = min remaining len in - Lwt_bytes.blit_from_string s pos da.da_buffer da.da_ptr write_len; - da.da_ptr <- da.da_ptr + write_len; - write_sub da s ~pos:(pos + write_len) ~len:(len - write_len) - - let write da s = write_sub da ~pos:0 ~len:(String.length s) s - let write_char da c = write da (String.make 1 c) -end - module Body = struct module Substring = struct type t = { base : string; pos : int; len : int } @@ -63,50 +45,50 @@ module Body = struct let stream ?(encoding = Encoding.Chunked) f : t = (encoding, `Stream f) let chunk_size = 4096 - let write_chunk da (sub : Substring.t) = - let* () = Direct_access.write da (Printf.sprintf "%x\r\n" sub.len) in - let* () = Direct_access.write_sub da sub.base ~pos:sub.pos ~len:sub.len in - Direct_access.write da "\r\n" + let write_chunk oc (sub : Substring.t) = + let* () = Lwt_io.write oc (Printf.sprintf "%x\r\n" sub.len) in + let* () = Lwt_io.write_from_string_exactly oc sub.base sub.pos sub.len in + Lwt_io.write oc "\r\n" let next_chunk base ~pos = let len = String.length base in if pos >= len then None else Some { Substring.base; pos; len = min chunk_size (len - pos) } - let rec write_string_as_chunks da s ~pos = + let rec write_string_as_chunks oc s ~pos = match next_chunk s ~pos with - | None -> Direct_access.write da "\r\n" + | None -> Lwt_io.write oc "\r\n" | Some chunk -> - let* () = write_chunk da chunk in + let* () = write_chunk oc chunk in let pos = pos + chunk.len in - write_string_as_chunks da s ~pos + write_string_as_chunks oc s ~pos - let rec write_fixed_stream da f = + let rec write_fixed_stream oc f = let* chunk = f () in match chunk with | None -> Lwt.return_unit | Some { Substring.base; pos; len } -> - let* () = Direct_access.write_sub da base ~pos ~len in - write_fixed_stream da f + let* () = Lwt_io.write_from_string_exactly oc base pos len in + write_fixed_stream oc f - let rec write_chunks_stream da f = + let rec write_chunks_stream oc f = let* chunk = f () in match chunk with - | None -> Direct_access.write da "\r\n" + | None -> Lwt_io.write oc "\r\n" | Some chunk -> - let* () = write_chunk da chunk in - write_chunks_stream da f + let* () = write_chunk oc chunk in + write_chunks_stream oc f - let write ((encoding, body) : t) da = + let write ((encoding, body) : t) oc = match body with | `String s -> ( match encoding with - | Fixed _ -> Direct_access.write da s - | Chunked -> write_string_as_chunks da s ~pos:0) + | Fixed _ -> Lwt_io.write oc s + | Chunked -> write_string_as_chunks oc s ~pos:0) | `Stream f -> ( match encoding with - | Fixed _ -> write_fixed_stream da f - | Chunked -> write_chunks_stream da f) + | Fixed _ -> write_fixed_stream oc f + | Chunked -> write_chunks_stream oc f) end module Input_channel = struct @@ -120,30 +102,45 @@ module Input_channel = struct let ( >>| ) = ( >|= ) end) (struct - type src = Lwt_io.direct_access - - let rec refill (da : Lwt_io.direct_access) buf ~pos ~len = - Lwt.catch - (fun () -> - let available = da.da_max - da.da_ptr in - if available = 0 then - let* read = da.da_perform () in - if read = 0 then Lwt.return `Eof else refill da buf ~pos ~len - else - let read_len = min available len in - Lwt_bytes.blit_to_bytes da.da_buffer da.da_ptr buf pos read_len; - da.da_ptr <- da.da_ptr + read_len; - Lwt.return (`Ok read_len)) - (function - | Lwt_io.Channel_closed _ -> Lwt.return `Eof | exn -> raise exn) + type src = Lwt_io.input_channel + + let rec refill ic buf ~pos ~len = + let open Lwt.Infix in + if Lwt_io.is_closed ic then Lwt.return `Eof + else + Lwt.catch + (fun () -> + Lwt_io.direct_access ic (fun da -> + let available = da.da_max - da.da_ptr in + if available = 0 then + let+ read = da.da_perform () in + if read = 0 then `Eof else `Refill + else + let read_len = min available len in + Lwt_bytes.blit_to_bytes da.da_buffer da.da_ptr buf pos + read_len; + da.da_ptr <- da.da_ptr + read_len; + Lwt.return (`Ok read_len))) + (function + | Unix.Unix_error (ECONNRESET, _, _) | Lwt_io.Channel_closed _ + -> + let* () = Lwt_io.close ic in + Lwt.return `Eof + | exn -> raise exn) + >>= function + | `Eof -> + let* () = Lwt_io.close ic in + Lwt.return `Eof + | `Ok n -> Lwt.return (`Ok n) + | `Refill -> refill ic buf ~pos ~len end) - type t = { buf : Bytebuffer.t; da : Lwt_io.direct_access } + type t = { buf : Bytebuffer.t; ic : Lwt_io.input_channel } - let create ?(buf_len = 0x4000) da = { buf = Bytebuffer.create buf_len; da } - let read_line_opt t = Refill.read_line t.buf t.da - let read t count = Refill.read t.buf t.da count - let refill t = Refill.refill t.buf t.da + let create ?(buf_len = 0x4000) ic = { buf = Bytebuffer.create buf_len; ic } + let read_line_opt t = Refill.read_line t.buf t.ic + let read t count = Refill.read t.buf t.ic count + let refill t = Refill.refill t.buf t.ic let remaining t = Bytebuffer.length t.buf let with_input_buffer (t : t) ~f = @@ -289,25 +286,20 @@ module Context = struct Http.Header.add_transfer_encoding response.headers encoding in let* () = - Lwt_io.direct_access t.oc (fun (da : Lwt_io.direct_access) -> - let* () = - Direct_access.write da (Http.Version.to_string response.version) - in - let* () = Direct_access.write_char da ' ' in - let* () = - Direct_access.write da (Http.Status.to_string response.status) - in - let* () = Direct_access.write da "\r\n" in - let* () = - Http.Header.to_list headers - |> Lwt_list.iter_s (fun (k, v) -> - let* () = Direct_access.write da k in - let* () = Direct_access.write da ": " in - let* () = Direct_access.write da v in - Direct_access.write da "\r\n") - in - let* () = Direct_access.write da "\r\n" in - Body.write body da) + let* () = Lwt_io.write t.oc (Http.Version.to_string response.version) in + let* () = Lwt_io.write_char t.oc ' ' in + let* () = Lwt_io.write t.oc (Http.Status.to_string response.status) in + let* () = Lwt_io.write t.oc "\r\n" in + let* () = + Http.Header.to_list headers + |> Lwt_list.iter_s (fun (k, v) -> + let* () = Lwt_io.write t.oc k in + let* () = Lwt_io.write t.oc ": " in + let* () = Lwt_io.write t.oc v in + Lwt_io.write t.oc "\r\n") + in + let* () = Lwt_io.write t.oc "\r\n" in + Body.write body t.oc in Lwt.wakeup_later t.response_send response; Lwt_io.flush t.oc @@ -365,6 +357,4 @@ let handle_connection { callback; on_exn } (ic, oc) = in if keep_alive then loop callback ic oc else Lwt.return_unit in - Lwt_io.direct_access ic (fun da -> - let ic = Input_channel.create da in - loop callback ic oc) + loop callback (Input_channel.create ic) oc