Skip to content

Commit

Permalink
Fix regression in batched SSE2 patch.
Browse files Browse the repository at this point in the history
Signed-off-by: Tuomas Tonteri <[email protected]>
  • Loading branch information
johnfea committed Aug 16, 2024
1 parent e0197db commit 9e5c674
Showing 1 changed file with 59 additions and 16 deletions.
75 changes: 59 additions & 16 deletions src/liboslexec/llvm_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3698,12 +3698,21 @@ LLVM_Util::mask_as_int(llvm::Value* mask)
// Convert <4 x i1> -> <4 x i32>
llvm::Value* w4_int_mask = builder().CreateSExt(mask,
type_wide_int());

// Now we will use the horizontal sign extraction intrinsic
// to build a 32 bit mask value. However the only 256bit
// version works on floats, so we will cast from int32 to
// float beforehand
llvm::Type* w4_float_type = llvm_vector_type(m_llvm_type_float, 4);
llvm::Value* w4_float_mask = builder().CreateBitCast(w4_int_mask,
w4_float_type);

// Now we will use the horizontal sign extraction intrinsic
// to build a 32 bit mask value.
llvm::Function* func = llvm::Intrinsic::getDeclaration(
module(), llvm::Intrinsic::x86_sse2_pmovmskb_128);
module(), llvm::Intrinsic::x86_sse_movmsk_ps);

llvm::Value* args[1] = { w4_int_mask };
llvm::Value* args[1] = { w4_float_mask };
llvm::Value* int8_mask;
int8_mask = builder().CreateCall(func, toArrayRef(args));
return int8_mask;
Expand All @@ -3727,18 +3736,28 @@ LLVM_Util::mask_as_int(llvm::Value* mask)
auto w4_int_masks = op_quarter_16x(wide_int_mask);

// Now we will use the horizontal sign extraction intrinsic
// to build a 32 bit mask value.
// to build a 32 bit mask value. However the only 128bit
// version works on floats, so we will cast from int32 to
// float beforehand
llvm::Type* w4_float_type = llvm_vector_type(m_llvm_type_float, 4);
std::array<llvm::Value*, 4> w4_float_masks = {
{ builder().CreateBitCast(w4_int_masks[0], w4_float_type),
builder().CreateBitCast(w4_int_masks[1], w4_float_type),
builder().CreateBitCast(w4_int_masks[2], w4_float_type),
builder().CreateBitCast(w4_int_masks[3], w4_float_type) }
};

llvm::Function* func = llvm::Intrinsic::getDeclaration(
module(), llvm::Intrinsic::x86_sse2_pmovmskb_128);
module(), llvm::Intrinsic::x86_sse_movmsk_ps);

llvm::Value* args[1] = { w4_int_masks[0] };
llvm::Value* args[1] = { w4_float_masks[0] };
std::array<llvm::Value*, 4> int4_masks;
int4_masks[0] = builder().CreateCall(func, toArrayRef(args));
args[0] = w4_int_masks[1];
args[0] = w4_float_masks[1];
int4_masks[1] = builder().CreateCall(func, toArrayRef(args));
args[0] = w4_int_masks[2];
args[0] = w4_float_masks[2];
int4_masks[2] = builder().CreateCall(func, toArrayRef(args));
args[0] = w4_int_masks[3];
args[0] = w4_float_masks[3];
int4_masks[3] = builder().CreateCall(func, toArrayRef(args));

llvm::Value* bits12_15 = op_shl(int4_masks[3], constant(12));
Expand All @@ -3759,14 +3778,22 @@ LLVM_Util::mask_as_int(llvm::Value* mask)
auto w4_int_masks = op_split_8x(wide_int_mask);

// Now we will use the horizontal sign extraction intrinsic
// to build a 32 bit mask value.
// to build a 32 bit mask value. However the only 128bit
// version works on floats, so we will cast from int32 to
// float beforehand
llvm::Type* w4_float_type = llvm_vector_type(m_llvm_type_float, 4);
std::array<llvm::Value*, 2> w4_float_masks = {
{ builder().CreateBitCast(w4_int_masks[0], w4_float_type),
builder().CreateBitCast(w4_int_masks[1], w4_float_type) }
};

llvm::Function* func = llvm::Intrinsic::getDeclaration(
module(), llvm::Intrinsic::x86_sse2_pmovmskb_128);
module(), llvm::Intrinsic::x86_sse_movmsk_ps);

llvm::Value* args[1] = { w4_int_masks[0] };
llvm::Value* args[1] = { w4_float_masks[0] };
std::array<llvm::Value*, 2> int4_masks;
int4_masks[0] = builder().CreateCall(func, toArrayRef(args));
args[0] = w4_int_masks[1];
args[0] = w4_float_masks[1];
int4_masks[1] = builder().CreateCall(func, toArrayRef(args));

llvm::Value* bits4_7 = op_shl(int4_masks[1], constant(4));
Expand All @@ -3782,12 +3809,20 @@ LLVM_Util::mask_as_int(llvm::Value* mask)
llvm::Value* w4_int_mask = builder().CreateSExt(mask,
type_wide_int());

// Now we will use the horizontal sign extraction intrinsic
// to build a 32 bit mask value. However the only 256bit
// version works on floats, so we will cast from int32 to
// float beforehand
llvm::Type* w4_float_type = llvm_vector_type(m_llvm_type_float, 4);
llvm::Value* w4_float_mask = builder().CreateBitCast(w4_int_mask,
w4_float_type);

// Now we will use the horizontal sign extraction intrinsic
// to build a 32 bit mask value.
llvm::Function* func = llvm::Intrinsic::getDeclaration(
module(), llvm::Intrinsic::x86_sse2_pmovmskb_128);
module(), llvm::Intrinsic::x86_sse_movmsk_ps);

llvm::Value* args[1] = { w4_int_mask };
llvm::Value* args[1] = { w4_float_mask };
llvm::Value* int4_mask = builder().CreateCall(func,
toArrayRef(args));

Expand Down Expand Up @@ -3833,12 +3868,20 @@ LLVM_Util::mask4_as_int8(llvm::Value* mask)
// Convert <4 x i1> -> <4 x i32>
llvm::Value* w4_int_mask = builder().CreateSExt(mask, type_wide_int());

// Now we will use the horizontal sign extraction intrinsic
// to build a 32 bit mask value. However the only 256bit
// version works on floats, so we will cast from int32 to
// float beforehand
llvm::Type* w4_float_type = llvm_vector_type(m_llvm_type_float, 4);
llvm::Value* w4_float_mask = builder().CreateBitCast(w4_int_mask,
w4_float_type);

// Now we will use the horizontal sign extraction intrinsic
// to build a 32 bit mask value.
llvm::Function* func = llvm::Intrinsic::getDeclaration(
module(), llvm::Intrinsic::x86_sse2_pmovmskb_128);
module(), llvm::Intrinsic::x86_sse_movmsk_ps);

llvm::Value* args[1] = { w4_int_mask };
llvm::Value* args[1] = { w4_float_mask };
llvm::Value* int32 = builder().CreateCall(func, toArrayRef(args));
llvm::Value* i8 = builder().CreateIntCast(int32, type_int8(), true);

Expand Down

0 comments on commit 9e5c674

Please sign in to comment.