Skip to content

Commit

Permalink
i#5036 A64 scatter/gather, part 7: Expand replicating loads (#6483)
Browse files Browse the repository at this point in the history
Adds support to drx_expand_scatter_gather() for SVE scalar+scalar and
scalar+immediate replicating predicated contiguous load and store
instructions, along with tests.

Issue: #5036
  • Loading branch information
jackgallagher-arm committed Dec 1, 2023
1 parent b9441b3 commit df27443
Show file tree
Hide file tree
Showing 7 changed files with 299 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,18 @@ Hello, world!
---- <application exited with code 0> ----
Basic counts tool results:
Total counts:
655 total \(fetched\) instructions
243 total unique \(fetched\) instructions
685 total \(fetched\) instructions
255 total unique \(fetched\) instructions
0 total non-fetched instructions
0 total prefetches
#if (__ARM_FEATURE_SVE_BITS == 128)
1069 total data loads
1137 total data loads
861 total data stores
#elif (__ARM_FEATURE_SVE_BITS == 256)
1967 total data loads
2035 total data loads
1595 total data stores
#elif (__ARM_FEATURE_SVE_BITS == 512)
3763 total data loads
3831 total data loads
3063 total data stores
#endif
0 total icache flushes
Expand All @@ -22,18 +22,18 @@ Total counts:
.* total scheduling markers
.*
Thread .* counts:
655 \(fetched\) instructions
243 unique \(fetched\) instructions
685 \(fetched\) instructions
255 unique \(fetched\) instructions
0 non-fetched instructions
0 prefetches
#if (__ARM_FEATURE_SVE_BITS == 128)
1069 data loads
1137 data loads
861 data stores
#elif (__ARM_FEATURE_SVE_BITS == 256)
1967 data loads
2035 data loads
1595 data stores
#elif (__ARM_FEATURE_SVE_BITS == 512)
3763 data loads
3831 data loads
3063 data stores
#endif
0 icache flushes
Expand Down
40 changes: 30 additions & 10 deletions clients/drcachesim/tests/allasm_scattergather_aarch64.asm
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,21 @@ test_scalar_plus_immediate:

ret

