forked from jittor/jittor
version 1522f3d004f9bdbf3953d91d4c259c341817c71f
This commit is contained in:
commit
1258121b1f
|
@ -0,0 +1,19 @@
|
|||
my
|
||||
.refresh
|
||||
__pycache__
|
||||
.ipynb_checkpoints/
|
||||
.vscode/
|
||||
__res/
|
||||
perf.data
|
||||
perf.data.old
|
||||
*.swp
|
||||
*.ipynb
|
||||
*.pdf
|
||||
*.zip
|
||||
*.tgz
|
||||
test.py
|
||||
extern/mkl/mkldnn_lnx*/*
|
||||
data/
|
||||
build/
|
||||
*.md
|
||||
!*.src.md
|
|
@ -0,0 +1,203 @@
|
|||
Copyright (c) 2020 Jittor. All Rights Reserved
|
||||
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
|
@ -0,0 +1,403 @@
|
|||
# Jittor: a Just-in-time(JIT) deep learning framework
|
||||
# Jittor: 即时编译深度学习框架
|
||||
|
||||
[Quickstart](#quickstart) | [Install](#install) | [Tutorial](#tutorial)
|
||||
|
||||
[快速开始](#快速开始) | [安装](#安装) | [教程](#教程)
|
||||
|
||||
Jittor is a high-performance deep learning framework based on JIT compiling and meta-operators. The whole framework and meta-operators are compiled just-in-time. A powerful op compiler and tuner are integrated into Jittor. It allowed us to generate high-performance code with specialized for your model.
|
||||
|
||||
Jittor 是一个基于即时编译和元算子的高性能深度学习框架,整个框架在即时编译的同时,还集成了强大的Op编译器和调优器,为您的模型生成定制化的高性能代码。
|
||||
|
||||
The front-end language is Python. Module Design is used in the front-end, like PyTorch and Keras. The back-end is implemented py high performance language, such as CUDA,C++.
|
||||
|
||||
Jittor前端语言为Python。前端使用了模块化的设计,类似于PyTorch,Keras,后端则使用高性能语言编写,如CUDA,C++。
|
||||
|
||||
The following example shows how to model a two-layer neural network step by step and train from scratch In a few lines of Python code.
|
||||
|
||||
下面的代码演示了如何一步一步使用Python代码,从头对一个双层神经网络建模。
|
||||
|
||||
```python
|
||||
import jittor as jt
|
||||
from jittor import Module
|
||||
from jittor import nn
|
||||
class Model(Module):
|
||||
def __init__(self):
|
||||
self.layer1 = nn.Linear(1, 10)
|
||||
self.relu = nn.Relu()
|
||||
self.layer2 = nn.Linear(10, 1)
|
||||
def execute (self,x) :
|
||||
x = self.layer1(x)
|
||||
x = self.relu(x)
|
||||
x = self.layer2(x)
|
||||
return x
|
||||
|
||||
def get_data(n): # generate random data for training test.
|
||||
for i in range(n):
|
||||
x = np.random.rand(batch_size, 1)
|
||||
y = x*x
|
||||
yield jt.float32(x), jt.float32(y)
|
||||
|
||||
model = Model()
|
||||
learning_rate = 0.1
|
||||
optim = nn.SGD(model.parameters(), learning_rate)
|
||||
|
||||
for i,(x,y) in enumerate(get_data(n)):
|
||||
pred_y = model(x)
|
||||
loss = ((pred_y - y)**2)
|
||||
loss_mean = loss.mean()
|
||||
optim.step (loss_mean)
|
||||
print(f"step {i}, loss = {loss_mean.data.sum()}")
|
||||
```
|
||||
|
||||
## Contents
|
||||
|
||||
* [Quickstart](#quickstart)
|
||||
* [Install](#install)
|
||||
* [Tutorial](#tutorial)
|
||||
* [Contributing](#contributing)
|
||||
* [The Team](#theteam)
|
||||
* [License](#license)
|
||||
|
||||
## 大纲
|
||||
|
||||
- [快速开始](#快速开始)
|
||||
- [安装](#安装)
|
||||
- [教程](#教程)
|
||||
- [贡献](#贡献)
|
||||
- [团队](#团队)
|
||||
- [版权声明](#版权声明)
|
||||
|
||||
## Quickstart
|
||||
|
||||
## 快速开始
|
||||
|
||||
We provide some jupyter notebooks to help you quick start with Jittor.
|
||||
|
||||
我们提供了一些jupyterr notebooks来帮助您快速入门Jittor。
|
||||
|
||||
- [Example: Model definition and training][1]
|
||||
- [示例:模型定义与训练][1]
|
||||
- [Basics: Op, Var][2]
|
||||
- [基础:Op, Var][2]
|
||||
- [Meta-operator: Implement your own convolution with Meta-operator][3]
|
||||
- [元算子:通过元算子实现自己的卷积层][3]
|
||||
|
||||
## Install
|
||||
|
||||
## 安装
|
||||
|
||||
Jittor is written in Python and C++. It requires a compiler for JIT compilation, Currently, we support four compilers:
|
||||
|
||||
Jittor使用Python和C++编写。 它需要用于即时编译的编译器。当前,我们支持三种编译器:
|
||||
|
||||
* CPU compiler (require at least one of the following)
|
||||
* g++ (>=5.4.0)
|
||||
* clang (>=8.0) recommend
|
||||
* CPU 编译器 (需要下列至少一个)
|
||||
- g++ (>=5.4.0)
|
||||
- clang (>=8.0)推荐
|
||||
* GPU compiler (optional)
|
||||
* nvcc (>=10.0)
|
||||
* GPU 编译器(可选)
|
||||
- nvcc(>=10.0)
|
||||
|
||||
Jittor的环境要求如下:
|
||||
|
||||
* 操作系统: Ubuntu>=16.04
|
||||
* Python >= 3.7
|
||||
|
||||
Jittor offers three ways to install: pip, script or manual.
|
||||
|
||||
Jittor 一共提供三种方式安装: pip安装, 一键脚本安装 和 手动安装.
|
||||
|
||||
## Pip 安装
|
||||
|
||||
## Pip install
|
||||
|
||||
如果您已经装好编译器和对应版本的Python,我们强烈推荐您使用这种方法
|
||||
(如果无法访问github, 可以通过jittor主页下载):
|
||||
|
||||
```bash
|
||||
sudo apt install python-dev libomp-dev
|
||||
sudo pip install https://github.com/Jittor/jittor.git
|
||||
# if you cannot access github, please download code from our website:
|
||||
# wget https://cg.cs.tsinghua.edu.cn/jittor/assets/build/jittor.tgz
|
||||
# mkdir -p jittor && tar -xvf ./jittor.tgz -C jittor
|
||||
# sudo pip install ./jittor
|
||||
python3 -m jittor.test.test_example
|
||||
```
|
||||
|
||||
如果测试运行通过,恭喜你已经安装完成.
|
||||
jittor会自动在路径中寻找合适的编译器, 如果您希望手动指定编译器, 请使用环境变量 `cc_path` 和 `nvcc_path`(可选).
|
||||
|
||||
## 一键脚本安装
|
||||
## single line script install
|
||||
|
||||
一键脚本安装会帮您安装好所需的编译器.
|
||||
|
||||
We provide single line command for quick installation the latest version of Jittor(Ubuntu>=16.04):
|
||||
|
||||
我们提供能快速安装最新版本Jittor的单行命令(Ubuntu> = 16.04):
|
||||
|
||||
```bash
|
||||
# install with clang and cuda
|
||||
git clone https://github.com/Jittor/jittor.git && with_clang=1 with_cuda=1 bash ./jittor/script/install.sh
|
||||
# install with clang
|
||||
git clone https://github.com/Jittor/jittor.git && with_clang=1 bash ./jittor/script/install.sh
|
||||
# install with g++ and cuda
|
||||
git clone https://github.com/Jittor/jittor.git && with_gcc=1 with_cuda=1 bash ./jittor/script/install.sh
|
||||
# install with g++
|
||||
git clone https://github.com/Jittor/jittor.git && with_gcc=1 bash ./jittor/script/install.sh
|
||||
```
|
||||
After execution, the script will show some environment variables you need to export.
|
||||
|
||||
执行后,脚本将显示一些需要导出的环境变量。
|
||||
|
||||
If you use Jittor for CPU computing, we strongly recommend clang(>=8.0) as the back-end compiler of Jittor. Because some customized optimizations will be enabled.
|
||||
|
||||
如果将Jittor用于CPU计算,则强烈建议使用clang(> = 8.0)作为Jittor的后端编译器。 因为Jittor会用到其中一些定制的优化。
|
||||
|
||||
|
||||
## 手动安装
|
||||
## manual install
|
||||
|
||||
We will show how to install Jittor in Ubuntu 16.04 step by step, Other Linux distributions may have similar commands.
|
||||
|
||||
我们将逐步演示如何在Ubuntu 16.04中安装Jittor,其他Linux发行版可能可以使用类似的命令。
|
||||
|
||||
### Step 1: Choose your back-end compiler
|
||||
|
||||
### 步骤一:选择您的后端编译器
|
||||
|
||||
```bash
|
||||
# g++
|
||||
sudo apt install g++ build-essential libomp-dev
|
||||
|
||||
# OR clang-8
|
||||
wget -O - https://apt.llvm.org/llvm.sh > /tmp/llvm.sh
|
||||
bash /tmp/llvm.sh 8
|
||||
```
|
||||
### Step 2: Install Python and python-dev
|
||||
|
||||
### 步骤二:安装Python和python-dev
|
||||
|
||||
Jittor need python version >= 3.7.
|
||||
|
||||
Jittor需要python的版本>=3.7。
|
||||
|
||||
```bash
|
||||
sudo apt install python3.7 python3.7-dev
|
||||
```
|
||||
|
||||
### Step 3: Run Jittor
|
||||
|
||||
### 步骤三:运行Jittor
|
||||
|
||||
The whole framework is compiled Just-in-time. Let's install jittor via pip
|
||||
|
||||
整个框架是及时编译的。 让我们通过pip安装jittor
|
||||
|
||||
```bash
|
||||
git clone https://github.com/Jittor/jittor.git
|
||||
sudo pip3.7 install ./jittor
|
||||
export cc_path="clang-8"
|
||||
# if other compiler is used, change cc_path
|
||||
# export cc_path="g++"
|
||||
# export cc_path="icc"
|
||||
|
||||
# run a simple test
|
||||
python3.7 -m jittor.test.test_example
|
||||
```
|
||||
if the test is passed, your Jittor is ready.
|
||||
|
||||
如果通过了测试,那么您的Jittor已经准备就绪。
|
||||
|
||||
### Optional Step 4: Enable CUDA
|
||||
|
||||
### 可选步骤四:启用CUDA
|
||||
|
||||
Using CUDA in Jittor is very simple, Just setup environment value `nvcc_path`
|
||||
|
||||
在Jittor中使用CUDA非常简单,只需设置环境值`nvcc_path`
|
||||
|
||||
```bash
|
||||
# replace this var with your nvcc location
|
||||
export nvcc_path="/usr/local/cuda/bin/nvcc"
|
||||
# run a simple cuda test
|
||||
python3.7 -m jittor.test.test_cuda
|
||||
```
|
||||
if the test is passed, your can use Jittor with CUDA by setting `use_cuda` flag.
|
||||
|
||||
如果测试通过,则可以通过设置`use_cuda`标识符在Jittor中启用CUDA。
|
||||
|
||||
```python
|
||||
import jittor as jt
|
||||
jt.flags.use_cuda = 1
|
||||
```
|
||||
|
||||
### Optional Step 5: Run full tests
|
||||
|
||||
### 可选步骤五:进行完整测试
|
||||
|
||||
To check the integrity of Jittor, you can run full tests.
|
||||
|
||||
要检查Jittor的完整性,您可以运行完整的测试。
|
||||
|
||||
```bash
|
||||
python3.7 -m jittor.test -v
|
||||
```
|
||||
if those tests are failed, please report bugs for us, and feel free to contribute ^_^
|
||||
|
||||
如果这些测试失败,请为我们报告错误,我们十分欢迎您为Jittor做出贡献^ _ ^
|
||||
|
||||
## Tutorial
|
||||
|
||||
## 教程
|
||||
|
||||
In the tutorial section, we will briefly explain the basic concept of Jittor.
|
||||
|
||||
在教程部分,我们将简要解释Jittor的基本概念。
|
||||
|
||||
To train your model with Jittor, there are only three main concepts you need to know:
|
||||
|
||||
要使用Jittor训练模型,您需要了解两个主要概念:
|
||||
|
||||
* Var: basic data type of jittor
|
||||
* Var:Jittor的基本数据类型
|
||||
* Operations: Jittor'op is simular with numpy
|
||||
* Operations:Jittor的算子与numpy类似
|
||||
|
||||
### Var
|
||||
|
||||
### 数据类型
|
||||
|
||||
First, let's get started with Var. Var is the basic data type of jittor. Computation process in Jittor is asynchronous for optimization. If you want to access the data, `Var.data` can be used for synchronous data accessing.
|
||||
|
||||
首先,让我们开始使用Var。Var是jittor的基本数据类型,为了运算更加高效Jittor中的计算过程是异步的。 如果要访问数据,可以使用`Var.data`进行同步数据访问。
|
||||
|
||||
```python
|
||||
import jittor as jt
|
||||
a = jt.float32([1,2,3])
|
||||
print (a)
|
||||
print (a.data)
|
||||
# Output: float32[3,]
|
||||
# Output: [ 1. 2. 3.]
|
||||
```
|
||||
|
||||
And we can give the variable a name.
|
||||
|
||||
此外我们可以给变量起一个名字。
|
||||
|
||||
```python
|
||||
c.name('c')
|
||||
print(c.name())
|
||||
# Output: c
|
||||
```
|
||||
|
||||
###Operations
|
||||
|
||||
### 数据运算
|
||||
|
||||
Jittor'op is simular with numpy. Let's try some operations. We create Var `a` and `b` via operation `jt.float32`, and add them. Printing those variables shows they have the same shape and dtype.
|
||||
|
||||
Jittor的算子与numpy类似。 让我们尝试一些运算, 我们通过Op`jt.float32`创建Var `a`和`b`,并将它们相加。 输出这些变量相关信息,可以看出它们具有相同的形状和类型。
|
||||
|
||||
```python
|
||||
import jittor as jt
|
||||
a = jt.float32([1,2,3])
|
||||
b = jt.float32([4,5,6])
|
||||
c = a*b
|
||||
print(a,b,c)
|
||||
print(type(a), type(b), type(c))
|
||||
# Output: float32[3,] float32[3,] float32[3,]
|
||||
# Output: <class 'jittor_core.Var'> <class 'jittor_core.Var'> <class 'jittor_core.Var'>
|
||||
```
|
||||
Beside that, All the operators we used `jt.xxx(Var, ...)` have alias `Var.xxx(...)`. For example:
|
||||
|
||||
除此之外,我们使用的所有算子`jt.xxx(Var,...)`都具有别名`Var.xxx(...)`。 例如:
|
||||
|
||||
```python
|
||||
c.max() # alias of jt.max(a)
|
||||
c.add(a) # alias of jt.add(c, a)
|
||||
c.min(keepdims=True) # alias of jt.min(c, keepdims=True)
|
||||
```
|
||||
|
||||
if you want to know all the operation which Jittor supports. try `help(jt.ops)`. All the operation you found in `jt.ops.xxx`, can be used via alias `jt.xxx`.
|
||||
|
||||
如果您想知道Jittor支持的所有运算,可以运行`help(jt.ops)`。 您在`jt.ops.xxx`中找到的所有运算都可以通过别名`jt.xxx`。
|
||||
|
||||
```python
|
||||
help(jt.ops)
|
||||
# Output:
|
||||
# abs(x: core.Var) -> core.Var
|
||||
# add(x: core.Var, y: core.Var) -> core.Var
|
||||
# array(data: array) -> core.Var
|
||||
# binary(x: core.Var, y: core.Var, op: str) -> core.Var
|
||||
# ......
|
||||
```
|
||||
### More
|
||||
|
||||
### 更多教程
|
||||
|
||||
If you want to know more about Jittor, please check out the notebooks below:
|
||||
|
||||
如果您想进一步了解Jittor,请查看以下notebooks:
|
||||
|
||||
* Quickstart
|
||||
- [Example: Model definition and training][1]
|
||||
- [Basics: Op, Var][2]
|
||||
- [Meta-operator: Implement your own convolution with Meta-operator][3]
|
||||
* 快速开始
|
||||
* [示例:模型定义与训练][1]
|
||||
* [基本概念:Op, Var][2]
|
||||
* [元算子:通过元算子实现自己的卷积层][3]
|
||||
* Advanced
|
||||
- [Custom Op: write your operator with C++ and CUDA and JIT compile it][4]
|
||||
- [Profiler: Profiling your model][5]
|
||||
- Jtune: Tool for performance tuning
|
||||
* 进阶
|
||||
* [自定义算子:使用C ++和CUDA编写您的算子,并其进行即时编译][4]
|
||||
* [性能分析器:分析您的模型][5]
|
||||
* Jtune:性能调优工具
|
||||
|
||||
|
||||
|
||||
[1]: notebooks/example.md "example"
|
||||
[2]: notebooks/basics.md "basics"
|
||||
[3]: notebooks/meta_op.md "meta_op"
|
||||
[4]: notebooks/custom_op.md "custom_op"
|
||||
[5]: notebooks/profiler.md "profiler"
|
||||
[1]: notebooks/example.md "示例"
|
||||
[2]: notebooks/basics.md "基本概念"
|
||||
[3]: notebooks/meta_op.md "元算子"
|
||||
[4]: notebooks/custom_op.md "自定义算子"
|
||||
[5]: notebooks/profiler.md "性能分析器"
|
||||
|
||||
Those notebooks can be started in your own computer by `python3.7 -m jittor.notebook`
|
||||
|
||||
这些notebooks可以通过python3.7 -m jittor.notebook在您自己的计算机中运行。
|
||||
|
||||
## Contributing
|
||||
|
||||
## 贡献
|
||||
|
||||
Jittor is still young. It may contain bugs and issues. Please report them in our bug track system. Contributions are welcome. Besides, if you have any ideas about Jittor, please let us know.
|
||||
|
||||
Jittor还很年轻。 它可能存在错误和问题。 请在我们的错误跟踪系统中报告它们。 我们欢迎您为Jittor做出贡献。 此外,如果您对Jittor有任何想法,请告诉我们。
|
||||
|
||||
## The Team
|
||||
|
||||
## 团队
|
||||
|
||||
Jittor is currently maintained by Dun Liang, Guo-Ye Yang, Guo-Wei Yang and Wen-Yang Zhou etc. from the [Tsinghua CSCG Group](https://cg.cs.tsinghua.edu.cn/). If you are also interested in Jittor and want to improve it, Please join us!
|
||||
|
||||
Jittor目前由来自[清华大学计算机图形学组](https://cg.cs.tsinghua.edu.cn/)的梁盾,杨国烨,杨国炜和周文洋等博士生维护。 如果您也对Jittor感兴趣并希望对其进行改进,请加入我们!
|
||||
|
||||
## License
|
||||
|
||||
## 版权声明
|
||||
|
||||
Jittor is Apache 2.0 licensed, as found in the LICENSE.txt file.
|
||||
|
||||
如LICENSE.txt文件中所示,Jittor使用Apache 2.0版权协议。
|
|
@ -0,0 +1,253 @@
|
|||
/******************************************************************************
|
||||
* Copyright (c) 2011, Duane Merrill. All rights reserved.
|
||||
* Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright
|
||||
* notice, this list of conditions and the following disclaimer in the
|
||||
* documentation and/or other materials provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the
|
||||
* names of its contributors may be used to endorse or promote products
|
||||
* derived from this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
|
||||
/******************************************************************************
|
||||
* Simple example of DeviceRadixSort::SortPairs().
|
||||
*
|
||||
* Sorts an array of float keys paired with a corresponding array of int values.
|
||||
*
|
||||
* To compile using the command line:
|
||||
* nvcc -arch=sm_XX example_device_radix_sort.cu -I../.. -lcudart -O3
|
||||
*
|
||||
******************************************************************************/
|
||||
|
||||
// Ensure printing of CUDA runtime errors to console
|
||||
#define CUB_STDERR
|
||||
|
||||
#include <stdio.h>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
|
||||
#include <cub/util_allocator.cuh>
|
||||
#include <cub/device/device_radix_sort.cuh>
|
||||
|
||||
#include <test/test_util.h>
|
||||
|
||||
using namespace cub;
|
||||
|
||||
|
||||
//---------------------------------------------------------------------
|
||||
// Globals, constants and typedefs
|
||||
//---------------------------------------------------------------------
|
||||
|
||||
bool g_verbose = false; // Whether to display input/output to console
|
||||
|
||||
|
||||
//---------------------------------------------------------------------
|
||||
// Test generation
|
||||
//---------------------------------------------------------------------
|
||||
|
||||
/**
|
||||
* Simple key-value pairing for floating point types. Distinguishes
|
||||
* between positive and negative zero.
|
||||
*/
|
||||
struct Pair
|
||||
{
|
||||
float key;
|
||||
int value;
|
||||
|
||||
bool operator<(const Pair &b) const
|
||||
{
|
||||
if (key < b.key)
|
||||
return true;
|
||||
|
||||
if (key > b.key)
|
||||
return false;
|
||||
|
||||
// Return true if key is negative zero and b.key is positive zero
|
||||
unsigned int key_bits = *reinterpret_cast<unsigned*>(const_cast<float*>(&key));
|
||||
unsigned int b_key_bits = *reinterpret_cast<unsigned*>(const_cast<float*>(&b.key));
|
||||
unsigned int HIGH_BIT = 1u << 31;
|
||||
|
||||
return ((key_bits & HIGH_BIT) != 0) && ((b_key_bits & HIGH_BIT) == 0);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/**
|
||||
* Initialize key-value sorting problem.
|
||||
*/
|
||||
void Initialize(
|
||||
float *h_keys,
|
||||
int *h_values,
|
||||
float *h_reference_keys,
|
||||
int *h_reference_values,
|
||||
int num_items)
|
||||
{
|
||||
Pair *h_pairs = new Pair[num_items];
|
||||
|
||||
for (int i = 0; i < num_items; ++i)
|
||||
{
|
||||
RandomBits(h_keys[i]);
|
||||
h_values[i] = i;
|
||||
h_pairs[i].key = h_keys[i];
|
||||
h_pairs[i].value = h_values[i];
|
||||
}
|
||||
|
||||
if (g_verbose)
|
||||
{
|
||||
printf("Input keys:\n");
|
||||
DisplayResults(h_keys, num_items);
|
||||
printf("\n\n");
|
||||
|
||||
printf("Input values:\n");
|
||||
DisplayResults(h_values, num_items);
|
||||
printf("\n\n");
|
||||
}
|
||||
|
||||
std::stable_sort(h_pairs, h_pairs + num_items);
|
||||
|
||||
for (int i = 0; i < num_items; ++i)
|
||||
{
|
||||
h_reference_keys[i] = h_pairs[i].key;
|
||||
h_reference_values[i] = h_pairs[i].value;
|
||||
}
|
||||
|
||||
if (g_verbose)
|
||||
{
|
||||
printf("std Output keys:\n");
|
||||
DisplayResults(h_reference_keys, num_items);
|
||||
printf("\n\n");
|
||||
|
||||
printf("std Output values:\n");
|
||||
DisplayResults(h_reference_values, num_items);
|
||||
printf("\n\n");
|
||||
}
|
||||
delete[] h_pairs;
|
||||
}
|
||||
|
||||
|
||||
//---------------------------------------------------------------------
|
||||
// Main
|
||||
//---------------------------------------------------------------------
|
||||
|
||||
/**
|
||||
* Main
|
||||
*/
|
||||
int cub_test_entry(int argc, char** argv)
|
||||
{
|
||||
CachingDeviceAllocator g_allocator(true); // Caching allocator for device memory
|
||||
|
||||
int num_items = 150;
|
||||
|
||||
// Initialize command line
|
||||
CommandLineArgs args(argc, argv);
|
||||
g_verbose = args.CheckCmdLineFlag("v");
|
||||
args.GetCmdLineArgument("n", num_items);
|
||||
|
||||
// Print usage
|
||||
if (args.CheckCmdLineFlag("help"))
|
||||
{
|
||||
printf("%s "
|
||||
"[--n=<input items> "
|
||||
"[--device=<device-id>] "
|
||||
"[--v] "
|
||||
"\n", argv[0]);
|
||||
exit(0);
|
||||
}
|
||||
|
||||
// Initialize device
|
||||
CubDebugExit(args.DeviceInit());
|
||||
|
||||
printf("cub::DeviceRadixSort::SortPairs() %d items (%d-byte keys %d-byte values)\n",
|
||||
num_items, int(sizeof(float)), int(sizeof(int)));
|
||||
fflush(stdout);
|
||||
|
||||
// Allocate host arrays
|
||||
float *h_keys = new float[num_items];
|
||||
float *h_reference_keys = new float[num_items];
|
||||
int *h_values = new int[num_items];
|
||||
int *h_reference_values = new int[num_items];
|
||||
|
||||
// Initialize problem and solution on host
|
||||
Initialize(h_keys, h_values, h_reference_keys, h_reference_values, num_items);
|
||||
|
||||
// Allocate device arrays
|
||||
DoubleBuffer<float> d_keys;
|
||||
DoubleBuffer<int> d_values;
|
||||
CubDebugExit(g_allocator.DeviceAllocate((void**)&d_keys.d_buffers[0], sizeof(float) * num_items));
|
||||
CubDebugExit(g_allocator.DeviceAllocate((void**)&d_keys.d_buffers[1], sizeof(float) * num_items));
|
||||
CubDebugExit(g_allocator.DeviceAllocate((void**)&d_values.d_buffers[0], sizeof(int) * num_items));
|
||||
CubDebugExit(g_allocator.DeviceAllocate((void**)&d_values.d_buffers[1], sizeof(int) * num_items));
|
||||
|
||||
// Allocate temporary storage
|
||||
size_t temp_storage_bytes = 0;
|
||||
void *d_temp_storage = NULL;
|
||||
|
||||
CubDebugExit(DeviceRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, d_keys, d_values, num_items));
|
||||
CubDebugExit(g_allocator.DeviceAllocate(&d_temp_storage, temp_storage_bytes));
|
||||
|
||||
// Initialize device arrays
|
||||
CubDebugExit(cudaMemcpy(d_keys.d_buffers[d_keys.selector], h_keys, sizeof(float) * num_items, cudaMemcpyHostToDevice));
|
||||
CubDebugExit(cudaMemcpy(d_values.d_buffers[d_values.selector], h_values, sizeof(int) * num_items, cudaMemcpyHostToDevice));
|
||||
|
||||
// Run
|
||||
CubDebugExit(DeviceRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, d_keys, d_values, num_items));
|
||||
|
||||
// Check for correctness (and display results, if specified)
|
||||
std::unique_ptr<float[]> d_keys_ptr(new float[num_items]);
|
||||
std::unique_ptr<int[]> d_values_ptr(new int[num_items]);
|
||||
std::unique_ptr<float[]> origin(new float[num_items]);
|
||||
cudaMemcpy(d_keys_ptr.get(), d_keys.Current(), sizeof(float) * num_items, cudaMemcpyDeviceToHost);
|
||||
cudaMemcpy(d_values_ptr.get(), d_values.Current(), sizeof(int) * num_items, cudaMemcpyDeviceToHost);
|
||||
bool ok = true;
|
||||
for (int i=0; i<num_items; i++) {
|
||||
origin[h_reference_values[i]] = h_reference_keys[i];
|
||||
}
|
||||
for (int i=0; i<num_items; i++) {
|
||||
AssertEquals(h_reference_keys[i], d_keys_ptr[i]);
|
||||
auto h = origin[h_reference_values[i]];
|
||||
auto d = origin[d_values_ptr[i]];
|
||||
if (h != d) {
|
||||
printf("aa %d\n", h_reference_values[i]);
|
||||
printf("aa %d\n", d_values_ptr[i]);
|
||||
printf("bb %f %f %f\n", h, d, d_keys_ptr[i]);
|
||||
ok = false;
|
||||
}
|
||||
}
|
||||
if (!ok) exit(1);
|
||||
|
||||
// Cleanup
|
||||
if (h_keys) delete[] h_keys;
|
||||
if (h_reference_keys) delete[] h_reference_keys;
|
||||
if (h_values) delete[] h_values;
|
||||
if (h_reference_values) delete[] h_reference_values;
|
||||
|
||||
if (d_keys.d_buffers[0]) CubDebugExit(g_allocator.DeviceFree(d_keys.d_buffers[0]));
|
||||
if (d_keys.d_buffers[1]) CubDebugExit(g_allocator.DeviceFree(d_keys.d_buffers[1]));
|
||||
if (d_values.d_buffers[0]) CubDebugExit(g_allocator.DeviceFree(d_values.d_buffers[0]));
|
||||
if (d_values.d_buffers[1]) CubDebugExit(g_allocator.DeviceFree(d_values.d_buffers[1]));
|
||||
if (d_temp_storage) CubDebugExit(g_allocator.DeviceFree(d_temp_storage));
|
||||
|
||||
printf("\n\n");
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,107 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors:
|
||||
// Guoye Yang <498731903@qq.com>
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include <algorithm>
|
||||
#include "var.h"
|
||||
#include "cub_arg_reduce_op.h"
|
||||
#include <vector>
|
||||
#include "executor.h"
|
||||
#include "ops/arg_reduce_op.h"
|
||||
#ifdef JIT_cuda
|
||||
#include <cub/device/device_segmented_radix_sort.cuh>
|
||||
#include <cub/cub.cuh>
|
||||
#endif
|
||||
|
||||
namespace jittor {
|
||||
|
||||
#ifndef JIT
|
||||
CubArgReduceOp::CubArgReduceOp(Var* x, Var* offsets, string op, bool keepdims)
|
||||
: x(x), offsets(offsets), op(op), keepdims(keepdims) {
|
||||
flags.set(NodeFlags::_cpu, 0);
|
||||
flags.set(NodeFlags::_cuda, 1);
|
||||
ASSERT(offsets->dtype()==ns_int || offsets->dtype()==ns_int32);
|
||||
y = create_output(nullptr, ns_int32);
|
||||
y_key = create_output(nullptr, x->dtype());
|
||||
}
|
||||
|
||||
VarPtr CubArgReduceOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
||||
return ArgReduceOp::get_grad(out, dout, v, v_index, v->shape.size()-1, y);
|
||||
}
|
||||
|
||||
void CubArgReduceOp::infer_shape() {
|
||||
int n = 1;
|
||||
for (int i = 0; i < x->shape.size(); ++i) {
|
||||
if (i < x->shape.size() - 1) {
|
||||
n *= x->shape[i];
|
||||
}
|
||||
}
|
||||
ASSERT(offsets->shape.size() == 1);
|
||||
ASSERT(offsets->shape[0] == n + 1);
|
||||
NanoVector shape;
|
||||
for (int i = 0; i < x->shape.size() - 1; ++i) {
|
||||
shape.push_back(x->shape[i]);
|
||||
}
|
||||
if (keepdims) {
|
||||
shape.push_back(1);
|
||||
}
|
||||
y->set_shape(shape);
|
||||
y_key->set_shape(shape);
|
||||
}
|
||||
|
||||
void CubArgReduceOp::jit_prepare() {
|
||||
add_jit_define("Tx", x->dtype());
|
||||
add_jit_define("Toffsets", offsets->dtype());
|
||||
add_jit_define("FUNC", op=="min" ? "ArgMin" : "ArgMax");
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
#ifdef JIT_cuda
|
||||
|
||||
static __global__ void split(cub::KeyValuePair<int, Tx>* a, Tx* key, int* val, int n) {
|
||||
int tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
int tnum = blockDim.x * gridDim.x;
|
||||
for (int i=tid; i<n; i+=tnum) {
|
||||
val[i] = a[i].key;
|
||||
key[i] = a[i].value;
|
||||
}
|
||||
}
|
||||
|
||||
void CubArgReduceOp::jit_run() {
|
||||
auto* __restrict__ xp = x->ptr<Tx>();
|
||||
auto* __restrict__ offsetsp = offsets->ptr<Toffsets>();
|
||||
|
||||
int num_segments = 1;
|
||||
for (int i = 0; i < x->shape.size() - 1; ++i) {
|
||||
num_segments *= x->shape[i];
|
||||
}
|
||||
size_t allocation_dout;
|
||||
cub::KeyValuePair<int, Tx> *d_out = (cub::KeyValuePair<int, Tx> *)exe.allocator->alloc(sizeof(cub::KeyValuePair<int, Tx>) * num_segments, allocation_dout);
|
||||
|
||||
// Determine temporary device storage requirementse = NULL;
|
||||
void *d_temp_storage = NULL;
|
||||
size_t temp_storage_bytes = 0;
|
||||
cub::DeviceSegmentedReduce::@FUNC@@(d_temp_storage, temp_storage_bytes,
|
||||
xp, d_out, num_segments, offsetsp, offsetsp + 1);
|
||||
// Allocate temporary storage
|
||||
size_t allocation;
|
||||
d_temp_storage = exe.allocator->alloc(temp_storage_bytes, allocation);
|
||||
// Run sorting operation
|
||||
cub::DeviceSegmentedReduce::@FUNC@@(d_temp_storage, temp_storage_bytes,
|
||||
xp, d_out, num_segments, offsetsp, offsetsp + 1);
|
||||
|
||||
auto* __restrict__ yp = y->ptr<int>();
|
||||
auto* __restrict__ y_keyp = y_key->ptr<Tx>();
|
||||
split<<<max(1,num_segments/1024),1024>>>(d_out, y_keyp, yp, num_segments);
|
||||
|
||||
exe.allocator->free(d_temp_storage, temp_storage_bytes, allocation);
|
||||
exe.allocator->free(d_out, sizeof(cub::KeyValuePair<int, Tx>) * num_segments, allocation_dout);
|
||||
}
|
||||
#endif // JIT_cuda
|
||||
#endif // JIT
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,28 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors:
|
||||
// Guoye Yang <498731903@qq.com>
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include "op.h"
|
||||
|
||||
|
||||
namespace jittor {
|
||||
|
||||
struct CubArgReduceOp : Op {
|
||||
Var* x, * offsets, * y, * y_key;
|
||||
string op;
|
||||
bool keepdims;
|
||||
// @attrs(multiple_outputs)
|
||||
CubArgReduceOp(Var* x, Var* offsets, string op, bool keepdims);
|
||||
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
|
||||
void infer_shape() override;
|
||||
|
||||
const char* name() const override { return "cub_arg_reduce"; }
|
||||
DECLARE_jit_run;
|
||||
};
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,93 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors:
|
||||
// Guoye Yang <498731903@qq.com>
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include <algorithm>
|
||||
#include "var.h"
|
||||
#include "cub_argsort_op.h"
|
||||
#include <vector>
|
||||
#include "executor.h"
|
||||
#include "ops/argsort_op.h"
|
||||
#ifdef JIT_cuda
|
||||
#include <cub/device/device_segmented_radix_sort.cuh>
|
||||
#endif
|
||||
|
||||
namespace jittor {
|
||||
|
||||
#ifndef JIT
|
||||
CubArgsortOp::CubArgsortOp(Var* x, Var* indexes, Var* offsets, bool descending, NanoString dtype)
|
||||
: x(x), indexes(indexes), offsets(offsets), descending(descending) {
|
||||
flags.set(NodeFlags::_cpu, 0);
|
||||
flags.set(NodeFlags::_cuda, 1);
|
||||
ASSERT(offsets->dtype()==ns_int || offsets->dtype()==ns_int32);
|
||||
y = create_output(nullptr, dtype);
|
||||
y_key = create_output(nullptr, x->dtype());
|
||||
}
|
||||
|
||||
VarPtr CubArgsortOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
||||
return ArgsortOp::get_grad(out, dout, v, v_index, v->shape.size()-1, y);
|
||||
}
|
||||
|
||||
void CubArgsortOp::infer_shape() {
|
||||
ASSERT(x->shape.size() == indexes->shape.size());
|
||||
int n = 1;
|
||||
for (int i = 0; i < x->shape.size(); ++i) {
|
||||
ASSERT(x->shape[i] == indexes->shape[i]);
|
||||
if (i < x->shape.size() - 1) {
|
||||
n *= x->shape[i];
|
||||
}
|
||||
}
|
||||
ASSERT(offsets->shape.size() == 1);
|
||||
ASSERT(offsets->shape[0] == n + 1);
|
||||
y->set_shape(x->shape);
|
||||
y_key->set_shape(x->shape);
|
||||
}
|
||||
|
||||
void CubArgsortOp::jit_prepare() {
|
||||
add_jit_define("Tx", x->dtype());
|
||||
add_jit_define("Tindexes", indexes->dtype());
|
||||
add_jit_define("Toffsets", offsets->dtype());
|
||||
add_jit_define("Ty", y->dtype());
|
||||
add_jit_define("FUNC", descending ? "SortPairsDescending" : "SortPairs");
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
#ifdef JIT_cuda
|
||||
void CubArgsortOp::jit_run() {
|
||||
auto* __restrict__ xp = x->ptr<Tx>();
|
||||
auto* __restrict__ indexesp = indexes->ptr<Tindexes>();
|
||||
auto* __restrict__ offsetsp = offsets->ptr<Toffsets>();
|
||||
|
||||
int num_items = 1, num_segments = 1;
|
||||
for (int i = 0; i < x->shape.size(); ++i) {
|
||||
num_items *= x->shape[i];
|
||||
if (i < x->shape.size() - 1) {
|
||||
num_segments *= x->shape[i];
|
||||
}
|
||||
}
|
||||
auto* __restrict__ yp = y->ptr<Ty>();
|
||||
auto* __restrict__ y_keyp = y_key->ptr<Tx>();
|
||||
|
||||
// Determine temporary device storage requirementse = NULL;
|
||||
void *d_temp_storage = NULL;
|
||||
size_t temp_storage_bytes = 0;
|
||||
cub::DeviceSegmentedRadixSort::@FUNC@@(d_temp_storage, temp_storage_bytes,
|
||||
xp, y_keyp, indexesp, yp,
|
||||
num_items, num_segments, offsetsp, offsetsp + 1);
|
||||
// Allocate temporary storage
|
||||
size_t allocation;
|
||||
d_temp_storage = exe.allocator->alloc(temp_storage_bytes, allocation);
|
||||
// Run sorting operation
|
||||
cub::DeviceSegmentedRadixSort::@FUNC@@(d_temp_storage, temp_storage_bytes,
|
||||
xp, y_keyp, indexesp, yp,
|
||||
num_items, num_segments, offsetsp, offsetsp + 1);
|
||||
exe.allocator->free(d_temp_storage, temp_storage_bytes, allocation);
|
||||
}
|
||||
#endif // JIT_cuda
|
||||
#endif // JIT
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,27 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors:
|
||||
// Guoye Yang <498731903@qq.com>
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include "op.h"
|
||||
|
||||
|
||||
namespace jittor {
|
||||
|
||||
struct CubArgsortOp : Op {
|
||||
Var* x, * indexes, * offsets, * y, * y_key;
|
||||
bool descending;
|
||||
// @attrs(multiple_outputs)
|
||||
CubArgsortOp(Var* x, Var* indexes, Var* offsets, bool descending=false, NanoString dtype=ns_int32);
|
||||
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
|
||||
void infer_shape() override;
|
||||
|
||||
const char* name() const override { return "cub_argsort"; }
|
||||
DECLARE_jit_run;
|
||||
};
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,43 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include <random>
|
||||
|
||||
#include "var.h"
|
||||
#include "cub_test_op.h"
|
||||
#include "misc/str_utils.h"
|
||||
|
||||
#ifdef JIT
|
||||
#include "cub_test.h"
|
||||
#endif
|
||||
|
||||
namespace jittor {
|
||||
|
||||
#ifndef JIT
|
||||
CubTestOp::CubTestOp(string cmd) : cmd(cmd) {
|
||||
flags.set(NodeFlags::_cpu, 0);
|
||||
flags.set(NodeFlags::_cuda, 1);
|
||||
output = create_output(1, ns_float32);
|
||||
}
|
||||
|
||||
void CubTestOp::jit_prepare() {
|
||||
add_jit_define("T", ns_float32);
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
#ifdef JIT_cuda
|
||||
void CubTestOp::jit_run() {
|
||||
auto args = split(cmd, " ");
|
||||
if (!cmd.size()) args.clear();
|
||||
vector<char*> v(args.size());
|
||||
for (uint i=0; i<args.size(); i++)
|
||||
v[i] = &args[i][0];
|
||||
ASSERT(cub_test_entry(v.size(), &v[0])==0);
|
||||
output->ptr<T>()[0] = 123;
|
||||
}
|
||||
#endif
|
||||
#endif // JIT
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,21 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include "op.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
struct CubTestOp : Op {
|
||||
Var* output;
|
||||
string cmd;
|
||||
|
||||
CubTestOp(string cmd);
|
||||
|
||||
const char* name() const override { return "cub_test"; }
|
||||
DECLARE_jit_run;
|
||||
};
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,21 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors:
|
||||
// Guoye Yang <498731903@qq.com>
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include <cuda_runtime.h>
|
||||
#include <cublas_v2.h>
|
||||
|
||||
#include "helper_cuda.h"
|
||||
#include "fp16_emu.h"
|
||||
#include "common.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
extern cublasHandle_t cublas_handle;
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,82 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors:
|
||||
// Guowei Yang <471184555@qq.com>
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
|
||||
#include "var.h"
|
||||
#include "cublas_matmul_op.h"
|
||||
#include "cublas_warper.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
namespace jittor {
|
||||
|
||||
#ifndef JIT
|
||||
|
||||
CublasMatmulOp::CublasMatmulOp(Var* a, Var* b, bool trans_a, bool trans_b)
|
||||
: a(a), b(b), trans_a(trans_a), trans_b(trans_b) {
|
||||
// TODO: support int8 * int8
|
||||
ASSERT(a->dtype().is_float() && b->dtype().is_float()) << "type of two inputs should be the same";
|
||||
// TODO: support diffrent input type
|
||||
ASSERT(a->dtype().dsize() == b->dtype().dsize()) << "type of two inputs should be the same";
|
||||
c = create_output(nullptr, a->dtype());
|
||||
}
|
||||
|
||||
void CublasMatmulOp::infer_shape() {
|
||||
ASSERTop(a->shape.size(),==,2);
|
||||
ASSERTop(b->shape.size(),==,2);
|
||||
int n = a->shape[0], m = a->shape[1];
|
||||
int m_ = b->shape[0], k = b->shape[1];
|
||||
if (trans_a) {
|
||||
swap(n, m);
|
||||
}
|
||||
if (trans_b) {
|
||||
swap(m_, k);
|
||||
}
|
||||
ASSERTop(m,==,m_);
|
||||
c->set_shape({n, k});
|
||||
}
|
||||
|
||||
void CublasMatmulOp::jit_prepare() {
|
||||
add_jit_define("T", a->dtype());
|
||||
add_jit_define("Trans_a", trans_a ? "T" : "N");
|
||||
add_jit_define("Trans_b", trans_b ? "T" : "N");
|
||||
add_jit_define("op", a->dtype().dsize() == 4 ? "S" : "D");
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
#ifdef JIT_cpu
|
||||
#pragma clang diagnostic ignored "-Wtautological-compare"
|
||||
void CublasMatmulOp::jit_run() {
|
||||
cublasHandle_t& handle_ = cublas_handle;
|
||||
const T alpha = 1.0f;
|
||||
const T beta = 0.0f;
|
||||
|
||||
const auto& as = a->shape;
|
||||
const auto& bs = b->shape;
|
||||
auto n = as[0];
|
||||
auto m = as[1];
|
||||
auto k = bs[1];
|
||||
if ('@Trans_a'=='T') {
|
||||
n = as[1];
|
||||
m = as[0];
|
||||
}
|
||||
if ('@Trans_b'=='T') {
|
||||
k = bs[0];
|
||||
}
|
||||
// a: [n,m], b: [m,k], c: [n,k]
|
||||
checkCudaErrors(cublas@op@@gemm(handle_,
|
||||
CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a,
|
||||
k, n, m, &alpha,
|
||||
b->ptr<T>(), '@Trans_b' == 'N' ? k : m,
|
||||
a->ptr<T>(), '@Trans_a' == 'N' ? m : n, &beta,
|
||||
c->ptr<T>(), k));
|
||||
}
|
||||
#endif
|
||||
#endif // JIT
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,24 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors:
|
||||
// Guoye Yang <498731903@qq.com>
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include "op.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
struct CublasMatmulOp : Op {
|
||||
Var* a, * b, * c;
|
||||
bool trans_a, trans_b;
|
||||
CublasMatmulOp(Var* a, Var* b, bool trans_a, bool trans_b);
|
||||
|
||||
const char* name() const override { return "cublas_matmul"; }
|
||||
void infer_shape() override;
|
||||
DECLARE_jit_run;
|
||||
};
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,33 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include <random>
|
||||
|
||||
#include "var.h"
|
||||
#include "cublas_test_op.h"
|
||||
|
||||
int cublas_test_entry(int);
|
||||
|
||||
namespace jittor {
|
||||
|
||||
#ifndef JIT
|
||||
CublasTestOp::CublasTestOp(int size_mult) : size_mult(size_mult) {
|
||||
output = create_output(1, ns_float32);
|
||||
}
|
||||
|
||||
void CublasTestOp::jit_prepare() {
|
||||
add_jit_define("T", ns_float32);
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
#ifdef JIT_cpu
|
||||
void CublasTestOp::jit_run() {
|
||||
ASSERT(cublas_test_entry(size_mult)==0);
|
||||
output->ptr<T>()[0] = 123;
|
||||
}
|
||||
#endif
|
||||
#endif // JIT
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,21 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include "op.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
struct CublasTestOp : Op {
|
||||
Var* output;
|
||||
int size_mult;
|
||||
|
||||
CublasTestOp(int size_mult);
|
||||
|
||||
const char* name() const override { return "cublas_test"; }
|
||||
DECLARE_jit_run;
|
||||
};
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,352 @@
|
|||
////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// Copyright 1993-2015 NVIDIA Corporation. All rights reserved.
|
||||
//
|
||||
// Please refer to the NVIDIA end user license agreement (EULA) associated
|
||||
// with this source code for terms and conditions that govern your use of
|
||||
// this software. Any use, reproduction, disclosure, or distribution of
|
||||
// this software and related documentation outside the terms of the EULA
|
||||
// is strictly prohibited.
|
||||
//
|
||||
////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
//
|
||||
// Matrix multiplication: C = A * B.
|
||||
// Host code.
|
||||
//
|
||||
// This sample implements matrix multiplication as described in Chapter 3
|
||||
// of the programming guide and uses the CUBLAS library to demonstrate
|
||||
// the best performance.
|
||||
|
||||
// SOME PRECAUTIONS:
|
||||
// IF WE WANT TO CALCULATE ROW-MAJOR MATRIX MULTIPLY C = A * B,
|
||||
// WE JUST NEED CALL CUBLAS API IN A REVERSE ORDER: cublasSegemm(B, A)!
|
||||
// The reason is explained as follows:
|
||||
|
||||
// CUBLAS library uses column-major storage, but C/C++ use row-major storage.
|
||||
// When passing the matrix pointer to CUBLAS, the memory layout alters from
|
||||
// row-major to column-major, which is equivalent to an implicit transpose.
|
||||
|
||||
// In the case of row-major C/C++ matrix A, B, and a simple matrix multiplication
|
||||
// C = A * B, we can't use the input order like cublasSgemm(A, B) because of
|
||||
// implicit transpose. The actual result of cublasSegemm(A, B) is A(T) * B(T).
|
||||
// If col(A(T)) != row(B(T)), equal to row(A) != col(B), A(T) and B(T) are not
|
||||
// multipliable. Moreover, even if A(T) and B(T) are multipliable, the result C
|
||||
// is a column-based cublas matrix, which means C(T) in C/C++, we need extra
|
||||
// transpose code to convert it to a row-based C/C++ matrix.
|
||||
|
||||
// To solve the problem, let's consider our desired result C, a row-major matrix.
|
||||
// In cublas format, it is C(T) actually (because of the implicit transpose).
|
||||
// C = A * B, so C(T) = (A * B) (T) = B(T) * A(T). Cublas matrice B(T) and A(T)
|
||||
// happen to be C/C++ matrice B and A (still because of the implicit transpose)!
|
||||
// We don't need extra transpose code, we only need alter the input order!
|
||||
//
|
||||
// CUBLAS provides high-performance matrix multiplication.
|
||||
// See also:
|
||||
// V. Volkov and J. Demmel, "Benchmarking GPUs to tune dense linear algebra,"
|
||||
// in Proc. 2008 ACM/IEEE Conf. on Supercomputing (SC '08),
|
||||
// Piscataway, NJ: IEEE Press, 2008, pp. Art. 31:1-11.
|
||||
//
|
||||
|
||||
// Utilities and system includes
|
||||
#include <assert.h>
|
||||
#include <helper_string.h> // helper for shared functions common to CUDA Samples
|
||||
|
||||
// CUDA runtime
|
||||
#include <cuda_runtime.h>
|
||||
#include <cublas_v2.h>
|
||||
|
||||
// CUDA and CUBLAS functions
|
||||
#include <helper_functions.h>
|
||||
#include <helper_cuda.h>
|
||||
|
||||
#ifndef min
|
||||
#define min(a,b) ((a < b) ? a : b)
|
||||
#endif
|
||||
#ifndef max
|
||||
#define max(a,b) ((a > b) ? a : b)
|
||||
#endif
|
||||
|
||||
typedef struct _matrixSize // Optional Command-line multiplier for matrix sizes
|
||||
{
|
||||
unsigned int uiWA, uiHA, uiWB, uiHB, uiWC, uiHC;
|
||||
} sMatrixSize;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
//! Compute reference data set matrix multiply on CPU
|
||||
//! C = A * B
|
||||
//! @param C reference data, computed but preallocated
|
||||
//! @param A matrix A as provided to device
|
||||
//! @param B matrix B as provided to device
|
||||
//! @param hA height of matrix A
|
||||
//! @param wB width of matrix B
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
void
|
||||
matrixMulCPU(float *C, const float *A, const float *B, unsigned int hA, unsigned int wA, unsigned int wB)
|
||||
{
|
||||
for (unsigned int i = 0; i < hA; ++i)
|
||||
for (unsigned int j = 0; j < wB; ++j)
|
||||
{
|
||||
double sum = 0;
|
||||
|
||||
for (unsigned int k = 0; k < wA; ++k)
|
||||
{
|
||||
double a = A[i * wA + k];
|
||||
double b = B[k * wB + j];
|
||||
sum += a * b;
|
||||
}
|
||||
|
||||
C[i * wB + j] = (float)sum;
|
||||
}
|
||||
}
|
||||
|
||||
// Allocates a matrix with random float entries.
|
||||
void randomInit(float *data, int size)
|
||||
{
|
||||
for (int i = 0; i < size; ++i)
|
||||
data[i] = rand() / (float)RAND_MAX;
|
||||
}
|
||||
|
||||
void printDiff(float *data1, float *data2, int width, int height, int iListLength, float fListTol)
|
||||
{
|
||||
printf("Listing first %d Differences > %.6f...\n", iListLength, fListTol);
|
||||
int i,j,k;
|
||||
int error_count=0;
|
||||
|
||||
for (j = 0; j < height; j++)
|
||||
{
|
||||
if (error_count < iListLength)
|
||||
{
|
||||
printf("\n Row %d:\n", j);
|
||||
}
|
||||
|
||||
for (i = 0; i < width; i++)
|
||||
{
|
||||
k = j * width + i;
|
||||
float fDiff = fabs(data1[k] - data2[k]);
|
||||
|
||||
if (fDiff > fListTol)
|
||||
{
|
||||
if (error_count < iListLength)
|
||||
{
|
||||
printf(" Loc(%d,%d)\tCPU=%.5f\tGPU=%.5f\tDiff=%.6f\n", i, j, data1[k], data2[k], fDiff);
|
||||
}
|
||||
|
||||
error_count++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
printf(" \n Total Errors = %d\n", error_count);
|
||||
}
|
||||
|
||||
void initializeCUDA(int &devID, int &iSizeMultiple, sMatrixSize &matrix_size)
|
||||
{
|
||||
// By default, we use device 0, otherwise we override the device ID based on what is provided at the command line
|
||||
cudaError_t error;
|
||||
devID = 0;
|
||||
|
||||
iSizeMultiple = min(iSizeMultiple, 100);
|
||||
iSizeMultiple = max(iSizeMultiple, 1);
|
||||
|
||||
cudaDeviceProp deviceProp;
|
||||
|
||||
error = cudaGetDeviceProperties(&deviceProp, devID);
|
||||
|
||||
if (error != cudaSuccess)
|
||||
{
|
||||
printf("cudaGetDeviceProperties returned error code %d, line(%d)\n", error, __LINE__);
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
printf("GPU Device %d: \"%s\" with compute capability %d.%d\n\n", devID, deviceProp.name, deviceProp.major, deviceProp.minor);
|
||||
|
||||
int block_size = 32;
|
||||
|
||||
matrix_size.uiWA = 3 * block_size * iSizeMultiple;
|
||||
matrix_size.uiHA = 4 * block_size * iSizeMultiple;
|
||||
matrix_size.uiWB = 2 * block_size * iSizeMultiple;
|
||||
matrix_size.uiHB = 3 * block_size * iSizeMultiple;
|
||||
matrix_size.uiWC = 2 * block_size * iSizeMultiple;
|
||||
matrix_size.uiHC = 4 * block_size * iSizeMultiple;
|
||||
|
||||
printf("MatrixA(%u,%u), MatrixB(%u,%u), MatrixC(%u,%u)\n",
|
||||
matrix_size.uiHA, matrix_size.uiWA,
|
||||
matrix_size.uiHB, matrix_size.uiWB,
|
||||
matrix_size.uiHC, matrix_size.uiWC);
|
||||
|
||||
if( matrix_size.uiWA != matrix_size.uiHB ||
|
||||
matrix_size.uiHA != matrix_size.uiHC ||
|
||||
matrix_size.uiWB != matrix_size.uiWC)
|
||||
{
|
||||
printf("ERROR: Matrix sizes do not match!\n");
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
//! Run a simple test matrix multiply using CUBLAS
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
int matrixMultiply(int devID, sMatrixSize &matrix_size)
|
||||
{
|
||||
cudaDeviceProp deviceProp;
|
||||
|
||||
checkCudaErrors(cudaGetDeviceProperties(&deviceProp, devID));
|
||||
|
||||
int block_size = 32;
|
||||
|
||||
// set seed for rand()
|
||||
srand(2006);
|
||||
|
||||
// allocate host memory for matrices A and B
|
||||
unsigned int size_A = matrix_size.uiWA * matrix_size.uiHA;
|
||||
unsigned int mem_size_A = sizeof(float) * size_A;
|
||||
float *h_A = (float *)malloc(mem_size_A);
|
||||
unsigned int size_B = matrix_size.uiWB * matrix_size.uiHB;
|
||||
unsigned int mem_size_B = sizeof(float) * size_B;
|
||||
float *h_B = (float *)malloc(mem_size_B);
|
||||
|
||||
// set seed for rand()
|
||||
srand(2006);
|
||||
|
||||
// initialize host memory
|
||||
randomInit(h_A, size_A);
|
||||
randomInit(h_B, size_B);
|
||||
|
||||
// allocate device memory
|
||||
float *d_A, *d_B, *d_C;
|
||||
unsigned int size_C = matrix_size.uiWC * matrix_size.uiHC;
|
||||
unsigned int mem_size_C = sizeof(float) * size_C;
|
||||
|
||||
// allocate host memory for the result
|
||||
float *h_C = (float *) malloc(mem_size_C);
|
||||
float *h_CUBLAS = (float *) malloc(mem_size_C);
|
||||
|
||||
checkCudaErrors(cudaMalloc((void **) &d_A, mem_size_A));
|
||||
checkCudaErrors(cudaMalloc((void **) &d_B, mem_size_B));
|
||||
checkCudaErrors(cudaMemcpy(d_A, h_A, mem_size_A, cudaMemcpyHostToDevice));
|
||||
checkCudaErrors(cudaMemcpy(d_B, h_B, mem_size_B, cudaMemcpyHostToDevice));
|
||||
checkCudaErrors(cudaMalloc((void **) &d_C, mem_size_C));
|
||||
|
||||
// setup execution parameters
|
||||
dim3 threads(block_size, block_size);
|
||||
dim3 grid(matrix_size.uiWC / threads.x, matrix_size.uiHC / threads.y);
|
||||
|
||||
// create and start timer
|
||||
printf("Computing result using CUBLAS...");
|
||||
|
||||
// execute the kernel
|
||||
int nIter = 30;
|
||||
|
||||
// CUBLAS version 2.0
|
||||
{
|
||||
const float alpha = 1.0f;
|
||||
const float beta = 0.0f;
|
||||
cublasHandle_t handle;
|
||||
cudaEvent_t start, stop;
|
||||
|
||||
checkCudaErrors(cublasCreate(&handle));
|
||||
|
||||
//Perform warmup operation with cublas
|
||||
checkCudaErrors(cublasSgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, matrix_size.uiWB, matrix_size.uiHA, matrix_size.uiWA, &alpha, d_B, matrix_size.uiWB, d_A, matrix_size.uiWA, &beta, d_C, matrix_size.uiWB));
|
||||
|
||||
// Allocate CUDA events that we'll use for timing
|
||||
checkCudaErrors(cudaEventCreate(&start));
|
||||
checkCudaErrors(cudaEventCreate(&stop));
|
||||
|
||||
// Record the start event
|
||||
checkCudaErrors(cudaEventRecord(start, NULL));
|
||||
|
||||
for (int j = 0; j < nIter; j++)
|
||||
{
|
||||
//note cublas is column primary!
|
||||
//need to transpose the order
|
||||
checkCudaErrors(cublasSgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, matrix_size.uiWB, matrix_size.uiHA, matrix_size.uiWA, &alpha, d_B, matrix_size.uiWB, d_A, matrix_size.uiWA, &beta, d_C, matrix_size.uiWB));
|
||||
|
||||
}
|
||||
|
||||
printf("done.\n");
|
||||
|
||||
// Record the stop event
|
||||
checkCudaErrors(cudaEventRecord(stop, NULL));
|
||||
|
||||
// Wait for the stop event to complete
|
||||
checkCudaErrors(cudaEventSynchronize(stop));
|
||||
|
||||
float msecTotal = 0.0f;
|
||||
checkCudaErrors(cudaEventElapsedTime(&msecTotal, start, stop));
|
||||
|
||||
// Compute and print the performance
|
||||
float msecPerMatrixMul = msecTotal / nIter;
|
||||
double flopsPerMatrixMul = 2.0 * (double)matrix_size.uiHC * (double)matrix_size.uiWC * (double)matrix_size.uiHB;
|
||||
double gigaFlops = (flopsPerMatrixMul * 1.0e-9f) / (msecPerMatrixMul / 1000.0f);
|
||||
printf(
|
||||
"Performance= %.2f GFlop/s, Time= %.3f msec, Size= %.0f Ops\n",
|
||||
gigaFlops,
|
||||
msecPerMatrixMul,
|
||||
flopsPerMatrixMul);
|
||||
|
||||
// copy result from device to host
|
||||
checkCudaErrors(cudaMemcpy(h_CUBLAS, d_C, mem_size_C, cudaMemcpyDeviceToHost));
|
||||
|
||||
// Destroy the handle
|
||||
checkCudaErrors(cublasDestroy(handle));
|
||||
}
|
||||
|
||||
// compute reference solution
|
||||
printf("Computing result using host CPU...");
|
||||
float *reference = (float *)malloc(mem_size_C);
|
||||
bool resCUBLAS = true;
|
||||
// only compare with cpu when size smaller than 1000
|
||||
if (matrix_size.uiHA < 1000) {
|
||||
matrixMulCPU(reference, h_A, h_B, matrix_size.uiHA, matrix_size.uiWA, matrix_size.uiWB);
|
||||
printf("done.\n");
|
||||
|
||||
// check result (CUBLAS)
|
||||
resCUBLAS = sdkCompareL2fe(reference, h_CUBLAS, size_C, 1.0e-6f);
|
||||
|
||||
if (resCUBLAS != true)
|
||||
{
|
||||
printDiff(reference, h_CUBLAS, matrix_size.uiWC, matrix_size.uiHC, 100, 1.0e-5f);
|
||||
}
|
||||
|
||||
printf("Comparing CUBLAS Matrix Multiply with CPU results: %s\n", (true == resCUBLAS) ? "PASS" : "FAIL");
|
||||
|
||||
printf("\nNOTE: The CUDA Samples are not meant for performance measurements. Results may vary when GPU Boost is enabled.\n");
|
||||
|
||||
}
|
||||
// clean up memory
|
||||
free(h_A);
|
||||
free(h_B);
|
||||
free(h_C);
|
||||
free(reference);
|
||||
checkCudaErrors(cudaFree(d_A));
|
||||
checkCudaErrors(cudaFree(d_B));
|
||||
checkCudaErrors(cudaFree(d_C));
|
||||
|
||||
if (resCUBLAS == true)
|
||||
{
|
||||
return EXIT_SUCCESS; // return value = 1
|
||||
}
|
||||
else
|
||||
{
|
||||
return EXIT_FAILURE; // return value = 0
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Program main
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
int cublas_test_entry(int sizeMult)
|
||||
{
|
||||
printf("[Matrix Multiply CUBLAS] - Starting...\n");
|
||||
|
||||
int devID = 0;
|
||||
sMatrixSize matrix_size;
|
||||
|
||||
initializeCUDA(devID, sizeMult, matrix_size);
|
||||
|
||||
int matrix_result = matrixMultiply(devID, matrix_size);
|
||||
|
||||
return matrix_result;
|
||||
}
|
|
@ -0,0 +1,29 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors:
|
||||
// Guoye Yang <498731903@qq.com>
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include "cublas_warper.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
cublasHandle_t cublas_handle;
|
||||
|
||||
struct cublas_initer {
|
||||
|
||||
inline cublas_initer() {
|
||||
checkCudaErrors(cublasCreate(&cublas_handle));
|
||||
LOGv << "cublasCreate finished";
|
||||
}
|
||||
|
||||
inline ~cublas_initer() {
|
||||
checkCudaErrors(cublasDestroy(cublas_handle));
|
||||
LOGv << "cublasDestroy finished";
|
||||
}
|
||||
|
||||
} init;
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,56 @@
|
|||
/**
|
||||
* Copyright 1993-2017 NVIDIA Corporation. All rights reserved.
|
||||
*
|
||||
* Please refer to the NVIDIA end user license agreement (EULA) associated
|
||||
* with this source code for terms and conditions that govern your use of
|
||||
* this software. Any use, reproduction, disclosure, or distribution of
|
||||
* this software and related documentation outside the terms of the EULA
|
||||
* is strictly prohibited.
|
||||
*
|
||||
*/
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// These are CUDA Helper functions for initialization and error checking
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cublas.h>
|
||||
#include <helper_cuda.h>
|
||||
|
||||
#ifdef CUBLAS_API_H_
|
||||
// cuBLAS API errors
|
||||
const char *_cudaGetErrorEnum(cublasStatus_t error) {
|
||||
switch (error) {
|
||||
case CUBLAS_STATUS_SUCCESS:
|
||||
return "CUBLAS_STATUS_SUCCESS";
|
||||
|
||||
case CUBLAS_STATUS_NOT_INITIALIZED:
|
||||
return "CUBLAS_STATUS_NOT_INITIALIZED";
|
||||
|
||||
case CUBLAS_STATUS_ALLOC_FAILED:
|
||||
return "CUBLAS_STATUS_ALLOC_FAILED";
|
||||
|
||||
case CUBLAS_STATUS_INVALID_VALUE:
|
||||
return "CUBLAS_STATUS_INVALID_VALUE";
|
||||
|
||||
case CUBLAS_STATUS_ARCH_MISMATCH:
|
||||
return "CUBLAS_STATUS_ARCH_MISMATCH";
|
||||
|
||||
case CUBLAS_STATUS_MAPPING_ERROR:
|
||||
return "CUBLAS_STATUS_MAPPING_ERROR";
|
||||
|
||||
case CUBLAS_STATUS_EXECUTION_FAILED:
|
||||
return "CUBLAS_STATUS_EXECUTION_FAILED";
|
||||
|
||||
case CUBLAS_STATUS_INTERNAL_ERROR:
|
||||
return "CUBLAS_STATUS_INTERNAL_ERROR";
|
||||
|
||||
case CUBLAS_STATUS_NOT_SUPPORTED:
|
||||
return "CUBLAS_STATUS_NOT_SUPPORTED";
|
||||
|
||||
case CUBLAS_STATUS_LICENSE_ERROR:
|
||||
return "CUBLAS_STATUS_LICENSE_ERROR";
|
||||
}
|
||||
|
||||
return "<unknown>";
|
||||
}
|
||||
#endif
|
|
@ -0,0 +1,19 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include <cuda_runtime.h>
|
||||
#include <cudnn.h>
|
||||
|
||||
#include "helper_cuda.h"
|
||||
#include "fp16_emu.h"
|
||||
#include "common.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
extern cudnnHandle_t cudnn_handle;
|
||||
constexpr int max_cache_size=100;
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,272 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors:
|
||||
// Dun Liang <randonlang@gmail.com>
|
||||
// Guowei Yang <471184555@qq.com>
|
||||
// All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include "var.h"
|
||||
#include "cudnn_conv_backward_w_op.h"
|
||||
#include "cudnn_warper.h"
|
||||
#include "executor.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
namespace jittor {
|
||||
static inline int findc(const string& format, const char& c) {
|
||||
if (c==format[0]) return 0;
|
||||
if (c==format[1]) return 1;
|
||||
if (c==format[2]) return 2;
|
||||
ASSERT(c==format[3]) << "Not a valid format" << format << c;
|
||||
return 3;
|
||||
}
|
||||
|
||||
#ifndef JIT
|
||||
static inline void get_shape(Var* x, const char* f, const string& format, int& a, int& b, int &c, int& d) {
|
||||
auto& shape = x->shape;
|
||||
a = shape[findc(format, f[0])];
|
||||
b = shape[findc(format, f[1])];
|
||||
c = shape[findc(format, f[2])];
|
||||
d = shape[findc(format, f[3])];
|
||||
}
|
||||
|
||||
static inline void set_shape(Var* x, const char* f, const string& format, int a, int b, int c, int d) {
|
||||
int64 shape[4];
|
||||
shape[findc(format, f[0])] = a;
|
||||
shape[findc(format, f[1])] = b;
|
||||
shape[findc(format, f[2])] = c;
|
||||
shape[findc(format, f[3])] = d;
|
||||
x->set_shape(NanoVector(
|
||||
shape[0], shape[1], shape[2], shape[3]));
|
||||
}
|
||||
|
||||
CudnnConvBackwardWOp::CudnnConvBackwardWOp(Var* x, Var* dy, int kernel_size, int stride, int padding, int dilation, string xformat, string wformat, string yformat)
|
||||
: x(x), dy(dy), kernel_size(kernel_size), stride(stride), padding(padding), dilation(dilation),
|
||||
xformat(move(xformat)), wformat(move(wformat)), yformat(move(yformat)) {
|
||||
flags.set(NodeFlags::_cuda, 1);
|
||||
flags.set(NodeFlags::_cpu, 0);
|
||||
dw = create_output(nullptr, dtype_infer(dy->ns, x->ns));
|
||||
}
|
||||
|
||||
void CudnnConvBackwardWOp::infer_shape() {
|
||||
ASSERTop(x->shape.size(),==,4);
|
||||
ASSERTop(dy->shape.size(),==,4);
|
||||
int xn, xc, xh, xw, wh, ww, wci, wco, yn, yc, yh, yw;
|
||||
get_shape(x, "abcd", xformat, xn, xc, xh, xw);
|
||||
get_shape(dy, "abcd", yformat, yn, yc, yh, yw);
|
||||
wco = yc, wci = xc;
|
||||
wh = kernel_size;
|
||||
ww = kernel_size;
|
||||
set_shape(dw, "oihw", wformat, wco, wci, wh, ww);
|
||||
}
|
||||
|
||||
void CudnnConvBackwardWOp::jit_prepare() {
|
||||
add_jit_define("Tx", x->dtype());
|
||||
add_jit_define("Ty", dy->dtype());
|
||||
add_jit_define("Tw", dw->dtype());
|
||||
add_jit_define("XFORMAT", xformat);
|
||||
add_jit_define("WFORMAT", wformat);
|
||||
add_jit_define("YFORMAT", yformat);
|
||||
}
|
||||
unordered_map<string, cudnnConvolutionBwdFilterAlgo_t> bwdw_algo_cache;
|
||||
|
||||
#else // JIT
|
||||
#ifdef JIT_cuda
|
||||
|
||||
#pragma clang diagnostic ignored "-Wtautological-compare"
|
||||
|
||||
extern unordered_map<string, cudnnConvolutionBwdFilterAlgo_t> bwdw_algo_cache;
|
||||
|
||||
template <typename T_ELEM> __inline__ cudnnDataType_t getDataType();
|
||||
template <> __inline__ cudnnDataType_t getDataType<half1>() { return CUDNN_DATA_HALF; }
|
||||
template <> __inline__ cudnnDataType_t getDataType<float>() { return CUDNN_DATA_FLOAT; }
|
||||
|
||||
void CudnnConvBackwardWOp::jit_run() {
|
||||
auto w = dw;
|
||||
auto y = dy;
|
||||
cudnnHandle_t& handle_ = cudnn_handle;
|
||||
|
||||
cudnnTensorDescriptor_t cudnnIdesc;
|
||||
cudnnFilterDescriptor_t cudnnFdesc;
|
||||
cudnnTensorDescriptor_t cudnnOdesc;
|
||||
cudnnConvolutionDescriptor_t cudnnConvDesc;
|
||||
|
||||
checkCudaErrors(cudnnCreateTensorDescriptor( &cudnnIdesc ));
|
||||
checkCudaErrors(cudnnCreateFilterDescriptor( &cudnnFdesc ));
|
||||
checkCudaErrors(cudnnCreateTensorDescriptor( &cudnnOdesc ));
|
||||
checkCudaErrors(cudnnCreateConvolutionDescriptor( &cudnnConvDesc ));
|
||||
|
||||
int dimX[] = {
|
||||
(int)x->shape[findc("@XFORMAT", 'a')], // n
|
||||
(int)x->shape[findc("@XFORMAT", 'b')], // c
|
||||
(int)x->shape[findc("@XFORMAT", 'c')], // h
|
||||
(int)x->shape[findc("@XFORMAT", 'd')], // w
|
||||
};
|
||||
int _strideX[] = {0,0,0,1};
|
||||
for (int i=2; i>=0; i--) _strideX[i] = _strideX[i+1] * x->shape[i+1];
|
||||
int strideX[] = {
|
||||
_strideX[findc("@XFORMAT", 'a')], // n
|
||||
_strideX[findc("@XFORMAT", 'b')], // c
|
||||
_strideX[findc("@XFORMAT", 'c')], // h
|
||||
_strideX[findc("@XFORMAT", 'd')], // w
|
||||
};
|
||||
// dimX: nchw
|
||||
checkCudaErrors(cudnnSetTensorNdDescriptor(
|
||||
cudnnIdesc, getDataType<Tx>(),
|
||||
4, dimX, strideX
|
||||
));
|
||||
|
||||
auto ws = w->shape;
|
||||
int dimW[] = {(int)ws[0],(int)ws[1],(int)ws[2],(int)ws[3]};
|
||||
// cudnn only support this two format
|
||||
// https://docs.nvidia.com/deeplearning/sdk/cudnn-api/index.html#cudnnSetFilterNdDescriptor
|
||||
#define filterFormat_oihw CUDNN_TENSOR_NCHW
|
||||
#define filterFormat_ohwi CUDNN_TENSOR_NHWC
|
||||
|
||||
// dimW: KCRS(oihw)
|
||||
checkCudaErrors(cudnnSetFilterNdDescriptor(
|
||||
cudnnFdesc, getDataType<Tw>(),
|
||||
filterFormat_@WFORMAT, 4, dimW
|
||||
));
|
||||
|
||||
int padA[] = {padding, padding};
|
||||
int convstrideA[] = {stride, stride};
|
||||
int dilationA[] = {dilation, dilation};
|
||||
// difference between
|
||||
// CUDNN_CONVOLUTION and CUDNN_CROSS_CORRELATION
|
||||
// is the kernel rc order
|
||||
// currently, No perf difference is observed between
|
||||
// this two mode
|
||||
checkCudaErrors(cudnnSetConvolutionNdDescriptor(
|
||||
cudnnConvDesc, /*convDim=*/2,
|
||||
padA, convstrideA, dilationA,
|
||||
CUDNN_CROSS_CORRELATION, getDataType<Ty>()
|
||||
));
|
||||
|
||||
// using tensor core
|
||||
// checkCudaErrors( cudnnSetConvolutionMathType(cudnnConvDesc, CUDNN_TENSOR_OP_MATH) );
|
||||
|
||||
int dimY[] = {
|
||||
(int)y->shape[findc("@YFORMAT", 'a')], // n
|
||||
(int)y->shape[findc("@YFORMAT", 'b')], // c
|
||||
(int)y->shape[findc("@YFORMAT", 'c')], // h
|
||||
(int)y->shape[findc("@YFORMAT", 'd')], // w
|
||||
};
|
||||
int _strideY[] = {0,0,0,1};
|
||||
for (int i=2; i>=0; i--) _strideY[i] = _strideY[i+1] * y->shape[i+1];
|
||||
int strideY[] = {
|
||||
_strideY[findc("@YFORMAT", 'a')], // n
|
||||
_strideY[findc("@YFORMAT", 'b')], // c
|
||||
_strideY[findc("@YFORMAT", 'c')], // h
|
||||
_strideY[findc("@YFORMAT", 'd')], // w
|
||||
};
|
||||
checkCudaErrors(cudnnSetTensorNdDescriptor(
|
||||
cudnnOdesc, getDataType<Ty>(),
|
||||
4, dimY, strideY
|
||||
));
|
||||
|
||||
cudnnConvolutionBwdFilterAlgo_t algos[] = {
|
||||
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0,
|
||||
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1,
|
||||
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT,
|
||||
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3,
|
||||
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED,
|
||||
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING,
|
||||
};
|
||||
int num_algos = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT;
|
||||
int perf_count;
|
||||
cudnnConvolutionBwdFilterAlgoPerf_t perf_results[num_algos];
|
||||
cudnnConvolutionBwdFilterAlgo_t algo;
|
||||
bool benchmark=true;
|
||||
|
||||
jk.clear();
|
||||
jk << dimX[0] << "," << dimX[1] << "," << dimX[2] << "," << dimX[3] << ",";
|
||||
jk << dimW[0] << "," << dimW[1] << "," << dimW[2] << "," << dimW[3] << ",";
|
||||
jk << padding << "," <<stride << "," << dilation << ".";
|
||||
auto iter = bwdw_algo_cache.find(jk.to_string());
|
||||
|
||||
if (iter!=bwdw_algo_cache.end()) algo = iter->second;
|
||||
else {
|
||||
if (bwdw_algo_cache.size()>=max_cache_size) benchmark = false;
|
||||
if (benchmark) {
|
||||
size_t max_ws_size = 0;
|
||||
for (int i = 0; i < num_algos; i++) {
|
||||
size_t sz;
|
||||
cudnnStatus_t ret = cudnnGetConvolutionBackwardFilterWorkspaceSize(handle_, cudnnIdesc, cudnnOdesc, cudnnConvDesc, cudnnFdesc, algos[i], &sz);
|
||||
if (CUDNN_STATUS_SUCCESS == ret && sz > max_ws_size) max_ws_size = sz;
|
||||
}
|
||||
size_t allocation;
|
||||
void* ws = exe.allocator->alloc(max_ws_size, allocation);
|
||||
checkCudaErrors(cudnnFindConvolutionBackwardFilterAlgorithmEx(
|
||||
handle_,
|
||||
cudnnIdesc, x->ptr<Tx>(),
|
||||
cudnnOdesc, y->ptr<Ty>(),
|
||||
cudnnConvDesc,
|
||||
cudnnFdesc, w->ptr<Tw>(),
|
||||
num_algos,
|
||||
&perf_count,
|
||||
perf_results,
|
||||
ws,
|
||||
max_ws_size));
|
||||
exe.allocator->free(ws, max_ws_size, allocation);
|
||||
} else {
|
||||
checkCudaErrors(cudnnGetConvolutionBackwardFilterAlgorithm_v7(
|
||||
handle_,
|
||||
cudnnIdesc,
|
||||
cudnnOdesc,
|
||||
cudnnConvDesc,
|
||||
cudnnFdesc,
|
||||
num_algos,
|
||||
&perf_count,
|
||||
perf_results));
|
||||
}
|
||||
int best_algo_idx=-1;
|
||||
for (int i = 0; i < perf_count; i++)
|
||||
if (perf_results[i].status == CUDNN_STATUS_SUCCESS){
|
||||
best_algo_idx=i;
|
||||
break;
|
||||
}
|
||||
ASSERT(best_algo_idx!=-1);
|
||||
algo=perf_results[best_algo_idx].algo;
|
||||
if (benchmark) {
|
||||
bwdw_algo_cache[jk.to_string()] = algo;
|
||||
if (bwdw_algo_cache.size()==max_cache_size)
|
||||
LOGw << "backward w algorithm cache is full";
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: warp work space
|
||||
void *workSpace = 0;
|
||||
size_t workSpaceSize;
|
||||
checkCudaErrors (cudnnGetConvolutionBackwardFilterWorkspaceSize(
|
||||
handle_, cudnnIdesc, cudnnOdesc, cudnnConvDesc,
|
||||
cudnnFdesc, algo, &workSpaceSize));
|
||||
size_t allocation;
|
||||
if (workSpaceSize > 0) {
|
||||
workSpace = exe.allocator->alloc(workSpaceSize, allocation);
|
||||
}
|
||||
float alpha=1, beta=0;
|
||||
checkCudaErrors(cudnnConvolutionBackwardFilter(
|
||||
handle_,
|
||||
(void*)(&alpha),
|
||||
cudnnIdesc, x->ptr<Tx>(),
|
||||
cudnnOdesc, y->ptr<Ty>(),
|
||||
cudnnConvDesc,
|
||||
algo,
|
||||
workSpace, workSpaceSize,
|
||||
(void*)(&beta),
|
||||
cudnnFdesc, w->ptr<Tw>())
|
||||
);
|
||||
if (workSpace)
|
||||
exe.allocator->free(workSpace, workSpaceSize, allocation);
|
||||
|
||||
checkCudaErrors(cudnnDestroyTensorDescriptor( cudnnIdesc ));
|
||||
checkCudaErrors(cudnnDestroyFilterDescriptor( cudnnFdesc ));
|
||||
checkCudaErrors(cudnnDestroyTensorDescriptor( cudnnOdesc ));
|
||||
checkCudaErrors(cudnnDestroyConvolutionDescriptor( cudnnConvDesc ));
|
||||
}
|
||||
#endif
|
||||
#endif // JIT
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,26 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors:
|
||||
// Dun Liang <randonlang@gmail.com>
|
||||
// Guowei Yang <471184555@qq.com>
|
||||
// All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include "op.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
struct CudnnConvBackwardWOp : Op {
|
||||
Var* x, * dy, * dw;
|
||||
int kernel_size, stride, padding, dilation;
|
||||
string xformat, wformat, yformat;
|
||||
|
||||
CudnnConvBackwardWOp(Var* x, Var* y, int kernel_size, int stride, int padding, int dilation, string xformat="abcd", string wformat="oihw", string yformat="abcd");
|
||||
|
||||
const char* name() const override { return "cudnn_conv_backward_w"; }
|
||||
void infer_shape() override;
|
||||
DECLARE_jit_run;
|
||||
};
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,273 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors:
|
||||
// Dun Liang <randonlang@gmail.com>
|
||||
// Guowei Yang <471184555@qq.com>
|
||||
// All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include "var.h"
|
||||
#include "cudnn_conv_backward_x_op.h"
|
||||
#include "cudnn_warper.h"
|
||||
#include "executor.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
static inline int findc(const char* format, const char& c) {
|
||||
if (c==format[0]) return 0;
|
||||
if (c==format[1]) return 1;
|
||||
if (c==format[2]) return 2;
|
||||
ASSERT(c==format[3]) << "Not a valid format" << format << c;
|
||||
return 3;
|
||||
}
|
||||
|
||||
namespace jittor {
|
||||
|
||||
#ifndef JIT
|
||||
|
||||
static inline void get_shape(Var* x, const char* f, const string& format, int& a, int& b, int &c, int& d) {
|
||||
auto& shape = x->shape;
|
||||
a = shape[findc(format.c_str(), f[0])];
|
||||
b = shape[findc(format.c_str(), f[1])];
|
||||
c = shape[findc(format.c_str(), f[2])];
|
||||
d = shape[findc(format.c_str(), f[3])];
|
||||
}
|
||||
|
||||
static inline void set_shape(Var* x, const char* f, const string& format, int a, int b, int c, int d) {
|
||||
int64 shape[4];
|
||||
shape[findc(format.c_str(), f[0])] = a;
|
||||
shape[findc(format.c_str(), f[1])] = b;
|
||||
shape[findc(format.c_str(), f[2])] = c;
|
||||
shape[findc(format.c_str(), f[3])] = d;
|
||||
x->set_shape(NanoVector(
|
||||
shape[0], shape[1], shape[2], shape[3]));
|
||||
}
|
||||
|
||||
CudnnConvBackwardXOp::CudnnConvBackwardXOp(Var* w, Var* dy, int height, int width, int stride, int padding, int dilation, string xformat, string wformat, string yformat)
|
||||
: w(w), dy(dy), xh(height), xw(width), stride(stride), padding(padding), dilation(dilation),
|
||||
xformat(move(xformat)), wformat(move(wformat)), yformat(move(yformat)) {
|
||||
flags.set(NodeFlags::_cuda, 1);
|
||||
flags.set(NodeFlags::_cpu, 0);
|
||||
dx = create_output(nullptr, dtype_infer(dy->ns, w->ns));
|
||||
}
|
||||
|
||||
void CudnnConvBackwardXOp::infer_shape() {
|
||||
ASSERTop(w->shape.size(),==,4);
|
||||
ASSERTop(dy->shape.size(),==,4);
|
||||
int xn, xc, wh, ww, wci, wco, yn, yc, yh, yw;
|
||||
get_shape(w, "oihw", wformat, wco, wci, wh, ww);
|
||||
get_shape(dy, "abcd", yformat, yn, yc, yh, yw);
|
||||
xn = yn, xc = wci;
|
||||
set_shape(dx, "abcd", xformat, xn, xc, xh, xw);
|
||||
}
|
||||
|
||||
void CudnnConvBackwardXOp::jit_prepare() {
|
||||
add_jit_define("Ty", dy->dtype());
|
||||
add_jit_define("Tw", w->dtype());
|
||||
add_jit_define("Tx", dx->dtype());
|
||||
add_jit_define("XFORMAT", xformat);
|
||||
add_jit_define("WFORMAT", wformat);
|
||||
add_jit_define("YFORMAT", yformat);
|
||||
}
|
||||
unordered_map<string, cudnnConvolutionBwdDataAlgo_t> bwdx_algo_cache;
|
||||
|
||||
#else // JIT
|
||||
#ifdef JIT_cuda
|
||||
|
||||
#pragma clang diagnostic ignored "-Wtautological-compare"
|
||||
|
||||
extern unordered_map<string, cudnnConvolutionBwdDataAlgo_t> bwdx_algo_cache;
|
||||
|
||||
template <typename T_ELEM> __inline__ cudnnDataType_t getDataType();
|
||||
template <> __inline__ cudnnDataType_t getDataType<half1>() { return CUDNN_DATA_HALF; }
|
||||
template <> __inline__ cudnnDataType_t getDataType<float>() { return CUDNN_DATA_FLOAT; }
|
||||
|
||||
void CudnnConvBackwardXOp::jit_run() {
|
||||
auto x = dx;
|
||||
auto y = dy;
|
||||
cudnnHandle_t& handle_ = cudnn_handle;
|
||||
|
||||
cudnnTensorDescriptor_t cudnnIdesc;
|
||||
cudnnFilterDescriptor_t cudnnFdesc;
|
||||
cudnnTensorDescriptor_t cudnnOdesc;
|
||||
cudnnConvolutionDescriptor_t cudnnConvDesc;
|
||||
|
||||
checkCudaErrors(cudnnCreateTensorDescriptor( &cudnnIdesc ));
|
||||
checkCudaErrors(cudnnCreateFilterDescriptor( &cudnnFdesc ));
|
||||
checkCudaErrors(cudnnCreateTensorDescriptor( &cudnnOdesc ));
|
||||
checkCudaErrors(cudnnCreateConvolutionDescriptor( &cudnnConvDesc ));
|
||||
|
||||
int dimX[] = {
|
||||
(int)x->shape[findc("@XFORMAT", 'a')], // n
|
||||
(int)x->shape[findc("@XFORMAT", 'b')], // c
|
||||
(int)x->shape[findc("@XFORMAT", 'c')], // h
|
||||
(int)x->shape[findc("@XFORMAT", 'd')], // w
|
||||
};
|
||||
int _strideX[] = {0,0,0,1};
|
||||
for (int i=2; i>=0; i--) _strideX[i] = _strideX[i+1] * x->shape[i+1];
|
||||
int strideX[] = {
|
||||
_strideX[findc("@XFORMAT", 'a')], // n
|
||||
_strideX[findc("@XFORMAT", 'b')], // c
|
||||
_strideX[findc("@XFORMAT", 'c')], // h
|
||||
_strideX[findc("@XFORMAT", 'd')], // w
|
||||
};
|
||||
// dimX: nchw
|
||||
checkCudaErrors(cudnnSetTensorNdDescriptor(
|
||||
cudnnIdesc, getDataType<Tx>(),
|
||||
4, dimX, strideX
|
||||
));
|
||||
|
||||
auto ws = w->shape;
|
||||
int dimW[] = {(int)ws[0],(int)ws[1],(int)ws[2],(int)ws[3]};
|
||||
// cudnn only support this two format
|
||||
// https://docs.nvidia.com/deeplearning/sdk/cudnn-api/index.html#cudnnSetFilterNdDescriptor
|
||||
#define filterFormat_oihw CUDNN_TENSOR_NCHW
|
||||
#define filterFormat_iohw CUDNN_TENSOR_NCHW
|
||||
#define filterFormat_ohwi CUDNN_TENSOR_NHWC
|
||||
|
||||
// dimW: KCRS(oihw)
|
||||
checkCudaErrors(cudnnSetFilterNdDescriptor(
|
||||
cudnnFdesc, getDataType<Tw>(),
|
||||
filterFormat_@WFORMAT, 4, dimW
|
||||
));
|
||||
|
||||
int padA[] = {padding, padding};
|
||||
int convstrideA[] = {stride, stride};
|
||||
int dilationA[] = {dilation, dilation};
|
||||
// difference between
|
||||
// CUDNN_CONVOLUTION and CUDNN_CROSS_CORRELATION
|
||||
// is the kernel rc order
|
||||
// currently, No perf difference is observed between
|
||||
// this two mode
|
||||
checkCudaErrors(cudnnSetConvolutionNdDescriptor(
|
||||
cudnnConvDesc, 2,
|
||||
padA, convstrideA, dilationA,
|
||||
CUDNN_CROSS_CORRELATION, getDataType<Ty>()
|
||||
));
|
||||
|
||||
// using tensor core
|
||||
// checkCudaErrors( cudnnSetConvolutionMathType(cudnnConvDesc, CUDNN_TENSOR_OP_MATH) );
|
||||
|
||||
int dimY[] = {
|
||||
(int)y->shape[findc("@YFORMAT", 'a')], // n
|
||||
(int)y->shape[findc("@YFORMAT", 'b')], // c
|
||||
(int)y->shape[findc("@YFORMAT", 'c')], // h
|
||||
(int)y->shape[findc("@YFORMAT", 'd')], // w
|
||||
};
|
||||
int _strideY[] = {0,0,0,1};
|
||||
for (int i=2; i>=0; i--) _strideY[i] = _strideY[i+1] * y->shape[i+1];
|
||||
int strideY[] = {
|
||||
_strideY[findc("@YFORMAT", 'a')], // n
|
||||
_strideY[findc("@YFORMAT", 'b')], // c
|
||||
_strideY[findc("@YFORMAT", 'c')], // h
|
||||
_strideY[findc("@YFORMAT", 'd')], // w
|
||||
};
|
||||
checkCudaErrors(cudnnSetTensorNdDescriptor(
|
||||
cudnnOdesc, getDataType<Ty>(),
|
||||
4, dimY, strideY
|
||||
));
|
||||
|
||||
cudnnConvolutionBwdDataAlgo_t algos[] = {
|
||||
CUDNN_CONVOLUTION_BWD_DATA_ALGO_0,
|
||||
CUDNN_CONVOLUTION_BWD_DATA_ALGO_1,
|
||||
CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT,
|
||||
CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING,
|
||||
CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD,
|
||||
CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED
|
||||
};
|
||||
int num_algos = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT;
|
||||
int perf_count;
|
||||
cudnnConvolutionBwdDataAlgoPerf_t perf_results[num_algos];
|
||||
cudnnConvolutionBwdDataAlgo_t algo;
|
||||
bool benchmark=true;
|
||||
|
||||
jk.clear();
|
||||
jk << dimX[0] << "," << dimX[1] << "," << dimX[2] << "," << dimX[3] << ",";
|
||||
jk << dimW[0] << "," << dimW[1] << "," << dimW[2] << "," << dimW[3] << ",";
|
||||
jk << padding << "," <<stride << "," << dilation << ".";
|
||||
auto iter = bwdx_algo_cache.find(jk.to_string());
|
||||
|
||||
if (iter!=bwdx_algo_cache.end()) algo = iter->second;
|
||||
else {
|
||||
if (bwdx_algo_cache.size()>=max_cache_size) benchmark = false;
|
||||
if (benchmark) {
|
||||
size_t max_ws_size = 0;
|
||||
for (int i = 0; i < num_algos; i++) {
|
||||
size_t sz;
|
||||
cudnnStatus_t ret = cudnnGetConvolutionBackwardDataWorkspaceSize(handle_, cudnnFdesc, cudnnOdesc, cudnnConvDesc, cudnnIdesc, algos[i], &sz);
|
||||
if (CUDNN_STATUS_SUCCESS == ret && sz > max_ws_size) max_ws_size = sz;
|
||||
}
|
||||
size_t allocation;
|
||||
void* ws = exe.allocator->alloc(max_ws_size, allocation);
|
||||
checkCudaErrors(cudnnFindConvolutionBackwardDataAlgorithmEx(
|
||||
handle_,
|
||||
cudnnFdesc, w->ptr<Tw>(),
|
||||
cudnnOdesc, y->ptr<Ty>(),
|
||||
cudnnConvDesc,
|
||||
cudnnIdesc, x->ptr<Tx>(),
|
||||
num_algos,
|
||||
&perf_count,
|
||||
perf_results,
|
||||
ws,
|
||||
max_ws_size));
|
||||
exe.allocator->free(ws, max_ws_size, allocation);
|
||||
} else {
|
||||
checkCudaErrors(cudnnGetConvolutionBackwardDataAlgorithm_v7(
|
||||
handle_,
|
||||
cudnnFdesc,
|
||||
cudnnOdesc,
|
||||
cudnnConvDesc,
|
||||
cudnnIdesc,
|
||||
num_algos,
|
||||
&perf_count,
|
||||
perf_results));
|
||||
}
|
||||
int best_algo_idx=-1;
|
||||
for (int i = 0; i < perf_count; i++)
|
||||
if (perf_results[i].status == CUDNN_STATUS_SUCCESS){
|
||||
best_algo_idx=i;
|
||||
break;
|
||||
}
|
||||
ASSERT(best_algo_idx!=-1);
|
||||
algo=perf_results[best_algo_idx].algo;
|
||||
if (benchmark) {
|
||||
bwdx_algo_cache[jk.to_string()] = algo;
|
||||
if (bwdx_algo_cache.size()==max_cache_size)
|
||||
LOGw << "backward x algorithm cache is full";
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: warp work space
|
||||
void *workSpace = 0;
|
||||
size_t workSpaceSize;
|
||||
checkCudaErrors (cudnnGetConvolutionBackwardDataWorkspaceSize(
|
||||
handle_, cudnnFdesc, cudnnOdesc, cudnnConvDesc,
|
||||
cudnnIdesc, algo, &workSpaceSize));
|
||||
size_t allocation;
|
||||
if (workSpaceSize > 0) {
|
||||
workSpace = exe.allocator->alloc(workSpaceSize, allocation);
|
||||
}
|
||||
float alpha=1, beta=0;
|
||||
checkCudaErrors(cudnnConvolutionBackwardData(
|
||||
handle_,
|
||||
(void*)(&alpha),
|
||||
cudnnFdesc, w->ptr<Tw>(),
|
||||
cudnnOdesc, y->ptr<Ty>(),
|
||||
cudnnConvDesc,
|
||||
algo,
|
||||
workSpace, workSpaceSize,
|
||||
(void*)(&beta),
|
||||
cudnnIdesc, x->ptr<Tx>())
|
||||
);
|
||||
if (workSpace)
|
||||
exe.allocator->free(workSpace, workSpaceSize, allocation);
|
||||
|
||||
checkCudaErrors(cudnnDestroyTensorDescriptor( cudnnIdesc ));
|
||||
checkCudaErrors(cudnnDestroyFilterDescriptor( cudnnFdesc ));
|
||||
checkCudaErrors(cudnnDestroyTensorDescriptor( cudnnOdesc ));
|
||||
checkCudaErrors(cudnnDestroyConvolutionDescriptor( cudnnConvDesc ));
|
||||
}
|
||||
#endif
|
||||
#endif // JIT
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,26 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors:
|
||||
// Dun Liang <randonlang@gmail.com>
|
||||
// Guowei Yang <471184555@qq.com>
|
||||
// All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include "op.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
struct CudnnConvBackwardXOp : Op {
|
||||
Var* w, * dy, * dx;
|
||||
int xh, xw, stride, padding, dilation;
|
||||
string xformat, wformat, yformat;
|
||||
|
||||
CudnnConvBackwardXOp(Var* w, Var* y, int height, int width, int stride, int padding, int dilation, string xformat="abcd", string wformat="oihw", string yformat="abcd");
|
||||
|
||||
const char* name() const override { return "cudnn_conv_backward_x"; }
|
||||
void infer_shape() override;
|
||||
DECLARE_jit_run;
|
||||
};
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,277 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include "var.h"
|
||||
#include "cudnn_conv_op.h"
|
||||
#include "cudnn_warper.h"
|
||||
#include "executor.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
namespace jittor {
|
||||
|
||||
static inline int findc(const char* format, const char& c) {
|
||||
if (c==format[0]) return 0;
|
||||
if (c==format[1]) return 1;
|
||||
if (c==format[2]) return 2;
|
||||
ASSERT(c==format[3]) << "Not a valid format" << format << c;
|
||||
return 3;
|
||||
}
|
||||
|
||||
#ifndef JIT
|
||||
|
||||
static inline void get_shape(Var* x, const char* f, const string& format, int& a, int& b, int &c, int& d) {
|
||||
auto& shape = x->shape;
|
||||
a = shape[findc(format.c_str(), f[0])];
|
||||
b = shape[findc(format.c_str(), f[1])];
|
||||
c = shape[findc(format.c_str(), f[2])];
|
||||
d = shape[findc(format.c_str(), f[3])];
|
||||
}
|
||||
|
||||
static inline void set_shape(Var* x, const char* f, const string& format, int a, int b, int c, int d) {
|
||||
int64 shape[4];
|
||||
shape[findc(format.c_str(), f[0])] = a;
|
||||
shape[findc(format.c_str(), f[1])] = b;
|
||||
shape[findc(format.c_str(), f[2])] = c;
|
||||
shape[findc(format.c_str(), f[3])] = d;
|
||||
x->set_shape(NanoVector(
|
||||
shape[0], shape[1], shape[2], shape[3]));
|
||||
}
|
||||
|
||||
CudnnConvOp::CudnnConvOp(Var* x, Var* w, int stride, int padding, int dilation, string xformat, string wformat, string yformat)
|
||||
: x(x), w(w), stride(stride), padding(padding), dilation(dilation),
|
||||
xformat(move(xformat)), wformat(move(wformat)), yformat(move(yformat)) {
|
||||
flags.set(NodeFlags::_cuda, 1);
|
||||
flags.set(NodeFlags::_cpu, 0);
|
||||
y = create_output(nullptr, dtype_infer(x->ns, w->ns));
|
||||
if (!this->yformat.size())
|
||||
this->yformat = this->xformat;
|
||||
}
|
||||
|
||||
void CudnnConvOp::infer_shape() {
|
||||
ASSERTop(x->shape.size(),==,4);
|
||||
ASSERTop(w->shape.size(),==,4);
|
||||
int xn, xc, xh, xw, wh, ww, wci, wco, yn, yc, yh, yw;
|
||||
get_shape(x, "abcd", xformat, xn, xc, xh, xw);
|
||||
get_shape(w, "oihw", wformat, wco, wci, wh, ww);
|
||||
ASSERTop(wci,==,xc);
|
||||
yn = xn, yc = wco;
|
||||
yh = (xh+padding*2-wh*dilation+dilation-1)/stride+1;
|
||||
yw = (xw+padding*2-ww*dilation+dilation-1)/stride+1;
|
||||
set_shape(y, "abcd", yformat, yn, yc, yh, yw);
|
||||
}
|
||||
|
||||
void CudnnConvOp::jit_prepare() {
|
||||
add_jit_define("Tx", x->dtype());
|
||||
add_jit_define("Tw", w->dtype());
|
||||
add_jit_define("Ty", y->dtype());
|
||||
add_jit_define("XFORMAT", xformat);
|
||||
add_jit_define("WFORMAT", wformat);
|
||||
add_jit_define("YFORMAT", yformat);
|
||||
}
|
||||
unordered_map<string, cudnnConvolutionFwdAlgo_t> fwd_algo_cache;
|
||||
|
||||
#else // JIT
|
||||
#ifdef JIT_cuda
|
||||
|
||||
#pragma clang diagnostic ignored "-Wtautological-compare"
|
||||
|
||||
extern unordered_map<string, cudnnConvolutionFwdAlgo_t> fwd_algo_cache;
|
||||
|
||||
template <typename T_ELEM> __inline__ cudnnDataType_t getDataType();
|
||||
template <> __inline__ cudnnDataType_t getDataType<half1>() { return CUDNN_DATA_HALF; }
|
||||
template <> __inline__ cudnnDataType_t getDataType<float>() { return CUDNN_DATA_FLOAT; }
|
||||
|
||||
void CudnnConvOp::jit_run() {
|
||||
cudnnHandle_t& handle_ = cudnn_handle;
|
||||
|
||||
cudnnTensorDescriptor_t cudnnIdesc;
|
||||
cudnnFilterDescriptor_t cudnnFdesc;
|
||||
cudnnTensorDescriptor_t cudnnOdesc;
|
||||
cudnnConvolutionDescriptor_t cudnnConvDesc;
|
||||
|
||||
checkCudaErrors(cudnnCreateTensorDescriptor( &cudnnIdesc ));
|
||||
checkCudaErrors(cudnnCreateFilterDescriptor( &cudnnFdesc ));
|
||||
checkCudaErrors(cudnnCreateTensorDescriptor( &cudnnOdesc ));
|
||||
checkCudaErrors(cudnnCreateConvolutionDescriptor( &cudnnConvDesc ));
|
||||
|
||||
int dimX[] = {
|
||||
(int)x->shape[findc("@XFORMAT", 'a')], // n
|
||||
(int)x->shape[findc("@XFORMAT", 'b')], // c
|
||||
(int)x->shape[findc("@XFORMAT", 'c')], // h
|
||||
(int)x->shape[findc("@XFORMAT", 'd')], // w
|
||||
};
|
||||
int _strideX[] = {0,0,0,1};
|
||||
for (int i=2; i>=0; i--) _strideX[i] = _strideX[i+1] * x->shape[i+1];
|
||||
int strideX[] = {
|
||||
_strideX[findc("@XFORMAT", 'a')], // n
|
||||
_strideX[findc("@XFORMAT", 'b')], // c
|
||||
_strideX[findc("@XFORMAT", 'c')], // h
|
||||
_strideX[findc("@XFORMAT", 'd')], // w
|
||||
};
|
||||
// dimX: nchw
|
||||
checkCudaErrors(cudnnSetTensorNdDescriptor(
|
||||
cudnnIdesc, getDataType<Tx>(),
|
||||
4, dimX, strideX
|
||||
));
|
||||
|
||||
auto ws = w->shape;
|
||||
int dimW[] = {(int)ws[0],(int)ws[1],(int)ws[2],(int)ws[3]};
|
||||
// cudnn only support this two format
|
||||
// https://docs.nvidia.com/deeplearning/sdk/cudnn-api/index.html#cudnnSetFilterNdDescriptor
|
||||
#define filterFormat_oihw CUDNN_TENSOR_NCHW
|
||||
#define filterFormat_ohwi CUDNN_TENSOR_NHWC
|
||||
|
||||
// dimW: KCRS(oihw)
|
||||
checkCudaErrors(cudnnSetFilterNdDescriptor(
|
||||
cudnnFdesc, getDataType<Tw>(),
|
||||
filterFormat_@WFORMAT, 4, dimW
|
||||
));
|
||||
|
||||
int padA[] = {padding, padding};
|
||||
int convstrideA[] = {stride, stride};
|
||||
int dilationA[] = {dilation, dilation};
|
||||
// difference between
|
||||
// CUDNN_CONVOLUTION and CUDNN_CROSS_CORRELATION
|
||||
// is the kernel rc order
|
||||
// currently, No perf difference is observed between
|
||||
// this two mode
|
||||
checkCudaErrors(cudnnSetConvolutionNdDescriptor(
|
||||
cudnnConvDesc, 2,
|
||||
padA, convstrideA, dilationA,
|
||||
CUDNN_CROSS_CORRELATION, getDataType<Ty>()
|
||||
));
|
||||
|
||||
// using tensor core
|
||||
// checkCudaErrors( cudnnSetConvolutionMathType(cudnnConvDesc, CUDNN_TENSOR_OP_MATH) );
|
||||
|
||||
int dimY[] = {
|
||||
(int)y->shape[findc("@YFORMAT", 'a')], // n
|
||||
(int)y->shape[findc("@YFORMAT", 'b')], // c
|
||||
(int)y->shape[findc("@YFORMAT", 'c')], // h
|
||||
(int)y->shape[findc("@YFORMAT", 'd')], // w
|
||||
};
|
||||
int _strideY[] = {0,0,0,1};
|
||||
for (int i=2; i>=0; i--) _strideY[i] = _strideY[i+1] * y->shape[i+1];
|
||||
int strideY[] = {
|
||||
_strideY[findc("@YFORMAT", 'a')], // n
|
||||
_strideY[findc("@YFORMAT", 'b')], // c
|
||||
_strideY[findc("@YFORMAT", 'c')], // h
|
||||
_strideY[findc("@YFORMAT", 'd')], // w
|
||||
};
|
||||
checkCudaErrors(cudnnSetTensorNdDescriptor(
|
||||
cudnnOdesc, getDataType<Ty>(),
|
||||
4, dimY, strideY
|
||||
));
|
||||
|
||||
cudnnConvolutionFwdAlgo_t algos[] = {
|
||||
CUDNN_CONVOLUTION_FWD_ALGO_GEMM,
|
||||
CUDNN_CONVOLUTION_FWD_ALGO_FFT,
|
||||
CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING,
|
||||
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM,
|
||||
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM,
|
||||
CUDNN_CONVOLUTION_FWD_ALGO_DIRECT,
|
||||
CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD,
|
||||
CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED,
|
||||
};
|
||||
int num_algos = CUDNN_CONVOLUTION_FWD_ALGO_COUNT;
|
||||
int perf_count;
|
||||
cudnnConvolutionFwdAlgoPerf_t perf_results[num_algos];
|
||||
cudnnConvolutionFwdAlgo_t algo;
|
||||
bool benchmark=true;
|
||||
|
||||
jk.clear();
|
||||
jk << dimX[0] << "," << dimX[1] << "," << dimX[2] << "," << dimX[3] << ",";
|
||||
jk << dimW[0] << "," << dimW[1] << "," << dimW[2] << "," << dimW[3] << ",";
|
||||
jk << padding << "," <<stride << "," << dilation << ".";
|
||||
auto iter = fwd_algo_cache.find(jk.to_string());
|
||||
|
||||
if (iter!=fwd_algo_cache.end()) algo = iter->second;
|
||||
else {
|
||||
if (fwd_algo_cache.size()>=max_cache_size) benchmark = false;
|
||||
if (benchmark) {
|
||||
size_t max_ws_size = 0;
|
||||
for (int i = 0; i < num_algos; i++) {
|
||||
size_t sz;
|
||||
cudnnStatus_t ret = cudnnGetConvolutionForwardWorkspaceSize(
|
||||
handle_, cudnnIdesc, cudnnFdesc, cudnnConvDesc,
|
||||
cudnnOdesc, algos[i], &sz);
|
||||
if (CUDNN_STATUS_SUCCESS == ret && sz > max_ws_size && sz<512*1024*1024) max_ws_size = sz;
|
||||
}
|
||||
size_t allocation;
|
||||
void* ws = exe.allocator->alloc(max_ws_size, allocation);
|
||||
checkCudaErrors(cudnnFindConvolutionForwardAlgorithmEx(
|
||||
handle_,
|
||||
cudnnIdesc, x->ptr<Tx>(),
|
||||
cudnnFdesc, w->ptr<Tw>(),
|
||||
cudnnConvDesc,
|
||||
cudnnOdesc, y->ptr<Ty>(),
|
||||
num_algos,
|
||||
&perf_count,
|
||||
perf_results,
|
||||
ws,
|
||||
max_ws_size));
|
||||
exe.allocator->free(ws, max_ws_size, allocation);
|
||||
} else {
|
||||
checkCudaErrors(cudnnGetConvolutionForwardAlgorithm_v7(
|
||||
handle_,
|
||||
cudnnIdesc,
|
||||
cudnnFdesc,
|
||||
cudnnConvDesc,
|
||||
cudnnOdesc,
|
||||
num_algos,
|
||||
&perf_count,
|
||||
perf_results));
|
||||
}
|
||||
int best_algo_idx=-1;
|
||||
for (int i = 0; i < perf_count; i++)
|
||||
if (perf_results[i].status == CUDNN_STATUS_SUCCESS){
|
||||
best_algo_idx=i;
|
||||
break;
|
||||
}
|
||||
ASSERT(best_algo_idx!=-1);
|
||||
algo=perf_results[best_algo_idx].algo;
|
||||
if (benchmark) {
|
||||
fwd_algo_cache[jk.to_string()] = algo;
|
||||
if (fwd_algo_cache.size()==max_cache_size)
|
||||
LOGw << "forward_ algorithm cache is full";
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: warp work space
|
||||
void *workSpace = 0;
|
||||
size_t workSpaceSize;
|
||||
checkCudaErrors (cudnnGetConvolutionForwardWorkspaceSize(
|
||||
handle_, cudnnIdesc, cudnnFdesc, cudnnConvDesc,
|
||||
cudnnOdesc, algo, &workSpaceSize) );
|
||||
size_t allocation;
|
||||
if (workSpaceSize > 0) {
|
||||
workSpace = exe.allocator->alloc(workSpaceSize, allocation);
|
||||
}
|
||||
float alpha=1, beta=0;
|
||||
checkCudaErrors(cudnnConvolutionForward(
|
||||
handle_,
|
||||
(void*)(&alpha),
|
||||
cudnnIdesc, x->ptr<Tx>(),
|
||||
cudnnFdesc, w->ptr<Tw>(),
|
||||
cudnnConvDesc,
|
||||
algo,
|
||||
workSpace, workSpaceSize,
|
||||
(void*)(&beta),
|
||||
cudnnOdesc, y->ptr<Ty>())
|
||||
);
|
||||
if (workSpace)
|
||||
exe.allocator->free(workSpace, workSpaceSize, allocation);
|
||||
|
||||
checkCudaErrors(cudnnDestroyTensorDescriptor( cudnnIdesc ));
|
||||
checkCudaErrors(cudnnDestroyFilterDescriptor( cudnnFdesc ));
|
||||
checkCudaErrors(cudnnDestroyTensorDescriptor( cudnnOdesc ));
|
||||
checkCudaErrors(cudnnDestroyConvolutionDescriptor( cudnnConvDesc ));
|
||||
}
|
||||
#endif
|
||||
#endif // JIT
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,23 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include "op.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
struct CudnnConvOp : Op {
|
||||
Var* x, * w, * y;
|
||||
int stride, padding, dilation;
|
||||
string xformat, wformat, yformat;
|
||||
/* CudnnConvOp: xformat abcd represents nchw */
|
||||
CudnnConvOp(Var* x, Var* w, int stride, int padding, int dilation, string xformat="abcd", string wformat="oihw", string yformat="");
|
||||
|
||||
const char* name() const override { return "cudnn_conv"; }
|
||||
void infer_shape() override;
|
||||
DECLARE_jit_run;
|
||||
};
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,39 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include <random>
|
||||
|
||||
#include "var.h"
|
||||
#include "cudnn_test_op.h"
|
||||
#include "misc/str_utils.h"
|
||||
|
||||
int cudnn_test_entry( int argc, char** argv );
|
||||
|
||||
namespace jittor {
|
||||
|
||||
#ifndef JIT
|
||||
CudnnTestOp::CudnnTestOp(string cmd) : cmd(move(cmd)) {
|
||||
output = create_output(1, ns_float32);
|
||||
}
|
||||
|
||||
void CudnnTestOp::jit_prepare() {
|
||||
add_jit_define("T", ns_float32);
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
#ifdef JIT_cpu
|
||||
void CudnnTestOp::jit_run() {
|
||||
auto args = split(cmd, " ");
|
||||
if (!cmd.size()) args.clear();
|
||||
vector<char*> v(args.size());
|
||||
for (uint i=0; i<args.size(); i++)
|
||||
v[i] = &args[i][0];
|
||||
ASSERT(cudnn_test_entry(v.size(), &v[0])==0);
|
||||
output->ptr<T>()[0] = 123;
|
||||
}
|
||||
#endif
|
||||
#endif // JIT
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,20 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include "op.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
struct CudnnTestOp : Op {
|
||||
Var* output;
|
||||
string cmd;
|
||||
CudnnTestOp(string cmd);
|
||||
|
||||
const char* name() const override { return "cudnn_test"; }
|
||||
DECLARE_jit_run;
|
||||
};
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,997 @@
|
|||
// Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without
|
||||
// modification, are permitted provided that the following conditions
|
||||
// are met:
|
||||
// * Redistributions of source code must retain the above copyright
|
||||
// notice, this list of conditions and the following disclaimer.
|
||||
// * Redistributions in binary form must reproduce the above copyright
|
||||
// notice, this list of conditions and the following disclaimer in the
|
||||
// documentation and/or other materials provided with the distribution.
|
||||
// * Neither the name of NVIDIA CORPORATION nor the names of its
|
||||
// contributors may be used to endorse or promote products derived
|
||||
// from this software without specific prior written permission.
|
||||
//
|
||||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
|
||||
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
||||
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
|
||||
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
|
||||
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
//
|
||||
// This example demonstrates how to use CUDNN library calls cudnnConvolutionForward,
|
||||
// cudnnConvolutionBackwardData, and cudnnConvolutionBackwardFilter with the option
|
||||
// to enable Tensor Cores on Volta with cudnnSetConvolutionMathType.
|
||||
//
|
||||
// 1. Make sure cuda and cudnn are installed in the same directory.
|
||||
//
|
||||
// 2. Run make from the directory of the sample specifying the cuda installation path:
|
||||
// make CUDA_PATH=<cuda installation path>
|
||||
//
|
||||
// 3. Use the following arguments to run sample with different convolution parameters:
|
||||
// -c2048 -h7 -w7 -k512 -r1 -s1 -pad_h0 -pad_w0 -u1 -v1
|
||||
// -c512 -h28 -w28 -k128 -r1 -s1 -pad_h0 -pad_w0 -u1 -v1
|
||||
// -c512 -h28 -w28 -k1024 -r1 -s1 -pad_h0 -pad_w0 -u2 -v2
|
||||
// -c512 -h28 -w28 -k256 -r1 -s1 -pad_h0 -pad_w0 -u2 -v2
|
||||
// -c256 -h14 -w14 -k256 -r3 -s3 -pad_h1 -pad_w1 -u1 -v1
|
||||
// -c256 -h14 -w14 -k1024 -r1 -s1 -pad_h0 -pad_w0 -u1 -v1
|
||||
// -c1024 -h14 -w14 -k256 -r1 -s1 -pad_h0 -pad_w0 -u1 -v1
|
||||
// -c1024 -h14 -w14 -k2048 -r1 -s1 -pad_h0 -pad_w0 -u2 -v2
|
||||
// -c1024 -h14 -w14 -k512 -r1 -s1 -pad_h0 -pad_w0 -u2 -v2
|
||||
// -c512 -h7 -w7 -k512 -r3 -s3 -pad_h1 -pad_w1 -u1 -v1
|
||||
// -c512 -h7 -w7 -k2048 -r1 -s1 -pad_h0 -pad_w0 -u1 -v1
|
||||
// -c2048 -h7 -w7 -k512 -r1 -s1 -pad_h0 -pad_w0 -u1 -v1
|
||||
//
|
||||
// 4. Use the following additional arguments to run the layer with different setup:
|
||||
// -mathType1 : enable Tensor Cores on Volta.
|
||||
// -dgrad : run cudnnConvolutionBackwardData() instead of cudnnConvolutionForward().
|
||||
// -wgrad : run cudnnConvolutionBackwardFilter() instead of cudnnConvolutionForward().
|
||||
// -n<int> : mini batch size. (use -b with large n)
|
||||
// -b : benchmark mode. Bypass the CPU correctness check.
|
||||
// -filterFormat1 : Use tensor format CUDNN_TENSOR_NHWC instead of CUDNN_TENSOR_NCHW.
|
||||
//
|
||||
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
#include <ctype.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <assert.h>
|
||||
|
||||
#include <cudnn.h>
|
||||
#include <helper_cuda.h>
|
||||
#include "fp16_dev.h"
|
||||
#include "fp16_emu.h"
|
||||
|
||||
#define SWITCH_CHAR '-'
|
||||
#define THRESHOLD 2.0e-2
|
||||
|
||||
#if defined(__linux__)
|
||||
#include <stddef.h>
|
||||
#include <sys/time.h>
|
||||
#include <sys/resource.h>
|
||||
#include <sys/sysinfo.h>
|
||||
static double second (void)
|
||||
{
|
||||
struct timeval tv;
|
||||
gettimeofday(&tv, NULL);
|
||||
return (double)tv.tv_sec + (double)tv.tv_usec / 1000000.0;
|
||||
}
|
||||
#else
|
||||
#error unsupported platform
|
||||
#endif
|
||||
|
||||
template <typename T_ELEM> __inline__ cudnnDataType_t getDataType();
|
||||
template <> __inline__ cudnnDataType_t getDataType<half1>() { return CUDNN_DATA_HALF; }
|
||||
template <> __inline__ cudnnDataType_t getDataType<float>() { return CUDNN_DATA_FLOAT; }
|
||||
|
||||
//Generate uniform numbers [0,1)
|
||||
static void initImage(float* image, int imageSize) {
|
||||
static unsigned seed = 123456789;
|
||||
for (int index = 0; index < imageSize; index++) {
|
||||
seed = ( 1103515245 * seed + 12345 ) & 0xffffffff;
|
||||
image[index] = float(seed)*2.3283064e-10; //2^-32
|
||||
}
|
||||
}
|
||||
|
||||
static void initImage(half1* image, int imageSize) {
|
||||
static unsigned seed = 123456789;
|
||||
for (int index = 0; index < imageSize; index++) {
|
||||
seed = ( 1103515245 * seed + 12345 ) & 0xffffffff;
|
||||
image[index] = cpu_float2half_rn(float(seed)*2.3283064e-10); //2^-32
|
||||
}
|
||||
}
|
||||
|
||||
static void printPerf( double cudaTime, double cudaGflops, double cudaBandwithGb,
|
||||
const char *cpuLib, double cpuTime, double cpuGflops, double cpuBandwithGb)
|
||||
{
|
||||
printf( "^^^^ CUDA : elapsed = %g sec, ", cudaTime );
|
||||
if (cudaGflops > 0) printf( "Gflops = %.3f ", cudaGflops );
|
||||
if (cudaBandwithGb > 0) printf( "Bandwidth = %.3f ", cudaBandwithGb );
|
||||
printf( "\n");
|
||||
if (cpuLib) {
|
||||
printf( "^^^^%s : elapsed = %g sec, ", cpuLib, cpuTime );
|
||||
if (cpuGflops > 0) printf( "Gflops = %.3f ", cpuGflops );
|
||||
if (cpuBandwithGb > 0) printf( "Bandwidth = %.3f, ", cpuBandwithGb );
|
||||
printf( "Speedup %.2f\n", cpuTime/cudaTime );
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
static void generateStrides(const int* dimA, int* strideA, int nbDims, bool isNchw) {
|
||||
if (isNchw) {
|
||||
strideA[nbDims-1] = 1 ;
|
||||
for(int d = nbDims-2 ; d >= 0 ; d--) {
|
||||
strideA[d] = strideA[d+1] * dimA[d+1] ;
|
||||
}
|
||||
} else {
|
||||
strideA[1] = 1;
|
||||
strideA[nbDims-1] = strideA[1]*dimA[1];
|
||||
for(int d = nbDims-2 ; d >= 2 ; d--) {
|
||||
strideA[d] = strideA[d+1] * dimA[d+1] ;
|
||||
}
|
||||
strideA[0] = strideA[2]*dimA[2];
|
||||
}
|
||||
}
|
||||
|
||||
// Convert a linear index
|
||||
// i = d_1 s_1 ... s_n + d_2 s_2 ... s_n + d_n-1 s_n + d_n
|
||||
// into a multidimensional index
|
||||
// (d_1, d_2, ..., d_n)
|
||||
void lin2dim(int id, int* ids, const int* dims, int length) {
|
||||
int idrem = id ;
|
||||
int prod = 1 ; // accumulates the product of the dimensions
|
||||
for(int i = length-1; i >= 0; i--) {
|
||||
ids[i] = (idrem / prod) % dims[i] ;
|
||||
idrem = id - ids[i] * prod ;
|
||||
prod *= dims[i] ;
|
||||
}
|
||||
}
|
||||
|
||||
// Convert a multidimensional index
|
||||
// (d_1, d_2, ..., d_n)
|
||||
// into a linear index
|
||||
// i = d_1 s_1 + ... + d_n s_n
|
||||
static int dim2lin(const int* ids, const int* strides, int length) {
|
||||
int res = 0 ;
|
||||
for(int i = 0 ; i < length ; i++) {
|
||||
res += ids[i] * strides[i];
|
||||
}
|
||||
return res ;
|
||||
}
|
||||
|
||||
static float doFma(float fval, float ival, float tmp) {
|
||||
return fval*ival+tmp;
|
||||
}
|
||||
|
||||
static float doFma(half1 fval, half1 ival, float tmp) {
|
||||
return cpu_half2float(fval)*cpu_half2float(ival)+tmp;
|
||||
}
|
||||
|
||||
static void doEpilog(float *out, int idx, float alphaAcc, float beta) {
|
||||
if( beta == 0.f ) {
|
||||
out[idx] = alphaAcc;
|
||||
} else {
|
||||
out[idx] = alphaAcc + out[idx]*beta;
|
||||
}
|
||||
}
|
||||
|
||||
static void doEpilog(half1 *out, int idx, float alphaAcc, float beta) {
|
||||
if( beta == 0.f ) {
|
||||
out[idx] = cpu_float2half_rn(alphaAcc);
|
||||
} else {
|
||||
out[idx] = cpu_float2half_rn(alphaAcc + cpu_half2float(out[idx])*beta);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T_ELEM>
|
||||
static void conv_cpu_ref (
|
||||
const T_ELEM* inputData,
|
||||
const T_ELEM* filterData,
|
||||
T_ELEM* outputData,
|
||||
float alpha,
|
||||
float beta,
|
||||
bool isNchw,
|
||||
const int* inDims,
|
||||
const int* filDims,
|
||||
const int* outDims,
|
||||
const int* inStride,
|
||||
const int* outStride,
|
||||
const int* stride,
|
||||
const int* pad,
|
||||
const int* dilation,
|
||||
int nbDims
|
||||
) {
|
||||
int imDims = nbDims - 2 ;
|
||||
|
||||
int filStride[8] = {0} ;
|
||||
generateStrides(filDims, filStride, nbDims, isNchw);
|
||||
|
||||
bool isConv = true; //(CUDNN_CONVOLUTION == mode) ;
|
||||
// Number of pixels in output
|
||||
int nPixelsOut = 1 ;
|
||||
for(int i = 2 ; i < nbDims ; i++)
|
||||
nPixelsOut *= outDims[i] ;
|
||||
// Number of pixels in filter
|
||||
int nPixelsFil = 1 ;
|
||||
for(int i = 2 ; i < nbDims ; i++)
|
||||
nPixelsFil *= filDims[i] ;
|
||||
// Used to store coordinates
|
||||
int filIds[8] = {0} ;
|
||||
int outIds[8] = {0} ;
|
||||
int inIds [8] = {0} ;
|
||||
int tmpIds[8] = {0} ;
|
||||
// For each image in the output
|
||||
for(int ni = 0 ; ni < outDims[0] ; ni++) {
|
||||
// For each feature layer of the output
|
||||
for(int ki = 0 ; ki < outDims[1] ; ki++) {
|
||||
int outputOffset = ni * outStride[0] + ki * outStride[1] ;
|
||||
// Loop over all entries of the result
|
||||
for(int outId = 0 ; outId < nPixelsOut ; outId++) {
|
||||
// Get output pixel ids
|
||||
lin2dim(outId, outIds, outDims+2, imDims) ; // Skip n and k dimensions
|
||||
// Now we get the coordinates in input space of the "top left" corner of the filter: multiply by stride and remove pad
|
||||
for(int d = 0 ; d < imDims ; d++) {
|
||||
inIds[d] = outIds[d] * stride[d] - pad[d] ;
|
||||
}
|
||||
// We then accumulate
|
||||
float tmp = 0.f;
|
||||
for(int ci = 0 ; ci < inDims[1] ; ci++) {
|
||||
int inputOffset = ni * inStride[0] + ci * inStride[1] ;
|
||||
int filterOffset = ki * filStride[0] + ci * filStride[1] ;
|
||||
for(int filId = 0 ; filId < nPixelsFil ; filId ++) {
|
||||
// Get the position of the pixel
|
||||
lin2dim(filId, filIds, filDims+2, imDims) ;
|
||||
// Compute the corresponding output pixel
|
||||
// and check wether we are in the padding area on the fly too (not that for convolution, we flip the image patch (equivalent to flipping the filter patch))
|
||||
bool inside = true ;
|
||||
for(int d = 0 ; d < imDims && inside ; d++) {
|
||||
if (isConv) {
|
||||
tmpIds[d] = inIds[d] + dilation[d] * (filDims[2+d]-1 - filIds[d]) ;
|
||||
} else {
|
||||
tmpIds[d] = inIds[d] + dilation[d] * filIds[d] ;
|
||||
}
|
||||
inside &= (tmpIds[d] >= 0 && tmpIds[d] < inDims[2+d]) ; // If we are in the padding area: stop and skip computations
|
||||
}
|
||||
if(inside) {
|
||||
int actualTmpId = inputOffset + dim2lin(tmpIds, (inStride)+2, imDims) ;
|
||||
//int actualFilId = filterOffset + filId ;
|
||||
int actualFilId = filterOffset + dim2lin(filIds, (filStride)+2, imDims) ;
|
||||
T_ELEM fval = filterData[actualFilId] ;
|
||||
T_ELEM ival = inputData [actualTmpId] ;
|
||||
tmp = doFma(fval, ival, tmp);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// We put the result in the output
|
||||
int actualOutId = outputOffset + dim2lin(outIds, (outStride)+2, imDims) ;
|
||||
doEpilog(outputData, actualOutId, alpha*tmp, beta);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T_ELEM>
|
||||
static void dataGrad_cpu_ref (
|
||||
const T_ELEM *weight,
|
||||
const T_ELEM *top_diff,
|
||||
T_ELEM *output,
|
||||
float alpha,
|
||||
float beta,
|
||||
bool isNchw,
|
||||
const int* inDims,
|
||||
const int* filDims,
|
||||
const int* outDims,
|
||||
const int* inStride,
|
||||
const int* outStride,
|
||||
const int* stride,
|
||||
const int* pad,
|
||||
const int* dilation,
|
||||
int nbDims )
|
||||
{
|
||||
|
||||
// Sanity checks
|
||||
// output is n x c x h x w
|
||||
// diff is n x k x p x q
|
||||
// filter is k x c x r x s
|
||||
assert(inDims[0] == outDims[0]); // n
|
||||
assert(inDims[1] == filDims[0]); // k
|
||||
assert(outDims[1] == filDims[1]); // c
|
||||
|
||||
int filStride[8] = {0} ;
|
||||
generateStrides(filDims, filStride, nbDims, isNchw);
|
||||
|
||||
bool isConv = true; //(CUDNN_CONVOLUTION == mode) ;
|
||||
|
||||
// For every output pixel (n x c x h x w)
|
||||
for(int ni = 0; ni < outDims[0]; ni++) {
|
||||
for(int ci = 0; ci < outDims[1]; ci++) {
|
||||
for(int hi = 0; hi < outDims[2]; hi++) {
|
||||
for(int wi = 0; wi < outDims[3]; wi++) {
|
||||
int outIdx = ni * outStride[0] +
|
||||
ci * outStride[1] +
|
||||
hi * outStride[2] +
|
||||
wi * outStride[3];
|
||||
float val = 0.0;
|
||||
|
||||
// For every diff channel (k)
|
||||
for(int ki = 0; ki < inDims[1]; ki++) { // Sum over k channels
|
||||
int offset_filter = ki * filStride[0] + ci * filStride[1];
|
||||
int offset_diff = ni * inStride[0] + ki * inStride[1];
|
||||
// For every pixel if filter (r x s)
|
||||
for(int ri = 0; ri < filDims[2]; ri++) {
|
||||
int p = hi + pad[0];
|
||||
if (isConv){
|
||||
p -= (filDims[2] - 1 - ri) * dilation[0];
|
||||
} else {
|
||||
p -= ri * dilation[0];
|
||||
}
|
||||
if ( p%stride[0] )
|
||||
continue;
|
||||
p/=stride[0];
|
||||
|
||||
for(int si = 0; si < filDims[3]; si++) {
|
||||
int q = wi + pad[1];
|
||||
// Fetch the value in filter and diff, product and accumulate
|
||||
// So basically, for the convolution, we replace r by dim-1-r and s by dim-1-s to "flip" the filter
|
||||
// We can then just reason in term of correlation
|
||||
if (isConv){
|
||||
q -= (filDims[3] - 1 - si) * dilation[1];
|
||||
} else {
|
||||
q -= si * dilation[1];
|
||||
}
|
||||
//Skip if q or p isn't multiple of strides
|
||||
if ( q%stride[1] )
|
||||
continue;
|
||||
q/=stride[1];
|
||||
int inBounds = ( (p >= 0) && (p < inDims[2]) && (q >= 0) && (q < inDims[3]) );
|
||||
if (inBounds) {
|
||||
int filterIdx = offset_filter + ri * filStride[2] + si * filStride[3];
|
||||
int diffIdx = offset_diff + p * inStride[2] + q * inStride[3];
|
||||
T_ELEM imTmp = top_diff[diffIdx];
|
||||
T_ELEM filTmp = weight[filterIdx];
|
||||
val = doFma(filTmp, imTmp, val);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
doEpilog(output, outIdx, alpha*val, beta);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T_ELEM>
|
||||
static void weightGrad_cpu_ref(/*const TensorNdTestDesc_t *tensorInputDesc,*/
|
||||
const T_ELEM *image,
|
||||
/*const TensorNdTestDesc_t *tensorDiffDesc,*/
|
||||
const T_ELEM *diffData,
|
||||
/*const ConvNdTestDesc_t *convDesc,*/
|
||||
/*const TensorNdTestDesc_t *filterOutputDesc,*/
|
||||
float alpha,
|
||||
float beta,
|
||||
T_ELEM *output,
|
||||
bool isNchw,
|
||||
const int* inDims,
|
||||
const int* filDims,
|
||||
const int* diffDims,
|
||||
const int* inStride,
|
||||
const int* diffStride,
|
||||
const int* stride,
|
||||
const int* pad,
|
||||
const int* dilation,
|
||||
int nbDims )
|
||||
{
|
||||
// Some sanity checks
|
||||
// image is n x c x h x w
|
||||
// diff is n x k x p x q
|
||||
// filter is k x c x r x s
|
||||
assert(inDims[0] == diffDims[0]) ;
|
||||
assert(inDims[1] == filDims[1]) ;
|
||||
assert(diffDims[1] == filDims[0]) ;
|
||||
|
||||
// Filter stride
|
||||
int filterStride[4] ;
|
||||
generateStrides(filDims, filterStride, nbDims, isNchw);
|
||||
|
||||
bool isConv = true; //(CUDNN_CONVOLUTION == mode) ;
|
||||
|
||||
// For every filter pixel (k x c x r x s)
|
||||
for(int ci = 0; ci < inDims[1]; ci++) { // Loop over filter output pixels
|
||||
for(int ri = 0; ri < filDims[2]; ri++) { // ^
|
||||
for(int si = 0; si < filDims[3]; si++) { // ^
|
||||
for(int ki = 0; ki < filDims[0]; ki++){ // ^
|
||||
int filIdx = ki * filterStride[0] + ci * filterStride[1] + ri * filterStride[2] + si * filterStride[3] ;
|
||||
float val = 0.f ;
|
||||
// For every image (n)
|
||||
for(int ni = 0 ; ni < inDims[0]; ni++) { // Sum over the batch
|
||||
int offset_image = ni * inStride[0] + ci * inStride[1] ;
|
||||
int offset_diff = ni * diffStride[0] + ki * diffStride[1] ;
|
||||
// For every pixel in diff (p x q)
|
||||
for(int pi = 0; pi < diffDims[2] ; pi++ ) { // Sum over the pixels of diff
|
||||
for(int qi = 0; qi < diffDims[3] ; qi++ ) { // ^
|
||||
// Fetch the value in image and diff, product and accumulate
|
||||
int y = pi * stride[0] - pad[0] ;
|
||||
int x = qi * stride[1] - pad[1] ;
|
||||
// Convolution = Correlation with a flipped filter
|
||||
// So basically, for the convolution, we replace r by dim-1-r and s by dim-1-s to "flip" the filter
|
||||
// We can then just reason in term of correlation
|
||||
if (isConv){
|
||||
y += (filDims[2] - 1 - ri) * dilation[0] ;
|
||||
x += (filDims[3] - 1 - si) * dilation[1] ;
|
||||
} else {
|
||||
// The effect of dilation on the gradient is to start the "zone of influence" of a given pixel further into the image, so dilation
|
||||
// only produces a shift in x and y
|
||||
y += ri * dilation[0] ;
|
||||
x += si * dilation[1] ;
|
||||
}
|
||||
// Image value
|
||||
int inBounds = ((x >=0)&&(x < inDims[3])&&(y >=0)&&(y < inDims[2]));
|
||||
if (inBounds) {
|
||||
int imIdx = offset_image + y * inStride[2] + x * inStride[3] ;
|
||||
// Diff value
|
||||
int diffIdx = offset_diff + pi * diffStride[2] + qi * diffStride[3] ;
|
||||
// Prod and accumulate
|
||||
T_ELEM imTmp = image[imIdx] ;
|
||||
T_ELEM diffTmp = diffData[diffIdx];
|
||||
val = doFma(diffTmp, imTmp, val);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
doEpilog(output, filIdx, alpha*val, beta);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
float getError(float dev, float ref) {
|
||||
if (ref > 1.0 || ref < -1.0)
|
||||
return (dev - ref)/ref;
|
||||
else
|
||||
return dev - ref;
|
||||
}
|
||||
|
||||
float getError(half1 dev, half1 ref) {
|
||||
if (cpu_half2float(ref) > 1.0 || cpu_half2float(ref) < -1.0)
|
||||
return (cpu_half2float(dev) - cpu_half2float(ref))/cpu_half2float(ref);
|
||||
else
|
||||
return cpu_half2float(dev) - cpu_half2float(ref);
|
||||
}
|
||||
|
||||
static inline int getFwdConvDilatedFilterDim(int filterDim,
|
||||
int dilation)
|
||||
{
|
||||
return ( (filterDim - 1) * dilation ) + 1 ;
|
||||
}
|
||||
|
||||
static inline int getFwdConvPaddedImageDim(int tensorDim,
|
||||
int pad)
|
||||
{
|
||||
return tensorDim + (2 * pad) ;
|
||||
}
|
||||
|
||||
static inline int getFwdConvOutputDim( int tensorDim,
|
||||
int pad,
|
||||
int filterDim,
|
||||
int stride,
|
||||
int dilation)
|
||||
{
|
||||
int p = (getFwdConvPaddedImageDim(tensorDim, pad) - getFwdConvDilatedFilterDim(filterDim, dilation))/stride + 1;
|
||||
return(p);
|
||||
}
|
||||
|
||||
template <typename T_ELEM>
|
||||
int doConv(
|
||||
cudnnHandle_t handle_,
|
||||
T_ELEM* devPtrI,
|
||||
T_ELEM* devPtrF,
|
||||
T_ELEM* devPtrO,
|
||||
T_ELEM* hostI,
|
||||
T_ELEM* hostF,
|
||||
T_ELEM* hostO,
|
||||
cudnnTensorDescriptor_t cudnnIdesc,
|
||||
cudnnFilterDescriptor_t cudnnFdesc,
|
||||
cudnnTensorDescriptor_t cudnnOdesc,
|
||||
cudnnConvolutionDescriptor_t cudnnConvDesc,
|
||||
float alpha,
|
||||
float beta,
|
||||
cudnnTensorFormat_t filterFormat,
|
||||
const int* dimA,
|
||||
const int* filterdimA,
|
||||
const int* outdimA,
|
||||
const int* strideA,
|
||||
const int* outstrideA,
|
||||
const int* convstrideA,
|
||||
const int* padA,
|
||||
const int* dilationA,
|
||||
const int benchmark) {
|
||||
|
||||
int outsize = outstrideA[0]*outdimA[0];
|
||||
T_ELEM* hostOfromdev = (T_ELEM*)calloc (outsize, sizeof(hostO[0]) );
|
||||
|
||||
cudnnConvolutionFwdAlgo_t algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM;
|
||||
|
||||
void *workSpace = 0;
|
||||
size_t workSpaceSize;
|
||||
|
||||
checkCudaErrors ( cudnnGetConvolutionForwardWorkspaceSize(handle_, cudnnIdesc, cudnnFdesc, cudnnConvDesc,
|
||||
cudnnOdesc, algo, &workSpaceSize) );
|
||||
|
||||
if (workSpaceSize > 0) {
|
||||
cudaMalloc(&workSpace, workSpaceSize);
|
||||
}
|
||||
double start = second();
|
||||
checkCudaErrors ( cudnnConvolutionForward (handle_,
|
||||
(void*)(&alpha),
|
||||
cudnnIdesc, devPtrI,
|
||||
cudnnFdesc, devPtrF,
|
||||
cudnnConvDesc,
|
||||
algo,
|
||||
workSpace, workSpaceSize,
|
||||
(void*)(&beta),
|
||||
cudnnOdesc, devPtrO) );
|
||||
checkCudaErrors( cudaDeviceSynchronize() );
|
||||
double stop = second();
|
||||
printPerf( stop - start, 0, 0,
|
||||
0, 0, 0, 0);
|
||||
checkCudaErrors( cudaMemcpy(hostOfromdev, devPtrO, sizeof(hostO[0]) * outsize, cudaMemcpyDeviceToHost) );
|
||||
checkCudaErrors( cudaDeviceSynchronize() );
|
||||
if (workSpace) {
|
||||
cudaFree(workSpace);
|
||||
workSpace = 0;
|
||||
}
|
||||
int numErrors = 0;
|
||||
if (!benchmark) {
|
||||
conv_cpu_ref<T_ELEM>( hostI, hostF, hostO, alpha, beta, (filterFormat == CUDNN_TENSOR_NCHW), dimA, filterdimA, outdimA, strideA, outstrideA, convstrideA, padA, dilationA, 4);
|
||||
for (int index = 0; index < outsize; index++) { // assuming out data is packed
|
||||
float diff = getError(hostOfromdev[index], hostO[index]);
|
||||
if (diff < 0) diff = -diff;
|
||||
if(diff > THRESHOLD) {
|
||||
numErrors++;
|
||||
}
|
||||
}
|
||||
}
|
||||
return numErrors;
|
||||
}
|
||||
|
||||
template <typename T_ELEM>
|
||||
int doDgrad(
|
||||
cudnnHandle_t handle_,
|
||||
T_ELEM* devPtrI,
|
||||
T_ELEM* devPtrF,
|
||||
T_ELEM* devPtrO,
|
||||
T_ELEM* hostI,
|
||||
T_ELEM* hostF,
|
||||
T_ELEM* hostO,
|
||||
cudnnTensorDescriptor_t cudnnIdesc,
|
||||
cudnnFilterDescriptor_t cudnnFdesc,
|
||||
cudnnTensorDescriptor_t cudnnOdesc,
|
||||
cudnnConvolutionDescriptor_t cudnnConvDesc,
|
||||
float alpha,
|
||||
float beta,
|
||||
cudnnTensorFormat_t filterFormat,
|
||||
const int* dimA,
|
||||
const int* filterdimA,
|
||||
const int* outdimA,
|
||||
const int* strideA,
|
||||
const int* outstrideA,
|
||||
const int* convstrideA,
|
||||
const int* padA,
|
||||
const int* dilationA,
|
||||
const int benchmark) {
|
||||
|
||||
int insize = strideA[0]*dimA[0];
|
||||
T_ELEM* hostIfromdev = (T_ELEM*)calloc (insize, sizeof(hostI[0]) );
|
||||
cudnnConvolutionBwdDataAlgo_t algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
|
||||
|
||||
void *workSpace = 0;
|
||||
size_t workSpaceSize;
|
||||
|
||||
checkCudaErrors ( cudnnGetConvolutionBackwardDataWorkspaceSize(handle_, cudnnFdesc, cudnnOdesc, cudnnConvDesc,
|
||||
cudnnIdesc, algo, &workSpaceSize) );
|
||||
|
||||
if (workSpaceSize > 0) {
|
||||
cudaMalloc(&workSpace, workSpaceSize);
|
||||
}
|
||||
double start = second();
|
||||
checkCudaErrors ( cudnnConvolutionBackwardData (handle_,
|
||||
(void*)(&alpha),
|
||||
cudnnFdesc, devPtrF,
|
||||
cudnnOdesc, devPtrO,
|
||||
cudnnConvDesc,
|
||||
algo,
|
||||
workSpace, workSpaceSize,
|
||||
(void*)(&beta),
|
||||
cudnnIdesc, devPtrI) );
|
||||
checkCudaErrors( cudaDeviceSynchronize() );
|
||||
double stop = second();
|
||||
printPerf( stop - start, 0, 0,
|
||||
0, 0, 0, 0);
|
||||
checkCudaErrors( cudaMemcpy(hostIfromdev, devPtrI, sizeof(hostI[0]) * insize, cudaMemcpyDeviceToHost) );
|
||||
checkCudaErrors( cudaDeviceSynchronize() );
|
||||
if (workSpace) {
|
||||
cudaFree(workSpace);
|
||||
workSpace = 0;
|
||||
}
|
||||
int numErrors = 0;
|
||||
if (!benchmark) {
|
||||
dataGrad_cpu_ref<T_ELEM>(hostF, hostO, hostI, alpha, beta, (filterFormat == CUDNN_TENSOR_NCHW), outdimA, filterdimA, dimA, outstrideA, strideA, convstrideA, padA, dilationA, 4);
|
||||
for (int index = 0; index < insize; index++) { // assuming in data is packed
|
||||
float diff = getError(hostIfromdev[index], hostI[index]);
|
||||
if (diff < 0) diff = -diff;
|
||||
if(diff > THRESHOLD) {
|
||||
numErrors++;
|
||||
}
|
||||
}
|
||||
}
|
||||
return numErrors;
|
||||
}
|
||||
|
||||
template <typename T_ELEM>
|
||||
int doWgrad(
|
||||
cudnnHandle_t handle_,
|
||||
T_ELEM* devPtrI,
|
||||
T_ELEM* devPtrF,
|
||||
T_ELEM* devPtrO,
|
||||
T_ELEM* hostI,
|
||||
T_ELEM* hostF,
|
||||
T_ELEM* hostO,
|
||||
cudnnTensorDescriptor_t cudnnIdesc,
|
||||
cudnnFilterDescriptor_t cudnnFdesc,
|
||||
cudnnTensorDescriptor_t cudnnOdesc,
|
||||
cudnnConvolutionDescriptor_t cudnnConvDesc,
|
||||
float alpha,
|
||||
float beta,
|
||||
cudnnTensorFormat_t filterFormat,
|
||||
const int* dimA,
|
||||
const int* filterdimA,
|
||||
const int* outdimA,
|
||||
const int* strideA,
|
||||
const int* outstrideA,
|
||||
const int* convstrideA,
|
||||
const int* padA,
|
||||
const int* dilationA,
|
||||
const int benchmark) {
|
||||
|
||||
int filsize = filterdimA[0]*filterdimA[1]*filterdimA[2]*filterdimA[3];
|
||||
T_ELEM* hostFfromdev = (T_ELEM*)calloc (filsize, sizeof(hostF[0]) );
|
||||
cudnnConvolutionBwdFilterAlgo_t algo = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1;
|
||||
|
||||
void *workSpace = 0;
|
||||
size_t workSpaceSize;
|
||||
|
||||
checkCudaErrors ( cudnnGetConvolutionBackwardFilterWorkspaceSize(handle_, cudnnIdesc, cudnnOdesc, cudnnConvDesc,
|
||||
cudnnFdesc, algo, &workSpaceSize) );
|
||||
|
||||
if (workSpaceSize > 0) {
|
||||
cudaMalloc(&workSpace, workSpaceSize);
|
||||
}
|
||||
double start = second();
|
||||
checkCudaErrors ( cudnnConvolutionBackwardFilter (handle_,
|
||||
(void*)(&alpha),
|
||||
cudnnIdesc, devPtrI,
|
||||
cudnnOdesc, devPtrO,
|
||||
cudnnConvDesc,
|
||||
algo,
|
||||
workSpace, workSpaceSize,
|
||||
(void*)(&beta),
|
||||
cudnnFdesc, devPtrF) );
|
||||
checkCudaErrors( cudaDeviceSynchronize() );
|
||||
double stop = second();
|
||||
printPerf( stop - start, 0, 0,
|
||||
0, 0, 0, 0);
|
||||
checkCudaErrors( cudaMemcpy(hostFfromdev, devPtrF, sizeof(hostF[0]) * filsize, cudaMemcpyDeviceToHost) );
|
||||
checkCudaErrors( cudaDeviceSynchronize() );
|
||||
if (workSpace) {
|
||||
cudaFree(workSpace);
|
||||
workSpace = 0;
|
||||
}
|
||||
int numErrors = 0;
|
||||
if (!benchmark) {
|
||||
weightGrad_cpu_ref<T_ELEM>(hostI, hostO, alpha, beta, hostF, (filterFormat == CUDNN_TENSOR_NCHW), dimA, filterdimA, outdimA, strideA, outstrideA, convstrideA, padA, dilationA, 4);
|
||||
for (int index = 0; index < filsize; index++) { // assuming in data is packed
|
||||
float diff = getError(hostFfromdev[index], hostF[index]);
|
||||
if (diff < 0) diff = -diff;
|
||||
if(diff > THRESHOLD) {
|
||||
numErrors++;
|
||||
}
|
||||
}
|
||||
}
|
||||
return numErrors;
|
||||
}
|
||||
|
||||
template <typename T_ELEM>
|
||||
int doTest(int algo, int* dimA, int* padA, int* convstrideA, int* filterdimA, cudnnTensorFormat_t filterFormat, int mathType, int benchmark) {
|
||||
|
||||
cudnnHandle_t handle_;
|
||||
T_ELEM* devPtrI;
|
||||
T_ELEM* devPtrF;
|
||||
T_ELEM* devPtrO;
|
||||
T_ELEM* hostI;
|
||||
T_ELEM* hostF;
|
||||
T_ELEM* hostO;
|
||||
|
||||
cudnnTensorDescriptor_t cudnnIdesc;
|
||||
cudnnFilterDescriptor_t cudnnFdesc;
|
||||
cudnnTensorDescriptor_t cudnnOdesc;
|
||||
cudnnConvolutionDescriptor_t cudnnConvDesc;
|
||||
|
||||
int convDim = 2;
|
||||
|
||||
float alpha = 1.0f;
|
||||
float beta = 0.0;
|
||||
|
||||
checkCudaErrors(cudnnCreate(&handle_));
|
||||
|
||||
checkCudaErrors( cudnnCreateTensorDescriptor( &cudnnIdesc ));
|
||||
checkCudaErrors( cudnnCreateFilterDescriptor( &cudnnFdesc ));
|
||||
checkCudaErrors( cudnnCreateTensorDescriptor( &cudnnOdesc ));
|
||||
checkCudaErrors( cudnnCreateConvolutionDescriptor( &cudnnConvDesc ));
|
||||
|
||||
int dilationA[] = {1, 1};
|
||||
|
||||
int strideA[] = {8192, 1024, 32, 1};
|
||||
generateStrides(dimA, strideA, 4, (filterFormat == CUDNN_TENSOR_NCHW));
|
||||
int insize = strideA[0]*dimA[0];
|
||||
|
||||
int filtersize = filterdimA[0]*filterdimA[1]*filterdimA[2]*filterdimA[3];
|
||||
|
||||
int outdimA[] = {1, 8, 30, 30};
|
||||
outdimA[0] = dimA[0];
|
||||
outdimA[1] = filterdimA[0];
|
||||
for( int dim = 0; dim < 2; dim++) {
|
||||
outdimA[dim+2] = getFwdConvOutputDim( dimA[dim+2],
|
||||
padA[dim],
|
||||
filterdimA[dim+2],
|
||||
convstrideA[dim],
|
||||
dilationA[dim]);
|
||||
}
|
||||
|
||||
int outstrideA[] = {7200, 900, 30, 1};
|
||||
generateStrides(outdimA, outstrideA, 4, (filterFormat == CUDNN_TENSOR_NCHW));
|
||||
int outsize = outstrideA[0]*outdimA[0];
|
||||
|
||||
cudaMalloc ((void**)&(devPtrI), (insize) * sizeof(devPtrI[0]) );
|
||||
cudaMalloc ((void**)&(devPtrF), (filtersize) * sizeof(devPtrF[0]) );
|
||||
cudaMalloc ((void**)&(devPtrO), (outsize) * sizeof(devPtrO[0]) );
|
||||
hostI = (T_ELEM*)calloc (insize, sizeof(hostI[0]) );
|
||||
hostF = (T_ELEM*)calloc (filtersize, sizeof(hostF[0]) );
|
||||
hostO = (T_ELEM*)calloc (outsize, sizeof(hostO[0]) );
|
||||
|
||||
initImage(hostI, insize);
|
||||
initImage(hostF, filtersize);
|
||||
initImage(hostO, outsize);
|
||||
|
||||
checkCudaErrors( cudaMemcpy(devPtrI, hostI, sizeof(hostI[0]) * insize, cudaMemcpyHostToDevice));
|
||||
checkCudaErrors( cudaMemcpy(devPtrF, hostF, sizeof(hostF[0]) * filtersize, cudaMemcpyHostToDevice));
|
||||
checkCudaErrors( cudaMemcpy(devPtrO, hostO, sizeof(hostO[0]) * outsize, cudaMemcpyHostToDevice));
|
||||
checkCudaErrors( cudaDeviceSynchronize() );
|
||||
|
||||
checkCudaErrors( cudnnSetTensorNdDescriptor(cudnnIdesc, getDataType<T_ELEM>(), convDim+2, dimA, strideA) );
|
||||
|
||||
checkCudaErrors( cudnnSetFilterNdDescriptor(cudnnFdesc, getDataType<T_ELEM>(), filterFormat, convDim+2, filterdimA));
|
||||
|
||||
checkCudaErrors( cudnnSetConvolutionNdDescriptor(cudnnConvDesc,
|
||||
convDim,
|
||||
padA,
|
||||
convstrideA,
|
||||
dilationA,
|
||||
CUDNN_CONVOLUTION,
|
||||
CUDNN_DATA_FLOAT) );
|
||||
if (mathType == 1) {
|
||||
checkCudaErrors( cudnnSetConvolutionMathType(cudnnConvDesc, CUDNN_TENSOR_OP_MATH) );
|
||||
}
|
||||
|
||||
checkCudaErrors( cudnnSetTensorNdDescriptor(cudnnOdesc, getDataType<T_ELEM>(), convDim+2, outdimA, outstrideA) );
|
||||
|
||||
int numErrors = 0;
|
||||
if (algo == 0) {
|
||||
printf("Testing conv\n");
|
||||
numErrors = doConv(
|
||||
handle_,
|
||||
devPtrI,
|
||||
devPtrF,
|
||||
devPtrO,
|
||||
hostI,
|
||||
hostF,
|
||||
hostO,
|
||||
cudnnIdesc,
|
||||
cudnnFdesc,
|
||||
cudnnOdesc,
|
||||
cudnnConvDesc,
|
||||
alpha,
|
||||
beta,
|
||||
filterFormat,
|
||||
dimA,
|
||||
filterdimA,
|
||||
outdimA,
|
||||
strideA,
|
||||
outstrideA,
|
||||
convstrideA,
|
||||
padA,
|
||||
dilationA,
|
||||
benchmark);
|
||||
} else if (algo == 1) {
|
||||
printf("Testing dgrad\n");
|
||||
numErrors = doDgrad(
|
||||
handle_,
|
||||
devPtrI,
|
||||
devPtrF,
|
||||
devPtrO,
|
||||
hostI,
|
||||
hostF,
|
||||
hostO,
|
||||
cudnnIdesc,
|
||||
cudnnFdesc,
|
||||
cudnnOdesc,
|
||||
cudnnConvDesc,
|
||||
alpha,
|
||||
beta,
|
||||
filterFormat,
|
||||
dimA,
|
||||
filterdimA,
|
||||
outdimA,
|
||||
strideA,
|
||||
outstrideA,
|
||||
convstrideA,
|
||||
padA,
|
||||
dilationA,
|
||||
benchmark);
|
||||
} else {
|
||||
printf("Testing wgrad\n");
|
||||
numErrors = doWgrad(
|
||||
handle_,
|
||||
devPtrI,
|
||||
devPtrF,
|
||||
devPtrO,
|
||||
hostI,
|
||||
hostF,
|
||||
hostO,
|
||||
cudnnIdesc,
|
||||
cudnnFdesc,
|
||||
cudnnOdesc,
|
||||
cudnnConvDesc,
|
||||
alpha,
|
||||
beta,
|
||||
filterFormat,
|
||||
dimA,
|
||||
filterdimA,
|
||||
outdimA,
|
||||
strideA,
|
||||
outstrideA,
|
||||
convstrideA,
|
||||
padA,
|
||||
dilationA,
|
||||
benchmark);
|
||||
}
|
||||
|
||||
if (!benchmark) {
|
||||
if (numErrors == 0) {
|
||||
printf("Test PASSED\n");
|
||||
} else {
|
||||
printf("Test FAILED, num errors = %d\n", numErrors);
|
||||
}
|
||||
}
|
||||
|
||||
if (devPtrI) cudaFree (devPtrI);
|
||||
if (devPtrF) cudaFree (devPtrF);
|
||||
if (devPtrO) cudaFree (devPtrO);
|
||||
if (cudnnIdesc) cudnnDestroyTensorDescriptor(cudnnIdesc);
|
||||
if (cudnnFdesc) cudnnDestroyFilterDescriptor(cudnnFdesc);
|
||||
if (cudnnOdesc) cudnnDestroyTensorDescriptor(cudnnOdesc);
|
||||
if (cudnnConvDesc) cudnnDestroyConvolutionDescriptor(cudnnConvDesc);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
int cudnn_test_entry( int argc, char** argv )
|
||||
{
|
||||
int algo = 0;
|
||||
int mathType = 0;
|
||||
int benchmark = 0;
|
||||
|
||||
int dimA[] = {1, 8, 32, 32};
|
||||
|
||||
int padA[] = {0, 0};
|
||||
int convstrideA[] = {1, 1};
|
||||
|
||||
int filterdimA[] = {8, 8, 3, 3};
|
||||
|
||||
cudnnTensorFormat_t filterFormat = CUDNN_TENSOR_NCHW;
|
||||
|
||||
int error = 0;
|
||||
while (argc) {
|
||||
if (*argv[0] == SWITCH_CHAR) {
|
||||
switch (*(argv[0]+1)) {
|
||||
case 'b':
|
||||
benchmark = 1;
|
||||
break;
|
||||
case 'c':
|
||||
dimA[1] = atol(argv[0]+2);
|
||||
filterdimA[1] = dimA[1];
|
||||
break;
|
||||
case 'd':
|
||||
if ( strncmp( argv[0]+1, "dgrad" , strlen("dgrad")) == 0) {
|
||||
algo = 1;
|
||||
}
|
||||
break;
|
||||
case 'f':
|
||||
if ( strncmp( argv[0]+1, "filterFormat" , strlen("filterFormat")) == 0) {
|
||||
filterFormat = (cudnnTensorFormat_t)(atoi(argv[0]+ 1 + strlen("filterFormat")));
|
||||
}
|
||||
break;
|
||||
case 'h':
|
||||
dimA[2] = atol(argv[0]+2);
|
||||
break;
|
||||
case 'k':
|
||||
filterdimA[0] = atol(argv[0]+2);
|
||||
break;
|
||||
case 'm':
|
||||
if ( strncmp( argv[0]+1, "mathType1" , strlen("mathType1")) == 0) {
|
||||
mathType = 1;
|
||||
}
|
||||
break;
|
||||
case 'n':
|
||||
dimA[0] = atol(argv[0]+2);
|
||||
break;
|
||||
case 'p':
|
||||
if ( strncmp( argv[0]+1, "pad_h" , strlen("pad_h")) == 0) {
|
||||
padA[0] = (int)atol(argv[0]+ 1 + strlen("pad_h"));
|
||||
}
|
||||
else if ( strncmp( argv[0]+1, "pad_w" , strlen("pad_w")) == 0) {
|
||||
padA[1] = (int)atol(argv[0]+ 1 + strlen("pad_w"));
|
||||
}
|
||||
break;
|
||||
case 'r':
|
||||
filterdimA[2] = atol(argv[0]+2);
|
||||
break;
|
||||
case 's':
|
||||
filterdimA[3] = atol(argv[0]+2);
|
||||
break;
|
||||
case 'u':
|
||||
convstrideA[0] = atol(argv[0]+2);
|
||||
break;
|
||||
case 'v':
|
||||
convstrideA[1] = atol(argv[0]+2);
|
||||
break;
|
||||
case 'w':
|
||||
if ( strncmp( argv[0]+1, "wgrad" , strlen("wgrad")) == 0) {
|
||||
algo = 2;
|
||||
}
|
||||
else dimA[3] = atol(argv[0]+2);
|
||||
break;
|
||||
default:
|
||||
error++;
|
||||
break;
|
||||
}
|
||||
if (error) {
|
||||
fprintf(stderr, "Unknown switch '%c%s'\n\n", SWITCH_CHAR, argv[0]+1);
|
||||
return error;
|
||||
}
|
||||
}
|
||||
else {
|
||||
fprintf(stderr, "Invalid separator '%c' for option '%s'\n\n", *argv[0], argv[0] );
|
||||
return 1;
|
||||
}
|
||||
argc -= 1;
|
||||
argv++;
|
||||
}
|
||||
|
||||
printf("Testing single precision\n");
|
||||
doTest<float>(algo, dimA, padA, convstrideA, filterdimA, filterFormat, mathType, benchmark);
|
||||
printf("Testing half precision (math in single precision)\n");
|
||||
doTest<half1>(algo, dimA, padA, convstrideA, filterdimA, filterFormat, mathType, benchmark);
|
||||
|
||||
return 0;
|
||||
}
|
|
@ -0,0 +1,26 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include "cudnn_warper.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
cudnnHandle_t cudnn_handle;
|
||||
|
||||
struct cudnn_initer {
|
||||
|
||||
inline cudnn_initer() {
|
||||
checkCudaErrors(cudnnCreate(&cudnn_handle));
|
||||
LOGv << "cudnnCreate finished";
|
||||
}
|
||||
|
||||
inline ~cudnn_initer() {
|
||||
checkCudaErrors(cudnnDestroy(cudnn_handle));
|
||||
LOGv << "cudnnDestroy finished";
|
||||
}
|
||||
|
||||
} init;
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,6 @@
|
|||
#include <cudnn.h>
|
||||
#include <helper_cuda.h>
|
||||
|
||||
const char *_cudaGetErrorEnum(cudnnStatus_t error) {
|
||||
return cudnnGetErrorString(error);
|
||||
}
|
|
@ -0,0 +1,21 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors:
|
||||
// Guoye Yang <498731903@qq.com>
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include <cuda_runtime.h>
|
||||
#include <curand.h>
|
||||
|
||||
#include "helper_cuda.h"
|
||||
#include "fp16_emu.h"
|
||||
#include "common.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
extern curandGenerator_t gen;
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,45 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include <random>
|
||||
|
||||
#include "var.h"
|
||||
#include "init.h"
|
||||
#include <cuda_runtime.h>
|
||||
#include <curand.h>
|
||||
#include <helper_cuda.h>
|
||||
#include "curand_random_op.h"
|
||||
#include "curand_warper.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
#ifndef JIT
|
||||
CurandRandomOp::CurandRandomOp(NanoVector shape, NanoString dtype) {
|
||||
flags.set(NodeFlags::_cuda, 1);
|
||||
output = create_output(shape, dtype);
|
||||
}
|
||||
|
||||
void CurandRandomOp::jit_prepare() {
|
||||
add_jit_define("T", output->dtype());
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
#ifdef JIT_cpu
|
||||
void CurandRandomOp::jit_run() {
|
||||
}
|
||||
#else // JIT_cuda
|
||||
void CurandRandomOp::jit_run() {
|
||||
auto* __restrict__ x = output->ptr<T>();
|
||||
index_t num = output->num;
|
||||
if (sizeof(T) == 4) {
|
||||
checkCudaErrors( curandGenerateUniform(gen, (float*)x, num) );
|
||||
} else {
|
||||
checkCudaErrors( curandGenerateUniformDouble(gen, (float64*)x, num) );
|
||||
}
|
||||
}
|
||||
#endif // JIT_cpu
|
||||
#endif // JIT
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,22 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors:
|
||||
// Guoye Yang <498731903@qq.com>.
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include "op.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
struct CurandRandomOp : Op {
|
||||
Var* output;
|
||||
CurandRandomOp(NanoVector shape, NanoString dtype=ns_float);
|
||||
|
||||
const char* name() const override { return "curand_random"; }
|
||||
DECLARE_jit_run;
|
||||
};
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,33 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors:
|
||||
// Guoye Yang <498731903@qq.com>
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include "curand_warper.h"
|
||||
#include "init.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
curandGenerator_t gen;
|
||||
|
||||
struct curand_initer {
|
||||
|
||||
inline curand_initer() {
|
||||
checkCudaErrors( curandCreateGenerator(&gen, CURAND_RNG_PSEUDO_DEFAULT) );
|
||||
add_set_seed_callback([](int seed) {
|
||||
checkCudaErrors( curandSetPseudoRandomGeneratorSeed(gen, seed) );
|
||||
});
|
||||
LOGv << "curandCreate finished";
|
||||
}
|
||||
|
||||
inline ~curand_initer() {
|
||||
checkCudaErrors( curandDestroyGenerator(gen) );
|
||||
LOGv << "curandDestroy finished";
|
||||
}
|
||||
|
||||
} init_;
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,60 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors:
|
||||
// Guoye Yang <498731903@qq.com>.
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
// These are CUDA Helper functions for initialization and error checking
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cublas.h>
|
||||
#include <helper_cuda.h>
|
||||
#include <curand.h>
|
||||
|
||||
// cuRAND API errors
|
||||
const char *_cudaGetErrorEnum(curandStatus_t error) {
|
||||
switch (error) {
|
||||
case CURAND_STATUS_SUCCESS:
|
||||
return "CURAND_STATUS_SUCCESS";
|
||||
|
||||
case CURAND_STATUS_VERSION_MISMATCH:
|
||||
return "CURAND_STATUS_VERSION_MISMATCH";
|
||||
|
||||
case CURAND_STATUS_NOT_INITIALIZED:
|
||||
return "CURAND_STATUS_NOT_INITIALIZED";
|
||||
|
||||
case CURAND_STATUS_ALLOCATION_FAILED:
|
||||
return "CURAND_STATUS_ALLOCATION_FAILED";
|
||||
|
||||
case CURAND_STATUS_TYPE_ERROR:
|
||||
return "CURAND_STATUS_TYPE_ERROR";
|
||||
|
||||
case CURAND_STATUS_OUT_OF_RANGE:
|
||||
return "CURAND_STATUS_OUT_OF_RANGE";
|
||||
|
||||
case CURAND_STATUS_LENGTH_NOT_MULTIPLE:
|
||||
return "CURAND_STATUS_LENGTH_NOT_MULTIPLE";
|
||||
|
||||
case CURAND_STATUS_DOUBLE_PRECISION_REQUIRED:
|
||||
return "CURAND_STATUS_DOUBLE_PRECISION_REQUIRED";
|
||||
|
||||
case CURAND_STATUS_LAUNCH_FAILURE:
|
||||
return "CURAND_STATUS_LAUNCH_FAILURE";
|
||||
|
||||
case CURAND_STATUS_PREEXISTING_FAILURE:
|
||||
return "CURAND_STATUS_PREEXISTING_FAILURE";
|
||||
|
||||
case CURAND_STATUS_INITIALIZATION_FAILED:
|
||||
return "CURAND_STATUS_INITIALIZATION_FAILED";
|
||||
|
||||
case CURAND_STATUS_ARCH_MISMATCH:
|
||||
return "CURAND_STATUS_ARCH_MISMATCH";
|
||||
|
||||
case CURAND_STATUS_INTERNAL_ERROR:
|
||||
return "CURAND_STATUS_INTERNAL_ERROR";
|
||||
}
|
||||
|
||||
return "<unknown>";
|
||||
}
|
|
@ -0,0 +1,42 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2019 Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include "var.h"
|
||||
#include "cutt_test_op.h"
|
||||
#include "misc/str_utils.h"
|
||||
|
||||
#ifdef JIT
|
||||
#include "cutt.h"
|
||||
#endif
|
||||
|
||||
namespace jittor {
|
||||
|
||||
#ifndef JIT
|
||||
CuttTestOp::CuttTestOp(string cmd) : cmd(cmd) {
|
||||
flags.set(NodeFlags::_cpu, 0);
|
||||
flags.set(NodeFlags::_cuda, 1);
|
||||
output = create_output(1, ns_float32);
|
||||
}
|
||||
|
||||
void CuttTestOp::jit_prepare() {
|
||||
add_jit_define("T", ns_float32);
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
#ifdef JIT_cuda
|
||||
|
||||
void CuttTestOp::jit_run() {
|
||||
auto args = split(cmd, " ");
|
||||
if (!cmd.size()) args.clear();
|
||||
vector<char*> v(args.size());
|
||||
for (uint i=0; i<args.size(); i++)
|
||||
v[i] = &args[i][0];
|
||||
output->ptr<T>()[0] = 123;
|
||||
|
||||
}
|
||||
#endif
|
||||
#endif // JIT
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,24 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2019
|
||||
// Guoye Yang <498731903@qq.com>
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include "op.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
struct CuttTestOp : Op {
|
||||
Var* output;
|
||||
string cmd;
|
||||
|
||||
CuttTestOp(string cmd);
|
||||
|
||||
const char* name() const override { return "cutt_test"; }
|
||||
DECLARE_jit_run;
|
||||
};
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,108 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2019 Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include "var.h"
|
||||
#include "cutt_transpose_op.h"
|
||||
#include "ops/op_register.h"
|
||||
#include <iostream>
|
||||
|
||||
#ifdef JIT
|
||||
#include "cutt.h"
|
||||
#endif
|
||||
#include "cutt_warper.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
#ifndef JIT
|
||||
static auto make_transpose = get_op_info("cutt_transpose")
|
||||
.get_constructor<VarPtr, Var*, NanoVector>();
|
||||
|
||||
CuttTransposeOp::CuttTransposeOp(Var* x, NanoVector axes) : x(x), axes(axes) {
|
||||
flags.set(NodeFlags::_cpu, 0);
|
||||
flags.set(NodeFlags::_cuda, 1);
|
||||
int i=0;
|
||||
for (; i<axes.size(); i++)
|
||||
if (i!=axes[i]) break;
|
||||
if (i==axes.size() && axes.size()) {
|
||||
forward(x);
|
||||
return;
|
||||
}
|
||||
y = create_output(nullptr, x->dtype());
|
||||
}
|
||||
|
||||
void CuttTransposeOp::infer_shape() {
|
||||
auto xdim = x->shape.size();
|
||||
CHECK(xdim);
|
||||
if (!axes.size()) {
|
||||
for (int i=0; i<(int)xdim; i++)
|
||||
axes.push_back(xdim-1-i);
|
||||
} else {
|
||||
CHECKop(axes.size(),==,xdim);
|
||||
int64_t mask=0;
|
||||
for (auto i : axes) mask |= 1<<i;
|
||||
CHECK(mask==((1ll<<xdim)-1)) << "Invalid axes" << axes;
|
||||
}
|
||||
NanoVector shape;
|
||||
for (uint i=0; i<xdim; i++)
|
||||
shape.push_back(x->shape[axes[i]]);
|
||||
y->set_shape(shape);
|
||||
}
|
||||
|
||||
VarPtr CuttTransposeOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
||||
NanoVector reverse;
|
||||
reverse.reserve(axes.size(), axes.size());
|
||||
for (uint i=0; i<axes.size(); i++)
|
||||
reverse.set_data(axes[i], i);
|
||||
return make_transpose(dout, reverse);
|
||||
}
|
||||
|
||||
void CuttTransposeOp::jit_prepare() {
|
||||
add_jit_define("Tx", x->dtype());
|
||||
add_jit_define("DIM", JK::hex1(axes.size()));
|
||||
for (uint i=0; i<axes.size(); i++)
|
||||
add_jit_define("AXES", JK::hex1(axes[i]), S(i));
|
||||
}
|
||||
unordered_map<string, unsigned int> cutt_plan_cache;
|
||||
|
||||
#else // JIT
|
||||
#ifdef JIT_cuda
|
||||
|
||||
extern unordered_map<string, unsigned int> cutt_plan_cache;
|
||||
|
||||
void CuttTransposeOp::jit_run() {
|
||||
auto* __restrict__ xp = x->ptr<Tx>();
|
||||
auto* __restrict__ yp = y->ptr<Tx>();
|
||||
vector<int> permutation, permutation2;
|
||||
vector<int> y_shape;
|
||||
vector<int> x_shape;
|
||||
@for(i, 0, DIM, permutation.push_back(DIM-1-AXES@i);)
|
||||
@for(i, 0, DIM, permutation2.push_back(permutation[DIM-1-@i@@]);)
|
||||
std::vector<int> reverse;
|
||||
reverse.reserve(permutation2.size());
|
||||
for (uint i=0; i<permutation2.size(); i++)
|
||||
reverse[permutation2[i]] = i;
|
||||
|
||||
@for(i, 0, DIM, x_shape.push_back(x->shape[DIM-1-@i@@]);)
|
||||
|
||||
jk.clear();
|
||||
jk << @DIM << ",";
|
||||
for (uint i=0; i<@DIM; i++) jk << x_shape[i] << ",";
|
||||
for (uint i=0; i<@DIM; i++) jk << reverse[i] << ",";
|
||||
jk << sizeof(Tx) << ".";
|
||||
auto iter = cutt_plan_cache.find(jk.to_string());
|
||||
|
||||
if (iter!=cutt_plan_cache.end()){
|
||||
cuttExecute(iter->second, xp, yp);
|
||||
} else {
|
||||
cuttHandle plan;
|
||||
cuttPlan(&plan, @DIM, x_shape.data(), reverse.data(), sizeof(Tx), 0);
|
||||
cutt_plan_cache[jk.to_string()] = plan;
|
||||
cuttExecute(plan, xp, yp);
|
||||
}
|
||||
}
|
||||
#endif // JIT_cuda
|
||||
#endif // JIT
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,25 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2019
|
||||
// Guoye Yang <498731903@qq.com>
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include "op.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
struct CuttTransposeOp : Op {
|
||||
Var* x, * y;
|
||||
NanoVector axes;
|
||||
CuttTransposeOp(Var* x, NanoVector axes=NanoVector());
|
||||
|
||||
const char* name() const override { return "cutt_transpose"; }
|
||||
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
|
||||
void infer_shape() override;
|
||||
DECLARE_jit_run;
|
||||
};
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,36 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2019
|
||||
// Dun Liang <randonlang@gmail.com>
|
||||
// Guowei Yang <471184555@qq.com>
|
||||
// All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include "cutt_warper.h"
|
||||
|
||||
|
||||
namespace jittor {
|
||||
|
||||
void jt_alloc(void** p, size_t len, size_t& allocation) {
|
||||
*p = exe.allocator->alloc(len, allocation);
|
||||
}
|
||||
|
||||
void jt_free(void* p, size_t len, size_t& allocation) {
|
||||
exe.allocator->free(p, len, allocation);
|
||||
}
|
||||
|
||||
struct cutt_initer {
|
||||
|
||||
inline cutt_initer() {
|
||||
custom_cuda_malloc = jt_alloc;
|
||||
custom_cuda_free = jt_free;
|
||||
LOGv << "cuttCreate finished";
|
||||
}
|
||||
|
||||
inline ~cutt_initer() {
|
||||
LOGv << "cuttDestroy finished";
|
||||
}
|
||||
|
||||
} cutt_init;
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,15 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2019
|
||||
// Dun Liang <randonlang@gmail.com>
|
||||
// Guowei Yang <471184555@qq.com>
|
||||
// All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include "executor.h"
|
||||
#include "CudaUtils.h"
|
||||
|
||||
void jt_alloc(void** p, size_t len, size_t& allocation);
|
||||
|
||||
void jt_free(void* p, size_t len, size_t& allocation);
|
|
@ -0,0 +1,21 @@
|
|||
/**
|
||||
* Copyright 2014 NVIDIA Corporation. All rights reserved.
|
||||
*
|
||||
* Please refer to the NVIDIA end user license agreement (EULA) associated
|
||||
* with this source code for terms and conditions that govern your use of
|
||||
* this software. Any use, reproduction, disclosure, or distribution of
|
||||
* this software and related documentation outside the terms of the EULA
|
||||
* is strictly prohibited.
|
||||
*
|
||||
*/
|
||||
|
||||
#if !defined(_FP16_DEV_H_)
|
||||
#define _FP16_DEV_H_
|
||||
|
||||
#include "fp16_emu.h"
|
||||
|
||||
template <class value_type>
|
||||
void gpu_float2half_rn(int size, const value_type *buffIn, half1 *buffOut);
|
||||
|
||||
#endif // _FP16_DEV_H_
|
||||
|
|
@ -0,0 +1,167 @@
|
|||
/*
|
||||
* Copyright 1993-2014 NVIDIA Corporation. All rights reserved.
|
||||
*
|
||||
* NOTICE TO LICENSEE:
|
||||
*
|
||||
* This source code and/or documentation ("Licensed Deliverables") are
|
||||
* subject to NVIDIA intellectual property rights under U.S. and
|
||||
* international Copyright laws.
|
||||
*
|
||||
* These Licensed Deliverables contained herein is PROPRIETARY and
|
||||
* CONFIDENTIAL to NVIDIA and is being provided under the terms and
|
||||
* conditions of a form of NVIDIA software license agreement by and
|
||||
* between NVIDIA and Licensee ("License Agreement") or electronically
|
||||
* accepted by Licensee. Notwithstanding any terms or conditions to
|
||||
* the contrary in the License Agreement, reproduction or disclosure
|
||||
* of the Licensed Deliverables to any third party without the express
|
||||
* written consent of NVIDIA is prohibited.
|
||||
*
|
||||
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
||||
* LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
|
||||
* SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
|
||||
* PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
|
||||
* NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
|
||||
* DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
|
||||
* NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
||||
* LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
|
||||
* SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
|
||||
* DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
|
||||
* WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
|
||||
* ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
|
||||
* OF THESE LICENSED DELIVERABLES.
|
||||
*
|
||||
* U.S. Government End Users. These Licensed Deliverables are a
|
||||
* "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
|
||||
* 1995), consisting of "commercial computer software" and "commercial
|
||||
* computer software documentation" as such terms are used in 48
|
||||
* C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
|
||||
* only as a commercial end item. Consistent with 48 C.F.R.12.212 and
|
||||
* 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
|
||||
* U.S. Government End Users acquire the Licensed Deliverables with
|
||||
* only those rights set forth herein.
|
||||
*
|
||||
* Any use of the Licensed Deliverables in individual and commercial
|
||||
* software must include, in the user documentation and internal
|
||||
* comments to the code, the above Disclaimer and U.S. Government End
|
||||
* Users Notice.
|
||||
*/
|
||||
|
||||
// Conversion from/to 16-bit floating point (half-precision).
|
||||
|
||||
#if !defined(_FP16_EMU_H_)
|
||||
#define _FP16_EMU_H_
|
||||
|
||||
#include <driver_types.h>
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
// Necessary to ensure visibility of CUDART_VERSION macro
|
||||
#include <cuda_runtime_api.h>
|
||||
|
||||
// Definition of '__half_raw' was not provided before CUDA 9.0.
|
||||
// '__half_raw' is our type where the unsigned 16-bit integer
|
||||
// data member 'x' can be accessed in both CUDA 9.0 and 8.0.
|
||||
#if CUDART_VERSION < 9000
|
||||
typedef __half __half_raw;
|
||||
#endif
|
||||
|
||||
// Internally, in CUDNN we use half1 struct as the FP16 type.
|
||||
typedef __half half1;
|
||||
|
||||
#define HLF_EPSILON 4.887581E-04
|
||||
#define HLF_MIN 6.103516E-05
|
||||
#define HLF_MAX 6.550400E+04
|
||||
|
||||
#ifdef __GNUC__
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
|
||||
#endif
|
||||
|
||||
half1 cpu_float2half_rn(float f);
|
||||
|
||||
float cpu_half2float(half1 h);
|
||||
|
||||
static __inline__ __device__ __host__ half1 habs(half1 h)
|
||||
{
|
||||
__half_raw hr = reinterpret_cast<__half_raw&>(h);
|
||||
hr.x &= 0x7fffU;
|
||||
return reinterpret_cast<half1&>(hr);
|
||||
}
|
||||
|
||||
static __inline__ __device__ __host__ half1 hneg(half1 h)
|
||||
{
|
||||
__half_raw hr = reinterpret_cast<__half_raw&>(h);
|
||||
hr.x ^= 0x8000U;
|
||||
return reinterpret_cast<half1&>(hr);
|
||||
}
|
||||
|
||||
static __inline__ __device__ __host__ int ishnan(half1 h)
|
||||
{
|
||||
// When input is NaN, exponent is all ones and mantissa is non-zero.
|
||||
__half_raw hr = reinterpret_cast<__half_raw&>(h);
|
||||
return (hr.x & 0x7c00U) == 0x7c00U && (hr.x & 0x03ffU) != 0;
|
||||
}
|
||||
|
||||
static __inline__ __device__ __host__ int ishinf(half1 h)
|
||||
{
|
||||
// When input is +/- inf, exponent is all ones and mantissa is zero.
|
||||
__half_raw hr = reinterpret_cast<__half_raw&>(h);
|
||||
return (hr.x & 0x7c00U) == 0x7c00U && (hr.x & 0x03ffU) == 0;
|
||||
}
|
||||
|
||||
static __inline__ __device__ __host__ int ishequ(half1 x, half1 y)
|
||||
{
|
||||
__half_raw xr = reinterpret_cast<__half_raw&>(x);
|
||||
__half_raw yr = reinterpret_cast<__half_raw&>(y);
|
||||
return ishnan(x) == 0 && ishnan(y) == 0 && xr.x == yr.x;
|
||||
}
|
||||
|
||||
// Returns 0.0000 in FP16 binary form
|
||||
static __inline__ __device__ __host__ half1 hzero()
|
||||
{
|
||||
__half_raw hr;
|
||||
hr.x = 0x0000U;
|
||||
return reinterpret_cast<half1&>(hr);
|
||||
}
|
||||
|
||||
// Returns 1.0000 in FP16 binary form
|
||||
static __inline__ __device__ __host__ half1 hone()
|
||||
{
|
||||
__half_raw hr;
|
||||
hr.x = 0x3c00U;
|
||||
return reinterpret_cast<half1&>(hr);
|
||||
}
|
||||
|
||||
// Returns quiet NaN, the most significant fraction bit #9 is set
|
||||
static __inline__ __device__ __host__ half1 hnan()
|
||||
{
|
||||
__half_raw hr;
|
||||
hr.x = 0x7e00U;
|
||||
return reinterpret_cast<half1&>(hr);
|
||||
}
|
||||
|
||||
// Largest positive FP16 value, corresponds to 6.5504e+04
|
||||
static __inline__ __device__ __host__ half1 hmax()
|
||||
{
|
||||
// Exponent all ones except LSB (0x1e), mantissa is all ones (0x3ff)
|
||||
__half_raw hr;
|
||||
hr.x = 0x7bffU;
|
||||
return reinterpret_cast<half1&>(hr);
|
||||
}
|
||||
|
||||
// Smallest positive (normalized) FP16 value, corresponds to 6.1035e-05
|
||||
static __inline__ __device__ __host__ half1 hmin()
|
||||
{
|
||||
// Exponent is 0x01 (5 bits), mantissa is all zeros (10 bits)
|
||||
__half_raw hr;
|
||||
hr.x = 0x0400U;
|
||||
return reinterpret_cast<half1&>(hr);
|
||||
}
|
||||
|
||||
|
||||
#ifdef __GNUC__
|
||||
#pragma GCC diagnostic pop
|
||||
#endif
|
||||
|
||||
#endif // _FP16_EMU_H_
|
||||
|
|
@ -0,0 +1,418 @@
|
|||
/**
|
||||
* Copyright 1993-2017 NVIDIA Corporation. All rights reserved.
|
||||
*
|
||||
* Please refer to the NVIDIA end user license agreement (EULA) associated
|
||||
* with this source code for terms and conditions that govern your use of
|
||||
* this software. Any use, reproduction, disclosure, or distribution of
|
||||
* this software and related documentation outside the terms of the EULA
|
||||
* is strictly prohibited.
|
||||
*
|
||||
*/
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// These are CUDA Helper functions for initialization and error checking
|
||||
|
||||
#ifndef COMMON_HELPER_CUDA_H_
|
||||
#define COMMON_HELPER_CUDA_H_
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <stdexcept>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
|
||||
#include <helper_string.h>
|
||||
|
||||
#ifndef EXIT_WAIVED
|
||||
#define EXIT_WAIVED 2
|
||||
#endif
|
||||
|
||||
// Note, it is required that your SDK sample to include the proper header
|
||||
// files, please refer the CUDA examples for examples of the needed CUDA
|
||||
// headers, which may change depending on which CUDA functions are used.
|
||||
|
||||
// CUDA Runtime error messages
|
||||
#ifdef __DRIVER_TYPES_H__
|
||||
inline const char *_cudaGetErrorEnum(cudaError_t error) {
|
||||
return cudaGetErrorName(error);
|
||||
}
|
||||
#endif
|
||||
|
||||
// CUDA Driver API errors
|
||||
#ifdef CUDA_DRIVER_API
|
||||
inline const char *_cudaGetErrorEnum(CUresult error) {
|
||||
const char *ret = NULL;
|
||||
cuGetErrorName(error, &ret);
|
||||
return ret ? ret : "<unknown>";
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef CUBLAS_API_H_
|
||||
// cuBLAS API errors
|
||||
const char *_cudaGetErrorEnum(cublasStatus_t error);
|
||||
#endif
|
||||
|
||||
#ifdef CUDNN_H_
|
||||
// cudnn API errors
|
||||
const char *_cudaGetErrorEnum(cudnnStatus_t error);
|
||||
#endif
|
||||
|
||||
#ifdef _CUFFT_H_
|
||||
// cuFFT API errors
|
||||
const char *_cudaGetErrorEnum(cufftResult error);
|
||||
#endif
|
||||
|
||||
#ifdef CUSPARSEAPI
|
||||
// cuSPARSE API errors
|
||||
const char *_cudaGetErrorEnum(cusparseStatus_t error);
|
||||
#endif
|
||||
|
||||
#ifdef CUSOLVER_COMMON_H_
|
||||
// cuSOLVER API errors
|
||||
const char *_cudaGetErrorEnum(cusolverStatus_t error);
|
||||
#endif
|
||||
|
||||
#ifdef CURAND_H_
|
||||
// cuRAND API errors
|
||||
const char *_cudaGetErrorEnum(curandStatus_t error);
|
||||
#endif
|
||||
|
||||
#ifdef NV_NPPIDEFS_H
|
||||
// NPP API errors
|
||||
const char *_cudaGetErrorEnum(NppStatus error);
|
||||
#endif
|
||||
|
||||
#ifdef __DRIVER_TYPES_H__
|
||||
#ifndef DEVICE_RESET
|
||||
#define DEVICE_RESET cudaDeviceReset();
|
||||
#endif
|
||||
#else
|
||||
#ifndef DEVICE_RESET
|
||||
#define DEVICE_RESET
|
||||
#endif
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
void check(T result, char const *const func, const char *const file,
|
||||
int const line) {
|
||||
if (result) {
|
||||
fprintf(stderr, "CUDA error at %s:%d code=%d(%s) \"%s\" \n", file, line,
|
||||
static_cast<unsigned int>(result), _cudaGetErrorEnum(result), func);
|
||||
DEVICE_RESET
|
||||
throw std::runtime_error("CUDA error");
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef __DRIVER_TYPES_H__
|
||||
// This will output the proper CUDA error strings in the event
|
||||
// that a CUDA host call returns an error
|
||||
#define checkCudaErrors(val) check((val), #val, __FILE__, __LINE__)
|
||||
|
||||
// This will output the proper error string when calling cudaGetLastError
|
||||
#define getLastCudaError(msg) __getLastCudaError(msg, __FILE__, __LINE__)
|
||||
|
||||
inline void __getLastCudaError(const char *errorMessage, const char *file,
|
||||
const int line) {
|
||||
cudaError_t err = cudaGetLastError();
|
||||
|
||||
if (cudaSuccess != err) {
|
||||
fprintf(stderr,
|
||||
"%s(%i) : getLastCudaError() CUDA error :"
|
||||
" %s : (%d) %s.\n",
|
||||
file, line, errorMessage, static_cast<int>(err),
|
||||
cudaGetErrorString(err));
|
||||
DEVICE_RESET
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
}
|
||||
|
||||
// This will only print the proper error string when calling cudaGetLastError
|
||||
// but not exit program incase error detected.
|
||||
#define printLastCudaError(msg) __printLastCudaError(msg, __FILE__, __LINE__)
|
||||
|
||||
inline void __printLastCudaError(const char *errorMessage, const char *file,
|
||||
const int line) {
|
||||
cudaError_t err = cudaGetLastError();
|
||||
|
||||
if (cudaSuccess != err) {
|
||||
fprintf(stderr,
|
||||
"%s(%i) : getLastCudaError() CUDA error :"
|
||||
" %s : (%d) %s.\n",
|
||||
file, line, errorMessage, static_cast<int>(err),
|
||||
cudaGetErrorString(err));
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifndef MAX
|
||||
#define MAX(a, b) (a > b ? a : b)
|
||||
#endif
|
||||
|
||||
// Float To Int conversion
|
||||
inline int ftoi(float value) {
|
||||
return (value >= 0 ? static_cast<int>(value + 0.5)
|
||||
: static_cast<int>(value - 0.5));
|
||||
}
|
||||
|
||||
// Beginning of GPU Architecture definitions
|
||||
inline int _ConvertSMVer2Cores(int major, int minor) {
|
||||
// Defines for GPU Architecture types (using the SM version to determine
|
||||
// the # of cores per SM
|
||||
typedef struct {
|
||||
int SM; // 0xMm (hexidecimal notation), M = SM Major version,
|
||||
// and m = SM minor version
|
||||
int Cores;
|
||||
} sSMtoCores;
|
||||
|
||||
sSMtoCores nGpuArchCoresPerSM[] = {
|
||||
{0x30, 192},
|
||||
{0x32, 192},
|
||||
{0x35, 192},
|
||||
{0x37, 192},
|
||||
{0x50, 128},
|
||||
{0x52, 128},
|
||||
{0x53, 128},
|
||||
{0x60, 64},
|
||||
{0x61, 128},
|
||||
{0x62, 128},
|
||||
{0x70, 64},
|
||||
{0x72, 64},
|
||||
{0x75, 64},
|
||||
{-1, -1}};
|
||||
|
||||
int index = 0;
|
||||
|
||||
while (nGpuArchCoresPerSM[index].SM != -1) {
|
||||
if (nGpuArchCoresPerSM[index].SM == ((major << 4) + minor)) {
|
||||
return nGpuArchCoresPerSM[index].Cores;
|
||||
}
|
||||
|
||||
index++;
|
||||
}
|
||||
|
||||
// If we don't find the values, we default use the previous one
|
||||
// to run properly
|
||||
printf(
|
||||
"MapSMtoCores for SM %d.%d is undefined."
|
||||
" Default to use %d Cores/SM\n",
|
||||
major, minor, nGpuArchCoresPerSM[index - 1].Cores);
|
||||
return nGpuArchCoresPerSM[index - 1].Cores;
|
||||
}
|
||||
// end of GPU Architecture definitions
|
||||
|
||||
#ifdef __CUDA_RUNTIME_H__
|
||||
// General GPU Device CUDA Initialization
|
||||
inline int gpuDeviceInit(int devID) {
|
||||
int device_count;
|
||||
checkCudaErrors(cudaGetDeviceCount(&device_count));
|
||||
|
||||
if (device_count == 0) {
|
||||
fprintf(stderr,
|
||||
"gpuDeviceInit() CUDA error: "
|
||||
"no devices supporting CUDA.\n");
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
if (devID < 0) {
|
||||
devID = 0;
|
||||
}
|
||||
|
||||
if (devID > device_count - 1) {
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, ">> %d CUDA capable GPU device(s) detected. <<\n",
|
||||
device_count);
|
||||
fprintf(stderr,
|
||||
">> gpuDeviceInit (-device=%d) is not a valid"
|
||||
" GPU device. <<\n",
|
||||
devID);
|
||||
fprintf(stderr, "\n");
|
||||
return -devID;
|
||||
}
|
||||
|
||||
cudaDeviceProp deviceProp;
|
||||
checkCudaErrors(cudaGetDeviceProperties(&deviceProp, devID));
|
||||
|
||||
if (deviceProp.computeMode == cudaComputeModeProhibited) {
|
||||
fprintf(stderr,
|
||||
"Error: device is running in <Compute Mode "
|
||||
"Prohibited>, no threads can use cudaSetDevice().\n");
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (deviceProp.major < 1) {
|
||||
fprintf(stderr, "gpuDeviceInit(): GPU device does not support CUDA.\n");
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
checkCudaErrors(cudaSetDevice(devID));
|
||||
printf("gpuDeviceInit() CUDA Device [%d]: \"%s\n", devID, deviceProp.name);
|
||||
|
||||
return devID;
|
||||
}
|
||||
|
||||
// This function returns the best GPU (with maximum GFLOPS)
|
||||
inline int gpuGetMaxGflopsDeviceId() {
|
||||
int current_device = 0, sm_per_multiproc = 0;
|
||||
int max_perf_device = 0;
|
||||
int device_count = 0;
|
||||
int devices_prohibited = 0;
|
||||
|
||||
uint64_t max_compute_perf = 0;
|
||||
cudaDeviceProp deviceProp;
|
||||
checkCudaErrors(cudaGetDeviceCount(&device_count));
|
||||
|
||||
if (device_count == 0) {
|
||||
fprintf(stderr,
|
||||
"gpuGetMaxGflopsDeviceId() CUDA error:"
|
||||
" no devices supporting CUDA.\n");
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
// Find the best CUDA capable GPU device
|
||||
current_device = 0;
|
||||
|
||||
while (current_device < device_count) {
|
||||
cudaGetDeviceProperties(&deviceProp, current_device);
|
||||
|
||||
// If this GPU is not running on Compute Mode prohibited,
|
||||
// then we can add it to the list
|
||||
if (deviceProp.computeMode != cudaComputeModeProhibited) {
|
||||
if (deviceProp.major == 9999 && deviceProp.minor == 9999) {
|
||||
sm_per_multiproc = 1;
|
||||
} else {
|
||||
sm_per_multiproc =
|
||||
_ConvertSMVer2Cores(deviceProp.major, deviceProp.minor);
|
||||
}
|
||||
|
||||
uint64_t compute_perf = (uint64_t)deviceProp.multiProcessorCount *
|
||||
sm_per_multiproc * deviceProp.clockRate;
|
||||
|
||||
if (compute_perf > max_compute_perf) {
|
||||
max_compute_perf = compute_perf;
|
||||
max_perf_device = current_device;
|
||||
}
|
||||
} else {
|
||||
devices_prohibited++;
|
||||
}
|
||||
|
||||
++current_device;
|
||||
}
|
||||
|
||||
if (devices_prohibited == device_count) {
|
||||
fprintf(stderr,
|
||||
"gpuGetMaxGflopsDeviceId() CUDA error:"
|
||||
" all devices have compute mode prohibited.\n");
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
return max_perf_device;
|
||||
}
|
||||
|
||||
// Initialization code to find the best CUDA Device
|
||||
inline int findCudaDevice(int argc, const char **argv) {
|
||||
cudaDeviceProp deviceProp;
|
||||
int devID = 0;
|
||||
|
||||
// If the command-line has a device number specified, use it
|
||||
if (checkCmdLineFlag(argc, argv, "device")) {
|
||||
devID = getCmdLineArgumentInt(argc, argv, "device=");
|
||||
|
||||
if (devID < 0) {
|
||||
printf("Invalid command line parameter\n ");
|
||||
exit(EXIT_FAILURE);
|
||||
} else {
|
||||
devID = gpuDeviceInit(devID);
|
||||
|
||||
if (devID < 0) {
|
||||
printf("exiting...\n");
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Otherwise pick the device with highest Gflops/s
|
||||
devID = gpuGetMaxGflopsDeviceId();
|
||||
checkCudaErrors(cudaSetDevice(devID));
|
||||
checkCudaErrors(cudaGetDeviceProperties(&deviceProp, devID));
|
||||
printf("GPU Device %d: \"%s\" with compute capability %d.%d\n\n", devID,
|
||||
deviceProp.name, deviceProp.major, deviceProp.minor);
|
||||
}
|
||||
|
||||
return devID;
|
||||
}
|
||||
|
||||
inline int findIntegratedGPU() {
|
||||
int current_device = 0;
|
||||
int device_count = 0;
|
||||
int devices_prohibited = 0;
|
||||
|
||||
cudaDeviceProp deviceProp;
|
||||
checkCudaErrors(cudaGetDeviceCount(&device_count));
|
||||
|
||||
if (device_count == 0) {
|
||||
fprintf(stderr, "CUDA error: no devices supporting CUDA.\n");
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
// Find the integrated GPU which is compute capable
|
||||
while (current_device < device_count) {
|
||||
cudaGetDeviceProperties(&deviceProp, current_device);
|
||||
|
||||
// If GPU is integrated and is not running on Compute Mode prohibited,
|
||||
// then cuda can map to GLES resource
|
||||
if (deviceProp.integrated &&
|
||||
(deviceProp.computeMode != cudaComputeModeProhibited)) {
|
||||
checkCudaErrors(cudaSetDevice(current_device));
|
||||
checkCudaErrors(cudaGetDeviceProperties(&deviceProp, current_device));
|
||||
printf("GPU Device %d: \"%s\" with compute capability %d.%d\n\n",
|
||||
current_device, deviceProp.name, deviceProp.major,
|
||||
deviceProp.minor);
|
||||
|
||||
return current_device;
|
||||
} else {
|
||||
devices_prohibited++;
|
||||
}
|
||||
|
||||
current_device++;
|
||||
}
|
||||
|
||||
if (devices_prohibited == device_count) {
|
||||
fprintf(stderr,
|
||||
"CUDA error:"
|
||||
" No GLES-CUDA Interop capable GPU found.\n");
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
return -1;
|
||||
}
|
||||
|
||||
// General check for CUDA GPU SM Capabilities
|
||||
inline bool checkCudaCapabilities(int major_version, int minor_version) {
|
||||
cudaDeviceProp deviceProp;
|
||||
deviceProp.major = 0;
|
||||
deviceProp.minor = 0;
|
||||
int dev;
|
||||
|
||||
checkCudaErrors(cudaGetDevice(&dev));
|
||||
checkCudaErrors(cudaGetDeviceProperties(&deviceProp, dev));
|
||||
|
||||
if ((deviceProp.major > major_version) ||
|
||||
(deviceProp.major == major_version &&
|
||||
deviceProp.minor >= minor_version)) {
|
||||
printf(" Device %d: <%16s >, Compute SM %d.%d detected\n", dev,
|
||||
deviceProp.name, deviceProp.major, deviceProp.minor);
|
||||
return true;
|
||||
} else {
|
||||
printf(
|
||||
" No GPU device was found that can support "
|
||||
"CUDA compute capability %d.%d.\n",
|
||||
major_version, minor_version);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
// end of CUDA Helper Functions
|
||||
|
||||
#endif // COMMON_HELPER_CUDA_H_
|
|
@ -0,0 +1,42 @@
|
|||
/**
|
||||
* Copyright 1993-2013 NVIDIA Corporation. All rights reserved.
|
||||
*
|
||||
* Please refer to the NVIDIA end user license agreement (EULA) associated
|
||||
* with this source code for terms and conditions that govern your use of
|
||||
* this software. Any use, reproduction, disclosure, or distribution of
|
||||
* this software and related documentation outside the terms of the EULA
|
||||
* is strictly prohibited.
|
||||
*
|
||||
*/
|
||||
|
||||
// These are helper functions for the SDK samples (string parsing,
|
||||
// timers, image helpers, etc)
|
||||
#ifndef COMMON_HELPER_FUNCTIONS_H_
|
||||
#define COMMON_HELPER_FUNCTIONS_H_
|
||||
|
||||
#ifdef WIN32
|
||||
#pragma warning(disable : 4996)
|
||||
#endif
|
||||
|
||||
// includes, project
|
||||
#include <assert.h>
|
||||
#include <math.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
// includes, timer, string parsing, image helpers
|
||||
#include <helper_image.h> // helper functions for image compare, dump, data comparisons
|
||||
#include <helper_string.h> // helper functions for string parsing
|
||||
#include <helper_timer.h> // helper functions for timers
|
||||
|
||||
#ifndef EXIT_WAIVED
|
||||
#define EXIT_WAIVED 2
|
||||
#endif
|
||||
|
||||
#endif // COMMON_HELPER_FUNCTIONS_H_
|
|
@ -0,0 +1,984 @@
|
|||
/**
|
||||
* Copyright 1993-2013 NVIDIA Corporation. All rights reserved.
|
||||
*
|
||||
* Please refer to the NVIDIA end user license agreement (EULA) associated
|
||||
* with this source code for terms and conditions that govern your use of
|
||||
* this software. Any use, reproduction, disclosure, or distribution of
|
||||
* this software and related documentation outside the terms of the EULA
|
||||
* is strictly prohibited.
|
||||
*
|
||||
*/
|
||||
|
||||
// These are helper functions for the SDK samples (image,bitmap)
|
||||
#ifndef COMMON_HELPER_IMAGE_H_
|
||||
#define COMMON_HELPER_IMAGE_H_
|
||||
|
||||
#include <assert.h>
|
||||
#include <math.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#ifndef MIN
|
||||
#define MIN(a, b) ((a < b) ? a : b)
|
||||
#endif
|
||||
#ifndef MAX
|
||||
#define MAX(a, b) ((a > b) ? a : b)
|
||||
#endif
|
||||
|
||||
#ifndef EXIT_WAIVED
|
||||
#define EXIT_WAIVED 2
|
||||
#endif
|
||||
|
||||
#include <helper_string.h>
|
||||
|
||||
// namespace unnamed (internal)
|
||||
namespace helper_image_internal {
|
||||
//! size of PGM file header
|
||||
const unsigned int PGMHeaderSize = 0x40;
|
||||
|
||||
// types
|
||||
|
||||
//! Data converter from unsigned char / unsigned byte to type T
|
||||
template <class T>
|
||||
struct ConverterFromUByte;
|
||||
|
||||
//! Data converter from unsigned char / unsigned byte
|
||||
template <>
|
||||
struct ConverterFromUByte<unsigned char> {
|
||||
//! Conversion operator
|
||||
//! @return converted value
|
||||
//! @param val value to convert
|
||||
float operator()(const unsigned char &val) {
|
||||
return static_cast<unsigned char>(val);
|
||||
}
|
||||
};
|
||||
|
||||
//! Data converter from unsigned char / unsigned byte to float
|
||||
template <>
|
||||
struct ConverterFromUByte<float> {
|
||||
//! Conversion operator
|
||||
//! @return converted value
|
||||
//! @param val value to convert
|
||||
float operator()(const unsigned char &val) {
|
||||
return static_cast<float>(val) / 255.0f;
|
||||
}
|
||||
};
|
||||
|
||||
//! Data converter from unsigned char / unsigned byte to type T
|
||||
template <class T>
|
||||
struct ConverterToUByte;
|
||||
|
||||
//! Data converter from unsigned char / unsigned byte to unsigned int
|
||||
template <>
|
||||
struct ConverterToUByte<unsigned char> {
|
||||
//! Conversion operator (essentially a passthru
|
||||
//! @return converted value
|
||||
//! @param val value to convert
|
||||
unsigned char operator()(const unsigned char &val) { return val; }
|
||||
};
|
||||
|
||||
//! Data converter from unsigned char / unsigned byte to unsigned int
|
||||
template <>
|
||||
struct ConverterToUByte<float> {
|
||||
//! Conversion operator
|
||||
//! @return converted value
|
||||
//! @param val value to convert
|
||||
unsigned char operator()(const float &val) {
|
||||
return static_cast<unsigned char>(val * 255.0f);
|
||||
}
|
||||
};
|
||||
} // namespace helper_image_internal
|
||||
|
||||
#if defined(WIN32) || defined(_WIN32) || defined(WIN64) || defined(_WIN64)
|
||||
#ifndef FOPEN
|
||||
#define FOPEN(fHandle, filename, mode) fopen_s(&fHandle, filename, mode)
|
||||
#endif
|
||||
#ifndef FOPEN_FAIL
|
||||
#define FOPEN_FAIL(result) (result != 0)
|
||||
#endif
|
||||
#ifndef SSCANF
|
||||
#define SSCANF sscanf_s
|
||||
#endif
|
||||
#else
|
||||
#ifndef FOPEN
|
||||
#define FOPEN(fHandle, filename, mode) (fHandle = fopen(filename, mode))
|
||||
#endif
|
||||
#ifndef FOPEN_FAIL
|
||||
#define FOPEN_FAIL(result) (result == NULL)
|
||||
#endif
|
||||
#ifndef SSCANF
|
||||
#define SSCANF sscanf
|
||||
#endif
|
||||
#endif
|
||||
|
||||
inline bool __loadPPM(const char *file, unsigned char **data, unsigned int *w,
|
||||
unsigned int *h, unsigned int *channels) {
|
||||
FILE *fp = NULL;
|
||||
|
||||
if (FOPEN_FAIL(FOPEN(fp, file, "rb"))) {
|
||||
std::cerr << "__LoadPPM() : Failed to open file: " << file << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
// check header
|
||||
char header[helper_image_internal::PGMHeaderSize];
|
||||
|
||||
if (fgets(header, helper_image_internal::PGMHeaderSize, fp) == NULL) {
|
||||
std::cerr << "__LoadPPM() : reading PGM header returned NULL" << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (strncmp(header, "P5", 2) == 0) {
|
||||
*channels = 1;
|
||||
} else if (strncmp(header, "P6", 2) == 0) {
|
||||
*channels = 3;
|
||||
} else {
|
||||
std::cerr << "__LoadPPM() : File is not a PPM or PGM image" << std::endl;
|
||||
*channels = 0;
|
||||
return false;
|
||||
}
|
||||
|
||||
// parse header, read maxval, width and height
|
||||
unsigned int width = 0;
|
||||
unsigned int height = 0;
|
||||
unsigned int maxval = 0;
|
||||
unsigned int i = 0;
|
||||
|
||||
while (i < 3) {
|
||||
if (fgets(header, helper_image_internal::PGMHeaderSize, fp) == NULL) {
|
||||
std::cerr << "__LoadPPM() : reading PGM header returned NULL"
|
||||
<< std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (header[0] == '#') {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (i == 0) {
|
||||
i += SSCANF(header, "%u %u %u", &width, &height, &maxval);
|
||||
} else if (i == 1) {
|
||||
i += SSCANF(header, "%u %u", &height, &maxval);
|
||||
} else if (i == 2) {
|
||||
i += SSCANF(header, "%u", &maxval);
|
||||
}
|
||||
}
|
||||
|
||||
// check if given handle for the data is initialized
|
||||
if (NULL != *data) {
|
||||
if (*w != width || *h != height) {
|
||||
std::cerr << "__LoadPPM() : Invalid image dimensions." << std::endl;
|
||||
}
|
||||
} else {
|
||||
*data = (unsigned char *)malloc(sizeof(unsigned char) * width * height *
|
||||
*channels);
|
||||
*w = width;
|
||||
*h = height;
|
||||
}
|
||||
|
||||
// read and close file
|
||||
if (fread(*data, sizeof(unsigned char), width * height * *channels, fp) ==
|
||||
0) {
|
||||
std::cerr << "__LoadPPM() read data returned error." << std::endl;
|
||||
}
|
||||
|
||||
fclose(fp);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
inline bool sdkLoadPGM(const char *file, T **data, unsigned int *w,
|
||||
unsigned int *h) {
|
||||
unsigned char *idata = NULL;
|
||||
unsigned int channels;
|
||||
|
||||
if (true != __loadPPM(file, &idata, w, h, &channels)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
unsigned int size = *w * *h * channels;
|
||||
|
||||
// initialize mem if necessary
|
||||
// the correct size is checked / set in loadPGMc()
|
||||
if (NULL == *data) {
|
||||
*data = reinterpret_cast<T *>(malloc(sizeof(T) * size));
|
||||
}
|
||||
|
||||
// copy and cast data
|
||||
std::transform(idata, idata + size, *data,
|
||||
helper_image_internal::ConverterFromUByte<T>());
|
||||
|
||||
free(idata);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
inline bool sdkLoadPPM4(const char *file, T **data, unsigned int *w,
|
||||
unsigned int *h) {
|
||||
unsigned char *idata = 0;
|
||||
unsigned int channels;
|
||||
|
||||
if (__loadPPM(file, &idata, w, h, &channels)) {
|
||||
// pad 4th component
|
||||
int size = *w * *h;
|
||||
// keep the original pointer
|
||||
unsigned char *idata_orig = idata;
|
||||
*data = reinterpret_cast<T *>(malloc(sizeof(T) * size * 4));
|
||||
unsigned char *ptr = *data;
|
||||
|
||||
for (int i = 0; i < size; i++) {
|
||||
*ptr++ = *idata++;
|
||||
*ptr++ = *idata++;
|
||||
*ptr++ = *idata++;
|
||||
*ptr++ = 0;
|
||||
}
|
||||
|
||||
free(idata_orig);
|
||||
return true;
|
||||
} else {
|
||||
free(idata);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
inline bool __savePPM(const char *file, unsigned char *data, unsigned int w,
|
||||
unsigned int h, unsigned int channels) {
|
||||
assert(NULL != data);
|
||||
assert(w > 0);
|
||||
assert(h > 0);
|
||||
|
||||
std::fstream fh(file, std::fstream::out | std::fstream::binary);
|
||||
|
||||
if (fh.bad()) {
|
||||
std::cerr << "__savePPM() : Opening file failed." << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (channels == 1) {
|
||||
fh << "P5\n";
|
||||
} else if (channels == 3) {
|
||||
fh << "P6\n";
|
||||
} else {
|
||||
std::cerr << "__savePPM() : Invalid number of channels." << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
fh << w << "\n" << h << "\n" << 0xff << std::endl;
|
||||
|
||||
for (unsigned int i = 0; (i < (w * h * channels)) && fh.good(); ++i) {
|
||||
fh << data[i];
|
||||
}
|
||||
|
||||
fh.flush();
|
||||
|
||||
if (fh.bad()) {
|
||||
std::cerr << "__savePPM() : Writing data failed." << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
fh.close();
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
inline bool sdkSavePGM(const char *file, T *data, unsigned int w,
|
||||
unsigned int h) {
|
||||
unsigned int size = w * h;
|
||||
unsigned char *idata = (unsigned char *)malloc(sizeof(unsigned char) * size);
|
||||
|
||||
std::transform(data, data + size, idata,
|
||||
helper_image_internal::ConverterToUByte<T>());
|
||||
|
||||
// write file
|
||||
bool result = __savePPM(file, idata, w, h, 1);
|
||||
|
||||
// cleanup
|
||||
free(idata);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
inline bool sdkSavePPM4ub(const char *file, unsigned char *data, unsigned int w,
|
||||
unsigned int h) {
|
||||
// strip 4th component
|
||||
int size = w * h;
|
||||
unsigned char *ndata =
|
||||
(unsigned char *)malloc(sizeof(unsigned char) * size * 3);
|
||||
unsigned char *ptr = ndata;
|
||||
|
||||
for (int i = 0; i < size; i++) {
|
||||
*ptr++ = *data++;
|
||||
*ptr++ = *data++;
|
||||
*ptr++ = *data++;
|
||||
data++;
|
||||
}
|
||||
|
||||
bool result = __savePPM(file, ndata, w, h, 3);
|
||||
free(ndata);
|
||||
return result;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
//! Read file \filename and return the data
|
||||
//! @return bool if reading the file succeeded, otherwise false
|
||||
//! @param filename name of the source file
|
||||
//! @param data uninitialized pointer, returned initialized and pointing to
|
||||
//! the data read
|
||||
//! @param len number of data elements in data, -1 on error
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
template <class T>
|
||||
inline bool sdkReadFile(const char *filename, T **data, unsigned int *len,
|
||||
bool verbose) {
|
||||
// check input arguments
|
||||
assert(NULL != filename);
|
||||
assert(NULL != len);
|
||||
|
||||
// intermediate storage for the data read
|
||||
std::vector<T> data_read;
|
||||
|
||||
// open file for reading
|
||||
FILE *fh = NULL;
|
||||
|
||||
// check if filestream is valid
|
||||
if (FOPEN_FAIL(FOPEN(fh, filename, "r"))) {
|
||||
printf("Unable to open input file: %s\n", filename);
|
||||
return false;
|
||||
}
|
||||
|
||||
// read all data elements
|
||||
T token;
|
||||
|
||||
while (!feof(fh)) {
|
||||
fscanf(fh, "%f", &token);
|
||||
data_read.push_back(token);
|
||||
}
|
||||
|
||||
// the last element is read twice
|
||||
data_read.pop_back();
|
||||
fclose(fh);
|
||||
|
||||
// check if the given handle is already initialized
|
||||
if (NULL != *data) {
|
||||
if (*len != data_read.size()) {
|
||||
std::cerr << "sdkReadFile() : Initialized memory given but "
|
||||
<< "size mismatch with signal read "
|
||||
<< "(data read / data init = " << (unsigned int)data_read.size()
|
||||
<< " / " << *len << ")" << std::endl;
|
||||
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
// allocate storage for the data read
|
||||
*data = reinterpret_cast<T *>(malloc(sizeof(T) * data_read.size()));
|
||||
// store signal size
|
||||
*len = static_cast<unsigned int>(data_read.size());
|
||||
}
|
||||
|
||||
// copy data
|
||||
memcpy(*data, &data_read.front(), sizeof(T) * data_read.size());
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
//! Read file \filename and return the data
|
||||
//! @return bool if reading the file succeeded, otherwise false
|
||||
//! @param filename name of the source file
|
||||
//! @param data uninitialized pointer, returned initialized and pointing to
|
||||
//! the data read
|
||||
//! @param len number of data elements in data, -1 on error
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
template <class T>
|
||||
inline bool sdkReadFileBlocks(const char *filename, T **data, unsigned int *len,
|
||||
unsigned int block_num, unsigned int block_size,
|
||||
bool verbose) {
|
||||
// check input arguments
|
||||
assert(NULL != filename);
|
||||
assert(NULL != len);
|
||||
|
||||
// open file for reading
|
||||
FILE *fh = fopen(filename, "rb");
|
||||
|
||||
if (fh == NULL && verbose) {
|
||||
std::cerr << "sdkReadFile() : Opening file failed." << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
// check if the given handle is already initialized
|
||||
// allocate storage for the data read
|
||||
data[block_num] = reinterpret_cast<T *>(malloc(block_size));
|
||||
|
||||
// read all data elements
|
||||
fseek(fh, block_num * block_size, SEEK_SET);
|
||||
*len = fread(data[block_num], sizeof(T), block_size / sizeof(T), fh);
|
||||
|
||||
fclose(fh);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
//! Write a data file \filename
|
||||
//! @return true if writing the file succeeded, otherwise false
|
||||
//! @param filename name of the source file
|
||||
//! @param data data to write
|
||||
//! @param len number of data elements in data, -1 on error
|
||||
//! @param epsilon epsilon for comparison
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
template <class T, class S>
|
||||
inline bool sdkWriteFile(const char *filename, const T *data, unsigned int len,
|
||||
const S epsilon, bool verbose, bool append = false) {
|
||||
assert(NULL != filename);
|
||||
assert(NULL != data);
|
||||
|
||||
// open file for writing
|
||||
// if (append) {
|
||||
std::fstream fh(filename, std::fstream::out | std::fstream::ate);
|
||||
|
||||
if (verbose) {
|
||||
std::cerr << "sdkWriteFile() : Open file " << filename
|
||||
<< " for write/append." << std::endl;
|
||||
}
|
||||
|
||||
/* } else {
|
||||
std::fstream fh(filename, std::fstream::out);
|
||||
if (verbose) {
|
||||
std::cerr << "sdkWriteFile() : Open file " << filename << " for
|
||||
write." << std::endl;
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
||||
// check if filestream is valid
|
||||
if (!fh.good()) {
|
||||
if (verbose) {
|
||||
std::cerr << "sdkWriteFile() : Opening file failed." << std::endl;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
// first write epsilon
|
||||
fh << "# " << epsilon << "\n";
|
||||
|
||||
// write data
|
||||
for (unsigned int i = 0; (i < len) && (fh.good()); ++i) {
|
||||
fh << data[i] << ' ';
|
||||
}
|
||||
|
||||
// Check if writing succeeded
|
||||
if (!fh.good()) {
|
||||
if (verbose) {
|
||||
std::cerr << "sdkWriteFile() : Writing file failed." << std::endl;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
// file ends with nl
|
||||
fh << std::endl;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
//! Compare two arrays of arbitrary type
|
||||
//! @return true if \a reference and \a data are identical, otherwise false
|
||||
//! @param reference timer_interface to the reference data / gold image
|
||||
//! @param data handle to the computed data
|
||||
//! @param len number of elements in reference and data
|
||||
//! @param epsilon epsilon to use for the comparison
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
template <class T, class S>
|
||||
inline bool compareData(const T *reference, const T *data,
|
||||
const unsigned int len, const S epsilon,
|
||||
const float threshold) {
|
||||
assert(epsilon >= 0);
|
||||
|
||||
bool result = true;
|
||||
unsigned int error_count = 0;
|
||||
|
||||
for (unsigned int i = 0; i < len; ++i) {
|
||||
float diff = static_cast<float>(reference[i]) - static_cast<float>(data[i]);
|
||||
bool comp = (diff <= epsilon) && (diff >= -epsilon);
|
||||
result &= comp;
|
||||
|
||||
error_count += !comp;
|
||||
|
||||
#if 0
|
||||
|
||||
if (!comp) {
|
||||
std::cerr << "ERROR, i = " << i << ",\t "
|
||||
<< reference[i] << " / "
|
||||
<< data[i]
|
||||
<< " (reference / data)\n";
|
||||
}
|
||||
|
||||
#endif
|
||||
}
|
||||
|
||||
if (threshold == 0.0f) {
|
||||
return (result) ? true : false;
|
||||
} else {
|
||||
if (error_count) {
|
||||
printf("%4.2f(%%) of bytes mismatched (count=%d)\n",
|
||||
static_cast<float>(error_count) * 100 / static_cast<float>(len),
|
||||
error_count);
|
||||
}
|
||||
|
||||
return (len * threshold > error_count) ? true : false;
|
||||
}
|
||||
}
|
||||
|
||||
#ifndef __MIN_EPSILON_ERROR
|
||||
#define __MIN_EPSILON_ERROR 1e-3f
|
||||
#endif
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
//! Compare two arrays of arbitrary type
|
||||
//! @return true if \a reference and \a data are identical, otherwise false
|
||||
//! @param reference handle to the reference data / gold image
|
||||
//! @param data handle to the computed data
|
||||
//! @param len number of elements in reference and data
|
||||
//! @param epsilon epsilon to use for the comparison
|
||||
//! @param epsilon threshold % of (# of bytes) for pass/fail
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
template <class T, class S>
|
||||
inline bool compareDataAsFloatThreshold(const T *reference, const T *data,
|
||||
const unsigned int len, const S epsilon,
|
||||
const float threshold) {
|
||||
assert(epsilon >= 0);
|
||||
|
||||
// If we set epsilon to be 0, let's set a minimum threshold
|
||||
float max_error = MAX((float)epsilon, __MIN_EPSILON_ERROR);
|
||||
int error_count = 0;
|
||||
bool result = true;
|
||||
|
||||
for (unsigned int i = 0; i < len; ++i) {
|
||||
float diff =
|
||||
fabs(static_cast<float>(reference[i]) - static_cast<float>(data[i]));
|
||||
bool comp = (diff < max_error);
|
||||
result &= comp;
|
||||
|
||||
if (!comp) {
|
||||
error_count++;
|
||||
}
|
||||
}
|
||||
|
||||
if (threshold == 0.0f) {
|
||||
if (error_count) {
|
||||
printf("total # of errors = %d\n", error_count);
|
||||
}
|
||||
|
||||
return (error_count == 0) ? true : false;
|
||||
} else {
|
||||
if (error_count) {
|
||||
printf("%4.2f(%%) of bytes mismatched (count=%d)\n",
|
||||
static_cast<float>(error_count) * 100 / static_cast<float>(len),
|
||||
error_count);
|
||||
}
|
||||
|
||||
return ((len * threshold > error_count) ? true : false);
|
||||
}
|
||||
}
|
||||
|
||||
inline void sdkDumpBin(void *data, unsigned int bytes, const char *filename) {
|
||||
printf("sdkDumpBin: <%s>\n", filename);
|
||||
FILE *fp;
|
||||
FOPEN(fp, filename, "wb");
|
||||
fwrite(data, bytes, 1, fp);
|
||||
fflush(fp);
|
||||
fclose(fp);
|
||||
}
|
||||
|
||||
inline bool sdkCompareBin2BinUint(const char *src_file, const char *ref_file,
|
||||
unsigned int nelements, const float epsilon,
|
||||
const float threshold, char *exec_path) {
|
||||
unsigned int *src_buffer, *ref_buffer;
|
||||
FILE *src_fp = NULL, *ref_fp = NULL;
|
||||
|
||||
uint64_t error_count = 0;
|
||||
size_t fsize = 0;
|
||||
|
||||
if (FOPEN_FAIL(FOPEN(src_fp, src_file, "rb"))) {
|
||||
printf("compareBin2Bin <unsigned int> unable to open src_file: %s\n",
|
||||
src_file);
|
||||
error_count++;
|
||||
}
|
||||
|
||||
char *ref_file_path = sdkFindFilePath(ref_file, exec_path);
|
||||
|
||||
if (ref_file_path == NULL) {
|
||||
printf("compareBin2Bin <unsigned int> unable to find <%s> in <%s>\n",
|
||||
ref_file, exec_path);
|
||||
printf(">>> Check info.xml and [project//data] folder <%s> <<<\n",
|
||||
ref_file);
|
||||
printf("Aborting comparison!\n");
|
||||
printf(" FAILED\n");
|
||||
error_count++;
|
||||
|
||||
if (src_fp) {
|
||||
fclose(src_fp);
|
||||
}
|
||||
|
||||
if (ref_fp) {
|
||||
fclose(ref_fp);
|
||||
}
|
||||
} else {
|
||||
if (FOPEN_FAIL(FOPEN(ref_fp, ref_file_path, "rb"))) {
|
||||
printf(
|
||||
"compareBin2Bin <unsigned int>"
|
||||
" unable to open ref_file: %s\n",
|
||||
ref_file_path);
|
||||
error_count++;
|
||||
}
|
||||
|
||||
if (src_fp && ref_fp) {
|
||||
src_buffer = (unsigned int *)malloc(nelements * sizeof(unsigned int));
|
||||
ref_buffer = (unsigned int *)malloc(nelements * sizeof(unsigned int));
|
||||
|
||||
fsize = fread(src_buffer, nelements, sizeof(unsigned int), src_fp);
|
||||
fsize = fread(ref_buffer, nelements, sizeof(unsigned int), ref_fp);
|
||||
|
||||
printf(
|
||||
"> compareBin2Bin <unsigned int> nelements=%d,"
|
||||
" epsilon=%4.2f, threshold=%4.2f\n",
|
||||
nelements, epsilon, threshold);
|
||||
printf(" src_file <%s>, size=%d bytes\n", src_file,
|
||||
static_cast<int>(fsize));
|
||||
printf(" ref_file <%s>, size=%d bytes\n", ref_file_path,
|
||||
static_cast<int>(fsize));
|
||||
|
||||
if (!compareData<unsigned int, float>(ref_buffer, src_buffer, nelements,
|
||||
epsilon, threshold)) {
|
||||
error_count++;
|
||||
}
|
||||
|
||||
fclose(src_fp);
|
||||
fclose(ref_fp);
|
||||
|
||||
free(src_buffer);
|
||||
free(ref_buffer);
|
||||
} else {
|
||||
if (src_fp) {
|
||||
fclose(src_fp);
|
||||
}
|
||||
|
||||
if (ref_fp) {
|
||||
fclose(ref_fp);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (error_count == 0) {
|
||||
printf(" OK\n");
|
||||
} else {
|
||||
printf(" FAILURE: %d errors...\n", (unsigned int)error_count);
|
||||
}
|
||||
|
||||
return (error_count == 0); // returns true if all pixels pass
|
||||
}
|
||||
|
||||
inline bool sdkCompareBin2BinFloat(const char *src_file, const char *ref_file,
|
||||
unsigned int nelements, const float epsilon,
|
||||
const float threshold, char *exec_path) {
|
||||
float *src_buffer = NULL, *ref_buffer = NULL;
|
||||
FILE *src_fp = NULL, *ref_fp = NULL;
|
||||
size_t fsize = 0;
|
||||
|
||||
uint64_t error_count = 0;
|
||||
|
||||
if (FOPEN_FAIL(FOPEN(src_fp, src_file, "rb"))) {
|
||||
printf("compareBin2Bin <float> unable to open src_file: %s\n", src_file);
|
||||
error_count = 1;
|
||||
}
|
||||
|
||||
char *ref_file_path = sdkFindFilePath(ref_file, exec_path);
|
||||
|
||||
if (ref_file_path == NULL) {
|
||||
printf("compareBin2Bin <float> unable to find <%s> in <%s>\n", ref_file,
|
||||
exec_path);
|
||||
printf(">>> Check info.xml and [project//data] folder <%s> <<<\n",
|
||||
exec_path);
|
||||
printf("Aborting comparison!\n");
|
||||
printf(" FAILED\n");
|
||||
error_count++;
|
||||
|
||||
if (src_fp) {
|
||||
fclose(src_fp);
|
||||
}
|
||||
|
||||
if (ref_fp) {
|
||||
fclose(ref_fp);
|
||||
}
|
||||
} else {
|
||||
if (FOPEN_FAIL(FOPEN(ref_fp, ref_file_path, "rb"))) {
|
||||
printf("compareBin2Bin <float> unable to open ref_file: %s\n",
|
||||
ref_file_path);
|
||||
error_count = 1;
|
||||
}
|
||||
|
||||
if (src_fp && ref_fp) {
|
||||
src_buffer = reinterpret_cast<float *>(malloc(nelements * sizeof(float)));
|
||||
ref_buffer = reinterpret_cast<float *>(malloc(nelements * sizeof(float)));
|
||||
|
||||
printf(
|
||||
"> compareBin2Bin <float> nelements=%d, epsilon=%4.2f,"
|
||||
" threshold=%4.2f\n",
|
||||
nelements, epsilon, threshold);
|
||||
fsize = fread(src_buffer, sizeof(float), nelements, src_fp);
|
||||
printf(" src_file <%s>, size=%d bytes\n", src_file,
|
||||
static_cast<int>(fsize * sizeof(float)));
|
||||
fsize = fread(ref_buffer, sizeof(float), nelements, ref_fp);
|
||||
printf(" ref_file <%s>, size=%d bytes\n", ref_file_path,
|
||||
static_cast<int>(fsize * sizeof(float)));
|
||||
|
||||
if (!compareDataAsFloatThreshold<float, float>(
|
||||
ref_buffer, src_buffer, nelements, epsilon, threshold)) {
|
||||
error_count++;
|
||||
}
|
||||
|
||||
fclose(src_fp);
|
||||
fclose(ref_fp);
|
||||
|
||||
free(src_buffer);
|
||||
free(ref_buffer);
|
||||
} else {
|
||||
if (src_fp) {
|
||||
fclose(src_fp);
|
||||
}
|
||||
|
||||
if (ref_fp) {
|
||||
fclose(ref_fp);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (error_count == 0) {
|
||||
printf(" OK\n");
|
||||
} else {
|
||||
printf(" FAILURE: %d errors...\n", (unsigned int)error_count);
|
||||
}
|
||||
|
||||
return (error_count == 0); // returns true if all pixels pass
|
||||
}
|
||||
|
||||
inline bool sdkCompareL2fe(const float *reference, const float *data,
|
||||
const unsigned int len, const float epsilon) {
|
||||
assert(epsilon >= 0);
|
||||
|
||||
float error = 0;
|
||||
float ref = 0;
|
||||
|
||||
for (unsigned int i = 0; i < len; ++i) {
|
||||
float diff = reference[i] - data[i];
|
||||
error += diff * diff;
|
||||
ref += reference[i] * reference[i];
|
||||
}
|
||||
|
||||
float normRef = sqrtf(ref);
|
||||
|
||||
if (fabs(ref) < 1e-7) {
|
||||
#ifdef _DEBUG
|
||||
std::cerr << "ERROR, reference l2-norm is 0\n";
|
||||
#endif
|
||||
return false;
|
||||
}
|
||||
|
||||
float normError = sqrtf(error);
|
||||
error = normError / normRef;
|
||||
bool result = error < epsilon;
|
||||
#ifdef _DEBUG
|
||||
|
||||
if (!result) {
|
||||
std::cerr << "ERROR, l2-norm error " << error << " is greater than epsilon "
|
||||
<< epsilon << "\n";
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
inline bool sdkLoadPPMub(const char *file, unsigned char **data,
|
||||
unsigned int *w, unsigned int *h) {
|
||||
unsigned int channels;
|
||||
return __loadPPM(file, data, w, h, &channels);
|
||||
}
|
||||
|
||||
inline bool sdkLoadPPM4ub(const char *file, unsigned char **data,
|
||||
unsigned int *w, unsigned int *h) {
|
||||
unsigned char *idata = 0;
|
||||
unsigned int channels;
|
||||
|
||||
if (__loadPPM(file, &idata, w, h, &channels)) {
|
||||
// pad 4th component
|
||||
int size = *w * *h;
|
||||
// keep the original pointer
|
||||
unsigned char *idata_orig = idata;
|
||||
*data = (unsigned char *)malloc(sizeof(unsigned char) * size * 4);
|
||||
unsigned char *ptr = *data;
|
||||
|
||||
for (int i = 0; i < size; i++) {
|
||||
*ptr++ = *idata++;
|
||||
*ptr++ = *idata++;
|
||||
*ptr++ = *idata++;
|
||||
*ptr++ = 0;
|
||||
}
|
||||
|
||||
free(idata_orig);
|
||||
return true;
|
||||
} else {
|
||||
free(idata);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
inline bool sdkComparePPM(const char *src_file, const char *ref_file,
|
||||
const float epsilon, const float threshold,
|
||||
bool verboseErrors) {
|
||||
unsigned char *src_data, *ref_data;
|
||||
uint64_t error_count = 0;
|
||||
unsigned int ref_width, ref_height;
|
||||
unsigned int src_width, src_height;
|
||||
|
||||
if (src_file == NULL || ref_file == NULL) {
|
||||
if (verboseErrors) {
|
||||
std::cerr << "PPMvsPPM: src_file or ref_file is NULL."
|
||||
" Aborting comparison\n";
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
if (verboseErrors) {
|
||||
std::cerr << "> Compare (a)rendered: <" << src_file << ">\n";
|
||||
std::cerr << "> (b)reference: <" << ref_file << ">\n";
|
||||
}
|
||||
|
||||
if (sdkLoadPPM4ub(ref_file, &ref_data, &ref_width, &ref_height) != true) {
|
||||
if (verboseErrors) {
|
||||
std::cerr << "PPMvsPPM: unable to load ref image file: " << ref_file
|
||||
<< "\n";
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
if (sdkLoadPPM4ub(src_file, &src_data, &src_width, &src_height) != true) {
|
||||
std::cerr << "PPMvsPPM: unable to load src image file: " << src_file
|
||||
<< "\n";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (src_height != ref_height || src_width != ref_width) {
|
||||
if (verboseErrors) {
|
||||
std::cerr << "PPMvsPPM: source and ref size mismatch (" << src_width
|
||||
<< "," << src_height << ")vs(" << ref_width << "," << ref_height
|
||||
<< ")\n";
|
||||
}
|
||||
}
|
||||
|
||||
if (verboseErrors) {
|
||||
std::cerr << "PPMvsPPM: comparing images size (" << src_width << ","
|
||||
<< src_height << ") epsilon(" << epsilon << "), threshold("
|
||||
<< threshold * 100 << "%)\n";
|
||||
}
|
||||
|
||||
if (compareData(ref_data, src_data, src_width * src_height * 4, epsilon,
|
||||
threshold) == false) {
|
||||
error_count = 1;
|
||||
}
|
||||
|
||||
if (error_count == 0) {
|
||||
if (verboseErrors) {
|
||||
std::cerr << " OK\n\n";
|
||||
}
|
||||
} else {
|
||||
if (verboseErrors) {
|
||||
std::cerr << " FAILURE! " << error_count << " errors...\n\n";
|
||||
}
|
||||
}
|
||||
|
||||
// returns true if all pixels pass
|
||||
return (error_count == 0) ? true : false;
|
||||
}
|
||||
|
||||
inline bool sdkComparePGM(const char *src_file, const char *ref_file,
|
||||
const float epsilon, const float threshold,
|
||||
bool verboseErrors) {
|
||||
unsigned char *src_data = 0, *ref_data = 0;
|
||||
uint64_t error_count = 0;
|
||||
unsigned int ref_width, ref_height;
|
||||
unsigned int src_width, src_height;
|
||||
|
||||
if (src_file == NULL || ref_file == NULL) {
|
||||
if (verboseErrors) {
|
||||
std::cerr << "PGMvsPGM: src_file or ref_file is NULL."
|
||||
" Aborting comparison\n";
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
if (verboseErrors) {
|
||||
std::cerr << "> Compare (a)rendered: <" << src_file << ">\n";
|
||||
std::cerr << "> (b)reference: <" << ref_file << ">\n";
|
||||
}
|
||||
|
||||
if (sdkLoadPPMub(ref_file, &ref_data, &ref_width, &ref_height) != true) {
|
||||
if (verboseErrors) {
|
||||
std::cerr << "PGMvsPGM: unable to load ref image file: " << ref_file
|
||||
<< "\n";
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
if (sdkLoadPPMub(src_file, &src_data, &src_width, &src_height) != true) {
|
||||
std::cerr << "PGMvsPGM: unable to load src image file: " << src_file
|
||||
<< "\n";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (src_height != ref_height || src_width != ref_width) {
|
||||
if (verboseErrors) {
|
||||
std::cerr << "PGMvsPGM: source and ref size mismatch (" << src_width
|
||||
<< "," << src_height << ")vs(" << ref_width << "," << ref_height
|
||||
<< ")\n";
|
||||
}
|
||||
}
|
||||
|
||||
if (verboseErrors)
|
||||
std::cerr << "PGMvsPGM: comparing images size (" << src_width << ","
|
||||
<< src_height << ") epsilon(" << epsilon << "), threshold("
|
||||
<< threshold * 100 << "%)\n";
|
||||
|
||||
if (compareData(ref_data, src_data, src_width * src_height, epsilon,
|
||||
threshold) == false) {
|
||||
error_count = 1;
|
||||
}
|
||||
|
||||
if (error_count == 0) {
|
||||
if (verboseErrors) {
|
||||
std::cerr << " OK\n\n";
|
||||
}
|
||||
} else {
|
||||
if (verboseErrors) {
|
||||
std::cerr << " FAILURE! " << error_count << " errors...\n\n";
|
||||
}
|
||||
}
|
||||
|
||||
// returns true if all pixels pass
|
||||
return (error_count == 0) ? true : false;
|
||||
}
|
||||
|
||||
#endif // COMMON_HELPER_IMAGE_H_
|
|
@ -0,0 +1,683 @@
|
|||
/**
|
||||
* Copyright 1993-2013 NVIDIA Corporation. All rights reserved.
|
||||
*
|
||||
* Please refer to the NVIDIA end user license agreement (EULA) associated
|
||||
* with this source code for terms and conditions that govern your use of
|
||||
* this software. Any use, reproduction, disclosure, or distribution of
|
||||
* this software and related documentation outside the terms of the EULA
|
||||
* is strictly prohibited.
|
||||
*
|
||||
*/
|
||||
|
||||
// These are helper functions for the SDK samples (string parsing, timers, etc)
|
||||
#ifndef COMMON_HELPER_STRING_H_
|
||||
#define COMMON_HELPER_STRING_H_
|
||||
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <fstream>
|
||||
#include <string>
|
||||
|
||||
#if defined(WIN32) || defined(_WIN32) || defined(WIN64) || defined(_WIN64)
|
||||
#ifndef _CRT_SECURE_NO_DEPRECATE
|
||||
#define _CRT_SECURE_NO_DEPRECATE
|
||||
#endif
|
||||
#ifndef STRCASECMP
|
||||
#define STRCASECMP _stricmp
|
||||
#endif
|
||||
#ifndef STRNCASECMP
|
||||
#define STRNCASECMP _strnicmp
|
||||
#endif
|
||||
#ifndef STRCPY
|
||||
#define STRCPY(sFilePath, nLength, sPath) strcpy_s(sFilePath, nLength, sPath)
|
||||
#endif
|
||||
|
||||
#ifndef FOPEN
|
||||
#define FOPEN(fHandle, filename, mode) fopen_s(&fHandle, filename, mode)
|
||||
#endif
|
||||
#ifndef FOPEN_FAIL
|
||||
#define FOPEN_FAIL(result) (result != 0)
|
||||
#endif
|
||||
#ifndef SSCANF
|
||||
#define SSCANF sscanf_s
|
||||
#endif
|
||||
#ifndef SPRINTF
|
||||
#define SPRINTF sprintf_s
|
||||
#endif
|
||||
#else // Linux Includes
|
||||
#include <string.h>
|
||||
#include <strings.h>
|
||||
|
||||
#ifndef STRCASECMP
|
||||
#define STRCASECMP strcasecmp
|
||||
#endif
|
||||
#ifndef STRNCASECMP
|
||||
#define STRNCASECMP strncasecmp
|
||||
#endif
|
||||
#ifndef STRCPY
|
||||
#define STRCPY(sFilePath, nLength, sPath) strcpy(sFilePath, sPath)
|
||||
#endif
|
||||
|
||||
#ifndef FOPEN
|
||||
#define FOPEN(fHandle, filename, mode) (fHandle = fopen(filename, mode))
|
||||
#endif
|
||||
#ifndef FOPEN_FAIL
|
||||
#define FOPEN_FAIL(result) (result == NULL)
|
||||
#endif
|
||||
#ifndef SSCANF
|
||||
#define SSCANF sscanf
|
||||
#endif
|
||||
#ifndef SPRINTF
|
||||
#define SPRINTF sprintf
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#ifndef EXIT_WAIVED
|
||||
#define EXIT_WAIVED 2
|
||||
#endif
|
||||
|
||||
// CUDA Utility Helper Functions
|
||||
inline int stringRemoveDelimiter(char delimiter, const char *string) {
|
||||
int string_start = 0;
|
||||
|
||||
while (string[string_start] == delimiter) {
|
||||
string_start++;
|
||||
}
|
||||
|
||||
if (string_start >= static_cast<int>(strlen(string) - 1)) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
return string_start;
|
||||
}
|
||||
|
||||
inline int getFileExtension(char *filename, char **extension) {
|
||||
int string_length = static_cast<int>(strlen(filename));
|
||||
|
||||
while (filename[string_length--] != '.') {
|
||||
if (string_length == 0) break;
|
||||
}
|
||||
|
||||
if (string_length > 0) string_length += 2;
|
||||
|
||||
if (string_length == 0)
|
||||
*extension = NULL;
|
||||
else
|
||||
*extension = &filename[string_length];
|
||||
|
||||
return string_length;
|
||||
}
|
||||
|
||||
inline bool checkCmdLineFlag(const int argc, const char **argv,
|
||||
const char *string_ref) {
|
||||
bool bFound = false;
|
||||
|
||||
if (argc >= 1) {
|
||||
for (int i = 1; i < argc; i++) {
|
||||
int string_start = stringRemoveDelimiter('-', argv[i]);
|
||||
const char *string_argv = &argv[i][string_start];
|
||||
|
||||
const char *equal_pos = strchr(string_argv, '=');
|
||||
int argv_length = static_cast<int>(
|
||||
equal_pos == 0 ? strlen(string_argv) : equal_pos - string_argv);
|
||||
|
||||
int length = static_cast<int>(strlen(string_ref));
|
||||
|
||||
if (length == argv_length &&
|
||||
!STRNCASECMP(string_argv, string_ref, length)) {
|
||||
bFound = true;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return bFound;
|
||||
}
|
||||
|
||||
// This function wraps the CUDA Driver API into a template function
|
||||
template <class T>
|
||||
inline bool getCmdLineArgumentValue(const int argc, const char **argv,
|
||||
const char *string_ref, T *value) {
|
||||
bool bFound = false;
|
||||
|
||||
if (argc >= 1) {
|
||||
for (int i = 1; i < argc; i++) {
|
||||
int string_start = stringRemoveDelimiter('-', argv[i]);
|
||||
const char *string_argv = &argv[i][string_start];
|
||||
int length = static_cast<int>(strlen(string_ref));
|
||||
|
||||
if (!STRNCASECMP(string_argv, string_ref, length)) {
|
||||
if (length + 1 <= static_cast<int>(strlen(string_argv))) {
|
||||
int auto_inc = (string_argv[length] == '=') ? 1 : 0;
|
||||
*value = (T)atoi(&string_argv[length + auto_inc]);
|
||||
}
|
||||
|
||||
bFound = true;
|
||||
i = argc;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return bFound;
|
||||
}
|
||||
|
||||
inline int getCmdLineArgumentInt(const int argc, const char **argv,
|
||||
const char *string_ref) {
|
||||
bool bFound = false;
|
||||
int value = -1;
|
||||
|
||||
if (argc >= 1) {
|
||||
for (int i = 1; i < argc; i++) {
|
||||
int string_start = stringRemoveDelimiter('-', argv[i]);
|
||||
const char *string_argv = &argv[i][string_start];
|
||||
int length = static_cast<int>(strlen(string_ref));
|
||||
|
||||
if (!STRNCASECMP(string_argv, string_ref, length)) {
|
||||
if (length + 1 <= static_cast<int>(strlen(string_argv))) {
|
||||
int auto_inc = (string_argv[length] == '=') ? 1 : 0;
|
||||
value = atoi(&string_argv[length + auto_inc]);
|
||||
} else {
|
||||
value = 0;
|
||||
}
|
||||
|
||||
bFound = true;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (bFound) {
|
||||
return value;
|
||||
} else {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
inline float getCmdLineArgumentFloat(const int argc, const char **argv,
|
||||
const char *string_ref) {
|
||||
bool bFound = false;
|
||||
float value = -1;
|
||||
|
||||
if (argc >= 1) {
|
||||
for (int i = 1; i < argc; i++) {
|
||||
int string_start = stringRemoveDelimiter('-', argv[i]);
|
||||
const char *string_argv = &argv[i][string_start];
|
||||
int length = static_cast<int>(strlen(string_ref));
|
||||
|
||||
if (!STRNCASECMP(string_argv, string_ref, length)) {
|
||||
if (length + 1 <= static_cast<int>(strlen(string_argv))) {
|
||||
int auto_inc = (string_argv[length] == '=') ? 1 : 0;
|
||||
value = static_cast<float>(atof(&string_argv[length + auto_inc]));
|
||||
} else {
|
||||
value = 0.f;
|
||||
}
|
||||
|
||||
bFound = true;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (bFound) {
|
||||
return value;
|
||||
} else {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
inline bool getCmdLineArgumentString(const int argc, const char **argv,
|
||||
const char *string_ref,
|
||||
char **string_retval) {
|
||||
bool bFound = false;
|
||||
|
||||
if (argc >= 1) {
|
||||
for (int i = 1; i < argc; i++) {
|
||||
int string_start = stringRemoveDelimiter('-', argv[i]);
|
||||
char *string_argv = const_cast<char*>(&argv[i][string_start]);
|
||||
int length = static_cast<int>(strlen(string_ref));
|
||||
|
||||
if (!STRNCASECMP(string_argv, string_ref, length)) {
|
||||
*string_retval = &string_argv[length + 1];
|
||||
bFound = true;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!bFound) {
|
||||
*string_retval = NULL;
|
||||
}
|
||||
|
||||
return bFound;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
//! Find the path for a file assuming that
|
||||
//! files are found in the searchPath.
|
||||
//!
|
||||
//! @return the path if succeeded, otherwise 0
|
||||
//! @param filename name of the file
|
||||
//! @param executable_path optional absolute path of the executable
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
inline char *sdkFindFilePath(const char *filename,
|
||||
const char *executable_path) {
|
||||
// <executable_name> defines a variable that is replaced with the name of the
|
||||
// executable
|
||||
|
||||
// Typical relative search paths to locate needed companion files (e.g. sample
|
||||
// input data, or JIT source files) The origin for the relative search may be
|
||||
// the .exe file, a .bat file launching an .exe, a browser .exe launching the
|
||||
// .exe or .bat, etc
|
||||
const char *searchPath[] = {
|
||||
"./", // same dir
|
||||
"./<executable_name>_data_files/",
|
||||
"./common/", // "/common/" subdir
|
||||
"./common/data/", // "/common/data/" subdir
|
||||
"./data/", // "/data/" subdir
|
||||
"./src/", // "/src/" subdir
|
||||
"./src/<executable_name>/data/", // "/src/<executable_name>/data/" subdir
|
||||
"./inc/", // "/inc/" subdir
|
||||
"./0_Simple/", // "/0_Simple/" subdir
|
||||
"./1_Utilities/", // "/1_Utilities/" subdir
|
||||
"./2_Graphics/", // "/2_Graphics/" subdir
|
||||
"./3_Imaging/", // "/3_Imaging/" subdir
|
||||
"./4_Finance/", // "/4_Finance/" subdir
|
||||
"./5_Simulations/", // "/5_Simulations/" subdir
|
||||
"./6_Advanced/", // "/6_Advanced/" subdir
|
||||
"./7_CUDALibraries/", // "/7_CUDALibraries/" subdir
|
||||
"./8_Android/", // "/8_Android/" subdir
|
||||
"./samples/", // "/samples/" subdir
|
||||
|
||||
"./0_Simple/<executable_name>/data/", // "/0_Simple/<executable_name>/data/"
|
||||
// subdir
|
||||
"./1_Utilities/<executable_name>/data/", // "/1_Utilities/<executable_name>/data/"
|
||||
// subdir
|
||||
"./2_Graphics/<executable_name>/data/", // "/2_Graphics/<executable_name>/data/"
|
||||
// subdir
|
||||
"./3_Imaging/<executable_name>/data/", // "/3_Imaging/<executable_name>/data/"
|
||||
// subdir
|
||||
"./4_Finance/<executable_name>/data/", // "/4_Finance/<executable_name>/data/"
|
||||
// subdir
|
||||
"./5_Simulations/<executable_name>/data/", // "/5_Simulations/<executable_name>/data/"
|
||||
// subdir
|
||||
"./6_Advanced/<executable_name>/data/", // "/6_Advanced/<executable_name>/data/"
|
||||
// subdir
|
||||
"./7_CUDALibraries/<executable_name>/", // "/7_CUDALibraries/<executable_name>/"
|
||||
// subdir
|
||||
"./7_CUDALibraries/<executable_name>/data/", // "/7_CUDALibraries/<executable_name>/data/"
|
||||
// subdir
|
||||
|
||||
"../", // up 1 in tree
|
||||
"../common/", // up 1 in tree, "/common/" subdir
|
||||
"../common/data/", // up 1 in tree, "/common/data/" subdir
|
||||
"../data/", // up 1 in tree, "/data/" subdir
|
||||
"../src/", // up 1 in tree, "/src/" subdir
|
||||
"../inc/", // up 1 in tree, "/inc/" subdir
|
||||
|
||||
"../0_Simple/<executable_name>/data/", // up 1 in tree,
|
||||
// "/0_Simple/<executable_name>/"
|
||||
// subdir
|
||||
"../1_Utilities/<executable_name>/data/", // up 1 in tree,
|
||||
// "/1_Utilities/<executable_name>/"
|
||||
// subdir
|
||||
"../2_Graphics/<executable_name>/data/", // up 1 in tree,
|
||||
// "/2_Graphics/<executable_name>/"
|
||||
// subdir
|
||||
"../3_Imaging/<executable_name>/data/", // up 1 in tree,
|
||||
// "/3_Imaging/<executable_name>/"
|
||||
// subdir
|
||||
"../4_Finance/<executable_name>/data/", // up 1 in tree,
|
||||
// "/4_Finance/<executable_name>/"
|
||||
// subdir
|
||||
"../5_Simulations/<executable_name>/data/", // up 1 in tree,
|
||||
// "/5_Simulations/<executable_name>/"
|
||||
// subdir
|
||||
"../6_Advanced/<executable_name>/data/", // up 1 in tree,
|
||||
// "/6_Advanced/<executable_name>/"
|
||||
// subdir
|
||||
"../7_CUDALibraries/<executable_name>/data/", // up 1 in tree,
|
||||
// "/7_CUDALibraries/<executable_name>/"
|
||||
// subdir
|
||||
"../8_Android/<executable_name>/data/", // up 1 in tree,
|
||||
// "/8_Android/<executable_name>/"
|
||||
// subdir
|
||||
"../samples/<executable_name>/data/", // up 1 in tree,
|
||||
// "/samples/<executable_name>/"
|
||||
// subdir
|
||||
"../../", // up 2 in tree
|
||||
"../../common/", // up 2 in tree, "/common/" subdir
|
||||
"../../common/data/", // up 2 in tree, "/common/data/" subdir
|
||||
"../../data/", // up 2 in tree, "/data/" subdir
|
||||
"../../src/", // up 2 in tree, "/src/" subdir
|
||||
"../../inc/", // up 2 in tree, "/inc/" subdir
|
||||
"../../sandbox/<executable_name>/data/", // up 2 in tree,
|
||||
// "/sandbox/<executable_name>/"
|
||||
// subdir
|
||||
"../../0_Simple/<executable_name>/data/", // up 2 in tree,
|
||||
// "/0_Simple/<executable_name>/"
|
||||
// subdir
|
||||
"../../1_Utilities/<executable_name>/data/", // up 2 in tree,
|
||||
// "/1_Utilities/<executable_name>/"
|
||||
// subdir
|
||||
"../../2_Graphics/<executable_name>/data/", // up 2 in tree,
|
||||
// "/2_Graphics/<executable_name>/"
|
||||
// subdir
|
||||
"../../3_Imaging/<executable_name>/data/", // up 2 in tree,
|
||||
// "/3_Imaging/<executable_name>/"
|
||||
// subdir
|
||||
"../../4_Finance/<executable_name>/data/", // up 2 in tree,
|
||||
// "/4_Finance/<executable_name>/"
|
||||
// subdir
|
||||
"../../5_Simulations/<executable_name>/data/", // up 2 in tree,
|
||||
// "/5_Simulations/<executable_name>/"
|
||||
// subdir
|
||||
"../../6_Advanced/<executable_name>/data/", // up 2 in tree,
|
||||
// "/6_Advanced/<executable_name>/"
|
||||
// subdir
|
||||
"../../7_CUDALibraries/<executable_name>/data/", // up 2 in tree,
|
||||
// "/7_CUDALibraries/<executable_name>/"
|
||||
// subdir
|
||||
"../../8_Android/<executable_name>/data/", // up 2 in tree,
|
||||
// "/8_Android/<executable_name>/"
|
||||
// subdir
|
||||
"../../samples/<executable_name>/data/", // up 2 in tree,
|
||||
// "/samples/<executable_name>/"
|
||||
// subdir
|
||||
"../../../", // up 3 in tree
|
||||
"../../../src/<executable_name>/", // up 3 in tree,
|
||||
// "/src/<executable_name>/" subdir
|
||||
"../../../src/<executable_name>/data/", // up 3 in tree,
|
||||
// "/src/<executable_name>/data/"
|
||||
// subdir
|
||||
"../../../src/<executable_name>/src/", // up 3 in tree,
|
||||
// "/src/<executable_name>/src/"
|
||||
// subdir
|
||||
"../../../src/<executable_name>/inc/", // up 3 in tree,
|
||||
// "/src/<executable_name>/inc/"
|
||||
// subdir
|
||||
"../../../sandbox/<executable_name>/", // up 3 in tree,
|
||||
// "/sandbox/<executable_name>/"
|
||||
// subdir
|
||||
"../../../sandbox/<executable_name>/data/", // up 3 in tree,
|
||||
// "/sandbox/<executable_name>/data/"
|
||||
// subdir
|
||||
"../../../sandbox/<executable_name>/src/", // up 3 in tree,
|
||||
// "/sandbox/<executable_name>/src/"
|
||||
// subdir
|
||||
"../../../sandbox/<executable_name>/inc/", // up 3 in tree,
|
||||
// "/sandbox/<executable_name>/inc/"
|
||||
// subdir
|
||||
"../../../0_Simple/<executable_name>/data/", // up 3 in tree,
|
||||
// "/0_Simple/<executable_name>/"
|
||||
// subdir
|
||||
"../../../1_Utilities/<executable_name>/data/", // up 3 in tree,
|
||||
// "/1_Utilities/<executable_name>/"
|
||||
// subdir
|
||||
"../../../2_Graphics/<executable_name>/data/", // up 3 in tree,
|
||||
// "/2_Graphics/<executable_name>/"
|
||||
// subdir
|
||||
"../../../3_Imaging/<executable_name>/data/", // up 3 in tree,
|
||||
// "/3_Imaging/<executable_name>/"
|
||||
// subdir
|
||||
"../../../4_Finance/<executable_name>/data/", // up 3 in tree,
|
||||
// "/4_Finance/<executable_name>/"
|
||||
// subdir
|
||||
"../../../5_Simulations/<executable_name>/data/", // up 3 in tree,
|
||||
// "/5_Simulations/<executable_name>/"
|
||||
// subdir
|
||||
"../../../6_Advanced/<executable_name>/data/", // up 3 in tree,
|
||||
// "/6_Advanced/<executable_name>/"
|
||||
// subdir
|
||||
"../../../7_CUDALibraries/<executable_name>/data/", // up 3 in tree,
|
||||
// "/7_CUDALibraries/<executable_name>/"
|
||||
// subdir
|
||||
"../../../8_Android/<executable_name>/data/", // up 3 in tree,
|
||||
// "/8_Android/<executable_name>/"
|
||||
// subdir
|
||||
"../../../0_Simple/<executable_name>/", // up 3 in tree,
|
||||
// "/0_Simple/<executable_name>/"
|
||||
// subdir
|
||||
"../../../1_Utilities/<executable_name>/", // up 3 in tree,
|
||||
// "/1_Utilities/<executable_name>/"
|
||||
// subdir
|
||||
"../../../2_Graphics/<executable_name>/", // up 3 in tree,
|
||||
// "/2_Graphics/<executable_name>/"
|
||||
// subdir
|
||||
"../../../3_Imaging/<executable_name>/", // up 3 in tree,
|
||||
// "/3_Imaging/<executable_name>/"
|
||||
// subdir
|
||||
"../../../4_Finance/<executable_name>/", // up 3 in tree,
|
||||
// "/4_Finance/<executable_name>/"
|
||||
// subdir
|
||||
"../../../5_Simulations/<executable_name>/", // up 3 in tree,
|
||||
// "/5_Simulations/<executable_name>/"
|
||||
// subdir
|
||||
"../../../6_Advanced/<executable_name>/", // up 3 in tree,
|
||||
// "/6_Advanced/<executable_name>/"
|
||||
// subdir
|
||||
"../../../7_CUDALibraries/<executable_name>/", // up 3 in tree,
|
||||
// "/7_CUDALibraries/<executable_name>/"
|
||||
// subdir
|
||||
"../../../8_Android/<executable_name>/", // up 3 in tree,
|
||||
// "/8_Android/<executable_name>/"
|
||||
// subdir
|
||||
"../../../samples/<executable_name>/data/", // up 3 in tree,
|
||||
// "/samples/<executable_name>/"
|
||||
// subdir
|
||||
"../../../common/", // up 3 in tree, "../../../common/" subdir
|
||||
"../../../common/data/", // up 3 in tree, "../../../common/data/" subdir
|
||||
"../../../data/", // up 3 in tree, "../../../data/" subdir
|
||||
"../../../../", // up 4 in tree
|
||||
"../../../../src/<executable_name>/", // up 4 in tree,
|
||||
// "/src/<executable_name>/" subdir
|
||||
"../../../../src/<executable_name>/data/", // up 4 in tree,
|
||||
// "/src/<executable_name>/data/"
|
||||
// subdir
|
||||
"../../../../src/<executable_name>/src/", // up 4 in tree,
|
||||
// "/src/<executable_name>/src/"
|
||||
// subdir
|
||||
"../../../../src/<executable_name>/inc/", // up 4 in tree,
|
||||
// "/src/<executable_name>/inc/"
|
||||
// subdir
|
||||
"../../../../sandbox/<executable_name>/", // up 4 in tree,
|
||||
// "/sandbox/<executable_name>/"
|
||||
// subdir
|
||||
"../../../../sandbox/<executable_name>/data/", // up 4 in tree,
|
||||
// "/sandbox/<executable_name>/data/"
|
||||
// subdir
|
||||
"../../../../sandbox/<executable_name>/src/", // up 4 in tree,
|
||||
// "/sandbox/<executable_name>/src/"
|
||||
// subdir
|
||||
"../../../../sandbox/<executable_name>/inc/", // up 4 in tree,
|
||||
// "/sandbox/<executable_name>/inc/"
|
||||
// subdir
|
||||
"../../../../0_Simple/<executable_name>/data/", // up 4 in tree,
|
||||
// "/0_Simple/<executable_name>/"
|
||||
// subdir
|
||||
"../../../../1_Utilities/<executable_name>/data/", // up 4 in tree,
|
||||
// "/1_Utilities/<executable_name>/"
|
||||
// subdir
|
||||
"../../../../2_Graphics/<executable_name>/data/", // up 4 in tree,
|
||||
// "/2_Graphics/<executable_name>/"
|
||||
// subdir
|
||||
"../../../../3_Imaging/<executable_name>/data/", // up 4 in tree,
|
||||
// "/3_Imaging/<executable_name>/"
|
||||
// subdir
|
||||
"../../../../4_Finance/<executable_name>/data/", // up 4 in tree,
|
||||
// "/4_Finance/<executable_name>/"
|
||||
// subdir
|
||||
"../../../../5_Simulations/<executable_name>/data/", // up 4 in tree,
|
||||
// "/5_Simulations/<executable_name>/"
|
||||
// subdir
|
||||
"../../../../6_Advanced/<executable_name>/data/", // up 4 in tree,
|
||||
// "/6_Advanced/<executable_name>/"
|
||||
// subdir
|
||||
"../../../../7_CUDALibraries/<executable_name>/data/", // up 4 in tree,
|
||||
// "/7_CUDALibraries/<executable_name>/"
|
||||
// subdir
|
||||
"../../../../8_Android/<executable_name>/data/", // up 4 in tree,
|
||||
// "/8_Android/<executable_name>/"
|
||||
// subdir
|
||||
"../../../../0_Simple/<executable_name>/", // up 4 in tree,
|
||||
// "/0_Simple/<executable_name>/"
|
||||
// subdir
|
||||
"../../../../1_Utilities/<executable_name>/", // up 4 in tree,
|
||||
// "/1_Utilities/<executable_name>/"
|
||||
// subdir
|
||||
"../../../../2_Graphics/<executable_name>/", // up 4 in tree,
|
||||
// "/2_Graphics/<executable_name>/"
|
||||
// subdir
|
||||
"../../../../3_Imaging/<executable_name>/", // up 4 in tree,
|
||||
// "/3_Imaging/<executable_name>/"
|
||||
// subdir
|
||||
"../../../../4_Finance/<executable_name>/", // up 4 in tree,
|
||||
// "/4_Finance/<executable_name>/"
|
||||
// subdir
|
||||
"../../../../5_Simulations/<executable_name>/", // up 4 in tree,
|
||||
// "/5_Simulations/<executable_name>/"
|
||||
// subdir
|
||||
"../../../../6_Advanced/<executable_name>/", // up 4 in tree,
|
||||
// "/6_Advanced/<executable_name>/"
|
||||
// subdir
|
||||
"../../../../7_CUDALibraries/<executable_name>/", // up 4 in tree,
|
||||
// "/7_CUDALibraries/<executable_name>/"
|
||||
// subdir
|
||||
"../../../../8_Android/<executable_name>/", // up 4 in tree,
|
||||
// "/8_Android/<executable_name>/"
|
||||
// subdir
|
||||
"../../../../samples/<executable_name>/data/", // up 4 in tree,
|
||||
// "/samples/<executable_name>/"
|
||||
// subdir
|
||||
"../../../../common/", // up 4 in tree, "../../../common/" subdir
|
||||
"../../../../common/data/", // up 4 in tree, "../../../common/data/"
|
||||
// subdir
|
||||
"../../../../data/", // up 4 in tree, "../../../data/" subdir
|
||||
"../../../../../", // up 5 in tree
|
||||
"../../../../../src/<executable_name>/", // up 5 in tree,
|
||||
// "/src/<executable_name>/"
|
||||
// subdir
|
||||
"../../../../../src/<executable_name>/data/", // up 5 in tree,
|
||||
// "/src/<executable_name>/data/"
|
||||
// subdir
|
||||
"../../../../../src/<executable_name>/src/", // up 5 in tree,
|
||||
// "/src/<executable_name>/src/"
|
||||
// subdir
|
||||
"../../../../../src/<executable_name>/inc/", // up 5 in tree,
|
||||
// "/src/<executable_name>/inc/"
|
||||
// subdir
|
||||
"../../../../../sandbox/<executable_name>/", // up 5 in tree,
|
||||
// "/sandbox/<executable_name>/"
|
||||
// subdir
|
||||
"../../../../../sandbox/<executable_name>/data/", // up 5 in tree,
|
||||
// "/sandbox/<executable_name>/data/"
|
||||
// subdir
|
||||
"../../../../../sandbox/<executable_name>/src/", // up 5 in tree,
|
||||
// "/sandbox/<executable_name>/src/"
|
||||
// subdir
|
||||
"../../../../../sandbox/<executable_name>/inc/", // up 5 in tree,
|
||||
// "/sandbox/<executable_name>/inc/"
|
||||
// subdir
|
||||
"../../../../../0_Simple/<executable_name>/data/", // up 5 in tree,
|
||||
// "/0_Simple/<executable_name>/"
|
||||
// subdir
|
||||
"../../../../../1_Utilities/<executable_name>/data/", // up 5 in tree,
|
||||
// "/1_Utilities/<executable_name>/"
|
||||
// subdir
|
||||
"../../../../../2_Graphics/<executable_name>/data/", // up 5 in tree,
|
||||
// "/2_Graphics/<executable_name>/"
|
||||
// subdir
|
||||
"../../../../../3_Imaging/<executable_name>/data/", // up 5 in tree,
|
||||
// "/3_Imaging/<executable_name>/"
|
||||
// subdir
|
||||
"../../../../../4_Finance/<executable_name>/data/", // up 5 in tree,
|
||||
// "/4_Finance/<executable_name>/"
|
||||
// subdir
|
||||
"../../../../../5_Simulations/<executable_name>/data/", // up 5 in tree,
|
||||
// "/5_Simulations/<executable_name>/"
|
||||
// subdir
|
||||
"../../../../../6_Advanced/<executable_name>/data/", // up 5 in tree,
|
||||
// "/6_Advanced/<executable_name>/"
|
||||
// subdir
|
||||
"../../../../../7_CUDALibraries/<executable_name>/data/", // up 5 in
|
||||
// tree,
|
||||
// "/7_CUDALibraries/<executable_name>/"
|
||||
// subdir
|
||||
"../../../../../8_Android/<executable_name>/data/", // up 5 in tree,
|
||||
// "/8_Android/<executable_name>/"
|
||||
// subdir
|
||||
"../../../../../samples/<executable_name>/data/", // up 5 in tree,
|
||||
// "/samples/<executable_name>/"
|
||||
// subdir
|
||||
"../../../../../common/", // up 5 in tree, "../../../common/" subdir
|
||||
"../../../../../common/data/", // up 5 in tree, "../../../common/data/"
|
||||
// subdir
|
||||
};
|
||||
|
||||
// Extract the executable name
|
||||
std::string executable_name;
|
||||
|
||||
if (executable_path != 0) {
|
||||
executable_name = std::string(executable_path);
|
||||
|
||||
#if defined(WIN32) || defined(_WIN32) || defined(WIN64) || defined(_WIN64)
|
||||
// Windows path delimiter
|
||||
size_t delimiter_pos = executable_name.find_last_of('\\');
|
||||
executable_name.erase(0, delimiter_pos + 1);
|
||||
|
||||
if (executable_name.rfind(".exe") != std::string::npos) {
|
||||
// we strip .exe, only if the .exe is found
|
||||
executable_name.resize(executable_name.size() - 4);
|
||||
}
|
||||
|
||||
#else
|
||||
// Linux & OSX path delimiter
|
||||
size_t delimiter_pos = executable_name.find_last_of('/');
|
||||
executable_name.erase(0, delimiter_pos + 1);
|
||||
#endif
|
||||
}
|
||||
|
||||
// Loop over all search paths and return the first hit
|
||||
for (unsigned int i = 0; i < sizeof(searchPath) / sizeof(char *); ++i) {
|
||||
std::string path(searchPath[i]);
|
||||
size_t executable_name_pos = path.find("<executable_name>");
|
||||
|
||||
// If there is executable_name variable in the searchPath
|
||||
// replace it with the value
|
||||
if (executable_name_pos != std::string::npos) {
|
||||
if (executable_path != 0) {
|
||||
path.replace(executable_name_pos, strlen("<executable_name>"),
|
||||
executable_name);
|
||||
} else {
|
||||
// Skip this path entry if no executable argument is given
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef _DEBUG
|
||||
printf("sdkFindFilePath <%s> in %s\n", filename, path.c_str());
|
||||
#endif
|
||||
|
||||
// Test if the file exists
|
||||
path.append(filename);
|
||||
FILE *fp;
|
||||
FOPEN(fp, path.c_str(), "rb");
|
||||
|
||||
if (fp != NULL) {
|
||||
fclose(fp);
|
||||
// File found
|
||||
// returning an allocated array here for backwards compatibility reasons
|
||||
char *file_path = reinterpret_cast<char *>(malloc(path.length() + 1));
|
||||
STRCPY(file_path, path.length() + 1, path.c_str());
|
||||
return file_path;
|
||||
}
|
||||
|
||||
if (fp) {
|
||||
fclose(fp);
|
||||
}
|
||||
}
|
||||
|
||||
// File not found
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif // COMMON_HELPER_STRING_H_
|
|
@ -0,0 +1,448 @@
|
|||
/**
|
||||
* Copyright 1993-2013 NVIDIA Corporation. All rights reserved.
|
||||
*
|
||||
* Please refer to the NVIDIA end user license agreement (EULA) associated
|
||||
* with this source code for terms and conditions that govern your use of
|
||||
* this software. Any use, reproduction, disclosure, or distribution of
|
||||
* this software and related documentation outside the terms of the EULA
|
||||
* is strictly prohibited.
|
||||
*
|
||||
*/
|
||||
|
||||
// Helper Timing Functions
|
||||
#ifndef COMMON_HELPER_TIMER_H_
|
||||
#define COMMON_HELPER_TIMER_H_
|
||||
|
||||
#ifndef EXIT_WAIVED
|
||||
#define EXIT_WAIVED 2
|
||||
#endif
|
||||
|
||||
// includes, system
|
||||
#include <vector>
|
||||
|
||||
|
||||
// Definition of the StopWatch Interface, this is used if we don't want to use
|
||||
// the CUT functions But rather in a self contained class interface
|
||||
class StopWatchInterface {
|
||||
public:
|
||||
StopWatchInterface() {}
|
||||
virtual ~StopWatchInterface() {}
|
||||
|
||||
public:
|
||||
//! Start time measurement
|
||||
virtual void start() = 0;
|
||||
|
||||
//! Stop time measurement
|
||||
virtual void stop() = 0;
|
||||
|
||||
//! Reset time counters to zero
|
||||
virtual void reset() = 0;
|
||||
|
||||
//! Time in msec. after start. If the stop watch is still running (i.e. there
|
||||
//! was no call to stop()) then the elapsed time is returned, otherwise the
|
||||
//! time between the last start() and stop call is returned
|
||||
virtual float getTime() = 0;
|
||||
|
||||
//! Mean time to date based on the number of times the stopwatch has been
|
||||
//! _stopped_ (ie finished sessions) and the current total time
|
||||
virtual float getAverageTime() = 0;
|
||||
};
|
||||
|
||||
//////////////////////////////////////////////////////////////////
|
||||
// Begin Stopwatch timer class definitions for all OS platforms //
|
||||
//////////////////////////////////////////////////////////////////
|
||||
#if defined(WIN32) || defined(_WIN32) || defined(WIN64) || defined(_WIN64)
|
||||
// includes, system
|
||||
#define WINDOWS_LEAN_AND_MEAN
|
||||
#include <windows.h>
|
||||
#undef min
|
||||
#undef max
|
||||
|
||||
//! Windows specific implementation of StopWatch
|
||||
class StopWatchWin : public StopWatchInterface {
|
||||
public:
|
||||
//! Constructor, default
|
||||
StopWatchWin()
|
||||
: start_time(),
|
||||
end_time(),
|
||||
diff_time(0.0f),
|
||||
total_time(0.0f),
|
||||
running(false),
|
||||
clock_sessions(0),
|
||||
freq(0),
|
||||
freq_set(false) {
|
||||
if (!freq_set) {
|
||||
// helper variable
|
||||
LARGE_INTEGER temp;
|
||||
|
||||
// get the tick frequency from the OS
|
||||
QueryPerformanceFrequency(reinterpret_cast<LARGE_INTEGER *>(&temp));
|
||||
|
||||
// convert to type in which it is needed
|
||||
freq = (static_cast<double>(temp.QuadPart)) / 1000.0;
|
||||
|
||||
// rememeber query
|
||||
freq_set = true;
|
||||
}
|
||||
}
|
||||
|
||||
// Destructor
|
||||
~StopWatchWin() {}
|
||||
|
||||
public:
|
||||
//! Start time measurement
|
||||
inline void start();
|
||||
|
||||
//! Stop time measurement
|
||||
inline void stop();
|
||||
|
||||
//! Reset time counters to zero
|
||||
inline void reset();
|
||||
|
||||
//! Time in msec. after start. If the stop watch is still running (i.e. there
|
||||
//! was no call to stop()) then the elapsed time is returned, otherwise the
|
||||
//! time between the last start() and stop call is returned
|
||||
inline float getTime();
|
||||
|
||||
//! Mean time to date based on the number of times the stopwatch has been
|
||||
//! _stopped_ (ie finished sessions) and the current total time
|
||||
inline float getAverageTime();
|
||||
|
||||
private:
|
||||
// member variables
|
||||
|
||||
//! Start of measurement
|
||||
LARGE_INTEGER start_time;
|
||||
//! End of measurement
|
||||
LARGE_INTEGER end_time;
|
||||
|
||||
//! Time difference between the last start and stop
|
||||
float diff_time;
|
||||
|
||||
//! TOTAL time difference between starts and stops
|
||||
float total_time;
|
||||
|
||||
//! flag if the stop watch is running
|
||||
bool running;
|
||||
|
||||
//! Number of times clock has been started
|
||||
//! and stopped to allow averaging
|
||||
int clock_sessions;
|
||||
|
||||
//! tick frequency
|
||||
double freq;
|
||||
|
||||
//! flag if the frequency has been set
|
||||
bool freq_set;
|
||||
};
|
||||
|
||||
// functions, inlined
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
//! Start time measurement
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
inline void StopWatchWin::start() {
|
||||
QueryPerformanceCounter(reinterpret_cast<LARGE_INTEGER *>(&start_time));
|
||||
running = true;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
//! Stop time measurement and increment add to the current diff_time summation
|
||||
//! variable. Also increment the number of times this clock has been run.
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
inline void StopWatchWin::stop() {
|
||||
QueryPerformanceCounter(reinterpret_cast<LARGE_INTEGER *>(&end_time));
|
||||
diff_time = static_cast<float>(((static_cast<double>(end_time.QuadPart) -
|
||||
static_cast<double>(start_time.QuadPart)) /
|
||||
freq));
|
||||
|
||||
total_time += diff_time;
|
||||
clock_sessions++;
|
||||
running = false;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
//! Reset the timer to 0. Does not change the timer running state but does
|
||||
//! recapture this point in time as the current start time if it is running.
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
inline void StopWatchWin::reset() {
|
||||
diff_time = 0;
|
||||
total_time = 0;
|
||||
clock_sessions = 0;
|
||||
|
||||
if (running) {
|
||||
QueryPerformanceCounter(reinterpret_cast<LARGE_INTEGER *>(&start_time));
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
//! Time in msec. after start. If the stop watch is still running (i.e. there
|
||||
//! was no call to stop()) then the elapsed time is returned added to the
|
||||
//! current diff_time sum, otherwise the current summed time difference alone
|
||||
//! is returned.
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
inline float StopWatchWin::getTime() {
|
||||
// Return the TOTAL time to date
|
||||
float retval = total_time;
|
||||
|
||||
if (running) {
|
||||
LARGE_INTEGER temp;
|
||||
QueryPerformanceCounter(reinterpret_cast<LARGE_INTEGER *>(&temp));
|
||||
retval += static_cast<float>(((static_cast<double>(temp.QuadPart) -
|
||||
static_cast<double>(start_time.QuadPart)) /
|
||||
freq));
|
||||
}
|
||||
|
||||
return retval;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
//! Time in msec. for a single run based on the total number of COMPLETED runs
|
||||
//! and the total time.
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
inline float StopWatchWin::getAverageTime() {
|
||||
return (clock_sessions > 0) ? (total_time / clock_sessions) : 0.0f;
|
||||
}
|
||||
#else
|
||||
// Declarations for Stopwatch on Linux and Mac OSX
|
||||
// includes, system
|
||||
#include <sys/time.h>
|
||||
#include <ctime>
|
||||
|
||||
//! Windows specific implementation of StopWatch
|
||||
class StopWatchLinux : public StopWatchInterface {
|
||||
public:
|
||||
//! Constructor, default
|
||||
StopWatchLinux()
|
||||
: start_time(),
|
||||
diff_time(0.0),
|
||||
total_time(0.0),
|
||||
running(false),
|
||||
clock_sessions(0) {}
|
||||
|
||||
// Destructor
|
||||
virtual ~StopWatchLinux() {}
|
||||
|
||||
public:
|
||||
//! Start time measurement
|
||||
inline void start();
|
||||
|
||||
//! Stop time measurement
|
||||
inline void stop();
|
||||
|
||||
//! Reset time counters to zero
|
||||
inline void reset();
|
||||
|
||||
//! Time in msec. after start. If the stop watch is still running (i.e. there
|
||||
//! was no call to stop()) then the elapsed time is returned, otherwise the
|
||||
//! time between the last start() and stop call is returned
|
||||
inline float getTime();
|
||||
|
||||
//! Mean time to date based on the number of times the stopwatch has been
|
||||
//! _stopped_ (ie finished sessions) and the current total time
|
||||
inline float getAverageTime();
|
||||
|
||||
private:
|
||||
// helper functions
|
||||
|
||||
//! Get difference between start time and current time
|
||||
inline float getDiffTime();
|
||||
|
||||
private:
|
||||
// member variables
|
||||
|
||||
//! Start of measurement
|
||||
struct timeval start_time;
|
||||
|
||||
//! Time difference between the last start and stop
|
||||
float diff_time;
|
||||
|
||||
//! TOTAL time difference between starts and stops
|
||||
float total_time;
|
||||
|
||||
//! flag if the stop watch is running
|
||||
bool running;
|
||||
|
||||
//! Number of times clock has been started
|
||||
//! and stopped to allow averaging
|
||||
int clock_sessions;
|
||||
};
|
||||
|
||||
// functions, inlined
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
//! Start time measurement
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
inline void StopWatchLinux::start() {
|
||||
gettimeofday(&start_time, 0);
|
||||
running = true;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
//! Stop time measurement and increment add to the current diff_time summation
|
||||
//! variable. Also increment the number of times this clock has been run.
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
inline void StopWatchLinux::stop() {
|
||||
diff_time = getDiffTime();
|
||||
total_time += diff_time;
|
||||
running = false;
|
||||
clock_sessions++;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
//! Reset the timer to 0. Does not change the timer running state but does
|
||||
//! recapture this point in time as the current start time if it is running.
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
inline void StopWatchLinux::reset() {
|
||||
diff_time = 0;
|
||||
total_time = 0;
|
||||
clock_sessions = 0;
|
||||
|
||||
if (running) {
|
||||
gettimeofday(&start_time, 0);
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
//! Time in msec. after start. If the stop watch is still running (i.e. there
|
||||
//! was no call to stop()) then the elapsed time is returned added to the
|
||||
//! current diff_time sum, otherwise the current summed time difference alone
|
||||
//! is returned.
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
inline float StopWatchLinux::getTime() {
|
||||
// Return the TOTAL time to date
|
||||
float retval = total_time;
|
||||
|
||||
if (running) {
|
||||
retval += getDiffTime();
|
||||
}
|
||||
|
||||
return retval;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
//! Time in msec. for a single run based on the total number of COMPLETED runs
|
||||
//! and the total time.
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
inline float StopWatchLinux::getAverageTime() {
|
||||
return (clock_sessions > 0) ? (total_time / clock_sessions) : 0.0f;
|
||||
}
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
inline float StopWatchLinux::getDiffTime() {
|
||||
struct timeval t_time;
|
||||
gettimeofday(&t_time, 0);
|
||||
|
||||
// time difference in milli-seconds
|
||||
return static_cast<float>(1000.0 * (t_time.tv_sec - start_time.tv_sec) +
|
||||
(0.001 * (t_time.tv_usec - start_time.tv_usec)));
|
||||
}
|
||||
#endif // WIN32
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
//! Timer functionality exported
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
//! Create a new timer
|
||||
//! @return true if a time has been created, otherwise false
|
||||
//! @param name of the new timer, 0 if the creation failed
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
inline bool sdkCreateTimer(StopWatchInterface **timer_interface) {
|
||||
// printf("sdkCreateTimer called object %08x\n", (void *)*timer_interface);
|
||||
#if defined(WIN32) || defined(_WIN32) || defined(WIN64) || defined(_WIN64)
|
||||
*timer_interface = reinterpret_cast<StopWatchInterface *>(new StopWatchWin());
|
||||
#else
|
||||
*timer_interface =
|
||||
reinterpret_cast<StopWatchInterface *>(new StopWatchLinux());
|
||||
#endif
|
||||
return (*timer_interface != NULL) ? true : false;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
//! Delete a timer
|
||||
//! @return true if a time has been deleted, otherwise false
|
||||
//! @param name of the timer to delete
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
inline bool sdkDeleteTimer(StopWatchInterface **timer_interface) {
|
||||
// printf("sdkDeleteTimer called object %08x\n", (void *)*timer_interface);
|
||||
if (*timer_interface) {
|
||||
delete *timer_interface;
|
||||
*timer_interface = NULL;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
//! Start the time with name \a name
|
||||
//! @param name name of the timer to start
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
inline bool sdkStartTimer(StopWatchInterface **timer_interface) {
|
||||
// printf("sdkStartTimer called object %08x\n", (void *)*timer_interface);
|
||||
if (*timer_interface) {
|
||||
(*timer_interface)->start();
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
//! Stop the time with name \a name. Does not reset.
|
||||
//! @param name name of the timer to stop
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
inline bool sdkStopTimer(StopWatchInterface **timer_interface) {
|
||||
// printf("sdkStopTimer called object %08x\n", (void *)*timer_interface);
|
||||
if (*timer_interface) {
|
||||
(*timer_interface)->stop();
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
//! Resets the timer's counter.
|
||||
//! @param name name of the timer to reset.
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
inline bool sdkResetTimer(StopWatchInterface **timer_interface) {
|
||||
// printf("sdkResetTimer called object %08x\n", (void *)*timer_interface);
|
||||
if (*timer_interface) {
|
||||
(*timer_interface)->reset();
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
//! Return the average time for timer execution as the total time
|
||||
//! for the timer dividied by the number of completed (stopped) runs the timer
|
||||
//! has made.
|
||||
//! Excludes the current running time if the timer is currently running.
|
||||
//! @param name name of the timer to return the time of
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
inline float sdkGetAverageTimerValue(StopWatchInterface **timer_interface) {
|
||||
// printf("sdkGetAverageTimerValue called object %08x\n", (void
|
||||
// *)*timer_interface);
|
||||
if (*timer_interface) {
|
||||
return (*timer_interface)->getAverageTime();
|
||||
} else {
|
||||
return 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
//! Total execution time for the timer over all runs since the last reset
|
||||
//! or timer creation.
|
||||
//! @param name name of the timer to obtain the value of.
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
inline float sdkGetTimerValue(StopWatchInterface **timer_interface) {
|
||||
// printf("sdkGetTimerValue called object %08x\n", (void *)*timer_interface);
|
||||
if (*timer_interface) {
|
||||
return (*timer_interface)->getTime();
|
||||
} else {
|
||||
return 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
#endif // COMMON_HELPER_TIMER_H_
|
||||
|
|
@ -0,0 +1,42 @@
|
|||
/**
|
||||
* Copyright 2014 NVIDIA Corporation. All rights reserved.
|
||||
*
|
||||
* Please refer to the NVIDIA end user license agreement (EULA) associated
|
||||
* with this source code for terms and conditions that govern your use of
|
||||
* this software. Any use, reproduction, disclosure, or distribution of
|
||||
* this software and related documentation outside the terms of the EULA
|
||||
* is strictly prohibited.
|
||||
*
|
||||
*/
|
||||
|
||||
#include "helper_cuda.h"
|
||||
#include "fp16_dev.h"
|
||||
|
||||
#define BLOCK_SIZE 128
|
||||
template <class value_type>
|
||||
__global__ void float2half_rn_kernel(int size, const value_type *buffIn, half1 *buffOut)
|
||||
{
|
||||
const int idx = BLOCK_SIZE*blockIdx.x+threadIdx.x;
|
||||
if (idx >= size) {
|
||||
return;
|
||||
}
|
||||
#if CUDART_VERSION < 9000
|
||||
half1 val;
|
||||
val.x = __float2half_rn(float(buffIn[idx]));
|
||||
#else
|
||||
half1 val = __float2half_rn(float(buffIn[idx]));
|
||||
#endif
|
||||
buffOut[idx] = val;
|
||||
}
|
||||
|
||||
template <class value_type>
|
||||
void gpu_float2half_rn(int size, const value_type *buffIn, half1 *buffOut)
|
||||
{
|
||||
int grid_size = (size + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
float2half_rn_kernel<value_type><<<grid_size, BLOCK_SIZE>>> (size, buffIn, buffOut);
|
||||
checkCudaErrors(cudaDeviceSynchronize());
|
||||
}
|
||||
|
||||
template void gpu_float2half_rn<float> (int, const float*, half1*);
|
||||
template void gpu_float2half_rn<double> (int, const double*, half1*);
|
||||
|
|
@ -0,0 +1,152 @@
|
|||
/*
|
||||
* Copyright 1993-2014 NVIDIA Corporation. All rights reserved.
|
||||
*
|
||||
* NOTICE TO LICENSEE:
|
||||
*
|
||||
* This source code and/or documentation ("Licensed Deliverables") are
|
||||
* subject to NVIDIA intellectual property rights under U.S. and
|
||||
* international Copyright laws.
|
||||
*
|
||||
* These Licensed Deliverables contained herein is PROPRIETARY and
|
||||
* CONFIDENTIAL to NVIDIA and is being provided under the terms and
|
||||
* conditions of a form of NVIDIA software license agreement by and
|
||||
* between NVIDIA and Licensee ("License Agreement") or electronically
|
||||
* accepted by Licensee. Notwithstanding any terms or conditions to
|
||||
* the contrary in the License Agreement, reproduction or disclosure
|
||||
* of the Licensed Deliverables to any third party without the express
|
||||
* written consent of NVIDIA is prohibited.
|
||||
*
|
||||
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
||||
* LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
|
||||
* SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
|
||||
* PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
|
||||
* NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
|
||||
* DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
|
||||
* NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
||||
* LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
|
||||
* SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
|
||||
* DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
|
||||
* WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
|
||||
* ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
|
||||
* OF THESE LICENSED DELIVERABLES.
|
||||
*
|
||||
* U.S. Government End Users. These Licensed Deliverables are a
|
||||
* "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
|
||||
* 1995), consisting of "commercial computer software" and "commercial
|
||||
* computer software documentation" as such terms are used in 48
|
||||
* C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
|
||||
* only as a commercial end item. Consistent with 48 C.F.R.12.212 and
|
||||
* 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
|
||||
* U.S. Government End Users acquire the Licensed Deliverables with
|
||||
* only those rights set forth herein.
|
||||
*
|
||||
* Any use of the Licensed Deliverables in individual and commercial
|
||||
* software must include, in the user documentation and internal
|
||||
* comments to the code, the above Disclaimer and U.S. Government End
|
||||
* Users Notice.
|
||||
*/
|
||||
|
||||
#include "fp16_emu.h"
|
||||
|
||||
|
||||
#ifdef __GNUC__
|
||||
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
|
||||
#pragma GCC diagnostic ignored "-Wunused-local-typedefs"
|
||||
#endif
|
||||
|
||||
#define STATIC_ASSERT(cond) do { typedef char compile_time_assert[(cond) ? 1 : -1]; } while (0)
|
||||
|
||||
// Host functions for converting between FP32 and FP16 formats
|
||||
// Paulius Micikevicius (pauliusm@nvidia.com)
|
||||
|
||||
half1 cpu_float2half_rn(float f)
|
||||
{
|
||||
unsigned x = *((int*)(void*)(&f));
|
||||
unsigned u = (x & 0x7fffffff), remainder, shift, lsb, lsb_s1, lsb_m1;
|
||||
unsigned sign, exponent, mantissa;
|
||||
|
||||
__half_raw hr;
|
||||
|
||||
// Get rid of +NaN/-NaN case first.
|
||||
if (u > 0x7f800000) {
|
||||
hr.x = 0x7fffU;
|
||||
return reinterpret_cast<half1&>(hr);
|
||||
}
|
||||
|
||||
sign = ((x >> 16) & 0x8000);
|
||||
|
||||
// Get rid of +Inf/-Inf, +0/-0.
|
||||
if (u > 0x477fefff) {
|
||||
hr.x = sign | 0x7c00U;
|
||||
return reinterpret_cast<half1&>(hr);
|
||||
}
|
||||
if (u < 0x33000001) {
|
||||
hr.x = sign | 0x0000U;
|
||||
return reinterpret_cast<half1&>(hr);
|
||||
}
|
||||
|
||||
exponent = ((u >> 23) & 0xff);
|
||||
mantissa = (u & 0x7fffff);
|
||||
|
||||
if (exponent > 0x70) {
|
||||
shift = 13;
|
||||
exponent -= 0x70;
|
||||
} else {
|
||||
shift = 0x7e - exponent;
|
||||
exponent = 0;
|
||||
mantissa |= 0x800000;
|
||||
}
|
||||
lsb = (1 << shift);
|
||||
lsb_s1 = (lsb >> 1);
|
||||
lsb_m1 = (lsb - 1);
|
||||
|
||||
// Round to nearest even.
|
||||
remainder = (mantissa & lsb_m1);
|
||||
mantissa >>= shift;
|
||||
if (remainder > lsb_s1 || (remainder == lsb_s1 && (mantissa & 0x1))) {
|
||||
++mantissa;
|
||||
if (!(mantissa & 0x3ff)) {
|
||||
++exponent;
|
||||
mantissa = 0;
|
||||
}
|
||||
}
|
||||
|
||||
hr.x = (sign | (exponent << 10) | mantissa);
|
||||
|
||||
return reinterpret_cast<half1&>(hr);
|
||||
}
|
||||
|
||||
|
||||
float cpu_half2float(half1 h)
|
||||
{
|
||||
STATIC_ASSERT(sizeof(int) == sizeof(float));
|
||||
|
||||
__half_raw hr = reinterpret_cast<__half_raw&>(h);
|
||||
|
||||
unsigned sign = ((hr.x >> 15) & 1);
|
||||
unsigned exponent = ((hr.x >> 10) & 0x1f);
|
||||
unsigned mantissa = ((hr.x & 0x3ff) << 13);
|
||||
|
||||
if (exponent == 0x1f) { /* NaN or Inf */
|
||||
mantissa = (mantissa ? (sign = 0, 0x7fffff) : 0);
|
||||
exponent = 0xff;
|
||||
} else if (!exponent) { /* Denorm or Zero */
|
||||
if (mantissa) {
|
||||
unsigned int msb;
|
||||
exponent = 0x71;
|
||||
do {
|
||||
msb = (mantissa & 0x400000);
|
||||
mantissa <<= 1; /* normalize */
|
||||
--exponent;
|
||||
} while (!msb);
|
||||
mantissa &= 0x7fffff; /* 1.mantissa is implicit */
|
||||
}
|
||||
} else {
|
||||
exponent += 0x70;
|
||||
}
|
||||
|
||||
int temp = ((sign << 31) | (exponent << 23) | mantissa);
|
||||
|
||||
return reinterpret_cast<float&>(temp);
|
||||
}
|
|
@ -0,0 +1,460 @@
|
|||
/**
|
||||
* Copyright 1993-2017 NVIDIA Corporation. All rights reserved.
|
||||
*
|
||||
* Please refer to the NVIDIA end user license agreement (EULA) associated
|
||||
* with this source code for terms and conditions that govern your use of
|
||||
* this software. Any use, reproduction, disclosure, or distribution of
|
||||
* this software and related documentation outside the terms of the EULA
|
||||
* is strictly prohibited.
|
||||
*
|
||||
*/
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// These are CUDA Helper functions for initialization and error checking
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <helper_cuda.h>
|
||||
|
||||
|
||||
#ifdef _CUFFT_H_
|
||||
// cuFFT API errors
|
||||
const char *_cudaGetErrorEnum(cufftResult error) {
|
||||
switch (error) {
|
||||
case CUFFT_SUCCESS:
|
||||
return "CUFFT_SUCCESS";
|
||||
|
||||
case CUFFT_INVALID_PLAN:
|
||||
return "CUFFT_INVALID_PLAN";
|
||||
|
||||
case CUFFT_ALLOC_FAILED:
|
||||
return "CUFFT_ALLOC_FAILED";
|
||||
|
||||
case CUFFT_INVALID_TYPE:
|
||||
return "CUFFT_INVALID_TYPE";
|
||||
|
||||
case CUFFT_INVALID_VALUE:
|
||||
return "CUFFT_INVALID_VALUE";
|
||||
|
||||
case CUFFT_INTERNAL_ERROR:
|
||||
return "CUFFT_INTERNAL_ERROR";
|
||||
|
||||
case CUFFT_EXEC_FAILED:
|
||||
return "CUFFT_EXEC_FAILED";
|
||||
|
||||
case CUFFT_SETUP_FAILED:
|
||||
return "CUFFT_SETUP_FAILED";
|
||||
|
||||
case CUFFT_INVALID_SIZE:
|
||||
return "CUFFT_INVALID_SIZE";
|
||||
|
||||
case CUFFT_UNALIGNED_DATA:
|
||||
return "CUFFT_UNALIGNED_DATA";
|
||||
|
||||
case CUFFT_INCOMPLETE_PARAMETER_LIST:
|
||||
return "CUFFT_INCOMPLETE_PARAMETER_LIST";
|
||||
|
||||
case CUFFT_INVALID_DEVICE:
|
||||
return "CUFFT_INVALID_DEVICE";
|
||||
|
||||
case CUFFT_PARSE_ERROR:
|
||||
return "CUFFT_PARSE_ERROR";
|
||||
|
||||
case CUFFT_NO_WORKSPACE:
|
||||
return "CUFFT_NO_WORKSPACE";
|
||||
|
||||
case CUFFT_NOT_IMPLEMENTED:
|
||||
return "CUFFT_NOT_IMPLEMENTED";
|
||||
|
||||
case CUFFT_LICENSE_ERROR:
|
||||
return "CUFFT_LICENSE_ERROR";
|
||||
|
||||
case CUFFT_NOT_SUPPORTED:
|
||||
return "CUFFT_NOT_SUPPORTED";
|
||||
}
|
||||
|
||||
return "<unknown>";
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
#ifdef CUSPARSEAPI
|
||||
// cuSPARSE API errors
|
||||
const char *_cudaGetErrorEnum(cusparseStatus_t error) {
|
||||
switch (error) {
|
||||
case CUSPARSE_STATUS_SUCCESS:
|
||||
return "CUSPARSE_STATUS_SUCCESS";
|
||||
|
||||
case CUSPARSE_STATUS_NOT_INITIALIZED:
|
||||
return "CUSPARSE_STATUS_NOT_INITIALIZED";
|
||||
|
||||
case CUSPARSE_STATUS_ALLOC_FAILED:
|
||||
return "CUSPARSE_STATUS_ALLOC_FAILED";
|
||||
|
||||
case CUSPARSE_STATUS_INVALID_VALUE:
|
||||
return "CUSPARSE_STATUS_INVALID_VALUE";
|
||||
|
||||
case CUSPARSE_STATUS_ARCH_MISMATCH:
|
||||
return "CUSPARSE_STATUS_ARCH_MISMATCH";
|
||||
|
||||
case CUSPARSE_STATUS_MAPPING_ERROR:
|
||||
return "CUSPARSE_STATUS_MAPPING_ERROR";
|
||||
|
||||
case CUSPARSE_STATUS_EXECUTION_FAILED:
|
||||
return "CUSPARSE_STATUS_EXECUTION_FAILED";
|
||||
|
||||
case CUSPARSE_STATUS_INTERNAL_ERROR:
|
||||
return "CUSPARSE_STATUS_INTERNAL_ERROR";
|
||||
|
||||
case CUSPARSE_STATUS_MATRIX_TYPE_NOT_SUPPORTED:
|
||||
return "CUSPARSE_STATUS_MATRIX_TYPE_NOT_SUPPORTED";
|
||||
}
|
||||
|
||||
return "<unknown>";
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
#ifdef CUSOLVER_COMMON_H_
|
||||
// cuSOLVER API errors
|
||||
const char *_cudaGetErrorEnum(cusolverStatus_t error) {
|
||||
switch (error) {
|
||||
case CUSOLVER_STATUS_SUCCESS:
|
||||
return "CUSOLVER_STATUS_SUCCESS";
|
||||
case CUSOLVER_STATUS_NOT_INITIALIZED:
|
||||
return "CUSOLVER_STATUS_NOT_INITIALIZED";
|
||||
case CUSOLVER_STATUS_ALLOC_FAILED:
|
||||
return "CUSOLVER_STATUS_ALLOC_FAILED";
|
||||
case CUSOLVER_STATUS_INVALID_VALUE:
|
||||
return "CUSOLVER_STATUS_INVALID_VALUE";
|
||||
case CUSOLVER_STATUS_ARCH_MISMATCH:
|
||||
return "CUSOLVER_STATUS_ARCH_MISMATCH";
|
||||
case CUSOLVER_STATUS_MAPPING_ERROR:
|
||||
return "CUSOLVER_STATUS_MAPPING_ERROR";
|
||||
case CUSOLVER_STATUS_EXECUTION_FAILED:
|
||||
return "CUSOLVER_STATUS_EXECUTION_FAILED";
|
||||
case CUSOLVER_STATUS_INTERNAL_ERROR:
|
||||
return "CUSOLVER_STATUS_INTERNAL_ERROR";
|
||||
case CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED:
|
||||
return "CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED";
|
||||
case CUSOLVER_STATUS_NOT_SUPPORTED:
|
||||
return "CUSOLVER_STATUS_NOT_SUPPORTED ";
|
||||
case CUSOLVER_STATUS_ZERO_PIVOT:
|
||||
return "CUSOLVER_STATUS_ZERO_PIVOT";
|
||||
case CUSOLVER_STATUS_INVALID_LICENSE:
|
||||
return "CUSOLVER_STATUS_INVALID_LICENSE";
|
||||
}
|
||||
|
||||
return "<unknown>";
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
#ifdef CURAND_H_
|
||||
// cuRAND API errors
|
||||
const char *_cudaGetErrorEnum(curandStatus_t error) {
|
||||
switch (error) {
|
||||
case CURAND_STATUS_SUCCESS:
|
||||
return "CURAND_STATUS_SUCCESS";
|
||||
|
||||
case CURAND_STATUS_VERSION_MISMATCH:
|
||||
return "CURAND_STATUS_VERSION_MISMATCH";
|
||||
|
||||
case CURAND_STATUS_NOT_INITIALIZED:
|
||||
return "CURAND_STATUS_NOT_INITIALIZED";
|
||||
|
||||
case CURAND_STATUS_ALLOCATION_FAILED:
|
||||
return "CURAND_STATUS_ALLOCATION_FAILED";
|
||||
|
||||
case CURAND_STATUS_TYPE_ERROR:
|
||||
return "CURAND_STATUS_TYPE_ERROR";
|
||||
|
||||
case CURAND_STATUS_OUT_OF_RANGE:
|
||||
return "CURAND_STATUS_OUT_OF_RANGE";
|
||||
|
||||
case CURAND_STATUS_LENGTH_NOT_MULTIPLE:
|
||||
return "CURAND_STATUS_LENGTH_NOT_MULTIPLE";
|
||||
|
||||
case CURAND_STATUS_DOUBLE_PRECISION_REQUIRED:
|
||||
return "CURAND_STATUS_DOUBLE_PRECISION_REQUIRED";
|
||||
|
||||
case CURAND_STATUS_LAUNCH_FAILURE:
|
||||
return "CURAND_STATUS_LAUNCH_FAILURE";
|
||||
|
||||
case CURAND_STATUS_PREEXISTING_FAILURE:
|
||||
return "CURAND_STATUS_PREEXISTING_FAILURE";
|
||||
|
||||
case CURAND_STATUS_INITIALIZATION_FAILED:
|
||||
return "CURAND_STATUS_INITIALIZATION_FAILED";
|
||||
|
||||
case CURAND_STATUS_ARCH_MISMATCH:
|
||||
return "CURAND_STATUS_ARCH_MISMATCH";
|
||||
|
||||
case CURAND_STATUS_INTERNAL_ERROR:
|
||||
return "CURAND_STATUS_INTERNAL_ERROR";
|
||||
}
|
||||
|
||||
return "<unknown>";
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
#ifdef NV_NPPIDEFS_H
|
||||
// NPP API errors
|
||||
const char *_cudaGetErrorEnum(NppStatus error) {
|
||||
switch (error) {
|
||||
case NPP_NOT_SUPPORTED_MODE_ERROR:
|
||||
return "NPP_NOT_SUPPORTED_MODE_ERROR";
|
||||
|
||||
case NPP_ROUND_MODE_NOT_SUPPORTED_ERROR:
|
||||
return "NPP_ROUND_MODE_NOT_SUPPORTED_ERROR";
|
||||
|
||||
case NPP_RESIZE_NO_OPERATION_ERROR:
|
||||
return "NPP_RESIZE_NO_OPERATION_ERROR";
|
||||
|
||||
case NPP_NOT_SUFFICIENT_COMPUTE_CAPABILITY:
|
||||
return "NPP_NOT_SUFFICIENT_COMPUTE_CAPABILITY";
|
||||
|
||||
#if ((NPP_VERSION_MAJOR << 12) + (NPP_VERSION_MINOR << 4)) <= 0x5000
|
||||
|
||||
case NPP_BAD_ARG_ERROR:
|
||||
return "NPP_BAD_ARGUMENT_ERROR";
|
||||
|
||||
case NPP_COEFF_ERROR:
|
||||
return "NPP_COEFFICIENT_ERROR";
|
||||
|
||||
case NPP_RECT_ERROR:
|
||||
return "NPP_RECTANGLE_ERROR";
|
||||
|
||||
case NPP_QUAD_ERROR:
|
||||
return "NPP_QUADRANGLE_ERROR";
|
||||
|
||||
case NPP_MEM_ALLOC_ERR:
|
||||
return "NPP_MEMORY_ALLOCATION_ERROR";
|
||||
|
||||
case NPP_HISTO_NUMBER_OF_LEVELS_ERROR:
|
||||
return "NPP_HISTOGRAM_NUMBER_OF_LEVELS_ERROR";
|
||||
|
||||
case NPP_INVALID_INPUT:
|
||||
return "NPP_INVALID_INPUT";
|
||||
|
||||
case NPP_POINTER_ERROR:
|
||||
return "NPP_POINTER_ERROR";
|
||||
|
||||
case NPP_WARNING:
|
||||
return "NPP_WARNING";
|
||||
|
||||
case NPP_ODD_ROI_WARNING:
|
||||
return "NPP_ODD_ROI_WARNING";
|
||||
#else
|
||||
|
||||
// These are for CUDA 5.5 or higher
|
||||
case NPP_BAD_ARGUMENT_ERROR:
|
||||
return "NPP_BAD_ARGUMENT_ERROR";
|
||||
|
||||
case NPP_COEFFICIENT_ERROR:
|
||||
return "NPP_COEFFICIENT_ERROR";
|
||||
|
||||
case NPP_RECTANGLE_ERROR:
|
||||
return "NPP_RECTANGLE_ERROR";
|
||||
|
||||
case NPP_QUADRANGLE_ERROR:
|
||||
return "NPP_QUADRANGLE_ERROR";
|
||||
|
||||
case NPP_MEMORY_ALLOCATION_ERR:
|
||||
return "NPP_MEMORY_ALLOCATION_ERROR";
|
||||
|
||||
case NPP_HISTOGRAM_NUMBER_OF_LEVELS_ERROR:
|
||||
return "NPP_HISTOGRAM_NUMBER_OF_LEVELS_ERROR";
|
||||
|
||||
case NPP_INVALID_HOST_POINTER_ERROR:
|
||||
return "NPP_INVALID_HOST_POINTER_ERROR";
|
||||
|
||||
case NPP_INVALID_DEVICE_POINTER_ERROR:
|
||||
return "NPP_INVALID_DEVICE_POINTER_ERROR";
|
||||
#endif
|
||||
|
||||
case NPP_LUT_NUMBER_OF_LEVELS_ERROR:
|
||||
return "NPP_LUT_NUMBER_OF_LEVELS_ERROR";
|
||||
|
||||
case NPP_TEXTURE_BIND_ERROR:
|
||||
return "NPP_TEXTURE_BIND_ERROR";
|
||||
|
||||
case NPP_WRONG_INTERSECTION_ROI_ERROR:
|
||||
return "NPP_WRONG_INTERSECTION_ROI_ERROR";
|
||||
|
||||
case NPP_NOT_EVEN_STEP_ERROR:
|
||||
return "NPP_NOT_EVEN_STEP_ERROR";
|
||||
|
||||
case NPP_INTERPOLATION_ERROR:
|
||||
return "NPP_INTERPOLATION_ERROR";
|
||||
|
||||
case NPP_RESIZE_FACTOR_ERROR:
|
||||
return "NPP_RESIZE_FACTOR_ERROR";
|
||||
|
||||
case NPP_HAAR_CLASSIFIER_PIXEL_MATCH_ERROR:
|
||||
return "NPP_HAAR_CLASSIFIER_PIXEL_MATCH_ERROR";
|
||||
|
||||
#if ((NPP_VERSION_MAJOR << 12) + (NPP_VERSION_MINOR << 4)) <= 0x5000
|
||||
|
||||
case NPP_MEMFREE_ERR:
|
||||
return "NPP_MEMFREE_ERR";
|
||||
|
||||
case NPP_MEMSET_ERR:
|
||||
return "NPP_MEMSET_ERR";
|
||||
|
||||
case NPP_MEMCPY_ERR:
|
||||
return "NPP_MEMCPY_ERROR";
|
||||
|
||||
case NPP_MIRROR_FLIP_ERR:
|
||||
return "NPP_MIRROR_FLIP_ERR";
|
||||
#else
|
||||
|
||||
case NPP_MEMFREE_ERROR:
|
||||
return "NPP_MEMFREE_ERROR";
|
||||
|
||||
case NPP_MEMSET_ERROR:
|
||||
return "NPP_MEMSET_ERROR";
|
||||
|
||||
case NPP_MEMCPY_ERROR:
|
||||
return "NPP_MEMCPY_ERROR";
|
||||
|
||||
case NPP_MIRROR_FLIP_ERROR:
|
||||
return "NPP_MIRROR_FLIP_ERROR";
|
||||
#endif
|
||||
|
||||
case NPP_ALIGNMENT_ERROR:
|
||||
return "NPP_ALIGNMENT_ERROR";
|
||||
|
||||
case NPP_STEP_ERROR:
|
||||
return "NPP_STEP_ERROR";
|
||||
|
||||
case NPP_SIZE_ERROR:
|
||||
return "NPP_SIZE_ERROR";
|
||||
|
||||
case NPP_NULL_POINTER_ERROR:
|
||||
return "NPP_NULL_POINTER_ERROR";
|
||||
|
||||
case NPP_CUDA_KERNEL_EXECUTION_ERROR:
|
||||
return "NPP_CUDA_KERNEL_EXECUTION_ERROR";
|
||||
|
||||
case NPP_NOT_IMPLEMENTED_ERROR:
|
||||
return "NPP_NOT_IMPLEMENTED_ERROR";
|
||||
|
||||
case NPP_ERROR:
|
||||
return "NPP_ERROR";
|
||||
|
||||
case NPP_SUCCESS:
|
||||
return "NPP_SUCCESS";
|
||||
|
||||
case NPP_WRONG_INTERSECTION_QUAD_WARNING:
|
||||
return "NPP_WRONG_INTERSECTION_QUAD_WARNING";
|
||||
|
||||
case NPP_MISALIGNED_DST_ROI_WARNING:
|
||||
return "NPP_MISALIGNED_DST_ROI_WARNING";
|
||||
|
||||
case NPP_AFFINE_QUAD_INCORRECT_WARNING:
|
||||
return "NPP_AFFINE_QUAD_INCORRECT_WARNING";
|
||||
|
||||
case NPP_DOUBLE_SIZE_WARNING:
|
||||
return "NPP_DOUBLE_SIZE_WARNING";
|
||||
|
||||
case NPP_WRONG_INTERSECTION_ROI_WARNING:
|
||||
return "NPP_WRONG_INTERSECTION_ROI_WARNING";
|
||||
|
||||
#if ((NPP_VERSION_MAJOR << 12) + (NPP_VERSION_MINOR << 4)) >= 0x6000
|
||||
/* These are 6.0 or higher */
|
||||
case NPP_LUT_PALETTE_BITSIZE_ERROR:
|
||||
return "NPP_LUT_PALETTE_BITSIZE_ERROR";
|
||||
|
||||
case NPP_ZC_MODE_NOT_SUPPORTED_ERROR:
|
||||
return "NPP_ZC_MODE_NOT_SUPPORTED_ERROR";
|
||||
|
||||
case NPP_QUALITY_INDEX_ERROR:
|
||||
return "NPP_QUALITY_INDEX_ERROR";
|
||||
|
||||
case NPP_CHANNEL_ORDER_ERROR:
|
||||
return "NPP_CHANNEL_ORDER_ERROR";
|
||||
|
||||
case NPP_ZERO_MASK_VALUE_ERROR:
|
||||
return "NPP_ZERO_MASK_VALUE_ERROR";
|
||||
|
||||
case NPP_NUMBER_OF_CHANNELS_ERROR:
|
||||
return "NPP_NUMBER_OF_CHANNELS_ERROR";
|
||||
|
||||
case NPP_COI_ERROR:
|
||||
return "NPP_COI_ERROR";
|
||||
|
||||
case NPP_DIVISOR_ERROR:
|
||||
return "NPP_DIVISOR_ERROR";
|
||||
|
||||
case NPP_CHANNEL_ERROR:
|
||||
return "NPP_CHANNEL_ERROR";
|
||||
|
||||
case NPP_STRIDE_ERROR:
|
||||
return "NPP_STRIDE_ERROR";
|
||||
|
||||
case NPP_ANCHOR_ERROR:
|
||||
return "NPP_ANCHOR_ERROR";
|
||||
|
||||
case NPP_MASK_SIZE_ERROR:
|
||||
return "NPP_MASK_SIZE_ERROR";
|
||||
|
||||
case NPP_MOMENT_00_ZERO_ERROR:
|
||||
return "NPP_MOMENT_00_ZERO_ERROR";
|
||||
|
||||
case NPP_THRESHOLD_NEGATIVE_LEVEL_ERROR:
|
||||
return "NPP_THRESHOLD_NEGATIVE_LEVEL_ERROR";
|
||||
|
||||
case NPP_THRESHOLD_ERROR:
|
||||
return "NPP_THRESHOLD_ERROR";
|
||||
|
||||
case NPP_CONTEXT_MATCH_ERROR:
|
||||
return "NPP_CONTEXT_MATCH_ERROR";
|
||||
|
||||
case NPP_FFT_FLAG_ERROR:
|
||||
return "NPP_FFT_FLAG_ERROR";
|
||||
|
||||
case NPP_FFT_ORDER_ERROR:
|
||||
return "NPP_FFT_ORDER_ERROR";
|
||||
|
||||
case NPP_SCALE_RANGE_ERROR:
|
||||
return "NPP_SCALE_RANGE_ERROR";
|
||||
|
||||
case NPP_DATA_TYPE_ERROR:
|
||||
return "NPP_DATA_TYPE_ERROR";
|
||||
|
||||
case NPP_OUT_OFF_RANGE_ERROR:
|
||||
return "NPP_OUT_OFF_RANGE_ERROR";
|
||||
|
||||
case NPP_DIVIDE_BY_ZERO_ERROR:
|
||||
return "NPP_DIVIDE_BY_ZERO_ERROR";
|
||||
|
||||
case NPP_RANGE_ERROR:
|
||||
return "NPP_RANGE_ERROR";
|
||||
|
||||
case NPP_NO_MEMORY_ERROR:
|
||||
return "NPP_NO_MEMORY_ERROR";
|
||||
|
||||
case NPP_ERROR_RESERVED:
|
||||
return "NPP_ERROR_RESERVED";
|
||||
|
||||
case NPP_NO_OPERATION_WARNING:
|
||||
return "NPP_NO_OPERATION_WARNING";
|
||||
|
||||
case NPP_DIVIDE_BY_ZERO_WARNING:
|
||||
return "NPP_DIVIDE_BY_ZERO_WARNING";
|
||||
#endif
|
||||
|
||||
#if ((NPP_VERSION_MAJOR << 12) + (NPP_VERSION_MINOR << 4)) >= 0x7000
|
||||
/* These are 7.0 or higher */
|
||||
case NPP_OVERFLOW_ERROR:
|
||||
return "NPP_OVERFLOW_ERROR";
|
||||
|
||||
case NPP_CORRUPTED_DATA_ERROR:
|
||||
return "NPP_CORRUPTED_DATA_ERROR";
|
||||
#endif
|
||||
}
|
||||
|
||||
return "<unknown>";
|
||||
}
|
||||
#endif
|
|
@ -0,0 +1,416 @@
|
|||
//===----------------------- AlignmentFromAssumptions.cpp -----------------===//
|
||||
// Set Load/Store Alignments From Assumptions
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file implements a ScalarEvolution-based transformation to set
|
||||
// the alignments of load, stores and memory intrinsics based on the truth
|
||||
// expressions of assume intrinsics. The primary motivation is to handle
|
||||
// complex alignment assumptions that apply to vector loads and stores that
|
||||
// appear after vectorization and unrolling.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#define AA_NAME "jt-alignment-from-assumptions"
|
||||
#define DEBUG_TYPE AA_NAME
|
||||
|
||||
#include <llvm/Pass.h>
|
||||
#include <llvm/IR/Function.h>
|
||||
#include <llvm/Support/raw_ostream.h>
|
||||
|
||||
#include <llvm/ADT/SmallPtrSet.h>
|
||||
#include <llvm/ADT/Statistic.h>
|
||||
#include <llvm/Analysis/ScalarEvolution.h>
|
||||
#include <llvm/Analysis/AliasAnalysis.h>
|
||||
#include <llvm/Analysis/AssumptionCache.h>
|
||||
#include <llvm/Analysis/GlobalsModRef.h>
|
||||
#include <llvm/Analysis/LoopInfo.h>
|
||||
#include <llvm/Analysis/ScalarEvolutionExpressions.h>
|
||||
#include <llvm/Analysis/ValueTracking.h>
|
||||
#include <llvm/IR/Constant.h>
|
||||
#include <llvm/IR/Dominators.h>
|
||||
#include <llvm/IR/Instruction.h>
|
||||
#include <llvm/IR/Intrinsics.h>
|
||||
#include <llvm/IR/IntrinsicInst.h>
|
||||
#include <llvm/IR/Module.h>
|
||||
#include <llvm/Support/Debug.h>
|
||||
#include <llvm/Support/raw_ostream.h>
|
||||
#include <llvm/Transforms/Scalar.h>
|
||||
|
||||
#include <llvm/IR/LegacyPassManager.h>
|
||||
#include <llvm/Transforms/IPO/PassManagerBuilder.h>
|
||||
|
||||
using namespace llvm;
|
||||
|
||||
STATISTIC(NumLoadAlignChanged,
|
||||
"Number of loads changed by alignment assumptions");
|
||||
STATISTIC(NumStoreAlignChanged,
|
||||
"Number of stores changed by alignment assumptions");
|
||||
STATISTIC(NumMemIntAlignChanged,
|
||||
"Number of memory intrinsics changed by alignment assumptions");
|
||||
|
||||
namespace {
|
||||
|
||||
struct AlignmentFromAssumptionsPass
|
||||
: public PassInfoMixin<AlignmentFromAssumptionsPass> {
|
||||
|
||||
bool runImpl(Function &F, AssumptionCache &AC, ScalarEvolution *SE_,
|
||||
DominatorTree *DT_);
|
||||
|
||||
ScalarEvolution *SE = nullptr;
|
||||
DominatorTree *DT = nullptr;
|
||||
|
||||
bool extractAlignmentInfo(CallInst *I, Value *&AAPtr, const SCEV *&AlignSCEV,
|
||||
const SCEV *&OffSCEV);
|
||||
bool processAssumption(CallInst *I);
|
||||
};
|
||||
|
||||
struct JittorAlignmentFromAssumptions : public FunctionPass {
|
||||
static char ID;
|
||||
JittorAlignmentFromAssumptions() : FunctionPass(ID) {}
|
||||
|
||||
bool runOnFunction(Function &F) override {
|
||||
if (skipFunction(F))
|
||||
return false;
|
||||
|
||||
auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
|
||||
ScalarEvolution *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
|
||||
DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
|
||||
|
||||
return Impl.runImpl(F, AC, SE, DT);
|
||||
}
|
||||
|
||||
void getAnalysisUsage(AnalysisUsage &AU) const override {
|
||||
AU.addRequired<AssumptionCacheTracker>();
|
||||
AU.addRequired<ScalarEvolutionWrapperPass>();
|
||||
AU.addRequired<DominatorTreeWrapperPass>();
|
||||
|
||||
AU.setPreservesCFG();
|
||||
AU.addPreserved<AAResultsWrapperPass>();
|
||||
AU.addPreserved<GlobalsAAWrapperPass>();
|
||||
AU.addPreserved<LoopInfoWrapperPass>();
|
||||
AU.addPreserved<DominatorTreeWrapperPass>();
|
||||
AU.addPreserved<ScalarEvolutionWrapperPass>();
|
||||
}
|
||||
|
||||
AlignmentFromAssumptionsPass Impl;
|
||||
}; // end of struct JittorAlignmentFromAssumptions
|
||||
|
||||
// Given an expression for the (constant) alignment, AlignSCEV, and an
|
||||
// expression for the displacement between a pointer and the aligned address,
|
||||
// DiffSCEV, compute the alignment of the displaced pointer if it can be reduced
|
||||
// to a constant. Using SCEV to compute alignment handles the case where
|
||||
// DiffSCEV is a recurrence with constant start such that the aligned offset
|
||||
// is constant. e.g. {16,+,32} % 32 -> 16.
|
||||
static unsigned getNewAlignmentDiff(const SCEV *DiffSCEV,
|
||||
const SCEV *AlignSCEV,
|
||||
ScalarEvolution *SE) {
|
||||
// DiffUnits = Diff % int64_t(Alignment)
|
||||
const SCEV *DiffUnitsSCEV = SE->getURemExpr(DiffSCEV, AlignSCEV);
|
||||
|
||||
LLVM_DEBUG(dbgs() << "\talignment relative to " << *AlignSCEV << " is "
|
||||
<< *DiffUnitsSCEV << " (diff: " << *DiffSCEV << ")\n");
|
||||
|
||||
if (const SCEVConstant *ConstDUSCEV =
|
||||
dyn_cast<SCEVConstant>(DiffUnitsSCEV)) {
|
||||
int64_t DiffUnits = ConstDUSCEV->getValue()->getSExtValue();
|
||||
|
||||
// If the displacement is an exact multiple of the alignment, then the
|
||||
// displaced pointer has the same alignment as the aligned pointer, so
|
||||
// return the alignment value.
|
||||
if (!DiffUnits)
|
||||
return (unsigned)
|
||||
cast<SCEVConstant>(AlignSCEV)->getValue()->getSExtValue();
|
||||
|
||||
// If the displacement is not an exact multiple, but the remainder is a
|
||||
// constant, then return this remainder (but only if it is a power of 2).
|
||||
uint64_t DiffUnitsAbs = std::abs(DiffUnits);
|
||||
if (isPowerOf2_64(DiffUnitsAbs))
|
||||
return (unsigned) DiffUnitsAbs;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
// There is an address given by an offset OffSCEV from AASCEV which has an
|
||||
// alignment AlignSCEV. Use that information, if possible, to compute a new
|
||||
// alignment for Ptr.
|
||||
static unsigned getNewAlignment(const SCEV *AASCEV, const SCEV *AlignSCEV,
|
||||
const SCEV *OffSCEV, Value *Ptr,
|
||||
ScalarEvolution *SE) {
|
||||
const SCEV *PtrSCEV = SE->getSCEV(Ptr);
|
||||
const SCEV *DiffSCEV = SE->getMinusSCEV(PtrSCEV, AASCEV);
|
||||
|
||||
// On 32-bit platforms, DiffSCEV might now have type i32 -- we've always
|
||||
// sign-extended OffSCEV to i64, so make sure they agree again.
|
||||
DiffSCEV = SE->getNoopOrSignExtend(DiffSCEV, OffSCEV->getType());
|
||||
|
||||
// What we really want to know is the overall offset to the aligned
|
||||
// address. This address is displaced by the provided offset.
|
||||
DiffSCEV = SE->getMinusSCEV(DiffSCEV, OffSCEV);
|
||||
|
||||
LLVM_DEBUG(dbgs() << "AFI: alignment of " << *Ptr << " relative to "
|
||||
<< *AlignSCEV << " and offset " << *OffSCEV
|
||||
<< " using diff " << *DiffSCEV << "\n");
|
||||
|
||||
unsigned NewAlignment = getNewAlignmentDiff(DiffSCEV, AlignSCEV, SE);
|
||||
LLVM_DEBUG(dbgs() << "\tnew alignment: " << NewAlignment << "\n");
|
||||
|
||||
if (NewAlignment) {
|
||||
return NewAlignment;
|
||||
} else if (const SCEVAddRecExpr *DiffARSCEV =
|
||||
dyn_cast<SCEVAddRecExpr>(DiffSCEV)) {
|
||||
// The relative offset to the alignment assumption did not yield a constant,
|
||||
// but we should try harder: if we assume that a is 32-byte aligned, then in
|
||||
// for (i = 0; i < 1024; i += 4) r += a[i]; not all of the loads from a are
|
||||
// 32-byte aligned, but instead alternate between 32 and 16-byte alignment.
|
||||
// As a result, the new alignment will not be a constant, but can still
|
||||
// be improved over the default (of 4) to 16.
|
||||
|
||||
const SCEV *DiffStartSCEV = DiffARSCEV->getStart();
|
||||
const SCEV *DiffIncSCEV = DiffARSCEV->getStepRecurrence(*SE);
|
||||
|
||||
LLVM_DEBUG(dbgs() << "\ttrying start/inc alignment using start "
|
||||
<< *DiffStartSCEV << " and inc " << *DiffIncSCEV << "\n");
|
||||
|
||||
// Now compute the new alignment using the displacement to the value in the
|
||||
// first iteration, and also the alignment using the per-iteration delta.
|
||||
// If these are the same, then use that answer. Otherwise, use the smaller
|
||||
// one, but only if it divides the larger one.
|
||||
NewAlignment = getNewAlignmentDiff(DiffStartSCEV, AlignSCEV, SE);
|
||||
unsigned NewIncAlignment = getNewAlignmentDiff(DiffIncSCEV, AlignSCEV, SE);
|
||||
|
||||
LLVM_DEBUG(dbgs() << "\tnew start alignment: " << NewAlignment << "\n");
|
||||
LLVM_DEBUG(dbgs() << "\tnew inc alignment: " << NewIncAlignment << "\n");
|
||||
|
||||
if (!NewAlignment || !NewIncAlignment) {
|
||||
return 0;
|
||||
} else if (NewAlignment > NewIncAlignment) {
|
||||
if (NewAlignment % NewIncAlignment == 0) {
|
||||
LLVM_DEBUG(dbgs() << "\tnew start/inc alignment: " << NewIncAlignment
|
||||
<< "\n");
|
||||
return NewIncAlignment;
|
||||
}
|
||||
} else if (NewIncAlignment > NewAlignment) {
|
||||
if (NewIncAlignment % NewAlignment == 0) {
|
||||
LLVM_DEBUG(dbgs() << "\tnew start/inc alignment: " << NewAlignment
|
||||
<< "\n");
|
||||
return NewAlignment;
|
||||
}
|
||||
} else if (NewIncAlignment == NewAlignment) {
|
||||
LLVM_DEBUG(dbgs() << "\tnew start/inc alignment: " << NewAlignment
|
||||
<< "\n");
|
||||
return NewAlignment;
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
bool AlignmentFromAssumptionsPass::extractAlignmentInfo(CallInst *I,
|
||||
Value *&AAPtr,
|
||||
const SCEV *&AlignSCEV,
|
||||
const SCEV *&OffSCEV) {
|
||||
// An alignment assume must be a statement about the least-significant
|
||||
// bits of the pointer being zero, possibly with some offset.
|
||||
ICmpInst *ICI = dyn_cast<ICmpInst>(I->getArgOperand(0));
|
||||
if (!ICI)
|
||||
return false;
|
||||
|
||||
// This must be an expression of the form: x & m == 0.
|
||||
if (ICI->getPredicate() != ICmpInst::ICMP_EQ)
|
||||
return false;
|
||||
|
||||
// Swap things around so that the RHS is 0.
|
||||
Value *CmpLHS = ICI->getOperand(0);
|
||||
Value *CmpRHS = ICI->getOperand(1);
|
||||
const SCEV *CmpLHSSCEV = SE->getSCEV(CmpLHS);
|
||||
const SCEV *CmpRHSSCEV = SE->getSCEV(CmpRHS);
|
||||
if (CmpLHSSCEV->isZero())
|
||||
std::swap(CmpLHS, CmpRHS);
|
||||
else if (!CmpRHSSCEV->isZero())
|
||||
return false;
|
||||
|
||||
BinaryOperator *CmpBO = dyn_cast<BinaryOperator>(CmpLHS);
|
||||
if (!CmpBO || CmpBO->getOpcode() != Instruction::And)
|
||||
return false;
|
||||
|
||||
// Swap things around so that the right operand of the and is a constant
|
||||
// (the mask); we cannot deal with variable masks.
|
||||
Value *AndLHS = CmpBO->getOperand(0);
|
||||
Value *AndRHS = CmpBO->getOperand(1);
|
||||
const SCEV *AndLHSSCEV = SE->getSCEV(AndLHS);
|
||||
const SCEV *AndRHSSCEV = SE->getSCEV(AndRHS);
|
||||
if (isa<SCEVConstant>(AndLHSSCEV)) {
|
||||
std::swap(AndLHS, AndRHS);
|
||||
std::swap(AndLHSSCEV, AndRHSSCEV);
|
||||
}
|
||||
|
||||
const SCEVConstant *MaskSCEV = dyn_cast<SCEVConstant>(AndRHSSCEV);
|
||||
if (!MaskSCEV)
|
||||
return false;
|
||||
|
||||
// The mask must have some trailing ones (otherwise the condition is
|
||||
// trivial and tells us nothing about the alignment of the left operand).
|
||||
unsigned TrailingOnes = MaskSCEV->getAPInt().countTrailingOnes();
|
||||
if (!TrailingOnes)
|
||||
return false;
|
||||
|
||||
// Cap the alignment at the maximum with which LLVM can deal (and make sure
|
||||
// we don't overflow the shift).
|
||||
uint64_t Alignment;
|
||||
TrailingOnes = std::min(TrailingOnes,
|
||||
unsigned(sizeof(unsigned) * CHAR_BIT - 1));
|
||||
Alignment = std::min(1u << TrailingOnes, +Value::MaximumAlignment);
|
||||
|
||||
Type *Int64Ty = Type::getInt64Ty(I->getParent()->getParent()->getContext());
|
||||
AlignSCEV = SE->getConstant(Int64Ty, Alignment);
|
||||
|
||||
// The LHS might be a ptrtoint instruction, or it might be the pointer
|
||||
// with an offset.
|
||||
AAPtr = nullptr;
|
||||
OffSCEV = nullptr;
|
||||
if (PtrToIntInst *PToI = dyn_cast<PtrToIntInst>(AndLHS)) {
|
||||
AAPtr = PToI->getPointerOperand();
|
||||
OffSCEV = SE->getZero(Int64Ty);
|
||||
} else if (const SCEVAddExpr* AndLHSAddSCEV =
|
||||
dyn_cast<SCEVAddExpr>(AndLHSSCEV)) {
|
||||
// Try to find the ptrtoint; subtract it and the rest is the offset.
|
||||
for (SCEVAddExpr::op_iterator J = AndLHSAddSCEV->op_begin(),
|
||||
JE = AndLHSAddSCEV->op_end(); J != JE; ++J)
|
||||
if (const SCEVUnknown *OpUnk = dyn_cast<SCEVUnknown>(*J))
|
||||
if (PtrToIntInst *PToI = dyn_cast<PtrToIntInst>(OpUnk->getValue())) {
|
||||
AAPtr = PToI->getPointerOperand();
|
||||
OffSCEV = SE->getMinusSCEV(AndLHSAddSCEV, *J);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!AAPtr)
|
||||
return false;
|
||||
|
||||
// Sign extend the offset to 64 bits (so that it is like all of the other
|
||||
// expressions).
|
||||
unsigned OffSCEVBits = OffSCEV->getType()->getPrimitiveSizeInBits();
|
||||
if (OffSCEVBits < 64)
|
||||
OffSCEV = SE->getSignExtendExpr(OffSCEV, Int64Ty);
|
||||
else if (OffSCEVBits > 64)
|
||||
return false;
|
||||
|
||||
AAPtr = AAPtr->stripPointerCasts();
|
||||
return true;
|
||||
}
|
||||
|
||||
bool AlignmentFromAssumptionsPass::processAssumption(CallInst *ACall) {
|
||||
Value *AAPtr;
|
||||
const SCEV *AlignSCEV, *OffSCEV;
|
||||
if (!extractAlignmentInfo(ACall, AAPtr, AlignSCEV, OffSCEV))
|
||||
return false;
|
||||
|
||||
// Skip ConstantPointerNull and UndefValue. Assumptions on these shouldn't
|
||||
// affect other users.
|
||||
if (isa<ConstantData>(AAPtr))
|
||||
return false;
|
||||
|
||||
const SCEV *AASCEV = SE->getSCEV(AAPtr);
|
||||
|
||||
// Apply the assumption to all other users of the specified pointer.
|
||||
SmallPtrSet<Instruction *, 32> Visited;
|
||||
SmallVector<Instruction*, 16> WorkList;
|
||||
for (User *J : AAPtr->users()) {
|
||||
if (J == ACall)
|
||||
continue;
|
||||
|
||||
if (Instruction *K = dyn_cast<Instruction>(J))
|
||||
if (isValidAssumeForContext(ACall, K, DT))
|
||||
WorkList.push_back(K);
|
||||
}
|
||||
|
||||
while (!WorkList.empty()) {
|
||||
Instruction *J = WorkList.pop_back_val();
|
||||
|
||||
if (LoadInst *LI = dyn_cast<LoadInst>(J)) {
|
||||
unsigned NewAlignment = getNewAlignment(AASCEV, AlignSCEV, OffSCEV,
|
||||
LI->getPointerOperand(), SE);
|
||||
|
||||
if (NewAlignment > LI->getAlignment()) {
|
||||
LI->setAlignment(NewAlignment);
|
||||
++NumLoadAlignChanged;
|
||||
}
|
||||
} else if (StoreInst *SI = dyn_cast<StoreInst>(J)) {
|
||||
unsigned NewAlignment = getNewAlignment(AASCEV, AlignSCEV, OffSCEV,
|
||||
SI->getPointerOperand(), SE);
|
||||
|
||||
if (NewAlignment > SI->getAlignment()) {
|
||||
SI->setAlignment(NewAlignment);
|
||||
++NumStoreAlignChanged;
|
||||
}
|
||||
} else if (MemIntrinsic *MI = dyn_cast<MemIntrinsic>(J)) {
|
||||
unsigned NewDestAlignment = getNewAlignment(AASCEV, AlignSCEV, OffSCEV,
|
||||
MI->getDest(), SE);
|
||||
|
||||
LLVM_DEBUG(dbgs() << "\tmem inst: " << NewDestAlignment << "\n";);
|
||||
if (NewDestAlignment > MI->getDestAlignment()) {
|
||||
MI->setDestAlignment(NewDestAlignment);
|
||||
++NumMemIntAlignChanged;
|
||||
}
|
||||
|
||||
// For memory transfers, there is also a source alignment that
|
||||
// can be set.
|
||||
if (MemTransferInst *MTI = dyn_cast<MemTransferInst>(MI)) {
|
||||
unsigned NewSrcAlignment = getNewAlignment(AASCEV, AlignSCEV, OffSCEV,
|
||||
MTI->getSource(), SE);
|
||||
|
||||
LLVM_DEBUG(dbgs() << "\tmem trans: " << NewSrcAlignment << "\n";);
|
||||
|
||||
if (NewSrcAlignment > MTI->getSourceAlignment()) {
|
||||
MTI->setSourceAlignment(NewSrcAlignment);
|
||||
++NumMemIntAlignChanged;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Now that we've updated that use of the pointer, look for other uses of
|
||||
// the pointer to update.
|
||||
Visited.insert(J);
|
||||
for (User *UJ : J->users()) {
|
||||
Instruction *K = cast<Instruction>(UJ);
|
||||
if (!Visited.count(K) && isValidAssumeForContext(ACall, K, DT))
|
||||
WorkList.push_back(K);
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool AlignmentFromAssumptionsPass::runImpl(Function &F, AssumptionCache &AC,
|
||||
ScalarEvolution *SE_,
|
||||
DominatorTree *DT_) {
|
||||
SE = SE_;
|
||||
DT = DT_;
|
||||
|
||||
bool Changed = false;
|
||||
for (auto &AssumeVH : AC.assumptions())
|
||||
if (AssumeVH)
|
||||
Changed |= processAssumption(cast<CallInst>(AssumeVH));
|
||||
|
||||
return Changed;
|
||||
}
|
||||
|
||||
} // end of anonymous namespace
|
||||
|
||||
char JittorAlignmentFromAssumptions::ID = 0;
|
||||
static RegisterPass<JittorAlignmentFromAssumptions> X(
|
||||
"jt-alignment-from-assumptions",
|
||||
"Jittor Alignment From Assumptions",
|
||||
false /* Only looks at CFG */,
|
||||
false /* Analysis Pass */);
|
||||
|
||||
static RegisterStandardPasses Y(
|
||||
PassManagerBuilder::EP_OptimizerLast,
|
||||
[](const PassManagerBuilder &Builder,
|
||||
legacy::PassManagerBase &PM) { PM.add(new JittorAlignmentFromAssumptions()); });
|
|
@ -0,0 +1,800 @@
|
|||
/*******************************************************************************
|
||||
* Copyright 2016-2019 Intel Corporation
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*******************************************************************************/
|
||||
|
||||
/// @example cpu_cnn_inference_f32.cpp
|
||||
/// @copybrief cpu_cnn_inference_f32_cpp
|
||||
/// > Annotated version: @ref cpu_cnn_inference_f32_cpp
|
||||
|
||||
/// @page cpu_cnn_inference_f32_cpp CNN f32 inference example
|
||||
/// This C++ API example demonstrates how to build an AlexNet neural
|
||||
/// network topology for forward-pass inference.
|
||||
///
|
||||
/// > Example code: @ref cpu_cnn_inference_f32.cpp
|
||||
///
|
||||
/// Some key take-aways include:
|
||||
///
|
||||
/// * How tensors are implemented and submitted to primitives.
|
||||
/// * How primitives are created.
|
||||
/// * How primitives are sequentially submitted to the network, where the output
|
||||
/// from primitives is passed as input to the next primitive. The latter
|
||||
/// specifies a dependency between the primitive input and output data.
|
||||
/// * Specific 'inference-only' configurations.
|
||||
/// * Limiting the number of reorders performed that are detrimental
|
||||
/// to performance.
|
||||
///
|
||||
/// The example implements the AlexNet layers
|
||||
/// as numbered primitives (for example, conv1, pool1, conv2).
|
||||
|
||||
#include <assert.h>
|
||||
|
||||
#include <chrono>
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include <mkldnn.hpp>
|
||||
|
||||
using namespace mkldnn;
|
||||
|
||||
using namespace std;
|
||||
|
||||
memory::dim product(const memory::dims &dims) {
|
||||
return std::accumulate(dims.begin(), dims.end(), (memory::dim)1,
|
||||
std::multiplies<memory::dim>());
|
||||
}
|
||||
|
||||
void simple_net(int times = 100) {
|
||||
using tag = memory::format_tag;
|
||||
using dt = memory::data_type;
|
||||
|
||||
/// Initialize a CPU engine and stream. The last parameter in the call represents
|
||||
/// the index of the engine.
|
||||
/// @snippet cpu_cnn_inference_f32.cpp Initialize engine and stream
|
||||
//[Initialize engine and stream]
|
||||
engine eng(engine::kind::cpu, 0);
|
||||
stream s(eng);
|
||||
//[Initialize engine and stream]
|
||||
|
||||
/// Create a vector for the primitives and a vector to hold memory
|
||||
/// that will be used as arguments.
|
||||
/// @snippet cpu_cnn_inference_f32.cpp Create network
|
||||
//[Create network]
|
||||
std::vector<primitive> net;
|
||||
std::vector<std::unordered_map<int, memory>> net_args;
|
||||
//[Create network]
|
||||
|
||||
const memory::dim batch = 1;
|
||||
|
||||
// AlexNet: conv1
|
||||
// {batch, 3, 227, 227} (x) {96, 3, 11, 11} -> {batch, 96, 55, 55}
|
||||
// strides: {4, 4}
|
||||
memory::dims conv1_src_tz = { batch, 3, 227, 227 };
|
||||
memory::dims conv1_weights_tz = { 96, 3, 11, 11 };
|
||||
memory::dims conv1_bias_tz = { 96 };
|
||||
memory::dims conv1_dst_tz = { batch, 96, 55, 55 };
|
||||
memory::dims conv1_strides = { 4, 4 };
|
||||
memory::dims conv1_padding = { 0, 0 };
|
||||
|
||||
/// Allocate buffers for input and output data, weights, and bias.
|
||||
/// @snippet cpu_cnn_inference_f32.cpp Allocate buffers
|
||||
//[Allocate buffers]
|
||||
std::vector<float> user_src(batch * 3 * 227 * 227);
|
||||
std::vector<float> user_dst(batch * 1000);
|
||||
std::vector<float> conv1_weights(product(conv1_weights_tz));
|
||||
std::vector<float> conv1_bias(product(conv1_bias_tz));
|
||||
//[Allocate buffers]
|
||||
|
||||
/// Create memory that describes data layout in the buffers. This example uses
|
||||
/// tag::nchw (batch-channels-height-width) for input data and tag::oihw
|
||||
/// for weights.
|
||||
/// @snippet cpu_cnn_inference_f32.cpp Create user memory
|
||||
//[Create user memory]
|
||||
auto user_src_memory = memory(
|
||||
{ { conv1_src_tz }, dt::f32, tag::nchw }, eng, user_src.data());
|
||||
auto user_weights_memory
|
||||
= memory({ { conv1_weights_tz }, dt::f32, tag::oihw }, eng,
|
||||
conv1_weights.data());
|
||||
auto conv1_user_bias_memory = memory(
|
||||
{ { conv1_bias_tz }, dt::f32, tag::x }, eng, conv1_bias.data());
|
||||
//[Create user memory]
|
||||
|
||||
/// Create memory descriptors with layout tag::any. The `any` format enables
|
||||
/// the convolution primitive to choose the data format that will result in
|
||||
/// best performance based on its input parameters (convolution kernel
|
||||
/// sizes, strides, padding, and so on). If the resulting format is different
|
||||
/// from `nchw`, the user data must be transformed to the format required for
|
||||
/// the convolution (as explained below).
|
||||
/// @snippet cpu_cnn_inference_f32.cpp Create convolution memory descriptors
|
||||
//[Create convolution memory descriptors]
|
||||
auto conv1_src_md = memory::desc({ conv1_src_tz }, dt::f32, tag::any);
|
||||
auto conv1_bias_md = memory::desc({ conv1_bias_tz }, dt::f32, tag::any);
|
||||
auto conv1_weights_md
|
||||
= memory::desc({ conv1_weights_tz }, dt::f32, tag::any);
|
||||
auto conv1_dst_md = memory::desc({ conv1_dst_tz }, dt::f32, tag::any);
|
||||
//[Create convolution memory descriptors]
|
||||
|
||||
/// Create a convolution descriptor by specifying propagation kind,
|
||||
/// [convolution algorithm](@ref dev_guide_convolution), shapes of input,
|
||||
/// weights, bias, output, convolution strides, padding, and kind of padding.
|
||||
/// Propagation kind is set to prop_kind::forward_inference to optimize for
|
||||
/// inference execution and omit computations that are necessary only for
|
||||
/// backward propagation.
|
||||
/// @snippet cpu_cnn_inference_f32.cpp Create convolution descriptor
|
||||
//[Create convolution descriptor]
|
||||
auto conv1_desc = convolution_forward::desc(prop_kind::forward_inference,
|
||||
algorithm::convolution_direct, conv1_src_md, conv1_weights_md, conv1_bias_md,
|
||||
conv1_dst_md, conv1_strides, conv1_padding, conv1_padding);
|
||||
//[Create convolution descriptor]
|
||||
|
||||
/// Create a convolution primitive descriptor. Once created, this
|
||||
/// descriptor has specific formats instead of the `any` format specified
|
||||
/// in the convolution descriptor.
|
||||
/// @snippet cpu_cnn_inference_f32.cpp Create convolution primitive descriptor
|
||||
//[Create convolution primitive descriptor]
|
||||
auto conv1_prim_desc = convolution_forward::primitive_desc(conv1_desc, eng);
|
||||
//[Create convolution primitive descriptor]
|
||||
|
||||
|
||||
/// Check whether data and weights formats required by convolution is different
|
||||
/// from the user format. In case it is different change the layout using
|
||||
/// reorder primitive.
|
||||
/// @snippet cpu_cnn_inference_f32.cpp Reorder data and weights
|
||||
//[Reorder data and weights]
|
||||
auto conv1_src_memory = user_src_memory;
|
||||
if (conv1_prim_desc.src_desc() != user_src_memory.get_desc()) {
|
||||
conv1_src_memory = memory(conv1_prim_desc.src_desc(), eng);
|
||||
net.push_back(reorder(user_src_memory, conv1_src_memory));
|
||||
net_args.push_back({ { MKLDNN_ARG_FROM, user_src_memory },
|
||||
{ MKLDNN_ARG_TO, conv1_src_memory } });
|
||||
}
|
||||
|
||||
auto conv1_weights_memory = user_weights_memory;
|
||||
if (conv1_prim_desc.weights_desc() != user_weights_memory.get_desc()) {
|
||||
conv1_weights_memory = memory(conv1_prim_desc.weights_desc(), eng);
|
||||
reorder(user_weights_memory, conv1_weights_memory)
|
||||
.execute(s, user_weights_memory, conv1_weights_memory);
|
||||
}
|
||||
//[Reorder data and weights]
|
||||
|
||||
/// Create a memory primitive for output.
|
||||
/// @snippet cpu_cnn_inference_f32.cpp Create memory for output
|
||||
//[Create memory for output]
|
||||
auto conv1_dst_memory = memory(conv1_prim_desc.dst_desc(), eng);
|
||||
//[Create memory for output]
|
||||
|
||||
/// Create a convolution primitive and add it to the net.
|
||||
/// @snippet cpu_cnn_inference_f32.cpp Create memory for output
|
||||
//[Create convolution primitive]
|
||||
net.push_back(convolution_forward(conv1_prim_desc));
|
||||
net_args.push_back({ { MKLDNN_ARG_SRC, conv1_src_memory },
|
||||
{ MKLDNN_ARG_WEIGHTS, conv1_weights_memory },
|
||||
{ MKLDNN_ARG_BIAS, conv1_user_bias_memory },
|
||||
{ MKLDNN_ARG_DST, conv1_dst_memory } });
|
||||
//[Create convolution primitive]
|
||||
|
||||
// AlexNet: relu1
|
||||
// {batch, 96, 55, 55} -> {batch, 96, 55, 55}
|
||||
const float negative1_slope = 1.0f;
|
||||
|
||||
|
||||
/// Create the relu primitive. For better performance, keep the input data
|
||||
/// format for ReLU (as well as for other operation primitives until another
|
||||
/// convolution or inner product is encountered) the same as the one chosen
|
||||
/// for convolution. Also note that ReLU is done in-place by using conv1 memory.
|
||||
/// @snippet cpu_cnn_inference_f32.cpp Create relu primitive
|
||||
//[Create relu primitive]
|
||||
auto relu1_desc = eltwise_forward::desc(prop_kind::forward_inference,
|
||||
algorithm::eltwise_relu, conv1_dst_memory.get_desc(),
|
||||
negative1_slope);
|
||||
auto relu1_prim_desc = eltwise_forward::primitive_desc(relu1_desc, eng);
|
||||
|
||||
net.push_back(eltwise_forward(relu1_prim_desc));
|
||||
net_args.push_back({ { MKLDNN_ARG_SRC, conv1_dst_memory },
|
||||
{ MKLDNN_ARG_DST, conv1_dst_memory } });
|
||||
//[Create relu primitive]
|
||||
|
||||
// AlexNet: lrn1
|
||||
// {batch, 96, 55, 55} -> {batch, 96, 55, 55}
|
||||
// local size: 5
|
||||
// alpha1: 0.0001
|
||||
// beta1: 0.75
|
||||
const memory::dim local1_size = 5;
|
||||
const float alpha1 = 0.0001f;
|
||||
const float beta1 = 0.75f;
|
||||
const float k1 = 1.0f;
|
||||
|
||||
// create lrn primitive and add it to net
|
||||
auto lrn1_desc = lrn_forward::desc(prop_kind::forward_inference,
|
||||
algorithm::lrn_across_channels, conv1_dst_memory.get_desc(), local1_size,
|
||||
alpha1, beta1, k1);
|
||||
auto lrn1_prim_desc = lrn_forward::primitive_desc(lrn1_desc, eng);
|
||||
auto lrn1_dst_memory = memory(lrn1_prim_desc.dst_desc(), eng);
|
||||
|
||||
net.push_back(lrn_forward(lrn1_prim_desc));
|
||||
net_args.push_back({ { MKLDNN_ARG_SRC, conv1_dst_memory },
|
||||
{ MKLDNN_ARG_DST, lrn1_dst_memory } });
|
||||
|
||||
// AlexNet: pool1
|
||||
// {batch, 96, 55, 55} -> {batch, 96, 27, 27}
|
||||
// kernel: {3, 3}
|
||||
// strides: {2, 2}
|
||||
memory::dims pool1_dst_tz = { batch, 96, 27, 27 };
|
||||
memory::dims pool1_kernel = { 3, 3 };
|
||||
memory::dims pool1_strides = { 2, 2 };
|
||||
memory::dims pool_padding = { 0, 0 };
|
||||
|
||||
auto pool1_dst_md = memory::desc({ pool1_dst_tz }, dt::f32, tag::any);
|
||||
|
||||
/// For training execution, pooling requires a private workspace memory
|
||||
/// to perform the backward pass. However, pooling should not use 'workspace'
|
||||
/// for inference, because this is detrimental to performance.
|
||||
/// @snippet cpu_cnn_inference_f32.cpp Create pooling primitive
|
||||
///
|
||||
/// The example continues to create more layers according
|
||||
/// to the AlexNet topology.
|
||||
//[Create pooling primitive]
|
||||
auto pool1_desc = pooling_forward::desc(prop_kind::forward_inference,
|
||||
algorithm::pooling_max, lrn1_dst_memory.get_desc(), pool1_dst_md,
|
||||
pool1_strides, pool1_kernel, pool_padding, pool_padding);
|
||||
auto pool1_pd = pooling_forward::primitive_desc(pool1_desc, eng);
|
||||
auto pool1_dst_memory = memory(pool1_pd.dst_desc(), eng);
|
||||
|
||||
net.push_back(pooling_forward(pool1_pd));
|
||||
net_args.push_back({ { MKLDNN_ARG_SRC, lrn1_dst_memory },
|
||||
{ MKLDNN_ARG_DST, pool1_dst_memory } });
|
||||
//[Create pooling primitive]
|
||||
|
||||
// AlexNet: conv2
|
||||
// {batch, 96, 27, 27} (x) {2, 128, 48, 5, 5} -> {batch, 256, 27, 27}
|
||||
// strides: {1, 1}
|
||||
memory::dims conv2_src_tz = { batch, 96, 27, 27 };
|
||||
memory::dims conv2_weights_tz = { 2, 128, 48, 5, 5 };
|
||||
memory::dims conv2_bias_tz = { 256 };
|
||||
memory::dims conv2_dst_tz = { batch, 256, 27, 27 };
|
||||
memory::dims conv2_strides = { 1, 1 };
|
||||
memory::dims conv2_padding = { 2, 2 };
|
||||
|
||||
std::vector<float> conv2_weights(product(conv2_weights_tz));
|
||||
std::vector<float> conv2_bias(product(conv2_bias_tz));
|
||||
|
||||
// create memory for user data
|
||||
auto conv2_user_weights_memory
|
||||
= memory({ { conv2_weights_tz }, dt::f32, tag::goihw }, eng,
|
||||
conv2_weights.data());
|
||||
auto conv2_user_bias_memory = memory(
|
||||
{ { conv2_bias_tz }, dt::f32, tag::x }, eng, conv2_bias.data());
|
||||
|
||||
// create memory descriptors for convolution data w/ no specified format
|
||||
auto conv2_src_md = memory::desc({ conv2_src_tz }, dt::f32, tag::any);
|
||||
auto conv2_bias_md = memory::desc({ conv2_bias_tz }, dt::f32, tag::any);
|
||||
auto conv2_weights_md
|
||||
= memory::desc({ conv2_weights_tz }, dt::f32, tag::any);
|
||||
auto conv2_dst_md = memory::desc({ conv2_dst_tz }, dt::f32, tag::any);
|
||||
|
||||
// create a convolution
|
||||
auto conv2_desc = convolution_forward::desc(prop_kind::forward_inference,
|
||||
algorithm::convolution_direct, conv2_src_md, conv2_weights_md, conv2_bias_md,
|
||||
conv2_dst_md, conv2_strides, conv2_padding, conv2_padding);
|
||||
auto conv2_prim_desc = convolution_forward::primitive_desc(conv2_desc, eng);
|
||||
|
||||
auto conv2_src_memory = pool1_dst_memory;
|
||||
if (conv2_prim_desc.src_desc() != conv2_src_memory.get_desc()) {
|
||||
conv2_src_memory = memory(conv2_prim_desc.src_desc(), eng);
|
||||
net.push_back(reorder(pool1_dst_memory, conv2_src_memory));
|
||||
net_args.push_back({ { MKLDNN_ARG_FROM, pool1_dst_memory },
|
||||
{ MKLDNN_ARG_TO, conv2_src_memory } });
|
||||
}
|
||||
|
||||
auto conv2_weights_memory = conv2_user_weights_memory;
|
||||
if (conv2_prim_desc.weights_desc()
|
||||
!= conv2_user_weights_memory.get_desc()) {
|
||||
conv2_weights_memory = memory(conv2_prim_desc.weights_desc(), eng);
|
||||
reorder(conv2_user_weights_memory, conv2_weights_memory)
|
||||
.execute(s, conv2_user_weights_memory, conv2_weights_memory);
|
||||
}
|
||||
|
||||
auto conv2_dst_memory = memory(conv2_prim_desc.dst_desc(), eng);
|
||||
|
||||
// create convolution primitive and add it to net
|
||||
net.push_back(convolution_forward(conv2_prim_desc));
|
||||
net_args.push_back({ { MKLDNN_ARG_SRC, conv2_src_memory },
|
||||
{ MKLDNN_ARG_WEIGHTS, conv2_weights_memory },
|
||||
{ MKLDNN_ARG_BIAS, conv2_user_bias_memory },
|
||||
{ MKLDNN_ARG_DST, conv2_dst_memory } });
|
||||
|
||||
// AlexNet: relu2
|
||||
// {batch, 256, 27, 27} -> {batch, 256, 27, 27}
|
||||
const float negative2_slope = 1.0f;
|
||||
|
||||
// create relu primitive and add it to net
|
||||
auto relu2_desc = eltwise_forward::desc(prop_kind::forward_inference,
|
||||
algorithm::eltwise_relu, conv2_dst_memory.get_desc(),
|
||||
negative2_slope);
|
||||
auto relu2_prim_desc = eltwise_forward::primitive_desc(relu2_desc, eng);
|
||||
|
||||
net.push_back(eltwise_forward(relu2_prim_desc));
|
||||
net_args.push_back({ { MKLDNN_ARG_SRC, conv2_dst_memory },
|
||||
{ MKLDNN_ARG_DST, conv2_dst_memory } });
|
||||
|
||||
// AlexNet: lrn2
|
||||
// {batch, 256, 27, 27} -> {batch, 256, 27, 27}
|
||||
// local size: 5
|
||||
// alpha2: 0.0001
|
||||
// beta2: 0.75
|
||||
const memory::dim local2_size = 5;
|
||||
const float alpha2 = 0.0001f;
|
||||
const float beta2 = 0.75f;
|
||||
const float k2 = 1.0f;
|
||||
|
||||
// create lrn primitive and add it to net
|
||||
auto lrn2_desc = lrn_forward::desc(prop_kind::forward_inference,
|
||||
algorithm::lrn_across_channels, conv2_prim_desc.dst_desc(), local2_size,
|
||||
alpha2, beta2, k2);
|
||||
auto lrn2_prim_desc = lrn_forward::primitive_desc(lrn2_desc, eng);
|
||||
auto lrn2_dst_memory = memory(lrn2_prim_desc.dst_desc(), eng);
|
||||
|
||||
net.push_back(lrn_forward(lrn2_prim_desc));
|
||||
net_args.push_back({ { MKLDNN_ARG_SRC, conv2_dst_memory },
|
||||
{ MKLDNN_ARG_DST, lrn2_dst_memory } });
|
||||
|
||||
// AlexNet: pool2
|
||||
// {batch, 256, 27, 27} -> {batch, 256, 13, 13}
|
||||
// kernel: {3, 3}
|
||||
// strides: {2, 2}
|
||||
memory::dims pool2_dst_tz = { batch, 256, 13, 13 };
|
||||
memory::dims pool2_kernel = { 3, 3 };
|
||||
memory::dims pool2_strides = { 2, 2 };
|
||||
memory::dims pool2_padding = { 0, 0 };
|
||||
|
||||
auto pool2_dst_md = memory::desc({ pool2_dst_tz }, dt::f32, tag::any);
|
||||
|
||||
// create a pooling
|
||||
auto pool2_desc = pooling_forward::desc(prop_kind::forward_inference,
|
||||
algorithm::pooling_max, lrn2_dst_memory.get_desc(), pool2_dst_md,
|
||||
pool2_strides, pool2_kernel, pool2_padding, pool2_padding);
|
||||
auto pool2_pd = pooling_forward::primitive_desc(pool2_desc, eng);
|
||||
auto pool2_dst_memory = memory(pool2_pd.dst_desc(), eng);
|
||||
|
||||
// create pooling primitive an add it to net
|
||||
net.push_back(pooling_forward(pool2_pd));
|
||||
net_args.push_back({ { MKLDNN_ARG_SRC, lrn2_dst_memory },
|
||||
{ MKLDNN_ARG_DST, pool2_dst_memory } });
|
||||
|
||||
// AlexNet: conv3
|
||||
// {batch, 256, 13, 13} (x) {384, 256, 3, 3}; -> {batch, 384, 13, 13};
|
||||
// strides: {1, 1}
|
||||
memory::dims conv3_src_tz = { batch, 256, 13, 13 };
|
||||
memory::dims conv3_weights_tz = { 384, 256, 3, 3 };
|
||||
memory::dims conv3_bias_tz = { 384 };
|
||||
memory::dims conv3_dst_tz = { batch, 384, 13, 13 };
|
||||
memory::dims conv3_strides = { 1, 1 };
|
||||
memory::dims conv3_padding = { 1, 1 };
|
||||
|
||||
std::vector<float> conv3_weights(product(conv3_weights_tz));
|
||||
std::vector<float> conv3_bias(product(conv3_bias_tz));
|
||||
|
||||
// create memory for user data
|
||||
auto conv3_user_weights_memory
|
||||
= memory({ { conv3_weights_tz }, dt::f32, tag::oihw }, eng,
|
||||
conv3_weights.data());
|
||||
auto conv3_user_bias_memory = memory(
|
||||
{ { conv3_bias_tz }, dt::f32, tag::x }, eng, conv3_bias.data());
|
||||
|
||||
// create memory descriptors for convolution data w/ no specified format
|
||||
auto conv3_src_md = memory::desc({ conv3_src_tz }, dt::f32, tag::any);
|
||||
auto conv3_bias_md = memory::desc({ conv3_bias_tz }, dt::f32, tag::any);
|
||||
auto conv3_weights_md
|
||||
= memory::desc({ conv3_weights_tz }, dt::f32, tag::any);
|
||||
auto conv3_dst_md = memory::desc({ conv3_dst_tz }, dt::f32, tag::any);
|
||||
|
||||
// create a convolution
|
||||
auto conv3_desc = convolution_forward::desc(prop_kind::forward_inference,
|
||||
algorithm::convolution_direct, conv3_src_md, conv3_weights_md, conv3_bias_md,
|
||||
conv3_dst_md, conv3_strides, conv3_padding, conv3_padding);
|
||||
auto conv3_prim_desc = convolution_forward::primitive_desc(conv3_desc, eng);
|
||||
|
||||
auto conv3_src_memory = pool2_dst_memory;
|
||||
if (conv3_prim_desc.src_desc() != conv3_src_memory.get_desc()) {
|
||||
conv3_src_memory = memory(conv3_prim_desc.src_desc(), eng);
|
||||
net.push_back(reorder(pool2_dst_memory, conv3_src_memory));
|
||||
net_args.push_back({ { MKLDNN_ARG_FROM, pool2_dst_memory },
|
||||
{ MKLDNN_ARG_TO, conv3_src_memory } });
|
||||
}
|
||||
|
||||
auto conv3_weights_memory = conv3_user_weights_memory;
|
||||
if (conv3_prim_desc.weights_desc()
|
||||
!= conv3_user_weights_memory.get_desc()) {
|
||||
conv3_weights_memory = memory(conv3_prim_desc.weights_desc(), eng);
|
||||
reorder(conv3_user_weights_memory, conv3_weights_memory)
|
||||
.execute(s, conv3_user_weights_memory, conv3_weights_memory);
|
||||
}
|
||||
|
||||
auto conv3_dst_memory = memory(conv3_prim_desc.dst_desc(), eng);
|
||||
|
||||
// create convolution primitive and add it to net
|
||||
net.push_back(convolution_forward(conv3_prim_desc));
|
||||
net_args.push_back({ { MKLDNN_ARG_SRC, conv3_src_memory },
|
||||
{ MKLDNN_ARG_WEIGHTS, conv3_weights_memory },
|
||||
{ MKLDNN_ARG_BIAS, conv3_user_bias_memory },
|
||||
{ MKLDNN_ARG_DST, conv3_dst_memory } });
|
||||
|
||||
// AlexNet: relu3
|
||||
// {batch, 384, 13, 13} -> {batch, 384, 13, 13}
|
||||
const float negative3_slope = 1.0f;
|
||||
|
||||
// create relu primitive and add it to net
|
||||
auto relu3_desc = eltwise_forward::desc(prop_kind::forward_inference,
|
||||
algorithm::eltwise_relu, conv3_dst_memory.get_desc(),
|
||||
negative3_slope);
|
||||
auto relu3_prim_desc = eltwise_forward::primitive_desc(relu3_desc, eng);
|
||||
|
||||
net.push_back(eltwise_forward(relu3_prim_desc));
|
||||
net_args.push_back({ { MKLDNN_ARG_SRC, conv3_dst_memory },
|
||||
{ MKLDNN_ARG_DST, conv3_dst_memory } });
|
||||
|
||||
// AlexNet: conv4
|
||||
// {batch, 384, 13, 13} (x) {2, 192, 192, 3, 3}; ->
|
||||
// {batch, 384, 13, 13};
|
||||
// strides: {1, 1}
|
||||
memory::dims conv4_src_tz = { batch, 384, 13, 13 };
|
||||
memory::dims conv4_weights_tz = { 2, 192, 192, 3, 3 };
|
||||
memory::dims conv4_bias_tz = { 384 };
|
||||
memory::dims conv4_dst_tz = { batch, 384, 13, 13 };
|
||||
memory::dims conv4_strides = { 1, 1 };
|
||||
memory::dims conv4_padding = { 1, 1 };
|
||||
|
||||
std::vector<float> conv4_weights(product(conv4_weights_tz));
|
||||
std::vector<float> conv4_bias(product(conv4_bias_tz));
|
||||
|
||||
// create memory for user data
|
||||
auto conv4_user_weights_memory
|
||||
= memory({ { conv4_weights_tz }, dt::f32, tag::goihw }, eng,
|
||||
conv4_weights.data());
|
||||
auto conv4_user_bias_memory = memory(
|
||||
{ { conv4_bias_tz }, dt::f32, tag::x }, eng, conv4_bias.data());
|
||||
|
||||
// create memory descriptors for convolution data w/ no specified format
|
||||
auto conv4_src_md = memory::desc({ conv4_src_tz }, dt::f32, tag::any);
|
||||
auto conv4_bias_md = memory::desc({ conv4_bias_tz }, dt::f32, tag::any);
|
||||
auto conv4_weights_md
|
||||
= memory::desc({ conv4_weights_tz }, dt::f32, tag::any);
|
||||
auto conv4_dst_md = memory::desc({ conv4_dst_tz }, dt::f32, tag::any);
|
||||
|
||||
// create a convolution
|
||||
auto conv4_desc = convolution_forward::desc(prop_kind::forward_inference,
|
||||
algorithm::convolution_direct, conv4_src_md, conv4_weights_md, conv4_bias_md,
|
||||
conv4_dst_md, conv4_strides, conv4_padding, conv4_padding);
|
||||
auto conv4_prim_desc = convolution_forward::primitive_desc(conv4_desc, eng);
|
||||
|
||||
auto conv4_src_memory = conv3_dst_memory;
|
||||
if (conv4_prim_desc.src_desc() != conv4_src_memory.get_desc()) {
|
||||
conv4_src_memory = memory(conv4_prim_desc.src_desc(), eng);
|
||||
net.push_back(reorder(conv3_dst_memory, conv4_src_memory));
|
||||
net_args.push_back({ { MKLDNN_ARG_FROM, conv3_dst_memory },
|
||||
{ MKLDNN_ARG_TO, conv4_src_memory } });
|
||||
}
|
||||
|
||||
auto conv4_weights_memory = conv4_user_weights_memory;
|
||||
if (conv4_prim_desc.weights_desc()
|
||||
!= conv4_user_weights_memory.get_desc()) {
|
||||
conv4_weights_memory = memory(conv4_prim_desc.weights_desc(), eng);
|
||||
reorder(conv4_user_weights_memory, conv4_weights_memory)
|
||||
.execute(s, conv4_user_weights_memory, conv4_weights_memory);
|
||||
}
|
||||
|
||||
auto conv4_dst_memory = memory(conv4_prim_desc.dst_desc(), eng);
|
||||
|
||||
// create convolution primitive and add it to net
|
||||
net.push_back(convolution_forward(conv4_prim_desc));
|
||||
net_args.push_back({ { MKLDNN_ARG_SRC, conv4_src_memory },
|
||||
{ MKLDNN_ARG_WEIGHTS, conv4_weights_memory },
|
||||
{ MKLDNN_ARG_BIAS, conv4_user_bias_memory },
|
||||
{ MKLDNN_ARG_DST, conv4_dst_memory } });
|
||||
|
||||
// AlexNet: relu4
|
||||
// {batch, 384, 13, 13} -> {batch, 384, 13, 13}
|
||||
const float negative4_slope = 1.0f;
|
||||
|
||||
// create relu primitive and add it to net
|
||||
auto relu4_desc = eltwise_forward::desc(prop_kind::forward_inference,
|
||||
algorithm::eltwise_relu, conv4_dst_memory.get_desc(),
|
||||
negative4_slope);
|
||||
auto relu4_prim_desc = eltwise_forward::primitive_desc(relu4_desc, eng);
|
||||
|
||||
net.push_back(eltwise_forward(relu4_prim_desc));
|
||||
net_args.push_back({ { MKLDNN_ARG_SRC, conv4_dst_memory },
|
||||
{ MKLDNN_ARG_DST, conv4_dst_memory } });
|
||||
|
||||
// AlexNet: conv5
|
||||
// {batch, 384, 13, 13} (x) {2, 128, 192, 3, 3}; -> {batch, 256, 13, 13};
|
||||
// strides: {1, 1}
|
||||
memory::dims conv5_src_tz = { batch, 384, 13, 13 };
|
||||
memory::dims conv5_weights_tz = { 2, 128, 192, 3, 3 };
|
||||
memory::dims conv5_bias_tz = { 256 };
|
||||
memory::dims conv5_dst_tz = { batch, 256, 13, 13 };
|
||||
memory::dims conv5_strides = { 1, 1 };
|
||||
memory::dims conv5_padding = { 1, 1 };
|
||||
|
||||
std::vector<float> conv5_weights(product(conv5_weights_tz));
|
||||
std::vector<float> conv5_bias(product(conv5_bias_tz));
|
||||
|
||||
// create memory for user data
|
||||
auto conv5_user_weights_memory
|
||||
= memory({ { conv5_weights_tz }, dt::f32, tag::goihw }, eng,
|
||||
conv5_weights.data());
|
||||
auto conv5_user_bias_memory = memory(
|
||||
{ { conv5_bias_tz }, dt::f32, tag::x }, eng, conv5_bias.data());
|
||||
|
||||
// create memory descriptors for convolution data w/ no specified format
|
||||
auto conv5_src_md = memory::desc({ conv5_src_tz }, dt::f32, tag::any);
|
||||
auto conv5_weights_md
|
||||
= memory::desc({ conv5_weights_tz }, dt::f32, tag::any);
|
||||
auto conv5_bias_md = memory::desc({ conv5_bias_tz }, dt::f32, tag::any);
|
||||
auto conv5_dst_md = memory::desc({ conv5_dst_tz }, dt::f32, tag::any);
|
||||
|
||||
// create a convolution
|
||||
auto conv5_desc = convolution_forward::desc(prop_kind::forward_inference,
|
||||
algorithm::convolution_direct, conv5_src_md, conv5_weights_md, conv5_bias_md,
|
||||
conv5_dst_md, conv5_strides, conv5_padding, conv5_padding);
|
||||
auto conv5_prim_desc = convolution_forward::primitive_desc(conv5_desc, eng);
|
||||
|
||||
auto conv5_src_memory = conv4_dst_memory;
|
||||
if (conv5_prim_desc.src_desc() != conv5_src_memory.get_desc()) {
|
||||
conv5_src_memory = memory(conv5_prim_desc.src_desc(), eng);
|
||||
net.push_back(reorder(conv4_dst_memory, conv5_src_memory));
|
||||
net_args.push_back({ { MKLDNN_ARG_FROM, conv4_dst_memory },
|
||||
{ MKLDNN_ARG_TO, conv5_src_memory } });
|
||||
}
|
||||
|
||||
auto conv5_weights_memory = conv5_user_weights_memory;
|
||||
if (conv5_prim_desc.weights_desc()
|
||||
!= conv5_user_weights_memory.get_desc()) {
|
||||
conv5_weights_memory = memory(conv5_prim_desc.weights_desc(), eng);
|
||||
reorder(conv5_user_weights_memory, conv5_weights_memory)
|
||||
.execute(s, conv5_user_weights_memory, conv5_weights_memory);
|
||||
}
|
||||
|
||||
auto conv5_dst_memory = memory(conv5_prim_desc.dst_desc(), eng);
|
||||
|
||||
// create convolution primitive and add it to net
|
||||
net.push_back(convolution_forward(conv5_prim_desc));
|
||||
net_args.push_back({ { MKLDNN_ARG_SRC, conv5_src_memory },
|
||||
{ MKLDNN_ARG_WEIGHTS, conv5_weights_memory },
|
||||
{ MKLDNN_ARG_BIAS, conv5_user_bias_memory },
|
||||
{ MKLDNN_ARG_DST, conv5_dst_memory } });
|
||||
|
||||
// AlexNet: relu5
|
||||
// {batch, 256, 13, 13} -> {batch, 256, 13, 13}
|
||||
const float negative5_slope = 1.0f;
|
||||
|
||||
// create relu primitive and add it to net
|
||||
auto relu5_desc = eltwise_forward::desc(prop_kind::forward_inference,
|
||||
algorithm::eltwise_relu, conv5_dst_memory.get_desc(),
|
||||
negative5_slope);
|
||||
auto relu5_prim_desc = eltwise_forward::primitive_desc(relu5_desc, eng);
|
||||
|
||||
net.push_back(eltwise_forward(relu5_prim_desc));
|
||||
net_args.push_back({ { MKLDNN_ARG_SRC, conv5_dst_memory },
|
||||
{ MKLDNN_ARG_DST, conv5_dst_memory } });
|
||||
|
||||
// AlexNet: pool5
|
||||
// {batch, 256, 13, 13} -> {batch, 256, 6, 6}
|
||||
// kernel: {3, 3}
|
||||
// strides: {2, 2}
|
||||
memory::dims pool5_dst_tz = { batch, 256, 6, 6 };
|
||||
memory::dims pool5_kernel = { 3, 3 };
|
||||
memory::dims pool5_strides = { 2, 2 };
|
||||
memory::dims pool5_padding = { 0, 0 };
|
||||
|
||||
std::vector<float> pool5_dst(product(pool5_dst_tz));
|
||||
|
||||
auto pool5_dst_md = memory::desc({ pool5_dst_tz }, dt::f32, tag::any);
|
||||
|
||||
// create a pooling
|
||||
auto pool5_desc = pooling_forward::desc(prop_kind::forward_inference,
|
||||
algorithm::pooling_max, conv5_dst_memory.get_desc(), pool5_dst_md,
|
||||
pool5_strides, pool5_kernel, pool5_padding, pool5_padding);
|
||||
auto pool5_pd = pooling_forward::primitive_desc(pool5_desc, eng);
|
||||
|
||||
auto pool5_dst_memory = memory(pool5_pd.dst_desc(), eng);
|
||||
|
||||
// create pooling primitive an add it to net
|
||||
net.push_back(pooling_forward(pool5_pd));
|
||||
net_args.push_back({ { MKLDNN_ARG_SRC, conv5_dst_memory },
|
||||
{ MKLDNN_ARG_DST, pool5_dst_memory } });
|
||||
|
||||
|
||||
// fc6 inner product {batch, 256, 6, 6} (x) {4096, 256, 6, 6}-> {batch,
|
||||
// 4096}
|
||||
memory::dims fc6_src_tz = { batch, 256, 6, 6 };
|
||||
memory::dims fc6_weights_tz = { 4096, 256, 6, 6 };
|
||||
memory::dims fc6_bias_tz = { 4096 };
|
||||
memory::dims fc6_dst_tz = { batch, 4096 };
|
||||
|
||||
std::vector<float> fc6_weights(product(fc6_weights_tz));
|
||||
std::vector<float> fc6_bias(product(fc6_bias_tz));
|
||||
|
||||
// create memory for user data
|
||||
auto fc6_user_weights_memory
|
||||
= memory({ { fc6_weights_tz }, dt::f32, tag::oihw }, eng,
|
||||
fc6_weights.data());
|
||||
auto fc6_user_bias_memory = memory(
|
||||
{ { fc6_bias_tz }, dt::f32, tag::x }, eng, fc6_bias.data());
|
||||
|
||||
// create memory descriptors for convolution data w/ no specified format
|
||||
auto fc6_src_md = memory::desc({ fc6_src_tz }, dt::f32, tag::any);
|
||||
auto fc6_bias_md = memory::desc({ fc6_bias_tz }, dt::f32, tag::any);
|
||||
auto fc6_weights_md = memory::desc({ fc6_weights_tz }, dt::f32, tag::any);
|
||||
auto fc6_dst_md = memory::desc({ fc6_dst_tz }, dt::f32, tag::any);
|
||||
|
||||
// create a inner_product
|
||||
auto fc6_desc = inner_product_forward::desc(prop_kind::forward_inference,
|
||||
fc6_src_md, fc6_weights_md, fc6_bias_md, fc6_dst_md);
|
||||
auto fc6_prim_desc = inner_product_forward::primitive_desc(fc6_desc, eng);
|
||||
|
||||
auto fc6_src_memory = pool5_dst_memory;
|
||||
if (fc6_prim_desc.src_desc() != fc6_src_memory.get_desc()) {
|
||||
fc6_src_memory = memory(fc6_prim_desc.src_desc(), eng);
|
||||
net.push_back(reorder(pool5_dst_memory, fc6_src_memory));
|
||||
net_args.push_back({ { MKLDNN_ARG_FROM, pool5_dst_memory },
|
||||
{ MKLDNN_ARG_TO, fc6_src_memory } });
|
||||
}
|
||||
|
||||
auto fc6_weights_memory = fc6_user_weights_memory;
|
||||
if (fc6_prim_desc.weights_desc() != fc6_user_weights_memory.get_desc()) {
|
||||
fc6_weights_memory = memory(fc6_prim_desc.weights_desc(), eng);
|
||||
reorder(fc6_user_weights_memory, fc6_weights_memory)
|
||||
.execute(s, fc6_user_weights_memory, fc6_weights_memory);
|
||||
}
|
||||
|
||||
auto fc6_dst_memory = memory(fc6_prim_desc.dst_desc(), eng);
|
||||
|
||||
// create convolution primitive and add it to net
|
||||
net.push_back(inner_product_forward(fc6_prim_desc));
|
||||
net_args.push_back({ { MKLDNN_ARG_SRC, fc6_src_memory },
|
||||
{ MKLDNN_ARG_WEIGHTS, fc6_weights_memory },
|
||||
{ MKLDNN_ARG_BIAS, fc6_user_bias_memory },
|
||||
{ MKLDNN_ARG_DST, fc6_dst_memory } });
|
||||
|
||||
|
||||
// fc7 inner product {batch, 4096} (x) {4096, 4096}-> {batch, 4096}
|
||||
memory::dims fc7_weights_tz = { 4096, 4096 };
|
||||
memory::dims fc7_bias_tz = { 4096 };
|
||||
memory::dims fc7_dst_tz = { batch, 4096 };
|
||||
|
||||
std::vector<float> fc7_weights(product(fc7_weights_tz));
|
||||
std::vector<float> fc7_bias(product(fc7_bias_tz));
|
||||
|
||||
// create memory for user data
|
||||
auto fc7_user_weights_memory = memory(
|
||||
{ { fc7_weights_tz }, dt::f32, tag::nc }, eng, fc7_weights.data());
|
||||
|
||||
auto fc7_user_bias_memory = memory(
|
||||
{ { fc7_bias_tz }, dt::f32, tag::x }, eng, fc7_bias.data());
|
||||
|
||||
// create memory descriptors for convolution data w/ no specified format
|
||||
auto fc7_bias_md = memory::desc({ fc7_bias_tz }, dt::f32, tag::any);
|
||||
auto fc7_weights_md = memory::desc({ fc7_weights_tz }, dt::f32, tag::any);
|
||||
auto fc7_dst_md = memory::desc({ fc7_dst_tz }, dt::f32, tag::any);
|
||||
|
||||
// create a inner_product
|
||||
auto fc7_desc = inner_product_forward::desc(prop_kind::forward_inference,
|
||||
fc6_dst_memory.get_desc(), fc7_weights_md, fc7_bias_md, fc7_dst_md);
|
||||
auto fc7_prim_desc = inner_product_forward::primitive_desc(fc7_desc, eng);
|
||||
|
||||
auto fc7_weights_memory = fc7_user_weights_memory;
|
||||
if (fc7_prim_desc.weights_desc() != fc7_user_weights_memory.get_desc()) {
|
||||
fc7_weights_memory = memory(fc7_prim_desc.weights_desc(), eng);
|
||||
reorder(fc7_user_weights_memory, fc7_weights_memory)
|
||||
.execute(s, fc7_user_weights_memory, fc7_weights_memory);
|
||||
}
|
||||
|
||||
auto fc7_dst_memory = memory(fc7_prim_desc.dst_desc(), eng);
|
||||
|
||||
// create convolution primitive and add it to net
|
||||
net.push_back(inner_product_forward(fc7_prim_desc));
|
||||
net_args.push_back({ { MKLDNN_ARG_SRC, fc6_dst_memory },
|
||||
{ MKLDNN_ARG_WEIGHTS, fc7_weights_memory },
|
||||
{ MKLDNN_ARG_BIAS, fc7_user_bias_memory },
|
||||
{ MKLDNN_ARG_DST, fc7_dst_memory } });
|
||||
|
||||
// fc8 inner product {batch, 4096} (x) {1000, 4096}-> {batch, 1000}
|
||||
memory::dims fc8_weights_tz = { 1000, 4096 };
|
||||
memory::dims fc8_bias_tz = { 1000 };
|
||||
memory::dims fc8_dst_tz = { batch, 1000 };
|
||||
|
||||
std::vector<float> fc8_weights(product(fc8_weights_tz));
|
||||
std::vector<float> fc8_bias(product(fc8_bias_tz));
|
||||
|
||||
// create memory for user data
|
||||
auto fc8_user_weights_memory = memory(
|
||||
{ { fc8_weights_tz }, dt::f32, tag::nc }, eng, fc8_weights.data());
|
||||
auto fc8_user_bias_memory = memory(
|
||||
{ { fc8_bias_tz }, dt::f32, tag::x }, eng, fc8_bias.data());
|
||||
auto user_dst_memory = memory(
|
||||
{ { fc8_dst_tz }, dt::f32, tag::nc }, eng, user_dst.data());
|
||||
|
||||
// create memory descriptors for convolution data w/ no specified format
|
||||
auto fc8_bias_md = memory::desc({ fc8_bias_tz }, dt::f32, tag::any);
|
||||
auto fc8_weights_md = memory::desc({ fc8_weights_tz }, dt::f32, tag::any);
|
||||
auto fc8_dst_md = memory::desc({ fc8_dst_tz }, dt::f32, tag::any);
|
||||
|
||||
// create a inner_product
|
||||
auto fc8_desc = inner_product_forward::desc(prop_kind::forward_inference,
|
||||
fc7_dst_memory.get_desc(), fc8_weights_md, fc8_bias_md, fc8_dst_md);
|
||||
auto fc8_prim_desc = inner_product_forward::primitive_desc(fc8_desc, eng);
|
||||
|
||||
auto fc8_weights_memory = fc8_user_weights_memory;
|
||||
if (fc8_prim_desc.weights_desc() != fc8_user_weights_memory.get_desc()) {
|
||||
fc8_weights_memory = memory(fc8_prim_desc.weights_desc(), eng);
|
||||
reorder(fc8_user_weights_memory, fc8_weights_memory)
|
||||
.execute(s, fc8_user_weights_memory, fc8_weights_memory);
|
||||
}
|
||||
|
||||
auto fc8_dst_memory = memory(fc8_prim_desc.dst_desc(), eng);
|
||||
|
||||
// create convolution primitive and add it to net
|
||||
net.push_back(inner_product_forward(fc8_prim_desc));
|
||||
net_args.push_back({ { MKLDNN_ARG_SRC, fc7_dst_memory },
|
||||
{ MKLDNN_ARG_WEIGHTS, fc8_weights_memory },
|
||||
{ MKLDNN_ARG_BIAS, fc8_user_bias_memory },
|
||||
{ MKLDNN_ARG_DST, fc8_dst_memory } });
|
||||
|
||||
// create reorder between internal and user data if it is needed and
|
||||
// add it to net after pooling
|
||||
if (fc8_dst_memory != user_dst_memory) {
|
||||
net.push_back(reorder(fc8_dst_memory, user_dst_memory));
|
||||
net_args.push_back({ { MKLDNN_ARG_FROM, fc8_dst_memory },
|
||||
{ MKLDNN_ARG_TO, user_dst_memory } });
|
||||
}
|
||||
|
||||
/// @page cpu_cnn_inference_f32_cpp
|
||||
/// Finally, execute the primitives. For this example, the net is executed
|
||||
/// multiple times and each execution is timed individually.
|
||||
/// @snippet cpu_cnn_inference_f32.cpp Execute model
|
||||
//[Execute model]
|
||||
for (int j = 0; j < times; ++j) {
|
||||
assert(net.size() == net_args.size() && "something is missing");
|
||||
for (size_t i = 0; i < net.size(); ++i)
|
||||
net.at(i).execute(s, net_args.at(i));
|
||||
}
|
||||
//[Execute model]
|
||||
|
||||
s.wait();
|
||||
}
|
||||
|
||||
// extern "C" int mkl_test_entry();
|
||||
|
||||
int mkl_test_entry() {
|
||||
try {
|
||||
auto begin = chrono::duration_cast<chrono::milliseconds>(
|
||||
chrono::steady_clock::now().time_since_epoch())
|
||||
.count();
|
||||
int times = 100;
|
||||
simple_net(times);
|
||||
auto end = chrono::duration_cast<chrono::milliseconds>(
|
||||
chrono::steady_clock::now().time_since_epoch())
|
||||
.count();
|
||||
cout << "Use time " << (end - begin) / (times + 0.0) << "\n";
|
||||
} catch (error &e) {
|
||||
std::cerr << "status: " << e.status << std::endl;
|
||||
std::cerr << "message: " << e.message << std::endl;
|
||||
return 1;
|
||||
}
|
||||
return 0;
|
||||
}
|
|
@ -0,0 +1,209 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors:
|
||||
// Guowei Yang <471184555@qq.com>
|
||||
// Guoye Yang <498731903@qq.com>
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include <random>
|
||||
|
||||
#include "var.h"
|
||||
#include "mkl_conv_backward_w_op.h"
|
||||
|
||||
#include <mkldnn.hpp>
|
||||
|
||||
using namespace mkldnn;
|
||||
using namespace std;
|
||||
|
||||
namespace jittor {
|
||||
static inline int findc(const string& format, const char& c) {
|
||||
if (c==format[0]) return 0;
|
||||
if (c==format[1]) return 1;
|
||||
if (c==format[2]) return 2;
|
||||
ASSERT(c==format[3]) << "Not a valid format" << format << c;
|
||||
return 3;
|
||||
}
|
||||
|
||||
#ifndef JIT
|
||||
static inline void get_shape(Var* x, const char* f, const string& format, int& a, int& b, int &c, int& d) {
|
||||
auto& shape = x->shape;
|
||||
a = shape[findc(format, f[0])];
|
||||
b = shape[findc(format, f[1])];
|
||||
c = shape[findc(format, f[2])];
|
||||
d = shape[findc(format, f[3])];
|
||||
}
|
||||
|
||||
static inline void set_shape(Var* x, const char* f, const string& format, int a, int b, int c, int d) {
|
||||
int64 shape[4];
|
||||
shape[findc(format, f[0])] = a;
|
||||
shape[findc(format, f[1])] = b;
|
||||
shape[findc(format, f[2])] = c;
|
||||
shape[findc(format, f[3])] = d;
|
||||
x->set_shape(NanoVector(
|
||||
shape[0], shape[1], shape[2], shape[3]));
|
||||
}
|
||||
|
||||
MklConvBackwardWOp::MklConvBackwardWOp(Var* x, Var* dy, int kernel_size, int stride, int padding, int dilation, string xformat, string wformat, string yformat)
|
||||
: x(x), dy(dy), kernel_size(kernel_size), stride(stride), padding(padding), dilation(dilation),
|
||||
xformat(move(xformat)), wformat(move(wformat)), yformat(move(yformat)) {
|
||||
dw = create_output(nullptr, dtype_infer(dy->ns, x->ns));
|
||||
}
|
||||
|
||||
void MklConvBackwardWOp::infer_shape() {
|
||||
ASSERTop(x->shape.size(),==,4);
|
||||
ASSERTop(dy->shape.size(),==,4);
|
||||
int xn, xc, xh, xw, wh, ww, wci, wco, yn, yc, yh, yw;
|
||||
get_shape(x, "abcd", xformat, xn, xc, xh, xw);
|
||||
get_shape(dy, "abcd", yformat, yn, yc, yh, yw);
|
||||
wco = yc, wci = xc;
|
||||
wh = kernel_size;
|
||||
ww = kernel_size;
|
||||
set_shape(dw, "oihw", wformat, wco, wci, wh, ww);
|
||||
}
|
||||
|
||||
static const char* short_type(Var* x) {
|
||||
if (x->is_float()) {
|
||||
if (x->dsize()==4) return "f32";
|
||||
if (x->dsize()==8) return "f64";
|
||||
if (x->dsize()==2) return "f16";
|
||||
return "f8";
|
||||
} else {
|
||||
if (x->dsize()==4) return "s32";
|
||||
if (x->dsize()==8) return "s64";
|
||||
if (x->dsize()==2) return "s16";
|
||||
return "s8";
|
||||
}
|
||||
}
|
||||
|
||||
void MklConvBackwardWOp::jit_prepare() {
|
||||
add_jit_define("Txd", x->dtype());
|
||||
add_jit_define("Tyd", dy->dtype());
|
||||
add_jit_define("Twd", dw->dtype());
|
||||
add_jit_define("Tx", short_type(x));
|
||||
add_jit_define("Tw", short_type(dw));
|
||||
add_jit_define("Ty", short_type(dy));
|
||||
add_jit_define("XFORMAT", xformat);
|
||||
add_jit_define("WFORMAT", wformat);
|
||||
add_jit_define("YFORMAT", yformat);
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
#ifdef JIT_cpu
|
||||
void MklConvBackwardWOp::jit_run() {
|
||||
int batch = x->shape[findc("@XFORMAT",'a')];
|
||||
int ch_in = x->shape[findc("@XFORMAT",'b')];
|
||||
int height = x->shape[findc("@XFORMAT",'c')];
|
||||
int width = x->shape[findc("@XFORMAT",'d')];
|
||||
int ch_out = dw->shape[findc("@WFORMAT",'o')];
|
||||
int kernel_size = dw->shape[findc("@WFORMAT",'h')];
|
||||
|
||||
auto* __restrict__ net_src = x->ptr<Txd>();
|
||||
auto* __restrict__ net_diff_dst = dy->ptr<Tyd>();
|
||||
auto* __restrict__ conv_user_diff_weights_buffer = dw->ptr<Twd>();
|
||||
|
||||
using tag = memory::format_tag;
|
||||
using dt = memory::data_type;
|
||||
|
||||
auto eng = engine(engine::kind::cpu, 0);
|
||||
stream s(eng);
|
||||
|
||||
std::vector<primitive> net_bwd;
|
||||
std::vector<std::unordered_map<int, memory>> net_bwd_args;
|
||||
|
||||
memory::dims conv_src_tz = {batch, ch_in, height, width};
|
||||
memory::dims conv_weights_tz = {ch_out, ch_in, kernel_size, kernel_size};
|
||||
memory::dims conv_dst_tz = {batch, ch_out, (height+padding*2-kernel_size*dilation+dilation-1)/stride+1, (width+padding*2-kernel_size*dilation+dilation-1)/stride+1};
|
||||
memory::dims conv_strides = {stride, stride};
|
||||
memory::dims conv_padding = {padding, padding};
|
||||
memory::dims conv_dilation = {dilation-1, dilation-1};
|
||||
|
||||
auto conv_user_src_memory
|
||||
= memory({{conv_src_tz}, dt::@Tx, tag::@XFORMAT}, eng, net_src);
|
||||
|
||||
auto conv_src_md = memory::desc({conv_src_tz}, dt::@Tx, tag::any);
|
||||
auto conv_weights_md = memory::desc({conv_weights_tz}, dt::@Tw, tag::any);
|
||||
auto conv_dst_md = memory::desc({conv_dst_tz}, dt::@Ty, tag::any);
|
||||
|
||||
auto conv_desc = convolution_forward::desc(prop_kind::forward,
|
||||
algorithm::convolution_direct, conv_src_md, conv_weights_md,
|
||||
conv_dst_md, conv_strides, conv_dilation, conv_padding,
|
||||
conv_padding);
|
||||
auto conv_pd = convolution_forward::primitive_desc(conv_desc, eng);
|
||||
|
||||
auto conv_src_memory = conv_user_src_memory;
|
||||
if (conv_pd.src_desc() != conv_user_src_memory.get_desc()) {
|
||||
conv_src_memory = memory(conv_pd.src_desc(), eng);
|
||||
net_bwd.push_back(reorder(conv_user_src_memory, conv_src_memory));
|
||||
net_bwd_args.push_back({{MKLDNN_ARG_FROM, conv_user_src_memory},
|
||||
{MKLDNN_ARG_TO, conv_src_memory}});
|
||||
}
|
||||
|
||||
auto conv_user_diff_dst_memory
|
||||
= memory({{conv_dst_tz}, dt::@Ty, tag::YFORMAT}, eng, net_diff_dst);
|
||||
|
||||
auto conv_user_diff_weights_memory
|
||||
= memory({{conv_weights_tz}, dt::@Tw, tag::WFORMAT}, eng, conv_user_diff_weights_buffer);
|
||||
|
||||
auto conv_bwd_src_md = memory::desc({conv_src_tz}, dt::@Tx, tag::any);
|
||||
auto conv_diff_weights_md
|
||||
= memory::desc({conv_weights_tz}, dt::@Tw, tag::any);
|
||||
auto conv_diff_dst_md = memory::desc({conv_dst_tz}, dt::@Ty, tag::any);
|
||||
|
||||
auto conv_bwd_weights_desc
|
||||
= convolution_backward_weights::desc(algorithm::convolution_direct,
|
||||
conv_bwd_src_md, conv_diff_weights_md,
|
||||
conv_diff_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding);
|
||||
auto conv_bwd_weights_pd = convolution_backward_weights::primitive_desc(
|
||||
conv_bwd_weights_desc, eng, conv_pd);
|
||||
|
||||
auto conv_bwd_src_memory = conv_src_memory;
|
||||
if (conv_bwd_weights_pd.src_desc() != conv_src_memory.get_desc()) {
|
||||
conv_bwd_src_memory = memory(conv_bwd_weights_pd.src_desc(), eng);
|
||||
net_bwd.push_back(reorder(conv_src_memory, conv_bwd_src_memory));
|
||||
net_bwd_args.push_back({{MKLDNN_ARG_FROM, conv_src_memory},
|
||||
{MKLDNN_ARG_TO, conv_bwd_src_memory}});
|
||||
}
|
||||
|
||||
auto conv_diff_dst_memory = conv_user_diff_dst_memory;
|
||||
if (conv_bwd_weights_pd.diff_dst_desc()
|
||||
!= conv_user_diff_dst_memory.get_desc()) {
|
||||
conv_diff_dst_memory = memory(conv_bwd_weights_pd.diff_dst_desc(), eng);
|
||||
net_bwd.push_back(reorder(conv_user_diff_dst_memory, conv_diff_dst_memory));
|
||||
net_bwd_args.push_back({{MKLDNN_ARG_FROM, conv_user_diff_dst_memory},
|
||||
{MKLDNN_ARG_TO, conv_diff_dst_memory}});
|
||||
}
|
||||
|
||||
net_bwd.push_back(convolution_backward_weights(conv_bwd_weights_pd));
|
||||
net_bwd_args.push_back({{MKLDNN_ARG_SRC, conv_bwd_src_memory},
|
||||
{MKLDNN_ARG_DIFF_DST, conv_diff_dst_memory}});
|
||||
|
||||
auto conv_diff_weights_memory = conv_user_diff_weights_memory;
|
||||
if (conv_bwd_weights_pd.diff_weights_desc()
|
||||
!= conv_user_diff_weights_memory.get_desc()) {
|
||||
conv_diff_weights_memory
|
||||
= memory(conv_bwd_weights_pd.diff_weights_desc(), eng);
|
||||
net_bwd_args.back().insert(
|
||||
{MKLDNN_ARG_DIFF_WEIGHTS, conv_diff_weights_memory});
|
||||
|
||||
net_bwd.push_back(reorder(
|
||||
conv_diff_weights_memory, conv_user_diff_weights_memory));
|
||||
net_bwd_args.push_back({{MKLDNN_ARG_FROM, conv_diff_weights_memory},
|
||||
{MKLDNN_ARG_TO, conv_user_diff_weights_memory}});
|
||||
} else {
|
||||
net_bwd_args.back().insert(
|
||||
{MKLDNN_ARG_DIFF_WEIGHTS, conv_diff_weights_memory});
|
||||
}
|
||||
|
||||
ASSERTop(net_bwd.size(),==,net_bwd_args.size());
|
||||
|
||||
for (size_t i = 0; i < net_bwd.size(); ++i)
|
||||
net_bwd.at(i).execute(s, net_bwd_args.at(i));
|
||||
|
||||
s.wait();
|
||||
}
|
||||
#endif
|
||||
#endif // JIT
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,26 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors:
|
||||
// Guowei Yang <471184555@qq.com>
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include "op.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
struct MklConvBackwardWOp : Op {
|
||||
Var* x, * dy, * dw;
|
||||
int kernel_size, stride, padding, dilation;
|
||||
string xformat, wformat, yformat;
|
||||
|
||||
MklConvBackwardWOp(Var* x, Var* y, int kernel_size, int stride, int padding, int dilation, string xformat="abcd", string wformat="oihw", string yformat="abcd");
|
||||
|
||||
const char* name() const override { return "mkl_conv_backward_w"; }
|
||||
void infer_shape() override;
|
||||
DECLARE_jit_run;
|
||||
};
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,205 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors:
|
||||
// Guowei Yang <471184555@qq.com>
|
||||
// Guoye Yang <498731903@qq.com>
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include <random>
|
||||
|
||||
#include "var.h"
|
||||
#include "mkl_conv_backward_x_op.h"
|
||||
|
||||
#include <mkldnn.hpp>
|
||||
|
||||
using namespace mkldnn;
|
||||
using namespace std;
|
||||
|
||||
namespace jittor {
|
||||
static inline int findc(const string& format, const char& c) {
|
||||
if (c==format[0]) return 0;
|
||||
if (c==format[1]) return 1;
|
||||
if (c==format[2]) return 2;
|
||||
ASSERT(c==format[3]) << "Not a valid format" << format << c;
|
||||
return 3;
|
||||
}
|
||||
|
||||
#ifndef JIT
|
||||
static inline void get_shape(Var* x, const char* f, const string& format, int& a, int& b, int &c, int& d) {
|
||||
auto& shape = x->shape;
|
||||
a = shape[findc(format, f[0])];
|
||||
b = shape[findc(format, f[1])];
|
||||
c = shape[findc(format, f[2])];
|
||||
d = shape[findc(format, f[3])];
|
||||
}
|
||||
|
||||
static inline void set_shape(Var* x, const char* f, const string& format, int a, int b, int c, int d) {
|
||||
int64 shape[4];
|
||||
shape[findc(format, f[0])] = a;
|
||||
shape[findc(format, f[1])] = b;
|
||||
shape[findc(format, f[2])] = c;
|
||||
shape[findc(format, f[3])] = d;
|
||||
x->set_shape(NanoVector(
|
||||
shape[0], shape[1], shape[2], shape[3]));
|
||||
}
|
||||
|
||||
MklConvBackwardXOp::MklConvBackwardXOp(Var* w, Var* dy, int height, int width, int stride, int padding, int dilation, string xformat, string wformat, string yformat)
|
||||
: w(w), dy(dy), xh(height), xw(width), stride(stride), padding(padding), dilation(dilation),
|
||||
xformat(move(xformat)), wformat(move(wformat)), yformat(move(yformat)) {
|
||||
dx = create_output(nullptr, dtype_infer(dy->ns, w->ns));
|
||||
}
|
||||
|
||||
void MklConvBackwardXOp::infer_shape() {
|
||||
ASSERTop(w->shape.size(),==,4);
|
||||
ASSERTop(dy->shape.size(),==,4);
|
||||
int xn, xc, wh, ww, wci, wco, yn, yc, yh, yw;
|
||||
get_shape(w, "oihw", wformat, wco, wci, wh, ww);
|
||||
get_shape(dy, "abcd", yformat, yn, yc, yh, yw);
|
||||
xn = yn, xc = wci;
|
||||
set_shape(dx, "abcd", xformat, xn, xc, xh, xw);
|
||||
}
|
||||
|
||||
static const char* short_type(Var* x) {
|
||||
if (x->is_float()) {
|
||||
if (x->dsize()==4) return "f32";
|
||||
if (x->dsize()==8) return "f64";
|
||||
if (x->dsize()==2) return "f16";
|
||||
return "f8";
|
||||
} else {
|
||||
if (x->dsize()==4) return "s32";
|
||||
if (x->dsize()==8) return "s64";
|
||||
if (x->dsize()==2) return "s16";
|
||||
return "s8";
|
||||
}
|
||||
}
|
||||
|
||||
void MklConvBackwardXOp::jit_prepare() {
|
||||
add_jit_define("Tyd", dy->dtype());
|
||||
add_jit_define("Twd", w->dtype());
|
||||
add_jit_define("Txd", dx->dtype());
|
||||
add_jit_define("Tx", short_type(dx));
|
||||
add_jit_define("Tw", short_type(w));
|
||||
add_jit_define("Ty", short_type(dy));
|
||||
add_jit_define("XFORMAT", xformat);
|
||||
add_jit_define("WFORMAT", wformat);
|
||||
add_jit_define("YFORMAT", yformat);
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
#ifdef JIT_cpu
|
||||
void MklConvBackwardXOp::jit_run() {
|
||||
int batch = dx->shape[findc("@XFORMAT",'a')];
|
||||
int ch_in = dx->shape[findc("@XFORMAT",'b')];
|
||||
int height = dx->shape[findc("@XFORMAT",'c')];
|
||||
int width = dx->shape[findc("@XFORMAT",'d')];
|
||||
int ch_out = w->shape[findc("@WFORMAT",'o')];
|
||||
int kernel_size = w->shape[findc("@WFORMAT",'h')];
|
||||
|
||||
auto* __restrict__ conv_weights = w->ptr<Twd>();
|
||||
auto* __restrict__ net_diff_dst = dy->ptr<Tyd>();
|
||||
auto* __restrict__ conv_user_diff_src_buffer = dx->ptr<Txd>();
|
||||
|
||||
using tag = memory::format_tag;
|
||||
using dt = memory::data_type;
|
||||
|
||||
auto eng = engine(engine::kind::cpu, 0);
|
||||
stream s(eng);
|
||||
|
||||
std::vector<primitive> net_bwd;
|
||||
std::vector<std::unordered_map<int, memory>> net_bwd_args;
|
||||
|
||||
memory::dims conv_src_tz = {batch, ch_in, height, width};
|
||||
memory::dims conv_weights_tz = {ch_out, ch_in, kernel_size, kernel_size};
|
||||
memory::dims conv_dst_tz = {batch, ch_out, (height+padding*2-kernel_size*dilation+dilation-1)/stride+1, (width+padding*2-kernel_size*dilation+dilation-1)/stride+1};
|
||||
memory::dims conv_strides = {stride, stride};
|
||||
memory::dims conv_padding = {padding, padding};
|
||||
memory::dims conv_dilation = {dilation-1, dilation-1};
|
||||
|
||||
auto conv_user_weights_memory
|
||||
= memory({{conv_weights_tz}, dt::@Tw, tag::@WFORMAT}, eng, conv_weights);
|
||||
|
||||
auto conv_src_md = memory::desc({conv_src_tz}, dt::@Tx, tag::any);
|
||||
auto conv_weights_md = memory::desc({conv_weights_tz}, dt::@Tw, tag::any);
|
||||
auto conv_dst_md = memory::desc({conv_dst_tz}, dt::@Ty, tag::any);
|
||||
|
||||
auto conv_desc = convolution_forward::desc(prop_kind::forward,
|
||||
algorithm::convolution_direct, conv_src_md, conv_weights_md,
|
||||
conv_dst_md, conv_strides, conv_dilation, conv_padding,
|
||||
conv_padding);
|
||||
auto conv_pd = convolution_forward::primitive_desc(conv_desc, eng);
|
||||
|
||||
auto conv_weights_memory = conv_user_weights_memory;
|
||||
if (conv_pd.weights_desc() != conv_user_weights_memory.get_desc()) {
|
||||
conv_weights_memory = memory(conv_pd.weights_desc(), eng);
|
||||
net_bwd.push_back(
|
||||
reorder(conv_user_weights_memory, conv_weights_memory));
|
||||
net_bwd_args.push_back({{MKLDNN_ARG_FROM, conv_user_weights_memory},
|
||||
{MKLDNN_ARG_TO, conv_weights_memory}});
|
||||
}
|
||||
|
||||
auto conv_user_diff_dst_memory
|
||||
= memory({{conv_dst_tz}, dt::@Ty, tag::@YFORMAT}, eng, net_diff_dst);
|
||||
|
||||
auto conv_user_diff_src_memory
|
||||
= memory({{conv_src_tz}, dt::@Tx, tag::@XFORMAT}, eng, conv_user_diff_src_buffer);
|
||||
|
||||
auto conv_bwd_weights_md
|
||||
= memory::desc({conv_weights_tz}, dt::@Tw, tag::any);
|
||||
auto conv_diff_src_md = memory::desc({conv_src_tz}, dt::@Tx, tag::any);
|
||||
auto conv_diff_dst_md = memory::desc({conv_dst_tz}, dt::@Ty, tag::any);
|
||||
|
||||
auto conv_bwd_data_desc
|
||||
= convolution_backward_data::desc(algorithm::convolution_direct,
|
||||
conv_diff_src_md, conv_bwd_weights_md, conv_diff_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding);
|
||||
auto conv_bwd_data_pd = convolution_backward_data::primitive_desc(
|
||||
conv_bwd_data_desc, eng, conv_pd);
|
||||
|
||||
auto conv_diff_dst_memory = conv_user_diff_dst_memory;
|
||||
if (conv_bwd_data_pd.diff_dst_desc()
|
||||
!= conv_user_diff_dst_memory.get_desc()) {
|
||||
conv_diff_dst_memory = memory(conv_bwd_data_pd.diff_dst_desc(), eng);
|
||||
net_bwd.push_back(reorder(conv_user_diff_dst_memory, conv_diff_dst_memory));
|
||||
net_bwd_args.push_back({{MKLDNN_ARG_FROM, conv_user_diff_dst_memory},
|
||||
{MKLDNN_ARG_TO, conv_diff_dst_memory}});
|
||||
}
|
||||
|
||||
auto conv_bwd_weights_memory = conv_weights_memory;
|
||||
if (conv_bwd_data_pd.weights_desc() != conv_weights_memory.get_desc()) {
|
||||
conv_bwd_weights_memory = memory(conv_bwd_data_pd.weights_desc(), eng);
|
||||
net_bwd.push_back(reorder(conv_weights_memory, conv_bwd_weights_memory));
|
||||
net_bwd_args.push_back({{MKLDNN_ARG_FROM, conv_weights_memory},
|
||||
{MKLDNN_ARG_TO, conv_bwd_weights_memory}});
|
||||
}
|
||||
|
||||
net_bwd.push_back(convolution_backward_data(conv_bwd_data_pd));
|
||||
net_bwd_args.push_back({{MKLDNN_ARG_WEIGHTS, conv_bwd_weights_memory},
|
||||
{MKLDNN_ARG_DIFF_DST, conv_diff_dst_memory}});
|
||||
|
||||
auto conv_diff_src_memory = conv_user_diff_src_memory;
|
||||
if (conv_bwd_data_pd.diff_src_desc()
|
||||
!= conv_user_diff_src_memory.get_desc()) {
|
||||
conv_diff_src_memory
|
||||
= memory(conv_bwd_data_pd.diff_src_desc(), eng);
|
||||
net_bwd_args.back().insert(
|
||||
{MKLDNN_ARG_DIFF_SRC, conv_diff_src_memory});
|
||||
|
||||
net_bwd.push_back(reorder(
|
||||
conv_diff_src_memory, conv_user_diff_src_memory));
|
||||
net_bwd_args.push_back({{MKLDNN_ARG_FROM, conv_diff_src_memory},
|
||||
{MKLDNN_ARG_TO, conv_user_diff_src_memory}});
|
||||
} else {
|
||||
net_bwd_args.back().insert(
|
||||
{MKLDNN_ARG_DIFF_SRC, conv_diff_src_memory});
|
||||
}
|
||||
|
||||
ASSERTop(net_bwd.size(),==,net_bwd_args.size());
|
||||
|
||||
for (size_t i = 0; i < net_bwd.size(); ++i)
|
||||
net_bwd.at(i).execute(s, net_bwd_args.at(i));
|
||||
}
|
||||
#endif
|
||||
#endif // JIT
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,26 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors:
|
||||
// Guowei Yang <471184555@qq.com>
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include "op.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
struct MklConvBackwardXOp : Op {
|
||||
Var* w, * dy, * dx;
|
||||
int xh, xw, stride, padding, dilation;
|
||||
string xformat, wformat, yformat;
|
||||
|
||||
MklConvBackwardXOp(Var* w, Var* y, int height, int width, int stride, int padding, int dilation, string xformat="abcd", string wformat="oihw", string yformat="abcd");
|
||||
|
||||
const char* name() const override { return "mkl_conv_backward_x"; }
|
||||
void infer_shape() override;
|
||||
DECLARE_jit_run;
|
||||
};
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,186 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors:
|
||||
// Guowei Yang <471184555@qq.com>
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include <mkldnn.hpp>
|
||||
|
||||
#include "var.h"
|
||||
#include "mkl_conv_op.h"
|
||||
|
||||
using namespace mkldnn;
|
||||
using namespace std;
|
||||
|
||||
namespace jittor {
|
||||
|
||||
static inline int findc(const string& format, const char& c) {
|
||||
if (c==format[0]) return 0;
|
||||
if (c==format[1]) return 1;
|
||||
if (c==format[2]) return 2;
|
||||
ASSERT(c==format[3]) << "Not a valid format" << format << c;
|
||||
return 3;
|
||||
}
|
||||
|
||||
static inline void get_shape(Var* x, const char* f, const string& format, int& a, int& b, int &c, int& d) {
|
||||
auto& shape = x->shape;
|
||||
a = shape[findc(format, f[0])];
|
||||
b = shape[findc(format, f[1])];
|
||||
c = shape[findc(format, f[2])];
|
||||
d = shape[findc(format, f[3])];
|
||||
}
|
||||
|
||||
#ifndef JIT
|
||||
|
||||
static inline void set_shape(Var* x, const char* f, const string& format, int a, int b, int c, int d) {
|
||||
int64 shape[4];
|
||||
shape[findc(format, f[0])] = a;
|
||||
shape[findc(format, f[1])] = b;
|
||||
shape[findc(format, f[2])] = c;
|
||||
shape[findc(format, f[3])] = d;
|
||||
x->set_shape(NanoVector(
|
||||
shape[0], shape[1], shape[2], shape[3]));
|
||||
}
|
||||
|
||||
MklConvOp::MklConvOp(Var* x, Var* w, int stride, int padding, int dilation, string xformat, string wformat, string yformat)
|
||||
: x(x), w(w), stride(stride), padding(padding), dilation(dilation),
|
||||
xformat(move(xformat)), wformat(move(wformat)), yformat(move(yformat)) {
|
||||
y = create_output(nullptr, dtype_infer(x->ns, w->ns));
|
||||
if (!this->yformat.size())
|
||||
this->yformat = this->xformat;
|
||||
}
|
||||
|
||||
void MklConvOp::infer_shape() {
|
||||
ASSERTop(x->shape.size(),==,4);
|
||||
ASSERTop(w->shape.size(),==,4);
|
||||
int xn, xc, xh, xw, wh, ww, wci, wco, yn, yc, yh, yw;
|
||||
get_shape(x, "abcd", xformat, xn, xc, xh, xw);
|
||||
get_shape(w, "oihw", wformat, wco, wci, wh, ww);
|
||||
ASSERTop(wci,==,xc);
|
||||
yn = xn, yc = wco;
|
||||
yh = (xh+padding*2-wh*dilation+dilation-1)/stride+1;
|
||||
yw = (xw+padding*2-ww*dilation+dilation-1)/stride+1;
|
||||
set_shape(y, "abcd", yformat, yn, yc, yh, yw);
|
||||
}
|
||||
|
||||
static const char* short_type(Var* x) {
|
||||
if (x->is_float()) {
|
||||
if (x->dsize()==4) return "f32";
|
||||
if (x->dsize()==8) return "f64";
|
||||
if (x->dsize()==2) return "f16";
|
||||
return "f8";
|
||||
} else {
|
||||
if (x->dsize()==4) return "s32";
|
||||
if (x->dsize()==8) return "s64";
|
||||
if (x->dsize()==2) return "s16";
|
||||
return "s8";
|
||||
}
|
||||
}
|
||||
|
||||
void MklConvOp::jit_prepare() {
|
||||
add_jit_define("Tx", short_type(x));
|
||||
add_jit_define("Tw", short_type(w));
|
||||
add_jit_define("Ty", short_type(y));
|
||||
add_jit_define("XFORMAT", xformat);
|
||||
add_jit_define("WFORMAT", wformat);
|
||||
add_jit_define("YFORMAT", yformat);
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
#ifdef JIT_cpu
|
||||
#pragma clang diagnostic ignored "-Wtautological-compare"
|
||||
void MklConvOp::jit_run() {
|
||||
const auto& xs = x->shape;
|
||||
const auto& ws = w->shape;
|
||||
|
||||
using tag = memory::format_tag;
|
||||
using dt = memory::data_type;
|
||||
|
||||
if (tag::@XFORMAT==tag::nhwc && tag::@YFORMAT==tag::nhwc && tag::@WFORMAT==tag::hwio
|
||||
&& stride==1 && padding==0 && dilation==1 && ws[0]==1 && ws[1]==1
|
||||
&& dt::@Tx==dt::f32 && dt::@Ty==dt::f32 && dt::@Tw==dt::f32) {
|
||||
auto m = xs[0]*xs[1]*xs[2];
|
||||
auto n = ws[3];
|
||||
auto k = xs[3];
|
||||
// x: [m,k], w: [k,n], y: [m,n]
|
||||
ASSERTop(0,==,mkldnn_sgemm('N', 'N', m, n, k,
|
||||
1.f, x->ptr<float32>(), k,
|
||||
w->ptr<float32>(), n,
|
||||
0.f, y->ptr<float32>(), n));
|
||||
return;
|
||||
}
|
||||
|
||||
engine eng(engine::kind::cpu, 0);
|
||||
stream s(eng);
|
||||
|
||||
std::vector<primitive> net;
|
||||
std::vector<std::unordered_map<int, memory>> net_args;
|
||||
|
||||
int xn, xc, xh, xw, wh, ww, wci, wco, yn, yc, yh, yw;
|
||||
get_shape(x, "abcd", xformat, xn, xc, xh, xw);
|
||||
get_shape(w, "oihw", wformat, wco, wci, wh, ww);
|
||||
get_shape(y, "abcd", yformat, yn, yc, yh, yw);
|
||||
|
||||
memory::dims conv1_src_tz = {xn, xc, xh, xw};
|
||||
memory::dims conv1_weights_tz = {wco, wci, wh, ww};
|
||||
memory::dims conv1_dst_tz = {yn, yc, yh, yw};
|
||||
memory::dims conv1_strides = { stride, stride };
|
||||
memory::dims conv1_padding = { padding, padding };
|
||||
memory::dims conv1_dilation = { dilation-1, dilation-1 };
|
||||
|
||||
auto user_src_memory = memory(
|
||||
{ { conv1_src_tz }, dt::@Tx, tag::@XFORMAT }, eng, x->mem_ptr);
|
||||
auto user_dst_memory = memory(
|
||||
{ { conv1_dst_tz }, dt::@Ty, tag::@YFORMAT }, eng, y->mem_ptr);
|
||||
auto user_weights_memory = memory(
|
||||
{ { conv1_weights_tz }, dt::@Tw, tag::@WFORMAT }, eng, w->mem_ptr);
|
||||
|
||||
auto conv1_src_md = memory::desc({ conv1_src_tz }, dt::@Tx, tag::any);
|
||||
auto conv1_weights_md
|
||||
= memory::desc({ conv1_weights_tz }, dt::@Tw, tag::any);
|
||||
auto conv1_dst_md = memory::desc({ conv1_dst_tz }, dt::@Ty, tag::any);
|
||||
|
||||
auto conv1_desc = convolution_forward::desc(prop_kind::forward_inference,
|
||||
algorithm::convolution_auto, conv1_src_md, conv1_weights_md, conv1_dst_md, conv1_strides, conv1_dilation, conv1_padding, conv1_padding);
|
||||
|
||||
auto conv1_prim_desc = convolution_forward::primitive_desc(conv1_desc, eng);
|
||||
|
||||
net.clear();
|
||||
net_args.clear();
|
||||
auto conv1_src_memory = user_src_memory;
|
||||
if (conv1_prim_desc.src_desc() != user_src_memory.get_desc()) {
|
||||
conv1_src_memory = memory(conv1_prim_desc.src_desc(), eng);
|
||||
net.push_back(reorder(user_src_memory, conv1_src_memory));
|
||||
net_args.push_back({ { MKLDNN_ARG_FROM, user_src_memory },
|
||||
{ MKLDNN_ARG_TO, conv1_src_memory } });
|
||||
}
|
||||
|
||||
auto conv1_weights_memory = user_weights_memory;
|
||||
if (conv1_prim_desc.weights_desc() != user_weights_memory.get_desc()) {
|
||||
conv1_weights_memory = memory(conv1_prim_desc.weights_desc(), eng);
|
||||
net.push_back(reorder(user_weights_memory, conv1_weights_memory));
|
||||
net_args.push_back({ { MKLDNN_ARG_FROM, user_weights_memory }, { MKLDNN_ARG_TO, conv1_weights_memory } });
|
||||
}
|
||||
|
||||
auto conv1_dst_memory = memory(conv1_prim_desc.dst_desc(), eng);
|
||||
|
||||
net.push_back(convolution_forward(conv1_prim_desc));
|
||||
net_args.push_back({ { MKLDNN_ARG_SRC, conv1_src_memory },
|
||||
{ MKLDNN_ARG_WEIGHTS, conv1_weights_memory },
|
||||
{ MKLDNN_ARG_DST, conv1_dst_memory } });
|
||||
|
||||
if (conv1_dst_memory != user_dst_memory) {
|
||||
net.push_back(reorder(conv1_dst_memory, user_dst_memory));
|
||||
net_args.push_back({ { MKLDNN_ARG_FROM, conv1_dst_memory },{ MKLDNN_ARG_TO, user_dst_memory } });
|
||||
}
|
||||
|
||||
ASSERTop(net.size(),==,net_args.size());
|
||||
for (size_t i = 0; i < net.size(); ++i)
|
||||
net.at(i).execute(s, net_args.at(i));
|
||||
}
|
||||
#endif
|
||||
#endif // JIT
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,26 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors:
|
||||
// Guowei Yang <471184555@qq.com>
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include "op.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
struct MklConvOp : Op {
|
||||
Var* x, * w, * y;
|
||||
int stride, padding, dilation;
|
||||
string xformat, wformat, yformat;
|
||||
/* MklConvOp: xformat abcd represents nchw */
|
||||
MklConvOp(Var* x, Var* w, int stride, int padding, int dilation=1, string xformat="abcd", string wformat="oihw", string yformat="");
|
||||
|
||||
const char* name() const override { return "mkl_conv"; }
|
||||
void infer_shape() override;
|
||||
DECLARE_jit_run;
|
||||
};
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,76 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors:
|
||||
// Guowei Yang <471184555@qq.com>
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include <mkldnn.hpp>
|
||||
|
||||
#include "var.h"
|
||||
#include "mkl_matmul_op.h"
|
||||
|
||||
using namespace mkldnn;
|
||||
using namespace std;
|
||||
|
||||
namespace jittor {
|
||||
|
||||
#ifndef JIT
|
||||
|
||||
MklMatmulOp::MklMatmulOp(Var* a, Var* b, bool trans_a, bool trans_b)
|
||||
: a(a), b(b), trans_a(trans_a), trans_b(trans_b) {
|
||||
// TODO: support int8 * int8
|
||||
ASSERT(a->dtype().is_float() && b->dtype().is_float()) << "type of two inputs should be the same";
|
||||
// TODO: support diffrent input type
|
||||
ASSERT(a->dtype().dsize() == 4 && b->dtype().dsize() == 4) << "support float32 only now.";
|
||||
c = create_output(nullptr, a->dtype());
|
||||
}
|
||||
|
||||
void MklMatmulOp::infer_shape() {
|
||||
ASSERTop(a->shape.size(),==,2);
|
||||
ASSERTop(b->shape.size(),==,2);
|
||||
int n = a->shape[0], m = a->shape[1];
|
||||
int m_ = b->shape[0], k = b->shape[1];
|
||||
if (trans_a) {
|
||||
swap(n, m);
|
||||
}
|
||||
if (trans_b) {
|
||||
swap(m_, k);
|
||||
}
|
||||
ASSERTop(m,==,m_);
|
||||
c->set_shape({n, k});
|
||||
}
|
||||
|
||||
void MklMatmulOp::jit_prepare() {
|
||||
add_jit_define("T", a->dtype());
|
||||
add_jit_define("Trans_a", trans_a ? "T" : "N");
|
||||
add_jit_define("Trans_b", trans_b ? "T" : "N");
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
#ifdef JIT_cpu
|
||||
#pragma clang diagnostic ignored "-Wtautological-compare"
|
||||
void MklMatmulOp::jit_run() {
|
||||
const auto& as = a->shape;
|
||||
const auto& bs = b->shape;
|
||||
auto n = as[0];
|
||||
auto m = as[1];
|
||||
auto k = bs[1];
|
||||
if ('@Trans_a'=='T') {
|
||||
n = as[1];
|
||||
m = as[0];
|
||||
}
|
||||
if ('@Trans_b'=='T') {
|
||||
k = bs[0];
|
||||
}
|
||||
// a: [n,m], b: [m,k], c: [n,k]
|
||||
ASSERTop(0,==,mkldnn_sgemm('@Trans_a', '@Trans_b', n, k, m,
|
||||
1.f, a->ptr<T>(), '@Trans_a'=='N'? m : n,
|
||||
b->ptr<T>(), '@Trans_b' == 'N' ? k : m,
|
||||
0.f, c->ptr<T>(), k));
|
||||
}
|
||||
#endif
|
||||
#endif // JIT
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,24 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors:
|
||||
// Guoye Yang <498731903@qq.com>
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include "op.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
struct MklMatmulOp : Op {
|
||||
Var* a, * b, * c;
|
||||
bool trans_a, trans_b;
|
||||
MklMatmulOp(Var* a, Var* b, bool trans_a, bool trans_b);
|
||||
|
||||
const char* name() const override { return "mkl_matmul"; }
|
||||
void infer_shape() override;
|
||||
DECLARE_jit_run;
|
||||
};
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,33 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include <random>
|
||||
|
||||
#include "var.h"
|
||||
#include "mkl_test_op.h"
|
||||
|
||||
int mkl_test_entry();
|
||||
|
||||
namespace jittor {
|
||||
|
||||
#ifndef JIT
|
||||
MklTestOp::MklTestOp() {
|
||||
output = create_output(1, ns_float32);
|
||||
}
|
||||
|
||||
void MklTestOp::jit_prepare() {
|
||||
add_jit_define("T", ns_float32);
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
#ifdef JIT_cpu
|
||||
void MklTestOp::jit_run() {
|
||||
ASSERT(mkl_test_entry()==0);
|
||||
output->ptr<T>()[0] = 123;
|
||||
}
|
||||
#endif
|
||||
#endif // JIT
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,19 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include "op.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
struct MklTestOp : Op {
|
||||
Var* output;
|
||||
MklTestOp();
|
||||
|
||||
const char* name() const override { return "mkl_test"; }
|
||||
DECLARE_jit_run;
|
||||
};
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,7 @@
|
|||
from .md_to_ipynb import dirname, notebook_dir
|
||||
import os
|
||||
import sys
|
||||
|
||||
cmd = f"cp -r {dirname}/* {notebook_dir}/ && cd {notebook_dir} && jupyter notebook {' '.join(sys.argv[1:])}"
|
||||
print("run cmd:", cmd)
|
||||
os.system(cmd)
|
|
@ -0,0 +1,57 @@
|
|||
# Basics: Op, Var
|
||||
|
||||
# 基本概念:Op, Var
|
||||
|
||||
To train your model with jittor, there are only two main concept you need to know:
|
||||
|
||||
要使用jittor训练模型,您需要了解两个主要概念:
|
||||
|
||||
* Var: basic data type of jittor
|
||||
* Var:Jittor的基本数据类型
|
||||
* Operations: Jittor'op is simular with numpy
|
||||
* Operations:Jittor的算子与numpy类似
|
||||
|
||||
## Var
|
||||
First, let's get started with Var. Var is the basic data type of jittor. Computation process in Jittor is asynchronous for optimization. If you want to access the data, `Var.data` can be used for synchronous data accessing.
|
||||
|
||||
首先,让我们开始使用Var。Var是jittor的基本数据类型,为了运算更加高效Jittor中的计算过程是异步的。 如果要访问数据,可以使用`Var.data`进行同步数据访问。
|
||||
|
||||
```
|
||||
import jittor as jt
|
||||
a = jt.float32([1,2,3])
|
||||
print (a)
|
||||
print (a.data)
|
||||
# Output: float32[3,]
|
||||
# Output: [ 1. 2. 3.]
|
||||
```
|
||||
## Op
|
||||
Jittor'op is simular with numpy. Let's try some operations. We create Var `a` and `b` via operation `jt.float32`, and add them. Printing those variables shows they have the same shape and dtype.
|
||||
|
||||
Jittor的算子与numpy类似。 让我们尝试一些操作, 我们通过操作jt.float32创建Var `a`和`b`,并将它们相加。 输出这些变量相关信息,可以看出它们具有相同的形状和类型。
|
||||
|
||||
```
|
||||
import jittor as jt
|
||||
a = jt.float32([1,2,3])
|
||||
b = jt.float32([4,5,6])
|
||||
c = a+b
|
||||
print(a,b,c)
|
||||
```
|
||||
|
||||
Beside that, All the operators we used `jt.xxx(Var, ...)` have alias `Var.xxx(...)`. For example:
|
||||
|
||||
除此之外,我们使用的所有算子`jt.xxx(Var,...)`都具有别名`Var.xxx(...)`。 例如:
|
||||
|
||||
```
|
||||
c.max() # alias of jt.max(a)
|
||||
c.add(a) # alias of jt.add(c, a)
|
||||
c.min(keepdims=True) # alias of jt.min(c, keepdims=True)
|
||||
```
|
||||
|
||||
if you want to know all the operation which Jittor supports. try `help(jt.ops)`. All the operation you found in `jt.ops.xxx`, can be used via alias `jt.xxx`.
|
||||
|
||||
如果您想知道Jittor支持的所有操作,可以运行`help(jt.ops)`。 您在`jt.ops.xxx`中找到的所有操作都可以通过别名`jt.xxx`。
|
||||
|
||||
```
|
||||
help(jt.ops)
|
||||
```
|
||||
|
|
@ -0,0 +1,109 @@
|
|||
# Custom Op: write your operator with C++ and CUDA and JIT compile it
|
||||
|
||||
# 自定义算子:使用C ++和CUDA编写您的算子,并其进行即时编译
|
||||
|
||||
> NOTE: This tutorial is still working in progress
|
||||
|
||||
In this tutorial, we will show:
|
||||
|
||||
1. how to write your operator with C++ and CUDA and JIT compile it
|
||||
2. execute your custom operation
|
||||
|
||||
If you want to implement a very simple op with few lines of code, please use code op, please see `help(jt.code)`.
|
||||
custom_op is used for implement a complicated op. The capabilities of custom_op and built-in operations are exactly the same.
|
||||
|
||||
> 注意:本教程仍在持续更新中
|
||||
|
||||
在本教程中,我们将展示:
|
||||
|
||||
1. 如何用C ++和CUDA编写您的算子并对其进行即时编译
|
||||
2. 运行您的自定义算子
|
||||
|
||||
如果您想用几行代码来实现一个非常简单的算子,请使用code运算,请参阅`help(jt.code)`.
|
||||
custom_op用于实现复杂的算子。 custom_op和内置运算的功能完全相同。
|
||||
|
||||
```python
|
||||
import jittor as jt
|
||||
|
||||
header ="""
|
||||
#pragma once
|
||||
#include "op.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
struct CustomOp : Op {
|
||||
Var* output;
|
||||
CustomOp(NanoVector shape, NanoString dtype=ns_float);
|
||||
|
||||
const char* name() const override { return "custom"; }
|
||||
DECLARE_jit_run;
|
||||
};
|
||||
|
||||
} // jittor
|
||||
"""
|
||||
|
||||
src = """
|
||||
#include "var.h"
|
||||
#include "custom_op.h"
|
||||
|
||||
namespace jittor {
|
||||
#ifndef JIT
|
||||
CustomOp::CustomOp(NanoVector shape, NanoString dtype) {
|
||||
flags.set(NodeFlags::_cuda, 1);
|
||||
flags.set(NodeFlags::_cpu, 1);
|
||||
output = create_output(shape, dtype);
|
||||
}
|
||||
|
||||
void CustomOp::jit_prepare() {
|
||||
add_jit_define("T", output->dtype());
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
#ifdef JIT_cpu
|
||||
void CustomOp::jit_run() {
|
||||
index_t num = output->num;
|
||||
auto* __restrict__ x = output->ptr<T>();
|
||||
for (index_t i=0; i<num; i++)
|
||||
x[i] = (T)i;
|
||||
}
|
||||
#else
|
||||
// JIT_cuda
|
||||
__global__ void kernel(index_t n, T *x) {
|
||||
int index = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int stride = blockDim.x * gridDim.x;
|
||||
for (int i = index; i < n; i += stride)
|
||||
x[i] = (T)-i;
|
||||
}
|
||||
|
||||
void CustomOp::jit_run() {
|
||||
index_t num = output->num;
|
||||
auto* __restrict__ x = output->ptr<T>();
|
||||
int blockSize = 256;
|
||||
int numBlocks = (num + blockSize - 1) / blockSize;
|
||||
kernel<<<numBlocks, blockSize>>>(num, x);
|
||||
}
|
||||
#endif // JIT_cpu
|
||||
#endif // JIT
|
||||
|
||||
} // jittor
|
||||
"""
|
||||
|
||||
my_op = jt.compile_custom_op(header, src, "custom", warp=False)
|
||||
```
|
||||
|
||||
Let's check the result of this op.
|
||||
|
||||
让我们查看一下这个运算的结果。
|
||||
|
||||
```python
|
||||
# run cpu version
|
||||
jt.flags.use_cuda = 0
|
||||
a = my_op([3,4,5], 'float').fetch_sync()
|
||||
assert (a.flatten() == range(3*4*5)).all()
|
||||
|
||||
if jt.compiler.has_cuda:
|
||||
# run cuda version
|
||||
jt.flags.use_cuda = 1
|
||||
a = my_op([3,4,5], 'float').fetch_sync()
|
||||
assert (-a.flatten() == range(3*4*5)).all()
|
||||
```
|
|
@ -0,0 +1,68 @@
|
|||
# Example: Model definition and training
|
||||
|
||||
# 示例:模型定义与训练
|
||||
|
||||
The following example shows how to model a two-layer neural network step by step and train from scratch In a few lines of Python code.
|
||||
|
||||
以下示例展示了如何逐步搭建两层神经网络模型,并通过几行Python代码从头开始进行模型训练。
|
||||
|
||||
```
|
||||
import jittor as jt
|
||||
import numpy as np
|
||||
from jittor import nn, Module, init
|
||||
```
|
||||
|
||||
The following code defines our model, which is a two-layer neural network. The size of hidden layer is 10. and the activation function is relu.
|
||||
|
||||
以下代码定义了我们的模型,该模型是一个两层神经网络。 隐藏层的大小为10,激活函数为relu。
|
||||
|
||||
```
|
||||
### model define
|
||||
|
||||
class Model(Module):
|
||||
def __init__(self):
|
||||
self.layer1 = nn.Linear(1, 10)
|
||||
self.relu = nn.ReLU()
|
||||
self.layer2 = nn.Linear(10, 1)
|
||||
def execute (self,x) :
|
||||
x = self.layer1(x)
|
||||
x = self.relu(x)
|
||||
x = self.layer2(x)
|
||||
return x
|
||||
```
|
||||
|
||||
At last, this model is trained from scratch. A simple gradient descent is used, and the loss function is L2 distance. The training process is asynchronous for efficiency. jittor calculates the gradients and applies graph- and operator-level optimizations via **unify IR graph** and **jit analyzer**.
|
||||
In this example, multiple optimizations can be used, including: **operator fusion**, the activation function and loss function can be fused into the first and second linear layers; Three meta-operators in matrix multiplication could also be fused. **Parallelism**, it can improve performance of compute-intensive operations on modern multi-core CPUs and GPUs. The operator fusion is a graph-level optimization, and parallelism can be achieved in both graph-level and operator-level.
|
||||
|
||||
最后,从头开始训练该模型。 优化器使用简单的梯度下降,损失函数为L2距离。 为提高效率训练过程是异步的。 jittor通过**统一计算图**和**即时分析器**计算梯度,并进行计算图级和算子级的优化。
|
||||
|
||||
在该示例中,Jittor使用了多个优化,包括:**算子融合**,激活函数和损失函数可以融合到第一和第二全连接层中; 矩阵乘法中的三元算子也可以融合。 **并行化**,它可以提高现代多核CPU和GPU上计算密集型运算的性能。 算子融合是一种计算图级优化,而并行化则同时作用于图形级和算子级的优化。
|
||||
|
||||
```
|
||||
np.random.seed(0)
|
||||
jt.set_seed(3)
|
||||
n = 1000
|
||||
batch_size = 50
|
||||
base_lr = 0.05
|
||||
# we need to stop grad of global value to prevent memory leak
|
||||
lr = jt.float32(base_lr).name("lr").stop_grad()
|
||||
|
||||
def get_data(n):
|
||||
for i in range(n):
|
||||
x = np.random.rand(batch_size, 1)
|
||||
y = x*x
|
||||
yield jt.float32(x), jt.float32(y)
|
||||
|
||||
model = Model()
|
||||
learning_rate = 0.1
|
||||
optim = nn.SGD (model.parameters(), learning_rate)
|
||||
|
||||
for i,(x,y) in enumerate(get_data(n)):
|
||||
pred_y = model(x)
|
||||
loss = ((pred_y - y)**2)
|
||||
loss_mean = loss.mean()
|
||||
optim.step (loss_mean)
|
||||
print(f"step {i}, loss = {loss_mean.data.sum()}")
|
||||
|
||||
assert loss_mean.data < 0.005
|
||||
```
|
|
@ -0,0 +1,936 @@
|
|||
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
|
||||
<!-- Created with Inkscape (http://www.inkscape.org/) -->
|
||||
|
||||
<svg
|
||||
xmlns:dc="http://purl.org/dc/elements/1.1/"
|
||||
xmlns:cc="http://creativecommons.org/ns#"
|
||||
xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#"
|
||||
xmlns:svg="http://www.w3.org/2000/svg"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd"
|
||||
xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape"
|
||||
width="55.587765mm"
|
||||
height="77.815247mm"
|
||||
viewBox="0 0 196.96452 275.72331"
|
||||
id="svg4512"
|
||||
version="1.1"
|
||||
inkscape:version="0.91 r13725"
|
||||
sodipodi:docname="mop.svg">
|
||||
<defs
|
||||
id="defs4514">
|
||||
<marker
|
||||
inkscape:isstock="true"
|
||||
style="overflow:visible"
|
||||
id="marker18253"
|
||||
refX="0"
|
||||
refY="0"
|
||||
orient="auto"
|
||||
inkscape:stockid="Arrow1Send">
|
||||
<path
|
||||
inkscape:connector-curvature="0"
|
||||
transform="matrix(-0.2,0,0,-0.2,-1.2,0)"
|
||||
style="fill:#000000;fill-opacity:1;fill-rule:evenodd;stroke:#000000;stroke-width:1pt;stroke-opacity:1"
|
||||
d="M 0,0 5,-5 -12.5,0 5,5 0,0 Z"
|
||||
id="path18255" />
|
||||
</marker>
|
||||
<marker
|
||||
inkscape:stockid="Arrow1Send"
|
||||
orient="auto"
|
||||
refY="0"
|
||||
refX="0"
|
||||
id="marker16795"
|
||||
style="overflow:visible"
|
||||
inkscape:isstock="true"
|
||||
inkscape:collect="always">
|
||||
<path
|
||||
id="path16797"
|
||||
d="M 0,0 5,-5 -12.5,0 5,5 0,0 Z"
|
||||
style="fill:#000000;fill-opacity:1;fill-rule:evenodd;stroke:#000000;stroke-width:1pt;stroke-opacity:1"
|
||||
transform="matrix(-0.2,0,0,-0.2,-1.2,0)"
|
||||
inkscape:connector-curvature="0" />
|
||||
</marker>
|
||||
<marker
|
||||
inkscape:isstock="true"
|
||||
style="overflow:visible"
|
||||
id="marker16429"
|
||||
refX="0"
|
||||
refY="0"
|
||||
orient="auto"
|
||||
inkscape:stockid="Arrow1Send"
|
||||
inkscape:collect="always">
|
||||
<path
|
||||
transform="matrix(-0.2,0,0,-0.2,-1.2,0)"
|
||||
style="fill:#000000;fill-opacity:1;fill-rule:evenodd;stroke:#000000;stroke-width:1pt;stroke-opacity:1"
|
||||
d="M 0,0 5,-5 -12.5,0 5,5 0,0 Z"
|
||||
id="path16431"
|
||||
inkscape:connector-curvature="0" />
|
||||
</marker>
|
||||
<marker
|
||||
inkscape:stockid="Arrow1Send"
|
||||
orient="auto"
|
||||
refY="0"
|
||||
refX="0"
|
||||
id="marker16111"
|
||||
style="overflow:visible"
|
||||
inkscape:isstock="true"
|
||||
inkscape:collect="always">
|
||||
<path
|
||||
id="path16113"
|
||||
d="M 0,0 5,-5 -12.5,0 5,5 0,0 Z"
|
||||
style="fill:#000000;fill-opacity:1;fill-rule:evenodd;stroke:#000000;stroke-width:1pt;stroke-opacity:1"
|
||||
transform="matrix(-0.2,0,0,-0.2,-1.2,0)"
|
||||
inkscape:connector-curvature="0" />
|
||||
</marker>
|
||||
<marker
|
||||
inkscape:isstock="true"
|
||||
style="overflow:visible"
|
||||
id="marker15763"
|
||||
refX="0"
|
||||
refY="0"
|
||||
orient="auto"
|
||||
inkscape:stockid="Arrow1Send"
|
||||
inkscape:collect="always">
|
||||
<path
|
||||
transform="matrix(-0.2,0,0,-0.2,-1.2,0)"
|
||||
style="fill:#000000;fill-opacity:1;fill-rule:evenodd;stroke:#000000;stroke-width:1pt;stroke-opacity:1"
|
||||
d="M 0,0 5,-5 -12.5,0 5,5 0,0 Z"
|
||||
id="path15765"
|
||||
inkscape:connector-curvature="0" />
|
||||
</marker>
|
||||
<marker
|
||||
inkscape:stockid="Arrow1Send"
|
||||
orient="auto"
|
||||
refY="0"
|
||||
refX="0"
|
||||
id="marker15385"
|
||||
style="overflow:visible"
|
||||
inkscape:isstock="true"
|
||||
inkscape:collect="always">
|
||||
<path
|
||||
id="path15387"
|
||||
d="M 0,0 5,-5 -12.5,0 5,5 0,0 Z"
|
||||
style="fill:#000000;fill-opacity:1;fill-rule:evenodd;stroke:#000000;stroke-width:1pt;stroke-opacity:1"
|
||||
transform="matrix(-0.2,0,0,-0.2,-1.2,0)"
|
||||
inkscape:connector-curvature="0" />
|
||||
</marker>
|
||||
<marker
|
||||
inkscape:stockid="Arrow1Sstart"
|
||||
orient="auto"
|
||||
refY="0"
|
||||
refX="0"
|
||||
id="Arrow1Sstart"
|
||||
style="overflow:visible"
|
||||
inkscape:isstock="true">
|
||||
<path
|
||||
id="path5140"
|
||||
d="M 0,0 5,-5 -12.5,0 5,5 0,0 Z"
|
||||
style="fill:#000000;fill-opacity:1;fill-rule:evenodd;stroke:#000000;stroke-width:1pt;stroke-opacity:1"
|
||||
transform="matrix(0.2,0,0,0.2,1.2,0)"
|
||||
inkscape:connector-curvature="0" />
|
||||
</marker>
|
||||
<marker
|
||||
inkscape:stockid="Arrow1Send"
|
||||
orient="auto"
|
||||
refY="0"
|
||||
refX="0"
|
||||
id="marker10149"
|
||||
style="overflow:visible"
|
||||
inkscape:isstock="true">
|
||||
<path
|
||||
id="path10151"
|
||||
d="M 0,0 5,-5 -12.5,0 5,5 0,0 Z"
|
||||
style="fill:#000000;fill-opacity:1;fill-rule:evenodd;stroke:#000000;stroke-width:1pt;stroke-opacity:1"
|
||||
transform="matrix(-0.2,0,0,-0.2,-1.2,0)"
|
||||
inkscape:connector-curvature="0" />
|
||||
</marker>
|
||||
<marker
|
||||
inkscape:isstock="true"
|
||||
style="overflow:visible"
|
||||
id="marker9311"
|
||||
refX="0"
|
||||
refY="0"
|
||||
orient="auto"
|
||||
inkscape:stockid="Arrow1Send"
|
||||
inkscape:collect="always">
|
||||
<path
|
||||
transform="matrix(-0.2,0,0,-0.2,-1.2,0)"
|
||||
style="fill:#000000;fill-opacity:1;fill-rule:evenodd;stroke:#000000;stroke-width:1pt;stroke-opacity:1"
|
||||
d="M 0,0 5,-5 -12.5,0 5,5 0,0 Z"
|
||||
id="path9313"
|
||||
inkscape:connector-curvature="0" />
|
||||
</marker>
|
||||
<marker
|
||||
inkscape:isstock="true"
|
||||
style="overflow:visible"
|
||||
id="marker7663"
|
||||
refX="0"
|
||||
refY="0"
|
||||
orient="auto"
|
||||
inkscape:stockid="Arrow1Send"
|
||||
inkscape:collect="always">
|
||||
<path
|
||||
transform="matrix(-0.2,0,0,-0.2,-1.2,0)"
|
||||
style="fill:#000000;fill-opacity:1;fill-rule:evenodd;stroke:#000000;stroke-width:1pt;stroke-opacity:1"
|
||||
d="M 0,0 5,-5 -12.5,0 5,5 0,0 Z"
|
||||
id="path7665"
|
||||
inkscape:connector-curvature="0" />
|
||||
</marker>
|
||||
<marker
|
||||
inkscape:stockid="Arrow1Send"
|
||||
orient="auto"
|
||||
refY="0"
|
||||
refX="0"
|
||||
id="Arrow1Send"
|
||||
style="overflow:visible"
|
||||
inkscape:isstock="true"
|
||||
inkscape:collect="always">
|
||||
<path
|
||||
id="path5143"
|
||||
d="M 0,0 5,-5 -12.5,0 5,5 0,0 Z"
|
||||
style="fill:#000000;fill-opacity:1;fill-rule:evenodd;stroke:#000000;stroke-width:1pt;stroke-opacity:1"
|
||||
transform="matrix(-0.2,0,0,-0.2,-1.2,0)"
|
||||
inkscape:connector-curvature="0" />
|
||||
</marker>
|
||||
<marker
|
||||
inkscape:isstock="true"
|
||||
style="overflow:visible"
|
||||
id="marker7523"
|
||||
refX="0"
|
||||
refY="0"
|
||||
orient="auto"
|
||||
inkscape:stockid="Arrow1Mend">
|
||||
<path
|
||||
transform="matrix(-0.4,0,0,-0.4,-4,0)"
|
||||
style="fill:#000000;fill-opacity:1;fill-rule:evenodd;stroke:#000000;stroke-width:1pt;stroke-opacity:1"
|
||||
d="M 0,0 5,-5 -12.5,0 5,5 0,0 Z"
|
||||
id="path7525"
|
||||
inkscape:connector-curvature="0" />
|
||||
</marker>
|
||||
<marker
|
||||
inkscape:stockid="Arrow1Mend"
|
||||
orient="auto"
|
||||
refY="0"
|
||||
refX="0"
|
||||
id="marker7459"
|
||||
style="overflow:visible"
|
||||
inkscape:isstock="true">
|
||||
<path
|
||||
id="path7461"
|
||||
d="M 0,0 5,-5 -12.5,0 5,5 0,0 Z"
|
||||
style="fill:#000000;fill-opacity:1;fill-rule:evenodd;stroke:#000000;stroke-width:1pt;stroke-opacity:1"
|
||||
transform="matrix(-0.4,0,0,-0.4,-4,0)"
|
||||
inkscape:connector-curvature="0" />
|
||||
</marker>
|
||||
<marker
|
||||
inkscape:isstock="true"
|
||||
style="overflow:visible"
|
||||
id="marker7287"
|
||||
refX="0"
|
||||
refY="0"
|
||||
orient="auto"
|
||||
inkscape:stockid="Arrow1Mend">
|
||||
<path
|
||||
transform="matrix(-0.4,0,0,-0.4,-4,0)"
|
||||
style="fill:#000000;fill-opacity:1;fill-rule:evenodd;stroke:#000000;stroke-width:1pt;stroke-opacity:1"
|
||||
d="M 0,0 5,-5 -12.5,0 5,5 0,0 Z"
|
||||
id="path7289"
|
||||
inkscape:connector-curvature="0" />
|
||||
</marker>
|
||||
<marker
|
||||
inkscape:stockid="Arrow1Mend"
|
||||
orient="auto"
|
||||
refY="0"
|
||||
refX="0"
|
||||
id="marker5741"
|
||||
style="overflow:visible"
|
||||
inkscape:isstock="true">
|
||||
<path
|
||||
id="path5743"
|
||||
d="M 0,0 5,-5 -12.5,0 5,5 0,0 Z"
|
||||
style="fill:#000000;fill-opacity:1;fill-rule:evenodd;stroke:#000000;stroke-width:1pt;stroke-opacity:1"
|
||||
transform="matrix(-0.4,0,0,-0.4,-4,0)"
|
||||
inkscape:connector-curvature="0" />
|
||||
</marker>
|
||||
<inkscape:path-effect
|
||||
effect="powerstroke"
|
||||
id="path-effect5731"
|
||||
is_visible="true"
|
||||
offset_points="0,0.5"
|
||||
sort_points="true"
|
||||
interpolator_type="Linear"
|
||||
interpolator_beta="0.2"
|
||||
start_linecap_type="zerowidth"
|
||||
linejoin_type="round"
|
||||
miter_limit="4"
|
||||
end_linecap_type="zerowidth"
|
||||
cusp_linecap_type="round" />
|
||||
<marker
|
||||
inkscape:isstock="true"
|
||||
style="overflow:visible"
|
||||
id="marker5665"
|
||||
refX="0"
|
||||
refY="0"
|
||||
orient="auto"
|
||||
inkscape:stockid="Arrow1Mend"
|
||||
inkscape:collect="always">
|
||||
<path
|
||||
transform="matrix(-0.4,0,0,-0.4,-4,0)"
|
||||
style="fill:#000000;fill-opacity:1;fill-rule:evenodd;stroke:#000000;stroke-width:1pt;stroke-opacity:1"
|
||||
d="M 0,0 5,-5 -12.5,0 5,5 0,0 Z"
|
||||
id="path5667"
|
||||
inkscape:connector-curvature="0" />
|
||||
</marker>
|
||||
<marker
|
||||
inkscape:stockid="Arrow2Lend"
|
||||
orient="auto"
|
||||
refY="0"
|
||||
refX="0"
|
||||
id="Arrow2Lend"
|
||||
style="overflow:visible"
|
||||
inkscape:isstock="true">
|
||||
<path
|
||||
id="path5149"
|
||||
style="fill:#000000;fill-opacity:1;fill-rule:evenodd;stroke:#000000;stroke-width:0.625;stroke-linejoin:round;stroke-opacity:1"
|
||||
d="M 8.7185878,4.0337352 -2.2072895,0.01601326 8.7185884,-4.0017078 c -1.7454984,2.3720609 -1.7354408,5.6174519 -6e-7,8.035443 z"
|
||||
transform="matrix(-1.1,0,0,-1.1,-1.1,0)"
|
||||
inkscape:connector-curvature="0" />
|
||||
</marker>
|
||||
<marker
|
||||
inkscape:stockid="Arrow1Mend"
|
||||
orient="auto"
|
||||
refY="0"
|
||||
refX="0"
|
||||
id="Arrow1Mend"
|
||||
style="overflow:visible"
|
||||
inkscape:isstock="true"
|
||||
inkscape:collect="always">
|
||||
<path
|
||||
id="path5137"
|
||||
d="M 0,0 5,-5 -12.5,0 5,5 0,0 Z"
|
||||
style="fill:#000000;fill-opacity:1;fill-rule:evenodd;stroke:#000000;stroke-width:1pt;stroke-opacity:1"
|
||||
transform="matrix(-0.4,0,0,-0.4,-4,0)"
|
||||
inkscape:connector-curvature="0" />
|
||||
</marker>
|
||||
<marker
|
||||
inkscape:stockid="Arrow1Lend"
|
||||
orient="auto"
|
||||
refY="0"
|
||||
refX="0"
|
||||
id="Arrow1Lend"
|
||||
style="overflow:visible"
|
||||
inkscape:isstock="true">
|
||||
<path
|
||||
id="path5131"
|
||||
d="M 0,0 5,-5 -12.5,0 5,5 0,0 Z"
|
||||
style="fill:#000000;fill-opacity:1;fill-rule:evenodd;stroke:#000000;stroke-width:1pt;stroke-opacity:1"
|
||||
transform="matrix(-0.8,0,0,-0.8,-10,0)"
|
||||
inkscape:connector-curvature="0" />
|
||||
</marker>
|
||||
</defs>
|
||||
<sodipodi:namedview
|
||||
id="base"
|
||||
pagecolor="#ffffff"
|
||||
bordercolor="#666666"
|
||||
borderopacity="1.0"
|
||||
inkscape:pageopacity="0.0"
|
||||
inkscape:pageshadow="2"
|
||||
inkscape:zoom="2.8"
|
||||
inkscape:cx="-17.612906"
|
||||
inkscape:cy="137.15779"
|
||||
inkscape:document-units="px"
|
||||
inkscape:current-layer="layer1"
|
||||
showgrid="false"
|
||||
inkscape:window-width="1855"
|
||||
inkscape:window-height="1056"
|
||||
inkscape:window-x="65"
|
||||
inkscape:window-y="24"
|
||||
inkscape:window-maximized="1"
|
||||
showguides="false"
|
||||
fit-margin-top="0"
|
||||
fit-margin-left="0"
|
||||
fit-margin-right="0"
|
||||
fit-margin-bottom="0" />
|
||||
<metadata
|
||||
id="metadata4517">
|
||||
<rdf:RDF>
|
||||
<cc:Work
|
||||
rdf:about="">
|
||||
<dc:format>image/svg+xml</dc:format>
|
||||
<dc:type
|
||||
rdf:resource="http://purl.org/dc/dcmitype/StillImage" />
|
||||
<dc:title></dc:title>
|
||||
</cc:Work>
|
||||
</rdf:RDF>
|
||||
</metadata>
|
||||
<g
|
||||
inkscape:label="Layer 1"
|
||||
inkscape:groupmode="layer"
|
||||
id="layer1"
|
||||
transform="translate(-95.719566,-101.59962)">
|
||||
<rect
|
||||
style="opacity:1;fill:none;fill-opacity:1;fill-rule:nonzero;stroke:#000000;stroke-width:1.50000012;stroke-linecap:round;stroke-linejoin:miter;stroke-miterlimit:4;stroke-dasharray:3.00000008, 3.00000008;stroke-dashoffset:0;stroke-opacity:1"
|
||||
id="rect5060"
|
||||
width="99.362617"
|
||||
height="28.284271"
|
||||
x="144.16338"
|
||||
y="102.34962" />
|
||||
<text
|
||||
xml:space="preserve"
|
||||
style="font-style:normal;font-weight:normal;font-size:15.71928406px;line-height:125%;font-family:sans-serif;letter-spacing:0px;word-spacing:0px;fill:#000000;fill-opacity:1;stroke:none;stroke-width:1px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1"
|
||||
x="148.01875"
|
||||
y="122.35196"
|
||||
id="text5062"
|
||||
sodipodi:linespacing="125%"><tspan
|
||||
sodipodi:role="line"
|
||||
id="tspan5064"
|
||||
x="148.01875"
|
||||
y="122.35196"
|
||||
style="font-weight:bold">DL Models</tspan></text>
|
||||
<rect
|
||||
y="157.87215"
|
||||
x="96.469566"
|
||||
height="48.674656"
|
||||
width="195.46452"
|
||||
id="rect5066"
|
||||
style="opacity:1;fill:none;fill-opacity:1;fill-rule:nonzero;stroke:#000000;stroke-width:1.5;stroke-linecap:round;stroke-linejoin:miter;stroke-miterlimit:4;stroke-dasharray:3.00000006, 3.00000006;stroke-dashoffset:0;stroke-opacity:1" />
|
||||
<text
|
||||
sodipodi:linespacing="125%"
|
||||
id="text5068"
|
||||
y="173.86671"
|
||||
x="112.07446"
|
||||
style="font-style:normal;font-weight:normal;font-size:12.89062405px;line-height:125%;font-family:sans-serif;letter-spacing:0px;word-spacing:0px;fill:#000000;fill-opacity:1;stroke:none;stroke-width:1px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1"
|
||||
xml:space="preserve"><tspan
|
||||
y="173.86671"
|
||||
x="112.07446"
|
||||
id="tspan5070"
|
||||
sodipodi:role="line"
|
||||
style="font-weight:bold">Common DL Operators</tspan></text>
|
||||
<rect
|
||||
style="opacity:1;fill:none;fill-opacity:1;fill-rule:nonzero;stroke:#000000;stroke-width:1.5;stroke-linecap:round;stroke-linejoin:miter;stroke-miterlimit:4;stroke-dasharray:none;stroke-dashoffset:0;stroke-opacity:1"
|
||||
id="rect5072"
|
||||
width="46.972092"
|
||||
height="15.657365"
|
||||
x="102.53049"
|
||||
y="184.89371" />
|
||||
<text
|
||||
xml:space="preserve"
|
||||
style="font-style:normal;font-weight:normal;font-size:9.49497509px;line-height:125%;font-family:sans-serif;letter-spacing:0px;word-spacing:0px;fill:#000000;fill-opacity:1;stroke:none;stroke-width:1.5;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1;stroke-miterlimit:4;stroke-dasharray:none"
|
||||
x="113.84646"
|
||||
y="196.1787"
|
||||
id="text5074"
|
||||
sodipodi:linespacing="125%"><tspan
|
||||
sodipodi:role="line"
|
||||
id="tspan5076"
|
||||
x="113.84646"
|
||||
y="196.1787">Conv</tspan></text>
|
||||
<rect
|
||||
y="184.89371"
|
||||
x="102.53049"
|
||||
height="15.657365"
|
||||
width="46.972092"
|
||||
id="rect5084"
|
||||
style="opacity:1;fill:none;fill-opacity:1;fill-rule:nonzero;stroke:#000000;stroke-width:1.5;stroke-linecap:round;stroke-linejoin:miter;stroke-miterlimit:4;stroke-dasharray:none;stroke-dashoffset:0;stroke-opacity:1" />
|
||||
<text
|
||||
sodipodi:linespacing="125%"
|
||||
id="text5086"
|
||||
y="196.1787"
|
||||
x="113.84646"
|
||||
style="font-style:normal;font-weight:normal;font-size:9.49497509px;line-height:125%;font-family:sans-serif;letter-spacing:0px;word-spacing:0px;fill:#000000;fill-opacity:1;stroke:none;stroke-width:1.5;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1;stroke-miterlimit:4;stroke-dasharray:none"
|
||||
xml:space="preserve"><tspan
|
||||
y="196.1787"
|
||||
x="113.84646"
|
||||
id="tspan5088"
|
||||
sodipodi:role="line">Conv</tspan></text>
|
||||
<rect
|
||||
y="184.89371"
|
||||
x="156.82619"
|
||||
height="15.657365"
|
||||
width="46.972092"
|
||||
id="rect5090"
|
||||
style="opacity:1;fill:none;fill-opacity:1;fill-rule:nonzero;stroke:#000000;stroke-width:1.5;stroke-linecap:round;stroke-linejoin:miter;stroke-miterlimit:4;stroke-dasharray:none;stroke-dashoffset:0;stroke-opacity:1" />
|
||||
<text
|
||||
sodipodi:linespacing="125%"
|
||||
id="text5092"
|
||||
y="196.1161"
|
||||
x="167.29837"
|
||||
style="font-style:normal;font-weight:normal;font-size:9.49497509px;line-height:125%;font-family:sans-serif;letter-spacing:0px;word-spacing:0px;fill:#000000;fill-opacity:1;stroke:none;stroke-width:1.5;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1;stroke-miterlimit:4;stroke-dasharray:none"
|
||||
xml:space="preserve"><tspan
|
||||
y="196.1161"
|
||||
x="167.29837"
|
||||
id="tspan5094"
|
||||
sodipodi:role="line">Norm</tspan></text>
|
||||
<rect
|
||||
style="opacity:1;fill:none;fill-opacity:1;fill-rule:nonzero;stroke:#000000;stroke-width:1.5;stroke-linecap:round;stroke-linejoin:miter;stroke-miterlimit:4;stroke-dasharray:none;stroke-dashoffset:0;stroke-opacity:1"
|
||||
id="rect5096"
|
||||
width="46.972092"
|
||||
height="15.657365"
|
||||
x="156.57365"
|
||||
y="184.89371" />
|
||||
<text
|
||||
xml:space="preserve"
|
||||
style="font-style:normal;font-weight:normal;font-size:9.49497509px;line-height:125%;font-family:sans-serif;letter-spacing:0px;word-spacing:0px;fill:#000000;fill-opacity:1;stroke:none;stroke-width:1px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1"
|
||||
x="168.35912"
|
||||
y="213.97815"
|
||||
id="text5098"
|
||||
sodipodi:linespacing="125%"><tspan
|
||||
sodipodi:role="line"
|
||||
id="tspan5100"
|
||||
x="168.35912"
|
||||
y="213.97815" /></text>
|
||||
<rect
|
||||
style="opacity:1;fill:none;fill-opacity:1;fill-rule:nonzero;stroke:#000000;stroke-width:1.5;stroke-linecap:round;stroke-linejoin:miter;stroke-miterlimit:4;stroke-dasharray:none;stroke-dashoffset:0;stroke-opacity:1"
|
||||
id="rect5102"
|
||||
width="46.972092"
|
||||
height="15.657365"
|
||||
x="210.61681"
|
||||
y="184.89371" />
|
||||
<rect
|
||||
y="184.89371"
|
||||
x="210.61681"
|
||||
height="15.657365"
|
||||
width="46.972092"
|
||||
id="rect5108"
|
||||
style="opacity:1;fill:none;fill-opacity:1;fill-rule:nonzero;stroke:#000000;stroke-width:1.5;stroke-linecap:round;stroke-linejoin:miter;stroke-miterlimit:4;stroke-dasharray:none;stroke-dashoffset:0;stroke-opacity:1" />
|
||||
<text
|
||||
sodipodi:linespacing="125%"
|
||||
id="text5110"
|
||||
y="196.26215"
|
||||
x="224.26016"
|
||||
style="font-style:normal;font-weight:normal;font-size:9.49497509px;line-height:125%;font-family:sans-serif;letter-spacing:0px;word-spacing:0px;fill:#000000;fill-opacity:1;stroke:none;stroke-width:1.5;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1;stroke-miterlimit:4;stroke-dasharray:none"
|
||||
xml:space="preserve"><tspan
|
||||
y="196.26215"
|
||||
x="224.26016"
|
||||
id="tspan5112"
|
||||
sodipodi:role="line">Pool</tspan></text>
|
||||
<circle
|
||||
style="opacity:1;fill:#000000;fill-opacity:1;fill-rule:nonzero;stroke:none;stroke-width:1;stroke-linecap:round;stroke-linejoin:miter;stroke-miterlimit:4;stroke-dasharray:none;stroke-dashoffset:0;stroke-opacity:1"
|
||||
id="path5116"
|
||||
cx="265.79639"
|
||||
cy="193.35374"
|
||||
r="1.8940361" />
|
||||
<circle
|
||||
r="1.8940361"
|
||||
cy="193.35374"
|
||||
cx="271.8573"
|
||||
id="circle5118"
|
||||
style="opacity:1;fill:#000000;fill-opacity:1;fill-rule:nonzero;stroke:none;stroke-width:1;stroke-linecap:round;stroke-linejoin:miter;stroke-miterlimit:4;stroke-dasharray:none;stroke-dashoffset:0;stroke-opacity:1" />
|
||||
<circle
|
||||
style="opacity:1;fill:#000000;fill-opacity:1;fill-rule:nonzero;stroke:none;stroke-width:1;stroke-linecap:round;stroke-linejoin:miter;stroke-miterlimit:4;stroke-dasharray:none;stroke-dashoffset:0;stroke-opacity:1"
|
||||
id="circle5120"
|
||||
cx="277.91821"
|
||||
cy="193.35374"
|
||||
r="1.8940361" />
|
||||
<path
|
||||
style="fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:2;stroke-linecap:butt;stroke-linejoin:miter;stroke-miterlimit:4;stroke-dasharray:none;stroke-opacity:1;marker-end:url(#Arrow1Mend)"
|
||||
d="m 194.20182,156.87315 -1e-5,-24.59346"
|
||||
id="path5122"
|
||||
inkscape:connector-type="polyline"
|
||||
inkscape:connector-curvature="0" />
|
||||
<rect
|
||||
style="opacity:1;fill:none;fill-opacity:1;fill-rule:nonzero;stroke:#000000;stroke-width:1.5;stroke-linecap:round;stroke-linejoin:miter;stroke-miterlimit:4;stroke-dasharray:3, 3;stroke-dashoffset:0;stroke-opacity:1"
|
||||
id="rect5582"
|
||||
width="195.46452"
|
||||
height="142.93935"
|
||||
x="96.469566"
|
||||
y="233.63358" />
|
||||
<text
|
||||
xml:space="preserve"
|
||||
style="font-style:normal;font-weight:normal;font-size:12.89062405px;line-height:125%;font-family:sans-serif;letter-spacing:0px;word-spacing:0px;fill:#000000;fill-opacity:1;stroke:none;stroke-width:1px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1"
|
||||
x="143.14287"
|
||||
y="249.62816"
|
||||
id="text5584"
|
||||
sodipodi:linespacing="125%"><tspan
|
||||
sodipodi:role="line"
|
||||
id="tspan5586"
|
||||
x="143.14287"
|
||||
y="249.62816"
|
||||
style="font-weight:bold">Meta-Operators</tspan></text>
|
||||
<rect
|
||||
y="279.53458"
|
||||
x="102.53049"
|
||||
height="15.657365"
|
||||
width="54.928276"
|
||||
id="rect5588"
|
||||
style="opacity:1;fill:none;fill-opacity:1;fill-rule:nonzero;stroke:#000000;stroke-width:1.5;stroke-linecap:round;stroke-linejoin:miter;stroke-miterlimit:4;stroke-dasharray:none;stroke-dashoffset:0;stroke-opacity:1" />
|
||||
<rect
|
||||
style="opacity:1;fill:none;fill-opacity:1;fill-rule:nonzero;stroke:#000000;stroke-width:1.5;stroke-linecap:round;stroke-linejoin:miter;stroke-miterlimit:4;stroke-dasharray:none;stroke-dashoffset:0;stroke-opacity:1"
|
||||
id="rect5600"
|
||||
width="54.928276"
|
||||
height="15.657365"
|
||||
x="166.02286"
|
||||
y="279.53458" />
|
||||
<rect
|
||||
y="279.53458"
|
||||
x="228.92459"
|
||||
height="15.657365"
|
||||
width="54.928276"
|
||||
id="rect5608"
|
||||
style="opacity:1;fill:none;fill-opacity:1;fill-rule:nonzero;stroke:#000000;stroke-width:1.5;stroke-linecap:round;stroke-linejoin:miter;stroke-miterlimit:4;stroke-dasharray:none;stroke-dashoffset:0;stroke-opacity:1" />
|
||||
<path
|
||||
style="fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:2;stroke-linecap:butt;stroke-linejoin:miter;stroke-miterlimit:4;stroke-dasharray:none;stroke-opacity:1;marker-end:url(#marker5665)"
|
||||
d="m 194.20183,233.63358 0,-27.08678"
|
||||
id="path5626"
|
||||
inkscape:connector-type="polyline"
|
||||
inkscape:connector-curvature="0"
|
||||
inkscape:connection-start="#rect5582"
|
||||
inkscape:connection-end="#rect5066" />
|
||||
<text
|
||||
xml:space="preserve"
|
||||
style="font-style:normal;font-weight:normal;font-size:9.88953686px;line-height:125%;font-family:sans-serif;letter-spacing:0px;word-spacing:0px;fill:#000000;fill-opacity:1;stroke:none;stroke-width:1.5;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1;stroke-miterlimit:4;stroke-dasharray:none"
|
||||
x="109.88236"
|
||||
y="291.05011"
|
||||
id="text5628"
|
||||
sodipodi:linespacing="125%"><tspan
|
||||
sodipodi:role="line"
|
||||
id="tspan5630"
|
||||
x="109.88236"
|
||||
y="291.05011">Reindex</tspan></text>
|
||||
<text
|
||||
sodipodi:linespacing="125%"
|
||||
id="text5632"
|
||||
y="289.71198"
|
||||
x="168.07721"
|
||||
style="font-style:normal;font-weight:normal;font-size:6.30014944px;line-height:125%;font-family:sans-serif;letter-spacing:0px;word-spacing:0px;fill:#000000;fill-opacity:1;stroke:none;stroke-width:1.5;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1;stroke-miterlimit:4;stroke-dasharray:none"
|
||||
xml:space="preserve"><tspan
|
||||
y="289.71198"
|
||||
x="168.07721"
|
||||
id="tspan5634"
|
||||
sodipodi:role="line">Reindex Reduce</tspan></text>
|
||||
<text
|
||||
sodipodi:linespacing="125%"
|
||||
id="text5636"
|
||||
y="290.1658"
|
||||
x="230.89981"
|
||||
style="font-style:normal;font-weight:normal;font-size:7.51746845px;line-height:125%;font-family:sans-serif;letter-spacing:0px;word-spacing:0px;fill:#000000;fill-opacity:1;stroke:none;stroke-width:1.5;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1;stroke-miterlimit:4;stroke-dasharray:none"
|
||||
xml:space="preserve"><tspan
|
||||
y="290.1658"
|
||||
x="230.89981"
|
||||
id="tspan5638"
|
||||
sodipodi:role="line">Element-Wise</tspan></text>
|
||||
<rect
|
||||
style="opacity:1;fill:none;fill-opacity:1;fill-rule:nonzero;stroke:#000000;stroke-width:1;stroke-linecap:round;stroke-linejoin:miter;stroke-miterlimit:4;stroke-dasharray:none;stroke-dashoffset:0;stroke-opacity:1"
|
||||
id="rect5711"
|
||||
width="40.999706"
|
||||
height="11.694196"
|
||||
x="115.03664"
|
||||
y="310.45312" />
|
||||
<text
|
||||
sodipodi:linespacing="125%"
|
||||
id="text5713"
|
||||
y="318.95584"
|
||||
x="116.79535"
|
||||
style="font-style:normal;font-weight:normal;font-size:7.38404226px;line-height:125%;font-family:sans-serif;letter-spacing:0px;word-spacing:0px;fill:#000000;fill-opacity:1;stroke:none;stroke-width:1px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1"
|
||||
xml:space="preserve"
|
||||
transform="scale(0.9996927,1.0003074)"><tspan
|
||||
y="318.95584"
|
||||
x="116.79535"
|
||||
id="tspan5715"
|
||||
sodipodi:role="line">Broadcast</tspan></text>
|
||||
<rect
|
||||
y="327.42856"
|
||||
x="115.03664"
|
||||
height="11.694196"
|
||||
width="40.999706"
|
||||
id="rect5717"
|
||||
style="opacity:1;fill:none;fill-opacity:1;fill-rule:nonzero;stroke:#000000;stroke-width:1;stroke-linecap:round;stroke-linejoin:miter;stroke-miterlimit:4;stroke-dasharray:none;stroke-dashoffset:0;stroke-opacity:1" />
|
||||
<text
|
||||
xml:space="preserve"
|
||||
style="font-style:normal;font-weight:normal;font-size:7.38404226px;line-height:125%;font-family:sans-serif;letter-spacing:0px;word-spacing:0px;fill:#000000;fill-opacity:1;stroke:none;stroke-width:1px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1"
|
||||
x="128.88095"
|
||||
y="335.92606"
|
||||
id="text5719"
|
||||
sodipodi:linespacing="125%"
|
||||
transform="scale(0.9996927,1.0003074)"><tspan
|
||||
sodipodi:role="line"
|
||||
id="tspan5721"
|
||||
x="128.88095"
|
||||
y="335.92606">Pad</tspan></text>
|
||||
<rect
|
||||
style="opacity:1;fill:none;fill-opacity:1;fill-rule:nonzero;stroke:#000000;stroke-width:1;stroke-linecap:round;stroke-linejoin:miter;stroke-miterlimit:4;stroke-dasharray:none;stroke-dashoffset:0;stroke-opacity:1"
|
||||
id="rect5723"
|
||||
width="40.999706"
|
||||
height="11.694196"
|
||||
x="115.03664"
|
||||
y="344.40402" />
|
||||
<text
|
||||
sodipodi:linespacing="125%"
|
||||
id="text5725"
|
||||
y="352.8963"
|
||||
x="126.83484"
|
||||
style="font-style:normal;font-weight:normal;font-size:7.38404226px;line-height:125%;font-family:sans-serif;letter-spacing:0px;word-spacing:0px;fill:#000000;fill-opacity:1;stroke:none;stroke-width:1px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1"
|
||||
xml:space="preserve"
|
||||
transform="scale(0.9996927,1.0003074)"><tspan
|
||||
y="352.8963"
|
||||
x="126.83484"
|
||||
id="tspan5727"
|
||||
sodipodi:role="line">Slice</tspan></text>
|
||||
<path
|
||||
style="fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:1px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1;marker-end:url(#Arrow1Send)"
|
||||
d="m 150.36843,279.29918 0,-7.44793 22.61636,0 0,6.01905"
|
||||
id="path5733"
|
||||
inkscape:connector-curvature="0"
|
||||
sodipodi:nodetypes="cccc" />
|
||||
<path
|
||||
inkscape:connector-curvature="0"
|
||||
id="path7661"
|
||||
d="m 176.042,279.30845 0,-9.89119 -27.43443,0 0,8.47063"
|
||||
style="fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:1px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1;marker-end:url(#marker7663)"
|
||||
sodipodi:nodetypes="cccc" />
|
||||
<text
|
||||
xml:space="preserve"
|
||||
style="font-style:normal;font-weight:normal;font-size:6.72696304px;line-height:125%;font-family:sans-serif;letter-spacing:0px;word-spacing:0px;fill:#000000;fill-opacity:1;stroke:none;stroke-width:1px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1"
|
||||
x="145.52654"
|
||||
y="266.45084"
|
||||
id="text9185"
|
||||
sodipodi:linespacing="125%"><tspan
|
||||
sodipodi:role="line"
|
||||
id="tspan9187"
|
||||
x="145.52654"
|
||||
y="266.45084">Backward</tspan></text>
|
||||
<path
|
||||
sodipodi:nodetypes="cccc"
|
||||
inkscape:connector-curvature="0"
|
||||
id="path9303"
|
||||
d="m 243.69211,277.26609 0,-6.31376 23.24776,0 0,6.06777"
|
||||
style="fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:1px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1;marker-start:url(#Arrow1Sstart);marker-end:url(#marker9311)" />
|
||||
<text
|
||||
sodipodi:linespacing="125%"
|
||||
id="text9307"
|
||||
y="268.72366"
|
||||
x="239.09192"
|
||||
style="font-style:normal;font-weight:normal;font-size:6.72696304px;line-height:125%;font-family:sans-serif;letter-spacing:0px;word-spacing:0px;fill:#000000;fill-opacity:1;stroke:none;stroke-width:1px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1"
|
||||
xml:space="preserve"><tspan
|
||||
y="268.72366"
|
||||
x="239.09192"
|
||||
id="tspan9309"
|
||||
sodipodi:role="line">Backward</tspan></text>
|
||||
<path
|
||||
style="fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:1px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1;marker-end:url(#marker15385)"
|
||||
d="m 105.3084,295.44449 0,72.22591 7.32361,0"
|
||||
id="path15377"
|
||||
inkscape:connector-curvature="0" />
|
||||
<circle
|
||||
r="1.8940361"
|
||||
cy="367.79669"
|
||||
cx="127.15295"
|
||||
id="circle15749"
|
||||
style="opacity:1;fill:#000000;fill-opacity:1;fill-rule:nonzero;stroke:none;stroke-width:1;stroke-linecap:round;stroke-linejoin:miter;stroke-miterlimit:4;stroke-dasharray:none;stroke-dashoffset:0;stroke-opacity:1" />
|
||||
<circle
|
||||
style="opacity:1;fill:#000000;fill-opacity:1;fill-rule:nonzero;stroke:none;stroke-width:1;stroke-linecap:round;stroke-linejoin:miter;stroke-miterlimit:4;stroke-dasharray:none;stroke-dashoffset:0;stroke-opacity:1"
|
||||
id="circle15751"
|
||||
cx="133.21387"
|
||||
cy="367.79669"
|
||||
r="1.8940361" />
|
||||
<circle
|
||||
r="1.8940361"
|
||||
cy="367.79669"
|
||||
cx="139.27478"
|
||||
id="circle15753"
|
||||
style="opacity:1;fill:#000000;fill-opacity:1;fill-rule:nonzero;stroke:none;stroke-width:1;stroke-linecap:round;stroke-linejoin:miter;stroke-miterlimit:4;stroke-dasharray:none;stroke-dashoffset:0;stroke-opacity:1" />
|
||||
<path
|
||||
inkscape:connector-curvature="0"
|
||||
id="path15761"
|
||||
d="m 105.3084,295.6514 0,54.44285 7.32361,0"
|
||||
style="fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:0.99999994px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1;marker-end:url(#marker15763)" />
|
||||
<path
|
||||
style="fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:0.99999988px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1;marker-end:url(#marker16111)"
|
||||
d="m 105.3084,295.90679 0,37.89779 7.32361,0"
|
||||
id="path16109"
|
||||
inkscape:connector-curvature="0" />
|
||||
<path
|
||||
inkscape:connector-curvature="0"
|
||||
id="path16427"
|
||||
d="m 105.3084,296.29759 0,20.27361 7.32361,0"
|
||||
style="fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:0.99999988px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1;marker-end:url(#marker16429)" />
|
||||
<rect
|
||||
y="310.45312"
|
||||
x="179.67949"
|
||||
height="11.694196"
|
||||
width="40.999706"
|
||||
id="rect16763"
|
||||
style="opacity:1;fill:none;fill-opacity:1;fill-rule:nonzero;stroke:#000000;stroke-width:1;stroke-linecap:round;stroke-linejoin:miter;stroke-miterlimit:4;stroke-dasharray:none;stroke-dashoffset:0;stroke-opacity:1" />
|
||||
<text
|
||||
transform="scale(0.9996927,1.0003074)"
|
||||
xml:space="preserve"
|
||||
style="font-style:normal;font-weight:normal;font-size:7.38404226px;line-height:125%;font-family:sans-serif;letter-spacing:0px;word-spacing:0px;fill:#000000;fill-opacity:1;stroke:none;stroke-width:1px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1"
|
||||
x="186.41022"
|
||||
y="318.95584"
|
||||
id="text16765"
|
||||
sodipodi:linespacing="125%"><tspan
|
||||
sodipodi:role="line"
|
||||
id="tspan16767"
|
||||
x="186.41022"
|
||||
y="318.95584">Reduce</tspan></text>
|
||||
<rect
|
||||
style="opacity:1;fill:none;fill-opacity:1;fill-rule:nonzero;stroke:#000000;stroke-width:1;stroke-linecap:round;stroke-linejoin:miter;stroke-miterlimit:4;stroke-dasharray:none;stroke-dashoffset:0;stroke-opacity:1"
|
||||
id="rect16769"
|
||||
width="40.999706"
|
||||
height="11.694196"
|
||||
x="179.67949"
|
||||
y="327.42856" />
|
||||
<text
|
||||
transform="scale(0.9996927,1.0003074)"
|
||||
sodipodi:linespacing="125%"
|
||||
id="text16771"
|
||||
y="335.92606"
|
||||
x="185.9361"
|
||||
style="font-style:normal;font-weight:normal;font-size:7.38404226px;line-height:125%;font-family:sans-serif;letter-spacing:0px;word-spacing:0px;fill:#000000;fill-opacity:1;stroke:none;stroke-width:1px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1"
|
||||
xml:space="preserve"><tspan
|
||||
y="335.92606"
|
||||
x="185.9361"
|
||||
id="tspan16773"
|
||||
sodipodi:role="line">Product</tspan></text>
|
||||
<rect
|
||||
y="344.40402"
|
||||
x="179.67949"
|
||||
height="11.694196"
|
||||
width="40.999706"
|
||||
id="rect16775"
|
||||
style="opacity:1;fill:none;fill-opacity:1;fill-rule:nonzero;stroke:#000000;stroke-width:1;stroke-linecap:round;stroke-linejoin:miter;stroke-miterlimit:4;stroke-dasharray:none;stroke-dashoffset:0;stroke-opacity:1" />
|
||||
<text
|
||||
transform="scale(0.9996927,1.0003074)"
|
||||
xml:space="preserve"
|
||||
style="font-style:normal;font-weight:normal;font-size:7.38404226px;line-height:125%;font-family:sans-serif;letter-spacing:0px;word-spacing:0px;fill:#000000;fill-opacity:1;stroke:none;stroke-width:1px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1"
|
||||
x="192.02757"
|
||||
y="352.8963"
|
||||
id="text16777"
|
||||
sodipodi:linespacing="125%"><tspan
|
||||
sodipodi:role="line"
|
||||
id="tspan16779"
|
||||
x="192.02757"
|
||||
y="352.8963">Sum</tspan></text>
|
||||
<path
|
||||
inkscape:connector-curvature="0"
|
||||
id="path16781"
|
||||
d="m 169.95126,295.44449 0,72.22591 7.32361,0"
|
||||
style="fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:1px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1;marker-end:url(#marker15385)" />
|
||||
<circle
|
||||
style="opacity:1;fill:#000000;fill-opacity:1;fill-rule:nonzero;stroke:none;stroke-width:1;stroke-linecap:round;stroke-linejoin:miter;stroke-miterlimit:4;stroke-dasharray:none;stroke-dashoffset:0;stroke-opacity:1"
|
||||
id="circle16783"
|
||||
cx="191.79581"
|
||||
cy="367.79669"
|
||||
r="1.8940361" />
|
||||
<circle
|
||||
r="1.8940361"
|
||||
cy="367.79669"
|
||||
cx="197.85672"
|
||||
id="circle16785"
|
||||
style="opacity:1;fill:#000000;fill-opacity:1;fill-rule:nonzero;stroke:none;stroke-width:1;stroke-linecap:round;stroke-linejoin:miter;stroke-miterlimit:4;stroke-dasharray:none;stroke-dashoffset:0;stroke-opacity:1" />
|
||||
<circle
|
||||
style="opacity:1;fill:#000000;fill-opacity:1;fill-rule:nonzero;stroke:none;stroke-width:1;stroke-linecap:round;stroke-linejoin:miter;stroke-miterlimit:4;stroke-dasharray:none;stroke-dashoffset:0;stroke-opacity:1"
|
||||
id="circle16787"
|
||||
cx="203.91763"
|
||||
cy="367.79669"
|
||||
r="1.8940361" />
|
||||
<path
|
||||
style="fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:0.99999994px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1;marker-end:url(#marker15763)"
|
||||
d="m 169.95126,295.6514 0,54.44285 7.32361,0"
|
||||
id="path16789"
|
||||
inkscape:connector-curvature="0" />
|
||||
<path
|
||||
inkscape:connector-curvature="0"
|
||||
id="path16791"
|
||||
d="m 169.95126,295.90679 0,37.89779 7.32361,0"
|
||||
style="fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:0.99999988px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1;marker-end:url(#marker16111)" />
|
||||
<path
|
||||
style="fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:0.99999988px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1;marker-end:url(#marker16795)"
|
||||
d="m 169.95126,296.29759 0,20.27361 7.32361,0"
|
||||
id="path16793"
|
||||
inkscape:connector-curvature="0" />
|
||||
<rect
|
||||
style="opacity:1;fill:none;fill-opacity:1;fill-rule:nonzero;stroke:#000000;stroke-width:1;stroke-linecap:round;stroke-linejoin:miter;stroke-miterlimit:4;stroke-dasharray:none;stroke-dashoffset:0;stroke-opacity:1"
|
||||
id="rect16925"
|
||||
width="40.999706"
|
||||
height="11.694196"
|
||||
x="242.53664"
|
||||
y="310.45312" />
|
||||
<text
|
||||
sodipodi:linespacing="125%"
|
||||
id="text16927"
|
||||
y="318.95584"
|
||||
x="251.88805"
|
||||
style="font-style:normal;font-weight:normal;font-size:7.38404226px;line-height:125%;font-family:sans-serif;letter-spacing:0px;word-spacing:0px;fill:#000000;fill-opacity:1;stroke:none;stroke-width:1px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1"
|
||||
xml:space="preserve"
|
||||
transform="scale(0.9996927,1.0003074)"><tspan
|
||||
y="318.95584"
|
||||
x="251.88805"
|
||||
id="tspan16929"
|
||||
sodipodi:role="line">Unary</tspan></text>
|
||||
<rect
|
||||
y="327.42856"
|
||||
x="242.53664"
|
||||
height="11.694196"
|
||||
width="40.999706"
|
||||
id="rect16931"
|
||||
style="opacity:1;fill:none;fill-opacity:1;fill-rule:nonzero;stroke:#000000;stroke-width:1;stroke-linecap:round;stroke-linejoin:miter;stroke-miterlimit:4;stroke-dasharray:none;stroke-dashoffset:0;stroke-opacity:1" />
|
||||
<text
|
||||
xml:space="preserve"
|
||||
style="font-style:normal;font-weight:normal;font-size:7.38404226px;line-height:125%;font-family:sans-serif;letter-spacing:0px;word-spacing:0px;fill:#000000;fill-opacity:1;stroke:none;stroke-width:1px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1"
|
||||
x="250.9957"
|
||||
y="335.92606"
|
||||
id="text16933"
|
||||
sodipodi:linespacing="125%"
|
||||
transform="scale(0.9996927,1.0003074)"><tspan
|
||||
sodipodi:role="line"
|
||||
id="tspan16935"
|
||||
x="250.9957"
|
||||
y="335.92606">Binary</tspan></text>
|
||||
<rect
|
||||
style="opacity:1;fill:none;fill-opacity:1;fill-rule:nonzero;stroke:#000000;stroke-width:1;stroke-linecap:round;stroke-linejoin:miter;stroke-miterlimit:4;stroke-dasharray:none;stroke-dashoffset:0;stroke-opacity:1"
|
||||
id="rect16937"
|
||||
width="40.999706"
|
||||
height="11.694196"
|
||||
x="242.53664"
|
||||
y="344.40402" />
|
||||
<text
|
||||
sodipodi:linespacing="125%"
|
||||
id="text16939"
|
||||
y="352.8963"
|
||||
x="249.56612"
|
||||
style="font-style:normal;font-weight:normal;font-size:7.38404226px;line-height:125%;font-family:sans-serif;letter-spacing:0px;word-spacing:0px;fill:#000000;fill-opacity:1;stroke:none;stroke-width:1px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1"
|
||||
xml:space="preserve"
|
||||
transform="scale(0.9996927,1.0003074)"><tspan
|
||||
y="352.8963"
|
||||
x="249.56612"
|
||||
id="tspan16941"
|
||||
sodipodi:role="line">Ternary</tspan></text>
|
||||
<path
|
||||
style="fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:1px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1;marker-end:url(#marker15385)"
|
||||
d="m 232.8084,295.44449 0,72.22591 7.32361,0"
|
||||
id="path16943"
|
||||
inkscape:connector-curvature="0" />
|
||||
<circle
|
||||
r="1.8940361"
|
||||
cy="367.79669"
|
||||
cx="254.65295"
|
||||
id="circle16945"
|
||||
style="opacity:1;fill:#000000;fill-opacity:1;fill-rule:nonzero;stroke:none;stroke-width:1;stroke-linecap:round;stroke-linejoin:miter;stroke-miterlimit:4;stroke-dasharray:none;stroke-dashoffset:0;stroke-opacity:1" />
|
||||
<circle
|
||||
style="opacity:1;fill:#000000;fill-opacity:1;fill-rule:nonzero;stroke:none;stroke-width:1;stroke-linecap:round;stroke-linejoin:miter;stroke-miterlimit:4;stroke-dasharray:none;stroke-dashoffset:0;stroke-opacity:1"
|
||||
id="circle16947"
|
||||
cx="260.71387"
|
||||
cy="367.79669"
|
||||
r="1.8940361" />
|
||||
<circle
|
||||
r="1.8940361"
|
||||
cy="367.79669"
|
||||
cx="266.77478"
|
||||
id="circle16949"
|
||||
style="opacity:1;fill:#000000;fill-opacity:1;fill-rule:nonzero;stroke:none;stroke-width:1;stroke-linecap:round;stroke-linejoin:miter;stroke-miterlimit:4;stroke-dasharray:none;stroke-dashoffset:0;stroke-opacity:1" />
|
||||
<path
|
||||
inkscape:connector-curvature="0"
|
||||
id="path16951"
|
||||
d="m 232.8084,295.6514 0,54.44285 7.32361,0"
|
||||
style="fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:0.99999994px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1;marker-end:url(#marker15763)" />
|
||||
<path
|
||||
style="fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:0.99999988px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1;marker-end:url(#marker16111)"
|
||||
d="m 232.8084,295.90679 0,37.89779 7.32361,0"
|
||||
id="path16953"
|
||||
inkscape:connector-curvature="0" />
|
||||
<path
|
||||
inkscape:connector-curvature="0"
|
||||
id="path16955"
|
||||
d="m 232.8084,296.29759 0,20.27361 7.32361,0"
|
||||
style="fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:0.99999988px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1;marker-end:url(#marker18253)" />
|
||||
</g>
|
||||
</svg>
|
After Width: | Height: | Size: 40 KiB |
|
@ -0,0 +1,66 @@
|
|||
#!python3
|
||||
import os, json
|
||||
from pathlib import Path
|
||||
notebook_dir = os.path.join(str(Path.home()), ".cache","jittor","notebook")
|
||||
if not os.path.isdir(notebook_dir):
|
||||
os.mkdir(notebook_dir)
|
||||
dirname = os.path.dirname(__file__)
|
||||
all_md = []
|
||||
for r, _, f in os.walk(dirname):
|
||||
for fname in f:
|
||||
if not fname.endswith(".md"): continue
|
||||
all_md.append(os.path.join(r, fname))
|
||||
for mdname in all_md:
|
||||
with open(os.path.join(dirname, mdname), "r") as f:
|
||||
src = f.read()
|
||||
blocks = []
|
||||
for i, b in enumerate(src.split("```")):
|
||||
b = b.strip()
|
||||
is_markdown_block = i%2==0
|
||||
if not is_markdown_block and not b.startswith("python"):
|
||||
is_markdown_block = True
|
||||
b = "```\n"+b+"\n```"
|
||||
if is_markdown_block:
|
||||
# in a markdown block
|
||||
if len(blocks)%2==0:
|
||||
# prev code block
|
||||
blocks.append(b)
|
||||
else:
|
||||
# prev markdown block
|
||||
blocks[-1] += "\n\n" + b
|
||||
else:
|
||||
# in a code block
|
||||
if b.startswith("python"):
|
||||
b = b[6:].strip()
|
||||
# prev markdown block
|
||||
assert len(blocks)%2==1
|
||||
blocks.append(b)
|
||||
cells = []
|
||||
for i, b in enumerate(blocks):
|
||||
b = b.strip()
|
||||
if len(b)==0: continue
|
||||
b = b.split("\n")
|
||||
for j in range(len(b)-1):
|
||||
b[j] += '\n'
|
||||
cell = {
|
||||
"source": b,
|
||||
"metadata": {},
|
||||
}
|
||||
if i%2==0:
|
||||
cell["cell_type"] = "markdown"
|
||||
else:
|
||||
cell["cell_type"] = "code"
|
||||
cell["outputs"] = []
|
||||
cell["execution_count"] = None
|
||||
cells.append(cell)
|
||||
ipynb = {
|
||||
"cells":cells,
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2,
|
||||
"metadata": {
|
||||
},
|
||||
}
|
||||
ipynb_name = mdname[:-2]+"ipynb"
|
||||
print(mdname, len(src), len(blocks), len(cells), "--->", ipynb_name)
|
||||
with open(os.path.join(notebook_dir, ipynb_name), "w") as f:
|
||||
f.write(json.dumps(ipynb))
|
|
@ -0,0 +1,252 @@
|
|||
# Meta-operator: Implement your own convolution with Meta-operator
|
||||
|
||||
# 元算子:通过元算子实现自己的卷积层
|
||||
|
||||
Meta-operator is a key concept of jittor, The hierarchical architecture of meta-operators is shown below.
|
||||
|
||||
The meta-operators are consist of reindex, reindex-reduce and element-wise operators. Reindex and reindex-reduce operators are both unary operators. The reindex operator is a one-to-many mapping between its input and output. And the reindex-reduce operator is a many-to-one mapping. Broadcast, pad and slice operators are common reindex operators. And reduce, product and sum are common reindex-reduce operators. Element-wise operator is the third component of meta-operators. Compared to the first two, element-wise operators may contain multiple inputs. But all the input and output shapes of element-wise operators must be the same. And they are one-to-one mapped. For example, the addition of two variables is a binary element-wise operator.
|
||||
|
||||
元算子是jittor的关键概念,元算子的层次结构如下所示。
|
||||
|
||||
元算子由重索引算子,重索引化简算子和元素级算子组成。重索引算子,重索引化简算子都是一元算子。 重索引算子是其输入和输出之间的一对多映射。重索引简化算子是多对一映射。广播,填补, 切分算子是常见的重新索引算子。 而化简,累乘,累加算子是常见的索引化简算子。 元素级算子是元算子的第三部分,与前两个相比,元素算级子可能包含多个输入。 但是元素级算子的所有输入和输出形状必须相同,它们是一对一映射的。 例如,两个变量的加法是一个二进制的逐元素算子。
|
||||
|
||||
> ![](./figs/mop.svg)
|
||||
> The hierarchical architecture of meta-operators. The meta-operators are consist of reindex, reindex-reduce and element-wise operators. Reindex and reindex-reduce are each other's backward operators. The backward operators of element-wise operators are itself. Those meta-operators are fused into common DL operations, and these DL operators further constitute the model.
|
||||
>
|
||||
> 元算子的层级结构。元算子包含三类算子,重索引算子,重索引化简算子,元素级算子。元算
|
||||
> 子的反向传播算子还是元算子。元算子可以组成常用的深度学习算子。而这些深度学习算子又
|
||||
> 可以进一步组成深度学习模型。
|
||||
|
||||
In the previous [example](example.ipynb), we have demonstrated how to implement matrix multiplication via three meta-operators:
|
||||
|
||||
在第一个[示例](example.ipynb)中,我们演示了如何通过三个元算子实现矩阵乘法:
|
||||
|
||||
```
|
||||
def matmul(a, b):
|
||||
(n, m), k = a.shape, b.shape[-1]
|
||||
a = a.broadcast([n,m,k], dims=[2])
|
||||
b = b.broadcast([n,m,k], dims=[0])
|
||||
return (a*b).sum(dim=1)
|
||||
```
|
||||
|
||||
In this tutorial, we will show how to implement your own convolution with meta-operator.
|
||||
|
||||
First, let's implement a naive Python convolution:
|
||||
|
||||
在本教程中,我们将展示如何使用元算子实现自己的卷积。
|
||||
|
||||
首先,让我们实现一个朴素的Python卷积:
|
||||
|
||||
```
|
||||
import numpy as np
|
||||
import os
|
||||
def conv_naive(x, w):
|
||||
N,H,W,C = x.shape
|
||||
|
||||
Kh, Kw, _C, Kc = w.shape
|
||||
assert C==_C, (x.shape, w.shape)
|
||||
y = np.zeros([N,H-Kh+1,W-Kw+1,Kc])
|
||||
for i0 in range(N):
|
||||
for i1 in range(H-Kh+1): # dimension error
|
||||
for i2 in range(W-Kw+1):
|
||||
for i3 in range(Kh):
|
||||
for i4 in range(Kw):
|
||||
for i5 in range(C):
|
||||
for i6 in range(Kc):
|
||||
if i1-i3<0 or i2-i4<0 or i1-i3>=H or i2-i4>=W: continue
|
||||
y[i0, i1, i2, i6] += x[i0, i1 + i3, i2 + i4, i5] * w[i3,i4,i5,i6]
|
||||
return y
|
||||
```
|
||||
|
||||
Then, let's download a cat image, and run `conv_naive` with a simple horizontal filte.
|
||||
|
||||
然后,让我们下载一个猫的图像,并使用`conv_naive`实现一个简单的水平滤波器。
|
||||
|
||||
```
|
||||
# %matplotlib inline
|
||||
import pylab as pl
|
||||
img_path="/tmp/cat.jpg"
|
||||
if not os.path.isfile(img_path):
|
||||
!wget -O - 'https://upload.wikimedia.org/wikipedia/commons/thumb/4/4f/Felis_silvestris_catus_lying_on_rice_straw.jpg/220px-Felis_silvestris_catus_lying_on_rice_straw.jpg' > $img_path
|
||||
img = pl.imread(img_path)
|
||||
pl.subplot(121)
|
||||
pl.imshow(img)
|
||||
kernel = np.array([
|
||||
[-1, -1, -1],
|
||||
[0, 0, 0],
|
||||
[1, 1, 1],
|
||||
])
|
||||
pl.subplot(122)
|
||||
x = img[np.newaxis,:,:,:1].astype("float32")
|
||||
w = kernel[:,:,np.newaxis,np.newaxis].astype("float32")
|
||||
y = conv_naive(x, w)
|
||||
print (x.shape, y.shape) # shape exists confusion
|
||||
pl.imshow(y[0,:,:,0])
|
||||
```
|
||||
It looks good, our `naive_conv` works well. Let's replace our naive implementation with jittor.
|
||||
|
||||
看起来不错,我们的`naive_conv`运作良好。现在让我们用jittor替换我们的朴素实现。
|
||||
|
||||
```
|
||||
import jittor as jt
|
||||
|
||||
def conv(x, w):
|
||||
N,H,W,C = x.shape
|
||||
Kh, Kw, _C, Kc = w.shape
|
||||
assert C==_C
|
||||
xx = x.reindex([N,H-Kh+1,W-Kw+1,Kh,Kw,C,Kc], [
|
||||
'i0', # Nid
|
||||
'i1+i3', # Hid+Khid
|
||||
'i2+i4', # Wid+KWid
|
||||
'i5', # Cid|
|
||||
])
|
||||
ww = w.broadcast_var(xx)
|
||||
yy = xx*ww
|
||||
y = yy.sum([3,4,5]) # Kh, Kw, Kc
|
||||
return y
|
||||
|
||||
# Let's disable tuner. This will cause jittor not to use mkl for convolution
|
||||
jt.flags.enable_tuner = 0
|
||||
|
||||
jx = jt.array(x)
|
||||
jw = jt.array(w)
|
||||
jy = conv(jx, jw).fetch_sync()
|
||||
print (jx.shape, jy.shape)
|
||||
pl.imshow(jy[0,:,:,0])
|
||||
```
|
||||
|
||||
They looks the same. How about the performance?
|
||||
|
||||
他们的结果看起来一样。那么它们的性能如何?
|
||||
|
||||
```
|
||||
%time y = conv_naive(x, w)
|
||||
%time jy = conv(jx, jw).fetch_sync()
|
||||
```
|
||||
|
||||
The jittor implementation is much faster. So why this two implementation are equivalent in math, and why jittor's implementation is faster? We will explain step by step:
|
||||
|
||||
First, let's take a look at the help document of `jt.reindex`.
|
||||
|
||||
可以看出jittor的实现要快得多。 那么,为什么这两个实现在数学上等效,而jittor的实现运行速度更快? 我们将逐步进行解释:
|
||||
|
||||
首先,让我们看一下`jt.reindex`的帮助文档。
|
||||
|
||||
```
|
||||
help(jt.reindex)
|
||||
```
|
||||
|
||||
Following the document, we can expand the reindex operation for better understanding:
|
||||
|
||||
遵循该文档,我们可以扩展重索引操作以便更好地理解:
|
||||
|
||||
```
|
||||
py
|
||||
xx = x.reindex([N,H-Kh+1,W-Kw+1,Kh,Kw,C,Kc], [
|
||||
'i0', # Nid
|
||||
'i1+i3', # Hid+Khid
|
||||
'i2+i4', # Wid+KWid
|
||||
'i5', # Cid
|
||||
])
|
||||
ww = w.broadcast_var(xx)
|
||||
yy = xx*ww
|
||||
y = yy.sum([3,4,5]) # Kh, Kw, Kc
|
||||
```
|
||||
|
||||
**After expansion:**
|
||||
|
||||
扩展后:
|
||||
|
||||
```
|
||||
py
|
||||
shape = [N,H+Kh-1,W+Kw-1,Kh,Kw,C,Kc]
|
||||
# expansion of x.reindex
|
||||
xx = np.zeros(shape, x.dtype)
|
||||
for i0 in range(shape[0]):
|
||||
for i1 in range(shape[1]):
|
||||
for i2 in range(shape[2]):
|
||||
for i3 in range(shape[3]):
|
||||
for i4 in range(shape[4]):
|
||||
for i5 in range(shape[5]):
|
||||
for i6 in range(shape[6]):
|
||||
if is_overflow(i0,i1,i2,i3,i4,i5,i6):
|
||||
y[i0,i1,...,in] = 0
|
||||
else:
|
||||
y[i0,i1,i2,i3,i4,i5,i6] = x[i0,i1+i3,i2+i4,i5]
|
||||
|
||||
# expansion of w.broadcast_var(xx)
|
||||
ww = np.zeros(shape, x.dtype)
|
||||
for i0 in range(shape[0]):
|
||||
for i1 in range(shape[1]):
|
||||
for i2 in range(shape[2]):
|
||||
for i3 in range(shape[3]):
|
||||
for i4 in range(shape[4]):
|
||||
for i5 in range(shape[5]):
|
||||
for i6 in range(shape[6]):
|
||||
ww[i0,i1,i2,i3,i4,i5,i6] = w[i3,i4,i5,i6]
|
||||
# expansion of xx*ww
|
||||
yy = np.zeros(shape, x.dtype)
|
||||
for i0 in range(shape[0]):
|
||||
for i1 in range(shape[1]):
|
||||
for i2 in range(shape[2]):
|
||||
for i3 in range(shape[3]):
|
||||
for i4 in range(shape[4]):
|
||||
for i5 in range(shape[5]):
|
||||
for i6 in range(shape[6]):
|
||||
yy[i0,i1,i2,i3,i4,i5,i6] = xx[i0,i1,i2,i3,i4,i5,i6] * ww[i0,i1,i2,i3,i4,i5,i6]
|
||||
# expansion of yy.sum([3,4,5])
|
||||
shape2 = [N,H-Kh+1,W-Kw+1,Kc]
|
||||
y = np.zeros(shape2, x.dtype)
|
||||
for i0 in range(shape[0]):
|
||||
for i1 in range(shape[1]):
|
||||
for i2 in range(shape[2]):
|
||||
for i3 in range(shape[3]):
|
||||
for i4 in range(shape[4]):
|
||||
for i5 in range(shape[5]):
|
||||
for i6 in range(shape[6]):
|
||||
y[i0,i1,i2,i6] += yy[i0,i1,i2,i3,i4,i5,i6]
|
||||
```
|
||||
|
||||
**After loop fusion:**
|
||||
|
||||
循环融合后:
|
||||
|
||||
```
|
||||
py
|
||||
shape2 = [N,H-Kh+1,W-Kw+1,Kc]
|
||||
y = np.zeros(shape2, x.dtype)
|
||||
for i0 in range(shape[0]):
|
||||
for i1 in range(shape[1]):
|
||||
for i2 in range(shape[2]):
|
||||
for i3 in range(shape[3]):
|
||||
for i4 in range(shape[4]):
|
||||
for i5 in range(shape[5]):
|
||||
for i6 in range(shape[6]):
|
||||
if not is_overflow(i0,i1,i2,i3,i4,i5,i6):
|
||||
y[i0,i1,i2,i6] += x[i0,i1+i3,i2+i4,i5]
|
||||
```
|
||||
|
||||
This is the trick of meta-operator, It can fused multiple operator into a complicated operation, including many variation of convolution (e.g. group conv, seperate conv,...).
|
||||
|
||||
jittor will try to optimize the fused operator as fast as possible. Let's try some optimizations(compile the shapes as constants into the kernel), and show the underlying c++ kernel.
|
||||
|
||||
这是就元算子的优化技巧,它可以将多个算子融合为一个复杂的融合算子,包括许多卷积的变化(例如group conv,separate conv等)。
|
||||
|
||||
jittor会尝试将融合算子优化得尽可能快。 让我们尝试一些优化(将形状作为常量编译到内核中),并编译到底层的c++内核代码中。
|
||||
|
||||
|
||||
```
|
||||
jt.flags.compile_options={"compile_shapes":1}
|
||||
with jt.profile_scope() as report:
|
||||
jy = conv(jx, jw).fetch_sync()
|
||||
jt.flags.compile_options={}
|
||||
|
||||
print(f"Time: {float(report[1][4])/1e6}ms")
|
||||
|
||||
with open(report[1][1], 'r') as f:
|
||||
print(f.read())
|
||||
```
|
||||
|
||||
Even faster than the previous implementation! From the output we can look at the function definition of func0. This is the main code of our convolution kernel, which is generated Just-in-time. Because the compiler knows the shapes of the kernel and more optimizations are used.
|
||||
|
||||
比之前的实现还要更快! 从输出中我们可以看一看`func0`的函数定义,这是我们卷积内核的主要代码,该内核代码是即时生成的。因为编译器知道内核的形状,所以使用了更多的优化方法。
|
|
@ -0,0 +1,20 @@
|
|||
# Profiler: Profiling your model
|
||||
|
||||
# 性能分析器:分析您的模型
|
||||
|
||||
> NOTE: This tutorial is still working in progress
|
||||
|
||||
In this tutorial, we will show:
|
||||
1. how to profiling your model and check the elapsed time of each operation
|
||||
2. profiling the cache hit rate
|
||||
|
||||
> 注意:本教程仍在持续更新中
|
||||
|
||||
在本教程中,我们将展示:
|
||||
|
||||
1. 如何分析模型并检查每个运算的耗时
|
||||
2. 分析缓存命中率
|
||||
|
||||
```python
|
||||
import jittor as jt
|
||||
```
|
|
@ -0,0 +1,647 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
# Meng-Hao Guo <guomenghao1997@gmail.com>
|
||||
#
|
||||
# All Rights Reserved.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
from . import compiler
|
||||
from .compiler import LOG, has_cuda
|
||||
from .compiler import compile_custom_ops, compile_custom_op
|
||||
import jittor_core as core
|
||||
from jittor_core import *
|
||||
from jittor_core.ops import *
|
||||
from . import compile_extern
|
||||
from .compile_extern import mkl_ops
|
||||
|
||||
import contextlib
|
||||
import numpy as np
|
||||
from collections import OrderedDict
|
||||
import types
|
||||
import pickle
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
def dfs(scope, vars):
|
||||
for v in scope.children.values():
|
||||
if type(v) == Scope:
|
||||
dfs(v, vars)
|
||||
else:
|
||||
vars.append(v)
|
||||
|
||||
def dfs_records(scope, records):
|
||||
for v in scope.children.values():
|
||||
if type(v) == Scope:
|
||||
dfs_records(v, records)
|
||||
for v in scope.records.values():
|
||||
records.append(v)
|
||||
|
||||
class Scope:
|
||||
def __init__(self, parent=None, name=None):
|
||||
self.children = OrderedDict()
|
||||
self.index = {}
|
||||
self.records = OrderedDict()
|
||||
if name == None:
|
||||
self.name = self.full_name = ""
|
||||
else:
|
||||
self.name = name
|
||||
self.full_name = parent.full_name + name + "/"
|
||||
|
||||
def get_scope(self, name, unique=True):
|
||||
if not unique:
|
||||
index = self.index.get(name, 0)
|
||||
self.index[name] = index+1
|
||||
name = name + f'_{index}'
|
||||
if name not in self.children:
|
||||
sub_scope = Scope(self, name)
|
||||
self.children[name] = sub_scope
|
||||
else:
|
||||
sub_scope = self.children[name]
|
||||
assert type(sub_scope) == Scope, f"Name {name} is a Var: {sub_scope}"
|
||||
return sub_scope
|
||||
|
||||
def make_var(self, shape, dtype, init, name, unique):
|
||||
if not unique:
|
||||
index = self.index.get(name, 0)
|
||||
self.index[name] = index+1
|
||||
name = name + f'_{index}'
|
||||
if name in self.children:
|
||||
var = self.children[name]
|
||||
assert type(var) == core.Var, f"Name {name} exist: {var}"
|
||||
assert (shape is None or var.shape == shape) and var.dtype == dtype, f"Shape or dtype not match {var} != {dtype}{shape}"
|
||||
return var
|
||||
else:
|
||||
full_name = self.full_name + name
|
||||
if type(init) != core.Var:
|
||||
if callable(init):
|
||||
var = init(shape, dtype)
|
||||
if type(var) != core.Var:
|
||||
var = array(var)
|
||||
else:
|
||||
assert init != None
|
||||
var = array(init)
|
||||
else:
|
||||
var = init
|
||||
var.stop_fuse()
|
||||
self.children[name] = var
|
||||
var.name(full_name)
|
||||
return var
|
||||
|
||||
def clean_index(self): self.index.clear()
|
||||
|
||||
def clean(self):
|
||||
self.children.clear()
|
||||
self.records.clear()
|
||||
self.index.clear()
|
||||
|
||||
current_scope = Scope()
|
||||
root_scope = current_scope
|
||||
|
||||
class _call_record_scope:
|
||||
def __enter__(self): pass
|
||||
def __exit__(self, *exc): pass
|
||||
def __call__(self, func):
|
||||
def inner(*args, **kw):
|
||||
with self:
|
||||
ret = func(*args, **kw)
|
||||
record_in_scope(ret, "output")
|
||||
return ret
|
||||
return inner
|
||||
|
||||
class _call_no_record_scope:
|
||||
def __enter__(self): pass
|
||||
def __exit__(self, *exc): pass
|
||||
def __call__(self, func):
|
||||
def inner(*args, **kw):
|
||||
with self:
|
||||
ret = func(*args, **kw)
|
||||
return ret
|
||||
return inner
|
||||
|
||||
class flag_scope(_call_no_record_scope):
|
||||
def __init__(self, **jt_flags):
|
||||
self.jt_flags = jt_flags
|
||||
|
||||
def __enter__(self):
|
||||
flags_bk = self.flags_bk = {}
|
||||
try:
|
||||
for k,v in self.jt_flags.items():
|
||||
flags_bk[k] = getattr(flags, k)
|
||||
setattr(flags, k, v)
|
||||
except:
|
||||
self.__exit__()
|
||||
raise
|
||||
|
||||
def __exit__(self, *exc):
|
||||
for k,v in self.flags_bk.items():
|
||||
setattr(flags, k, v)
|
||||
|
||||
class var_scope(_call_record_scope):
|
||||
def __init__(self, name="scope", unique=False, **jt_flags):
|
||||
self.fs = flag_scope(**jt_flags)
|
||||
self.name = name
|
||||
self.unique = unique
|
||||
|
||||
def __enter__(self):
|
||||
global current_scope
|
||||
self.prev = current_scope
|
||||
try:
|
||||
current_scope = current_scope.get_scope(self.name, self.unique)
|
||||
current_scope.clean_index()
|
||||
self.fs.__enter__()
|
||||
except:
|
||||
current_scope = self.prev
|
||||
del self.prev
|
||||
raise
|
||||
|
||||
def __exit__(self, *exc):
|
||||
self.fs.__exit__(*exc)
|
||||
global current_scope
|
||||
current_scope = self.prev
|
||||
del self.prev
|
||||
|
||||
single_log_capture = None
|
||||
|
||||
class log_capture_scope(_call_no_record_scope):
|
||||
"""log capture scope
|
||||
example:
|
||||
with jt.log_capture_scope(log_v=0) as logs:
|
||||
LOG.v("...")
|
||||
print(logs)
|
||||
"""
|
||||
def __init__(self, **jt_flags):
|
||||
self.fs = flag_scope(**jt_flags)
|
||||
|
||||
def __enter__(self):
|
||||
global single_log_capture
|
||||
assert not single_log_capture
|
||||
single_log_capture = True
|
||||
self.logs = []
|
||||
LOG.log_capture_start()
|
||||
try:
|
||||
self.fs.__enter__()
|
||||
return self.logs
|
||||
except:
|
||||
LOG.log_capture_stop()
|
||||
single_log_capture = None
|
||||
raise
|
||||
|
||||
def __exit__(self, *exc):
|
||||
global single_log_capture
|
||||
self.fs.__exit__(*exc)
|
||||
LOG.log_capture_stop()
|
||||
self.logs.extend(LOG.log_capture_read())
|
||||
single_log_capture = None
|
||||
|
||||
|
||||
class profile_scope(_call_no_record_scope):
|
||||
""" profile scope
|
||||
example:
|
||||
with jt.profile_scope() as report:
|
||||
......
|
||||
print(report)
|
||||
"""
|
||||
def __init__(self, warmup=0, rerun=0, **jt_flags):
|
||||
self.fs = flag_scope(**jt_flags)
|
||||
self.warmup = warmup
|
||||
self.rerun = rerun
|
||||
|
||||
def __enter__(self):
|
||||
assert not flags.profiler_enable
|
||||
profiler.start(self.warmup, self.rerun)
|
||||
self.report = []
|
||||
try:
|
||||
self.fs.__enter__()
|
||||
return self.report
|
||||
except:
|
||||
profiler.stop()
|
||||
raise
|
||||
|
||||
def __exit__(self, *exc):
|
||||
self.fs.__exit__(*exc)
|
||||
profiler.stop()
|
||||
self.report.extend(profiler.report())
|
||||
|
||||
def make_var(shape=None, dtype="float32", init=None, name='var', unique=False):
|
||||
return current_scope.make_var(shape, dtype, init, name, unique)
|
||||
|
||||
def find_vars(path=None):
|
||||
scope = current_scope
|
||||
if path is not None:
|
||||
assert isinstance(path, str)
|
||||
ns = path.split("/")
|
||||
if ns[-1] == "":
|
||||
ns.pop()
|
||||
for n in ns: scope = scope.children[n]
|
||||
if not isinstance(scope, Scope):
|
||||
return [scope]
|
||||
vars = []
|
||||
dfs(scope, vars)
|
||||
return vars
|
||||
|
||||
def find_var(path):
|
||||
scope = current_scope
|
||||
if path is not None:
|
||||
assert isinstance(path, str)
|
||||
ns = path.split("/")
|
||||
for n in ns: scope = scope.children[n]
|
||||
assert not isinstance(scope, Scope)
|
||||
return scope
|
||||
|
||||
def find_records(path=None):
|
||||
scope = current_scope
|
||||
if path is not None:
|
||||
assert isinstance(path, str)
|
||||
ns = path.split("/")
|
||||
if ns[-1] == "":
|
||||
ns.pop()
|
||||
for n in ns: scope = scope.children[n]
|
||||
assert isinstance(scope, Scope)
|
||||
records = []
|
||||
dfs_records(scope, records)
|
||||
return records
|
||||
|
||||
def find_record(path):
|
||||
scope = current_scope
|
||||
assert isinstance(path, str)
|
||||
ns = path.split("/")
|
||||
for n in ns[:-1]: scope = scope.children[n]
|
||||
assert isinstance(scope, Scope)
|
||||
return scope.records[ns[-1]]
|
||||
|
||||
def find_scope(path):
|
||||
scope = current_scope
|
||||
if path is not None:
|
||||
assert isinstance(path, str)
|
||||
ns = path.split("/")
|
||||
if ns[-1] == "":
|
||||
ns.pop()
|
||||
for n in ns: scope = scope.children[n]
|
||||
assert isinstance(scope, Scope)
|
||||
return scope
|
||||
|
||||
def record_in_scope(self, name):
|
||||
current_scope.records[name] = self
|
||||
if isinstance(self, Var):
|
||||
full_name = current_scope.full_name + name
|
||||
self.name(full_name)
|
||||
return self
|
||||
|
||||
Var.record_in_scope = record_in_scope
|
||||
|
||||
def clean():
|
||||
current_scope.clean()
|
||||
import gc
|
||||
# make sure python do a full collection
|
||||
gc.collect()
|
||||
|
||||
cast = unary
|
||||
|
||||
def array(data, dtype=None):
|
||||
if type(data) == core.Var:
|
||||
if dtype is None:
|
||||
return cast(data, data.dtype)
|
||||
return cast(data, dtype)
|
||||
if dtype != None:
|
||||
return ops.array(np.array(data, dtype))
|
||||
if type(data) == np.ndarray:
|
||||
if data.flags.c_contiguous:
|
||||
return ops.array(data)
|
||||
else:
|
||||
return ops.array(data.copy())
|
||||
return ops.array(np.array(data))
|
||||
|
||||
def grad(loss, targets):
|
||||
if type(targets) == core.Var:
|
||||
return core.grad(loss, [targets])[0]
|
||||
return core.grad(loss, targets)
|
||||
|
||||
def liveness_info():
|
||||
return {
|
||||
"hold_vars": core.number_of_hold_vars(),
|
||||
"lived_vars": core.number_of_lived_vars(),
|
||||
"lived_ops": core.number_of_lived_ops(),
|
||||
}
|
||||
|
||||
def ones(shape, dtype="float32"):
|
||||
return unary(1, dtype).broadcast(shape)
|
||||
|
||||
def zeros(shape, dtype="float32"):
|
||||
return unary(0, dtype).broadcast(shape)
|
||||
|
||||
flags = core.flags()
|
||||
|
||||
def detach(x):
|
||||
"""return detached var"""
|
||||
return x.clone().stop_grad().clone()
|
||||
Var.detach = detach
|
||||
|
||||
def detach_inplace(x):
|
||||
return x.swap(x.stop_grad().clone())
|
||||
Var.start_grad = Var.detach_inplace = detach_inplace
|
||||
|
||||
def unsqueeze(x, dim):
|
||||
shape = list(x.shape)
|
||||
assert dim <= len(shape)
|
||||
return x.reshape(shape[:dim] + [1] + shape[dim:])
|
||||
Var.unsqueeze = unsqueeze
|
||||
|
||||
def squeeze(x, dim):
|
||||
shape = list(x.shape)
|
||||
assert dim < len(shape)
|
||||
assert shape[dim] == 1
|
||||
return x.reshape(shape[:dim] + shape[dim+1:])
|
||||
Var.squeeze = squeeze
|
||||
|
||||
def clamp(x, min_v, max_v):
|
||||
# TODO: change to x.maximum(min_v).minimum(max_v)
|
||||
assert min_v <= max_v
|
||||
min_b = (x < min_v).int()
|
||||
max_b = (x > max_v).int()
|
||||
return x * (1 - min_b - max_b) + min_v * min_b + max_v * max_b
|
||||
Var.clamp = clamp
|
||||
|
||||
def type_as(a, b):
|
||||
return a.unary(op=b.dtype)
|
||||
Var.type_as = type_as
|
||||
|
||||
def masked_fill(x, mask, value):
|
||||
assert list(x.shape) == list(mask.shape)
|
||||
# TODO: assert mask = 0 or 1
|
||||
return x * (1 - mask) + mask * value
|
||||
Var.masked_fill = masked_fill
|
||||
|
||||
|
||||
def sqr(x): return x*x
|
||||
Var.sqr = sqr
|
||||
|
||||
def attrs(var):
|
||||
return {
|
||||
"is_stop_fuse": var.is_stop_fuse(),
|
||||
"is_stop_grad": var.is_stop_grad(),
|
||||
"shape": var.shape,
|
||||
"dtype": var.dtype,
|
||||
}
|
||||
Var.attrs = attrs
|
||||
|
||||
def fetch(vars, func, *args, **kw):
|
||||
core.fetch(vars, lambda *results: func(*results, *args, **kw))
|
||||
|
||||
def fetch_var(var, func, *args, **kw):
|
||||
core.fetch([var], lambda a: func(a, *args, **kw))
|
||||
Var.fetch = fetch_var
|
||||
del fetch_var
|
||||
|
||||
def import_vars(data):
|
||||
''' Load variables into current scopes
|
||||
example:
|
||||
import_vars({"w":[1.0,2.0,3.0]})
|
||||
jt.get_var([3], "float64", name="w", gen_index=False)
|
||||
'''
|
||||
for k in data:
|
||||
v = data[k]
|
||||
if type(v) != core.Var:
|
||||
v = array(v).stop_fuse()
|
||||
scopes = k.split("/")
|
||||
scope = current_scope
|
||||
for i in range(len(scopes)-1):
|
||||
scope = scope.get_scope(scopes[i])
|
||||
vname = scopes[-1]
|
||||
assert vname not in scope.children, f"Var {k} exists. Please load_vars at the beginning"
|
||||
v.name(k)
|
||||
scope.children[vname] = v
|
||||
|
||||
def export_vars():
|
||||
''' Export all vars into a dictionary
|
||||
return: a dictionary, key is var name, value is numpy array
|
||||
'''
|
||||
data = { v.name():v.fetch_sync() for v in find_vars() }
|
||||
return data
|
||||
|
||||
def load(path):
|
||||
pkl_file = open(path, 'rb')
|
||||
model_dict = pickle.load(pkl_file)
|
||||
return model_dict
|
||||
|
||||
class Module:
|
||||
def __init__(self, *args, **kw):
|
||||
__doc__ == 'doc'
|
||||
def execute(self, *args, **kw):
|
||||
pass
|
||||
def __call__(self, *args, **kw):
|
||||
return self.execute(*args, **kw)
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
def _get_name(self):
|
||||
return self.__class__.__name__
|
||||
def __doc__(self):
|
||||
pass
|
||||
def __name__(self):
|
||||
pass
|
||||
|
||||
def dfs(self, parents, k, callback, callback_leave=None):
|
||||
n_children = 0
|
||||
for v in self.__dict__.values():
|
||||
if isinstance(v, Module):
|
||||
n_children += 1
|
||||
ret = callback(parents, k, self, n_children)
|
||||
if ret == False: return
|
||||
for k,v in self.__dict__.items():
|
||||
if not isinstance(v, Module):
|
||||
continue
|
||||
parents.append(self)
|
||||
v.dfs(parents, k, callback, callback_leave)
|
||||
parents.pop()
|
||||
if callback_leave:
|
||||
callback_leave(parents, k, self, n_children)
|
||||
|
||||
def __str__(self):
|
||||
ss = []
|
||||
def callback(parents, k, v, n):
|
||||
# indent key:class_name(extra_repr)
|
||||
k = f"{k}: " if k is not None else ""
|
||||
s = f"{' '*(len(parents)*4)}{k}{v.__class__.__name__}"
|
||||
if n:
|
||||
s += '('
|
||||
else:
|
||||
s += f"({v.extra_repr()})"
|
||||
ss.append(s)
|
||||
def callback_leave(parents, k, v, n):
|
||||
if n:
|
||||
ss.append(' '*(len(parents)*4)+')')
|
||||
self.dfs([], None, callback, callback_leave)
|
||||
return "\n".join(ss)
|
||||
|
||||
def parameters(self):
|
||||
ps = []
|
||||
stack = []
|
||||
def callback(parents, k, v, n):
|
||||
stack.append(str(k))
|
||||
for k2, p in v.__dict__.items():
|
||||
if isinstance(p, Var):
|
||||
ps.append(p)
|
||||
p.name(".".join(stack[1:]+[str(k2)]))
|
||||
def callback_leave(parents, k, v, n):
|
||||
stack.pop()
|
||||
self.dfs([], None, callback, callback_leave)
|
||||
return ps
|
||||
|
||||
def modules(self):
|
||||
ms = []
|
||||
def callback(parents, k, v, n):
|
||||
if isinstance(v, Module):
|
||||
ms.append(v)
|
||||
self.dfs([], None, callback, None)
|
||||
return ms
|
||||
|
||||
def children(self):
|
||||
cd = []
|
||||
def callback(parents, k, v, n):
|
||||
if len(parents) == 1 and isinstance(v, Module):
|
||||
cd.append(v)
|
||||
return False
|
||||
self.dfs([], None, callback, None)
|
||||
return cd
|
||||
|
||||
def extra_repr(self):
|
||||
ss = []
|
||||
n = len(self.__init__.__code__.co_varnames) - \
|
||||
len(self.__init__.__defaults__)
|
||||
for i, k in enumerate(self.__init__.__code__.co_varnames[1:]):
|
||||
v = getattr(self, k) if hasattr(self, k) else None
|
||||
if isinstance(v, Var): v = v.peek()
|
||||
s = f"{k}={v}" if i >= n else str(v)
|
||||
ss.append(s)
|
||||
return ", ".join(ss)
|
||||
|
||||
def load_parameters(self, params):
|
||||
for key in params.keys():
|
||||
v = self
|
||||
key_ = key.split('.')
|
||||
end = 0
|
||||
for k in key_:
|
||||
if isinstance(v, nn.Sequential):
|
||||
if np.int(k) >= len(v.layers):
|
||||
end = 1
|
||||
break
|
||||
else:
|
||||
v = v[np.int(k)]
|
||||
else:
|
||||
if hasattr(v, k):
|
||||
v = getattr(v, k)
|
||||
else:
|
||||
end = 1
|
||||
break
|
||||
if end ==1:
|
||||
print(f'init {key} fail ...')
|
||||
else:
|
||||
# print(f'init {key} success ...')
|
||||
if isinstance(params[key], np.ndarray) or isinstance(params[key], list):
|
||||
v.assign(array(params[key]))
|
||||
elif isinstance(params[key], Var):
|
||||
v.assign(params[key])
|
||||
else:
|
||||
v.assign(array(params[key].cpu( ).detach().numpy()))
|
||||
def save(self, path):
|
||||
params = self.parameters()
|
||||
params_dict = {}
|
||||
for p in params:
|
||||
params_dict[p.name()] = p.data
|
||||
with open(path, 'wb') as f:
|
||||
pickle.dump(params_dict, f, pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
def eval(self):
|
||||
def callback(parents, k, v, n):
|
||||
if isinstance(v, Module) and hasattr(v, "is_train"):
|
||||
v.is_train = False
|
||||
self.dfs([], None, callback, None)
|
||||
|
||||
# backup stop grad or not
|
||||
if not hasattr(self, "backup_grad_state"):
|
||||
self.backup_grad_state = {}
|
||||
for p in self.parameters():
|
||||
if id(p) not in self.backup_grad_state:
|
||||
self.backup_grad_state[id(p)] = not p.is_stop_grad()
|
||||
p.stop_grad()
|
||||
|
||||
def train(self):
|
||||
def callback(parents, k, v, n):
|
||||
if isinstance(v, Module) and hasattr(v, "is_train"):
|
||||
v.is_train = True
|
||||
self.dfs([], None, callback, None)
|
||||
|
||||
# backup stop grad or not
|
||||
if hasattr(self, "backup_grad_state"):
|
||||
for p in self.parameters():
|
||||
if id(p) in self.backup_grad_state and self.backup_grad_state[id(p)]:
|
||||
p.start_grad()
|
||||
|
||||
def make_module(func, exec_n_args=1):
|
||||
class MakeModule(Module):
|
||||
def __init__(self, *args, **kw):
|
||||
self.args = args
|
||||
self.kw = kw
|
||||
self.__doc__ == 'doc'
|
||||
def execute(self, *args):
|
||||
return func(*args, *self.args, **self.kw)
|
||||
def __str__(self):
|
||||
return 'str'
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
def extra_repr(self):
|
||||
return ''
|
||||
|
||||
return MakeModule
|
||||
|
||||
|
||||
def dirty_fix_pytorch_runtime_error():
|
||||
''' This funtion should be called before pytorch.
|
||||
Example:
|
||||
import jittor as jt
|
||||
jt.dirty_fix_pytorch_runtime_error()
|
||||
import torch
|
||||
'''
|
||||
import os
|
||||
os.RTLD_GLOBAL = os.RTLD_GLOBAL | os.RTLD_DEEPBIND
|
||||
|
||||
|
||||
import atexit
|
||||
|
||||
class ExitHooks(object):
|
||||
def __init__(self):
|
||||
self.exit_code = None
|
||||
self.exception = None
|
||||
|
||||
def hook(self):
|
||||
self._orig_exit = sys.exit
|
||||
sys.exit = self.exit
|
||||
sys.excepthook = self.exc_handler
|
||||
|
||||
def exit(self, code=0):
|
||||
self.exit_code = code
|
||||
self._orig_exit(code)
|
||||
|
||||
def exc_handler(self, exc_type, exc, *args):
|
||||
self.exception = exc
|
||||
traceback.print_exception(exc_type, exc, *args)
|
||||
|
||||
hooks = ExitHooks()
|
||||
hooks.hook()
|
||||
|
||||
def jittor_exit():
|
||||
if hooks.exit_code is not None:
|
||||
pass
|
||||
elif hooks.exception is not None:
|
||||
pass
|
||||
else:
|
||||
core.sync_all(True)
|
||||
atexit.register(jittor_exit)
|
||||
|
||||
Var.__repr__ = Var.__str__ = lambda x: str(x.data)
|
||||
Var.peek = lambda x: str(x.dtype)+str(x.shape)
|
||||
|
||||
from . import nn
|
||||
from .nn import matmul
|
||||
from . import contrib
|
|
@ -0,0 +1,234 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import os, sys
|
||||
from .compiler import *
|
||||
from jittor.dataset.utils import download_url_to_local
|
||||
|
||||
def search_file(dirs, name):
|
||||
for d in dirs:
|
||||
fname = os.path.join(d, name)
|
||||
if os.path.isfile(fname):
|
||||
LOG.i(f"found {fname}")
|
||||
return fname
|
||||
LOG.f(f"file {name} not found in {dirs}")
|
||||
|
||||
def install_mkl(root_folder):
|
||||
url = "https://github.com/intel/mkl-dnn/releases/download/v1.0.2/mkldnn_lnx_1.0.2_cpu_gomp.tgz"
|
||||
filename = "mkldnn_lnx_1.0.2_cpu_gomp.tgz"
|
||||
fullname = os.path.join(root_folder, filename)
|
||||
dirname = os.path.join(root_folder, filename.replace(".tgz",""))
|
||||
|
||||
if not os.path.isfile(os.path.join(dirname, "examples", "test")):
|
||||
LOG.i("Downloading mkl...")
|
||||
download_url_to_local(url, filename, root_folder, "47187284ede27ad3bd64b5f0e7d5e730")
|
||||
import tarfile
|
||||
|
||||
with tarfile.open(fullname, "r") as tar:
|
||||
tar.extractall(root_folder)
|
||||
|
||||
assert 0 == os.system(f"cd {dirname}/examples && "
|
||||
f"{cc_path} -std=c++14 cpu_cnn_inference_f32.cpp -Ofast -lmkldnn -I ../include -L ../lib -o test && LD_LIBRARY_PATH=../lib/ ./test")
|
||||
|
||||
def setup_mkl():
|
||||
global mkl_ops, use_mkl
|
||||
use_mkl = os.environ.get("use_mkl", "1")=="1"
|
||||
mkl_ops = None
|
||||
if not use_mkl: return
|
||||
mkl_include_path = os.environ.get("mkl_include_path")
|
||||
mkl_lib_path = os.environ.get("mkl_lib_path")
|
||||
|
||||
if mkl_lib_path is None or mkl_include_path is None:
|
||||
mkl_install_sh = os.path.join(jittor_path, "script", "install_mkl.sh")
|
||||
LOG.v("setup mkl...")
|
||||
# mkl_path = os.path.join(cache_path, "mkl")
|
||||
# mkl_path decouple with cc_path
|
||||
from pathlib import Path
|
||||
mkl_path = os.path.join(str(Path.home()), ".cache", "jittor", "mkl")
|
||||
|
||||
make_cache_dir(mkl_path)
|
||||
install_mkl(mkl_path)
|
||||
mkl_home = ""
|
||||
for name in os.listdir(mkl_path):
|
||||
if name.startswith("mkldnn_lnx") and os.path.isdir(os.path.join(mkl_path, name)):
|
||||
mkl_home = os.path.join(mkl_path, name)
|
||||
break
|
||||
assert mkl_home!=""
|
||||
mkl_include_path = os.path.join(mkl_home, "include")
|
||||
mkl_lib_path = os.path.join(mkl_home, "lib")
|
||||
|
||||
mkl_lib_name = os.path.join(mkl_lib_path, "libmkldnn.so")
|
||||
assert os.path.isdir(mkl_include_path)
|
||||
assert os.path.isdir(mkl_lib_path)
|
||||
assert os.path.isfile(mkl_lib_name)
|
||||
LOG.v(f"mkl_include_path: {mkl_include_path}")
|
||||
LOG.v(f"mkl_lib_path: {mkl_lib_path}")
|
||||
LOG.v(f"mkl_lib_name: {mkl_lib_name}")
|
||||
# We do not link manualy, link in custom ops
|
||||
# ctypes.CDLL(mkl_lib_name, dlopen_flags)
|
||||
|
||||
mkl_op_dir = os.path.join(jittor_path, "extern", "mkl", "ops")
|
||||
mkl_op_files = [os.path.join(mkl_op_dir, name) for name in os.listdir(mkl_op_dir)]
|
||||
mkl_ops = compile_custom_ops(mkl_op_files,
|
||||
extra_flags=f" -I'{mkl_include_path}' -L'{mkl_lib_path}' -lmkldnn -Wl,-rpath='{mkl_lib_path}' ")
|
||||
LOG.vv("Get mkl_ops: "+str(dir(mkl_ops)))
|
||||
|
||||
|
||||
def install_cub(root_folder):
|
||||
url = "https://github.com/NVlabs/cub/archive/v1.8.0.tar.gz"
|
||||
filename = "cub-1.8.0.tgz"
|
||||
fullname = os.path.join(root_folder, filename)
|
||||
dirname = os.path.join(root_folder, filename.replace(".tgz",""))
|
||||
|
||||
if not os.path.isfile(os.path.join(dirname, "examples", "test")):
|
||||
LOG.i("Downloading cub...")
|
||||
download_url_to_local(url, filename, root_folder, "9203ea2499b56782601fddf8a12e9b08")
|
||||
import tarfile
|
||||
|
||||
with tarfile.open(fullname, "r") as tar:
|
||||
tar.extractall(root_folder)
|
||||
|
||||
assert 0 == os.system(f"cd {dirname}/examples && "
|
||||
f"{nvcc_path} device/example_device_radix_sort.cu -O2 -I.. -o test && ./test")
|
||||
return dirname
|
||||
|
||||
def setup_cub():
|
||||
from pathlib import Path
|
||||
cub_path = os.path.join(str(Path.home()), ".cache", "jittor", "cub")
|
||||
cub_home = install_cub(cub_path)
|
||||
setup_cuda_lib("cub", link=False, extra_flags=f"-I{cub_home}")
|
||||
|
||||
def setup_cuda_extern():
|
||||
if not has_cuda: return
|
||||
LOG.vv("setup cuda extern...")
|
||||
cache_path_cuda = os.path.join(cache_path, "cuda")
|
||||
cuda_include = os.path.join(jittor_path, "extern", "cuda", "inc")
|
||||
make_cache_dir(cache_path_cuda)
|
||||
cuda_extern_src = os.path.join(jittor_path, "extern", "cuda", "src")
|
||||
cuda_extern_files = [os.path.join(cuda_extern_src, name)
|
||||
for name in os.listdir(cuda_extern_src)]
|
||||
so_name = os.path.join(cache_path_cuda, "cuda_extern.so")
|
||||
compile(cc_path, cc_flags+f" -I'{cuda_include}' ", cuda_extern_files, so_name)
|
||||
ctypes.CDLL(so_name, dlopen_flags)
|
||||
|
||||
try:
|
||||
setup_cub()
|
||||
except Exception as e:
|
||||
import traceback
|
||||
line = traceback.format_exc()
|
||||
LOG.w(f"CUDA found but cub is not loaded:\n{line}")
|
||||
|
||||
libs = ["cublas", "cudnn", "curand"]
|
||||
for lib_name in libs:
|
||||
try:
|
||||
setup_cuda_lib(lib_name)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
line = traceback.format_exc()
|
||||
LOG.w(f"CUDA found but {lib_name} is not loaded:\n{line}")
|
||||
|
||||
def setup_cuda_lib(lib_name, link=True, extra_flags=""):
|
||||
globals()[lib_name+"_ops"] = None
|
||||
if not has_cuda: return
|
||||
LOG.v(f"setup {lib_name}...")
|
||||
|
||||
culib_path = os.path.join(cuda_lib, f"lib{lib_name}.so")
|
||||
jt_cuda_include = os.path.join(jittor_path, "extern", "cuda", "inc")
|
||||
jt_culib_include = os.path.join(jittor_path, "extern", "cuda", lib_name, "inc")
|
||||
|
||||
link_flags = ""
|
||||
if link:
|
||||
cuda_include_name = search_file([cuda_include, "/usr/include"], lib_name+".h")
|
||||
culib_path = search_file([cuda_lib, "/usr/lib/x86_64-linux-gnu"], f"lib{lib_name}.so")
|
||||
# dynamic link cuda library
|
||||
ctypes.CDLL(culib_path, dlopen_flags)
|
||||
link_flags = f"-l{lib_name} -L'{cuda_lib}'"
|
||||
|
||||
# find all source files
|
||||
culib_src_dir = os.path.join(jittor_path, "extern", "cuda", lib_name)
|
||||
culib_src_files = []
|
||||
for r, _, f in os.walk(culib_src_dir):
|
||||
for fname in f:
|
||||
culib_src_files.append(os.path.join(r, fname))
|
||||
if len(culib_src_files) == 0:
|
||||
return
|
||||
|
||||
# compile and get operators
|
||||
culib_ops = compile_custom_ops(culib_src_files,
|
||||
extra_flags=f" -I'{jt_cuda_include}' -I'{jt_culib_include}' {link_flags} {extra_flags} ")
|
||||
globals()[lib_name+"_ops"] = culib_ops
|
||||
LOG.vv(f"Get {lib_name}_ops: "+str(dir(culib_ops)))
|
||||
|
||||
def install_cutt(root_folder):
|
||||
url = "https://cloud.tsinghua.edu.cn/f/4be7e1dd51c6459aa119/?dl=1"
|
||||
filename = "cutt.tgz"
|
||||
fullname = os.path.join(root_folder, filename)
|
||||
dirname = os.path.join(root_folder, filename.replace(".tgz",""))
|
||||
true_md5 = "c79ad93b76544d598eb250ec749c492c"
|
||||
|
||||
if os.path.exists(fullname):
|
||||
md5 = os.popen('md5sum ' + fullname).read().split()[0]
|
||||
else:
|
||||
md5 = '233'
|
||||
if md5 != true_md5:
|
||||
os.system('rm ' + fullname)
|
||||
os.system('rm -rf ' + dirname)
|
||||
if not os.path.isfile(os.path.join(dirname, "bin", "cutt_test")):
|
||||
LOG.i("Downloading cub...")
|
||||
download_url_to_local(url, filename, root_folder, true_md5)
|
||||
import tarfile
|
||||
|
||||
with tarfile.open(fullname, "r") as tar:
|
||||
tar.extractall(root_folder)
|
||||
|
||||
from jittor_utils import run_cmd
|
||||
LOG.i("installing cutt...")
|
||||
run_cmd(f"cd {dirname} && make")
|
||||
return dirname
|
||||
|
||||
def setup_cutt():
|
||||
global cutt_ops, use_cutt
|
||||
if not has_cuda:
|
||||
use_cutt = False
|
||||
return
|
||||
use_cutt = os.environ.get("use_cutt", "1")=="1"
|
||||
cutt_ops = None
|
||||
if not use_cutt: return
|
||||
cutt_include_path = os.environ.get("cutt_include_path")
|
||||
cutt_lib_path = os.environ.get("cutt_lib_path")
|
||||
|
||||
if cutt_lib_path is None or cutt_include_path is None:
|
||||
LOG.v("setup cutt...")
|
||||
# cutt_path decouple with cc_path
|
||||
from pathlib import Path
|
||||
cutt_path = os.path.join(str(Path.home()), ".cache", "jittor", "cutt")
|
||||
|
||||
make_cache_dir(cutt_path)
|
||||
install_cutt(cutt_path)
|
||||
cutt_home = os.path.join(cutt_path, "cutt")
|
||||
cutt_include_path = os.path.join(cutt_home, "src")
|
||||
cutt_lib_path = os.path.join(cutt_home, "lib")
|
||||
|
||||
cutt_lib_name = os.path.join(cutt_lib_path, "libcutt.so")
|
||||
assert os.path.isdir(cutt_include_path)
|
||||
assert os.path.isdir(cutt_lib_path)
|
||||
assert os.path.isfile(cutt_lib_name), cutt_lib_name
|
||||
LOG.v(f"cutt_include_path: {cutt_include_path}")
|
||||
LOG.v(f"cutt_lib_path: {cutt_lib_path}")
|
||||
LOG.v(f"cutt_lib_name: {cutt_lib_name}")
|
||||
# We do not link manualy, link in custom ops
|
||||
ctypes.CDLL(cutt_lib_name, dlopen_flags)
|
||||
|
||||
cutt_op_dir = os.path.join(jittor_path, "extern", "cuda", "cutt", "ops")
|
||||
cutt_op_files = [os.path.join(cutt_op_dir, name) for name in os.listdir(cutt_op_dir)]
|
||||
cutt_ops = compile_custom_ops(cutt_op_files,
|
||||
extra_flags=f" -I'{cutt_include_path}'")
|
||||
LOG.vv("Get cutt_ops: "+str(dir(cutt_ops)))
|
||||
|
||||
|
||||
setup_cutt()
|
||||
setup_mkl()
|
||||
|
||||
setup_cuda_extern()
|
|
@ -0,0 +1,924 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import subprocess as sp
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import inspect
|
||||
import datetime
|
||||
import threading
|
||||
import ctypes
|
||||
from ctypes import cdll
|
||||
from ctypes.util import find_library
|
||||
|
||||
import jittor_utils as jit_utils
|
||||
from jittor_utils import LOG, run_cmd, cache_path, find_exe, cc_path, cc_type, cache_path
|
||||
from . import pyjt_compiler
|
||||
|
||||
def find_jittor_path():
|
||||
return os.path.dirname(__file__)
|
||||
|
||||
def make_cache_dir(cache_path):
|
||||
if not os.path.isdir(cache_path):
|
||||
LOG.i(f"Create cache dir: {cache_path}")
|
||||
os.mkdir(cache_path)
|
||||
|
||||
def remove_flags(flags, rm_flags):
|
||||
flags = flags.split(" ")
|
||||
output = []
|
||||
for s in flags:
|
||||
for rm in rm_flags:
|
||||
if s.startswith(rm):
|
||||
break
|
||||
else:
|
||||
output.append(s)
|
||||
return " ".join(output)
|
||||
|
||||
def compile(compiler, flags, inputs, output, combind_build=False):
|
||||
def do_compile(cmd):
|
||||
if jit_utils.cc:
|
||||
return jit_utils.cc.cache_compile(cmd, cache_path, jittor_path)
|
||||
else:
|
||||
run_cmd(cmd)
|
||||
return True
|
||||
link = link_flags
|
||||
# if output is core, add core_link_flags
|
||||
if output.startswith("jittor_core"):
|
||||
link = link + core_link_flags
|
||||
output = os.path.join(cache_path, output)
|
||||
# don't recompile object file in inputs
|
||||
obj_files = []
|
||||
new_inputs = []
|
||||
for name in inputs:
|
||||
if name.endswith(".o"):
|
||||
obj_files.append(name)
|
||||
else:
|
||||
new_inputs.append(os.path.join(jittor_path, name))
|
||||
obj_files.append(os.path.join(
|
||||
cache_path, "obj_files", os.path.basename(name)+".o"))
|
||||
inputs = new_inputs
|
||||
|
||||
if len(inputs) == 1 or combind_build:
|
||||
cmd = f"{compiler} {' '.join(inputs)} {flags} {link} -o {output}"
|
||||
return do_compile(cmd)
|
||||
# split compile object file and link
|
||||
# remove -l -L flags when compile object files
|
||||
oflags = remove_flags(flags, ['-l', '-L', '-Wl,'])
|
||||
cmds = []
|
||||
for input, obj_file in zip(inputs, obj_files):
|
||||
cc = compiler
|
||||
nflags = oflags
|
||||
if has_cuda and input.endswith(".cu"):
|
||||
nflags = convert_nvcc_flags(oflags)
|
||||
cc = nvcc_path
|
||||
cmd = f"{cc} {input} {nflags} -c {lto_flags} -o {obj_file}"
|
||||
cmds.append(cmd)
|
||||
jit_utils.run_cmds(cmds, cache_path, jittor_path)
|
||||
cmd = f"{compiler} {' '.join(obj_files)} {flags} {lto_flags} {link} -o {output}"
|
||||
return do_compile(cmd)
|
||||
|
||||
def gen_jit_tests():
|
||||
all_src = run_cmd('find -L src/ | grep "cc$"', jittor_path).splitlines()
|
||||
jit_declares = []
|
||||
re_def = re.compile("JIT_TEST\\((.*?)\\)")
|
||||
names = set()
|
||||
test_defs = []
|
||||
|
||||
for src_name in all_src:
|
||||
src_name = os.path.join(jittor_path, src_name)
|
||||
with open(src_name) as f:
|
||||
src = f.read()
|
||||
defs = re_def.findall(src)
|
||||
for name in defs:
|
||||
LOG.vv(f"Find test {name} from {src_name}")
|
||||
assert name not in names, f"Conflict test name {name}"
|
||||
names.add(name)
|
||||
jit_declares.append(f"JIT_TEST({name});")
|
||||
test_defs.append(f"""
|
||||
/* From {src_name} */
|
||||
// @pyjt({name})
|
||||
static inline void test_{name}() {{ jit_test_{name}(); }}
|
||||
""")
|
||||
|
||||
jit_declares = "\n ".join(jit_declares)
|
||||
jit_src = f"""
|
||||
#pragma once
|
||||
#include "common.h"
|
||||
|
||||
void expect_error(std::function<void()> func) {{
|
||||
try {{ func(); }}
|
||||
catch (...) {{ return; }}
|
||||
CHECK(0) << "Missing error";
|
||||
}}
|
||||
|
||||
namespace jittor {{
|
||||
|
||||
{jit_declares}
|
||||
|
||||
// @pyjt(tests)
|
||||
// @attrs(submodule)
|
||||
namespace tests {{
|
||||
{"".join(test_defs)}
|
||||
}}
|
||||
|
||||
}} // jittor
|
||||
"""
|
||||
LOG.vvvv(jit_src)
|
||||
with open(os.path.join(cache_path, "gen", "jit_tests.h"), 'w') as f:
|
||||
f.write(jit_src)
|
||||
|
||||
def gen_jit_flags():
|
||||
all_src = run_cmd('find -L src/ | grep "cc$"', jittor_path).splitlines()
|
||||
jit_declares = []
|
||||
re_def = re.compile("DEFINE_FLAG(_WITH_SETTER)?\\((.*?)\\);", re.DOTALL)
|
||||
|
||||
flags_defs = []
|
||||
visit = {}
|
||||
|
||||
for src_name in all_src:
|
||||
src_name = os.path.join(jittor_path, src_name)
|
||||
with open(src_name) as f:
|
||||
src = f.read()
|
||||
defs = re_def.findall(src)
|
||||
for _, args in defs:
|
||||
args = args.split(",")
|
||||
type = args[0].strip()
|
||||
name = args[1].strip()
|
||||
if not has_cuda and "cuda" in name and name!="use_cuda":
|
||||
continue
|
||||
default = args[2].strip()
|
||||
doc = ",".join(args[3:])
|
||||
doc = eval(f"({doc})")
|
||||
LOG.vv(f"Find define {name} from {src_name}")
|
||||
if name in visit:
|
||||
continue
|
||||
visit[name] = 1
|
||||
jit_declares.append(f"DECLARE_FLAG({type}, {name});")
|
||||
flags_defs.append(f"""
|
||||
/* {name}(type:{type}, default:{default}): {doc} */
|
||||
// @pyjt(__get__{name})
|
||||
{type} _get_{name}() {{ return {name}; }}
|
||||
// @pyjt(__set__{name})
|
||||
void _set_{name}({type} v) {{ set_{name}(v); }}
|
||||
{f'''// @pyjt(__set__{name})
|
||||
void _set_{name}(bool v) {{ set_{name}(v); }}
|
||||
''' if type=="int" else ""}
|
||||
""")
|
||||
|
||||
jit_declares = "\n ".join(jit_declares)
|
||||
jit_src = f"""
|
||||
#include "utils/flags.h"
|
||||
|
||||
namespace jittor {{
|
||||
|
||||
{jit_declares}
|
||||
|
||||
// @pyjt(flags)
|
||||
struct _Flags {{
|
||||
// @pyjt(__init__)
|
||||
_Flags() {{}}
|
||||
{"".join(flags_defs)}
|
||||
}};
|
||||
|
||||
}} // jittor
|
||||
"""
|
||||
LOG.vvvv(jit_src)
|
||||
with open(os.path.join(cache_path, "gen", "jit_flags.h"), 'w') as f:
|
||||
f.write(jit_src)
|
||||
|
||||
def gen_jit_op_maker(op_headers, export=False, extra_flags=""):
|
||||
def add_src(
|
||||
cc_func_name,
|
||||
cc_args,
|
||||
op_name,
|
||||
op_args,
|
||||
src,
|
||||
pybind_name,
|
||||
py_args,
|
||||
jit_cc_src,
|
||||
doc_string,
|
||||
attrs
|
||||
):
|
||||
has_ir = set(["add", "sub", "mul", "matmul", "truediv", "floordiv", "mod", "divmod", "pow", "lshift", "rshift", "and", "xor", "or"])
|
||||
pybind_names = [ s.strip() for s in pybind_name.split(",")]
|
||||
cc_make_args = [ arg.replace("VarHolder*", "Var*") for arg in cc_args ]
|
||||
op_make_args = [ arg.replace("->var", "") for arg in op_args ]
|
||||
py_args = [ arg.replace("Var*", "VarHolder*") for arg in py_args ]
|
||||
op_args = []
|
||||
cc_args_with_default = []
|
||||
for i, arg in enumerate(cc_args):
|
||||
pre_arg = arg.split()[-1].split('=')[0]
|
||||
op_arg = None
|
||||
if arg.startswith("VarHolder*"):
|
||||
op_arg = pre_arg+"->var"
|
||||
elif arg.startswith("vector<VarHolder*>"):
|
||||
op_arg = f"convert({pre_arg})"
|
||||
if "&&" in arg:
|
||||
if op_arg == None:
|
||||
op_arg = "move("+pre_arg+")"
|
||||
op_make_args[i] = "move("+pre_arg+")"
|
||||
if op_arg==None: op_arg = pre_arg
|
||||
op_args.append(op_arg)
|
||||
py_arg = py_args[i]
|
||||
if "_a=" not in py_arg:
|
||||
cc_args_with_default.append(arg)
|
||||
continue
|
||||
py_arg = py_arg.split("_a=")[1]
|
||||
cc_args_with_default.append(arg + "=" + py_arg)
|
||||
cc_args = cc_args_with_default
|
||||
# steps of Op creation:
|
||||
# 1. new op
|
||||
# 2. new output var (create_output in op constructor)
|
||||
# 3. take over op's output VarPtr from outputs_holder
|
||||
# 4. set op's output
|
||||
# 5. set op's input
|
||||
# 6. infer shape(op->init())
|
||||
if "multiple_outputs" not in attrs:
|
||||
jit_cc_src.append(f"""
|
||||
VarPtr make_{cc_func_name}({", ".join(cc_make_args)}) {{
|
||||
Op* _op = new {op_name}({", ".join(op_make_args)});
|
||||
if (_op->outputs_holder.size() != 1) {{
|
||||
delete _op;
|
||||
LOGf << "Wrong output size of" << \"{op_name}\";
|
||||
}}
|
||||
if (_op->flags.get(NodeFlags::_forwarded)) {{
|
||||
VarPtr output(move(_op->outputs_holder[0]));
|
||||
delete _op;
|
||||
return output;
|
||||
}}
|
||||
_op->outputs_holder[0]->set_inputs({{_op}});
|
||||
VarPtr output(move(_op->outputs_holder[0]));
|
||||
{src.replace("->var","")};
|
||||
_op->init();
|
||||
return output;
|
||||
}}
|
||||
""")
|
||||
else:
|
||||
jit_cc_src.append(f"""
|
||||
vector<VarPtr> make_{cc_func_name}({", ".join(cc_make_args)}) {{
|
||||
Op* _op = new {op_name}({", ".join(op_make_args)});
|
||||
if (_op->flags.get(NodeFlags::_forwarded)) {{
|
||||
vector<VarPtr> outputs = move(_op->outputs_holder);
|
||||
delete _op;
|
||||
return outputs;
|
||||
}}
|
||||
vector<VarPtr> outputs = move(_op->outputs_holder);
|
||||
for (uint i=0; i<outputs.size(); i++)
|
||||
outputs[i]->set_inputs({{_op}});
|
||||
{src.replace("->var","")};
|
||||
_op->init();
|
||||
return outputs;
|
||||
}}
|
||||
""")
|
||||
if pybind_name == 'None':
|
||||
return
|
||||
pyjt_names = []
|
||||
for pybind_name in pybind_names:
|
||||
if pybind_name.startswith("__"):
|
||||
pyjt_names.append("Var."+pybind_name)
|
||||
else:
|
||||
pyjt_names.append(pybind_name)
|
||||
if len(cc_args)>0 and cc_args[0].startswith("VarHolder* "):
|
||||
pyjt_names.append("Var."+pybind_name)
|
||||
if "multiple_outputs" in attrs:
|
||||
jit_cc_src.append(f"""
|
||||
/*{doc_string}*/
|
||||
// @pyjt({",".join(pyjt_names)})
|
||||
vector<VarHolder*> {cc_func_name}({", ".join(cc_args)}) {{
|
||||
return make_vh_vector(make_{cc_func_name}({", ".join(op_args)}));
|
||||
}}
|
||||
""")
|
||||
else:
|
||||
jit_cc_src.append(f"""
|
||||
/*{doc_string}*/
|
||||
// @pyjt({",".join(pyjt_names)})
|
||||
VarHolder* {cc_func_name}({", ".join(cc_args)}) {{
|
||||
return new VarHolder(make_{cc_func_name}({", ".join(op_args)}));
|
||||
}}
|
||||
""")
|
||||
need_ir_define = False
|
||||
ir_name = None
|
||||
for pybind_name in pybind_names:
|
||||
if pybind_name.startswith("__") and pybind_name[2:-2] in has_ir:
|
||||
need_ir_define = True
|
||||
assert ir_name is None
|
||||
ir_name = pybind_name[2:-2]
|
||||
if need_ir_define:
|
||||
assert len(cc_args)>0 and cc_args[0].startswith("VarHolder* ")
|
||||
this = cc_args[0].split()[-1]
|
||||
jit_cc_src.append(f"""
|
||||
// @pyjt(Var.__i{ir_name}__)
|
||||
// @attrs(return_self)
|
||||
VarHolder* i{cc_func_name}({", ".join(cc_args)}) {{
|
||||
*{this} = make_{cc_func_name}({", ".join(op_args)});
|
||||
return {this};
|
||||
}}
|
||||
""")
|
||||
assert len(cc_args)>1 and cc_args[1].startswith("VarHolder* "), cc_args
|
||||
r_cc_args = [cc_args[1], cc_args[0]] + cc_args[2:]
|
||||
r_py_args = [py_args[1], py_args[0]] + py_args[2:]
|
||||
jit_cc_src.append(f"""
|
||||
VarHolder* r{cc_func_name}({", ".join(r_cc_args)}) {{
|
||||
return new VarHolder(make_{cc_func_name}({", ".join(op_args)}));
|
||||
}}
|
||||
""")
|
||||
|
||||
jit_cc_src = []
|
||||
jit_headers = ""
|
||||
initer = []
|
||||
pybind_reg = '(/\\*(.*?)\\*/\\s*)?(//\\s*@pybind\\(([^\\n]*)\\)\\s*)?'
|
||||
pybind_attrs_reg = pybind_reg + '(//\\s*@attrs\\(([^\\n]*)\\)\\s*)?'
|
||||
for header in op_headers:
|
||||
# xxx_xxx_op
|
||||
name = os.path.basename(header)
|
||||
name = os.path.splitext(name)[0]
|
||||
# xxx_xxx
|
||||
assert name.endswith("_op")
|
||||
func_name = name[:-3]
|
||||
# XxxXxxOp
|
||||
name2 = map(lambda s:s[:1].upper() + s[1:], name.split('_'))
|
||||
name2 = "".join(name2)
|
||||
with open(os.path.join(jittor_path, header), encoding='utf8') as f:
|
||||
src = f.read()
|
||||
# XxxXxxOp(args)
|
||||
res = re.findall(pybind_attrs_reg + '('+name2+"\\([^\\n]*\\))", src, re.S)
|
||||
assert len(res) >= 1, "Wrong op args in " + header
|
||||
# registe op
|
||||
cc_name = os.path.join(jittor_path, header[:-2] + ".cc")
|
||||
constructors = []
|
||||
for i in range(len(res)):
|
||||
name = 'make_'+func_name+'_'*i
|
||||
constructors.append(f"{{ &typeid(&{name}), (void*)&{name} }}")
|
||||
constructors = ",".join(constructors)
|
||||
var_member_reg = r"\n\s*Var\b(.*);"
|
||||
var_member_match = re.findall(var_member_reg, src)
|
||||
var_member_match = " ".join(var_member_match)
|
||||
for c in "*,": var_member_match = var_member_match.replace(c, " ")
|
||||
var_member = var_member_match.split()
|
||||
LOG.vv("var_member_match "+var_member_match)
|
||||
LOG.vv("var_member "+str(var_member))
|
||||
var_member_src = [ f"VAR_MEMBER_NAME_AND_OFFSET({name}, {name2})" for name in var_member ]
|
||||
var_member_src = ",".join(var_member_src)
|
||||
initer.append(f'\n op_registe({{ "{func_name}", R"({cc_name})", extra_flags, {{{constructors}}}, {{{var_member_src}}} }});')
|
||||
for hid, h_def in enumerate(res):
|
||||
h_def = list(h_def)
|
||||
# // @attrs(...)
|
||||
attrs = {}
|
||||
if h_def[4] != "":
|
||||
attrs = pyjt_compiler.parse_attrs(h_def[5])
|
||||
del h_def[4:6]
|
||||
# /* doc_string */
|
||||
# // @pybind(bind_name)
|
||||
# XxxXxxOp(args_def)
|
||||
doc_string = h_def[1].strip()
|
||||
h_def = h_def[2:]
|
||||
args_def = h_def[2][len(name2)+1:-1]
|
||||
bind_name = h_def[1]
|
||||
if bind_name == "":
|
||||
bind_name = func_name
|
||||
if args_def=="":
|
||||
args = []
|
||||
else:
|
||||
args = list(map(lambda s: s.split()[-1].split('=')[0], args_def.split(',')))
|
||||
# py_args: "arg"_a=default
|
||||
py_args = []
|
||||
new_args_def = []
|
||||
new_args = []
|
||||
# source of convert VarHolder* to Var*
|
||||
vh2v_src = []
|
||||
more_src = []
|
||||
for arg, arg_def in zip(args, args_def.split(',')):
|
||||
py_arg = f'"{arg}"_a'
|
||||
if '=' in arg_def:
|
||||
py_arg += "=" + arg_def.split('=')[-1]
|
||||
arg_def = arg_def.split('=')[0]
|
||||
py_args.append(py_arg)
|
||||
arg_type = arg_def[:-(len(arg)+1)].strip()
|
||||
if arg_type == "Var*":
|
||||
new_args_def.append("VarHolder* " + arg)
|
||||
vh2v_src.append(arg + "->var")
|
||||
new_args.append(arg + "->var")
|
||||
elif arg_type.startswith("vector<Var*>"):
|
||||
new_args_def.append(
|
||||
arg_type.replace("Var", "VarHolder")+' '+arg)
|
||||
new_args.append(arg)
|
||||
more_src.append(f"_op->add_inputs({arg});")
|
||||
else:
|
||||
new_args_def.append(arg_def)
|
||||
new_args.append(arg)
|
||||
vh2v_src = "_op->set_inputs({" + ", ".join(vh2v_src) + "});" + \
|
||||
"".join(more_src)
|
||||
LOG.vvvv(f"Find op: {name2} args: {new_args}")
|
||||
if header.startswith("src/"):
|
||||
jit_headers += f"#include \"{header[4:]}\"\n"
|
||||
else:
|
||||
jit_headers += f"#include \"{header}\"\n"
|
||||
add_src(
|
||||
func_name+'_'*hid,
|
||||
new_args_def,
|
||||
name2,
|
||||
new_args,
|
||||
vh2v_src,
|
||||
bind_name,
|
||||
py_args,
|
||||
jit_cc_src,
|
||||
doc_string,
|
||||
attrs
|
||||
)
|
||||
if func_name in ["binary", "unary", "reduce"]:
|
||||
# generate binary op alias
|
||||
with open(os.path.join(jittor_path, f"src/ops/{func_name}_op.cc"), encoding="utf-8") as f:
|
||||
src = f.read()
|
||||
src = src.split(f"unordered_set<string> {func_name}_ops = ""{")[1].split("};")[0]
|
||||
res2 = re.findall(pybind_reg + "\"([a-z_A-Z0-9]*)\"", src, re.S)
|
||||
# remove /* doc_string */ pattern
|
||||
res2 = [ (_[3], _[4]) for _ in res2 ]
|
||||
LOG.vvvv(f"All supported {func_name} ops: {res2}")
|
||||
# remove op args
|
||||
if func_name == "reduce":
|
||||
args_def = new_args_def[:1] + new_args_def[2:]
|
||||
py_args_s = py_args[:1] + py_args[2:]
|
||||
else:
|
||||
args_def = new_args_def[:-1]
|
||||
py_args_s = py_args[:-1]
|
||||
# find the last type id(float64)
|
||||
# add "_" suffix for all function
|
||||
if func_name == "unary":
|
||||
last_tid = res2.index(("","float64"))
|
||||
# for each functor
|
||||
for tid, (bind_name, func_name2) in enumerate(res2):
|
||||
# add _ for types
|
||||
if func_name == "unary" and tid <= last_tid:
|
||||
func_name3 = func_name2 + "_"
|
||||
elif func_name == "reduce":
|
||||
func_name4 = func_name2
|
||||
func_name2 = "reduce_" + func_name2
|
||||
func_name3 = func_name2
|
||||
else:
|
||||
func_name3 = func_name2
|
||||
if len(bind_name) == 0:
|
||||
bind_name = func_name2
|
||||
if func_name == "reduce":
|
||||
args = new_args[:1] + [f'ns_{func_name4}'] + new_args[2:]
|
||||
else:
|
||||
args = new_args[:-1] + [f'ns_{func_name2}']
|
||||
add_src(
|
||||
func_name3+'_'*hid,
|
||||
args_def,
|
||||
name2,
|
||||
args,
|
||||
vh2v_src,
|
||||
bind_name,
|
||||
py_args_s,
|
||||
jit_cc_src,
|
||||
doc_string,
|
||||
attrs
|
||||
)
|
||||
|
||||
jit_src = f"""
|
||||
#pragma once
|
||||
#include "pyjt/py_obj_holder.h"
|
||||
#include "var.h"
|
||||
#include "var_holder.h"
|
||||
#include "ops/op_register.h"
|
||||
{jit_headers}
|
||||
|
||||
namespace jittor {{
|
||||
// fix make_array(py::array) undefine reference
|
||||
#pragma GCC visibility push(default)
|
||||
#define JIT_NAMESPACE {export+"_maker" if export else "jit_op_maker"}
|
||||
// @pyjt(ops)
|
||||
// @attrs(submodule{",core_name="+export if export else ""})
|
||||
namespace JIT_NAMESPACE {{
|
||||
{"".join(jit_cc_src)}
|
||||
|
||||
void initer() {{
|
||||
string extra_flags = R"({extra_flags})";
|
||||
{"".join(initer)}
|
||||
}}
|
||||
int caller = (initer(), 0);
|
||||
|
||||
}} // JIT_NAMESPACE
|
||||
}} // jittor
|
||||
{f'''
|
||||
namespace jittor {{
|
||||
extern void pyjt_def_{export}(PyObject*);
|
||||
}}
|
||||
|
||||
static void init_module(PyModuleDef* mdef, PyObject* m) {{
|
||||
mdef->m_doc = "User defined custom ops";
|
||||
jittor::pyjt_def_{export}(m);
|
||||
}}
|
||||
PYJF_MODULE_INIT({export});
|
||||
|
||||
''' if export else ""}
|
||||
"""
|
||||
return jit_src
|
||||
|
||||
def compile_custom_op(header, source, op_name, warp=True):
|
||||
"""Compile a single custom op
|
||||
header: code of op header, not path
|
||||
source: code of op source, not path
|
||||
op_name: op_name of this op, it will used for
|
||||
generation of header and source files, if the
|
||||
type name of op is XxxXxxOp, op_name should be
|
||||
xxx_xxx
|
||||
warp: if true, warp a snippet for header and source
|
||||
"""
|
||||
if warp:
|
||||
header = f"""
|
||||
#pragma once
|
||||
#include "op.h"
|
||||
#include "var.h"
|
||||
namespace jittor {{
|
||||
{header}
|
||||
}}
|
||||
"""
|
||||
source = f"""
|
||||
#include "{op_name}_op.h"
|
||||
namespace jittor {{
|
||||
{source}
|
||||
}}
|
||||
"""
|
||||
cops_dir = os.path.join(cache_path, "custom_ops")
|
||||
make_cache_dir(cops_dir)
|
||||
hname = os.path.join(cops_dir, op_name+"_op.h")
|
||||
ccname = os.path.join(cops_dir, op_name+"_op.cc")
|
||||
with open(hname, 'w') as f:
|
||||
f.write(header)
|
||||
with open(ccname, 'w') as f:
|
||||
f.write(source)
|
||||
m = compile_custom_ops([hname, ccname])
|
||||
return getattr(m, op_name)
|
||||
|
||||
def compile_custom_ops(filenames, extra_flags=""):
|
||||
"""Compile custom ops
|
||||
filenames: path of op source files, filenames must be
|
||||
pairs of xxx_xxx_op.cc and xxx_xxx_op.h, and the
|
||||
type name of op must be XxxXxxOp.
|
||||
extra_flags: extra compile flags
|
||||
return: compiled ops
|
||||
"""
|
||||
srcs = {}
|
||||
headers = {}
|
||||
builds = []
|
||||
includes = []
|
||||
for name in filenames:
|
||||
name = os.path.realpath(name)
|
||||
if name.endswith(".cc") or name.endswith(".cpp") or name.endswith(".cu"):
|
||||
builds.append(name)
|
||||
bname = os.path.basename(name)
|
||||
bname = os.path.splitext(bname)[0]
|
||||
if bname.endswith("_op"):
|
||||
bname = bname[:-3]
|
||||
if name.endswith(".cc"):
|
||||
srcs[bname] = name
|
||||
elif name.endswith(".h"):
|
||||
includes.append(os.path.dirname(name))
|
||||
headers[bname] = name
|
||||
assert len(srcs) == len(headers), "Source and header names not match"
|
||||
for name in srcs:
|
||||
assert name in headers, f"Header of op {name} not found"
|
||||
gen_name = "gen_ops_" + "_".join(headers.keys())
|
||||
if len(gen_name) > 100:
|
||||
gen_name = gen_name[:80] + "___hash" + str(hash(gen_name))
|
||||
|
||||
includes = set(includes)
|
||||
includes = "".join(map(lambda x: f" -I'{x}' ", includes))
|
||||
LOG.vvvv(f"Include flags:{includes}")
|
||||
|
||||
op_extra_flags = includes + extra_flags
|
||||
|
||||
gen_src = gen_jit_op_maker(headers.values(), export=gen_name, extra_flags=op_extra_flags)
|
||||
make_cache_dir(os.path.join(cache_path, "custom_ops"))
|
||||
gen_src_fname = os.path.join(cache_path, "custom_ops", gen_name+".cc")
|
||||
gen_head_fname = os.path.join(cache_path, "custom_ops", gen_name+".h")
|
||||
gen_lib = os.path.join("custom_ops", gen_name+extension_suffix)
|
||||
with open(gen_head_fname, "w") as f:
|
||||
f.write(gen_src)
|
||||
pyjt_compiler.compile_single(gen_head_fname, gen_src_fname)
|
||||
# gen src initialize first
|
||||
builds.insert(0, gen_src_fname)
|
||||
LOG.vvv(f"Build custum ops lib:{gen_lib}")
|
||||
LOG.vvvv(f"Build sources:{builds}")
|
||||
compile(cc_path, cc_flags+opt_flags+includes+extra_flags, builds, gen_lib)
|
||||
|
||||
# add python path and import
|
||||
LOG.vvv(f"Import custum ops lib:{gen_lib}")
|
||||
lib_path = os.path.join(cache_path, "custom_ops")
|
||||
if lib_path not in os.sys.path:
|
||||
os.sys.path.append(lib_path)
|
||||
with jit_utils.import_scope(os.RTLD_GLOBAL | os.RTLD_NOW | os.RTLD_DEEPBIND):
|
||||
exec(f"import {gen_name}")
|
||||
return (locals()[gen_name]).ops
|
||||
|
||||
|
||||
def get_full_path_of_executable(name):
|
||||
full_path = os.path.abspath(name)
|
||||
while os.path.islink(full_path):
|
||||
full_path = os.path.realpath(full_path)
|
||||
if os.path.isfile(full_path) and os.access(full_path, os.X_OK):
|
||||
return full_path
|
||||
return get_full_path_of_executable(find_exe(name))
|
||||
|
||||
def compile_extern():
|
||||
# compile llvm passes
|
||||
if cc_type != "clang":
|
||||
return
|
||||
global kernel_opt_flags
|
||||
cache_path_llvm = os.path.join(cache_path, "llvm")
|
||||
jittor_path_llvm = os.path.join(jittor_path, "extern", "llvm")
|
||||
clang_dir = os.path.dirname(get_full_path_of_executable(cc_path))
|
||||
assert clang_dir.endswith("bin") and "llvm" in clang_dir, f"Wrong clang_dir: {clang_dir}"
|
||||
llvm_include = os.path.abspath(os.path.join(clang_dir, "..", "include"))
|
||||
assert os.path.isdir(llvm_include), "LLVM include path not found"
|
||||
make_cache_dir(cache_path_llvm)
|
||||
files = os.listdir(jittor_path_llvm)
|
||||
# test_pass.cc is used for test link problem of llvm pass plugin
|
||||
test_pass_path = os.path.join(cache_path_llvm, "test_pass.cc")
|
||||
with open(test_pass_path, 'w') as f:
|
||||
f.write("int main() {return 0;}")
|
||||
|
||||
# -fno-rtti fix link error
|
||||
|
||||
# -Wl,-znodelete fix segfault
|
||||
# https://github.com/sampsyo/llvm-pass-skeleton/issues/7#issuecomment-401834287
|
||||
|
||||
# -D_GLIBCXX_USE_CXX11_ABI=0 fix undefined symbol: createPrinterPass
|
||||
# https://stackoverflow.com/questions/37366291/undefined-symbol-for-self-built-llvm-opt
|
||||
|
||||
# try different flags
|
||||
try_flags = [
|
||||
" -Wl,-znodelete -D_GLIBCXX_USE_CXX11_ABI=0 ",
|
||||
" -Wl,-znodelete ",
|
||||
]
|
||||
found_flags_id = -1
|
||||
for fname in files:
|
||||
for i, flag in enumerate(try_flags):
|
||||
if found_flags_id != -1 and found_flags_id != i:
|
||||
continue
|
||||
so_name = os.path.join(cache_path_llvm, os.path.splitext(fname)[0]+f".{i}.so")
|
||||
compile(
|
||||
cc_path,
|
||||
f"{cc_flags} {opt_flags} {flag} -I'{llvm_include}'",
|
||||
[os.path.join(jittor_path_llvm, fname)],
|
||||
so_name
|
||||
)
|
||||
# if not found available flags, we test it.
|
||||
if found_flags_id == -1:
|
||||
try:
|
||||
s = run_cmd(
|
||||
f"{cc_path} {cc_flags} -Xclang -load -Xclang '{so_name}' {test_pass_path}",
|
||||
cache_path_llvm,
|
||||
print_error=False
|
||||
)
|
||||
except Exception as e:
|
||||
LOG.v(f"Try flag {flag} failed: {e}")
|
||||
continue
|
||||
found_flags_id = i
|
||||
kernel_opt_flags += f" -Xclang -load -Xclang '{so_name}' "
|
||||
break
|
||||
else:
|
||||
LOG.w("Clang is used, but LLVM pass plugin is unable to link.")
|
||||
break
|
||||
LOG.vv(f"Compile extern llvm passes: {str(files)}")
|
||||
|
||||
def check_cuda():
|
||||
if nvcc_path == "":
|
||||
return
|
||||
global cc_flags, has_cuda, core_link_flags, cuda_dir, cuda_lib, cuda_include
|
||||
cuda_dir = os.path.dirname(get_full_path_of_executable(nvcc_path))
|
||||
assert cuda_dir.endswith("bin") and "cuda" in cuda_dir, f"Wrong cuda_dir: {cuda_dir}"
|
||||
cuda_include = os.path.abspath(os.path.join(cuda_dir, "..", "include"))
|
||||
cuda_lib = os.path.abspath(os.path.join(cuda_dir, "..", "lib64"))
|
||||
cuda_include2 = os.path.join(jittor_path, "extern","cuda","inc")
|
||||
cc_flags += f" -DHAS_CUDA -I'{cuda_include}' -I'{cuda_include2}' "
|
||||
core_link_flags += f" -lcudart -L'{cuda_lib}' "
|
||||
# ctypes.CDLL(cuda_lib+"/libcudart.so", import_flags)
|
||||
ctypes.CDLL(cuda_lib+"/libcudart.so", dlopen_flags)
|
||||
has_cuda = 1
|
||||
|
||||
def check_cache_compile():
|
||||
files = [
|
||||
"src/utils/cache_compile.cc",
|
||||
"src/utils/log.cc",
|
||||
"src/utils/tracer.cc",
|
||||
"src/utils/jit_utils.cc",
|
||||
]
|
||||
global jit_utils_core_files
|
||||
jit_utils_core_files = files
|
||||
recompile = compile(cc_path, cc_flags+f" {opt_flags} ", files, 'jit_utils_core'+extension_suffix, True)
|
||||
if recompile and jit_utils.cc:
|
||||
LOG.e("jit_utils updated, please restart jittor.")
|
||||
sys.exit(0)
|
||||
if not jit_utils.cc:
|
||||
with jit_utils.import_scope(import_flags):
|
||||
jit_utils.try_import_jit_utils_core()
|
||||
assert jit_utils.cc
|
||||
# recompile, generate cache key
|
||||
compile(cc_path, cc_flags+f" {opt_flags} ", files, 'jit_utils_core'+extension_suffix, True)
|
||||
|
||||
def env_or_try_find(name, bname):
|
||||
if name in os.environ:
|
||||
return os.environ[name]
|
||||
return try_find_exe(bname)
|
||||
|
||||
def try_find_exe(*args):
|
||||
try:
|
||||
return find_exe(*args)
|
||||
except:
|
||||
LOG.v(f"{args[0]} not found.")
|
||||
return ""
|
||||
|
||||
def check_pybt(gdb_path, python_path):
|
||||
if gdb_path=='' or python_path=='':
|
||||
return False
|
||||
ret = sp.getoutput(f"{gdb_path} --batch {python_path} -ex 'help py-bt'")
|
||||
if 'python frame' in ret:
|
||||
LOG.v("py-bt found in gdb.")
|
||||
return True
|
||||
return False
|
||||
|
||||
def check_debug_flags():
|
||||
global is_debug
|
||||
is_debug = 0
|
||||
if os.environ.get("debug")=="1":
|
||||
is_debug = 1
|
||||
global cc_flags
|
||||
cc_flags += " -g -DNODE_MEMCHECK "
|
||||
|
||||
cc_flags = " " + os.environ.get("cc_flags", "")
|
||||
# os.RTLD_NOW | os.RTLD_GLOBAL cause segfault when import torch first
|
||||
import_flags = os.RTLD_NOW | os.RTLD_GLOBAL | os.RTLD_DEEPBIND
|
||||
# if cc_type=="icc":
|
||||
# # weird link problem, icc omp library may conflict and cause segfault
|
||||
# import_flags = os.RTLD_NOW | os.RTLD_GLOBAL
|
||||
dlopen_flags = os.RTLD_NOW | os.RTLD_GLOBAL | os.RTLD_DEEPBIND
|
||||
|
||||
with jit_utils.import_scope(import_flags):
|
||||
jit_utils.try_import_jit_utils_core()
|
||||
jittor_path = find_jittor_path()
|
||||
check_debug_flags()
|
||||
|
||||
sys.path.append(cache_path)
|
||||
|
||||
with jit_utils.import_scope(import_flags):
|
||||
jit_utils.try_import_jit_utils_core()
|
||||
|
||||
python_path = sys.executable
|
||||
py3_config_path = sys.executable+"-config"
|
||||
assert os.path.isfile(python_path)
|
||||
if not os.path.isfile(py3_config_path) :
|
||||
py3_config_path = sys.executable + '3-config'
|
||||
|
||||
assert os.path.isfile(py3_config_path)
|
||||
nvcc_path = env_or_try_find('nvcc_path', '/usr/local/cuda/bin/nvcc')
|
||||
gdb_path = try_find_exe('gdb')
|
||||
addr2line_path = try_find_exe('addr2line')
|
||||
has_pybt = check_pybt(gdb_path, python_path)
|
||||
|
||||
cc_flags += " -Wall -Werror -Wno-unknown-pragmas -std=c++14 -fPIC -march=native "
|
||||
link_flags = " -lstdc++ -ldl -shared "
|
||||
core_link_flags = ""
|
||||
opt_flags = ""
|
||||
kernel_opt_flags = os.environ.get("kernel_flags", "") + opt_flags + " -fopenmp "
|
||||
|
||||
if ' -O' not in cc_flags:
|
||||
opt_flags += " -O2 "
|
||||
kernel_opt_flags += " -Ofast "
|
||||
lto_flags = ""
|
||||
if os.environ.get("enable_lto") == "1":
|
||||
if cc_type == "icc":
|
||||
lto_flags = " -flto -ipo -ipo-c "
|
||||
elif cc_type == "g++":
|
||||
lto_flags = " -flto -fuse-linker-plugin "
|
||||
else:
|
||||
lto_flags = " -flto "
|
||||
|
||||
pybind_include = run_cmd(python_path+" -m pybind11 --includes")
|
||||
LOG.i(f"pybind_include: {pybind_include}")
|
||||
extension_suffix = run_cmd(py3_config_path+" --extension-suffix")
|
||||
LOG.i(f"extension_suffix: {extension_suffix}")
|
||||
|
||||
make_cache_dir(cache_path)
|
||||
make_cache_dir(os.path.join(cache_path, "jit"))
|
||||
make_cache_dir(os.path.join(cache_path, "obj_files"))
|
||||
make_cache_dir(os.path.join(cache_path, "gen"))
|
||||
|
||||
# build cache_compile
|
||||
cc_flags += pybind_include
|
||||
cc_flags += f" -I{jittor_path}/src "
|
||||
check_cache_compile()
|
||||
LOG.v(f"Get cache_compile: {jit_utils.cc}")
|
||||
|
||||
# check cuda
|
||||
has_cuda = 0
|
||||
check_cuda()
|
||||
nvcc_flags = os.environ.get("nvcc_flags", "")
|
||||
if has_cuda:
|
||||
nvcc_flags += cc_flags + link_flags
|
||||
def convert_nvcc_flags(nvcc_flags):
|
||||
# nvcc don't support -Wall option
|
||||
nvcc_flags = nvcc_flags.replace("-Wall", "")
|
||||
nvcc_flags = nvcc_flags.replace("-Wno-unknown-pragmas", "")
|
||||
nvcc_flags = nvcc_flags.replace("-fopenmp", "")
|
||||
nvcc_flags = nvcc_flags.replace("-march", "-Xcompiler -march")
|
||||
nvcc_flags = nvcc_flags.replace("-Werror", "")
|
||||
nvcc_flags = nvcc_flags.replace("-fPIC", "-Xcompiler -fPIC")
|
||||
nvcc_flags += f" -x cu --cudart=shared -ccbin='{cc_path}' --use_fast_math "
|
||||
# nvcc warning is noise
|
||||
nvcc_flags += " -w "
|
||||
nvcc_flags += f" -I'{os.path.join(jittor_path, 'extern/cuda/inc')}' "
|
||||
if os.environ.get("cuda_debug", "0") == "1":
|
||||
nvcc_flags += " -G "
|
||||
return nvcc_flags
|
||||
nvcc_flags = convert_nvcc_flags(nvcc_flags)
|
||||
|
||||
# build core
|
||||
gen_jit_flags()
|
||||
gen_jit_tests()
|
||||
op_headers = run_cmd('find -L src/ops/ | grep "op.h$"', jittor_path).splitlines()
|
||||
jit_src = gen_jit_op_maker(op_headers)
|
||||
LOG.vvvv(jit_src)
|
||||
with open(os.path.join(cache_path, "gen", "jit_op_maker.h"), 'w') as f:
|
||||
f.write(jit_src)
|
||||
cc_flags += f' -I{cache_path} '
|
||||
# gen pyjt
|
||||
pyjt_compiler.compile(cache_path, jittor_path)
|
||||
|
||||
# initialize order:
|
||||
# 1. registers
|
||||
# 2. generate source
|
||||
# 3. op_utils
|
||||
# 4. other
|
||||
files2 = run_cmd(f'find "{os.path.join(cache_path, "gen")}" | grep "cc$"').splitlines()
|
||||
files4 = run_cmd('find -L src | grep "cc$"', jittor_path).splitlines()
|
||||
at_beginning = [
|
||||
"src/ops/op_utils.cc",
|
||||
"src/event_queue.cc",
|
||||
"src/mem/allocator/sfrl_allocator.cc",
|
||||
"src/mem/allocator.cc",
|
||||
]
|
||||
at_last = [
|
||||
"src/profiler/profiler.cc",
|
||||
"src/executor.cc",
|
||||
"src/fetcher.cc",
|
||||
]
|
||||
for i in range(len(at_beginning)):
|
||||
if at_beginning[i] not in files4:
|
||||
continue
|
||||
files4.remove(at_beginning[i])
|
||||
files4.insert(i, at_beginning[i])
|
||||
for v in at_last:
|
||||
if v not in files4:
|
||||
continue
|
||||
files4.remove(v)
|
||||
files4.append(v)
|
||||
registers = [ name for name in files4 if "register" in name ]
|
||||
for name in registers: files4.remove(name)
|
||||
files = registers + files2 + files4
|
||||
for file in jit_utils_core_files:
|
||||
files.remove(file)
|
||||
LOG.vv("compile order:", files)
|
||||
|
||||
# manual Link omp using flags(os.RTLD_NOW | os.RTLD_GLOBAL)
|
||||
# if cc_type=="icc":
|
||||
# os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
|
||||
libname = {"clang":"omp", "icc":"iomp5", "g++":"gomp"}[cc_type]
|
||||
libname = ctypes.util.find_library(libname)
|
||||
assert libname is not None, "openmp library not found"
|
||||
ctypes.CDLL(libname, os.RTLD_NOW | os.RTLD_GLOBAL)
|
||||
|
||||
version_file = os.path.join(jittor_path, "version")
|
||||
if os.path.isfile(version_file):
|
||||
with open(version_file, 'r') as f:
|
||||
version = f.read().strip()
|
||||
key = f"{version}-{cc_type}-{'cuda' if has_cuda else 'cpu'}.o"
|
||||
# TODO: open the website
|
||||
extra_obj = os.path.join(cache_path, key)
|
||||
url = os.path.join("https://cg.cs.tsinghua.edu.cn/jittor/assets/build/"+key)
|
||||
jit_utils.download(url, extra_obj)
|
||||
files.append(extra_obj)
|
||||
|
||||
compile(cc_path, cc_flags+opt_flags, files, 'jittor_core'+extension_suffix)
|
||||
|
||||
# TODO: move to compile_extern.py
|
||||
compile_extern()
|
||||
|
||||
with jit_utils.import_scope(import_flags):
|
||||
import jittor_core as core
|
||||
flags = core.flags()
|
||||
flags.cc_path = cc_path
|
||||
flags.cc_type = cc_type
|
||||
flags.cc_flags = cc_flags + link_flags + kernel_opt_flags
|
||||
flags.nvcc_path = nvcc_path
|
||||
flags.nvcc_flags = nvcc_flags
|
||||
flags.python_path = python_path
|
||||
flags.cache_path = cache_path
|
||||
flags.jittor_path = jittor_path
|
||||
flags.gdb_path = gdb_path
|
||||
flags.addr2line_path = addr2line_path
|
||||
flags.has_pybt = has_pybt
|
|
@ -0,0 +1,204 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Guowei Yang <471184555@qq.com>
|
||||
# Guoye Yang <498731903@qq.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
# All Rights Reserved.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import jittor as jt
|
||||
import numpy as np
|
||||
|
||||
def argmax_pool(x, size, stride, padding=0):
|
||||
y_shape = list(x.shape)
|
||||
y_shape[2]=(x.shape[2]+padding*2-size)//stride+1
|
||||
y_shape[3]=(x.shape[3]+padding*2-size)//stride+1
|
||||
|
||||
y = jt.code(y_shape, x.dtype, [x],
|
||||
cpu_src=f'''
|
||||
for (int i=0; i<outshape0; i++)
|
||||
for (int j=0; j<outshape1; j++)
|
||||
for (int k=0; k<outshape2; k++)
|
||||
for (int l=0; l<outshape3; l++) {{
|
||||
int kx=k*{stride}+{size}/2-{padding};
|
||||
int ky=l*{stride}+{size}/2-{padding};
|
||||
@out(i,j,k,l) = @in0(i,j,kx,ky);
|
||||
for (int p=kx-{size}/2;p<=kx+{size}/2;p++)
|
||||
for (int q=ky-{size}/2;q<=ky+{size}/2;q++)
|
||||
if (p>=0 && q>=0 && p<in0shape2 && q<in0shape3)
|
||||
if (@out(i,j,k,l) < @in0(i,j,p,q))
|
||||
@out(i,j,k,l) = @in0(i,j,p,q);
|
||||
}}
|
||||
''',
|
||||
cpu_grad_src = [f'''
|
||||
for (int i=0; i<outshape0; i++)
|
||||
for (int j=0; j<outshape1; j++)
|
||||
for (int k=0; k<outshape2; k++)
|
||||
for (int l=0; l<outshape3; l++) @out(i,j,k,l) = 0;
|
||||
|
||||
for (int i=0; i<poutshape0; i++)
|
||||
for (int j=0; j<poutshape1; j++)
|
||||
for (int k=0; k<poutshape2; k++)
|
||||
for (int l=0; l<poutshape3; l++) {{
|
||||
int kx=k*{stride}+{size}/2-{padding};
|
||||
int ky=l*{stride}+{size}/2-{padding};
|
||||
int bo=1;
|
||||
for (int p=kx-{size}/2;p<=kx+{size}/2 && bo;p++)
|
||||
for (int q=ky-{size}/2;q<=ky+{size}/2 && bo;q++)
|
||||
if (p>=0 && q>=0 && p<in0shape2 && q<in0shape3)
|
||||
if (@pout(i,j,k,l) == @in0(i,j,p,q)) {{
|
||||
@out(i,j,p,q) += @dout(i,j,k,l);
|
||||
bo=0;
|
||||
}}
|
||||
}}
|
||||
'''])
|
||||
return y
|
||||
|
||||
def concat(arr, dim):
|
||||
# TODO: low performance when concat lots of vars
|
||||
total_dim = 0
|
||||
for a in arr:
|
||||
total_dim += a.shape[dim]
|
||||
cdim = 0
|
||||
s = None
|
||||
indexes = [ f"i{i}" for i in range(len(a.shape)) ]
|
||||
for a in arr:
|
||||
shape = list(a.shape)
|
||||
shape[dim] = total_dim
|
||||
indexes[dim] = f"i{dim}-{cdim}"
|
||||
b = a.reindex(shape, indexes)
|
||||
# ugly fix for preventing large fused op
|
||||
if len(arr)>=10:
|
||||
b.stop_fuse()
|
||||
if s is None:
|
||||
s = b
|
||||
else:
|
||||
s += b
|
||||
cdim += a.shape[dim]
|
||||
return s
|
||||
|
||||
def check(bc):
|
||||
bc = np.array(bc)
|
||||
if ((bc != 1) * (bc != bc.max(0))).sum() > 0:
|
||||
raise Exception(f"Shape not match.")
|
||||
else:
|
||||
return bc.max(0)
|
||||
|
||||
def slice_var_index(x, slices):
|
||||
if not isinstance(slices, tuple):
|
||||
slices = (slices,)
|
||||
if isinstance(slices[0], jt.Var):
|
||||
if len(slices) == 1 and slices[0].dtype == "bool":
|
||||
return (slices[0].where(),)
|
||||
bc = []
|
||||
ml = -1
|
||||
for idx, s in enumerate(slices):
|
||||
if isinstance(s, jt.Var):
|
||||
shape = s.shape
|
||||
elif isinstance(s, np.ndarray):
|
||||
shape = list(s.shape)
|
||||
elif isinstance(s, list):
|
||||
shape = list(np.array(s).shape)
|
||||
else:
|
||||
continue
|
||||
if len(shape) >= ml:
|
||||
ml = len(shape)
|
||||
bc.append(shape)
|
||||
for idx, shape in enumerate(bc):
|
||||
if len(shape) < ml:
|
||||
shape = (ml - len(shape)) * [1] + shape
|
||||
bc[idx] = shape
|
||||
if len(bc) >= 1:
|
||||
bc_shape = check(bc)
|
||||
ss = []
|
||||
for idx, s in enumerate(slices):
|
||||
if isinstance(s, np.ndarray) or isinstance(s, list):
|
||||
ss.append(jt.array(s).broadcast(bc_shape.tolist()))
|
||||
elif isinstance(s, jt.Var):
|
||||
ss.append(s.broadcast(bc_shape.tolist()))
|
||||
else:
|
||||
ss.append(s)
|
||||
slices = ss
|
||||
out_shape = []
|
||||
out_index = []
|
||||
shape = x.shape
|
||||
cnt_list = 0
|
||||
extras_idx = []
|
||||
extras = []
|
||||
for i in range(len(shape)):
|
||||
if i>=len(slices):
|
||||
s = slice(None)
|
||||
else:
|
||||
s = slices[i]
|
||||
sp = shape[i]
|
||||
j = len(out_shape)
|
||||
if isinstance(s, int):
|
||||
if s<0: s += sp
|
||||
out_index.append(str(s))
|
||||
elif isinstance(s, slice):
|
||||
if s == slice(None):
|
||||
out_shape.append(sp)
|
||||
out_index.append(f"i{j}")
|
||||
continue
|
||||
start = 0 if s.start is None else s.start
|
||||
stop = sp if s.stop is None else s.stop
|
||||
step = 1 if s.step is None else s.step
|
||||
if start<0: start += sp
|
||||
if stop<0: stop += sp
|
||||
out_shape.append(1+int(max(0, (stop-start-1)//step)))
|
||||
out_index.append(f"{start}+i{j}*{step}")
|
||||
elif isinstance(s, jt.Var):
|
||||
if cnt_list == 0:
|
||||
for idx in range(len(bc_shape)):
|
||||
extras_idx.append(f"i{len(out_shape) + idx}")
|
||||
out_shape += bc_shape.tolist()
|
||||
out_index.append(f"@e{cnt_list}("+ ",".join(extras_idx) + ")")
|
||||
cnt_list += 1
|
||||
extras.append(s)
|
||||
else:
|
||||
raise Exception(f"Not support slice {s}")
|
||||
if len(out_shape)==0:
|
||||
out_shape = [1]
|
||||
# Stop fuse both input and output, prevent recompile
|
||||
x.stop_fuse()
|
||||
return (out_shape, out_index, 0, [], extras)
|
||||
|
||||
def slice_var(x, slices):
|
||||
reindex_args = slice_var_index(x, slices)
|
||||
x.stop_fuse()
|
||||
return x.reindex(*reindex_args).stop_fuse()
|
||||
|
||||
def setitem(x, slices, value):
|
||||
reindex_args = slice_var_index(x, slices)
|
||||
reindex_reduce_args = (x.shape, reindex_args[1]) + reindex_args[3:]
|
||||
xslice = x.stop_fuse().reindex(*reindex_args).stop_fuse()
|
||||
value = jt.broadcast(value, xslice)
|
||||
one = jt.broadcast(1, xslice)
|
||||
if not isinstance(reindex_args[0][0], jt.Var):
|
||||
reindex_args = (x.shape,) + reindex_args[1:]
|
||||
mask = one.reindex_reduce("add", *reindex_reduce_args)
|
||||
data = value.reindex_reduce("add", *reindex_reduce_args)
|
||||
# Stop fuse both input and output, prevent recompile
|
||||
out = mask.ternary(data, x).stop_fuse()
|
||||
x.assign(out)
|
||||
return x
|
||||
|
||||
jt.Var.__getitem__ = jt.Var.slice_var = slice_var
|
||||
jt.Var.__setitem__ = setitem
|
||||
|
||||
def adam(model, loss, lr=3e-4, betas=[0.9, 0.999], eps=1e-8):
|
||||
ps = jt.find_vars(model)
|
||||
gs = jt.grad(loss, ps)
|
||||
with jt.var_scope('_'.join([model, 'adam']), unique=True):
|
||||
adam_step = jt.make_var([1], init=jt.zeros)
|
||||
adam_step += 1
|
||||
for p,g in zip(ps,gs):
|
||||
m = jt.make_var(p.shape, init=jt.zeros)
|
||||
v = jt.make_var(p.shape, init=jt.zeros)
|
||||
|
||||
m.assign(betas[0] * m + (1-betas[0]) * g)
|
||||
v.assign(betas[1] * v + (1-betas[1]) * g * g)
|
||||
step_size = lr * jt.sqrt(1-betas[1]**adam_step) / (1-betas[0] ** adam_step)
|
||||
p -= m * step_size / (jt.sqrt(v) + eps)
|
||||
|
|
@ -0,0 +1,131 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Meng-Hao Guo <guomenghao1997@gmail.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
# All Rights Reserved.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
import numpy as np
|
||||
from urllib import request
|
||||
import gzip
|
||||
import pickle
|
||||
import os
|
||||
from jittor.dataset.utils import get_random_list, get_order_list, collate_batch
|
||||
import pathlib
|
||||
from PIL import Image
|
||||
|
||||
dataset_root = os.path.join(pathlib.Path.home(), ".cache", "jittor", "dataset")
|
||||
|
||||
class Dataset(object):
|
||||
'''
|
||||
base class for reading data
|
||||
|
||||
Example:
|
||||
class YourDataset(Dataset):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.set_attrs(total_len=1024)
|
||||
|
||||
def __getitem__(self, k):
|
||||
return k, k*k
|
||||
|
||||
dataset = YourDataset().set_attrs(batch_size=256, shuffle=True)
|
||||
for x, y in dataset:
|
||||
......
|
||||
'''
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.batch_size = 16
|
||||
self.total_len = None
|
||||
self.shuffle = False
|
||||
self.drop_last = False
|
||||
|
||||
def __getitem__(self, index):
|
||||
raise NotImplementedError
|
||||
|
||||
def __len__(self):
|
||||
assert self.total_len >= 0
|
||||
assert self.batch_size > 0
|
||||
return (self.total_len-1) // self.batch_size + 1
|
||||
|
||||
def set_attrs(self, **kw):
|
||||
'''set attributes of dataset, equivalent to setattr
|
||||
|
||||
Attrs:
|
||||
batch_size(int): batch size, default 16.
|
||||
totol_len(int): totol lenght.
|
||||
shuffle(bool): shuffle at each epoch, default False.
|
||||
drop_last(bool): if true, the last batch of dataset
|
||||
might smaller than batch_size, default True.
|
||||
'''
|
||||
for k,v in kw.items():
|
||||
assert hasattr(self, k), k
|
||||
setattr(self, k, v)
|
||||
return self
|
||||
|
||||
def collate_batch(self, batch):
|
||||
return collate_batch(batch)
|
||||
|
||||
def __iter__(self):
|
||||
if self.shuffle == False:
|
||||
index_list = get_order_list(self.total_len)
|
||||
else:
|
||||
index_list = get_random_list(self.total_len)
|
||||
batch_data = []
|
||||
for idx in index_list:
|
||||
batch_data.append(self[int(idx)])
|
||||
if len(batch_data) == self.batch_size:
|
||||
batch_data = self.collate_batch(batch_data)
|
||||
yield batch_data
|
||||
batch_data = []
|
||||
|
||||
# depend on drop_last
|
||||
if not self.drop_last and len(batch_data) > 0:
|
||||
batch_data = self.collate_batch(batch_data)
|
||||
yield batch_data
|
||||
|
||||
class ImageFolder(Dataset):
|
||||
"""A image classify dataset, load image and label from directory:
|
||||
|
||||
root/label1/img1.png
|
||||
root/label1/img2.png
|
||||
...
|
||||
root/label2/img1.png
|
||||
root/label2/img2.png
|
||||
...
|
||||
Args:
|
||||
root(string): Root directory path.
|
||||
|
||||
Attributes:
|
||||
classes(list): List of the class names.
|
||||
class_to_idx(dict): map from class_name to class_index.
|
||||
imgs(list): List of (image_path, class_index) tuples
|
||||
"""
|
||||
def __init__(self, root, transform=None):
|
||||
# import ipdb; ipdb.set_trace()
|
||||
super().__init__()
|
||||
self.root = root
|
||||
self.transform = transform
|
||||
self.classes = sorted([d.name for d in os.scandir(root) if d.is_dir()])
|
||||
self.class_to_idx = {v:k for k,v in enumerate(self.classes)}
|
||||
self.imgs = []
|
||||
image_exts = set(('.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff'))
|
||||
|
||||
for i, class_name in enumerate(self.classes):
|
||||
class_dir = os.path.join(root, class_name)
|
||||
for dname, _, fnames in sorted(os.walk(class_dir, followlinks=True)):
|
||||
for fname in sorted(fnames):
|
||||
if os.path.splitext(fname)[-1].lower() in image_exts:
|
||||
path = os.path.join(class_dir, fname)
|
||||
self.imgs.append((path, i))
|
||||
print(f"Found {len(self.classes)} classes and {len(self.imgs)} images.")
|
||||
self.set_attrs(total_len=len(self.imgs))
|
||||
|
||||
def __getitem__(self, k):
|
||||
with open(self.imgs[k][0], 'rb') as f:
|
||||
img = Image.open(f).convert('RGB')
|
||||
if self.transform:
|
||||
img = self.transform(img)
|
||||
return img, self.imgs[k][1]
|
|
@ -0,0 +1,67 @@
|
|||
# ***************************************************************
|
||||
# Copyright(c) 2019
|
||||
# Meng-Hao Guo <guomenghao1997@gmail.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
# All Rights Reserved.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
import numpy as np
|
||||
import gzip
|
||||
from PIL import Image
|
||||
# our lib jittor import
|
||||
from jittor.dataset.dataset import Dataset, dataset_root
|
||||
from jittor.dataset.utils import ensure_dir, download_url_to_local
|
||||
import jittor as jt
|
||||
import jittor.transform as trans
|
||||
|
||||
class MNIST(Dataset):
|
||||
def __init__(self, data_root=dataset_root+"/mnist_data/", train=True ,download=True, transform=None):
|
||||
# if you want to test resnet etc you should set input_channel = 3, because the net set 3 as the input dimensions
|
||||
super().__init__()
|
||||
self.data_root = data_root
|
||||
self.is_train = train
|
||||
self.transform = transform
|
||||
if download == True:
|
||||
self.download_url()
|
||||
|
||||
filesname = [
|
||||
"train-images-idx3-ubyte.gz",
|
||||
"t10k-images-idx3-ubyte.gz",
|
||||
"train-labels-idx1-ubyte.gz",
|
||||
"t10k-labels-idx1-ubyte.gz"
|
||||
]
|
||||
self.mnist = {}
|
||||
if self.is_train:
|
||||
with gzip.open(data_root + filesname[0], 'rb') as f:
|
||||
self.mnist["images"] = np.frombuffer(f.read(), np.uint8, offset=16).reshape(-1,28, 28)
|
||||
with gzip.open(data_root + filesname[2], 'rb') as f:
|
||||
self.mnist["labels"] = np.frombuffer(f.read(), np.uint8, offset=8)
|
||||
else:
|
||||
with gzip.open(data_root + filesname[1], 'rb') as f:
|
||||
self.mnist["images"] = np.frombuffer(f.read(), np.uint8, offset=16).reshape(-1,28, 28)
|
||||
with gzip.open(data_root + filesname[3], 'rb') as f:
|
||||
self.mnist["labels"] = np.frombuffer(f.read(), np.uint8, offset=8)
|
||||
assert(self.mnist["images"].shape[0] == self.mnist["labels"].shape[0])
|
||||
self.total_len = self.mnist["images"].shape[0]
|
||||
# this function must be called
|
||||
self.set_attrs(total_len = self.total_len)
|
||||
|
||||
def __getitem__(self, index):
|
||||
img = Image.fromarray(self.mnist['images'][index]).convert('RGB')
|
||||
if self.transform:
|
||||
img = self.transform(img)
|
||||
return trans.to_tensor(img), self.mnist['labels'][index]
|
||||
|
||||
def download_url(self):
|
||||
resources = [
|
||||
("http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz", "f68b3c2dcbeaaa9fbdd348bbdeb94873"),
|
||||
("http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz", "d53e105ee54ea40749a09fcbcd1e9432"),
|
||||
("http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz", "9fb629c4189551a2d022fa330f9573f3"),
|
||||
("http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz", "ec29112dd5afa0611ce80d1b7f02629c")
|
||||
]
|
||||
|
||||
for url, md5 in resources:
|
||||
filename = url.rpartition('/')[2]
|
||||
download_url_to_local(url, filename, self.data_root, md5)
|
|
@ -0,0 +1,122 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Meng-Hao Guo <guomenghao1997@gmail.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
# All Rights Reserved.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
import jittor as jt
|
||||
import os
|
||||
from six.moves import urllib
|
||||
import hashlib
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
from collections.abc import Sequence, Mapping
|
||||
from PIL import Image
|
||||
|
||||
def ensure_dir(dir_path):
|
||||
if not os.path.isdir(dir_path):
|
||||
os.makedirs(dir_path)
|
||||
|
||||
def _progress():
|
||||
pbar = tqdm(total=None)
|
||||
|
||||
def bar_update(block_num, block_size, total_size):
|
||||
""" reporthook
|
||||
@block_num: the num of downloaded data block
|
||||
@block_size: the size of data block
|
||||
@total_size: the total size of remote file
|
||||
"""
|
||||
if pbar.total is None and total_size:
|
||||
pbar.total = total_size
|
||||
progress_bytes = block_num * block_size
|
||||
pbar.update(progress_bytes - pbar.n)
|
||||
|
||||
return bar_update
|
||||
|
||||
|
||||
def download_url_to_local(url, filename, root_folder, md5):
|
||||
ensure_dir(root_folder)
|
||||
file_path = os.path.join(root_folder, filename)
|
||||
if check_file_exist(file_path, md5):
|
||||
print("Data file has been downloaded and verified")
|
||||
else:
|
||||
try:
|
||||
print('Downloading ' + url + ' to ' + file_path)
|
||||
urllib.request.urlretrieve(
|
||||
url, file_path,
|
||||
reporthook=_progress()
|
||||
)
|
||||
except(urllib.error.URLError, IOError) as e:
|
||||
raise e
|
||||
if not check_file_exist(file_path, md5):
|
||||
raise RuntimeError("File downloads failed.")
|
||||
|
||||
|
||||
|
||||
def check_file_exist(file_path, md5):
|
||||
if not os.path.isfile(file_path):
|
||||
return False
|
||||
if md5 is None:
|
||||
return True
|
||||
return check_md5(file_path, md5)
|
||||
|
||||
|
||||
def calculate_md5(file_path, chunk_size=1024 * 1024):
|
||||
md5 = hashlib.md5()
|
||||
with open(file_path, 'rb') as f:
|
||||
for chunk in iter(lambda: f.read(chunk_size), b''):
|
||||
md5.update(chunk)
|
||||
return md5.hexdigest()
|
||||
|
||||
|
||||
def check_md5(file_path, md5, **kwargs):
|
||||
return md5 == calculate_md5(file_path, **kwargs)
|
||||
|
||||
|
||||
def get_random_list(n):
|
||||
return list(np.random.permutation(range(n)))
|
||||
|
||||
def get_order_list(n):
|
||||
return [i for i in range(n)]
|
||||
|
||||
|
||||
def collate_batch(batch):
|
||||
r"""Puts each data field into a tensor with outer dimension batch size"""
|
||||
real_size = len(batch)
|
||||
elem = batch[0]
|
||||
elem_type = type(elem)
|
||||
if isinstance(elem, jt.Var):
|
||||
if elem.ndim == 1:
|
||||
temp_data = np.stack([data.numpy() for data in batch], 0)
|
||||
temp_data = np.squeeze(temp_data, -1)
|
||||
return jt.array(temp_data)
|
||||
else:
|
||||
temp_data = np.stack([data.numpy() for data in batch], 0)
|
||||
return jt.array(temp_data)
|
||||
if elem_type is np.ndarray:
|
||||
temp_data = np.stack([data for data in batch], 0)
|
||||
return jt.array(temp_data)
|
||||
elif np.issubdtype(elem_type, np.integer):
|
||||
return jt.array(batch)
|
||||
elif isinstance(elem, int):
|
||||
return jt.array(batch)
|
||||
elif isinstance(elem, float):
|
||||
return jt.array(batch)
|
||||
elif isinstance(elem, str):
|
||||
return batch
|
||||
elif isinstance(elem, Mapping):
|
||||
return {key: collate_batch([d[key] for d in batch]) for key in elem}
|
||||
elif isinstance(elem, Sequence):
|
||||
transposed = zip(*batch)
|
||||
return [collate_batch(samples) for samples in transposed]
|
||||
elif isinstance(elem, Image.Image):
|
||||
temp_data = np.stack([np.array(data) for data in batch], 0)
|
||||
return jt.array(temp_data)
|
||||
else:
|
||||
raise TypeError(f"Not support type <{elem_type.__name__}>")
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,53 @@
|
|||
# ***************************************************************
|
||||
# Copyright(c) 2019
|
||||
# Meng-Hao Guo <guomenghao1997@gmail.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
# All Rights Reserved.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
import numpy as np
|
||||
import os
|
||||
from PIL import Image
|
||||
import matplotlib.pyplot as plt
|
||||
import cv2
|
||||
from dataset import Dataset, dataset_root
|
||||
|
||||
class VOC(Dataset):
|
||||
NUM_CLASSES = 21
|
||||
def __init__(self, data_root=dataset_root+'/voc/', split='train'):
|
||||
super().__init__()
|
||||
''' total_len , batch_size, shuffle must be set '''
|
||||
self.data_root = data_root
|
||||
self.split = split
|
||||
|
||||
self.image_root = os.path.join(data_root, 'JPEGImages')
|
||||
self.label_root = os.path.join(data_root, 'SegmentationClass')
|
||||
|
||||
self.data_list_path = os.path.join(self.data_root, 'ImageSets', 'Segmentation', self.split + '.txt')
|
||||
self.image_path = []
|
||||
self.label_path = []
|
||||
|
||||
with open(self.data_list_path, "r") as f:
|
||||
lines = f.read().splitlines()
|
||||
|
||||
for idx, line in enumerate(lines):
|
||||
_img_path = os.path.join(self.image_root, line + '.jpg')
|
||||
_label_path = os.path.join(self.label_root, line + '.png')
|
||||
assert os.path.isfile(_img_path)
|
||||
assert os.path.isfile(_label_path)
|
||||
self.image_path.append(_img_path)
|
||||
self.label_path.append(_label_path)
|
||||
self.set_attrs(total_len = len(self.image_path))
|
||||
|
||||
def __getitem__(self, index):
|
||||
_img = Image.open(self.image_path[index])
|
||||
_label = Image.open(self.label_path[index])
|
||||
_img = _img.resize((513, 513))
|
||||
_label = _label.resize((513, 513))
|
||||
_img = np.array(_img)
|
||||
_label = np.array(_label)
|
||||
_img = _img.transpose(2,0,1)
|
||||
return _img, _label
|
||||
|
|
@ -0,0 +1,52 @@
|
|||
import jittor as jt
|
||||
import numpy as np
|
||||
import sys, os
|
||||
f32 = jt.float32
|
||||
|
||||
@jt.var_scope('linear')
|
||||
def linear(x, n):
|
||||
w = jt.make_var([x.shape[-1], n], init=lambda *a:
|
||||
(jt.random(*a)-f32(0.5)) / f32(x.shape[-1])**f32(0.5))
|
||||
b = jt.make_var([n], init=lambda *a: jt.random(*a)-f32(0.5))
|
||||
return jt.matmul(x, w) + b
|
||||
|
||||
def relu(x): return jt.maximum(x, f32(0))
|
||||
|
||||
@jt.var_scope('model', unique=True)
|
||||
def model(x):
|
||||
x = linear(x, 10)
|
||||
x = relu(x)
|
||||
x = linear(x, 1)
|
||||
return x
|
||||
|
||||
np.random.seed(0)
|
||||
jt.set_seed(3)
|
||||
n = 1000
|
||||
batch_size = 50
|
||||
base_lr = 0.05
|
||||
# we need to stop grad of global value to prevent memory leak
|
||||
lr = f32(base_lr).name("lr").stop_grad()
|
||||
|
||||
def get_data(n):
|
||||
for i in range(n):
|
||||
x = np.random.rand(batch_size, 1)
|
||||
y = x*x
|
||||
yield np.float32(x), np.float32(y)
|
||||
|
||||
for i,(x,y) in enumerate(get_data(n)):
|
||||
pred_y = model(x).name("pred_y")
|
||||
loss = ((pred_y - y)**f32(2)).name("loss")
|
||||
loss_mean = loss.mean()
|
||||
|
||||
ps = jt.find_vars('model')
|
||||
gs = jt.grad(loss_mean, ps)
|
||||
for p,g in zip(ps, gs):
|
||||
p -= g * lr
|
||||
if i>2:
|
||||
assert prev == jt.liveness_info(), f"memory leak {prev} {jt.liveness_info()}"
|
||||
prev = jt.liveness_info()
|
||||
print(f"step {i}, loss = {loss_mean().sum()}")
|
||||
|
||||
# result is 0.0009948202641680837
|
||||
result = 0.0009948202641680837
|
||||
assert abs(loss_mean.data - result) < 1e-6
|
|
@ -0,0 +1 @@
|
|||
../../extern
|
|
@ -0,0 +1,57 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Guowei Yang <471184555@qq.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
# All Rights Reserved.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import jittor as jt
|
||||
import numpy as np
|
||||
import math
|
||||
|
||||
def constant(shape, dtype, value=0.0):
|
||||
return jt.array(np.ones(shape)*value).unary(dtype)
|
||||
|
||||
def constant_(var, value=0.0):
|
||||
var.assign(constant(tuple(var.shape), var.dtype, value))
|
||||
|
||||
def uniform(shape, dtype, low, high):
|
||||
return jt.array(np.random.uniform(low, high, shape)).unary(dtype)
|
||||
|
||||
def uniform_(var, low, high):
|
||||
var.assign(uniform(tuple(var.shape), var.dtype, low, high))
|
||||
|
||||
def gauss(shape, dtype, mean=0.0, std=1.0):
|
||||
return jt.array(np.random.normal(mean, std, shape)).unary(dtype)
|
||||
|
||||
def gauss_(var, mean=0.0, std=1.0):
|
||||
var.assign(gauss(tuple(var.shape), var.dtype, mean, std))
|
||||
|
||||
def invariant_uniform(shape, dtype, mode="fan_in"):
|
||||
assert len(shape)>1
|
||||
assert mode=="fan_in" or mode=="fan_out"
|
||||
|
||||
matsize=1
|
||||
for i in shape[2:]:
|
||||
matsize *= i
|
||||
fan = (shape[1] * matsize) if mode=="fan_in" else (shape[0] * matsize)
|
||||
bound = math.sqrt(1.0/fan)
|
||||
return uniform(shape, dtype, -bound, bound)
|
||||
|
||||
def invariant_uniform_(var, mode="fan_in"):
|
||||
var.assign(invariant_uniform(tuple(var.shape), var.dtype, mode))
|
||||
|
||||
def relu_invariant_gauss(shape, dtype, mode="fan_in"):
|
||||
assert len(shape)>1
|
||||
assert mode=="fan_in" or mode=="fan_out"
|
||||
|
||||
matsize=1
|
||||
for i in shape[2:]:
|
||||
matsize *= i
|
||||
fan = (shape[1] * matsize) if mode=="fan_in" else (shape[0] * matsize)
|
||||
std = math.sqrt(2.0/fan)
|
||||
return gauss(shape, dtype, 0, std)
|
||||
|
||||
def relu_invariant_gauss_(var, mode="fan_in"):
|
||||
var.assign(relu_invariant_gauss(tuple(var.shape), var.dtype, mode))
|
|
@ -0,0 +1,2 @@
|
|||
from . import resnet
|
||||
from . import vgg
|
|
@ -0,0 +1,237 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Guowei Yang <471184555@qq.com>
|
||||
# Wenyang Zhou <576825820@qq.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
# All Rights Reserved.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import jittor as jt
|
||||
from jittor import nn
|
||||
from jittor import Module
|
||||
|
||||
@jt.var_scope('basic_block')
|
||||
def basic_block(x, is_train, in_planes, out_planes, stride = 1):
|
||||
identity = x
|
||||
x = nn.conv(x, in_planes, out_planes, 3, 1, stride)
|
||||
x = nn.batch_norm(x, is_train)
|
||||
x = nn.relu(x)
|
||||
x = nn.conv(x, out_planes, out_planes, 3, 1)
|
||||
x = nn.batch_norm(x, is_train)
|
||||
if in_planes!=out_planes:
|
||||
identity = nn.conv(identity, in_planes, out_planes, 1, 0, stride)
|
||||
identity = nn.batch_norm(identity, is_train)
|
||||
x = x+identity
|
||||
x = nn.relu(x)
|
||||
return x
|
||||
|
||||
@jt.var_scope('make_layer')
|
||||
def make_layer(x, is_train, out_planes, blocks, layer_in_planes, stride = 1):
|
||||
x = basic_block(x, is_train, layer_in_planes, out_planes, stride)
|
||||
layer_in_planes = out_planes
|
||||
|
||||
for i in range(1, blocks):
|
||||
x = basic_block(x, is_train, layer_in_planes, out_planes)
|
||||
return x, layer_in_planes
|
||||
|
||||
@jt.var_scope('bottleneck_block')
|
||||
def bottleneck_block(x, is_train, in_planes, out_planes, stride = 1):
|
||||
expansion = 4
|
||||
width = out_planes
|
||||
identity = x
|
||||
|
||||
x = nn.conv(x, in_planes, width, 1, 0)
|
||||
x = nn.batch_norm(x, is_train)
|
||||
x = nn.relu(x)
|
||||
|
||||
x = nn.conv(x, width, width, 3, 1, stride)
|
||||
x = nn.batch_norm(x, is_train)
|
||||
x = nn.relu(x)
|
||||
|
||||
x = nn.conv(x, width, out_planes * expansion, 1, 0)
|
||||
x = nn.batch_norm(x, is_train)
|
||||
|
||||
if in_planes != out_planes * expansion:
|
||||
identity = nn.conv(identity, in_planes, out_planes * expansion, 1, 0, stride)
|
||||
identity = nn.batch_norm(identity, is_train)
|
||||
|
||||
x = x+identity
|
||||
x = nn.relu(x)
|
||||
return x
|
||||
|
||||
@jt.var_scope('make_layer_bottleneck')
|
||||
def make_layer_bottleneck(x, is_train, out_planes, blocks, layer_in_planes, stride = 1):
|
||||
expansion = 4
|
||||
x = bottleneck_block(x, is_train, layer_in_planes, out_planes, stride)
|
||||
layer_in_planes = out_planes * expansion
|
||||
for i in range(1, blocks):
|
||||
x = bottleneck_block(x, is_train, layer_in_planes, out_planes)
|
||||
return x, layer_in_planes
|
||||
|
||||
@jt.var_scope('resnet')
|
||||
def resnet(x, is_train, block, layers, num_classes = 1000):
|
||||
layer_in_planes = 64
|
||||
x = nn.conv(x, 3, layer_in_planes, 7, 3, 2)
|
||||
x = nn.batch_norm(x, is_train)
|
||||
x = nn.relu(x)
|
||||
x = nn.pool(x, 3, "maximum", 1, 2)
|
||||
x, layer_in_planes = block(x, is_train, 64, layers[0], layer_in_planes)
|
||||
x, layer_in_planes = block(x, is_train, 128, layers[1], layer_in_planes, 2)
|
||||
x, layer_in_planes = block(x, is_train, 256, layers[2], layer_in_planes, 2)
|
||||
x, layer_in_planes = block(x, is_train, 512, layers[3], layer_in_planes, 2)
|
||||
|
||||
x = x.reindex_reduce("add", [x.shape[0],x.shape[1]], ["i0","i1"])/x.shape[2]/x.shape[3]
|
||||
x = nn.linear(x, num_classes)
|
||||
|
||||
return x
|
||||
|
||||
@jt.var_scope('resnet18', unique=True)
|
||||
def resnet18(x, is_train):
|
||||
return resnet(x, is_train, make_layer, [2, 2, 2, 2])
|
||||
|
||||
@jt.var_scope('resnet34', unique=True)
|
||||
def resnet34(x, is_train):
|
||||
return resnet(x, is_train, make_layer, [3, 4, 6, 3])
|
||||
|
||||
@jt.var_scope('resnet50', unique=True)
|
||||
def resnet50(x, is_train):
|
||||
return resnet(x, is_train, make_layer_bottleneck, [3, 4, 6, 3])
|
||||
|
||||
@jt.var_scope('resnet101', unique=True)
|
||||
def resnet101(x, is_train):
|
||||
return resnet(x, is_train, make_layer_bottleneck, [3, 4, 23, 3])
|
||||
|
||||
@jt.var_scope('resnet152', unique=True)
|
||||
def resnet152(x, is_train):
|
||||
return resnet(x, is_train, make_layer_bottleneck, [3, 8, 36, 3])
|
||||
|
||||
class BasicBlock(Module):
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
||||
self.conv1 = nn.Conv(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm(planes)
|
||||
self.relu = nn.Relu()
|
||||
self.conv2 = nn.Conv(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm(planes)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
self.planes = planes
|
||||
|
||||
def execute(self, x):
|
||||
residual = x
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
return out
|
||||
|
||||
class Bottleneck(Module):
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
||||
self.conv1 = nn.Conv(inplanes, planes, kernel_size=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm(planes)
|
||||
self.conv2 = nn.Conv(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm(planes)
|
||||
self.conv3 = nn.Conv(planes, planes * self.expansion, kernel_size=1, bias=False)
|
||||
self.bn3 = nn.BatchNorm(planes * self.expansion)
|
||||
self.relu = nn.Relu()
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def execute(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
return out
|
||||
|
||||
class ResNet(Module):
|
||||
def __init__(self, block, layers, num_classes=1000):
|
||||
self.inplanes = 64
|
||||
self.conv1 = nn.Conv(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
|
||||
self.bn1 = nn.BatchNorm(64)
|
||||
self.relu = nn.Relu()
|
||||
self.maxpool = nn.Pool(kernel_size=3, stride=2, padding=1)
|
||||
self.layer1 = self._make_layer(block, 64, layers[0])
|
||||
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
||||
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
||||
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
||||
self.avgpool = nn.Pool(7, stride=1, op="mean")
|
||||
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride=1):
|
||||
downsample = None
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
nn.Conv(self.inplanes, planes * block.expansion,
|
||||
kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm(planes * block.expansion),
|
||||
)
|
||||
|
||||
layers = []
|
||||
layers.append(block(self.inplanes, planes, stride, downsample))
|
||||
self.inplanes = planes * block.expansion
|
||||
for i in range(1, blocks):
|
||||
layers.append(block(self.inplanes, planes))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def execute(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.maxpool(x)
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
|
||||
x = self.avgpool(x)
|
||||
x = jt.reshape(x, [x.shape[0],-1])
|
||||
x = self.fc(x)
|
||||
|
||||
return x
|
||||
|
||||
def Resnet18():
|
||||
model = ResNet(BasicBlock, [2,2,2,2])
|
||||
return model
|
||||
|
||||
def Resnet34():
|
||||
model = ResNet(BasicBlock, [3,4,6,3])
|
||||
return model
|
||||
|
||||
def Resnet50():
|
||||
model = ResNet(Bottleneck, [3,4,6,3])
|
||||
return model
|
||||
|
||||
def Resnet101():
|
||||
model = ResNet(Bottleneck, [3,4,23,3])
|
||||
return model
|
||||
|
||||
def Resnet152():
|
||||
model = ResNet(Bottleneck, [3,8,36,3])
|
||||
return model
|
|
@ -0,0 +1,122 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Guoye Yang <498731903@qq.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
# All Rights Reserved.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import jittor as jt
|
||||
from jittor import nn
|
||||
|
||||
|
||||
__all__ = [
|
||||
'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn',
|
||||
'vgg19_bn', 'vgg19',
|
||||
]
|
||||
|
||||
|
||||
class VGG(nn.Module):
|
||||
|
||||
def __init__(self, features, num_classes=1000, init_weights=True):
|
||||
super(VGG, self).__init__()
|
||||
self.features = features
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Linear(512 * 7 * 7, 4096),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(),
|
||||
nn.Linear(4096, 4096),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(),
|
||||
nn.Linear(4096, num_classes),
|
||||
)
|
||||
|
||||
def execute(self, x):
|
||||
x = self.features(x)
|
||||
x = jt.reshape(x, [x.shape[0],-1])
|
||||
x = self.classifier(x)
|
||||
return x
|
||||
|
||||
def make_layers(cfg, batch_norm=False):
|
||||
layers = []
|
||||
in_channels = 3
|
||||
for v in cfg:
|
||||
if v == 'M':
|
||||
layers += [nn.Pool(kernel_size=2, stride=2, op="maximum")]
|
||||
else:
|
||||
conv2d = nn.Conv(in_channels, v, kernel_size=3, padding=1)
|
||||
if batch_norm:
|
||||
layers += [conv2d, nn.BatchNorm(v), nn.ReLU()]
|
||||
else:
|
||||
layers += [conv2d, nn.ReLU()]
|
||||
in_channels = v
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
cfgs = {
|
||||
'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
|
||||
'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
|
||||
'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
|
||||
'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
|
||||
}
|
||||
|
||||
|
||||
def _vgg(arch, cfg, batch_norm, **kwargs):
|
||||
model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def VGG11(**kwargs):
|
||||
r"""VGG 11-layer model (configuration "A") from
|
||||
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_
|
||||
"""
|
||||
return _vgg('vgg11', 'A', False, **kwargs)
|
||||
|
||||
|
||||
def VGG11_bn(**kwargs):
|
||||
r"""VGG 11-layer model (configuration "A") with batch normalization
|
||||
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_
|
||||
"""
|
||||
return _vgg('vgg11_bn', 'A', True, **kwargs)
|
||||
|
||||
|
||||
def VGG13(**kwargs):
|
||||
r"""VGG 13-layer model (configuration "B")
|
||||
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_
|
||||
"""
|
||||
return _vgg('vgg13', 'B', False, **kwargs)
|
||||
|
||||
|
||||
def VGG13_bn(**kwargs):
|
||||
r"""VGG 13-layer model (configuration "B") with batch normalization
|
||||
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_
|
||||
"""
|
||||
return _vgg('vgg13_bn', 'B', True, **kwargs)
|
||||
|
||||
|
||||
def VGG16(**kwargs):
|
||||
r"""VGG 16-layer model (configuration "D")
|
||||
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_
|
||||
"""
|
||||
return _vgg('vgg16', 'D', False, **kwargs)
|
||||
|
||||
|
||||
def VGG16_bn(**kwargs):
|
||||
r"""VGG 16-layer model (configuration "D") with batch normalization
|
||||
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_
|
||||
"""
|
||||
return _vgg('vgg16_bn', 'D', True, **kwargs)
|
||||
|
||||
|
||||
def VGG19(**kwargs):
|
||||
r"""VGG 19-layer model (configuration "E")
|
||||
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_
|
||||
"""
|
||||
return _vgg('vgg19', 'E', False, **kwargs)
|
||||
|
||||
|
||||
def VGG19_bn(**kwargs):
|
||||
r"""VGG 19-layer model (configuration 'E') with batch normalization
|
||||
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_
|
||||
"""
|
||||
return _vgg('vgg19_bn', 'E', True, **kwargs)
|
|
@ -0,0 +1,462 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Guowei Yang <471184555@qq.com>
|
||||
# Wenyang Zhou <576825820@qq.com>
|
||||
# Meng-Hao Guo <guomenghao1997@gmail.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
#
|
||||
# All Rights Reserved.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import jittor as jt
|
||||
from jittor import init, Module
|
||||
import numpy as np
|
||||
import math
|
||||
|
||||
def matmul_transpose(a, b):
|
||||
'''
|
||||
returns a * b^T
|
||||
'''
|
||||
assert len(a.shape) >= 2 and len(b.shape) == 2
|
||||
assert a.shape[-1] == b.shape[-1]
|
||||
|
||||
shape = list(a.shape)[:-1] + list(b.shape)
|
||||
a = a.broadcast(shape, [len(shape)-2])
|
||||
b = b.broadcast(shape)
|
||||
return (a*b).sum(len(shape)-1)
|
||||
|
||||
def matmul(a, b):
|
||||
assert len(a.shape) >= 2 and len(b.shape) == 2
|
||||
assert a.shape[-1] == b.shape[-2]
|
||||
|
||||
shape = list(a.shape) + [b.shape[-1]]
|
||||
a = a.broadcast(shape, [len(shape)-1])
|
||||
b = b.broadcast(shape)
|
||||
return (a*b).sum(len(shape)-2)
|
||||
jt.Var.matmul = jt.Var.__matmul__ = matmul
|
||||
jt.Var.__imatmul__ = lambda a,b: a.assign(matmul(a,b))
|
||||
|
||||
def get_init_var_rand(shape, dtype):
|
||||
return jt.array(np.random.normal(0.0, 1.0, shape).astype(np.float32))
|
||||
|
||||
@jt.var_scope('batch_norm')
|
||||
def batch_norm(x, is_train, eps=1e-5, momentum=0.1):
|
||||
w = jt.make_var([x.shape[1]], init=lambda *a: init.constant(*a, 1.0))
|
||||
b = jt.make_var([x.shape[1]], init=lambda *a: init.constant(*a, 0.0))
|
||||
running_mean = jt.make_var([x.shape[1]], init=lambda *a: init.constant(*a, 0.0))
|
||||
running_var = jt.make_var([x.shape[1]], init=lambda *a: init.constant(*a, 1.0))
|
||||
|
||||
w = w.broadcast(x, [0,2,3])
|
||||
b = b.broadcast(x, [0,2,3])
|
||||
if is_train:
|
||||
xmean = jt.mean(x, dims=[0,2,3], keepdims=1)
|
||||
x2mean = jt.mean(x*x, dims=[0,2,3], keepdims=1)
|
||||
xvar = x2mean-xmean*xmean
|
||||
norm_x = (x-xmean)/jt.sqrt(xvar+eps)
|
||||
|
||||
running_mean += (xmean.sum([0,2,3])-running_mean)*momentum
|
||||
running_var += (xvar.sum([0,2,3])-running_var)*momentum
|
||||
else:
|
||||
running_mean = running_mean.broadcast(x, [0,2,3])
|
||||
running_var = running_var.broadcast(x, [0,2,3])
|
||||
norm_x = (x-running_mean)/jt.sqrt(running_var+eps)
|
||||
|
||||
return norm_x * w + b
|
||||
|
||||
def pool(x, size, op, padding, stride = 1):
|
||||
N,C,H,W = x.shape
|
||||
h = (H+padding*2-size)//stride+1
|
||||
w = (W+padding*2-size)//stride+1
|
||||
xx = x.reindex([N,C,h,w,size,size], [
|
||||
"i0", # Nid
|
||||
"i1", # Cid
|
||||
f"i2*{stride}-{padding}+i4", # Hid
|
||||
f"i3*{stride}-{padding}+i5", # Wid
|
||||
])
|
||||
return xx.reduce(op, [4,5])
|
||||
|
||||
@jt.var_scope('conv')
|
||||
def conv(x, in_planes, out_planes, kernel_size, padding, stride = 1, init_method=None):
|
||||
Kw = kernel_size
|
||||
Kh = kernel_size
|
||||
_C = in_planes
|
||||
Kc = out_planes
|
||||
N,C,H,W = x.shape
|
||||
|
||||
assert C==_C
|
||||
if init_method==None:
|
||||
w = jt.make_var([Kc, _C, Kh, Kw], init=lambda *a: init.relu_invariant_gauss(*a, mode="fan_out"))
|
||||
else:
|
||||
w = jt.make_var([Kc, _C, Kh, Kw], init=init_method)
|
||||
xx = x.reindex([N,Kc,C,(H+padding*2-kernel_size)//stride+1,(W+padding*2-kernel_size)//stride+1,Kh,Kw], [
|
||||
'i0', # Nid
|
||||
'i2', # Cid
|
||||
f'i3*{stride}-{padding}+i5', # Hid+Khid
|
||||
f'i4*{stride}-{padding}+i6', # Wid+KWid
|
||||
])
|
||||
ww = w.broadcast(xx.shape, [0,3,4])
|
||||
yy = xx*ww
|
||||
y = yy.sum([2,5,6]) # C, Kh, Kw
|
||||
return y
|
||||
|
||||
@jt.var_scope('linear')
|
||||
def linear(x, n):
|
||||
w = jt.make_var([n, x.shape[-1]], init=lambda *a: init.invariant_uniform(*a))
|
||||
w = w.reindex([w.shape[1], w.shape[0]],["i1","i0"])
|
||||
bound = 1.0/math.sqrt(w.shape[0])
|
||||
b = jt.make_var([n], init=lambda *a: init.uniform(*a,-bound,bound))
|
||||
return jt.matmul(x, w) + b
|
||||
|
||||
def relu(x): return jt.maximum(x, 0)
|
||||
def leaky_relu(x, scale): return jt.ternary(x>0, x, x*scale)
|
||||
|
||||
#TODO dims is 4 will cause slowly execution
|
||||
def cross_entropy_loss(output, target, ignore_index=None):
|
||||
if len(output.shape) == 4:
|
||||
c_dim = output.shape[1]
|
||||
output = output.transpose((0, 2, 3, 1))
|
||||
output = output.reshape((-1, c_dim))
|
||||
if ignore_index is not None:
|
||||
target = jt.ternary(target==ignore_index,
|
||||
jt.array(-1).broadcast(target), target)
|
||||
mask = jt.logical_and(target >= 0, target < output.shape[1])
|
||||
target = target.reshape((-1, ))
|
||||
target = target.broadcast(output, [1])
|
||||
target = target.index(1) == target
|
||||
|
||||
output = output - output.max([1], keepdims=True)
|
||||
loss = output.exp().sum(1).log()
|
||||
loss = loss - (output*target).sum(1)
|
||||
if ignore_index is None:
|
||||
return loss.mean()
|
||||
else:
|
||||
return loss.sum() / jt.maximum(mask.int().sum(), 1)
|
||||
|
||||
class SGD(object):
|
||||
""" Usage:
|
||||
optimizer = nn.SGD(model.parameters(), lr)
|
||||
optimizer.step(loss)
|
||||
"""
|
||||
def __init__(self, parameters, lr, momentum=0, weight_decay=0, dampening=0, nesterov=False):
|
||||
self.lr = lr
|
||||
self.momentum = momentum
|
||||
self.weight_decay = weight_decay
|
||||
self.dampening = dampening
|
||||
self.nesterov = nesterov
|
||||
|
||||
self.no_grad_parameters = []
|
||||
self.parameters = []
|
||||
self.values = []
|
||||
for p in parameters:
|
||||
if p.is_stop_grad():
|
||||
self.no_grad_parameters.append(p)
|
||||
continue
|
||||
self.parameters.append(p)
|
||||
self.values.append(jt.zeros(p.shape, p.dtype).stop_fuse().stop_grad())
|
||||
|
||||
def step(self, loss):
|
||||
ps = self.parameters
|
||||
gs = jt.grad(loss, ps)
|
||||
for p, g, v in zip(ps, gs, self.values):
|
||||
dp = p * self.weight_decay + g
|
||||
v.assign(self.momentum * v + dp * (1 - self.dampening))
|
||||
if self.nesterov:
|
||||
p -= (dp + self.momentum * v) * self.lr
|
||||
else:
|
||||
p -= v * self.lr
|
||||
# detach with the prev graph to reduce memory consumption
|
||||
p.detach_inplace()
|
||||
# sync all no grad parameters, such as
|
||||
# moving_mean and moving_var in batch_norm
|
||||
# sync such parameters to reduce memory consumption
|
||||
jt.sync(self.no_grad_parameters)
|
||||
|
||||
class Adam(object):
|
||||
""" Usage:
|
||||
optimizer = nn.Adam(model.parameters(), lr)
|
||||
optimizer.step(loss)
|
||||
"""
|
||||
def __init__(self, parameters, lr, eps=1e-8, betas=(0.9, 0.999), weight_decay=0):
|
||||
self.lr = lr
|
||||
self.eps = eps
|
||||
self.betas = betas
|
||||
# self.weight_decay = weight_decay
|
||||
assert weight_decay==0, "weight_decay is not supported yet"
|
||||
self.adam_step = 0
|
||||
|
||||
self.no_grad_parameters = []
|
||||
self.parameters = []
|
||||
self.values = []
|
||||
self.m = []
|
||||
for p in parameters:
|
||||
if p.is_stop_grad():
|
||||
self.no_grad_parameters.append(p)
|
||||
continue
|
||||
self.parameters.append(p)
|
||||
self.values.append(jt.zeros(p.shape, p.dtype).stop_fuse().stop_grad())
|
||||
self.m.append(jt.zeros(p.shape, p.dtype).stop_fuse().stop_grad())
|
||||
|
||||
def step(self, loss):
|
||||
ps = self.parameters
|
||||
gs = jt.grad(loss, ps)
|
||||
self.adam_step += 1
|
||||
n, (b0, b1) = float(self.adam_step), self.betas
|
||||
for p, g, v, m in zip(ps, gs, self.values, self.m):
|
||||
m.assign(b0 * m + (1-b0) * g)
|
||||
v.assign(b1 * v + (1-b1) * g * g)
|
||||
step_size = self.lr * jt.sqrt(1-b1**n) / (1-b0 ** n)
|
||||
p -= m * step_size / (jt.sqrt(v) + self.eps)
|
||||
p.detach_inplace()
|
||||
jt.sync(self.no_grad_parameters)
|
||||
|
||||
def softmax(x, dim = None):
|
||||
if dim is None:
|
||||
x = (x - x.max()).exp()
|
||||
ret = x / x.sum()
|
||||
else:
|
||||
x = (x-x.max(dim, keepdims=True)).exp()
|
||||
ret = x / x.sum(dim, keepdims=True)
|
||||
return ret
|
||||
|
||||
class Dropout(Module):
|
||||
def __init__(self, p=0.5, is_train=False):
|
||||
assert p >= 0 and p <= 1, "dropout probability has to be between 0 and 1, but got {}".format(p)
|
||||
self.p = p
|
||||
self.is_train = is_train
|
||||
#TODO: test model.train() to change self.is_train
|
||||
def execute(self, input):
|
||||
output = input
|
||||
if self.p > 0 and self.is_train:
|
||||
if self.p == 1:
|
||||
noise = jt.zeros(input.shape)
|
||||
else:
|
||||
noise = jt.random(input.shape)
|
||||
noise = (noise > self.p).int()
|
||||
output = output * noise
|
||||
return output
|
||||
|
||||
class Linear(Module):
|
||||
def __init__(self, in_features, out_features, bias=True):
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.weight = init.invariant_uniform((out_features, in_features), "float32")
|
||||
bound = 1.0/math.sqrt(in_features)
|
||||
self.bias = init.uniform((out_features,), "float32",-bound,bound) if bias else None
|
||||
|
||||
def execute(self, x):
|
||||
x = matmul_transpose(x, self.weight)
|
||||
if self.bias is not None:
|
||||
return x + self.bias
|
||||
return x
|
||||
|
||||
class BatchNorm(Module):
|
||||
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=None, is_train=True):
|
||||
assert affine == None
|
||||
|
||||
self.num_features = num_features
|
||||
self.is_train = is_train
|
||||
self.eps = eps
|
||||
self.momentum = momentum
|
||||
self.weight = init.constant((num_features,), "float32", 1.0)
|
||||
self.bias = init.constant((num_features,), "float32", 0.0)
|
||||
self.running_mean = init.constant((num_features,), "float32", 0.0).stop_grad()
|
||||
self.running_var = init.constant((num_features,), "float32", 1.0).stop_grad()
|
||||
|
||||
def execute(self, x):
|
||||
if self.is_train:
|
||||
xmean = jt.mean(x, dims=[0,2,3], keepdims=1)
|
||||
x2mean = jt.mean(x*x, dims=[0,2,3], keepdims=1)
|
||||
xvar = x2mean-xmean*xmean
|
||||
norm_x = (x-xmean)/jt.sqrt(xvar+self.eps)
|
||||
self.running_mean += (xmean.sum([0,2,3])-self.running_mean)*self.momentum
|
||||
self.running_var += (xvar.sum([0,2,3])-self.running_var)*self.momentum
|
||||
else:
|
||||
running_mean = self.running_mean.broadcast(x, [0,2,3])
|
||||
running_var = self.running_var.broadcast(x, [0,2,3])
|
||||
norm_x = (x-running_mean)/jt.sqrt(running_var+self.eps)
|
||||
w = self.weight.broadcast(x, [0,2,3])
|
||||
b = self.bias.broadcast(x, [0,2,3])
|
||||
return norm_x * w + b
|
||||
|
||||
class Pool(Module):
|
||||
def __init__(self, kernel_size, stride=None, padding=0, dilation=None, return_indices=None, ceil_mode=False, op="maximum"):
|
||||
assert dilation == None
|
||||
assert return_indices == None
|
||||
self.kernel_size = kernel_size
|
||||
self.op = op
|
||||
self.stride = stride if stride else kernel_size
|
||||
self.padding = padding
|
||||
self.ceil_mode = ceil_mode
|
||||
|
||||
def execute(self, x):
|
||||
N,C,H,W = x.shape
|
||||
if (self.ceil_mode == False):
|
||||
h = (H+self.padding*2-self.kernel_size)//self.stride+1
|
||||
w = (W+self.padding*2-self.kernel_size)//self.stride+1
|
||||
else:
|
||||
h = (H+self.padding*2-self.kernel_size + self.stride - 1)//self.stride+1
|
||||
w = (W+self.padding*2-self.kernel_size + self.stride - 1)//self.stride+1
|
||||
|
||||
xx = x.reindex([N,C,h,w,self.kernel_size,self.kernel_size], [
|
||||
"i0", # Nid
|
||||
"i1", # Cid
|
||||
f"i2*{self.stride}-{self.padding}+i4", # Hid
|
||||
f"i3*{self.stride}-{self.padding}+i5", # Wid
|
||||
])
|
||||
return xx.reduce(self.op, [4,5])
|
||||
|
||||
Relu = jt.make_module(relu)
|
||||
ReLU = Relu
|
||||
Leaky_relu = jt.make_module(leaky_relu, 2)
|
||||
Softmax = jt.make_module(softmax, 2)
|
||||
|
||||
class Conv(Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
|
||||
assert groups == 1
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size)
|
||||
self.stride = stride if isinstance(stride, tuple) else (stride, stride)
|
||||
self.padding = padding if isinstance(padding, tuple) else (padding, padding)
|
||||
self.dilation = dilation if isinstance(dilation, tuple) else (dilation, dilation)
|
||||
Kh, Kw = self.kernel_size
|
||||
assert groups==1, "Group conv not supported yet."
|
||||
self.weight = init.relu_invariant_gauss([out_channels, in_channels, Kh, Kw], dtype="float", mode="fan_out")
|
||||
if bias:
|
||||
self.bias = init.uniform([out_channels], dtype="float", low=-1, high=1)
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
def execute(self, x):
|
||||
N,C,H,W = x.shape
|
||||
Kh, Kw = self.kernel_size
|
||||
assert C==self.in_channels
|
||||
oh = (H+self.padding[0]*2-Kh*self.dilation[0]+self.dilation[0]-1)//self.stride[0]+1
|
||||
ow = (W+self.padding[1]*2-Kw*self.dilation[1]+self.dilation[1]-1)//self.stride[1]+1
|
||||
xx = x.reindex([N,self.out_channels,C,oh,ow,Kh,Kw], [
|
||||
'i0', # Nid
|
||||
'i2', # Cid
|
||||
f'i3*{self.stride[0]}-{self.padding[0]}+i5*{self.dilation[0]}', # Hid+Khid
|
||||
f'i4*{self.stride[1]}-{self.padding[1]}+i6*{self.dilation[1]}', # Wid+KWid
|
||||
])
|
||||
ww = self.weight.broadcast(xx.shape, [0,3,4])
|
||||
yy = xx*ww
|
||||
y = yy.sum([2,5,6]) # Kc, Kh, Kw
|
||||
if self.bias is not None:
|
||||
b = self.bias.broadcast(y.shape, [0,2,3])
|
||||
y = y + b
|
||||
return y
|
||||
|
||||
class ConvTranspose(Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, \
|
||||
padding=0, output_padding=0, groups=1, bias=True, dilation=1):
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
# added
|
||||
self.dilation = dilation
|
||||
self.group = groups
|
||||
assert groups==1, "Group conv not supported yet."
|
||||
|
||||
self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size)
|
||||
self.stride = stride if isinstance(stride, tuple) else (stride, stride)
|
||||
self.dilation = dilation if isinstance(dilation, tuple) else (dilation, dilation)
|
||||
# added
|
||||
self.padding = padding if isinstance(padding, tuple) else (padding, padding)
|
||||
self.real_padding = (self.dilation[0] * (self.kernel_size[0] - 1) - self.padding[0],
|
||||
self.dilation[1] * (self.kernel_size[1] - 1) - self.padding[1])
|
||||
self.output_padding = output_padding if isinstance (output_padding, tuple) else (output_padding, output_padding)
|
||||
|
||||
self.weight = init.relu_invariant_gauss((in_channels, out_channels) + self.kernel_size, dtype="float", mode="fan_out")
|
||||
if bias:
|
||||
self.bias = init.uniform([out_channels], dtype="float", low=-1, high=1)
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
def execute(self, x):
|
||||
N,C,H,W = x.shape
|
||||
i,o,h,w = self.weight.shape
|
||||
assert C==i
|
||||
stride_h, stride_w = self.stride
|
||||
padding_h, padding_w = self.padding
|
||||
dilation_h, dilation_w = self.dilation
|
||||
|
||||
h_out = (H-1) * stride_h + self.output_padding[0] - 2*padding_h + 1 + (h-1)*dilation_h
|
||||
w_out = (W-1) * stride_w + self.output_padding[1] - 2*padding_w + 1 + (w-1)*dilation_w
|
||||
out_shape = (N, o, h_out, w_out)
|
||||
shape = (N, i, o, H, W, h, w)
|
||||
xx = x.broadcast(shape, (2, 5, 6)) # i,h,w
|
||||
ww = self.weight.broadcast(shape, (0, 3, 4)) # N,H,W
|
||||
y = (ww*xx).reindex_reduce("add", out_shape, [
|
||||
'i0', # N
|
||||
'i2', # o
|
||||
f'i3*{stride_h}-{padding_h}+i5*{dilation_h}', # Hid+Khid
|
||||
f'i4*{stride_w}-{padding_w}+i6*{dilation_w}', # Wid+KWid
|
||||
])
|
||||
if self.bias is not None:
|
||||
b = self.bias.broadcast(y.shape, [0,2,3])
|
||||
y = y + b
|
||||
return y
|
||||
|
||||
|
||||
class Tanh(Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
def execute(self, x) :
|
||||
return ((jt.exp (x) - jt.exp(-x)) / (jt.exp(x) + jt.exp (-x)))
|
||||
|
||||
class Sigmoid(Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
def execute(self, x) :
|
||||
return 1 / (1 + jt.exp(-x))
|
||||
|
||||
def resize(x, size, mode="nearest"):
|
||||
img = x
|
||||
n,c,h,w = x.shape
|
||||
H,W = size
|
||||
new_size = [n,c,H,W]
|
||||
nid, cid, hid, wid = jt.index(new_size)
|
||||
x = hid * ((h-1)/(H-1))
|
||||
y = wid * ((w-1)/(W-1))
|
||||
if mode=="nearest":
|
||||
return img.reindex([nid, cid, x.floor(), y.floor()])
|
||||
if mode=="bilinear":
|
||||
fx, fy = x.floor(), y.floor()
|
||||
cx, cy = fx+1, fy+1
|
||||
dx, dy = x-fx, y-fy
|
||||
a = img.reindex_var([nid, cid, fx, fy])
|
||||
b = img.reindex_var([nid, cid, cx, fy])
|
||||
c = img.reindex_var([nid, cid, fx, cy])
|
||||
d = img.reindex_var([nid, cid, cx, cy])
|
||||
dnx, dny = 1-dx, 1-dy
|
||||
ab = dx*b + dnx*a
|
||||
cd = dx*d + dnx*c
|
||||
o = ab*dny + cd*dy
|
||||
return o
|
||||
raise(f"Not support {interpolation}")
|
||||
|
||||
|
||||
|
||||
class Sequential(Module):
|
||||
def __init__(self, *args):
|
||||
self.layers = args
|
||||
def __getitem__(self, idx):
|
||||
return self.layers[idx]
|
||||
def execute(self, x):
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
return x
|
||||
def dfs(self, parents, k, callback, callback_leave):
|
||||
n_children = len(self.layers)
|
||||
ret = callback(parents, k, self, n_children)
|
||||
if ret == False:
|
||||
return
|
||||
for k,v in enumerate(self.layers):
|
||||
parents.append(self)
|
||||
v.dfs(parents, k, callback, callback_leave)
|
||||
parents.pop()
|
||||
if callback_leave:
|
||||
callback_leave(parents, k, self, n_children)
|
|
@ -0,0 +1 @@
|
|||
../../notebook
|
|
@ -0,0 +1,841 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import re
|
||||
import os
|
||||
from jittor_utils import LOG, run_cmd, simple_timer
|
||||
import json
|
||||
from collections import OrderedDict
|
||||
|
||||
def parse_attrs(s):
|
||||
'''parse @attrs(..., x=y) syntax'''
|
||||
attrs = {}
|
||||
if s is None: return attrs
|
||||
for a in s.split(','):
|
||||
a = a.strip()
|
||||
if len(a)==0: continue
|
||||
if '=' in a:
|
||||
k, v = a.split('=')
|
||||
attrs[k] = v
|
||||
else:
|
||||
attrs[a] = 1
|
||||
return attrs
|
||||
|
||||
|
||||
pytype_map = {
|
||||
"const char*": ["PyUnicode_AsUTF8", "PyUnicode_FromString", "PyUnicode_CheckExact"],
|
||||
"int": ["PyLong_AsLong", "PyLong_FromLong", "PyLong_CheckExact"],
|
||||
"int64": ["PyLong_AsLongLong", "PyLong_FromLongLong", "PyLong_CheckExact"],
|
||||
"uint": ["PyLong_AsUnsignedLong", "PyLong_FromUnsignedLong", "PyLong_CheckExact"],
|
||||
"uint64": ["PyLong_AsUnsignedLongLong", "PyLong_FromUnsignedLongLong", "PyLong_CheckExact"],
|
||||
"void": ["...", "GET_PY_NONE", "..."],
|
||||
}
|
||||
def get_pytype_map(T, i):
|
||||
if T in pytype_map:
|
||||
return pytype_map[T][i]
|
||||
return ["from_py_object", "to_py_object", "is_type"][i]+"<"+T+">"
|
||||
|
||||
binary_number_slots = {
|
||||
"__add__": "nb_add",
|
||||
"__sub__": "nb_subtract",
|
||||
"__mul__": "nb_multiply",
|
||||
"__mod__": "nb_remainder",
|
||||
"__divmod__": "nb_divmod",
|
||||
"__pow__": "nb_power",
|
||||
"__lshift__": "nb_lshift",
|
||||
"__rshift__": "nb_rshift",
|
||||
"__and__": "nb_and",
|
||||
"__xor__": "nb_xor",
|
||||
"__or__": "nb_or",
|
||||
"__floordiv__": "nb_floor_divide",
|
||||
"__truediv__": "nb_true_divide",
|
||||
"__matmul__": "nb_matrix_multiply",
|
||||
}
|
||||
|
||||
for k,v in list(binary_number_slots.items()):
|
||||
# __add__: nb_add ----> __iadd: nb_inplace_add
|
||||
binary_number_slots["__i"+k[2:]] = "nb_inplace"+v[2:]
|
||||
|
||||
unary_number_slots = {
|
||||
"__neg__": "nb_negative",
|
||||
"__abs__": "nb_absolute",
|
||||
}
|
||||
|
||||
def split_args(s):
|
||||
# split args xxx,xxx, xx<xx,xx>, xx
|
||||
s = s.strip()
|
||||
if s=="": return []
|
||||
prev = -1
|
||||
presum = 0
|
||||
args = []
|
||||
for i in range(len(s)):
|
||||
if s[i]=='<':
|
||||
presum += 1
|
||||
elif s[i]=='>':
|
||||
presum -= 1
|
||||
if presum==0 and s[i]==',':
|
||||
args.append(s[prev+1:i])
|
||||
prev = i
|
||||
args.append(s[prev+1:])
|
||||
return args
|
||||
|
||||
def get_def_code(df, scope_name, pyname, self_as_arg0=False):
|
||||
is_fast_call = not pyname.startswith("__")
|
||||
no_need_convert = pyname == "__getitem__"
|
||||
args = df["args"]
|
||||
# n==1 && PyXXX__CheckExact(args[0]) && ...
|
||||
max_args = len(args)
|
||||
min_args = max_args
|
||||
for tid, a in enumerate(args):
|
||||
if a[2] != "":
|
||||
min_args = tid
|
||||
break
|
||||
arg_names = [ f"args[{i}]" for i in range(len(args))]
|
||||
if self_as_arg0:
|
||||
max_args -= 1
|
||||
min_args -= 1
|
||||
arg_names = ["self"] + arg_names[:-1]
|
||||
kw_args_id = []
|
||||
for aid, arg in enumerate(args):
|
||||
if "VarHolder*" != arg[0] and is_fast_call:
|
||||
kw_args_id.append(aid)
|
||||
func_quick_check_runable = ""
|
||||
func_quick_check_size = f"n<={max_args} && n>={min_args}"
|
||||
if len(kw_args_id):
|
||||
func_quick_check_size = f"n+(kw?Py_SIZE(kw):0)<={max_args} && n+(kw?Py_SIZE(kw):0)>={min_args}"
|
||||
fill_with_default = ""
|
||||
func_args_convert = ""
|
||||
func_call = df["func_name"]+"("
|
||||
pytypes = [ get_pytype_map(a[0],0) for a in args ]
|
||||
for tid, tpc in enumerate(pytypes):
|
||||
check = get_pytype_map(args[tid][0],2)
|
||||
default_arg = args[tid][2]
|
||||
jtp = args[tid][0]
|
||||
holder_dec = ""
|
||||
holder_set = ""
|
||||
if jtp == "VarHolder*":
|
||||
holder_dec = f"unique_ptr<VarHolder> arg{tid}_holder"
|
||||
holder_set = f", arg{tid}_holder"
|
||||
if len(default_arg):
|
||||
func_args_convert += f"""
|
||||
{holder_dec};
|
||||
{jtp} arg{tid};
|
||||
if (n>{tid-self_as_arg0}) {{
|
||||
CHECK(({check}({arg_names[tid]})));
|
||||
arg{tid} = {tpc}({arg_names[tid]}{holder_set});
|
||||
arg_filled |= 1ull << {tid};
|
||||
}}
|
||||
"""
|
||||
fill_with_default += f"""
|
||||
if (!(arg_filled & (1ull<<{tid}))) {{
|
||||
arg{tid} = {default_arg};
|
||||
}}
|
||||
"""
|
||||
else:
|
||||
func_quick_check_runable += f" && {check}({arg_names[tid]})"
|
||||
func_args_convert += f"""
|
||||
{holder_dec};
|
||||
{jtp} arg{tid} = {tpc}({arg_names[tid]}{holder_set});
|
||||
"""
|
||||
if tid: func_call += ","
|
||||
if args[tid][3].endswith("&&"):
|
||||
func_call += f"move(arg{tid})"
|
||||
else:
|
||||
func_call += f"arg{tid}"
|
||||
if pyname == "__richcmp__":
|
||||
for rname in [ "__lt__", "__le__", "__gt__",
|
||||
"__ge__", "__eq__", "__ne__"]:
|
||||
if rname in df["attrs"]:
|
||||
func_quick_check_runable += " && op==Py_"+rname[2:-2].upper()
|
||||
# fill args with keyword arguments
|
||||
fill_with_kw = ""
|
||||
if is_fast_call and len(kw_args_id):
|
||||
fill_with_kw = f"""
|
||||
if (kw) {{
|
||||
auto kw_n = Py_SIZE(kw);
|
||||
for (int i=0; i<kw_n; i++) {{
|
||||
auto ko = PyTuple_GET_ITEM(kw, i);
|
||||
auto vo = args[i+n];
|
||||
auto ks = PyUnicode_AsUTF8(ko);
|
||||
uint khash = hash(ks);
|
||||
{"".join([
|
||||
f'''
|
||||
if (khash == {get_hash(args[aid][1])}u) {{
|
||||
// hash match {args[aid][1]}
|
||||
CHECK(({get_pytype_map(args[aid][0],2)}(vo)));
|
||||
arg{aid} = {pytypes[aid]}(vo);
|
||||
arg_filled |= 1ull << {aid};
|
||||
continue;
|
||||
}}
|
||||
'''
|
||||
for aid in kw_args_id
|
||||
])}
|
||||
LOGf << "Not a valid keyword:" << ks;
|
||||
}}
|
||||
}}
|
||||
"""
|
||||
|
||||
if len(args):
|
||||
func_args_convert += """
|
||||
CHECK(!PyErr_Occurred());
|
||||
"""
|
||||
func_call += ")"
|
||||
if df["is_property"]:
|
||||
if pyname.startswith("__get__"):
|
||||
func_call = df["func_name"]
|
||||
else:
|
||||
assert pyname.startswith("__set__"), pyname
|
||||
func_call = df["func_name"] + "= arg0"
|
||||
has_return = df["return_t"]!="void" and df["return_t"]!=""
|
||||
# add XXX::xxx or XXX->xxx if is class def
|
||||
if df["is_scope_def"]:
|
||||
if df["is_static"]:
|
||||
func_call = f"{scope_name}::" + func_call
|
||||
else:
|
||||
func_call = f"(GET_RAW_PTR({scope_name},self))->" + func_call
|
||||
if pyname == "__init__":
|
||||
# XXX->xxx(...) ---> new XXX xxx(...)
|
||||
assert "->" in func_call
|
||||
func_call = "new " + func_call.replace("->", " ")
|
||||
if no_need_convert:
|
||||
func_quick_check_runable = ""
|
||||
func_args_convert = ""
|
||||
fill_with_kw = fill_with_default = ""
|
||||
return (
|
||||
func_quick_check_size + func_quick_check_runable,
|
||||
func_args_convert,
|
||||
fill_with_kw+fill_with_default,
|
||||
func_call,
|
||||
has_return
|
||||
)
|
||||
|
||||
hash_to_key_map = {}
|
||||
|
||||
def get_hash(s):
|
||||
mask = (1<<32)-1
|
||||
v=0
|
||||
mul = 1
|
||||
for c in s:
|
||||
v += mul * ord(c)
|
||||
mul *= 55
|
||||
v &= mask
|
||||
mul &= mask
|
||||
if v in hash_to_key_map:
|
||||
assert hash_to_key_map[v] == s, \
|
||||
f"hash conflict {hash_to_key_map[v]} {s} {hash_to_key_map}"
|
||||
hash_to_key_map[v] = s
|
||||
return v
|
||||
|
||||
|
||||
reg = re.compile(
|
||||
'(/\\*(.*?)\\*/\\s*)?(//\\s*@pyjt\\(([^\\n]*)\\)\\s*)'
|
||||
# ^^^^^^^^^^^^^^^^^ ^^^^ ^^^^
|
||||
# doc string $1 pyjt args $3
|
||||
+
|
||||
'(//\\s*@attrs\\(([^\\n]*)\\)\\s*)?'
|
||||
# ^^^^^ ^^^^^^^
|
||||
# attrs args $5
|
||||
, re.DOTALL)
|
||||
|
||||
def compile_src(src, h, basename):
|
||||
res = list(reg.finditer(src, re.S))
|
||||
if len(res)==0: return
|
||||
class_ranges = None
|
||||
class_name = None
|
||||
class_info = None
|
||||
submodule_name = None
|
||||
submodule_ranges = None
|
||||
submodule_info = None
|
||||
defs = []
|
||||
LOG.vv(("find in", h))
|
||||
for x in res:
|
||||
LOG.vvv((x, x.groups()))
|
||||
g = x.groups()
|
||||
doc = g[1]
|
||||
pyjt = g[3]
|
||||
attrs = g[5]
|
||||
esplit = lambda x: [] if x==None else \
|
||||
[ a.strip() for a in x.split(",") if len(a.strip()) ]
|
||||
attrs = parse_attrs(attrs)
|
||||
pynames = esplit(pyjt)
|
||||
end = x.end()
|
||||
def find_bc(i):
|
||||
while src[i] not in "({;":
|
||||
i += 1
|
||||
j = i+1
|
||||
if src[i]==';':
|
||||
return i, j
|
||||
presum = 1
|
||||
while True:
|
||||
if src[j] in "({[":
|
||||
presum += 1
|
||||
elif src[j] in ")}]":
|
||||
presum -= 1
|
||||
if presum==0:
|
||||
s = src[i]+src[j]
|
||||
assert s in ("()","{}","()"), "braces not match "+s
|
||||
return i, j
|
||||
j += 1
|
||||
# // @pyjt(DType)
|
||||
# struct DType {
|
||||
# ^ --> a
|
||||
# .....
|
||||
# } <--- b
|
||||
# or
|
||||
# // @pyjt(hash)
|
||||
# inline uint hash(const char* input)
|
||||
# ^ --> a ^ --> b
|
||||
a, b = find_bc(end)
|
||||
is_property = 0
|
||||
if src[a] == ';':
|
||||
# This case
|
||||
# class XXX {
|
||||
# // @pyjt(property)
|
||||
# T property;
|
||||
# }
|
||||
is_property = 1
|
||||
if src[a] == '{':
|
||||
assert len(pynames)==1
|
||||
if "submodule" in attrs:
|
||||
assert submodule_ranges==None
|
||||
submodule_ranges = (a, b)
|
||||
submodule_name = src[end:a-1].strip().split()[-1]
|
||||
submodule_info = {
|
||||
"pynames": pynames,
|
||||
"attrs": attrs
|
||||
}
|
||||
continue
|
||||
assert class_ranges==None
|
||||
class_ranges = (a, b)
|
||||
class_name = src[end:a-1].strip().split()[-1]
|
||||
class_info = {
|
||||
"pynames": pynames,
|
||||
"attrs": attrs
|
||||
}
|
||||
continue
|
||||
is_scope_def = False
|
||||
is_static = False
|
||||
scope_name = ""
|
||||
if class_ranges != None:
|
||||
if class_ranges[0] < a and a < class_ranges[1]:
|
||||
is_scope_def = True
|
||||
scope_name = class_name
|
||||
if submodule_ranges != None:
|
||||
if submodule_ranges[0] < a and a < submodule_ranges[1]:
|
||||
is_scope_def = True
|
||||
scope_name = submodule_name
|
||||
is_static = True
|
||||
dec = src[end:b+1].strip()
|
||||
arr = src[end:a].strip().split()
|
||||
func_name = arr[-1]
|
||||
|
||||
is_constructor = False
|
||||
if is_scope_def and func_name==class_name:
|
||||
is_constructor = True
|
||||
|
||||
args = []
|
||||
for arg in split_args(src[a+1:b]):
|
||||
if arg=="": continue
|
||||
default = ""
|
||||
if "=" in arg:
|
||||
arg, default = arg.split('=')
|
||||
default = default
|
||||
arg = arg.strip()
|
||||
name = arg.split(' ')[-1]
|
||||
tp = arg[:-len(name)]
|
||||
tp = tp.strip()
|
||||
prev_tp = tp
|
||||
# const string& ----> string
|
||||
if tp.startswith("const") and tp.endswith("&"):
|
||||
tp = tp[5:-1].strip()
|
||||
# T&& -> T
|
||||
if tp.endswith("&&"):
|
||||
tp = tp[:-2].strip()
|
||||
# ArrayArgs& -> ArrayArgs
|
||||
if tp.endswith("&"):
|
||||
tp = tp[:-1].strip()
|
||||
args.append((tp, name.strip(), default.strip(), prev_tp))
|
||||
return_t = ""
|
||||
for a in arr[:-1]:
|
||||
if a in ["", "inline", "constexpr"]: continue
|
||||
if a == "static":
|
||||
is_static = True
|
||||
continue
|
||||
if return_t != "": return_t += " "
|
||||
return_t += a
|
||||
|
||||
if is_scope_def and class_info and "submodule" in class_info["attrs"]:
|
||||
is_static = True
|
||||
|
||||
for pid, pyname in enumerate(pynames):
|
||||
for rname in [ "__lt__", "__le__", "__gt__",
|
||||
"__ge__", "__eq__", "__ne__"]:
|
||||
if pyname.endswith(rname):
|
||||
attrs[rname] = 1
|
||||
pynames[pid] = pyname.replace(rname, "__richcmp__")
|
||||
|
||||
def_info = {
|
||||
"is_scope_def": is_scope_def,
|
||||
"is_constructor": is_constructor,
|
||||
"is_static": is_static,
|
||||
"is_property": is_property,
|
||||
"func_name": func_name,
|
||||
"args": args, # [(type,name,defaut), ...]
|
||||
"return_t": return_t, # return type
|
||||
"dec": dec, # full string of xxx(A a, B b)
|
||||
"pynames": pynames, # names in @pyjt(...)
|
||||
"attrs": attrs, # attrs in @attrs(...)
|
||||
"doc": doc,
|
||||
"scope_name": scope_name,
|
||||
}
|
||||
if is_property:
|
||||
# This case
|
||||
# class XXX {
|
||||
# // @pyjt(property)
|
||||
# T property;
|
||||
# }
|
||||
assert is_scope_def and not is_static
|
||||
def_info["is_property"] = 1
|
||||
def_info["pynames"] = ["__get__"+n for n in pynames]
|
||||
assert return_t != "void"
|
||||
defs.append(dict(def_info))
|
||||
def_info["pynames"] = ["__set__"+n for n in pynames]
|
||||
assert len(args) == 0
|
||||
def_info["args"] = [(def_info["return_t"], func_name, "", "")]
|
||||
def_info["return_t"] = "void"
|
||||
defs.append(dict(def_info))
|
||||
continue
|
||||
else:
|
||||
defs.append(def_info)
|
||||
LOG.vvv(json.dumps(def_info, indent=4))
|
||||
# deal with defs
|
||||
if len(defs) == 0: return
|
||||
# include_name = h[4:] # remove "src/" prefix
|
||||
include_name = h
|
||||
code = []
|
||||
class_defs_code = []
|
||||
class_getsets_code = []
|
||||
class_gets = OrderedDict()
|
||||
class_sets = OrderedDict()
|
||||
class_slots_code = []
|
||||
submodule_defs_code = []
|
||||
def_targets = OrderedDict()
|
||||
for df in defs:
|
||||
for name in df["pynames"]:
|
||||
if df["is_scope_def"] and '.' not in name:
|
||||
if df["scope_name"] == class_name:
|
||||
name = class_info["pynames"][0] + '.' + name
|
||||
else:
|
||||
name = submodule_info["pynames"][0] + '.' + name
|
||||
if name not in def_targets:
|
||||
def_targets[name] = []
|
||||
def_targets[name].append(df)
|
||||
for name in def_targets:
|
||||
dfs = def_targets[name]
|
||||
target_scope_name = None
|
||||
LOG.vv(name)
|
||||
if "." in name:
|
||||
target_scope_name, name = name.split(".")
|
||||
# array for each df:
|
||||
arr_func_quick_check_runable = []
|
||||
arr_func_args_convert = []
|
||||
arr_fill_with_default = []
|
||||
arr_func_call = []
|
||||
arr_has_return = []
|
||||
self_as_arg0 = False
|
||||
for df in dfs:
|
||||
self_as_arg0 = class_info and \
|
||||
target_scope_name == class_info["pynames"][0] and \
|
||||
df["scope_name"] == submodule_name \
|
||||
and not name.startswith("__")
|
||||
res = get_def_code(df, df["scope_name"], name, bool(self_as_arg0))
|
||||
arr_func_quick_check_runable.append(res[0])
|
||||
arr_func_args_convert.append(res[1])
|
||||
arr_fill_with_default.append(res[2])
|
||||
arr_func_call.append(res[3])
|
||||
arr_has_return.append(res[4])
|
||||
|
||||
slot_name = None
|
||||
func_cast = ""
|
||||
func_fill = ""
|
||||
if name == "__init__":
|
||||
slot_name = "tp_init"
|
||||
func_head = "(PyObject* self, PyObject* _args, PyObject* kw) -> int"
|
||||
func_fill = """
|
||||
int64 n = Py_SIZE(_args);
|
||||
auto args = (PyObject**)&PyTuple_GET_ITEM(_args, 0);
|
||||
(void)n, (void)args;
|
||||
// TODO: support kw
|
||||
CHECK(kw==0);
|
||||
"""
|
||||
|
||||
elif name == "__repr__":
|
||||
slot_name = "tp_repr"
|
||||
func_head = "(PyObject* self) -> PyObject*"
|
||||
func_fill = "int64 n = 0; (void)n;"
|
||||
|
||||
elif name.startswith("__get__"):
|
||||
slot_name = "tp_gets"
|
||||
name = name[len("__get__"):]
|
||||
func_head = "(PyObject* self, void*) -> PyObject*"
|
||||
func_fill = "int64 n = 0; (void)n;"
|
||||
|
||||
elif name.startswith("__set__"):
|
||||
slot_name = "tp_sets"
|
||||
name = name[len("__set__"):]
|
||||
func_head = "(PyObject* self, PyObject* arg, void*) -> int"
|
||||
func_fill = """
|
||||
int64 n=1;
|
||||
PyObject** args = &arg;
|
||||
(void)n, (void)args;
|
||||
"""
|
||||
|
||||
elif name == "__call__":
|
||||
slot_name = "tp_call"
|
||||
func_head = "(PyObject* self, PyObject* _args, PyObject* kw) -> PyObject*"
|
||||
func_fill = """
|
||||
int64 n = Py_SIZE(_args);
|
||||
auto args = (PyObject**)&PyTuple_GET_ITEM(_args, 0);
|
||||
(void)n, (void)args;
|
||||
// TODO: support kw
|
||||
CHECK(kw==0);
|
||||
"""
|
||||
|
||||
elif name == "__dealloc__":
|
||||
slot_name = "tp_dealloc"
|
||||
func_head = "(PyObject* self) -> void"
|
||||
func_fill = "int64 n = 0"
|
||||
|
||||
elif name in binary_number_slots:
|
||||
slot_name = "tp_as_number->"+binary_number_slots[name]
|
||||
func_head = "(PyObject* self, PyObject* b) -> PyObject*"
|
||||
if name.endswith("pow__"):
|
||||
func_head = "(PyObject* self, PyObject* b, PyObject*) -> PyObject*"
|
||||
func_fill = """
|
||||
int64 n = 2;
|
||||
PyObject* args[] = {self, b};
|
||||
(void)n, (void)args;
|
||||
"""
|
||||
|
||||
elif name in unary_number_slots:
|
||||
slot_name = "tp_as_number->"+unary_number_slots[name]
|
||||
func_head = "(PyObject* self) -> PyObject*"
|
||||
func_fill = """
|
||||
int64 n = 1;
|
||||
PyObject* args[] = {self};
|
||||
(void)n, (void)args;
|
||||
"""
|
||||
|
||||
elif name == "__richcmp__":
|
||||
slot_name = "tp_richcompare"
|
||||
func_head = "(PyObject* self, PyObject* b, int op) -> PyObject*"
|
||||
func_fill = """
|
||||
int64 n = 2;
|
||||
PyObject* args[] = {self, b};
|
||||
(void)n, (void)args;
|
||||
"""
|
||||
|
||||
elif name == "__len__":
|
||||
slot_name = "tp_as_sequence->sq_length"
|
||||
func_head = "(PyObject* self) -> Py_ssize_t"
|
||||
func_fill = """
|
||||
int64 n = 0;
|
||||
(void)n;
|
||||
"""
|
||||
|
||||
elif name == "__map_len__":
|
||||
slot_name = "tp_as_mapping->mp_length"
|
||||
func_head = "(PyObject* self) -> Py_ssize_t"
|
||||
func_fill = """
|
||||
int64 n = 0;
|
||||
(void)n;
|
||||
"""
|
||||
|
||||
elif name == "__getitem__":
|
||||
slot_name = "tp_as_sequence->sq_item"
|
||||
func_head = "(PyObject* self, Py_ssize_t arg0) -> PyObject*"
|
||||
func_fill = f"""
|
||||
int64 n = 1;
|
||||
(void)n;
|
||||
if (arg0 >= GET_RAW_PTR({dfs[0]["scope_name"]},self)->size()) {{
|
||||
PyErr_SetString(PyExc_IndexError, "");
|
||||
return 0;
|
||||
}}
|
||||
"""
|
||||
|
||||
elif name == "__map_getitem__":
|
||||
slot_name = "tp_as_mapping->mp_subscript"
|
||||
func_head = "(PyObject* self, PyObject* arg0) -> PyObject*"
|
||||
func_fill = f"""
|
||||
int64 n = 1;
|
||||
PyObject* args[] = {{arg0}};
|
||||
(void)n;
|
||||
"""
|
||||
|
||||
elif name.startswith("__"):
|
||||
LOG.f(f"Not support slot {name}")
|
||||
continue
|
||||
|
||||
else:
|
||||
func_head = "(PyObject* self, PyObject** args, int64 n, PyObject* kw) -> PyObject*"
|
||||
func_cast = f"(PyCFunction)(PyObject* (*)(PyObject*,PyObject**,int64,PyObject*))"
|
||||
# if not return, return py_none
|
||||
arr_has_return = [ True for _ in arr_has_return ]
|
||||
|
||||
arr_func_return = []
|
||||
doc_all = ""
|
||||
decs = "Declarations:\n"
|
||||
for did, has_return in enumerate(arr_has_return):
|
||||
df = dfs[did]
|
||||
func_call = arr_func_call[did]
|
||||
if df["doc"]:
|
||||
doc_all += "Document:\n"
|
||||
doc_all += df["doc"]
|
||||
doc_all += "\nDeclaration:\n"
|
||||
doc_all += df["dec"]
|
||||
decs += df["dec"]+'\n'
|
||||
if has_return:
|
||||
assert "-> int" not in func_head
|
||||
if "-> PyObject*" in func_head:
|
||||
if "return_self" in df["attrs"]:
|
||||
arr_func_return.append(
|
||||
f"return (({func_call}), Py_INCREF(self), self)")
|
||||
else:
|
||||
arr_func_return.append(
|
||||
f"return {get_pytype_map(df['return_t'],1)}(({func_call}))")
|
||||
func_return_failed = "return nullptr"
|
||||
else:
|
||||
arr_func_return.append(
|
||||
f"return ({func_call});")
|
||||
func_return_failed = "return -1"
|
||||
else:
|
||||
if "-> int" in func_head:
|
||||
arr_func_return.append(f"return ({func_call},0)")
|
||||
func_return_failed = "return -1"
|
||||
else:
|
||||
assert "-> void" in func_head
|
||||
arr_func_return.append(f"{func_call};return")
|
||||
func_return_failed = "return"
|
||||
func = f"""
|
||||
{func_cast}[]{func_head} {{
|
||||
try {{
|
||||
{func_fill};
|
||||
uint64 arg_filled=0;
|
||||
(void)arg_filled;
|
||||
{"".join([f'''
|
||||
if ({arr_func_quick_check_runable[did]}) {{
|
||||
{arr_func_args_convert[did]};
|
||||
{arr_fill_with_default[did]};
|
||||
{arr_func_return[did]};
|
||||
}}
|
||||
'''
|
||||
for did in range(len(arr_func_return))
|
||||
])}
|
||||
LOGf << "Not a valid call";
|
||||
}} catch (const std::exception& e) {{
|
||||
PyErr_Format(PyExc_RuntimeError, "%s\\n%s",
|
||||
e.what(),
|
||||
R""({decs})""
|
||||
);
|
||||
}}
|
||||
{func_return_failed};
|
||||
}}
|
||||
"""
|
||||
|
||||
if slot_name:
|
||||
if slot_name=="tp_gets":
|
||||
class_gets[name] = {
|
||||
"func": func,
|
||||
"doc": doc_all
|
||||
}
|
||||
continue
|
||||
if slot_name=="tp_sets":
|
||||
class_sets[name] = {
|
||||
"func": func,
|
||||
"doc": ""
|
||||
}
|
||||
continue
|
||||
class_slots_code.append(f"""
|
||||
tp.{slot_name} = {func};
|
||||
""")
|
||||
continue
|
||||
need_static = ""
|
||||
if df["is_scope_def"] and df["is_static"] and \
|
||||
df["scope_name"] == class_name and \
|
||||
"submodule" not in class_info["attrs"]:
|
||||
need_static = " | METH_STATIC"
|
||||
func = (f"""
|
||||
{{ R""({name})"",
|
||||
{func},
|
||||
METH_FASTCALL | METH_KEYWORDS{need_static},
|
||||
R""({doc_all})""
|
||||
}}""")
|
||||
if df["is_scope_def"]:
|
||||
if df["scope_name"] == class_name or \
|
||||
(class_info and \
|
||||
target_scope_name == class_info["pynames"][0]):
|
||||
class_defs_code.append(func)
|
||||
else:
|
||||
submodule_defs_code.append(func)
|
||||
else:
|
||||
code.append(func)
|
||||
prop_names = list(set(class_gets.keys()).union(class_sets.keys()))
|
||||
prop_names = sorted(prop_names)
|
||||
for prop_name in prop_names:
|
||||
get_func = "NULL"
|
||||
set_func = "NULL"
|
||||
doc = ""
|
||||
if prop_name in class_gets:
|
||||
get_func = class_gets[prop_name]["func"]
|
||||
if class_gets[prop_name]["doc"]:
|
||||
doc += class_gets[prop_name]["doc"]
|
||||
if prop_name in class_sets:
|
||||
set_func = class_sets[prop_name]["func"]
|
||||
if class_sets[prop_name]["doc"]:
|
||||
doc += class_sets[prop_name]["doc"]
|
||||
class_getsets_code.append(f"""
|
||||
{{"{prop_name}", {get_func}, {set_func}, R""({doc})""}}
|
||||
""")
|
||||
code.append("{0,0,0,0}")
|
||||
class_defs_code.append("{0,0,0,0}")
|
||||
class_getsets_code.append("{0,0,0,0}")
|
||||
submodule_defs_code.append("{0,0,0,0}")
|
||||
core_name = "jittor_core"
|
||||
if class_info and "attrs" in class_info and "core_name" in class_info["attrs"]:
|
||||
core_name = class_info["attrs"]["core_name"]
|
||||
if submodule_info and "attrs" in submodule_info and "core_name" in submodule_info["attrs"]:
|
||||
core_name = submodule_info["attrs"]["core_name"]
|
||||
has_map = class_name in ["VarHolder", "NanoVector"]
|
||||
has_seq = class_name == "NanoVector"
|
||||
code = f"""
|
||||
#include "pyjt/py_converter.h"
|
||||
#include "common.h"
|
||||
#include "{include_name}"
|
||||
|
||||
namespace jittor {{
|
||||
|
||||
{
|
||||
"" if class_name is None else
|
||||
f"PyHeapTypeObject Pyjt{class_name};" if "heaptype" in class_info["attrs"] else
|
||||
f"PyTypeObject Pyjt{class_name};"
|
||||
}
|
||||
|
||||
void pyjt_def_{basename}(PyObject* m) {{
|
||||
static PyMethodDef defs[] = {{
|
||||
{",".join(code)}
|
||||
}};
|
||||
ASSERT(PyModule_AddFunctions(m, defs)==0);
|
||||
{
|
||||
f'''
|
||||
static PyMethodDef class_defs[] = {{
|
||||
{",".join(class_defs_code)}
|
||||
}};
|
||||
static PyGetSetDef class_getsets[] = {{
|
||||
{",".join(class_getsets_code)}
|
||||
}};
|
||||
|
||||
static PyNumberMethods number_methods = {{0}};
|
||||
{f"auto& htp =Pyjt{class_name}; auto& tp = htp.ht_type;"
|
||||
if "heaptype" in class_info["attrs"] else
|
||||
f"auto& tp = Pyjt{class_name};"}
|
||||
tp.tp_as_number = &number_methods;
|
||||
|
||||
{f"static PyMappingMethods class_map_defs = {{0}};" if has_map else ""}
|
||||
{f"tp.tp_as_mapping = &class_map_defs;" if has_map else ""}
|
||||
|
||||
{f"static PySequenceMethods class_seq_defs = {{0}};" if has_seq else ""}
|
||||
{f"tp.tp_as_sequence = &class_seq_defs;" if has_seq else ""}
|
||||
|
||||
tp.tp_name = "{core_name}.{class_info["pynames"][0]}";
|
||||
tp.tp_basicsize = GET_OBJ_SIZE({class_name});
|
||||
tp.tp_new = PyType_GenericNew;
|
||||
tp.tp_flags = Py_TPFLAGS_DEFAULT;
|
||||
{"tp.tp_flags |= Py_TPFLAGS_HEAPTYPE; htp.ht_name = htp.ht_qualname = to_py_object<string>(tp.tp_name);"
|
||||
if "heaptype" in class_info["attrs"] else ""}
|
||||
tp.tp_methods = &class_defs[0];
|
||||
tp.tp_getset = &class_getsets[0];
|
||||
{"".join(class_slots_code)};
|
||||
ASSERT(0==PyType_Ready(&tp)) << (PyErr_Print(), 0);
|
||||
Py_INCREF(&tp);
|
||||
ASSERT(0==PyModule_AddObject(m, "{class_info["pynames"][0]}", (PyObject*)&tp));
|
||||
''' if class_name is not None else ""
|
||||
}
|
||||
{f'''
|
||||
|
||||
// sub module def
|
||||
static PyMethodDef submodule_defs[] = {{
|
||||
{",".join(submodule_defs_code)}
|
||||
}};
|
||||
auto sub = PyImport_AddModule("{core_name}.{submodule_info["pynames"][0]}");
|
||||
ASSERT(PyModule_AddFunctions(sub, submodule_defs)==0);
|
||||
ASSERT(sub);
|
||||
ASSERT(0==PyModule_AddObject(m, "{submodule_info["pynames"][0]}", sub));
|
||||
''' if submodule_name is not None else ""
|
||||
}
|
||||
|
||||
}}
|
||||
|
||||
}}
|
||||
"""
|
||||
return code
|
||||
|
||||
def compile_single(head_file_name, src_file_name, src=None):
|
||||
basename = head_file_name.split("/")[-1].split(".")[0]
|
||||
if src==None:
|
||||
with open(head_file_name, 'r') as f:
|
||||
src = f.read()
|
||||
code = compile_src(src, head_file_name, basename)
|
||||
if not code: return False
|
||||
LOG.vvv("write to", src_file_name)
|
||||
LOG.vvvv(code)
|
||||
with open(src_file_name, 'w') as f:
|
||||
f.write(code)
|
||||
return True
|
||||
|
||||
def compile(cache_path, jittor_path):
|
||||
headers1 = run_cmd('find -L src/ | grep ".h$"', jittor_path).splitlines()
|
||||
headers2 = run_cmd('find gen/ | grep ".h$"', cache_path).splitlines()
|
||||
headers = [ os.path.join(jittor_path, h) for h in headers1 ] + \
|
||||
[ os.path.join(cache_path, h) for h in headers2 ]
|
||||
basenames = []
|
||||
for h in headers:
|
||||
with open(h, 'r') as f:
|
||||
src = f.read()
|
||||
|
||||
# jit_op_maker.h merge compile with var_holder.h
|
||||
if h.endswith("src/var_holder.h"): continue
|
||||
if h.endswith("jit_op_maker.h"):
|
||||
with open(os.path.join(jittor_path, "src", "var_holder.h"), "r") as f:
|
||||
src = f.read() + src
|
||||
basename = h.split("/")[-1].split(".")[0]
|
||||
fname = "pyjt_"+basename+".cc"
|
||||
fname = os.path.join(cache_path, "gen", fname)
|
||||
check = compile_single(h, fname, src)
|
||||
|
||||
if not check: continue
|
||||
|
||||
basenames.append(basename)
|
||||
|
||||
code = f"""
|
||||
#include "pyjt/numpy.h"
|
||||
#include "pyjt/py_converter.h"
|
||||
#include "common.h"
|
||||
|
||||
namespace jittor {{
|
||||
|
||||
{ " ".join([f"extern void pyjt_def_{n}(PyObject* m);" for n in basenames])}
|
||||
|
||||
void pyjt_def_all(PyObject* m) {{
|
||||
numpy_init();
|
||||
{ " ".join([f"pyjt_def_{n}(m);" for n in basenames])}
|
||||
}}
|
||||
|
||||
}}
|
||||
"""
|
||||
fname = os.path.join(cache_path, "gen", "pyjt_all.cc")
|
||||
LOG.vvv(("write to", fname))
|
||||
LOG.vvvv(code)
|
||||
with open(fname, "w") as f:
|
||||
f.write(code)
|
|
@ -0,0 +1 @@
|
|||
../../script
|
|
@ -0,0 +1 @@
|
|||
../../src
|
|
@ -0,0 +1,27 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
if __name__ == "__main__":
|
||||
import unittest
|
||||
|
||||
suffix = "__main__.py"
|
||||
assert __file__.endswith(suffix)
|
||||
test_dir = __file__[:-len(suffix)]
|
||||
import os
|
||||
test_files = os.listdir(test_dir)
|
||||
for test_file in test_files:
|
||||
if not test_file.startswith("test_"):
|
||||
continue
|
||||
test_name = test_file.split(".")[0]
|
||||
exec(f"from . import {test_name}")
|
||||
test_mod = globals()[test_name]
|
||||
print(test_name)
|
||||
for i in dir(test_mod):
|
||||
obj = getattr(test_mod, i)
|
||||
if isinstance(obj, type) and issubclass(obj, unittest.TestCase):
|
||||
globals()[test_name+"_"+i] = obj
|
||||
|
||||
unittest.main()
|
|
@ -0,0 +1,32 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include <functional>
|
||||
|
||||
using namespace std;
|
||||
|
||||
void test_main();
|
||||
|
||||
void on_error() {
|
||||
throw std::exception();
|
||||
}
|
||||
|
||||
void expect_error(function<void()> func) {
|
||||
try {
|
||||
func();
|
||||
} catch (...) {
|
||||
return;
|
||||
}
|
||||
CHECK(0) << "Missing error";
|
||||
}
|
||||
|
||||
int main() {
|
||||
try {
|
||||
test_main();
|
||||
} catch (...) {
|
||||
return 1;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,26 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import unittest
|
||||
import jittor as jt
|
||||
import gc
|
||||
|
||||
class TestAllocator(unittest.TestCase):
|
||||
def test_stat(self):
|
||||
jt.clean()
|
||||
with jt.flag_scope(use_stat_allocator=1, use_sfrl_allocator = 0):
|
||||
a = jt.random([10,10])
|
||||
b = a+a
|
||||
c = a*b
|
||||
c.data
|
||||
del a,b,c
|
||||
gc.collect()
|
||||
assert jt.flags.stat_allocator_total_alloc_call == 2
|
||||
assert jt.flags.stat_allocator_total_alloc_byte == 800
|
||||
assert jt.flags.stat_allocator_total_free_call == 2
|
||||
assert jt.flags.stat_allocator_total_free_byte == 800
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -0,0 +1,45 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Guoye Yang <498731903@qq.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
# All Rights Reserved.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import unittest
|
||||
import jittor as jt
|
||||
import gc
|
||||
|
||||
def test(h, w, total_alloc_call, total_alloc_byte, total_free_call = 0, total_free_byte = 0):
|
||||
jt.clean()
|
||||
jt.gc()
|
||||
with jt.flag_scope(use_stat_allocator=1):
|
||||
a = jt.random([h,w])
|
||||
b = a+a
|
||||
c = a*b
|
||||
c.data
|
||||
del a,b,c
|
||||
gc.collect()
|
||||
x = (
|
||||
jt.flags.stat_allocator_total_alloc_call,
|
||||
jt.flags.stat_allocator_total_alloc_byte,
|
||||
jt.flags.stat_allocator_total_free_call,
|
||||
jt.flags.stat_allocator_total_free_byte
|
||||
)
|
||||
y = (total_alloc_call, total_alloc_byte, total_free_call, total_free_byte)
|
||||
assert x==y, (x, y)
|
||||
|
||||
|
||||
class TestAllocator2(unittest.TestCase):
|
||||
def test_stat(self):
|
||||
#small_block
|
||||
test(10, 10, 1, 1048576) #800
|
||||
#small_block
|
||||
test(100, 100, 1, 1048576) #80000
|
||||
#large_block
|
||||
test(1000, 1000, 1, 20971520) #8000000
|
||||
#large_block2
|
||||
test(8000, 1000, 2, 67108864) #64000000
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -0,0 +1,122 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Guoye Yang <498731903@qq.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
# All Rights Reserved.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import unittest
|
||||
import jittor as jt
|
||||
import numpy as np
|
||||
from jittor import compile_extern
|
||||
from .test_log import find_log_with_re
|
||||
import copy
|
||||
if compile_extern.has_cuda:
|
||||
from jittor.compile_extern import cublas_ops, cudnn_ops, cub_ops
|
||||
else:
|
||||
cublas_ops = cudnn_ops = cub_ops = None
|
||||
|
||||
def check_reduce(shape, op, dim, keepdims, is_cuda = False):
|
||||
with jt.log_capture_scope(
|
||||
log_silent=1,
|
||||
log_v=0, log_vprefix="op.cc=100"
|
||||
) as raw_log:
|
||||
x = jt.random(shape)
|
||||
key, v = jt.arg_reduce(x, op, dim, keepdims)
|
||||
x_ = x.data
|
||||
key_ = key.data
|
||||
v_ = v.data
|
||||
if (is_cuda):
|
||||
logs = find_log_with_re(raw_log, "(Jit op key (not )?found: " + "cub_arg_reduce" + ".*)")
|
||||
assert len(logs)==1
|
||||
if op == 'max':
|
||||
key__ = np.argmax(x_, axis=dim)
|
||||
v__ = np.max(x_, axis=dim)
|
||||
else:
|
||||
key__ = np.argmin(x_, axis=dim)
|
||||
v__ = np.min(x_, axis=dim)
|
||||
|
||||
if keepdims:
|
||||
key__ = np.expand_dims(key__, axis=dim)
|
||||
v__ = np.expand_dims(v__, axis=dim)
|
||||
assert np.allclose(key_, key__)
|
||||
assert np.allclose(v_, v__)
|
||||
|
||||
def check_backward(shape, op, dim, keepdims):
|
||||
x = jt.random(shape)
|
||||
v, key = jt.arg_reduce(x, op, dim, keepdims)
|
||||
loss = (key * key).sum()
|
||||
gs = jt.grad(loss, x) / 2
|
||||
assert np.allclose((gs * x).data, (gs * gs).data)
|
||||
|
||||
class TestArgReduceOp(unittest.TestCase):
|
||||
def test_backward(self):
|
||||
check_backward([5,5,5], 'min', 0, True)
|
||||
check_backward([5,5,5], 'min', 2, True)
|
||||
check_backward([5,5,5], 'min', 1, True)
|
||||
check_backward([20,20,20,20], 'max', 0, True)
|
||||
check_backward([20,20,20,20], 'max', 2, True)
|
||||
check_backward([20,20,20,20], 'max', 1, True)
|
||||
check_backward([20,20,20,20], 'max', 3, True)
|
||||
check_backward([5,5,5], 'min', 0, False)
|
||||
check_backward([5,5,5], 'min', 2, False)
|
||||
check_backward([5,5,5], 'min', 1, False)
|
||||
check_backward([20,20,20,20], 'max', 0, False)
|
||||
check_backward([20,20,20,20], 'max', 2, False)
|
||||
check_backward([20,20,20,20], 'max', 1, False)
|
||||
check_backward([20,20,20,20], 'max', 3, False)
|
||||
|
||||
@unittest.skipIf(cub_ops==None, "Not use cub, Skip")
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_backward_cuda(self):
|
||||
check_backward([5,5,5], 'min', 0, True)
|
||||
check_backward([5,5,5], 'min', 2, True)
|
||||
check_backward([5,5,5], 'min', 1, True)
|
||||
check_backward([20,20,20,20], 'max', 0, True)
|
||||
check_backward([20,20,20,20], 'max', 2, True)
|
||||
check_backward([20,20,20,20], 'max', 1, True)
|
||||
check_backward([20,20,20,20], 'max', 3, True)
|
||||
check_backward([5,5,5], 'min', 0, False)
|
||||
check_backward([5,5,5], 'min', 2, False)
|
||||
check_backward([5,5,5], 'min', 1, False)
|
||||
check_backward([20,20,20,20], 'max', 0, False)
|
||||
check_backward([20,20,20,20], 'max', 2, False)
|
||||
check_backward([20,20,20,20], 'max', 1, False)
|
||||
check_backward([20,20,20,20], 'max', 3, False)
|
||||
|
||||
def test(self):
|
||||
check_reduce([5,5,5], 'min', 0, True)
|
||||
check_reduce([5,5,5], 'min', 2, True)
|
||||
check_reduce([5,5,5], 'min', 1, True)
|
||||
check_reduce([20,20,20,20], 'max', 0, True)
|
||||
check_reduce([20,20,20,20], 'max', 2, True)
|
||||
check_reduce([20,20,20,20], 'max', 1, True)
|
||||
check_reduce([20,20,20,20], 'max', 3, True)
|
||||
check_reduce([5,5,5], 'min', 0, False)
|
||||
check_reduce([5,5,5], 'min', 2, False)
|
||||
check_reduce([5,5,5], 'min', 1, False)
|
||||
check_reduce([20,20,20,20], 'max', 0, False)
|
||||
check_reduce([20,20,20,20], 'max', 2, False)
|
||||
check_reduce([20,20,20,20], 'max', 1, False)
|
||||
check_reduce([20,20,20,20], 'max', 3, False)
|
||||
|
||||
@unittest.skipIf(cub_ops==None, "Not use cub, Skip")
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_cuda(self):
|
||||
check_reduce([5,5,5], 'min', 0, True, True)
|
||||
check_reduce([5,5,5], 'min', 2, True, True)
|
||||
check_reduce([5,5,5], 'min', 1, True, True)
|
||||
check_reduce([20,20,20,20], 'max', 0, True, True)
|
||||
check_reduce([20,20,20,20], 'max', 2, True, True)
|
||||
check_reduce([20,20,20,20], 'max', 1, True, True)
|
||||
check_reduce([20,20,20,20], 'max', 3, True, True)
|
||||
check_reduce([5,5], 'min', 0, False, True)
|
||||
check_reduce([5,5,5], 'min', 2, False, True)
|
||||
check_reduce([5,5,5], 'min', 1, False, True)
|
||||
check_reduce([20,20,20,20], 'max', 0, False, True)
|
||||
check_reduce([20,20,20,20], 'max', 2, False, True)
|
||||
check_reduce([20,20,20,20], 'max', 1, False, True)
|
||||
check_reduce([20,20,20,20], 'max', 3, False, True)
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -0,0 +1,123 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Guoye Yang <498731903@qq.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
# All Rights Reserved.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import unittest
|
||||
import jittor as jt
|
||||
import numpy as np
|
||||
from jittor import compile_extern
|
||||
from .test_log import find_log_with_re
|
||||
if compile_extern.has_cuda:
|
||||
from jittor.compile_extern import cublas_ops, cudnn_ops, cub_ops
|
||||
else:
|
||||
cublas_ops = cudnn_ops = cub_ops = None
|
||||
|
||||
def check_argsort(shape, dim, descending = False):
|
||||
x = jt.random(shape)
|
||||
y, y_key = jt.argsort(x, dim=dim, descending=descending)
|
||||
v = []
|
||||
for i in range(len(shape)):
|
||||
if (i == dim):
|
||||
v.append(y)
|
||||
else:
|
||||
v.append(jt.index(shape, dim=i))
|
||||
yk = jt.reindex(x, v)
|
||||
yk_ = yk.data
|
||||
y_key_ = y_key.data
|
||||
x__ = x.data
|
||||
if descending:
|
||||
x__ = -x__
|
||||
yk__ = np.sort(x__, axis=dim)
|
||||
if descending:
|
||||
yk__ = -yk__
|
||||
assert np.allclose(y_key_, yk__)
|
||||
assert np.allclose(yk_, yk__)
|
||||
|
||||
def check_cub_argsort(shape, dim, descending = False):
|
||||
with jt.log_capture_scope(
|
||||
log_silent=1,
|
||||
log_v=0, log_vprefix="op.cc=100"
|
||||
) as raw_log:
|
||||
x = jt.random(shape)
|
||||
y, y_key = jt.argsort(x, dim=dim, descending=descending)
|
||||
v = []
|
||||
for i in range(len(shape)):
|
||||
if (i == dim):
|
||||
v.append(y)
|
||||
else:
|
||||
v.append(jt.index(shape, dim=i))
|
||||
yk = jt.reindex(x, v)
|
||||
yk_ = yk.data
|
||||
y_key_ = y_key.data
|
||||
logs = find_log_with_re(raw_log, "(Jit op key (not )?found: " + "cub_argsort" + ".*)")
|
||||
assert len(logs)==1
|
||||
x__ = x.data
|
||||
if descending:
|
||||
x__ = -x__
|
||||
yk__ = np.sort(x__, axis=dim)
|
||||
if descending:
|
||||
yk__ = -yk__
|
||||
assert np.allclose(y_key_, yk__)
|
||||
assert np.allclose(yk_, yk__)
|
||||
|
||||
def check_backward(shape, dim, descending = False):
|
||||
x = jt.random(shape)
|
||||
y, y_key = jt.argsort(x, dim=dim, descending=descending)
|
||||
loss = (y_key * y_key).sum()
|
||||
gs = jt.grad(loss, x)
|
||||
assert np.allclose(x.data*2, gs.data)
|
||||
|
||||
class TestArgsortOp(unittest.TestCase):
|
||||
def test(self):
|
||||
check_argsort([5,5], 0, False)
|
||||
check_argsort([5,5], 0, True)
|
||||
check_argsort([5,5], 1, False)
|
||||
check_argsort([5,5], 1, True)
|
||||
check_argsort([12, 34, 56, 78], 1, True)
|
||||
check_argsort([12, 34, 56, 78], 3, True)
|
||||
check_argsort([12, 34, 56, 78], 2, False)
|
||||
check_argsort([12, 34, 56, 78], 0, False)
|
||||
|
||||
def test_backward(self):
|
||||
check_backward([5,5], 0, False)
|
||||
check_backward([5,5], 0, True)
|
||||
check_backward([5,5], 1, False)
|
||||
check_backward([5,5], 1, True)
|
||||
check_backward([12, 34, 56, 78], 1, True)
|
||||
check_backward([12, 34, 56, 78], 3, True)
|
||||
check_backward([12, 34, 56, 78], 2, False)
|
||||
check_backward([12, 34, 56, 78], 0, False)
|
||||
|
||||
def test_doc(self):
|
||||
assert "Argsort Operator" in jt.argsort.__doc__
|
||||
|
||||
@unittest.skipIf(cub_ops==None, "Not use cub, Skip")
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_cub(self):
|
||||
check_cub_argsort([5,5], 0, False)
|
||||
check_cub_argsort([5,5], 0, True)
|
||||
check_cub_argsort([5,5], 1, False)
|
||||
check_cub_argsort([5,5], 1, True)
|
||||
check_cub_argsort([12, 34, 56, 78], 1, True)
|
||||
check_cub_argsort([12, 34, 56, 78], 3, True)
|
||||
check_cub_argsort([12, 34, 56, 78], 2, False)
|
||||
check_cub_argsort([12, 34, 56, 78], 0, False)
|
||||
|
||||
@unittest.skipIf(cub_ops==None, "Not use cub, Skip")
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_cub_backward(self):
|
||||
check_backward([5,5], 0, False)
|
||||
check_backward([5,5], 0, True)
|
||||
check_backward([5,5], 1, False)
|
||||
check_backward([5,5], 1, True)
|
||||
check_backward([12, 34, 56, 78], 1, True)
|
||||
check_backward([12, 34, 56, 78], 3, True)
|
||||
check_backward([12, 34, 56, 78], 2, False)
|
||||
check_backward([12, 34, 56, 78], 0, False)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -0,0 +1,118 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import unittest
|
||||
import jittor as jt
|
||||
import numpy as np
|
||||
from jittor import compile_extern
|
||||
|
||||
class TestArray(unittest.TestCase):
|
||||
def test_data(self):
|
||||
a = jt.array([1,2,3])
|
||||
assert (a.data == [1,2,3]).all()
|
||||
d = a.data
|
||||
a.data[1] = -2
|
||||
assert (a.data == [1,-2,3]).all()
|
||||
assert (a.fetch_sync()==[1,-2,3]).all()
|
||||
li = jt.liveness_info()
|
||||
del a
|
||||
assert li == jt.liveness_info()
|
||||
del d
|
||||
assert li != jt.liveness_info()
|
||||
|
||||
def test_set_data(self):
|
||||
a = jt.array([1,2,3])
|
||||
assert (a.fetch_sync()==[1,2,3]).all()
|
||||
a.data = [4,5,6]
|
||||
assert (a.fetch_sync()==[4,5,6]).all()
|
||||
a.data = jt.array([7,8,9])
|
||||
assert (a.fetch_sync()==[7,8,9]).all()
|
||||
|
||||
@unittest.skipIf(not compile_extern.has_cuda, "Cuda not found")
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_memcopy_overlap(self):
|
||||
import time
|
||||
from jittor.models import resnet
|
||||
|
||||
im=np.random.rand(100,3,224,224).astype(np.float32)
|
||||
net = resnet.Resnet34()
|
||||
net.eval()
|
||||
# warm up
|
||||
x = jt.array(im).stop_grad()
|
||||
|
||||
for i in range(10):
|
||||
a = net(x)
|
||||
a.sync()
|
||||
jt.sync(device_sync=True)
|
||||
|
||||
# pure compute
|
||||
time_start=time.time()
|
||||
x = jt.array(im).stop_grad()
|
||||
for i in range(10):
|
||||
a = net(x)
|
||||
a.sync()
|
||||
jt.sync(device_sync=True)
|
||||
t1 = time.time() - time_start
|
||||
|
||||
# warm up
|
||||
for i in range(3):
|
||||
x = jt.array(im)
|
||||
b = net(x)
|
||||
b.sync()
|
||||
jt.sync(device_sync=True)
|
||||
|
||||
# overlap
|
||||
time_start=time.time()
|
||||
results = []
|
||||
for i in range(10):
|
||||
x = jt.array(im)
|
||||
b = net(x)
|
||||
b.fetch(lambda b: results.append(b))
|
||||
# del c
|
||||
jt.sync(device_sync=True)
|
||||
t2 = time.time() - time_start
|
||||
|
||||
assert t2-t1 < 0.010, (t2, t1, t2-t1)
|
||||
assert np.allclose(a.data, b.data)
|
||||
assert len(results) == 10
|
||||
for v in results:
|
||||
assert np.allclose(a.data, v), (v.shape, a.data.shape)
|
||||
jt.LOG.v(f"pure compute: {t1}, overlap: {t2}")
|
||||
|
||||
def test_segfault(self):
|
||||
a = jt.array([1.0,2.0,3.0])
|
||||
b = (jt.maximum(a, 0)).sum() * 2.0
|
||||
da = jt.grad(b, a)
|
||||
jt.sync_all()
|
||||
assert (a.data==[1,2,3]).all()
|
||||
assert (da.data==[2,2,2]).all()
|
||||
|
||||
def test_segfault2(self):
|
||||
assert (jt.array([1,2,3]).reshape((1,3)).data==[1,2,3]).all()
|
||||
if jt.has_cuda:
|
||||
with jt.flag_scope(use_cuda=1):
|
||||
assert (jt.array([1,2,3]).reshape((1,3)).data==[1,2,3]).all()
|
||||
|
||||
@unittest.skipIf(not compile_extern.has_cuda, "Cuda not found")
|
||||
def test_array_dual(self):
|
||||
with jt.flag_scope(use_cuda=1):
|
||||
a = jt.array(np.float32([1,2,3]))
|
||||
assert (a.data==[1,2,3]).all()
|
||||
|
||||
@unittest.skipIf(not compile_extern.has_cuda, "Cuda not found")
|
||||
def test_array_migrate(self):
|
||||
with jt.flag_scope(use_cuda=1):
|
||||
a = jt.array(np.float32([1,2,3]))
|
||||
b = jt.code(a.shape, a.dtype, [a], cpu_src="""
|
||||
for (int i=0; i<in0shape0; i++)
|
||||
@out(i) = @in0(i)*@in0(i)*2;
|
||||
""")
|
||||
assert (b.data==[2,8,18]).all()
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -0,0 +1,110 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Guowei Yang <471184555@qq.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
# All Rights Reserved.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import unittest
|
||||
import jittor as jt
|
||||
from jittor import LOG
|
||||
import os
|
||||
import re
|
||||
|
||||
class TestAsmTuner(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(self):
|
||||
inline = "inline"
|
||||
if jt.flags.cc_type == "clang":
|
||||
inline = "__attribute__((always_inline))"
|
||||
self.cc_content='''
|
||||
#include <cmath>
|
||||
#include <algorithm>
|
||||
#include "var.h"
|
||||
#include "ops/broadcast_to_op.h"
|
||||
#include "ops/binary_op.h"
|
||||
#include "fused_op.h"
|
||||
#define op0_Tx float32
|
||||
#define op0_DIM 2
|
||||
#define op0_BCAST 1
|
||||
#define op0_index_t int32_t
|
||||
#define op1_Tx float
|
||||
#define op1_DIM 2
|
||||
#define op1_BCAST 0
|
||||
#define op1_index_t int32_t
|
||||
#define op2_Tx float
|
||||
#define op2_Ty float32
|
||||
#define op2_Tz float32
|
||||
#define op2_OP subtract
|
||||
#define op2_index_t int32_t
|
||||
using namespace jittor;
|
||||
#define INLINE_FUNC '''+inline+''' void
|
||||
INLINE_FUNC func0(op0_Tx* __restrict__ op0_xp, op1_Tx* __restrict__ op1_xp, op2_Tz* __restrict__ op2_zp) {
|
||||
//@begin replace "vmova(.*,.*\(.*\))" "vmovnt\g<1>"
|
||||
(void)(__builtin_assume_aligned(op0_xp, alignment));
|
||||
(void)(__builtin_assume_aligned(op1_xp, alignment));
|
||||
(void)(__builtin_assume_aligned(op2_zp, alignment));
|
||||
op2_index_t range0 = 1048576;
|
||||
op2_index_t range1 = 32;
|
||||
op0_index_t op0_xstride1 = 1;
|
||||
auto op0_xstride0 = op0_xstride1 * range1;
|
||||
op1_index_t op1_xstride1 = 1;
|
||||
auto op1_xstride0 = op1_xstride1 * range1;
|
||||
op2_index_t op2_zstride1 = 1;
|
||||
auto op2_zstride0 = op2_zstride1 * range1;
|
||||
for (op2_index_t id0 = 0; id0<range0; id0++) {
|
||||
for (op2_index_t id1 = 0; id1<range1; id1++) {
|
||||
auto op0_xid = + 0 * op0_xstride0 + id1 * op0_xstride1;
|
||||
auto op0_zd = op0_xp[op0_xid];
|
||||
auto op1_xid = + id0 * op1_xstride0 + id1 * op1_xstride1;
|
||||
auto op1_zd = op1_xp[op1_xid];
|
||||
op2_index_t op2_i = + id0 * op2_zstride0 + id1 * op2_zstride1;
|
||||
op2_zp[op2_i] = ((op1_zd )-(op0_zd ));
|
||||
}
|
||||
}
|
||||
//@end
|
||||
}
|
||||
void jittor::FusedOp::jit_run() {
|
||||
auto op0_x = ((BroadcastToOp*)(ops[0]))->x;
|
||||
auto op1_x = ((BroadcastToOp*)(ops[1]))->x;
|
||||
auto op2_z = ((BinaryOp*)(ops[2]))->z;
|
||||
auto* __restrict__ op0_xp = op0_x->ptr<op0_Tx>();
|
||||
auto* __restrict__ op1_xp = op1_x->ptr<op1_Tx>();
|
||||
auto* __restrict__ op2_zp = op2_z->ptr<op2_Tz>();
|
||||
func0(op0_xp,op1_xp,op2_zp);
|
||||
}
|
||||
'''
|
||||
|
||||
self.src_path=os.path.join(jt.flags.cache_path, 'jit', 'asm_test_op.cc')
|
||||
self.asm_path = os.path.join(jt.flags.jittor_path, "utils/asm_tuner.py")
|
||||
self.so_path=self.src_path.replace(".cc",".so")
|
||||
|
||||
def run_cmd(self, cmd):
|
||||
return jt.compiler.run_cmd(cmd)
|
||||
|
||||
def check_cc(self, content, check_movnt):
|
||||
LOG.vv("check_cc")
|
||||
with open(self.src_path,"w") as f:
|
||||
f.write(content)
|
||||
|
||||
cmd = jt.flags.python_path + " " + \
|
||||
jt.flags.jittor_path+"/utils/asm_tuner.py --cc_path=" + jt.flags.cc_path + " '" + self.src_path + "'" + " -DJIT -DJIT_cpu " + jt.flags.cc_flags + " -o '" + self.so_path + "'";
|
||||
self.run_cmd(cmd)
|
||||
|
||||
s_path=self.so_path.replace(".so",".s")
|
||||
bo=False
|
||||
with open(s_path) as f:
|
||||
for line in f:
|
||||
if line.find("vmovnt")!=-1:
|
||||
bo=True
|
||||
break
|
||||
if check_movnt and jt.flags.cc_type == "clang":
|
||||
assert bo
|
||||
|
||||
def test_asm_tuner(self):
|
||||
self.check_cc(self.cc_content,True)
|
||||
self.check_cc(self.cc_content.replace("@begin","233").replace("@end","666"), False)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -0,0 +1,74 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Guowei Yang <471184555@qq.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
# All Rights Reserved.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import jittor as jt
|
||||
from jittor import nn, Module
|
||||
import numpy as np
|
||||
import sys, os
|
||||
import random
|
||||
import math
|
||||
import unittest
|
||||
from .test_reorder_tuner import simple_parser
|
||||
from .test_log import find_log_with_re
|
||||
|
||||
class testNet(Module):
|
||||
def __init__(self, op):
|
||||
self.op = op
|
||||
return
|
||||
|
||||
def execute(self, x):
|
||||
N,H,W,C = x.shape
|
||||
y1=x.reindex_reduce(self.op, [N,H], ["i0","i1",])
|
||||
y2=x.reindex_reduce(self.op, [H,W], ["i1","i2",])
|
||||
y1=y1.broadcast([N,H,W],[2])
|
||||
y2=y2.broadcast([N,H,W],[0])
|
||||
return y1+y2
|
||||
|
||||
class TestAtomicTunerClass(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(self):
|
||||
self.addNet = testNet("add")
|
||||
self.maxNet = testNet("maximum")
|
||||
self.minNet = testNet("minimum")
|
||||
return
|
||||
|
||||
def check(self, model, std_log):
|
||||
x=jt.random([100,64,128,128])
|
||||
with jt.log_capture_scope(
|
||||
log_silent=1,
|
||||
log_v=0, log_vprefix="atomic_tuner_pass.cc=100",
|
||||
) as logs:
|
||||
y=model(x).numpy()
|
||||
with jt.log_capture_scope(
|
||||
log_v=0,
|
||||
exclude_pass="atomic",
|
||||
) as logs2:
|
||||
y_std=model(x).numpy()
|
||||
|
||||
err=np.max(y_std-y)/(np.mean(y_std)+1e-6)
|
||||
assert err<1e-5
|
||||
log_move = find_log_with_re(logs, "atomictuner: move .* to loop .*")
|
||||
assert len(log_move)==len(std_log), (len(log_move), len(std_log))
|
||||
for st in log_move:
|
||||
sidx=-1
|
||||
for j in range(len(std_log)):
|
||||
if std_log[j]==st:
|
||||
sidx=j
|
||||
break
|
||||
assert sidx!=-1
|
||||
std_log[sidx]="matched"
|
||||
|
||||
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_atomic_tuner(self):
|
||||
self.check(self.addNet, ['atomictuner: move atomicAdd to loop 1', 'atomictuner: move atomicAdd to loop 2'])
|
||||
self.check(self.maxNet, ['atomictuner: move cuda_atomic_max to loop 1', 'atomictuner: move cuda_atomic_max to loop 2'])
|
||||
self.check(self.minNet, ['atomictuner: move cuda_atomic_min to loop 1', 'atomictuner: move cuda_atomic_min to loop 2'])
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -0,0 +1,132 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import unittest
|
||||
import jittor as jt
|
||||
import numpy as np
|
||||
from .test_core import expect_error
|
||||
from .test_grad import ngrad
|
||||
from .test_cuda import test_cuda
|
||||
|
||||
def all_eq(x, y):
|
||||
if len(x.shape) == 0: x = np.array([x])
|
||||
if len(y.shape) == 0: y = np.array([y])
|
||||
convert = lambda x: x.astype("uint8") if x.dtype=="bool" else x
|
||||
x = convert(x)
|
||||
y = convert(y)
|
||||
return x.dtype == y.dtype and x.shape == y.shape and (x==y).all()
|
||||
|
||||
def check(op, *args):
|
||||
x = eval(f"np.{op}(*args)")
|
||||
y = eval(f"jt.{op}(*args).data")
|
||||
assert all_eq(x, y), f"{x}\n{y}"
|
||||
|
||||
class TestBinaryOp(unittest.TestCase):
|
||||
def test_binary_op(self):
|
||||
assert np.all(jt.binary(1,2,'maximum').data == 2)
|
||||
assert np.all(jt.binary([[1,2]],[[3,4]],'add').data == [[4,6]])
|
||||
assert np.all(jt.less(1,2).data)
|
||||
assert jt.less(1,2).data.dtype == "bool"
|
||||
x = (jt.array(1) << jt.array(3)).data
|
||||
assert (x == 8).all()
|
||||
x = (jt.array(2) ** jt.array(3)).data
|
||||
assert (x == 8).all()
|
||||
a = [1,2,3]
|
||||
b = [7,10,13]
|
||||
check("logical_and", a, b)
|
||||
check("logical_or", a, b)
|
||||
check("logical_xor", a, b)
|
||||
check("bitwise_and", a, b)
|
||||
check("bitwise_or", a, b)
|
||||
check("bitwise_xor", a, b)
|
||||
|
||||
def test_i(self):
|
||||
def check(op, a, b):
|
||||
if jt.flags.use_cuda and op == "@":
|
||||
return
|
||||
if op=="@":
|
||||
a = np.float32(a)
|
||||
b = np.float32(b)
|
||||
ja = jt.array(a)
|
||||
jb = jt.array(b)
|
||||
exec(f"ja {op}= jb")
|
||||
ja = ja.fetch_sync()
|
||||
|
||||
if op == "@":
|
||||
# numpy do not support @=
|
||||
a = np.array(a) @ np.array(b)
|
||||
else:
|
||||
a = eval(f"a {op} b")
|
||||
a = np.float32(a)
|
||||
ja = np.float32(ja)
|
||||
|
||||
assert all_eq(ja, a), (ja,a)
|
||||
check("+", 5, 2)
|
||||
check("-", 5, 2)
|
||||
check("*", 5, 2)
|
||||
check("/", 5, 2)
|
||||
check("//", 5, 2)
|
||||
check("@", [[5]], [[2]])
|
||||
check("%", 5, 2)
|
||||
check("**", 5, 2)
|
||||
check("<<", 5, 2)
|
||||
check(">>", 5, 2)
|
||||
check("&", 5, 2)
|
||||
check("^", 5, 2)
|
||||
check("|", 5, 2)
|
||||
|
||||
def test_r(self):
|
||||
def check(op, a, b):
|
||||
if jt.flags.use_cuda and op == "@":
|
||||
return
|
||||
jb = jt.array(b)
|
||||
jc = eval(f"a {op} jb").data
|
||||
|
||||
|
||||
if op == "@":
|
||||
# numpy do not support @=
|
||||
a = np.array(a) @ np.array(b)
|
||||
else:
|
||||
a = eval(f"a {op} b")
|
||||
a = np.array(a)
|
||||
|
||||
assert all_eq(jc, a), f"\n{jc}\n{a}"
|
||||
check("+", 5, 2)
|
||||
check("-", 5, 2)
|
||||
check("*", 5, 2)
|
||||
check("/", 5, 2)
|
||||
check("//", 5, 2)
|
||||
# check("@", [[5]], [[2]])
|
||||
check("%", 5, 2)
|
||||
check("**", 5, 2)
|
||||
check("<<", 5, 2)
|
||||
check(">>", 5, 2)
|
||||
check("&", 5, 2)
|
||||
check("^", 5, 2)
|
||||
check("|", 5, 2)
|
||||
|
||||
def test_grad(self):
|
||||
ops = ["+", "-", "*", "/", "**"]
|
||||
np.random.seed(3)
|
||||
a = np.random.rand(10)
|
||||
b = np.random.rand(10)
|
||||
c = np.random.rand(10)
|
||||
for op in ops:
|
||||
func = lambda x: eval(f"((x[0]{op}x[1])*x[2]).sum()")
|
||||
x, grads = ngrad(func, [a,b,c], 1e-8)
|
||||
ja = jt.array(a).name("ja")
|
||||
jb = jt.array(b).name("jb")
|
||||
jc = jt.array(c).name("jc")
|
||||
jx = eval(f"(ja{op}jb)*jc")
|
||||
jgrads = jt.grad(jx, [ja,jb,jc])
|
||||
for jd, nd in zip(jgrads, grads):
|
||||
assert (np.abs(jd.data-nd)<1e-4).all(), f"\n{jd.data}\n{nd}"
|
||||
|
||||
|
||||
class TestBinaryOpCuda(TestBinaryOp, test_cuda(2)):
|
||||
pass
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue