Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bug fix for dataset_merger #9

Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 29 additions & 106 deletions environments/dataset_merger.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
import argparse
import os
import shutil
import pdb

import numpy as np
from tqdm import tqdm

# List of all possible labels identifying a task,
# for experiments in Continual Learning scenari.
CONTINUAL_LEARNING_LABELS = ['CC', 'SC', 'EC', 'SQC']
CL_LABEL_KEY = "continual_learning_label"

Expand All @@ -22,14 +23,13 @@ def main():
parser.add_argument('-f', '--force', action='store_true', default=False,
help='Force the merge, even if it overrides something else,'
' including the destination if it exist')
parser.add_argument('--timesteps', type=int, nargs=2, default=[-1,-1],
help="To have a certain number of frames for two data sets ")
group = parser.add_mutually_exclusive_group()
group.add_argument('--merge', type=str, nargs=3, metavar=('source_1', 'source_2', 'destination'),
default=argparse.SUPPRESS,
help='Merge two datasets by appending the episodes, deleting sources right after.')

args = parser.parse_args()

if 'merge' in args:
# let make sure everything is in order
assert os.path.exists(args.merge[0]), "Error: dataset '{}' could not be found".format(args.merge[0])
Expand All @@ -47,63 +47,24 @@ def main():
# create the output
os.mkdir(args.merge[2])


#os.rename(args.merge[0] + "/dataset_config.json", args.merge[2] + "/dataset_config.json")
#os.rename(args.merge[0] + "/env_globals.json", args.merge[2] + "/env_globals.json")
shutil.copy2(args.merge[0] + "/dataset_config.json",args.merge[2] + "/dataset_config.json")
shutil.copy2(args.merge[0] + "/env_globals.json", args.merge[2] + "/env_globals.json")

# copy files from first source
num_timesteps_1, num_timesteps_2 = args.timesteps
local_path = os.getcwd()
all_records = sorted(glob.glob(args.merge[0] + "/record_[0-9]*/*"))
previous_records = all_records[0]
for ts_counter_1, record in enumerate(all_records):

#if the timesteps is larger than needed, we wait until this episode is over
if(num_timesteps_1>0 and ts_counter_1 >num_timesteps_1):
if(os.path.dirname(previous_records).split('_')[-1] != os.path.dirname(record).split('_')[-1]):
break
os.rename(args.merge[0] + "/dataset_config.json", args.merge[2] + "/dataset_config.json")
os.rename(args.merge[0] + "/env_globals.json", args.merge[2] + "/env_globals.json")

for record in sorted(glob.glob(args.merge[0] + "/record_[0-9]*/*")):
s = args.merge[2] + "/" + record.split("/")[-2] + '/' + record.split("/")[-1]
s = os.path.join(local_path,s)
record = os.path.join(local_path, record)
try:
shutil.copy2(record, s)
except FileNotFoundError:
os.mkdir(os.path.dirname(s))
shutil.copy2(record, s)
previous_records = record
num_episode_dataset_1 = int(previous_records.split("/")[-2][7:])
if (num_timesteps_1 == -1):
num_episode_dataset_1 += 1
ts_counter_1 += 1
os.renames(record, s)

# copy files from second source
all_records = sorted(glob.glob(args.merge[1] + "/record_[0-9]*/*"))
previous_records = all_records[0]
for ts_counter_2, record in enumerate(all_records):
num_episode_dataset_1 = int(record.split("/")[-2][7:]) + 1

if (num_timesteps_2 > 0 and ts_counter_2 > num_timesteps_2):
if (os.path.dirname(previous_records).split('_')[-1] != os.path.dirname(record).split('_')[-1]):
break
# copy files from second source
for record in sorted(glob.glob(args.merge[1] + "/record_[0-9]*/*")):
episode = str(num_episode_dataset_1 + int(record.split("/")[-2][7:]))
new_episode = record.split("/")[-2][:-len(episode)] + episode
s = args.merge[2] + "/" + new_episode + '/' + record.split("/")[-1]
s = os.path.join(local_path, s)
record = os.path.join(local_path, record)
try:
shutil.copy2(record, s)
except FileNotFoundError:
os.mkdir(os.path.dirname(s))
shutil.copy2(record, s)
previous_records = record

num_episode_dataset_2 = int(previous_records.split("/")[-2][7:])
if(num_timesteps_2==-1):
num_episode_dataset_2 +=1
ts_counter_2 +=1

ts_counter = [ts_counter_1, ts_counter_2]
os.renames(record, s)
num_episode_dataset_2 = int(record.split("/")[-2][7:]) + 1

# load and correct ground_truth
ground_truth = {}
ground_truth_load = np.load(args.merge[0] + "/ground_truth.npz")
Expand All @@ -115,20 +76,12 @@ def main():
index_margin_str = len("/record_")
directory_str = args.merge[2][index_slash+1:]

