This commit is contained in:
esenke
2025-12-08 21:38:53 +08:00
commit 71118fc649
22 changed files with 4780 additions and 0 deletions

83
summary.py Normal file
View File

@@ -0,0 +1,83 @@
# --------------------------------------------#
# 该部分代码用于看网络结构
# --------------------------------------------#
import os
from torch import nn
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import torch
from thop import clever_format, profile
from model.TDCNet.TDCNetwork import TDCNetwork
from model.TDCNet.TDCR import RepConv3D
if __name__ == "__main__":
input_shape = [640, 640]
num_classes = 1
# 需要使用device来指定网络在GPU还是CPU运行8824
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_frame = 5
m = TDCNetwork(num_classes, num_frame=5)
for mm in m.modules():
if isinstance(mm, RepConv3D):
mm.switch_to_deploy()
dummy_input = torch.randn(1, 3, num_frame * 2, input_shape[0], input_shape[1]).to(device) # torch.randn(1, 3, 5,input_shape[0], input_shape[1]).to(device)
flops, params = profile(m.to(device), (dummy_input,), verbose=False)
# #--------------------------------------------------------#
# # flops * 2是因为profile没有将卷积作为两个operations
# # 有些论文将卷积算乘法、加法两个operations。此时乘2
# # 有些论文只考虑乘法的运算次数忽略加法。此时不乘2
# # 本代码选择乘2参考YOLOX。
# #--------------------------------------------------------#
flops = flops * 2
flops, params = clever_format([flops, params], "%.3f")
print('Total GFLOPS: %s' % (flops))
print('Total params: %s' % (params))
# 计算FPS
from data.dataloader_for_IRSTD_UAV import seqDataset, dataset_collate
from torch.utils.data import DataLoader
import time
max_iter = 2000
log_interval = 50
num_warmup = 20
pure_inf_time = 0
fps = 0
val_annotation_path = "/data1/gsk/Dataset/IRSTD-UAV/val.txt"
val_dataset = seqDataset(val_annotation_path, input_shape[0], num_frame, 'val')
gen_val = DataLoader(val_dataset, shuffle = False, batch_size = 1, num_workers = 10, pin_memory=True,
drop_last=True, collate_fn=dataset_collate)
m = nn.DataParallel(m).cuda()
# benchmark with 2000 image and take the average
for i, data in enumerate(gen_val):
torch.cuda.synchronize()
start_time = time.perf_counter()
with torch.no_grad():
m(data[0].to('cuda'))
torch.cuda.synchronize()
elapsed = time.perf_counter() - start_time
if i >= num_warmup:
pure_inf_time += elapsed
if (i + 1) % log_interval == 0:
fps = (i + 1 - num_warmup) / pure_inf_time
print(
f'Done image [{i + 1:<3}/ {max_iter}], '
f'fps: {fps:.1f} img / s, '
f'times per image: {1000 / fps:.1f} ms / img',
flush=True)
if (i + 1) == max_iter:
fps = (i + 1 - num_warmup) / pure_inf_time
print(
f'Overall fps: {fps:.1f} img / s, '
f'times per image: {1000 / fps:.1f} ms / img',
flush=True)
break
print("FPS:" ,fps)