From 5cad7f4d28764ee3269a2b76cfe2245c14bd0a65 Mon Sep 17 00:00:00 2001 From: mpaillassa Date: Thu, 21 Mar 2024 18:35:15 +0900 Subject: [PATCH] maximask cube support --- maximask_and_maxitrack/maximask/maximask.py | 103 ++++++++++++++------ 1 file changed, 73 insertions(+), 30 deletions(-) diff --git a/maximask_and_maxitrack/maximask/maximask.py b/maximask_and_maxitrack/maximask/maximask.py index 7dbf034..7487777 100644 --- a/maximask_and_maxitrack/maximask/maximask.py +++ b/maximask_and_maxitrack/maximask/maximask.py @@ -290,49 +290,92 @@ def process_hdu(self, file_name, hdu_idx, task, tf_model): elif task == "process": # prediction array - h, w = hdu_data.shape + hdu_shape = hdu_data.shape if np.all(hdu_data == 0): return np.zeros_like(hdu_data, dtype=np.uint8) else: if self.sing_mask: - preds = np.zeros([h, w], dtype=np.int16) + preds = np.zeros_like(hdu_data, dtype=np.int16) elif self.thresholds is not None: - preds = np.zeros([h, w, np.sum(self.class_flags)], dtype=np.uint8) + preds = np.zeros( + list(hdu_shape) + [np.sum(self.class_flags)], dtype=np.uint8 + ) else: - preds = np.zeros([h, w, np.sum(self.class_flags)], dtype=np.float32) - - # get list of block coordinate to process - block_coord_list = self.get_block_coords(h, w) + preds = np.zeros( + list(hdu_shape) + [np.sum(self.class_flags)], dtype=np.float32 + ) # preprocessing log.info("Preprocessing...") hdu_data, t = utils.image_norm(hdu_data) - log.info(f"Preprocessing done in {t:.2f}s, {h*w/(t*1e06):.2f}MPix/s") - - # process all the blocks by batches - # the process_batch method writes the predictions in preds by reference - nb_blocks = len(block_coord_list) - if nb_blocks <= self.batch_size: - # only one (possibly not full) batch to process - self.process_batch(hdu_data, preds, tf_model, block_coord_list) - else: - # several batches to process + one last possibly not full - nb_batch = nb_blocks // self.batch_size - for b in tqdm.tqdm(range(nb_batch), desc="INFERENCE: "): - batch_coord_list = block_coord_list[ - b * self.batch_size : (b + 1) * self.batch_size - ] - self.process_batch(hdu_data, preds, tf_model, batch_coord_list) - rest = nb_blocks - nb_batch * self.batch_size - if rest: - batch_coord_list = block_coord_list[-rest:] - self.process_batch(hdu_data, preds, tf_model, batch_coord_list) - - if not self.sing_mask: - preds = np.transpose(preds, (2, 0, 1)) + log.info( + f"Preprocessing done in {t:.2f}s, {np.prod(hdu_shape)/(t*1e06):.2f}MPix/s" + ) + + # process the HDU 3D or 2D data + if len(hdu_shape) == 3: + c, h, w = hdu_shape + for ch in tqdm.tqdm(range(c), desc="CUBE CHANNELS"): + + # make temporary 2D prediction array to get results by reference + if self.sing_mask: + tmp_preds = np.zeros_like([h, w], dtype=np.int16) + elif self.thresholds is not None: + tmp_preds = np.zeros( + [h, w, np.sum(self.class_flags)], dtype=np.uint8 + ) + else: + tmp_preds = np.zeros( + [h, w, np.sum(self.class_flags)], dtype=np.float32 + ) + + # make predictions and forward them to the final prediction array + ch_im_data = hdu_data[ch] + self.process_image(ch_im_data, tmp_preds, tf_model) + preds[ch] = tmp_preds + + elif len(hdu_shape) == 2: + self.process_image(hdu_data, preds, tf_model) return preds + def process_image(self, im_data, preds, tf_model): + """Processes 2D image data. + + Args: + im_data (np.ndarray): 2D image data to process. + preds (np.ndarray): corresponding 2D MaxiMask predictions to fill. + tf_model (tf.keras.Model): MaxiMask tensorflow model. + """ + + # get list of block coordinate to process + h, w = im_data.shape + block_coord_list = self.get_block_coords(h, w) + + # process all the blocks by batches + # the process_batch method writes the predictions in preds by reference + nb_blocks = len(block_coord_list) + if nb_blocks <= self.batch_size: + # only one (possibly not full) batch to process + self.process_batch(im_data, preds, tf_model, block_coord_list) + else: + # several batches to process + one last possibly not full + nb_batch = nb_blocks // self.batch_size + for b in tqdm.tqdm(range(nb_batch), desc="INFERENCE: "): + batch_coord_list = block_coord_list[ + b * self.batch_size : (b + 1) * self.batch_size + ] + self.process_batch(im_data, preds, tf_model, batch_coord_list) + rest = nb_blocks - nb_batch * self.batch_size + if rest: + batch_coord_list = block_coord_list[-rest:] + self.process_batch(im_data, preds, tf_model, batch_coord_list) + + if not self.sing_mask: + preds = np.transpose(preds, (2, 0, 1)) + + return preds + def get_block_coords(self, h, w): """Gets the coordinate list of blocks to process.