test_replicating_loads:
ld1rqb DEST_REG1.b, B_MASK_REG/z, [BUFFER_REG, X_INDEX_REG] // 16
ld1rqh DEST_REG1.h, H_MASK_REG/z, [BUFFER_REG, X_INDEX_REG, lsl #1] // 8
ld1rqw DEST_REG1.s, S_MASK_REG/z, [BUFFER_REG, X_INDEX_REG, lsl #2] // 4
ld1rqd DEST_REG1.d, D_MASK_REG/z, [BUFFER_REG, X_INDEX_REG, lsl #3] // 2
// Total: 30

ld1rqb DEST_REG1.b, B_MASK_REG/z, [BUFFER_REG, #0] // 16
ld1rqh DEST_REG1.h, H_MASK_REG/z, [BUFFER_REG, #0] // 8
ld1rqw DEST_REG1.s, S_MASK_REG/z, [BUFFER_REG, #0] // 4
ld1rqd DEST_REG1.d, D_MASK_REG/z, [BUFFER_REG, #0] // 2
// Total: 30

ret

_start:
#ifdef __APPLE__
adrp BUFFER_REG, buffer@PAGE
Expand Down Expand Up @@ -407,8 +422,10 @@ _start:

bl test_scalar_plus_immediate // +(374 * vl_bytes/16) loads
// +(322 * vl_bytes/16) stores
bl test_replicating_loads // +60 loads
// +0 stores
// Running total:
// Loads: (136 + 14 + 374 + 374) * vl_bytes/16 = 898 * vl_bytes/16
// Loads: (136 + 14 + 374 + 374) * vl_bytes/16 + 60 = 898 * vl_bytes/16 + 60
// Stores: (82 + 8 + 322 + 322) * vl_bytes/16 = 734 * vl_bytes/16

/* Run all the instructions with no active elements */
Expand All @@ -422,9 +439,10 @@ _start:
bl test_vector_plus_immediate // +0 loads, +0 stores
bl test_scalar_plus_scalar // +0 loads, +0 stores
bl test_scalar_plus_immediate // +0 loads, +0 stores
bl test_replicating_loads // +0 loads, +0 stores

// Running total (unchanged from above):
// Loads: 898 * vl_bytes/16
// Loads: (898 * vl_bytes/16) + 60
// Stores: 734 * vl_bytes/16

/* Run all instructions with one active element */
Expand All @@ -437,26 +455,28 @@ _start:
bl test_vector_plus_immediate // +7 loads, +4 stores
bl test_scalar_plus_scalar // +56 loads, +46 stores
bl test_scalar_plus_immediate // +56 loads, +46 stores
bl test_replicating_loads // +8 loads, +0 stores

// Running total:
// Loads: (898 * vl_bytes/16) + 52 + 7 + 56 + 56 = (898 * vl_bytes/16) + 171
// Loads: (898 * vl_bytes/16) + 60 + 52 + 7 + 56 + 56 + 8 = (898 * vl_bytes/16) + 239
// Stores: (734 * vl_bytes/16) + 41 + 4 + 46 + 46 = (734 * vl_bytes/16) + 127

// The functions in this file have the following instructions counts:
// _start 37
// _start 40
// test_scalar_plus_vector 84
// test_vector_plus_immediate 12
// test_scalar_plus_scalar 55
// test_scalar_plus_immediate 55
// So there are 37 + 84 + 12 + 55 + 55 = 243 unique instructions
// We run the test_* functions 3 times each so the totoal instruction executed is
// ((84 + 12 + 55 + 55) * 3) + 37 = (206 * 3) + 37 = 655
// test_replicating_loads 9
// So there are 40 + 84 + 12 + 55 + 55 + 9 = 255 unique instructions
// We run the test_* functions 3 times each so the total instruction executed is
// ((84 + 12 + 55 + 55 + 9) * 3) + 40 = (215 * 3) + 37 = 685

// Totals:
// Loads: (898 * vl_bytes/16) + 171
// Loads: (898 * vl_bytes/16) + 239
// Stores: (734 * vl_bytes/16) + 127
// Instructions: 703
// Unique instructions: 259
// Instructions: 685
// Unique instructions: 255

// Exit.
mov w0, #1 // stdout
Expand Down
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
Hello, world!
Basic counts tool results:
Total counts:
655 total \(fetched\) instructions
243 total unique \(fetched\) instructions
685 total \(fetched\) instructions
255 total unique \(fetched\) instructions
0 total non-fetched instructions
0 total prefetches
#if (__ARM_FEATURE_SVE_BITS == 128)
1069 total data loads
1137 total data loads
861 total data stores
#elif (__ARM_FEATURE_SVE_BITS == 256)
1967 total data loads
2035 total data loads
1595 total data stores
#elif (__ARM_FEATURE_SVE_BITS == 512)
3763 total data loads
3831 total data loads
3063 total data stores
#endif
0 total icache flushes
Expand All @@ -21,18 +21,18 @@ Total counts:
.* total scheduling markers
.*
Thread .* counts:
655 \(fetched\) instructions
243 unique \(fetched\) instructions
685 \(fetched\) instructions
255 unique \(fetched\) instructions
0 non-fetched instructions
0 prefetches
#if (__ARM_FEATURE_SVE_BITS == 128)
1069 data loads
1137 data loads
861 data stores
#elif (__ARM_FEATURE_SVE_BITS == 256)
1967 data loads
2035 data loads
1595 data stores
#elif (__ARM_FEATURE_SVE_BITS == 512)
3763 data loads
3831 data loads
3063 data stores
#endif
0 icache flushes
Expand Down
10 changes: 10 additions & 0 deletions clients/drcachesim/tests/scattergather-aarch64.templatex
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,10 @@ ldnt1w scalar\+scalar: PASS
ld1sw scalar\+scalar: PASS
ld1d scalar\+scalar: PASS
ldnt1d scalar\+scalar: PASS
ld1rqb scalar\+scalar: PASS
ld1rqh scalar\+scalar: PASS
ld1rqw scalar\+scalar: PASS
ld1rqd scalar\+scalar: PASS
ld2b scalar\+scalar: PASS
ld2h scalar\+scalar: PASS
ld2w scalar\+scalar: PASS
Expand Down Expand Up @@ -212,6 +216,12 @@ ld1d scalar\+immediate 64bit element: PASS
ld1d scalar\+immediate 64bit element \(min index\): PASS
ld1d scalar\+immediate 64bit element \(max index\): PASS
ldnt1d scalar\+immediate 64bit element: PASS
ld1rqb scalar\+immediate: PASS
ld1rqh scalar\+immediate: PASS
ld1rqw scalar\+immediate: PASS
ld1rqd scalar\+immediate: PASS
ld1rqd scalar\+immediate \(min index\): PASS
ld1rqd scalar\+immediate \(max index\): PASS
ld2b scalar\+immediate: PASS
ld2h scalar\+immediate: PASS
ld2w scalar\+immediate: PASS
Expand Down
102 changes: 81 additions & 21 deletions ext/drx/scatter_gather_aarch64.c
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ typedef struct _per_thread_t {
* This corresponds to the spill slot storage in per_thread_t.
*/
typedef struct _spill_slot_state_t {
#define NUM_PRED_SLOTS 1
#define NUM_PRED_SLOTS 2
reg_id_t pred_slots[NUM_PRED_SLOTS];

#define NUM_VECTOR_SLOTS 1
Expand Down Expand Up @@ -621,12 +621,37 @@ static void
expand_contiguous(void *drcontext, instrlist_t *bb, instr_t *sg_instr,
const scatter_gather_info_t *sg_info, reg_id_t new_base,
reg_id_t scalar_index, reg_id_t scalar_src_or_dst,
reg_id_t scratch_pred, reg_id_t scratch_vec, app_pc orig_app_pc)
reg_id_t scratch_pred, reg_id_t governing_pred, reg_id_t scratch_vec,
app_pc orig_app_pc)
{
#define EMIT(op, ...) \
instrlist_preinsert( \
bb, sg_instr, INSTR_XL8(INSTR_CREATE_##op(drcontext, __VA_ARGS__), orig_app_pc))

if (sg_info->is_replicating && proc_get_vector_length_bytes() > 16) {
/* This instruction loads a fixed size 16-byte vector which is replicated to
* all quadword elements on hardware with a vector length > 16 bytes.
* Only the bottom 16 bits of the governing predicate register are used so we
* need to mask out any higher bits than that.
*/
DR_ASSERT(sg_info->scatter_gather_size == OPSZ_16);

/* Set scratch_pred to a value with the first 16 elements active */
/* ptrue scratch_pred.b, vl16 */
EMIT(ptrue_sve, opnd_create_reg_element_vector(scratch_pred, OPSZ_1),
opnd_create_immed_pred_constr(DR_PRED_CONSTR_VL16));

/* Create a new governing predicate by applying the mask we created in
* scratch_pred to the instruction's mask_reg.
*/

/* and governing_pred.b, mask_reg/z, mask_reg.b, scratch_pred.b */
EMIT(and_sve_pred_b, opnd_create_reg_element_vector(governing_pred, OPSZ_1),
opnd_create_predicate_reg(sg_info->mask_reg, /*merging=*/false),
opnd_create_reg_element_vector(sg_info->mask_reg, OPSZ_1),
opnd_create_reg_element_vector(scratch_pred, OPSZ_1));
}

/* Calculate the new base address in scratch_gpr0.
* Note that we can't use drutil_insert_get_mem_addr() here because we don't want the
* BSD licensed drx to have a dependency on the LGPL licensed drutil.
Expand Down Expand Up @@ -686,6 +711,7 @@ expand_contiguous(void *drcontext, instrlist_t *bb, instr_t *sg_instr,
} else {
/* scalar+scalar: Keep the original modifier copied from sg_info */
}
modified_sg_info.mask_reg = governing_pred;

/* Note that modified_sg_info might not describe a valid SVE instruction.
* For example if we are expanding:
Expand All @@ -702,6 +728,21 @@ expand_contiguous(void *drcontext, instrlist_t *bb, instr_t *sg_instr,
/* Expand the instruction as if it were a scalar+vector scatter/gather instruction */
expand_scatter_gather(drcontext, bb, sg_instr, &modified_sg_info, scalar_index,
scalar_src_or_dst, scratch_pred, orig_app_pc);

if (sg_info->is_replicating && proc_get_vector_length_bytes() > 16) {
/* All supported replicating loads load a 16-byte vector. */
DR_ASSERT(sg_info->scatter_gather_size == OPSZ_16);

/* Replicate the first quadword element (16 bytes) to the other elements in the
* vector.
*/

/* dup gather_dst.q, gather_dst.q[0]*/
EMIT(dup_sve_idx,
opnd_create_reg_element_vector(sg_info->gather_dst_reg, OPSZ_16),
opnd_create_reg_element_vector(sg_info->gather_dst_reg, OPSZ_16),
opnd_create_immed_uint(0, OPSZ_2b));
}
#undef EMIT
}

Expand All @@ -713,13 +754,20 @@ expand_contiguous(void *drcontext, instrlist_t *bb, instr_t *sg_instr,
reg_id_t
reserve_sve_register(void *drcontext, instrlist_t *bb, instr_t *where,
reg_id_t scratch_gpr0, reg_id_t min_register, reg_id_t max_register,
size_t slot_tls_offset, opnd_size_t reg_size, uint slot_num)
size_t slot_tls_offset, opnd_size_t reg_size, uint slot_num,
reg_id_t *already_allocated_regs, uint num_already_allocated)
{
/* Search the instruction for an unused register we will use as a temp. */
reg_id_t reg;
for (reg = min_register; reg <= max_register; ++reg) {
if (!instr_uses_reg(where, reg))
break;
if (!instr_uses_reg(where, reg)) {
bool reg_already_allocated = false;
for (uint i = 0; !reg_already_allocated && i < num_already_allocated; i++) {
reg_already_allocated = already_allocated_regs[i] == reg;
}
if (!reg_already_allocated)
break;
}
}
DR_ASSERT(!instr_uses_reg(where, reg));

Expand Down Expand Up @@ -759,10 +807,11 @@ reserve_pred_register(void *drcontext, instrlist_t *bb, instr_t *where,
/* Some instructions require the predicate to be in the range p0 - p7. This includes
* LASTB which we use to extract elements from the vector register.
*/
const reg_id_t reg = reserve_sve_register(
drcontext, bb, where, scratch_gpr0, DR_REG_P0, DR_REG_P7,
offsetof(per_thread_t, scratch_pred_spill_slots),
opnd_size_from_bytes(proc_get_vector_length_bytes() / 8), slot);
const reg_id_t reg =
reserve_sve_register(drcontext, bb, where, scratch_gpr0, DR_REG_P0, DR_REG_P7,
offsetof(per_thread_t, scratch_pred_spill_slots),
opnd_size_from_bytes(proc_get_vector_length_bytes() / 8),
slot, slot_state->pred_slots, slot);

slot_state->pred_slots[slot] = reg;
return reg;
Expand All @@ -783,7 +832,8 @@ reserve_vector_register(void *drcontext, instrlist_t *bb, instr_t *where,
const reg_id_t reg =
reserve_sve_register(drcontext, bb, where, scratch_gpr0, DR_REG_Z0, DR_REG_Z31,
offsetof(per_thread_t, scratch_vector_spill_slots_aligned),
opnd_size_from_bytes(proc_get_vector_length_bytes()), slot);
opnd_size_from_bytes(proc_get_vector_length_bytes()), slot,
slot_state->vector_slots, slot);

slot_state->vector_slots[slot] = reg;
return reg;
Expand Down Expand Up @@ -897,10 +947,6 @@ drx_expand_scatter_gather(void *drcontext, instrlist_t *bb, DR_PARAM_OUT bool *e
/* TODO i#5036: Add support for first-fault and non-fault accesses. */
return true;
}
if (sg_info.is_replicating) {
/* TODO i#5036: Add support for ld1rq* replicating loads. */
return true;
}

const bool is_contiguous =
!(reg_is_z(sg_info.base_reg) || reg_is_z(sg_info.index_reg));
Expand Down Expand Up @@ -967,6 +1013,12 @@ drx_expand_scatter_gather(void *drcontext, instrlist_t *bb, DR_PARAM_OUT bool *e
&spill_slot_state);
}

reg_id_t governing_pred = sg_info.mask_reg;
if (sg_info.is_replicating && proc_get_vector_length_bytes() > 16) {
governing_pred = reserve_pred_register(drcontext, bb, sg_instr, scratch_gpr,
&spill_slot_state);
}

const app_pc orig_app_pc = instr_get_app_pc(sg_instr);

emulated_instr_t emulated_instr;
Expand All @@ -980,8 +1032,8 @@ drx_expand_scatter_gather(void *drcontext, instrlist_t *bb, DR_PARAM_OUT bool *e
if (is_contiguous) {
/* scalar+scalar or scalar+immediate predicated contiguous access */
expand_contiguous(drcontext, bb, sg_instr, &sg_info, contiguous_new_base,
scratch_gpr, scalar_src_or_dst, scratch_pred, scratch_vec,
orig_app_pc);
scratch_gpr, scalar_src_or_dst, scratch_pred, governing_pred,
scratch_vec, orig_app_pc);
} else {
/* scalar+vector or vector+immediate scatter/gather */
expand_scatter_gather(drcontext, bb, sg_instr, &sg_info, scratch_gpr,
Expand All @@ -990,13 +1042,21 @@ drx_expand_scatter_gather(void *drcontext, instrlist_t *bb, DR_PARAM_OUT bool *e

drmgr_insert_emulation_end(drcontext, bb, sg_instr);

if (scratch_vec != DR_REG_INVALID) {
unreserve_vector_register(drcontext, bb, sg_instr, scratch_gpr, scratch_vec,
&spill_slot_state);
for (uint i = 0; i < NUM_VECTOR_SLOTS; i++) {
const reg_id_t reg = spill_slot_state.vector_slots[i];
if (reg != DR_REG_NULL) {
unreserve_vector_register(drcontext, bb, sg_instr, scratch_gpr, reg,
&spill_slot_state);
}
}

unreserve_pred_register(drcontext, bb, sg_instr, scratch_gpr, scratch_pred,
&spill_slot_state);
for (uint i = 0; i < NUM_PRED_SLOTS; i++) {
const reg_id_t reg = spill_slot_state.pred_slots[i];
if (reg != DR_REG_NULL) {
unreserve_pred_register(drcontext, bb, sg_instr, scratch_gpr, reg,
&spill_slot_state);
}
}

if (drreg_unreserve_register(drcontext, bb, sg_instr, scratch_gpr) != DRREG_SUCCESS) {
DR_ASSERT_MSG(false, "drreg_unreserve_register should not fail");
Expand Down
Loading

0 comments on commit df27443

Please sign in to comment.