2323 lines
85 KiB
Python
2323 lines
85 KiB
Python
|
# coding=utf-8
|
||
|
# Copyright 2022 Meta and The HuggingFace Inc. team. All rights reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
import math
|
||
|
import sys
|
||
|
from dataclasses import dataclass
|
||
|
from functools import partial
|
||
|
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union
|
||
|
|
||
|
import numpy as np
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
from torch.nn import LayerNorm
|
||
|
|
||
|
from ...integrations.deepspeed import is_deepspeed_available
|
||
|
from ...modeling_outputs import ModelOutput
|
||
|
from ...utils import (
|
||
|
ContextManagers,
|
||
|
add_start_docstrings,
|
||
|
add_start_docstrings_to_model_forward,
|
||
|
is_scipy_available,
|
||
|
logging,
|
||
|
replace_return_docstrings,
|
||
|
)
|
||
|
from .configuration_esm import EsmConfig
|
||
|
from .modeling_esm import ESM_START_DOCSTRING, EsmModel, EsmPreTrainedModel
|
||
|
from .openfold_utils import (
|
||
|
OFProtein,
|
||
|
Rigid,
|
||
|
Rotation,
|
||
|
atom14_to_atom37,
|
||
|
chunk_layer,
|
||
|
compute_predicted_aligned_error,
|
||
|
compute_tm,
|
||
|
frames_and_literature_positions_to_atom14_pos,
|
||
|
make_atom14_masks,
|
||
|
residue_constants,
|
||
|
to_pdb,
|
||
|
torsion_angles_to_frames,
|
||
|
)
|
||
|
|
||
|
|
||
|
logger = logging.get_logger(__name__)
|
||
|
_CHECKPOINT_FOR_DOC = "facebook/esmfold_v1"
|
||
|
_CONFIG_FOR_DOC = "EsmConfig"
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
class EsmForProteinFoldingOutput(ModelOutput):
|
||
|
"""
|
||
|
Output type of [`EsmForProteinFoldingOutput`].
|
||
|
|
||
|
Args:
|
||
|
frames (`torch.FloatTensor`):
|
||
|
Output frames.
|
||
|
sidechain_frames (`torch.FloatTensor`):
|
||
|
Output sidechain frames.
|
||
|
unnormalized_angles (`torch.FloatTensor`):
|
||
|
Predicted unnormalized backbone and side chain torsion angles.
|
||
|
angles (`torch.FloatTensor`):
|
||
|
Predicted backbone and side chain torsion angles.
|
||
|
positions (`torch.FloatTensor`):
|
||
|
Predicted positions of the backbone and side chain atoms.
|
||
|
states (`torch.FloatTensor`):
|
||
|
Hidden states from the protein folding trunk.
|
||
|
s_s (`torch.FloatTensor`):
|
||
|
Per-residue embeddings derived by concatenating the hidden states of each layer of the ESM-2 LM stem.
|
||
|
s_z (`torch.FloatTensor`):
|
||
|
Pairwise residue embeddings.
|
||
|
distogram_logits (`torch.FloatTensor`):
|
||
|
Input logits to the distogram used to compute residue distances.
|
||
|
lm_logits (`torch.FloatTensor`):
|
||
|
Logits output by the ESM-2 protein language model stem.
|
||
|
aatype (`torch.FloatTensor`):
|
||
|
Input amino acids (AlphaFold2 indices).
|
||
|
atom14_atom_exists (`torch.FloatTensor`):
|
||
|
Whether each atom exists in the atom14 representation.
|
||
|
residx_atom14_to_atom37 (`torch.FloatTensor`):
|
||
|
Mapping between atoms in the atom14 and atom37 representations.
|
||
|
residx_atom37_to_atom14 (`torch.FloatTensor`):
|
||
|
Mapping between atoms in the atom37 and atom14 representations.
|
||
|
atom37_atom_exists (`torch.FloatTensor`):
|
||
|
Whether each atom exists in the atom37 representation.
|
||
|
residue_index (`torch.FloatTensor`):
|
||
|
The index of each residue in the protein chain. Unless internal padding tokens are used, this will just be
|
||
|
a sequence of integers from 0 to `sequence_length`.
|
||
|
lddt_head (`torch.FloatTensor`):
|
||
|
Raw outputs from the lddt head used to compute plddt.
|
||
|
plddt (`torch.FloatTensor`):
|
||
|
Per-residue confidence scores. Regions of low confidence may indicate areas where the model's prediction is
|
||
|
uncertain, or where the protein structure is disordered.
|
||
|
ptm_logits (`torch.FloatTensor`):
|
||
|
Raw logits used for computing ptm.
|
||
|
ptm (`torch.FloatTensor`):
|
||
|
TM-score output representing the model's high-level confidence in the overall structure.
|
||
|
aligned_confidence_probs (`torch.FloatTensor`):
|
||
|
Per-residue confidence scores for the aligned structure.
|
||
|
predicted_aligned_error (`torch.FloatTensor`):
|
||
|
Predicted error between the model's prediction and the ground truth.
|
||
|
max_predicted_aligned_error (`torch.FloatTensor`):
|
||
|
Per-sample maximum predicted error.
|
||
|
"""
|
||
|
|
||
|
frames: torch.FloatTensor = None
|
||
|
sidechain_frames: torch.FloatTensor = None
|
||
|
unnormalized_angles: torch.FloatTensor = None
|
||
|
angles: torch.FloatTensor = None
|
||
|
positions: torch.FloatTensor = None
|
||
|
states: torch.FloatTensor = None
|
||
|
s_s: torch.FloatTensor = None
|
||
|
s_z: torch.FloatTensor = None
|
||
|
distogram_logits: torch.FloatTensor = None
|
||
|
lm_logits: torch.FloatTensor = None
|
||
|
aatype: torch.FloatTensor = None
|
||
|
atom14_atom_exists: torch.FloatTensor = None
|
||
|
residx_atom14_to_atom37: torch.FloatTensor = None
|
||
|
residx_atom37_to_atom14: torch.FloatTensor = None
|
||
|
atom37_atom_exists: torch.FloatTensor = None
|
||
|
residue_index: torch.FloatTensor = None
|
||
|
lddt_head: torch.FloatTensor = None
|
||
|
plddt: torch.FloatTensor = None
|
||
|
ptm_logits: torch.FloatTensor = None
|
||
|
ptm: torch.FloatTensor = None
|
||
|
aligned_confidence_probs: torch.FloatTensor = None
|
||
|
predicted_aligned_error: torch.FloatTensor = None
|
||
|
max_predicted_aligned_error: torch.FloatTensor = None
|
||
|
|
||
|
|
||
|
ESMFOLD_INPUTS_DOCSTRING = r"""
|
||
|
Args:
|
||
|
input_ids (`torch.LongTensor` of shape `({0})`):
|
||
|
Indices of input sequence tokens in the vocabulary.
|
||
|
|
||
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||
|
[`PreTrainedTokenizer.__call__`] for details.
|
||
|
|
||
|
[What are input IDs?](../glossary#input-ids)
|
||
|
attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
|
||
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||
|
|
||
|
- 1 for tokens that are **not masked**,
|
||
|
- 0 for tokens that are **masked**.
|
||
|
|
||
|
[What are attention masks?](../glossary#attention-mask)
|
||
|
position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
|
||
|
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
||
|
config.max_position_embeddings - 1]`.
|
||
|
|
||
|
[What are position IDs?](../glossary#position-ids)
|
||
|
masking_pattern (`torch.LongTensor` of shape `({0})`, *optional*):
|
||
|
Locations of tokens to mask during training as a form of regularization. Mask values selected in `[0, 1]`.
|
||
|
num_recycles (`int`, *optional*, defaults to `None`):
|
||
|
Number of times to recycle the input sequence. If `None`, defaults to `config.num_recycles`. "Recycling"
|
||
|
consists of passing the output of the folding trunk back in as input to the trunk. During training, the
|
||
|
number of recycles should vary with each batch, to ensure that the model learns to output valid predictions
|
||
|
after each recycle. During inference, num_recycles should be set to the highest value that the model was
|
||
|
trained with for maximum accuracy. Accordingly, when this value is set to `None`, config.max_recycles is
|
||
|
used.
|
||
|
"""
|
||
|
|
||
|
|
||
|
def is_fp16_enabled():
|
||
|
# Autocast world
|
||
|
fp16_enabled = torch.get_autocast_gpu_dtype() == torch.float16
|
||
|
fp16_enabled = fp16_enabled and torch.is_autocast_enabled()
|
||
|
|
||
|
return fp16_enabled
|
||
|
|
||
|
|
||
|
def is_deepspeed_initialized():
|
||
|
if is_deepspeed_available():
|
||
|
return False
|
||
|
else:
|
||
|
try:
|
||
|
import deepspeed
|
||
|
|
||
|
# This is not available in all DeepSpeed versions.
|
||
|
return deepspeed.utils.is_initialized()
|
||
|
except Exception:
|
||
|
return False
|
||
|
|
||
|
|
||
|
def collate_dense_tensors(samples: List[torch.Tensor], pad_v: float = 0) -> torch.Tensor:
|
||
|
"""
|
||
|
Takes a list of tensors with the following dimensions:
|
||
|
[(d_11, ..., d_1K),
|
||
|
(d_21, ..., d_2K), ..., (d_N1, ..., d_NK)]
|
||
|
and stack + pads them into a single tensor of:
|
||
|
(N, max_i=1,N { d_i1 }, ..., max_i=1,N {diK})
|
||
|
"""
|
||
|
if len(samples) == 0:
|
||
|
return torch.Tensor()
|
||
|
if len({x.dim() for x in samples}) != 1:
|
||
|
raise RuntimeError(f"Samples has varying dimensions: {[x.dim() for x in samples]}")
|
||
|
(device,) = tuple({x.device for x in samples}) # assumes all on same device
|
||
|
max_shape = [max(lst) for lst in zip(*[x.shape for x in samples])]
|
||
|
result = torch.empty(len(samples), *max_shape, dtype=samples[0].dtype, device=device)
|
||
|
result.fill_(pad_v)
|
||
|
for i in range(len(samples)):
|
||
|
result_i = result[i]
|
||
|
t = samples[i]
|
||
|
result_i[tuple(slice(0, k) for k in t.shape)] = t
|
||
|
return result
|
||
|
|
||
|
|
||
|
def flatten_final_dims(t: torch.Tensor, no_dims: int):
|
||
|
return t.reshape(t.shape[:-no_dims] + (-1,))
|
||
|
|
||
|
|
||
|
def permute_final_dims(tensor: torch.Tensor, inds: List[int]):
|
||
|
zero_index = -1 * len(inds)
|
||
|
first_inds = list(range(len(tensor.shape[:zero_index])))
|
||
|
return tensor.permute(first_inds + [zero_index + i for i in inds])
|
||
|
|
||
|
|
||
|
def dict_multimap(fn, dicts):
|
||
|
first = dicts[0]
|
||
|
new_dict = {}
|
||
|
for k, v in first.items():
|
||
|
all_v = [d[k] for d in dicts]
|
||
|
if isinstance(v, dict):
|
||
|
new_dict[k] = dict_multimap(fn, all_v)
|
||
|
else:
|
||
|
new_dict[k] = fn(all_v)
|
||
|
|
||
|
return new_dict
|
||
|
|
||
|
|
||
|
def trunc_normal_init_(weights, scale=1.0, fan="fan_in"):
|
||
|
shape = weights.shape
|
||
|
scale = scale / max(1, shape[1])
|
||
|
|
||
|
if not is_scipy_available():
|
||
|
logger.warning(
|
||
|
"This init requires scipy, but scipy was not found, default to an approximation that might not be"
|
||
|
" equivalent."
|
||
|
)
|
||
|
std = math.sqrt(scale)
|
||
|
torch.nn.init.normal_(weights, std=std).clamp(min=0.0, max=2.0 * std)
|
||
|
|
||
|
else:
|
||
|
from scipy.stats import truncnorm
|
||
|
|
||
|
std = math.sqrt(scale) / truncnorm.std(a=-2, b=2, loc=0, scale=1)
|
||
|
samples = truncnorm.rvs(a=-2, b=2, loc=0, scale=std, size=weights.numel())
|
||
|
samples = np.reshape(samples, shape)
|
||
|
weights.copy_(torch.tensor(samples, device=weights.device))
|
||
|
|
||
|
|
||
|
def ipa_point_weights_init_(weights):
|
||
|
with torch.no_grad():
|
||
|
softplus_inverse_1 = 0.541324854612918
|
||
|
weights.fill_(softplus_inverse_1)
|
||
|
|
||
|
|
||
|
class EsmFoldLinear(nn.Linear):
|
||
|
"""
|
||
|
A Linear layer with built-in nonstandard initializations. Called just like torch.nn.Linear.
|
||
|
|
||
|
Implements the initializers in 1.11.4, plus some additional ones found in the code.
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
in_dim: int,
|
||
|
out_dim: int,
|
||
|
bias: bool = True,
|
||
|
init: str = "default",
|
||
|
init_fn: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None,
|
||
|
):
|
||
|
"""
|
||
|
Args:
|
||
|
in_dim:
|
||
|
The final dimension of inputs to the layer
|
||
|
out_dim:
|
||
|
The final dimension of layer outputs
|
||
|
bias:
|
||
|
Whether to learn an additive bias. True by default
|
||
|
init:
|
||
|
The initializer to use. Choose from:
|
||
|
|
||
|
"default": LeCun fan-in truncated normal initialization "relu": He initialization w/ truncated normal
|
||
|
distribution "glorot": Fan-average Glorot uniform initialization "gating": Weights=0, Bias=1 "normal":
|
||
|
Normal initialization with std=1/sqrt(fan_in) "final": Weights=0, Bias=0
|
||
|
|
||
|
Overridden by init_fn if the latter is not None.
|
||
|
init_fn:
|
||
|
A custom initializer taking weight and bias as inputs. Overrides init if not None.
|
||
|
"""
|
||
|
super().__init__(in_dim, out_dim, bias=bias)
|
||
|
|
||
|
if bias:
|
||
|
with torch.no_grad():
|
||
|
self.bias.fill_(0)
|
||
|
self.init = init
|
||
|
self.init_fn = init_fn
|
||
|
|
||
|
if init not in ["default", "relu", "glorot", "gating", "normal", "final"]:
|
||
|
raise ValueError("Invalid init string.")
|
||
|
|
||
|
|
||
|
class EsmFoldLayerNorm(nn.Module):
|
||
|
def __init__(self, c_in, eps=1e-5):
|
||
|
super().__init__()
|
||
|
|
||
|
self.c_in = (c_in,)
|
||
|
self.eps = eps
|
||
|
|
||
|
self.weight = nn.Parameter(torch.ones(c_in))
|
||
|
self.bias = nn.Parameter(torch.zeros(c_in))
|
||
|
|
||
|
def forward(self, x):
|
||
|
d = x.dtype
|
||
|
if d is torch.bfloat16 and not is_deepspeed_initialized():
|
||
|
with torch.cuda.amp.autocast(enabled=False):
|
||
|
out = nn.functional.layer_norm(x, self.c_in, self.weight.to(dtype=d), self.bias.to(dtype=d), self.eps)
|
||
|
else:
|
||
|
out = nn.functional.layer_norm(x, self.c_in, self.weight, self.bias, self.eps)
|
||
|
|
||
|
return out
|
||
|
|
||
|
|
||
|
@torch.jit.ignore
|
||
|
def softmax_no_cast(t: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
||
|
"""
|
||
|
Softmax, but without automatic casting to fp32 when the input is of type bfloat16
|
||
|
"""
|
||
|
d = t.dtype
|
||
|
if d is torch.bfloat16 and not is_deepspeed_initialized():
|
||
|
with torch.cuda.amp.autocast(enabled=False):
|
||
|
s = torch.nn.functional.softmax(t, dim=dim)
|
||
|
else:
|
||
|
s = torch.nn.functional.softmax(t, dim=dim)
|
||
|
|
||
|
return s
|
||
|
|
||
|
|
||
|
class EsmFoldAttention(nn.Module):
|
||
|
"""
|
||
|
Standard multi-head attention using AlphaFold's default layer initialization. Allows multiple bias vectors.
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
c_q: int,
|
||
|
c_k: int,
|
||
|
c_v: int,
|
||
|
c_hidden: int,
|
||
|
no_heads: int,
|
||
|
gating: bool = True,
|
||
|
):
|
||
|
"""
|
||
|
Args:
|
||
|
c_q:
|
||
|
Input dimension of query data
|
||
|
c_k:
|
||
|
Input dimension of key data
|
||
|
c_v:
|
||
|
Input dimension of value data
|
||
|
c_hidden:
|
||
|
Per-head hidden dimension
|
||
|
no_heads:
|
||
|
Number of attention heads
|
||
|
gating:
|
||
|
Whether the output should be gated using query data
|
||
|
"""
|
||
|
super().__init__()
|
||
|
|
||
|
self.c_q = c_q
|
||
|
self.c_k = c_k
|
||
|
self.c_v = c_v
|
||
|
self.c_hidden = c_hidden
|
||
|
self.no_heads = no_heads
|
||
|
self.gating = gating
|
||
|
|
||
|
# DISCREPANCY: c_hidden is not the per-head channel dimension, as
|
||
|
# stated in the supplement, but the overall channel dimension.
|
||
|
|
||
|
self.linear_q = EsmFoldLinear(self.c_q, self.c_hidden * self.no_heads, bias=False, init="glorot")
|
||
|
self.linear_k = EsmFoldLinear(self.c_k, self.c_hidden * self.no_heads, bias=False, init="glorot")
|
||
|
self.linear_v = EsmFoldLinear(self.c_v, self.c_hidden * self.no_heads, bias=False, init="glorot")
|
||
|
self.linear_o = EsmFoldLinear(self.c_hidden * self.no_heads, self.c_q, init="final")
|
||
|
|
||
|
self.linear_g = None
|
||
|
if self.gating:
|
||
|
self.linear_g = EsmFoldLinear(self.c_q, self.c_hidden * self.no_heads, init="gating")
|
||
|
|
||
|
self.sigmoid = nn.Sigmoid()
|
||
|
|
||
|
def _prep_qkv(self, q_x: torch.Tensor, kv_x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||
|
# [*, Q/K/V, H * C_hidden]
|
||
|
q = self.linear_q(q_x)
|
||
|
k = self.linear_k(kv_x)
|
||
|
v = self.linear_v(kv_x)
|
||
|
|
||
|
# [*, Q/K, H, C_hidden]
|
||
|
q = q.view(q.shape[:-1] + (self.no_heads, -1))
|
||
|
k = k.view(k.shape[:-1] + (self.no_heads, -1))
|
||
|
v = v.view(v.shape[:-1] + (self.no_heads, -1))
|
||
|
|
||
|
# [*, H, Q/K, C_hidden]
|
||
|
q = q.transpose(-2, -3)
|
||
|
k = k.transpose(-2, -3)
|
||
|
v = v.transpose(-2, -3)
|
||
|
|
||
|
q /= math.sqrt(self.c_hidden)
|
||
|
|
||
|
return q, k, v
|
||
|
|
||
|
def _wrap_up(self, o: torch.Tensor, q_x: torch.Tensor) -> torch.Tensor:
|
||
|
if self.linear_g is not None:
|
||
|
g = self.sigmoid(self.linear_g(q_x))
|
||
|
|
||
|
# [*, Q, H, C_hidden]
|
||
|
g = g.view(g.shape[:-1] + (self.no_heads, -1))
|
||
|
o = o * g
|
||
|
|
||
|
# [*, Q, H * C_hidden]
|
||
|
o = flatten_final_dims(o, 2)
|
||
|
|
||
|
# [*, Q, C_q]
|
||
|
o = self.linear_o(o)
|
||
|
|
||
|
return o
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
q_x: torch.Tensor,
|
||
|
kv_x: torch.Tensor,
|
||
|
biases: Optional[List[torch.Tensor]] = None,
|
||
|
use_memory_efficient_kernel: bool = False,
|
||
|
use_lma: bool = False,
|
||
|
lma_q_chunk_size: int = 1024,
|
||
|
lma_kv_chunk_size: int = 4096,
|
||
|
use_flash: bool = False,
|
||
|
flash_mask: Optional[torch.Tensor] = None,
|
||
|
) -> torch.Tensor:
|
||
|
"""
|
||
|
Args:
|
||
|
q_x:
|
||
|
[*, Q, C_q] query data
|
||
|
kv_x:
|
||
|
[*, K, C_k] key data
|
||
|
biases:
|
||
|
List of biases that broadcast to [*, H, Q, K]
|
||
|
use_memory_efficient_kernel:
|
||
|
Whether to use a custom memory-efficient attention kernel. This should be the default choice for most.
|
||
|
If none of the "use_<...>" flags are True, a stock PyTorch implementation is used instead
|
||
|
use_lma:
|
||
|
Whether to use low-memory attention (Staats & Rabe 2021). If none of the "use_<...>" flags are True, a
|
||
|
stock PyTorch implementation is used instead
|
||
|
lma_q_chunk_size:
|
||
|
Query chunk size (for LMA)
|
||
|
lma_kv_chunk_size:
|
||
|
Key/Value chunk size (for LMA)
|
||
|
Returns
|
||
|
[*, Q, C_q] attention update
|
||
|
"""
|
||
|
if use_lma and (lma_q_chunk_size is None or lma_kv_chunk_size is None):
|
||
|
raise ValueError("If use_lma is specified, lma_q_chunk_size and lma_kv_chunk_size must be provided")
|
||
|
|
||
|
if use_flash and biases is not None:
|
||
|
raise ValueError("use_flash is incompatible with the bias option. For masking, use flash_mask instead")
|
||
|
|
||
|
attn_options = [use_memory_efficient_kernel, use_lma, use_flash]
|
||
|
if sum(attn_options) > 1:
|
||
|
raise ValueError("Choose at most one alternative attention algorithm")
|
||
|
|
||
|
if biases is None:
|
||
|
biases = []
|
||
|
|
||
|
# [*, H, Q/K, C_hidden]
|
||
|
query, key, value = self._prep_qkv(q_x, kv_x)
|
||
|
key = permute_final_dims(key, (1, 0))
|
||
|
|
||
|
# [*, H, Q, K]
|
||
|
output = torch.matmul(query, key)
|
||
|
for b in biases:
|
||
|
output += b
|
||
|
output = softmax_no_cast(output, -1)
|
||
|
|
||
|
# [*, H, Q, C_hidden]
|
||
|
output = torch.matmul(output, value)
|
||
|
output = output.transpose(-2, -3)
|
||
|
output = self._wrap_up(output, q_x)
|
||
|
|
||
|
return output
|
||
|
|
||
|
|
||
|
class EsmFoldTriangleAttention(nn.Module):
|
||
|
def __init__(self, c_in, c_hidden, no_heads, starting=True, inf=1e9):
|
||
|
"""
|
||
|
Args:
|
||
|
c_in:
|
||
|
Input channel dimension
|
||
|
c_hidden:
|
||
|
Overall hidden channel dimension (not per-head)
|
||
|
no_heads:
|
||
|
Number of attention heads
|
||
|
"""
|
||
|
super().__init__()
|
||
|
|
||
|
self.c_in = c_in
|
||
|
self.c_hidden = c_hidden
|
||
|
self.no_heads = no_heads
|
||
|
self.starting = starting
|
||
|
self.inf = inf
|
||
|
|
||
|
self.layer_norm = LayerNorm(self.c_in)
|
||
|
|
||
|
self.linear = EsmFoldLinear(c_in, self.no_heads, bias=False, init="normal")
|
||
|
|
||
|
self.mha = EsmFoldAttention(self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads)
|
||
|
|
||
|
@torch.jit.ignore
|
||
|
def _chunk(
|
||
|
self,
|
||
|
x: torch.Tensor,
|
||
|
biases: List[torch.Tensor],
|
||
|
chunk_size: int,
|
||
|
use_memory_efficient_kernel: bool = False,
|
||
|
use_lma: bool = False,
|
||
|
inplace_safe: bool = False,
|
||
|
) -> torch.Tensor:
|
||
|
"triangle! triangle!"
|
||
|
mha_inputs = {
|
||
|
"q_x": x,
|
||
|
"kv_x": x,
|
||
|
"biases": biases,
|
||
|
}
|
||
|
|
||
|
return chunk_layer(
|
||
|
partial(self.mha, use_memory_efficient_kernel=use_memory_efficient_kernel, use_lma=use_lma),
|
||
|
mha_inputs,
|
||
|
chunk_size=chunk_size,
|
||
|
no_batch_dims=len(x.shape[:-2]),
|
||
|
_out=x if inplace_safe else None,
|
||
|
)
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
x: torch.Tensor,
|
||
|
mask: Optional[torch.Tensor] = None,
|
||
|
chunk_size: Optional[int] = None,
|
||
|
use_memory_efficient_kernel: bool = False,
|
||
|
use_lma: bool = False,
|
||
|
inplace_safe: bool = False,
|
||
|
) -> torch.Tensor:
|
||
|
"""
|
||
|
Args:
|
||
|
x:
|
||
|
[*, I, J, C_in] input tensor (e.g. the pair representation)
|
||
|
Returns:
|
||
|
[*, I, J, C_in] output tensor
|
||
|
"""
|
||
|
if mask is None:
|
||
|
# [*, I, J]
|
||
|
mask = x.new_ones(
|
||
|
x.shape[:-1],
|
||
|
)
|
||
|
|
||
|
if not self.starting:
|
||
|
x = x.transpose(-2, -3)
|
||
|
mask = mask.transpose(-1, -2)
|
||
|
|
||
|
# [*, I, J, C_in]
|
||
|
x = self.layer_norm(x)
|
||
|
|
||
|
# [*, I, 1, 1, J]
|
||
|
mask_bias = (self.inf * (mask - 1))[..., :, None, None, :]
|
||
|
|
||
|
# [*, H, I, J]
|
||
|
triangle_bias = permute_final_dims(self.linear(x), (2, 0, 1))
|
||
|
|
||
|
# [*, 1, H, I, J]
|
||
|
triangle_bias = triangle_bias.unsqueeze(-4)
|
||
|
|
||
|
biases = [mask_bias, triangle_bias]
|
||
|
|
||
|
if chunk_size is not None:
|
||
|
x = self._chunk(
|
||
|
x,
|
||
|
biases,
|
||
|
chunk_size,
|
||
|
use_memory_efficient_kernel=use_memory_efficient_kernel,
|
||
|
use_lma=use_lma,
|
||
|
inplace_safe=inplace_safe,
|
||
|
)
|
||
|
else:
|
||
|
x = self.mha(
|
||
|
q_x=x, kv_x=x, biases=biases, use_memory_efficient_kernel=use_memory_efficient_kernel, use_lma=use_lma
|
||
|
)
|
||
|
|
||
|
if not self.starting:
|
||
|
x = x.transpose(-2, -3)
|
||
|
|
||
|
return x
|
||
|
|
||
|
|
||
|
class EsmFoldTriangleMultiplicativeUpdate(nn.Module):
|
||
|
"""
|
||
|
Implements Algorithms 11 and 12.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, config, _outgoing=True):
|
||
|
super().__init__()
|
||
|
c_hidden = config.pairwise_state_dim
|
||
|
self._outgoing = _outgoing
|
||
|
|
||
|
self.linear_a_p = EsmFoldLinear(c_hidden, c_hidden)
|
||
|
self.linear_a_g = EsmFoldLinear(c_hidden, c_hidden, init="gating")
|
||
|
self.linear_b_p = EsmFoldLinear(c_hidden, c_hidden)
|
||
|
self.linear_b_g = EsmFoldLinear(c_hidden, c_hidden, init="gating")
|
||
|
self.linear_g = EsmFoldLinear(c_hidden, c_hidden, init="gating")
|
||
|
self.linear_z = EsmFoldLinear(c_hidden, c_hidden, init="final")
|
||
|
|
||
|
self.layer_norm_in = LayerNorm(c_hidden)
|
||
|
self.layer_norm_out = LayerNorm(c_hidden)
|
||
|
|
||
|
self.sigmoid = nn.Sigmoid()
|
||
|
|
||
|
def _combine_projections(
|
||
|
self, a: torch.Tensor, b: torch.Tensor, _inplace_chunk_size: Optional[int] = None
|
||
|
) -> torch.Tensor:
|
||
|
if self._outgoing:
|
||
|
a = permute_final_dims(a, (2, 0, 1))
|
||
|
b = permute_final_dims(b, (2, 1, 0))
|
||
|
else:
|
||
|
a = permute_final_dims(a, (2, 1, 0))
|
||
|
b = permute_final_dims(b, (2, 0, 1))
|
||
|
|
||
|
if _inplace_chunk_size is not None:
|
||
|
# To be replaced by torch vmap
|
||
|
for i in range(0, a.shape[-3], _inplace_chunk_size):
|
||
|
a_chunk = a[..., i : i + _inplace_chunk_size, :, :]
|
||
|
b_chunk = b[..., i : i + _inplace_chunk_size, :, :]
|
||
|
a[..., i : i + _inplace_chunk_size, :, :] = torch.matmul(
|
||
|
a_chunk,
|
||
|
b_chunk,
|
||
|
)
|
||
|
|
||
|
p = a
|
||
|
else:
|
||
|
p = torch.matmul(a, b)
|
||
|
|
||
|
return permute_final_dims(p, (1, 2, 0))
|
||
|
|
||
|
def _inference_forward(
|
||
|
self,
|
||
|
z: torch.Tensor,
|
||
|
mask: Optional[torch.Tensor] = None,
|
||
|
inplace_chunk_size: Optional[int] = None,
|
||
|
with_add: bool = True,
|
||
|
):
|
||
|
"""
|
||
|
Args:
|
||
|
z:
|
||
|
A [*, N, N, C_z] pair representation
|
||
|
mask:
|
||
|
A [*, N, N] pair mask
|
||
|
inplace_chunk_size:
|
||
|
Size of chunks used in the main computation. Increase to trade memory for speed.
|
||
|
with_add:
|
||
|
If True, z is overwritten with (z + update). Otherwise, it is overwritten with (update).
|
||
|
Returns:
|
||
|
A reference to the overwritten z
|
||
|
|
||
|
More memory-efficient, inference-only version of the forward function. Uses in-place operations, fusion of the
|
||
|
addition that happens after this module in the Evoformer, a smidge of recomputation, and a cache of overwritten
|
||
|
values to lower peak memory consumption of this module from 5x the size of the input tensor z to 2.5x its size.
|
||
|
Useful for inference on extremely long sequences.
|
||
|
|
||
|
It works as follows. We will make reference to variables used in the default forward implementation below.
|
||
|
Naively, triangle multiplication attention requires the manifestation of 5 tensors the size of z: 1) z, the
|
||
|
"square" input tensor, 2) a, the first projection of z, 3) b, the second projection of b, 4) g, a z-sized mask,
|
||
|
and 5) a z-sized tensor for intermediate computations. For large N, this is prohibitively expensive; for
|
||
|
N=4000, for example, z is more than 8GB alone. To avoid this problem, we compute b, g, and all intermediate
|
||
|
tensors in small chunks, noting that the chunks required to compute a chunk of the output depend only on the
|
||
|
tensor a and corresponding vertical and horizontal chunks of z. This suggests an algorithm that loops over
|
||
|
pairs of chunks of z: hereafter "columns" and "rows" of z, even though each "column" and "row" in fact contains
|
||
|
inplace_chunk_size contiguous true columns and rows of z. Writing output chunks to a new tensor would bring
|
||
|
total memory consumption down to 3x the size of z. However, more memory can be saved by writing output chunks
|
||
|
directly to z in-place. WLOG, we choose to write output chunks vertically, overwriting the ith "column" of z at
|
||
|
the end of the ith iteration of the main loop. Despite this overwriting, the ith column is always one column
|
||
|
ahead of previously overwritten columns and can be recovered directly from z. After the first iteration,
|
||
|
however, the ith row of z is always at least partially overwritten. For this reason, we introduce the z-cache,
|
||
|
a tensor one-half the size of z. The z-cache initially contains the left half (2nd and 3rd quadrants) of z. For
|
||
|
0 < i < N/2, the missing left part of the ith row of z is recovered from this cache at the beginning of the ith
|
||
|
iteration. Once i exceeds n/2, the cache is "reoriented" to encompass the 3rd and 4th quadrants of z instead.
|
||
|
Though the 3rd quadrant of the original z is entirely overwritten at this point, it can be recovered from the
|
||
|
z-cache itself. Thereafter, the ith row of z can be recovered in its entirety from the reoriented z-cache.
|
||
|
After the final iteration, z has been completely overwritten and contains the triangular multiplicative update.
|
||
|
If with_add is True, it instead contains the sum of z and the triangular multiplicative update. In either case,
|
||
|
peak memory consumption is just 2.5x the size of z, disregarding memory used for chunks and other small
|
||
|
variables.
|
||
|
"""
|
||
|
if mask is None:
|
||
|
mask = z.new_ones(z.shape[:-1])
|
||
|
|
||
|
mask = mask.unsqueeze(-1)
|
||
|
|
||
|
def compute_projection_helper(pair, mask, a=True):
|
||
|
if a:
|
||
|
linear_g = self.linear_a_g
|
||
|
linear_p = self.linear_a_p
|
||
|
else:
|
||
|
linear_g = self.linear_b_g
|
||
|
linear_p = self.linear_b_p
|
||
|
|
||
|
pair = self.layer_norm_in(pair)
|
||
|
p = linear_g(pair)
|
||
|
p.sigmoid_()
|
||
|
p *= linear_p(pair)
|
||
|
p *= mask
|
||
|
p = permute_final_dims(p, (2, 0, 1))
|
||
|
return p
|
||
|
|
||
|
def compute_projection(pair, mask, a=True, chunked=True):
|
||
|
need_transpose = self._outgoing ^ a
|
||
|
if not chunked:
|
||
|
p = compute_projection_helper(pair, mask, a)
|
||
|
if need_transpose:
|
||
|
p = p.transpose(-1, -2)
|
||
|
else:
|
||
|
# This computation is chunked so as not to exceed our 2.5x
|
||
|
# budget with a large intermediate tensor
|
||
|
linear_g = self.linear_a_g if a else self.linear_b_g
|
||
|
c = linear_g.bias.shape[-1]
|
||
|
out_shape = pair.shape[:-3] + (c,) + pair.shape[-3:-1]
|
||
|
p = pair.new_zeros(out_shape)
|
||
|
for i in range(0, pair.shape[-3], inplace_chunk_size):
|
||
|
pair_chunk = pair[..., i : i + inplace_chunk_size, :, :]
|
||
|
pair_chunk = compute_projection_helper(
|
||
|
pair[..., i : i + inplace_chunk_size, :, :],
|
||
|
mask[..., i : i + inplace_chunk_size, :, :],
|
||
|
a,
|
||
|
)
|
||
|
if need_transpose:
|
||
|
pair_chunk = pair_chunk.transpose(-1, -2)
|
||
|
p[..., i : i + inplace_chunk_size] = pair_chunk
|
||
|
else:
|
||
|
p[..., i : i + inplace_chunk_size, :] = pair_chunk
|
||
|
|
||
|
del pair_chunk
|
||
|
|
||
|
return p
|
||
|
|
||
|
# We start by fully manifesting a. In addition to the input, this
|
||
|
# brings total memory consumption to 2x z (disregarding size of chunks)
|
||
|
# [*, N, N, c]
|
||
|
a = compute_projection(z, mask, True, chunked=True)
|
||
|
|
||
|
if inplace_chunk_size is not None:
|
||
|
n = a.shape[-1]
|
||
|
half_n = n // 2 + n % 2
|
||
|
row_dim = -3
|
||
|
col_dim = -2
|
||
|
b_chunk_dim = row_dim if self._outgoing else col_dim
|
||
|
|
||
|
def empty_slicer(t):
|
||
|
return [slice(None) for _ in t.shape]
|
||
|
|
||
|
def slice_tensor(t, start, end, dim):
|
||
|
# Slices start:end from the dim dimension of t
|
||
|
s = empty_slicer(t)
|
||
|
s[dim] = slice(start, end)
|
||
|
return t[s]
|
||
|
|
||
|
def flip_z_cache_(z_cache, z):
|
||
|
# "Reorient" the z_cache (see below), filling it with quadrants
|
||
|
# 3---recovered from the z_cache---and 4---recovered from z---
|
||
|
# of the input tensor z.
|
||
|
quadrant_3 = slice_tensor(z_cache, half_n, None, row_dim)
|
||
|
z_cache = z_cache.transpose(row_dim, col_dim)
|
||
|
|
||
|
# If n is odd, we need to shrink the z_cache by one row
|
||
|
z_cache = z_cache[..., : (n // 2), :, :]
|
||
|
|
||
|
# Move the 3rd quadrant of z into the
|
||
|
first_half_slicer = empty_slicer(z_cache)
|
||
|
first_half_slicer[col_dim] = slice(0, half_n)
|
||
|
z_cache[first_half_slicer] = quadrant_3
|
||
|
|
||
|
# Get the fourth quadrant of z
|
||
|
quadrant_4 = slice_tensor(z, half_n, None, row_dim)
|
||
|
quadrant_4 = slice_tensor(quadrant_4, half_n, None, col_dim)
|
||
|
|
||
|
# Insert said quadrant into the rotated z-cache
|
||
|
quadrant_3_slicer = empty_slicer(z_cache)
|
||
|
quadrant_3_slicer[col_dim] = slice(half_n, None)
|
||
|
|
||
|
z_cache[quadrant_3_slicer] = quadrant_4
|
||
|
|
||
|
return z_cache
|
||
|
|
||
|
# Initialize the z cache to the left half of z.
|
||
|
z_cache_shape = list(z.shape)
|
||
|
z_cache_shape[col_dim] = half_n
|
||
|
z_cache = z.new_zeros(z_cache_shape)
|
||
|
z_cache_slicer = empty_slicer(z_cache)
|
||
|
z_cache_slicer[col_dim] = slice(0, half_n)
|
||
|
z_cache.copy_(z[z_cache_slicer])
|
||
|
z_cache_rotated = False
|
||
|
|
||
|
# We need to reorient the z-cache at the halfway point, and we
|
||
|
# don't want a single chunk to straddle that point. We contract one
|
||
|
# of the chunks in the middle to address that problem.
|
||
|
i_range = list(range(0, half_n, inplace_chunk_size))
|
||
|
initial_offsets = [i_2 - i_1 for i_1, i_2 in zip(i_range, i_range[1:] + [half_n])]
|
||
|
after_half = list(range(half_n, n, inplace_chunk_size))
|
||
|
after_half_offsets = [inplace_chunk_size for _ in after_half]
|
||
|
combined_range_with_offsets = zip(i_range + after_half, initial_offsets + after_half_offsets)
|
||
|
for i, offset in combined_range_with_offsets:
|
||
|
if not z_cache_rotated and i >= half_n:
|
||
|
z_cache = flip_z_cache_(z_cache, z)
|
||
|
z_cache_rotated = True
|
||
|
|
||
|
z_chunk_b = slice_tensor(z, i, i + offset, b_chunk_dim)
|
||
|
mask_chunk = slice_tensor(mask, i, i + offset, b_chunk_dim)
|
||
|
|
||
|
z_chunk_b = z_chunk_b.clone()
|
||
|
if b_chunk_dim == col_dim:
|
||
|
z_chunk_b = slice_tensor(z, i, i + offset, col_dim)
|
||
|
else: # b_chunk_dim == row_dim
|
||
|
# In this case, the b-dimension (b_chunk_dim) is partially
|
||
|
# overwritten at the end of each iteration. We need to
|
||
|
# restore the missing component from the z-cache.
|
||
|
if not z_cache_rotated:
|
||
|
z_chunk_slicer = empty_slicer(z_chunk_b)
|
||
|
z_chunk_slicer[col_dim] = slice(0, half_n)
|
||
|
z_chunk_b[z_chunk_slicer] = slice_tensor(z_cache, i, i + offset, row_dim)
|
||
|
else:
|
||
|
z_cache_offset = i - half_n
|
||
|
z_chunk_b = slice_tensor(z_cache, z_cache_offset, z_cache_offset + offset, row_dim)
|
||
|
|
||
|
b_chunk = compute_projection(z_chunk_b, mask_chunk, a=False, chunked=False)
|
||
|
del z_chunk_b
|
||
|
|
||
|
x_chunk = torch.matmul(a, b_chunk)
|
||
|
x_chunk = permute_final_dims(x_chunk, (1, 2, 0))
|
||
|
x_chunk = self.layer_norm_out(x_chunk)
|
||
|
x_chunk = self.linear_z(x_chunk)
|
||
|
|
||
|
# The g dimension (col_dim) is parallel to and ahead of the
|
||
|
# overwrites in z. We can extract the g chunk normally.
|
||
|
z_chunk_g = slice_tensor(z, i, i + offset, col_dim)
|
||
|
g_chunk = self.linear_g(self.layer_norm_in(z_chunk_g))
|
||
|
g_chunk.sigmoid_()
|
||
|
del z_chunk_g
|
||
|
|
||
|
x_chunk *= g_chunk
|
||
|
|
||
|
# Write the columns into z in-place
|
||
|
z_slicer = empty_slicer(z)
|
||
|
z_slicer[col_dim] = slice(i, i + offset)
|
||
|
if with_add:
|
||
|
z[z_slicer] += x_chunk
|
||
|
else:
|
||
|
z[z_slicer] = x_chunk
|
||
|
else:
|
||
|
b = compute_projection(z, mask, False, False)
|
||
|
x = torch.matmul(a, b)
|
||
|
x = self.layer_norm_out(x)
|
||
|
x = self.linear_z(x)
|
||
|
g = self.linear_g(z)
|
||
|
g.sigmoid_()
|
||
|
x *= g
|
||
|
if with_add:
|
||
|
z += x
|
||
|
else:
|
||
|
z = x
|
||
|
|
||
|
return z
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
z: torch.Tensor,
|
||
|
mask: Optional[torch.Tensor] = None,
|
||
|
inplace_safe: bool = False,
|
||
|
_add_with_inplace: bool = False,
|
||
|
_inplace_chunk_size: Optional[int] = 256,
|
||
|
) -> torch.Tensor:
|
||
|
"""
|
||
|
Args:
|
||
|
x:
|
||
|
[*, N_res, N_res, C_z] input tensor
|
||
|
mask:
|
||
|
[*, N_res, N_res] input mask
|
||
|
Returns:
|
||
|
[*, N_res, N_res, C_z] output tensor
|
||
|
"""
|
||
|
if inplace_safe:
|
||
|
x = self._inference_forward(
|
||
|
z,
|
||
|
mask,
|
||
|
inplace_chunk_size=_inplace_chunk_size,
|
||
|
with_add=_add_with_inplace,
|
||
|
)
|
||
|
return x
|
||
|
|
||
|
if mask is None:
|
||
|
mask = z.new_ones(z.shape[:-1])
|
||
|
|
||
|
mask = mask.unsqueeze(-1)
|
||
|
|
||
|
z = self.layer_norm_in(z)
|
||
|
a = mask
|
||
|
a = a * self.sigmoid(self.linear_a_g(z))
|
||
|
a = a * self.linear_a_p(z)
|
||
|
b = mask
|
||
|
b = b * self.sigmoid(self.linear_b_g(z))
|
||
|
b = b * self.linear_b_p(z)
|
||
|
|
||
|
if is_fp16_enabled():
|
||
|
with torch.cuda.amp.autocast(enabled=False):
|
||
|
x = self._combine_projections(a.float(), b.float())
|
||
|
else:
|
||
|
x = self._combine_projections(a, b)
|
||
|
|
||
|
del a, b
|
||
|
x = self.layer_norm_out(x)
|
||
|
x = self.linear_z(x)
|
||
|
g = self.sigmoid(self.linear_g(z))
|
||
|
x = x * g
|
||
|
|
||
|
return x
|
||
|
|
||
|
|
||
|
class EsmFoldPreTrainedModel(EsmPreTrainedModel):
|
||
|
"""
|
||
|
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||
|
models.
|
||
|
"""
|
||
|
|
||
|
# Subclass `EsMPreTrainedModel` to deal with special init
|
||
|
def _init_weights(self, module):
|
||
|
"""Initialize the weights"""
|
||
|
if isinstance(module, EsmFoldLinear):
|
||
|
with torch.no_grad():
|
||
|
if module.init_fn is not None:
|
||
|
module.init_fn(module.weight, module.bias)
|
||
|
elif module.init == "default":
|
||
|
trunc_normal_init_(module.weight, scale=1.0)
|
||
|
elif module.init == "relu":
|
||
|
trunc_normal_init_(module.weight, scale=2.0)
|
||
|
elif module.init == "glorot":
|
||
|
nn.init.xavier_uniform_(module.weight, gain=1)
|
||
|
elif module.init == "gating":
|
||
|
module.weight.fill_(0.0)
|
||
|
if module.bias:
|
||
|
module.bias.fill_(1.0)
|
||
|
elif module.init == "normal":
|
||
|
torch.nn.init.kaiming_normal_(module.weight, nonlinearity="linear")
|
||
|
elif module.init == "final":
|
||
|
module.weight.fill_(0.0)
|
||
|
elif isinstance(module, EsmFoldInvariantPointAttention):
|
||
|
ipa_point_weights_init_(module.head_weights)
|
||
|
elif isinstance(module, EsmFoldTriangularSelfAttentionBlock):
|
||
|
torch.nn.init.zeros_(module.tri_mul_in.linear_z.weight)
|
||
|
torch.nn.init.zeros_(module.tri_mul_in.linear_z.bias)
|
||
|
torch.nn.init.zeros_(module.tri_mul_out.linear_z.weight)
|
||
|
torch.nn.init.zeros_(module.tri_mul_out.linear_z.bias)
|
||
|
torch.nn.init.zeros_(module.tri_att_start.mha.linear_o.weight)
|
||
|
torch.nn.init.zeros_(module.tri_att_start.mha.linear_o.bias)
|
||
|
torch.nn.init.zeros_(module.tri_att_end.mha.linear_o.weight)
|
||
|
torch.nn.init.zeros_(module.tri_att_end.mha.linear_o.bias)
|
||
|
|
||
|
torch.nn.init.zeros_(module.sequence_to_pair.o_proj.weight)
|
||
|
torch.nn.init.zeros_(module.sequence_to_pair.o_proj.bias)
|
||
|
torch.nn.init.zeros_(module.pair_to_sequence.linear.weight)
|
||
|
torch.nn.init.zeros_(module.seq_attention.o_proj.weight)
|
||
|
torch.nn.init.zeros_(module.seq_attention.o_proj.bias)
|
||
|
torch.nn.init.zeros_(module.mlp_seq.mlp[-2].weight)
|
||
|
torch.nn.init.zeros_(module.mlp_seq.mlp[-2].bias)
|
||
|
torch.nn.init.zeros_(module.mlp_pair.mlp[-2].weight)
|
||
|
torch.nn.init.zeros_(module.mlp_pair.mlp[-2].bias)
|
||
|
else:
|
||
|
super()._init_weights(module)
|
||
|
|
||
|
|
||
|
class EsmFoldSelfAttention(nn.Module):
|
||
|
def __init__(self, embed_dim, num_heads, head_width, gated=False):
|
||
|
super().__init__()
|
||
|
assert embed_dim == num_heads * head_width
|
||
|
|
||
|
self.embed_dim = embed_dim
|
||
|
self.num_heads = num_heads
|
||
|
self.head_width = head_width
|
||
|
|
||
|
self.proj = nn.Linear(embed_dim, embed_dim * 3, bias=False)
|
||
|
self.o_proj = nn.Linear(embed_dim, embed_dim, bias=True)
|
||
|
self.gated = gated
|
||
|
if gated:
|
||
|
self.g_proj = nn.Linear(embed_dim, embed_dim)
|
||
|
torch.nn.init.zeros_(self.g_proj.weight)
|
||
|
torch.nn.init.ones_(self.g_proj.bias)
|
||
|
|
||
|
self.rescale_factor = self.head_width**-0.5
|
||
|
|
||
|
torch.nn.init.zeros_(self.o_proj.bias)
|
||
|
|
||
|
def forward(self, x, mask=None, bias=None, indices=None):
|
||
|
"""
|
||
|
Basic self attention with optional mask and external pairwise bias. To handle sequences of different lengths,
|
||
|
use mask.
|
||
|
|
||
|
Inputs:
|
||
|
x: batch of input sequneces (.. x L x C) mask: batch of boolean masks where 1=valid, 0=padding position (..
|
||
|
x L_k) bias: batch of scalar pairwise attention biases (.. x Lq x Lk x num_heads)
|
||
|
|
||
|
Outputs:
|
||
|
sequence projection (B x L x embed_dim), attention maps (B x L x L x num_heads)
|
||
|
"""
|
||
|
|
||
|
t = self.proj(x).view(*x.shape[:2], self.num_heads, -1)
|
||
|
t = t.permute(0, 2, 1, 3)
|
||
|
q, k, v = t.chunk(3, dim=-1)
|
||
|
|
||
|
q = self.rescale_factor * q
|
||
|
a = torch.einsum("...qc,...kc->...qk", q, k)
|
||
|
|
||
|
# Add external attention bias.
|
||
|
if bias is not None:
|
||
|
a = a + bias.permute(0, 3, 1, 2)
|
||
|
|
||
|
# Do not attend to padding tokens.
|
||
|
if mask is not None:
|
||
|
mask = mask[:, None, None]
|
||
|
a = a.masked_fill(mask == False, -np.inf) # noqa: E712
|
||
|
|
||
|
a = nn.functional.softmax(a, dim=-1)
|
||
|
|
||
|
y = torch.einsum("...hqk,...hkc->...qhc", a, v)
|
||
|
y = y.reshape(*y.shape[:2], -1)
|
||
|
|
||
|
if self.gated:
|
||
|
y = self.g_proj(x).sigmoid() * y
|
||
|
y = self.o_proj(y)
|
||
|
|
||
|
return y, a.permute(0, 3, 1, 2)
|
||
|
|
||
|
|
||
|
class EsmFoldDropout(nn.Module):
|
||
|
"""
|
||
|
Implementation of dropout with the ability to share the dropout mask along a particular dimension.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, r: float, batch_dim: Union[int, List[int]]):
|
||
|
super().__init__()
|
||
|
|
||
|
self.r = r
|
||
|
if isinstance(batch_dim, int):
|
||
|
batch_dim = [batch_dim]
|
||
|
self.batch_dim = batch_dim
|
||
|
self.dropout = nn.Dropout(self.r)
|
||
|
|
||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
|
shape = list(x.shape)
|
||
|
if self.batch_dim is not None:
|
||
|
for bd in self.batch_dim:
|
||
|
shape[bd] = 1
|
||
|
return x * self.dropout(x.new_ones(shape))
|
||
|
|
||
|
|
||
|
class EsmFoldSequenceToPair(nn.Module):
|
||
|
def __init__(self, sequence_state_dim, inner_dim, pairwise_state_dim):
|
||
|
super().__init__()
|
||
|
|
||
|
self.layernorm = nn.LayerNorm(sequence_state_dim)
|
||
|
self.proj = nn.Linear(sequence_state_dim, inner_dim * 2, bias=True)
|
||
|
self.o_proj = nn.Linear(2 * inner_dim, pairwise_state_dim, bias=True)
|
||
|
|
||
|
torch.nn.init.zeros_(self.proj.bias)
|
||
|
torch.nn.init.zeros_(self.o_proj.bias)
|
||
|
|
||
|
def forward(self, sequence_state):
|
||
|
"""
|
||
|
Inputs:
|
||
|
sequence_state: B x L x sequence_state_dim
|
||
|
|
||
|
Output:
|
||
|
pairwise_state: B x L x L x pairwise_state_dim
|
||
|
|
||
|
Intermediate state:
|
||
|
B x L x L x 2*inner_dim
|
||
|
"""
|
||
|
|
||
|
assert len(sequence_state.shape) == 3
|
||
|
|
||
|
s = self.layernorm(sequence_state)
|
||
|
s = self.proj(s)
|
||
|
q, k = s.chunk(2, dim=-1)
|
||
|
|
||
|
prod = q[:, None, :, :] * k[:, :, None, :]
|
||
|
diff = q[:, None, :, :] - k[:, :, None, :]
|
||
|
|
||
|
x = torch.cat([prod, diff], dim=-1)
|
||
|
x = self.o_proj(x)
|
||
|
|
||
|
return x
|
||
|
|
||
|
|
||
|
class EsmFoldPairToSequence(nn.Module):
|
||
|
def __init__(self, pairwise_state_dim, num_heads):
|
||
|
super().__init__()
|
||
|
|
||
|
self.layernorm = nn.LayerNorm(pairwise_state_dim)
|
||
|
self.linear = nn.Linear(pairwise_state_dim, num_heads, bias=False)
|
||
|
|
||
|
def forward(self, pairwise_state):
|
||
|
"""
|
||
|
Inputs:
|
||
|
pairwise_state: B x L x L x pairwise_state_dim
|
||
|
|
||
|
Output:
|
||
|
pairwise_bias: B x L x L x num_heads
|
||
|
"""
|
||
|
assert len(pairwise_state.shape) == 4
|
||
|
z = self.layernorm(pairwise_state)
|
||
|
pairwise_bias = self.linear(z)
|
||
|
return pairwise_bias
|
||
|
|
||
|
|
||
|
class EsmFoldResidueMLP(nn.Module):
|
||
|
def __init__(self, embed_dim, inner_dim, dropout=0):
|
||
|
super().__init__()
|
||
|
|
||
|
self.mlp = nn.Sequential(
|
||
|
nn.LayerNorm(embed_dim),
|
||
|
nn.Linear(embed_dim, inner_dim),
|
||
|
nn.ReLU(),
|
||
|
nn.Linear(inner_dim, embed_dim),
|
||
|
nn.Dropout(dropout),
|
||
|
)
|
||
|
|
||
|
def forward(self, x):
|
||
|
return x + self.mlp(x)
|
||
|
|
||
|
|
||
|
class EsmFoldTriangularSelfAttentionBlock(nn.Module):
|
||
|
def __init__(self, config):
|
||
|
super().__init__()
|
||
|
self.config = config
|
||
|
|
||
|
sequence_state_dim = config.sequence_state_dim
|
||
|
pairwise_state_dim = config.pairwise_state_dim
|
||
|
sequence_num_heads = sequence_state_dim // config.sequence_head_width
|
||
|
pairwise_num_heads = pairwise_state_dim // config.pairwise_head_width
|
||
|
|
||
|
self.layernorm_1 = nn.LayerNorm(sequence_state_dim)
|
||
|
|
||
|
self.sequence_to_pair = EsmFoldSequenceToPair(sequence_state_dim, pairwise_state_dim // 2, pairwise_state_dim)
|
||
|
self.pair_to_sequence = EsmFoldPairToSequence(pairwise_state_dim, sequence_num_heads)
|
||
|
|
||
|
self.seq_attention = EsmFoldSelfAttention(
|
||
|
sequence_state_dim, sequence_num_heads, config.sequence_head_width, gated=True
|
||
|
)
|
||
|
self.tri_mul_out = EsmFoldTriangleMultiplicativeUpdate(config, _outgoing=True)
|
||
|
self.tri_mul_in = EsmFoldTriangleMultiplicativeUpdate(config, _outgoing=False)
|
||
|
|
||
|
self.tri_att_start = EsmFoldTriangleAttention(
|
||
|
pairwise_state_dim, config.pairwise_head_width, pairwise_num_heads, inf=1e9, starting=True
|
||
|
)
|
||
|
self.tri_att_end = EsmFoldTriangleAttention(
|
||
|
pairwise_state_dim, config.pairwise_head_width, pairwise_num_heads, inf=1e9, starting=False
|
||
|
)
|
||
|
|
||
|
self.mlp_seq = EsmFoldResidueMLP(sequence_state_dim, 4 * sequence_state_dim, dropout=config.dropout)
|
||
|
self.mlp_pair = EsmFoldResidueMLP(pairwise_state_dim, 4 * pairwise_state_dim, dropout=config.dropout)
|
||
|
|
||
|
self.drop = nn.Dropout(config.dropout)
|
||
|
self.row_drop = EsmFoldDropout(config.dropout * 2, 2)
|
||
|
self.col_drop = EsmFoldDropout(config.dropout * 2, 1)
|
||
|
|
||
|
def forward(self, sequence_state, pairwise_state, mask=None, chunk_size=None, **__kwargs):
|
||
|
"""
|
||
|
Inputs:
|
||
|
sequence_state: B x L x sequence_state_dim pairwise_state: B x L x L x pairwise_state_dim mask: B x L boolean
|
||
|
tensor of valid positions
|
||
|
|
||
|
Output:
|
||
|
sequence_state: B x L x sequence_state_dim pairwise_state: B x L x L x pairwise_state_dim
|
||
|
"""
|
||
|
if len(sequence_state.shape) != 3:
|
||
|
raise ValueError(f"`sequence_state` should be a 3d-tensor, got {len(sequence_state.shape)} dims.")
|
||
|
if len(pairwise_state.shape) != 4:
|
||
|
raise ValueError(f"`pairwise_state` should be a 4d-tensor, got {len(pairwise_state.shape)} dims.")
|
||
|
if mask is not None and len(mask.shape) != 2:
|
||
|
raise ValueError(f"`mask` should be a 2d-tensor, got {len(mask.shape)} dims.")
|
||
|
|
||
|
batch_dim, seq_dim, sequence_state_dim = sequence_state.shape
|
||
|
pairwise_state_dim = pairwise_state.shape[3]
|
||
|
|
||
|
if sequence_state_dim != self.config.sequence_state_dim:
|
||
|
raise ValueError(
|
||
|
"`sequence_state` last dimension should be equal to `self.sequence_state_dim`. Got "
|
||
|
f"{sequence_state_dim} != {self.config.sequence_state_dim}."
|
||
|
)
|
||
|
if pairwise_state_dim != self.config.pairwise_state_dim:
|
||
|
raise ValueError(
|
||
|
"`pairwise_state` last dimension should be equal to `self.pairwise_state_dim`. Got "
|
||
|
f"{pairwise_state_dim} != {self.config.pairwise_state_dim}."
|
||
|
)
|
||
|
if batch_dim != pairwise_state.shape[0]:
|
||
|
raise ValueError(
|
||
|
f"`sequence_state` and `pairwise_state` have inconsistent batch size: {batch_dim} != "
|
||
|
f"{pairwise_state.shape[0]}."
|
||
|
)
|
||
|
if seq_dim != pairwise_state.shape[1] or seq_dim != pairwise_state.shape[2]:
|
||
|
raise ValueError(
|
||
|
f"`sequence_state` and `pairwise_state` have inconsistent sequence length: {seq_dim} != "
|
||
|
f"{pairwise_state.shape[1]} or {pairwise_state.shape[2]}."
|
||
|
)
|
||
|
|
||
|
# Update sequence state
|
||
|
bias = self.pair_to_sequence(pairwise_state)
|
||
|
|
||
|
# Self attention with bias + mlp.
|
||
|
y = self.layernorm_1(sequence_state)
|
||
|
y, _ = self.seq_attention(y, mask=mask, bias=bias)
|
||
|
sequence_state = sequence_state + self.drop(y)
|
||
|
sequence_state = self.mlp_seq(sequence_state)
|
||
|
|
||
|
# Update pairwise state
|
||
|
pairwise_state = pairwise_state + self.sequence_to_pair(sequence_state)
|
||
|
|
||
|
# Axial attention with triangular bias.
|
||
|
tri_mask = mask.unsqueeze(2) * mask.unsqueeze(1) if mask is not None else None
|
||
|
pairwise_state = pairwise_state + self.row_drop(self.tri_mul_out(pairwise_state, mask=tri_mask))
|
||
|
pairwise_state = pairwise_state + self.col_drop(self.tri_mul_in(pairwise_state, mask=tri_mask))
|
||
|
pairwise_state = pairwise_state + self.row_drop(
|
||
|
self.tri_att_start(pairwise_state, mask=tri_mask, chunk_size=chunk_size)
|
||
|
)
|
||
|
pairwise_state = pairwise_state + self.col_drop(
|
||
|
self.tri_att_end(pairwise_state, mask=tri_mask, chunk_size=chunk_size)
|
||
|
)
|
||
|
|
||
|
# MLP over pairs.
|
||
|
pairwise_state = self.mlp_pair(pairwise_state)
|
||
|
|
||
|
return sequence_state, pairwise_state
|
||
|
|
||
|
|
||
|
class EsmCategoricalMixture:
|
||
|
def __init__(self, param, bins=50, start=0, end=1):
|
||
|
# All tensors are of shape ..., bins.
|
||
|
self.logits = param
|
||
|
bins = torch.linspace(start, end, bins + 1, device=self.logits.device, dtype=self.logits.dtype)
|
||
|
self.v_bins = (bins[:-1] + bins[1:]) / 2
|
||
|
|
||
|
def log_prob(self, true):
|
||
|
# Shapes are:
|
||
|
# self.probs: ... x bins
|
||
|
# true : ...
|
||
|
true_index = (true.unsqueeze(-1) - self.v_bins[[None] * true.ndim]).abs().argmin(-1)
|
||
|
nll = self.logits.log_softmax(-1)
|
||
|
return torch.take_along_dim(nll, true_index.unsqueeze(-1), dim=-1).squeeze(-1)
|
||
|
|
||
|
def mean(self):
|
||
|
return (self.logits.softmax(-1) @ self.v_bins.unsqueeze(1)).squeeze(-1)
|
||
|
|
||
|
|
||
|
def categorical_lddt(logits, bins=50):
|
||
|
# Logits are ..., 37, bins.
|
||
|
return EsmCategoricalMixture(logits, bins=bins).mean()
|
||
|
|
||
|
|
||
|
def get_axial_mask(mask):
|
||
|
"""
|
||
|
Helper to convert B x L mask of valid positions to axial mask used in row column attentions.
|
||
|
|
||
|
Input:
|
||
|
mask: B x L tensor of booleans
|
||
|
|
||
|
Output:
|
||
|
mask: B x L x L tensor of booleans
|
||
|
"""
|
||
|
|
||
|
if mask is None:
|
||
|
return None
|
||
|
|
||
|
if len(mask.shape) != 2:
|
||
|
raise ValueError(f"`mask` should be a 2d-tensor, got {len(mask.shape)} dims.")
|
||
|
batch_dim, seq_dim = mask.shape
|
||
|
m = mask.unsqueeze(1).expand(batch_dim, seq_dim, seq_dim)
|
||
|
m = m.reshape(batch_dim * seq_dim, seq_dim)
|
||
|
return m
|
||
|
|
||
|
|
||
|
class EsmFoldRelativePosition(nn.Module):
|
||
|
def __init__(self, config):
|
||
|
super().__init__()
|
||
|
self.bins = config.position_bins
|
||
|
|
||
|
# Note an additional offset is used so that the 0th position
|
||
|
# is reserved for masked pairs.
|
||
|
self.embedding = torch.nn.Embedding(2 * self.bins + 2, config.pairwise_state_dim)
|
||
|
|
||
|
def forward(self, residue_index, mask=None):
|
||
|
"""
|
||
|
Input:
|
||
|
residue_index: B x L tensor of indices (dytpe=torch.long) mask: B x L tensor of booleans
|
||
|
|
||
|
Output:
|
||
|
pairwise_state: B x L x L x pairwise_state_dim tensor of embeddings
|
||
|
"""
|
||
|
if residue_index.dtype != torch.long:
|
||
|
raise ValueError(f"`residue_index` has dtype {residue_index.dtype}, it should be `torch.long`.")
|
||
|
if mask is not None and residue_index.shape != mask.shape:
|
||
|
raise ValueError(
|
||
|
f"`residue_index` and `mask` have inconsistent shapes: {residue_index.shape} != {mask.shape}."
|
||
|
)
|
||
|
|
||
|
diff = residue_index[:, None, :] - residue_index[:, :, None]
|
||
|
diff = diff.clamp(-self.bins, self.bins)
|
||
|
diff = diff + self.bins + 1 # Add 1 to adjust for padding index.
|
||
|
|
||
|
if mask is not None:
|
||
|
mask = mask[:, None, :] * mask[:, :, None]
|
||
|
diff[mask == False] = 0 # noqa: E712
|
||
|
|
||
|
output = self.embedding(diff)
|
||
|
return output
|
||
|
|
||
|
|
||
|
class EsmFoldAngleResnetBlock(nn.Module):
|
||
|
def __init__(self, config):
|
||
|
super().__init__()
|
||
|
|
||
|
self.linear_1 = EsmFoldLinear(config.resnet_dim, config.resnet_dim, init="relu")
|
||
|
self.linear_2 = EsmFoldLinear(config.resnet_dim, config.resnet_dim, init="final")
|
||
|
|
||
|
self.relu = nn.ReLU()
|
||
|
|
||
|
def forward(self, a: torch.Tensor) -> torch.Tensor:
|
||
|
s_initial = a
|
||
|
|
||
|
a = self.relu(a)
|
||
|
a = self.linear_1(a)
|
||
|
a = self.relu(a)
|
||
|
a = self.linear_2(a)
|
||
|
|
||
|
return a + s_initial
|
||
|
|
||
|
|
||
|
class EsmFoldAngleResnet(nn.Module):
|
||
|
"""
|
||
|
Implements Algorithm 20, lines 11-14
|
||
|
"""
|
||
|
|
||
|
def __init__(self, config):
|
||
|
super().__init__()
|
||
|
self.config = config
|
||
|
|
||
|
self.linear_in = EsmFoldLinear(config.sequence_dim, config.resnet_dim)
|
||
|
self.linear_initial = EsmFoldLinear(config.sequence_dim, config.resnet_dim)
|
||
|
|
||
|
self.layers = nn.ModuleList()
|
||
|
for _ in range(config.num_resnet_blocks):
|
||
|
layer = EsmFoldAngleResnetBlock(config)
|
||
|
self.layers.append(layer)
|
||
|
|
||
|
self.linear_out = EsmFoldLinear(config.resnet_dim, config.num_angles * 2)
|
||
|
|
||
|
self.relu = nn.ReLU()
|
||
|
|
||
|
def forward(self, s: torch.Tensor, s_initial: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||
|
"""
|
||
|
Args:
|
||
|
s:
|
||
|
[*, C_hidden] single embedding
|
||
|
s_initial:
|
||
|
[*, C_hidden] single embedding as of the start of the StructureModule
|
||
|
Returns:
|
||
|
[*, no_angles, 2] predicted angles
|
||
|
"""
|
||
|
# NOTE: The ReLU's applied to the inputs are absent from the supplement
|
||
|
# pseudocode but present in the source. For maximal compatibility with
|
||
|
# the pretrained weights, I'm going with the source.
|
||
|
|
||
|
# [*, C_hidden]
|
||
|
s_initial = self.relu(s_initial)
|
||
|
s_initial = self.linear_initial(s_initial)
|
||
|
s = self.relu(s)
|
||
|
s = self.linear_in(s)
|
||
|
s = s + s_initial
|
||
|
|
||
|
for l in self.layers:
|
||
|
s = l(s)
|
||
|
|
||
|
s = self.relu(s)
|
||
|
|
||
|
# [*, no_angles * 2]
|
||
|
s = self.linear_out(s)
|
||
|
|
||
|
# [*, no_angles, 2]
|
||
|
s = s.view(s.shape[:-1] + (-1, 2))
|
||
|
|
||
|
unnormalized_s = s
|
||
|
norm_denom = torch.sqrt(
|
||
|
torch.clamp(
|
||
|
torch.sum(s**2, dim=-1, keepdim=True),
|
||
|
min=self.config.epsilon,
|
||
|
)
|
||
|
)
|
||
|
s = s / norm_denom
|
||
|
|
||
|
return unnormalized_s, s
|
||
|
|
||
|
|
||
|
class EsmFoldInvariantPointAttention(nn.Module):
|
||
|
"""
|
||
|
Implements Algorithm 22.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, config):
|
||
|
super().__init__()
|
||
|
self.config = config
|
||
|
|
||
|
c_s = config.sequence_dim
|
||
|
c_z = config.pairwise_dim
|
||
|
self.hidden_dim = config.ipa_dim
|
||
|
self.num_heads = config.num_heads_ipa
|
||
|
self.num_qk_points = config.num_qk_points
|
||
|
self.num_v_points = config.num_v_points
|
||
|
|
||
|
# These linear layers differ from their specifications in the
|
||
|
# supplement. There, they lack bias and use Glorot initialization.
|
||
|
# Here as in the official source, they have bias and use the default
|
||
|
# Lecun initialization.
|
||
|
hc = config.ipa_dim * config.num_heads_ipa
|
||
|
self.linear_q = EsmFoldLinear(c_s, hc)
|
||
|
self.linear_kv = EsmFoldLinear(c_s, 2 * hc)
|
||
|
|
||
|
hpq = config.num_heads_ipa * config.num_qk_points * 3
|
||
|
self.linear_q_points = EsmFoldLinear(c_s, hpq)
|
||
|
|
||
|
hpkv = config.num_heads_ipa * (config.num_qk_points + config.num_v_points) * 3
|
||
|
self.linear_kv_points = EsmFoldLinear(c_s, hpkv)
|
||
|
|
||
|
self.linear_b = EsmFoldLinear(c_z, config.num_heads_ipa)
|
||
|
|
||
|
self.head_weights = nn.Parameter(torch.zeros((config.num_heads_ipa)))
|
||
|
|
||
|
concat_out_dim = config.num_heads_ipa * (c_z + config.ipa_dim + config.num_v_points * 4)
|
||
|
self.linear_out = EsmFoldLinear(concat_out_dim, c_s, init="final")
|
||
|
|
||
|
self.softmax = nn.Softmax(dim=-1)
|
||
|
self.softplus = nn.Softplus()
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
s: torch.Tensor,
|
||
|
z: Optional[torch.Tensor],
|
||
|
r: Rigid,
|
||
|
mask: torch.Tensor,
|
||
|
_offload_inference: bool = False,
|
||
|
_z_reference_list: Optional[Sequence[torch.Tensor]] = None,
|
||
|
) -> torch.Tensor:
|
||
|
"""
|
||
|
Args:
|
||
|
s:
|
||
|
[*, N_res, C_s] single representation
|
||
|
z:
|
||
|
[*, N_res, N_res, C_z] pair representation
|
||
|
r:
|
||
|
[*, N_res] transformation object
|
||
|
mask:
|
||
|
[*, N_res] mask
|
||
|
Returns:
|
||
|
[*, N_res, C_s] single representation update
|
||
|
"""
|
||
|
z = [z]
|
||
|
|
||
|
#######################################
|
||
|
# Generate scalar and point activations
|
||
|
#######################################
|
||
|
# [*, N_res, H * C_hidden]
|
||
|
q = self.linear_q(s)
|
||
|
kv = self.linear_kv(s)
|
||
|
|
||
|
# [*, N_res, H, C_hidden]
|
||
|
q = q.view(q.shape[:-1] + (self.num_heads, -1))
|
||
|
|
||
|
# [*, N_res, H, 2 * C_hidden]
|
||
|
kv = kv.view(kv.shape[:-1] + (self.num_heads, -1))
|
||
|
|
||
|
# [*, N_res, H, C_hidden]
|
||
|
k, v = torch.split(kv, self.hidden_dim, dim=-1)
|
||
|
|
||
|
# [*, N_res, H * P_q * 3]
|
||
|
q_pts = self.linear_q_points(s)
|
||
|
|
||
|
# This is kind of clunky, but it's how the original does it
|
||
|
# [*, N_res, H * P_q, 3]
|
||
|
q_pts = torch.split(q_pts, q_pts.shape[-1] // 3, dim=-1)
|
||
|
q_pts = torch.stack(q_pts, dim=-1)
|
||
|
q_pts = r[..., None].apply(q_pts)
|
||
|
|
||
|
# [*, N_res, H, P_q, 3]
|
||
|
q_pts = q_pts.view(q_pts.shape[:-2] + (self.num_heads, self.num_qk_points, 3))
|
||
|
|
||
|
# [*, N_res, H * (P_q + P_v) * 3]
|
||
|
kv_pts = self.linear_kv_points(s)
|
||
|
|
||
|
# [*, N_res, H * (P_q + P_v), 3]
|
||
|
kv_pts = torch.split(kv_pts, kv_pts.shape[-1] // 3, dim=-1)
|
||
|
kv_pts = torch.stack(kv_pts, dim=-1)
|
||
|
kv_pts = r[..., None].apply(kv_pts)
|
||
|
|
||
|
# [*, N_res, H, (P_q + P_v), 3]
|
||
|
kv_pts = kv_pts.view(kv_pts.shape[:-2] + (self.num_heads, -1, 3))
|
||
|
|
||
|
# [*, N_res, H, P_q/P_v, 3]
|
||
|
k_pts, v_pts = torch.split(kv_pts, [self.num_qk_points, self.num_v_points], dim=-2)
|
||
|
|
||
|
##########################
|
||
|
# Compute attention scores
|
||
|
##########################
|
||
|
# [*, N_res, N_res, H]
|
||
|
b = self.linear_b(z[0])
|
||
|
|
||
|
if _offload_inference:
|
||
|
assert sys.getrefcount(z[0]) == 2
|
||
|
z[0] = z[0].cpu()
|
||
|
|
||
|
# [*, H, N_res, N_res]
|
||
|
if is_fp16_enabled():
|
||
|
with torch.cuda.amp.autocast(enabled=False):
|
||
|
a = torch.matmul(
|
||
|
permute_final_dims(q.float(), (1, 0, 2)), # [*, H, N_res, C_hidden]
|
||
|
permute_final_dims(k.float(), (1, 2, 0)), # [*, H, C_hidden, N_res]
|
||
|
)
|
||
|
else:
|
||
|
a = torch.matmul(
|
||
|
permute_final_dims(q, (1, 0, 2)), # [*, H, N_res, C_hidden]
|
||
|
permute_final_dims(k, (1, 2, 0)), # [*, H, C_hidden, N_res]
|
||
|
)
|
||
|
|
||
|
a *= math.sqrt(1.0 / (3 * self.hidden_dim))
|
||
|
a += math.sqrt(1.0 / 3) * permute_final_dims(b, (2, 0, 1))
|
||
|
|
||
|
# [*, N_res, N_res, H, P_q, 3]
|
||
|
pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5)
|
||
|
pt_att = pt_att**2
|
||
|
|
||
|
# [*, N_res, N_res, H, P_q]
|
||
|
pt_att = sum(torch.unbind(pt_att, dim=-1))
|
||
|
head_weights = self.softplus(self.head_weights).view(*((1,) * len(pt_att.shape[:-2]) + (-1, 1)))
|
||
|
head_weights = head_weights * math.sqrt(1.0 / (3 * (self.num_qk_points * 9.0 / 2)))
|
||
|
pt_att = pt_att * head_weights
|
||
|
|
||
|
# [*, N_res, N_res, H]
|
||
|
pt_att = torch.sum(pt_att, dim=-1) * (-0.5)
|
||
|
# [*, N_res, N_res]
|
||
|
square_mask = mask.unsqueeze(-1) * mask.unsqueeze(-2)
|
||
|
square_mask = self.config.inf * (square_mask - 1)
|
||
|
|
||
|
# [*, H, N_res, N_res]
|
||
|
pt_att = permute_final_dims(pt_att, (2, 0, 1))
|
||
|
|
||
|
a = a + pt_att
|
||
|
a = a + square_mask.unsqueeze(-3)
|
||
|
a = self.softmax(a)
|
||
|
|
||
|
################
|
||
|
# Compute output
|
||
|
################
|
||
|
# [*, N_res, H, C_hidden]
|
||
|
o = torch.matmul(a, v.transpose(-2, -3).to(dtype=a.dtype)).transpose(-2, -3)
|
||
|
|
||
|
# [*, N_res, H * C_hidden]
|
||
|
o = flatten_final_dims(o, 2)
|
||
|
|
||
|
# [*, H, 3, N_res, P_v]
|
||
|
o_pt = torch.sum(
|
||
|
(a[..., None, :, :, None] * permute_final_dims(v_pts, (1, 3, 0, 2))[..., None, :, :]),
|
||
|
dim=-2,
|
||
|
)
|
||
|
|
||
|
# [*, N_res, H, P_v, 3]
|
||
|
o_pt = permute_final_dims(o_pt, (2, 0, 3, 1))
|
||
|
o_pt = r[..., None, None].invert_apply(o_pt)
|
||
|
|
||
|
# [*, N_res, H * P_v]
|
||
|
o_pt_norm = flatten_final_dims(torch.sqrt(torch.sum(o_pt**2, dim=-1) + self.config.epsilon), 2)
|
||
|
|
||
|
# [*, N_res, H * P_v, 3]
|
||
|
o_pt = o_pt.reshape(*o_pt.shape[:-3], -1, 3)
|
||
|
|
||
|
if _offload_inference:
|
||
|
z[0] = z[0].to(o_pt.device)
|
||
|
|
||
|
# [*, N_res, H, C_z]
|
||
|
o_pair = torch.matmul(a.transpose(-2, -3), z[0].to(dtype=a.dtype))
|
||
|
|
||
|
# [*, N_res, H * C_z]
|
||
|
o_pair = flatten_final_dims(o_pair, 2)
|
||
|
|
||
|
# [*, N_res, C_s]
|
||
|
s = self.linear_out(
|
||
|
torch.cat((o, *torch.unbind(o_pt, dim=-1), o_pt_norm, o_pair), dim=-1).to(dtype=z[0].dtype)
|
||
|
)
|
||
|
|
||
|
return s
|
||
|
|
||
|
|
||
|
class EsmFoldBackboneUpdate(nn.Module):
|
||
|
"""
|
||
|
Implements part of Algorithm 23.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, config):
|
||
|
super().__init__()
|
||
|
|
||
|
self.linear = EsmFoldLinear(config.sequence_dim, 6, init="final")
|
||
|
|
||
|
def forward(self, s: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||
|
"""
|
||
|
Args:
|
||
|
[*, N_res, C_s] single representation
|
||
|
Returns:
|
||
|
[*, N_res, 6] update vector
|
||
|
"""
|
||
|
# [*, 6]
|
||
|
update = self.linear(s)
|
||
|
|
||
|
return update
|
||
|
|
||
|
|
||
|
class EsmFoldStructureModuleTransitionLayer(nn.Module):
|
||
|
def __init__(self, config):
|
||
|
super().__init__()
|
||
|
|
||
|
self.linear_1 = EsmFoldLinear(config.sequence_dim, config.sequence_dim, init="relu")
|
||
|
self.linear_2 = EsmFoldLinear(config.sequence_dim, config.sequence_dim, init="relu")
|
||
|
self.linear_3 = EsmFoldLinear(config.sequence_dim, config.sequence_dim, init="final")
|
||
|
|
||
|
self.relu = nn.ReLU()
|
||
|
|
||
|
def forward(self, s):
|
||
|
s_initial = s
|
||
|
s = self.linear_1(s)
|
||
|
s = self.relu(s)
|
||
|
s = self.linear_2(s)
|
||
|
s = self.relu(s)
|
||
|
s = self.linear_3(s)
|
||
|
|
||
|
s = s + s_initial
|
||
|
|
||
|
return s
|
||
|
|
||
|
|
||
|
class EsmFoldStructureModuleTransition(nn.Module):
|
||
|
def __init__(self, config):
|
||
|
super().__init__()
|
||
|
self.config = config
|
||
|
|
||
|
self.layers = nn.ModuleList()
|
||
|
for _ in range(config.num_transition_layers):
|
||
|
l = EsmFoldStructureModuleTransitionLayer(config)
|
||
|
self.layers.append(l)
|
||
|
|
||
|
self.dropout = nn.Dropout(config.dropout_rate)
|
||
|
self.layer_norm = LayerNorm(config.sequence_dim)
|
||
|
|
||
|
def forward(self, s):
|
||
|
for l in self.layers:
|
||
|
s = l(s)
|
||
|
|
||
|
s = self.dropout(s)
|
||
|
s = self.layer_norm(s)
|
||
|
|
||
|
return s
|
||
|
|
||
|
|
||
|
class EsmFoldStructureModule(nn.Module):
|
||
|
def __init__(self, config):
|
||
|
super().__init__()
|
||
|
self.config = config
|
||
|
|
||
|
# Buffers to be lazily initialized later
|
||
|
# self.default_frames
|
||
|
# self.group_idx
|
||
|
# self.atom_mask
|
||
|
# self.lit_positions
|
||
|
|
||
|
self.layer_norm_s = LayerNorm(config.sequence_dim)
|
||
|
self.layer_norm_z = LayerNorm(config.pairwise_dim)
|
||
|
|
||
|
self.linear_in = EsmFoldLinear(config.sequence_dim, config.sequence_dim)
|
||
|
|
||
|
self.ipa = EsmFoldInvariantPointAttention(config)
|
||
|
|
||
|
self.ipa_dropout = nn.Dropout(config.dropout_rate)
|
||
|
self.layer_norm_ipa = LayerNorm(config.sequence_dim)
|
||
|
|
||
|
self.transition = EsmFoldStructureModuleTransition(config)
|
||
|
self.bb_update = EsmFoldBackboneUpdate(config)
|
||
|
self.angle_resnet = EsmFoldAngleResnet(config)
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
evoformer_output_dict,
|
||
|
aatype,
|
||
|
mask=None,
|
||
|
_offload_inference=False,
|
||
|
):
|
||
|
"""
|
||
|
Args:
|
||
|
evoformer_output_dict:
|
||
|
Dictionary containing:
|
||
|
"single":
|
||
|
[*, N_res, C_s] single representation
|
||
|
"pair":
|
||
|
[*, N_res, N_res, C_z] pair representation
|
||
|
aatype:
|
||
|
[*, N_res] amino acid indices
|
||
|
mask:
|
||
|
Optional [*, N_res] sequence mask
|
||
|
Returns:
|
||
|
A dictionary of outputs
|
||
|
"""
|
||
|
s = evoformer_output_dict["single"]
|
||
|
|
||
|
if mask is None:
|
||
|
# [*, N]
|
||
|
mask = s.new_ones(s.shape[:-1])
|
||
|
|
||
|
# [*, N, C_s]
|
||
|
s = self.layer_norm_s(s)
|
||
|
|
||
|
# [*, N, N, C_z]
|
||
|
z = self.layer_norm_z(evoformer_output_dict["pair"])
|
||
|
|
||
|
z_reference_list = None
|
||
|
if _offload_inference:
|
||
|
assert sys.getrefcount(evoformer_output_dict["pair"]) == 2
|
||
|
evoformer_output_dict["pair"] = evoformer_output_dict["pair"].cpu()
|
||
|
z_reference_list = [z]
|
||
|
z = None
|
||
|
|
||
|
# [*, N, C_s]
|
||
|
s_initial = s
|
||
|
s = self.linear_in(s)
|
||
|
|
||
|
# [*, N]
|
||
|
rigids = Rigid.identity(
|
||
|
s.shape[:-1],
|
||
|
s.dtype,
|
||
|
s.device,
|
||
|
self.training,
|
||
|
fmt="quat",
|
||
|
)
|
||
|
outputs = []
|
||
|
for i in range(self.config.num_blocks):
|
||
|
# [*, N, C_s]
|
||
|
s = s + self.ipa(
|
||
|
s,
|
||
|
z,
|
||
|
rigids,
|
||
|
mask,
|
||
|
_offload_inference=_offload_inference,
|
||
|
_z_reference_list=z_reference_list,
|
||
|
)
|
||
|
s = self.ipa_dropout(s)
|
||
|
s = self.layer_norm_ipa(s)
|
||
|
s = self.transition(s)
|
||
|
|
||
|
# [*, N]
|
||
|
rigids = rigids.compose_q_update_vec(self.bb_update(s))
|
||
|
|
||
|
# To hew as closely as possible to AlphaFold, we convert our
|
||
|
# quaternion-based transformations to rotation-matrix ones
|
||
|
# here
|
||
|
backb_to_global = Rigid(
|
||
|
Rotation(rot_mats=rigids.get_rots().get_rot_mats(), quats=None),
|
||
|
rigids.get_trans(),
|
||
|
)
|
||
|
|
||
|
backb_to_global = backb_to_global.scale_translation(self.config.trans_scale_factor)
|
||
|
|
||
|
# [*, N, 7, 2]
|
||
|
unnormalized_angles, angles = self.angle_resnet(s, s_initial)
|
||
|
|
||
|
all_frames_to_global = self.torsion_angles_to_frames(backb_to_global, angles, aatype)
|
||
|
|
||
|
pred_xyz = self.frames_and_literature_positions_to_atom14_pos(all_frames_to_global, aatype)
|
||
|
|
||
|
scaled_rigids = rigids.scale_translation(self.config.trans_scale_factor)
|
||
|
|
||
|
preds = {
|
||
|
"frames": scaled_rigids.to_tensor_7(),
|
||
|
"sidechain_frames": all_frames_to_global.to_tensor_4x4(),
|
||
|
"unnormalized_angles": unnormalized_angles,
|
||
|
"angles": angles,
|
||
|
"positions": pred_xyz,
|
||
|
"states": s,
|
||
|
}
|
||
|
|
||
|
outputs.append(preds)
|
||
|
|
||
|
rigids = rigids.stop_rot_gradient()
|
||
|
|
||
|
del z, z_reference_list
|
||
|
|
||
|
if _offload_inference:
|
||
|
evoformer_output_dict["pair"] = evoformer_output_dict["pair"].to(s.device)
|
||
|
|
||
|
outputs = dict_multimap(torch.stack, outputs)
|
||
|
outputs["single"] = s
|
||
|
|
||
|
return outputs
|
||
|
|
||
|
def _init_residue_constants(self, float_dtype, device):
|
||
|
if not hasattr(self, "default_frames"):
|
||
|
self.register_buffer(
|
||
|
"default_frames",
|
||
|
torch.tensor(
|
||
|
residue_constants.restype_rigid_group_default_frame,
|
||
|
dtype=float_dtype,
|
||
|
device=device,
|
||
|
requires_grad=False,
|
||
|
),
|
||
|
persistent=False,
|
||
|
)
|
||
|
if not hasattr(self, "group_idx"):
|
||
|
self.register_buffer(
|
||
|
"group_idx",
|
||
|
torch.tensor(
|
||
|
residue_constants.restype_atom14_to_rigid_group,
|
||
|
device=device,
|
||
|
requires_grad=False,
|
||
|
),
|
||
|
persistent=False,
|
||
|
)
|
||
|
if not hasattr(self, "atom_mask"):
|
||
|
self.register_buffer(
|
||
|
"atom_mask",
|
||
|
torch.tensor(
|
||
|
residue_constants.restype_atom14_mask,
|
||
|
dtype=float_dtype,
|
||
|
device=device,
|
||
|
requires_grad=False,
|
||
|
),
|
||
|
persistent=False,
|
||
|
)
|
||
|
if not hasattr(self, "lit_positions"):
|
||
|
self.register_buffer(
|
||
|
"lit_positions",
|
||
|
torch.tensor(
|
||
|
residue_constants.restype_atom14_rigid_group_positions,
|
||
|
dtype=float_dtype,
|
||
|
device=device,
|
||
|
requires_grad=False,
|
||
|
),
|
||
|
persistent=False,
|
||
|
)
|
||
|
|
||
|
def torsion_angles_to_frames(self, r, alpha, f):
|
||
|
# Lazily initialize the residue constants on the correct device
|
||
|
self._init_residue_constants(alpha.dtype, alpha.device)
|
||
|
# Separated purely to make testing less annoying
|
||
|
return torsion_angles_to_frames(r, alpha, f, self.default_frames)
|
||
|
|
||
|
def frames_and_literature_positions_to_atom14_pos(self, r, f): # [*, N, 8] # [*, N]
|
||
|
# Lazily initialize the residue constants on the correct device
|
||
|
self._init_residue_constants(r.get_rots().dtype, r.get_rots().device)
|
||
|
return frames_and_literature_positions_to_atom14_pos(
|
||
|
r,
|
||
|
f,
|
||
|
self.default_frames,
|
||
|
self.group_idx,
|
||
|
self.atom_mask,
|
||
|
self.lit_positions,
|
||
|
)
|
||
|
|
||
|
|
||
|
class EsmFoldingTrunk(nn.Module):
|
||
|
def __init__(self, config):
|
||
|
super().__init__()
|
||
|
self.config = config
|
||
|
|
||
|
c_s = config.sequence_state_dim
|
||
|
c_z = config.pairwise_state_dim
|
||
|
|
||
|
self.pairwise_positional_embedding = EsmFoldRelativePosition(config)
|
||
|
|
||
|
self.blocks = nn.ModuleList([EsmFoldTriangularSelfAttentionBlock(config) for _ in range(config.num_blocks)])
|
||
|
|
||
|
self.recycle_bins = 15
|
||
|
self.recycle_s_norm = nn.LayerNorm(c_s)
|
||
|
self.recycle_z_norm = nn.LayerNorm(c_z)
|
||
|
self.recycle_disto = nn.Embedding(self.recycle_bins, c_z)
|
||
|
self.recycle_disto.weight[0].detach().zero_()
|
||
|
|
||
|
self.structure_module = EsmFoldStructureModule(config.structure_module)
|
||
|
self.trunk2sm_s = nn.Linear(c_s, config.structure_module.sequence_dim)
|
||
|
self.trunk2sm_z = nn.Linear(c_z, config.structure_module.pairwise_dim)
|
||
|
|
||
|
self.chunk_size = config.chunk_size
|
||
|
|
||
|
def set_chunk_size(self, chunk_size):
|
||
|
# This parameter means the axial attention will be computed
|
||
|
# in a chunked manner. This should make the memory used more or less O(L) instead of O(L^2).
|
||
|
# It's equivalent to running a for loop over chunks of the dimension we're iterative over,
|
||
|
# where the chunk_size is the size of the chunks, so 128 would mean to parse 128-length chunks.
|
||
|
self.chunk_size = chunk_size
|
||
|
|
||
|
def forward(self, seq_feats, pair_feats, true_aa, residx, mask, no_recycles):
|
||
|
"""
|
||
|
Inputs:
|
||
|
seq_feats: B x L x C tensor of sequence features pair_feats: B x L x L x C tensor of pair features residx: B
|
||
|
x L long tensor giving the position in the sequence mask: B x L boolean tensor indicating valid residues
|
||
|
|
||
|
Output:
|
||
|
predicted_structure: B x L x (num_atoms_per_residue * 3) tensor wrapped in a Coordinates object
|
||
|
"""
|
||
|
|
||
|
device = seq_feats.device
|
||
|
s_s_0 = seq_feats
|
||
|
s_z_0 = pair_feats
|
||
|
|
||
|
if no_recycles is None:
|
||
|
no_recycles = self.config.max_recycles
|
||
|
else:
|
||
|
if no_recycles < 0:
|
||
|
raise ValueError("Number of recycles must not be negative.")
|
||
|
no_recycles += 1 # First 'recycle' is just the standard forward pass through the model.
|
||
|
|
||
|
def trunk_iter(s, z, residx, mask):
|
||
|
z = z + self.pairwise_positional_embedding(residx, mask=mask)
|
||
|
|
||
|
for block in self.blocks:
|
||
|
s, z = block(s, z, mask=mask, residue_index=residx, chunk_size=self.chunk_size)
|
||
|
return s, z
|
||
|
|
||
|
s_s = s_s_0
|
||
|
s_z = s_z_0
|
||
|
recycle_s = torch.zeros_like(s_s)
|
||
|
recycle_z = torch.zeros_like(s_z)
|
||
|
recycle_bins = torch.zeros(*s_z.shape[:-1], device=device, dtype=torch.int64)
|
||
|
|
||
|
for recycle_idx in range(no_recycles):
|
||
|
with ContextManagers([] if recycle_idx == no_recycles - 1 else [torch.no_grad()]):
|
||
|
# === Recycling ===
|
||
|
recycle_s = self.recycle_s_norm(recycle_s.detach()).to(device)
|
||
|
recycle_z = self.recycle_z_norm(recycle_z.detach()).to(device)
|
||
|
recycle_z += self.recycle_disto(recycle_bins.detach()).to(device)
|
||
|
|
||
|
s_s, s_z = trunk_iter(s_s_0 + recycle_s, s_z_0 + recycle_z, residx, mask)
|
||
|
|
||
|
# === Structure module ===
|
||
|
structure = self.structure_module(
|
||
|
{"single": self.trunk2sm_s(s_s), "pair": self.trunk2sm_z(s_z)},
|
||
|
true_aa,
|
||
|
mask.float(),
|
||
|
)
|
||
|
|
||
|
recycle_s = s_s
|
||
|
recycle_z = s_z
|
||
|
# Distogram needs the N, CA, C coordinates, and bin constants same as alphafold.
|
||
|
recycle_bins = EsmFoldingTrunk.distogram(
|
||
|
structure["positions"][-1][:, :, :3],
|
||
|
3.375,
|
||
|
21.375,
|
||
|
self.recycle_bins,
|
||
|
)
|
||
|
|
||
|
structure["s_s"] = s_s
|
||
|
structure["s_z"] = s_z
|
||
|
|
||
|
return structure
|
||
|
|
||
|
@staticmethod
|
||
|
def distogram(coords, min_bin, max_bin, num_bins):
|
||
|
# Coords are [... L x 3 x 3], where it's [N, CA, C] x 3 coordinates.
|
||
|
boundaries = torch.linspace(
|
||
|
min_bin,
|
||
|
max_bin,
|
||
|
num_bins - 1,
|
||
|
device=coords.device,
|
||
|
)
|
||
|
boundaries = boundaries**2
|
||
|
N, CA, C = [x.squeeze(-2) for x in coords.chunk(3, dim=-2)]
|
||
|
# Infer CB coordinates.
|
||
|
b = CA - N
|
||
|
c = C - CA
|
||
|
a = b.cross(c, dim=-1)
|
||
|
CB = -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + CA
|
||
|
dists = (CB[..., None, :, :] - CB[..., :, None, :]).pow(2).sum(dim=-1, keepdims=True)
|
||
|
bins = torch.sum(dists > boundaries, dim=-1) # [..., L, L]
|
||
|
return bins
|
||
|
|
||
|
|
||
|
# TODO Add information to the docstring about any methods that convert to PDB format, or otherwise prepare
|
||
|
# the outputs for downstream use.
|
||
|
|
||
|
|
||
|
@add_start_docstrings(
|
||
|
"""
|
||
|
ESMForProteinFolding is the HuggingFace port of the original ESMFold model. It consists of an ESM-2 "stem" followed
|
||
|
by a protein folding "head", although unlike most other output heads, this "head" is similar in size and runtime to
|
||
|
the rest of the model combined! It outputs a dictionary containing predicted structural information about the input
|
||
|
protein(s).
|
||
|
""",
|
||
|
ESM_START_DOCSTRING,
|
||
|
)
|
||
|
class EsmForProteinFolding(EsmPreTrainedModel):
|
||
|
_no_split_modules = ["EsmFoldStructureModule", "EsmFoldTriangularSelfAttentionBlock"]
|
||
|
|
||
|
def __init__(self, config):
|
||
|
super().__init__(config)
|
||
|
|
||
|
self.config = config
|
||
|
|
||
|
self.distogram_bins = 64
|
||
|
|
||
|
self.esm = EsmModel(config, add_pooling_layer=False)
|
||
|
|
||
|
self.esm.requires_grad_(False)
|
||
|
if self.config.esmfold_config.fp16_esm:
|
||
|
self.esm.half()
|
||
|
|
||
|
self.esm_feats = self.config.hidden_size
|
||
|
self.esm_attns = self.config.num_hidden_layers * self.config.num_attention_heads
|
||
|
self.esm_layers = self.config.num_hidden_layers
|
||
|
self.register_buffer("af2_to_esm", self._af2_to_esm_from_vocab_list(config.vocab_list))
|
||
|
self.esm_s_combine = nn.Parameter(torch.zeros(self.esm_layers + 1))
|
||
|
|
||
|
trunk_config = self.config.esmfold_config.trunk
|
||
|
c_s = trunk_config.sequence_state_dim
|
||
|
c_z = trunk_config.pairwise_state_dim
|
||
|
self.esm_s_mlp = nn.Sequential(
|
||
|
LayerNorm(self.esm_feats),
|
||
|
nn.Linear(self.esm_feats, c_s),
|
||
|
nn.ReLU(),
|
||
|
nn.Linear(c_s, c_s),
|
||
|
)
|
||
|
|
||
|
# 0 is padding, N is unknown residues, N + 1 is mask.
|
||
|
self.n_tokens_embed = residue_constants.restype_num + 3
|
||
|
self.pad_idx = 0
|
||
|
self.unk_idx = self.n_tokens_embed - 2
|
||
|
self.mask_idx = self.n_tokens_embed - 1
|
||
|
self.esm_dict_cls_idx = self.config.vocab_list.index("<cls>")
|
||
|
self.esm_dict_mask_idx = self.config.vocab_list.index("<mask>")
|
||
|
self.esm_dict_eos_idx = self.config.vocab_list.index("<eos>")
|
||
|
self.esm_dict_padding_idx = self.config.vocab_list.index("<pad>")
|
||
|
if self.config.esmfold_config.embed_aa:
|
||
|
self.embedding = nn.Embedding(self.n_tokens_embed, c_s, padding_idx=0)
|
||
|
|
||
|
self.trunk = EsmFoldingTrunk(trunk_config)
|
||
|
|
||
|
self.distogram_head = nn.Linear(c_z, self.distogram_bins)
|
||
|
self.ptm_head = nn.Linear(c_z, self.distogram_bins)
|
||
|
self.lm_head = nn.Linear(c_s, self.n_tokens_embed)
|
||
|
self.lddt_bins = 50
|
||
|
structure_module_config = trunk_config.structure_module
|
||
|
self.lddt_head = nn.Sequential(
|
||
|
nn.LayerNorm(structure_module_config.sequence_dim),
|
||
|
nn.Linear(structure_module_config.sequence_dim, self.config.esmfold_config.lddt_head_hid_dim),
|
||
|
nn.Linear(self.config.esmfold_config.lddt_head_hid_dim, self.config.esmfold_config.lddt_head_hid_dim),
|
||
|
nn.Linear(self.config.esmfold_config.lddt_head_hid_dim, 37 * self.lddt_bins),
|
||
|
)
|
||
|
|
||
|
@staticmethod
|
||
|
def _af2_to_esm_from_vocab_list(vocab_list: List[str]) -> torch.Tensor:
|
||
|
# Remember that t is shifted from residue_constants by 1 (0 is padding).
|
||
|
esm_reorder = [vocab_list.index("<pad>")] + [vocab_list.index(v) for v in residue_constants.restypes_with_x]
|
||
|
return torch.tensor(esm_reorder)
|
||
|
|
||
|
@add_start_docstrings_to_model_forward(ESMFOLD_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||
|
@replace_return_docstrings(output_type=EsmForProteinFoldingOutput, config_class=EsmConfig)
|
||
|
def forward(
|
||
|
self,
|
||
|
input_ids: torch.Tensor,
|
||
|
attention_mask: Optional[torch.Tensor] = None,
|
||
|
position_ids: Optional[torch.Tensor] = None,
|
||
|
masking_pattern: Optional[torch.Tensor] = None,
|
||
|
num_recycles: Optional[int] = None,
|
||
|
) -> EsmForProteinFoldingOutput:
|
||
|
r"""
|
||
|
Returns:
|
||
|
|
||
|
Example:
|
||
|
|
||
|
```python
|
||
|
>>> from transformers import AutoTokenizer, EsmForProteinFolding
|
||
|
|
||
|
>>> model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1")
|
||
|
>>> tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1")
|
||
|
>>> inputs = tokenizer(["MLKNVQVQLV"], return_tensors="pt", add_special_tokens=False) # A tiny random peptide
|
||
|
>>> outputs = model(**inputs)
|
||
|
>>> folded_positions = outputs.positions
|
||
|
```
|
||
|
|
||
|
"""
|
||
|
cfg = self.config.esmfold_config
|
||
|
|
||
|
aa = input_ids # B x L
|
||
|
B = aa.shape[0]
|
||
|
L = aa.shape[1]
|
||
|
device = input_ids.device
|
||
|
if attention_mask is None:
|
||
|
attention_mask = torch.ones_like(aa, device=device)
|
||
|
if position_ids is None:
|
||
|
position_ids = torch.arange(L, device=device).expand_as(input_ids)
|
||
|
|
||
|
# === ESM ===
|
||
|
esmaa = self.af2_idx_to_esm_idx(aa, attention_mask)
|
||
|
|
||
|
if masking_pattern is not None:
|
||
|
masked_aa, esmaa, mlm_targets = self.bert_mask(aa, esmaa, attention_mask, masking_pattern)
|
||
|
else:
|
||
|
masked_aa = aa
|
||
|
mlm_targets = None
|
||
|
|
||
|
# We get sequence and pair representations from whatever version of ESM /
|
||
|
# configuration we are using. The sequence representation esm_s is always
|
||
|
# present. The pair embedding esm_z may be present depending on the
|
||
|
# configuration of the model. If esm_z is not used by the model then it
|
||
|
# is returned as None here.
|
||
|
esm_s = self.compute_language_model_representations(esmaa)
|
||
|
|
||
|
# Convert esm_s and esm_z, if present, to the precision used by the trunk and
|
||
|
# the structure module. These tensors may be a lower precision if, for example,
|
||
|
# we're running the language model in fp16 precision.
|
||
|
esm_s = esm_s.to(self.esm_s_combine.dtype)
|
||
|
|
||
|
if cfg.esm_ablate_sequence:
|
||
|
esm_s = esm_s * 0
|
||
|
|
||
|
esm_s = esm_s.detach()
|
||
|
|
||
|
# === preprocessing ===
|
||
|
esm_s = (self.esm_s_combine.softmax(0).unsqueeze(0) @ esm_s).squeeze(2)
|
||
|
s_s_0 = self.esm_s_mlp(esm_s)
|
||
|
|
||
|
s_z_0 = s_s_0.new_zeros(B, L, L, cfg.trunk.pairwise_state_dim)
|
||
|
|
||
|
if self.config.esmfold_config.embed_aa:
|
||
|
s_s_0 += self.embedding(masked_aa)
|
||
|
|
||
|
structure: dict = self.trunk(s_s_0, s_z_0, aa, position_ids, attention_mask, no_recycles=num_recycles)
|
||
|
# Documenting what we expect:
|
||
|
structure = {
|
||
|
k: v
|
||
|
for k, v in structure.items()
|
||
|
if k
|
||
|
in [
|
||
|
"s_z",
|
||
|
"s_s",
|
||
|
"frames",
|
||
|
"sidechain_frames",
|
||
|
"unnormalized_angles",
|
||
|
"angles",
|
||
|
"positions",
|
||
|
"states",
|
||
|
]
|
||
|
}
|
||
|
|
||
|
# Add BERT mask for the loss to use, if available.
|
||
|
if mlm_targets:
|
||
|
structure["mlm_targets"] = mlm_targets
|
||
|
|
||
|
disto_logits = self.distogram_head(structure["s_z"])
|
||
|
disto_logits = (disto_logits + disto_logits.transpose(1, 2)) / 2
|
||
|
structure["distogram_logits"] = disto_logits
|
||
|
|
||
|
lm_logits = self.lm_head(structure["s_s"])
|
||
|
structure["lm_logits"] = lm_logits
|
||
|
|
||
|
structure["aatype"] = aa
|
||
|
make_atom14_masks(structure)
|
||
|
# Of course, this doesn't respect the true mask because it doesn't know about it...
|
||
|
# We're not going to properly mask change of index tensors:
|
||
|
# "residx_atom14_to_atom37",
|
||
|
# "residx_atom37_to_atom14",
|
||
|
for k in [
|
||
|
"atom14_atom_exists",
|
||
|
"atom37_atom_exists",
|
||
|
]:
|
||
|
structure[k] *= attention_mask.unsqueeze(-1)
|
||
|
structure["residue_index"] = position_ids
|
||
|
|
||
|
lddt_head = self.lddt_head(structure["states"]).reshape(structure["states"].shape[0], B, L, -1, self.lddt_bins)
|
||
|
structure["lddt_head"] = lddt_head
|
||
|
plddt = categorical_lddt(lddt_head[-1], bins=self.lddt_bins)
|
||
|
structure["plddt"] = plddt
|
||
|
|
||
|
ptm_logits = self.ptm_head(structure["s_z"])
|
||
|
structure["ptm_logits"] = ptm_logits
|
||
|
structure["ptm"] = compute_tm(ptm_logits, max_bin=31, no_bins=self.distogram_bins)
|
||
|
structure.update(compute_predicted_aligned_error(ptm_logits, max_bin=31, no_bins=self.distogram_bins))
|
||
|
|
||
|
return EsmForProteinFoldingOutput(**structure)
|
||
|
|
||
|
def af2_idx_to_esm_idx(self, aa, mask):
|
||
|
# avoid indexing on different devices
|
||
|
if self.af2_to_esm.device != aa.device:
|
||
|
self.af2_to_esm = self.af2_to_esm.to(aa.device)
|
||
|
aa = (aa + 1).masked_fill(mask != 1, 0)
|
||
|
return self.af2_to_esm[aa]
|
||
|
|
||
|
def compute_language_model_representations(self, esmaa: torch.Tensor) -> torch.Tensor:
|
||
|
device = next(self.parameters()).device
|
||
|
B, L = esmaa.shape # B = batch size, L = sequence length.
|
||
|
|
||
|
if self.config.esmfold_config.bypass_lm:
|
||
|
esm_s = torch.zeros(B, L, self.esm_s_combine.size[0], -1, self.esm_feats, device=device)
|
||
|
return esm_s
|
||
|
|
||
|
bosi, eosi = self.esm_dict_cls_idx, self.esm_dict_eos_idx
|
||
|
bos = esmaa.new_full((B, 1), bosi)
|
||
|
eos = esmaa.new_full((B, 1), self.esm_dict_padding_idx)
|
||
|
esmaa = torch.cat([bos, esmaa, eos], dim=1)
|
||
|
# Use the first padding index as eos during inference.
|
||
|
esmaa[range(B), (esmaa != 1).sum(1)] = eosi
|
||
|
|
||
|
# _, esm_z, esm_s = self.esm(esmaa, return_pairs=self.config.esmfold_config.use_esm_attn_map)
|
||
|
# Because we do not support use_esm_attn_map in the HF port as it is not used in any public models,
|
||
|
# esm_z is always None
|
||
|
esm_hidden_states = self.esm(esmaa, attention_mask=esmaa != 1, output_hidden_states=True)["hidden_states"]
|
||
|
esm_s = torch.stack(esm_hidden_states, dim=2)
|
||
|
|
||
|
esm_s = esm_s[:, 1:-1] # B, L, nLayers, C
|
||
|
|
||
|
return esm_s
|
||
|
|
||
|
def bert_mask(self, aa, esmaa, mask, pattern):
|
||
|
new_aa = aa.clone()
|
||
|
target = aa.clone()
|
||
|
new_esmaa = esmaa.clone()
|
||
|
new_aa[pattern == 1] = self.mask_idx
|
||
|
target[pattern != 1] = 0
|
||
|
new_esmaa[pattern == 1] = self.esm_dict_mask_idx
|
||
|
return new_aa, new_esmaa, target
|
||
|
|
||
|
@torch.no_grad()
|
||
|
def infer(
|
||
|
self,
|
||
|
seqs: Union[str, List[str]],
|
||
|
position_ids=None,
|
||
|
):
|
||
|
if isinstance(seqs, str):
|
||
|
lst = [seqs]
|
||
|
else:
|
||
|
lst = seqs
|
||
|
# Returns the raw outputs of the model given an input sequence.
|
||
|
device = next(self.parameters()).device
|
||
|
aatype = collate_dense_tensors(
|
||
|
[
|
||
|
torch.from_numpy(
|
||
|
residue_constants.sequence_to_onehot(
|
||
|
sequence=seq,
|
||
|
mapping=residue_constants.restype_order_with_x,
|
||
|
map_unknown_to_x=True,
|
||
|
)
|
||
|
)
|
||
|
.to(device)
|
||
|
.argmax(dim=1)
|
||
|
for seq in lst
|
||
|
]
|
||
|
) # B=1 x L
|
||
|
mask = collate_dense_tensors([aatype.new_ones(len(seq)) for seq in lst])
|
||
|
position_ids = (
|
||
|
torch.arange(aatype.shape[1], device=device).expand(len(lst), -1)
|
||
|
if position_ids is None
|
||
|
else position_ids.to(device)
|
||
|
)
|
||
|
if position_ids.ndim == 1:
|
||
|
position_ids = position_ids.unsqueeze(0)
|
||
|
return self.forward(
|
||
|
aatype,
|
||
|
mask,
|
||
|
position_ids=position_ids,
|
||
|
)
|
||
|
|
||
|
@staticmethod
|
||
|
def output_to_pdb(output: Dict) -> List[str]:
|
||
|
"""Returns the pbd (file) string from the model given the model output."""
|
||
|
output = {k: v.to("cpu").numpy() for k, v in output.items()}
|
||
|
pdbs = []
|
||
|
final_atom_positions = atom14_to_atom37(output["positions"][-1], output)
|
||
|
final_atom_mask = output["atom37_atom_exists"]
|
||
|
for i in range(output["aatype"].shape[0]):
|
||
|
aa = output["aatype"][i]
|
||
|
pred_pos = final_atom_positions[i]
|
||
|
mask = final_atom_mask[i]
|
||
|
resid = output["residue_index"][i] + 1
|
||
|
pred = OFProtein(
|
||
|
aatype=aa,
|
||
|
atom_positions=pred_pos,
|
||
|
atom_mask=mask,
|
||
|
residue_index=resid,
|
||
|
b_factors=output["plddt"][i],
|
||
|
)
|
||
|
pdbs.append(to_pdb(pred))
|
||
|
return pdbs
|
||
|
|
||
|
def infer_pdb(self, seqs, *args, **kwargs) -> str:
|
||
|
"""Returns the pdb (file) string from the model given an input sequence."""
|
||
|
assert isinstance(seqs, str)
|
||
|
output = self.infer(seqs, *args, **kwargs)
|
||
|
return self.output_to_pdb(output)[0]
|
||
|
|
||
|
def infer_pdbs(self, seqs: List[str], *args, **kwargs) -> List[str]:
|
||
|
"""Returns the pdb (file) string from the model given an input sequence."""
|
||
|
output = self.infer(seqs, *args, **kwargs)
|
||
|
return self.output_to_pdb(output)
|