Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bugs and improve visualization #100

Merged
merged 5 commits into from
Jul 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions proteinflow/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@
"F": "PHE",
"Y": "TYR",
"W": "TRP",
"-": "GLY",
}
ATOM_MAP_4 = {a: ["N", "C", "CA", "O"] for a in ONE_TO_THREE_LETTER_MAP.keys()}
ATOM_MAP_1 = {a: ["CA"] for a in ONE_TO_THREE_LETTER_MAP.keys()}
Expand Down
100 changes: 78 additions & 22 deletions proteinflow/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,16 @@
import os
import pickle
import tempfile
import urllib
import warnings
from collections import defaultdict

import numpy as np
import pandas as pd
import py3Dmol
import requests
from Bio import pairwise2
from biopandas.pdb import PandasPdb
from methodtools import lru_cache
from torch import Tensor

from proteinflow.constants import (
_PMAP,
Expand Down Expand Up @@ -121,6 +120,10 @@ def __init__(self, seqs, crds, masks, chain_ids, predict_masks=None, cdrs=None):
`'numpy'` arrays of shape `(L,)` where CDR residues are marked with the corresponding type (`'H1'`, `'L1'`, ...)

"""
if crds[0].shape[1] != 14:
raise ValueError(
"Coordinates array must have 14 atoms in the order of N, C, CA, O, sidechain atoms"
)
self.seq = {x: seq for x, seq in zip(chain_ids, seqs)}
self.crd = {x: crd for x, crd in zip(chain_ids, crds)}
self.mask = {x: mask for x, mask in zip(chain_ids, masks)}
Expand Down Expand Up @@ -363,9 +366,6 @@ def get_mask(self, chains=None, cdr=None, original=False):
mask : np.ndarray
Mask array where 1 indicates residues with known coordinates and 0
indicates missing values
chains : list of str, optional
If specified, only the masks of the specified chains are returned (in the same order);
otherwise, all masks are concatenated in alphabetical order of the chain IDs

"""
if cdr is not None and self.cdr is None:
Expand Down Expand Up @@ -465,7 +465,16 @@ def decode_cdr(cdr):
residues are marked with `'-'`

"""
return np.array([CDR_ALPHABET[x] for x in cdr])
cdr = ProteinEntry._to_numpy(cdr)
return np.array([CDR_ALPHABET[x] for x in cdr.astype(int)])

@staticmethod
def _to_numpy(arr):
if isinstance(arr, Tensor):
arr = arr.detach().cpu().numpy()
if isinstance(arr, list):
arr = np.array(arr)
return arr

@staticmethod
def decode_sequence(seq):
Expand All @@ -483,7 +492,8 @@ def decode_sequence(seq):
Amino acid sequence of the protein (one-letter code)

"""
return "".join([ALPHABET[x] for x in seq])
seq = ProteinEntry._to_numpy(seq)
return "".join([ALPHABET[x] for x in seq.astype(int)])

def rename_chains(self, chain_dict):
"""Rename the chains of the protein.
Expand Down Expand Up @@ -520,6 +530,10 @@ def merge(self, entry):
self.mask_original[chain] = entry.mask_original[chain]
self.cdr[chain] = entry.cdr[chain]
self.predict_mask[chain] = entry.predict_mask[chain]
if not all([x is None for x in entry.predict_mask.values()]):
for k, v in self.predict_mask.items():
if v is None:
self.predict_mask[k] = np.zeros(len(self.get_sequence(k)))

@staticmethod
def from_arrays(
Expand All @@ -532,7 +546,7 @@ def from_arrays(
seqs : np.ndarray
Amino acid sequences of the protein (encoded as integers, see `proteinflow.constants.ALPHABET`), `'numpy'` array of shape `(L,)`
crds : np.ndarray
Coordinates of the protein, `'numpy'` array of shape `(L, 14, 3)`
Coordinates of the protein, `'numpy'` array of shape `(L, 14, 3)` or `(L, 4, 3)`
masks : np.ndarray
Mask array where 1 indicates residues with known coordinates and 0
indicates missing values, `'numpy'` array of shape `(L,)`
Expand All @@ -557,18 +571,24 @@ def from_arrays(
crds_list = []
masks_list = []
chain_ids_list = []
predict_masks_list = None if predict_masks is None else {}
cdrs_list = None if cdrs is None else {}
for ind, chain_id in chain_id_dict.items():
predict_masks_list = None if predict_masks is None else []
cdrs_list = None if cdrs is None else []
if crds.shape[1] != 14:
crds_ = np.zeros((crds.shape[0], 14, 3))
crds_[:, :4, :] = ProteinEntry._to_numpy(crds)
crds = crds_
for chain_id, ind in chain_id_dict.items():
chain_ids_list.append(chain_id)
chain_mask = chain_id_array == ind
seqs_list.append(ProteinEntry.decode_sequence(seqs[chain_mask]))
crds_list.append(ProteinEntry.decode_cdr(crds[chain_mask]))
masks_list.append(masks[chain_mask])
crds_list.append(ProteinEntry._to_numpy(crds[chain_mask]))
masks_list.append(ProteinEntry._to_numpy(masks[chain_mask]))
if predict_masks is not None:
predict_masks_list.append(predict_masks[chain_mask])
predict_masks_list.append(
ProteinEntry._to_numpy(predict_masks[chain_mask])
)
if cdrs is not None:
cdrs_list.append(cdrs[chain_mask])
cdrs_list.append(ProteinEntry.decode_cdr(cdrs[chain_mask]))
return ProteinEntry(
seqs_list,
crds_list,
Expand Down Expand Up @@ -1194,6 +1214,33 @@ def _get_atom_dicts(self, highlight_mask=None, style="cartoon", opacity=1):
highlight_mask_dict=highlight_mask_dict, style=style, opacity=opacity
)

def get_predict_mask(self, chains=None):
"""Get the prediction mask of the protein.

The prediction mask is a `'numpy'` array of shape `(L,)` with ones
corresponding to residues that were generated by a model and zeros to
residues with known coordinates. If the prediction mask is not available,
`None` is returned.

Parameters
----------
chains : list of str, optional
If specified, only the prediction mask of the specified chains is returned (in the same order);
otherwise, all features are concatenated in alphabetical order of the chain IDs

Returns
-------
predict_mask : np.ndarray
A `'numpy'` array of shape `(L,)` with ones corresponding to residues that were generated by a model and
zeros to residues with known coordinates

"""
if list(self.predict_mask.values())[0] is None:
return None
chains = self._get_chains_list(chains)
predict_mask = np.concatenate([self.predict_mask[chain] for chain in chains])
return predict_mask

def visualize(self, highlight_mask=None, style="cartoon", opacity=1):
"""Visualize the protein in a notebook.

Expand All @@ -1205,14 +1252,19 @@ def visualize(self, highlight_mask=None, style="cartoon", opacity=1):
`self.predict_mask` is not `None`, the predicted residues are highlighted
style : str, default 'cartoon'
The style of the visualization; one of 'cartoon', 'sphere', 'stick', 'line', 'cross'
opacity : float, default 1
Opacity of the visualization
opacity : float or dict, default 1
Opacity of the visualization (can be a dictionary mapping from chain IDs to opacity values)

"""
print(f"{highlight_mask=}")
if highlight_mask is not None:
highlight_mask_dict = self._get_highlight_mask_dict(highlight_mask)
elif list(self.predict_mask.values())[0] is not None:
highlight_mask_dict = self.predict_mask
print("HERE")
highlight_mask_dict = {
chain: self.predict_mask[chain][self.get_mask([chain]).astype(bool)]
for chain in self.get_chains()
}
else:
highlight_mask_dict = None
with tempfile.NamedTemporaryFile(suffix=".pdb") as tmp:
Expand Down Expand Up @@ -1674,14 +1726,18 @@ def _get_atom_dicts(self, highlight_mask_dict=None, style="cartoon", opacity=1):
self._pdb_sequence(chain)
), "Mask length does not match sequence length"
for at in outstr:
if isinstance(opacity, dict):
op_ = opacity[at["chain"]]
else:
op_ = opacity
if at["resid"] != chain_last_res[at["chain"]]:
chain_last_res[at["chain"]] = at["resid"]
chain_counters[at["chain"]] += 1
at["pymol"] = {style: {"color": colors[at["chain"]], "opacity": opacity}}
at["pymol"] = {style: {"color": colors[at["chain"]], "opacity": op_}}
if highlight_mask_dict is not None and at["chain"] in highlight_mask_dict:
num = chain_counters[at["chain"]]
if highlight_mask_dict[at["chain"]][num - 1] == 1:
at["pymol"] = {style: {"color": "red", "opacity": opacity}}
at["pymol"] = {style: {"color": "red", "opacity": op_}}
return outstr

def visualize(self, highlight_mask_dict=None, style="cartoon", opacity=1):
Expand All @@ -1694,8 +1750,8 @@ def visualize(self, highlight_mask_dict=None, style="cartoon", opacity=1):
the atoms corresponding to 1s will be highlighted in red
style : str, default 'cartoon'
The style of the visualization; one of 'cartoon', 'sphere', 'stick', 'line', 'cross'
opacity : float, default 1
Opacity of the visualization
opacity : float or dict, default 1
Opacity of the visualization (can be a dictionary mapping from chain IDs to opacity values)

"""
outstr = self._get_atom_dicts(highlight_mask_dict, style=style, opacity=opacity)
Expand Down
46 changes: 0 additions & 46 deletions proteinflow/data/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,52 +19,6 @@
class _PadCollate:
"""A variant of `collate_fn` that pads according to the longest sequence in a batch of sequences."""

def __init__(
self,
mask_residues=True,
lower_limit=15,
upper_limit=100,
mask_frac=None,
mask_whole_chains=False,
force_binding_sites_frac=0.15,
mask_all_cdrs=False,
):
"""Initialize a _PadCollate object.

Parameters
----------
batch : dict
a batch generated by `ProteinDataset` and `PadCollate`
lower_limit : int, default 15
the minimum number of residues to mask
upper_limit : int, default 100
the maximum number of residues to mask
mask_frac : float, optional
if given, the `lower_limit` and `upper_limit` are ignored and the number of residues to mask is `mask_frac` times the length of the chain
mask_whole_chains : bool, default False
if `True`, `upper_limit`, `force_binding_sites` and `lower_limit` are ignored and the whole chain is masked instead
force_binding_sites_frac : float, default 0.15
if > 0, in the fraction of cases where a chain from a polymer is sampled, the center of the masked region will be
forced to be in a binding site
mask_all_cdrs : bool, default False
if `True`, all CDRs are masked

Returns
-------
chain_M : torch.Tensor
a `(B, L)` shaped binary tensor where 1 denotes the part that needs to be predicted and
0 is everything else

"""
super().__init__()
self.mask_residues = mask_residues
self.lower_limit = lower_limit
self.upper_limit = upper_limit
self.mask_frac = mask_frac
self.mask_whole_chains = mask_whole_chains
self.force_binding_sites_frac = force_binding_sites_frac
self.mask_all_cdrs = mask_all_cdrs

def pad_collate(self, batch):
# find longest sequence
out = {}
Expand Down
7 changes: 6 additions & 1 deletion proteinflow/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ def __init__(
"""
seq = protein_entry.get_sequence()
coords = protein_entry.get_coordinates()
mask = protein_entry.get_mask().astype(bool)
seq = "".join(np.array(list(seq))[mask])
coords = coords[mask]
if (coords[:, 4:] == 0).all():
only_backbone = True
if only_ca:
coords = coords[:, 2, :].unsqueeze(1)
elif skip_oxygens:
Expand All @@ -66,7 +71,7 @@ def __init__(
self.seq = seq
self.mapping = self._make_mapping_from_seq()

self.chain_ids = protein_entry.get_chain_id_array(encode=False)
self.chain_ids = protein_entry.get_chain_id_array(encode=False)[mask]
self.chain_ids_unique = protein_entry.get_chains()

# PDB Formatting Information
Expand Down
25 changes: 19 additions & 6 deletions proteinflow/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ def show_animation_from_pickle(
models = ""
for i, mol in enumerate(entries):
models += "MODEL " + str(i) + "\n"
if highlight_mask is None:
highlight_mask = mol.get_predict_mask()
atoms = mol._get_atom_dicts(
highlight_mask=highlight_mask, style=style, opacity=opacity
)
Expand Down Expand Up @@ -124,25 +126,32 @@ def show_merged_pickle(file_paths, highlight_masks=None, style="cartoon", opacit
the chains to be concatenated in alphabetical order.
style : str, optional
The style of the visualization; one of 'cartoon', 'sphere', 'stick', 'line', 'cross'
opacity : float, default 1
opacity : float or list, default 1
The opacity of the visualization.

"""
create_fn = ProteinEntry.from_pickle
entries = [create_fn(path) for path in file_paths]
alphabet = list(string.ascii_uppercase)
opacity_dict = {}
if isinstance(opacity, float):
opacity = [opacity] * len(entries)
for i, entry in enumerate(entries):
entry.rename_chains({chain: alphabet.pop(0) for chain in entry.get_chains()})
update_dict = {chain: alphabet.pop(0) for chain in entry.get_chains()}
entry.rename_chains(update_dict)
opacity_dict.update({chain: opacity[i] for chain in entry.get_chains()})
if highlight_masks is not None and highlight_masks[i] is None:
highlight_masks[i] = np.zeros(len(entry))
highlight_masks[i] = np.zeros(entry.get_mask().sum())
merged_entry = entries[0]
for entry in entries[1:]:
merged_entry.merge(entry)
if highlight_masks is not None:
highlight_mask = np.concatenate(highlight_masks, axis=0)
else:
highlight_mask = None
merged_entry.visualize(style=style, highlight_mask=highlight_mask, opacity=opacity)
merged_entry.visualize(
style=style, highlight_mask=highlight_mask, opacity=opacity_dict
)


def show_merged_pdb(file_paths, highlight_mask_dicts=None, style="cartoon", opacity=1):
Expand All @@ -157,17 +166,21 @@ def show_merged_pdb(file_paths, highlight_mask_dicts=None, style="cartoon", opac
1s and 0s, where 1s indicate the atoms to highlight.
style : str, optional
The style of the visualization; one of 'cartoon', 'sphere', 'stick', 'line', 'cross'
opacity : float, default 1
opacity : float or list, default 1
The opacity of the visualization.

"""
create_fn = PDBEntry
entries = [create_fn(path) for path in file_paths]
alphabet = list(string.ascii_uppercase)
highlight_mask_dict = {} if highlight_mask_dicts is not None else None
opacity_dict = {}
if isinstance(opacity, float):
opacity = [opacity] * len(entries)
for i, entry in enumerate(entries):
update_dict = {chain: alphabet.pop(0) for chain in entry.get_chains()}
entry.rename_chains(update_dict)
opacity_dict.update({chain: opacity[i] for chain in entry.get_chains()})
if highlight_mask_dicts is not None:
highlight_mask_dict.update(
{
Expand All @@ -179,5 +192,5 @@ def show_merged_pdb(file_paths, highlight_mask_dicts=None, style="cartoon", opac
for entry in entries[1:]:
merged_entry.merge(entry)
merged_entry.visualize(
style=style, highlight_mask_dict=highlight_mask_dict, opacity=opacity
style=style, highlight_mask_dict=highlight_mask_dict, opacity=opacity_dict
)