Skip to content

Commit

Permalink
Merge pull request #96 from adaptyvbio/visualize
Browse files Browse the repository at this point in the history
Add visualization functions
  • Loading branch information
elkoz authored Jul 21, 2023
2 parents 113581f + 2a4d43a commit 446a8e5
Show file tree
Hide file tree
Showing 8 changed files with 290 additions and 19 deletions.
Binary file added 1afv.cif.gz
Binary file not shown.
6 changes: 6 additions & 0 deletions 1afv.fasta
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
>1AFV_1|Chains A, B|HUMAN IMMUNODEFICIENCY VIRUS TYPE 1 CAPSID PROTEIN|Human immunodeficiency virus 1 (11676)
PIVQNLQGQMVHQAISPRTLNAWVKVVEEKAFSPEVIPMFSALSEGATPQDLNTMLNTVGGHQAAMQMLKETINEEAAEWDRLHPVHAGPIAPGQMREPRGSDIAGTTSTLQEQIGWMTHNPPIPVGEIYKRWIILGLNKIVRMYSPTSIL
>1AFV_2|Chains C[auth L], E[auth M]|ANTIBODY FAB25.3 FRAGMENT (LIGHT CHAIN)|Mus musculus (10090)
DIVLTQSPASLAVSLGQRATISCRASESVDNYGISFMNWFQQKPGQPPKLLIYAASNLGSGVPARFSGSGSGTDFSLNIHPMEEEDTAMYFCQQSKEVPLTFGAGTKVELKRADAAPTVSIFPPSSEQLTSGGASVVCFLNNFYPKDINVKWKIDGSERQNGVLNSWTDQDSKDSTYSMSSTLTLTKDEYERHNSYTCEATHKTSTSPIVKSFNRNE
>1AFV_3|Chains D[auth H], F[auth K]|ANTIBODY FAB25.3 FRAGMENT (HEAVY CHAIN)|Mus musculus (10090)
QVQLQQPGSVLVRPGASVKLSCKASGYTFTSSWIHWAKQRPGQGLEWIGEIHPNSGNTNYNEKFKGKATLTVDTSSSTAYVDLSSLTSEDSAVYYCARWRYGSPYYFDYWGQGTTLTVSSAKTTPPSVYPLAPGSAAQTNSMVTLGCLVKGYFPEPVTVTWNSGSLSSGVHTFPAVLQSDLYTLSSSVTVPSSTWPSETVTCNVAHPASSTKVDKKIVPK
4 changes: 4 additions & 0 deletions proteinflow/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@


def _PMAP(x):
"""Get chemical properties of amino acid."""
return [
FEATURES_DICT["hydropathy"][x] / 5,
FEATURES_DICT["volume"][x] / 200,
Expand All @@ -239,3 +240,6 @@ def _PMAP(x):
FEATURES_DICT["acceptor"][x],
FEATURES_DICT["donor"][x],
]


