Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add cross builtins and cross HLSL function to DirectX and SPIR-V backend #109180

Merged
merged 11 commits into from
Oct 3, 2024

Conversation

bob80905
Copy link
Contributor

This PR adds the step intrinsic and an HLSL function that uses it.
The SPIRV backend is also implemented.

Used #106471 as a reference.
Fixes #99095

@llvmbot llvmbot added clang Clang issues not falling into any other category backend:X86 clang:frontend Language frontend issues, e.g. anything involving "Sema" clang:headers Headers provided by Clang, e.g. for intrinsics clang:codegen backend:DirectX HLSL HLSL Language Support backend:SPIR-V llvm:ir labels Sep 18, 2024
@llvmbot
Copy link
Collaborator

llvmbot commented Sep 18, 2024

@llvm/pr-subscribers-backend-spir-v
@llvm/pr-subscribers-llvm-ir
@llvm/pr-subscribers-hlsl
@llvm/pr-subscribers-clang

@llvm/pr-subscribers-backend-directx

Author: Joshua Batista (bob80905)

Changes

This PR adds the step intrinsic and an HLSL function that uses it.
The SPIRV backend is also implemented.

Used #106471 as a reference.
Fixes #99095


Full diff: https://github.com/llvm/llvm-project/pull/109180.diff

13 Files Affected:

  • (modified) clang/include/clang/Basic/Builtins.td (+6)
  • (modified) clang/lib/CodeGen/CGBuiltin.cpp (+15)
  • (modified) clang/lib/CodeGen/CGHLSLRuntime.h (+1)
  • (modified) clang/lib/Headers/hlsl/hlsl_intrinsics.h (+22)
  • (modified) clang/lib/Sema/SemaHLSL.cpp (+14)
  • (added) clang/test/CodeGenHLSL/builtins/cross.hlsl (+36)
  • (added) clang/test/SemaHLSL/BuiltIns/cross-errors.hlsl (+31)
  • (modified) llvm/include/llvm/IR/IntrinsicsDirectX.td (+1)
  • (modified) llvm/include/llvm/IR/IntrinsicsSPIRV.td (+1)
  • (modified) llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp (+40)
  • (modified) llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp (+23-1)
  • (added) llvm/test/CodeGen/DirectX/cross.ll (+57)
  • (added) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/cross.ll (+33)
diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index 8c5d7ad763bf97..c4735b59dfb0be 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -4709,6 +4709,12 @@ def HLSLCreateHandle : LangBuiltin<"HLSL_LANG"> {
   let Prototype = "void*(unsigned char)";
 }
 
+def HLSLCross: LangBuiltin<"HLSL_LANG"> {
+  let Spellings = ["__builtin_hlsl_cross"];
+  let Attributes = [NoThrow, Const];
+  let Prototype = "void(...)";
+}
+
 def HLSLDotProduct : LangBuiltin<"HLSL_LANG"> {
   let Spellings = ["__builtin_hlsl_dot"];
   let Attributes = [NoThrow, Const];
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 7e18aafcdd4b8a..0883ad17ce0ffd 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18639,6 +18639,21 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
         IsUnsigned ? Intrinsic::dx_uclamp : Intrinsic::dx_clamp,
         ArrayRef<Value *>{OpX, OpMin, OpMax}, nullptr, "dx.clamp");
   }
+  case Builtin::BI__builtin_hlsl_cross: {
+    Value *Op0 = EmitScalarExpr(E->getArg(0));
+    Value *Op1 = EmitScalarExpr(E->getArg(1));
+    assert(E->getArg(0)->getType()->hasFloatingRepresentation() &&
+           E->getArg(1)->getType()->hasFloatingRepresentation() &&
+           "step operands must have a float representation");
+    // make sure each vector has exactly 3 elements
+    auto *XVecTy1 = E->getArg(0)->getType()->getAs<VectorType>();
+    auto *XVecTy2 = E->getArg(1)->getType()->getAs<VectorType>();
+    assert(XVecTy1->getNumElements() == 3 && XVecTy2->getNumElements() == 3 &&
+           "input vectors must have 3 elements each");
+    return Builder.CreateIntrinsic(
+        /*ReturnType=*/Op0->getType(), CGM.getHLSLRuntime().getCrossIntrinsic(),
+        ArrayRef<Value *>{Op0, Op1}, nullptr, "hlsl.cross");
+  }
   case Builtin::BI__builtin_hlsl_dot: {
     Value *Op0 = EmitScalarExpr(E->getArg(0));
     Value *Op1 = EmitScalarExpr(E->getArg(1));
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h b/clang/lib/CodeGen/CGHLSLRuntime.h
index a8aabca7348ffb..6722d2c7c50a2b 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.h
+++ b/clang/lib/CodeGen/CGHLSLRuntime.h
@@ -74,6 +74,7 @@ class CGHLSLRuntime {
 
   GENERATE_HLSL_INTRINSIC_FUNCTION(All, all)
   GENERATE_HLSL_INTRINSIC_FUNCTION(Any, any)
+  GENERATE_HLSL_INTRINSIC_FUNCTION(Cross, cross)
   GENERATE_HLSL_INTRINSIC_FUNCTION(Frac, frac)
   GENERATE_HLSL_INTRINSIC_FUNCTION(Length, length)
   GENERATE_HLSL_INTRINSIC_FUNCTION(Lerp, lerp)
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 6cd6a2caf19994..e4b40978dea6b9 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -1563,6 +1563,28 @@ uint64_t3 reversebits(uint64_t3);
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_bitreverse)
 uint64_t4 reversebits(uint64_t4);
 
+//===----------------------------------------------------------------------===//
+// cross builtins
+//===----------------------------------------------------------------------===//
+
+/// \fn T cross(T x, T y)
+/// \brief Returns the cross product of two floating-point, 3D vectors.
+/// \param x [in] The first floating-point, 3D vector.
+/// \param y [in] The second floating-point, 3D vector.
+///
+/// Result is the cross product of x and y, i.e., the resulting
+/// components are, in order :
+/// x[1] * y[2] - y[1] * x[2]
+/// x[2] * y[0] - y[2] * x[0]
+/// x[0] * y[1] - y[0] * x[1]
+
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_cross)
+half3 cross(half3, half3);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_cross)
+float3 cross(float3, float3);
+
 //===----------------------------------------------------------------------===//
 // rcp builtins
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index a303f211501348..e7e7a3c259e27a 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -1704,6 +1704,20 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
       return true;
     break;
   }
