# Copyright 2020 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """The TensorBoard metrics plugin.""" import collections import imghdr import json from werkzeug import wrappers from tensorboard import errors from tensorboard import plugin_util from tensorboard.backend import http_util from tensorboard.data import provider from tensorboard.plugins import base_plugin from tensorboard.plugins.histogram import metadata as histogram_metadata from tensorboard.plugins.image import metadata as image_metadata from tensorboard.plugins.metrics import metadata from tensorboard.plugins.scalar import metadata as scalar_metadata _IMGHDR_TO_MIMETYPE = { "bmp": "image/bmp", "gif": "image/gif", "jpeg": "image/jpeg", "png": "image/png", "svg": "image/svg+xml", } _DEFAULT_IMAGE_MIMETYPE = "application/octet-stream" _SINGLE_RUN_PLUGINS = frozenset( [histogram_metadata.PLUGIN_NAME, image_metadata.PLUGIN_NAME] ) _SAMPLED_PLUGINS = frozenset([image_metadata.PLUGIN_NAME]) def _get_tag_description_info(mapping): """Gets maps from tags to descriptions, and descriptions to runs. Args: mapping: a nested map `d` such that `d[run][tag]` is a time series produced by DataProvider's `list_*` methods. Returns: A tuple containing tag_to_descriptions: A map from tag strings to a set of description strings. description_to_runs: A map from description strings to a set of run strings. """ tag_to_descriptions = collections.defaultdict(set) description_to_runs = collections.defaultdict(set) for (run, tag_to_content) in mapping.items(): for (tag, metadatum) in tag_to_content.items(): description = metadatum.description if len(description): tag_to_descriptions[tag].add(description) description_to_runs[description].add(run) return tag_to_descriptions, description_to_runs def _build_combined_description(descriptions, description_to_runs): """Creates a single description from a set of descriptions. Descriptions may be composites when a single tag has different descriptions across multiple runs. Args: descriptions: A list of description strings. description_to_runs: A map from description strings to a set of run strings. Returns: The combined description string. """ prefixed_descriptions = [] for description in descriptions: runs = sorted(description_to_runs[description]) run_or_runs = "runs" if len(runs) > 1 else "run" run_header = "## For " + run_or_runs + ": " + ", ".join(runs) description_html = run_header + "\n" + description prefixed_descriptions.append(description_html) header = "# Multiple descriptions\n" return header + "\n".join(prefixed_descriptions) def _get_tag_to_description(mapping): """Returns a map of tags to descriptions. Args: mapping: a nested map `d` such that `d[run][tag]` is a time series produced by DataProvider's `list_*` methods. Returns: A map from tag strings to description HTML strings. E.g. { "loss": "
...
", "loss2": "The lossy details
", } """ tag_to_descriptions, description_to_runs = _get_tag_description_info( mapping ) result = {} for tag in tag_to_descriptions: descriptions = sorted(tag_to_descriptions[tag]) if len(descriptions) == 1: description = descriptions[0] else: description = _build_combined_description( descriptions, description_to_runs ) result[tag] = plugin_util.markdown_to_safe_html(description) return result def _get_run_tag_info(mapping): """Returns a map of run names to a list of tag names. Args: mapping: a nested map `d` such that `d[run][tag]` is a time series produced by DataProvider's `list_*` methods. Returns: A map from run strings to a list of tag strings. E.g. {"loss001a": ["actor/loss", "critic/loss"], ...} """ return {run: sorted(mapping[run]) for run in mapping} def _format_basic_mapping(mapping): """Prepares a scalar or histogram mapping for client consumption. Args: mapping: a nested map `d` such that `d[run][tag]` is a time series produced by DataProvider's `list_*` methods. Returns: A dict with the following fields: runTagInfo: the return type of `_get_run_tag_info` tagDescriptions: the return type of `_get_tag_to_description` """ return { "runTagInfo": _get_run_tag_info(mapping), "tagDescriptions": _get_tag_to_description(mapping), } def _format_image_blob_sequence_datum(sorted_datum_list, sample): """Formats image metadata from a list of BlobSequenceDatum's for clients. This expects that frontend clients need to access images based on the run+tag+sample. Args: sorted_datum_list: a list of DataProvider's `BlobSequenceDatum`, sorted by step. This can be produced via DataProvider's `read_blob_sequences`. sample: zero-indexed integer for the requested sample. Returns: A list of `ImageStepDatum` (see http_api.md). """ # For images, ignore the first 2 items of a BlobSequenceDatum's values, which # correspond to width, height. index = sample + 2 step_data = [] for datum in sorted_datum_list: if len(datum.values) <= index: continue step_data.append( { "step": datum.step, "wallTime": datum.wall_time, "imageId": datum.values[index].blob_key, } ) return step_data def _get_tag_run_image_info(mapping): """Returns a map of tag names to run information. Args: mapping: the result of DataProvider's `list_blob_sequences`. Returns: A nested map from run strings to tag string to image info, where image info is an object of form {"maxSamplesPerStep": num}. For example, { "reshaped": { "test": {"maxSamplesPerStep": 1}, "train": {"maxSamplesPerStep": 1} }, "convolved": {"test": {"maxSamplesPerStep": 50}}, } """ tag_run_image_info = collections.defaultdict(dict) for (run, tag_to_content) in mapping.items(): for (tag, metadatum) in tag_to_content.items(): tag_run_image_info[tag][run] = { "maxSamplesPerStep": metadatum.max_length - 2 # width, height } return dict(tag_run_image_info) def _format_image_mapping(mapping): """Prepares an image mapping for client consumption. Args: mapping: the result of DataProvider's `list_blob_sequences`. Returns: A dict with the following fields: tagRunSampledInfo: the return type of `_get_tag_run_image_info` tagDescriptions: the return type of `_get_tag_description_info` """ return { "tagDescriptions": _get_tag_to_description(mapping), "tagRunSampledInfo": _get_tag_run_image_info(mapping), } class MetricsPlugin(base_plugin.TBPlugin): """Metrics Plugin for TensorBoard.""" plugin_name = metadata.PLUGIN_NAME def __init__(self, context): """Instantiates MetricsPlugin. Args: context: A base_plugin.TBContext instance. MetricsLoader checks that it contains a valid `data_provider`. """ self._data_provider = context.data_provider # For histograms, use a round number + 1 since sampling includes both start # and end steps, so N+1 samples corresponds to dividing the step sequence # into N intervals. sampling_hints = context.sampling_hints or {} self._plugin_downsampling = { "scalars": sampling_hints.get(scalar_metadata.PLUGIN_NAME, 1000), "histograms": sampling_hints.get( histogram_metadata.PLUGIN_NAME, 51 ), "images": sampling_hints.get(image_metadata.PLUGIN_NAME, 10), } self._scalar_version_checker = plugin_util._MetadataVersionChecker( data_kind="scalar time series", latest_known_version=0, ) self._histogram_version_checker = plugin_util._MetadataVersionChecker( data_kind="histogram time series", latest_known_version=0, ) self._image_version_checker = plugin_util._MetadataVersionChecker( data_kind="image time series", latest_known_version=0, ) def frontend_metadata(self): return base_plugin.FrontendMetadata( is_ng_component=True, tab_name="Time Series" ) def get_plugin_apps(self): return { "/tags": self._serve_tags, "/timeSeries": self._serve_time_series, "/imageData": self._serve_image_data, } def data_plugin_names(self): return ( scalar_metadata.PLUGIN_NAME, histogram_metadata.PLUGIN_NAME, image_metadata.PLUGIN_NAME, ) def is_active(self): return False # 'data_plugin_names' suffices. @wrappers.Request.application def _serve_tags(self, request): ctx = plugin_util.context(request.environ) experiment = plugin_util.experiment_id(request.environ) index = self._tags_impl(ctx, experiment=experiment) return http_util.Respond(request, index, "application/json") def _tags_impl(self, ctx, experiment=None): """Returns tag metadata for a given experiment's logged metrics. Args: ctx: A `tensorboard.context.RequestContext` value. experiment: optional string ID of the request's experiment. Returns: A nested dict 'd' with keys in ("scalars", "histograms", "images") and values being the return type of _format_*mapping. """ scalar_mapping = self._data_provider.list_scalars( ctx, experiment_id=experiment, plugin_name=scalar_metadata.PLUGIN_NAME, ) scalar_mapping = self._filter_by_version( scalar_mapping, scalar_metadata.parse_plugin_metadata, self._scalar_version_checker, ) histogram_mapping = self._data_provider.list_tensors( ctx, experiment_id=experiment, plugin_name=histogram_metadata.PLUGIN_NAME, ) if histogram_mapping is None: histogram_mapping = {} histogram_mapping = self._filter_by_version( histogram_mapping, histogram_metadata.parse_plugin_metadata, self._histogram_version_checker, ) image_mapping = self._data_provider.list_blob_sequences( ctx, experiment_id=experiment, plugin_name=image_metadata.PLUGIN_NAME, ) if image_mapping is None: image_mapping = {} image_mapping = self._filter_by_version( image_mapping, image_metadata.parse_plugin_metadata, self._image_version_checker, ) result = {} result["scalars"] = _format_basic_mapping(scalar_mapping) result["histograms"] = _format_basic_mapping(histogram_mapping) result["images"] = _format_image_mapping(image_mapping) return result def _filter_by_version(self, mapping, parse_metadata, version_checker): """Filter `DataProvider.list_*` output by summary metadata version.""" result = {run: {} for run in mapping} for (run, tag_to_content) in mapping.items(): for (tag, metadatum) in tag_to_content.items(): md = parse_metadata(metadatum.plugin_content) if not version_checker.ok(md.version, run, tag): continue result[run][tag] = metadatum return result @wrappers.Request.application def _serve_time_series(self, request): ctx = plugin_util.context(request.environ) experiment = plugin_util.experiment_id(request.environ) if request.method == "POST": series_requests_string = request.form.get("requests") else: series_requests_string = request.args.get("requests") if not series_requests_string: raise errors.InvalidArgumentError("Missing 'requests' field") try: series_requests = json.loads(series_requests_string) except ValueError: raise errors.InvalidArgumentError( "Unable to parse 'requests' as JSON" ) response = self._time_series_impl(ctx, experiment, series_requests) return http_util.Respond(request, response, "application/json") def _time_series_impl(self, ctx, experiment, series_requests): """Constructs a list of responses from a list of series requests. Args: ctx: A `tensorboard.context.RequestContext` value. experiment: string ID of the request's experiment. series_requests: a list of `TimeSeriesRequest` dicts (see http_api.md). Returns: A list of `TimeSeriesResponse` dicts (see http_api.md). """ responses = [ self._get_time_series(ctx, experiment, request) for request in series_requests ] return responses def _create_base_response(self, series_request): tag = series_request.get("tag") run = series_request.get("run") plugin = series_request.get("plugin") sample = series_request.get("sample") response = {"plugin": plugin, "tag": tag} if isinstance(run, str): response["run"] = run if isinstance(sample, int): response["sample"] = sample return response def _get_invalid_request_error(self, series_request): tag = series_request.get("tag") plugin = series_request.get("plugin") run = series_request.get("run") sample = series_request.get("sample") if not isinstance(tag, str): return "Missing tag" if ( plugin != scalar_metadata.PLUGIN_NAME and plugin != histogram_metadata.PLUGIN_NAME and plugin != image_metadata.PLUGIN_NAME ): return "Invalid plugin" if plugin in _SINGLE_RUN_PLUGINS and not isinstance(run, str): return "Missing run" if plugin in _SAMPLED_PLUGINS and not isinstance(sample, int): return "Missing sample" return None def _get_time_series(self, ctx, experiment, series_request): """Returns time series data for a given tag, plugin. Args: ctx: A `tensorboard.context.RequestContext` value. experiment: string ID of the request's experiment. series_request: a `TimeSeriesRequest` (see http_api.md). Returns: A `TimeSeriesResponse` dict (see http_api.md). """ tag = series_request.get("tag") run = series_request.get("run") plugin = series_request.get("plugin") sample = series_request.get("sample") response = self._create_base_response(series_request) request_error = self._get_invalid_request_error(series_request) if request_error: response["error"] = request_error return response runs = [run] if run else None run_to_series = None if plugin == scalar_metadata.PLUGIN_NAME: run_to_series = self._get_run_to_scalar_series( ctx, experiment, tag, runs ) if plugin == histogram_metadata.PLUGIN_NAME: run_to_series = self._get_run_to_histogram_series( ctx, experiment, tag, runs ) if plugin == image_metadata.PLUGIN_NAME: run_to_series = self._get_run_to_image_series( ctx, experiment, tag, sample, runs ) response["runToSeries"] = run_to_series return response def _get_run_to_scalar_series(self, ctx, experiment, tag, runs): """Builds a run-to-scalar-series dict for client consumption. Args: ctx: A `tensorboard.context.RequestContext` value. experiment: a string experiment id. tag: string of the requested tag. runs: optional list of run names as strings. Returns: A map from string run names to `ScalarStepDatum` (see http_api.md). """ mapping = self._data_provider.read_scalars( ctx, experiment_id=experiment, plugin_name=scalar_metadata.PLUGIN_NAME, downsample=self._plugin_downsampling["scalars"], run_tag_filter=provider.RunTagFilter(runs=runs, tags=[tag]), ) run_to_series = {} for (result_run, tag_data) in mapping.items(): if tag not in tag_data: continue values = [ { "wallTime": datum.wall_time, "step": datum.step, "value": datum.value, } for datum in tag_data[tag] ] run_to_series[result_run] = values return run_to_series def _format_histogram_datum_bins(self, datum): """Formats a histogram datum's bins for client consumption. Args: datum: a DataProvider's TensorDatum. Returns: A list of `HistogramBin`s (see http_api.md). """ numpy_list = datum.numpy.tolist() bins = [{"min": x[0], "max": x[1], "count": x[2]} for x in numpy_list] return bins def _get_run_to_histogram_series(self, ctx, experiment, tag, runs): """Builds a run-to-histogram-series dict for client consumption. Args: ctx: A `tensorboard.context.RequestContext` value. experiment: a string experiment id. tag: string of the requested tag. runs: optional list of run names as strings. Returns: A map from string run names to `HistogramStepDatum` (see http_api.md). """ mapping = self._data_provider.read_tensors( ctx, experiment_id=experiment, plugin_name=histogram_metadata.PLUGIN_NAME, downsample=self._plugin_downsampling["histograms"], run_tag_filter=provider.RunTagFilter(runs=runs, tags=[tag]), ) run_to_series = {} for (result_run, tag_data) in mapping.items(): if tag not in tag_data: continue values = [ { "wallTime": datum.wall_time, "step": datum.step, "bins": self._format_histogram_datum_bins(datum), } for datum in tag_data[tag] ] run_to_series[result_run] = values return run_to_series def _get_run_to_image_series(self, ctx, experiment, tag, sample, runs): """Builds a run-to-image-series dict for client consumption. Args: ctx: A `tensorboard.context.RequestContext` value. experiment: a string experiment id. tag: string of the requested tag. sample: zero-indexed integer for the requested sample. runs: optional list of run names as strings. Returns: A `RunToSeries` dict (see http_api.md). """ mapping = self._data_provider.read_blob_sequences( ctx, experiment_id=experiment, plugin_name=image_metadata.PLUGIN_NAME, downsample=self._plugin_downsampling["images"], run_tag_filter=provider.RunTagFilter(runs, tags=[tag]), ) run_to_series = {} for (result_run, tag_data) in mapping.items(): if tag not in tag_data: continue blob_sequence_datum_list = tag_data[tag] series = _format_image_blob_sequence_datum( blob_sequence_datum_list, sample ) if series: run_to_series[result_run] = series return run_to_series @wrappers.Request.application def _serve_image_data(self, request): """Serves an individual image.""" ctx = plugin_util.context(request.environ) blob_key = request.args["imageId"] if not blob_key: raise errors.InvalidArgumentError("Missing 'imageId' field") (data, content_type) = self._image_data_impl(ctx, blob_key) return http_util.Respond(request, data, content_type) def _image_data_impl(self, ctx, blob_key): """Gets the image data for a blob key. Args: ctx: A `tensorboard.context.RequestContext` value. blob_key: a string identifier for a DataProvider blob. Returns: A tuple containing: data: a raw bytestring of the requested image's contents. content_type: a string HTTP content type. """ data = self._data_provider.read_blob(ctx, blob_key=blob_key) image_type = imghdr.what(None, data) content_type = _IMGHDR_TO_MIMETYPE.get( image_type, _DEFAULT_IMAGE_MIMETYPE ) return (data, content_type)