From 948562ae3bca8a8ebff2323eee28a9c0f2f9e2a9 Mon Sep 17 00:00:00 2001 From: Janusz Lisiecki Date: Fri, 20 Oct 2023 15:40:26 +0200 Subject: [PATCH] Erase VFR frames Signed-off-by: Janusz Lisiecki --- dali/operators/reader/nvdecoder/nvdecoder.cc | 104 ++++++++++++------- tools/hw_decoder_bench.py | 81 ++++++++++++++- 2 files changed, 145 insertions(+), 40 deletions(-) diff --git a/dali/operators/reader/nvdecoder/nvdecoder.cc b/dali/operators/reader/nvdecoder/nvdecoder.cc index e9d2d5ae55c..d17e2641ef6 100644 --- a/dali/operators/reader/nvdecoder/nvdecoder.cc +++ b/dali/operators/reader/nvdecoder/nvdecoder.cc @@ -336,42 +336,54 @@ int NvDecoder::handle_display_(CUVIDPARSERDISPINFO* disp_info) { req_ready_ = VidReqStatus::REQ_IN_PROGRESS; } - if (current_recv_.count <= 0) { - if (recv_queue_.empty()) { - LOG_LINE << "Ditching frame " << frame << " since " - << "the receive queue is empty." << std::endl; - return kNvcuvid_success; + while (1) { + if (current_recv_.count <= 0) { + if (recv_queue_.empty()) { + LOG_LINE << "Ditching frame " << frame << " since " + << "the receive queue is empty." << std::endl; + return kNvcuvid_success; + } + LOG_LINE << "Moving on to next request, " << recv_queue_.size() + << " reqs left" << std::endl; + current_recv_ = recv_queue_.pop(); + frame = av_rescale_q(disp_info->timestamp, + nv_time_base_, current_recv_.frame_base); } - LOG_LINE << "Moving on to next request, " << recv_queue_.size() - << " reqs left" << std::endl; - current_recv_ = recv_queue_.pop(); - frame = av_rescale_q(disp_info->timestamp, - nv_time_base_, current_recv_.frame_base); - } - if (stop_) return kNvcuvid_failure; + if (stop_) return kNvcuvid_failure; - if (current_recv_.count <= 0) { - // a new req with count <= 0 probably means we are finishing - // up and should just ditch this frame - LOG_LINE << "Ditching frame " << frame << "since current_recv_.count <= 0" << std::endl; - return kNvcuvid_success; - } + if (current_recv_.count <= 0) { + // a new req with count <= 0 probably means we are finishing + // up and should just ditch this frame + LOG_LINE << "Ditching frame " << frame << "since current_recv_.count <= 0" << std::endl; + return kNvcuvid_success; + } - if (frame < current_recv_.frame) { - // TODO(spanev) This definitely needs better error handling... - // Add exception? Directly or after countdown treshold? - LOG_LINE << "Ditching frame " << frame << " since we are waiting for " - << "frame " << current_recv_.frame << std::endl; - return kNvcuvid_success; - } else if (frame > current_recv_.frame) { - LOG_LINE << "Receive frame " << frame << " that is pas the exptected " - << "frame " << current_recv_.frame << std::endl; - req_ready_ = VidReqStatus::REQ_ERROR; - stop_ = true; - // Main thread is waiting on frame_queue_ - frame_queue_.shutdown(); - return kNvcuvid_failure; + if (frame < current_recv_.frame) { + // TODO(spanev) This definitely needs better error handling... + // Add exception? Directly or after countdown treshold? + LOG_LINE << "Ditching frame " << frame << " since we are waiting for " + << "frame " << current_recv_.frame << std::endl; + return kNvcuvid_success; + } else if (frame > current_recv_.frame) { + LOG_LINE << "Receive frame " << frame << " that is over the expected " + << "frame " << current_recv_.frame + << "\e[1mGoing ahead with empty frame " << frame + << " wanted count: " << current_recv_.count + << "\e[0m" << std::endl; + + current_recv_.frame += current_recv_.stride; + current_recv_.count -= current_recv_.stride; + + // push empty if we are past the expected one and check if the one we have now + // matches the next frame + frame_queue_.push(nullptr); + if (current_recv_.count <= 0) { + req_ready_ = VidReqStatus::REQ_READY; + } + continue; + } + break; } LOG_LINE << "\e[1mGoing ahead with frame " << frame @@ -425,13 +437,27 @@ void NvDecoder::receive_frames(SequenceWrapper& sequence) { auto* frame_disp_info = frame_queue_.pop(); if (stop_) break; - auto frame = MappedFrame{frame_disp_info, decoder_, stream_}; - sequence.timestamps.push_back(frame_disp_info->timestamp * av_q2d( - nv_time_base_)); - if (stop_) break; - convert_frame(frame, sequence, i); - // synchronize before MappedFrame is destroyed and cuvidUnmapVideoFrame is called - CUDA_CALL(cudaStreamSynchronize(stream_)); + if (frame_disp_info) { + auto frame = MappedFrame{frame_disp_info, decoder_, stream_}; + sequence.timestamps.push_back(frame_disp_info->timestamp * av_q2d( + nv_time_base_)); + if (stop_) break; + convert_frame(frame, sequence, i); + // synchronize before MappedFrame is destroyed and cuvidUnmapVideoFrame is called + CUDA_CALL(cudaStreamSynchronize(stream_)); + } else { + LOG_LINE << "Padding empty frame " << i << std::endl; + sequence.timestamps.push_back(-1); + auto data_size = i * volume(sequence.frame_shape()); + auto pad_size = volume(sequence.frame_shape()) * + dali::TypeTable::GetTypeInfo(sequence.dtype).size(); + TYPE_SWITCH(dtype_, type2id, OutputType, NVDECODER_SUPPORTED_TYPES, ( + cudaMemsetAsync(sequence.sequence.mutable_data() + data_size, 0, pad_size, + stream_); + ), DALI_FAIL(make_string("Not supported output type:", dtype_, // NOLINT + "Only DALI_UINT8 and DALI_FLOAT are supported as the decoder outputs."));); + CUDA_CALL(cudaStreamSynchronize(stream_)); + } } if (captured_exception_) std::rethrow_exception(captured_exception_); diff --git a/tools/hw_decoder_bench.py b/tools/hw_decoder_bench.py index 1012f0938c8..edd87ea5582 100644 --- a/tools/hw_decoder_bench.py +++ b/tools/hw_decoder_bench.py @@ -19,6 +19,8 @@ from nvidia.dali.pipeline import pipeline_def import random import numpy as np +import os +from nvidia.dali.auto_aug import auto_augment parser = argparse.ArgumentParser(description='DALI HW decoder benchmark') parser.add_argument('-b', dest='batch_size', help='batch size', default=1, type=int) @@ -35,7 +37,7 @@ input_files_arg.add_argument('-i', dest='images_dir', help='images dir') input_files_arg.add_argument('--image_list', dest='image_list', nargs='+', default=[], help='List of images used for the benchmark.') -parser.add_argument('-p', dest='pipeline', choices=['decoder', 'rn50', 'efficientnet_inference'], +parser.add_argument('-p', dest='pipeline', choices=['decoder', 'rn50', 'efficientnet_inference', 'vit'], help='pipeline to test', default='decoder', type=str) parser.add_argument('--width_hint', dest='width_hint', default=0, type=int) @@ -139,6 +141,80 @@ def create_input_tensor(batch_size, file_list): assert arr.shape == arrays[0].shape, "Arrays must have the same shape" return np.stack(arrays) +# Updated pipeline definition for ViT and the creation of the iterator compatible with CLU. This should be the only code neccesary +# for the updated approach. Plus it implements GPU support. + + +def non_image_preprocessing(raw_text): + return np.array([int(bytes(raw_text).decode('utf-8'))]) + + +@pipeline_def(batch_size=args.batch_size, + num_threads=args.num_threads, + device_id=args.device_id, + seed=0) +def vit_pipeline(is_training=False, image_shape=(384,384,3), num_classes = 1000): + files_paths = [os.path.join(args.images_dir, f) for f in os.listdir( + args.images_dir)] + + img, clss = fn.readers.webdataset( + paths=files_paths, + index_paths=None, + ext=['jpg', 'cls'], + missing_component_behavior='error', + random_shuffle=False, + shard_id=0, + num_shards=1, + pad_last_batch=False if is_training else True, + name='webdataset_reader') + + use_gpu = args.device == 'gpu' + labels = fn.python_function(clss, function=non_image_preprocessing, num_outputs=1) + if use_gpu: + labels = labels.gpu() + labels = fn.one_hot(labels, num_classes=num_classes) + + device = 'mixed' if use_gpu else 'cpu' + img = fn.decoders.image(img, device=device, output_type=types.RGB, + hw_decoder_load=args.hw_load, + preallocate_width_hint=args.width_hint, + preallocate_height_hint=args.height_hint) + + if is_training: + img = fn.random_resized_crop(img, size=image_shape[:-1]) + img = fn.flip(img, depthwise=0, horizontal=fn.random.coin_flip()) + + # color jitter + brightness = fn.random.uniform(range=[0.6, 1.4]) + contrast = fn.random.uniform(range=[0.6, 1.4]) + saturation = fn.random.uniform(range=[0.6, 1.4]) + hue = fn.random.uniform(range=[0.9, 1.1]) + img = fn.color_twist( + img, + brightness=brightness, + contrast=contrast, + hue=hue, + saturation=saturation) + + # auto-augment + # `shape` controls the magnitude of the translation operations + img = auto_augment.auto_augment_image_net(img) + else: + img = fn.resize(img, size=image_shape[:-1]) + + # normalize + # https://github.com/NVIDIA/DALI/issues/4469 + mean = np.asarray([0.485, 0.456, 0.406])[None, None, :] + std = np.asarray([0.229, 0.224, 0.225])[None, None, :] + scale = 1 / 255. + img = fn.normalize( + img, + mean=mean / scale, + stddev=std, + scale=scale, + dtype=types.FLOAT) + + return img, labels pipes = [] if args.pipeline == 'decoder': @@ -150,6 +226,9 @@ def create_input_tensor(batch_size, file_list): elif args.pipeline == 'efficientnet_inference': for i in range(args.gpu_num): pipes.append(EfficientnetInferencePipeline(device_id=i + args.device_id)) +elif args.pipeline == 'vit': + for i in range(args.gpu_num): + pipes.append(vit_pipeline(device_id=i + args.device_id)) else: raise RuntimeError('Unsupported pipeline') for p in pipes: