Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updates automotive reference implementation #2045

23 changes: 18 additions & 5 deletions automotive/3d-object-detection/README.md
Original file line number Diff line number Diff line change
@@ -1,14 +1,27 @@
## Reference implementation fo automotive 3D detection benchmark

## TODO: Instructions for dataset download after it is uploaded somewhere appropriate

## TODO: Instructions for checkpoints downloads after it is uploaded somewhere appropriate
## Dataset and model checkpoints
Contact MLCommons support for accessing the Waymo Open Dataset along with the model checkpoints for the reference implementation. You will need to accept a license agreement and will be given directions to download the data. You will need to place the kitti_format folder under a directory named waymo. There are four total checkpoints 2 for pytorch and 2 for onnx.

## Running with docker
Build the container and mount the inference repo and Waymo dataset directory.
```
docker build -t auto_inference -f dockerfile.gpu .

docker run --gpus=all -it -v <directory to inference repo>/inference/:/inference -v <directory to waymo dataset>/waymo:/waymo --rm auto_inference

docker run --gpus=all -it -v <directory to inference repo>/inference/:/inference -v <path to waymo dataset>/waymo:/waymo --rm auto_inference
```
### Run with GPU
```
cd /inference/automotive/3d-object-detection
python main.py --dataset waymo --dataset-path /waymo/kitti_format/ --lidar-path <checkpoint_path>/pp_ep36.pth --segmentor-path <checkpoint_path>/best_deeplabv3plus_resnet50_waymo_os16.pth --mlperf_conf /inference/mlperf.conf
```

### Run with CPU and ONNX
```
python main.py --dataset waymo --dataset-path /waymo/kitti_format/ --lidar-path <checkpoint_path>/pp.onnx --segmentor-path <checkpoint_path>/deeplabv3+.onnx --mlperf_conf /inference/mlperf.conf
```

### Run the accuracy checker
```
python accuracy_waymo.py --mlperf-accuracy-file <path to accuracy file>/mlperf_log_accuracy.json --waymo-dir /waymo/kitti_format/
```
15 changes: 8 additions & 7 deletions automotive/3d-object-detection/accuracy_waymo.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,13 @@ def main():
'bbox': [],
'score': []
}

detections[image_idx]['name'].append(LABEL2CLASSES[label])
detections[image_idx]['dimensions'].append(dimension)
detections[image_idx]['location'].append(location)
detections[image_idx]['rotation_y'].append(rotation_y)
detections[image_idx]['bbox'].append(bbox)
detections[image_idx]['score'].append(score)
if dimension[0] > 0:
detections[image_idx]['name'].append(LABEL2CLASSES[label])
detections[image_idx]['dimensions'].append(dimension)
detections[image_idx]['location'].append(location)
detections[image_idx]['rotation_y'].append(rotation_y)
detections[image_idx]['bbox'].append(bbox)
detections[image_idx]['score'].append(score)
image_ids.add(image_idx)

with open(args.output_file, "w") as fp:
Expand All @@ -115,6 +115,7 @@ def main():
val_dataset.data_infos,
CLASSES,
cam_sync=False)
map_stats['Total'] = np.mean(list(map_stats.values()))

print(map_stats)
if args.verbose:
Expand Down
59 changes: 27 additions & 32 deletions automotive/3d-object-detection/backend_deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,11 @@ def load(self):
return self

