Skip to content

Commit

Permalink
FormatNXMX: Avoid storing duplicate static masks (#789)
Browse files Browse the repository at this point in the history
When trying to dials.import a large number of experiment files, the RAM
usage on Linux was unnecessarily high due to separate static masks being
created for each file. Since masks tend to be limited in number for each
experiment, it makes more sense to store masks with unique values only,
which is what this PR aims to do.

Closes dials/dials#2227.
  • Loading branch information
yash4karan authored Feb 21, 2025
1 parent 62e3332 commit 45f50b2
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 2 deletions.
2 changes: 1 addition & 1 deletion AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,4 @@ Robert Oeffner
Takanori Nakane
Tara Michels-Clark
Viktor Bengtsson

Yash Karan
1 change: 1 addition & 0 deletions newsfragments/789.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
``dials.import``: Reduce excessive memory usage when importing many (>100s) FormatNXMX files.
33 changes: 32 additions & 1 deletion src/dxtbx/format/FormatNXmx.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,41 @@
from __future__ import annotations

import weakref

import h5py
import nxmx

import scitbx.array_family.flex as flex

import dxtbx.nexus
from dxtbx.format.FormatNexus import FormatNexus


class _MaskCache:
"""A singleton to hold unique static_mask objects to avoid duplications"""

def __init__(self):
self.local_mask_cache = weakref.WeakValueDictionary()

def _mask_hasher(self, mask: flex.bool) -> int:
return hash(mask.as_numpy_array().tobytes())

def store_unique_and_get(
self, mask_tuple: tuple[flex.bool, ...] | None
) -> tuple[flex.bool, ...] | None:
if mask_tuple is None:
return None
output = []
for mask in mask_tuple:
mask_hash = self._mask_hasher(mask)
mask = self.local_mask_cache.setdefault(mask_hash, mask)
output.append(mask)
return tuple(output)


mask_cache = _MaskCache()


def detector_between_sample_and_source(detector, beam):
"""Check if the detector is perpendicular to beam and
upstream of the sample."""
Expand Down Expand Up @@ -83,7 +112,9 @@ def _start(self):
self._detector_model = inverted_distance_detector(self._detector_model)

self._scan_model = dxtbx.nexus.get_dxtbx_scan(nxsample, nxdetector)
self._static_mask = dxtbx.nexus.get_static_mask(nxdetector)
self._static_mask = mask_cache.store_unique_and_get(
dxtbx.nexus.get_static_mask(nxdetector)
)
self._bit_depth_readout = nxdetector.bit_depth_readout

if self._scan_model:
Expand Down

0 comments on commit 45f50b2

Please sign in to comment.