Skip to content

Commit

Permalink
SV: Set up memory writes
Browse files Browse the repository at this point in the history
Add simplification rules for bvadd and bvsub
  • Loading branch information
Alasdair committed Oct 29, 2024
1 parent 9ac9ca2 commit df978ba
Show file tree
Hide file tree
Showing 18 changed files with 239 additions and 23 deletions.
9 changes: 6 additions & 3 deletions language/jib.ott
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ name :: '' ::=
| current_exception nat :: :: current_exception
| throw_location nat :: :: throw_location
| channel chan nat :: :: channel
| memory_writes nat :: :: memory_writes
| return nat :: :: return

op :: '' ::=
Expand Down Expand Up @@ -188,11 +189,13 @@ ctyp :: 'CT_' ::=
% A vector type for non-bit vectors, and a (linked) list type.
| fvector ( nat , ctyp ) :: :: fvector
| vector ( ctyp ) :: :: vector
| list ( ctyp ) :: :: list
| list ( ctyp ) :: :: list

| ref ( ctyp ) :: :: ref
| ref ( ctyp ) :: :: ref

| poly kid :: :: poly
| poly kid :: :: poly

| memory_writes :: :: memory_writes

clexp :: 'CL_' ::=
| name : ctyp :: :: id
Expand Down
1 change: 1 addition & 0 deletions lib/concurrency_interface/common.sail
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ $endif
val sail_end_cycle = impure "cycle_count" : unit -> unit

/*! Returns the current cycle count */
$[sv_function { return_type = int }]
val sail_get_cycle_count = impure "get_cycle_count" : unit -> int

