msvc support

This commit is contained in:
Dun Liang 2021-09-15 17:34:50 +08:00
parent 4e38190483
commit c3938e14bf
114 changed files with 1910 additions and 1142 deletions

1
.gitignore vendored
View File

@ -12,6 +12,7 @@ perf.data.old
*.pdf *.pdf
*.zip *.zip
*.tgz *.tgz
*.obj
test.py test.py
extern/mkl/mkldnn_lnx*/* extern/mkl/mkldnn_lnx*/*
data/ data/

View File

@ -25,6 +25,7 @@ def install_mkl(root_folder):
# origin url is # origin url is
# url = "https://github.com/intel/mkl-dnn/releases/download/v1.0.2/mkldnn_lnx_1.0.2_cpu_gomp.tgz" # url = "https://github.com/intel/mkl-dnn/releases/download/v1.0.2/mkldnn_lnx_1.0.2_cpu_gomp.tgz"
import platform import platform
url = None
if platform.system()=="Linux": if platform.system()=="Linux":
if platform.machine()=='x86_64': if platform.machine()=='x86_64':
filename = "dnnl_lnx_2.2.0_cpu_gomp.tgz" filename = "dnnl_lnx_2.2.0_cpu_gomp.tgz"
@ -35,23 +36,44 @@ def install_mkl(root_folder):
else: else:
raise RuntimeError(f"platform.machine()=={platform.machine()} not support yet," raise RuntimeError(f"platform.machine()=={platform.machine()} not support yet,"
" Please contact us on https://github.com/jittor/jittor ") " Please contact us on https://github.com/jittor/jittor ")
elif os.name == "nt":
# url = "https://github.com/oneapi-src/oneDNN/releases/download/v2.2/dnnl_win_2.2.0_cpu_iomp.zip"
# url = "https://github.com/oneapi-src/oneDNN/releases/download/v2.2/dnnl_win_2.2.0_cpu_vcomp.zip"
filename = "dnnl_win_2.2.0_cpu_vcomp.zip"
md5 = "fa12c693b2ec07700d174e1e99d60a7e"
else: else:
raise RuntimeError(f"platform.machine()=={platform.machine()} not support yet," raise RuntimeError(f"platform.machine()=={platform.machine()} not support yet,"
" Please contact us on https://github.com/jittor/jittor ") " Please contact us on https://github.com/jittor/jittor ")
url = "https://cg.cs.tsinghua.edu.cn/jittor/assets/" + filename if not url:
url = "https://cg.cs.tsinghua.edu.cn/jittor/assets/" + filename
fullname = os.path.join(root_folder, filename) fullname = os.path.join(root_folder, filename)
dirname = os.path.join(root_folder, filename.replace(".tgz","")) dirname = os.path.join(root_folder, filename.rsplit(".",1)[0])
if not os.path.isfile(os.path.join(dirname, "lib", "libmkldnn.so")): if not (os.path.isfile(os.path.join(dirname, "lib", "libmkldnn.so")) or
os.path.isfile(os.path.join(dirname, "bin", "dnnl.dll"))):
LOG.i("Downloading mkl...") LOG.i("Downloading mkl...")
download_url_to_local(url, filename, root_folder, md5) download_url_to_local(url, filename, root_folder, md5)
import tarfile if fullname.endswith(".zip"):
import zipfile
with tarfile.open(fullname, "r") as tar: with zipfile.ZipFile(fullname, "r") as f:
tar.extractall(root_folder) f.extractall(root_folder)
else:
assert 0 == os.system(f"cd {dirname}/examples && " import tarfile
with tarfile.open(fullname, "r") as tar:
tar.extractall(root_folder)
if os.name == 'nt':
# this env is used for execute example/text
bin_path = os.path.join(dirname, "bin")
sys.path.append(bin_path)
os.add_dll_directory(bin_path)
os.environ["PATH"] = os.environ.get("PATH", "") + ";" + bin_path
cmd = f"cd /d {dirname}/examples && {cc_path} {dirname}/examples/cnn_inference_f32.cpp -I{dirname}/include -Fe: {dirname}/examples/test {cc_flags} {win_link_flags} {dirname}/lib/mkldnn.lib"
assert 0 == os.system(cmd)
assert 0 == os.system(f"{dirname}/examples/test")
else:
assert 0 == os.system(f"cd {dirname}/examples && "
f"{cc_path} -std=c++14 cnn_inference_f32.cpp -Ofast -lmkldnn -I ../include -L ../lib -o test && LD_LIBRARY_PATH=../lib/ ./test") f"{cc_path} -std=c++14 cnn_inference_f32.cpp -Ofast -lmkldnn -I ../include -L ../lib -o test && LD_LIBRARY_PATH=../lib/ ./test")
def setup_mkl(): def setup_mkl():
@ -74,7 +96,7 @@ def setup_mkl():
mkl_include_path = os.environ.get("mkl_include_path") mkl_include_path = os.environ.get("mkl_include_path")
mkl_lib_path = os.environ.get("mkl_lib_path") mkl_lib_path = os.environ.get("mkl_lib_path")
if platform.system() == 'Linux': if platform.system() == 'Linux' or os.name == 'nt':
if mkl_lib_path is None or mkl_include_path is None: if mkl_lib_path is None or mkl_include_path is None:
mkl_install_sh = os.path.join(jittor_path, "script", "install_mkl.sh") mkl_install_sh = os.path.join(jittor_path, "script", "install_mkl.sh")
LOG.v("setup mkl...") LOG.v("setup mkl...")
@ -95,6 +117,13 @@ def setup_mkl():
mkl_lib_path = os.path.join(mkl_home, "lib") mkl_lib_path = os.path.join(mkl_home, "lib")
mkl_lib_name = os.path.join(mkl_lib_path, "libmkldnn.so") mkl_lib_name = os.path.join(mkl_lib_path, "libmkldnn.so")
extra_flags = f" -I\"{mkl_include_path}\" -L\"{mkl_lib_path}\" -lmkldnn -Wl,-rpath='{mkl_lib_path}' "
if os.name == 'nt':
mkl_lib_name = os.path.join(mkl_home, 'bin', 'dnnl.dll')
mkl_bin_path = os.path.join(mkl_home, 'bin')
os.add_dll_directory(mkl_bin_path)
mkl_lib = os.path.join(mkl_lib_path, "dnnl.lib")
extra_flags = f" -I\"{mkl_include_path}\" \"{mkl_lib}\" "
assert os.path.isdir(mkl_include_path) assert os.path.isdir(mkl_include_path)
assert os.path.isdir(mkl_lib_path) assert os.path.isdir(mkl_lib_path)
assert os.path.isfile(mkl_lib_name) assert os.path.isfile(mkl_lib_name)
@ -103,7 +132,6 @@ def setup_mkl():
LOG.v(f"mkl_lib_name: {mkl_lib_name}") LOG.v(f"mkl_lib_name: {mkl_lib_name}")
# We do not link manualy, link in custom ops # We do not link manualy, link in custom ops
# ctypes.CDLL(mkl_lib_name, dlopen_flags) # ctypes.CDLL(mkl_lib_name, dlopen_flags)
extra_flags = f" -I'{mkl_include_path}' -L'{mkl_lib_path}' -lmkldnn -Wl,-rpath='{mkl_lib_path}' "
elif platform.system() == 'Darwin': elif platform.system() == 'Darwin':
mkl_lib_paths = [ mkl_lib_paths = [
@ -508,6 +536,7 @@ world_size = mpi.world_size() if in_mpi else 1
setup_nccl() setup_nccl()
setup_cutt() setup_cutt()
try: try:
setup_mkl() setup_mkl()
except Exception as e: except Exception as e:

View File

@ -55,18 +55,22 @@ def compile(compiler, flags, inputs, output, combind_build=False):
link = link_flags link = link_flags
base_output = os.path.basename(output).split('.')[0] base_output = os.path.basename(output).split('.')[0]
if os.name == 'nt': if os.name == 'nt':
# initialize order in windows seems reversed # windows do not combind build, need gen def
inputs = list(inputs[::-1]) combind_build = False
# windows need libxxx.a # windows need xxxx.lib
afile = os.path.join(cache_path, f"lib{base_output}.a") afile = output.rsplit('.', 1)[0] + ".lib"
link = link + f' -Wl,--export-all-symbols,--out-implib,"{afile}" ' afile = os.path.join(cache_path, afile)
if base_output == "jit_utils_core": if cc_type != 'cl':
pass # initialize order in windows seems reversed
elif base_output == "jittor_core": inputs = list(inputs[::-1])
inputs.append(os.path.join(cache_path, f"libjit_utils_core.a")) link = link + f' -Wl,--export-all-symbols,--out-implib,"{afile}" '
else: if base_output == "jit_utils_core":
inputs.append(os.path.join(cache_path, f"libjit_utils_core.a")) pass
inputs.append(os.path.join(cache_path, f"libjittor_core.a")) elif base_output == "jittor_core":
inputs.append(os.path.join(cache_path, f"jit_utils_core{lib_suffix}"))
else:
inputs.append(os.path.join(cache_path, f"jit_utils_core{lib_suffix}"))
inputs.append(os.path.join(cache_path, f"jittor_core{lib_suffix}"))
# if output is core, add core_link_flags # if output is core, add core_link_flags
if output.startswith("jittor_core"): if output.startswith("jittor_core"):
@ -77,7 +81,7 @@ def compile(compiler, flags, inputs, output, combind_build=False):
ex_obj_files = [] ex_obj_files = []
new_inputs = [] new_inputs = []
for name in inputs: for name in inputs:
if name[-1] in 'oa': if name[-1] in 'oab':
ex_obj_files.append(name) ex_obj_files.append(name)
else: else:
new_inputs.append(os.path.join(jittor_path, name)) new_inputs.append(os.path.join(jittor_path, name))
@ -87,7 +91,7 @@ def compile(compiler, flags, inputs, output, combind_build=False):
if len(inputs) == 1 or combind_build: if len(inputs) == 1 or combind_build:
cmd = f"\"{compiler}\" {' '.join(inputs)} {flags} {link} -o {output}" cmd = f"\"{compiler}\" {' '.join(inputs)} {flags} {link} -o {output}"
return do_compile(cmd) return do_compile(fix_cl_flags(cmd))
# split compile object file and link # split compile object file and link
# remove -l -L flags when compile object files # remove -l -L flags when compile object files
oflags = remove_flags(flags, ['-l', '-L', '-Wl,']) oflags = remove_flags(flags, ['-l', '-L', '-Wl,'])
@ -101,16 +105,20 @@ def compile(compiler, flags, inputs, output, combind_build=False):
cc = nvcc_path cc = nvcc_path
else: else:
continue continue
cmd = f"{cc} {input} {nflags} -c {lto_flags} -o {obj_file}" cmd = f"\"{cc}\" {input} {nflags} {lto_flags} -c -o {obj_file}"
if "nan_checker" in input: if "nan_checker" in input:
# nan checker needs to disable fast_math # nan checker needs to disable fast_math
cmd = cmd.replace("--use_fast_math", "") cmd = cmd.replace("--use_fast_math", "")
cmd = cmd.replace("-Ofast", "-O2") cmd = cmd.replace("-Ofast", "-O2")
cmds.append(cmd) cmds.append(fix_cl_flags(cmd))
jit_utils.run_cmds(cmds, cache_path, jittor_path, "Compiling "+base_output) jit_utils.run_cmds(cmds, cache_path, jittor_path, "Compiling "+base_output)
obj_files += ex_obj_files obj_files += ex_obj_files
if os.name == 'nt':
dumpdef_path = os.path.join(jittor_path, "utils", "dumpdef.py")
cmd = f"\"{sys.executable}\" \"{dumpdef_path}\" {' '.join(obj_files)} -Fo: \"{output}.def\""
do_compile(fix_cl_flags(cmd))
cmd = f"\"{compiler}\" {' '.join(obj_files)} {flags} {lto_flags} {link} -o {output}" cmd = f"\"{compiler}\" {' '.join(obj_files)} {flags} {lto_flags} {link} -o {output}"
return do_compile(cmd) return do_compile(fix_cl_flags(cmd))
def gen_jit_tests(): def gen_jit_tests():
all_src = glob.glob(jittor_path+"/src/**/*.cc", recursive=True) all_src = glob.glob(jittor_path+"/src/**/*.cc", recursive=True)
@ -660,7 +668,7 @@ def compile_custom_ops(
gen_name = gen_name[:80] + "___hash" + hashlib.md5(gen_name.encode()).hexdigest() gen_name = gen_name[:80] + "___hash" + hashlib.md5(gen_name.encode()).hexdigest()
includes = sorted(list(set(includes))) includes = sorted(list(set(includes)))
includes = "".join(map(lambda x: f" -I'{x}' ", includes)) includes = "".join(map(lambda x: f" -I\"{x}\" ", includes))
LOG.vvvv(f"Include flags:{includes}") LOG.vvvv(f"Include flags:{includes}")
op_extra_flags = includes + extra_flags op_extra_flags = includes + extra_flags
@ -916,7 +924,7 @@ if not nvcc_path:
nvcc_path = try_find_exe(nvcc_path) nvcc_path = try_find_exe(nvcc_path)
if nvcc_path is None: if nvcc_path is None:
nvcc_path = "" nvcc_path = ""
gdb_path = try_find_exe('gdb') gdb_path = env_or_try_find('gdb_path', 'gdb')
addr2line_path = try_find_exe('addr2line') addr2line_path = try_find_exe('addr2line')
has_pybt = check_pybt(gdb_path, python_path) has_pybt = check_pybt(gdb_path, python_path)
@ -952,26 +960,80 @@ if platform.system() == 'Darwin':
core_link_flags = "" core_link_flags = ""
opt_flags = "" opt_flags = ""
py_include = jit_utils.get_py3_include_path()
LOG.i(f"py_include: {py_include}")
extension_suffix = jit_utils.get_py3_extension_suffix()
lib_suffix = extension_suffix.replace(".pyd", ".lib")
LOG.i(f"extension_suffix: {extension_suffix}")
kernel_opt_flags = os.environ.get("kernel_flags", "") + opt_flags kernel_opt_flags = os.environ.get("kernel_flags", "") + opt_flags
if platform.system() == 'Darwin': if platform.system() == 'Darwin':
# TODO: if not using apple clang, cannot add -Xpreprocessor # TODO: if not using apple clang, cannot add -Xpreprocessor
kernel_opt_flags = kernel_opt_flags + " -Xpreprocessor -fopenmp " kernel_opt_flags = kernel_opt_flags + " -Xpreprocessor -fopenmp "
else: elif cc_type != 'cl':
kernel_opt_flags = kernel_opt_flags + " -fopenmp " kernel_opt_flags = kernel_opt_flags + " -fopenmp "
fix_cl_flags = lambda x:x
if os.name == 'nt': if os.name == 'nt':
link_flags = link_flags.replace('-ldl', '') if cc_type == 'g++':
py3_link_path = '-L"' + os.path.join( link_flags = link_flags.replace('-ldl', '')
os.path.dirname(sys.executable), py3_link_path = '-L"' + os.path.join(
"libs" os.path.dirname(sys.executable),
) + f'" -lpython3{sys.version_info.minor} ' "libs"
core_link_flags = py3_link_path ) + f'" -lpython3{sys.version_info.minor} '
link_flags += core_link_flags core_link_flags = py3_link_path
# link_flags += " -Wl,--unresolved-symbols=ignore-all " link_flags += core_link_flags
# cc_flags += " -Xlinker --allow-shlib-undefined " # link_flags += " -Wl,--unresolved-symbols=ignore-all "
cc_flags = cc_flags.replace('-std=c++14', '-std=c++17') # cc_flags += " -Xlinker --allow-shlib-undefined "
link_flags += " -fopenmp " cc_flags = cc_flags.replace('-std=c++14', '-std=c++17')
kernel_opt_flags += f" {cache_path}\\libjit_utils_core.a " link_flags += " -fopenmp "
kernel_opt_flags += f" {cache_path}\\libjittor_core.a " kernel_opt_flags += f" {cache_path}\\jit_utils_core{lib_suffix} "
kernel_opt_flags += f" {cache_path}\\jittor_core{lib_suffix} "
elif cc_type == 'cl':
py3_link_path = os.path.join(
os.path.dirname(sys.executable),
"libs",
f'python3{sys.version_info.minor}.lib'
)
# core_link_flags = py3_link_path
link_flags += core_link_flags
# link_flags += " -Wl,--unresolved-symbols=ignore-all "
# cc_flags += " -Xlinker --allow-shlib-undefined "
kernel_opt_flags += f" {cache_path}\\jit_utils_core{lib_suffix} "
kernel_opt_flags += f" {cache_path}\\jittor_core{lib_suffix} "
# cc_flags = " -std:c++17 -O2 -fp:fast -EHsc "
cc_flags = " -std:c++17 -O2 -fp:fast -EHsc "
# cc_flags += py3_link_path + " "
import jittor_utils
if jittor_utils.msvc_path:
mp = jittor_utils.msvc_path
cc_flags += f' -nologo -I"{mp}\\cl_x64\\include" -I"{mp}\\win10_kits\\include\\ucrt" -I"{mp}\\win10_kits\\include\\shared" -I"{mp}\\win10_kits\\include\\um" -DNOMINMAX '
win_link_flags = f' -link -LIBPATH:"{mp}\\cl_x64\\lib" -LIBPATH:"{mp}\\win10_kits\\lib\\um\\x64" -LIBPATH:"{mp}\\win10_kits\\lib\\ucrt\\x64" '
link_flags = ' -LD '
kernel_opt_flags += win_link_flags# + " -EXPORT:\"?jit_run@FusedOp@jittor@@QEAAXXZ\""
def fix_cl_flags(cmd):
cmd = cmd.replace(".o ", ".obj ")
cmd = cmd.replace(".o\" ", ".obj\" ")
if cmd.endswith(".o"): cmd += "bj"
from shlex import split
if " -LD " in cmd:
cmd = cmd.replace(" -o ", " -Fe: ")
output = split(cmd.split("-Fe:")[1].strip(), posix=False)[0]
base_output = os.path.basename(output).split('.')[0]
cmd += win_link_flags
cmd += f" -DEF:\"{output}.def\" -IGNORE:4102 -IGNORE:4197 -IGNORE:4217 {py3_link_path}"
if base_output == "jit_utils_core":
pass
elif base_output == "jittor_core":
cmd += " " + os.path.join(cache_path, f"jit_utils_core{lib_suffix}")
else:
cmd += " " + os.path.join(cache_path, f"jit_utils_core{lib_suffix} ")
cmd += " " + os.path.join(cache_path, f"jittor_core{lib_suffix} ")
elif " -c -o " in cmd:
cmd = cmd.replace(" -c -o ", " -c -Fo: ")
cmd = cmd.replace("-include", "-FI")
return cmd
if ' -O' not in cc_flags: if ' -O' not in cc_flags:
opt_flags += " -O2 " opt_flags += " -O2 "
@ -985,11 +1047,6 @@ if os.environ.get("enable_lto") == "1":
else: else:
lto_flags = " -flto " lto_flags = " -flto "
py_include = jit_utils.get_py3_include_path()
LOG.i(f"py_include: {py_include}")
extension_suffix = jit_utils.get_py3_extension_suffix()
LOG.i(f"extension_suffix: {extension_suffix}")
make_cache_dir(cache_path) make_cache_dir(cache_path)
make_cache_dir(os.path.join(cache_path, "jit")) make_cache_dir(os.path.join(cache_path, "jit"))
make_cache_dir(os.path.join(cache_path, "obj_files")) make_cache_dir(os.path.join(cache_path, "obj_files"))
@ -1107,7 +1164,8 @@ if use_data_gz:
dflags = (cc_flags+opt_flags)\ dflags = (cc_flags+opt_flags)\
.replace("-Wall", "") \ .replace("-Wall", "") \
.replace("-Werror", "") .replace("-Werror", "")
run_cmd(f"{cc_path} {dflags} \"-D_P(...)=\" {data_s_path} -c -o {data_o_path}") vdp = os.path.join(jittor_path, "src", "utils", "vdp")
run_cmd(fix_cl_flags(f"{cc_path} {dflags} -include {vdp} {data_s_path} -c -o {data_o_path}"))
os.remove(data_s_path) os.remove(data_s_path)
with open(data_gz_md5_path, 'w') as f: with open(data_gz_md5_path, 'w') as f:
f.write(md5) f.write(md5)

View File

@ -28,6 +28,43 @@ mpi = jt.mpi
img_open_hook = HookTimer(Image, "open") img_open_hook = HookTimer(Image, "open")
CHECK_MEMORY = int(os.environ.get("CHECK_MEMORY", "0")) CHECK_MEMORY = int(os.environ.get("CHECK_MEMORY", "0"))
if os.name == "nt":
from multiprocessing import shared_memory
class RingBuffer:
def __init__(self, size, shm=None):
for i in range(100):
if (1<<i) >= size: break
size = 1<<i
init = False
if shm is None:
init = True
shm = shared_memory.SharedMemory(create=True, size=size+1024)
rb = jt.core.RingBuffer(size, id(shm.buf), init)
self.size = size
self.shm = shm
self.rb = rb
def __reduce__(self):
return (RingBuffer, (self.size, self.shm))
def __del__(self):
del self.rb
del self.shm
def push(self, obj): self.send(obj)
def pop(self): return self.recv()
def send(self, obj): self.rb.push(obj)
def recv(self): return self.rb.pop()
def clear(self): return self.rb.clear()
def stop(self): return self.rb.stop()
def is_stop(self): return self.rb.is_stop()
def total_pop(self): return self.rb.total_pop()
def total_push(self): return self.rb.total_push()
def __repr__(self): return repr(self.rb)
def keep_numpy_array(self, keep): self.rb.keep_numpy_array(keep)
jt.RingBuffer = RingBuffer
class Worker: class Worker:
def __init__(self, target, args, buffer_size, keep_numpy_array=False): def __init__(self, target, args, buffer_size, keep_numpy_array=False):
self.buffer = jt.RingBuffer(buffer_size) self.buffer = jt.RingBuffer(buffer_size)

View File

@ -18,6 +18,6 @@
namespace jittor { namespace jittor {
extern cublasHandle_t cublas_handle; EXTERN_LIB cublasHandle_t cublas_handle;
} // jittor } // jittor

View File

@ -15,9 +15,9 @@
namespace jittor { namespace jittor {
extern cudnnHandle_t cudnn_handle; EXTERN_LIB cudnnHandle_t cudnn_handle;
extern int max_cache_size; EXTERN_LIB int max_cache_size;
extern float max_workspace_ratio; EXTERN_LIB float max_workspace_ratio;
// @pyjt(set_algorithm_cache_size) // @pyjt(set_algorithm_cache_size)
void set_algorithm_cache_size(int size); void set_algorithm_cache_size(int size);

View File

@ -87,7 +87,7 @@ VarPtr CudnnConv3dBackwardWOp::grad(Var* out, Var* dout, Var* v, int v_index) {
#pragma clang diagnostic ignored "-Wtautological-compare" #pragma clang diagnostic ignored "-Wtautological-compare"
extern unordered_map<string, cudnnConvolutionBwdFilterAlgo_t> bwdw_algo_cache; EXTERN_LIB unordered_map<string, cudnnConvolutionBwdFilterAlgo_t> bwdw_algo_cache;
template <typename T_ELEM> __inline__ cudnnDataType_t getDataType(); template <typename T_ELEM> __inline__ cudnnDataType_t getDataType();
template <> __inline__ cudnnDataType_t getDataType<half1>() { return CUDNN_DATA_HALF; } template <> __inline__ cudnnDataType_t getDataType<half1>() { return CUDNN_DATA_HALF; }
@ -194,6 +194,7 @@ void CudnnConv3dBackwardWOp::jit_run() {
cudnnConvolutionBwdFilterAlgo_t algo; cudnnConvolutionBwdFilterAlgo_t algo;
bool benchmark=true; bool benchmark=true;
JK& jk = get_jk();
jk.clear(); jk.clear();
jk << dimX[0] << "," << dimX[1] << "," << dimX[2] << "," << dimX[3] << "," << dimX[4] << ","; jk << dimX[0] << "," << dimX[1] << "," << dimX[2] << "," << dimX[3] << "," << dimX[4] << ",";
jk << dimW[0] << "," << dimW[1] << "," << dimW[2] << "," << dimW[3] << "," << dimW[4] << ","; jk << dimW[0] << "," << dimW[1] << "," << dimW[2] << "," << dimW[3] << "," << dimW[4] << ",";

View File

@ -77,7 +77,7 @@ VarPtr CudnnConv3dBackwardXOp::grad(Var* out, Var* dout, Var* v, int v_index) {
#pragma clang diagnostic ignored "-Wtautological-compare" #pragma clang diagnostic ignored "-Wtautological-compare"
extern unordered_map<string, cudnnConvolutionBwdDataAlgo_t> bwdx_algo_cache; EXTERN_LIB unordered_map<string, cudnnConvolutionBwdDataAlgo_t> bwdx_algo_cache;
template <typename T_ELEM> __inline__ cudnnDataType_t getDataType(); template <typename T_ELEM> __inline__ cudnnDataType_t getDataType();
template <> __inline__ cudnnDataType_t getDataType<half1>() { return CUDNN_DATA_HALF; } template <> __inline__ cudnnDataType_t getDataType<half1>() { return CUDNN_DATA_HALF; }
@ -185,6 +185,7 @@ void CudnnConv3dBackwardXOp::jit_run() {
cudnnConvolutionBwdDataAlgo_t algo; cudnnConvolutionBwdDataAlgo_t algo;
bool benchmark=true; bool benchmark=true;
JK& jk = get_jk();
jk.clear(); jk.clear();
jk << dimX[0] << "," << dimX[1] << "," << dimX[2] << "," << dimX[3] << "," << dimX[4] << ","; jk << dimX[0] << "," << dimX[1] << "," << dimX[2] << "," << dimX[3] << "," << dimX[4] << ",";
jk << dimW[0] << "," << dimW[1] << "," << dimW[2] << "," << dimW[3] << "," << dimW[4] << ","; jk << dimW[0] << "," << dimW[1] << "," << dimW[2] << "," << dimW[3] << "," << dimW[4] << ",";

View File

@ -80,7 +80,7 @@ VarPtr CudnnConv3dOp::grad(Var* out, Var* dout, Var* v, int v_index) {
#pragma clang diagnostic ignored "-Wtautological-compare" #pragma clang diagnostic ignored "-Wtautological-compare"
extern unordered_map<string, cudnnConvolutionFwdAlgo_t> fwd_algo_cache; EXTERN_LIB unordered_map<string, cudnnConvolutionFwdAlgo_t> fwd_algo_cache;
template <typename T_ELEM> __inline__ cudnnDataType_t getDataType(); template <typename T_ELEM> __inline__ cudnnDataType_t getDataType();
template <> __inline__ cudnnDataType_t getDataType<half1>() { return CUDNN_DATA_HALF; } template <> __inline__ cudnnDataType_t getDataType<half1>() { return CUDNN_DATA_HALF; }
@ -188,6 +188,7 @@ void CudnnConv3dOp::jit_run() {
cudnnConvolutionFwdAlgo_t algo; cudnnConvolutionFwdAlgo_t algo;
bool benchmark=true; bool benchmark=true;
JK& jk = get_jk();
jk.clear(); jk.clear();
jk << dimX[0] << "," << dimX[1] << "," << dimX[2] << "," << dimX[3] << "," << dimX[4] << ","; jk << dimX[0] << "," << dimX[1] << "," << dimX[2] << "," << dimX[3] << "," << dimX[4] << ",";
jk << dimW[0] << "," << dimW[1] << "," << dimW[2] << "," << dimW[3] << "," << dimW[4] << ","; jk << dimW[0] << "," << dimW[1] << "," << dimW[2] << "," << dimW[3] << "," << dimW[4] << ",";

View File

@ -79,7 +79,7 @@ unordered_map<string, cudnnConvolutionBwdFilterAlgo_t> bwdw_algo_cache;
#pragma clang diagnostic ignored "-Wtautological-compare" #pragma clang diagnostic ignored "-Wtautological-compare"
extern unordered_map<string, cudnnConvolutionBwdFilterAlgo_t> bwdw_algo_cache; EXTERN_LIB unordered_map<string, cudnnConvolutionBwdFilterAlgo_t> bwdw_algo_cache;
template <typename T_ELEM> __inline__ cudnnDataType_t getDataType(); template <typename T_ELEM> __inline__ cudnnDataType_t getDataType();
template <> __inline__ cudnnDataType_t getDataType<half1>() { return CUDNN_DATA_HALF; } template <> __inline__ cudnnDataType_t getDataType<half1>() { return CUDNN_DATA_HALF; }
@ -184,6 +184,7 @@ void CudnnConvBackwardWOp::jit_run() {
cudnnConvolutionBwdFilterAlgo_t algo; cudnnConvolutionBwdFilterAlgo_t algo;
bool benchmark=true; bool benchmark=true;
JK& jk = get_jk();
jk.clear(); jk.clear();
jk << dimX[0] << "," << dimX[1] << "," << dimX[2] << "," << dimX[3] << ","; jk << dimX[0] << "," << dimX[1] << "," << dimX[2] << "," << dimX[3] << ",";
jk << dimW[0] << "," << dimW[1] << "," << dimW[2] << "," << dimW[3] << ","; jk << dimW[0] << "," << dimW[1] << "," << dimW[2] << "," << dimW[3] << ",";

View File

@ -79,7 +79,7 @@ unordered_map<string, cudnnConvolutionBwdDataAlgo_t> bwdx_algo_cache;
#pragma clang diagnostic ignored "-Wtautological-compare" #pragma clang diagnostic ignored "-Wtautological-compare"
extern unordered_map<string, cudnnConvolutionBwdDataAlgo_t> bwdx_algo_cache; EXTERN_LIB unordered_map<string, cudnnConvolutionBwdDataAlgo_t> bwdx_algo_cache;
template <typename T_ELEM> __inline__ cudnnDataType_t getDataType(); template <typename T_ELEM> __inline__ cudnnDataType_t getDataType();
template <> __inline__ cudnnDataType_t getDataType<half1>() { return CUDNN_DATA_HALF; } template <> __inline__ cudnnDataType_t getDataType<half1>() { return CUDNN_DATA_HALF; }
@ -185,6 +185,7 @@ void CudnnConvBackwardXOp::jit_run() {
cudnnConvolutionBwdDataAlgo_t algo; cudnnConvolutionBwdDataAlgo_t algo;
bool benchmark=true; bool benchmark=true;
JK& jk = get_jk();
jk.clear(); jk.clear();
jk << dimX[0] << "," << dimX[1] << "," << dimX[2] << "," << dimX[3] << ","; jk << dimX[0] << "," << dimX[1] << "," << dimX[2] << "," << dimX[3] << ",";
jk << dimW[0] << "," << dimW[1] << "," << dimW[2] << "," << dimW[3] << ","; jk << dimW[0] << "," << dimW[1] << "," << dimW[2] << "," << dimW[3] << ",";

View File

@ -81,7 +81,7 @@ unordered_map<string, cudnnConvolutionFwdAlgo_t> fwd_algo_cache;
#pragma clang diagnostic ignored "-Wtautological-compare" #pragma clang diagnostic ignored "-Wtautological-compare"
extern unordered_map<string, cudnnConvolutionFwdAlgo_t> fwd_algo_cache; EXTERN_LIB unordered_map<string, cudnnConvolutionFwdAlgo_t> fwd_algo_cache;
template <typename T_ELEM> __inline__ cudnnDataType_t getDataType(); template <typename T_ELEM> __inline__ cudnnDataType_t getDataType();
template <> __inline__ cudnnDataType_t getDataType<half1>() { return CUDNN_DATA_HALF; } template <> __inline__ cudnnDataType_t getDataType<half1>() { return CUDNN_DATA_HALF; }
@ -187,6 +187,7 @@ void CudnnConvOp::jit_run() {
cudnnConvolutionFwdAlgo_t algo; cudnnConvolutionFwdAlgo_t algo;
bool benchmark=true; bool benchmark=true;
JK& jk = get_jk();
jk.clear(); jk.clear();
jk << dimX[0] << "," << dimX[1] << "," << dimX[2] << "," << dimX[3] << ","; jk << dimX[0] << "," << dimX[1] << "," << dimX[2] << "," << dimX[3] << ",";
jk << dimW[0] << "," << dimW[1] << "," << dimW[2] << "," << dimW[3] << ","; jk << dimW[0] << "," << dimW[1] << "," << dimW[2] << "," << dimW[3] << ",";

View File

@ -17,6 +17,6 @@
namespace jittor { namespace jittor {
extern curandGenerator_t gen; EXTERN_LIB curandGenerator_t gen;
} // jittor } // jittor

View File

@ -66,7 +66,7 @@ unordered_map<string, unsigned int> cutt_plan_cache;
#else // JIT #else // JIT
extern unordered_map<string, unsigned int> cutt_plan_cache; EXTERN_LIB unordered_map<string, unsigned int> cutt_plan_cache;
void CuttTransposeOp::jit_run() { void CuttTransposeOp::jit_run() {
auto* __restrict__ xp = x->mem_ptr; auto* __restrict__ xp = x->mem_ptr;
@ -93,6 +93,7 @@ void CuttTransposeOp::jit_run() {
checkCudaErrors(cudaMemcpyAsync(yp, xp, x->size, cudaMemcpyDefault, 0)); checkCudaErrors(cudaMemcpyAsync(yp, xp, x->size, cudaMemcpyDefault, 0));
return; return;
} }
JK& jk = get_jk();
jk.clear(); jk.clear();
jk << dim << ','; jk << dim << ',';
for (int i=0; i<dim; i++) jk << x_shape[i] << ','; for (int i=0; i<dim; i++) jk << x_shape[i] << ',';

View File

@ -102,7 +102,7 @@ const char *_cudaGetErrorEnum(NppStatus error);
#endif #endif
namespace jittor { namespace jittor {
extern bool peek_logged; EXTERN_LIB bool peek_logged;
} }
template <typename T> template <typename T>

View File

@ -17,8 +17,8 @@
namespace jittor { namespace jittor {
extern ncclComm_t comm; EXTERN_LIB ncclComm_t comm;
extern ncclUniqueId id; EXTERN_LIB ncclUniqueId id;
extern int nccl_device_id; EXTERN_LIB int nccl_device_id;
} // jittor } // jittor

View File

@ -9,6 +9,7 @@
// *************************************************************** // ***************************************************************
#pragma once #pragma once
#define OMPI_SKIP_MPICXX #define OMPI_SKIP_MPICXX
#include <common.h>
#include <mpi.h> #include <mpi.h>
extern void throw_mpi_error(int result, extern void throw_mpi_error(int result,
@ -25,13 +26,13 @@ static inline void mpi_check(int result,
namespace jittor { namespace jittor {
extern int mpi_world_size; EXTERN_LIB int mpi_world_size;
extern int mpi_world_rank; EXTERN_LIB int mpi_world_rank;
extern int mpi_local_size; EXTERN_LIB int mpi_local_size;
extern int mpi_local_rank; EXTERN_LIB int mpi_local_rank;
extern bool inside_mpi; EXTERN_LIB bool inside_mpi;
extern bool mpi_enabled; EXTERN_LIB bool mpi_enabled;
extern bool use_device_mpi; EXTERN_LIB bool use_device_mpi;
/** /**
Return number of MPI nodes. Return number of MPI nodes.

View File

@ -1,432 +1,432 @@
# *************************************************************** # ***************************************************************
# Copyright (c) 2021 Jittor. All Rights Reserved. # Copyright (c) 2021 Jittor. All Rights Reserved.
# Maintainers: # Maintainers:
# Haoyang Peng <2247838039@qq.com> # Haoyang Peng <2247838039@qq.com>
# Guowei Yang <471184555@qq.com> # Guowei Yang <471184555@qq.com>
# Dun Liang <randonlang@gmail.com>. # Dun Liang <randonlang@gmail.com>.
# #
# This file is subject to the terms and conditions defined in # This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package. # file 'LICENSE.txt', which is part of this source code package.
# *************************************************************** # ***************************************************************
import jittor as jt import jittor as jt
from functools import partial from functools import partial
#TODO:full_matrices=1 #TODO:full_matrices=1
def svd(x): def svd(x):
r''' r'''
calculate the Singular Value Decomposition of x.It follows the below fomula: calculate the Singular Value Decomposition of x.It follows the below fomula:
x = usv* x = usv*
only support full matrices == False ver now, which means: only support full matrices == False ver now, which means:
x's shape (...,M,K) x's shape (...,M,K)
u's shape (...,M,K) u's shape (...,M,K)
s's shape (...,K) s's shape (...,K)
v's shape (...,K,N) v's shape (...,K,N)
where K is min(M,N). where K is min(M,N).
:param x: :param x:
:return:u,s,v. :return:u,s,v.
''' '''
def forward_code(np, data): def forward_code(np, data):
a = data["inputs"][0] a = data["inputs"][0]
u, s, v = data["outputs"] u, s, v = data["outputs"]
#TODO:remove copyto #TODO:remove copyto
tu, ts, tv = np.linalg.svd(a, full_matrices=0) tu, ts, tv = np.linalg.svd(a, full_matrices=0)
np.copyto(u, tu) np.copyto(u, tu)
np.copyto(s, ts) np.copyto(s, ts)
np.copyto(v, tv) np.copyto(v, tv)
def backward_code(np, data): def backward_code(np, data):
def T(x): def T(x):
return np.swapaxes(x, -1, -2) return np.swapaxes(x, -1, -2)
_dot = partial(np.einsum, '...ij,...jk->...ik') _dot = partial(np.einsum, '...ij,...jk->...ik')
dout = data["dout"] dout = data["dout"]
out = data["outputs"][0] out = data["outputs"][0]
inp = data["inputs"][0] inp = data["inputs"][0]
out_index = data["out_index"] out_index = data["out_index"]
u, s, v = data["f_outputs"] u, s, v = data["f_outputs"]
v = T(v) v = T(v)
m, n = inp.shape[-2:] m, n = inp.shape[-2:]
k = np.min((m, n)) k = np.min((m, n))
i = np.reshape(np.eye(k), np.concatenate((np.ones(inp.ndim - 2, dtype=int), (k, k)))) i = np.reshape(np.eye(k), np.concatenate((np.ones(inp.ndim - 2, dtype=int), (k, k))))
if out_index == 0: if out_index == 0:
f = 1 / (s[..., np.newaxis, :] ** 2 - s[..., :, np.newaxis] ** 2 + i) f = 1 / (s[..., np.newaxis, :] ** 2 - s[..., :, np.newaxis] ** 2 + i)
gu = dout gu = dout
utgu = _dot(T(u), gu) utgu = _dot(T(u), gu)
t = (f * (utgu - T(utgu))) * s[..., np.newaxis, :] t = (f * (utgu - T(utgu))) * s[..., np.newaxis, :]
t = _dot(_dot(u, t), T(v)) t = _dot(_dot(u, t), T(v))
if m > n: if m > n:
i_minus_uut = (np.reshape(np.eye(m), np.concatenate((np.ones(inp.ndim - 2, dtype=int), (m, m)))) - i_minus_uut = (np.reshape(np.eye(m), np.concatenate((np.ones(inp.ndim - 2, dtype=int), (m, m)))) -
_dot(u, np.conj(T(u)))) _dot(u, np.conj(T(u))))
t = t + T(_dot(_dot(v / s[..., np.newaxis, :], T(gu)), i_minus_uut)) t = t + T(_dot(_dot(v / s[..., np.newaxis, :], T(gu)), i_minus_uut))
np.copyto(out, t) np.copyto(out, t)
elif out_index == 1: elif out_index == 1:
gs = dout gs = dout
t = i * gs[..., :, np.newaxis] t = i * gs[..., :, np.newaxis]
t = _dot(_dot(u, t), T(v)) t = _dot(_dot(u, t), T(v))
np.copyto(out, t) np.copyto(out, t)
elif out_index == 2: elif out_index == 2:
f = 1 / (s[..., np.newaxis, :] ** 2 - s[..., :, np.newaxis] ** 2 + i) f = 1 / (s[..., np.newaxis, :] ** 2 - s[..., :, np.newaxis] ** 2 + i)
gv = dout gv = dout
vtgv = _dot(T(v), gv) vtgv = _dot(T(v), gv)
t = s[..., :, np.newaxis] * (f * (vtgv - T(vtgv))) t = s[..., :, np.newaxis] * (f * (vtgv - T(vtgv)))
t = _dot(_dot(u, t), T(v)) t = _dot(_dot(u, t), T(v))
if m < n: if m < n:
i_minus_vvt = (np.reshape(np.eye(n), np.concatenate((np.ones(inp.ndim - 2, dtype=int), (n, n)))) - i_minus_vvt = (np.reshape(np.eye(n), np.concatenate((np.ones(inp.ndim - 2, dtype=int), (n, n)))) -
_dot(v, np.conj(T(v)))) _dot(v, np.conj(T(v))))
t = t + T(_dot(_dot(u / s[..., np.newaxis, :], T(gv)), i_minus_vvt)) t = t + T(_dot(_dot(u / s[..., np.newaxis, :], T(gv)), i_minus_vvt))
np.copyto(out, t) np.copyto(out, t)
m, n = x.shape[-2:] m, n = x.shape[-2:]
k = min(m, n) k = min(m, n)
s1 = list(x.shape) s1 = list(x.shape)
s1[-1] = k s1[-1] = k
s2 = list(x.shape) s2 = list(x.shape)
s2[-2] = k s2[-2] = k
s3 = list(x.shape)[:-2] s3 = list(x.shape)[:-2]
s3.append(k) s3.append(k)
u, s, v = jt.numpy_code( u, s, v = jt.numpy_code(
[s1, s3, s2], [s1, s3, s2],
[x.dtype, x.dtype, x.dtype], [x.dtype, x.dtype, x.dtype],
[x], [x],
forward_code, forward_code,
[backward_code], [backward_code],
) )
return u, s, v return u, s, v
def eigh(x): def eigh(x):
r""" r"""
calculate the eigenvalues and eigenvectors of x. calculate the eigenvalues and eigenvectors of x.
:param x (...,M,M): :param x (...,M,M):
:return:w, v. :return:w, v.
w (...,M) : the eigenvalues. w (...,M) : the eigenvalues.
v (...,M,M) : normalized eigenvectors. v (...,M,M) : normalized eigenvectors.
""" """
def forward_code(np, data): def forward_code(np, data):
a = data["inputs"][0] a = data["inputs"][0]
w, v = data["outputs"] w, v = data["outputs"]
tw, tv = np.linalg.eigh(a, UPLO='L') tw, tv = np.linalg.eigh(a, UPLO='L')
np.copyto(w, tw) np.copyto(w, tw)
np.copyto(v, tv) np.copyto(v, tv)
def backward_code(np, data): def backward_code(np, data):
def T(x): def T(x):
return np.swapaxes(x, -1, -2) return np.swapaxes(x, -1, -2)
_dot = partial(np.einsum, '...ij,...jk->...ik') _dot = partial(np.einsum, '...ij,...jk->...ik')
dout = data["dout"] dout = data["dout"]
out = data["outputs"][0] out = data["outputs"][0]
inp = data["inputs"][0] inp = data["inputs"][0]
out_index = data["out_index"] out_index = data["out_index"]
w, v = data["f_outputs"] w, v = data["f_outputs"]
k = int(inp.shape[-1]) k = int(inp.shape[-1])
w_repeated = np.repeat(w[..., np.newaxis], k, axis=-1) w_repeated = np.repeat(w[..., np.newaxis], k, axis=-1)
if out_index == 0: if out_index == 0:
t = _dot(v * dout[..., np.newaxis, :], T(v)) t = _dot(v * dout[..., np.newaxis, :], T(v))
np.copyto(out, t) np.copyto(out, t)
elif out_index == 1: elif out_index == 1:
if np.any(dout): if np.any(dout):
off_diag = np.ones((k, k)) - np.eye(k) off_diag = np.ones((k, k)) - np.eye(k)
F = off_diag / (T(w_repeated) - w_repeated + np.eye(k)) F = off_diag / (T(w_repeated) - w_repeated + np.eye(k))
t = _dot(_dot(v, F * _dot(T(v), dout)), T(v)) t = _dot(_dot(v, F * _dot(T(v), dout)), T(v))
np.copyto(out, t) np.copyto(out, t)
sw = x.shape[:-2] + x.shape[-1:] sw = x.shape[:-2] + x.shape[-1:]
sv = x.shape sv = x.shape
w, v = jt.numpy_code( w, v = jt.numpy_code(
[sw, sv], [sw, sv],
[x.dtype, x.dtype], [x.dtype, x.dtype],
[x], [x],
forward_code, forward_code,
[backward_code], [backward_code],
) )
return w, v return w, v
def inv(x): def inv(x):
r""" r"""
calculate the inverse of x. calculate the inverse of x.
:param x (...,M,M): :param x (...,M,M):
:return:x^-1 (...,M,M). :return:x^-1 (...,M,M).
""" """
def forward_code(np, data): def forward_code(np, data):
a = data["inputs"][0] a = data["inputs"][0]
m_a = data["outputs"][0] m_a = data["outputs"][0]
t_a = np.linalg.inv(a) t_a = np.linalg.inv(a)
np.copyto(m_a, t_a) np.copyto(m_a, t_a)
def backward_code(np, data): def backward_code(np, data):
def T(x): def T(x):
return np.swapaxes(x, -1, -2) return np.swapaxes(x, -1, -2)
_dot = partial(np.einsum, '...ij,...jk->...ik') _dot = partial(np.einsum, '...ij,...jk->...ik')
dout = data["dout"] dout = data["dout"]
out = data["outputs"][0] out = data["outputs"][0]
lmx = data["f_outputs"] lmx = data["f_outputs"]
mx = lmx[0] mx = lmx[0]
t = -_dot(_dot(T(mx), dout), T(mx)) t = -_dot(_dot(T(mx), dout), T(mx))
np.copyto(out, t) np.copyto(out, t)
lmx = jt.numpy_code( lmx = jt.numpy_code(
[x.shape], [x.shape],
[x.dtype], [x.dtype],
[x], [x],
forward_code, forward_code,
[backward_code], [backward_code],
) )
mx = lmx[0] mx = lmx[0]
return mx return mx
def pinv(x): def pinv(x):
r""" r"""
calculate the pseudo-inverse of a x. calculate the pseudo-inverse of a x.
:param x (...,M,N) :param x (...,M,N)
:return: x's pinv (...N,M) :return: x's pinv (...N,M)
""" """
def forward_code(np, data): def forward_code(np, data):
a = data["inputs"][0] a = data["inputs"][0]
m_a = data["outputs"][0] m_a = data["outputs"][0]
t_a = np.linalg.pinv(a) t_a = np.linalg.pinv(a)
np.copyto(m_a, t_a) np.copyto(m_a, t_a)
def backward_code(np, data): def backward_code(np, data):
def T(x): def T(x):
return np.swapaxes(x, -1, -2) return np.swapaxes(x, -1, -2)
_dot = partial(np.einsum, '...ij,...jk->...ik') _dot = partial(np.einsum, '...ij,...jk->...ik')
dout = data["dout"] dout = data["dout"]
out = data["outputs"][0] out = data["outputs"][0]
inp = data["inputs"][0] inp = data["inputs"][0]
lmx = data["f_outputs"] lmx = data["f_outputs"]
mx = lmx[0] mx = lmx[0]
t = T( t = T(
-_dot(_dot(mx, T(dout)), mx) -_dot(_dot(mx, T(dout)), mx)
+ _dot(_dot(_dot(mx, T(mx)), dout), np.eye(inp.shape[-2]) - _dot(inp, mx)) + _dot(_dot(_dot(mx, T(mx)), dout), np.eye(inp.shape[-2]) - _dot(inp, mx))
+ _dot(_dot(_dot(np.eye(mx.shape[-2]) - _dot(mx, inp), dout), T(mx)), mx) + _dot(_dot(_dot(np.eye(mx.shape[-2]) - _dot(mx, inp), dout), T(mx)), mx)
) )
np.copyto(out, t) np.copyto(out, t)
sw = list(x.shape[:-2]) + [x.shape[-1]] + [x.shape[-2]] sw = list(x.shape[:-2]) + [x.shape[-1]] + [x.shape[-2]]
lmx = jt.numpy_code( lmx = jt.numpy_code(
[sw], [sw],
[x.dtype], [x.dtype],
[x], [x],
forward_code, forward_code,
[backward_code], [backward_code],
) )
mx = lmx[0] mx = lmx[0]
return mx return mx
def det(x): def det(x):
r""" r"""
calculate the determinant of x. calculate the determinant of x.
:param x (...,M,M): :param x (...,M,M):
:return:|x| (...,1) :return:|x| (...,1)
""" """
def forward_code(np, data): def forward_code(np, data):
a = data["inputs"][0] a = data["inputs"][0]
L = data["outputs"][0] L = data["outputs"][0]
tL = np.linalg.det(a) tL = np.linalg.det(a)
np.copyto(L, tL) np.copyto(L, tL)
def backward_code(np, data): def backward_code(np, data):
def T(x): def T(x):
return np.swapaxes(x, -1, -2) return np.swapaxes(x, -1, -2)
_dot = partial(np.einsum, '...ij,...jk->...ik') _dot = partial(np.einsum, '...ij,...jk->...ik')
dout = data["dout"] dout = data["dout"]
out = data["outputs"][0] out = data["outputs"][0]
f_out = data["f_outputs"][0] f_out = data["f_outputs"][0]
inp = data["inputs"][0] inp = data["inputs"][0]
n_d = np.reshape(dout, np.shape(dout) + (1, 1)) n_d = np.reshape(dout, np.shape(dout) + (1, 1))
n_o = np.reshape(f_out, np.shape(f_out) + (1, 1)) n_o = np.reshape(f_out, np.shape(f_out) + (1, 1))
s = n_d * n_o * T(np.linalg.inv(inp)) s = n_d * n_o * T(np.linalg.inv(inp))
np.copyto(out, s) np.copyto(out, s)
s = x.shape s = x.shape
x_s = s[:-2] x_s = s[:-2]
if len(s) == 2: if len(s) == 2:
x_s.append(1) x_s.append(1)
l_det = jt.numpy_code( l_det = jt.numpy_code(
[x_s], [x_s],
[x.dtype], [x.dtype],
[x], [x],
forward_code, forward_code,
[backward_code], [backward_code],
) )
det = l_det[0] det = l_det[0]
return det return det
def slogdet(x): def slogdet(x):
r""" r"""
calculate the sign and log of the determinant of x. calculate the sign and log of the determinant of x.
:param x (...,M,M): :param x (...,M,M):
:return sign, x's logdet. :return sign, x's logdet.
sign array decides the sign of determinant and their values can be -1,0,1.Only Real number now.0 means det is 0 and logdet is -inf. sign array decides the sign of determinant and their values can be -1,0,1.Only Real number now.0 means det is 0 and logdet is -inf.
logdet in shape (...,1). logdet in shape (...,1).
""" """
def forward_code(np, data): def forward_code(np, data):
a = data["inputs"][0] a = data["inputs"][0]
sign, m_a = data["outputs"] sign, m_a = data["outputs"]
sign_, t_a = np.linalg.slogdet(a) sign_, t_a = np.linalg.slogdet(a)
np.copyto(m_a, t_a) np.copyto(m_a, t_a)
np.copyto(sign, sign_) np.copyto(sign, sign_)
def backward_code(np, data): def backward_code(np, data):
def T(x): def T(x):
return np.swapaxes(x, -1, -2) return np.swapaxes(x, -1, -2)
_dot = partial(np.einsum, '...ij,...jk->...ik') _dot = partial(np.einsum, '...ij,...jk->...ik')
dout = data["dout"] dout = data["dout"]
out = data["outputs"][0] out = data["outputs"][0]
inp = data["inputs"][0] inp = data["inputs"][0]
out_index = data["out_index"] out_index = data["out_index"]
if out_index == 0: if out_index == 0:
np.copyto(out, 0) np.copyto(out, 0)
if out_index == 1: if out_index == 1:
t = np.reshape(dout, np.shape(dout) + (1, 1)) t = np.reshape(dout, np.shape(dout) + (1, 1))
t = t * T(np.linalg.inv(inp)) t = t * T(np.linalg.inv(inp))
np.copyto(out, t) np.copyto(out, t)
s = x.shape s = x.shape
det_s = s[:-2] det_s = s[:-2]
if len(det_s) == 0: if len(det_s) == 0:
det_s.append(1) det_s.append(1)
sign, mx = jt.numpy_code( sign, mx = jt.numpy_code(
[det_s, det_s], [det_s, det_s],
[x.dtype, x.dtype], [x.dtype, x.dtype],
[x], [x],
forward_code, forward_code,
[backward_code], [backward_code],
) )
return sign, mx return sign, mx
def cholesky(x): def cholesky(x):
r""" r"""
do Cholesky decomposition of x in the form of below formula: do Cholesky decomposition of x in the form of below formula:
x = LL^T x = LL^T
x must be a Hermite and positive-definite matrix. L is a lower-triangular matrix. x must be a Hermite and positive-definite matrix. L is a lower-triangular matrix.
:param x (...,M,M): :param x (...,M,M):
:return: L (...,M,M). :return: L (...,M,M).
""" """
def forward_code(np, data): def forward_code(np, data):
a = data["inputs"][0] a = data["inputs"][0]
L = data["outputs"][0] L = data["outputs"][0]
tL = np.linalg.cholesky(a) tL = np.linalg.cholesky(a)
np.copyto(L, tL) np.copyto(L, tL)
def backward_code(np, data): def backward_code(np, data):
def T(x): def T(x):
return np.swapaxes(x, -1, -2) return np.swapaxes(x, -1, -2)
_dot = partial(np.einsum, '...ij,...jk->...ik') _dot = partial(np.einsum, '...ij,...jk->...ik')
dout = data["dout"] dout = data["dout"]
out = data["outputs"][0] out = data["outputs"][0]
f_out = data["f_outputs"][0] f_out = data["f_outputs"][0]
solve_trans = lambda a, b: np.linalg.solve(T(a), b) solve_trans = lambda a, b: np.linalg.solve(T(a), b)
phi = lambda X: np.tril(X) / (1. + np.eye(X.shape[-1])) phi = lambda X: np.tril(X) / (1. + np.eye(X.shape[-1]))
def conjugate_solve(L, X): def conjugate_solve(L, X):
return solve_trans(L, T(solve_trans(L, T(X)))) return solve_trans(L, T(solve_trans(L, T(X))))
s = conjugate_solve(f_out, phi(np.einsum('...ki,...kj->...ij', f_out, dout))) s = conjugate_solve(f_out, phi(np.einsum('...ki,...kj->...ij', f_out, dout)))
s = (s + T(s)) / 2. s = (s + T(s)) / 2.
np.copyto(out, s) np.copyto(out, s)
lL = jt.numpy_code( lL = jt.numpy_code(
[x.shape], [x.shape],
[x.dtype], [x.dtype],
[x], [x],
forward_code, forward_code,
[backward_code], [backward_code],
) )
L = lL[0] L = lL[0]
return L return L
def solve(a,b): def solve(a,b):
r""" r"""
Solve a linear matrix equation Ax = B.This is done by calculating x = A^-1B.So A must not be singular. Solve a linear matrix equation Ax = B.This is done by calculating x = A^-1B.So A must not be singular.
:param a:(...,M,M) :param a:(...,M,M)
:param b:(...,M) :param b:(...,M)
:return:solution of Ax = b formula.x in the shape of (...M) :return:solution of Ax = b formula.x in the shape of (...M)
""" """
def forward_code(np, data): def forward_code(np, data):
a, b = data["inputs"] a, b = data["inputs"]
L = data["outputs"][0] L = data["outputs"][0]
ans = np.linalg.solve(a, b) ans = np.linalg.solve(a, b)
np.copyto(L, ans) np.copyto(L, ans)
def backward_code1(np, data): def backward_code1(np, data):
def T(x): def T(x):
return np.swapaxes(x, -1, -2) return np.swapaxes(x, -1, -2)
_dot = partial(np.einsum, '...ij,...jk->...ik') _dot = partial(np.einsum, '...ij,...jk->...ik')
dout = data["dout"] dout = data["dout"]
out = data["outputs"][0] out = data["outputs"][0]
f_out = data["f_outputs"][0] f_out = data["f_outputs"][0]
inp = data["inputs"][0] inp = data["inputs"][0]
updim = lambda x: x if x.ndim == a.ndim else x[..., None] updim = lambda x: x if x.ndim == a.ndim else x[..., None]
t = -_dot(updim(np.linalg.solve(T(inp), dout)), T(updim(f_out))) t = -_dot(updim(np.linalg.solve(T(inp), dout)), T(updim(f_out)))
np.copyto(out, t) np.copyto(out, t)
def backward_code2(np, data): def backward_code2(np, data):
out = data["outputs"][0] out = data["outputs"][0]
np.copyto(out, 0) np.copyto(out, 0)
l_ans = jt.numpy_code( l_ans = jt.numpy_code(
[b.shape], [b.shape],
[b.dtype], [b.dtype],
[a, b], [a, b],
forward_code, forward_code,
[backward_code1, backward_code2], [backward_code1, backward_code2],
) )
ans = l_ans[0] ans = l_ans[0]
return ans return ans
def qr(x): def qr(x):
r""" r"""
do the qr factorization of x in the below formula: do the qr factorization of x in the below formula:
x = QR where Q is orthogonal matrix and R is upper-triangle matrix. x = QR where Q is orthogonal matrix and R is upper-triangle matrix.
:param x (...,M,M): :param x (...,M,M):
:return:q,r as the result of qr factorization.They are both in the shape of (...,M,M). :return:q,r as the result of qr factorization.They are both in the shape of (...,M,M).
""" """
def forward_code(np, data): def forward_code(np, data):
a = data["inputs"][0] a = data["inputs"][0]
q, r = data["outputs"] q, r = data["outputs"]
Q, R = np.linalg.qr(a) Q, R = np.linalg.qr(a)
np.copyto(q,Q) np.copyto(q,Q)
np.copyto(r,R) np.copyto(r,R)
def backward_code(np, data): def backward_code(np, data):
def T(x): def T(x):
return np.swapaxes(x, -1, -2) return np.swapaxes(x, -1, -2)
_dot = partial(np.einsum, '...ij,...jk->...ik') _dot = partial(np.einsum, '...ij,...jk->...ik')
_harmard = partial(np.einsum, '...ij,...ij->...ij') _harmard = partial(np.einsum, '...ij,...ij->...ij')
dout = data["dout"] dout = data["dout"]
out = data["outputs"][0] out = data["outputs"][0]
q, r = data["f_outputs"] q, r = data["f_outputs"]
out_index = data["out_index"] out_index = data["out_index"]
#pl = np.tril(np.ones((inp.shape[-1],inp.shape[-1])))-diags #pl = np.tril(np.ones((inp.shape[-1],inp.shape[-1])))-diags
if out_index == 0: # Q_TERM if out_index == 0: # Q_TERM
q_t = _dot(T(q),dout) q_t = _dot(T(q),dout)
rhs_solve = q_t - T(q_t) rhs_solve = q_t - T(q_t)
rhs_solve = T(np.tril(rhs_solve,-1)) rhs_solve = T(np.tril(rhs_solve,-1))
qsolve = np.linalg.solve(r,rhs_solve) qsolve = np.linalg.solve(r,rhs_solve)
qsolve = T(qsolve) qsolve = T(qsolve)
tq = _dot(q,qsolve) tq = _dot(q,qsolve)
np.copyto(out,tq) np.copyto(out,tq)
else: #R_TERM else: #R_TERM
r_t = _dot(r ,T(dout)) r_t = _dot(r ,T(dout))
rhs_solve = r_t - T(r_t) rhs_solve = r_t - T(r_t)
rhs_solve = np.tril(rhs_solve,-1) rhs_solve = np.tril(rhs_solve,-1)
rhs_solve = T(rhs_solve) rhs_solve = T(rhs_solve)
r_solve = np.linalg.solve(r,rhs_solve) r_solve = np.linalg.solve(r,rhs_solve)
tr = _dot(q,(T(r_solve) + dout)) tr = _dot(q,(T(r_solve) + dout))
np.copyto(out,tr) np.copyto(out,tr)
q, r = jt.numpy_code( q, r = jt.numpy_code(
[x.shape,x.shape], [x.shape,x.shape],
[x.dtype,x.dtype], [x.dtype,x.dtype],
[x], [x],
forward_code, forward_code,
[backward_code], [backward_code],
) )
return q, r return q, r

View File

@ -614,7 +614,7 @@ def compile_src(src, h, basename):
(void)n; (void)n;
if (arg0 >= GET_RAW_PTR({dfs[0]["scope_name"]},self)->size()) {{ if (arg0 >= GET_RAW_PTR({dfs[0]["scope_name"]},self)->size()) {{
PyErr_SetString(PyExc_IndexError, ""); PyErr_SetString(PyExc_IndexError, "");
return 0; return (PyObject*)nullptr;
}} }}
""" """
@ -675,7 +675,7 @@ def compile_src(src, h, basename):
error_log_code = generate_error_code_from_func_header(func_head, target_scope_name, name, dfs, basename ,h, class_info) error_log_code = generate_error_code_from_func_header(func_head, target_scope_name, name, dfs, basename ,h, class_info)
func = f""" func = f"""
{func_cast}[]{func_head} {{ {func_cast}[]{func_head} {{
try {{ try {{_JT_SEH_START3;
{func_fill}; {func_fill};
uint64 arg_filled=0; uint64 arg_filled=0;
(void)arg_filled; (void)arg_filled;
@ -689,7 +689,7 @@ def compile_src(src, h, basename):
for did in range(len(arr_func_return)) for did in range(len(arr_func_return))
])} ])}
LOGf << "Not a valid call."; LOGf << "Not a valid call.";
}} catch (const std::exception& e) {{ _JT_SEH_END3; }} catch (const std::exception& e) {{
if (!PyErr_Occurred()) {{ if (!PyErr_Occurred()) {{
std::stringstream ss; std::stringstream ss;
ss {error_log_code}; ss {error_log_code};
@ -775,6 +775,7 @@ def compile_src(src, h, basename):
if include_name.endswith("var_slices.h"): if include_name.endswith("var_slices.h"):
src_code += '#include "var_holder.h"\n' src_code += '#include "var_holder.h"\n'
src_code += f""" src_code += f"""
#include "utils/seh.h"
#include "pyjt/py_converter.h" #include "pyjt/py_converter.h"
#include "pyjt/py_arg_printer.h" #include "pyjt/py_arg_printer.h"
#include "common.h" #include "common.h"

View File

@ -5,7 +5,6 @@
// file 'LICENSE.txt', which is part of this source code package. // file 'LICENSE.txt', which is part of this source code package.
// *************************************************************** // ***************************************************************
#pragma once #pragma once
#include <stddef.h>
#include <memory> #include <memory>
#include <functional> #include <functional>
#include "utils/log.h" #include "utils/log.h"
@ -26,4 +25,14 @@ void expect_error(std::function<void()> func);
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized" #pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
#pragma GCC diagnostic ignored "-Wdiv-by-zero" #pragma GCC diagnostic ignored "-Wdiv-by-zero"
#endif #endif
#endif #endif
#ifdef _WIN32
#ifndef __restrict__
#define __restrict__ __restrict
#endif
#endif
#ifdef _MSC_VER
#define __builtin_popcount __popcnt
#endif

View File

@ -14,7 +14,7 @@ namespace jittor {
// @pyjt(number_of_hold_vars) // @pyjt(number_of_hold_vars)
inline static uint64 get_number_of_hold_vars() { inline static uint64 get_number_of_hold_vars() {
return VarHolder::hold_vars.size(); return hold_vars.size();
} }
// @pyjt(number_of_lived_vars) // @pyjt(number_of_lived_vars)

View File

@ -34,7 +34,7 @@ void EventQueue::Worker::stop() {
LOGv << "stopped event queue worker."; LOGv << "stopped event queue worker.";
} }
extern vector<void(*)()> cleanup_callback; EXTERN_LIB vector<void(*)()> cleanup_callback;
EventQueue::Worker::Worker() : thread(EventQueue::Worker::start) { EventQueue::Worker::Worker() : thread(EventQueue::Worker::start) {
cleanup_callback.push_back(&EventQueue::Worker::stop); cleanup_callback.push_back(&EventQueue::Worker::stop);

View File

@ -88,7 +88,7 @@ struct EventQueue {
} }
}; };
extern EventQueue event_queue; EXTERN_LIB EventQueue event_queue;
#endif #endif

View File

@ -28,16 +28,17 @@
#include "memory_profiler.h" #include "memory_profiler.h"
#include "misc/nan_checker.h" #include "misc/nan_checker.h"
#include "memory_profiler.h" #include "memory_profiler.h"
#include "utils/seh.h"
namespace jittor { namespace jittor {
Executor exe; Executor exe;
extern MemoryProfiler memory_profiler; EXTERN_LIB MemoryProfiler memory_profiler;
DECLARE_FLAG(int, profile_memory_enable); DECLARE_FLAG(int, profile_memory_enable);
DEFINE_FLAG(int, gopt_disable, 0, "Disable graph optimizer."); DEFINE_FLAG(int, gopt_disable, 0, "Disable graph optimizer.");
// from fetch_op.cc // from fetch_op.cc
extern list<VarPtr> fetcher_to_free; EXTERN_LIB list<VarPtr> fetcher_to_free;
// from cuda_managed_allocator // from cuda_managed_allocator
#ifdef HAS_CUDA #ifdef HAS_CUDA
DECLARE_FLAG(int, use_cuda_managed_allocator); DECLARE_FLAG(int, use_cuda_managed_allocator);
@ -414,7 +415,7 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
#ifdef HAS_CUDA #ifdef HAS_CUDA
int sync_times = 0; int sync_times = 0;
#endif #endif
auto& jkl = jk; auto& jkl = get_jk();
for (uint rid=0; rid<queue.size(); rid++) { for (uint rid=0; rid<queue.size(); rid++) {
int root = queue[rid]; int root = queue[rid];
Op* op = ops[root]; Op* op = ops[root];
@ -471,7 +472,9 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
} }
#endif #endif
last_is_cuda = is_cuda; last_is_cuda = is_cuda;
_JT_SEH_START2;
op->do_run_after_prepare(jkl); op->do_run_after_prepare(jkl);
_JT_SEH_END2;
#ifdef HAS_CUDA #ifdef HAS_CUDA
// migrate to gpu // migrate to gpu
if (PREDICT_BRANCH_NOT_TAKEN((!is_cuda && use_cuda && !use_cuda_managed_allocator))) { if (PREDICT_BRANCH_NOT_TAKEN((!is_cuda && use_cuda && !use_cuda_managed_allocator))) {

View File

@ -24,7 +24,7 @@ struct Executor {
void run_sync(vector<Var*> vars, bool device_sync); void run_sync(vector<Var*> vars, bool device_sync);
}; };
extern Executor exe; EXTERN_LIB Executor exe;
void load_fused_op(FusedOp& fused_op, vector<int>& fuse_ops, vector<Op*>& ops, int ll, int rr, int64 tt); void load_fused_op(FusedOp& fused_op, vector<int>& fuse_ops, vector<Op*>& ops, int ll, int rr, int64 tt);

View File

@ -32,6 +32,7 @@ loop_options_t& FusedOp::get_loop_options_tuned() {
} }
void FusedOp::update_jit_key() { void FusedOp::update_jit_key() {
JK& jk = get_jk();
jk.clear(); jk.clear();
do_jit_prepare(jk); do_jit_prepare(jk);
} }
@ -256,7 +257,8 @@ int FusedOp::has(Node* node) {
return context->node_id.count(node); return context->node_id.count(node);
} }
void FusedOp::do_run(){ void FusedOp::do_run() {
JK& jk = get_jk();
do_prepare(jk); do_prepare(jk);
do_run_after_prepare(jk); do_run_after_prepare(jk);
} }

View File

@ -24,7 +24,7 @@ struct FusedOpContext {
void setup(FusedOp* fop); void setup(FusedOp* fop);
}; };
extern string_view_map<FusedOpContext*> jit_fused_ops; EXTERN_LIB string_view_map<FusedOpContext*> jit_fused_ops;
struct FusedOp final : Op { struct FusedOp final : Op {
vector<Op*> ops; vector<Op*> ops;

View File

@ -153,8 +153,8 @@ vector<VarPtr> grad(Var* loss, vector<Var*> targets) {
if (op->flags.get(NodeFlags::_grads)) { if (op->flags.get(NodeFlags::_grads)) {
// backward together // backward together
auto n_i = op->inputs().size(); auto n_i = op->inputs().size();
Var* douts[n_o]; STACK_ALLOC(Var*, douts, n_o);
VarPtr dins[n_i]; STACK_ALLOC(VarPtr, dins, n_i);
// dump "for (Var* out : op->outputs())" // dump "for (Var* out : op->outputs())"
for (int i=0; i<n_o; i++,j++) { for (int i=0; i<n_o; i++,j++) {
auto id = id_buffer[j].second; auto id = id_buffer[j].second;

View File

@ -13,7 +13,7 @@ namespace jittor {
DEFINE_FLAG(int, check_graph, 0, "Unify graph sanity check."); DEFINE_FLAG(int, check_graph, 0, "Unify graph sanity check.");
extern unordered_map<void*, int64> lived_nodes; EXTERN_LIB unordered_map<void*, int64> lived_nodes;
template <typename T> template <typename T>
string ss_convert(T x) { string ss_convert(T x) {
@ -25,7 +25,7 @@ string ss_convert(T x) {
void do_graph_check() { void do_graph_check() {
vector<Node*> queue; vector<Node*> queue;
unordered_map<Node*,int> visited; unordered_map<Node*,int> visited;
for (auto& vh : VarHolder::hold_vars) { for (auto& vh : hold_vars) {
if (0==visited[vh->var]++) if (0==visited[vh->var]++)
queue.push_back(vh->var); queue.push_back(vh->var);
} }
@ -85,7 +85,7 @@ void do_graph_check() {
DumpGraphs dump_all_graphs() { DumpGraphs dump_all_graphs() {
vector<Node*> queue; vector<Node*> queue;
auto t = ++Node::tflag_count; auto t = ++Node::tflag_count;
for (auto& vh : VarHolder::hold_vars) for (auto& vh : hold_vars)
if (vh->var->tflag != t) { if (vh->var->tflag != t) {
vh->var->tflag = t; vh->var->tflag = t;
queue.push_back(vh->var); queue.push_back(vh->var);

View File

@ -27,9 +27,9 @@ vector<set_seed_callback> callbacks;
int current_seed; int current_seed;
// fron fetch_op.cc // fron fetch_op.cc
extern list<VarPtr> fetcher; EXTERN_LIB list<VarPtr> fetcher;
extern list<VarPtr> fetcher_to_free; EXTERN_LIB list<VarPtr> fetcher_to_free;
extern vector<void(*)()> cleanup_callback; EXTERN_LIB vector<void(*)()> cleanup_callback;
void cleanup() { void cleanup() {
fetcher_to_free.clear(); fetcher_to_free.clear();

View File

@ -37,10 +37,13 @@ namespace jit_compiler {
std::mutex dl_open_mutex; std::mutex dl_open_mutex;
jit_op_entry_t load_jit_lib(string name, string symbol_name="jit_entry") { jit_op_entry_t load_jit_lib(string name, string symbol_name="jit_entry") {
std::lock_guard<std::mutex> lock(dl_open_mutex);
const char* msg = ""; const char* msg = "";
LOGvv << "Opening jit lib:" << name; LOGvv << "Opening jit lib:" << name;
#ifdef _WIN32 #ifdef _WIN32
void* handle = (void*)LoadLibrary(name.c_str()); void* handle = (void*)LoadLibraryExA(name.c_str(), nullptr,
LOAD_LIBRARY_SEARCH_DEFAULT_DIRS |
LOAD_LIBRARY_SEARCH_USER_DIRS);
#elif defined(__linux__) #elif defined(__linux__)
void* handle = dlopen(name.c_str(), RTLD_LAZY | RTLD_DEEPBIND | RTLD_LOCAL); void* handle = dlopen(name.c_str(), RTLD_LAZY | RTLD_DEEPBIND | RTLD_LOCAL);
msg = dlerror(); msg = dlerror();
@ -76,7 +79,11 @@ static string get_symbol_name(const string& jit_key) {
op_name = Op::file_name_to_class_name(op_name); op_name = Op::file_name_to_class_name(op_name);
// _ZN7jittorXyyyyyy7jit_runEv // _ZN7jittorXyyyyyy7jit_runEv
// jittor::yyyyyy::jit_run // jittor::yyyyyy::jit_run
#ifdef _MSC_VER
op_name = "?jit_run@"+op_name+"Op@jittor@@QEAAXXZ";
#else
op_name = "_ZN6jittor"+S(op_name.size()+2)+op_name+"Op7jit_runEv"; op_name = "_ZN6jittor"+S(op_name.size()+2)+op_name+"Op7jit_runEv";
#endif
return op_name; return op_name;
} }
@ -95,13 +102,15 @@ jit_op_entry_t compile(const string& jit_key, const string& src, const bool is_c
if (rewrite_op || !file_exist(jit_src_path)) if (rewrite_op || !file_exist(jit_src_path))
write(jit_src_path, src); write(jit_src_path, src);
string cmd; string cmd;
#ifndef _MSC_VER
if (is_cuda_op) { if (is_cuda_op) {
cmd = nvcc_path cmd = "\"" + nvcc_path + "\""
+ " \"" + jit_src_path + "\"" + other_src + " \"" + jit_src_path + "\"" + other_src
+ nvcc_flags + extra_flags + nvcc_flags + extra_flags
+ " -o \"" + jit_lib_path + "\""; + " -o \"" + jit_lib_path + "\"";
} else { } else {
cmd = cc_path cmd = "\"" + cc_path + "\""
+ " \"" + jit_src_path + "\"" + other_src + " \"" + jit_src_path + "\"" + other_src
+ cc_flags + extra_flags + cc_flags + extra_flags
+ " -o \"" + jit_lib_path + "\""; + " -o \"" + jit_lib_path + "\"";
@ -110,6 +119,24 @@ jit_op_entry_t compile(const string& jit_key, const string& src, const bool is_c
"--cc_path=" + cmd; "--cc_path=" + cmd;
#endif #endif
} }
#else // Windows _MSC_VER
if (is_cuda_op) {
cmd = "\"" + nvcc_path + "\""
+ " \"" + jit_src_path + "\"" + other_src
+ nvcc_flags + extra_flags
+ " -o \"" + jit_lib_path + "\"";
} else {
auto symbol_name = get_symbol_name(jit_key);
auto pos = cc_flags.find("-link");
auto cc_flags1 = cc_flags.substr(0, pos);
auto cc_flags2 = cc_flags.substr(pos);
cmd = "\"" + cc_path + "\""
+ " \"" + jit_src_path + "\"" + other_src
+ cc_flags1 + extra_flags
+ " -Fe: \"" + jit_lib_path + "\" " + cc_flags2 + " -EXPORT:\""
+ symbol_name + "\"";
}
#endif
cache_compile(cmd, cache_path, jittor_path); cache_compile(cmd, cache_path, jittor_path);
auto symbol_name = get_symbol_name(jit_key); auto symbol_name = get_symbol_name(jit_key);
auto jit_entry = load_jit_lib(jit_lib_path, symbol_name); auto jit_entry = load_jit_lib(jit_lib_path, symbol_name);

View File

@ -6,17 +6,17 @@
// *************************************************************** // ***************************************************************
#ifndef _WIN32 #ifndef _WIN32
#include <sys/mman.h> #include <sys/mman.h>
#include <unistd.h>
#endif #endif
#include <sstream> #include <sstream>
#include <unistd.h>
#include "jit_key.h" #include "jit_key.h"
#include "utils/str_utils.h" #include "utils/str_utils.h"
namespace jittor { namespace jittor {
extern thread_local size_t protected_page;
#ifndef _WIN32 #ifndef _WIN32
EXTERN_LIB thread_local size_t protected_page;
static size_t get_buffer_end_page(size_t buffer_end) { static size_t get_buffer_end_page(size_t buffer_end) {
// get the last complete page in buffer // get the last complete page in buffer
// 4k align : // 4k align :
@ -121,4 +121,8 @@ vector<pair<string,string>> parse_jit_keys(const string& s) {
thread_local JitKey jk; thread_local JitKey jk;
JK& get_jk() {
return jk;
}
} // jittor } // jittor

View File

@ -78,8 +78,8 @@ struct __jk_int256 {
int64 a,b,c,d; int64 a,b,c,d;
}; };
extern thread_local JitKey jk;
typedef JitKey JK; typedef JitKey JK;
EXTERN_LIB JK& get_jk();
inline JK& operator<<(JK& jk, const char* s) { inline JK& operator<<(JK& jk, const char* s) {
int i; int i;
@ -284,7 +284,11 @@ getChr(s,35)
#define getChr(name, ii) ((_CS_MIN(ii,MAX_CONST_CHAR))<sizeof(name)/sizeof(*name)?name[ii]:0) #define getChr(name, ii) ((_CS_MIN(ii,MAX_CONST_CHAR))<sizeof(name)/sizeof(*name)?name[ii]:0)
#ifdef _MSC_VER
#define _CS(str) str
#else
#define _CS(str) _CS_G<_CS_T(str)>() #define _CS(str) _CS_G<_CS_T(str)>()
#endif
template <char c1, char c2, char c3, char c4, char... Chars_> struct _CS_G { template <char c1, char c2, char c3, char c4, char... Chars_> struct _CS_G {
}; };

View File

@ -8,10 +8,15 @@
// file 'LICENSE.txt', which is part of this source code package. // file 'LICENSE.txt', which is part of this source code package.
// *************************************************************** // ***************************************************************
#include <stdio.h> #include <stdio.h>
#include <unistd.h>
#ifdef _WIN32 #ifdef _WIN32
#include <windows.h>
#include <fileapi.h> #include <fileapi.h>
#include <process.h>
#include <io.h>
#define getpid _getpid
#define open _open
#else #else
#include <unistd.h>
#endif #endif
#include <fcntl.h> #include <fcntl.h>
#include <errno.h> #include <errno.h>

View File

@ -19,7 +19,7 @@ void lock();
void unlock(); void unlock();
extern int _has_lock; EXTERN_LIB int _has_lock;
struct lock_guard { struct lock_guard {
int has_lock = 0; int has_lock = 0;

View File

@ -27,7 +27,7 @@ struct Allocator {
}; };
struct AlignedAllocator; struct AlignedAllocator;
extern AlignedAllocator aligned_allocator; EXTERN_LIB AlignedAllocator aligned_allocator;
struct Allocation { struct Allocation {
void* ptr; void* ptr;
@ -48,7 +48,7 @@ struct Allocation {
{ if (ptr) allocator->free(ptr, size, allocation); } { if (ptr) allocator->free(ptr, size, allocation); }
}; };
extern Allocator* cpu_allocator; EXTERN_LIB Allocator* cpu_allocator;
Allocator* get_allocator(bool temp_allocator=false); Allocator* get_allocator(bool temp_allocator=false);
// @pyjt(gc) // @pyjt(gc)
void gc_all(); void gc_all();

View File

@ -25,7 +25,11 @@ void* AlignedAllocator::alloc(size_t size, size_t& allocation) {
} }
void AlignedAllocator::free(void* mem_ptr, size_t size, const size_t& allocation) { void AlignedAllocator::free(void* mem_ptr, size_t size, const size_t& allocation) {
#ifdef _WIN32
_aligned_free(mem_ptr);
#else
::free(mem_ptr); ::free(mem_ptr);
#endif
} }
} // jittor } // jittor

View File

@ -16,6 +16,6 @@ struct AlignedAllocator : Allocator {
void free(void* mem_ptr, size_t size, const size_t& allocation) override; void free(void* mem_ptr, size_t size, const size_t& allocation) override;
}; };
extern AlignedAllocator aligned_allocator; EXTERN_LIB AlignedAllocator aligned_allocator;
} // jittor } // jittor

View File

@ -12,7 +12,7 @@
namespace jittor { namespace jittor {
CudaDeviceAllocator cuda_device_allocator; CudaDeviceAllocator cuda_device_allocator;
extern bool no_cuda_error_when_free; EXTERN_LIB bool no_cuda_error_when_free;
const char* CudaDeviceAllocator::name() const {return "cuda_device";} const char* CudaDeviceAllocator::name() const {return "cuda_device";}

View File

@ -17,7 +17,7 @@ struct CudaDeviceAllocator : Allocator {
void free(void* mem_ptr, size_t size, const size_t& allocation) override; void free(void* mem_ptr, size_t size, const size_t& allocation) override;
}; };
extern CudaDeviceAllocator cuda_device_allocator; EXTERN_LIB CudaDeviceAllocator cuda_device_allocator;
} }

View File

@ -24,9 +24,9 @@ struct DualAllocation {
size_t host_allocation, device_allocation; size_t host_allocation, device_allocation;
}; };
extern SFRLAllocator cuda_dual_host_allocator; EXTERN_LIB SFRLAllocator cuda_dual_host_allocator;
extern SFRLAllocator cuda_dual_device_allocator; EXTERN_LIB SFRLAllocator cuda_dual_device_allocator;
extern bool no_cuda_error_when_free; EXTERN_LIB bool no_cuda_error_when_free;
struct CudaDualAllocator : Allocator { struct CudaDualAllocator : Allocator {
//for recycle block_id //for recycle block_id
@ -74,11 +74,11 @@ struct CudaDualAllocator : Allocator {
} }
}; };
extern CudaDualAllocator cuda_dual_allocator; EXTERN_LIB CudaDualAllocator cuda_dual_allocator;
namespace cuda_dual_local { namespace cuda_dual_local {
extern list<Allocation> allocations; EXTERN_LIB list<Allocation> allocations;
} }
@ -115,7 +115,7 @@ struct DelayFree final : Allocator {
} }
}; };
extern DelayFree delay_free; EXTERN_LIB DelayFree delay_free;
} }

View File

@ -12,7 +12,7 @@
namespace jittor { namespace jittor {
CudaHostAllocator cuda_host_allocator; CudaHostAllocator cuda_host_allocator;
extern bool no_cuda_error_when_free; EXTERN_LIB bool no_cuda_error_when_free;
const char* CudaHostAllocator::name() const {return "cuda_host";} const char* CudaHostAllocator::name() const {return "cuda_host";}

View File

@ -17,7 +17,7 @@ struct CudaHostAllocator : Allocator {
void free(void* mem_ptr, size_t size, const size_t& allocation) override; void free(void* mem_ptr, size_t size, const size_t& allocation) override;
}; };
extern CudaHostAllocator cuda_host_allocator; EXTERN_LIB CudaHostAllocator cuda_host_allocator;
} }

View File

@ -13,7 +13,7 @@ namespace jittor {
CudaManagedAllocator cuda_managed_allocator; CudaManagedAllocator cuda_managed_allocator;
DEFINE_FLAG(int, use_cuda_managed_allocator, 1, "Enable cuda_managed_allocator"); DEFINE_FLAG(int, use_cuda_managed_allocator, 1, "Enable cuda_managed_allocator");
extern bool no_cuda_error_when_free; EXTERN_LIB bool no_cuda_error_when_free;
const char* CudaManagedAllocator::name() const {return "cuda_managed";} const char* CudaManagedAllocator::name() const {return "cuda_managed";}

View File

@ -17,7 +17,7 @@ struct CudaManagedAllocator : Allocator {
void free(void* mem_ptr, size_t size, const size_t& allocation) override; void free(void* mem_ptr, size_t size, const size_t& allocation) override;
}; };
extern CudaManagedAllocator cuda_managed_allocator; EXTERN_LIB CudaManagedAllocator cuda_managed_allocator;
DECLARE_FLAG(int, use_cuda_managed_allocator); DECLARE_FLAG(int, use_cuda_managed_allocator);
} }

View File

@ -16,7 +16,9 @@
#elif defined(_WIN32) #elif defined(_WIN32)
#include <windows.h> #include <windows.h>
#endif #endif
#ifndef _WIN32
#include <unistd.h> #include <unistd.h>
#endif
#include "var.h" #include "var.h"
#include "op.h" #include "op.h"
@ -62,7 +64,7 @@ void display_memory_info(const char* fileline, bool dump_var, bool red_color) {
FloatOutput{(double)mem_info.total_cpu_ram, " KMG", 1024, "B"}; FloatOutput{(double)mem_info.total_cpu_ram, " KMG", 1024, "B"};
log << "total_cuda_ram:" << log << "total_cuda_ram:" <<
FloatOutput{(double)mem_info.total_cuda_ram, " KMG", 1024, "B"} >> "\n"; FloatOutput{(double)mem_info.total_cuda_ram, " KMG", 1024, "B"} >> "\n";
log << "hold_vars:" << VarHolder::hold_vars.size() log << "hold_vars:" << hold_vars.size()
<< "lived_vars:" << Var::number_of_lived_vars << "lived_vars:" << Var::number_of_lived_vars
<< "lived_ops:" << Op::number_of_lived_ops >> '\n'; << "lived_ops:" << Op::number_of_lived_ops >> '\n';
log << "update queue:" << update_queue.queue.size() log << "update queue:" << update_queue.queue.size()
@ -72,7 +74,7 @@ void display_memory_info(const char* fileline, bool dump_var, bool red_color) {
// get the oldest var // get the oldest var
// vector<Node*> queue; // vector<Node*> queue;
// auto t = ++Node::tflag_count; // auto t = ++Node::tflag_count;
// for (auto& vh : VarHolder::hold_vars) // for (auto& vh : hold_vars)
// if (vh->var->tflag != t) { // if (vh->var->tflag != t) {
// vh->var->tflag = t; // vh->var->tflag = t;
// queue.push_back(vh->var); // queue.push_back(vh->var);
@ -148,7 +150,7 @@ void display_memory_info(const char* fileline, bool dump_var, bool red_color) {
if (dump_var) { if (dump_var) {
vector<Node*> queue; vector<Node*> queue;
unordered_set<Node*> visited; unordered_set<Node*> visited;
for (auto& vh : VarHolder::hold_vars) for (auto& vh : hold_vars)
if (!visited.count(vh->var)) { if (!visited.count(vh->var)) {
queue.push_back(vh->var); queue.push_back(vh->var);
visited.insert(vh->var); visited.insert(vh->var);
@ -186,7 +188,7 @@ void display_memory_info(const char* fileline, bool dump_var, bool red_color) {
log.end(); log.end();
} }
extern vector<void(*)()> sigquit_callback; EXTERN_LIB vector<void(*)()> sigquit_callback;
void meminfo_callback() { void meminfo_callback() {
display_memory_info(); display_memory_info();

View File

@ -24,7 +24,7 @@ struct MemInfo {
MemInfo(); MemInfo();
}; };
extern MemInfo mem_info; EXTERN_LIB MemInfo mem_info;
// @pyjt(get_mem_info) // @pyjt(get_mem_info)
inline MemInfo get_mem_info() { return mem_info; } inline MemInfo get_mem_info() { return mem_info; }

View File

@ -79,7 +79,7 @@ void MemoryProfiler::check() {
vector<Node*> queue; vector<Node*> queue;
auto t = ++Node::tflag_count; auto t = ++Node::tflag_count;
for (auto& vh : VarHolder::hold_vars) for (auto& vh : hold_vars)
if (vh->var->tflag != t) { if (vh->var->tflag != t) {
vh->var->tflag = t; vh->var->tflag = t;
queue.push_back(vh->var); queue.push_back(vh->var);

View File

@ -39,7 +39,7 @@ struct MemoryProfiler {
string get_max_memory_info(); string get_max_memory_info();
}; };
extern MemoryProfiler memory_profiler; EXTERN_LIB MemoryProfiler memory_profiler;
DECLARE_FLAG(int, profile_memory_enable); DECLARE_FLAG(int, profile_memory_enable);

View File

@ -10,7 +10,7 @@
namespace jittor { namespace jittor {
extern std::atomic_flag lock; EXTERN_LIB std::atomic_flag lock;
struct spin_lock_guard { struct spin_lock_guard {
inline spin_lock_guard() { inline spin_lock_guard() {

View File

@ -15,7 +15,7 @@ namespace jittor {
DEFINE_FLAG_WITH_SETTER(int, use_cuda, 0, DEFINE_FLAG_WITH_SETTER(int, use_cuda, 0,
"Use cuda or not. 1 for trying to use cuda, 2 for forcing to use cuda."); "Use cuda or not. 1 for trying to use cuda, 2 for forcing to use cuda.");
extern void sync_all(bool device_sync); EXTERN_LIB void sync_all(bool device_sync);
void setter_use_cuda(int value) { void setter_use_cuda(int value) {
#ifdef HAS_CUDA #ifdef HAS_CUDA

View File

@ -18,8 +18,8 @@ namespace jittor {
#ifdef HAS_CUDA #ifdef HAS_CUDA
extern void check_nan_float32(float32* ptr, int64 num); EXTERN_LIB void check_nan_float32(float32* ptr, int64 num);
extern void check_nan_float64(float64* ptr, int64 num); EXTERN_LIB void check_nan_float64(float64* ptr, int64 num);
#endif #endif
bool check_nan(Var* v) { bool check_nan(Var* v) {

View File

@ -22,7 +22,16 @@ namespace jittor {
m(float32) \ m(float32) \
m(float64) m(float64)
#ifdef _MSC_VER
inline int ffs(int i) {
int j=0;
while (i) j++,i/=2;
return j;
}
#define map_size(T) {#T, ffs(sizeof(T))-1},
#else
#define map_size(T) {#T, __builtin_ffs(sizeof(T))-1}, #define map_size(T) {#T, __builtin_ffs(sizeof(T))-1},
#endif
unordered_map<string, size_t> dsize_map = {FOR_ALL_TYPES(map_size)}; unordered_map<string, size_t> dsize_map = {FOR_ALL_TYPES(map_size)};
@ -120,9 +129,9 @@ static unordered_set<string> binary_ops = {
#define DEFINE_NS(T) NanoString ns_##T; #define DEFINE_NS(T) NanoString ns_##T;
FOR_ALL_NS(DEFINE_NS); FOR_ALL_NS(DEFINE_NS);
unordered_map<string, NanoString> NanoString::__string_to_ns; unordered_map<string, NanoString> __string_to_ns;
char NanoString::__ns_to_string[ns_max_size*ns_max_len]; char __ns_to_string[ns_max_size*ns_max_len];
int NanoString::__ns_len[ns_max_size]; int __ns_len[ns_max_size];
static void init_ns() { static void init_ns() {
NanoString::ns_t i=0; NanoString::ns_t i=0;
@ -146,27 +155,27 @@ static void init_ns() {
ns.set(NanoString::_type, NanoString::_binary, NanoString::_type_nbits); ns.set(NanoString::_type, NanoString::_binary, NanoString::_type_nbits);
ns.set(NanoString::_bool, is_bool.count(name)); ns.set(NanoString::_bool, is_bool.count(name));
} }
NanoString::__string_to_ns[name] = ns; __string_to_ns[name] = ns;
auto name2 = ns.to_cstring(); auto name2 = ns.to_cstring();
int len=0; int len=0;
for (;;len++) { for (;;len++) {
name2[len] = name[len]; name2[len] = name[len];
if (!name[len]) break; if (!name[len]) break;
} }
NanoString::__ns_len[i-1] = len; __ns_len[i-1] = len;
}; };
#define INIT_NS(T) func(#T, ns_##T); #define INIT_NS(T) func(#T, ns_##T);
FOR_ALL_NS(INIT_NS); FOR_ALL_NS(INIT_NS);
ASSERT(i<=(1<<NanoString::_index_nbits)); ASSERT(i<=(1<<NanoString::_index_nbits));
NanoString::__string_to_ns["sum"] = ns_add; __string_to_ns["sum"] = ns_add;
NanoString::__string_to_ns["min"] = ns_minimum; __string_to_ns["min"] = ns_minimum;
NanoString::__string_to_ns["max"] = ns_maximum; __string_to_ns["max"] = ns_maximum;
NanoString::__string_to_ns["float"] = ns_float32; __string_to_ns["float"] = ns_float32;
NanoString::__string_to_ns["double"] = ns_float64; __string_to_ns["double"] = ns_float64;
NanoString::__string_to_ns["int"] = ns_int32; __string_to_ns["int"] = ns_int32;
NanoString::__string_to_ns["uint"] = ns_uint32; __string_to_ns["uint"] = ns_uint32;
LOGvv << "init __string_to_ns" << NanoString::__string_to_ns; LOGvv << "init __string_to_ns" << __string_to_ns;
LOGvv << "init __ns_to_string" << NanoString::__ns_to_string; LOGvv << "init __ns_to_string" << __ns_to_string;
} }
int __init_ns = (init_ns(), 0); int __init_ns = (init_ns(), 0);

View File

@ -86,9 +86,14 @@ constexpr int ns_max_len = 16;
m(normal) \ m(normal) \
struct NanoString; struct NanoString;
#define DECLEAR_NS(T) extern NanoString ns_##T; #define DECLEAR_NS(T) EXTERN_LIB NanoString ns_##T;
FOR_ALL_NS(DECLEAR_NS); FOR_ALL_NS(DECLEAR_NS);
EXTERN_LIB unordered_map<string, NanoString> __string_to_ns;
EXTERN_LIB char __ns_to_string[];
EXTERN_LIB int __ns_len[];
// @pyjt(NanoString) // @pyjt(NanoString)
struct NanoString { struct NanoString {
typedef uint16 ns_t; typedef uint16 ns_t;
@ -113,10 +118,6 @@ struct NanoString {
}; };
ns_t data=0; ns_t data=0;
static unordered_map<string, NanoString> __string_to_ns;
static char __ns_to_string[];
static int __ns_len[];
inline void set(Flags f, ns_t a=1, ns_t nbits=1) { inline void set(Flags f, ns_t a=1, ns_t nbits=1) {
ns_t mask = (((1u<<nbits)-1)<<f); ns_t mask = (((1u<<nbits)-1)<<f);
data = (data & ~mask) | ((a<<f)&mask); data = (data & ~mask) | ((a<<f)&mask);

View File

@ -16,9 +16,13 @@ static inline int lzcnt(int64 v) {
#else #else
return v ? __builtin_clzll(v) : 64; return v ? __builtin_clzll(v) : 64;
#endif #endif
#else
#ifdef _MSC_VER
return __lzcnt64(v);
#else #else
return __builtin_clzll(v); return __builtin_clzll(v);
#endif #endif
#endif
} }
struct Slice { struct Slice {

View File

@ -35,7 +35,7 @@ RingBuffer::~RingBuffer() {
} }
RingBuffer* RingBuffer::make_ring_buffer(uint64 size, bool multiprocess) { RingBuffer* RingBuffer::make_ring_buffer(uint64 size, bool multiprocess, uint64 buffer, bool init) {
int i=0; int i=0;
for (;(1ll<<i)<size;i++); for (;(1ll<<i)<size;i++);
uint64 size_mask = (1ll<<i)-1; uint64 size_mask = (1ll<<i)-1;
@ -47,26 +47,30 @@ RingBuffer* RingBuffer::make_ring_buffer(uint64 size, bool multiprocess) {
mmap(NULL, total_size, PROT_READ | PROT_WRITE, MAP_SHARED | MAP_ANONYMOUS, -1, 0) mmap(NULL, total_size, PROT_READ | PROT_WRITE, MAP_SHARED | MAP_ANONYMOUS, -1, 0)
#else #else
// TODO: multiprocess ring buffer in windows // TODO: multiprocess ring buffer in windows
(void*)malloc(total_size) (void*)buffer
#endif #endif
: :
// mmap(NULL, total_size, PROT_READ | PROT_WRITE, MAP_SHARED, -1, 0) : // mmap(NULL, total_size, PROT_READ | PROT_WRITE, MAP_SHARED, -1, 0) :
(void*)malloc(total_size); (void*)malloc(total_size);
std::memset(ptr, 0, total_size);
auto rb = (RingBuffer*)ptr; auto rb = (RingBuffer*)ptr;
if (!init) return rb;
std::memset(ptr, 0, total_size);
new (rb) RingBuffer(size, multiprocess); new (rb) RingBuffer(size, multiprocess);
return rb; return rb;
} }
void RingBuffer::free_ring_buffer(RingBuffer* rb) { void RingBuffer::free_ring_buffer(RingBuffer* rb, uint64 buffer, bool init) {
uint64 total_size = sizeof(RingBuffer) + rb->size; uint64 total_size = sizeof(RingBuffer) + rb->size;
auto is_multiprocess = rb->is_multiprocess; auto is_multiprocess = rb->is_multiprocess;
rb->~RingBuffer(); if (init)
rb->~RingBuffer();
if (is_multiprocess) { if (is_multiprocess) {
#ifndef _WIN32 #ifndef _WIN32
munmap(rb, total_size); munmap(rb, total_size);
#else #else
free((void*)rb); if (!buffer)
free((void*)rb);
// this buffer is not owned by this obj
#endif #endif
(void)total_size; (void)total_size;
} else { } else {

View File

@ -5,7 +5,11 @@
// file 'LICENSE.txt', which is part of this source code package. // file 'LICENSE.txt', which is part of this source code package.
// *************************************************************** // ***************************************************************
#pragma once #pragma once
#ifdef _MSC_VER
#include <windows.h>
#else
#include <pthread.h> #include <pthread.h>
#endif
#include <cstring> #include <cstring>
#include "common.h" #include "common.h"
@ -13,6 +17,37 @@ namespace jittor {
struct RingBuffer { struct RingBuffer {
#ifdef _MSC_VER
struct Mutex {
HANDLE handle;
inline Mutex(bool multiprocess=0) {
}
inline void lock() {
}
inline void unlock() {
}
inline ~Mutex() {
}
};
struct MutexScope {
Mutex* m;
inline MutexScope(Mutex& m) : m(&m) { m.lock(); }
inline ~MutexScope() { m->unlock(); }
};
struct Cond {
inline Cond(bool multiprocess=0) {
}
inline void wait(MutexScope& m) {
}
inline void notify() {
}
};
#else
struct Mutex { struct Mutex {
pthread_mutex_t m; pthread_mutex_t m;
inline Mutex(bool multiprocess=0) { inline Mutex(bool multiprocess=0) {
@ -35,6 +70,11 @@ struct RingBuffer {
pthread_mutex_unlock(&m); pthread_mutex_unlock(&m);
} }
}; };
struct MutexScope {
Mutex* m;
inline MutexScope(Mutex& m) : m(&m) { m.lock(); }
inline ~MutexScope() { m->unlock(); }
};
struct Cond { struct Cond {
pthread_cond_t cv; pthread_cond_t cv;
@ -56,20 +96,15 @@ struct RingBuffer {
pthread_cond_destroy(&cv); pthread_cond_destroy(&cv);
} }
inline void wait(Mutex& m) { inline void wait(MutexScope& m) {
pthread_cond_wait(&cv, &m.m); pthread_cond_wait(&cv, &m.m->m);
} }
inline void notify() { inline void notify() {
pthread_cond_signal(&cv); pthread_cond_signal(&cv);
} }
}; };
#endif
struct MutexScope {
Mutex* m;
inline MutexScope(Mutex& m) : m(&m) { m.lock(); }
inline ~MutexScope() { m->unlock(); }
};
uint64 size; uint64 size;
uint64 size_mask; uint64 size_mask;
@ -86,8 +121,8 @@ struct RingBuffer {
RingBuffer(uint64 size, bool multiprocess=false); RingBuffer(uint64 size, bool multiprocess=false);
~RingBuffer(); ~RingBuffer();
void stop(); void stop();
static RingBuffer* make_ring_buffer(uint64 size, bool multiprocess); static RingBuffer* make_ring_buffer(uint64 size, bool multiprocess, uint64 buffer=0, bool init=true);
static void free_ring_buffer(RingBuffer* rb); static void free_ring_buffer(RingBuffer* rb, uint64 buffer=0, bool init=true);
inline void clear() { l = r = is_stop = 0; } inline void clear() { l = r = is_stop = 0; }
@ -102,7 +137,7 @@ struct RingBuffer {
is_wait = 0; is_wait = 0;
} }
is_wait = 1; is_wait = 1;
cv.wait(m); cv.wait(_);
} }
} }

View File

@ -20,6 +20,8 @@ namespace jittor {
using std::string_view; using std::string_view;
#elif defined(__GNUC__) #elif defined(__GNUC__)
using std::experimental::string_view; using std::experimental::string_view;
#else
using std::string_view;
#endif #endif
template<class T> template<class T>

View File

@ -12,10 +12,10 @@
namespace jittor { namespace jittor {
extern unordered_map<void*, int64> lived_nodes; EXTERN_LIB unordered_map<void*, int64> lived_nodes;
extern int64 total_node; EXTERN_LIB int64 total_node;
extern int64 nt; EXTERN_LIB int64 nt;
extern vector<Node*> free_buffer; EXTERN_LIB vector<Node*> free_buffer;
struct NodeFlags { struct NodeFlags {
typedef uint16 nf_t; typedef uint16 nf_t;

View File

@ -97,12 +97,13 @@ string Op::get_jit_key(JK& jk) {
} }
vector<pair<string,string>> Op::get_jit_define() { vector<pair<string,string>> Op::get_jit_define() {
return parse_jit_keys(get_jit_key(jk)); return parse_jit_keys(get_jit_key(get_jk()));
} }
string Op::get_hash_name() { string Op::get_hash_name() {
string hash_name; string hash_name;
std::stringstream ss; std::stringstream ss;
JK& jk = get_jk();
do_prepare(jk); do_prepare(jk);
ss << std::hex << std::hash<string>()(jk.to_string()); ss << std::hex << std::hash<string>()(jk.to_string());
hash_name = ss.str(); hash_name = ss.str();
@ -186,12 +187,13 @@ void Op::do_prepare(JK& jk){
void Op::do_run_after_prepare(JK& jk) { void Op::do_run_after_prepare(JK& jk) {
if (!jk.empty()) if (!jk.empty())
jit_run(); jit_run(jk);
else else
run(); run();
} }
void Op::do_run() { void Op::do_run() {
JK& jk = get_jk();
do_prepare(jk); do_prepare(jk);
do_run_after_prepare(jk); do_run_after_prepare(jk);
} }
@ -209,10 +211,7 @@ string Op::get_filename_from_jit_key(const string& jit_key, const string& suffix
} }
s = ss.str(); s = ss.str();
for (char& c : s) { for (char& c : s) {
if (c=='[' || c==']' || c=='<' || c=='>' if (!((c>='a' && c<='z') || (c>='A' && c<='Z') || (c>='0' && c<='9')))
|| c=='{' || c=='}' || c=='(' || c==')' || c==','
|| c=='\n' || c=='\t' || c==' ' || c=='&' || c=='|'
|| c=='/' || c==':')
c = '_'; c = '_';
} }
#ifndef _WIN32 #ifndef _WIN32
@ -248,7 +247,7 @@ string Op::file_name_to_class_name(const string& s) {
return res; return res;
} }
void Op::jit_run() { void Op::jit_run(JK& jk) {
const char* jit_key = jk.to_cstring(); const char* jit_key = jk.to_cstring();
auto iter = jit_ops.find(jit_key); auto iter = jit_ops.find(jit_key);
if (iter != jit_ops.end()) { if (iter != jit_ops.end()) {

View File

@ -50,7 +50,7 @@ struct Op : Node {
virtual VarPtr duplicate(); virtual VarPtr duplicate();
virtual void compile_optimize(string& src); virtual void compile_optimize(string& src);
virtual void graph_optimize(); virtual void graph_optimize();
void jit_run(); void jit_run(JK& jk);
string name_ex() const; string name_ex() const;
string get_jit_key(JK& jk); string get_jit_key(JK& jk);
@ -60,9 +60,9 @@ struct Op : Node {
std::ostream& operator<<(std::ostream& os, const Op* var); std::ostream& operator<<(std::ostream& os, const Op* var);
extern string_view_map<jit_op_entry_t> jit_ops; EXTERN_LIB string_view_map<jit_op_entry_t> jit_ops;
// jit_key_mapper: map origin jit_key -> tuned jit_key // jit_key_mapper: map origin jit_key -> tuned jit_key
extern string_view_map<string> jit_key_mapper; EXTERN_LIB string_view_map<string> jit_key_mapper;
#ifdef JIT #ifdef JIT
#define DECLARE_jit_run void jit_run(); #define DECLARE_jit_run void jit_run();

View File

@ -1042,7 +1042,7 @@ jit_op_entry_t OpCompiler::do_compile(Op* op) {
src = &src_after_passes; src = &src_after_passes;
} }
op->compile_optimize(*src); op->compile_optimize(*src);
auto ret = oc.compile(op->get_jit_key(jk), *src); auto ret = oc.compile(op->get_jit_key(get_jk()), *src);
return ret; return ret;
} }

View File

@ -129,9 +129,13 @@ void BroadcastToOp::infer_shape() {
auto xdim = x->shape.size(); auto xdim = x->shape.size();
auto ydim = yshapes.size(); auto ydim = yshapes.size();
auto count = __builtin_popcount(bcast_mask&~keepdims_mask); auto count = __builtin_popcount(bcast_mask&~keepdims_mask);
auto zdim = std::max(xdim, ydim-count) + count; auto zdim = std::max(uint64(xdim), uint64(ydim-count)) + count;
#ifdef _WIN32
int64 zz[10];
#else
int64 zz[zdim]; int64 zz[zdim];
#endif
for (int i=zdim-1, xi = xdim-1, yi = ydim-1; i>=0; i--) { for (int i=zdim-1, xi = xdim-1, yi = ydim-1; i>=0; i--) {
bool bx = xi>=0; bool bx = xi>=0;
bool by = yi>=0; bool by = yi>=0;

View File

@ -280,7 +280,7 @@ void GetitemOp::_compile_optimize(string& src) {
new_func->push_back(func->children.back()->move_out()); new_func->push_back(func->children.back()->move_out());
auto& loop = new_func->children.back(); auto& loop = new_func->children.back();
int no = o_shape.size(); int no = o_shape.size();
KernelIR* loops[no]; STACK_ALLOC(KernelIR*, loops, no);
if (!no) { if (!no) {
func->push_back("func<<<1,1>>>("+arg_call+");"); func->push_back("func<<<1,1>>>("+arg_call+");");
} else { } else {

View File

@ -38,6 +38,6 @@ VarPtr make_number(float number, Var* x) {
static void init() { static void init() {
op_registe({"number", "", "", {{&typeid(&make_number), (void*)&make_number}}}); op_registe({"number", "", "", {{&typeid(&make_number), (void*)&make_number}}});
} }
__attribute__((unused)) static int caller = (init(), 0); static int caller = (init(), 0);
} // jittor } // jittor

View File

@ -213,17 +213,17 @@ static void getitem_inplace(GetitemOp* op) {
void SetitemOp::graph_optimize() { void SetitemOp::graph_optimize() {
// LOGir << "hello graph_optimize"; // LOGir << "hello graph_optimize";
setitem_inplace(this); setitem_inplace(this);
(void)setitem_inplace; (void*)setitem_inplace;
} }
void GetitemOp::graph_optimize() { void GetitemOp::graph_optimize() {
// This optimize is still WIP // This optimize is still WIP
// LOGir << "hello getitem graph_optimize"; // LOGir << "hello getitem graph_optimize";
// setitem_grad_opt(this); // setitem_grad_opt(this);
(void)setitem_grad_opt; (void*)setitem_grad_opt;
// (void)getitem_inplace; // (void)getitem_inplace;
getitem_inplace(this); getitem_inplace(this);
(void)getitem_inplace; (void*)getitem_inplace;
} }
} }

View File

@ -23,6 +23,7 @@ Searcher::Searcher(OpCompiler* oc) : oc(oc) {
} }
int64_t Searcher::get_time_of_current_choices() { int64_t Searcher::get_time_of_current_choices() {
JK& jk = get_jk();
auto* op = oc->op; auto* op = oc->op;
// generate jit_key // generate jit_key
op->update_jit_key(); op->update_jit_key();

View File

@ -90,7 +90,7 @@ void CheckCachePass::run() {
ir->push_back("#include \"profiler/memory_checker.h\"", &ir->before); ir->push_back("#include \"profiler/memory_checker.h\"", &ir->before);
ir->push_back("using namespace jittor;", &ir->before); ir->push_back("using namespace jittor;", &ir->before);
// declaration // declaration
ir->push_back("extern \"C\" std::unique_ptr<MemoryChecker> memory_checker;", &ir->before); ir->push_back("EXTERN_LIB \"C\" std::unique_ptr<MemoryChecker> memory_checker;", &ir->before);
// definition // definition
ir->push_back("std::unique_ptr<MemoryChecker> memory_checker;", &ir->before); ir->push_back("std::unique_ptr<MemoryChecker> memory_checker;", &ir->before);
vector<string> commands; vector<string> commands;

View File

@ -17,6 +17,7 @@ namespace jittor {
using namespace expr; using namespace expr;
void ConstVarPass::run() { void ConstVarPass::run() {
JK& jk = get_jk();
int changed = 0; int changed = 0;
for (int i=0; i<op->ops.size(); i++) { for (int i=0; i<op->ops.size(); i++) {
auto opi = op->ops[i]; auto opi = op->ops[i];

View File

@ -234,7 +234,7 @@ void ConvTuner::forwardTune(FusedOp* fop) {
continue; continue;
Op* ops[3] = {op, bop->x->input(), bop->y->input()}; Op* ops[3] = {op, bop->x->input(), bop->y->input()};
int ok = 0; int ok = 0;
LOGvvvv << "conv like op" << fop << fop->get_jit_key(jk); LOGvvvv << "conv like op" << fop << fop->get_jit_key(get_jk());
for (int y_id=0; y_id<3; y_id++) for (int y_id=0; y_id<3; y_id++)
for (int x_id=0; x_id<3; x_id++) for (int x_id=0; x_id<3; x_id++)
for (int w_id=0; w_id<3; w_id++) { for (int w_id=0; w_id<3; w_id++) {

View File

@ -69,7 +69,7 @@ int VarRelayManager::add_relay_group(const vector<pair<Var*, Var*>>& group) {
if (node->is_var()) if (node->is_var())
continue; continue;
Op* op = node->op(); Op* op = node->op();
op->do_jit_prepare(jk); op->do_jit_prepare(get_jk());
list<Node*> new_inputs; list<Node*> new_inputs;
int removed = 0; int removed = 0;
for (Var* v : op->inputs()) for (Var* v : op->inputs())

View File

@ -25,7 +25,7 @@ namespace jittor {
DEFINE_FLAG(int, use_parallel_op_compiler, 16, "Number of threads that parallel op comiler used, default 16, set this value to 0 will disable parallel op compiler."); DEFINE_FLAG(int, use_parallel_op_compiler, 16, "Number of threads that parallel op comiler used, default 16, set this value to 0 will disable parallel op compiler.");
// from log.cc // from log.cc
extern int segfault_happen; EXTERN_LIB int segfault_happen;
// simple thread used for parallel compilation // simple thread used for parallel compilation
struct SimpleThread { struct SimpleThread {
@ -36,7 +36,7 @@ struct SimpleThread {
std::condition_variable cv; std::condition_variable cv;
std::thread thread; std::thread thread;
void run() { void run() {
thread_name = "C"+S(id); get_thread_name() = "C"+S(id);
try { try {
std::unique_lock<std::mutex> lck(mtx); std::unique_lock<std::mutex> lck(mtx);
if (func) if (func)
@ -70,8 +70,8 @@ struct SimpleThread {
}; };
struct SimpleThreads; struct SimpleThreads;
extern SimpleThreads threads; EXTERN_LIB SimpleThreads threads;
extern vector<void(*)()> cleanup_callback; EXTERN_LIB vector<void(*)()> cleanup_callback;
struct SimpleThreads { struct SimpleThreads {
list<SimpleThread> threads; list<SimpleThread> threads;
@ -136,7 +136,7 @@ void parallel_compile_all_ops(vector<int>& queue, vector<int>& range, FusedOp& f
vector<int> op_needs_compile; vector<int> op_needs_compile;
string_view_map<int> map; string_view_map<int> map;
vector<unique_ptr<FusedOp>> fop_needs_compile; vector<unique_ptr<FusedOp>> fop_needs_compile;
auto& jkl = jk; auto& jkl = get_jk();
for (uint rid=0; rid<queue.size(); rid++) { for (uint rid=0; rid<queue.size(); rid++) {
int root = queue[rid]; int root = queue[rid];
@ -213,7 +213,7 @@ void parallel_compile_all_ops(vector<int>& queue, vector<int>& range, FusedOp& f
auto func = [&](int tid) { auto func = [&](int tid) {
auto& entrys = op_entrys.at(tid); auto& entrys = op_entrys.at(tid);
entrys.clear(); entrys.clear();
auto& jkl = jk; auto& jkl = get_jk();
while (!has_error && !segfault_happen) { while (!has_error && !segfault_happen) {
int i = ai++; int i = ai++;
if (i >= n) break; if (i >= n) break;
@ -247,14 +247,14 @@ void parallel_compile_all_ops(vector<int>& queue, vector<int>& range, FusedOp& f
bool needs_compile; bool needs_compile;
{ {
std::lock_guard<std::mutex> lock(entry_lock); std::lock_guard<std::mutex> lock(entry_lock);
auto iter = jit_ops.find(jk.to_cstring()); auto iter = jit_ops.find(jkl.to_cstring());
needs_compile = (iter == jit_ops.end()); needs_compile = (iter == jit_ops.end());
if (needs_compile) { if (needs_compile) {
jit_ops[jk.to_cstring()] = nullptr; jit_ops[jkl.to_cstring()] = nullptr;
} }
} }
if (!needs_compile) continue; if (!needs_compile) continue;
string s = jk.to_string(); string s = jkl.to_string();
auto op_entry = OpCompiler::do_compile(orc.op); auto op_entry = OpCompiler::do_compile(orc.op);
{ {
std::lock_guard<std::mutex> lock(entry_lock); std::lock_guard<std::mutex> lock(entry_lock);
@ -266,7 +266,7 @@ void parallel_compile_all_ops(vector<int>& queue, vector<int>& range, FusedOp& f
} catch (const std::exception& e) { } catch (const std::exception& e) {
// log jit_key and file location // log jit_key and file location
op->do_prepare(jkl); op->do_prepare(jkl);
string jit_src_path = Op::get_filename_from_jit_key(jk.to_cstring(), ".cc"); string jit_src_path = Op::get_filename_from_jit_key(jkl.to_cstring(), ".cc");
LOGe << "[Error] source file location:" << jit_src_path; LOGe << "[Error] source file location:" << jit_src_path;
if (is_fused_op) { if (is_fused_op) {

View File

@ -87,7 +87,7 @@ unique_ptr<MemoryChecker>* load_memory_checker(string name) {
return mm; return mm;
} }
extern string _get_stack_info(Node* node); EXTERN_LIB string _get_stack_info(Node* node);
static string get_stack_info(Op* op) { static string get_stack_info(Op* op) {
string stack_info = "stack info:\n"; string stack_info = "stack info:\n";

View File

@ -59,7 +59,7 @@ struct Profiler {
~Profiler(); ~Profiler();
}; };
extern Profiler profiler; EXTERN_LIB Profiler profiler;
DECLARE_FLAG(int, profiler_enable); DECLARE_FLAG(int, profiler_enable);

View File

@ -18,9 +18,13 @@ static inline int _lzcnt(int64 v) {
#else #else
return v ? __builtin_clzll(v) : 64; return v ? __builtin_clzll(v) : 64;
#endif #endif
#else
#ifdef _MSC_VER
return __lzcnt64(v);
#else #else
return __builtin_clzll(v); return __builtin_clzll(v);
#endif #endif
#endif
} }
struct SimpleProfiler { struct SimpleProfiler {

View File

@ -12,7 +12,7 @@
namespace jittor { namespace jittor {
// Those function is generated by python // Those function is generated by python
extern void pyjt_def_all(PyObject* m); EXTERN_LIB void pyjt_def_all(PyObject* m);
vector<VarHolder*> _grad(VarHolder* loss, const vector<VarHolder*>& targets) { vector<VarHolder*> _grad(VarHolder* loss, const vector<VarHolder*>& targets) {
vector<Var*> vs; vector<Var*> vs;

View File

@ -94,7 +94,7 @@ static vector<Stack> get_stack_info() {
auto frame = (PyFrameObject*)ret.obj; auto frame = (PyFrameObject*)ret.obj;
int n=0; int n=0;
while (frame) n++, frame = frame->f_back; while (frame) n++, frame = frame->f_back;
PyFrameObject* frames[n]; STACK_ALLOC(PyFrameObject*, frames, n);
frame = (PyFrameObject*)ret.obj; frame = (PyFrameObject*)ret.obj;
int i=n; int i=n;
while (i) frames[--i] = frame, frame = frame->f_back; while (i) frames[--i] = frame, frame = frame->f_back;
@ -225,7 +225,7 @@ static inline string get_var_data_str(Var* v) {
} }
void TraceData::record_node(Node* node, bool record_stack) { void TraceData::record_node(Node* node, bool record_stack) {
if (thread_name.size()) return; if (get_thread_name().size()) return;
NodeData data; NodeData data;
data.id = node_data_cnt++; data.id = node_data_cnt++;
id_map[node] = data.id; id_map[node] = data.id;
@ -261,7 +261,7 @@ static int64 get_node_id(Node* node) {
} }
void TraceData::release_node(Node* node) { void TraceData::release_node(Node* node) {
if (thread_name.size()) return; if (get_thread_name().size()) return;
auto iter = trace_data.id_map.find(node); auto iter = trace_data.id_map.find(node);
if (iter == trace_data.id_map.end()) if (iter == trace_data.id_map.end())
return; return;

View File

@ -10,7 +10,7 @@
namespace jittor { namespace jittor {
DECLARE_FLAG(int, trace_py_var); DECLARE_FLAG(int, trace_py_var);
extern Op* trace_grad_op; EXTERN_LIB Op* trace_grad_op;
struct JitKey; struct JitKey;
struct Stack { struct Stack {
@ -64,7 +64,7 @@ struct TraceData {
void record_execution(Op* op, bool is_fused_op, JitKey& jk); void record_execution(Op* op, bool is_fused_op, JitKey& jk);
}; };
extern TraceData trace_data; EXTERN_LIB TraceData trace_data;
void print_node_trace(const Node* node, std::ostream& os); void print_node_trace(const Node* node, std::ostream& os);
vector<Stack> get_node_trace(Node* node); vector<Stack> get_node_trace(Node* node);

View File

@ -50,8 +50,8 @@ enum NPY_TYPES {
NPY_OBJECT=17, NPY_OBJECT=17,
}; };
extern NanoString npy2ns[]; EXTERN_LIB NanoString npy2ns[];
extern NPY_TYPES ns2npy[]; EXTERN_LIB NPY_TYPES ns2npy[];
#define NPY_ARRAY_C_CONTIGUOUS 0x0001 #define NPY_ARRAY_C_CONTIGUOUS 0x0001
#define NPY_ARRAY_ALIGNED 0x0100 #define NPY_ARRAY_ALIGNED 0x0100
@ -74,19 +74,19 @@ inline int get_typenum(NanoString ns) {
typedef Py_intptr_t npy_intp; typedef Py_intptr_t npy_intp;
extern unordered_map<string, int> np_typenum_map; EXTERN_LIB unordered_map<string, int> np_typenum_map;
extern void** PyArray_API; EXTERN_LIB void** PyArray_API;
extern PyTypeObject *PyArray_Type; EXTERN_LIB PyTypeObject *PyArray_Type;
extern PyTypeObject *PyNumberArrType_Type; EXTERN_LIB PyTypeObject *PyNumberArrType_Type;
extern PyTypeObject *PyArrayDescr_Type; EXTERN_LIB PyTypeObject *PyArrayDescr_Type;
extern PyObject* (*PyArray_New)(PyTypeObject *, int, npy_intp const *, int, npy_intp const *, void *, int, int, PyObject *); EXTERN_LIB PyObject* (*PyArray_New)(PyTypeObject *, int, npy_intp const *, int, npy_intp const *, void *, int, int, PyObject *);
extern PyObject* (*PyArray_FromAny)(PyObject *, PyArrayDescr_Proxy *, int, int, int, PyObject *); EXTERN_LIB PyObject* (*PyArray_FromAny)(PyObject *, PyArrayDescr_Proxy *, int, int, int, PyObject *);
extern unsigned int (*PyArray_GetNDArrayCFeatureVersion)(); EXTERN_LIB unsigned int (*PyArray_GetNDArrayCFeatureVersion)();
extern int (*PyArray_SetBaseObject)(PyObject *arr, PyObject *obj); EXTERN_LIB int (*PyArray_SetBaseObject)(PyObject *arr, PyObject *obj);
extern PyObject* (*PyArray_NewCopy)(PyObject *, int); EXTERN_LIB PyObject* (*PyArray_NewCopy)(PyObject *, int);
extern int (*PyArray_CopyInto)(PyObject *, PyObject *); EXTERN_LIB int (*PyArray_CopyInto)(PyObject *, PyObject *);
extern void (*PyArray_CastScalarToCtype)(PyObject* scalar, void* ctypeptr, PyArrayDescr_Proxy* outcode); EXTERN_LIB void (*PyArray_CastScalarToCtype)(PyObject* scalar, void* ctypeptr, PyArrayDescr_Proxy* outcode);
#define PyArray_Copy(obj) PyArray_NewCopy(obj, 0) #define PyArray_Copy(obj) PyArray_NewCopy(obj, 0)
@ -121,7 +121,7 @@ union tmp_data_t {
int8 i8; int8 i8;
}; };
extern tmp_data_t tmp_data; EXTERN_LIB tmp_data_t tmp_data;
void numpy_init(); void numpy_init();

View File

@ -141,7 +141,7 @@ ArrayOp::ArrayOp(PyObject* obj) {
} else { } else {
// this is non-continue numpy array // this is non-continue numpy array
#if defined(__linux__) || defined(_WIN32) #if defined(__linux__) || defined(_WIN32)
int64 dims[args.shape.size()]; STACK_ALLOC(int64, dims, args.shape.size());
#elif defined(__APPLE__) #elif defined(__APPLE__)
long dims[args.shape.size()]; long dims[args.shape.size()];
#endif #endif

View File

@ -135,7 +135,7 @@ DEF_IS(Slice, T) from_py_object(PyObject* obj) {
// DumpGraphs // DumpGraphs
struct DumpGraphs; struct DumpGraphs;
extern PyTypeObject PyjtDumpGraphs; EXTERN_LIB PyTypeObject PyjtDumpGraphs;
DEF_IS(DumpGraphs, bool) is_type(PyObject* obj) { DEF_IS(DumpGraphs, bool) is_type(PyObject* obj) {
return Py_TYPE(obj) == &PyjtDumpGraphs; return Py_TYPE(obj) == &PyjtDumpGraphs;
} }
@ -157,7 +157,7 @@ DEF_IS(DumpGraphs, const T&) from_py_object(PyObject* obj) {
// MemInfo // MemInfo
struct MemInfo; struct MemInfo;
extern PyTypeObject PyjtMemInfo; EXTERN_LIB PyTypeObject PyjtMemInfo;
DEF_IS(MemInfo, bool) is_type(PyObject* obj) { DEF_IS(MemInfo, bool) is_type(PyObject* obj) {
return Py_TYPE(obj) == &PyjtMemInfo; return Py_TYPE(obj) == &PyjtMemInfo;
} }
@ -177,7 +177,7 @@ DEF_IS(MemInfo, const T&) from_py_object(PyObject* obj) {
// NanoString // NanoString
struct NanoString; struct NanoString;
extern PyTypeObject PyjtNanoString; EXTERN_LIB PyTypeObject PyjtNanoString;
DEF_IS(NanoString, bool) is_type(PyObject* obj) { DEF_IS(NanoString, bool) is_type(PyObject* obj) {
return Py_TYPE(obj) == &PyjtNanoString || return Py_TYPE(obj) == &PyjtNanoString ||
PyUnicode_CheckExact(obj) || PyUnicode_CheckExact(obj) ||
@ -215,7 +215,7 @@ DEF_IS(NanoString, T) from_py_object(PyObject* obj) {
// NanoVector // NanoVector
struct NanoVector; struct NanoVector;
extern PyTypeObject PyjtNanoVector; EXTERN_LIB PyTypeObject PyjtNanoVector;
DEF_IS(NanoVector, bool) is_type(PyObject* obj) { DEF_IS(NanoVector, bool) is_type(PyObject* obj) {
return Py_TYPE(obj) == &PyjtNanoVector || return Py_TYPE(obj) == &PyjtNanoVector ||
PyList_CheckExact(obj) || PyTuple_CheckExact(obj); PyList_CheckExact(obj) || PyTuple_CheckExact(obj);
@ -253,7 +253,7 @@ DEF_IS(NanoVector, T) from_py_object(PyObject* obj) {
struct ArrayArgs; struct ArrayArgs;
struct VarHolder; struct VarHolder;
vector<ArrayArgs> fetch_sync(const vector<VarHolder*>& vh); vector<ArrayArgs> fetch_sync(const vector<VarHolder*>& vh);
extern PyHeapTypeObject PyjtVarHolder; EXTERN_LIB PyHeapTypeObject PyjtVarHolder;
DEF_IS(ArrayArgs, bool) is_type(PyObject* obj) { DEF_IS(ArrayArgs, bool) is_type(PyObject* obj) {
return return
Py_TYPE(obj) == &PyjtVarHolder.ht_type || Py_TYPE(obj) == &PyjtVarHolder.ht_type ||
@ -267,7 +267,7 @@ DEF_IS(ArrayArgs, bool) is_type(PyObject* obj) {
DEF_IS(ArrayArgs, PyObject*) to_py_object(const T& a) { DEF_IS(ArrayArgs, PyObject*) to_py_object(const T& a) {
#if defined(__linux__) || defined(_WIN32) #if defined(__linux__) || defined(_WIN32)
int64 dims[a.shape.size()]; STACK_ALLOC(int64, dims, a.shape.size());
#elif defined(__APPLE__) #elif defined(__APPLE__)
long dims[a.shape.size()]; long dims[a.shape.size()];
#endif #endif
@ -351,8 +351,8 @@ DEF_IS(ArrayArgs, T) from_py_object(PyObject* obj) {
// VarHolder // VarHolder
struct VarHolder; struct VarHolder;
extern PyHeapTypeObject PyjtVarHolder; EXTERN_LIB PyHeapTypeObject PyjtVarHolder;
namespace jit_op_maker { extern VarHolder* array_(ArrayArgs&& args); } namespace jit_op_maker { EXTERN_LIB VarHolder* array_(ArrayArgs&& args); }
DEF_IS(VarHolder*, bool) is_type(PyObject* obj) { DEF_IS(VarHolder*, bool) is_type(PyObject* obj) {
return Py_TYPE(obj) == &PyjtVarHolder.ht_type || return Py_TYPE(obj) == &PyjtVarHolder.ht_type ||
is_type<ArrayArgs>(obj); is_type<ArrayArgs>(obj);
@ -383,7 +383,7 @@ DEF_IS(VarHolder*, T) from_py_object(PyObject* obj, unique_ptr<VarHolder>& holde
struct DataView; struct DataView;
DEF_IS(DataView, PyObject*) to_py_object(T a) { DEF_IS(DataView, PyObject*) to_py_object(T a) {
#if defined(__linux__) || defined(_WIN32) #if defined(__linux__) || defined(_WIN32)
int64 dims[a.shape.size()]; STACK_ALLOC(int64, dims, a.shape.size());
#elif defined(__APPLE__) #elif defined(__APPLE__)
long dims[a.shape.size()]; long dims[a.shape.size()];
#endif #endif
@ -410,8 +410,9 @@ DEF_IS(DataView, PyObject*) to_py_object(T a) {
return oh.release(); return oh.release();
} }
#ifdef __GNUC__
#pragma GCC diagnostic ignored "-Wstrict-aliasing" #pragma GCC diagnostic ignored "-Wstrict-aliasing"
#endif
struct ItemData; struct ItemData;
DEF_IS(ItemData, PyObject*) to_py_object(T a) { DEF_IS(ItemData, PyObject*) to_py_object(T a) {
if (a.dtype == ns_bool) { if (a.dtype == ns_bool) {

View File

@ -110,7 +110,7 @@ static void push_py_object(RingBuffer* rb, PyObject* obj, uint64& __restrict__ o
rb->push(size, offset); rb->push(size, offset);
args.ptr = rb->get_ptr(size, offset); args.ptr = rb->get_ptr(size, offset);
#if defined(__linux__) || defined(_WIN32) #if defined(__linux__) || defined(_WIN32)
int64 dims[args.shape.size()]; STACK_ALLOC(int64, dims, args.shape.size());
#elif defined(__APPLE__) #elif defined(__APPLE__)
long dims[args.shape.size()]; long dims[args.shape.size()];
#endif #endif
@ -225,12 +225,19 @@ PyObject* PyMultiprocessRingBuffer::pop() {
return obj; return obj;
} }
PyMultiprocessRingBuffer::PyMultiprocessRingBuffer(uint64 size) { PyMultiprocessRingBuffer::PyMultiprocessRingBuffer(uint64 size, uint64 buffer, bool init) {
rb = RingBuffer::make_ring_buffer(size, 1); this->buffer = buffer;
this->init = init;
if (buffer) {
auto mobj = (PyObject*)buffer;
auto buf = PyMemoryView_GET_BUFFER(mobj);
buffer = (uint64)buf->buf;
}
rb = RingBuffer::make_ring_buffer(size, 1, buffer, init);
} }
PyMultiprocessRingBuffer::~PyMultiprocessRingBuffer() { PyMultiprocessRingBuffer::~PyMultiprocessRingBuffer() {
RingBuffer::free_ring_buffer(rb); RingBuffer::free_ring_buffer(rb, buffer, init);
} }
} }

View File

@ -13,9 +13,11 @@ namespace jittor {
// @pyjt(RingBuffer) // @pyjt(RingBuffer)
struct PyMultiprocessRingBuffer { struct PyMultiprocessRingBuffer {
RingBuffer* rb; RingBuffer* rb;
uint64 buffer;
bool _keep_numpy_array = false; bool _keep_numpy_array = false;
bool init;
// @pyjt(__init__) // @pyjt(__init__)
PyMultiprocessRingBuffer(uint64 size); PyMultiprocessRingBuffer(uint64 size, uint64 buffer=0, bool init=true);
// @pyjt(__dealloc__) // @pyjt(__dealloc__)
~PyMultiprocessRingBuffer(); ~PyMultiprocessRingBuffer();
// @pyjt(push,send) // @pyjt(push,send)
@ -46,6 +48,9 @@ struct PyMultiprocessRingBuffer {
s += ")"; s += ")";
return s; return s;
} }
// @pyjt(__get__size)
inline uint64 size() { return rb->size; }
}; };

View File

@ -9,10 +9,11 @@
namespace jittor { namespace jittor {
JIT_TEST(jit_key) { JIT_TEST(jit_key) {
JK& jk = get_jk();
jk.clear(); jk.clear();
for (int i=0; i<JK::buffer_size/2; i++) for (int i=0; i<JK::buffer_size/2; i++)
jk.buffer[i] = i%256; jk.buffer[i] = i%256;
expect_error([]() { expect_error([&]() {
for (int i=0; i<JK::buffer_size; i++) for (int i=0; i<JK::buffer_size; i++)
jk.buffer[i] = i%256; jk.buffer[i] = i%256;
}); });
@ -45,9 +46,11 @@ JIT_TEST(jit_key) {
jk.clear(); jk.clear();
add_jit_define(jk, "f", 0.01); add_jit_define(jk, "f", 0.01);
add_jit_define(jk, "f", 0.5); add_jit_define(jk, "f", 0.5);
#ifndef _MSC_VER
add_jit_define(jk, "f", 1.0/0); add_jit_define(jk, "f", 1.0/0);
add_jit_define(jk, "f", -1.0/0); add_jit_define(jk, "f", -1.0/0);
add_jit_define(jk, "f", 0.0/0); add_jit_define(jk, "f", 0.0/0);
#endif
keys = parse_jit_keys(jk.to_string()); keys = parse_jit_keys(jk.to_string());
k2 = {{"f","0x1.47ae147ae147bp-7"}, k2 = {{"f","0x1.47ae147ae147bp-7"},
{"f","0x1p-1"}, {"f","0x1p-1"},

View File

@ -31,6 +31,7 @@ JIT_TEST(op_register) {
} }
JIT_TEST(fused_op_relay_matmul) { JIT_TEST(fused_op_relay_matmul) {
JK& jk = get_jk();
VarPtr a({10,10}, "float32"); VarPtr a({10,10}, "float32");
VarPtr b({10,10}, "float32"); VarPtr b({10,10}, "float32");
auto aa = make_broadcast_to_op(a, {10,10,10}, {2}); auto aa = make_broadcast_to_op(a, {10,10,10}, {2});

View File

@ -13,7 +13,7 @@ void cuda_loop_schedule(NanoVector o_shape, int* masks, int* tdims);
JIT_TEST(cuda_loop_schedule) { JIT_TEST(cuda_loop_schedule) {
auto check = [&](const vector<int64>& shape, const vector<int>& masks, vector<int> tdims={}) { auto check = [&](const vector<int64>& shape, const vector<int>& masks, vector<int> tdims={}) {
int masks2[shape.size()]; STACK_ALLOC(int, masks2, shape.size());
int tdims2[6]; int tdims2[6];
cuda_loop_schedule(shape, masks2, tdims2); cuda_loop_schedule(shape, masks2, tdims2);
while (tdims.size() < 6) tdims.push_back(1); while (tdims.size() < 6) tdims.push_back(1);

View File

@ -21,7 +21,7 @@ struct TestTask {
JIT_TEST(sfrl_allocator_time) { JIT_TEST(sfrl_allocator_time) {
Allocator* allocator = get_allocator(); Allocator* allocator = get_allocator();
int max_allc_num = 10000; constexpr int max_allc_num = 10000;
size_t id[max_allc_num]; size_t id[max_allc_num];
size_t temp[max_allc_num]; size_t temp[max_allc_num];
std::vector<TestTask> tasks; std::vector<TestTask> tasks;
@ -52,7 +52,7 @@ JIT_TEST(sfrl_allocator_time) {
JIT_TEST(sfrl_allocator_share) { JIT_TEST(sfrl_allocator_share) {
Allocator* allocator = get_allocator(); Allocator* allocator = get_allocator();
int max_allc_num = 10000; constexpr int max_allc_num = 10000;
size_t id[max_allc_num]; size_t id[max_allc_num];
size_t temp[max_allc_num]; size_t temp[max_allc_num];
std::vector<TestTask> tasks; std::vector<TestTask> tasks;
@ -88,7 +88,7 @@ JIT_TEST(sfrl_allocator_share) {
JIT_TEST(sfrl_allocator_share_without_size_and_ptr) { JIT_TEST(sfrl_allocator_share_without_size_and_ptr) {
Allocator* allocator = get_allocator(); Allocator* allocator = get_allocator();
int max_allc_num = 1000; constexpr int max_allc_num = 1000;
size_t id[max_allc_num]; size_t id[max_allc_num];
size_t temp[max_allc_num]; size_t temp[max_allc_num];
std::vector<TestTask> tasks; std::vector<TestTask> tasks;

View File

@ -22,7 +22,7 @@ struct UpdateQueue {
void auto_flush(); void auto_flush();
}; };
extern UpdateQueue update_queue; EXTERN_LIB UpdateQueue update_queue;
} // jittor } // jittor

View File

@ -31,7 +31,7 @@ void write(const string& fname, const string& src) {
bool file_exist(const string& fname) { bool file_exist(const string& fname) {
std::ifstream f(fname); std::ifstream f(fname);
return f.good(); return f && f.good();
} }
#endif #endif
@ -45,23 +45,21 @@ string join(string a, string b) {
} }
void find_names(string cmd, vector<string>& input_names, string& output_name, map<string,vector<string>>& extra) { void find_names(string cmd, vector<string>& input_names, string& output_name, map<string,vector<string>>& extra) {
size_t i=0;
while (i<cmd.size() && cmd[i] != ' ') i++;
CHECK(i<cmd.size());
// find space not in str // find space not in str
#define is_quate(x) ((x)=='\'' || (x)=='\"')
auto pass = [&](size_t& j) { auto pass = [&](size_t& j) {
while (j<cmd.size()) { while (j<cmd.size()) {
if (cmd[j]=='\'') { if (is_quate(cmd[j])) {
j++; j++;
while (j<cmd.size() && cmd[j]!='\'') j++; while (j<cmd.size() && !is_quate(cmd[j])) j++;
ASSERT(j<cmd.size()); ASSERT(j<cmd.size());
j++; j++;
continue; continue;
} }
while (j<cmd.size() && cmd[j]!=' ' && cmd[j]!='\'') j++; while (j<cmd.size() && cmd[j]!=' ' && !is_quate(cmd[j])) j++;
if (j<cmd.size()) { if (j<cmd.size()) {
if (cmd[j]==' ') break; if (cmd[j]==' ') break;
if (cmd[j]=='\'') continue; if (is_quate(cmd[j])) continue;
} }
} }
}; };
@ -69,15 +67,33 @@ void find_names(string cmd, vector<string>& input_names, string& output_name, ma
auto substr = [&](size_t i, size_t j) -> string { auto substr = [&](size_t i, size_t j) -> string {
string s; string s;
for (size_t k=i; k<j; k++) for (size_t k=i; k<j; k++)
if (cmd[k]!='\'' && cmd[k]!='"') s += cmd[k]; if (!is_quate(cmd[k])) s += cmd[k];
return s; return s;
}; };
size_t i=0;
pass(i);
while (i<cmd.size()) { while (i<cmd.size()) {
if (cmd[i] == ' ') { if (cmd[i] == ' ') {
i++; i++;
continue; continue;
} }
if (cmd[i] == '-') { if (cmd[i] == '-') {
#ifdef _MSC_VER
if (i+4<cmd.size() && cmd[i+1]=='F' && cmd[i+4]==' ') {
// -Fo: -Fe:
auto j=i+5;
while (j<cmd.size() && cmd[j] == ' ') j++;
CHECK(j<cmd.size());
auto k=j;
pass(k);
CHECK(j<k && output_name.size()==0);
// -Fo: xxx
// i j k
output_name = substr(j, k);
i = k;
continue;
} else
#endif
if (i+2<cmd.size() && cmd[i+1]=='o' && cmd[i+2]==' ') { if (i+2<cmd.size() && cmd[i+1]=='o' && cmd[i+2]==' ') {
auto j=i+3; auto j=i+3;
while (j<cmd.size() && cmd[j] == ' ') j++; while (j<cmd.size() && cmd[j] == ' ') j++;
@ -141,6 +157,8 @@ size_t skip_comments(const string& src, size_t i) {
return i; return i;
} }
map<string,string> jt_env;
void process(string src, vector<string>& input_names, string& cmd) { void process(string src, vector<string>& input_names, string& cmd) {
for (size_t i=0; i<src.size(); i++) { for (size_t i=0; i<src.size(); i++) {
i = skip_comments(src, i); i = skip_comments(src, i);
@ -149,8 +167,9 @@ void process(string src, vector<string>& input_names, string& cmd) {
// #include "a.h" // #include "a.h"
// i jk l // i jk l
auto j=i+1; auto j=i+1;
while (j<src.size() && src[j] != ' ') j++; while (j<src.size() && (src[j] != ' ' && src[j] != '\n')) j++;
if (j>=src.size()) return; if (j>=src.size()) return;
if (j-i != 8 && j-i != 6) continue;
auto k=j+1; auto k=j+1;
while (k<src.size() && src[k] == ' ') k++; while (k<src.size() && src[k] == ' ') k++;
if (k>=src.size()) return; if (k>=src.size()) return;
@ -167,12 +186,22 @@ void process(string src, vector<string>& input_names, string& cmd) {
auto inc = src.substr(k, l-k); auto inc = src.substr(k, l-k);
auto env = getenv(inc.c_str()); auto env = getenv(inc.c_str());
if (env && string(env)!="0") { if (env && string(env)!="0") {
string dflag = " -D"+inc+"="+string(env)+" -o "; auto senv = string(env);
if (!jt_env.count(inc)) {
LOGe << "Load JT env ok:" << inc << senv;
jt_env[inc] = senv;
}
string dflag = " -D"+inc+"="+senv;
if (cmd.find(dflag) == string::npos) { if (cmd.find(dflag) == string::npos) {
// -D flags should insert before -o flag // -D flags should insert before -o flag
auto cmds = split(cmd, " -o ", 2); #ifdef _MSC_VER
string patt = " -Fo: ";
#else
string patt = " -o ";
#endif
auto cmds = split(cmd, patt, 2);
if (cmds.size() == 2) { if (cmds.size() == 2) {
cmd = cmds[0] + dflag + cmds[1]; cmd = cmds[0] + dflag + patt + cmds[1];
} }
} }
} }
@ -199,7 +228,7 @@ static inline void check_win_file(const string& name) {
static inline bool is_full_path(const string& name) { static inline bool is_full_path(const string& name) {
#ifdef _WIN32 #ifdef _WIN32
return name.size()>=2 && name[1]==':'; return name.size()>=2 && (name[1]==':' || (name[0]=='\\' && name[1]=='\\'));
#else #else
return name.size() && name[0]=='/'; return name.size() && name[0]=='/';
#endif #endif
@ -217,6 +246,7 @@ bool cache_compile(string cmd, const string& cache_path, const string& jittor_pa
unordered_set<string> processed; unordered_set<string> processed;
auto src_path = join(jittor_path, "src"); auto src_path = join(jittor_path, "src");
const auto& extra_include = extra["I"]; const auto& extra_include = extra["I"];
string tmp_dir =join(cache_path, "obj_files");
for (size_t i=0; i<input_names.size(); i++) { for (size_t i=0; i<input_names.size(); i++) {
if (processed.count(input_names[i]) != 0) if (processed.count(input_names[i]) != 0)
continue; continue;
@ -224,10 +254,13 @@ bool cache_compile(string cmd, const string& cache_path, const string& jittor_pa
continue; continue;
processed.insert(input_names[i]); processed.insert(input_names[i]);
auto src = read_all(input_names[i]); auto src = read_all(input_names[i]);
ASSERT(src.size()) << "Source read failed:" << input_names[i]; ASSERT(src.size()) << "Source read failed:" << input_names[i] << "cmd:" << cmd;
auto hash = S(hash64(src)); auto hash = S(hash64(src));
vector<string> new_names; vector<string> new_names;
process(src, new_names, cmd); auto back = input_names[i].back();
// *.obj, *.o, *.pyd
if (back != 'j' && back != 'o' && back != 'd')
process(src, new_names, cmd);
for (auto& name : new_names) { for (auto& name : new_names) {
string full_name; string full_name;
if (name.substr(0, 4) == "jit/" || name.substr(0, 4) == "gen/") if (name.substr(0, 4) == "jit/" || name.substr(0, 4) == "gen/")
@ -261,14 +294,15 @@ bool cache_compile(string cmd, const string& cache_path, const string& jittor_pa
if (output_cache_key.size() == 0) { if (output_cache_key.size() == 0) {
LOGvv << "Cache key of" << output_name << "not found."; LOGvv << "Cache key of" << output_name << "not found.";
LOGvvv << "Run cmd:" << cmd; LOGvvv << "Run cmd:" << cmd;
system_with_check(cmd.c_str()); check_win_file(output_name);
system_with_check(cmd.c_str(), tmp_dir.c_str());
ran = true; ran = true;
} }
if (output_cache_key.size() != 0 && output_cache_key != cache_key) { if (output_cache_key.size() != 0 && output_cache_key != cache_key) {
LOGvv << "Cache key of" << output_name << "changed."; LOGvv << "Cache key of" << output_name << "changed.";
LOGvvv << "Run cmd:" << cmd; LOGvvv << "Run cmd:" << cmd;
check_win_file(output_name); check_win_file(output_name);
system_with_check(cmd.c_str()); system_with_check(cmd.c_str(), tmp_dir.c_str());
ran = true; ran = true;
} }
if (output_cache_key != cache_key) { if (output_cache_key != cache_key) {
@ -277,7 +311,7 @@ bool cache_compile(string cmd, const string& cache_path, const string& jittor_pa
write(output_name+".key", cache_key); write(output_name+".key", cache_key);
} }
if (!ran) if (!ran)
LOGvv << "Command cached:" << cmd; LOGvvvv << "Command cached:" << cmd;
return ran; return ran;
} }

View File

@ -0,0 +1,58 @@
// ***************************************************************
// Copyright (c) 2021 Jittor. All Rights Reserved.
// Maintainers: Dun Liang <randonlang@gmail.com>.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#ifndef _WIN32
#include <sys/wait.h>
#ifdef __linux__
#include <sys/prctl.h>
#endif
#include <unistd.h>
#include <execinfo.h>
#include <sys/wait.h>
#include <sys/time.h>
#else
#include <wchar.h>
#include <windows.h>
#endif
#ifdef _MSC_VER
#include <process.h>
#include <synchapi.h>
#define getpid _getpid
inline void sleep(int s) { Sleep(s*1000); }
#else
#include <unistd.h>
#endif
#ifdef _MSC_VER
// typedef struct timeval {
// long tv_sec;
// long tv_usec;
// } timeval;
inline int gettimeofday(struct timeval * tp, struct timezone * tzp)
{
// Note: some broken versions only have 8 trailing zero's, the correct epoch has 9 trailing zero's
// This magic number is the number of 100 nanosecond intervals since January 1, 1601 (UTC)
// until 00:00:00 January 1, 1970
static const uint64_t EPOCH = ((uint64_t) 116444736000000000ULL);
SYSTEMTIME system_time;
FILETIME file_time;
uint64_t time;
GetSystemTime( &system_time );
SystemTimeToFileTime( &system_time, &file_time );
time = ((uint64_t)file_time.dwLowDateTime ) ;
time += ((uint64_t)file_time.dwHighDateTime) << 32;
tp->tv_sec = (long) ((time - EPOCH) / 10000000L);
tp->tv_usec = (long) (system_time.wMilliseconds * 1000);
return 0;
}
#endif

View File

@ -19,9 +19,209 @@
#include <iterator> #include <iterator>
#include <algorithm> #include <algorithm>
#include <cstring> #include <cstring>
#ifdef _WIN32
#include <exception>
#include <windows.h>
#include <eh.h>
#include <sstream>
#endif
#include "utils/seh.h"
namespace jittor { namespace jittor {
#ifdef _WIN32
using std::stringstream;
void raise_win_error(int ierr) {
DWORD err = (DWORD)ierr;
WCHAR *s_buf = NULL; /* Free via LocalFree */
stringstream message;
if (err==0) {
err = GetLastError();
}
auto len = FormatMessageW(
/* Error API error */
FORMAT_MESSAGE_ALLOCATE_BUFFER |
FORMAT_MESSAGE_FROM_SYSTEM |
FORMAT_MESSAGE_IGNORE_INSERTS,
NULL, /* no message source */
err,
MAKELANGID(LANG_NEUTRAL,
SUBLANG_DEFAULT), /* Default language */
(LPWSTR) &s_buf,
0, /* size not used */
NULL); /* no args */
if (len==0) {
/* Only seen this in out of mem situations */
message << "Windows Error " << err;
s_buf = NULL;
} else {
/* remove trailing cr/lf and dots */
while (len > 0 && (s_buf[len-1] <= L' ' || s_buf[len-1] == L'.'))
s_buf[--len] = L'\0';
message << s_buf;
}
if (s_buf)
LocalFree(s_buf);
throw std::runtime_error(message.str());
}
void raise_cxx_exception(DWORD code, const EXCEPTION_RECORD* pr) {
/* The 'code' is a normal win32 error code so it could be handled by
raise_win_error(). However, for some errors, we have additional
information not included in the error code. We handle those here and
delegate all others to the generic function. */
stringstream message;
switch (code) {
case EXCEPTION_ACCESS_VIOLATION:
/* The thread attempted to read from or write
to a virtual address for which it does not
have the appropriate access. */
if (pr->ExceptionInformation[0] == 0)
message << "exception: access violation reading " << (void*)pr->ExceptionInformation[1];
else
message << "exception: access violation writing " << (void*)pr->ExceptionInformation[1];
break;
case EXCEPTION_BREAKPOINT:
/* A breakpoint was encountered. */
message << "exception: breakpoint encountered";
break;
case EXCEPTION_DATATYPE_MISALIGNMENT:
/* The thread attempted to read or write data that is
misaligned on hardware that does not provide
alignment. For example, 16-bit values must be
aligned on 2-byte boundaries, 32-bit values on
4-byte boundaries, and so on. */
message << "exception: datatype misalignment";
break;
case EXCEPTION_SINGLE_STEP:
/* A trace trap or other single-instruction mechanism
signaled that one instruction has been executed. */
message << "exception: single step";
break;
case EXCEPTION_ARRAY_BOUNDS_EXCEEDED:
/* The thread attempted to access an array element
that is out of bounds, and the underlying hardware
supports bounds checking. */
message << "exception: array bounds exceeded";
break;
case EXCEPTION_FLT_DENORMAL_OPERAND:
/* One of the operands in a floating-point operation
is denormal. A denormal value is one that is too
small to represent as a standard floating-point
value. */
message << "exception: floating-point operand denormal";
break;
case EXCEPTION_FLT_DIVIDE_BY_ZERO:
/* The thread attempted to divide a floating-point
value by a floating-point divisor of zero. */
message << "exception: float divide by zero";
break;
case EXCEPTION_FLT_INEXACT_RESULT:
/* The result of a floating-point operation cannot be
represented exactly as a decimal fraction. */
message << "exception: float inexact";
break;
case EXCEPTION_FLT_INVALID_OPERATION:
/* This exception represents any floating-point
exception not included in this list. */
message << "exception: float invalid operation";
break;
case EXCEPTION_FLT_OVERFLOW:
/* The exponent of a floating-point operation is
greater than the magnitude allowed by the
corresponding type. */
message << "exception: float overflow";
break;
case EXCEPTION_FLT_STACK_CHECK:
/* The stack overflowed or underflowed as the result
of a floating-point operation. */
message << "exception: stack over/underflow";
break;
case EXCEPTION_STACK_OVERFLOW:
/* The stack overflowed or underflowed as the result
of a floating-point operation. */
message << "exception: stack overflow";
break;
case EXCEPTION_FLT_UNDERFLOW:
/* The exponent of a floating-point operation is less
than the magnitude allowed by the corresponding
type. */
message << "exception: float underflow";
break;
case EXCEPTION_INT_DIVIDE_BY_ZERO:
/* The thread attempted to divide an integer value by
an integer divisor of zero. */
message << "exception: integer divide by zero";
break;
case EXCEPTION_INT_OVERFLOW:
/* The result of an integer operation caused a carry
out of the most significant bit of the result. */
message << "exception: integer overflow";
break;
case EXCEPTION_PRIV_INSTRUCTION:
/* The thread attempted to execute an instruction
whose operation is not allowed in the current
machine mode. */
message << "exception: privileged instruction";
break;
case EXCEPTION_NONCONTINUABLE_EXCEPTION:
/* The thread attempted to continue execution after a
noncontinuable exception occurred. */
message << "exception: nocontinuable";
break;
case 0xE06D7363:
/* magic number(0xE06D7363) of c++ exception:
https://devblogs.microsoft.com/oldnewthing/20100730-00/?p=13273
*/
message << "Error c++ exception";
break;
default:
raise_win_error(code);
break;
}
// std::cout << message.str() << std::endl;
throw std::runtime_error(message.str());
}
DWORD HandleException(EXCEPTION_POINTERS *ptrs,
DWORD *pdw, EXCEPTION_RECORD *record)
{
*pdw = ptrs->ExceptionRecord->ExceptionCode;
*record = *ptrs->ExceptionRecord;
/* We don't want to catch breakpoint exceptions, they are used to attach
* a debugger to the process.
*/
if (*pdw == EXCEPTION_BREAKPOINT)
return EXCEPTION_CONTINUE_SEARCH;
return EXCEPTION_EXECUTE_HANDLER;
}
#endif
void init_subprocess() { void init_subprocess() {
#ifdef __linux__ #ifdef __linux__
prctl(PR_SET_PDEATHSIG, SIGKILL); prctl(PR_SET_PDEATHSIG, SIGKILL);
@ -193,7 +393,7 @@ static void pyjt_def_core(PyObject* m) {
{ R""(cache_compile)"", { R""(cache_compile)"",
(PyCFunction)(PyObject* (*)(PyObject*,PyObject**,int64,PyObject*))[](PyObject* self, PyObject** args, int64 n, PyObject* kw) -> PyObject* { (PyCFunction)(PyObject* (*)(PyObject*,PyObject**,int64,PyObject*))[](PyObject* self, PyObject** args, int64 n, PyObject* kw) -> PyObject* {
try { try {_JT_SEH_START3;
; ;
uint64 arg_filled=0; uint64 arg_filled=0;
(void)arg_filled; (void)arg_filled;
@ -270,7 +470,7 @@ static void pyjt_def_core(PyObject* m) {
} }
LOGf << "Not a valid call."; LOGf << "Not a valid call.";
} catch (const std::exception& e) { _JT_SEH_END3; } catch (const std::exception& e) {
if (!PyErr_Occurred()) { if (!PyErr_Occurred()) {
PyErr_Format(PyExc_RuntimeError, e.what()); PyErr_Format(PyExc_RuntimeError, e.what());
} }
@ -287,7 +487,7 @@ bool cache_compile(const string& cmd, const string& cache_path="", const string&
{ R""(log)"", { R""(log)"",
(PyCFunction)(PyObject* (*)(PyObject*,PyObject**,int64,PyObject*))[](PyObject* self, PyObject** args, int64 n, PyObject* kw) -> PyObject* { (PyCFunction)(PyObject* (*)(PyObject*,PyObject**,int64,PyObject*))[](PyObject* self, PyObject** args, int64 n, PyObject* kw) -> PyObject* {
try { try {_JT_SEH_START3;
; ;
uint64 arg_filled=0; uint64 arg_filled=0;
(void)arg_filled; (void)arg_filled;
@ -357,7 +557,7 @@ bool cache_compile(const string& cmd, const string& cache_path="", const string&
} }
LOGf << "Not a valid call."; LOGf << "Not a valid call.";
} catch (const std::exception& e) { _JT_SEH_END3; } catch (const std::exception& e) {
if (!PyErr_Occurred()) { if (!PyErr_Occurred()) {
PyErr_Format(PyExc_RuntimeError, e.what()); PyErr_Format(PyExc_RuntimeError, e.what());
} }
@ -374,7 +574,7 @@ void log(const std::string& fileline, const char* level, int verbose, const std:
{ R""(init_subprocess)"", { R""(init_subprocess)"",
(PyCFunction)(PyObject* (*)(PyObject*,PyObject**,int64,PyObject*))[](PyObject* self, PyObject** args, int64 n, PyObject* kw) -> PyObject* { (PyCFunction)(PyObject* (*)(PyObject*,PyObject**,int64,PyObject*))[](PyObject* self, PyObject** args, int64 n, PyObject* kw) -> PyObject* {
try { try {_JT_SEH_START3;
; ;
uint64 arg_filled=0; uint64 arg_filled=0;
(void)arg_filled; (void)arg_filled;
@ -386,7 +586,7 @@ void log(const std::string& fileline, const char* level, int verbose, const std:
} }
LOGf << "Not a valid call."; LOGf << "Not a valid call.";
} catch (const std::exception& e) { _JT_SEH_END3; } catch (const std::exception& e) {
if (!PyErr_Occurred()) { if (!PyErr_Occurred()) {
PyErr_Format(PyExc_RuntimeError, e.what()); PyErr_Format(PyExc_RuntimeError, e.what());
} }
@ -403,7 +603,7 @@ void init_subprocess()
{ R""(log_capture_start)"", { R""(log_capture_start)"",
(PyCFunction)(PyObject* (*)(PyObject*,PyObject**,int64,PyObject*))[](PyObject* self, PyObject** args, int64 n, PyObject* kw) -> PyObject* { (PyCFunction)(PyObject* (*)(PyObject*,PyObject**,int64,PyObject*))[](PyObject* self, PyObject** args, int64 n, PyObject* kw) -> PyObject* {
try { try {_JT_SEH_START3;
; ;
uint64 arg_filled=0; uint64 arg_filled=0;
(void)arg_filled; (void)arg_filled;
@ -415,7 +615,7 @@ void init_subprocess()
} }
LOGf << "Not a valid call."; LOGf << "Not a valid call.";
} catch (const std::exception& e) { _JT_SEH_END3; } catch (const std::exception& e) {
if (!PyErr_Occurred()) { if (!PyErr_Occurred()) {
PyErr_Format(PyExc_RuntimeError, e.what()); PyErr_Format(PyExc_RuntimeError, e.what());
} }
@ -432,7 +632,7 @@ void log_capture_start()
{ R""(log_capture_stop)"", { R""(log_capture_stop)"",
(PyCFunction)(PyObject* (*)(PyObject*,PyObject**,int64,PyObject*))[](PyObject* self, PyObject** args, int64 n, PyObject* kw) -> PyObject* { (PyCFunction)(PyObject* (*)(PyObject*,PyObject**,int64,PyObject*))[](PyObject* self, PyObject** args, int64 n, PyObject* kw) -> PyObject* {
try { try {_JT_SEH_START3;
; ;
uint64 arg_filled=0; uint64 arg_filled=0;
(void)arg_filled; (void)arg_filled;
@ -444,7 +644,7 @@ void log_capture_start()
} }
LOGf << "Not a valid call."; LOGf << "Not a valid call.";
} catch (const std::exception& e) { _JT_SEH_END3; } catch (const std::exception& e) {
if (!PyErr_Occurred()) { if (!PyErr_Occurred()) {
PyErr_Format(PyExc_RuntimeError, e.what()); PyErr_Format(PyExc_RuntimeError, e.what());
} }
@ -461,7 +661,7 @@ void log_capture_stop()
{ R""(log_capture_read)"", { R""(log_capture_read)"",
(PyCFunction)(PyObject* (*)(PyObject*,PyObject**,int64,PyObject*))[](PyObject* self, PyObject** args, int64 n, PyObject* kw) -> PyObject* { (PyCFunction)(PyObject* (*)(PyObject*,PyObject**,int64,PyObject*))[](PyObject* self, PyObject** args, int64 n, PyObject* kw) -> PyObject* {
try { try {_JT_SEH_START3;
; ;
uint64 arg_filled=0; uint64 arg_filled=0;
(void)arg_filled; (void)arg_filled;
@ -475,7 +675,7 @@ void log_capture_stop()
} }
LOGf << "Not a valid call."; LOGf << "Not a valid call.";
} catch (const std::exception& e) { _JT_SEH_END3; } catch (const std::exception& e) {
if (!PyErr_Occurred()) { if (!PyErr_Occurred()) {
PyErr_Format(PyExc_RuntimeError, e.what()); PyErr_Format(PyExc_RuntimeError, e.what());
} }
@ -492,7 +692,7 @@ void log_capture_read()
{ R""(ostream_redirect)"", { R""(ostream_redirect)"",
(PyCFunction)(PyObject* (*)(PyObject*,PyObject**,int64,PyObject*))[](PyObject* self, PyObject** args, int64 n, PyObject* kw) -> PyObject* { (PyCFunction)(PyObject* (*)(PyObject*,PyObject**,int64,PyObject*))[](PyObject* self, PyObject** args, int64 n, PyObject* kw) -> PyObject* {
try { try {_JT_SEH_START3;
; ;
uint64 arg_filled=0; uint64 arg_filled=0;
(void)arg_filled; (void)arg_filled;
@ -540,7 +740,7 @@ void log_capture_read()
} }
LOGf << "Not a valid call."; LOGf << "Not a valid call.";
} catch (const std::exception& e) { _JT_SEH_END3; } catch (const std::exception& e) {
if (!PyErr_Occurred()) { if (!PyErr_Occurred()) {
PyErr_Format(PyExc_RuntimeError, e.what()); PyErr_Format(PyExc_RuntimeError, e.what());
} }

View File

@ -6,15 +6,10 @@
// *************************************************************** // ***************************************************************
#include <string.h> #include <string.h>
#include <signal.h> #include <signal.h>
#include <sys/time.h>
#include <iomanip> #include <iomanip>
#include <thread> #include <thread>
#include <unordered_map> #include <unordered_map>
#include <unistd.h> #include "utils/cross_platform.h"
#ifdef _WIN32
#include <wchar.h>
#include <windows.h>
#endif
#include "utils/log.h" #include "utils/log.h"
#include "utils/mwsr_list.h" #include "utils/mwsr_list.h"
#include "utils/str_utils.h" #include "utils/str_utils.h"
@ -72,6 +67,7 @@ static bool supports_color() {
return term_supports_color; return term_supports_color;
} }
bool g_supports_color = supports_color(); bool g_supports_color = supports_color();
string thread_local thread_name;
struct timeval start_tv; struct timeval start_tv;
@ -166,10 +162,10 @@ void log_capture(const string& s) {
DECLARE_FLAG(int, log_silent); DECLARE_FLAG(int, log_silent);
void send_log(std::ostringstream&& out) { void send_log(std::ostringstream&& out, char level, int verbose) {
if (log_capture_enabled) if (log_capture_enabled)
log_capture(out.str()); log_capture(out.str());
if (log_silent) return; if ((level=='i' || level=='w') && log_silent) return;
if (!log_sync) { if (!log_sync) {
#if LOG_ASYNC #if LOG_ASYNC
mwsr_list_log::push(move(out)); mwsr_list_log::push(move(out));
@ -203,12 +199,15 @@ void log_exiting();
bool exited = false; bool exited = false;
size_t thread_local protected_page = 0; size_t thread_local protected_page = 0;
int segfault_happen = 0; int segfault_happen = 0;
string thread_local thread_name;
static int _pid = getpid(); static int _pid = getpid();
vector<void(*)()> cleanup_callback; vector<void(*)()> cleanup_callback;
vector<void(*)()> sigquit_callback; vector<void(*)()> sigquit_callback;
int64 last_q_time; int64 last_q_time;
string& get_thread_name() {
return thread_name;
}
#ifdef _WIN32 #ifdef _WIN32
void handle_signal(int signal) { void handle_signal(int signal) {
std::cerr << "Caught SIGNAL " << signal << ", quick exit"; std::cerr << "Caught SIGNAL " << signal << ", quick exit";
@ -432,7 +431,7 @@ If you still have problems, please contact us:
} }
#ifdef _WIN32 #ifdef _WIN32
int system_popen(const char *cmd) { int system_popen(const char *cmd, const char* cwd) {
HANDLE g_hChildStd_OUT_Rd = NULL; HANDLE g_hChildStd_OUT_Rd = NULL;
HANDLE g_hChildStd_OUT_Wr = NULL; HANDLE g_hChildStd_OUT_Wr = NULL;
SECURITY_ATTRIBUTES saAttr; SECURITY_ATTRIBUTES saAttr;
@ -472,7 +471,7 @@ int system_popen(const char *cmd) {
TRUE, // handles are inherited TRUE, // handles are inherited
0, // creation flags 0, // creation flags
NULL, // use parent's environment NULL, // use parent's environment
NULL, // use parent's current directory cwd, // use cwd directory
&siStartInfo, // STARTUPINFO pointer &siStartInfo, // STARTUPINFO pointer
&piProcInfo); // receives PROCESS_INFORMATION &piProcInfo); // receives PROCESS_INFORMATION
@ -495,7 +494,8 @@ int system_popen(const char *cmd) {
if (!bSuccess || dwRead == 0) if (!bSuccess || dwRead == 0)
break; break;
output += chBuf; output += chBuf;
bSuccess = WriteFile(hParentStdOut, chBuf, if (log_v)
bSuccess = WriteFile(hParentStdOut, chBuf,
dwRead, &dwWritten, NULL); dwRead, &dwWritten, NULL);
if (!bSuccess) if (!bSuccess)
break; break;
@ -508,6 +508,8 @@ int system_popen(const char *cmd) {
// of the child process, for example. // of the child process, for example.
CloseHandle(piProcInfo.hProcess); CloseHandle(piProcInfo.hProcess);
CloseHandle(piProcInfo.hThread); CloseHandle(piProcInfo.hThread);
if (ec && !log_v)
LOGe << output;
if (ec) { if (ec) {
check_cuda_unsupport_version(output); check_cuda_unsupport_version(output);
@ -516,7 +518,7 @@ int system_popen(const char *cmd) {
return ec; return ec;
} }
#else #else
int system_popen(const char* cmd) { int system_popen(const char* cmd, const char* cwd) {
char buf[BUFSIZ]; char buf[BUFSIZ];
string cmd2; string cmd2;
cmd2 = cmd; cmd2 = cmd;
@ -542,8 +544,8 @@ int system_popen(const char* cmd) {
} }
#endif #endif
void system_with_check(const char* cmd) { void system_with_check(const char* cmd, const char* cwd) {
auto ret = system_popen(cmd); auto ret = system_popen(cmd, cwd);
CHECK(ret>=0 && ret<=256) << "Run cmd failed:" << cmd << CHECK(ret>=0 && ret<=256) << "Run cmd failed:" << cmd <<
"\nreturn ">> ret >> ". This might be an overcommit issue or out of memory." "\nreturn ">> ret >> ". This might be an overcommit issue or out of memory."
<< "Try : sudo sysctl vm.overcommit_memory=1"; << "Try : sudo sysctl vm.overcommit_memory=1";

View File

@ -32,11 +32,26 @@ constexpr int32_t basename_index(const char * const path, const int32_t index =
#define __FILELINE__ \ #define __FILELINE__ \
(&((__FILE__ ":" STRINGIZE(__LINE__))[jittor::basename_index(__FILE__)])) (&((__FILE__ ":" STRINGIZE(__LINE__))[jittor::basename_index(__FILE__)]))
#ifndef _WIN32
#define PREDICT_BRANCH_NOT_TAKEN(x) (__builtin_expect(x, 0)) #define PREDICT_BRANCH_NOT_TAKEN(x) (__builtin_expect(x, 0))
#else
#define PREDICT_BRANCH_NOT_TAKEN(x) (x)
#endif
extern uint32_t get_tid();
extern bool g_supports_color; #ifdef _MSC_VER
extern void print_prefix(std::ostream* out); #define STACK_ALLOC(T, a, n) T* a = (T*)_alloca(sizeof(T)*(n))
#define EXTERN_LIB extern __declspec(dllimport)
#define EXPORT_LIB __declspec(dllimport)
#else
#define STACK_ALLOC(T, a, n) T a[n]
#define EXTERN_LIB extern
#define EXPORT_LIB
#endif
EXTERN_LIB uint32_t get_tid();
EXTERN_LIB bool g_supports_color;
EXTERN_LIB void print_prefix(std::ostream* out);
#ifdef _WIN32 #ifdef _WIN32
constexpr char green[] = "\x1b[1;32m"; constexpr char green[] = "\x1b[1;32m";
@ -44,7 +59,7 @@ constexpr char red[] = "\x1b[1;31m";
constexpr char yellow[] = "\x1b[1;33m"; constexpr char yellow[] = "\x1b[1;33m";
static void get_color(char level, int verbose, const char*& color_begin, const char*& color_end) { inline static void get_color(char level, int verbose, const char*& color_begin, const char*& color_end) {
if (level == 'i') { if (level == 'i') {
if (verbose == 0) color_begin = "\x1b[1;32m"; else if (verbose == 0) color_begin = "\x1b[1;32m"; else
if (verbose < 10) color_begin = "\x1b[1;32m"; else if (verbose < 10) color_begin = "\x1b[1;32m"; else
@ -65,7 +80,7 @@ constexpr char green[] = "\033[38;5;2m";
constexpr char red[] = "\033[38;5;1m"; constexpr char red[] = "\033[38;5;1m";
constexpr char yellow[] = "\033[38;5;3m"; constexpr char yellow[] = "\033[38;5;3m";
static void get_color(char level, int verbose, const char*& color_begin, const char*& color_end) { inline static void get_color(char level, int verbose, const char*& color_begin, const char*& color_end) {
if (level == 'i') { if (level == 'i') {
if (verbose == 0) color_begin = "\033[38;5;2m"; else if (verbose == 0) color_begin = "\033[38;5;2m"; else
if (verbose < 10) color_begin = "\033[38;5;250m"; else if (verbose < 10) color_begin = "\033[38;5;250m"; else
@ -83,18 +98,22 @@ static void get_color(char level, int verbose, const char*& color_begin, const c
#endif #endif
extern void send_log(std::ostringstream&& out); EXTERN_LIB void send_log(std::ostringstream&& out, char level, int verbose);
extern void flush_log(); EXTERN_LIB void flush_log();
extern void log_capture_start(); EXTERN_LIB void log_capture_start();
extern void log_capture_stop(); EXTERN_LIB void log_capture_stop();
extern std::vector<std::map<string,string>> log_capture_read(); EXTERN_LIB std::vector<std::map<string,string>> log_capture_read();
extern string thread_local thread_name; EXTERN_LIB string& get_thread_name();
struct Log { struct Log {
std::ostringstream out; std::ostringstream out;
const char* color_end; const char* color_end;
int verbose;
char level;
Log(const char* const fileline, char level, int verbose) { inline Log(const char* const fileline, char level, int verbose) {
this->verbose = verbose;
this->level = level;
const char* color_begin; const char* color_begin;
get_color(level, verbose, color_begin, color_end); get_color(level, verbose, color_begin, color_end);
if (g_supports_color) out << color_begin; if (g_supports_color) out << color_begin;
@ -104,12 +123,12 @@ struct Log {
out << fileline << ']'; out << fileline << ']';
} }
void end() { inline void end() {
if (g_supports_color) out << color_end; if (g_supports_color) out << color_end;
out << '\n'; out << '\n';
send_log(move(out)); send_log(move(out), level, verbose);
} }
void flush() { flush_log(); } inline void flush() { flush_log(); }
template<class T> template<class T>
Log& operator<<(const T& a) { out << ' ' << a; return *this; } Log& operator<<(const T& a) { out << ' ' << a; return *this; }
@ -118,11 +137,11 @@ struct Log {
}; };
struct LogVoidify { struct LogVoidify {
void operator&&(Log& log) { log.end(); } inline void operator&&(Log& log) { log.end(); }
}; };
struct LogFatalVoidify { struct LogFatalVoidify {
void operator&&(Log& log) { inline void operator&&(Log& log) {
log.flush(); log.flush();
if (g_supports_color) log.out << log.color_end; if (g_supports_color) log.out << log.color_end;
throw std::runtime_error(log.out.str()); throw std::runtime_error(log.out.str());
@ -170,9 +189,9 @@ template<class T> T get_from_env(const char* name,const T& _default) {
template<> std::string get_from_env(const char* name, const std::string& _default); template<> std::string get_from_env(const char* name, const std::string& _default);
#define DECLARE_FLAG(type, name) \ #define DECLARE_FLAG(type, name) \
extern type name; \ EXTERN_LIB type name; \
extern std::string doc_ ## name; \ EXTERN_LIB std::string doc_ ## name; \
extern void set_ ## name (const type&); EXTERN_LIB void set_ ## name (const type&);
#ifdef JIT #ifdef JIT
@ -256,6 +275,6 @@ bool check_vlog(const char* fileline, int verbose);
#define LOGig LOGi >> jittor::green #define LOGig LOGi >> jittor::green
#define LOGiy LOGi >> jittor::yellow #define LOGiy LOGi >> jittor::yellow
void system_with_check(const char* cmd); void system_with_check(const char* cmd, const char* cwd=nullptr);
} // jittor } // jittor

View File

@ -0,0 +1,77 @@
#pragma once
#ifdef _WIN32
#include <windows.h>
#include "common.h"
namespace jittor {
EXTERN_LIB void raise_win_error(int ierr);
EXTERN_LIB void raise_cxx_exception(DWORD code, const EXCEPTION_RECORD* pr);
EXTERN_LIB DWORD HandleException(EXCEPTION_POINTERS *ptrs,
DWORD *pdw, EXCEPTION_RECORD *record);
#define _JT_SEH_TRY \
DWORD dwExceptionCode = 0; \
EXCEPTION_RECORD record; \
__try {
#define _JT_SEH_CATCH \
} \
__except (HandleException(GetExceptionInformation(), \
&dwExceptionCode, &record)) { \
raise_cxx_exception(dwExceptionCode, &record); \
}
#define _JT_SEH_START \
return [&]() { \
_JT_SEH_TRY; \
return [&]() {
#define _JT_SEH_END \
}(); \
_JT_SEH_CATCH; \
}(); \
#define _JT_SEH_START2 \
[&]() { \
_JT_SEH_TRY;
#define _JT_SEH_END2 \
_JT_SEH_CATCH; \
}();
#ifdef JT_SEH_FULL
#define _JT_SEH_START3 \
return [&]() { \
_JT_SEH_TRY; \
return [&]() {
#define _JT_SEH_END3 \
}(); \
_JT_SEH_CATCH; \
}(); \
#else
#define _JT_SEH_START3
#define _JT_SEH_END3
#endif
}
#else
#define _JT_SEH_TRY
#define _JT_SEH_CATCH
#define _JT_SEH_START
#define _JT_SEH_END
#define _JT_SEH_START2
#define _JT_SEH_END2
#define _JT_SEH_START3
#define _JT_SEH_END3
#endif

View File

@ -6,19 +6,8 @@
// *************************************************************** // ***************************************************************
#include <stdio.h> #include <stdio.h>
#include <stdlib.h> #include <stdlib.h>
#ifndef _WIN32
#include <sys/wait.h>
#ifdef __linux__
#include <sys/prctl.h>
#endif
#include <unistd.h>
#include <execinfo.h>
#include <sys/wait.h>
#else
#include <windows.h>
#endif
#include <unistd.h>
#include <iostream> #include <iostream>
#include "utils/cross_platform.h"
#include "utils/tracer.h" #include "utils/tracer.h"
namespace jittor { namespace jittor {
@ -32,7 +21,7 @@ DEFINE_FLAG_WITH_SETTER(int, gdb_attach, 0, "gdb attach self process.");
string _extra_gdb_cmd; string _extra_gdb_cmd;
int system_popen(const char* cmd); int system_popen(const char* cmd, const char* cwd=nullptr);
#ifdef _WIN32 #ifdef _WIN32
string get_cmds(const vector<const char*>& argv) { string get_cmds(const vector<const char*>& argv) {
@ -76,9 +65,9 @@ void setter_gdb_attach(int v) {
} }
} }
} }
LOGi << "gdb attach for" << "pid=" >> pid_buf << argv;
// argv.insert(argv.end(), {name_buf, pid_buf, NULL}); // argv.insert(argv.end(), {name_buf, pid_buf, NULL});
argv.insert(argv.end(), {"-p", pid_buf, NULL}); argv.insert(argv.end(), {"-p", pid_buf, NULL});
LOGi << "gdb attach for" << "pid=" >> pid_buf << argv;
#ifdef _WIN32 #ifdef _WIN32
// _spawnvp(_P_OVERLAY, gdb_path.c_str(), (char* const*)&argv[0]); // _spawnvp(_P_OVERLAY, gdb_path.c_str(), (char* const*)&argv[0]);
@ -150,6 +139,7 @@ void breakpoint() {
} }
void print_trace() { void print_trace() {
LOGir << "???" << gdb_path;
if (gdb_path.size()) { if (gdb_path.size()) {
// using gdb to print the stack trace // using gdb to print the stack trace
char pid_buf[30]; char pid_buf[30];

View File

@ -0,0 +1 @@
#define _P(...)

View File

@ -21,11 +21,11 @@ namespace jittor {
DEFINE_FLAG(int, lazy_execution, 1, "Default enabled, if disable, use immediately eager execution rather than lazy execution, This flag makes error message and traceback infomation better. But this flag will raise memory consumption and lower the performance."); DEFINE_FLAG(int, lazy_execution, 1, "Default enabled, if disable, use immediately eager execution rather than lazy execution, This flag makes error message and traceback infomation better. But this flag will raise memory consumption and lower the performance.");
list<VarHolder*> VarHolder::hold_vars; list<VarHolder*> hold_vars;
void add_hold_vars(VarHolder* self) { void add_hold_vars(VarHolder* self) {
VarHolder::hold_vars.push_front(self); hold_vars.push_front(self);
self->iter = VarHolder::hold_vars.begin(); self->iter = hold_vars.begin();
if (lazy_execution) return; if (lazy_execution) return;
auto v = self->var; auto v = self->var;
for (int i=0; i<5; i++) { for (int i=0; i<5; i++) {
@ -129,7 +129,7 @@ VarHolder* VarHolder::_update(VarHolder* v) {
return this; return this;
} }
extern Executor exe; EXTERN_LIB Executor exe;
void VarHolder::sync(bool device_sync) { void VarHolder::sync(bool device_sync) {
jittor::sync({this}, device_sync); jittor::sync({this}, device_sync);
@ -162,12 +162,12 @@ ItemData VarHolder::item() {
} }
// from fetch_op.cc // from fetch_op.cc
extern list<VarPtr> fetcher; EXTERN_LIB list<VarPtr> fetcher;
void sync_all(bool device_sync) { void sync_all(bool device_sync) {
vector<Var*> vars; vector<Var*> vars;
vars.reserve(VarHolder::hold_vars.size()); vars.reserve(hold_vars.size());
for (auto v : VarHolder::hold_vars) { for (auto v : hold_vars) {
if (!v->var->_outputs.size()) if (!v->var->_outputs.size())
vars.push_back(v->var); vars.push_back(v->var);
} }

View File

@ -30,6 +30,8 @@ struct ItemData {
typedef struct _object PyObject; typedef struct _object PyObject;
EXTERN_LIB list<VarHolder*> hold_vars;
// @pyjt(Var) // @pyjt(Var)
// @attrs(heaptype) // @attrs(heaptype)
struct VarHolder { struct VarHolder {
@ -82,7 +84,6 @@ struct VarHolder {
void operator=(VarPtr&& v); void operator=(VarPtr&& v);
static list<VarHolder*> hold_vars;
/** /**
* set the name of the Var. * set the name of the Var.

View File

@ -17,6 +17,8 @@ def all_eq(x, y):
convert = lambda x: x.astype("uint8") if x.dtype=="bool" else x convert = lambda x: x.astype("uint8") if x.dtype=="bool" else x
x = convert(x) x = convert(x)
y = convert(y) y = convert(y)
if str(x.dtype).startswith("float"):
return str(y.dtype).startswith("float") and x.shape == y.shape and (x==y).all()
return x.dtype == y.dtype and x.shape == y.shape and (x==y).all() return x.dtype == y.dtype and x.shape == y.shape and (x==y).all()
def check(op, *args): def check(op, *args):

View File

@ -76,23 +76,59 @@ class TestDataset(unittest.TestCase):
assert isinstance(batch[1], np.ndarray) assert isinstance(batch[1], np.ndarray)
class YourDataset(Dataset):
def __init__(self):
super().__init__()
self.set_attrs(total_len=10240)
def __getitem__(self, k):
self.tmp = None
x = jt.array(k)
y = x
for i in range(10):
for j in range(i+2):
y = y + j - j
y.stop_fuse()
return x, y
class YourDataset2(Dataset):
def __init__(self):
super().__init__()
self.set_attrs(total_len=16)
def __getitem__(self, k):
return np.random.rand(2)
class YourDataset3(Dataset):
def __init__(self):
super().__init__()
self.set_attrs(total_len=16)
def __getitem__(self, k):
return random.randint(0,1000)
class YourDataset4(Dataset):
def __init__(self):
super().__init__()
self.set_attrs(total_len=160)
def __getitem__(self, k):
return jt.rand(2)
class YourDataset5(Dataset):
def __init__(self):
super().__init__()
self.set_attrs(total_len=160)
def __getitem__(self, k):
return { "a":np.array([1,2,3]) }
class TestDataset2(unittest.TestCase): class TestDataset2(unittest.TestCase):
def test_dataset_use_jittor(self): def test_dataset_use_jittor(self):
class YourDataset(Dataset):
def __init__(self):
super().__init__()
self.set_attrs(total_len=10240)
def __getitem__(self, k):
self.tmp = None
x = jt.array(k)
y = x
for i in range(10):
for j in range(i+2):
y = y + j - j
y.stop_fuse()
return x, y
dataset = YourDataset().set_attrs(batch_size=256, shuffle=True, num_workers=4) dataset = YourDataset().set_attrs(batch_size=256, shuffle=True, num_workers=4)
dataset.tmp = jt.array([1,2,3,4,5]) dataset.tmp = jt.array([1,2,3,4,5])
dataset.tmp.sync() dataset.tmp.sync()
@ -108,15 +144,8 @@ class TestDataset2(unittest.TestCase):
class TestDatasetSeed(unittest.TestCase): class TestDatasetSeed(unittest.TestCase):
def test_np(self): def test_np(self):
class YourDataset(Dataset):
def __init__(self):
super().__init__()
self.set_attrs(total_len=16)
def __getitem__(self, k): dataset = YourDataset2().set_attrs(batch_size=1, shuffle=True, num_workers=4)
return np.random.rand(2)
dataset = YourDataset().set_attrs(batch_size=1, shuffle=True, num_workers=4)
for _ in range(10): for _ in range(10):
dd = [] dd = []
for d in dataset: for d in dataset:
@ -127,16 +156,9 @@ class TestDatasetSeed(unittest.TestCase):
def test_py_native(self): def test_py_native(self):
import random import random
class YourDataset(Dataset):
def __init__(self):
super().__init__()
self.set_attrs(total_len=16)
def __getitem__(self, k):
return random.randint(0,1000)
jt.set_global_seed(0) jt.set_global_seed(0)
dataset = YourDataset().set_attrs(batch_size=1, shuffle=True, num_workers=4) dataset = YourDataset3().set_attrs(batch_size=1, shuffle=True, num_workers=4)
for _ in range(10): for _ in range(10):
dd = [] dd = []
for d in dataset: for d in dataset:
@ -147,16 +169,9 @@ class TestDatasetSeed(unittest.TestCase):
def test_jtrand(self): def test_jtrand(self):
import random import random
class YourDataset(Dataset):
def __init__(self):
super().__init__()
self.set_attrs(total_len=160)
def __getitem__(self, k):
return jt.rand(2)
jt.set_global_seed(0) jt.set_global_seed(0)
dataset = YourDataset().set_attrs(batch_size=1, shuffle=True, num_workers=4) dataset = YourDataset4().set_attrs(batch_size=1, shuffle=True, num_workers=4)
for _ in range(10): for _ in range(10):
dd = [] dd = []
for d in dataset: for d in dataset:
@ -167,16 +182,9 @@ class TestDatasetSeed(unittest.TestCase):
def test_dict(self): def test_dict(self):
import random import random
class YourDataset(Dataset):
def __init__(self):
super().__init__()
self.set_attrs(total_len=160)
def __getitem__(self, k):
return { "a":np.array([1,2,3]) }
jt.set_global_seed(0) jt.set_global_seed(0)
dataset = YourDataset().set_attrs(batch_size=1, shuffle=True, num_workers=4) dataset = YourDataset5().set_attrs(batch_size=1, shuffle=True, num_workers=4)
for _ in range(10): for _ in range(10):
dd = [] dd = []
for d in dataset: for d in dataset:
@ -216,6 +224,11 @@ class TestDatasetSeed(unittest.TestCase):
assert z[i] == c assert z[i] == c
def test_children_died(self): def test_children_died(self):
if os.name == 'nt':
# TODO: windows cannot pass this test now
# don't know how to detect child died in windows
# some clue: https://ikriv.com/blog/?p=1431
return
src = """ src = """
import jittor as jt import jittor as jt
from jittor.dataset import Dataset from jittor.dataset import Dataset
@ -231,13 +244,13 @@ class YourDataset(Dataset):
while 1: while 1:
pass pass
return { "a":np.array([1,2,3]) } return { "a":np.array([1,2,3]) }
if __name__ == "__main__":
dataset = YourDataset()
dataset.set_attrs(num_workers=2)
dataset = YourDataset() for d in dataset:
dataset.set_attrs(num_workers=2) dataset.workers[0].p.kill()
pass
for d in dataset:
dataset.workers[0].p.kill()
pass
""" """
fname = os.path.join(jt.flags.cache_path, "children_dead_test.py") fname = os.path.join(jt.flags.cache_path, "children_dead_test.py")
with open(fname, 'w') as f: with open(fname, 'w') as f:
@ -271,12 +284,13 @@ class YourDataset(Dataset):
pass pass
return { "a":np.array([1,2,3]) } return { "a":np.array([1,2,3]) }
dataset = YourDataset() if __name__ == "__main__":
dataset.set_attrs(num_workers=2) dataset = YourDataset()
dataset.set_attrs(num_workers=2)
for d in dataset: for d in dataset:
break break
dataset.terminate() dataset.terminate()
""" """
fname = os.path.join(jt.flags.cache_path, "children_dead_test.py") fname = os.path.join(jt.flags.cache_path, "children_dead_test.py")
with open(fname, 'w') as f: with open(fname, 'w') as f:

Some files were not shown because too many files have changed in this diff Show More