init
This commit is contained in:
373
model/TDCNet/TDCNetwork.py
Normal file
373
model/TDCNet/TDCNetwork.py
Normal file
@@ -0,0 +1,373 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from model.TDCNet.TDCSTA import CrossAttention, SelfAttention
|
||||
from model.TDCNet.backbone3d import Backbone3D
|
||||
from model.TDCNet.backbonetd import BackboneTD
|
||||
from model.TDCNet.darknet import BaseConv, CSPDarknet, DWConv
|
||||
|
||||
|
||||
class Feature_Backbone(nn.Module):
|
||||
def __init__(self, depth=1.0, width=1.0, in_features=("dark3", "dark4", "dark5"), in_channels=[256, 512, 1024], depthwise=False, act="silu"):
|
||||
super().__init__()
|
||||
self.backbone = CSPDarknet(depth, width, depthwise=depthwise, act=act)
|
||||
self.in_features = in_features
|
||||
|
||||
def forward(self, input):
|
||||
out_features = self.backbone.forward(input)
|
||||
[feat1, feat2, feat3] = [out_features[f] for f in self.in_features]
|
||||
return [feat1, feat2, feat3]
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
# Standard bottleneck
|
||||
def __init__(self, in_channels, out_channels, shortcut=True, expansion=0.5, depthwise=False, act="silu", ):
|
||||
super().__init__()
|
||||
hidden_channels = int(out_channels * expansion)
|
||||
Conv = BaseConv # if depthwise else BaseConv
|
||||
# --------------------------------------------------#
|
||||
# 利用1x1卷积进行通道数的缩减。缩减率一般是50%
|
||||
# --------------------------------------------------#
|
||||
self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act)
|
||||
# --------------------------------------------------#
|
||||
# 利用3x3卷积进行通道数的拓张。并且完成特征提取
|
||||
# --------------------------------------------------#
|
||||
self.conv2 = Conv(hidden_channels, out_channels, 3, stride=1, act=act)
|
||||
# self.conv2=nn.Identity()
|
||||
self.use_add = shortcut and in_channels == out_channels
|
||||
|
||||
def forward(self, x):
|
||||
y = self.conv2(self.conv1(x))
|
||||
if self.use_add:
|
||||
y = y + x
|
||||
return y
|
||||
|
||||
|
||||
class FusionLayer(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, expansion=0.5, depthwise=False, act="silu", ):
|
||||
# ch_in, ch_out, number, shortcut, groups, expansion
|
||||
super().__init__()
|
||||
hidden_channels = int(out_channels * expansion)
|
||||
n = 1
|
||||
# --------------------------------------------------#
|
||||
# 主干部分的初次卷积
|
||||
# --------------------------------------------------#
|
||||
self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act)
|
||||
# --------------------------------------------------#
|
||||
# 大的残差边部分的初次卷积
|
||||
# --------------------------------------------------#
|
||||
self.conv2 = BaseConv(hidden_channels, hidden_channels, 1, stride=1, act=act) # in_channel
|
||||
# -----------------------------------------------#
|
||||
# 对堆叠的结果进行卷积的处理
|
||||
# self.deepfeature=nn.Sequential(BaseConv(hidden_channels, hidden_channels//2, 1, stride=1, act=act),
|
||||
# BaseConv(hidden_channels//2, hidden_channels, 3, stride=1, act=act))
|
||||
# -----------------------------------------------#
|
||||
# module_list = [Bottleneck(hidden_channels, hidden_channels, True, 1.0, depthwise, act=act) for _ in range(n)]
|
||||
# self.deepfeature = nn.Sequential(*module_list)
|
||||
self.conv3 = BaseConv(hidden_channels, out_channels, 1, stride=1, act=act) # 2*hidden_channel
|
||||
|
||||
# --------------------------------------------------#
|
||||
# 根据循环的次数构建上述Bottleneck残差结构
|
||||
# --------------------------------------------------#
|
||||
# module_list = [Bottleneck(hidden_channels, hidden_channels, shortcut, 1.0, depthwise, act=act) for _ in range(n)]
|
||||
# self.m = nn.Sequential(*module_list)
|
||||
|
||||
def forward(self, x):
|
||||
# -------------------------------#
|
||||
# x_1是主干部分
|
||||
# -------------------------------#
|
||||
# x_1 = self.conv1(x)
|
||||
x = self.conv1(x)
|
||||
# -------------------------------#
|
||||
# x_2是大的残差边部分
|
||||
# -------------------------------#
|
||||
# x_2 = self.conv2(x)
|
||||
x = self.conv2(x)
|
||||
# -----------------------------------------------#
|
||||
# 主干部分利用残差结构堆叠继续进行特征提取
|
||||
# -----------------------------------------------#
|
||||
# x_1 = self.deepfeature(x_1)
|
||||
# -----------------------------------------------#
|
||||
# 主干部分和大的残差边部分进行堆叠
|
||||
# -----------------------------------------------#
|
||||
# x = torch.cat((x_1, x_2), dim=1)
|
||||
# -----------------------------------------------#
|
||||
# 对堆叠的结果进行卷积的处理
|
||||
# -----------------------------------------------#
|
||||
return self.conv3(x)
|
||||
|
||||
|
||||
class Feature_Fusion(nn.Module):
|
||||
def __init__(self, in_channels=[128, 256, 512], depthwise=False, act="silu"):
|
||||
super().__init__()
|
||||
Conv = DWConv if depthwise else BaseConv
|
||||
|
||||
self.upsample = nn.Upsample(scale_factor=2, mode="nearest")
|
||||
|
||||
# -------------------------------------------#
|
||||
# 20, 20, 1024 -> 20, 20, 512
|
||||
# -------------------------------------------#
|
||||
# self.lateral_conv0 = BaseConv(2 * int(in_channels[2]), int(in_channels[1]), 1, 1, act=act)
|
||||
self.lateral_conv0 = BaseConv(in_channels[1] + in_channels[2], in_channels[1], 1, 1, act=act)
|
||||
|
||||
# -------------------------------------------#
|
||||
# 40, 40, 1024 -> 40, 40, 512
|
||||
# -------------------------------------------#
|
||||
self.C3_p4 = FusionLayer(
|
||||
int(2 * in_channels[1]),
|
||||
int(in_channels[1]),
|
||||
depthwise=depthwise,
|
||||
act=act,
|
||||
)
|
||||
|
||||
# -------------------------------------------#
|
||||
# 40, 40, 512 -> 40, 40, 256
|
||||
# -------------------------------------------#
|
||||
# self.reduce_conv1 = BaseConv(int(2 * in_channels[1]), int(in_channels[0]), 1, 1, act=act)
|
||||
self.reduce_conv1 = BaseConv(int(in_channels[0] + in_channels[1]), int(in_channels[0]), 1, 1, act=act)
|
||||
# -------------------------------------------#
|
||||
# 80, 80, 512 -> 80, 80, 256
|
||||
# -------------------------------------------#
|
||||
self.C3_p3 = FusionLayer(
|
||||
int(2 * in_channels[0]),
|
||||
int(in_channels[0]),
|
||||
depthwise=depthwise,
|
||||
act=act,
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
out_features = input # self.backbone.forward(input)
|
||||
[feat1, feat2, feat3] = out_features # [out_features[f] for f in self.in_features]
|
||||
|
||||
# -------------------------------------------#
|
||||
# 20, 20, 1024 -> 20, 20, 512
|
||||
# -------------------------------------------#
|
||||
# P5 = self.lateral_conv0(feat3)
|
||||
# -------------------------------------------#
|
||||
# 20, 20, 512 -> 40, 40, 512
|
||||
# -------------------------------------------#
|
||||
P5_upsample = self.upsample(feat3)
|
||||
# -------------------------------------------#
|
||||
# 40, 40, 512 + 40, 40, 512 -> 40, 40, 1024
|
||||
# -------------------------------------------#
|
||||
P5_upsample = torch.cat([P5_upsample, feat2], 1)
|
||||
# pdb.set_trace()
|
||||
# -------------------------------------------#
|
||||
# 40, 40, 1024 -> 40, 40, 512
|
||||
# -------------------------------------------#
|
||||
P4 = self.lateral_conv0(P5_upsample)
|
||||
# P5_upsample = self.C3_p4(P5_upsample)
|
||||
|
||||
# -------------------------------------------#
|
||||
# 40, 40, 512 -> 40, 40, 256
|
||||
# -------------------------------------------#
|
||||
# P4 = self.reduce_conv1(P5_upsample)
|
||||
# -------------------------------------------#
|
||||
# 40, 40, 256 -> 80, 80, 256
|
||||
# -------------------------------------------#
|
||||
P4_upsample = self.upsample(P4)
|
||||
# -------------------------------------------#
|
||||
# 80, 80, 256 + 80, 80, 256 -> 80, 80, 512
|
||||
# -------------------------------------------#
|
||||
P4_upsample = torch.cat([P4_upsample, feat1], 1)
|
||||
# -------------------------------------------#
|
||||
# 80, 80, 512 -> 80, 80, 256
|
||||
# -------------------------------------------#
|
||||
P3_out = self.reduce_conv1(P4_upsample)
|
||||
# P3_out = self.C3_p3(P4_upsample)
|
||||
|
||||
return P3_out
|
||||
|
||||
|
||||
class YOLOXHead(nn.Module):
|
||||
def __init__(self, num_classes, width=1.0, in_channels=[16, 32, 64], act="silu"):
|
||||
super().__init__()
|
||||
Conv = BaseConv
|
||||
|
||||
self.cls_convs = nn.ModuleList()
|
||||
self.reg_convs = nn.ModuleList()
|
||||
self.cls_preds = nn.ModuleList()
|
||||
self.reg_preds = nn.ModuleList()
|
||||
self.obj_preds = nn.ModuleList()
|
||||
self.stems = nn.ModuleList()
|
||||
|
||||
for i in range(len(in_channels)):
|
||||
self.stems.append(BaseConv(in_channels=int(in_channels[i]), out_channels=int(256 * width), ksize=1, stride=1, act=act)) # 128-> 256 通道整合
|
||||
self.cls_convs.append(nn.Sequential(*[
|
||||
Conv(in_channels=int(256 * width), out_channels=int(256 * width), ksize=3, stride=1, act=act),
|
||||
Conv(in_channels=int(256 * width), out_channels=int(256 * width), ksize=3, stride=1, act=act),
|
||||
]))
|
||||
self.cls_preds.append(
|
||||
nn.Conv2d(in_channels=int(256 * width), out_channels=num_classes, kernel_size=1, stride=1, padding=0)
|
||||
)
|
||||
|
||||
self.reg_convs.append(nn.Sequential(*[
|
||||
Conv(in_channels=int(256 * width), out_channels=int(256 * width), ksize=3, stride=1, act=act),
|
||||
Conv(in_channels=int(256 * width), out_channels=int(256 * width), ksize=3, stride=1, act=act)
|
||||
]))
|
||||
self.reg_preds.append(
|
||||
nn.Conv2d(in_channels=int(256 * width), out_channels=4, kernel_size=1, stride=1, padding=0)
|
||||
)
|
||||
self.obj_preds.append(
|
||||
nn.Conv2d(in_channels=int(256 * width), out_channels=1, kernel_size=1, stride=1, padding=0)
|
||||
)
|
||||
|
||||
def forward(self, inputs):
|
||||
# ---------------------------------------------------#
|
||||
# inputs输入
|
||||
# P3_out 80, 80, 256
|
||||
# P4_out 40, 40, 512
|
||||
# P5_out 20, 20, 1024
|
||||
# ---------------------------------------------------#
|
||||
outputs = []
|
||||
for k, x in enumerate(inputs):
|
||||
# ---------------------------------------------------#
|
||||
# 利用1x1卷积进行通道整合
|
||||
# ---------------------------------------------------#
|
||||
x = self.stems[k](x)
|
||||
# ---------------------------------------------------#
|
||||
# 利用两个卷积标准化激活函数来进行特征提取
|
||||
# ---------------------------------------------------#
|
||||
cls_feat = self.cls_convs[k](x)
|
||||
# ---------------------------------------------------#
|
||||
# 判断特征点所属的种类
|
||||
# 80, 80, num_classes
|
||||
# 40, 40, num_classes
|
||||
# 20, 20, num_classes
|
||||
# ---------------------------------------------------#
|
||||
cls_output = self.cls_preds[k](cls_feat)
|
||||
|
||||
# ---------------------------------------------------#
|
||||
# 利用两个卷积标准化激活函数来进行特征提取
|
||||
# ---------------------------------------------------#
|
||||
reg_feat = self.reg_convs[k](x)
|
||||
# ---------------------------------------------------#
|
||||
# 特征点的回归系数
|
||||
# reg_pred 80, 80, 4
|
||||
# reg_pred 40, 40, 4
|
||||
# reg_pred 20, 20, 4
|
||||
# ---------------------------------------------------#
|
||||
reg_output = self.reg_preds[k](reg_feat)
|
||||
# ---------------------------------------------------#
|
||||
# 判断特征点是否有对应的物体
|
||||
# obj_pred 80, 80, 1
|
||||
# obj_pred 40, 40, 1
|
||||
# obj_pred 20, 20, 1
|
||||
# ---------------------------------------------------#
|
||||
obj_output = self.obj_preds[k](reg_feat)
|
||||
|
||||
output = torch.cat([reg_output, obj_output, cls_output], 1)
|
||||
outputs.append(output)
|
||||
return outputs
|
||||
|
||||
|
||||
model_config = {
|
||||
|
||||
'backbone_2d': 'yolo_free_nano',
|
||||
'pretrained_2d': True,
|
||||
'stride': [8, 16, 32],
|
||||
# ## 3D
|
||||
'backbone_3d': 'shufflenetv2',
|
||||
'model_size': '1.0x', # 1.0x
|
||||
'pretrained_3d': True,
|
||||
'memory_momentum': 0.9,
|
||||
'head_dim': 128, # 64
|
||||
'head_norm': 'BN',
|
||||
'head_act': 'lrelu',
|
||||
'num_cls_heads': 2,
|
||||
'num_reg_heads': 2,
|
||||
'head_depthwise': True,
|
||||
|
||||
}
|
||||
|
||||
|
||||
def build_backbone_3d(cfg, pretrained=False):
|
||||
backbone = Backbone3D(cfg, pretrained)
|
||||
return backbone, backbone.feat_dim
|
||||
|
||||
|
||||
mcfg = model_config
|
||||
|
||||
|
||||
class TDCNetwork(nn.Module):
|
||||
def __init__(self, num_classes, fp16=False, num_frame=5):
|
||||
super(TDCNetwork, self).__init__()
|
||||
self.num_frame = num_frame
|
||||
self.backbone2d = Feature_Backbone(0.33, 0.50)
|
||||
self.backbone3d, bk_dim_3d = build_backbone_3d(mcfg, pretrained=mcfg['pretrained_3d'] and True)
|
||||
self.backbonetd = BackboneTD(mcfg, pretrained=mcfg['pretrained_3d'] and True)
|
||||
self.q_sa1 = SelfAttention(128, window_size=(2, 8, 8), num_heads=4, use_shift=True, mlp_ratio=1.5)
|
||||
self.k_sa1 = SelfAttention(128, window_size=(2, 8, 8), num_heads=4, use_shift=True, mlp_ratio=1.5)
|
||||
self.v_sa1 = SelfAttention(128, window_size=(2, 8, 8), num_heads=4, use_shift=True, mlp_ratio=1.5)
|
||||
self.q_sa2 = SelfAttention(256, window_size=(2, 4, 4), num_heads=4, use_shift=True, mlp_ratio=1.5)
|
||||
self.k_sa2 = SelfAttention(256, window_size=(2, 4, 4), num_heads=4, use_shift=True, mlp_ratio=1.5)
|
||||
self.v_sa2 = SelfAttention(256, window_size=(2, 4, 4), num_heads=4, use_shift=True, mlp_ratio=1.5)
|
||||
self.q_sa3 = SelfAttention(512, window_size=(2, 2, 2), num_heads=4, use_shift=True, mlp_ratio=1.5)
|
||||
self.k_sa3 = SelfAttention(512, window_size=(2, 2, 2), num_heads=4, use_shift=True, mlp_ratio=1.5)
|
||||
self.v_sa3 = SelfAttention(512, window_size=(2, 2, 2), num_heads=4, use_shift=True, mlp_ratio=1.5)
|
||||
self.ca1 = CrossAttention(128, window_size=(2, 8, 8), num_heads=4)
|
||||
self.ca2 = CrossAttention(256, window_size=(2, 4, 4), num_heads=4)
|
||||
self.ca3 = CrossAttention(512, window_size=(2, 2, 2), num_heads=4)
|
||||
self.feature_fusion = Feature_Fusion()
|
||||
self.head = YOLOXHead(num_classes=num_classes, width=1.0, in_channels=[128], act="silu")
|
||||
|
||||
def forward(self, inputs):
|
||||
# inputs: [B, 3, T, H, W]
|
||||
if len(inputs.shape) == 5:
|
||||
T = inputs.shape[2]
|
||||
diff_imgs = inputs[:, :, :T // 2, :, :]
|
||||
mt_imgs = inputs[:, :, T // 2:, :, :]
|
||||
else:
|
||||
diff_imgs = inputs
|
||||
mt_imgs = inputs
|
||||
q_3d = self.backbonetd(diff_imgs)
|
||||
q_3d1, q_3d2, q_3d3 = q_3d['stage2'], q_3d['stage3'], q_3d['stage4']
|
||||
k_3d = self.backbone3d(mt_imgs)
|
||||
k_3d1, k_3d2, k_3d3 = k_3d['stage2'], k_3d['stage3'], k_3d['stage4']
|
||||
[feat1, feat2, feat3] = self.backbone2d(inputs[:, :, -1, :, :])
|
||||
|
||||
def to_5d(x):
|
||||
# [B, C, T, H, W] -> [B, T, H, W, C]
|
||||
return x.permute(0, 2, 3, 4, 1)
|
||||
|
||||
q_3d1 = to_5d(q_3d1)
|
||||
q_3d2 = to_5d(q_3d2)
|
||||
q_3d3 = to_5d(q_3d3)
|
||||
k_3d1 = to_5d(k_3d1)
|
||||
k_3d2 = to_5d(k_3d2)
|
||||
k_3d3 = to_5d(k_3d3)
|
||||
|
||||
# V特征扩展T维度,与Q/K对齐(假设V为最后一帧,T=1)
|
||||
def expand_v(x, T):
|
||||
# [B, C, H, W] -> [B, T, H, W, C],复制T次
|
||||
x = x.permute(0, 2, 3, 1).unsqueeze(1)
|
||||
x = x.expand(-1, T, -1, -1, -1)
|
||||
return x
|
||||
|
||||
T1 = q_3d1.shape[1]
|
||||
T2 = q_3d2.shape[1]
|
||||
T3 = q_3d3.shape[1]
|
||||
v1 = expand_v(feat1, T1)
|
||||
v2 = expand_v(feat2, T2)
|
||||
v3 = expand_v(feat3, T3)
|
||||
|
||||
q1 = self.q_sa1(q_3d1)
|
||||
k1 = self.k_sa1(k_3d1)
|
||||
v1 = self.v_sa1(v1)
|
||||
q2 = self.q_sa2(q_3d2)
|
||||
k2 = self.k_sa2(k_3d2)
|
||||
v2 = self.v_sa2(v2)
|
||||
q3 = self.q_sa3(q_3d3)
|
||||
k3 = self.k_sa3(k_3d3)
|
||||
v3 = self.v_sa3(v3)
|
||||
out1 = self.ca1(q1, k1, v1)
|
||||
out2 = self.ca2(q2, k2, v2)
|
||||
out3 = self.ca3(q3, k3, v3)
|
||||
out1 = out1.mean(1).permute(0, 3, 1, 2)
|
||||
out2 = out2.mean(1).permute(0, 3, 1, 2)
|
||||
out3 = out3.mean(1).permute(0, 3, 1, 2)
|
||||
|
||||
feat_all = self.feature_fusion([out1, out2, out3])
|
||||
outputs = self.head([feat_all])
|
||||
|
||||
return outputs
|
||||
131
model/TDCNet/TDCR.py
Normal file
131
model/TDCNet/TDCR.py
Normal file
@@ -0,0 +1,131 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class TDC(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size=(5, 3, 3), stride=1, padding=(2, 1, 1), groups=1, bias=False, step=1):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, bias=bias)
|
||||
self.step = step
|
||||
self.groups = groups
|
||||
|
||||
def get_time_gradient_weight(self):
|
||||
weight = self.conv.weight
|
||||
kT, kH, kW = weight.shape[2:]
|
||||
grad_weight = torch.zeros_like(weight, device=weight.device, dtype=weight.dtype)
|
||||
if kT == 5:
|
||||
if self.step == -1:
|
||||
grad_weight[:, :, :, :, :] = -weight[:, :, :, :, :]
|
||||
grad_weight[:, :, 4, :, :] = weight[:, :, 0, :, :] + weight[:, :, 1, :, :] + weight[:, :, 2, :, :] + weight[:, :, 3, :, :] + weight[:, :, 4, :, :]
|
||||
elif self.step == 1:
|
||||
grad_weight[:, :, 4, :, :] = weight[:, :, 4, :, :]
|
||||
grad_weight[:, :, 3, :, :] = weight[:, :, 3, :, :] - weight[:, :, 4, :, :]
|
||||
grad_weight[:, :, 2, :, :] = weight[:, :, 2, :, :] - weight[:, :, 3, :, :]
|
||||
grad_weight[:, :, 1, :, :] = weight[:, :, 1, :, :] - weight[:, :, 2, :, :]
|
||||
grad_weight[:, :, 0, :, :] = -weight[:, :, 1, :, :]
|
||||
elif self.step == 2:
|
||||
grad_weight[:, :, 4, :, :] = weight[:, :, 4, :, :]
|
||||
grad_weight[:, :, 3, :, :] = weight[:, :, 3, :, :]
|
||||
grad_weight[:, :, 2, :, :] = weight[:, :, 2, :, :] - weight[:, :, 4, :, :]
|
||||
grad_weight[:, :, 1, :, :] = -weight[:, :, 3, :, :]
|
||||
grad_weight[:, :, 0, :, :] = -weight[:, :, 2, :, :]
|
||||
else:
|
||||
grad_weight = weight
|
||||
bias = self.conv.bias
|
||||
if bias is None:
|
||||
bias = torch.zeros(weight.shape[0], device=weight.device, dtype=weight.dtype)
|
||||
return grad_weight, bias
|
||||
|
||||
def forward(self, x):
|
||||
weight, bias = self.get_time_gradient_weight()
|
||||
x_diff = F.conv3d(x, weight, bias, stride=self.conv.stride, groups=self.groups, padding=self.conv.padding)
|
||||
return x_diff
|
||||
|
||||
|
||||
class RepConv3D(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size=(5, 3, 3), stride=1, padding=(2, 1, 1), groups=1, deploy=False):
|
||||
super(RepConv3D, self).__init__()
|
||||
self.deploy = deploy
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.stride = stride
|
||||
self.groups = groups
|
||||
if self.deploy:
|
||||
self.conv_reparam = nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding, groups=groups, bias=True)
|
||||
else:
|
||||
self.l_tdc = nn.Sequential(
|
||||
TDC(in_channels, out_channels, kernel_size, stride, padding, groups=groups, bias=False, step=-1),
|
||||
nn.BatchNorm3d(out_channels)
|
||||
)
|
||||
self.s_tdc = nn.Sequential(
|
||||
TDC(in_channels, out_channels, kernel_size, stride, padding, groups=groups, bias=False, step=1),
|
||||
nn.BatchNorm3d(out_channels)
|
||||
)
|
||||
self.m_tdc = nn.Sequential(
|
||||
TDC(in_channels, out_channels, kernel_size, stride, padding, groups=groups, bias=False, step=2),
|
||||
nn.BatchNorm3d(out_channels)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
if self.deploy:
|
||||
out = F.relu(self.conv_reparam(x))
|
||||
else:
|
||||
out = self.s_tdc(x) + self.m_tdc(x) + self.l_tdc(x)
|
||||
out = F.relu(out)
|
||||
return out
|
||||
|
||||
def get_equivalent_kernel_bias(self):
|
||||
kernel_s_tdc, bias_s_tdc = self._fuse_conv_bn(self.s_tdc)
|
||||
kernel_m_tdc, bias_m_tdc = self._fuse_conv_bn(self.m_tdc)
|
||||
kernel_l_tdc, bias_l_tdc = self._fuse_conv_bn(self.l_tdc)
|
||||
kernel = kernel_s_tdc + kernel_m_tdc + kernel_l_tdc
|
||||
bias = bias_s_tdc + bias_m_tdc + bias_l_tdc
|
||||
return kernel, bias
|
||||
|
||||
def switch_to_deploy(self):
|
||||
if self.deploy:
|
||||
return
|
||||
kernel, bias = self.get_equivalent_kernel_bias()
|
||||
self.conv_reparam = nn.Conv3d(
|
||||
self.in_channels, self.out_channels, (5, 3, 3), self.stride,
|
||||
(2, 1, 1), groups=self.groups, bias=True
|
||||
)
|
||||
self.conv_reparam.weight.data = kernel
|
||||
self.conv_reparam.bias.data = bias
|
||||
self.deploy = True
|
||||
del self.s_tdc
|
||||
del self.m_tdc
|
||||
del self.l_tdc
|
||||
|
||||
@staticmethod
|
||||
def _fuse_conv_bn(branch):
|
||||
if branch is None:
|
||||
return 0, 0
|
||||
|
||||
def find_conv(module):
|
||||
if isinstance(module, nn.Conv3d):
|
||||
return module
|
||||
for child in module.children():
|
||||
conv = find_conv(child)
|
||||
if conv is not None:
|
||||
return conv
|
||||
return None
|
||||
|
||||
conv = find_conv(branch[0])
|
||||
bn = branch[1]
|
||||
if hasattr(branch[0], 'get_time_gradient_weight'):
|
||||
w, bias = branch[0].get_time_gradient_weight()
|
||||
else:
|
||||
w = conv.weight
|
||||
if conv.bias is not None:
|
||||
bias = conv.bias
|
||||
else:
|
||||
bias = torch.zeros_like(bn.running_mean)
|
||||
mean = bn.running_mean
|
||||
var_sqrt = torch.sqrt(bn.running_var + bn.eps)
|
||||
gamma = bn.weight
|
||||
beta = bn.bias
|
||||
w = w * (gamma / var_sqrt).reshape(-1, 1, 1, 1, 1)
|
||||
bias = (bias - mean) / var_sqrt * gamma + beta
|
||||
return w, bias
|
||||
239
model/TDCNet/TDCSTA.py
Normal file
239
model/TDCNet/TDCSTA.py
Normal file
@@ -0,0 +1,239 @@
|
||||
from functools import reduce
|
||||
from operator import mul
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class WindowAttention3D(nn.Module):
|
||||
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.window_size = window_size # (T, H, W)
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = qk_scale or head_dim ** -0.5
|
||||
self.relative_position_bias_table = nn.Parameter(
|
||||
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1) * (2 * window_size[2] - 1), num_heads))
|
||||
coords_t = torch.arange(self.window_size[0])
|
||||
coords_h = torch.arange(self.window_size[1])
|
||||
coords_w = torch.arange(self.window_size[2])
|
||||
coords = torch.stack(torch.meshgrid(coords_t, coords_h, coords_w, indexing='ij')) # 3, T, H, W
|
||||
coords_flatten = torch.flatten(coords, 1)
|
||||
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
|
||||
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
|
||||
relative_coords[:, :, 0] += self.window_size[0] - 1
|
||||
relative_coords[:, :, 1] += self.window_size[1] - 1
|
||||
relative_coords[:, :, 2] += self.window_size[2] - 1
|
||||
relative_coords[:, :, 0] *= (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1)
|
||||
relative_coords[:, :, 1] *= (2 * self.window_size[2] - 1)
|
||||
relative_position_index = relative_coords.sum(-1)
|
||||
self.register_buffer("relative_position_index", relative_position_index)
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
nn.init.trunc_normal_(self.relative_position_bias_table, std=.02)
|
||||
|
||||
def forward(self, x, k=None, v=None, mask=None):
|
||||
B_, N, C = x.shape
|
||||
if k is None or v is None:
|
||||
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||
else:
|
||||
q = x.reshape(B_, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) # [B_, num_heads, N, head_dim]
|
||||
k = k.reshape(B_, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
||||
v = v.reshape(B_, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
||||
q = q * self.scale
|
||||
attn = (q @ k.transpose(-2, -1))
|
||||
relative_position_bias = self.relative_position_bias_table[self.relative_position_index[:N, :N].reshape(-1)].reshape(N, N, -1)
|
||||
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
|
||||
attn = attn + relative_position_bias.unsqueeze(0)
|
||||
if mask is not None:
|
||||
nW = mask.shape[0]
|
||||
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
|
||||
attn = attn.view(-1, self.num_heads, N, N)
|
||||
attn = self.softmax(attn)
|
||||
else:
|
||||
attn = self.softmax(attn)
|
||||
attn = self.attn_drop(attn)
|
||||
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
def window_partition(x, window_size):
|
||||
B, T, H, W, C = x.shape
|
||||
window_size = list(window_size)
|
||||
if T < window_size[0]:
|
||||
window_size[0] = T
|
||||
if H < window_size[1]:
|
||||
window_size[1] = H
|
||||
if W < window_size[2]:
|
||||
window_size[2] = W
|
||||
x = x.view(B, T // window_size[0] if window_size[0] > 0 else 1, window_size[0],
|
||||
H // window_size[1] if window_size[1] > 0 else 1, window_size[1],
|
||||
W // window_size[2] if window_size[2] > 0 else 1, window_size[2], C)
|
||||
windows = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, reduce(mul, window_size), C)
|
||||
return windows
|
||||
|
||||
|
||||
def window_reverse(windows, window_size, B, T, H, W):
|
||||
x = windows.view(B, T // window_size[0], H // window_size[1], W // window_size[2], window_size[0], window_size[1], window_size[2], -1)
|
||||
x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(B, T, H, W, -1)
|
||||
return x
|
||||
|
||||
|
||||
def get_window_size(x_size, window_size, shift_size=None):
|
||||
use_window_size = list(window_size)
|
||||
if shift_size is not None:
|
||||
use_shift_size = list(shift_size)
|
||||
for i in range(len(x_size)):
|
||||
if x_size[i] <= window_size[i]:
|
||||
use_window_size[i] = x_size[i]
|
||||
if shift_size is not None:
|
||||
use_shift_size[i] = 0
|
||||
if shift_size is None:
|
||||
return tuple(use_window_size)
|
||||
else:
|
||||
return tuple(use_window_size), tuple(use_shift_size)
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
def __init__(self, dim, window_size=(2, 8, 8), num_heads=8, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., use_shift=False, shift_size=None, mlp_ratio=2.0, norm_layer=nn.LayerNorm):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.window_size = window_size
|
||||
self.num_heads = num_heads
|
||||
self.use_shift = use_shift
|
||||
self.shift_size = shift_size if shift_size is not None else tuple([w // 2 for w in window_size]) if use_shift else tuple([0] * len(window_size))
|
||||
self.attn1 = WindowAttention3D(dim, window_size=self.window_size, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=proj_drop)
|
||||
self.attn2 = WindowAttention3D(dim, window_size=self.window_size, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=proj_drop)
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.norm2 = norm_layer(dim)
|
||||
self.norm3 = norm_layer(dim)
|
||||
self.norm4 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp1 = nn.Sequential(
|
||||
nn.Linear(dim, mlp_hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(mlp_hidden_dim, dim)
|
||||
)
|
||||
self.mlp2 = nn.Sequential(
|
||||
nn.Linear(dim, mlp_hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(mlp_hidden_dim, dim)
|
||||
)
|
||||
|
||||
def create_mask(self, x_shape, device):
|
||||
B, T, H, W, C = x_shape
|
||||
img_mask = torch.zeros((1, T, H, W, 1), device=device)
|
||||
cnt = 0
|
||||
t_slices = (slice(0, -self.window_size[0]), slice(-self.window_size[0], -self.shift_size[0]), slice(-self.shift_size[0], None))
|
||||
h_slices = (slice(0, -self.window_size[1]), slice(-self.window_size[1], -self.shift_size[1]), slice(-self.shift_size[1], None))
|
||||
w_slices = (slice(0, -self.window_size[2]), slice(-self.window_size[2], -self.shift_size[2]), slice(-self.shift_size[2], None))
|
||||
for t in t_slices:
|
||||
for h in h_slices:
|
||||
for w in w_slices:
|
||||
img_mask[:, t, h, w, :] = cnt
|
||||
cnt += 1
|
||||
mask_windows = window_partition(img_mask, self.window_size)
|
||||
mask_windows = mask_windows.squeeze(-1)
|
||||
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
||||
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
||||
return attn_mask
|
||||
|
||||
def forward(self, x):
|
||||
B, T, H, W, C = x.shape
|
||||
window_size, shift_size = get_window_size((T, H, W), self.window_size, self.shift_size)
|
||||
shortcut = x
|
||||
x = self.norm1(x)
|
||||
pad_t = (window_size[0] - T % window_size[0]) % window_size[0]
|
||||
pad_h = (window_size[1] - H % window_size[1]) % window_size[1]
|
||||
pad_w = (window_size[2] - W % window_size[2]) % window_size[2]
|
||||
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h, 0, pad_t))
|
||||
shortcut = F.pad(shortcut, (0, 0, 0, pad_w, 0, pad_h, 0, pad_t))
|
||||
_, Tp, Hp, Wp, _ = x.shape
|
||||
x_windows = window_partition(x, window_size)
|
||||
attn_windows = self.attn1(x_windows, mask=None)
|
||||
attn_windows = attn_windows.view(-1, *(window_size + (C,)))
|
||||
x = window_reverse(attn_windows, window_size, B, Tp, Hp, Wp)
|
||||
x = shortcut + x
|
||||
x = x + self.mlp1(self.norm2(x))
|
||||
shortcut = x
|
||||
x = self.norm3(x)
|
||||
if self.use_shift and any(i > 0 for i in shift_size):
|
||||
shifted_x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1], -shift_size[2]), dims=(1, 2, 3))
|
||||
attn_mask = self.create_mask((B, Tp, Hp, Wp, C), x.device)
|
||||
x_windows = window_partition(shifted_x, window_size)
|
||||
attn_windows = self.attn2(x_windows, mask=attn_mask)
|
||||
attn_windows = attn_windows.view(-1, *(window_size + (C,)))
|
||||
shifted_x = window_reverse(attn_windows, window_size, B, Tp, Hp, Wp)
|
||||
x = torch.roll(shifted_x, shifts=(shift_size[0], shift_size[1], shift_size[2]), dims=(1, 2, 3))
|
||||
if pad_t > 0:
|
||||
x = x[:, :T, :, :, :]
|
||||
shortcut = shortcut[:, :T, :, :, :]
|
||||
if pad_h > 0:
|
||||
x = x[:, :, :H, :, :]
|
||||
shortcut = shortcut[:, :, :H, :, :]
|
||||
if pad_w > 0:
|
||||
x = x[:, :, :, :W, :]
|
||||
shortcut = shortcut[:, :, :, :W, :]
|
||||
|
||||
x = shortcut + x
|
||||
x = x + self.mlp2(self.norm4(x))
|
||||
return x
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(self, dim, window_size=(2, 8, 8), num_heads=8, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., mlp_ratio=2.0, norm_layer=nn.LayerNorm):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.window_size = window_size
|
||||
self.num_heads = num_heads
|
||||
self.norm1_q = norm_layer(dim)
|
||||
self.norm1_k = norm_layer(dim)
|
||||
self.norm1_v = norm_layer(dim)
|
||||
self.attn = WindowAttention3D(dim, window_size=self.window_size, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=proj_drop)
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(dim, mlp_hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(mlp_hidden_dim, dim)
|
||||
)
|
||||
|
||||
def forward(self, q, k, v):
|
||||
B, T, H, W, C = q.shape
|
||||
window_size = get_window_size((T, H, W), self.window_size)
|
||||
shortcut = v
|
||||
q = self.norm1_q(q)
|
||||
k = self.norm1_k(k)
|
||||
v = self.norm1_v(v)
|
||||
pad_t = (window_size[0] - T % window_size[0]) % window_size[0]
|
||||
pad_h = (window_size[1] - H % window_size[1]) % window_size[1]
|
||||
pad_w = (window_size[2] - W % window_size[2]) % window_size[2]
|
||||
q = F.pad(q, (0, 0, 0, pad_w, 0, pad_h, 0, pad_t))
|
||||
k = F.pad(k, (0, 0, 0, pad_w, 0, pad_h, 0, pad_t))
|
||||
v = F.pad(v, (0, 0, 0, pad_w, 0, pad_h, 0, pad_t))
|
||||
_, Tp, Hp, Wp, _ = q.shape
|
||||
|
||||
q_windows = window_partition(q, window_size)
|
||||
k_windows = window_partition(k, window_size)
|
||||
v_windows = window_partition(v, window_size)
|
||||
attn_windows = self.attn(q_windows, k_windows, v_windows)
|
||||
attn_windows = attn_windows.view(-1, *(window_size + (C,)))
|
||||
shifted_x = window_reverse(attn_windows, window_size, B, Tp, Hp, Wp)
|
||||
x = shifted_x
|
||||
if pad_t > 0:
|
||||
x = x[:, :T, :, :, :]
|
||||
if pad_h > 0:
|
||||
x = x[:, :, :H, :, :]
|
||||
if pad_w > 0:
|
||||
x = x[:, :, :, :W, :]
|
||||
x = shortcut + x
|
||||
x = x + self.mlp(self.norm2(x))
|
||||
return x
|
||||
0
model/TDCNet/__init__.py
Normal file
0
model/TDCNet/__init__.py
Normal file
272
model/TDCNet/backbone3d.py
Normal file
272
model/TDCNet/backbone3d.py
Normal file
@@ -0,0 +1,272 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from torch.hub import load_state_dict_from_url
|
||||
|
||||
model_urls = {
|
||||
"0.25x": "https://github.com/yjh0410/PyTorch_YOWO/releases/download/yowo-weight/kinetics_shufflenetv2_0.25x_RGB_16_best.pth",
|
||||
"1.0x": "https://github.com/yjh0410/PyTorch_YOWO/releases/download/yowo-weight/kinetics_shufflenetv2_1.0x_RGB_16_best.pth",
|
||||
"1.5x": "https://github.com/yjh0410/PyTorch_YOWO/releases/download/yowo-weight/kinetics_shufflenetv2_1.5x_RGB_16_best.pth",
|
||||
"2.0x": "https://github.com/yjh0410/PyTorch_YOWO/releases/download/yowo-weight/kinetics_shufflenetv2_2.0x_RGB_16_best.pth",
|
||||
}
|
||||
|
||||
|
||||
def load_weight(model, arch):
|
||||
url = model_urls[arch]
|
||||
# check
|
||||
if url is None:
|
||||
print('No pretrained weight for 3D CNN: {}'.format(arch.upper()))
|
||||
return model
|
||||
|
||||
# checkpoint state dict
|
||||
checkpoint = load_state_dict_from_url(url=url, map_location="cpu", check_hash=True)
|
||||
|
||||
checkpoint_state_dict = checkpoint.pop('state_dict')
|
||||
|
||||
# model state dict
|
||||
model_state_dict = model.state_dict()
|
||||
# reformat checkpoint_state_dict:
|
||||
new_state_dict = {}
|
||||
for k in checkpoint_state_dict.keys():
|
||||
v = checkpoint_state_dict[k]
|
||||
new_state_dict[k[7:]] = v
|
||||
# pdb.set_trace()
|
||||
# check
|
||||
for k in list(new_state_dict.keys()):
|
||||
if k in model_state_dict:
|
||||
shape_model = tuple(model_state_dict[k].shape)
|
||||
shape_checkpoint = tuple(new_state_dict[k].shape)
|
||||
if shape_model != shape_checkpoint:
|
||||
new_state_dict.pop(k)
|
||||
else:
|
||||
new_state_dict.pop(k)
|
||||
|
||||
model.load_state_dict(new_state_dict)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def conv_bn(inp, oup, stride):
|
||||
return nn.Sequential(
|
||||
nn.Conv3d(inp, oup, kernel_size=(5, 3, 3), stride=stride, padding=(2, 1, 1), bias=False),
|
||||
nn.BatchNorm3d(oup),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
|
||||
|
||||
class InvertedResidual(nn.Module):
|
||||
def __init__(self, inp, oup, stride):
|
||||
super(InvertedResidual, self).__init__()
|
||||
self.stride = stride
|
||||
assert stride in [1, 2]
|
||||
|
||||
oup_inc = oup // 2
|
||||
|
||||
if self.stride == 1:
|
||||
self.banch2 = nn.Sequential(
|
||||
# pw
|
||||
nn.Conv3d(oup_inc, oup_inc, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm3d(oup_inc),
|
||||
nn.ReLU(inplace=True),
|
||||
# dw
|
||||
nn.Conv3d(oup_inc, oup_inc, (5, 3, 3), stride, (2, 1, 1), groups=oup_inc, bias=False),
|
||||
nn.BatchNorm3d(oup_inc),
|
||||
# pw-linear
|
||||
nn.Conv3d(oup_inc, oup_inc, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm3d(oup_inc),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
|
||||
else:
|
||||
self.banch1 = nn.Sequential(
|
||||
# dw
|
||||
nn.Conv3d(inp, inp, (5, 3, 3), stride, (2, 1, 1), groups=inp, bias=False),
|
||||
nn.BatchNorm3d(inp),
|
||||
# pw-linear
|
||||
nn.Conv3d(inp, oup_inc, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm3d(oup_inc),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
self.banch2 = nn.Sequential(
|
||||
# pw
|
||||
nn.Conv3d(inp, oup_inc, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm3d(oup_inc),
|
||||
nn.ReLU(inplace=True),
|
||||
# dw
|
||||
nn.Conv3d(oup_inc, oup_inc, (5, 3, 3), stride, (2, 1, 1), groups=oup_inc, bias=False),
|
||||
nn.BatchNorm3d(oup_inc),
|
||||
# pw-linear
|
||||
nn.Conv3d(oup_inc, oup_inc, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm3d(oup_inc),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _concat(x, out):
|
||||
# concatenate along channel axis
|
||||
return torch.cat((x, out), 1)
|
||||
|
||||
def forward(self, x):
|
||||
if self.stride == 1:
|
||||
x1 = x[:, :(x.shape[1] // 2), :, :, :]
|
||||
x2 = x[:, (x.shape[1] // 2):, :, :, :]
|
||||
out = self._concat(x1, self.banch2(x2))
|
||||
elif self.stride == 2:
|
||||
out = self._concat(self.banch1(x), self.banch2(x))
|
||||
|
||||
return channel_shuffle(out, 2)
|
||||
|
||||
|
||||
def channel_shuffle(x, groups):
|
||||
'''Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]'''
|
||||
batchsize, num_channels, depth, height, width = x.data.size()
|
||||
channels_per_group = num_channels // groups
|
||||
# reshape
|
||||
x = x.view(batchsize, groups,
|
||||
channels_per_group, depth, height, width)
|
||||
# permute
|
||||
x = x.permute(0, 2, 1, 3, 4, 5).contiguous()
|
||||
# flatten
|
||||
x = x.view(batchsize, num_channels, depth, height, width)
|
||||
return x
|
||||
|
||||
|
||||
class ShuffleNetV2(nn.Module):
|
||||
def __init__(self, width_mult='1.0x', num_classes=600):
|
||||
super(ShuffleNetV2, self).__init__()
|
||||
|
||||
self.stage_repeats = [4, 8, 4]
|
||||
# index 0 is invalid and should never be called.
|
||||
# only used for indexing convenience.
|
||||
if width_mult == '0.25x':
|
||||
self.stage_out_channels = [-1, 24, 32, 64, 128]
|
||||
elif width_mult == '0.5x':
|
||||
self.stage_out_channels = [-1, 24, 48, 96, 192]
|
||||
elif width_mult == '1.0x':
|
||||
self.stage_out_channels = [-1, 24, 128, 256, 512]
|
||||
elif width_mult == '1.5x':
|
||||
self.stage_out_channels = [-1, 24, 176, 352, 704]
|
||||
elif width_mult == '2.0x':
|
||||
self.stage_out_channels = [-1, 24, 224, 488, 976]
|
||||
|
||||
# building first layer
|
||||
input_channel = self.stage_out_channels[1]
|
||||
self.conv1 = conv_bn(3, input_channel, stride=(1, 2, 2))
|
||||
self.maxpool = nn.MaxPool3d(kernel_size=3, stride=2, padding=1)
|
||||
self.features = []
|
||||
self.features1 = []
|
||||
self.features2 = []
|
||||
self.features3 = []
|
||||
# building inverted residual blocks
|
||||
for idxstage in range(len(self.stage_repeats)):
|
||||
numrepeat = self.stage_repeats[idxstage]
|
||||
output_channel = self.stage_out_channels[idxstage + 2]
|
||||
for i in range(numrepeat):
|
||||
stride = 2 if i == 0 else 1
|
||||
self.features.append(InvertedResidual(input_channel, output_channel, stride))
|
||||
input_channel = output_channel
|
||||
self.features = nn.Sequential(*self.features)
|
||||
# for idxstage in range(len(self.stage_repeats)):
|
||||
# numrepeat = self.stage_repeats[idxstage]
|
||||
# output_channel = self.stage_out_channels[idxstage+2]
|
||||
# for i in range(numrepeat):
|
||||
# if idxstage==0:
|
||||
# stride = 2 if i == 0 else 1
|
||||
# self.features1.append(InvertedResidual(input_channel, output_channel, stride))
|
||||
# input_channel = output_channel
|
||||
# elif idxstage==1:
|
||||
# stride = 2 if i == 0 else 1
|
||||
# self.features2.append(InvertedResidual(input_channel, output_channel, stride))
|
||||
# input_channel = output_channel
|
||||
# elif idxstage==2:
|
||||
# stride = 2 if i == 0 else 1
|
||||
# self.features3.append(InvertedResidual(input_channel, output_channel, stride))
|
||||
# input_channel = output_channel
|
||||
# # make it nn.Sequential
|
||||
# self.features1 = nn.Sequential(*self.features1)
|
||||
# self.features2 = nn.Sequential(*self.features2)
|
||||
# self.features3 = nn.Sequential(*self.features3)
|
||||
|
||||
# # building last several layers
|
||||
# self.conv_last = conv_1x1x1_bn(input_channel, self.stage_out_channels[-1])
|
||||
# self.avgpool = nn.AvgPool3d((2, 1, 1), stride=1)
|
||||
|
||||
def forward(self, x):
|
||||
outputs = {}
|
||||
# pdb.set_trace() #(1,3,16,512,512) #(1,3,5,512,512)
|
||||
x = self.conv1(x) # (1,24,16,256,256) #(1,24,5,256,256)
|
||||
|
||||
x = self.maxpool(x) # (1,24,8,128,128) #(1,24,3,128,128)
|
||||
# outputs['stage1'] = x
|
||||
# x=self.features(x)
|
||||
x = self.features[:4](x) # (1,116,4,64,64) #(1,116,2,64,64)
|
||||
outputs['stage2'] = x # torch.mean(x, dim=2, keepdim=True).squeeze(2)
|
||||
x = self.features[4:12](x) # (1,232,2,32,32) #(1,232,1,32,32)
|
||||
outputs['stage3'] = x # torch.mean(x, dim=2, keepdim=True).squeeze(2)
|
||||
x = self.features[12:16](x) # (1,464,1,16,16) #(1,464,1,16,16)
|
||||
outputs['stage4'] = x # torch.mean(x, dim=2, keepdim=True).squeeze(2)
|
||||
# out = self.conv_last(out)
|
||||
|
||||
# if x.size(2) > 1:
|
||||
# x = torch.mean(x, dim=2, keepdim=True)
|
||||
|
||||
# return x.squeeze(2)
|
||||
return outputs
|
||||
|
||||
|
||||
def build_shufflenetv2_3d(model_size='0.25x', pretrained=False):
|
||||
model = ShuffleNetV2(model_size)
|
||||
feats = model.stage_out_channels[-1]
|
||||
|
||||
# if pretrained:
|
||||
# model = load_weight(model, model_size)
|
||||
|
||||
return model, feats
|
||||
|
||||
|
||||
def build_3d_cnn(cfg, pretrained=False):
|
||||
if 'resnet' in cfg['backbone_3d']:
|
||||
model, feat_dims = build_resnet_3d(
|
||||
model_name=cfg['backbone_3d'],
|
||||
pretrained=pretrained
|
||||
)
|
||||
elif 'resnext' in cfg['backbone_3d']:
|
||||
model, feat_dims = build_resnext_3d(
|
||||
model_name=cfg['backbone_3d'],
|
||||
pretrained=pretrained
|
||||
)
|
||||
elif 'shufflenetv2' in cfg['backbone_3d']:
|
||||
model, feat_dims = build_shufflenetv2_3d(
|
||||
model_size=cfg['model_size'],
|
||||
pretrained=pretrained
|
||||
)
|
||||
else:
|
||||
print('Unknown Backbone ...')
|
||||
exit()
|
||||
|
||||
return model, feat_dims
|
||||
|
||||
|
||||
class Backbone3D(nn.Module):
|
||||
def __init__(self, cfg, pretrained=False):
|
||||
super().__init__()
|
||||
self.cfg = cfg
|
||||
|
||||
# 3D CNN
|
||||
self.backbone, self.feat_dim = build_3d_cnn(cfg, pretrained)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Input:
|
||||
x: (Tensor) -> [B, C, T, H, W]
|
||||
Output:
|
||||
y: (List) -> [
|
||||
(Tensor) -> [B, C1, H1, W1],
|
||||
(Tensor) -> [B, C2, H2, W2],
|
||||
(Tensor) -> [B, C3, H3, W3]
|
||||
]
|
||||
"""
|
||||
feat = self.backbone(x)
|
||||
|
||||
return feat
|
||||
|
||||
281
model/TDCNet/backbonetd.py
Normal file
281
model/TDCNet/backbonetd.py
Normal file
@@ -0,0 +1,281 @@
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from matplotlib import pyplot as plt
|
||||
from torch.hub import load_state_dict_from_url
|
||||
|
||||
from model.TDCNet.TDCR import RepConv3D
|
||||
|
||||
model_urls = {
|
||||
"0.25x": "https://github.com/yjh0410/PyTorch_YOWO/releases/download/yowo-weight/kinetics_shufflenetv2_0.25x_RGB_16_best.pth",
|
||||
"1.0x": "https://github.com/yjh0410/PyTorch_YOWO/releases/download/yowo-weight/kinetics_shufflenetv2_1.0x_RGB_16_best.pth",
|
||||
"1.5x": "https://github.com/yjh0410/PyTorch_YOWO/releases/download/yowo-weight/kinetics_shufflenetv2_1.5x_RGB_16_best.pth",
|
||||
"2.0x": "https://github.com/yjh0410/PyTorch_YOWO/releases/download/yowo-weight/kinetics_shufflenetv2_2.0x_RGB_16_best.pth",
|
||||
}
|
||||
|
||||
|
||||
def load_weight(model, arch):
|
||||
url = model_urls[arch]
|
||||
# check
|
||||
if url is None:
|
||||
print('No pretrained weight for 3D CNN: {}'.format(arch.upper()))
|
||||
return model
|
||||
|
||||
# checkpoint state dict
|
||||
checkpoint = load_state_dict_from_url(url=url, map_location="cpu", check_hash=True)
|
||||
|
||||
checkpoint_state_dict = checkpoint.pop('state_dict')
|
||||
|
||||
# model state dict
|
||||
model_state_dict = model.state_dict()
|
||||
# reformat checkpoint_state_dict:
|
||||
new_state_dict = {}
|
||||
for k in checkpoint_state_dict.keys():
|
||||
v = checkpoint_state_dict[k]
|
||||
new_state_dict[k[7:]] = v
|
||||
# pdb.set_trace()
|
||||
# check
|
||||
for k in list(new_state_dict.keys()):
|
||||
if k in model_state_dict:
|
||||
shape_model = tuple(model_state_dict[k].shape)
|
||||
shape_checkpoint = tuple(new_state_dict[k].shape)
|
||||
if shape_model != shape_checkpoint:
|
||||
new_state_dict.pop(k)
|
||||
else:
|
||||
new_state_dict.pop(k)
|
||||
|
||||
model.load_state_dict(new_state_dict)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def conv_bn(inp, oup, stride):
|
||||
# return nn.Sequential(
|
||||
# nn.Conv3d(inp, oup, kernel_size=3, stride=stride, padding=(1,1,1), bias=False),
|
||||
# nn.BatchNorm3d(oup),
|
||||
# nn.ReLU(inplace=True)
|
||||
# )
|
||||
return RepConv3D(inp, oup, (5, 3, 3), stride, (2, 1, 1))
|
||||
|
||||
|
||||
class InvertedResidual(nn.Module):
|
||||
def __init__(self, inp, oup, stride):
|
||||
super(InvertedResidual, self).__init__()
|
||||
self.stride = stride
|
||||
assert stride in [1, 2]
|
||||
|
||||
oup_inc = oup // 2
|
||||
|
||||
if self.stride == 1:
|
||||
self.banch2 = nn.Sequential(
|
||||
# pw
|
||||
nn.Conv3d(oup_inc, oup_inc, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm3d(oup_inc),
|
||||
nn.ReLU(inplace=True),
|
||||
# dw
|
||||
# nn.Conv3d(oup_inc, oup_inc, 3, stride, 1, groups=oup_inc, bias=False),
|
||||
# nn.BatchNorm3d(oup_inc),
|
||||
RepConv3D(oup_inc, oup_inc, (5, 3, 3), stride, (2, 1, 1), groups=oup_inc),
|
||||
# pw-linear
|
||||
nn.Conv3d(oup_inc, oup_inc, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm3d(oup_inc),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
|
||||
else:
|
||||
self.banch1 = nn.Sequential(
|
||||
# dw
|
||||
# nn.Conv3d(inp, inp, 3, stride, 1, groups=inp, bias=False),
|
||||
# nn.BatchNorm3d(inp),
|
||||
RepConv3D(inp, inp, (5, 3, 3), stride, (2, 1, 1), groups=inp, ),
|
||||
# pw-linear
|
||||
nn.Conv3d(inp, oup_inc, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm3d(oup_inc),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
self.banch2 = nn.Sequential(
|
||||
# pw
|
||||
nn.Conv3d(inp, oup_inc, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm3d(oup_inc),
|
||||
nn.ReLU(inplace=True),
|
||||
# dw
|
||||
# nn.Conv3d(oup_inc, oup_inc, 3, stride, 1, groups=oup_inc, bias=False),
|
||||
# nn.BatchNorm3d(oup_inc),
|
||||
RepConv3D(oup_inc, oup_inc, (5, 3, 3), stride, (2, 1, 1), groups=oup_inc, ),
|
||||
# pw-linear
|
||||
nn.Conv3d(oup_inc, oup_inc, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm3d(oup_inc),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _concat(x, out):
|
||||
# concatenate along channel axis
|
||||
return torch.cat((x, out), 1)
|
||||
|
||||
def forward(self, x):
|
||||
if self.stride == 1:
|
||||
x1 = x[:, :(x.shape[1] // 2), :, :, :]
|
||||
x2 = x[:, (x.shape[1] // 2):, :, :, :]
|
||||
out = self._concat(x1, self.banch2(x2))
|
||||
elif self.stride == 2:
|
||||
out = self._concat(self.banch1(x), self.banch2(x))
|
||||
# return out
|
||||
return channel_shuffle(out, 2)
|
||||
#
|
||||
#
|
||||
def channel_shuffle(x, groups):
|
||||
'''Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]'''
|
||||
batchsize, num_channels, depth, height, width = x.data.size()
|
||||
channels_per_group = num_channels // groups
|
||||
# reshape
|
||||
x = x.view(batchsize, groups,
|
||||
channels_per_group, depth, height, width)
|
||||
# permute
|
||||
x = x.permute(0, 2, 1, 3, 4, 5).contiguous()
|
||||
# flatten
|
||||
x = x.view(batchsize, num_channels, depth, height, width)
|
||||
return x
|
||||
|
||||
|
||||
class ShuffleNetV2(nn.Module):
|
||||
def __init__(self, width_mult='1.0x', num_classes=600):
|
||||
super(ShuffleNetV2, self).__init__()
|
||||
|
||||
self.stage_repeats = [4, 8, 4]
|
||||
# index 0 is invalid and should never be called.
|
||||
# only used for indexing convenience.
|
||||
if width_mult == '0.25x':
|
||||
self.stage_out_channels = [-1, 24, 32, 64, 128]
|
||||
elif width_mult == '0.5x':
|
||||
self.stage_out_channels = [-1, 24, 48, 96, 192]
|
||||
elif width_mult == '1.0x':
|
||||
# self.stage_out_channels = [-1, 24, 116, 232, 464]
|
||||
self.stage_out_channels = [-1, 24, 128, 256, 512]
|
||||
elif width_mult == '1.5x':
|
||||
self.stage_out_channels = [-1, 24, 176, 352, 704]
|
||||
elif width_mult == '2.0x':
|
||||
self.stage_out_channels = [-1, 24, 224, 488, 976]
|
||||
|
||||
# building first layer
|
||||
input_channel = self.stage_out_channels[1]
|
||||
self.conv1 = conv_bn(3, input_channel, stride=(1, 2, 2))
|
||||
self.maxpool = nn.MaxPool3d(kernel_size=3, stride=2, padding=1)
|
||||
self.features = []
|
||||
self.features1 = []
|
||||
self.features2 = []
|
||||
self.features3 = []
|
||||
# building inverted residual blocks
|
||||
for idxstage in range(len(self.stage_repeats)):
|
||||
numrepeat = self.stage_repeats[idxstage]
|
||||
output_channel = self.stage_out_channels[idxstage + 2]
|
||||
for i in range(numrepeat):
|
||||
stride = 2 if i == 0 else 1
|
||||
self.features.append(InvertedResidual(input_channel, output_channel, stride))
|
||||
input_channel = output_channel
|
||||
self.features = nn.Sequential(*self.features)
|
||||
# for idxstage in range(len(self.stage_repeats)):
|
||||
# numrepeat = self.stage_repeats[idxstage]
|
||||
# output_channel = self.stage_out_channels[idxstage+2]
|
||||
# for i in range(numrepeat):
|
||||
# if idxstage==0:
|
||||
# stride = 2 if i == 0 else 1
|
||||
# self.features1.append(InvertedResidual(input_channel, output_channel, stride))
|
||||
# input_channel = output_channel
|
||||
# elif idxstage==1:
|
||||
# stride = 2 if i == 0 else 1
|
||||
# self.features2.append(InvertedResidual(input_channel, output_channel, stride))
|
||||
# input_channel = output_channel
|
||||
# elif idxstage==2:
|
||||
# stride = 2 if i == 0 else 1
|
||||
# self.features3.append(InvertedResidual(input_channel, output_channel, stride))
|
||||
# input_channel = output_channel
|
||||
# # make it nn.Sequential
|
||||
# self.features1 = nn.Sequential(*self.features1)
|
||||
# self.features2 = nn.Sequential(*self.features2)
|
||||
# self.features3 = nn.Sequential(*self.features3)
|
||||
|
||||
# # building last several layers
|
||||
# self.conv_last = conv_1x1x1_bn(input_channel, self.stage_out_channels[-1])
|
||||
# self.avgpool = nn.AvgPool3d((2, 1, 1), stride=1)
|
||||
|
||||
def forward(self, x):
|
||||
outputs = {}
|
||||
# pdb.set_trace() #(1,3,16,512,512) #(1,3,5,512,512)
|
||||
|
||||
x = self.conv1(x) # (1,24,16,256,256) #(1,24,5,256,256)
|
||||
|
||||
x = self.maxpool(x) # (1,24,8,128,128) #(1,24,3,128,128)
|
||||
# outputs['stage1'] = x
|
||||
# x = self.features(x)
|
||||
x = self.features[:4](x) # (1,116,4,64,64) #(1,116,2,64,64)
|
||||
outputs['stage2'] = x # torch.mean(x, dim=2, keepdim=True).squeeze(2)
|
||||
x = self.features[4:12](x) # (1,232,2,32,32) #(1,232,1,32,32)
|
||||
outputs['stage3'] = x # torch.mean(x, dim=2, keepdim=True).squeeze(2)
|
||||
x = self.features[12:16](x) # (1,464,1,16,16) #(1,464,1,16,16)
|
||||
outputs['stage4'] = x # torch.mean(x, dim=2, keepdim=True).squeeze(2)
|
||||
# out = self.conv_last(out)
|
||||
|
||||
# if x.size(2) > 1:
|
||||
# x = torch.mean(x, dim=2, keepdim=True)
|
||||
|
||||
# return x.squeeze(2)
|
||||
return outputs
|
||||
|
||||
|
||||
def build_shufflenetv2_3d(model_size='1.0x', pretrained=False):
|
||||
model = ShuffleNetV2(model_size)
|
||||
feats = model.stage_out_channels[-1]
|
||||
|
||||
# if pretrained:
|
||||
# model = load_weight(model, model_size)
|
||||
|
||||
return model, feats
|
||||
|
||||
|
||||
def build_3d_cnn(cfg, pretrained=False):
|
||||
if 'resnet' in cfg['backbone_3d']:
|
||||
model, feat_dims = build_resnet_3d(
|
||||
model_name=cfg['backbone_3d'],
|
||||
pretrained=pretrained
|
||||
)
|
||||
elif 'resnext' in cfg['backbone_3d']:
|
||||
model, feat_dims = build_resnext_3d(
|
||||
model_name=cfg['backbone_3d'],
|
||||
pretrained=pretrained
|
||||
)
|
||||
elif 'shufflenetv2' in cfg['backbone_3d']:
|
||||
model, feat_dims = build_shufflenetv2_3d(
|
||||
model_size=cfg['model_size'],
|
||||
pretrained=pretrained
|
||||
)
|
||||
else:
|
||||
print('Unknown Backbone ...')
|
||||
exit()
|
||||
|
||||
return model, feat_dims
|
||||
|
||||
|
||||
class BackboneTD(nn.Module):
|
||||
def __init__(self, cfg, pretrained=False):
|
||||
super().__init__()
|
||||
self.cfg = cfg
|
||||
|
||||
# 3D CNN
|
||||
self.backbone, self.feat_dim = build_3d_cnn(cfg, pretrained)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Input:
|
||||
x: (Tensor) -> [B, C, T, H, W]
|
||||
Output:
|
||||
y: (List) -> [
|
||||
(Tensor) -> [B, C1, H1, W1],
|
||||
(Tensor) -> [B, C2, H2, W2],
|
||||
(Tensor) -> [B, C3, H3, W3]
|
||||
]
|
||||
"""
|
||||
feat = self.backbone(x)
|
||||
|
||||
return feat
|
||||
234
model/TDCNet/darknet.py
Normal file
234
model/TDCNet/darknet.py
Normal file
@@ -0,0 +1,234 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
# Copyright (c) Megvii, Inc. and its affiliates.
|
||||
import os
|
||||
|
||||
import torch
|
||||
from matplotlib import pyplot as plt
|
||||
from torch import nn
|
||||
|
||||
class SiLU(nn.Module):
|
||||
@staticmethod
|
||||
def forward(x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
def get_activation(name="silu", inplace=True):
|
||||
if name == "silu":
|
||||
module = SiLU()
|
||||
elif name == "relu":
|
||||
module = nn.ReLU(inplace=inplace)
|
||||
elif name == "lrelu":
|
||||
module = nn.LeakyReLU(0.1, inplace=inplace)
|
||||
elif name == "sigmoid":
|
||||
module = nn.Sigmoid()
|
||||
else:
|
||||
raise AttributeError("Unsupported act type: {}".format(name))
|
||||
return module
|
||||
|
||||
class Focus(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, ksize=1, stride=1, act="silu"):
|
||||
super().__init__()
|
||||
self.conv = BaseConv(in_channels * 4, out_channels, ksize, stride, act=act)
|
||||
|
||||
def forward(self, x):
|
||||
patch_top_left = x[..., ::2, ::2]
|
||||
patch_bot_left = x[..., 1::2, ::2]
|
||||
patch_top_right = x[..., ::2, 1::2]
|
||||
patch_bot_right = x[..., 1::2, 1::2]
|
||||
x = torch.cat((patch_top_left, patch_bot_left, patch_top_right, patch_bot_right,), dim=1,)
|
||||
return self.conv(x)
|
||||
|
||||
class BaseConv(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, ksize, stride, groups=1, bias=False, act="silu"):
|
||||
super().__init__()
|
||||
pad = (ksize - 1) // 2
|
||||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=ksize, stride=stride, padding=pad, groups=groups, bias=bias)
|
||||
self.bn = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.03)
|
||||
self.act = get_activation(act, inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
return self.act(self.bn(self.conv(x)))
|
||||
|
||||
def fuseforward(self, x):
|
||||
return self.act(self.conv(x))
|
||||
|
||||
class DWConv(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, ksize, stride=1, act="silu"):
|
||||
super().__init__()
|
||||
self.dconv = BaseConv(in_channels, in_channels, ksize=ksize, stride=stride, groups=in_channels, act=act,)
|
||||
self.pconv = BaseConv(in_channels, out_channels, ksize=1, stride=1, groups=1, act=act)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.dconv(x)
|
||||
return self.pconv(x)
|
||||
|
||||
class SPPBottleneck(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_sizes=(5, 9, 13), activation="silu"):
|
||||
super().__init__()
|
||||
hidden_channels = in_channels // 2
|
||||
self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=activation)
|
||||
self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=ks, stride=1, padding=ks // 2) for ks in kernel_sizes])
|
||||
conv2_channels = hidden_channels * (len(kernel_sizes) + 1)
|
||||
self.conv2 = BaseConv(conv2_channels, out_channels, 1, stride=1, act=activation)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = torch.cat([x] + [m(x) for m in self.m], dim=1)
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
#--------------------------------------------------#
|
||||
# 残差结构的构建,小的残差结构
|
||||
#--------------------------------------------------#
|
||||
class Bottleneck(nn.Module):
|
||||
# Standard bottleneck
|
||||
def __init__(self, in_channels, out_channels, shortcut=True, expansion=0.5, depthwise=False, act="silu",):
|
||||
super().__init__()
|
||||
hidden_channels = int(out_channels * expansion)
|
||||
Conv = DWConv if depthwise else BaseConv
|
||||
#--------------------------------------------------#
|
||||
# 利用1x1卷积进行通道数的缩减。缩减率一般是50%
|
||||
#--------------------------------------------------#
|
||||
self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act)
|
||||
#--------------------------------------------------#
|
||||
# 利用3x3卷积进行通道数的拓张。并且完成特征提取
|
||||
#--------------------------------------------------#
|
||||
self.conv2 = Conv(hidden_channels, out_channels, 3, stride=1, act=act)
|
||||
self.use_add = shortcut and in_channels == out_channels
|
||||
|
||||
def forward(self, x):
|
||||
y = self.conv2(self.conv1(x))
|
||||
if self.use_add:
|
||||
y = y + x
|
||||
return y
|
||||
|
||||
class CSPLayer(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, n=1, shortcut=True, expansion=0.5, depthwise=False, act="silu",):
|
||||
# ch_in, ch_out, number, shortcut, groups, expansion
|
||||
super().__init__()
|
||||
hidden_channels = int(out_channels * expansion)
|
||||
#--------------------------------------------------#
|
||||
# 主干部分的初次卷积
|
||||
#--------------------------------------------------#
|
||||
self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act)
|
||||
#--------------------------------------------------#
|
||||
# 大的残差边部分的初次卷积
|
||||
#--------------------------------------------------#
|
||||
self.conv2 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act)
|
||||
#-----------------------------------------------#
|
||||
# 对堆叠的结果进行卷积的处理
|
||||
#-----------------------------------------------#
|
||||
self.conv3 = BaseConv(2 * hidden_channels, out_channels, 1, stride=1, act=act)
|
||||
|
||||
#--------------------------------------------------#
|
||||
# 根据循环的次数构建上述Bottleneck残差结构
|
||||
#--------------------------------------------------#
|
||||
module_list = [Bottleneck(hidden_channels, hidden_channels, shortcut, 1.0, depthwise, act=act) for _ in range(n)]
|
||||
self.m = nn.Sequential(*module_list)
|
||||
|
||||
def forward(self, x):
|
||||
#-------------------------------#
|
||||
# x_1是主干部分
|
||||
#-------------------------------#
|
||||
x_1 = self.conv1(x)
|
||||
#-------------------------------#
|
||||
# x_2是大的残差边部分
|
||||
#-------------------------------#
|
||||
x_2 = self.conv2(x)
|
||||
|
||||
#-----------------------------------------------#
|
||||
# 主干部分利用残差结构堆叠继续进行特征提取
|
||||
#-----------------------------------------------#
|
||||
x_1 = self.m(x_1)
|
||||
#-----------------------------------------------#
|
||||
# 主干部分和大的残差边部分进行堆叠
|
||||
#-----------------------------------------------#
|
||||
x = torch.cat((x_1, x_2), dim=1)
|
||||
#-----------------------------------------------#
|
||||
# 对堆叠的结果进行卷积的处理
|
||||
#-----------------------------------------------#
|
||||
return self.conv3(x)
|
||||
|
||||
class CSPDarknet(nn.Module):
|
||||
def __init__(self, dep_mul, wid_mul, out_features=("dark3", "dark4", "dark5"), depthwise=False, act="silu",):
|
||||
super().__init__()
|
||||
assert out_features, "please provide output features of Darknet"
|
||||
self.out_features = out_features
|
||||
Conv = DWConv if depthwise else BaseConv
|
||||
|
||||
#-----------------------------------------------#
|
||||
# 输入图片是640, 640, 3
|
||||
# 初始的基本通道是64
|
||||
#-----------------------------------------------#
|
||||
base_channels = int(wid_mul * 64) # 64
|
||||
base_depth = max(round(dep_mul * 3), 1) # 3
|
||||
|
||||
#-----------------------------------------------#
|
||||
# 利用focus网络结构进行特征提取
|
||||
# 640, 640, 3 -> 320, 320, 12 -> 320, 320, 64
|
||||
#-----------------------------------------------#
|
||||
self.stem = Focus(3, base_channels, ksize=3, act=act)
|
||||
|
||||
#-----------------------------------------------#
|
||||
# 完成卷积之后,320, 320, 64 -> 160, 160, 128
|
||||
# 完成CSPlayer之后,160, 160, 128 -> 160, 160, 128
|
||||
#-----------------------------------------------#
|
||||
self.dark2 = nn.Sequential(
|
||||
Conv(base_channels, base_channels * 2, 3, 2, act=act),
|
||||
CSPLayer(base_channels * 2, base_channels * 2, n=base_depth, depthwise=depthwise, act=act),
|
||||
)
|
||||
|
||||
#-----------------------------------------------#
|
||||
# 完成卷积之后,160, 160, 128 -> 80, 80, 256
|
||||
# 完成CSPlayer之后,80, 80, 256 -> 80, 80, 256
|
||||
#-----------------------------------------------#
|
||||
self.dark3 = nn.Sequential(
|
||||
Conv(base_channels * 2, base_channels * 4, 3, 2, act=act),
|
||||
CSPLayer(base_channels * 4, base_channels * 4, n=base_depth * 3, depthwise=depthwise, act=act),
|
||||
)
|
||||
|
||||
#-----------------------------------------------#
|
||||
# 完成卷积之后,80, 80, 256 -> 40, 40, 512
|
||||
# 完成CSPlayer之后,40, 40, 512 -> 40, 40, 512
|
||||
#-----------------------------------------------#
|
||||
self.dark4 = nn.Sequential(
|
||||
Conv(base_channels * 4, base_channels * 8, 3, 2, act=act),
|
||||
CSPLayer(base_channels * 8, base_channels * 8, n=base_depth * 3, depthwise=depthwise, act=act),
|
||||
)
|
||||
|
||||
#-----------------------------------------------#
|
||||
# 完成卷积之后,40, 40, 512 -> 20, 20, 1024
|
||||
# 完成SPP之后,20, 20, 1024 -> 20, 20, 1024
|
||||
# 完成CSPlayer之后,20, 20, 1024 -> 20, 20, 1024
|
||||
#-----------------------------------------------#
|
||||
self.dark5 = nn.Sequential(
|
||||
Conv(base_channels * 8, base_channels * 16, 3, 2, act=act),
|
||||
SPPBottleneck(base_channels * 16, base_channels * 16, activation=act),
|
||||
CSPLayer(base_channels * 16, base_channels * 16, n=base_depth, shortcut=False, depthwise=depthwise, act=act),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
outputs = {}
|
||||
x = self.stem(x)
|
||||
outputs["stem"] = x
|
||||
|
||||
|
||||
x = self.dark2(x)
|
||||
outputs["dark2"] = x
|
||||
|
||||
#-----------------------------------------------#
|
||||
# dark3的输出为80, 80, 256,是一个有效特征层
|
||||
#-----------------------------------------------#
|
||||
x = self.dark3(x)
|
||||
outputs["dark3"] = x
|
||||
#-----------------------------------------------#
|
||||
# dark4的输出为40, 40, 512,是一个有效特征层
|
||||
#-----------------------------------------------#
|
||||
x = self.dark4(x)
|
||||
outputs["dark4"] = x
|
||||
#-----------------------------------------------#
|
||||
# dark5的输出为20, 20, 1024,是一个有效特征层
|
||||
#-----------------------------------------------#
|
||||
x = self.dark5(x)
|
||||
outputs["dark5"] = x
|
||||
return {k: v for k, v in outputs.items() if k in self.out_features}
|
||||
507
model/nets/yolo_training.py
Normal file
507
model/nets/yolo_training.py
Normal file
@@ -0,0 +1,507 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
# Copyright (c) Megvii, Inc. and its affiliates.
|
||||
import math
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torchvision.ops.focal_loss import sigmoid_focal_loss
|
||||
|
||||
|
||||
class IOUloss(nn.Module):
|
||||
def __init__(self, reduction="none", loss_type="iou"):
|
||||
super(IOUloss, self).__init__()
|
||||
self.reduction = reduction
|
||||
self.loss_type = loss_type
|
||||
|
||||
def forward(self, pred, target):
|
||||
assert pred.shape[0] == target.shape[0]
|
||||
|
||||
pred = pred.view(-1, 4)
|
||||
target = target.view(-1, 4)
|
||||
tl = torch.max(
|
||||
(pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2)
|
||||
)
|
||||
br = torch.min(
|
||||
(pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2)
|
||||
)
|
||||
|
||||
area_p = torch.prod(pred[:, 2:], 1)
|
||||
area_g = torch.prod(target[:, 2:], 1)
|
||||
|
||||
en = (tl < br).type(tl.type()).prod(dim=1)
|
||||
area_i = torch.prod(br - tl, 1) * en
|
||||
area_u = area_p + area_g - area_i
|
||||
iou = (area_i) / (area_u + 1e-16)
|
||||
|
||||
if self.loss_type == "iou":
|
||||
loss = 1 - iou ** 2
|
||||
elif self.loss_type == "giou":
|
||||
c_tl = torch.min(
|
||||
(pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2)
|
||||
)
|
||||
c_br = torch.max(
|
||||
(pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2)
|
||||
)
|
||||
area_c = torch.prod(c_br - c_tl, 1)
|
||||
giou = iou - (area_c - area_u) / area_c.clamp(1e-16)
|
||||
loss = 1 - giou.clamp(min=-1.0, max=1.0)
|
||||
elif self.loss_type == 'ciou':
|
||||
b1_cxy = pred[:,:2]
|
||||
b2_cxy = target[:,:2]
|
||||
# 计算中心的差距
|
||||
center_distance = torch.sum(torch.pow((b1_cxy - b2_cxy), 2), axis=-1)
|
||||
# 找到包裹两个框的最小框的左上角和右下角
|
||||
enclose_mins = torch.min((pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2))
|
||||
enclose_maxes = torch.max((pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2))
|
||||
enclose_wh = torch.max(enclose_maxes - enclose_mins, torch.zeros_like(br))
|
||||
# 计算对角线距离
|
||||
enclose_diagonal = torch.sum(torch.pow(enclose_wh,2), axis=-1)
|
||||
ciou = iou - 1.0 * (center_distance) / torch.clamp(enclose_diagonal,min = 1e-6)
|
||||
v = (4 / (torch.pi ** 2)) * torch.pow((torch.atan(pred[:, 2]/torch.clamp(pred[:, 3],min = 1e-6)) - torch.atan(target[:, 2]/torch.clamp(target[:, 3],min = 1e-6))), 2)
|
||||
alpha = v / torch.clamp((1.0 - iou + v),min=1e-6)
|
||||
ciou = ciou - alpha * v
|
||||
loss = 1 - ciou.clamp(min=-1.0, max=1.0)
|
||||
|
||||
if self.reduction == "mean":
|
||||
loss = loss.mean()
|
||||
elif self.reduction == "sum":
|
||||
loss = loss.sum()
|
||||
|
||||
return loss
|
||||
|
||||
class YOLOLoss(nn.Module):
|
||||
def __init__(self, num_classes, fp16, strides=[8, 16, 32]):
|
||||
super().__init__()
|
||||
self.num_classes = num_classes
|
||||
self.strides = strides
|
||||
|
||||
self.bcewithlog_loss = nn.BCEWithLogitsLoss(reduction="none")
|
||||
self.iou_loss = IOUloss(reduction="none")
|
||||
self.grids = [torch.zeros(1)] * len(strides)
|
||||
self.fp16 = fp16
|
||||
|
||||
def forward(self, inputs, labels=None):
|
||||
outputs = []
|
||||
x_shifts = []
|
||||
y_shifts = []
|
||||
expanded_strides = []
|
||||
|
||||
#-----------------------------------------------#
|
||||
# inputs [[batch_size, num_classes + 5, 20, 20]
|
||||
# [batch_size, num_classes + 5, 40, 40]
|
||||
# [batch_size, num_classes + 5, 80, 80]]
|
||||
# outputs [[batch_size, 400, num_classes + 5]
|
||||
# [batch_size, 1600, num_classes + 5]
|
||||
# [batch_size, 6400, num_classes + 5]]
|
||||
# x_shifts [[batch_size, 400]
|
||||
# [batch_size, 1600]
|
||||
# [batch_size, 6400]]
|
||||
#-----------------------------------------------#
|
||||
for k, (stride, output) in enumerate(zip(self.strides, inputs)):
|
||||
output, grid = self.get_output_and_grid(output, k, stride)
|
||||
x_shifts.append(grid[:, :, 0])
|
||||
y_shifts.append(grid[:, :, 1])
|
||||
expanded_strides.append(torch.ones_like(grid[:, :, 0]) * stride)
|
||||
outputs.append(output)
|
||||
|
||||
return self.get_losses(x_shifts, y_shifts, expanded_strides, labels, torch.cat(outputs, 1))
|
||||
|
||||
def get_output_and_grid(self, output, k, stride):
|
||||
grid = self.grids[k]
|
||||
hsize, wsize = output.shape[-2:]
|
||||
if grid.shape[2:4] != output.shape[2:4]:
|
||||
yv, xv = torch.meshgrid([torch.arange(hsize), torch.arange(wsize)], indexing='ij')
|
||||
grid = torch.stack((xv, yv), 2).view(1, hsize, wsize, 2).type(output.type())
|
||||
self.grids[k] = grid
|
||||
grid = grid.view(1, -1, 2)
|
||||
|
||||
output = output.flatten(start_dim=2).permute(0, 2, 1)
|
||||
output[..., :2] = (output[..., :2] + grid.type_as(output)) * stride
|
||||
output[..., 2:4] = torch.exp(output[..., 2:4]) * stride
|
||||
return output, grid
|
||||
|
||||
def get_losses(self, x_shifts, y_shifts, expanded_strides, labels, outputs):
|
||||
#-----------------------------------------------#
|
||||
# [batch, n_anchors_all, 4]
|
||||
#-----------------------------------------------#
|
||||
bbox_preds = outputs[:, :, :4]
|
||||
#-----------------------------------------------#
|
||||
# [batch, n_anchors_all, 1]
|
||||
#-----------------------------------------------#
|
||||
obj_preds = outputs[:, :, 4:5]
|
||||
#-----------------------------------------------#
|
||||
# [batch, n_anchors_all, n_cls]
|
||||
#-----------------------------------------------#
|
||||
cls_preds = outputs[:, :, 5:]
|
||||
|
||||
total_num_anchors = outputs.shape[1]
|
||||
#-----------------------------------------------#
|
||||
# x_shifts [1, n_anchors_all]
|
||||
# y_shifts [1, n_anchors_all]
|
||||
# expanded_strides [1, n_anchors_all]
|
||||
#-----------------------------------------------#
|
||||
x_shifts = torch.cat(x_shifts, 1).type_as(outputs)
|
||||
y_shifts = torch.cat(y_shifts, 1).type_as(outputs)
|
||||
expanded_strides = torch.cat(expanded_strides, 1).type_as(outputs)
|
||||
|
||||
cls_targets = []
|
||||
reg_targets = []
|
||||
obj_targets = []
|
||||
fg_masks = []
|
||||
|
||||
num_fg = 0.0
|
||||
for batch_idx in range(outputs.shape[0]):
|
||||
num_gt = len(labels[batch_idx])
|
||||
if num_gt == 0:
|
||||
cls_target = outputs.new_zeros((0, self.num_classes))
|
||||
reg_target = outputs.new_zeros((0, 4))
|
||||
obj_target = outputs.new_zeros((total_num_anchors, 1))
|
||||
fg_mask = outputs.new_zeros(total_num_anchors).bool()
|
||||
else:
|
||||
#-----------------------------------------------#
|
||||
# gt_bboxes_per_image [num_gt, num_classes]
|
||||
# gt_classes [num_gt]
|
||||
# bboxes_preds_per_image [n_anchors_all, 4]
|
||||
# cls_preds_per_image [n_anchors_all, num_classes]
|
||||
# obj_preds_per_image [n_anchors_all, 1]
|
||||
#-----------------------------------------------#
|
||||
gt_bboxes_per_image = labels[batch_idx][..., :4].type_as(outputs)
|
||||
gt_classes = labels[batch_idx][..., 4].type_as(outputs)
|
||||
bboxes_preds_per_image = bbox_preds[batch_idx]
|
||||
cls_preds_per_image = cls_preds[batch_idx]
|
||||
obj_preds_per_image = obj_preds[batch_idx]
|
||||
|
||||
gt_matched_classes, fg_mask, pred_ious_this_matching, matched_gt_inds, num_fg_img = self.get_assignments(
|
||||
num_gt, total_num_anchors, gt_bboxes_per_image, gt_classes, bboxes_preds_per_image, cls_preds_per_image, obj_preds_per_image,
|
||||
expanded_strides, x_shifts, y_shifts,
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
num_fg += num_fg_img
|
||||
cls_target = F.one_hot(gt_matched_classes.to(torch.int64), self.num_classes).float() * pred_ious_this_matching.unsqueeze(-1)
|
||||
obj_target = fg_mask.unsqueeze(-1)
|
||||
reg_target = gt_bboxes_per_image[matched_gt_inds]
|
||||
cls_targets.append(cls_target)
|
||||
reg_targets.append(reg_target)
|
||||
obj_targets.append(obj_target.type(cls_target.type()))
|
||||
fg_masks.append(fg_mask)
|
||||
|
||||
cls_targets = torch.cat(cls_targets, 0)
|
||||
reg_targets = torch.cat(reg_targets, 0)
|
||||
obj_targets = torch.cat(obj_targets, 0)
|
||||
fg_masks = torch.cat(fg_masks, 0)
|
||||
|
||||
num_fg = max(num_fg, 1)
|
||||
loss_iou = (self.iou_loss(bbox_preds.view(-1, 4)[fg_masks], reg_targets)).sum()
|
||||
loss_obj = (self.bcewithlog_loss(obj_preds.view(-1, 1), obj_targets)).sum()
|
||||
loss_cls = (self.bcewithlog_loss(cls_preds.view(-1, self.num_classes)[fg_masks], cls_targets)).sum()
|
||||
# loss_obj = (sigmoid_focal_loss(obj_preds.view(-1, 1), obj_targets)).sum()
|
||||
# loss_cls = (sigmoid_focal_loss(cls_preds.view(-1, self.num_classes)[fg_masks], cls_targets)).sum()
|
||||
reg_weight = 5.0
|
||||
loss = reg_weight * loss_iou + loss_obj + loss_cls
|
||||
|
||||
return loss / num_fg
|
||||
|
||||
@torch.no_grad()
|
||||
def get_assignments(self, num_gt, total_num_anchors, gt_bboxes_per_image, gt_classes, bboxes_preds_per_image, cls_preds_per_image, obj_preds_per_image, expanded_strides, x_shifts, y_shifts):
|
||||
#-------------------------------------------------------#
|
||||
# fg_mask [n_anchors_all]
|
||||
# is_in_boxes_and_center [num_gt, len(fg_mask)]
|
||||
#-------------------------------------------------------#
|
||||
fg_mask, is_in_boxes_and_center = self.get_in_boxes_info(gt_bboxes_per_image, expanded_strides, x_shifts, y_shifts, total_num_anchors, num_gt)
|
||||
|
||||
#-------------------------------------------------------#
|
||||
# fg_mask [n_anchors_all]
|
||||
# bboxes_preds_per_image [fg_mask, 4]
|
||||
# cls_preds_ [fg_mask, num_classes]
|
||||
# obj_preds_ [fg_mask, 1]
|
||||
#-------------------------------------------------------#
|
||||
bboxes_preds_per_image = bboxes_preds_per_image[fg_mask]
|
||||
cls_preds_ = cls_preds_per_image[fg_mask]
|
||||
obj_preds_ = obj_preds_per_image[fg_mask]
|
||||
num_in_boxes_anchor = bboxes_preds_per_image.shape[0]
|
||||
|
||||
#-------------------------------------------------------#
|
||||
# pair_wise_ious [num_gt, fg_mask]
|
||||
#-------------------------------------------------------#
|
||||
pair_wise_ious = self.bboxes_iou(gt_bboxes_per_image, bboxes_preds_per_image, False)
|
||||
pair_wise_ious_loss = -torch.log(pair_wise_ious + 1e-8)
|
||||
|
||||
#-------------------------------------------------------#
|
||||
# cls_preds_ [num_gt, fg_mask, num_classes]
|
||||
# gt_cls_per_image [num_gt, fg_mask, num_classes]
|
||||
#-------------------------------------------------------#
|
||||
if self.fp16:
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
cls_preds_ = cls_preds_.float().unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_() * obj_preds_.unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_()
|
||||
gt_cls_per_image = F.one_hot(gt_classes.to(torch.int64), self.num_classes).float().unsqueeze(1).repeat(1, num_in_boxes_anchor, 1)
|
||||
pair_wise_cls_loss = F.binary_cross_entropy(cls_preds_.sqrt_(), gt_cls_per_image, reduction="none").sum(-1)
|
||||
else:
|
||||
cls_preds_ = cls_preds_.float().unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_() * obj_preds_.unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_()
|
||||
gt_cls_per_image = F.one_hot(gt_classes.to(torch.int64), self.num_classes).float().unsqueeze(1).repeat(1, num_in_boxes_anchor, 1)
|
||||
pair_wise_cls_loss = F.binary_cross_entropy(cls_preds_.sqrt_(), gt_cls_per_image, reduction="none").sum(-1)
|
||||
del cls_preds_
|
||||
|
||||
cost = pair_wise_cls_loss + 3.0 * pair_wise_ious_loss + 100000.0 * (~is_in_boxes_and_center).float()
|
||||
|
||||
num_fg, gt_matched_classes, pred_ious_this_matching, matched_gt_inds = self.dynamic_k_matching(cost, pair_wise_ious, gt_classes, num_gt, fg_mask)
|
||||
del pair_wise_cls_loss, cost, pair_wise_ious, pair_wise_ious_loss
|
||||
return gt_matched_classes, fg_mask, pred_ious_this_matching, matched_gt_inds, num_fg
|
||||
|
||||
def bboxes_iou(self, bboxes_a, bboxes_b, xyxy=True):
|
||||
if bboxes_a.shape[1] != 4 or bboxes_b.shape[1] != 4:
|
||||
raise IndexError
|
||||
|
||||
if xyxy:
|
||||
tl = torch.max(bboxes_a[:, None, :2], bboxes_b[:, :2])
|
||||
br = torch.min(bboxes_a[:, None, 2:], bboxes_b[:, 2:])
|
||||
area_a = torch.prod(bboxes_a[:, 2:] - bboxes_a[:, :2], 1)
|
||||
area_b = torch.prod(bboxes_b[:, 2:] - bboxes_b[:, :2], 1)
|
||||
else:
|
||||
tl = torch.max(
|
||||
(bboxes_a[:, None, :2] - bboxes_a[:, None, 2:] / 2),
|
||||
(bboxes_b[:, :2] - bboxes_b[:, 2:] / 2),
|
||||
)
|
||||
br = torch.min(
|
||||
(bboxes_a[:, None, :2] + bboxes_a[:, None, 2:] / 2),
|
||||
(bboxes_b[:, :2] + bboxes_b[:, 2:] / 2),
|
||||
)
|
||||
|
||||
area_a = torch.prod(bboxes_a[:, 2:], 1)
|
||||
area_b = torch.prod(bboxes_b[:, 2:], 1)
|
||||
en = (tl < br).type(tl.type()).prod(dim=2)
|
||||
area_i = torch.prod(br - tl, 2) * en
|
||||
return area_i / (area_a[:, None] + area_b - area_i)
|
||||
|
||||
def get_in_boxes_info(self, gt_bboxes_per_image, expanded_strides, x_shifts, y_shifts, total_num_anchors, num_gt, center_radius = 2.5):
|
||||
#-------------------------------------------------------#
|
||||
# expanded_strides_per_image [n_anchors_all]
|
||||
# x_centers_per_image [num_gt, n_anchors_all]
|
||||
# x_centers_per_image [num_gt, n_anchors_all]
|
||||
#-------------------------------------------------------#
|
||||
expanded_strides_per_image = expanded_strides[0]
|
||||
x_centers_per_image = ((x_shifts[0] + 0.5) * expanded_strides_per_image).unsqueeze(0).repeat(num_gt, 1)
|
||||
y_centers_per_image = ((y_shifts[0] + 0.5) * expanded_strides_per_image).unsqueeze(0).repeat(num_gt, 1)
|
||||
|
||||
#-------------------------------------------------------#
|
||||
# gt_bboxes_per_image_x [num_gt, n_anchors_all]
|
||||
#-------------------------------------------------------#
|
||||
gt_bboxes_per_image_l = (gt_bboxes_per_image[:, 0] - 0.5 * gt_bboxes_per_image[:, 2]).unsqueeze(1).repeat(1, total_num_anchors)
|
||||
gt_bboxes_per_image_r = (gt_bboxes_per_image[:, 0] + 0.5 * gt_bboxes_per_image[:, 2]).unsqueeze(1).repeat(1, total_num_anchors)
|
||||
gt_bboxes_per_image_t = (gt_bboxes_per_image[:, 1] - 0.5 * gt_bboxes_per_image[:, 3]).unsqueeze(1).repeat(1, total_num_anchors)
|
||||
gt_bboxes_per_image_b = (gt_bboxes_per_image[:, 1] + 0.5 * gt_bboxes_per_image[:, 3]).unsqueeze(1).repeat(1, total_num_anchors)
|
||||
|
||||
#-------------------------------------------------------#
|
||||
# bbox_deltas [num_gt, n_anchors_all, 4]
|
||||
#-------------------------------------------------------#
|
||||
b_l = x_centers_per_image - gt_bboxes_per_image_l
|
||||
b_r = gt_bboxes_per_image_r - x_centers_per_image
|
||||
b_t = y_centers_per_image - gt_bboxes_per_image_t
|
||||
b_b = gt_bboxes_per_image_b - y_centers_per_image
|
||||
bbox_deltas = torch.stack([b_l, b_t, b_r, b_b], 2)
|
||||
|
||||
#-------------------------------------------------------#
|
||||
# is_in_boxes [num_gt, n_anchors_all]
|
||||
# is_in_boxes_all [n_anchors_all]
|
||||
#-------------------------------------------------------#
|
||||
is_in_boxes = bbox_deltas.min(dim=-1).values > 0.0
|
||||
is_in_boxes_all = is_in_boxes.sum(dim=0) > 0
|
||||
|
||||
gt_bboxes_per_image_l = (gt_bboxes_per_image[:, 0]).unsqueeze(1).repeat(1, total_num_anchors) - center_radius * expanded_strides_per_image.unsqueeze(0)
|
||||
gt_bboxes_per_image_r = (gt_bboxes_per_image[:, 0]).unsqueeze(1).repeat(1, total_num_anchors) + center_radius * expanded_strides_per_image.unsqueeze(0)
|
||||
gt_bboxes_per_image_t = (gt_bboxes_per_image[:, 1]).unsqueeze(1).repeat(1, total_num_anchors) - center_radius * expanded_strides_per_image.unsqueeze(0)
|
||||
gt_bboxes_per_image_b = (gt_bboxes_per_image[:, 1]).unsqueeze(1).repeat(1, total_num_anchors) + center_radius * expanded_strides_per_image.unsqueeze(0)
|
||||
|
||||
#-------------------------------------------------------#
|
||||
# center_deltas [num_gt, n_anchors_all, 4]
|
||||
#-------------------------------------------------------#
|
||||
c_l = x_centers_per_image - gt_bboxes_per_image_l
|
||||
c_r = gt_bboxes_per_image_r - x_centers_per_image
|
||||
c_t = y_centers_per_image - gt_bboxes_per_image_t
|
||||
c_b = gt_bboxes_per_image_b - y_centers_per_image
|
||||
center_deltas = torch.stack([c_l, c_t, c_r, c_b], 2)
|
||||
|
||||
#-------------------------------------------------------#
|
||||
# is_in_centers [num_gt, n_anchors_all]
|
||||
# is_in_centers_all [n_anchors_all]
|
||||
#-------------------------------------------------------#
|
||||
is_in_centers = center_deltas.min(dim=-1).values > 0.0
|
||||
is_in_centers_all = is_in_centers.sum(dim=0) > 0
|
||||
|
||||
#-------------------------------------------------------#
|
||||
# is_in_boxes_anchor [n_anchors_all]
|
||||
# is_in_boxes_and_center [num_gt, is_in_boxes_anchor]
|
||||
#-------------------------------------------------------#
|
||||
is_in_boxes_anchor = is_in_boxes_all | is_in_centers_all
|
||||
is_in_boxes_and_center = is_in_boxes[:, is_in_boxes_anchor] & is_in_centers[:, is_in_boxes_anchor]
|
||||
return is_in_boxes_anchor, is_in_boxes_and_center
|
||||
|
||||
def dynamic_k_matching(self, cost, pair_wise_ious, gt_classes, num_gt, fg_mask):
|
||||
#-------------------------------------------------------#
|
||||
# cost [num_gt, fg_mask]
|
||||
# pair_wise_ious [num_gt, fg_mask]
|
||||
# gt_classes [num_gt]
|
||||
# fg_mask [n_anchors_all]
|
||||
# matching_matrix [num_gt, fg_mask]
|
||||
#-------------------------------------------------------#
|
||||
matching_matrix = torch.zeros_like(cost)
|
||||
|
||||
#------------------------------------------------------------#
|
||||
# 选取iou最大的n_candidate_k个点
|
||||
# 然后求和,判断应该有多少点用于该框预测
|
||||
# topk_ious [num_gt, n_candidate_k]
|
||||
# dynamic_ks [num_gt]
|
||||
# matching_matrix [num_gt, fg_mask]
|
||||
#------------------------------------------------------------#
|
||||
n_candidate_k = min(10, pair_wise_ious.size(1))
|
||||
topk_ious, _ = torch.topk(pair_wise_ious, n_candidate_k, dim=1)
|
||||
dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1)
|
||||
|
||||
for gt_idx in range(num_gt):
|
||||
#------------------------------------------------------------#
|
||||
# 给每个真实框选取最小的动态k个点
|
||||
#------------------------------------------------------------#
|
||||
_, pos_idx = torch.topk(cost[gt_idx], k=dynamic_ks[gt_idx].item(), largest=False)
|
||||
matching_matrix[gt_idx][pos_idx] = 1.0
|
||||
del topk_ious, dynamic_ks, pos_idx
|
||||
|
||||
#------------------------------------------------------------#
|
||||
# anchor_matching_gt [fg_mask]
|
||||
#------------------------------------------------------------#
|
||||
anchor_matching_gt = matching_matrix.sum(0)
|
||||
if (anchor_matching_gt > 1).sum() > 0:
|
||||
#------------------------------------------------------------#
|
||||
# 当某一个特征点指向多个真实框的时候
|
||||
# 选取cost最小的真实框。
|
||||
#------------------------------------------------------------#
|
||||
_, cost_argmin = torch.min(cost[:, anchor_matching_gt > 1], dim=0)
|
||||
matching_matrix[:, anchor_matching_gt > 1] *= 0.0
|
||||
matching_matrix[cost_argmin, anchor_matching_gt > 1] = 1.0
|
||||
#------------------------------------------------------------#
|
||||
# fg_mask_inboxes [fg_mask]
|
||||
# num_fg为正样本的特征点个数
|
||||
#------------------------------------------------------------#
|
||||
fg_mask_inboxes = matching_matrix.sum(0) > 0.0
|
||||
num_fg = fg_mask_inboxes.sum().item()
|
||||
|
||||
#------------------------------------------------------------#
|
||||
# 对fg_mask进行更新
|
||||
#------------------------------------------------------------#
|
||||
fg_mask[fg_mask.clone()] = fg_mask_inboxes
|
||||
|
||||
#------------------------------------------------------------#
|
||||
# 获得特征点对应的物品种类
|
||||
#------------------------------------------------------------#
|
||||
matched_gt_inds = matching_matrix[:, fg_mask_inboxes].argmax(0)
|
||||
gt_matched_classes = gt_classes[matched_gt_inds]
|
||||
|
||||
pred_ious_this_matching = (matching_matrix * pair_wise_ious).sum(0)[fg_mask_inboxes]
|
||||
return num_fg, gt_matched_classes, pred_ious_this_matching, matched_gt_inds
|
||||
|
||||
def is_parallel(model):
|
||||
# Returns True if model is of type DP or DDP
|
||||
return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
|
||||
|
||||
def de_parallel(model):
|
||||
# De-parallelize a model: returns single-GPU model if model is of type DP or DDP
|
||||
return model.module if is_parallel(model) else model
|
||||
|
||||
def copy_attr(a, b, include=(), exclude=()):
|
||||
# Copy attributes from b to a, options to only include [...] and to exclude [...]
|
||||
for k, v in b.__dict__.items():
|
||||
if (len(include) and k not in include) or k.startswith('_') or k in exclude:
|
||||
continue
|
||||
else:
|
||||
setattr(a, k, v)
|
||||
|
||||
class ModelEMA:
|
||||
""" Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models
|
||||
Keeps a moving average of everything in the model state_dict (parameters and buffers)
|
||||
For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
|
||||
"""
|
||||
|
||||
def __init__(self, model, decay=0.9999, tau=2000, updates=0):
|
||||
# Create EMA
|
||||
self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA
|
||||
# if next(model.parameters()).device.type != 'cpu':
|
||||
# self.ema.half() # FP16 EMA
|
||||
self.updates = updates # number of EMA updates
|
||||
self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs)
|
||||
for p in self.ema.parameters():
|
||||
p.requires_grad_(False)
|
||||
|
||||
def update(self, model):
|
||||
# Update EMA parameters
|
||||
with torch.no_grad():
|
||||
self.updates += 1
|
||||
d = self.decay(self.updates)
|
||||
|
||||
msd = de_parallel(model).state_dict() # model state_dict
|
||||
for k, v in self.ema.state_dict().items():
|
||||
if v.dtype.is_floating_point:
|
||||
v *= d
|
||||
v += (1 - d) * msd[k].detach()
|
||||
|
||||
def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
|
||||
# Update EMA attributes
|
||||
copy_attr(self.ema, model, include, exclude)
|
||||
|
||||
def weights_init(net, init_type='normal', init_gain = 0.02):
|
||||
def init_func(m):
|
||||
classname = m.__class__.__name__
|
||||
if hasattr(m, 'weight') and classname.find('Conv') != -1:
|
||||
if init_type == 'normal':
|
||||
torch.nn.init.normal_(m.weight.data, 0.0, init_gain)
|
||||
elif init_type == 'xavier':
|
||||
torch.nn.init.xavier_normal_(m.weight.data, gain=init_gain)
|
||||
elif init_type == 'kaiming':
|
||||
torch.nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
|
||||
elif init_type == 'orthogonal':
|
||||
torch.nn.init.orthogonal_(m.weight.data, gain=init_gain)
|
||||
else:
|
||||
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
|
||||
elif classname.find('BatchNorm2d') != -1:
|
||||
torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
|
||||
torch.nn.init.constant_(m.bias.data, 0.0)
|
||||
print('initialize network with %s type' % init_type)
|
||||
net.apply(init_func)
|
||||
|
||||
def get_lr_scheduler(lr_decay_type, lr, min_lr, total_iters, warmup_iters_ratio = 0.05, warmup_lr_ratio = 0.1, no_aug_iter_ratio = 0.05, step_num = 10):
|
||||
def yolox_warm_cos_lr(lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter, iters):
|
||||
if iters <= warmup_total_iters:
|
||||
# lr = (lr - warmup_lr_start) * iters / float(warmup_total_iters) + warmup_lr_start
|
||||
lr = (lr - warmup_lr_start) * pow(iters / float(warmup_total_iters), 2) + warmup_lr_start
|
||||
elif iters >= total_iters - no_aug_iter:
|
||||
lr = min_lr
|
||||
else:
|
||||
lr = min_lr + 0.5 * (lr - min_lr) * (
|
||||
1.0 + math.cos(math.pi* (iters - warmup_total_iters) / (total_iters - warmup_total_iters - no_aug_iter))
|
||||
)
|
||||
return lr
|
||||
|
||||
def step_lr(lr, decay_rate, step_size, iters):
|
||||
if step_size < 1:
|
||||
raise ValueError("step_size must above 1.")
|
||||
n = iters // step_size
|
||||
out_lr = lr * decay_rate ** n
|
||||
return out_lr
|
||||
|
||||
if lr_decay_type == "cos":
|
||||
warmup_total_iters = min(max(warmup_iters_ratio * total_iters, 1), 3)
|
||||
warmup_lr_start = max(warmup_lr_ratio * lr, 1e-6)
|
||||
no_aug_iter = min(max(no_aug_iter_ratio * total_iters, 1), 15)
|
||||
func = partial(yolox_warm_cos_lr ,lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter)
|
||||
else:
|
||||
decay_rate = (min_lr / lr) ** (1 / (step_num - 1))
|
||||
step_size = total_iters / step_num
|
||||
func = partial(step_lr, lr, decay_rate, step_size)
|
||||
|
||||
return func
|
||||
|
||||
def set_optimizer_lr(optimizer, lr_scheduler_func, epoch):
|
||||
lr = lr_scheduler_func(epoch)
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
Reference in New Issue
Block a user