327 lines
13 KiB
C
327 lines
13 KiB
C
|
#ifndef C10_UTIL_REGISTRY_H_
|
||
|
#define C10_UTIL_REGISTRY_H_
|
||
|
|
||
|
/**
|
||
|
* Simple registry implementation that uses static variables to
|
||
|
* register object creators during program initialization time.
|
||
|
*/
|
||
|
|
||
|
// NB: This Registry works poorly when you have other namespaces.
|
||
|
// Make all macro invocations from inside the at namespace.
|
||
|
|
||
|
#include <cstdio>
|
||
|
#include <cstdlib>
|
||
|
#include <functional>
|
||
|
#include <memory>
|
||
|
#include <mutex>
|
||
|
#include <stdexcept>
|
||
|
#include <string>
|
||
|
#include <unordered_map>
|
||
|
#include <vector>
|
||
|
|
||
|
#include <c10/macros/Export.h>
|
||
|
#include <c10/macros/Macros.h>
|
||
|
#include <c10/util/Type.h>
|
||
|
|
||
|
namespace c10 {
|
||
|
|
||
|
template <typename KeyType>
|
||
|
inline std::string KeyStrRepr(const KeyType& /*key*/) {
|
||
|
return "[key type printing not supported]";
|
||
|
}
|
||
|
|
||
|
template <>
|
||
|
inline std::string KeyStrRepr(const std::string& key) {
|
||
|
return key;
|
||
|
}
|
||
|
|
||
|
enum RegistryPriority {
|
||
|
REGISTRY_FALLBACK = 1,
|
||
|
REGISTRY_DEFAULT = 2,
|
||
|
REGISTRY_PREFERRED = 3,
|
||
|
};
|
||
|
|
||
|
/**
|
||
|
* @brief A template class that allows one to register classes by keys.
|
||
|
*
|
||
|
* The keys are usually a std::string specifying the name, but can be anything
|
||
|
* that can be used in a std::map.
|
||
|
*
|
||
|
* You should most likely not use the Registry class explicitly, but use the
|
||
|
* helper macros below to declare specific registries as well as registering
|
||
|
* objects.
|
||
|
*/
|
||
|
template <class SrcType, class ObjectPtrType, class... Args>
|
||
|
class Registry {
|
||
|
public:
|
||
|
typedef std::function<ObjectPtrType(Args...)> Creator;
|
||
|
|
||
|
Registry(bool warning = true) : registry_(), priority_(), warning_(warning) {}
|
||
|
|
||
|
void Register(
|
||
|
const SrcType& key,
|
||
|
Creator creator,
|
||
|
const RegistryPriority priority = REGISTRY_DEFAULT) {
|
||
|
std::lock_guard<std::mutex> lock(register_mutex_);
|
||
|
// The if statement below is essentially the same as the following line:
|
||
|
// TORCH_CHECK_EQ(registry_.count(key), 0) << "Key " << key
|
||
|
// << " registered twice.";
|
||
|
// However, TORCH_CHECK_EQ depends on google logging, and since registration
|
||
|
// is carried out at static initialization time, we do not want to have an
|
||
|
// explicit dependency on glog's initialization function.
|
||
|
if (registry_.count(key) != 0) {
|
||
|
auto cur_priority = priority_[key];
|
||
|
if (priority > cur_priority) {
|
||
|
#ifdef DEBUG
|
||
|
std::string warn_msg =
|
||
|
"Overwriting already registered item for key " + KeyStrRepr(key);
|
||
|
fprintf(stderr, "%s\n", warn_msg.c_str());
|
||
|
#endif
|
||
|
registry_[key] = creator;
|
||
|
priority_[key] = priority;
|
||
|
} else if (priority == cur_priority) {
|
||
|
std::string err_msg =
|
||
|
"Key already registered with the same priority: " + KeyStrRepr(key);
|
||
|
fprintf(stderr, "%s\n", err_msg.c_str());
|
||
|
if (terminate_) {
|
||
|
std::exit(1);
|
||
|
} else {
|
||
|
throw std::runtime_error(err_msg);
|
||
|
}
|
||
|
} else if (warning_) {
|
||
|
std::string warn_msg =
|
||
|
"Higher priority item already registered, skipping registration of " +
|
||
|
KeyStrRepr(key);
|
||
|
fprintf(stderr, "%s\n", warn_msg.c_str());
|
||
|
}
|
||
|
} else {
|
||
|
registry_[key] = creator;
|
||
|
priority_[key] = priority;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
void Register(
|
||
|
const SrcType& key,
|
||
|
Creator creator,
|
||
|
const std::string& help_msg,
|
||
|
const RegistryPriority priority = REGISTRY_DEFAULT) {
|
||
|
Register(key, creator, priority);
|
||
|
help_message_[key] = help_msg;
|
||
|
}
|
||
|
|
||
|
inline bool Has(const SrcType& key) {
|
||
|
return (registry_.count(key) != 0);
|
||
|
}
|
||
|
|
||
|
ObjectPtrType Create(const SrcType& key, Args... args) {
|
||
|
auto it = registry_.find(key);
|
||
|
if (it == registry_.end()) {
|
||
|
// Returns nullptr if the key is not registered.
|
||
|
return nullptr;
|
||
|
}
|
||
|
return it->second(args...);
|
||
|
}
|
||
|
|
||
|
/**
|
||
|
* Returns the keys currently registered as a std::vector.
|
||
|
*/
|
||
|
std::vector<SrcType> Keys() const {
|
||
|
std::vector<SrcType> keys;
|
||
|
keys.reserve(registry_.size());
|
||
|
for (const auto& it : registry_) {
|
||
|
keys.push_back(it.first);
|
||
|
}
|
||
|
return keys;
|
||
|
}
|
||
|
|
||
|
inline const std::unordered_map<SrcType, std::string>& HelpMessage() const {
|
||
|
return help_message_;
|
||
|
}
|
||
|
|
||
|
const char* HelpMessage(const SrcType& key) const {
|
||
|
auto it = help_message_.find(key);
|
||
|
if (it == help_message_.end()) {
|
||
|
return nullptr;
|
||
|
}
|
||
|
return it->second.c_str();
|
||
|
}
|
||
|
|
||
|
// Used for testing, if terminate is unset, Registry throws instead of
|
||
|
// calling std::exit
|
||
|
void SetTerminate(bool terminate) {
|
||
|
terminate_ = terminate;
|
||
|
}
|
||
|
|
||
|
private:
|
||
|
std::unordered_map<SrcType, Creator> registry_;
|
||
|
std::unordered_map<SrcType, RegistryPriority> priority_;
|
||
|
bool terminate_{true};
|
||
|
const bool warning_;
|
||
|
std::unordered_map<SrcType, std::string> help_message_;
|
||
|
std::mutex register_mutex_;
|
||
|
|
||
|
C10_DISABLE_COPY_AND_ASSIGN(Registry);
|
||
|
};
|
||
|
|
||
|
template <class SrcType, class ObjectPtrType, class... Args>
|
||
|
class Registerer {
|
||
|
public:
|
||
|
explicit Registerer(
|
||
|
const SrcType& key,
|
||
|
Registry<SrcType, ObjectPtrType, Args...>* registry,
|
||
|
typename Registry<SrcType, ObjectPtrType, Args...>::Creator creator,
|
||
|
const std::string& help_msg = "") {
|
||
|
registry->Register(key, creator, help_msg);
|
||
|
}
|
||
|
|
||
|
explicit Registerer(
|
||
|
const SrcType& key,
|
||
|
const RegistryPriority priority,
|
||
|
Registry<SrcType, ObjectPtrType, Args...>* registry,
|
||
|
typename Registry<SrcType, ObjectPtrType, Args...>::Creator creator,
|
||
|
const std::string& help_msg = "") {
|
||
|
registry->Register(key, creator, help_msg, priority);
|
||
|
}
|
||
|
|
||
|
template <class DerivedType>
|
||
|
static ObjectPtrType DefaultCreator(Args... args) {
|
||
|
return ObjectPtrType(new DerivedType(args...));
|
||
|
}
|
||
|
};
|
||
|
|
||
|
/**
|
||
|
* C10_DECLARE_TYPED_REGISTRY is a macro that expands to a function
|
||
|
* declaration, as well as creating a convenient typename for its corresponding
|
||
|
* registerer.
|
||
|
*/
|
||
|
// Note on C10_IMPORT and C10_EXPORT below: we need to explicitly mark DECLARE
|
||
|
// as import and DEFINE as export, because these registry macros will be used
|
||
|
// in downstream shared libraries as well, and one cannot use *_API - the API
|
||
|
// macro will be defined on a per-shared-library basis. Semantically, when one
|
||
|
// declares a typed registry it is always going to be IMPORT, and when one
|
||
|
// defines a registry (which should happen ONLY ONCE and ONLY IN SOURCE FILE),
|
||
|
// the instantiation unit is always going to be exported.
|
||
|
//
|
||
|
// The only unique condition is when in the same file one does DECLARE and
|
||
|
// DEFINE - in Windows compilers, this generates a warning that dllimport and
|
||
|
// dllexport are mixed, but the warning is fine and linker will be properly
|
||
|
// exporting the symbol. Same thing happens in the gflags flag declaration and
|
||
|
// definition caes.
|
||
|
#define C10_DECLARE_TYPED_REGISTRY( \
|
||
|
RegistryName, SrcType, ObjectType, PtrType, ...) \
|
||
|
C10_API ::c10::Registry<SrcType, PtrType<ObjectType>, ##__VA_ARGS__>* \
|
||
|
RegistryName(); \
|
||
|
typedef ::c10::Registerer<SrcType, PtrType<ObjectType>, ##__VA_ARGS__> \
|
||
|
Registerer##RegistryName
|
||
|
|
||
|
#define TORCH_DECLARE_TYPED_REGISTRY( \
|
||
|
RegistryName, SrcType, ObjectType, PtrType, ...) \
|
||
|
TORCH_API ::c10::Registry<SrcType, PtrType<ObjectType>, ##__VA_ARGS__>* \
|
||
|
RegistryName(); \
|
||
|
typedef ::c10::Registerer<SrcType, PtrType<ObjectType>, ##__VA_ARGS__> \
|
||
|
Registerer##RegistryName
|
||
|
|
||
|
#define C10_DEFINE_TYPED_REGISTRY( \
|
||
|
RegistryName, SrcType, ObjectType, PtrType, ...) \
|
||
|
C10_EXPORT ::c10::Registry<SrcType, PtrType<ObjectType>, ##__VA_ARGS__>* \
|
||
|
RegistryName() { \
|
||
|
static ::c10::Registry<SrcType, PtrType<ObjectType>, ##__VA_ARGS__>* \
|
||
|
registry = new ::c10:: \
|
||
|
Registry<SrcType, PtrType<ObjectType>, ##__VA_ARGS__>(); \
|
||
|
return registry; \
|
||
|
}
|
||
|
|
||
|
#define C10_DEFINE_TYPED_REGISTRY_WITHOUT_WARNING( \
|
||
|
RegistryName, SrcType, ObjectType, PtrType, ...) \
|
||
|
C10_EXPORT ::c10::Registry<SrcType, PtrType<ObjectType>, ##__VA_ARGS__>* \
|
||
|
RegistryName() { \
|
||
|
static ::c10::Registry<SrcType, PtrType<ObjectType>, ##__VA_ARGS__>* \
|
||
|
registry = \
|
||
|
new ::c10::Registry<SrcType, PtrType<ObjectType>, ##__VA_ARGS__>( \
|
||
|
false); \
|
||
|
return registry; \
|
||
|
}
|
||
|
|
||
|
// Note(Yangqing): The __VA_ARGS__ below allows one to specify a templated
|
||
|
// creator with comma in its templated arguments.
|
||
|
#define C10_REGISTER_TYPED_CREATOR(RegistryName, key, ...) \
|
||
|
static Registerer##RegistryName C10_ANONYMOUS_VARIABLE(g_##RegistryName)( \
|
||
|
key, RegistryName(), ##__VA_ARGS__);
|
||
|
|
||
|
#define C10_REGISTER_TYPED_CREATOR_WITH_PRIORITY( \
|
||
|
RegistryName, key, priority, ...) \
|
||
|
static Registerer##RegistryName C10_ANONYMOUS_VARIABLE(g_##RegistryName)( \
|
||
|
key, priority, RegistryName(), ##__VA_ARGS__);
|
||
|
|
||
|
#define C10_REGISTER_TYPED_CLASS(RegistryName, key, ...) \
|
||
|
static Registerer##RegistryName C10_ANONYMOUS_VARIABLE(g_##RegistryName)( \
|
||
|
key, \
|
||
|
RegistryName(), \
|
||
|
Registerer##RegistryName::DefaultCreator<__VA_ARGS__>, \
|
||
|
::c10::demangle_type<__VA_ARGS__>());
|
||
|
|
||
|
#define C10_REGISTER_TYPED_CLASS_WITH_PRIORITY( \
|
||
|
RegistryName, key, priority, ...) \
|
||
|
static Registerer##RegistryName C10_ANONYMOUS_VARIABLE(g_##RegistryName)( \
|
||
|
key, \
|
||
|
priority, \
|
||
|
RegistryName(), \
|
||
|
Registerer##RegistryName::DefaultCreator<__VA_ARGS__>, \
|
||
|
::c10::demangle_type<__VA_ARGS__>());
|
||
|
|
||
|
// C10_DECLARE_REGISTRY and C10_DEFINE_REGISTRY are hard-wired to use
|
||
|
// std::string as the key type, because that is the most commonly used cases.
|
||
|
#define C10_DECLARE_REGISTRY(RegistryName, ObjectType, ...) \
|
||
|
C10_DECLARE_TYPED_REGISTRY( \
|
||
|
RegistryName, std::string, ObjectType, std::unique_ptr, ##__VA_ARGS__)
|
||
|
|
||
|
#define TORCH_DECLARE_REGISTRY(RegistryName, ObjectType, ...) \
|
||
|
TORCH_DECLARE_TYPED_REGISTRY( \
|
||
|
RegistryName, std::string, ObjectType, std::unique_ptr, ##__VA_ARGS__)
|
||
|
|
||
|
#define C10_DEFINE_REGISTRY(RegistryName, ObjectType, ...) \
|
||
|
C10_DEFINE_TYPED_REGISTRY( \
|
||
|
RegistryName, std::string, ObjectType, std::unique_ptr, ##__VA_ARGS__)
|
||
|
|
||
|
#define C10_DEFINE_REGISTRY_WITHOUT_WARNING(RegistryName, ObjectType, ...) \
|
||
|
C10_DEFINE_TYPED_REGISTRY_WITHOUT_WARNING( \
|
||
|
RegistryName, std::string, ObjectType, std::unique_ptr, ##__VA_ARGS__)
|
||
|
|
||
|
#define C10_DECLARE_SHARED_REGISTRY(RegistryName, ObjectType, ...) \
|
||
|
C10_DECLARE_TYPED_REGISTRY( \
|
||
|
RegistryName, std::string, ObjectType, std::shared_ptr, ##__VA_ARGS__)
|
||
|
|
||
|
#define TORCH_DECLARE_SHARED_REGISTRY(RegistryName, ObjectType, ...) \
|
||
|
TORCH_DECLARE_TYPED_REGISTRY( \
|
||
|
RegistryName, std::string, ObjectType, std::shared_ptr, ##__VA_ARGS__)
|
||
|
|
||
|
#define C10_DEFINE_SHARED_REGISTRY(RegistryName, ObjectType, ...) \
|
||
|
C10_DEFINE_TYPED_REGISTRY( \
|
||
|
RegistryName, std::string, ObjectType, std::shared_ptr, ##__VA_ARGS__)
|
||
|
|
||
|
#define C10_DEFINE_SHARED_REGISTRY_WITHOUT_WARNING( \
|
||
|
RegistryName, ObjectType, ...) \
|
||
|
C10_DEFINE_TYPED_REGISTRY_WITHOUT_WARNING( \
|
||
|
RegistryName, std::string, ObjectType, std::shared_ptr, ##__VA_ARGS__)
|
||
|
|
||
|
// C10_REGISTER_CREATOR and C10_REGISTER_CLASS are hard-wired to use std::string
|
||
|
// as the key
|
||
|
// type, because that is the most commonly used cases.
|
||
|
#define C10_REGISTER_CREATOR(RegistryName, key, ...) \
|
||
|
C10_REGISTER_TYPED_CREATOR(RegistryName, #key, __VA_ARGS__)
|
||
|
|
||
|
#define C10_REGISTER_CREATOR_WITH_PRIORITY(RegistryName, key, priority, ...) \
|
||
|
C10_REGISTER_TYPED_CREATOR_WITH_PRIORITY( \
|
||
|
RegistryName, #key, priority, __VA_ARGS__)
|
||
|
|
||
|
#define C10_REGISTER_CLASS(RegistryName, key, ...) \
|
||
|
C10_REGISTER_TYPED_CLASS(RegistryName, #key, __VA_ARGS__)
|
||
|
|
||
|
#define C10_REGISTER_CLASS_WITH_PRIORITY(RegistryName, key, priority, ...) \
|
||
|
C10_REGISTER_TYPED_CLASS_WITH_PRIORITY( \
|
||
|
RegistryName, #key, priority, __VA_ARGS__)
|
||
|
|
||
|
} // namespace c10
|
||
|
|
||
|
#endif // C10_UTIL_REGISTRY_H_
|