Facebook
From 李志成, 1 Month ago, written in Plain Text.
Embed
Download Paste or View Raw
Hits: 163
  1. import torch
  2. import torchvision
  3.  
  4. model = torchvision.models.resnet101(pretrained=True)
  5. model = model.cuda()
  6.  
  7. model.eval()
  8.  
  9. from ptflops import get_model_complexity_info
  10.  
  11. import torch2trt
  12.  
  13. #x = torch.randn(6, 3, 928, 1600, device="cuda:0", dtype=torch.float32)
  14.            
  15. import time
  16. timer = []
  17. input = torch.randn(6, 3, 928, 1600, device="cuda:0", dtype=torch.float32)
  18. trtmodel = torch2trt.torch2trt(model,
  19.                                inputs=[input],
  20.                                fp16_mode=False,
  21.                                max_batch_size=6,
  22.                                max_workspace_size=int(4e9))
  23.  
  24. for i in range(1000):
  25.     s_time = time.time()
  26.     trtmodel(input)
  27.     torch.cuda.synchronize()
  28.     e_time = time.time()
  29.     timer.append(1000 * (e_time - s_time))
  30.     print("{}".format(1000 * (e_time - s_time)))
  31.     #print("{}".format(timer[-1]))
  32. avg_time = sum(timer[500:])/len(timer[500:])
  33. print("resnet_trt with 5 3 928 1600 {}".format(avg_time))
  34.  
  35. timer = []
  36. input = torch.randn(1, 3, 224, 224, device="cuda:0", dtype=torch.float32)
  37. for i in range(1000):
  38.     model(input)
  39.  
  40. for i in range(1000):
  41.     s_time = time.time()
  42.     model(input)
  43.     torch.cuda.synchronize()
  44.     e_time = time.time()
  45.     timer.append(1000 * (e_time - s_time))
  46.     print("{}".format(1000 * (e_time - s_time)))
  47. avg_time = sum(timer)/len(timer)
  48. print("resnet_torch with 1 3 224 224 {}".format(avg_time))
  49.  
  50. timer = []
  51. input = torch.randn(6, 3, 224, 224, device="cuda:0", dtype=torch.float32)
  52. for i in range(1000):
  53.     s_time = time.time()
  54.     model(input)
  55.     torch.cuda.synchronize()
  56.     e_time = time.time()
  57.     timer.append(1000 * (e_time - s_time))
  58.     print("{}".format(1000 * (e_time - s_time)))
  59. avg_time = sum(timer)/len(timer)
  60. print("resnet_torch with 6 3 224 224 {}".format(avg_time))
  61.  
  62. timer = []
  63. input = torch.randn(1, 3, 928, 1600, device="cuda:0", dtype=torch.float32)
  64. for i in range(1000):
  65.     s_time = time.time()
  66.     model(input)
  67.     torch.cuda.synchronize()
  68.     e_time = time.time()
  69.     timer.append(1000 * (e_time - s_time))
  70.     print("{}".format(1000 * (e_time - s_time)))
  71. avg_time = sum(timer)/len(timer)
  72. print("resnet_torch with 1 3 928 1600 {}".format(avg_time))
  73.  
  74. timer = []
  75. input = torch.randn(6, 3, 928, 1600, device="cuda:0", dtype=torch.float32)
  76. for i in range(1000):
  77.     s_time = time.time()
  78.     model(input)
  79.     torch.cuda.synchronize()
  80.     e_time = time.time()
  81.     timer.append(1000 * (e_time - s_time))
  82.     print("{}".format(1000 * (e_time - s_time)))
  83. avg_time = sum(timer)/len(timer)
  84. print("resnet_torch with 6 3 928 1600 {}".format(avg_time))
  85.  
  86.  
  87. timer = []
  88. input = torch.randn(1, 3, 224, 224, device="cuda:0", dtype=torch.float32)
  89. trtmodel = torch2trt.torch2trt(model,
  90.                                inputs=[input],
  91.                                fp16_mode=False,
  92.                                max_batch_size=6,
  93.                                max_workspace_size=int(4e9))
  94.  
  95. for i in range(1000):
  96.     s_time = time.time()
  97.     trtmodel(input)
  98.     torch.cuda.synchronize()
  99.     e_time = time.time()
  100.     timer.append(1000 * (e_time - s_time))
  101.     print("{}".format(1000 * (e_time - s_time)))
  102. avg_time = sum(timer)/len(timer)
  103. print("resnet_trt with 1 3 224 224 {}".format(avg_time))
  104.  
  105. timer = []
  106. input = torch.randn(6, 3, 224, 224, device="cuda:0", dtype=torch.float32)
  107. trtmodel = torch2trt.torch2trt(model,
  108.                                inputs=[input],
  109.                                fp16_mode=False,
  110.                                max_batch_size=6,
  111.                                max_workspace_size=int(4e9))
  112.  
  113. for i in range(1000):
  114.     s_time = time.time()
  115.     trtmodel(input)
  116.     torch.cuda.synchronize()
  117.     e_time = time.time()
  118.     timer.append(1000 * (e_time - s_time))
  119.     print("{}".format(1000 * (e_time - s_time)))
  120. avg_time = sum(timer)/len(timer)
  121. print("resnet_trt with 6 3 224 224 {}".format(avg_time))
  122.  
  123. timer = []
  124. input = torch.randn(1, 3, 928, 1600, device="cuda:0", dtype=torch.float32)
  125. trtmodel = torch2trt.torch2trt(model,
  126.                                inputs=[input],
  127.                                fp16_mode=False,
  128.                                max_batch_size=6,
  129.                                max_workspace_size=int(4e9))
  130.  
  131. for i in range(1000):
  132.     s_time = time.time()
  133.     trtmodel(input)
  134.     torch.cuda.synchronize()
  135.     e_time = time.time()
  136.     timer.append(1000 * (e_time - s_time))
  137.     print("{}".format(1000 * (e_time - s_time)))
  138. avg_time = sum(timer)/len(timer)
  139. print("resnet_trt with 1 3 928 1600 {}".format(avg_time))
  140.  
  141.