当前位置: 首页 > news >正文

U-net

文章目录

  • 1、U-net 简介
  • 2、U-net网络详解
    • 2.1、U-net结构图
    • 2.2、U-net主要创新
    • 2.3、U-net网络优势
  • 3、目前常用方法U-net改动
  • 4、U-net网络程序代码

1、U-net 简介

U-net 发表于 2015 年,属于 FCN 的一种变体 。Unet 的初衷是为了解决生物医学图像方面的问题,由于效果确实很好后来也被广泛的应用在语义分割的各个方向,比如卫星图像分割,工业瑕疵检测等。

U-net采用 Encoder-Decoder 和 跳跃链接 的结构,结构简单但很有效。Encoder 负责特征提取,也可以将自己熟悉的各种特征提取网络放在这个位置,例如,vgg、resnet等常用网络。

U-Net是比较早的使用全卷积网络进行语义分割的算法之一,论文中使用包含压缩路径和扩展路径的对称U形结构在当时非常具有创新性,且一定程度上影响了后面若干个分割网络的设计,

该网络的名字也是取自其U形形状。

2、U-net网络详解

2.1、U-net结构图

如下图所示,U-net网络结构就像一个U形状,因此得名。

网络结构的左半部分是论文中的压缩路径(contracting path),主要负责提取图像中的特征。其中共有四个block组成,包含了两次卷积和最大池化操作,每次下采样之后的通道数翻倍,
特征图尺寸缩小一半。得到了512维度大小为32x32的特征,再次经过两次卷积然后送入decoder路径(无池化操作)。

网络结构的右半部分是论文中的扩展路径(expansive path),主要用以恢复特征尺寸并融合高分辨率的特征信息。其中同样包含了四个Block,每个block包含了一次上采样操作+拼接+两次卷积操作。
上采样常见的有转置卷积和插值法,论文中使用的是转置卷积操作。每次上采样之后,通道数减半,特征尺寸增大一倍,然后和对应的高分辨率特征图进行拼接以恢复维度,(由于卷积的过程中没有填充,导致尺寸不匹配,通过Crop裁剪以满足相同大小)。

最后得到64维度大小为388x388的特征图。

网络结构的最后一层是通过卷积操作进行降维并分类,根据数据集的类别个数选择输出对应的通道个数的特征图。论文中是二分类任务,所以输出通道为2。
在这里插入图片描述

2.2、U-net主要创新

采取将低级特征图与后面的高级特征图进行融合操作

完全对称的U型结构使得前后特征融合更为彻底,使得高分辨率信息与低分辨率信息在目标图片中增加

结合了下采样时的低分辨率信息(提供物体类别识别依据)和上采样时的高分辨率信息(提供精准分割定位依据),此外还通过融合操作(跳跃结构)填补底层信息以提高分割精度.(分辨率就是图片的尺寸)

2.3、U-net网络优势

网络层越深得到的特征图,有着更大的视野,浅层卷积关注纹理特征,深层网络关注本质特征,所以深层浅层特征都是有各自的意义;

另外一点是通过反卷积得到的更大的尺寸的特征图的边缘,是缺少信息的,毕竟每一次下采样提炼特征的同时,也必然会损失一些边缘特征,而失去的特征并不能从上采样中找回,因此通过特征的拼接(跳跃连接),来实现边缘特征的一个找回。

在医疗影像上效果好的原因分析:

医疗影像语义较为简单、结构固定。因此语义信息相比自动驾驶等较为单一,因此并不需要去筛选过滤无用的信息。医疗影像的所有特征都很重要,因此低级特征和高级语义特征都很重要,所以U型结构的skip connection结构(特征拼接)更好派上用场。
医学影像的数据较少,获取难度大,数据量可能只有几百甚至不到100,大型网络容易过拟合。

3、目前常用方法U-net改动

从U-net的结构图中可以看出,编码器和解码器的特征尺寸是不同的,所以需要经过裁剪才可以进行后续的拼接操作,增加了模型设计的难度和普适性。目前主流的方法是保持输入输出尺寸相同,并且采用插值法来进行上采样操作,卷积操作加上了BN。此外,对编码器使用常见的网络来提取图像特征,vgg16、resnet18等。

