Skip to content

Commit

Permalink
Avoid shutdown race condition on service startup
Browse files Browse the repository at this point in the history
Fix mirage#352.

The `lwt` and `async` service loops previously took condition variables
that are used to signal when to shutdown. This resulted in a race
condition: the shutdown signal can be broadcast before the server is
waiting on it.

Fixed by having the service loops take switches as configuration
instead.
  • Loading branch information
craigfe committed Dec 4, 2020
1 parent da9377d commit d7ff4f0
Show file tree
Hide file tree
Showing 7 changed files with 112 additions and 115 deletions.
66 changes: 33 additions & 33 deletions src/async/conduit_async.ml
Original file line number Diff line number Diff line change
Expand Up @@ -18,43 +18,43 @@ type ('a, 'b, 'c) service = ('a, 'b, 'c) Service.t
let serve :
type cfg t v.
?timeout:int ->
?stop:unit Async.Deferred.t ->
handler:(flow -> unit Async.Deferred.t) ->
(cfg, t, v) service ->
cfg ->
unit Async.Condition.t * (unit -> unit Async.Deferred.t) =
fun ?timeout ~handler service cfg ->
unit Async.Deferred.t =
fun ?timeout ?(stop = Async.Deferred.never ()) ~handler service cfg ->
let open Async in
let stop = Async.Condition.create () in
let main () =
Service.init service cfg >>= function
| Error err -> failwith "%a" Service.pp_error err
| Ok t -> (
let rec loop () =
let close = Async.Condition.wait stop >>| fun () -> Ok `Stop in
let accept =
Service.accept service t >>? fun flow ->
Async.(Deferred.ok (return (`Flow flow))) in
let events =
match timeout with
| None -> [ close; accept ]
| Some t ->
let t = Core.Time.Span.of_int_sec t in
let timeout = Async.after t >>| fun () -> Ok `Timeout in
[ close; accept; timeout ] in

Async.Deferred.any events >>= function
| Ok (`Flow flow) ->
Async.don't_wait_for (handler flow) ;
Async.Scheduler.yield () >>= fun () -> (loop [@tailcall]) ()
| Ok (`Stop | `Timeout) -> Service.stop service t
| Error err0 -> (
Service.stop service t >>= function
| Ok () -> Async.return (Error err0)
| Error _err1 -> Async.return (Error err0)) in
loop () >>= function
| Ok () -> Async.return ()
| Error err -> failwith "%a" Service.pp_error err) in
(stop, main)
let timeout =
match timeout with
| None -> Deferred.never
| Some t -> fun () -> after (Core.Time.Span.of_int_sec t) in
Service.init service cfg >>= function
| Error err -> failwith "%a" Service.pp_error err
| Ok t -> (
let rec loop () =
let accept = Service.accept service t in
Deferred.choose
[
choice accept (Result.map (fun f -> `Flow f));
choice (timeout ()) (fun () -> Ok `Timeout);
]
>>? function
| `Flow flow ->
don't_wait_for (handler flow) ;
Scheduler.yield () >>= loop
| `Timeout -> return (Ok `Timeout) in
let stop_result =
Deferred.choose
[ choice stop (fun () -> Ok `Stopped); choice (loop ()) (fun r -> r) ]
>>= function
| Ok (`Timeout | `Stopped) -> Service.stop service t
| Error _ as err0 -> (
Service.stop service t >>= function Ok () | Error _ -> return err0)
in
stop_result >>= function
| Ok () -> return ()
| Error err -> failwith "%a" Service.pp_error err)

let reader_and_writer_of_flow flow =
let open Async in
Expand Down
7 changes: 4 additions & 3 deletions src/async/conduit_async.mli
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,20 @@ type ('a, 'b, 'c) service = ('a, 'b, 'c) Service.t

