Vision-Language-Action:LMDrive大语言模型(LLM)核心推理单元
2026/7/6 1:51:08 网站建设 项目流程

LMDrive 大语言模型(LLM)核心推理单元

1. 整体架构概览

LMDrive 使用LLaVA-7B作为核心语言模型,基于 Meta LLaMA 架构。LLM 在系统中扮演特征提取器角色,而非直接生成文本。

数据流

文本指令 + Q-Former特征 → LLM输入嵌入 → LlamaDecoderLayer × 32 → 隐藏状态 → Waypoint预测器 + 终点分类器

2. 核心算法组件

2.1 LlamaRMSNorm(均方根归一化)

算法原理

  • 与传统 LayerNorm 不同,RMSNorm 不减去均值,仅对标准差进行归一化
  • 公式:y = x / sqrt(E[x²] + ε) * weight
  • 优点:计算更快、训练更稳定,避免均值偏移问题

代码实现modeling_llama.py:95-109):

classLlamaRMSNorm(nn.Module):def__init__(self,hidden_size,eps=1e-6):super().__init__()self.weight=nn.Parameter(torch.ones(hidden_size))self.variance_epsilon=epsdefforward(self,hidden_states):input_dtype=hidden_states.dtype hidden_states=hidden_states.to(torch.float32)variance=hidden_states.pow(2).mean(-1,keepdim=True)hidden_states=hidden_states*torch.rsqrt(variance+self.variance_epsilon)returnself.weight*hidden_states.to(input_dtype)

举例说明

  • 输入:[1.0, 2.0, 3.0](hidden_dim=3)
  • 方差计算:(1+4+9)/3 = 14/3 ≈ 4.67
  • 归一化:[1/2.16, 2/2.16, 3/2.16] ≈ [0.46, 0.93, 1.39]
  • 乘以可学习权重:[0.46*w1, 0.93*w2, 1.39*w3]

2.2 LlamaRotaryEmbedding(RoPE 旋转位置编码)

算法原理

  • 通过旋转矩阵对 Query/Key 进行位置编码,实现相对位置信息
  • 公式:q' = q × cos(θ) + rotate_half(q) × sin(θ)
  • 支持动态 NTK 缩放,扩展序列长度上限

代码实现modeling_llama.py:115-148):

