# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """The TensorBoard HParams plugin. See `http_api.md` in this directory for specifications of the routes for this plugin. """ import json import werkzeug from werkzeug import wrappers from tensorboard import plugin_util from tensorboard.plugins.hparams import api_pb2 from tensorboard.plugins.hparams import backend_context from tensorboard.plugins.hparams import download_data from tensorboard.plugins.hparams import error from tensorboard.plugins.hparams import get_experiment from tensorboard.plugins.hparams import list_metric_evals from tensorboard.plugins.hparams import list_session_groups from tensorboard.plugins.hparams import metadata from google.protobuf import json_format from tensorboard.backend import http_util from tensorboard.plugins import base_plugin from tensorboard.plugins.scalar import metadata as scalars_metadata from tensorboard.util import tb_logging logger = tb_logging.get_logger() class HParamsPlugin(base_plugin.TBPlugin): """HParams Plugin for TensorBoard. It supports both GETs and POSTs. See 'http_api.md' for more details. """ plugin_name = metadata.PLUGIN_NAME def __init__(self, context): """Instantiates HParams plugin via TensorBoard core. Args: context: A base_plugin.TBContext instance. """ self._context = backend_context.Context(context) def get_plugin_apps(self): """See base class.""" return { "/download_data": self.download_data_route, "/experiment": self.get_experiment_route, "/session_groups": self.list_session_groups_route, "/metric_evals": self.list_metric_evals_route, } def is_active(self): return False # `list_plugins` as called by TB core suffices def frontend_metadata(self): return base_plugin.FrontendMetadata(element_name="tf-hparams-dashboard") # ---- /download_data- ------------------------------------------------------- @wrappers.Request.application def download_data_route(self, request): ctx = plugin_util.context(request.environ) experiment_id = plugin_util.experiment_id(request.environ) try: response_format = request.args.get("format") columns_visibility = json.loads( request.args.get("columnsVisibility") ) request_proto = _parse_request_argument( request, api_pb2.ListSessionGroupsRequest ) session_groups = list_session_groups.Handler( ctx, self._context, experiment_id, request_proto ).run() experiment = get_experiment.Handler( ctx, self._context, experiment_id, request_proto ).run() body, mime_type = download_data.Handler( self._context, experiment, session_groups, response_format, columns_visibility, ).run() return http_util.Respond(request, body, mime_type) except error.HParamsError as e: logger.error("HParams error: %s" % e) raise werkzeug.exceptions.BadRequest(description=str(e)) # ---- /experiment ----------------------------------------------------------- @wrappers.Request.application def get_experiment_route(self, request): ctx = plugin_util.context(request.environ) experiment_id = plugin_util.experiment_id(request.environ) try: request_proto = _parse_request_argument( request, api_pb2.GetExperimentRequest ) return http_util.Respond( request, json_format.MessageToJson( get_experiment.Handler( ctx, self._context, experiment_id, request_proto ).run(), including_default_value_fields=True, ), "application/json", ) except error.HParamsError as e: logger.error("HParams error: %s" % e) raise werkzeug.exceptions.BadRequest(description=str(e)) # ---- /session_groups ------------------------------------------------------- @wrappers.Request.application def list_session_groups_route(self, request): ctx = plugin_util.context(request.environ) experiment_id = plugin_util.experiment_id(request.environ) try: request_proto = _parse_request_argument( request, api_pb2.ListSessionGroupsRequest ) return http_util.Respond( request, json_format.MessageToJson( list_session_groups.Handler( ctx, self._context, experiment_id, request_proto ).run(), including_default_value_fields=True, ), "application/json", ) except error.HParamsError as e: logger.error("HParams error: %s" % e) raise werkzeug.exceptions.BadRequest(description=str(e)) # ---- /metric_evals --------------------------------------------------------- @wrappers.Request.application def list_metric_evals_route(self, request): ctx = plugin_util.context(request.environ) experiment_id = plugin_util.experiment_id(request.environ) try: request_proto = _parse_request_argument( request, api_pb2.ListMetricEvalsRequest ) scalars_plugin = self._get_scalars_plugin() if not scalars_plugin: raise werkzeug.exceptions.NotFound("Scalars plugin not loaded") return http_util.Respond( request, list_metric_evals.Handler( ctx, request_proto, scalars_plugin, experiment_id ).run(), "application/json", ) except error.HParamsError as e: logger.error("HParams error: %s" % e) raise werkzeug.exceptions.BadRequest(description=str(e)) def _get_scalars_plugin(self): """Tries to get the scalars plugin. Returns: The scalars plugin or None if it is not yet registered. """ return self._context.tb_context.plugin_name_to_instance.get( scalars_metadata.PLUGIN_NAME ) def _parse_request_argument(request, proto_class): request_json = ( request.data if request.method == "POST" else request.args.get("request") ) try: return json_format.Parse(request_json, proto_class()) # if request_json is None, json_format.Parse will throw an AttributeError: # 'NoneType' object has no attribute 'decode'. except (AttributeError, json_format.ParseError) as e: raise error.HParamsError( "Expected a JSON-formatted request data of type: {}, but got {} ".format( proto_class, request_json ) ) from e