Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Draft of validation utilites - focusing on type checking #5400

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,6 @@
[submodule "third_party/cvcuda"]
path = third_party/cvcuda
url = https://github.com/CVCUDA/CV-CUDA.git
[submodule "third_party/fmt"]
path = third_party/fmt
url = [email protected]:fmtlib/fmt.git
8 changes: 8 additions & 0 deletions cmake/Dependencies.common.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -380,3 +380,11 @@ if(BUILD_NVIMAGECODEC)
endif()
endif()
endif()

##################################################################
# {fmt}
##################################################################
check_and_add_cmake_submodule(${PROJECT_SOURCE_DIR}/third_party/fmt EXCLUDE_FROM_ALL)
set_target_properties(fmt PROPERTIES POSITION_INDEPENDENT_CODE ON)
list(APPEND DALI_LIBS fmt)
list(APPEND DALI_EXCLUDES libfmt.a)
9 changes: 5 additions & 4 deletions dali/operators/audio/preemphasis_filter_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <utility>
#include <vector>
#include "dali/operators/audio/preemphasis_filter_op.h"
#include "dali/pipeline/operator/error_reporting.h"

namespace dali {

Expand Down Expand Up @@ -93,11 +94,11 @@ void PreemphasisFilterCPU::RunImplTyped(Workspace &ws) {

void PreemphasisFilterCPU::RunImpl(Workspace &ws) {
const auto &input = ws.Input<CPUBackend>(0);
TYPE_SWITCH(input.type(), type2id, InputType, PREEMPH_TYPES, (
TYPE_SWITCH(output_type_, type2id, OutputType, PREEMPH_TYPES, (
TYPE_SWITCH(input.type(), type2id, InputType, (PREEMPH_TYPES), (
TYPE_SWITCH(output_type_, type2id, OutputType, (PREEMPH_TYPES), (
RunImplTyped<OutputType, InputType>(ws);
), DALI_FAIL(make_string("Unsupported output type: ", output_type_))); // NOLINT
), DALI_FAIL(make_string("Unsupported input type: ", input.type()))); // NOLINT
), (validate::OutputType<PREEMPH_TYPES>(spec_, ws, 0))); // NOLINT
), (validate::InputType<PREEMPH_TYPES>(spec_, ws, 0))); // NOLINT
}

DALI_REGISTER_OPERATOR(PreemphasisFilter, PreemphasisFilterCPU, CPU);
Expand Down
4 changes: 2 additions & 2 deletions dali/operators/audio/preemphasis_filter_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,8 @@ void PreemphasisFilterGPU::RunImplTyped(Workspace &ws) {

void PreemphasisFilterGPU::RunImpl(Workspace &ws) {
const auto &input = ws.Input<GPUBackend>(0);
TYPE_SWITCH(input.type(), type2id, InputType, PREEMPH_TYPES, (
TYPE_SWITCH(output_type_, type2id, OutputType, PREEMPH_TYPES, (
TYPE_SWITCH(input.type(), type2id, InputType, (PREEMPH_TYPES), (
TYPE_SWITCH(output_type_, type2id, OutputType, (PREEMPH_TYPES), (
RunImplTyped<OutputType, InputType>(ws);
), DALI_FAIL(make_string("Unsupported output type: ", output_type_))); // NOLINT
), DALI_FAIL(make_string("Unsupported input type: ", input.type()))); // NOLINT
Expand Down
5 changes: 3 additions & 2 deletions dali/operators/audio/preemphasis_filter_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@
#include "dali/pipeline/data/types.h"
#include "dali/pipeline/operator/checkpointing/stateless_operator.h"
#include "dali/pipeline/operator/operator.h"
#include "dali/pipeline/operator/error_reporting.h"

#define PREEMPH_TYPES \
(uint8_t, int8_t, uint16_t, int16_t, uint32_t, int32_t, uint64_t, int64_t, float, double)
uint8_t, int8_t, uint16_t, int16_t, uint32_t, int32_t, uint64_t, int64_t, float, double

namespace dali {
namespace detail {
Expand All @@ -46,7 +47,7 @@ class PreemphasisFilter : public StatelessOperator<Backend> {

explicit PreemphasisFilter(const OpSpec &spec)
: StatelessOperator<Backend>(spec),
output_type_(spec.GetArgument<DALIDataType>(arg_names::kDtype)) {
output_type_(validate::Dtype<PREEMPH_TYPES>(spec)) {
auto border_str = spec.GetArgument<std::string>(detail::kBorder);
if (border_str == "zero") {
border_type_ = BorderType::Zero;
Expand Down
7 changes: 7 additions & 0 deletions dali/pipeline/executor/executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "dali/pipeline/graph/op_graph_storage.h"
#include "dali/pipeline/operator/builtin/conditional/split_merge.h"
#include "dali/pipeline/operator/common.h"
#include "dali/pipeline/operator/error_reporting.h"
#include "dali/pipeline/workspace/workspace.h"
#include "dali/pipeline/workspace/workspace_data_factory.h"

Expand Down Expand Up @@ -487,6 +488,12 @@ void Executor<WorkspacePolicy, QueuePolicy>::RunHelper(OpNode &op_node, Workspac
if (had_empty_layout) empty_layout_in_idxs.push_back(i);
}

// TODO(klecki): Extract this to a separate function, this is just an example.
for (auto &argument_input : ws.ArgumentInputs()) {
// Check the types of argument inputs before they are accessed
validate::ArgumentType(spec, ws, argument_input.name);
}

bool should_allocate = false;
{
DomainTimeRange tr("[DALI][Executor] Setup");
Expand Down
136 changes: 132 additions & 4 deletions dali/pipeline/operator/error_reporting.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,35 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include <fmt/format.h>
#include <sstream>
#include <stdexcept>
#include <string>
#include <utility>
#include <vector>
// for fmt::join
#include <fmt/ranges.h>
// for ostream support
#include <fmt/ostream.h>

#include "dali/core/error_handling.h"
#include "dali/pipeline/data/backend.h"
#include "dali/pipeline/data/types.h"
#include "dali/pipeline/operator/error_reporting.h"
#include "dali/pipeline/operator/name_utils.h"
#include "dali/pipeline/operator/op_spec.h"

// template <> struct fmt::formatter<dali::DALIDataType> : fmt::ostream_formatter {};
// template <> struct fmt::formatter<dali::DALIDataType> :
// fmt::tostring_formatter<dali::DALIDataType> {};


namespace dali {

auto format_as(dali::DALIDataType type) {
return dali::to_string(type);
}

std::vector<PythonStackFrame> GetOperatorOriginInfo(const OpSpec &spec) {
auto origin_stack_filename = spec.GetRepeatedArgument<std::string>("_origin_stack_filename");
auto origin_stack_lineno = spec.GetRepeatedArgument<int>("_origin_stack_lineno");
Expand Down Expand Up @@ -67,8 +84,7 @@ void PropagateError(ErrorInfo error) {
catch (DaliError &e) {
e.UpdateMessage(make_string(error.context_info, e.what(), error.additional_message));
throw;
}
catch (DALIException &e) {
} catch (DALIException &e) {
// We drop the C++ stack trace at this point and go back to runtime_error.
throw std::runtime_error(
make_string(error.context_info, e.what(),
Expand Down Expand Up @@ -109,9 +125,121 @@ std::string GetErrorContextMessage(const OpSpec &spec) {
formatted_origin_stack + "\n") :
" "; // we need space before "encountered"

return make_string("Error in ", device, " operator `", op_name, "`",
optional_stack_mention, "encountered:\n\n");
return make_string("Error in ", device, " operator `", op_name, "`", optional_stack_mention,
"encountered:\n\n");
}


namespace validate {

std::string SepIfNotEmpty(const std::string &str, const std::string &sep = " ") {
if (str.empty()) {
return "";
}
return sep;
}

DALIDataType Type(DALIDataType actual_type, DALIDataType expected_type, const std::string &name,
const std::string &additional_msg) {
if (actual_type == expected_type) {
return actual_type;
}

throw DaliTypeError(fmt::format("Unexpected type for {}. Got type: `{}` but expected: `{}`.{}{}",
name, actual_type, expected_type, SepIfNotEmpty(additional_msg),
additional_msg));
}

DALIDataType Type(DALIDataType actual_type, span<const DALIDataType> expected_types,
const std::string &name, const std::string &additional_msg) {
if (std::size(expected_types) == 1) {
return Type(actual_type, expected_types[0], name, additional_msg);
}
for (auto expected_type : expected_types) {
if (actual_type == expected_type) {
return actual_type;
}
}

throw DaliTypeError(fmt::format(
"Unexpected type for {}. Got type: `{}` but expected one of: `{}`.{}{}", name, actual_type,
fmt::join(expected_types, "`, `"), SepIfNotEmpty(additional_msg), additional_msg));
}

DALIDataType InputType(const OpSpec &spec, const Workspace &ws, int input_idx,
DALIDataType allowed_type, const std::string &additional_msg) {
DALIDataType dtype = ws.GetInputDataType(input_idx);
return Type(dtype, allowed_type, FormatInput(spec, input_idx), additional_msg);
}

DALIDataType InputType(const OpSpec &spec, const Workspace &ws, int input_idx,
span<const DALIDataType> allowed_types, const std::string &additional_msg) {
DALIDataType dtype = ws.GetInputDataType(input_idx);
return Type(dtype, allowed_types, FormatInput(spec, input_idx), additional_msg);
}

DALIDataType Dtype(const OpSpec &spec, DALIDataType allowed_type, bool allow_unspecified,
const std::string &additional_msg) {
if (allow_unspecified && !spec.HasArgument("dtype")) {
return DALI_NO_TYPE;
} else if (!allow_unspecified && !spec.HasArgument("dtype")) {
throw DaliValueError(fmt::format("{} was not specified.{}{}",
FormatArgument(spec, "dtype", true),
SepIfNotEmpty(additional_msg), additional_msg));
}
return Type(spec.GetArgument<DALIDataType>("dtype"), allowed_type, FormatArgument(spec, "dtype"),
additional_msg);
}

DALIDataType Dtype(const OpSpec &spec, span<const DALIDataType> allowed_types,
bool allow_unspecified, const std::string &additional_msg) {
if (allow_unspecified && !spec.HasArgument("dtype")) {
return DALI_NO_TYPE;
} else if (!allow_unspecified && !spec.HasArgument("dtype")) {
throw DaliValueError(fmt::format("{} was not specified.{}{}",
FormatArgument(spec, "dtype", true),
SepIfNotEmpty(additional_msg), additional_msg));
}
return Type(spec.GetArgument<DALIDataType>("dtype"), allowed_types, FormatArgument(spec, "dtype"),
additional_msg);
}

void Dim(int actual_dim, int expected_dim, const std::string &name,
const std::string &additional_msg) {
if (actual_dim == expected_dim) {
return;
}
throw DaliValueError(fmt::format("Got dim: `{}` for {}, but expected: `{}`.{}{}", actual_dim,
name, expected_dim, SepIfNotEmpty(additional_msg),
additional_msg));
}

DALIDataType Dtype(const OpSpec &spec, const Workspace &ws, bool (*is_valid)(DALIDataType),
const std::string &explanation) {
return DALI_NO_TYPE; // TODO(klecki): implement
}

DALIDataType OutputType(const OpSpec &spec, const Workspace &ws, int output_idx,
DALIDataType allowed_type, const std::string &additional_msg) {
DALIDataType dtype = ws.GetOutputDataType(output_idx);
return Type(dtype, allowed_type, FormatOutput(spec, output_idx), additional_msg);
}

DALIDataType OutputType(const OpSpec &spec, const Workspace &ws, int output_idx,
span<const DALIDataType> allowed_types, const std::string &additional_msg) {
DALIDataType dtype = ws.GetOutputDataType(output_idx);
return Type(dtype, allowed_types, FormatOutput(spec, output_idx), additional_msg);
}

DALIDataType ArgumentType(const OpSpec &spec, const Workspace &ws, const std::string &arg_name,
const std::string &additional_msg) {
DALIDataType expected_type = spec.GetSchema().GetArgumentType(arg_name);
if (!spec.HasTensorArgument(arg_name)) {
return expected_type;
}
return Type(ws.ArgumentInput(arg_name).type(), expected_type, FormatArgument(spec, arg_name),
additional_msg);
}

} // namespace validate
} // namespace dali
Loading
Loading