Skip to content

Commit

Permalink
Merge pull request #31 from mpaillassa/cube_inputs
Browse files Browse the repository at this point in the history
Cube hdu input support
  • Loading branch information
mpaillassa authored Mar 21, 2024
2 parents 838d9e6 + 5cad7f4 commit cffbd85
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 60 deletions.
103 changes: 73 additions & 30 deletions maximask_and_maxitrack/maximask/maximask.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,49 +290,92 @@ def process_hdu(self, file_name, hdu_idx, task, tf_model):
elif task == "process":

# prediction array
h, w = hdu_data.shape
hdu_shape = hdu_data.shape
if np.all(hdu_data == 0):
return np.zeros_like(hdu_data, dtype=np.uint8)
else:
if self.sing_mask:
preds = np.zeros([h, w], dtype=np.int16)
preds = np.zeros_like(hdu_data, dtype=np.int16)
elif self.thresholds is not None:
preds = np.zeros([h, w, np.sum(self.class_flags)], dtype=np.uint8)
preds = np.zeros(
list(hdu_shape) + [np.sum(self.class_flags)], dtype=np.uint8
)
else:
preds = np.zeros([h, w, np.sum(self.class_flags)], dtype=np.float32)

# get list of block coordinate to process
block_coord_list = self.get_block_coords(h, w)
preds = np.zeros(
list(hdu_shape) + [np.sum(self.class_flags)], dtype=np.float32
)

# preprocessing
log.info("Preprocessing...")
hdu_data, t = utils.image_norm(hdu_data)
log.info(f"Preprocessing done in {t:.2f}s, {h*w/(t*1e06):.2f}MPix/s")

# process all the blocks by batches
# the process_batch method writes the predictions in preds by reference
nb_blocks = len(block_coord_list)
if nb_blocks <= self.batch_size:
# only one (possibly not full) batch to process
self.process_batch(hdu_data, preds, tf_model, block_coord_list)
else:
# several batches to process + one last possibly not full
nb_batch = nb_blocks // self.batch_size
for b in tqdm.tqdm(range(nb_batch), desc="INFERENCE: "):
batch_coord_list = block_coord_list[
b * self.batch_size : (b + 1) * self.batch_size
]
self.process_batch(hdu_data, preds, tf_model, batch_coord_list)
rest = nb_blocks - nb_batch * self.batch_size
if rest:
batch_coord_list = block_coord_list[-rest:]
self.process_batch(hdu_data, preds, tf_model, batch_coord_list)

if not self.sing_mask:
preds = np.transpose(preds, (2, 0, 1))
log.info(
f"Preprocessing done in {t:.2f}s, {np.prod(hdu_shape)/(t*1e06):.2f}MPix/s"
)

# process the HDU 3D or 2D data
if len(hdu_shape) == 3:
c, h, w = hdu_shape
for ch in tqdm.tqdm(range(c), desc="CUBE CHANNELS"):

# make temporary 2D prediction array to get results by reference
if self.sing_mask:
tmp_preds = np.zeros_like([h, w], dtype=np.int16)
elif self.thresholds is not None:
tmp_preds = np.zeros(
[h, w, np.sum(self.class_flags)], dtype=np.uint8
)
else:
tmp_preds = np.zeros(
[h, w, np.sum(self.class_flags)], dtype=np.float32
)

# make predictions and forward them to the final prediction array
ch_im_data = hdu_data[ch]
self.process_image(ch_im_data, tmp_preds, tf_model)
preds[ch] = tmp_preds

elif len(hdu_shape) == 2:
self.process_image(hdu_data, preds, tf_model)

return preds

def process_image(self, im_data, preds, tf_model):
"""Processes 2D image data.
Args:
im_data (np.ndarray): 2D image data to process.
preds (np.ndarray): corresponding 2D MaxiMask predictions to fill.
tf_model (tf.keras.Model): MaxiMask tensorflow model.
"""

