381 lines
15 KiB
Python
381 lines
15 KiB
Python
|
# coding=utf-8
|
||
|
# Copyright 2023-present, the HuggingFace Inc. team.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
"""Contains `WebhooksServer` and `webhook_endpoint` to create a webhook server easily."""
|
||
|
|
||
|
import atexit
|
||
|
import inspect
|
||
|
import os
|
||
|
from functools import wraps
|
||
|
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
|
||
|
|
||
|
from .utils import experimental, is_gradio_available
|
||
|
from .utils._deprecation import _deprecate_method
|
||
|
|
||
|
|
||
|
if TYPE_CHECKING:
|
||
|
import gradio as gr
|
||
|
|
||
|
|
||
|
from fastapi import FastAPI, Request
|
||
|
from fastapi.responses import JSONResponse
|
||
|
|
||
|
|
||
|
_global_app: Optional["WebhooksServer"] = None
|
||
|
_is_local = os.getenv("SYSTEM") != "spaces"
|
||
|
|
||
|
|
||
|
@experimental
|
||
|
class WebhooksServer:
|
||
|
"""
|
||
|
The [`WebhooksServer`] class lets you create an instance of a Gradio app that can receive Huggingface webhooks.
|
||
|
These webhooks can be registered using the [`~WebhooksServer.add_webhook`] decorator. Webhook endpoints are added to
|
||
|
the app as a POST endpoint to the FastAPI router. Once all the webhooks are registered, the `run` method has to be
|
||
|
called to start the app.
|
||
|
|
||
|
It is recommended to accept [`WebhookPayload`] as the first argument of the webhook function. It is a Pydantic
|
||
|
model that contains all the information about the webhook event. The data will be parsed automatically for you.
|
||
|
|
||
|
Check out the [webhooks guide](../guides/webhooks_server) for a step-by-step tutorial on how to setup your
|
||
|
WebhooksServer and deploy it on a Space.
|
||
|
|
||
|
<Tip warning={true}>
|
||
|
|
||
|
`WebhooksServer` is experimental. Its API is subject to change in the future.
|
||
|
|
||
|
</Tip>
|
||
|
|
||
|
<Tip warning={true}>
|
||
|
|
||
|
You must have `gradio` installed to use `WebhooksServer` (`pip install --upgrade gradio`).
|
||
|
|
||
|
</Tip>
|
||
|
|
||
|
Args:
|
||
|
ui (`gradio.Blocks`, optional):
|
||
|
A Gradio UI instance to be used as the Space landing page. If `None`, a UI displaying instructions
|
||
|
about the configured webhooks is created.
|
||
|
webhook_secret (`str`, optional):
|
||
|
A secret key to verify incoming webhook requests. You can set this value to any secret you want as long as
|
||
|
you also configure it in your [webhooks settings panel](https://huggingface.co/settings/webhooks). You
|
||
|
can also set this value as the `WEBHOOK_SECRET` environment variable. If no secret is provided, the
|
||
|
webhook endpoints are opened without any security.
|
||
|
|
||
|
Example:
|
||
|
|
||
|
```python
|
||
|
import gradio as gr
|
||
|
from huggingface_hub import WebhooksServer, WebhookPayload
|
||
|
|
||
|
with gr.Blocks() as ui:
|
||
|
...
|
||
|
|
||
|
app = WebhooksServer(ui=ui, webhook_secret="my_secret_key")
|
||
|
|
||
|
@app.add_webhook("/say_hello")
|
||
|
async def hello(payload: WebhookPayload):
|
||
|
return {"message": "hello"}
|
||
|
|
||
|
app.run()
|
||
|
```
|
||
|
"""
|
||
|
|
||
|
def __new__(cls, *args, **kwargs) -> "WebhooksServer":
|
||
|
if not is_gradio_available():
|
||
|
raise ImportError(
|
||
|
"You must have `gradio` installed to use `WebhooksServer`. Please run `pip install --upgrade gradio`"
|
||
|
" first."
|
||
|
)
|
||
|
return super().__new__(cls)
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
ui: Optional["gr.Blocks"] = None,
|
||
|
webhook_secret: Optional[str] = None,
|
||
|
) -> None:
|
||
|
self._ui = ui
|
||
|
|
||
|
self.webhook_secret = webhook_secret or os.getenv("WEBHOOK_SECRET")
|
||
|
self.registered_webhooks: Dict[str, Callable] = {}
|
||
|
_warn_on_empty_secret(self.webhook_secret)
|
||
|
|
||
|
def add_webhook(self, path: Optional[str] = None) -> Callable:
|
||
|
"""
|
||
|
Decorator to add a webhook to the [`WebhooksServer`] server.
|
||
|
|
||
|
Args:
|
||
|
path (`str`, optional):
|
||
|
The URL path to register the webhook function. If not provided, the function name will be used as the
|
||
|
path. In any case, all webhooks are registered under `/webhooks`.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: If the provided path is already registered as a webhook.
|
||
|
|
||
|
Example:
|
||
|
```python
|
||
|
from huggingface_hub import WebhooksServer, WebhookPayload
|
||
|
|
||
|
app = WebhooksServer()
|
||
|
|
||
|
@app.add_webhook
|
||
|
async def trigger_training(payload: WebhookPayload):
|
||
|
if payload.repo.type == "dataset" and payload.event.action == "update":
|
||
|
# Trigger a training job if a dataset is updated
|
||
|
...
|
||
|
|
||
|
app.run()
|
||
|
```
|
||
|
"""
|
||
|
# Usage: directly as decorator. Example: `@app.add_webhook`
|
||
|
if callable(path):
|
||
|
# If path is a function, it means it was used as a decorator without arguments
|
||
|
return self.add_webhook()(path)
|
||
|
|
||
|
# Usage: provide a path. Example: `@app.add_webhook(...)`
|
||
|
@wraps(FastAPI.post)
|
||
|
def _inner_post(*args, **kwargs):
|
||
|
func = args[0]
|
||
|
abs_path = f"/webhooks/{(path or func.__name__).strip('/')}"
|
||
|
if abs_path in self.registered_webhooks:
|
||
|
raise ValueError(f"Webhook {abs_path} already exists.")
|
||
|
self.registered_webhooks[abs_path] = func
|
||
|
|
||
|
return _inner_post
|
||
|
|
||
|
def launch(self, prevent_thread_lock: bool = False, **launch_kwargs: Any) -> None:
|
||
|
"""Launch the Gradio app and register webhooks to the underlying FastAPI server.
|
||
|
|
||
|
Input parameters are forwarded to Gradio when launching the app.
|
||
|
"""
|
||
|
ui = self._ui or self._get_default_ui()
|
||
|
|
||
|
# Start Gradio App
|
||
|
# - as non-blocking so that webhooks can be added afterwards
|
||
|
# - as shared if launch locally (to debug webhooks)
|
||
|
launch_kwargs.setdefault("share", _is_local)
|
||
|
self.fastapi_app, _, _ = ui.launch(prevent_thread_lock=True, **launch_kwargs)
|
||
|
|
||
|
# Register webhooks to FastAPI app
|
||
|
for path, func in self.registered_webhooks.items():
|
||
|
# Add secret check if required
|
||
|
if self.webhook_secret is not None:
|
||
|
func = _wrap_webhook_to_check_secret(func, webhook_secret=self.webhook_secret)
|
||
|
|
||
|
# Add route to FastAPI app
|
||
|
self.fastapi_app.post(path)(func)
|
||
|
|
||
|
# Print instructions and block main thread
|
||
|
url = (ui.share_url or ui.local_url).strip("/")
|
||
|
message = "\nWebhooks are correctly setup and ready to use:"
|
||
|
message += "\n" + "\n".join(f" - POST {url}{webhook}" for webhook in self.registered_webhooks)
|
||
|
message += "\nGo to https://huggingface.co/settings/webhooks to setup your webhooks."
|
||
|
print(message)
|
||
|
|
||
|
if not prevent_thread_lock:
|
||
|
ui.block_thread()
|
||
|
|
||
|
@_deprecate_method(version="0.23", message="Use `WebhooksServer.launch` instead.")
|
||
|
def run(self) -> None:
|
||
|
return self.launch()
|
||
|
|
||
|
def _get_default_ui(self) -> "gr.Blocks":
|
||
|
"""Default UI if not provided (lists webhooks and provides basic instructions)."""
|
||
|
import gradio as gr
|
||
|
|
||
|
with gr.Blocks() as ui:
|
||
|
gr.Markdown("# This is an app to process 🤗 Webhooks")
|
||
|
gr.Markdown(
|
||
|
"Webhooks are a foundation for MLOps-related features. They allow you to listen for new changes on"
|
||
|
" specific repos or to all repos belonging to particular set of users/organizations (not just your"
|
||
|
" repos, but any repo). Check out this [guide](https://huggingface.co/docs/hub/webhooks) to get to"
|
||
|
" know more about webhooks on the Huggingface Hub."
|
||
|
)
|
||
|
gr.Markdown(
|
||
|
f"{len(self.registered_webhooks)} webhook(s) are registered:"
|
||
|
+ "\n\n"
|
||
|
+ "\n ".join(
|
||
|
f"- [{webhook_path}]({_get_webhook_doc_url(webhook.__name__, webhook_path)})"
|
||
|
for webhook_path, webhook in self.registered_webhooks.items()
|
||
|
)
|
||
|
)
|
||
|
gr.Markdown(
|
||
|
"Go to https://huggingface.co/settings/webhooks to setup your webhooks."
|
||
|
+ "\nYou app is running locally. Please look at the logs to check the full URL you need to set."
|
||
|
if _is_local
|
||
|
else (
|
||
|
"\nThis app is running on a Space. You can find the corresponding URL in the options menu"
|
||
|
" (top-right) > 'Embed the Space'. The URL looks like 'https://{username}-{repo_name}.hf.space'."
|
||
|
)
|
||
|
)
|
||
|
return ui
|
||
|
|
||
|
|
||
|
@experimental
|
||
|
def webhook_endpoint(path: Optional[str] = None) -> Callable:
|
||
|
"""Decorator to start a [`WebhooksServer`] and register the decorated function as a webhook endpoint.
|
||
|
|
||
|
This is a helper to get started quickly. If you need more flexibility (custom landing page or webhook secret),
|
||
|
you can use [`WebhooksServer`] directly. You can register multiple webhook endpoints (to the same server) by using
|
||
|
this decorator multiple times.
|
||
|
|
||
|
Check out the [webhooks guide](../guides/webhooks_server) for a step-by-step tutorial on how to setup your
|
||
|
server and deploy it on a Space.
|
||
|
|
||
|
<Tip warning={true}>
|
||
|
|
||
|
`webhook_endpoint` is experimental. Its API is subject to change in the future.
|
||
|
|
||
|
</Tip>
|
||
|
|
||
|
<Tip warning={true}>
|
||
|
|
||
|
You must have `gradio` installed to use `webhook_endpoint` (`pip install --upgrade gradio`).
|
||
|
|
||
|
</Tip>
|
||
|
|
||
|
Args:
|
||
|
path (`str`, optional):
|
||
|
The URL path to register the webhook function. If not provided, the function name will be used as the path.
|
||
|
In any case, all webhooks are registered under `/webhooks`.
|
||
|
|
||
|
Examples:
|
||
|
The default usage is to register a function as a webhook endpoint. The function name will be used as the path.
|
||
|
The server will be started automatically at exit (i.e. at the end of the script).
|
||
|
|
||
|
```python
|
||
|
from huggingface_hub import webhook_endpoint, WebhookPayload
|
||
|
|
||
|
@webhook_endpoint
|
||
|
async def trigger_training(payload: WebhookPayload):
|
||
|
if payload.repo.type == "dataset" and payload.event.action == "update":
|
||
|
# Trigger a training job if a dataset is updated
|
||
|
...
|
||
|
|
||
|
# Server is automatically started at the end of the script.
|
||
|
```
|
||
|
|
||
|
Advanced usage: register a function as a webhook endpoint and start the server manually. This is useful if you
|
||
|
are running it in a notebook.
|
||
|
|
||
|
```python
|
||
|
from huggingface_hub import webhook_endpoint, WebhookPayload
|
||
|
|
||
|
@webhook_endpoint
|
||
|
async def trigger_training(payload: WebhookPayload):
|
||
|
if payload.repo.type == "dataset" and payload.event.action == "update":
|
||
|
# Trigger a training job if a dataset is updated
|
||
|
...
|
||
|
|
||
|
# Start the server manually
|
||
|
trigger_training.run()
|
||
|
```
|
||
|
"""
|
||
|
if callable(path):
|
||
|
# If path is a function, it means it was used as a decorator without arguments
|
||
|
return webhook_endpoint()(path)
|
||
|
|
||
|
@wraps(WebhooksServer.add_webhook)
|
||
|
def _inner(func: Callable) -> Callable:
|
||
|
app = _get_global_app()
|
||
|
app.add_webhook(path)(func)
|
||
|
if len(app.registered_webhooks) == 1:
|
||
|
# Register `app.run` to run at exit (only once)
|
||
|
atexit.register(app.run)
|
||
|
|
||
|
@wraps(app.run)
|
||
|
def _run_now():
|
||
|
# Run the app directly (without waiting atexit)
|
||
|
atexit.unregister(app.run)
|
||
|
app.run()
|
||
|
|
||
|
func.run = _run_now # type: ignore
|
||
|
return func
|
||
|
|
||
|
return _inner
|
||
|
|
||
|
|
||
|
def _get_global_app() -> WebhooksServer:
|
||
|
global _global_app
|
||
|
if _global_app is None:
|
||
|
_global_app = WebhooksServer()
|
||
|
return _global_app
|
||
|
|
||
|
|
||
|
def _warn_on_empty_secret(webhook_secret: Optional[str]) -> None:
|
||
|
if webhook_secret is None:
|
||
|
print("Webhook secret is not defined. This means your webhook endpoints will be open to everyone.")
|
||
|
print(
|
||
|
"To add a secret, set `WEBHOOK_SECRET` as environment variable or pass it at initialization: "
|
||
|
"\n\t`app = WebhooksServer(webhook_secret='my_secret', ...)`"
|
||
|
)
|
||
|
print(
|
||
|
"For more details about webhook secrets, please refer to"
|
||
|
" https://huggingface.co/docs/hub/webhooks#webhook-secret."
|
||
|
)
|
||
|
else:
|
||
|
print("Webhook secret is correctly defined.")
|
||
|
|
||
|
|
||
|
def _get_webhook_doc_url(webhook_name: str, webhook_path: str) -> str:
|
||
|
"""Returns the anchor to a given webhook in the docs (experimental)"""
|
||
|
return "/docs#/default/" + webhook_name + webhook_path.replace("/", "_") + "_post"
|
||
|
|
||
|
|
||
|
def _wrap_webhook_to_check_secret(func: Callable, webhook_secret: str) -> Callable:
|
||
|
"""Wraps a webhook function to check the webhook secret before calling the function.
|
||
|
|
||
|
This is a hacky way to add the `request` parameter to the function signature. Since FastAPI based itself on route
|
||
|
parameters to inject the values to the function, we need to hack the function signature to retrieve the `Request`
|
||
|
object (and hence the headers). A far cleaner solution would be to use a middleware. However, since
|
||
|
`fastapi==0.90.1`, a middleware cannot be added once the app has started. And since the FastAPI app is started by
|
||
|
Gradio internals (and not by us), we cannot add a middleware.
|
||
|
|
||
|
This method is called only when a secret has been defined by the user. If a request is sent without the
|
||
|
"x-webhook-secret", the function will return a 401 error (unauthorized). If the header is sent but is incorrect,
|
||
|
the function will return a 403 error (forbidden).
|
||
|
|
||
|
Inspired by https://stackoverflow.com/a/33112180.
|
||
|
"""
|
||
|
initial_sig = inspect.signature(func)
|
||
|
|
||
|
@wraps(func)
|
||
|
async def _protected_func(request: Request, **kwargs):
|
||
|
request_secret = request.headers.get("x-webhook-secret")
|
||
|
if request_secret is None:
|
||
|
return JSONResponse({"error": "x-webhook-secret header not set."}, status_code=401)
|
||
|
if request_secret != webhook_secret:
|
||
|
return JSONResponse({"error": "Invalid webhook secret."}, status_code=403)
|
||
|
|
||
|
# Inject `request` in kwargs if required
|
||
|
if "request" in initial_sig.parameters:
|
||
|
kwargs["request"] = request
|
||
|
|
||
|
# Handle both sync and async routes
|
||
|
if inspect.iscoroutinefunction(func):
|
||
|
return await func(**kwargs)
|
||
|
else:
|
||
|
return func(**kwargs)
|
||
|
|
||
|
# Update signature to include request
|
||
|
if "request" not in initial_sig.parameters:
|
||
|
_protected_func.__signature__ = initial_sig.replace( # type: ignore
|
||
|
parameters=(
|
||
|
inspect.Parameter(name="request", kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Request),
|
||
|
)
|
||
|
+ tuple(initial_sig.parameters.values())
|
||
|
)
|
||
|
|
||
|
# Return protected route
|
||
|
return _protected_func
|