version 77593ddd55381fddacdfa637355784488523c5e2

This commit is contained in:
Dun Liang 2020-03-20 09:49:49 +08:00
parent 1258121b1f
commit e96f7ceee8
21 changed files with 2912 additions and 14 deletions

View File

@ -104,8 +104,9 @@ Jittor使用Python和C++编写。 它需要用于即时编译的编译器。当
Jittor的环境要求如下:
* 操作系统: Ubuntu>=16.04
* Python >= 3.7
* 操作系统: Ubuntu >= 16.04
* Python版本 >= 3.7
* C++编译器g++ or clang
Jittor offers three ways to install: pip, script or manual.
@ -115,7 +116,7 @@ Jittor 一共提供三种方式安装: pip安装, 一键脚本安装 和 手动
## Pip install
如果您已经装好编译器和对应版本的Python,我们强烈推荐您使用这种方法
如果您没有准备好环境,欢迎使用我们提供的一键安装脚本, 如果您已经装好编译器和对应版本的Python,我们强烈推荐您使用这种方法
(如果无法访问github, 可以通过jittor主页下载):
```bash
@ -134,7 +135,7 @@ jittor会自动在路径中寻找合适的编译器, 如果您希望手动指定
## 一键脚本安装
## single line script install
一键脚本安装会帮您安装好所需的编译器.
一键脚本安装会帮您安装好所需的编译器以及对应的Python版本.
We provide single line command for quick installation the latest version of Jittor(Ubuntu>=16.04):
@ -142,13 +143,13 @@ We provide single line command for quick installation the latest version of Jitt
```bash
# install with clang and cuda
git clone https://github.com/Jittor/jittor.git && with_clang=1 with_cuda=1 bash ./jittor/script/install.sh
wget https://cg.cs.tsinghua.edu.cn/jittor/assets/build/jittor.tgz && mkdir -p jittor && tar -xvf ./jittor.tgz -C jittor && with_clang=1 with_cuda=1 bash ./jittor/script/install.sh
# install with clang
git clone https://github.com/Jittor/jittor.git && with_clang=1 bash ./jittor/script/install.sh
wget https://cg.cs.tsinghua.edu.cn/jittor/assets/build/jittor.tgz && mkdir -p jittor && tar -xvf ./jittor.tgz -C jittor && with_clang=1 bash ./jittor/script/install.sh
# install with g++ and cuda
git clone https://github.com/Jittor/jittor.git && with_gcc=1 with_cuda=1 bash ./jittor/script/install.sh
wget https://cg.cs.tsinghua.edu.cn/jittor/assets/build/jittor.tgz && mkdir -p jittor && tar -xvf ./jittor.tgz -C jittor && with_gcc=1 with_cuda=1 bash ./jittor/script/install.sh
# install with g++
git clone https://github.com/Jittor/jittor.git && with_gcc=1 bash ./jittor/script/install.sh
wget https://cg.cs.tsinghua.edu.cn/jittor/assets/build/jittor.tgz && mkdir -p jittor && tar -xvf ./jittor.tgz -C jittor && with_gcc=1 bash ./jittor/script/install.sh
```
After execution, the script will show some environment variables you need to export.

View File

@ -166,7 +166,7 @@ def install_cutt(root_folder):
filename = "cutt.tgz"
fullname = os.path.join(root_folder, filename)
dirname = os.path.join(root_folder, filename.replace(".tgz",""))
true_md5 = "c79ad93b76544d598eb250ec749c492c"
true_md5 = "28a67bb3a713e29ce434303df6577507"
if os.path.exists(fullname):
md5 = os.popen('md5sum ' + fullname).read().split()[0]
@ -178,6 +178,7 @@ def install_cutt(root_folder):
if not os.path.isfile(os.path.join(dirname, "bin", "cutt_test")):
LOG.i("Downloading cub...")
download_url_to_local(url, filename, root_folder, true_md5)
import tarfile
with tarfile.open(fullname, "r") as tar:

View File

@ -5,7 +5,6 @@
# ***************************************************************
import unittest
import jittor as jt
from jittor.utils import pytorch_converter
import math
import numpy as np
@ -13,6 +12,7 @@ try:
jt.dirty_fix_pytorch_runtime_error()
import torch
from torch import nn
from jittor.utils import pytorch_converter
except:
torch = None

View File

@ -46,7 +46,7 @@ run_cmd(f"git rev-parse HEAD > {polish_path}/python/jittor/version", jittor_path
files = jt.compiler.files
file_to_delete = [ name for name in files
if name.startswith("src") and \
len(name.split("/"))==2
len(name.split("/"))==2 and name.endswith("node.cc")
]
LOG.i("file_to_delete", file_to_delete)
run_cmd(f"rm {' '.join(file_to_delete)}", polish_path)
@ -97,6 +97,7 @@ assert os.system(f"cd {polish_path} && tar -cvzf build/jittor.tgz . --exclude bu
jittor_web_base_dir = "Documents/jittor-blog/assets/"
jittor_web_build_dir = jittor_web_base_dir + "build/"
assert os.system(f"rsync -avPu {polish_path}/build/ jittor-web:{jittor_web_build_dir}")==0
assert os.system(f"ssh jittor@166.111.68.30 Documents/jittor-blog.git/hooks/post-update")==0
# push to github
assert os.system(f"cd {polish_path} && git push -f origin master")==0

View File

@ -1 +1 @@
1522f3d004f9bdbf3953d91d4c259c341817c71f
77593ddd55381fddacdfa637355784488523c5e2

View File

@ -50,8 +50,10 @@ wget -O - https://bootstrap.pypa.io/get-pip.py | sudo -H python$py_version
# Step 3: Run jittor
sudo apt install git -y
git clone https://github.com/Jittor/jittor.git
if [ ! -d jittor ]; then
wget https://cg.cs.tsinghua.edu.cn/jittor/assets/build/jittor.tgz
mkdir -p jittor && tar -xvf ./jittor.tgz -C jittor
fi
sudo python$py_version -m pip install ./jittor

View File

@ -28,5 +28,7 @@ setuptools.setup(
"pybind11",
"numpy",
"tqdm",
"pillow",
"astunparse",
],
)

36
src/event_queue.cc Normal file
View File

@ -0,0 +1,36 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include "event_queue.h"
namespace jittor {
EventQueue event_queue;
void EventQueue::Worker::start() {
Worker* self = &event_queue.worker;
while (1) {
Func todo;
{
std::unique_lock<std::mutex> l(self->mtx);
event_queue.cv.notify_one();
self->cv.wait(l);
todo = self->todo;
}
if (!todo) break;
todo();
}
}
void EventQueue::worker_caller() {
event_queue.func();
{
std::lock_guard<std::mutex> l(event_queue.mtx);
event_queue.run_sync_done = true;
}
}
} // jittor

405
src/executor.cc Normal file
View File

@ -0,0 +1,405 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include <algorithm>
#include <functional>
#ifdef HAS_CUDA
#include <cuda_runtime.h>
#include <helper_cuda.h>
#include "mem/allocator/cuda_dual_allocator.h"
#include "fetcher.h"
#include "event_queue.h"
#endif
#include "misc/cuda_flags.h"
#include "executor.h"
#include "var.h"
#include "op.h"
#include "mem/allocator.h"
#include "graph.h"
#include "fused_op.h"
#include "fuser.h"
#include "profiler/profiler_guard.h"
namespace jittor {
Executor exe;
void Executor::run_sync(vector<Var*> vars, bool device_sync) {
auto allocator = get_allocator();
this->allocator = allocator;
// bfs find all ops need to run
int op_num = 0;
vector<Node*> bfs_q;
bfs_q.reserve(vars.size());
auto nodes = (vector<Node*>*)&vars;
int start_var_num = 0;
for (Var* v : vars)
if (!v->is_finished())
start_var_num++;
bfs_backward(*nodes, bfs_q, [&](Node *node) -> bool {
node->custom_data = 0;
if (node->is_finished())
return false;
op_num += !node->is_var();
return true;
});
auto tt = Node::tflag_count;
vector<Op*> ops;
vector<Var*> all_vars;
ops.reserve(op_num);
for (Node* node : bfs_q)
if (!node->is_var()) {
node->custom_data = ops.size();
ops.push_back(node->op());
} else {
// set can't fuse flag to false
node->custom_data = all_vars.size();
all_vars.push_back(node->var());
}
int var_num = all_vars.size();
// father: father of union-find set
vector<int> father(op_num);
for (int i=0; i<op_num; i++) {
father[i] = i;
}
// union-find algorithm
auto find_fa = [&](int i) -> int {
int j=i;
while (father[j] != j) j = father[j];
while (i != j) {
int tmp = father[i];
father[i] = j;
i = tmp;
}
return j;
};
vector<int> var_fused(var_num);
if (V_ON(100)) {
for (uint i=0; i<ops.size(); i++) {
Op* op = ops[i];
string st="others";
if (op->type()==OpType::reduce) st="reduce";
if (op->type()==OpType::broadcast) st="broadcast";
if (op->type()==OpType::element) st="element";
LOGvvv << "id:" << ops[i]->custom_data << " type:" <<
st << " addr:" << op;
for (Var* v : op->inputs()) {
Op* next_op = v->input();
// continue if is boundary
if (!next_op || next_op->tflag != tt) {
LOGvvv << "input:" << v;
continue;
}
LOGvvv << "input:" << next_op->custom_data << " addr:" << next_op;
}
LOGvvv << "";
}
}
count_fuse(tt, start_var_num, ops, all_vars, father, var_fused);
// var_fused represents:
// 0: can fused
// 1: cannot fused
// 2: can shared
vector<int> roots, next(op_num, -1);
vector<int> deps(op_num, 0);
roots.reserve(op_num);
for (int i=0; i<op_num; i++) {
int fa = find_fa(i);
if (fa == i)
roots.push_back(i);
else {
next[i] = next[fa];
next[fa] = i;
}
}
vector<int> queue;
queue.reserve(roots.size());
// ** toplogical_sort external **
// output:
// queue: toplogical order of fused op
{
for (int root : roots) {
for (int i=root; i>=0; i=next[i]) {
Op* op = ops[i];
for (Var* v : op->inputs()) {
if (v->tflag != tt) continue;
Op* opi = v->input();
// if those two ops are not fused
if (father[opi->custom_data] != root) {
deps[root]++;
}
}
}
if (deps[root] == 0)
queue.push_back(root);
}
for (uint s=0; s<queue.size(); s++) {
int op_id = queue[s];
for (int i=op_id; i>=0; i=next[i]) {
Op* op = ops[i];
for (Var* v : op->outputs())
if (v->tflag == tt)
for (Op* op2 : v->outputs()) {
if (op2->tflag != tt) continue;
int op2_id = father[op2->custom_data];
// continue if those two ops are fused
if (op2_id == op_id) continue;
deps[op2_id]--;
if (deps[op2_id] == 0)
queue.push_back(op2_id);
}
}
}
ASSERTop(queue.size(),==,roots.size());
}
// ** toplogical_sort internal **
// output:
// fuse_ops: fused op id [000|1111|22|3333]
// range: split index ^ ^ ^ ^ ^
vector<int> fuse_ops;
fuse_ops.reserve(op_num*2);
vector<int> range(queue.size());
{
vector<int> subgraph;
subgraph.reserve(16);
vector<int> sharegraph;
sharegraph.reserve(16);
vector<int> sharegraph_q;
sharegraph_q.reserve(16);
vector<int> shared_id(op_num, -1);
for (uint rid=0; rid<queue.size(); rid++) {
int root = queue[queue.size()-rid-1];
auto& queue = subgraph;
queue.clear();
sharegraph.clear();
int total=0;
for (int i=root; i>=0; i=next[i], total++) {
Op* op = ops[i];
for (Var* v : op->inputs()) {
if (v->tflag != tt) continue;
Op* opi = v->input();
// if those two ops are fused
int opid = opi->custom_data;
auto fopid = father[opid];
if (fopid == root)
deps[i]++;
else if (shared_id[opid] != root) {
// var_fused = 1 cannot share input op
// TODO: check this input op's output var all can be shared
if (var_fused[v->custom_data] == 1)
continue;
// new shared op
deps[opid] = 0;
shared_id[opid] = root;
sharegraph.push_back(opid);
}
}
if (deps[i] == 0)
queue.push_back(i);
}
// find all share graph
uint sn = sharegraph.size();
for (uint i=0; i<sharegraph.size(); i++) {
int id = sharegraph[i];
Op* op = ops[id];
for (Var* v : op->inputs()) {
if (v->tflag != tt) continue;
int vi = v->custom_data;
if (var_fused[vi] == 1)
continue;
Op* opi = v->input();
int opid = opi->custom_data;
int& dep = deps[opid];
if (shared_id[opid] != root) {
shared_id[opid] = root;
dep = 1;
sharegraph.push_back(opid);
} else
dep ++;
}
}
sharegraph_q.clear();
for (uint i=0; i<sn; i++)
if (deps[sharegraph[i]]==0)
sharegraph_q.push_back(sharegraph[i]);
// topsort in sharegraph_q
for (uint i=0; i<sharegraph_q.size(); i++) {
int id = sharegraph_q[i];
Op* op = ops[id];
for (Var* v : op->inputs()) {
if (v->tflag != tt) continue;
int vi = v->custom_data;
if (var_fused[vi] == 1)
continue;
Op* opi = v->input();
int opid = opi->custom_data;
int& dep = deps[opid];
dep --;
if (dep == 0)
sharegraph_q.push_back(opid);
}
}
LOGvvvv << "sharegraph_q" << sharegraph_q;
ASSERTop(sharegraph.size(),==,sharegraph_q.size());
// topsort fused op internal
for (uint s=0; s<queue.size(); s++) {
int i = queue[s];
Op* op = ops[i];
for (Var* v : op->outputs())
if (v->tflag == tt)
for (Op* op2 : v->outputs()) {
if (op2->tflag != tt) continue;
int op2_id = op2->custom_data;
// continue if those two ops are not fused
if (father[op2_id] != root) continue;
deps[op2_id]--;
if (deps[op2_id] == 0)
queue.push_back(op2_id);
}
}
ASSERTop(queue.size(),==,(uint)total);
LOGvvvv << "topsort internal" << queue;
for (int i=(int)sharegraph_q.size()-1; i>=0; i--)
fuse_ops.push_back(sharegraph_q[i]);
for (uint i=0; i<queue.size(); i++)
fuse_ops.push_back(queue[i]);
range[rid] = fuse_ops.size();
}
}
for (int i=0; i<var_num; i++) {
all_vars[i]->custom_data = var_fused[i]==1;
}
// running
FusedOp fused_op;
vector<Var*> outputs_bk;
#ifdef HAS_CUDA
int sync_times = 0;
#endif
for (uint rid=0; rid<queue.size(); rid++) {
int root = queue[rid];
Op* op = ops[root];
bool is_fused_op = false;
try {
if (op->type() != OpType::other) {
op = &fused_op;
is_fused_op = true;
fused_op.ops.clear();
fused_op.edges.clear();
int ll = (rid<queue.size()-1)?range[queue.size()-rid-2]:0, rr = range[queue.size()-rid-1];
root = fuse_ops[rr-1];
auto ntt = ++Node::tflag_count;
for (int i=ll; i<rr; i++) {
int opid = fuse_ops[i];
Op* op = ops[opid];
uint64_t fid1 = fused_op.ops.size();
op->custom_data = fid1;
op->tflag = ntt;
fused_op.ops.push_back(op);
}
for (Op* op : fused_op.ops) {
uint fid1 = op->custom_data;
uint oid = 0;
for (Var* v : op->outputs()) {
oid++;
if (v->tflag != tt) {
// this var node not belong to current execution
// this will happend in multiple outputs fuseable op
v->custom_data = 0;
continue;
}
for (auto o : v->outputs_with_index()) {
Op* op2 = o.op;
uint iid = o.index;
if (op2->tflag != ntt) continue;
uint fid2 = op2->custom_data;
fused_op.edges.emplace_back(fid1, oid-1, fid2, iid);
}
}
}
LOGvvv << "Prepare fused_op" << fused_op.ops;
fused_op.update_ops();
}
LOGvvv << "Run" << op;
if (!op->shape_infered()) op->infer_shape();
ASSERT(op->shape_infered()) << "Shape of(" >> op->name() >> ") not solved.";
for (auto* var : op->outputs())
var->alloc(allocator);
LOGvvv << "Run" << op << "inputs:" << op->inputs() << "outputs:" << op->outputs();
op->do_prepare();
bool is_cuda = op->flags.get(NodeFlags::_cuda);
#ifdef HAS_CUDA
if (!is_cuda) {
if (last_is_cuda) {
// if prev op in gpu and this op in cpu
// cuda sync
checkCudaErrors(cudaDeviceSynchronize());
sync_times++;
}
for (Var* v : op->inputs()) {
migrate_to_cpu(v, allocator);
}
}
#endif
#ifdef NODE_MEMCHECK
if (is_fused_op) {
for (auto& vi : fused_op.vars)
if (vi.type == 0)
ASSERT(vi.var->mem_ptr) << vi.var;
} else {
for (auto* v : op->inputs())
ASSERT(v->mem_ptr) << v;
}
#endif
last_is_cuda = is_cuda;
op->do_run_after_prepare();
LOGvvv << "Finished Op(" >> op->name() << rid >>
"/" >> queue.size() >> ") output:" << op->outputs();
if (is_fused_op) {
for (Var* var : op->outputs())
var->finish_pending_liveness();
continue;
}
// release liveness when op is finished
// outputs may change during free, we need to backup it;
outputs_bk.clear();
for (Var* var : op->outputs())
outputs_bk.push_back(var);
op->finish_pending_liveness();
for (Var* var : outputs_bk)
// var->finish_pending_liveness();
var->finish_pending_liveness();
} catch (const std::exception& e) {
if (is_fused_op) {
LOGf << "Execute fused operator(" >> rid >> '/' >> queue.size() >> ")"
<< "failed:" << fused_op.ops << "\n\nReason: " >> e.what();
} else
LOGf << "Execute operator(" >> rid >> '/' >> queue.size() >> ")"
<< "failed:" << op << "\n\nReason: " >> e.what();
}
}
LOGvv << "All" << op_num << "ops finished, return vars:" << vars;
for (Var* v : vars) ASSERT(v->mem_ptr);
#ifdef HAS_CUDA
if (device_sync) {
last_is_cuda = false;
sync_times++;
event_queue.run_sync([]() {
checkCudaErrors(cudaDeviceSynchronize());
});
}
LOGvv << "cudaDeviceSynchronize times:" << sync_times << "/" <<queue.size();
#endif
}
} // jittor

108
src/fetcher.cc Normal file
View File

@ -0,0 +1,108 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#ifdef HAS_CUDA
#include <cuda_runtime.h>
#include <helper_cuda.h>
#include <mutex>
#include "mem/allocator/sfrl_allocator.h"
#include "mem/allocator/cuda_dual_allocator.h"
#include "event_queue.h"
#endif
#include "fetcher.h"
#include "mem/allocator.h"
namespace jittor {
#ifdef HAS_CUDA
#pragma GCC visibility push(hidden)
namespace fetcher_local {
cudaStream_t stream;
cudaEvent_t event;
volatile int64 n_to_fetch;
std::mutex m;
list<FetchResult> fetch_tasks;
static void fetch_caller() {
fetch_tasks.front().call();
fetch_tasks.pop_front();
}
static void to_fetch(void*) {
event_queue.push(fetch_caller);
}
struct Init {
Init() {
checkCudaErrors(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
checkCudaErrors(cudaEventCreate(&event, cudaEventDisableTiming));
}
~Init() {
// do not call deleter on exit
for (auto& f : fetch_tasks)
f.func.deleter = nullptr;
checkCudaErrors(cudaDeviceSynchronize());
checkCudaErrors(cudaStreamDestroy(stream));
checkCudaErrors(cudaEventDestroy(event));
}
} init;
}
using namespace fetcher_local;
#endif
void fetch(const vector<VarHolder*>& vh, FetchFunc&& func) {
sync(vh);
vector<Allocation> allocations(vh.size());
vector<ArrayArgs> arrays(vh.size());
#ifdef HAS_CUDA
bool has_cuda_memcpy = false;
event_queue.flush();
#endif
for (int i=0; i<vh.size(); i++) {
auto v = vh[i]->var;
auto& allocation = allocations[i];
#ifdef HAS_CUDA
if (v->allocator->is_cuda()) {
checkCudaErrors(cudaEventRecord(event, 0));
checkCudaErrors(cudaStreamWaitEvent(stream, event, 0));
new (&allocation) Allocation(&cuda_dual_allocator, v->size);
// mostly device to device
checkCudaErrors(cudaMemcpyAsync(
allocation.ptr, v->mem_ptr, v->size, cudaMemcpyDefault, stream));
auto host_ptr = cuda_dual_allocator.get_dual_allocation(
allocation.allocation).host_ptr;
// device to host
checkCudaErrors(cudaMemcpyAsync(
host_ptr, allocation.ptr, v->size, cudaMemcpyDefault, stream));
allocation.ptr = host_ptr;
has_cuda_memcpy = true;
} else
#endif
{
new (&allocation) Allocation(cpu_allocator, v->size);
std::memcpy(allocation.ptr, v->mem_ptr, v->size);
}
arrays[i].ptr = allocation.ptr;
arrays[i].shape = v->shape;
arrays[i].dtype = v->dtype();
}
#ifdef HAS_CUDA
if (has_cuda_memcpy) {
fetch_tasks.push_back({move(func), move(allocations), move(arrays)});
checkCudaErrors(cudaLaunchHostFunc(stream, &to_fetch, 0));
} else
#endif
{
FetchResult fr{move(func), move(allocations), move(arrays)};
fr.call();
}
}
} // jittor

242
src/fused_op.cc Normal file
View File

@ -0,0 +1,242 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include "fused_op.h"
#include "var.h"
#include "op_compiler.h"
#include "profiler/profiler.h"
#include "misc/fast_shared_ptr.h"
namespace jittor {
#ifndef JIT
string_view_map<FusedOpContext*> jit_fused_ops;
std::ostream& operator<<(std::ostream& os, const VarInfo& vi) {
return os << vi.var << " type:" << vi.type;
}
int FusedOp::get_loop_option(const string& key, const int& _default) {
auto iter = loop_options->find(key);
return iter == loop_options->end() ? _default : iter->second;
}
loop_options_t& FusedOp::get_loop_options_tuned() {
loop_options_tuned = *loop_options_origin;
loop_options = &loop_options_tuned;
return loop_options_tuned;
}
void FusedOp::update_jit_key() {
jk.clear();
do_jit_prepare();
}
void FusedOp::update_ops() {
loop_options_merged.clear();
loop_options_tuned.clear();
loop_options = loop_options_origin = nullptr;
_outputs.clear();
jk.clear();
for (Op* op : ops) {
for (Var* o : op->outputs()) {
if (o->loop_options) {
if (loop_options_origin == nullptr)
loop_options_origin = &o->loop_options.data();
else if (loop_options_origin != &o->loop_options.data()) {
// merge loop options
for (auto& kv : o->loop_options.data())
loop_options_merged[kv.first] = kv.second;
}
}
// bit0 represents can fuse or not
if (o->custom_data&1)
// this var can not fuse
_outputs.emplace_back((Node*)o, 0);
}
}
if (loop_options_origin) {
if (loop_options_merged.size()) {
// merge loop_options_origin into loop_options_merged
for (auto& kv : *loop_options_origin)
loop_options_merged.emplace(kv);
}
} else {
loop_options_origin = &loop_options_merged;
}
loop_options = loop_options_origin;
ASSERT(outputs().size());
LOGvvvv << "set fused output" << outputs();
// var.custom_data
// meaning of custom_data&1(input): 1: cannot fuse, 0 can fuse
// meaning of custom_data&2: visited or not
// meaning of custom_data>>2: index of vars
// op.custom_data: opid
for (uint i=0; i<ops.size(); i++) {
auto opi = ops[i];
opi->custom_data = i;
for (Var* i : opi->inputs()) {
i->custom_data &= 1;
}
for (Var* o : opi->outputs()) {
o->custom_data &= 1;
}
}
vars.clear();
for (Op* opi : ops) {
for (Var* i : opi->inputs()) {
auto &c = i->custom_data;
// if not visited
if (!(c&2)) {
c += 2 + vars.size()*4;
vars.push_back({i, 0});
}
}
for (Var* o : opi->outputs()) {
auto &c = o->custom_data;
// if not visited
if (!(c&2)) {
c += 2 + vars.size()*4;
// intermediate(can fuse) or output
vars.push_back({o, int((c&1)+1)});
}
}
}
LOGvvvv << "Var info" << vars;
}
FusedOp::FusedOp() {
Op::number_of_lived_ops--;
}
FusedOp::~FusedOp() {
_outputs.clear();
Op::number_of_lived_ops++;
}
void FusedOp::infer_shape() {
for (uint i=0; i<ops.size(); i++)
ops[i]->infer_shape();
}
bool FusedOp::shape_infered() {
for (uint i=0; i<ops.size(); i++)
if (!ops[i]->shape_infered())
return false;
return true;
}
void FusedOp::statistics(uint64_t& in, uint64_t& out, uint64_t& compute) {
in = out = compute = 0;
for (auto& vi : vars) {
compute = std::max(compute, (uint64_t)vi.var->num);
if (vi.type == 0) in += vi.var->size;
if (vi.type == 2) out += vi.var->size;
}
}
void FusedOp::do_jit_prepare() {
jk.clear();
int8 flags = 3;
for (uint i=0; i<ops.size(); i++) {
Op* op = ops[i];
jk << JK::key << "opkey" << i << JK::val;
op->do_jit_prepare();
jk << JK::end;
if (op->flags.get(NodeFlags::_cpu))
flags &= 1; // only cpu
else
flags &= 2; // only gpu
}
ASSERT(flags) << "FusedOp cannot contain both cpu and cuda ops.";
add_jit_define("JIT", "1");
if (flags==1) {
// only cpu
add_jit_define("JIT_cpu", "1");
this->flags.set(NodeFlags::_cuda, 0);
this->flags.set(NodeFlags::_cpu, 1);
} else {
add_jit_define("JIT_cuda", "1");
this->flags.set(NodeFlags::_cpu, 0);
this->flags.set(NodeFlags::_cuda, 1);
}
jk << JK::key << "graph" << JK::val;
for (auto& t : edges) {
uint i,j,k,l;
std::tie(i,j,k,l) = t;
jk << JK::hex2(i) << JK::hex1(j) << JK::hex2(k) << JK::hex1(l) << ',';
}
jk << JK::end << JK::key << "var_info" << JK::val;
for (auto& vi : vars)
jk << JK::hex1(vi.type) << JK::hex1(vi.var->shape.size());
jk << JK::end;
if (loop_options->size()) {
if (get_loop_option("compile_shapes")) {
jk << JK::key << "shapes" << JK::val;
for (auto& vi : vars) {
jk << '[';
for (auto a : vi.var->shape)
jk << a << ',';
jk << "],";
}
jk << JK::end;
}
jk << JK::key << "choices" << JK::val;
for (auto& kv : *loop_options)
jk << kv.first << ':' << kv.second << ',';
jk << JK::end;
}
jk.finilize();
}
void FusedOp::do_prepare() {
do_jit_prepare();
}
void FusedOp::do_run_after_prepare() {
const char* jit_key = jk.to_cstring();
auto iter = jit_fused_ops.find(string_view(jit_key, jk.size));
if (iter != jit_fused_ops.end()) {
LOGvvv << "Jit fused op key found:" << jit_key << "jit op entry:" << (void*)iter->second;
context = iter->second;
iter->second->vrm.fop = this;
Profiler::record_and_run(iter->second->entry, this, jit_key);
return;
}
LOGvv << "Jit op key not found:" << jit_key;
// compile JIT op
context = new FusedOpContext();
context->vrm.fop = this;
string prev_jit_key = jit_key;
context->entry = OpCompiler::do_compile(this);
string new_jit_key = get_jit_key();
jit_fused_ops[new_jit_key] = jit_fused_ops[prev_jit_key] = context;
jit_key_mapper[prev_jit_key] = new_jit_key;
LOGvv << "Get jit op entry:" << (void*)(context->entry);
Profiler::record_and_run(context->entry, this, new_jit_key.c_str());
}
void FusedOp::do_run(){
do_prepare();
do_run_after_prepare();
}
#else // JIT
void FusedOp::jit_run() {
for (uint i=0; i<ops.size(); i++) {
LOGvvvv << "fuse run:" << ops[i] << ops[i]->inputs() << ops[i]->outputs();
ops[i]->do_run();
}
}
#endif // JIT
}

187
src/fuser.cc Normal file
View File

@ -0,0 +1,187 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. Authors:
// Guowei Yang <471184555@qq.com>
// Dun Liang <randonlang@gmail.com>.
// All Rights Reserved.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include <algorithm>
#include <functional>
#include "fuser.h"
#include "var.h"
#include "op.h"
#include "mem/allocator.h"
#include "graph.h"
#include "fused_op.h"
namespace jittor {
#define PREVENT_LARGE_FUSED_OP 16
void count_fuse(int64_t tt, int start_var_num, const vector<Op*>& ops, const vector<Var*>& vars, vector<int> &father, vector<int> &var_fused) {
vector<int> dis(ops.size(), -1);
auto find_fa = [&](int i) -> int {
int j=i;
while (father[j] != j) j = father[j];
while (i != j) {
int tmp = father[i];
father[i] = j;
i = tmp;
}
return j;
};
auto can_fuse = [&](Var* v, Op* op1, Op* op2, int fuse_type) -> bool {
if (v->flags.get(NodeFlags::_stop_fuse))
return false;
if (fuse_type == 1) {
// if v is output, do not fuse
if (v->custom_data < start_var_num)
return false;
// op2 ---> v ---> op1
if (op1->type() == OpType::other || op2->type() == OpType::other)
return false;
if (v->flags.get(NodeFlags::_force_fuse))
return true;
// Do not fuse op after reduce(has reduce)
// TODO: better fuse strategy
if (op2->type() == OpType::reduce)
return false;
// Do not fuse op before broadcast
// TODO: better fuse strategy
if (op1->type() == OpType::broadcast)
return false;
return op2->type() == OpType::element ||
op2->type() == OpType::broadcast;
} else if (fuse_type == 0) {
#ifdef PREVENT_LARGE_FUSED_OP
// This statement prevent fuse large ops
if (v->outputs().size()>=PREVENT_LARGE_FUSED_OP) return false;
#endif
// v ---> op1
// |
// +----> op2 ( prev of op1 )
if (op1->type() == OpType::other || op2->type() == OpType::other)
return false;
// Do not fuse op after reduce(has reduce)
// TODO: better fuse strategy
if (op2->type() == OpType::broadcast || op1->type() == OpType::broadcast)
return false;
return true;
}
return false;
};
auto for_each_edge = [&](Op* op, int forward, auto&& func){
auto e=op->_inputs.begin();
for (Var* v : op->inputs()) {
if ((forward && (*e).back!=std::prev(v->_outputs.end())) ||
(!forward && (*e).back!=v->_outputs.begin())){
Op* next_op = forward ? std::next((*e).back)->node->op() : std::prev((*e).back)->node->op();
if (next_op && next_op->tflag==tt
&& next_op->custom_data != op->custom_data
&& can_fuse(v, next_op, op, 0))
func(v, next_op, 0);
}
e = std::next(e);
}
if (forward) {
for (Var* sv : op->outputs())
if (sv && sv->tflag == tt)
for (Op* next_op: sv->outputs())
if (next_op && next_op->tflag==tt) func(sv, next_op, 1);
} else {
for (Var* sv : op->inputs())
if (sv && sv->tflag == tt) func(sv, sv->input(), 1);
}
};
vector<int> queue;
vector<int> deps;
deps.reserve(ops.size());
queue.reserve(ops.size());
for (uint i=0; i<ops.size(); i++) {
deps.push_back(0);
Op* op = ops[i];
for_each_edge(op, 1, [&](Var* v, Op* next_op, int real_edge) {
deps[i]++;
});
if (!deps[i]) {
queue.push_back(i);
dis[i]=0;
}
}
uint head=0;
while (head<queue.size()) {
int op_id=queue[head++];
Op* op = ops[op_id];
for_each_edge(op, 1, [&](Var* v, Op* next_op, int real_edge) {
int next_id = next_op->custom_data;
if (dis[next_id] == dis[op_id]){
int next_fa = find_fa(next_id);
father[next_fa] = op_id;
}
});
for_each_edge(op, 0, [&](Var* v, Op* next_op, int real_edge) {
int next_id = next_op->custom_data;
int lon=0;
if (real_edge && !can_fuse(v, op, next_op, 1)) lon=1;
if (dis[op_id]+lon>dis[next_id])
dis[next_id]=dis[op_id]+lon;
if (!--deps[next_id]) queue.push_back(next_id);
});
}
if (V_ON(1000)) {
for (uint i=0; i<ops.size(); i++)
LOGvvvv << ops[i] << dis[i] << deps[i];
}
for (uint i=0; i<vars.size(); i++) {
Var* v = vars[i];
if (!v || v->tflag!=tt) {
var_fused[i]=1;
continue;
}
// sf: input op's father id
int sf = -1;
// vf: is input op can be fused with all output op
int vf = 1;
// all outputs are reduce
int all_reduce = 1;
Op* iop = v->input();
// if (iop && iop->tflag==tt)
sf = find_fa(iop->custom_data);
for (Op* sop : v->outputs())
if (sop->tflag==tt) {
if (vf && !can_fuse(v,sop,iop,1))
vf = 0;
if (sop->type()!=OpType::reduce)
all_reduce = 0;
// in two different fused op
if (find_fa(sop->custom_data)!=sf) {
var_fused[i]=1;
}
}
if (vf==0) var_fused[i]=1;
if (var_fused[i] && vf &&
(iop->type()==OpType::broadcast || all_reduce || v->flags.get(NodeFlags::_force_fuse)))
var_fused[i]=2;
}
// output vars can not be fused
for (int i=0; i<start_var_num; i++)
var_fused[i] = 1;
}
} // jittor

134
src/grad.cc Normal file
View File

@ -0,0 +1,134 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include "pybind/py_var_tracer.h"
#include "grad.h"
#include "var.h"
#include "op.h"
#include "graph.h"
#include "ops/op_register.h"
namespace jittor {
#define PREVENT_LARGE_FUSED_OP 16
static auto make_binary = get_op_info("binary")
.get_constructor<VarPtr, Var*, Var*, NanoString>();
static auto make_number = get_op_info("number")
.get_constructor<VarPtr, float, Var*>();
VarPtr make_grad(Op* op, Var* out, Var* dout, Var* x, int x_index) {
if (dout == nullptr) return nullptr;
LOGvvvv << "Make grad op:" >> op->name() << "inputs:" >> op->inputs()
<< "out:" >> out << "dout:" >> dout << "x:" >> x << "xid:" >> x_index;
return op->grad(out, dout, x, x_index);
}
inline static void assign_attrs(Var* a, Var* b) {
if (b->flags.get(NodeFlags::_stop_fuse))
a->flags.set(NodeFlags::_stop_fuse);
}
vector<VarPtr> grad(Var* loss, vector<Var*> targets) {
LOGvv << "loss:" >> loss << "targets:" >> targets;
CHECK(loss->is_float()) << "Loss should be float";
for (Var* var : targets)
CHECK(var->is_float()) << "Targets of grad should be float";
// successors of targets
vector<Node*> ts(targets.begin(), targets.end());
// bfs visit find all successors of targets
LOGvv << "Size of successors:" << ts.size();
bfs_forward(ts, [](Node*){ return true; });
vector<Node*> gnodes;
gnodes.reserve(ts.size());
auto nt = Node::tflag_count;
if (loss->tflag == nt)
gnodes.push_back(loss);
bfs_backward(gnodes, [&](Node* node) {
if (node->tflag != nt)
return false;
if (node->is_stop_grad())
return false;
// int value has zero grad
if (node->is_var())
return node->var()->is_float();
return true;
});
LOGvv << "Size of grad nodes:" << gnodes.size();
vector<Node*> sorted;
toplogical_sort_backward(gnodes, sorted, [](Node*){});
nt = Node::tflag_count;
vector<Var*> gvars;
gvars.reserve(sorted.size());
for (Node* node : sorted)
if (node->is_var())
gvars.push_back(node->var());
LOGvv << "Size of grad vars:" << gvars.size();
vector<VarPtr> grads(gvars.size());
vector<VarPtr> results(targets.size());
for (size_t i=0; i<gvars.size(); i++)
gvars[i]->custom_data = i;
for (size_t i=0; i<gvars.size(); i++) {
Var* var = gvars[i];
auto& grad = grads[i];
#ifdef PREVENT_LARGE_FUSED_OP
int gsum = 0;
#endif
if (i==0) {
grad = make_number(1.f, loss);
assign_attrs(grad.ptr, loss);
registe_node_trace_grad(grad.ptr, loss, 0);
} else
for (auto it : var->outputs_with_index()) {
Op* op = it.op;
auto index = it.index;
if (op->tflag != nt) continue;
// TODO: support two outputs backprop.
Var* out = op->outputs().back();
Var* dout = grads[out->custom_data];
VarPtr dvar = make_grad(op, out, dout, var, index);
registe_node_trace_grad(dvar.ptr, op, index);
if (dvar)
ASSERT(dvar->num==var->num && dvar->shape.size()==var->shape.size())
<< "dvar" << dvar << "var" << var;
if (!grad)
grad = move(dvar);
else if (dvar) {
grad = make_binary(grad, dvar, ns_add);
#ifdef PREVENT_LARGE_FUSED_OP
gsum ++;
if (gsum>=PREVENT_LARGE_FUSED_OP) {
// TODO: this is a dirty fix for
// stopping fuse lots of op together,
// try to find a better solution
grad->flags.set(NodeFlags::_stop_fuse);
}
#endif
assign_attrs(grad.ptr, var);
registe_node_trace_grad(grad.ptr, var, index);
}
}
}
// set zero grad
for (size_t i=0; i<results.size(); i++) {
Var* var = targets[i];
VarPtr& grad = results[i];
if (var->tflag == nt)
grad = move(grads[var->custom_data]);
if (!grad) {
LOGvvv << var << "grads[">>i>>"] set to zero";
grad = make_number(0.f, var);
assign_attrs(grad.ptr, var);
registe_node_trace_grad(grad.ptr, var, 0);
}
}
return results;
}
} // jittor

114
src/graph.cc Normal file
View File

@ -0,0 +1,114 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include <sstream>
#include "graph.h"
#include "var_holder.h"
#include "var.h"
namespace jittor {
DEFINE_FLAG(int, check_graph, 0, "Unify graph sanity check.");
extern unordered_map<void*, int64> lived_nodes;
template <typename T>
string ss_convert(T x) {
std::stringstream ss;
ss << x;
return ss.str();
}
void do_graph_check() {
vector<Node*> queue;
unordered_map<Node*,int> visited;
for (auto& vh : VarHolder::hold_vars) {
if (0==visited[vh->var]++)
queue.push_back(vh->var);
}
LOGvv << "Check hold_vars size" << queue.size();
int vhsize = queue.size();
for (auto* node : queue) {
ASSERTop(node->forward_liveness,>,0);
ASSERTop(node->backward_liveness,>,0);
}
for (uint i=0; i<queue.size(); i++) {
auto* node = queue[i];
for (auto* i : node->inputs()) {
if (visited.count(i)) continue;
visited[i] = 0;
queue.push_back(i);
}
}
LOGvv << "Check all var size" << queue.size();
for (int i=0; i<(int)queue.size(); i++) {
auto* node = queue[i];
LOGvvvv << "Check node" << i << node;
int f=0, b=0, p=0;
if (i<vhsize) {
f+=visited.at(node), b+=visited.at(node);
}
for (auto* i : node->inputs()) {
if (i->is_stop_grad()) continue;
if (!i->forward_liveness) continue;
f ++;
}
for (auto* o : node->outputs()) {
if (o->backward_liveness)
b ++;
if (o->pending_liveness && !o->is_finished())
p++;
}
if (f>0 && b>0 && !node->is_finished()) p++;
if (f!=node->forward_liveness || b!=node->backward_liveness || p!=node->pending_liveness) {
LOGf << "ERROR" << node << '\n'
<< f << b << p << i << '\n'
<< node->inputs() << '\n'
<< node->outputs();
continue;
}
}
for (auto& kv : lived_nodes) {
if (!kv.second) continue;
auto* node = (Node*) kv.first;
if (!visited.count(node) && node->tflag != -1) {
if (node->is_var() && node->_inputs.size())
continue;
LOGf << "ERROR dnode" << (void*)node << kv.second << node;
}
}
}
DumpGraphs dump_all_graphs() {
vector<Node*> queue;
auto t = ++Node::tflag_count;
for (auto& vh : VarHolder::hold_vars)
if (vh->var->tflag != t) {
vh->var->tflag = t;
queue.push_back(vh->var);
}
bfs_both(queue, [](Node*){return true;});
DumpGraphs graphs;
for (uint i=0; i<queue.size(); i++)
queue[i]->custom_data = i;
for (Node* node : queue) {
graphs.nodes_info.emplace_back(ss_convert(node));
graphs.inputs.emplace_back();
auto& inputs = graphs.inputs.back();
inputs.reserve(node->_inputs.size());
for (auto i : node->_inputs)
inputs.push_back(i.node->custom_data);
graphs.outputs.emplace_back();
auto& outputs = graphs.outputs.back();
outputs.reserve(node->_outputs.size());
for (auto o : node->_outputs)
outputs.push_back(o.node->custom_data);
}
return graphs;
}
} // jittor

39
src/init.cc Normal file
View File

@ -0,0 +1,39 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include <random>
#include "init.h"
#include "ops/op_register.h"
namespace jittor {
unique_ptr<std::default_random_engine> eng;
vector<set_seed_callback> callbacks;
int current_seed;
void init() {
// init default_random_engine
set_seed(time(0));
// init fused op
op_registe({"fused","",""});
}
void set_seed(int seed) {
current_seed = seed;
eng.reset(new std::default_random_engine(seed));
for (auto cb : callbacks)
cb(seed);
}
void add_set_seed_callback(set_seed_callback callback) {
callbacks.push_back(callback);
callback(current_seed);
}
std::default_random_engine* get_random_engine() { return eng.get(); }
}

100
src/jit_compiler.cc Executable file
View File

@ -0,0 +1,100 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include <fstream>
#include <streambuf>
#include <stdlib.h>
#include <dlfcn.h>
#include "jit_compiler.h"
#include "op.h"
#include "utils/cache_compile.h"
#include "utils/flags.h"
#include "fused_op.h"
namespace jittor {
DEFINE_FLAG(string, jittor_path, "", "Source path of jittor");
DEFINE_FLAG(string, cc_path, "", "Path of C++ compiler");
DEFINE_FLAG(string, cc_type, "", "Type of C++ compiler(clang, icc, g++)");
DEFINE_FLAG(string, cc_flags, "", "Flags of C++ compiler");
DEFINE_FLAG(string, nvcc_path, "", "Path of CUDA C++ compiler");
DEFINE_FLAG(string, nvcc_flags, "", "Flags of CUDA C++ compiler");
DEFINE_FLAG(string, python_path, "", "Path of python interpreter");
DEFINE_FLAG(string, cache_path, "", "Cache path of jittor");
DEFINE_FLAG(int, rewrite_op, 1, "Rewrite source file of jit operator or not");
namespace jit_compiler {
jit_op_entry_t load_jit_lib(string name, string symbol_name="jit_entry") {
LOGvv << "Opening jit lib:" << name;
// void* handle = dlopen(name.c_str(), RTLD_NOW | RTLD_DEEPBIND | RTLD_LOCAL);
// RTLD_DEEPBIND and openmp cause segfault
void* handle = dlopen(name.c_str(), RTLD_NOW | RTLD_LOCAL | RTLD_DEEPBIND);
CHECK(handle) << "Cannot open library" << name << ":" << dlerror();
//dlerror();
auto jit_entry = (jit_op_entry_t)dlsym(handle, symbol_name.c_str());
const char* dlsym_error = dlerror();
CHECK(!dlsym_error) << "Loading symbol jit_entry from" << name << "failed:" << dlsym_error;
return jit_entry;
}
void run_cmd(string cmd, string cwd="") {
if (cwd.size()) cmd = "cd "+cwd + " && " + cmd;
LOGvvv << "Run cmd:" << cmd;
system_with_check(cmd.c_str());
}
static string get_symbol_name(const string& jit_key) {
int i=0;
while (i<jit_key.size() && jit_key[i]!='[') i++;
string op_name = i ? jit_key.substr(0, i) : "fused";
op_name = Op::file_name_to_class_name(op_name);
// _ZN7jittorXyyyyyy7jit_runEv
// jittor::yyyyyy::jit_run
op_name = "_ZN6jittor"+S(op_name.size()+2)+op_name+"Op7jit_runEv";
return op_name;
}
jit_op_entry_t compile(const string& jit_key, const string& src, const bool is_cuda_op, const string& extra_flags) {
auto iter = jit_ops.find(jit_key);
if (iter != jit_ops.end())
return iter->second;
LOGvv << "Compile op" << jit_key;
// compiler do not allowed filename too long
CHECK(cc_path.size());
string jit_src_path = Op::get_filename_from_jit_key(jit_key, ".cc");
string jit_lib_path = Op::get_filename_from_jit_key(jit_key, ".so");
string other_src = " "+join(jittor_path, "src/op.cc")+" "+
join(jittor_path, "src/var.cc")+" ";
other_src = "";
LOGvvv << "Generate" << jit_src_path >> "\n" >> src;
if (rewrite_op || !file_exist(jit_src_path))
write(jit_src_path, src);
string cmd;
if (is_cuda_op) {
cmd = nvcc_path
+ " '" + jit_src_path + "'" + other_src
+ nvcc_flags + extra_flags
+ " -o '" + jit_lib_path + "'";
} else {
cmd = cc_path
+ " '" + jit_src_path + "'" + other_src
+ cc_flags + extra_flags
+ " -o '" + jit_lib_path + "'";
cmd = python_path+" "+jittor_path+"/utils/asm_tuner.py "
"--cc_path=" + cmd;
}
cache_compile(cmd, cache_path, jittor_path);
auto symbol_name = get_symbol_name(jit_key);
auto jit_entry = load_jit_lib(jit_lib_path, symbol_name);
jit_ops[jit_key] = jit_entry;
return jit_entry;
}
} // jit_compiler
} // jittor

117
src/jit_key.cc Normal file
View File

@ -0,0 +1,117 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include <sys/mman.h>
#include <sstream>
#include "jit_key.h"
#include "misc/str_utils.h"
namespace jittor {
const int page_size = 4*1024;
extern size_t protected_page;
static size_t get_buffer_end_page(size_t buffer_end) {
// get the last complete page in buffer
// 4k align :
// | | | | |
// buffer: xxxxxxxxxxxxxxxxxxxxxxxx
// ^ buffer_end_page
size_t buffer_end_page = buffer_end - buffer_end % page_size;
if (buffer_end_page + page_size-1 > buffer_end)
buffer_end_page -= page_size;
return buffer_end_page;
}
JitKey::JitKey() {
auto buffer_end_page = get_buffer_end_page((size_t)&buffer[buffer_size-1]);
LOGvv << "protect page" << (void*)buffer_end_page;
ASSERT(0==mprotect((void*)buffer_end_page, page_size, PROT_NONE));
protected_page = buffer_end_page;
}
JitKey::~JitKey() {
auto buffer_end_page = get_buffer_end_page((size_t)&buffer[buffer_size-1]);
LOGvv << "un-protect page" << (void*)buffer_end_page;
ASSERT(0==
mprotect((void*)buffer_end_page, page_size, PROT_READ|PROT_WRITE|PROT_EXEC));
protected_page = 0;
}
static void hex_to_dec(string& s) {
// check s is hex or not, if yes, convert to dec
if (!s.size()) return;
unsigned int x;
std::stringstream ss;
ss << std::hex << s;
ss >> x;
s = S(x);
}
static void convert_itof(string& s) {
uint64 x;
std::stringstream ss;
// itof(0x...)
// ^ ^
// 7
ASSERT(s.size()>=8);
ss << std::hex << s.substr(7, s.size()-7-1);
ASSERT(ss >> x);
ss.str(""); ss.clear();
ss << std::hexfloat << itof(x);
s = ss.str();
// 0x0p+0 ---> 0x0p0
if (s.find("p+") != string::npos)
s.erase(s.find("p+")+1, 1);
if (s=="inf") s = "(1.0/0)";
if (s=="-inf") s = "(-1.0/0)";
if (s=="nan" || s=="-nan") s = "(0.0/0)";
}
vector<pair<string,string>> parse_jit_keys(const string& s) {
vector<pair<string,string>> jit_keys;
int presum = 0;
char state=0;
string key, val;
for (char c : s) {
if (c==JK::key) {
presum++;
if (presum==1) {
state = c;
continue;
}
} else
if (c==JK::val || c==JK::hex_val) {
if (presum==1 && state==JK::key) {
state = c;
continue;
}
} else
if (c==JK::end) {
presum--;
if (presum==0) {
if (state == JK::hex_val)
hex_to_dec(val);
if (startswith(val, "itof"))
convert_itof(val);
jit_keys.emplace_back(move(key), move(val));
continue;
}
}
if (presum) {
if (state==JK::key)
key += c;
if (state==JK::val || state==JK::hex_val)
val += c;
}
}
ASSERT(presum==0);
return jit_keys;
}
JitKey jk;
} // jittor

268
src/op.cc Normal file
View File

@ -0,0 +1,268 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include <limits>
#include "node.h"
#include "op.h"
#include "var.h"
#include "op_compiler.h"
#include "profiler/profiler.h"
#include "mem/allocator.h"
#include "misc/cuda_flags.h"
#include "pybind/py_var_tracer.h"
namespace jittor {
DECLARE_FLAG(string, cache_path);
DEFINE_FLAG(int, try_use_32bit_index, 0,
"If not overflow, try to use 32 bit type as index type.");
string_view_map<jit_op_entry_t> jit_ops;
string_view_map<string> jit_key_mapper;
int64_t Op::number_of_lived_ops = 0;
Op::Op() {
flags.set(NodeFlags::_var, 0);
flags.set(NodeFlags::_cpu, 1);
number_of_lived_ops++;
}
Op::~Op() {
number_of_lived_ops--;
}
void Op::forward(Var* input) {
flags.set(NodeFlags::_forwarded);
outputs_holder.emplace_back(input);
}
VarPtr Op::grad(Var* out, Var* dout, Var* v, int v_index) {
LOGw << "Grad of" << name() << "return zeros";
return nullptr;
}
Var* Op::create_output(NanoVector shape, NanoString dtype) {
VarPtr vp(shape, dtype);
Var* output = vp.ptr;
outputs_holder.emplace_back(move(vp));
return output;
}
void Op::init() {
infer_shape();
LOGvvvv << "Create" << this << "and outputs" << outputs();
for (Var* v : outputs())
CHECK(v->shape.size()) << "Number of dims should be solved.";
}
bool Op::shape_infered() {
if (flags.get(NodeFlags::_vary_shape)) return true;
for (Var* v : outputs())
if (v->num < 0) return false;
return true;
}
string Op::name_ex() const {
string a=name();
if (ns!=ns_void) {
a += '.';
a += ns.to_cstring();
}
return a;
}
string Op::get_jit_key() {
jk.clear();
do_jit_prepare();
return jk.to_string();
}
vector<pair<string,string>> Op::get_jit_define() {
return parse_jit_keys(get_jit_key());
}
void Op::do_jit_prepare() {
memcheck_all_exist();
jk << name();
jit_prepare();
if (!jk.empty()) {
// check use int64_t as index_t if array is too big
int in_id=0, out_id=0;
bool use_int64_t = false;
// TODO: fused op do not have inputs,
// check use_cuda_op from outputs may not be enough
bool use_cuda_op = use_cuda;
for (Var* var : inputs()) {
if (var->allocator) {
jk << JK::key << "alloc_i" << JK::hex1(in_id)
<< JK::hex1(var->allocator->flags()) << JK::end;
use_cuda_op &= var->allocator->is_cuda();
}
if (var->num >= std::numeric_limits<int32_t>::max())
use_int64_t = true;
in_id ++;
}
for (Var* var : outputs()) {
if (var->allocator) {
jk << JK::key << "alloc_o" << JK::hex1(in_id)
<< JK::hex1(var->allocator->flags()) << JK::end;
use_cuda_op &= var->allocator->is_cuda();
}
if (var->num >= std::numeric_limits<int32_t>::max())
use_int64_t = true;
out_id ++;
}
add_jit_define("JIT", "1");
if (use_cuda_op && flags.get(NodeFlags::_cuda)) {
add_jit_define("JIT_cuda", "1");
flags.set(NodeFlags::_cpu, 0);
// TODO: 64bit index in CUDA
use_int64_t = false;
} else {
if (use_cuda==2) {
if (flags.get(NodeFlags::_cuda))
LOGf << "Op" << name() >> "'s vars are not allocated in cuda";
else
LOGf << "Op" << name() << "doesn't have cuda version";
}
ASSERT(flags.get(NodeFlags::_cpu))
<< "Op" << name() << "doesn't have cpu version";
add_jit_define("JIT_cpu", "1");
flags.set(NodeFlags::_cuda, 0);
}
if (try_use_32bit_index) use_int64_t = false;
add_jit_define("index_t", use_int64_t ? "int64" : "int32");
}
jk.finilize();
}
void Op::do_prepare(){
jk.clear();
do_jit_prepare();
}
void Op::do_run_after_prepare() {
if (!jk.empty())
jit_run();
else
run();
}
void Op::do_run() {
do_prepare();
do_run_after_prepare();
}
string Op::get_filename_from_jit_key(const string& jit_key, const string& suffix) {
auto iter = jit_key_mapper.find(jit_key);
string s = iter==jit_key_mapper.end() ? jit_key : iter->second;
std::stringstream ss;
if (s.size() > 100) {
ss << s.substr(0, 90) << "...hash:"
<< std::hex << std::hash<string>()(s);
} else {
ss << s << "_hash:" <<
std::hex << std::hash<string>()(s);
}
s = ss.str();
for (char& c : s) {
if (c=='[' || c==']' || c=='<' || c=='>'
|| c=='{' || c=='}' || c=='(' || c==')' || c==','
|| c=='\n' || c=='\t' || c==' ' || c=='&' || c=='|'
|| c=='/')
c = '_';
}
string filename = cache_path + "/jit/";
filename += s;
filename += "_op";
filename += suffix;
return filename;
}
// convert xxx.yyy -> xxx
string Op::op_name_to_file_name(const string& s) {
auto pos = s.find('.');
return pos == string::npos ? s : s.substr(0, pos);
}
// convert xxx_xxx -> XxxXxx
string Op::file_name_to_class_name(const string& s) {
char prev = '_';
string res;
res.reserve(s.size());
for (char c : s) {
if (c != '_') {
if (prev == '_')
res += c-'a'+'A';
else
res += c;
}
prev = c;
}
return res;
}
void Op::jit_run() {
const char* jit_key = jk.to_cstring();
auto iter = jit_ops.find(jit_key);
if (iter != jit_ops.end()) {
LOGvvv << "Jit op key found:" << jit_key << "jit op entry:" << (void*)iter->second;
Profiler::record_and_run(iter->second, this, jit_key);
return;
}
LOGvv << "Jit op key not found:" << jit_key;
// compile JIT op
string prev_jit_key = jit_key;
auto op_entry = OpCompiler::do_compile(this);
string new_jit_key = get_jit_key();
jit_ops[new_jit_key] = jit_ops[prev_jit_key] = op_entry;
jit_key_mapper[prev_jit_key] = new_jit_key;
LOGvv << "Get jit op entry:" << (void*)op_entry;
Profiler::record_and_run(op_entry, this, new_jit_key.c_str());
}
void Op::statistics(uint64_t& in, uint64_t& out, uint64_t& compute) {
in = out = compute = 0;
for (Var* var : inputs()) {
in += var->size;
compute = std::max(compute, (uint64_t)var->num);
}
for (Var* var : outputs()) {
out += var->size;
compute = std::max(compute, (uint64_t)var->num);
}
}
std::ostream& operator<<(std::ostream& os, const Op* op) {
if (!op) return os << "Op(0)";
os << "Op(" << (void*)op
<< ':' << op->forward_liveness
<< ':' << op->backward_liveness
<< ':' << op->pending_liveness
<< ":i" << op->_inputs.size()
<< ":o" << op->_outputs.size()
<< ":s" << op->is_finished()
<< "," << op->name_ex();
if (op->_outputs.size()>1)
os << "->...";
else if (op->_outputs.size() == 1) {
auto v = (Var*)op->_outputs.front().node;
if (v->name.size())
os << "->" << v->name;
else
os << "->" << (void*)v;
}
os << ')';
#ifdef NODE_MEMCHECK
os << '<' << op->__id() << '>';
print_node_trace(op, os);
#endif
return os;
}
} // jittor

915
src/op_compiler.cc Normal file
View File

@ -0,0 +1,915 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include <regex>
#include <algorithm>
#include "op.h"
#include "fused_op.h"
#include "op_compiler.h"
#include "jit_compiler.h"
#include "utils/cache_compile.h"
#include "opt/tuner_manager.h"
#include "misc/str_utils.h"
#include "ops/op_register.h"
#include "ops/array_op.h"
namespace jittor {
DECLARE_FLAG(string, jittor_path);
using namespace jit_compiler;
static bool isvar(char x) { return isalnum(x) || x == '_' || x == ':'; }
void OpCompiler::get_op_var_by_name(const string& name, uint& op_id, uint& opvar_id, Op*& op, Var*& var) {
// name: op{id}_{varname}
ASSERT(name.size()>3 && name[0]=='o' && name[1]=='p');
uint j=2;
while (j<name.size() && isdigit(name[j])) j++;
ASSERT(j>2);
op_id = std::stoi(name.substr(2, j-2));
ASSERT(op_members.size() > op_id);
bool found = false;
for (opvar_id=0 ;opvar_id < op_members[op_id].size(); opvar_id++) {
if (op_members[op_id][opvar_id] == name) {
found = true;
break;
}
}
op = this->op->ops[op_id];
ASSERT(found && opvar_id < op->inputs().size() + op->outputs().size());
if (opvar_id >= op->inputs().size()) {
auto iter = op->outputs().begin();
for (uint t=op->inputs().size(); t<opvar_id; t++)
iter++;
var = *iter;
} else {
auto iter = op->inputs().begin();
for (uint t=0; t<opvar_id; t++)
iter++;
var = *iter;
}
}
string OpCompiler::get_name_by_op_var(Op* op, Var* var) {
uint var_id=0;
bool found = 0;
for (Var* i : op->inputs()) {
if (i==var) {
found = 1;
break;
}
var_id++;
}
if (!found)
for (Var* o : op->outputs()) {
if (o==var) {
found = 1;
break;
}
var_id++;
}
ASSERT(found);
ASSERT(op->custom_data<(int)op_members.size());
auto& v = op_members[op->custom_data];
ASSERT(var_id < v.size());
return v[var_id];
}
string OpCompiler::get_name_by_op_input(Op* op, uint i) {
return op_members.at(op->custom_data).at(i);
}
string OpCompiler::get_name_by_op_output(Op* op, uint i) {
return op_members.at(op->custom_data).at(i+op->inputs().size());
}
bool OpCompiler::op_exist(Op* op) {
return op_members.at(op->custom_data).size();
}
int OpCompiler::total_member_count() {
int member_count=0;
int i = 0;
for (auto& v : op_members) {
// array need a extra local var
if (op->ops[i]->name()==string("array"))
member_count += 1;
member_count += v.size();
i += 1;
}
return member_count;
}
#define FOR_ALL_UOPS(m) \
m(!,3) m(~,3)
#define FOR_ALL_BOPS(m) \
m(*,5) m(/,5) m(%,5) \
m(+,6) m(-,6) \
m(<<,7) m(>>,7) \
m(<,9) m(<=,9) m(>,9) m(>=,9) \
m(!=,10) m(==,10) \
m(&,11) \
m(^,12) \
m(|,13) \
m(&&,14) \
m(||,15)
#define FOR_ALL_OPS(m) FOR_ALL_UOPS(m) FOR_ALL_BOPS(m)
inline bool is_unary_op(const string& op) {
#define _u(o, _) if (op == #o) return true;
FOR_ALL_UOPS(_u);
return false;
}
inline int precedence(const string& op) {
#define _prior(o, p) if (op == #o) return p;
FOR_ALL_OPS(_prior);
return 20;
}
inline bool check_precedence(const string& op1, const string& op2) {
if (op1 == op2 && is_unary_op(op1)) return false;
return precedence(op1) <= precedence(op2);
}
inline int64_t calc_op(int64_t a, int64_t b, const string& op) {
#define _calc_b(o, _) if (op == #o) return a o b;
FOR_ALL_BOPS(_calc_b);
#define _calc_u(o, _) if (op == #o) return o b;
FOR_ALL_UOPS(_calc_u);
ASSERT(0) << "Unrecognized op " << op;
return 0;
}
int64_t OpCompiler::eval(const string& expr, const unordered_map<string,string>& vars) {
if (expr.find("@") != string::npos) {
string new_expr;
for (size_t i=0; i<expr.size(); i++) {
if (expr[i] != '@') new_expr += expr[i];
else {
size_t j=i+1;
ASSERT(j < expr.size());
// syntax @{...}
// ij k
if (expr[j] == '{') {
size_t k=j+1;
int presum = 1;
while (k<expr.size() && presum) {
if (expr[k] == '}')
presum--;
else if (expr[k] == '{')
presum++;
k++;
}
ASSERT(presum==0) << "Jit error: braces are not matched.";
new_expr += S(eval(expr.substr(j+1, k-j-2), vars));
i = k-1;
continue;
} else {
// syntax: @x
ASSERT(isvar(expr[j]));
size_t k=j+1;
while (k<expr.size() && isvar(expr[k])) k++;
string var = expr.substr(j, k-j);
auto iter = vars.find(var);
ASSERT(iter!=vars.end()) << "Jit var " << var << " not found.";
new_expr += iter->second;
i = k-1;
}
}
}
return eval(new_expr, vars);
}
vector<int64> values = {0};
vector<string> ops;
auto pop_values_and_calc_op = [&]() {
CHECK(ops.size());
auto op = ops.back();
ops.pop_back();
CHECK(values.size());
auto val2 = values.back();
values.pop_back();
auto val1 = val2;
if (!is_unary_op(op)) {
CHECK(values.size());
val1 = values.back();
values.pop_back();
}
values.push_back(calc_op(val1, val2, op));
};
for (size_t i=0; i<expr.size(); i++) {
if (expr[i] == ' ')
continue;
if (expr[i] == '(')
ops.push_back(string()+expr[i]);
else if (isdigit(expr[i])) {
int64_t val = 0;
while (i<expr.length() && isdigit(expr[i])) {
val = val*10 + (expr[i]-'0');
i++;
}
i--;
values.push_back(val);
} else if (isvar(expr[i])) {
auto j=i+1;
while (j<expr.size() && isvar(expr[j])) j++;
auto var_name = expr.substr(i,j-i);
auto iter = vars.find(var_name);
ASSERT(iter!=vars.end()) << "Jit var " << var_name << " not found.";
try {
values.push_back(std::stoll(iter->second));
} catch (...) {
ASSERT(0) << "'" << iter->second << "' is not integer, expr " << expr;
}
i = j-1;
} else if (expr[i] == ')') {
while (ops.size() && ops.back() != "(")
pop_values_and_calc_op();
ops.pop_back();
} else {
auto j=i+1;
while (j<expr.size() && expr[j] != ' ' &&
expr[j] != '!' && expr[j] != '~' &&
!isdigit(expr[j]) && !isvar(expr[j]) &&
expr[j] != '(' && expr[j] != ')') j++;
auto op = expr.substr(i, j-i);
while (ops.size() && check_precedence(ops.back(), op))
pop_values_and_calc_op();
ops.push_back(op);
i = j-1;
}
}
while (ops.size())
pop_values_and_calc_op();
return values.back();
}
void load_macros(const string& src, unordered_map<string,string>& macros) {
LOGvvvv << "load_macros" << src;
for (size_t i=0; i<src.size(); i++) {
if (src[i] == '#') {
// #define xxx(...) xxx
// i jk r l p q
auto j=i+1;
while (j<src.size() && src[j] != ' ') j++;
if (j-i!=7 || src.substr(i,j-i) != "#define") {
i = j;
continue;
}
ASSERT(j<src.size());
auto k=j+1;
while (k<src.size() && src[k] == ' ') k++;
ASSERT(k<src.size());
auto l=k+1;
while (l<src.size() && (src[l] != '\n' && src[l-1] != ')')) l++;
auto p=l;
while (p<src.size() && (src[p] == ' ')) p++;
auto q=p;
while (q<src.size() && (src[q] != '\n')) q++;
// TODO: multiline macro
auto r=k;
while (r<l && src[r] != '(') r++;
auto body = q>p ? src.substr(p,q-p) : "";
body = (r<l?src.substr(r,l-r):"()") + body;
auto header = src.substr(k,r-k);
LOGvvvv << "header:" << header << "body:" << body;
macros[header] = body;
i = q;
}
}
}
void expand_macro(const string& macro, const vector<string>& args, string& new_src) {
LOGvvvv << "expand_macro" << macro << "args:" << args;
auto i = macro.find(")");
ASSERT(i != string::npos);
// (a1, a2, ...)body
// j k i
unordered_map<string, int> args_map;
for (uint j=1, l=0; j<i; l++) {
uint k=j;
while (k<i && macro[k] != ',') k++;
args_map[macro.substr(j,k-j)] = l;
j = k+1;
while (j<i && macro[j] == ' ') j++;
}
ASSERTop(args.size(),==,args_map.size()) << "Number of macro args not match.";
for (i=i+1; i<macro.size(); i++) {
if (isvar(macro[i])) {
uint j = i+1;
while (j<macro.size() && isvar(macro[j])) j++;
string var = macro.substr(i, j-i);
auto iter = args_map.find(var);
if (iter == args_map.end()) {
new_src += var;
} else {
new_src += args[iter->second];
}
i = j-1;
continue;
}
new_src += macro[i];
}
}
string precompile(unordered_map<string,string> defs, string src, unordered_map<string, string>& macros) {
string new_src;
new_src.reserve(src.size());
for (size_t i=0; i<src.size(); i++) {
try{
if (src[i] == '/' && (i+1<src.size() && src[i+1] == '/')) {
size_t j=i+1;
while (j<src.size() && src[j] != '\n') j++;
if (j<src.size()) j++;
// remove comment
// for (size_t k=i; k<j; k++) new_src += src[k];
if (src[j-1]=='\n')
new_src += '\n';
i = j-1;
continue;
} else
if (src[i] == '/' && (i+1<src.size() && src[i+1] == '*')) {
size_t j=i+1;
while (j<src.size() && !(src[j] == '/' && src[j-1] == '*')) j++;
if (j<src.size()) j++;
// remove comment
// for (size_t k=i; k<j; k++) new_src += src[k];
i = j-1;
continue;
} else
if (src[i] == '#') {
// #include "a.h"
// i jk l
// #define xxx
// i jk l
auto j=i+1;
while (j<src.size() && src[j] != ' ') j++;
ASSERT(j<src.size());
auto k=j+1;
while (k<src.size() && src[k] == ' ') k++;
ASSERT(k<src.size());
auto l=k+1;
while (l<src.size() && (src[l] != '\n')) l++;
if (src[k] == '"' && src[l-1] == '"' && j-i==8 && src.substr(i,j-i) == "#include") {
auto inc = src.substr(k+1, l-k-2);
if (inc.size()>=6 && inc.substr(inc.size()-6) == "defs.h") {
LOGvvvv << "Found defs include" << inc;
auto src_path = join(jittor_path, "src");
src_path = join(src_path, inc);
auto inc_src = read_all(src_path);
// load_macros from include src
precompile(defs, inc_src, macros);
// we do not include defs.h
i = l;
continue;
}
} else
if (j-i==7 && src.substr(i,j-i) == "#define") {
load_macros(src.substr(i,l-i), macros);
} else
// #ifdef JITxxx
// #else
// #endif
if (((j-i==6 && src.substr(i,j-i) == "#ifdef") ||
(j-i==7 && src.substr(i,j-i) == "#ifndef")) && startswith(src, "JIT", k)) {
bool is_ifndef = j-i==7;
string key = src.substr(k, l-k);
// find pair #endif and #else
int presum = 1;
size_t prev = l+1, ii = prev;
string block, else_block;
while (ii < src.size()) {
if (startswith(src, "#if", ii)) {
presum++;
ii += 3;
continue;
}
if (startswith(src, "#else", ii)) {
auto next_ii = ii+5;
// remove ' ' or '\n' after #else
if (next_ii<src.size() && (src[next_ii]==' ' || src[next_ii]=='\n'))
next_ii++;
if (presum==1) {
block = src.substr(prev, ii-prev);
prev = next_ii;
}
ii = next_ii;
continue;
}
if (startswith(src, "#endif", ii)) {
presum--;
auto next_ii = ii+6;
// remove ' ' or '\n' after #endif
if (next_ii<src.size() && (src[next_ii]==' ' || src[next_ii]=='\n'))
next_ii++;
if (presum==0) {
if (prev == l+1)
block = src.substr(prev, ii-prev);
else
else_block = src.substr(prev, ii-prev);
ii = next_ii;
break;
}
ii = next_ii;
continue;
}
ii++;
}
ASSERT(presum==0);
if (is_ifndef) block.swap(else_block);
if (defs.count(key) || macros.count(key)) {
new_src += precompile(defs, block, macros);
} else {
new_src += precompile(defs, else_block, macros);
}
i = ii-1;
continue;
}
for (auto k=i; k<l; k++) new_src += src[k];
i= l-1;
continue;
} else
if (src[i] == '@' && i+1<src.size()) {
size_t j=i+1;
// syntax @{...}
// ij k
if (src[j] == '{') {
size_t k=j+1;
int presum = 1;
while (k<src.size() && presum) {
if (src[k] == '}')
presum--;
else if (src[k] == '{')
presum++;
k++;
}
ASSERT(presum==0) << "Jit error: braces are not matched.";
new_src += S(OpCompiler::eval(src.substr(j+1, k-j-2), defs));
i = k-1;
continue;
} else if (isvar(src[j])) {
size_t k=j+1;
while (k<src.size() && isvar(src[k])) k++;
string expr = src.substr(j, k-j);
int presum = 1;
vector<int> comma;
vector<string> args;
size_t l = k+1;
if (expr == "for" || expr == "if" || expr == "expand_macro" ||
(k<src.size() && src[k]=='(')) {
ASSERT(src[k] == '(');
comma.push_back(k);
while (l<src.size() && presum) {
if (src[l] == ')')
presum--;
else if (src[l] == '(')
presum++;
else if (presum == 1 && src[l] == ',')
comma.push_back(l);
l++;
}
ASSERT(presum==0) << "Jit error: braces are not matched.";
comma.push_back(l-1);
for (uint i=0; i+1<comma.size(); i++)
args.push_back(src.substr(comma[i]+1, comma[i+1]-comma[i]-1));
}
// syntax @for(i, l, r, ...)
// ij k l
if (expr == "for") {
CHECKop(args.size(),>=,4u) << "Jit error: for missing arguments.";
string vi = args[0];
string vl = args[1];
string vr = args[2];
string vs = args[3];
auto vil = OpCompiler::eval(vl, defs);
auto vir = OpCompiler::eval(vr, defs);
int step = 1;
if (args.size() >= 5) {
step = OpCompiler::eval(vs, defs);
vs = args[4];
}
auto new_defs = defs;
LOGvvv << "Expand for" << expr >> "[" >> vil >> "," >> vir >> "," >> step >> "]";
int total_step = 0;
for (auto vii=vil; vii!=vir; vii+=step) {
total_step ++;
ASSERT(total_step < 1000) << "Too much step.";
new_defs[vi] = S(vii);
new_src += precompile(new_defs, vs, macros);
}
i = l-1;
continue;
} else
if (expr == "if") {
// syntax: @if(cond, true[, false])
// ij k l
ASSERT(args.size()>=2u && args.size()<=3u)
<< "Jit error: if wrong arguments.";
string vcond = args[0];
string vtrue = args[1];
string vfalse = args.size() == 3u ? args[2] : "";
int cond = OpCompiler::eval(vcond, defs);
new_src += precompile(defs, cond?vtrue:vfalse, macros);
i = l-1;
continue;
} else
if (expr == "expand_macro") {
// syntax: @expand_macro(macro, args)
// ij k l
for (auto& arg : args) {
uint p=0;
while (p<arg.size() && arg[p] == ' ') p++;
arg = precompile(defs, arg.substr(p), macros);
}
string vmacro = args[0];
args.erase(args.begin());
auto iter = macros.find(vmacro);
string ns;
if (iter == macros.end()) {
if (defs.count(vmacro))
ns = defs[vmacro];
else
LOGf << "Macro" << vmacro << "not found.";
} else {
expand_macro(iter->second, args, ns);
}
new_src += precompile(defs, ns, macros);
i = l-1;
continue;
} else
if (args.size()) {
// syntax: @e0(i0,i1,...,in) -> e0p[i0*e0stride0+i1*e0stride1+...]
int nid=(int)expr.size();
while (nid && isdigit(expr[nid-1])) nid--;
// xyz123 ---> prefix: xxx; suffix: 123
string prefix = expr.substr(0, nid);
string suffix = expr.substr(nid);
string up_prefix = prefix;
for (auto& c : up_prefix)
if (c>='a' && c<='z') c = c-'a'+'A';
string dim = up_prefix + "DIM" + suffix;
if (prefix == "e") prefix = "extras";
ASSERT(defs.count(dim)) << dim;
ASSERTop(defs.at(dim),==,S(args.size()));
expr = prefix + suffix; // e0 ->extras0
std::stringstream ss;
ss << expr << "p[";
for (uint ii=0; ii<args.size(); ii++) {
string arg = precompile(defs, args[ii], macros);
if (ii) ss << "+";
ss << '(' << arg << ")*" << expr << "stride" << ii;
}
ss << ']';
new_src += ss.str();
i = l-1;
continue;
}
// syntax: @x
auto iter = defs.find(expr);
ASSERT(iter!=defs.end()) << "Jit var " << expr << " not found.";
new_src += precompile(defs, iter->second, macros);
i = k-1;
continue;
} else if (src[j]=='@') {
// seperater syntex: @@
i++;
continue;
} else
LOGf << "Jit error: Invalid syntax.";
} else
new_src += src[i];
} catch (std::exception& e) {
uint il = i, ir = i;
while (il && src[il] != '\n') il--;
while (ir<src.size() && src[ir] != '\n') ir++;
string this_line = src.substr(il+1, ir-il-1);
LOGf << e.what() >> "\nJit compiler error:\n" >> this_line;
}
}
return new_src;
}
string OpCompiler::precompile(const unordered_map<string,string>& defs, const string& src) {
unordered_map<string, string> macros;
return jittor::precompile(defs, src, macros);
}
string OpCompiler::get_jit_src(Op* op) {
string name = op->name();
string name2 = Op::op_name_to_file_name(name);
string name3 = Op::file_name_to_class_name(name2);
if (name == "fused") {
string src = get_fused_src((FusedOp*)op);
ASSERT(src.size());
return src;
}
auto op_info = get_op_info(name);
auto& src_path = op_info.source_path;
string begin_src = "", end_src = "";
// source that need to be added after the last #include statement
string after_include_src = "";
auto jit_define = op->get_jit_define();
for (auto &t : jit_define) {
string src = "#define " + t.first + " ";
for (char c : t.second) {
if (c=='\n') src += '\\';
src += c;
}
src += '\n';
if (startswith(t.first, "JIT"))
begin_src += src;
else
after_include_src += src;
}
ASSERT(file_exist(src_path));
LOGvvv << "Read from" << src_path;
string src = read_all(src_path);
ASSERT(src.size()) << "Source read failed:" << src_path;
unordered_map<string,string> defs(jit_define.begin(), jit_define.end());
LOGvvv << "Precompile with key:" << defs;
src = precompile(defs, src);
// find the last occur of #include "..."\n
auto pos = src.rfind("#include");
if (pos == string::npos) pos=0;
else {
// find \n
pos = src.find("\n", pos);
if (pos == string::npos)
pos = src.size();
else
pos++;
}
string new_src = begin_src + src.substr(0, pos) +
after_include_src + src.substr(pos) + "\n" + end_src;
return new_src;
}
string OpCompiler::get_fused_src(FusedOp* op) {
vector<string> op_srcs;
vector<bool> relay_switch(op->context->vrm.relay_groups.size());
for (uint i=0; i<relay_switch.size(); i++) {
auto relay_key = "relay"+S(i);
if (op->loop_options->count(relay_key) &&
op->loop_options->at(relay_key) == 1)
relay_switch[i] = 1;
}
auto relay_source = op->context->vrm.get_op_relay_info(relay_switch);
std::set<pair<int,int>> relayed;
for (uint oi=0; oi<op->ops.size(); oi++) {
// relay group id, pair id
auto p = relay_source[oi];
if (p.first != -1) {
if (relayed.count(p)) {
op_srcs.push_back("");
continue;
}
relayed.insert(p);
auto src = op->context->vrm.get_relay_src(p.first, p.second);
op_srcs.push_back(src);
// op_srcs.push_back(get_relayed_src(src));
continue;
}
Op* opi = op->ops[oi];
string src = get_jit_src(opi);
op_srcs.push_back(move(src));
}
return OpCompiler::__get_fused_src(op->ops, op_srcs, op_members);
}
string OpCompiler::__get_fused_src(
const vector<Op*>& ops,
const vector<string>& op_srcs,
vector<vector<string>>& op_members
) {
string fused_begin;
string fused_includes;
string fused_defines;
string fused_kernel_args;
string fused_kernel;
// definitions of fused_begin
map<string,string> defs;
unordered_set<string> kernel_args;
op_members = vector<vector<string>>(op_srcs.size());
fused_begin += "#define JIT 1\n";
defs["JIT"] = "1";
const string pattern = "::jit_run() {";
// TODO: better check member
const unordered_set<string> members = {
"x", "y", "z", "cond", "output", "extras"
};
const unordered_set<string> unchanged = {
"for", "const", "auto", "get_random_engine",
"int", "float", "bool", "CHECK", "STRINGIZE",
"void", "__restrict__", "if", "true", "false",
"Op", "Var", "Node", "itof"
};
auto not_change = [&](const string& s) -> bool {
if (unchanged.count(s)) return true;
return (s.find("::") != string::npos) || (s.find("LOG") != string::npos);
};
// regex find XxxXxxOp::jit_run
std::regex e(R"([^]*\s(\S*)Op::jit_run[^]*)");
for (uint oi=0; oi<op_srcs.size(); oi++) {
const string& src = op_srcs[oi];
if (src.size()==0) continue;
if (src.find("@relay_op") != string::npos) {
fused_kernel += src;
continue;
}
if (ops[oi]->name()==string("array")) {
string op_name = "op" + S(oi);
string arg_name = op_name + "_output";
string argp_name = op_name + "_outputp";
string T = ((ArrayOp*)ops[oi])->output->dtype().to_cstring();
fused_kernel_args += " ArrayOp* " + op_name + " = (ArrayOp*)(ops[" + S(oi) + "]);\n";
// op_name = "((ArrayOp*)(ops[" + S(oi) + "]))";
fused_kernel_args += " Var* " + arg_name + " = " + op_name + "->output;\n";
fused_kernel += " auto* " + argp_name + " = " + arg_name + "->ptr<" + T + ">();\n";
fused_kernel += " " + argp_name + "[0] = " + op_name + "->ptr<" + T + ">()[0];\n";
fused_kernel += " int " + arg_name + "shape0 = 1;\n";
fused_kernel += " int " + arg_name + "stride0 = 1;\n";
fused_includes += "#include \"ops/array_op.h\"\n";
op_members[oi].push_back(arg_name);
// auto opi = (ArrayOp*)(ops[i]);
// auto opi_output = opi->output;
// auto* opi_outputp = opi_output->ptr<T>();
// opi_outputp[0] = ((T*)(opi->buffer.get()))[0];
continue;
}
std::smatch cm;
std::regex_match(src, cm, e);
ASSERT(cm.size()>=2) << src;
string name3 = cm[1];
for (uint i=0; i<src.size(); i++) {
if (src[i] == '#' &&
(i+1<src.size() && src[i+1] == 'i') &&
(i+2<src.size() && src[i+2] == 'n'))
{
// #include ...
uint j=i+1;
while (j<src.size() && src[j] != '\n') j++;
if (j<src.size()) j++;
for (uint k=i; k<j; k++) fused_includes += src[k];
i = j-1;
continue;
}
if (src[i] == '#' && (i+1<src.size() && src[i+1] == 'd')) {
// #define aaa bbb
// i j k l
// TODO: multi-line define
uint j=i+1;
while (j<src.size() && src[j] != ' ') j++;
while (j<src.size() && src[j] == ' ') j++;
uint k=j;
while (k<src.size() && src[k] != ' ') k++;
uint l=k;
while (l<src.size() && src[l] != '\n') l++;
if (l<src.size()) l++;
CHECK(i<j && j<k && k<l);
// define startswith JIT should be added at the very beginning
if (startswith(src, "JIT", j)) {
string key = src.substr(j,k-j);
string value = src.substr(k+1, l-k-2);
if (defs.count(key))
CHECKop(defs[key],==,value);
else {
defs[key] = value;
fused_begin += "#define ";
for (; j<l; j++) fused_begin += src[j];
}
j = l;
} else {
fused_defines += "#define op" + S(oi) + "_";
for (; j<l; j++) fused_defines += src[j];
}
i = j-1;
continue;
}
// find the first function match the pattern "jit_run"
bool found = true;
for (uint j=0; j<pattern.size(); j++)
if (pattern[j] != src[i+j]) {
found = false;
break;
}
if (!found) continue;
uint j = i+pattern.size();
uint k = j;
int presum = 1;
while (k<src.size() && presum) {
if (src[k] == '}')
presum--;
else if (src[k] == '{')
presum++;
k++;
}
ASSERT(presum==0) << "Jit error: braces are not matched.";
for (;j < k-2; j++) {
if (isvar(src[j])) {
uint l=j;
while (l<src.size() && isvar(src[l])) l++;
auto var = src.substr(j, l-j);
if (var[0] == ':' || isdigit(var[0]) || not_change(var) || src[j-1]=='.' || src[j-1]=='>') {} else
if (members.count(var)) {
string arg_name = "op" + S(oi) + "_" + var;
if (l<src.size() && src[l]=='[') {
// handle extras[...]
// l r
uint r = l+1;
while (r<src.size() && src[r]!=']') r++;
ASSERT(r<src.size());
for (uint i=l+1; i<r; i++) {
ASSERT(isdigit(src[i]));
arg_name += src[i];
}
l = r+1;
var = src.substr(j, l-j);
// arg_name = opi_extra0
// var = extra[0]
}
if (!kernel_args.count(arg_name)) {
fused_kernel_args +=
string(" auto ") + arg_name +
" = (("+name3+"Op*)(ops[" + S(oi) + "]))->" + var;
fused_kernel_args += ";\n";
kernel_args.insert(arg_name);
op_members[oi].push_back(arg_name);
}
fused_kernel += arg_name;
j = l-1;
continue;
} else
fused_kernel += "op" + S(oi) + "_";
for (uint p=j; p<l; p++) fused_kernel += src[p];
j = l-1;
continue;
}
fused_kernel += src[j];
}
break;
}
}
CHECK(!(defs.count("JIT_cpu") && defs.count("JIT_cuda")))
<< "CPU op and GPU op cannot be fused together.";
fused_kernel = fused_kernel_args + "\n" + fused_kernel;
LOGvvvv << "Fused kernel:\n" >> fused_kernel;
auto fused_src = fused_begin + fused_includes + "\n#include \"fused_op.h\"\n" +
fused_defines + '\n' +
"void jittor::FusedOp::jit_run() {\n" + fused_kernel + "\n}\n";
// we assume the member name is in lexicographical order
// for (auto& v : op_members) std::sort(v.begin(), v.end());
return fused_src;
}
string OpCompiler::get_src() {
if (op==nullptr) return src;
for (const auto& p : *op->loop_options)
if (startswith(p.first, "relay")) {
// return get jit src if has relay op
return get_jit_src(op);
}
return src;
}
OpCompiler::OpCompiler(Op* op) {
_op = op;
this->op = op->name()==string("fused") ? (FusedOp*)op : nullptr;
src = get_jit_src(op);
}
jit_op_entry_t OpCompiler::compile(const string& jit_key, const string& src) {
// add extra flags for custom ops
bool is_cuda = _op->flags.get(NodeFlags::_cuda);
auto op_info = get_op_info(_op->name());
return jit_compiler::compile(jit_key, src, is_cuda, op_info.extra_flags);
}
jit_op_entry_t OpCompiler::do_compile(Op* op) {
OpCompiler oc(op);
string* src = &oc.src;
string src_after_passes;
// if is fused op
if (oc.op) {
TunerManager tm(&oc);
src_after_passes = tm.tune();
src = &src_after_passes;
}
return oc.compile(op->get_jit_key(), *src);
}
}

101
src/var.cc Normal file
View File

@ -0,0 +1,101 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include <type_traits>
#include "var.h"
#include "op.h"
#include "mem/allocator.h"
#include "pybind/py_var_tracer.h"
namespace jittor {
int64_t Var::number_of_lived_vars = 0;
DEFINE_FLAG(fast_shared_ptr<loop_options_t>, compile_options, {},
"Override the default loop transfrom options");
Var::Var(NanoVector shape, NanoString dtype)
: shape(shape),
loop_options(compile_options) {
flags.set(NodeFlags::_var, 1);
ns = dtype;
ASSERT(ns.is_dtype());
number_of_lived_vars++;
numel();
}
Var::~Var() {
if (mem_ptr != nullptr)
allocator->free(mem_ptr, size, allocation);
number_of_lived_vars--;
}
string Var::to_string() {
string s = dtype().to_cstring();
s += shape.to_string();
return s;
}
int64_t Var::numel() {
if (!shape.size()) return size=num=-1;
bool negtive = 0;
num=1;
for (auto k : shape) {
if (k<0) {
negtive = 1;
num *= -k;
} else {
num *= k;
}
}
size = num * dsize();
if (negtive) num = -num;
return num;
}
void Var::set_shape(NanoVector shape) {
this->shape = shape;
numel();
}
bool Var::alloc(Allocator* allocator) {
if (mem_ptr) return true;
if (auto* x = (Var*)(this->allocator)) {
if (x->allocator->share_with(size, x->allocation)) {
mem_ptr = x->mem_ptr;
allocation = x->allocation;
this->allocator = x->allocator;
return true;
}
}
mem_ptr = allocator->alloc(size, allocation);
this->allocator = allocator;
return mem_ptr;
}
std::ostream& operator<<(std::ostream& os, const Var& var) {
os << "Var" << '(' << (void*)&var
<< ':' << var.forward_liveness
<< ':' << var.backward_liveness
<< ':' << var.pending_liveness
<< ":i" << var._inputs.size()
<< ":o" << var._outputs.size()
<< ":s" << var.is_finished()
<< ','
<< var.dtype().to_cstring() << ',' << var.name << ',' << var.mem_ptr
<< ')' << var.shape;
#ifdef NODE_MEMCHECK
os << '<' << var.__id() << '>';
print_node_trace(&var, os);
#endif
return os;
}
std::ostream& operator<<(std::ostream& os, const Var* var) {
return os << *var;
}
std::ostream& operator<<(std::ostream& os, const VarPtr& v) { return os << v.ptr; }
} // jittor

125
src/var_holder.cc Normal file
View File

@ -0,0 +1,125 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#ifdef HAS_CUDA
#include <cuda_runtime.h>
#include <helper_cuda.h>
#endif
#include "var_holder.h"
#include "var.h"
#include "executor.h"
#include "graph.h"
namespace jittor {
list<VarHolder*> VarHolder::hold_vars;
void add_hold_vars(VarHolder* self) {
VarHolder::hold_vars.push_front(self);
self->iter = VarHolder::hold_vars.begin();
}
VarHolder::VarHolder(Var* v) : var(v) {
add_hold_vars(this);
// Var holder has both forward and backward liveness
var->own_both_liveness();
}
VarHolder::VarHolder(VarPtr&& v) : VarHolder(v.ptr) {
v.free_liveness();
v.ptr = nullptr;
}
VarHolder::VarHolder(VarHolder* v) : var(v->var) {
iter = v->iter;
*iter = this;
// free memory without calling deconstructor
operator delete(v);
}
VarHolder::~VarHolder() {
hold_vars.erase(iter);
var->release_both_liveness();
}
// assign attributes of b to a
static inline void assign_var(Var* a, Var* b) {
a->name = move(b->name);
if (b->is_stop_grad())
a->set_stop_grad();
if (b->flags.get(NodeFlags::_stop_fuse))
a->flags.set(NodeFlags::_stop_fuse);
}
void VarHolder::operator=(VarPtr&& v) {
assign_var(v.ptr, var);
var->release_both_liveness();
var = v.ptr;
v.ptr = nullptr;
}
string VarHolder::to_string() {
if (var->num<0) sync();
return var->to_string();
}
VarHolder* VarHolder::assign(VarHolder* v) {
assign_var(v->var, var);
var->release_both_liveness();
var = v->var;
var->own_both_liveness();
return this;
}
extern Executor exe;
void VarHolder::sync(bool device_sync) {
jittor::sync({this}, device_sync);
}
ArrayArgs VarHolder::fetch_sync() {
sync(true);
#ifdef HAS_CUDA
migrate_to_cpu(var, exe.allocator);
#endif
return {var->mem_ptr, var->shape, var->dtype()};
}
void sync_all(bool device_sync) {
vector<Var*> vars;
vars.reserve(VarHolder::hold_vars.size());
for (auto v : VarHolder::hold_vars) {
if (!v->var->_outputs.size())
vars.push_back(v->var);
}
graph_check();
exe.run_sync(vars, device_sync); //need sync at last
graph_check();
}
void sync(const vector<VarHolder*>& vh, bool device_sync) {
vector<Var*> vars;
vars.reserve(vh.size());
for (auto v : vh) vars.push_back(v->var);
graph_check();
exe.run_sync(vars, device_sync); //need sync at last
graph_check();
}
vector<ArrayArgs> fetch_sync(const vector<VarHolder*>& vh) {
vector<ArrayArgs> ret(vh.size());
sync(vh, true);
for (uint i=0; i<vh.size(); i++) {
#ifdef HAS_CUDA
migrate_to_cpu(vh[i]->var, exe.allocator);
#endif
ret[i].ptr = vh[i]->var->mem_ptr;
ret[i].shape = vh[i]->var->shape;
ret[i].dtype = vh[i]->var->dtype();
}
return ret;
}
} // jittor