classLlamaRotaryEmbedding(nn.Module):def__init__(self,dim,max_position_embeddings=2048,base=10000):super().__init__()self.dim=dim inv_freq=1.0/(base**(torch.arange(0,dim,2).float()/dim))self.register_buffer("inv_freq",inv_freq)defforward(self,x,seq_len=None):t=torch.arange(seq_len,dtype=self.inv_freq.dtype)freqs=torch.einsum("i,j->ij",t,self.inv_freq)emb=torch.cat((freqs,freqs),dim=-1)returnemb.cos(),emb.sin()defrotate_half(x):x1=x[...,:x.shape[-1]//2]x2=x[...,x.shape[-1]//2:]returntorch.cat((-x2,x1),dim=-1)

举例说明

  • 假设 head_dim=4,位置 pos=2:
    • inv_freq = [1/(10000^(0/4)), 1/(10000^(2/4))] = [1, 0.01]
    • freqs = [2×1, 2×0.01] = [2, 0.02]
    • cos = [cos(2), cos(0.02)] ≈ [-0.416, 1.0]
    • sin = [sin(2), sin(0.02)] ≈ [0.909, 0.02]
  • 对 Q/K 向量[q0, q1, q2, q3]
    • q0' = q0 × cos(2) - q2 × sin(2)
    • q1' = q1 × cos(0.02) - q3 × sin(0.02)

2.3 LlamaAttention(因果多头自注意力)

算法原理

  • 标准多头注意力机制,配合因果掩码防止未来信息泄露
  • 支持 KV 缓存加速推理(past_key_value
  • 可选 GQA(Grouped Query Attention)通过repeat_kv实现

代码实现modeling_llama.py:258-412):

classLlamaAttention(nn.Module):def__init__(self,config):super().__init__()self.num_heads=config.num_attention_heads self.head_dim=config.hidden_size//self.num_heads self.q_proj=nn.Linear(config.hidden_size,self.num_heads*self.head_dim)self.k_proj=nn.Linear(config.hidden_size,self.num_key_value_heads*self.head_dim)self.v_proj=nn.Linear(config.hidden_size,self.num_key_value_heads*self.head_dim)self.o_proj=nn.Linear(self.num_heads*self.head_dim,config.hidden_size)self._init_rope()defforward(self,hidden_states,attention_mask=None,position_ids=None,past_key_value=None,use_cache=False):bsz,q_len,_=hidden_states.size()query_states=self.q_proj(hidden_states)# [B, L, n_heads * head_dim]key_states=self.k_proj(hidden_states)# [B, L, n_kv_heads * head_dim]value_states=self.v_proj(hidden_states)# [B, L, n_kv_heads * head_dim]query_states=query_states.view(bsz,q_len,self.num_heads,self.head_dim).transpose(1,2)key_states=key_states.view(bsz,q_len,self.num_key_value_heads,self.head_dim).transpose(1,2)value_states=value_states.view(bsz,q_len,self.num_key_value_heads,self.head_dim).transpose(1,2)# 应用 RoPEcos,sin=self.rotary_emb(value_states,seq_len=kv_seq_len)query_states,key_states=apply_rotary_pos_emb(query_states,key_states,cos,sin,position_ids)# KV缓存拼接ifpast_key_valueisnotNone:key_states=torch.cat([past_key_value[0],key_states],dim=2)value_states=torch.cat([past_key_value[1],value_states],dim=2)# GQA: 扩展 KV 到所有头key_states=repeat_kv(key_states,self.num_key_value_groups)value_states=repeat_kv(value_states,self.num_key_value_groups)# 注意力计算attn_weights=torch.matmul(query_states,key_states.transpose(2,3))/math.sqrt(self.head_dim)ifattention_maskisnotNone:attn_weights=attn_weights+attention_mask attn_weights=nn.functional.softmax(attn_weights,dim=-1)attn_output=torch.matmul(attn_weights,value_states)attn_output=attn_output.transpose(1,2).reshape(bsz,q_len,self.hidden_size)attn_output=self.o_proj(attn_output)returnattn_output,attn_weights,(key_states,value_states)ifuse_cacheelseNone

关键机制

机制作用
因果掩码防止当前 token 关注未来 token
RoPE注入相对位置信息
KV缓存推理时复用历史 KV,减少重复计算
GQA通过共享 KV 头减少内存占用

2.4 LlamaMLP(SwiGLU 前馈网络)

算法原理

  • 使用 SwiGLU(Swish-Gated Linear Unit)激活函数
  • 结构:down_proj(SiLU(gate_proj(x)) × up_proj(x))
  • 相比 ReLU,SwiGLU 提供更平滑的梯度流动

代码实现modeling_llama.py:212-243):

classLlamaMLP(nn.Module):def__init__(self,config):super().__init__()self.gate_proj=nn.Linear(config.hidden_size,config.intermediate_size,bias=False)self.up_proj=nn.Linear(config.hidden_size,config.intermediate_size,bias=False)self.down_proj=nn.Linear(config.intermediate_size,config.hidden_size,bias=False)self.act_fn=ACT2FN[config.hidden_act]# 'silu'defforward(self,x):returnself.down_proj(self.act_fn(self.gate_proj(x))*self.up_proj(x))

SwiGLU 公式

y = W3 × (SiLU(W1 × x) ⊙ W2 × x)

其中表示逐元素相乘,SiLU(x) = x × σ(x)


2.5 LlamaDecoderLayer(解码器层)

算法原理

  • 采用 Pre-Norm 架构:归一化在注意力/MLP 之前
  • 残差连接模式:x + Attention(LN(x))x + MLP(LN(x))

代码实现modeling_llama.py:597-663):

