Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add cross builtins and cross HLSL function to DirectX and SPIR-V backend #109180

Merged
merged 11 commits into from
Oct 3, 2024
6 changes: 6 additions & 0 deletions clang/include/clang/Basic/Builtins.td
Original file line number Diff line number Diff line change
Expand Up @@ -4709,6 +4709,12 @@ def HLSLCreateHandle : LangBuiltin<"HLSL_LANG"> {
let Prototype = "void*(unsigned char)";
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like an unrelated change sliding back in?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's the result of a merge, I'll remove it.


def HLSLCross: LangBuiltin<"HLSL_LANG"> {
let Spellings = ["__builtin_hlsl_cross"];
let Attributes = [NoThrow, Const];
let Prototype = "void(...)";
}

def HLSLDotProduct : LangBuiltin<"HLSL_LANG"> {
let Spellings = ["__builtin_hlsl_dot"];
let Attributes = [NoThrow, Const];
Expand Down
2 changes: 2 additions & 0 deletions clang/include/clang/Basic/DiagnosticSemaKinds.td
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,8 @@ def err_invalid_vector_long_double_decl_spec : Error<
"cannot use 'long double' with '__vector'">;
def err_invalid_vector_complex_decl_spec : Error<
"cannot use '_Complex' with '__vector'">;
def err_invalid_vector_size : Error<
"expected vector size of '%0', but vector size is '%1'">;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a little surprised we don't have a similar diagnostic, but it seems that we don't. Maybe we could make a small tweak to an existing message and reuse the message. What if we changed err_vector_incorrect_num_initializers to:

def err_vector_incorrect_num_initializers : Error<
  "%select{too many|too few}0 elements in vector %select{initialization|operand}3 (expected %1 elements, have %2)">;

Then you could add a 0 argument to the one place it is called, and change your new diagnostic to:

Suggested change
def err_invalid_vector_size : Error<
"expected vector size of '%0', but vector size is '%1'">;
def err_incorrect_vector_element_count : Error<
err_vector_incorrect_num_initializers.Summary>;

def warn_vector_long_decl_spec_combination : Warning<
"use of 'long' with '__vector' is deprecated">, InGroup<Deprecated>;

Expand Down
15 changes: 15 additions & 0 deletions clang/lib/CodeGen/CGBuiltin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18639,6 +18639,21 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
IsUnsigned ? Intrinsic::dx_uclamp : Intrinsic::dx_clamp,
ArrayRef<Value *>{OpX, OpMin, OpMax}, nullptr, "dx.clamp");
}
case Builtin::BI__builtin_hlsl_cross: {
Value *Op0 = EmitScalarExpr(E->getArg(0));
Value *Op1 = EmitScalarExpr(E->getArg(1));
assert(E->getArg(0)->getType()->hasFloatingRepresentation() &&
E->getArg(1)->getType()->hasFloatingRepresentation() &&
"cross operands must have a float representation");
// make sure each vector has exactly 3 elements
auto *XVecTy1 = E->getArg(0)->getType()->getAs<VectorType>();
auto *XVecTy2 = E->getArg(1)->getType()->getAs<VectorType>();
assert(XVecTy1->getNumElements() == 3 && XVecTy2->getNumElements() == 3 &&
"input vectors must have 3 elements each");
return Builder.CreateIntrinsic(
/*ReturnType=*/Op0->getType(), CGM.getHLSLRuntime().getCrossIntrinsic(),
ArrayRef<Value *>{Op0, Op1}, nullptr, "hlsl.cross");
}
case Builtin::BI__builtin_hlsl_dot: {
Value *Op0 = EmitScalarExpr(E->getArg(0));
Value *Op1 = EmitScalarExpr(E->getArg(1));
Expand Down
1 change: 1 addition & 0 deletions clang/lib/CodeGen/CGHLSLRuntime.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ class CGHLSLRuntime {

GENERATE_HLSL_INTRINSIC_FUNCTION(All, all)
GENERATE_HLSL_INTRINSIC_FUNCTION(Any, any)
GENERATE_HLSL_INTRINSIC_FUNCTION(Cross, cross)
GENERATE_HLSL_INTRINSIC_FUNCTION(Frac, frac)
GENERATE_HLSL_INTRINSIC_FUNCTION(Length, length)
GENERATE_HLSL_INTRINSIC_FUNCTION(Lerp, lerp)
Expand Down
22 changes: 22 additions & 0 deletions clang/lib/Headers/hlsl/hlsl_intrinsics.h
Original file line number Diff line number Diff line change
Expand Up @@ -1563,6 +1563,28 @@ uint64_t3 reversebits(uint64_t3);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_bitreverse)
uint64_t4 reversebits(uint64_t4);

//===----------------------------------------------------------------------===//
// cross builtins
//===----------------------------------------------------------------------===//

/// \fn T cross(T x, T y)
/// \brief Returns the cross product of two floating-point, 3D vectors.
/// \param x [in] The first floating-point, 3D vector.
/// \param y [in] The second floating-point, 3D vector.
///
/// Result is the cross product of x and y, i.e., the resulting
/// components are, in order :
/// x[1] * y[2] - y[1] * x[2]
/// x[2] * y[0] - y[2] * x[0]
/// x[0] * y[1] - y[0] * x[1]

_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_cross)
half3 cross(half3, half3);

