From 794d02397a313bf41955a0acfcae8a4914695d63 Mon Sep 17 00:00:00 2001 From: Liza Kozlova Date: Thu, 3 Aug 2023 15:53:44 +0000 Subject: [PATCH 1/3] fix: patching bugs + add get_anchor_ind function --- proteinflow/data/torch.py | 72 +++++++++++++++++++++++++-------------- 1 file changed, 46 insertions(+), 26 deletions(-) diff --git a/proteinflow/data/torch.py b/proteinflow/data/torch.py index ea2dc22..46787cd 100644 --- a/proteinflow/data/torch.py +++ b/proteinflow/data/torch.py @@ -4,7 +4,8 @@ import random from collections import defaultdict from copy import deepcopy -from itertools import combinations +from itertools import combinations, groupby +from operator import itemgetter import numpy as np import torch @@ -617,7 +618,6 @@ def _get_masked_sequence( if self.mask_frac is not None: assert self.mask_frac > 0 and self.mask_frac < 1 k = int(len(neighbor_indices) * self.mask_frac) - k = max(k, 10) else: up = min( self.upper_limit, int(len(neighbor_indices) * 0.5) @@ -820,22 +820,45 @@ def _to_pyg_graph(self, data): pyg_data[key] = value.unsqueeze(0) return pyg_data - def _get_anchor_ind(self, data): - """Get the indices of the anchor residues.""" - masked_ind = torch.where(data["masked_res"].bool())[0] - known_ind = torch.where(data["mask"].bool())[0] - start, end = masked_ind[0], masked_ind[-1] - start = ( - known_ind[known_ind < start][-1] - if (known_ind < start).sum() > 0 - else known_ind[0] - ) - end = ( - known_ind[known_ind > end][0] - if (known_ind > end).sum() > 0 - else known_ind[-1] - ) - return start, end + @staticmethod + def get_anchor_ind(masked_res, mask): + """Get the indices of the anchor residues. + + Anchor residues are defined as the first and last known residues before and + after each continuous masked region. + + Parameters + ---------- + masked_res : torch.Tensor + A boolean tensor indicating which residues should be predicted + mask : torch.Tensor + A boolean tensor indicating which residues are known + + Returns + ------- + list + A list of indices of the anchor residues + + """ + anchor_ind = [] + masked_ind = torch.where(masked_res.bool())[0] + known_ind = torch.where(mask.bool())[0] + for _, g in groupby(enumerate(masked_ind), lambda x: x[0] - x[1]): + group = map(itemgetter(1), g) + group = list(map(int, group)) + start, end = group[0], group[-1] + start = ( + known_ind[known_ind < start][-1] + if (known_ind < start).sum() > 0 + else known_ind[0] + ) + end = ( + known_ind[known_ind > end][0] + if (known_ind > end).sum() > 0 + else known_ind[-1] + ) + anchor_ind += [start, end] + return anchor_ind def _get_antibody_mask(self, data): """Get a mask for the antibody residues.""" @@ -853,11 +876,9 @@ def _patch(self, data): """Cut the data around the anchor residues.""" # adapted from diffab pos_alpha = data["X"][:, 2] - start, end = self._get_anchor_ind(data) - anchor_points = torch.stack([pos_alpha[start], pos_alpha[end]], dim=0) - dist_anchor = torch.cdist(pos_alpha, anchor_points[[0]], p=2).min(dim=1)[ - 0 - ] # (L, ) + anchor_ind = self.get_anchor_ind(data["masked_res"], data["mask"]) + anchor_points = torch.stack([pos_alpha[ind] for ind in anchor_ind], dim=0) + dist_anchor = torch.cdist(pos_alpha, anchor_points, p=2).min(dim=1)[0] # (L, ) dist_anchor[~data["mask"].bool()] = float("+inf") initial_patch_idx = torch.topk( dist_anchor, @@ -867,9 +888,8 @@ def _patch(self, data): )[ 1 ] # (initial_patch_size, ) - patch_mask = data["masked_res"].clone() - patch_mask[start] = True - patch_mask[end] = True + patch_mask = data["masked_res"].bool().clone() + patch_mask[[int(x) for x in anchor_ind]] = True patch_mask[initial_patch_idx] = True if self.sabdab: From 0d0a271015bf5125606e853b7610ddca43b476cb Mon Sep 17 00:00:00 2001 From: Liza Kozlova Date: Thu, 3 Aug 2023 15:54:12 +0000 Subject: [PATCH 2/3] chore: bump version --- .conda/arm64/meta.yaml | 2 +- .conda/default/meta.yaml | 2 +- pyproject.toml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.conda/arm64/meta.yaml b/.conda/arm64/meta.yaml index 661783b..553900b 100644 --- a/.conda/arm64/meta.yaml +++ b/.conda/arm64/meta.yaml @@ -1,4 +1,4 @@ -{% set version = "2.2.0" %} +{% set version = "--2.2.1" %} package: name: "proteinflow" version: {{ version }} diff --git a/.conda/default/meta.yaml b/.conda/default/meta.yaml index bb3de45..6150afb 100644 --- a/.conda/default/meta.yaml +++ b/.conda/default/meta.yaml @@ -1,4 +1,4 @@ -{% set version = "2.2.0" %} +{% set version = "--2.2.1" %} package: name: "proteinflow" version: {{ version }} diff --git a/pyproject.toml b/pyproject.toml index 2edb8d9..2cfb088 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "proteinflow" -version = "2.2.0" +version = "--2.2.1" authors = [ {name = "Liza Kozlova", email = "liza@adaptyvbio.com"}, {name = "Arthur Valentin", email = "arthur@adaptyvbio.com"} From 3cd436e513cb323e46e7e3f2510c42a5508367e7 Mon Sep 17 00:00:00 2001 From: Liza Kozlova Date: Thu, 3 Aug 2023 15:59:56 +0000 Subject: [PATCH 3/3] chore: bump version --- .conda/arm64/meta.yaml | 2 +- .conda/default/meta.yaml | 2 +- pyproject.toml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.conda/arm64/meta.yaml b/.conda/arm64/meta.yaml index 553900b..4206bdf 100644 --- a/.conda/arm64/meta.yaml +++ b/.conda/arm64/meta.yaml @@ -1,4 +1,4 @@ -{% set version = "--2.2.1" %} +{% set version = "2.2.1" %} package: name: "proteinflow" version: {{ version }} diff --git a/.conda/default/meta.yaml b/.conda/default/meta.yaml index 6150afb..b125ccd 100644 --- a/.conda/default/meta.yaml +++ b/.conda/default/meta.yaml @@ -1,4 +1,4 @@ -{% set version = "--2.2.1" %} +{% set version = "2.2.1" %} package: name: "proteinflow" version: {{ version }} diff --git a/pyproject.toml b/pyproject.toml index 2cfb088..ec24b1c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "proteinflow" -version = "--2.2.1" +version = "2.2.1" authors = [ {name = "Liza Kozlova", email = "liza@adaptyvbio.com"}, {name = "Arthur Valentin", email = "arthur@adaptyvbio.com"}