init
This commit is contained in:
83
summary.py
Normal file
83
summary.py
Normal file
@@ -0,0 +1,83 @@
|
||||
# --------------------------------------------#
|
||||
# 该部分代码用于看网络结构
|
||||
# --------------------------------------------#
|
||||
import os
|
||||
|
||||
from torch import nn
|
||||
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
||||
import torch
|
||||
from thop import clever_format, profile
|
||||
|
||||
from model.TDCNet.TDCNetwork import TDCNetwork
|
||||
from model.TDCNet.TDCR import RepConv3D
|
||||
|
||||
if __name__ == "__main__":
|
||||
input_shape = [640, 640]
|
||||
num_classes = 1
|
||||
|
||||
# 需要使用device来指定网络在GPU还是CPU运行8824
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
num_frame = 5
|
||||
m = TDCNetwork(num_classes, num_frame=5)
|
||||
for mm in m.modules():
|
||||
if isinstance(mm, RepConv3D):
|
||||
mm.switch_to_deploy()
|
||||
dummy_input = torch.randn(1, 3, num_frame * 2, input_shape[0], input_shape[1]).to(device) # torch.randn(1, 3, 5,input_shape[0], input_shape[1]).to(device)
|
||||
flops, params = profile(m.to(device), (dummy_input,), verbose=False)
|
||||
# #--------------------------------------------------------#
|
||||
# # flops * 2是因为profile没有将卷积作为两个operations
|
||||
# # 有些论文将卷积算乘法、加法两个operations。此时乘2
|
||||
# # 有些论文只考虑乘法的运算次数,忽略加法。此时不乘2
|
||||
# # 本代码选择乘2,参考YOLOX。
|
||||
# #--------------------------------------------------------#
|
||||
flops = flops * 2
|
||||
flops, params = clever_format([flops, params], "%.3f")
|
||||
print('Total GFLOPS: %s' % (flops))
|
||||
print('Total params: %s' % (params))
|
||||
|
||||
# 计算FPS
|
||||
|
||||
from data.dataloader_for_IRSTD_UAV import seqDataset, dataset_collate
|
||||
from torch.utils.data import DataLoader
|
||||
import time
|
||||
max_iter = 2000
|
||||
log_interval = 50
|
||||
num_warmup = 20
|
||||
pure_inf_time = 0
|
||||
fps = 0
|
||||
val_annotation_path = "/data1/gsk/Dataset/IRSTD-UAV/val.txt"
|
||||
val_dataset = seqDataset(val_annotation_path, input_shape[0], num_frame, 'val')
|
||||
gen_val = DataLoader(val_dataset, shuffle = False, batch_size = 1, num_workers = 10, pin_memory=True,
|
||||
drop_last=True, collate_fn=dataset_collate)
|
||||
m = nn.DataParallel(m).cuda()
|
||||
|
||||
# benchmark with 2000 image and take the average
|
||||
for i, data in enumerate(gen_val):
|
||||
torch.cuda.synchronize()
|
||||
start_time = time.perf_counter()
|
||||
|
||||
with torch.no_grad():
|
||||
m(data[0].to('cuda'))
|
||||
|
||||
torch.cuda.synchronize()
|
||||
elapsed = time.perf_counter() - start_time
|
||||
|
||||
if i >= num_warmup:
|
||||
pure_inf_time += elapsed
|
||||
if (i + 1) % log_interval == 0:
|
||||
fps = (i + 1 - num_warmup) / pure_inf_time
|
||||
print(
|
||||
f'Done image [{i + 1:<3}/ {max_iter}], '
|
||||
f'fps: {fps:.1f} img / s, '
|
||||
f'times per image: {1000 / fps:.1f} ms / img',
|
||||
flush=True)
|
||||
|
||||
if (i + 1) == max_iter:
|
||||
fps = (i + 1 - num_warmup) / pure_inf_time
|
||||
print(
|
||||
f'Overall fps: {fps:.1f} img / s, '
|
||||
f'times per image: {1000 / fps:.1f} ms / img',
|
||||
flush=True)
|
||||
break
|
||||
print("FPS:" ,fps)
|
||||
Reference in New Issue
Block a user