_HLSL_BUILTIN_ALIAS(__builtin_hlsl_cross)
float3 cross(float3, float3);

//===----------------------------------------------------------------------===//
// rcp builtins
//===----------------------------------------------------------------------===//
Expand Down
29 changes: 29 additions & 0 deletions clang/lib/Sema/SemaHLSL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1704,6 +1704,35 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
return true;
break;
}
case Builtin::BI__builtin_hlsl_cross: {
if (SemaRef.checkArgCount(TheCall, 2))
return true;
if (CheckVectorElementCallArgs(&SemaRef, TheCall))
return true;
if (CheckFloatOrHalfRepresentations(&SemaRef, TheCall))
return true;
// ensure both args have 3 elements
int NumElementsArg1 =
TheCall->getArg(0)->getType()->getAs<VectorType>()->getNumElements();
int NumElementsArg2 =
TheCall->getArg(1)->getType()->getAs<VectorType>()->getNumElements();
if (NumElementsArg1 != 3) {
SemaRef.Diag(TheCall->getBeginLoc(), diag::err_invalid_vector_size)
<< NumElementsArg1 << 3;
return true;
}
if (NumElementsArg2 != 3) {
SemaRef.Diag(TheCall->getBeginLoc(), diag::err_invalid_vector_size)
<< NumElementsArg2 << 3;
return true;
}

Copy link
Member

@farzonl farzonl Sep 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should do a vector size == 3 check to prevent misuse of the builtin? Downsize is the check isn't triggerable from the hlsl function name.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are also other asserts to catch when the size isn't 3, I don't think this is needed.

Copy link
Member

@farzonl farzonl Sep 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If thats the case could you add a test to clang/test/SemaHLSL/BuiltIns/cross-errors.hlsl As long as this gets caugh by an error in Sema then we should be good.

float2 builtin_cross_float2(float2  p1, float2  p2)
{
  return __builtin_hlsl_cross(p1, p2);
}

float4 builtin_cross_float4(float4  p1, float4  p2)
{
  return __builtin_hlsl_cross(p1, p2);
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a diagnostic and put the test in cross-errors.hlsl instead.

Copy link
Contributor

@damyanp damyanp Oct 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't find a float4 builtin_cross_float4(float4 p1, float4 p2) test case in cross-errors.hlsl?

Do float4s implicitly truncate to float3 perhaps?

ExprResult A = TheCall->getArg(0);
QualType ArgTyA = A.get()->getType();
// return type is the same as the input type
TheCall->setType(ArgTyA);
break;
}
case Builtin::BI__builtin_hlsl_dot: {
if (SemaRef.checkArgCount(TheCall, 2))
return true;
Expand Down
36 changes: 36 additions & 0 deletions clang/test/CodeGenHLSL/builtins/cross.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
// RUN: dxil-pc-shadermodel6.3-library %s -fnative-half-type \
// RUN: -emit-llvm -disable-llvm-passes -o - | FileCheck %s \
// RUN: --check-prefixes=CHECK,NATIVE_HALF \
// RUN: -DFNATTRS=noundef -DTARGET=dx
// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
// RUN: dxil-pc-shadermodel6.3-library %s -emit-llvm -disable-llvm-passes \
// RUN: -o - | FileCheck %s --check-prefixes=CHECK,NO_HALF \
// RUN: -DFNATTRS=noundef -DTARGET=dx
// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
// RUN: spirv-unknown-vulkan-compute %s -fnative-half-type \
// RUN: -emit-llvm -disable-llvm-passes -o - | FileCheck %s \
// RUN: --check-prefixes=CHECK,NATIVE_HALF \
// RUN: -DFNATTRS="spir_func noundef" -DTARGET=spv
// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
// RUN: spirv-unknown-vulkan-compute %s -emit-llvm -disable-llvm-passes \
// RUN: -o - | FileCheck %s --check-prefixes=CHECK,NO_HALF \
// RUN: -DFNATTRS="spir_func noundef" -DTARGET=spv
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's harmless, but given that these run lines are long and hard to compare as-is, I'll note for next time that -x hlsl isn't necessary as long as the file ends in .hlsl


// NATIVE_HALF: define [[FNATTRS]] <3 x half> @
// NATIVE_HALF: call <3 x half> @llvm.[[TARGET]].cross.v3f16(<3 x half>
// NO_HALF: call <3 x float> @llvm.[[TARGET]].cross.v3f32(<3 x float>
// NATIVE_HALF: ret <3 x half> %hlsl.cross
// NO_HALF: ret <3 x float> %hlsl.cross
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's probably clearer to put the NATIVE_HALF checks consecutively, followed by the NO_HALF checks - the way they're intertwined here is a bit hard to read. Also, it looks like the define check is missing for the NO_HALF case.

half3 test_cross_half3(half3 p0, half3 p1)
{
return cross(p0, p1);
}

// CHECK: define [[FNATTRS]] <3 x float> @
// CHECK: %hlsl.cross = call <3 x float> @llvm.[[TARGET]].cross.v3f32(
// CHECK: ret <3 x float> %hlsl.cross
float3 test_cross_float3(float3 p0, float3 p1)
{
return cross(p0, p1);
}
43 changes: 43 additions & 0 deletions clang/test/SemaHLSL/BuiltIns/cross-errors.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -fnative-half-type -disable-llvm-passes -verify -verify-ignore-unexpected
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this producing or expected to produce unexpected diagnostics? I see this flag in a few other uses of this in SemaHLSL, but not so much in other Sema* test directories. Perhaps it's getting copied over needlessly?

If there are some tricky diagnostics getting produced, it might be better to limit them by assigning tis to a specific type of diagnostic.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, this revealed a missing open curly brace on line 18!
I got rid of this extra flag, seems to be a needless copy.


void test_too_few_arg()
{
return __builtin_hlsl_cross();
// expected-error@-1 {{too few arguments to function call, expected 2, have 0}}
}

void test_too_many_arg(float3 p0)
{
return __builtin_hlsl_cross(p0, p0, p0);
// expected-error@-1 {{too many arguments to function call, expected 2, have 3}}
}

bool builtin_bool_to_float_type_promotion(bool p1)
{
return __builtin_hlsl_cross(p1, p1);
// expected-error@-1 {passing 'bool' to parameter of incompatible type 'float'}}
}

bool builtin_cross_int_to_float_promotion(int p1)
{
return __builtin_hlsl_cross(p1, p1);
// expected-error@-1 {{passing 'int' to parameter of incompatible type 'float'}}
}

bool2 builtin_cross_int2_to_float2_promotion(int2 p1)
{
return __builtin_hlsl_cross(p1, p1);
// expected-error@-1 {{passing 'int2' (aka 'vector<int, 2>') to parameter of incompatible type '__attribute__((__vector_size__(2 * sizeof(float)))) float' (vector of 2 'float' values)}}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what happens if its

float3  builtin_cross_float3_int3(float3 p1, int3 p2)
{
  return __builtin_hlsl_cross(p1, p2);
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

emits an error that all args must be the same type. Added this case to cross-errors.hlsl

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah ok it got promoted in DXC so wanted t check: https://hlsl.godbolt.org/z/3nvGve7WP

}

float2 builtin_cross_float2(float2 p1, float2 p2)
{
return __builtin_hlsl_cross(p1, p2);
// expected-error@-1 {{expected vector size of '2', but vector size is '3'}}
}

float3 builtin_cross_float3_int3(float3 p1, int3 p2)
{
return __builtin_hlsl_cross(p1, p2);
// expected-error@-1 {{all arguments to '__builtin_hlsl_cross' must have the same type}}
}
1 change: 1 addition & 0 deletions llvm/include/llvm/IR/IntrinsicsDirectX.td
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def int_dx_all : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty], [IntrNoMem]>
def int_dx_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty], [IntrNoMem]>;
def int_dx_clamp : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>], [IntrNoMem]>;
def int_dx_uclamp : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>], [IntrNoMem]>;
def int_dx_cross : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>, LLVMMatchType<0>], [IntrNoMem]>;
def int_dx_saturate : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>], [IntrNoMem]>;

