Skip to content

Commit

Permalink
TypeTable/TypeInfo optimization (#5634)
Browse files Browse the repository at this point in the history
* TypeTable/TypeInfo optimization

- TypeInfo uses string_view for type name
- TypeTable stores types in an array
- TypeTable read access is lockless

---------

Signed-off-by: Michal Zientkiewicz <[email protected]>
  • Loading branch information
mzient committed Sep 17, 2024
1 parent 869afb3 commit 8904209
Show file tree
Hide file tree
Showing 4 changed files with 174 additions and 77 deletions.
22 changes: 13 additions & 9 deletions dali/operators/math/expressions/expression_tree.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2019-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -84,9 +84,10 @@ class ExprNode {

virtual std::string GetOutputDesc() const {
const auto &op_type = TypeTable::GetTypeInfo(GetTypeId()).name();
std::string result = GetAbbreviation(GetNodeType());
result += IsScalarLike(GetShape()) ? "C:" : "T:";
return result + op_type;
return make_string(
GetAbbreviation(GetNodeType()),
IsScalarLike(GetShape()) ? "C:" : "T:",
op_type);
}

virtual NodeType GetNodeType() const = 0;
Expand Down Expand Up @@ -140,15 +141,18 @@ class ExprFunc : public ExprNode {

std::string GetNodeDesc() const override {
const auto &op_type = TypeTable::GetTypeInfo(GetTypeId()).name();
std::string result = func_name_ + (IsScalarLike(GetShape()) ? ":C:" : ":T:") + op_type + "(";
std::stringstream result;
result << func_name_ << (IsScalarLike(GetShape()) ? ":C:" : ":T:") << op_type << "(";

for (int i = 0; i < GetSubexpressionCount(); i++) {
result += (*this)[i].GetOutputDesc();
result << (*this)[i].GetOutputDesc();
if (i < GetSubexpressionCount() - 1) {
result += " ";
result << " ";
}
}
result += ")";
return result;

result << ")";
return result.str();
}

NodeType GetNodeType() const override {
Expand Down
4 changes: 2 additions & 2 deletions dali/pipeline/data/types.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2017-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2017-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -14,7 +14,7 @@

#define DALI_TYPENAME_REGISTERER(Type, dtype) \
{ \
return to_string(dtype); \
return dali::TypeName(dtype); \
}

#define DALI_TYPEID_REGISTERER(Type, dtype) \
Expand Down
170 changes: 109 additions & 61 deletions dali/pipeline/data/types.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2017-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2017-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -15,16 +15,20 @@
#ifndef DALI_PIPELINE_DATA_TYPES_H_
#define DALI_PIPELINE_DATA_TYPES_H_

#include <algorithm>
#include <atomic>
#include <cstdint>
#include <cstring>
#include <functional>
#include <list>
#include <mutex>
#include <string>
#include <type_traits>
#include <typeindex>
#include <typeinfo>
#include <unordered_map>
#include <vector>
#include "dali/core/util.h"
#include "dali/core/common.h"
#include "dali/core/spinlock.h"
#include "dali/core/float16.h"
Expand Down Expand Up @@ -123,7 +127,8 @@ enum DALIDataType : int {
DALI_PYTHON_OBJECT = 24,
DALI_TENSOR_LAYOUT_VEC = 25,
DALI_DATA_TYPE_VEC = 26,
DALI_DATATYPE_END = 1000
DALI_NUM_BUILTIN_TYPES,
DALI_CUSTOM_TYPE_START = 1001
};

inline const char *GetBuiltinTypeName(DALIDataType t) {
Expand Down Expand Up @@ -397,7 +402,7 @@ class DLL_PUBLIC TypeInfo {
return type_size_;
}

DLL_PUBLIC inline const string &name() const {
DLL_PUBLIC inline std::string_view name() const {
return name_;
}

Expand All @@ -410,12 +415,12 @@ class DLL_PUBLIC TypeInfo {

DALIDataType id_ = DALI_NO_TYPE;
size_t type_size_ = 0;
std::string name_ = dali::to_string(DALI_NO_TYPE);
std::string_view name_ = GetBuiltinTypeName(DALI_NO_TYPE);
};

template <typename T>
struct TypeNameHelper {
static string GetTypeName() {
static std::string_view GetTypeName() {
return typeid(T).name();
}
};
Expand All @@ -427,23 +432,23 @@ class DLL_PUBLIC TypeTable {
public:
template <typename T>
DLL_PUBLIC static DALIDataType GetTypeId() {
auto &inst = instance();
static DALIDataType type_id = inst.RegisterType<T>(static_cast<DALIDataType>(++inst.index_));
static DALIDataType type_id = instance().RegisterType<T>(
static_cast<DALIDataType>(instance().next_id_++));
return type_id;
}

template <typename T>
DLL_PUBLIC static string GetTypeName() {
DLL_PUBLIC static std::string_view GetTypeName() {
return TypeNameHelper<T>::GetTypeName();
}

DLL_PUBLIC static const TypeInfo *TryGetTypeInfo(DALIDataType dtype) {
auto &inst = instance();
std::lock_guard<spinlock> guard(inst.lock_);
auto id_it = inst.type_info_map_.find(dtype);
if (id_it == inst.type_info_map_.end())
auto *types = instance().type_info_map_;
assert(types);
size_t idx = dtype - DALI_NO_TYPE;
if (idx >= types->size())
return nullptr;
return &id_it->second;
return (*types)[idx];
}

DLL_PUBLIC static const TypeInfo &GetTypeInfo(DALIDataType dtype) {
Expand All @@ -465,42 +470,77 @@ class DLL_PUBLIC TypeTable {

template <typename T>
DALIDataType RegisterType(DALIDataType dtype) {
std::lock_guard<spinlock> guard(lock_);
// Check the map for this types id
auto id_it = type_map_.find(typeid(T));

if (id_it == type_map_.end()) {
type_map_[typeid(T)] = dtype;
TypeInfo t;
t.SetType<T>(dtype);
type_info_map_[dtype] = t;
static DALIDataType id = [dtype, this]() {
std::lock_guard guard(insert_lock_);
size_t idx = dtype - DALI_NO_TYPE;
// We need the map because this function (and the static variable) may be instantiated
// in multiple shared objects whereas the map instance is tied to one well defined
// instance of the TypeTable returned by `instance()`.
auto [it, inserted] = type_map_.emplace(typeid(T), dtype);
if (!inserted)
return it->second;
if (!type_info_map_ || idx >= type_info_map_->size()) {
constexpr size_t kMinCapacity = next_pow2(DALI_CUSTOM_TYPE_START + 100);
// we don't need to look at the previous capacity to achieve std::vector-like growth
size_t capacity = next_pow2(idx + 1);
if (capacity < kMinCapacity)
capacity = kMinCapacity;
auto &m = type_info_maps_.emplace_back();
m.resize(capacity);
if (type_info_map_) // copy the old map into the new one
std::copy(type_info_map_->begin(), type_info_map_->end(), m.begin());
// The new map contains everything that the old map did - we can "publish" it.
// Make sure that the compiler doesn't reorder after the "publishing".
std::atomic_thread_fence(std::memory_order_release);
// Publish the new map.
type_info_map_ = &m;
}
TypeInfo &info = type_infos_.emplace_back();
info.SetType<T>(dtype);
if ((*type_info_map_)[idx] != nullptr)
DALI_FAIL("The type id ", idx, " is already taken by type ",
(*type_info_map_)[idx]->name());
(*type_info_map_)[idx] = &info;

return dtype;
} else {
return id_it->second;
}
}();
return id;
}


spinlock lock_;
using TypeInfoMap = std::vector<TypeInfo*>;
// The "current" type map - it's just a vector that maps type_id (adjusted and treated as index)
// to a TypeInfo pointer.
TypeInfoMap *type_info_map_ = nullptr;

std::mutex insert_lock_;
// All type info maps - old ones are never deleted to avoid locks when only read access is needed.
std::list<TypeInfoMap> type_info_maps_;
// The actual type info objects. Each type has exactly one TypeInfo - even if we need to grow
// the storage - hence, we need to store TypeInfo* in the pas (see typedef TypeInfoMap) and
// we need to store TypeInfo instances in a container that never invalidates pointers
// (e.g. a list).
std::list<TypeInfo> type_infos_;
// This is necessary because it turns out that static field in RegisterType has many instances
// in a program built with multiple shared libraries.
std::unordered_map<std::type_index, DALIDataType> type_map_;
// Unordered maps do not work with enums,
// so we need to use underlying type instead of DALIDataType
std::unordered_map<std::underlying_type_t<DALIDataType>, TypeInfo> type_info_map_;
int index_ = DALI_DATATYPE_END;

int next_id_ = DALI_CUSTOM_TYPE_START;
DLL_PUBLIC static TypeTable &instance();
};

template <typename T, typename A>
struct TypeNameHelper<std::vector<T, A> > {
static string GetTypeName() {
return "list of " + TypeTable::GetTypeName<T>();
static std::string_view GetTypeName() {
static const std::string name = "list of " + std::string(TypeTable::GetTypeName<T>());
return name;
}
};

template <typename T, size_t N>
struct TypeNameHelper<std::array<T, N> > {
static string GetTypeName() {
return "list of " + TypeTable::GetTypeName<T>();
static std::string_view GetTypeName() {
static const std::string name = "list of " + std::string(TypeTable::GetTypeName<T>());
return name;
}
};

Expand All @@ -513,8 +553,9 @@ template <typename T>
void TypeInfo::SetType(DALIDataType dtype) {
// Note: We enforce the fact that NoType is invalid by
// explicitly setting its type size as 0
type_size_ = std::is_same<T, NoType>::value ? 0 : sizeof(T);
if (!std::is_same<T, NoType>::value) {
constexpr bool is_no_type = std::is_same_v<T, NoType>;
type_size_ = is_no_type ? 0 : sizeof(T);
if constexpr (!is_no_type) {
id_ = dtype != DALI_NO_TYPE ? dtype : TypeTable::GetTypeId<T>();
} else {
id_ = DALI_NO_TYPE;
Expand Down Expand Up @@ -555,17 +596,34 @@ DLL_PUBLIC inline bool IsValidType(const TypeInfo &type) {
return !IsType<NoType>(type);
}

inline std::string_view TypeName(DALIDataType dtype) {
if (const char *builtin = GetBuiltinTypeName(dtype))
return builtin;
auto *info = TypeTable::TryGetTypeInfo(dtype);
if (info)
return info->name();
return "<unknown>";
}

inline std::string to_string(DALIDataType dtype) {
std::string_view name = TypeName(dtype);
if (name == "<unknown>")
return "unknown type: " + std::to_string(static_cast<int>(dtype));
else
return std::string(name);
}

// Used to define a type for use in dali. Inserts the type into the
// TypeTable w/ a unique id and creates a method to get the name of
// the type as a string. This does not work for non-fundamental types,
// as we do not have any mechanism for calling the constructor of the
// type when the buffer allocates the memory.
#define DALI_REGISTER_TYPE(Type, dtype) \
template <> DLL_PUBLIC string TypeTable::GetTypeName<Type>() \
DALI_TYPENAME_REGISTERER(Type, dtype); \
template <> DLL_PUBLIC DALIDataType TypeTable::GetTypeId<Type>() \
DALI_TYPEID_REGISTERER(Type, dtype); \
DALI_STATIC_TYPE_MAPPING(Type, dtype); \
#define DALI_REGISTER_TYPE(Type, dtype) \
template <> DLL_PUBLIC std::string_view TypeTable::GetTypeName<Type>() \
DALI_TYPENAME_REGISTERER(Type, dtype); \
template <> DLL_PUBLIC DALIDataType TypeTable::GetTypeId<Type>() \
DALI_TYPEID_REGISTERER(Type, dtype); \
DALI_STATIC_TYPE_MAPPING(Type, dtype); \
DALI_REGISTER_TYPE_IMPL(Type, dtype);

// Instantiate some basic types
Expand Down Expand Up @@ -600,25 +658,15 @@ DALI_REGISTER_TYPE(std::vector<float>, DALI_FLOAT_VEC);
DALI_REGISTER_TYPE(std::vector<TensorLayout>, DALI_TENSOR_LAYOUT_VEC);
DALI_REGISTER_TYPE(std::vector<DALIDataType>, DALI_DATA_TYPE_VEC);


inline std::string to_string(DALIDataType dtype) {
if (const char *builtin = GetBuiltinTypeName(dtype))
return builtin;
auto *info = TypeTable::TryGetTypeInfo(dtype);
if (info)
return info->name();
return "unknown type: " + std::to_string(static_cast<int>(dtype));
}

inline std::ostream &operator<<(std::ostream &os, DALIDataType dtype) {
if (const char *builtin = GetBuiltinTypeName(dtype))
return os << builtin;
auto *info = TypeTable::TryGetTypeInfo(dtype);
if (info)
return os << info->name();
// Use string concatenation so that the result is the same as in to_string, unaffected by
// formatting & other settings in `os`.
return os << ("unknown type: " + std::to_string(static_cast<int>(dtype)));
std::string_view name = TypeName(dtype);
if (name == "<unknown>") {
// Use string concatenation so that the result is the same as in to_string, unaffected by
// formatting & other settings in `os`.
return os << ("unknown type: " + std::to_string(static_cast<int>(dtype)));
} else {
return os << name;
}
}

#define DALI_INTEGRAL_TYPES uint8_t, int8_t, uint16_t, int16_t, uint32_t, int32_t, uint64_t, int64_t
Expand Down
Loading

0 comments on commit 8904209

Please sign in to comment.