def predict(self, inputs):
# TODO: implement predict
dimensions, locations, rotation_y, box2d, class_labels, class_scores, ids = [
], [], [], [], [], [], []
], [], [], [], [], [], []
with torch.inference_mode():
device = torch.device(
"cuda:0" if torch.cuda.is_available() else "cpu")
format_results = {}
"cuda:0" if torch.cuda.is_available() else "cpu")
model_input = inputs[0]
batched_pts = model_input['pts']
scores_from_cam = []
Expand Down Expand Up @@ -124,32 +122,29 @@ def predict(self, inputs):
bboxes2d, camera_bboxes = result_filter['bboxes2d'], result_filter['camera_bboxes']
for lidar_bbox, label, score, bbox2d, camera_bbox in \
zip(lidar_bboxes, labels, scores, bboxes2d, camera_bboxes):
format_result['class'].append(label.item())
format_result['truncated'].append(0.0)
format_result['occluded'].append(0)
alpha = camera_bbox[6] - \
np.arctan2(camera_bbox[0], camera_bbox[2])
format_result['alpha'].append(alpha.item())
format_result['bbox'].append(bbox2d.tolist())
format_result['dimensions'].append(camera_bbox[3:6])
format_result['location'].append(camera_bbox[:3])
format_result['rotation_y'].append(camera_bbox[6].item())
format_result['score'].append(score.item())
format_results['idx'] = idx

# write_label(format_result, os.path.join(saved_submit_path, f'{idx:06d}.txt'))

if len(format_result['dimensions']) > 0:
format_result['dimensions'] = torch.stack(
format_result['dimensions'])
format_result['location'] = torch.stack(
format_result['location'])
dimensions.append(format_result['dimensions'])
locations.append(format_result['location'])
rotation_y.append(format_result['rotation_y'])
class_labels.append(format_result['class'])
class_scores.append(format_result['score'])
box2d.append(format_result['bbox'])
ids.append(format_results['idx'])
# return Boxes, Classes, Scores # Change to desired output
format_result['class'].append(label.item())
format_result['truncated'].append(0.0)
format_result['occluded'].append(0)
alpha = camera_bbox[6] - \
np.arctan2(camera_bbox[0], camera_bbox[2])
format_result['alpha'].append(alpha.item())
format_result['bbox'].append(bbox2d.tolist())
format_result['dimensions'].append(camera_bbox[3:6])
format_result['location'].append(camera_bbox[:3])
format_result['rotation_y'].append(
camera_bbox[6].item())
format_result['score'].append(score.item())
format_result['idx'] = idx


if len(format_result['dimensions']) > 0:
format_result['dimensions'] = torch.stack(format_result['dimensions'])
format_result['location'] = torch.stack(format_result['location'])
dimensions.append(format_result['dimensions'])
locations.append(format_result['location'])
rotation_y.append(format_result['rotation_y'])
class_labels.append(format_result['class'])
class_scores.append(format_result['score'])
box2d.append(format_result['bbox'])
ids.append(format_result['idx'])
return dimensions, locations, rotation_y, box2d, class_labels, class_scores, ids
149 changes: 149 additions & 0 deletions automotive/3d-object-detection/backend_onnx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
from typing import Optional, List, Union
import os
import torch
import logging
import backend
from collections import namedtuple
from model.painter import Painter
from model.pointpillars_core import PointPillarsPre, PointPillarsPos
import numpy as np
from tools.process import keep_bbox_from_image_range
from waymo import Waymo
import onnxruntime as ort


logging.basicConfig(level=logging.INFO)
log = logging.getLogger("backend-onnx")

def change_calib_device(calib, cuda):
result = {}
if cuda:
device = 'cuda'
else:
device='cpu'
result['R0_rect'] = calib['R0_rect'].to(device=device, dtype=torch.float)
for i in range(5):
result['P' + str(i)] = calib['P' + str(i)].to(device=device, dtype=torch.float)
result['Tr_velo_to_cam_' + str(i)] = calib['Tr_velo_to_cam_' + str(i)].to(device=device, dtype=torch.float)
return result

def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

class BackendOnnx(backend.Backend):
def __init__(
self,
segmentor_path,
lidar_detector_path,
data_path
):
super(BackendOnnx, self).__init__()
self.segmentor_path = segmentor_path
self.lidar_detector_path = lidar_detector_path
#self.segmentation_classes = 18
self.detection_classes = 3
self.data_root = data_path
CLASSES = Waymo.CLASSES
self.LABEL2CLASSES = {v:k for k, v in CLASSES.items()}


def version(self):
return torch.__version__

def name(self):
return "python-SUT"


