#ifndef C10_UTIL_REGISTRY_H_ #define C10_UTIL_REGISTRY_H_ /** * Simple registry implementation that uses static variables to * register object creators during program initialization time. */ // NB: This Registry works poorly when you have other namespaces. // Make all macro invocations from inside the at namespace. #include #include #include #include #include #include #include #include #include #include #include #include namespace c10 { template inline std::string KeyStrRepr(const KeyType& /*key*/) { return "[key type printing not supported]"; } template <> inline std::string KeyStrRepr(const std::string& key) { return key; } enum RegistryPriority { REGISTRY_FALLBACK = 1, REGISTRY_DEFAULT = 2, REGISTRY_PREFERRED = 3, }; /** * @brief A template class that allows one to register classes by keys. * * The keys are usually a std::string specifying the name, but can be anything * that can be used in a std::map. * * You should most likely not use the Registry class explicitly, but use the * helper macros below to declare specific registries as well as registering * objects. */ template class Registry { public: typedef std::function Creator; Registry(bool warning = true) : registry_(), priority_(), warning_(warning) {} void Register( const SrcType& key, Creator creator, const RegistryPriority priority = REGISTRY_DEFAULT) { std::lock_guard lock(register_mutex_); // The if statement below is essentially the same as the following line: // TORCH_CHECK_EQ(registry_.count(key), 0) << "Key " << key // << " registered twice."; // However, TORCH_CHECK_EQ depends on google logging, and since registration // is carried out at static initialization time, we do not want to have an // explicit dependency on glog's initialization function. if (registry_.count(key) != 0) { auto cur_priority = priority_[key]; if (priority > cur_priority) { #ifdef DEBUG std::string warn_msg = "Overwriting already registered item for key " + KeyStrRepr(key); fprintf(stderr, "%s\n", warn_msg.c_str()); #endif registry_[key] = creator; priority_[key] = priority; } else if (priority == cur_priority) { std::string err_msg = "Key already registered with the same priority: " + KeyStrRepr(key); fprintf(stderr, "%s\n", err_msg.c_str()); if (terminate_) { std::exit(1); } else { throw std::runtime_error(err_msg); } } else if (warning_) { std::string warn_msg = "Higher priority item already registered, skipping registration of " + KeyStrRepr(key); fprintf(stderr, "%s\n", warn_msg.c_str()); } } else { registry_[key] = creator; priority_[key] = priority; } } void Register( const SrcType& key, Creator creator, const std::string& help_msg, const RegistryPriority priority = REGISTRY_DEFAULT) { Register(key, creator, priority); help_message_[key] = help_msg; } inline bool Has(const SrcType& key) { return (registry_.count(key) != 0); } ObjectPtrType Create(const SrcType& key, Args... args) { auto it = registry_.find(key); if (it == registry_.end()) { // Returns nullptr if the key is not registered. return nullptr; } return it->second(args...); } /** * Returns the keys currently registered as a std::vector. */ std::vector Keys() const { std::vector keys; keys.reserve(registry_.size()); for (const auto& it : registry_) { keys.push_back(it.first); } return keys; } inline const std::unordered_map& HelpMessage() const { return help_message_; } const char* HelpMessage(const SrcType& key) const { auto it = help_message_.find(key); if (it == help_message_.end()) { return nullptr; } return it->second.c_str(); } // Used for testing, if terminate is unset, Registry throws instead of // calling std::exit void SetTerminate(bool terminate) { terminate_ = terminate; } private: std::unordered_map registry_; std::unordered_map priority_; bool terminate_{true}; const bool warning_; std::unordered_map help_message_; std::mutex register_mutex_; C10_DISABLE_COPY_AND_ASSIGN(Registry); }; template class Registerer { public: explicit Registerer( const SrcType& key, Registry* registry, typename Registry::Creator creator, const std::string& help_msg = "") { registry->Register(key, creator, help_msg); } explicit Registerer( const SrcType& key, const RegistryPriority priority, Registry* registry, typename Registry::Creator creator, const std::string& help_msg = "") { registry->Register(key, creator, help_msg, priority); } template static ObjectPtrType DefaultCreator(Args... args) { return ObjectPtrType(new DerivedType(args...)); } }; /** * C10_DECLARE_TYPED_REGISTRY is a macro that expands to a function * declaration, as well as creating a convenient typename for its corresponding * registerer. */ // Note on C10_IMPORT and C10_EXPORT below: we need to explicitly mark DECLARE // as import and DEFINE as export, because these registry macros will be used // in downstream shared libraries as well, and one cannot use *_API - the API // macro will be defined on a per-shared-library basis. Semantically, when one // declares a typed registry it is always going to be IMPORT, and when one // defines a registry (which should happen ONLY ONCE and ONLY IN SOURCE FILE), // the instantiation unit is always going to be exported. // // The only unique condition is when in the same file one does DECLARE and // DEFINE - in Windows compilers, this generates a warning that dllimport and // dllexport are mixed, but the warning is fine and linker will be properly // exporting the symbol. Same thing happens in the gflags flag declaration and // definition caes. #define C10_DECLARE_TYPED_REGISTRY( \ RegistryName, SrcType, ObjectType, PtrType, ...) \ C10_API ::c10::Registry, ##__VA_ARGS__>* \ RegistryName(); \ typedef ::c10::Registerer, ##__VA_ARGS__> \ Registerer##RegistryName #define TORCH_DECLARE_TYPED_REGISTRY( \ RegistryName, SrcType, ObjectType, PtrType, ...) \ TORCH_API ::c10::Registry, ##__VA_ARGS__>* \ RegistryName(); \ typedef ::c10::Registerer, ##__VA_ARGS__> \ Registerer##RegistryName #define C10_DEFINE_TYPED_REGISTRY( \ RegistryName, SrcType, ObjectType, PtrType, ...) \ C10_EXPORT ::c10::Registry, ##__VA_ARGS__>* \ RegistryName() { \ static ::c10::Registry, ##__VA_ARGS__>* \ registry = new ::c10:: \ Registry, ##__VA_ARGS__>(); \ return registry; \ } #define C10_DEFINE_TYPED_REGISTRY_WITHOUT_WARNING( \ RegistryName, SrcType, ObjectType, PtrType, ...) \ C10_EXPORT ::c10::Registry, ##__VA_ARGS__>* \ RegistryName() { \ static ::c10::Registry, ##__VA_ARGS__>* \ registry = \ new ::c10::Registry, ##__VA_ARGS__>( \ false); \ return registry; \ } // Note(Yangqing): The __VA_ARGS__ below allows one to specify a templated // creator with comma in its templated arguments. #define C10_REGISTER_TYPED_CREATOR(RegistryName, key, ...) \ static Registerer##RegistryName C10_ANONYMOUS_VARIABLE(g_##RegistryName)( \ key, RegistryName(), ##__VA_ARGS__); #define C10_REGISTER_TYPED_CREATOR_WITH_PRIORITY( \ RegistryName, key, priority, ...) \ static Registerer##RegistryName C10_ANONYMOUS_VARIABLE(g_##RegistryName)( \ key, priority, RegistryName(), ##__VA_ARGS__); #define C10_REGISTER_TYPED_CLASS(RegistryName, key, ...) \ static Registerer##RegistryName C10_ANONYMOUS_VARIABLE(g_##RegistryName)( \ key, \ RegistryName(), \ Registerer##RegistryName::DefaultCreator<__VA_ARGS__>, \ ::c10::demangle_type<__VA_ARGS__>()); #define C10_REGISTER_TYPED_CLASS_WITH_PRIORITY( \ RegistryName, key, priority, ...) \ static Registerer##RegistryName C10_ANONYMOUS_VARIABLE(g_##RegistryName)( \ key, \ priority, \ RegistryName(), \ Registerer##RegistryName::DefaultCreator<__VA_ARGS__>, \ ::c10::demangle_type<__VA_ARGS__>()); // C10_DECLARE_REGISTRY and C10_DEFINE_REGISTRY are hard-wired to use // std::string as the key type, because that is the most commonly used cases. #define C10_DECLARE_REGISTRY(RegistryName, ObjectType, ...) \ C10_DECLARE_TYPED_REGISTRY( \ RegistryName, std::string, ObjectType, std::unique_ptr, ##__VA_ARGS__) #define TORCH_DECLARE_REGISTRY(RegistryName, ObjectType, ...) \ TORCH_DECLARE_TYPED_REGISTRY( \ RegistryName, std::string, ObjectType, std::unique_ptr, ##__VA_ARGS__) #define C10_DEFINE_REGISTRY(RegistryName, ObjectType, ...) \ C10_DEFINE_TYPED_REGISTRY( \ RegistryName, std::string, ObjectType, std::unique_ptr, ##__VA_ARGS__) #define C10_DEFINE_REGISTRY_WITHOUT_WARNING(RegistryName, ObjectType, ...) \ C10_DEFINE_TYPED_REGISTRY_WITHOUT_WARNING( \ RegistryName, std::string, ObjectType, std::unique_ptr, ##__VA_ARGS__) #define C10_DECLARE_SHARED_REGISTRY(RegistryName, ObjectType, ...) \ C10_DECLARE_TYPED_REGISTRY( \ RegistryName, std::string, ObjectType, std::shared_ptr, ##__VA_ARGS__) #define TORCH_DECLARE_SHARED_REGISTRY(RegistryName, ObjectType, ...) \ TORCH_DECLARE_TYPED_REGISTRY( \ RegistryName, std::string, ObjectType, std::shared_ptr, ##__VA_ARGS__) #define C10_DEFINE_SHARED_REGISTRY(RegistryName, ObjectType, ...) \ C10_DEFINE_TYPED_REGISTRY( \ RegistryName, std::string, ObjectType, std::shared_ptr, ##__VA_ARGS__) #define C10_DEFINE_SHARED_REGISTRY_WITHOUT_WARNING( \ RegistryName, ObjectType, ...) \ C10_DEFINE_TYPED_REGISTRY_WITHOUT_WARNING( \ RegistryName, std::string, ObjectType, std::shared_ptr, ##__VA_ARGS__) // C10_REGISTER_CREATOR and C10_REGISTER_CLASS are hard-wired to use std::string // as the key // type, because that is the most commonly used cases. #define C10_REGISTER_CREATOR(RegistryName, key, ...) \ C10_REGISTER_TYPED_CREATOR(RegistryName, #key, __VA_ARGS__) #define C10_REGISTER_CREATOR_WITH_PRIORITY(RegistryName, key, priority, ...) \ C10_REGISTER_TYPED_CREATOR_WITH_PRIORITY( \ RegistryName, #key, priority, __VA_ARGS__) #define C10_REGISTER_CLASS(RegistryName, key, ...) \ C10_REGISTER_TYPED_CLASS(RegistryName, #key, __VA_ARGS__) #define C10_REGISTER_CLASS_WITH_PRIORITY(RegistryName, key, priority, ...) \ C10_REGISTER_TYPED_CLASS_WITH_PRIORITY( \ RegistryName, #key, priority, __VA_ARGS__) } // namespace c10 #endif // C10_UTIL_REGISTRY_H_