从SAM到3D高斯:我是如何用SAGA在NeRF‘蕨类’数据集上实现‘点击即分割’的
第一次看到SAGA(Segment Any 3D Gaussians)论文时,我就被这个将SAM(Segment Anything)与3D高斯点云结合的想法吸引了。作为一个长期在3D视觉领域摸爬滚打的实践者,我决定亲自尝试在nerf_llff_data/fern数据集上复现这个CVPR2023的工作。没想到,这段从理论到实践的旅程充满了意外发现和宝贵经验。
1. 环境配置:意料之外的第一个挑战
论文提供的environment.yml看似简单,但实际配置过程却让我踩了不少坑。官方建议使用conda env create --file environment.yml一键安装,但在我的Ubuntu 20.04系统上,这条路并不顺畅。
关键依赖安装步骤:
# 创建基础环境 conda create -n gaussian_splatting python==3.7.13 conda activate gaussian_splatting # 安装特定版本的PyTorch pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113 # 手动编译安装pytorch3d cd third_party unzip pytorch3d-0.7.1.zip cd pytorch3d-0.7.1 pip install -e .最棘手的是pytorch3d的安装。直接pip安装会失败,必须从源码编译。我后来发现,这是因为0.7.1版本对CUDA和PyTorch版本有严格限制。安装过程中,有几个关键点需要注意:
- 确保CUDA版本与PyTorch版本匹配
- 安装顺序很重要,必须先装PyTorch再装pytorch3d
- 所有
pip install -e .命令必须在对应目录下执行
提示:如果遇到"Could not build wheels"错误,通常是因为缺少系统依赖。在Ubuntu上,安装
build-essential和python3-dev通常能解决问题。
2. 数据预处理:down_sample参数的陷阱
处理fern数据集时,我遇到了第一个技术难点——down_sample参数的一致性问题。官方代码中,extract_features.py和extract_segment_everything_masks.py都涉及这个参数,但默认设置会导致严重问题。
问题本质:
extract_features.py默认将图像resize到1024x1024extract_segment_everything_masks.py默认使用down_sample=4- 两者处理的图像尺寸不一致会导致特征对齐失败
我的解决方案是修改extract_segment_everything_masks.py,强制统一尺寸:
# 在mask_generator.generate前添加resize img = cv2.resize(img, dsize=(1024, 1024), fx=1, fy=1, interpolation=cv2.INTER_LINEAR) masks = mask_generator.generate(img)这个修改确保了特征提取和mask生成使用相同尺寸的图像。有趣的是,这个细节在论文和官方代码中都没有明确说明,只有在实际运行时才会发现问题。
3. 训练过程:从2D到3D的特征映射
SAGA的核心创新在于通过MLP将SAM的2D特征映射到3D高斯点云空间。训练分为两个阶段:
3D高斯点云训练:
python train_scene.py -s nerf_llff_data/fern对比特征训练:
python train_contrastive_feature.py -m SegAnyGAussians/output/fern
训练中的关键观察:
| 训练阶段 | 耗时 | 显存占用 | 注意事项 |
|---|---|---|---|
| 3DGS训练 | ~2小时 | 10GB | 学习率需要微调 |
| 特征训练 | ~1小时 | 8GB | batch size影响收敛速度 |
训练完成后,最令人兴奋的部分来了——交互式分割。
4. 交互式分割:点击的艺术
在prompt_segmenting.ipynb中,需要手动设置几个关键参数:
DATA_ROOT = 'nerf_llff_data/fern' MODEL_PATH = './output/fern/' input_point = np.array([[500, 400]]) # 这个坐标需要根据具体图像调整选择input_point是个需要技巧的过程。我发现在蕨类数据集中,选择叶片中心位置通常能得到最好的分割效果。太靠近边缘的点容易导致分割不完整。
后处理两步曲:
- Statistical Filtering:去除噪声点
- Growing:扩展分割区域
通过可视化对比,可以清晰看到每一步的效果变化:
- 初始分割结果往往包含多余背景
- Filtering后噪声明显减少
- Growing使分割区域更完整
# 后处理代码示例 filtered_points, filtered_mask, thresh = postprocess_grad_based_statistical_filtering( pcd=selected_xyz.clone(), precomputed_mask=mask_.clone(), feature_gaussians=feature_gaussians, view=view, sam_mask=ref_mask.clone(), pipeline_args=pipeline.extract(args))5. 结果可视化:从数据到洞察
最终的分割结果保存在final_mask.pt中。为了让结果更直观,我修改了颜色渲染部分:
def load_point_colors_from_pcd(point_num, pcd_path, mask): # 读取原始点云颜色 pcd = o3d.io.read_point_cloud(pcd_path) colors = np.asarray(pcd.colors)[mask] return colors这个修改使得分割后的点云保留了原始颜色,视觉效果更加直观。对比原始点云和分割结果,可以清晰看到SAGA的精准分割能力。
在fern数据集上,SAGA展现了几点独特优势:
- 对复杂植物结构的良好分割能力
- 即使部分遮挡也能保持分割一致性
- 交互简单,只需单点输入
整个项目跑下来,最深的体会是:理论论文和实际代码之间往往存在gap,而填补这些gap正是工程实践的价值所在。特别是down_sample参数问题,如果没有实际动手尝试,很难发现这个隐藏的陷阱。