forked from jittor/jittor
version 77593ddd55381fddacdfa637355784488523c5e2
This commit is contained in:
parent
1258121b1f
commit
e96f7ceee8
|
@ -104,8 +104,9 @@ Jittor使用Python和C++编写。 它需要用于即时编译的编译器。当
|
|||
|
||||
Jittor的环境要求如下:
|
||||
|
||||
* 操作系统: Ubuntu>=16.04
|
||||
* Python >= 3.7
|
||||
* 操作系统: Ubuntu >= 16.04
|
||||
* Python版本 >= 3.7
|
||||
* C++编译器(g++ or clang)
|
||||
|
||||
Jittor offers three ways to install: pip, script or manual.
|
||||
|
||||
|
@ -115,7 +116,7 @@ Jittor 一共提供三种方式安装: pip安装, 一键脚本安装 和 手动
|
|||
|
||||
## Pip install
|
||||
|
||||
如果您已经装好编译器和对应版本的Python,我们强烈推荐您使用这种方法
|
||||
如果您没有准备好环境,欢迎使用我们提供的一键安装脚本, 如果您已经装好编译器和对应版本的Python,我们强烈推荐您使用这种方法
|
||||
(如果无法访问github, 可以通过jittor主页下载):
|
||||
|
||||
```bash
|
||||
|
@ -134,7 +135,7 @@ jittor会自动在路径中寻找合适的编译器, 如果您希望手动指定
|
|||
## 一键脚本安装
|
||||
## single line script install
|
||||
|
||||
一键脚本安装会帮您安装好所需的编译器.
|
||||
一键脚本安装会帮您安装好所需的编译器以及对应的Python版本.
|
||||
|
||||
We provide single line command for quick installation the latest version of Jittor(Ubuntu>=16.04):
|
||||
|
||||
|
@ -142,13 +143,13 @@ We provide single line command for quick installation the latest version of Jitt
|
|||
|
||||
```bash
|
||||
# install with clang and cuda
|
||||
git clone https://github.com/Jittor/jittor.git && with_clang=1 with_cuda=1 bash ./jittor/script/install.sh
|
||||
wget https://cg.cs.tsinghua.edu.cn/jittor/assets/build/jittor.tgz && mkdir -p jittor && tar -xvf ./jittor.tgz -C jittor && with_clang=1 with_cuda=1 bash ./jittor/script/install.sh
|
||||
# install with clang
|
||||
git clone https://github.com/Jittor/jittor.git && with_clang=1 bash ./jittor/script/install.sh
|
||||
wget https://cg.cs.tsinghua.edu.cn/jittor/assets/build/jittor.tgz && mkdir -p jittor && tar -xvf ./jittor.tgz -C jittor && with_clang=1 bash ./jittor/script/install.sh
|
||||
# install with g++ and cuda
|
||||
git clone https://github.com/Jittor/jittor.git && with_gcc=1 with_cuda=1 bash ./jittor/script/install.sh
|
||||
wget https://cg.cs.tsinghua.edu.cn/jittor/assets/build/jittor.tgz && mkdir -p jittor && tar -xvf ./jittor.tgz -C jittor && with_gcc=1 with_cuda=1 bash ./jittor/script/install.sh
|
||||
# install with g++
|
||||
git clone https://github.com/Jittor/jittor.git && with_gcc=1 bash ./jittor/script/install.sh
|
||||
wget https://cg.cs.tsinghua.edu.cn/jittor/assets/build/jittor.tgz && mkdir -p jittor && tar -xvf ./jittor.tgz -C jittor && with_gcc=1 bash ./jittor/script/install.sh
|
||||
```
|
||||
After execution, the script will show some environment variables you need to export.
|
||||
|
||||
|
|
|
@ -166,7 +166,7 @@ def install_cutt(root_folder):
|
|||
filename = "cutt.tgz"
|
||||
fullname = os.path.join(root_folder, filename)
|
||||
dirname = os.path.join(root_folder, filename.replace(".tgz",""))
|
||||
true_md5 = "c79ad93b76544d598eb250ec749c492c"
|
||||
true_md5 = "28a67bb3a713e29ce434303df6577507"
|
||||
|
||||
if os.path.exists(fullname):
|
||||
md5 = os.popen('md5sum ' + fullname).read().split()[0]
|
||||
|
@ -178,6 +178,7 @@ def install_cutt(root_folder):
|
|||
if not os.path.isfile(os.path.join(dirname, "bin", "cutt_test")):
|
||||
LOG.i("Downloading cub...")
|
||||
download_url_to_local(url, filename, root_folder, true_md5)
|
||||
|
||||
import tarfile
|
||||
|
||||
with tarfile.open(fullname, "r") as tar:
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
# ***************************************************************
|
||||
import unittest
|
||||
import jittor as jt
|
||||
from jittor.utils import pytorch_converter
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
|
@ -13,6 +12,7 @@ try:
|
|||
jt.dirty_fix_pytorch_runtime_error()
|
||||
import torch
|
||||
from torch import nn
|
||||
from jittor.utils import pytorch_converter
|
||||
except:
|
||||
torch = None
|
||||
|
||||
|
|
|
@ -46,7 +46,7 @@ run_cmd(f"git rev-parse HEAD > {polish_path}/python/jittor/version", jittor_path
|
|||
files = jt.compiler.files
|
||||
file_to_delete = [ name for name in files
|
||||
if name.startswith("src") and \
|
||||
len(name.split("/"))==2
|
||||
len(name.split("/"))==2 and name.endswith("node.cc")
|
||||
]
|
||||
LOG.i("file_to_delete", file_to_delete)
|
||||
run_cmd(f"rm {' '.join(file_to_delete)}", polish_path)
|
||||
|
@ -97,6 +97,7 @@ assert os.system(f"cd {polish_path} && tar -cvzf build/jittor.tgz . --exclude bu
|
|||
jittor_web_base_dir = "Documents/jittor-blog/assets/"
|
||||
jittor_web_build_dir = jittor_web_base_dir + "build/"
|
||||
assert os.system(f"rsync -avPu {polish_path}/build/ jittor-web:{jittor_web_build_dir}")==0
|
||||
assert os.system(f"ssh jittor@166.111.68.30 Documents/jittor-blog.git/hooks/post-update")==0
|
||||
|
||||
# push to github
|
||||
assert os.system(f"cd {polish_path} && git push -f origin master")==0
|
||||
|
|
|
@ -1 +1 @@
|
|||
1522f3d004f9bdbf3953d91d4c259c341817c71f
|
||||
77593ddd55381fddacdfa637355784488523c5e2
|
||||
|
|
|
@ -50,8 +50,10 @@ wget -O - https://bootstrap.pypa.io/get-pip.py | sudo -H python$py_version
|
|||
|
||||
# Step 3: Run jittor
|
||||
|
||||
sudo apt install git -y
|
||||
git clone https://github.com/Jittor/jittor.git
|
||||
if [ ! -d jittor ]; then
|
||||
wget https://cg.cs.tsinghua.edu.cn/jittor/assets/build/jittor.tgz
|
||||
mkdir -p jittor && tar -xvf ./jittor.tgz -C jittor
|
||||
fi
|
||||
|
||||
sudo python$py_version -m pip install ./jittor
|
||||
|
||||
|
|
2
setup.py
2
setup.py
|
@ -28,5 +28,7 @@ setuptools.setup(
|
|||
"pybind11",
|
||||
"numpy",
|
||||
"tqdm",
|
||||
"pillow",
|
||||
"astunparse",
|
||||
],
|
||||
)
|
|
@ -0,0 +1,36 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include "event_queue.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
EventQueue event_queue;
|
||||
|
||||
void EventQueue::Worker::start() {
|
||||
Worker* self = &event_queue.worker;
|
||||
while (1) {
|
||||
Func todo;
|
||||
{
|
||||
std::unique_lock<std::mutex> l(self->mtx);
|
||||
event_queue.cv.notify_one();
|
||||
self->cv.wait(l);
|
||||
todo = self->todo;
|
||||
}
|
||||
if (!todo) break;
|
||||
todo();
|
||||
}
|
||||
}
|
||||
|
||||
void EventQueue::worker_caller() {
|
||||
event_queue.func();
|
||||
{
|
||||
std::lock_guard<std::mutex> l(event_queue.mtx);
|
||||
event_queue.run_sync_done = true;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,405 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#ifdef HAS_CUDA
|
||||
#include <cuda_runtime.h>
|
||||
#include <helper_cuda.h>
|
||||
#include "mem/allocator/cuda_dual_allocator.h"
|
||||
#include "fetcher.h"
|
||||
#include "event_queue.h"
|
||||
#endif
|
||||
#include "misc/cuda_flags.h"
|
||||
#include "executor.h"
|
||||
#include "var.h"
|
||||
#include "op.h"
|
||||
#include "mem/allocator.h"
|
||||
#include "graph.h"
|
||||
#include "fused_op.h"
|
||||
#include "fuser.h"
|
||||
#include "profiler/profiler_guard.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
Executor exe;
|
||||
|
||||
void Executor::run_sync(vector<Var*> vars, bool device_sync) {
|
||||
auto allocator = get_allocator();
|
||||
this->allocator = allocator;
|
||||
// bfs find all ops need to run
|
||||
int op_num = 0;
|
||||
vector<Node*> bfs_q;
|
||||
bfs_q.reserve(vars.size());
|
||||
auto nodes = (vector<Node*>*)&vars;
|
||||
int start_var_num = 0;
|
||||
for (Var* v : vars)
|
||||
if (!v->is_finished())
|
||||
start_var_num++;
|
||||
bfs_backward(*nodes, bfs_q, [&](Node *node) -> bool {
|
||||
node->custom_data = 0;
|
||||
if (node->is_finished())
|
||||
return false;
|
||||
op_num += !node->is_var();
|
||||
return true;
|
||||
});
|
||||
auto tt = Node::tflag_count;
|
||||
vector<Op*> ops;
|
||||
vector<Var*> all_vars;
|
||||
ops.reserve(op_num);
|
||||
for (Node* node : bfs_q)
|
||||
if (!node->is_var()) {
|
||||
node->custom_data = ops.size();
|
||||
ops.push_back(node->op());
|
||||
} else {
|
||||
// set can't fuse flag to false
|
||||
node->custom_data = all_vars.size();
|
||||
all_vars.push_back(node->var());
|
||||
}
|
||||
int var_num = all_vars.size();
|
||||
|
||||
// father: father of union-find set
|
||||
vector<int> father(op_num);
|
||||
for (int i=0; i<op_num; i++) {
|
||||
father[i] = i;
|
||||
}
|
||||
// union-find algorithm
|
||||
auto find_fa = [&](int i) -> int {
|
||||
int j=i;
|
||||
while (father[j] != j) j = father[j];
|
||||
while (i != j) {
|
||||
int tmp = father[i];
|
||||
father[i] = j;
|
||||
i = tmp;
|
||||
}
|
||||
return j;
|
||||
};
|
||||
vector<int> var_fused(var_num);
|
||||
|
||||
if (V_ON(100)) {
|
||||
for (uint i=0; i<ops.size(); i++) {
|
||||
Op* op = ops[i];
|
||||
string st="others";
|
||||
if (op->type()==OpType::reduce) st="reduce";
|
||||
if (op->type()==OpType::broadcast) st="broadcast";
|
||||
if (op->type()==OpType::element) st="element";
|
||||
|
||||
LOGvvv << "id:" << ops[i]->custom_data << " type:" <<
|
||||
st << " addr:" << op;
|
||||
for (Var* v : op->inputs()) {
|
||||
Op* next_op = v->input();
|
||||
// continue if is boundary
|
||||
if (!next_op || next_op->tflag != tt) {
|
||||
LOGvvv << "input:" << v;
|
||||
continue;
|
||||
}
|
||||
LOGvvv << "input:" << next_op->custom_data << " addr:" << next_op;
|
||||
}
|
||||
LOGvvv << "";
|
||||
}
|
||||
}
|
||||
|
||||
count_fuse(tt, start_var_num, ops, all_vars, father, var_fused);
|
||||
// var_fused represents:
|
||||
// 0: can fused
|
||||
// 1: cannot fused
|
||||
// 2: can shared
|
||||
vector<int> roots, next(op_num, -1);
|
||||
vector<int> deps(op_num, 0);
|
||||
roots.reserve(op_num);
|
||||
for (int i=0; i<op_num; i++) {
|
||||
int fa = find_fa(i);
|
||||
if (fa == i)
|
||||
roots.push_back(i);
|
||||
else {
|
||||
next[i] = next[fa];
|
||||
next[fa] = i;
|
||||
}
|
||||
}
|
||||
vector<int> queue;
|
||||
queue.reserve(roots.size());
|
||||
|
||||
// ** toplogical_sort external **
|
||||
// output:
|
||||
// queue: toplogical order of fused op
|
||||
{
|
||||
for (int root : roots) {
|
||||
for (int i=root; i>=0; i=next[i]) {
|
||||
Op* op = ops[i];
|
||||
for (Var* v : op->inputs()) {
|
||||
if (v->tflag != tt) continue;
|
||||
Op* opi = v->input();
|
||||
// if those two ops are not fused
|
||||
if (father[opi->custom_data] != root) {
|
||||
deps[root]++;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (deps[root] == 0)
|
||||
queue.push_back(root);
|
||||
}
|
||||
for (uint s=0; s<queue.size(); s++) {
|
||||
int op_id = queue[s];
|
||||
for (int i=op_id; i>=0; i=next[i]) {
|
||||
Op* op = ops[i];
|
||||
for (Var* v : op->outputs())
|
||||
if (v->tflag == tt)
|
||||
for (Op* op2 : v->outputs()) {
|
||||
if (op2->tflag != tt) continue;
|
||||
int op2_id = father[op2->custom_data];
|
||||
// continue if those two ops are fused
|
||||
if (op2_id == op_id) continue;
|
||||
deps[op2_id]--;
|
||||
if (deps[op2_id] == 0)
|
||||
queue.push_back(op2_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
ASSERTop(queue.size(),==,roots.size());
|
||||
}
|
||||
|
||||
// ** toplogical_sort internal **
|
||||
// output:
|
||||
// fuse_ops: fused op id [000|1111|22|3333]
|
||||
// range: split index ^ ^ ^ ^ ^
|
||||
vector<int> fuse_ops;
|
||||
fuse_ops.reserve(op_num*2);
|
||||
vector<int> range(queue.size());
|
||||
{
|
||||
vector<int> subgraph;
|
||||
subgraph.reserve(16);
|
||||
vector<int> sharegraph;
|
||||
sharegraph.reserve(16);
|
||||
vector<int> sharegraph_q;
|
||||
sharegraph_q.reserve(16);
|
||||
vector<int> shared_id(op_num, -1);
|
||||
|
||||
for (uint rid=0; rid<queue.size(); rid++) {
|
||||
int root = queue[queue.size()-rid-1];
|
||||
auto& queue = subgraph;
|
||||
queue.clear();
|
||||
sharegraph.clear();
|
||||
int total=0;
|
||||
for (int i=root; i>=0; i=next[i], total++) {
|
||||
Op* op = ops[i];
|
||||
for (Var* v : op->inputs()) {
|
||||
if (v->tflag != tt) continue;
|
||||
Op* opi = v->input();
|
||||
// if those two ops are fused
|
||||
int opid = opi->custom_data;
|
||||
auto fopid = father[opid];
|
||||
if (fopid == root)
|
||||
deps[i]++;
|
||||
else if (shared_id[opid] != root) {
|
||||
// var_fused = 1 cannot share input op
|
||||
// TODO: check this input op's output var all can be shared
|
||||
if (var_fused[v->custom_data] == 1)
|
||||
continue;
|
||||
// new shared op
|
||||
deps[opid] = 0;
|
||||
shared_id[opid] = root;
|
||||
sharegraph.push_back(opid);
|
||||
}
|
||||
}
|
||||
if (deps[i] == 0)
|
||||
queue.push_back(i);
|
||||
}
|
||||
// find all share graph
|
||||
uint sn = sharegraph.size();
|
||||
for (uint i=0; i<sharegraph.size(); i++) {
|
||||
int id = sharegraph[i];
|
||||
Op* op = ops[id];
|
||||
for (Var* v : op->inputs()) {
|
||||
if (v->tflag != tt) continue;
|
||||
int vi = v->custom_data;
|
||||
if (var_fused[vi] == 1)
|
||||
continue;
|
||||
Op* opi = v->input();
|
||||
int opid = opi->custom_data;
|
||||
int& dep = deps[opid];
|
||||
if (shared_id[opid] != root) {
|
||||
shared_id[opid] = root;
|
||||
dep = 1;
|
||||
sharegraph.push_back(opid);
|
||||
} else
|
||||
dep ++;
|
||||
}
|
||||
}
|
||||
sharegraph_q.clear();
|
||||
for (uint i=0; i<sn; i++)
|
||||
if (deps[sharegraph[i]]==0)
|
||||
sharegraph_q.push_back(sharegraph[i]);
|
||||
// topsort in sharegraph_q
|
||||
for (uint i=0; i<sharegraph_q.size(); i++) {
|
||||
int id = sharegraph_q[i];
|
||||
Op* op = ops[id];
|
||||
for (Var* v : op->inputs()) {
|
||||
if (v->tflag != tt) continue;
|
||||
int vi = v->custom_data;
|
||||
if (var_fused[vi] == 1)
|
||||
continue;
|
||||
Op* opi = v->input();
|
||||
int opid = opi->custom_data;
|
||||
int& dep = deps[opid];
|
||||
dep --;
|
||||
if (dep == 0)
|
||||
sharegraph_q.push_back(opid);
|
||||
}
|
||||
}
|
||||
LOGvvvv << "sharegraph_q" << sharegraph_q;
|
||||
ASSERTop(sharegraph.size(),==,sharegraph_q.size());
|
||||
// topsort fused op internal
|
||||
for (uint s=0; s<queue.size(); s++) {
|
||||
int i = queue[s];
|
||||
Op* op = ops[i];
|
||||
|
||||
for (Var* v : op->outputs())
|
||||
if (v->tflag == tt)
|
||||
for (Op* op2 : v->outputs()) {
|
||||
if (op2->tflag != tt) continue;
|
||||
int op2_id = op2->custom_data;
|
||||
// continue if those two ops are not fused
|
||||
if (father[op2_id] != root) continue;
|
||||
deps[op2_id]--;
|
||||
if (deps[op2_id] == 0)
|
||||
queue.push_back(op2_id);
|
||||
}
|
||||
}
|
||||
ASSERTop(queue.size(),==,(uint)total);
|
||||
LOGvvvv << "topsort internal" << queue;
|
||||
for (int i=(int)sharegraph_q.size()-1; i>=0; i--)
|
||||
fuse_ops.push_back(sharegraph_q[i]);
|
||||
for (uint i=0; i<queue.size(); i++)
|
||||
fuse_ops.push_back(queue[i]);
|
||||
range[rid] = fuse_ops.size();
|
||||
}
|
||||
}
|
||||
for (int i=0; i<var_num; i++) {
|
||||
all_vars[i]->custom_data = var_fused[i]==1;
|
||||
}
|
||||
|
||||
// running
|
||||
FusedOp fused_op;
|
||||
vector<Var*> outputs_bk;
|
||||
#ifdef HAS_CUDA
|
||||
int sync_times = 0;
|
||||
#endif
|
||||
for (uint rid=0; rid<queue.size(); rid++) {
|
||||
int root = queue[rid];
|
||||
Op* op = ops[root];
|
||||
bool is_fused_op = false;
|
||||
try {
|
||||
if (op->type() != OpType::other) {
|
||||
op = &fused_op;
|
||||
is_fused_op = true;
|
||||
fused_op.ops.clear();
|
||||
fused_op.edges.clear();
|
||||
int ll = (rid<queue.size()-1)?range[queue.size()-rid-2]:0, rr = range[queue.size()-rid-1];
|
||||
root = fuse_ops[rr-1];
|
||||
auto ntt = ++Node::tflag_count;
|
||||
for (int i=ll; i<rr; i++) {
|
||||
int opid = fuse_ops[i];
|
||||
Op* op = ops[opid];
|
||||
uint64_t fid1 = fused_op.ops.size();
|
||||
op->custom_data = fid1;
|
||||
op->tflag = ntt;
|
||||
fused_op.ops.push_back(op);
|
||||
}
|
||||
for (Op* op : fused_op.ops) {
|
||||
uint fid1 = op->custom_data;
|
||||
uint oid = 0;
|
||||
for (Var* v : op->outputs()) {
|
||||
oid++;
|
||||
if (v->tflag != tt) {
|
||||
// this var node not belong to current execution
|
||||
// this will happend in multiple outputs fuseable op
|
||||
v->custom_data = 0;
|
||||
continue;
|
||||
}
|
||||
for (auto o : v->outputs_with_index()) {
|
||||
Op* op2 = o.op;
|
||||
uint iid = o.index;
|
||||
if (op2->tflag != ntt) continue;
|
||||
uint fid2 = op2->custom_data;
|
||||
fused_op.edges.emplace_back(fid1, oid-1, fid2, iid);
|
||||
}
|
||||
}
|
||||
}
|
||||
LOGvvv << "Prepare fused_op" << fused_op.ops;
|
||||
fused_op.update_ops();
|
||||
}
|
||||
LOGvvv << "Run" << op;
|
||||
if (!op->shape_infered()) op->infer_shape();
|
||||
ASSERT(op->shape_infered()) << "Shape of(" >> op->name() >> ") not solved.";
|
||||
for (auto* var : op->outputs())
|
||||
var->alloc(allocator);
|
||||
LOGvvv << "Run" << op << "inputs:" << op->inputs() << "outputs:" << op->outputs();
|
||||
op->do_prepare();
|
||||
bool is_cuda = op->flags.get(NodeFlags::_cuda);
|
||||
#ifdef HAS_CUDA
|
||||
if (!is_cuda) {
|
||||
if (last_is_cuda) {
|
||||
// if prev op in gpu and this op in cpu
|
||||
// cuda sync
|
||||
checkCudaErrors(cudaDeviceSynchronize());
|
||||
sync_times++;
|
||||
}
|
||||
for (Var* v : op->inputs()) {
|
||||
migrate_to_cpu(v, allocator);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#ifdef NODE_MEMCHECK
|
||||
if (is_fused_op) {
|
||||
for (auto& vi : fused_op.vars)
|
||||
if (vi.type == 0)
|
||||
ASSERT(vi.var->mem_ptr) << vi.var;
|
||||
} else {
|
||||
for (auto* v : op->inputs())
|
||||
ASSERT(v->mem_ptr) << v;
|
||||
}
|
||||
#endif
|
||||
last_is_cuda = is_cuda;
|
||||
op->do_run_after_prepare();
|
||||
LOGvvv << "Finished Op(" >> op->name() << rid >>
|
||||
"/" >> queue.size() >> ") output:" << op->outputs();
|
||||
if (is_fused_op) {
|
||||
for (Var* var : op->outputs())
|
||||
var->finish_pending_liveness();
|
||||
continue;
|
||||
}
|
||||
// release liveness when op is finished
|
||||
// outputs may change during free, we need to backup it;
|
||||
outputs_bk.clear();
|
||||
for (Var* var : op->outputs())
|
||||
outputs_bk.push_back(var);
|
||||
op->finish_pending_liveness();
|
||||
for (Var* var : outputs_bk)
|
||||
// var->finish_pending_liveness();
|
||||
var->finish_pending_liveness();
|
||||
} catch (const std::exception& e) {
|
||||
if (is_fused_op) {
|
||||
LOGf << "Execute fused operator(" >> rid >> '/' >> queue.size() >> ")"
|
||||
<< "failed:" << fused_op.ops << "\n\nReason: " >> e.what();
|
||||
} else
|
||||
LOGf << "Execute operator(" >> rid >> '/' >> queue.size() >> ")"
|
||||
<< "failed:" << op << "\n\nReason: " >> e.what();
|
||||
}
|
||||
}
|
||||
LOGvv << "All" << op_num << "ops finished, return vars:" << vars;
|
||||
for (Var* v : vars) ASSERT(v->mem_ptr);
|
||||
#ifdef HAS_CUDA
|
||||
if (device_sync) {
|
||||
last_is_cuda = false;
|
||||
sync_times++;
|
||||
event_queue.run_sync([]() {
|
||||
checkCudaErrors(cudaDeviceSynchronize());
|
||||
});
|
||||
}
|
||||
LOGvv << "cudaDeviceSynchronize times:" << sync_times << "/" <<queue.size();
|
||||
#endif
|
||||
}
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,108 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#ifdef HAS_CUDA
|
||||
#include <cuda_runtime.h>
|
||||
#include <helper_cuda.h>
|
||||
#include <mutex>
|
||||
#include "mem/allocator/sfrl_allocator.h"
|
||||
#include "mem/allocator/cuda_dual_allocator.h"
|
||||
#include "event_queue.h"
|
||||
#endif
|
||||
#include "fetcher.h"
|
||||
#include "mem/allocator.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
#ifdef HAS_CUDA
|
||||
|
||||
#pragma GCC visibility push(hidden)
|
||||
namespace fetcher_local {
|
||||
|
||||
cudaStream_t stream;
|
||||
cudaEvent_t event;
|
||||
|
||||
volatile int64 n_to_fetch;
|
||||
std::mutex m;
|
||||
list<FetchResult> fetch_tasks;
|
||||
|
||||
static void fetch_caller() {
|
||||
fetch_tasks.front().call();
|
||||
fetch_tasks.pop_front();
|
||||
}
|
||||
|
||||
static void to_fetch(void*) {
|
||||
event_queue.push(fetch_caller);
|
||||
}
|
||||
|
||||
struct Init {
|
||||
Init() {
|
||||
checkCudaErrors(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
|
||||
checkCudaErrors(cudaEventCreate(&event, cudaEventDisableTiming));
|
||||
}
|
||||
~Init() {
|
||||
// do not call deleter on exit
|
||||
for (auto& f : fetch_tasks)
|
||||
f.func.deleter = nullptr;
|
||||
checkCudaErrors(cudaDeviceSynchronize());
|
||||
checkCudaErrors(cudaStreamDestroy(stream));
|
||||
checkCudaErrors(cudaEventDestroy(event));
|
||||
}
|
||||
} init;
|
||||
|
||||
}
|
||||
using namespace fetcher_local;
|
||||
|
||||
#endif
|
||||
|
||||
void fetch(const vector<VarHolder*>& vh, FetchFunc&& func) {
|
||||
sync(vh);
|
||||
vector<Allocation> allocations(vh.size());
|
||||
vector<ArrayArgs> arrays(vh.size());
|
||||
#ifdef HAS_CUDA
|
||||
bool has_cuda_memcpy = false;
|
||||
event_queue.flush();
|
||||
#endif
|
||||
for (int i=0; i<vh.size(); i++) {
|
||||
auto v = vh[i]->var;
|
||||
auto& allocation = allocations[i];
|
||||
#ifdef HAS_CUDA
|
||||
if (v->allocator->is_cuda()) {
|
||||
checkCudaErrors(cudaEventRecord(event, 0));
|
||||
checkCudaErrors(cudaStreamWaitEvent(stream, event, 0));
|
||||
new (&allocation) Allocation(&cuda_dual_allocator, v->size);
|
||||
// mostly device to device
|
||||
checkCudaErrors(cudaMemcpyAsync(
|
||||
allocation.ptr, v->mem_ptr, v->size, cudaMemcpyDefault, stream));
|
||||
auto host_ptr = cuda_dual_allocator.get_dual_allocation(
|
||||
allocation.allocation).host_ptr;
|
||||
// device to host
|
||||
checkCudaErrors(cudaMemcpyAsync(
|
||||
host_ptr, allocation.ptr, v->size, cudaMemcpyDefault, stream));
|
||||
allocation.ptr = host_ptr;
|
||||
has_cuda_memcpy = true;
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
new (&allocation) Allocation(cpu_allocator, v->size);
|
||||
std::memcpy(allocation.ptr, v->mem_ptr, v->size);
|
||||
}
|
||||
arrays[i].ptr = allocation.ptr;
|
||||
arrays[i].shape = v->shape;
|
||||
arrays[i].dtype = v->dtype();
|
||||
}
|
||||
#ifdef HAS_CUDA
|
||||
if (has_cuda_memcpy) {
|
||||
fetch_tasks.push_back({move(func), move(allocations), move(arrays)});
|
||||
checkCudaErrors(cudaLaunchHostFunc(stream, &to_fetch, 0));
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
FetchResult fr{move(func), move(allocations), move(arrays)};
|
||||
fr.call();
|
||||
}
|
||||
}
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,242 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include "fused_op.h"
|
||||
#include "var.h"
|
||||
#include "op_compiler.h"
|
||||
#include "profiler/profiler.h"
|
||||
#include "misc/fast_shared_ptr.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
#ifndef JIT
|
||||
|
||||
string_view_map<FusedOpContext*> jit_fused_ops;
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const VarInfo& vi) {
|
||||
return os << vi.var << " type:" << vi.type;
|
||||
}
|
||||
|
||||
int FusedOp::get_loop_option(const string& key, const int& _default) {
|
||||
auto iter = loop_options->find(key);
|
||||
return iter == loop_options->end() ? _default : iter->second;
|
||||
}
|
||||
|
||||
loop_options_t& FusedOp::get_loop_options_tuned() {
|
||||
loop_options_tuned = *loop_options_origin;
|
||||
loop_options = &loop_options_tuned;
|
||||
return loop_options_tuned;
|
||||
}
|
||||
|
||||
void FusedOp::update_jit_key() {
|
||||
jk.clear();
|
||||
do_jit_prepare();
|
||||
}
|
||||
|
||||
void FusedOp::update_ops() {
|
||||
loop_options_merged.clear();
|
||||
loop_options_tuned.clear();
|
||||
loop_options = loop_options_origin = nullptr;
|
||||
|
||||
_outputs.clear();
|
||||
jk.clear();
|
||||
for (Op* op : ops) {
|
||||
for (Var* o : op->outputs()) {
|
||||
if (o->loop_options) {
|
||||
if (loop_options_origin == nullptr)
|
||||
loop_options_origin = &o->loop_options.data();
|
||||
else if (loop_options_origin != &o->loop_options.data()) {
|
||||
// merge loop options
|
||||
for (auto& kv : o->loop_options.data())
|
||||
loop_options_merged[kv.first] = kv.second;
|
||||
}
|
||||
}
|
||||
// bit0 represents can fuse or not
|
||||
if (o->custom_data&1)
|
||||
// this var can not fuse
|
||||
_outputs.emplace_back((Node*)o, 0);
|
||||
}
|
||||
}
|
||||
|
||||
if (loop_options_origin) {
|
||||
if (loop_options_merged.size()) {
|
||||
// merge loop_options_origin into loop_options_merged
|
||||
for (auto& kv : *loop_options_origin)
|
||||
loop_options_merged.emplace(kv);
|
||||
}
|
||||
} else {
|
||||
loop_options_origin = &loop_options_merged;
|
||||
}
|
||||
loop_options = loop_options_origin;
|
||||
|
||||
ASSERT(outputs().size());
|
||||
LOGvvvv << "set fused output" << outputs();
|
||||
|
||||
// var.custom_data
|
||||
// meaning of custom_data&1(input): 1: cannot fuse, 0 can fuse
|
||||
// meaning of custom_data&2: visited or not
|
||||
// meaning of custom_data>>2: index of vars
|
||||
|
||||
// op.custom_data: opid
|
||||
for (uint i=0; i<ops.size(); i++) {
|
||||
auto opi = ops[i];
|
||||
opi->custom_data = i;
|
||||
for (Var* i : opi->inputs()) {
|
||||
i->custom_data &= 1;
|
||||
}
|
||||
for (Var* o : opi->outputs()) {
|
||||
o->custom_data &= 1;
|
||||
}
|
||||
}
|
||||
vars.clear();
|
||||
for (Op* opi : ops) {
|
||||
for (Var* i : opi->inputs()) {
|
||||
auto &c = i->custom_data;
|
||||
// if not visited
|
||||
if (!(c&2)) {
|
||||
c += 2 + vars.size()*4;
|
||||
vars.push_back({i, 0});
|
||||
}
|
||||
}
|
||||
for (Var* o : opi->outputs()) {
|
||||
auto &c = o->custom_data;
|
||||
// if not visited
|
||||
if (!(c&2)) {
|
||||
c += 2 + vars.size()*4;
|
||||
// intermediate(can fuse) or output
|
||||
vars.push_back({o, int((c&1)+1)});
|
||||
}
|
||||
}
|
||||
}
|
||||
LOGvvvv << "Var info" << vars;
|
||||
}
|
||||
|
||||
|
||||
FusedOp::FusedOp() {
|
||||
Op::number_of_lived_ops--;
|
||||
}
|
||||
|
||||
FusedOp::~FusedOp() {
|
||||
_outputs.clear();
|
||||
Op::number_of_lived_ops++;
|
||||
}
|
||||
|
||||
void FusedOp::infer_shape() {
|
||||
for (uint i=0; i<ops.size(); i++)
|
||||
ops[i]->infer_shape();
|
||||
}
|
||||
|
||||
bool FusedOp::shape_infered() {
|
||||
for (uint i=0; i<ops.size(); i++)
|
||||
if (!ops[i]->shape_infered())
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
void FusedOp::statistics(uint64_t& in, uint64_t& out, uint64_t& compute) {
|
||||
in = out = compute = 0;
|
||||
for (auto& vi : vars) {
|
||||
compute = std::max(compute, (uint64_t)vi.var->num);
|
||||
if (vi.type == 0) in += vi.var->size;
|
||||
if (vi.type == 2) out += vi.var->size;
|
||||
}
|
||||
}
|
||||
|
||||
void FusedOp::do_jit_prepare() {
|
||||
jk.clear();
|
||||
int8 flags = 3;
|
||||
for (uint i=0; i<ops.size(); i++) {
|
||||
Op* op = ops[i];
|
||||
jk << JK::key << "opkey" << i << JK::val;
|
||||
op->do_jit_prepare();
|
||||
jk << JK::end;
|
||||
if (op->flags.get(NodeFlags::_cpu))
|
||||
flags &= 1; // only cpu
|
||||
else
|
||||
flags &= 2; // only gpu
|
||||
}
|
||||
ASSERT(flags) << "FusedOp cannot contain both cpu and cuda ops.";
|
||||
add_jit_define("JIT", "1");
|
||||
if (flags==1) {
|
||||
// only cpu
|
||||
add_jit_define("JIT_cpu", "1");
|
||||
this->flags.set(NodeFlags::_cuda, 0);
|
||||
this->flags.set(NodeFlags::_cpu, 1);
|
||||
} else {
|
||||
add_jit_define("JIT_cuda", "1");
|
||||
this->flags.set(NodeFlags::_cpu, 0);
|
||||
this->flags.set(NodeFlags::_cuda, 1);
|
||||
}
|
||||
jk << JK::key << "graph" << JK::val;
|
||||
for (auto& t : edges) {
|
||||
uint i,j,k,l;
|
||||
std::tie(i,j,k,l) = t;
|
||||
jk << JK::hex2(i) << JK::hex1(j) << JK::hex2(k) << JK::hex1(l) << ',';
|
||||
}
|
||||
jk << JK::end << JK::key << "var_info" << JK::val;
|
||||
for (auto& vi : vars)
|
||||
jk << JK::hex1(vi.type) << JK::hex1(vi.var->shape.size());
|
||||
jk << JK::end;
|
||||
if (loop_options->size()) {
|
||||
if (get_loop_option("compile_shapes")) {
|
||||
jk << JK::key << "shapes" << JK::val;
|
||||
for (auto& vi : vars) {
|
||||
jk << '[';
|
||||
for (auto a : vi.var->shape)
|
||||
jk << a << ',';
|
||||
jk << "],";
|
||||
}
|
||||
jk << JK::end;
|
||||
}
|
||||
jk << JK::key << "choices" << JK::val;
|
||||
for (auto& kv : *loop_options)
|
||||
jk << kv.first << ':' << kv.second << ',';
|
||||
jk << JK::end;
|
||||
}
|
||||
jk.finilize();
|
||||
}
|
||||
|
||||
void FusedOp::do_prepare() {
|
||||
do_jit_prepare();
|
||||
}
|
||||
|
||||
void FusedOp::do_run_after_prepare() {
|
||||
const char* jit_key = jk.to_cstring();
|
||||
auto iter = jit_fused_ops.find(string_view(jit_key, jk.size));
|
||||
if (iter != jit_fused_ops.end()) {
|
||||
LOGvvv << "Jit fused op key found:" << jit_key << "jit op entry:" << (void*)iter->second;
|
||||
context = iter->second;
|
||||
iter->second->vrm.fop = this;
|
||||
Profiler::record_and_run(iter->second->entry, this, jit_key);
|
||||
return;
|
||||
}
|
||||
LOGvv << "Jit op key not found:" << jit_key;
|
||||
// compile JIT op
|
||||
context = new FusedOpContext();
|
||||
context->vrm.fop = this;
|
||||
string prev_jit_key = jit_key;
|
||||
context->entry = OpCompiler::do_compile(this);
|
||||
string new_jit_key = get_jit_key();
|
||||
jit_fused_ops[new_jit_key] = jit_fused_ops[prev_jit_key] = context;
|
||||
jit_key_mapper[prev_jit_key] = new_jit_key;
|
||||
LOGvv << "Get jit op entry:" << (void*)(context->entry);
|
||||
Profiler::record_and_run(context->entry, this, new_jit_key.c_str());
|
||||
}
|
||||
|
||||
void FusedOp::do_run(){
|
||||
do_prepare();
|
||||
do_run_after_prepare();
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
void FusedOp::jit_run() {
|
||||
for (uint i=0; i<ops.size(); i++) {
|
||||
LOGvvvv << "fuse run:" << ops[i] << ops[i]->inputs() << ops[i]->outputs();
|
||||
ops[i]->do_run();
|
||||
}
|
||||
}
|
||||
#endif // JIT
|
||||
|
||||
}
|
|
@ -0,0 +1,187 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors:
|
||||
// Guowei Yang <471184555@qq.com>
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include "fuser.h"
|
||||
#include "var.h"
|
||||
#include "op.h"
|
||||
#include "mem/allocator.h"
|
||||
#include "graph.h"
|
||||
#include "fused_op.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
#define PREVENT_LARGE_FUSED_OP 16
|
||||
|
||||
void count_fuse(int64_t tt, int start_var_num, const vector<Op*>& ops, const vector<Var*>& vars, vector<int> &father, vector<int> &var_fused) {
|
||||
vector<int> dis(ops.size(), -1);
|
||||
|
||||
auto find_fa = [&](int i) -> int {
|
||||
int j=i;
|
||||
while (father[j] != j) j = father[j];
|
||||
while (i != j) {
|
||||
int tmp = father[i];
|
||||
father[i] = j;
|
||||
i = tmp;
|
||||
}
|
||||
return j;
|
||||
};
|
||||
|
||||
auto can_fuse = [&](Var* v, Op* op1, Op* op2, int fuse_type) -> bool {
|
||||
if (v->flags.get(NodeFlags::_stop_fuse))
|
||||
return false;
|
||||
if (fuse_type == 1) {
|
||||
// if v is output, do not fuse
|
||||
if (v->custom_data < start_var_num)
|
||||
return false;
|
||||
// op2 ---> v ---> op1
|
||||
if (op1->type() == OpType::other || op2->type() == OpType::other)
|
||||
return false;
|
||||
if (v->flags.get(NodeFlags::_force_fuse))
|
||||
return true;
|
||||
// Do not fuse op after reduce(has reduce)
|
||||
// TODO: better fuse strategy
|
||||
if (op2->type() == OpType::reduce)
|
||||
return false;
|
||||
// Do not fuse op before broadcast
|
||||
// TODO: better fuse strategy
|
||||
if (op1->type() == OpType::broadcast)
|
||||
return false;
|
||||
return op2->type() == OpType::element ||
|
||||
op2->type() == OpType::broadcast;
|
||||
} else if (fuse_type == 0) {
|
||||
#ifdef PREVENT_LARGE_FUSED_OP
|
||||
// This statement prevent fuse large ops
|
||||
if (v->outputs().size()>=PREVENT_LARGE_FUSED_OP) return false;
|
||||
#endif
|
||||
|
||||
// v ---> op1
|
||||
// |
|
||||
// +----> op2 ( prev of op1 )
|
||||
if (op1->type() == OpType::other || op2->type() == OpType::other)
|
||||
return false;
|
||||
// Do not fuse op after reduce(has reduce)
|
||||
// TODO: better fuse strategy
|
||||
if (op2->type() == OpType::broadcast || op1->type() == OpType::broadcast)
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
auto for_each_edge = [&](Op* op, int forward, auto&& func){
|
||||
auto e=op->_inputs.begin();
|
||||
for (Var* v : op->inputs()) {
|
||||
if ((forward && (*e).back!=std::prev(v->_outputs.end())) ||
|
||||
(!forward && (*e).back!=v->_outputs.begin())){
|
||||
Op* next_op = forward ? std::next((*e).back)->node->op() : std::prev((*e).back)->node->op();
|
||||
if (next_op && next_op->tflag==tt
|
||||
&& next_op->custom_data != op->custom_data
|
||||
&& can_fuse(v, next_op, op, 0))
|
||||
func(v, next_op, 0);
|
||||
}
|
||||
e = std::next(e);
|
||||
}
|
||||
|
||||
if (forward) {
|
||||
for (Var* sv : op->outputs())
|
||||
if (sv && sv->tflag == tt)
|
||||
for (Op* next_op: sv->outputs())
|
||||
if (next_op && next_op->tflag==tt) func(sv, next_op, 1);
|
||||
} else {
|
||||
for (Var* sv : op->inputs())
|
||||
if (sv && sv->tflag == tt) func(sv, sv->input(), 1);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
vector<int> queue;
|
||||
vector<int> deps;
|
||||
deps.reserve(ops.size());
|
||||
queue.reserve(ops.size());
|
||||
for (uint i=0; i<ops.size(); i++) {
|
||||
deps.push_back(0);
|
||||
Op* op = ops[i];
|
||||
|
||||
for_each_edge(op, 1, [&](Var* v, Op* next_op, int real_edge) {
|
||||
deps[i]++;
|
||||
});
|
||||
|
||||
if (!deps[i]) {
|
||||
queue.push_back(i);
|
||||
dis[i]=0;
|
||||
}
|
||||
}
|
||||
|
||||
uint head=0;
|
||||
while (head<queue.size()) {
|
||||
int op_id=queue[head++];
|
||||
Op* op = ops[op_id];
|
||||
|
||||
for_each_edge(op, 1, [&](Var* v, Op* next_op, int real_edge) {
|
||||
int next_id = next_op->custom_data;
|
||||
if (dis[next_id] == dis[op_id]){
|
||||
int next_fa = find_fa(next_id);
|
||||
father[next_fa] = op_id;
|
||||
}
|
||||
});
|
||||
|
||||
for_each_edge(op, 0, [&](Var* v, Op* next_op, int real_edge) {
|
||||
int next_id = next_op->custom_data;
|
||||
int lon=0;
|
||||
if (real_edge && !can_fuse(v, op, next_op, 1)) lon=1;
|
||||
if (dis[op_id]+lon>dis[next_id])
|
||||
dis[next_id]=dis[op_id]+lon;
|
||||
if (!--deps[next_id]) queue.push_back(next_id);
|
||||
});
|
||||
}
|
||||
|
||||
if (V_ON(1000)) {
|
||||
for (uint i=0; i<ops.size(); i++)
|
||||
LOGvvvv << ops[i] << dis[i] << deps[i];
|
||||
}
|
||||
|
||||
for (uint i=0; i<vars.size(); i++) {
|
||||
Var* v = vars[i];
|
||||
if (!v || v->tflag!=tt) {
|
||||
var_fused[i]=1;
|
||||
continue;
|
||||
}
|
||||
// sf: input op's father id
|
||||
int sf = -1;
|
||||
// vf: is input op can be fused with all output op
|
||||
int vf = 1;
|
||||
// all outputs are reduce
|
||||
int all_reduce = 1;
|
||||
Op* iop = v->input();
|
||||
// if (iop && iop->tflag==tt)
|
||||
sf = find_fa(iop->custom_data);
|
||||
|
||||
for (Op* sop : v->outputs())
|
||||
if (sop->tflag==tt) {
|
||||
if (vf && !can_fuse(v,sop,iop,1))
|
||||
vf = 0;
|
||||
if (sop->type()!=OpType::reduce)
|
||||
all_reduce = 0;
|
||||
// in two different fused op
|
||||
if (find_fa(sop->custom_data)!=sf) {
|
||||
var_fused[i]=1;
|
||||
}
|
||||
}
|
||||
if (vf==0) var_fused[i]=1;
|
||||
if (var_fused[i] && vf &&
|
||||
(iop->type()==OpType::broadcast || all_reduce || v->flags.get(NodeFlags::_force_fuse)))
|
||||
var_fused[i]=2;
|
||||
}
|
||||
// output vars can not be fused
|
||||
for (int i=0; i<start_var_num; i++)
|
||||
var_fused[i] = 1;
|
||||
}
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,134 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include "pybind/py_var_tracer.h"
|
||||
#include "grad.h"
|
||||
#include "var.h"
|
||||
#include "op.h"
|
||||
#include "graph.h"
|
||||
#include "ops/op_register.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
#define PREVENT_LARGE_FUSED_OP 16
|
||||
|
||||
static auto make_binary = get_op_info("binary")
|
||||
.get_constructor<VarPtr, Var*, Var*, NanoString>();
|
||||
static auto make_number = get_op_info("number")
|
||||
.get_constructor<VarPtr, float, Var*>();
|
||||
|
||||
|
||||
VarPtr make_grad(Op* op, Var* out, Var* dout, Var* x, int x_index) {
|
||||
if (dout == nullptr) return nullptr;
|
||||
LOGvvvv << "Make grad op:" >> op->name() << "inputs:" >> op->inputs()
|
||||
<< "out:" >> out << "dout:" >> dout << "x:" >> x << "xid:" >> x_index;
|
||||
return op->grad(out, dout, x, x_index);
|
||||
}
|
||||
|
||||
inline static void assign_attrs(Var* a, Var* b) {
|
||||
if (b->flags.get(NodeFlags::_stop_fuse))
|
||||
a->flags.set(NodeFlags::_stop_fuse);
|
||||
}
|
||||
|
||||
vector<VarPtr> grad(Var* loss, vector<Var*> targets) {
|
||||
LOGvv << "loss:" >> loss << "targets:" >> targets;
|
||||
CHECK(loss->is_float()) << "Loss should be float";
|
||||
for (Var* var : targets)
|
||||
CHECK(var->is_float()) << "Targets of grad should be float";
|
||||
// successors of targets
|
||||
vector<Node*> ts(targets.begin(), targets.end());
|
||||
// bfs visit find all successors of targets
|
||||
LOGvv << "Size of successors:" << ts.size();
|
||||
bfs_forward(ts, [](Node*){ return true; });
|
||||
vector<Node*> gnodes;
|
||||
gnodes.reserve(ts.size());
|
||||
auto nt = Node::tflag_count;
|
||||
if (loss->tflag == nt)
|
||||
gnodes.push_back(loss);
|
||||
bfs_backward(gnodes, [&](Node* node) {
|
||||
if (node->tflag != nt)
|
||||
return false;
|
||||
if (node->is_stop_grad())
|
||||
return false;
|
||||
// int value has zero grad
|
||||
if (node->is_var())
|
||||
return node->var()->is_float();
|
||||
return true;
|
||||
});
|
||||
LOGvv << "Size of grad nodes:" << gnodes.size();
|
||||
|
||||
vector<Node*> sorted;
|
||||
toplogical_sort_backward(gnodes, sorted, [](Node*){});
|
||||
nt = Node::tflag_count;
|
||||
vector<Var*> gvars;
|
||||
gvars.reserve(sorted.size());
|
||||
for (Node* node : sorted)
|
||||
if (node->is_var())
|
||||
gvars.push_back(node->var());
|
||||
LOGvv << "Size of grad vars:" << gvars.size();
|
||||
|
||||
vector<VarPtr> grads(gvars.size());
|
||||
vector<VarPtr> results(targets.size());
|
||||
for (size_t i=0; i<gvars.size(); i++)
|
||||
gvars[i]->custom_data = i;
|
||||
|
||||
for (size_t i=0; i<gvars.size(); i++) {
|
||||
Var* var = gvars[i];
|
||||
auto& grad = grads[i];
|
||||
#ifdef PREVENT_LARGE_FUSED_OP
|
||||
int gsum = 0;
|
||||
#endif
|
||||
if (i==0) {
|
||||
grad = make_number(1.f, loss);
|
||||
assign_attrs(grad.ptr, loss);
|
||||
registe_node_trace_grad(grad.ptr, loss, 0);
|
||||
} else
|
||||
for (auto it : var->outputs_with_index()) {
|
||||
Op* op = it.op;
|
||||
auto index = it.index;
|
||||
if (op->tflag != nt) continue;
|
||||
// TODO: support two outputs backprop.
|
||||
Var* out = op->outputs().back();
|
||||
Var* dout = grads[out->custom_data];
|
||||
VarPtr dvar = make_grad(op, out, dout, var, index);
|
||||
registe_node_trace_grad(dvar.ptr, op, index);
|
||||
if (dvar)
|
||||
ASSERT(dvar->num==var->num && dvar->shape.size()==var->shape.size())
|
||||
<< "dvar" << dvar << "var" << var;
|
||||
if (!grad)
|
||||
grad = move(dvar);
|
||||
else if (dvar) {
|
||||
grad = make_binary(grad, dvar, ns_add);
|
||||
#ifdef PREVENT_LARGE_FUSED_OP
|
||||
gsum ++;
|
||||
if (gsum>=PREVENT_LARGE_FUSED_OP) {
|
||||
// TODO: this is a dirty fix for
|
||||
// stopping fuse lots of op together,
|
||||
// try to find a better solution
|
||||
grad->flags.set(NodeFlags::_stop_fuse);
|
||||
}
|
||||
#endif
|
||||
assign_attrs(grad.ptr, var);
|
||||
registe_node_trace_grad(grad.ptr, var, index);
|
||||
}
|
||||
}
|
||||
}
|
||||
// set zero grad
|
||||
for (size_t i=0; i<results.size(); i++) {
|
||||
Var* var = targets[i];
|
||||
VarPtr& grad = results[i];
|
||||
if (var->tflag == nt)
|
||||
grad = move(grads[var->custom_data]);
|
||||
if (!grad) {
|
||||
LOGvvv << var << "grads[">>i>>"] set to zero";
|
||||
grad = make_number(0.f, var);
|
||||
assign_attrs(grad.ptr, var);
|
||||
registe_node_trace_grad(grad.ptr, var, 0);
|
||||
}
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,114 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include <sstream>
|
||||
#include "graph.h"
|
||||
#include "var_holder.h"
|
||||
#include "var.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
DEFINE_FLAG(int, check_graph, 0, "Unify graph sanity check.");
|
||||
|
||||
extern unordered_map<void*, int64> lived_nodes;
|
||||
|
||||
template <typename T>
|
||||
string ss_convert(T x) {
|
||||
std::stringstream ss;
|
||||
ss << x;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
void do_graph_check() {
|
||||
vector<Node*> queue;
|
||||
unordered_map<Node*,int> visited;
|
||||
for (auto& vh : VarHolder::hold_vars) {
|
||||
if (0==visited[vh->var]++)
|
||||
queue.push_back(vh->var);
|
||||
}
|
||||
LOGvv << "Check hold_vars size" << queue.size();
|
||||
int vhsize = queue.size();
|
||||
for (auto* node : queue) {
|
||||
ASSERTop(node->forward_liveness,>,0);
|
||||
ASSERTop(node->backward_liveness,>,0);
|
||||
}
|
||||
for (uint i=0; i<queue.size(); i++) {
|
||||
auto* node = queue[i];
|
||||
for (auto* i : node->inputs()) {
|
||||
if (visited.count(i)) continue;
|
||||
visited[i] = 0;
|
||||
queue.push_back(i);
|
||||
}
|
||||
}
|
||||
LOGvv << "Check all var size" << queue.size();
|
||||
for (int i=0; i<(int)queue.size(); i++) {
|
||||
auto* node = queue[i];
|
||||
LOGvvvv << "Check node" << i << node;
|
||||
int f=0, b=0, p=0;
|
||||
if (i<vhsize) {
|
||||
f+=visited.at(node), b+=visited.at(node);
|
||||
}
|
||||
for (auto* i : node->inputs()) {
|
||||
if (i->is_stop_grad()) continue;
|
||||
if (!i->forward_liveness) continue;
|
||||
f ++;
|
||||
}
|
||||
for (auto* o : node->outputs()) {
|
||||
if (o->backward_liveness)
|
||||
b ++;
|
||||
if (o->pending_liveness && !o->is_finished())
|
||||
p++;
|
||||
}
|
||||
if (f>0 && b>0 && !node->is_finished()) p++;
|
||||
if (f!=node->forward_liveness || b!=node->backward_liveness || p!=node->pending_liveness) {
|
||||
LOGf << "ERROR" << node << '\n'
|
||||
<< f << b << p << i << '\n'
|
||||
<< node->inputs() << '\n'
|
||||
<< node->outputs();
|
||||
continue;
|
||||
}
|
||||
}
|
||||
for (auto& kv : lived_nodes) {
|
||||
if (!kv.second) continue;
|
||||
auto* node = (Node*) kv.first;
|
||||
if (!visited.count(node) && node->tflag != -1) {
|
||||
if (node->is_var() && node->_inputs.size())
|
||||
continue;
|
||||
LOGf << "ERROR dnode" << (void*)node << kv.second << node;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
DumpGraphs dump_all_graphs() {
|
||||
vector<Node*> queue;
|
||||
auto t = ++Node::tflag_count;
|
||||
for (auto& vh : VarHolder::hold_vars)
|
||||
if (vh->var->tflag != t) {
|
||||
vh->var->tflag = t;
|
||||
queue.push_back(vh->var);
|
||||
}
|
||||
bfs_both(queue, [](Node*){return true;});
|
||||
DumpGraphs graphs;
|
||||
for (uint i=0; i<queue.size(); i++)
|
||||
queue[i]->custom_data = i;
|
||||
for (Node* node : queue) {
|
||||
graphs.nodes_info.emplace_back(ss_convert(node));
|
||||
|
||||
graphs.inputs.emplace_back();
|
||||
auto& inputs = graphs.inputs.back();
|
||||
inputs.reserve(node->_inputs.size());
|
||||
for (auto i : node->_inputs)
|
||||
inputs.push_back(i.node->custom_data);
|
||||
|
||||
graphs.outputs.emplace_back();
|
||||
auto& outputs = graphs.outputs.back();
|
||||
outputs.reserve(node->_outputs.size());
|
||||
for (auto o : node->_outputs)
|
||||
outputs.push_back(o.node->custom_data);
|
||||
}
|
||||
return graphs;
|
||||
}
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,39 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include <random>
|
||||
|
||||
#include "init.h"
|
||||
#include "ops/op_register.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
unique_ptr<std::default_random_engine> eng;
|
||||
|
||||
vector<set_seed_callback> callbacks;
|
||||
int current_seed;
|
||||
|
||||
void init() {
|
||||
// init default_random_engine
|
||||
set_seed(time(0));
|
||||
// init fused op
|
||||
op_registe({"fused","",""});
|
||||
}
|
||||
|
||||
void set_seed(int seed) {
|
||||
current_seed = seed;
|
||||
eng.reset(new std::default_random_engine(seed));
|
||||
for (auto cb : callbacks)
|
||||
cb(seed);
|
||||
}
|
||||
|
||||
void add_set_seed_callback(set_seed_callback callback) {
|
||||
callbacks.push_back(callback);
|
||||
callback(current_seed);
|
||||
}
|
||||
|
||||
std::default_random_engine* get_random_engine() { return eng.get(); }
|
||||
|
||||
}
|
|
@ -0,0 +1,100 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include <fstream>
|
||||
#include <streambuf>
|
||||
#include <stdlib.h>
|
||||
#include <dlfcn.h>
|
||||
|
||||
#include "jit_compiler.h"
|
||||
#include "op.h"
|
||||
#include "utils/cache_compile.h"
|
||||
#include "utils/flags.h"
|
||||
#include "fused_op.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
DEFINE_FLAG(string, jittor_path, "", "Source path of jittor");
|
||||
DEFINE_FLAG(string, cc_path, "", "Path of C++ compiler");
|
||||
DEFINE_FLAG(string, cc_type, "", "Type of C++ compiler(clang, icc, g++)");
|
||||
DEFINE_FLAG(string, cc_flags, "", "Flags of C++ compiler");
|
||||
DEFINE_FLAG(string, nvcc_path, "", "Path of CUDA C++ compiler");
|
||||
DEFINE_FLAG(string, nvcc_flags, "", "Flags of CUDA C++ compiler");
|
||||
DEFINE_FLAG(string, python_path, "", "Path of python interpreter");
|
||||
DEFINE_FLAG(string, cache_path, "", "Cache path of jittor");
|
||||
DEFINE_FLAG(int, rewrite_op, 1, "Rewrite source file of jit operator or not");
|
||||
|
||||
namespace jit_compiler {
|
||||
|
||||
jit_op_entry_t load_jit_lib(string name, string symbol_name="jit_entry") {
|
||||
LOGvv << "Opening jit lib:" << name;
|
||||
// void* handle = dlopen(name.c_str(), RTLD_NOW | RTLD_DEEPBIND | RTLD_LOCAL);
|
||||
// RTLD_DEEPBIND and openmp cause segfault
|
||||
void* handle = dlopen(name.c_str(), RTLD_NOW | RTLD_LOCAL | RTLD_DEEPBIND);
|
||||
CHECK(handle) << "Cannot open library" << name << ":" << dlerror();
|
||||
|
||||
//dlerror();
|
||||
auto jit_entry = (jit_op_entry_t)dlsym(handle, symbol_name.c_str());
|
||||
const char* dlsym_error = dlerror();
|
||||
CHECK(!dlsym_error) << "Loading symbol jit_entry from" << name << "failed:" << dlsym_error;
|
||||
|
||||
return jit_entry;
|
||||
}
|
||||
|
||||
void run_cmd(string cmd, string cwd="") {
|
||||
if (cwd.size()) cmd = "cd "+cwd + " && " + cmd;
|
||||
LOGvvv << "Run cmd:" << cmd;
|
||||
system_with_check(cmd.c_str());
|
||||
}
|
||||
|
||||
static string get_symbol_name(const string& jit_key) {
|
||||
int i=0;
|
||||
while (i<jit_key.size() && jit_key[i]!='[') i++;
|
||||
string op_name = i ? jit_key.substr(0, i) : "fused";
|
||||
op_name = Op::file_name_to_class_name(op_name);
|
||||
// _ZN7jittorXyyyyyy7jit_runEv
|
||||
// jittor::yyyyyy::jit_run
|
||||
op_name = "_ZN6jittor"+S(op_name.size()+2)+op_name+"Op7jit_runEv";
|
||||
return op_name;
|
||||
}
|
||||
|
||||
jit_op_entry_t compile(const string& jit_key, const string& src, const bool is_cuda_op, const string& extra_flags) {
|
||||
auto iter = jit_ops.find(jit_key);
|
||||
if (iter != jit_ops.end())
|
||||
return iter->second;
|
||||
LOGvv << "Compile op" << jit_key;
|
||||
// compiler do not allowed filename too long
|
||||
CHECK(cc_path.size());
|
||||
string jit_src_path = Op::get_filename_from_jit_key(jit_key, ".cc");
|
||||
string jit_lib_path = Op::get_filename_from_jit_key(jit_key, ".so");
|
||||
string other_src = " "+join(jittor_path, "src/op.cc")+" "+
|
||||
join(jittor_path, "src/var.cc")+" ";
|
||||
other_src = "";
|
||||
LOGvvv << "Generate" << jit_src_path >> "\n" >> src;
|
||||
if (rewrite_op || !file_exist(jit_src_path))
|
||||
write(jit_src_path, src);
|
||||
string cmd;
|
||||
if (is_cuda_op) {
|
||||
cmd = nvcc_path
|
||||
+ " '" + jit_src_path + "'" + other_src
|
||||
+ nvcc_flags + extra_flags
|
||||
+ " -o '" + jit_lib_path + "'";
|
||||
} else {
|
||||
cmd = cc_path
|
||||
+ " '" + jit_src_path + "'" + other_src
|
||||
+ cc_flags + extra_flags
|
||||
+ " -o '" + jit_lib_path + "'";
|
||||
cmd = python_path+" "+jittor_path+"/utils/asm_tuner.py "
|
||||
"--cc_path=" + cmd;
|
||||
}
|
||||
cache_compile(cmd, cache_path, jittor_path);
|
||||
auto symbol_name = get_symbol_name(jit_key);
|
||||
auto jit_entry = load_jit_lib(jit_lib_path, symbol_name);
|
||||
jit_ops[jit_key] = jit_entry;
|
||||
return jit_entry;
|
||||
}
|
||||
|
||||
} // jit_compiler
|
||||
} // jittor
|
|
@ -0,0 +1,117 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include <sys/mman.h>
|
||||
#include <sstream>
|
||||
#include "jit_key.h"
|
||||
#include "misc/str_utils.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
const int page_size = 4*1024;
|
||||
|
||||
extern size_t protected_page;
|
||||
|
||||
static size_t get_buffer_end_page(size_t buffer_end) {
|
||||
// get the last complete page in buffer
|
||||
// 4k align :
|
||||
// | | | | |
|
||||
// buffer: xxxxxxxxxxxxxxxxxxxxxxxx
|
||||
// ^ buffer_end_page
|
||||
size_t buffer_end_page = buffer_end - buffer_end % page_size;
|
||||
if (buffer_end_page + page_size-1 > buffer_end)
|
||||
buffer_end_page -= page_size;
|
||||
return buffer_end_page;
|
||||
}
|
||||
|
||||
JitKey::JitKey() {
|
||||
auto buffer_end_page = get_buffer_end_page((size_t)&buffer[buffer_size-1]);
|
||||
LOGvv << "protect page" << (void*)buffer_end_page;
|
||||
ASSERT(0==mprotect((void*)buffer_end_page, page_size, PROT_NONE));
|
||||
protected_page = buffer_end_page;
|
||||
}
|
||||
|
||||
JitKey::~JitKey() {
|
||||
auto buffer_end_page = get_buffer_end_page((size_t)&buffer[buffer_size-1]);
|
||||
LOGvv << "un-protect page" << (void*)buffer_end_page;
|
||||
ASSERT(0==
|
||||
mprotect((void*)buffer_end_page, page_size, PROT_READ|PROT_WRITE|PROT_EXEC));
|
||||
protected_page = 0;
|
||||
}
|
||||
|
||||
static void hex_to_dec(string& s) {
|
||||
// check s is hex or not, if yes, convert to dec
|
||||
if (!s.size()) return;
|
||||
unsigned int x;
|
||||
std::stringstream ss;
|
||||
ss << std::hex << s;
|
||||
ss >> x;
|
||||
s = S(x);
|
||||
}
|
||||
|
||||
static void convert_itof(string& s) {
|
||||
uint64 x;
|
||||
std::stringstream ss;
|
||||
// itof(0x...)
|
||||
// ^ ^
|
||||
// 7
|
||||
ASSERT(s.size()>=8);
|
||||
ss << std::hex << s.substr(7, s.size()-7-1);
|
||||
ASSERT(ss >> x);
|
||||
ss.str(""); ss.clear();
|
||||
ss << std::hexfloat << itof(x);
|
||||
s = ss.str();
|
||||
// 0x0p+0 ---> 0x0p0
|
||||
if (s.find("p+") != string::npos)
|
||||
s.erase(s.find("p+")+1, 1);
|
||||
if (s=="inf") s = "(1.0/0)";
|
||||
if (s=="-inf") s = "(-1.0/0)";
|
||||
if (s=="nan" || s=="-nan") s = "(0.0/0)";
|
||||
}
|
||||
|
||||
vector<pair<string,string>> parse_jit_keys(const string& s) {
|
||||
vector<pair<string,string>> jit_keys;
|
||||
int presum = 0;
|
||||
char state=0;
|
||||
string key, val;
|
||||
for (char c : s) {
|
||||
if (c==JK::key) {
|
||||
presum++;
|
||||
if (presum==1) {
|
||||
state = c;
|
||||
continue;
|
||||
}
|
||||
} else
|
||||
if (c==JK::val || c==JK::hex_val) {
|
||||
if (presum==1 && state==JK::key) {
|
||||
state = c;
|
||||
continue;
|
||||
}
|
||||
} else
|
||||
if (c==JK::end) {
|
||||
presum--;
|
||||
if (presum==0) {
|
||||
if (state == JK::hex_val)
|
||||
hex_to_dec(val);
|
||||
if (startswith(val, "itof"))
|
||||
convert_itof(val);
|
||||
jit_keys.emplace_back(move(key), move(val));
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if (presum) {
|
||||
if (state==JK::key)
|
||||
key += c;
|
||||
if (state==JK::val || state==JK::hex_val)
|
||||
val += c;
|
||||
}
|
||||
}
|
||||
ASSERT(presum==0);
|
||||
return jit_keys;
|
||||
}
|
||||
|
||||
JitKey jk;
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,268 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include <limits>
|
||||
|
||||
#include "node.h"
|
||||
#include "op.h"
|
||||
#include "var.h"
|
||||
#include "op_compiler.h"
|
||||
#include "profiler/profiler.h"
|
||||
#include "mem/allocator.h"
|
||||
#include "misc/cuda_flags.h"
|
||||
#include "pybind/py_var_tracer.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
DECLARE_FLAG(string, cache_path);
|
||||
|
||||
DEFINE_FLAG(int, try_use_32bit_index, 0,
|
||||
"If not overflow, try to use 32 bit type as index type.");
|
||||
|
||||
string_view_map<jit_op_entry_t> jit_ops;
|
||||
string_view_map<string> jit_key_mapper;
|
||||
|
||||
int64_t Op::number_of_lived_ops = 0;
|
||||
|
||||
Op::Op() {
|
||||
flags.set(NodeFlags::_var, 0);
|
||||
flags.set(NodeFlags::_cpu, 1);
|
||||
number_of_lived_ops++;
|
||||
}
|
||||
|
||||
Op::~Op() {
|
||||
number_of_lived_ops--;
|
||||
}
|
||||
|
||||
void Op::forward(Var* input) {
|
||||
flags.set(NodeFlags::_forwarded);
|
||||
outputs_holder.emplace_back(input);
|
||||
}
|
||||
|
||||
VarPtr Op::grad(Var* out, Var* dout, Var* v, int v_index) {
|
||||
LOGw << "Grad of" << name() << "return zeros";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Var* Op::create_output(NanoVector shape, NanoString dtype) {
|
||||
VarPtr vp(shape, dtype);
|
||||
Var* output = vp.ptr;
|
||||
outputs_holder.emplace_back(move(vp));
|
||||
return output;
|
||||
}
|
||||
|
||||
void Op::init() {
|
||||
infer_shape();
|
||||
LOGvvvv << "Create" << this << "and outputs" << outputs();
|
||||
for (Var* v : outputs())
|
||||
CHECK(v->shape.size()) << "Number of dims should be solved.";
|
||||
}
|
||||
|
||||
bool Op::shape_infered() {
|
||||
if (flags.get(NodeFlags::_vary_shape)) return true;
|
||||
for (Var* v : outputs())
|
||||
if (v->num < 0) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
string Op::name_ex() const {
|
||||
string a=name();
|
||||
if (ns!=ns_void) {
|
||||
a += '.';
|
||||
a += ns.to_cstring();
|
||||
}
|
||||
return a;
|
||||
}
|
||||
|
||||
string Op::get_jit_key() {
|
||||
jk.clear();
|
||||
do_jit_prepare();
|
||||
return jk.to_string();
|
||||
}
|
||||
|
||||
vector<pair<string,string>> Op::get_jit_define() {
|
||||
return parse_jit_keys(get_jit_key());
|
||||
}
|
||||
|
||||
void Op::do_jit_prepare() {
|
||||
memcheck_all_exist();
|
||||
jk << name();
|
||||
jit_prepare();
|
||||
if (!jk.empty()) {
|
||||
// check use int64_t as index_t if array is too big
|
||||
int in_id=0, out_id=0;
|
||||
bool use_int64_t = false;
|
||||
// TODO: fused op do not have inputs,
|
||||
// check use_cuda_op from outputs may not be enough
|
||||
bool use_cuda_op = use_cuda;
|
||||
for (Var* var : inputs()) {
|
||||
if (var->allocator) {
|
||||
jk << JK::key << "alloc_i" << JK::hex1(in_id)
|
||||
<< JK::hex1(var->allocator->flags()) << JK::end;
|
||||
use_cuda_op &= var->allocator->is_cuda();
|
||||
}
|
||||
if (var->num >= std::numeric_limits<int32_t>::max())
|
||||
use_int64_t = true;
|
||||
in_id ++;
|
||||
}
|
||||
for (Var* var : outputs()) {
|
||||
if (var->allocator) {
|
||||
jk << JK::key << "alloc_o" << JK::hex1(in_id)
|
||||
<< JK::hex1(var->allocator->flags()) << JK::end;
|
||||
use_cuda_op &= var->allocator->is_cuda();
|
||||
}
|
||||
if (var->num >= std::numeric_limits<int32_t>::max())
|
||||
use_int64_t = true;
|
||||
out_id ++;
|
||||
}
|
||||
add_jit_define("JIT", "1");
|
||||
if (use_cuda_op && flags.get(NodeFlags::_cuda)) {
|
||||
add_jit_define("JIT_cuda", "1");
|
||||
flags.set(NodeFlags::_cpu, 0);
|
||||
// TODO: 64bit index in CUDA
|
||||
use_int64_t = false;
|
||||
} else {
|
||||
if (use_cuda==2) {
|
||||
if (flags.get(NodeFlags::_cuda))
|
||||
LOGf << "Op" << name() >> "'s vars are not allocated in cuda";
|
||||
else
|
||||
LOGf << "Op" << name() << "doesn't have cuda version";
|
||||
}
|
||||
ASSERT(flags.get(NodeFlags::_cpu))
|
||||
<< "Op" << name() << "doesn't have cpu version";
|
||||
add_jit_define("JIT_cpu", "1");
|
||||
flags.set(NodeFlags::_cuda, 0);
|
||||
}
|
||||
if (try_use_32bit_index) use_int64_t = false;
|
||||
add_jit_define("index_t", use_int64_t ? "int64" : "int32");
|
||||
}
|
||||
jk.finilize();
|
||||
}
|
||||
|
||||
void Op::do_prepare(){
|
||||
jk.clear();
|
||||
do_jit_prepare();
|
||||
}
|
||||
|
||||
void Op::do_run_after_prepare() {
|
||||
if (!jk.empty())
|
||||
jit_run();
|
||||
else
|
||||
run();
|
||||
}
|
||||
|
||||
void Op::do_run() {
|
||||
do_prepare();
|
||||
do_run_after_prepare();
|
||||
}
|
||||
|
||||
string Op::get_filename_from_jit_key(const string& jit_key, const string& suffix) {
|
||||
auto iter = jit_key_mapper.find(jit_key);
|
||||
string s = iter==jit_key_mapper.end() ? jit_key : iter->second;
|
||||
std::stringstream ss;
|
||||
if (s.size() > 100) {
|
||||
ss << s.substr(0, 90) << "...hash:"
|
||||
<< std::hex << std::hash<string>()(s);
|
||||
} else {
|
||||
ss << s << "_hash:" <<
|
||||
std::hex << std::hash<string>()(s);
|
||||
}
|
||||
s = ss.str();
|
||||
for (char& c : s) {
|
||||
if (c=='[' || c==']' || c=='<' || c=='>'
|
||||
|| c=='{' || c=='}' || c=='(' || c==')' || c==','
|
||||
|| c=='\n' || c=='\t' || c==' ' || c=='&' || c=='|'
|
||||
|| c=='/')
|
||||
c = '_';
|
||||
}
|
||||
string filename = cache_path + "/jit/";
|
||||
filename += s;
|
||||
filename += "_op";
|
||||
filename += suffix;
|
||||
return filename;
|
||||
}
|
||||
|
||||
// convert xxx.yyy -> xxx
|
||||
string Op::op_name_to_file_name(const string& s) {
|
||||
auto pos = s.find('.');
|
||||
return pos == string::npos ? s : s.substr(0, pos);
|
||||
}
|
||||
// convert xxx_xxx -> XxxXxx
|
||||
string Op::file_name_to_class_name(const string& s) {
|
||||
char prev = '_';
|
||||
string res;
|
||||
res.reserve(s.size());
|
||||
for (char c : s) {
|
||||
if (c != '_') {
|
||||
if (prev == '_')
|
||||
res += c-'a'+'A';
|
||||
else
|
||||
res += c;
|
||||
}
|
||||
prev = c;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
void Op::jit_run() {
|
||||
const char* jit_key = jk.to_cstring();
|
||||
auto iter = jit_ops.find(jit_key);
|
||||
if (iter != jit_ops.end()) {
|
||||
LOGvvv << "Jit op key found:" << jit_key << "jit op entry:" << (void*)iter->second;
|
||||
Profiler::record_and_run(iter->second, this, jit_key);
|
||||
return;
|
||||
}
|
||||
LOGvv << "Jit op key not found:" << jit_key;
|
||||
// compile JIT op
|
||||
string prev_jit_key = jit_key;
|
||||
auto op_entry = OpCompiler::do_compile(this);
|
||||
string new_jit_key = get_jit_key();
|
||||
jit_ops[new_jit_key] = jit_ops[prev_jit_key] = op_entry;
|
||||
jit_key_mapper[prev_jit_key] = new_jit_key;
|
||||
LOGvv << "Get jit op entry:" << (void*)op_entry;
|
||||
Profiler::record_and_run(op_entry, this, new_jit_key.c_str());
|
||||
}
|
||||
|
||||
void Op::statistics(uint64_t& in, uint64_t& out, uint64_t& compute) {
|
||||
in = out = compute = 0;
|
||||
for (Var* var : inputs()) {
|
||||
in += var->size;
|
||||
compute = std::max(compute, (uint64_t)var->num);
|
||||
}
|
||||
for (Var* var : outputs()) {
|
||||
out += var->size;
|
||||
compute = std::max(compute, (uint64_t)var->num);
|
||||
}
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const Op* op) {
|
||||
if (!op) return os << "Op(0)";
|
||||
os << "Op(" << (void*)op
|
||||
<< ':' << op->forward_liveness
|
||||
<< ':' << op->backward_liveness
|
||||
<< ':' << op->pending_liveness
|
||||
<< ":i" << op->_inputs.size()
|
||||
<< ":o" << op->_outputs.size()
|
||||
<< ":s" << op->is_finished()
|
||||
<< "," << op->name_ex();
|
||||
if (op->_outputs.size()>1)
|
||||
os << "->...";
|
||||
else if (op->_outputs.size() == 1) {
|
||||
auto v = (Var*)op->_outputs.front().node;
|
||||
if (v->name.size())
|
||||
os << "->" << v->name;
|
||||
else
|
||||
os << "->" << (void*)v;
|
||||
}
|
||||
os << ')';
|
||||
#ifdef NODE_MEMCHECK
|
||||
os << '<' << op->__id() << '>';
|
||||
print_node_trace(op, os);
|
||||
#endif
|
||||
return os;
|
||||
}
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,915 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include <regex>
|
||||
#include <algorithm>
|
||||
#include "op.h"
|
||||
#include "fused_op.h"
|
||||
#include "op_compiler.h"
|
||||
#include "jit_compiler.h"
|
||||
#include "utils/cache_compile.h"
|
||||
#include "opt/tuner_manager.h"
|
||||
#include "misc/str_utils.h"
|
||||
#include "ops/op_register.h"
|
||||
#include "ops/array_op.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
DECLARE_FLAG(string, jittor_path);
|
||||
|
||||
using namespace jit_compiler;
|
||||
|
||||
static bool isvar(char x) { return isalnum(x) || x == '_' || x == ':'; }
|
||||
|
||||
void OpCompiler::get_op_var_by_name(const string& name, uint& op_id, uint& opvar_id, Op*& op, Var*& var) {
|
||||
// name: op{id}_{varname}
|
||||
ASSERT(name.size()>3 && name[0]=='o' && name[1]=='p');
|
||||
uint j=2;
|
||||
while (j<name.size() && isdigit(name[j])) j++;
|
||||
ASSERT(j>2);
|
||||
op_id = std::stoi(name.substr(2, j-2));
|
||||
ASSERT(op_members.size() > op_id);
|
||||
bool found = false;
|
||||
for (opvar_id=0 ;opvar_id < op_members[op_id].size(); opvar_id++) {
|
||||
if (op_members[op_id][opvar_id] == name) {
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
op = this->op->ops[op_id];
|
||||
ASSERT(found && opvar_id < op->inputs().size() + op->outputs().size());
|
||||
if (opvar_id >= op->inputs().size()) {
|
||||
auto iter = op->outputs().begin();
|
||||
for (uint t=op->inputs().size(); t<opvar_id; t++)
|
||||
iter++;
|
||||
var = *iter;
|
||||
} else {
|
||||
auto iter = op->inputs().begin();
|
||||
for (uint t=0; t<opvar_id; t++)
|
||||
iter++;
|
||||
var = *iter;
|
||||
}
|
||||
}
|
||||
|
||||
string OpCompiler::get_name_by_op_var(Op* op, Var* var) {
|
||||
uint var_id=0;
|
||||
bool found = 0;
|
||||
for (Var* i : op->inputs()) {
|
||||
if (i==var) {
|
||||
found = 1;
|
||||
break;
|
||||
}
|
||||
var_id++;
|
||||
}
|
||||
if (!found)
|
||||
for (Var* o : op->outputs()) {
|
||||
if (o==var) {
|
||||
found = 1;
|
||||
break;
|
||||
}
|
||||
var_id++;
|
||||
}
|
||||
ASSERT(found);
|
||||
ASSERT(op->custom_data<(int)op_members.size());
|
||||
auto& v = op_members[op->custom_data];
|
||||
ASSERT(var_id < v.size());
|
||||
return v[var_id];
|
||||
}
|
||||
|
||||
string OpCompiler::get_name_by_op_input(Op* op, uint i) {
|
||||
return op_members.at(op->custom_data).at(i);
|
||||
}
|
||||
|
||||
string OpCompiler::get_name_by_op_output(Op* op, uint i) {
|
||||
return op_members.at(op->custom_data).at(i+op->inputs().size());
|
||||
}
|
||||
|
||||
bool OpCompiler::op_exist(Op* op) {
|
||||
return op_members.at(op->custom_data).size();
|
||||
}
|
||||
|
||||
int OpCompiler::total_member_count() {
|
||||
int member_count=0;
|
||||
int i = 0;
|
||||
for (auto& v : op_members) {
|
||||
// array need a extra local var
|
||||
if (op->ops[i]->name()==string("array"))
|
||||
member_count += 1;
|
||||
member_count += v.size();
|
||||
i += 1;
|
||||
}
|
||||
return member_count;
|
||||
}
|
||||
|
||||
#define FOR_ALL_UOPS(m) \
|
||||
m(!,3) m(~,3)
|
||||
#define FOR_ALL_BOPS(m) \
|
||||
m(*,5) m(/,5) m(%,5) \
|
||||
m(+,6) m(-,6) \
|
||||
m(<<,7) m(>>,7) \
|
||||
m(<,9) m(<=,9) m(>,9) m(>=,9) \
|
||||
m(!=,10) m(==,10) \
|
||||
m(&,11) \
|
||||
m(^,12) \
|
||||
m(|,13) \
|
||||
m(&&,14) \
|
||||
m(||,15)
|
||||
|
||||
#define FOR_ALL_OPS(m) FOR_ALL_UOPS(m) FOR_ALL_BOPS(m)
|
||||
|
||||
inline bool is_unary_op(const string& op) {
|
||||
#define _u(o, _) if (op == #o) return true;
|
||||
FOR_ALL_UOPS(_u);
|
||||
return false;
|
||||
}
|
||||
|
||||
inline int precedence(const string& op) {
|
||||
#define _prior(o, p) if (op == #o) return p;
|
||||
FOR_ALL_OPS(_prior);
|
||||
return 20;
|
||||
}
|
||||
|
||||
inline bool check_precedence(const string& op1, const string& op2) {
|
||||
if (op1 == op2 && is_unary_op(op1)) return false;
|
||||
return precedence(op1) <= precedence(op2);
|
||||
}
|
||||
|
||||
inline int64_t calc_op(int64_t a, int64_t b, const string& op) {
|
||||
#define _calc_b(o, _) if (op == #o) return a o b;
|
||||
FOR_ALL_BOPS(_calc_b);
|
||||
#define _calc_u(o, _) if (op == #o) return o b;
|
||||
FOR_ALL_UOPS(_calc_u);
|
||||
ASSERT(0) << "Unrecognized op " << op;
|
||||
return 0;
|
||||
}
|
||||
|
||||
int64_t OpCompiler::eval(const string& expr, const unordered_map<string,string>& vars) {
|
||||
if (expr.find("@") != string::npos) {
|
||||
string new_expr;
|
||||
for (size_t i=0; i<expr.size(); i++) {
|
||||
if (expr[i] != '@') new_expr += expr[i];
|
||||
else {
|
||||
size_t j=i+1;
|
||||
ASSERT(j < expr.size());
|
||||
// syntax @{...}
|
||||
// ij k
|
||||
if (expr[j] == '{') {
|
||||
size_t k=j+1;
|
||||
int presum = 1;
|
||||
while (k<expr.size() && presum) {
|
||||
if (expr[k] == '}')
|
||||
presum--;
|
||||
else if (expr[k] == '{')
|
||||
presum++;
|
||||
k++;
|
||||
}
|
||||
ASSERT(presum==0) << "Jit error: braces are not matched.";
|
||||
new_expr += S(eval(expr.substr(j+1, k-j-2), vars));
|
||||
i = k-1;
|
||||
continue;
|
||||
} else {
|
||||
// syntax: @x
|
||||
ASSERT(isvar(expr[j]));
|
||||
size_t k=j+1;
|
||||
while (k<expr.size() && isvar(expr[k])) k++;
|
||||
string var = expr.substr(j, k-j);
|
||||
auto iter = vars.find(var);
|
||||
ASSERT(iter!=vars.end()) << "Jit var " << var << " not found.";
|
||||
new_expr += iter->second;
|
||||
i = k-1;
|
||||
}
|
||||
}
|
||||
}
|
||||
return eval(new_expr, vars);
|
||||
}
|
||||
vector<int64> values = {0};
|
||||
vector<string> ops;
|
||||
auto pop_values_and_calc_op = [&]() {
|
||||
CHECK(ops.size());
|
||||
auto op = ops.back();
|
||||
ops.pop_back();
|
||||
CHECK(values.size());
|
||||
auto val2 = values.back();
|
||||
values.pop_back();
|
||||
auto val1 = val2;
|
||||
if (!is_unary_op(op)) {
|
||||
CHECK(values.size());
|
||||
val1 = values.back();
|
||||
values.pop_back();
|
||||
}
|
||||
values.push_back(calc_op(val1, val2, op));
|
||||
};
|
||||
for (size_t i=0; i<expr.size(); i++) {
|
||||
if (expr[i] == ' ')
|
||||
continue;
|
||||
if (expr[i] == '(')
|
||||
ops.push_back(string()+expr[i]);
|
||||
else if (isdigit(expr[i])) {
|
||||
int64_t val = 0;
|
||||
while (i<expr.length() && isdigit(expr[i])) {
|
||||
val = val*10 + (expr[i]-'0');
|
||||
i++;
|
||||
}
|
||||
i--;
|
||||
values.push_back(val);
|
||||
} else if (isvar(expr[i])) {
|
||||
auto j=i+1;
|
||||
while (j<expr.size() && isvar(expr[j])) j++;
|
||||
auto var_name = expr.substr(i,j-i);
|
||||
auto iter = vars.find(var_name);
|
||||
ASSERT(iter!=vars.end()) << "Jit var " << var_name << " not found.";
|
||||
try {
|
||||
values.push_back(std::stoll(iter->second));
|
||||
} catch (...) {
|
||||
ASSERT(0) << "'" << iter->second << "' is not integer, expr " << expr;
|
||||
}
|
||||
i = j-1;
|
||||
} else if (expr[i] == ')') {
|
||||
while (ops.size() && ops.back() != "(")
|
||||
pop_values_and_calc_op();
|
||||
ops.pop_back();
|
||||
} else {
|
||||
auto j=i+1;
|
||||
while (j<expr.size() && expr[j] != ' ' &&
|
||||
expr[j] != '!' && expr[j] != '~' &&
|
||||
!isdigit(expr[j]) && !isvar(expr[j]) &&
|
||||
expr[j] != '(' && expr[j] != ')') j++;
|
||||
auto op = expr.substr(i, j-i);
|
||||
while (ops.size() && check_precedence(ops.back(), op))
|
||||
pop_values_and_calc_op();
|
||||
ops.push_back(op);
|
||||
i = j-1;
|
||||
}
|
||||
}
|
||||
while (ops.size())
|
||||
pop_values_and_calc_op();
|
||||
return values.back();
|
||||
}
|
||||
|
||||
void load_macros(const string& src, unordered_map<string,string>& macros) {
|
||||
LOGvvvv << "load_macros" << src;
|
||||
for (size_t i=0; i<src.size(); i++) {
|
||||
if (src[i] == '#') {
|
||||
// #define xxx(...) xxx
|
||||
// i jk r l p q
|
||||
auto j=i+1;
|
||||
while (j<src.size() && src[j] != ' ') j++;
|
||||
if (j-i!=7 || src.substr(i,j-i) != "#define") {
|
||||
i = j;
|
||||
continue;
|
||||
}
|
||||
ASSERT(j<src.size());
|
||||
auto k=j+1;
|
||||
while (k<src.size() && src[k] == ' ') k++;
|
||||
ASSERT(k<src.size());
|
||||
auto l=k+1;
|
||||
while (l<src.size() && (src[l] != '\n' && src[l-1] != ')')) l++;
|
||||
auto p=l;
|
||||
while (p<src.size() && (src[p] == ' ')) p++;
|
||||
auto q=p;
|
||||
while (q<src.size() && (src[q] != '\n')) q++;
|
||||
// TODO: multiline macro
|
||||
auto r=k;
|
||||
while (r<l && src[r] != '(') r++;
|
||||
auto body = q>p ? src.substr(p,q-p) : "";
|
||||
body = (r<l?src.substr(r,l-r):"()") + body;
|
||||
auto header = src.substr(k,r-k);
|
||||
LOGvvvv << "header:" << header << "body:" << body;
|
||||
macros[header] = body;
|
||||
i = q;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void expand_macro(const string& macro, const vector<string>& args, string& new_src) {
|
||||
LOGvvvv << "expand_macro" << macro << "args:" << args;
|
||||
auto i = macro.find(")");
|
||||
ASSERT(i != string::npos);
|
||||
// (a1, a2, ...)body
|
||||
// j k i
|
||||
unordered_map<string, int> args_map;
|
||||
for (uint j=1, l=0; j<i; l++) {
|
||||
uint k=j;
|
||||
while (k<i && macro[k] != ',') k++;
|
||||
args_map[macro.substr(j,k-j)] = l;
|
||||
j = k+1;
|
||||
while (j<i && macro[j] == ' ') j++;
|
||||
}
|
||||
ASSERTop(args.size(),==,args_map.size()) << "Number of macro args not match.";
|
||||
for (i=i+1; i<macro.size(); i++) {
|
||||
if (isvar(macro[i])) {
|
||||
uint j = i+1;
|
||||
while (j<macro.size() && isvar(macro[j])) j++;
|
||||
string var = macro.substr(i, j-i);
|
||||
auto iter = args_map.find(var);
|
||||
if (iter == args_map.end()) {
|
||||
new_src += var;
|
||||
} else {
|
||||
new_src += args[iter->second];
|
||||
}
|
||||
i = j-1;
|
||||
continue;
|
||||
}
|
||||
new_src += macro[i];
|
||||
}
|
||||
}
|
||||
|
||||
string precompile(unordered_map<string,string> defs, string src, unordered_map<string, string>& macros) {
|
||||
string new_src;
|
||||
new_src.reserve(src.size());
|
||||
for (size_t i=0; i<src.size(); i++) {
|
||||
try{
|
||||
if (src[i] == '/' && (i+1<src.size() && src[i+1] == '/')) {
|
||||
size_t j=i+1;
|
||||
while (j<src.size() && src[j] != '\n') j++;
|
||||
if (j<src.size()) j++;
|
||||
// remove comment
|
||||
// for (size_t k=i; k<j; k++) new_src += src[k];
|
||||
if (src[j-1]=='\n')
|
||||
new_src += '\n';
|
||||
i = j-1;
|
||||
continue;
|
||||
} else
|
||||
if (src[i] == '/' && (i+1<src.size() && src[i+1] == '*')) {
|
||||
size_t j=i+1;
|
||||
while (j<src.size() && !(src[j] == '/' && src[j-1] == '*')) j++;
|
||||
if (j<src.size()) j++;
|
||||
// remove comment
|
||||
// for (size_t k=i; k<j; k++) new_src += src[k];
|
||||
i = j-1;
|
||||
continue;
|
||||
} else
|
||||
if (src[i] == '#') {
|
||||
// #include "a.h"
|
||||
// i jk l
|
||||
// #define xxx
|
||||
// i jk l
|
||||
auto j=i+1;
|
||||
while (j<src.size() && src[j] != ' ') j++;
|
||||
ASSERT(j<src.size());
|
||||
auto k=j+1;
|
||||
while (k<src.size() && src[k] == ' ') k++;
|
||||
ASSERT(k<src.size());
|
||||
auto l=k+1;
|
||||
while (l<src.size() && (src[l] != '\n')) l++;
|
||||
if (src[k] == '"' && src[l-1] == '"' && j-i==8 && src.substr(i,j-i) == "#include") {
|
||||
auto inc = src.substr(k+1, l-k-2);
|
||||
if (inc.size()>=6 && inc.substr(inc.size()-6) == "defs.h") {
|
||||
LOGvvvv << "Found defs include" << inc;
|
||||
auto src_path = join(jittor_path, "src");
|
||||
src_path = join(src_path, inc);
|
||||
auto inc_src = read_all(src_path);
|
||||
// load_macros from include src
|
||||
precompile(defs, inc_src, macros);
|
||||
// we do not include defs.h
|
||||
i = l;
|
||||
continue;
|
||||
}
|
||||
} else
|
||||
if (j-i==7 && src.substr(i,j-i) == "#define") {
|
||||
load_macros(src.substr(i,l-i), macros);
|
||||
} else
|
||||
// #ifdef JITxxx
|
||||
// #else
|
||||
// #endif
|
||||
if (((j-i==6 && src.substr(i,j-i) == "#ifdef") ||
|
||||
(j-i==7 && src.substr(i,j-i) == "#ifndef")) && startswith(src, "JIT", k)) {
|
||||
bool is_ifndef = j-i==7;
|
||||
string key = src.substr(k, l-k);
|
||||
// find pair #endif and #else
|
||||
int presum = 1;
|
||||
size_t prev = l+1, ii = prev;
|
||||
string block, else_block;
|
||||
while (ii < src.size()) {
|
||||
if (startswith(src, "#if", ii)) {
|
||||
presum++;
|
||||
ii += 3;
|
||||
continue;
|
||||
}
|
||||
if (startswith(src, "#else", ii)) {
|
||||
auto next_ii = ii+5;
|
||||
// remove ' ' or '\n' after #else
|
||||
if (next_ii<src.size() && (src[next_ii]==' ' || src[next_ii]=='\n'))
|
||||
next_ii++;
|
||||
if (presum==1) {
|
||||
block = src.substr(prev, ii-prev);
|
||||
prev = next_ii;
|
||||
}
|
||||
ii = next_ii;
|
||||
continue;
|
||||
}
|
||||
if (startswith(src, "#endif", ii)) {
|
||||
presum--;
|
||||
auto next_ii = ii+6;
|
||||
// remove ' ' or '\n' after #endif
|
||||
if (next_ii<src.size() && (src[next_ii]==' ' || src[next_ii]=='\n'))
|
||||
next_ii++;
|
||||
if (presum==0) {
|
||||
if (prev == l+1)
|
||||
block = src.substr(prev, ii-prev);
|
||||
else
|
||||
else_block = src.substr(prev, ii-prev);
|
||||
ii = next_ii;
|
||||
break;
|
||||
}
|
||||
ii = next_ii;
|
||||
continue;
|
||||
}
|
||||
ii++;
|
||||
}
|
||||
ASSERT(presum==0);
|
||||
if (is_ifndef) block.swap(else_block);
|
||||
if (defs.count(key) || macros.count(key)) {
|
||||
new_src += precompile(defs, block, macros);
|
||||
} else {
|
||||
new_src += precompile(defs, else_block, macros);
|
||||
}
|
||||
i = ii-1;
|
||||
continue;
|
||||
}
|
||||
for (auto k=i; k<l; k++) new_src += src[k];
|
||||
i= l-1;
|
||||
continue;
|
||||
} else
|
||||
if (src[i] == '@' && i+1<src.size()) {
|
||||
size_t j=i+1;
|
||||
// syntax @{...}
|
||||
// ij k
|
||||
if (src[j] == '{') {
|
||||
size_t k=j+1;
|
||||
int presum = 1;
|
||||
while (k<src.size() && presum) {
|
||||
if (src[k] == '}')
|
||||
presum--;
|
||||
else if (src[k] == '{')
|
||||
presum++;
|
||||
k++;
|
||||
}
|
||||
ASSERT(presum==0) << "Jit error: braces are not matched.";
|
||||
new_src += S(OpCompiler::eval(src.substr(j+1, k-j-2), defs));
|
||||
i = k-1;
|
||||
continue;
|
||||
} else if (isvar(src[j])) {
|
||||
size_t k=j+1;
|
||||
while (k<src.size() && isvar(src[k])) k++;
|
||||
string expr = src.substr(j, k-j);
|
||||
int presum = 1;
|
||||
vector<int> comma;
|
||||
vector<string> args;
|
||||
size_t l = k+1;
|
||||
if (expr == "for" || expr == "if" || expr == "expand_macro" ||
|
||||
(k<src.size() && src[k]=='(')) {
|
||||
ASSERT(src[k] == '(');
|
||||
comma.push_back(k);
|
||||
while (l<src.size() && presum) {
|
||||
if (src[l] == ')')
|
||||
presum--;
|
||||
else if (src[l] == '(')
|
||||
presum++;
|
||||
else if (presum == 1 && src[l] == ',')
|
||||
comma.push_back(l);
|
||||
l++;
|
||||
}
|
||||
ASSERT(presum==0) << "Jit error: braces are not matched.";
|
||||
comma.push_back(l-1);
|
||||
for (uint i=0; i+1<comma.size(); i++)
|
||||
args.push_back(src.substr(comma[i]+1, comma[i+1]-comma[i]-1));
|
||||
}
|
||||
// syntax @for(i, l, r, ...)
|
||||
// ij k l
|
||||
if (expr == "for") {
|
||||
CHECKop(args.size(),>=,4u) << "Jit error: for missing arguments.";
|
||||
string vi = args[0];
|
||||
string vl = args[1];
|
||||
string vr = args[2];
|
||||
string vs = args[3];
|
||||
auto vil = OpCompiler::eval(vl, defs);
|
||||
auto vir = OpCompiler::eval(vr, defs);
|
||||
int step = 1;
|
||||
if (args.size() >= 5) {
|
||||
step = OpCompiler::eval(vs, defs);
|
||||
vs = args[4];
|
||||
}
|
||||
auto new_defs = defs;
|
||||
LOGvvv << "Expand for" << expr >> "[" >> vil >> "," >> vir >> "," >> step >> "]";
|
||||
int total_step = 0;
|
||||
for (auto vii=vil; vii!=vir; vii+=step) {
|
||||
total_step ++;
|
||||
ASSERT(total_step < 1000) << "Too much step.";
|
||||
new_defs[vi] = S(vii);
|
||||
new_src += precompile(new_defs, vs, macros);
|
||||
}
|
||||
i = l-1;
|
||||
continue;
|
||||
} else
|
||||
if (expr == "if") {
|
||||
// syntax: @if(cond, true[, false])
|
||||
// ij k l
|
||||
ASSERT(args.size()>=2u && args.size()<=3u)
|
||||
<< "Jit error: if wrong arguments.";
|
||||
string vcond = args[0];
|
||||
string vtrue = args[1];
|
||||
string vfalse = args.size() == 3u ? args[2] : "";
|
||||
int cond = OpCompiler::eval(vcond, defs);
|
||||
new_src += precompile(defs, cond?vtrue:vfalse, macros);
|
||||
i = l-1;
|
||||
continue;
|
||||
} else
|
||||
if (expr == "expand_macro") {
|
||||
// syntax: @expand_macro(macro, args)
|
||||
// ij k l
|
||||
for (auto& arg : args) {
|
||||
uint p=0;
|
||||
while (p<arg.size() && arg[p] == ' ') p++;
|
||||
arg = precompile(defs, arg.substr(p), macros);
|
||||
}
|
||||
string vmacro = args[0];
|
||||
args.erase(args.begin());
|
||||
auto iter = macros.find(vmacro);
|
||||
string ns;
|
||||
if (iter == macros.end()) {
|
||||
if (defs.count(vmacro))
|
||||
ns = defs[vmacro];
|
||||
else
|
||||
LOGf << "Macro" << vmacro << "not found.";
|
||||
} else {
|
||||
expand_macro(iter->second, args, ns);
|
||||
}
|
||||
new_src += precompile(defs, ns, macros);
|
||||
i = l-1;
|
||||
continue;
|
||||
} else
|
||||
if (args.size()) {
|
||||
// syntax: @e0(i0,i1,...,in) -> e0p[i0*e0stride0+i1*e0stride1+...]
|
||||
int nid=(int)expr.size();
|
||||
while (nid && isdigit(expr[nid-1])) nid--;
|
||||
// xyz123 ---> prefix: xxx; suffix: 123
|
||||
string prefix = expr.substr(0, nid);
|
||||
string suffix = expr.substr(nid);
|
||||
string up_prefix = prefix;
|
||||
for (auto& c : up_prefix)
|
||||
if (c>='a' && c<='z') c = c-'a'+'A';
|
||||
string dim = up_prefix + "DIM" + suffix;
|
||||
if (prefix == "e") prefix = "extras";
|
||||
ASSERT(defs.count(dim)) << dim;
|
||||
ASSERTop(defs.at(dim),==,S(args.size()));
|
||||
expr = prefix + suffix; // e0 ->extras0
|
||||
std::stringstream ss;
|
||||
ss << expr << "p[";
|
||||
for (uint ii=0; ii<args.size(); ii++) {
|
||||
string arg = precompile(defs, args[ii], macros);
|
||||
if (ii) ss << "+";
|
||||
ss << '(' << arg << ")*" << expr << "stride" << ii;
|
||||
}
|
||||
ss << ']';
|
||||
new_src += ss.str();
|
||||
i = l-1;
|
||||
continue;
|
||||
}
|
||||
// syntax: @x
|
||||
auto iter = defs.find(expr);
|
||||
ASSERT(iter!=defs.end()) << "Jit var " << expr << " not found.";
|
||||
new_src += precompile(defs, iter->second, macros);
|
||||
i = k-1;
|
||||
continue;
|
||||
} else if (src[j]=='@') {
|
||||
// seperater syntex: @@
|
||||
i++;
|
||||
continue;
|
||||
} else
|
||||
LOGf << "Jit error: Invalid syntax.";
|
||||
} else
|
||||
new_src += src[i];
|
||||
} catch (std::exception& e) {
|
||||
uint il = i, ir = i;
|
||||
while (il && src[il] != '\n') il--;
|
||||
while (ir<src.size() && src[ir] != '\n') ir++;
|
||||
string this_line = src.substr(il+1, ir-il-1);
|
||||
LOGf << e.what() >> "\nJit compiler error:\n" >> this_line;
|
||||
}
|
||||
}
|
||||
return new_src;
|
||||
}
|
||||
|
||||
string OpCompiler::precompile(const unordered_map<string,string>& defs, const string& src) {
|
||||
unordered_map<string, string> macros;
|
||||
return jittor::precompile(defs, src, macros);
|
||||
}
|
||||
|
||||
string OpCompiler::get_jit_src(Op* op) {
|
||||
string name = op->name();
|
||||
string name2 = Op::op_name_to_file_name(name);
|
||||
string name3 = Op::file_name_to_class_name(name2);
|
||||
if (name == "fused") {
|
||||
string src = get_fused_src((FusedOp*)op);
|
||||
ASSERT(src.size());
|
||||
return src;
|
||||
}
|
||||
auto op_info = get_op_info(name);
|
||||
auto& src_path = op_info.source_path;
|
||||
|
||||
string begin_src = "", end_src = "";
|
||||
// source that need to be added after the last #include statement
|
||||
string after_include_src = "";
|
||||
auto jit_define = op->get_jit_define();
|
||||
for (auto &t : jit_define) {
|
||||
string src = "#define " + t.first + " ";
|
||||
for (char c : t.second) {
|
||||
if (c=='\n') src += '\\';
|
||||
src += c;
|
||||
}
|
||||
src += '\n';
|
||||
if (startswith(t.first, "JIT"))
|
||||
begin_src += src;
|
||||
else
|
||||
after_include_src += src;
|
||||
}
|
||||
ASSERT(file_exist(src_path));
|
||||
LOGvvv << "Read from" << src_path;
|
||||
string src = read_all(src_path);
|
||||
ASSERT(src.size()) << "Source read failed:" << src_path;
|
||||
|
||||
unordered_map<string,string> defs(jit_define.begin(), jit_define.end());
|
||||
LOGvvv << "Precompile with key:" << defs;
|
||||
src = precompile(defs, src);
|
||||
|
||||
// find the last occur of #include "..."\n
|
||||
auto pos = src.rfind("#include");
|
||||
if (pos == string::npos) pos=0;
|
||||
else {
|
||||
// find \n
|
||||
pos = src.find("\n", pos);
|
||||
if (pos == string::npos)
|
||||
pos = src.size();
|
||||
else
|
||||
pos++;
|
||||
}
|
||||
|
||||
string new_src = begin_src + src.substr(0, pos) +
|
||||
after_include_src + src.substr(pos) + "\n" + end_src;
|
||||
return new_src;
|
||||
}
|
||||
|
||||
string OpCompiler::get_fused_src(FusedOp* op) {
|
||||
vector<string> op_srcs;
|
||||
vector<bool> relay_switch(op->context->vrm.relay_groups.size());
|
||||
for (uint i=0; i<relay_switch.size(); i++) {
|
||||
auto relay_key = "relay"+S(i);
|
||||
if (op->loop_options->count(relay_key) &&
|
||||
op->loop_options->at(relay_key) == 1)
|
||||
relay_switch[i] = 1;
|
||||
}
|
||||
auto relay_source = op->context->vrm.get_op_relay_info(relay_switch);
|
||||
std::set<pair<int,int>> relayed;
|
||||
for (uint oi=0; oi<op->ops.size(); oi++) {
|
||||
// relay group id, pair id
|
||||
auto p = relay_source[oi];
|
||||
if (p.first != -1) {
|
||||
if (relayed.count(p)) {
|
||||
op_srcs.push_back("");
|
||||
continue;
|
||||
}
|
||||
relayed.insert(p);
|
||||
auto src = op->context->vrm.get_relay_src(p.first, p.second);
|
||||
op_srcs.push_back(src);
|
||||
// op_srcs.push_back(get_relayed_src(src));
|
||||
continue;
|
||||
}
|
||||
Op* opi = op->ops[oi];
|
||||
string src = get_jit_src(opi);
|
||||
op_srcs.push_back(move(src));
|
||||
}
|
||||
return OpCompiler::__get_fused_src(op->ops, op_srcs, op_members);
|
||||
}
|
||||
|
||||
string OpCompiler::__get_fused_src(
|
||||
const vector<Op*>& ops,
|
||||
const vector<string>& op_srcs,
|
||||
vector<vector<string>>& op_members
|
||||
) {
|
||||
string fused_begin;
|
||||
string fused_includes;
|
||||
string fused_defines;
|
||||
string fused_kernel_args;
|
||||
string fused_kernel;
|
||||
// definitions of fused_begin
|
||||
map<string,string> defs;
|
||||
unordered_set<string> kernel_args;
|
||||
op_members = vector<vector<string>>(op_srcs.size());
|
||||
fused_begin += "#define JIT 1\n";
|
||||
defs["JIT"] = "1";
|
||||
const string pattern = "::jit_run() {";
|
||||
// TODO: better check member
|
||||
const unordered_set<string> members = {
|
||||
"x", "y", "z", "cond", "output", "extras"
|
||||
};
|
||||
const unordered_set<string> unchanged = {
|
||||
"for", "const", "auto", "get_random_engine",
|
||||
"int", "float", "bool", "CHECK", "STRINGIZE",
|
||||
"void", "__restrict__", "if", "true", "false",
|
||||
"Op", "Var", "Node", "itof"
|
||||
};
|
||||
auto not_change = [&](const string& s) -> bool {
|
||||
if (unchanged.count(s)) return true;
|
||||
return (s.find("::") != string::npos) || (s.find("LOG") != string::npos);
|
||||
};
|
||||
// regex find XxxXxxOp::jit_run
|
||||
std::regex e(R"([^]*\s(\S*)Op::jit_run[^]*)");
|
||||
for (uint oi=0; oi<op_srcs.size(); oi++) {
|
||||
const string& src = op_srcs[oi];
|
||||
if (src.size()==0) continue;
|
||||
if (src.find("@relay_op") != string::npos) {
|
||||
fused_kernel += src;
|
||||
continue;
|
||||
}
|
||||
if (ops[oi]->name()==string("array")) {
|
||||
string op_name = "op" + S(oi);
|
||||
string arg_name = op_name + "_output";
|
||||
string argp_name = op_name + "_outputp";
|
||||
string T = ((ArrayOp*)ops[oi])->output->dtype().to_cstring();
|
||||
fused_kernel_args += " ArrayOp* " + op_name + " = (ArrayOp*)(ops[" + S(oi) + "]);\n";
|
||||
// op_name = "((ArrayOp*)(ops[" + S(oi) + "]))";
|
||||
fused_kernel_args += " Var* " + arg_name + " = " + op_name + "->output;\n";
|
||||
|
||||
fused_kernel += " auto* " + argp_name + " = " + arg_name + "->ptr<" + T + ">();\n";
|
||||
fused_kernel += " " + argp_name + "[0] = " + op_name + "->ptr<" + T + ">()[0];\n";
|
||||
fused_kernel += " int " + arg_name + "shape0 = 1;\n";
|
||||
fused_kernel += " int " + arg_name + "stride0 = 1;\n";
|
||||
|
||||
fused_includes += "#include \"ops/array_op.h\"\n";
|
||||
op_members[oi].push_back(arg_name);
|
||||
// auto opi = (ArrayOp*)(ops[i]);
|
||||
// auto opi_output = opi->output;
|
||||
// auto* opi_outputp = opi_output->ptr<T>();
|
||||
// opi_outputp[0] = ((T*)(opi->buffer.get()))[0];
|
||||
continue;
|
||||
}
|
||||
std::smatch cm;
|
||||
std::regex_match(src, cm, e);
|
||||
ASSERT(cm.size()>=2) << src;
|
||||
string name3 = cm[1];
|
||||
for (uint i=0; i<src.size(); i++) {
|
||||
if (src[i] == '#' &&
|
||||
(i+1<src.size() && src[i+1] == 'i') &&
|
||||
(i+2<src.size() && src[i+2] == 'n'))
|
||||
{
|
||||
// #include ...
|
||||
uint j=i+1;
|
||||
while (j<src.size() && src[j] != '\n') j++;
|
||||
if (j<src.size()) j++;
|
||||
for (uint k=i; k<j; k++) fused_includes += src[k];
|
||||
i = j-1;
|
||||
continue;
|
||||
}
|
||||
if (src[i] == '#' && (i+1<src.size() && src[i+1] == 'd')) {
|
||||
// #define aaa bbb
|
||||
// i j k l
|
||||
// TODO: multi-line define
|
||||
uint j=i+1;
|
||||
while (j<src.size() && src[j] != ' ') j++;
|
||||
while (j<src.size() && src[j] == ' ') j++;
|
||||
uint k=j;
|
||||
while (k<src.size() && src[k] != ' ') k++;
|
||||
uint l=k;
|
||||
while (l<src.size() && src[l] != '\n') l++;
|
||||
if (l<src.size()) l++;
|
||||
CHECK(i<j && j<k && k<l);
|
||||
// define startswith JIT should be added at the very beginning
|
||||
if (startswith(src, "JIT", j)) {
|
||||
string key = src.substr(j,k-j);
|
||||
string value = src.substr(k+1, l-k-2);
|
||||
if (defs.count(key))
|
||||
CHECKop(defs[key],==,value);
|
||||
else {
|
||||
defs[key] = value;
|
||||
fused_begin += "#define ";
|
||||
for (; j<l; j++) fused_begin += src[j];
|
||||
}
|
||||
j = l;
|
||||
} else {
|
||||
fused_defines += "#define op" + S(oi) + "_";
|
||||
for (; j<l; j++) fused_defines += src[j];
|
||||
}
|
||||
i = j-1;
|
||||
continue;
|
||||
}
|
||||
// find the first function match the pattern "jit_run"
|
||||
bool found = true;
|
||||
for (uint j=0; j<pattern.size(); j++)
|
||||
if (pattern[j] != src[i+j]) {
|
||||
found = false;
|
||||
break;
|
||||
}
|
||||
if (!found) continue;
|
||||
uint j = i+pattern.size();
|
||||
uint k = j;
|
||||
int presum = 1;
|
||||
while (k<src.size() && presum) {
|
||||
if (src[k] == '}')
|
||||
presum--;
|
||||
else if (src[k] == '{')
|
||||
presum++;
|
||||
k++;
|
||||
}
|
||||
ASSERT(presum==0) << "Jit error: braces are not matched.";
|
||||
for (;j < k-2; j++) {
|
||||
if (isvar(src[j])) {
|
||||
uint l=j;
|
||||
while (l<src.size() && isvar(src[l])) l++;
|
||||
auto var = src.substr(j, l-j);
|
||||
if (var[0] == ':' || isdigit(var[0]) || not_change(var) || src[j-1]=='.' || src[j-1]=='>') {} else
|
||||
if (members.count(var)) {
|
||||
string arg_name = "op" + S(oi) + "_" + var;
|
||||
if (l<src.size() && src[l]=='[') {
|
||||
// handle extras[...]
|
||||
// l r
|
||||
uint r = l+1;
|
||||
while (r<src.size() && src[r]!=']') r++;
|
||||
ASSERT(r<src.size());
|
||||
for (uint i=l+1; i<r; i++) {
|
||||
ASSERT(isdigit(src[i]));
|
||||
arg_name += src[i];
|
||||
}
|
||||
l = r+1;
|
||||
var = src.substr(j, l-j);
|
||||
// arg_name = opi_extra0
|
||||
// var = extra[0]
|
||||
}
|
||||
if (!kernel_args.count(arg_name)) {
|
||||
fused_kernel_args +=
|
||||
string(" auto ") + arg_name +
|
||||
" = (("+name3+"Op*)(ops[" + S(oi) + "]))->" + var;
|
||||
fused_kernel_args += ";\n";
|
||||
kernel_args.insert(arg_name);
|
||||
op_members[oi].push_back(arg_name);
|
||||
}
|
||||
fused_kernel += arg_name;
|
||||
j = l-1;
|
||||
continue;
|
||||
} else
|
||||
fused_kernel += "op" + S(oi) + "_";
|
||||
for (uint p=j; p<l; p++) fused_kernel += src[p];
|
||||
j = l-1;
|
||||
continue;
|
||||
}
|
||||
fused_kernel += src[j];
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
CHECK(!(defs.count("JIT_cpu") && defs.count("JIT_cuda")))
|
||||
<< "CPU op and GPU op cannot be fused together.";
|
||||
|
||||
fused_kernel = fused_kernel_args + "\n" + fused_kernel;
|
||||
LOGvvvv << "Fused kernel:\n" >> fused_kernel;
|
||||
|
||||
auto fused_src = fused_begin + fused_includes + "\n#include \"fused_op.h\"\n" +
|
||||
fused_defines + '\n' +
|
||||
"void jittor::FusedOp::jit_run() {\n" + fused_kernel + "\n}\n";
|
||||
|
||||
// we assume the member name is in lexicographical order
|
||||
// for (auto& v : op_members) std::sort(v.begin(), v.end());
|
||||
|
||||
return fused_src;
|
||||
}
|
||||
|
||||
string OpCompiler::get_src() {
|
||||
if (op==nullptr) return src;
|
||||
for (const auto& p : *op->loop_options)
|
||||
if (startswith(p.first, "relay")) {
|
||||
// return get jit src if has relay op
|
||||
return get_jit_src(op);
|
||||
}
|
||||
return src;
|
||||
}
|
||||
|
||||
OpCompiler::OpCompiler(Op* op) {
|
||||
_op = op;
|
||||
this->op = op->name()==string("fused") ? (FusedOp*)op : nullptr;
|
||||
src = get_jit_src(op);
|
||||
}
|
||||
|
||||
jit_op_entry_t OpCompiler::compile(const string& jit_key, const string& src) {
|
||||
// add extra flags for custom ops
|
||||
bool is_cuda = _op->flags.get(NodeFlags::_cuda);
|
||||
auto op_info = get_op_info(_op->name());
|
||||
return jit_compiler::compile(jit_key, src, is_cuda, op_info.extra_flags);
|
||||
}
|
||||
|
||||
jit_op_entry_t OpCompiler::do_compile(Op* op) {
|
||||
OpCompiler oc(op);
|
||||
string* src = &oc.src;
|
||||
string src_after_passes;
|
||||
// if is fused op
|
||||
if (oc.op) {
|
||||
TunerManager tm(&oc);
|
||||
src_after_passes = tm.tune();
|
||||
src = &src_after_passes;
|
||||
}
|
||||
return oc.compile(op->get_jit_key(), *src);
|
||||
}
|
||||
|
||||
|
||||
}
|
|
@ -0,0 +1,101 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include <type_traits>
|
||||
|
||||
#include "var.h"
|
||||
#include "op.h"
|
||||
#include "mem/allocator.h"
|
||||
#include "pybind/py_var_tracer.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
int64_t Var::number_of_lived_vars = 0;
|
||||
|
||||
DEFINE_FLAG(fast_shared_ptr<loop_options_t>, compile_options, {},
|
||||
"Override the default loop transfrom options");
|
||||
|
||||
Var::Var(NanoVector shape, NanoString dtype)
|
||||
: shape(shape),
|
||||
loop_options(compile_options) {
|
||||
flags.set(NodeFlags::_var, 1);
|
||||
ns = dtype;
|
||||
ASSERT(ns.is_dtype());
|
||||
number_of_lived_vars++;
|
||||
numel();
|
||||
}
|
||||
Var::~Var() {
|
||||
if (mem_ptr != nullptr)
|
||||
allocator->free(mem_ptr, size, allocation);
|
||||
number_of_lived_vars--;
|
||||
}
|
||||
|
||||
string Var::to_string() {
|
||||
string s = dtype().to_cstring();
|
||||
s += shape.to_string();
|
||||
return s;
|
||||
}
|
||||
|
||||
int64_t Var::numel() {
|
||||
if (!shape.size()) return size=num=-1;
|
||||
bool negtive = 0;
|
||||
num=1;
|
||||
for (auto k : shape) {
|
||||
if (k<0) {
|
||||
negtive = 1;
|
||||
num *= -k;
|
||||
} else {
|
||||
num *= k;
|
||||
}
|
||||
}
|
||||
size = num * dsize();
|
||||
if (negtive) num = -num;
|
||||
return num;
|
||||
}
|
||||
|
||||
void Var::set_shape(NanoVector shape) {
|
||||
this->shape = shape;
|
||||
numel();
|
||||
}
|
||||
|
||||
bool Var::alloc(Allocator* allocator) {
|
||||
if (mem_ptr) return true;
|
||||
if (auto* x = (Var*)(this->allocator)) {
|
||||
if (x->allocator->share_with(size, x->allocation)) {
|
||||
mem_ptr = x->mem_ptr;
|
||||
allocation = x->allocation;
|
||||
this->allocator = x->allocator;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
mem_ptr = allocator->alloc(size, allocation);
|
||||
this->allocator = allocator;
|
||||
return mem_ptr;
|
||||
}
|
||||
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const Var& var) {
|
||||
os << "Var" << '(' << (void*)&var
|
||||
<< ':' << var.forward_liveness
|
||||
<< ':' << var.backward_liveness
|
||||
<< ':' << var.pending_liveness
|
||||
<< ":i" << var._inputs.size()
|
||||
<< ":o" << var._outputs.size()
|
||||
<< ":s" << var.is_finished()
|
||||
<< ','
|
||||
<< var.dtype().to_cstring() << ',' << var.name << ',' << var.mem_ptr
|
||||
<< ')' << var.shape;
|
||||
#ifdef NODE_MEMCHECK
|
||||
os << '<' << var.__id() << '>';
|
||||
print_node_trace(&var, os);
|
||||
#endif
|
||||
return os;
|
||||
}
|
||||
std::ostream& operator<<(std::ostream& os, const Var* var) {
|
||||
return os << *var;
|
||||
}
|
||||
std::ostream& operator<<(std::ostream& os, const VarPtr& v) { return os << v.ptr; }
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,125 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#ifdef HAS_CUDA
|
||||
#include <cuda_runtime.h>
|
||||
#include <helper_cuda.h>
|
||||
#endif
|
||||
#include "var_holder.h"
|
||||
#include "var.h"
|
||||
#include "executor.h"
|
||||
#include "graph.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
list<VarHolder*> VarHolder::hold_vars;
|
||||
|
||||
void add_hold_vars(VarHolder* self) {
|
||||
VarHolder::hold_vars.push_front(self);
|
||||
self->iter = VarHolder::hold_vars.begin();
|
||||
}
|
||||
|
||||
VarHolder::VarHolder(Var* v) : var(v) {
|
||||
add_hold_vars(this);
|
||||
// Var holder has both forward and backward liveness
|
||||
var->own_both_liveness();
|
||||
}
|
||||
|
||||
VarHolder::VarHolder(VarPtr&& v) : VarHolder(v.ptr) {
|
||||
v.free_liveness();
|
||||
v.ptr = nullptr;
|
||||
}
|
||||
|
||||
VarHolder::VarHolder(VarHolder* v) : var(v->var) {
|
||||
iter = v->iter;
|
||||
*iter = this;
|
||||
// free memory without calling deconstructor
|
||||
operator delete(v);
|
||||
}
|
||||
|
||||
VarHolder::~VarHolder() {
|
||||
hold_vars.erase(iter);
|
||||
var->release_both_liveness();
|
||||
}
|
||||
|
||||
// assign attributes of b to a
|
||||
static inline void assign_var(Var* a, Var* b) {
|
||||
a->name = move(b->name);
|
||||
if (b->is_stop_grad())
|
||||
a->set_stop_grad();
|
||||
if (b->flags.get(NodeFlags::_stop_fuse))
|
||||
a->flags.set(NodeFlags::_stop_fuse);
|
||||
}
|
||||
|
||||
void VarHolder::operator=(VarPtr&& v) {
|
||||
assign_var(v.ptr, var);
|
||||
var->release_both_liveness();
|
||||
var = v.ptr;
|
||||
v.ptr = nullptr;
|
||||
}
|
||||
|
||||
string VarHolder::to_string() {
|
||||
if (var->num<0) sync();
|
||||
return var->to_string();
|
||||
}
|
||||
|
||||
VarHolder* VarHolder::assign(VarHolder* v) {
|
||||
assign_var(v->var, var);
|
||||
var->release_both_liveness();
|
||||
var = v->var;
|
||||
var->own_both_liveness();
|
||||
return this;
|
||||
}
|
||||
|
||||
extern Executor exe;
|
||||
|
||||
void VarHolder::sync(bool device_sync) {
|
||||
jittor::sync({this}, device_sync);
|
||||
}
|
||||
|
||||
ArrayArgs VarHolder::fetch_sync() {
|
||||
sync(true);
|
||||
#ifdef HAS_CUDA
|
||||
migrate_to_cpu(var, exe.allocator);
|
||||
#endif
|
||||
return {var->mem_ptr, var->shape, var->dtype()};
|
||||
}
|
||||
|
||||
void sync_all(bool device_sync) {
|
||||
vector<Var*> vars;
|
||||
vars.reserve(VarHolder::hold_vars.size());
|
||||
for (auto v : VarHolder::hold_vars) {
|
||||
if (!v->var->_outputs.size())
|
||||
vars.push_back(v->var);
|
||||
}
|
||||
graph_check();
|
||||
exe.run_sync(vars, device_sync); //need sync at last
|
||||
graph_check();
|
||||
}
|
||||
|
||||
void sync(const vector<VarHolder*>& vh, bool device_sync) {
|
||||
vector<Var*> vars;
|
||||
vars.reserve(vh.size());
|
||||
for (auto v : vh) vars.push_back(v->var);
|
||||
graph_check();
|
||||
exe.run_sync(vars, device_sync); //need sync at last
|
||||
graph_check();
|
||||
}
|
||||
|
||||
vector<ArrayArgs> fetch_sync(const vector<VarHolder*>& vh) {
|
||||
vector<ArrayArgs> ret(vh.size());
|
||||
sync(vh, true);
|
||||
for (uint i=0; i<vh.size(); i++) {
|
||||
#ifdef HAS_CUDA
|
||||
migrate_to_cpu(vh[i]->var, exe.allocator);
|
||||
#endif
|
||||
ret[i].ptr = vh[i]->var->mem_ptr;
|
||||
ret[i].shape = vh[i]->var->shape;
|
||||
ret[i].dtype = vh[i]->var->dtype();
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
} // jittor
|
Loading…
Reference in New Issue