U-Net网络的Pytorch实现
1.文章原文地址
U-Net: Convolutional Networks for Biomedical Image Segmentation
2.文章摘要
普遍认为成功训练深度神经网络需要大量标注的训练数据。在本文中,我们提出了一个网络结构,以及使用数据增强的策略来训练网络使得可用的标注样本更加有效的被使用。这个网络是由一个捕捉上下文信息的收缩部分和与之相对称的放大部分,后者能够准确的定位。我们的结果展示了这个网络可以进行端到端的训练,使用非常少的数据就可以达到非常好的结果,并且超过了当前的最佳方法(滑动窗网络)在ISBII挑战赛上电子显微镜下神经结构的分割的结果。利用透射光显微镜图像使用相同网络进行训练,我们大幅度的赢得了2015年的ISBI细胞追踪挑战赛。而且,这个网络非常快,在一个当前的GPU上,分割一个512×512的图像所花费的时间少于一秒。完整的代码以及训练好的网络可见(基于Caffe)http://lmb.informatik.uni-freiburg.de/people/ronneber/u-net.
3.网络结构
4.Pytorch实现
1 import torch 2 import torch.nn as nn 3 import torch.nn.functional as F 4 from torchsummary import summary 5 6 7 class unetConv2(nn.Module): 8 def __init__(self,in_size,out_size,is_batchnorm): 9 super(unetConv2,self).__init__() 10 11 if is_batchnorm: 12 self.conv1=nn.Sequential( 13 nn.Conv2d(in_size,out_size,kernel_size=3,stride=1,padding=0), 14 nn.BatchNorm2d(out_size), 15 nn.ReLU(inplace=True), 16 ) 17 self.conv2=nn.Sequential( 18 nn.Conv2d(out_size,out_size,kernel_size=3,stride=1,padding=0), 19 nn.BatchNorm2d(out_size), 20 nn.ReLU(inplace=True), 21 ) 22 else: 23 self.conv1=nn.Sequential( 24 nn.Conv2d(in_size,out_size,kernel_size=3,stride=1,padding=0), 25 nn.ReLU(inplace=True), 26 ) 27 self.conv2=nn.Sequential( 28 nn.Conv2d(out_size,out_size,kernel_size=3,stride=1,padding=0), 29 nn.ReLU(inplace=True) 30 ) 31 def forward(self, inputs): 32 outputs=self.conv1(inputs) 33 outputs=self.conv2(outputs) 34 35 return outputs 36 37 class unetUp(nn.Module): 38 def __init__(self,in_size,out_size,is_deconv): 39 super(unetUp,self).__init__() 40 self.conv=unetConv2(in_size,out_size,False) 41 if is_deconv: 42 self.up=nn.ConvTranspose2d(in_size,out_size,kernel_size=2,stride=2) 43 else: 44 self.up=nn.UpsamplingBilinear2d(scale_factor=2) 45 46 def forward(self, inputs1,inputs2): 47 outputs2=self.up(inputs2) 48 offset=outputs2.size()[2]-inputs1.size()[2] 49 padding=2*[offset//2,offset//2] 50 outputs1=F.pad(inputs1,padding) #padding is negative, size become smaller 51 52 return self.conv(torch.cat([outputs1,outputs2],1)) 53 54 class unet(nn.Module): 55 def __init__(self,feature_scale=4,n_classes=21,is_deconv=True,in_channels=3,is_batchnorm=True): 56 super(unet,self).__init__() 57 self.is_deconv=is_deconv 58 self.in_channels=in_channels 59 self.is_batchnorm=is_batchnorm 60 self.feature_scale=feature_scale 61 62 filters=[64,128,256,512,1024] 63 filters=[int(x/self.feature_scale) for x in filters] 64 65 #downsample 66 self.conv1=unetConv2(self.in_channels,filters[0],self.is_batchnorm) 67 self.maxpool1=nn.MaxPool2d(kernel_size=2) 68 69 self.conv2=unetConv2(filters[0],filters[1],self.is_batchnorm) 70 self.maxpool2=nn.MaxPool2d(kernel_size=2) 71 72 self.conv3=unetConv2(filters[1],filters[2],self.is_batchnorm) 73 self.maxpool3=nn.MaxPool2d(kernel_size=2) 74 75 self.conv4=unetConv2(filters[2],filters[3],self.is_batchnorm) 76 self.maxpool4=nn.MaxPool2d(kernel_size=2) 77 78 self.center=unetConv2(filters[3],filters[4],self.is_batchnorm) 79 80 #umsampling 81 self.up_concat4=unetUp(filters[4],filters[3],self.is_deconv) 82 self.up_concat3=unetUp(filters[3],filters[2],self.is_deconv) 83 self.up_concat2=unetUp(filters[2],filters[1],self.is_deconv) 84 self.up_concat1=unetUp(filters[1],filters[0],self.is_deconv) 85 86 #final conv (without and concat) 87 self.final=nn.Conv2d(filters[0],n_classes,kernel_size=1) 88 89 def forward(self, inputs): 90 conv1=self.conv1(inputs) 91 maxpool1=self.maxpool1(conv1) 92 93 conv2=self.conv2(maxpool1) 94 maxpool2=self.maxpool2(conv2) 95 96 conv3=self.conv3(maxpool2) 97 maxpool3=self.maxpool3(conv3) 98 99 conv4=self.conv4(maxpool3) 100 maxpool4=self.maxpool4(conv4) 101 102 center=self.center(maxpool4) 103 up4=self.up_concat4(conv4,center) 104 up3=self.up_concat3(conv3,up4) 105 up2=self.up_concat2(conv2,up3) 106 up1=self.up_concat1(conv1,up2) 107 108 final=self.final(up1) 109 110 return final 111 112 if __name__=="__main__": 113 model=unet(feature_scale=1) 114 print(summary(model,(3,572,572)))
1 ---------------------------------------------------------------- 2 Layer (type) Output Shape Param # 3 ================================================================ 4 Conv2d-1 [-1, 64, 570, 570] 1,792 5 BatchNorm2d-2 [-1, 64, 570, 570] 128 6 ReLU-3 [-1, 64, 570, 570] 0 7 Conv2d-4 [-1, 64, 568, 568] 36,928 8 BatchNorm2d-5 [-1, 64, 568, 568] 128 9 ReLU-6 [-1, 64, 568, 568] 0 10 unetConv2-7 [-1, 64, 568, 568] 0 11 MaxPool2d-8 [-1, 64, 284, 284] 0 12 Conv2d-9 [-1, 128, 282, 282] 73,856 13 BatchNorm2d-10 [-1, 128, 282, 282] 256 14 ReLU-11 [-1, 128, 282, 282] 0 15 Conv2d-12 [-1, 128, 280, 280] 147,584 16 BatchNorm2d-13 [-1, 128, 280, 280] 256 17 ReLU-14 [-1, 128, 280, 280] 0 18 unetConv2-15 [-1, 128, 280, 280] 0 19 MaxPool2d-16 [-1, 128, 140, 140] 0 20 Conv2d-17 [-1, 256, 138, 138] 295,168 21 BatchNorm2d-18 [-1, 256, 138, 138] 512 22 ReLU-19 [-1, 256, 138, 138] 0 23 Conv2d-20 [-1, 256, 136, 136] 590,080 24 BatchNorm2d-21 [-1, 256, 136, 136] 512 25 ReLU-22 [-1, 256, 136, 136] 0 26 unetConv2-23 [-1, 256, 136, 136] 0 27 MaxPool2d-24 [-1, 256, 68, 68] 0 28 Conv2d-25 [-1, 512, 66, 66] 1,180,160 29 BatchNorm2d-26 [-1, 512, 66, 66] 1,024 30 ReLU-27 [-1, 512, 66, 66] 0 31 Conv2d-28 [-1, 512, 64, 64] 2,359,808 32 BatchNorm2d-29 [-1, 512, 64, 64] 1,024 33 ReLU-30 [-1, 512, 64, 64] 0 34 unetConv2-31 [-1, 512, 64, 64] 0 35 MaxPool2d-32 [-1, 512, 32, 32] 0 36 Conv2d-33 [-1, 1024, 30, 30] 4,719,616 37 BatchNorm2d-34 [-1, 1024, 30, 30] 2,048 38 ReLU-35 [-1, 1024, 30, 30] 0 39 Conv2d-36 [-1, 1024, 28, 28] 9,438,208 40 BatchNorm2d-37 [-1, 1024, 28, 28] 2,048 41 ReLU-38 [-1, 1024, 28, 28] 0 42 unetConv2-39 [-1, 1024, 28, 28] 0 43 ConvTranspose2d-40 [-1, 512, 56, 56] 2,097,664 44 Conv2d-41 [-1, 512, 54, 54] 4,719,104 45 ReLU-42 [-1, 512, 54, 54] 0 46 Conv2d-43 [-1, 512, 52, 52] 2,359,808 47 ReLU-44 [-1, 512, 52, 52] 0 48 unetConv2-45 [-1, 512, 52, 52] 0 49 unetUp-46 [-1, 512, 52, 52] 0 50 ConvTranspose2d-47 [-1, 256, 104, 104] 524,544 51 Conv2d-48 [-1, 256, 102, 102] 1,179,904 52 ReLU-49 [-1, 256, 102, 102] 0 53 Conv2d-50 [-1, 256, 100, 100] 590,080 54 ReLU-51 [-1, 256, 100, 100] 0 55 unetConv2-52 [-1, 256, 100, 100] 0 56 unetUp-53 [-1, 256, 100, 100] 0 57 ConvTranspose2d-54 [-1, 128, 200, 200] 131,200 58 Conv2d-55 [-1, 128, 198, 198] 295,040 59 ReLU-56 [-1, 128, 198, 198] 0 60 Conv2d-57 [-1, 128, 196, 196] 147,584 61 ReLU-58 [-1, 128, 196, 196] 0 62 unetConv2-59 [-1, 128, 196, 196] 0 63 unetUp-60 [-1, 128, 196, 196] 0 64 ConvTranspose2d-61 [-1, 64, 392, 392] 32,832 65 Conv2d-62 [-1, 64, 390, 390] 73,792 66 ReLU-63 [-1, 64, 390, 390] 0 67 Conv2d-64 [-1, 64, 388, 388] 36,928 68 ReLU-65 [-1, 64, 388, 388] 0 69 unetConv2-66 [-1, 64, 388, 388] 0 70 unetUp-67 [-1, 64, 388, 388] 0 71 Conv2d-68 [-1, 21, 388, 388] 1,365 72 ================================================================ 73 Total params: 31,040,981 74 Trainable params: 31,040,981 75 Non-trainable params: 0 76 ---------------------------------------------------------------- 77 Input size (MB): 3.74 78 Forward/backward pass size (MB): 3158.15 79 Params size (MB): 118.41 80 Estimated Total Size (MB): 3280.31
参考