diff --git a/ndsl/stencils/testing/serialbox_to_netcdf.py b/ndsl/stencils/testing/serialbox_to_netcdf.py index 11814fff..a29139c5 100644 --- a/ndsl/stencils/testing/serialbox_to_netcdf.py +++ b/ndsl/stencils/testing/serialbox_to_netcdf.py @@ -32,6 +32,13 @@ def get_parser(): type=str, help="[Optional] Give the name of the data, will default to Generator_rankX", ) + parser.add_argument( + "-m", + "--merge", + action="store_true", + default=False, + help="merges datastreams blocked into separate savepoints", + ) return parser @@ -58,7 +65,12 @@ def get_serializer(data_path: str, rank: int, data_name: Optional[str] = None): return serialbox.Serializer(serialbox.OpenModeKind.Read, data_path, name) -def main(data_path: str, output_path: str, data_name: Optional[str] = None): +def main( + data_path: str, + output_path: str, + merge_blocks: bool, + data_name: Optional[str] = None, +): os.makedirs(output_path, exist_ok=True) namelist_filename_in = os.path.join(data_path, "input.nml") @@ -69,8 +81,21 @@ def main(data_path: str, output_path: str, data_name: Optional[str] = None): if namelist_filename_out != namelist_filename_in: shutil.copyfile(os.path.join(data_path, "input.nml"), namelist_filename_out) namelist = f90nml.read(namelist_filename_out) - total_ranks = ( - 6 * namelist["fv_core_nml"]["layout"][0] * namelist["fv_core_nml"]["layout"][1] + if namelist["fv_core_nml"]["grid_type"] <= 3: + total_ranks = ( + 6 + * namelist["fv_core_nml"]["layout"][0] + * namelist["fv_core_nml"]["layout"][1] + ) + else: + total_ranks = ( + namelist["fv_core_nml"]["layout"][0] * namelist["fv_core_nml"]["layout"][1] + ) + nx = int( + (namelist["fv_core_nml"]["npx"] - 1) / (namelist["fv_core_nml"]["layout"][0]) + ) + ny = int( + (namelist["fv_core_nml"]["npy"] - 1) / (namelist["fv_core_nml"]["layout"][1]) ) # all ranks have the same names, just look at first one @@ -96,8 +121,30 @@ def main(data_path: str, output_path: str, data_name: Optional[str] = None): rank_data[name].append( read_serialized_data(serializer, savepoint, name) ) + nblocks = len(rank_data[name]) + if merge_blocks and len(rank_data[name]) > 1: + full_data = np.array(rank_data[name]) + if len(full_data.shape) > 1: + if nx * ny == full_data.shape[0] * full_data.shape[1]: + # If we have an (i, x) array from each block reshape it + new_shape = (nx, ny) + full_data.shape[2:] + full_data = full_data.reshape(new_shape) + else: + # We have one array for all blocks + # could be a k-array or something else, so we take one copy + # TODO: is there a decent check for this? + full_data = full_data[0] + elif len(full_data.shape) == 1: + # if it's a scalar from each block then just take one + full_data = full_data[0] + else: + raise IndexError(f"{name} data appears to be empty") + rank_data[name] = [full_data] rank_list.append(rank_data) - n_savepoints = len(savepoints) # checking from last rank is fine + if merge_blocks: + n_savepoints = 1 + else: + n_savepoints = len(savepoints) # checking from last rank is fine data_vars = {} if n_savepoints > 0: encoding = {} @@ -166,7 +213,10 @@ def entry_point(): parser = get_parser() args = parser.parse_args() main( - data_path=args.data_path, output_path=args.output_path, data_name=args.data_name + data_path=args.data_path, + output_path=args.output_path, + merge_blocks=args.merge, + data_name=args.data_name, )