Skip to content

Commit

Permalink
Merge pull request #108 from adaptyvbio/add_pdb_id
Browse files Browse the repository at this point in the history
Add IDs to ProteinEntry
  • Loading branch information
elkoz authored Aug 10, 2023
2 parents acb67ab + 48e9f4e commit b5cfe9b
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 8 deletions.
49 changes: 43 additions & 6 deletions proteinflow/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,16 @@ class ProteinEntry:
ATOM_ORDER = {k: BACKBONE_ORDER + v for k, v in SIDECHAIN_ORDER.items()}
"""A dictionary mapping 3-letter residue names to the order of atoms in the coordinates array."""

def __init__(self, seqs, crds, masks, chain_ids, predict_masks=None, cdrs=None):
def __init__(
self,
seqs,
crds,
masks,
chain_ids,
predict_masks=None,
cdrs=None,
protein_id=None,
):
"""Initialize a `ProteinEntry` object.
Parameters
Expand All @@ -118,6 +127,8 @@ def __init__(self, seqs, crds, masks, chain_ids, predict_masks=None, cdrs=None):
indicates residues with known coordinates
cdrs : list of np.ndarray, optional
`'numpy'` arrays of shape `(L,)` where CDR residues are marked with the corresponding type (`'H1'`, `'L1'`, ...)
protein_id : str, optional
ID of the protein
"""
if crds[0].shape[1] != 14:
Expand All @@ -134,6 +145,11 @@ def __init__(self, seqs, crds, masks, chain_ids, predict_masks=None, cdrs=None):
if predict_masks is None:
predict_masks = [None for _ in chain_ids]
self.predict_mask = {x: mask for x, mask in zip(chain_ids, predict_masks)}
self.id = protein_id

def get_id(self):
"""Return the ID of the protein."""
return self.id

def interpolate_coords(self, fill_ends=True):
"""Fill in missing values in the coordinates arrays with linear interpolation.
Expand Down Expand Up @@ -537,7 +553,14 @@ def merge(self, entry):

@staticmethod
def from_arrays(
seqs, crds, masks, chain_id_dict, chain_id_array, predict_masks=None, cdrs=None
seqs,
crds,
masks,
chain_id_dict,
chain_id_array,
predict_masks=None,
cdrs=None,
protein_id=None,
):
"""Load a protein entry from arrays.
Expand All @@ -560,6 +583,8 @@ def from_arrays(
cdrs : np.ndarray, optional
A `'numpy'` array of shape `(L,)` where residues are marked
with the corresponding CDR type (encoded as integers, see `proteinflow.constants.CDR_ALPHABET`)
protein_id : str, optional
Protein ID
Returns
-------
Expand Down Expand Up @@ -597,6 +622,7 @@ def from_arrays(
chain_ids_list,
predict_masks_list,
cdrs_list,
protein_id,
)

@staticmethod
Expand All @@ -617,14 +643,15 @@ def from_dict(dictionary):
with the corresponding type (`'H1'`, `'L1'`, ...) and non-CDR residues are marked with `'-'`
- `'predict_msk'` (optional): mask array where 1 indicates residues that were generated by a model and 0
indicates residues with known coordinates, shaped `(L,)`
It can also contain a `'protein_id'` first-level key.
Returns
-------
entry : ProteinEntry
A `ProteinEntry` object
"""
chains = sorted(dictionary.keys())
chains = sorted([x for x in dictionary.keys() if x != "protein_id"])
seq = [dictionary[k]["seq"] for k in chains]
crd = [
np.concatenate([dictionary[k]["crd_bb"], dictionary[k]["crd_sc"]], axis=1)
Expand All @@ -640,6 +667,7 @@ def from_dict(dictionary):
cdrs=cdr,
chain_ids=chains,
predict_masks=predict_mask,
protein_id=dictionary.get("protein_id"),
)

@staticmethod
Expand Down Expand Up @@ -678,6 +706,7 @@ def from_pdb_entry(pdb_entry):
pdb_dict[chain]["msk"][
(pdb_dict[chain]["crd_bb"] == 0).sum(-1).sum(-1) > 0
] = 0
pdb_dict["protein_id"] = pdb_entry.pdb_id
return ProteinEntry.from_dict(pdb_dict)

@staticmethod
Expand Down Expand Up @@ -824,6 +853,7 @@ def to_dict(self):
`proteinflow.constants.CDR_ALPHABET`
- `'predict_msk'` (optional): mask array where 1 indicates residues that were generated by a model and 0
indicates residues with known coordinates, shaped `(L,)`
It can optionally also contain `protein_id` as a first-level key.
"""
data = {}
Expand All @@ -838,6 +868,8 @@ def to_dict(self):
data[chain]["cdr"] = self.cdr[chain]
if self.predict_mask[chain] is not None:
data[chain]["predict_msk"] = self.predict_mask[chain]
if self.id is not None:
data["protein_id"] = self.id
return data

def to_pdb(
Expand All @@ -846,7 +878,7 @@ def to_pdb(
only_ca=False,
skip_oxygens=False,
only_backbone=False,
title="Untitled",
title=None,
):
"""Save the protein entry to a PDB file.
Expand All @@ -860,8 +892,8 @@ def to_pdb(
If `True`, oxygen atoms are not saved
only_backbone : bool, default `False`
If `True`, only backbone atoms are saved
title : str, default 'Untitled'
Title of the PDB file
title : str, optional
Title of the PDB file (by default either the protein id or "Untitled")
"""
pdb_builder = PDBBuilder(
Expand All @@ -870,6 +902,11 @@ def to_pdb(
skip_oxygens=skip_oxygens,
only_backbone=only_backbone,
)
if title is None:
if self.id is not None:
title = self.id
else:
title = "Untitled"
pdb_builder.save_pdb(path, title=title)

def to_pickle(self, path):
Expand Down
1 change: 0 additions & 1 deletion proteinflow/data/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ class ProteinLoader(DataLoader):
If batch size is larger than one, all objects are padded with zeros at the ends to reach the length of the
longest protein in the batch.
"""

def __init__(
Expand Down
2 changes: 1 addition & 1 deletion sample_data/7zor_copy.pdb
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
REMARK Untitled
REMARK 7zor
SEQRES 1 H 436 GLN VAL GLN LEU GLN GLN PRO GLY ALA GLU LEU VAL LYS
SEQRES 2 H 436 PRO GLY ALA SER VAL LYS MET SER CYS LYS ALA SER GLY
SEQRES 3 H 436 TYR THR PHE THR SER TYR TRP ILE THR TRP VAL ILE GLN
Expand Down

0 comments on commit b5cfe9b

Please sign in to comment.