ai-content-maker/.venv/Lib/site-packages/transformers/models/esm/modeling_esmfold.py

2323 lines
85 KiB
Python

# coding=utf-8
# Copyright 2022 Meta and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import sys
from dataclasses import dataclass
from functools import partial
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
from torch.nn import LayerNorm
from ...integrations.deepspeed import is_deepspeed_available
from ...modeling_outputs import ModelOutput
from ...utils import (
ContextManagers,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_scipy_available,
logging,
replace_return_docstrings,
)
from .configuration_esm import EsmConfig
from .modeling_esm import ESM_START_DOCSTRING, EsmModel, EsmPreTrainedModel
from .openfold_utils import (
OFProtein,
Rigid,
Rotation,
atom14_to_atom37,
chunk_layer,
compute_predicted_aligned_error,
compute_tm,
frames_and_literature_positions_to_atom14_pos,
make_atom14_masks,
residue_constants,
to_pdb,
torsion_angles_to_frames,
)
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "facebook/esmfold_v1"
_CONFIG_FOR_DOC = "EsmConfig"
@dataclass
class EsmForProteinFoldingOutput(ModelOutput):
"""
Output type of [`EsmForProteinFoldingOutput`].
Args:
frames (`torch.FloatTensor`):
Output frames.
sidechain_frames (`torch.FloatTensor`):
Output sidechain frames.
unnormalized_angles (`torch.FloatTensor`):
Predicted unnormalized backbone and side chain torsion angles.
angles (`torch.FloatTensor`):
Predicted backbone and side chain torsion angles.
positions (`torch.FloatTensor`):
Predicted positions of the backbone and side chain atoms.
states (`torch.FloatTensor`):
Hidden states from the protein folding trunk.
s_s (`torch.FloatTensor`):
Per-residue embeddings derived by concatenating the hidden states of each layer of the ESM-2 LM stem.
s_z (`torch.FloatTensor`):
Pairwise residue embeddings.
distogram_logits (`torch.FloatTensor`):
Input logits to the distogram used to compute residue distances.
lm_logits (`torch.FloatTensor`):
Logits output by the ESM-2 protein language model stem.
aatype (`torch.FloatTensor`):
Input amino acids (AlphaFold2 indices).
atom14_atom_exists (`torch.FloatTensor`):
Whether each atom exists in the atom14 representation.
residx_atom14_to_atom37 (`torch.FloatTensor`):
Mapping between atoms in the atom14 and atom37 representations.
residx_atom37_to_atom14 (`torch.FloatTensor`):
Mapping between atoms in the atom37 and atom14 representations.
atom37_atom_exists (`torch.FloatTensor`):
Whether each atom exists in the atom37 representation.
residue_index (`torch.FloatTensor`):
The index of each residue in the protein chain. Unless internal padding tokens are used, this will just be
a sequence of integers from 0 to `sequence_length`.
lddt_head (`torch.FloatTensor`):
Raw outputs from the lddt head used to compute plddt.
plddt (`torch.FloatTensor`):
Per-residue confidence scores. Regions of low confidence may indicate areas where the model's prediction is
uncertain, or where the protein structure is disordered.
ptm_logits (`torch.FloatTensor`):
Raw logits used for computing ptm.
ptm (`torch.FloatTensor`):
TM-score output representing the model's high-level confidence in the overall structure.
aligned_confidence_probs (`torch.FloatTensor`):
Per-residue confidence scores for the aligned structure.
predicted_aligned_error (`torch.FloatTensor`):
Predicted error between the model's prediction and the ground truth.
max_predicted_aligned_error (`torch.FloatTensor`):
Per-sample maximum predicted error.
"""
frames: torch.FloatTensor = None
sidechain_frames: torch.FloatTensor = None
unnormalized_angles: torch.FloatTensor = None
angles: torch.FloatTensor = None
positions: torch.FloatTensor = None
states: torch.FloatTensor = None
s_s: torch.FloatTensor = None
s_z: torch.FloatTensor = None
distogram_logits: torch.FloatTensor = None
lm_logits: torch.FloatTensor = None
aatype: torch.FloatTensor = None
atom14_atom_exists: torch.FloatTensor = None
residx_atom14_to_atom37: torch.FloatTensor = None
residx_atom37_to_atom14: torch.FloatTensor = None
atom37_atom_exists: torch.FloatTensor = None
residue_index: torch.FloatTensor = None
lddt_head: torch.FloatTensor = None
plddt: torch.FloatTensor = None
ptm_logits: torch.FloatTensor = None
ptm: torch.FloatTensor = None
aligned_confidence_probs: torch.FloatTensor = None
predicted_aligned_error: torch.FloatTensor = None
max_predicted_aligned_error: torch.FloatTensor = None
ESMFOLD_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `({0})`):
Indices of input sequence tokens in the vocabulary.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.max_position_embeddings - 1]`.
[What are position IDs?](../glossary#position-ids)
masking_pattern (`torch.LongTensor` of shape `({0})`, *optional*):
Locations of tokens to mask during training as a form of regularization. Mask values selected in `[0, 1]`.
num_recycles (`int`, *optional*, defaults to `None`):
Number of times to recycle the input sequence. If `None`, defaults to `config.num_recycles`. "Recycling"
consists of passing the output of the folding trunk back in as input to the trunk. During training, the
number of recycles should vary with each batch, to ensure that the model learns to output valid predictions
after each recycle. During inference, num_recycles should be set to the highest value that the model was
trained with for maximum accuracy. Accordingly, when this value is set to `None`, config.max_recycles is
used.
"""
def is_fp16_enabled():
# Autocast world
fp16_enabled = torch.get_autocast_gpu_dtype() == torch.float16
fp16_enabled = fp16_enabled and torch.is_autocast_enabled()
return fp16_enabled
def is_deepspeed_initialized():
if is_deepspeed_available():
return False
else:
try:
import deepspeed
# This is not available in all DeepSpeed versions.
return deepspeed.utils.is_initialized()
except Exception:
return False
def collate_dense_tensors(samples: List[torch.Tensor], pad_v: float = 0) -> torch.Tensor:
"""
Takes a list of tensors with the following dimensions:
[(d_11, ..., d_1K),
(d_21, ..., d_2K), ..., (d_N1, ..., d_NK)]
and stack + pads them into a single tensor of:
(N, max_i=1,N { d_i1 }, ..., max_i=1,N {diK})
"""
if len(samples) == 0:
return torch.Tensor()
if len({x.dim() for x in samples}) != 1:
raise RuntimeError(f"Samples has varying dimensions: {[x.dim() for x in samples]}")
(device,) = tuple({x.device for x in samples}) # assumes all on same device
max_shape = [max(lst) for lst in zip(*[x.shape for x in samples])]
result = torch.empty(len(samples), *max_shape, dtype=samples[0].dtype, device=device)
result.fill_(pad_v)
for i in range(len(samples)):
result_i = result[i]
t = samples[i]
result_i[tuple(slice(0, k) for k in t.shape)] = t
return result
def flatten_final_dims(t: torch.Tensor, no_dims: int):
return t.reshape(t.shape[:-no_dims] + (-1,))
def permute_final_dims(tensor: torch.Tensor, inds: List[int]):
zero_index = -1 * len(inds)
first_inds = list(range(len(tensor.shape[:zero_index])))
return tensor.permute(first_inds + [zero_index + i for i in inds])
def dict_multimap(fn, dicts):
first = dicts[0]
new_dict = {}
for k, v in first.items():
all_v = [d[k] for d in dicts]
if isinstance(v, dict):
new_dict[k] = dict_multimap(fn, all_v)
else:
new_dict[k] = fn(all_v)
return new_dict
def trunc_normal_init_(weights, scale=1.0, fan="fan_in"):
shape = weights.shape
scale = scale / max(1, shape[1])
if not is_scipy_available():
logger.warning(
"This init requires scipy, but scipy was not found, default to an approximation that might not be"
" equivalent."
)
std = math.sqrt(scale)
torch.nn.init.normal_(weights, std=std).clamp(min=0.0, max=2.0 * std)
else:
from scipy.stats import truncnorm
std = math.sqrt(scale) / truncnorm.std(a=-2, b=2, loc=0, scale=1)
samples = truncnorm.rvs(a=-2, b=2, loc=0, scale=std, size=weights.numel())
samples = np.reshape(samples, shape)
weights.copy_(torch.tensor(samples, device=weights.device))
def ipa_point_weights_init_(weights):
with torch.no_grad():
softplus_inverse_1 = 0.541324854612918
weights.fill_(softplus_inverse_1)
class EsmFoldLinear(nn.Linear):
"""
A Linear layer with built-in nonstandard initializations. Called just like torch.nn.Linear.
Implements the initializers in 1.11.4, plus some additional ones found in the code.
"""
def __init__(
self,
in_dim: int,
out_dim: int,
bias: bool = True,
init: str = "default",
init_fn: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None,
):
"""
Args:
in_dim:
The final dimension of inputs to the layer
out_dim:
The final dimension of layer outputs
bias:
Whether to learn an additive bias. True by default
init:
The initializer to use. Choose from:
"default": LeCun fan-in truncated normal initialization "relu": He initialization w/ truncated normal
distribution "glorot": Fan-average Glorot uniform initialization "gating": Weights=0, Bias=1 "normal":
Normal initialization with std=1/sqrt(fan_in) "final": Weights=0, Bias=0
Overridden by init_fn if the latter is not None.
init_fn:
A custom initializer taking weight and bias as inputs. Overrides init if not None.
"""
super().__init__(in_dim, out_dim, bias=bias)
if bias:
with torch.no_grad():
self.bias.fill_(0)
self.init = init
self.init_fn = init_fn
if init not in ["default", "relu", "glorot", "gating", "normal", "final"]:
raise ValueError("Invalid init string.")
class EsmFoldLayerNorm(nn.Module):
def __init__(self, c_in, eps=1e-5):
super().__init__()
self.c_in = (c_in,)
self.eps = eps
self.weight = nn.Parameter(torch.ones(c_in))
self.bias = nn.Parameter(torch.zeros(c_in))
def forward(self, x):
d = x.dtype
if d is torch.bfloat16 and not is_deepspeed_initialized():
with torch.cuda.amp.autocast(enabled=False):
out = nn.functional.layer_norm(x, self.c_in, self.weight.to(dtype=d), self.bias.to(dtype=d), self.eps)
else:
out = nn.functional.layer_norm(x, self.c_in, self.weight, self.bias, self.eps)
return out
@torch.jit.ignore
def softmax_no_cast(t: torch.Tensor, dim: int = -1) -> torch.Tensor:
"""
Softmax, but without automatic casting to fp32 when the input is of type bfloat16
"""
d = t.dtype
if d is torch.bfloat16 and not is_deepspeed_initialized():
with torch.cuda.amp.autocast(enabled=False):
s = torch.nn.functional.softmax(t, dim=dim)
else:
s = torch.nn.functional.softmax(t, dim=dim)
return s
class EsmFoldAttention(nn.Module):
"""
Standard multi-head attention using AlphaFold's default layer initialization. Allows multiple bias vectors.
"""
def __init__(
self,
c_q: int,
c_k: int,
c_v: int,
c_hidden: int,
no_heads: int,
gating: bool = True,
):
"""
Args:
c_q:
Input dimension of query data
c_k:
Input dimension of key data
c_v:
Input dimension of value data
c_hidden:
Per-head hidden dimension
no_heads:
Number of attention heads
gating:
Whether the output should be gated using query data
"""
super().__init__()
self.c_q = c_q
self.c_k = c_k
self.c_v = c_v
self.c_hidden = c_hidden
self.no_heads = no_heads
self.gating = gating
# DISCREPANCY: c_hidden is not the per-head channel dimension, as
# stated in the supplement, but the overall channel dimension.
self.linear_q = EsmFoldLinear(self.c_q, self.c_hidden * self.no_heads, bias=False, init="glorot")
self.linear_k = EsmFoldLinear(self.c_k, self.c_hidden * self.no_heads, bias=False, init="glorot")
self.linear_v = EsmFoldLinear(self.c_v, self.c_hidden * self.no_heads, bias=False, init="glorot")
self.linear_o = EsmFoldLinear(self.c_hidden * self.no_heads, self.c_q, init="final")
self.linear_g = None
if self.gating:
self.linear_g = EsmFoldLinear(self.c_q, self.c_hidden * self.no_heads, init="gating")
self.sigmoid = nn.Sigmoid()
def _prep_qkv(self, q_x: torch.Tensor, kv_x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# [*, Q/K/V, H * C_hidden]
q = self.linear_q(q_x)
k = self.linear_k(kv_x)
v = self.linear_v(kv_x)
# [*, Q/K, H, C_hidden]
q = q.view(q.shape[:-1] + (self.no_heads, -1))
k = k.view(k.shape[:-1] + (self.no_heads, -1))
v = v.view(v.shape[:-1] + (self.no_heads, -1))
# [*, H, Q/K, C_hidden]
q = q.transpose(-2, -3)
k = k.transpose(-2, -3)
v = v.transpose(-2, -3)
q /= math.sqrt(self.c_hidden)
return q, k, v
def _wrap_up(self, o: torch.Tensor, q_x: torch.Tensor) -> torch.Tensor:
if self.linear_g is not None:
g = self.sigmoid(self.linear_g(q_x))
# [*, Q, H, C_hidden]
g = g.view(g.shape[:-1] + (self.no_heads, -1))
o = o * g
# [*, Q, H * C_hidden]
o = flatten_final_dims(o, 2)
# [*, Q, C_q]
o = self.linear_o(o)
return o
def forward(
self,
q_x: torch.Tensor,
kv_x: torch.Tensor,
biases: Optional[List[torch.Tensor]] = None,
use_memory_efficient_kernel: bool = False,
use_lma: bool = False,
lma_q_chunk_size: int = 1024,
lma_kv_chunk_size: int = 4096,
use_flash: bool = False,
flash_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Args:
q_x:
[*, Q, C_q] query data
kv_x:
[*, K, C_k] key data
biases:
List of biases that broadcast to [*, H, Q, K]
use_memory_efficient_kernel:
Whether to use a custom memory-efficient attention kernel. This should be the default choice for most.
If none of the "use_<...>" flags are True, a stock PyTorch implementation is used instead
use_lma:
Whether to use low-memory attention (Staats & Rabe 2021). If none of the "use_<...>" flags are True, a
stock PyTorch implementation is used instead
lma_q_chunk_size:
Query chunk size (for LMA)
lma_kv_chunk_size:
Key/Value chunk size (for LMA)
Returns
[*, Q, C_q] attention update
"""
if use_lma and (lma_q_chunk_size is None or lma_kv_chunk_size is None):
raise ValueError("If use_lma is specified, lma_q_chunk_size and lma_kv_chunk_size must be provided")
if use_flash and biases is not None:
raise ValueError("use_flash is incompatible with the bias option. For masking, use flash_mask instead")
attn_options = [use_memory_efficient_kernel, use_lma, use_flash]
if sum(attn_options) > 1:
raise ValueError("Choose at most one alternative attention algorithm")
if biases is None:
biases = []
# [*, H, Q/K, C_hidden]
query, key, value = self._prep_qkv(q_x, kv_x)
key = permute_final_dims(key, (1, 0))
# [*, H, Q, K]
output = torch.matmul(query, key)
for b in biases:
output += b
output = softmax_no_cast(output, -1)
# [*, H, Q, C_hidden]
output = torch.matmul(output, value)
output = output.transpose(-2, -3)
output = self._wrap_up(output, q_x)
return output
class EsmFoldTriangleAttention(nn.Module):
def __init__(self, c_in, c_hidden, no_heads, starting=True, inf=1e9):
"""
Args:
c_in:
Input channel dimension
c_hidden:
Overall hidden channel dimension (not per-head)
no_heads:
Number of attention heads
"""
super().__init__()
self.c_in = c_in
self.c_hidden = c_hidden
self.no_heads = no_heads
self.starting = starting
self.inf = inf
self.layer_norm = LayerNorm(self.c_in)
self.linear = EsmFoldLinear(c_in, self.no_heads, bias=False, init="normal")
self.mha = EsmFoldAttention(self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads)
@torch.jit.ignore
def _chunk(
self,
x: torch.Tensor,
biases: List[torch.Tensor],
chunk_size: int,
use_memory_efficient_kernel: bool = False,
use_lma: bool = False,
inplace_safe: bool = False,
) -> torch.Tensor:
"triangle! triangle!"
mha_inputs = {
"q_x": x,
"kv_x": x,
"biases": biases,
}
return chunk_layer(
partial(self.mha, use_memory_efficient_kernel=use_memory_efficient_kernel, use_lma=use_lma),
mha_inputs,
chunk_size=chunk_size,
no_batch_dims=len(x.shape[:-2]),
_out=x if inplace_safe else None,
)
def forward(
self,
x: torch.Tensor,
mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None,
use_memory_efficient_kernel: bool = False,
use_lma: bool = False,
inplace_safe: bool = False,
) -> torch.Tensor:
"""
Args:
x:
[*, I, J, C_in] input tensor (e.g. the pair representation)
Returns:
[*, I, J, C_in] output tensor
"""
if mask is None:
# [*, I, J]
mask = x.new_ones(
x.shape[:-1],
)
if not self.starting:
x = x.transpose(-2, -3)
mask = mask.transpose(-1, -2)
# [*, I, J, C_in]
x = self.layer_norm(x)
# [*, I, 1, 1, J]
mask_bias = (self.inf * (mask - 1))[..., :, None, None, :]
# [*, H, I, J]
triangle_bias = permute_final_dims(self.linear(x), (2, 0, 1))
# [*, 1, H, I, J]
triangle_bias = triangle_bias.unsqueeze(-4)
biases = [mask_bias, triangle_bias]
if chunk_size is not None:
x = self._chunk(
x,
biases,
chunk_size,
use_memory_efficient_kernel=use_memory_efficient_kernel,
use_lma=use_lma,
inplace_safe=inplace_safe,
)
else:
x = self.mha(
q_x=x, kv_x=x, biases=biases, use_memory_efficient_kernel=use_memory_efficient_kernel, use_lma=use_lma
)
if not self.starting:
x = x.transpose(-2, -3)
return x
class EsmFoldTriangleMultiplicativeUpdate(nn.Module):
"""
Implements Algorithms 11 and 12.
"""
def __init__(self, config, _outgoing=True):
super().__init__()
c_hidden = config.pairwise_state_dim
self._outgoing = _outgoing
self.linear_a_p = EsmFoldLinear(c_hidden, c_hidden)
self.linear_a_g = EsmFoldLinear(c_hidden, c_hidden, init="gating")
self.linear_b_p = EsmFoldLinear(c_hidden, c_hidden)
self.linear_b_g = EsmFoldLinear(c_hidden, c_hidden, init="gating")
self.linear_g = EsmFoldLinear(c_hidden, c_hidden, init="gating")
self.linear_z = EsmFoldLinear(c_hidden, c_hidden, init="final")
self.layer_norm_in = LayerNorm(c_hidden)
self.layer_norm_out = LayerNorm(c_hidden)
self.sigmoid = nn.Sigmoid()
def _combine_projections(
self, a: torch.Tensor, b: torch.Tensor, _inplace_chunk_size: Optional[int] = None
) -> torch.Tensor:
if self._outgoing:
a = permute_final_dims(a, (2, 0, 1))
b = permute_final_dims(b, (2, 1, 0))
else:
a = permute_final_dims(a, (2, 1, 0))
b = permute_final_dims(b, (2, 0, 1))
if _inplace_chunk_size is not None:
# To be replaced by torch vmap
for i in range(0, a.shape[-3], _inplace_chunk_size):
a_chunk = a[..., i : i + _inplace_chunk_size, :, :]
b_chunk = b[..., i : i + _inplace_chunk_size, :, :]
a[..., i : i + _inplace_chunk_size, :, :] = torch.matmul(
a_chunk,
b_chunk,
)
p = a
else:
p = torch.matmul(a, b)
return permute_final_dims(p, (1, 2, 0))
def _inference_forward(
self,
z: torch.Tensor,
mask: Optional[torch.Tensor] = None,
inplace_chunk_size: Optional[int] = None,
with_add: bool = True,
):
"""
Args:
z:
A [*, N, N, C_z] pair representation
mask:
A [*, N, N] pair mask
inplace_chunk_size:
Size of chunks used in the main computation. Increase to trade memory for speed.
with_add:
If True, z is overwritten with (z + update). Otherwise, it is overwritten with (update).
Returns:
A reference to the overwritten z
More memory-efficient, inference-only version of the forward function. Uses in-place operations, fusion of the
addition that happens after this module in the Evoformer, a smidge of recomputation, and a cache of overwritten
values to lower peak memory consumption of this module from 5x the size of the input tensor z to 2.5x its size.
Useful for inference on extremely long sequences.
It works as follows. We will make reference to variables used in the default forward implementation below.
Naively, triangle multiplication attention requires the manifestation of 5 tensors the size of z: 1) z, the
"square" input tensor, 2) a, the first projection of z, 3) b, the second projection of b, 4) g, a z-sized mask,
and 5) a z-sized tensor for intermediate computations. For large N, this is prohibitively expensive; for
N=4000, for example, z is more than 8GB alone. To avoid this problem, we compute b, g, and all intermediate
tensors in small chunks, noting that the chunks required to compute a chunk of the output depend only on the
tensor a and corresponding vertical and horizontal chunks of z. This suggests an algorithm that loops over
pairs of chunks of z: hereafter "columns" and "rows" of z, even though each "column" and "row" in fact contains
inplace_chunk_size contiguous true columns and rows of z. Writing output chunks to a new tensor would bring
total memory consumption down to 3x the size of z. However, more memory can be saved by writing output chunks
directly to z in-place. WLOG, we choose to write output chunks vertically, overwriting the ith "column" of z at
the end of the ith iteration of the main loop. Despite this overwriting, the ith column is always one column
ahead of previously overwritten columns and can be recovered directly from z. After the first iteration,
however, the ith row of z is always at least partially overwritten. For this reason, we introduce the z-cache,
a tensor one-half the size of z. The z-cache initially contains the left half (2nd and 3rd quadrants) of z. For
0 < i < N/2, the missing left part of the ith row of z is recovered from this cache at the beginning of the ith
iteration. Once i exceeds n/2, the cache is "reoriented" to encompass the 3rd and 4th quadrants of z instead.
Though the 3rd quadrant of the original z is entirely overwritten at this point, it can be recovered from the
z-cache itself. Thereafter, the ith row of z can be recovered in its entirety from the reoriented z-cache.
After the final iteration, z has been completely overwritten and contains the triangular multiplicative update.
If with_add is True, it instead contains the sum of z and the triangular multiplicative update. In either case,
peak memory consumption is just 2.5x the size of z, disregarding memory used for chunks and other small
variables.
"""
if mask is None:
mask = z.new_ones(z.shape[:-1])
mask = mask.unsqueeze(-1)
def compute_projection_helper(pair, mask, a=True):
if a:
linear_g = self.linear_a_g
linear_p = self.linear_a_p
else:
linear_g = self.linear_b_g
linear_p = self.linear_b_p
pair = self.layer_norm_in(pair)
p = linear_g(pair)
p.sigmoid_()
p *= linear_p(pair)
p *= mask
p = permute_final_dims(p, (2, 0, 1))
return p
def compute_projection(pair, mask, a=True, chunked=True):
need_transpose = self._outgoing ^ a
if not chunked:
p = compute_projection_helper(pair, mask, a)
if need_transpose:
p = p.transpose(-1, -2)
else:
# This computation is chunked so as not to exceed our 2.5x
# budget with a large intermediate tensor
linear_g = self.linear_a_g if a else self.linear_b_g
c = linear_g.bias.shape[-1]
out_shape = pair.shape[:-3] + (c,) + pair.shape[-3:-1]
p = pair.new_zeros(out_shape)
for i in range(0, pair.shape[-3], inplace_chunk_size):
pair_chunk = pair[..., i : i + inplace_chunk_size, :, :]
pair_chunk = compute_projection_helper(
pair[..., i : i + inplace_chunk_size, :, :],
mask[..., i : i + inplace_chunk_size, :, :],
a,
)
if need_transpose:
pair_chunk = pair_chunk.transpose(-1, -2)
p[..., i : i + inplace_chunk_size] = pair_chunk
else:
p[..., i : i + inplace_chunk_size, :] = pair_chunk
del pair_chunk
return p
# We start by fully manifesting a. In addition to the input, this
# brings total memory consumption to 2x z (disregarding size of chunks)
# [*, N, N, c]
a = compute_projection(z, mask, True, chunked=True)
if inplace_chunk_size is not None:
n = a.shape[-1]
half_n = n // 2 + n % 2
row_dim = -3
col_dim = -2
b_chunk_dim = row_dim if self._outgoing else col_dim
def empty_slicer(t):
return [slice(None) for _ in t.shape]
def slice_tensor(t, start, end, dim):
# Slices start:end from the dim dimension of t
s = empty_slicer(t)
s[dim] = slice(start, end)
return t[s]
def flip_z_cache_(z_cache, z):
# "Reorient" the z_cache (see below), filling it with quadrants
# 3---recovered from the z_cache---and 4---recovered from z---
# of the input tensor z.
quadrant_3 = slice_tensor(z_cache, half_n, None, row_dim)
z_cache = z_cache.transpose(row_dim, col_dim)
# If n is odd, we need to shrink the z_cache by one row
z_cache = z_cache[..., : (n // 2), :, :]
# Move the 3rd quadrant of z into the
first_half_slicer = empty_slicer(z_cache)
first_half_slicer[col_dim] = slice(0, half_n)
z_cache[first_half_slicer] = quadrant_3
# Get the fourth quadrant of z
quadrant_4 = slice_tensor(z, half_n, None, row_dim)
quadrant_4 = slice_tensor(quadrant_4, half_n, None, col_dim)
# Insert said quadrant into the rotated z-cache
quadrant_3_slicer = empty_slicer(z_cache)
quadrant_3_slicer[col_dim] = slice(half_n, None)
z_cache[quadrant_3_slicer] = quadrant_4
return z_cache
# Initialize the z cache to the left half of z.
z_cache_shape = list(z.shape)
z_cache_shape[col_dim] = half_n
z_cache = z.new_zeros(z_cache_shape)
z_cache_slicer = empty_slicer(z_cache)
z_cache_slicer[col_dim] = slice(0, half_n)
z_cache.copy_(z[z_cache_slicer])
z_cache_rotated = False
# We need to reorient the z-cache at the halfway point, and we
# don't want a single chunk to straddle that point. We contract one
# of the chunks in the middle to address that problem.
i_range = list(range(0, half_n, inplace_chunk_size))
initial_offsets = [i_2 - i_1 for i_1, i_2 in zip(i_range, i_range[1:] + [half_n])]
after_half = list(range(half_n, n, inplace_chunk_size))
after_half_offsets = [inplace_chunk_size for _ in after_half]
combined_range_with_offsets = zip(i_range + after_half, initial_offsets + after_half_offsets)
for i, offset in combined_range_with_offsets:
if not z_cache_rotated and i >= half_n:
z_cache = flip_z_cache_(z_cache, z)
z_cache_rotated = True
z_chunk_b = slice_tensor(z, i, i + offset, b_chunk_dim)
mask_chunk = slice_tensor(mask, i, i + offset, b_chunk_dim)
z_chunk_b = z_chunk_b.clone()
if b_chunk_dim == col_dim:
z_chunk_b = slice_tensor(z, i, i + offset, col_dim)
else: # b_chunk_dim == row_dim
# In this case, the b-dimension (b_chunk_dim) is partially
# overwritten at the end of each iteration. We need to
# restore the missing component from the z-cache.
if not z_cache_rotated:
z_chunk_slicer = empty_slicer(z_chunk_b)
z_chunk_slicer[col_dim] = slice(0, half_n)
z_chunk_b[z_chunk_slicer] = slice_tensor(z_cache, i, i + offset, row_dim)
else:
z_cache_offset = i - half_n
z_chunk_b = slice_tensor(z_cache, z_cache_offset, z_cache_offset + offset, row_dim)
b_chunk = compute_projection(z_chunk_b, mask_chunk, a=False, chunked=False)
del z_chunk_b
x_chunk = torch.matmul(a, b_chunk)
x_chunk = permute_final_dims(x_chunk, (1, 2, 0))
x_chunk = self.layer_norm_out(x_chunk)
x_chunk = self.linear_z(x_chunk)
# The g dimension (col_dim) is parallel to and ahead of the
# overwrites in z. We can extract the g chunk normally.
z_chunk_g = slice_tensor(z, i, i + offset, col_dim)
g_chunk = self.linear_g(self.layer_norm_in(z_chunk_g))
g_chunk.sigmoid_()
del z_chunk_g
x_chunk *= g_chunk
# Write the columns into z in-place
z_slicer = empty_slicer(z)
z_slicer[col_dim] = slice(i, i + offset)
if with_add:
z[z_slicer] += x_chunk
else:
z[z_slicer] = x_chunk
else:
b = compute_projection(z, mask, False, False)
x = torch.matmul(a, b)
x = self.layer_norm_out(x)
x = self.linear_z(x)
g = self.linear_g(z)
g.sigmoid_()
x *= g
if with_add:
z += x
else:
z = x
return z
def forward(
self,
z: torch.Tensor,
mask: Optional[torch.Tensor] = None,
inplace_safe: bool = False,
_add_with_inplace: bool = False,
_inplace_chunk_size: Optional[int] = 256,
) -> torch.Tensor:
"""
Args:
x:
[*, N_res, N_res, C_z] input tensor
mask:
[*, N_res, N_res] input mask
Returns:
[*, N_res, N_res, C_z] output tensor
"""
if inplace_safe:
x = self._inference_forward(
z,
mask,
inplace_chunk_size=_inplace_chunk_size,
with_add=_add_with_inplace,
)
return x
if mask is None:
mask = z.new_ones(z.shape[:-1])
mask = mask.unsqueeze(-1)
z = self.layer_norm_in(z)
a = mask
a = a * self.sigmoid(self.linear_a_g(z))
a = a * self.linear_a_p(z)
b = mask
b = b * self.sigmoid(self.linear_b_g(z))
b = b * self.linear_b_p(z)
if is_fp16_enabled():
with torch.cuda.amp.autocast(enabled=False):
x = self._combine_projections(a.float(), b.float())
else:
x = self._combine_projections(a, b)
del a, b
x = self.layer_norm_out(x)
x = self.linear_z(x)
g = self.sigmoid(self.linear_g(z))
x = x * g
return x
class EsmFoldPreTrainedModel(EsmPreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
# Subclass `EsMPreTrainedModel` to deal with special init
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, EsmFoldLinear):
with torch.no_grad():
if module.init_fn is not None:
module.init_fn(module.weight, module.bias)
elif module.init == "default":
trunc_normal_init_(module.weight, scale=1.0)
elif module.init == "relu":
trunc_normal_init_(module.weight, scale=2.0)
elif module.init == "glorot":
nn.init.xavier_uniform_(module.weight, gain=1)
elif module.init == "gating":
module.weight.fill_(0.0)
if module.bias:
module.bias.fill_(1.0)
elif module.init == "normal":
torch.nn.init.kaiming_normal_(module.weight, nonlinearity="linear")
elif module.init == "final":
module.weight.fill_(0.0)
elif isinstance(module, EsmFoldInvariantPointAttention):
ipa_point_weights_init_(module.head_weights)
elif isinstance(module, EsmFoldTriangularSelfAttentionBlock):
torch.nn.init.zeros_(module.tri_mul_in.linear_z.weight)
torch.nn.init.zeros_(module.tri_mul_in.linear_z.bias)
torch.nn.init.zeros_(module.tri_mul_out.linear_z.weight)
torch.nn.init.zeros_(module.tri_mul_out.linear_z.bias)
torch.nn.init.zeros_(module.tri_att_start.mha.linear_o.weight)
torch.nn.init.zeros_(module.tri_att_start.mha.linear_o.bias)
torch.nn.init.zeros_(module.tri_att_end.mha.linear_o.weight)
torch.nn.init.zeros_(module.tri_att_end.mha.linear_o.bias)
torch.nn.init.zeros_(module.sequence_to_pair.o_proj.weight)
torch.nn.init.zeros_(module.sequence_to_pair.o_proj.bias)
torch.nn.init.zeros_(module.pair_to_sequence.linear.weight)
torch.nn.init.zeros_(module.seq_attention.o_proj.weight)
torch.nn.init.zeros_(module.seq_attention.o_proj.bias)
torch.nn.init.zeros_(module.mlp_seq.mlp[-2].weight)
torch.nn.init.zeros_(module.mlp_seq.mlp[-2].bias)
torch.nn.init.zeros_(module.mlp_pair.mlp[-2].weight)
torch.nn.init.zeros_(module.mlp_pair.mlp[-2].bias)
else:
super()._init_weights(module)
class EsmFoldSelfAttention(nn.Module):
def __init__(self, embed_dim, num_heads, head_width, gated=False):
super().__init__()
assert embed_dim == num_heads * head_width
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_width = head_width
self.proj = nn.Linear(embed_dim, embed_dim * 3, bias=False)
self.o_proj = nn.Linear(embed_dim, embed_dim, bias=True)
self.gated = gated
if gated:
self.g_proj = nn.Linear(embed_dim, embed_dim)
torch.nn.init.zeros_(self.g_proj.weight)
torch.nn.init.ones_(self.g_proj.bias)
self.rescale_factor = self.head_width**-0.5
torch.nn.init.zeros_(self.o_proj.bias)
def forward(self, x, mask=None, bias=None, indices=None):
"""
Basic self attention with optional mask and external pairwise bias. To handle sequences of different lengths,
use mask.
Inputs:
x: batch of input sequneces (.. x L x C) mask: batch of boolean masks where 1=valid, 0=padding position (..
x L_k) bias: batch of scalar pairwise attention biases (.. x Lq x Lk x num_heads)
Outputs:
sequence projection (B x L x embed_dim), attention maps (B x L x L x num_heads)
"""
t = self.proj(x).view(*x.shape[:2], self.num_heads, -1)
t = t.permute(0, 2, 1, 3)
q, k, v = t.chunk(3, dim=-1)
q = self.rescale_factor * q
a = torch.einsum("...qc,...kc->...qk", q, k)
# Add external attention bias.
if bias is not None:
a = a + bias.permute(0, 3, 1, 2)
# Do not attend to padding tokens.
if mask is not None:
mask = mask[:, None, None]
a = a.masked_fill(mask == False, -np.inf) # noqa: E712
a = nn.functional.softmax(a, dim=-1)
y = torch.einsum("...hqk,...hkc->...qhc", a, v)
y = y.reshape(*y.shape[:2], -1)
if self.gated:
y = self.g_proj(x).sigmoid() * y
y = self.o_proj(y)
return y, a.permute(0, 3, 1, 2)
class EsmFoldDropout(nn.Module):
"""
Implementation of dropout with the ability to share the dropout mask along a particular dimension.
"""
def __init__(self, r: float, batch_dim: Union[int, List[int]]):
super().__init__()
self.r = r
if isinstance(batch_dim, int):
batch_dim = [batch_dim]
self.batch_dim = batch_dim
self.dropout = nn.Dropout(self.r)
def forward(self, x: torch.Tensor) -> torch.Tensor:
shape = list(x.shape)
if self.batch_dim is not None:
for bd in self.batch_dim:
shape[bd] = 1
return x * self.dropout(x.new_ones(shape))
class EsmFoldSequenceToPair(nn.Module):
def __init__(self, sequence_state_dim, inner_dim, pairwise_state_dim):
super().__init__()
self.layernorm = nn.LayerNorm(sequence_state_dim)
self.proj = nn.Linear(sequence_state_dim, inner_dim * 2, bias=True)
self.o_proj = nn.Linear(2 * inner_dim, pairwise_state_dim, bias=True)
torch.nn.init.zeros_(self.proj.bias)
torch.nn.init.zeros_(self.o_proj.bias)
def forward(self, sequence_state):
"""
Inputs:
sequence_state: B x L x sequence_state_dim
Output:
pairwise_state: B x L x L x pairwise_state_dim
Intermediate state:
B x L x L x 2*inner_dim
"""
assert len(sequence_state.shape) == 3
s = self.layernorm(sequence_state)
s = self.proj(s)
q, k = s.chunk(2, dim=-1)
prod = q[:, None, :, :] * k[:, :, None, :]
diff = q[:, None, :, :] - k[:, :, None, :]
x = torch.cat([prod, diff], dim=-1)
x = self.o_proj(x)
return x
class EsmFoldPairToSequence(nn.Module):
def __init__(self, pairwise_state_dim, num_heads):
super().__init__()
self.layernorm = nn.LayerNorm(pairwise_state_dim)
self.linear = nn.Linear(pairwise_state_dim, num_heads, bias=False)
def forward(self, pairwise_state):
"""
Inputs:
pairwise_state: B x L x L x pairwise_state_dim
Output:
pairwise_bias: B x L x L x num_heads
"""
assert len(pairwise_state.shape) == 4
z = self.layernorm(pairwise_state)
pairwise_bias = self.linear(z)
return pairwise_bias
class EsmFoldResidueMLP(nn.Module):
def __init__(self, embed_dim, inner_dim, dropout=0):
super().__init__()
self.mlp = nn.Sequential(
nn.LayerNorm(embed_dim),
nn.Linear(embed_dim, inner_dim),
nn.ReLU(),
nn.Linear(inner_dim, embed_dim),
nn.Dropout(dropout),
)
def forward(self, x):
return x + self.mlp(x)
class EsmFoldTriangularSelfAttentionBlock(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
sequence_state_dim = config.sequence_state_dim
pairwise_state_dim = config.pairwise_state_dim
sequence_num_heads = sequence_state_dim // config.sequence_head_width
pairwise_num_heads = pairwise_state_dim // config.pairwise_head_width
self.layernorm_1 = nn.LayerNorm(sequence_state_dim)
self.sequence_to_pair = EsmFoldSequenceToPair(sequence_state_dim, pairwise_state_dim // 2, pairwise_state_dim)
self.pair_to_sequence = EsmFoldPairToSequence(pairwise_state_dim, sequence_num_heads)
self.seq_attention = EsmFoldSelfAttention(
sequence_state_dim, sequence_num_heads, config.sequence_head_width, gated=True
)
self.tri_mul_out = EsmFoldTriangleMultiplicativeUpdate(config, _outgoing=True)
self.tri_mul_in = EsmFoldTriangleMultiplicativeUpdate(config, _outgoing=False)
self.tri_att_start = EsmFoldTriangleAttention(
pairwise_state_dim, config.pairwise_head_width, pairwise_num_heads, inf=1e9, starting=True
)
self.tri_att_end = EsmFoldTriangleAttention(
pairwise_state_dim, config.pairwise_head_width, pairwise_num_heads, inf=1e9, starting=False
)
self.mlp_seq = EsmFoldResidueMLP(sequence_state_dim, 4 * sequence_state_dim, dropout=config.dropout)
self.mlp_pair = EsmFoldResidueMLP(pairwise_state_dim, 4 * pairwise_state_dim, dropout=config.dropout)
self.drop = nn.Dropout(config.dropout)
self.row_drop = EsmFoldDropout(config.dropout * 2, 2)
self.col_drop = EsmFoldDropout(config.dropout * 2, 1)
def forward(self, sequence_state, pairwise_state, mask=None, chunk_size=None, **__kwargs):
"""
Inputs:
sequence_state: B x L x sequence_state_dim pairwise_state: B x L x L x pairwise_state_dim mask: B x L boolean
tensor of valid positions
Output:
sequence_state: B x L x sequence_state_dim pairwise_state: B x L x L x pairwise_state_dim
"""
if len(sequence_state.shape) != 3:
raise ValueError(f"`sequence_state` should be a 3d-tensor, got {len(sequence_state.shape)} dims.")
if len(pairwise_state.shape) != 4:
raise ValueError(f"`pairwise_state` should be a 4d-tensor, got {len(pairwise_state.shape)} dims.")
if mask is not None and len(mask.shape) != 2:
raise ValueError(f"`mask` should be a 2d-tensor, got {len(mask.shape)} dims.")
batch_dim, seq_dim, sequence_state_dim = sequence_state.shape
pairwise_state_dim = pairwise_state.shape[3]
if sequence_state_dim != self.config.sequence_state_dim:
raise ValueError(
"`sequence_state` last dimension should be equal to `self.sequence_state_dim`. Got "
f"{sequence_state_dim} != {self.config.sequence_state_dim}."
)
if pairwise_state_dim != self.config.pairwise_state_dim:
raise ValueError(
"`pairwise_state` last dimension should be equal to `self.pairwise_state_dim`. Got "
f"{pairwise_state_dim} != {self.config.pairwise_state_dim}."
)
if batch_dim != pairwise_state.shape[0]:
raise ValueError(
f"`sequence_state` and `pairwise_state` have inconsistent batch size: {batch_dim} != "
f"{pairwise_state.shape[0]}."
)
if seq_dim != pairwise_state.shape[1] or seq_dim != pairwise_state.shape[2]:
raise ValueError(
f"`sequence_state` and `pairwise_state` have inconsistent sequence length: {seq_dim} != "
f"{pairwise_state.shape[1]} or {pairwise_state.shape[2]}."
)
# Update sequence state
bias = self.pair_to_sequence(pairwise_state)
# Self attention with bias + mlp.
y = self.layernorm_1(sequence_state)
y, _ = self.seq_attention(y, mask=mask, bias=bias)
sequence_state = sequence_state + self.drop(y)
sequence_state = self.mlp_seq(sequence_state)
# Update pairwise state
pairwise_state = pairwise_state + self.sequence_to_pair(sequence_state)
# Axial attention with triangular bias.
tri_mask = mask.unsqueeze(2) * mask.unsqueeze(1) if mask is not None else None
pairwise_state = pairwise_state + self.row_drop(self.tri_mul_out(pairwise_state, mask=tri_mask))
pairwise_state = pairwise_state + self.col_drop(self.tri_mul_in(pairwise_state, mask=tri_mask))
pairwise_state = pairwise_state + self.row_drop(
self.tri_att_start(pairwise_state, mask=tri_mask, chunk_size=chunk_size)
)
pairwise_state = pairwise_state + self.col_drop(
self.tri_att_end(pairwise_state, mask=tri_mask, chunk_size=chunk_size)
)
# MLP over pairs.
pairwise_state = self.mlp_pair(pairwise_state)
return sequence_state, pairwise_state
class EsmCategoricalMixture:
def __init__(self, param, bins=50, start=0, end=1):
# All tensors are of shape ..., bins.
self.logits = param
bins = torch.linspace(start, end, bins + 1, device=self.logits.device, dtype=self.logits.dtype)
self.v_bins = (bins[:-1] + bins[1:]) / 2
def log_prob(self, true):
# Shapes are:
# self.probs: ... x bins
# true : ...
true_index = (true.unsqueeze(-1) - self.v_bins[[None] * true.ndim]).abs().argmin(-1)
nll = self.logits.log_softmax(-1)
return torch.take_along_dim(nll, true_index.unsqueeze(-1), dim=-1).squeeze(-1)
def mean(self):
return (self.logits.softmax(-1) @ self.v_bins.unsqueeze(1)).squeeze(-1)
def categorical_lddt(logits, bins=50):
# Logits are ..., 37, bins.
return EsmCategoricalMixture(logits, bins=bins).mean()
def get_axial_mask(mask):
"""
Helper to convert B x L mask of valid positions to axial mask used in row column attentions.
Input:
mask: B x L tensor of booleans
Output:
mask: B x L x L tensor of booleans
"""
if mask is None:
return None
if len(mask.shape) != 2:
raise ValueError(f"`mask` should be a 2d-tensor, got {len(mask.shape)} dims.")
batch_dim, seq_dim = mask.shape
m = mask.unsqueeze(1).expand(batch_dim, seq_dim, seq_dim)
m = m.reshape(batch_dim * seq_dim, seq_dim)
return m
class EsmFoldRelativePosition(nn.Module):
def __init__(self, config):
super().__init__()
self.bins = config.position_bins
# Note an additional offset is used so that the 0th position
# is reserved for masked pairs.
self.embedding = torch.nn.Embedding(2 * self.bins + 2, config.pairwise_state_dim)
def forward(self, residue_index, mask=None):
"""
Input:
residue_index: B x L tensor of indices (dytpe=torch.long) mask: B x L tensor of booleans
Output:
pairwise_state: B x L x L x pairwise_state_dim tensor of embeddings
"""
if residue_index.dtype != torch.long:
raise ValueError(f"`residue_index` has dtype {residue_index.dtype}, it should be `torch.long`.")
if mask is not None and residue_index.shape != mask.shape:
raise ValueError(
f"`residue_index` and `mask` have inconsistent shapes: {residue_index.shape} != {mask.shape}."
)
diff = residue_index[:, None, :] - residue_index[:, :, None]
diff = diff.clamp(-self.bins, self.bins)
diff = diff + self.bins + 1 # Add 1 to adjust for padding index.
if mask is not None:
mask = mask[:, None, :] * mask[:, :, None]
diff[mask == False] = 0 # noqa: E712
output = self.embedding(diff)
return output
class EsmFoldAngleResnetBlock(nn.Module):
def __init__(self, config):
super().__init__()
self.linear_1 = EsmFoldLinear(config.resnet_dim, config.resnet_dim, init="relu")
self.linear_2 = EsmFoldLinear(config.resnet_dim, config.resnet_dim, init="final")
self.relu = nn.ReLU()
def forward(self, a: torch.Tensor) -> torch.Tensor:
s_initial = a
a = self.relu(a)
a = self.linear_1(a)
a = self.relu(a)
a = self.linear_2(a)
return a + s_initial
class EsmFoldAngleResnet(nn.Module):
"""
Implements Algorithm 20, lines 11-14
"""
def __init__(self, config):
super().__init__()
self.config = config
self.linear_in = EsmFoldLinear(config.sequence_dim, config.resnet_dim)
self.linear_initial = EsmFoldLinear(config.sequence_dim, config.resnet_dim)
self.layers = nn.ModuleList()
for _ in range(config.num_resnet_blocks):
layer = EsmFoldAngleResnetBlock(config)
self.layers.append(layer)
self.linear_out = EsmFoldLinear(config.resnet_dim, config.num_angles * 2)
self.relu = nn.ReLU()
def forward(self, s: torch.Tensor, s_initial: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
s:
[*, C_hidden] single embedding
s_initial:
[*, C_hidden] single embedding as of the start of the StructureModule
Returns:
[*, no_angles, 2] predicted angles
"""
# NOTE: The ReLU's applied to the inputs are absent from the supplement
# pseudocode but present in the source. For maximal compatibility with
# the pretrained weights, I'm going with the source.
# [*, C_hidden]
s_initial = self.relu(s_initial)
s_initial = self.linear_initial(s_initial)
s = self.relu(s)
s = self.linear_in(s)
s = s + s_initial
for l in self.layers:
s = l(s)
s = self.relu(s)
# [*, no_angles * 2]
s = self.linear_out(s)
# [*, no_angles, 2]
s = s.view(s.shape[:-1] + (-1, 2))
unnormalized_s = s
norm_denom = torch.sqrt(
torch.clamp(
torch.sum(s**2, dim=-1, keepdim=True),
min=self.config.epsilon,
)
)
s = s / norm_denom
return unnormalized_s, s
class EsmFoldInvariantPointAttention(nn.Module):
"""
Implements Algorithm 22.
"""
def __init__(self, config):
super().__init__()
self.config = config
c_s = config.sequence_dim
c_z = config.pairwise_dim
self.hidden_dim = config.ipa_dim
self.num_heads = config.num_heads_ipa
self.num_qk_points = config.num_qk_points
self.num_v_points = config.num_v_points
# These linear layers differ from their specifications in the
# supplement. There, they lack bias and use Glorot initialization.
# Here as in the official source, they have bias and use the default
# Lecun initialization.
hc = config.ipa_dim * config.num_heads_ipa
self.linear_q = EsmFoldLinear(c_s, hc)
self.linear_kv = EsmFoldLinear(c_s, 2 * hc)
hpq = config.num_heads_ipa * config.num_qk_points * 3
self.linear_q_points = EsmFoldLinear(c_s, hpq)
hpkv = config.num_heads_ipa * (config.num_qk_points + config.num_v_points) * 3
self.linear_kv_points = EsmFoldLinear(c_s, hpkv)
self.linear_b = EsmFoldLinear(c_z, config.num_heads_ipa)
self.head_weights = nn.Parameter(torch.zeros((config.num_heads_ipa)))
concat_out_dim = config.num_heads_ipa * (c_z + config.ipa_dim + config.num_v_points * 4)
self.linear_out = EsmFoldLinear(concat_out_dim, c_s, init="final")
self.softmax = nn.Softmax(dim=-1)
self.softplus = nn.Softplus()
def forward(
self,
s: torch.Tensor,
z: Optional[torch.Tensor],
r: Rigid,
mask: torch.Tensor,
_offload_inference: bool = False,
_z_reference_list: Optional[Sequence[torch.Tensor]] = None,
) -> torch.Tensor:
"""
Args:
s:
[*, N_res, C_s] single representation
z:
[*, N_res, N_res, C_z] pair representation
r:
[*, N_res] transformation object
mask:
[*, N_res] mask
Returns:
[*, N_res, C_s] single representation update
"""
z = [z]
#######################################
# Generate scalar and point activations
#######################################
# [*, N_res, H * C_hidden]
q = self.linear_q(s)
kv = self.linear_kv(s)
# [*, N_res, H, C_hidden]
q = q.view(q.shape[:-1] + (self.num_heads, -1))
# [*, N_res, H, 2 * C_hidden]
kv = kv.view(kv.shape[:-1] + (self.num_heads, -1))
# [*, N_res, H, C_hidden]
k, v = torch.split(kv, self.hidden_dim, dim=-1)
# [*, N_res, H * P_q * 3]
q_pts = self.linear_q_points(s)
# This is kind of clunky, but it's how the original does it
# [*, N_res, H * P_q, 3]
q_pts = torch.split(q_pts, q_pts.shape[-1] // 3, dim=-1)
q_pts = torch.stack(q_pts, dim=-1)
q_pts = r[..., None].apply(q_pts)
# [*, N_res, H, P_q, 3]
q_pts = q_pts.view(q_pts.shape[:-2] + (self.num_heads, self.num_qk_points, 3))
# [*, N_res, H * (P_q + P_v) * 3]
kv_pts = self.linear_kv_points(s)
# [*, N_res, H * (P_q + P_v), 3]
kv_pts = torch.split(kv_pts, kv_pts.shape[-1] // 3, dim=-1)
kv_pts = torch.stack(kv_pts, dim=-1)
kv_pts = r[..., None].apply(kv_pts)
# [*, N_res, H, (P_q + P_v), 3]
kv_pts = kv_pts.view(kv_pts.shape[:-2] + (self.num_heads, -1, 3))
# [*, N_res, H, P_q/P_v, 3]
k_pts, v_pts = torch.split(kv_pts, [self.num_qk_points, self.num_v_points], dim=-2)
##########################
# Compute attention scores
##########################
# [*, N_res, N_res, H]
b = self.linear_b(z[0])
if _offload_inference:
assert sys.getrefcount(z[0]) == 2
z[0] = z[0].cpu()
# [*, H, N_res, N_res]
if is_fp16_enabled():
with torch.cuda.amp.autocast(enabled=False):
a = torch.matmul(
permute_final_dims(q.float(), (1, 0, 2)), # [*, H, N_res, C_hidden]
permute_final_dims(k.float(), (1, 2, 0)), # [*, H, C_hidden, N_res]
)
else:
a = torch.matmul(
permute_final_dims(q, (1, 0, 2)), # [*, H, N_res, C_hidden]
permute_final_dims(k, (1, 2, 0)), # [*, H, C_hidden, N_res]
)
a *= math.sqrt(1.0 / (3 * self.hidden_dim))
a += math.sqrt(1.0 / 3) * permute_final_dims(b, (2, 0, 1))
# [*, N_res, N_res, H, P_q, 3]
pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5)
pt_att = pt_att**2
# [*, N_res, N_res, H, P_q]
pt_att = sum(torch.unbind(pt_att, dim=-1))
head_weights = self.softplus(self.head_weights).view(*((1,) * len(pt_att.shape[:-2]) + (-1, 1)))
head_weights = head_weights * math.sqrt(1.0 / (3 * (self.num_qk_points * 9.0 / 2)))
pt_att = pt_att * head_weights
# [*, N_res, N_res, H]
pt_att = torch.sum(pt_att, dim=-1) * (-0.5)
# [*, N_res, N_res]
square_mask = mask.unsqueeze(-1) * mask.unsqueeze(-2)
square_mask = self.config.inf * (square_mask - 1)
# [*, H, N_res, N_res]
pt_att = permute_final_dims(pt_att, (2, 0, 1))
a = a + pt_att
a = a + square_mask.unsqueeze(-3)
a = self.softmax(a)
################
# Compute output
################
# [*, N_res, H, C_hidden]
o = torch.matmul(a, v.transpose(-2, -3).to(dtype=a.dtype)).transpose(-2, -3)
# [*, N_res, H * C_hidden]
o = flatten_final_dims(o, 2)
# [*, H, 3, N_res, P_v]
o_pt = torch.sum(
(a[..., None, :, :, None] * permute_final_dims(v_pts, (1, 3, 0, 2))[..., None, :, :]),
dim=-2,
)
# [*, N_res, H, P_v, 3]
o_pt = permute_final_dims(o_pt, (2, 0, 3, 1))
o_pt = r[..., None, None].invert_apply(o_pt)
# [*, N_res, H * P_v]
o_pt_norm = flatten_final_dims(torch.sqrt(torch.sum(o_pt**2, dim=-1) + self.config.epsilon), 2)
# [*, N_res, H * P_v, 3]
o_pt = o_pt.reshape(*o_pt.shape[:-3], -1, 3)
if _offload_inference:
z[0] = z[0].to(o_pt.device)
# [*, N_res, H, C_z]
o_pair = torch.matmul(a.transpose(-2, -3), z[0].to(dtype=a.dtype))
# [*, N_res, H * C_z]
o_pair = flatten_final_dims(o_pair, 2)
# [*, N_res, C_s]
s = self.linear_out(
torch.cat((o, *torch.unbind(o_pt, dim=-1), o_pt_norm, o_pair), dim=-1).to(dtype=z[0].dtype)
)
return s
class EsmFoldBackboneUpdate(nn.Module):
"""
Implements part of Algorithm 23.
"""
def __init__(self, config):
super().__init__()
self.linear = EsmFoldLinear(config.sequence_dim, 6, init="final")
def forward(self, s: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
[*, N_res, C_s] single representation
Returns:
[*, N_res, 6] update vector
"""
# [*, 6]
update = self.linear(s)
return update
class EsmFoldStructureModuleTransitionLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.linear_1 = EsmFoldLinear(config.sequence_dim, config.sequence_dim, init="relu")
self.linear_2 = EsmFoldLinear(config.sequence_dim, config.sequence_dim, init="relu")
self.linear_3 = EsmFoldLinear(config.sequence_dim, config.sequence_dim, init="final")
self.relu = nn.ReLU()
def forward(self, s):
s_initial = s
s = self.linear_1(s)
s = self.relu(s)
s = self.linear_2(s)
s = self.relu(s)
s = self.linear_3(s)
s = s + s_initial
return s
class EsmFoldStructureModuleTransition(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.layers = nn.ModuleList()
for _ in range(config.num_transition_layers):
l = EsmFoldStructureModuleTransitionLayer(config)
self.layers.append(l)
self.dropout = nn.Dropout(config.dropout_rate)
self.layer_norm = LayerNorm(config.sequence_dim)
def forward(self, s):
for l in self.layers:
s = l(s)
s = self.dropout(s)
s = self.layer_norm(s)
return s
class EsmFoldStructureModule(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
# Buffers to be lazily initialized later
# self.default_frames
# self.group_idx
# self.atom_mask
# self.lit_positions
self.layer_norm_s = LayerNorm(config.sequence_dim)
self.layer_norm_z = LayerNorm(config.pairwise_dim)
self.linear_in = EsmFoldLinear(config.sequence_dim, config.sequence_dim)
self.ipa = EsmFoldInvariantPointAttention(config)
self.ipa_dropout = nn.Dropout(config.dropout_rate)
self.layer_norm_ipa = LayerNorm(config.sequence_dim)
self.transition = EsmFoldStructureModuleTransition(config)
self.bb_update = EsmFoldBackboneUpdate(config)
self.angle_resnet = EsmFoldAngleResnet(config)
def forward(
self,
evoformer_output_dict,
aatype,
mask=None,
_offload_inference=False,
):
"""
Args:
evoformer_output_dict:
Dictionary containing:
"single":
[*, N_res, C_s] single representation
"pair":
[*, N_res, N_res, C_z] pair representation
aatype:
[*, N_res] amino acid indices
mask:
Optional [*, N_res] sequence mask
Returns:
A dictionary of outputs
"""
s = evoformer_output_dict["single"]
if mask is None:
# [*, N]
mask = s.new_ones(s.shape[:-1])
# [*, N, C_s]
s = self.layer_norm_s(s)
# [*, N, N, C_z]
z = self.layer_norm_z(evoformer_output_dict["pair"])
z_reference_list = None
if _offload_inference:
assert sys.getrefcount(evoformer_output_dict["pair"]) == 2
evoformer_output_dict["pair"] = evoformer_output_dict["pair"].cpu()
z_reference_list = [z]
z = None
# [*, N, C_s]
s_initial = s
s = self.linear_in(s)
# [*, N]
rigids = Rigid.identity(
s.shape[:-1],
s.dtype,
s.device,
self.training,
fmt="quat",
)
outputs = []
for i in range(self.config.num_blocks):
# [*, N, C_s]
s = s + self.ipa(
s,
z,
rigids,
mask,
_offload_inference=_offload_inference,
_z_reference_list=z_reference_list,
)
s = self.ipa_dropout(s)
s = self.layer_norm_ipa(s)
s = self.transition(s)
# [*, N]
rigids = rigids.compose_q_update_vec(self.bb_update(s))
# To hew as closely as possible to AlphaFold, we convert our
# quaternion-based transformations to rotation-matrix ones
# here
backb_to_global = Rigid(
Rotation(rot_mats=rigids.get_rots().get_rot_mats(), quats=None),
rigids.get_trans(),
)
backb_to_global = backb_to_global.scale_translation(self.config.trans_scale_factor)
# [*, N, 7, 2]
unnormalized_angles, angles = self.angle_resnet(s, s_initial)
all_frames_to_global = self.torsion_angles_to_frames(backb_to_global, angles, aatype)
pred_xyz = self.frames_and_literature_positions_to_atom14_pos(all_frames_to_global, aatype)
scaled_rigids = rigids.scale_translation(self.config.trans_scale_factor)
preds = {
"frames": scaled_rigids.to_tensor_7(),
"sidechain_frames": all_frames_to_global.to_tensor_4x4(),
"unnormalized_angles": unnormalized_angles,
"angles": angles,
"positions": pred_xyz,
"states": s,
}
outputs.append(preds)
rigids = rigids.stop_rot_gradient()
del z, z_reference_list
if _offload_inference:
evoformer_output_dict["pair"] = evoformer_output_dict["pair"].to(s.device)
outputs = dict_multimap(torch.stack, outputs)
outputs["single"] = s
return outputs
def _init_residue_constants(self, float_dtype, device):
if not hasattr(self, "default_frames"):
self.register_buffer(
"default_frames",
torch.tensor(
residue_constants.restype_rigid_group_default_frame,
dtype=float_dtype,
device=device,
requires_grad=False,
),
persistent=False,
)
if not hasattr(self, "group_idx"):
self.register_buffer(
"group_idx",
torch.tensor(
residue_constants.restype_atom14_to_rigid_group,
device=device,
requires_grad=False,
),
persistent=False,
)
if not hasattr(self, "atom_mask"):
self.register_buffer(
"atom_mask",
torch.tensor(
residue_constants.restype_atom14_mask,
dtype=float_dtype,
device=device,
requires_grad=False,
),
persistent=False,
)
if not hasattr(self, "lit_positions"):
self.register_buffer(
"lit_positions",
torch.tensor(
residue_constants.restype_atom14_rigid_group_positions,
dtype=float_dtype,
device=device,
requires_grad=False,
),
persistent=False,
)
def torsion_angles_to_frames(self, r, alpha, f):
# Lazily initialize the residue constants on the correct device
self._init_residue_constants(alpha.dtype, alpha.device)
# Separated purely to make testing less annoying
return torsion_angles_to_frames(r, alpha, f, self.default_frames)
def frames_and_literature_positions_to_atom14_pos(self, r, f): # [*, N, 8] # [*, N]
# Lazily initialize the residue constants on the correct device
self._init_residue_constants(r.get_rots().dtype, r.get_rots().device)
return frames_and_literature_positions_to_atom14_pos(
r,
f,
self.default_frames,
self.group_idx,
self.atom_mask,
self.lit_positions,
)
class EsmFoldingTrunk(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
c_s = config.sequence_state_dim
c_z = config.pairwise_state_dim
self.pairwise_positional_embedding = EsmFoldRelativePosition(config)
self.blocks = nn.ModuleList([EsmFoldTriangularSelfAttentionBlock(config) for _ in range(config.num_blocks)])
self.recycle_bins = 15
self.recycle_s_norm = nn.LayerNorm(c_s)
self.recycle_z_norm = nn.LayerNorm(c_z)
self.recycle_disto = nn.Embedding(self.recycle_bins, c_z)
self.recycle_disto.weight[0].detach().zero_()
self.structure_module = EsmFoldStructureModule(config.structure_module)
self.trunk2sm_s = nn.Linear(c_s, config.structure_module.sequence_dim)
self.trunk2sm_z = nn.Linear(c_z, config.structure_module.pairwise_dim)
self.chunk_size = config.chunk_size
def set_chunk_size(self, chunk_size):
# This parameter means the axial attention will be computed
# in a chunked manner. This should make the memory used more or less O(L) instead of O(L^2).
# It's equivalent to running a for loop over chunks of the dimension we're iterative over,
# where the chunk_size is the size of the chunks, so 128 would mean to parse 128-length chunks.
self.chunk_size = chunk_size
def forward(self, seq_feats, pair_feats, true_aa, residx, mask, no_recycles):
"""
Inputs:
seq_feats: B x L x C tensor of sequence features pair_feats: B x L x L x C tensor of pair features residx: B
x L long tensor giving the position in the sequence mask: B x L boolean tensor indicating valid residues
Output:
predicted_structure: B x L x (num_atoms_per_residue * 3) tensor wrapped in a Coordinates object
"""
device = seq_feats.device
s_s_0 = seq_feats
s_z_0 = pair_feats
if no_recycles is None:
no_recycles = self.config.max_recycles
else:
if no_recycles < 0:
raise ValueError("Number of recycles must not be negative.")
no_recycles += 1 # First 'recycle' is just the standard forward pass through the model.
def trunk_iter(s, z, residx, mask):
z = z + self.pairwise_positional_embedding(residx, mask=mask)
for block in self.blocks:
s, z = block(s, z, mask=mask, residue_index=residx, chunk_size=self.chunk_size)
return s, z
s_s = s_s_0
s_z = s_z_0
recycle_s = torch.zeros_like(s_s)
recycle_z = torch.zeros_like(s_z)
recycle_bins = torch.zeros(*s_z.shape[:-1], device=device, dtype=torch.int64)
for recycle_idx in range(no_recycles):
with ContextManagers([] if recycle_idx == no_recycles - 1 else [torch.no_grad()]):
# === Recycling ===
recycle_s = self.recycle_s_norm(recycle_s.detach()).to(device)
recycle_z = self.recycle_z_norm(recycle_z.detach()).to(device)
recycle_z += self.recycle_disto(recycle_bins.detach()).to(device)
s_s, s_z = trunk_iter(s_s_0 + recycle_s, s_z_0 + recycle_z, residx, mask)
# === Structure module ===
structure = self.structure_module(
{"single": self.trunk2sm_s(s_s), "pair": self.trunk2sm_z(s_z)},
true_aa,
mask.float(),
)
recycle_s = s_s
recycle_z = s_z
# Distogram needs the N, CA, C coordinates, and bin constants same as alphafold.
recycle_bins = EsmFoldingTrunk.distogram(
structure["positions"][-1][:, :, :3],
3.375,
21.375,
self.recycle_bins,
)
structure["s_s"] = s_s
structure["s_z"] = s_z
return structure
@staticmethod
def distogram(coords, min_bin, max_bin, num_bins):
# Coords are [... L x 3 x 3], where it's [N, CA, C] x 3 coordinates.
boundaries = torch.linspace(
min_bin,
max_bin,
num_bins - 1,
device=coords.device,
)
boundaries = boundaries**2
N, CA, C = [x.squeeze(-2) for x in coords.chunk(3, dim=-2)]
# Infer CB coordinates.
b = CA - N
c = C - CA
a = b.cross(c, dim=-1)
CB = -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + CA
dists = (CB[..., None, :, :] - CB[..., :, None, :]).pow(2).sum(dim=-1, keepdims=True)
bins = torch.sum(dists > boundaries, dim=-1) # [..., L, L]
return bins
# TODO Add information to the docstring about any methods that convert to PDB format, or otherwise prepare
# the outputs for downstream use.
@add_start_docstrings(
"""
ESMForProteinFolding is the HuggingFace port of the original ESMFold model. It consists of an ESM-2 "stem" followed
by a protein folding "head", although unlike most other output heads, this "head" is similar in size and runtime to
the rest of the model combined! It outputs a dictionary containing predicted structural information about the input
protein(s).
""",
ESM_START_DOCSTRING,
)
class EsmForProteinFolding(EsmPreTrainedModel):
_no_split_modules = ["EsmFoldStructureModule", "EsmFoldTriangularSelfAttentionBlock"]
def __init__(self, config):
super().__init__(config)
self.config = config
self.distogram_bins = 64
self.esm = EsmModel(config, add_pooling_layer=False)
self.esm.requires_grad_(False)
if self.config.esmfold_config.fp16_esm:
self.esm.half()
self.esm_feats = self.config.hidden_size
self.esm_attns = self.config.num_hidden_layers * self.config.num_attention_heads
self.esm_layers = self.config.num_hidden_layers
self.register_buffer("af2_to_esm", self._af2_to_esm_from_vocab_list(config.vocab_list))
self.esm_s_combine = nn.Parameter(torch.zeros(self.esm_layers + 1))
trunk_config = self.config.esmfold_config.trunk
c_s = trunk_config.sequence_state_dim
c_z = trunk_config.pairwise_state_dim
self.esm_s_mlp = nn.Sequential(
LayerNorm(self.esm_feats),
nn.Linear(self.esm_feats, c_s),
nn.ReLU(),
nn.Linear(c_s, c_s),
)
# 0 is padding, N is unknown residues, N + 1 is mask.
self.n_tokens_embed = residue_constants.restype_num + 3
self.pad_idx = 0
self.unk_idx = self.n_tokens_embed - 2
self.mask_idx = self.n_tokens_embed - 1
self.esm_dict_cls_idx = self.config.vocab_list.index("<cls>")
self.esm_dict_mask_idx = self.config.vocab_list.index("<mask>")
self.esm_dict_eos_idx = self.config.vocab_list.index("<eos>")
self.esm_dict_padding_idx = self.config.vocab_list.index("<pad>")
if self.config.esmfold_config.embed_aa:
self.embedding = nn.Embedding(self.n_tokens_embed, c_s, padding_idx=0)
self.trunk = EsmFoldingTrunk(trunk_config)
self.distogram_head = nn.Linear(c_z, self.distogram_bins)
self.ptm_head = nn.Linear(c_z, self.distogram_bins)
self.lm_head = nn.Linear(c_s, self.n_tokens_embed)
self.lddt_bins = 50
structure_module_config = trunk_config.structure_module
self.lddt_head = nn.Sequential(
nn.LayerNorm(structure_module_config.sequence_dim),
nn.Linear(structure_module_config.sequence_dim, self.config.esmfold_config.lddt_head_hid_dim),
nn.Linear(self.config.esmfold_config.lddt_head_hid_dim, self.config.esmfold_config.lddt_head_hid_dim),
nn.Linear(self.config.esmfold_config.lddt_head_hid_dim, 37 * self.lddt_bins),
)
@staticmethod
def _af2_to_esm_from_vocab_list(vocab_list: List[str]) -> torch.Tensor:
# Remember that t is shifted from residue_constants by 1 (0 is padding).
esm_reorder = [vocab_list.index("<pad>")] + [vocab_list.index(v) for v in residue_constants.restypes_with_x]
return torch.tensor(esm_reorder)
@add_start_docstrings_to_model_forward(ESMFOLD_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=EsmForProteinFoldingOutput, config_class=EsmConfig)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
masking_pattern: Optional[torch.Tensor] = None,
num_recycles: Optional[int] = None,
) -> EsmForProteinFoldingOutput:
r"""
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, EsmForProteinFolding
>>> model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1")
>>> tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1")
>>> inputs = tokenizer(["MLKNVQVQLV"], return_tensors="pt", add_special_tokens=False) # A tiny random peptide
>>> outputs = model(**inputs)
>>> folded_positions = outputs.positions
```
"""
cfg = self.config.esmfold_config
aa = input_ids # B x L
B = aa.shape[0]
L = aa.shape[1]
device = input_ids.device
if attention_mask is None:
attention_mask = torch.ones_like(aa, device=device)
if position_ids is None:
position_ids = torch.arange(L, device=device).expand_as(input_ids)
# === ESM ===
esmaa = self.af2_idx_to_esm_idx(aa, attention_mask)
if masking_pattern is not None:
masked_aa, esmaa, mlm_targets = self.bert_mask(aa, esmaa, attention_mask, masking_pattern)
else:
masked_aa = aa
mlm_targets = None
# We get sequence and pair representations from whatever version of ESM /
# configuration we are using. The sequence representation esm_s is always
# present. The pair embedding esm_z may be present depending on the
# configuration of the model. If esm_z is not used by the model then it
# is returned as None here.
esm_s = self.compute_language_model_representations(esmaa)
# Convert esm_s and esm_z, if present, to the precision used by the trunk and
# the structure module. These tensors may be a lower precision if, for example,
# we're running the language model in fp16 precision.
esm_s = esm_s.to(self.esm_s_combine.dtype)
if cfg.esm_ablate_sequence:
esm_s = esm_s * 0
esm_s = esm_s.detach()
# === preprocessing ===
esm_s = (self.esm_s_combine.softmax(0).unsqueeze(0) @ esm_s).squeeze(2)
s_s_0 = self.esm_s_mlp(esm_s)
s_z_0 = s_s_0.new_zeros(B, L, L, cfg.trunk.pairwise_state_dim)
if self.config.esmfold_config.embed_aa:
s_s_0 += self.embedding(masked_aa)
structure: dict = self.trunk(s_s_0, s_z_0, aa, position_ids, attention_mask, no_recycles=num_recycles)
# Documenting what we expect:
structure = {
k: v
for k, v in structure.items()
if k
in [
"s_z",
"s_s",
"frames",
"sidechain_frames",
"unnormalized_angles",
"angles",
"positions",
"states",
]
}
# Add BERT mask for the loss to use, if available.
if mlm_targets:
structure["mlm_targets"] = mlm_targets
disto_logits = self.distogram_head(structure["s_z"])
disto_logits = (disto_logits + disto_logits.transpose(1, 2)) / 2
structure["distogram_logits"] = disto_logits
lm_logits = self.lm_head(structure["s_s"])
structure["lm_logits"] = lm_logits
structure["aatype"] = aa
make_atom14_masks(structure)
# Of course, this doesn't respect the true mask because it doesn't know about it...
# We're not going to properly mask change of index tensors:
# "residx_atom14_to_atom37",
# "residx_atom37_to_atom14",
for k in [
"atom14_atom_exists",
"atom37_atom_exists",
]:
structure[k] *= attention_mask.unsqueeze(-1)
structure["residue_index"] = position_ids
lddt_head = self.lddt_head(structure["states"]).reshape(structure["states"].shape[0], B, L, -1, self.lddt_bins)
structure["lddt_head"] = lddt_head
plddt = categorical_lddt(lddt_head[-1], bins=self.lddt_bins)
structure["plddt"] = plddt
ptm_logits = self.ptm_head(structure["s_z"])
structure["ptm_logits"] = ptm_logits
structure["ptm"] = compute_tm(ptm_logits, max_bin=31, no_bins=self.distogram_bins)
structure.update(compute_predicted_aligned_error(ptm_logits, max_bin=31, no_bins=self.distogram_bins))
return EsmForProteinFoldingOutput(**structure)
def af2_idx_to_esm_idx(self, aa, mask):
# avoid indexing on different devices
if self.af2_to_esm.device != aa.device:
self.af2_to_esm = self.af2_to_esm.to(aa.device)
aa = (aa + 1).masked_fill(mask != 1, 0)
return self.af2_to_esm[aa]
def compute_language_model_representations(self, esmaa: torch.Tensor) -> torch.Tensor:
device = next(self.parameters()).device
B, L = esmaa.shape # B = batch size, L = sequence length.
if self.config.esmfold_config.bypass_lm:
esm_s = torch.zeros(B, L, self.esm_s_combine.size[0], -1, self.esm_feats, device=device)
return esm_s
bosi, eosi = self.esm_dict_cls_idx, self.esm_dict_eos_idx
bos = esmaa.new_full((B, 1), bosi)
eos = esmaa.new_full((B, 1), self.esm_dict_padding_idx)
esmaa = torch.cat([bos, esmaa, eos], dim=1)
# Use the first padding index as eos during inference.
esmaa[range(B), (esmaa != 1).sum(1)] = eosi
# _, esm_z, esm_s = self.esm(esmaa, return_pairs=self.config.esmfold_config.use_esm_attn_map)
# Because we do not support use_esm_attn_map in the HF port as it is not used in any public models,
# esm_z is always None
esm_hidden_states = self.esm(esmaa, attention_mask=esmaa != 1, output_hidden_states=True)["hidden_states"]
esm_s = torch.stack(esm_hidden_states, dim=2)
esm_s = esm_s[:, 1:-1] # B, L, nLayers, C
return esm_s
def bert_mask(self, aa, esmaa, mask, pattern):
new_aa = aa.clone()
target = aa.clone()
new_esmaa = esmaa.clone()
new_aa[pattern == 1] = self.mask_idx
target[pattern != 1] = 0
new_esmaa[pattern == 1] = self.esm_dict_mask_idx
return new_aa, new_esmaa, target
@torch.no_grad()
def infer(
self,
seqs: Union[str, List[str]],
position_ids=None,
):
if isinstance(seqs, str):
lst = [seqs]
else:
lst = seqs
# Returns the raw outputs of the model given an input sequence.
device = next(self.parameters()).device
aatype = collate_dense_tensors(
[
torch.from_numpy(
residue_constants.sequence_to_onehot(
sequence=seq,
mapping=residue_constants.restype_order_with_x,
map_unknown_to_x=True,
)
)
.to(device)
.argmax(dim=1)
for seq in lst
]
) # B=1 x L
mask = collate_dense_tensors([aatype.new_ones(len(seq)) for seq in lst])
position_ids = (
torch.arange(aatype.shape[1], device=device).expand(len(lst), -1)
if position_ids is None
else position_ids.to(device)
)
if position_ids.ndim == 1:
position_ids = position_ids.unsqueeze(0)
return self.forward(
aatype,
mask,
position_ids=position_ids,
)
@staticmethod
def output_to_pdb(output: Dict) -> List[str]:
"""Returns the pbd (file) string from the model given the model output."""
output = {k: v.to("cpu").numpy() for k, v in output.items()}
pdbs = []
final_atom_positions = atom14_to_atom37(output["positions"][-1], output)
final_atom_mask = output["atom37_atom_exists"]
for i in range(output["aatype"].shape[0]):
aa = output["aatype"][i]
pred_pos = final_atom_positions[i]
mask = final_atom_mask[i]
resid = output["residue_index"][i] + 1
pred = OFProtein(
aatype=aa,
atom_positions=pred_pos,
atom_mask=mask,
residue_index=resid,
b_factors=output["plddt"][i],
)
pdbs.append(to_pdb(pred))
return pdbs
def infer_pdb(self, seqs, *args, **kwargs) -> str:
"""Returns the pdb (file) string from the model given an input sequence."""
assert isinstance(seqs, str)
output = self.infer(seqs, *args, **kwargs)
return self.output_to_pdb(output)[0]
def infer_pdbs(self, seqs: List[str], *args, **kwargs) -> List[str]:
"""Returns the pdb (file) string from the model given an input sequence."""
output = self.infer(seqs, *args, **kwargs)
return self.output_to_pdb(output)