# get list of block coordinate to process
h, w = im_data.shape
block_coord_list = self.get_block_coords(h, w)

# process all the blocks by batches
# the process_batch method writes the predictions in preds by reference
nb_blocks = len(block_coord_list)
if nb_blocks <= self.batch_size:
# only one (possibly not full) batch to process
self.process_batch(im_data, preds, tf_model, block_coord_list)
else:
# several batches to process + one last possibly not full
nb_batch = nb_blocks // self.batch_size
for b in tqdm.tqdm(range(nb_batch), desc="INFERENCE: "):
batch_coord_list = block_coord_list[
b * self.batch_size : (b + 1) * self.batch_size
]
self.process_batch(im_data, preds, tf_model, batch_coord_list)
rest = nb_blocks - nb_batch * self.batch_size
if rest:
batch_coord_list = block_coord_list[-rest:]
self.process_batch(im_data, preds, tf_model, batch_coord_list)

if not self.sing_mask:
preds = np.transpose(preds, (2, 0, 1))

return preds

def get_block_coords(self, h, w):
"""Gets the coordinate list of blocks to process.
Expand Down
85 changes: 60 additions & 25 deletions maximask_and_maxitrack/maxitrack/maxitrack.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,6 @@ def process_file(self, file_name, tf_model):
tf_model (tf.keras.Model): MaxiTrack tensorflow model.
"""

all_preds = []

# make hdu tasks
hdu_task_list = self.make_hdu_tasks(file_name)

Expand All @@ -102,25 +100,34 @@ def process_file(self, file_name, tf_model):
else:
log.info("Using all available HDUs")

# go through all HDUs
for hdu_idx, hdu_type, hdu_shape in hdu_task_list:
log.info(f"HDU {hdu_idx}")
# go through all HDUs and write results
all_2d_hdu_preds = []
with open("maxitrack.out", "a") as fd:

# get raw predictions
preds, t = self.process_hdu(file_name, hdu_idx, tf_model)
log.info(
f"Whole processing time (incl. preprocessing): {t:.2f}s, {np.prod(hdu_shape)/(t*1e06):.2f}MPix/s"
)
for hdu_idx, hdu_shape in hdu_task_list:
log.info(f"HDU {hdu_idx}")

# append the results
for pred in preds:
all_preds.append(pred)
# get raw predictions
preds, t = self.process_hdu(file_name, hdu_idx, tf_model)
log.info(
f"Whole processing time (incl. preprocessing): {t:.2f}s, {np.prod(hdu_shape)/(t*1e06):.2f}MPix/s"
)

final_res = np.mean(all_preds)
# if this is a 3D HDU, outputs a score per channel image
if len(preds) > 1:
for ch in range(len(preds)):
fd.write(
f"{file_name} HDU {hdu_idx} Channel {ch} {preds[ch]:.4f}\n"
)
# if this is a 2D HDU, consider this is the same field over all 2D HDUs and aggregate a score over them
elif len(preds) == 1:
all_2d_hdu_preds.append(preds)

# write the aggregated score of 2D HDUs
if len(all_2d_hdu_preds):
final_pred = np.mean(all_2d_hdu_preds)
fd.write(f"{file_name} {final_pred:.4f}\n")

# write file
with open("maxitrack.out", "a") as fd:
fd.write(f"{file_name} {final_res:.4f}\n")
else:
log.info(f"Skipping {file_name} because no HDU was found to be processed")

