This commit is contained in:
esenke
2025-12-08 21:38:53 +08:00
commit 71118fc649
22 changed files with 4780 additions and 0 deletions

173
README.md Normal file
View File

@@ -0,0 +1,173 @@
# TDCNet
Official repository of the AAAI 2026 paper "Spatio-Temporal Context Learning with Temporal Difference Convolution for Moving Infrared Small Target Detection".
## Spatio-Temporal Context Learning with Temporal Difference Convolution for Moving Infrared Small Target Detection [[PDF](https://arxiv.org/pdf/2511.09352)]
Authors: Houzhang Fang<sup>1</sup>, Shukai Guo<sup>1</sup>, Qiuhuan Chen<sup>1</sup>, Yi Chang<sup>2</sup>, Luxin Yan<sup>2</sup>
<sup>1</sup>Xidian University, <sup>2</sup>Huazhong University of Science and Technology
## Abstract
Moving infrared small target detection (IRSTD) plays a critical role in practical applications, such as surveillance of unmanned aerial vehicles (UAVs) and UAV-based search system. Moving IRSTD still remains highly challenging due to weak target features and complex background interference. Accurate spatio-temporal feature modeling is crucial for moving target detection, typically achieved through either temporal differences or spatio-temporal (3D) convolutions. Temporal difference can explicitly leverage motion cues but exhibits limited capability in extracting spatial features, whereas 3D convolution effectively represents spatio-temporal features yet lacks explicit awareness of motion dynamics along the temporal dimension. In this paper, we propose a novel moving IRSTD network (TDCNet), which effectively extracts and enhances spatio-temporal features for accurate target detection. Specifically, we introduce a novel temporal difference convolution (TDC) re-parameterization module that comprises three parallel TDC blocks designed to capture contextual dependencies across different temporal ranges. Each TDC block fuses temporal difference and 3D convolution into a unified spatio-temporal convolution representation. This re-parameterized module can effectively capture multi-scale motion contextual features while suppressing pseudo-motion clutter in complex backgrounds, significantly improving detection performance. Moreover, we propose a TDC-guided spatio-temporal attention mechanism that performs cross-attention between the spatio-temporal features extracted from the TDC-based backbone and a parallel 3D backbone. This mechanism models their global semantic dependencies to refine the current frames features, thereby guiding the model to focus more accurately on critical target regions. To facilitate comprehensive evaluation, we construct a new challenging benchmark, IRSTD-UAV, consisting of 15,106 real infrared images with diverse low signal-to-clutter ratio scenarios and complex backgrounds. Extensive experiments on IRSTD-UAV and public infrared datasets demonstrate that our TDCNet achieves state-of-the-art detection performance in moving target detection.
## TDCNet Framework
![image-20250407200916034](./figs/overall_framework.png)
## Visualization
![image-20250407201214584](./figs/vis_main.png)
Visual comparison of results from SOTA methods and TDCNet on the IRSTD-UAV and IRDST dataset. Boxes in green and red represent ground-truth and detected targets, respectively
## Environment
- [Python](https://www.python.org/)
- [PyTorch](https://pytorch.org/)
- [tqdm](https://github.com/tqdm/tqdm)
- [pycocotools](https://github.com/cocodataset/cocoapi)
- [OpenCV (cv2)](https://opencv.org/)
- [NumPy](https://numpy.org/)
## Data
1. Download the datasets.
- [IRSTD-UAV (ours)](https://drive.google.com/file/d/1orHDqG-nLYBSdJETyt6ozpAGOSoKiT-k/view): We constructed the IRSTD-UAV dataset, which contains 17 real infrared video sequences with 15,106 frames. It features small targets in complex backgrounds such as buildings, trees, and clouds, providing a realistic benchmark for UAV-based IRSTD. If you use this dataset, please cite our work.
- [IRDST](https://xzbai.buaa.edu.cn/datasets.html)
2. Perform background alignment before training or testing.
You can use the [GIM](https://github.com/xuelunshen/gim) method for background alignment.
3. Organize the dataset structure as follows:
```
IRSTD-UAV/
├── images/
│ ├── 1/
│ │ ├── 00000000.png
│ │ ├── 00000001.png
│ │ └── ...
│ ├── 2/
│ │ ├── 00000000.png
│ │ ├── 00000001.png
│ │ └── ...
│ └── ...
├── labels/
│ ├── 1/
│ │ ├── 00000000.png
│ │ ├── 00000001.png
│ │ └── ...
│ ├── 2/
│ │ ├── 00000000.png
│ │ ├── 00000001.png
│ │ └── ...
│ └── ...
├── matches/
│ ├── 1/
│ │ ├── 00000000/
│ │ │ ├── match_1.png
│ │ │ ├── match_2.png
│ │ │ └── ...
│ │ ├── 00000001/
│ │ │ ├── match_1.png
│ │ │ ├── match_2.png
│ │ │ └── ...
│ │ └── ...
│ └── ...
├── train.txt
├── val.txt
├── train_coco.json
└── val_coco.json
```
## How To Train
1. **Prepare the environment and dataset**
- Configure environment variables:
```python
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
```
- Set dataset paths in the script:
```python
DATA_PATH = "/Dataset/IRSTD-UAV/"
train_annotation_path = "/Dataset/IRSTD-UAV/train.txt"
val_annotation_path = "/Dataset/IRSTD-UAV/val.txt"
```
2. **Training command**
```bash
python train.py
```
## How To Test
```bash
python test.py
```
The testing results will be saved in the ./results folder.
## Citation
If you find our work and our dataset IRSTD-UAV useful for your research, please consider citing our paper:
```bibtex
@inproceedings{2026AAAI_TDCNet,
title = {Spatio-Temporal Context Learning with Temporal Difference Convolution for Moving Infrared Small Target Detection},
author = {Houzhang Fang and Shukai Guo and Qiuhuan Chen and Yi Chang and Luxin Yan},
booktitle = {Proceedings of the AAAI Conference on Artificial Intelligence},
year = {2026},
pages = { },
}
```
In additoin to the above paper, please also consider citing the following references. Thank you!
```bibtex
@inproceedings{2025CVPR_UniCD,
title = {Detection-Friendly Nonuniformity Correction: A Union Framework for Infrared {UAV} Target Detection},
author = {Houzhang Fang and Xiaolin Wang and Zengyang Li and Lu Wang and Qingshan Li and Yi Chang and Luxin Yan},
booktitle = {IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
year = {2025},
pages = {11898-11907},
}
@ARTICLE{2023TII_DAGNet,
title = {Differentiated Attention Guided Network Over Hierarchical and Aggregated Features for Intelligent {UAV} Surveillance},
author = {Houzhang Fang and Zikai Liao and Xuhua Wang and Yi Chang and Luxin Yan},
journal = {IEEE Transactions on Industrial Informatics},
year = {2023},
volume = {19},
number = {9},
pages = {9909-9920},
}
@inproceedings{2023ACMMM_DANet,
title = {{DANet}: Multi-scale {UAV} Target Detection with Dynamic Feature Perception and Scale-aware Knowledge Distillation},
author = {Houzhang Fang and Zikai Liao and Lu Wang and Qingshan Li and Yi Chang and Luxin Yan and Xuhua Wang},
booktitle = {Proceedings of the 31st ACM International Conference on Multimedia (ACMMM)},
pages = {2121-2130},
year = {2023},
}
@article{2024TGRS_SCINet,
title = {{SCINet}: Spatial and Contrast Interactive Super-Resolution Assisted Infrared {UAV} Target Detection},
author = {Houzhang Fang and Lan Ding and Xiaolin Wang and Yi Chang and Luxin Yan and Li Liu and Jinrui Fang},
journal = {IEEE Transactions on Geoscience and Remote Sensing},
volume = {62},
year = {2024},
pages = {1-22},
}
@ARTICLE{2022TIMFang,
title = {Infrared Small {UAV} Target Detection Based on Depthwise Separable Residual Dense Network and Multiscale Feature Fusion},
author = {Houzhang Fang and Lan Ding and Liming Wang and Yi Chang and Luxin Yan and Jinhui Han},
journal = {IEEE Transactions on Instrumentation and Measurement},
year = {2022},
volume = {71},
number = {},
pages = {1-20},
}
```
## Contact
If you have any question, please contact: houzhangfang@xidian.edu.cn,
Copyright &copy; Xidian University.
## Acknowledgments
Some of the code is based on [STMENet](https://github.com/UESTC-nnLab/STME) and [Video Swin Transformer](https://github.com/SwinTransformer/Video-Swin-Transformer). Thanks for their excellent work!
## License
MIT License. This code is only freely available for non-commercial research use.

View File

@@ -0,0 +1,123 @@
import os
import numpy as np
import torch
from PIL import Image
from torch.utils.data.dataset import Dataset
def cvtColor(image):
if len(np.shape(image)) == 3 and np.shape(image)[2] == 3:
return image
else:
image = image.convert('RGB')
return image
def preprocess(image):
image = image.astype(np.float32)
image /= 255.0
return image
def rand(a=0, b=1):
return np.random.rand() * (b - a) + a
class seqDataset(Dataset):
def __init__(self, dataset_path, image_size, num_frame=5, type='train'):
super(seqDataset, self).__init__()
self.dataset_path = dataset_path
self.img_idx = []
self.anno_idx = []
self.image_size = image_size
self.num_frame = num_frame
self.txt_path = dataset_path
with open(self.txt_path) as f:
data_lines = f.readlines()
self.length = len(data_lines)
for line in data_lines:
line = line.strip('\n').split()
self.img_idx.append(line[0])
self.anno_idx.append(np.array([np.array(list(map(int, box.split(',')))) for box in line[1:]]))
def __len__(self):
return self.length
def __getitem__(self, index):
images, box = self.get_data(index)
images = np.array(images)
images = np.transpose(preprocess(images), (3, 0, 1, 2))
if len(box) != 0:
box[:, 2:4] = box[:, 2:4] - box[:, 0:2]
box[:, 0:2] = box[:, 0:2] + (box[:, 2:4] / 2)
return images, box
def get_data(self, index):
h, w = self.image_size, self.image_size
file_name = self.img_idx[index]
dir_path = file_name.replace('images', 'matches').replace('.png', '')
images = []
for i in range(self.num_frame):
img_path = os.path.join(dir_path, f"img_{i + 1}.png")
img = Image.open(img_path)
img = cvtColor(img)
iw, ih = img.size
scale = min(w / iw, h / ih)
nw = int(iw * scale)
nh = int(ih * scale)
dx = (w - nw) // 2
dy = (h - nh) // 2
img = img.resize((nw, nh), Image.BICUBIC)
new_img = Image.new('RGB', (w, h), (128, 128, 128))
new_img.paste(img, (dx, dy))
images.append(np.array(new_img, np.float32))
image_data = []
image_id = int(file_name.split("/")[-1][:-4])
image_path = file_name.replace(file_name.split("/")[-1], '')
for id in range(0, self.num_frame):
img = Image.open(image_path + '%d.png' % max(image_id - id, 0))
img = cvtColor(img)
iw, ih = img.size
scale = min(w / iw, h / ih)
nw = int(iw * scale)
nh = int(ih * scale)
dx = (w - nw) // 2
dy = (h - nh) // 2
img = img.resize((nw, nh), Image.BICUBIC) # 原图等比列缩放
new_img = Image.new('RGB', (w, h), (128, 128, 128)) # 预期大小的灰色图
new_img.paste(img, (dx, dy)) # 缩放图片放在正中
image_data.append(np.array(new_img, np.float32))
image_data = image_data[::-1] # 关键帧在后 # [5,w,h,3]
for img in image_data:
images.append(img.copy())
label_data = self.anno_idx[index] # 4+1
if len(label_data) > 0:
np.random.shuffle(label_data)
label_data[:, [0, 2]] = label_data[:, [0, 2]] * nw / iw + dx
label_data[:, [1, 3]] = label_data[:, [1, 3]] * nh / ih + dy
label_data[:, 0:2][label_data[:, 0:2] < 0] = 0
label_data[:, 2][label_data[:, 2] > w] = w
label_data[:, 3][label_data[:, 3] > h] = h
box_w = label_data[:, 2] - label_data[:, 0]
box_h = label_data[:, 3] - label_data[:, 1]
label_data = label_data[np.logical_and(box_w > 1, box_h > 1)]
images = np.array(images)
label_data = np.array(label_data, dtype=np.float32)
return images, label_data
def dataset_collate(batch):
images = []
bboxes = []
for img, box in batch:
images.append(img)
bboxes.append(box)
images = torch.from_numpy(np.array(images)).type(torch.FloatTensor)
bboxes = [torch.from_numpy(ann).type(torch.FloatTensor) for ann in bboxes]
return images, bboxes

View File

@@ -0,0 +1,125 @@
import os
import numpy as np
import torch
from PIL import Image
from torch.utils.data.dataset import Dataset
def cvtColor(image):
if len(np.shape(image)) == 3 and np.shape(image)[2] == 3:
return image
else:
image = image.convert('RGB')
return image
def preprocess(image):
image = image.astype(np.float32)
image /= 255.0
return image
def rand(a=0, b=1):
return np.random.rand() * (b - a) + a
class seqDataset(Dataset):
def __init__(self, dataset_path, image_size, num_frame=5, type='train'):
super(seqDataset, self).__init__()
self.dataset_path = dataset_path
self.img_idx = []
self.anno_idx = []
self.image_size = image_size
self.num_frame = num_frame
self.txt_path = dataset_path
with open(self.txt_path) as f:
data_lines = f.readlines()
self.length = len(data_lines)
for line in data_lines:
line = line.strip('\n').split()
self.img_idx.append(line[0])
self.anno_idx.append(np.array([np.array(list(map(int, box.split(',')))) for box in line[1:]]))
def __len__(self):
return self.length
def __getitem__(self, index):
images, box = self.get_data(index)
images = np.array(images)
images = np.transpose(preprocess(images), (3, 0, 1, 2))
if len(box) != 0:
box[:, 2:4] = box[:, 2:4] - box[:, 0:2]
box[:, 0:2] = box[:, 0:2] + (box[:, 2:4] / 2)
return images, box
def get_data(self, index):
h, w = self.image_size, self.image_size
file_name = self.img_idx[index]
dir_path = file_name.replace('images', 'matches').replace('.png', '')
images = []
for i in range(self.num_frame):
img_path = os.path.join(dir_path, f"match_{i + 1}.png")
img = Image.open(img_path)
img = cvtColor(img)
iw, ih = img.size
scale = min(w / iw, h / ih)
nw = int(iw * scale)
nh = int(ih * scale)
dx = (w - nw) // 2
dy = (h - nh) // 2
img = img.resize((nw, nh), Image.BICUBIC)
new_img = Image.new('RGB', (w, h), (128, 128, 128))
new_img.paste(img, (dx, dy))
images.append(np.array(new_img, np.float32))
image_data = []
image_id = int(file_name.split("/")[-1][:8])
image_path = file_name.replace(file_name.split("/")[-1], '')
min_index = image_id - (image_id % 50)
for id in range(0, self.num_frame):
img = Image.open(image_path + '%08d.png' % max(image_id - id, min_index))
img = cvtColor(img)
iw, ih = img.size
scale = min(w / iw, h / ih)
nw = int(iw * scale)
nh = int(ih * scale)
dx = (w - nw) // 2
dy = (h - nh) // 2
img = img.resize((nw, nh), Image.BICUBIC) # 原图等比列缩放
new_img = Image.new('RGB', (w, h), (128, 128, 128)) # 预期大小的灰色图
new_img.paste(img, (dx, dy)) # 缩放图片放在正中
image_data.append(np.array(new_img, np.float32))
image_data = image_data[::-1] # 关键帧在后 # [5,w,h,3]
for img in image_data:
images.append(img.copy())
label_data = self.anno_idx[index] # 4+1
if len(label_data) > 0:
np.random.shuffle(label_data)
label_data[:, [0, 2]] = label_data[:, [0, 2]] * nw / iw + dx
label_data[:, [1, 3]] = label_data[:, [1, 3]] * nh / ih + dy
label_data[:, 0:2][label_data[:, 0:2] < 0] = 0
label_data[:, 2][label_data[:, 2] > w] = w
label_data[:, 3][label_data[:, 3] > h] = h
box_w = label_data[:, 2] - label_data[:, 0]
box_h = label_data[:, 3] - label_data[:, 1]
label_data = label_data[np.logical_and(box_w > 1, box_h > 1)]
images = np.array(images)
label_data = np.array(label_data, dtype=np.float32)
return images, label_data
def dataset_collate(batch):
images = []
bboxes = []
for img, box in batch:
images.append(img)
bboxes.append(box)
images = torch.from_numpy(np.array(images)).type(torch.FloatTensor)
bboxes = [torch.from_numpy(ann).type(torch.FloatTensor) for ann in bboxes]
return images, bboxes

BIN
figs/overall_framework.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.6 MiB

BIN
figs/vis_main.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.4 MiB

373
model/TDCNet/TDCNetwork.py Normal file
View File

@@ -0,0 +1,373 @@
import torch
import torch.nn as nn
from model.TDCNet.TDCSTA import CrossAttention, SelfAttention
from model.TDCNet.backbone3d import Backbone3D
from model.TDCNet.backbonetd import BackboneTD
from model.TDCNet.darknet import BaseConv, CSPDarknet, DWConv
class Feature_Backbone(nn.Module):
def __init__(self, depth=1.0, width=1.0, in_features=("dark3", "dark4", "dark5"), in_channels=[256, 512, 1024], depthwise=False, act="silu"):
super().__init__()
self.backbone = CSPDarknet(depth, width, depthwise=depthwise, act=act)
self.in_features = in_features
def forward(self, input):
out_features = self.backbone.forward(input)
[feat1, feat2, feat3] = [out_features[f] for f in self.in_features]
return [feat1, feat2, feat3]
class Bottleneck(nn.Module):
# Standard bottleneck
def __init__(self, in_channels, out_channels, shortcut=True, expansion=0.5, depthwise=False, act="silu", ):
super().__init__()
hidden_channels = int(out_channels * expansion)
Conv = BaseConv # if depthwise else BaseConv
# --------------------------------------------------#
# 利用1x1卷积进行通道数的缩减。缩减率一般是50%
# --------------------------------------------------#
self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act)
# --------------------------------------------------#
# 利用3x3卷积进行通道数的拓张。并且完成特征提取
# --------------------------------------------------#
self.conv2 = Conv(hidden_channels, out_channels, 3, stride=1, act=act)
# self.conv2=nn.Identity()
self.use_add = shortcut and in_channels == out_channels
def forward(self, x):
y = self.conv2(self.conv1(x))
if self.use_add:
y = y + x
return y
class FusionLayer(nn.Module):
def __init__(self, in_channels, out_channels, expansion=0.5, depthwise=False, act="silu", ):
# ch_in, ch_out, number, shortcut, groups, expansion
super().__init__()
hidden_channels = int(out_channels * expansion)
n = 1
# --------------------------------------------------#
# 主干部分的初次卷积
# --------------------------------------------------#
self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act)
# --------------------------------------------------#
# 大的残差边部分的初次卷积
# --------------------------------------------------#
self.conv2 = BaseConv(hidden_channels, hidden_channels, 1, stride=1, act=act) # in_channel
# -----------------------------------------------#
# 对堆叠的结果进行卷积的处理
# self.deepfeature=nn.Sequential(BaseConv(hidden_channels, hidden_channels//2, 1, stride=1, act=act),
# BaseConv(hidden_channels//2, hidden_channels, 3, stride=1, act=act))
# -----------------------------------------------#
# module_list = [Bottleneck(hidden_channels, hidden_channels, True, 1.0, depthwise, act=act) for _ in range(n)]
# self.deepfeature = nn.Sequential(*module_list)
self.conv3 = BaseConv(hidden_channels, out_channels, 1, stride=1, act=act) # 2*hidden_channel
# --------------------------------------------------#
# 根据循环的次数构建上述Bottleneck残差结构
# --------------------------------------------------#
# module_list = [Bottleneck(hidden_channels, hidden_channels, shortcut, 1.0, depthwise, act=act) for _ in range(n)]
# self.m = nn.Sequential(*module_list)
def forward(self, x):
# -------------------------------#
# x_1是主干部分
# -------------------------------#
# x_1 = self.conv1(x)
x = self.conv1(x)
# -------------------------------#
# x_2是大的残差边部分
# -------------------------------#
# x_2 = self.conv2(x)
x = self.conv2(x)
# -----------------------------------------------#
# 主干部分利用残差结构堆叠继续进行特征提取
# -----------------------------------------------#
# x_1 = self.deepfeature(x_1)
# -----------------------------------------------#
# 主干部分和大的残差边部分进行堆叠
# -----------------------------------------------#
# x = torch.cat((x_1, x_2), dim=1)
# -----------------------------------------------#
# 对堆叠的结果进行卷积的处理
# -----------------------------------------------#
return self.conv3(x)
class Feature_Fusion(nn.Module):
def __init__(self, in_channels=[128, 256, 512], depthwise=False, act="silu"):
super().__init__()
Conv = DWConv if depthwise else BaseConv
self.upsample = nn.Upsample(scale_factor=2, mode="nearest")
# -------------------------------------------#
# 20, 20, 1024 -> 20, 20, 512
# -------------------------------------------#
# self.lateral_conv0 = BaseConv(2 * int(in_channels[2]), int(in_channels[1]), 1, 1, act=act)
self.lateral_conv0 = BaseConv(in_channels[1] + in_channels[2], in_channels[1], 1, 1, act=act)
# -------------------------------------------#
# 40, 40, 1024 -> 40, 40, 512
# -------------------------------------------#
self.C3_p4 = FusionLayer(
int(2 * in_channels[1]),
int(in_channels[1]),
depthwise=depthwise,
act=act,
)
# -------------------------------------------#
# 40, 40, 512 -> 40, 40, 256
# -------------------------------------------#
# self.reduce_conv1 = BaseConv(int(2 * in_channels[1]), int(in_channels[0]), 1, 1, act=act)
self.reduce_conv1 = BaseConv(int(in_channels[0] + in_channels[1]), int(in_channels[0]), 1, 1, act=act)
# -------------------------------------------#
# 80, 80, 512 -> 80, 80, 256
# -------------------------------------------#
self.C3_p3 = FusionLayer(
int(2 * in_channels[0]),
int(in_channels[0]),
depthwise=depthwise,
act=act,
)
def forward(self, input):
out_features = input # self.backbone.forward(input)
[feat1, feat2, feat3] = out_features # [out_features[f] for f in self.in_features]
# -------------------------------------------#
# 20, 20, 1024 -> 20, 20, 512
# -------------------------------------------#
# P5 = self.lateral_conv0(feat3)
# -------------------------------------------#
# 20, 20, 512 -> 40, 40, 512
# -------------------------------------------#
P5_upsample = self.upsample(feat3)
# -------------------------------------------#
# 40, 40, 512 + 40, 40, 512 -> 40, 40, 1024
# -------------------------------------------#
P5_upsample = torch.cat([P5_upsample, feat2], 1)
# pdb.set_trace()
# -------------------------------------------#
# 40, 40, 1024 -> 40, 40, 512
# -------------------------------------------#
P4 = self.lateral_conv0(P5_upsample)
# P5_upsample = self.C3_p4(P5_upsample)
# -------------------------------------------#
# 40, 40, 512 -> 40, 40, 256
# -------------------------------------------#
# P4 = self.reduce_conv1(P5_upsample)
# -------------------------------------------#
# 40, 40, 256 -> 80, 80, 256
# -------------------------------------------#
P4_upsample = self.upsample(P4)
# -------------------------------------------#
# 80, 80, 256 + 80, 80, 256 -> 80, 80, 512
# -------------------------------------------#
P4_upsample = torch.cat([P4_upsample, feat1], 1)
# -------------------------------------------#
# 80, 80, 512 -> 80, 80, 256
# -------------------------------------------#
P3_out = self.reduce_conv1(P4_upsample)
# P3_out = self.C3_p3(P4_upsample)
return P3_out
class YOLOXHead(nn.Module):
def __init__(self, num_classes, width=1.0, in_channels=[16, 32, 64], act="silu"):
super().__init__()
Conv = BaseConv
self.cls_convs = nn.ModuleList()
self.reg_convs = nn.ModuleList()
self.cls_preds = nn.ModuleList()
self.reg_preds = nn.ModuleList()
self.obj_preds = nn.ModuleList()
self.stems = nn.ModuleList()
for i in range(len(in_channels)):
self.stems.append(BaseConv(in_channels=int(in_channels[i]), out_channels=int(256 * width), ksize=1, stride=1, act=act)) # 128-> 256 通道整合
self.cls_convs.append(nn.Sequential(*[
Conv(in_channels=int(256 * width), out_channels=int(256 * width), ksize=3, stride=1, act=act),
Conv(in_channels=int(256 * width), out_channels=int(256 * width), ksize=3, stride=1, act=act),
]))
self.cls_preds.append(
nn.Conv2d(in_channels=int(256 * width), out_channels=num_classes, kernel_size=1, stride=1, padding=0)
)
self.reg_convs.append(nn.Sequential(*[
Conv(in_channels=int(256 * width), out_channels=int(256 * width), ksize=3, stride=1, act=act),
Conv(in_channels=int(256 * width), out_channels=int(256 * width), ksize=3, stride=1, act=act)
]))
self.reg_preds.append(
nn.Conv2d(in_channels=int(256 * width), out_channels=4, kernel_size=1, stride=1, padding=0)
)
self.obj_preds.append(
nn.Conv2d(in_channels=int(256 * width), out_channels=1, kernel_size=1, stride=1, padding=0)
)
def forward(self, inputs):
# ---------------------------------------------------#
# inputs输入
# P3_out 80, 80, 256
# P4_out 40, 40, 512
# P5_out 20, 20, 1024
# ---------------------------------------------------#
outputs = []
for k, x in enumerate(inputs):
# ---------------------------------------------------#
# 利用1x1卷积进行通道整合
# ---------------------------------------------------#
x = self.stems[k](x)
# ---------------------------------------------------#
# 利用两个卷积标准化激活函数来进行特征提取
# ---------------------------------------------------#
cls_feat = self.cls_convs[k](x)
# ---------------------------------------------------#
# 判断特征点所属的种类
# 80, 80, num_classes
# 40, 40, num_classes
# 20, 20, num_classes
# ---------------------------------------------------#
cls_output = self.cls_preds[k](cls_feat)
# ---------------------------------------------------#
# 利用两个卷积标准化激活函数来进行特征提取
# ---------------------------------------------------#
reg_feat = self.reg_convs[k](x)
# ---------------------------------------------------#
# 特征点的回归系数
# reg_pred 80, 80, 4
# reg_pred 40, 40, 4
# reg_pred 20, 20, 4
# ---------------------------------------------------#
reg_output = self.reg_preds[k](reg_feat)
# ---------------------------------------------------#
# 判断特征点是否有对应的物体
# obj_pred 80, 80, 1
# obj_pred 40, 40, 1
# obj_pred 20, 20, 1
# ---------------------------------------------------#
obj_output = self.obj_preds[k](reg_feat)
output = torch.cat([reg_output, obj_output, cls_output], 1)
outputs.append(output)
return outputs
model_config = {
'backbone_2d': 'yolo_free_nano',
'pretrained_2d': True,
'stride': [8, 16, 32],
# ## 3D
'backbone_3d': 'shufflenetv2',
'model_size': '1.0x', # 1.0x
'pretrained_3d': True,
'memory_momentum': 0.9,
'head_dim': 128, # 64
'head_norm': 'BN',
'head_act': 'lrelu',
'num_cls_heads': 2,
'num_reg_heads': 2,
'head_depthwise': True,
}
def build_backbone_3d(cfg, pretrained=False):
backbone = Backbone3D(cfg, pretrained)
return backbone, backbone.feat_dim
mcfg = model_config
class TDCNetwork(nn.Module):
def __init__(self, num_classes, fp16=False, num_frame=5):
super(TDCNetwork, self).__init__()
self.num_frame = num_frame
self.backbone2d = Feature_Backbone(0.33, 0.50)
self.backbone3d, bk_dim_3d = build_backbone_3d(mcfg, pretrained=mcfg['pretrained_3d'] and True)
self.backbonetd = BackboneTD(mcfg, pretrained=mcfg['pretrained_3d'] and True)
self.q_sa1 = SelfAttention(128, window_size=(2, 8, 8), num_heads=4, use_shift=True, mlp_ratio=1.5)
self.k_sa1 = SelfAttention(128, window_size=(2, 8, 8), num_heads=4, use_shift=True, mlp_ratio=1.5)
self.v_sa1 = SelfAttention(128, window_size=(2, 8, 8), num_heads=4, use_shift=True, mlp_ratio=1.5)
self.q_sa2 = SelfAttention(256, window_size=(2, 4, 4), num_heads=4, use_shift=True, mlp_ratio=1.5)
self.k_sa2 = SelfAttention(256, window_size=(2, 4, 4), num_heads=4, use_shift=True, mlp_ratio=1.5)
self.v_sa2 = SelfAttention(256, window_size=(2, 4, 4), num_heads=4, use_shift=True, mlp_ratio=1.5)
self.q_sa3 = SelfAttention(512, window_size=(2, 2, 2), num_heads=4, use_shift=True, mlp_ratio=1.5)
self.k_sa3 = SelfAttention(512, window_size=(2, 2, 2), num_heads=4, use_shift=True, mlp_ratio=1.5)
self.v_sa3 = SelfAttention(512, window_size=(2, 2, 2), num_heads=4, use_shift=True, mlp_ratio=1.5)
self.ca1 = CrossAttention(128, window_size=(2, 8, 8), num_heads=4)
self.ca2 = CrossAttention(256, window_size=(2, 4, 4), num_heads=4)
self.ca3 = CrossAttention(512, window_size=(2, 2, 2), num_heads=4)
self.feature_fusion = Feature_Fusion()
self.head = YOLOXHead(num_classes=num_classes, width=1.0, in_channels=[128], act="silu")
def forward(self, inputs):
# inputs: [B, 3, T, H, W]
if len(inputs.shape) == 5:
T = inputs.shape[2]
diff_imgs = inputs[:, :, :T // 2, :, :]
mt_imgs = inputs[:, :, T // 2:, :, :]
else:
diff_imgs = inputs
mt_imgs = inputs
q_3d = self.backbonetd(diff_imgs)
q_3d1, q_3d2, q_3d3 = q_3d['stage2'], q_3d['stage3'], q_3d['stage4']
k_3d = self.backbone3d(mt_imgs)
k_3d1, k_3d2, k_3d3 = k_3d['stage2'], k_3d['stage3'], k_3d['stage4']
[feat1, feat2, feat3] = self.backbone2d(inputs[:, :, -1, :, :])
def to_5d(x):
# [B, C, T, H, W] -> [B, T, H, W, C]
return x.permute(0, 2, 3, 4, 1)
q_3d1 = to_5d(q_3d1)
q_3d2 = to_5d(q_3d2)
q_3d3 = to_5d(q_3d3)
k_3d1 = to_5d(k_3d1)
k_3d2 = to_5d(k_3d2)
k_3d3 = to_5d(k_3d3)
# V特征扩展T维度与Q/K对齐假设V为最后一帧T=1
def expand_v(x, T):
# [B, C, H, W] -> [B, T, H, W, C]复制T次
x = x.permute(0, 2, 3, 1).unsqueeze(1)
x = x.expand(-1, T, -1, -1, -1)
return x
T1 = q_3d1.shape[1]
T2 = q_3d2.shape[1]
T3 = q_3d3.shape[1]
v1 = expand_v(feat1, T1)
v2 = expand_v(feat2, T2)
v3 = expand_v(feat3, T3)
q1 = self.q_sa1(q_3d1)
k1 = self.k_sa1(k_3d1)
v1 = self.v_sa1(v1)
q2 = self.q_sa2(q_3d2)
k2 = self.k_sa2(k_3d2)
v2 = self.v_sa2(v2)
q3 = self.q_sa3(q_3d3)
k3 = self.k_sa3(k_3d3)
v3 = self.v_sa3(v3)
out1 = self.ca1(q1, k1, v1)
out2 = self.ca2(q2, k2, v2)
out3 = self.ca3(q3, k3, v3)
out1 = out1.mean(1).permute(0, 3, 1, 2)
out2 = out2.mean(1).permute(0, 3, 1, 2)
out3 = out3.mean(1).permute(0, 3, 1, 2)
feat_all = self.feature_fusion([out1, out2, out3])
outputs = self.head([feat_all])
return outputs

131
model/TDCNet/TDCR.py Normal file
View File

@@ -0,0 +1,131 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class TDC(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=(5, 3, 3), stride=1, padding=(2, 1, 1), groups=1, bias=False, step=1):
super().__init__()
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, bias=bias)
self.step = step
self.groups = groups
def get_time_gradient_weight(self):
weight = self.conv.weight
kT, kH, kW = weight.shape[2:]
grad_weight = torch.zeros_like(weight, device=weight.device, dtype=weight.dtype)
if kT == 5:
if self.step == -1:
grad_weight[:, :, :, :, :] = -weight[:, :, :, :, :]
grad_weight[:, :, 4, :, :] = weight[:, :, 0, :, :] + weight[:, :, 1, :, :] + weight[:, :, 2, :, :] + weight[:, :, 3, :, :] + weight[:, :, 4, :, :]
elif self.step == 1:
grad_weight[:, :, 4, :, :] = weight[:, :, 4, :, :]
grad_weight[:, :, 3, :, :] = weight[:, :, 3, :, :] - weight[:, :, 4, :, :]
grad_weight[:, :, 2, :, :] = weight[:, :, 2, :, :] - weight[:, :, 3, :, :]
grad_weight[:, :, 1, :, :] = weight[:, :, 1, :, :] - weight[:, :, 2, :, :]
grad_weight[:, :, 0, :, :] = -weight[:, :, 1, :, :]
elif self.step == 2:
grad_weight[:, :, 4, :, :] = weight[:, :, 4, :, :]
grad_weight[:, :, 3, :, :] = weight[:, :, 3, :, :]
grad_weight[:, :, 2, :, :] = weight[:, :, 2, :, :] - weight[:, :, 4, :, :]
grad_weight[:, :, 1, :, :] = -weight[:, :, 3, :, :]
grad_weight[:, :, 0, :, :] = -weight[:, :, 2, :, :]
else:
grad_weight = weight
bias = self.conv.bias
if bias is None:
bias = torch.zeros(weight.shape[0], device=weight.device, dtype=weight.dtype)
return grad_weight, bias
def forward(self, x):
weight, bias = self.get_time_gradient_weight()
x_diff = F.conv3d(x, weight, bias, stride=self.conv.stride, groups=self.groups, padding=self.conv.padding)
return x_diff
class RepConv3D(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=(5, 3, 3), stride=1, padding=(2, 1, 1), groups=1, deploy=False):
super(RepConv3D, self).__init__()
self.deploy = deploy
self.in_channels = in_channels
self.out_channels = out_channels
self.stride = stride
self.groups = groups
if self.deploy:
self.conv_reparam = nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding, groups=groups, bias=True)
else:
self.l_tdc = nn.Sequential(
TDC(in_channels, out_channels, kernel_size, stride, padding, groups=groups, bias=False, step=-1),
nn.BatchNorm3d(out_channels)
)
self.s_tdc = nn.Sequential(
TDC(in_channels, out_channels, kernel_size, stride, padding, groups=groups, bias=False, step=1),
nn.BatchNorm3d(out_channels)
)
self.m_tdc = nn.Sequential(
TDC(in_channels, out_channels, kernel_size, stride, padding, groups=groups, bias=False, step=2),
nn.BatchNorm3d(out_channels)
)
def forward(self, x):
if self.deploy:
out = F.relu(self.conv_reparam(x))
else:
out = self.s_tdc(x) + self.m_tdc(x) + self.l_tdc(x)
out = F.relu(out)
return out
def get_equivalent_kernel_bias(self):
kernel_s_tdc, bias_s_tdc = self._fuse_conv_bn(self.s_tdc)
kernel_m_tdc, bias_m_tdc = self._fuse_conv_bn(self.m_tdc)
kernel_l_tdc, bias_l_tdc = self._fuse_conv_bn(self.l_tdc)
kernel = kernel_s_tdc + kernel_m_tdc + kernel_l_tdc
bias = bias_s_tdc + bias_m_tdc + bias_l_tdc
return kernel, bias
def switch_to_deploy(self):
if self.deploy:
return
kernel, bias = self.get_equivalent_kernel_bias()
self.conv_reparam = nn.Conv3d(
self.in_channels, self.out_channels, (5, 3, 3), self.stride,
(2, 1, 1), groups=self.groups, bias=True
)
self.conv_reparam.weight.data = kernel
self.conv_reparam.bias.data = bias
self.deploy = True
del self.s_tdc
del self.m_tdc
del self.l_tdc
@staticmethod
def _fuse_conv_bn(branch):
if branch is None:
return 0, 0
def find_conv(module):
if isinstance(module, nn.Conv3d):
return module
for child in module.children():
conv = find_conv(child)
if conv is not None:
return conv
return None
conv = find_conv(branch[0])
bn = branch[1]
if hasattr(branch[0], 'get_time_gradient_weight'):
w, bias = branch[0].get_time_gradient_weight()
else:
w = conv.weight
if conv.bias is not None:
bias = conv.bias
else:
bias = torch.zeros_like(bn.running_mean)
mean = bn.running_mean
var_sqrt = torch.sqrt(bn.running_var + bn.eps)
gamma = bn.weight
beta = bn.bias
w = w * (gamma / var_sqrt).reshape(-1, 1, 1, 1, 1)
bias = (bias - mean) / var_sqrt * gamma + beta
return w, bias

239
model/TDCNet/TDCSTA.py Normal file
View File

@@ -0,0 +1,239 @@
from functools import reduce
from operator import mul
import torch
import torch.nn as nn
import torch.nn.functional as F
class WindowAttention3D(nn.Module):
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.dim = dim
self.window_size = window_size # (T, H, W)
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1) * (2 * window_size[2] - 1), num_heads))
coords_t = torch.arange(self.window_size[0])
coords_h = torch.arange(self.window_size[1])
coords_w = torch.arange(self.window_size[2])
coords = torch.stack(torch.meshgrid(coords_t, coords_h, coords_w, indexing='ij')) # 3, T, H, W
coords_flatten = torch.flatten(coords, 1)
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
relative_coords[:, :, 0] += self.window_size[0] - 1
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 2] += self.window_size[2] - 1
relative_coords[:, :, 0] *= (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1)
relative_coords[:, :, 1] *= (2 * self.window_size[2] - 1)
relative_position_index = relative_coords.sum(-1)
self.register_buffer("relative_position_index", relative_position_index)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.softmax = nn.Softmax(dim=-1)
nn.init.trunc_normal_(self.relative_position_bias_table, std=.02)
def forward(self, x, k=None, v=None, mask=None):
B_, N, C = x.shape
if k is None or v is None:
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
else:
q = x.reshape(B_, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) # [B_, num_heads, N, head_dim]
k = k.reshape(B_, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
v = v.reshape(B_, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
relative_position_bias = self.relative_position_bias_table[self.relative_position_index[:N, :N].reshape(-1)].reshape(N, N, -1)
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
def window_partition(x, window_size):
B, T, H, W, C = x.shape
window_size = list(window_size)
if T < window_size[0]:
window_size[0] = T
if H < window_size[1]:
window_size[1] = H
if W < window_size[2]:
window_size[2] = W
x = x.view(B, T // window_size[0] if window_size[0] > 0 else 1, window_size[0],
H // window_size[1] if window_size[1] > 0 else 1, window_size[1],
W // window_size[2] if window_size[2] > 0 else 1, window_size[2], C)
windows = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, reduce(mul, window_size), C)
return windows
def window_reverse(windows, window_size, B, T, H, W):
x = windows.view(B, T // window_size[0], H // window_size[1], W // window_size[2], window_size[0], window_size[1], window_size[2], -1)
x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(B, T, H, W, -1)
return x
def get_window_size(x_size, window_size, shift_size=None):
use_window_size = list(window_size)
if shift_size is not None:
use_shift_size = list(shift_size)
for i in range(len(x_size)):
if x_size[i] <= window_size[i]:
use_window_size[i] = x_size[i]
if shift_size is not None:
use_shift_size[i] = 0
if shift_size is None:
return tuple(use_window_size)
else:
return tuple(use_window_size), tuple(use_shift_size)
class SelfAttention(nn.Module):
def __init__(self, dim, window_size=(2, 8, 8), num_heads=8, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., use_shift=False, shift_size=None, mlp_ratio=2.0, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.window_size = window_size
self.num_heads = num_heads
self.use_shift = use_shift
self.shift_size = shift_size if shift_size is not None else tuple([w // 2 for w in window_size]) if use_shift else tuple([0] * len(window_size))
self.attn1 = WindowAttention3D(dim, window_size=self.window_size, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=proj_drop)
self.attn2 = WindowAttention3D(dim, window_size=self.window_size, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=proj_drop)
self.norm1 = norm_layer(dim)
self.norm2 = norm_layer(dim)
self.norm3 = norm_layer(dim)
self.norm4 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp1 = nn.Sequential(
nn.Linear(dim, mlp_hidden_dim),
nn.GELU(),
nn.Linear(mlp_hidden_dim, dim)
)
self.mlp2 = nn.Sequential(
nn.Linear(dim, mlp_hidden_dim),
nn.GELU(),
nn.Linear(mlp_hidden_dim, dim)
)
def create_mask(self, x_shape, device):
B, T, H, W, C = x_shape
img_mask = torch.zeros((1, T, H, W, 1), device=device)
cnt = 0
t_slices = (slice(0, -self.window_size[0]), slice(-self.window_size[0], -self.shift_size[0]), slice(-self.shift_size[0], None))
h_slices = (slice(0, -self.window_size[1]), slice(-self.window_size[1], -self.shift_size[1]), slice(-self.shift_size[1], None))
w_slices = (slice(0, -self.window_size[2]), slice(-self.window_size[2], -self.shift_size[2]), slice(-self.shift_size[2], None))
for t in t_slices:
for h in h_slices:
for w in w_slices:
img_mask[:, t, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, self.window_size)
mask_windows = mask_windows.squeeze(-1)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
return attn_mask
def forward(self, x):
B, T, H, W, C = x.shape
window_size, shift_size = get_window_size((T, H, W), self.window_size, self.shift_size)
shortcut = x
x = self.norm1(x)
pad_t = (window_size[0] - T % window_size[0]) % window_size[0]
pad_h = (window_size[1] - H % window_size[1]) % window_size[1]
pad_w = (window_size[2] - W % window_size[2]) % window_size[2]
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h, 0, pad_t))
shortcut = F.pad(shortcut, (0, 0, 0, pad_w, 0, pad_h, 0, pad_t))
_, Tp, Hp, Wp, _ = x.shape
x_windows = window_partition(x, window_size)
attn_windows = self.attn1(x_windows, mask=None)
attn_windows = attn_windows.view(-1, *(window_size + (C,)))
x = window_reverse(attn_windows, window_size, B, Tp, Hp, Wp)
x = shortcut + x
x = x + self.mlp1(self.norm2(x))
shortcut = x
x = self.norm3(x)
if self.use_shift and any(i > 0 for i in shift_size):
shifted_x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1], -shift_size[2]), dims=(1, 2, 3))
attn_mask = self.create_mask((B, Tp, Hp, Wp, C), x.device)
x_windows = window_partition(shifted_x, window_size)
attn_windows = self.attn2(x_windows, mask=attn_mask)
attn_windows = attn_windows.view(-1, *(window_size + (C,)))
shifted_x = window_reverse(attn_windows, window_size, B, Tp, Hp, Wp)
x = torch.roll(shifted_x, shifts=(shift_size[0], shift_size[1], shift_size[2]), dims=(1, 2, 3))
if pad_t > 0:
x = x[:, :T, :, :, :]
shortcut = shortcut[:, :T, :, :, :]
if pad_h > 0:
x = x[:, :, :H, :, :]
shortcut = shortcut[:, :, :H, :, :]
if pad_w > 0:
x = x[:, :, :, :W, :]
shortcut = shortcut[:, :, :, :W, :]
x = shortcut + x
x = x + self.mlp2(self.norm4(x))
return x
class CrossAttention(nn.Module):
def __init__(self, dim, window_size=(2, 8, 8), num_heads=8, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., mlp_ratio=2.0, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.window_size = window_size
self.num_heads = num_heads
self.norm1_q = norm_layer(dim)
self.norm1_k = norm_layer(dim)
self.norm1_v = norm_layer(dim)
self.attn = WindowAttention3D(dim, window_size=self.window_size, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=proj_drop)
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(dim, mlp_hidden_dim),
nn.GELU(),
nn.Linear(mlp_hidden_dim, dim)
)
def forward(self, q, k, v):
B, T, H, W, C = q.shape
window_size = get_window_size((T, H, W), self.window_size)
shortcut = v
q = self.norm1_q(q)
k = self.norm1_k(k)
v = self.norm1_v(v)
pad_t = (window_size[0] - T % window_size[0]) % window_size[0]
pad_h = (window_size[1] - H % window_size[1]) % window_size[1]
pad_w = (window_size[2] - W % window_size[2]) % window_size[2]
q = F.pad(q, (0, 0, 0, pad_w, 0, pad_h, 0, pad_t))
k = F.pad(k, (0, 0, 0, pad_w, 0, pad_h, 0, pad_t))
v = F.pad(v, (0, 0, 0, pad_w, 0, pad_h, 0, pad_t))
_, Tp, Hp, Wp, _ = q.shape
q_windows = window_partition(q, window_size)
k_windows = window_partition(k, window_size)
v_windows = window_partition(v, window_size)
attn_windows = self.attn(q_windows, k_windows, v_windows)
attn_windows = attn_windows.view(-1, *(window_size + (C,)))
shifted_x = window_reverse(attn_windows, window_size, B, Tp, Hp, Wp)
x = shifted_x
if pad_t > 0:
x = x[:, :T, :, :, :]
if pad_h > 0:
x = x[:, :, :H, :, :]
if pad_w > 0:
x = x[:, :, :, :W, :]
x = shortcut + x
x = x + self.mlp(self.norm2(x))
return x

0
model/TDCNet/__init__.py Normal file
View File

272
model/TDCNet/backbone3d.py Normal file
View File

@@ -0,0 +1,272 @@
import torch
import torch.nn as nn
from torch.hub import load_state_dict_from_url
model_urls = {
"0.25x": "https://github.com/yjh0410/PyTorch_YOWO/releases/download/yowo-weight/kinetics_shufflenetv2_0.25x_RGB_16_best.pth",
"1.0x": "https://github.com/yjh0410/PyTorch_YOWO/releases/download/yowo-weight/kinetics_shufflenetv2_1.0x_RGB_16_best.pth",
"1.5x": "https://github.com/yjh0410/PyTorch_YOWO/releases/download/yowo-weight/kinetics_shufflenetv2_1.5x_RGB_16_best.pth",
"2.0x": "https://github.com/yjh0410/PyTorch_YOWO/releases/download/yowo-weight/kinetics_shufflenetv2_2.0x_RGB_16_best.pth",
}
def load_weight(model, arch):
url = model_urls[arch]
# check
if url is None:
print('No pretrained weight for 3D CNN: {}'.format(arch.upper()))
return model
# checkpoint state dict
checkpoint = load_state_dict_from_url(url=url, map_location="cpu", check_hash=True)
checkpoint_state_dict = checkpoint.pop('state_dict')
# model state dict
model_state_dict = model.state_dict()
# reformat checkpoint_state_dict:
new_state_dict = {}
for k in checkpoint_state_dict.keys():
v = checkpoint_state_dict[k]
new_state_dict[k[7:]] = v
# pdb.set_trace()
# check
for k in list(new_state_dict.keys()):
if k in model_state_dict:
shape_model = tuple(model_state_dict[k].shape)
shape_checkpoint = tuple(new_state_dict[k].shape)
if shape_model != shape_checkpoint:
new_state_dict.pop(k)
else:
new_state_dict.pop(k)
model.load_state_dict(new_state_dict)
return model
def conv_bn(inp, oup, stride):
return nn.Sequential(
nn.Conv3d(inp, oup, kernel_size=(5, 3, 3), stride=stride, padding=(2, 1, 1), bias=False),
nn.BatchNorm3d(oup),
nn.ReLU(inplace=True)
)
class InvertedResidual(nn.Module):
def __init__(self, inp, oup, stride):
super(InvertedResidual, self).__init__()
self.stride = stride
assert stride in [1, 2]
oup_inc = oup // 2
if self.stride == 1:
self.banch2 = nn.Sequential(
# pw
nn.Conv3d(oup_inc, oup_inc, 1, 1, 0, bias=False),
nn.BatchNorm3d(oup_inc),
nn.ReLU(inplace=True),
# dw
nn.Conv3d(oup_inc, oup_inc, (5, 3, 3), stride, (2, 1, 1), groups=oup_inc, bias=False),
nn.BatchNorm3d(oup_inc),
# pw-linear
nn.Conv3d(oup_inc, oup_inc, 1, 1, 0, bias=False),
nn.BatchNorm3d(oup_inc),
nn.ReLU(inplace=True)
)
else:
self.banch1 = nn.Sequential(
# dw
nn.Conv3d(inp, inp, (5, 3, 3), stride, (2, 1, 1), groups=inp, bias=False),
nn.BatchNorm3d(inp),
# pw-linear
nn.Conv3d(inp, oup_inc, 1, 1, 0, bias=False),
nn.BatchNorm3d(oup_inc),
nn.ReLU(inplace=True)
)
self.banch2 = nn.Sequential(
# pw
nn.Conv3d(inp, oup_inc, 1, 1, 0, bias=False),
nn.BatchNorm3d(oup_inc),
nn.ReLU(inplace=True),
# dw
nn.Conv3d(oup_inc, oup_inc, (5, 3, 3), stride, (2, 1, 1), groups=oup_inc, bias=False),
nn.BatchNorm3d(oup_inc),
# pw-linear
nn.Conv3d(oup_inc, oup_inc, 1, 1, 0, bias=False),
nn.BatchNorm3d(oup_inc),
nn.ReLU(inplace=True)
)
@staticmethod
def _concat(x, out):
# concatenate along channel axis
return torch.cat((x, out), 1)
def forward(self, x):
if self.stride == 1:
x1 = x[:, :(x.shape[1] // 2), :, :, :]
x2 = x[:, (x.shape[1] // 2):, :, :, :]
out = self._concat(x1, self.banch2(x2))
elif self.stride == 2:
out = self._concat(self.banch1(x), self.banch2(x))
return channel_shuffle(out, 2)
def channel_shuffle(x, groups):
'''Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]'''
batchsize, num_channels, depth, height, width = x.data.size()
channels_per_group = num_channels // groups
# reshape
x = x.view(batchsize, groups,
channels_per_group, depth, height, width)
# permute
x = x.permute(0, 2, 1, 3, 4, 5).contiguous()
# flatten
x = x.view(batchsize, num_channels, depth, height, width)
return x
class ShuffleNetV2(nn.Module):
def __init__(self, width_mult='1.0x', num_classes=600):
super(ShuffleNetV2, self).__init__()
self.stage_repeats = [4, 8, 4]
# index 0 is invalid and should never be called.
# only used for indexing convenience.
if width_mult == '0.25x':
self.stage_out_channels = [-1, 24, 32, 64, 128]
elif width_mult == '0.5x':
self.stage_out_channels = [-1, 24, 48, 96, 192]
elif width_mult == '1.0x':
self.stage_out_channels = [-1, 24, 128, 256, 512]
elif width_mult == '1.5x':
self.stage_out_channels = [-1, 24, 176, 352, 704]
elif width_mult == '2.0x':
self.stage_out_channels = [-1, 24, 224, 488, 976]
# building first layer
input_channel = self.stage_out_channels[1]
self.conv1 = conv_bn(3, input_channel, stride=(1, 2, 2))
self.maxpool = nn.MaxPool3d(kernel_size=3, stride=2, padding=1)
self.features = []
self.features1 = []
self.features2 = []
self.features3 = []
# building inverted residual blocks
for idxstage in range(len(self.stage_repeats)):
numrepeat = self.stage_repeats[idxstage]
output_channel = self.stage_out_channels[idxstage + 2]
for i in range(numrepeat):
stride = 2 if i == 0 else 1
self.features.append(InvertedResidual(input_channel, output_channel, stride))
input_channel = output_channel
self.features = nn.Sequential(*self.features)
# for idxstage in range(len(self.stage_repeats)):
# numrepeat = self.stage_repeats[idxstage]
# output_channel = self.stage_out_channels[idxstage+2]
# for i in range(numrepeat):
# if idxstage==0:
# stride = 2 if i == 0 else 1
# self.features1.append(InvertedResidual(input_channel, output_channel, stride))
# input_channel = output_channel
# elif idxstage==1:
# stride = 2 if i == 0 else 1
# self.features2.append(InvertedResidual(input_channel, output_channel, stride))
# input_channel = output_channel
# elif idxstage==2:
# stride = 2 if i == 0 else 1
# self.features3.append(InvertedResidual(input_channel, output_channel, stride))
# input_channel = output_channel
# # make it nn.Sequential
# self.features1 = nn.Sequential(*self.features1)
# self.features2 = nn.Sequential(*self.features2)
# self.features3 = nn.Sequential(*self.features3)
# # building last several layers
# self.conv_last = conv_1x1x1_bn(input_channel, self.stage_out_channels[-1])
# self.avgpool = nn.AvgPool3d((2, 1, 1), stride=1)
def forward(self, x):
outputs = {}
# pdb.set_trace() #(1,3,16,512,512) #(1,3,5,512,512)
x = self.conv1(x) # (1,24,16,256,256) #(1,24,5,256,256)
x = self.maxpool(x) # (1,24,8,128,128) #(1,24,3,128,128)
# outputs['stage1'] = x
# x=self.features(x)
x = self.features[:4](x) # (1,116,4,64,64) #(1,116,2,64,64)
outputs['stage2'] = x # torch.mean(x, dim=2, keepdim=True).squeeze(2)
x = self.features[4:12](x) # (1,232,2,32,32) #(1,232,1,32,32)
outputs['stage3'] = x # torch.mean(x, dim=2, keepdim=True).squeeze(2)
x = self.features[12:16](x) # (1,464,1,16,16) #(1,464,1,16,16)
outputs['stage4'] = x # torch.mean(x, dim=2, keepdim=True).squeeze(2)
# out = self.conv_last(out)
# if x.size(2) > 1:
# x = torch.mean(x, dim=2, keepdim=True)
# return x.squeeze(2)
return outputs
def build_shufflenetv2_3d(model_size='0.25x', pretrained=False):
model = ShuffleNetV2(model_size)
feats = model.stage_out_channels[-1]
# if pretrained:
# model = load_weight(model, model_size)
return model, feats
def build_3d_cnn(cfg, pretrained=False):
if 'resnet' in cfg['backbone_3d']:
model, feat_dims = build_resnet_3d(
model_name=cfg['backbone_3d'],
pretrained=pretrained
)
elif 'resnext' in cfg['backbone_3d']:
model, feat_dims = build_resnext_3d(
model_name=cfg['backbone_3d'],
pretrained=pretrained
)
elif 'shufflenetv2' in cfg['backbone_3d']:
model, feat_dims = build_shufflenetv2_3d(
model_size=cfg['model_size'],
pretrained=pretrained
)
else:
print('Unknown Backbone ...')
exit()
return model, feat_dims
class Backbone3D(nn.Module):
def __init__(self, cfg, pretrained=False):
super().__init__()
self.cfg = cfg
# 3D CNN
self.backbone, self.feat_dim = build_3d_cnn(cfg, pretrained)
def forward(self, x):
"""
Input:
x: (Tensor) -> [B, C, T, H, W]
Output:
y: (List) -> [
(Tensor) -> [B, C1, H1, W1],
(Tensor) -> [B, C2, H2, W2],
(Tensor) -> [B, C3, H3, W3]
]
"""
feat = self.backbone(x)
return feat

281
model/TDCNet/backbonetd.py Normal file
View File

@@ -0,0 +1,281 @@
import os
import torch
import torch.nn as nn
from matplotlib import pyplot as plt
from torch.hub import load_state_dict_from_url
from model.TDCNet.TDCR import RepConv3D
model_urls = {
"0.25x": "https://github.com/yjh0410/PyTorch_YOWO/releases/download/yowo-weight/kinetics_shufflenetv2_0.25x_RGB_16_best.pth",
"1.0x": "https://github.com/yjh0410/PyTorch_YOWO/releases/download/yowo-weight/kinetics_shufflenetv2_1.0x_RGB_16_best.pth",
"1.5x": "https://github.com/yjh0410/PyTorch_YOWO/releases/download/yowo-weight/kinetics_shufflenetv2_1.5x_RGB_16_best.pth",
"2.0x": "https://github.com/yjh0410/PyTorch_YOWO/releases/download/yowo-weight/kinetics_shufflenetv2_2.0x_RGB_16_best.pth",
}
def load_weight(model, arch):
url = model_urls[arch]
# check
if url is None:
print('No pretrained weight for 3D CNN: {}'.format(arch.upper()))
return model
# checkpoint state dict
checkpoint = load_state_dict_from_url(url=url, map_location="cpu", check_hash=True)
checkpoint_state_dict = checkpoint.pop('state_dict')
# model state dict
model_state_dict = model.state_dict()
# reformat checkpoint_state_dict:
new_state_dict = {}
for k in checkpoint_state_dict.keys():
v = checkpoint_state_dict[k]
new_state_dict[k[7:]] = v
# pdb.set_trace()
# check
for k in list(new_state_dict.keys()):
if k in model_state_dict:
shape_model = tuple(model_state_dict[k].shape)
shape_checkpoint = tuple(new_state_dict[k].shape)
if shape_model != shape_checkpoint:
new_state_dict.pop(k)
else:
new_state_dict.pop(k)
model.load_state_dict(new_state_dict)
return model
def conv_bn(inp, oup, stride):
# return nn.Sequential(
# nn.Conv3d(inp, oup, kernel_size=3, stride=stride, padding=(1,1,1), bias=False),
# nn.BatchNorm3d(oup),
# nn.ReLU(inplace=True)
# )
return RepConv3D(inp, oup, (5, 3, 3), stride, (2, 1, 1))
class InvertedResidual(nn.Module):
def __init__(self, inp, oup, stride):
super(InvertedResidual, self).__init__()
self.stride = stride
assert stride in [1, 2]
oup_inc = oup // 2
if self.stride == 1:
self.banch2 = nn.Sequential(
# pw
nn.Conv3d(oup_inc, oup_inc, 1, 1, 0, bias=False),
nn.BatchNorm3d(oup_inc),
nn.ReLU(inplace=True),
# dw
# nn.Conv3d(oup_inc, oup_inc, 3, stride, 1, groups=oup_inc, bias=False),
# nn.BatchNorm3d(oup_inc),
RepConv3D(oup_inc, oup_inc, (5, 3, 3), stride, (2, 1, 1), groups=oup_inc),
# pw-linear
nn.Conv3d(oup_inc, oup_inc, 1, 1, 0, bias=False),
nn.BatchNorm3d(oup_inc),
nn.ReLU(inplace=True)
)
else:
self.banch1 = nn.Sequential(
# dw
# nn.Conv3d(inp, inp, 3, stride, 1, groups=inp, bias=False),
# nn.BatchNorm3d(inp),
RepConv3D(inp, inp, (5, 3, 3), stride, (2, 1, 1), groups=inp, ),
# pw-linear
nn.Conv3d(inp, oup_inc, 1, 1, 0, bias=False),
nn.BatchNorm3d(oup_inc),
nn.ReLU(inplace=True)
)
self.banch2 = nn.Sequential(
# pw
nn.Conv3d(inp, oup_inc, 1, 1, 0, bias=False),
nn.BatchNorm3d(oup_inc),
nn.ReLU(inplace=True),
# dw
# nn.Conv3d(oup_inc, oup_inc, 3, stride, 1, groups=oup_inc, bias=False),
# nn.BatchNorm3d(oup_inc),
RepConv3D(oup_inc, oup_inc, (5, 3, 3), stride, (2, 1, 1), groups=oup_inc, ),
# pw-linear
nn.Conv3d(oup_inc, oup_inc, 1, 1, 0, bias=False),
nn.BatchNorm3d(oup_inc),
nn.ReLU(inplace=True)
)
@staticmethod
def _concat(x, out):
# concatenate along channel axis
return torch.cat((x, out), 1)
def forward(self, x):
if self.stride == 1:
x1 = x[:, :(x.shape[1] // 2), :, :, :]
x2 = x[:, (x.shape[1] // 2):, :, :, :]
out = self._concat(x1, self.banch2(x2))
elif self.stride == 2:
out = self._concat(self.banch1(x), self.banch2(x))
# return out
return channel_shuffle(out, 2)
#
#
def channel_shuffle(x, groups):
'''Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]'''
batchsize, num_channels, depth, height, width = x.data.size()
channels_per_group = num_channels // groups
# reshape
x = x.view(batchsize, groups,
channels_per_group, depth, height, width)
# permute
x = x.permute(0, 2, 1, 3, 4, 5).contiguous()
# flatten
x = x.view(batchsize, num_channels, depth, height, width)
return x
class ShuffleNetV2(nn.Module):
def __init__(self, width_mult='1.0x', num_classes=600):
super(ShuffleNetV2, self).__init__()
self.stage_repeats = [4, 8, 4]
# index 0 is invalid and should never be called.
# only used for indexing convenience.
if width_mult == '0.25x':
self.stage_out_channels = [-1, 24, 32, 64, 128]
elif width_mult == '0.5x':
self.stage_out_channels = [-1, 24, 48, 96, 192]
elif width_mult == '1.0x':
# self.stage_out_channels = [-1, 24, 116, 232, 464]
self.stage_out_channels = [-1, 24, 128, 256, 512]
elif width_mult == '1.5x':
self.stage_out_channels = [-1, 24, 176, 352, 704]
elif width_mult == '2.0x':
self.stage_out_channels = [-1, 24, 224, 488, 976]
# building first layer
input_channel = self.stage_out_channels[1]
self.conv1 = conv_bn(3, input_channel, stride=(1, 2, 2))
self.maxpool = nn.MaxPool3d(kernel_size=3, stride=2, padding=1)
self.features = []
self.features1 = []
self.features2 = []
self.features3 = []
# building inverted residual blocks
for idxstage in range(len(self.stage_repeats)):
numrepeat = self.stage_repeats[idxstage]
output_channel = self.stage_out_channels[idxstage + 2]
for i in range(numrepeat):
stride = 2 if i == 0 else 1
self.features.append(InvertedResidual(input_channel, output_channel, stride))
input_channel = output_channel
self.features = nn.Sequential(*self.features)
# for idxstage in range(len(self.stage_repeats)):
# numrepeat = self.stage_repeats[idxstage]
# output_channel = self.stage_out_channels[idxstage+2]
# for i in range(numrepeat):
# if idxstage==0:
# stride = 2 if i == 0 else 1
# self.features1.append(InvertedResidual(input_channel, output_channel, stride))
# input_channel = output_channel
# elif idxstage==1:
# stride = 2 if i == 0 else 1
# self.features2.append(InvertedResidual(input_channel, output_channel, stride))
# input_channel = output_channel
# elif idxstage==2:
# stride = 2 if i == 0 else 1
# self.features3.append(InvertedResidual(input_channel, output_channel, stride))
# input_channel = output_channel
# # make it nn.Sequential
# self.features1 = nn.Sequential(*self.features1)
# self.features2 = nn.Sequential(*self.features2)
# self.features3 = nn.Sequential(*self.features3)
# # building last several layers
# self.conv_last = conv_1x1x1_bn(input_channel, self.stage_out_channels[-1])
# self.avgpool = nn.AvgPool3d((2, 1, 1), stride=1)
def forward(self, x):
outputs = {}
# pdb.set_trace() #(1,3,16,512,512) #(1,3,5,512,512)
x = self.conv1(x) # (1,24,16,256,256) #(1,24,5,256,256)
x = self.maxpool(x) # (1,24,8,128,128) #(1,24,3,128,128)
# outputs['stage1'] = x
# x = self.features(x)
x = self.features[:4](x) # (1,116,4,64,64) #(1,116,2,64,64)
outputs['stage2'] = x # torch.mean(x, dim=2, keepdim=True).squeeze(2)
x = self.features[4:12](x) # (1,232,2,32,32) #(1,232,1,32,32)
outputs['stage3'] = x # torch.mean(x, dim=2, keepdim=True).squeeze(2)
x = self.features[12:16](x) # (1,464,1,16,16) #(1,464,1,16,16)
outputs['stage4'] = x # torch.mean(x, dim=2, keepdim=True).squeeze(2)
# out = self.conv_last(out)
# if x.size(2) > 1:
# x = torch.mean(x, dim=2, keepdim=True)
# return x.squeeze(2)
return outputs
def build_shufflenetv2_3d(model_size='1.0x', pretrained=False):
model = ShuffleNetV2(model_size)
feats = model.stage_out_channels[-1]
# if pretrained:
# model = load_weight(model, model_size)
return model, feats
def build_3d_cnn(cfg, pretrained=False):
if 'resnet' in cfg['backbone_3d']:
model, feat_dims = build_resnet_3d(
model_name=cfg['backbone_3d'],
pretrained=pretrained
)
elif 'resnext' in cfg['backbone_3d']:
model, feat_dims = build_resnext_3d(
model_name=cfg['backbone_3d'],
pretrained=pretrained
)
elif 'shufflenetv2' in cfg['backbone_3d']:
model, feat_dims = build_shufflenetv2_3d(
model_size=cfg['model_size'],
pretrained=pretrained
)
else:
print('Unknown Backbone ...')
exit()
return model, feat_dims
class BackboneTD(nn.Module):
def __init__(self, cfg, pretrained=False):
super().__init__()
self.cfg = cfg
# 3D CNN
self.backbone, self.feat_dim = build_3d_cnn(cfg, pretrained)
def forward(self, x):
"""
Input:
x: (Tensor) -> [B, C, T, H, W]
Output:
y: (List) -> [
(Tensor) -> [B, C1, H1, W1],
(Tensor) -> [B, C2, H2, W2],
(Tensor) -> [B, C3, H3, W3]
]
"""
feat = self.backbone(x)
return feat

234
model/TDCNet/darknet.py Normal file
View File

@@ -0,0 +1,234 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# Copyright (c) Megvii, Inc. and its affiliates.
import os
import torch
from matplotlib import pyplot as plt
from torch import nn
class SiLU(nn.Module):
@staticmethod
def forward(x):
return x * torch.sigmoid(x)
def get_activation(name="silu", inplace=True):
if name == "silu":
module = SiLU()
elif name == "relu":
module = nn.ReLU(inplace=inplace)
elif name == "lrelu":
module = nn.LeakyReLU(0.1, inplace=inplace)
elif name == "sigmoid":
module = nn.Sigmoid()
else:
raise AttributeError("Unsupported act type: {}".format(name))
return module
class Focus(nn.Module):
def __init__(self, in_channels, out_channels, ksize=1, stride=1, act="silu"):
super().__init__()
self.conv = BaseConv(in_channels * 4, out_channels, ksize, stride, act=act)
def forward(self, x):
patch_top_left = x[..., ::2, ::2]
patch_bot_left = x[..., 1::2, ::2]
patch_top_right = x[..., ::2, 1::2]
patch_bot_right = x[..., 1::2, 1::2]
x = torch.cat((patch_top_left, patch_bot_left, patch_top_right, patch_bot_right,), dim=1,)
return self.conv(x)
class BaseConv(nn.Module):
def __init__(self, in_channels, out_channels, ksize, stride, groups=1, bias=False, act="silu"):
super().__init__()
pad = (ksize - 1) // 2
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=ksize, stride=stride, padding=pad, groups=groups, bias=bias)
self.bn = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.03)
self.act = get_activation(act, inplace=True)
def forward(self, x):
return self.act(self.bn(self.conv(x)))
def fuseforward(self, x):
return self.act(self.conv(x))
class DWConv(nn.Module):
def __init__(self, in_channels, out_channels, ksize, stride=1, act="silu"):
super().__init__()
self.dconv = BaseConv(in_channels, in_channels, ksize=ksize, stride=stride, groups=in_channels, act=act,)
self.pconv = BaseConv(in_channels, out_channels, ksize=1, stride=1, groups=1, act=act)
def forward(self, x):
x = self.dconv(x)
return self.pconv(x)
class SPPBottleneck(nn.Module):
def __init__(self, in_channels, out_channels, kernel_sizes=(5, 9, 13), activation="silu"):
super().__init__()
hidden_channels = in_channels // 2
self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=activation)
self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=ks, stride=1, padding=ks // 2) for ks in kernel_sizes])
conv2_channels = hidden_channels * (len(kernel_sizes) + 1)
self.conv2 = BaseConv(conv2_channels, out_channels, 1, stride=1, act=activation)
def forward(self, x):
x = self.conv1(x)
x = torch.cat([x] + [m(x) for m in self.m], dim=1)
x = self.conv2(x)
return x
#--------------------------------------------------#
# 残差结构的构建,小的残差结构
#--------------------------------------------------#
class Bottleneck(nn.Module):
# Standard bottleneck
def __init__(self, in_channels, out_channels, shortcut=True, expansion=0.5, depthwise=False, act="silu",):
super().__init__()
hidden_channels = int(out_channels * expansion)
Conv = DWConv if depthwise else BaseConv
#--------------------------------------------------#
# 利用1x1卷积进行通道数的缩减。缩减率一般是50%
#--------------------------------------------------#
self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act)
#--------------------------------------------------#
# 利用3x3卷积进行通道数的拓张。并且完成特征提取
#--------------------------------------------------#
self.conv2 = Conv(hidden_channels, out_channels, 3, stride=1, act=act)
self.use_add = shortcut and in_channels == out_channels
def forward(self, x):
y = self.conv2(self.conv1(x))
if self.use_add:
y = y + x
return y
class CSPLayer(nn.Module):
def __init__(self, in_channels, out_channels, n=1, shortcut=True, expansion=0.5, depthwise=False, act="silu",):
# ch_in, ch_out, number, shortcut, groups, expansion
super().__init__()
hidden_channels = int(out_channels * expansion)
#--------------------------------------------------#
# 主干部分的初次卷积
#--------------------------------------------------#
self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act)
#--------------------------------------------------#
# 大的残差边部分的初次卷积
#--------------------------------------------------#
self.conv2 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act)
#-----------------------------------------------#
# 对堆叠的结果进行卷积的处理
#-----------------------------------------------#
self.conv3 = BaseConv(2 * hidden_channels, out_channels, 1, stride=1, act=act)
#--------------------------------------------------#
# 根据循环的次数构建上述Bottleneck残差结构
#--------------------------------------------------#
module_list = [Bottleneck(hidden_channels, hidden_channels, shortcut, 1.0, depthwise, act=act) for _ in range(n)]
self.m = nn.Sequential(*module_list)
def forward(self, x):
#-------------------------------#
# x_1是主干部分
#-------------------------------#
x_1 = self.conv1(x)
#-------------------------------#
# x_2是大的残差边部分
#-------------------------------#
x_2 = self.conv2(x)
#-----------------------------------------------#
# 主干部分利用残差结构堆叠继续进行特征提取
#-----------------------------------------------#
x_1 = self.m(x_1)
#-----------------------------------------------#
# 主干部分和大的残差边部分进行堆叠
#-----------------------------------------------#
x = torch.cat((x_1, x_2), dim=1)
#-----------------------------------------------#
# 对堆叠的结果进行卷积的处理
#-----------------------------------------------#
return self.conv3(x)
class CSPDarknet(nn.Module):
def __init__(self, dep_mul, wid_mul, out_features=("dark3", "dark4", "dark5"), depthwise=False, act="silu",):
super().__init__()
assert out_features, "please provide output features of Darknet"
self.out_features = out_features
Conv = DWConv if depthwise else BaseConv
#-----------------------------------------------#
# 输入图片是640, 640, 3
# 初始的基本通道是64
#-----------------------------------------------#
base_channels = int(wid_mul * 64) # 64
base_depth = max(round(dep_mul * 3), 1) # 3
#-----------------------------------------------#
# 利用focus网络结构进行特征提取
# 640, 640, 3 -> 320, 320, 12 -> 320, 320, 64
#-----------------------------------------------#
self.stem = Focus(3, base_channels, ksize=3, act=act)
#-----------------------------------------------#
# 完成卷积之后320, 320, 64 -> 160, 160, 128
# 完成CSPlayer之后160, 160, 128 -> 160, 160, 128
#-----------------------------------------------#
self.dark2 = nn.Sequential(
Conv(base_channels, base_channels * 2, 3, 2, act=act),
CSPLayer(base_channels * 2, base_channels * 2, n=base_depth, depthwise=depthwise, act=act),
)
#-----------------------------------------------#
# 完成卷积之后160, 160, 128 -> 80, 80, 256
# 完成CSPlayer之后80, 80, 256 -> 80, 80, 256
#-----------------------------------------------#
self.dark3 = nn.Sequential(
Conv(base_channels * 2, base_channels * 4, 3, 2, act=act),
CSPLayer(base_channels * 4, base_channels * 4, n=base_depth * 3, depthwise=depthwise, act=act),
)
#-----------------------------------------------#
# 完成卷积之后80, 80, 256 -> 40, 40, 512
# 完成CSPlayer之后40, 40, 512 -> 40, 40, 512
#-----------------------------------------------#
self.dark4 = nn.Sequential(
Conv(base_channels * 4, base_channels * 8, 3, 2, act=act),
CSPLayer(base_channels * 8, base_channels * 8, n=base_depth * 3, depthwise=depthwise, act=act),
)
#-----------------------------------------------#
# 完成卷积之后40, 40, 512 -> 20, 20, 1024
# 完成SPP之后20, 20, 1024 -> 20, 20, 1024
# 完成CSPlayer之后20, 20, 1024 -> 20, 20, 1024
#-----------------------------------------------#
self.dark5 = nn.Sequential(
Conv(base_channels * 8, base_channels * 16, 3, 2, act=act),
SPPBottleneck(base_channels * 16, base_channels * 16, activation=act),
CSPLayer(base_channels * 16, base_channels * 16, n=base_depth, shortcut=False, depthwise=depthwise, act=act),
)
def forward(self, x):
outputs = {}
x = self.stem(x)
outputs["stem"] = x
x = self.dark2(x)
outputs["dark2"] = x
#-----------------------------------------------#
# dark3的输出为80, 80, 256是一个有效特征层
#-----------------------------------------------#
x = self.dark3(x)
outputs["dark3"] = x
#-----------------------------------------------#
# dark4的输出为40, 40, 512是一个有效特征层
#-----------------------------------------------#
x = self.dark4(x)
outputs["dark4"] = x
#-----------------------------------------------#
# dark5的输出为20, 20, 1024是一个有效特征层
#-----------------------------------------------#
x = self.dark5(x)
outputs["dark5"] = x
return {k: v for k, v in outputs.items() if k in self.out_features}

507
model/nets/yolo_training.py Normal file
View File

@@ -0,0 +1,507 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# Copyright (c) Megvii, Inc. and its affiliates.
import math
from copy import deepcopy
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.ops.focal_loss import sigmoid_focal_loss
class IOUloss(nn.Module):
def __init__(self, reduction="none", loss_type="iou"):
super(IOUloss, self).__init__()
self.reduction = reduction
self.loss_type = loss_type
def forward(self, pred, target):
assert pred.shape[0] == target.shape[0]
pred = pred.view(-1, 4)
target = target.view(-1, 4)
tl = torch.max(
(pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2)
)
br = torch.min(
(pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2)
)
area_p = torch.prod(pred[:, 2:], 1)
area_g = torch.prod(target[:, 2:], 1)
en = (tl < br).type(tl.type()).prod(dim=1)
area_i = torch.prod(br - tl, 1) * en
area_u = area_p + area_g - area_i
iou = (area_i) / (area_u + 1e-16)
if self.loss_type == "iou":
loss = 1 - iou ** 2
elif self.loss_type == "giou":
c_tl = torch.min(
(pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2)
)
c_br = torch.max(
(pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2)
)
area_c = torch.prod(c_br - c_tl, 1)
giou = iou - (area_c - area_u) / area_c.clamp(1e-16)
loss = 1 - giou.clamp(min=-1.0, max=1.0)
elif self.loss_type == 'ciou':
b1_cxy = pred[:,:2]
b2_cxy = target[:,:2]
# 计算中心的差距
center_distance = torch.sum(torch.pow((b1_cxy - b2_cxy), 2), axis=-1)
# 找到包裹两个框的最小框的左上角和右下角
enclose_mins = torch.min((pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2))
enclose_maxes = torch.max((pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2))
enclose_wh = torch.max(enclose_maxes - enclose_mins, torch.zeros_like(br))
# 计算对角线距离
enclose_diagonal = torch.sum(torch.pow(enclose_wh,2), axis=-1)
ciou = iou - 1.0 * (center_distance) / torch.clamp(enclose_diagonal,min = 1e-6)
v = (4 / (torch.pi ** 2)) * torch.pow((torch.atan(pred[:, 2]/torch.clamp(pred[:, 3],min = 1e-6)) - torch.atan(target[:, 2]/torch.clamp(target[:, 3],min = 1e-6))), 2)
alpha = v / torch.clamp((1.0 - iou + v),min=1e-6)
ciou = ciou - alpha * v
loss = 1 - ciou.clamp(min=-1.0, max=1.0)
if self.reduction == "mean":
loss = loss.mean()
elif self.reduction == "sum":
loss = loss.sum()
return loss
class YOLOLoss(nn.Module):
def __init__(self, num_classes, fp16, strides=[8, 16, 32]):
super().__init__()
self.num_classes = num_classes
self.strides = strides
self.bcewithlog_loss = nn.BCEWithLogitsLoss(reduction="none")
self.iou_loss = IOUloss(reduction="none")
self.grids = [torch.zeros(1)] * len(strides)
self.fp16 = fp16
def forward(self, inputs, labels=None):
outputs = []
x_shifts = []
y_shifts = []
expanded_strides = []
#-----------------------------------------------#
# inputs [[batch_size, num_classes + 5, 20, 20]
# [batch_size, num_classes + 5, 40, 40]
# [batch_size, num_classes + 5, 80, 80]]
# outputs [[batch_size, 400, num_classes + 5]
# [batch_size, 1600, num_classes + 5]
# [batch_size, 6400, num_classes + 5]]
# x_shifts [[batch_size, 400]
# [batch_size, 1600]
# [batch_size, 6400]]
#-----------------------------------------------#
for k, (stride, output) in enumerate(zip(self.strides, inputs)):
output, grid = self.get_output_and_grid(output, k, stride)
x_shifts.append(grid[:, :, 0])
y_shifts.append(grid[:, :, 1])
expanded_strides.append(torch.ones_like(grid[:, :, 0]) * stride)
outputs.append(output)
return self.get_losses(x_shifts, y_shifts, expanded_strides, labels, torch.cat(outputs, 1))
def get_output_and_grid(self, output, k, stride):
grid = self.grids[k]
hsize, wsize = output.shape[-2:]
if grid.shape[2:4] != output.shape[2:4]:
yv, xv = torch.meshgrid([torch.arange(hsize), torch.arange(wsize)], indexing='ij')
grid = torch.stack((xv, yv), 2).view(1, hsize, wsize, 2).type(output.type())
self.grids[k] = grid
grid = grid.view(1, -1, 2)
output = output.flatten(start_dim=2).permute(0, 2, 1)
output[..., :2] = (output[..., :2] + grid.type_as(output)) * stride
output[..., 2:4] = torch.exp(output[..., 2:4]) * stride
return output, grid
def get_losses(self, x_shifts, y_shifts, expanded_strides, labels, outputs):
#-----------------------------------------------#
# [batch, n_anchors_all, 4]
#-----------------------------------------------#
bbox_preds = outputs[:, :, :4]
#-----------------------------------------------#
# [batch, n_anchors_all, 1]
#-----------------------------------------------#
obj_preds = outputs[:, :, 4:5]
#-----------------------------------------------#
# [batch, n_anchors_all, n_cls]
#-----------------------------------------------#
cls_preds = outputs[:, :, 5:]
total_num_anchors = outputs.shape[1]
#-----------------------------------------------#
# x_shifts [1, n_anchors_all]
# y_shifts [1, n_anchors_all]
# expanded_strides [1, n_anchors_all]
#-----------------------------------------------#
x_shifts = torch.cat(x_shifts, 1).type_as(outputs)
y_shifts = torch.cat(y_shifts, 1).type_as(outputs)
expanded_strides = torch.cat(expanded_strides, 1).type_as(outputs)
cls_targets = []
reg_targets = []
obj_targets = []
fg_masks = []
num_fg = 0.0
for batch_idx in range(outputs.shape[0]):
num_gt = len(labels[batch_idx])
if num_gt == 0:
cls_target = outputs.new_zeros((0, self.num_classes))
reg_target = outputs.new_zeros((0, 4))
obj_target = outputs.new_zeros((total_num_anchors, 1))
fg_mask = outputs.new_zeros(total_num_anchors).bool()
else:
#-----------------------------------------------#
# gt_bboxes_per_image [num_gt, num_classes]
# gt_classes [num_gt]
# bboxes_preds_per_image [n_anchors_all, 4]
# cls_preds_per_image [n_anchors_all, num_classes]
# obj_preds_per_image [n_anchors_all, 1]
#-----------------------------------------------#
gt_bboxes_per_image = labels[batch_idx][..., :4].type_as(outputs)
gt_classes = labels[batch_idx][..., 4].type_as(outputs)
bboxes_preds_per_image = bbox_preds[batch_idx]
cls_preds_per_image = cls_preds[batch_idx]
obj_preds_per_image = obj_preds[batch_idx]
gt_matched_classes, fg_mask, pred_ious_this_matching, matched_gt_inds, num_fg_img = self.get_assignments(
num_gt, total_num_anchors, gt_bboxes_per_image, gt_classes, bboxes_preds_per_image, cls_preds_per_image, obj_preds_per_image,
expanded_strides, x_shifts, y_shifts,
)
torch.cuda.empty_cache()
num_fg += num_fg_img
cls_target = F.one_hot(gt_matched_classes.to(torch.int64), self.num_classes).float() * pred_ious_this_matching.unsqueeze(-1)
obj_target = fg_mask.unsqueeze(-1)
reg_target = gt_bboxes_per_image[matched_gt_inds]
cls_targets.append(cls_target)
reg_targets.append(reg_target)
obj_targets.append(obj_target.type(cls_target.type()))
fg_masks.append(fg_mask)
cls_targets = torch.cat(cls_targets, 0)
reg_targets = torch.cat(reg_targets, 0)
obj_targets = torch.cat(obj_targets, 0)
fg_masks = torch.cat(fg_masks, 0)
num_fg = max(num_fg, 1)
loss_iou = (self.iou_loss(bbox_preds.view(-1, 4)[fg_masks], reg_targets)).sum()
loss_obj = (self.bcewithlog_loss(obj_preds.view(-1, 1), obj_targets)).sum()
loss_cls = (self.bcewithlog_loss(cls_preds.view(-1, self.num_classes)[fg_masks], cls_targets)).sum()
# loss_obj = (sigmoid_focal_loss(obj_preds.view(-1, 1), obj_targets)).sum()
# loss_cls = (sigmoid_focal_loss(cls_preds.view(-1, self.num_classes)[fg_masks], cls_targets)).sum()
reg_weight = 5.0
loss = reg_weight * loss_iou + loss_obj + loss_cls
return loss / num_fg
@torch.no_grad()
def get_assignments(self, num_gt, total_num_anchors, gt_bboxes_per_image, gt_classes, bboxes_preds_per_image, cls_preds_per_image, obj_preds_per_image, expanded_strides, x_shifts, y_shifts):
#-------------------------------------------------------#
# fg_mask [n_anchors_all]
# is_in_boxes_and_center [num_gt, len(fg_mask)]
#-------------------------------------------------------#
fg_mask, is_in_boxes_and_center = self.get_in_boxes_info(gt_bboxes_per_image, expanded_strides, x_shifts, y_shifts, total_num_anchors, num_gt)
#-------------------------------------------------------#
# fg_mask [n_anchors_all]
# bboxes_preds_per_image [fg_mask, 4]
# cls_preds_ [fg_mask, num_classes]
# obj_preds_ [fg_mask, 1]
#-------------------------------------------------------#
bboxes_preds_per_image = bboxes_preds_per_image[fg_mask]
cls_preds_ = cls_preds_per_image[fg_mask]
obj_preds_ = obj_preds_per_image[fg_mask]
num_in_boxes_anchor = bboxes_preds_per_image.shape[0]
#-------------------------------------------------------#
# pair_wise_ious [num_gt, fg_mask]
#-------------------------------------------------------#
pair_wise_ious = self.bboxes_iou(gt_bboxes_per_image, bboxes_preds_per_image, False)
pair_wise_ious_loss = -torch.log(pair_wise_ious + 1e-8)
#-------------------------------------------------------#
# cls_preds_ [num_gt, fg_mask, num_classes]
# gt_cls_per_image [num_gt, fg_mask, num_classes]
#-------------------------------------------------------#
if self.fp16:
with torch.cuda.amp.autocast(enabled=False):
cls_preds_ = cls_preds_.float().unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_() * obj_preds_.unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_()
gt_cls_per_image = F.one_hot(gt_classes.to(torch.int64), self.num_classes).float().unsqueeze(1).repeat(1, num_in_boxes_anchor, 1)
pair_wise_cls_loss = F.binary_cross_entropy(cls_preds_.sqrt_(), gt_cls_per_image, reduction="none").sum(-1)
else:
cls_preds_ = cls_preds_.float().unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_() * obj_preds_.unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_()
gt_cls_per_image = F.one_hot(gt_classes.to(torch.int64), self.num_classes).float().unsqueeze(1).repeat(1, num_in_boxes_anchor, 1)
pair_wise_cls_loss = F.binary_cross_entropy(cls_preds_.sqrt_(), gt_cls_per_image, reduction="none").sum(-1)
del cls_preds_
cost = pair_wise_cls_loss + 3.0 * pair_wise_ious_loss + 100000.0 * (~is_in_boxes_and_center).float()
num_fg, gt_matched_classes, pred_ious_this_matching, matched_gt_inds = self.dynamic_k_matching(cost, pair_wise_ious, gt_classes, num_gt, fg_mask)
del pair_wise_cls_loss, cost, pair_wise_ious, pair_wise_ious_loss
return gt_matched_classes, fg_mask, pred_ious_this_matching, matched_gt_inds, num_fg
def bboxes_iou(self, bboxes_a, bboxes_b, xyxy=True):
if bboxes_a.shape[1] != 4 or bboxes_b.shape[1] != 4:
raise IndexError
if xyxy:
tl = torch.max(bboxes_a[:, None, :2], bboxes_b[:, :2])
br = torch.min(bboxes_a[:, None, 2:], bboxes_b[:, 2:])
area_a = torch.prod(bboxes_a[:, 2:] - bboxes_a[:, :2], 1)
area_b = torch.prod(bboxes_b[:, 2:] - bboxes_b[:, :2], 1)
else:
tl = torch.max(
(bboxes_a[:, None, :2] - bboxes_a[:, None, 2:] / 2),
(bboxes_b[:, :2] - bboxes_b[:, 2:] / 2),
)
br = torch.min(
(bboxes_a[:, None, :2] + bboxes_a[:, None, 2:] / 2),
(bboxes_b[:, :2] + bboxes_b[:, 2:] / 2),
)
area_a = torch.prod(bboxes_a[:, 2:], 1)
area_b = torch.prod(bboxes_b[:, 2:], 1)
en = (tl < br).type(tl.type()).prod(dim=2)
area_i = torch.prod(br - tl, 2) * en
return area_i / (area_a[:, None] + area_b - area_i)
def get_in_boxes_info(self, gt_bboxes_per_image, expanded_strides, x_shifts, y_shifts, total_num_anchors, num_gt, center_radius = 2.5):
#-------------------------------------------------------#
# expanded_strides_per_image [n_anchors_all]
# x_centers_per_image [num_gt, n_anchors_all]
# x_centers_per_image [num_gt, n_anchors_all]
#-------------------------------------------------------#
expanded_strides_per_image = expanded_strides[0]
x_centers_per_image = ((x_shifts[0] + 0.5) * expanded_strides_per_image).unsqueeze(0).repeat(num_gt, 1)
y_centers_per_image = ((y_shifts[0] + 0.5) * expanded_strides_per_image).unsqueeze(0).repeat(num_gt, 1)
#-------------------------------------------------------#
# gt_bboxes_per_image_x [num_gt, n_anchors_all]
#-------------------------------------------------------#
gt_bboxes_per_image_l = (gt_bboxes_per_image[:, 0] - 0.5 * gt_bboxes_per_image[:, 2]).unsqueeze(1).repeat(1, total_num_anchors)
gt_bboxes_per_image_r = (gt_bboxes_per_image[:, 0] + 0.5 * gt_bboxes_per_image[:, 2]).unsqueeze(1).repeat(1, total_num_anchors)
gt_bboxes_per_image_t = (gt_bboxes_per_image[:, 1] - 0.5 * gt_bboxes_per_image[:, 3]).unsqueeze(1).repeat(1, total_num_anchors)
gt_bboxes_per_image_b = (gt_bboxes_per_image[:, 1] + 0.5 * gt_bboxes_per_image[:, 3]).unsqueeze(1).repeat(1, total_num_anchors)
#-------------------------------------------------------#
# bbox_deltas [num_gt, n_anchors_all, 4]
#-------------------------------------------------------#
b_l = x_centers_per_image - gt_bboxes_per_image_l
b_r = gt_bboxes_per_image_r - x_centers_per_image
b_t = y_centers_per_image - gt_bboxes_per_image_t
b_b = gt_bboxes_per_image_b - y_centers_per_image
bbox_deltas = torch.stack([b_l, b_t, b_r, b_b], 2)
#-------------------------------------------------------#
# is_in_boxes [num_gt, n_anchors_all]
# is_in_boxes_all [n_anchors_all]
#-------------------------------------------------------#
is_in_boxes = bbox_deltas.min(dim=-1).values > 0.0
is_in_boxes_all = is_in_boxes.sum(dim=0) > 0
gt_bboxes_per_image_l = (gt_bboxes_per_image[:, 0]).unsqueeze(1).repeat(1, total_num_anchors) - center_radius * expanded_strides_per_image.unsqueeze(0)
gt_bboxes_per_image_r = (gt_bboxes_per_image[:, 0]).unsqueeze(1).repeat(1, total_num_anchors) + center_radius * expanded_strides_per_image.unsqueeze(0)
gt_bboxes_per_image_t = (gt_bboxes_per_image[:, 1]).unsqueeze(1).repeat(1, total_num_anchors) - center_radius * expanded_strides_per_image.unsqueeze(0)
gt_bboxes_per_image_b = (gt_bboxes_per_image[:, 1]).unsqueeze(1).repeat(1, total_num_anchors) + center_radius * expanded_strides_per_image.unsqueeze(0)
#-------------------------------------------------------#
# center_deltas [num_gt, n_anchors_all, 4]
#-------------------------------------------------------#
c_l = x_centers_per_image - gt_bboxes_per_image_l
c_r = gt_bboxes_per_image_r - x_centers_per_image
c_t = y_centers_per_image - gt_bboxes_per_image_t
c_b = gt_bboxes_per_image_b - y_centers_per_image
center_deltas = torch.stack([c_l, c_t, c_r, c_b], 2)
#-------------------------------------------------------#
# is_in_centers [num_gt, n_anchors_all]
# is_in_centers_all [n_anchors_all]
#-------------------------------------------------------#
is_in_centers = center_deltas.min(dim=-1).values > 0.0
is_in_centers_all = is_in_centers.sum(dim=0) > 0
#-------------------------------------------------------#
# is_in_boxes_anchor [n_anchors_all]
# is_in_boxes_and_center [num_gt, is_in_boxes_anchor]
#-------------------------------------------------------#
is_in_boxes_anchor = is_in_boxes_all | is_in_centers_all
is_in_boxes_and_center = is_in_boxes[:, is_in_boxes_anchor] & is_in_centers[:, is_in_boxes_anchor]
return is_in_boxes_anchor, is_in_boxes_and_center
def dynamic_k_matching(self, cost, pair_wise_ious, gt_classes, num_gt, fg_mask):
#-------------------------------------------------------#
# cost [num_gt, fg_mask]
# pair_wise_ious [num_gt, fg_mask]
# gt_classes [num_gt]
# fg_mask [n_anchors_all]
# matching_matrix [num_gt, fg_mask]
#-------------------------------------------------------#
matching_matrix = torch.zeros_like(cost)
#------------------------------------------------------------#
# 选取iou最大的n_candidate_k个点
# 然后求和,判断应该有多少点用于该框预测
# topk_ious [num_gt, n_candidate_k]
# dynamic_ks [num_gt]
# matching_matrix [num_gt, fg_mask]
#------------------------------------------------------------#
n_candidate_k = min(10, pair_wise_ious.size(1))
topk_ious, _ = torch.topk(pair_wise_ious, n_candidate_k, dim=1)
dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1)
for gt_idx in range(num_gt):
#------------------------------------------------------------#
# 给每个真实框选取最小的动态k个点
#------------------------------------------------------------#
_, pos_idx = torch.topk(cost[gt_idx], k=dynamic_ks[gt_idx].item(), largest=False)
matching_matrix[gt_idx][pos_idx] = 1.0
del topk_ious, dynamic_ks, pos_idx
#------------------------------------------------------------#
# anchor_matching_gt [fg_mask]
#------------------------------------------------------------#
anchor_matching_gt = matching_matrix.sum(0)
if (anchor_matching_gt > 1).sum() > 0:
#------------------------------------------------------------#
# 当某一个特征点指向多个真实框的时候
# 选取cost最小的真实框。
#------------------------------------------------------------#
_, cost_argmin = torch.min(cost[:, anchor_matching_gt > 1], dim=0)
matching_matrix[:, anchor_matching_gt > 1] *= 0.0
matching_matrix[cost_argmin, anchor_matching_gt > 1] = 1.0
#------------------------------------------------------------#
# fg_mask_inboxes [fg_mask]
# num_fg为正样本的特征点个数
#------------------------------------------------------------#
fg_mask_inboxes = matching_matrix.sum(0) > 0.0
num_fg = fg_mask_inboxes.sum().item()
#------------------------------------------------------------#
# 对fg_mask进行更新
#------------------------------------------------------------#
fg_mask[fg_mask.clone()] = fg_mask_inboxes
#------------------------------------------------------------#
# 获得特征点对应的物品种类
#------------------------------------------------------------#
matched_gt_inds = matching_matrix[:, fg_mask_inboxes].argmax(0)
gt_matched_classes = gt_classes[matched_gt_inds]
pred_ious_this_matching = (matching_matrix * pair_wise_ious).sum(0)[fg_mask_inboxes]
return num_fg, gt_matched_classes, pred_ious_this_matching, matched_gt_inds
def is_parallel(model):
# Returns True if model is of type DP or DDP
return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
def de_parallel(model):
# De-parallelize a model: returns single-GPU model if model is of type DP or DDP
return model.module if is_parallel(model) else model
def copy_attr(a, b, include=(), exclude=()):
# Copy attributes from b to a, options to only include [...] and to exclude [...]
for k, v in b.__dict__.items():
if (len(include) and k not in include) or k.startswith('_') or k in exclude:
continue
else:
setattr(a, k, v)
class ModelEMA:
""" Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models
Keeps a moving average of everything in the model state_dict (parameters and buffers)
For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
"""
def __init__(self, model, decay=0.9999, tau=2000, updates=0):
# Create EMA
self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA
# if next(model.parameters()).device.type != 'cpu':
# self.ema.half() # FP16 EMA
self.updates = updates # number of EMA updates
self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs)
for p in self.ema.parameters():
p.requires_grad_(False)
def update(self, model):
# Update EMA parameters
with torch.no_grad():
self.updates += 1
d = self.decay(self.updates)
msd = de_parallel(model).state_dict() # model state_dict
for k, v in self.ema.state_dict().items():
if v.dtype.is_floating_point:
v *= d
v += (1 - d) * msd[k].detach()
def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
# Update EMA attributes
copy_attr(self.ema, model, include, exclude)
def weights_init(net, init_type='normal', init_gain = 0.02):
def init_func(m):
classname = m.__class__.__name__
if hasattr(m, 'weight') and classname.find('Conv') != -1:
if init_type == 'normal':
torch.nn.init.normal_(m.weight.data, 0.0, init_gain)
elif init_type == 'xavier':
torch.nn.init.xavier_normal_(m.weight.data, gain=init_gain)
elif init_type == 'kaiming':
torch.nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif init_type == 'orthogonal':
torch.nn.init.orthogonal_(m.weight.data, gain=init_gain)
else:
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
elif classname.find('BatchNorm2d') != -1:
torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
torch.nn.init.constant_(m.bias.data, 0.0)
print('initialize network with %s type' % init_type)
net.apply(init_func)
def get_lr_scheduler(lr_decay_type, lr, min_lr, total_iters, warmup_iters_ratio = 0.05, warmup_lr_ratio = 0.1, no_aug_iter_ratio = 0.05, step_num = 10):
def yolox_warm_cos_lr(lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter, iters):
if iters <= warmup_total_iters:
# lr = (lr - warmup_lr_start) * iters / float(warmup_total_iters) + warmup_lr_start
lr = (lr - warmup_lr_start) * pow(iters / float(warmup_total_iters), 2) + warmup_lr_start
elif iters >= total_iters - no_aug_iter:
lr = min_lr
else:
lr = min_lr + 0.5 * (lr - min_lr) * (
1.0 + math.cos(math.pi* (iters - warmup_total_iters) / (total_iters - warmup_total_iters - no_aug_iter))
)
return lr
def step_lr(lr, decay_rate, step_size, iters):
if step_size < 1:
raise ValueError("step_size must above 1.")
n = iters // step_size
out_lr = lr * decay_rate ** n
return out_lr
if lr_decay_type == "cos":
warmup_total_iters = min(max(warmup_iters_ratio * total_iters, 1), 3)
warmup_lr_start = max(warmup_lr_ratio * lr, 1e-6)
no_aug_iter = min(max(no_aug_iter_ratio * total_iters, 1), 15)
func = partial(yolox_warm_cos_lr ,lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter)
else:
decay_rate = (min_lr / lr) ** (1 / (step_num - 1))
step_size = total_iters / step_num
func = partial(step_lr, lr, decay_rate, step_size)
return func
def set_optimizer_lr(optimizer, lr_scheduler_func, epoch):
lr = lr_scheduler_func(epoch)
for param_group in optimizer.param_groups:
param_group['lr'] = lr

1
model_data/classes.txt Normal file
View File

@@ -0,0 +1 @@
target

83
summary.py Normal file
View File

@@ -0,0 +1,83 @@
# --------------------------------------------#
# 该部分代码用于看网络结构
# --------------------------------------------#
import os
from torch import nn
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import torch
from thop import clever_format, profile
from model.TDCNet.TDCNetwork import TDCNetwork
from model.TDCNet.TDCR import RepConv3D
if __name__ == "__main__":
input_shape = [640, 640]
num_classes = 1
# 需要使用device来指定网络在GPU还是CPU运行8824
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_frame = 5
m = TDCNetwork(num_classes, num_frame=5)
for mm in m.modules():
if isinstance(mm, RepConv3D):
mm.switch_to_deploy()
dummy_input = torch.randn(1, 3, num_frame * 2, input_shape[0], input_shape[1]).to(device) # torch.randn(1, 3, 5,input_shape[0], input_shape[1]).to(device)
flops, params = profile(m.to(device), (dummy_input,), verbose=False)
# #--------------------------------------------------------#
# # flops * 2是因为profile没有将卷积作为两个operations
# # 有些论文将卷积算乘法、加法两个operations。此时乘2
# # 有些论文只考虑乘法的运算次数忽略加法。此时不乘2
# # 本代码选择乘2参考YOLOX。
# #--------------------------------------------------------#
flops = flops * 2
flops, params = clever_format([flops, params], "%.3f")
print('Total GFLOPS: %s' % (flops))
print('Total params: %s' % (params))
# 计算FPS
from data.dataloader_for_IRSTD_UAV import seqDataset, dataset_collate
from torch.utils.data import DataLoader
import time
max_iter = 2000
log_interval = 50
num_warmup = 20
pure_inf_time = 0
fps = 0
val_annotation_path = "/data1/gsk/Dataset/IRSTD-UAV/val.txt"
val_dataset = seqDataset(val_annotation_path, input_shape[0], num_frame, 'val')
gen_val = DataLoader(val_dataset, shuffle = False, batch_size = 1, num_workers = 10, pin_memory=True,
drop_last=True, collate_fn=dataset_collate)
m = nn.DataParallel(m).cuda()
# benchmark with 2000 image and take the average
for i, data in enumerate(gen_val):
torch.cuda.synchronize()
start_time = time.perf_counter()
with torch.no_grad():
m(data[0].to('cuda'))
torch.cuda.synchronize()
elapsed = time.perf_counter() - start_time
if i >= num_warmup:
pure_inf_time += elapsed
if (i + 1) % log_interval == 0:
fps = (i + 1 - num_warmup) / pure_inf_time
print(
f'Done image [{i + 1:<3}/ {max_iter}], '
f'fps: {fps:.1f} img / s, '
f'times per image: {1000 / fps:.1f} ms / img',
flush=True)
if (i + 1) == max_iter:
fps = (i + 1 - num_warmup) / pure_inf_time
print(
f'Overall fps: {fps:.1f} img / s, '
f'times per image: {1000 / fps:.1f} ms / img',
flush=True)
break
print("FPS:" ,fps)

158
test.py Normal file
View File

@@ -0,0 +1,158 @@
import json
import os
import cv2
import numpy as np
import torch
import torch.nn as nn
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
from tqdm import tqdm
from model.TDCNet.TDCNetwork import TDCNetwork
from model.TDCNet.TDCR import RepConv3D
from utils.utils import get_classes, show_config
from utils.utils_bbox import decode_outputs, non_max_suppression
# 配置参数
cocoGt_path = '/Dataset/IRSTD-UAV/val_coco.json'
dataset_img_path = '/Dataset/IRSTD-UAV'
temp_save_path = 'results/TDCNet_epoch_100_batch_4_optim_adam_lr_0.001_T_5'
model_path = ''
num_frame = 5
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
def get_history_imgs(image_path):
"""获取5帧背景对齐图 + 5帧原图"""
dir_path = os.path.dirname(image_path)
index = int(os.path.basename(image_path).split('.')[0])
match_dir = dir_path.replace('images', f'matches') + f"/{index:08d}"
match_imgs = [os.path.join(match_dir, f"match_{i}.png") for i in range(1, num_frame + 1)]
min_index = index - (index % 50)
original_imgs = [os.path.join(dir_path, f"{max(index - i, min_index):08d}.png") for i in reversed(range(num_frame))]
return match_imgs + original_imgs
def letterbox_image_batch(images, target_size=(512, 512), color=(128, 128, 128)):
"""
letterbox预处理
Args:
images: list of np.ndarray, shape: (H, W, 3)
target_size: desired (width, height)
Returns:
np.ndarray of shape (N, target_H, target_W, 3)
"""
w, h = target_size
output = np.full((len(images), h, w, 3), color, dtype=np.uint8)
for i, img in enumerate(images):
ih, iw = img.shape[:2]
scale = min(w / iw, h / ih)
nw = int(iw * scale)
nh = int(ih * scale)
resized = cv2.resize(img, (nw, nh), interpolation=cv2.INTER_LINEAR)
top = (h - nh) // 2
left = (w - nw) // 2
output[i, top:top + nh, left:left + nw, :] = resized
return output
class MAP_vid:
def __init__(self):
self.model_path = model_path
self.classes_path = 'model_data/classes.txt'
self.input_shape = [640, 640]
self.confidence = 0.001
self.nms_iou = 0.5
self.letterbox_image = True
self.cuda = True
self.class_names, self.num_classes = get_classes(self.classes_path)
self.net = TDCNetwork(self.num_classes, num_frame=num_frame)
state_dict = torch.load(self.model_path, map_location='cuda' if self.cuda else 'cpu')
self.net.load_state_dict(state_dict)
for m in self.net.modules():
if isinstance(m, RepConv3D):
m.switch_to_deploy()
if self.cuda:
self.net = nn.DataParallel(self.net).cuda()
self.net = self.net.eval()
show_config(**self.__dict__)
def detect_image(self, image_id, images, results):
# Resize
image_shape = np.array(images[0].shape[:2])
images = letterbox_image_batch(images, target_size=tuple(self.input_shape))
# Preprocess
images = np.array(images).astype(np.float32) / 255.0
images = images.transpose(3, 0, 1, 2)[None]
# To tensor
with torch.no_grad():
images_tensor = torch.from_numpy(images).cuda()
# Inference
with torch.no_grad():
outputs = self.net(images_tensor)
outputs = decode_outputs(outputs, self.input_shape)
# NMS
outputs = non_max_suppression(outputs, self.num_classes, self.input_shape,
image_shape, self.letterbox_image,
conf_thres=self.confidence, nms_thres=self.nms_iou)
# Postprocess
if outputs[0] is not None:
top_label = np.array(outputs[0][:, 6], dtype='int32')
top_conf = outputs[0][:, 4] * outputs[0][:, 5]
top_boxes = outputs[0][:, :4]
for i, c in enumerate(top_label):
top, left, bottom, right = top_boxes[i]
results.append({
"image_id": int(image_id),
"category_id": clsid2catid[c],
"bbox": [float(left), float(top), float(right - left), float(bottom - top)],
"score": float(top_conf[i])
})
return results
if __name__ == "__main__":
os.makedirs(temp_save_path, exist_ok=True)
cocoGt = COCO(cocoGt_path)
ids = list(cocoGt.imgToAnns.keys())
global clsid2catid
clsid2catid = cocoGt.getCatIds()
yolo = MAP_vid()
results = []
for image_id in tqdm(ids):
file_name = cocoGt.loadImgs(image_id)[0]['file_name']
image_path = os.path.join(dataset_img_path, file_name)
image_paths = get_history_imgs(image_path)
images = [cv2.imread(p)[:, :, ::-1] for p in image_paths] # BGR -> RGB
results = yolo.detect_image(image_id, images, results)
with open(os.path.join(temp_save_path, 'eval_results.json'), 'w') as f:
json.dump(results, f)
cocoDt = cocoGt.loadRes(os.path.join(temp_save_path, 'eval_results.json'))
cocoEval = COCOeval(cocoGt, cocoDt, 'bbox')
cocoEval.evaluate()
cocoEval.accumulate()
cocoEval.summarize()
precisions = cocoEval.eval['precision']
precision_50 = precisions[0, :, 0, 0, -1] # 第三为类别 (T,R,K,A,M)
recalls = cocoEval.eval['recall']
recall_50 = recalls[0, 0, 0, -1] # 第二为类别 (T,K,A,M)
print("Precision: %.4f, Recall: %.4f, F1: %.4f" % (np.mean(precision_50[:int(recall_50 * 100)]), recall_50, 2 * recall_50 * np.mean(precision_50[:int(recall_50 * 100)]) / (recall_50 + np.mean(precision_50[:int(recall_50 * 100)]))))
print("Get map done.")

480
train.py Normal file
View File

@@ -0,0 +1,480 @@
# -------------------------------------#
# 对数据集进行训练
# -------------------------------------#
import datetime
import os
import random
import numpy as np
import torch
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from data.dataloader_for_IRSTD_UAV import seqDataset, dataset_collate
from model.TDCNet.TDCNetwork import TDCNetwork
from model.nets.yolo_training import (ModelEMA, YOLOLoss, get_lr_scheduler, set_optimizer_lr, weights_init)
from utils.callbacks import EvalCallback, LossHistory
from utils.utils import get_classes, show_config
from utils.utils_fit import fit_one_epoch
if __name__ == "__main__":
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
# ---------------------------------#
# num_frame 输入帧数
# ---------------------------------#
num_frame = 5
# ---------------------------------#
# Cuda 是否使用Cuda
# 没有GPU可以设置成False
# ---------------------------------#
Cuda = True
# ---------------------------------------------------------------------#
# distributed 用于指定是否使用单机多卡分布式运行
# 终端指令仅支持Ubuntu。CUDA_VISIBLE_DEVICES用于在Ubuntu下指定显卡。
# Windows系统下默认使用DP模式调用所有显卡不支持DDP。
# DP模式
# 设置 distributed = False
# 在终端中输入 CUDA_VISIBLE_DEVICES=0,1 python train.py
# DDP模式
# 设置 distributed = True
# 在终端中输入 CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 train.py
# ---------------------------------------------------------------------#
distributed = False
# ---------------------------------------------------------------------#
# sync_bn 是否使用sync_bnDDP模式多卡可用
# ---------------------------------------------------------------------#
sync_bn = False
# ---------------------------------------------------------------------#
# fp16 是否使用混合精度训练
# 可减少约一半的显存、需要pytorch1.7.1以上
# ---------------------------------------------------------------------#
fp16 = False
# ---------------------------------------------------------------------#
# classes_path 指向model_data下的txt与自己训练的数据集相关
# 训练前一定要修改classes_path使其对应自己的数据集
# ---------------------------------------------------------------------#
classes_path = 'model_data/classes.txt'
model_path = ''
input_shape = [640, 640]
# ----------------------------------------------------------------------------------------------------------------------------#
# 训练分为两个阶段,分别是冻结阶段和解冻阶段。设置冻结阶段是为了满足机器性能不足的同学的训练需求。
# 冻结训练需要的显存较小显卡非常差的情况下可设置Freeze_Epoch等于UnFreeze_EpochFreeze_Train = True此时仅仅进行冻结训练。
#
# 在此提供若干参数设置建议,各位训练者根据自己的需求进行灵活调整:
# (一)从整个模型的预训练权重开始训练:
# Adam
# Init_Epoch = 0Freeze_Epoch = 50UnFreeze_Epoch = 100Freeze_Train = Trueoptimizer_type = 'adam'Init_lr = 1e-3weight_decay = 0。冻结
# Init_Epoch = 0UnFreeze_Epoch = 100Freeze_Train = Falseoptimizer_type = 'adam'Init_lr = 1e-3weight_decay = 0。不冻结
# SGD
# Init_Epoch = 0Freeze_Epoch = 50UnFreeze_Epoch = 300Freeze_Train = Trueoptimizer_type = 'sgd'Init_lr = 1e-2weight_decay = 5e-4。冻结
# Init_Epoch = 0UnFreeze_Epoch = 300Freeze_Train = Falseoptimizer_type = 'sgd'Init_lr = 1e-2weight_decay = 5e-4。不冻结
# 其中UnFreeze_Epoch可以在100-300之间调整。
# 从0开始训练
# Init_Epoch = 0UnFreeze_Epoch >= 300Unfreeze_batch_size >= 16Freeze_Train = False不冻结训练
# 其中UnFreeze_Epoch尽量不小于300。optimizer_type = 'sgd'Init_lr = 1e-2mosaic = True。
# batch_size的设置
# 在显卡能够接受的范围内以大为好。显存不足与数据集大小无关提示显存不足OOM或者CUDA out of memory请调小batch_size。
# 受到BatchNorm层影响batch_size最小为2不能为1。
# 正常情况下Freeze_batch_size建议为Unfreeze_batch_size的1-2倍。不建议设置的差距过大因为关系到学习率的自动调整。
# ----------------------------------------------------------------------------------------------------------------------------#
# ------------------------------------------------------------------#
# 冻结阶段训练参数
# 此时模型的主干被冻结了,特征提取网络不发生改变
# 占用的显存较小,仅对网络进行微调
# Init_Epoch 模型当前开始的训练世代其值可以大于Freeze_Epoch如设置
# Init_Epoch = 60、Freeze_Epoch = 50、UnFreeze_Epoch = 100
# 会跳过冻结阶段直接从60代开始并调整对应的学习率。
# (断点续练时使用)
# Freeze_Epoch 模型冻结训练的Freeze_Epoch
# (当Freeze_Train=False时失效)
# Freeze_batch_size 模型冻结训练的batch_size
# (当Freeze_Train=False时失效)
# ------------------------------------------------------------------#
Init_Epoch = 0
Freeze_Epoch = 100
Freeze_batch_size = 4
# ------------------------------------------------------------------#
# 解冻阶段训练参数
# 此时模型的主干不被冻结了,特征提取网络会发生改变
# 占用的显存较大,网络所有的参数都会发生改变
# UnFreeze_Epoch 模型总共训练的epoch
# SGD需要更长的时间收敛因此设置较大的UnFreeze_Epoch
# Adam可以使用相对较小的UnFreeze_Epoch
# Unfreeze_batch_size 模型在解冻后的batch_size
# ------------------------------------------------------------------#
UnFreeze_Epoch = 100
Unfreeze_batch_size = 4
# ------------------------------------------------------------------#
# Freeze_Train 是否进行冻结训练
# 默认先冻结主干训练后解冻训练。
# ------------------------------------------------------------------#
Freeze_Train = False
# ------------------------------------------------------------------#
# 其它训练参数:学习率、优化器、学习率下降有关
# ------------------------------------------------------------------#
# ------------------------------------------------------------------#
# Init_lr 模型的最大学习率
# Min_lr 模型的最小学习率默认为最大学习率的0.01
# ------------------------------------------------------------------#
Init_lr = 1e-3
Min_lr = Init_lr * 0.01
# ------------------------------------------------------------------#
# optimizer_type 使用到的优化器种类可选的有adam、sgd
# 当使用Adam优化器时建议设置 Init_lr=1e-3
# 当使用SGD优化器时建议设置 Init_lr=1e-2
# momentum 优化器内部使用到的momentum参数
# weight_decay 权值衰减,可防止过拟合
# adam会导致weight_decay错误使用adam时建议设置为0。
# ------------------------------------------------------------------#
optimizer_type = "adam"
momentum = 0.937
weight_decay = 1e-4
# ------------------------------------------------------------------#
# lr_decay_type 使用到的学习率下降方式可选的有step、cos
# ------------------------------------------------------------------#
lr_decay_type = "cos"
# ------------------------------------------------------------------#
# save_period 多少个epoch保存一次权值
# ------------------------------------------------------------------#
save_period = 10
# ------------------------------------------------------------------#
# save_dir 权值与日志文件保存的文件夹
# ------------------------------------------------------------------#
save_dir = f'logs/TDCNet_epoch_{UnFreeze_Epoch}_batch_{Unfreeze_batch_size}_optim_{optimizer_type}_lr_{Init_lr}_T_{num_frame}'
# ------------------------------------------------------------------#
# eval_flag 是否在训练时进行评估,评估对象为验证集
# 安装pycocotools库后评估体验更佳。
# eval_period 代表多少个epoch评估一次不建议频繁的评估
# 评估需要消耗较多的时间,频繁评估会导致训练非常慢
# 此处获得的mAP会与get_map.py获得的会有所不同原因有二
# 此处获得的mAP为验证集的mAP。
# (二)此处设置评估参数较为保守,目的是加快评估速度。
# ------------------------------------------------------------------#
eval_flag = True
eval_period = 200
# ------------------------------------------------------------------#
# num_workers 用于设置是否使用多线程读取数据
# 开启后会加快数据读取速度,但是会占用更多内存
# 内存较小的电脑可以设置为2或者0
# ------------------------------------------------------------------#
num_workers = 8
# ----------------------------------------------------#
# 获得图片路径和标签
# ----------------------------------------------------#
DATA_PATH = "/Dataset/IRSTD-UAV/"
train_annotation_path = "/Dataset/IRSTD-UAV/train.txt"
val_annotation_path = "/Dataset/IRSTD-UAV/val.txt"
# ------------------------------------------------------#
# 设置用到的显卡
# ------------------------------------------------------#
ngpus_per_node = torch.cuda.device_count()
if distributed:
dist.init_process_group(backend="nccl")
local_rank = int(os.environ["LOCAL_RANK"])
rank = int(os.environ["RANK"])
device = torch.device("cuda", local_rank)
if local_rank == 0:
print(f"[{os.getpid()}] (rank = {rank}, local_rank = {local_rank}) training...")
print("Gpu Device Count : ", ngpus_per_node)
else:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
local_rank = 0
rank = 0
# ------------------------------------------------------#
# 设置随机种子
# ------------------------------------------------------#
seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
class_names, num_classes = get_classes(classes_path)
model = TDCNetwork(num_classes=1, num_frame=num_frame)
weights_init(model)
if model_path != '':
if local_rank == 0:
print('Load weights {}.'.format(model_path))
# ------------------------------------------------------#
# 根据预训练权重的Key和模型的Key进行加载
# ------------------------------------------------------#
model_dict = model.state_dict()
# pdb.set_trace()
pretrained_dict = torch.load(model_path, map_location=device)
load_key, no_load_key, temp_dict = [], [], {}
for k, v in pretrained_dict.items():
if k in model_dict.keys() and np.shape(model_dict[k]) == np.shape(v):
temp_dict[k] = v
load_key.append(k)
else:
no_load_key.append(k)
model_dict.update(temp_dict)
model.load_state_dict(model_dict)
# ------------------------------------------------------#
# 显示没有匹配上的Key
# ------------------------------------------------------#
if local_rank == 0:
print("\nSuccessful Load Key:", str(load_key)[:500], "……\nSuccessful Load Key Num:", len(load_key))
print("\nFail To Load Key:", str(no_load_key)[:500], "……\nFail To Load Key num:", len(no_load_key))
print("\n\033[1;33;44m温馨提示head部分没有载入是正常现象Backbone部分没有载入是错误的。\033[0m")
# ----------------------#
# 获得损失函数
# ----------------------#
yolo_loss = YOLOLoss(num_classes, fp16, strides=[8])
# ----------------------#
# 记录Loss
# ----------------------#
if local_rank == 0:
time_str = datetime.datetime.strftime(datetime.datetime.now(), '%Y_%m_%d_%H_%M_%S')
log_dir = os.path.join(save_dir, "loss_" + str(time_str))
loss_history = LossHistory(log_dir, model, input_shape=input_shape)
# pdb.set_trace()
else:
loss_history = None
# ------------------------------------------------------------------#
# torch 1.2不支持amp建议使用torch 1.7.1及以上正确使用fp16
# 因此torch1.2这里显示"could not be resolve"
# ------------------------------------------------------------------#
if fp16:
from torch.cuda.amp import GradScaler as GradScaler
scaler = GradScaler()
else:
scaler = None
model_train = model.train()
# ----------------------------#
# 多卡同步Bn
# ----------------------------#
if sync_bn and ngpus_per_node > 1 and distributed:
model_train = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model_train)
elif sync_bn:
print("Sync_bn is not support in one gpu or not distributed.")
if Cuda:
if distributed:
# ----------------------------#
# 多卡平行运行
# ----------------------------#
model_train = model_train.cuda(local_rank)
model_train = torch.nn.parallel.DistributedDataParallel(model_train, device_ids=[local_rank], find_unused_parameters=True)
else:
model_train = torch.nn.DataParallel(model)
cudnn.benchmark = True
model_train = model_train.cuda()
# ----------------------------#
# 权值平滑
# ----------------------------#
# pdb.set_trace()
ema = ModelEMA(model_train)
# ---------------------------#
# 读取数据集对应的txt
# ---------------------------#
with open(train_annotation_path, encoding='utf-8') as f:
train_lines = f.readlines()
with open(val_annotation_path, encoding='utf-8') as f:
val_lines = f.readlines()
num_train = len(train_lines)
num_val = len(val_lines)
if local_rank == 0:
show_config(
classes_path=classes_path, model_path=model_path, input_shape=input_shape, \
Init_Epoch=Init_Epoch, Freeze_Epoch=Freeze_Epoch, UnFreeze_Epoch=UnFreeze_Epoch, Freeze_batch_size=Freeze_batch_size, Unfreeze_batch_size=Unfreeze_batch_size, Freeze_Train=Freeze_Train, \
Init_lr=Init_lr, Min_lr=Min_lr, optimizer_type=optimizer_type, momentum=momentum, lr_decay_type=lr_decay_type, \
save_period=save_period, save_dir=log_dir, num_workers=num_workers, num_train=num_train, num_val=num_val
)
# ---------------------------------------------------------#
# 总训练世代指的是遍历全部数据的总次数
# 总训练步长指的是梯度下降的总次数
# 每个训练世代包含若干训练步长,每个训练步长进行一次梯度下降。
# 此处仅建议最低训练世代,上不封顶,计算时只考虑了解冻部分
# ----------------------------------------------------------#
wanted_step = 5e4 if optimizer_type == "sgd" else 1.5e4
total_step = num_train // Unfreeze_batch_size * UnFreeze_Epoch
if total_step <= wanted_step:
if num_train // Unfreeze_batch_size == 0:
raise ValueError('数据集过小,无法进行训练,请扩充数据集。')
wanted_epoch = wanted_step // (num_train // Unfreeze_batch_size) + 1
print("\n\033[1;33;44m[Warning] 使用%s优化器时,建议将训练总步长设置到%d以上。\033[0m" % (optimizer_type, wanted_step))
print("\033[1;33;44m[Warning] 本次运行的总训练数据量为%dUnfreeze_batch_size为%d,共训练%d个Epoch计算出总训练步长为%d\033[0m" % (num_train, Unfreeze_batch_size, UnFreeze_Epoch, total_step))
print("\033[1;33;44m[Warning] 由于总训练步长为%d,小于建议总步长%d,建议设置总世代为%d\033[0m" % (total_step, wanted_step, wanted_epoch))
# ------------------------------------------------------#
# 主干特征提取网络特征通用,冻结训练可以加快训练速度
# 也可以在训练初期防止权值被破坏。
# Init_Epoch为起始世代
# Freeze_Epoch为冻结训练的世代
# UnFreeze_Epoch总训练世代
# 提示OOM或者显存不足请调小Batch_size
# ------------------------------------------------------#
if True:
UnFreeze_flag = False
# ------------------------------------#
# 冻结一定部分训练
# ------------------------------------#
if Freeze_Train:
for param in model.backbone.parameters():
param.requires_grad = False
for param in model.backbone_3d.parameters():
param.requires_grad = False
# -------------------------------------------------------------------#
# 如果不冻结训练的话直接设置batch_size为Unfreeze_batch_size
# -------------------------------------------------------------------#
batch_size = Freeze_batch_size if Freeze_Train else Unfreeze_batch_size
# -------------------------------------------------------------------#
# 判断当前batch_size自适应调整学习率
# -------------------------------------------------------------------#
nbs = 64
lr_limit_max = 1e-3 if optimizer_type == 'adam' else 5e-2
lr_limit_min = 1e-5 if optimizer_type == 'adam' else 5e-4
Init_lr_fit = min(max(batch_size / nbs * Init_lr, lr_limit_min), lr_limit_max)
Min_lr_fit = min(max(batch_size / nbs * Min_lr, lr_limit_min * 1e-2), lr_limit_max * 1e-2)
# ---------------------------------------#
# 根据optimizer_type选择优化器
# ---------------------------------------#
pg0, pg1, pg2 = [], [], []
for k, v in model.named_modules():
if hasattr(v, "bias") and isinstance(v.bias, nn.Parameter):
pg2.append(v.bias)
if isinstance(v, nn.BatchNorm2d) or "bn" in k:
pg0.append(v.weight)
elif hasattr(v, "weight") and isinstance(v.weight, nn.Parameter):
pg1.append(v.weight)
optimizer = {
'adam': optim.Adam(pg0, Init_lr_fit, betas=(momentum, 0.999)),
'sgd': optim.SGD(pg0, Init_lr_fit, momentum=momentum, nesterov=True)
}[optimizer_type]
optimizer.add_param_group({"params": pg1, "weight_decay": weight_decay})
optimizer.add_param_group({"params": pg2})
# ---------------------------------------#
# 获得学习率下降的公式
# ---------------------------------------#
lr_scheduler_func = get_lr_scheduler(lr_decay_type, Init_lr_fit, Min_lr_fit, UnFreeze_Epoch)
# ---------------------------------------#
# 判断每一个世代的长度
# ---------------------------------------#
epoch_step = num_train // batch_size
epoch_step_val = num_val // batch_size
if epoch_step == 0 or epoch_step_val == 0:
raise ValueError("数据集过小,无法继续进行训练,请扩充数据集。")
if ema:
ema.updates = epoch_step * Init_Epoch
train_dataset = seqDataset(train_annotation_path, input_shape[0], num_frame, 'train') # 5
val_dataset = seqDataset(val_annotation_path, input_shape[0], num_frame, 'val') # 5
if distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, shuffle=True, )
val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False, )
batch_size = batch_size // ngpus_per_node
shuffle = False
else:
train_sampler = None
val_sampler = None
shuffle = True
gen = DataLoader(train_dataset, shuffle=shuffle, batch_size=batch_size, num_workers=num_workers, pin_memory=True,
drop_last=True, collate_fn=dataset_collate, sampler=train_sampler)
gen_val = DataLoader(val_dataset, shuffle=shuffle, batch_size=batch_size, num_workers=num_workers, pin_memory=True,
drop_last=True, collate_fn=dataset_collate, sampler=val_sampler)
# ----------------------#
# 记录eval的map曲线
# ----------------------#
if local_rank == 0:
eval_callback = EvalCallback(model, input_shape, class_names, num_classes, val_lines, log_dir, Cuda, \
eval_flag=eval_flag, period=eval_period)
else:
eval_callback = None
# ---------------------------------------#
# 开始模型训练
# ---------------------------------------#
for epoch in range(Init_Epoch, UnFreeze_Epoch):
# ---------------------------------------#
# 如果模型有冻结学习部分
# 则解冻,并设置参数
# ---------------------------------------#
if epoch >= Freeze_Epoch and not UnFreeze_flag and Freeze_Train:
batch_size = Unfreeze_batch_size
# -------------------------------------------------------------------#
# 判断当前batch_size自适应调整学习率
# -------------------------------------------------------------------#
nbs = 64
lr_limit_max = 1e-3 if optimizer_type == 'adam' else 5e-2
lr_limit_min = 1e-5 if optimizer_type == 'adam' else 5e-4
Init_lr_fit = min(max(batch_size / nbs * Init_lr, lr_limit_min), lr_limit_max)
Min_lr_fit = min(max(batch_size / nbs * Min_lr, lr_limit_min * 1e-2), lr_limit_max * 1e-2)
# ---------------------------------------#
# 获得学习率下降的公式
# ---------------------------------------#
lr_scheduler_func = get_lr_scheduler(lr_decay_type, Init_lr_fit, Min_lr_fit, UnFreeze_Epoch)
for param in model.backbone.parameters():
param.requires_grad = True
for param in model.backbone_3d.parameters():
param.requires_grad = True
epoch_step = num_train // batch_size
epoch_step_val = num_val // batch_size
if epoch_step == 0 or epoch_step_val == 0:
raise ValueError("数据集过小,无法继续进行训练,请扩充数据集。")
if distributed:
batch_size = batch_size // ngpus_per_node
if ema:
ema.updates = epoch_step * epoch
gen = DataLoader(train_dataset, shuffle=shuffle, batch_size=batch_size, num_workers=num_workers, pin_memory=True,
drop_last=True, collate_fn=dataset_collate, sampler=train_sampler)
gen_val = DataLoader(val_dataset, shuffle=shuffle, batch_size=batch_size, num_workers=num_workers, pin_memory=True,
drop_last=True, collate_fn=dataset_collate, sampler=val_sampler)
UnFreeze_flag = True
gen.dataset.epoch_now = epoch
gen_val.dataset.epoch_now = epoch
if distributed:
train_sampler.set_epoch(epoch)
set_optimizer_lr(optimizer, lr_scheduler_func, epoch)
# pdb.set_trace()
fit_one_epoch(model_train, model, ema, yolo_loss, loss_history, eval_callback, optimizer, epoch, epoch_step, epoch_step_val, gen, gen_val, UnFreeze_Epoch, Cuda, fp16, scaler, save_period, log_dir, local_rank)
if distributed:
dist.barrier()
if local_rank == 0:
loss_history.writer.close()

288
utils/callbacks.py Normal file
View File

@@ -0,0 +1,288 @@
from email.mime import image
import os
import torch
import matplotlib
matplotlib.use('Agg')
import scipy.signal
from matplotlib import pyplot as plt
from torch.utils.tensorboard import SummaryWriter
import shutil
import numpy as np
from PIL import Image
from tqdm import tqdm
from .utils import cvtColor, preprocess_input, resize_image
from .utils_bbox import decode_outputs, non_max_suppression
from .utils_map import get_coco_map, get_map
# from utils import cvtColor, preprocess_input, resize_image
# from utils_bbox import decode_outputs, non_max_suppression
# from utils_map import get_coco_map, get_map
class LossHistory():
def __init__(self, log_dir, model, input_shape):
self.log_dir = log_dir
self.losses = []
self.val_loss = []
os.makedirs(self.log_dir)
self.writer = SummaryWriter(self.log_dir)
try:
dummy_input = torch.randn(2, 3, input_shape[0], input_shape[1])
self.writer.add_graph(model, dummy_input)
except:
pass
def append_loss(self, epoch, loss, val_loss):
if not os.path.exists(self.log_dir):
os.makedirs(self.log_dir)
self.losses.append(loss)
self.val_loss.append(val_loss)
with open(os.path.join(self.log_dir, "epoch_loss.txt"), 'a') as f:
f.write(str(loss))
f.write("\n")
with open(os.path.join(self.log_dir, "epoch_val_loss.txt"), 'a') as f:
f.write(str(val_loss))
f.write("\n")
self.writer.add_scalar('loss', loss, epoch)
self.writer.add_scalar('val_loss', val_loss, epoch)
self.loss_plot()
def loss_plot(self):
iters = range(len(self.losses))
plt.figure()
plt.plot(iters, self.losses, 'red', linewidth = 2, label='train loss')
plt.plot(iters, self.val_loss, 'coral', linewidth = 2, label='val loss')
try:
if len(self.losses) < 25:
num = 5
else:
num = 15
plt.plot(iters, scipy.signal.savgol_filter(self.losses, num, 3), 'green', linestyle = '--', linewidth = 2, label='smooth train loss')
plt.plot(iters, scipy.signal.savgol_filter(self.val_loss, num, 3), '#8B4513', linestyle = '--', linewidth = 2, label='smooth val loss')
except:
pass
plt.grid(True)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend(loc="upper right")
plt.savefig(os.path.join(self.log_dir, "epoch_loss.png"))
plt.cla()
plt.close("all")
class EvalCallback():
def __init__(self, net, input_shape, class_names, num_classes, val_lines, log_dir, cuda, \
map_out_path=".temp_map_out", max_boxes=100, confidence=0.05, nms_iou=0.5, letterbox_image=True, MINOVERLAP=0.5, eval_flag=True, period=1):
super(EvalCallback, self).__init__()
self.net = net
self.input_shape = input_shape
self.class_names = class_names
self.num_classes = num_classes
self.val_lines = val_lines
self.log_dir = log_dir
self.cuda = cuda
self.map_out_path = map_out_path
self.max_boxes = max_boxes
self.confidence = confidence
self.nms_iou = nms_iou
self.letterbox_image = letterbox_image
self.MINOVERLAP = MINOVERLAP
self.eval_flag = eval_flag
self.period = period
self.maps = [0]
self.epoches = [0]
if self.eval_flag:
with open(os.path.join(self.log_dir, "epoch_map.txt"), 'a') as f:
f.write(str(0))
f.write("\n")
# def get_history_imgs(self, line):
# dir_path = line.replace(line.split('/')[-1],'')
# file_type = line.split('.')[-1]
# index = int(line.split('/')[-1][:-4])
# return [os.path.join(dir_path, "%d.%s" % (max(id, 0),file_type)) for id in range(index - 4, index + 1)]
# def get_history_imgs(self, line):
# dir_path = line.replace(line.split('/')[-1],'')
# file_type = line.split('.')[-1]
# index = int(line.split("/")[-1][:8])
# return [os.path.join(dir_path, "%08d.%s" % (max(id, 0),file_type)) for id in range(index - 4, index + 1)]
def get_history_imgs(self, line):
dir_path = line.replace(line.split('/')[-1],'')
file_type = line.split('.')[-1]
index = int(line.split("/")[-1][:8])
return [os.path.join(dir_path, "%08d.%s" % (max(id, 1),file_type)) for id in range(index - 4, index + 1)]
def get_map_txt(self, image_id, images, class_names, map_out_path):
f = open(os.path.join(map_out_path, "detection-results/"+image_id+".txt"),"w")
image_shape = np.array(np.shape(images[0])[0:2])
#---------------------------------------------------------#
# 在这里将图像转换成RGB图像防止灰度图在预测时报错。
# 代码仅仅支持RGB图像的预测所有其它类型的图像都会转化成RGB
#---------------------------------------------------------#
images = [cvtColor(image) for image in images]
#---------------------------------------------------------#
# 给图像增加灰条实现不失真的resize
# 也可以直接resize进行识别
#---------------------------------------------------------#
image_data = [resize_image(image, (self.input_shape[1],self.input_shape[0]), self.letterbox_image) for image in images]
#---------------------------------------------------------#
# 添加上batch_size维度
#---------------------------------------------------------#
image_data = [np.transpose(preprocess_input(np.array(image, dtype='float32')), (2, 0, 1)) for image in image_data]
# (3, 640, 640) -> (3, 16, 640, 640)
image_data = np.stack(image_data, axis=1)
image_data = np.expand_dims(image_data, 0)
with torch.no_grad():
images = torch.from_numpy(image_data)
if self.cuda:
images = images.cuda()
#---------------------------------------------------------#
# 将图像输入网络当中进行预测!
#---------------------------------------------------------#
outputs = self.net(images)
outputs = decode_outputs(outputs, self.input_shape)
#---------------------------------------------------------#
# 将预测框进行堆叠,然后进行非极大抑制
#---------------------------------------------------------#
results = non_max_suppression(outputs, self.num_classes, self.input_shape,
image_shape, self.letterbox_image, conf_thres = self.confidence, nms_thres = self.nms_iou)
if results[0] is None:
return
top_label = np.array(results[0][:, 6], dtype = 'int32')
top_conf = results[0][:, 4] * results[0][:, 5]
top_boxes = results[0][:, :4]
top_100 = np.argsort(top_label)[::-1][:self.max_boxes]
top_boxes = top_boxes[top_100]
top_conf = top_conf[top_100]
top_label = top_label[top_100]
for i, c in list(enumerate(top_label)):
predicted_class = self.class_names[int(c)]
box = top_boxes[i]
score = str(top_conf[i])
top, left, bottom, right = box
if predicted_class not in class_names:
continue
f.write("%s %s %s %s %s %s\n" % (predicted_class, score[:6], str(int(left)), str(int(top)), str(int(right)),str(int(bottom))))
f.close()
return
def on_epoch_end(self, epoch, model_eval):
if epoch % self.period == 0 and self.eval_flag:
self.net = model_eval
if not os.path.exists(self.map_out_path):
os.makedirs(self.map_out_path)
if not os.path.exists(os.path.join(self.map_out_path, "ground-truth")):
os.makedirs(os.path.join(self.map_out_path, "ground-truth"))
if not os.path.exists(os.path.join(self.map_out_path, "detection-results")):
os.makedirs(os.path.join(self.map_out_path, "detection-results"))
print("Get map.")
for annotation_line in tqdm(self.val_lines):
line = annotation_line.split()
'''
# 不同视频的图片序号会重复, 视频号-图片序号作为id
'''
image_id = "-".join(line[0].split("/")[6:8]).split('.')[0]
#------------------------------#
# 读取图像并转换成RGB图像
#------------------------------#
# cb update
images = self.get_history_imgs(line[0])
images = [Image.open(item) for item in images]
# image = Image.open(line[0])
#------------------------------#
# 获得预测框
#------------------------------#
gt_boxes = np.array([np.array(list(map(int,box.split(',')))) for box in line[1:]])
#------------------------------#
# 获得预测txt
#------------------------------#
self.get_map_txt(image_id, images, self.class_names, self.map_out_path)
#------------------------------#
# 获得真实框txt
#------------------------------#
with open(os.path.join(self.map_out_path, "ground-truth/"+image_id+".txt"), "w") as new_f:
for box in gt_boxes:
left, top, right, bottom, obj = box
obj_name = self.class_names[obj]
new_f.write("%s %s %s %s %s\n" % (obj_name, left, top, right, bottom))
print("Calculate Map.")
try:
temp_map = get_coco_map(class_names = self.class_names, path = self.map_out_path)[1]
except:
temp_map = get_map(self.MINOVERLAP, False, path = self.map_out_path)
self.maps.append(temp_map)
self.epoches.append(epoch)
with open(os.path.join(self.log_dir, "epoch_map.txt"), 'a') as f:
f.write(str(temp_map))
f.write("\n")
plt.figure()
plt.plot(self.epoches, self.maps, 'red', linewidth = 2, label='train map')
plt.grid(True)
plt.xlabel('Epoch')
plt.ylabel('Map %s'%str(self.MINOVERLAP))
plt.title('A Map Curve')
plt.legend(loc="upper right")
plt.savefig(os.path.join(self.log_dir, "epoch_map.png"))
plt.cla()
plt.close("all")
print("Get map done.")
shutil.rmtree(self.map_out_path)
# def get_history_imgs(line):
# dir_path = line.replace(line.split('/')[-1],'')
# file_type = line.split('.')[-1]
# index = int(line.split('/')[-1][:-4])
# image_id = "-".join(line.split("/")[6:8]).split('.')[0]
# print(image_id)
# return [os.path.join(dir_path, "%d.%s" % (max(id, 0),file_type)) for id in range(index - 4, index + 1)]
# if __name__ == "__main__":
# with open('coco_val.txt', encoding='utf-8') as f:
# val_lines = f.readlines()
# for annotation_line in val_lines:
# line = annotation_line.split()
# images = get_history_imgs(line[0])
# # for item in images:
# # print(item)
# # break

64
utils/utils.py Normal file
View File

@@ -0,0 +1,64 @@
import numpy as np
from PIL import Image
#---------------------------------------------------------#
# 将图像转换成RGB图像防止灰度图在预测时报错。
# 代码仅仅支持RGB图像的预测所有其它类型的图像都会转化成RGB
#---------------------------------------------------------#
def cvtColor(image):
if len(np.shape(image)) == 3 and np.shape(image)[2] == 3:
return image
else:
image = image.convert('RGB')
return image
#---------------------------------------------------#
# 对输入图像进行resize
#---------------------------------------------------#
def resize_image(image, size, letterbox_image):
iw, ih = image.size
w, h = size
if letterbox_image:
scale = min(w/iw, h/ih)
nw = int(iw*scale)
nh = int(ih*scale)
image = image.resize((nw,nh), Image.BICUBIC)
new_image = Image.new('RGB', size, (128,128,128))
new_image.paste(image, ((w-nw)//2, (h-nh)//2))
else:
new_image = image.resize((w, h), Image.BICUBIC)
return new_image
#---------------------------------------------------#
# 获得类
#---------------------------------------------------#
def get_classes(classes_path):
with open(classes_path, encoding='utf-8') as f:
class_names = f.readlines()
class_names = [c.strip() for c in class_names]
return class_names, len(class_names)
def preprocess_input(image):
image /= 255.0
image -= np.array([0.485, 0.456, 0.406])
image /= np.array([0.229, 0.224, 0.225])
return image
#---------------------------------------------------#
# 获得学习率
#---------------------------------------------------#
def get_lr(optimizer):
for param_group in optimizer.param_groups:
return param_group['lr']
def show_config(**kwargs):
print('Configurations:')
print('-' * 130)
print('|%25s | %100s|' % ('keys', 'values'))
print('-' * 130)
for key, value in kwargs.items():
print('|%25s | %100s|' % (str(key), str(value)))
print('-' * 130)

180
utils/utils_bbox.py Normal file
View File

@@ -0,0 +1,180 @@
import numpy as np
import torch
from torchvision.ops import nms, boxes
def yolo_correct_boxes(box_xy, box_wh, input_shape, image_shape, letterbox_image):
#-----------------------------------------------------------------#
# 把y轴放前面是因为方便预测框和图像的宽高进行相乘
#-----------------------------------------------------------------#
box_yx = box_xy[..., ::-1]
box_hw = box_wh[..., ::-1]
input_shape = np.array(input_shape)
image_shape = np.array(image_shape)
if letterbox_image:
#-----------------------------------------------------------------#
# 这里求出来的offset是图像有效区域相对于图像左上角的偏移情况
# new_shape指的是宽高缩放情况
#-----------------------------------------------------------------#
new_shape = np.round(image_shape * np.min(input_shape/image_shape))
offset = (input_shape - new_shape)/2./input_shape
scale = input_shape/new_shape
box_yx = (box_yx - offset) * scale
box_hw *= scale
box_mins = box_yx - (box_hw / 2.)
box_maxes = box_yx + (box_hw / 2.)
boxes = np.concatenate([box_mins[..., 0:1], box_mins[..., 1:2], box_maxes[..., 0:1], box_maxes[..., 1:2]], axis=-1)
boxes *= np.concatenate([image_shape, image_shape], axis=-1)
return boxes
def decode_outputs(outputs, input_shape):
grids = []
strides = []
hw = [x.shape[-2:] for x in outputs]
#---------------------------------------------------#
# outputs输入前代表每个特征层的预测结果
# batch_size, 4 + 1 + num_classes, 80, 80 => batch_size, 4 + 1 + num_classes, 6400
# batch_size, 5 + num_classes, 40, 40
# batch_size, 5 + num_classes, 20, 20
# batch_size, 4 + 1 + num_classes, 6400 + 1600 + 400 -> batch_size, 4 + 1 + num_classes, 8400
# 堆叠后为batch_size, 8400, 5 + num_classes
#---------------------------------------------------#
outputs = torch.cat([x.flatten(start_dim=2) for x in outputs], dim=2).permute(0, 2, 1)
#---------------------------------------------------#
# 获得每一个特征点属于每一个种类的概率
#---------------------------------------------------#
outputs[:, :, 4:] = torch.sigmoid(outputs[:, :, 4:])
for h, w in hw:
#---------------------------#
# 根据特征层的高宽生成网格点
#---------------------------#
grid_y, grid_x = torch.meshgrid([torch.arange(h), torch.arange(w)], indexing='ij')
#---------------------------#
# 1, 6400, 2
# 1, 1600, 2
# 1, 400, 2
#---------------------------#
grid = torch.stack((grid_x, grid_y), 2).view(1, -1, 2)
shape = grid.shape[:2]
grids.append(grid)
strides.append(torch.full((shape[0], shape[1], 1), input_shape[0] / h))
#---------------------------#
# 将网格点堆叠到一起
# 1, 6400, 2
# 1, 1600, 2
# 1, 400, 2
#
# 1, 8400, 2
#---------------------------#
grids = torch.cat(grids, dim=1).type(outputs.type())
strides = torch.cat(strides, dim=1).type(outputs.type())
#------------------------#
# 根据网格点进行解码
#------------------------#
outputs[..., :2] = (outputs[..., :2] + grids) * strides
outputs[..., 2:4] = torch.exp(outputs[..., 2:4]) * strides
#-----------------#
# 归一化
#-----------------#
outputs[..., [0,2]] = outputs[..., [0,2]] / input_shape[1]
outputs[..., [1,3]] = outputs[..., [1,3]] / input_shape[0]
return outputs
def non_max_suppression(prediction, num_classes, input_shape, image_shape, letterbox_image, conf_thres=0.5, nms_thres=0.4):
#----------------------------------------------------------#
# 将预测结果的格式转换成左上角右下角的格式。
# prediction [batch_size, num_anchors, 85]
#----------------------------------------------------------#
box_corner = prediction.new(prediction.shape)
box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2
box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2
box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2
box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2
prediction[:, :, :4] = box_corner[:, :, :4]
output = [None for _ in range(len(prediction))]
#----------------------------------------------------------#
# 对输入图片进行循环,一般只会进行一次
#----------------------------------------------------------#
for i, image_pred in enumerate(prediction):
#----------------------------------------------------------#
# 对种类预测部分取max。
# class_conf [num_anchors, 1] 种类置信度
# class_pred [num_anchors, 1] 种类
#----------------------------------------------------------#
class_conf, class_pred = torch.max(image_pred[:, 5:5 + num_classes], 1, keepdim=True)
#----------------------------------------------------------#
# 利用置信度进行第一轮筛选
#----------------------------------------------------------#
conf_mask = (image_pred[:, 4] * class_conf[:, 0] >= conf_thres).squeeze()
if not image_pred.size(0):
continue
#-------------------------------------------------------------------------#
# detections [num_anchors, 7]
# 7的内容为x1, y1, x2, y2, obj_conf, class_conf, class_pred
#-------------------------------------------------------------------------#
detections = torch.cat((image_pred[:, :5], class_conf, class_pred.float()), 1)
detections = detections[conf_mask]
nms_out_index = boxes.batched_nms(
detections[:, :4],
detections[:, 4] * detections[:, 5],
detections[:, 6],
nms_thres,
)
output[i] = detections[nms_out_index]
# #------------------------------------------#
# # 获得预测结果中包含的所有种类
# #------------------------------------------#
# unique_labels = detections[:, -1].cpu().unique()
# if prediction.is_cuda:
# unique_labels = unique_labels.cuda()
# detections = detections.cuda()
# for c in unique_labels:
# #------------------------------------------#
# # 获得某一类得分筛选后全部的预测结果
# #------------------------------------------#
# detections_class = detections[detections[:, -1] == c]
# #------------------------------------------#
# # 使用官方自带的非极大抑制会速度更快一些!
# #------------------------------------------#
# keep = nms(
# detections_class[:, :4],
# detections_class[:, 4] * detections_class[:, 5],
# nms_thres
# )
# max_detections = detections_class[keep]
# # # 按照存在物体的置信度排序
# # _, conf_sort_index = torch.sort(detections_class[:, 4]*detections_class[:, 5], descending=True)
# # detections_class = detections_class[conf_sort_index]
# # # 进行非极大抑制
# # max_detections = []
# # while detections_class.size(0):
# # # 取出这一类置信度最高的一步一步往下判断判断重合程度是否大于nms_thres如果是则去除掉
# # max_detections.append(detections_class[0].unsqueeze(0))
# # if len(detections_class) == 1:
# # break
# # ious = bbox_iou(max_detections[-1], detections_class[1:])
# # detections_class = detections_class[1:][ious < nms_thres]
# # # 堆叠
# # max_detections = torch.cat(max_detections).data
# # Add max detections to outputs
# output[i] = max_detections if output[i] is None else torch.cat((output[i], max_detections))
if output[i] is not None:
output[i] = output[i].cpu().numpy()
box_xy, box_wh = (output[i][:, 0:2] + output[i][:, 2:4])/2, output[i][:, 2:4] - output[i][:, 0:2]
output[i][:, :4] = yolo_correct_boxes(box_xy, box_wh, input_shape, image_shape, letterbox_image)
return output

145
utils/utils_fit.py Normal file
View File

@@ -0,0 +1,145 @@
import os
import torch
from tqdm import tqdm
from utils.utils import get_lr
import pdb
def fit_one_epoch(model_train, model, ema, yolo_loss, loss_history, eval_callback, optimizer, epoch, epoch_step, epoch_step_val, gen, gen_val, Epoch, cuda, fp16, scaler, save_period, save_dir, local_rank=0):
loss = 0
val_loss = 0
epoch_step = epoch_step // 5 # 每次epoch只随机用训练集合的一部分 防止过拟合
if local_rank == 0:
print('Start Train')
pbar = tqdm(total=epoch_step,desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3)
#pdb.set_trace()
model_train.train()
for iteration, batch in enumerate(gen):
if iteration >= epoch_step:
break
images, targets = batch[0], batch[1]
with torch.no_grad():
if cuda:
images = images.cuda(local_rank)
targets = [ann.cuda(local_rank) for ann in targets]
#----------------------#
# 清零梯度
#----------------------#
optimizer.zero_grad()
if not fp16:
#----------------------#
# 前向传播
#----------------------#
#pdb.set_trace()
outputs = model_train(images)
#----------------------#
# 计算损失
#----------------------#
loss_value = yolo_loss(outputs, targets) #+ motion_loss
#----------------------#
# 反向传播
#----------------------#
# torch.autograd.set_detect_anomaly(True)
# with torch.autograd.detect_anomaly():
loss_value.backward()
optimizer.step()
else:
from torch.cuda.amp import autocast
with autocast():
outputs = model_train(images)
#----------------------#
# 计算损失
#----------------------#
loss_value = yolo_loss(outputs, targets)
#----------------------#
# 反向传播
#----------------------#
scaler.scale(loss_value).backward()
scaler.step(optimizer)
scaler.update()
#pdb.set_trace()
if ema:
ema.update(model_train)
loss += loss_value.item()
if local_rank == 0:
pbar.set_postfix(**{'loss' : loss / (iteration + 1),
'lr' : get_lr(optimizer)})
pbar.update(1)
if local_rank == 0:
pbar.close()
print('Finish Train')
print('Start Validation')
pbar = tqdm(total=epoch_step_val, desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3)
if ema:
model_train_eval = ema.ema
else:
model_train_eval = model_train.eval()
for iteration, batch in enumerate(gen_val):
if iteration >= epoch_step_val:
break
images, targets = batch[0], batch[1]
with torch.no_grad():
if cuda:
images = images.cuda(local_rank)
targets = [ann.cuda(local_rank) for ann in targets]
#----------------------#
# 清零梯度
#----------------------#
optimizer.zero_grad()
#----------------------#
# 前向传播
#----------------------#
outputs = model_train_eval(images)
#----------------------#
# 计算损失
#----------------------#
loss_value = yolo_loss(outputs, targets)
val_loss += loss_value.item()
if local_rank == 0:
pbar.set_postfix(**{'val_loss': val_loss / (iteration + 1)})
pbar.update(1)
if local_rank == 0:
pbar.close()
print('Finish Validation')
loss_history.append_loss(epoch + 1, loss / epoch_step, val_loss / epoch_step_val)
eval_callback.on_epoch_end(epoch + 1, model_train_eval)
print('Epoch:'+ str(epoch + 1) + '/' + str(Epoch))
print('Total Loss: %.3f || Val Loss: %.3f ' % (loss / epoch_step, val_loss / epoch_step_val))
#-----------------------------------------------#
# 保存权值
#-----------------------------------------------#
if ema:
save_state_dict = ema.ema.state_dict()
else:
save_state_dict = model.state_dict()
if (epoch + 1) % save_period == 0 or epoch + 1 == Epoch:
torch.save(save_state_dict, os.path.join(save_dir, "ep%03d-loss%.3f-val_loss%.3f.pth" % (epoch + 1, loss / epoch_step, val_loss / epoch_step_val)))
if len(loss_history.val_loss) <= 1 or (val_loss / epoch_step_val) <= min(loss_history.val_loss):
print('Save best model to best_epoch_weights.pth')
torch.save(save_state_dict, os.path.join(save_dir, "best_epoch_weights.pth"))
torch.save(save_state_dict, os.path.join(save_dir, "last_epoch_weights.pth"))

923
utils/utils_map.py Normal file
View File

@@ -0,0 +1,923 @@
import glob
import json
import math
import operator
import os
import shutil
import sys
try:
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
except:
pass
import cv2
import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot as plt
import numpy as np
'''
0,0 ------> x (width)
|
| (Left,Top)
| *_________
| | |
| |
y |_________|
(height) *
(Right,Bottom)
'''
def log_average_miss_rate(precision, fp_cumsum, num_images):
"""
log-average miss rate:
Calculated by averaging miss rates at 9 evenly spaced FPPI points
between 10e-2 and 10e0, in log-space.
output:
lamr | log-average miss rate
mr | miss rate
fppi | false positives per image
references:
[1] Dollar, Piotr, et al. "Pedestrian Detection: An Evaluation of the
State of the Art." Pattern Analysis and Machine Intelligence, IEEE
Transactions on 34.4 (2012): 743 - 761.
"""
if precision.size == 0:
lamr = 0
mr = 1
fppi = 0
return lamr, mr, fppi
fppi = fp_cumsum / float(num_images)
mr = (1 - precision)
fppi_tmp = np.insert(fppi, 0, -1.0)
mr_tmp = np.insert(mr, 0, 1.0)
ref = np.logspace(-2.0, 0.0, num = 9)
for i, ref_i in enumerate(ref):
j = np.where(fppi_tmp <= ref_i)[-1][-1]
ref[i] = mr_tmp[j]
lamr = math.exp(np.mean(np.log(np.maximum(1e-10, ref))))
return lamr, mr, fppi
"""
throw error and exit
"""
def error(msg):
print(msg)
sys.exit(0)
"""
check if the number is a float between 0.0 and 1.0
"""
def is_float_between_0_and_1(value):
try:
val = float(value)
if val > 0.0 and val < 1.0:
return True
else:
return False
except ValueError:
return False
"""
Calculate the AP given the recall and precision array
1st) We compute a version of the measured precision/recall curve with
precision monotonically decreasing
2nd) We compute the AP as the area under this curve by numerical integration.
"""
def voc_ap(rec, prec):
"""
--- Official matlab code VOC2012---
mrec=[0 ; rec ; 1];
mpre=[0 ; prec ; 0];
for i=numel(mpre)-1:-1:1
mpre(i)=max(mpre(i),mpre(i+1));
end
i=find(mrec(2:end)~=mrec(1:end-1))+1;
ap=sum((mrec(i)-mrec(i-1)).*mpre(i));
"""
rec.insert(0, 0.0) # insert 0.0 at begining of list
rec.append(1.0) # insert 1.0 at end of list
mrec = rec[:]
prec.insert(0, 0.0) # insert 0.0 at begining of list
prec.append(0.0) # insert 0.0 at end of list
mpre = prec[:]
"""
This part makes the precision monotonically decreasing
(goes from the end to the beginning)
matlab: for i=numel(mpre)-1:-1:1
mpre(i)=max(mpre(i),mpre(i+1));
"""
for i in range(len(mpre)-2, -1, -1):
mpre[i] = max(mpre[i], mpre[i+1])
"""
This part creates a list of indexes where the recall changes
matlab: i=find(mrec(2:end)~=mrec(1:end-1))+1;
"""
i_list = []
for i in range(1, len(mrec)):
if mrec[i] != mrec[i-1]:
i_list.append(i) # if it was matlab would be i + 1
"""
The Average Precision (AP) is the area under the curve
(numerical integration)
matlab: ap=sum((mrec(i)-mrec(i-1)).*mpre(i));
"""
ap = 0.0
for i in i_list:
ap += ((mrec[i]-mrec[i-1])*mpre[i])
return ap, mrec, mpre
"""
Convert the lines of a file to a list
"""
def file_lines_to_list(path):
# open txt file lines to a list
with open(path) as f:
content = f.readlines()
# remove whitespace characters like `\n` at the end of each line
content = [x.strip() for x in content]
return content
"""
Draws text in image
"""
def draw_text_in_image(img, text, pos, color, line_width):
font = cv2.FONT_HERSHEY_PLAIN
fontScale = 1
lineType = 1
bottomLeftCornerOfText = pos
cv2.putText(img, text,
bottomLeftCornerOfText,
font,
fontScale,
color,
lineType)
text_width, _ = cv2.getTextSize(text, font, fontScale, lineType)[0]
return img, (line_width + text_width)
"""
Plot - adjust axes
"""
def adjust_axes(r, t, fig, axes):
# get text width for re-scaling
bb = t.get_window_extent(renderer=r)
text_width_inches = bb.width / fig.dpi
# get axis width in inches
current_fig_width = fig.get_figwidth()
new_fig_width = current_fig_width + text_width_inches
propotion = new_fig_width / current_fig_width
# get axis limit
x_lim = axes.get_xlim()
axes.set_xlim([x_lim[0], x_lim[1]*propotion])
"""
Draw plot using Matplotlib
"""
def draw_plot_func(dictionary, n_classes, window_title, plot_title, x_label, output_path, to_show, plot_color, true_p_bar):
# sort the dictionary by decreasing value, into a list of tuples
sorted_dic_by_value = sorted(dictionary.items(), key=operator.itemgetter(1))
# unpacking the list of tuples into two lists
sorted_keys, sorted_values = zip(*sorted_dic_by_value)
#
if true_p_bar != "":
"""
Special case to draw in:
- green -> TP: True Positives (object detected and matches ground-truth)
- red -> FP: False Positives (object detected but does not match ground-truth)
- orange -> FN: False Negatives (object not detected but present in the ground-truth)
"""
fp_sorted = []
tp_sorted = []
for key in sorted_keys:
fp_sorted.append(dictionary[key] - true_p_bar[key])
tp_sorted.append(true_p_bar[key])
plt.barh(range(n_classes), fp_sorted, align='center', color='crimson', label='False Positive')
plt.barh(range(n_classes), tp_sorted, align='center', color='forestgreen', label='True Positive', left=fp_sorted)
# add legend
plt.legend(loc='lower right')
"""
Write number on side of bar
"""
fig = plt.gcf() # gcf - get current figure
axes = plt.gca()
r = fig.canvas.get_renderer()
for i, val in enumerate(sorted_values):
fp_val = fp_sorted[i]
tp_val = tp_sorted[i]
fp_str_val = " " + str(fp_val)
tp_str_val = fp_str_val + " " + str(tp_val)
# trick to paint multicolor with offset:
# first paint everything and then repaint the first number
t = plt.text(val, i, tp_str_val, color='forestgreen', va='center', fontweight='bold')
plt.text(val, i, fp_str_val, color='crimson', va='center', fontweight='bold')
if i == (len(sorted_values)-1): # largest bar
adjust_axes(r, t, fig, axes)
else:
plt.barh(range(n_classes), sorted_values, color=plot_color)
"""
Write number on side of bar
"""
fig = plt.gcf() # gcf - get current figure
axes = plt.gca()
r = fig.canvas.get_renderer()
for i, val in enumerate(sorted_values):
str_val = " " + str(val) # add a space before
if val < 1.0:
str_val = " {0:.2f}".format(val)
t = plt.text(val, i, str_val, color=plot_color, va='center', fontweight='bold')
# re-set axes to show number inside the figure
if i == (len(sorted_values)-1): # largest bar
adjust_axes(r, t, fig, axes)
# set window title
fig.canvas.set_window_title(window_title)
# write classes in y axis
tick_font_size = 12
plt.yticks(range(n_classes), sorted_keys, fontsize=tick_font_size)
"""
Re-scale height accordingly
"""
init_height = fig.get_figheight()
# comput the matrix height in points and inches
dpi = fig.dpi
height_pt = n_classes * (tick_font_size * 1.4) # 1.4 (some spacing)
height_in = height_pt / dpi
# compute the required figure height
top_margin = 0.15 # in percentage of the figure height
bottom_margin = 0.05 # in percentage of the figure height
figure_height = height_in / (1 - top_margin - bottom_margin)
# set new height
if figure_height > init_height:
fig.set_figheight(figure_height)
# set plot title
plt.title(plot_title, fontsize=14)
# set axis titles
# plt.xlabel('classes')
plt.xlabel(x_label, fontsize='large')
# adjust size of window
fig.tight_layout()
# save the plot
fig.savefig(output_path)
# show image
if to_show:
plt.show()
# close the plot
plt.close()
def get_map(MINOVERLAP, draw_plot, score_threhold=0.5, path = './map_out'):
GT_PATH = os.path.join(path, 'ground-truth')
DR_PATH = os.path.join(path, 'detection-results')
IMG_PATH = os.path.join(path, 'images-optional')
TEMP_FILES_PATH = os.path.join(path, '.temp_files')
RESULTS_FILES_PATH = os.path.join(path, 'results')
show_animation = True
if os.path.exists(IMG_PATH):
for dirpath, dirnames, files in os.walk(IMG_PATH):
if not files:
show_animation = False
else:
show_animation = False
if not os.path.exists(TEMP_FILES_PATH):
os.makedirs(TEMP_FILES_PATH)
if os.path.exists(RESULTS_FILES_PATH):
shutil.rmtree(RESULTS_FILES_PATH)
else:
os.makedirs(RESULTS_FILES_PATH)
if draw_plot:
try:
matplotlib.use('TkAgg')
except:
pass
os.makedirs(os.path.join(RESULTS_FILES_PATH, "AP"))
os.makedirs(os.path.join(RESULTS_FILES_PATH, "F1"))
os.makedirs(os.path.join(RESULTS_FILES_PATH, "Recall"))
os.makedirs(os.path.join(RESULTS_FILES_PATH, "Precision"))
if show_animation:
os.makedirs(os.path.join(RESULTS_FILES_PATH, "images", "detections_one_by_one"))
ground_truth_files_list = glob.glob(GT_PATH + '/*.txt')
if len(ground_truth_files_list) == 0:
error("Error: No ground-truth files found!")
ground_truth_files_list.sort()
gt_counter_per_class = {}
counter_images_per_class = {}
for txt_file in ground_truth_files_list:
file_id = txt_file.split(".txt", 1)[0]
file_id = os.path.basename(os.path.normpath(file_id))
temp_path = os.path.join(DR_PATH, (file_id + ".txt"))
if not os.path.exists(temp_path):
error_msg = "Error. File not found: {}\n".format(temp_path)
error(error_msg)
lines_list = file_lines_to_list(txt_file)
bounding_boxes = []
is_difficult = False
already_seen_classes = []
for line in lines_list:
try:
if "difficult" in line:
class_name, left, top, right, bottom, _difficult = line.split()
is_difficult = True
else:
class_name, left, top, right, bottom = line.split()
except:
if "difficult" in line:
line_split = line.split()
_difficult = line_split[-1]
bottom = line_split[-2]
right = line_split[-3]
top = line_split[-4]
left = line_split[-5]
class_name = ""
for name in line_split[:-5]:
class_name += name + " "
class_name = class_name[:-1]
is_difficult = True
else:
line_split = line.split()
bottom = line_split[-1]
right = line_split[-2]
top = line_split[-3]
left = line_split[-4]
class_name = ""
for name in line_split[:-4]:
class_name += name + " "
class_name = class_name[:-1]
bbox = left + " " + top + " " + right + " " + bottom
if is_difficult:
bounding_boxes.append({"class_name":class_name, "bbox":bbox, "used":False, "difficult":True})
is_difficult = False
else:
bounding_boxes.append({"class_name":class_name, "bbox":bbox, "used":False})
if class_name in gt_counter_per_class:
gt_counter_per_class[class_name] += 1
else:
gt_counter_per_class[class_name] = 1
if class_name not in already_seen_classes:
if class_name in counter_images_per_class:
counter_images_per_class[class_name] += 1
else:
counter_images_per_class[class_name] = 1
already_seen_classes.append(class_name)
with open(TEMP_FILES_PATH + "/" + file_id + "_ground_truth.json", 'w') as outfile:
json.dump(bounding_boxes, outfile)
gt_classes = list(gt_counter_per_class.keys())
gt_classes = sorted(gt_classes)
n_classes = len(gt_classes)
dr_files_list = glob.glob(DR_PATH + '/*.txt')
dr_files_list.sort()
for class_index, class_name in enumerate(gt_classes):
bounding_boxes = []
for txt_file in dr_files_list:
file_id = txt_file.split(".txt",1)[0]
file_id = os.path.basename(os.path.normpath(file_id))
temp_path = os.path.join(GT_PATH, (file_id + ".txt"))
if class_index == 0:
if not os.path.exists(temp_path):
error_msg = "Error. File not found: {}\n".format(temp_path)
error(error_msg)
lines = file_lines_to_list(txt_file)
for line in lines:
try:
tmp_class_name, confidence, left, top, right, bottom = line.split()
except:
line_split = line.split()
bottom = line_split[-1]
right = line_split[-2]
top = line_split[-3]
left = line_split[-4]
confidence = line_split[-5]
tmp_class_name = ""
for name in line_split[:-5]:
tmp_class_name += name + " "
tmp_class_name = tmp_class_name[:-1]
if tmp_class_name == class_name:
bbox = left + " " + top + " " + right + " " +bottom
bounding_boxes.append({"confidence":confidence, "file_id":file_id, "bbox":bbox})
bounding_boxes.sort(key=lambda x:float(x['confidence']), reverse=True)
with open(TEMP_FILES_PATH + "/" + class_name + "_dr.json", 'w') as outfile:
json.dump(bounding_boxes, outfile)
sum_AP = 0.0
ap_dictionary = {}
lamr_dictionary = {}
with open(RESULTS_FILES_PATH + "/results.txt", 'w') as results_file:
results_file.write("# AP and precision/recall per class\n")
count_true_positives = {}
for class_index, class_name in enumerate(gt_classes):
count_true_positives[class_name] = 0
dr_file = TEMP_FILES_PATH + "/" + class_name + "_dr.json"
dr_data = json.load(open(dr_file))
nd = len(dr_data)
tp = [0] * nd
fp = [0] * nd
score = [0] * nd
score_threhold_idx = 0
for idx, detection in enumerate(dr_data):
file_id = detection["file_id"]
score[idx] = float(detection["confidence"])
if score[idx] >= score_threhold:
score_threhold_idx = idx
if show_animation:
ground_truth_img = glob.glob1(IMG_PATH, file_id + ".*")
if len(ground_truth_img) == 0:
error("Error. Image not found with id: " + file_id)
elif len(ground_truth_img) > 1:
error("Error. Multiple image with id: " + file_id)
else:
img = cv2.imread(IMG_PATH + "/" + ground_truth_img[0])
img_cumulative_path = RESULTS_FILES_PATH + "/images/" + ground_truth_img[0]
if os.path.isfile(img_cumulative_path):
img_cumulative = cv2.imread(img_cumulative_path)
else:
img_cumulative = img.copy()
bottom_border = 60
BLACK = [0, 0, 0]
img = cv2.copyMakeBorder(img, 0, bottom_border, 0, 0, cv2.BORDER_CONSTANT, value=BLACK)
gt_file = TEMP_FILES_PATH + "/" + file_id + "_ground_truth.json"
ground_truth_data = json.load(open(gt_file))
ovmax = -1
gt_match = -1
bb = [float(x) for x in detection["bbox"].split()]
for obj in ground_truth_data:
if obj["class_name"] == class_name:
bbgt = [ float(x) for x in obj["bbox"].split() ]
bi = [max(bb[0],bbgt[0]), max(bb[1],bbgt[1]), min(bb[2],bbgt[2]), min(bb[3],bbgt[3])]
iw = bi[2] - bi[0] + 1
ih = bi[3] - bi[1] + 1
if iw > 0 and ih > 0:
ua = (bb[2] - bb[0] + 1) * (bb[3] - bb[1] + 1) + (bbgt[2] - bbgt[0]
+ 1) * (bbgt[3] - bbgt[1] + 1) - iw * ih
ov = iw * ih / ua
if ov > ovmax:
ovmax = ov
gt_match = obj
if show_animation:
status = "NO MATCH FOUND!"
min_overlap = MINOVERLAP
if ovmax >= min_overlap:
if "difficult" not in gt_match:
if not bool(gt_match["used"]):
tp[idx] = 1
gt_match["used"] = True
count_true_positives[class_name] += 1
with open(gt_file, 'w') as f:
f.write(json.dumps(ground_truth_data))
if show_animation:
status = "MATCH!"
else:
fp[idx] = 1
if show_animation:
status = "REPEATED MATCH!"
else:
fp[idx] = 1
if ovmax > 0:
status = "INSUFFICIENT OVERLAP"
"""
Draw image to show animation
"""
if show_animation:
height, widht = img.shape[:2]
white = (255,255,255)
light_blue = (255,200,100)
green = (0,255,0)
light_red = (30,30,255)
margin = 10
# 1nd line
v_pos = int(height - margin - (bottom_border / 2.0))
text = "Image: " + ground_truth_img[0] + " "
img, line_width = draw_text_in_image(img, text, (margin, v_pos), white, 0)
text = "Class [" + str(class_index) + "/" + str(n_classes) + "]: " + class_name + " "
img, line_width = draw_text_in_image(img, text, (margin + line_width, v_pos), light_blue, line_width)
if ovmax != -1:
color = light_red
if status == "INSUFFICIENT OVERLAP":
text = "IoU: {0:.2f}% ".format(ovmax*100) + "< {0:.2f}% ".format(min_overlap*100)
else:
text = "IoU: {0:.2f}% ".format(ovmax*100) + ">= {0:.2f}% ".format(min_overlap*100)
color = green
img, _ = draw_text_in_image(img, text, (margin + line_width, v_pos), color, line_width)
# 2nd line
v_pos += int(bottom_border / 2.0)
rank_pos = str(idx+1)
text = "Detection #rank: " + rank_pos + " confidence: {0:.2f}% ".format(float(detection["confidence"])*100)
img, line_width = draw_text_in_image(img, text, (margin, v_pos), white, 0)
color = light_red
if status == "MATCH!":
color = green
text = "Result: " + status + " "
img, line_width = draw_text_in_image(img, text, (margin + line_width, v_pos), color, line_width)
font = cv2.FONT_HERSHEY_SIMPLEX
if ovmax > 0:
bbgt = [ int(round(float(x))) for x in gt_match["bbox"].split() ]
cv2.rectangle(img,(bbgt[0],bbgt[1]),(bbgt[2],bbgt[3]),light_blue,2)
cv2.rectangle(img_cumulative,(bbgt[0],bbgt[1]),(bbgt[2],bbgt[3]),light_blue,2)
cv2.putText(img_cumulative, class_name, (bbgt[0],bbgt[1] - 5), font, 0.6, light_blue, 1, cv2.LINE_AA)
bb = [int(i) for i in bb]
cv2.rectangle(img,(bb[0],bb[1]),(bb[2],bb[3]),color,2)
cv2.rectangle(img_cumulative,(bb[0],bb[1]),(bb[2],bb[3]),color,2)
cv2.putText(img_cumulative, class_name, (bb[0],bb[1] - 5), font, 0.6, color, 1, cv2.LINE_AA)
cv2.imshow("Animation", img)
cv2.waitKey(20)
output_img_path = RESULTS_FILES_PATH + "/images/detections_one_by_one/" + class_name + "_detection" + str(idx) + ".jpg"
cv2.imwrite(output_img_path, img)
cv2.imwrite(img_cumulative_path, img_cumulative)
cumsum = 0
for idx, val in enumerate(fp):
fp[idx] += cumsum
cumsum += val
cumsum = 0
for idx, val in enumerate(tp):
tp[idx] += cumsum
cumsum += val
rec = tp[:]
for idx, val in enumerate(tp):
rec[idx] = float(tp[idx]) / np.maximum(gt_counter_per_class[class_name], 1)
prec = tp[:]
for idx, val in enumerate(tp):
prec[idx] = float(tp[idx]) / np.maximum((fp[idx] + tp[idx]), 1)
ap, mrec, mprec = voc_ap(rec[:], prec[:])
F1 = np.array(rec)*np.array(prec)*2 / np.where((np.array(prec)+np.array(rec))==0, 1, (np.array(prec)+np.array(rec)))
sum_AP += ap
text = "{0:.2f}%".format(ap*100) + " = " + class_name + " AP " #class_name + " AP = {0:.2f}%".format(ap*100)
if len(prec)>0:
F1_text = "{0:.2f}".format(F1[score_threhold_idx]) + " = " + class_name + " F1 "
Recall_text = "{0:.2f}%".format(rec[score_threhold_idx]*100) + " = " + class_name + " Recall "
Precision_text = "{0:.2f}%".format(prec[score_threhold_idx]*100) + " = " + class_name + " Precision "
else:
F1_text = "0.00" + " = " + class_name + " F1 "
Recall_text = "0.00%" + " = " + class_name + " Recall "
Precision_text = "0.00%" + " = " + class_name + " Precision "
rounded_prec = [ '%.2f' % elem for elem in prec ]
rounded_rec = [ '%.2f' % elem for elem in rec ]
results_file.write(text + "\n Precision: " + str(rounded_prec) + "\n Recall :" + str(rounded_rec) + "\n\n")
if len(prec)>0:
print(text + "\t||\tscore_threhold=" + str(score_threhold) + " : " + "F1=" + "{0:.2f}".format(F1[score_threhold_idx])\
+ " ; Recall=" + "{0:.2f}%".format(rec[score_threhold_idx]*100) + " ; Precision=" + "{0:.2f}%".format(prec[score_threhold_idx]*100))
else:
print(text + "\t||\tscore_threhold=" + str(score_threhold) + " : " + "F1=0.00% ; Recall=0.00% ; Precision=0.00%")
ap_dictionary[class_name] = ap
n_images = counter_images_per_class[class_name]
lamr, mr, fppi = log_average_miss_rate(np.array(rec), np.array(fp), n_images)
lamr_dictionary[class_name] = lamr
if draw_plot:
plt.plot(rec, prec, '-o')
area_under_curve_x = mrec[:-1] + [mrec[-2]] + [mrec[-1]]
area_under_curve_y = mprec[:-1] + [0.0] + [mprec[-1]]
plt.fill_between(area_under_curve_x, 0, area_under_curve_y, alpha=0.2, edgecolor='r')
fig = plt.gcf()
fig.canvas.set_window_title('AP ' + class_name)
plt.title('class: ' + text)
plt.xlabel('Recall')
plt.ylabel('Precision')
axes = plt.gca()
axes.set_xlim([0.0,1.0])
axes.set_ylim([0.0,1.05])
fig.savefig(RESULTS_FILES_PATH + "/AP/" + class_name + ".png")
plt.cla()
plt.plot(score, F1, "-", color='orangered')
plt.title('class: ' + F1_text + "\nscore_threhold=" + str(score_threhold))
plt.xlabel('Score_Threhold')
plt.ylabel('F1')
axes = plt.gca()
axes.set_xlim([0.0,1.0])
axes.set_ylim([0.0,1.05])
fig.savefig(RESULTS_FILES_PATH + "/F1/" + class_name + ".png")
plt.cla()
plt.plot(score, rec, "-H", color='gold')
plt.title('class: ' + Recall_text + "\nscore_threhold=" + str(score_threhold))
plt.xlabel('Score_Threhold')
plt.ylabel('Recall')
axes = plt.gca()
axes.set_xlim([0.0,1.0])
axes.set_ylim([0.0,1.05])
fig.savefig(RESULTS_FILES_PATH + "/Recall/" + class_name + ".png")
plt.cla()
plt.plot(score, prec, "-s", color='palevioletred')
plt.title('class: ' + Precision_text + "\nscore_threhold=" + str(score_threhold))
plt.xlabel('Score_Threhold')
plt.ylabel('Precision')
axes = plt.gca()
axes.set_xlim([0.0,1.0])
axes.set_ylim([0.0,1.05])
fig.savefig(RESULTS_FILES_PATH + "/Precision/" + class_name + ".png")
plt.cla()
if show_animation:
cv2.destroyAllWindows()
if n_classes == 0:
print("未检测到任何种类请检查标签信息与get_map.py中的classes_path是否修改。")
return 0
results_file.write("\n# mAP of all classes\n")
mAP = sum_AP / n_classes
text = "mAP = {0:.2f}%".format(mAP*100)
results_file.write(text + "\n")
print(text)
shutil.rmtree(TEMP_FILES_PATH)
"""
Count total of detection-results
"""
det_counter_per_class = {}
for txt_file in dr_files_list:
lines_list = file_lines_to_list(txt_file)
for line in lines_list:
class_name = line.split()[0]
if class_name in det_counter_per_class:
det_counter_per_class[class_name] += 1
else:
det_counter_per_class[class_name] = 1
dr_classes = list(det_counter_per_class.keys())
"""
Write number of ground-truth objects per class to results.txt
"""
with open(RESULTS_FILES_PATH + "/results.txt", 'a') as results_file:
results_file.write("\n# Number of ground-truth objects per class\n")
for class_name in sorted(gt_counter_per_class):
results_file.write(class_name + ": " + str(gt_counter_per_class[class_name]) + "\n")
"""
Finish counting true positives
"""
for class_name in dr_classes:
if class_name not in gt_classes:
count_true_positives[class_name] = 0
"""
Write number of detected objects per class to results.txt
"""
with open(RESULTS_FILES_PATH + "/results.txt", 'a') as results_file:
results_file.write("\n# Number of detected objects per class\n")
for class_name in sorted(dr_classes):
n_det = det_counter_per_class[class_name]
text = class_name + ": " + str(n_det)
text += " (tp:" + str(count_true_positives[class_name]) + ""
text += ", fp:" + str(n_det - count_true_positives[class_name]) + ")\n"
results_file.write(text)
"""
Plot the total number of occurences of each class in the ground-truth
"""
if draw_plot:
window_title = "ground-truth-info"
plot_title = "ground-truth\n"
plot_title += "(" + str(len(ground_truth_files_list)) + " files and " + str(n_classes) + " classes)"
x_label = "Number of objects per class"
output_path = RESULTS_FILES_PATH + "/ground-truth-info.png"
to_show = False
plot_color = 'forestgreen'
draw_plot_func(
gt_counter_per_class,
n_classes,
window_title,
plot_title,
x_label,
output_path,
to_show,
plot_color,
'',
)
# """
# Plot the total number of occurences of each class in the "detection-results" folder
# """
# if draw_plot:
# window_title = "detection-results-info"
# # Plot title
# plot_title = "detection-results\n"
# plot_title += "(" + str(len(dr_files_list)) + " files and "
# count_non_zero_values_in_dictionary = sum(int(x) > 0 for x in list(det_counter_per_class.values()))
# plot_title += str(count_non_zero_values_in_dictionary) + " detected classes)"
# # end Plot title
# x_label = "Number of objects per class"
# output_path = RESULTS_FILES_PATH + "/detection-results-info.png"
# to_show = False
# plot_color = 'forestgreen'
# true_p_bar = count_true_positives
# draw_plot_func(
# det_counter_per_class,
# len(det_counter_per_class),
# window_title,
# plot_title,
# x_label,
# output_path,
# to_show,
# plot_color,
# true_p_bar
# )
"""
Draw log-average miss rate plot (Show lamr of all classes in decreasing order)
"""
if draw_plot:
window_title = "lamr"
plot_title = "log-average miss rate"
x_label = "log-average miss rate"
output_path = RESULTS_FILES_PATH + "/lamr.png"
to_show = False
plot_color = 'royalblue'
draw_plot_func(
lamr_dictionary,
n_classes,
window_title,
plot_title,
x_label,
output_path,
to_show,
plot_color,
""
)
"""
Draw mAP plot (Show AP's of all classes in decreasing order)
"""
if draw_plot:
window_title = "mAP"
plot_title = "mAP = {0:.2f}%".format(mAP*100)
x_label = "Average Precision"
output_path = RESULTS_FILES_PATH + "/mAP.png"
to_show = True
plot_color = 'royalblue'
draw_plot_func(
ap_dictionary,
n_classes,
window_title,
plot_title,
x_label,
output_path,
to_show,
plot_color,
""
)
return mAP
def preprocess_gt(gt_path, class_names):
image_ids = os.listdir(gt_path)
results = {}
images = []
bboxes = []
for i, image_id in enumerate(image_ids):
lines_list = file_lines_to_list(os.path.join(gt_path, image_id))
boxes_per_image = []
image = {}
image_id = os.path.splitext(image_id)[0]
image['file_name'] = image_id + '.jpg'
image['width'] = 1
image['height'] = 1
#-----------------------------------------------------------------#
# 感谢 多学学英语吧 的提醒
# 解决了'Results do not correspond to current coco set'问题
#-----------------------------------------------------------------#
image['id'] = str(image_id)
for line in lines_list:
difficult = 0
if "difficult" in line:
line_split = line.split()
left, top, right, bottom, _difficult = line_split[-5:]
class_name = ""
for name in line_split[:-5]:
class_name += name + " "
class_name = class_name[:-1]
difficult = 1
else:
line_split = line.split()
left, top, right, bottom = line_split[-4:]
class_name = ""
for name in line_split[:-4]:
class_name += name + " "
class_name = class_name[:-1]
left, top, right, bottom = float(left), float(top), float(right), float(bottom)
if class_name not in class_names:
continue
cls_id = class_names.index(class_name) + 1
bbox = [left, top, right - left, bottom - top, difficult, str(image_id), cls_id, (right - left) * (bottom - top) - 10.0]
boxes_per_image.append(bbox)
images.append(image)
bboxes.extend(boxes_per_image)
results['images'] = images
categories = []
for i, cls in enumerate(class_names):
category = {}
category['supercategory'] = cls
category['name'] = cls
category['id'] = i + 1
categories.append(category)
results['categories'] = categories
annotations = []
for i, box in enumerate(bboxes):
annotation = {}
annotation['area'] = box[-1]
annotation['category_id'] = box[-2]
annotation['image_id'] = box[-3]
annotation['iscrowd'] = box[-4]
annotation['bbox'] = box[:4]
annotation['id'] = i
annotations.append(annotation)
results['annotations'] = annotations
return results
def preprocess_dr(dr_path, class_names):
image_ids = os.listdir(dr_path)
results = []
for image_id in image_ids:
lines_list = file_lines_to_list(os.path.join(dr_path, image_id))
image_id = os.path.splitext(image_id)[0]
for line in lines_list:
line_split = line.split()
confidence, left, top, right, bottom = line_split[-5:]
class_name = ""
for name in line_split[:-5]:
class_name += name + " "
class_name = class_name[:-1]
left, top, right, bottom = float(left), float(top), float(right), float(bottom)
result = {}
result["image_id"] = str(image_id)
if class_name not in class_names:
continue
result["category_id"] = class_names.index(class_name) + 1
result["bbox"] = [left, top, right - left, bottom - top]
result["score"] = float(confidence)
results.append(result)
return results
def get_coco_map(class_names, path):
GT_PATH = os.path.join(path, 'ground-truth')
DR_PATH = os.path.join(path, 'detection-results')
COCO_PATH = os.path.join(path, 'coco_eval')
if not os.path.exists(COCO_PATH):
os.makedirs(COCO_PATH)
GT_JSON_PATH = os.path.join(COCO_PATH, 'instances_gt.json')
DR_JSON_PATH = os.path.join(COCO_PATH, 'instances_dr.json')
with open(GT_JSON_PATH, "w") as f:
results_gt = preprocess_gt(GT_PATH, class_names)
json.dump(results_gt, f, indent=4)
with open(DR_JSON_PATH, "w") as f:
results_dr = preprocess_dr(DR_PATH, class_names)
json.dump(results_dr, f, indent=4)
if len(results_dr) == 0:
print("未检测到任何目标。")
return [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
cocoGt = COCO(GT_JSON_PATH)
cocoDt = cocoGt.loadRes(DR_JSON_PATH)
cocoEval = COCOeval(cocoGt, cocoDt, 'bbox')
cocoEval.evaluate()
cocoEval.accumulate()
cocoEval.summarize()
return cocoEval.stats