Skip to content

Commit

Permalink
Merge pull request #443 from mirage/cancelation
Browse files Browse the repository at this point in the history
Add cancelation on tcpip.stack-socket
  • Loading branch information
dinosaure authored Mar 26, 2021
2 parents d3d6f40 + 509cb2c commit f0b0094
Showing 1 changed file with 35 additions and 24 deletions.
59 changes: 35 additions & 24 deletions src/stack-unix/tcpip_stack_socket.ml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ open Lwt.Infix
let src = Logs.Src.create "tcpip-stack-socket" ~doc:"Platform's native TCP/IP stack"
module Log = (val Logs.src_log src : Logs.LOG)

let ignore_canceled = function
| Lwt.Canceled -> Lwt.return_unit
| exn -> raise exn

module V4 = struct
module TCPV4 = Tcpv4_socket
module UDPV4 = Udpv4_socket
Expand All @@ -27,6 +31,8 @@ module V4 = struct
type t = {
udpv4 : UDPV4.t;
tcpv4 : TCPV4.t;
stop : unit Lwt.u;
switched_off : unit Lwt.t;
}

let udpv4 { udpv4; _ } = udpv4
Expand All @@ -44,7 +50,7 @@ module V4 = struct
UDPV4.get_udpv4_listening_fd t.udpv4 port >>= fun fd ->
let buf = Cstruct.create 4096 in
let rec loop () =
(* TODO cancellation *)
if not (Lwt.is_sleeping t.switched_off) then raise Lwt.Canceled ;
Lwt.catch (fun () ->
Lwt_cstruct.recvfrom fd buf [] >>= fun (len, sa) ->
let buf = Cstruct.sub buf 0 len in
Expand All @@ -59,7 +65,8 @@ module V4 = struct
Lwt.return_unit) >>= fun () ->
loop ()
in
loop ())
Lwt.catch loop ignore_canceled >>= fun () ->
Lwt_unix.close fd)

let listen_tcpv4 ?keepalive t ~port callback =
if port < 0 || port > 65535 then
Expand All @@ -73,6 +80,7 @@ module V4 = struct
Lwt.async (fun () ->
(* TODO cancellation *)
let rec loop () =
if not (Lwt.is_sleeping t.switched_off) then raise Lwt.Canceled ;
Lwt.catch (fun () ->
Lwt_unix.accept fd >|= fun (afd, _) ->
(match keepalive with
Expand All @@ -91,17 +99,16 @@ module V4 = struct
Lwt.return_unit) >>= fun () ->
loop ()
in
loop ())
Lwt.catch loop ignore_canceled >>= fun () -> Lwt_unix.close fd)

let listen _t =
let t, _ = Lwt.task () in
t (* TODO cancellation *)
let listen t = t.switched_off

let connect udpv4 tcpv4 =
Log.info (fun f -> f "IPv4 socket stack: connect");
Lwt.return { tcpv4; udpv4 }
let switched_off, stop = Lwt.wait () in
Lwt.return { tcpv4; udpv4; stop; switched_off; }

let disconnect _ = Lwt.return_unit
let disconnect t = Lwt.wakeup_later t.stop () ; Lwt.return_unit
end

module V6 = struct
Expand All @@ -112,6 +119,8 @@ module V6 = struct
type t = {
udp : UDP.t;
tcp : TCP.t;
stop : unit Lwt.u;
switched_off : unit Lwt.t;
}

let udp { udp; _ } = udp
Expand All @@ -129,7 +138,7 @@ module V6 = struct
UDP.get_udpv6_listening_fd t.udp port >>= fun fd ->
let buf = Cstruct.create 4096 in
let rec loop () =
(* TODO cancellation *)
if not (Lwt.is_sleeping t.switched_off) then raise Lwt.Canceled ;
Lwt.catch (fun () ->
Lwt_cstruct.recvfrom fd buf [] >>= fun (len, sa) ->
let buf = Cstruct.sub buf 0 len in
Expand All @@ -144,7 +153,7 @@ module V6 = struct
Lwt.return_unit) >>= fun () ->
loop ()
in
loop ())
Lwt.catch loop ignore_canceled >>= fun () -> Lwt_unix.close fd)

let listen_tcp ?keepalive t ~port callback =
if port < 0 || port > 65535 then
Expand All @@ -159,6 +168,7 @@ module V6 = struct
Lwt.async (fun () ->
(* TODO cancellation *)
let rec loop () =
if not (Lwt.is_sleeping t.switched_off) then raise Lwt.Canceled ;
Lwt.catch (fun () ->
Lwt_unix.accept fd >|= fun (afd, _) ->
(match keepalive with
Expand All @@ -177,17 +187,16 @@ module V6 = struct
Lwt.return_unit) >>= fun () ->
loop ()
in
loop ())
Lwt.catch loop ignore_canceled >>= fun () -> Lwt_unix.close fd)

let listen _t =
let t, _ = Lwt.task () in
t (* TODO cancellation *)
let listen t = t.switched_off

let connect udp tcp =
Log.info (fun f -> f "IPv6 socket stack: connect");
Lwt.return { tcp; udp }
let switched_off, stop = Lwt.wait () in
Lwt.return { tcp; udp; stop; switched_off; }

let disconnect _ = Lwt.return_unit
let disconnect t = Lwt.wakeup_later t.stop () ; Lwt.return_unit
end

module V4V6 = struct
Expand All @@ -198,6 +207,8 @@ module V4V6 = struct
type t = {
udp : UDP.t;
tcp : TCP.t;
stop : unit Lwt.u;
switched_off : unit Lwt.t;
}

let udp { udp; _ } = udp
Expand All @@ -217,7 +228,7 @@ module V4V6 = struct
Lwt.async (fun () ->
let buf = Cstruct.create 4096 in
let rec loop () =
(* TODO cancellation *)
if not (Lwt.is_sleeping t.switched_off) then raise Lwt.Canceled ;
Lwt.catch (fun () ->
Lwt_cstruct.recvfrom fd buf [] >>= fun (len, sa) ->
let buf = Cstruct.sub buf 0 len in
Expand All @@ -232,7 +243,7 @@ module V4V6 = struct
Lwt.return_unit) >>= fun () ->
loop ()
in
loop ())) fds)
Lwt.catch loop ignore_canceled >>= fun () -> Lwt_unix.close fd)) fds)

let listen_tcp ?keepalive t ~port callback =
if port < 0 || port > 65535 then
Expand Down Expand Up @@ -269,6 +280,7 @@ module V4V6 = struct
Lwt.async (fun () ->
(* TODO cancellation *)
let rec loop () =
if not (Lwt.is_sleeping t.switched_off) then raise Lwt.Canceled ;
Lwt.catch (fun () ->
Lwt_unix.accept fd >|= fun (afd, _) ->
(match keepalive with
Expand All @@ -287,15 +299,14 @@ module V4V6 = struct
Lwt.return_unit) >>= fun () ->
loop ()
in
loop ())) fds
Lwt.catch loop ignore_canceled >>= fun () -> Lwt_unix.close fd)) fds

let listen _t =
let t, _ = Lwt.task () in
t (* TODO cancellation *)
let listen t = t.switched_off

let connect udp tcp =
Log.info (fun f -> f "Dual IPv4 and IPv6 socket stack: connect");
Lwt.return { tcp; udp }
let switched_off, stop = Lwt.wait () in
Lwt.return { tcp; udp; stop; switched_off; }

let disconnect _ = Lwt.return_unit
let disconnect t = Lwt.wakeup_later t.stop () ; Lwt.return_unit
end

0 comments on commit f0b0094

Please sign in to comment.