// Copyright © 2023 Apple Inc. #pragma once #include #include #include namespace at::mps { // NOTE: don't create instances of this class directly. // Use MPSEventPool to acquire instances of MPSEvent. class MPSEvent { public: explicit MPSEvent(id_t ID, MPSStream* stream, bool enable_timing); ~MPSEvent(); // records an event on the stream void record(bool needsLock, bool syncEvent = false); // makes all future work submitted to the stream wait for this event. bool wait(bool needsLock, bool syncEvent = false); // schedules a notifyListener callback for the event. bool notify(bool needsLock, MTLSharedEventNotificationBlock block); // checks if events are already signaled. bool query() const; // blocks the CPU thread until all the GPU work that were scheduled // prior to recording this event are completed. bool synchronize(); // resets this event with new parameters in case it gets reused from the event pool void reset(MPSStream* stream, bool enable_timing); // returns the unique ID of the event instance id_t getID() const { return m_id; } // returns the completion timestamp of the event uint64_t getCompletionTime() const { return m_completion_time; } // if already recorded, waits for cpu_sync_cv to be signaled void waitForCpuSync(); private: id_t m_id; // enables measuring the completion time of the notifyListener of this event bool m_enable_timing; uint64_t m_signalCounter = 0; MPSStream* m_stream = nullptr; MTLSharedEvent_t m_event = nullptr; MTLSharedEventListener* m_listener = nullptr; // used to sync the events created on this Stream with CPU std::mutex m_cpu_sync_mutex{}; std::condition_variable m_cpu_sync_cv{}; // CondVar predicate to sync the events created on this Stream with CPU bool m_cpu_sync_completed = false; // used to compute elapsed time uint64_t m_completion_time = 0; void recordLocked(bool syncEvent); bool waitLocked(bool syncEvent); bool notifyLocked(MTLSharedEventNotificationBlock block); void notifyCpuSync(); static uint64_t getTime() { return clock_gettime_nsec_np(CLOCK_MONOTONIC_RAW); } }; typedef std::unique_ptr> MPSEventPtr; class MPSEventPool { public: explicit MPSEventPool(MPSStream* default_stream); ~MPSEventPool(); MPSEventPtr acquireEvent(bool enable_timing, MPSStream* stream); void emptyCache(); // these are mainly used for MPSHooks and torch.mps.Event() bindings id_t acquireEvent(bool enable_timing); void releaseEvent(id_t event_id); void recordEvent(id_t event_id, bool syncEvent); void waitForEvent(id_t event_id, bool syncEvent); void synchronizeEvent(id_t event_id); bool queryEvent(id_t event_id); // returns elapsed time between two recorded events in milliseconds double elapsedTime(id_t start_event_id, id_t end_event_id); private: MPSStream* m_default_stream = nullptr; std::recursive_mutex m_mutex; std::stack> m_pool{}; // dictionary to associate event IDs with event objects // used to retain in-use events out of the pool // for torch.mps.Event() bindings. std::unordered_map m_in_use_events{}; uint64_t m_event_counter = 0; std::function m_default_deleter; MPSEvent* getInUseEvent(id_t event_id, bool locked = true); }; // shared_ptr is used to get MPSEventPool destroyed after dependent instances std::shared_ptr getMPSEventPool(); } // namespace at::mps