Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a way to set context mode when creating context #134

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,3 @@ jobs:
extra-trusted-public-keys = ocaml.nix-cache.com-1:/xI2h2+56rwFfKyyFVbkJSeGqSIYMC/Je+7XXqGKDIY=
- name: "Run nix-build"
run: nix-build ./nix/ci/test.nix --argstr ocamlVersion ${{ matrix.setup.ocamlVersion }}

45 changes: 41 additions & 4 deletions src/ssl.ml
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,6 @@ type verify_error =
| Error_v_keyusage_no_certsign
| Error_v_application_verification

type bigarray =
(char, Bigarray.int8_unsigned_elt, Bigarray.c_layout) Bigarray.Array1.t

external get_error_string : unit -> string = "ocaml_ssl_get_error_string"
(** Kept for backwards compatibility *)

Expand Down Expand Up @@ -211,9 +208,35 @@ type context_type =
| Server_context
| Both_context

external create_context :
module Modes = struct
type t = int

(* value taken from openssl/ssl.h *)
let no_mode = 0x000
let enable_partial_write = 0x001 (* SSL_MODE_ENABLE_PARTIAL_WRITE *)
(*let accept_moving_write_buffer = 0x002: is always set because of GC*)
let auto_retry = 0x004 (* SSL_MODE_AUTO_RETRY *)
let no_auto_chain = 0x008 (* SSL_MODE_RELEASE_BUFFERS *)
let release_buffers = 0x010 (* SSL_MODE_RELEASE_BUFFERS *)
let send_clienthello_time = 0x020 (* SSL_MODE_SEND_CLIENTHELLO_TIME *)
let send_serverhello_time = 0x040 (* SSL_MODE_SEND_SERVERHELLO_TIME *)
let send_fallback_scsv = 0x080 (* SSL_MODE_SEND_FALLBACK_SCSV *)
let async = 0x100 (* SSL_MODE_ASYNC *)

let (lor) = (lor)
let (land) = (land)
let lnot = lnot
let subset a b = a land (lnot b) = no_mode
end

external set_mode : context -> Modes.t -> unit = "ocaml_ssl_set_mode"
external clear_mode : context -> Modes.t -> unit = "ocaml_ssl_clear_mode"
external get_mode : context -> Modes.t = "ocaml_ssl_get_mode"

external raw_create_context :
protocol
-> context_type
-> Modes.t
-> context
= "ocaml_ssl_create_context"

Expand Down Expand Up @@ -454,9 +477,13 @@ external set_hostflags :
external set_host : socket -> string -> unit = "ocaml_ssl_set1_host"
external set_ip : socket -> string -> unit = "ocaml_ssl_set1_ip"

type bigarray =
(char, Bigarray.int8_unsigned_elt, Bigarray.c_layout) Bigarray.Array1.t

(* Here is the signature of the base communication functions that are
implemented below in two versions *)
module type Ssl_base = sig
val create_context : ?modes:Modes.t -> protocol -> context_type -> context
val connect : socket -> unit
val accept : socket -> unit
val ssl_shutdown : socket -> bool
Expand All @@ -471,6 +498,9 @@ end
(* Provide the base implementation communication functions that release the
OCaml runtime lock, allowing multiple systhreads to execute concurrently. *)
module Runtime_unlock_base = struct
let create_context ?(modes = Modes.auto_retry) protocol ctype =
raw_create_context protocol ctype modes

external connect : socket -> unit = "ocaml_ssl_connect"
external accept : socket -> unit = "ocaml_ssl_accept"
external write : socket -> Bytes.t -> int -> int -> int = "ocaml_ssl_write"
Expand Down Expand Up @@ -507,6 +537,10 @@ end