def int_dx_dot2 :
Expand Down
1 change: 1 addition & 0 deletions llvm/include/llvm/IR/IntrinsicsSPIRV.td
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ let TargetPrefix = "spv" in {
Intrinsic<[ llvm_ptr_ty ], [llvm_i8_ty], [IntrWillReturn]>;
def int_spv_all : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty]>;
def int_spv_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty]>;
def int_spv_cross : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>, LLVMMatchType<0>], [IntrNoMem]>;
def int_spv_frac : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]>;
def int_spv_lerp : Intrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>,LLVMMatchType<0>],
[IntrNoMem, IntrWillReturn] >;
Expand Down
40 changes: 40 additions & 0 deletions llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ static bool isIntrinsicExpansion(Function &F) {
case Intrinsic::dx_all:
case Intrinsic::dx_any:
case Intrinsic::dx_clamp:
case Intrinsic::dx_cross:
case Intrinsic::dx_uclamp:
case Intrinsic::dx_lerp:
case Intrinsic::dx_length:
Expand Down Expand Up @@ -73,6 +74,42 @@ static Value *expandAbs(CallInst *Orig) {
"dx.max");
}

static Value *expandCrossIntrinsic(CallInst *Orig) {

VectorType *VT = cast<VectorType>(Orig->getType());
if (cast<FixedVectorType>(VT)->getNumElements() != 3)
report_fatal_error(Twine("return vector must have exactly 3 elements"),
/* gen_crash_diag=*/false);

Value *op0 = Orig->getOperand(0);
Value *op1 = Orig->getOperand(1);
IRBuilder<> Builder(Orig);

Value *op0_x = Builder.CreateExtractElement(op0, (uint64_t)0, "x0");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious why the 0 needs a cast and the others don't.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without the cast I get:
error C2668: 'llvm::IRBuilderBase::CreateExtractElement': ambiguous call to overloaded function note: while trying to match the argument list '(llvm::Value *, int)'
as a compilation error.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't C++ fun!

Value *op0_y = Builder.CreateExtractElement(op0, 1, "x1");
Value *op0_z = Builder.CreateExtractElement(op0, 2, "x2");

Value *op1_x = Builder.CreateExtractElement(op1, (uint64_t)0, "y0");
Value *op1_y = Builder.CreateExtractElement(op1, 1, "y1");
Value *op1_z = Builder.CreateExtractElement(op1, 2, "y2");

auto MulSub = [&](Value *x0, Value *y0, Value *x1, Value *y1) -> Value * {
Value *xy = Builder.CreateFMul(x0, y1);
Value *yx = Builder.CreateFMul(y0, x1);
return Builder.CreateFSub(xy, yx);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we copy the value name from Orig for the final result here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so, since we want distinct variable names when constructing the final result vector. Each insertelement needs a unique variable name so that contents of the vector are correct.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried it out and turns out each variable was assigned a unique variable name, contrary to what I expected. So your comment has been incorporated.

};

Value *yz_zy = MulSub(op0_y, op0_z, op1_y, op1_z);
Value *zx_xz = MulSub(op0_z, op0_x, op1_z, op1_x);
Value *xy_yx = MulSub(op0_x, op0_y, op1_x, op1_y);

Value *cross = UndefValue::get(VT);
cross = Builder.CreateInsertElement(cross, yz_zy, (uint64_t)0);
cross = Builder.CreateInsertElement(cross, zx_xz, 1);
cross = Builder.CreateInsertElement(cross, xy_yx, 2);
return cross;
}

