Skip to content

Commit

Permalink
Erase VFR frames
Browse files Browse the repository at this point in the history
Signed-off-by: Janusz Lisiecki <[email protected]>
  • Loading branch information
JanuszL committed Oct 20, 2023
1 parent 15a4956 commit 948562a
Show file tree
Hide file tree
Showing 2 changed files with 145 additions and 40 deletions.
104 changes: 65 additions & 39 deletions dali/operators/reader/nvdecoder/nvdecoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -336,42 +336,54 @@ int NvDecoder::handle_display_(CUVIDPARSERDISPINFO* disp_info) {
req_ready_ = VidReqStatus::REQ_IN_PROGRESS;
}

if (current_recv_.count <= 0) {
if (recv_queue_.empty()) {
LOG_LINE << "Ditching frame " << frame << " since "
<< "the receive queue is empty." << std::endl;
return kNvcuvid_success;
while (1) {
if (current_recv_.count <= 0) {
if (recv_queue_.empty()) {
LOG_LINE << "Ditching frame " << frame << " since "
<< "the receive queue is empty." << std::endl;
return kNvcuvid_success;
}
LOG_LINE << "Moving on to next request, " << recv_queue_.size()
<< " reqs left" << std::endl;
current_recv_ = recv_queue_.pop();
frame = av_rescale_q(disp_info->timestamp,
nv_time_base_, current_recv_.frame_base);
}
LOG_LINE << "Moving on to next request, " << recv_queue_.size()
<< " reqs left" << std::endl;
current_recv_ = recv_queue_.pop();
frame = av_rescale_q(disp_info->timestamp,
nv_time_base_, current_recv_.frame_base);
}

if (stop_) return kNvcuvid_failure;
if (stop_) return kNvcuvid_failure;

if (current_recv_.count <= 0) {
// a new req with count <= 0 probably means we are finishing
// up and should just ditch this frame
LOG_LINE << "Ditching frame " << frame << "since current_recv_.count <= 0" << std::endl;
return kNvcuvid_success;
}
if (current_recv_.count <= 0) {
// a new req with count <= 0 probably means we are finishing
// up and should just ditch this frame
LOG_LINE << "Ditching frame " << frame << "since current_recv_.count <= 0" << std::endl;
return kNvcuvid_success;
}

if (frame < current_recv_.frame) {
// TODO(spanev) This definitely needs better error handling...
// Add exception? Directly or after countdown treshold?
LOG_LINE << "Ditching frame " << frame << " since we are waiting for "
<< "frame " << current_recv_.frame << std::endl;
return kNvcuvid_success;
} else if (frame > current_recv_.frame) {
LOG_LINE << "Receive frame " << frame << " that is pas the exptected "
<< "frame " << current_recv_.frame << std::endl;
req_ready_ = VidReqStatus::REQ_ERROR;
stop_ = true;
// Main thread is waiting on frame_queue_
frame_queue_.shutdown();
return kNvcuvid_failure;
if (frame < current_recv_.frame) {
// TODO(spanev) This definitely needs better error handling...
// Add exception? Directly or after countdown treshold?
LOG_LINE << "Ditching frame " << frame << " since we are waiting for "
<< "frame " << current_recv_.frame << std::endl;
return kNvcuvid_success;
} else if (frame > current_recv_.frame) {
LOG_LINE << "Receive frame " << frame << " that is over the expected "
<< "frame " << current_recv_.frame
<< "\e[1mGoing ahead with empty frame " << frame
<< " wanted count: " << current_recv_.count
<< "\e[0m" << std::endl;

current_recv_.frame += current_recv_.stride;
current_recv_.count -= current_recv_.stride;

// push empty if we are past the expected one and check if the one we have now
// matches the next frame
frame_queue_.push(nullptr);
if (current_recv_.count <= 0) {
req_ready_ = VidReqStatus::REQ_READY;
}
continue;
}
break;
}

LOG_LINE << "\e[1mGoing ahead with frame " << frame
Expand Down Expand Up @@ -425,13 +437,27 @@ void NvDecoder::receive_frames(SequenceWrapper& sequence) {

auto* frame_disp_info = frame_queue_.pop();
if (stop_) break;
auto frame = MappedFrame{frame_disp_info, decoder_, stream_};
sequence.timestamps.push_back(frame_disp_info->timestamp * av_q2d(
nv_time_base_));
if (stop_) break;
convert_frame(frame, sequence, i);
// synchronize before MappedFrame is destroyed and cuvidUnmapVideoFrame is called
CUDA_CALL(cudaStreamSynchronize(stream_));
if (frame_disp_info) {
auto frame = MappedFrame{frame_disp_info, decoder_, stream_};
sequence.timestamps.push_back(frame_disp_info->timestamp * av_q2d(
nv_time_base_));
if (stop_) break;
convert_frame(frame, sequence, i);
// synchronize before MappedFrame is destroyed and cuvidUnmapVideoFrame is called
CUDA_CALL(cudaStreamSynchronize(stream_));
} else {
LOG_LINE << "Padding empty frame " << i << std::endl;
sequence.timestamps.push_back(-1);
auto data_size = i * volume(sequence.frame_shape());
auto pad_size = volume(sequence.frame_shape()) *
dali::TypeTable::GetTypeInfo(sequence.dtype).size();
TYPE_SWITCH(dtype_, type2id, OutputType, NVDECODER_SUPPORTED_TYPES, (
cudaMemsetAsync(sequence.sequence.mutable_data<OutputType>() + data_size, 0, pad_size,
stream_);
), DALI_FAIL(make_string("Not supported output type:", dtype_, // NOLINT
"Only DALI_UINT8 and DALI_FLOAT are supported as the decoder outputs.")););
CUDA_CALL(cudaStreamSynchronize(stream_));
}
}
if (captured_exception_)
std::rethrow_exception(captured_exception_);
Expand Down
81 changes: 80 additions & 1 deletion tools/hw_decoder_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from nvidia.dali.pipeline import pipeline_def
import random
import numpy as np
import os
from nvidia.dali.auto_aug import auto_augment

parser = argparse.ArgumentParser(description='DALI HW decoder benchmark')
parser.add_argument('-b', dest='batch_size', help='batch size', default=1, type=int)
Expand All @@ -35,7 +37,7 @@
input_files_arg.add_argument('-i', dest='images_dir', help='images dir')
input_files_arg.add_argument('--image_list', dest='image_list', nargs='+', default=[],
help='List of images used for the benchmark.')
parser.add_argument('-p', dest='pipeline', choices=['decoder', 'rn50', 'efficientnet_inference'],
parser.add_argument('-p', dest='pipeline', choices=['decoder', 'rn50', 'efficientnet_inference', 'vit'],
help='pipeline to test', default='decoder',
type=str)
parser.add_argument('--width_hint', dest='width_hint', default=0, type=int)
Expand Down Expand Up @@ -139,6 +141,80 @@ def create_input_tensor(batch_size, file_list):
assert arr.shape == arrays[0].shape, "Arrays must have the same shape"
return np.stack(arrays)

# Updated pipeline definition for ViT and the creation of the iterator compatible with CLU. This should be the only code neccesary
# for the updated approach. Plus it implements GPU support.


def non_image_preprocessing(raw_text):
return np.array([int(bytes(raw_text).decode('utf-8'))])


@pipeline_def(batch_size=args.batch_size,
num_threads=args.num_threads,
device_id=args.device_id,
seed=0)
def vit_pipeline(is_training=False, image_shape=(384,384,3), num_classes = 1000):
files_paths = [os.path.join(args.images_dir, f) for f in os.listdir(
args.images_dir)]

img, clss = fn.readers.webdataset(
paths=files_paths,
index_paths=None,
ext=['jpg', 'cls'],
missing_component_behavior='error',
random_shuffle=False,
shard_id=0,
num_shards=1,
pad_last_batch=False if is_training else True,
name='webdataset_reader')

use_gpu = args.device == 'gpu'
labels = fn.python_function(clss, function=non_image_preprocessing, num_outputs=1)
if use_gpu:
labels = labels.gpu()
labels = fn.one_hot(labels, num_classes=num_classes)

device = 'mixed' if use_gpu else 'cpu'
img = fn.decoders.image(img, device=device, output_type=types.RGB,
hw_decoder_load=args.hw_load,
preallocate_width_hint=args.width_hint,
preallocate_height_hint=args.height_hint)

if is_training:
img = fn.random_resized_crop(img, size=image_shape[:-1])
img = fn.flip(img, depthwise=0, horizontal=fn.random.coin_flip())

# color jitter
brightness = fn.random.uniform(range=[0.6, 1.4])
contrast = fn.random.uniform(range=[0.6, 1.4])
saturation = fn.random.uniform(range=[0.6, 1.4])
hue = fn.random.uniform(range=[0.9, 1.1])
img = fn.color_twist(
img,
brightness=brightness,
contrast=contrast,
hue=hue,
saturation=saturation)

# auto-augment
# `shape` controls the magnitude of the translation operations
img = auto_augment.auto_augment_image_net(img)
else:
img = fn.resize(img, size=image_shape[:-1])

# normalize
# https://github.com/NVIDIA/DALI/issues/4469
mean = np.asarray([0.485, 0.456, 0.406])[None, None, :]
std = np.asarray([0.229, 0.224, 0.225])[None, None, :]
scale = 1 / 255.
img = fn.normalize(
img,
mean=mean / scale,
stddev=std,
scale=scale,
dtype=types.FLOAT)

return img, labels

pipes = []
if args.pipeline == 'decoder':
Expand All @@ -150,6 +226,9 @@ def create_input_tensor(batch_size, file_list):
elif args.pipeline == 'efficientnet_inference':
for i in range(args.gpu_num):
pipes.append(EfficientnetInferencePipeline(device_id=i + args.device_id))
elif args.pipeline == 'vit':
for i in range(args.gpu_num):
pipes.append(vit_pipeline(device_id=i + args.device_id))
else:
raise RuntimeError('Unsupported pipeline')
for p in pipes:
Expand Down

0 comments on commit 948562a

Please sign in to comment.