Skip to content

Commit

Permalink
Merge pull request #93 from adaptyvbio/large_files
Browse files Browse the repository at this point in the history
Improve large files handling
  • Loading branch information
elkoz authored Jul 19, 2023
2 parents 67355d5 + 7a946bb commit 70d9328
Show file tree
Hide file tree
Showing 7 changed files with 66 additions and 9 deletions.
6 changes: 5 additions & 1 deletion proteinflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@ def generate_data(
exclude_clusters=False,
exclude_based_on_cdr=None,
random_seed=42,
max_chains=10,
):
"""Download and parse PDB files that meet filtering criteria.
Expand Down Expand Up @@ -321,6 +322,8 @@ def generate_data(
if given and `exclude_clusters` is `True` + the dataset is SAbDab, exclude files based on only the given CDR clusters
random_seed : int, default 42
the random seed to use for splitting
max_chains : int, default 10
the maximum number of chains per biounit
Returns
-------
Expand Down Expand Up @@ -350,7 +353,7 @@ def generate_data(
missing_middle_thr=missing_middle_thr,
filter_methods=filter_methods,
remove_redundancies=remove_redundancies,
seq_identity_threshold=redundancy_thr,
redundancy_thr=redundancy_thr,
n=n,
force=force,
tag=tag,
Expand All @@ -359,6 +362,7 @@ def generate_data(
sabdab=sabdab,
sabdab_data_path=sabdab_data_path,
require_antigen=require_antigen,
max_chains=max_chains,
)
if not skip_splitting:
split_data(
Expand Down
6 changes: 6 additions & 0 deletions proteinflow/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,12 @@ def download(**kwargs):
type=int,
help="The random seed to use for splitting",
)
@click.option(
"--max_chains",
default=10,
type=int,
help="The maximum number of chains per biounit",
)
@cli.command("generate", help="Generate a new ProteinFlow dataset")
def generate(**kwargs):
"""Generate a new ProteinFlow dataset."""
Expand Down
3 changes: 3 additions & 0 deletions proteinflow/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1326,6 +1326,9 @@ def arr_index(row):
informative_mask = indices != -1
res_indices = np.where(mask == 1)[0]
unique_numbers = self.get_unique_residue_numbers(chain)
pdb_seq = self._pdb_sequence(chain)
if len(unique_numbers) != len(pdb_seq):
raise PDBError("Inconsistencies in the biopandas dataframe")
replace_dict = {x: y for x, y in zip(unique_numbers, res_indices)}
chain_crd.loc[:, "unique_residue_number"] = chain_crd[
"unique_residue_number"
Expand Down
18 changes: 17 additions & 1 deletion proteinflow/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from copy import deepcopy

import numpy as np
import pandas as pd
from biopandas.mmcif import PandasMmcif
from biotite.structure.geometry import angle, dihedral, distance
from einops import rearrange
Expand All @@ -12,6 +13,7 @@
ATOM_MAP_3,
ATOM_MAP_4,
ATOM_MAP_14,
D3TO1,
GLOBAL_PAD_CHAR,
ONE_TO_THREE_LETTER_MAP,
)
Expand Down Expand Up @@ -451,6 +453,7 @@ def read_mmcif(self, path: str):
"Cartn_x": "x_coord",
"Cartn_y": "y_coord",
"Cartn_z": "z_coord",
"pdbx_PDB_ins_code": "insertion",
},
axis=1,
inplace=True,
Expand All @@ -460,7 +463,20 @@ def read_mmcif(self, path: str):

def amino3to1(self):
"""Return a dataframe with the amino acid names converted to one letter codes."""
df = super().amino3to1()
tmp = self.df["ATOM"]
cmp = "placeholder"
indices = []

residue_number_insertion = tmp["residue_number"].astype(str) + tmp["insertion"]

for num, ind in zip(residue_number_insertion, np.arange(tmp.shape[0])):
if num != cmp:
indices.append(ind)
cmp = num

transl = tmp.iloc[indices]["auth_comp_id"].map(D3TO1).fillna("?")

df = pd.concat((tmp.iloc[indices]["auth_asym_id"], transl), axis=1)
df.columns = ["chain_id", "residue_name"]
return df

Expand Down
11 changes: 11 additions & 0 deletions proteinflow/download/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def get_pdb_ids(
resolution_thr=3.5,
pdb_snapshot=None,
filter_methods=True,
max_chains=5,
):
"""Get PDB ids from PDB API."""
# get filtered PDB ids from PDB API
Expand All @@ -132,6 +133,10 @@ def get_pdb_ids(
# if include_na:
# pdb_ids = pdb_ids.or_('rcsb_entry_info.polymer_composition').in_(["protein/NA", "protein/NA/oligosaccharide"])

if max_chains is not None:
pdb_ids = pdb_ids.and_(
"rcsb_assembly_info.polymer_entity_instance_count_protein"
).__le__(max_chains)
if resolution_thr is not None:
pdb_ids = pdb_ids.and_("rcsb_entry_info.resolution_combined").__le__(
resolution_thr
Expand Down Expand Up @@ -189,6 +194,7 @@ def download_filtered_pdb_files(
n=None,
local_folder=".",
load_live=False,
max_chains=5,
):
"""Download filtered PDB files and return a list of local file paths.
Expand All @@ -207,6 +213,8 @@ def download_filtered_pdb_files(
load_live : bool, default False
Whether to load the PDB files from the RCSB PDB database directly
instead of downloading them from the PDB snapshots
max_chains : int, default 5
Maximum number of chains per biounit
Returns
-------
Expand All @@ -220,6 +228,7 @@ def download_filtered_pdb_files(
resolution_thr=resolution_thr,
pdb_snapshot=pdb_snapshot,
filter_methods=filter_methods,
max_chains=max_chains,
)
with ThreadPoolExecutor(max_workers=8) as executor:
print("Getting a file list...")
Expand Down Expand Up @@ -532,6 +541,7 @@ def _load_files(
sabdab=False,
sabdab_data_path=None,
require_antigen=False,
max_chains=5,
):
"""Download filtered structure files and return a list of local file paths."""
if sabdab:
Expand All @@ -552,6 +562,7 @@ def _load_files(
local_folder=local_folder,
load_live=load_live,
n=n,
max_chains=max_chains,
)
paths = [(x, _get_fasta_path(x)) for x in paths]
return paths, error_ids
Expand Down
30 changes: 23 additions & 7 deletions proteinflow/processing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def run_processing(
missing_middle_thr=0.1,
filter_methods=True,
remove_redundancies=False,
seq_identity_threshold=0.9,
redundancy_thr=0.9,
n=None,
force=False,
tag=None,
Expand All @@ -35,6 +35,7 @@ def run_processing(
sabdab=False,
sabdab_data_path=None,
require_antigen=False,
max_chains=5,
):
"""Download and parse PDB files that meet filtering criteria.
Expand Down Expand Up @@ -73,7 +74,7 @@ def run_processing(
If `True`, only files obtained with X-ray or EM will be processed
remove_redundancies : bool, default False
If `True`, removes biounits that are doubles of others sequence wise
seq_identity_threshold : float, default 0.9
redundancy_thr : float, default 0.9
The threshold upon which sequences are considered as one and the same (default: 90%)
n : int, default None
The number of files to process (for debugging purposes)
Expand All @@ -91,6 +92,8 @@ def run_processing(
path to a zip file or a directory containing SAbDab files (only used if `sabdab` is `True`)
require_antigen : bool, default False
if `True`, only keep files with antigen chains (only used if `sabdab` is `True`)
max_chains : int, default 5
the maximum number of chains per biounit
Returns
-------
Expand Down Expand Up @@ -128,8 +131,9 @@ def run_processing(
f.write(f" remove_redundancies: {remove_redundancies} \n")
f.write(f" sabdab: {sabdab} \n")
f.write(f" pdb_snapshot: {pdb_snapshot} \n")
f.write(f" max_chains: {max_chains} \n")
if remove_redundancies:
f.write(f" seq_identity_threshold: {seq_identity_threshold} \n")
f.write(f" redundancy_threshold: {redundancy_thr} \n")
if sabdab:
f.write(f" require_antigen: {require_antigen} \n")
f.write(f" sabdab_data_path: {sabdab_data_path} \n")
Expand Down Expand Up @@ -158,6 +162,14 @@ def process_f(
antigen = antigen.split(" | ")
fn = os.path.basename(pdb_path)
pdb_id = fn.split(".")[0]
if os.path.getsize(pdb_path) > 1e7:
_log_exception(
PDBError("PDB / mmCIF file is too large"),
LOG_FILE,
pdb_id,
TMP_FOLDER,
chain_id=chain_id,
)
try:
# local_path = download_f(pdb_id, s3_client=s3_client, load_live=load_live)
name = pdb_id if not sabdab else pdb_id + "-" + chain_id
Expand Down Expand Up @@ -202,14 +214,18 @@ def process_f(
sabdab=sabdab,
sabdab_data_path=sabdab_data_path,
require_antigen=require_antigen,
max_chains=max_chains,
)
for id in error_ids:
with open(LOG_FILE, "a") as f:
f.write(f"<<< Could not download PDB/mmCIF file: {id} \n")
# paths = [(os.path.join(TMP_FOLDER, "6tkb.pdb"), "H_L_nan")]
# paths = [("data/2c2m-1.pdb.gz", "data/2c2m.fasta")]
print("Filter and process...")
_ = p_map(lambda x: process_f(x, force=force, sabdab=sabdab), paths)
# _ = [process_f(x, force=force, sabdab=sabdab, show_error=True) for x in tqdm(paths)]
# _ = [
# process_f(x, force=force, sabdab=sabdab, show_error=True)
# for x in tqdm(paths)
# ]
except Exception as e:
_raise_rcsbsearch(e)

Expand Down Expand Up @@ -246,7 +262,7 @@ def process_f(

if remove_redundancies:
removed = _remove_database_redundancies(
OUTPUT_FOLDER, seq_identity_threshold=seq_identity_threshold
OUTPUT_FOLDER, seq_identity_threshold=redundancy_thr
)
_log_removed(removed, LOG_FILE)

Expand Down Expand Up @@ -290,7 +306,7 @@ def filter_and_convert(
if pdb_entry.has_unnatural_amino_acids():
raise PDBError("Unnatural amino acids found")

for (chain,) in pdb_entry.get_chains():
for chain in pdb_entry.get_chains():
pdb_dict[chain] = {}
chain_crd = pdb_entry.get_sequence_df(chain)
fasta_seq = fasta_dict[chain]
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ dependencies = [
"aiobotocore==2.4.2",
"awscli==1.25.60",
"bs4>=0.0.1",
"pyyaml==5.3",
"rcsbsearch",
"blosum",
"pre-commit",
Expand Down

0 comments on commit 70d9328

Please sign in to comment.