Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix inconsistency of quaternion formats and use a better loss for classifying gripper openess #6

Merged
merged 5 commits into from
Mar 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ Remember to use the latest `calvin_env` module, which fixes bugs of `turn_off_le
> cd RLBench; git checkout -b peract --track origin/peract; pip install -r requirements.txt; pip install -e .; cd ..;
```

Remember to modify the success condition of `close_jar` task in RLBench, as the original condition is incorrect. See this [pull request](https://github.com/MohitShridhar/RLBench/pull/1) for more detail.

# Data Preparation

See [Preparing RLBench dataset](./docs/DATA_PREPARATION_RLBENCH.md) and [Preparing CALVIN dataset](./docs/DATA_PREPARATION_CALVIN.md).
Expand Down Expand Up @@ -116,6 +118,10 @@ First, donwload the weights and put under `train_logs/`
* For RLBench, run the bashscripts to test the policy. See [Getting started with RLBench](./docs/GETTING_STARTED_RLBENCH.md#step-3-test-the-policy) for detail.
* For CALVIN, you can run [this bashcript](./scripts/test_trajectory_calvin.sh).

**Important note:** Our released model weights of 3D Diffuser Actor assume input quaternions are in `wxyz` format. Yet, we didn't notice that CALVIN and RLBench simulation use different quaternion formats (`wxyz` and `xyzw`). We have updated our code base with an additional argument `quaternion_format` to switch between these two formats. We have verified the change by re-training and testing 3D Diffuser Actor on GNFactor with `xyzw` quaternions. The model achieves similar performance as the released checkpoint. Please see this [post](https://github.com/nickgkan/3d_diffuser_actor/issues/3#issue-2164855979) for more detail.

For users to train 3D Diffuser Actor from scratch, we update the training scripts with the correct `xyzw` quaternion format. For users to test our released model, we keep the `wxyz` quaternion format in the testing scripts ([Peract](./online_evaluation_rlbench/eval_peract.sh), [GNFactor](./online_evaluation_rlbench/eval_gnfactor.sh)).


# Getting started

Expand Down
20 changes: 18 additions & 2 deletions diffuser_actor/trajectory_optimization/diffuser_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,14 @@ def __init__(self,
fps_subsampling_factor=5,
gripper_loc_bounds=None,
rotation_parametrization='6D',
quaternion_format='xyzw',
diffusion_timesteps=100,
nhist=3,
relative=False,
lang_enhanced=False):
super().__init__()
self._rotation_parametrization = rotation_parametrization
self._quaternion_format = quaternion_format
self._relative = relative
self.use_instruction = use_instruction
self.encoder = Encoder(
Expand Down Expand Up @@ -227,6 +229,9 @@ def compute_trajectory(
trajectory = self.unconvert_rot(trajectory)
# unnormalize position
trajectory[:, :, :3] = self.unnormalize_pos(trajectory[:, :, :3])
# Convert gripper status to probaility
if trajectory.shape[-1] > 7:
trajectory[..., 7] = trajectory[..., 7].sigmoid()

return trajectory

Expand All @@ -243,6 +248,9 @@ def unnormalize_pos(self, pos):
def convert_rot(self, signal):
signal[..., 3:7] = normalise_quat(signal[..., 3:7])
if self._rotation_parametrization == '6D':
# The following code expects wxyz quaternion format!
if self._quaternion_format == 'xyzw':
signal[..., 3:7] = signal[..., (6, 3, 4, 5)]
rot = quaternion_to_matrix(signal[..., 3:7])
res = signal[..., 7:] if signal.size(-1) > 7 else None
if len(rot.shape) == 4:
Expand Down Expand Up @@ -273,6 +281,9 @@ def unconvert_rot(self, signal):
signal = torch.cat([signal[..., :3], quat], dim=-1)
if res is not None:
signal = torch.cat((signal, res), -1)
# The above code handled wxyz quaternion format!
if self._quaternion_format == 'xyzw':
signal[..., 3:7] = signal[..., (4, 5, 6, 3)]
return signal

def convert2rel(self, pcd, curr_gripper):
Expand All @@ -296,13 +307,18 @@ def forward(
):
"""
Arguments:
gt_trajectory: (B, trajectory_length, 3+6+X)
gt_trajectory: (B, trajectory_length, 3+4+X)
trajectory_mask: (B, trajectory_length)
timestep: (B, 1)
rgb_obs: (B, num_cameras, 3, H, W) in [0, 1]
pcd_obs: (B, num_cameras, 3, H, W) in world coordinates
instruction: (B, max_instruction_length, 512)
curr_gripper: (B, nhist, output_dim)
curr_gripper: (B, nhist, 3+4+X)

Note:
Regardless of rotation parametrization, the input rotation
is ALWAYS expressed as a quaternion form.
The model converts it to 6D internally if needed.
"""
if self._relative:
pcd_obs, curr_gripper = self.convert2rel(pcd_obs, curr_gripper)
Expand Down
2 changes: 1 addition & 1 deletion main_keypose.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def compute_loss(self, pred, sample):

self._compute_rotation_loss(pred, gt_action[:, 3:7], losses)

losses["gripper"] = F.binary_cross_entropy_with_logits(pred["gripper"], gt_action[:, 7:8])
losses["gripper"] = F.binary_cross_entropy(pred["gripper"], gt_action[:, 7:8])
losses["gripper"] *= self.gripper_loss_coeff

return losses
Expand Down
4 changes: 3 additions & 1 deletion main_trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class Arguments(tap.Tap):
num_vis_ins_attn_layers: int = 2
use_instruction: int = 0
rotation_parametrization: str = 'quat'
quaternion_format: str = 'wxyz'
diffusion_timesteps: int = 100
keypose_only: int = 0
num_history: int = 0
Expand Down Expand Up @@ -150,6 +151,7 @@ def get_model(self):
fps_subsampling_factor=self.args.fps_subsampling_factor,
gripper_loc_bounds=self.args.gripper_loc_bounds,
rotation_parametrization=self.args.rotation_parametrization,
quaternion_format=self.args.quaternion_format,
diffusion_timesteps=self.args.diffusion_timesteps,
nhist=self.args.num_history,
relative=bool(self.args.relative_action),
Expand Down Expand Up @@ -325,7 +327,7 @@ def compute_metrics(pred, gt, mask):
select_mask = (quat_l1 < quat_l1_).float()
quat_l1 = (select_mask * quat_l1 + (1 - select_mask) * quat_l1_)
# gripper openess
openess = ((pred[..., 7:].sigmoid() >= 0.5) == (gt[..., 7:] > 0.0)).bool()
openess = ((pred[..., 7:] >= 0.5) == (gt[..., 7:] > 0.0)).bool()
tr = 'traj_'

# Trajectory metrics
Expand Down
2 changes: 2 additions & 0 deletions online_evaluation_rlbench/eval_gnfactor.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ lang_enhanced=0
relative_action=0
seed=0
checkpoint=train_logs/diffuser_actor_gnfactor.pth
quaternion_format=wxyz

num_ckpts=${#tasks[@]}
for ((i=0; i<$num_ckpts; i++)); do
Expand Down Expand Up @@ -49,6 +50,7 @@ for ((i=0; i<$num_ckpts; i++)); do
--seed $seed \
--gripper_loc_bounds_file $gripper_loc_bounds_file \
--gripper_loc_bounds_buffer 0.08 \
--quaternion_format $quaternion_format \
--interpolation_length $interpolation_length \
--dense_interpolation 1
done
Expand Down
2 changes: 2 additions & 0 deletions online_evaluation_rlbench/eval_peract.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ lang_enhanced=0
relative_action=0
seed=0
checkpoint=train_logs/diffuser_actor_peract.pth
quaternion_format=wxyz

num_ckpts=${#tasks[@]}
for ((i=0; i<$num_ckpts; i++)); do
Expand Down Expand Up @@ -49,6 +50,7 @@ for ((i=0; i<$num_ckpts; i++)); do
--seed $seed \
--gripper_loc_bounds_file $gripper_loc_bounds_file \
--gripper_loc_bounds_buffer 0.04 \
--quaternion_format $quaternion_format \
--interpolation_length $interpolation_length \
--dense_interpolation 1
done
Expand Down
2 changes: 2 additions & 0 deletions online_evaluation_rlbench/evaluate_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ class Arguments(tap.Tap):
num_vis_ins_attn_layers: int = 2
use_instruction: int = 1
rotation_parametrization: str = '6D'
quaternion_format: str = 'xyzw'


def load_models(args):
Expand Down Expand Up @@ -97,6 +98,7 @@ def load_models(args):
fps_subsampling_factor=args.fps_subsampling_factor,
gripper_loc_bounds=gripper_loc_bounds,
rotation_parametrization=args.rotation_parametrization,
quaternion_format=args.quaternion_format,
diffusion_timesteps=args.diffusion_timesteps,
nhist=args.num_history,
relative=bool(args.relative_action),
Expand Down
2 changes: 2 additions & 0 deletions scripts/train_keypose_gnfactor.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ B=8
C=120
ngpus=6
max_episodes_per_task=20
quaternion_format=xyzw

CUDA_LAUNCH_BLOCKING=1 torchrun --nproc_per_node $ngpus --master_port $RANDOM \
main_trajectory.py \
Expand Down Expand Up @@ -41,4 +42,5 @@ CUDA_LAUNCH_BLOCKING=1 torchrun --nproc_per_node $ngpus --master_port $RANDOM \
--num_history $num_history \
--cameras front\
--max_episodes_per_task $max_episodes_per_task \
--quaternion_format $quaternion_format \
--run_log_dir diffusion_multitask-C$C-B$B-lr$lr-DI$dense_interpolation-$interpolation_length-H$num_history-DT$diffusion_timesteps
2 changes: 2 additions & 0 deletions scripts/train_keypose_peract.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ diffusion_timesteps=100
B=8
C=120
ngpus=6
quaternion_format=xyzw

CUDA_LAUNCH_BLOCKING=1 torchrun --nproc_per_node $ngpus --master_port $RANDOM \
main_trajectory.py \
Expand Down Expand Up @@ -39,4 +40,5 @@ CUDA_LAUNCH_BLOCKING=1 torchrun --nproc_per_node $ngpus --master_port $RANDOM \
--num_history $num_history \
--cameras left_shoulder right_shoulder wrist front\
--max_episodes_per_task -1 \
--quaternion_format $quaternion_format \
--run_log_dir diffusion_multitask-C$C-B$B-lr$lr-DI$dense_interpolation-$interpolation_length-H$num_history-DT$diffusion_timesteps
2 changes: 2 additions & 0 deletions scripts/train_trajectory_gnfactor.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ B=8
C=192
ngpus=6
max_episodes_per_task=20
quaternion_format=xyzw

# CUDA_LAUNCH_BLOCKING=1 python -m torch.distributed.launch --nproc_per_node $ngpus --master_port $RANDOM \
CUDA_LAUNCH_BLOCKING=1 torchrun --nproc_per_node $ngpus --master_port $RANDOM \
Expand Down Expand Up @@ -42,4 +43,5 @@ CUDA_LAUNCH_BLOCKING=1 torchrun --nproc_per_node $ngpus --master_port $RANDOM \
--num_history $num_history \
--cameras front\
--max_episodes_per_task $max_episodes_per_task \
--quaternion_format $quaternion_format\
--run_log_dir diffusion_multitask-C$C-B$B-lr$lr-DI$dense_interpolation-$interpolation_length-H$num_history-DT$diffusion_timesteps
2 changes: 2 additions & 0 deletions scripts/train_trajectory_peract.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ diffusion_timesteps=100
B=7
C=192
ngpus=7
quaternion_format=xyzw

CUDA_LAUNCH_BLOCKING=1 torchrun --nproc_per_node $ngpus --master_port $RANDOM \
main_trajectory.py \
Expand Down Expand Up @@ -39,4 +40,5 @@ CUDA_LAUNCH_BLOCKING=1 torchrun --nproc_per_node $ngpus --master_port $RANDOM \
--num_history $num_history \
--cameras left_shoulder right_shoulder wrist front\
--max_episodes_per_task -1 \
--quaternion_format $quaternion_format\
--run_log_dir diffusion_multitask-C$C-B$B-lr$lr-DI$dense_interpolation-$interpolation_length-H$num_history-DT$diffusion_timesteps