160 lines
5.6 KiB
Python
160 lines
5.6 KiB
Python
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
# ==============================================================================
|
||
|
"""Classes and functions for handling the DownloadData API call."""
|
||
|
|
||
|
|
||
|
import csv
|
||
|
import io
|
||
|
import math
|
||
|
|
||
|
from tensorboard.plugins.hparams import error
|
||
|
|
||
|
|
||
|
class OutputFormat:
|
||
|
"""An enum used to list the valid output formats for API calls."""
|
||
|
|
||
|
JSON = "json"
|
||
|
CSV = "csv"
|
||
|
LATEX = "latex"
|
||
|
|
||
|
|
||
|
class Handler:
|
||
|
"""Handles a DownloadData request."""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
context,
|
||
|
experiment,
|
||
|
session_groups,
|
||
|
response_format,
|
||
|
columns_visibility,
|
||
|
):
|
||
|
"""Constructor.
|
||
|
|
||
|
Args:
|
||
|
context: A backend_context.Context instance.
|
||
|
experiment: Experiment proto.
|
||
|
session_groups: ListSessionGroupsResponse proto.
|
||
|
response_format: A string in the OutputFormat enum.
|
||
|
columns_visibility: A list of boolean values to filter columns.
|
||
|
"""
|
||
|
self._context = context
|
||
|
self._experiment = experiment
|
||
|
self._session_groups = session_groups
|
||
|
self._response_format = response_format
|
||
|
self._columns_visibility = columns_visibility
|
||
|
|
||
|
def run(self):
|
||
|
"""Handles the request specified on construction.
|
||
|
|
||
|
Returns:
|
||
|
A response body.
|
||
|
A mime type (string) for the response.
|
||
|
"""
|
||
|
experiment = self._experiment
|
||
|
session_groups = self._session_groups
|
||
|
response_format = self._response_format
|
||
|
visibility = self._columns_visibility
|
||
|
|
||
|
header = []
|
||
|
for hparam_info in experiment.hparam_infos:
|
||
|
header.append(hparam_info.display_name or hparam_info.name)
|
||
|
|
||
|
for metric_info in experiment.metric_infos:
|
||
|
header.append(metric_info.display_name or metric_info.name.tag)
|
||
|
|
||
|
def _filter_columns(row):
|
||
|
return [value for value, visible in zip(row, visibility) if visible]
|
||
|
|
||
|
header = _filter_columns(header)
|
||
|
|
||
|
rows = []
|
||
|
|
||
|
def _get_value(value):
|
||
|
if value.HasField("number_value"):
|
||
|
return value.number_value
|
||
|
if value.HasField("string_value"):
|
||
|
return value.string_value
|
||
|
if value.HasField("bool_value"):
|
||
|
return value.bool_value
|
||
|
# hyperparameter values can be optional in a session group
|
||
|
return ""
|
||
|
|
||
|
def _get_metric_id(metric):
|
||
|
return metric.group + "." + metric.tag
|
||
|
|
||
|
for group in session_groups.session_groups:
|
||
|
row = []
|
||
|
for hparam_info in experiment.hparam_infos:
|
||
|
row.append(_get_value(group.hparams[hparam_info.name]))
|
||
|
metric_values = {}
|
||
|
for metric_value in group.metric_values:
|
||
|
metric_id = _get_metric_id(metric_value.name)
|
||
|
metric_values[metric_id] = metric_value.value
|
||
|
for metric_info in experiment.metric_infos:
|
||
|
metric_id = _get_metric_id(metric_info.name)
|
||
|
row.append(metric_values.get(metric_id))
|
||
|
rows.append(_filter_columns(row))
|
||
|
|
||
|
if response_format == OutputFormat.JSON:
|
||
|
mime_type = "application/json"
|
||
|
body = dict(header=header, rows=rows)
|
||
|
elif response_format == OutputFormat.LATEX:
|
||
|
|
||
|
def latex_format(value):
|
||
|
if value is None:
|
||
|
return "-"
|
||
|
elif isinstance(value, int):
|
||
|
return "$%d$" % value
|
||
|
elif isinstance(value, float):
|
||
|
if math.isnan(value):
|
||
|
return r"$\mathrm{NaN}$"
|
||
|
if value in (float("inf"), float("-inf")):
|
||
|
return r"$%s\infty$" % ("-" if value < 0 else "+")
|
||
|
scientific = "%.3g" % value
|
||
|
if "e" in scientific:
|
||
|
coefficient, exponent = scientific.split("e")
|
||
|
return "$%s\\cdot 10^{%d}$" % (
|
||
|
coefficient,
|
||
|
int(exponent),
|
||
|
)
|
||
|
return "$%s$" % scientific
|
||
|
return value.replace("_", "\\_").replace("%", "\\%")
|
||
|
|
||
|
mime_type = "application/x-latex"
|
||
|
top_part = "\\begin{table}[tbp]\n\\begin{tabular}{%s}\n" % (
|
||
|
"l" * len(header)
|
||
|
)
|
||
|
header_part = (
|
||
|
" & ".join(map(latex_format, header)) + " \\\\ \\hline\n"
|
||
|
)
|
||
|
middle_part = "".join(
|
||
|
" & ".join(map(latex_format, row)) + " \\\\\n" for row in rows
|
||
|
)
|
||
|
bottom_part = "\\hline\n\\end{tabular}\n\\end{table}\n"
|
||
|
body = top_part + header_part + middle_part + bottom_part
|
||
|
elif response_format == OutputFormat.CSV:
|
||
|
string_io = io.StringIO()
|
||
|
writer = csv.writer(string_io)
|
||
|
writer.writerow(header)
|
||
|
writer.writerows(rows)
|
||
|
body = string_io.getvalue()
|
||
|
mime_type = "text/csv"
|
||
|
else:
|
||
|
raise error.HParamsError(
|
||
|
"Invalid reponses format: %s" % response_format
|
||
|
)
|
||
|
return body, mime_type
|