Skip to content

Commit

Permalink
Merge pull request #489 from TheLortex/simultaneous-close-fix
Browse files Browse the repository at this point in the history
Fixing memory leaks when a simultaneous close is happening
  • Loading branch information
dinosaure committed Jul 27, 2022
2 parents afa354f + aac0e02 commit efbebdf
Show file tree
Hide file tree
Showing 11 changed files with 311 additions and 167 deletions.
9 changes: 9 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
### v7.1.2 (2022-07-27)

* TCP: fix memory leaks on connection close in three scenarios (#489 @TheLortex)
- simultanous close: set up the timewait timer in the `Closing(1) - Recv_ack(2) -> Time_wait`
state transition
- client sends a RST instead of a FIN: enable sending a challenge ACK even when the reception
thread is stopped
- client doesn't ACK server's FIN: enable the retransmit timer in the `Closing(_)` state

### v7.1.1 (2022-05-24)

* Ndpv6: demote more logs to debug level (#480 @reynir)
Expand Down
12 changes: 7 additions & 5 deletions src/tcp/flow.ml
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ module Log = (val Logs.src_log src : Logs.LOG)
module Make(Ip: Tcpip.Ip.S)(Time:Mirage_time.S)(Clock:Mirage_clock.MCLOCK)(Random:Mirage_random.S) =
struct

module RXS = Segment.Rx(Time)
module TXS = Segment.Tx(Time)(Clock)
module ACK = Ack.Immediate
module RXS = Segment.Rx(Time)(ACK)
module TXS = Segment.Tx(Time)(Clock)
module UTX = User_buffer.Tx(Time)(Clock)
module WIRE = Wire.Make(Ip)
module STATE = State.Make(Time)
Expand Down Expand Up @@ -75,6 +75,8 @@ struct
connects: (WIRE.t, ((connection, error) result Lwt.u * Sequence.t * Tcpip.Tcp.Keepalive.t option)) Hashtbl.t;
}

let num_open_channels t = Hashtbl.length t.channels

let listen t ~port ?keepalive cb =
if port < 0 || port > 65535 then
raise (Invalid_argument (Printf.sprintf "invalid port number (%d)" port))
Expand Down Expand Up @@ -356,11 +358,11 @@ struct
let txq, _tx_t =
TXS.create ~xmit:(Tx.xmit_pcb t.ip id) ~wnd ~state ~rx_ack ~tx_ack ~tx_wnd_update
in
(* The user application transmit buffer *)
let utx = UTX.create ~wnd ~txq ~max_size:16384l in
let rxq = RXS.create ~rx_data ~wnd ~state ~tx_ack in
(* Set up ACK module *)
let ack = ACK.t ~send_ack ~last:(Sequence.succ rx_isn) in
(* The user application transmit buffer *)
let utx = UTX.create ~wnd ~txq ~max_size:16384l in
let rxq = RXS.create ~rx_data ~ack ~wnd ~state ~tx_ack in
(* Set up the keepalive state if requested *)
let keepalive = match keepalive with
| None -> None
Expand Down
6 changes: 6 additions & 0 deletions src/tcp/flow.mli
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,10 @@ module Make (IP:Tcpip.Ip.S)
(R:Mirage_random.S) : sig
include Tcpip.Tcp.S with type ipaddr = IP.ipaddr
val connect : IP.t -> t Lwt.t

(**/**)
(* the number of open connections *)
val num_open_channels : t -> int
(**/**)

end
14 changes: 6 additions & 8 deletions src/tcp/segment.ml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ let rec reset_seq segs =
It also looks for control messages and dispatches them to
the Rtx queue to ack messages or close channels.
*)
module Rx(Time:Mirage_time.S) = struct
module Rx(Time:Mirage_time.S)(ACK: Ack.M) = struct
open Tcp_packet
module StateTick = State.Make(Time)

Expand Down Expand Up @@ -82,14 +82,15 @@ module Rx(Time:Mirage_time.S) = struct
type t = {
mutable segs: S.t;
rx_data: (Cstruct.t list option * Sequence.t option) Lwt_mvar.t; (* User receive channel *)
ack: ACK.t;
tx_ack: (Sequence.t * int) Lwt_mvar.t; (* Acks of our transmitted segs *)
wnd: Window.t;
state: State.t;
}

let create ~rx_data ~wnd ~state ~tx_ack =
let create ~rx_data ~ack ~wnd ~state ~tx_ack =
let segs = S.empty in
{ segs; rx_data; tx_ack; wnd; state }
{ segs; rx_data; ack; tx_ack; wnd; state }

let pp fmt t =
let pp_v fmt seg =
Expand Down Expand Up @@ -133,10 +134,7 @@ module Rx(Time:Mirage_time.S) = struct

