175 lines
4.8 KiB
C++
175 lines
4.8 KiB
C++
// Copyright © 2022 Apple Inc.
|
|
|
|
#pragma once
|
|
#include <c10/core/impl/DeviceGuardImplInterface.h>
|
|
#include <c10/macros/Macros.h>
|
|
#include <c10/util/Exception.h>
|
|
#include <ATen/Context.h>
|
|
#include <ATen/mps/MPSStream.h>
|
|
#include <ATen/mps/MPSEvent.h>
|
|
|
|
#ifdef __OBJC__
|
|
#include <Foundation/Foundation.h>
|
|
#include <Metal/Metal.h>
|
|
#include <MetalPerformanceShaders/MetalPerformanceShaders.h>
|
|
#endif
|
|
|
|
#include <ATen/Tensor.h>
|
|
#include <c10/core/MemoryFormat.h>
|
|
#include <c10/core/Storage.h>
|
|
#include <c10/core/TensorImpl.h>
|
|
#include <sys/_types/_size_t.h>
|
|
#include <memory>
|
|
#include <c10/core/UndefinedTensorImpl.h>
|
|
#include <c10/util/intrusive_ptr.h>
|
|
|
|
|
|
namespace at::mps {
|
|
|
|
typedef MPSEvent* mpsEvent_t;
|
|
|
|
// TODO: Move the MPSGuardImpl to inherit from NoOpDeviceGuardImpl
|
|
// https://github.com/pytorch/pytorch/issues/77170
|
|
struct TORCH_API MPSGuardImpl final : public c10::impl::DeviceGuardImplInterface {
|
|
static constexpr c10::DeviceType static_type = c10::DeviceType::MPS;
|
|
|
|
// constructor
|
|
MPSGuardImpl() {}
|
|
explicit MPSGuardImpl(c10::DeviceType t) {
|
|
TORCH_INTERNAL_ASSERT(t == c10::DeviceType::MPS);
|
|
}
|
|
|
|
// returns the type
|
|
c10::DeviceType type() const override {
|
|
return c10::DeviceType::MPS;
|
|
}
|
|
|
|
Device exchangeDevice(Device d) const override {
|
|
return Device(c10::DeviceType::MPS, 0);
|
|
}
|
|
|
|
Device getDevice() const override {
|
|
return Device(c10::DeviceType::MPS, 0);
|
|
}
|
|
|
|
c10::optional<Device> uncheckedGetDevice() const noexcept {
|
|
return Device(c10::DeviceType::MPS, 0);
|
|
}
|
|
|
|
void setDevice(Device d) const override {
|
|
TORCH_INTERNAL_ASSERT(d.is_mps());
|
|
}
|
|
|
|
void uncheckedSetDevice(Device d) const noexcept override {
|
|
// TODO: Currently setting only device 0
|
|
}
|
|
|
|
Stream getStream(Device d) const noexcept override {
|
|
return Stream(Stream::DEFAULT, Device(c10::DeviceType::MPS, 0));
|
|
}
|
|
|
|
Stream getDefaultStream(Device d) const override {
|
|
return Stream(Stream::DEFAULT, Device(c10::DeviceType::MPS, 0));
|
|
}
|
|
|
|
// NB: These do NOT set the current device
|
|
Stream exchangeStream(Stream s) const noexcept override {
|
|
return Stream(Stream::DEFAULT, Device(c10::DeviceType::MPS, 0));
|
|
}
|
|
DeviceIndex deviceCount() const noexcept override {
|
|
if (at::hasMPS()) {
|
|
//TODO: extend it for multi-device case
|
|
return 1;
|
|
} else {
|
|
return 0;
|
|
}
|
|
}
|
|
|
|
// Event-related functions
|
|
void createEvent(
|
|
mpsEvent_t* event,
|
|
const EventFlag flag) const;
|
|
|
|
void destroyEvent(
|
|
void* event,
|
|
const DeviceIndex device_index) const noexcept override;
|
|
|
|
void record(
|
|
void** event,
|
|
const Stream& stream,
|
|
const DeviceIndex device_index,
|
|
const EventFlag flag) const override;
|
|
|
|
void block(
|
|
void* event,
|
|
const Stream& stream) const override;
|
|
|
|
bool queryEvent(void* event) const override;
|
|
|
|
};
|
|
|
|
/// A variant of OptionalDeviceGuard that is specialized for MPS.
|
|
struct OptionalMPSGuard {
|
|
explicit OptionalMPSGuard() : guard_() {}
|
|
|
|
explicit OptionalMPSGuard(c10::optional<Device> device_opt)
|
|
: guard_(device_opt) {}
|
|
|
|
/// Set the current MPS device to the passed device index, if it is not
|
|
/// nullopt
|
|
explicit OptionalMPSGuard(c10::optional<DeviceIndex> device_index_opt)
|
|
: guard_(device_index_opt) {}
|
|
|
|
// Copy is not allowed
|
|
OptionalMPSGuard(const OptionalMPSGuard&) = delete;
|
|
OptionalMPSGuard& operator=(const OptionalMPSGuard&) = delete;
|
|
OptionalMPSGuard(OptionalMPSGuard&& other) = delete;
|
|
OptionalMPSGuard& operator=(OptionalMPSGuard&& other) = delete;
|
|
|
|
/// Sets the MPS device to the given device, initializing the guard if it
|
|
/// is not already initialized. Errors if the given device is not a MPS
|
|
/// device.
|
|
void set_device(Device device) {
|
|
guard_.set_device(device);
|
|
}
|
|
|
|
/// Sets the MPS device to the given device, initializing the guard if it is
|
|
/// not already initialized. Errors if the given device is not a MPS device.
|
|
void reset_device(Device device) {
|
|
guard_.reset_device(device);
|
|
}
|
|
|
|
/// Sets the MPS device to the given device index, initializing the guard if
|
|
/// it is not already initialized.
|
|
void set_index(DeviceIndex device_index) {
|
|
guard_.set_index(device_index);
|
|
}
|
|
|
|
/// Returns the device that was set immediately prior to initialization of the
|
|
/// guard, or nullopt if the guard is uninitialized.
|
|
c10::optional<Device> original_device() const {
|
|
return guard_.original_device();
|
|
}
|
|
|
|
/// Returns the most recent device that was set using this device guard,
|
|
/// either from construction, or via set_device, if the guard is initialized,
|
|
/// or nullopt if the guard is uninitialized.
|
|
c10::optional<Device> current_device() const {
|
|
return guard_.current_device();
|
|
}
|
|
|
|
/// Restore the original MPS device, resetting this guard to uninitialized
|
|
/// state.
|
|
void reset() {
|
|
guard_.reset();
|
|
}
|
|
|
|
private:
|
|
c10::impl::InlineOptionalDeviceGuard<MPSGuardImpl> guard_;
|
|
};
|
|
|
|
|
|
C10_REGISTER_GUARD_IMPL(MPS, MPSGuardImpl);
|
|
|
|
} // namespace at::mps
|