commit 71118fc64925c027139d797a58ebc1527e816e5b Author: esenke Date: Mon Dec 8 21:38:53 2025 +0800 init diff --git a/README.md b/README.md new file mode 100644 index 0000000..b760ed3 --- /dev/null +++ b/README.md @@ -0,0 +1,173 @@ +# TDCNet +Official repository of the AAAI 2026 paper "Spatio-Temporal Context Learning with Temporal Difference Convolution for Moving Infrared Small Target Detection". + +## Spatio-Temporal Context Learning with Temporal Difference Convolution for Moving Infrared Small Target Detection [[PDF](https://arxiv.org/pdf/2511.09352)] + +Authors: Houzhang Fang1, Shukai Guo1, Qiuhuan Chen1, Yi Chang2, Luxin Yan2 + +1Xidian University, 2Huazhong University of Science and Technology + +## Abstract + +Moving infrared small target detection (IRSTD) plays a critical role in practical applications, such as surveillance of unmanned aerial vehicles (UAVs) and UAV-based search system. Moving IRSTD still remains highly challenging due to weak target features and complex background interference. Accurate spatio-temporal feature modeling is crucial for moving target detection, typically achieved through either temporal differences or spatio-temporal (3D) convolutions. Temporal difference can explicitly leverage motion cues but exhibits limited capability in extracting spatial features, whereas 3D convolution effectively represents spatio-temporal features yet lacks explicit awareness of motion dynamics along the temporal dimension. In this paper, we propose a novel moving IRSTD network (TDCNet), which effectively extracts and enhances spatio-temporal features for accurate target detection. Specifically, we introduce a novel temporal difference convolution (TDC) re-parameterization module that comprises three parallel TDC blocks designed to capture contextual dependencies across different temporal ranges. Each TDC block fuses temporal difference and 3D convolution into a unified spatio-temporal convolution representation. This re-parameterized module can effectively capture multi-scale motion contextual features while suppressing pseudo-motion clutter in complex backgrounds, significantly improving detection performance. Moreover, we propose a TDC-guided spatio-temporal attention mechanism that performs cross-attention between the spatio-temporal features extracted from the TDC-based backbone and a parallel 3D backbone. This mechanism models their global semantic dependencies to refine the current frame’s features, thereby guiding the model to focus more accurately on critical target regions. To facilitate comprehensive evaluation, we construct a new challenging benchmark, IRSTD-UAV, consisting of 15,106 real infrared images with diverse low signal-to-clutter ratio scenarios and complex backgrounds. Extensive experiments on IRSTD-UAV and public infrared datasets demonstrate that our TDCNet achieves state-of-the-art detection performance in moving target detection. + +## TDCNet Framework + +![image-20250407200916034](./figs/overall_framework.png) + +## Visualization + +![image-20250407201214584](./figs/vis_main.png) + +Visual comparison of results from SOTA methods and TDCNet on the IRSTD-UAV and IRDST dataset. Boxes in green and red represent ground-truth and detected targets, respectively + +## Environment + +- [Python](https://www.python.org/) +- [PyTorch](https://pytorch.org/) +- [tqdm](https://github.com/tqdm/tqdm) +- [pycocotools](https://github.com/cocodataset/cocoapi) +- [OpenCV (cv2)](https://opencv.org/) +- [NumPy](https://numpy.org/) + +## Data +1. Download the datasets. + - [IRSTD-UAV (ours)](https://drive.google.com/file/d/1orHDqG-nLYBSdJETyt6ozpAGOSoKiT-k/view): We constructed the IRSTD-UAV dataset, which contains 17 real infrared video sequences with 15,106 frames. It features small targets in complex backgrounds such as buildings, trees, and clouds, providing a realistic benchmark for UAV-based IRSTD. If you use this dataset, please cite our work. + - [IRDST](https://xzbai.buaa.edu.cn/datasets.html) + +2. Perform background alignment before training or testing. + You can use the [GIM](https://github.com/xuelunshen/gim) method for background alignment. + +3. Organize the dataset structure as follows: + ``` + IRSTD-UAV/ + ├── images/ + │ ├── 1/ + │ │ ├── 00000000.png + │ │ ├── 00000001.png + │ │ └── ... + │ ├── 2/ + │ │ ├── 00000000.png + │ │ ├── 00000001.png + │ │ └── ... + │ └── ... + ├── labels/ + │ ├── 1/ + │ │ ├── 00000000.png + │ │ ├── 00000001.png + │ │ └── ... + │ ├── 2/ + │ │ ├── 00000000.png + │ │ ├── 00000001.png + │ │ └── ... + │ └── ... + ├── matches/ + │ ├── 1/ + │ │ ├── 00000000/ + │ │ │ ├── match_1.png + │ │ │ ├── match_2.png + │ │ │ └── ... + │ │ ├── 00000001/ + │ │ │ ├── match_1.png + │ │ │ ├── match_2.png + │ │ │ └── ... + │ │ └── ... + │ └── ... + ├── train.txt + ├── val.txt + ├── train_coco.json + └── val_coco.json + ``` +## How To Train + +1. **Prepare the environment and dataset** + - Configure environment variables: + ```python + os.environ["CUDA_VISIBLE_DEVICES"] = "0" + ``` + - Set dataset paths in the script: + ```python + DATA_PATH = "/Dataset/IRSTD-UAV/" + train_annotation_path = "/Dataset/IRSTD-UAV/train.txt" + val_annotation_path = "/Dataset/IRSTD-UAV/val.txt" + ``` + +2. **Training command** + ```bash + python train.py + ``` + +## How To Test +```bash +python test.py +``` + +The testing results will be saved in the ./results folder. + +## Citation +If you find our work and our dataset IRSTD-UAV useful for your research, please consider citing our paper: +```bibtex +@inproceedings{2026AAAI_TDCNet, + title = {Spatio-Temporal Context Learning with Temporal Difference Convolution for Moving Infrared Small Target Detection}, + author = {Houzhang Fang and Shukai Guo and Qiuhuan Chen and Yi Chang and Luxin Yan}, + booktitle = {Proceedings of the AAAI Conference on Artificial Intelligence}, + year = {2026}, + pages = { }, +} +``` + +In additoin to the above paper, please also consider citing the following references. Thank you! +```bibtex +@inproceedings{2025CVPR_UniCD, + title = {Detection-Friendly Nonuniformity Correction: A Union Framework for Infrared {UAV} Target Detection}, + author = {Houzhang Fang and Xiaolin Wang and Zengyang Li and Lu Wang and Qingshan Li and Yi Chang and Luxin Yan}, + booktitle = {IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, + year = {2025}, + pages = {11898-11907}, +} +@ARTICLE{2023TII_DAGNet, + title = {Differentiated Attention Guided Network Over Hierarchical and Aggregated Features for Intelligent {UAV} Surveillance}, + author = {Houzhang Fang and Zikai Liao and Xuhua Wang and Yi Chang and Luxin Yan}, + journal = {IEEE Transactions on Industrial Informatics}, + year = {2023}, + volume = {19}, + number = {9}, + pages = {9909-9920}, + } +@inproceedings{2023ACMMM_DANet, +title = {{DANet}: Multi-scale {UAV} Target Detection with Dynamic Feature Perception and Scale-aware Knowledge Distillation}, +author = {Houzhang Fang and Zikai Liao and Lu Wang and Qingshan Li and Yi Chang and Luxin Yan and Xuhua Wang}, +booktitle = {Proceedings of the 31st ACM International Conference on Multimedia (ACMMM)}, +pages = {2121-2130}, +year = {2023}, +} +@article{2024TGRS_SCINet, + title = {{SCINet}: Spatial and Contrast Interactive Super-Resolution Assisted Infrared {UAV} Target Detection}, + author = {Houzhang Fang and Lan Ding and Xiaolin Wang and Yi Chang and Luxin Yan and Li Liu and Jinrui Fang}, + journal = {IEEE Transactions on Geoscience and Remote Sensing}, + volume = {62}, + year = {2024}, + pages = {1-22}, +} +@ARTICLE{2022TIMFang, + title = {Infrared Small {UAV} Target Detection Based on Depthwise Separable Residual Dense Network and Multiscale Feature Fusion}, + author = {Houzhang Fang and Lan Ding and Liming Wang and Yi Chang and Luxin Yan and Jinhui Han}, + journal = {IEEE Transactions on Instrumentation and Measurement}, + year = {2022}, + volume = {71}, + number = {}, + pages = {1-20}, +} +``` + +## Contact +If you have any question, please contact: houzhangfang@xidian.edu.cn, + +Copyright © Xidian University. + +## Acknowledgments +Some of the code is based on [STMENet](https://github.com/UESTC-nnLab/STME) and [Video Swin Transformer](https://github.com/SwinTransformer/Video-Swin-Transformer). Thanks for their excellent work! + +## License +MIT License. This code is only freely available for non-commercial research use. + diff --git a/data/dataloader_for_IRDST.py b/data/dataloader_for_IRDST.py new file mode 100644 index 0000000..51e3e38 --- /dev/null +++ b/data/dataloader_for_IRDST.py @@ -0,0 +1,123 @@ +import os + +import numpy as np +import torch +from PIL import Image +from torch.utils.data.dataset import Dataset + + +def cvtColor(image): + if len(np.shape(image)) == 3 and np.shape(image)[2] == 3: + return image + else: + image = image.convert('RGB') + return image + +def preprocess(image): + image = image.astype(np.float32) + image /= 255.0 + return image + + +def rand(a=0, b=1): + return np.random.rand() * (b - a) + a + + +class seqDataset(Dataset): + def __init__(self, dataset_path, image_size, num_frame=5, type='train'): + super(seqDataset, self).__init__() + self.dataset_path = dataset_path + self.img_idx = [] + self.anno_idx = [] + self.image_size = image_size + self.num_frame = num_frame + self.txt_path = dataset_path + with open(self.txt_path) as f: + data_lines = f.readlines() + self.length = len(data_lines) + for line in data_lines: + line = line.strip('\n').split() + self.img_idx.append(line[0]) + self.anno_idx.append(np.array([np.array(list(map(int, box.split(',')))) for box in line[1:]])) + + def __len__(self): + return self.length + + def __getitem__(self, index): + images, box = self.get_data(index) + images = np.array(images) + images = np.transpose(preprocess(images), (3, 0, 1, 2)) + if len(box) != 0: + box[:, 2:4] = box[:, 2:4] - box[:, 0:2] + box[:, 0:2] = box[:, 0:2] + (box[:, 2:4] / 2) + return images, box + + def get_data(self, index): + h, w = self.image_size, self.image_size + file_name = self.img_idx[index] + + dir_path = file_name.replace('images', 'matches').replace('.png', '') + images = [] + for i in range(self.num_frame): + img_path = os.path.join(dir_path, f"img_{i + 1}.png") + img = Image.open(img_path) + img = cvtColor(img) + iw, ih = img.size + scale = min(w / iw, h / ih) + nw = int(iw * scale) + nh = int(ih * scale) + dx = (w - nw) // 2 + dy = (h - nh) // 2 + img = img.resize((nw, nh), Image.BICUBIC) + new_img = Image.new('RGB', (w, h), (128, 128, 128)) + new_img.paste(img, (dx, dy)) + images.append(np.array(new_img, np.float32)) + + image_data = [] + image_id = int(file_name.split("/")[-1][:-4]) + image_path = file_name.replace(file_name.split("/")[-1], '') + for id in range(0, self.num_frame): + img = Image.open(image_path + '%d.png' % max(image_id - id, 0)) + + img = cvtColor(img) + iw, ih = img.size + + scale = min(w / iw, h / ih) + nw = int(iw * scale) + nh = int(ih * scale) + dx = (w - nw) // 2 + dy = (h - nh) // 2 + + img = img.resize((nw, nh), Image.BICUBIC) # 原图等比列缩放 + new_img = Image.new('RGB', (w, h), (128, 128, 128)) # 预期大小的灰色图 + new_img.paste(img, (dx, dy)) # 缩放图片放在正中 + image_data.append(np.array(new_img, np.float32)) + + image_data = image_data[::-1] # 关键帧在后 # [5,w,h,3] + for img in image_data: + images.append(img.copy()) + label_data = self.anno_idx[index] # 4+1 + if len(label_data) > 0: + np.random.shuffle(label_data) + label_data[:, [0, 2]] = label_data[:, [0, 2]] * nw / iw + dx + label_data[:, [1, 3]] = label_data[:, [1, 3]] * nh / ih + dy + label_data[:, 0:2][label_data[:, 0:2] < 0] = 0 + label_data[:, 2][label_data[:, 2] > w] = w + label_data[:, 3][label_data[:, 3] > h] = h + box_w = label_data[:, 2] - label_data[:, 0] + box_h = label_data[:, 3] - label_data[:, 1] + label_data = label_data[np.logical_and(box_w > 1, box_h > 1)] + images = np.array(images) + label_data = np.array(label_data, dtype=np.float32) + return images, label_data + + +def dataset_collate(batch): + images = [] + bboxes = [] + for img, box in batch: + images.append(img) + bboxes.append(box) + images = torch.from_numpy(np.array(images)).type(torch.FloatTensor) + bboxes = [torch.from_numpy(ann).type(torch.FloatTensor) for ann in bboxes] + return images, bboxes diff --git a/data/dataloader_for_IRSTD_UAV.py b/data/dataloader_for_IRSTD_UAV.py new file mode 100644 index 0000000..bf62058 --- /dev/null +++ b/data/dataloader_for_IRSTD_UAV.py @@ -0,0 +1,125 @@ +import os + +import numpy as np +import torch +from PIL import Image +from torch.utils.data.dataset import Dataset + + +def cvtColor(image): + if len(np.shape(image)) == 3 and np.shape(image)[2] == 3: + return image + else: + image = image.convert('RGB') + return image + + +def preprocess(image): + image = image.astype(np.float32) + image /= 255.0 + return image + + +def rand(a=0, b=1): + return np.random.rand() * (b - a) + a + + +class seqDataset(Dataset): + def __init__(self, dataset_path, image_size, num_frame=5, type='train'): + super(seqDataset, self).__init__() + self.dataset_path = dataset_path + self.img_idx = [] + self.anno_idx = [] + self.image_size = image_size + self.num_frame = num_frame + self.txt_path = dataset_path + with open(self.txt_path) as f: + data_lines = f.readlines() + self.length = len(data_lines) + for line in data_lines: + line = line.strip('\n').split() + self.img_idx.append(line[0]) + self.anno_idx.append(np.array([np.array(list(map(int, box.split(',')))) for box in line[1:]])) + + def __len__(self): + return self.length + + def __getitem__(self, index): + images, box = self.get_data(index) + images = np.array(images) + images = np.transpose(preprocess(images), (3, 0, 1, 2)) + if len(box) != 0: + box[:, 2:4] = box[:, 2:4] - box[:, 0:2] + box[:, 0:2] = box[:, 0:2] + (box[:, 2:4] / 2) + return images, box + + def get_data(self, index): + h, w = self.image_size, self.image_size + file_name = self.img_idx[index] + + dir_path = file_name.replace('images', 'matches').replace('.png', '') + images = [] + for i in range(self.num_frame): + img_path = os.path.join(dir_path, f"match_{i + 1}.png") + img = Image.open(img_path) + img = cvtColor(img) + iw, ih = img.size + scale = min(w / iw, h / ih) + nw = int(iw * scale) + nh = int(ih * scale) + dx = (w - nw) // 2 + dy = (h - nh) // 2 + img = img.resize((nw, nh), Image.BICUBIC) + new_img = Image.new('RGB', (w, h), (128, 128, 128)) + new_img.paste(img, (dx, dy)) + images.append(np.array(new_img, np.float32)) + + image_data = [] + image_id = int(file_name.split("/")[-1][:8]) + image_path = file_name.replace(file_name.split("/")[-1], '') + min_index = image_id - (image_id % 50) + for id in range(0, self.num_frame): + img = Image.open(image_path + '%08d.png' % max(image_id - id, min_index)) + + img = cvtColor(img) + iw, ih = img.size + + scale = min(w / iw, h / ih) + nw = int(iw * scale) + nh = int(ih * scale) + dx = (w - nw) // 2 + dy = (h - nh) // 2 + + img = img.resize((nw, nh), Image.BICUBIC) # 原图等比列缩放 + new_img = Image.new('RGB', (w, h), (128, 128, 128)) # 预期大小的灰色图 + new_img.paste(img, (dx, dy)) # 缩放图片放在正中 + image_data.append(np.array(new_img, np.float32)) + + image_data = image_data[::-1] # 关键帧在后 # [5,w,h,3] + for img in image_data: + images.append(img.copy()) + label_data = self.anno_idx[index] # 4+1 + if len(label_data) > 0: + np.random.shuffle(label_data) + label_data[:, [0, 2]] = label_data[:, [0, 2]] * nw / iw + dx + label_data[:, [1, 3]] = label_data[:, [1, 3]] * nh / ih + dy + label_data[:, 0:2][label_data[:, 0:2] < 0] = 0 + label_data[:, 2][label_data[:, 2] > w] = w + label_data[:, 3][label_data[:, 3] > h] = h + box_w = label_data[:, 2] - label_data[:, 0] + box_h = label_data[:, 3] - label_data[:, 1] + label_data = label_data[np.logical_and(box_w > 1, box_h > 1)] + images = np.array(images) + label_data = np.array(label_data, dtype=np.float32) + return images, label_data + + +def dataset_collate(batch): + images = [] + bboxes = [] + for img, box in batch: + images.append(img) + bboxes.append(box) + images = torch.from_numpy(np.array(images)).type(torch.FloatTensor) + bboxes = [torch.from_numpy(ann).type(torch.FloatTensor) for ann in bboxes] + return images, bboxes diff --git a/figs/overall_framework.png b/figs/overall_framework.png new file mode 100644 index 0000000..bd4e813 Binary files /dev/null and b/figs/overall_framework.png differ diff --git a/figs/vis_main.png b/figs/vis_main.png new file mode 100644 index 0000000..ed94ca8 Binary files /dev/null and b/figs/vis_main.png differ diff --git a/model/TDCNet/TDCNetwork.py b/model/TDCNet/TDCNetwork.py new file mode 100644 index 0000000..03e83df --- /dev/null +++ b/model/TDCNet/TDCNetwork.py @@ -0,0 +1,373 @@ +import torch +import torch.nn as nn + +from model.TDCNet.TDCSTA import CrossAttention, SelfAttention +from model.TDCNet.backbone3d import Backbone3D +from model.TDCNet.backbonetd import BackboneTD +from model.TDCNet.darknet import BaseConv, CSPDarknet, DWConv + + +class Feature_Backbone(nn.Module): + def __init__(self, depth=1.0, width=1.0, in_features=("dark3", "dark4", "dark5"), in_channels=[256, 512, 1024], depthwise=False, act="silu"): + super().__init__() + self.backbone = CSPDarknet(depth, width, depthwise=depthwise, act=act) + self.in_features = in_features + + def forward(self, input): + out_features = self.backbone.forward(input) + [feat1, feat2, feat3] = [out_features[f] for f in self.in_features] + return [feat1, feat2, feat3] + + +class Bottleneck(nn.Module): + # Standard bottleneck + def __init__(self, in_channels, out_channels, shortcut=True, expansion=0.5, depthwise=False, act="silu", ): + super().__init__() + hidden_channels = int(out_channels * expansion) + Conv = BaseConv # if depthwise else BaseConv + # --------------------------------------------------# + # 利用1x1卷积进行通道数的缩减。缩减率一般是50% + # --------------------------------------------------# + self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act) + # --------------------------------------------------# + # 利用3x3卷积进行通道数的拓张。并且完成特征提取 + # --------------------------------------------------# + self.conv2 = Conv(hidden_channels, out_channels, 3, stride=1, act=act) + # self.conv2=nn.Identity() + self.use_add = shortcut and in_channels == out_channels + + def forward(self, x): + y = self.conv2(self.conv1(x)) + if self.use_add: + y = y + x + return y + + +class FusionLayer(nn.Module): + def __init__(self, in_channels, out_channels, expansion=0.5, depthwise=False, act="silu", ): + # ch_in, ch_out, number, shortcut, groups, expansion + super().__init__() + hidden_channels = int(out_channels * expansion) + n = 1 + # --------------------------------------------------# + # 主干部分的初次卷积 + # --------------------------------------------------# + self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act) + # --------------------------------------------------# + # 大的残差边部分的初次卷积 + # --------------------------------------------------# + self.conv2 = BaseConv(hidden_channels, hidden_channels, 1, stride=1, act=act) # in_channel + # -----------------------------------------------# + # 对堆叠的结果进行卷积的处理 + # self.deepfeature=nn.Sequential(BaseConv(hidden_channels, hidden_channels//2, 1, stride=1, act=act), + # BaseConv(hidden_channels//2, hidden_channels, 3, stride=1, act=act)) + # -----------------------------------------------# + # module_list = [Bottleneck(hidden_channels, hidden_channels, True, 1.0, depthwise, act=act) for _ in range(n)] + # self.deepfeature = nn.Sequential(*module_list) + self.conv3 = BaseConv(hidden_channels, out_channels, 1, stride=1, act=act) # 2*hidden_channel + + # --------------------------------------------------# + # 根据循环的次数构建上述Bottleneck残差结构 + # --------------------------------------------------# + # module_list = [Bottleneck(hidden_channels, hidden_channels, shortcut, 1.0, depthwise, act=act) for _ in range(n)] + # self.m = nn.Sequential(*module_list) + + def forward(self, x): + # -------------------------------# + # x_1是主干部分 + # -------------------------------# + # x_1 = self.conv1(x) + x = self.conv1(x) + # -------------------------------# + # x_2是大的残差边部分 + # -------------------------------# + # x_2 = self.conv2(x) + x = self.conv2(x) + # -----------------------------------------------# + # 主干部分利用残差结构堆叠继续进行特征提取 + # -----------------------------------------------# + # x_1 = self.deepfeature(x_1) + # -----------------------------------------------# + # 主干部分和大的残差边部分进行堆叠 + # -----------------------------------------------# + # x = torch.cat((x_1, x_2), dim=1) + # -----------------------------------------------# + # 对堆叠的结果进行卷积的处理 + # -----------------------------------------------# + return self.conv3(x) + + +class Feature_Fusion(nn.Module): + def __init__(self, in_channels=[128, 256, 512], depthwise=False, act="silu"): + super().__init__() + Conv = DWConv if depthwise else BaseConv + + self.upsample = nn.Upsample(scale_factor=2, mode="nearest") + + # -------------------------------------------# + # 20, 20, 1024 -> 20, 20, 512 + # -------------------------------------------# + # self.lateral_conv0 = BaseConv(2 * int(in_channels[2]), int(in_channels[1]), 1, 1, act=act) + self.lateral_conv0 = BaseConv(in_channels[1] + in_channels[2], in_channels[1], 1, 1, act=act) + + # -------------------------------------------# + # 40, 40, 1024 -> 40, 40, 512 + # -------------------------------------------# + self.C3_p4 = FusionLayer( + int(2 * in_channels[1]), + int(in_channels[1]), + depthwise=depthwise, + act=act, + ) + + # -------------------------------------------# + # 40, 40, 512 -> 40, 40, 256 + # -------------------------------------------# + # self.reduce_conv1 = BaseConv(int(2 * in_channels[1]), int(in_channels[0]), 1, 1, act=act) + self.reduce_conv1 = BaseConv(int(in_channels[0] + in_channels[1]), int(in_channels[0]), 1, 1, act=act) + # -------------------------------------------# + # 80, 80, 512 -> 80, 80, 256 + # -------------------------------------------# + self.C3_p3 = FusionLayer( + int(2 * in_channels[0]), + int(in_channels[0]), + depthwise=depthwise, + act=act, + ) + + def forward(self, input): + out_features = input # self.backbone.forward(input) + [feat1, feat2, feat3] = out_features # [out_features[f] for f in self.in_features] + + # -------------------------------------------# + # 20, 20, 1024 -> 20, 20, 512 + # -------------------------------------------# + # P5 = self.lateral_conv0(feat3) + # -------------------------------------------# + # 20, 20, 512 -> 40, 40, 512 + # -------------------------------------------# + P5_upsample = self.upsample(feat3) + # -------------------------------------------# + # 40, 40, 512 + 40, 40, 512 -> 40, 40, 1024 + # -------------------------------------------# + P5_upsample = torch.cat([P5_upsample, feat2], 1) + # pdb.set_trace() + # -------------------------------------------# + # 40, 40, 1024 -> 40, 40, 512 + # -------------------------------------------# + P4 = self.lateral_conv0(P5_upsample) + # P5_upsample = self.C3_p4(P5_upsample) + + # -------------------------------------------# + # 40, 40, 512 -> 40, 40, 256 + # -------------------------------------------# + # P4 = self.reduce_conv1(P5_upsample) + # -------------------------------------------# + # 40, 40, 256 -> 80, 80, 256 + # -------------------------------------------# + P4_upsample = self.upsample(P4) + # -------------------------------------------# + # 80, 80, 256 + 80, 80, 256 -> 80, 80, 512 + # -------------------------------------------# + P4_upsample = torch.cat([P4_upsample, feat1], 1) + # -------------------------------------------# + # 80, 80, 512 -> 80, 80, 256 + # -------------------------------------------# + P3_out = self.reduce_conv1(P4_upsample) + # P3_out = self.C3_p3(P4_upsample) + + return P3_out + + +class YOLOXHead(nn.Module): + def __init__(self, num_classes, width=1.0, in_channels=[16, 32, 64], act="silu"): + super().__init__() + Conv = BaseConv + + self.cls_convs = nn.ModuleList() + self.reg_convs = nn.ModuleList() + self.cls_preds = nn.ModuleList() + self.reg_preds = nn.ModuleList() + self.obj_preds = nn.ModuleList() + self.stems = nn.ModuleList() + + for i in range(len(in_channels)): + self.stems.append(BaseConv(in_channels=int(in_channels[i]), out_channels=int(256 * width), ksize=1, stride=1, act=act)) # 128-> 256 通道整合 + self.cls_convs.append(nn.Sequential(*[ + Conv(in_channels=int(256 * width), out_channels=int(256 * width), ksize=3, stride=1, act=act), + Conv(in_channels=int(256 * width), out_channels=int(256 * width), ksize=3, stride=1, act=act), + ])) + self.cls_preds.append( + nn.Conv2d(in_channels=int(256 * width), out_channels=num_classes, kernel_size=1, stride=1, padding=0) + ) + + self.reg_convs.append(nn.Sequential(*[ + Conv(in_channels=int(256 * width), out_channels=int(256 * width), ksize=3, stride=1, act=act), + Conv(in_channels=int(256 * width), out_channels=int(256 * width), ksize=3, stride=1, act=act) + ])) + self.reg_preds.append( + nn.Conv2d(in_channels=int(256 * width), out_channels=4, kernel_size=1, stride=1, padding=0) + ) + self.obj_preds.append( + nn.Conv2d(in_channels=int(256 * width), out_channels=1, kernel_size=1, stride=1, padding=0) + ) + + def forward(self, inputs): + # ---------------------------------------------------# + # inputs输入 + # P3_out 80, 80, 256 + # P4_out 40, 40, 512 + # P5_out 20, 20, 1024 + # ---------------------------------------------------# + outputs = [] + for k, x in enumerate(inputs): + # ---------------------------------------------------# + # 利用1x1卷积进行通道整合 + # ---------------------------------------------------# + x = self.stems[k](x) + # ---------------------------------------------------# + # 利用两个卷积标准化激活函数来进行特征提取 + # ---------------------------------------------------# + cls_feat = self.cls_convs[k](x) + # ---------------------------------------------------# + # 判断特征点所属的种类 + # 80, 80, num_classes + # 40, 40, num_classes + # 20, 20, num_classes + # ---------------------------------------------------# + cls_output = self.cls_preds[k](cls_feat) + + # ---------------------------------------------------# + # 利用两个卷积标准化激活函数来进行特征提取 + # ---------------------------------------------------# + reg_feat = self.reg_convs[k](x) + # ---------------------------------------------------# + # 特征点的回归系数 + # reg_pred 80, 80, 4 + # reg_pred 40, 40, 4 + # reg_pred 20, 20, 4 + # ---------------------------------------------------# + reg_output = self.reg_preds[k](reg_feat) + # ---------------------------------------------------# + # 判断特征点是否有对应的物体 + # obj_pred 80, 80, 1 + # obj_pred 40, 40, 1 + # obj_pred 20, 20, 1 + # ---------------------------------------------------# + obj_output = self.obj_preds[k](reg_feat) + + output = torch.cat([reg_output, obj_output, cls_output], 1) + outputs.append(output) + return outputs + + +model_config = { + + 'backbone_2d': 'yolo_free_nano', + 'pretrained_2d': True, + 'stride': [8, 16, 32], + # ## 3D + 'backbone_3d': 'shufflenetv2', + 'model_size': '1.0x', # 1.0x + 'pretrained_3d': True, + 'memory_momentum': 0.9, + 'head_dim': 128, # 64 + 'head_norm': 'BN', + 'head_act': 'lrelu', + 'num_cls_heads': 2, + 'num_reg_heads': 2, + 'head_depthwise': True, + +} + + +def build_backbone_3d(cfg, pretrained=False): + backbone = Backbone3D(cfg, pretrained) + return backbone, backbone.feat_dim + + +mcfg = model_config + + +class TDCNetwork(nn.Module): + def __init__(self, num_classes, fp16=False, num_frame=5): + super(TDCNetwork, self).__init__() + self.num_frame = num_frame + self.backbone2d = Feature_Backbone(0.33, 0.50) + self.backbone3d, bk_dim_3d = build_backbone_3d(mcfg, pretrained=mcfg['pretrained_3d'] and True) + self.backbonetd = BackboneTD(mcfg, pretrained=mcfg['pretrained_3d'] and True) + self.q_sa1 = SelfAttention(128, window_size=(2, 8, 8), num_heads=4, use_shift=True, mlp_ratio=1.5) + self.k_sa1 = SelfAttention(128, window_size=(2, 8, 8), num_heads=4, use_shift=True, mlp_ratio=1.5) + self.v_sa1 = SelfAttention(128, window_size=(2, 8, 8), num_heads=4, use_shift=True, mlp_ratio=1.5) + self.q_sa2 = SelfAttention(256, window_size=(2, 4, 4), num_heads=4, use_shift=True, mlp_ratio=1.5) + self.k_sa2 = SelfAttention(256, window_size=(2, 4, 4), num_heads=4, use_shift=True, mlp_ratio=1.5) + self.v_sa2 = SelfAttention(256, window_size=(2, 4, 4), num_heads=4, use_shift=True, mlp_ratio=1.5) + self.q_sa3 = SelfAttention(512, window_size=(2, 2, 2), num_heads=4, use_shift=True, mlp_ratio=1.5) + self.k_sa3 = SelfAttention(512, window_size=(2, 2, 2), num_heads=4, use_shift=True, mlp_ratio=1.5) + self.v_sa3 = SelfAttention(512, window_size=(2, 2, 2), num_heads=4, use_shift=True, mlp_ratio=1.5) + self.ca1 = CrossAttention(128, window_size=(2, 8, 8), num_heads=4) + self.ca2 = CrossAttention(256, window_size=(2, 4, 4), num_heads=4) + self.ca3 = CrossAttention(512, window_size=(2, 2, 2), num_heads=4) + self.feature_fusion = Feature_Fusion() + self.head = YOLOXHead(num_classes=num_classes, width=1.0, in_channels=[128], act="silu") + + def forward(self, inputs): + # inputs: [B, 3, T, H, W] + if len(inputs.shape) == 5: + T = inputs.shape[2] + diff_imgs = inputs[:, :, :T // 2, :, :] + mt_imgs = inputs[:, :, T // 2:, :, :] + else: + diff_imgs = inputs + mt_imgs = inputs + q_3d = self.backbonetd(diff_imgs) + q_3d1, q_3d2, q_3d3 = q_3d['stage2'], q_3d['stage3'], q_3d['stage4'] + k_3d = self.backbone3d(mt_imgs) + k_3d1, k_3d2, k_3d3 = k_3d['stage2'], k_3d['stage3'], k_3d['stage4'] + [feat1, feat2, feat3] = self.backbone2d(inputs[:, :, -1, :, :]) + + def to_5d(x): + # [B, C, T, H, W] -> [B, T, H, W, C] + return x.permute(0, 2, 3, 4, 1) + + q_3d1 = to_5d(q_3d1) + q_3d2 = to_5d(q_3d2) + q_3d3 = to_5d(q_3d3) + k_3d1 = to_5d(k_3d1) + k_3d2 = to_5d(k_3d2) + k_3d3 = to_5d(k_3d3) + + # V特征扩展T维度,与Q/K对齐(假设V为最后一帧,T=1) + def expand_v(x, T): + # [B, C, H, W] -> [B, T, H, W, C],复制T次 + x = x.permute(0, 2, 3, 1).unsqueeze(1) + x = x.expand(-1, T, -1, -1, -1) + return x + + T1 = q_3d1.shape[1] + T2 = q_3d2.shape[1] + T3 = q_3d3.shape[1] + v1 = expand_v(feat1, T1) + v2 = expand_v(feat2, T2) + v3 = expand_v(feat3, T3) + + q1 = self.q_sa1(q_3d1) + k1 = self.k_sa1(k_3d1) + v1 = self.v_sa1(v1) + q2 = self.q_sa2(q_3d2) + k2 = self.k_sa2(k_3d2) + v2 = self.v_sa2(v2) + q3 = self.q_sa3(q_3d3) + k3 = self.k_sa3(k_3d3) + v3 = self.v_sa3(v3) + out1 = self.ca1(q1, k1, v1) + out2 = self.ca2(q2, k2, v2) + out3 = self.ca3(q3, k3, v3) + out1 = out1.mean(1).permute(0, 3, 1, 2) + out2 = out2.mean(1).permute(0, 3, 1, 2) + out3 = out3.mean(1).permute(0, 3, 1, 2) + + feat_all = self.feature_fusion([out1, out2, out3]) + outputs = self.head([feat_all]) + + return outputs \ No newline at end of file diff --git a/model/TDCNet/TDCR.py b/model/TDCNet/TDCR.py new file mode 100644 index 0000000..a500dfd --- /dev/null +++ b/model/TDCNet/TDCR.py @@ -0,0 +1,131 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class TDC(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=(5, 3, 3), stride=1, padding=(2, 1, 1), groups=1, bias=False, step=1): + super().__init__() + self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, bias=bias) + self.step = step + self.groups = groups + + def get_time_gradient_weight(self): + weight = self.conv.weight + kT, kH, kW = weight.shape[2:] + grad_weight = torch.zeros_like(weight, device=weight.device, dtype=weight.dtype) + if kT == 5: + if self.step == -1: + grad_weight[:, :, :, :, :] = -weight[:, :, :, :, :] + grad_weight[:, :, 4, :, :] = weight[:, :, 0, :, :] + weight[:, :, 1, :, :] + weight[:, :, 2, :, :] + weight[:, :, 3, :, :] + weight[:, :, 4, :, :] + elif self.step == 1: + grad_weight[:, :, 4, :, :] = weight[:, :, 4, :, :] + grad_weight[:, :, 3, :, :] = weight[:, :, 3, :, :] - weight[:, :, 4, :, :] + grad_weight[:, :, 2, :, :] = weight[:, :, 2, :, :] - weight[:, :, 3, :, :] + grad_weight[:, :, 1, :, :] = weight[:, :, 1, :, :] - weight[:, :, 2, :, :] + grad_weight[:, :, 0, :, :] = -weight[:, :, 1, :, :] + elif self.step == 2: + grad_weight[:, :, 4, :, :] = weight[:, :, 4, :, :] + grad_weight[:, :, 3, :, :] = weight[:, :, 3, :, :] + grad_weight[:, :, 2, :, :] = weight[:, :, 2, :, :] - weight[:, :, 4, :, :] + grad_weight[:, :, 1, :, :] = -weight[:, :, 3, :, :] + grad_weight[:, :, 0, :, :] = -weight[:, :, 2, :, :] + else: + grad_weight = weight + bias = self.conv.bias + if bias is None: + bias = torch.zeros(weight.shape[0], device=weight.device, dtype=weight.dtype) + return grad_weight, bias + + def forward(self, x): + weight, bias = self.get_time_gradient_weight() + x_diff = F.conv3d(x, weight, bias, stride=self.conv.stride, groups=self.groups, padding=self.conv.padding) + return x_diff + + +class RepConv3D(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=(5, 3, 3), stride=1, padding=(2, 1, 1), groups=1, deploy=False): + super(RepConv3D, self).__init__() + self.deploy = deploy + self.in_channels = in_channels + self.out_channels = out_channels + self.stride = stride + self.groups = groups + if self.deploy: + self.conv_reparam = nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding, groups=groups, bias=True) + else: + self.l_tdc = nn.Sequential( + TDC(in_channels, out_channels, kernel_size, stride, padding, groups=groups, bias=False, step=-1), + nn.BatchNorm3d(out_channels) + ) + self.s_tdc = nn.Sequential( + TDC(in_channels, out_channels, kernel_size, stride, padding, groups=groups, bias=False, step=1), + nn.BatchNorm3d(out_channels) + ) + self.m_tdc = nn.Sequential( + TDC(in_channels, out_channels, kernel_size, stride, padding, groups=groups, bias=False, step=2), + nn.BatchNorm3d(out_channels) + ) + + def forward(self, x): + if self.deploy: + out = F.relu(self.conv_reparam(x)) + else: + out = self.s_tdc(x) + self.m_tdc(x) + self.l_tdc(x) + out = F.relu(out) + return out + + def get_equivalent_kernel_bias(self): + kernel_s_tdc, bias_s_tdc = self._fuse_conv_bn(self.s_tdc) + kernel_m_tdc, bias_m_tdc = self._fuse_conv_bn(self.m_tdc) + kernel_l_tdc, bias_l_tdc = self._fuse_conv_bn(self.l_tdc) + kernel = kernel_s_tdc + kernel_m_tdc + kernel_l_tdc + bias = bias_s_tdc + bias_m_tdc + bias_l_tdc + return kernel, bias + + def switch_to_deploy(self): + if self.deploy: + return + kernel, bias = self.get_equivalent_kernel_bias() + self.conv_reparam = nn.Conv3d( + self.in_channels, self.out_channels, (5, 3, 3), self.stride, + (2, 1, 1), groups=self.groups, bias=True + ) + self.conv_reparam.weight.data = kernel + self.conv_reparam.bias.data = bias + self.deploy = True + del self.s_tdc + del self.m_tdc + del self.l_tdc + + @staticmethod + def _fuse_conv_bn(branch): + if branch is None: + return 0, 0 + + def find_conv(module): + if isinstance(module, nn.Conv3d): + return module + for child in module.children(): + conv = find_conv(child) + if conv is not None: + return conv + return None + + conv = find_conv(branch[0]) + bn = branch[1] + if hasattr(branch[0], 'get_time_gradient_weight'): + w, bias = branch[0].get_time_gradient_weight() + else: + w = conv.weight + if conv.bias is not None: + bias = conv.bias + else: + bias = torch.zeros_like(bn.running_mean) + mean = bn.running_mean + var_sqrt = torch.sqrt(bn.running_var + bn.eps) + gamma = bn.weight + beta = bn.bias + w = w * (gamma / var_sqrt).reshape(-1, 1, 1, 1, 1) + bias = (bias - mean) / var_sqrt * gamma + beta + return w, bias diff --git a/model/TDCNet/TDCSTA.py b/model/TDCNet/TDCSTA.py new file mode 100644 index 0000000..f2d89fa --- /dev/null +++ b/model/TDCNet/TDCSTA.py @@ -0,0 +1,239 @@ +from functools import reduce +from operator import mul + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class WindowAttention3D(nn.Module): + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.dim = dim + self.window_size = window_size # (T, H, W) + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1) * (2 * window_size[2] - 1), num_heads)) + coords_t = torch.arange(self.window_size[0]) + coords_h = torch.arange(self.window_size[1]) + coords_w = torch.arange(self.window_size[2]) + coords = torch.stack(torch.meshgrid(coords_t, coords_h, coords_w, indexing='ij')) # 3, T, H, W + coords_flatten = torch.flatten(coords, 1) + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] + relative_coords = relative_coords.permute(1, 2, 0).contiguous() + relative_coords[:, :, 0] += self.window_size[0] - 1 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 2] += self.window_size[2] - 1 + relative_coords[:, :, 0] *= (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1) + relative_coords[:, :, 1] *= (2 * self.window_size[2] - 1) + relative_position_index = relative_coords.sum(-1) + self.register_buffer("relative_position_index", relative_position_index) + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.softmax = nn.Softmax(dim=-1) + nn.init.trunc_normal_(self.relative_position_bias_table, std=.02) + + def forward(self, x, k=None, v=None, mask=None): + B_, N, C = x.shape + if k is None or v is None: + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + else: + q = x.reshape(B_, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) # [B_, num_heads, N, head_dim] + k = k.reshape(B_, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + v = v.reshape(B_, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + relative_position_bias = self.relative_position_bias_table[self.relative_position_index[:N, :N].reshape(-1)].reshape(N, N, -1) + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() + attn = attn + relative_position_bias.unsqueeze(0) + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + attn = self.attn_drop(attn) + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +def window_partition(x, window_size): + B, T, H, W, C = x.shape + window_size = list(window_size) + if T < window_size[0]: + window_size[0] = T + if H < window_size[1]: + window_size[1] = H + if W < window_size[2]: + window_size[2] = W + x = x.view(B, T // window_size[0] if window_size[0] > 0 else 1, window_size[0], + H // window_size[1] if window_size[1] > 0 else 1, window_size[1], + W // window_size[2] if window_size[2] > 0 else 1, window_size[2], C) + windows = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, reduce(mul, window_size), C) + return windows + + +def window_reverse(windows, window_size, B, T, H, W): + x = windows.view(B, T // window_size[0], H // window_size[1], W // window_size[2], window_size[0], window_size[1], window_size[2], -1) + x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(B, T, H, W, -1) + return x + + +def get_window_size(x_size, window_size, shift_size=None): + use_window_size = list(window_size) + if shift_size is not None: + use_shift_size = list(shift_size) + for i in range(len(x_size)): + if x_size[i] <= window_size[i]: + use_window_size[i] = x_size[i] + if shift_size is not None: + use_shift_size[i] = 0 + if shift_size is None: + return tuple(use_window_size) + else: + return tuple(use_window_size), tuple(use_shift_size) + + +class SelfAttention(nn.Module): + def __init__(self, dim, window_size=(2, 8, 8), num_heads=8, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., use_shift=False, shift_size=None, mlp_ratio=2.0, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.window_size = window_size + self.num_heads = num_heads + self.use_shift = use_shift + self.shift_size = shift_size if shift_size is not None else tuple([w // 2 for w in window_size]) if use_shift else tuple([0] * len(window_size)) + self.attn1 = WindowAttention3D(dim, window_size=self.window_size, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=proj_drop) + self.attn2 = WindowAttention3D(dim, window_size=self.window_size, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=proj_drop) + self.norm1 = norm_layer(dim) + self.norm2 = norm_layer(dim) + self.norm3 = norm_layer(dim) + self.norm4 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp1 = nn.Sequential( + nn.Linear(dim, mlp_hidden_dim), + nn.GELU(), + nn.Linear(mlp_hidden_dim, dim) + ) + self.mlp2 = nn.Sequential( + nn.Linear(dim, mlp_hidden_dim), + nn.GELU(), + nn.Linear(mlp_hidden_dim, dim) + ) + + def create_mask(self, x_shape, device): + B, T, H, W, C = x_shape + img_mask = torch.zeros((1, T, H, W, 1), device=device) + cnt = 0 + t_slices = (slice(0, -self.window_size[0]), slice(-self.window_size[0], -self.shift_size[0]), slice(-self.shift_size[0], None)) + h_slices = (slice(0, -self.window_size[1]), slice(-self.window_size[1], -self.shift_size[1]), slice(-self.shift_size[1], None)) + w_slices = (slice(0, -self.window_size[2]), slice(-self.window_size[2], -self.shift_size[2]), slice(-self.shift_size[2], None)) + for t in t_slices: + for h in h_slices: + for w in w_slices: + img_mask[:, t, h, w, :] = cnt + cnt += 1 + mask_windows = window_partition(img_mask, self.window_size) + mask_windows = mask_windows.squeeze(-1) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + return attn_mask + + def forward(self, x): + B, T, H, W, C = x.shape + window_size, shift_size = get_window_size((T, H, W), self.window_size, self.shift_size) + shortcut = x + x = self.norm1(x) + pad_t = (window_size[0] - T % window_size[0]) % window_size[0] + pad_h = (window_size[1] - H % window_size[1]) % window_size[1] + pad_w = (window_size[2] - W % window_size[2]) % window_size[2] + x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h, 0, pad_t)) + shortcut = F.pad(shortcut, (0, 0, 0, pad_w, 0, pad_h, 0, pad_t)) + _, Tp, Hp, Wp, _ = x.shape + x_windows = window_partition(x, window_size) + attn_windows = self.attn1(x_windows, mask=None) + attn_windows = attn_windows.view(-1, *(window_size + (C,))) + x = window_reverse(attn_windows, window_size, B, Tp, Hp, Wp) + x = shortcut + x + x = x + self.mlp1(self.norm2(x)) + shortcut = x + x = self.norm3(x) + if self.use_shift and any(i > 0 for i in shift_size): + shifted_x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1], -shift_size[2]), dims=(1, 2, 3)) + attn_mask = self.create_mask((B, Tp, Hp, Wp, C), x.device) + x_windows = window_partition(shifted_x, window_size) + attn_windows = self.attn2(x_windows, mask=attn_mask) + attn_windows = attn_windows.view(-1, *(window_size + (C,))) + shifted_x = window_reverse(attn_windows, window_size, B, Tp, Hp, Wp) + x = torch.roll(shifted_x, shifts=(shift_size[0], shift_size[1], shift_size[2]), dims=(1, 2, 3)) + if pad_t > 0: + x = x[:, :T, :, :, :] + shortcut = shortcut[:, :T, :, :, :] + if pad_h > 0: + x = x[:, :, :H, :, :] + shortcut = shortcut[:, :, :H, :, :] + if pad_w > 0: + x = x[:, :, :, :W, :] + shortcut = shortcut[:, :, :, :W, :] + + x = shortcut + x + x = x + self.mlp2(self.norm4(x)) + return x + + +class CrossAttention(nn.Module): + def __init__(self, dim, window_size=(2, 8, 8), num_heads=8, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., mlp_ratio=2.0, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.window_size = window_size + self.num_heads = num_heads + self.norm1_q = norm_layer(dim) + self.norm1_k = norm_layer(dim) + self.norm1_v = norm_layer(dim) + self.attn = WindowAttention3D(dim, window_size=self.window_size, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=proj_drop) + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = nn.Sequential( + nn.Linear(dim, mlp_hidden_dim), + nn.GELU(), + nn.Linear(mlp_hidden_dim, dim) + ) + + def forward(self, q, k, v): + B, T, H, W, C = q.shape + window_size = get_window_size((T, H, W), self.window_size) + shortcut = v + q = self.norm1_q(q) + k = self.norm1_k(k) + v = self.norm1_v(v) + pad_t = (window_size[0] - T % window_size[0]) % window_size[0] + pad_h = (window_size[1] - H % window_size[1]) % window_size[1] + pad_w = (window_size[2] - W % window_size[2]) % window_size[2] + q = F.pad(q, (0, 0, 0, pad_w, 0, pad_h, 0, pad_t)) + k = F.pad(k, (0, 0, 0, pad_w, 0, pad_h, 0, pad_t)) + v = F.pad(v, (0, 0, 0, pad_w, 0, pad_h, 0, pad_t)) + _, Tp, Hp, Wp, _ = q.shape + + q_windows = window_partition(q, window_size) + k_windows = window_partition(k, window_size) + v_windows = window_partition(v, window_size) + attn_windows = self.attn(q_windows, k_windows, v_windows) + attn_windows = attn_windows.view(-1, *(window_size + (C,))) + shifted_x = window_reverse(attn_windows, window_size, B, Tp, Hp, Wp) + x = shifted_x + if pad_t > 0: + x = x[:, :T, :, :, :] + if pad_h > 0: + x = x[:, :, :H, :, :] + if pad_w > 0: + x = x[:, :, :, :W, :] + x = shortcut + x + x = x + self.mlp(self.norm2(x)) + return x diff --git a/model/TDCNet/__init__.py b/model/TDCNet/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/model/TDCNet/backbone3d.py b/model/TDCNet/backbone3d.py new file mode 100644 index 0000000..2ec37b2 --- /dev/null +++ b/model/TDCNet/backbone3d.py @@ -0,0 +1,272 @@ +import torch +import torch.nn as nn + +from torch.hub import load_state_dict_from_url + +model_urls = { + "0.25x": "https://github.com/yjh0410/PyTorch_YOWO/releases/download/yowo-weight/kinetics_shufflenetv2_0.25x_RGB_16_best.pth", + "1.0x": "https://github.com/yjh0410/PyTorch_YOWO/releases/download/yowo-weight/kinetics_shufflenetv2_1.0x_RGB_16_best.pth", + "1.5x": "https://github.com/yjh0410/PyTorch_YOWO/releases/download/yowo-weight/kinetics_shufflenetv2_1.5x_RGB_16_best.pth", + "2.0x": "https://github.com/yjh0410/PyTorch_YOWO/releases/download/yowo-weight/kinetics_shufflenetv2_2.0x_RGB_16_best.pth", +} + + +def load_weight(model, arch): + url = model_urls[arch] + # check + if url is None: + print('No pretrained weight for 3D CNN: {}'.format(arch.upper())) + return model + + # checkpoint state dict + checkpoint = load_state_dict_from_url(url=url, map_location="cpu", check_hash=True) + + checkpoint_state_dict = checkpoint.pop('state_dict') + + # model state dict + model_state_dict = model.state_dict() + # reformat checkpoint_state_dict: + new_state_dict = {} + for k in checkpoint_state_dict.keys(): + v = checkpoint_state_dict[k] + new_state_dict[k[7:]] = v + # pdb.set_trace() + # check + for k in list(new_state_dict.keys()): + if k in model_state_dict: + shape_model = tuple(model_state_dict[k].shape) + shape_checkpoint = tuple(new_state_dict[k].shape) + if shape_model != shape_checkpoint: + new_state_dict.pop(k) + else: + new_state_dict.pop(k) + + model.load_state_dict(new_state_dict) + + return model + + +def conv_bn(inp, oup, stride): + return nn.Sequential( + nn.Conv3d(inp, oup, kernel_size=(5, 3, 3), stride=stride, padding=(2, 1, 1), bias=False), + nn.BatchNorm3d(oup), + nn.ReLU(inplace=True) + ) + + +class InvertedResidual(nn.Module): + def __init__(self, inp, oup, stride): + super(InvertedResidual, self).__init__() + self.stride = stride + assert stride in [1, 2] + + oup_inc = oup // 2 + + if self.stride == 1: + self.banch2 = nn.Sequential( + # pw + nn.Conv3d(oup_inc, oup_inc, 1, 1, 0, bias=False), + nn.BatchNorm3d(oup_inc), + nn.ReLU(inplace=True), + # dw + nn.Conv3d(oup_inc, oup_inc, (5, 3, 3), stride, (2, 1, 1), groups=oup_inc, bias=False), + nn.BatchNorm3d(oup_inc), + # pw-linear + nn.Conv3d(oup_inc, oup_inc, 1, 1, 0, bias=False), + nn.BatchNorm3d(oup_inc), + nn.ReLU(inplace=True) + ) + + else: + self.banch1 = nn.Sequential( + # dw + nn.Conv3d(inp, inp, (5, 3, 3), stride, (2, 1, 1), groups=inp, bias=False), + nn.BatchNorm3d(inp), + # pw-linear + nn.Conv3d(inp, oup_inc, 1, 1, 0, bias=False), + nn.BatchNorm3d(oup_inc), + nn.ReLU(inplace=True) + ) + self.banch2 = nn.Sequential( + # pw + nn.Conv3d(inp, oup_inc, 1, 1, 0, bias=False), + nn.BatchNorm3d(oup_inc), + nn.ReLU(inplace=True), + # dw + nn.Conv3d(oup_inc, oup_inc, (5, 3, 3), stride, (2, 1, 1), groups=oup_inc, bias=False), + nn.BatchNorm3d(oup_inc), + # pw-linear + nn.Conv3d(oup_inc, oup_inc, 1, 1, 0, bias=False), + nn.BatchNorm3d(oup_inc), + nn.ReLU(inplace=True) + ) + + @staticmethod + def _concat(x, out): + # concatenate along channel axis + return torch.cat((x, out), 1) + + def forward(self, x): + if self.stride == 1: + x1 = x[:, :(x.shape[1] // 2), :, :, :] + x2 = x[:, (x.shape[1] // 2):, :, :, :] + out = self._concat(x1, self.banch2(x2)) + elif self.stride == 2: + out = self._concat(self.banch1(x), self.banch2(x)) + + return channel_shuffle(out, 2) + + +def channel_shuffle(x, groups): + '''Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]''' + batchsize, num_channels, depth, height, width = x.data.size() + channels_per_group = num_channels // groups + # reshape + x = x.view(batchsize, groups, + channels_per_group, depth, height, width) + # permute + x = x.permute(0, 2, 1, 3, 4, 5).contiguous() + # flatten + x = x.view(batchsize, num_channels, depth, height, width) + return x + + +class ShuffleNetV2(nn.Module): + def __init__(self, width_mult='1.0x', num_classes=600): + super(ShuffleNetV2, self).__init__() + + self.stage_repeats = [4, 8, 4] + # index 0 is invalid and should never be called. + # only used for indexing convenience. + if width_mult == '0.25x': + self.stage_out_channels = [-1, 24, 32, 64, 128] + elif width_mult == '0.5x': + self.stage_out_channels = [-1, 24, 48, 96, 192] + elif width_mult == '1.0x': + self.stage_out_channels = [-1, 24, 128, 256, 512] + elif width_mult == '1.5x': + self.stage_out_channels = [-1, 24, 176, 352, 704] + elif width_mult == '2.0x': + self.stage_out_channels = [-1, 24, 224, 488, 976] + + # building first layer + input_channel = self.stage_out_channels[1] + self.conv1 = conv_bn(3, input_channel, stride=(1, 2, 2)) + self.maxpool = nn.MaxPool3d(kernel_size=3, stride=2, padding=1) + self.features = [] + self.features1 = [] + self.features2 = [] + self.features3 = [] + # building inverted residual blocks + for idxstage in range(len(self.stage_repeats)): + numrepeat = self.stage_repeats[idxstage] + output_channel = self.stage_out_channels[idxstage + 2] + for i in range(numrepeat): + stride = 2 if i == 0 else 1 + self.features.append(InvertedResidual(input_channel, output_channel, stride)) + input_channel = output_channel + self.features = nn.Sequential(*self.features) + # for idxstage in range(len(self.stage_repeats)): + # numrepeat = self.stage_repeats[idxstage] + # output_channel = self.stage_out_channels[idxstage+2] + # for i in range(numrepeat): + # if idxstage==0: + # stride = 2 if i == 0 else 1 + # self.features1.append(InvertedResidual(input_channel, output_channel, stride)) + # input_channel = output_channel + # elif idxstage==1: + # stride = 2 if i == 0 else 1 + # self.features2.append(InvertedResidual(input_channel, output_channel, stride)) + # input_channel = output_channel + # elif idxstage==2: + # stride = 2 if i == 0 else 1 + # self.features3.append(InvertedResidual(input_channel, output_channel, stride)) + # input_channel = output_channel + # # make it nn.Sequential + # self.features1 = nn.Sequential(*self.features1) + # self.features2 = nn.Sequential(*self.features2) + # self.features3 = nn.Sequential(*self.features3) + + # # building last several layers + # self.conv_last = conv_1x1x1_bn(input_channel, self.stage_out_channels[-1]) + # self.avgpool = nn.AvgPool3d((2, 1, 1), stride=1) + + def forward(self, x): + outputs = {} + # pdb.set_trace() #(1,3,16,512,512) #(1,3,5,512,512) + x = self.conv1(x) # (1,24,16,256,256) #(1,24,5,256,256) + + x = self.maxpool(x) # (1,24,8,128,128) #(1,24,3,128,128) + # outputs['stage1'] = x + # x=self.features(x) + x = self.features[:4](x) # (1,116,4,64,64) #(1,116,2,64,64) + outputs['stage2'] = x # torch.mean(x, dim=2, keepdim=True).squeeze(2) + x = self.features[4:12](x) # (1,232,2,32,32) #(1,232,1,32,32) + outputs['stage3'] = x # torch.mean(x, dim=2, keepdim=True).squeeze(2) + x = self.features[12:16](x) # (1,464,1,16,16) #(1,464,1,16,16) + outputs['stage4'] = x # torch.mean(x, dim=2, keepdim=True).squeeze(2) + # out = self.conv_last(out) + + # if x.size(2) > 1: + # x = torch.mean(x, dim=2, keepdim=True) + + # return x.squeeze(2) + return outputs + + +def build_shufflenetv2_3d(model_size='0.25x', pretrained=False): + model = ShuffleNetV2(model_size) + feats = model.stage_out_channels[-1] + + # if pretrained: + # model = load_weight(model, model_size) + + return model, feats + + +def build_3d_cnn(cfg, pretrained=False): + if 'resnet' in cfg['backbone_3d']: + model, feat_dims = build_resnet_3d( + model_name=cfg['backbone_3d'], + pretrained=pretrained + ) + elif 'resnext' in cfg['backbone_3d']: + model, feat_dims = build_resnext_3d( + model_name=cfg['backbone_3d'], + pretrained=pretrained + ) + elif 'shufflenetv2' in cfg['backbone_3d']: + model, feat_dims = build_shufflenetv2_3d( + model_size=cfg['model_size'], + pretrained=pretrained + ) + else: + print('Unknown Backbone ...') + exit() + + return model, feat_dims + + +class Backbone3D(nn.Module): + def __init__(self, cfg, pretrained=False): + super().__init__() + self.cfg = cfg + + # 3D CNN + self.backbone, self.feat_dim = build_3d_cnn(cfg, pretrained) + + def forward(self, x): + """ + Input: + x: (Tensor) -> [B, C, T, H, W] + Output: + y: (List) -> [ + (Tensor) -> [B, C1, H1, W1], + (Tensor) -> [B, C2, H2, W2], + (Tensor) -> [B, C3, H3, W3] + ] + """ + feat = self.backbone(x) + + return feat + diff --git a/model/TDCNet/backbonetd.py b/model/TDCNet/backbonetd.py new file mode 100644 index 0000000..e4e1e56 --- /dev/null +++ b/model/TDCNet/backbonetd.py @@ -0,0 +1,281 @@ +import os + +import torch +import torch.nn as nn +from matplotlib import pyplot as plt +from torch.hub import load_state_dict_from_url + +from model.TDCNet.TDCR import RepConv3D + +model_urls = { + "0.25x": "https://github.com/yjh0410/PyTorch_YOWO/releases/download/yowo-weight/kinetics_shufflenetv2_0.25x_RGB_16_best.pth", + "1.0x": "https://github.com/yjh0410/PyTorch_YOWO/releases/download/yowo-weight/kinetics_shufflenetv2_1.0x_RGB_16_best.pth", + "1.5x": "https://github.com/yjh0410/PyTorch_YOWO/releases/download/yowo-weight/kinetics_shufflenetv2_1.5x_RGB_16_best.pth", + "2.0x": "https://github.com/yjh0410/PyTorch_YOWO/releases/download/yowo-weight/kinetics_shufflenetv2_2.0x_RGB_16_best.pth", +} + + +def load_weight(model, arch): + url = model_urls[arch] + # check + if url is None: + print('No pretrained weight for 3D CNN: {}'.format(arch.upper())) + return model + + # checkpoint state dict + checkpoint = load_state_dict_from_url(url=url, map_location="cpu", check_hash=True) + + checkpoint_state_dict = checkpoint.pop('state_dict') + + # model state dict + model_state_dict = model.state_dict() + # reformat checkpoint_state_dict: + new_state_dict = {} + for k in checkpoint_state_dict.keys(): + v = checkpoint_state_dict[k] + new_state_dict[k[7:]] = v + # pdb.set_trace() + # check + for k in list(new_state_dict.keys()): + if k in model_state_dict: + shape_model = tuple(model_state_dict[k].shape) + shape_checkpoint = tuple(new_state_dict[k].shape) + if shape_model != shape_checkpoint: + new_state_dict.pop(k) + else: + new_state_dict.pop(k) + + model.load_state_dict(new_state_dict) + + return model + + +def conv_bn(inp, oup, stride): + # return nn.Sequential( + # nn.Conv3d(inp, oup, kernel_size=3, stride=stride, padding=(1,1,1), bias=False), + # nn.BatchNorm3d(oup), + # nn.ReLU(inplace=True) + # ) + return RepConv3D(inp, oup, (5, 3, 3), stride, (2, 1, 1)) + + +class InvertedResidual(nn.Module): + def __init__(self, inp, oup, stride): + super(InvertedResidual, self).__init__() + self.stride = stride + assert stride in [1, 2] + + oup_inc = oup // 2 + + if self.stride == 1: + self.banch2 = nn.Sequential( + # pw + nn.Conv3d(oup_inc, oup_inc, 1, 1, 0, bias=False), + nn.BatchNorm3d(oup_inc), + nn.ReLU(inplace=True), + # dw + # nn.Conv3d(oup_inc, oup_inc, 3, stride, 1, groups=oup_inc, bias=False), + # nn.BatchNorm3d(oup_inc), + RepConv3D(oup_inc, oup_inc, (5, 3, 3), stride, (2, 1, 1), groups=oup_inc), + # pw-linear + nn.Conv3d(oup_inc, oup_inc, 1, 1, 0, bias=False), + nn.BatchNorm3d(oup_inc), + nn.ReLU(inplace=True) + ) + + else: + self.banch1 = nn.Sequential( + # dw + # nn.Conv3d(inp, inp, 3, stride, 1, groups=inp, bias=False), + # nn.BatchNorm3d(inp), + RepConv3D(inp, inp, (5, 3, 3), stride, (2, 1, 1), groups=inp, ), + # pw-linear + nn.Conv3d(inp, oup_inc, 1, 1, 0, bias=False), + nn.BatchNorm3d(oup_inc), + nn.ReLU(inplace=True) + ) + self.banch2 = nn.Sequential( + # pw + nn.Conv3d(inp, oup_inc, 1, 1, 0, bias=False), + nn.BatchNorm3d(oup_inc), + nn.ReLU(inplace=True), + # dw + # nn.Conv3d(oup_inc, oup_inc, 3, stride, 1, groups=oup_inc, bias=False), + # nn.BatchNorm3d(oup_inc), + RepConv3D(oup_inc, oup_inc, (5, 3, 3), stride, (2, 1, 1), groups=oup_inc, ), + # pw-linear + nn.Conv3d(oup_inc, oup_inc, 1, 1, 0, bias=False), + nn.BatchNorm3d(oup_inc), + nn.ReLU(inplace=True) + ) + + @staticmethod + def _concat(x, out): + # concatenate along channel axis + return torch.cat((x, out), 1) + + def forward(self, x): + if self.stride == 1: + x1 = x[:, :(x.shape[1] // 2), :, :, :] + x2 = x[:, (x.shape[1] // 2):, :, :, :] + out = self._concat(x1, self.banch2(x2)) + elif self.stride == 2: + out = self._concat(self.banch1(x), self.banch2(x)) + # return out + return channel_shuffle(out, 2) +# +# +def channel_shuffle(x, groups): + '''Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]''' + batchsize, num_channels, depth, height, width = x.data.size() + channels_per_group = num_channels // groups + # reshape + x = x.view(batchsize, groups, + channels_per_group, depth, height, width) + # permute + x = x.permute(0, 2, 1, 3, 4, 5).contiguous() + # flatten + x = x.view(batchsize, num_channels, depth, height, width) + return x + + +class ShuffleNetV2(nn.Module): + def __init__(self, width_mult='1.0x', num_classes=600): + super(ShuffleNetV2, self).__init__() + + self.stage_repeats = [4, 8, 4] + # index 0 is invalid and should never be called. + # only used for indexing convenience. + if width_mult == '0.25x': + self.stage_out_channels = [-1, 24, 32, 64, 128] + elif width_mult == '0.5x': + self.stage_out_channels = [-1, 24, 48, 96, 192] + elif width_mult == '1.0x': + # self.stage_out_channels = [-1, 24, 116, 232, 464] + self.stage_out_channels = [-1, 24, 128, 256, 512] + elif width_mult == '1.5x': + self.stage_out_channels = [-1, 24, 176, 352, 704] + elif width_mult == '2.0x': + self.stage_out_channels = [-1, 24, 224, 488, 976] + + # building first layer + input_channel = self.stage_out_channels[1] + self.conv1 = conv_bn(3, input_channel, stride=(1, 2, 2)) + self.maxpool = nn.MaxPool3d(kernel_size=3, stride=2, padding=1) + self.features = [] + self.features1 = [] + self.features2 = [] + self.features3 = [] + # building inverted residual blocks + for idxstage in range(len(self.stage_repeats)): + numrepeat = self.stage_repeats[idxstage] + output_channel = self.stage_out_channels[idxstage + 2] + for i in range(numrepeat): + stride = 2 if i == 0 else 1 + self.features.append(InvertedResidual(input_channel, output_channel, stride)) + input_channel = output_channel + self.features = nn.Sequential(*self.features) + # for idxstage in range(len(self.stage_repeats)): + # numrepeat = self.stage_repeats[idxstage] + # output_channel = self.stage_out_channels[idxstage+2] + # for i in range(numrepeat): + # if idxstage==0: + # stride = 2 if i == 0 else 1 + # self.features1.append(InvertedResidual(input_channel, output_channel, stride)) + # input_channel = output_channel + # elif idxstage==1: + # stride = 2 if i == 0 else 1 + # self.features2.append(InvertedResidual(input_channel, output_channel, stride)) + # input_channel = output_channel + # elif idxstage==2: + # stride = 2 if i == 0 else 1 + # self.features3.append(InvertedResidual(input_channel, output_channel, stride)) + # input_channel = output_channel + # # make it nn.Sequential + # self.features1 = nn.Sequential(*self.features1) + # self.features2 = nn.Sequential(*self.features2) + # self.features3 = nn.Sequential(*self.features3) + + # # building last several layers + # self.conv_last = conv_1x1x1_bn(input_channel, self.stage_out_channels[-1]) + # self.avgpool = nn.AvgPool3d((2, 1, 1), stride=1) + + def forward(self, x): + outputs = {} + # pdb.set_trace() #(1,3,16,512,512) #(1,3,5,512,512) + + x = self.conv1(x) # (1,24,16,256,256) #(1,24,5,256,256) + + x = self.maxpool(x) # (1,24,8,128,128) #(1,24,3,128,128) + # outputs['stage1'] = x + # x = self.features(x) + x = self.features[:4](x) # (1,116,4,64,64) #(1,116,2,64,64) + outputs['stage2'] = x # torch.mean(x, dim=2, keepdim=True).squeeze(2) + x = self.features[4:12](x) # (1,232,2,32,32) #(1,232,1,32,32) + outputs['stage3'] = x # torch.mean(x, dim=2, keepdim=True).squeeze(2) + x = self.features[12:16](x) # (1,464,1,16,16) #(1,464,1,16,16) + outputs['stage4'] = x # torch.mean(x, dim=2, keepdim=True).squeeze(2) + # out = self.conv_last(out) + + # if x.size(2) > 1: + # x = torch.mean(x, dim=2, keepdim=True) + + # return x.squeeze(2) + return outputs + + +def build_shufflenetv2_3d(model_size='1.0x', pretrained=False): + model = ShuffleNetV2(model_size) + feats = model.stage_out_channels[-1] + + # if pretrained: + # model = load_weight(model, model_size) + + return model, feats + + +def build_3d_cnn(cfg, pretrained=False): + if 'resnet' in cfg['backbone_3d']: + model, feat_dims = build_resnet_3d( + model_name=cfg['backbone_3d'], + pretrained=pretrained + ) + elif 'resnext' in cfg['backbone_3d']: + model, feat_dims = build_resnext_3d( + model_name=cfg['backbone_3d'], + pretrained=pretrained + ) + elif 'shufflenetv2' in cfg['backbone_3d']: + model, feat_dims = build_shufflenetv2_3d( + model_size=cfg['model_size'], + pretrained=pretrained + ) + else: + print('Unknown Backbone ...') + exit() + + return model, feat_dims + + +class BackboneTD(nn.Module): + def __init__(self, cfg, pretrained=False): + super().__init__() + self.cfg = cfg + + # 3D CNN + self.backbone, self.feat_dim = build_3d_cnn(cfg, pretrained) + + def forward(self, x): + """ + Input: + x: (Tensor) -> [B, C, T, H, W] + Output: + y: (List) -> [ + (Tensor) -> [B, C1, H1, W1], + (Tensor) -> [B, C2, H2, W2], + (Tensor) -> [B, C3, H3, W3] + ] + """ + feat = self.backbone(x) + + return feat diff --git a/model/TDCNet/darknet.py b/model/TDCNet/darknet.py new file mode 100644 index 0000000..b2072c4 --- /dev/null +++ b/model/TDCNet/darknet.py @@ -0,0 +1,234 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +# Copyright (c) Megvii, Inc. and its affiliates. +import os + +import torch +from matplotlib import pyplot as plt +from torch import nn + +class SiLU(nn.Module): + @staticmethod + def forward(x): + return x * torch.sigmoid(x) + +def get_activation(name="silu", inplace=True): + if name == "silu": + module = SiLU() + elif name == "relu": + module = nn.ReLU(inplace=inplace) + elif name == "lrelu": + module = nn.LeakyReLU(0.1, inplace=inplace) + elif name == "sigmoid": + module = nn.Sigmoid() + else: + raise AttributeError("Unsupported act type: {}".format(name)) + return module + +class Focus(nn.Module): + def __init__(self, in_channels, out_channels, ksize=1, stride=1, act="silu"): + super().__init__() + self.conv = BaseConv(in_channels * 4, out_channels, ksize, stride, act=act) + + def forward(self, x): + patch_top_left = x[..., ::2, ::2] + patch_bot_left = x[..., 1::2, ::2] + patch_top_right = x[..., ::2, 1::2] + patch_bot_right = x[..., 1::2, 1::2] + x = torch.cat((patch_top_left, patch_bot_left, patch_top_right, patch_bot_right,), dim=1,) + return self.conv(x) + +class BaseConv(nn.Module): + def __init__(self, in_channels, out_channels, ksize, stride, groups=1, bias=False, act="silu"): + super().__init__() + pad = (ksize - 1) // 2 + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=ksize, stride=stride, padding=pad, groups=groups, bias=bias) + self.bn = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.03) + self.act = get_activation(act, inplace=True) + + def forward(self, x): + return self.act(self.bn(self.conv(x))) + + def fuseforward(self, x): + return self.act(self.conv(x)) + +class DWConv(nn.Module): + def __init__(self, in_channels, out_channels, ksize, stride=1, act="silu"): + super().__init__() + self.dconv = BaseConv(in_channels, in_channels, ksize=ksize, stride=stride, groups=in_channels, act=act,) + self.pconv = BaseConv(in_channels, out_channels, ksize=1, stride=1, groups=1, act=act) + + def forward(self, x): + x = self.dconv(x) + return self.pconv(x) + +class SPPBottleneck(nn.Module): + def __init__(self, in_channels, out_channels, kernel_sizes=(5, 9, 13), activation="silu"): + super().__init__() + hidden_channels = in_channels // 2 + self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=activation) + self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=ks, stride=1, padding=ks // 2) for ks in kernel_sizes]) + conv2_channels = hidden_channels * (len(kernel_sizes) + 1) + self.conv2 = BaseConv(conv2_channels, out_channels, 1, stride=1, act=activation) + + def forward(self, x): + x = self.conv1(x) + x = torch.cat([x] + [m(x) for m in self.m], dim=1) + x = self.conv2(x) + return x + +#--------------------------------------------------# +# 残差结构的构建,小的残差结构 +#--------------------------------------------------# +class Bottleneck(nn.Module): + # Standard bottleneck + def __init__(self, in_channels, out_channels, shortcut=True, expansion=0.5, depthwise=False, act="silu",): + super().__init__() + hidden_channels = int(out_channels * expansion) + Conv = DWConv if depthwise else BaseConv + #--------------------------------------------------# + # 利用1x1卷积进行通道数的缩减。缩减率一般是50% + #--------------------------------------------------# + self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act) + #--------------------------------------------------# + # 利用3x3卷积进行通道数的拓张。并且完成特征提取 + #--------------------------------------------------# + self.conv2 = Conv(hidden_channels, out_channels, 3, stride=1, act=act) + self.use_add = shortcut and in_channels == out_channels + + def forward(self, x): + y = self.conv2(self.conv1(x)) + if self.use_add: + y = y + x + return y + +class CSPLayer(nn.Module): + def __init__(self, in_channels, out_channels, n=1, shortcut=True, expansion=0.5, depthwise=False, act="silu",): + # ch_in, ch_out, number, shortcut, groups, expansion + super().__init__() + hidden_channels = int(out_channels * expansion) + #--------------------------------------------------# + # 主干部分的初次卷积 + #--------------------------------------------------# + self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act) + #--------------------------------------------------# + # 大的残差边部分的初次卷积 + #--------------------------------------------------# + self.conv2 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act) + #-----------------------------------------------# + # 对堆叠的结果进行卷积的处理 + #-----------------------------------------------# + self.conv3 = BaseConv(2 * hidden_channels, out_channels, 1, stride=1, act=act) + + #--------------------------------------------------# + # 根据循环的次数构建上述Bottleneck残差结构 + #--------------------------------------------------# + module_list = [Bottleneck(hidden_channels, hidden_channels, shortcut, 1.0, depthwise, act=act) for _ in range(n)] + self.m = nn.Sequential(*module_list) + + def forward(self, x): + #-------------------------------# + # x_1是主干部分 + #-------------------------------# + x_1 = self.conv1(x) + #-------------------------------# + # x_2是大的残差边部分 + #-------------------------------# + x_2 = self.conv2(x) + + #-----------------------------------------------# + # 主干部分利用残差结构堆叠继续进行特征提取 + #-----------------------------------------------# + x_1 = self.m(x_1) + #-----------------------------------------------# + # 主干部分和大的残差边部分进行堆叠 + #-----------------------------------------------# + x = torch.cat((x_1, x_2), dim=1) + #-----------------------------------------------# + # 对堆叠的结果进行卷积的处理 + #-----------------------------------------------# + return self.conv3(x) + +class CSPDarknet(nn.Module): + def __init__(self, dep_mul, wid_mul, out_features=("dark3", "dark4", "dark5"), depthwise=False, act="silu",): + super().__init__() + assert out_features, "please provide output features of Darknet" + self.out_features = out_features + Conv = DWConv if depthwise else BaseConv + + #-----------------------------------------------# + # 输入图片是640, 640, 3 + # 初始的基本通道是64 + #-----------------------------------------------# + base_channels = int(wid_mul * 64) # 64 + base_depth = max(round(dep_mul * 3), 1) # 3 + + #-----------------------------------------------# + # 利用focus网络结构进行特征提取 + # 640, 640, 3 -> 320, 320, 12 -> 320, 320, 64 + #-----------------------------------------------# + self.stem = Focus(3, base_channels, ksize=3, act=act) + + #-----------------------------------------------# + # 完成卷积之后,320, 320, 64 -> 160, 160, 128 + # 完成CSPlayer之后,160, 160, 128 -> 160, 160, 128 + #-----------------------------------------------# + self.dark2 = nn.Sequential( + Conv(base_channels, base_channels * 2, 3, 2, act=act), + CSPLayer(base_channels * 2, base_channels * 2, n=base_depth, depthwise=depthwise, act=act), + ) + + #-----------------------------------------------# + # 完成卷积之后,160, 160, 128 -> 80, 80, 256 + # 完成CSPlayer之后,80, 80, 256 -> 80, 80, 256 + #-----------------------------------------------# + self.dark3 = nn.Sequential( + Conv(base_channels * 2, base_channels * 4, 3, 2, act=act), + CSPLayer(base_channels * 4, base_channels * 4, n=base_depth * 3, depthwise=depthwise, act=act), + ) + + #-----------------------------------------------# + # 完成卷积之后,80, 80, 256 -> 40, 40, 512 + # 完成CSPlayer之后,40, 40, 512 -> 40, 40, 512 + #-----------------------------------------------# + self.dark4 = nn.Sequential( + Conv(base_channels * 4, base_channels * 8, 3, 2, act=act), + CSPLayer(base_channels * 8, base_channels * 8, n=base_depth * 3, depthwise=depthwise, act=act), + ) + + #-----------------------------------------------# + # 完成卷积之后,40, 40, 512 -> 20, 20, 1024 + # 完成SPP之后,20, 20, 1024 -> 20, 20, 1024 + # 完成CSPlayer之后,20, 20, 1024 -> 20, 20, 1024 + #-----------------------------------------------# + self.dark5 = nn.Sequential( + Conv(base_channels * 8, base_channels * 16, 3, 2, act=act), + SPPBottleneck(base_channels * 16, base_channels * 16, activation=act), + CSPLayer(base_channels * 16, base_channels * 16, n=base_depth, shortcut=False, depthwise=depthwise, act=act), + ) + + def forward(self, x): + outputs = {} + x = self.stem(x) + outputs["stem"] = x + + + x = self.dark2(x) + outputs["dark2"] = x + + #-----------------------------------------------# + # dark3的输出为80, 80, 256,是一个有效特征层 + #-----------------------------------------------# + x = self.dark3(x) + outputs["dark3"] = x + #-----------------------------------------------# + # dark4的输出为40, 40, 512,是一个有效特征层 + #-----------------------------------------------# + x = self.dark4(x) + outputs["dark4"] = x + #-----------------------------------------------# + # dark5的输出为20, 20, 1024,是一个有效特征层 + #-----------------------------------------------# + x = self.dark5(x) + outputs["dark5"] = x + return {k: v for k, v in outputs.items() if k in self.out_features} \ No newline at end of file diff --git a/model/nets/yolo_training.py b/model/nets/yolo_training.py new file mode 100644 index 0000000..5308c7f --- /dev/null +++ b/model/nets/yolo_training.py @@ -0,0 +1,507 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +# Copyright (c) Megvii, Inc. and its affiliates. +import math +from copy import deepcopy +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision.ops.focal_loss import sigmoid_focal_loss + + +class IOUloss(nn.Module): + def __init__(self, reduction="none", loss_type="iou"): + super(IOUloss, self).__init__() + self.reduction = reduction + self.loss_type = loss_type + + def forward(self, pred, target): + assert pred.shape[0] == target.shape[0] + + pred = pred.view(-1, 4) + target = target.view(-1, 4) + tl = torch.max( + (pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2) + ) + br = torch.min( + (pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2) + ) + + area_p = torch.prod(pred[:, 2:], 1) + area_g = torch.prod(target[:, 2:], 1) + + en = (tl < br).type(tl.type()).prod(dim=1) + area_i = torch.prod(br - tl, 1) * en + area_u = area_p + area_g - area_i + iou = (area_i) / (area_u + 1e-16) + + if self.loss_type == "iou": + loss = 1 - iou ** 2 + elif self.loss_type == "giou": + c_tl = torch.min( + (pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2) + ) + c_br = torch.max( + (pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2) + ) + area_c = torch.prod(c_br - c_tl, 1) + giou = iou - (area_c - area_u) / area_c.clamp(1e-16) + loss = 1 - giou.clamp(min=-1.0, max=1.0) + elif self.loss_type == 'ciou': + b1_cxy = pred[:,:2] + b2_cxy = target[:,:2] + # 计算中心的差距 + center_distance = torch.sum(torch.pow((b1_cxy - b2_cxy), 2), axis=-1) + # 找到包裹两个框的最小框的左上角和右下角 + enclose_mins = torch.min((pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2)) + enclose_maxes = torch.max((pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2)) + enclose_wh = torch.max(enclose_maxes - enclose_mins, torch.zeros_like(br)) + # 计算对角线距离 + enclose_diagonal = torch.sum(torch.pow(enclose_wh,2), axis=-1) + ciou = iou - 1.0 * (center_distance) / torch.clamp(enclose_diagonal,min = 1e-6) + v = (4 / (torch.pi ** 2)) * torch.pow((torch.atan(pred[:, 2]/torch.clamp(pred[:, 3],min = 1e-6)) - torch.atan(target[:, 2]/torch.clamp(target[:, 3],min = 1e-6))), 2) + alpha = v / torch.clamp((1.0 - iou + v),min=1e-6) + ciou = ciou - alpha * v + loss = 1 - ciou.clamp(min=-1.0, max=1.0) + + if self.reduction == "mean": + loss = loss.mean() + elif self.reduction == "sum": + loss = loss.sum() + + return loss + +class YOLOLoss(nn.Module): + def __init__(self, num_classes, fp16, strides=[8, 16, 32]): + super().__init__() + self.num_classes = num_classes + self.strides = strides + + self.bcewithlog_loss = nn.BCEWithLogitsLoss(reduction="none") + self.iou_loss = IOUloss(reduction="none") + self.grids = [torch.zeros(1)] * len(strides) + self.fp16 = fp16 + + def forward(self, inputs, labels=None): + outputs = [] + x_shifts = [] + y_shifts = [] + expanded_strides = [] + + #-----------------------------------------------# + # inputs [[batch_size, num_classes + 5, 20, 20] + # [batch_size, num_classes + 5, 40, 40] + # [batch_size, num_classes + 5, 80, 80]] + # outputs [[batch_size, 400, num_classes + 5] + # [batch_size, 1600, num_classes + 5] + # [batch_size, 6400, num_classes + 5]] + # x_shifts [[batch_size, 400] + # [batch_size, 1600] + # [batch_size, 6400]] + #-----------------------------------------------# + for k, (stride, output) in enumerate(zip(self.strides, inputs)): + output, grid = self.get_output_and_grid(output, k, stride) + x_shifts.append(grid[:, :, 0]) + y_shifts.append(grid[:, :, 1]) + expanded_strides.append(torch.ones_like(grid[:, :, 0]) * stride) + outputs.append(output) + + return self.get_losses(x_shifts, y_shifts, expanded_strides, labels, torch.cat(outputs, 1)) + + def get_output_and_grid(self, output, k, stride): + grid = self.grids[k] + hsize, wsize = output.shape[-2:] + if grid.shape[2:4] != output.shape[2:4]: + yv, xv = torch.meshgrid([torch.arange(hsize), torch.arange(wsize)], indexing='ij') + grid = torch.stack((xv, yv), 2).view(1, hsize, wsize, 2).type(output.type()) + self.grids[k] = grid + grid = grid.view(1, -1, 2) + + output = output.flatten(start_dim=2).permute(0, 2, 1) + output[..., :2] = (output[..., :2] + grid.type_as(output)) * stride + output[..., 2:4] = torch.exp(output[..., 2:4]) * stride + return output, grid + + def get_losses(self, x_shifts, y_shifts, expanded_strides, labels, outputs): + #-----------------------------------------------# + # [batch, n_anchors_all, 4] + #-----------------------------------------------# + bbox_preds = outputs[:, :, :4] + #-----------------------------------------------# + # [batch, n_anchors_all, 1] + #-----------------------------------------------# + obj_preds = outputs[:, :, 4:5] + #-----------------------------------------------# + # [batch, n_anchors_all, n_cls] + #-----------------------------------------------# + cls_preds = outputs[:, :, 5:] + + total_num_anchors = outputs.shape[1] + #-----------------------------------------------# + # x_shifts [1, n_anchors_all] + # y_shifts [1, n_anchors_all] + # expanded_strides [1, n_anchors_all] + #-----------------------------------------------# + x_shifts = torch.cat(x_shifts, 1).type_as(outputs) + y_shifts = torch.cat(y_shifts, 1).type_as(outputs) + expanded_strides = torch.cat(expanded_strides, 1).type_as(outputs) + + cls_targets = [] + reg_targets = [] + obj_targets = [] + fg_masks = [] + + num_fg = 0.0 + for batch_idx in range(outputs.shape[0]): + num_gt = len(labels[batch_idx]) + if num_gt == 0: + cls_target = outputs.new_zeros((0, self.num_classes)) + reg_target = outputs.new_zeros((0, 4)) + obj_target = outputs.new_zeros((total_num_anchors, 1)) + fg_mask = outputs.new_zeros(total_num_anchors).bool() + else: + #-----------------------------------------------# + # gt_bboxes_per_image [num_gt, num_classes] + # gt_classes [num_gt] + # bboxes_preds_per_image [n_anchors_all, 4] + # cls_preds_per_image [n_anchors_all, num_classes] + # obj_preds_per_image [n_anchors_all, 1] + #-----------------------------------------------# + gt_bboxes_per_image = labels[batch_idx][..., :4].type_as(outputs) + gt_classes = labels[batch_idx][..., 4].type_as(outputs) + bboxes_preds_per_image = bbox_preds[batch_idx] + cls_preds_per_image = cls_preds[batch_idx] + obj_preds_per_image = obj_preds[batch_idx] + + gt_matched_classes, fg_mask, pred_ious_this_matching, matched_gt_inds, num_fg_img = self.get_assignments( + num_gt, total_num_anchors, gt_bboxes_per_image, gt_classes, bboxes_preds_per_image, cls_preds_per_image, obj_preds_per_image, + expanded_strides, x_shifts, y_shifts, + ) + torch.cuda.empty_cache() + num_fg += num_fg_img + cls_target = F.one_hot(gt_matched_classes.to(torch.int64), self.num_classes).float() * pred_ious_this_matching.unsqueeze(-1) + obj_target = fg_mask.unsqueeze(-1) + reg_target = gt_bboxes_per_image[matched_gt_inds] + cls_targets.append(cls_target) + reg_targets.append(reg_target) + obj_targets.append(obj_target.type(cls_target.type())) + fg_masks.append(fg_mask) + + cls_targets = torch.cat(cls_targets, 0) + reg_targets = torch.cat(reg_targets, 0) + obj_targets = torch.cat(obj_targets, 0) + fg_masks = torch.cat(fg_masks, 0) + + num_fg = max(num_fg, 1) + loss_iou = (self.iou_loss(bbox_preds.view(-1, 4)[fg_masks], reg_targets)).sum() + loss_obj = (self.bcewithlog_loss(obj_preds.view(-1, 1), obj_targets)).sum() + loss_cls = (self.bcewithlog_loss(cls_preds.view(-1, self.num_classes)[fg_masks], cls_targets)).sum() + # loss_obj = (sigmoid_focal_loss(obj_preds.view(-1, 1), obj_targets)).sum() + # loss_cls = (sigmoid_focal_loss(cls_preds.view(-1, self.num_classes)[fg_masks], cls_targets)).sum() + reg_weight = 5.0 + loss = reg_weight * loss_iou + loss_obj + loss_cls + + return loss / num_fg + + @torch.no_grad() + def get_assignments(self, num_gt, total_num_anchors, gt_bboxes_per_image, gt_classes, bboxes_preds_per_image, cls_preds_per_image, obj_preds_per_image, expanded_strides, x_shifts, y_shifts): + #-------------------------------------------------------# + # fg_mask [n_anchors_all] + # is_in_boxes_and_center [num_gt, len(fg_mask)] + #-------------------------------------------------------# + fg_mask, is_in_boxes_and_center = self.get_in_boxes_info(gt_bboxes_per_image, expanded_strides, x_shifts, y_shifts, total_num_anchors, num_gt) + + #-------------------------------------------------------# + # fg_mask [n_anchors_all] + # bboxes_preds_per_image [fg_mask, 4] + # cls_preds_ [fg_mask, num_classes] + # obj_preds_ [fg_mask, 1] + #-------------------------------------------------------# + bboxes_preds_per_image = bboxes_preds_per_image[fg_mask] + cls_preds_ = cls_preds_per_image[fg_mask] + obj_preds_ = obj_preds_per_image[fg_mask] + num_in_boxes_anchor = bboxes_preds_per_image.shape[0] + + #-------------------------------------------------------# + # pair_wise_ious [num_gt, fg_mask] + #-------------------------------------------------------# + pair_wise_ious = self.bboxes_iou(gt_bboxes_per_image, bboxes_preds_per_image, False) + pair_wise_ious_loss = -torch.log(pair_wise_ious + 1e-8) + + #-------------------------------------------------------# + # cls_preds_ [num_gt, fg_mask, num_classes] + # gt_cls_per_image [num_gt, fg_mask, num_classes] + #-------------------------------------------------------# + if self.fp16: + with torch.cuda.amp.autocast(enabled=False): + cls_preds_ = cls_preds_.float().unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_() * obj_preds_.unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_() + gt_cls_per_image = F.one_hot(gt_classes.to(torch.int64), self.num_classes).float().unsqueeze(1).repeat(1, num_in_boxes_anchor, 1) + pair_wise_cls_loss = F.binary_cross_entropy(cls_preds_.sqrt_(), gt_cls_per_image, reduction="none").sum(-1) + else: + cls_preds_ = cls_preds_.float().unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_() * obj_preds_.unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_() + gt_cls_per_image = F.one_hot(gt_classes.to(torch.int64), self.num_classes).float().unsqueeze(1).repeat(1, num_in_boxes_anchor, 1) + pair_wise_cls_loss = F.binary_cross_entropy(cls_preds_.sqrt_(), gt_cls_per_image, reduction="none").sum(-1) + del cls_preds_ + + cost = pair_wise_cls_loss + 3.0 * pair_wise_ious_loss + 100000.0 * (~is_in_boxes_and_center).float() + + num_fg, gt_matched_classes, pred_ious_this_matching, matched_gt_inds = self.dynamic_k_matching(cost, pair_wise_ious, gt_classes, num_gt, fg_mask) + del pair_wise_cls_loss, cost, pair_wise_ious, pair_wise_ious_loss + return gt_matched_classes, fg_mask, pred_ious_this_matching, matched_gt_inds, num_fg + + def bboxes_iou(self, bboxes_a, bboxes_b, xyxy=True): + if bboxes_a.shape[1] != 4 or bboxes_b.shape[1] != 4: + raise IndexError + + if xyxy: + tl = torch.max(bboxes_a[:, None, :2], bboxes_b[:, :2]) + br = torch.min(bboxes_a[:, None, 2:], bboxes_b[:, 2:]) + area_a = torch.prod(bboxes_a[:, 2:] - bboxes_a[:, :2], 1) + area_b = torch.prod(bboxes_b[:, 2:] - bboxes_b[:, :2], 1) + else: + tl = torch.max( + (bboxes_a[:, None, :2] - bboxes_a[:, None, 2:] / 2), + (bboxes_b[:, :2] - bboxes_b[:, 2:] / 2), + ) + br = torch.min( + (bboxes_a[:, None, :2] + bboxes_a[:, None, 2:] / 2), + (bboxes_b[:, :2] + bboxes_b[:, 2:] / 2), + ) + + area_a = torch.prod(bboxes_a[:, 2:], 1) + area_b = torch.prod(bboxes_b[:, 2:], 1) + en = (tl < br).type(tl.type()).prod(dim=2) + area_i = torch.prod(br - tl, 2) * en + return area_i / (area_a[:, None] + area_b - area_i) + + def get_in_boxes_info(self, gt_bboxes_per_image, expanded_strides, x_shifts, y_shifts, total_num_anchors, num_gt, center_radius = 2.5): + #-------------------------------------------------------# + # expanded_strides_per_image [n_anchors_all] + # x_centers_per_image [num_gt, n_anchors_all] + # x_centers_per_image [num_gt, n_anchors_all] + #-------------------------------------------------------# + expanded_strides_per_image = expanded_strides[0] + x_centers_per_image = ((x_shifts[0] + 0.5) * expanded_strides_per_image).unsqueeze(0).repeat(num_gt, 1) + y_centers_per_image = ((y_shifts[0] + 0.5) * expanded_strides_per_image).unsqueeze(0).repeat(num_gt, 1) + + #-------------------------------------------------------# + # gt_bboxes_per_image_x [num_gt, n_anchors_all] + #-------------------------------------------------------# + gt_bboxes_per_image_l = (gt_bboxes_per_image[:, 0] - 0.5 * gt_bboxes_per_image[:, 2]).unsqueeze(1).repeat(1, total_num_anchors) + gt_bboxes_per_image_r = (gt_bboxes_per_image[:, 0] + 0.5 * gt_bboxes_per_image[:, 2]).unsqueeze(1).repeat(1, total_num_anchors) + gt_bboxes_per_image_t = (gt_bboxes_per_image[:, 1] - 0.5 * gt_bboxes_per_image[:, 3]).unsqueeze(1).repeat(1, total_num_anchors) + gt_bboxes_per_image_b = (gt_bboxes_per_image[:, 1] + 0.5 * gt_bboxes_per_image[:, 3]).unsqueeze(1).repeat(1, total_num_anchors) + + #-------------------------------------------------------# + # bbox_deltas [num_gt, n_anchors_all, 4] + #-------------------------------------------------------# + b_l = x_centers_per_image - gt_bboxes_per_image_l + b_r = gt_bboxes_per_image_r - x_centers_per_image + b_t = y_centers_per_image - gt_bboxes_per_image_t + b_b = gt_bboxes_per_image_b - y_centers_per_image + bbox_deltas = torch.stack([b_l, b_t, b_r, b_b], 2) + + #-------------------------------------------------------# + # is_in_boxes [num_gt, n_anchors_all] + # is_in_boxes_all [n_anchors_all] + #-------------------------------------------------------# + is_in_boxes = bbox_deltas.min(dim=-1).values > 0.0 + is_in_boxes_all = is_in_boxes.sum(dim=0) > 0 + + gt_bboxes_per_image_l = (gt_bboxes_per_image[:, 0]).unsqueeze(1).repeat(1, total_num_anchors) - center_radius * expanded_strides_per_image.unsqueeze(0) + gt_bboxes_per_image_r = (gt_bboxes_per_image[:, 0]).unsqueeze(1).repeat(1, total_num_anchors) + center_radius * expanded_strides_per_image.unsqueeze(0) + gt_bboxes_per_image_t = (gt_bboxes_per_image[:, 1]).unsqueeze(1).repeat(1, total_num_anchors) - center_radius * expanded_strides_per_image.unsqueeze(0) + gt_bboxes_per_image_b = (gt_bboxes_per_image[:, 1]).unsqueeze(1).repeat(1, total_num_anchors) + center_radius * expanded_strides_per_image.unsqueeze(0) + + #-------------------------------------------------------# + # center_deltas [num_gt, n_anchors_all, 4] + #-------------------------------------------------------# + c_l = x_centers_per_image - gt_bboxes_per_image_l + c_r = gt_bboxes_per_image_r - x_centers_per_image + c_t = y_centers_per_image - gt_bboxes_per_image_t + c_b = gt_bboxes_per_image_b - y_centers_per_image + center_deltas = torch.stack([c_l, c_t, c_r, c_b], 2) + + #-------------------------------------------------------# + # is_in_centers [num_gt, n_anchors_all] + # is_in_centers_all [n_anchors_all] + #-------------------------------------------------------# + is_in_centers = center_deltas.min(dim=-1).values > 0.0 + is_in_centers_all = is_in_centers.sum(dim=0) > 0 + + #-------------------------------------------------------# + # is_in_boxes_anchor [n_anchors_all] + # is_in_boxes_and_center [num_gt, is_in_boxes_anchor] + #-------------------------------------------------------# + is_in_boxes_anchor = is_in_boxes_all | is_in_centers_all + is_in_boxes_and_center = is_in_boxes[:, is_in_boxes_anchor] & is_in_centers[:, is_in_boxes_anchor] + return is_in_boxes_anchor, is_in_boxes_and_center + + def dynamic_k_matching(self, cost, pair_wise_ious, gt_classes, num_gt, fg_mask): + #-------------------------------------------------------# + # cost [num_gt, fg_mask] + # pair_wise_ious [num_gt, fg_mask] + # gt_classes [num_gt] + # fg_mask [n_anchors_all] + # matching_matrix [num_gt, fg_mask] + #-------------------------------------------------------# + matching_matrix = torch.zeros_like(cost) + + #------------------------------------------------------------# + # 选取iou最大的n_candidate_k个点 + # 然后求和,判断应该有多少点用于该框预测 + # topk_ious [num_gt, n_candidate_k] + # dynamic_ks [num_gt] + # matching_matrix [num_gt, fg_mask] + #------------------------------------------------------------# + n_candidate_k = min(10, pair_wise_ious.size(1)) + topk_ious, _ = torch.topk(pair_wise_ious, n_candidate_k, dim=1) + dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1) + + for gt_idx in range(num_gt): + #------------------------------------------------------------# + # 给每个真实框选取最小的动态k个点 + #------------------------------------------------------------# + _, pos_idx = torch.topk(cost[gt_idx], k=dynamic_ks[gt_idx].item(), largest=False) + matching_matrix[gt_idx][pos_idx] = 1.0 + del topk_ious, dynamic_ks, pos_idx + + #------------------------------------------------------------# + # anchor_matching_gt [fg_mask] + #------------------------------------------------------------# + anchor_matching_gt = matching_matrix.sum(0) + if (anchor_matching_gt > 1).sum() > 0: + #------------------------------------------------------------# + # 当某一个特征点指向多个真实框的时候 + # 选取cost最小的真实框。 + #------------------------------------------------------------# + _, cost_argmin = torch.min(cost[:, anchor_matching_gt > 1], dim=0) + matching_matrix[:, anchor_matching_gt > 1] *= 0.0 + matching_matrix[cost_argmin, anchor_matching_gt > 1] = 1.0 + #------------------------------------------------------------# + # fg_mask_inboxes [fg_mask] + # num_fg为正样本的特征点个数 + #------------------------------------------------------------# + fg_mask_inboxes = matching_matrix.sum(0) > 0.0 + num_fg = fg_mask_inboxes.sum().item() + + #------------------------------------------------------------# + # 对fg_mask进行更新 + #------------------------------------------------------------# + fg_mask[fg_mask.clone()] = fg_mask_inboxes + + #------------------------------------------------------------# + # 获得特征点对应的物品种类 + #------------------------------------------------------------# + matched_gt_inds = matching_matrix[:, fg_mask_inboxes].argmax(0) + gt_matched_classes = gt_classes[matched_gt_inds] + + pred_ious_this_matching = (matching_matrix * pair_wise_ious).sum(0)[fg_mask_inboxes] + return num_fg, gt_matched_classes, pred_ious_this_matching, matched_gt_inds + +def is_parallel(model): + # Returns True if model is of type DP or DDP + return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel) + +def de_parallel(model): + # De-parallelize a model: returns single-GPU model if model is of type DP or DDP + return model.module if is_parallel(model) else model + +def copy_attr(a, b, include=(), exclude=()): + # Copy attributes from b to a, options to only include [...] and to exclude [...] + for k, v in b.__dict__.items(): + if (len(include) and k not in include) or k.startswith('_') or k in exclude: + continue + else: + setattr(a, k, v) + +class ModelEMA: + """ Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models + Keeps a moving average of everything in the model state_dict (parameters and buffers) + For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage + """ + + def __init__(self, model, decay=0.9999, tau=2000, updates=0): + # Create EMA + self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA + # if next(model.parameters()).device.type != 'cpu': + # self.ema.half() # FP16 EMA + self.updates = updates # number of EMA updates + self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs) + for p in self.ema.parameters(): + p.requires_grad_(False) + + def update(self, model): + # Update EMA parameters + with torch.no_grad(): + self.updates += 1 + d = self.decay(self.updates) + + msd = de_parallel(model).state_dict() # model state_dict + for k, v in self.ema.state_dict().items(): + if v.dtype.is_floating_point: + v *= d + v += (1 - d) * msd[k].detach() + + def update_attr(self, model, include=(), exclude=('process_group', 'reducer')): + # Update EMA attributes + copy_attr(self.ema, model, include, exclude) + +def weights_init(net, init_type='normal', init_gain = 0.02): + def init_func(m): + classname = m.__class__.__name__ + if hasattr(m, 'weight') and classname.find('Conv') != -1: + if init_type == 'normal': + torch.nn.init.normal_(m.weight.data, 0.0, init_gain) + elif init_type == 'xavier': + torch.nn.init.xavier_normal_(m.weight.data, gain=init_gain) + elif init_type == 'kaiming': + torch.nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') + elif init_type == 'orthogonal': + torch.nn.init.orthogonal_(m.weight.data, gain=init_gain) + else: + raise NotImplementedError('initialization method [%s] is not implemented' % init_type) + elif classname.find('BatchNorm2d') != -1: + torch.nn.init.normal_(m.weight.data, 1.0, 0.02) + torch.nn.init.constant_(m.bias.data, 0.0) + print('initialize network with %s type' % init_type) + net.apply(init_func) + +def get_lr_scheduler(lr_decay_type, lr, min_lr, total_iters, warmup_iters_ratio = 0.05, warmup_lr_ratio = 0.1, no_aug_iter_ratio = 0.05, step_num = 10): + def yolox_warm_cos_lr(lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter, iters): + if iters <= warmup_total_iters: + # lr = (lr - warmup_lr_start) * iters / float(warmup_total_iters) + warmup_lr_start + lr = (lr - warmup_lr_start) * pow(iters / float(warmup_total_iters), 2) + warmup_lr_start + elif iters >= total_iters - no_aug_iter: + lr = min_lr + else: + lr = min_lr + 0.5 * (lr - min_lr) * ( + 1.0 + math.cos(math.pi* (iters - warmup_total_iters) / (total_iters - warmup_total_iters - no_aug_iter)) + ) + return lr + + def step_lr(lr, decay_rate, step_size, iters): + if step_size < 1: + raise ValueError("step_size must above 1.") + n = iters // step_size + out_lr = lr * decay_rate ** n + return out_lr + + if lr_decay_type == "cos": + warmup_total_iters = min(max(warmup_iters_ratio * total_iters, 1), 3) + warmup_lr_start = max(warmup_lr_ratio * lr, 1e-6) + no_aug_iter = min(max(no_aug_iter_ratio * total_iters, 1), 15) + func = partial(yolox_warm_cos_lr ,lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter) + else: + decay_rate = (min_lr / lr) ** (1 / (step_num - 1)) + step_size = total_iters / step_num + func = partial(step_lr, lr, decay_rate, step_size) + + return func + +def set_optimizer_lr(optimizer, lr_scheduler_func, epoch): + lr = lr_scheduler_func(epoch) + for param_group in optimizer.param_groups: + param_group['lr'] = lr diff --git a/model_data/classes.txt b/model_data/classes.txt new file mode 100644 index 0000000..1de5659 --- /dev/null +++ b/model_data/classes.txt @@ -0,0 +1 @@ +target \ No newline at end of file diff --git a/summary.py b/summary.py new file mode 100644 index 0000000..98aa489 --- /dev/null +++ b/summary.py @@ -0,0 +1,83 @@ +# --------------------------------------------# +# 该部分代码用于看网络结构 +# --------------------------------------------# +import os + +from torch import nn + +os.environ['CUDA_VISIBLE_DEVICES'] = '0' +import torch +from thop import clever_format, profile + +from model.TDCNet.TDCNetwork import TDCNetwork +from model.TDCNet.TDCR import RepConv3D + +if __name__ == "__main__": + input_shape = [640, 640] + num_classes = 1 + + # 需要使用device来指定网络在GPU还是CPU运行8824 + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + num_frame = 5 + m = TDCNetwork(num_classes, num_frame=5) + for mm in m.modules(): + if isinstance(mm, RepConv3D): + mm.switch_to_deploy() + dummy_input = torch.randn(1, 3, num_frame * 2, input_shape[0], input_shape[1]).to(device) # torch.randn(1, 3, 5,input_shape[0], input_shape[1]).to(device) + flops, params = profile(m.to(device), (dummy_input,), verbose=False) + # #--------------------------------------------------------# + # # flops * 2是因为profile没有将卷积作为两个operations + # # 有些论文将卷积算乘法、加法两个operations。此时乘2 + # # 有些论文只考虑乘法的运算次数,忽略加法。此时不乘2 + # # 本代码选择乘2,参考YOLOX。 + # #--------------------------------------------------------# + flops = flops * 2 + flops, params = clever_format([flops, params], "%.3f") + print('Total GFLOPS: %s' % (flops)) + print('Total params: %s' % (params)) + + # 计算FPS + + from data.dataloader_for_IRSTD_UAV import seqDataset, dataset_collate + from torch.utils.data import DataLoader + import time + max_iter = 2000 + log_interval = 50 + num_warmup = 20 + pure_inf_time = 0 + fps = 0 + val_annotation_path = "/data1/gsk/Dataset/IRSTD-UAV/val.txt" + val_dataset = seqDataset(val_annotation_path, input_shape[0], num_frame, 'val') + gen_val = DataLoader(val_dataset, shuffle = False, batch_size = 1, num_workers = 10, pin_memory=True, + drop_last=True, collate_fn=dataset_collate) + m = nn.DataParallel(m).cuda() + + # benchmark with 2000 image and take the average + for i, data in enumerate(gen_val): + torch.cuda.synchronize() + start_time = time.perf_counter() + + with torch.no_grad(): + m(data[0].to('cuda')) + + torch.cuda.synchronize() + elapsed = time.perf_counter() - start_time + + if i >= num_warmup: + pure_inf_time += elapsed + if (i + 1) % log_interval == 0: + fps = (i + 1 - num_warmup) / pure_inf_time + print( + f'Done image [{i + 1:<3}/ {max_iter}], ' + f'fps: {fps:.1f} img / s, ' + f'times per image: {1000 / fps:.1f} ms / img', + flush=True) + + if (i + 1) == max_iter: + fps = (i + 1 - num_warmup) / pure_inf_time + print( + f'Overall fps: {fps:.1f} img / s, ' + f'times per image: {1000 / fps:.1f} ms / img', + flush=True) + break + print("FPS:" ,fps) diff --git a/test.py b/test.py new file mode 100644 index 0000000..440419c --- /dev/null +++ b/test.py @@ -0,0 +1,158 @@ +import json +import os + +import cv2 +import numpy as np +import torch +import torch.nn as nn +from pycocotools.coco import COCO +from pycocotools.cocoeval import COCOeval +from tqdm import tqdm + +from model.TDCNet.TDCNetwork import TDCNetwork +from model.TDCNet.TDCR import RepConv3D +from utils.utils import get_classes, show_config +from utils.utils_bbox import decode_outputs, non_max_suppression + +# 配置参数 +cocoGt_path = '/Dataset/IRSTD-UAV/val_coco.json' +dataset_img_path = '/Dataset/IRSTD-UAV' +temp_save_path = 'results/TDCNet_epoch_100_batch_4_optim_adam_lr_0.001_T_5' +model_path = '' +num_frame = 5 +os.environ['CUDA_VISIBLE_DEVICES'] = '0' + + +def get_history_imgs(image_path): + """获取5帧背景对齐图 + 5帧原图""" + dir_path = os.path.dirname(image_path) + index = int(os.path.basename(image_path).split('.')[0]) + match_dir = dir_path.replace('images', f'matches') + f"/{index:08d}" + match_imgs = [os.path.join(match_dir, f"match_{i}.png") for i in range(1, num_frame + 1)] + + min_index = index - (index % 50) + original_imgs = [os.path.join(dir_path, f"{max(index - i, min_index):08d}.png") for i in reversed(range(num_frame))] + return match_imgs + original_imgs + + +def letterbox_image_batch(images, target_size=(512, 512), color=(128, 128, 128)): + """ + letterbox预处理 + Args: + images: list of np.ndarray, shape: (H, W, 3) + target_size: desired (width, height) + Returns: + np.ndarray of shape (N, target_H, target_W, 3) + """ + w, h = target_size + output = np.full((len(images), h, w, 3), color, dtype=np.uint8) + + for i, img in enumerate(images): + ih, iw = img.shape[:2] + scale = min(w / iw, h / ih) + nw = int(iw * scale) + nh = int(ih * scale) + + resized = cv2.resize(img, (nw, nh), interpolation=cv2.INTER_LINEAR) + top = (h - nh) // 2 + left = (w - nw) // 2 + output[i, top:top + nh, left:left + nw, :] = resized + + return output + + +class MAP_vid: + def __init__(self): + self.model_path = model_path + self.classes_path = 'model_data/classes.txt' + self.input_shape = [640, 640] + self.confidence = 0.001 + self.nms_iou = 0.5 + self.letterbox_image = True + self.cuda = True + + self.class_names, self.num_classes = get_classes(self.classes_path) + self.net = TDCNetwork(self.num_classes, num_frame=num_frame) + state_dict = torch.load(self.model_path, map_location='cuda' if self.cuda else 'cpu') + self.net.load_state_dict(state_dict) + for m in self.net.modules(): + if isinstance(m, RepConv3D): + m.switch_to_deploy() + + if self.cuda: + self.net = nn.DataParallel(self.net).cuda() + self.net = self.net.eval() + show_config(**self.__dict__) + + def detect_image(self, image_id, images, results): + # Resize + image_shape = np.array(images[0].shape[:2]) + images = letterbox_image_batch(images, target_size=tuple(self.input_shape)) + + # Preprocess + images = np.array(images).astype(np.float32) / 255.0 + images = images.transpose(3, 0, 1, 2)[None] + + # To tensor + with torch.no_grad(): + images_tensor = torch.from_numpy(images).cuda() + + # Inference + with torch.no_grad(): + outputs = self.net(images_tensor) + outputs = decode_outputs(outputs, self.input_shape) + + # NMS + outputs = non_max_suppression(outputs, self.num_classes, self.input_shape, + image_shape, self.letterbox_image, + conf_thres=self.confidence, nms_thres=self.nms_iou) + + # Postprocess + if outputs[0] is not None: + top_label = np.array(outputs[0][:, 6], dtype='int32') + top_conf = outputs[0][:, 4] * outputs[0][:, 5] + top_boxes = outputs[0][:, :4] + + for i, c in enumerate(top_label): + top, left, bottom, right = top_boxes[i] + results.append({ + "image_id": int(image_id), + "category_id": clsid2catid[c], + "bbox": [float(left), float(top), float(right - left), float(bottom - top)], + "score": float(top_conf[i]) + }) + return results + + +if __name__ == "__main__": + os.makedirs(temp_save_path, exist_ok=True) + cocoGt = COCO(cocoGt_path) + ids = list(cocoGt.imgToAnns.keys()) + global clsid2catid + clsid2catid = cocoGt.getCatIds() + + yolo = MAP_vid() + results = [] + + for image_id in tqdm(ids): + file_name = cocoGt.loadImgs(image_id)[0]['file_name'] + image_path = os.path.join(dataset_img_path, file_name) + image_paths = get_history_imgs(image_path) + images = [cv2.imread(p)[:, :, ::-1] for p in image_paths] # BGR -> RGB + results = yolo.detect_image(image_id, images, results) + + with open(os.path.join(temp_save_path, 'eval_results.json'), 'w') as f: + json.dump(results, f) + + cocoDt = cocoGt.loadRes(os.path.join(temp_save_path, 'eval_results.json')) + cocoEval = COCOeval(cocoGt, cocoDt, 'bbox') + cocoEval.evaluate() + cocoEval.accumulate() + cocoEval.summarize() + precisions = cocoEval.eval['precision'] + precision_50 = precisions[0, :, 0, 0, -1] # 第三为类别 (T,R,K,A,M) + recalls = cocoEval.eval['recall'] + recall_50 = recalls[0, 0, 0, -1] # 第二为类别 (T,K,A,M) + + print("Precision: %.4f, Recall: %.4f, F1: %.4f" % (np.mean(precision_50[:int(recall_50 * 100)]), recall_50, 2 * recall_50 * np.mean(precision_50[:int(recall_50 * 100)]) / (recall_50 + np.mean(precision_50[:int(recall_50 * 100)])))) + print("Get map done.") diff --git a/train.py b/train.py new file mode 100644 index 0000000..154f944 --- /dev/null +++ b/train.py @@ -0,0 +1,480 @@ +# -------------------------------------# +# 对数据集进行训练 +# -------------------------------------# +import datetime +import os +import random + +import numpy as np +import torch +import torch.backends.cudnn as cudnn +import torch.distributed as dist +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import DataLoader + +from data.dataloader_for_IRSTD_UAV import seqDataset, dataset_collate +from model.TDCNet.TDCNetwork import TDCNetwork +from model.nets.yolo_training import (ModelEMA, YOLOLoss, get_lr_scheduler, set_optimizer_lr, weights_init) +from utils.callbacks import EvalCallback, LossHistory +from utils.utils import get_classes, show_config +from utils.utils_fit import fit_one_epoch + +if __name__ == "__main__": + os.environ["CUDA_VISIBLE_DEVICES"] = "0" + # ---------------------------------# + # num_frame 输入帧数 + # ---------------------------------# + num_frame = 5 + # ---------------------------------# + # Cuda 是否使用Cuda + # 没有GPU可以设置成False + # ---------------------------------# + Cuda = True + # ---------------------------------------------------------------------# + # distributed 用于指定是否使用单机多卡分布式运行 + # 终端指令仅支持Ubuntu。CUDA_VISIBLE_DEVICES用于在Ubuntu下指定显卡。 + # Windows系统下默认使用DP模式调用所有显卡,不支持DDP。 + # DP模式: + # 设置 distributed = False + # 在终端中输入 CUDA_VISIBLE_DEVICES=0,1 python train.py + # DDP模式: + # 设置 distributed = True + # 在终端中输入 CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 train.py + # ---------------------------------------------------------------------# + distributed = False + # ---------------------------------------------------------------------# + # sync_bn 是否使用sync_bn,DDP模式多卡可用 + # ---------------------------------------------------------------------# + sync_bn = False + # ---------------------------------------------------------------------# + # fp16 是否使用混合精度训练 + # 可减少约一半的显存、需要pytorch1.7.1以上 + # ---------------------------------------------------------------------# + fp16 = False + # ---------------------------------------------------------------------# + # classes_path 指向model_data下的txt,与自己训练的数据集相关 + # 训练前一定要修改classes_path,使其对应自己的数据集 + # ---------------------------------------------------------------------# + classes_path = 'model_data/classes.txt' + model_path = '' + input_shape = [640, 640] + # ----------------------------------------------------------------------------------------------------------------------------# + # 训练分为两个阶段,分别是冻结阶段和解冻阶段。设置冻结阶段是为了满足机器性能不足的同学的训练需求。 + # 冻结训练需要的显存较小,显卡非常差的情况下,可设置Freeze_Epoch等于UnFreeze_Epoch,Freeze_Train = True,此时仅仅进行冻结训练。 + # + # 在此提供若干参数设置建议,各位训练者根据自己的需求进行灵活调整: + # (一)从整个模型的预训练权重开始训练: + # Adam: + # Init_Epoch = 0,Freeze_Epoch = 50,UnFreeze_Epoch = 100,Freeze_Train = True,optimizer_type = 'adam',Init_lr = 1e-3,weight_decay = 0。(冻结) + # Init_Epoch = 0,UnFreeze_Epoch = 100,Freeze_Train = False,optimizer_type = 'adam',Init_lr = 1e-3,weight_decay = 0。(不冻结) + # SGD: + # Init_Epoch = 0,Freeze_Epoch = 50,UnFreeze_Epoch = 300,Freeze_Train = True,optimizer_type = 'sgd',Init_lr = 1e-2,weight_decay = 5e-4。(冻结) + # Init_Epoch = 0,UnFreeze_Epoch = 300,Freeze_Train = False,optimizer_type = 'sgd',Init_lr = 1e-2,weight_decay = 5e-4。(不冻结) + # 其中:UnFreeze_Epoch可以在100-300之间调整。 + # (二)从0开始训练: + # Init_Epoch = 0,UnFreeze_Epoch >= 300,Unfreeze_batch_size >= 16,Freeze_Train = False(不冻结训练) + # 其中:UnFreeze_Epoch尽量不小于300。optimizer_type = 'sgd',Init_lr = 1e-2,mosaic = True。 + # (三)batch_size的设置: + # 在显卡能够接受的范围内,以大为好。显存不足与数据集大小无关,提示显存不足(OOM或者CUDA out of memory)请调小batch_size。 + # 受到BatchNorm层影响,batch_size最小为2,不能为1。 + # 正常情况下Freeze_batch_size建议为Unfreeze_batch_size的1-2倍。不建议设置的差距过大,因为关系到学习率的自动调整。 + # ----------------------------------------------------------------------------------------------------------------------------# + # ------------------------------------------------------------------# + # 冻结阶段训练参数 + # 此时模型的主干被冻结了,特征提取网络不发生改变 + # 占用的显存较小,仅对网络进行微调 + # Init_Epoch 模型当前开始的训练世代,其值可以大于Freeze_Epoch,如设置: + # Init_Epoch = 60、Freeze_Epoch = 50、UnFreeze_Epoch = 100 + # 会跳过冻结阶段,直接从60代开始,并调整对应的学习率。 + # (断点续练时使用) + # Freeze_Epoch 模型冻结训练的Freeze_Epoch + # (当Freeze_Train=False时失效) + # Freeze_batch_size 模型冻结训练的batch_size + # (当Freeze_Train=False时失效) + # ------------------------------------------------------------------# + Init_Epoch = 0 + Freeze_Epoch = 100 + Freeze_batch_size = 4 + # ------------------------------------------------------------------# + # 解冻阶段训练参数 + # 此时模型的主干不被冻结了,特征提取网络会发生改变 + # 占用的显存较大,网络所有的参数都会发生改变 + # UnFreeze_Epoch 模型总共训练的epoch + # SGD需要更长的时间收敛,因此设置较大的UnFreeze_Epoch + # Adam可以使用相对较小的UnFreeze_Epoch + # Unfreeze_batch_size 模型在解冻后的batch_size + # ------------------------------------------------------------------# + UnFreeze_Epoch = 100 + Unfreeze_batch_size = 4 + # ------------------------------------------------------------------# + # Freeze_Train 是否进行冻结训练 + # 默认先冻结主干训练后解冻训练。 + # ------------------------------------------------------------------# + Freeze_Train = False + + # ------------------------------------------------------------------# + # 其它训练参数:学习率、优化器、学习率下降有关 + # ------------------------------------------------------------------# + # ------------------------------------------------------------------# + # Init_lr 模型的最大学习率 + # Min_lr 模型的最小学习率,默认为最大学习率的0.01 + # ------------------------------------------------------------------# + Init_lr = 1e-3 + Min_lr = Init_lr * 0.01 + # ------------------------------------------------------------------# + # optimizer_type 使用到的优化器种类,可选的有adam、sgd + # 当使用Adam优化器时建议设置 Init_lr=1e-3 + # 当使用SGD优化器时建议设置 Init_lr=1e-2 + # momentum 优化器内部使用到的momentum参数 + # weight_decay 权值衰减,可防止过拟合 + # adam会导致weight_decay错误,使用adam时建议设置为0。 + # ------------------------------------------------------------------# + optimizer_type = "adam" + momentum = 0.937 + weight_decay = 1e-4 + # ------------------------------------------------------------------# + # lr_decay_type 使用到的学习率下降方式,可选的有step、cos + # ------------------------------------------------------------------# + lr_decay_type = "cos" + # ------------------------------------------------------------------# + # save_period 多少个epoch保存一次权值 + # ------------------------------------------------------------------# + save_period = 10 + # ------------------------------------------------------------------# + # save_dir 权值与日志文件保存的文件夹 + # ------------------------------------------------------------------# + save_dir = f'logs/TDCNet_epoch_{UnFreeze_Epoch}_batch_{Unfreeze_batch_size}_optim_{optimizer_type}_lr_{Init_lr}_T_{num_frame}' + # ------------------------------------------------------------------# + # eval_flag 是否在训练时进行评估,评估对象为验证集 + # 安装pycocotools库后,评估体验更佳。 + # eval_period 代表多少个epoch评估一次,不建议频繁的评估 + # 评估需要消耗较多的时间,频繁评估会导致训练非常慢 + # 此处获得的mAP会与get_map.py获得的会有所不同,原因有二: + # (一)此处获得的mAP为验证集的mAP。 + # (二)此处设置评估参数较为保守,目的是加快评估速度。 + # ------------------------------------------------------------------# + eval_flag = True + eval_period = 200 + # ------------------------------------------------------------------# + # num_workers 用于设置是否使用多线程读取数据 + # 开启后会加快数据读取速度,但是会占用更多内存 + # 内存较小的电脑可以设置为2或者0 + # ------------------------------------------------------------------# + num_workers = 8 + + # ----------------------------------------------------# + # 获得图片路径和标签 + # ----------------------------------------------------# + + DATA_PATH = "/Dataset/IRSTD-UAV/" + train_annotation_path = "/Dataset/IRSTD-UAV/train.txt" + val_annotation_path = "/Dataset/IRSTD-UAV/val.txt" + + # ------------------------------------------------------# + # 设置用到的显卡 + # ------------------------------------------------------# + ngpus_per_node = torch.cuda.device_count() + if distributed: + dist.init_process_group(backend="nccl") + local_rank = int(os.environ["LOCAL_RANK"]) + rank = int(os.environ["RANK"]) + device = torch.device("cuda", local_rank) + if local_rank == 0: + print(f"[{os.getpid()}] (rank = {rank}, local_rank = {local_rank}) training...") + print("Gpu Device Count : ", ngpus_per_node) + else: + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + local_rank = 0 + rank = 0 + + # ------------------------------------------------------# + # 设置随机种子 + # ------------------------------------------------------# + seed = 42 + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + torch.backends.cudnn.deterministic = True + + class_names, num_classes = get_classes(classes_path) + + model = TDCNetwork(num_classes=1, num_frame=num_frame) + + weights_init(model) + if model_path != '': + if local_rank == 0: + print('Load weights {}.'.format(model_path)) + + # ------------------------------------------------------# + # 根据预训练权重的Key和模型的Key进行加载 + # ------------------------------------------------------# + model_dict = model.state_dict() + # pdb.set_trace() + pretrained_dict = torch.load(model_path, map_location=device) + load_key, no_load_key, temp_dict = [], [], {} + for k, v in pretrained_dict.items(): + if k in model_dict.keys() and np.shape(model_dict[k]) == np.shape(v): + temp_dict[k] = v + load_key.append(k) + else: + no_load_key.append(k) + model_dict.update(temp_dict) + model.load_state_dict(model_dict) + # ------------------------------------------------------# + # 显示没有匹配上的Key + # ------------------------------------------------------# + if local_rank == 0: + print("\nSuccessful Load Key:", str(load_key)[:500], "……\nSuccessful Load Key Num:", len(load_key)) + print("\nFail To Load Key:", str(no_load_key)[:500], "……\nFail To Load Key num:", len(no_load_key)) + print("\n\033[1;33;44m温馨提示,head部分没有载入是正常现象,Backbone部分没有载入是错误的。\033[0m") + + # ----------------------# + # 获得损失函数 + # ----------------------# + + yolo_loss = YOLOLoss(num_classes, fp16, strides=[8]) + # ----------------------# + # 记录Loss + # ----------------------# + if local_rank == 0: + time_str = datetime.datetime.strftime(datetime.datetime.now(), '%Y_%m_%d_%H_%M_%S') + log_dir = os.path.join(save_dir, "loss_" + str(time_str)) + + loss_history = LossHistory(log_dir, model, input_shape=input_shape) + + # pdb.set_trace() + else: + loss_history = None + + # ------------------------------------------------------------------# + # torch 1.2不支持amp,建议使用torch 1.7.1及以上正确使用fp16 + # 因此torch1.2这里显示"could not be resolve" + # ------------------------------------------------------------------# + if fp16: + from torch.cuda.amp import GradScaler as GradScaler + + scaler = GradScaler() + else: + scaler = None + + model_train = model.train() + # ----------------------------# + # 多卡同步Bn + # ----------------------------# + if sync_bn and ngpus_per_node > 1 and distributed: + model_train = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model_train) + elif sync_bn: + print("Sync_bn is not support in one gpu or not distributed.") + + if Cuda: + if distributed: + # ----------------------------# + # 多卡平行运行 + # ----------------------------# + model_train = model_train.cuda(local_rank) + model_train = torch.nn.parallel.DistributedDataParallel(model_train, device_ids=[local_rank], find_unused_parameters=True) + else: + model_train = torch.nn.DataParallel(model) + cudnn.benchmark = True + model_train = model_train.cuda() + + # ----------------------------# + # 权值平滑 + # ----------------------------# + # pdb.set_trace() + ema = ModelEMA(model_train) + + # ---------------------------# + # 读取数据集对应的txt + # ---------------------------# + with open(train_annotation_path, encoding='utf-8') as f: + train_lines = f.readlines() + with open(val_annotation_path, encoding='utf-8') as f: + val_lines = f.readlines() + num_train = len(train_lines) + num_val = len(val_lines) + + if local_rank == 0: + show_config( + classes_path=classes_path, model_path=model_path, input_shape=input_shape, \ + Init_Epoch=Init_Epoch, Freeze_Epoch=Freeze_Epoch, UnFreeze_Epoch=UnFreeze_Epoch, Freeze_batch_size=Freeze_batch_size, Unfreeze_batch_size=Unfreeze_batch_size, Freeze_Train=Freeze_Train, \ + Init_lr=Init_lr, Min_lr=Min_lr, optimizer_type=optimizer_type, momentum=momentum, lr_decay_type=lr_decay_type, \ + save_period=save_period, save_dir=log_dir, num_workers=num_workers, num_train=num_train, num_val=num_val + ) + # ---------------------------------------------------------# + # 总训练世代指的是遍历全部数据的总次数 + # 总训练步长指的是梯度下降的总次数 + # 每个训练世代包含若干训练步长,每个训练步长进行一次梯度下降。 + # 此处仅建议最低训练世代,上不封顶,计算时只考虑了解冻部分 + # ----------------------------------------------------------# + wanted_step = 5e4 if optimizer_type == "sgd" else 1.5e4 + total_step = num_train // Unfreeze_batch_size * UnFreeze_Epoch + if total_step <= wanted_step: + if num_train // Unfreeze_batch_size == 0: + raise ValueError('数据集过小,无法进行训练,请扩充数据集。') + wanted_epoch = wanted_step // (num_train // Unfreeze_batch_size) + 1 + print("\n\033[1;33;44m[Warning] 使用%s优化器时,建议将训练总步长设置到%d以上。\033[0m" % (optimizer_type, wanted_step)) + print("\033[1;33;44m[Warning] 本次运行的总训练数据量为%d,Unfreeze_batch_size为%d,共训练%d个Epoch,计算出总训练步长为%d。\033[0m" % (num_train, Unfreeze_batch_size, UnFreeze_Epoch, total_step)) + print("\033[1;33;44m[Warning] 由于总训练步长为%d,小于建议总步长%d,建议设置总世代为%d。\033[0m" % (total_step, wanted_step, wanted_epoch)) + + # ------------------------------------------------------# + # 主干特征提取网络特征通用,冻结训练可以加快训练速度 + # 也可以在训练初期防止权值被破坏。 + # Init_Epoch为起始世代 + # Freeze_Epoch为冻结训练的世代 + # UnFreeze_Epoch总训练世代 + # 提示OOM或者显存不足请调小Batch_size + # ------------------------------------------------------# + if True: + UnFreeze_flag = False + # ------------------------------------# + # 冻结一定部分训练 + # ------------------------------------# + if Freeze_Train: + for param in model.backbone.parameters(): + param.requires_grad = False + for param in model.backbone_3d.parameters(): + param.requires_grad = False + + # -------------------------------------------------------------------# + # 如果不冻结训练的话,直接设置batch_size为Unfreeze_batch_size + # -------------------------------------------------------------------# + batch_size = Freeze_batch_size if Freeze_Train else Unfreeze_batch_size + + # -------------------------------------------------------------------# + # 判断当前batch_size,自适应调整学习率 + # -------------------------------------------------------------------# + nbs = 64 + lr_limit_max = 1e-3 if optimizer_type == 'adam' else 5e-2 + lr_limit_min = 1e-5 if optimizer_type == 'adam' else 5e-4 + Init_lr_fit = min(max(batch_size / nbs * Init_lr, lr_limit_min), lr_limit_max) + Min_lr_fit = min(max(batch_size / nbs * Min_lr, lr_limit_min * 1e-2), lr_limit_max * 1e-2) + + # ---------------------------------------# + # 根据optimizer_type选择优化器 + # ---------------------------------------# + pg0, pg1, pg2 = [], [], [] + for k, v in model.named_modules(): + if hasattr(v, "bias") and isinstance(v.bias, nn.Parameter): + pg2.append(v.bias) + if isinstance(v, nn.BatchNorm2d) or "bn" in k: + pg0.append(v.weight) + elif hasattr(v, "weight") and isinstance(v.weight, nn.Parameter): + pg1.append(v.weight) + optimizer = { + 'adam': optim.Adam(pg0, Init_lr_fit, betas=(momentum, 0.999)), + 'sgd': optim.SGD(pg0, Init_lr_fit, momentum=momentum, nesterov=True) + }[optimizer_type] + optimizer.add_param_group({"params": pg1, "weight_decay": weight_decay}) + optimizer.add_param_group({"params": pg2}) + + # ---------------------------------------# + # 获得学习率下降的公式 + # ---------------------------------------# + lr_scheduler_func = get_lr_scheduler(lr_decay_type, Init_lr_fit, Min_lr_fit, UnFreeze_Epoch) + + # ---------------------------------------# + # 判断每一个世代的长度 + # ---------------------------------------# + epoch_step = num_train // batch_size + epoch_step_val = num_val // batch_size + + if epoch_step == 0 or epoch_step_val == 0: + raise ValueError("数据集过小,无法继续进行训练,请扩充数据集。") + + if ema: + ema.updates = epoch_step * Init_Epoch + + train_dataset = seqDataset(train_annotation_path, input_shape[0], num_frame, 'train') # 5 + val_dataset = seqDataset(val_annotation_path, input_shape[0], num_frame, 'val') # 5 + + if distributed: + train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, shuffle=True, ) + val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False, ) + batch_size = batch_size // ngpus_per_node + shuffle = False + else: + train_sampler = None + val_sampler = None + shuffle = True + + gen = DataLoader(train_dataset, shuffle=shuffle, batch_size=batch_size, num_workers=num_workers, pin_memory=True, + drop_last=True, collate_fn=dataset_collate, sampler=train_sampler) + gen_val = DataLoader(val_dataset, shuffle=shuffle, batch_size=batch_size, num_workers=num_workers, pin_memory=True, + drop_last=True, collate_fn=dataset_collate, sampler=val_sampler) + + # ----------------------# + # 记录eval的map曲线 + # ----------------------# + if local_rank == 0: + eval_callback = EvalCallback(model, input_shape, class_names, num_classes, val_lines, log_dir, Cuda, \ + eval_flag=eval_flag, period=eval_period) + else: + eval_callback = None + + # ---------------------------------------# + # 开始模型训练 + # ---------------------------------------# + for epoch in range(Init_Epoch, UnFreeze_Epoch): + # ---------------------------------------# + # 如果模型有冻结学习部分 + # 则解冻,并设置参数 + # ---------------------------------------# + if epoch >= Freeze_Epoch and not UnFreeze_flag and Freeze_Train: + batch_size = Unfreeze_batch_size + + # -------------------------------------------------------------------# + # 判断当前batch_size,自适应调整学习率 + # -------------------------------------------------------------------# + nbs = 64 + lr_limit_max = 1e-3 if optimizer_type == 'adam' else 5e-2 + lr_limit_min = 1e-5 if optimizer_type == 'adam' else 5e-4 + Init_lr_fit = min(max(batch_size / nbs * Init_lr, lr_limit_min), lr_limit_max) + Min_lr_fit = min(max(batch_size / nbs * Min_lr, lr_limit_min * 1e-2), lr_limit_max * 1e-2) + # ---------------------------------------# + # 获得学习率下降的公式 + # ---------------------------------------# + lr_scheduler_func = get_lr_scheduler(lr_decay_type, Init_lr_fit, Min_lr_fit, UnFreeze_Epoch) + + for param in model.backbone.parameters(): + param.requires_grad = True + for param in model.backbone_3d.parameters(): + param.requires_grad = True + + epoch_step = num_train // batch_size + epoch_step_val = num_val // batch_size + + if epoch_step == 0 or epoch_step_val == 0: + raise ValueError("数据集过小,无法继续进行训练,请扩充数据集。") + + if distributed: + batch_size = batch_size // ngpus_per_node + + if ema: + ema.updates = epoch_step * epoch + + gen = DataLoader(train_dataset, shuffle=shuffle, batch_size=batch_size, num_workers=num_workers, pin_memory=True, + drop_last=True, collate_fn=dataset_collate, sampler=train_sampler) + gen_val = DataLoader(val_dataset, shuffle=shuffle, batch_size=batch_size, num_workers=num_workers, pin_memory=True, + drop_last=True, collate_fn=dataset_collate, sampler=val_sampler) + + UnFreeze_flag = True + + gen.dataset.epoch_now = epoch + gen_val.dataset.epoch_now = epoch + + if distributed: + train_sampler.set_epoch(epoch) + + set_optimizer_lr(optimizer, lr_scheduler_func, epoch) + # pdb.set_trace() + + fit_one_epoch(model_train, model, ema, yolo_loss, loss_history, eval_callback, optimizer, epoch, epoch_step, epoch_step_val, gen, gen_val, UnFreeze_Epoch, Cuda, fp16, scaler, save_period, log_dir, local_rank) + + if distributed: + dist.barrier() + + if local_rank == 0: + loss_history.writer.close() diff --git a/utils/callbacks.py b/utils/callbacks.py new file mode 100644 index 0000000..539c0e1 --- /dev/null +++ b/utils/callbacks.py @@ -0,0 +1,288 @@ +from email.mime import image +import os + +import torch +import matplotlib +matplotlib.use('Agg') +import scipy.signal +from matplotlib import pyplot as plt +from torch.utils.tensorboard import SummaryWriter + +import shutil +import numpy as np + +from PIL import Image +from tqdm import tqdm +from .utils import cvtColor, preprocess_input, resize_image +from .utils_bbox import decode_outputs, non_max_suppression +from .utils_map import get_coco_map, get_map +# from utils import cvtColor, preprocess_input, resize_image +# from utils_bbox import decode_outputs, non_max_suppression +# from utils_map import get_coco_map, get_map + + +class LossHistory(): + def __init__(self, log_dir, model, input_shape): + self.log_dir = log_dir + self.losses = [] + self.val_loss = [] + + os.makedirs(self.log_dir) + self.writer = SummaryWriter(self.log_dir) + try: + dummy_input = torch.randn(2, 3, input_shape[0], input_shape[1]) + self.writer.add_graph(model, dummy_input) + except: + pass + + def append_loss(self, epoch, loss, val_loss): + if not os.path.exists(self.log_dir): + os.makedirs(self.log_dir) + + self.losses.append(loss) + self.val_loss.append(val_loss) + + with open(os.path.join(self.log_dir, "epoch_loss.txt"), 'a') as f: + f.write(str(loss)) + f.write("\n") + with open(os.path.join(self.log_dir, "epoch_val_loss.txt"), 'a') as f: + f.write(str(val_loss)) + f.write("\n") + + self.writer.add_scalar('loss', loss, epoch) + self.writer.add_scalar('val_loss', val_loss, epoch) + self.loss_plot() + + def loss_plot(self): + iters = range(len(self.losses)) + + plt.figure() + plt.plot(iters, self.losses, 'red', linewidth = 2, label='train loss') + plt.plot(iters, self.val_loss, 'coral', linewidth = 2, label='val loss') + try: + if len(self.losses) < 25: + num = 5 + else: + num = 15 + + plt.plot(iters, scipy.signal.savgol_filter(self.losses, num, 3), 'green', linestyle = '--', linewidth = 2, label='smooth train loss') + plt.plot(iters, scipy.signal.savgol_filter(self.val_loss, num, 3), '#8B4513', linestyle = '--', linewidth = 2, label='smooth val loss') + except: + pass + + plt.grid(True) + plt.xlabel('Epoch') + plt.ylabel('Loss') + plt.legend(loc="upper right") + + plt.savefig(os.path.join(self.log_dir, "epoch_loss.png")) + + plt.cla() + plt.close("all") + +class EvalCallback(): + def __init__(self, net, input_shape, class_names, num_classes, val_lines, log_dir, cuda, \ + map_out_path=".temp_map_out", max_boxes=100, confidence=0.05, nms_iou=0.5, letterbox_image=True, MINOVERLAP=0.5, eval_flag=True, period=1): + super(EvalCallback, self).__init__() + + self.net = net + self.input_shape = input_shape + self.class_names = class_names + self.num_classes = num_classes + self.val_lines = val_lines + + self.log_dir = log_dir + self.cuda = cuda + self.map_out_path = map_out_path + self.max_boxes = max_boxes + self.confidence = confidence + self.nms_iou = nms_iou + self.letterbox_image = letterbox_image + self.MINOVERLAP = MINOVERLAP + self.eval_flag = eval_flag + self.period = period + + self.maps = [0] + self.epoches = [0] + if self.eval_flag: + with open(os.path.join(self.log_dir, "epoch_map.txt"), 'a') as f: + f.write(str(0)) + f.write("\n") + + # def get_history_imgs(self, line): + # dir_path = line.replace(line.split('/')[-1],'') + # file_type = line.split('.')[-1] + # index = int(line.split('/')[-1][:-4]) + # return [os.path.join(dir_path, "%d.%s" % (max(id, 0),file_type)) for id in range(index - 4, index + 1)] + # def get_history_imgs(self, line): + # dir_path = line.replace(line.split('/')[-1],'') + # file_type = line.split('.')[-1] + # index = int(line.split("/")[-1][:8]) + # return [os.path.join(dir_path, "%08d.%s" % (max(id, 0),file_type)) for id in range(index - 4, index + 1)] + def get_history_imgs(self, line): + dir_path = line.replace(line.split('/')[-1],'') + file_type = line.split('.')[-1] + index = int(line.split("/")[-1][:8]) + return [os.path.join(dir_path, "%08d.%s" % (max(id, 1),file_type)) for id in range(index - 4, index + 1)] + + + + def get_map_txt(self, image_id, images, class_names, map_out_path): + f = open(os.path.join(map_out_path, "detection-results/"+image_id+".txt"),"w") + image_shape = np.array(np.shape(images[0])[0:2]) + #---------------------------------------------------------# + # 在这里将图像转换成RGB图像,防止灰度图在预测时报错。 + # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB + #---------------------------------------------------------# + images = [cvtColor(image) for image in images] + #---------------------------------------------------------# + # 给图像增加灰条,实现不失真的resize + # 也可以直接resize进行识别 + #---------------------------------------------------------# + image_data = [resize_image(image, (self.input_shape[1],self.input_shape[0]), self.letterbox_image) for image in images] + #---------------------------------------------------------# + # 添加上batch_size维度 + #---------------------------------------------------------# + image_data = [np.transpose(preprocess_input(np.array(image, dtype='float32')), (2, 0, 1)) for image in image_data] + # (3, 640, 640) -> (3, 16, 640, 640) + image_data = np.stack(image_data, axis=1) + + + image_data = np.expand_dims(image_data, 0) + + + with torch.no_grad(): + images = torch.from_numpy(image_data) + if self.cuda: + images = images.cuda() + #---------------------------------------------------------# + # 将图像输入网络当中进行预测! + #---------------------------------------------------------# + outputs = self.net(images) + outputs = decode_outputs(outputs, self.input_shape) + #---------------------------------------------------------# + # 将预测框进行堆叠,然后进行非极大抑制 + #---------------------------------------------------------# + results = non_max_suppression(outputs, self.num_classes, self.input_shape, + image_shape, self.letterbox_image, conf_thres = self.confidence, nms_thres = self.nms_iou) + + if results[0] is None: + return + + top_label = np.array(results[0][:, 6], dtype = 'int32') + top_conf = results[0][:, 4] * results[0][:, 5] + top_boxes = results[0][:, :4] + + top_100 = np.argsort(top_label)[::-1][:self.max_boxes] + top_boxes = top_boxes[top_100] + top_conf = top_conf[top_100] + top_label = top_label[top_100] + + for i, c in list(enumerate(top_label)): + predicted_class = self.class_names[int(c)] + box = top_boxes[i] + score = str(top_conf[i]) + + top, left, bottom, right = box + if predicted_class not in class_names: + continue + + f.write("%s %s %s %s %s %s\n" % (predicted_class, score[:6], str(int(left)), str(int(top)), str(int(right)),str(int(bottom)))) + + f.close() + return + + def on_epoch_end(self, epoch, model_eval): + if epoch % self.period == 0 and self.eval_flag: + self.net = model_eval + if not os.path.exists(self.map_out_path): + os.makedirs(self.map_out_path) + if not os.path.exists(os.path.join(self.map_out_path, "ground-truth")): + os.makedirs(os.path.join(self.map_out_path, "ground-truth")) + if not os.path.exists(os.path.join(self.map_out_path, "detection-results")): + os.makedirs(os.path.join(self.map_out_path, "detection-results")) + print("Get map.") + for annotation_line in tqdm(self.val_lines): + line = annotation_line.split() + ''' + # 不同视频的图片序号会重复, 视频号-图片序号作为id + ''' + image_id = "-".join(line[0].split("/")[6:8]).split('.')[0] + #------------------------------# + # 读取图像并转换成RGB图像 + #------------------------------# + # cb update + images = self.get_history_imgs(line[0]) + images = [Image.open(item) for item in images] + # image = Image.open(line[0]) + #------------------------------# + # 获得预测框 + #------------------------------# + gt_boxes = np.array([np.array(list(map(int,box.split(',')))) for box in line[1:]]) + #------------------------------# + # 获得预测txt + #------------------------------# + self.get_map_txt(image_id, images, self.class_names, self.map_out_path) + + #------------------------------# + # 获得真实框txt + #------------------------------# + with open(os.path.join(self.map_out_path, "ground-truth/"+image_id+".txt"), "w") as new_f: + for box in gt_boxes: + left, top, right, bottom, obj = box + obj_name = self.class_names[obj] + new_f.write("%s %s %s %s %s\n" % (obj_name, left, top, right, bottom)) + + print("Calculate Map.") + try: + temp_map = get_coco_map(class_names = self.class_names, path = self.map_out_path)[1] + except: + temp_map = get_map(self.MINOVERLAP, False, path = self.map_out_path) + self.maps.append(temp_map) + self.epoches.append(epoch) + + with open(os.path.join(self.log_dir, "epoch_map.txt"), 'a') as f: + f.write(str(temp_map)) + f.write("\n") + + plt.figure() + plt.plot(self.epoches, self.maps, 'red', linewidth = 2, label='train map') + + plt.grid(True) + plt.xlabel('Epoch') + plt.ylabel('Map %s'%str(self.MINOVERLAP)) + plt.title('A Map Curve') + plt.legend(loc="upper right") + + plt.savefig(os.path.join(self.log_dir, "epoch_map.png")) + plt.cla() + plt.close("all") + + print("Get map done.") + shutil.rmtree(self.map_out_path) + + + + + + + +# def get_history_imgs(line): +# dir_path = line.replace(line.split('/')[-1],'') +# file_type = line.split('.')[-1] +# index = int(line.split('/')[-1][:-4]) +# image_id = "-".join(line.split("/")[6:8]).split('.')[0] +# print(image_id) + +# return [os.path.join(dir_path, "%d.%s" % (max(id, 0),file_type)) for id in range(index - 4, index + 1)] + + +# if __name__ == "__main__": +# with open('coco_val.txt', encoding='utf-8') as f: +# val_lines = f.readlines() +# for annotation_line in val_lines: +# line = annotation_line.split() +# images = get_history_imgs(line[0]) +# # for item in images: +# # print(item) +# # break \ No newline at end of file diff --git a/utils/utils.py b/utils/utils.py new file mode 100644 index 0000000..29b0c73 --- /dev/null +++ b/utils/utils.py @@ -0,0 +1,64 @@ +import numpy as np +from PIL import Image + + +#---------------------------------------------------------# +# 将图像转换成RGB图像,防止灰度图在预测时报错。 +# 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB +#---------------------------------------------------------# +def cvtColor(image): + if len(np.shape(image)) == 3 and np.shape(image)[2] == 3: + return image + else: + image = image.convert('RGB') + return image + +#---------------------------------------------------# +# 对输入图像进行resize +#---------------------------------------------------# +def resize_image(image, size, letterbox_image): + iw, ih = image.size + w, h = size + if letterbox_image: + scale = min(w/iw, h/ih) + nw = int(iw*scale) + nh = int(ih*scale) + + image = image.resize((nw,nh), Image.BICUBIC) + new_image = Image.new('RGB', size, (128,128,128)) + new_image.paste(image, ((w-nw)//2, (h-nh)//2)) + else: + new_image = image.resize((w, h), Image.BICUBIC) + return new_image + +#---------------------------------------------------# +# 获得类 +#---------------------------------------------------# +def get_classes(classes_path): + with open(classes_path, encoding='utf-8') as f: + class_names = f.readlines() + class_names = [c.strip() for c in class_names] + return class_names, len(class_names) + +def preprocess_input(image): + image /= 255.0 + image -= np.array([0.485, 0.456, 0.406]) + image /= np.array([0.229, 0.224, 0.225]) + return image + +#---------------------------------------------------# +# 获得学习率 +#---------------------------------------------------# +def get_lr(optimizer): + for param_group in optimizer.param_groups: + return param_group['lr'] + +def show_config(**kwargs): + print('Configurations:') + print('-' * 130) + print('|%25s | %100s|' % ('keys', 'values')) + print('-' * 130) + for key, value in kwargs.items(): + print('|%25s | %100s|' % (str(key), str(value))) + print('-' * 130) + \ No newline at end of file diff --git a/utils/utils_bbox.py b/utils/utils_bbox.py new file mode 100644 index 0000000..64fc1ae --- /dev/null +++ b/utils/utils_bbox.py @@ -0,0 +1,180 @@ +import numpy as np +import torch +from torchvision.ops import nms, boxes + +def yolo_correct_boxes(box_xy, box_wh, input_shape, image_shape, letterbox_image): + #-----------------------------------------------------------------# + # 把y轴放前面是因为方便预测框和图像的宽高进行相乘 + #-----------------------------------------------------------------# + box_yx = box_xy[..., ::-1] + box_hw = box_wh[..., ::-1] + input_shape = np.array(input_shape) + image_shape = np.array(image_shape) + + if letterbox_image: + #-----------------------------------------------------------------# + # 这里求出来的offset是图像有效区域相对于图像左上角的偏移情况 + # new_shape指的是宽高缩放情况 + #-----------------------------------------------------------------# + new_shape = np.round(image_shape * np.min(input_shape/image_shape)) + offset = (input_shape - new_shape)/2./input_shape + scale = input_shape/new_shape + + box_yx = (box_yx - offset) * scale + box_hw *= scale + + box_mins = box_yx - (box_hw / 2.) + box_maxes = box_yx + (box_hw / 2.) + boxes = np.concatenate([box_mins[..., 0:1], box_mins[..., 1:2], box_maxes[..., 0:1], box_maxes[..., 1:2]], axis=-1) + boxes *= np.concatenate([image_shape, image_shape], axis=-1) + return boxes + +def decode_outputs(outputs, input_shape): + grids = [] + strides = [] + hw = [x.shape[-2:] for x in outputs] + #---------------------------------------------------# + # outputs输入前代表每个特征层的预测结果 + # batch_size, 4 + 1 + num_classes, 80, 80 => batch_size, 4 + 1 + num_classes, 6400 + # batch_size, 5 + num_classes, 40, 40 + # batch_size, 5 + num_classes, 20, 20 + # batch_size, 4 + 1 + num_classes, 6400 + 1600 + 400 -> batch_size, 4 + 1 + num_classes, 8400 + # 堆叠后为batch_size, 8400, 5 + num_classes + #---------------------------------------------------# + outputs = torch.cat([x.flatten(start_dim=2) for x in outputs], dim=2).permute(0, 2, 1) + #---------------------------------------------------# + # 获得每一个特征点属于每一个种类的概率 + #---------------------------------------------------# + outputs[:, :, 4:] = torch.sigmoid(outputs[:, :, 4:]) + for h, w in hw: + #---------------------------# + # 根据特征层的高宽生成网格点 + #---------------------------# + grid_y, grid_x = torch.meshgrid([torch.arange(h), torch.arange(w)], indexing='ij') + #---------------------------# + # 1, 6400, 2 + # 1, 1600, 2 + # 1, 400, 2 + #---------------------------# + grid = torch.stack((grid_x, grid_y), 2).view(1, -1, 2) + shape = grid.shape[:2] + + grids.append(grid) + strides.append(torch.full((shape[0], shape[1], 1), input_shape[0] / h)) + #---------------------------# + # 将网格点堆叠到一起 + # 1, 6400, 2 + # 1, 1600, 2 + # 1, 400, 2 + # + # 1, 8400, 2 + #---------------------------# + grids = torch.cat(grids, dim=1).type(outputs.type()) + strides = torch.cat(strides, dim=1).type(outputs.type()) + #------------------------# + # 根据网格点进行解码 + #------------------------# + outputs[..., :2] = (outputs[..., :2] + grids) * strides + outputs[..., 2:4] = torch.exp(outputs[..., 2:4]) * strides + #-----------------# + # 归一化 + #-----------------# + outputs[..., [0,2]] = outputs[..., [0,2]] / input_shape[1] + outputs[..., [1,3]] = outputs[..., [1,3]] / input_shape[0] + return outputs + +def non_max_suppression(prediction, num_classes, input_shape, image_shape, letterbox_image, conf_thres=0.5, nms_thres=0.4): + #----------------------------------------------------------# + # 将预测结果的格式转换成左上角右下角的格式。 + # prediction [batch_size, num_anchors, 85] + #----------------------------------------------------------# + box_corner = prediction.new(prediction.shape) + box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2 + box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2 + box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2 + box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2 + prediction[:, :, :4] = box_corner[:, :, :4] + + output = [None for _ in range(len(prediction))] + #----------------------------------------------------------# + # 对输入图片进行循环,一般只会进行一次 + #----------------------------------------------------------# + for i, image_pred in enumerate(prediction): + #----------------------------------------------------------# + # 对种类预测部分取max。 + # class_conf [num_anchors, 1] 种类置信度 + # class_pred [num_anchors, 1] 种类 + #----------------------------------------------------------# + class_conf, class_pred = torch.max(image_pred[:, 5:5 + num_classes], 1, keepdim=True) + + #----------------------------------------------------------# + # 利用置信度进行第一轮筛选 + #----------------------------------------------------------# + conf_mask = (image_pred[:, 4] * class_conf[:, 0] >= conf_thres).squeeze() + + if not image_pred.size(0): + continue + #-------------------------------------------------------------------------# + # detections [num_anchors, 7] + # 7的内容为:x1, y1, x2, y2, obj_conf, class_conf, class_pred + #-------------------------------------------------------------------------# + detections = torch.cat((image_pred[:, :5], class_conf, class_pred.float()), 1) + detections = detections[conf_mask] + + nms_out_index = boxes.batched_nms( + detections[:, :4], + detections[:, 4] * detections[:, 5], + detections[:, 6], + nms_thres, + ) + + output[i] = detections[nms_out_index] + + # #------------------------------------------# + # # 获得预测结果中包含的所有种类 + # #------------------------------------------# + # unique_labels = detections[:, -1].cpu().unique() + + # if prediction.is_cuda: + # unique_labels = unique_labels.cuda() + # detections = detections.cuda() + + # for c in unique_labels: + # #------------------------------------------# + # # 获得某一类得分筛选后全部的预测结果 + # #------------------------------------------# + # detections_class = detections[detections[:, -1] == c] + + # #------------------------------------------# + # # 使用官方自带的非极大抑制会速度更快一些! + # #------------------------------------------# + # keep = nms( + # detections_class[:, :4], + # detections_class[:, 4] * detections_class[:, 5], + # nms_thres + # ) + # max_detections = detections_class[keep] + + # # # 按照存在物体的置信度排序 + # # _, conf_sort_index = torch.sort(detections_class[:, 4]*detections_class[:, 5], descending=True) + # # detections_class = detections_class[conf_sort_index] + # # # 进行非极大抑制 + # # max_detections = [] + # # while detections_class.size(0): + # # # 取出这一类置信度最高的,一步一步往下判断,判断重合程度是否大于nms_thres,如果是则去除掉 + # # max_detections.append(detections_class[0].unsqueeze(0)) + # # if len(detections_class) == 1: + # # break + # # ious = bbox_iou(max_detections[-1], detections_class[1:]) + # # detections_class = detections_class[1:][ious < nms_thres] + # # # 堆叠 + # # max_detections = torch.cat(max_detections).data + + # # Add max detections to outputs + # output[i] = max_detections if output[i] is None else torch.cat((output[i], max_detections)) + + if output[i] is not None: + output[i] = output[i].cpu().numpy() + box_xy, box_wh = (output[i][:, 0:2] + output[i][:, 2:4])/2, output[i][:, 2:4] - output[i][:, 0:2] + output[i][:, :4] = yolo_correct_boxes(box_xy, box_wh, input_shape, image_shape, letterbox_image) + return output diff --git a/utils/utils_fit.py b/utils/utils_fit.py new file mode 100644 index 0000000..ae07337 --- /dev/null +++ b/utils/utils_fit.py @@ -0,0 +1,145 @@ +import os + +import torch +from tqdm import tqdm + +from utils.utils import get_lr + +import pdb +def fit_one_epoch(model_train, model, ema, yolo_loss, loss_history, eval_callback, optimizer, epoch, epoch_step, epoch_step_val, gen, gen_val, Epoch, cuda, fp16, scaler, save_period, save_dir, local_rank=0): + loss = 0 + val_loss = 0 + + epoch_step = epoch_step // 5 # 每次epoch只随机用训练集合的一部分 防止过拟合 + + if local_rank == 0: + print('Start Train') + pbar = tqdm(total=epoch_step,desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3) + #pdb.set_trace() + model_train.train() + for iteration, batch in enumerate(gen): + if iteration >= epoch_step: + break + + images, targets = batch[0], batch[1] + with torch.no_grad(): + if cuda: + images = images.cuda(local_rank) + targets = [ann.cuda(local_rank) for ann in targets] + #----------------------# + # 清零梯度 + #----------------------# + optimizer.zero_grad() + if not fp16: + #----------------------# + # 前向传播 + #----------------------# + #pdb.set_trace() + outputs = model_train(images) + + #----------------------# + # 计算损失 + #----------------------# + loss_value = yolo_loss(outputs, targets) #+ motion_loss + + #----------------------# + # 反向传播 + #----------------------# + # torch.autograd.set_detect_anomaly(True) + # with torch.autograd.detect_anomaly(): + loss_value.backward() + optimizer.step() + else: + from torch.cuda.amp import autocast + with autocast(): + outputs = model_train(images) + #----------------------# + # 计算损失 + #----------------------# + loss_value = yolo_loss(outputs, targets) + + #----------------------# + # 反向传播 + #----------------------# + scaler.scale(loss_value).backward() + scaler.step(optimizer) + scaler.update() + #pdb.set_trace() + if ema: + ema.update(model_train) + + loss += loss_value.item() + + if local_rank == 0: + pbar.set_postfix(**{'loss' : loss / (iteration + 1), + 'lr' : get_lr(optimizer)}) + pbar.update(1) + + if local_rank == 0: + pbar.close() + print('Finish Train') + print('Start Validation') + pbar = tqdm(total=epoch_step_val, desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3) + + if ema: + model_train_eval = ema.ema + else: + model_train_eval = model_train.eval() + + for iteration, batch in enumerate(gen_val): + if iteration >= epoch_step_val: + break + images, targets = batch[0], batch[1] + with torch.no_grad(): + if cuda: + images = images.cuda(local_rank) + targets = [ann.cuda(local_rank) for ann in targets] + #----------------------# + # 清零梯度 + #----------------------# + optimizer.zero_grad() + #----------------------# + # 前向传播 + #----------------------# + outputs = model_train_eval(images) + + #----------------------# + # 计算损失 + #----------------------# + loss_value = yolo_loss(outputs, targets) + + val_loss += loss_value.item() + if local_rank == 0: + pbar.set_postfix(**{'val_loss': val_loss / (iteration + 1)}) + pbar.update(1) + + if local_rank == 0: + pbar.close() + print('Finish Validation') + loss_history.append_loss(epoch + 1, loss / epoch_step, val_loss / epoch_step_val) + eval_callback.on_epoch_end(epoch + 1, model_train_eval) + print('Epoch:'+ str(epoch + 1) + '/' + str(Epoch)) + print('Total Loss: %.3f || Val Loss: %.3f ' % (loss / epoch_step, val_loss / epoch_step_val)) + + #-----------------------------------------------# + # 保存权值 + #-----------------------------------------------# + if ema: + save_state_dict = ema.ema.state_dict() + else: + save_state_dict = model.state_dict() + + if (epoch + 1) % save_period == 0 or epoch + 1 == Epoch: + torch.save(save_state_dict, os.path.join(save_dir, "ep%03d-loss%.3f-val_loss%.3f.pth" % (epoch + 1, loss / epoch_step, val_loss / epoch_step_val))) + + if len(loss_history.val_loss) <= 1 or (val_loss / epoch_step_val) <= min(loss_history.val_loss): + print('Save best model to best_epoch_weights.pth') + torch.save(save_state_dict, os.path.join(save_dir, "best_epoch_weights.pth")) + + torch.save(save_state_dict, os.path.join(save_dir, "last_epoch_weights.pth")) + + + + + + diff --git a/utils/utils_map.py b/utils/utils_map.py new file mode 100644 index 0000000..b49ddba --- /dev/null +++ b/utils/utils_map.py @@ -0,0 +1,923 @@ +import glob +import json +import math +import operator +import os +import shutil +import sys +try: + from pycocotools.coco import COCO + from pycocotools.cocoeval import COCOeval +except: + pass +import cv2 +import matplotlib +matplotlib.use('Agg') +from matplotlib import pyplot as plt +import numpy as np + +''' + 0,0 ------> x (width) + | + | (Left,Top) + | *_________ + | | | + | | + y |_________| + (height) * + (Right,Bottom) +''' + +def log_average_miss_rate(precision, fp_cumsum, num_images): + """ + log-average miss rate: + Calculated by averaging miss rates at 9 evenly spaced FPPI points + between 10e-2 and 10e0, in log-space. + + output: + lamr | log-average miss rate + mr | miss rate + fppi | false positives per image + + references: + [1] Dollar, Piotr, et al. "Pedestrian Detection: An Evaluation of the + State of the Art." Pattern Analysis and Machine Intelligence, IEEE + Transactions on 34.4 (2012): 743 - 761. + """ + + if precision.size == 0: + lamr = 0 + mr = 1 + fppi = 0 + return lamr, mr, fppi + + fppi = fp_cumsum / float(num_images) + mr = (1 - precision) + + fppi_tmp = np.insert(fppi, 0, -1.0) + mr_tmp = np.insert(mr, 0, 1.0) + + ref = np.logspace(-2.0, 0.0, num = 9) + for i, ref_i in enumerate(ref): + j = np.where(fppi_tmp <= ref_i)[-1][-1] + ref[i] = mr_tmp[j] + + lamr = math.exp(np.mean(np.log(np.maximum(1e-10, ref)))) + + return lamr, mr, fppi + +""" + throw error and exit +""" +def error(msg): + print(msg) + sys.exit(0) + +""" + check if the number is a float between 0.0 and 1.0 +""" +def is_float_between_0_and_1(value): + try: + val = float(value) + if val > 0.0 and val < 1.0: + return True + else: + return False + except ValueError: + return False + +""" + Calculate the AP given the recall and precision array + 1st) We compute a version of the measured precision/recall curve with + precision monotonically decreasing + 2nd) We compute the AP as the area under this curve by numerical integration. +""" +def voc_ap(rec, prec): + """ + --- Official matlab code VOC2012--- + mrec=[0 ; rec ; 1]; + mpre=[0 ; prec ; 0]; + for i=numel(mpre)-1:-1:1 + mpre(i)=max(mpre(i),mpre(i+1)); + end + i=find(mrec(2:end)~=mrec(1:end-1))+1; + ap=sum((mrec(i)-mrec(i-1)).*mpre(i)); + """ + rec.insert(0, 0.0) # insert 0.0 at begining of list + rec.append(1.0) # insert 1.0 at end of list + mrec = rec[:] + prec.insert(0, 0.0) # insert 0.0 at begining of list + prec.append(0.0) # insert 0.0 at end of list + mpre = prec[:] + """ + This part makes the precision monotonically decreasing + (goes from the end to the beginning) + matlab: for i=numel(mpre)-1:-1:1 + mpre(i)=max(mpre(i),mpre(i+1)); + """ + for i in range(len(mpre)-2, -1, -1): + mpre[i] = max(mpre[i], mpre[i+1]) + """ + This part creates a list of indexes where the recall changes + matlab: i=find(mrec(2:end)~=mrec(1:end-1))+1; + """ + i_list = [] + for i in range(1, len(mrec)): + if mrec[i] != mrec[i-1]: + i_list.append(i) # if it was matlab would be i + 1 + """ + The Average Precision (AP) is the area under the curve + (numerical integration) + matlab: ap=sum((mrec(i)-mrec(i-1)).*mpre(i)); + """ + ap = 0.0 + for i in i_list: + ap += ((mrec[i]-mrec[i-1])*mpre[i]) + return ap, mrec, mpre + + +""" + Convert the lines of a file to a list +""" +def file_lines_to_list(path): + # open txt file lines to a list + with open(path) as f: + content = f.readlines() + # remove whitespace characters like `\n` at the end of each line + content = [x.strip() for x in content] + return content + +""" + Draws text in image +""" +def draw_text_in_image(img, text, pos, color, line_width): + font = cv2.FONT_HERSHEY_PLAIN + fontScale = 1 + lineType = 1 + bottomLeftCornerOfText = pos + cv2.putText(img, text, + bottomLeftCornerOfText, + font, + fontScale, + color, + lineType) + text_width, _ = cv2.getTextSize(text, font, fontScale, lineType)[0] + return img, (line_width + text_width) + +""" + Plot - adjust axes +""" +def adjust_axes(r, t, fig, axes): + # get text width for re-scaling + bb = t.get_window_extent(renderer=r) + text_width_inches = bb.width / fig.dpi + # get axis width in inches + current_fig_width = fig.get_figwidth() + new_fig_width = current_fig_width + text_width_inches + propotion = new_fig_width / current_fig_width + # get axis limit + x_lim = axes.get_xlim() + axes.set_xlim([x_lim[0], x_lim[1]*propotion]) + +""" + Draw plot using Matplotlib +""" +def draw_plot_func(dictionary, n_classes, window_title, plot_title, x_label, output_path, to_show, plot_color, true_p_bar): + # sort the dictionary by decreasing value, into a list of tuples + sorted_dic_by_value = sorted(dictionary.items(), key=operator.itemgetter(1)) + # unpacking the list of tuples into two lists + sorted_keys, sorted_values = zip(*sorted_dic_by_value) + # + if true_p_bar != "": + """ + Special case to draw in: + - green -> TP: True Positives (object detected and matches ground-truth) + - red -> FP: False Positives (object detected but does not match ground-truth) + - orange -> FN: False Negatives (object not detected but present in the ground-truth) + """ + fp_sorted = [] + tp_sorted = [] + for key in sorted_keys: + fp_sorted.append(dictionary[key] - true_p_bar[key]) + tp_sorted.append(true_p_bar[key]) + plt.barh(range(n_classes), fp_sorted, align='center', color='crimson', label='False Positive') + plt.barh(range(n_classes), tp_sorted, align='center', color='forestgreen', label='True Positive', left=fp_sorted) + # add legend + plt.legend(loc='lower right') + """ + Write number on side of bar + """ + fig = plt.gcf() # gcf - get current figure + axes = plt.gca() + r = fig.canvas.get_renderer() + for i, val in enumerate(sorted_values): + fp_val = fp_sorted[i] + tp_val = tp_sorted[i] + fp_str_val = " " + str(fp_val) + tp_str_val = fp_str_val + " " + str(tp_val) + # trick to paint multicolor with offset: + # first paint everything and then repaint the first number + t = plt.text(val, i, tp_str_val, color='forestgreen', va='center', fontweight='bold') + plt.text(val, i, fp_str_val, color='crimson', va='center', fontweight='bold') + if i == (len(sorted_values)-1): # largest bar + adjust_axes(r, t, fig, axes) + else: + plt.barh(range(n_classes), sorted_values, color=plot_color) + """ + Write number on side of bar + """ + fig = plt.gcf() # gcf - get current figure + axes = plt.gca() + r = fig.canvas.get_renderer() + for i, val in enumerate(sorted_values): + str_val = " " + str(val) # add a space before + if val < 1.0: + str_val = " {0:.2f}".format(val) + t = plt.text(val, i, str_val, color=plot_color, va='center', fontweight='bold') + # re-set axes to show number inside the figure + if i == (len(sorted_values)-1): # largest bar + adjust_axes(r, t, fig, axes) + # set window title + fig.canvas.set_window_title(window_title) + # write classes in y axis + tick_font_size = 12 + plt.yticks(range(n_classes), sorted_keys, fontsize=tick_font_size) + """ + Re-scale height accordingly + """ + init_height = fig.get_figheight() + # comput the matrix height in points and inches + dpi = fig.dpi + height_pt = n_classes * (tick_font_size * 1.4) # 1.4 (some spacing) + height_in = height_pt / dpi + # compute the required figure height + top_margin = 0.15 # in percentage of the figure height + bottom_margin = 0.05 # in percentage of the figure height + figure_height = height_in / (1 - top_margin - bottom_margin) + # set new height + if figure_height > init_height: + fig.set_figheight(figure_height) + + # set plot title + plt.title(plot_title, fontsize=14) + # set axis titles + # plt.xlabel('classes') + plt.xlabel(x_label, fontsize='large') + # adjust size of window + fig.tight_layout() + # save the plot + fig.savefig(output_path) + # show image + if to_show: + plt.show() + # close the plot + plt.close() + +def get_map(MINOVERLAP, draw_plot, score_threhold=0.5, path = './map_out'): + GT_PATH = os.path.join(path, 'ground-truth') + DR_PATH = os.path.join(path, 'detection-results') + IMG_PATH = os.path.join(path, 'images-optional') + TEMP_FILES_PATH = os.path.join(path, '.temp_files') + RESULTS_FILES_PATH = os.path.join(path, 'results') + + show_animation = True + if os.path.exists(IMG_PATH): + for dirpath, dirnames, files in os.walk(IMG_PATH): + if not files: + show_animation = False + else: + show_animation = False + + if not os.path.exists(TEMP_FILES_PATH): + os.makedirs(TEMP_FILES_PATH) + + if os.path.exists(RESULTS_FILES_PATH): + shutil.rmtree(RESULTS_FILES_PATH) + else: + os.makedirs(RESULTS_FILES_PATH) + if draw_plot: + try: + matplotlib.use('TkAgg') + except: + pass + os.makedirs(os.path.join(RESULTS_FILES_PATH, "AP")) + os.makedirs(os.path.join(RESULTS_FILES_PATH, "F1")) + os.makedirs(os.path.join(RESULTS_FILES_PATH, "Recall")) + os.makedirs(os.path.join(RESULTS_FILES_PATH, "Precision")) + if show_animation: + os.makedirs(os.path.join(RESULTS_FILES_PATH, "images", "detections_one_by_one")) + + ground_truth_files_list = glob.glob(GT_PATH + '/*.txt') + if len(ground_truth_files_list) == 0: + error("Error: No ground-truth files found!") + ground_truth_files_list.sort() + gt_counter_per_class = {} + counter_images_per_class = {} + + for txt_file in ground_truth_files_list: + file_id = txt_file.split(".txt", 1)[0] + file_id = os.path.basename(os.path.normpath(file_id)) + temp_path = os.path.join(DR_PATH, (file_id + ".txt")) + if not os.path.exists(temp_path): + error_msg = "Error. File not found: {}\n".format(temp_path) + error(error_msg) + lines_list = file_lines_to_list(txt_file) + bounding_boxes = [] + is_difficult = False + already_seen_classes = [] + for line in lines_list: + try: + if "difficult" in line: + class_name, left, top, right, bottom, _difficult = line.split() + is_difficult = True + else: + class_name, left, top, right, bottom = line.split() + except: + if "difficult" in line: + line_split = line.split() + _difficult = line_split[-1] + bottom = line_split[-2] + right = line_split[-3] + top = line_split[-4] + left = line_split[-5] + class_name = "" + for name in line_split[:-5]: + class_name += name + " " + class_name = class_name[:-1] + is_difficult = True + else: + line_split = line.split() + bottom = line_split[-1] + right = line_split[-2] + top = line_split[-3] + left = line_split[-4] + class_name = "" + for name in line_split[:-4]: + class_name += name + " " + class_name = class_name[:-1] + + bbox = left + " " + top + " " + right + " " + bottom + if is_difficult: + bounding_boxes.append({"class_name":class_name, "bbox":bbox, "used":False, "difficult":True}) + is_difficult = False + else: + bounding_boxes.append({"class_name":class_name, "bbox":bbox, "used":False}) + if class_name in gt_counter_per_class: + gt_counter_per_class[class_name] += 1 + else: + gt_counter_per_class[class_name] = 1 + + if class_name not in already_seen_classes: + if class_name in counter_images_per_class: + counter_images_per_class[class_name] += 1 + else: + counter_images_per_class[class_name] = 1 + already_seen_classes.append(class_name) + + with open(TEMP_FILES_PATH + "/" + file_id + "_ground_truth.json", 'w') as outfile: + json.dump(bounding_boxes, outfile) + + gt_classes = list(gt_counter_per_class.keys()) + gt_classes = sorted(gt_classes) + n_classes = len(gt_classes) + + dr_files_list = glob.glob(DR_PATH + '/*.txt') + dr_files_list.sort() + for class_index, class_name in enumerate(gt_classes): + bounding_boxes = [] + for txt_file in dr_files_list: + file_id = txt_file.split(".txt",1)[0] + file_id = os.path.basename(os.path.normpath(file_id)) + temp_path = os.path.join(GT_PATH, (file_id + ".txt")) + if class_index == 0: + if not os.path.exists(temp_path): + error_msg = "Error. File not found: {}\n".format(temp_path) + error(error_msg) + lines = file_lines_to_list(txt_file) + for line in lines: + try: + tmp_class_name, confidence, left, top, right, bottom = line.split() + except: + line_split = line.split() + bottom = line_split[-1] + right = line_split[-2] + top = line_split[-3] + left = line_split[-4] + confidence = line_split[-5] + tmp_class_name = "" + for name in line_split[:-5]: + tmp_class_name += name + " " + tmp_class_name = tmp_class_name[:-1] + + if tmp_class_name == class_name: + bbox = left + " " + top + " " + right + " " +bottom + bounding_boxes.append({"confidence":confidence, "file_id":file_id, "bbox":bbox}) + + bounding_boxes.sort(key=lambda x:float(x['confidence']), reverse=True) + with open(TEMP_FILES_PATH + "/" + class_name + "_dr.json", 'w') as outfile: + json.dump(bounding_boxes, outfile) + + sum_AP = 0.0 + ap_dictionary = {} + lamr_dictionary = {} + with open(RESULTS_FILES_PATH + "/results.txt", 'w') as results_file: + results_file.write("# AP and precision/recall per class\n") + count_true_positives = {} + + for class_index, class_name in enumerate(gt_classes): + count_true_positives[class_name] = 0 + dr_file = TEMP_FILES_PATH + "/" + class_name + "_dr.json" + dr_data = json.load(open(dr_file)) + + nd = len(dr_data) + tp = [0] * nd + fp = [0] * nd + score = [0] * nd + score_threhold_idx = 0 + for idx, detection in enumerate(dr_data): + file_id = detection["file_id"] + score[idx] = float(detection["confidence"]) + if score[idx] >= score_threhold: + score_threhold_idx = idx + + if show_animation: + ground_truth_img = glob.glob1(IMG_PATH, file_id + ".*") + if len(ground_truth_img) == 0: + error("Error. Image not found with id: " + file_id) + elif len(ground_truth_img) > 1: + error("Error. Multiple image with id: " + file_id) + else: + img = cv2.imread(IMG_PATH + "/" + ground_truth_img[0]) + img_cumulative_path = RESULTS_FILES_PATH + "/images/" + ground_truth_img[0] + if os.path.isfile(img_cumulative_path): + img_cumulative = cv2.imread(img_cumulative_path) + else: + img_cumulative = img.copy() + bottom_border = 60 + BLACK = [0, 0, 0] + img = cv2.copyMakeBorder(img, 0, bottom_border, 0, 0, cv2.BORDER_CONSTANT, value=BLACK) + + gt_file = TEMP_FILES_PATH + "/" + file_id + "_ground_truth.json" + ground_truth_data = json.load(open(gt_file)) + ovmax = -1 + gt_match = -1 + bb = [float(x) for x in detection["bbox"].split()] + for obj in ground_truth_data: + if obj["class_name"] == class_name: + bbgt = [ float(x) for x in obj["bbox"].split() ] + bi = [max(bb[0],bbgt[0]), max(bb[1],bbgt[1]), min(bb[2],bbgt[2]), min(bb[3],bbgt[3])] + iw = bi[2] - bi[0] + 1 + ih = bi[3] - bi[1] + 1 + if iw > 0 and ih > 0: + ua = (bb[2] - bb[0] + 1) * (bb[3] - bb[1] + 1) + (bbgt[2] - bbgt[0] + + 1) * (bbgt[3] - bbgt[1] + 1) - iw * ih + ov = iw * ih / ua + if ov > ovmax: + ovmax = ov + gt_match = obj + + if show_animation: + status = "NO MATCH FOUND!" + + min_overlap = MINOVERLAP + if ovmax >= min_overlap: + if "difficult" not in gt_match: + if not bool(gt_match["used"]): + tp[idx] = 1 + gt_match["used"] = True + count_true_positives[class_name] += 1 + with open(gt_file, 'w') as f: + f.write(json.dumps(ground_truth_data)) + if show_animation: + status = "MATCH!" + else: + fp[idx] = 1 + if show_animation: + status = "REPEATED MATCH!" + else: + fp[idx] = 1 + if ovmax > 0: + status = "INSUFFICIENT OVERLAP" + + """ + Draw image to show animation + """ + if show_animation: + height, widht = img.shape[:2] + white = (255,255,255) + light_blue = (255,200,100) + green = (0,255,0) + light_red = (30,30,255) + margin = 10 + # 1nd line + v_pos = int(height - margin - (bottom_border / 2.0)) + text = "Image: " + ground_truth_img[0] + " " + img, line_width = draw_text_in_image(img, text, (margin, v_pos), white, 0) + text = "Class [" + str(class_index) + "/" + str(n_classes) + "]: " + class_name + " " + img, line_width = draw_text_in_image(img, text, (margin + line_width, v_pos), light_blue, line_width) + if ovmax != -1: + color = light_red + if status == "INSUFFICIENT OVERLAP": + text = "IoU: {0:.2f}% ".format(ovmax*100) + "< {0:.2f}% ".format(min_overlap*100) + else: + text = "IoU: {0:.2f}% ".format(ovmax*100) + ">= {0:.2f}% ".format(min_overlap*100) + color = green + img, _ = draw_text_in_image(img, text, (margin + line_width, v_pos), color, line_width) + # 2nd line + v_pos += int(bottom_border / 2.0) + rank_pos = str(idx+1) + text = "Detection #rank: " + rank_pos + " confidence: {0:.2f}% ".format(float(detection["confidence"])*100) + img, line_width = draw_text_in_image(img, text, (margin, v_pos), white, 0) + color = light_red + if status == "MATCH!": + color = green + text = "Result: " + status + " " + img, line_width = draw_text_in_image(img, text, (margin + line_width, v_pos), color, line_width) + + font = cv2.FONT_HERSHEY_SIMPLEX + if ovmax > 0: + bbgt = [ int(round(float(x))) for x in gt_match["bbox"].split() ] + cv2.rectangle(img,(bbgt[0],bbgt[1]),(bbgt[2],bbgt[3]),light_blue,2) + cv2.rectangle(img_cumulative,(bbgt[0],bbgt[1]),(bbgt[2],bbgt[3]),light_blue,2) + cv2.putText(img_cumulative, class_name, (bbgt[0],bbgt[1] - 5), font, 0.6, light_blue, 1, cv2.LINE_AA) + bb = [int(i) for i in bb] + cv2.rectangle(img,(bb[0],bb[1]),(bb[2],bb[3]),color,2) + cv2.rectangle(img_cumulative,(bb[0],bb[1]),(bb[2],bb[3]),color,2) + cv2.putText(img_cumulative, class_name, (bb[0],bb[1] - 5), font, 0.6, color, 1, cv2.LINE_AA) + + cv2.imshow("Animation", img) + cv2.waitKey(20) + output_img_path = RESULTS_FILES_PATH + "/images/detections_one_by_one/" + class_name + "_detection" + str(idx) + ".jpg" + cv2.imwrite(output_img_path, img) + cv2.imwrite(img_cumulative_path, img_cumulative) + + cumsum = 0 + for idx, val in enumerate(fp): + fp[idx] += cumsum + cumsum += val + + cumsum = 0 + for idx, val in enumerate(tp): + tp[idx] += cumsum + cumsum += val + + rec = tp[:] + for idx, val in enumerate(tp): + rec[idx] = float(tp[idx]) / np.maximum(gt_counter_per_class[class_name], 1) + + prec = tp[:] + for idx, val in enumerate(tp): + prec[idx] = float(tp[idx]) / np.maximum((fp[idx] + tp[idx]), 1) + + ap, mrec, mprec = voc_ap(rec[:], prec[:]) + F1 = np.array(rec)*np.array(prec)*2 / np.where((np.array(prec)+np.array(rec))==0, 1, (np.array(prec)+np.array(rec))) + + sum_AP += ap + text = "{0:.2f}%".format(ap*100) + " = " + class_name + " AP " #class_name + " AP = {0:.2f}%".format(ap*100) + + if len(prec)>0: + F1_text = "{0:.2f}".format(F1[score_threhold_idx]) + " = " + class_name + " F1 " + Recall_text = "{0:.2f}%".format(rec[score_threhold_idx]*100) + " = " + class_name + " Recall " + Precision_text = "{0:.2f}%".format(prec[score_threhold_idx]*100) + " = " + class_name + " Precision " + else: + F1_text = "0.00" + " = " + class_name + " F1 " + Recall_text = "0.00%" + " = " + class_name + " Recall " + Precision_text = "0.00%" + " = " + class_name + " Precision " + + rounded_prec = [ '%.2f' % elem for elem in prec ] + rounded_rec = [ '%.2f' % elem for elem in rec ] + results_file.write(text + "\n Precision: " + str(rounded_prec) + "\n Recall :" + str(rounded_rec) + "\n\n") + + if len(prec)>0: + print(text + "\t||\tscore_threhold=" + str(score_threhold) + " : " + "F1=" + "{0:.2f}".format(F1[score_threhold_idx])\ + + " ; Recall=" + "{0:.2f}%".format(rec[score_threhold_idx]*100) + " ; Precision=" + "{0:.2f}%".format(prec[score_threhold_idx]*100)) + else: + print(text + "\t||\tscore_threhold=" + str(score_threhold) + " : " + "F1=0.00% ; Recall=0.00% ; Precision=0.00%") + ap_dictionary[class_name] = ap + + n_images = counter_images_per_class[class_name] + lamr, mr, fppi = log_average_miss_rate(np.array(rec), np.array(fp), n_images) + lamr_dictionary[class_name] = lamr + + if draw_plot: + plt.plot(rec, prec, '-o') + area_under_curve_x = mrec[:-1] + [mrec[-2]] + [mrec[-1]] + area_under_curve_y = mprec[:-1] + [0.0] + [mprec[-1]] + plt.fill_between(area_under_curve_x, 0, area_under_curve_y, alpha=0.2, edgecolor='r') + + fig = plt.gcf() + fig.canvas.set_window_title('AP ' + class_name) + + plt.title('class: ' + text) + plt.xlabel('Recall') + plt.ylabel('Precision') + axes = plt.gca() + axes.set_xlim([0.0,1.0]) + axes.set_ylim([0.0,1.05]) + fig.savefig(RESULTS_FILES_PATH + "/AP/" + class_name + ".png") + plt.cla() + + plt.plot(score, F1, "-", color='orangered') + plt.title('class: ' + F1_text + "\nscore_threhold=" + str(score_threhold)) + plt.xlabel('Score_Threhold') + plt.ylabel('F1') + axes = plt.gca() + axes.set_xlim([0.0,1.0]) + axes.set_ylim([0.0,1.05]) + fig.savefig(RESULTS_FILES_PATH + "/F1/" + class_name + ".png") + plt.cla() + + plt.plot(score, rec, "-H", color='gold') + plt.title('class: ' + Recall_text + "\nscore_threhold=" + str(score_threhold)) + plt.xlabel('Score_Threhold') + plt.ylabel('Recall') + axes = plt.gca() + axes.set_xlim([0.0,1.0]) + axes.set_ylim([0.0,1.05]) + fig.savefig(RESULTS_FILES_PATH + "/Recall/" + class_name + ".png") + plt.cla() + + plt.plot(score, prec, "-s", color='palevioletred') + plt.title('class: ' + Precision_text + "\nscore_threhold=" + str(score_threhold)) + plt.xlabel('Score_Threhold') + plt.ylabel('Precision') + axes = plt.gca() + axes.set_xlim([0.0,1.0]) + axes.set_ylim([0.0,1.05]) + fig.savefig(RESULTS_FILES_PATH + "/Precision/" + class_name + ".png") + plt.cla() + + if show_animation: + cv2.destroyAllWindows() + if n_classes == 0: + print("未检测到任何种类,请检查标签信息与get_map.py中的classes_path是否修改。") + return 0 + results_file.write("\n# mAP of all classes\n") + mAP = sum_AP / n_classes + text = "mAP = {0:.2f}%".format(mAP*100) + results_file.write(text + "\n") + print(text) + + shutil.rmtree(TEMP_FILES_PATH) + + """ + Count total of detection-results + """ + det_counter_per_class = {} + for txt_file in dr_files_list: + lines_list = file_lines_to_list(txt_file) + for line in lines_list: + class_name = line.split()[0] + if class_name in det_counter_per_class: + det_counter_per_class[class_name] += 1 + else: + det_counter_per_class[class_name] = 1 + dr_classes = list(det_counter_per_class.keys()) + + """ + Write number of ground-truth objects per class to results.txt + """ + with open(RESULTS_FILES_PATH + "/results.txt", 'a') as results_file: + results_file.write("\n# Number of ground-truth objects per class\n") + for class_name in sorted(gt_counter_per_class): + results_file.write(class_name + ": " + str(gt_counter_per_class[class_name]) + "\n") + + """ + Finish counting true positives + """ + for class_name in dr_classes: + if class_name not in gt_classes: + count_true_positives[class_name] = 0 + + """ + Write number of detected objects per class to results.txt + """ + with open(RESULTS_FILES_PATH + "/results.txt", 'a') as results_file: + results_file.write("\n# Number of detected objects per class\n") + for class_name in sorted(dr_classes): + n_det = det_counter_per_class[class_name] + text = class_name + ": " + str(n_det) + text += " (tp:" + str(count_true_positives[class_name]) + "" + text += ", fp:" + str(n_det - count_true_positives[class_name]) + ")\n" + results_file.write(text) + + """ + Plot the total number of occurences of each class in the ground-truth + """ + if draw_plot: + window_title = "ground-truth-info" + plot_title = "ground-truth\n" + plot_title += "(" + str(len(ground_truth_files_list)) + " files and " + str(n_classes) + " classes)" + x_label = "Number of objects per class" + output_path = RESULTS_FILES_PATH + "/ground-truth-info.png" + to_show = False + plot_color = 'forestgreen' + draw_plot_func( + gt_counter_per_class, + n_classes, + window_title, + plot_title, + x_label, + output_path, + to_show, + plot_color, + '', + ) + + # """ + # Plot the total number of occurences of each class in the "detection-results" folder + # """ + # if draw_plot: + # window_title = "detection-results-info" + # # Plot title + # plot_title = "detection-results\n" + # plot_title += "(" + str(len(dr_files_list)) + " files and " + # count_non_zero_values_in_dictionary = sum(int(x) > 0 for x in list(det_counter_per_class.values())) + # plot_title += str(count_non_zero_values_in_dictionary) + " detected classes)" + # # end Plot title + # x_label = "Number of objects per class" + # output_path = RESULTS_FILES_PATH + "/detection-results-info.png" + # to_show = False + # plot_color = 'forestgreen' + # true_p_bar = count_true_positives + # draw_plot_func( + # det_counter_per_class, + # len(det_counter_per_class), + # window_title, + # plot_title, + # x_label, + # output_path, + # to_show, + # plot_color, + # true_p_bar + # ) + + """ + Draw log-average miss rate plot (Show lamr of all classes in decreasing order) + """ + if draw_plot: + window_title = "lamr" + plot_title = "log-average miss rate" + x_label = "log-average miss rate" + output_path = RESULTS_FILES_PATH + "/lamr.png" + to_show = False + plot_color = 'royalblue' + draw_plot_func( + lamr_dictionary, + n_classes, + window_title, + plot_title, + x_label, + output_path, + to_show, + plot_color, + "" + ) + + """ + Draw mAP plot (Show AP's of all classes in decreasing order) + """ + if draw_plot: + window_title = "mAP" + plot_title = "mAP = {0:.2f}%".format(mAP*100) + x_label = "Average Precision" + output_path = RESULTS_FILES_PATH + "/mAP.png" + to_show = True + plot_color = 'royalblue' + draw_plot_func( + ap_dictionary, + n_classes, + window_title, + plot_title, + x_label, + output_path, + to_show, + plot_color, + "" + ) + return mAP + +def preprocess_gt(gt_path, class_names): + image_ids = os.listdir(gt_path) + results = {} + + images = [] + bboxes = [] + for i, image_id in enumerate(image_ids): + lines_list = file_lines_to_list(os.path.join(gt_path, image_id)) + boxes_per_image = [] + image = {} + image_id = os.path.splitext(image_id)[0] + image['file_name'] = image_id + '.jpg' + image['width'] = 1 + image['height'] = 1 + #-----------------------------------------------------------------# + # 感谢 多学学英语吧 的提醒 + # 解决了'Results do not correspond to current coco set'问题 + #-----------------------------------------------------------------# + image['id'] = str(image_id) + + for line in lines_list: + difficult = 0 + if "difficult" in line: + line_split = line.split() + left, top, right, bottom, _difficult = line_split[-5:] + class_name = "" + for name in line_split[:-5]: + class_name += name + " " + class_name = class_name[:-1] + difficult = 1 + else: + line_split = line.split() + left, top, right, bottom = line_split[-4:] + class_name = "" + for name in line_split[:-4]: + class_name += name + " " + class_name = class_name[:-1] + + left, top, right, bottom = float(left), float(top), float(right), float(bottom) + if class_name not in class_names: + continue + cls_id = class_names.index(class_name) + 1 + bbox = [left, top, right - left, bottom - top, difficult, str(image_id), cls_id, (right - left) * (bottom - top) - 10.0] + boxes_per_image.append(bbox) + images.append(image) + bboxes.extend(boxes_per_image) + results['images'] = images + + categories = [] + for i, cls in enumerate(class_names): + category = {} + category['supercategory'] = cls + category['name'] = cls + category['id'] = i + 1 + categories.append(category) + results['categories'] = categories + + annotations = [] + for i, box in enumerate(bboxes): + annotation = {} + annotation['area'] = box[-1] + annotation['category_id'] = box[-2] + annotation['image_id'] = box[-3] + annotation['iscrowd'] = box[-4] + annotation['bbox'] = box[:4] + annotation['id'] = i + annotations.append(annotation) + results['annotations'] = annotations + return results + +def preprocess_dr(dr_path, class_names): + image_ids = os.listdir(dr_path) + results = [] + for image_id in image_ids: + lines_list = file_lines_to_list(os.path.join(dr_path, image_id)) + image_id = os.path.splitext(image_id)[0] + for line in lines_list: + line_split = line.split() + confidence, left, top, right, bottom = line_split[-5:] + class_name = "" + for name in line_split[:-5]: + class_name += name + " " + class_name = class_name[:-1] + left, top, right, bottom = float(left), float(top), float(right), float(bottom) + result = {} + result["image_id"] = str(image_id) + if class_name not in class_names: + continue + result["category_id"] = class_names.index(class_name) + 1 + result["bbox"] = [left, top, right - left, bottom - top] + result["score"] = float(confidence) + results.append(result) + return results + +def get_coco_map(class_names, path): + GT_PATH = os.path.join(path, 'ground-truth') + DR_PATH = os.path.join(path, 'detection-results') + COCO_PATH = os.path.join(path, 'coco_eval') + + if not os.path.exists(COCO_PATH): + os.makedirs(COCO_PATH) + + GT_JSON_PATH = os.path.join(COCO_PATH, 'instances_gt.json') + DR_JSON_PATH = os.path.join(COCO_PATH, 'instances_dr.json') + + with open(GT_JSON_PATH, "w") as f: + results_gt = preprocess_gt(GT_PATH, class_names) + json.dump(results_gt, f, indent=4) + + with open(DR_JSON_PATH, "w") as f: + results_dr = preprocess_dr(DR_PATH, class_names) + json.dump(results_dr, f, indent=4) + if len(results_dr) == 0: + print("未检测到任何目标。") + return [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + + cocoGt = COCO(GT_JSON_PATH) + cocoDt = cocoGt.loadRes(DR_JSON_PATH) + cocoEval = COCOeval(cocoGt, cocoDt, 'bbox') + cocoEval.evaluate() + cocoEval.accumulate() + cocoEval.summarize() + + return cocoEval.stats \ No newline at end of file