ai-content-maker/.venv/Lib/site-packages/tensorboard/plugins/hparams/summary.py

207 lines
8.3 KiB
Python
Raw Permalink Normal View History

2024-05-03 04:18:51 +03:00
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Summary creation methods for the HParams plugin.
Typical usage for exporting summaries in a hyperparameters-tuning experiment:
1. Create the experiment (once) by calling experiment_pb() and exporting
the resulting summary into a top-level (empty) run.
2. In each training session in the experiment, call session_start_pb() before
the session starts, exporting the resulting summary into a uniquely named
run for the session, say <session_name>.
3. Train the model in the session, exporting each metric as a scalar summary
in runs of the form <session_name>/<sub_dir>, where <sub_dir> can be empty a
(in which case the run is just the <session_name>) and depends on the
metric. The name of such a metric is a (group, tag) pair given by
(<sub_dir>, tag) where tag is the tag of the scalar summary.
When calling experiment_pb in step 1, you'll need to pass all the metric
names used in the experiemnt.
4. When the session completes, call session_end_pb() and export the resulting
summary into the same session run <session_name>.
"""
import time
import tensorflow as tf
from tensorboard.plugins.hparams import api_pb2
from tensorboard.plugins.hparams import metadata
from tensorboard.plugins.hparams import plugin_data_pb2
def experiment_pb(
hparam_infos, metric_infos, user="", description="", time_created_secs=None
):
"""Creates a summary that defines a hyperparameter-tuning experiment.
Args:
hparam_infos: Array of api_pb2.HParamInfo messages. Describes the
hyperparameters used in the experiment.
metric_infos: Array of api_pb2.MetricInfo messages. Describes the metrics
used in the experiment. See the documentation at the top of this file
for how to populate this.
user: String. An id for the user running the experiment
description: String. A description for the experiment. May contain markdown.
time_created_secs: float. The time the experiment is created in seconds
since the UNIX epoch. If None uses the current time.
Returns:
A summary protobuffer containing the experiment definition.
"""
if time_created_secs is None:
time_created_secs = time.time()
experiment = api_pb2.Experiment(
description=description,
user=user,
time_created_secs=time_created_secs,
hparam_infos=hparam_infos,
metric_infos=metric_infos,
)
return _summary(
metadata.EXPERIMENT_TAG,
plugin_data_pb2.HParamsPluginData(experiment=experiment),
)
def session_start_pb(
hparams, model_uri="", monitor_url="", group_name="", start_time_secs=None
):
"""Constructs a SessionStartInfo protobuffer.
Creates a summary that contains a training session metadata information.
One such summary per training session should be created. Each should have
a different run.
Args:
hparams: A dictionary with string keys. Describes the hyperparameter values
used in the session, mapping each hyperparameter name to its value.
Supported value types are `bool`, `int`, `float`, `str`, `list`,
`tuple`.
The type of value must correspond to the type of hyperparameter
(defined in the corresponding api_pb2.HParamInfo member of the
Experiment protobuf) as follows:
+-----------------+---------------------------------+
|Hyperparameter | Allowed (Python) value types |
|type | |
+-----------------+---------------------------------+
|DATA_TYPE_BOOL | bool |
|DATA_TYPE_FLOAT64| int, float |
|DATA_TYPE_STRING | str, tuple, list |
+-----------------+---------------------------------+
Tuple and list instances will be converted to their string
representation.
model_uri: See the comment for the field with the same name of
plugin_data_pb2.SessionStartInfo.
monitor_url: See the comment for the field with the same name of
plugin_data_pb2.SessionStartInfo.
group_name: See the comment for the field with the same name of
plugin_data_pb2.SessionStartInfo.
start_time_secs: float. The time to use as the session start time.
Represented as seconds since the UNIX epoch. If None uses
the current time.
Returns:
The summary protobuffer mentioned above.
"""
if start_time_secs is None:
start_time_secs = time.time()
session_start_info = plugin_data_pb2.SessionStartInfo(
model_uri=model_uri,
monitor_url=monitor_url,
group_name=group_name,
start_time_secs=start_time_secs,
)
for (hp_name, hp_val) in hparams.items():
# Boolean typed values need to be checked before integers since in Python
# isinstance(True/False, int) returns True.
if isinstance(hp_val, bool):
session_start_info.hparams[hp_name].bool_value = hp_val
elif isinstance(hp_val, (float, int)):
session_start_info.hparams[hp_name].number_value = hp_val
elif isinstance(hp_val, str):
session_start_info.hparams[hp_name].string_value = hp_val
elif isinstance(hp_val, (list, tuple)):
session_start_info.hparams[hp_name].string_value = str(hp_val)
else:
raise TypeError(
"hparams[%s]=%s has type: %s which is not supported"
% (hp_name, hp_val, type(hp_val))
)
return _summary(
metadata.SESSION_START_INFO_TAG,
plugin_data_pb2.HParamsPluginData(
session_start_info=session_start_info
),
)
def session_end_pb(status, end_time_secs=None):
"""Constructs a SessionEndInfo protobuffer.
Creates a summary that contains status information for a completed
training session. Should be exported after the training session is completed.
One such summary per training session should be created. Each should have
a different run.
Args:
status: A tensorboard.hparams.Status enumeration value denoting the
status of the session.
end_time_secs: float. The time to use as the session end time. Represented
as seconds since the unix epoch. If None uses the current time.
Returns:
The summary protobuffer mentioned above.
"""
if end_time_secs is None:
end_time_secs = time.time()
session_end_info = plugin_data_pb2.SessionEndInfo(
status=status, end_time_secs=end_time_secs
)
return _summary(
metadata.SESSION_END_INFO_TAG,
plugin_data_pb2.HParamsPluginData(session_end_info=session_end_info),
)
def _summary(tag, hparams_plugin_data):
"""Returns a summary holding the given HParamsPluginData message.
Helper function.
Args:
tag: string. The tag to use.
hparams_plugin_data: The HParamsPluginData message to use.
"""
summary = tf.compat.v1.Summary()
tb_metadata = metadata.create_summary_metadata(hparams_plugin_data)
raw_metadata = tb_metadata.SerializeToString()
tf_metadata = tf.compat.v1.SummaryMetadata.FromString(raw_metadata)
summary.value.add(
tag=tag,
metadata=tf_metadata,
tensor=_TF_NULL_TENSOR,
)
return summary
# Like `metadata.NULL_TENSOR`, but with the TensorFlow version of the
# proto. Slight kludge needed to expose the `TensorProto` type.
_TF_NULL_TENSOR = type(tf.make_tensor_proto(0)).FromString(
metadata.NULL_TENSOR.SerializeToString()
)