84 lines
3.1 KiB
Python
84 lines
3.1 KiB
Python
# --------------------------------------------#
|
||
# 该部分代码用于看网络结构
|
||
# --------------------------------------------#
|
||
import os
|
||
|
||
from torch import nn
|
||
|
||
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
||
import torch
|
||
from thop import clever_format, profile
|
||
|
||
from model.TDCNet.TDCNetwork import TDCNetwork
|
||
from model.TDCNet.TDCR import RepConv3D
|
||
|
||
if __name__ == "__main__":
|
||
input_shape = [640, 640]
|
||
num_classes = 1
|
||
|
||
# 需要使用device来指定网络在GPU还是CPU运行8824
|
||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||
num_frame = 5
|
||
m = TDCNetwork(num_classes, num_frame=5)
|
||
for mm in m.modules():
|
||
if isinstance(mm, RepConv3D):
|
||
mm.switch_to_deploy()
|
||
dummy_input = torch.randn(1, 3, num_frame * 2, input_shape[0], input_shape[1]).to(device) # torch.randn(1, 3, 5,input_shape[0], input_shape[1]).to(device)
|
||
flops, params = profile(m.to(device), (dummy_input,), verbose=False)
|
||
# #--------------------------------------------------------#
|
||
# # flops * 2是因为profile没有将卷积作为两个operations
|
||
# # 有些论文将卷积算乘法、加法两个operations。此时乘2
|
||
# # 有些论文只考虑乘法的运算次数,忽略加法。此时不乘2
|
||
# # 本代码选择乘2,参考YOLOX。
|
||
# #--------------------------------------------------------#
|
||
flops = flops * 2
|
||
flops, params = clever_format([flops, params], "%.3f")
|
||
print('Total GFLOPS: %s' % (flops))
|
||
print('Total params: %s' % (params))
|
||
|
||
# 计算FPS
|
||
|
||
from data.dataloader_for_IRSTD_UAV import seqDataset, dataset_collate
|
||
from torch.utils.data import DataLoader
|
||
import time
|
||
max_iter = 2000
|
||
log_interval = 50
|
||
num_warmup = 20
|
||
pure_inf_time = 0
|
||
fps = 0
|
||
val_annotation_path = "/data1/gsk/Dataset/IRSTD-UAV/val.txt"
|
||
val_dataset = seqDataset(val_annotation_path, input_shape[0], num_frame, 'val')
|
||
gen_val = DataLoader(val_dataset, shuffle = False, batch_size = 1, num_workers = 10, pin_memory=True,
|
||
drop_last=True, collate_fn=dataset_collate)
|
||
m = nn.DataParallel(m).cuda()
|
||
|
||
# benchmark with 2000 image and take the average
|
||
for i, data in enumerate(gen_val):
|
||
torch.cuda.synchronize()
|
||
start_time = time.perf_counter()
|
||
|
||
with torch.no_grad():
|
||
m(data[0].to('cuda'))
|
||
|
||
torch.cuda.synchronize()
|
||
elapsed = time.perf_counter() - start_time
|
||
|
||
if i >= num_warmup:
|
||
pure_inf_time += elapsed
|
||
if (i + 1) % log_interval == 0:
|
||
fps = (i + 1 - num_warmup) / pure_inf_time
|
||
print(
|
||
f'Done image [{i + 1:<3}/ {max_iter}], '
|
||
f'fps: {fps:.1f} img / s, '
|
||
f'times per image: {1000 / fps:.1f} ms / img',
|
||
flush=True)
|
||
|
||
if (i + 1) == max_iter:
|
||
fps = (i + 1 - num_warmup) / pure_inf_time
|
||
print(
|
||
f'Overall fps: {fps:.1f} img / s, '
|
||
f'times per image: {1000 / fps:.1f} ms / img',
|
||
flush=True)
|
||
break
|
||
print("FPS:" ,fps)
|