Skip to content

Commit

Permalink
Merge pull request #9 from GaspardQin/revert-8-circular_movement_omni…
Browse files Browse the repository at this point in the history
…bot_data_fusioner_issue

bug fix for dataset_merger
  • Loading branch information
sun-te authored Jun 28, 2019
2 parents d8fdb05 + e622e1f commit e823cbb
Showing 1 changed file with 29 additions and 106 deletions.
135 changes: 29 additions & 106 deletions environments/dataset_merger.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
import argparse
import os
import shutil
import pdb

import numpy as np
from tqdm import tqdm

# List of all possible labels identifying a task,
# for experiments in Continual Learning scenari.
CONTINUAL_LEARNING_LABELS = ['CC', 'SC', 'EC', 'SQC']
CL_LABEL_KEY = "continual_learning_label"

Expand All @@ -22,14 +23,13 @@ def main():
parser.add_argument('-f', '--force', action='store_true', default=False,
help='Force the merge, even if it overrides something else,'
' including the destination if it exist')
parser.add_argument('--timesteps', type=int, nargs=2, default=[-1,-1],
help="To have a certain number of frames for two data sets ")
group = parser.add_mutually_exclusive_group()
group.add_argument('--merge', type=str, nargs=3, metavar=('source_1', 'source_2', 'destination'),
default=argparse.SUPPRESS,
help='Merge two datasets by appending the episodes, deleting sources right after.')

args = parser.parse_args()

if 'merge' in args:
# let make sure everything is in order
assert os.path.exists(args.merge[0]), "Error: dataset '{}' could not be found".format(args.merge[0])
Expand All @@ -47,63 +47,24 @@ def main():
# create the output
os.mkdir(args.merge[2])


#os.rename(args.merge[0] + "/dataset_config.json", args.merge[2] + "/dataset_config.json")
#os.rename(args.merge[0] + "/env_globals.json", args.merge[2] + "/env_globals.json")
shutil.copy2(args.merge[0] + "/dataset_config.json",args.merge[2] + "/dataset_config.json")
shutil.copy2(args.merge[0] + "/env_globals.json", args.merge[2] + "/env_globals.json")

# copy files from first source
num_timesteps_1, num_timesteps_2 = args.timesteps
local_path = os.getcwd()
all_records = sorted(glob.glob(args.merge[0] + "/record_[0-9]*/*"))
previous_records = all_records[0]
for ts_counter_1, record in enumerate(all_records):

#if the timesteps is larger than needed, we wait until this episode is over
if(num_timesteps_1>0 and ts_counter_1 >num_timesteps_1):
if(os.path.dirname(previous_records).split('_')[-1] != os.path.dirname(record).split('_')[-1]):
break
os.rename(args.merge[0] + "/dataset_config.json", args.merge[2] + "/dataset_config.json")
os.rename(args.merge[0] + "/env_globals.json", args.merge[2] + "/env_globals.json")

for record in sorted(glob.glob(args.merge[0] + "/record_[0-9]*/*")):
s = args.merge[2] + "/" + record.split("/")[-2] + '/' + record.split("/")[-1]
s = os.path.join(local_path,s)
record = os.path.join(local_path, record)
try:
shutil.copy2(record, s)
except FileNotFoundError:
os.mkdir(os.path.dirname(s))
shutil.copy2(record, s)
previous_records = record
num_episode_dataset_1 = int(previous_records.split("/")[-2][7:])
if (num_timesteps_1 == -1):
num_episode_dataset_1 += 1
ts_counter_1 += 1
os.renames(record, s)

# copy files from second source
all_records = sorted(glob.glob(args.merge[1] + "/record_[0-9]*/*"))
previous_records = all_records[0]
for ts_counter_2, record in enumerate(all_records):
num_episode_dataset_1 = int(record.split("/")[-2][7:]) + 1

if (num_timesteps_2 > 0 and ts_counter_2 > num_timesteps_2):
if (os.path.dirname(previous_records).split('_')[-1] != os.path.dirname(record).split('_')[-1]):
break
# copy files from second source
for record in sorted(glob.glob(args.merge[1] + "/record_[0-9]*/*")):
episode = str(num_episode_dataset_1 + int(record.split("/")[-2][7:]))
new_episode = record.split("/")[-2][:-len(episode)] + episode
s = args.merge[2] + "/" + new_episode + '/' + record.split("/")[-1]
s = os.path.join(local_path, s)
record = os.path.join(local_path, record)
try:
shutil.copy2(record, s)
except FileNotFoundError:
os.mkdir(os.path.dirname(s))
shutil.copy2(record, s)
previous_records = record

num_episode_dataset_2 = int(previous_records.split("/")[-2][7:])
if(num_timesteps_2==-1):
num_episode_dataset_2 +=1
ts_counter_2 +=1

ts_counter = [ts_counter_1, ts_counter_2]
os.renames(record, s)
num_episode_dataset_2 = int(record.split("/")[-2][7:]) + 1

# load and correct ground_truth
ground_truth = {}
ground_truth_load = np.load(args.merge[0] + "/ground_truth.npz")
Expand All @@ -115,20 +76,12 @@ def main():
index_margin_str = len("/record_")
directory_str = args.merge[2][index_slash+1:]

