import json import logging import os import struct from typing import Any, List, Optional import torch import numpy as np from google.protobuf import struct_pb2 from tensorboard.compat.proto.summary_pb2 import ( HistogramProto, Summary, SummaryMetadata, ) from tensorboard.compat.proto.tensor_pb2 import TensorProto from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto from tensorboard.plugins.custom_scalar import layout_pb2 from tensorboard.plugins.pr_curve.plugin_data_pb2 import PrCurvePluginData from tensorboard.plugins.text.plugin_data_pb2 import TextPluginData from ._convert_np import make_np from ._utils import _prepare_video, convert_to_HWC __all__ = [ "half_to_int", "int_to_half", "hparams", "scalar", "histogram_raw", "histogram", "make_histogram", "image", "image_boxes", "draw_boxes", "make_image", "video", "make_video", "audio", "custom_scalars", "text", "tensor_proto", "pr_curve_raw", "pr_curve", "compute_curve", "mesh", ] logger = logging.getLogger(__name__) def half_to_int(f: float) -> int: """Casts a half-precision float value into an integer. Converts a half precision floating point value, such as `torch.half` or `torch.bfloat16`, into an integer value which can be written into the half_val field of a TensorProto for storage. To undo the effects of this conversion, use int_to_half(). """ buf = struct.pack("f", f) return struct.unpack("i", buf)[0] def int_to_half(i: int) -> float: """Casts an integer value to a half-precision float. Converts an integer value obtained from half_to_int back into a floating point value. """ buf = struct.pack("i", i) return struct.unpack("f", buf)[0] def _tensor_to_half_val(t: torch.Tensor) -> List[int]: return [half_to_int(x) for x in t.flatten().tolist()] def _tensor_to_complex_val(t: torch.Tensor) -> List[float]: return torch.view_as_real(t).flatten().tolist() def _tensor_to_list(t: torch.Tensor) -> List[Any]: return t.flatten().tolist() # type maps: torch.Tensor type -> (protobuf type, protobuf val field) _TENSOR_TYPE_MAP = { torch.half: ("DT_HALF", "half_val", _tensor_to_half_val), torch.float16: ("DT_HALF", "half_val", _tensor_to_half_val), torch.bfloat16: ("DT_BFLOAT16", "half_val", _tensor_to_half_val), torch.float32: ("DT_FLOAT", "float_val", _tensor_to_list), torch.float: ("DT_FLOAT", "float_val", _tensor_to_list), torch.float64: ("DT_DOUBLE", "double_val", _tensor_to_list), torch.double: ("DT_DOUBLE", "double_val", _tensor_to_list), torch.int8: ("DT_INT8", "int_val", _tensor_to_list), torch.uint8: ("DT_UINT8", "int_val", _tensor_to_list), torch.qint8: ("DT_UINT8", "int_val", _tensor_to_list), torch.int16: ("DT_INT16", "int_val", _tensor_to_list), torch.short: ("DT_INT16", "int_val", _tensor_to_list), torch.int: ("DT_INT32", "int_val", _tensor_to_list), torch.int32: ("DT_INT32", "int_val", _tensor_to_list), torch.qint32: ("DT_INT32", "int_val", _tensor_to_list), torch.int64: ("DT_INT64", "int64_val", _tensor_to_list), torch.complex32: ("DT_COMPLEX32", "scomplex_val", _tensor_to_complex_val), torch.chalf: ("DT_COMPLEX32", "scomplex_val", _tensor_to_complex_val), torch.complex64: ("DT_COMPLEX64", "scomplex_val", _tensor_to_complex_val), torch.cfloat: ("DT_COMPLEX64", "scomplex_val", _tensor_to_complex_val), torch.bool: ("DT_BOOL", "bool_val", _tensor_to_list), torch.complex128: ("DT_COMPLEX128", "dcomplex_val", _tensor_to_complex_val), torch.cdouble: ("DT_COMPLEX128", "dcomplex_val", _tensor_to_complex_val), torch.uint8: ("DT_UINT8", "uint32_val", _tensor_to_list), torch.quint8: ("DT_UINT8", "uint32_val", _tensor_to_list), torch.quint4x2: ("DT_UINT8", "uint32_val", _tensor_to_list), } def _calc_scale_factor(tensor): converted = tensor.numpy() if not isinstance(tensor, np.ndarray) else tensor return 1 if converted.dtype == np.uint8 else 255 def _draw_single_box( image, xmin, ymin, xmax, ymax, display_str, color="black", color_text="black", thickness=2, ): from PIL import ImageDraw, ImageFont font = ImageFont.load_default() draw = ImageDraw.Draw(image) (left, right, top, bottom) = (xmin, xmax, ymin, ymax) draw.line( [(left, top), (left, bottom), (right, bottom), (right, top), (left, top)], width=thickness, fill=color, ) if display_str: text_bottom = bottom # Reverse list and print from bottom to top. text_width, text_height = font.getsize(display_str) margin = np.ceil(0.05 * text_height) draw.rectangle( [ (left, text_bottom - text_height - 2 * margin), (left + text_width, text_bottom), ], fill=color, ) draw.text( (left + margin, text_bottom - text_height - margin), display_str, fill=color_text, font=font, ) return image def hparams(hparam_dict=None, metric_dict=None, hparam_domain_discrete=None): """Output three `Summary` protocol buffers needed by hparams plugin. `Experiment` keeps the metadata of an experiment, such as the name of the hyperparameters and the name of the metrics. `SessionStartInfo` keeps key-value pairs of the hyperparameters `SessionEndInfo` describes status of the experiment e.g. STATUS_SUCCESS Args: hparam_dict: A dictionary that contains names of the hyperparameters and their values. metric_dict: A dictionary that contains names of the metrics and their values. hparam_domain_discrete: (Optional[Dict[str, List[Any]]]) A dictionary that contains names of the hyperparameters and all discrete values they can hold Returns: The `Summary` protobufs for Experiment, SessionStartInfo and SessionEndInfo """ import torch from tensorboard.plugins.hparams.api_pb2 import ( DataType, Experiment, HParamInfo, MetricInfo, MetricName, Status, ) from tensorboard.plugins.hparams.metadata import ( EXPERIMENT_TAG, PLUGIN_DATA_VERSION, PLUGIN_NAME, SESSION_END_INFO_TAG, SESSION_START_INFO_TAG, ) from tensorboard.plugins.hparams.plugin_data_pb2 import ( HParamsPluginData, SessionEndInfo, SessionStartInfo, ) # TODO: expose other parameters in the future. # hp = HParamInfo(name='lr',display_name='learning rate', # type=DataType.DATA_TYPE_FLOAT64, domain_interval=Interval(min_value=10, # max_value=100)) # mt = MetricInfo(name=MetricName(tag='accuracy'), display_name='accuracy', # description='', dataset_type=DatasetType.DATASET_VALIDATION) # exp = Experiment(name='123', description='456', time_created_secs=100.0, # hparam_infos=[hp], metric_infos=[mt], user='tw') if not isinstance(hparam_dict, dict): logger.warning("parameter: hparam_dict should be a dictionary, nothing logged.") raise TypeError( "parameter: hparam_dict should be a dictionary, nothing logged." ) if not isinstance(metric_dict, dict): logger.warning("parameter: metric_dict should be a dictionary, nothing logged.") raise TypeError( "parameter: metric_dict should be a dictionary, nothing logged." ) hparam_domain_discrete = hparam_domain_discrete or {} if not isinstance(hparam_domain_discrete, dict): raise TypeError( "parameter: hparam_domain_discrete should be a dictionary, nothing logged." ) for k, v in hparam_domain_discrete.items(): if ( k not in hparam_dict or not isinstance(v, list) or not all(isinstance(d, type(hparam_dict[k])) for d in v) ): raise TypeError( f"parameter: hparam_domain_discrete[{k}] should be a list of same type as hparam_dict[{k}]." ) hps = [] ssi = SessionStartInfo() for k, v in hparam_dict.items(): if v is None: continue if isinstance(v, (int, float)): ssi.hparams[k].number_value = v if k in hparam_domain_discrete: domain_discrete: Optional[struct_pb2.ListValue] = struct_pb2.ListValue( values=[ struct_pb2.Value(number_value=d) for d in hparam_domain_discrete[k] ] ) else: domain_discrete = None hps.append( HParamInfo( name=k, type=DataType.Value("DATA_TYPE_FLOAT64"), domain_discrete=domain_discrete, ) ) continue if isinstance(v, str): ssi.hparams[k].string_value = v if k in hparam_domain_discrete: domain_discrete = struct_pb2.ListValue( values=[ struct_pb2.Value(string_value=d) for d in hparam_domain_discrete[k] ] ) else: domain_discrete = None hps.append( HParamInfo( name=k, type=DataType.Value("DATA_TYPE_STRING"), domain_discrete=domain_discrete, ) ) continue if isinstance(v, bool): ssi.hparams[k].bool_value = v if k in hparam_domain_discrete: domain_discrete = struct_pb2.ListValue( values=[ struct_pb2.Value(bool_value=d) for d in hparam_domain_discrete[k] ] ) else: domain_discrete = None hps.append( HParamInfo( name=k, type=DataType.Value("DATA_TYPE_BOOL"), domain_discrete=domain_discrete, ) ) continue if isinstance(v, torch.Tensor): v = make_np(v)[0] ssi.hparams[k].number_value = v hps.append(HParamInfo(name=k, type=DataType.Value("DATA_TYPE_FLOAT64"))) continue raise ValueError( "value should be one of int, float, str, bool, or torch.Tensor" ) content = HParamsPluginData(session_start_info=ssi, version=PLUGIN_DATA_VERSION) smd = SummaryMetadata( plugin_data=SummaryMetadata.PluginData( plugin_name=PLUGIN_NAME, content=content.SerializeToString() ) ) ssi = Summary(value=[Summary.Value(tag=SESSION_START_INFO_TAG, metadata=smd)]) mts = [MetricInfo(name=MetricName(tag=k)) for k in metric_dict.keys()] exp = Experiment(hparam_infos=hps, metric_infos=mts) content = HParamsPluginData(experiment=exp, version=PLUGIN_DATA_VERSION) smd = SummaryMetadata( plugin_data=SummaryMetadata.PluginData( plugin_name=PLUGIN_NAME, content=content.SerializeToString() ) ) exp = Summary(value=[Summary.Value(tag=EXPERIMENT_TAG, metadata=smd)]) sei = SessionEndInfo(status=Status.Value("STATUS_SUCCESS")) content = HParamsPluginData(session_end_info=sei, version=PLUGIN_DATA_VERSION) smd = SummaryMetadata( plugin_data=SummaryMetadata.PluginData( plugin_name=PLUGIN_NAME, content=content.SerializeToString() ) ) sei = Summary(value=[Summary.Value(tag=SESSION_END_INFO_TAG, metadata=smd)]) return exp, ssi, sei def scalar(name, tensor, collections=None, new_style=False, double_precision=False): """Output a `Summary` protocol buffer containing a single scalar value. The generated Summary has a Tensor.proto containing the input Tensor. Args: name: A name for the generated node. Will also serve as the series name in TensorBoard. tensor: A real numeric Tensor containing a single value. collections: Optional list of graph collections keys. The new summary op is added to these collections. Defaults to `[GraphKeys.SUMMARIES]`. new_style: Whether to use new style (tensor field) or old style (simple_value field). New style could lead to faster data loading. Returns: A scalar `Tensor` of type `string`. Which contains a `Summary` protobuf. Raises: ValueError: If tensor has the wrong shape or type. """ tensor = make_np(tensor).squeeze() assert ( tensor.ndim == 0 ), f"Tensor should contain one element (0 dimensions). Was given size: {tensor.size} and {tensor.ndim} dimensions." # python float is double precision in numpy scalar = float(tensor) if new_style: tensor_proto = TensorProto(float_val=[scalar], dtype="DT_FLOAT") if double_precision: tensor_proto = TensorProto(double_val=[scalar], dtype="DT_DOUBLE") plugin_data = SummaryMetadata.PluginData(plugin_name="scalars") smd = SummaryMetadata(plugin_data=plugin_data) return Summary( value=[ Summary.Value( tag=name, tensor=tensor_proto, metadata=smd, ) ] ) else: return Summary(value=[Summary.Value(tag=name, simple_value=scalar)]) def tensor_proto(tag, tensor): """Outputs a `Summary` protocol buffer containing the full tensor. The generated Summary has a Tensor.proto containing the input Tensor. Args: name: A name for the generated node. Will also serve as the series name in TensorBoard. tensor: Tensor to be converted to protobuf Returns: A tensor protobuf in a `Summary` protobuf. Raises: ValueError: If tensor is too big to be converted to protobuf, or tensor data type is not supported """ if tensor.numel() * tensor.itemsize >= (1 << 31): raise ValueError( "tensor is bigger than protocol buffer's hard limit of 2GB in size" ) if tensor.dtype in _TENSOR_TYPE_MAP: dtype, field_name, conversion_fn = _TENSOR_TYPE_MAP[tensor.dtype] tensor_proto = TensorProto( **{ "dtype": dtype, "tensor_shape": TensorShapeProto( dim=[TensorShapeProto.Dim(size=x) for x in tensor.shape] ), field_name: conversion_fn(tensor), }, ) else: raise ValueError(f"{tag} has unsupported tensor dtype {tensor.dtype}") plugin_data = SummaryMetadata.PluginData(plugin_name="tensor") smd = SummaryMetadata(plugin_data=plugin_data) return Summary(value=[Summary.Value(tag=tag, metadata=smd, tensor=tensor_proto)]) def histogram_raw(name, min, max, num, sum, sum_squares, bucket_limits, bucket_counts): # pylint: disable=line-too-long """Output a `Summary` protocol buffer with a histogram. The generated [`Summary`](https://www.tensorflow.org/code/tensorflow/core/framework/summary.proto) has one summary value containing a histogram for `values`. Args: name: A name for the generated node. Will also serve as a series name in TensorBoard. min: A float or int min value max: A float or int max value num: Int number of values sum: Float or int sum of all values sum_squares: Float or int sum of squares for all values bucket_limits: A numeric `Tensor` with upper value per bucket bucket_counts: A numeric `Tensor` with number of values per bucket Returns: A scalar `Tensor` of type `string`. The serialized `Summary` protocol buffer. """ hist = HistogramProto( min=min, max=max, num=num, sum=sum, sum_squares=sum_squares, bucket_limit=bucket_limits, bucket=bucket_counts, ) return Summary(value=[Summary.Value(tag=name, histo=hist)]) def histogram(name, values, bins, max_bins=None): # pylint: disable=line-too-long """Output a `Summary` protocol buffer with a histogram. The generated [`Summary`](https://www.tensorflow.org/code/tensorflow/core/framework/summary.proto) has one summary value containing a histogram for `values`. This op reports an `InvalidArgument` error if any value is not finite. Args: name: A name for the generated node. Will also serve as a series name in TensorBoard. values: A real numeric `Tensor`. Any shape. Values to use to build the histogram. Returns: A scalar `Tensor` of type `string`. The serialized `Summary` protocol buffer. """ values = make_np(values) hist = make_histogram(values.astype(float), bins, max_bins) return Summary(value=[Summary.Value(tag=name, histo=hist)]) def make_histogram(values, bins, max_bins=None): """Convert values into a histogram proto using logic from histogram.cc.""" if values.size == 0: raise ValueError("The input has no element.") values = values.reshape(-1) counts, limits = np.histogram(values, bins=bins) num_bins = len(counts) if max_bins is not None and num_bins > max_bins: subsampling = num_bins // max_bins subsampling_remainder = num_bins % subsampling if subsampling_remainder != 0: counts = np.pad( counts, pad_width=[[0, subsampling - subsampling_remainder]], mode="constant", constant_values=0, ) counts = counts.reshape(-1, subsampling).sum(axis=-1) new_limits = np.empty((counts.size + 1,), limits.dtype) new_limits[:-1] = limits[:-1:subsampling] new_limits[-1] = limits[-1] limits = new_limits # Find the first and the last bin defining the support of the histogram: cum_counts = np.cumsum(np.greater(counts, 0)) start, end = np.searchsorted(cum_counts, [0, cum_counts[-1] - 1], side="right") start = int(start) end = int(end) + 1 del cum_counts # TensorBoard only includes the right bin limits. To still have the leftmost limit # included, we include an empty bin left. # If start == 0, we need to add an empty one left, otherwise we can just include the bin left to the # first nonzero-count bin: counts = ( counts[start - 1 : end] if start > 0 else np.concatenate([[0], counts[:end]]) ) limits = limits[start : end + 1] if counts.size == 0 or limits.size == 0: raise ValueError("The histogram is empty, please file a bug report.") sum_sq = values.dot(values) return HistogramProto( min=values.min(), max=values.max(), num=len(values), sum=values.sum(), sum_squares=sum_sq, bucket_limit=limits.tolist(), bucket=counts.tolist(), ) def image(tag, tensor, rescale=1, dataformats="NCHW"): """Output a `Summary` protocol buffer with images. The summary has up to `max_images` summary values containing images. The images are built from `tensor` which must be 3-D with shape `[height, width, channels]` and where `channels` can be: * 1: `tensor` is interpreted as Grayscale. * 3: `tensor` is interpreted as RGB. * 4: `tensor` is interpreted as RGBA. The `name` in the outputted Summary.Value protobufs is generated based on the name, with a suffix depending on the max_outputs setting: * If `max_outputs` is 1, the summary value tag is '*name*/image'. * If `max_outputs` is greater than 1, the summary value tags are generated sequentially as '*name*/image/0', '*name*/image/1', etc. Args: tag: A name for the generated node. Will also serve as a series name in TensorBoard. tensor: A 3-D `uint8` or `float32` `Tensor` of shape `[height, width, channels]` where `channels` is 1, 3, or 4. 'tensor' can either have values in [0, 1] (float32) or [0, 255] (uint8). The image() function will scale the image values to [0, 255] by applying a scale factor of either 1 (uint8) or 255 (float32). Out-of-range values will be clipped. Returns: A scalar `Tensor` of type `string`. The serialized `Summary` protocol buffer. """ tensor = make_np(tensor) tensor = convert_to_HWC(tensor, dataformats) # Do not assume that user passes in values in [0, 255], use data type to detect scale_factor = _calc_scale_factor(tensor) tensor = tensor.astype(np.float32) tensor = (tensor * scale_factor).clip(0, 255).astype(np.uint8) image = make_image(tensor, rescale=rescale) return Summary(value=[Summary.Value(tag=tag, image=image)]) def image_boxes( tag, tensor_image, tensor_boxes, rescale=1, dataformats="CHW", labels=None ): """Output a `Summary` protocol buffer with images.""" tensor_image = make_np(tensor_image) tensor_image = convert_to_HWC(tensor_image, dataformats) tensor_boxes = make_np(tensor_boxes) tensor_image = tensor_image.astype(np.float32) * _calc_scale_factor(tensor_image) image = make_image( tensor_image.clip(0, 255).astype(np.uint8), rescale=rescale, rois=tensor_boxes, labels=labels, ) return Summary(value=[Summary.Value(tag=tag, image=image)]) def draw_boxes(disp_image, boxes, labels=None): # xyxy format num_boxes = boxes.shape[0] list_gt = range(num_boxes) for i in list_gt: disp_image = _draw_single_box( disp_image, boxes[i, 0], boxes[i, 1], boxes[i, 2], boxes[i, 3], display_str=None if labels is None else labels[i], color="Red", ) return disp_image def make_image(tensor, rescale=1, rois=None, labels=None): """Convert a numpy representation of an image to Image protobuf.""" from PIL import Image height, width, channel = tensor.shape scaled_height = int(height * rescale) scaled_width = int(width * rescale) image = Image.fromarray(tensor) if rois is not None: image = draw_boxes(image, rois, labels=labels) try: ANTIALIAS = Image.Resampling.LANCZOS except AttributeError: ANTIALIAS = Image.ANTIALIAS image = image.resize((scaled_width, scaled_height), ANTIALIAS) import io output = io.BytesIO() image.save(output, format="PNG") image_string = output.getvalue() output.close() return Summary.Image( height=height, width=width, colorspace=channel, encoded_image_string=image_string, ) def video(tag, tensor, fps=4): tensor = make_np(tensor) tensor = _prepare_video(tensor) # If user passes in uint8, then we don't need to rescale by 255 scale_factor = _calc_scale_factor(tensor) tensor = tensor.astype(np.float32) tensor = (tensor * scale_factor).clip(0, 255).astype(np.uint8) video = make_video(tensor, fps) return Summary(value=[Summary.Value(tag=tag, image=video)]) def make_video(tensor, fps): try: import moviepy # noqa: F401 except ImportError: print("add_video needs package moviepy") return try: from moviepy import editor as mpy except ImportError: print( "moviepy is installed, but can't import moviepy.editor.", "Some packages could be missing [imageio, requests]", ) return import tempfile t, h, w, c = tensor.shape # encode sequence of images into gif string clip = mpy.ImageSequenceClip(list(tensor), fps=fps) filename = tempfile.NamedTemporaryFile(suffix=".gif", delete=False).name try: # newer version of moviepy use logger instead of progress_bar argument. clip.write_gif(filename, verbose=False, logger=None) except TypeError: try: # older version of moviepy does not support progress_bar argument. clip.write_gif(filename, verbose=False, progress_bar=False) except TypeError: clip.write_gif(filename, verbose=False) with open(filename, "rb") as f: tensor_string = f.read() try: os.remove(filename) except OSError: logger.warning("The temporary file used by moviepy cannot be deleted.") return Summary.Image( height=h, width=w, colorspace=c, encoded_image_string=tensor_string ) def audio(tag, tensor, sample_rate=44100): array = make_np(tensor) array = array.squeeze() if abs(array).max() > 1: print("warning: audio amplitude out of range, auto clipped.") array = array.clip(-1, 1) assert array.ndim == 1, "input tensor should be 1 dimensional." array = (array * np.iinfo(np.int16).max).astype(" 127: # weird, value > 127 breaks protobuf num_thresholds = 127 data = np.stack((tp, fp, tn, fn, precision, recall)) pr_curve_plugin_data = PrCurvePluginData( version=0, num_thresholds=num_thresholds ).SerializeToString() plugin_data = SummaryMetadata.PluginData( plugin_name="pr_curves", content=pr_curve_plugin_data ) smd = SummaryMetadata(plugin_data=plugin_data) tensor = TensorProto( dtype="DT_FLOAT", float_val=data.reshape(-1).tolist(), tensor_shape=TensorShapeProto( dim=[ TensorShapeProto.Dim(size=data.shape[0]), TensorShapeProto.Dim(size=data.shape[1]), ] ), ) return Summary(value=[Summary.Value(tag=tag, metadata=smd, tensor=tensor)]) def pr_curve(tag, labels, predictions, num_thresholds=127, weights=None): # weird, value > 127 breaks protobuf num_thresholds = min(num_thresholds, 127) data = compute_curve( labels, predictions, num_thresholds=num_thresholds, weights=weights ) pr_curve_plugin_data = PrCurvePluginData( version=0, num_thresholds=num_thresholds ).SerializeToString() plugin_data = SummaryMetadata.PluginData( plugin_name="pr_curves", content=pr_curve_plugin_data ) smd = SummaryMetadata(plugin_data=plugin_data) tensor = TensorProto( dtype="DT_FLOAT", float_val=data.reshape(-1).tolist(), tensor_shape=TensorShapeProto( dim=[ TensorShapeProto.Dim(size=data.shape[0]), TensorShapeProto.Dim(size=data.shape[1]), ] ), ) return Summary(value=[Summary.Value(tag=tag, metadata=smd, tensor=tensor)]) # https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/pr_curve/summary.py def compute_curve(labels, predictions, num_thresholds=None, weights=None): _MINIMUM_COUNT = 1e-7 if weights is None: weights = 1.0 # Compute bins of true positives and false positives. bucket_indices = np.int32(np.floor(predictions * (num_thresholds - 1))) float_labels = labels.astype(np.float64) histogram_range = (0, num_thresholds - 1) tp_buckets, _ = np.histogram( bucket_indices, bins=num_thresholds, range=histogram_range, weights=float_labels * weights, ) fp_buckets, _ = np.histogram( bucket_indices, bins=num_thresholds, range=histogram_range, weights=(1.0 - float_labels) * weights, ) # Obtain the reverse cumulative sum. tp = np.cumsum(tp_buckets[::-1])[::-1] fp = np.cumsum(fp_buckets[::-1])[::-1] tn = fp[0] - fp fn = tp[0] - tp precision = tp / np.maximum(_MINIMUM_COUNT, tp + fp) recall = tp / np.maximum(_MINIMUM_COUNT, tp + fn) return np.stack((tp, fp, tn, fn, precision, recall)) def _get_tensor_summary( name, display_name, description, tensor, content_type, components, json_config ): """Create a tensor summary with summary metadata. Args: name: Uniquely identifiable name of the summary op. Could be replaced by combination of name and type to make it unique even outside of this summary. display_name: Will be used as the display name in TensorBoard. Defaults to `name`. description: A longform readable description of the summary data. Markdown is supported. tensor: Tensor to display in summary. content_type: Type of content inside the Tensor. components: Bitmask representing present parts (vertices, colors, etc.) that belong to the summary. json_config: A string, JSON-serialized dictionary of ThreeJS classes configuration. Returns: Tensor summary with metadata. """ import torch from tensorboard.plugins.mesh import metadata tensor = torch.as_tensor(tensor) tensor_metadata = metadata.create_summary_metadata( name, display_name, content_type, components, tensor.shape, description, json_config=json_config, ) tensor = TensorProto( dtype="DT_FLOAT", float_val=tensor.reshape(-1).tolist(), tensor_shape=TensorShapeProto( dim=[ TensorShapeProto.Dim(size=tensor.shape[0]), TensorShapeProto.Dim(size=tensor.shape[1]), TensorShapeProto.Dim(size=tensor.shape[2]), ] ), ) tensor_summary = Summary.Value( tag=metadata.get_instance_name(name, content_type), tensor=tensor, metadata=tensor_metadata, ) return tensor_summary def _get_json_config(config_dict): """Parse and returns JSON string from python dictionary.""" json_config = "{}" if config_dict is not None: json_config = json.dumps(config_dict, sort_keys=True) return json_config # https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/mesh/summary.py def mesh( tag, vertices, colors, faces, config_dict, display_name=None, description=None ): """Output a merged `Summary` protocol buffer with a mesh/point cloud. Args: tag: A name for this summary operation. vertices: Tensor of shape `[dim_1, ..., dim_n, 3]` representing the 3D coordinates of vertices. faces: Tensor of shape `[dim_1, ..., dim_n, 3]` containing indices of vertices within each triangle. colors: Tensor of shape `[dim_1, ..., dim_n, 3]` containing colors for each vertex. display_name: If set, will be used as the display name in TensorBoard. Defaults to `name`. description: A longform readable description of the summary data. Markdown is supported. config_dict: Dictionary with ThreeJS classes names and configuration. Returns: Merged summary for mesh/point cloud representation. """ from tensorboard.plugins.mesh import metadata from tensorboard.plugins.mesh.plugin_data_pb2 import MeshPluginData json_config = _get_json_config(config_dict) summaries = [] tensors = [ (vertices, MeshPluginData.VERTEX), (faces, MeshPluginData.FACE), (colors, MeshPluginData.COLOR), ] tensors = [tensor for tensor in tensors if tensor[0] is not None] components = metadata.get_components_bitmask( [content_type for (tensor, content_type) in tensors] ) for tensor, content_type in tensors: summaries.append( _get_tensor_summary( tag, display_name, description, tensor, content_type, components, json_config, ) ) return Summary(value=summaries)