diff --git a/.conda/arm64/meta.yaml b/.conda/arm64/meta.yaml index 661783b..4206bdf 100644 --- a/.conda/arm64/meta.yaml +++ b/.conda/arm64/meta.yaml @@ -1,4 +1,4 @@ -{% set version = "2.2.0" %} +{% set version = "2.2.1" %} package: name: "proteinflow" version: {{ version }} diff --git a/.conda/default/meta.yaml b/.conda/default/meta.yaml index bb3de45..b125ccd 100644 --- a/.conda/default/meta.yaml +++ b/.conda/default/meta.yaml @@ -1,4 +1,4 @@ -{% set version = "2.2.0" %} +{% set version = "2.2.1" %} package: name: "proteinflow" version: {{ version }} diff --git a/proteinflow/data/torch.py b/proteinflow/data/torch.py index ea2dc22..46787cd 100644 --- a/proteinflow/data/torch.py +++ b/proteinflow/data/torch.py @@ -4,7 +4,8 @@ import random from collections import defaultdict from copy import deepcopy -from itertools import combinations +from itertools import combinations, groupby +from operator import itemgetter import numpy as np import torch @@ -617,7 +618,6 @@ def _get_masked_sequence( if self.mask_frac is not None: assert self.mask_frac > 0 and self.mask_frac < 1 k = int(len(neighbor_indices) * self.mask_frac) - k = max(k, 10) else: up = min( self.upper_limit, int(len(neighbor_indices) * 0.5) @@ -820,22 +820,45 @@ def _to_pyg_graph(self, data): pyg_data[key] = value.unsqueeze(0) return pyg_data - def _get_anchor_ind(self, data): - """Get the indices of the anchor residues.""" - masked_ind = torch.where(data["masked_res"].bool())[0] - known_ind = torch.where(data["mask"].bool())[0] - start, end = masked_ind[0], masked_ind[-1] - start = ( - known_ind[known_ind < start][-1] - if (known_ind < start).sum() > 0 - else known_ind[0] - ) - end = ( - known_ind[known_ind > end][0] - if (known_ind > end).sum() > 0 - else known_ind[-1] - ) - return start, end + @staticmethod + def get_anchor_ind(masked_res, mask): + """Get the indices of the anchor residues. + + Anchor residues are defined as the first and last known residues before and + after each continuous masked region. + + Parameters + ---------- + masked_res : torch.Tensor + A boolean tensor indicating which residues should be predicted + mask : torch.Tensor + A boolean tensor indicating which residues are known + + Returns + ------- + list + A list of indices of the anchor residues + + """ + anchor_ind = [] + masked_ind = torch.where(masked_res.bool())[0] + known_ind = torch.where(mask.bool())[0] + for _, g in groupby(enumerate(masked_ind), lambda x: x[0] - x[1]): + group = map(itemgetter(1), g) + group = list(map(int, group)) + start, end = group[0], group[-1] + start = ( + known_ind[known_ind < start][-1] + if (known_ind < start).sum() > 0 + else known_ind[0] + ) + end = ( + known_ind[known_ind > end][0] + if (known_ind > end).sum() > 0 + else known_ind[-1] + ) + anchor_ind += [start, end] + return anchor_ind def _get_antibody_mask(self, data): """Get a mask for the antibody residues.""" @@ -853,11 +876,9 @@ def _patch(self, data): """Cut the data around the anchor residues.""" # adapted from diffab pos_alpha = data["X"][:, 2] - start, end = self._get_anchor_ind(data) - anchor_points = torch.stack([pos_alpha[start], pos_alpha[end]], dim=0) - dist_anchor = torch.cdist(pos_alpha, anchor_points[[0]], p=2).min(dim=1)[ - 0 - ] # (L, ) + anchor_ind = self.get_anchor_ind(data["masked_res"], data["mask"]) + anchor_points = torch.stack([pos_alpha[ind] for ind in anchor_ind], dim=0) + dist_anchor = torch.cdist(pos_alpha, anchor_points, p=2).min(dim=1)[0] # (L, ) dist_anchor[~data["mask"].bool()] = float("+inf") initial_patch_idx = torch.topk( dist_anchor, @@ -867,9 +888,8 @@ def _patch(self, data): )[ 1 ] # (initial_patch_size, ) - patch_mask = data["masked_res"].clone() - patch_mask[start] = True - patch_mask[end] = True + patch_mask = data["masked_res"].bool().clone() + patch_mask[[int(x) for x in anchor_ind]] = True patch_mask[initial_patch_idx] = True if self.sabdab: diff --git a/pyproject.toml b/pyproject.toml index 2edb8d9..ec24b1c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "proteinflow" -version = "2.2.0" +version = "2.2.1" authors = [ {name = "Liza Kozlova", email = "liza@adaptyvbio.com"}, {name = "Arthur Valentin", email = "arthur@adaptyvbio.com"}