984 lines
34 KiB
Python
984 lines
34 KiB
Python
|
import json
|
||
|
import logging
|
||
|
import os
|
||
|
import struct
|
||
|
|
||
|
from typing import Any, List, Optional
|
||
|
|
||
|
import torch
|
||
|
import numpy as np
|
||
|
|
||
|
from google.protobuf import struct_pb2
|
||
|
|
||
|
from tensorboard.compat.proto.summary_pb2 import (
|
||
|
HistogramProto,
|
||
|
Summary,
|
||
|
SummaryMetadata,
|
||
|
)
|
||
|
from tensorboard.compat.proto.tensor_pb2 import TensorProto
|
||
|
from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto
|
||
|
from tensorboard.plugins.custom_scalar import layout_pb2
|
||
|
from tensorboard.plugins.pr_curve.plugin_data_pb2 import PrCurvePluginData
|
||
|
from tensorboard.plugins.text.plugin_data_pb2 import TextPluginData
|
||
|
|
||
|
from ._convert_np import make_np
|
||
|
from ._utils import _prepare_video, convert_to_HWC
|
||
|
|
||
|
__all__ = [
|
||
|
"half_to_int",
|
||
|
"int_to_half",
|
||
|
"hparams",
|
||
|
"scalar",
|
||
|
"histogram_raw",
|
||
|
"histogram",
|
||
|
"make_histogram",
|
||
|
"image",
|
||
|
"image_boxes",
|
||
|
"draw_boxes",
|
||
|
"make_image",
|
||
|
"video",
|
||
|
"make_video",
|
||
|
"audio",
|
||
|
"custom_scalars",
|
||
|
"text",
|
||
|
"tensor_proto",
|
||
|
"pr_curve_raw",
|
||
|
"pr_curve",
|
||
|
"compute_curve",
|
||
|
"mesh",
|
||
|
]
|
||
|
|
||
|
logger = logging.getLogger(__name__)
|
||
|
|
||
|
def half_to_int(f: float) -> int:
|
||
|
"""Casts a half-precision float value into an integer.
|
||
|
|
||
|
Converts a half precision floating point value, such as `torch.half` or
|
||
|
`torch.bfloat16`, into an integer value which can be written into the
|
||
|
half_val field of a TensorProto for storage.
|
||
|
|
||
|
To undo the effects of this conversion, use int_to_half().
|
||
|
|
||
|
"""
|
||
|
buf = struct.pack("f", f)
|
||
|
return struct.unpack("i", buf)[0]
|
||
|
|
||
|
def int_to_half(i: int) -> float:
|
||
|
"""Casts an integer value to a half-precision float.
|
||
|
|
||
|
Converts an integer value obtained from half_to_int back into a floating
|
||
|
point value.
|
||
|
|
||
|
"""
|
||
|
buf = struct.pack("i", i)
|
||
|
return struct.unpack("f", buf)[0]
|
||
|
|
||
|
def _tensor_to_half_val(t: torch.Tensor) -> List[int]:
|
||
|
return [half_to_int(x) for x in t.flatten().tolist()]
|
||
|
|
||
|
def _tensor_to_complex_val(t: torch.Tensor) -> List[float]:
|
||
|
return torch.view_as_real(t).flatten().tolist()
|
||
|
|
||
|
def _tensor_to_list(t: torch.Tensor) -> List[Any]:
|
||
|
return t.flatten().tolist()
|
||
|
|
||
|
# type maps: torch.Tensor type -> (protobuf type, protobuf val field)
|
||
|
_TENSOR_TYPE_MAP = {
|
||
|
torch.half: ("DT_HALF", "half_val", _tensor_to_half_val),
|
||
|
torch.float16: ("DT_HALF", "half_val", _tensor_to_half_val),
|
||
|
torch.bfloat16: ("DT_BFLOAT16", "half_val", _tensor_to_half_val),
|
||
|
torch.float32: ("DT_FLOAT", "float_val", _tensor_to_list),
|
||
|
torch.float: ("DT_FLOAT", "float_val", _tensor_to_list),
|
||
|
torch.float64: ("DT_DOUBLE", "double_val", _tensor_to_list),
|
||
|
torch.double: ("DT_DOUBLE", "double_val", _tensor_to_list),
|
||
|
torch.int8: ("DT_INT8", "int_val", _tensor_to_list),
|
||
|
torch.uint8: ("DT_UINT8", "int_val", _tensor_to_list),
|
||
|
torch.qint8: ("DT_UINT8", "int_val", _tensor_to_list),
|
||
|
torch.int16: ("DT_INT16", "int_val", _tensor_to_list),
|
||
|
torch.short: ("DT_INT16", "int_val", _tensor_to_list),
|
||
|
torch.int: ("DT_INT32", "int_val", _tensor_to_list),
|
||
|
torch.int32: ("DT_INT32", "int_val", _tensor_to_list),
|
||
|
torch.qint32: ("DT_INT32", "int_val", _tensor_to_list),
|
||
|
torch.int64: ("DT_INT64", "int64_val", _tensor_to_list),
|
||
|
torch.complex32: ("DT_COMPLEX32", "scomplex_val", _tensor_to_complex_val),
|
||
|
torch.chalf: ("DT_COMPLEX32", "scomplex_val", _tensor_to_complex_val),
|
||
|
torch.complex64: ("DT_COMPLEX64", "scomplex_val", _tensor_to_complex_val),
|
||
|
torch.cfloat: ("DT_COMPLEX64", "scomplex_val", _tensor_to_complex_val),
|
||
|
torch.bool: ("DT_BOOL", "bool_val", _tensor_to_list),
|
||
|
torch.complex128: ("DT_COMPLEX128", "dcomplex_val", _tensor_to_complex_val),
|
||
|
torch.cdouble: ("DT_COMPLEX128", "dcomplex_val", _tensor_to_complex_val),
|
||
|
torch.uint8: ("DT_UINT8", "uint32_val", _tensor_to_list),
|
||
|
torch.quint8: ("DT_UINT8", "uint32_val", _tensor_to_list),
|
||
|
torch.quint4x2: ("DT_UINT8", "uint32_val", _tensor_to_list),
|
||
|
}
|
||
|
|
||
|
|
||
|
def _calc_scale_factor(tensor):
|
||
|
converted = tensor.numpy() if not isinstance(tensor, np.ndarray) else tensor
|
||
|
return 1 if converted.dtype == np.uint8 else 255
|
||
|
|
||
|
|
||
|
def _draw_single_box(
|
||
|
image,
|
||
|
xmin,
|
||
|
ymin,
|
||
|
xmax,
|
||
|
ymax,
|
||
|
display_str,
|
||
|
color="black",
|
||
|
color_text="black",
|
||
|
thickness=2,
|
||
|
):
|
||
|
from PIL import ImageDraw, ImageFont
|
||
|
|
||
|
font = ImageFont.load_default()
|
||
|
draw = ImageDraw.Draw(image)
|
||
|
(left, right, top, bottom) = (xmin, xmax, ymin, ymax)
|
||
|
draw.line(
|
||
|
[(left, top), (left, bottom), (right, bottom), (right, top), (left, top)],
|
||
|
width=thickness,
|
||
|
fill=color,
|
||
|
)
|
||
|
if display_str:
|
||
|
text_bottom = bottom
|
||
|
# Reverse list and print from bottom to top.
|
||
|
text_width, text_height = font.getsize(display_str)
|
||
|
margin = np.ceil(0.05 * text_height)
|
||
|
draw.rectangle(
|
||
|
[
|
||
|
(left, text_bottom - text_height - 2 * margin),
|
||
|
(left + text_width, text_bottom),
|
||
|
],
|
||
|
fill=color,
|
||
|
)
|
||
|
draw.text(
|
||
|
(left + margin, text_bottom - text_height - margin),
|
||
|
display_str,
|
||
|
fill=color_text,
|
||
|
font=font,
|
||
|
)
|
||
|
return image
|
||
|
|
||
|
|
||
|
def hparams(hparam_dict=None, metric_dict=None, hparam_domain_discrete=None):
|
||
|
"""Output three `Summary` protocol buffers needed by hparams plugin.
|
||
|
|
||
|
`Experiment` keeps the metadata of an experiment, such as the name of the
|
||
|
hyperparameters and the name of the metrics.
|
||
|
`SessionStartInfo` keeps key-value pairs of the hyperparameters
|
||
|
`SessionEndInfo` describes status of the experiment e.g. STATUS_SUCCESS
|
||
|
|
||
|
Args:
|
||
|
hparam_dict: A dictionary that contains names of the hyperparameters
|
||
|
and their values.
|
||
|
metric_dict: A dictionary that contains names of the metrics
|
||
|
and their values.
|
||
|
hparam_domain_discrete: (Optional[Dict[str, List[Any]]]) A dictionary that
|
||
|
contains names of the hyperparameters and all discrete values they can hold
|
||
|
|
||
|
Returns:
|
||
|
The `Summary` protobufs for Experiment, SessionStartInfo and
|
||
|
SessionEndInfo
|
||
|
"""
|
||
|
import torch
|
||
|
from tensorboard.plugins.hparams.api_pb2 import (
|
||
|
DataType,
|
||
|
Experiment,
|
||
|
HParamInfo,
|
||
|
MetricInfo,
|
||
|
MetricName,
|
||
|
Status,
|
||
|
)
|
||
|
from tensorboard.plugins.hparams.metadata import (
|
||
|
EXPERIMENT_TAG,
|
||
|
PLUGIN_DATA_VERSION,
|
||
|
PLUGIN_NAME,
|
||
|
SESSION_END_INFO_TAG,
|
||
|
SESSION_START_INFO_TAG,
|
||
|
)
|
||
|
from tensorboard.plugins.hparams.plugin_data_pb2 import (
|
||
|
HParamsPluginData,
|
||
|
SessionEndInfo,
|
||
|
SessionStartInfo,
|
||
|
)
|
||
|
|
||
|
# TODO: expose other parameters in the future.
|
||
|
# hp = HParamInfo(name='lr',display_name='learning rate',
|
||
|
# type=DataType.DATA_TYPE_FLOAT64, domain_interval=Interval(min_value=10,
|
||
|
# max_value=100))
|
||
|
# mt = MetricInfo(name=MetricName(tag='accuracy'), display_name='accuracy',
|
||
|
# description='', dataset_type=DatasetType.DATASET_VALIDATION)
|
||
|
# exp = Experiment(name='123', description='456', time_created_secs=100.0,
|
||
|
# hparam_infos=[hp], metric_infos=[mt], user='tw')
|
||
|
|
||
|
if not isinstance(hparam_dict, dict):
|
||
|
logger.warning("parameter: hparam_dict should be a dictionary, nothing logged.")
|
||
|
raise TypeError(
|
||
|
"parameter: hparam_dict should be a dictionary, nothing logged."
|
||
|
)
|
||
|
if not isinstance(metric_dict, dict):
|
||
|
logger.warning("parameter: metric_dict should be a dictionary, nothing logged.")
|
||
|
raise TypeError(
|
||
|
"parameter: metric_dict should be a dictionary, nothing logged."
|
||
|
)
|
||
|
|
||
|
hparam_domain_discrete = hparam_domain_discrete or {}
|
||
|
if not isinstance(hparam_domain_discrete, dict):
|
||
|
raise TypeError(
|
||
|
"parameter: hparam_domain_discrete should be a dictionary, nothing logged."
|
||
|
)
|
||
|
for k, v in hparam_domain_discrete.items():
|
||
|
if (
|
||
|
k not in hparam_dict
|
||
|
or not isinstance(v, list)
|
||
|
or not all(isinstance(d, type(hparam_dict[k])) for d in v)
|
||
|
):
|
||
|
raise TypeError(
|
||
|
f"parameter: hparam_domain_discrete[{k}] should be a list of same type as hparam_dict[{k}]."
|
||
|
)
|
||
|
hps = []
|
||
|
|
||
|
ssi = SessionStartInfo()
|
||
|
for k, v in hparam_dict.items():
|
||
|
if v is None:
|
||
|
continue
|
||
|
if isinstance(v, (int, float)):
|
||
|
ssi.hparams[k].number_value = v
|
||
|
|
||
|
if k in hparam_domain_discrete:
|
||
|
domain_discrete: Optional[struct_pb2.ListValue] = struct_pb2.ListValue(
|
||
|
values=[
|
||
|
struct_pb2.Value(number_value=d)
|
||
|
for d in hparam_domain_discrete[k]
|
||
|
]
|
||
|
)
|
||
|
else:
|
||
|
domain_discrete = None
|
||
|
|
||
|
hps.append(
|
||
|
HParamInfo(
|
||
|
name=k,
|
||
|
type=DataType.Value("DATA_TYPE_FLOAT64"),
|
||
|
domain_discrete=domain_discrete,
|
||
|
)
|
||
|
)
|
||
|
continue
|
||
|
|
||
|
if isinstance(v, str):
|
||
|
ssi.hparams[k].string_value = v
|
||
|
|
||
|
if k in hparam_domain_discrete:
|
||
|
domain_discrete = struct_pb2.ListValue(
|
||
|
values=[
|
||
|
struct_pb2.Value(string_value=d)
|
||
|
for d in hparam_domain_discrete[k]
|
||
|
]
|
||
|
)
|
||
|
else:
|
||
|
domain_discrete = None
|
||
|
|
||
|
hps.append(
|
||
|
HParamInfo(
|
||
|
name=k,
|
||
|
type=DataType.Value("DATA_TYPE_STRING"),
|
||
|
domain_discrete=domain_discrete,
|
||
|
)
|
||
|
)
|
||
|
continue
|
||
|
|
||
|
if isinstance(v, bool):
|
||
|
ssi.hparams[k].bool_value = v
|
||
|
|
||
|
if k in hparam_domain_discrete:
|
||
|
domain_discrete = struct_pb2.ListValue(
|
||
|
values=[
|
||
|
struct_pb2.Value(bool_value=d)
|
||
|
for d in hparam_domain_discrete[k]
|
||
|
]
|
||
|
)
|
||
|
else:
|
||
|
domain_discrete = None
|
||
|
|
||
|
hps.append(
|
||
|
HParamInfo(
|
||
|
name=k,
|
||
|
type=DataType.Value("DATA_TYPE_BOOL"),
|
||
|
domain_discrete=domain_discrete,
|
||
|
)
|
||
|
)
|
||
|
continue
|
||
|
|
||
|
if isinstance(v, torch.Tensor):
|
||
|
v = make_np(v)[0]
|
||
|
ssi.hparams[k].number_value = v
|
||
|
hps.append(HParamInfo(name=k, type=DataType.Value("DATA_TYPE_FLOAT64")))
|
||
|
continue
|
||
|
raise ValueError(
|
||
|
"value should be one of int, float, str, bool, or torch.Tensor"
|
||
|
)
|
||
|
|
||
|
content = HParamsPluginData(session_start_info=ssi, version=PLUGIN_DATA_VERSION)
|
||
|
smd = SummaryMetadata(
|
||
|
plugin_data=SummaryMetadata.PluginData(
|
||
|
plugin_name=PLUGIN_NAME, content=content.SerializeToString()
|
||
|
)
|
||
|
)
|
||
|
ssi = Summary(value=[Summary.Value(tag=SESSION_START_INFO_TAG, metadata=smd)])
|
||
|
|
||
|
mts = [MetricInfo(name=MetricName(tag=k)) for k in metric_dict.keys()]
|
||
|
|
||
|
exp = Experiment(hparam_infos=hps, metric_infos=mts)
|
||
|
|
||
|
content = HParamsPluginData(experiment=exp, version=PLUGIN_DATA_VERSION)
|
||
|
smd = SummaryMetadata(
|
||
|
plugin_data=SummaryMetadata.PluginData(
|
||
|
plugin_name=PLUGIN_NAME, content=content.SerializeToString()
|
||
|
)
|
||
|
)
|
||
|
exp = Summary(value=[Summary.Value(tag=EXPERIMENT_TAG, metadata=smd)])
|
||
|
|
||
|
sei = SessionEndInfo(status=Status.Value("STATUS_SUCCESS"))
|
||
|
content = HParamsPluginData(session_end_info=sei, version=PLUGIN_DATA_VERSION)
|
||
|
smd = SummaryMetadata(
|
||
|
plugin_data=SummaryMetadata.PluginData(
|
||
|
plugin_name=PLUGIN_NAME, content=content.SerializeToString()
|
||
|
)
|
||
|
)
|
||
|
sei = Summary(value=[Summary.Value(tag=SESSION_END_INFO_TAG, metadata=smd)])
|
||
|
|
||
|
return exp, ssi, sei
|
||
|
|
||
|
|
||
|
def scalar(name, tensor, collections=None, new_style=False, double_precision=False):
|
||
|
"""Output a `Summary` protocol buffer containing a single scalar value.
|
||
|
|
||
|
The generated Summary has a Tensor.proto containing the input Tensor.
|
||
|
Args:
|
||
|
name: A name for the generated node. Will also serve as the series name in
|
||
|
TensorBoard.
|
||
|
tensor: A real numeric Tensor containing a single value.
|
||
|
collections: Optional list of graph collections keys. The new summary op is
|
||
|
added to these collections. Defaults to `[GraphKeys.SUMMARIES]`.
|
||
|
new_style: Whether to use new style (tensor field) or old style (simple_value
|
||
|
field). New style could lead to faster data loading.
|
||
|
Returns:
|
||
|
A scalar `Tensor` of type `string`. Which contains a `Summary` protobuf.
|
||
|
Raises:
|
||
|
ValueError: If tensor has the wrong shape or type.
|
||
|
"""
|
||
|
tensor = make_np(tensor).squeeze()
|
||
|
assert (
|
||
|
tensor.ndim == 0
|
||
|
), f"Tensor should contain one element (0 dimensions). Was given size: {tensor.size} and {tensor.ndim} dimensions."
|
||
|
# python float is double precision in numpy
|
||
|
scalar = float(tensor)
|
||
|
if new_style:
|
||
|
tensor_proto = TensorProto(float_val=[scalar], dtype="DT_FLOAT")
|
||
|
if double_precision:
|
||
|
tensor_proto = TensorProto(double_val=[scalar], dtype="DT_DOUBLE")
|
||
|
|
||
|
plugin_data = SummaryMetadata.PluginData(plugin_name="scalars")
|
||
|
smd = SummaryMetadata(plugin_data=plugin_data)
|
||
|
return Summary(
|
||
|
value=[
|
||
|
Summary.Value(
|
||
|
tag=name,
|
||
|
tensor=tensor_proto,
|
||
|
metadata=smd,
|
||
|
)
|
||
|
]
|
||
|
)
|
||
|
else:
|
||
|
return Summary(value=[Summary.Value(tag=name, simple_value=scalar)])
|
||
|
|
||
|
|
||
|
def tensor_proto(tag, tensor):
|
||
|
"""Outputs a `Summary` protocol buffer containing the full tensor.
|
||
|
The generated Summary has a Tensor.proto containing the input Tensor.
|
||
|
Args:
|
||
|
name: A name for the generated node. Will also serve as the series name in
|
||
|
TensorBoard.
|
||
|
tensor: Tensor to be converted to protobuf
|
||
|
Returns:
|
||
|
A tensor protobuf in a `Summary` protobuf.
|
||
|
Raises:
|
||
|
ValueError: If tensor is too big to be converted to protobuf, or
|
||
|
tensor data type is not supported
|
||
|
"""
|
||
|
if tensor.numel() * tensor.itemsize >= (1 << 31):
|
||
|
raise ValueError(
|
||
|
"tensor is bigger than protocol buffer's hard limit of 2GB in size"
|
||
|
)
|
||
|
|
||
|
if tensor.dtype in _TENSOR_TYPE_MAP:
|
||
|
dtype, field_name, conversion_fn = _TENSOR_TYPE_MAP[tensor.dtype]
|
||
|
tensor_proto = TensorProto(
|
||
|
**{
|
||
|
"dtype": dtype,
|
||
|
"tensor_shape": TensorShapeProto(
|
||
|
dim=[TensorShapeProto.Dim(size=x) for x in tensor.shape]
|
||
|
),
|
||
|
field_name: conversion_fn(tensor),
|
||
|
},
|
||
|
)
|
||
|
else:
|
||
|
raise ValueError(f"{tag} has unsupported tensor dtype {tensor.dtype}")
|
||
|
|
||
|
plugin_data = SummaryMetadata.PluginData(plugin_name="tensor")
|
||
|
smd = SummaryMetadata(plugin_data=plugin_data)
|
||
|
return Summary(value=[Summary.Value(tag=tag, metadata=smd, tensor=tensor_proto)])
|
||
|
|
||
|
|
||
|
def histogram_raw(name, min, max, num, sum, sum_squares, bucket_limits, bucket_counts):
|
||
|
# pylint: disable=line-too-long
|
||
|
"""Output a `Summary` protocol buffer with a histogram.
|
||
|
|
||
|
The generated
|
||
|
[`Summary`](https://www.tensorflow.org/code/tensorflow/core/framework/summary.proto)
|
||
|
has one summary value containing a histogram for `values`.
|
||
|
Args:
|
||
|
name: A name for the generated node. Will also serve as a series name in
|
||
|
TensorBoard.
|
||
|
min: A float or int min value
|
||
|
max: A float or int max value
|
||
|
num: Int number of values
|
||
|
sum: Float or int sum of all values
|
||
|
sum_squares: Float or int sum of squares for all values
|
||
|
bucket_limits: A numeric `Tensor` with upper value per bucket
|
||
|
bucket_counts: A numeric `Tensor` with number of values per bucket
|
||
|
Returns:
|
||
|
A scalar `Tensor` of type `string`. The serialized `Summary` protocol
|
||
|
buffer.
|
||
|
"""
|
||
|
hist = HistogramProto(
|
||
|
min=min,
|
||
|
max=max,
|
||
|
num=num,
|
||
|
sum=sum,
|
||
|
sum_squares=sum_squares,
|
||
|
bucket_limit=bucket_limits,
|
||
|
bucket=bucket_counts,
|
||
|
)
|
||
|
return Summary(value=[Summary.Value(tag=name, histo=hist)])
|
||
|
|
||
|
|
||
|
def histogram(name, values, bins, max_bins=None):
|
||
|
# pylint: disable=line-too-long
|
||
|
"""Output a `Summary` protocol buffer with a histogram.
|
||
|
|
||
|
The generated
|
||
|
[`Summary`](https://www.tensorflow.org/code/tensorflow/core/framework/summary.proto)
|
||
|
has one summary value containing a histogram for `values`.
|
||
|
This op reports an `InvalidArgument` error if any value is not finite.
|
||
|
Args:
|
||
|
name: A name for the generated node. Will also serve as a series name in
|
||
|
TensorBoard.
|
||
|
values: A real numeric `Tensor`. Any shape. Values to use to
|
||
|
build the histogram.
|
||
|
Returns:
|
||
|
A scalar `Tensor` of type `string`. The serialized `Summary` protocol
|
||
|
buffer.
|
||
|
"""
|
||
|
values = make_np(values)
|
||
|
hist = make_histogram(values.astype(float), bins, max_bins)
|
||
|
return Summary(value=[Summary.Value(tag=name, histo=hist)])
|
||
|
|
||
|
|
||
|
def make_histogram(values, bins, max_bins=None):
|
||
|
"""Convert values into a histogram proto using logic from histogram.cc."""
|
||
|
if values.size == 0:
|
||
|
raise ValueError("The input has no element.")
|
||
|
values = values.reshape(-1)
|
||
|
counts, limits = np.histogram(values, bins=bins)
|
||
|
num_bins = len(counts)
|
||
|
if max_bins is not None and num_bins > max_bins:
|
||
|
subsampling = num_bins // max_bins
|
||
|
subsampling_remainder = num_bins % subsampling
|
||
|
if subsampling_remainder != 0:
|
||
|
counts = np.pad(
|
||
|
counts,
|
||
|
pad_width=[[0, subsampling - subsampling_remainder]],
|
||
|
mode="constant",
|
||
|
constant_values=0,
|
||
|
)
|
||
|
counts = counts.reshape(-1, subsampling).sum(axis=-1)
|
||
|
new_limits = np.empty((counts.size + 1,), limits.dtype)
|
||
|
new_limits[:-1] = limits[:-1:subsampling]
|
||
|
new_limits[-1] = limits[-1]
|
||
|
limits = new_limits
|
||
|
|
||
|
# Find the first and the last bin defining the support of the histogram:
|
||
|
|
||
|
cum_counts = np.cumsum(np.greater(counts, 0))
|
||
|
start, end = np.searchsorted(cum_counts, [0, cum_counts[-1] - 1], side="right")
|
||
|
start = int(start)
|
||
|
end = int(end) + 1
|
||
|
del cum_counts
|
||
|
|
||
|
# TensorBoard only includes the right bin limits. To still have the leftmost limit
|
||
|
# included, we include an empty bin left.
|
||
|
# If start == 0, we need to add an empty one left, otherwise we can just include the bin left to the
|
||
|
# first nonzero-count bin:
|
||
|
counts = (
|
||
|
counts[start - 1 : end] if start > 0 else np.concatenate([[0], counts[:end]])
|
||
|
)
|
||
|
limits = limits[start : end + 1]
|
||
|
|
||
|
if counts.size == 0 or limits.size == 0:
|
||
|
raise ValueError("The histogram is empty, please file a bug report.")
|
||
|
|
||
|
sum_sq = values.dot(values)
|
||
|
return HistogramProto(
|
||
|
min=values.min(),
|
||
|
max=values.max(),
|
||
|
num=len(values),
|
||
|
sum=values.sum(),
|
||
|
sum_squares=sum_sq,
|
||
|
bucket_limit=limits.tolist(),
|
||
|
bucket=counts.tolist(),
|
||
|
)
|
||
|
|
||
|
|
||
|
def image(tag, tensor, rescale=1, dataformats="NCHW"):
|
||
|
"""Output a `Summary` protocol buffer with images.
|
||
|
|
||
|
The summary has up to `max_images` summary values containing images. The
|
||
|
images are built from `tensor` which must be 3-D with shape `[height, width,
|
||
|
channels]` and where `channels` can be:
|
||
|
* 1: `tensor` is interpreted as Grayscale.
|
||
|
* 3: `tensor` is interpreted as RGB.
|
||
|
* 4: `tensor` is interpreted as RGBA.
|
||
|
The `name` in the outputted Summary.Value protobufs is generated based on the
|
||
|
name, with a suffix depending on the max_outputs setting:
|
||
|
* If `max_outputs` is 1, the summary value tag is '*name*/image'.
|
||
|
* If `max_outputs` is greater than 1, the summary value tags are
|
||
|
generated sequentially as '*name*/image/0', '*name*/image/1', etc.
|
||
|
Args:
|
||
|
tag: A name for the generated node. Will also serve as a series name in
|
||
|
TensorBoard.
|
||
|
tensor: A 3-D `uint8` or `float32` `Tensor` of shape `[height, width,
|
||
|
channels]` where `channels` is 1, 3, or 4.
|
||
|
'tensor' can either have values in [0, 1] (float32) or [0, 255] (uint8).
|
||
|
The image() function will scale the image values to [0, 255] by applying
|
||
|
a scale factor of either 1 (uint8) or 255 (float32). Out-of-range values
|
||
|
will be clipped.
|
||
|
Returns:
|
||
|
A scalar `Tensor` of type `string`. The serialized `Summary` protocol
|
||
|
buffer.
|
||
|
"""
|
||
|
tensor = make_np(tensor)
|
||
|
tensor = convert_to_HWC(tensor, dataformats)
|
||
|
# Do not assume that user passes in values in [0, 255], use data type to detect
|
||
|
scale_factor = _calc_scale_factor(tensor)
|
||
|
tensor = tensor.astype(np.float32)
|
||
|
tensor = (tensor * scale_factor).clip(0, 255).astype(np.uint8)
|
||
|
image = make_image(tensor, rescale=rescale)
|
||
|
return Summary(value=[Summary.Value(tag=tag, image=image)])
|
||
|
|
||
|
|
||
|
def image_boxes(
|
||
|
tag, tensor_image, tensor_boxes, rescale=1, dataformats="CHW", labels=None
|
||
|
):
|
||
|
"""Output a `Summary` protocol buffer with images."""
|
||
|
tensor_image = make_np(tensor_image)
|
||
|
tensor_image = convert_to_HWC(tensor_image, dataformats)
|
||
|
tensor_boxes = make_np(tensor_boxes)
|
||
|
tensor_image = tensor_image.astype(np.float32) * _calc_scale_factor(tensor_image)
|
||
|
image = make_image(
|
||
|
tensor_image.clip(0, 255).astype(np.uint8),
|
||
|
rescale=rescale,
|
||
|
rois=tensor_boxes,
|
||
|
labels=labels,
|
||
|
)
|
||
|
return Summary(value=[Summary.Value(tag=tag, image=image)])
|
||
|
|
||
|
|
||
|
def draw_boxes(disp_image, boxes, labels=None):
|
||
|
# xyxy format
|
||
|
num_boxes = boxes.shape[0]
|
||
|
list_gt = range(num_boxes)
|
||
|
for i in list_gt:
|
||
|
disp_image = _draw_single_box(
|
||
|
disp_image,
|
||
|
boxes[i, 0],
|
||
|
boxes[i, 1],
|
||
|
boxes[i, 2],
|
||
|
boxes[i, 3],
|
||
|
display_str=None if labels is None else labels[i],
|
||
|
color="Red",
|
||
|
)
|
||
|
return disp_image
|
||
|
|
||
|
|
||
|
def make_image(tensor, rescale=1, rois=None, labels=None):
|
||
|
"""Convert a numpy representation of an image to Image protobuf."""
|
||
|
from PIL import Image
|
||
|
|
||
|
height, width, channel = tensor.shape
|
||
|
scaled_height = int(height * rescale)
|
||
|
scaled_width = int(width * rescale)
|
||
|
image = Image.fromarray(tensor)
|
||
|
if rois is not None:
|
||
|
image = draw_boxes(image, rois, labels=labels)
|
||
|
try:
|
||
|
ANTIALIAS = Image.Resampling.LANCZOS
|
||
|
except AttributeError:
|
||
|
ANTIALIAS = Image.ANTIALIAS
|
||
|
image = image.resize((scaled_width, scaled_height), ANTIALIAS)
|
||
|
import io
|
||
|
|
||
|
output = io.BytesIO()
|
||
|
image.save(output, format="PNG")
|
||
|
image_string = output.getvalue()
|
||
|
output.close()
|
||
|
return Summary.Image(
|
||
|
height=height,
|
||
|
width=width,
|
||
|
colorspace=channel,
|
||
|
encoded_image_string=image_string,
|
||
|
)
|
||
|
|
||
|
|
||
|
def video(tag, tensor, fps=4):
|
||
|
tensor = make_np(tensor)
|
||
|
tensor = _prepare_video(tensor)
|
||
|
# If user passes in uint8, then we don't need to rescale by 255
|
||
|
scale_factor = _calc_scale_factor(tensor)
|
||
|
tensor = tensor.astype(np.float32)
|
||
|
tensor = (tensor * scale_factor).clip(0, 255).astype(np.uint8)
|
||
|
video = make_video(tensor, fps)
|
||
|
return Summary(value=[Summary.Value(tag=tag, image=video)])
|
||
|
|
||
|
|
||
|
def make_video(tensor, fps):
|
||
|
try:
|
||
|
import moviepy # noqa: F401
|
||
|
except ImportError:
|
||
|
print("add_video needs package moviepy")
|
||
|
return
|
||
|
try:
|
||
|
from moviepy import editor as mpy
|
||
|
except ImportError:
|
||
|
print(
|
||
|
"moviepy is installed, but can't import moviepy.editor.",
|
||
|
"Some packages could be missing [imageio, requests]",
|
||
|
)
|
||
|
return
|
||
|
import tempfile
|
||
|
|
||
|
t, h, w, c = tensor.shape
|
||
|
|
||
|
# encode sequence of images into gif string
|
||
|
clip = mpy.ImageSequenceClip(list(tensor), fps=fps)
|
||
|
|
||
|
filename = tempfile.NamedTemporaryFile(suffix=".gif", delete=False).name
|
||
|
try: # newer version of moviepy use logger instead of progress_bar argument.
|
||
|
clip.write_gif(filename, verbose=False, logger=None)
|
||
|
except TypeError:
|
||
|
try: # older version of moviepy does not support progress_bar argument.
|
||
|
clip.write_gif(filename, verbose=False, progress_bar=False)
|
||
|
except TypeError:
|
||
|
clip.write_gif(filename, verbose=False)
|
||
|
|
||
|
with open(filename, "rb") as f:
|
||
|
tensor_string = f.read()
|
||
|
|
||
|
try:
|
||
|
os.remove(filename)
|
||
|
except OSError:
|
||
|
logger.warning("The temporary file used by moviepy cannot be deleted.")
|
||
|
|
||
|
return Summary.Image(
|
||
|
height=h, width=w, colorspace=c, encoded_image_string=tensor_string
|
||
|
)
|
||
|
|
||
|
|
||
|
def audio(tag, tensor, sample_rate=44100):
|
||
|
array = make_np(tensor)
|
||
|
array = array.squeeze()
|
||
|
if abs(array).max() > 1:
|
||
|
print("warning: audio amplitude out of range, auto clipped.")
|
||
|
array = array.clip(-1, 1)
|
||
|
assert array.ndim == 1, "input tensor should be 1 dimensional."
|
||
|
array = (array * np.iinfo(np.int16).max).astype("<i2")
|
||
|
|
||
|
import io
|
||
|
import wave
|
||
|
|
||
|
fio = io.BytesIO()
|
||
|
with wave.open(fio, "wb") as wave_write:
|
||
|
wave_write.setnchannels(1)
|
||
|
wave_write.setsampwidth(2)
|
||
|
wave_write.setframerate(sample_rate)
|
||
|
wave_write.writeframes(array.data)
|
||
|
audio_string = fio.getvalue()
|
||
|
fio.close()
|
||
|
audio = Summary.Audio(
|
||
|
sample_rate=sample_rate,
|
||
|
num_channels=1,
|
||
|
length_frames=array.shape[-1],
|
||
|
encoded_audio_string=audio_string,
|
||
|
content_type="audio/wav",
|
||
|
)
|
||
|
return Summary(value=[Summary.Value(tag=tag, audio=audio)])
|
||
|
|
||
|
|
||
|
def custom_scalars(layout):
|
||
|
categories = []
|
||
|
for k, v in layout.items():
|
||
|
charts = []
|
||
|
for chart_name, chart_meatadata in v.items():
|
||
|
tags = chart_meatadata[1]
|
||
|
if chart_meatadata[0] == "Margin":
|
||
|
assert len(tags) == 3
|
||
|
mgcc = layout_pb2.MarginChartContent(
|
||
|
series=[
|
||
|
layout_pb2.MarginChartContent.Series(
|
||
|
value=tags[0], lower=tags[1], upper=tags[2]
|
||
|
)
|
||
|
]
|
||
|
)
|
||
|
chart = layout_pb2.Chart(title=chart_name, margin=mgcc)
|
||
|
else:
|
||
|
mlcc = layout_pb2.MultilineChartContent(tag=tags)
|
||
|
chart = layout_pb2.Chart(title=chart_name, multiline=mlcc)
|
||
|
charts.append(chart)
|
||
|
categories.append(layout_pb2.Category(title=k, chart=charts))
|
||
|
|
||
|
layout = layout_pb2.Layout(category=categories)
|
||
|
plugin_data = SummaryMetadata.PluginData(plugin_name="custom_scalars")
|
||
|
smd = SummaryMetadata(plugin_data=plugin_data)
|
||
|
tensor = TensorProto(
|
||
|
dtype="DT_STRING",
|
||
|
string_val=[layout.SerializeToString()],
|
||
|
tensor_shape=TensorShapeProto(),
|
||
|
)
|
||
|
return Summary(
|
||
|
value=[
|
||
|
Summary.Value(tag="custom_scalars__config__", tensor=tensor, metadata=smd)
|
||
|
]
|
||
|
)
|
||
|
|
||
|
|
||
|
def text(tag, text):
|
||
|
plugin_data = SummaryMetadata.PluginData(
|
||
|
plugin_name="text", content=TextPluginData(version=0).SerializeToString()
|
||
|
)
|
||
|
smd = SummaryMetadata(plugin_data=plugin_data)
|
||
|
tensor = TensorProto(
|
||
|
dtype="DT_STRING",
|
||
|
string_val=[text.encode(encoding="utf_8")],
|
||
|
tensor_shape=TensorShapeProto(dim=[TensorShapeProto.Dim(size=1)]),
|
||
|
)
|
||
|
return Summary(
|
||
|
value=[Summary.Value(tag=tag + "/text_summary", metadata=smd, tensor=tensor)]
|
||
|
)
|
||
|
|
||
|
|
||
|
def pr_curve_raw(
|
||
|
tag, tp, fp, tn, fn, precision, recall, num_thresholds=127, weights=None
|
||
|
):
|
||
|
if num_thresholds > 127: # weird, value > 127 breaks protobuf
|
||
|
num_thresholds = 127
|
||
|
data = np.stack((tp, fp, tn, fn, precision, recall))
|
||
|
pr_curve_plugin_data = PrCurvePluginData(
|
||
|
version=0, num_thresholds=num_thresholds
|
||
|
).SerializeToString()
|
||
|
plugin_data = SummaryMetadata.PluginData(
|
||
|
plugin_name="pr_curves", content=pr_curve_plugin_data
|
||
|
)
|
||
|
smd = SummaryMetadata(plugin_data=plugin_data)
|
||
|
tensor = TensorProto(
|
||
|
dtype="DT_FLOAT",
|
||
|
float_val=data.reshape(-1).tolist(),
|
||
|
tensor_shape=TensorShapeProto(
|
||
|
dim=[
|
||
|
TensorShapeProto.Dim(size=data.shape[0]),
|
||
|
TensorShapeProto.Dim(size=data.shape[1]),
|
||
|
]
|
||
|
),
|
||
|
)
|
||
|
return Summary(value=[Summary.Value(tag=tag, metadata=smd, tensor=tensor)])
|
||
|
|
||
|
|
||
|
def pr_curve(tag, labels, predictions, num_thresholds=127, weights=None):
|
||
|
# weird, value > 127 breaks protobuf
|
||
|
num_thresholds = min(num_thresholds, 127)
|
||
|
data = compute_curve(
|
||
|
labels, predictions, num_thresholds=num_thresholds, weights=weights
|
||
|
)
|
||
|
pr_curve_plugin_data = PrCurvePluginData(
|
||
|
version=0, num_thresholds=num_thresholds
|
||
|
).SerializeToString()
|
||
|
plugin_data = SummaryMetadata.PluginData(
|
||
|
plugin_name="pr_curves", content=pr_curve_plugin_data
|
||
|
)
|
||
|
smd = SummaryMetadata(plugin_data=plugin_data)
|
||
|
tensor = TensorProto(
|
||
|
dtype="DT_FLOAT",
|
||
|
float_val=data.reshape(-1).tolist(),
|
||
|
tensor_shape=TensorShapeProto(
|
||
|
dim=[
|
||
|
TensorShapeProto.Dim(size=data.shape[0]),
|
||
|
TensorShapeProto.Dim(size=data.shape[1]),
|
||
|
]
|
||
|
),
|
||
|
)
|
||
|
return Summary(value=[Summary.Value(tag=tag, metadata=smd, tensor=tensor)])
|
||
|
|
||
|
|
||
|
# https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/pr_curve/summary.py
|
||
|
def compute_curve(labels, predictions, num_thresholds=None, weights=None):
|
||
|
_MINIMUM_COUNT = 1e-7
|
||
|
|
||
|
if weights is None:
|
||
|
weights = 1.0
|
||
|
|
||
|
# Compute bins of true positives and false positives.
|
||
|
bucket_indices = np.int32(np.floor(predictions * (num_thresholds - 1)))
|
||
|
float_labels = labels.astype(np.float64)
|
||
|
histogram_range = (0, num_thresholds - 1)
|
||
|
tp_buckets, _ = np.histogram(
|
||
|
bucket_indices,
|
||
|
bins=num_thresholds,
|
||
|
range=histogram_range,
|
||
|
weights=float_labels * weights,
|
||
|
)
|
||
|
fp_buckets, _ = np.histogram(
|
||
|
bucket_indices,
|
||
|
bins=num_thresholds,
|
||
|
range=histogram_range,
|
||
|
weights=(1.0 - float_labels) * weights,
|
||
|
)
|
||
|
|
||
|
# Obtain the reverse cumulative sum.
|
||
|
tp = np.cumsum(tp_buckets[::-1])[::-1]
|
||
|
fp = np.cumsum(fp_buckets[::-1])[::-1]
|
||
|
tn = fp[0] - fp
|
||
|
fn = tp[0] - tp
|
||
|
precision = tp / np.maximum(_MINIMUM_COUNT, tp + fp)
|
||
|
recall = tp / np.maximum(_MINIMUM_COUNT, tp + fn)
|
||
|
return np.stack((tp, fp, tn, fn, precision, recall))
|
||
|
|
||
|
|
||
|
def _get_tensor_summary(
|
||
|
name, display_name, description, tensor, content_type, components, json_config
|
||
|
):
|
||
|
"""Create a tensor summary with summary metadata.
|
||
|
|
||
|
Args:
|
||
|
name: Uniquely identifiable name of the summary op. Could be replaced by
|
||
|
combination of name and type to make it unique even outside of this
|
||
|
summary.
|
||
|
display_name: Will be used as the display name in TensorBoard.
|
||
|
Defaults to `name`.
|
||
|
description: A longform readable description of the summary data. Markdown
|
||
|
is supported.
|
||
|
tensor: Tensor to display in summary.
|
||
|
content_type: Type of content inside the Tensor.
|
||
|
components: Bitmask representing present parts (vertices, colors, etc.) that
|
||
|
belong to the summary.
|
||
|
json_config: A string, JSON-serialized dictionary of ThreeJS classes
|
||
|
configuration.
|
||
|
|
||
|
Returns:
|
||
|
Tensor summary with metadata.
|
||
|
"""
|
||
|
import torch
|
||
|
from tensorboard.plugins.mesh import metadata
|
||
|
|
||
|
tensor = torch.as_tensor(tensor)
|
||
|
|
||
|
tensor_metadata = metadata.create_summary_metadata(
|
||
|
name,
|
||
|
display_name,
|
||
|
content_type,
|
||
|
components,
|
||
|
tensor.shape,
|
||
|
description,
|
||
|
json_config=json_config,
|
||
|
)
|
||
|
|
||
|
tensor = TensorProto(
|
||
|
dtype="DT_FLOAT",
|
||
|
float_val=tensor.reshape(-1).tolist(),
|
||
|
tensor_shape=TensorShapeProto(
|
||
|
dim=[
|
||
|
TensorShapeProto.Dim(size=tensor.shape[0]),
|
||
|
TensorShapeProto.Dim(size=tensor.shape[1]),
|
||
|
TensorShapeProto.Dim(size=tensor.shape[2]),
|
||
|
]
|
||
|
),
|
||
|
)
|
||
|
|
||
|
tensor_summary = Summary.Value(
|
||
|
tag=metadata.get_instance_name(name, content_type),
|
||
|
tensor=tensor,
|
||
|
metadata=tensor_metadata,
|
||
|
)
|
||
|
|
||
|
return tensor_summary
|
||
|
|
||
|
|
||
|
def _get_json_config(config_dict):
|
||
|
"""Parse and returns JSON string from python dictionary."""
|
||
|
json_config = "{}"
|
||
|
if config_dict is not None:
|
||
|
json_config = json.dumps(config_dict, sort_keys=True)
|
||
|
return json_config
|
||
|
|
||
|
|
||
|
# https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/mesh/summary.py
|
||
|
def mesh(
|
||
|
tag, vertices, colors, faces, config_dict, display_name=None, description=None
|
||
|
):
|
||
|
"""Output a merged `Summary` protocol buffer with a mesh/point cloud.
|
||
|
|
||
|
Args:
|
||
|
tag: A name for this summary operation.
|
||
|
vertices: Tensor of shape `[dim_1, ..., dim_n, 3]` representing the 3D
|
||
|
coordinates of vertices.
|
||
|
faces: Tensor of shape `[dim_1, ..., dim_n, 3]` containing indices of
|
||
|
vertices within each triangle.
|
||
|
colors: Tensor of shape `[dim_1, ..., dim_n, 3]` containing colors for each
|
||
|
vertex.
|
||
|
display_name: If set, will be used as the display name in TensorBoard.
|
||
|
Defaults to `name`.
|
||
|
description: A longform readable description of the summary data. Markdown
|
||
|
is supported.
|
||
|
config_dict: Dictionary with ThreeJS classes names and configuration.
|
||
|
|
||
|
Returns:
|
||
|
Merged summary for mesh/point cloud representation.
|
||
|
"""
|
||
|
from tensorboard.plugins.mesh import metadata
|
||
|
from tensorboard.plugins.mesh.plugin_data_pb2 import MeshPluginData
|
||
|
|
||
|
json_config = _get_json_config(config_dict)
|
||
|
|
||
|
summaries = []
|
||
|
tensors = [
|
||
|
(vertices, MeshPluginData.VERTEX),
|
||
|
(faces, MeshPluginData.FACE),
|
||
|
(colors, MeshPluginData.COLOR),
|
||
|
]
|
||
|
tensors = [tensor for tensor in tensors if tensor[0] is not None]
|
||
|
components = metadata.get_components_bitmask(
|
||
|
[content_type for (tensor, content_type) in tensors]
|
||
|
)
|
||
|
|
||
|
for tensor, content_type in tensors:
|
||
|
summaries.append(
|
||
|
_get_tensor_summary(
|
||
|
tag,
|
||
|
display_name,
|
||
|
description,
|
||
|
tensor,
|
||
|
content_type,
|
||
|
components,
|
||
|
json_config,
|
||
|
)
|
||
|
)
|
||
|
|
||
|
return Summary(value=summaries)
|