$ifdef SYMBOLIC
Expand Down
8 changes: 4 additions & 4 deletions lib/concurrency_interface/emulator_memory.sail
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ val read_mem# : forall ('a: Type) 'n 'addrsize, 'n >= 0 & 'addrsize in {32, 64}.
('a, int('addrsize), bits('addrsize), int('n)) -> bits(8 * 'n)

$ifdef _EMULATOR_MEMORY_PRIMOPS
$[sv_module { reads_memory = true }]
$[sv_function]
val __read_mem# = impure "emulator_read_mem" : forall 'n 'addrsize, 'n >= 0 & 'addrsize in {32, 64}.
(int('addrsize), bits('addrsize), int('n)) -> bits(8 * 'n)

Expand All @@ -96,7 +96,7 @@ val read_mem_ifetch# : forall ('a: Type) 'n 'addrsize, 'n >= 0 & 'addrsize in {3
('a, int('addrsize), bits('addrsize), int('n)) -> bits(8 * 'n)

$ifdef _EMULATOR_MEMORY_PRIMOPS
$[sv_module { reads_memory = true }]
$[sv_function]
val __read_mem_ifetch# = impure "emulator_read_mem_ifetch" : forall 'n 'addrsize, 'n >= 0 & 'addrsize in {32, 64}.
(int('addrsize), bits('addrsize), int('n)) -> bits(8 * 'n)

Expand All @@ -113,7 +113,7 @@ val read_mem_exclusive# : forall ('a: Type) 'n 'addrsize, 'n >= 0 & 'addrsize in
('a, int('addrsize), bits('addrsize), int('n)) -> bits(8 * 'n)

$ifdef _EMULATOR_MEMORY_PRIMOPS
$[sv_module { reads_memory = true }]
$[sv_function]
val __read_mem_exclusive# = impure "emulator_read_mem_exclusive" : forall 'n 'addrsize, 'n >= 0 & 'addrsize in {32, 64}.
(int('addrsize), bits('addrsize), int('n)) -> bits(8 * 'n)

Expand Down Expand Up @@ -155,7 +155,7 @@ function write_mem_exclusive#(_, addrsize, addr, n, value) = __write_mem_exclusi
$endif
$endif

$[sv_module { reads_memory = true }]
$[sv_function]
val read_tag# = impure "emulator_read_tag" : forall 'addrsize, 'addrsize in {32, 64}. (int('addrsize), bits('addrsize)) -> bool

$[sv_module { writes_memory = true }]
Expand Down
82 changes: 82 additions & 0 deletions lib/sv/sail_modules.sv
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@ typedef enum logic [0:0] {SAIL_UNIT=0} sail_unit;
// The Sail zero-width bitvector.
typedef enum logic [0:0] {SAIL_ZWBV=0} sail_zwbv;

int cycle_count;

function automatic int get_cycle_count(sail_unit u);
return cycle_count;
endfunction

module print
(input string in_str,
input string in_sail_stdout,
Expand Down Expand Up @@ -80,4 +86,80 @@ function automatic int string_length(string str);
return str.len();
endfunction // string_length

logic [7:0] sail_memory [logic [63:0]];

bit sail_tag_memory [logic [63:0]];

typedef struct {
logic [63:0] paddr;
logic [7:0] data;
} sail_write;

typedef sail_write sail_memory_writes [$];

function automatic sail_bits emulator_read_mem(logic [63:0] addrsize, sail_bits addr, sail_int n);
logic [63:0] paddr;
logic [SAIL_BITS_WIDTH-1:0] buffer;
sail_int i;

paddr = addr.bits[63:0];

for (i = n; i > 0; i = i - 1) begin
buffer = buffer << 8;
buffer[7:0] = sail_memory[paddr + (i[63:0] - 1)];
end

return '{n[SAIL_INDEX_WIDTH-1:0] * 8, buffer};
endfunction

function automatic sail_bits emulator_read_mem_ifetch(logic [63:0] addrsize, sail_bits addr, sail_int n);
return emulator_read_mem(addrsize, addr, n);
endfunction

function automatic sail_bits emulator_read_mem_exclusive(logic [63:0] addrsize, sail_bits addr, sail_int n);
return emulator_read_mem(addrsize, addr, n);
endfunction

function automatic bit emulator_read_tag(logic [63:0] addrsize, sail_bits addr);
logic [63:0] paddr;
paddr = addr.bits[63:0];
if (sail_tag_memory.exists(paddr) == 1)
return sail_tag_memory[paddr];
else
return 1'b0;
endfunction

module emulator_write_mem
(input logic [63:0] addrsize,
input sail_bits addr,
input sail_int n,
input sail_bits value,
input sail_memory_writes in_writes,
output sail_unit ret,
output sail_memory_writes out_writes
);
endmodule

module emulator_write_mem_exclusive
(input logic [63:0] addrsize,
input sail_bits addr,
input sail_int n,
input sail_bits value,
input sail_memory_writes in_writes,
output sail_unit ret,
output sail_memory_writes out_writes
);
endmodule

module emulator_write_tag
(input logic [63:0] addrsize,
input sail_bits addr,
input bit tag_value,
input sail_memory_writes in_writes,
output sail_unit ret,
output sail_memory_writes out_writes
);
endmodule


`endif
1 change: 1 addition & 0 deletions src/lib/jib_compile.ml
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ let rec mangle_string_of_ctyp ctx = function
| CT_rounding_mode -> "m"
| CT_enum (id, _) -> "E" ^ string_of_id id ^ "%"
| CT_ref ctyp -> "&" ^ mangle_string_of_ctyp ctx ctyp
| CT_memory_writes -> "w"
| CT_tup ctyps -> "(" ^ Util.string_of_list "," (mangle_string_of_ctyp ctx) ctyps ^ ")"
| CT_struct (id, fields) ->
let generic_fields = Bindings.find id ctx.records |> snd |> Bindings.bindings in
Expand Down
7 changes: 4 additions & 3 deletions src/lib/jib_optimize.ml
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ let ssa_name i = function
| Throw_location _ -> Throw_location i
| Return _ -> Return i
| Channel (chan, _) -> Channel (chan, i)
| Memory_writes _ -> Memory_writes i

let inline cdefs should_inline instrs =
let inlines = ref (-1) in
Expand Down Expand Up @@ -469,7 +470,7 @@ let remove_tuples cdefs ctx =
List.fold_left (fun cts (_, ctyp) -> CTSet.union (all_tuples ctyp) cts) CTSet.empty id_ctyps
| CT_list ctyp | CT_vector ctyp | CT_fvector (_, ctyp) | CT_ref ctyp -> all_tuples ctyp
| CT_lint | CT_fint _ | CT_lbits | CT_sbits _ | CT_fbits _ | CT_constant _ | CT_float _ | CT_unit | CT_bool
| CT_real | CT_bit | CT_poly _ | CT_string | CT_enum _ | CT_rounding_mode ->
| CT_real | CT_bit | CT_poly _ | CT_string | CT_enum _ | CT_rounding_mode | CT_memory_writes ->
CTSet.empty
in
let rec tuple_depth = function
Expand All @@ -478,7 +479,7 @@ let remove_tuples cdefs ctx =
List.fold_left (fun d (_, ctyp) -> max (tuple_depth ctyp) d) 0 id_ctyps
| CT_list ctyp | CT_vector ctyp | CT_fvector (_, ctyp) | CT_ref ctyp -> tuple_depth ctyp
| CT_lint | CT_fint _ | CT_lbits | CT_sbits _ | CT_fbits _ | CT_constant _ | CT_unit | CT_bool | CT_real | CT_bit
| CT_poly _ | CT_string | CT_enum _ | CT_float _ | CT_rounding_mode ->
| CT_poly _ | CT_string | CT_enum _ | CT_float _ | CT_rounding_mode | CT_memory_writes ->
0
in
let rec fix_tuples = function
Expand All @@ -493,7 +494,7 @@ let remove_tuples cdefs ctx =
| CT_fvector (n, ctyp) -> CT_fvector (n, fix_tuples ctyp)
| CT_ref ctyp -> CT_ref (fix_tuples ctyp)
| ( CT_lint | CT_fint _ | CT_lbits | CT_sbits _ | CT_fbits _ | CT_constant _ | CT_float _ | CT_unit | CT_bool
| CT_real | CT_bit | CT_poly _ | CT_string | CT_enum _ | CT_rounding_mode ) as ctyp ->
| CT_real | CT_bit | CT_poly _ | CT_string | CT_enum _ | CT_rounding_mode | CT_memory_writes ) as ctyp ->
ctyp
and fix_cval = function
| V_id (id, ctyp) -> V_id (id, ctyp)
Expand Down
2 changes: 2 additions & 0 deletions src/lib/jib_ssa.ml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ let ssa_name i = function
| Current_exception _ -> Current_exception i
| Throw_location _ -> Throw_location i
| Channel (c, _) -> Channel (c, i)
| Memory_writes _ -> Memory_writes i
| Return _ -> Return i

let unssa_name = function
Expand All @@ -68,6 +69,7 @@ let unssa_name = function
| Current_exception n -> (Current_exception (-1), n)
| Throw_location n -> (Throw_location (-1), n)
| Channel (c, n) -> (Channel (c, -1), n)
| Memory_writes n -> (Memory_writes (-1), n)
| Return n -> (Return (-1), n)

(**************************************************************************)
Expand Down
19 changes: 14 additions & 5 deletions src/lib/jib_util.ml
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ module Name = struct
| Have_exception n, Have_exception m -> compare n m
| Current_exception n, Current_exception m -> compare n m
| Return n, Return m -> compare n m
| Memory_writes n, Memory_writes m -> compare n m
| Channel (c1, n), Channel (c2, m) -> begin
match (c1, c2) with
| Chan_stdout, Chan_stdout -> compare n m
Expand All @@ -145,6 +146,8 @@ module Name = struct
| _, Throw_location _ -> -1
| Return _, _ -> 1
| _, Return _ -> -1
| Memory_writes _, _ -> 1
| _, Memory_writes _ -> -1
end

module NameSet = Set.Make (Name)
Expand Down Expand Up @@ -199,6 +202,7 @@ let string_of_name ?deref_current_exception:(dce = false) ?(zencode = true) =
| Current_exception n when dce -> "(*current_exception)" ^ ssa_num n
| Current_exception n -> "current_exception" ^ ssa_num n
| Throw_location n -> "throw_location" ^ ssa_num n
| Memory_writes n -> "memory_writes" ^ ssa_num n
| Channel (chan, n) -> (
match chan with Chan_stdout -> "stdout" ^ ssa_num n | Chan_stderr -> "stderr" ^ ssa_num n
)
Expand Down Expand Up @@ -253,6 +257,7 @@ let rec string_of_ctyp = function
| CT_bool -> "%bool"
| CT_real -> "%real"
| CT_string -> "%string"
| CT_memory_writes -> "%memory_writes"
| CT_tup ctyps -> "(" ^ Util.string_of_list ", " string_of_ctyp ctyps ^ ")"
| CT_struct (id, _fields) -> "%struct " ^ Util.zencode_string (string_of_id id)
| CT_enum (id, _) -> "%enum " ^ Util.zencode_string (string_of_id id)
Expand Down Expand Up @@ -384,7 +389,7 @@ let string_of_instr i = Document.to_string (doc_instr i)

let rec map_ctyp f = function
| ( CT_lint | CT_fint _ | CT_constant _ | CT_lbits | CT_fbits _ | CT_sbits _ | CT_float _ | CT_rounding_mode | CT_bit
| CT_unit | CT_bool | CT_real | CT_string | CT_poly _ | CT_enum _ ) as ctyp ->
| CT_unit | CT_bool | CT_real | CT_string | CT_poly _ | CT_enum _ | CT_memory_writes ) as ctyp ->
f ctyp
| CT_tup ctyps -> f (CT_tup (List.map (map_ctyp f) ctyps))
| CT_ref ctyp -> f (CT_ref (map_ctyp f ctyp))
Expand All @@ -399,7 +404,7 @@ let rec ctyp_has pred ctyp =
||
match ctyp with
| CT_lint | CT_fint _ | CT_constant _ | CT_lbits | CT_fbits _ | CT_sbits _ | CT_float _ | CT_rounding_mode | CT_bit
| CT_unit | CT_bool | CT_real | CT_string | CT_poly _ | CT_enum _ ->
| CT_unit | CT_bool | CT_real | CT_string | CT_poly _ | CT_enum _ | CT_memory_writes ->
false
| CT_tup ctyps -> List.exists (ctyp_has pred) ctyps
| CT_ref ctyp | CT_vector ctyp | CT_fvector (_, ctyp) | CT_list ctyp -> ctyp_has pred ctyp
Expand Down Expand Up @@ -510,6 +515,9 @@ let rec ctyp_compare ctyp1 ctyp2 =
| CT_enum _, _ -> 1
| _, CT_enum _ -> -1
| CT_rounding_mode, CT_rounding_mode -> 0
| CT_rounding_mode, _ -> 1
| _, CT_rounding_mode -> -1
| CT_memory_writes, CT_memory_writes -> 0

module CT = struct
type t = ctyp
Expand Down Expand Up @@ -548,6 +556,7 @@ let rec ctyp_suprema = function
| CT_string -> CT_string
| CT_float n -> CT_float n
| CT_rounding_mode -> CT_rounding_mode
| CT_memory_writes -> CT_memory_writes
| CT_enum (id, ids) -> CT_enum (id, ids)
(* Do we really never want to never call ctyp_suprema on constructor
fields? Doing it causes issues for structs (see
Expand Down Expand Up @@ -609,7 +618,7 @@ let rec ctyp_ids = function
| CT_tup ctyps -> List.fold_left (fun ids ctyp -> IdSet.union (ctyp_ids ctyp) ids) IdSet.empty ctyps
| CT_vector ctyp | CT_fvector (_, ctyp) | CT_list ctyp | CT_ref ctyp -> ctyp_ids ctyp
| CT_lint | CT_fint _ | CT_constant _ | CT_lbits | CT_fbits _ | CT_sbits _ | CT_unit | CT_bool | CT_real | CT_bit
| CT_string | CT_poly _ | CT_float _ | CT_rounding_mode ->
| CT_string | CT_poly _ | CT_float _ | CT_rounding_mode | CT_memory_writes ->
IdSet.empty

let rec subst_poly substs = function
Expand All @@ -622,12 +631,12 @@ let rec subst_poly substs = function
| CT_variant (id, ctors) -> CT_variant (id, List.map (fun (ctor_id, ctyp) -> (ctor_id, subst_poly substs ctyp)) ctors)
| CT_struct (id, fields) -> CT_struct (id, List.map (fun (ctor_id, ctyp) -> (ctor_id, subst_poly substs ctyp)) fields)
| ( CT_lint | CT_fint _ | CT_constant _ | CT_unit | CT_bool | CT_bit | CT_string | CT_real | CT_lbits | CT_fbits _
| CT_sbits _ | CT_enum _ | CT_float _ | CT_rounding_mode ) as ctyp ->
| CT_sbits _ | CT_enum _ | CT_float _ | CT_rounding_mode | CT_memory_writes ) as ctyp ->
ctyp

let rec is_polymorphic = function
| CT_lint | CT_fint _ | CT_constant _ | CT_lbits | CT_fbits _ | CT_sbits _ | CT_bit | CT_unit | CT_bool | CT_real
| CT_string | CT_float _ | CT_rounding_mode ->
| CT_string | CT_float _ | CT_rounding_mode | CT_memory_writes ->
false
| CT_tup ctyps -> List.exists is_polymorphic ctyps
| CT_enum _ -> false
Expand Down
2 changes: 1 addition & 1 deletion src/lib/jib_visitor.ml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ let rec visit_ctyp vis outer_ctyp =
let aux vis no_change =
match no_change with
| CT_lint | CT_fint _ | CT_constant _ | CT_lbits | CT_sbits _ | CT_fbits _ | CT_unit | CT_bool | CT_bit | CT_string
| CT_real | CT_float _ | CT_rounding_mode | CT_poly _ ->
| CT_real | CT_float _ | CT_rounding_mode | CT_memory_writes | CT_poly _ ->
no_change
| CT_tup ctyps ->
let ctyps' = visit_ctyps vis ctyps in
Expand Down
4 changes: 3 additions & 1 deletion src/lib/smt_exp.ml
Original file line number Diff line number Diff line change
Expand Up @@ -766,7 +766,7 @@ module Simplifier = struct
| _ -> NoChange

let is_bvfunction = function
| "bvnot" | "bvand" | "bvor" | "bvxor" | "bvshl" | "bvlshr" | "bvashr" -> true
| "bvnot" | "bvand" | "bvor" | "bvxor" | "bvshl" | "bvlshr" | "bvashr" | "bvadd" | "bvsub" -> true
| _ -> false

let rule_bvfunction_literal =
Expand All @@ -778,6 +778,8 @@ module Simplifier = struct
| "bvand", [Bitvec_lit lhs; Bitvec_lit rhs] -> change (Bitvec_lit (and_vec lhs rhs))
| "bvor", [Bitvec_lit lhs; Bitvec_lit rhs] -> change (Bitvec_lit (or_vec lhs rhs))
| "bvxor", [Bitvec_lit lhs; Bitvec_lit rhs] -> change (Bitvec_lit (xor_vec lhs rhs))
| "bvadd", [Bitvec_lit lhs; Bitvec_lit rhs] -> change (Bitvec_lit (add_vec lhs rhs))
| "bvsub", [Bitvec_lit lhs; Bitvec_lit rhs] -> change (Bitvec_lit (sub_vec lhs rhs))
| "bvshl", [lhs; Bitvec_lit rhs] when bv_is_zero rhs -> change lhs
| "bvshl", [Bitvec_lit lhs; Bitvec_lit rhs] -> begin
match sint_maybe rhs with Some shift -> change (Bitvec_lit (shiftl lhs shift)) | None -> NoChange
Expand Down
7 changes: 6 additions & 1 deletion src/sail_c_backend/c_backend.ml
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ let rec is_stack_ctyp ctyp =
| CT_float _ -> true
| CT_rounding_mode -> true
| CT_constant n -> Big_int.less_equal (min_int 64) n && Big_int.greater_equal n (max_int 64)
| CT_memory_writes -> false

let v_mask_lower i = V_lit (VL_bits (Util.list_init i (fun _ -> Sail2_values.B1)), CT_fbits i)

Expand Down Expand Up @@ -187,6 +188,7 @@ let rec sgen_ctyp_name = function
| CT_ref ctyp -> "ref_" ^ sgen_ctyp_name ctyp
| CT_float n -> "float" ^ string_of_int n
| CT_rounding_mode -> "rounding_mode"
| CT_memory_writes -> "sail_memory_writes"
| CT_poly _ -> "POLY" (* c_error "Tried to generate code for non-monomorphic type" *)

let sail_create ?(prefix = "") ?(suffix = "") ctyp fmt =
Expand Down Expand Up @@ -917,6 +919,7 @@ let rec sgen_ctyp = function
| CT_ref ctyp -> sgen_ctyp ctyp ^ "*"
| CT_float n -> "float" ^ string_of_int n ^ "_t"
| CT_rounding_mode -> "uint_fast8_t"
| CT_memory_writes -> "sail_memory_writes"
| CT_poly _ -> "POLY" (* c_error "Tried to generate code for non-monomorphic type" *)

let sgen_const_ctyp = function CT_string -> "const_sail_string" | ty -> sgen_ctyp ty
Expand Down Expand Up @@ -1093,6 +1096,7 @@ let rec sgen_clexp l = function
| CL_id (Have_exception _, _) -> "have_exception"
| CL_id (Current_exception _, _) -> "current_exception"
| CL_id (Throw_location _, _) -> "throw_location"
| CL_id (Memory_writes _, _) -> "memory_writes"
| CL_id (Channel _, _) -> Reporting.unreachable l __POS__ "CL_id Channel should not appear in C backend"
| CL_id (Return _, _) -> Reporting.unreachable l __POS__ "CL_id Return should have been removed"
| CL_id (Name (id, _), _) -> "&" ^ sgen_id id
Expand All @@ -1106,6 +1110,7 @@ let rec sgen_clexp_pure l = function
| CL_id (Have_exception _, _) -> "have_exception"
| CL_id (Current_exception _, _) -> "current_exception"
| CL_id (Throw_location _, _) -> "throw_location"
| CL_id (Memory_writes _, _) -> "memory_writes"
| CL_id (Channel _, _) -> Reporting.unreachable l __POS__ "CL_id Channel should not appear in C backend"
| CL_id (Return _, _) -> Reporting.unreachable l __POS__ "CL_id Return should have been removed"
| CL_id (Name (id, _), _) -> sgen_id id
Expand Down Expand Up @@ -2018,7 +2023,7 @@ let rec ctyp_dependencies = function
| CT_struct (_, ctors) -> List.concat (List.map (fun (_, ctyp) -> ctyp_dependencies ctyp) ctors)
| CT_variant (_, ctors) -> List.concat (List.map (fun (_, ctyp) -> ctyp_dependencies ctyp) ctors)
| CT_lint | CT_fint _ | CT_lbits | CT_fbits _ | CT_sbits _ | CT_unit | CT_bool | CT_real | CT_bit | CT_string
| CT_enum _ | CT_poly _ | CT_constant _ | CT_float _ | CT_rounding_mode ->
| CT_enum _ | CT_poly _ | CT_constant _ | CT_float _ | CT_rounding_mode | CT_memory_writes ->
[]

let codegen_ctg = function
Expand Down
Loading

0 comments on commit df978ba

Please sign in to comment.