如下图所示:
在这里插入图片描述

4、U-net网络程序代码

from typing import Dict
import torch
import torch.nn as nn
import torch.nn.functional as F
 
## 两次3x3卷积操作
class DoubleConv(nn.Sequential):
    def __init__(self, in_channels, out_channels, mid_channels=None):
        if mid_channels is None:
            mid_channels = out_channels
        super(DoubleConv, self).__init__(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
 
## 池化+两次卷积
class Down(nn.Sequential):
    def __init__(self, in_channels, out_channels):
        super(Down, self).__init__(
            nn.MaxPool2d(2, stride=2),
            DoubleConv(in_channels, out_channels)
        )
 
## 上采样+拼接+两次卷积
class Up(nn.Module):
    def __init__(self, in_channels, out_channels, bilinear=True):
        super(Up, self).__init__()
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)
 
    def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
        x1 = self.up(x1)
        # [N, C, H, W]
        diff_y = x2.size()[2] - x1.size()[2]
        diff_x = x2.size()[3] - x1.size()[3]
 
        # padding_left, padding_right, padding_top, padding_bottom
        x1 = F.pad(x1, [diff_x // 2, diff_x - diff_x // 2,
                        diff_y // 2, diff_y - diff_y // 2])
 
        x = torch.cat([x2, x1], dim=1)
        x = self.conv(x)
        return x
 
## 最后的分类层
class OutConv(nn.Sequential):
    def __init__(self, in_channels, num_classes):
        super(OutConv, self).__init__(
            nn.Conv2d(in_channels, num_classes, kernel_size=1)
        )
 
 
class UNet(nn.Module):
    def __init__(self,
                 in_channels: int = 1,
                 num_classes: int = 2,
                 bilinear: bool = True,
                 base_c: int = 64):
        super(UNet, self).__init__()
        self.in_channels = in_channels
        self.num_classes = num_classes
        self.bilinear = bilinear
 
        self.in_conv = DoubleConv(in_channels, base_c)
        self.down1 = Down(base_c, base_c * 2)
        self.down2 = Down(base_c * 2, base_c * 4)
        self.down3 = Down(base_c * 4, base_c * 8)
        factor = 2 if bilinear else 1
        self.down4 = Down(base_c * 8, base_c * 16 // factor)
        self.up1 = Up(base_c * 16, base_c * 8 // factor, bilinear)
        self.up2 = Up(base_c * 8, base_c * 4 // factor, bilinear)
        self.up3 = Up(base_c * 4, base_c * 2 // factor, bilinear)
        self.up4 = Up(base_c * 2, base_c, bilinear)
        self.out_conv = OutConv(base_c, num_classes)
 
    def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
        x1 = self.in_conv(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.out_conv(x)
 
        return {"out": logits}

相关文章:

  • 印尼做网站的教学 中文/西安seo外包优化
  • wordpress主题 时光/深圳互联网营销
  • wordpress 读取数据的地方/前端性能优化
  • 网站开发在网页插入音频/厦门seo优化公司
  • 保定做网站的公司/查看关键词被搜索排名的软件
  • 晋城网站设计人/做seo推广公司
  • pytest-需要模块相应的库
  • VSCode连GitHub的代理服务器配置和获取历史版本命令
  • 笔试训练(4)
  • “世界上最鸽派”的央行转鹰,透露了什么信号?
  • 从“小螺栓血案”谈装配体模型连接螺栓6个正确的处理方法
  • 讯飞听见SaaS服务迈入全新时代
  • mapstruct 无法生成字段映射code
  • 全志V853常用模型跑分数据
  • 迅为3A5000开发板龙芯自主指令集从里到外100%全国产设计方案
  • 安卓面经<15/30>之SharedPreference解析
  • LeetCode 10. 正则表达式匹配(C++)*
  • 《计算机体系结构量化研究方法》第2章-存储器层次结构设计 2.1 引言