def load(self):
device = torch.device("cpu")
PaintArgs = namedtuple('PaintArgs', ['training_path', 'model_path', 'cam_sync'])
painting_args = PaintArgs(os.path.join(self.data_root, 'training'), self.segmentor_path, False)
self.painter = Painter(painting_args, onnx=True)
self.segmentor = self.painter.model
model_pre = PointPillarsPre()
model_post = PointPillarsPos(self.detection_classes)
model_pre.eval()
model_post.eval()
ort_sess = ort.InferenceSession(self.lidar_detector_path)
self.lidar_detector = ort_sess
self.model_pre = model_pre
self.model_post = model_post
return self



def predict(self, inputs):
dimensions, locations, rotation_y, box2d, class_labels, class_scores, ids = [], [], [], [], [], [], []
with torch.inference_mode():
model_input = inputs[0]
batched_pts = model_input['pts']
scores_from_cam = []
for i in range(len(model_input['images'])):
input_image_name = self.segmentor.get_inputs()[0].name
input_data = {input_image_name: to_numpy(model_input['images'][i])}
segmentation_score = self.segmentor.run(None, input_data)
segmentation_score = [torch.from_numpy(item) for item in segmentation_score]
scores_from_cam.append(self.painter.get_score(segmentation_score[0].squeeze(0)).cpu())
points = self.painter.augment_lidar_class_scores_both(scores_from_cam, batched_pts, model_input['calib_info'])
pillars, coors_batch, npoints_per_pillar = self.model_pre(batched_pts=[points])
input_pillars_name = self.lidar_detector.get_inputs()[0].name
input_coors_batch_name = self.lidar_detector.get_inputs()[1].name
input_npoints_per_pillar_name = self.lidar_detector.get_inputs()[2].name
input_data = {input_pillars_name: to_numpy(pillars),
input_coors_batch_name: to_numpy(coors_batch),
input_npoints_per_pillar_name: to_numpy(npoints_per_pillar)}
result = self.lidar_detector.run(None, input_data)
result = [torch.from_numpy(item) for item in result]
batch_results = self.model_post(result)
for j, result in enumerate(batch_results):
format_result = {
'class': [],
'truncated': [],
'occluded': [],
'alpha': [],
'bbox': [],
'dimensions': [],
'location': [],
'rotation_y': [],
'score': [],
'idx': -1
}

calib_info = model_input['calib_info']
image_info = model_input['image_info']
idx = model_input['image_info']['image_idx']

calib_info = change_calib_device(calib_info, False)
result_filter = keep_bbox_from_image_range(result, calib_info, 5, image_info, False)

lidar_bboxes = result_filter['lidar_bboxes']
labels, scores = result_filter['labels'], result_filter['scores']
bboxes2d, camera_bboxes = result_filter['bboxes2d'], result_filter['camera_bboxes']
for lidar_bbox, label, score, bbox2d, camera_bbox in \
zip(lidar_bboxes, labels, scores, bboxes2d, camera_bboxes):
format_result['class'].append(label.item())
format_result['truncated'].append(0.0)
format_result['occluded'].append(0)
alpha = camera_bbox[6] - np.arctan2(camera_bbox[0], camera_bbox[2])
format_result['alpha'].append(alpha.item())
format_result['bbox'].append(bbox2d.tolist())
format_result['dimensions'].append(camera_bbox[3:6])
format_result['location'].append(camera_bbox[:3])
format_result['rotation_y'].append(camera_bbox[6].item())
format_result['score'].append(score.item())
format_result['idx'] = idx


if len(format_result['dimensions']) > 0:
format_result['dimensions'] = torch.stack(format_result['dimensions'])
format_result['location'] = torch.stack(format_result['location'])
dimensions.append(format_result['dimensions'])
locations.append(format_result['location'])
rotation_y.append(format_result['rotation_y'])
class_labels.append(format_result['class'])
class_scores.append(format_result['score'])
box2d.append(format_result['bbox'])
ids.append(format_result['idx'])

return dimensions, locations, rotation_y, box2d, class_labels, class_scores, ids