let send_challenge_ack q =
(* TODO: rfc5961 ACK Throttling *)
(* Is this the correct way trigger an ack? *)
if Lwt_mvar.is_empty q.rx_data
then Lwt_mvar.put q.rx_data (Some [], Some Sequence.zero)
else Lwt.return_unit
ACK.pushack q.ack Sequence.zero

(* Given an input segment, the window information, and a receive
queue, update the window, extract any ready segments into the
Expand Down Expand Up @@ -285,7 +283,7 @@ module Tx (Time:Mirage_time.S) (Clock:Mirage_clock.MCLOCK) = struct
let ontimer xmit st segs wnd seq =
match State.state st with
| State.Syn_rcvd _ | State.Established | State.Fin_wait_1 _
| State.Close_wait | State.Last_ack _ ->
| State.Close_wait | State.Closing _ | State.Last_ack _ ->
begin match peek_opt_l segs with
| None -> Lwt.return Tcptimer.Stoptimer
| Some rexmit_seg ->
Expand Down
3 changes: 2 additions & 1 deletion src/tcp/segment.mli
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
the Rtx queue to ack messages or close channels.
*)

module Rx (T:Mirage_time.S) : sig
module Rx (T:Mirage_time.S)(ACK:Ack.M) : sig

type segment = { header: Tcp_packet.t; payload: Cstruct.t }
(** Individual received TCP segment *)
Expand All @@ -38,6 +38,7 @@ module Rx (T:Mirage_time.S) : sig

val create:
rx_data:(Cstruct.t list option * Sequence.t option) Lwt_mvar.t ->
ack:ACK.t ->
wnd:Window.t ->
state:State.t ->
tx_ack:(Sequence.t * int) Lwt_mvar.t ->
Expand Down
13 changes: 9 additions & 4 deletions src/tcp/state.ml
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,10 @@ module Make(Time:Mirage_time.S) = struct
t.on_close ();
Lwt.return_unit

let transition_to_timewait t =
Lwt.async (fun () -> timewait t time_wait_time);
Time_wait

let tick t (i:action) =
let diffone x y = Sequence.succ y = x in
let tstr s (i:action) =
Expand Down Expand Up @@ -148,10 +152,11 @@ module Make(Time:Mirage_time.S) = struct
| Fin_wait_1 _, Recv_rst -> t.on_close (); Reset
| Fin_wait_2 i, Recv_ack _ -> Fin_wait_2 (i + 1)
| Fin_wait_2 _, Recv_rst -> t.on_close (); Reset
| Fin_wait_2 _, Recv_fin ->
Lwt.async (fun () -> timewait t time_wait_time);
Time_wait
| Closing a, Recv_ack b -> if diffone b a then Time_wait else Closing a
| Fin_wait_2 _, Recv_fin -> transition_to_timewait t
| Closing a, Recv_ack b ->
if diffone b a then
transition_to_timewait t
else Closing a
| Closing _, Timeout -> t.on_close (); Closed
| Closing _, Recv_rst -> t.on_close (); Reset
| Time_wait, Timeout -> t.on_close (); Closed
Expand Down
150 changes: 150 additions & 0 deletions test/low_level.ml
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
open Lwt.Infix

(*
* Connects two stacks to the same backend.
* One is a complete v4 stack (the system under test, referred to as [sut]).
* The other gives us low level access to inject crafted TCP packets,
* and sends and receives crafted packets to check the [sut] behavior.
*)
module VNETIF_STACK = Vnetif_common.VNETIF_STACK(Vnetif_backends.Basic)

module Time = Vnetif_common.Time
module V = Vnetif.Make(Vnetif_backends.Basic)
module E = Ethernet.Make(V)
module A = Arp.Make(E)(Time)
module I = Static_ipv4.Make(Mirage_random_test)(Vnetif_common.Clock)(E)(A)
module Wire = Tcp.Wire
module WIRE = Wire.Make(I)
module Tcp_wire = Tcp.Tcp_wire
module Tcp_unmarshal = Tcp.Tcp_packet.Unmarshal
module Sequence = Tcp.Sequence

let sut_cidr = Ipaddr.V4.Prefix.of_string_exn "10.0.0.101/24"
let server_ip = Ipaddr.V4.of_string_exn "10.0.0.100"
let server_cidr = Ipaddr.V4.Prefix.make 24 server_ip
let gateway = Ipaddr.V4.of_string_exn "10.0.0.1"

let header_size = Ethernet.Packet.sizeof_ethernet



(* defaults when injecting packets *)
let options = []
let window = 5120

