forked from jittor/jittor
parallel compiler & normal random
This commit is contained in:
parent
3989c7f19a
commit
e78ef3fc71
|
@ -363,7 +363,7 @@ Jittor还很年轻。 它可能存在错误和问题。 请在我们的错误跟
|
|||
## 团队
|
||||
|
||||
|
||||
Jittor目前由来自[清华大学计算机图形学组](https://cg.cs.tsinghua.edu.cn/)的梁盾,杨国烨,杨国炜和周文洋等博士生维护。 如果您也对Jittor感兴趣并希望对其进行改进,请加入我们!
|
||||
Jittor目前由来自[清华大学计算机图形学组](https://cg.cs.tsinghua.edu.cn/)的梁盾,杨国烨,杨国炜,周文洋和国孟昊等博士生维护。 如果您也对Jittor感兴趣并希望对其进行改进,请加入我们!
|
||||
|
||||
|
||||
## 版权声明
|
||||
|
|
|
@ -356,7 +356,7 @@ File an issue: https://github.com/Jittor/jittor/issues
|
|||
## The Team
|
||||
|
||||
|
||||
Jittor is currently maintained by Dun Liang, Guo-Ye Yang, Guo-Wei Yang and Wen-Yang Zhou etc. from the [Tsinghua CSCG Group](https://cg.cs.tsinghua.edu.cn/). If you are also interested in Jittor and want to improve it, Please join us!
|
||||
Jittor is currently maintained by Dun Liang, Guo-Ye Yang, Guo-Wei Yang, Wen-Yang Zhou and Meng-Hao Guo etc. from the [Tsinghua CSCG Group](https://cg.cs.tsinghua.edu.cn/). If you are also interested in Jittor and want to improve it, Please join us!
|
||||
|
||||
|
||||
## License
|
||||
|
|
|
@ -457,9 +457,9 @@ File an issue: https://github.com/Jittor/jittor/issues
|
|||
|
||||
## 团队
|
||||
|
||||
Jittor is currently maintained by Dun Liang, Guo-Ye Yang, Guo-Wei Yang and Wen-Yang Zhou etc. from the [Tsinghua CSCG Group](https://cg.cs.tsinghua.edu.cn/). If you are also interested in Jittor and want to improve it, Please join us!
|
||||
Jittor is currently maintained by Dun Liang, Guo-Ye Yang, Guo-Wei Yang, Wen-Yang Zhou and Meng-Hao Guo etc. from the [Tsinghua CSCG Group](https://cg.cs.tsinghua.edu.cn/). If you are also interested in Jittor and want to improve it, Please join us!
|
||||
|
||||
Jittor目前由来自[清华大学计算机图形学组](https://cg.cs.tsinghua.edu.cn/)的梁盾,杨国烨,杨国炜和周文洋等博士生维护。 如果您也对Jittor感兴趣并希望对其进行改进,请加入我们!
|
||||
Jittor目前由来自[清华大学计算机图形学组](https://cg.cs.tsinghua.edu.cn/)的梁盾,杨国烨,杨国炜,周文洋和国孟昊等博士生维护。 如果您也对Jittor感兴趣并希望对其进行改进,请加入我们!
|
||||
|
||||
## License
|
||||
|
||||
|
|
|
@ -16,13 +16,16 @@
|
|||
namespace jittor {
|
||||
|
||||
#ifndef JIT
|
||||
CurandRandomOp::CurandRandomOp(NanoVector shape, NanoString dtype) {
|
||||
CurandRandomOp::CurandRandomOp(NanoVector shape, NanoString dtype, NanoString type) {
|
||||
flags.set(NodeFlags::_cuda, 1);
|
||||
output = create_output(shape, dtype);
|
||||
this->type = type;
|
||||
ASSERT(type == ns_normal || type == ns_uniform);
|
||||
}
|
||||
|
||||
void CurandRandomOp::jit_prepare() {
|
||||
add_jit_define("T", output->dtype());
|
||||
add_jit_define("R", type);
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
@ -31,13 +34,14 @@ void CurandRandomOp::jit_run() {
|
|||
}
|
||||
#else // JIT_cuda
|
||||
void CurandRandomOp::jit_run() {
|
||||
@define(TT,@if(@strcmp(@T,float32)==0,,Double))
|
||||
|
||||
auto* __restrict__ x = output->ptr<T>();
|
||||
index_t num = output->num;
|
||||
if (sizeof(T) == 4) {
|
||||
checkCudaErrors( curandGenerateUniform(gen, (float*)x, num) );
|
||||
} else {
|
||||
checkCudaErrors( curandGenerateUniformDouble(gen, (float64*)x, num) );
|
||||
}
|
||||
@if(@strcmp(@R,uniform)==0,
|
||||
checkCudaErrors(curandGenerateUniform@TT (gen, x, num));,
|
||||
checkCudaErrors(curandGenerateNormal@TT (gen, x, num, 0, 1));
|
||||
)
|
||||
}
|
||||
#endif // JIT_cpu
|
||||
#endif // JIT
|
||||
|
|
|
@ -13,7 +13,8 @@ namespace jittor {
|
|||
|
||||
struct CurandRandomOp : Op {
|
||||
Var* output;
|
||||
CurandRandomOp(NanoVector shape, NanoString dtype=ns_float32);
|
||||
NanoString type;
|
||||
CurandRandomOp(NanoVector shape, NanoString dtype=ns_float32, NanoString type=ns_uniform);
|
||||
|
||||
const char* name() const override { return "curand_random"; }
|
||||
DECLARE_jit_run;
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
__version__ = '1.1.7.12'
|
||||
__version__ = '1.1.7.13'
|
||||
from . import lock
|
||||
with lock.lock_scope():
|
||||
from . import compiler
|
||||
|
@ -99,6 +99,8 @@ class log_capture_scope(_call_no_record_scope):
|
|||
LOG.log_capture_start()
|
||||
try:
|
||||
self.fs.__enter__()
|
||||
if "log_v" in self.fs.jt_flags:
|
||||
LOG.log_v = self.fs.jt_flags["log_v"]
|
||||
return self.logs
|
||||
except:
|
||||
LOG.log_capture_stop()
|
||||
|
@ -108,6 +110,8 @@ class log_capture_scope(_call_no_record_scope):
|
|||
def __exit__(self, *exc):
|
||||
global single_log_capture
|
||||
self.fs.__exit__(*exc)
|
||||
if "log_v" in self.fs.jt_flags:
|
||||
LOG.log_v = flags.log_v
|
||||
LOG.log_capture_stop()
|
||||
self.logs.extend(LOG.log_capture_read())
|
||||
single_log_capture = None
|
||||
|
|
|
@ -46,6 +46,7 @@ def compile(compiler, flags, inputs, output, combind_build=False):
|
|||
run_cmd(cmd)
|
||||
return True
|
||||
link = link_flags
|
||||
base_output = output.split('/')[-1].split('.')[0]
|
||||
# if output is core, add core_link_flags
|
||||
if output.startswith("jittor_core"):
|
||||
link = link + core_link_flags
|
||||
|
@ -77,7 +78,7 @@ def compile(compiler, flags, inputs, output, combind_build=False):
|
|||
cc = nvcc_path
|
||||
cmd = f"{cc} {input} {nflags} -c {lto_flags} -o {obj_file}"
|
||||
cmds.append(cmd)
|
||||
jit_utils.run_cmds(cmds, cache_path, jittor_path, "compiling")
|
||||
jit_utils.run_cmds(cmds, cache_path, jittor_path, "Compiling "+base_output)
|
||||
cmd = f"{compiler} {' '.join(obj_files)} {flags} {lto_flags} {link} -o {output}"
|
||||
return do_compile(cmd)
|
||||
|
||||
|
|
|
@ -11,22 +11,22 @@ import numpy as np
|
|||
import math
|
||||
|
||||
def constant(shape, dtype, value=0.0):
|
||||
return jt.array(np.ones(shape)*value).unary(dtype)
|
||||
return jt.array(value).unary(dtype).broadcast(shape)
|
||||
|
||||
def constant_(var, value=0.0):
|
||||
var.assign(constant(tuple(var.shape), var.dtype, value))
|
||||
var.assign(constant(var.shape, var.dtype, value))
|
||||
|
||||
def uniform(shape, dtype, low, high):
|
||||
return jt.array(np.random.uniform(low, high, shape)).unary(dtype)
|
||||
return jt.random(shape, dtype) * (low - high) + high
|
||||
|
||||
def uniform_(var, low, high):
|
||||
var.assign(uniform(tuple(var.shape), var.dtype, low, high))
|
||||
var.assign(uniform(var.shape, var.dtype, low, high))
|
||||
|
||||
def gauss(shape, dtype, mean=0.0, std=1.0):
|
||||
return jt.array(np.random.normal(mean, std, shape)).unary(dtype)
|
||||
return jt.random(shape, dtype, "normal") * std + mean
|
||||
|
||||
def gauss_(var, mean=0.0, std=1.0):
|
||||
var.assign(gauss(tuple(var.shape), var.dtype, mean, std))
|
||||
var.assign(gauss(var.shape, var.dtype, mean, std))
|
||||
|
||||
def invariant_uniform(shape, dtype, mode="fan_in"):
|
||||
assert len(shape)>1
|
||||
|
|
|
@ -38,7 +38,8 @@ def check_equal_with_istrain(arr, j_layer, p_layer, is_train=True, has_running=T
|
|||
assert np.allclose(p_layer.running_mean.detach().numpy(), j_layer.running_mean.numpy(), threshold)
|
||||
else:
|
||||
assert np.allclose(p_layer.layer.running_mean.detach().numpy(), j_layer.running_mean.numpy(), threshold)
|
||||
assert np.allclose(pytorch_result.detach().numpy(), jittor_result.numpy(), threshold)
|
||||
assert np.allclose(pytorch_result.detach().numpy(), jittor_result.numpy(), 1e-2, threshold), \
|
||||
( np.abs(pytorch_result.detach().numpy() - jittor_result.numpy()).max() )
|
||||
|
||||
def check_equal_without_istrain(arr, j_layer, p_layer, threshold=1e-5):
|
||||
jittor_arr = jt.array(arr)
|
||||
|
|
|
@ -112,7 +112,7 @@ def check_backward(xshape, wshape, stride, padding, dilation, use_cuda, nhwc):
|
|||
op_name = "mkl_conv"
|
||||
|
||||
with jt.log_capture_scope(use_cuda=use_cuda, enable_tuner=1,
|
||||
log_v=1, log_vprefix="op.cc=100,exe=1000,conv_t=1000", compile_options={"test":244}
|
||||
log_v=1, log_vprefix="op.cc=1000,exe=1000,conv_t=1000", compile_options={"test":244}
|
||||
) as raw_log:
|
||||
x = jt.random(xshape)
|
||||
w = jt.random(wshape)
|
||||
|
|
|
@ -50,10 +50,10 @@ class TestInit(unittest.TestCase):
|
|||
torch.manual_seed(0)
|
||||
|
||||
def test_conv(self):
|
||||
check(jt.nn.Conv(64, 256, 3), torch.nn.Conv2d(64, 256, 3), rtol=1e-1, mean_atol=1e-3)
|
||||
check(jt.nn.Conv(64, 256, 3), torch.nn.Conv2d(64, 256, 3), rtol=1e-1, mean_atol=1e-2)
|
||||
|
||||
def test_resnet(self):
|
||||
check(models.resnet152(), torchvision.models.resnet152(), rtol=2e-2, mean_atol=1e-2)
|
||||
check(models.resnet152(), torchvision.models.resnet152(), rtol=5e-2, mean_atol=1e-2)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -28,8 +28,7 @@ def find_log_with_re(logs, pattern=None, **args):
|
|||
|
||||
class TestLog(unittest.TestCase):
|
||||
def test_log_capture(self):
|
||||
LOG.log_capture_start()
|
||||
with jt.flag_scope(log_v=1000, log_vprefix=""):
|
||||
with jt.log_capture_scope(log_v=1000, log_vprefix="") as logs:
|
||||
LOG.v("1")
|
||||
LOG.vv("2")
|
||||
LOG.i("3")
|
||||
|
@ -37,10 +36,8 @@ class TestLog(unittest.TestCase):
|
|||
LOG.e("5")
|
||||
a = jt.zeros([10])
|
||||
a.sync()
|
||||
LOG.log_capture_stop()
|
||||
# TODO: why need manually delete this variable?
|
||||
del a
|
||||
logs = LOG.log_capture_read()
|
||||
logs2 = LOG.log_capture_read()
|
||||
assert len(logs2)==0
|
||||
|
||||
|
|
|
@ -156,8 +156,8 @@ class TestMatmul(unittest.TestCase):
|
|||
loss_mean.data.sum()
|
||||
jt.liveness_info()
|
||||
|
||||
# result is 0.00038617782411165535
|
||||
result = 0.00038617782411165535
|
||||
# result is 0.00022486248053610325
|
||||
result = 0.00022486248053610325
|
||||
assert abs(loss_mean.data - result) < 1e-6, [loss_mean.data, result]
|
||||
jt.clean()
|
||||
|
||||
|
@ -255,9 +255,9 @@ class TestMatmul(unittest.TestCase):
|
|||
loss_mean.data.sum()
|
||||
jt.liveness_info()
|
||||
|
||||
# result is 0.00038617782411165535
|
||||
result = 0.00038617782411165535
|
||||
assert abs(loss_mean.data - result) < 1e-6
|
||||
# result is 0.00018236637697555125
|
||||
result = 0.00018236637697555125
|
||||
assert abs(loss_mean.data - result) < 1e-2
|
||||
jt.clean()
|
||||
|
||||
@unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found")
|
||||
|
|
|
@ -43,5 +43,23 @@ class TestRandomOp(unittest.TestCase):
|
|||
t.data
|
||||
logs = find_log_with_re(raw_log, "(Jit op key (not )?found: " + "curand_random" + ".*)")
|
||||
assert len(logs)==1
|
||||
|
||||
def test_normal(self):
|
||||
from jittor import init
|
||||
n = 10000
|
||||
r = 0.155
|
||||
a = init.gauss([n], "float32", 1, 3)
|
||||
data = a.data
|
||||
|
||||
assert (np.abs((data<(1-3)).mean() - r) < 0.1)
|
||||
assert (np.abs((data<(1)).mean() - 0.5) < 0.1)
|
||||
assert (np.abs((data<(1+3)).mean() - (1-r)) < 0.1)
|
||||
|
||||
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_normal_cuda(self):
|
||||
self.test_normal()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
@ -108,8 +108,7 @@ def pass_asm(cc_path,s_path):
|
|||
|
||||
output_path=s_path.replace(".post.s",".s")
|
||||
with open(output_path,"w") as f:
|
||||
for line in s_content:
|
||||
f.write(line)
|
||||
f.write("".join(s_content))
|
||||
|
||||
def run_cmd(cmd):
|
||||
LOG.vvvv(f"Run cmd: {cmd}")
|
||||
|
|
|
@ -73,6 +73,7 @@ class DelayProgress:
|
|||
if used > 2:
|
||||
eta = used / (i+1) * (self.n-i-1)
|
||||
print(f"{self.msg}({i+1}/{self.n}) used: {used:.3f}s eta: {eta:.3f}s", end='\r')
|
||||
if i==self.n-1: print()
|
||||
|
||||
# check is in jupyter notebook
|
||||
def in_ipynb():
|
||||
|
@ -153,7 +154,7 @@ def run_cmds(cmds, cache_path, jittor_path, msg="run_cmds"):
|
|||
if pool_size == 0:
|
||||
mem_bytes = os.sysconf('SC_PAGE_SIZE') * os.sysconf('SC_PHYS_PAGES')
|
||||
mem_gib = mem_bytes/(1024.**3)
|
||||
pool_size = min(8,max(int(mem_gib // 3), 1))
|
||||
pool_size = min(16,max(int(mem_gib // 3), 1))
|
||||
LOG.i(f"Total mem: {mem_gib:.2f}GB, using {pool_size} procs for compiling.")
|
||||
cmds = [ [cmd, cache_path, jittor_path] for cmd in cmds ]
|
||||
bk = mp.current_process()._config.get('daemon')
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include "fused_op.h"
|
||||
#include "fuser.h"
|
||||
#include "profiler/profiler_guard.h"
|
||||
#include "parallel_compiler.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
|
@ -28,6 +29,43 @@ Executor exe;
|
|||
// from fetch_op.cc
|
||||
extern list<VarPtr> fetcher_to_free;
|
||||
|
||||
void load_fused_op(FusedOp& fused_op, vector<int>& fuse_ops, vector<Op*>& ops, int ll, int rr, int64 tt) {
|
||||
fused_op.ops.clear();
|
||||
fused_op.edges.clear();
|
||||
auto ntt = ++Node::tflag_count;
|
||||
for (int i=ll; i<rr; i++) {
|
||||
int opid = fuse_ops[i];
|
||||
Op* op = ops[opid];
|
||||
uint64_t fid1 = fused_op.ops.size();
|
||||
op->custom_data = fid1;
|
||||
op->tflag = ntt;
|
||||
fused_op.ops.push_back(op);
|
||||
}
|
||||
for (Op* op : fused_op.ops) {
|
||||
uint fid1 = op->custom_data;
|
||||
uint oid = 0;
|
||||
for (Var* v : op->outputs()) {
|
||||
oid++;
|
||||
if (v->tflag != tt) {
|
||||
// this var node not belong to current execution
|
||||
// this will happend in multiple outputs fuseable op
|
||||
// v->custom_data = 0 represents this var cannot be fused
|
||||
v->custom_data = 0;
|
||||
continue;
|
||||
}
|
||||
for (auto o : v->outputs_with_index()) {
|
||||
Op* op2 = o.op;
|
||||
uint iid = o.index;
|
||||
if (op2->tflag != ntt) continue;
|
||||
uint fid2 = op2->custom_data;
|
||||
fused_op.edges.emplace_back(fid1, oid-1, fid2, iid);
|
||||
}
|
||||
}
|
||||
}
|
||||
LOGvvv << "Prepare fused_op" << fused_op.ops;
|
||||
fused_op.update_ops();
|
||||
}
|
||||
|
||||
void Executor::run_sync(vector<Var*> vars, bool device_sync) {
|
||||
auto allocator = get_allocator();
|
||||
this->allocator = allocator;
|
||||
|
@ -316,10 +354,13 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
|
|||
for (int i=0; i<var_num; i++) {
|
||||
all_vars[i]->custom_data = var_fused[i]==1;
|
||||
}
|
||||
FusedOp fused_op;
|
||||
|
||||
// compile all ops, prevent compiling during running
|
||||
parallel_compile_all_ops(queue, range, fused_op, fuse_ops, ops, tt);
|
||||
|
||||
// running
|
||||
SetupFreeBuffer setup_free_buffer;
|
||||
FusedOp fused_op;
|
||||
vector<Var*> outputs_bk;
|
||||
#ifdef HAS_CUDA
|
||||
int sync_times = 0;
|
||||
|
@ -332,41 +373,9 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
|
|||
if (op->type() != OpType::other) {
|
||||
op = &fused_op;
|
||||
is_fused_op = true;
|
||||
fused_op.ops.clear();
|
||||
fused_op.edges.clear();
|
||||
int ll = (rid<queue.size()-1)?range[queue.size()-rid-2]:0, rr = range[queue.size()-rid-1];
|
||||
root = fuse_ops[rr-1];
|
||||
auto ntt = ++Node::tflag_count;
|
||||
for (int i=ll; i<rr; i++) {
|
||||
int opid = fuse_ops[i];
|
||||
Op* op = ops[opid];
|
||||
uint64_t fid1 = fused_op.ops.size();
|
||||
op->custom_data = fid1;
|
||||
op->tflag = ntt;
|
||||
fused_op.ops.push_back(op);
|
||||
}
|
||||
for (Op* op : fused_op.ops) {
|
||||
uint fid1 = op->custom_data;
|
||||
uint oid = 0;
|
||||
for (Var* v : op->outputs()) {
|
||||
oid++;
|
||||
if (v->tflag != tt) {
|
||||
// this var node not belong to current execution
|
||||
// this will happend in multiple outputs fuseable op
|
||||
v->custom_data = 0;
|
||||
continue;
|
||||
}
|
||||
for (auto o : v->outputs_with_index()) {
|
||||
Op* op2 = o.op;
|
||||
uint iid = o.index;
|
||||
if (op2->tflag != ntt) continue;
|
||||
uint fid2 = op2->custom_data;
|
||||
fused_op.edges.emplace_back(fid1, oid-1, fid2, iid);
|
||||
}
|
||||
}
|
||||
}
|
||||
LOGvvv << "Prepare fused_op" << fused_op.ops;
|
||||
fused_op.update_ops();
|
||||
load_fused_op(fused_op, fuse_ops, ops, ll, rr, tt);
|
||||
}
|
||||
LOGvvv << "Run" << op;
|
||||
if (!op->shape_infered()) op->infer_shape();
|
||||
|
@ -430,7 +439,7 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
|
|||
var->finish_pending_liveness();
|
||||
} catch (const std::exception& e) {
|
||||
// log memory info
|
||||
display_memory_info(__FILELINE__);
|
||||
display_memory_info(__FILELINE__, false, true);
|
||||
// log jit_key and file location
|
||||
op->do_prepare();
|
||||
string jit_src_path = Op::get_filename_from_jit_key(jk.to_cstring(), ".cc");
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||
// Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
// Authors: Dun Liang <randonlang@gmail.com>.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
|
@ -20,5 +21,7 @@ struct Executor {
|
|||
};
|
||||
|
||||
extern Executor exe;
|
||||
|
||||
void load_fused_op(FusedOp& fused_op, vector<int>& fuse_ops, vector<Op*>& ops, int ll, int rr, int64 tt);
|
||||
|
||||
} // jittor
|
|
@ -118,6 +118,18 @@ FusedOp::FusedOp() {
|
|||
Op::number_of_lived_ops--;
|
||||
}
|
||||
|
||||
FusedOp::FusedOp(const FusedOp& other) {
|
||||
Op::number_of_lived_ops--;
|
||||
ops = other.ops;
|
||||
edges = other.edges;
|
||||
vars = other.vars;
|
||||
loop_options_merged = other.loop_options_merged;
|
||||
loop_options_tuned = other.loop_options_tuned;
|
||||
loop_options = other.loop_options;
|
||||
loop_options_origin = other.loop_options_origin;
|
||||
context = other.context;
|
||||
}
|
||||
|
||||
FusedOp::~FusedOp() {
|
||||
_outputs.clear();
|
||||
Op::number_of_lived_ops++;
|
||||
|
@ -215,7 +227,7 @@ void FusedOp::do_run_after_prepare() {
|
|||
LOGvv << "Jit op key not found:" << jit_key;
|
||||
// compile JIT op
|
||||
context = new FusedOpContext();
|
||||
context->vrm.fop = this;
|
||||
context->setup(this);
|
||||
string prev_jit_key = jit_key;
|
||||
context->entry = OpCompiler::do_compile(this);
|
||||
string new_jit_key = get_jit_key();
|
||||
|
@ -225,6 +237,25 @@ void FusedOp::do_run_after_prepare() {
|
|||
Profiler::record_and_run(context->entry, this, new_jit_key.c_str());
|
||||
}
|
||||
|
||||
void FusedOpContext::setup(FusedOp* fop) {
|
||||
node_id.clear();
|
||||
vrm.fop = fop;
|
||||
for (int i=0; i<fop->ops.size(); i++)
|
||||
node_id[fop->ops[i]] = i;
|
||||
for (int i=0; i<fop->vars.size(); i++)
|
||||
node_id[fop->vars[i].var] = i;
|
||||
}
|
||||
|
||||
int FusedOp::get_node_id(Node* node) {
|
||||
ASSERT(context);
|
||||
return context->node_id.at(node);
|
||||
}
|
||||
|
||||
int FusedOp::has(Node* node) {
|
||||
ASSERT(context);
|
||||
return context->node_id.count(node);
|
||||
}
|
||||
|
||||
void FusedOp::do_run(){
|
||||
do_prepare();
|
||||
do_run_after_prepare();
|
||||
|
|
|
@ -19,8 +19,12 @@ std::ostream& operator<<(std::ostream& os, const VarInfo& vi);
|
|||
struct FusedOpContext {
|
||||
VarRelayManager vrm;
|
||||
jit_op_entry_t entry;
|
||||
unordered_map<Node*, int> node_id;
|
||||
void setup(FusedOp* fop);
|
||||
};
|
||||
|
||||
extern string_view_map<FusedOpContext*> jit_fused_ops;
|
||||
|
||||
struct FusedOp final : Op {
|
||||
vector<Op*> ops;
|
||||
// edges: [[i,j,k,l], ...] represents opi.output(j) == opk.input(i)
|
||||
|
@ -31,8 +35,11 @@ struct FusedOp final : Op {
|
|||
loop_options_t& get_loop_options_tuned();
|
||||
FusedOpContext* context;
|
||||
|
||||
int get_node_id(Node* node);
|
||||
int has(Node* node);
|
||||
void update_ops();
|
||||
FusedOp();
|
||||
FusedOp(const FusedOp& other);
|
||||
~FusedOp();
|
||||
|
||||
int get_loop_option(const string& key, const int& _default=0);
|
||||
|
|
|
@ -61,9 +61,6 @@ static string get_symbol_name(const string& jit_key) {
|
|||
}
|
||||
|
||||
jit_op_entry_t compile(const string& jit_key, const string& src, const bool is_cuda_op, const string& extra_flags) {
|
||||
auto iter = jit_ops.find(jit_key);
|
||||
if (iter != jit_ops.end())
|
||||
return iter->second;
|
||||
LOGvv << "Compile op" << jit_key;
|
||||
// compiler do not allowed filename too long
|
||||
CHECK(cc_path.size());
|
||||
|
@ -92,7 +89,6 @@ jit_op_entry_t compile(const string& jit_key, const string& src, const bool is_c
|
|||
cache_compile(cmd, cache_path, jittor_path);
|
||||
auto symbol_name = get_symbol_name(jit_key);
|
||||
auto jit_entry = load_jit_lib(jit_lib_path, symbol_name);
|
||||
jit_ops[jit_key] = jit_entry;
|
||||
return jit_entry;
|
||||
}
|
||||
|
||||
|
|
|
@ -12,7 +12,7 @@ namespace jittor {
|
|||
|
||||
const int page_size = 4*1024;
|
||||
|
||||
extern size_t protected_page;
|
||||
extern thread_local size_t protected_page;
|
||||
|
||||
static size_t get_buffer_end_page(size_t buffer_end) {
|
||||
// get the last complete page in buffer
|
||||
|
@ -112,6 +112,6 @@ vector<pair<string,string>> parse_jit_keys(const string& s) {
|
|||
return jit_keys;
|
||||
}
|
||||
|
||||
JitKey jk;
|
||||
thread_local JitKey jk;
|
||||
|
||||
} // jittor
|
|
@ -65,7 +65,7 @@ struct JitKey {
|
|||
};
|
||||
};
|
||||
|
||||
extern JitKey jk;
|
||||
extern thread_local JitKey jk;
|
||||
typedef JitKey JK;
|
||||
|
||||
inline JK& operator<<(JK& jk, const char* s) {
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
namespace jittor {
|
||||
|
||||
static int lock_fd = -1;
|
||||
int _has_lock = 0;
|
||||
|
||||
void set_lock_path(string path) {
|
||||
lock_fd = open(path.c_str(), O_RDWR);
|
||||
|
@ -32,6 +33,7 @@ void lock() {
|
|||
.l_len = 0
|
||||
};
|
||||
ASSERT(fcntl(lock_fd, F_SETLKW, &lock) == 0);
|
||||
_has_lock = 1;
|
||||
LOGvv << "LOCK Pid:" << getpid();
|
||||
}
|
||||
|
||||
|
@ -44,6 +46,7 @@ void unlock() {
|
|||
.l_len = 0
|
||||
};
|
||||
ASSERT(fcntl(lock_fd, F_SETLKW, &lock) == 0);
|
||||
_has_lock = 0;
|
||||
LOGvv << "UNLOCK Pid:" << getpid();
|
||||
}
|
||||
|
||||
|
|
14
src/lock.h
14
src/lock.h
|
@ -18,9 +18,19 @@ void lock();
|
|||
|
||||
void unlock();
|
||||
|
||||
extern int _has_lock;
|
||||
|
||||
struct lock_guard {
|
||||
inline lock_guard() { lock(); }
|
||||
inline ~lock_guard() { unlock(); }
|
||||
int has_lock = 0;
|
||||
inline lock_guard() {
|
||||
if (_has_lock) return;
|
||||
has_lock = 1;
|
||||
lock();
|
||||
}
|
||||
inline ~lock_guard() {
|
||||
if (!has_lock) return;
|
||||
unlock();
|
||||
}
|
||||
};
|
||||
|
||||
} // jittor
|
||||
|
|
|
@ -41,9 +41,9 @@ std::ostream& operator<<(std::ostream& os, const FloatOutput& o) {
|
|||
return os << o.suffix;
|
||||
}
|
||||
|
||||
void display_memory_info(const char* fileline, bool dump_var) {
|
||||
void display_memory_info(const char* fileline, bool dump_var, bool red_color) {
|
||||
int p = 3;
|
||||
Log log(fileline, 'i', 0);
|
||||
Log log(fileline, red_color?'e':'i', 0);
|
||||
log << "\n=== display_memory_info ===\n";
|
||||
log << "total_cpu_ram:" <<
|
||||
FloatOutput{(double)mem_info.total_cpu_ram, " KMG", 1024, "B"};
|
||||
|
|
|
@ -9,7 +9,7 @@
|
|||
namespace jittor {
|
||||
|
||||
// @pyjt(display_memory_info)
|
||||
void display_memory_info(const char* fileline="", bool dump_var=false);
|
||||
void display_memory_info(const char* fileline="", bool dump_var=false, bool red_color=false);
|
||||
|
||||
// @pyjt(MemInfo)
|
||||
struct MemInfo {
|
||||
|
|
|
@ -74,6 +74,9 @@ namespace jittor {
|
|||
m(cosh) \
|
||||
m(acosh) \
|
||||
m(sigmoid) \
|
||||
\
|
||||
m(uniform) \
|
||||
m(normal) \
|
||||
|
||||
struct NanoString;
|
||||
#define DECLEAR_NS(T) extern NanoString ns_##T;
|
||||
|
|
|
@ -110,9 +110,12 @@ void Op::do_jit_prepare() {
|
|||
// check use_cuda_op from outputs may not be enough
|
||||
bool use_cuda_op = use_cuda;
|
||||
for (Var* var : inputs()) {
|
||||
if (var->allocator) {
|
||||
if (var->mem_ptr) {
|
||||
/* jit key don't include here, because
|
||||
parallel compiler don't known
|
||||
jk << JK::key << "alloc_i" << JK::hex1(in_id)
|
||||
<< JK::hex1(var->allocator->flags()) << JK::end;
|
||||
*/
|
||||
use_cuda_op &= var->allocator->is_cuda();
|
||||
}
|
||||
if (var->num >= std::numeric_limits<int32_t>::max())
|
||||
|
@ -120,9 +123,11 @@ void Op::do_jit_prepare() {
|
|||
in_id ++;
|
||||
}
|
||||
for (Var* var : outputs()) {
|
||||
if (var->allocator) {
|
||||
if (var->mem_ptr) {
|
||||
/*
|
||||
jk << JK::key << "alloc_o" << JK::hex1(in_id)
|
||||
<< JK::hex1(var->allocator->flags()) << JK::end;
|
||||
*/
|
||||
use_cuda_op &= var->allocator->is_cuda();
|
||||
}
|
||||
if (var->num >= std::numeric_limits<int32_t>::max())
|
||||
|
|
|
@ -74,22 +74,25 @@ string OpCompiler::get_name_by_op_var(Op* op, Var* var) {
|
|||
var_id++;
|
||||
}
|
||||
ASSERT(found);
|
||||
ASSERT(op->custom_data<(int)op_members.size());
|
||||
auto& v = op_members[op->custom_data];
|
||||
ASSERT(this->op);
|
||||
ASSERT(this->op->context);
|
||||
auto opid = this->op->context->node_id.at(op);
|
||||
ASSERT(opid<(int)op_members.size());
|
||||
auto& v = op_members[opid];
|
||||
ASSERT(var_id < v.size());
|
||||
return v[var_id];
|
||||
}
|
||||
|
||||
string OpCompiler::get_name_by_op_input(Op* op, uint i) {
|
||||
return op_members.at(op->custom_data).at(i);
|
||||
return op_members.at(this->op->get_node_id(op)).at(i);
|
||||
}
|
||||
|
||||
string OpCompiler::get_name_by_op_output(Op* op, uint i) {
|
||||
return op_members.at(op->custom_data).at(i+op->inputs().size());
|
||||
return op_members.at(this->op->get_node_id(op)).at(i+op->inputs().size());
|
||||
}
|
||||
|
||||
bool OpCompiler::op_exist(Op* op) {
|
||||
return op_members.at(op->custom_data).size();
|
||||
return op_members.at(this->op->get_node_id(op)).size();
|
||||
}
|
||||
|
||||
int OpCompiler::total_member_count() {
|
||||
|
@ -733,6 +736,25 @@ string OpCompiler::get_fused_src(FusedOp* op) {
|
|||
return OpCompiler::__get_fused_src(op->ops, op_srcs, op_members);
|
||||
}
|
||||
|
||||
static void fix_op_member(
|
||||
const vector<Op*>& ops,
|
||||
vector<vector<string>>& op_members
|
||||
) {
|
||||
// fill op member: [in0, in1, ... inN, fill, fill, out0, out1, ...]
|
||||
for (int i=0; i<ops.size(); i++) {
|
||||
auto op = ops[i];
|
||||
auto var_num = op->inputs().size() + op->outputs().size();
|
||||
auto& member = op_members.at(i);
|
||||
if (!member.size()) {
|
||||
continue;
|
||||
}
|
||||
ASSERT(member.size() <= var_num);
|
||||
while (member.size() < var_num) {
|
||||
member.insert(member.end() - op->outputs().size(), "__fill__");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
string OpCompiler::__get_fused_src(
|
||||
const vector<Op*>& ops,
|
||||
const vector<string>& op_srcs,
|
||||
|
@ -908,6 +930,7 @@ string OpCompiler::__get_fused_src(
|
|||
break;
|
||||
}
|
||||
}
|
||||
fix_op_member(ops, op_members);
|
||||
CHECK(!(defs.count("JIT_cpu") && defs.count("JIT_cuda")))
|
||||
<< "CPU op and GPU op cannot be fused together.";
|
||||
|
||||
|
|
|
@ -14,36 +14,42 @@
|
|||
namespace jittor {
|
||||
|
||||
#ifndef JIT
|
||||
RandomOp::RandomOp(NanoVector shape, NanoString dtype) {
|
||||
RandomOp::RandomOp(NanoVector shape, NanoString dtype, NanoString type) {
|
||||
// auto curand_random = get_op_info("curand_random")
|
||||
// .get_constructor<NanoVector, NanoString>();
|
||||
// output = curand_random(shape, dtype);
|
||||
#ifdef HAS_CUDA
|
||||
if (use_cuda) {
|
||||
static VarPtr(*curand_random)(NanoVector, NanoString) = nullptr;
|
||||
static VarPtr(*curand_random)(NanoVector, NanoString, NanoString) = nullptr;
|
||||
if (!curand_random && has_op("curand_random")) {
|
||||
curand_random = get_op_info("curand_random")
|
||||
.get_constructor<VarPtr, NanoVector, NanoString>();
|
||||
.get_constructor<VarPtr, NanoVector, NanoString, NanoString>();
|
||||
}
|
||||
if (curand_random) {
|
||||
auto var = curand_random(shape, dtype);
|
||||
auto var = curand_random(shape, dtype, type);
|
||||
forward(var);
|
||||
return;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
output = create_output(shape, dtype);
|
||||
this->type = type;
|
||||
ASSERT(type == ns_normal || type == ns_uniform);
|
||||
}
|
||||
|
||||
void RandomOp::jit_prepare() {
|
||||
add_jit_define("T", output->dtype());
|
||||
add_jit_define("R", type);
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
#ifdef JIT_cpu
|
||||
void RandomOp::jit_run() {
|
||||
auto* generator = get_random_engine();
|
||||
std::uniform_real_distribution<T> distribution(0.0,1.0);
|
||||
@if(@strcmp(@R,uniform)==0,
|
||||
std::uniform_real_distribution<T> distribution(0.0,1.0);,
|
||||
std::normal_distribution<T> distribution(0.0,1.0);
|
||||
)
|
||||
auto* __restrict__ x = output->ptr<T>();
|
||||
index_t num = output->num;
|
||||
for (index_t i=0; i<num; i++)
|
||||
|
|
|
@ -10,7 +10,8 @@ namespace jittor {
|
|||
|
||||
struct RandomOp : Op {
|
||||
Var* output;
|
||||
RandomOp(NanoVector shape, NanoString dtype=ns_float32);
|
||||
NanoString type;
|
||||
RandomOp(NanoVector shape, NanoString dtype=ns_float32, NanoString type=ns_uniform);
|
||||
|
||||
const char* name() const override { return "random"; }
|
||||
DECLARE_jit_run;
|
||||
|
|
|
@ -9,6 +9,7 @@
|
|||
#include "mem/allocator.h"
|
||||
#include "opt/pass_manager.h"
|
||||
#include "opt/pass/assume_aligned_pass.h"
|
||||
#include "executor.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
|
@ -31,7 +32,7 @@ void AssumeAlignedPass::run() {
|
|||
Var* var;
|
||||
pm->oc->get_op_var_by_name(name, op_id, opvar_id, op, var);
|
||||
// add assume_aligned if is aligned_allocator
|
||||
if (var->allocator->is_aligned()) {
|
||||
if (exe.allocator->is_aligned()) {
|
||||
// if is a function arguments
|
||||
if (l == ls[0])
|
||||
func->push_front("assume_aligned("+lvalue+");");
|
||||
|
|
|
@ -24,9 +24,10 @@ void LoopVarAnalyzePass::run() {
|
|||
auto& op_members = this->pm->oc->op_members;
|
||||
// TODO: fix it
|
||||
// ugly temp fix for index_var
|
||||
auto opid = this->op->get_node_id(op);
|
||||
if (op->name()==string("index") &&
|
||||
op->inputs().size()+op->outputs().size() != op_members[op->custom_data].size()) {
|
||||
op_members[op->custom_data].insert(op_members[op->custom_data].begin(), "wtf");
|
||||
op->inputs().size()+op->outputs().size() != op_members[opid].size()) {
|
||||
op_members[opid].insert(op_members[opid].begin(), "wtf");
|
||||
}
|
||||
}
|
||||
// LoopVarAnalyzePass has three steps:
|
||||
|
|
|
@ -221,11 +221,11 @@ void ConvTuner::forwardTune(FusedOp* fop) {
|
|||
auto op_iop = op->input(0)->input();
|
||||
if (!(op_iop
|
||||
&& op_iop->name_ex()=="binary.multiply"
|
||||
&& op_iop->tflag==op->tflag))
|
||||
&& fop->has(op_iop)))
|
||||
continue;
|
||||
auto bop = (BinaryOp*)op_iop;
|
||||
|
||||
if (!(bop->y->input() && bop->x->input() && bop->x->input()->tflag==op->tflag && bop->y->input()->tflag==op->tflag)) continue;
|
||||
if (!(bop->y->input() && bop->x->input() && fop->has(bop->x->input()) && fop->has(bop->y->input()))) continue;
|
||||
if (!(bop->x->input()->type()==OpType::broadcast && bop->y->input()->type()==OpType::broadcast)) return;
|
||||
|
||||
// only support float32 currently
|
||||
|
|
|
@ -22,12 +22,12 @@ void MatmulTuner::run(PassManager* pm, TunerManager* tm) {
|
|||
for (Op* op : fop->ops) {
|
||||
if (op->name_ex()!="reduce.add") continue;
|
||||
auto rop = (ReduceOp*)op;
|
||||
if (!(rop->x->input() && rop->x->input()->name_ex()=="binary.multiply" && rop->x->input()->tflag==op->tflag))
|
||||
if (!(rop->x->input() && rop->x->input()->name_ex()=="binary.multiply" && fop->has(rop->x->input())))
|
||||
continue;
|
||||
auto bop = (BinaryOp*)(rop->x->input());
|
||||
if (!(bop->x->input() && bop->x->input()->name_ex()=="broadcast_to" && bop->x->input()->tflag==op->tflag))
|
||||
if (!(bop->x->input() && bop->x->input()->name_ex()=="broadcast_to" && fop->has(bop->x->input())))
|
||||
continue;
|
||||
if (!(bop->y->input() && bop->y->input()->name_ex()=="broadcast_to" && bop->y->input()->tflag==op->tflag))
|
||||
if (!(bop->y->input() && bop->y->input()->name_ex()=="broadcast_to" && fop->has(bop->y->input())))
|
||||
continue;
|
||||
auto bcop1 = (BroadcastToOp*)(bop->x->input());
|
||||
auto bcop2 = (BroadcastToOp*)(bop->y->input());
|
||||
|
|
|
@ -31,7 +31,6 @@ template <class T> void TunerManager::run_tuner(PassManager* pm) {
|
|||
}
|
||||
|
||||
string TunerManager::tune() {
|
||||
auto tmp = Var::number_of_lived_vars;
|
||||
PassManager pm(oc);
|
||||
string src_after_passes;
|
||||
pm.run_passes();
|
||||
|
@ -60,7 +59,6 @@ string TunerManager::tune() {
|
|||
}
|
||||
}
|
||||
}
|
||||
ASSERTop(Var::number_of_lived_vars,==,tmp) << (print_trace(), 0) << oc->op->ops << best_tuner->candidates;
|
||||
return src_after_passes;
|
||||
}
|
||||
|
||||
|
|
|
@ -35,7 +35,7 @@ int VarRelayManager::add_relay_group(const vector<pair<Var*, Var*>>& group) {
|
|||
for (auto& g : relay_groups)
|
||||
for (auto& p : g.relayed_pairs)
|
||||
for (auto& p2 : group)
|
||||
if (p.second == (p2.second->custom_data>>2)) {
|
||||
if (p.second == (fop->get_node_id(p2.second))) {
|
||||
LOGvvvv << "Var allready relayed" << p2.second;
|
||||
return -1;
|
||||
}
|
||||
|
@ -43,8 +43,7 @@ int VarRelayManager::add_relay_group(const vector<pair<Var*, Var*>>& group) {
|
|||
auto& relay_group = relay_groups.back();
|
||||
relay_group.relayed_pairs.reserve(group.size());
|
||||
for (const auto& p : group) {
|
||||
// var->custom_data>>2: var id
|
||||
relay_group.relayed_pairs.push_back({p.first, p.second->custom_data>>2});
|
||||
relay_group.relayed_pairs.push_back({p.first, fop->get_node_id(p.second)});
|
||||
ASSERTop(p.first->size,==,p.second->size);
|
||||
}
|
||||
|
||||
|
@ -101,7 +100,7 @@ int VarRelayManager::add_relay_group(const vector<pair<Var*, Var*>>& group) {
|
|||
oprc.relayed_members[i] = -1;
|
||||
else {
|
||||
ASSERT(fnodes.count(v));
|
||||
oprc.relayed_members[i] = v->custom_data>>2;
|
||||
oprc.relayed_members[i] = fop->get_node_id(v);
|
||||
}
|
||||
LOGvvvv << "Relay op" << oprc.op->name() >>".">>
|
||||
op_info.var_members[i].first << "-->" <<
|
||||
|
@ -115,8 +114,8 @@ vector<pair<int,int>> VarRelayManager::get_op_relay_info(const vector<bool>& rel
|
|||
ASSERT(relay_switches.size()==relay_groups.size());
|
||||
auto num = fop->ops.size()+fop->vars.size();
|
||||
auto node_id = [&](Node* node) -> int {
|
||||
if (node->is_var()) return node->custom_data>>2;
|
||||
return node->custom_data + fop->vars.size();
|
||||
if (node->is_var()) return fop->get_node_id(node);
|
||||
return fop->get_node_id(node) + fop->vars.size();
|
||||
};
|
||||
vector<int> deps(num);
|
||||
// pair: first: group_id, second: relayed_pair id
|
||||
|
|
|
@ -0,0 +1,301 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
// Authors: Dun Liang <randonlang@gmail.com>.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include <atomic>
|
||||
#include <chrono>
|
||||
#include <thread>
|
||||
#include <tuple>
|
||||
#include <mutex>
|
||||
#include <condition_variable>
|
||||
#include <iomanip>
|
||||
|
||||
#include "parallel_compiler.h"
|
||||
#include "op_compiler.h"
|
||||
#include "executor.h"
|
||||
#include "lock.h"
|
||||
#include "opt/jit_searcher.h"
|
||||
#include "fused_op.h"
|
||||
|
||||
|
||||
namespace jittor {
|
||||
|
||||
DEFINE_FLAG(int, use_parallel_op_compiler, 16, "Number of threads that parallel op comiler used, default 16, set this value to 0 will disable parallel op compiler.");
|
||||
|
||||
// from log.cc
|
||||
extern int segfault_happen;
|
||||
|
||||
// simple thread used for parallel compilation
|
||||
struct SimpleThread {
|
||||
int id;
|
||||
typedef std::function<void(int)> Func;
|
||||
Func func;
|
||||
std::mutex mtx;
|
||||
std::condition_variable cv;
|
||||
std::thread thread;
|
||||
void run() {
|
||||
thread_name = "C"+S(id);
|
||||
try{
|
||||
std::unique_lock<std::mutex> lck(mtx);
|
||||
if (func)
|
||||
func(id);
|
||||
while (true) {
|
||||
cv.wait(lck);
|
||||
if (func) {
|
||||
func(id);
|
||||
} else
|
||||
return;
|
||||
}
|
||||
} catch (const std::exception& e) {
|
||||
LOGe << e.what();
|
||||
}
|
||||
}
|
||||
void launch_one(Func func) {
|
||||
std::unique_lock<std::mutex> lck(mtx);
|
||||
this->func = func;
|
||||
cv.notify_all();
|
||||
}
|
||||
SimpleThread(int id) : id(id), func(nullptr), thread(&SimpleThread::run, this) {}
|
||||
~SimpleThread() {
|
||||
join();
|
||||
}
|
||||
void join() {
|
||||
if (thread.joinable()) {
|
||||
launch_one(nullptr);
|
||||
thread.join();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct SimpleThreads {
|
||||
list<SimpleThread> threads;
|
||||
SimpleThreads(int n) {
|
||||
for (int i=0; i<n; i++)
|
||||
threads.emplace_back(i);
|
||||
}
|
||||
void launch_all(int active_thread, SimpleThread::Func func) {
|
||||
if (active_thread == 1) {
|
||||
func(0);
|
||||
return;
|
||||
}
|
||||
for (auto& t : threads) {
|
||||
t.launch_one(func);
|
||||
active_thread--;
|
||||
if (!active_thread)
|
||||
return;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
void parallel_compile_all_ops(vector<int>& queue, vector<int>& range, FusedOp& fused_op, vector<int>& fuse_ops, vector<Op*>& ops, int64 tt) {
|
||||
// jit_search_kernel require compile at runtime
|
||||
if (jit_search_kernel || !use_parallel_op_compiler)
|
||||
return;
|
||||
|
||||
vector<int> op_needs_compile;
|
||||
string_view_map<int> map;
|
||||
vector<unique_ptr<FusedOp>> fop_needs_compile;
|
||||
|
||||
for (uint rid=0; rid<queue.size(); rid++) {
|
||||
int root = queue[rid];
|
||||
Op* op = ops[root];
|
||||
bool is_fused_op = false;
|
||||
try {
|
||||
if (op->type() != OpType::other) {
|
||||
op = &fused_op;
|
||||
is_fused_op = true;
|
||||
int ll = (rid<queue.size()-1)?range[queue.size()-rid-2]:0, rr = range[queue.size()-rid-1];
|
||||
root = fuse_ops[rr-1];
|
||||
load_fused_op(fused_op, fuse_ops, ops, ll, rr, tt);
|
||||
}
|
||||
LOGvvv << "Check op needs compile:" << op;
|
||||
op->do_prepare();
|
||||
if (jk.empty()) continue;
|
||||
|
||||
const char* jit_key = jk.to_cstring();
|
||||
auto iter = jit_key_mapper.find(jit_key);
|
||||
if (iter != jit_key_mapper.end()) continue;
|
||||
|
||||
auto iter2 = map.find(jit_key);
|
||||
if (iter2 != map.end()) continue;
|
||||
|
||||
map[jit_key] = 1;
|
||||
if (is_fused_op) {
|
||||
op_needs_compile.push_back(-1-(int)fop_needs_compile.size());
|
||||
fop_needs_compile.emplace_back(std::make_unique<FusedOp>(fused_op));
|
||||
} else {
|
||||
op_needs_compile.push_back(rid);
|
||||
}
|
||||
|
||||
|
||||
LOGvv << "Op needs compile:" << op;
|
||||
} catch (const std::exception& e) {
|
||||
// log jit_key and file location
|
||||
op->do_prepare();
|
||||
string jit_src_path = Op::get_filename_from_jit_key(jk.to_cstring(), ".cc");
|
||||
LOGe << "[Error] source file location:" << jit_src_path;
|
||||
if (is_fused_op) {
|
||||
LOGf << "Compile fused operator(" >> rid >> '/' >> queue.size() >> ")"
|
||||
<< "failed:" << fused_op.ops << "\n\nReason: " >> e.what();
|
||||
} else
|
||||
LOGf << "Compile operator(" >> rid >> '/' >> queue.size() >> ")"
|
||||
<< "failed:" << op << "\n\nReason: " >> e.what();
|
||||
}
|
||||
}
|
||||
// if too less op needs compile, don't use parallel compiler
|
||||
// if (op_needs_compile.size() < 3) return;
|
||||
if (op_needs_compile.size() == 0) return;
|
||||
|
||||
static int thread_num = std::max(1, std::min(use_parallel_op_compiler,
|
||||
int(mem_info.total_cpu_ram/(1024ll*1024*1024*3))));
|
||||
#ifdef NODE_MEMCHECK
|
||||
// only use one thread in debug mode
|
||||
// because global id map has no lock
|
||||
thread_num = 1;
|
||||
#endif
|
||||
static std::atomic<int> ai;
|
||||
static volatile int has_error;
|
||||
static vector<vector<std::tuple<int,int,void*,string>>> op_entrys(thread_num);
|
||||
// <int,int,void*,string> represents: task id, is_fused_op, entry or context, new_jit_key
|
||||
static SimpleThreads threads(thread_num);
|
||||
static std::mutex entry_lock;
|
||||
ai = 0;
|
||||
has_error = 0;
|
||||
int n = op_needs_compile.size();
|
||||
LOGvv << "Total number of op needs compile" << op_needs_compile.size()
|
||||
<< "thread_num:" << thread_num;
|
||||
|
||||
// backup number
|
||||
auto bk_var = Var::number_of_lived_vars, bk_op = Op::number_of_lived_ops;
|
||||
jittor::lock_guard lg;
|
||||
auto func = [&](int tid) {
|
||||
auto& entrys = op_entrys.at(tid);
|
||||
entrys.clear();
|
||||
while (!has_error && !segfault_happen) {
|
||||
int i = ai++;
|
||||
if (i >= n) break;
|
||||
int rid = op_needs_compile[i];
|
||||
Op* op;
|
||||
bool is_fused_op = rid<0;
|
||||
try {
|
||||
if (!is_fused_op) {
|
||||
int root = queue[rid];
|
||||
op = ops[root];
|
||||
LOGvv << "Compile Op:" << op;
|
||||
op->do_prepare();
|
||||
auto op_entry = OpCompiler::do_compile(op);
|
||||
entrys.emplace_back(std::make_tuple(i, 0, (void*)op_entry, op->get_jit_key()));
|
||||
} else {
|
||||
FusedOp& fused_op = *fop_needs_compile[-rid-1];
|
||||
op = &fused_op;
|
||||
LOGvv << "Compile FusedOp:" << op;
|
||||
fused_op.context = new FusedOpContext();
|
||||
fused_op.context->setup(&fused_op);
|
||||
fused_op.do_prepare();
|
||||
auto op_entry = OpCompiler::do_compile(op);
|
||||
fused_op.context->entry = op_entry;
|
||||
entrys.emplace_back(std::make_tuple(i, 1, (void*)fused_op.context, op->get_jit_key()));
|
||||
|
||||
// compile relay operators
|
||||
for (auto& vrg : fused_op.context->vrm.relay_groups) {
|
||||
for (auto& orc : vrg.oprcs) {
|
||||
orc.op->do_prepare();
|
||||
bool needs_compile;
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(entry_lock);
|
||||
auto iter = jit_ops.find(jk.to_cstring());
|
||||
needs_compile = (iter == jit_ops.end());
|
||||
if (needs_compile) {
|
||||
jit_ops[jk.to_cstring()] = nullptr;
|
||||
}
|
||||
}
|
||||
if (!needs_compile) continue;
|
||||
string s = jk.to_string();
|
||||
auto op_entry = OpCompiler::do_compile(orc.op);
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(entry_lock);
|
||||
jit_ops[s] = op_entry;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (const std::exception& e) {
|
||||
// log jit_key and file location
|
||||
op->do_prepare();
|
||||
string jit_src_path = Op::get_filename_from_jit_key(jk.to_cstring(), ".cc");
|
||||
LOGe << "[Error] source file location:" << jit_src_path;
|
||||
|
||||
if (is_fused_op) {
|
||||
LOGe << "Compile fused operator(" >> i >> '/' >> n >> ")"
|
||||
<< "failed:" << ((FusedOp*)op)->ops << "\n\nReason: " >> e.what();
|
||||
} else
|
||||
LOGe << "Compile operator(" >> i >> '/' >> n >> ")"
|
||||
<< "failed:" << op << "\n\nReason: " >> e.what();
|
||||
has_error = 1;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}; // end of threads.launch_all
|
||||
int active_threads = std::min(thread_num, (int)op_needs_compile.size());
|
||||
threads.launch_all(active_threads, func);
|
||||
|
||||
typedef std::chrono::high_resolution_clock Time;
|
||||
auto start = Time::now();
|
||||
int prev_i = 0;
|
||||
bool change_line = false;
|
||||
int sleep_us = 10;
|
||||
while (prev_i < n && !has_error && !segfault_happen) {
|
||||
int i = std::max(std::min(ai-active_threads, n), 0);
|
||||
if (i == prev_i) {
|
||||
// std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||
std::this_thread::sleep_for(std::chrono::microseconds(sleep_us));
|
||||
sleep_us = std::min(sleep_us*2, 1000000); // max 0.1s
|
||||
continue;
|
||||
}
|
||||
prev_i = i;
|
||||
auto diff = (Time::now() - start).count();
|
||||
if (diff > 2e9) {
|
||||
if (!change_line) {
|
||||
std::cerr << "\n";
|
||||
change_line = true;
|
||||
}
|
||||
// delay output progress in 2s
|
||||
float eta = diff / 1e9 / i * (n-i);
|
||||
std::cerr << "Compiling Operators(" << i << '/' << n << ")"
|
||||
<< " used: " << std::setprecision(3) << std::setw(4) << diff/1e9 << "s eta: "
|
||||
<< std::setprecision(3) << std::setw(4) << eta << "s \r";
|
||||
}
|
||||
}
|
||||
if (change_line)
|
||||
std::cerr << std::endl;
|
||||
Var::number_of_lived_vars = bk_var; Op::number_of_lived_ops = bk_op;
|
||||
|
||||
if (segfault_happen) {
|
||||
LOGe << "Segfault happen, main thread exit";
|
||||
exit(1);
|
||||
}
|
||||
|
||||
if (has_error) {
|
||||
LOGf << "Error happend during compilation, see error above.";
|
||||
}
|
||||
|
||||
// fill all op entry
|
||||
for (int i=0; i<active_threads; i++) {
|
||||
auto& v = op_entrys[i];
|
||||
for (auto& t : v) {
|
||||
auto& prev_jit_key = map.holder.at(std::get<0>(t));
|
||||
int is_fused_op = std::get<1>(t);
|
||||
auto& new_jit_key = std::get<3>(t);
|
||||
if (is_fused_op)
|
||||
jit_fused_ops[new_jit_key] = jit_fused_ops[prev_jit_key] = (FusedOpContext*)std::get<2>(t);
|
||||
else
|
||||
jit_ops[new_jit_key] = jit_ops[prev_jit_key] = (jit_op_entry_t)std::get<2>(t);
|
||||
jit_key_mapper[prev_jit_key] = new_jit_key;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,14 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
// Authors: Dun Liang <randonlang@gmail.com>.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include "common.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
void parallel_compile_all_ops(vector<int>& queue, vector<int>& range, FusedOp& fused_op, vector<int>& fuse_ops, vector<Op*>& ops, int64 tt);
|
||||
|
||||
} // jittor
|
|
@ -55,6 +55,7 @@ JIT_TEST(fused_op_relay_matmul) {
|
|||
// a, b, d can not fuse
|
||||
a->custom_data = b->custom_data = d->custom_data = 1;
|
||||
fop.update_ops();
|
||||
context.setup(&fop);
|
||||
if (!has_op("mkl_matmul")) return;
|
||||
auto make_matmul = get_op_info("mkl_matmul")
|
||||
.get_constructor<VarPtr, Var*, Var*, bool, bool>();
|
||||
|
@ -72,6 +73,7 @@ JIT_TEST(fused_op_relay_matmul) {
|
|||
// broadcast(a) can not fused
|
||||
fop.vars[1].var->custom_data = 1;
|
||||
fop.update_ops();
|
||||
context.setup(&fop);
|
||||
is_op_relayed = context.vrm.get_op_relay_info({1});
|
||||
vector<pair<int,int>> ans{{-1,-1},{0,0},{0,0},{0,0}};
|
||||
CHECKop(is_op_relayed,==,ans);
|
||||
|
|
|
@ -97,7 +97,10 @@ void print_prefix(std::ostream* out) {
|
|||
<< PRINT_W2(lt.tm_min) << ':'
|
||||
<< PRINT_W2(lt.tm_sec) << "."
|
||||
<< PRINT_W6(usecs) << ' '
|
||||
<< PRINT_W2(tid) << ' ';
|
||||
<< PRINT_W2(tid);
|
||||
if (thread_name.size())
|
||||
*out << ":" << thread_name;
|
||||
*out << ' ';
|
||||
}
|
||||
|
||||
MWSR_LIST(log, std::ostringstream);
|
||||
|
@ -109,8 +112,6 @@ std::vector<std::map<string,string>> logs;
|
|||
int log_capture_enabled = 0;
|
||||
|
||||
void log_capture(const string& s) {
|
||||
int bk = log_capture_enabled;
|
||||
log_capture_enabled = 0;
|
||||
// find [ and ]
|
||||
uint i=0;
|
||||
while (i+2<s.size() && !(s[i]=='[' && s[i+2]==' ')) i++;
|
||||
|
@ -139,8 +140,10 @@ void log_capture(const string& s) {
|
|||
if (s[end]=='\n') end--;
|
||||
if (s[end-2]=='\033') end-=3;
|
||||
log["msg"] = s.substr(j, end-j+1);
|
||||
logs.emplace_back(std::move(log));
|
||||
log_capture_enabled = bk;
|
||||
{
|
||||
std::lock_guard<std::mutex> lg(sync_log_capture);
|
||||
logs.emplace_back(std::move(log));
|
||||
}
|
||||
}
|
||||
|
||||
DECLARE_FLAG(int, log_silent);
|
||||
|
@ -174,14 +177,17 @@ std::vector<std::map<string,string>> log_capture_read() {
|
|||
|
||||
void log_exiting();
|
||||
|
||||
size_t protected_page=0;
|
||||
size_t thread_local protected_page = 0;
|
||||
int segfault_happen = 0;
|
||||
string thread_local thread_name;
|
||||
|
||||
void segfault_sigaction(int signal, siginfo_t *si, void *arg) {
|
||||
if (signal == SIGINT) {
|
||||
LOGe << "Caught SIGINT, exit";
|
||||
exit(1);
|
||||
}
|
||||
std::cerr << "Caught segfault at address " << si->si_addr << ", flush log..." << std::endl;
|
||||
std::cerr << "Caught segfault at address " << si->si_addr << ", "
|
||||
<< "thread_name: '" << thread_name << "', flush log..." << std::endl;
|
||||
std::cerr.flush();
|
||||
if (protected_page &&
|
||||
si->si_addr>=(void*)protected_page &&
|
||||
|
@ -189,11 +195,14 @@ void segfault_sigaction(int signal, siginfo_t *si, void *arg) {
|
|||
LOGf << "Accessing protect pages, maybe jit_key too long";
|
||||
}
|
||||
if (signal == SIGSEGV) {
|
||||
print_trace();
|
||||
// only print trace in main thread
|
||||
if (thread_name.size() == 0)
|
||||
print_trace();
|
||||
std::cerr << "Segfault, exit" << std::endl;
|
||||
} else {
|
||||
std::cerr << "Get signal " << signal << ", exit" << std::endl;
|
||||
}
|
||||
segfault_happen = 1;
|
||||
exit(1);
|
||||
}
|
||||
|
||||
|
@ -281,8 +290,8 @@ bool check_vlog(const char* fileline, int verbose) {
|
|||
}
|
||||
|
||||
int system_popen(const char* cmd) {
|
||||
static char buf[BUFSIZ];
|
||||
static string cmd2;
|
||||
static thread_local char buf[BUFSIZ];
|
||||
static thread_local string cmd2;
|
||||
cmd2 = cmd;
|
||||
cmd2 += " 2>&1 ";
|
||||
FILE *ptr = popen(cmd2.c_str(), "r");
|
||||
|
|
|
@ -36,9 +36,9 @@ extern "C" uint32_t get_tid();
|
|||
extern "C" bool g_supports_color;
|
||||
extern "C" void print_prefix(std::ostream* out);
|
||||
|
||||
const char green[] = "\033[38;5;2m";
|
||||
const char red[] = "\033[38;5;1m";
|
||||
const char yellow[] = "\033[38;5;3m";
|
||||
constexpr char green[] = "\033[38;5;2m";
|
||||
constexpr char red[] = "\033[38;5;1m";
|
||||
constexpr char yellow[] = "\033[38;5;3m";
|
||||
|
||||
static void get_color(char level, int verbose, const char*& color_begin, const char*& color_end) {
|
||||
if (level == 'i') {
|
||||
|
@ -61,6 +61,7 @@ extern "C" void flush_log();
|
|||
extern "C" void log_capture_start();
|
||||
extern "C" void log_capture_stop();
|
||||
extern std::vector<std::map<string,string>> log_capture_read();
|
||||
extern string thread_local thread_name;
|
||||
|
||||
struct Log {
|
||||
std::ostringstream out;
|
||||
|
|
Loading…
Reference in New Issue