Pytorch实现SEvariants
【摘要】
import torch
import torch.nn as nn
import torchvision
class cSE_Module(nn.Module):
def __init__(...
import torch
import torch.nn as nn
import torchvision
class cSE_Module(nn.Module):
def __init__(self, channel,ratio = 16):
super(cSE_Module, self).__init__()
self.squeeze = nn.AdaptiveAvgPool2d(1)
self.excitation = nn.Sequential(
nn.Linear(in_features=channel, out_features=channel // ratio),
nn.ReLU(inplace=True),
nn.Linear(in_features=channel // ratio, out_features=channel),
nn.Sigmoid()
)
def forward(self, x):
b, c, _, _ = x.size()
y = self.squeeze(x).view(b, c)
z = self.excitation(y).view(b, c, 1, 1)
return x * z.expand_as(x)
class sSE_Module(nn.Module):
def __init__(self, channel):
super(sSE_Module, self).__init__()
self.spatial_excitation = nn.Sequential(
nn.Conv2d(in_channels=channel, out_channels=1, kernel_size=1,stride=1,padding=0),
nn.Sigmoid()
)
def forward(self, x):
z = self.spatial_excitation(x)
return x * z.expand_as(x)
class scSE_Module(nn.Module):
def __init__(self, channel,ratio = 16):
super(scSE_Module, self).__init__()
self.cSE = cSE_Module(channel,ratio)
self.sSE = sSE_Module(channel)
def forward(self, x):
return self.cSE(x) + self.sSE(x)
if __name__=='__main__':
# model = cSE_Module(channel=16)
# model = sSE_Module(channel=16)
model = scSE_Module(channel=16)
print(model)
input = torch.randn(1, 16, 64, 64)
out = model(input)
print(out.shape)
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
- 52
- 53
文章来源: wanghao.blog.csdn.net,作者:AI浩,版权归原作者所有,如需转载,请联系作者。
原文链接:wanghao.blog.csdn.net/article/details/121573827
【版权声明】本文为华为云社区用户转载文章,如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱:
cloudbbs@huaweicloud.com
- 点赞
- 收藏
- 关注作者
评论(0)