Skip to content

Commit

Permalink
feat(api): Add API for building interpolated getter free functions. (#…
Browse files Browse the repository at this point in the history
…1765)

Adds a new RendererServices API build_interpolated_getter which allows for building custom free functions that provide values for interpolated datas.

The new API is involved at material compilation time and allow the developers to provide specialized functions leveraging the information known at compile time. It will allow for many optimization opportunities for the compiler by replacing runtime branch with direct memory reads.

This PR is a followup to the API build_attribute_getter proposed in this PR: #1704

TestShade provides an example implementation showing how this compile time information can be used to select an appropriate function to use as the attribute provider, and how to configure the function signature.
Since the feature is enable by default when rs_bitcode is used, the existing testsuite ensure the feature is generating the same result as the previous API.

---------

Signed-off-by: uedaki <[email protected]>
  • Loading branch information
Uedaki authored Feb 15, 2024
1 parent becc03d commit 7316669
Show file tree
Hide file tree
Showing 13 changed files with 442 additions and 57 deletions.
2 changes: 2 additions & 0 deletions src/include/OSL/llvm_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -971,6 +971,8 @@ class OSLEXECPUBLIC LLVM_Util {
llvm::Value* op_mul(llvm::Value* a, llvm::Value* b);
llvm::Value* op_div(llvm::Value* a, llvm::Value* b);
llvm::Value* op_mod(llvm::Value* a, llvm::Value* b);
llvm::Value* op_int8_to_int(llvm::Value* a);
llvm::Value* op_int_to_int8(llvm::Value* a);
llvm::Value* op_float_to_int(llvm::Value* a);
llvm::Value* op_int_to_float(llvm::Value* a);
llvm::Value* op_bool_to_int(llvm::Value* a);
Expand Down
38 changes: 37 additions & 1 deletion src/include/OSL/rendererservices.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ typedef void (*PrepareClosureFunc)(RendererServices*, int id, void* data);
typedef void (*SetupClosureFunc)(RendererServices*, int id, void* data);

enum class AttributeSpecBuiltinArg {
OpaqueExecutionContext, //OpaqueExecContextPtr
OpaqueExecutionContext, // OpaqueExecContextPtr
ShadeIndex, // int
Derivatives, // bool
Type, // TypeDesc_pod
Expand All @@ -45,6 +45,17 @@ enum class AttributeSpecBuiltinArg {
using AttributeSpecArg = ArgVariant<AttributeSpecBuiltinArg>;
using AttributeGetterSpec = FunctionSpec<AttributeSpecArg>;

enum class InterpolatedSpecBuiltinArg {
OpaqueExecutionContext, // OpaqueExecContextPtr
ShadeIndex, // int
Derivatives, // bool
Type, // TypeDesc_pod
ParamName, // ustringhash_pod
};

using InterpolatedSpecArg = ArgVariant<InterpolatedSpecBuiltinArg>;
using InterpolatedGetterSpec = FunctionSpec<InterpolatedSpecArg>;

// Turn off warnings about unused params for this file, since we have lots
// of declarations with stub function bodies.
OSL_PRAGMA_WARNING_PUSH
Expand All @@ -68,6 +79,7 @@ class OSLEXECPUBLIC RendererServices {
/// supports it. Feature names include:
/// "OptiX"
/// "build_attribute_getter"
/// "build_interpolated_getter"
///
/// This allows some customization of JIT generated code based on the
/// facilities and features of a particular renderer. It also allows
Expand Down Expand Up @@ -273,6 +285,30 @@ class OSLEXECPUBLIC RendererServices {
ustringhash object, TypeDesc type,
ustringhash name, int index, void* val);

/// Builds a free function to provide a value for a given interpolated value.
/// This occurs at shader compile time, not at execution time.
///
/// @param group
/// The shader group currently requesting the attribute.
///
/// @param param_name
/// The parameter name.
///
/// @param type
/// The type of the value being requested.
///
/// @param derivatives
/// True if derivatives are also being requested.
///
/// @param spec
/// The built interpolated getter. An empty function name is interpreted
/// as a missing attribute.
///
virtual void build_interpolated_getter(const ShaderGroup& group,
const ustring& param_name,
TypeDesc type, bool derivatives,
InterpolatedGetterSpec& spec);

/// Get the named user-data from the current object and write it into
/// 'val'. If derivatives is true, the derivatives should be written into val
/// as well. Return false if no user-data with the given name and type was
Expand Down
56 changes: 56 additions & 0 deletions src/liboslexec/backendllvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -570,5 +570,61 @@ class BackendLLVM final : public OSOProcessorBase {
};


template<typename TArgVariant>
void
append_constant_arg(BackendLLVM& rop, const TArgVariant& arg,
std::vector<llvm::Value*>& args)
{
switch (arg.type()) {
default:
case TArgVariant::Type::Unspecified:
case TArgVariant::Type::Builtin: OSL_DASSERT(false); break;
case TArgVariant::Type::Bool:
args.push_back(rop.ll.constant_bool(arg.get_bool()));
break;
case TArgVariant::Type::Int8:
args.push_back(rop.ll.constant8(arg.get_int8()));
break;
case TArgVariant::Type::Int16:
args.push_back(rop.ll.constant16(arg.get_int16()));
break;
case TArgVariant::Type::Int32:
args.push_back(rop.ll.constant(arg.get_int32()));
break;
case TArgVariant::Type::Int64:
args.push_back(rop.ll.constanti64(arg.get_int64()));
break;
case TArgVariant::Type::UInt8:
args.push_back(rop.ll.constant8(arg.get_uint8()));
break;
case TArgVariant::Type::UInt16:
args.push_back(rop.ll.constant16(arg.get_uint16()));
break;
case TArgVariant::Type::UInt32:
args.push_back(rop.ll.constant(arg.get_uint32()));
break;
case TArgVariant::Type::UInt64:
args.push_back(rop.ll.constant64(arg.get_uint64()));
break;
case TArgVariant::Type::Float:
args.push_back(rop.ll.constant(arg.get_float()));
break;
case TArgVariant::Type::Double:
args.push_back(rop.ll.constant64(arg.get_double()));
break;
case TArgVariant::Type::Pointer:
args.push_back(rop.ll.constant_ptr(arg.get_ptr()));
break;
case TArgVariant::Type::UString:
args.push_back(rop.ll.constant(arg.get_ustring()));
break;
case TArgVariant::Type::UStringHash:
args.push_back(rop.ll.constant(ustring(arg.get_ustringhash())));
break;
}
}



}; // namespace pvt
OSL_NAMESPACE_EXIT
1 change: 1 addition & 0 deletions src/liboslexec/builtindecl.h
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ DECL(osl_naninf_check, "xiXiXhihiih")
DECL(osl_uninit_check, "xLXXhihihhihihii")
DECL(osl_get_attribute, "iXihhiiLX")
DECL(osl_bind_interpolated_param, "iXhLiXiXiXi")
DECL(osl_incr_get_userdata_calls, "xX")
DECL(osl_get_texture_options, "XX");
DECL(osl_get_noise_options, "XX");
DECL(osl_get_trace_options, "XX");
Expand Down
54 changes: 1 addition & 53 deletions src/liboslexec/llvm_gen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3298,59 +3298,7 @@ LLVMGEN(llvm_gen_noise)
return true;
}

template<typename TArgVariant>
void
append_constant_arg(BackendLLVM& rop, const TArgVariant& arg,
std::vector<llvm::Value*>& args)
{
switch (arg.type()) {
default:
case TArgVariant::Type::Unspecified:
case TArgVariant::Type::Builtin: OSL_DASSERT(false); break;
case TArgVariant::Type::Bool:
args.push_back(rop.ll.constant_bool(arg.get_bool()));
break;
case TArgVariant::Type::Int8:
args.push_back(rop.ll.constant8(arg.get_int8()));
break;
case TArgVariant::Type::Int16:
args.push_back(rop.ll.constant16(arg.get_int16()));
break;
case TArgVariant::Type::Int32:
args.push_back(rop.ll.constant(arg.get_int32()));
break;
case TArgVariant::Type::Int64:
args.push_back(rop.ll.constanti64(arg.get_int64()));
break;
case TArgVariant::Type::UInt8:
args.push_back(rop.ll.constant8(arg.get_uint8()));
break;
case TArgVariant::Type::UInt16:
args.push_back(rop.ll.constant16(arg.get_uint16()));
break;
case TArgVariant::Type::UInt32:
args.push_back(rop.ll.constant(arg.get_uint32()));
break;
case TArgVariant::Type::UInt64:
args.push_back(rop.ll.constant64(arg.get_uint64()));
break;
case TArgVariant::Type::Float:
args.push_back(rop.ll.constant(arg.get_float()));
break;
case TArgVariant::Type::Double:
args.push_back(rop.ll.constant64(arg.get_double()));
break;
case TArgVariant::Type::Pointer:
args.push_back(rop.ll.constant_ptr(arg.get_ptr()));
break;
case TArgVariant::Type::UString:
args.push_back(rop.ll.constant(arg.get_ustring()));
break;
case TArgVariant::Type::UStringHash:
args.push_back(rop.ll.constant(ustring(arg.get_ustringhash())));
break;
}
}


LLVMGEN(llvm_gen_getattribute)
{
Expand Down
106 changes: 104 additions & 2 deletions src/liboslexec/llvm_instance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,9 @@ BackendLLVM::llvm_type_groupdata()

// Now add the array that tells which userdata have been initialized,
// and the space for the userdata values.
// The initialized array can contains 3 different values. 0 means the value
// is not initialized, 1 means the value can not be found, and 2 means the
// value has been found and cached in the userdata values.
int nuserdata = (int)group().m_userdata_names.size();
if (nuserdata) {
if (llvm_debug() >= 2)
Expand All @@ -320,9 +323,9 @@ BackendLLVM::llvm_type_groupdata()
TypeDesc* types = &group().m_userdata_types[0];
int* offsets = &group().m_userdata_offsets[0];
int sz = (nuserdata + 3) & (~3);
fields.push_back(ll.type_array(ll.type_bool(), sz));
fields.push_back(ll.type_array(ll.type_int8(), sz));
m_groupdata_field_names.emplace_back("userdata_init_flags");
offset += nuserdata * sizeof(bool);
offset += nuserdata * sizeof(int8_t);
++order;
for (int i = 0; i < nuserdata; ++i) {
TypeDesc type = types[i];
Expand Down Expand Up @@ -638,6 +641,105 @@ BackendLLVM::llvm_assign_initial_value(const Symbol& sym, bool force)
// source didn't have them.
if (sym.has_derivs() && !symloc->derivs)
ll.op_memset(ll.offset_ptr(dstptr, size), 0, 2 * size);
} else if (renderer()->supports("build_interpolated_getter")) {
InterpolatedGetterSpec spec;
renderer()->build_interpolated_getter(group(), symname, type,
sym.has_derivs(), spec);
if (!spec.function_name().empty()) {
std::vector<llvm::Value*> args;
args.reserve(spec.arg_count() + 1);
// Pushed the arguments of the function stored in the
// InterpolatedGetterSpec. Each value can either be the enum
// InterpolatedSpecBuiltinArg or a constant value.
for (size_t index = 0; index < spec.arg_count(); index++) {
const auto& arg = spec.arg(index);
if (arg.is_holding<InterpolatedSpecBuiltinArg>()) {
switch (arg.get_builtin()) {
default: OSL_DASSERT(false); break;
case InterpolatedSpecBuiltinArg::OpaqueExecutionContext:
args.push_back(sg_void_ptr());
break;
case InterpolatedSpecBuiltinArg::ShadeIndex:
args.push_back(shadeindex());
break;
case InterpolatedSpecBuiltinArg::Derivatives:
args.push_back(ll.constant_bool(sym.has_derivs()));
break;
case InterpolatedSpecBuiltinArg::Type:
args.push_back(ll.constant(type));
break;
case InterpolatedSpecBuiltinArg::ParamName:
args.push_back(llvm_const_hash(symname));
break;
}
} else {
append_constant_arg(*this, arg, args);
}
}
// Push userdata data ptr
args.push_back(
ll.void_ptr(groupdata_field_ptr(2 + userdata_index)));
// Start of the osl_bind_interpolated_param instructions
int userdata_has_derivs
= group().m_userdata_derivs[userdata_index];
llvm::Value* userdata_data = groupdata_field_ptr(
2 + userdata_index);
llvm::Value* symbol_data = llvm_void_ptr(sym);
int symbol_data_size = sym.derivsize();
// char status = *userdata_initialized;
llvm::Value* userdata_initializedPtr = userdata_initialized_ref(
userdata_index);
llvm::Value* status = ll.op_int8_to_int(
ll.op_load(ll.type_int8(), userdata_initializedPtr));
// if (status == 0)
llvm::BasicBlock* then_block = ll.new_basic_block();
llvm::BasicBlock* after_block = ll.new_basic_block();
ll.op_branch(ll.op_eq(status, ll.constant(0)), then_block,
after_block);
{
// bool ok = interpolated_getter();
llvm::Value* ok
= ll.call_function(spec.function_name().c_str(), args);
ok = ll.op_bool_to_int(ok);
// status = 1 + ok;
status = ll.op_add(ll.constant(1), ok);
// *userdata_initialized = status;
ll.op_store(ll.op_int_to_int8(status),
userdata_initializedPtr);
if (!use_optix() && shadingsys().m_statslevel != 0) {
// sg->context->incr_get_userdata_calls();
ll.call_function("osl_incr_get_userdata_calls",
sg_void_ptr());
}
}
// endif
ll.op_branch(after_block);
status = ll.op_int8_to_int(
ll.op_load(ll.type_int8(), userdata_initializedPtr));
// if (status == 2)
then_block = ll.new_basic_block();
after_block = ll.new_basic_block();
llvm::Value* cond = ll.op_eq(status, ll.constant(2));
ll.op_branch(cond, then_block, after_block);
{
int udata_size = (userdata_has_derivs ? 3 : 1)
* type.size();
// memcpy(symbol_data, userdata_data, std::min(symbol_data_size, udata_size));
ll.op_memcpy(symbol_data, userdata_data,
std::min(symbol_data_size, udata_size));
if (symbol_data_size > udata_size) {
// memset((char*)symbol_data + udata_size, 0, symbol_data_size - udata_size);
llvm::Value* padding = ll.offset_ptr(symbol_data,
udata_size);
ll.op_memset(padding, 0, symbol_data_size - udata_size);
}
}
// endif
ll.op_branch(after_block);
got_userdata = ll.op_bool_to_int(cond);
} else {
got_userdata = ll.constant(0);
}
} else {
// No pre-placement: fall back to call to the renderer callback.
llvm::Value* args[] = {
Expand Down
26 changes: 26 additions & 0 deletions src/liboslexec/llvm_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6098,6 +6098,32 @@ LLVM_Util::op_mod(llvm::Value* a, llvm::Value* b)



llvm::Value*
LLVM_Util::op_int8_to_int(llvm::Value* a)
{
if (a->getType() == type_int8())
return builder().CreateSExt(a, type_int());
if (a->getType() == type_int())
return a;
OSL_ASSERT(0 && "Op has bad value type combination");
return nullptr;
}



llvm::Value*
LLVM_Util::op_int_to_int8(llvm::Value* a)
{
if (a->getType() == type_int())
return builder().CreateTrunc(a, type_int8());
if (a->getType() == type_int8())
return a;
OSL_ASSERT(0 && "Op has bad value type combination");
return nullptr;
}



llvm::Value*
LLVM_Util::op_float_to_int(llvm::Value* a)
{
Expand Down
10 changes: 10 additions & 0 deletions src/liboslexec/rendservices.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,16 @@ RendererServices::get_array_attribute(ShaderGlobals* sg, bool derivatives,



void
RendererServices::build_interpolated_getter(const ShaderGroup& group,
const ustring& param_name,
TypeDesc type, bool derivatives,
InterpolatedGetterSpec& spec)
{
}



bool
RendererServices::get_userdata(bool derivatives, ustringhash name,
TypeDesc type, ShaderGlobals* sg, void* val)
Expand Down
9 changes: 9 additions & 0 deletions src/liboslexec/shadingsys.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4772,3 +4772,12 @@ osl_bind_interpolated_param(void* sg_, ustringhash_pod name_, long long type,
}
return 0; // no such user data
}



OSL_SHADEOP void
osl_incr_get_userdata_calls(void* sg_)
{
ShaderGlobals* sg = (ShaderGlobals*)sg_;
sg->context->incr_get_userdata_calls();
}
Loading

0 comments on commit 7316669

Please sign in to comment.