ai-content-maker/.venv/Lib/site-packages/trainer/logging/clearml_logger.py

70 lines
2.3 KiB
Python

import os
from typing import Any
import torch
from trainer.logging.tensorboard_logger import TensorboardLogger
from trainer.trainer_utils import is_clearml_available
from trainer.utils.distributed import rank_zero_only
if is_clearml_available():
from clearml import Task # pylint: disable=import-error
else:
raise ImportError("ClearML is not installed. Please install it with `pip install clearml`")
class ClearMLLogger(TensorboardLogger):
"""ClearML Logger using TensorBoard in the background.
TODO:
- Add hyperparameter handling
- Use ClearML logger for plots
- Handle continuing training
Args:
output_uri (str): URI of the ClearML repository.
local_path (str): Path to the local directory where the model is saved.
project_name (str): Name of the ClearML project.
task_name (str): Name of the ClearML task.
tags (str): Comma separated list of tags to add to the ClearML task.
"""
def __init__(
self,
output_uri: str,
local_path: str,
project_name: str,
task_name: str,
tags: str = None,
):
self._context = None
self.local_path = local_path
self.task_name = task_name
self.tags = tags.split(",") if tags else []
self.run = Task.init(project_name=project_name, task_name=task_name, tags=self.tags, output_uri=output_uri)
if tags:
for tag in tags.split(","):
self.run.add_tag(tag)
super().__init__("run", None)
@rank_zero_only
def add_config(self, config):
"""Upload config file(s) to ClearML."""
self.add_text("run_config", f"{config.to_json()}", 0)
self.run.connect_configuration(name="model_config", configuration=config.to_dict())
self.run.set_comment(config.run_description)
self.run.upload_artifact("model_config", config.to_dict())
self.run.upload_artifact("configs", artifact_object=os.path.join(self.local_path, "*.json"))
@rank_zero_only
def add_artifact(self, file_or_dir, name, **kwargs): # pylint: disable=unused-argument, arguments-differ
"""Upload artifact to ClearML."""
self.run.upload_artifact(name, artifact_object=file_or_dir)
@staticmethod
@rank_zero_only
def save_model(state: Any, path: str):
torch.save(state, path)