Skip to content

Commit

Permalink
Merge pull request #4829 from psafont/private/paus/migrate-receive
Browse files Browse the repository at this point in the history
  • Loading branch information
psafont authored Oct 28, 2022
2 parents e3511f7 + b57f70c commit 406251e
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 102 deletions.
174 changes: 95 additions & 79 deletions ocaml/idl/ocaml_backend/gen_server.ml
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,8 @@ let debug msg args =
""

let has_default_args args =
let arg_has_default arg =
match arg.DT.param_default with None -> false | Some _ -> true
in
let any_defaults =
List.fold_left (fun e x -> e || x) false (List.map arg_has_default args)
in
any_defaults
let has_default arg = Option.is_some arg.DT.param_default in
List.exists has_default args

(* ------------------------------------------------------------------------------------------
Code to generate a single operation in server dispatcher
Expand All @@ -76,11 +71,8 @@ let count_mandatory_message_parameters (msg : message) =

let operation (obj : obj) (x : message) =
let msg_params = x.DT.msg_params in
let msg_params_with_default_values =
List.filter (fun p -> p.DT.param_default <> None) msg_params
in
let msg_params_without_default_values =
List.filter (fun p -> p.DT.param_default = None) msg_params
let msg_params_with_default_values, msg_params_without_default_values =
List.partition (fun p -> p.DT.param_default <> None) msg_params
in
let msg_without_default_values =
{x with DT.msg_params= msg_params_without_default_values}
Expand Down Expand Up @@ -119,41 +111,39 @@ let operation (obj : obj) (x : message) =
else
List.map O.string_of_param args_without_default_values
in
let string_args =
List.map (fun s -> Printf.sprintf "%s_rpc" s) orig_string_args
in
let to_rpc_name s = Printf.sprintf "%s_rpc" s in
let string_args = List.map to_rpc_name orig_string_args in
let is_non_constructor_with_defaults =
(not is_ctor) && has_default_args x.DT.msg_params
in
let arg_pattern = String.concat "::" string_args in
let arg_pattern =
if is_non_constructor_with_defaults then
arg_pattern ^ "::default_args"
String.concat " :: " (string_args @ ["default_args"])
else
arg_pattern ^ "::[]"
Printf.sprintf "[%s]" (String.concat "; " string_args)
in
let name_pattern_match =
Printf.sprintf "| \"%s\" | \"%s\" -> " wire_name alternative_wire_name
Printf.sprintf {|| "%s" | "%s" -> |} wire_name alternative_wire_name
in
(* Lookup the various fields from the constructor record *)
let from_ctor_record =
let fields = Client.ctor_fields obj in
let of_field f =
let binding = O.string_of_param (Client.param_of_field f) in
let converter = Printf.sprintf "%s_of_rpc" (OU.alias_of_ty f.DT.ty) in
let wire_name = Printf.sprintf {|"%s"|} (DU.wire_name_of_field f) in
let lookup_expr =
match f.DT.default_value with
| None ->
Printf.sprintf "(my_assoc \"%s\" __structure)"
(DU.wire_name_of_field f)
Printf.sprintf "(my_assoc %s __structure)" wire_name
| Some default ->
Printf.sprintf
"(if (List.mem_assoc \"%s\" __structure) then (my_assoc \"%s\" \
__structure) else %s)"
(DU.wire_name_of_field f) (DU.wire_name_of_field f)
"((List.assoc_opt %s __structure) |> Option.value ~default:(%s))"
wire_name
(Datamodel_values.to_ocaml_string default)
in
Printf.sprintf " let %s = %s %s in" binding converter lookup_expr
Printf.sprintf " let %s = %s %s in" binding converter
lookup_expr
in
String.concat "\n"
("let __structure = match __structure_rpc with Dict d -> d | _ -> \
Expand Down Expand Up @@ -197,44 +187,86 @@ let operation (obj : obj) (x : message) =
]
in
(* Generate the unmarshalling code *)
let rec add_counts i l =
match l with [] -> [] | x :: xs -> (i, x) :: add_counts (i + 1) xs
in
let has_session_arg =
if is_ctor then
is_session_arg Client.session
else
List.exists (fun a -> is_session_arg a) args_without_default_values
in
let name_default_params, unmarshall_default_params =
List.mapi
(fun param_count default_param ->
let param_name = OU.ocaml_of_record_name default_param.DT.param_name in
let param_type = OU.alias_of_ty default_param.DT.param_type in
let param_rpc = to_rpc_name param_name in
let try_and_get_default =
Printf.sprintf "List.nth default_args %d" param_count
in
let default_value =
match default_param.DT.param_default with
| None ->
"** EXPECTED DEFAULT VALUE IN THIS PARAM **"
| Some default ->
Datamodel_values.to_ocaml_string default
in
( Printf.sprintf "let %s = try %s with _ -> %s in" param_rpc
try_and_get_default default_value
, Printf.sprintf "let %s = %s_of_rpc %s in" param_name param_type
param_rpc
)
)
msg_params_with_default_values
|> List.split
in
let rbac_check_begin =
if has_session_arg then
let serialize_list lst =
String.concat "; " lst |> Printf.sprintf "[%s]"
in
let serialize_name_list lst =
List.map (Printf.sprintf {|"%s"|}) lst |> serialize_list
in
let serialize_args args =
List.map (fun (n, v) -> Printf.sprintf {|("%s", %s)|} n v) args
|> serialize_list
in
let default_arg_name_params =
if is_non_constructor_with_defaults then
List.map (fun dp -> dp.DT.param_name) msg_params_with_default_values
else
[]
in
let arg_names = orig_string_args @ default_arg_name_params in
let default_arg_values =
List.map
(fun dp -> to_rpc_name (OU.ocaml_of_record_name dp.DT.param_name))
msg_params_with_default_values
in
let arg_values = string_args @ default_arg_values in
let args =
try List.combine arg_names arg_values
with Invalid_argument _ ->
let msg =
Printf.sprintf
"Cannot serialize call %s.%s: number of arguments doesn't match \
with the number of names for it; in %s"
(OU.ocaml_of_obj_name obj.DT.name)
x.msg_name __LOC__
in
failwith msg
in
let key_names = List.map fst x.msg_map_keys_roles in
[
"let arg_names = "
^ List.fold_right
(fun arg args -> "\"" ^ arg ^ "\"::" ^ args)
orig_string_args
( if is_non_constructor_with_defaults then
List.fold_right
(fun dp ss -> "\"" ^ dp.DT.param_name ^ "\"::" ^ ss)
msg_params_with_default_values ""
^ "[]"
else
"[]"
)
^ " in"
; "let key_names = "
^ List.fold_right
(fun arg args -> "\"" ^ arg ^ "\"::" ^ args)
(List.map (fun (k, _) -> k) x.msg_map_keys_roles)
"[]"
^ " in"
Printf.sprintf "let arg_names_values = %s in" (serialize_args args)
; (* This incurs a runtime cost *)
"let arg_names, arg_values = List.split arg_names_values in"
; Printf.sprintf "let key_names = %s in" (serialize_name_list key_names)
; "let rbac __context fn = Rbac.check session_id __call \
~args:(arg_names,__params) ~keys:key_names ~__context ~fn in"
~args:(arg_names, arg_values) ~keys:key_names ~__context ~fn in"
]
else
["let rbac __context fn = fn() in"]
["let rbac __context fn = fn () in"]
in
let rbac_check_end = if has_session_arg then [] else [] in
let unmarshall_code =
(* If we are forwarding the call then we don't want to emit a warning
because we know we don't need the arguments *)
Expand All @@ -252,26 +284,7 @@ let operation (obj : obj) (x : message) =
args_without_default_values
)
(* and for every default value we try to get this from default_args or default it *)
@ List.map
(fun (param_count, default_param) ->
let param_name =
OU.ocaml_of_record_name default_param.DT.param_name
in
let param_type = OU.alias_of_ty default_param.DT.param_type in
let try_and_get_default =
Printf.sprintf "Server_helpers.nth %d default_args" param_count
in
let default_value =
match default_param.DT.param_default with
| None ->
"** EXPECTED DEFAULT VALUE IN THIS PARAM **"
| Some default ->
Datamodel_values.to_ocaml_string default
in
Printf.sprintf "let %s = %s_of_rpc (try %s with _ -> %s) in"
param_name param_type try_and_get_default default_value
)
(add_counts 1 msg_params_with_default_values)
@ unmarshall_default_params
in
let may_be_side_effecting msg =
match msg.msg_tag with
Expand Down Expand Up @@ -304,7 +317,7 @@ let operation (obj : obj) (x : message) =
"let host = ref_host_of_rpc host_rpc in"
; "let call_string = Jsonrpc.string_of_call {call with name=__call} in"
; "let marshaller = (fun x -> x) in"
; "let local_op = fun ~__context ->(rbac __context \
; "let local_op = fun ~__context -> (rbac __context \
(fun()->(Custom.Host.call_extension \
~__context:(Context.check_for_foreign_database ~__context) ~host \
~call:call_string))) in"
Expand Down Expand Up @@ -396,11 +409,11 @@ let operation (obj : obj) (x : message) =
let all_list =
if not (DU.has_been_removed x.DT.msg_lifecycle) then
comments
@ name_default_params
@ unmarshall_code
@ session_check_exp
@ rbac_check_begin
@ gen_body ()
@ rbac_check_end
else
comments
@ ["let session_id = ref_session_of_rpc session_id_rpc in"]
Expand Down Expand Up @@ -477,12 +490,10 @@ let gen_module api : O.Module.t =
http_req.Http.Request.subtask_of in"
; "let http_other_config = Context.get_http_other_config \
http_req in"
; "let may f = function | None -> None | Some x -> Some (f x) \
in"
; "Server_helpers.exec_with_new_task \
(\"dispatch:\"^__call^\"\") ~http_other_config \
?subtask_of:(may Ref.of_string subtask_of) (fun __context \
->"
?subtask_of:(Option.map Ref.of_string subtask_of) (fun \
__context ->"
; "Server_helpers.dispatch_exn_wrapper (fun () -> (match \
__call with "
]
Expand Down Expand Up @@ -518,15 +529,20 @@ let gen_module api : O.Module.t =
; " then "
^ debug "This is not a built-in rpc \"%s\"" ["__call"]
; " begin match __params with"
; " | session_id_rpc::_->"
; " | session_id_rpc :: _->"
; " let session_id = ref_session_of_rpc session_id_rpc in"
; " Session_check.check ~intra_pool_only:false \
~session_id;"
; " (* based on the Host.call_extension call *)"
; " let arg_names = \"session_id\"::__call::[] in"
; " let call_rpc = Rpc.String __call in "
; " let arg_names, arg_values ="
; " [(\"session_id\", session_id_rpc); (__call, \
call_rpc)]"
; " |> List.split"
; " in"
; " let key_names = [] in"
; " let rbac __context fn = Rbac.check session_id \
\"Host.call_extension\" ~args:(arg_names,__params) \
\"Host.call_extension\" ~args:(arg_names, arg_values) \
~keys:key_names ~__context ~fn in"
; " Server_helpers.forward_extension ~__context rbac { \
call with Rpc.name = __call }"
Expand Down
31 changes: 19 additions & 12 deletions ocaml/idl/ocaml_backend/ocaml_syntax.ml
Original file line number Diff line number Diff line change
Expand Up @@ -79,18 +79,25 @@ module Let = struct
| Named (name, _) ->
"~" ^ name
in
[
Line ("(** " ^ x.doc ^ " *)")
; Line
(prefix
^ " "
^ x.name
^ " "
^ String.concat " " (List.map param x.params)
^ " ="
)
; Indent (List.map (fun x -> Line x) x.body)
]
let doclines =
if x.doc <> "" then
[Line ("(** " ^ x.doc ^ " *)")]
else
[]
in

doclines
@ [
Line
(prefix
^ " "
^ x.name
^ " "
^ String.concat " " (List.map param x.params)
^ " ="
)
; Indent (List.map (fun x -> Line x) x.body)
]
end

module Type = struct
Expand Down
9 changes: 0 additions & 9 deletions ocaml/xapi/server_helpers.ml
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,6 @@ let my_assoc fld assoc_list =
try List.assoc fld assoc_list
with Not_found -> raise (Dispatcher_FieldNotFound fld)

exception Nth (* should never be thrown externally *)

let rec nth n l =
match l with
| [] ->
raise Nth
| x :: xs ->
if n = 1 then x else nth (n - 1) xs

let async_wire_name = "Async."

let async_length = String.length async_wire_name
Expand Down
2 changes: 0 additions & 2 deletions ocaml/xapi/server_helpers.mli
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,6 @@ val exec_with_subtask :
(* used by auto-generated code in server.ml *)
val my_assoc : string -> (string * 'a) list -> 'a

val nth : int -> 'a list -> 'a

val sync_ty_and_maybe_remove_prefix :
string -> [> `Async | `InternalAsync | `Sync] * string

Expand Down

0 comments on commit 406251e

Please sign in to comment.