COLORS = ["#62B9DC", "#EAB1C5", "#0C9094", "#8090E6", "#96E396", "#FCAC97", "#740E66"]
191 changes: 172 additions & 19 deletions proteinflow/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,14 @@
"""
import os
import pickle
import tempfile
import urllib
import warnings
from collections import defaultdict

import numpy as np
import pandas as pd
import py3Dmol
import requests
from Bio import pairwise2
from biopandas.pdb import PandasPdb
Expand All @@ -31,6 +34,7 @@
CDR_ALPHABET,
CDR_REVERSE,
CDR_VALUES,
COLORS,
D3TO1,
MAIN_ATOM_DICT,
SIDECHAIN_ORDER,
Expand All @@ -40,6 +44,7 @@
PDBBuilder,
PDBError,
_annotate_sse,
_Atom,
_dihedral_angle,
_retrieve_chain_names,
)
Expand Down Expand Up @@ -506,6 +511,44 @@ def from_dict(dictionary):
cdr = [dictionary[k].get("cdr", None) for k in chains]
return ProteinEntry(seqs=seq, crds=crd, masks=mask, cdrs=cdr, chain_ids=chains)

@staticmethod
def from_pdb_entry(pdb_entry):
"""Load a protein entry from a `PDBEntry` object.
Parameters
----------
pdb_entry : PDBEntry
A `PDBEntry` object
Returns
-------
entry : ProteinEntry
A `ProteinEntry` object
"""
pdb_dict = {}
fasta_dict = pdb_entry.get_fasta()
for (chain,) in pdb_entry.get_chains():
pdb_dict[chain] = {}
fasta_seq = fasta_dict[chain]

# align fasta and pdb and check criteria)
mask = pdb_entry.get_mask([chain])[chain]
if isinstance(pdb_entry, SAbDabEntry):
pdb_dict[chain]["cdr"] = pdb_entry.get_cdr([chain])[chain]
pdb_dict[chain]["seq"] = fasta_seq
pdb_dict[chain]["msk"] = mask

# go over rows of coordinates
crd_arr = pdb_entry.get_coordinates_array(chain)

pdb_dict[chain]["crd_bb"] = crd_arr[:, :4, :]
pdb_dict[chain]["crd_sc"] = crd_arr[:, 4:, :]
pdb_dict[chain]["msk"][
(pdb_dict[chain]["crd_bb"] == 0).sum(-1).sum(-1) > 0
] = 0
return ProteinEntry.from_dict(pdb_dict)

@staticmethod
def from_pdb(
pdb_path,
Expand Down Expand Up @@ -546,29 +589,48 @@ def from_pdb(
)
else:
pdb_entry = PDBEntry(pdb_path=pdb_path, fasta_path=fasta_path)
pdb_dict = {}
fasta_dict = pdb_entry.get_fasta()
return ProteinEntry.from_pdb_entry(pdb_entry)

for (chain,) in pdb_entry.get_chains():
pdb_dict[chain] = {}
fasta_seq = fasta_dict[chain]
@staticmethod
def from_id(
pdb_id,
local_folder=".",
heavy_chain=None,
light_chain=None,
antigen_chains=None,
):
"""Load a protein entry from a PDB file.
# align fasta and pdb and check criteria)
mask = pdb_entry.get_mask([chain])[chain]
if isinstance(pdb_entry, SAbDabEntry):
pdb_dict[chain]["cdr"] = pdb_entry.get_cdr([chain])[chain]
pdb_dict[chain]["seq"] = fasta_seq
pdb_dict[chain]["msk"] = mask
Parameters
----------
pdb_id : str
PDB ID of the protein
local_folder : str, default "."
Path to the local folder where the PDB file is saved
heavy_chain : str, optional
Chain ID of the heavy chain (to load a SAbDab entry)
light_chain : str, optional
Chain ID of the light chain (to load a SAbDab entry)
antigen_chains : list of str, optional
Chain IDs of the antigen chains (to load a SAbDab entry)
# go over rows of coordinates
crd_arr = pdb_entry.get_coordinates_array(chain)
Returns
-------
entry : ProteinEntry
A `ProteinEntry` object
pdb_dict[chain]["crd_bb"] = crd_arr[:, :4, :]
pdb_dict[chain]["crd_sc"] = crd_arr[:, 4:, :]
pdb_dict[chain]["msk"][
(pdb_dict[chain]["crd_bb"] == 0).sum(-1).sum(-1) > 0
] = 0
return ProteinEntry.from_dict(pdb_dict)
"""
if heavy_chain is not None or light_chain is not None:
pdb_entry = SAbDabEntry.from_id(
pdb_id=pdb_id,
local_folder=local_folder,
heavy_chain=heavy_chain,
light_chain=light_chain,
antigen_chains=antigen_chains,
)
else:
pdb_entry = PDBEntry.from_id(pdb_id=pdb_id)
return ProteinEntry.from_pdb_entry(pdb_entry)

@staticmethod
def from_pickle(path):
Expand Down Expand Up @@ -991,6 +1053,46 @@ def get_chain_id_array(self, chains=None, encode=True):
start_index += chain_length
return index_array

def _get_highlight_mask_dict(self, highlight_mask=None):
chain_arr = self.get_chain_id_array(encode=False)
mask_arr = self.get_mask().astype(bool)
highlight_mask_dict = {}
if highlight_mask is not None:
chains = self.get_chains()
for chain in chains:
chain_mask = chain_arr == chain
pdb_highlight = highlight_mask[mask_arr & chain_mask]
highlight_mask_dict[chain] = pdb_highlight
return highlight_mask_dict

def _get_atom_dicts(self, highlight_mask=None, style="cartoon"):
"""Get the atom dictionaries of the protein."""
highlight_mask_dict = self._get_highlight_mask_dict(highlight_mask)
with tempfile.NamedTemporaryFile(suffix=".pdb") as tmp:
self.to_pdb(tmp.name)
pdb_entry = PDBEntry(tmp.name)
return pdb_entry._get_atom_dicts(
highlight_mask_dict=highlight_mask_dict, style=style
)

def visualize(self, highlight_mask=None, style="cartoon"):
"""Visualize the protein in a notebook.
Parameters
----------
highlight_mask : np.ndarray, optional
A `'numpy'` array of shape `(L,)` with the residues to highlight
marked with 1 and the rest marked with 0
style : str, default 'cartoon'
The style of the visualization; one of 'cartoon', 'sphere', 'stick', 'line', 'cross'
"""
highlight_mask_dict = self._get_highlight_mask_dict(highlight_mask)
with tempfile.NamedTemporaryFile(suffix=".pdb") as tmp:
self.to_pdb(tmp.name)
pdb_entry = PDBEntry(tmp.name)
pdb_entry.visualize(highlight_mask_dict=highlight_mask_dict)


class PDBEntry:
"""A class for parsing PDB entries."""
Expand Down Expand Up @@ -1355,6 +1457,57 @@ def get_unique_residue_numbers(self, chain):
"""
return self.get_pdb_df(chain)["unique_residue_number"].unique().tolist()

