Skip to content

Commit

Permalink
Coq: more efficient equality decision procedures for enums
Browse files Browse the repository at this point in the history
  • Loading branch information
bacam committed Apr 11, 2024
1 parent 0c28da2 commit 47a7b18
Showing 1 changed file with 93 additions and 7 deletions.
100 changes: 93 additions & 7 deletions src/sail_coq_backend/pretty_print_coq.ml
Original file line number Diff line number Diff line change
Expand Up @@ -2377,7 +2377,7 @@ let rec doc_range ctxt (BF_aux(r,_)) = match r with
*)

(* TODO: check use of empty_ctxt below doesn't cause problems due to missing info *)
let doc_typdef types_mod avoid_target_names generic_eq_types (TD_aux (td, (l, annot))) =
let doc_typdef types_mod avoid_target_names generic_eq_types enum_number_defs (TD_aux (td, (l, annot))) =
let bare_ctxt = { empty_ctxt with avoid_target_names } in
match td with
| TD_abbrev (id, typq, A_aux (A_typ typ, _)) ->
Expand Down Expand Up @@ -2632,7 +2632,58 @@ let doc_typdef types_mod avoid_target_names generic_eq_types (TD_aux (td, (l, an
let typ_pp =
(doc_op coloneq) (concat [string "Inductive"; space; id_pp]) (ifflat empty (pipe ^^ space) ^^ enums_doc)
in
let eq1_pp = string "Scheme Equality for" ^^ space ^^ id_pp ^^ dot in
(* If we have conversion functions to Z, put them here and
derive a decision procedure that's efficient even for
large enums. *)
let eq1_pp =
let fallback = string "Scheme Equality for" ^^ space ^^ id_pp ^^ dot in
match (Bindings.find_opt id (fst enum_number_defs), Bindings.find_opt id (snd enum_number_defs)) with
| Some (num_of_id, num_of_pp), Some (of_num_id, of_num_pp) ->
let num_of_id_pp = doc_id bare_ctxt num_of_id in
let of_num_id_pp = doc_id bare_ctxt of_num_id in
let lemma1 =
separate hardline
[
string "Lemma " ^^ id_pp ^^ string "_num_of_roundtrip "
^^ parens (string "x : " ^^ id_pp)
^^ string " : " ^^ of_num_id_pp ^^ space
^^ parens (num_of_id_pp ^^ string " x")
^^ string " = x.";
string "destruct x; reflexivity.";
string "Qed.";
]
in
let lemma2 =
separate hardline
[
string "Lemma " ^^ num_of_id_pp ^^ string "_injective "
^^ parens (string "x y : " ^^ id_pp)
^^ string " : " ^^ num_of_id_pp ^^ string " x = " ^^ num_of_id_pp ^^ string " y -> x = y.";
string "intro.";
string "rewrite <- (" ^^ id_pp ^^ string "_num_of_roundtrip x).";
string "rewrite <- (" ^^ id_pp ^^ string "_num_of_roundtrip y).";
string "congruence.";
string "Qed.";
]
in
let eq_pp =
separate hardline
[
string "Definition " ^^ id_pp ^^ string "_eq_dec (x y : " ^^ id_pp
^^ string ") : {x = y} + {x <> y}.";
string "refine (match Z.eq_dec (" ^^ num_of_id_pp ^^ string " x) (" ^^ num_of_id_pp
^^ string " y) with";
string "| left e => left (" ^^ num_of_id_pp ^^ string "_injective x y e)";
string "| right ne => right _";
string "end).";
string "congruence.";
string "Defined.";
]
in
num_of_pp ^^ of_num_pp ^^ separate hardline [lemma1; lemma2; eq_pp]
| Some (_, pp), None | None, Some (_, pp) -> pp ^^ fallback
| None, None -> fallback
in
let eq2_pp =
string "#[export]" ^^ hardline
^^ group
Expand Down Expand Up @@ -3364,14 +3415,15 @@ let doc_val avoid_target_names pat exp =
^^ group (separate space [string "#[export] Hint Unfold"; idpp; colon; string "sail."])
^^ hardline

let doc_def types_mod unimplemented avoid_target_names generic_eq_types effect_info (DEF_aux (aux, def_annot) as def) =
let doc_def types_mod unimplemented avoid_target_names generic_eq_types enum_number_defs effect_info
(DEF_aux (aux, def_annot) as def) =
match aux with
| DEF_val v_spec -> doc_val_spec def_annot unimplemented avoid_target_names effect_info v_spec
| DEF_fixity _ -> empty
| DEF_overload _ -> empty
| DEF_type t_def ->
if List.mem (string_of_id (id_of_type_def t_def)) !opt_extern_types <> !opt_generate_extern_types then empty
else doc_typdef types_mod avoid_target_names generic_eq_types t_def
else doc_typdef types_mod avoid_target_names generic_eq_types enum_number_defs t_def
| DEF_register dec -> group (doc_dec avoid_target_names dec)
| DEF_default df -> empty
| DEF_fundef fdef -> group (doc_fundef types_mod avoid_target_names effect_info fdef) ^/^ hardline
Expand Down Expand Up @@ -3540,8 +3592,6 @@ let pp_ast_coq (types_file, types_modules) (defs_file, defs_modules) type_defs_m
in
let is_typ_def = function DEF_aux (DEF_type _, _) -> true | _ -> false in
let exc_typ = find_exc_typ defs in
let typdefs, defs = List.partition is_typ_def defs in
let statedefs, defs = List.partition is_state_def defs in
let unimplemented = find_unimplemented defs in
let avoid_target_names = builtin_target_names defs in
let bare_doc_id = doc_id { empty_ctxt with avoid_target_names } in
Expand Down Expand Up @@ -3603,7 +3653,43 @@ let pp_ast_coq (types_file, types_modules) (defs_file, defs_modules) type_defs_m
@ mr_m
)
in
let doc_def = doc_def type_defs_module unimplemented avoid_target_names generic_eq_types effect_info in
let enums = Type_check.Env.get_enums type_env in
let defs, enum_number_defs =
let doc_def =
doc_def type_defs_module unimplemented avoid_target_names generic_eq_types (Bindings.empty, Bindings.empty)
effect_info
in
let num_of_map, of_num_map, rdefs =
List.fold_left
(fun (num_of_map, of_num_map, rdefs) def ->
match def with
| DEF_aux (DEF_fundef (FD_aux (FD_function (_, _, [FCL_aux (FCL_funcl (id, _), _)]), _)), _) -> begin
match Type_check.Env.get_val_spec id type_env with
| _, Typ_aux (Typ_fn ([arg_typ], ret_typ), _) -> begin
match (arg_typ, ret_typ) with
| Typ_aux (Typ_id arg_id, _), _
when Bindings.mem arg_id enums && string_of_id id = "num_of_" ^ string_of_id arg_id ->
(Bindings.add arg_id (id, doc_def def) num_of_map, of_num_map, rdefs)
| _, Typ_aux (Typ_id ret_id, _)
when Bindings.mem ret_id enums && string_of_id id = string_of_id ret_id ^ "_of_num" ->
(num_of_map, Bindings.add ret_id (id, doc_def def) of_num_map, rdefs)
| _ -> (num_of_map, of_num_map, def :: rdefs)
end
| _ -> (num_of_map, of_num_map, def :: rdefs)
end
| _ -> (num_of_map, of_num_map, def :: rdefs)
)
(Bindings.empty, Bindings.empty, []) defs
in
(List.rev rdefs, (num_of_map, of_num_map))
in

let typdefs, defs = List.partition is_typ_def defs in
let statedefs, defs = List.partition is_state_def defs in

let doc_def =
doc_def type_defs_module unimplemented avoid_target_names generic_eq_types enum_number_defs effect_info
in
let () =
if !opt_undef_axioms || IdSet.is_empty unimplemented then ()
else
Expand Down

0 comments on commit 47a7b18

Please sign in to comment.