classLlamaDecoderLayer(nn.Module):def__init__(self,config):super().__init__()self.self_attn=LlamaAttention(config)ifnot_flash_attn_2_enabledelseLlamaFlashAttention2(config)self.mlp=LlamaMLP(config)self.input_layernorm=LlamaRMSNorm(config.hidden_size)self.post_attention_layernorm=LlamaRMSNorm(config.hidden_size)defforward(self,hidden_states,attention_mask=None,position_ids=None,past_key_value=None,use_cache=False):residual=hidden_states# 自注意力子层hidden_states=self.input_layernorm(hidden_states)hidden_states,_,present_key_value=self.self_attn(hidden_states=hidden_states,attention_mask=attention_mask,position_ids=position_ids,past_key_value=past_key_value,use_cache=use_cache,)hidden_states=residual+hidden_states# MLP 子层residual=hidden_states hidden_states=self.post_attention_layernorm(hidden_states)hidden_states=self.mlp(hidden_states)hidden_states=residual+hidden_statesreturn(hidden_states,present_key_value)ifuse_cacheelse(hidden_states,)

前向传播流程

输入 x → LN1(x) → Attention(LN1(x)) → x + Attention(...) → LN2(x+Attention) → MLP(LN2(...)) → x + Attention(...) + MLP(...) → 输出

3. LMDrive 特定适配

3.1 LoRA 微调

原理:在 Transformer 的q_projv_proj层插入低秩适配器,冻结主模型参数,仅训练适配器。

配置drive.py:145-155):

loraconfig=LoraConfig(r=16,# 秩lora_alpha=32,# 缩放因子target_modules=["q_proj","v_proj"],lora_dropout=0.05,bias="none",task_type="CAUSAL_LM")self.llm_model=get_peft_model(self.llm_model,loraconfig)

优点

  • 参数量大幅减少(仅约 0.1% 的参数参与训练)
  • 避免灾难性遗忘
  • 训练速度快、显存占用低

3.2 Waypoints 预测器

原理:将 LLM 最后一层隐藏状态映射为 5 个未来轨迹点(10 维,每点 x,y 坐标)。

实现drive.py:125-129):

self.waypoints_predictor=nn.Sequential(nn.Linear(self.llm_model.config.hidden_size,self.llm_model.config.hidden_size),nn.ReLU(),nn.Linear(self.llm_model.config.hidden_size,10)# 5 waypoints × (x,y))

推理流程

LLM隐藏状态 [B, L, 4096] → waypoints_predictor → [B, L, 10] → 提取有效帧位置 → [N, 10]
GRU Decoder 变体

原理:采用自回归方式迭代预测 5 个轨迹点,每步将前一预测点作为输入反馈给 GRUCell。

实现drive.py:116-123):

self.waypoints_fc=nn.Sequential(nn.Linear(self.llm_model.config.hidden_size,self.llm_model.config.hidden_size),nn.ReLU(),nn.Linear(self.llm_model.config.hidden_size,64))self.waypoints_predictor=nn.GRUCell(input_size=2,hidden_size=64)# 输入: (x,y)坐标self.waypoints_output=nn.Linear(64,2)

推理流程drive.py:469-494):

waypoints_feature=self.waypoints_fc(hidden_states.reshape(-1,4096))# [B*L, 64]x=torch.zeros(size=(bs*n_tokens,2))# 初始位置output_wp=[]for_inrange(5):# 迭代预测5个waypointwaypoints_feature=self.waypoints_predictor(x,waypoints_feature)# GRUCell更新dx=self.waypoints_output(waypoints_feature)# 预测位移x=dx+x# 累加位移output_wp.append(x)predicted_waypoints=torch.cat(output_wp,dim=1)# [B*L, 10]

对比

方法特点适用场景
MLP单次前馈,并行计算训练速度快,简单场景
GRU自回归迭代,考虑时序依赖复杂轨迹,需要平滑性

3.3 End 预测器

原理:分类器判断当前帧是否为轨迹终点。

实现drive.py:130-134):

self.end_predictor=nn.Sequential(nn.Linear(self.llm_model.config.hidden_size,self.llm_model.config.hidden_size),nn.ReLU(),nn.Linear(self.llm_model.config.hidden_size,2)# [continue, end])

损失计算CrossEntropyLoss(predicted_end_prob, gt_end_flags)


3.4 联合损失函数

总损失drive.py:514):

self.waypoints_loss=torch.nn.L1Loss()# L1损失:轨迹点回归self.end_loss=torch.nn.CrossEntropyLoss()# 交叉熵:终点分类loss=waypoints_loss+end_loss*0.2# 加权联合损失