20 changes: 10 additions & 10 deletions automotive/3d-object-detection/dockerfile.gpu
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
ARG FROM_IMAGE_NAME=pytorch/pytorch:2.0.1-cuda11.7-cudnn8-devel
ARG FROM_IMAGE_NAME=pytorch/pytorch:2.2.2-cuda11.8-cudnn8-devel
FROM ${FROM_IMAGE_NAME}

ENV DEBIAN_FRONTEND=noninteractive
Expand All @@ -20,12 +20,12 @@ RUN cd /tmp && \
CFLAGS="-std=c++14" python setup.py install && \
rm -rf mlperf

RUN pip install tqdm
RUN pip install numba
RUN pip install opencv-python
RUN pip install open3d
RUN pip install tensorboard
RUN pip install scikit-image
RUN pip install ninja
RUN pip install visdom
RUN pip install shapely
RUN pip install tqdm==4.65.0
RUN pip install numba==0.60.0
RUN pip install opencv-python==4.11.0.86
RUN pip install open3d==0.19.0
RUN pip install scikit-image==0.25.0
RUN pip install ninja==1.11.1
RUN pip install shapely==2.0.6
RUN pip install tensorboard==2.18.0
RUN pip install onnxruntime==1.20.1
5 changes: 3 additions & 2 deletions automotive/3d-object-detection/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,9 @@ def get_backend(backend, **kwargs):
from backend_deploy import BackendDeploy

backend = BackendDeploy(**kwargs)

elif backend == 'onnx':
from backend_onnx import BackendOnnx
backend = BackendOnnx(**kwargs)
elif backend == "debug":
from backend_debug import BackendDebug

Expand Down Expand Up @@ -403,7 +405,6 @@ def flush_queries():
log_settings.log_output = log_output_settings

settings = lg.TestSettings()
settings.FromConfig(mlperf_conf, args.model_name, args.scenario)
settings.FromConfig(user_conf, args.model_name, args.scenario)
settings.scenario = scenario
settings.mode = lg.TestMode.PerformanceOnly
Expand Down
27 changes: 19 additions & 8 deletions automotive/3d-object-detection/model/painter.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import onnxruntime as ort
import argparse
import model.segmentation as network
import os
Expand Down Expand Up @@ -34,24 +35,34 @@ def get_calib_from_file(calib_file):
return data


def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()


class Painter:
def __init__(self, args):
def __init__(self, args, onnx=False):
self.root_split_path = args.training_path
self.save_path = os.path.join(args.training_path, "painted_lidar/")
self.onnx = onnx
if not os.path.exists(self.save_path):
os.mkdir(self.save_path)

self.seg_net_index = 0
self.model = None
print(f'Using Segmentation Network -- deeplabv3plus')
checkpoint_file = args.model_path
model = network.modeling.__dict__['deeplabv3plus_resnet50'](
num_classes=19, output_stride=16)
checkpoint = torch.load(checkpoint_file)
model.load_state_dict(checkpoint["model_state"])
model.eval()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
if self.onnx:
model = ort.InferenceSession(checkpoint_file)
self.input_image_name = model.get_inputs()[0].name
else:
model = network.modeling.__dict__['deeplabv3plus_resnet50'](
num_classes=19, output_stride=16)
checkpoint = torch.load(checkpoint_file)
model.load_state_dict(checkpoint["model_state"])
model.eval()
device = torch.device(
'cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
self.model = model
self.cam_sync = args.cam_sync

Expand Down
4 changes: 3 additions & 1 deletion automotive/3d-object-detection/model/pointpillars.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,9 @@ def get_predicted_bboxes_single(

# 3.2 nms core
keep_inds = ml3d.ops.nms(
cur_bbox_pred2d, cur_bbox_cls_pred, self.nms_thr)
cur_bbox_pred2d.cpu(),
cur_bbox_cls_pred.cpu(),
self.nms_thr)

cur_bbox_cls_pred = cur_bbox_cls_pred[keep_inds]
cur_bbox_pred = cur_bbox_pred[keep_inds]
Expand Down
Loading
Loading