Skip to content

Commit

Permalink
[asl] Check type validity on user input
Browse files Browse the repository at this point in the history
  • Loading branch information
HadrienRenaud committed Mar 5, 2024
1 parent 68c0fbb commit 347d7d4
Show file tree
Hide file tree
Showing 16 changed files with 977 additions and 614 deletions.
9 changes: 8 additions & 1 deletion asllib/AST.mli
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ and type_desc =
| T_Bool
| T_Enum of identifier list
| T_Tuple of ty list
| T_Array of expr * ty
| T_Array of array_index * ty
| T_Record of field list
| T_Exception of field list
| T_Named of identifier (** A type variable. *)
Expand Down Expand Up @@ -208,6 +208,13 @@ and bitfield =
| BitField_Type of identifier * slice list * ty
(** A name, its corresponding slice and the type of the bitfield. *)

(** The type of indexes for an array. *)
and array_index =
| ArrayLength_Expr of expr
(** An integer expression giving the length of the array. *)
| ArrayLength_Enum of identifier * int
(** An enumeration name and its length. *)

and field = identifier * ty
(** A field of a record-like structure. *)

Expand Down
13 changes: 11 additions & 2 deletions asllib/ASTUtils.ml
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,8 @@ and use_ty acc t =
| T_Int (WellConstrained cs) -> use_constraints acc cs
| T_Tuple li -> List.fold_left use_ty acc li
| T_Record fields | T_Exception fields -> fold_named_list use_ty acc fields
| T_Array (e, t') -> use_ty (use_e acc e) t'
| T_Array (ArrayLength_Expr e, t') -> use_ty (use_e acc e) t'
| T_Array (ArrayLength_Enum (s, _), t') -> use_ty (ISet.add s acc) t'
| T_Bits (e, bit_fields) -> use_bitfields (use_e acc e) bit_fields

and use_bitfields acc bitfields = List.fold_left use_bitfield acc bitfields
Expand Down Expand Up @@ -380,6 +381,14 @@ and constraint_equal eq c1 c2 =
and constraints_equal eq cs1 cs2 =
cs1 == cs2 || list_equal (constraint_equal eq) cs1 cs2

and array_length_equal eq l1 l2 =
match (l1, l2) with
| ArrayLength_Expr e1, ArrayLength_Expr e2 -> expr_equal eq e1 e2
| ArrayLength_Enum (s1, _), ArrayLength_Enum (s2, _) -> String.equal s1 s2
| ArrayLength_Enum (_, _), ArrayLength_Expr _
| ArrayLength_Expr _, ArrayLength_Enum (_, _) ->
false

and type_equal eq t1 t2 =
t1.desc == t2.desc
||
Expand All @@ -396,7 +405,7 @@ and type_equal eq t1 t2 =
| T_Bits (w1, bf1), T_Bits (w2, bf2) ->
bitwidth_equal eq w1 w2 && bitfields_equal eq bf1 bf2
| T_Array (l1, t1), T_Array (l2, t2) ->
expr_equal eq l1 l2 && type_equal eq t1 t2
array_length_equal eq l1 l2 && type_equal eq t1 t2
| T_Named s1, T_Named s2 -> String.equal s1 s2
| T_Enum li1, T_Enum li2 ->
(* TODO: order of fields? *) list_equal String.equal li1 li2
Expand Down
46 changes: 23 additions & 23 deletions asllib/Interpreter.ml
Original file line number Diff line number Diff line change
Expand Up @@ -573,23 +573,26 @@ module Make (B : Backend.S) (C : Config) = struct
let* v' = B.get_index i v in
let* here = in_values v' ty' in
and' prev here
and i = i + 1 in
(i, m)
in
(i + 1, m)
in
List.fold_left fold (0, m_true) tys |> snd
| T_Array (e, ty') ->
let* v = eval_expr_sef env e in
let n =
match B.v_to_int v with
| Some i -> i
| None -> fatal_from loc @@ Error.UnsupportedExpr e
| T_Array (index, ty') ->
let* length =
match index with
| ArrayLength_Enum (_, i) -> return i
| ArrayLength_Expr e -> (
let* v_length = eval_expr_sef env e in
match B.v_to_int v_length with
| Some i -> return i
| None -> fatal_from loc @@ Error.UnsupportedExpr e)
in
let rec loop i prev =
if i = n then prev
if i >= length then prev
else
let* v' = B.get_index i v in
let* here = in_values v' ty' in
loop (succ i) (and' prev here)
loop (i + 1) (and' prev here)
in
loop 0 m_true
| T_Named _ -> assert false
Expand Down Expand Up @@ -1294,21 +1297,18 @@ module Make (B : Backend.S) (C : Config) = struct
(Error.NotYetImplemented "Base value of string types.")
| T_Tuple li ->
List.map (base_value env) li |> sync_list >>= B.create_vector
| T_Array (e_length, ty) ->
| T_Array (length, ty) ->
let* v = base_value env ty in
let* length =
match e_length.desc with
| E_Var x when IMap.mem x env.global.static.declared_types ->
IMap.find x env.global.static.constants_values
|> B.v_of_literal |> return
| _ -> eval_expr_sef env e_length
in
let length =
match B.v_to_int length with
| None -> Error.fatal_from t (Error.UnsupportedExpr e_length)
| Some i -> i
let* i_length =
match length with
| ArrayLength_Enum (_, i) -> return i
| ArrayLength_Expr e -> (
let* length = eval_expr_sef env e in
match B.v_to_int length with
| None -> Error.fatal_from t (Error.UnsupportedExpr e)
| Some i -> return i)
in
List.init length (Fun.const v) |> B.create_vector
List.init i_length (Fun.const v) |> B.create_vector

(* Begin TopLevel *)
let run_typed_env env (ast : AST.t) (static_env : StaticEnv.env) : B.value m =
Expand Down
4 changes: 3 additions & 1 deletion asllib/PP.ml
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,10 @@ and pp_ty f t =
(pp_comma_list pp_print_string)
enum_ty
| T_Tuple ty_list -> fprintf f "@[(%a)@]" (pp_comma_list pp_ty) ty_list
| T_Array (e, elt_type) ->
| T_Array (ArrayLength_Expr e, elt_type) ->
fprintf f "@[array [%a] of %a@]" pp_expr e pp_ty elt_type
| T_Array (ArrayLength_Enum (s, _), elt_type) ->
fprintf f "@[array [%s] of %a@]" s pp_ty elt_type
| T_Record record_ty ->
fprintf f "@[<hv 2>record {@ %a@;<1 -2>}@]" pp_fields record_ty
| T_Exception record_ty ->
Expand Down
12 changes: 8 additions & 4 deletions asllib/Parser.mly
Original file line number Diff line number Diff line change
Expand Up @@ -376,12 +376,16 @@ let ty :=
| STRING; { T_String }
| BIT; { t_bit }
| BITS; ~=pared(expr); ~=bitfields_opt; < T_Bits >
| ENUMERATION; l=braced(tclist(IDENTIFIER)); < T_Enum >
| l=plist(ty); < T_Tuple >
| ARRAY; e=bracketed(expr); OF; t=ty; < T_Array >
| name=IDENTIFIER; < T_Named >
| ARRAY; e=bracketed(expr); OF; t=ty; { T_Array (ArrayLength_Expr e, t) }
)

let ty_decl := ty |
annotated (
| ENUMERATION; l=braced(tclist(IDENTIFIER)); < T_Enum >
| RECORD; l=fields_opt; < T_Record >
| EXCEPTION; l=fields_opt; < T_Exception >
| name=IDENTIFIER; < T_Named >
)

(* Constructs on ty *)
Expand Down Expand Up @@ -568,7 +572,7 @@ let decl ==
}

| terminated_by(SEMI_COLON,
| TYPE; x=IDENTIFIER; OF; t=ty; ~=subtype_opt; < D_TypeDecl >
| TYPE; x=IDENTIFIER; OF; t=ty_decl; ~=subtype_opt; < D_TypeDecl >
| TYPE; x=IDENTIFIER; s=annotated(subtype); < make_ty_decl_subtype >

| keyword=storage_keyword; name=ignored_or_identifier;
Expand Down
8 changes: 6 additions & 2 deletions asllib/Serialize.ml
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,8 @@ and pp_ty =
| T_Tuple li ->
addb f "T_Tuple ";
pp_list pp_ty f li
| T_Array (e, elt_type) ->
bprintf f "T_Array (%a, %a)" pp_expr e pp_ty elt_type
| T_Array (length, elt_type) ->
bprintf f "T_Array (%a, %a)" pp_array_length length pp_ty elt_type
| T_Record li ->
addb f "T_Record ";
pp_id_assoc pp_ty f li
Expand All @@ -182,6 +182,10 @@ and pp_ty =
in
fun f s -> pp_annotated pp_desc f s

and pp_array_length f = function
| ArrayLength_Expr e -> bprintf f "ArrayLength_Expr (%a)" pp_expr e
| ArrayLength_Enum (s, i) -> bprintf f "ArrayLength_Enum (%S, %i)" s i

and pp_bitfield f = function
| BitField_Simple (name, slices) ->
bprintf f "BitField_Simple (%S, %a)" name pp_slice_list slices
Expand Down
3 changes: 3 additions & 0 deletions asllib/StaticEnv.ml
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,9 @@ let add_global_storage x ty gdk env =
}

let add_type x ty env =
let () =
if false then Format.eprintf "Adding type %s as %a.@." x PP.pp_ty ty
in
{
env with
global =
Expand Down
44 changes: 39 additions & 5 deletions asllib/StaticInterpreter.ml
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,13 @@ module Normalize = struct
| StrictNegative
| NotNull

let sign_of_z c =
match Z.sign c with
| 1 -> StrictPositive
| 0 -> Null
| -1 -> StrictNegative
| _ -> assert false

module PolynomialOrdered = struct
type t = polynomial

Expand Down Expand Up @@ -363,6 +370,13 @@ module Normalize = struct
| StrictPositive -> Negative
| StrictNegative -> Positive

let sign_minus = function
| (NotNull | Null) as s -> s
| Positive -> Negative
| Negative -> Positive
| StrictPositive -> StrictNegative
| StrictNegative -> StrictPositive

exception ConjunctionBottomInterrupt

let sign_and _p s1 s2 =
Expand Down Expand Up @@ -616,24 +630,44 @@ module Normalize = struct
if e1 == e_true then e2 else if e2 == e_true then e1 else binop BAND e1 e2

let e_cond e1 e2 e3 = E_Cond (e1, e2, e3) |> add_pos_from loc
let unop op e = E_Unop (op, e) |> add_pos_from loc

let monomial_to_expr (Prod map) =
let ( ** ) e1 e2 =
if e1 = one then e2 else if e2 = one then e1 else binop MUL e1 e2
if e1 == one then e2 else if e2 == one then e1 else binop MUL e1 e2
in
let ( ^^ ) e = function
| 0 -> one
| 1 -> e
| 2 -> e ** e
| p -> if e = one then one else binop POW e (expr_of_int p)
| p -> binop POW e (expr_of_int p)
in
AMap.fold (fun s p e -> (e_var s ^^ p) ** e) map

let polynomial_to_expr (Sum map) =
let ( ++ ) e1 e2 =
if e1 == zero then e2 else if e2 == zero then e1 else binop PLUS e1 e2
let add s1 e1 s2 e2 =
match (s1, s2) with
| _, Null -> e1
| Null, _ -> e2
| StrictPositive, StrictPositive | StrictNegative, StrictNegative ->
binop PLUS e1 e2
| StrictPositive, StrictNegative | StrictNegative, StrictPositive ->
binop MINUS e1 e2
| _ -> assert false
in
let res, sign =
MMap.fold
(fun m c (e, sign) ->
let c' = Z.abs c and sign' = sign_of_z c in
let m' = monomial_to_expr m (expr_of_z c') in
(add sign' m' sign e, sign'))
map (zero, Null)
in
MMap.fold (fun m c e -> monomial_to_expr m (expr_of_z c) ++ e) map zero
match sign with
| Null -> zero
| StrictPositive -> res
| StrictNegative -> unop NEG res
| _ -> assert false

let sign_to_binop = function
| Null -> EQ_OP
Expand Down
Loading

0 comments on commit 347d7d4

Please sign in to comment.