diff --git a/environments/dataset_merger.py b/environments/dataset_merger.py index af4b3d9aa..add6ec9a7 100644 --- a/environments/dataset_merger.py +++ b/environments/dataset_merger.py @@ -4,11 +4,12 @@ import argparse import os import shutil -import pdb import numpy as np from tqdm import tqdm +# List of all possible labels identifying a task, +# for experiments in Continual Learning scenari. CONTINUAL_LEARNING_LABELS = ['CC', 'SC', 'EC', 'SQC'] CL_LABEL_KEY = "continual_learning_label" @@ -22,14 +23,13 @@ def main(): parser.add_argument('-f', '--force', action='store_true', default=False, help='Force the merge, even if it overrides something else,' ' including the destination if it exist') - parser.add_argument('--timesteps', type=int, nargs=2, default=[-1,-1], - help="To have a certain number of frames for two data sets ") group = parser.add_mutually_exclusive_group() group.add_argument('--merge', type=str, nargs=3, metavar=('source_1', 'source_2', 'destination'), default=argparse.SUPPRESS, help='Merge two datasets by appending the episodes, deleting sources right after.') args = parser.parse_args() + if 'merge' in args: # let make sure everything is in order assert os.path.exists(args.merge[0]), "Error: dataset '{}' could not be found".format(args.merge[0]) @@ -47,63 +47,24 @@ def main(): # create the output os.mkdir(args.merge[2]) - - #os.rename(args.merge[0] + "/dataset_config.json", args.merge[2] + "/dataset_config.json") - #os.rename(args.merge[0] + "/env_globals.json", args.merge[2] + "/env_globals.json") - shutil.copy2(args.merge[0] + "/dataset_config.json",args.merge[2] + "/dataset_config.json") - shutil.copy2(args.merge[0] + "/env_globals.json", args.merge[2] + "/env_globals.json") - # copy files from first source - num_timesteps_1, num_timesteps_2 = args.timesteps - local_path = os.getcwd() - all_records = sorted(glob.glob(args.merge[0] + "/record_[0-9]*/*")) - previous_records = all_records[0] - for ts_counter_1, record in enumerate(all_records): - - #if the timesteps is larger than needed, we wait until this episode is over - if(num_timesteps_1>0 and ts_counter_1 >num_timesteps_1): - if(os.path.dirname(previous_records).split('_')[-1] != os.path.dirname(record).split('_')[-1]): - break + os.rename(args.merge[0] + "/dataset_config.json", args.merge[2] + "/dataset_config.json") + os.rename(args.merge[0] + "/env_globals.json", args.merge[2] + "/env_globals.json") + + for record in sorted(glob.glob(args.merge[0] + "/record_[0-9]*/*")): s = args.merge[2] + "/" + record.split("/")[-2] + '/' + record.split("/")[-1] - s = os.path.join(local_path,s) - record = os.path.join(local_path, record) - try: - shutil.copy2(record, s) - except FileNotFoundError: - os.mkdir(os.path.dirname(s)) - shutil.copy2(record, s) - previous_records = record - num_episode_dataset_1 = int(previous_records.split("/")[-2][7:]) - if (num_timesteps_1 == -1): - num_episode_dataset_1 += 1 - ts_counter_1 += 1 + os.renames(record, s) - # copy files from second source - all_records = sorted(glob.glob(args.merge[1] + "/record_[0-9]*/*")) - previous_records = all_records[0] - for ts_counter_2, record in enumerate(all_records): + num_episode_dataset_1 = int(record.split("/")[-2][7:]) + 1 - if (num_timesteps_2 > 0 and ts_counter_2 > num_timesteps_2): - if (os.path.dirname(previous_records).split('_')[-1] != os.path.dirname(record).split('_')[-1]): - break + # copy files from second source + for record in sorted(glob.glob(args.merge[1] + "/record_[0-9]*/*")): episode = str(num_episode_dataset_1 + int(record.split("/")[-2][7:])) new_episode = record.split("/")[-2][:-len(episode)] + episode s = args.merge[2] + "/" + new_episode + '/' + record.split("/")[-1] - s = os.path.join(local_path, s) - record = os.path.join(local_path, record) - try: - shutil.copy2(record, s) - except FileNotFoundError: - os.mkdir(os.path.dirname(s)) - shutil.copy2(record, s) - previous_records = record - - num_episode_dataset_2 = int(previous_records.split("/")[-2][7:]) - if(num_timesteps_2==-1): - num_episode_dataset_2 +=1 - ts_counter_2 +=1 - - ts_counter = [ts_counter_1, ts_counter_2] + os.renames(record, s) + num_episode_dataset_2 = int(record.split("/")[-2][7:]) + 1 + # load and correct ground_truth ground_truth = {} ground_truth_load = np.load(args.merge[0] + "/ground_truth.npz") @@ -115,20 +76,12 @@ def main(): index_margin_str = len("/record_") directory_str = args.merge[2][index_slash+1:] - len_info_1 = [len(ground_truth_load[k]) for k in ground_truth_load.keys()] - num_eps_total_1, num_ts_total_1 = min(len_info_1), max(len_info_1) - len_info_2 = [len(ground_truth_load_2[k]) for k in ground_truth_load_2.keys()] - num_eps_total_2, num_ts_total_2 = min(len_info_2), max(len_info_2) - - for idx_, gt_load in enumerate([ground_truth_load, ground_truth_load_2], 1): - for arr in gt_load.files: - if arr == "images_path": # here, we want to rename just the folder containing the records, hence the black magic - for i in tqdm(range(ts_counter[idx_-1]),#range(len(gt_load["images_path"])), + for i in tqdm(range(len(gt_load["images_path"])), desc="Update of paths (Folder " + str(1+idx_) + ")"): # find the "record_" position path = gt_load["images_path"][i] @@ -142,39 +95,21 @@ def main(): else: new_record_path = path[end_pos:] ground_truth["images_path"].append(directory_str + new_record_path) - else: # anything that isnt image_path, we dont need to change gt_arr = gt_load[arr] if idx_ > 1: num_episode_dataset = num_episode_dataset_2 + # HERE check before overwritting that the target is random !+ if gt_load[arr].shape[0] < num_episode_dataset: gt_arr = np.repeat(gt_load[arr], num_episode_dataset, axis=0) if idx_ > 1: - # This is the first dataset - if (len(gt_arr) == num_eps_total_2): - # This is a episode non-change variable - ground_truth[arr] = np.concatenate((ground_truth[arr], - gt_arr[:num_episode_dataset_2]), axis=0) - elif (len(gt_arr) == num_ts_total_2): # a timesteps changing variable - ground_truth[arr] = np.concatenate((ground_truth[arr], - gt_arr[:ts_counter_2]), axis=0) - else: - assert 0 == 1, "No compatible variable in the stored ground truth for the second dataset {}" \ - .format(args.merge[1]) + ground_truth[arr] = np.concatenate((ground_truth[arr], gt_arr), axis=0) else: - # This is the first dataset - if(len(gt_arr) == num_eps_total_1): - #This is a episode non-change variable - ground_truth[arr] = gt_arr[:num_episode_dataset_1] - elif(len(gt_arr) == num_ts_total_1): # a timesteps changing variable - ground_truth[arr] = gt_arr[:ts_counter_1] - else: - assert 0 ==1 , "No compatible variable in the stored ground truth for the first dataset {}"\ - .format(args.merge[0]) + ground_truth[arr] = gt_arr # save the corrected ground_truth np.savez(args.merge[2] + "/ground_truth.npz", **ground_truth) @@ -186,6 +121,8 @@ def main(): dataset_1_size = preprocessed_load["actions"].shape[0] dataset_2_size = preprocessed_load_2["actions"].shape[0] + + # Concatenating additional information: indices of episode start, action probabilities, CL labels... for idx, prepro_load in enumerate([preprocessed_load, preprocessed_load_2]): for arr in prepro_load.files: pr_arr = prepro_load[arr] @@ -193,43 +130,29 @@ def main(): to_class = None if arr == "episode_starts": to_class = bool - elif arr == "actions_proba" or arr =="rewards": + elif arr == "actions_proba" or arr == "rewards": to_class = float else: to_class = int - # all data is of timesteps changing (instead of episode changing) - if preprocessed.get(arr, None) is None: #for the first dataset - preprocessed[arr] = pr_arr.astype(to_class)[:ts_counter_1] - else:# for the second dataset + if preprocessed.get(arr, None) is None: + preprocessed[arr] = pr_arr.astype(to_class) + else: preprocessed[arr] = np.concatenate((preprocessed[arr].astype(to_class), - pr_arr[:ts_counter_2].astype(to_class)), axis=0) + pr_arr.astype(to_class)), axis=0) if 'continual_learning_labels' in args: if preprocessed.get(CL_LABEL_KEY, None) is None: preprocessed[CL_LABEL_KEY] = \ - np.array([args.continual_learning_labels[idx] for _ in range(ts_counter_1)]) + np.array([args.continual_learning_labels[idx] for _ in range(dataset_1_size)]) else: preprocessed[CL_LABEL_KEY] = \ np.concatenate((preprocessed[CL_LABEL_KEY], np.array([args.continual_learning_labels[idx] - for _ in range(ts_counter_2)])), axis=0) - - print("The total timesteps: ", ts_counter_1+ts_counter_2) - print("The total episodes: ", num_episode_dataset_1+num_episode_dataset_2) - for k in preprocessed: - print(k) - print(preprocessed[k].shape) - - for k in ground_truth: - print(k) - print(ground_truth[k].shape) - - + for _ in range(dataset_2_size)])), axis=0) np.savez(args.merge[2] + "/preprocessed_data.npz", ** preprocessed) - # remove the old folders - # shutil.rmtree(args.merge[0]) - # shutil.rmtree(args.merge[1]) + shutil.rmtree(args.merge[0]) + shutil.rmtree(args.merge[1]) if __name__ == '__main__':