Skip to content

Commit

Permalink
Merge pull request #974 from yuxiliu-arm/liu/sve-semantics-fix
Browse files Browse the repository at this point in the history
Rework the intrinsic dependencies for SVE instructions

* Add iico_ctrl to AnyActiveElement checks (for scatter load and gather store)
* Add iico_ctrl to ActivePredicateElement checks
* Add iico_order to SVE store1
  • Loading branch information
relokin authored Oct 25, 2024
2 parents 6be733a + 5490096 commit 8e7e63c
Show file tree
Hide file tree
Showing 4 changed files with 199 additions and 106 deletions.
246 changes: 141 additions & 105 deletions herd/AArch64Sem.ml
Original file line number Diff line number Diff line change
Expand Up @@ -2092,6 +2092,12 @@ module Make
(read_reg_data sz rs ii >>= tr_input)
an ii

(* Neon/SVE/SME instructions *)
let (let>*) = M.bind_control_set_data_input_first
let (let>=) = M.(>>=)
let (and*) = M.(>>|)
let (let<>=) = M.bind_data_to_output

(* Utility that performes an `N`-bit load as two independent `N/2`-bit
* loads. Used by 128-bit Neon LDR.
*
Expand Down Expand Up @@ -2820,127 +2826,157 @@ module Make
let ops = List.mapi op rlist in
List.fold_right (>>::) ops (M.unitT [[()]])

(** branch on whether [p]'s value [pred] has any active elements.

add [iico_causality_ctrl] from the predicate read to [mtrue] or
[mfalse] *)
let any_active p pred psize nelem ii mtrue mfalse =
let>= any = get_predicate_any pred psize nelem in
let>* () =
let cond = Printf.sprintf "AnyActive(%s)" (A.pp_reg p) in
commit_pred_txt (Some cond) ii
in
M.choiceT any mtrue mfalse

(** check the element [idx] in predicate [pred] and add [mtrue] if active,
or [mfalse] otherwise.
add [iico_causality_ctrl] from the predicate read to [mtrue] or
[mfalse] *)
let is_active_element p pred psize idx ii mtrue mfalse =
let>= last = get_predicate_last pred psize idx in
let>* () =
let cond = Printf.sprintf "ActiveElem(%s, %d)" (A.pp_reg p) idx in
commit_pred_txt (Some cond) ii
in
M.choiceT last mtrue mfalse

let no_action = M.mk_singleton_es Act.NoAction

(** perform [ops] in parallel and fold right on results *)
let para_fold_right mbind ops munit =
let final results =
List.fold_right
(fun v macc -> macc >>= mbind v)
results munit
in
M.data_output_union
(List.fold_right ( >>:: ) ops (M.unitT []))
final

let load_predicated_elem_or_zero_m sz p ma rlist ii =
let r = List.hd rlist in
let nelem = scalable_nelem r in
let psize = predicate_psize r in
let esize = scalable_esize r in
let nregs = List.length rlist in
read_reg_predicate false p ii >>= fun pred ->
get_predicate_any pred psize nelem >>= fun any ->
M.choiceT
any
(ma >>= fun addr ->
let op i r =
let calc_offset idx = (idx*nregs+i) * MachSize.nbytes sz in
let load idx =
let offset = calc_offset idx in
get_predicate_last pred psize idx >>= fun last ->
M.choiceT
last
(M.op1 (Op.AddK offset) addr >>= fun addr ->
do_read_mem_ret sz Annot.N aexp Access.VIR addr ii)
mzero
>>= promote >>= M.op1 (Op.LeftShift (idx*esize)) in
let rec reduce idx op =
match idx with
| 0 -> op >>| load idx >>= fun (v1,v2) -> M.op Op.Or v1 v2
| _ -> reduce (idx-1) (op >>| load idx >>= fun (v1,v2) -> M.op Op.Or v1 v2)
in
reduce (nelem-1) mzero_promoted >>= fun v ->
write_reg_scalable r v ii
in
let ops = List.mapi op rlist in
List.fold_right (>>::) ops (M.unitT [()]))
(let ops = List.map (fun r -> write_reg_scalable r AArch64.zero_promoted ii) rlist in
List.fold_right (>>::) ops (M.unitT [()]))
let>= results =
let<>= base = ma in
let>= pred = read_reg_predicate false p ii in
let ops i =
let op idx =
let load =
let offset = (idx * nregs + i) * MachSize.nbytes sz in
let>= addr = M.op1 (Op.AddK offset) base in
let>= v = do_read_mem_ret sz Annot.N aexp Access.VIR addr ii in
let>= v = promote v in
M.op1 (Op.LeftShift (idx * esize)) v
in
is_active_element p pred psize idx ii load (no_action ii >>! M.A.V.zero)
in
let ops = List.map op (Misc.interval 0 nelem) in
para_fold_right (M.op Op.Or) ops mzero
in
let ops = List.map ops (Misc.interval 0 nregs) in
List.fold_right ( >>:: ) ops (M.unitT [])
in
let f (r, result) macc =
write_reg_scalable r result ii >>:: macc
in
List.fold_right f (List.combine rlist results) (M.unitT [()])

let store_predicated_elem_or_merge_m sz p ma rlist ii =
let r = List.hd rlist in
let nelem = scalable_nelem r in
let psize = predicate_psize r in
let esize = scalable_esize r in
let nregs = List.length rlist in
read_reg_predicate false p ii >>= fun pred ->
get_predicate_any pred psize nelem >>= fun any ->
M.choiceT
any
(ma >>= fun addr ->
let op i r =
read_reg_scalable true r ii >>= fun v ->
let calc_offset idx = (idx*nregs+i) * MachSize.nbytes sz in
let store idx =
let offset = calc_offset idx in
get_predicate_last pred psize idx >>= fun last ->
M.choiceT
last
(M.op1 (Op.AddK offset) addr
>>| (scalable_getlane v idx esize >>= demote)
>>= fun (addr,v) -> write_mem sz aexp Access.VIR addr v ii)
(M.unitT ())
in
let rec reduce idx op =
match idx with
| 0 -> store idx >>:: op
| _ -> reduce (idx-1) (store idx >>:: op)
in
reduce (nelem-1) (M.unitT [()]) in
let ops = List.mapi op rlist in
List.fold_right (>>::) ops (M.unitT [[()]]))
(M.unitT [[()]])

let load_gather_predicated_elem_or_zero sz p ma mo rs e k ii =
let<>= base = ma in
let>= pred = read_reg_predicate false p ii in
let ops i r =
let<>= v =
any_active p pred psize nelem ii
(read_reg_scalable true r ii)
mzero
in
let op idx =
let store =
let offset = (idx * nregs + i) * MachSize.nbytes sz in
let>= addr = M.op1 (Op.AddK offset) base
and* v = scalable_getlane v idx esize >>= demote in
write_mem sz aexp Access.VIR addr v ii
in
is_active_element p pred psize idx ii store (M.unitT ())
in
let ops = List.map op (Misc.interval 0 nelem) in
List.fold_right M.seq_mem_list ops (M.unitT [])
(* List.fold_right M.seq_mem_list ops (M.unitT [()]) *)
in
let ops = List.mapi ops rlist in
List.fold_right M.seq_mem_list ops (M.unitT [])

let load_gather_predicated_elem_or_zero sz p ma mo rs e k ii =
let r = List.hd rs in
let psize = predicate_psize r in
let nelem = scalable_nelem r in
let esize = scalable_esize r in
read_reg_predicate false p ii >>= fun pred ->
get_predicate_any pred psize nelem >>= fun any ->
M.choiceT
any
(ma >>| mo >>= fun (base,offsets) ->
let load idx =
get_predicate_last pred psize idx >>= fun last ->
M.choiceT
last
(scalable_getlane offsets idx esize >>= memext_sext e k >>= fun o ->
M.add o base >>= fun addr ->
do_read_mem_ret sz Annot.N aexp Access.VIR addr ii)
mzero
>>= promote >>= fun v -> M.op1 (Op.LeftShift (idx*esize)) v
in
let rec reduce idx op =
match idx with
| 0 -> op >>| load idx >>= fun (v1,v2) -> M.op Op.Or v1 v2
| _ -> reduce (idx-1) (op >>| load idx >>= fun (v1,v2) -> M.op Op.Or v1 v2)
in
reduce (nelem-1) mzero)
mzero >>= fun v ->
write_reg_scalable r v ii
let>= pred = read_reg_predicate false p ii in
let>= result =
let<>= (base, offsets) =
any_active p pred psize nelem ii
(ma >>| mo)
(M.unitT M.A.V.(zero, zero))
in
let op idx =
let load =
let>= lane = scalable_getlane offsets idx esize in
let>= lane = demote lane in
let>= o = memext_sext e k lane in
let>= addr = M.add base o in
let>= v = do_read_mem_ret sz Annot.N aexp Access.VIR addr ii in
let>= v = promote v in
M.op1 (Op.LeftShift (idx * esize)) v
in
is_active_element p pred psize idx ii load (no_action ii >>! M.A.V.zero)
in
let ops = List.map op (Misc.interval 0 nelem) in
para_fold_right (M.op Op.Or) ops mzero
in
write_reg_scalable r result ii

let store_scatter_predicated_elem_or_merge sz p ma mo rs e k ii =
let store_scatter_predicated_elem_or_merge sz p ma mo rs e k ii =
let r = List.hd rs in
let psize = predicate_psize r in
let nelem = scalable_nelem r in
let esize = scalable_esize r in
read_reg_predicate false p ii >>= fun pred ->
get_predicate_any pred psize nelem >>= fun any ->
M.choiceT
any
(ma >>| mo >>| read_reg_scalable true r ii >>= fun ((base,offsets),v) ->
let op idx =
let store =
(scalable_getlane offsets idx esize
(* Warning: no sign extension on wide scalars *)
>>= demote >>= memext_sext e k >>= fun o ->
M.add o base) >>|
(scalable_getlane v idx esize >>= demote)
>>= fun (addr,v) -> write_mem sz aexp Access.VIR addr v ii in
get_predicate_last pred psize idx >>= fun last ->
M.choiceT last store (M.unitT ()) in
let ops = List.map op (Misc.interval 0 nelem) in
List.fold_right (>>::) ops (M.unitT [()]))
(M.unitT [()])
let>= pred = read_reg_predicate false p ii in
let<>= ((base, offsets), v) =
any_active p pred psize nelem ii
(ma >>| mo >>| read_reg_scalable true r ii)
(M.unitT ((M.A.V.zero, M.A.V.zero), M.A.V.zero))
in
let op idx =
let store =
let>= lane = scalable_getlane offsets idx esize in
let>= lane = demote lane in
let>= o = memext_sext e k lane in
let>= addr = M.add base o in
let>= v = scalable_getlane v idx esize in
let>= v = demote v in
write_mem sz aexp Access.VIR addr v ii in
is_active_element r pred psize idx ii store (M.unitT ())
in
let ops = List.map op (Misc.interval 0 nelem) in
List.fold_right M.seq_mem_list ops (M.unitT [()])

let load_predicated_slice sz r ri k p ma ii =
let dst,tile,dir,esize = match r with
Expand Down Expand Up @@ -3596,18 +3632,18 @@ module Make
| I_ST4SP(var,rs,p,rA,MemExt.Imm (k,Idx)) ->
check_sve inst;
!!!!(let sz = tr_simd_variant var in
let ma = get_ea_idx rA k ii in
store_predicated_elem_or_merge_m sz p ma rs ii >>|
M.unitT ())
let ma = get_ea_idx rA k ii in
store_predicated_elem_or_merge_m sz p ma rs ii >>|
M.unitT ())
| I_ST1SP(var,rs,p,rA,MemExt.Reg (V64,rM,MemExt.LSL,s))
| I_ST2SP(var,rs,p,rA,MemExt.Reg (V64,rM,MemExt.LSL,s))
| I_ST3SP(var,rs,p,rA,MemExt.Reg (V64,rM,MemExt.LSL,s))
| I_ST4SP(var,rs,p,rA,MemExt.Reg (V64,rM,MemExt.LSL,s)) ->
check_sve inst;
!!!!(let sz = tr_simd_variant var in
let ma = get_ea_reg rA V64 rM MemExt.LSL s ii in
store_predicated_elem_or_merge_m sz p ma rs ii >>|
M.unitT ())
let ma = get_ea_reg rA V64 rM MemExt.LSL s ii in
store_predicated_elem_or_merge_m sz p ma rs ii >>|
M.unitT ())
| I_ST1SP (var,rs,p,rA,MemExt.ZReg (rM,sext,s)) ->
check_sve inst;
!!!(let sz = tr_simd_variant var in
Expand Down
24 changes: 24 additions & 0 deletions herd/event.ml
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,9 @@ val same_instance : event -> event -> bool
val data_input_union : (* input to both structures *)
event_structure -> event_structure -> event_structure option

val data_to_output : (* input to second es output *)
event_structure -> event_structure -> event_structure option

val data_to_minimals : (* second es entries are minimal evts all iico *)
event_structure -> event_structure -> event_structure option

Expand All @@ -362,6 +365,9 @@ val same_instance : event -> event -> bool
val (=$$=) :
event_structure -> event_structure -> event_structure option

val data_output_union :
event_structure -> event_structure -> event_structure option

(* sequential composition, add control dependency *)
val (=**=) :
event_structure -> event_structure -> event_structure option
Expand Down Expand Up @@ -1470,6 +1476,18 @@ module Make (C:Config) (AI:Arch_herd.S) (Act:Action.S with module A = AI) :
let data_to_minimals =
check_disjoint (data_comp minimals sequence_data_output)

let data_to_output es1 es2 =
let r =
data_comp
get_output
sequence_data_output es1 es2 in
let r =
{ r with
input = union_input_seq es1 es2 ;
data_input = union_data_input_seq es1 es2 ;
} in
Some r

let (=$$=) =
let out es1 es2 =
let out = get_output es1 in
Expand All @@ -1479,6 +1497,12 @@ module Make (C:Config) (AI:Arch_herd.S) (Act:Action.S with module A = AI) :
Some out in
check_disjoint (data_comp minimals_data out)

let data_output_union es1 es2 =
let r = data_comp minimals sequence_data_output es1 es2 in
Some
{ r with
output = union_output es1 es2 ;
}

(* Composition with intra_causality_control from first to second *)

Expand Down
9 changes: 9 additions & 0 deletions herd/eventsMonad.ml
Original file line number Diff line number Diff line change
Expand Up @@ -266,12 +266,16 @@ Monad type:
let (>>==) : 'a t -> ('a -> 'b t) -> ('b) t
= fun s f -> data_comp (=$$=) s f