val serve :
?timeout:int ->
?stop:unit Async.Deferred.t ->
handler:(flow -> unit Async.Deferred.t) ->
('cfg, 't, 'v) service ->
'cfg ->
unit Async.Condition.t * (unit -> unit Async.Deferred.t)
unit Async.Deferred.t
(** [serve ~handler t cfg] creates an infinite service loop from the given
configuration ['cfg]. It returns the {i promise} to launch the loop and a
condition variable to stop the loop.
{[
let stop, loop = serve ~handler TCP.service cfg in
let loop = serve ~handler TCP.service cfg in
Async_unix.Signal.handle [ Core.Signal.int ] ~f:(fun _sig ->
Async.Condition.broadcast stop ()) ;
loop ()
loop
]} *)

val reader_and_writer_of_flow :
Expand Down
68 changes: 34 additions & 34 deletions src/lwt/conduit_lwt.ml
Original file line number Diff line number Diff line change
Expand Up @@ -53,44 +53,44 @@ let ( >>? ) = Lwt_result.bind
let serve :
type cfg service v.
?timeout:int ->
?stop:Lwt_switch.t ->
handler:(flow -> unit Lwt.t) ->
(cfg, service, v) Service.t ->
cfg ->
unit Lwt_condition.t * (unit -> unit Lwt.t) =
fun ?timeout ~handler service cfg ->
unit Lwt.t =
fun ?timeout ?stop ~handler service cfg ->
let open Lwt.Infix in
let stop = Lwt_condition.create () in
let main () =
Service.init service cfg >>= function
| Error err -> failwith "%a" Service.pp_error err
| Ok t -> (
let rec loop () =
let stop = Lwt_condition.wait stop >>= fun () -> Lwt.return_ok `Stop in
let accept =
Service.accept service t >>? fun flow -> Lwt.return_ok (`Flow flow)
in
let events =
match timeout with
| None -> [ stop; accept ]
| Some t ->
let timeout =
Lwt_unix.sleep (float_of_int t) >>= fun () ->
Lwt.return_ok `Timeout in
[ stop; accept; timeout ] in

Lwt.pick events >>= function
| Ok (`Flow flow) ->
Lwt.async (fun () -> handler flow) ;
Lwt.pause () >>= loop
| Ok (`Stop | `Timeout) -> Service.stop service t
| Error err0 -> (
Service.stop service t >>= function
| Ok () -> Lwt.return_error err0
| Error _err1 -> Lwt.return_error err0) in
loop () >>= function
| Ok () -> Lwt.return_unit
| Error err -> failwith "%a" Service.pp_error err) in
(stop, main)
let timeout () =
match timeout with
| None -> Lwt.wait () |> fst
| Some t -> Lwt_unix.sleep (float_of_int t) in
Service.init service cfg >>= function
| Error err -> failwith "%a" Service.pp_error err
| Ok t -> (
let switched_off =
let t, u = Lwt.wait () in
Lwt_switch.add_hook stop (fun () ->
Lwt.wakeup_later u (Ok `Stopped) ;
Lwt.return_unit) ;
t in
let rec loop () =
let accept =
Service.accept service t >>? fun flow -> Lwt.return_ok (`Flow flow)
in
Lwt.pick [ accept; (timeout () >|= fun () -> Ok `Timeout) ] >>? function
| `Flow flow ->
Lwt.async (fun () -> handler flow) ;
Lwt.pause () >>= loop
| `Timeout -> Lwt.return (Ok `Timeout) in
let stop_result =
Lwt.pick [ switched_off; loop () ] >>= function
| Ok (`Timeout | `Stopped) -> Service.stop service t
| Error _ as err0 -> (
Service.stop service t >>= function
| Ok () | Error _ -> Lwt.return err0) in
stop_result >>= function
| Ok () -> Lwt.return_unit
| Error err -> failwith "%a" Service.pp_error err)

module TCP = struct
open Lwt.Infix
Expand Down
25 changes: 13 additions & 12 deletions src/lwt/conduit_lwt.mli
Original file line number Diff line number Diff line change
Expand Up @@ -19,25 +19,26 @@ type ('a, 'b, 'c) service = ('a, 'b, 'c) Service.t

val serve :
?timeout:int ->
?stop:Lwt_switch.t ->
handler:(flow -> unit Lwt.t) ->
('cfg, 'service, 'v) service ->
'cfg ->
unit Lwt_condition.t * (unit -> unit Lwt.t)
(** [serve ~handler service cfg] creates an usual infinite [service] loop from
the given configuration ['cfg]. It returns the {i promise} to launch the
loop and a condition variable to stop the loop.
unit Lwt.t
(** [serve ~handler service cfg] launches a [service] loop from the given
configuration ['cfg]. By default, the service loop runs indefinitely.
- If passed, [~stop] is a switch that terminates the service loop, for
example to limit execution time to 10 seconds:
{[
let stop, loop = serve ~handler TCP.service cfg in
let switch = Lwt_switch.create () in
let loop = serve ~switch ~handler TCP.service cfg in
Lwt.both
( Lwt_unix.sleep 10. >>= fun () ->
Lwt_condition.broadcast stop () ;
Lwt.return () )
(loop ())
(Lwt_unix.sleep 10. >>= fun () -> Lwt_switch.turn_off switch)
loop
]}
In your example, we want to launch a server only for 10 seconds. To help the
user, the option [?timeout] allows us to wait less than [timeout] seconds. *)
- If passed, [~timeout] specifies a maximum time to wait between accepting
connections. *)

module TCP : sig
(** Implementation of TCP protocol as a client.
Expand Down
20 changes: 10 additions & 10 deletions tests/ping-pong/common.ml
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
module type S = sig
include Conduit.S

type 'a condition
type switch

val serve :
?timeout:int ->
?stop:switch ->
handler:(flow -> unit io) ->
('cfg, 's, 'flow) Service.t ->
'cfg ->
unit condition * (unit -> unit io)
unit io
end

module type CONDITION = sig
type 'a t
module type SWITCH = sig
type t
end

module type IO = sig
Expand All @@ -25,10 +26,10 @@ let ( <.> ) f g x = f (g x)

module Make
(IO : IO)
(Condition : CONDITION)
(Switch : SWITCH)
(Conduit : S
with type +'a io = 'a IO.t
and type 'a condition = 'a Condition.t
and type switch := Switch.t
and type input = Cstruct.t
and type output = Cstruct.t) =
struct
Expand Down Expand Up @@ -112,10 +113,9 @@ struct

let server :
type cfg s.
(cfg, s, 'flow) Conduit.Service.t ->
cfg ->
unit Condition.t * (unit -> unit IO.t) =
fun service cfg -> Conduit.serve ~handler:transmission service cfg
?stop:Switch.t -> (cfg, s, 'flow) Conduit.Service.t -> cfg -> unit IO.t =
fun ?stop service cfg ->
Conduit.serve ?stop ~handler:transmission service cfg

(* part *)

Expand Down
24 changes: 13 additions & 11 deletions tests/ping-pong/with_async.ml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@ open Async

let () = Mirage_crypto_rng_unix.initialize ()

module Stop = struct
type t = unit Deferred.t
end

include Common.Make
(struct
type +'a t = 'a Async.Deferred.t
Expand All @@ -13,12 +17,8 @@ include Common.Make

let yield () = Async.Deferred.return ()
end)
(Async.Condition)
(struct
type 'a condition = 'a Async.Condition.t

include Conduit_async
end)
(Stop)
(Conduit_async)

let tcp_protocol, tcp_service =
let open Conduit_async.TCP in
Expand Down Expand Up @@ -49,19 +49,21 @@ let run_with :
type cfg service flow.
(cfg, service, flow) Conduit_async.Service.t -> cfg -> string list -> unit =
fun service cfg clients ->
let stop, server = server (* ~launched ~stop *) service cfg in
let stop, signal_stop =
let open Async.Ivar in
let v = create () in
(read v, fill v) in
let server = server (* ~launched *) ~stop service cfg in
let clients =
Async.after Core.Time.Span.(of_sec 0.5) >>= fun () ->
(* XXX(dinosaure): [async] tries to go further and fibers
* can be launched before the initialization of the server.
* We waiting a bit to ensure that the server is launched
* before clients. *)
let clients = List.map (client ~resolvers) clients in
Async.Deferred.all_unit clients >>= fun () ->
Condition.broadcast stop () ;
Async.return () in
Async.Deferred.all_unit clients >>| signal_stop in
Async.don't_wait_for
(Async.Deferred.all_unit [ server (); clients ] >>| fun () -> shutdown 0) ;
(Async.Deferred.all_unit [ server; clients ] >>| fun () -> shutdown 0) ;
Core.never_returns (Scheduler.go ())

let run_with_tcp clients =
Expand Down
17 changes: 5 additions & 12 deletions tests/ping-pong/with_lwt.ml
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,7 @@ module Lwt = struct
let yield = Lwt_unix.yield
end

include Common.Make (Lwt) (Lwt_condition)
(struct
type 'a condition = 'a Lwt_condition.t

include Conduit_lwt
end)
include Common.Make (Lwt) (Lwt_switch) (Conduit_lwt)

(* Composition *)

Expand Down Expand Up @@ -62,13 +57,11 @@ let run_with :
type cfg s flow.
(cfg, s, flow) Conduit_lwt.Service.t -> cfg -> string list -> unit =
fun service cfg clients ->
let stop, server = server service cfg in
let stop = Lwt_switch.create () in
let server = server ~stop service cfg in
let clients = List.map (client ~resolvers) clients in
let clients =
Lwt.join clients >>= fun () ->
Lwt_condition.broadcast stop () ;
Lwt.return_unit in
Lwt_main.run (Lwt.join [ server (); clients ])
let clients = Lwt.join clients >>= fun () -> Lwt_switch.turn_off stop in
Lwt_main.run (Lwt.join [ server; clients ])

let run_with_tcp clients =
run_with Conduit_lwt.TCP.service
Expand Down

0 comments on commit d7ff4f0

Please sign in to comment.