forked from jittor/jittor
add performance test tutorial
This commit is contained in:
parent
856f201d9f
commit
6c980c2146
|
@ -0,0 +1,176 @@
|
|||
Jittor性能测试与对比方法
|
||||
=====================
|
||||
|
||||
下面代码以AlexNet为例,用于演示 Jittor 性能测试的正确方法:
|
||||
|
||||
```python
|
||||
import time
|
||||
import jittor as jt
|
||||
from jittor.models import resnet50
|
||||
jt.flags.use_cuda = jt.has_cuda
|
||||
|
||||
warmup = 10
|
||||
rerun = 100
|
||||
batch_size = 8
|
||||
data = jt.random((batch_size, 3, 224, 224))
|
||||
model = resnet50()
|
||||
model.eval()
|
||||
|
||||
# 此段代码对jittor进行热身,确保时间测试准确
|
||||
jt.sync_all(True)
|
||||
for i in range(warmup):
|
||||
pred = model(data)
|
||||
# sync是把计算图发送到计算设备上
|
||||
pred.sync()
|
||||
# sync_all(true)是把计算图发射到计算设备上,并且同步。
|
||||
# 只有运行了jt.sync_all(True)才会真正地运行,时间才是有效的,因此执行forward前后都要执行这句话
|
||||
jt.sync_all(True)
|
||||
|
||||
# 开始测试运行时间
|
||||
start = time.time()
|
||||
for i in range(rerun):
|
||||
pred = model(data)
|
||||
pred.sync()
|
||||
jt.sync_all(True)
|
||||
end = time.time()
|
||||
|
||||
print("Jittor FPS:", (rerun*batch_size)/(end-start))
|
||||
|
||||
```
|
||||
|
||||
在这段代码中,我们定义了几个参数`batch_size`, `warmup`, `rerun`, batch_size代表批大小,warmup是用于热身的循环次数,而rerun是用于测速的循环次数,最终输出FPS,对Jittor进行正确测速的关键是 热身部分和同步部分,热身部分确保测试时间稳定,没有包含编译用的时间,而同步部分确保计算完成,因为jittor是一个异步框架,只有同步操作能保证计算完成。
|
||||
|
||||
以上代码的运行结果如下(RTX Titan,batch 8):
|
||||
|
||||
```
|
||||
Compiling Operators(8/8) used: 7.35s eta: 0s
|
||||
Compiling Operators(13/13) used: 8.36s eta: 0s
|
||||
Jittor FPS: 908.9853866375396
|
||||
```
|
||||
|
||||
我们还可以使用类似的代码测试 PyTorch的性能:
|
||||
|
||||
```python
|
||||
import time
|
||||
import torch
|
||||
from torchvision.models import resnet50
|
||||
|
||||
warmup = 10
|
||||
rerun = 100
|
||||
batch_size = 8
|
||||
data = torch.randn((batch_size, 3, 224, 224)).cuda()
|
||||
model = resnet50()
|
||||
model.cuda()
|
||||
model.eval()
|
||||
|
||||
# 此段代码对pytorch进行热身,确保时间测试准确
|
||||
torch.cuda.synchronize()
|
||||
for i in range(warmup):
|
||||
pred = model(data)
|
||||
# synchronize用于确保PyTorch计算完成
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# 开始测试运行时间
|
||||
start = time.time()
|
||||
for i in range(rerun):
|
||||
pred = model(data)
|
||||
torch.cuda.synchronize()
|
||||
end = time.time()
|
||||
|
||||
print("PyTorch FPS:", (rerun*batch_size)/(end-start))
|
||||
```
|
||||
|
||||
|
||||
以上代码的运行结果如下(RTX Titan,batch 8):
|
||||
|
||||
```
|
||||
PyTorch FPS: 807.4806873965665
|
||||
```
|
||||
|
||||
我们还可以对这两段代码合并,并对比结果的一致性:
|
||||
|
||||
```python
|
||||
import time
|
||||
import jittor as jt
|
||||
from jittor.models import resnet50
|
||||
jt.flags.use_cuda = jt.has_cuda
|
||||
|
||||
warmup = 100
|
||||
rerun = 1000
|
||||
batch_size = 8
|
||||
data = jt.random((batch_size, 3, 224, 224))
|
||||
model = resnet50()
|
||||
model.eval()
|
||||
|
||||
# 此段代码对jittor进行热身,确保时间测试准确
|
||||
jt.sync_all(True)
|
||||
for i in range(warmup):
|
||||
pred = model(data)
|
||||
# sync是把计算图发送到计算设备上
|
||||
pred.sync()
|
||||
# sync_all(true)是把计算图发射到计算设备上,并且同步。
|
||||
# 只有运行了jt.sync_all(True)才会真正地运行,时间才是有效的,因此执行forward前后都要执行这句话
|
||||
jt.sync_all(True)
|
||||
|
||||
# 开始测试运行时间
|
||||
start = time.time()
|
||||
for i in range(rerun):
|
||||
pred = model(data)
|
||||
pred.sync()
|
||||
jt.sync_all(True)
|
||||
end = time.time()
|
||||
|
||||
print("Jittor FPS:", (rerun*batch_size)/(end-start))
|
||||
# 将 jittor 数据和参数导出为 numpy 和 torch 格式
|
||||
jittor_data = pred.numpy()
|
||||
jittor_param = model.state_dict(to="torch")
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torchvision.models import resnet50
|
||||
data = torch.Tensor(data.numpy()).cuda()
|
||||
model = resnet50()
|
||||
# 加载 jittor 参数
|
||||
model.load_state_dict(jittor_param)
|
||||
model.cuda()
|
||||
model.eval()
|
||||
|
||||
# 此段代码对pytorch进行热身,确保时间测试准确
|
||||
torch.cuda.synchronize()
|
||||
for i in range(warmup):
|
||||
pred = model(data)
|
||||
# synchronize用于确保PyTorch计算完成
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# 开始测试运行时间
|
||||
start = time.time()
|
||||
for i in range(rerun):
|
||||
pred = model(data)
|
||||
torch.cuda.synchronize()
|
||||
end = time.time()
|
||||
|
||||
print("PyTorch FPS:", (rerun*batch_size)/(end-start))
|
||||
pytorch_data = pred.detach().cpu().numpy()
|
||||
err = np.mean(np.abs(pytorch_data - jittor_data))
|
||||
print("mean error:", err)
|
||||
|
||||
```
|
||||
|
||||
|
||||
以上代码运行结果如下:
|
||||
|
||||
```
|
||||
Jittor FPS: 908.9853866375396
|
||||
PyTorch FPS: 807.4806873965665
|
||||
mean error: 1e-5
|
||||
```
|
||||
|
||||
误差输出为1e-5, 在可接受范围内。正确测速与对比的几大关键点为:
|
||||
|
||||
1. 充分热身,除去框架的准备时间。
|
||||
2. 多次运行,确保测试时间稳定。
|
||||
3. 加上同步语句,确保测试时间准确。
|
||||
4. 保证显存充足,在显存不足时,jittor会调用统一内存来弥补,会产生性能损失,请密切关注`nvidia-smi`的输出结果。
|
||||
5. 保证对比模型的一致性,检查输出结果的一致。
|
||||
|
||||
如果您对测试结果有疑问,或者有优化需求,欢迎随时联系Jittor开发团队。
|
|
@ -48,6 +48,7 @@
|
|||
:caption: 其他:
|
||||
|
||||
Jittor调试技巧
|
||||
Jittor性能测试与对比方法
|
||||
教程 <https://cg.cs.tsinghua.edu.cn/jittor/tutorial/>
|
||||
|
||||
Indices and tables
|
||||
|
|
|
@ -9,7 +9,7 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
__version__ = '1.3.1.10'
|
||||
__version__ = '1.3.1.11'
|
||||
from jittor_utils import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
@ -834,7 +834,35 @@ class Module:
|
|||
self.dfs([], None, callback, callback_leave)
|
||||
return _uniq(ps)
|
||||
|
||||
def state_dict(self):
|
||||
def state_dict(self, to=None):
|
||||
''' Returns a dictionary containing
|
||||
Jittor Var of the module and its descendants.
|
||||
|
||||
Args:
|
||||
to: target type of var, canbe None or 'numpy' or 'torch'
|
||||
|
||||
Return:
|
||||
dictionary of module's states.
|
||||
|
||||
Example::
|
||||
|
||||
import jittor as jt
|
||||
from jittor.models import resnet50
|
||||
jittor_model = resnet50()
|
||||
dict = jittor_model.state_dict()
|
||||
jittor_model.load_state_dict(dict)
|
||||
|
||||
Example2(export Jittor params to PyTorch)::
|
||||
|
||||
import jittor as jt
|
||||
from jittor.models import resnet50
|
||||
jittor_model = resnet50()
|
||||
import torch
|
||||
from torchvision.models import resnet50
|
||||
torch_model = resnet50()
|
||||
torch_model.load_state_dict(jittor_model.state_dict(to="torch"))
|
||||
|
||||
'''
|
||||
uniq_set = set()
|
||||
ps = {}
|
||||
stack = []
|
||||
|
@ -855,6 +883,15 @@ class Module:
|
|||
def callback_leave(parents, k, v, n):
|
||||
stack.pop()
|
||||
self.dfs([], None, callback, callback_leave)
|
||||
if to == "numpy":
|
||||
for k,v in ps.items():
|
||||
if isinstance(v, Var):
|
||||
ps[k] = v.numpy()
|
||||
elif to == "torch":
|
||||
import torch
|
||||
for k,v in ps.items():
|
||||
if isinstance(v, Var):
|
||||
ps[k] = torch.Tensor(v.numpy())
|
||||
return ps
|
||||
|
||||
def named_parameters(self):
|
||||
|
|
Loading…
Reference in New Issue