diff --git a/ocaml/idl/ocaml_backend/gen_server.ml b/ocaml/idl/ocaml_backend/gen_server.ml index 4e4fbe174b6..23c43f76380 100644 --- a/ocaml/idl/ocaml_backend/gen_server.ml +++ b/ocaml/idl/ocaml_backend/gen_server.ml @@ -50,13 +50,8 @@ let debug msg args = "" let has_default_args args = - let arg_has_default arg = - match arg.DT.param_default with None -> false | Some _ -> true - in - let any_defaults = - List.fold_left (fun e x -> e || x) false (List.map arg_has_default args) - in - any_defaults + let has_default arg = Option.is_some arg.DT.param_default in + List.exists has_default args (* ------------------------------------------------------------------------------------------ Code to generate a single operation in server dispatcher @@ -76,11 +71,8 @@ let count_mandatory_message_parameters (msg : message) = let operation (obj : obj) (x : message) = let msg_params = x.DT.msg_params in - let msg_params_with_default_values = - List.filter (fun p -> p.DT.param_default <> None) msg_params - in - let msg_params_without_default_values = - List.filter (fun p -> p.DT.param_default = None) msg_params + let msg_params_with_default_values, msg_params_without_default_values = + List.partition (fun p -> p.DT.param_default <> None) msg_params in let msg_without_default_values = {x with DT.msg_params= msg_params_without_default_values} @@ -119,21 +111,19 @@ let operation (obj : obj) (x : message) = else List.map O.string_of_param args_without_default_values in - let string_args = - List.map (fun s -> Printf.sprintf "%s_rpc" s) orig_string_args - in + let to_rpc_name s = Printf.sprintf "%s_rpc" s in + let string_args = List.map to_rpc_name orig_string_args in let is_non_constructor_with_defaults = (not is_ctor) && has_default_args x.DT.msg_params in - let arg_pattern = String.concat "::" string_args in let arg_pattern = if is_non_constructor_with_defaults then - arg_pattern ^ "::default_args" + String.concat " :: " (string_args @ ["default_args"]) else - arg_pattern ^ "::[]" + Printf.sprintf "[%s]" (String.concat "; " string_args) in let name_pattern_match = - Printf.sprintf "| \"%s\" | \"%s\" -> " wire_name alternative_wire_name + Printf.sprintf {|| "%s" | "%s" -> |} wire_name alternative_wire_name in (* Lookup the various fields from the constructor record *) let from_ctor_record = @@ -141,19 +131,19 @@ let operation (obj : obj) (x : message) = let of_field f = let binding = O.string_of_param (Client.param_of_field f) in let converter = Printf.sprintf "%s_of_rpc" (OU.alias_of_ty f.DT.ty) in + let wire_name = Printf.sprintf {|"%s"|} (DU.wire_name_of_field f) in let lookup_expr = match f.DT.default_value with | None -> - Printf.sprintf "(my_assoc \"%s\" __structure)" - (DU.wire_name_of_field f) + Printf.sprintf "(my_assoc %s __structure)" wire_name | Some default -> Printf.sprintf - "(if (List.mem_assoc \"%s\" __structure) then (my_assoc \"%s\" \ - __structure) else %s)" - (DU.wire_name_of_field f) (DU.wire_name_of_field f) + "((List.assoc_opt %s __structure) |> Option.value ~default:(%s))" + wire_name (Datamodel_values.to_ocaml_string default) in - Printf.sprintf " let %s = %s %s in" binding converter lookup_expr + Printf.sprintf " let %s = %s %s in" binding converter + lookup_expr in String.concat "\n" ("let __structure = match __structure_rpc with Dict d -> d | _ -> \ @@ -197,44 +187,86 @@ let operation (obj : obj) (x : message) = ] in (* Generate the unmarshalling code *) - let rec add_counts i l = - match l with [] -> [] | x :: xs -> (i, x) :: add_counts (i + 1) xs - in let has_session_arg = if is_ctor then is_session_arg Client.session else List.exists (fun a -> is_session_arg a) args_without_default_values in + let name_default_params, unmarshall_default_params = + List.mapi + (fun param_count default_param -> + let param_name = OU.ocaml_of_record_name default_param.DT.param_name in + let param_type = OU.alias_of_ty default_param.DT.param_type in + let param_rpc = to_rpc_name param_name in + let try_and_get_default = + Printf.sprintf "List.nth default_args %d" param_count + in + let default_value = + match default_param.DT.param_default with + | None -> + "** EXPECTED DEFAULT VALUE IN THIS PARAM **" + | Some default -> + Datamodel_values.to_ocaml_string default + in + ( Printf.sprintf "let %s = try %s with _ -> %s in" param_rpc + try_and_get_default default_value + , Printf.sprintf "let %s = %s_of_rpc %s in" param_name param_type + param_rpc + ) + ) + msg_params_with_default_values + |> List.split + in let rbac_check_begin = if has_session_arg then + let serialize_list lst = + String.concat "; " lst |> Printf.sprintf "[%s]" + in + let serialize_name_list lst = + List.map (Printf.sprintf {|"%s"|}) lst |> serialize_list + in + let serialize_args args = + List.map (fun (n, v) -> Printf.sprintf {|("%s", %s)|} n v) args + |> serialize_list + in + let default_arg_name_params = + if is_non_constructor_with_defaults then + List.map (fun dp -> dp.DT.param_name) msg_params_with_default_values + else + [] + in + let arg_names = orig_string_args @ default_arg_name_params in + let default_arg_values = + List.map + (fun dp -> to_rpc_name (OU.ocaml_of_record_name dp.DT.param_name)) + msg_params_with_default_values + in + let arg_values = string_args @ default_arg_values in + let args = + try List.combine arg_names arg_values + with Invalid_argument _ -> + let msg = + Printf.sprintf + "Cannot serialize call %s.%s: number of arguments doesn't match \ + with the number of names for it; in %s" + (OU.ocaml_of_obj_name obj.DT.name) + x.msg_name __LOC__ + in + failwith msg + in + let key_names = List.map fst x.msg_map_keys_roles in [ - "let arg_names = " - ^ List.fold_right - (fun arg args -> "\"" ^ arg ^ "\"::" ^ args) - orig_string_args - ( if is_non_constructor_with_defaults then - List.fold_right - (fun dp ss -> "\"" ^ dp.DT.param_name ^ "\"::" ^ ss) - msg_params_with_default_values "" - ^ "[]" - else - "[]" - ) - ^ " in" - ; "let key_names = " - ^ List.fold_right - (fun arg args -> "\"" ^ arg ^ "\"::" ^ args) - (List.map (fun (k, _) -> k) x.msg_map_keys_roles) - "[]" - ^ " in" + Printf.sprintf "let arg_names_values = %s in" (serialize_args args) + ; (* This incurs a runtime cost *) + "let arg_names, arg_values = List.split arg_names_values in" + ; Printf.sprintf "let key_names = %s in" (serialize_name_list key_names) ; "let rbac __context fn = Rbac.check session_id __call \ - ~args:(arg_names,__params) ~keys:key_names ~__context ~fn in" + ~args:(arg_names, arg_values) ~keys:key_names ~__context ~fn in" ] else - ["let rbac __context fn = fn() in"] + ["let rbac __context fn = fn () in"] in - let rbac_check_end = if has_session_arg then [] else [] in let unmarshall_code = (* If we are forwarding the call then we don't want to emit a warning because we know we don't need the arguments *) @@ -252,26 +284,7 @@ let operation (obj : obj) (x : message) = args_without_default_values ) (* and for every default value we try to get this from default_args or default it *) - @ List.map - (fun (param_count, default_param) -> - let param_name = - OU.ocaml_of_record_name default_param.DT.param_name - in - let param_type = OU.alias_of_ty default_param.DT.param_type in - let try_and_get_default = - Printf.sprintf "Server_helpers.nth %d default_args" param_count - in - let default_value = - match default_param.DT.param_default with - | None -> - "** EXPECTED DEFAULT VALUE IN THIS PARAM **" - | Some default -> - Datamodel_values.to_ocaml_string default - in - Printf.sprintf "let %s = %s_of_rpc (try %s with _ -> %s) in" - param_name param_type try_and_get_default default_value - ) - (add_counts 1 msg_params_with_default_values) + @ unmarshall_default_params in let may_be_side_effecting msg = match msg.msg_tag with @@ -304,7 +317,7 @@ let operation (obj : obj) (x : message) = "let host = ref_host_of_rpc host_rpc in" ; "let call_string = Jsonrpc.string_of_call {call with name=__call} in" ; "let marshaller = (fun x -> x) in" - ; "let local_op = fun ~__context ->(rbac __context \ + ; "let local_op = fun ~__context -> (rbac __context \ (fun()->(Custom.Host.call_extension \ ~__context:(Context.check_for_foreign_database ~__context) ~host \ ~call:call_string))) in" @@ -396,11 +409,11 @@ let operation (obj : obj) (x : message) = let all_list = if not (DU.has_been_removed x.DT.msg_lifecycle) then comments + @ name_default_params @ unmarshall_code @ session_check_exp @ rbac_check_begin @ gen_body () - @ rbac_check_end else comments @ ["let session_id = ref_session_of_rpc session_id_rpc in"] @@ -477,12 +490,10 @@ let gen_module api : O.Module.t = http_req.Http.Request.subtask_of in" ; "let http_other_config = Context.get_http_other_config \ http_req in" - ; "let may f = function | None -> None | Some x -> Some (f x) \ - in" ; "Server_helpers.exec_with_new_task \ (\"dispatch:\"^__call^\"\") ~http_other_config \ - ?subtask_of:(may Ref.of_string subtask_of) (fun __context \ - ->" + ?subtask_of:(Option.map Ref.of_string subtask_of) (fun \ + __context ->" ; "Server_helpers.dispatch_exn_wrapper (fun () -> (match \ __call with " ] @@ -518,15 +529,20 @@ let gen_module api : O.Module.t = ; " then " ^ debug "This is not a built-in rpc \"%s\"" ["__call"] ; " begin match __params with" - ; " | session_id_rpc::_->" + ; " | session_id_rpc :: _->" ; " let session_id = ref_session_of_rpc session_id_rpc in" ; " Session_check.check ~intra_pool_only:false \ ~session_id;" ; " (* based on the Host.call_extension call *)" - ; " let arg_names = \"session_id\"::__call::[] in" + ; " let call_rpc = Rpc.String __call in " + ; " let arg_names, arg_values =" + ; " [(\"session_id\", session_id_rpc); (__call, \ + call_rpc)]" + ; " |> List.split" + ; " in" ; " let key_names = [] in" ; " let rbac __context fn = Rbac.check session_id \ - \"Host.call_extension\" ~args:(arg_names,__params) \ + \"Host.call_extension\" ~args:(arg_names, arg_values) \ ~keys:key_names ~__context ~fn in" ; " Server_helpers.forward_extension ~__context rbac { \ call with Rpc.name = __call }" diff --git a/ocaml/idl/ocaml_backend/ocaml_syntax.ml b/ocaml/idl/ocaml_backend/ocaml_syntax.ml index ae7f7bae91e..01da3d662eb 100644 --- a/ocaml/idl/ocaml_backend/ocaml_syntax.ml +++ b/ocaml/idl/ocaml_backend/ocaml_syntax.ml @@ -79,18 +79,25 @@ module Let = struct | Named (name, _) -> "~" ^ name in - [ - Line ("(** " ^ x.doc ^ " *)") - ; Line - (prefix - ^ " " - ^ x.name - ^ " " - ^ String.concat " " (List.map param x.params) - ^ " =" - ) - ; Indent (List.map (fun x -> Line x) x.body) - ] + let doclines = + if x.doc <> "" then + [Line ("(** " ^ x.doc ^ " *)")] + else + [] + in + + doclines + @ [ + Line + (prefix + ^ " " + ^ x.name + ^ " " + ^ String.concat " " (List.map param x.params) + ^ " =" + ) + ; Indent (List.map (fun x -> Line x) x.body) + ] end module Type = struct diff --git a/ocaml/xapi/server_helpers.ml b/ocaml/xapi/server_helpers.ml index 4a4131198de..9324fddb71b 100644 --- a/ocaml/xapi/server_helpers.ml +++ b/ocaml/xapi/server_helpers.ml @@ -22,15 +22,6 @@ let my_assoc fld assoc_list = try List.assoc fld assoc_list with Not_found -> raise (Dispatcher_FieldNotFound fld) -exception Nth (* should never be thrown externally *) - -let rec nth n l = - match l with - | [] -> - raise Nth - | x :: xs -> - if n = 1 then x else nth (n - 1) xs - let async_wire_name = "Async." let async_length = String.length async_wire_name diff --git a/ocaml/xapi/server_helpers.mli b/ocaml/xapi/server_helpers.mli index 45da2dbf244..6651402acaa 100644 --- a/ocaml/xapi/server_helpers.mli +++ b/ocaml/xapi/server_helpers.mli @@ -42,8 +42,6 @@ val exec_with_subtask : (* used by auto-generated code in server.ml *) val my_assoc : string -> (string * 'a) list -> 'a -val nth : int -> 'a list -> 'a - val sync_ty_and_maybe_remove_prefix : string -> [> `Async | `InternalAsync | `Sync] * string