diff --git a/proteinflow/__init__.py b/proteinflow/__init__.py index f5e9e9a..da7a172 100644 --- a/proteinflow/__init__.py +++ b/proteinflow/__init__.py @@ -243,6 +243,7 @@ def generate_data( exclude_clusters=False, exclude_based_on_cdr=None, random_seed=42, + max_chains=10, ): """Download and parse PDB files that meet filtering criteria. @@ -321,6 +322,8 @@ def generate_data( if given and `exclude_clusters` is `True` + the dataset is SAbDab, exclude files based on only the given CDR clusters random_seed : int, default 42 the random seed to use for splitting + max_chains : int, default 10 + the maximum number of chains per biounit Returns ------- @@ -350,7 +353,7 @@ def generate_data( missing_middle_thr=missing_middle_thr, filter_methods=filter_methods, remove_redundancies=remove_redundancies, - seq_identity_threshold=redundancy_thr, + redundancy_thr=redundancy_thr, n=n, force=force, tag=tag, @@ -359,6 +362,7 @@ def generate_data( sabdab=sabdab, sabdab_data_path=sabdab_data_path, require_antigen=require_antigen, + max_chains=max_chains, ) if not skip_splitting: split_data( diff --git a/proteinflow/cli.py b/proteinflow/cli.py index e37dd0a..f8c5a5f 100644 --- a/proteinflow/cli.py +++ b/proteinflow/cli.py @@ -169,6 +169,12 @@ def download(**kwargs): type=int, help="The random seed to use for splitting", ) +@click.option( + "--max_chains", + default=10, + type=int, + help="The maximum number of chains per biounit", +) @cli.command("generate", help="Generate a new ProteinFlow dataset") def generate(**kwargs): """Generate a new ProteinFlow dataset.""" diff --git a/proteinflow/data/__init__.py b/proteinflow/data/__init__.py index 25ddfe3..07dcf78 100644 --- a/proteinflow/data/__init__.py +++ b/proteinflow/data/__init__.py @@ -1326,6 +1326,9 @@ def arr_index(row): informative_mask = indices != -1 res_indices = np.where(mask == 1)[0] unique_numbers = self.get_unique_residue_numbers(chain) + pdb_seq = self._pdb_sequence(chain) + if len(unique_numbers) != len(pdb_seq): + raise PDBError("Inconsistencies in the biopandas dataframe") replace_dict = {x: y for x, y in zip(unique_numbers, res_indices)} chain_crd.loc[:, "unique_residue_number"] = chain_crd[ "unique_residue_number" diff --git a/proteinflow/data/utils.py b/proteinflow/data/utils.py index 5637fb0..45df353 100644 --- a/proteinflow/data/utils.py +++ b/proteinflow/data/utils.py @@ -3,6 +3,7 @@ from copy import deepcopy import numpy as np +import pandas as pd from biopandas.mmcif import PandasMmcif from biotite.structure.geometry import angle, dihedral, distance from einops import rearrange @@ -12,6 +13,7 @@ ATOM_MAP_3, ATOM_MAP_4, ATOM_MAP_14, + D3TO1, GLOBAL_PAD_CHAR, ONE_TO_THREE_LETTER_MAP, ) @@ -451,6 +453,7 @@ def read_mmcif(self, path: str): "Cartn_x": "x_coord", "Cartn_y": "y_coord", "Cartn_z": "z_coord", + "pdbx_PDB_ins_code": "insertion", }, axis=1, inplace=True, @@ -460,7 +463,20 @@ def read_mmcif(self, path: str): def amino3to1(self): """Return a dataframe with the amino acid names converted to one letter codes.""" - df = super().amino3to1() + tmp = self.df["ATOM"] + cmp = "placeholder" + indices = [] + + residue_number_insertion = tmp["residue_number"].astype(str) + tmp["insertion"] + + for num, ind in zip(residue_number_insertion, np.arange(tmp.shape[0])): + if num != cmp: + indices.append(ind) + cmp = num + + transl = tmp.iloc[indices]["auth_comp_id"].map(D3TO1).fillna("?") + + df = pd.concat((tmp.iloc[indices]["auth_asym_id"], transl), axis=1) df.columns = ["chain_id", "residue_name"] return df diff --git a/proteinflow/download/__init__.py b/proteinflow/download/__init__.py index 48cd86e..d630c8f 100644 --- a/proteinflow/download/__init__.py +++ b/proteinflow/download/__init__.py @@ -120,6 +120,7 @@ def get_pdb_ids( resolution_thr=3.5, pdb_snapshot=None, filter_methods=True, + max_chains=5, ): """Get PDB ids from PDB API.""" # get filtered PDB ids from PDB API @@ -132,6 +133,10 @@ def get_pdb_ids( # if include_na: # pdb_ids = pdb_ids.or_('rcsb_entry_info.polymer_composition').in_(["protein/NA", "protein/NA/oligosaccharide"]) + if max_chains is not None: + pdb_ids = pdb_ids.and_( + "rcsb_assembly_info.polymer_entity_instance_count_protein" + ).__le__(max_chains) if resolution_thr is not None: pdb_ids = pdb_ids.and_("rcsb_entry_info.resolution_combined").__le__( resolution_thr @@ -189,6 +194,7 @@ def download_filtered_pdb_files( n=None, local_folder=".", load_live=False, + max_chains=5, ): """Download filtered PDB files and return a list of local file paths. @@ -207,6 +213,8 @@ def download_filtered_pdb_files( load_live : bool, default False Whether to load the PDB files from the RCSB PDB database directly instead of downloading them from the PDB snapshots + max_chains : int, default 5 + Maximum number of chains per biounit Returns ------- @@ -220,6 +228,7 @@ def download_filtered_pdb_files( resolution_thr=resolution_thr, pdb_snapshot=pdb_snapshot, filter_methods=filter_methods, + max_chains=max_chains, ) with ThreadPoolExecutor(max_workers=8) as executor: print("Getting a file list...") @@ -532,6 +541,7 @@ def _load_files( sabdab=False, sabdab_data_path=None, require_antigen=False, + max_chains=5, ): """Download filtered structure files and return a list of local file paths.""" if sabdab: @@ -552,6 +562,7 @@ def _load_files( local_folder=local_folder, load_live=load_live, n=n, + max_chains=max_chains, ) paths = [(x, _get_fasta_path(x)) for x in paths] return paths, error_ids diff --git a/proteinflow/processing/__init__.py b/proteinflow/processing/__init__.py index 81d29c8..04e48db 100644 --- a/proteinflow/processing/__init__.py +++ b/proteinflow/processing/__init__.py @@ -26,7 +26,7 @@ def run_processing( missing_middle_thr=0.1, filter_methods=True, remove_redundancies=False, - seq_identity_threshold=0.9, + redundancy_thr=0.9, n=None, force=False, tag=None, @@ -35,6 +35,7 @@ def run_processing( sabdab=False, sabdab_data_path=None, require_antigen=False, + max_chains=5, ): """Download and parse PDB files that meet filtering criteria. @@ -73,7 +74,7 @@ def run_processing( If `True`, only files obtained with X-ray or EM will be processed remove_redundancies : bool, default False If `True`, removes biounits that are doubles of others sequence wise - seq_identity_threshold : float, default 0.9 + redundancy_thr : float, default 0.9 The threshold upon which sequences are considered as one and the same (default: 90%) n : int, default None The number of files to process (for debugging purposes) @@ -91,6 +92,8 @@ def run_processing( path to a zip file or a directory containing SAbDab files (only used if `sabdab` is `True`) require_antigen : bool, default False if `True`, only keep files with antigen chains (only used if `sabdab` is `True`) + max_chains : int, default 5 + the maximum number of chains per biounit Returns ------- @@ -128,8 +131,9 @@ def run_processing( f.write(f" remove_redundancies: {remove_redundancies} \n") f.write(f" sabdab: {sabdab} \n") f.write(f" pdb_snapshot: {pdb_snapshot} \n") + f.write(f" max_chains: {max_chains} \n") if remove_redundancies: - f.write(f" seq_identity_threshold: {seq_identity_threshold} \n") + f.write(f" redundancy_threshold: {redundancy_thr} \n") if sabdab: f.write(f" require_antigen: {require_antigen} \n") f.write(f" sabdab_data_path: {sabdab_data_path} \n") @@ -158,6 +162,14 @@ def process_f( antigen = antigen.split(" | ") fn = os.path.basename(pdb_path) pdb_id = fn.split(".")[0] + if os.path.getsize(pdb_path) > 1e7: + _log_exception( + PDBError("PDB / mmCIF file is too large"), + LOG_FILE, + pdb_id, + TMP_FOLDER, + chain_id=chain_id, + ) try: # local_path = download_f(pdb_id, s3_client=s3_client, load_live=load_live) name = pdb_id if not sabdab else pdb_id + "-" + chain_id @@ -202,14 +214,18 @@ def process_f( sabdab=sabdab, sabdab_data_path=sabdab_data_path, require_antigen=require_antigen, + max_chains=max_chains, ) for id in error_ids: with open(LOG_FILE, "a") as f: f.write(f"<<< Could not download PDB/mmCIF file: {id} \n") - # paths = [(os.path.join(TMP_FOLDER, "6tkb.pdb"), "H_L_nan")] + # paths = [("data/2c2m-1.pdb.gz", "data/2c2m.fasta")] print("Filter and process...") _ = p_map(lambda x: process_f(x, force=force, sabdab=sabdab), paths) - # _ = [process_f(x, force=force, sabdab=sabdab, show_error=True) for x in tqdm(paths)] + # _ = [ + # process_f(x, force=force, sabdab=sabdab, show_error=True) + # for x in tqdm(paths) + # ] except Exception as e: _raise_rcsbsearch(e) @@ -246,7 +262,7 @@ def process_f( if remove_redundancies: removed = _remove_database_redundancies( - OUTPUT_FOLDER, seq_identity_threshold=seq_identity_threshold + OUTPUT_FOLDER, seq_identity_threshold=redundancy_thr ) _log_removed(removed, LOG_FILE) @@ -290,7 +306,7 @@ def filter_and_convert( if pdb_entry.has_unnatural_amino_acids(): raise PDBError("Unnatural amino acids found") - for (chain,) in pdb_entry.get_chains(): + for chain in pdb_entry.get_chains(): pdb_dict[chain] = {} chain_crd = pdb_entry.get_sequence_df(chain) fasta_seq = fasta_dict[chain] diff --git a/pyproject.toml b/pyproject.toml index 954d6df..495721b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ dependencies = [ "aiobotocore==2.4.2", "awscli==1.25.60", "bs4>=0.0.1", + "pyyaml==5.3", "rcsbsearch", "blosum", "pre-commit",