(* Same as above, but doesn't release the lock. *)
module Runtime_lock_base = struct
let create_context ?(modes = Modes.(async lor enable_partial_write))
protocol ctype =
raw_create_context protocol ctype modes

external get_error : socket -> int -> ssl_error = "ocaml_ssl_get_error_code"
[@@noalloc]

Expand Down Expand Up @@ -559,6 +593,9 @@ module Runtime_lock_base = struct
= "ocaml_ssl_write_blocking"
[@@noalloc]

(** Allow SSL_write(..., n) to return r with 0 < r < n (i.e. report success
when just a single record has been written *)

let write socket buffer start length =
if start < 0 then invalid_arg "Ssl.write: start negative";
if length < 0 then invalid_arg "Ssl.write: length negative";
Expand Down
73 changes: 68 additions & 5 deletions src/ssl.mli
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,6 @@ type ssl_error =
(** See
https://www.openssl.org/docs/manmaster/man3/SSL_CTX_set_verify.html *)

type bigarray =
(char, Bigarray.int8_unsigned_elt, Bigarray.c_layout) Bigarray.Array1.t

exception Method_error
(** The SSL method could not be initialized. *)

Expand Down Expand Up @@ -311,8 +308,67 @@ type context_type =
| Server_context (** Server connections. *)
| Both_context (** Client and server connections. *)

val create_context : protocol -> context_type -> context
(** Create a context. *)
module Modes : sig
(** set of mode *)
type t = int

val no_mode : t

(** Allow SSL_write(..., n) to return r with 0 < r < n (i.e. report success
when just a single record has been written *)
val enable_partial_write : t

(** Never bother the application with retries if the transport is blocking *)
val auto_retry : t

(** Don't attempt to automatically build certificate chain *)
val no_auto_chain : t

(** Save RAM by releasing read and write buffers when they're empty. (SSL3 and
TLS only.) Released buffers are freed. *)
val release_buffers : t

(** Send the current time in the Random fields of the ClientHello and
ServerHello records for compatibility with hypothetical implementations
that require it. *)
val send_clienthello_time : t
val send_serverhello_time : t

(** Send TLS_FALLBACK_SCSV in the ClientHello. To be set only by
applications that reconnect with a downgraded protocol version; see
draft-ietf-tls-downgrade-scsv-00 for details. DO NOT ENABLE THIS if your
application attempts a normal handshake. Only use this in explicit
fallback retries, following the guidance in
draft-ietf-tls-downgrade-scsv-00. *)
val send_fallback_scsv : t

(** Support Asynchronous operation *)
val async : t

(** put togther two sets of modes *)
val ( lor ) : t -> t -> t

(** conjunction of modes *)
val ( land ) : t -> t -> t

(** negation of modes *)
val lnot : t -> t

(** subset on modes*)
val subset : t -> t -> bool
end

(** Set the given modes in a context (does not clear preset modes) *)
val set_mode : context -> Modes.t -> unit

(** Clear the given modes in a context *)
val clear_mode : context -> Modes.t -> unit

(** Get the current mode of a context *)
val get_mode : context -> Modes.t

val create_context : ?modes:Modes.t -> protocol -> context_type -> context
(** Create a context. Default modes is Modes.(auto_retry) *)

val set_min_protocol_version : context -> protocol -> unit
(** [set_min_protocol_version ctx proto] sets the minimum supported protocol
Expand Down Expand Up @@ -571,6 +627,9 @@ val flush : socket -> unit
val read : socket -> Bytes.t -> int -> int -> int
(** [read sock buf off len] receives data from a connected SSL socket. *)

type bigarray =
(char, Bigarray.int8_unsigned_elt, Bigarray.c_layout) Bigarray.Array1.t

val read_into_bigarray : socket -> bigarray -> int -> int -> int
(** [read_into_bigarray sock ba off len] receives data from a connected SSL
socket. This function releases the runtime while the read takes place. *)
Expand Down Expand Up @@ -614,6 +673,10 @@ val output_int : socket -> int -> unit
i.e. handling of `EWOULDBLOCK`, `EGAIN`, etc. Additionally, the functions in
this module don't perform a copy of application data buffers. *)
module Runtime_lock : sig
val create_context : ?modes:Modes.t -> protocol -> context_type -> context
(** same as create_context above, but the default modes are
[Modes.(async lor enable_partial_write] *)

val connect : socket -> unit
(** Connect an SSL socket. *)

Expand Down
26 changes: 23 additions & 3 deletions src/ssl_stubs.c
Original file line number Diff line number Diff line change
Expand Up @@ -540,8 +540,28 @@ static void set_protocol(SSL_CTX *ssl_context, int protocol) {
}
}

CAMLprim value ocaml_ssl_create_context(value protocol, value type) {
CAMLparam2(protocol, type);
CAMLprim void ocaml_ssl_set_mode(value ctx, value modes) {
CAMLparam1(ctx);
SSL_CTX_set_mode(Ctx_val(ctx),
SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER | Int_val(modes));
CAMLreturn0;
}

CAMLprim void ocaml_ssl_clear_mode(value ctx, value modes) {
CAMLparam1(ctx);
SSL_CTX_clear_mode(Ctx_val(ctx),
~SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER & Int_val(modes));
CAMLreturn0;
}

CAMLprim value ocaml_ssl_get_mode(value ctx, value modes) {
CAMLparam1(ctx);
long r = SSL_CTX_get_mode(Ctx_val(ctx));
CAMLreturn(Val_int(~SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER & r));
}

CAMLprim value ocaml_ssl_create_context(value protocol, value type, value modes) {
CAMLparam3(protocol, type, modes);
CAMLlocal1(block);
SSL_CTX *ctx;
const SSL_METHOD *method = get_method(Int_val(type));
Expand All @@ -558,7 +578,7 @@ CAMLprim value ocaml_ssl_create_context(value protocol, value type) {
a write retry (since the GC may need to move it). In blocking
mode, hide SSL_ERROR_WANT_(READ|WRITE) from us. */
SSL_CTX_set_mode(ctx,
SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER | SSL_MODE_AUTO_RETRY);
SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER | Int_val(modes));
caml_acquire_runtime_system();

block = caml_alloc_custom(&ctx_ops, sizeof(SSL_CTX *), 0, 1);
Expand Down
11 changes: 11 additions & 0 deletions tests/dune
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@
(modules util)
(libraries ssl threads str alcotest))

(library
(name util_rlock)
(modules util_rlock)
(libraries ssl threads str alcotest))

(test
(name ssl_test)
(modules ssl_test)
Expand Down Expand Up @@ -43,3 +48,9 @@
(modules ssl_io)
(libraries ssl alcotest util)
(deps ca.pem ca.key server.key server.pem))

(test
(name ssl_rlock_io)
(modules ssl_rlock_io)
(libraries ssl alcotest util_rlock)
(deps ca.pem ca.key server.key server.pem))
126 changes: 126 additions & 0 deletions tests/ssl_rlock_io.ml
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
open Alcotest

module Ssl = struct
include Ssl
include Ssl.Runtime_lock
end

module Util = Util_rlock

let test_verify () =
let addr = Unix.ADDR_INET (Unix.inet_addr_of_string "127.0.0.1", 2342) in
Util.server_thread addr None |> ignore;

let context = Ssl.create_context TLSv1_3 Client_context in
let ssl = Ssl.open_connection_with_context context addr in
let verify_result =
try
Ssl.verify ssl;
""
with
| e -> Printexc.to_string e
in
let rec fn () =
try
Ssl.shutdown_connection ssl;
with
Ssl.(Connection_error(Error_want_write|Error_want_read|
Error_want_accept|Error_want_connect|Error_zero_return)) ->
fn ()
in
fn ();
check
bool
"no verify errors"
true
(Str.search_forward
(Str.regexp_string "error:00:000000:lib(0)")
verify_result
0
> 0)

let test_set_host () =
let addr = Unix.ADDR_INET (Unix.inet_addr_of_string "127.0.0.1", 2343) in
let pid = Util.server_thread addr None in

let context = Ssl.create_context TLSv1_3 Client_context in
let domain = Unix.domain_of_sockaddr addr in
let sock = Unix.socket domain Unix.SOCK_STREAM 0 in
let ssl = Ssl.embed_socket sock context in
Ssl.set_host ssl "localhost";
Unix.connect sock addr;
Unix.set_nonblock sock;
let rec fn () =
try
Ssl.connect ssl;
with
Ssl.(Connection_error(Error_want_write|Error_want_read|
Error_want_accept|Error_want_connect|Error_zero_return)) ->
fn ()
in fn ();

let verify_result =
try
Ssl.verify ssl;
""
with
| e -> Printexc.to_string e
in
let rec fn () =
try
Ssl.shutdown_connection ssl;
with
Ssl.(Connection_error(Error_want_write|Error_want_read|
Error_want_accept|Error_want_connect|Error_zero_return)) ->
fn ()
in
fn ();
check
bool
"no verify errors"
true
(Str.search_forward
(Str.regexp_string "error:00:000000:lib(0)")
verify_result
0
> 0);
Unix.kill pid Sys.sigint;
Unix.waitpid [] pid |> ignore


let test_read_write () =
let addr = Unix.ADDR_INET (Unix.inet_addr_of_string "127.0.0.1", 2344) in
let pid = Util.server_thread addr (Some (fun _ -> "received")) in

let context = Ssl.create_context TLSv1_3 Client_context in
let ssl = Ssl.open_connection_with_context context addr in
Unix.set_nonblock (Ssl.file_descr_of_socket ssl);
let send_msg = "send" in
let write_buf = Bytes.create (String.length send_msg) in
let rec fn () =
try Ssl.write ssl write_buf 0 4 |> ignore;
with Ssl.(Write_error(Error_want_write|Error_want_read|
Error_want_accept|Error_want_connect|Error_zero_return)) ->
fn ()
in fn ();
let read_buf = Bytes.create 8 in
let rec fn () =
try Ssl.read ssl read_buf 0 8 |> ignore;
with Ssl.(Read_error(Error_want_write|Error_want_read|
Error_want_accept|Error_want_connect|Error_zero_return)) ->
fn ()
in fn ();
Ssl.shutdown_connection ssl;
check string "received message" "received" (Bytes.to_string read_buf);
Unix.kill pid Sys.sigint;
Unix.waitpid [] pid |> ignore

let () =
run
"Ssl io functions with Ssl.Runtime_lock and non blocking socket"
[ ( "IO"
, [ test_case "Verify" `Quick test_verify
; test_case "Set host" `Quick test_set_host
; test_case "Read write" `Quick test_read_write
] )
]
Loading