Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Precompiled SPIR-V import support #5048

Merged
merged 10 commits into from
Oct 29, 2024
107 changes: 95 additions & 12 deletions source/slang/slang-emit-spirv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "slang-ir-call-graph.h"
#include "slang-ir-insts.h"
#include "slang-ir-layout.h"
#include "slang-ir-redundancy-removal.h"
#include "slang-ir-spirv-legalize.h"
#include "slang-ir-spirv-snippet.h"
#include "slang-ir-util.h"
Expand Down Expand Up @@ -2628,14 +2629,75 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
/// Emit a declaration for the given `irFunc`
SpvInst* emitFuncDeclaration(IRFunc* irFunc)
{
if (irFunc->findDecorationImpl(kIROp_SPIRVOpDecoration))
return nullptr;
// For now we aren't handling function declarations;
// we expect to deal only with fully linked modules.
// [2.4: Logical Layout of a Module]
//
// > All function declarations("declarations" are functions without a
// body; there is no forward declaration to a function with a body).
//
auto section = getSection(SpvLogicalSectionID::FunctionDeclarations);

// > A function declaration is as follows.
// > * Function declaration, using OpFunction.
// > * Function parameter declarations, using OpFunctionParameter.
// > * Function end, using OpFunctionEnd.
//

// [3.24. Function Control]
//
// TODO: We should eventually support emitting the "function control"
// mask to include inline and other hint bits based on decorations
// set on `irFunc`.
//
SpvFunctionControlMask spvFunctionControl = SpvFunctionControlMaskNone;

// [3.32.9. Function Instructions]
//
// > OpFunction
//
// Note that the type <id> of a SPIR-V function uses the
// *result* type of the function, while the actual function
// type is given as a later operand. Slan IR instead uses
// the type of a function instruction store, you know, its *type*.
//
SpvInst* spvFunc = emitOpFunction(
section,
irFunc,
irFunc->getDataType()->getResultType(),
spvFunctionControl,
irFunc->getDataType());

// > OpFunctionParameter
//
// Though parameters always belong to blocks in Slang, there are no
// blocks in a function declaration, so we will emit the parameters
// as derived from the function's type.
//
auto funcType = irFunc->getDataType();
auto paramCount = funcType->getParamCount();
for (UInt pp = 0; pp < paramCount; ++pp)
{
auto paramType = funcType->getParamType(pp);
SpvInst* spvParam = emitOpFunctionParameter(spvFunc, nullptr, paramType);
maybeEmitPointerDecoration(spvParam, paramType, false, kIROp_Param);
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added pointer decoration

}

// [3.32.9. Function Instructions]
//
// > OpFunctionEnd
//
// In the SPIR-V encoding a function is logically the parent of any
// instructions up to a matching `OpFunctionEnd`. In our intermediate
// structure we will make the `OpFunctionEnd` be the last child of
// the `OpFunction`.
//
m_sink->diagnose(irFunc, Diagnostics::internalCompilerError);
SLANG_UNEXPECTED("function declaration in SPIR-V emit");
UNREACHABLE_RETURN(nullptr);
emitOpFunctionEnd(spvFunc, nullptr);

// We will emit any decorations pertinent to the function to the
// appropriate section of the module.
//
emitDecorations(irFunc, getID(spvFunc));

return spvFunc;
}

/// Emit a SPIR-V function definition for the Slang IR function `irFunc`.
Expand Down Expand Up @@ -4358,6 +4420,21 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
SpvLinkageTypeExport);
break;
}
case kIROp_DownstreamModuleImportDecoration:
{
requireSPIRVCapability(SpvCapabilityLinkage);
auto name =
decoration->getParent()->findDecoration<IRExportDecoration>()->getMangledName();
emitInst(
getSection(SpvLogicalSectionID::Annotations),
decoration,
SpvOpDecorate,
dstID,
SpvDecorationLinkageAttributes,
name,
SpvLinkageTypeImport);
break;
}
// ...
}

Expand Down Expand Up @@ -5019,9 +5096,9 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
return nullptr;
}

void maybeEmitPointerDecoration(SpvInst* varInst, IRInst* inst)
void maybeEmitPointerDecoration(SpvInst* varInst, IRType* type, bool isVar, IROp op)
{
auto ptrType = as<IRPtrType>(unwrapArray(inst->getDataType()));
auto ptrType = as<IRPtrType>(unwrapArray(type));
if (!ptrType)
return;
if (addressSpaceToStorageClass(ptrType->getAddressSpace()) ==
Expand All @@ -5033,7 +5110,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
getSection(SpvLogicalSectionID::Annotations),
nullptr,
varInst,
(as<IRVar>(inst) ? SpvDecorationAliasedPointer : SpvDecorationAliased));
(isVar ? SpvDecorationAliasedPointer : SpvDecorationAliased));
}
else
{
Expand All @@ -5049,14 +5126,18 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
getSection(SpvLogicalSectionID::Annotations),
nullptr,
varInst,
(inst->getOp() == kIROp_GlobalVar || inst->getOp() == kIROp_Var ||
inst->getOp() == kIROp_DebugVar
(op == kIROp_GlobalVar || op == kIROp_Var || op == kIROp_DebugVar
? SpvDecorationAliasedPointer
: SpvDecorationAliased));
}
}
}

void maybeEmitPointerDecoration(SpvInst* varInst, IRInst* inst)
{
maybeEmitPointerDecoration(varInst, inst->getDataType(), as<IRVar>(inst), inst->getOp());
}

