代码功能概述
这段代码实现了一个基于卷积神经网络 (CNN) 和循环神经网络 (LSTM) 结合的模型,用于图像序列预测任务。具体来说,输入是连续的16帧图像,经过编码器 (Encoder) 提取特征后,输入到LSTM进行时间序列学习,最后通过解码器 (Decoder) 还原成原始图像尺寸(或生成第17帧)。模型采用了图像序列作为输入,并以第17帧作为目标图像,使用均方误差 (MSE) 作为损失函数进行训练。
代码详细讲解
1. 数据预处理部分
数据路径生成:
importosdir='/home/lab226/wdf/imgsrc'fp=open('./img_path.txt','w+')imgfile_list=os.listdir('/home/lab226/wdf/imgsrc')imgfile_list.sort(key=lambdax:int(x[:]))# 按照数字顺序对文件夹进行排序seqsize=17forimgfileinimgfile_list:filepath=os.path.join(dir,imgfile)img_list=os.listdir(filepath)img_list.sort(key=lambdax:int(x[:-4]))# 按照数字顺序对图像文件进行排序foriinrange(0,len(img_list)-seqsize,8):forjinrange(i,i+seqsize):img=img_list[j]path=os.path.join(filepath,img)ifj==i+seqsize-1:fp.write(path+'\n')# 写入第17帧图像路径else:fp.write(path+' ')# 写入连续的16帧图像路径fp.close()功能:
- 从文件夹中读取所有视频文件夹路径,并按顺序排序。
- 对于每个视频文件夹,按照图像文件的数字顺序读取帧图像。
- 使用滑动窗口(步长为8)从每个视频序列中提取连续的17帧图像路径,生成一个文本文件
img_path.txt,每行包含一段连续的17帧图像路径,其中最后一帧是目标帧。
2. 数据加载部分
自定义数据集类:
classSeqDataset(Dataset):def__init__(self,txt,transform=None,target_transform=None,loader=default_loader):fh=open(txt,'r')imgseqs=[]forlineinfh:line=line.strip('\n')imgseqs.append(line)self.num_samples=len(imgseqs)self.imgseqs=imgseqs self.transform=transform self.target_transform=target_transform self.loader=loaderdef__getitem__(self,index):current_index=np.random.choice(range(0,self.num_samples))imgs_path=self.imgseqs[current_index].split()current_imgs_path=imgs_path[:len(imgs_path)-1]current_label_path=imgs_path[len(imgs_path)-1]current_label=self.loader(current_label_path)current_imgs=[self.loader(frame)forframeincurrent_imgs_path]ifself.transform:current_imgs=[self.transform(img)forimgincurrent_imgs]current_label=self.transform(current_label)batch_cur_imgs=np.stack(current_imgs,axis=0)returnbatch_cur_imgs,current_label功能:
- 该类继承自
Dataset,从txt文件中读取图像路径。 - 每次调用
__getitem__()时,从路径中获取连续的16帧图像和第17帧作为目标图像,并应用必要的变换(如转换为Tensor)。 - 返回值是一个包含连续16帧图像的批次和目标图像的元组。
数据加载器:
train_data=SeqDataset(txt='./img_path.txt',transform=data_transforms)train_loader=DataLoader(train_data,shuffle=True,num_workers=20,batch_size=BATCH_SIZE)功能:
- 使用
DataLoader加载训练数据,将图像和标签数据组织成批次供模型训练使用。
3. 模型介绍
该模型由EncoderMUG2d_LSTM(编码器)和DecoderMUG2d(解码器)组成,能够对图像序列进行预测。
编码器(Encoder):
classEncoderMUG2d_LSTM(nn.Module):def__init__(self,input_nc=3,encode_dim=1024,lstm_hidden_size=1024,seq_len=SEQ_SIZE,num_lstm_layers=1,bidirectional=False):super(EncoderMUG2d_LSTM,self).__init__()self.encoder=nn.Sequential(nn.Conv2d(input_nc,32,4,2,1),nn.BatchNorm2d(32),nn.LeakyReLU(0.2,inplace=True),nn.Conv2d(32,64,4,2,1),nn.BatchNorm2d(64),nn.LeakyReLU(0.2,inplace=True),#... (省略其他卷积层)nn.Conv2d(512,1024,4,2,1),nn.BatchNorm2d(1024),nn.LeakyReLU(0.2,inplace=True),)self.fc=nn.Linear(1024,encode_dim)self.lstm=nn.LSTM(encode_dim,encode_dim,batch_first=True)功能:
EncoderMUG2d_LSTM通过卷积层提取图像的空间特征,并将输出平展为一个向量,通过全连接层 (fc) 输出一个编码向量。之后,该编码向量传入LSTM中进行时序特征学习。
解码器(Decoder):
classDecoderMUG2d(nn.Module):def__init__(self,output_nc=3,encode_dim=1024):super(DecoderMUG2d,self).__init__()self.project=nn.Sequential(nn.Linear(encode_dim,1024*1*1),nn.ReLU(inplace=True))self.decoder=nn.Sequential(nn.ConvTranspose2d(1024,512,4),nn.BatchNorm2d(512),nn.ReLU(True),#... (省略其他反卷积层)nn.ConvTranspose2d(16,output_nc,4,stride=2,padding=1),nn.Sigmoid(),)功能:
DecoderMUG2d将从LSTM获得的编码信息映射回原始图像尺寸,通过反卷积层逐步还原图像。
4. 训练过程
训练过程包括以下几个步骤:
初始化模型和优化器:
model=net()optimizer=optim.Adam(model.parameters(),lr=learning_rate)loss_func=nn.MSELoss()模型训练:
forepochinrange(10):forbatch_x,batch_yintrain_loader:inputs,label=Variable(batch_x).cuda(),Variable(batch_y).cuda()output=model(inputs)loss=loss_func(output,label)/label.shape[0]optimizer.zero_grad()loss.backward()optimizer.step()
功能:
- 使用 Adam 优化器和 MSE 损失函数对模型进行训练。
- 每个训练批次中,输入连续的16帧图像,预测第17帧,计算损失并更新网络参数。
- 保存图像和模型:
if(epoch+1)%5==0:pic=to_img(output.cpu().data)img=to_img(label.cpu().data)save_image(pic,'./conv_autoencoder/decode_image_{}.png'.format(epoch+1))save_image(img,'./conv_autoencoder/raw_image_{}.png'.format(epoch+1))
功能:
- 每5个周期保存一次预测图像和原图,以便观察模型的训练效果。
- 训练完成后,保存模型的权重。
总结
这段代码实现了一个基于卷积神经网络和循环神经网络(LSTM)的图像序列预测模型。模型通过卷积网络提取图像特征,将提取的特征输入到LSTM中进行时序学习,最后通过解码器恢复图像。通过滑动窗口从视频中提取连续帧的图像,并利用MSE损失函数对模型进行训练和优化。