101 lines
3.5 KiB
C++
101 lines
3.5 KiB
C++
// Copyright © 2023 Apple Inc.
|
|
|
|
#pragma once
|
|
|
|
#include <ATen/mps/MPSStream.h>
|
|
#include <ctime>
|
|
#include <stack>
|
|
|
|
namespace at::mps {
|
|
|
|
// NOTE: don't create instances of this class directly.
|
|
// Use MPSEventPool to acquire instances of MPSEvent.
|
|
class MPSEvent {
|
|
public:
|
|
explicit MPSEvent(id_t ID, MPSStream* stream, bool enable_timing);
|
|
~MPSEvent();
|
|
|
|
// records an event on the stream
|
|
void record(bool needsLock, bool syncEvent = false);
|
|
// makes all future work submitted to the stream wait for this event.
|
|
bool wait(bool needsLock, bool syncEvent = false);
|
|
// schedules a notifyListener callback for the event.
|
|
bool notify(bool needsLock, MTLSharedEventNotificationBlock block);
|
|
// checks if events are already signaled.
|
|
bool query() const;
|
|
// blocks the CPU thread until all the GPU work that were scheduled
|
|
// prior to recording this event are completed.
|
|
bool synchronize();
|
|
// resets this event with new parameters in case it gets reused from the event pool
|
|
void reset(MPSStream* stream, bool enable_timing);
|
|
// returns the unique ID of the event instance
|
|
id_t getID() const { return m_id; }
|
|
// returns the completion timestamp of the event
|
|
uint64_t getCompletionTime() const { return m_completion_time; }
|
|
// if already recorded, waits for cpu_sync_cv to be signaled
|
|
void waitForCpuSync();
|
|
|
|
private:
|
|
id_t m_id;
|
|
// enables measuring the completion time of the notifyListener of this event
|
|
bool m_enable_timing;
|
|
uint64_t m_signalCounter = 0;
|
|
MPSStream* m_stream = nullptr;
|
|
MTLSharedEvent_t m_event = nullptr;
|
|
MTLSharedEventListener* m_listener = nullptr;
|
|
// used to sync the events created on this Stream with CPU
|
|
std::mutex m_cpu_sync_mutex{};
|
|
std::condition_variable m_cpu_sync_cv{};
|
|
// CondVar predicate to sync the events created on this Stream with CPU
|
|
bool m_cpu_sync_completed = false;
|
|
// used to compute elapsed time
|
|
uint64_t m_completion_time = 0;
|
|
|
|
void recordLocked(bool syncEvent);
|
|
bool waitLocked(bool syncEvent);
|
|
bool notifyLocked(MTLSharedEventNotificationBlock block);
|
|
void notifyCpuSync();
|
|
static uint64_t getTime() {
|
|
return clock_gettime_nsec_np(CLOCK_MONOTONIC_RAW);
|
|
}
|
|
};
|
|
|
|
typedef std::unique_ptr<MPSEvent, std::function<void(MPSEvent*)>> MPSEventPtr;
|
|
|
|
class MPSEventPool {
|
|
public:
|
|
explicit MPSEventPool(MPSStream* default_stream);
|
|
~MPSEventPool();
|
|
|
|
MPSEventPtr acquireEvent(bool enable_timing, MPSStream* stream);
|
|
void emptyCache();
|
|
|
|
// these are mainly used for MPSHooks and torch.mps.Event() bindings
|
|
id_t acquireEvent(bool enable_timing);
|
|
void releaseEvent(id_t event_id);
|
|
void recordEvent(id_t event_id, bool syncEvent);
|
|
void waitForEvent(id_t event_id, bool syncEvent);
|
|
void synchronizeEvent(id_t event_id);
|
|
bool queryEvent(id_t event_id);
|
|
// returns elapsed time between two recorded events in milliseconds
|
|
double elapsedTime(id_t start_event_id, id_t end_event_id);
|
|
|
|
private:
|
|
MPSStream* m_default_stream = nullptr;
|
|
std::recursive_mutex m_mutex;
|
|
std::stack<std::unique_ptr<MPSEvent>> m_pool{};
|
|
// dictionary to associate event IDs with event objects
|
|
// used to retain in-use events out of the pool
|
|
// for torch.mps.Event() bindings.
|
|
std::unordered_map<id_t, MPSEventPtr> m_in_use_events{};
|
|
uint64_t m_event_counter = 0;
|
|
std::function<void(MPSEvent*)> m_default_deleter;
|
|
|
|
MPSEvent* getInUseEvent(id_t event_id, bool locked = true);
|
|
};
|
|
|
|
// shared_ptr is used to get MPSEventPool destroyed after dependent instances
|
|
std::shared_ptr<MPSEventPool> getMPSEventPool();
|
|
|
|
} // namespace at::mps
|