Expand All @@ -145,7 +152,7 @@ def make_hdu_tasks(self, file_name):
check, hdu_type = utils.check_hdu(specified_hdu, self.im_size)
if check:
hdu_shape = specified_hdu.data.shape
hdu_task_list.append([spec_hdu_idx, hdu_type, hdu_shape])
hdu_task_list.append([spec_hdu_idx, hdu_shape])
else:
log.info(
f"Ignoring HDU {spec_hdu_idx} because not adequate data format"
Expand All @@ -157,7 +164,7 @@ def make_hdu_tasks(self, file_name):
check, hdu_type = utils.check_hdu(file_hdu[k], self.im_size)
if check:
hdu_shape = file_hdu[k].data.shape
hdu_task_list.append([k, hdu_type, hdu_shape])
hdu_task_list.append([k, hdu_shape])
else:
log.info(f"Ignoring HDU {k} because not adequate data format")

Expand All @@ -172,7 +179,7 @@ def process_hdu(self, file_name, hdu_idx, tf_model):
hdu_idx (int): index of the HDU to process.
tf_model (tf.keras.Model): MaxiTrack tensorflow model.
Returns:
out_array (np.ndarray): MaxiTrack predictions over the image.
outputs (list): MaxiTrack predictions over image and channels if 3D data.
"""

# make file name
Expand All @@ -184,15 +191,43 @@ def process_hdu(self, file_name, hdu_idx, tf_model):
with fits.open(file_name) as file_hdu:
hdu = file_hdu[hdu_idx]
im_data = hdu.data

# get list of block coordinate to process
h, w = im_data.shape
block_coord_list = self.get_block_coords(h, w)
im_data_shape = im_data.shape

# preprocessing
log.info("Preprocessing...")
im_data, t = utils.image_norm(im_data)
log.info(f"Preprocessing done in {t:.2f}s, {h*w/(t*1e06):.2f}MPix/s")
log.info(
f"Preprocessing done in {t:.2f}s, {np.prod(im_data_shape)/(t*1e06):.2f}MPix/s"
)

# process the HDU 3D or 2D data
outputs = []
if len(im_data_shape) == 3:
c = im_data.shape[0]
for ch in tqdm.tqdm(range(c), desc="CUBE CHANNELS"):
ch_im_data = im_data[ch]
predictions = self.process_image(ch_im_data, tf_model)
outputs.append(np.mean(predictions))

elif len(im_data_shape) == 2:
predictions = self.process_image(im_data, tf_model)
outputs.append(np.mean(predictions))

return outputs

def process_image(self, im_data, tf_model):
"""Processes 2D image data.
Args:
im_data (np.ndarray): 2D image data to process.
tf_model (tf.keras.Model): MaxiTrack tensorflow model.
Returns:
out_array (np.ndarray): MaxiTrack predictions over the image.
"""

# get list of block coordinate to process
h, w = im_data.shape
block_coord_list = self.get_block_coords(h, w)

# process all the blocks by batches
nb_blocks = len(block_coord_list)
Expand Down
24 changes: 19 additions & 5 deletions maximask_and_maxitrack/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,16 @@ def check_hdu(hdu, min_size):
(bool): whether the hdu is to be processed or not.
"""

# get HDU information
infos = hdu._summary()

# check size validity
ds = infos[4]
size_b = len(ds) == 2 and ds[0] > min_size and ds[1] > min_size
size_b_2d = len(ds) == 2 and ds[0] > min_size and ds[1] > min_size
size_b_3d = len(ds) == 3 and ds[1] > min_size and ds[2] > min_size
size_b = size_b_2d or size_b_3d

# check data type validity
dt = infos[5]
data_type_b = (
"float16" in dt
Expand Down Expand Up @@ -347,9 +354,16 @@ def image_norm(im):
np.place(im, im > 80000, 80000)
np.place(im, im < -500, -500)

# normalization
bg_map, si_map = background_est(im)
im -= bg_map
im /= si_map
# normalize single image or all channels if 3d
im_shape = im.shape
if len(im_shape) == 3:
for ch in range(im_shape[0]):
bg_map, si_map = background_est(im[ch])
im[ch] -= bg_map
im[ch] /= si_map
elif len(im_shape) == 2:
bg_map, si_map = background_est(im)
im -= bg_map
im /= si_map

return im

0 comments on commit cffbd85

Please sign in to comment.