(* Helper functions *)
let reply_id_from ~src ~dst data =
let sport = Tcp_wire.get_tcp_src_port data in
let dport = Tcp_wire.get_tcp_dst_port data in
WIRE.v ~dst_port:sport ~dst:src ~src_port:dport ~src:dst

let ack_for data =
match Tcp_unmarshal.of_cstruct data with
| Error s -> Alcotest.fail ("attempting to ack data: " ^ s)
| Ok (packet, data) ->
let open Tcp.Tcp_packet in
let data_len =
Sequence.of_int ((Cstruct.length data) +
(if packet.fin then 1 else 0) +
(if packet.syn then 1 else 0)) in
let sequence = packet.sequence in
let ack_n = Sequence.(add sequence data_len) in
ack_n

let ack data =
Some(ack_for data)

let ack_in_future data off =
Some Sequence.(add (ack_for data) (of_int off))

let ack_from_past data off =
Some Sequence.(sub (ack_for data) (of_int off))

let fail_result_not_expected fail = function
| Error _err ->
fail "error not expected"
| Ok `Eof ->
fail "eof"
| Ok (`Data data) ->
Alcotest.fail (Format.asprintf "data not expected but received: %a"
Cstruct.hexdump_pp data)



let create_sut_stack backend =
VNETIF_STACK.create_stack ~cidr:sut_cidr ~gateway backend

let create_raw_stack backend =
V.connect backend >>= fun netif ->
E.connect netif >>= fun ethif ->
A.connect ethif >>= fun arpv4 ->
I.connect ~cidr:server_cidr ~gateway ethif arpv4 >>= fun ip ->
Lwt.return (netif, ethif, arpv4, ip)

type 'state fsm_result =
| Fsm_next of 'state
| Fsm_done
| Fsm_error of string

(* This could be moved to a common module and reused for other low level tcp tests *)

(* setups network and run a given sut and raw fsm *)
let run backend fsm sut () =
let initial_state, fsm_handler = fsm in
create_sut_stack backend >>= fun stackv4 ->
create_raw_stack backend >>= fun (netif, ethif, arp, rawip) ->
let error_mbox = Lwt_mvar.create_empty () in
let stream, pushf = Lwt_stream.create () in
Lwt.pick [
VNETIF_STACK.Stackv4.listen stackv4;

(* Consume TCP packets one by one, in sequence *)
let rec fsm_thread state =
Lwt_stream.next stream >>= fun (src, dst, data) ->
fsm_handler rawip state ~src ~dst data >>= function
| Fsm_next s ->
fsm_thread s
| Fsm_done ->
Lwt.return_unit
| Fsm_error err ->
Lwt_mvar.put error_mbox err >>= fun () ->
(* it will be terminated anyway when the error is picked up *)
fsm_thread state in

Lwt.async (fun () ->
(V.listen netif ~header_size
(E.input
~arpv4:(A.input arp)
~ipv4:(I.input
~tcp: (fun ~src ~dst data -> pushf (Some(src,dst,data)); Lwt.return_unit)
~udp:(fun ~src:_ ~dst:_ _data -> Lwt.return_unit)
~default:(fun ~proto ~src ~dst _data ->
Logs.debug (fun f -> f "default handler invoked for packet from %a to %a, protocol %d -- dropping" Ipaddr.V4.pp src Ipaddr.V4.pp dst proto); Lwt.return_unit)
rawip
)
~ipv6:(fun _buf ->
Logs.debug (fun f -> f "IPv6 packet -- dropping");
Lwt.return_unit)
ethif) ) >|= fun _ -> ());

(* Either both fsm and the sut terminates, or a timeout occurs, or one of the sut/fsm informs an error *)
Lwt.pick [
(Time.sleep_ns (Duration.of_sec 5) >>= fun () ->
Lwt.return_some "timed out");

(Lwt.join [
(fsm_thread initial_state);

(* time to let the other end connects to the network and listen.
* Otherwise initial syn might need to be repeated slowing down the test *)
(Time.sleep_ns (Duration.of_ms 100) >>= fun () ->
sut stackv4 (Lwt_mvar.put error_mbox) >>= fun _ ->
Time.sleep_ns (Duration.of_ms 100));
] >>= fun () -> Lwt.return_none);

(Lwt_mvar.take error_mbox >>= fun cause ->
Lwt.return_some cause);
] >|= function
| None -> ()
| Some err -> Alcotest.fail err
]
1 change: 1 addition & 0 deletions test/test.ml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ let suite = [
"iperf" , Test_iperf.suite ;
"iperf_ipv6" , Test_iperf_ipv6.suite ;
"keepalive" , Test_keepalive.suite ;
"simultaneous_close", Test_simulatenous_close.suite
]

let run test () =
Expand Down
Loading

0 comments on commit efbebdf

Please sign in to comment.