LR-MAE: Locate while Reconstructing with Masked Autoencoders for Point Cloud Self-supervised Learning
This repository provides the official implementation of Locate while Reconstructing with Masked Autoencoders for Point Cloud Self-supervised Learning.
As an efficient self-supervised pre-training approach, Masked autoencoder (MAE) has shown promising improvement across various 3D point cloud understanding tasks. However, the pretext task of existing point-based MAE is to reconstruct the geometry of masked points only, hence it learns features at lower semantic levels which is not appropriate for high-level downstream tasks. To address this challenge, we propose a novel self-supervised approach named Locate while Reconstructing with Masked Autoencoders (LR-MAE). Specifically, a multi-head decoder is designed to simultaneously localize the global position of masked patches while reconstructing masked points, aimed at learning better semantic features that align with downstream tasks. Moreover, we design a random query patch detection strategy for 3D object detection tasks in the pre-training stage, which significantly boosts the model performance with faster convergence speed. Extensive experiments show that our LR-MAE achieves superior performance on various point cloud understanding tasks. By fine-tuning on downstream datasets, LR-MAE outperforms the Point-MAE baseline by 3.65% classification accuracy on the ScanObjectNN dataset, and significantly exceeds the 3DETR baseline by 6.1%
Our code is tested with PyTorch 1.8.0, CUDA 11.1 and Python 3.7.0.
conda create -y -n lrmae python=3.7
conda activate lrmae
pip install torch==1.8.0+cu111 torchvision==0.9.0+cu111 torchaudio==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html
pip install -r requirements.txt
# Chamfer Distance & emd
cd /extensions/chamfer_dist
python setup.py install --user
cd /extensions/emd
python setup.py install --user
# PointNet++
pip install "git+https://github.com/erikwijmans/Pointnet2_PyTorch.git#egg=pointnet2_ops&subdirectory=pointnet2_ops_lib"
# if failed, you can try
git clone git@github.com/erikwijmans/Pointnet2_PyTorch.git
pip install pointnet2_ops_lib/.
# GPU kNN
pip install --upgrade https://github.com/unlimblue/KNN_CUDA/releases/download/0.2/KNN_CUDA-0.2-py3-none-any.whl
# if failed, you can try
pip install KNN_CUDA-0.2-py3-none-any.whl -i https://pypi.tuna.tsinghua.edu.cn/simple
Optionally, you can install a Cythonized implementation of gIOU for faster training.
conda install cython
cd ./detection/utils && python cython_compile.py build_ext --inplace
Before running the code, you need to download dataset and modify the corresponding file path in the code. Here we have also collected the download links of required datasets for you:
- ShapeNet55/34: [link].
- ScanObjectNN: [link].
- ModelNet40: [link].
- ShapeNetPart: [link].
- SUN RGB-D: [link].
- ScanNet: [link].
CUDA_VISIBLE_DEVICES=0 python main.py --config cfgs/pretrain.yaml --exp_name ./pretrain_upmae
cd ./detection
CUDA_VISIBLE_DEVICES=0,1,2,3 python pretrain_upmae.py --dataset_name upmaesunrgbd --checkpoint_dir ./checkpoint_upmae --model_name up_mae --ngpus 4
- ModelNet40
CUDA_VISIBLE_DEVICES=0 python main.py --config cfgs/finetune_modelnet.yaml --finetune_model --exp_name ./modelnet1k_ft --ckpts ./experiments/pretrain/cfgs/pretrain_upmae/ckpt-epoch-300.pth
# if you want to test the model with vote, please run:
CUDA_VISIBLE_DEVICES=1 python main.py --config cfgs/finetune_modelnet.yaml --test --exp_name ./modelnet1k_ft_vote --ckpts path/to/model
- ScanObjectNN (OBJ-BG)
CUDA_VISIBLE_DEVICES=0 python main.py --config cfgs/finetune_scan_objbg.yaml --finetune_model --exp_name ./scan_objbg_upmae_ft --ckpts ./experiments/pretrain/cfgs/pretrain_upmae/ckpt-epoch-300.pth
- ScanObjectNN (OBJ-ONLY)
CUDA_VISIBLE_DEVICES=0 python main.py --config cfgs/finetune_scan_objonly.yaml --finetune_model --exp_name ./scan_objonly_upmae_ft --ckpts ./experiments/pretrain/cfgs/pretrain_upmae/ckpt-epoch-300.pth
- ScanObjectNN (PB-T50-RS)
CUDA_VISIBLE_DEVICES=0 python main.py --config cfgs/finetune_scan_hardest.yaml --finetune_model --exp_name ./scan_hardest_upmae_ft --ckpts ./experiments/pretrain/cfgs/pretrain_upmae/ckpt-epoch-300.pth
- ShapeNet-Part
cd ./segmentation
CUDA_VISIBLE_DEVICES=0 python main.py --ckpts ../experiments/pretrain/cfgs/pretrain_upmae/ckpt-epoch-300.pth --log_dir ./shapenetpart1 --seed 1 --root data/path --learning_rate 0.0002 --epoch 300
- ScanNet
cd ./detection
CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py \
--model_name up_mae_3detr --ngpus 4 --nqueries 256 \
--batchsize_per_gpu 12 \
--pretrain_ckpt checkpoint_upmae/ckpt-last.pth \
--dataset_name scannet \
--max_epoch 1080 \
--matcher_giou_cost 2 \
--matcher_cls_cost 1 \
--matcher_center_cost 0 \
--matcher_objectness_cost 0 \
--loss_giou_weight 1 \
--loss_no_object_weight 0.25 \
--checkpoint_dir ./checkpoint_mae_q256_scannet
- SUN RGB-D
cd ./detection
CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py \
--model_name up_mae_3detr --ngpus 4 --nqueries 256 \
--batchsize_per_gpu 10 \
--pretrain_ckpt checkpoint_upmae/ckpt-last.pth \
--dataset_name sunrgbd \
--base_lr 7e-4 \
--matcher_giou_cost 3 \
--matcher_cls_cost 1 \
--matcher_center_cost 5 \
--matcher_objectness_cost 5 \
--loss_giou_weight 0 \
--loss_no_object_weight 0.1 \
--seed 2 \
--checkpoint_dir ./checkpoint_mae_q256_sunrgbd
Our code is based on prior work such as 3DETR and Point-MAE. Thanks for their efforts.