Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing bug when writing primitives in metal mesh shaders #5069

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions source/slang/slang-emit-metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,8 @@ bool MetalSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inO
emitOperand(setIndices->getIndex(), getInfo(EmitOp::General));
m_writer->emit("*");
m_writer->emitUInt64(numIndices);
m_writer->emit("+");
m_writer->emitUInt64(i);
m_writer->emit(",(");
emitOperand(setIndices->getElementValue(), getInfo(EmitOp::General));
m_writer->emit(")[");
Expand Down
71 changes: 71 additions & 0 deletions tests/metal/simple-mesh.slang
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
//TEST:SIMPLE(filecheck=METAL): -entry meshMain -stage mesh -target metal

//
// Mesh shader
//

const static float2 positions[3] = {
float2(0.0, -0.5),
float2(0.5, 0.5),
float2(-0.5, 0.5)
};

const static float3 colors[3] = {
float3(1.0, 1.0, 0.0),
float3(0.0, 1.0, 1.0),
float3(1.0, 0.0, 1.0)
};

struct MeshPayload
{
int exponent;
};


struct Vertex
{
float4 pos : SV_Position;
float3 color : Color;
int index : Index;
int value : Value;
};

struct Primitive
{
uint prim : SV_PrimitiveID;
};

const static uint MAX_VERTS = 12;
const static uint MAX_PRIMS = 4;

[outputtopology("triangle")]
[numthreads(12, 1, 1)]
void meshMain(
in uint tig: SV_GroupIndex,
in payload MeshPayload meshPayload,
// METAL: const MeshPayload_0 object_data* meshPayload_0
OutputVertices<Vertex, MAX_VERTS> verts,
OutputIndices<uint3, MAX_PRIMS> triangles,
OutputPrimitives<Primitive, MAX_PRIMS> primitives
)
{
const uint numVertices = 12;
const uint numPrimitives = 4;
SetMeshOutputCounts(numVertices, numPrimitives);

if (tig < numVertices)
{
const int tri = tig / 3;
verts[tig] = { float4(positions[tig % 3], 0, 1), colors[tig % 3], tri, int(pow(tri, meshPayload.exponent)) };
}

if (tig < numPrimitives)
{
// METAL: _slang_mesh.set_index({{.*}}+0,{{.*}}[0]);
// METAL: _slang_mesh.set_index({{.*}}+1,{{.*}}[1]);
// METAL: _slang_mesh.set_index({{.*}}+2,{{.*}}[2]);
triangles[tig] = tig * 3 + uint3(0, 1, 2);
// METAL: _slang_mesh.set_primitive({{.*}}
primitives[tig] = { tig };
}
}
95 changes: 5 additions & 90 deletions tests/metal/simple-task.slang
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
//TEST:SIMPLE(filecheck=CHECK): -entry taskMain -stage amplification -target metal

//TEST_INPUT: ubuffer(data=[0 0 0 0], stride=4):out,name outputBuffer

uniform RWStructuredBuffer<float> outputBuffer;
//TEST:SIMPLE(filecheck=METAL): -entry taskMain -stage amplification -target metal

cbuffer Uniforms
{
Expand All @@ -18,95 +14,14 @@ struct MeshPayload
int exponent;
};

// CHECK: MeshPayload_0 object_data* _slang_mesh_payload
// CHECK: mesh_grid_properties _slang_mgp
// METAL: MeshPayload_0 object_data* _slang_mesh_payload
// METAL: mesh_grid_properties _slang_mgp
[numthreads(1,1,1)]
void taskMain()
{
// CHECK: _slang_mesh_payload
// CHECK: _slang_mgp.set_threadgroups_per_grid
// METAL: _slang_mesh_payload
// METAL: _slang_mgp.set_threadgroups_per_grid
MeshPayload p;
p.exponent = 3;
DispatchMesh(1, 1, 1, p);
}

//
// Mesh shader
//

const static float2 positions[3] = {
float2(0.0, -0.5),
float2(0.5, 0.5),
float2(-0.5, 0.5)
};

const static float3 colors[3] = {
float3(1.0, 1.0, 0.0),
float3(0.0, 1.0, 1.0),
float3(1.0, 0.0, 1.0)
};

struct Vertex
{
float4 pos : SV_Position;
float3 color : Color;
int index : Index;
int value : Value;
};

struct Primitive
{
uint prim : SV_PrimitiveID;
};

const static uint MAX_VERTS = 12;
const static uint MAX_PRIMS = 4;

[outputtopology("triangle")]
[numthreads(12, 1, 1)]
void meshMain(
in uint tig: SV_GroupIndex,
in payload MeshPayload meshPayload,
// Check that we correctly generate the specific 'in payload' that HLSL
// requires:
// HLSL: , in payload MeshPayload
OutputVertices<Vertex, MAX_VERTS> verts,
OutputIndices<uint3, MAX_PRIMS> triangles,
OutputPrimitives<Primitive, MAX_PRIMS> primitives
)
{
const uint numVertices = 12;
const uint numPrimitives = 4;
SetMeshOutputCounts(numVertices, numPrimitives);

if (tig < numVertices)
{
const int tri = tig / 3;
verts[tig] = { float4(positions[tig % 3], 0, 1), colors[tig % 3], tri, int(pow(tri, meshPayload.exponent)) };
}

if (tig < numPrimitives)
{
triangles[tig] = tig * 3 + uint3(0, 1, 2);
primitives[tig] = { tig };
}
}

//
// Fragment Shader
//

struct Fragment
{
float4 color : SV_Target;
};

Fragment fragmentMain(Vertex input)
{
outputBuffer[input.index] = input.value;

Fragment output;
output.color = float4(input.color, 1.0);
return output;
}

Loading