# mypy: ignore-errors from inspect import getattr_static from ..bytecode_transformation import create_call_function from ..exc import Unsupported from .base import VariableTracker class SDPAParamsVariable(VariableTracker): """Represents the c++ params struct for scaled dot product attention. This is a read-only container.""" @staticmethod def create(tx, value, source): from torch.backends.cuda import SDPAParams from ..source import AttrSource from .builder import VariableBuilder from .torch import TorchInGraphFunctionVariable query_var = VariableBuilder(tx, AttrSource(source, "query"))(value.query) key_var = VariableBuilder(tx, AttrSource(source, "key"))(value.key) value_var = VariableBuilder(tx, AttrSource(source, "value"))(value.value) attn_mask_var = VariableBuilder(tx, AttrSource(source, "attn_mask"))( value.attn_mask ) dropout_var = VariableBuilder(tx, AttrSource(source, "dropout"))(value.dropout) is_causal_var = VariableBuilder(tx, AttrSource(source, "is_causal"))( value.is_causal ) param_vars = [ query_var, key_var, value_var, attn_mask_var, dropout_var, is_causal_var, ] return TorchInGraphFunctionVariable(SDPAParams).call_function( tx, param_vars, {} ) def __init__(self, proxy, param_vars, **kwargs): self.proxy = proxy self.param_vars = param_vars super().__init__(**kwargs) def reconstruct(self, codegen): assert self.source is None assert self.param_vars is not None codegen.load_import_from("torch._C", "_SDPAParams") codegen.foreach(self.param_vars) codegen.extend_output(create_call_function(len(self.param_vars), True)) def as_proxy(self): return self.proxy def var_getattr(self, tx, name: str) -> VariableTracker: import torch._C from ..source import AttrSource from .builder import wrap_fx_proxy from .misc import GetAttrVariable try: getattr_static(torch._C._SDPAParams, name) except AttributeError: # Using raise from is too verbose here raise Unsupported( # noqa: TRY200 f"Unsupported torch._C._SDPAParams attribute {name}" ) proxy = GetAttrVariable.create_getattr_proxy(self.as_proxy(), name) if self.source is not None: return wrap_fx_proxy( tx=tx, proxy=proxy, source=AttrSource(self.source, name) ) else: return wrap_fx_proxy(tx=tx, proxy=proxy) @staticmethod def is_sdpa_params(value): from torch.backends.cuda import SDPAParams return value is SDPAParams