1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
|
from torch import nn
from torch.nn import functional as F
import torch
from torchvision import models
import torchvision
class conv_block(nn.Module): # 形状没有发生变化
def __init__(self,ch_in,ch_out):
super(conv_block,self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(ch_in, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
nn.BatchNorm2d(ch_out),
nn.ReLU(inplace=True),
nn.Conv2d(ch_out, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
nn.BatchNorm2d(ch_out),
nn.ReLU(inplace=True)
)
def forward(self,x):
x = self.conv(x)
return x
class up_conv(nn.Module): # 上采样:扩大两倍
def __init__(self,ch_in,ch_out):
super(up_conv,self).__init__()
self.up = nn.Sequential(
nn.Upsample(scale_factor=2),
nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=1,padding=1,bias=True),
nn.BatchNorm2d(ch_out),
nn.ReLU(inplace=True)
)
def forward(self,x):
x = self.up(x)
return x
class Attention_block(nn.Module):
# self.Att5 = Attention_block(F_g=512, F_l=512, F_int=256)
def __init__(self, F_g, F_l, F_int):
super(Attention_block, self).__init__()
self.W_g = nn.Sequential(
nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
nn.BatchNorm2d(F_int)
)
self.W_x = nn.Sequential(
nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
nn.BatchNorm2d(F_int)
)
self.psi = nn.Sequential(
nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
nn.BatchNorm2d(1),
nn.Sigmoid()
)
self.relu = nn.ReLU(inplace=True)
def forward(self, g, x):
# 下采样的gating signal 卷积
g1 = self.W_g(g)
# 上采样的 l 卷积
x1 = self.W_x(x)
# concat + relu
psi = self.relu(g1 + x1)
# channel 减为1,并Sigmoid,得到权重矩阵
psi = self.psi(psi)
# 返回加权的 x
return x * psi
class AttentionUnet(nn.Module):
def __init__(self, img_ch=3, output_ch=1):
super(AttentionUnet, self).__init__()
self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
self.Conv1 = conv_block(ch_in=img_ch, ch_out=64)
self.Conv2 = conv_block(ch_in=64, ch_out=128)
self.Conv3 = conv_block(ch_in=128, ch_out=256)
self.Conv4 = conv_block(ch_in=256, ch_out=512)
self.Conv5 = conv_block(ch_in=512, ch_out=1024)
self.Up5 = up_conv(ch_in=1024, ch_out=512)
self.Att5 = Attention_block(F_g=512, F_l=512, F_int=256)
self.Up_conv5 = conv_block(ch_in=1024, ch_out=512)
self.Up4 = up_conv(ch_in=512, ch_out=256)
self.Att4 = Attention_block(F_g=256, F_l=256, F_int=128)
self.Up_conv4 = conv_block(ch_in=512, ch_out=256)
self.Up3 = up_conv(ch_in=256, ch_out=128)
self.Att3 = Attention_block(F_g=128, F_l=128, F_int=64)
self.Up_conv3 = conv_block(ch_in=256, ch_out=128)
self.Up2 = up_conv(ch_in=128, ch_out=64)
self.Att2 = Attention_block(F_g=64, F_l=64, F_int=32)
self.Up_conv2 = conv_block(ch_in=128, ch_out=64)
self.Conv_1x1 = nn.Conv2d(64, output_ch, kernel_size=1, stride=1, padding=0)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
# encoding path
x1 = self.Conv1(x) # [64,512,512]。[3,224,224]-->[64,224,224] 只是通道数改变(self.Conv1改变),形状不变
x2 = self.Maxpool(x1) # [64,256,256]。[64,224,224]-->[64,112,112] 通道数不变,形状缩小一倍
x2 = self.Conv2(x2) # [128,256,256]。 [64,112,112]-->[128,112,112]
x3 = self.Maxpool(x2) # [128,128,128]。[128,112,112]-->[128,56,56]
x3 = self.Conv3(x3) # [256,128,128]。[128,56,56]-->[256,56,56]
x4 = self.Maxpool(x3) # [256,64,64]。[256,56,56]-->[256,28,28]
x4 = self.Conv4(x4) # [512,64,64]。[256,28,28]-->[512,28,28]
x5 = self.Maxpool(x4) # [512,32,32]。[512,28,28]-->[512,14,14]
x5 = self.Conv5(x5) # [1024,32,32]。[512,14,14]-->[1024,14,14]
# decoding + concat path
d5 = self.Up5(x5) # [512,64,64]。[1024,14,14]-->[512,28,28] 形状扩大2,且通道数增加
x4 = self.Att5(g=d5, x=x4) # d5是[512,64,64],x4是[512,64,64]。我这里d5是[512,28,28],x4是[512,28,28]--->注意力机制之后输出x4也为[512,28,28]
d5 = torch.cat((x4, d5), dim=1) # [1024,64,64]。x4为[512,28,28],d5为[512,28,28]--->拼接之后为:[1024,28,28]
d5 = self.Up_conv5(d5) # [512,64,64]。通道变为,形状不变[512,28,28]
d4 = self.Up4(d5) # [256,128,128]。形状扩大2,且通道数变小:[256,56,56]
x3 = self.Att4(g=d4, x=x3) # [256,128,128],[256,128,128]。[256,56,56],[256,56,56]
d4 = torch.cat((x3, d4), dim=1) # [512,128,128]。拼接完用倒数变量:[512,56,56]
d4 = self.Up_conv4(d4) # [256,128,128]。[512,56,56]-->[256,56,56]
d3 = self.Up3(d4) # [128,256,256]。[256,56,56]-->[128,112,112]
x2 = self.Att3(g=d3, x=x2) # [128,256,256],[128,256,256]。[128,112,112],[128,112,112]
d3 = torch.cat((x2, d3), dim=1) # [256,256,256]。[128,112,112]-->[256,112,112]
d3 = self.Up_conv3(d3) # [128,256,256]。[128,112,112]
d2 = self.Up2(d3) # [64,512,512]。通道数减少,形状变大2倍:[64,224,224]
x1 = self.Att2(g=d2, x=x1) # [64,512,512],[64,512,512]。[64,224,224],[64,224,224]
d2 = torch.cat((x1, d2), dim=1) # [128,512,512]。[128,224,224]
d2 = self.Up_conv2(d2) # [64,512,512]。[64,224,224]
d1 = self.Conv_1x1(d2) # [2,512,512]。输出通道变为2:[2,224,224]
d1 = self.sigmoid(d1) # [2,512,512]。[2,224,224]
return d1
|