SpvInst* emitParam(SpvInstParent* parent, IRInst* inst)
{
auto paramSpvInst = emitOpFunctionParameter(parent, inst, inst->getFullType());
Expand Down Expand Up @@ -7534,6 +7615,8 @@ SlangResult emitSPIRVFromIR(
}
#endif

removeAvailableInDownstreamModuleDecorations(CodeGenTarget::SPIRV, irModule);

auto shouldPreserveParams = codeGenContext->getTargetProgram()->getOptionSet().getBoolOption(
CompilerOptionName::PreserveParameters);
auto generateWholeProgram = codeGenContext->getTargetProgram()->getOptionSet().getBoolOption(
Expand Down
1 change: 1 addition & 0 deletions source/slang/slang-ir-inst-defs.h
Original file line number Diff line number Diff line change
Expand Up @@ -822,6 +822,7 @@ INST_RANGE(BindingQuery, GetRegisterIndex, GetRegisterSpace)
INST(PublicDecoration, public, 0, 0)
INST(HLSLExportDecoration, hlslExport, 0, 0)
INST(DownstreamModuleExportDecoration, downstreamModuleExport, 0, 0)
INST(DownstreamModuleImportDecoration, downstreamModuleImport, 0, 0)
INST(PatchConstantFuncDecoration, patchConstantFunc, 1, 0)
INST(OutputControlPointsDecoration, outputControlPoints, 1, 0)
INST(OutputTopologyDecoration, outputTopology, 1, 0)
Expand Down
1 change: 1 addition & 0 deletions source/slang/slang-ir-insts.h
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,7 @@ IR_SIMPLE_DECORATION(HLSLMeshPayloadDecoration)
IR_SIMPLE_DECORATION(GlobalInputDecoration)
IR_SIMPLE_DECORATION(GlobalOutputDecoration)
IR_SIMPLE_DECORATION(DownstreamModuleExportDecoration)
IR_SIMPLE_DECORATION(DownstreamModuleImportDecoration)

struct IRAvailableInDownstreamIRDecoration : IRDecoration
{
Expand Down
9 changes: 4 additions & 5 deletions source/slang/slang-ir-redundancy-removal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ bool removeRedundancyInFunc(IRGlobalValueWithCode* func)
void removeAvailableInDownstreamModuleDecorations(CodeGenTarget target, IRModule* module)
{
List<IRInst*> toRemove;
auto builder = IRBuilder(module);
for (auto globalInst : module->getGlobalInsts())
{
if (auto funcInst = as<IRFunc>(globalInst))
Expand All @@ -181,13 +182,11 @@ void removeAvailableInDownstreamModuleDecorations(CodeGenTarget target, IRModule
(dec->getTarget() == target))
{
// Gut the function definition, turning it into a declaration
for (auto inst : funcInst->getChildren())
for (auto block : funcInst->getBlocks())
{
if (inst->getOp() == kIROp_Block)
{
toRemove.add(inst);
}
toRemove.add(block);
}
builder.addDecoration(funcInst, kIROp_DownstreamModuleImportDecoration);
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions tests/library/export-library-generics.slang
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ public int normalFuncUsesGeneric(int a)
return genericFunc(obj);
}

public int normalFunc(int a)
public int normalFunc(int a, float b)
{
return a - 2;
return a - floor(b);
}
10 changes: 10 additions & 0 deletions tests/library/module-library-pointer-param.slang
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
//TEST_IGNORE_FILE:

// module-library-pointer-param.slang

module "module-library-pointer-param";

public int ptrFunc(int* a)
{
return *a;
}
2 changes: 1 addition & 1 deletion tests/library/precompiled-dxil-generics.slang
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,5 @@ struct Attributes
[shader("anyhit")]
void anyhit(inout Payload payload, Attributes attrib)
{
payload.val = normalFunc(x * y) + normalFuncUsesGeneric(y);
payload.val = normalFunc(floor(x * y), x) + normalFuncUsesGeneric(y);
}
2 changes: 1 addition & 1 deletion tests/library/precompiled-spirv-generics.slang
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,5 @@ struct Attributes
[shader("anyhit")]
void anyhit(inout Payload payload, Attributes attrib)
{
payload.val = normalFunc(x * y) + normalFuncUsesGeneric(y);
payload.val = normalFunc(floor(x * y), x) + normalFuncUsesGeneric(y);
}
31 changes: 31 additions & 0 deletions tests/library/precompiled-spirv-pointer-param.slang
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// precompiled-spirv-pointer-param.slang

// A test that uses slang-modules with embedded precompiled SPIRV and a library containing
// a function with a pointer parameter.
// The test compiles a library slang (module-library-pointer-param.slang) with -embed-downstream-ir then links the
// library to entrypoint slang (this file).
// The test passes if there is no errror thrown.
// TODO: Check if final linkage used only the precompiled spirv.

//TEST:COMPILE: tests/library/module-library-pointer-param.slang -o tests/library/module-library-pointer-param.slang-module -target spirv -embed-downstream-ir -incomplete-library
//TEST:COMPILE: tests/library/precompiled-spirv-pointer-param.slang -target spirv -stage anyhit -entry anyhit -o tests/library/linked.spirv

import "module-library-pointer-param";

struct Payload
{
int val;
}

struct Attributes
{
float2 bary;
}

[vk::push_constant] int* g_int;

[shader("anyhit")]
void anyhit(inout Payload payload, Attributes attrib)
{
payload.val = ptrFunc(g_int);
}
Loading