awq_marlin和gptq_marlin量化算法简要介绍
2026/6/3 1:50:01 网站建设 项目流程

量化算法的本质

量化算法的本质在于快速实现Y=XW计算,其中X,W往往都是FP16数据类型,在大模型推理过程中,输入X往往被称之为activation,而权重W被称之为weight,对于一个[M,K,N]的矩阵乘法(即X的形状为[M,K],W形状为[K,N],Y形状为[M,N])的矩阵来说,最简单的实现方式就是调用cublas仓库,这里需要重点注意的是X和W的排列方式,尤其是W的排列方式,比如说下面这段代码

cublasGemmEx(handle,CUBLAS_OP_T,CUBLAS_OP_N,N,M,K,&alpha,//alpha = 1.0fW,CUDA_R_16F,ldb,//ldb = KX,CUDA_R_16F,lda,//lda = K&beta,//beta = 0.0fY,CUDA_R_32F,ldo,//ldo = NCUBLAS_COMPUTE_32F,CUBLAS_GEMM_DEFAULT);

处理的就是Y=XW的计算过程,但是其中X是行主元数据(即X形状是[M,K],步长stride是[K,1]),而W是列主元数据(即W形状是[K,N],步长stride是[1,K],这个可以通过W = torch.randn([N,K]).t()这种方式得到)。
上面的这个矩阵乘法大家耳熟能详,但是这个矩阵乘法在大模型推理训练过程中会带来一些问题,比如说最直接的就是显存占用情况,假设我们只考虑Y=XW这个计算,如果W是一个数据量为8B(80亿参数)的矩阵,如果W的每个元素都是FP16,那么W需要占用显存(GB)= 参数量×数据比特数/(8×1024×1024×1024)=14.9,也就是说,仅仅考虑存储这个权重W就需要占用14.9GB显存,如果我们能够换一种思路,比如说把权重的数据类型换成INT8,此时显存马上可以降低一半变成7.45GB,如果进一步把权重数据类型变成INT4,那么显存继续降低变成3.73GB,也就是说,对于一个稍微普通的带显卡的笔记本,就可以实现这个推理过程了。
在输入X,也就是activation数据类型为F16的情况下,W数据类型为INT8,此时的量化称之为W8A16量化,如果W数据类型为INT4,此时量化称之为W4A16量化。
比如说W8A8量化,指的就是输入X数据类型为INT8,权重矩阵数据类型也是INT8。

矩阵的量化算法

这里我们先介绍一下矩阵的量化算法,即:一个形状为[K,N],数据类型为FP16的权重矩阵W,是如何变成另一个形状为[K,N],数据类型为INT4或者INT8的量化矩阵w_packed。下面我们以INT8来举例子说明:

per_tensor_quant_int8

这种量化最简单,先计算出全局的abs最大值global_max = max(abs(W.flatten())),此时引入一个scale = global_max /127,有了scale以后,下面使用这段伪代码

foriinrange(K):forjinrange(N):val=W[i,j]/scale val=max(-127,min(127,val))w_packed[i,j]=val.to(torch.int8)

通过上面这种方式得到的结果,我们称之为对称量化,这种对称量化方式涉及到的参数有w_packed, scale和W,其中scale是一个长度为1的数据类型为F32的tensor,与之对应的还有一个非对称量化,非对称量化会多一个zero参数,但是实际大模型量化过程中用的最多的就是对称量化。

per_channel_quant_int8

有了上面关于per_tensor_quant_int8量化的介绍,此时理解per_channel_quant_int8量化就简单多了,对称的per_channel_quant_int8量化和上面的区别在于scale的形状变成了[K,1],也就是说,原来需要计算整个矩阵abs(W)的全局最大值,现在需要针对每一行abs(W[i,:]计算最大值,python代码实现可以参考

defper_channel_quant_int8_torch(x,symmetric):ifsymmetric:x=x.float()absmax=x.abs().max(dim=-1).values absmax=absmax.clamp_min(1e-10).unsqueeze(-1)scale_x=absmax/127x_q=x.mul(127/absmax)x_q=torch.round(x_q).to(torch.int8)returnx_q,scale_x,Noneelse:w=x.float()w_min=w.min(dim=-1,keepdim=True)[0]w_max=w.max(dim=-1,keepdim=True)[0]w_scale=(w_max-w_min)/255.0w_scale=torch.clamp(w_scale,min=1e-8)w_zero=-w_min/w_scale-128.0w_q=torch.round(w/w_scale+w_zero)w_q=torch.clamp(w_q,-128,127)w_packed=w_q.to(torch.int8)returnw_packed,w_scale,w_zero

与之类似的还有per_channel_quant_fp8,per_tensor_quant_fp8,本质上没有区别,只不过最后量化结果数据类型不一致而已,至于per_channel_quant_int8和per_tensor_quant_int8的CUDA代码实现也非常简单,可以参考添加链接描述和添加链接描述
真正的难点在于矩阵乘法。

量化模型的矩阵乘法

在量化大模型推理过程中,一般会提前提供已经量化好的权重以及对应的scale,也就是说,现在需要实现的计算过程拥有下面几个参数:
X:形状为[M,K],往往是行主元,即步长stride=[K,1],数据类型为FP16或者FP32
W:形状为[K,N],如果是行主元,那么步长stride=[N,1],如果是列主元,步长stride=[1,K],数据类型可能为INT4或者是INT8
scale:当形状为[M,1]对应的是per_channel_quant,当形状为[1,]的时候,对应的是per_tensor_quant,数据类型为FP32
zero:这是optional参数,数据类型和形状往往和scale保持一致(但是awq_marlin_gemm,gptq_marlin_gemm这些量化模型可能会很不一样),如果zeros存在,那么对应的就是非对称量化,如果zeros不存在,对应的就是对称量化。
bias:这是一个optional参数,数据类型为FP32,形状往往为[N,1]

W8A8矩阵乘法

我们以W8A8对称量化算法来举例说明量化矩阵乘法,计算过程涉及的参数就变成了:
X:形状为[M,K],往往是行主元,即步长stride=[K,1],数据类型为INT8
W:形状为[K,N],这里我们考虑列主元,步长stride=[1,K],数据类型为INT8
x_scale:形状为[M,1],数据类型为FP32
w_scale:形状为[N,1],数据类型为FP32
bias:这是一个optional参数,数据类型为FP32,形状为[N,1]
需要实现的计算大概就是
Y = (x_scale * X)@ (w_scale * W) + bias,
在具体的实现过程中,有两种方案,
方案1:先调用cublas计算y_packed = x_packed@w_packed,由于此时x_packed, w_packed数据类型都是INT8的,使用cublas计算速度会特别快,这个时候相当于说,我们需要在CUDA层面额外引入一份显存来存储这个临时数据y_packed,cublas的计算流程可以参考下面这段代码

constint32_talpha_I=1;constint32_tbeta_I=0;cublasGemmEx(handle,CUBLAS_OP_T,CUBLAS_OP_N,N,M,K,&alpha_I,b,CUDA_R_8I,ldb,//ldb = Ka,CUDA_R_8I,lda,//lda = K&beta_I,y_packed,CUDA_R_32I,ldo,//ldo=NCUBLAS_COMPUTE_32I,CUBLAS_GEMM_DEFAULT);

特别注意,上面的这段代码里面要求x_packed是行主元的形状为[M,K]的指针,而w_packed是列主元的,形状为[K,N]的指针,其中x_packed可以通过torch.randn([M,K])直接生成,而w_packed可以通过torch.randn([N,K]).t()直接生成。有了y_packed以后,剩下的就是做后处理,根据x_scale和w_scale以及y_packed把结果还原出来,这部分比较简单。
可以看出,这个方案1其实需要实现两个kernel,第一个kernel调用cublas,第二个kernel进行后处理,这个方案非常直接简单,但是在性能上不占优势,CUDA代码里面,我们希望一个算子往往只占用一个kernel,这种做法肯定会比直接使用torch.matmul计算FP16的X@W要慢。
方案2:直接调用cutlass来计算整个过程,这个说起来很简单,但是实现起来非常复杂,这个的原始代码参考添加链接描述

GPTQ MARLIN矩阵乘法

Gptq marlin矩阵乘法计算的也是Y=XW,其中X,W往往都是FP16数据类型,在大模型推理过程中,输入X往往被称之为activation,而权重W被称之为weight,对于一个[M,K,N]的矩阵乘法(即X的形状为[M,K],W形状为[K,N],Y形状为[M,N])的矩阵来说。这个算法的核心目的包括:
1:把浮点权重(W)量化压缩到 4bit/8bit(体积缩小 4~8 倍);
2:把量化后的权重重排成 MARLIN 专用格式(适配 GPU 硬件执行单元);
3:保证推理速度接近浮点、精度几乎无损。

模块1:将浮点数据类型的W量化得到量化权重w_q, w_s, w_z以及根据(w_q,w_s,w_z)反量化得到的w_ref

首先根据下面的逻辑做数据重排

# 代码逻辑w=w.reshape((-1,group_size,size_n))# [K/group, group, N]w=w.permute(1,0,2)# [group, K/group, N]w=w.reshape((group_size,-1))# [group, K/group * N]

然后获取对应group的最大值,最小值以及绝对最大值。

max_val=torch.max(w,0,keepdim=True).values min_val=torch.min(w,0,keepdim=True).values abs_val=torch.max(abs(max_val),abs(min_val)

如果需要设置零点,那么就计算对应的w_s和w_z,参考

max_q_val=quant_type.max()min_q_val=quant_type.min()w_s=(max_val-min_val).clamp(min=1e-5)/quant_type.max()maybe_w_zp=(torch.round(torch.abs(min_val/w_s)).clamp(min_q_val,max_q_val).int())

这里提到的quant_type就是对应的量化类型,如果是int8量化,那么对应的quant_type就是int8,此时对应的上下界max_q_val,min_q_val就是127,-127,用数学公式表达就是

如果不需要设置零点,那么只计算w_s,对应的代码参考

w_s=torch.max(abs(max_val/(max_q_valifmax_q_val!=0elsetorch.inf)),abs(min_val/(min_q_valifmin_q_val!=0elsetorch.inf)),)

用数学公式表达就是

对应的量化权重计算参考

w_q=torch.round(w/w_s).int()+(maybe_w_zpifzero_pointselse0)w_q=torch.clamp(w_q,min_q_val,max_q_val)

用数学公式表达就是

最后根据刚才计算的w_q,w_s,w_z得到一个反量化的w_ref=(w_q - w_z)w_s,值得说明的是,此时这个w_ref和最原始的w大概率不等价。
GPTQ 把权重压成 INT4/INT8,但直接存成 [K,N] 矩阵,GPU 跑不快。这是因为
1:GPU TensorCore(MMA)一次喜欢读 16×16 小块;
2:而且要连续内存、特定顺序才能用向量加载(LDG.128);
3:原生矩阵是 “行主序”,不满足硬件读取模式。
Marlin perm 的本质: 把 INT4/INT8 权重重新切成 16×16 瓦片 → 打乱瓦片内部元素顺序 → 拼成 GPU 最喜欢的内存布局。重点:perm 不是随机乱排,是硬编码的、为了 TensorCore 读得快的固定重排。
这个marlin重排的过程很复杂,本人也不太能看懂。
awq_marlin_gemm和gptq_marlin_gemm的主要区别在于awq_marlin不支持zeros,
awq_marlin的实现源代码来自添加链接描述
gptq_marlin的实现源代码来自添加链接描述
本人针对awq,gptq的矩阵乘法做了一个识别简化,可以参考添加链接描述和添加链接描述

需要专业的网站建设服务?

联系我们获取免费的网站建设咨询和方案报价,让我们帮助您实现业务目标

立即咨询