205 lines
7.9 KiB
Python
205 lines
7.9 KiB
Python
|
import hashlib
|
||
|
import os
|
||
|
import site
|
||
|
import sys
|
||
|
import tarfile
|
||
|
import urllib.parse
|
||
|
from pathlib import Path
|
||
|
from typing import TYPE_CHECKING, Dict, List, Optional
|
||
|
|
||
|
from wasabi import msg
|
||
|
|
||
|
from ..errors import Errors
|
||
|
from ..util import check_spacy_env_vars, download_file, ensure_pathy, get_checksum
|
||
|
from ..util import get_hash, make_tempdir, upload_file
|
||
|
|
||
|
if TYPE_CHECKING:
|
||
|
from cloudpathlib import CloudPath
|
||
|
|
||
|
|
||
|
class RemoteStorage:
|
||
|
"""Push and pull outputs to and from a remote file storage.
|
||
|
|
||
|
Remotes can be anything that `smart-open` can support: AWS, GCS, file system,
|
||
|
ssh, etc.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, project_root: Path, url: str, *, compression="gz"):
|
||
|
self.root = project_root
|
||
|
self.url = ensure_pathy(url)
|
||
|
self.compression = compression
|
||
|
|
||
|
def push(self, path: Path, command_hash: str, content_hash: str) -> "CloudPath":
|
||
|
"""Compress a file or directory within a project and upload it to a remote
|
||
|
storage. If an object exists at the full URL, nothing is done.
|
||
|
|
||
|
Within the remote storage, files are addressed by their project path
|
||
|
(url encoded) and two user-supplied hashes, representing their creation
|
||
|
context and their file contents. If the URL already exists, the data is
|
||
|
not uploaded. Paths are archived and compressed prior to upload.
|
||
|
"""
|
||
|
loc = self.root / path
|
||
|
if not loc.exists():
|
||
|
raise IOError(f"Cannot push {loc}: does not exist.")
|
||
|
url = self.make_url(path, command_hash, content_hash)
|
||
|
if url.exists():
|
||
|
return url
|
||
|
tmp: Path
|
||
|
with make_tempdir() as tmp:
|
||
|
tar_loc = tmp / self.encode_name(str(path))
|
||
|
mode_string = f"w:{self.compression}" if self.compression else "w"
|
||
|
with tarfile.open(tar_loc, mode=mode_string) as tar_file:
|
||
|
tar_file.add(str(loc), arcname=str(path))
|
||
|
upload_file(tar_loc, url)
|
||
|
return url
|
||
|
|
||
|
def pull(
|
||
|
self,
|
||
|
path: Path,
|
||
|
*,
|
||
|
command_hash: Optional[str] = None,
|
||
|
content_hash: Optional[str] = None,
|
||
|
) -> Optional["CloudPath"]:
|
||
|
"""Retrieve a file from the remote cache. If the file already exists,
|
||
|
nothing is done.
|
||
|
|
||
|
If the command_hash and/or content_hash are specified, only matching
|
||
|
results are returned. If no results are available, an error is raised.
|
||
|
"""
|
||
|
dest = self.root / path
|
||
|
if dest.exists():
|
||
|
return None
|
||
|
url = self.find(path, command_hash=command_hash, content_hash=content_hash)
|
||
|
if url is None:
|
||
|
return url
|
||
|
else:
|
||
|
# Make sure the destination exists
|
||
|
if not dest.parent.exists():
|
||
|
dest.parent.mkdir(parents=True)
|
||
|
tmp: Path
|
||
|
with make_tempdir() as tmp:
|
||
|
tar_loc = tmp / url.parts[-1]
|
||
|
download_file(url, tar_loc)
|
||
|
mode_string = f"r:{self.compression}" if self.compression else "r"
|
||
|
with tarfile.open(tar_loc, mode=mode_string) as tar_file:
|
||
|
# This requires that the path is added correctly, relative
|
||
|
# to root. This is how we set things up in push()
|
||
|
|
||
|
# Disallow paths outside the current directory for the tar
|
||
|
# file (CVE-2007-4559, directory traversal vulnerability)
|
||
|
def is_within_directory(directory, target):
|
||
|
abs_directory = os.path.abspath(directory)
|
||
|
abs_target = os.path.abspath(target)
|
||
|
prefix = os.path.commonprefix([abs_directory, abs_target])
|
||
|
return prefix == abs_directory
|
||
|
|
||
|
def safe_extract(tar, path):
|
||
|
for member in tar.getmembers():
|
||
|
member_path = os.path.join(path, member.name)
|
||
|
if not is_within_directory(path, member_path):
|
||
|
raise ValueError(Errors.E201)
|
||
|
if sys.version_info >= (3, 12):
|
||
|
tar.extractall(path, filter="data")
|
||
|
else:
|
||
|
tar.extractall(path)
|
||
|
|
||
|
safe_extract(tar_file, self.root)
|
||
|
return url
|
||
|
|
||
|
def find(
|
||
|
self,
|
||
|
path: Path,
|
||
|
*,
|
||
|
command_hash: Optional[str] = None,
|
||
|
content_hash: Optional[str] = None,
|
||
|
) -> Optional["CloudPath"]:
|
||
|
"""Find the best matching version of a file within the storage,
|
||
|
or `None` if no match can be found. If both the creation and content hash
|
||
|
are specified, only exact matches will be returned. Otherwise, the most
|
||
|
recent matching file is preferred.
|
||
|
"""
|
||
|
name = self.encode_name(str(path))
|
||
|
urls = []
|
||
|
if command_hash is not None and content_hash is not None:
|
||
|
url = self.url / name / command_hash / content_hash
|
||
|
urls = [url] if url.exists() else []
|
||
|
elif command_hash is not None:
|
||
|
if (self.url / name / command_hash).exists():
|
||
|
urls = list((self.url / name / command_hash).iterdir())
|
||
|
else:
|
||
|
if (self.url / name).exists():
|
||
|
for sub_dir in (self.url / name).iterdir():
|
||
|
urls.extend(sub_dir.iterdir())
|
||
|
if content_hash is not None:
|
||
|
urls = [url for url in urls if url.parts[-1] == content_hash]
|
||
|
if len(urls) >= 2:
|
||
|
try:
|
||
|
urls.sort(key=lambda x: x.stat().st_mtime)
|
||
|
except Exception:
|
||
|
msg.warn(
|
||
|
"Unable to sort remote files by last modified. The file(s) "
|
||
|
"pulled from the cache may not be the most recent."
|
||
|
)
|
||
|
return urls[-1] if urls else None
|
||
|
|
||
|
def make_url(self, path: Path, command_hash: str, content_hash: str) -> "CloudPath":
|
||
|
"""Construct a URL from a subpath, a creation hash and a content hash."""
|
||
|
return self.url / self.encode_name(str(path)) / command_hash / content_hash
|
||
|
|
||
|
def encode_name(self, name: str) -> str:
|
||
|
"""Encode a subpath into a URL-safe name."""
|
||
|
return urllib.parse.quote_plus(name)
|
||
|
|
||
|
|
||
|
def get_content_hash(loc: Path) -> str:
|
||
|
return get_checksum(loc)
|
||
|
|
||
|
|
||
|
def get_command_hash(
|
||
|
site_hash: str, env_hash: str, deps: List[Path], cmd: List[str]
|
||
|
) -> str:
|
||
|
"""Create a hash representing the execution of a command. This includes the
|
||
|
currently installed packages, whatever environment variables have been marked
|
||
|
as relevant, and the command.
|
||
|
"""
|
||
|
check_spacy_env_vars()
|
||
|
dep_checksums = [get_checksum(dep) for dep in sorted(deps)]
|
||
|
hashes = [site_hash, env_hash] + dep_checksums
|
||
|
hashes.extend(cmd)
|
||
|
creation_bytes = "".join(hashes).encode("utf8")
|
||
|
return hashlib.md5(creation_bytes).hexdigest()
|
||
|
|
||
|
|
||
|
def get_site_hash():
|
||
|
"""Hash the current Python environment's site-packages contents, including
|
||
|
the name and version of the libraries. The list we're hashing is what
|
||
|
`pip freeze` would output.
|
||
|
"""
|
||
|
site_dirs = site.getsitepackages()
|
||
|
if site.ENABLE_USER_SITE:
|
||
|
site_dirs.extend(site.getusersitepackages())
|
||
|
packages = set()
|
||
|
for site_dir in site_dirs:
|
||
|
site_dir = Path(site_dir)
|
||
|
for subpath in site_dir.iterdir():
|
||
|
if subpath.parts[-1].endswith("dist-info"):
|
||
|
packages.add(subpath.parts[-1].replace(".dist-info", ""))
|
||
|
package_bytes = "".join(sorted(packages)).encode("utf8")
|
||
|
return hashlib.md5sum(package_bytes).hexdigest()
|
||
|
|
||
|
|
||
|
def get_env_hash(env: Dict[str, str]) -> str:
|
||
|
"""Construct a hash of the environment variables that will be passed into
|
||
|
the commands.
|
||
|
|
||
|
Values in the env dict may be references to the current os.environ, using
|
||
|
the syntax $ENV_VAR to mean os.environ[ENV_VAR]
|
||
|
"""
|
||
|
env_vars = {}
|
||
|
for key, value in env.items():
|
||
|
if value.startswith("$"):
|
||
|
env_vars[key] = os.environ.get(value[1:], "")
|
||
|
else:
|
||
|
env_vars[key] = value
|
||
|
return get_hash(env_vars)
|