ai-content-maker/.venv/Lib/site-packages/torch/utils/tensorboard/summary.py

984 lines
34 KiB
Python

import json
import logging
import os
import struct
from typing import Any, List, Optional
import torch
import numpy as np
from google.protobuf import struct_pb2
from tensorboard.compat.proto.summary_pb2 import (
HistogramProto,
Summary,
SummaryMetadata,
)
from tensorboard.compat.proto.tensor_pb2 import TensorProto
from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto
from tensorboard.plugins.custom_scalar import layout_pb2
from tensorboard.plugins.pr_curve.plugin_data_pb2 import PrCurvePluginData
from tensorboard.plugins.text.plugin_data_pb2 import TextPluginData
from ._convert_np import make_np
from ._utils import _prepare_video, convert_to_HWC
__all__ = [
"half_to_int",
"int_to_half",
"hparams",
"scalar",
"histogram_raw",
"histogram",
"make_histogram",
"image",
"image_boxes",
"draw_boxes",
"make_image",
"video",
"make_video",
"audio",
"custom_scalars",
"text",
"tensor_proto",
"pr_curve_raw",
"pr_curve",
"compute_curve",
"mesh",
]
logger = logging.getLogger(__name__)
def half_to_int(f: float) -> int:
"""Casts a half-precision float value into an integer.
Converts a half precision floating point value, such as `torch.half` or
`torch.bfloat16`, into an integer value which can be written into the
half_val field of a TensorProto for storage.
To undo the effects of this conversion, use int_to_half().
"""
buf = struct.pack("f", f)
return struct.unpack("i", buf)[0]
def int_to_half(i: int) -> float:
"""Casts an integer value to a half-precision float.
Converts an integer value obtained from half_to_int back into a floating
point value.
"""
buf = struct.pack("i", i)
return struct.unpack("f", buf)[0]
def _tensor_to_half_val(t: torch.Tensor) -> List[int]:
return [half_to_int(x) for x in t.flatten().tolist()]
def _tensor_to_complex_val(t: torch.Tensor) -> List[float]:
return torch.view_as_real(t).flatten().tolist()
def _tensor_to_list(t: torch.Tensor) -> List[Any]:
return t.flatten().tolist()
# type maps: torch.Tensor type -> (protobuf type, protobuf val field)
_TENSOR_TYPE_MAP = {
torch.half: ("DT_HALF", "half_val", _tensor_to_half_val),
torch.float16: ("DT_HALF", "half_val", _tensor_to_half_val),
torch.bfloat16: ("DT_BFLOAT16", "half_val", _tensor_to_half_val),
torch.float32: ("DT_FLOAT", "float_val", _tensor_to_list),
torch.float: ("DT_FLOAT", "float_val", _tensor_to_list),
torch.float64: ("DT_DOUBLE", "double_val", _tensor_to_list),
torch.double: ("DT_DOUBLE", "double_val", _tensor_to_list),
torch.int8: ("DT_INT8", "int_val", _tensor_to_list),
torch.uint8: ("DT_UINT8", "int_val", _tensor_to_list),
torch.qint8: ("DT_UINT8", "int_val", _tensor_to_list),
torch.int16: ("DT_INT16", "int_val", _tensor_to_list),
torch.short: ("DT_INT16", "int_val", _tensor_to_list),
torch.int: ("DT_INT32", "int_val", _tensor_to_list),
torch.int32: ("DT_INT32", "int_val", _tensor_to_list),
torch.qint32: ("DT_INT32", "int_val", _tensor_to_list),
torch.int64: ("DT_INT64", "int64_val", _tensor_to_list),
torch.complex32: ("DT_COMPLEX32", "scomplex_val", _tensor_to_complex_val),
torch.chalf: ("DT_COMPLEX32", "scomplex_val", _tensor_to_complex_val),
torch.complex64: ("DT_COMPLEX64", "scomplex_val", _tensor_to_complex_val),
torch.cfloat: ("DT_COMPLEX64", "scomplex_val", _tensor_to_complex_val),
torch.bool: ("DT_BOOL", "bool_val", _tensor_to_list),
torch.complex128: ("DT_COMPLEX128", "dcomplex_val", _tensor_to_complex_val),
torch.cdouble: ("DT_COMPLEX128", "dcomplex_val", _tensor_to_complex_val),
torch.uint8: ("DT_UINT8", "uint32_val", _tensor_to_list),
torch.quint8: ("DT_UINT8", "uint32_val", _tensor_to_list),
torch.quint4x2: ("DT_UINT8", "uint32_val", _tensor_to_list),
}
def _calc_scale_factor(tensor):
converted = tensor.numpy() if not isinstance(tensor, np.ndarray) else tensor
return 1 if converted.dtype == np.uint8 else 255
def _draw_single_box(
image,
xmin,
ymin,
xmax,
ymax,
display_str,
color="black",
color_text="black",
thickness=2,
):
from PIL import ImageDraw, ImageFont
font = ImageFont.load_default()
draw = ImageDraw.Draw(image)
(left, right, top, bottom) = (xmin, xmax, ymin, ymax)
draw.line(
[(left, top), (left, bottom), (right, bottom), (right, top), (left, top)],
width=thickness,
fill=color,
)
if display_str:
text_bottom = bottom
# Reverse list and print from bottom to top.
text_width, text_height = font.getsize(display_str)
margin = np.ceil(0.05 * text_height)
draw.rectangle(
[
(left, text_bottom - text_height - 2 * margin),
(left + text_width, text_bottom),
],
fill=color,
)
draw.text(
(left + margin, text_bottom - text_height - margin),
display_str,
fill=color_text,
font=font,
)
return image
def hparams(hparam_dict=None, metric_dict=None, hparam_domain_discrete=None):
"""Output three `Summary` protocol buffers needed by hparams plugin.
`Experiment` keeps the metadata of an experiment, such as the name of the
hyperparameters and the name of the metrics.
`SessionStartInfo` keeps key-value pairs of the hyperparameters
`SessionEndInfo` describes status of the experiment e.g. STATUS_SUCCESS
Args:
hparam_dict: A dictionary that contains names of the hyperparameters
and their values.
metric_dict: A dictionary that contains names of the metrics
and their values.
hparam_domain_discrete: (Optional[Dict[str, List[Any]]]) A dictionary that
contains names of the hyperparameters and all discrete values they can hold
Returns:
The `Summary` protobufs for Experiment, SessionStartInfo and
SessionEndInfo
"""
import torch
from tensorboard.plugins.hparams.api_pb2 import (
DataType,
Experiment,
HParamInfo,
MetricInfo,
MetricName,
Status,
)
from tensorboard.plugins.hparams.metadata import (
EXPERIMENT_TAG,
PLUGIN_DATA_VERSION,
PLUGIN_NAME,
SESSION_END_INFO_TAG,
SESSION_START_INFO_TAG,
)
from tensorboard.plugins.hparams.plugin_data_pb2 import (
HParamsPluginData,
SessionEndInfo,
SessionStartInfo,
)
# TODO: expose other parameters in the future.
# hp = HParamInfo(name='lr',display_name='learning rate',
# type=DataType.DATA_TYPE_FLOAT64, domain_interval=Interval(min_value=10,
# max_value=100))
# mt = MetricInfo(name=MetricName(tag='accuracy'), display_name='accuracy',
# description='', dataset_type=DatasetType.DATASET_VALIDATION)
# exp = Experiment(name='123', description='456', time_created_secs=100.0,
# hparam_infos=[hp], metric_infos=[mt], user='tw')
if not isinstance(hparam_dict, dict):
logger.warning("parameter: hparam_dict should be a dictionary, nothing logged.")
raise TypeError(
"parameter: hparam_dict should be a dictionary, nothing logged."
)
if not isinstance(metric_dict, dict):
logger.warning("parameter: metric_dict should be a dictionary, nothing logged.")
raise TypeError(
"parameter: metric_dict should be a dictionary, nothing logged."
)
hparam_domain_discrete = hparam_domain_discrete or {}
if not isinstance(hparam_domain_discrete, dict):
raise TypeError(
"parameter: hparam_domain_discrete should be a dictionary, nothing logged."
)
for k, v in hparam_domain_discrete.items():
if (
k not in hparam_dict
or not isinstance(v, list)
or not all(isinstance(d, type(hparam_dict[k])) for d in v)
):
raise TypeError(
f"parameter: hparam_domain_discrete[{k}] should be a list of same type as hparam_dict[{k}]."
)
hps = []
ssi = SessionStartInfo()
for k, v in hparam_dict.items():
if v is None:
continue
if isinstance(v, (int, float)):
ssi.hparams[k].number_value = v
if k in hparam_domain_discrete:
domain_discrete: Optional[struct_pb2.ListValue] = struct_pb2.ListValue(
values=[
struct_pb2.Value(number_value=d)
for d in hparam_domain_discrete[k]
]
)
else:
domain_discrete = None
hps.append(
HParamInfo(
name=k,
type=DataType.Value("DATA_TYPE_FLOAT64"),
domain_discrete=domain_discrete,
)
)
continue
if isinstance(v, str):
ssi.hparams[k].string_value = v
if k in hparam_domain_discrete:
domain_discrete = struct_pb2.ListValue(
values=[
struct_pb2.Value(string_value=d)
for d in hparam_domain_discrete[k]
]
)
else:
domain_discrete = None
hps.append(
HParamInfo(
name=k,
type=DataType.Value("DATA_TYPE_STRING"),
domain_discrete=domain_discrete,
)
)
continue
if isinstance(v, bool):
ssi.hparams[k].bool_value = v
if k in hparam_domain_discrete:
domain_discrete = struct_pb2.ListValue(
values=[
struct_pb2.Value(bool_value=d)
for d in hparam_domain_discrete[k]
]
)
else:
domain_discrete = None
hps.append(
HParamInfo(
name=k,
type=DataType.Value("DATA_TYPE_BOOL"),
domain_discrete=domain_discrete,
)
)
continue
if isinstance(v, torch.Tensor):
v = make_np(v)[0]
ssi.hparams[k].number_value = v
hps.append(HParamInfo(name=k, type=DataType.Value("DATA_TYPE_FLOAT64")))
continue
raise ValueError(
"value should be one of int, float, str, bool, or torch.Tensor"
)
content = HParamsPluginData(session_start_info=ssi, version=PLUGIN_DATA_VERSION)
smd = SummaryMetadata(
plugin_data=SummaryMetadata.PluginData(
plugin_name=PLUGIN_NAME, content=content.SerializeToString()
)
)
ssi = Summary(value=[Summary.Value(tag=SESSION_START_INFO_TAG, metadata=smd)])
mts = [MetricInfo(name=MetricName(tag=k)) for k in metric_dict.keys()]
exp = Experiment(hparam_infos=hps, metric_infos=mts)
content = HParamsPluginData(experiment=exp, version=PLUGIN_DATA_VERSION)
smd = SummaryMetadata(
plugin_data=SummaryMetadata.PluginData(
plugin_name=PLUGIN_NAME, content=content.SerializeToString()
)
)
exp = Summary(value=[Summary.Value(tag=EXPERIMENT_TAG, metadata=smd)])
sei = SessionEndInfo(status=Status.Value("STATUS_SUCCESS"))
content = HParamsPluginData(session_end_info=sei, version=PLUGIN_DATA_VERSION)
smd = SummaryMetadata(
plugin_data=SummaryMetadata.PluginData(
plugin_name=PLUGIN_NAME, content=content.SerializeToString()
)
)
sei = Summary(value=[Summary.Value(tag=SESSION_END_INFO_TAG, metadata=smd)])
return exp, ssi, sei
def scalar(name, tensor, collections=None, new_style=False, double_precision=False):
"""Output a `Summary` protocol buffer containing a single scalar value.
The generated Summary has a Tensor.proto containing the input Tensor.
Args:
name: A name for the generated node. Will also serve as the series name in
TensorBoard.
tensor: A real numeric Tensor containing a single value.
collections: Optional list of graph collections keys. The new summary op is
added to these collections. Defaults to `[GraphKeys.SUMMARIES]`.
new_style: Whether to use new style (tensor field) or old style (simple_value
field). New style could lead to faster data loading.
Returns:
A scalar `Tensor` of type `string`. Which contains a `Summary` protobuf.
Raises:
ValueError: If tensor has the wrong shape or type.
"""
tensor = make_np(tensor).squeeze()
assert (
tensor.ndim == 0
), f"Tensor should contain one element (0 dimensions). Was given size: {tensor.size} and {tensor.ndim} dimensions."
# python float is double precision in numpy
scalar = float(tensor)
if new_style:
tensor_proto = TensorProto(float_val=[scalar], dtype="DT_FLOAT")
if double_precision:
tensor_proto = TensorProto(double_val=[scalar], dtype="DT_DOUBLE")
plugin_data = SummaryMetadata.PluginData(plugin_name="scalars")
smd = SummaryMetadata(plugin_data=plugin_data)
return Summary(
value=[
Summary.Value(
tag=name,
tensor=tensor_proto,
metadata=smd,
)
]
)
else:
return Summary(value=[Summary.Value(tag=name, simple_value=scalar)])
def tensor_proto(tag, tensor):
"""Outputs a `Summary` protocol buffer containing the full tensor.
The generated Summary has a Tensor.proto containing the input Tensor.
Args:
name: A name for the generated node. Will also serve as the series name in
TensorBoard.
tensor: Tensor to be converted to protobuf
Returns:
A tensor protobuf in a `Summary` protobuf.
Raises:
ValueError: If tensor is too big to be converted to protobuf, or
tensor data type is not supported
"""
if tensor.numel() * tensor.itemsize >= (1 << 31):
raise ValueError(
"tensor is bigger than protocol buffer's hard limit of 2GB in size"
)
if tensor.dtype in _TENSOR_TYPE_MAP:
dtype, field_name, conversion_fn = _TENSOR_TYPE_MAP[tensor.dtype]
tensor_proto = TensorProto(
**{
"dtype": dtype,
"tensor_shape": TensorShapeProto(
dim=[TensorShapeProto.Dim(size=x) for x in tensor.shape]
),
field_name: conversion_fn(tensor),
},
)
else:
raise ValueError(f"{tag} has unsupported tensor dtype {tensor.dtype}")
plugin_data = SummaryMetadata.PluginData(plugin_name="tensor")
smd = SummaryMetadata(plugin_data=plugin_data)
return Summary(value=[Summary.Value(tag=tag, metadata=smd, tensor=tensor_proto)])
def histogram_raw(name, min, max, num, sum, sum_squares, bucket_limits, bucket_counts):
# pylint: disable=line-too-long
"""Output a `Summary` protocol buffer with a histogram.
The generated
[`Summary`](https://www.tensorflow.org/code/tensorflow/core/framework/summary.proto)
has one summary value containing a histogram for `values`.
Args:
name: A name for the generated node. Will also serve as a series name in
TensorBoard.
min: A float or int min value
max: A float or int max value
num: Int number of values
sum: Float or int sum of all values
sum_squares: Float or int sum of squares for all values
bucket_limits: A numeric `Tensor` with upper value per bucket
bucket_counts: A numeric `Tensor` with number of values per bucket
Returns:
A scalar `Tensor` of type `string`. The serialized `Summary` protocol
buffer.
"""
hist = HistogramProto(
min=min,
max=max,
num=num,
sum=sum,
sum_squares=sum_squares,
bucket_limit=bucket_limits,
bucket=bucket_counts,
)
return Summary(value=[Summary.Value(tag=name, histo=hist)])
def histogram(name, values, bins, max_bins=None):
# pylint: disable=line-too-long
"""Output a `Summary` protocol buffer with a histogram.
The generated
[`Summary`](https://www.tensorflow.org/code/tensorflow/core/framework/summary.proto)
has one summary value containing a histogram for `values`.
This op reports an `InvalidArgument` error if any value is not finite.
Args:
name: A name for the generated node. Will also serve as a series name in
TensorBoard.
values: A real numeric `Tensor`. Any shape. Values to use to
build the histogram.
Returns:
A scalar `Tensor` of type `string`. The serialized `Summary` protocol
buffer.
"""
values = make_np(values)
hist = make_histogram(values.astype(float), bins, max_bins)
return Summary(value=[Summary.Value(tag=name, histo=hist)])
def make_histogram(values, bins, max_bins=None):
"""Convert values into a histogram proto using logic from histogram.cc."""
if values.size == 0:
raise ValueError("The input has no element.")
values = values.reshape(-1)
counts, limits = np.histogram(values, bins=bins)
num_bins = len(counts)
if max_bins is not None and num_bins > max_bins:
subsampling = num_bins // max_bins
subsampling_remainder = num_bins % subsampling
if subsampling_remainder != 0:
counts = np.pad(
counts,
pad_width=[[0, subsampling - subsampling_remainder]],
mode="constant",
constant_values=0,
)
counts = counts.reshape(-1, subsampling).sum(axis=-1)
new_limits = np.empty((counts.size + 1,), limits.dtype)
new_limits[:-1] = limits[:-1:subsampling]
new_limits[-1] = limits[-1]
limits = new_limits
# Find the first and the last bin defining the support of the histogram:
cum_counts = np.cumsum(np.greater(counts, 0))
start, end = np.searchsorted(cum_counts, [0, cum_counts[-1] - 1], side="right")
start = int(start)
end = int(end) + 1
del cum_counts
# TensorBoard only includes the right bin limits. To still have the leftmost limit
# included, we include an empty bin left.
# If start == 0, we need to add an empty one left, otherwise we can just include the bin left to the
# first nonzero-count bin:
counts = (
counts[start - 1 : end] if start > 0 else np.concatenate([[0], counts[:end]])
)
limits = limits[start : end + 1]
if counts.size == 0 or limits.size == 0:
raise ValueError("The histogram is empty, please file a bug report.")
sum_sq = values.dot(values)
return HistogramProto(
min=values.min(),
max=values.max(),
num=len(values),
sum=values.sum(),
sum_squares=sum_sq,
bucket_limit=limits.tolist(),
bucket=counts.tolist(),
)
def image(tag, tensor, rescale=1, dataformats="NCHW"):
"""Output a `Summary` protocol buffer with images.
The summary has up to `max_images` summary values containing images. The
images are built from `tensor` which must be 3-D with shape `[height, width,
channels]` and where `channels` can be:
* 1: `tensor` is interpreted as Grayscale.
* 3: `tensor` is interpreted as RGB.
* 4: `tensor` is interpreted as RGBA.
The `name` in the outputted Summary.Value protobufs is generated based on the
name, with a suffix depending on the max_outputs setting:
* If `max_outputs` is 1, the summary value tag is '*name*/image'.
* If `max_outputs` is greater than 1, the summary value tags are
generated sequentially as '*name*/image/0', '*name*/image/1', etc.
Args:
tag: A name for the generated node. Will also serve as a series name in
TensorBoard.
tensor: A 3-D `uint8` or `float32` `Tensor` of shape `[height, width,
channels]` where `channels` is 1, 3, or 4.
'tensor' can either have values in [0, 1] (float32) or [0, 255] (uint8).
The image() function will scale the image values to [0, 255] by applying
a scale factor of either 1 (uint8) or 255 (float32). Out-of-range values
will be clipped.
Returns:
A scalar `Tensor` of type `string`. The serialized `Summary` protocol
buffer.
"""
tensor = make_np(tensor)
tensor = convert_to_HWC(tensor, dataformats)
# Do not assume that user passes in values in [0, 255], use data type to detect
scale_factor = _calc_scale_factor(tensor)
tensor = tensor.astype(np.float32)
tensor = (tensor * scale_factor).clip(0, 255).astype(np.uint8)
image = make_image(tensor, rescale=rescale)
return Summary(value=[Summary.Value(tag=tag, image=image)])
def image_boxes(
tag, tensor_image, tensor_boxes, rescale=1, dataformats="CHW", labels=None
):
"""Output a `Summary` protocol buffer with images."""
tensor_image = make_np(tensor_image)
tensor_image = convert_to_HWC(tensor_image, dataformats)
tensor_boxes = make_np(tensor_boxes)
tensor_image = tensor_image.astype(np.float32) * _calc_scale_factor(tensor_image)
image = make_image(
tensor_image.clip(0, 255).astype(np.uint8),
rescale=rescale,
rois=tensor_boxes,
labels=labels,
)
return Summary(value=[Summary.Value(tag=tag, image=image)])
def draw_boxes(disp_image, boxes, labels=None):
# xyxy format
num_boxes = boxes.shape[0]
list_gt = range(num_boxes)
for i in list_gt:
disp_image = _draw_single_box(
disp_image,
boxes[i, 0],
boxes[i, 1],
boxes[i, 2],
boxes[i, 3],
display_str=None if labels is None else labels[i],
color="Red",
)
return disp_image
def make_image(tensor, rescale=1, rois=None, labels=None):
"""Convert a numpy representation of an image to Image protobuf."""
from PIL import Image
height, width, channel = tensor.shape
scaled_height = int(height * rescale)
scaled_width = int(width * rescale)
image = Image.fromarray(tensor)
if rois is not None:
image = draw_boxes(image, rois, labels=labels)
try:
ANTIALIAS = Image.Resampling.LANCZOS
except AttributeError:
ANTIALIAS = Image.ANTIALIAS
image = image.resize((scaled_width, scaled_height), ANTIALIAS)
import io
output = io.BytesIO()
image.save(output, format="PNG")
image_string = output.getvalue()
output.close()
return Summary.Image(
height=height,
width=width,
colorspace=channel,
encoded_image_string=image_string,
)
def video(tag, tensor, fps=4):
tensor = make_np(tensor)
tensor = _prepare_video(tensor)
# If user passes in uint8, then we don't need to rescale by 255
scale_factor = _calc_scale_factor(tensor)
tensor = tensor.astype(np.float32)
tensor = (tensor * scale_factor).clip(0, 255).astype(np.uint8)
video = make_video(tensor, fps)
return Summary(value=[Summary.Value(tag=tag, image=video)])
def make_video(tensor, fps):
try:
import moviepy # noqa: F401
except ImportError:
print("add_video needs package moviepy")
return
try:
from moviepy import editor as mpy
except ImportError:
print(
"moviepy is installed, but can't import moviepy.editor.",
"Some packages could be missing [imageio, requests]",
)
return
import tempfile
t, h, w, c = tensor.shape
# encode sequence of images into gif string
clip = mpy.ImageSequenceClip(list(tensor), fps=fps)
filename = tempfile.NamedTemporaryFile(suffix=".gif", delete=False).name
try: # newer version of moviepy use logger instead of progress_bar argument.
clip.write_gif(filename, verbose=False, logger=None)
except TypeError:
try: # older version of moviepy does not support progress_bar argument.
clip.write_gif(filename, verbose=False, progress_bar=False)
except TypeError:
clip.write_gif(filename, verbose=False)
with open(filename, "rb") as f:
tensor_string = f.read()
try:
os.remove(filename)
except OSError:
logger.warning("The temporary file used by moviepy cannot be deleted.")
return Summary.Image(
height=h, width=w, colorspace=c, encoded_image_string=tensor_string
)
def audio(tag, tensor, sample_rate=44100):
array = make_np(tensor)
array = array.squeeze()
if abs(array).max() > 1:
print("warning: audio amplitude out of range, auto clipped.")
array = array.clip(-1, 1)
assert array.ndim == 1, "input tensor should be 1 dimensional."
array = (array * np.iinfo(np.int16).max).astype("<i2")
import io
import wave
fio = io.BytesIO()
with wave.open(fio, "wb") as wave_write:
wave_write.setnchannels(1)
wave_write.setsampwidth(2)
wave_write.setframerate(sample_rate)
wave_write.writeframes(array.data)
audio_string = fio.getvalue()
fio.close()
audio = Summary.Audio(
sample_rate=sample_rate,
num_channels=1,
length_frames=array.shape[-1],
encoded_audio_string=audio_string,
content_type="audio/wav",
)
return Summary(value=[Summary.Value(tag=tag, audio=audio)])
def custom_scalars(layout):
categories = []
for k, v in layout.items():
charts = []
for chart_name, chart_meatadata in v.items():
tags = chart_meatadata[1]
if chart_meatadata[0] == "Margin":
assert len(tags) == 3
mgcc = layout_pb2.MarginChartContent(
series=[
layout_pb2.MarginChartContent.Series(
value=tags[0], lower=tags[1], upper=tags[2]
)
]
)
chart = layout_pb2.Chart(title=chart_name, margin=mgcc)
else:
mlcc = layout_pb2.MultilineChartContent(tag=tags)
chart = layout_pb2.Chart(title=chart_name, multiline=mlcc)
charts.append(chart)
categories.append(layout_pb2.Category(title=k, chart=charts))
layout = layout_pb2.Layout(category=categories)
plugin_data = SummaryMetadata.PluginData(plugin_name="custom_scalars")
smd = SummaryMetadata(plugin_data=plugin_data)
tensor = TensorProto(
dtype="DT_STRING",
string_val=[layout.SerializeToString()],
tensor_shape=TensorShapeProto(),
)
return Summary(
value=[
Summary.Value(tag="custom_scalars__config__", tensor=tensor, metadata=smd)
]
)
def text(tag, text):
plugin_data = SummaryMetadata.PluginData(
plugin_name="text", content=TextPluginData(version=0).SerializeToString()
)
smd = SummaryMetadata(plugin_data=plugin_data)
tensor = TensorProto(
dtype="DT_STRING",
string_val=[text.encode(encoding="utf_8")],
tensor_shape=TensorShapeProto(dim=[TensorShapeProto.Dim(size=1)]),
)
return Summary(
value=[Summary.Value(tag=tag + "/text_summary", metadata=smd, tensor=tensor)]
)
def pr_curve_raw(
tag, tp, fp, tn, fn, precision, recall, num_thresholds=127, weights=None
):
if num_thresholds > 127: # weird, value > 127 breaks protobuf
num_thresholds = 127
data = np.stack((tp, fp, tn, fn, precision, recall))
pr_curve_plugin_data = PrCurvePluginData(
version=0, num_thresholds=num_thresholds
).SerializeToString()
plugin_data = SummaryMetadata.PluginData(
plugin_name="pr_curves", content=pr_curve_plugin_data
)
smd = SummaryMetadata(plugin_data=plugin_data)
tensor = TensorProto(
dtype="DT_FLOAT",
float_val=data.reshape(-1).tolist(),
tensor_shape=TensorShapeProto(
dim=[
TensorShapeProto.Dim(size=data.shape[0]),
TensorShapeProto.Dim(size=data.shape[1]),
]
),
)
return Summary(value=[Summary.Value(tag=tag, metadata=smd, tensor=tensor)])
def pr_curve(tag, labels, predictions, num_thresholds=127, weights=None):
# weird, value > 127 breaks protobuf
num_thresholds = min(num_thresholds, 127)
data = compute_curve(
labels, predictions, num_thresholds=num_thresholds, weights=weights
)
pr_curve_plugin_data = PrCurvePluginData(
version=0, num_thresholds=num_thresholds
).SerializeToString()
plugin_data = SummaryMetadata.PluginData(
plugin_name="pr_curves", content=pr_curve_plugin_data
)
smd = SummaryMetadata(plugin_data=plugin_data)
tensor = TensorProto(
dtype="DT_FLOAT",
float_val=data.reshape(-1).tolist(),
tensor_shape=TensorShapeProto(
dim=[
TensorShapeProto.Dim(size=data.shape[0]),
TensorShapeProto.Dim(size=data.shape[1]),
]
),
)
return Summary(value=[Summary.Value(tag=tag, metadata=smd, tensor=tensor)])
# https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/pr_curve/summary.py
def compute_curve(labels, predictions, num_thresholds=None, weights=None):
_MINIMUM_COUNT = 1e-7
if weights is None:
weights = 1.0
# Compute bins of true positives and false positives.
bucket_indices = np.int32(np.floor(predictions * (num_thresholds - 1)))
float_labels = labels.astype(np.float64)
histogram_range = (0, num_thresholds - 1)
tp_buckets, _ = np.histogram(
bucket_indices,
bins=num_thresholds,
range=histogram_range,
weights=float_labels * weights,
)
fp_buckets, _ = np.histogram(
bucket_indices,
bins=num_thresholds,
range=histogram_range,
weights=(1.0 - float_labels) * weights,
)
# Obtain the reverse cumulative sum.
tp = np.cumsum(tp_buckets[::-1])[::-1]
fp = np.cumsum(fp_buckets[::-1])[::-1]
tn = fp[0] - fp
fn = tp[0] - tp
precision = tp / np.maximum(_MINIMUM_COUNT, tp + fp)
recall = tp / np.maximum(_MINIMUM_COUNT, tp + fn)
return np.stack((tp, fp, tn, fn, precision, recall))
def _get_tensor_summary(
name, display_name, description, tensor, content_type, components, json_config
):
"""Create a tensor summary with summary metadata.
Args:
name: Uniquely identifiable name of the summary op. Could be replaced by
combination of name and type to make it unique even outside of this
summary.
display_name: Will be used as the display name in TensorBoard.
Defaults to `name`.
description: A longform readable description of the summary data. Markdown
is supported.
tensor: Tensor to display in summary.
content_type: Type of content inside the Tensor.
components: Bitmask representing present parts (vertices, colors, etc.) that
belong to the summary.
json_config: A string, JSON-serialized dictionary of ThreeJS classes
configuration.
Returns:
Tensor summary with metadata.
"""
import torch
from tensorboard.plugins.mesh import metadata
tensor = torch.as_tensor(tensor)
tensor_metadata = metadata.create_summary_metadata(
name,
display_name,
content_type,
components,
tensor.shape,
description,
json_config=json_config,
)
tensor = TensorProto(
dtype="DT_FLOAT",
float_val=tensor.reshape(-1).tolist(),
tensor_shape=TensorShapeProto(
dim=[
TensorShapeProto.Dim(size=tensor.shape[0]),
TensorShapeProto.Dim(size=tensor.shape[1]),
TensorShapeProto.Dim(size=tensor.shape[2]),
]
),
)
tensor_summary = Summary.Value(
tag=metadata.get_instance_name(name, content_type),
tensor=tensor,
metadata=tensor_metadata,
)
return tensor_summary
def _get_json_config(config_dict):
"""Parse and returns JSON string from python dictionary."""
json_config = "{}"
if config_dict is not None:
json_config = json.dumps(config_dict, sort_keys=True)
return json_config
# https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/mesh/summary.py
def mesh(
tag, vertices, colors, faces, config_dict, display_name=None, description=None
):
"""Output a merged `Summary` protocol buffer with a mesh/point cloud.
Args:
tag: A name for this summary operation.
vertices: Tensor of shape `[dim_1, ..., dim_n, 3]` representing the 3D
coordinates of vertices.
faces: Tensor of shape `[dim_1, ..., dim_n, 3]` containing indices of
vertices within each triangle.
colors: Tensor of shape `[dim_1, ..., dim_n, 3]` containing colors for each
vertex.
display_name: If set, will be used as the display name in TensorBoard.
Defaults to `name`.
description: A longform readable description of the summary data. Markdown
is supported.
config_dict: Dictionary with ThreeJS classes names and configuration.
Returns:
Merged summary for mesh/point cloud representation.
"""
from tensorboard.plugins.mesh import metadata
from tensorboard.plugins.mesh.plugin_data_pb2 import MeshPluginData
json_config = _get_json_config(config_dict)
summaries = []
tensors = [
(vertices, MeshPluginData.VERTEX),
(faces, MeshPluginData.FACE),
(colors, MeshPluginData.COLOR),
]
tensors = [tensor for tensor in tensors if tensor[0] is not None]
components = metadata.get_components_bitmask(
[content_type for (tensor, content_type) in tensors]
)
for tensor, content_type in tensors:
summaries.append(
_get_tensor_summary(
tag,
display_name,
description,
tensor,
content_type,
components,
json_config,
)
)
return Summary(value=summaries)