142 lines
4.1 KiB
Python
142 lines
4.1 KiB
Python
# coding=utf-8
|
|
# Copyright 2023-present the HuggingFace Inc. team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from .integrations import (
|
|
is_optuna_available,
|
|
is_ray_tune_available,
|
|
is_sigopt_available,
|
|
is_wandb_available,
|
|
run_hp_search_optuna,
|
|
run_hp_search_ray,
|
|
run_hp_search_sigopt,
|
|
run_hp_search_wandb,
|
|
)
|
|
from .trainer_utils import (
|
|
HPSearchBackend,
|
|
default_hp_space_optuna,
|
|
default_hp_space_ray,
|
|
default_hp_space_sigopt,
|
|
default_hp_space_wandb,
|
|
)
|
|
from .utils import logging
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
class HyperParamSearchBackendBase:
|
|
name: str
|
|
pip_package: str = None
|
|
|
|
@staticmethod
|
|
def is_available():
|
|
raise NotImplementedError
|
|
|
|
def run(self, trainer, n_trials: int, direction: str, **kwargs):
|
|
raise NotImplementedError
|
|
|
|
def default_hp_space(self, trial):
|
|
raise NotImplementedError
|
|
|
|
def ensure_available(self):
|
|
if not self.is_available():
|
|
raise RuntimeError(
|
|
f"You picked the {self.name} backend, but it is not installed. Run {self.pip_install()}."
|
|
)
|
|
|
|
@classmethod
|
|
def pip_install(cls):
|
|
return f"`pip install {cls.pip_package or cls.name}`"
|
|
|
|
|
|
class OptunaBackend(HyperParamSearchBackendBase):
|
|
name = "optuna"
|
|
|
|
@staticmethod
|
|
def is_available():
|
|
return is_optuna_available()
|
|
|
|
def run(self, trainer, n_trials: int, direction: str, **kwargs):
|
|
return run_hp_search_optuna(trainer, n_trials, direction, **kwargs)
|
|
|
|
def default_hp_space(self, trial):
|
|
return default_hp_space_optuna(trial)
|
|
|
|
|
|
class RayTuneBackend(HyperParamSearchBackendBase):
|
|
name = "ray"
|
|
pip_package = "'ray[tune]'"
|
|
|
|
@staticmethod
|
|
def is_available():
|
|
return is_ray_tune_available()
|
|
|
|
def run(self, trainer, n_trials: int, direction: str, **kwargs):
|
|
return run_hp_search_ray(trainer, n_trials, direction, **kwargs)
|
|
|
|
def default_hp_space(self, trial):
|
|
return default_hp_space_ray(trial)
|
|
|
|
|
|
class SigOptBackend(HyperParamSearchBackendBase):
|
|
name = "sigopt"
|
|
|
|
@staticmethod
|
|
def is_available():
|
|
return is_sigopt_available()
|
|
|
|
def run(self, trainer, n_trials: int, direction: str, **kwargs):
|
|
return run_hp_search_sigopt(trainer, n_trials, direction, **kwargs)
|
|
|
|
def default_hp_space(self, trial):
|
|
return default_hp_space_sigopt(trial)
|
|
|
|
|
|
class WandbBackend(HyperParamSearchBackendBase):
|
|
name = "wandb"
|
|
|
|
@staticmethod
|
|
def is_available():
|
|
return is_wandb_available()
|
|
|
|
def run(self, trainer, n_trials: int, direction: str, **kwargs):
|
|
return run_hp_search_wandb(trainer, n_trials, direction, **kwargs)
|
|
|
|
def default_hp_space(self, trial):
|
|
return default_hp_space_wandb(trial)
|
|
|
|
|
|
ALL_HYPERPARAMETER_SEARCH_BACKENDS = {
|
|
HPSearchBackend(backend.name): backend for backend in [OptunaBackend, RayTuneBackend, SigOptBackend, WandbBackend]
|
|
}
|
|
|
|
|
|
def default_hp_search_backend() -> str:
|
|
available_backends = [backend for backend in ALL_HYPERPARAMETER_SEARCH_BACKENDS.values() if backend.is_available()]
|
|
if len(available_backends) > 0:
|
|
name = available_backends[0].name
|
|
if len(available_backends) > 1:
|
|
logger.info(
|
|
f"{len(available_backends)} hyperparameter search backends available. Using {name} as the default."
|
|
)
|
|
return name
|
|
raise RuntimeError(
|
|
"No hyperparameter search backend available.\n"
|
|
+ "\n".join(
|
|
f" - To install {backend.name} run {backend.pip_install()}"
|
|
for backend in ALL_HYPERPARAMETER_SEARCH_BACKENDS.values()
|
|
)
|
|
)
|