let data_output_union : 'a t -> ('a -> 'b t) -> ('b) t
= fun s f -> data_comp (E.data_output_union) s f

let asl_data s f = data_comp E.data_po_seq s f

let (>>*=) : 'a t -> ('a -> 'b t) -> ('b) t
= fun s f -> data_comp (=**=) s f

let control_input_union s f = data_comp E.control_input_union s f
let control_input_next s f = data_comp E.control_input_next s f

let (>>*==) : 'a t -> ('a -> 'b t) -> ('b) t
= fun s f -> data_comp (=*$$=) s f
Expand All @@ -291,6 +295,8 @@ Monad type:

let bind_data_to_minimals s f = data_comp E.data_to_minimals s f

let bind_data_to_output s f = data_comp E.data_to_output s f

(* Triple composition *)
let comp_comp comp_str m1 m2 m3 eiid =
let eiid,(acts1,spec1) = m1 eiid in
Expand Down Expand Up @@ -790,6 +796,9 @@ Monad type:
let seq_mem : 'a t -> 'b t -> ('a * 'b) t
= fun s1 s2 -> combi Misc.pair E.seq_mem s1 s2

let seq_mem_list : 'a t -> 'a list t -> 'a list t
= fun s1 s2 -> combi Misc.cons E.seq_mem s1 s2

(* Force monad value *)
let forceT (v : 'a) : 'b t -> 'a t =
let f (_, vcl, es) = (v, vcl, es) in
Expand Down
Loading

0 comments on commit 8e7e63c

Please sign in to comment.