len_info_1 = [len(ground_truth_load[k]) for k in ground_truth_load.keys()]
num_eps_total_1, num_ts_total_1 = min(len_info_1), max(len_info_1)
len_info_2 = [len(ground_truth_load_2[k]) for k in ground_truth_load_2.keys()]
num_eps_total_2, num_ts_total_2 = min(len_info_2), max(len_info_2)


for idx_, gt_load in enumerate([ground_truth_load, ground_truth_load_2], 1):

for arr in gt_load.files:

if arr == "images_path":
# here, we want to rename just the folder containing the records, hence the black magic

for i in tqdm(range(ts_counter[idx_-1]),#range(len(gt_load["images_path"])),
for i in tqdm(range(len(gt_load["images_path"])),
desc="Update of paths (Folder " + str(1+idx_) + ")"):
# find the "record_" position
path = gt_load["images_path"][i]
Expand All @@ -142,39 +95,21 @@ def main():
else:
new_record_path = path[end_pos:]
ground_truth["images_path"].append(directory_str + new_record_path)

else:
# anything that isnt image_path, we dont need to change
gt_arr = gt_load[arr]

if idx_ > 1:
num_episode_dataset = num_episode_dataset_2

# HERE check before overwritting that the target is random !+
if gt_load[arr].shape[0] < num_episode_dataset:
gt_arr = np.repeat(gt_load[arr], num_episode_dataset, axis=0)

if idx_ > 1:
# This is the first dataset
if (len(gt_arr) == num_eps_total_2):
# This is a episode non-change variable
ground_truth[arr] = np.concatenate((ground_truth[arr],
gt_arr[:num_episode_dataset_2]), axis=0)
elif (len(gt_arr) == num_ts_total_2): # a timesteps changing variable
ground_truth[arr] = np.concatenate((ground_truth[arr],
gt_arr[:ts_counter_2]), axis=0)
else:
assert 0 == 1, "No compatible variable in the stored ground truth for the second dataset {}" \
.format(args.merge[1])
ground_truth[arr] = np.concatenate((ground_truth[arr], gt_arr), axis=0)
else:
# This is the first dataset
if(len(gt_arr) == num_eps_total_1):
#This is a episode non-change variable
ground_truth[arr] = gt_arr[:num_episode_dataset_1]
elif(len(gt_arr) == num_ts_total_1): # a timesteps changing variable
ground_truth[arr] = gt_arr[:ts_counter_1]
else:
assert 0 ==1 , "No compatible variable in the stored ground truth for the first dataset {}"\
.format(args.merge[0])
ground_truth[arr] = gt_arr

# save the corrected ground_truth
np.savez(args.merge[2] + "/ground_truth.npz", **ground_truth)
Expand All @@ -186,50 +121,38 @@ def main():

dataset_1_size = preprocessed_load["actions"].shape[0]
dataset_2_size = preprocessed_load_2["actions"].shape[0]

# Concatenating additional information: indices of episode start, action probabilities, CL labels...
for idx, prepro_load in enumerate([preprocessed_load, preprocessed_load_2]):
for arr in prepro_load.files:
pr_arr = prepro_load[arr]

to_class = None
if arr == "episode_starts":
to_class = bool
elif arr == "actions_proba" or arr =="rewards":
elif arr == "actions_proba" or arr == "rewards":
to_class = float
else:
to_class = int
# all data is of timesteps changing (instead of episode changing)
if preprocessed.get(arr, None) is None: #for the first dataset
preprocessed[arr] = pr_arr.astype(to_class)[:ts_counter_1]
else:# for the second dataset
if preprocessed.get(arr, None) is None:
preprocessed[arr] = pr_arr.astype(to_class)
else:
preprocessed[arr] = np.concatenate((preprocessed[arr].astype(to_class),
pr_arr[:ts_counter_2].astype(to_class)), axis=0)
pr_arr.astype(to_class)), axis=0)
if 'continual_learning_labels' in args:
if preprocessed.get(CL_LABEL_KEY, None) is None:
preprocessed[CL_LABEL_KEY] = \
np.array([args.continual_learning_labels[idx] for _ in range(ts_counter_1)])
np.array([args.continual_learning_labels[idx] for _ in range(dataset_1_size)])
else:
preprocessed[CL_LABEL_KEY] = \
np.concatenate((preprocessed[CL_LABEL_KEY], np.array([args.continual_learning_labels[idx]
for _ in range(ts_counter_2)])), axis=0)

print("The total timesteps: ", ts_counter_1+ts_counter_2)
print("The total episodes: ", num_episode_dataset_1+num_episode_dataset_2)
for k in preprocessed:
print(k)
print(preprocessed[k].shape)

for k in ground_truth:
print(k)
print(ground_truth[k].shape)


for _ in range(dataset_2_size)])), axis=0)

np.savez(args.merge[2] + "/preprocessed_data.npz", ** preprocessed)


# remove the old folders
# shutil.rmtree(args.merge[0])
# shutil.rmtree(args.merge[1])
shutil.rmtree(args.merge[0])
shutil.rmtree(args.merge[1])


if __name__ == '__main__':
Expand Down

0 comments on commit e823cbb

Please sign in to comment.