def _get_atom_dicts(self, highlight_mask_dict=None, style="cartoon"):
"""Get the atom dictionaries for visualization."""
assert style in ["cartoon", "sphere", "stick", "line", "cross"]
outstr = []
df_ = self.crd_df.sort_values(["chain_id", "residue_number"], inplace=False)
for _, row in df_.iterrows():
outstr.append(_Atom(row))
chains = self.get_chains()
colors = {ch: COLORS[i % len(COLORS)] for i, ch in enumerate(chains)}
chain_counters = defaultdict(int)
chain_last_res = defaultdict(lambda: None)
if highlight_mask_dict is not None:
for chain, mask in highlight_mask_dict.items():
assert len(mask) == len(
self._pdb_sequence(chain)
), "Mask length does not match sequence length"
for at in outstr:
if at["resid"] != chain_last_res[at["chain"]]:
chain_last_res[at["chain"]] = at["resid"]
chain_counters[at["chain"]] += 1
at["pymol"] = {style: {"color": colors[at["chain"]]}}
if highlight_mask_dict is not None and at["chain"] in highlight_mask_dict:
num = chain_counters[at["chain"]]
if highlight_mask_dict[at["chain"]][num - 1] == 1:
at["pymol"] = {style: {"color": "red"}}
return outstr

def visualize(self, highlight_mask_dict=None, style="cartoon"):
"""Visualize the protein in a notebook.
Parameters
----------
highlight_mask_dict : dict, optional
A dictionary mapping from chain IDs to a mask of 0s and 1s of the same length as the chain sequence;
the atoms corresponding to 1s will be highlighted in red
style : str, default 'cartoon'
The style of the visualization; one of 'cartoon', 'sphere', 'stick', 'line', 'cross'
"""
outstr = self._get_atom_dicts(highlight_mask_dict, style=style)
vis_string = "".join([str(x) for x in outstr])
view = py3Dmol.view(width=400, height=300)
view.addModelsAsFrames(vis_string)
for i, at in enumerate(outstr):
view.setStyle(
{"model": -1, "serial": i + 1},
at["pymol"],
)
view.zoomTo()
view.show()


class SAbDabEntry(PDBEntry):
"""A class for parsing SAbDab entries."""
Expand Down
30 changes: 30 additions & 0 deletions proteinflow/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,11 +449,13 @@ def read_mmcif(self, path: str):
"label_comp_id": "residue_name",
"label_seq_id": "residue_number",
"label_atom_id": "atom_name",
"id": "atom_number",
"group_PDB": "record_name",
"Cartn_x": "x_coord",
"Cartn_y": "y_coord",
"Cartn_z": "z_coord",
"pdbx_PDB_ins_code": "insertion",
"type_symbol": "element_symbol",
},
axis=1,
inplace=True,
Expand Down Expand Up @@ -519,3 +521,31 @@ def _retrieve_chain_names(entry):
return [_retrieve_author_chain(e) for e in entry[7:].split(", ")]

return [_retrieve_author_chain(entry[6:])]


class _Atom(dict):
def __init__(self, row):
self["type"] = row["record_name"]
self["idx"] = row["atom_number"]
self["name"] = row["atom_name"]
self["resname"] = row["residue_name"]
self["resid"] = row["residue_number"]
self["chain"] = row["chain_id"]
self["x"] = row["x_coord"]
self["y"] = row["y_coord"]
self["z"] = row["z_coord"]
self["sym"] = row["element_symbol"]

def __str__(self):
line = list(" " * 80)

line[0:6] = self["type"].ljust(6)
line[6:11] = str(self["idx"]).ljust(5)
line[12:16] = self["name"].ljust(4)
line[17:20] = self["resname"].ljust(3)
line[22:26] = str(self["resid"]).ljust(4)
line[30:38] = str(self["x"]).rjust(8)
line[38:46] = str(self["y"]).rjust(8)
line[46:54] = str(self["z"]).rjust(8)
line[76:78] = self["sym"].rjust(2)
return "".join(line) + "\n"
2 changes: 2 additions & 0 deletions proteinflow/download/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def download_pdb(pdb_id, local_folder=".", sabdab=False):
Path to the downloaded file
"""
pdb_id = pdb_id.lower()
if sabdab:
try:
url = f"https://opig.stats.ox.ac.uk/webapps/sabdab-sabpred/sabdab/pdb/{pdb_id}/?scheme=chothia"
Expand Down Expand Up @@ -105,6 +106,7 @@ def download_fasta(pdb_id, local_folder="."):
Path to the downloaded file
"""
pdb_id = pdb_id.lower()
if "-" in pdb_id:
pdb_id = pdb_id.split("-")[0]
downloadurl = "https://www.rcsb.org/fasta/entry/"
Expand Down
Loading

0 comments on commit 446a8e5

Please sign in to comment.