SAM 2图像分割实战:从环境搭建到跑通第一个AI示例(含改进版代码)
2026/4/14 11:18:20 网站建设 项目流程

SAM 2图像分割实战:从环境搭建到跑通第一个AI示例(含改进版代码)

在计算机视觉领域,图像分割一直是核心技术之一。最近Meta推出的Segment Anything Model 2(SAM 2)凭借其出色的泛化能力和易用性,迅速成为开发者关注的焦点。本文将带你从零开始,快速搭建SAM 2开发环境,并通过改进后的代码示例实现第一个有效的图像分割应用。

1. 环境准备与安装

在开始之前,我们需要确保系统满足以下基本要求:

  • 操作系统:Windows 10/11或Linux(推荐Ubuntu 20.04+)
  • GPU:NVIDIA显卡(建议RTX 3060及以上,显存≥8GB)
  • CUDA:11.7或12.1(需与PyTorch版本匹配)
  • Python:3.10或更高版本

推荐使用conda创建独立环境,避免与其他项目产生依赖冲突:

conda create -n sam2 python=3.10 -y conda activate sam2

接下来安装PyTorch(注意选择与CUDA版本匹配的安装命令):

# CUDA 11.7 pip install torch==2.5.1 torchvision==0.15.2 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu117 # CUDA 12.1 pip install torch==2.5.1 torchvision==0.15.2 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu121

安装SAM 2及其依赖:

git clone https://github.com/facebookresearch/sam2.git cd sam2 pip install -e .

注意:如果遇到CUDA扩展编译失败警告,可以暂时忽略。大多数基础功能仍可正常使用。

2. 模型下载与初始化

SAM 2提供了多种预训练模型,根据硬件条件选择合适的版本:

模型名称参数量显存需求适用场景
sam2.1_hiera_tiny50M4GB快速验证/移动端
sam2.1_hiera_base150M6GB平衡性能与速度
sam2.1_hiera_large500M10GB高精度需求

下载大型模型(推荐):

wget https://dl.fbaipublicfiles.com/sam2/sam2.1_hiera_large.pt -P ./checkpoints

初始化预测器的Python代码:

import torch from sam2.build_sam import build_sam2 from sam2.sam2_image_predictor import SAM2ImagePredictor checkpoint = "./checkpoints/sam2.1_hiera_large.pt" model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml" # 初始化模型 predictor = SAM2ImagePredictor( build_sam2(model_cfg, checkpoint).to('cuda') )

3. 改进版图像分割实战

官方示例代码往往过于简化,实际使用时需要调整。以下是经过优化的完整流程:

from PIL import Image import numpy as np import matplotlib.pyplot as plt def show_mask(mask, ax, random_color=False): if random_color: color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) else: color = np.array([30/255, 144/255, 255/255, 0.6]) h, w = mask.shape[-2:] mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) ax.imshow(mask_image) # 加载测试图像 image_path = "test_image.jpg" image = Image.open(image_path).convert("RGB") # 设置交互点(格式:[x,y]) input_points = np.array([[500, 375]]) # 图像中心区域 input_labels = np.array([1]) # 1表示前景点 with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): predictor.set_image(np.array(image)) masks, scores, _ = predictor.predict( point_coords=input_points, point_labels=input_labels, multimask_output=True # 输出多个候选mask ) # 可视化结果 plt.figure(figsize=(15, 10)) plt.imshow(image) for i, (mask, score) in enumerate(zip(masks, scores)): show_mask(mask, plt.gca(), random_color=True) plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18) plt.axis('off') plt.show()

关键改进点:

  1. 多mask输出:设置multimask_output=True获取多个分割方案
  2. 可视化增强:为不同mask添加随机颜色和置信度评分
  3. 类型转换:确保图像为RGB格式,避免通道问题

4. 高级技巧与问题排查

4.1 交互式分割优化

SAM 2支持多种提示方式组合使用:

# 组合使用点和框提示 input_box = np.array([100, 100, 400, 400]) # [x1,y1,x2,y2] masks, _, _ = predictor.predict( point_coords=input_points, point_labels=input_labels, box=input_box, multimask_output=False )

4.2 常见错误处理

错误现象可能原因解决方案
CUDA out of memory显存不足换用更小模型或减小输入图像尺寸
分割结果不准确提示点位置不当尝试在目标物体不同位置添加多个点
预测速度慢未启用混合精度确保使用torch.autocast("cuda")

4.3 性能优化技巧

# 图像预处理优化 def resize_long_edge(image, max_size=1024): width, height = image.size if max(width, height) > max_size: scale = max_size / max(width, height) new_size = (int(width*scale), int(height*scale)) return image.resize(new_size, Image.BILINEAR) return image optimized_image = resize_long_edge(image) predictor.set_image(np.array(optimized_image))

5. 自定义数据集测试

要测试自己的图片,只需修改图像路径并调整提示点:

custom_image = Image.open("your_image.jpg").convert("RGB") # 通过matplotlib交互获取坐标 plt.imshow(custom_image) points = plt.ginput(n=-1, timeout=-1) # 点击获取点坐标 plt.close() input_points = np.array(points) input_labels = np.array([1]*len(points)) # 全部设为前景点 with torch.inference_mode(): predictor.set_image(np.array(custom_image)) masks, _, _ = predictor.predict( point_coords=input_points, point_labels=input_labels, multimask_output=True )

实际项目中,我发现对复杂场景添加3-5个分布均匀的点通常能得到最佳分割效果。对于细长物体(如电线),沿物体走向均匀布点比集中布点效果更好。

需要专业的网站建设服务?

联系我们获取免费的网站建设咨询和方案报价,让我们帮助您实现业务目标

立即咨询