Skip to content

Commit

Permalink
fix: new lwt server bugs (#1048)
Browse files Browse the repository at this point in the history
* cannot call direct access recursively
* we need to catch conn_reset

Signed-off-by: Rudi Grinberg <[email protected]>

<!-- ps-id: c1719f95-0851-4245-8948-61a975d97f1d -->
  • Loading branch information
rgrinberg committed Jun 27, 2024
1 parent 7ecb79a commit fe388bb
Showing 1 changed file with 71 additions and 81 deletions.
152 changes: 71 additions & 81 deletions cohttp-server-lwt-unix/src/cohttp_server_lwt_unix.ml
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,6 @@

open Lwt.Syntax

module Direct_access = struct
let rec write_sub (da : Lwt_io.direct_access) s ~pos ~len =
if len = 0 then Lwt.return_unit
else
let remaining = da.da_max - da.da_ptr in
if remaining = 0 then
let* (_ : int) = da.da_perform () in
write_sub da s ~pos ~len
else
let write_len = min remaining len in
Lwt_bytes.blit_from_string s pos da.da_buffer da.da_ptr write_len;
da.da_ptr <- da.da_ptr + write_len;
write_sub da s ~pos:(pos + write_len) ~len:(len - write_len)

let write da s = write_sub da ~pos:0 ~len:(String.length s) s
let write_char da c = write da (String.make 1 c)
end

module Body = struct
module Substring = struct
type t = { base : string; pos : int; len : int }
Expand Down Expand Up @@ -63,50 +45,50 @@ module Body = struct
let stream ?(encoding = Encoding.Chunked) f : t = (encoding, `Stream f)
let chunk_size = 4096

let write_chunk da (sub : Substring.t) =
let* () = Direct_access.write da (Printf.sprintf "%x\r\n" sub.len) in
let* () = Direct_access.write_sub da sub.base ~pos:sub.pos ~len:sub.len in
Direct_access.write da "\r\n"
let write_chunk oc (sub : Substring.t) =
let* () = Lwt_io.write oc (Printf.sprintf "%x\r\n" sub.len) in
let* () = Lwt_io.write_from_string_exactly oc sub.base sub.pos sub.len in
Lwt_io.write oc "\r\n"

let next_chunk base ~pos =
let len = String.length base in
if pos >= len then None
else Some { Substring.base; pos; len = min chunk_size (len - pos) }

let rec write_string_as_chunks da s ~pos =
let rec write_string_as_chunks oc s ~pos =
match next_chunk s ~pos with
| None -> Direct_access.write da "\r\n"
| None -> Lwt_io.write oc "\r\n"
| Some chunk ->
let* () = write_chunk da chunk in
let* () = write_chunk oc chunk in
let pos = pos + chunk.len in
write_string_as_chunks da s ~pos
write_string_as_chunks oc s ~pos

let rec write_fixed_stream da f =
let rec write_fixed_stream oc f =
let* chunk = f () in
match chunk with
| None -> Lwt.return_unit
| Some { Substring.base; pos; len } ->
let* () = Direct_access.write_sub da base ~pos ~len in
write_fixed_stream da f
let* () = Lwt_io.write_from_string_exactly oc base pos len in
write_fixed_stream oc f

let rec write_chunks_stream da f =
let rec write_chunks_stream oc f =
let* chunk = f () in
match chunk with
| None -> Direct_access.write da "\r\n"
| None -> Lwt_io.write oc "\r\n"
| Some chunk ->
let* () = write_chunk da chunk in
write_chunks_stream da f
let* () = write_chunk oc chunk in
write_chunks_stream oc f

let write ((encoding, body) : t) da =
let write ((encoding, body) : t) oc =
match body with
| `String s -> (
match encoding with
| Fixed _ -> Direct_access.write da s
| Chunked -> write_string_as_chunks da s ~pos:0)
| Fixed _ -> Lwt_io.write oc s
| Chunked -> write_string_as_chunks oc s ~pos:0)
| `Stream f -> (
match encoding with
| Fixed _ -> write_fixed_stream da f
| Chunked -> write_chunks_stream da f)
| Fixed _ -> write_fixed_stream oc f
| Chunked -> write_chunks_stream oc f)
end

module Input_channel = struct
Expand All @@ -120,30 +102,45 @@ module Input_channel = struct
let ( >>| ) = ( >|= )
end)
(struct
type src = Lwt_io.direct_access

let rec refill (da : Lwt_io.direct_access) buf ~pos ~len =
Lwt.catch
(fun () ->
let available = da.da_max - da.da_ptr in
if available = 0 then
let* read = da.da_perform () in
if read = 0 then Lwt.return `Eof else refill da buf ~pos ~len
else
let read_len = min available len in
Lwt_bytes.blit_to_bytes da.da_buffer da.da_ptr buf pos read_len;
da.da_ptr <- da.da_ptr + read_len;
Lwt.return (`Ok read_len))
(function
| Lwt_io.Channel_closed _ -> Lwt.return `Eof | exn -> raise exn)
type src = Lwt_io.input_channel

let rec refill ic buf ~pos ~len =
let open Lwt.Infix in
if Lwt_io.is_closed ic then Lwt.return `Eof
else
Lwt.catch
(fun () ->
Lwt_io.direct_access ic (fun da ->
let available = da.da_max - da.da_ptr in
if available = 0 then
let+ read = da.da_perform () in
if read = 0 then `Eof else `Refill
else
let read_len = min available len in
Lwt_bytes.blit_to_bytes da.da_buffer da.da_ptr buf pos
read_len;
da.da_ptr <- da.da_ptr + read_len;
Lwt.return (`Ok read_len)))
(function
| Unix.Unix_error (ECONNRESET, _, _) | Lwt_io.Channel_closed _
->
let* () = Lwt_io.close ic in
Lwt.return `Eof
| exn -> raise exn)
>>= function
| `Eof ->
let* () = Lwt_io.close ic in
Lwt.return `Eof
| `Ok n -> Lwt.return (`Ok n)
| `Refill -> refill ic buf ~pos ~len
end)

type t = { buf : Bytebuffer.t; da : Lwt_io.direct_access }
type t = { buf : Bytebuffer.t; ic : Lwt_io.input_channel }

let create ?(buf_len = 0x4000) da = { buf = Bytebuffer.create buf_len; da }
let read_line_opt t = Refill.read_line t.buf t.da
let read t count = Refill.read t.buf t.da count
let refill t = Refill.refill t.buf t.da
let create ?(buf_len = 0x4000) ic = { buf = Bytebuffer.create buf_len; ic }
let read_line_opt t = Refill.read_line t.buf t.ic
let read t count = Refill.read t.buf t.ic count
let refill t = Refill.refill t.buf t.ic
let remaining t = Bytebuffer.length t.buf

let with_input_buffer (t : t) ~f =
Expand Down Expand Up @@ -289,25 +286,20 @@ module Context = struct
Http.Header.add_transfer_encoding response.headers encoding
in
let* () =
Lwt_io.direct_access t.oc (fun (da : Lwt_io.direct_access) ->
let* () =
Direct_access.write da (Http.Version.to_string response.version)
in
let* () = Direct_access.write_char da ' ' in
let* () =
Direct_access.write da (Http.Status.to_string response.status)
in
let* () = Direct_access.write da "\r\n" in
let* () =
Http.Header.to_list headers
|> Lwt_list.iter_s (fun (k, v) ->
let* () = Direct_access.write da k in
let* () = Direct_access.write da ": " in
let* () = Direct_access.write da v in
Direct_access.write da "\r\n")
in
let* () = Direct_access.write da "\r\n" in
Body.write body da)
let* () = Lwt_io.write t.oc (Http.Version.to_string response.version) in
let* () = Lwt_io.write_char t.oc ' ' in
let* () = Lwt_io.write t.oc (Http.Status.to_string response.status) in
let* () = Lwt_io.write t.oc "\r\n" in
let* () =
Http.Header.to_list headers
|> Lwt_list.iter_s (fun (k, v) ->
let* () = Lwt_io.write t.oc k in
let* () = Lwt_io.write t.oc ": " in
let* () = Lwt_io.write t.oc v in
Lwt_io.write t.oc "\r\n")
in
let* () = Lwt_io.write t.oc "\r\n" in
Body.write body t.oc
in
Lwt.wakeup_later t.response_send response;
Lwt_io.flush t.oc
Expand Down Expand Up @@ -365,6 +357,4 @@ let handle_connection { callback; on_exn } (ic, oc) =
in
if keep_alive then loop callback ic oc else Lwt.return_unit
in
Lwt_io.direct_access ic (fun da ->
let ic = Input_channel.create da in
loop callback ic oc)
loop callback (Input_channel.create ic) oc

0 comments on commit fe388bb

Please sign in to comment.