len_info_1 = [len(ground_truth_load[k]) for k in ground_truth_load.keys()]
num_eps_total_1, num_ts_total_1 = min(len_info_1), max(len_info_1)
len_info_2 = [len(ground_truth_load_2[k]) for k in ground_truth_load_2.keys()]
num_eps_total_2, num_ts_total_2 = min(len_info_2), max(len_info_2)


for idx_, gt_load in enumerate([ground_truth_load, ground_truth_load_2], 1):

for arr in gt_load.files:

if arr == "images_path":
# here, we want to rename just the folder containing the records, hence the black magic

for i in tqdm(range(ts_counter[idx_-1]),#range(len(gt_load["images_path"])),
for i in tqdm(range(len(gt_load["images_path"])),
desc="Update of paths (Folder " + str(1+idx_) + ")"):
# find the "record_" position
path = gt_load["images_path"][i]
Expand All @@ -142,39 +95,21 @@ def main():
else:
new_record_path = path[end_pos:]
ground_truth["images_path"].append(directory_str + new_record_path)

else:
# anything that isnt image_path, we dont need to change
gt_arr = gt_load[arr]

if idx_ > 1:
num_episode_dataset = num_episode_dataset_2

# HERE check before overwritting that the target is random !+
if gt_load[arr].shape[0] < num_episode_dataset:
gt_arr = np.repeat(gt_load[arr], num_episode_dataset, axis=0)

if idx_ > 1:
# This is the first dataset
if (len(gt_arr) == num_eps_total_2):
# This is a episode non-change variable
ground_truth[arr] = np.concatenate((ground_truth[arr],
gt_arr[:num_episode_dataset_2]), axis=0)
elif (len(gt_arr) == num_ts_total_2): # a timesteps changing variable
ground_truth[arr] = np.concatenate((ground_truth[arr],
gt_arr[:ts_counter_2]), axis=0)
else:
assert 0 == 1, "No compatible variable in the stored ground truth for the second dataset {}" \
.format(args.merge[1])
ground_truth[arr] = np.concatenate((ground_truth[arr], gt_arr), axis=0)
else:
# This is the first dataset
if(len(gt_arr) == num_eps_total_1):
#This is a episode non-change variable
ground_truth[arr] = gt_arr[:num_episode_dataset_1]
elif(len(gt_arr) == num_ts_total_1): # a timesteps changing variable
ground_truth[arr] = gt_arr[:ts_counter_1]
else:
assert 0 ==1 , "No compatible variable in the stored ground truth for the first dataset {}"\
.format(args.merge[0])
ground_truth[arr] = gt_arr

# save the corrected ground_truth
np.savez(args.merge[2] + "/ground_truth.npz", **ground_truth)
Expand All @@ -186,50 +121,38 @@ def main():

dataset_1_size = preprocessed_load["actions"].shape[0]
dataset_2_size = preprocessed_load_2["actions"].shape[0]

# Concatenating additional information: indices of episode start, action probabilities, CL labels...
for idx, prepro_load in enumerate([preprocessed_load, preprocessed_load_2]):
for arr in prepro_load.files:
pr_arr = prepro_load[arr]

to_class = None
if arr == "episode_starts":
to_class = bool
elif arr == "actions_proba" or arr =="rewards":
elif arr == "actions_proba" or arr == "rewards":
to_class = float
else:
to_class = int
# all data is of timesteps changing (instead of episode changing)
if preprocessed.get(arr, None) is None: #for the first dataset
preprocessed[arr] = pr_arr.astype(to_class)[:ts_counter_1]
else:# for the second dataset
if preprocessed.get(arr, None) is None:
preprocessed[arr] = pr_arr.astype(to_class)
else:
preprocessed[arr] = np.concatenate((preprocessed[arr].astype(to_class),
pr_arr[:ts_counter_2].astype(to_class)), axis=0)
pr_arr.astype(to_class)), axis=0)
if 'continual_learning_labels' in args:
if preprocessed.get(CL_LABEL_KEY, None) is None:
preprocessed[CL_LABEL_KEY] = \
np.array([args.continual_learning_labels[idx] for _ in range(ts_counter_1)])
np.array([args.continual_learning_labels[idx] for _ in range(dataset_1_size)])
else:
preprocessed[CL_LABEL_KEY] = \
np.concatenate((preprocessed[CL_LABEL_KEY], np.array([args.continual_learning_labels[idx]
for _ in range(ts_counter_2)])), axis=0)

print("The total timesteps: ", ts_counter_1+ts_counter_2)
print("The total episodes: ", num_episode_dataset_1+num_episode_dataset_2)
for k in preprocessed:
print(k)
print(preprocessed[k].shape)

for k in ground_truth:
print(k)
print(ground_truth[k].shape)


for _ in range(dataset_2_size)])), axis=0)

np.savez(args.merge[2] + "/preprocessed_data.npz", ** preprocessed)


# remove the old folders
# shutil.rmtree(args.merge[0])
# shutil.rmtree(args.merge[1])
shutil.rmtree(args.merge[0])
shutil.rmtree(args.merge[1])


if __name__ == '__main__':
Expand Down