ai-content-maker/.venv/Lib/site-packages/thinc/compat.py

124 lines
3.0 KiB
Python

import warnings
from packaging.version import Version
try: # pragma: no cover
import cupy
import cupy.cublas
import cupyx
has_cupy = True
cublas = cupy.cublas
cupy_version = Version(cupy.__version__)
try:
cupy.cuda.runtime.getDeviceCount()
has_cupy_gpu = True
except cupy.cuda.runtime.CUDARuntimeError:
has_cupy_gpu = False
if cupy_version.major >= 10:
# fromDlpack was deprecated in v10.0.0.
cupy_from_dlpack = cupy.from_dlpack
else:
cupy_from_dlpack = cupy.fromDlpack
except (ImportError, AttributeError):
cublas = None
cupy = None
cupyx = None
cupy_version = Version("0.0.0")
has_cupy = False
cupy_from_dlpack = None
has_cupy_gpu = False
try: # pragma: no cover
import torch
import torch.utils.dlpack
has_torch = True
has_torch_cuda_gpu = torch.cuda.device_count() != 0
has_torch_mps = hasattr(torch.backends, "mps") and torch.backends.mps.is_built()
has_torch_mps_gpu = has_torch_mps and torch.backends.mps.is_available()
has_torch_gpu = has_torch_cuda_gpu
torch_version = Version(str(torch.__version__))
has_torch_amp = (
torch_version >= Version("1.9.0")
and not torch.cuda.amp.common.amp_definitely_not_available()
)
except ImportError: # pragma: no cover
torch = None # type: ignore
has_torch = False
has_torch_cuda_gpu = False
has_torch_gpu = False
has_torch_mps = False
has_torch_mps_gpu = False
has_torch_amp = False
torch_version = Version("0.0.0")
def enable_tensorflow():
warn_msg = (
"Built-in TensorFlow support will be removed in Thinc v9. If you need "
"TensorFlow support in the future, you can transition to using a "
"custom copy of the current TensorFlowWrapper in your package or "
"project."
)
warnings.warn(warn_msg, DeprecationWarning)
global tensorflow, has_tensorflow, has_tensorflow_gpu
import tensorflow
import tensorflow.experimental.dlpack
has_tensorflow = True
has_tensorflow_gpu = len(tensorflow.config.get_visible_devices("GPU")) > 0
tensorflow = None
has_tensorflow = False
has_tensorflow_gpu = False
def enable_mxnet():
warn_msg = (
"Built-in MXNet support will be removed in Thinc v9. If you need "
"MXNet support in the future, you can transition to using a "
"custom copy of the current MXNetWrapper in your package or "
"project."
)
warnings.warn(warn_msg, DeprecationWarning)
global mxnet, has_mxnet
import mxnet
has_mxnet = True
mxnet = None
has_mxnet = False
try:
import h5py
except ImportError: # pragma: no cover
h5py = None
try: # pragma: no cover
import os_signpost
has_os_signpost = True
except ImportError:
os_signpost = None
has_os_signpost = False
has_gpu = has_cupy_gpu or has_torch_mps_gpu
__all__ = [
"cupy",
"cupyx",
"torch",
"tensorflow",
"mxnet",
"h5py",
"os_signpost",
]