ai-content-maker/.venv/Lib/site-packages/tensorboard/plugins/hparams/hparams_plugin.py

206 lines
7.6 KiB
Python
Raw Permalink Normal View History

2024-05-03 04:18:51 +03:00
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The TensorBoard HParams plugin.
See `http_api.md` in this directory for specifications of the routes for
this plugin.
"""
import json
import werkzeug
from werkzeug import wrappers
from tensorboard import plugin_util
from tensorboard.plugins.hparams import api_pb2
from tensorboard.plugins.hparams import backend_context
from tensorboard.plugins.hparams import download_data
from tensorboard.plugins.hparams import error
from tensorboard.plugins.hparams import get_experiment
from tensorboard.plugins.hparams import list_metric_evals
from tensorboard.plugins.hparams import list_session_groups
from tensorboard.plugins.hparams import metadata
from google.protobuf import json_format
from tensorboard.backend import http_util
from tensorboard.plugins import base_plugin
from tensorboard.plugins.scalar import metadata as scalars_metadata
from tensorboard.util import tb_logging
logger = tb_logging.get_logger()
class HParamsPlugin(base_plugin.TBPlugin):
"""HParams Plugin for TensorBoard.
It supports both GETs and POSTs. See 'http_api.md' for more details.
"""
plugin_name = metadata.PLUGIN_NAME
def __init__(self, context):
"""Instantiates HParams plugin via TensorBoard core.
Args:
context: A base_plugin.TBContext instance.
"""
self._context = backend_context.Context(context)
def get_plugin_apps(self):
"""See base class."""
return {
"/download_data": self.download_data_route,
"/experiment": self.get_experiment_route,
"/session_groups": self.list_session_groups_route,
"/metric_evals": self.list_metric_evals_route,
}
def is_active(self):
return False # `list_plugins` as called by TB core suffices
def frontend_metadata(self):
return base_plugin.FrontendMetadata(element_name="tf-hparams-dashboard")
# ---- /download_data- -------------------------------------------------------
@wrappers.Request.application
def download_data_route(self, request):
ctx = plugin_util.context(request.environ)
experiment_id = plugin_util.experiment_id(request.environ)
try:
response_format = request.args.get("format")
columns_visibility = json.loads(
request.args.get("columnsVisibility")
)
request_proto = _parse_request_argument(
request, api_pb2.ListSessionGroupsRequest
)
session_groups = list_session_groups.Handler(
ctx, self._context, experiment_id, request_proto
).run()
experiment = get_experiment.Handler(
ctx, self._context, experiment_id, request_proto
).run()
body, mime_type = download_data.Handler(
self._context,
experiment,
session_groups,
response_format,
columns_visibility,
).run()
return http_util.Respond(request, body, mime_type)
except error.HParamsError as e:
logger.error("HParams error: %s" % e)
raise werkzeug.exceptions.BadRequest(description=str(e))
# ---- /experiment -----------------------------------------------------------
@wrappers.Request.application
def get_experiment_route(self, request):
ctx = plugin_util.context(request.environ)
experiment_id = plugin_util.experiment_id(request.environ)
try:
request_proto = _parse_request_argument(
request, api_pb2.GetExperimentRequest
)
return http_util.Respond(
request,
json_format.MessageToJson(
get_experiment.Handler(
ctx, self._context, experiment_id, request_proto
).run(),
including_default_value_fields=True,
),
"application/json",
)
except error.HParamsError as e:
logger.error("HParams error: %s" % e)
raise werkzeug.exceptions.BadRequest(description=str(e))
# ---- /session_groups -------------------------------------------------------
@wrappers.Request.application
def list_session_groups_route(self, request):
ctx = plugin_util.context(request.environ)
experiment_id = plugin_util.experiment_id(request.environ)
try:
request_proto = _parse_request_argument(
request, api_pb2.ListSessionGroupsRequest
)
return http_util.Respond(
request,
json_format.MessageToJson(
list_session_groups.Handler(
ctx, self._context, experiment_id, request_proto
).run(),
including_default_value_fields=True,
),
"application/json",
)
except error.HParamsError as e:
logger.error("HParams error: %s" % e)
raise werkzeug.exceptions.BadRequest(description=str(e))
# ---- /metric_evals ---------------------------------------------------------
@wrappers.Request.application
def list_metric_evals_route(self, request):
ctx = plugin_util.context(request.environ)
experiment_id = plugin_util.experiment_id(request.environ)
try:
request_proto = _parse_request_argument(
request, api_pb2.ListMetricEvalsRequest
)
scalars_plugin = self._get_scalars_plugin()
if not scalars_plugin:
raise werkzeug.exceptions.NotFound("Scalars plugin not loaded")
return http_util.Respond(
request,
list_metric_evals.Handler(
ctx, request_proto, scalars_plugin, experiment_id
).run(),
"application/json",
)
except error.HParamsError as e:
logger.error("HParams error: %s" % e)
raise werkzeug.exceptions.BadRequest(description=str(e))
def _get_scalars_plugin(self):
"""Tries to get the scalars plugin.
Returns:
The scalars plugin or None if it is not yet registered.
"""
return self._context.tb_context.plugin_name_to_instance.get(
scalars_metadata.PLUGIN_NAME
)
def _parse_request_argument(request, proto_class):
request_json = (
request.data
if request.method == "POST"
else request.args.get("request")
)
try:
return json_format.Parse(request_json, proto_class())
# if request_json is None, json_format.Parse will throw an AttributeError:
# 'NoneType' object has no attribute 'decode'.
except (AttributeError, json_format.ParseError) as e:
raise error.HParamsError(
"Expected a JSON-formatted request data of type: {}, but got {} ".format(
proto_class, request_json
)
) from e