+  case Builtin::BI__builtin_hlsl_cross: {
+    if (SemaRef.checkArgCount(TheCall, 2))
+      return true;
+    if (CheckVectorElementCallArgs(&SemaRef, TheCall))
+      return true;
+    if (CheckFloatOrHalfRepresentations(&SemaRef, TheCall))
+      return true;
+
+    ExprResult A = TheCall->getArg(0);
+    QualType ArgTyA = A.get()->getType();
+    // return type is the same as the input type
+    TheCall->setType(ArgTyA);
+    break;
+  }
   case Builtin::BI__builtin_hlsl_dot: {
     if (SemaRef.checkArgCount(TheCall, 2))
       return true;
diff --git a/clang/test/CodeGenHLSL/builtins/cross.hlsl b/clang/test/CodeGenHLSL/builtins/cross.hlsl
new file mode 100644
index 00000000000000..047e7fef7136fe
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/cross.hlsl
@@ -0,0 +1,36 @@
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
+// RUN:   dxil-pc-shadermodel6.3-library %s -fnative-half-type \
+// RUN:   -emit-llvm -disable-llvm-passes -o - | FileCheck %s \
+// RUN:   --check-prefixes=CHECK,NATIVE_HALF \
+// RUN:   -DFNATTRS=noundef -DTARGET=dx
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
+// RUN:   dxil-pc-shadermodel6.3-library %s -emit-llvm -disable-llvm-passes \
+// RUN:   -o - | FileCheck %s --check-prefixes=CHECK,NO_HALF \
+// RUN:   -DFNATTRS=noundef -DTARGET=dx
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
+// RUN:   spirv-unknown-vulkan-compute %s -fnative-half-type \
+// RUN:   -emit-llvm -disable-llvm-passes -o - | FileCheck %s \
+// RUN:   --check-prefixes=CHECK,NATIVE_HALF \
+// RUN:   -DFNATTRS="spir_func noundef" -DTARGET=spv
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
+// RUN:   spirv-unknown-vulkan-compute %s -emit-llvm -disable-llvm-passes \
+// RUN:   -o - | FileCheck %s --check-prefixes=CHECK,NO_HALF \
+// RUN:   -DFNATTRS="spir_func noundef" -DTARGET=spv
+
+// NATIVE_HALF: define [[FNATTRS]] <3 x half> @
+// NATIVE_HALF: call <3 x half> @llvm.[[TARGET]].cross.v3f16(<3 x half>
+// NO_HALF: call <3 x float> @llvm.[[TARGET]].cross.v3f32(<3 x float>
+// NATIVE_HALF: ret <3 x half> %hlsl.cross
+// NO_HALF: ret <3 x float> %hlsl.cross
+half3 test_cross_half3(half3 p0, half3 p1)
+{
+    return cross(p0, p1);
+}
+
+// CHECK: define [[FNATTRS]] <3 x float> @
+// CHECK: %hlsl.cross = call <3 x float> @llvm.[[TARGET]].cross.v3f32(
+// CHECK: ret <3 x float> %hlsl.cross
+float3 test_cross_float3(float3 p0, float3 p1)
+{
+    return cross(p0, p1);
+}
diff --git a/clang/test/SemaHLSL/BuiltIns/cross-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/cross-errors.hlsl
new file mode 100644
index 00000000000000..40ab4b533a495e
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/cross-errors.hlsl
@@ -0,0 +1,31 @@
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -fnative-half-type -disable-llvm-passes -verify -verify-ignore-unexpected
+
+void test_too_few_arg()
+{
+  return __builtin_hlsl_cross();
+  // expected-error@-1 {{too few arguments to function call, expected 2, have 0}}
+}
+
+void test_too_many_arg(float3 p0)
+{
+  return __builtin_hlsl_cross(p0, p0, p0);
+  // expected-error@-1 {{too many arguments to function call, expected 2, have 3}}
+}
+
+bool builtin_bool_to_float_type_promotion(bool p1)
+{
+  return __builtin_hlsl_cross(p1, p1);
+  // expected-error@-1 {passing 'bool' to parameter of incompatible type 'float'}}
+}
+
+bool builtin_cross_int_to_float_promotion(int p1)
+{
+  return __builtin_hlsl_cross(p1, p1);
+  // expected-error@-1 {{passing 'int' to parameter of incompatible type 'float'}}
+}
+
+bool2 builtin_cross_int2_to_float2_promotion(int2 p1)
+{
+  return __builtin_hlsl_cross(p1, p1);
+  // expected-error@-1 {{passing 'int2' (aka 'vector<int, 2>') to parameter of incompatible type '__attribute__((__vector_size__(2 * sizeof(float)))) float' (vector of 2 'float' values)}}
+}
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index 3ce7b8b987ef86..f4242772bab20c 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -44,6 +44,7 @@ def int_dx_cast_handle : Intrinsic<[llvm_any_ty], [llvm_any_ty]>;
 def int_dx_all : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty], [IntrNoMem]>;
 def int_dx_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty], [IntrNoMem]>;
 def int_dx_clamp : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>], [IntrNoMem]>;
+def int_dx_cross : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>, LLVMMatchType<0>], [IntrNoMem]>;
 def int_dx_uclamp : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>], [IntrNoMem]>;
 def int_dx_saturate : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>], [IntrNoMem]>;
 
diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
index a4c01952927175..480b391bd54fdf 100644
--- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td
+++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
@@ -60,6 +60,7 @@ let TargetPrefix = "spv" in {
       Intrinsic<[ llvm_ptr_ty ], [llvm_i8_ty], [IntrWillReturn]>;
   def int_spv_all : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty]>;
   def int_spv_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty]>;
+  def int_spv_cross : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>, LLVMMatchType<0>], [IntrNoMem]>;
   def int_spv_frac : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]>;
   def int_spv_lerp : Intrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>,LLVMMatchType<0>],
     [IntrNoMem, IntrWillReturn] >;
diff --git a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
index dd73b895b14d37..e921dffede38f8 100644
--- a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
+++ b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
@@ -43,6 +43,7 @@ static bool isIntrinsicExpansion(Function &F) {
   case Intrinsic::dx_all:
   case Intrinsic::dx_any:
   case Intrinsic::dx_clamp:
+  case Intrinsic::dx_cross:
   case Intrinsic::dx_uclamp:
   case Intrinsic::dx_lerp:
   case Intrinsic::dx_length:
@@ -73,6 +74,42 @@ static Value *expandAbs(CallInst *Orig) {
                                  "dx.max");
 }
 
+static Value *expandCrossIntrinsic(CallInst *Orig) {
+
+  VectorType *VT = cast<VectorType>(Orig->getType());
+  if (cast<FixedVectorType>(VT)->getNumElements() != 3)
+    report_fatal_error(Twine("return vector must have exactly 3 elements"),
+                       /* gen_crash_diag=*/false);
+
+  Value *op0 = Orig->getOperand(0);
+  Value *op1 = Orig->getOperand(1);
+  IRBuilder<> Builder(Orig);
+
+  Value *op0_x = Builder.CreateExtractElement(op0, (uint64_t)0);
+  Value *op0_y = Builder.CreateExtractElement(op0, 1);
+  Value *op0_z = Builder.CreateExtractElement(op0, 2);
+
+  Value *op1_x = Builder.CreateExtractElement(op1, (uint64_t)0);
+  Value *op1_y = Builder.CreateExtractElement(op1, 1);
+  Value *op1_z = Builder.CreateExtractElement(op1, 2);
+
+  auto MulSub = [&](Value *x0, Value *y0, Value *x1, Value *y1) -> Value * {
+    Value *xy = Builder.CreateFMul(x0, y1);
+    Value *yx = Builder.CreateFMul(y0, x1);
+    return Builder.CreateFSub(xy, yx);
+  };
+
+  Value *yz_zy = MulSub(op0_y, op0_z, op1_y, op1_z);
+  Value *zx_xz = MulSub(op0_z, op0_x, op1_z, op1_x);
+  Value *xy_yx = MulSub(op0_x, op0_y, op1_x, op1_y);
+
+  Value *cross = UndefValue::get(VT);
+  cross = Builder.CreateInsertElement(cross, yz_zy, (uint64_t)0);
+  cross = Builder.CreateInsertElement(cross, zx_xz, 1);
+  cross = Builder.CreateInsertElement(cross, xy_yx, 2);
+  return cross;
+}
+
 // Create appropriate DXIL float dot intrinsic for the given A and B operands
 // The appropriate opcode will be determined by the size of the operands
 // The dot product is placed in the position indicated by Orig
