Skip to content

Commit

Permalink
Review fixes
Browse files Browse the repository at this point in the history
Signed-off-by: Janusz Lisiecki <[email protected]>
  • Loading branch information
JanuszL committed Sep 10, 2024
1 parent f621fe6 commit ec7fa68
Show file tree
Hide file tree
Showing 7 changed files with 21 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class VideoLoaderDecoderBase {
stride_(spec.GetArgument<int>("stride")),
step_(spec.GetArgument<int>("step")) {
has_labels_ = spec.TryGetRepeatedArgument(labels_, "labels");
has_frame_no_ = spec.GetArgument<bool>("enable_frame_num");
has_frame_idx_ = spec.GetArgument<bool>("enable_frame_num");
DALI_ENFORCE(
!has_labels_ || labels_.size() == filenames_.size(),
make_string(
Expand All @@ -63,7 +63,7 @@ class VideoLoaderDecoderBase {
std::vector<std::string> filenames_;
std::vector<int> labels_;
bool has_labels_ = false;
bool has_frame_no_ = false;
bool has_frame_idx_ = false;

Index current_index_ = 0;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ void VideoLoaderDecoderCpu::ReadSample(VideoSample<CPUBackend> &sample) {
if (has_labels_) {
sample.label_ = labels_[sample_span.video_idx_];
}
if (has_frame_no_) {
if (has_frame_idx_) {
sample.first_frame_ = sample_span.start_;
}
}
Expand Down
12 changes: 5 additions & 7 deletions dali/operators/reader/video_reader_decoder_cpu_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,9 @@ void VideoReaderDecoderCpu::RunImpl(SampleWorkspace &ws) {

namespace detail {
inline int VideoReaderDecoderOutputFn(const OpSpec &spec) {
int num_outputs = 1;
if (spec.HasArgument("labels")) num_outputs++;
bool enable_frame_num = spec.GetArgument<bool>("enable_frame_num");
if (enable_frame_num) num_outputs++;
return num_outputs;
bool has_labels = spec.HasArgument("labels")
bool has_frame_num_output = spec.GetArgument<bool>("enable_frame_num");
return 1 + has_labels + has_frame_num_output;
}
} // namespace detail

Expand Down Expand Up @@ -82,8 +80,8 @@ even in the variable frame rate scenario.)code")
R"code(Frames to load per sequence.)code",
DALI_INT32)
.AddOptionalArg("enable_frame_num",
R"code(If set, returns the first frame number in the decoded sequence
as a separate output.)code",
R"code(If set, returns the index of the first frame in the decoded sequence
as an additional output.)code",
false)
.AddOptionalArg("step",
R"code(Frame interval between each sequence.
Expand Down
2 changes: 1 addition & 1 deletion dali/operators/reader/video_reader_decoder_cpu_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class VideoReaderDecoderCpu

private:
bool has_labels_ = false;
bool has_frame_no_ = false;
bool has_frame_idx_ = false;
};

} // namespace dali
Expand Down
6 changes: 3 additions & 3 deletions dali/operators/reader/video_reader_decoder_gpu_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ namespace dali {
VideoReaderDecoderGpu::VideoReaderDecoderGpu(const OpSpec &spec)
: DataReader<GPUBackend, VideoSampleGpu, VideoSampleGpu, true>(spec),
has_labels_(spec.HasArgument("labels")),
has_frame_no_(spec.GetArgument<bool>("enable_frame_num")) {
has_frame_idx_(spec.GetArgument<bool>("enable_frame_num")) {
loader_ = InitLoader<VideoLoaderDecoderGpu>(spec);
this->SetInitialSnapshot();
}
Expand Down Expand Up @@ -59,7 +59,7 @@ bool VideoReaderDecoderGpu::SetupImpl(
};
out_index++;
}
if (has_frame_no_) {
if (has_frame_idx_) {
output_desc[out_index] = {
uniform_list_shape<1>(batch_size, {1}),
DALI_INT32
Expand Down Expand Up @@ -105,7 +105,7 @@ void VideoReaderDecoderGpu::RunImpl(Workspace &ws) {
ws.stream());
out_index++;
}
if (has_frame_no_) {
if (has_frame_idx_) {
auto &frame_no_output = ws.Output<GPUBackend>(out_index);
SmallVector<int, 32> frame_no_output_cpu;

Expand Down
2 changes: 1 addition & 1 deletion dali/operators/reader/video_reader_decoder_gpu_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class VideoReaderDecoderGpu : public DataReader<GPUBackend, VideoSampleGpu, Vide

private:
bool has_labels_ = false;
bool has_frame_no_ = false;
bool has_frame_idx_ = false;
};

} // namespace dali
Expand Down
16 changes: 8 additions & 8 deletions dali/operators/reader/video_reader_decoder_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class VideoReaderDecoderBaseTest : public VideoTestBase {
int frame_id, const uint8_t *frame, TestVideo &ground_truth) = 0;

template<typename Backend>
int GetFrameNo(dali::TensorList<Backend> &device_frame_no);
int GetFrameIdx(dali::TensorList<Backend> &device_frame_idx);

private:
template<typename Backend>
Expand Down Expand Up @@ -151,7 +151,7 @@ class VideoReaderDecoderBaseTest : public VideoTestBase {

auto &frame_video_output = ws.Output<Backend>(0);
const auto sample = frame_video_output.template tensor<uint8_t>(0);
int frame_no = GetFrameNo(ws.Output<Backend>(1));
int frame_no = GetFrameIdx(ws.Output<Backend>(1));

// We want to access correct order, so we compare only the first frame of the sequence
AssertFrame(frame_no, sample, ground_truth_video);
Expand All @@ -173,9 +173,9 @@ void VideoReaderDecoderBaseTest::RunShuffleTest<dali::CPUBackend>() {
}

template<>
int VideoReaderDecoderBaseTest::GetFrameNo(
dali::TensorList<dali::CPUBackend> &device_frame_no) {
const auto frame_no = device_frame_no.template tensor<int>(0);
int VideoReaderDecoderBaseTest::GetFrameIdx(
dali::TensorList<dali::CPUBackend> &device_frame_idx) {
const auto frame_no = device_frame_idx.template tensor<int>(0);
int frame_no_buffer = -1;
std::copy_n(frame_no, 1, &frame_no_buffer);
return frame_no_buffer;
Expand All @@ -195,9 +195,9 @@ void VideoReaderDecoderBaseTest::RunShuffleTest<dali::GPUBackend>() {
}

template<>
int VideoReaderDecoderBaseTest::GetFrameNo(
dali::TensorList<dali::GPUBackend> &device_frame_no) {
const auto frame_no = device_frame_no.template tensor<int>(0);
int VideoReaderDecoderBaseTest::GetFrameIdx(
dali::TensorList<dali::GPUBackend> &device_frame_idx) {
const auto frame_no = device_frame_idx.template tensor<int>(0);
int frame_no_buffer = -1;
MemCopy(&frame_no_buffer, frame_no, sizeof(int));
return frame_no_buffer;
Expand Down

0 comments on commit ec7fa68

Please sign in to comment.