From c6ac3af1fcfe278a143d305d0f7deacf5ae8c6c8 Mon Sep 17 00:00:00 2001 From: Craig Ferguson Date: Fri, 4 Dec 2020 14:53:03 +0100 Subject: [PATCH] Avoid shutdown race condition on service startup Fix https://github.com/mirage/ocaml-conduit/issues/352. The `lwt` and `async` service loops previously took condition variables that are used to signal when to shutdown. This resulted in a race condition: the shutdown signal can be broadcast before the server is waiting on it. Fixed by having the service loops take switches as configuration instead. --- src/async/conduit_async.ml | 66 +++++++++++++++++----------------- src/async/conduit_async.mli | 7 ++-- src/lwt/conduit_lwt.ml | 68 +++++++++++++++++------------------ src/lwt/conduit_lwt.mli | 25 ++++++------- tests/ping-pong/common.ml | 20 +++++------ tests/ping-pong/with_async.ml | 24 +++++++------ tests/ping-pong/with_lwt.ml | 17 +++------ 7 files changed, 112 insertions(+), 115 deletions(-) diff --git a/src/async/conduit_async.ml b/src/async/conduit_async.ml index d501f034..0e593f8c 100644 --- a/src/async/conduit_async.ml +++ b/src/async/conduit_async.ml @@ -18,43 +18,43 @@ type ('a, 'b, 'c) service = ('a, 'b, 'c) Service.t let serve : type cfg t v. ?timeout:int -> + ?stop:unit Async.Deferred.t -> handler:(flow -> unit Async.Deferred.t) -> (cfg, t, v) service -> cfg -> - unit Async.Condition.t * (unit -> unit Async.Deferred.t) = - fun ?timeout ~handler service cfg -> + unit Async.Deferred.t = + fun ?timeout ?(stop = Async.Deferred.never ()) ~handler service cfg -> let open Async in - let stop = Async.Condition.create () in - let main () = - Service.init service cfg >>= function - | Error err -> failwith "%a" Service.pp_error err - | Ok t -> ( - let rec loop () = - let close = Async.Condition.wait stop >>| fun () -> Ok `Stop in - let accept = - Service.accept service t >>? fun flow -> - Async.(Deferred.ok (return (`Flow flow))) in - let events = - match timeout with - | None -> [ close; accept ] - | Some t -> - let t = Core.Time.Span.of_int_sec t in - let timeout = Async.after t >>| fun () -> Ok `Timeout in - [ close; accept; timeout ] in - - Async.Deferred.any events >>= function - | Ok (`Flow flow) -> - Async.don't_wait_for (handler flow) ; - Async.Scheduler.yield () >>= fun () -> (loop [@tailcall]) () - | Ok (`Stop | `Timeout) -> Service.stop service t - | Error err0 -> ( - Service.stop service t >>= function - | Ok () -> Async.return (Error err0) - | Error _err1 -> Async.return (Error err0)) in - loop () >>= function - | Ok () -> Async.return () - | Error err -> failwith "%a" Service.pp_error err) in - (stop, main) + let timeout = + match timeout with + | None -> Deferred.never + | Some t -> fun () -> after (Core.Time.Span.of_int_sec t) in + Service.init service cfg >>= function + | Error err -> failwith "%a" Service.pp_error err + | Ok t -> ( + let rec loop () = + let accept = Service.accept service t in + Deferred.choose + [ + choice accept (Result.map (fun f -> `Flow f)); + choice (timeout ()) (fun () -> Ok `Timeout); + ] + >>? function + | `Flow flow -> + don't_wait_for (handler flow) ; + Scheduler.yield () >>= loop + | `Timeout -> return (Ok `Timeout) in + let stop_result = + Deferred.choose + [ choice stop (fun () -> Ok `Stopped); choice (loop ()) (fun r -> r) ] + >>= function + | Ok (`Timeout | `Stopped) -> Service.stop service t + | Error _ as err0 -> ( + Service.stop service t >>= function Ok () | Error _ -> return err0) + in + stop_result >>= function + | Ok () -> return () + | Error err -> failwith "%a" Service.pp_error err) let reader_and_writer_of_flow flow = let open Async in diff --git a/src/async/conduit_async.mli b/src/async/conduit_async.mli index 9636d474..c4ad7916 100644 --- a/src/async/conduit_async.mli +++ b/src/async/conduit_async.mli @@ -15,19 +15,20 @@ type ('a, 'b, 'c) service = ('a, 'b, 'c) Service.t val serve : ?timeout:int -> + ?stop:unit Async.Deferred.t -> handler:(flow -> unit Async.Deferred.t) -> ('cfg, 't, 'v) service -> 'cfg -> - unit Async.Condition.t * (unit -> unit Async.Deferred.t) + unit Async.Deferred.t (** [serve ~handler t cfg] creates an infinite service loop from the given configuration ['cfg]. It returns the {i promise} to launch the loop and a condition variable to stop the loop. {[ - let stop, loop = serve ~handler TCP.service cfg in + let loop = serve ~handler TCP.service cfg in Async_unix.Signal.handle [ Core.Signal.int ] ~f:(fun _sig -> Async.Condition.broadcast stop ()) ; - loop () + loop ]} *) val reader_and_writer_of_flow : diff --git a/src/lwt/conduit_lwt.ml b/src/lwt/conduit_lwt.ml index a11ae406..d7b19842 100644 --- a/src/lwt/conduit_lwt.ml +++ b/src/lwt/conduit_lwt.ml @@ -53,44 +53,44 @@ let ( >>? ) = Lwt_result.bind let serve : type cfg service v. ?timeout:int -> + ?stop:Lwt_switch.t -> handler:(flow -> unit Lwt.t) -> (cfg, service, v) Service.t -> cfg -> - unit Lwt_condition.t * (unit -> unit Lwt.t) = - fun ?timeout ~handler service cfg -> + unit Lwt.t = + fun ?timeout ?stop ~handler service cfg -> let open Lwt.Infix in - let stop = Lwt_condition.create () in - let main () = - Service.init service cfg >>= function - | Error err -> failwith "%a" Service.pp_error err - | Ok t -> ( - let rec loop () = - let stop = Lwt_condition.wait stop >>= fun () -> Lwt.return_ok `Stop in - let accept = - Service.accept service t >>? fun flow -> Lwt.return_ok (`Flow flow) - in - let events = - match timeout with - | None -> [ stop; accept ] - | Some t -> - let timeout = - Lwt_unix.sleep (float_of_int t) >>= fun () -> - Lwt.return_ok `Timeout in - [ stop; accept; timeout ] in - - Lwt.pick events >>= function - | Ok (`Flow flow) -> - Lwt.async (fun () -> handler flow) ; - Lwt.pause () >>= loop - | Ok (`Stop | `Timeout) -> Service.stop service t - | Error err0 -> ( - Service.stop service t >>= function - | Ok () -> Lwt.return_error err0 - | Error _err1 -> Lwt.return_error err0) in - loop () >>= function - | Ok () -> Lwt.return_unit - | Error err -> failwith "%a" Service.pp_error err) in - (stop, main) + let timeout () = + match timeout with + | None -> Lwt.wait () |> fst + | Some t -> Lwt_unix.sleep (float_of_int t) in + Service.init service cfg >>= function + | Error err -> failwith "%a" Service.pp_error err + | Ok t -> ( + let switched_off = + let t, u = Lwt.wait () in + Lwt_switch.add_hook stop (fun () -> + Lwt.wakeup_later u (Ok `Stopped) ; + Lwt.return_unit) ; + t in + let rec loop () = + let accept = + Service.accept service t >>? fun flow -> Lwt.return_ok (`Flow flow) + in + Lwt.pick [ accept; (timeout () >|= fun () -> Ok `Timeout) ] >>? function + | `Flow flow -> + Lwt.async (fun () -> handler flow) ; + Lwt.pause () >>= loop + | `Timeout -> Lwt.return (Ok `Timeout) in + let stop_result = + Lwt.pick [ switched_off; loop () ] >>= function + | Ok (`Timeout | `Stopped) -> Service.stop service t + | Error _ as err0 -> ( + Service.stop service t >>= function + | Ok () | Error _ -> Lwt.return err0) in + stop_result >>= function + | Ok () -> Lwt.return_unit + | Error err -> failwith "%a" Service.pp_error err) module TCP = struct open Lwt.Infix diff --git a/src/lwt/conduit_lwt.mli b/src/lwt/conduit_lwt.mli index a328f46e..fd5da5b9 100644 --- a/src/lwt/conduit_lwt.mli +++ b/src/lwt/conduit_lwt.mli @@ -19,25 +19,26 @@ type ('a, 'b, 'c) service = ('a, 'b, 'c) Service.t val serve : ?timeout:int -> + ?stop:Lwt_switch.t -> handler:(flow -> unit Lwt.t) -> ('cfg, 'service, 'v) service -> 'cfg -> - unit Lwt_condition.t * (unit -> unit Lwt.t) -(** [serve ~handler service cfg] creates an usual infinite [service] loop from - the given configuration ['cfg]. It returns the {i promise} to launch the - loop and a condition variable to stop the loop. + unit Lwt.t +(** [serve ~handler service cfg] launches a [service] loop from the given + configuration ['cfg]. By default, the service loop runs indefinitely. + + - If passed, [~stop] is a switch that terminates the service loop, for + example to limit execution time to 10 seconds: {[ - let stop, loop = serve ~handler TCP.service cfg in + let switch = Lwt_switch.create () in + let loop = serve ~switch ~handler TCP.service cfg in Lwt.both - ( Lwt_unix.sleep 10. >>= fun () -> - Lwt_condition.broadcast stop () ; - Lwt.return () ) - (loop ()) + (Lwt_unix.sleep 10. >>= fun () -> Lwt_switch.turn_off switch) + loop ]} - - In your example, we want to launch a server only for 10 seconds. To help the - user, the option [?timeout] allows us to wait less than [timeout] seconds. *) + - If passed, [~timeout] specifies a maximum time to wait between accepting + connections. *) module TCP : sig (** Implementation of TCP protocol as a client. diff --git a/tests/ping-pong/common.ml b/tests/ping-pong/common.ml index 149c315f..38f48c55 100644 --- a/tests/ping-pong/common.ml +++ b/tests/ping-pong/common.ml @@ -1,18 +1,19 @@ module type S = sig include Conduit.S - type 'a condition + type switch val serve : ?timeout:int -> + ?stop:switch -> handler:(flow -> unit io) -> ('cfg, 's, 'flow) Service.t -> 'cfg -> - unit condition * (unit -> unit io) + unit io end -module type CONDITION = sig - type 'a t +module type SWITCH = sig + type t end module type IO = sig @@ -25,10 +26,10 @@ let ( <.> ) f g x = f (g x) module Make (IO : IO) - (Condition : CONDITION) + (Switch : SWITCH) (Conduit : S with type +'a io = 'a IO.t - and type 'a condition = 'a Condition.t + and type switch := Switch.t and type input = Cstruct.t and type output = Cstruct.t) = struct @@ -112,10 +113,9 @@ struct let server : type cfg s. - (cfg, s, 'flow) Conduit.Service.t -> - cfg -> - unit Condition.t * (unit -> unit IO.t) = - fun service cfg -> Conduit.serve ~handler:transmission service cfg + ?stop:Switch.t -> (cfg, s, 'flow) Conduit.Service.t -> cfg -> unit IO.t = + fun ?stop service cfg -> + Conduit.serve ?stop ~handler:transmission service cfg (* part *) diff --git a/tests/ping-pong/with_async.ml b/tests/ping-pong/with_async.ml index b38b1fd4..5552350a 100644 --- a/tests/ping-pong/with_async.ml +++ b/tests/ping-pong/with_async.ml @@ -3,6 +3,10 @@ open Async let () = Mirage_crypto_rng_unix.initialize () +module Stop = struct + type t = unit Deferred.t +end + include Common.Make (struct type +'a t = 'a Async.Deferred.t @@ -13,12 +17,8 @@ include Common.Make let yield () = Async.Deferred.return () end) - (Async.Condition) - (struct - type 'a condition = 'a Async.Condition.t - - include Conduit_async - end) + (Stop) + (Conduit_async) let tcp_protocol, tcp_service = let open Conduit_async.TCP in @@ -49,7 +49,11 @@ let run_with : type cfg service flow. (cfg, service, flow) Conduit_async.Service.t -> cfg -> string list -> unit = fun service cfg clients -> - let stop, server = server (* ~launched ~stop *) service cfg in + let stop, signal_stop = + let open Async.Ivar in + let v = create () in + (read v, fill v) in + let server = server (* ~launched *) ~stop service cfg in let clients = Async.after Core.Time.Span.(of_sec 0.5) >>= fun () -> (* XXX(dinosaure): [async] tries to go further and fibers @@ -57,11 +61,9 @@ let run_with : * We waiting a bit to ensure that the server is launched * before clients. *) let clients = List.map (client ~resolvers) clients in - Async.Deferred.all_unit clients >>= fun () -> - Condition.broadcast stop () ; - Async.return () in + Async.Deferred.all_unit clients >>| signal_stop in Async.don't_wait_for - (Async.Deferred.all_unit [ server (); clients ] >>| fun () -> shutdown 0) ; + (Async.Deferred.all_unit [ server; clients ] >>| fun () -> shutdown 0) ; Core.never_returns (Scheduler.go ()) let run_with_tcp clients = diff --git a/tests/ping-pong/with_lwt.ml b/tests/ping-pong/with_lwt.ml index f094abbe..202b1ef9 100644 --- a/tests/ping-pong/with_lwt.ml +++ b/tests/ping-pong/with_lwt.ml @@ -10,12 +10,7 @@ module Lwt = struct let yield = Lwt_unix.yield end -include Common.Make (Lwt) (Lwt_condition) - (struct - type 'a condition = 'a Lwt_condition.t - - include Conduit_lwt - end) +include Common.Make (Lwt) (Lwt_switch) (Conduit_lwt) (* Composition *) @@ -62,13 +57,11 @@ let run_with : type cfg s flow. (cfg, s, flow) Conduit_lwt.Service.t -> cfg -> string list -> unit = fun service cfg clients -> - let stop, server = server service cfg in + let stop = Lwt_switch.create () in + let server = server ~stop service cfg in let clients = List.map (client ~resolvers) clients in - let clients = - Lwt.join clients >>= fun () -> - Lwt_condition.broadcast stop () ; - Lwt.return_unit in - Lwt_main.run (Lwt.join [ server (); clients ]) + let clients = Lwt.join clients >>= fun () -> Lwt_switch.turn_off stop in + Lwt_main.run (Lwt.join [ server; clients ]) let run_with_tcp clients = run_with Conduit_lwt.TCP.service