@@ -434,6 +471,9 @@ static bool expandIntrinsic(Function &F, CallInst *Orig) {
   case Intrinsic::dx_any:
     Result = expandAnyOrAllIntrinsic(Orig, IntrinsicId);
     break;
+  case Intrinsic::dx_cross:
+    Result = expandCrossIntrinsic(Orig);
+    break;
   case Intrinsic::dx_uclamp:
   case Intrinsic::dx_clamp:
     Result = expandClampIntrinsic(Orig, IntrinsicId);
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index b526c9f29f1e6a..677119840709aa 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -167,7 +167,8 @@ class SPIRVInstructionSelector : public InstructionSelector {
 
   bool selectCmp(Register ResVReg, const SPIRVType *ResType,
                  unsigned comparisonOpcode, MachineInstr &I) const;
-
+  bool selectCross(Register ResVReg, const SPIRVType *ResType,
+                   MachineInstr &I) const;
   bool selectICmp(Register ResVReg, const SPIRVType *ResType,
                   MachineInstr &I) const;
   bool selectFCmp(Register ResVReg, const SPIRVType *ResType,
@@ -1465,6 +1466,25 @@ bool SPIRVInstructionSelector::selectAny(Register ResVReg,
   return selectAnyOrAll(ResVReg, ResType, I, SPIRV::OpAny);
 }
 
+bool SPIRVInstructionSelector::selectCross(Register ResVReg,
+                                           const SPIRVType *ResType,
+                                           MachineInstr &I) const {
+
+  assert(I.getNumOperands() == 4);
+  assert(I.getOperand(2).isReg());
+  assert(I.getOperand(3).isReg());
+  MachineBasicBlock &BB = *I.getParent();
+
+  return BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpExtInst))
+      .addDef(ResVReg)
+      .addUse(GR.getSPIRVTypeID(ResType))
+      .addImm(static_cast<uint32_t>(SPIRV::InstructionSet::GLSL_std_450))
+      .addImm(GL::Cross)
+      .addUse(I.getOperand(2).getReg())
+      .addUse(I.getOperand(3).getReg())
+      .constrainAllUses(TII, TRI, RBI);
+}
+
 bool SPIRVInstructionSelector::selectFmix(Register ResVReg,
                                           const SPIRVType *ResType,
                                           MachineInstr &I) const {
@@ -2458,6 +2478,8 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
     return selectAll(ResVReg, ResType, I);
   case Intrinsic::spv_any:
     return selectAny(ResVReg, ResType, I);
+  case Intrinsic::spv_cross:
+    return selectCross(ResVReg, ResType, I);
   case Intrinsic::spv_lerp:
     return selectFmix(ResVReg, ResType, I);
   case Intrinsic::spv_length:
diff --git a/llvm/test/CodeGen/DirectX/cross.ll b/llvm/test/CodeGen/DirectX/cross.ll
new file mode 100644
index 00000000000000..90847ac635dbba
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/cross.ll
@@ -0,0 +1,57 @@
+; RUN: opt -S  -dxil-intrinsic-expansion  < %s | FileCheck %s --check-prefix=CHECK
+; RUN: opt -S  -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library < %s | FileCheck %s --check-prefix=CHECK
+
+; Make sure dxil operation function calls for cross are generated for half/float.
+
+declare <3 x half> @llvm.dx.cross.v3f16(<3 x half>, <3 x half>)
+declare <3 x float> @llvm.dx.cross.v3f32(<3 x float>, <3 x float>)
+
+define noundef <3 x half> @test_cross_half3(<3 x half> noundef %p0, <3 x half> noundef %p1) {
+entry:
+  ; CHECK: %0 = extractelement <3 x half> %p0, i64 0
+  ; CHECK: %1 = extractelement <3 x half> %p0, i64 1
+  ; CHECK: %2 = extractelement <3 x half> %p0, i64 2
+  ; CHECK: %3 = extractelement <3 x half> %p1, i64 0
+  ; CHECK: %4 = extractelement <3 x half> %p1, i64 1
+  ; CHECK: %5 = extractelement <3 x half> %p1, i64 2
+  ; CHECK: %6 = fmul half %1, %5
+  ; CHECK: %7 = fmul half %2, %4
+  ; CHECK: %8 = fsub half %6, %7
+  ; CHECK: %9 = fmul half %2, %3
+  ; CHECK: %10 = fmul half %0, %5
+  ; CHECK: %11 = fsub half %9, %10
+  ; CHECK: %12 = fmul half %0, %4
+  ; CHECK: %13 = fmul half %1, %3
+  ; CHECK: %14 = fsub half %12, %13
+  ; CHECK: %15 = insertelement <3 x half> undef, half %8, i64 0
+  ; CHECK: %16 = insertelement <3 x half> %15, half %11, i64 1
+  ; CHECK: %17 = insertelement <3 x half> %16, half %14, i64 2
+  ; CHECK: ret <3 x half> %17
+  %hlsl.cross = call <3 x half> @llvm.dx.cross.v3f16(<3 x half> %p0, <3 x half> %p1)
+  ret <3 x half> %hlsl.cross
+}
+
+define noundef <3 x float> @test_cross_float3(<3 x float> noundef %p0, <3 x float> noundef %p1) {
+entry:
+  ; CHECK: %0 = extractelement <3 x float> %p0, i64 0
+  ; CHECK: %1 = extractelement <3 x float> %p0, i64 1
+  ; CHECK: %2 = extractelement <3 x float> %p0, i64 2
+  ; CHECK: %3 = extractelement <3 x float> %p1, i64 0
+  ; CHECK: %4 = extractelement <3 x float> %p1, i64 1
+  ; CHECK: %5 = extractelement <3 x float> %p1, i64 2
+  ; CHECK: %6 = fmul float %1, %5
+  ; CHECK: %7 = fmul float %2, %4
+  ; CHECK: %8 = fsub float %6, %7
+  ; CHECK: %9 = fmul float %2, %3
+  ; CHECK: %10 = fmul float %0, %5
+  ; CHECK: %11 = fsub float %9, %10
+  ; CHECK: %12 = fmul float %0, %4
+  ; CHECK: %13 = fmul float %1, %3
+  ; CHECK: %14 = fsub float %12, %13
+  ; CHECK: %15 = insertelement <3 x float> undef, float %8, i64 0
+  ; CHECK: %16 = insertelement <3 x float> %15, float %11, i64 1
+  ; CHECK: %17 = insertelement <3 x float> %16, float %14, i64 2
+  ; CHECK: ret <3 x float> %17
+  %hlsl.cross = call <3 x float> @llvm.dx.cross.v3f32(<3 x float> %p0, <3 x float> %p1)
+  ret <3 x float> %hlsl.cross
+}
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/cross.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/cross.ll
new file mode 100644
index 00000000000000..2e0eb8c429ac27
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/cross.ll
@@ -0,0 +1,33 @@
+; RUN: llc -O0 -mtriple=spirv-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; Make sure SPIRV operation function calls for cross are lowered correctly.
+
+; CHECK-DAG: %[[#op_ext_glsl:]] = OpExtInstImport "GLSL.std.450"
+; CHECK-DAG: %[[#float_32:]] = OpTypeFloat 32
+; CHECK-DAG: %[[#float_16:]] = OpTypeFloat 16
+; CHECK-DAG: %[[#vec3_float_16:]] = OpTypeVector %[[#float_16]] 3
+; CHECK-DAG: %[[#vec3_float_32:]] = OpTypeVector %[[#float_32]] 3
+
+define noundef <3 x half> @cross_half4(<3 x half> noundef %a, <3 x half> noundef %b) {
+entry:
+  ; CHECK: %[[#]] = OpFunction %[[#vec3_float_16]] None %[[#]]
+  ; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec3_float_16]]
+  ; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec3_float_16]]
+  ; CHECK: %[[#]] = OpExtInst %[[#vec3_float_16]] %[[#op_ext_glsl]] Cross %[[#arg0]] %[[#arg1]]
+  %hlsl.cross = call <3 x half> @llvm.spv.cross.v4f16(<3 x half> %a, <3 x half> %b)
+  ret <3 x half> %hlsl.cross
+}
+
+define noundef <3 x float> @cross_float4(<3 x float> noundef %a, <3 x float> noundef %b) {
+entry:
+  ; CHECK: %[[#]] = OpFunction %[[#vec3_float_32]] None %[[#]]
+  ; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec3_float_32]]
+  ; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec3_float_32]]
+  ; CHECK: %[[#]] = OpExtInst %[[#vec3_float_32]] %[[#op_ext_glsl]] Cross %[[#arg0]] %[[#arg1]]
+  %hlsl.cross = call <3 x float> @llvm.spv.cross.v4f32(<3 x float> %a, <3 x float> %b)
+  ret <3 x float> %hlsl.cross
+}
+
+declare <3 x half> @llvm.spv.cross.v4f16(<3 x half>, <3 x half>)
+declare <3 x float> @llvm.spv.cross.v4f32(<3 x float>, <3 x float>)

@llvmbot
Copy link
Collaborator

llvmbot commented Sep 18, 2024

@llvm/pr-subscribers-clang-codegen

Author: Joshua Batista (bob80905)

Changes

This PR adds the step intrinsic and an HLSL function that uses it.
The SPIRV backend is also implemented.

Used #106471 as a reference.
Fixes #99095


Full diff: https://github.com/llvm/llvm-project/pull/109180.diff

13 Files Affected:

  • (modified) clang/include/clang/Basic/Builtins.td (+6)
  • (modified) clang/lib/CodeGen/CGBuiltin.cpp (+15)
  • (modified) clang/lib/CodeGen/CGHLSLRuntime.h (+1)
  • (modified) clang/lib/Headers/hlsl/hlsl_intrinsics.h (+22)
  • (modified) clang/lib/Sema/SemaHLSL.cpp (+14)
  • (added) clang/test/CodeGenHLSL/builtins/cross.hlsl (+36)
  • (added) clang/test/SemaHLSL/BuiltIns/cross-errors.hlsl (+31)
  • (modified) llvm/include/llvm/IR/IntrinsicsDirectX.td (+1)
  • (modified) llvm/include/llvm/IR/IntrinsicsSPIRV.td (+1)
  • (modified) llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp (+40)
  • (modified) llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp (+23-1)
  • (added) llvm/test/CodeGen/DirectX/cross.ll (+57)
  • (added) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/cross.ll (+33)
diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index 8c5d7ad763bf97..c4735b59dfb0be 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -4709,6 +4709,12 @@ def HLSLCreateHandle : LangBuiltin<"HLSL_LANG"> {
   let Prototype = "void*(unsigned char)";
 }
 
+def HLSLCross: LangBuiltin<"HLSL_LANG"> {
+  let Spellings = ["__builtin_hlsl_cross"];
+  let Attributes = [NoThrow, Const];
+  let Prototype = "void(...)";
+}
+
 def HLSLDotProduct : LangBuiltin<"HLSL_LANG"> {
   let Spellings = ["__builtin_hlsl_dot"];
   let Attributes = [NoThrow, Const];
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 7e18aafcdd4b8a..0883ad17ce0ffd 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18639,6 +18639,21 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
         IsUnsigned ? Intrinsic::dx_uclamp : Intrinsic::dx_clamp,
         ArrayRef<Value *>{OpX, OpMin, OpMax}, nullptr, "dx.clamp");
   }
+  case Builtin::BI__builtin_hlsl_cross: {
+    Value *Op0 = EmitScalarExpr(E->getArg(0));
+    Value *Op1 = EmitScalarExpr(E->getArg(1));
+    assert(E->getArg(0)->getType()->hasFloatingRepresentation() &&
+           E->getArg(1)->getType()->hasFloatingRepresentation() &&
+           "step operands must have a float representation");
+    // make sure each vector has exactly 3 elements
+    auto *XVecTy1 = E->getArg(0)->getType()->getAs<VectorType>();
+    auto *XVecTy2 = E->getArg(1)->getType()->getAs<VectorType>();
+    assert(XVecTy1->getNumElements() == 3 && XVecTy2->getNumElements() == 3 &&
+           "input vectors must have 3 elements each");
+    return Builder.CreateIntrinsic(
+        /*ReturnType=*/Op0->getType(), CGM.getHLSLRuntime().getCrossIntrinsic(),
+        ArrayRef<Value *>{Op0, Op1}, nullptr, "hlsl.cross");
+  }
   case Builtin::BI__builtin_hlsl_dot: {
     Value *Op0 = EmitScalarExpr(E->getArg(0));
     Value *Op1 = EmitScalarExpr(E->getArg(1));
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h b/clang/lib/CodeGen/CGHLSLRuntime.h
index a8aabca7348ffb..6722d2c7c50a2b 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.h
+++ b/clang/lib/CodeGen/CGHLSLRuntime.h
@@ -74,6 +74,7 @@ class CGHLSLRuntime {
 
   GENERATE_HLSL_INTRINSIC_FUNCTION(All, all)
   GENERATE_HLSL_INTRINSIC_FUNCTION(Any, any)
+  GENERATE_HLSL_INTRINSIC_FUNCTION(Cross, cross)
   GENERATE_HLSL_INTRINSIC_FUNCTION(Frac, frac)
   GENERATE_HLSL_INTRINSIC_FUNCTION(Length, length)
   GENERATE_HLSL_INTRINSIC_FUNCTION(Lerp, lerp)
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 6cd6a2caf19994..e4b40978dea6b9 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -1563,6 +1563,28 @@ uint64_t3 reversebits(uint64_t3);
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_bitreverse)
 uint64_t4 reversebits(uint64_t4);
 
+//===----------------------------------------------------------------------===//
+// cross builtins
+//===----------------------------------------------------------------------===//
+
+/// \fn T cross(T x, T y)
+/// \brief Returns the cross product of two floating-point, 3D vectors.
+/// \param x [in] The first floating-point, 3D vector.
+/// \param y [in] The second floating-point, 3D vector.
+///
+/// Result is the cross product of x and y, i.e., the resulting
+/// components are, in order :
+/// x[1] * y[2] - y[1] * x[2]
+/// x[2] * y[0] - y[2] * x[0]
+/// x[0] * y[1] - y[0] * x[1]
+
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_cross)
+half3 cross(half3, half3);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_cross)
+float3 cross(float3, float3);
+
 //===----------------------------------------------------------------------===//
 // rcp builtins
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index a303f211501348..e7e7a3c259e27a 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -1704,6 +1704,20 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
       return true;
     break;
   }
+  case Builtin::BI__builtin_hlsl_cross: {
+    if (SemaRef.checkArgCount(TheCall, 2))
+      return true;
+    if (CheckVectorElementCallArgs(&SemaRef, TheCall))
+      return true;
+    if (CheckFloatOrHalfRepresentations(&SemaRef, TheCall))
+      return true;
+
+    ExprResult A = TheCall->getArg(0);
+    QualType ArgTyA = A.get()->getType();
+    // return type is the same as the input type
+    TheCall->setType(ArgTyA);
+    break;
+  }
   case Builtin::BI__builtin_hlsl_dot: {
     if (SemaRef.checkArgCount(TheCall, 2))
       return true;
diff --git a/clang/test/CodeGenHLSL/builtins/cross.hlsl b/clang/test/CodeGenHLSL/builtins/cross.hlsl
new file mode 100644
index 00000000000000..047e7fef7136fe
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/cross.hlsl
@@ -0,0 +1,36 @@
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
+// RUN:   dxil-pc-shadermodel6.3-library %s -fnative-half-type \
+// RUN:   -emit-llvm -disable-llvm-passes -o - | FileCheck %s \
+// RUN:   --check-prefixes=CHECK,NATIVE_HALF \
+// RUN:   -DFNATTRS=noundef -DTARGET=dx
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
+// RUN:   dxil-pc-shadermodel6.3-library %s -emit-llvm -disable-llvm-passes \
+// RUN:   -o - | FileCheck %s --check-prefixes=CHECK,NO_HALF \
+// RUN:   -DFNATTRS=noundef -DTARGET=dx
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
+// RUN:   spirv-unknown-vulkan-compute %s -fnative-half-type \
+// RUN:   -emit-llvm -disable-llvm-passes -o - | FileCheck %s \
+// RUN:   --check-prefixes=CHECK,NATIVE_HALF \
+// RUN:   -DFNATTRS="spir_func noundef" -DTARGET=spv
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
+// RUN:   spirv-unknown-vulkan-compute %s -emit-llvm -disable-llvm-passes \
+// RUN:   -o - | FileCheck %s --check-prefixes=CHECK,NO_HALF \
+// RUN:   -DFNATTRS="spir_func noundef" -DTARGET=spv
+
+// NATIVE_HALF: define [[FNATTRS]] <3 x half> @
+// NATIVE_HALF: call <3 x half> @llvm.[[TARGET]].cross.v3f16(<3 x half>
+// NO_HALF: call <3 x float> @llvm.[[TARGET]].cross.v3f32(<3 x float>
+// NATIVE_HALF: ret <3 x half> %hlsl.cross
+// NO_HALF: ret <3 x float> %hlsl.cross
+half3 test_cross_half3(half3 p0, half3 p1)
+{
+    return cross(p0, p1);
+}
+
+// CHECK: define [[FNATTRS]] <3 x float> @
+// CHECK: %hlsl.cross = call <3 x float> @llvm.[[TARGET]].cross.v3f32(
+// CHECK: ret <3 x float> %hlsl.cross
+float3 test_cross_float3(float3 p0, float3 p1)
+{
+    return cross(p0, p1);
+}
diff --git a/clang/test/SemaHLSL/BuiltIns/cross-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/cross-errors.hlsl
new file mode 100644
index 00000000000000..40ab4b533a495e
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/cross-errors.hlsl
@@ -0,0 +1,31 @@
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -fnative-half-type -disable-llvm-passes -verify -verify-ignore-unexpected
+
+void test_too_few_arg()
+{
+  return __builtin_hlsl_cross();
+  // expected-error@-1 {{too few arguments to function call, expected 2, have 0}}
+}
+
+void test_too_many_arg(float3 p0)
+{
+  return __builtin_hlsl_cross(p0, p0, p0);
+  // expected-error@-1 {{too many arguments to function call, expected 2, have 3}}
+}
+
+bool builtin_bool_to_float_type_promotion(bool p1)
+{
+  return __builtin_hlsl_cross(p1, p1);
+  // expected-error@-1 {passing 'bool' to parameter of incompatible type 'float'}}
+}
+
+bool builtin_cross_int_to_float_promotion(int p1)
+{
+  return __builtin_hlsl_cross(p1, p1);
+  // expected-error@-1 {{passing 'int' to parameter of incompatible type 'float'}}
+}
+
+bool2 builtin_cross_int2_to_float2_promotion(int2 p1)
+{
+  return __builtin_hlsl_cross(p1, p1);
+  // expected-error@-1 {{passing 'int2' (aka 'vector<int, 2>') to parameter of incompatible type '__attribute__((__vector_size__(2 * sizeof(float)))) float' (vector of 2 'float' values)}}
+}
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index 3ce7b8b987ef86..f4242772bab20c 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -44,6 +44,7 @@ def int_dx_cast_handle : Intrinsic<[llvm_any_ty], [llvm_any_ty]>;
 def int_dx_all : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty], [IntrNoMem]>;
 def int_dx_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty], [IntrNoMem]>;
 def int_dx_clamp : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>], [IntrNoMem]>;
+def int_dx_cross : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>, LLVMMatchType<0>], [IntrNoMem]>;
 def int_dx_uclamp : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>], [IntrNoMem]>;
 def int_dx_saturate : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>], [IntrNoMem]>;
 
diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
index a4c01952927175..480b391bd54fdf 100644
--- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td
+++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
@@ -60,6 +60,7 @@ let TargetPrefix = "spv" in {
       Intrinsic<[ llvm_ptr_ty ], [llvm_i8_ty], [IntrWillReturn]>;
   def int_spv_all : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty]>;
   def int_spv_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty]>;
+  def int_spv_cross : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>, LLVMMatchType<0>], [IntrNoMem]>;
   def int_spv_frac : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]>;
   def int_spv_lerp : Intrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>,LLVMMatchType<0>],
     [IntrNoMem, IntrWillReturn] >;
