From e4c96d9a46afd6ba36edb28af60f7b74080855db Mon Sep 17 00:00:00 2001 From: Liza Kozlova Date: Mon, 24 Jul 2023 17:30:47 +0000 Subject: [PATCH 1/3] fix: make opacity more flexible + fix some data class bugs --- proteinflow/constants.py | 1 + proteinflow/data/__init__.py | 67 ++++++++++++++++++++++++------------ proteinflow/data/utils.py | 7 +++- proteinflow/visualize.py | 23 +++++++++---- 4 files changed, 69 insertions(+), 29 deletions(-) diff --git a/proteinflow/constants.py b/proteinflow/constants.py index 6e7c6b6..4bd3111 100644 --- a/proteinflow/constants.py +++ b/proteinflow/constants.py @@ -223,6 +223,7 @@ "F": "PHE", "Y": "TYR", "W": "TRP", + "-": "GLY", } ATOM_MAP_4 = {a: ["N", "C", "CA", "O"] for a in ONE_TO_THREE_LETTER_MAP.keys()} ATOM_MAP_1 = {a: ["CA"] for a in ONE_TO_THREE_LETTER_MAP.keys()} diff --git a/proteinflow/data/__init__.py b/proteinflow/data/__init__.py index e9ccf5e..07a1006 100644 --- a/proteinflow/data/__init__.py +++ b/proteinflow/data/__init__.py @@ -13,17 +13,16 @@ import os import pickle import tempfile -import urllib import warnings from collections import defaultdict import numpy as np import pandas as pd import py3Dmol -import requests from Bio import pairwise2 from biopandas.pdb import PandasPdb from methodtools import lru_cache +from torch import Tensor from proteinflow.constants import ( _PMAP, @@ -121,6 +120,10 @@ def __init__(self, seqs, crds, masks, chain_ids, predict_masks=None, cdrs=None): `'numpy'` arrays of shape `(L,)` where CDR residues are marked with the corresponding type (`'H1'`, `'L1'`, ...) """ + if crds[0].shape[1] != 14: + raise ValueError( + "Coordinates array must have 14 atoms in the order of N, C, CA, O, sidechain atoms" + ) self.seq = {x: seq for x, seq in zip(chain_ids, seqs)} self.crd = {x: crd for x, crd in zip(chain_ids, crds)} self.mask = {x: mask for x, mask in zip(chain_ids, masks)} @@ -363,9 +366,6 @@ def get_mask(self, chains=None, cdr=None, original=False): mask : np.ndarray Mask array where 1 indicates residues with known coordinates and 0 indicates missing values - chains : list of str, optional - If specified, only the masks of the specified chains are returned (in the same order); - otherwise, all masks are concatenated in alphabetical order of the chain IDs """ if cdr is not None and self.cdr is None: @@ -465,7 +465,16 @@ def decode_cdr(cdr): residues are marked with `'-'` """ - return np.array([CDR_ALPHABET[x] for x in cdr]) + cdr = ProteinEntry._to_numpy(cdr) + return np.array([CDR_ALPHABET[x] for x in cdr.astype(int)]) + + @staticmethod + def _to_numpy(arr): + if isinstance(arr, Tensor): + arr = arr.detach().cpu().numpy() + if isinstance(arr, list): + arr = np.array(arr) + return arr @staticmethod def decode_sequence(seq): @@ -483,7 +492,8 @@ def decode_sequence(seq): Amino acid sequence of the protein (one-letter code) """ - return "".join([ALPHABET[x] for x in seq]) + seq = ProteinEntry._to_numpy(seq) + return "".join([ALPHABET[x] for x in seq.astype(int)]) def rename_chains(self, chain_dict): """Rename the chains of the protein. @@ -532,7 +542,7 @@ def from_arrays( seqs : np.ndarray Amino acid sequences of the protein (encoded as integers, see `proteinflow.constants.ALPHABET`), `'numpy'` array of shape `(L,)` crds : np.ndarray - Coordinates of the protein, `'numpy'` array of shape `(L, 14, 3)` + Coordinates of the protein, `'numpy'` array of shape `(L, 14, 3)` or `(L, 4, 3)` masks : np.ndarray Mask array where 1 indicates residues with known coordinates and 0 indicates missing values, `'numpy'` array of shape `(L,)` @@ -557,18 +567,24 @@ def from_arrays( crds_list = [] masks_list = [] chain_ids_list = [] - predict_masks_list = None if predict_masks is None else {} - cdrs_list = None if cdrs is None else {} - for ind, chain_id in chain_id_dict.items(): + predict_masks_list = None if predict_masks is None else [] + cdrs_list = None if cdrs is None else [] + if crds.shape[1] != 14: + crds_ = np.zeros((crds.shape[0], 14, 3)) + crds_[:, :4, :] = ProteinEntry._to_numpy(crds) + crds = crds_ + for chain_id, ind in chain_id_dict.items(): chain_ids_list.append(chain_id) chain_mask = chain_id_array == ind seqs_list.append(ProteinEntry.decode_sequence(seqs[chain_mask])) - crds_list.append(ProteinEntry.decode_cdr(crds[chain_mask])) - masks_list.append(masks[chain_mask]) + crds_list.append(ProteinEntry._to_numpy(crds[chain_mask])) + masks_list.append(ProteinEntry._to_numpy(masks[chain_mask])) if predict_masks is not None: - predict_masks_list.append(predict_masks[chain_mask]) + predict_masks_list.append( + ProteinEntry._to_numpy(predict_masks[chain_mask]) + ) if cdrs is not None: - cdrs_list.append(cdrs[chain_mask]) + cdrs_list.append(ProteinEntry.decode_cdr(cdrs[chain_mask])) return ProteinEntry( seqs_list, crds_list, @@ -1205,14 +1221,17 @@ def visualize(self, highlight_mask=None, style="cartoon", opacity=1): `self.predict_mask` is not `None`, the predicted residues are highlighted style : str, default 'cartoon' The style of the visualization; one of 'cartoon', 'sphere', 'stick', 'line', 'cross' - opacity : float, default 1 - Opacity of the visualization + opacity : float or dict, default 1 + Opacity of the visualization (can be a dictionary mapping from chain IDs to opacity values) """ if highlight_mask is not None: highlight_mask_dict = self._get_highlight_mask_dict(highlight_mask) elif list(self.predict_mask.values())[0] is not None: - highlight_mask_dict = self.predict_mask + highlight_mask_dict = { + chain: self.predict_mask[chain][self.get_mask([chain]).astype(bool)] + for chain in self.get_chains() + } else: highlight_mask_dict = None with tempfile.NamedTemporaryFile(suffix=".pdb") as tmp: @@ -1674,14 +1693,18 @@ def _get_atom_dicts(self, highlight_mask_dict=None, style="cartoon", opacity=1): self._pdb_sequence(chain) ), "Mask length does not match sequence length" for at in outstr: + if isinstance(opacity, dict): + op_ = opacity[at["chain"]] + else: + op_ = opacity if at["resid"] != chain_last_res[at["chain"]]: chain_last_res[at["chain"]] = at["resid"] chain_counters[at["chain"]] += 1 - at["pymol"] = {style: {"color": colors[at["chain"]], "opacity": opacity}} + at["pymol"] = {style: {"color": colors[at["chain"]], "opacity": op_}} if highlight_mask_dict is not None and at["chain"] in highlight_mask_dict: num = chain_counters[at["chain"]] if highlight_mask_dict[at["chain"]][num - 1] == 1: - at["pymol"] = {style: {"color": "red", "opacity": opacity}} + at["pymol"] = {style: {"color": "red", "opacity": op_}} return outstr def visualize(self, highlight_mask_dict=None, style="cartoon", opacity=1): @@ -1694,8 +1717,8 @@ def visualize(self, highlight_mask_dict=None, style="cartoon", opacity=1): the atoms corresponding to 1s will be highlighted in red style : str, default 'cartoon' The style of the visualization; one of 'cartoon', 'sphere', 'stick', 'line', 'cross' - opacity : float, default 1 - Opacity of the visualization + opacity : float or dict, default 1 + Opacity of the visualization (can be a dictionary mapping from chain IDs to opacity values) """ outstr = self._get_atom_dicts(highlight_mask_dict, style=style, opacity=opacity) diff --git a/proteinflow/data/utils.py b/proteinflow/data/utils.py index 9e50b5a..7caa6f8 100644 --- a/proteinflow/data/utils.py +++ b/proteinflow/data/utils.py @@ -41,6 +41,11 @@ def __init__( """ seq = protein_entry.get_sequence() coords = protein_entry.get_coordinates() + mask = protein_entry.get_mask().astype(bool) + seq = "".join(np.array(list(seq))[mask]) + coords = coords[mask] + if (coords[:, 4:] == 0).all(): + only_backbone = True if only_ca: coords = coords[:, 2, :].unsqueeze(1) elif skip_oxygens: @@ -66,7 +71,7 @@ def __init__( self.seq = seq self.mapping = self._make_mapping_from_seq() - self.chain_ids = protein_entry.get_chain_id_array(encode=False) + self.chain_ids = protein_entry.get_chain_id_array(encode=False)[mask] self.chain_ids_unique = protein_entry.get_chains() # PDB Formatting Information diff --git a/proteinflow/visualize.py b/proteinflow/visualize.py index 7c0a7e1..10b49e2 100644 --- a/proteinflow/visualize.py +++ b/proteinflow/visualize.py @@ -124,17 +124,22 @@ def show_merged_pickle(file_paths, highlight_masks=None, style="cartoon", opacit the chains to be concatenated in alphabetical order. style : str, optional The style of the visualization; one of 'cartoon', 'sphere', 'stick', 'line', 'cross' - opacity : float, default 1 + opacity : float or list, default 1 The opacity of the visualization. """ create_fn = ProteinEntry.from_pickle entries = [create_fn(path) for path in file_paths] alphabet = list(string.ascii_uppercase) + opacity_dict = {} + if isinstance(opacity, float): + opacity = [opacity] * len(entries) for i, entry in enumerate(entries): - entry.rename_chains({chain: alphabet.pop(0) for chain in entry.get_chains()}) + update_dict = {chain: alphabet.pop(0) for chain in entry.get_chains()} + entry.rename_chains(update_dict) + opacity_dict.update({chain: opacity[i] for chain in entry.get_chains()}) if highlight_masks is not None and highlight_masks[i] is None: - highlight_masks[i] = np.zeros(len(entry)) + highlight_masks[i] = np.zeros(entry.get_mask().sum()) merged_entry = entries[0] for entry in entries[1:]: merged_entry.merge(entry) @@ -142,7 +147,9 @@ def show_merged_pickle(file_paths, highlight_masks=None, style="cartoon", opacit highlight_mask = np.concatenate(highlight_masks, axis=0) else: highlight_mask = None - merged_entry.visualize(style=style, highlight_mask=highlight_mask, opacity=opacity) + merged_entry.visualize( + style=style, highlight_mask=highlight_mask, opacity=opacity_dict + ) def show_merged_pdb(file_paths, highlight_mask_dicts=None, style="cartoon", opacity=1): @@ -157,7 +164,7 @@ def show_merged_pdb(file_paths, highlight_mask_dicts=None, style="cartoon", opac 1s and 0s, where 1s indicate the atoms to highlight. style : str, optional The style of the visualization; one of 'cartoon', 'sphere', 'stick', 'line', 'cross' - opacity : float, default 1 + opacity : float or list, default 1 The opacity of the visualization. """ @@ -165,9 +172,13 @@ def show_merged_pdb(file_paths, highlight_mask_dicts=None, style="cartoon", opac entries = [create_fn(path) for path in file_paths] alphabet = list(string.ascii_uppercase) highlight_mask_dict = {} if highlight_mask_dicts is not None else None + opacity_dict = {} + if isinstance(opacity, float): + opacity = [opacity] * len(entries) for i, entry in enumerate(entries): update_dict = {chain: alphabet.pop(0) for chain in entry.get_chains()} entry.rename_chains(update_dict) + opacity_dict.update({chain: opacity[i] for chain in entry.get_chains()}) if highlight_mask_dicts is not None: highlight_mask_dict.update( { @@ -179,5 +190,5 @@ def show_merged_pdb(file_paths, highlight_mask_dicts=None, style="cartoon", opac for entry in entries[1:]: merged_entry.merge(entry) merged_entry.visualize( - style=style, highlight_mask_dict=highlight_mask_dict, opacity=opacity + style=style, highlight_mask_dict=highlight_mask_dict, opacity=opacity_dict ) From 35bb3b1f54ab3bdc3e72fc94402507b34a2b13b2 Mon Sep 17 00:00:00 2001 From: Liza Kozlova Date: Mon, 24 Jul 2023 18:02:39 +0000 Subject: [PATCH 2/3] fix: minor bugs --- proteinflow/data/__init__.py | 4 ++++ proteinflow/data/torch.py | 46 ------------------------------------ 2 files changed, 4 insertions(+), 46 deletions(-) diff --git a/proteinflow/data/__init__.py b/proteinflow/data/__init__.py index 07a1006..adf8dd1 100644 --- a/proteinflow/data/__init__.py +++ b/proteinflow/data/__init__.py @@ -530,6 +530,10 @@ def merge(self, entry): self.mask_original[chain] = entry.mask_original[chain] self.cdr[chain] = entry.cdr[chain] self.predict_mask[chain] = entry.predict_mask[chain] + if not all([x is None for x in entry.predict_mask.values()]): + for k, v in self.predict_mask.items(): + if v is None: + self.predict_mask[k] = np.zeros(len(self.get_sequence(k))) @staticmethod def from_arrays( diff --git a/proteinflow/data/torch.py b/proteinflow/data/torch.py index 6771035..cf8d105 100644 --- a/proteinflow/data/torch.py +++ b/proteinflow/data/torch.py @@ -19,52 +19,6 @@ class _PadCollate: """A variant of `collate_fn` that pads according to the longest sequence in a batch of sequences.""" - def __init__( - self, - mask_residues=True, - lower_limit=15, - upper_limit=100, - mask_frac=None, - mask_whole_chains=False, - force_binding_sites_frac=0.15, - mask_all_cdrs=False, - ): - """Initialize a _PadCollate object. - - Parameters - ---------- - batch : dict - a batch generated by `ProteinDataset` and `PadCollate` - lower_limit : int, default 15 - the minimum number of residues to mask - upper_limit : int, default 100 - the maximum number of residues to mask - mask_frac : float, optional - if given, the `lower_limit` and `upper_limit` are ignored and the number of residues to mask is `mask_frac` times the length of the chain - mask_whole_chains : bool, default False - if `True`, `upper_limit`, `force_binding_sites` and `lower_limit` are ignored and the whole chain is masked instead - force_binding_sites_frac : float, default 0.15 - if > 0, in the fraction of cases where a chain from a polymer is sampled, the center of the masked region will be - forced to be in a binding site - mask_all_cdrs : bool, default False - if `True`, all CDRs are masked - - Returns - ------- - chain_M : torch.Tensor - a `(B, L)` shaped binary tensor where 1 denotes the part that needs to be predicted and - 0 is everything else - - """ - super().__init__() - self.mask_residues = mask_residues - self.lower_limit = lower_limit - self.upper_limit = upper_limit - self.mask_frac = mask_frac - self.mask_whole_chains = mask_whole_chains - self.force_binding_sites_frac = force_binding_sites_frac - self.mask_all_cdrs = mask_all_cdrs - def pad_collate(self, batch): # find longest sequence out = {} From 4a353bce61471737b6d278cfee036374cca24824 Mon Sep 17 00:00:00 2001 From: Liza Kozlova Date: Mon, 24 Jul 2023 18:13:21 +0000 Subject: [PATCH 3/3] fix: show predict mask in animation by default --- proteinflow/data/__init__.py | 29 +++++++++++++++++++++++++++++ proteinflow/visualize.py | 2 ++ 2 files changed, 31 insertions(+) diff --git a/proteinflow/data/__init__.py b/proteinflow/data/__init__.py index adf8dd1..beadb1a 100644 --- a/proteinflow/data/__init__.py +++ b/proteinflow/data/__init__.py @@ -1214,6 +1214,33 @@ def _get_atom_dicts(self, highlight_mask=None, style="cartoon", opacity=1): highlight_mask_dict=highlight_mask_dict, style=style, opacity=opacity ) + def get_predict_mask(self, chains=None): + """Get the prediction mask of the protein. + + The prediction mask is a `'numpy'` array of shape `(L,)` with ones + corresponding to residues that were generated by a model and zeros to + residues with known coordinates. If the prediction mask is not available, + `None` is returned. + + Parameters + ---------- + chains : list of str, optional + If specified, only the prediction mask of the specified chains is returned (in the same order); + otherwise, all features are concatenated in alphabetical order of the chain IDs + + Returns + ------- + predict_mask : np.ndarray + A `'numpy'` array of shape `(L,)` with ones corresponding to residues that were generated by a model and + zeros to residues with known coordinates + + """ + if list(self.predict_mask.values())[0] is None: + return None + chains = self._get_chains_list(chains) + predict_mask = np.concatenate([self.predict_mask[chain] for chain in chains]) + return predict_mask + def visualize(self, highlight_mask=None, style="cartoon", opacity=1): """Visualize the protein in a notebook. @@ -1229,9 +1256,11 @@ def visualize(self, highlight_mask=None, style="cartoon", opacity=1): Opacity of the visualization (can be a dictionary mapping from chain IDs to opacity values) """ + print(f"{highlight_mask=}") if highlight_mask is not None: highlight_mask_dict = self._get_highlight_mask_dict(highlight_mask) elif list(self.predict_mask.values())[0] is not None: + print("HERE") highlight_mask_dict = { chain: self.predict_mask[chain][self.get_mask([chain]).astype(bool)] for chain in self.get_chains() diff --git a/proteinflow/visualize.py b/proteinflow/visualize.py index 10b49e2..1fb4806 100644 --- a/proteinflow/visualize.py +++ b/proteinflow/visualize.py @@ -70,6 +70,8 @@ def show_animation_from_pickle( models = "" for i, mol in enumerate(entries): models += "MODEL " + str(i) + "\n" + if highlight_mask is None: + highlight_mask = mol.get_predict_mask() atoms = mol._get_atom_dicts( highlight_mask=highlight_mask, style=style, opacity=opacity )