84 lines
2.5 KiB
C
84 lines
2.5 KiB
C
|
#pragma once
|
||
|
|
||
|
#include <c10/macros/Export.h>
|
||
|
|
||
|
#include <cstdint>
|
||
|
#include <memory>
|
||
|
|
||
|
namespace c10 {
|
||
|
|
||
|
enum class C10_API_ENUM DebugInfoKind : uint8_t {
|
||
|
PRODUCER_INFO = 0,
|
||
|
MOBILE_RUNTIME_INFO,
|
||
|
PROFILER_STATE,
|
||
|
INFERENCE_CONTEXT, // for inference usage
|
||
|
PARAM_COMMS_INFO,
|
||
|
|
||
|
TEST_INFO, // used only in tests
|
||
|
TEST_INFO_2, // used only in tests
|
||
|
};
|
||
|
|
||
|
class C10_API DebugInfoBase {
|
||
|
public:
|
||
|
DebugInfoBase() = default;
|
||
|
virtual ~DebugInfoBase() = default;
|
||
|
};
|
||
|
|
||
|
// Thread local debug information is propagated across the forward
|
||
|
// (including async fork tasks) and backward passes and is supposed
|
||
|
// to be utilized by the user's code to pass extra information from
|
||
|
// the higher layers (e.g. model id) down to the lower levels
|
||
|
// (e.g. to the operator observers used for debugging, logging,
|
||
|
// profiling, etc)
|
||
|
class C10_API ThreadLocalDebugInfo {
|
||
|
public:
|
||
|
static DebugInfoBase* get(DebugInfoKind kind);
|
||
|
|
||
|
// Get current ThreadLocalDebugInfo
|
||
|
static std::shared_ptr<ThreadLocalDebugInfo> current();
|
||
|
|
||
|
// Internal, use DebugInfoGuard/ThreadLocalStateGuard
|
||
|
static void _forceCurrentDebugInfo(
|
||
|
std::shared_ptr<ThreadLocalDebugInfo> info);
|
||
|
|
||
|
// Push debug info struct of a given kind
|
||
|
static void _push(DebugInfoKind kind, std::shared_ptr<DebugInfoBase> info);
|
||
|
// Pop debug info, throws in case the last pushed
|
||
|
// debug info is not of a given kind
|
||
|
static std::shared_ptr<DebugInfoBase> _pop(DebugInfoKind kind);
|
||
|
// Peek debug info, throws in case the last pushed debug info is not of the
|
||
|
// given kind
|
||
|
static std::shared_ptr<DebugInfoBase> _peek(DebugInfoKind kind);
|
||
|
|
||
|
private:
|
||
|
std::shared_ptr<DebugInfoBase> info_;
|
||
|
DebugInfoKind kind_;
|
||
|
std::shared_ptr<ThreadLocalDebugInfo> parent_info_;
|
||
|
|
||
|
friend class DebugInfoGuard;
|
||
|
};
|
||
|
|
||
|
// DebugInfoGuard is used to set debug information,
|
||
|
// ThreadLocalDebugInfo is semantically immutable, the values are set
|
||
|
// through the scope-based guard object.
|
||
|
// Nested DebugInfoGuard adds/overrides existing values in the scope,
|
||
|
// restoring the original values after exiting the scope.
|
||
|
// Users can access the values through the ThreadLocalDebugInfo::get() call;
|
||
|
class C10_API DebugInfoGuard {
|
||
|
public:
|
||
|
DebugInfoGuard(DebugInfoKind kind, std::shared_ptr<DebugInfoBase> info);
|
||
|
|
||
|
explicit DebugInfoGuard(std::shared_ptr<ThreadLocalDebugInfo> info);
|
||
|
|
||
|
~DebugInfoGuard();
|
||
|
|
||
|
DebugInfoGuard(const DebugInfoGuard&) = delete;
|
||
|
DebugInfoGuard(DebugInfoGuard&&) = delete;
|
||
|
|
||
|
private:
|
||
|
bool active_ = false;
|
||
|
std::shared_ptr<ThreadLocalDebugInfo> prev_info_ = nullptr;
|
||
|
};
|
||
|
|
||
|
} // namespace c10
|