diff --git a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
index dd73b895b14d37..e921dffede38f8 100644
--- a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
+++ b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
@@ -43,6 +43,7 @@ static bool isIntrinsicExpansion(Function &F) {
   case Intrinsic::dx_all:
   case Intrinsic::dx_any:
   case Intrinsic::dx_clamp:
+  case Intrinsic::dx_cross:
   case Intrinsic::dx_uclamp:
   case Intrinsic::dx_lerp:
   case Intrinsic::dx_length:
@@ -73,6 +74,42 @@ static Value *expandAbs(CallInst *Orig) {
                                  "dx.max");
 }
 
+static Value *expandCrossIntrinsic(CallInst *Orig) {
+
+  VectorType *VT = cast<VectorType>(Orig->getType());
+  if (cast<FixedVectorType>(VT)->getNumElements() != 3)
+    report_fatal_error(Twine("return vector must have exactly 3 elements"),
+                       /* gen_crash_diag=*/false);
+
+  Value *op0 = Orig->getOperand(0);
+  Value *op1 = Orig->getOperand(1);
+  IRBuilder<> Builder(Orig);
+
+  Value *op0_x = Builder.CreateExtractElement(op0, (uint64_t)0);
+  Value *op0_y = Builder.CreateExtractElement(op0, 1);
+  Value *op0_z = Builder.CreateExtractElement(op0, 2);
+
+  Value *op1_x = Builder.CreateExtractElement(op1, (uint64_t)0);
+  Value *op1_y = Builder.CreateExtractElement(op1, 1);
+  Value *op1_z = Builder.CreateExtractElement(op1, 2);
+
+  auto MulSub = [&](Value *x0, Value *y0, Value *x1, Value *y1) -> Value * {
+    Value *xy = Builder.CreateFMul(x0, y1);
+    Value *yx = Builder.CreateFMul(y0, x1);
+    return Builder.CreateFSub(xy, yx);
+  };
+
+  Value *yz_zy = MulSub(op0_y, op0_z, op1_y, op1_z);
+  Value *zx_xz = MulSub(op0_z, op0_x, op1_z, op1_x);
+  Value *xy_yx = MulSub(op0_x, op0_y, op1_x, op1_y);
+
+  Value *cross = UndefValue::get(VT);
+  cross = Builder.CreateInsertElement(cross, yz_zy, (uint64_t)0);
+  cross = Builder.CreateInsertElement(cross, zx_xz, 1);
+  cross = Builder.CreateInsertElement(cross, xy_yx, 2);
+  return cross;
+}
+
 // Create appropriate DXIL float dot intrinsic for the given A and B operands
 // The appropriate opcode will be determined by the size of the operands
 // The dot product is placed in the position indicated by Orig