// Create appropriate DXIL float dot intrinsic for the given A and B operands
// The appropriate opcode will be determined by the size of the operands
// The dot product is placed in the position indicated by Orig
Expand Down Expand Up @@ -434,6 +471,9 @@ static bool expandIntrinsic(Function &F, CallInst *Orig) {
case Intrinsic::dx_any:
Result = expandAnyOrAllIntrinsic(Orig, IntrinsicId);
break;
case Intrinsic::dx_cross:
Result = expandCrossIntrinsic(Orig);
break;
case Intrinsic::dx_uclamp:
case Intrinsic::dx_clamp:
Result = expandClampIntrinsic(Orig, IntrinsicId);
Expand Down
24 changes: 23 additions & 1 deletion llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,8 @@ class SPIRVInstructionSelector : public InstructionSelector {

bool selectCmp(Register ResVReg, const SPIRVType *ResType,
unsigned comparisonOpcode, MachineInstr &I) const;

bool selectCross(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;
bool selectICmp(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;
bool selectFCmp(Register ResVReg, const SPIRVType *ResType,
Expand Down Expand Up @@ -1465,6 +1466,25 @@ bool SPIRVInstructionSelector::selectAny(Register ResVReg,
return selectAnyOrAll(ResVReg, ResType, I, SPIRV::OpAny);
}

bool SPIRVInstructionSelector::selectCross(Register ResVReg,
const SPIRVType *ResType,
MachineInstr &I) const {

assert(I.getNumOperands() == 4);
assert(I.getOperand(2).isReg());
assert(I.getOperand(3).isReg());
MachineBasicBlock &BB = *I.getParent();

return BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpExtInst))
.addDef(ResVReg)
.addUse(GR.getSPIRVTypeID(ResType))
.addImm(static_cast<uint32_t>(SPIRV::InstructionSet::GLSL_std_450))
.addImm(GL::Cross)
.addUse(I.getOperand(2).getReg())
.addUse(I.getOperand(3).getReg())
.constrainAllUses(TII, TRI, RBI);
}

bool SPIRVInstructionSelector::selectFmix(Register ResVReg,
const SPIRVType *ResType,
MachineInstr &I) const {
Expand Down Expand Up @@ -2458,6 +2478,8 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
return selectAll(ResVReg, ResType, I);
case Intrinsic::spv_any:
return selectAny(ResVReg, ResType, I);
case Intrinsic::spv_cross:
return selectCross(ResVReg, ResType, I);
case Intrinsic::spv_lerp:
return selectFmix(ResVReg, ResType, I);
case Intrinsic::spv_length:
Expand Down
56 changes: 56 additions & 0 deletions llvm/test/CodeGen/DirectX/cross.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
; RUN: opt -S -dxil-intrinsic-expansion < %s | FileCheck %s

; Make sure dxil operation function calls for cross are generated for half/float.

declare <3 x half> @llvm.dx.cross.v3f16(<3 x half>, <3 x half>)
declare <3 x float> @llvm.dx.cross.v3f32(<3 x float>, <3 x float>)

define noundef <3 x half> @test_cross_half3(<3 x half> noundef %p0, <3 x half> noundef %p1) {
entry:
; CHECK: %x0 = extractelement <3 x half> %p0, i64 0
; CHECK: %x1 = extractelement <3 x half> %p0, i64 1
; CHECK: %x2 = extractelement <3 x half> %p0, i64 2
; CHECK: %y0 = extractelement <3 x half> %p1, i64 0
; CHECK: %y1 = extractelement <3 x half> %p1, i64 1
; CHECK: %y2 = extractelement <3 x half> %p1, i64 2
; CHECK: %0 = fmul half %x1, %y2
; CHECK: %1 = fmul half %x2, %y1
; CHECK: %2 = fsub half %0, %1
; CHECK: %3 = fmul half %x2, %y0
; CHECK: %4 = fmul half %x0, %y2
; CHECK: %5 = fsub half %3, %4
; CHECK: %6 = fmul half %x0, %y1
; CHECK: %7 = fmul half %x1, %y0
; CHECK: %8 = fsub half %6, %7
; CHECK: %9 = insertelement <3 x half> undef, half %2, i64 0
; CHECK: %10 = insertelement <3 x half> %9, half %5, i64 1
; CHECK: %11 = insertelement <3 x half> %10, half %8, i64 2
; CHECK: ret <3 x half> %11
%hlsl.cross = call <3 x half> @llvm.dx.cross.v3f16(<3 x half> %p0, <3 x half> %p1)
ret <3 x half> %hlsl.cross
}

define noundef <3 x float> @test_cross_float3(<3 x float> noundef %p0, <3 x float> noundef %p1) {
entry:
; CHECK: %x0 = extractelement <3 x float> %p0, i64 0
; CHECK: %x1 = extractelement <3 x float> %p0, i64 1
; CHECK: %x2 = extractelement <3 x float> %p0, i64 2
; CHECK: %y0 = extractelement <3 x float> %p1, i64 0
; CHECK: %y1 = extractelement <3 x float> %p1, i64 1
; CHECK: %y2 = extractelement <3 x float> %p1, i64 2
; CHECK: %0 = fmul float %x1, %y2
; CHECK: %1 = fmul float %x2, %y1
; CHECK: %2 = fsub float %0, %1
; CHECK: %3 = fmul float %x2, %y0
; CHECK: %4 = fmul float %x0, %y2
; CHECK: %5 = fsub float %3, %4
; CHECK: %6 = fmul float %x0, %y1
; CHECK: %7 = fmul float %x1, %y0
; CHECK: %8 = fsub float %6, %7
; CHECK: %9 = insertelement <3 x float> undef, float %2, i64 0
; CHECK: %10 = insertelement <3 x float> %9, float %5, i64 1
; CHECK: %11 = insertelement <3 x float> %10, float %8, i64 2
; CHECK: ret <3 x float> %11
%hlsl.cross = call <3 x float> @llvm.dx.cross.v3f32(<3 x float> %p0, <3 x float> %p1)
ret <3 x float> %hlsl.cross
}
Loading
Loading