#pragma once #include #include namespace at::impl { enum TorchFunctionDisabledState { ENABLED, SUBCLASSES_DISABLED, ALL_DISABLED }; struct TORCH_API PythonTorchFunctionTLS { static void set_disabled_state(TorchFunctionDisabledState disabled_state_); static TorchFunctionDisabledState get_disabled_state(); static void push_onto_stack(std::shared_ptr mode); static const std::shared_ptr pop_stack(); static const std::shared_ptr& get_stack_at(int64_t idx); static int64_t stack_len(); static const PythonTorchFunctionTLS& get_state(); static void set_state(const PythonTorchFunctionTLS& state); private: // The mode TLS is split into // - disabled_state, which says which part of torch function are disabled // - stack_, which is a vector of modes representing the stack of user // defined modes TorchFunctionDisabledState disabled_state_ = TorchFunctionDisabledState::ENABLED; std::vector> stack_; }; TORCH_API bool torch_function_mode_enabled(); } // namespace at::impl