@@ -434,6 +471,9 @@ static bool expandIntrinsic(Function &F, CallInst *Orig) {
   case Intrinsic::dx_any:
     Result = expandAnyOrAllIntrinsic(Orig, IntrinsicId);
     break;
+  case Intrinsic::dx_cross:
+    Result = expandCrossIntrinsic(Orig);
+    break;
   case Intrinsic::dx_uclamp:
   case Intrinsic::dx_clamp:
     Result = expandClampIntrinsic(Orig, IntrinsicId);
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index b526c9f29f1e6a..677119840709aa 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -167,7 +167,8 @@ class SPIRVInstructionSelector : public InstructionSelector {
 
   bool selectCmp(Register ResVReg, const SPIRVType *ResType,
                  unsigned comparisonOpcode, MachineInstr &I) const;
-
+  bool selectCross(Register ResVReg, const SPIRVType *ResType,
+                   MachineInstr &I) const;
   bool selectICmp(Register ResVReg, const SPIRVType *ResType,
                   MachineInstr &I) const;
   bool selectFCmp(Register ResVReg, const SPIRVType *ResType,
@@ -1465,6 +1466,25 @@ bool SPIRVInstructionSelector::selectAny(Register ResVReg,
   return selectAnyOrAll(ResVReg, ResType, I, SPIRV::OpAny);
 }
 
+bool SPIRVInstructionSelector::selectCross(Register ResVReg,
+                                           const SPIRVType *ResType,
+                                           MachineInstr &I) const {
+
+  assert(I.getNumOperands() == 4);
+  assert(I.getOperand(2).isReg());
+  assert(I.getOperand(3).isReg());
+  MachineBasicBlock &BB = *I.getParent();
+
+  return BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpExtInst))
+      .addDef(ResVReg)
+      .addUse(GR.getSPIRVTypeID(ResType))
+      .addImm(static_cast<uint32_t>(SPIRV::InstructionSet::GLSL_std_450))
+      .addImm(GL::Cross)
+      .addUse(I.getOperand(2).getReg())
+      .addUse(I.getOperand(3).getReg())
+      .constrainAllUses(TII, TRI, RBI);
+}
+
 bool SPIRVInstructionSelector::selectFmix(Register ResVReg,
                                           const SPIRVType *ResType,
                                           MachineInstr &I) const {
@@ -2458,6 +2478,8 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
     return selectAll(ResVReg, ResType, I);
   case Intrinsic::spv_any:
     return selectAny(ResVReg, ResType, I);
+  case Intrinsic::spv_cross:
+    return selectCross(ResVReg, ResType, I);
   case Intrinsic::spv_lerp:
     return selectFmix(ResVReg, ResType, I);
   case Intrinsic::spv_length:
diff --git a/llvm/test/CodeGen/DirectX/cross.ll b/llvm/test/CodeGen/DirectX/cross.ll
new file mode 100644
index 00000000000000..90847ac635dbba
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/cross.ll
@@ -0,0 +1,57 @@
+; RUN: opt -S  -dxil-intrinsic-expansion  < %s | FileCheck %s --check-prefix=CHECK
+; RUN: opt -S  -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library < %s | FileCheck %s --check-prefix=CHECK
+
+; Make sure dxil operation function calls for cross are generated for half/float.
+
+declare <3 x half> @llvm.dx.cross.v3f16(<3 x half>, <3 x half>)
+declare <3 x float> @llvm.dx.cross.v3f32(<3 x float>, <3 x float>)
+
+define noundef <3 x half> @test_cross_half3(<3 x half> noundef %p0, <3 x half> noundef %p1) {
+entry:
+  ; CHECK: %0 = extractelement <3 x half> %p0, i64 0
+  ; CHECK: %1 = extractelement <3 x half> %p0, i64 1
+  ; CHECK: %2 = extractelement <3 x half> %p0, i64 2
+  ; CHECK: %3 = extractelement <3 x half> %p1, i64 0
+  ; CHECK: %4 = extractelement <3 x half> %p1, i64 1
+  ; CHECK: %5 = extractelement <3 x half> %p1, i64 2
+  ; CHECK: %6 = fmul half %1, %5
+  ; CHECK: %7 = fmul half %2, %4
+  ; CHECK: %8 = fsub half %6, %7
+  ; CHECK: %9 = fmul half %2, %3
+  ; CHECK: %10 = fmul half %0, %5
+  ; CHECK: %11 = fsub half %9, %10
+  ; CHECK: %12 = fmul half %0, %4
+  ; CHECK: %13 = fmul half %1, %3
+  ; CHECK: %14 = fsub half %12, %13
+  ; CHECK: %15 = insertelement <3 x half> undef, half %8, i64 0
+  ; CHECK: %16 = insertelement <3 x half> %15, half %11, i64 1
+  ; CHECK: %17 = insertelement <3 x half> %16, half %14, i64 2
+  ; CHECK: ret <3 x half> %17
+  %hlsl.cross = call <3 x half> @llvm.dx.cross.v3f16(<3 x half> %p0, <3 x half> %p1)
+  ret <3 x half> %hlsl.cross
+}
+
+define noundef <3 x float> @test_cross_float3(<3 x float> noundef %p0, <3 x float> noundef %p1) {
+entry:
+  ; CHECK: %0 = extractelement <3 x float> %p0, i64 0
+  ; CHECK: %1 = extractelement <3 x float> %p0, i64 1
+  ; CHECK: %2 = extractelement <3 x float> %p0, i64 2
+  ; CHECK: %3 = extractelement <3 x float> %p1, i64 0
+  ; CHECK: %4 = extractelement <3 x float> %p1, i64 1
+  ; CHECK: %5 = extractelement <3 x float> %p1, i64 2
+  ; CHECK: %6 = fmul float %1, %5
+  ; CHECK: %7 = fmul float %2, %4
+  ; CHECK: %8 = fsub float %6, %7
+  ; CHECK: %9 = fmul float %2, %3
+  ; CHECK: %10 = fmul float %0, %5
+  ; CHECK: %11 = fsub float %9, %10
+  ; CHECK: %12 = fmul float %0, %4
+  ; CHECK: %13 = fmul float %1, %3
+  ; CHECK: %14 = fsub float %12, %13
+  ; CHECK: %15 = insertelement <3 x float> undef, float %8, i64 0
+  ; CHECK: %16 = insertelement <3 x float> %15, float %11, i64 1
+  ; CHECK: %17 = insertelement <3 x float> %16, float %14, i64 2
+  ; CHECK: ret <3 x float> %17
+  %hlsl.cross = call <3 x float> @llvm.dx.cross.v3f32(<3 x float> %p0, <3 x float> %p1)
+  ret <3 x float> %hlsl.cross
+}
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/cross.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/cross.ll
new file mode 100644
index 00000000000000..2e0eb8c429ac27
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/cross.ll
@@ -0,0 +1,33 @@
+; RUN: llc -O0 -mtriple=spirv-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; Make sure SPIRV operation function calls for cross are lowered correctly.
+
+; CHECK-DAG: %[[#op_ext_glsl:]] = OpExtInstImport "GLSL.std.450"
+; CHECK-DAG: %[[#float_32:]] = OpTypeFloat 32
+; CHECK-DAG: %[[#float_16:]] = OpTypeFloat 16
+; CHECK-DAG: %[[#vec3_float_16:]] = OpTypeVector %[[#float_16]] 3
+; CHECK-DAG: %[[#vec3_float_32:]] = OpTypeVector %[[#float_32]] 3
+
+define noundef <3 x half> @cross_half4(<3 x half> noundef %a, <3 x half> noundef %b) {
+entry:
+  ; CHECK: %[[#]] = OpFunction %[[#vec3_float_16]] None %[[#]]
+  ; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec3_float_16]]
+  ; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec3_float_16]]
+  ; CHECK: %[[#]] = OpExtInst %[[#vec3_float_16]] %[[#op_ext_glsl]] Cross %[[#arg0]] %[[#arg1]]
+  %hlsl.cross = call <3 x half> @llvm.spv.cross.v4f16(<3 x half> %a, <3 x half> %b)
+  ret <3 x half> %hlsl.cross
+}
+
+define noundef <3 x float> @cross_float4(<3 x float> noundef %a, <3 x float> noundef %b) {
+entry:
+  ; CHECK: %[[#]] = OpFunction %[[#vec3_float_32]] None %[[#]]
+  ; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec3_float_32]]
+  ; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec3_float_32]]
+  ; CHECK: %[[#]] = OpExtInst %[[#vec3_float_32]] %[[#op_ext_glsl]] Cross %[[#arg0]] %[[#arg1]]
+  %hlsl.cross = call <3 x float> @llvm.spv.cross.v4f32(<3 x float> %a, <3 x float> %b)
+  ret <3 x float> %hlsl.cross
+}
+
+declare <3 x half> @llvm.spv.cross.v4f16(<3 x half>, <3 x half>)
+declare <3 x float> @llvm.spv.cross.v4f32(<3 x float>, <3 x float>)

Value *Op1 = EmitScalarExpr(E->getArg(1));
assert(E->getArg(0)->getType()->hasFloatingRepresentation() &&
E->getArg(1)->getType()->hasFloatingRepresentation() &&
"step operands must have a float representation");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"step operands must have a float representation");
"cross operands must have a float representation");

@@ -44,6 +44,7 @@ def int_dx_cast_handle : Intrinsic<[llvm_any_ty], [llvm_any_ty]>;
def int_dx_all : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty], [IntrNoMem]>;
def int_dx_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty], [IntrNoMem]>;
def int_dx_clamp : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>], [IntrNoMem]>;
def int_dx_cross : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>, LLVMMatchType<0>], [IntrNoMem]>;
Copy link
Member

@farzonl farzonl Sep 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you are trying to keep this alphabetical, but don't put this between the clamp intrinsics.

return true;
if (CheckFloatOrHalfRepresentations(&SemaRef, TheCall))
return true;

Copy link
Member

@farzonl farzonl Sep 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should do a vector size == 3 check to prevent misuse of the builtin? Downsize is the check isn't triggerable from the hlsl function name.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are also other asserts to catch when the size isn't 3, I don't think this is needed.

Copy link
Member

@farzonl farzonl Sep 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If thats the case could you add a test to clang/test/SemaHLSL/BuiltIns/cross-errors.hlsl As long as this gets caugh by an error in Sema then we should be good.

float2 builtin_cross_float2(float2  p1, float2  p2)
{
  return __builtin_hlsl_cross(p1, p2);
}

float4 builtin_cross_float4(float4  p1, float4  p2)
{
  return __builtin_hlsl_cross(p1, p2);
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a diagnostic and put the test in cross-errors.hlsl instead.

Copy link
Contributor

@damyanp damyanp Oct 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't find a float4 builtin_cross_float4(float4 p1, float4 p2) test case in cross-errors.hlsl?

Do float4s implicitly truncate to float3 perhaps?

Value *op1 = Orig->getOperand(1);
IRBuilder<> Builder(Orig);

Value *op0_x = Builder.CreateExtractElement(op0, (uint64_t)0);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: in general I don't think this matters but since cross product is crossing over elements
with x1*y2 - x2*y1 I think its worth naming these elements like so Builder.CreateExtractElement(op0, 0, "x0");
This will effect the SSA names so it isn't just %1-%n

@bob80905 bob80905 self-assigned this Sep 23, 2024
bool2 builtin_cross_int2_to_float2_promotion(int2 p1)
{
return __builtin_hlsl_cross(p1, p1);
// expected-error@-1 {{passing 'int2' (aka 'vector<int, 2>') to parameter of incompatible type '__attribute__((__vector_size__(2 * sizeof(float)))) float' (vector of 2 'float' values)}}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what happens if its

float3  builtin_cross_float3_int3(float3 p1, int3 p2)
{
  return __builtin_hlsl_cross(p1, p2);
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

emits an error that all args must be the same type. Added this case to cross-errors.hlsl

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah ok it got promoted in DXC so wanted t check: https://hlsl.godbolt.org/z/3nvGve7WP

@@ -0,0 +1,57 @@
; RUN: opt -S -dxil-intrinsic-expansion < %s | FileCheck %s --check-prefix=CHECK
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typicall we do --check-prefix if one run is different than the other. in this case they are the same. My recommendation is to drop one of these runs and remove the check prefix.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My point still stands why do you need two runs?

Comment on lines 315 to 316
def err_invalid_vector_size : Error<
"expected vector size of '%0', but vector size is '%1'">;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a little surprised we don't have a similar diagnostic, but it seems that we don't. Maybe we could make a small tweak to an existing message and reuse the message. What if we changed err_vector_incorrect_num_initializers to:

def err_vector_incorrect_num_initializers : Error<
  "%select{too many|too few}0 elements in vector %select{initialization|operand}3 (expected %1 elements, have %2)">;

Then you could add a 0 argument to the one place it is called, and change your new diagnostic to:

Suggested change
def err_invalid_vector_size : Error<
"expected vector size of '%0', but vector size is '%1'">;
def err_incorrect_vector_element_count : Error<
err_vector_incorrect_num_initializers.Summary>;

@@ -10506,7 +10506,9 @@ def err_second_argument_to_cwsc_not_pointer : Error<
"second argument to __builtin_call_with_static_chain must be of pointer type">;

def err_vector_incorrect_num_initializers : Error<
"%select{too many|too few}0 elements in vector initialization (expected %1 elements, have %2)">;
"%select{too many|too few}0 elements in vector %select{initialization|operand}3 (expected %1 elements, have %2)">;
def err_incorrect_vector_element_count : Error<
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you updated all the uses of this to use the other diagnostic, which is fine, but then we can remove this.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also maybe rename err_vector_incorrect_num_initializers -> err_vector_incorrect_num_elements?

let Spellings = ["__builtin_hlsl_create_handle"];
let Attributes = [NoThrow, Const];
let Prototype = "void*(unsigned char)";
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like an unrelated change sliding back in?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's the result of a merge, I'll remove it.

@@ -1977,7 +1977,7 @@ void InitListChecker::CheckVectorType(const InitializedEntity &Entity,
if (!VerifyOnly)
SemaRef.Diag(IList->getBeginLoc(),
diag::err_vector_incorrect_num_initializers)
<< (numEltsInit < maxElements) << maxElements << numEltsInit;
<< (numEltsInit < maxElements) << maxElements << numEltsInit << 0;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: please comment the meaning of the number

Suggested change
<< (numEltsInit < maxElements) << maxElements << numEltsInit << 0;
<< (numEltsInit < maxElements) << maxElements << numEltsInit << /*initialization*/ 0;


SemaRef.Diag(TheCall->getBeginLoc(),
diag::err_vector_incorrect_num_initializers)
<< LessOrMore << 3 << NumElementsArg2 << 1;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
<< LessOrMore << 3 << NumElementsArg2 << 1;
<< LessOrMore << 3 << NumElementsArg2 << /*operand*/ 1;

int LessOrMore = NumElementsArg1 > 3 ? 1 : 0;
SemaRef.Diag(TheCall->getBeginLoc(),
diag::err_vector_incorrect_num_initializers)
<< LessOrMore << 3 << NumElementsArg1 << 1;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
<< LessOrMore << 3 << NumElementsArg1 << 1;
<< LessOrMore << 3 << NumElementsArg1 << /*operand*/ 1;

Copy link
Collaborator

@llvm-beanz llvm-beanz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few last small things, but mostly looks good.

Comment on lines 4712 to 4716
def HLSLCreateHandle : LangBuiltin<"HLSL_LANG"> {
let Spellings = ["__builtin_hlsl_create_handle"];
let Attributes = [NoThrow, Const];
let Prototype = "void*(unsigned char)";
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was removed recently, looks like your change is accidentally reintroducing it.

return true;
}
if (NumElementsArg2 != 3) {
int LessOrMore = NumElementsArg1 > 3 ? 1 : 0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be NumElementsArg2?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch!

Comment on lines 20 to 24
// NATIVE_HALF: define [[FNATTRS]] <3 x half> @
// NATIVE_HALF: call <3 x half> @llvm.[[TARGET]].cross.v3f16(<3 x half>
// NO_HALF: call <3 x float> @llvm.[[TARGET]].cross.v3f32(<3 x float>
// NATIVE_HALF: ret <3 x half> %hlsl.cross
// NO_HALF: ret <3 x float> %hlsl.cross
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's probably clearer to put the NATIVE_HALF checks consecutively, followed by the NO_HALF checks - the way they're intertwined here is a bit hard to read. Also, it looks like the define check is missing for the NO_HALF case.

auto MulSub = [&](Value *x0, Value *y0, Value *x1, Value *y1) -> Value * {
Value *xy = Builder.CreateFMul(x0, y1);
Value *yx = Builder.CreateFMul(y0, x1);
return Builder.CreateFSub(xy, yx);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we copy the value name from Orig for the final result here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so, since we want distinct variable names when constructing the final result vector. Each insertelement needs a unique variable name so that contents of the vector are correct.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried it out and turns out each variable was assigned a unique variable name, contrary to what I expected. So your comment has been incorporated.

Copy link
Contributor

@damyanp damyanp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, and I see you've addressed all of the feedback.

Would be good to check @bogner is happy with the resolution of this issue before merging.

Copy link
Contributor

@pow2clk pow2clk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me. I just have a couple nits that you might include if you create another commit for another reason.

@@ -0,0 +1,43 @@
// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -fnative-half-type -disable-llvm-passes -verify -verify-ignore-unexpected
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this producing or expected to produce unexpected diagnostics? I see this flag in a few other uses of this in SemaHLSL, but not so much in other Sema* test directories. Perhaps it's getting copied over needlessly?

If there are some tricky diagnostics getting produced, it might be better to limit them by assigning tis to a specific type of diagnostic.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, this revealed a missing open curly brace on line 18!
I got rid of this extra flag, seems to be a needless copy.

// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
// RUN: spirv-unknown-vulkan-compute %s -emit-llvm -disable-llvm-passes \
// RUN: -o - | FileCheck %s --check-prefixes=CHECK,NO_HALF \
// RUN: -DFNATTRS="spir_func noundef" -DTARGET=spv
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's harmless, but given that these run lines are long and hard to compare as-is, I'll note for next time that -x hlsl isn't necessary as long as the file ends in .hlsl

Value *op1 = Orig->getOperand(1);
IRBuilder<> Builder(Orig);

Value *op0_x = Builder.CreateExtractElement(op0, (uint64_t)0, "x0");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious why the 0 needs a cast and the others don't.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without the cast I get:
error C2668: 'llvm::IRBuilderBase::CreateExtractElement': ambiguous call to overloaded function note: while trying to match the argument list '(llvm::Value *, int)'
as a compilation error.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't C++ fun!

@bob80905 bob80905 merged commit c098435 into llvm:main Oct 3, 2024
9 checks passed
xgupta pushed a commit to xgupta/llvm-project that referenced this pull request Oct 4, 2024
…end (llvm#109180)

This PR adds the step intrinsic and an HLSL function that uses it.
The SPIRV backend is also implemented.

Used llvm#106471 as a reference.
Fixes llvm#99095
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:DirectX backend:SPIR-V backend:X86 clang:codegen clang:frontend Language frontend issues, e.g. anything involving "Sema" clang:headers Headers provided by Clang, e.g. for intrinsics clang Clang issues not falling into any other category HLSL HLSL Language Support llvm:ir
Projects
Status: No status
Development

Successfully merging this pull request may close these issues.

Implement the cross HLSL Function
7 participants