损失权重设计

  • Waypoints 损失(L1Loss):主要优化轨迹点精度,权重为 1.0
  • End 损失(CrossEntropyLoss):辅助优化终点判断,权重为 0.2

原因:自动驾驶任务中,轨迹点精度是首要目标,终点判断是次要目标但对安全至关重要,因此给予较小权重防止过度拟合。


3.5 LLM 作为特征提取器

关键设计:LMDrive 中LlamaForCausalLM.forward返回hidden_states(第1065行),而非 logits。

原因

  • 自动驾驶任务需要回归连续轨迹点,而非生成文本
  • 直接使用隐藏状态更灵活,可通过 MLP 头进行任意下游任务
  • 避免 LM Head 的冗余计算

代码证据modeling_llama.py:1065):

returnhidden_states# 直接返回隐藏状态,而非 logits

4. 输入数据构造

4.1 文本-图像特征拼接

拼接策略drive.py:169-224):

llm_inputs=[input_embeds[i][:text_len],# 文本前缀image_embeds[i].view(t*n,-1),# Q-Former精炼的视觉特征 (t帧 × n个query)input_embeds[i][text_len:]# 文本后缀(padding)]

Attention Mask

  • 文本部分:[1, 1, ..., 1]
  • 有效图像帧:[1, 1, ..., 1]
  • 无效图像帧(padding):[0, 0, ..., 0]

5. 完整推理流程

┌─────────────────────────────────────────────────────────────────────────┐ │ LMDrive LLM 推理流程 │ ├─────────────────────────────────────────────────────────────────────────┤ │ │ │ 1. 视觉编码器提取特征 │ │ [多视角图像 + LiDAR] → Memfuser → [B*t, 65, 768] │ │ │ │ 2. Q-Former 精炼特征 │ │ [B*t, 65, 768] → Q-Former → [B*t, 4, 768] → llm_proj → [B*t, 4, 4096]│ │ │ │ 3. 构造 LLM 输入 │ │ 文本嵌入 [B, L, 4096] + 视觉特征 [B, t, 4, 4096] → [B, L+t*4, 4096] │ │ │ │ 4. LLM 前向传播 │ │ [B, L+t*4, 4096] → LlamaModel(32层Decoder) → [B, L+t*4, 4096] │ │ └── 每层: RMSNorm → Attention(RoPE) → Residual → RMSNorm → MLP → Residual│ │ │ │ 5. Waypoint 预测 │ │ hidden_states[有效位置] → waypoints_predictor → [N, 10] │ │ │ │ 6. End 预测 │ │ hidden_states[有效位置] → end_predictor → [N, 2] → argmax → [N] │ │ │ └─────────────────────────────────────────────────────────────────────────┘

6. 关键技术参数

参数LLaVA-7B说明
hidden_size4096隐藏层维度
num_hidden_layers32解码器层数
num_attention_heads32注意力头数
head_dim128每头维度
intermediate_size11008MLP中间层维度
max_position_embeddings2048最大序列长度
vocab_size32000词表大小
hidden_actsilu激活函数(SwiGLU)

7. 总结

LMDrive 的 LLM 核心推理单元基于LLaMA 架构,包含以下关键组件:

  1. LlamaRMSNorm:Pre-Norm 归一化,无均值中心化
  2. RoPE:旋转位置编码,支持动态 NTK 缩放
  3. 因果自注意力:支持 KV 缓存和 GQA
  4. SwiGLU MLP:平滑梯度流动
  5. DecoderLayer:Pre-Norm 残差架构

在 LMDrive 中,LLM 被用作特征提取器而非文本生成器,通过:

  • LoRA 微调适配自动驾驶任务
  • Waypoints 预测器生成 5 个未来轨迹点
  • End 预测器判断轨迹终点

这种设计充分利用了 LLM 的强大语义理解能力,同时避免了文本生成的冗余计算,高效适配自动驾驶的回归任务需求。

需要专业的网站建设服务?

联系我们获取免费的网站建设咨询和方案报价,让我们帮助您实现业务目标

立即咨询