Skip to content

Commit

Permalink
maximask cube support
Browse files Browse the repository at this point in the history
  • Loading branch information
mpaillassa committed Mar 21, 2024
1 parent d5693c3 commit 5cad7f4
Showing 1 changed file with 73 additions and 30 deletions.
103 changes: 73 additions & 30 deletions maximask_and_maxitrack/maximask/maximask.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,49 +290,92 @@ def process_hdu(self, file_name, hdu_idx, task, tf_model):
elif task == "process":

# prediction array
h, w = hdu_data.shape
hdu_shape = hdu_data.shape
if np.all(hdu_data == 0):
return np.zeros_like(hdu_data, dtype=np.uint8)
else:
if self.sing_mask:
preds = np.zeros([h, w], dtype=np.int16)
preds = np.zeros_like(hdu_data, dtype=np.int16)
elif self.thresholds is not None:
preds = np.zeros([h, w, np.sum(self.class_flags)], dtype=np.uint8)
preds = np.zeros(
list(hdu_shape) + [np.sum(self.class_flags)], dtype=np.uint8
)
else:
preds = np.zeros([h, w, np.sum(self.class_flags)], dtype=np.float32)

# get list of block coordinate to process
block_coord_list = self.get_block_coords(h, w)
preds = np.zeros(
list(hdu_shape) + [np.sum(self.class_flags)], dtype=np.float32
)

# preprocessing
log.info("Preprocessing...")
hdu_data, t = utils.image_norm(hdu_data)
log.info(f"Preprocessing done in {t:.2f}s, {h*w/(t*1e06):.2f}MPix/s")

# process all the blocks by batches
# the process_batch method writes the predictions in preds by reference
nb_blocks = len(block_coord_list)
if nb_blocks <= self.batch_size:
# only one (possibly not full) batch to process
self.process_batch(hdu_data, preds, tf_model, block_coord_list)
else:
# several batches to process + one last possibly not full
nb_batch = nb_blocks // self.batch_size
for b in tqdm.tqdm(range(nb_batch), desc="INFERENCE: "):
batch_coord_list = block_coord_list[
b * self.batch_size : (b + 1) * self.batch_size
]
self.process_batch(hdu_data, preds, tf_model, batch_coord_list)
rest = nb_blocks - nb_batch * self.batch_size
if rest:
batch_coord_list = block_coord_list[-rest:]
self.process_batch(hdu_data, preds, tf_model, batch_coord_list)

if not self.sing_mask:
preds = np.transpose(preds, (2, 0, 1))
log.info(
f"Preprocessing done in {t:.2f}s, {np.prod(hdu_shape)/(t*1e06):.2f}MPix/s"
)

# process the HDU 3D or 2D data
if len(hdu_shape) == 3:
c, h, w = hdu_shape
for ch in tqdm.tqdm(range(c), desc="CUBE CHANNELS"):

# make temporary 2D prediction array to get results by reference
if self.sing_mask:
tmp_preds = np.zeros_like([h, w], dtype=np.int16)
elif self.thresholds is not None:
tmp_preds = np.zeros(
[h, w, np.sum(self.class_flags)], dtype=np.uint8
)
else:
tmp_preds = np.zeros(
[h, w, np.sum(self.class_flags)], dtype=np.float32
)

# make predictions and forward them to the final prediction array
ch_im_data = hdu_data[ch]
self.process_image(ch_im_data, tmp_preds, tf_model)
preds[ch] = tmp_preds

elif len(hdu_shape) == 2:
self.process_image(hdu_data, preds, tf_model)

return preds

def process_image(self, im_data, preds, tf_model):
"""Processes 2D image data.
Args:
im_data (np.ndarray): 2D image data to process.
preds (np.ndarray): corresponding 2D MaxiMask predictions to fill.
tf_model (tf.keras.Model): MaxiMask tensorflow model.
"""

# get list of block coordinate to process
h, w = im_data.shape
block_coord_list = self.get_block_coords(h, w)

# process all the blocks by batches
# the process_batch method writes the predictions in preds by reference
nb_blocks = len(block_coord_list)
if nb_blocks <= self.batch_size:
# only one (possibly not full) batch to process
self.process_batch(im_data, preds, tf_model, block_coord_list)
else:
# several batches to process + one last possibly not full
nb_batch = nb_blocks // self.batch_size
for b in tqdm.tqdm(range(nb_batch), desc="INFERENCE: "):
batch_coord_list = block_coord_list[
b * self.batch_size : (b + 1) * self.batch_size
]
self.process_batch(im_data, preds, tf_model, batch_coord_list)
rest = nb_blocks - nb_batch * self.batch_size
if rest:
batch_coord_list = block_coord_list[-rest:]
self.process_batch(im_data, preds, tf_model, batch_coord_list)

if not self.sing_mask:
preds = np.transpose(preds, (2, 0, 1))

return preds

def get_block_coords(self, h, w):
"""Gets the coordinate list of blocks to process.
Expand Down

0 comments on commit 5cad7f4

Please sign in to comment.