fine-turning of VGG
一、 fine-turning
由于数据集的限制,我们可以使用预训练的模型,来重新fine-turning(微调)。
使用卷积网络作为特征提取器,冻结卷积操作层,这是因为卷积层提取的特征对于许多任务都有用处,使用新的数据集训练新定义的全连接层。
何时以及如何Fine-tune
决定如何使用迁移学习的因素有很多,这是最重要的只有两个:新数据集的大小、以及新数据和原数据集的相似程度。有一点一定记住:网络前几层学到的是通用特征,后面几层学到的是与类别相关的特征。这里有使用的四个场景:
1、新数据集比较小且和原数据集相似。因为新数据集比较小,如果fine-tune可能会过拟合;又因为新旧数据集类似,我们期望他们高层特征类似,可以使用预训练网络当做特征提取器,用提取的特征训练线性分类器。
2、新数据集大且和原数据集相似。因为新数据集足够大,可以fine-tune整个网络。
3、新数据集小且和原数据集不相似。新数据集小,最好不要fine-tune,和原数据集不类似,最好也不使用高层特征。这时可是使用前面层的特征来训练SVM分类器。
4、新数据集大且和原数据集不相似。因为新数据集足够大,可以重新训练。但是实践中fine-tune预训练模型还是有益的。新数据集足够大,可以fine-tine整个网络。
我们这次的作业属于数据集与原来相似而且数据集很小的情况,可以使用预训练网络当做特征提取器,用提取的特征训练线性分类器。
二、代码重点
数据处理
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
vgg_format = transforms.Compose([
transforms.CenterCrop(224),
#.中心裁剪:transforms.CenterCrop
#class torchvision.transforms.CenterCrop(size)
#功能:依据给定的size从中心裁剪
#参数:
#size- (sequence or int),若为sequence,则为(h,w),若为int,则(size,size)
transforms.ToTensor(),
normalize,
])
data_dir = './dogscats'
dsets = {x: datasets.ImageFolder(os.path.join(data_dir, x), vgg_format)
for x in ['train', 'valid']}
dset_sizes = {x: len(dsets[x]) for x in ['train', 'valid']}
dset_classes = dsets['train'].classes
修改全连接层,冻结卷积层的参数
for param in model_vgg_new.parameters():
param.requires_grad = False #训练时不更改参数
model_vgg_new.classifier._modules['6'] = nn.Linear(4096, 2) #全连接输出两类猫或者狗
model_vgg_new.classifier._modules['7'] = torch.nn.LogSoftmax(dim = 1) # 数据处理
创建损失函数和优化器,训练模型
criterion = nn.NLLLoss() #设置损失函数
lr = 0.001 # 学习率
optimizer_vgg = torch.optim.SGD(model_vgg_new.classifier[6].parameters(),lr = lr) # 随机梯度下降
#训练模型 (模板 建议直接背诵)
def train_model(model,dataloader,size,epochs=1,optimizer=None):
model.train()
for epoch in range(epochs):
running_loss = 0.0
running_corrects = 0
count = 0
for inputs,classes in dataloader:
inputs = inputs.to(device)
classes = classes.to(device)
outputs = model(inputs)
loss = criterion(outputs,classes)
optimizer = optimizer
optimizer.zero_grad()
loss.backward()
optimizer.step()
_,preds = torch.max(outputs.data,1)
# statistics
running_loss += loss.data.item()
running_corrects += torch.sum(preds == classes.data)
count += len(inputs)
print('Training: No. ', count, ' process ... total: ', size)
epoch_loss = running_loss / size
epoch_acc = running_corrects.data.item() / size
print('Loss: {:.4f} Acc: {:.4f}'.format(
epoch_loss, epoch_acc))
三、代码优化
1.数据处理
vgg_format = transforms.Compose([
#transforms.CenterCrop(224),
transforms.Resize((224,224)),
#这里选择缩放而不是中心裁剪,因为简单地选择重心裁剪会让图像的一些特征直接丢失,严重的情况下直接无法捕捉到物体(cat or dog),这样的情况下卷积也没有什么作用了
transforms.ToTensor(),
normalize,
])
下面的图片就可以看出使用缩放而不是中心裁剪的原因:原本图片的