forked from jittor/jittor
msvc support
This commit is contained in:
parent
4e38190483
commit
c3938e14bf
|
@ -12,6 +12,7 @@ perf.data.old
|
||||||
*.pdf
|
*.pdf
|
||||||
*.zip
|
*.zip
|
||||||
*.tgz
|
*.tgz
|
||||||
|
*.obj
|
||||||
test.py
|
test.py
|
||||||
extern/mkl/mkldnn_lnx*/*
|
extern/mkl/mkldnn_lnx*/*
|
||||||
data/
|
data/
|
||||||
|
|
|
@ -25,6 +25,7 @@ def install_mkl(root_folder):
|
||||||
# origin url is
|
# origin url is
|
||||||
# url = "https://github.com/intel/mkl-dnn/releases/download/v1.0.2/mkldnn_lnx_1.0.2_cpu_gomp.tgz"
|
# url = "https://github.com/intel/mkl-dnn/releases/download/v1.0.2/mkldnn_lnx_1.0.2_cpu_gomp.tgz"
|
||||||
import platform
|
import platform
|
||||||
|
url = None
|
||||||
if platform.system()=="Linux":
|
if platform.system()=="Linux":
|
||||||
if platform.machine()=='x86_64':
|
if platform.machine()=='x86_64':
|
||||||
filename = "dnnl_lnx_2.2.0_cpu_gomp.tgz"
|
filename = "dnnl_lnx_2.2.0_cpu_gomp.tgz"
|
||||||
|
@ -35,23 +36,44 @@ def install_mkl(root_folder):
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f"platform.machine()=={platform.machine()} not support yet,"
|
raise RuntimeError(f"platform.machine()=={platform.machine()} not support yet,"
|
||||||
" Please contact us on https://github.com/jittor/jittor ")
|
" Please contact us on https://github.com/jittor/jittor ")
|
||||||
|
elif os.name == "nt":
|
||||||
|
# url = "https://github.com/oneapi-src/oneDNN/releases/download/v2.2/dnnl_win_2.2.0_cpu_iomp.zip"
|
||||||
|
# url = "https://github.com/oneapi-src/oneDNN/releases/download/v2.2/dnnl_win_2.2.0_cpu_vcomp.zip"
|
||||||
|
filename = "dnnl_win_2.2.0_cpu_vcomp.zip"
|
||||||
|
md5 = "fa12c693b2ec07700d174e1e99d60a7e"
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f"platform.machine()=={platform.machine()} not support yet,"
|
raise RuntimeError(f"platform.machine()=={platform.machine()} not support yet,"
|
||||||
" Please contact us on https://github.com/jittor/jittor ")
|
" Please contact us on https://github.com/jittor/jittor ")
|
||||||
|
|
||||||
url = "https://cg.cs.tsinghua.edu.cn/jittor/assets/" + filename
|
if not url:
|
||||||
|
url = "https://cg.cs.tsinghua.edu.cn/jittor/assets/" + filename
|
||||||
fullname = os.path.join(root_folder, filename)
|
fullname = os.path.join(root_folder, filename)
|
||||||
dirname = os.path.join(root_folder, filename.replace(".tgz",""))
|
dirname = os.path.join(root_folder, filename.rsplit(".",1)[0])
|
||||||
|
|
||||||
if not os.path.isfile(os.path.join(dirname, "lib", "libmkldnn.so")):
|
if not (os.path.isfile(os.path.join(dirname, "lib", "libmkldnn.so")) or
|
||||||
|
os.path.isfile(os.path.join(dirname, "bin", "dnnl.dll"))):
|
||||||
LOG.i("Downloading mkl...")
|
LOG.i("Downloading mkl...")
|
||||||
download_url_to_local(url, filename, root_folder, md5)
|
download_url_to_local(url, filename, root_folder, md5)
|
||||||
import tarfile
|
if fullname.endswith(".zip"):
|
||||||
|
import zipfile
|
||||||
with tarfile.open(fullname, "r") as tar:
|
with zipfile.ZipFile(fullname, "r") as f:
|
||||||
tar.extractall(root_folder)
|
f.extractall(root_folder)
|
||||||
|
else:
|
||||||
assert 0 == os.system(f"cd {dirname}/examples && "
|
import tarfile
|
||||||
|
with tarfile.open(fullname, "r") as tar:
|
||||||
|
tar.extractall(root_folder)
|
||||||
|
if os.name == 'nt':
|
||||||
|
# this env is used for execute example/text
|
||||||
|
bin_path = os.path.join(dirname, "bin")
|
||||||
|
sys.path.append(bin_path)
|
||||||
|
os.add_dll_directory(bin_path)
|
||||||
|
os.environ["PATH"] = os.environ.get("PATH", "") + ";" + bin_path
|
||||||
|
cmd = f"cd /d {dirname}/examples && {cc_path} {dirname}/examples/cnn_inference_f32.cpp -I{dirname}/include -Fe: {dirname}/examples/test {cc_flags} {win_link_flags} {dirname}/lib/mkldnn.lib"
|
||||||
|
|
||||||
|
assert 0 == os.system(cmd)
|
||||||
|
assert 0 == os.system(f"{dirname}/examples/test")
|
||||||
|
else:
|
||||||
|
assert 0 == os.system(f"cd {dirname}/examples && "
|
||||||
f"{cc_path} -std=c++14 cnn_inference_f32.cpp -Ofast -lmkldnn -I ../include -L ../lib -o test && LD_LIBRARY_PATH=../lib/ ./test")
|
f"{cc_path} -std=c++14 cnn_inference_f32.cpp -Ofast -lmkldnn -I ../include -L ../lib -o test && LD_LIBRARY_PATH=../lib/ ./test")
|
||||||
|
|
||||||
def setup_mkl():
|
def setup_mkl():
|
||||||
|
@ -74,7 +96,7 @@ def setup_mkl():
|
||||||
mkl_include_path = os.environ.get("mkl_include_path")
|
mkl_include_path = os.environ.get("mkl_include_path")
|
||||||
mkl_lib_path = os.environ.get("mkl_lib_path")
|
mkl_lib_path = os.environ.get("mkl_lib_path")
|
||||||
|
|
||||||
if platform.system() == 'Linux':
|
if platform.system() == 'Linux' or os.name == 'nt':
|
||||||
if mkl_lib_path is None or mkl_include_path is None:
|
if mkl_lib_path is None or mkl_include_path is None:
|
||||||
mkl_install_sh = os.path.join(jittor_path, "script", "install_mkl.sh")
|
mkl_install_sh = os.path.join(jittor_path, "script", "install_mkl.sh")
|
||||||
LOG.v("setup mkl...")
|
LOG.v("setup mkl...")
|
||||||
|
@ -95,6 +117,13 @@ def setup_mkl():
|
||||||
mkl_lib_path = os.path.join(mkl_home, "lib")
|
mkl_lib_path = os.path.join(mkl_home, "lib")
|
||||||
|
|
||||||
mkl_lib_name = os.path.join(mkl_lib_path, "libmkldnn.so")
|
mkl_lib_name = os.path.join(mkl_lib_path, "libmkldnn.so")
|
||||||
|
extra_flags = f" -I\"{mkl_include_path}\" -L\"{mkl_lib_path}\" -lmkldnn -Wl,-rpath='{mkl_lib_path}' "
|
||||||
|
if os.name == 'nt':
|
||||||
|
mkl_lib_name = os.path.join(mkl_home, 'bin', 'dnnl.dll')
|
||||||
|
mkl_bin_path = os.path.join(mkl_home, 'bin')
|
||||||
|
os.add_dll_directory(mkl_bin_path)
|
||||||
|
mkl_lib = os.path.join(mkl_lib_path, "dnnl.lib")
|
||||||
|
extra_flags = f" -I\"{mkl_include_path}\" \"{mkl_lib}\" "
|
||||||
assert os.path.isdir(mkl_include_path)
|
assert os.path.isdir(mkl_include_path)
|
||||||
assert os.path.isdir(mkl_lib_path)
|
assert os.path.isdir(mkl_lib_path)
|
||||||
assert os.path.isfile(mkl_lib_name)
|
assert os.path.isfile(mkl_lib_name)
|
||||||
|
@ -103,7 +132,6 @@ def setup_mkl():
|
||||||
LOG.v(f"mkl_lib_name: {mkl_lib_name}")
|
LOG.v(f"mkl_lib_name: {mkl_lib_name}")
|
||||||
# We do not link manualy, link in custom ops
|
# We do not link manualy, link in custom ops
|
||||||
# ctypes.CDLL(mkl_lib_name, dlopen_flags)
|
# ctypes.CDLL(mkl_lib_name, dlopen_flags)
|
||||||
extra_flags = f" -I'{mkl_include_path}' -L'{mkl_lib_path}' -lmkldnn -Wl,-rpath='{mkl_lib_path}' "
|
|
||||||
|
|
||||||
elif platform.system() == 'Darwin':
|
elif platform.system() == 'Darwin':
|
||||||
mkl_lib_paths = [
|
mkl_lib_paths = [
|
||||||
|
@ -508,6 +536,7 @@ world_size = mpi.world_size() if in_mpi else 1
|
||||||
setup_nccl()
|
setup_nccl()
|
||||||
|
|
||||||
setup_cutt()
|
setup_cutt()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
setup_mkl()
|
setup_mkl()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
@ -55,18 +55,22 @@ def compile(compiler, flags, inputs, output, combind_build=False):
|
||||||
link = link_flags
|
link = link_flags
|
||||||
base_output = os.path.basename(output).split('.')[0]
|
base_output = os.path.basename(output).split('.')[0]
|
||||||
if os.name == 'nt':
|
if os.name == 'nt':
|
||||||
# initialize order in windows seems reversed
|
# windows do not combind build, need gen def
|
||||||
inputs = list(inputs[::-1])
|
combind_build = False
|
||||||
# windows need libxxx.a
|
# windows need xxxx.lib
|
||||||
afile = os.path.join(cache_path, f"lib{base_output}.a")
|
afile = output.rsplit('.', 1)[0] + ".lib"
|
||||||
link = link + f' -Wl,--export-all-symbols,--out-implib,"{afile}" '
|
afile = os.path.join(cache_path, afile)
|
||||||
if base_output == "jit_utils_core":
|
if cc_type != 'cl':
|
||||||
pass
|
# initialize order in windows seems reversed
|
||||||
elif base_output == "jittor_core":
|
inputs = list(inputs[::-1])
|
||||||
inputs.append(os.path.join(cache_path, f"libjit_utils_core.a"))
|
link = link + f' -Wl,--export-all-symbols,--out-implib,"{afile}" '
|
||||||
else:
|
if base_output == "jit_utils_core":
|
||||||
inputs.append(os.path.join(cache_path, f"libjit_utils_core.a"))
|
pass
|
||||||
inputs.append(os.path.join(cache_path, f"libjittor_core.a"))
|
elif base_output == "jittor_core":
|
||||||
|
inputs.append(os.path.join(cache_path, f"jit_utils_core{lib_suffix}"))
|
||||||
|
else:
|
||||||
|
inputs.append(os.path.join(cache_path, f"jit_utils_core{lib_suffix}"))
|
||||||
|
inputs.append(os.path.join(cache_path, f"jittor_core{lib_suffix}"))
|
||||||
|
|
||||||
# if output is core, add core_link_flags
|
# if output is core, add core_link_flags
|
||||||
if output.startswith("jittor_core"):
|
if output.startswith("jittor_core"):
|
||||||
|
@ -77,7 +81,7 @@ def compile(compiler, flags, inputs, output, combind_build=False):
|
||||||
ex_obj_files = []
|
ex_obj_files = []
|
||||||
new_inputs = []
|
new_inputs = []
|
||||||
for name in inputs:
|
for name in inputs:
|
||||||
if name[-1] in 'oa':
|
if name[-1] in 'oab':
|
||||||
ex_obj_files.append(name)
|
ex_obj_files.append(name)
|
||||||
else:
|
else:
|
||||||
new_inputs.append(os.path.join(jittor_path, name))
|
new_inputs.append(os.path.join(jittor_path, name))
|
||||||
|
@ -87,7 +91,7 @@ def compile(compiler, flags, inputs, output, combind_build=False):
|
||||||
|
|
||||||
if len(inputs) == 1 or combind_build:
|
if len(inputs) == 1 or combind_build:
|
||||||
cmd = f"\"{compiler}\" {' '.join(inputs)} {flags} {link} -o {output}"
|
cmd = f"\"{compiler}\" {' '.join(inputs)} {flags} {link} -o {output}"
|
||||||
return do_compile(cmd)
|
return do_compile(fix_cl_flags(cmd))
|
||||||
# split compile object file and link
|
# split compile object file and link
|
||||||
# remove -l -L flags when compile object files
|
# remove -l -L flags when compile object files
|
||||||
oflags = remove_flags(flags, ['-l', '-L', '-Wl,'])
|
oflags = remove_flags(flags, ['-l', '-L', '-Wl,'])
|
||||||
|
@ -101,16 +105,20 @@ def compile(compiler, flags, inputs, output, combind_build=False):
|
||||||
cc = nvcc_path
|
cc = nvcc_path
|
||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
cmd = f"{cc} {input} {nflags} -c {lto_flags} -o {obj_file}"
|
cmd = f"\"{cc}\" {input} {nflags} {lto_flags} -c -o {obj_file}"
|
||||||
if "nan_checker" in input:
|
if "nan_checker" in input:
|
||||||
# nan checker needs to disable fast_math
|
# nan checker needs to disable fast_math
|
||||||
cmd = cmd.replace("--use_fast_math", "")
|
cmd = cmd.replace("--use_fast_math", "")
|
||||||
cmd = cmd.replace("-Ofast", "-O2")
|
cmd = cmd.replace("-Ofast", "-O2")
|
||||||
cmds.append(cmd)
|
cmds.append(fix_cl_flags(cmd))
|
||||||
jit_utils.run_cmds(cmds, cache_path, jittor_path, "Compiling "+base_output)
|
jit_utils.run_cmds(cmds, cache_path, jittor_path, "Compiling "+base_output)
|
||||||
obj_files += ex_obj_files
|
obj_files += ex_obj_files
|
||||||
|
if os.name == 'nt':
|
||||||
|
dumpdef_path = os.path.join(jittor_path, "utils", "dumpdef.py")
|
||||||
|
cmd = f"\"{sys.executable}\" \"{dumpdef_path}\" {' '.join(obj_files)} -Fo: \"{output}.def\""
|
||||||
|
do_compile(fix_cl_flags(cmd))
|
||||||
cmd = f"\"{compiler}\" {' '.join(obj_files)} {flags} {lto_flags} {link} -o {output}"
|
cmd = f"\"{compiler}\" {' '.join(obj_files)} {flags} {lto_flags} {link} -o {output}"
|
||||||
return do_compile(cmd)
|
return do_compile(fix_cl_flags(cmd))
|
||||||
|
|
||||||
def gen_jit_tests():
|
def gen_jit_tests():
|
||||||
all_src = glob.glob(jittor_path+"/src/**/*.cc", recursive=True)
|
all_src = glob.glob(jittor_path+"/src/**/*.cc", recursive=True)
|
||||||
|
@ -660,7 +668,7 @@ def compile_custom_ops(
|
||||||
gen_name = gen_name[:80] + "___hash" + hashlib.md5(gen_name.encode()).hexdigest()
|
gen_name = gen_name[:80] + "___hash" + hashlib.md5(gen_name.encode()).hexdigest()
|
||||||
|
|
||||||
includes = sorted(list(set(includes)))
|
includes = sorted(list(set(includes)))
|
||||||
includes = "".join(map(lambda x: f" -I'{x}' ", includes))
|
includes = "".join(map(lambda x: f" -I\"{x}\" ", includes))
|
||||||
LOG.vvvv(f"Include flags:{includes}")
|
LOG.vvvv(f"Include flags:{includes}")
|
||||||
|
|
||||||
op_extra_flags = includes + extra_flags
|
op_extra_flags = includes + extra_flags
|
||||||
|
@ -916,7 +924,7 @@ if not nvcc_path:
|
||||||
nvcc_path = try_find_exe(nvcc_path)
|
nvcc_path = try_find_exe(nvcc_path)
|
||||||
if nvcc_path is None:
|
if nvcc_path is None:
|
||||||
nvcc_path = ""
|
nvcc_path = ""
|
||||||
gdb_path = try_find_exe('gdb')
|
gdb_path = env_or_try_find('gdb_path', 'gdb')
|
||||||
addr2line_path = try_find_exe('addr2line')
|
addr2line_path = try_find_exe('addr2line')
|
||||||
has_pybt = check_pybt(gdb_path, python_path)
|
has_pybt = check_pybt(gdb_path, python_path)
|
||||||
|
|
||||||
|
@ -952,26 +960,80 @@ if platform.system() == 'Darwin':
|
||||||
core_link_flags = ""
|
core_link_flags = ""
|
||||||
opt_flags = ""
|
opt_flags = ""
|
||||||
|
|
||||||
|
py_include = jit_utils.get_py3_include_path()
|
||||||
|
LOG.i(f"py_include: {py_include}")
|
||||||
|
extension_suffix = jit_utils.get_py3_extension_suffix()
|
||||||
|
lib_suffix = extension_suffix.replace(".pyd", ".lib")
|
||||||
|
LOG.i(f"extension_suffix: {extension_suffix}")
|
||||||
|
|
||||||
|
|
||||||
kernel_opt_flags = os.environ.get("kernel_flags", "") + opt_flags
|
kernel_opt_flags = os.environ.get("kernel_flags", "") + opt_flags
|
||||||
if platform.system() == 'Darwin':
|
if platform.system() == 'Darwin':
|
||||||
# TODO: if not using apple clang, cannot add -Xpreprocessor
|
# TODO: if not using apple clang, cannot add -Xpreprocessor
|
||||||
kernel_opt_flags = kernel_opt_flags + " -Xpreprocessor -fopenmp "
|
kernel_opt_flags = kernel_opt_flags + " -Xpreprocessor -fopenmp "
|
||||||
else:
|
elif cc_type != 'cl':
|
||||||
kernel_opt_flags = kernel_opt_flags + " -fopenmp "
|
kernel_opt_flags = kernel_opt_flags + " -fopenmp "
|
||||||
|
fix_cl_flags = lambda x:x
|
||||||
if os.name == 'nt':
|
if os.name == 'nt':
|
||||||
link_flags = link_flags.replace('-ldl', '')
|
if cc_type == 'g++':
|
||||||
py3_link_path = '-L"' + os.path.join(
|
link_flags = link_flags.replace('-ldl', '')
|
||||||
os.path.dirname(sys.executable),
|
py3_link_path = '-L"' + os.path.join(
|
||||||
"libs"
|
os.path.dirname(sys.executable),
|
||||||
) + f'" -lpython3{sys.version_info.minor} '
|
"libs"
|
||||||
core_link_flags = py3_link_path
|
) + f'" -lpython3{sys.version_info.minor} '
|
||||||
link_flags += core_link_flags
|
core_link_flags = py3_link_path
|
||||||
# link_flags += " -Wl,--unresolved-symbols=ignore-all "
|
link_flags += core_link_flags
|
||||||
# cc_flags += " -Xlinker --allow-shlib-undefined "
|
# link_flags += " -Wl,--unresolved-symbols=ignore-all "
|
||||||
cc_flags = cc_flags.replace('-std=c++14', '-std=c++17')
|
# cc_flags += " -Xlinker --allow-shlib-undefined "
|
||||||
link_flags += " -fopenmp "
|
cc_flags = cc_flags.replace('-std=c++14', '-std=c++17')
|
||||||
kernel_opt_flags += f" {cache_path}\\libjit_utils_core.a "
|
link_flags += " -fopenmp "
|
||||||
kernel_opt_flags += f" {cache_path}\\libjittor_core.a "
|
kernel_opt_flags += f" {cache_path}\\jit_utils_core{lib_suffix} "
|
||||||
|
kernel_opt_flags += f" {cache_path}\\jittor_core{lib_suffix} "
|
||||||
|
elif cc_type == 'cl':
|
||||||
|
py3_link_path = os.path.join(
|
||||||
|
os.path.dirname(sys.executable),
|
||||||
|
"libs",
|
||||||
|
f'python3{sys.version_info.minor}.lib'
|
||||||
|
)
|
||||||
|
# core_link_flags = py3_link_path
|
||||||
|
link_flags += core_link_flags
|
||||||
|
# link_flags += " -Wl,--unresolved-symbols=ignore-all "
|
||||||
|
# cc_flags += " -Xlinker --allow-shlib-undefined "
|
||||||
|
kernel_opt_flags += f" {cache_path}\\jit_utils_core{lib_suffix} "
|
||||||
|
kernel_opt_flags += f" {cache_path}\\jittor_core{lib_suffix} "
|
||||||
|
# cc_flags = " -std:c++17 -O2 -fp:fast -EHsc "
|
||||||
|
cc_flags = " -std:c++17 -O2 -fp:fast -EHsc "
|
||||||
|
# cc_flags += py3_link_path + " "
|
||||||
|
import jittor_utils
|
||||||
|
if jittor_utils.msvc_path:
|
||||||
|
mp = jittor_utils.msvc_path
|
||||||
|
cc_flags += f' -nologo -I"{mp}\\cl_x64\\include" -I"{mp}\\win10_kits\\include\\ucrt" -I"{mp}\\win10_kits\\include\\shared" -I"{mp}\\win10_kits\\include\\um" -DNOMINMAX '
|
||||||
|
win_link_flags = f' -link -LIBPATH:"{mp}\\cl_x64\\lib" -LIBPATH:"{mp}\\win10_kits\\lib\\um\\x64" -LIBPATH:"{mp}\\win10_kits\\lib\\ucrt\\x64" '
|
||||||
|
link_flags = ' -LD '
|
||||||
|
kernel_opt_flags += win_link_flags# + " -EXPORT:\"?jit_run@FusedOp@jittor@@QEAAXXZ\""
|
||||||
|
def fix_cl_flags(cmd):
|
||||||
|
cmd = cmd.replace(".o ", ".obj ")
|
||||||
|
cmd = cmd.replace(".o\" ", ".obj\" ")
|
||||||
|
if cmd.endswith(".o"): cmd += "bj"
|
||||||
|
from shlex import split
|
||||||
|
if " -LD " in cmd:
|
||||||
|
cmd = cmd.replace(" -o ", " -Fe: ")
|
||||||
|
output = split(cmd.split("-Fe:")[1].strip(), posix=False)[0]
|
||||||
|
base_output = os.path.basename(output).split('.')[0]
|
||||||
|
cmd += win_link_flags
|
||||||
|
cmd += f" -DEF:\"{output}.def\" -IGNORE:4102 -IGNORE:4197 -IGNORE:4217 {py3_link_path}"
|
||||||
|
if base_output == "jit_utils_core":
|
||||||
|
pass
|
||||||
|
elif base_output == "jittor_core":
|
||||||
|
cmd += " " + os.path.join(cache_path, f"jit_utils_core{lib_suffix}")
|
||||||
|
else:
|
||||||
|
cmd += " " + os.path.join(cache_path, f"jit_utils_core{lib_suffix} ")
|
||||||
|
cmd += " " + os.path.join(cache_path, f"jittor_core{lib_suffix} ")
|
||||||
|
|
||||||
|
elif " -c -o " in cmd:
|
||||||
|
cmd = cmd.replace(" -c -o ", " -c -Fo: ")
|
||||||
|
cmd = cmd.replace("-include", "-FI")
|
||||||
|
return cmd
|
||||||
|
|
||||||
if ' -O' not in cc_flags:
|
if ' -O' not in cc_flags:
|
||||||
opt_flags += " -O2 "
|
opt_flags += " -O2 "
|
||||||
|
@ -985,11 +1047,6 @@ if os.environ.get("enable_lto") == "1":
|
||||||
else:
|
else:
|
||||||
lto_flags = " -flto "
|
lto_flags = " -flto "
|
||||||
|
|
||||||
py_include = jit_utils.get_py3_include_path()
|
|
||||||
LOG.i(f"py_include: {py_include}")
|
|
||||||
extension_suffix = jit_utils.get_py3_extension_suffix()
|
|
||||||
LOG.i(f"extension_suffix: {extension_suffix}")
|
|
||||||
|
|
||||||
make_cache_dir(cache_path)
|
make_cache_dir(cache_path)
|
||||||
make_cache_dir(os.path.join(cache_path, "jit"))
|
make_cache_dir(os.path.join(cache_path, "jit"))
|
||||||
make_cache_dir(os.path.join(cache_path, "obj_files"))
|
make_cache_dir(os.path.join(cache_path, "obj_files"))
|
||||||
|
@ -1107,7 +1164,8 @@ if use_data_gz:
|
||||||
dflags = (cc_flags+opt_flags)\
|
dflags = (cc_flags+opt_flags)\
|
||||||
.replace("-Wall", "") \
|
.replace("-Wall", "") \
|
||||||
.replace("-Werror", "")
|
.replace("-Werror", "")
|
||||||
run_cmd(f"{cc_path} {dflags} \"-D_P(...)=\" {data_s_path} -c -o {data_o_path}")
|
vdp = os.path.join(jittor_path, "src", "utils", "vdp")
|
||||||
|
run_cmd(fix_cl_flags(f"{cc_path} {dflags} -include {vdp} {data_s_path} -c -o {data_o_path}"))
|
||||||
os.remove(data_s_path)
|
os.remove(data_s_path)
|
||||||
with open(data_gz_md5_path, 'w') as f:
|
with open(data_gz_md5_path, 'w') as f:
|
||||||
f.write(md5)
|
f.write(md5)
|
||||||
|
|
|
@ -28,6 +28,43 @@ mpi = jt.mpi
|
||||||
img_open_hook = HookTimer(Image, "open")
|
img_open_hook = HookTimer(Image, "open")
|
||||||
CHECK_MEMORY = int(os.environ.get("CHECK_MEMORY", "0"))
|
CHECK_MEMORY = int(os.environ.get("CHECK_MEMORY", "0"))
|
||||||
|
|
||||||
|
if os.name == "nt":
|
||||||
|
from multiprocessing import shared_memory
|
||||||
|
class RingBuffer:
|
||||||
|
def __init__(self, size, shm=None):
|
||||||
|
for i in range(100):
|
||||||
|
if (1<<i) >= size: break
|
||||||
|
size = 1<<i
|
||||||
|
init = False
|
||||||
|
if shm is None:
|
||||||
|
init = True
|
||||||
|
shm = shared_memory.SharedMemory(create=True, size=size+1024)
|
||||||
|
rb = jt.core.RingBuffer(size, id(shm.buf), init)
|
||||||
|
self.size = size
|
||||||
|
self.shm = shm
|
||||||
|
self.rb = rb
|
||||||
|
|
||||||
|
def __reduce__(self):
|
||||||
|
return (RingBuffer, (self.size, self.shm))
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
del self.rb
|
||||||
|
del self.shm
|
||||||
|
|
||||||
|
def push(self, obj): self.send(obj)
|
||||||
|
def pop(self): return self.recv()
|
||||||
|
def send(self, obj): self.rb.push(obj)
|
||||||
|
def recv(self): return self.rb.pop()
|
||||||
|
def clear(self): return self.rb.clear()
|
||||||
|
def stop(self): return self.rb.stop()
|
||||||
|
def is_stop(self): return self.rb.is_stop()
|
||||||
|
def total_pop(self): return self.rb.total_pop()
|
||||||
|
def total_push(self): return self.rb.total_push()
|
||||||
|
def __repr__(self): return repr(self.rb)
|
||||||
|
def keep_numpy_array(self, keep): self.rb.keep_numpy_array(keep)
|
||||||
|
|
||||||
|
jt.RingBuffer = RingBuffer
|
||||||
|
|
||||||
class Worker:
|
class Worker:
|
||||||
def __init__(self, target, args, buffer_size, keep_numpy_array=False):
|
def __init__(self, target, args, buffer_size, keep_numpy_array=False):
|
||||||
self.buffer = jt.RingBuffer(buffer_size)
|
self.buffer = jt.RingBuffer(buffer_size)
|
||||||
|
|
|
@ -18,6 +18,6 @@
|
||||||
|
|
||||||
namespace jittor {
|
namespace jittor {
|
||||||
|
|
||||||
extern cublasHandle_t cublas_handle;
|
EXTERN_LIB cublasHandle_t cublas_handle;
|
||||||
|
|
||||||
} // jittor
|
} // jittor
|
||||||
|
|
|
@ -15,9 +15,9 @@
|
||||||
|
|
||||||
namespace jittor {
|
namespace jittor {
|
||||||
|
|
||||||
extern cudnnHandle_t cudnn_handle;
|
EXTERN_LIB cudnnHandle_t cudnn_handle;
|
||||||
extern int max_cache_size;
|
EXTERN_LIB int max_cache_size;
|
||||||
extern float max_workspace_ratio;
|
EXTERN_LIB float max_workspace_ratio;
|
||||||
|
|
||||||
// @pyjt(set_algorithm_cache_size)
|
// @pyjt(set_algorithm_cache_size)
|
||||||
void set_algorithm_cache_size(int size);
|
void set_algorithm_cache_size(int size);
|
||||||
|
|
|
@ -87,7 +87,7 @@ VarPtr CudnnConv3dBackwardWOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
||||||
|
|
||||||
#pragma clang diagnostic ignored "-Wtautological-compare"
|
#pragma clang diagnostic ignored "-Wtautological-compare"
|
||||||
|
|
||||||
extern unordered_map<string, cudnnConvolutionBwdFilterAlgo_t> bwdw_algo_cache;
|
EXTERN_LIB unordered_map<string, cudnnConvolutionBwdFilterAlgo_t> bwdw_algo_cache;
|
||||||
|
|
||||||
template <typename T_ELEM> __inline__ cudnnDataType_t getDataType();
|
template <typename T_ELEM> __inline__ cudnnDataType_t getDataType();
|
||||||
template <> __inline__ cudnnDataType_t getDataType<half1>() { return CUDNN_DATA_HALF; }
|
template <> __inline__ cudnnDataType_t getDataType<half1>() { return CUDNN_DATA_HALF; }
|
||||||
|
@ -194,6 +194,7 @@ void CudnnConv3dBackwardWOp::jit_run() {
|
||||||
cudnnConvolutionBwdFilterAlgo_t algo;
|
cudnnConvolutionBwdFilterAlgo_t algo;
|
||||||
bool benchmark=true;
|
bool benchmark=true;
|
||||||
|
|
||||||
|
JK& jk = get_jk();
|
||||||
jk.clear();
|
jk.clear();
|
||||||
jk << dimX[0] << "," << dimX[1] << "," << dimX[2] << "," << dimX[3] << "," << dimX[4] << ",";
|
jk << dimX[0] << "," << dimX[1] << "," << dimX[2] << "," << dimX[3] << "," << dimX[4] << ",";
|
||||||
jk << dimW[0] << "," << dimW[1] << "," << dimW[2] << "," << dimW[3] << "," << dimW[4] << ",";
|
jk << dimW[0] << "," << dimW[1] << "," << dimW[2] << "," << dimW[3] << "," << dimW[4] << ",";
|
||||||
|
|
|
@ -77,7 +77,7 @@ VarPtr CudnnConv3dBackwardXOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
||||||
|
|
||||||
#pragma clang diagnostic ignored "-Wtautological-compare"
|
#pragma clang diagnostic ignored "-Wtautological-compare"
|
||||||
|
|
||||||
extern unordered_map<string, cudnnConvolutionBwdDataAlgo_t> bwdx_algo_cache;
|
EXTERN_LIB unordered_map<string, cudnnConvolutionBwdDataAlgo_t> bwdx_algo_cache;
|
||||||
|
|
||||||
template <typename T_ELEM> __inline__ cudnnDataType_t getDataType();
|
template <typename T_ELEM> __inline__ cudnnDataType_t getDataType();
|
||||||
template <> __inline__ cudnnDataType_t getDataType<half1>() { return CUDNN_DATA_HALF; }
|
template <> __inline__ cudnnDataType_t getDataType<half1>() { return CUDNN_DATA_HALF; }
|
||||||
|
@ -185,6 +185,7 @@ void CudnnConv3dBackwardXOp::jit_run() {
|
||||||
cudnnConvolutionBwdDataAlgo_t algo;
|
cudnnConvolutionBwdDataAlgo_t algo;
|
||||||
bool benchmark=true;
|
bool benchmark=true;
|
||||||
|
|
||||||
|
JK& jk = get_jk();
|
||||||
jk.clear();
|
jk.clear();
|
||||||
jk << dimX[0] << "," << dimX[1] << "," << dimX[2] << "," << dimX[3] << "," << dimX[4] << ",";
|
jk << dimX[0] << "," << dimX[1] << "," << dimX[2] << "," << dimX[3] << "," << dimX[4] << ",";
|
||||||
jk << dimW[0] << "," << dimW[1] << "," << dimW[2] << "," << dimW[3] << "," << dimW[4] << ",";
|
jk << dimW[0] << "," << dimW[1] << "," << dimW[2] << "," << dimW[3] << "," << dimW[4] << ",";
|
||||||
|
|
|
@ -80,7 +80,7 @@ VarPtr CudnnConv3dOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
||||||
|
|
||||||
#pragma clang diagnostic ignored "-Wtautological-compare"
|
#pragma clang diagnostic ignored "-Wtautological-compare"
|
||||||
|
|
||||||
extern unordered_map<string, cudnnConvolutionFwdAlgo_t> fwd_algo_cache;
|
EXTERN_LIB unordered_map<string, cudnnConvolutionFwdAlgo_t> fwd_algo_cache;
|
||||||
|
|
||||||
template <typename T_ELEM> __inline__ cudnnDataType_t getDataType();
|
template <typename T_ELEM> __inline__ cudnnDataType_t getDataType();
|
||||||
template <> __inline__ cudnnDataType_t getDataType<half1>() { return CUDNN_DATA_HALF; }
|
template <> __inline__ cudnnDataType_t getDataType<half1>() { return CUDNN_DATA_HALF; }
|
||||||
|
@ -188,6 +188,7 @@ void CudnnConv3dOp::jit_run() {
|
||||||
cudnnConvolutionFwdAlgo_t algo;
|
cudnnConvolutionFwdAlgo_t algo;
|
||||||
bool benchmark=true;
|
bool benchmark=true;
|
||||||
|
|
||||||
|
JK& jk = get_jk();
|
||||||
jk.clear();
|
jk.clear();
|
||||||
jk << dimX[0] << "," << dimX[1] << "," << dimX[2] << "," << dimX[3] << "," << dimX[4] << ",";
|
jk << dimX[0] << "," << dimX[1] << "," << dimX[2] << "," << dimX[3] << "," << dimX[4] << ",";
|
||||||
jk << dimW[0] << "," << dimW[1] << "," << dimW[2] << "," << dimW[3] << "," << dimW[4] << ",";
|
jk << dimW[0] << "," << dimW[1] << "," << dimW[2] << "," << dimW[3] << "," << dimW[4] << ",";
|
||||||
|
|
|
@ -79,7 +79,7 @@ unordered_map<string, cudnnConvolutionBwdFilterAlgo_t> bwdw_algo_cache;
|
||||||
|
|
||||||
#pragma clang diagnostic ignored "-Wtautological-compare"
|
#pragma clang diagnostic ignored "-Wtautological-compare"
|
||||||
|
|
||||||
extern unordered_map<string, cudnnConvolutionBwdFilterAlgo_t> bwdw_algo_cache;
|
EXTERN_LIB unordered_map<string, cudnnConvolutionBwdFilterAlgo_t> bwdw_algo_cache;
|
||||||
|
|
||||||
template <typename T_ELEM> __inline__ cudnnDataType_t getDataType();
|
template <typename T_ELEM> __inline__ cudnnDataType_t getDataType();
|
||||||
template <> __inline__ cudnnDataType_t getDataType<half1>() { return CUDNN_DATA_HALF; }
|
template <> __inline__ cudnnDataType_t getDataType<half1>() { return CUDNN_DATA_HALF; }
|
||||||
|
@ -184,6 +184,7 @@ void CudnnConvBackwardWOp::jit_run() {
|
||||||
cudnnConvolutionBwdFilterAlgo_t algo;
|
cudnnConvolutionBwdFilterAlgo_t algo;
|
||||||
bool benchmark=true;
|
bool benchmark=true;
|
||||||
|
|
||||||
|
JK& jk = get_jk();
|
||||||
jk.clear();
|
jk.clear();
|
||||||
jk << dimX[0] << "," << dimX[1] << "," << dimX[2] << "," << dimX[3] << ",";
|
jk << dimX[0] << "," << dimX[1] << "," << dimX[2] << "," << dimX[3] << ",";
|
||||||
jk << dimW[0] << "," << dimW[1] << "," << dimW[2] << "," << dimW[3] << ",";
|
jk << dimW[0] << "," << dimW[1] << "," << dimW[2] << "," << dimW[3] << ",";
|
||||||
|
|
|
@ -79,7 +79,7 @@ unordered_map<string, cudnnConvolutionBwdDataAlgo_t> bwdx_algo_cache;
|
||||||
|
|
||||||
#pragma clang diagnostic ignored "-Wtautological-compare"
|
#pragma clang diagnostic ignored "-Wtautological-compare"
|
||||||
|
|
||||||
extern unordered_map<string, cudnnConvolutionBwdDataAlgo_t> bwdx_algo_cache;
|
EXTERN_LIB unordered_map<string, cudnnConvolutionBwdDataAlgo_t> bwdx_algo_cache;
|
||||||
|
|
||||||
template <typename T_ELEM> __inline__ cudnnDataType_t getDataType();
|
template <typename T_ELEM> __inline__ cudnnDataType_t getDataType();
|
||||||
template <> __inline__ cudnnDataType_t getDataType<half1>() { return CUDNN_DATA_HALF; }
|
template <> __inline__ cudnnDataType_t getDataType<half1>() { return CUDNN_DATA_HALF; }
|
||||||
|
@ -185,6 +185,7 @@ void CudnnConvBackwardXOp::jit_run() {
|
||||||
cudnnConvolutionBwdDataAlgo_t algo;
|
cudnnConvolutionBwdDataAlgo_t algo;
|
||||||
bool benchmark=true;
|
bool benchmark=true;
|
||||||
|
|
||||||
|
JK& jk = get_jk();
|
||||||
jk.clear();
|
jk.clear();
|
||||||
jk << dimX[0] << "," << dimX[1] << "," << dimX[2] << "," << dimX[3] << ",";
|
jk << dimX[0] << "," << dimX[1] << "," << dimX[2] << "," << dimX[3] << ",";
|
||||||
jk << dimW[0] << "," << dimW[1] << "," << dimW[2] << "," << dimW[3] << ",";
|
jk << dimW[0] << "," << dimW[1] << "," << dimW[2] << "," << dimW[3] << ",";
|
||||||
|
|
|
@ -81,7 +81,7 @@ unordered_map<string, cudnnConvolutionFwdAlgo_t> fwd_algo_cache;
|
||||||
|
|
||||||
#pragma clang diagnostic ignored "-Wtautological-compare"
|
#pragma clang diagnostic ignored "-Wtautological-compare"
|
||||||
|
|
||||||
extern unordered_map<string, cudnnConvolutionFwdAlgo_t> fwd_algo_cache;
|
EXTERN_LIB unordered_map<string, cudnnConvolutionFwdAlgo_t> fwd_algo_cache;
|
||||||
|
|
||||||
template <typename T_ELEM> __inline__ cudnnDataType_t getDataType();
|
template <typename T_ELEM> __inline__ cudnnDataType_t getDataType();
|
||||||
template <> __inline__ cudnnDataType_t getDataType<half1>() { return CUDNN_DATA_HALF; }
|
template <> __inline__ cudnnDataType_t getDataType<half1>() { return CUDNN_DATA_HALF; }
|
||||||
|
@ -187,6 +187,7 @@ void CudnnConvOp::jit_run() {
|
||||||
cudnnConvolutionFwdAlgo_t algo;
|
cudnnConvolutionFwdAlgo_t algo;
|
||||||
bool benchmark=true;
|
bool benchmark=true;
|
||||||
|
|
||||||
|
JK& jk = get_jk();
|
||||||
jk.clear();
|
jk.clear();
|
||||||
jk << dimX[0] << "," << dimX[1] << "," << dimX[2] << "," << dimX[3] << ",";
|
jk << dimX[0] << "," << dimX[1] << "," << dimX[2] << "," << dimX[3] << ",";
|
||||||
jk << dimW[0] << "," << dimW[1] << "," << dimW[2] << "," << dimW[3] << ",";
|
jk << dimW[0] << "," << dimW[1] << "," << dimW[2] << "," << dimW[3] << ",";
|
||||||
|
|
|
@ -17,6 +17,6 @@
|
||||||
|
|
||||||
namespace jittor {
|
namespace jittor {
|
||||||
|
|
||||||
extern curandGenerator_t gen;
|
EXTERN_LIB curandGenerator_t gen;
|
||||||
|
|
||||||
} // jittor
|
} // jittor
|
||||||
|
|
|
@ -66,7 +66,7 @@ unordered_map<string, unsigned int> cutt_plan_cache;
|
||||||
|
|
||||||
#else // JIT
|
#else // JIT
|
||||||
|
|
||||||
extern unordered_map<string, unsigned int> cutt_plan_cache;
|
EXTERN_LIB unordered_map<string, unsigned int> cutt_plan_cache;
|
||||||
|
|
||||||
void CuttTransposeOp::jit_run() {
|
void CuttTransposeOp::jit_run() {
|
||||||
auto* __restrict__ xp = x->mem_ptr;
|
auto* __restrict__ xp = x->mem_ptr;
|
||||||
|
@ -93,6 +93,7 @@ void CuttTransposeOp::jit_run() {
|
||||||
checkCudaErrors(cudaMemcpyAsync(yp, xp, x->size, cudaMemcpyDefault, 0));
|
checkCudaErrors(cudaMemcpyAsync(yp, xp, x->size, cudaMemcpyDefault, 0));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
JK& jk = get_jk();
|
||||||
jk.clear();
|
jk.clear();
|
||||||
jk << dim << ',';
|
jk << dim << ',';
|
||||||
for (int i=0; i<dim; i++) jk << x_shape[i] << ',';
|
for (int i=0; i<dim; i++) jk << x_shape[i] << ',';
|
||||||
|
|
|
@ -102,7 +102,7 @@ const char *_cudaGetErrorEnum(NppStatus error);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
namespace jittor {
|
namespace jittor {
|
||||||
extern bool peek_logged;
|
EXTERN_LIB bool peek_logged;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
|
|
@ -17,8 +17,8 @@
|
||||||
|
|
||||||
namespace jittor {
|
namespace jittor {
|
||||||
|
|
||||||
extern ncclComm_t comm;
|
EXTERN_LIB ncclComm_t comm;
|
||||||
extern ncclUniqueId id;
|
EXTERN_LIB ncclUniqueId id;
|
||||||
extern int nccl_device_id;
|
EXTERN_LIB int nccl_device_id;
|
||||||
|
|
||||||
} // jittor
|
} // jittor
|
||||||
|
|
|
@ -9,6 +9,7 @@
|
||||||
// ***************************************************************
|
// ***************************************************************
|
||||||
#pragma once
|
#pragma once
|
||||||
#define OMPI_SKIP_MPICXX
|
#define OMPI_SKIP_MPICXX
|
||||||
|
#include <common.h>
|
||||||
#include <mpi.h>
|
#include <mpi.h>
|
||||||
|
|
||||||
extern void throw_mpi_error(int result,
|
extern void throw_mpi_error(int result,
|
||||||
|
@ -25,13 +26,13 @@ static inline void mpi_check(int result,
|
||||||
|
|
||||||
namespace jittor {
|
namespace jittor {
|
||||||
|
|
||||||
extern int mpi_world_size;
|
EXTERN_LIB int mpi_world_size;
|
||||||
extern int mpi_world_rank;
|
EXTERN_LIB int mpi_world_rank;
|
||||||
extern int mpi_local_size;
|
EXTERN_LIB int mpi_local_size;
|
||||||
extern int mpi_local_rank;
|
EXTERN_LIB int mpi_local_rank;
|
||||||
extern bool inside_mpi;
|
EXTERN_LIB bool inside_mpi;
|
||||||
extern bool mpi_enabled;
|
EXTERN_LIB bool mpi_enabled;
|
||||||
extern bool use_device_mpi;
|
EXTERN_LIB bool use_device_mpi;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
Return number of MPI nodes.
|
Return number of MPI nodes.
|
||||||
|
|
|
@ -1,432 +1,432 @@
|
||||||
# ***************************************************************
|
# ***************************************************************
|
||||||
# Copyright (c) 2021 Jittor. All Rights Reserved.
|
# Copyright (c) 2021 Jittor. All Rights Reserved.
|
||||||
# Maintainers:
|
# Maintainers:
|
||||||
# Haoyang Peng <2247838039@qq.com>
|
# Haoyang Peng <2247838039@qq.com>
|
||||||
# Guowei Yang <471184555@qq.com>
|
# Guowei Yang <471184555@qq.com>
|
||||||
# Dun Liang <randonlang@gmail.com>.
|
# Dun Liang <randonlang@gmail.com>.
|
||||||
#
|
#
|
||||||
# This file is subject to the terms and conditions defined in
|
# This file is subject to the terms and conditions defined in
|
||||||
# file 'LICENSE.txt', which is part of this source code package.
|
# file 'LICENSE.txt', which is part of this source code package.
|
||||||
# ***************************************************************
|
# ***************************************************************
|
||||||
import jittor as jt
|
import jittor as jt
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
|
|
||||||
#TODO:full_matrices=1
|
#TODO:full_matrices=1
|
||||||
def svd(x):
|
def svd(x):
|
||||||
r'''
|
r'''
|
||||||
calculate the Singular Value Decomposition of x.It follows the below fomula:
|
calculate the Singular Value Decomposition of x.It follows the below fomula:
|
||||||
x = usv*
|
x = usv*
|
||||||
only support full matrices == False ver now, which means:
|
only support full matrices == False ver now, which means:
|
||||||
x's shape (...,M,K)
|
x's shape (...,M,K)
|
||||||
u's shape (...,M,K)
|
u's shape (...,M,K)
|
||||||
s's shape (...,K)
|
s's shape (...,K)
|
||||||
v's shape (...,K,N)
|
v's shape (...,K,N)
|
||||||
where K is min(M,N).
|
where K is min(M,N).
|
||||||
:param x:
|
:param x:
|
||||||
:return:u,s,v.
|
:return:u,s,v.
|
||||||
'''
|
'''
|
||||||
def forward_code(np, data):
|
def forward_code(np, data):
|
||||||
a = data["inputs"][0]
|
a = data["inputs"][0]
|
||||||
u, s, v = data["outputs"]
|
u, s, v = data["outputs"]
|
||||||
#TODO:remove copyto
|
#TODO:remove copyto
|
||||||
tu, ts, tv = np.linalg.svd(a, full_matrices=0)
|
tu, ts, tv = np.linalg.svd(a, full_matrices=0)
|
||||||
np.copyto(u, tu)
|
np.copyto(u, tu)
|
||||||
np.copyto(s, ts)
|
np.copyto(s, ts)
|
||||||
np.copyto(v, tv)
|
np.copyto(v, tv)
|
||||||
|
|
||||||
def backward_code(np, data):
|
def backward_code(np, data):
|
||||||
def T(x):
|
def T(x):
|
||||||
return np.swapaxes(x, -1, -2)
|
return np.swapaxes(x, -1, -2)
|
||||||
_dot = partial(np.einsum, '...ij,...jk->...ik')
|
_dot = partial(np.einsum, '...ij,...jk->...ik')
|
||||||
dout = data["dout"]
|
dout = data["dout"]
|
||||||
out = data["outputs"][0]
|
out = data["outputs"][0]
|
||||||
inp = data["inputs"][0]
|
inp = data["inputs"][0]
|
||||||
out_index = data["out_index"]
|
out_index = data["out_index"]
|
||||||
u, s, v = data["f_outputs"]
|
u, s, v = data["f_outputs"]
|
||||||
v = T(v)
|
v = T(v)
|
||||||
m, n = inp.shape[-2:]
|
m, n = inp.shape[-2:]
|
||||||
k = np.min((m, n))
|
k = np.min((m, n))
|
||||||
i = np.reshape(np.eye(k), np.concatenate((np.ones(inp.ndim - 2, dtype=int), (k, k))))
|
i = np.reshape(np.eye(k), np.concatenate((np.ones(inp.ndim - 2, dtype=int), (k, k))))
|
||||||
if out_index == 0:
|
if out_index == 0:
|
||||||
f = 1 / (s[..., np.newaxis, :] ** 2 - s[..., :, np.newaxis] ** 2 + i)
|
f = 1 / (s[..., np.newaxis, :] ** 2 - s[..., :, np.newaxis] ** 2 + i)
|
||||||
gu = dout
|
gu = dout
|
||||||
utgu = _dot(T(u), gu)
|
utgu = _dot(T(u), gu)
|
||||||
t = (f * (utgu - T(utgu))) * s[..., np.newaxis, :]
|
t = (f * (utgu - T(utgu))) * s[..., np.newaxis, :]
|
||||||
t = _dot(_dot(u, t), T(v))
|
t = _dot(_dot(u, t), T(v))
|
||||||
if m > n:
|
if m > n:
|
||||||
i_minus_uut = (np.reshape(np.eye(m), np.concatenate((np.ones(inp.ndim - 2, dtype=int), (m, m)))) -
|
i_minus_uut = (np.reshape(np.eye(m), np.concatenate((np.ones(inp.ndim - 2, dtype=int), (m, m)))) -
|
||||||
_dot(u, np.conj(T(u))))
|
_dot(u, np.conj(T(u))))
|
||||||
t = t + T(_dot(_dot(v / s[..., np.newaxis, :], T(gu)), i_minus_uut))
|
t = t + T(_dot(_dot(v / s[..., np.newaxis, :], T(gu)), i_minus_uut))
|
||||||
np.copyto(out, t)
|
np.copyto(out, t)
|
||||||
elif out_index == 1:
|
elif out_index == 1:
|
||||||
gs = dout
|
gs = dout
|
||||||
t = i * gs[..., :, np.newaxis]
|
t = i * gs[..., :, np.newaxis]
|
||||||
t = _dot(_dot(u, t), T(v))
|
t = _dot(_dot(u, t), T(v))
|
||||||
np.copyto(out, t)
|
np.copyto(out, t)
|
||||||
elif out_index == 2:
|
elif out_index == 2:
|
||||||
f = 1 / (s[..., np.newaxis, :] ** 2 - s[..., :, np.newaxis] ** 2 + i)
|
f = 1 / (s[..., np.newaxis, :] ** 2 - s[..., :, np.newaxis] ** 2 + i)
|
||||||
gv = dout
|
gv = dout
|
||||||
vtgv = _dot(T(v), gv)
|
vtgv = _dot(T(v), gv)
|
||||||
t = s[..., :, np.newaxis] * (f * (vtgv - T(vtgv)))
|
t = s[..., :, np.newaxis] * (f * (vtgv - T(vtgv)))
|
||||||
t = _dot(_dot(u, t), T(v))
|
t = _dot(_dot(u, t), T(v))
|
||||||
if m < n:
|
if m < n:
|
||||||
i_minus_vvt = (np.reshape(np.eye(n), np.concatenate((np.ones(inp.ndim - 2, dtype=int), (n, n)))) -
|
i_minus_vvt = (np.reshape(np.eye(n), np.concatenate((np.ones(inp.ndim - 2, dtype=int), (n, n)))) -
|
||||||
_dot(v, np.conj(T(v))))
|
_dot(v, np.conj(T(v))))
|
||||||
t = t + T(_dot(_dot(u / s[..., np.newaxis, :], T(gv)), i_minus_vvt))
|
t = t + T(_dot(_dot(u / s[..., np.newaxis, :], T(gv)), i_minus_vvt))
|
||||||
np.copyto(out, t)
|
np.copyto(out, t)
|
||||||
|
|
||||||
m, n = x.shape[-2:]
|
m, n = x.shape[-2:]
|
||||||
k = min(m, n)
|
k = min(m, n)
|
||||||
s1 = list(x.shape)
|
s1 = list(x.shape)
|
||||||
s1[-1] = k
|
s1[-1] = k
|
||||||
s2 = list(x.shape)
|
s2 = list(x.shape)
|
||||||
s2[-2] = k
|
s2[-2] = k
|
||||||
s3 = list(x.shape)[:-2]
|
s3 = list(x.shape)[:-2]
|
||||||
s3.append(k)
|
s3.append(k)
|
||||||
u, s, v = jt.numpy_code(
|
u, s, v = jt.numpy_code(
|
||||||
[s1, s3, s2],
|
[s1, s3, s2],
|
||||||
[x.dtype, x.dtype, x.dtype],
|
[x.dtype, x.dtype, x.dtype],
|
||||||
[x],
|
[x],
|
||||||
forward_code,
|
forward_code,
|
||||||
[backward_code],
|
[backward_code],
|
||||||
)
|
)
|
||||||
return u, s, v
|
return u, s, v
|
||||||
|
|
||||||
|
|
||||||
def eigh(x):
|
def eigh(x):
|
||||||
r"""
|
r"""
|
||||||
calculate the eigenvalues and eigenvectors of x.
|
calculate the eigenvalues and eigenvectors of x.
|
||||||
:param x (...,M,M):
|
:param x (...,M,M):
|
||||||
:return:w, v.
|
:return:w, v.
|
||||||
w (...,M) : the eigenvalues.
|
w (...,M) : the eigenvalues.
|
||||||
v (...,M,M) : normalized eigenvectors.
|
v (...,M,M) : normalized eigenvectors.
|
||||||
"""
|
"""
|
||||||
def forward_code(np, data):
|
def forward_code(np, data):
|
||||||
a = data["inputs"][0]
|
a = data["inputs"][0]
|
||||||
w, v = data["outputs"]
|
w, v = data["outputs"]
|
||||||
tw, tv = np.linalg.eigh(a, UPLO='L')
|
tw, tv = np.linalg.eigh(a, UPLO='L')
|
||||||
np.copyto(w, tw)
|
np.copyto(w, tw)
|
||||||
np.copyto(v, tv)
|
np.copyto(v, tv)
|
||||||
|
|
||||||
def backward_code(np, data):
|
def backward_code(np, data):
|
||||||
def T(x):
|
def T(x):
|
||||||
return np.swapaxes(x, -1, -2)
|
return np.swapaxes(x, -1, -2)
|
||||||
_dot = partial(np.einsum, '...ij,...jk->...ik')
|
_dot = partial(np.einsum, '...ij,...jk->...ik')
|
||||||
dout = data["dout"]
|
dout = data["dout"]
|
||||||
out = data["outputs"][0]
|
out = data["outputs"][0]
|
||||||
inp = data["inputs"][0]
|
inp = data["inputs"][0]
|
||||||
out_index = data["out_index"]
|
out_index = data["out_index"]
|
||||||
w, v = data["f_outputs"]
|
w, v = data["f_outputs"]
|
||||||
k = int(inp.shape[-1])
|
k = int(inp.shape[-1])
|
||||||
w_repeated = np.repeat(w[..., np.newaxis], k, axis=-1)
|
w_repeated = np.repeat(w[..., np.newaxis], k, axis=-1)
|
||||||
if out_index == 0:
|
if out_index == 0:
|
||||||
t = _dot(v * dout[..., np.newaxis, :], T(v))
|
t = _dot(v * dout[..., np.newaxis, :], T(v))
|
||||||
np.copyto(out, t)
|
np.copyto(out, t)
|
||||||
elif out_index == 1:
|
elif out_index == 1:
|
||||||
if np.any(dout):
|
if np.any(dout):
|
||||||
off_diag = np.ones((k, k)) - np.eye(k)
|
off_diag = np.ones((k, k)) - np.eye(k)
|
||||||
F = off_diag / (T(w_repeated) - w_repeated + np.eye(k))
|
F = off_diag / (T(w_repeated) - w_repeated + np.eye(k))
|
||||||
t = _dot(_dot(v, F * _dot(T(v), dout)), T(v))
|
t = _dot(_dot(v, F * _dot(T(v), dout)), T(v))
|
||||||
np.copyto(out, t)
|
np.copyto(out, t)
|
||||||
|
|
||||||
sw = x.shape[:-2] + x.shape[-1:]
|
sw = x.shape[:-2] + x.shape[-1:]
|
||||||
sv = x.shape
|
sv = x.shape
|
||||||
w, v = jt.numpy_code(
|
w, v = jt.numpy_code(
|
||||||
[sw, sv],
|
[sw, sv],
|
||||||
[x.dtype, x.dtype],
|
[x.dtype, x.dtype],
|
||||||
[x],
|
[x],
|
||||||
forward_code,
|
forward_code,
|
||||||
[backward_code],
|
[backward_code],
|
||||||
)
|
)
|
||||||
return w, v
|
return w, v
|
||||||
|
|
||||||
|
|
||||||
def inv(x):
|
def inv(x):
|
||||||
r"""
|
r"""
|
||||||
calculate the inverse of x.
|
calculate the inverse of x.
|
||||||
:param x (...,M,M):
|
:param x (...,M,M):
|
||||||
:return:x^-1 (...,M,M).
|
:return:x^-1 (...,M,M).
|
||||||
"""
|
"""
|
||||||
def forward_code(np, data):
|
def forward_code(np, data):
|
||||||
a = data["inputs"][0]
|
a = data["inputs"][0]
|
||||||
m_a = data["outputs"][0]
|
m_a = data["outputs"][0]
|
||||||
t_a = np.linalg.inv(a)
|
t_a = np.linalg.inv(a)
|
||||||
np.copyto(m_a, t_a)
|
np.copyto(m_a, t_a)
|
||||||
|
|
||||||
def backward_code(np, data):
|
def backward_code(np, data):
|
||||||
def T(x):
|
def T(x):
|
||||||
return np.swapaxes(x, -1, -2)
|
return np.swapaxes(x, -1, -2)
|
||||||
_dot = partial(np.einsum, '...ij,...jk->...ik')
|
_dot = partial(np.einsum, '...ij,...jk->...ik')
|
||||||
dout = data["dout"]
|
dout = data["dout"]
|
||||||
out = data["outputs"][0]
|
out = data["outputs"][0]
|
||||||
lmx = data["f_outputs"]
|
lmx = data["f_outputs"]
|
||||||
mx = lmx[0]
|
mx = lmx[0]
|
||||||
t = -_dot(_dot(T(mx), dout), T(mx))
|
t = -_dot(_dot(T(mx), dout), T(mx))
|
||||||
np.copyto(out, t)
|
np.copyto(out, t)
|
||||||
|
|
||||||
lmx = jt.numpy_code(
|
lmx = jt.numpy_code(
|
||||||
[x.shape],
|
[x.shape],
|
||||||
[x.dtype],
|
[x.dtype],
|
||||||
[x],
|
[x],
|
||||||
forward_code,
|
forward_code,
|
||||||
[backward_code],
|
[backward_code],
|
||||||
)
|
)
|
||||||
mx = lmx[0]
|
mx = lmx[0]
|
||||||
return mx
|
return mx
|
||||||
|
|
||||||
|
|
||||||
def pinv(x):
|
def pinv(x):
|
||||||
r"""
|
r"""
|
||||||
calculate the pseudo-inverse of a x.
|
calculate the pseudo-inverse of a x.
|
||||||
:param x (...,M,N)
|
:param x (...,M,N)
|
||||||
:return: x's pinv (...N,M)
|
:return: x's pinv (...N,M)
|
||||||
"""
|
"""
|
||||||
def forward_code(np, data):
|
def forward_code(np, data):
|
||||||
a = data["inputs"][0]
|
a = data["inputs"][0]
|
||||||
m_a = data["outputs"][0]
|
m_a = data["outputs"][0]
|
||||||
t_a = np.linalg.pinv(a)
|
t_a = np.linalg.pinv(a)
|
||||||
np.copyto(m_a, t_a)
|
np.copyto(m_a, t_a)
|
||||||
|
|
||||||
def backward_code(np, data):
|
def backward_code(np, data):
|
||||||
def T(x):
|
def T(x):
|
||||||
return np.swapaxes(x, -1, -2)
|
return np.swapaxes(x, -1, -2)
|
||||||
_dot = partial(np.einsum, '...ij,...jk->...ik')
|
_dot = partial(np.einsum, '...ij,...jk->...ik')
|
||||||
dout = data["dout"]
|
dout = data["dout"]
|
||||||
out = data["outputs"][0]
|
out = data["outputs"][0]
|
||||||
inp = data["inputs"][0]
|
inp = data["inputs"][0]
|
||||||
lmx = data["f_outputs"]
|
lmx = data["f_outputs"]
|
||||||
mx = lmx[0]
|
mx = lmx[0]
|
||||||
t = T(
|
t = T(
|
||||||
-_dot(_dot(mx, T(dout)), mx)
|
-_dot(_dot(mx, T(dout)), mx)
|
||||||
+ _dot(_dot(_dot(mx, T(mx)), dout), np.eye(inp.shape[-2]) - _dot(inp, mx))
|
+ _dot(_dot(_dot(mx, T(mx)), dout), np.eye(inp.shape[-2]) - _dot(inp, mx))
|
||||||
+ _dot(_dot(_dot(np.eye(mx.shape[-2]) - _dot(mx, inp), dout), T(mx)), mx)
|
+ _dot(_dot(_dot(np.eye(mx.shape[-2]) - _dot(mx, inp), dout), T(mx)), mx)
|
||||||
)
|
)
|
||||||
np.copyto(out, t)
|
np.copyto(out, t)
|
||||||
sw = list(x.shape[:-2]) + [x.shape[-1]] + [x.shape[-2]]
|
sw = list(x.shape[:-2]) + [x.shape[-1]] + [x.shape[-2]]
|
||||||
lmx = jt.numpy_code(
|
lmx = jt.numpy_code(
|
||||||
[sw],
|
[sw],
|
||||||
[x.dtype],
|
[x.dtype],
|
||||||
[x],
|
[x],
|
||||||
forward_code,
|
forward_code,
|
||||||
[backward_code],
|
[backward_code],
|
||||||
)
|
)
|
||||||
mx = lmx[0]
|
mx = lmx[0]
|
||||||
return mx
|
return mx
|
||||||
|
|
||||||
|
|
||||||
def det(x):
|
def det(x):
|
||||||
r"""
|
r"""
|
||||||
calculate the determinant of x.
|
calculate the determinant of x.
|
||||||
:param x (...,M,M):
|
:param x (...,M,M):
|
||||||
:return:|x| (...,1)
|
:return:|x| (...,1)
|
||||||
"""
|
"""
|
||||||
def forward_code(np, data):
|
def forward_code(np, data):
|
||||||
a = data["inputs"][0]
|
a = data["inputs"][0]
|
||||||
L = data["outputs"][0]
|
L = data["outputs"][0]
|
||||||
tL = np.linalg.det(a)
|
tL = np.linalg.det(a)
|
||||||
np.copyto(L, tL)
|
np.copyto(L, tL)
|
||||||
|
|
||||||
def backward_code(np, data):
|
def backward_code(np, data):
|
||||||
def T(x):
|
def T(x):
|
||||||
return np.swapaxes(x, -1, -2)
|
return np.swapaxes(x, -1, -2)
|
||||||
_dot = partial(np.einsum, '...ij,...jk->...ik')
|
_dot = partial(np.einsum, '...ij,...jk->...ik')
|
||||||
dout = data["dout"]
|
dout = data["dout"]
|
||||||
out = data["outputs"][0]
|
out = data["outputs"][0]
|
||||||
f_out = data["f_outputs"][0]
|
f_out = data["f_outputs"][0]
|
||||||
inp = data["inputs"][0]
|
inp = data["inputs"][0]
|
||||||
n_d = np.reshape(dout, np.shape(dout) + (1, 1))
|
n_d = np.reshape(dout, np.shape(dout) + (1, 1))
|
||||||
n_o = np.reshape(f_out, np.shape(f_out) + (1, 1))
|
n_o = np.reshape(f_out, np.shape(f_out) + (1, 1))
|
||||||
s = n_d * n_o * T(np.linalg.inv(inp))
|
s = n_d * n_o * T(np.linalg.inv(inp))
|
||||||
np.copyto(out, s)
|
np.copyto(out, s)
|
||||||
|
|
||||||
s = x.shape
|
s = x.shape
|
||||||
x_s = s[:-2]
|
x_s = s[:-2]
|
||||||
if len(s) == 2:
|
if len(s) == 2:
|
||||||
x_s.append(1)
|
x_s.append(1)
|
||||||
l_det = jt.numpy_code(
|
l_det = jt.numpy_code(
|
||||||
[x_s],
|
[x_s],
|
||||||
[x.dtype],
|
[x.dtype],
|
||||||
[x],
|
[x],
|
||||||
forward_code,
|
forward_code,
|
||||||
[backward_code],
|
[backward_code],
|
||||||
)
|
)
|
||||||
det = l_det[0]
|
det = l_det[0]
|
||||||
return det
|
return det
|
||||||
|
|
||||||
|
|
||||||
def slogdet(x):
|
def slogdet(x):
|
||||||
r"""
|
r"""
|
||||||
calculate the sign and log of the determinant of x.
|
calculate the sign and log of the determinant of x.
|
||||||
:param x (...,M,M):
|
:param x (...,M,M):
|
||||||
:return sign, x's logdet.
|
:return sign, x's logdet.
|
||||||
sign array decides the sign of determinant and their values can be -1,0,1.Only Real number now.0 means det is 0 and logdet is -inf.
|
sign array decides the sign of determinant and their values can be -1,0,1.Only Real number now.0 means det is 0 and logdet is -inf.
|
||||||
logdet in shape (...,1).
|
logdet in shape (...,1).
|
||||||
"""
|
"""
|
||||||
def forward_code(np, data):
|
def forward_code(np, data):
|
||||||
a = data["inputs"][0]
|
a = data["inputs"][0]
|
||||||
sign, m_a = data["outputs"]
|
sign, m_a = data["outputs"]
|
||||||
sign_, t_a = np.linalg.slogdet(a)
|
sign_, t_a = np.linalg.slogdet(a)
|
||||||
np.copyto(m_a, t_a)
|
np.copyto(m_a, t_a)
|
||||||
np.copyto(sign, sign_)
|
np.copyto(sign, sign_)
|
||||||
|
|
||||||
def backward_code(np, data):
|
def backward_code(np, data):
|
||||||
def T(x):
|
def T(x):
|
||||||
return np.swapaxes(x, -1, -2)
|
return np.swapaxes(x, -1, -2)
|
||||||
_dot = partial(np.einsum, '...ij,...jk->...ik')
|
_dot = partial(np.einsum, '...ij,...jk->...ik')
|
||||||
dout = data["dout"]
|
dout = data["dout"]
|
||||||
out = data["outputs"][0]
|
out = data["outputs"][0]
|
||||||
inp = data["inputs"][0]
|
inp = data["inputs"][0]
|
||||||
out_index = data["out_index"]
|
out_index = data["out_index"]
|
||||||
if out_index == 0:
|
if out_index == 0:
|
||||||
np.copyto(out, 0)
|
np.copyto(out, 0)
|
||||||
if out_index == 1:
|
if out_index == 1:
|
||||||
t = np.reshape(dout, np.shape(dout) + (1, 1))
|
t = np.reshape(dout, np.shape(dout) + (1, 1))
|
||||||
t = t * T(np.linalg.inv(inp))
|
t = t * T(np.linalg.inv(inp))
|
||||||
np.copyto(out, t)
|
np.copyto(out, t)
|
||||||
|
|
||||||
s = x.shape
|
s = x.shape
|
||||||
det_s = s[:-2]
|
det_s = s[:-2]
|
||||||
if len(det_s) == 0:
|
if len(det_s) == 0:
|
||||||
det_s.append(1)
|
det_s.append(1)
|
||||||
sign, mx = jt.numpy_code(
|
sign, mx = jt.numpy_code(
|
||||||
[det_s, det_s],
|
[det_s, det_s],
|
||||||
[x.dtype, x.dtype],
|
[x.dtype, x.dtype],
|
||||||
[x],
|
[x],
|
||||||
forward_code,
|
forward_code,
|
||||||
[backward_code],
|
[backward_code],
|
||||||
)
|
)
|
||||||
return sign, mx
|
return sign, mx
|
||||||
|
|
||||||
|
|
||||||
def cholesky(x):
|
def cholesky(x):
|
||||||
r"""
|
r"""
|
||||||
do Cholesky decomposition of x in the form of below formula:
|
do Cholesky decomposition of x in the form of below formula:
|
||||||
x = LL^T
|
x = LL^T
|
||||||
x must be a Hermite and positive-definite matrix. L is a lower-triangular matrix.
|
x must be a Hermite and positive-definite matrix. L is a lower-triangular matrix.
|
||||||
:param x (...,M,M):
|
:param x (...,M,M):
|
||||||
:return: L (...,M,M).
|
:return: L (...,M,M).
|
||||||
"""
|
"""
|
||||||
def forward_code(np, data):
|
def forward_code(np, data):
|
||||||
a = data["inputs"][0]
|
a = data["inputs"][0]
|
||||||
L = data["outputs"][0]
|
L = data["outputs"][0]
|
||||||
tL = np.linalg.cholesky(a)
|
tL = np.linalg.cholesky(a)
|
||||||
np.copyto(L, tL)
|
np.copyto(L, tL)
|
||||||
|
|
||||||
def backward_code(np, data):
|
def backward_code(np, data):
|
||||||
def T(x):
|
def T(x):
|
||||||
return np.swapaxes(x, -1, -2)
|
return np.swapaxes(x, -1, -2)
|
||||||
_dot = partial(np.einsum, '...ij,...jk->...ik')
|
_dot = partial(np.einsum, '...ij,...jk->...ik')
|
||||||
dout = data["dout"]
|
dout = data["dout"]
|
||||||
out = data["outputs"][0]
|
out = data["outputs"][0]
|
||||||
f_out = data["f_outputs"][0]
|
f_out = data["f_outputs"][0]
|
||||||
solve_trans = lambda a, b: np.linalg.solve(T(a), b)
|
solve_trans = lambda a, b: np.linalg.solve(T(a), b)
|
||||||
phi = lambda X: np.tril(X) / (1. + np.eye(X.shape[-1]))
|
phi = lambda X: np.tril(X) / (1. + np.eye(X.shape[-1]))
|
||||||
|
|
||||||
def conjugate_solve(L, X):
|
def conjugate_solve(L, X):
|
||||||
return solve_trans(L, T(solve_trans(L, T(X))))
|
return solve_trans(L, T(solve_trans(L, T(X))))
|
||||||
|
|
||||||
s = conjugate_solve(f_out, phi(np.einsum('...ki,...kj->...ij', f_out, dout)))
|
s = conjugate_solve(f_out, phi(np.einsum('...ki,...kj->...ij', f_out, dout)))
|
||||||
s = (s + T(s)) / 2.
|
s = (s + T(s)) / 2.
|
||||||
np.copyto(out, s)
|
np.copyto(out, s)
|
||||||
|
|
||||||
lL = jt.numpy_code(
|
lL = jt.numpy_code(
|
||||||
[x.shape],
|
[x.shape],
|
||||||
[x.dtype],
|
[x.dtype],
|
||||||
[x],
|
[x],
|
||||||
forward_code,
|
forward_code,
|
||||||
[backward_code],
|
[backward_code],
|
||||||
)
|
)
|
||||||
L = lL[0]
|
L = lL[0]
|
||||||
return L
|
return L
|
||||||
|
|
||||||
|
|
||||||
def solve(a,b):
|
def solve(a,b):
|
||||||
r"""
|
r"""
|
||||||
Solve a linear matrix equation Ax = B.This is done by calculating x = A^-1B.So A must not be singular.
|
Solve a linear matrix equation Ax = B.This is done by calculating x = A^-1B.So A must not be singular.
|
||||||
:param a:(...,M,M)
|
:param a:(...,M,M)
|
||||||
:param b:(...,M)
|
:param b:(...,M)
|
||||||
:return:solution of Ax = b formula.x in the shape of (...M)
|
:return:solution of Ax = b formula.x in the shape of (...M)
|
||||||
"""
|
"""
|
||||||
def forward_code(np, data):
|
def forward_code(np, data):
|
||||||
a, b = data["inputs"]
|
a, b = data["inputs"]
|
||||||
L = data["outputs"][0]
|
L = data["outputs"][0]
|
||||||
ans = np.linalg.solve(a, b)
|
ans = np.linalg.solve(a, b)
|
||||||
np.copyto(L, ans)
|
np.copyto(L, ans)
|
||||||
|
|
||||||
def backward_code1(np, data):
|
def backward_code1(np, data):
|
||||||
def T(x):
|
def T(x):
|
||||||
return np.swapaxes(x, -1, -2)
|
return np.swapaxes(x, -1, -2)
|
||||||
_dot = partial(np.einsum, '...ij,...jk->...ik')
|
_dot = partial(np.einsum, '...ij,...jk->...ik')
|
||||||
dout = data["dout"]
|
dout = data["dout"]
|
||||||
out = data["outputs"][0]
|
out = data["outputs"][0]
|
||||||
f_out = data["f_outputs"][0]
|
f_out = data["f_outputs"][0]
|
||||||
inp = data["inputs"][0]
|
inp = data["inputs"][0]
|
||||||
updim = lambda x: x if x.ndim == a.ndim else x[..., None]
|
updim = lambda x: x if x.ndim == a.ndim else x[..., None]
|
||||||
t = -_dot(updim(np.linalg.solve(T(inp), dout)), T(updim(f_out)))
|
t = -_dot(updim(np.linalg.solve(T(inp), dout)), T(updim(f_out)))
|
||||||
np.copyto(out, t)
|
np.copyto(out, t)
|
||||||
|
|
||||||
def backward_code2(np, data):
|
def backward_code2(np, data):
|
||||||
out = data["outputs"][0]
|
out = data["outputs"][0]
|
||||||
np.copyto(out, 0)
|
np.copyto(out, 0)
|
||||||
|
|
||||||
l_ans = jt.numpy_code(
|
l_ans = jt.numpy_code(
|
||||||
[b.shape],
|
[b.shape],
|
||||||
[b.dtype],
|
[b.dtype],
|
||||||
[a, b],
|
[a, b],
|
||||||
forward_code,
|
forward_code,
|
||||||
[backward_code1, backward_code2],
|
[backward_code1, backward_code2],
|
||||||
)
|
)
|
||||||
ans = l_ans[0]
|
ans = l_ans[0]
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
|
|
||||||
def qr(x):
|
def qr(x):
|
||||||
r"""
|
r"""
|
||||||
do the qr factorization of x in the below formula:
|
do the qr factorization of x in the below formula:
|
||||||
x = QR where Q is orthogonal matrix and R is upper-triangle matrix.
|
x = QR where Q is orthogonal matrix and R is upper-triangle matrix.
|
||||||
:param x (...,M,M):
|
:param x (...,M,M):
|
||||||
:return:q,r as the result of qr factorization.They are both in the shape of (...,M,M).
|
:return:q,r as the result of qr factorization.They are both in the shape of (...,M,M).
|
||||||
"""
|
"""
|
||||||
def forward_code(np, data):
|
def forward_code(np, data):
|
||||||
a = data["inputs"][0]
|
a = data["inputs"][0]
|
||||||
q, r = data["outputs"]
|
q, r = data["outputs"]
|
||||||
Q, R = np.linalg.qr(a)
|
Q, R = np.linalg.qr(a)
|
||||||
np.copyto(q,Q)
|
np.copyto(q,Q)
|
||||||
np.copyto(r,R)
|
np.copyto(r,R)
|
||||||
|
|
||||||
def backward_code(np, data):
|
def backward_code(np, data):
|
||||||
def T(x):
|
def T(x):
|
||||||
return np.swapaxes(x, -1, -2)
|
return np.swapaxes(x, -1, -2)
|
||||||
_dot = partial(np.einsum, '...ij,...jk->...ik')
|
_dot = partial(np.einsum, '...ij,...jk->...ik')
|
||||||
_harmard = partial(np.einsum, '...ij,...ij->...ij')
|
_harmard = partial(np.einsum, '...ij,...ij->...ij')
|
||||||
dout = data["dout"]
|
dout = data["dout"]
|
||||||
out = data["outputs"][0]
|
out = data["outputs"][0]
|
||||||
q, r = data["f_outputs"]
|
q, r = data["f_outputs"]
|
||||||
out_index = data["out_index"]
|
out_index = data["out_index"]
|
||||||
#pl = np.tril(np.ones((inp.shape[-1],inp.shape[-1])))-diags
|
#pl = np.tril(np.ones((inp.shape[-1],inp.shape[-1])))-diags
|
||||||
if out_index == 0: # Q_TERM
|
if out_index == 0: # Q_TERM
|
||||||
q_t = _dot(T(q),dout)
|
q_t = _dot(T(q),dout)
|
||||||
rhs_solve = q_t - T(q_t)
|
rhs_solve = q_t - T(q_t)
|
||||||
rhs_solve = T(np.tril(rhs_solve,-1))
|
rhs_solve = T(np.tril(rhs_solve,-1))
|
||||||
qsolve = np.linalg.solve(r,rhs_solve)
|
qsolve = np.linalg.solve(r,rhs_solve)
|
||||||
qsolve = T(qsolve)
|
qsolve = T(qsolve)
|
||||||
tq = _dot(q,qsolve)
|
tq = _dot(q,qsolve)
|
||||||
np.copyto(out,tq)
|
np.copyto(out,tq)
|
||||||
else: #R_TERM
|
else: #R_TERM
|
||||||
r_t = _dot(r ,T(dout))
|
r_t = _dot(r ,T(dout))
|
||||||
rhs_solve = r_t - T(r_t)
|
rhs_solve = r_t - T(r_t)
|
||||||
rhs_solve = np.tril(rhs_solve,-1)
|
rhs_solve = np.tril(rhs_solve,-1)
|
||||||
rhs_solve = T(rhs_solve)
|
rhs_solve = T(rhs_solve)
|
||||||
r_solve = np.linalg.solve(r,rhs_solve)
|
r_solve = np.linalg.solve(r,rhs_solve)
|
||||||
tr = _dot(q,(T(r_solve) + dout))
|
tr = _dot(q,(T(r_solve) + dout))
|
||||||
np.copyto(out,tr)
|
np.copyto(out,tr)
|
||||||
|
|
||||||
q, r = jt.numpy_code(
|
q, r = jt.numpy_code(
|
||||||
[x.shape,x.shape],
|
[x.shape,x.shape],
|
||||||
[x.dtype,x.dtype],
|
[x.dtype,x.dtype],
|
||||||
[x],
|
[x],
|
||||||
forward_code,
|
forward_code,
|
||||||
[backward_code],
|
[backward_code],
|
||||||
)
|
)
|
||||||
return q, r
|
return q, r
|
||||||
|
|
|
@ -614,7 +614,7 @@ def compile_src(src, h, basename):
|
||||||
(void)n;
|
(void)n;
|
||||||
if (arg0 >= GET_RAW_PTR({dfs[0]["scope_name"]},self)->size()) {{
|
if (arg0 >= GET_RAW_PTR({dfs[0]["scope_name"]},self)->size()) {{
|
||||||
PyErr_SetString(PyExc_IndexError, "");
|
PyErr_SetString(PyExc_IndexError, "");
|
||||||
return 0;
|
return (PyObject*)nullptr;
|
||||||
}}
|
}}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -675,7 +675,7 @@ def compile_src(src, h, basename):
|
||||||
error_log_code = generate_error_code_from_func_header(func_head, target_scope_name, name, dfs, basename ,h, class_info)
|
error_log_code = generate_error_code_from_func_header(func_head, target_scope_name, name, dfs, basename ,h, class_info)
|
||||||
func = f"""
|
func = f"""
|
||||||
{func_cast}[]{func_head} {{
|
{func_cast}[]{func_head} {{
|
||||||
try {{
|
try {{_JT_SEH_START3;
|
||||||
{func_fill};
|
{func_fill};
|
||||||
uint64 arg_filled=0;
|
uint64 arg_filled=0;
|
||||||
(void)arg_filled;
|
(void)arg_filled;
|
||||||
|
@ -689,7 +689,7 @@ def compile_src(src, h, basename):
|
||||||
for did in range(len(arr_func_return))
|
for did in range(len(arr_func_return))
|
||||||
])}
|
])}
|
||||||
LOGf << "Not a valid call.";
|
LOGf << "Not a valid call.";
|
||||||
}} catch (const std::exception& e) {{
|
_JT_SEH_END3; }} catch (const std::exception& e) {{
|
||||||
if (!PyErr_Occurred()) {{
|
if (!PyErr_Occurred()) {{
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
ss {error_log_code};
|
ss {error_log_code};
|
||||||
|
@ -775,6 +775,7 @@ def compile_src(src, h, basename):
|
||||||
if include_name.endswith("var_slices.h"):
|
if include_name.endswith("var_slices.h"):
|
||||||
src_code += '#include "var_holder.h"\n'
|
src_code += '#include "var_holder.h"\n'
|
||||||
src_code += f"""
|
src_code += f"""
|
||||||
|
#include "utils/seh.h"
|
||||||
#include "pyjt/py_converter.h"
|
#include "pyjt/py_converter.h"
|
||||||
#include "pyjt/py_arg_printer.h"
|
#include "pyjt/py_arg_printer.h"
|
||||||
#include "common.h"
|
#include "common.h"
|
||||||
|
|
|
@ -5,7 +5,6 @@
|
||||||
// file 'LICENSE.txt', which is part of this source code package.
|
// file 'LICENSE.txt', which is part of this source code package.
|
||||||
// ***************************************************************
|
// ***************************************************************
|
||||||
#pragma once
|
#pragma once
|
||||||
#include <stddef.h>
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include "utils/log.h"
|
#include "utils/log.h"
|
||||||
|
@ -26,4 +25,14 @@ void expect_error(std::function<void()> func);
|
||||||
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
|
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
|
||||||
#pragma GCC diagnostic ignored "-Wdiv-by-zero"
|
#pragma GCC diagnostic ignored "-Wdiv-by-zero"
|
||||||
#endif
|
#endif
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#ifdef _WIN32
|
||||||
|
#ifndef __restrict__
|
||||||
|
#define __restrict__ __restrict
|
||||||
|
#endif
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifdef _MSC_VER
|
||||||
|
#define __builtin_popcount __popcnt
|
||||||
|
#endif
|
||||||
|
|
|
@ -14,7 +14,7 @@ namespace jittor {
|
||||||
|
|
||||||
// @pyjt(number_of_hold_vars)
|
// @pyjt(number_of_hold_vars)
|
||||||
inline static uint64 get_number_of_hold_vars() {
|
inline static uint64 get_number_of_hold_vars() {
|
||||||
return VarHolder::hold_vars.size();
|
return hold_vars.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
// @pyjt(number_of_lived_vars)
|
// @pyjt(number_of_lived_vars)
|
||||||
|
|
|
@ -34,7 +34,7 @@ void EventQueue::Worker::stop() {
|
||||||
LOGv << "stopped event queue worker.";
|
LOGv << "stopped event queue worker.";
|
||||||
}
|
}
|
||||||
|
|
||||||
extern vector<void(*)()> cleanup_callback;
|
EXTERN_LIB vector<void(*)()> cleanup_callback;
|
||||||
|
|
||||||
EventQueue::Worker::Worker() : thread(EventQueue::Worker::start) {
|
EventQueue::Worker::Worker() : thread(EventQueue::Worker::start) {
|
||||||
cleanup_callback.push_back(&EventQueue::Worker::stop);
|
cleanup_callback.push_back(&EventQueue::Worker::stop);
|
||||||
|
|
|
@ -88,7 +88,7 @@ struct EventQueue {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
extern EventQueue event_queue;
|
EXTERN_LIB EventQueue event_queue;
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
|
@ -28,16 +28,17 @@
|
||||||
#include "memory_profiler.h"
|
#include "memory_profiler.h"
|
||||||
#include "misc/nan_checker.h"
|
#include "misc/nan_checker.h"
|
||||||
#include "memory_profiler.h"
|
#include "memory_profiler.h"
|
||||||
|
#include "utils/seh.h"
|
||||||
|
|
||||||
namespace jittor {
|
namespace jittor {
|
||||||
|
|
||||||
Executor exe;
|
Executor exe;
|
||||||
extern MemoryProfiler memory_profiler;
|
EXTERN_LIB MemoryProfiler memory_profiler;
|
||||||
DECLARE_FLAG(int, profile_memory_enable);
|
DECLARE_FLAG(int, profile_memory_enable);
|
||||||
DEFINE_FLAG(int, gopt_disable, 0, "Disable graph optimizer.");
|
DEFINE_FLAG(int, gopt_disable, 0, "Disable graph optimizer.");
|
||||||
|
|
||||||
// from fetch_op.cc
|
// from fetch_op.cc
|
||||||
extern list<VarPtr> fetcher_to_free;
|
EXTERN_LIB list<VarPtr> fetcher_to_free;
|
||||||
// from cuda_managed_allocator
|
// from cuda_managed_allocator
|
||||||
#ifdef HAS_CUDA
|
#ifdef HAS_CUDA
|
||||||
DECLARE_FLAG(int, use_cuda_managed_allocator);
|
DECLARE_FLAG(int, use_cuda_managed_allocator);
|
||||||
|
@ -414,7 +415,7 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
|
||||||
#ifdef HAS_CUDA
|
#ifdef HAS_CUDA
|
||||||
int sync_times = 0;
|
int sync_times = 0;
|
||||||
#endif
|
#endif
|
||||||
auto& jkl = jk;
|
auto& jkl = get_jk();
|
||||||
for (uint rid=0; rid<queue.size(); rid++) {
|
for (uint rid=0; rid<queue.size(); rid++) {
|
||||||
int root = queue[rid];
|
int root = queue[rid];
|
||||||
Op* op = ops[root];
|
Op* op = ops[root];
|
||||||
|
@ -471,7 +472,9 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
last_is_cuda = is_cuda;
|
last_is_cuda = is_cuda;
|
||||||
|
_JT_SEH_START2;
|
||||||
op->do_run_after_prepare(jkl);
|
op->do_run_after_prepare(jkl);
|
||||||
|
_JT_SEH_END2;
|
||||||
#ifdef HAS_CUDA
|
#ifdef HAS_CUDA
|
||||||
// migrate to gpu
|
// migrate to gpu
|
||||||
if (PREDICT_BRANCH_NOT_TAKEN((!is_cuda && use_cuda && !use_cuda_managed_allocator))) {
|
if (PREDICT_BRANCH_NOT_TAKEN((!is_cuda && use_cuda && !use_cuda_managed_allocator))) {
|
||||||
|
|
|
@ -24,7 +24,7 @@ struct Executor {
|
||||||
void run_sync(vector<Var*> vars, bool device_sync);
|
void run_sync(vector<Var*> vars, bool device_sync);
|
||||||
};
|
};
|
||||||
|
|
||||||
extern Executor exe;
|
EXTERN_LIB Executor exe;
|
||||||
|
|
||||||
void load_fused_op(FusedOp& fused_op, vector<int>& fuse_ops, vector<Op*>& ops, int ll, int rr, int64 tt);
|
void load_fused_op(FusedOp& fused_op, vector<int>& fuse_ops, vector<Op*>& ops, int ll, int rr, int64 tt);
|
||||||
|
|
||||||
|
|
|
@ -32,6 +32,7 @@ loop_options_t& FusedOp::get_loop_options_tuned() {
|
||||||
}
|
}
|
||||||
|
|
||||||
void FusedOp::update_jit_key() {
|
void FusedOp::update_jit_key() {
|
||||||
|
JK& jk = get_jk();
|
||||||
jk.clear();
|
jk.clear();
|
||||||
do_jit_prepare(jk);
|
do_jit_prepare(jk);
|
||||||
}
|
}
|
||||||
|
@ -256,7 +257,8 @@ int FusedOp::has(Node* node) {
|
||||||
return context->node_id.count(node);
|
return context->node_id.count(node);
|
||||||
}
|
}
|
||||||
|
|
||||||
void FusedOp::do_run(){
|
void FusedOp::do_run() {
|
||||||
|
JK& jk = get_jk();
|
||||||
do_prepare(jk);
|
do_prepare(jk);
|
||||||
do_run_after_prepare(jk);
|
do_run_after_prepare(jk);
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,7 +24,7 @@ struct FusedOpContext {
|
||||||
void setup(FusedOp* fop);
|
void setup(FusedOp* fop);
|
||||||
};
|
};
|
||||||
|
|
||||||
extern string_view_map<FusedOpContext*> jit_fused_ops;
|
EXTERN_LIB string_view_map<FusedOpContext*> jit_fused_ops;
|
||||||
|
|
||||||
struct FusedOp final : Op {
|
struct FusedOp final : Op {
|
||||||
vector<Op*> ops;
|
vector<Op*> ops;
|
||||||
|
|
|
@ -153,8 +153,8 @@ vector<VarPtr> grad(Var* loss, vector<Var*> targets) {
|
||||||
if (op->flags.get(NodeFlags::_grads)) {
|
if (op->flags.get(NodeFlags::_grads)) {
|
||||||
// backward together
|
// backward together
|
||||||
auto n_i = op->inputs().size();
|
auto n_i = op->inputs().size();
|
||||||
Var* douts[n_o];
|
STACK_ALLOC(Var*, douts, n_o);
|
||||||
VarPtr dins[n_i];
|
STACK_ALLOC(VarPtr, dins, n_i);
|
||||||
// dump "for (Var* out : op->outputs())"
|
// dump "for (Var* out : op->outputs())"
|
||||||
for (int i=0; i<n_o; i++,j++) {
|
for (int i=0; i<n_o; i++,j++) {
|
||||||
auto id = id_buffer[j].second;
|
auto id = id_buffer[j].second;
|
||||||
|
|
|
@ -13,7 +13,7 @@ namespace jittor {
|
||||||
|
|
||||||
DEFINE_FLAG(int, check_graph, 0, "Unify graph sanity check.");
|
DEFINE_FLAG(int, check_graph, 0, "Unify graph sanity check.");
|
||||||
|
|
||||||
extern unordered_map<void*, int64> lived_nodes;
|
EXTERN_LIB unordered_map<void*, int64> lived_nodes;
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
string ss_convert(T x) {
|
string ss_convert(T x) {
|
||||||
|
@ -25,7 +25,7 @@ string ss_convert(T x) {
|
||||||
void do_graph_check() {
|
void do_graph_check() {
|
||||||
vector<Node*> queue;
|
vector<Node*> queue;
|
||||||
unordered_map<Node*,int> visited;
|
unordered_map<Node*,int> visited;
|
||||||
for (auto& vh : VarHolder::hold_vars) {
|
for (auto& vh : hold_vars) {
|
||||||
if (0==visited[vh->var]++)
|
if (0==visited[vh->var]++)
|
||||||
queue.push_back(vh->var);
|
queue.push_back(vh->var);
|
||||||
}
|
}
|
||||||
|
@ -85,7 +85,7 @@ void do_graph_check() {
|
||||||
DumpGraphs dump_all_graphs() {
|
DumpGraphs dump_all_graphs() {
|
||||||
vector<Node*> queue;
|
vector<Node*> queue;
|
||||||
auto t = ++Node::tflag_count;
|
auto t = ++Node::tflag_count;
|
||||||
for (auto& vh : VarHolder::hold_vars)
|
for (auto& vh : hold_vars)
|
||||||
if (vh->var->tflag != t) {
|
if (vh->var->tflag != t) {
|
||||||
vh->var->tflag = t;
|
vh->var->tflag = t;
|
||||||
queue.push_back(vh->var);
|
queue.push_back(vh->var);
|
||||||
|
|
|
@ -27,9 +27,9 @@ vector<set_seed_callback> callbacks;
|
||||||
int current_seed;
|
int current_seed;
|
||||||
|
|
||||||
// fron fetch_op.cc
|
// fron fetch_op.cc
|
||||||
extern list<VarPtr> fetcher;
|
EXTERN_LIB list<VarPtr> fetcher;
|
||||||
extern list<VarPtr> fetcher_to_free;
|
EXTERN_LIB list<VarPtr> fetcher_to_free;
|
||||||
extern vector<void(*)()> cleanup_callback;
|
EXTERN_LIB vector<void(*)()> cleanup_callback;
|
||||||
|
|
||||||
void cleanup() {
|
void cleanup() {
|
||||||
fetcher_to_free.clear();
|
fetcher_to_free.clear();
|
||||||
|
|
|
@ -37,10 +37,13 @@ namespace jit_compiler {
|
||||||
std::mutex dl_open_mutex;
|
std::mutex dl_open_mutex;
|
||||||
|
|
||||||
jit_op_entry_t load_jit_lib(string name, string symbol_name="jit_entry") {
|
jit_op_entry_t load_jit_lib(string name, string symbol_name="jit_entry") {
|
||||||
|
std::lock_guard<std::mutex> lock(dl_open_mutex);
|
||||||
const char* msg = "";
|
const char* msg = "";
|
||||||
LOGvv << "Opening jit lib:" << name;
|
LOGvv << "Opening jit lib:" << name;
|
||||||
#ifdef _WIN32
|
#ifdef _WIN32
|
||||||
void* handle = (void*)LoadLibrary(name.c_str());
|
void* handle = (void*)LoadLibraryExA(name.c_str(), nullptr,
|
||||||
|
LOAD_LIBRARY_SEARCH_DEFAULT_DIRS |
|
||||||
|
LOAD_LIBRARY_SEARCH_USER_DIRS);
|
||||||
#elif defined(__linux__)
|
#elif defined(__linux__)
|
||||||
void* handle = dlopen(name.c_str(), RTLD_LAZY | RTLD_DEEPBIND | RTLD_LOCAL);
|
void* handle = dlopen(name.c_str(), RTLD_LAZY | RTLD_DEEPBIND | RTLD_LOCAL);
|
||||||
msg = dlerror();
|
msg = dlerror();
|
||||||
|
@ -76,7 +79,11 @@ static string get_symbol_name(const string& jit_key) {
|
||||||
op_name = Op::file_name_to_class_name(op_name);
|
op_name = Op::file_name_to_class_name(op_name);
|
||||||
// _ZN7jittorXyyyyyy7jit_runEv
|
// _ZN7jittorXyyyyyy7jit_runEv
|
||||||
// jittor::yyyyyy::jit_run
|
// jittor::yyyyyy::jit_run
|
||||||
|
#ifdef _MSC_VER
|
||||||
|
op_name = "?jit_run@"+op_name+"Op@jittor@@QEAAXXZ";
|
||||||
|
#else
|
||||||
op_name = "_ZN6jittor"+S(op_name.size()+2)+op_name+"Op7jit_runEv";
|
op_name = "_ZN6jittor"+S(op_name.size()+2)+op_name+"Op7jit_runEv";
|
||||||
|
#endif
|
||||||
return op_name;
|
return op_name;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -95,13 +102,15 @@ jit_op_entry_t compile(const string& jit_key, const string& src, const bool is_c
|
||||||
if (rewrite_op || !file_exist(jit_src_path))
|
if (rewrite_op || !file_exist(jit_src_path))
|
||||||
write(jit_src_path, src);
|
write(jit_src_path, src);
|
||||||
string cmd;
|
string cmd;
|
||||||
|
|
||||||
|
#ifndef _MSC_VER
|
||||||
if (is_cuda_op) {
|
if (is_cuda_op) {
|
||||||
cmd = nvcc_path
|
cmd = "\"" + nvcc_path + "\""
|
||||||
+ " \"" + jit_src_path + "\"" + other_src
|
+ " \"" + jit_src_path + "\"" + other_src
|
||||||
+ nvcc_flags + extra_flags
|
+ nvcc_flags + extra_flags
|
||||||
+ " -o \"" + jit_lib_path + "\"";
|
+ " -o \"" + jit_lib_path + "\"";
|
||||||
} else {
|
} else {
|
||||||
cmd = cc_path
|
cmd = "\"" + cc_path + "\""
|
||||||
+ " \"" + jit_src_path + "\"" + other_src
|
+ " \"" + jit_src_path + "\"" + other_src
|
||||||
+ cc_flags + extra_flags
|
+ cc_flags + extra_flags
|
||||||
+ " -o \"" + jit_lib_path + "\"";
|
+ " -o \"" + jit_lib_path + "\"";
|
||||||
|
@ -110,6 +119,24 @@ jit_op_entry_t compile(const string& jit_key, const string& src, const bool is_c
|
||||||
"--cc_path=" + cmd;
|
"--cc_path=" + cmd;
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
#else // Windows _MSC_VER
|
||||||
|
if (is_cuda_op) {
|
||||||
|
cmd = "\"" + nvcc_path + "\""
|
||||||
|
+ " \"" + jit_src_path + "\"" + other_src
|
||||||
|
+ nvcc_flags + extra_flags
|
||||||
|
+ " -o \"" + jit_lib_path + "\"";
|
||||||
|
} else {
|
||||||
|
auto symbol_name = get_symbol_name(jit_key);
|
||||||
|
auto pos = cc_flags.find("-link");
|
||||||
|
auto cc_flags1 = cc_flags.substr(0, pos);
|
||||||
|
auto cc_flags2 = cc_flags.substr(pos);
|
||||||
|
cmd = "\"" + cc_path + "\""
|
||||||
|
+ " \"" + jit_src_path + "\"" + other_src
|
||||||
|
+ cc_flags1 + extra_flags
|
||||||
|
+ " -Fe: \"" + jit_lib_path + "\" " + cc_flags2 + " -EXPORT:\""
|
||||||
|
+ symbol_name + "\"";
|
||||||
|
}
|
||||||
|
#endif
|
||||||
cache_compile(cmd, cache_path, jittor_path);
|
cache_compile(cmd, cache_path, jittor_path);
|
||||||
auto symbol_name = get_symbol_name(jit_key);
|
auto symbol_name = get_symbol_name(jit_key);
|
||||||
auto jit_entry = load_jit_lib(jit_lib_path, symbol_name);
|
auto jit_entry = load_jit_lib(jit_lib_path, symbol_name);
|
||||||
|
|
|
@ -6,17 +6,17 @@
|
||||||
// ***************************************************************
|
// ***************************************************************
|
||||||
#ifndef _WIN32
|
#ifndef _WIN32
|
||||||
#include <sys/mman.h>
|
#include <sys/mman.h>
|
||||||
|
#include <unistd.h>
|
||||||
#endif
|
#endif
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <unistd.h>
|
|
||||||
#include "jit_key.h"
|
#include "jit_key.h"
|
||||||
#include "utils/str_utils.h"
|
#include "utils/str_utils.h"
|
||||||
|
|
||||||
namespace jittor {
|
namespace jittor {
|
||||||
|
|
||||||
extern thread_local size_t protected_page;
|
|
||||||
|
|
||||||
#ifndef _WIN32
|
#ifndef _WIN32
|
||||||
|
EXTERN_LIB thread_local size_t protected_page;
|
||||||
|
|
||||||
static size_t get_buffer_end_page(size_t buffer_end) {
|
static size_t get_buffer_end_page(size_t buffer_end) {
|
||||||
// get the last complete page in buffer
|
// get the last complete page in buffer
|
||||||
// 4k align :
|
// 4k align :
|
||||||
|
@ -121,4 +121,8 @@ vector<pair<string,string>> parse_jit_keys(const string& s) {
|
||||||
|
|
||||||
thread_local JitKey jk;
|
thread_local JitKey jk;
|
||||||
|
|
||||||
|
JK& get_jk() {
|
||||||
|
return jk;
|
||||||
|
}
|
||||||
|
|
||||||
} // jittor
|
} // jittor
|
|
@ -78,8 +78,8 @@ struct __jk_int256 {
|
||||||
int64 a,b,c,d;
|
int64 a,b,c,d;
|
||||||
};
|
};
|
||||||
|
|
||||||
extern thread_local JitKey jk;
|
|
||||||
typedef JitKey JK;
|
typedef JitKey JK;
|
||||||
|
EXTERN_LIB JK& get_jk();
|
||||||
|
|
||||||
inline JK& operator<<(JK& jk, const char* s) {
|
inline JK& operator<<(JK& jk, const char* s) {
|
||||||
int i;
|
int i;
|
||||||
|
@ -284,7 +284,11 @@ getChr(s,35)
|
||||||
|
|
||||||
#define getChr(name, ii) ((_CS_MIN(ii,MAX_CONST_CHAR))<sizeof(name)/sizeof(*name)?name[ii]:0)
|
#define getChr(name, ii) ((_CS_MIN(ii,MAX_CONST_CHAR))<sizeof(name)/sizeof(*name)?name[ii]:0)
|
||||||
|
|
||||||
|
#ifdef _MSC_VER
|
||||||
|
#define _CS(str) str
|
||||||
|
#else
|
||||||
#define _CS(str) _CS_G<_CS_T(str)>()
|
#define _CS(str) _CS_G<_CS_T(str)>()
|
||||||
|
#endif
|
||||||
|
|
||||||
template <char c1, char c2, char c3, char c4, char... Chars_> struct _CS_G {
|
template <char c1, char c2, char c3, char c4, char... Chars_> struct _CS_G {
|
||||||
};
|
};
|
||||||
|
|
|
@ -8,10 +8,15 @@
|
||||||
// file 'LICENSE.txt', which is part of this source code package.
|
// file 'LICENSE.txt', which is part of this source code package.
|
||||||
// ***************************************************************
|
// ***************************************************************
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
#include <unistd.h>
|
|
||||||
#ifdef _WIN32
|
#ifdef _WIN32
|
||||||
|
#include <windows.h>
|
||||||
#include <fileapi.h>
|
#include <fileapi.h>
|
||||||
|
#include <process.h>
|
||||||
|
#include <io.h>
|
||||||
|
#define getpid _getpid
|
||||||
|
#define open _open
|
||||||
#else
|
#else
|
||||||
|
#include <unistd.h>
|
||||||
#endif
|
#endif
|
||||||
#include <fcntl.h>
|
#include <fcntl.h>
|
||||||
#include <errno.h>
|
#include <errno.h>
|
||||||
|
|
|
@ -19,7 +19,7 @@ void lock();
|
||||||
|
|
||||||
void unlock();
|
void unlock();
|
||||||
|
|
||||||
extern int _has_lock;
|
EXTERN_LIB int _has_lock;
|
||||||
|
|
||||||
struct lock_guard {
|
struct lock_guard {
|
||||||
int has_lock = 0;
|
int has_lock = 0;
|
||||||
|
|
|
@ -27,7 +27,7 @@ struct Allocator {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct AlignedAllocator;
|
struct AlignedAllocator;
|
||||||
extern AlignedAllocator aligned_allocator;
|
EXTERN_LIB AlignedAllocator aligned_allocator;
|
||||||
|
|
||||||
struct Allocation {
|
struct Allocation {
|
||||||
void* ptr;
|
void* ptr;
|
||||||
|
@ -48,7 +48,7 @@ struct Allocation {
|
||||||
{ if (ptr) allocator->free(ptr, size, allocation); }
|
{ if (ptr) allocator->free(ptr, size, allocation); }
|
||||||
};
|
};
|
||||||
|
|
||||||
extern Allocator* cpu_allocator;
|
EXTERN_LIB Allocator* cpu_allocator;
|
||||||
Allocator* get_allocator(bool temp_allocator=false);
|
Allocator* get_allocator(bool temp_allocator=false);
|
||||||
// @pyjt(gc)
|
// @pyjt(gc)
|
||||||
void gc_all();
|
void gc_all();
|
||||||
|
|
|
@ -25,7 +25,11 @@ void* AlignedAllocator::alloc(size_t size, size_t& allocation) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void AlignedAllocator::free(void* mem_ptr, size_t size, const size_t& allocation) {
|
void AlignedAllocator::free(void* mem_ptr, size_t size, const size_t& allocation) {
|
||||||
|
#ifdef _WIN32
|
||||||
|
_aligned_free(mem_ptr);
|
||||||
|
#else
|
||||||
::free(mem_ptr);
|
::free(mem_ptr);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
} // jittor
|
} // jittor
|
|
@ -16,6 +16,6 @@ struct AlignedAllocator : Allocator {
|
||||||
void free(void* mem_ptr, size_t size, const size_t& allocation) override;
|
void free(void* mem_ptr, size_t size, const size_t& allocation) override;
|
||||||
};
|
};
|
||||||
|
|
||||||
extern AlignedAllocator aligned_allocator;
|
EXTERN_LIB AlignedAllocator aligned_allocator;
|
||||||
|
|
||||||
} // jittor
|
} // jittor
|
|
@ -12,7 +12,7 @@
|
||||||
namespace jittor {
|
namespace jittor {
|
||||||
|
|
||||||
CudaDeviceAllocator cuda_device_allocator;
|
CudaDeviceAllocator cuda_device_allocator;
|
||||||
extern bool no_cuda_error_when_free;
|
EXTERN_LIB bool no_cuda_error_when_free;
|
||||||
|
|
||||||
const char* CudaDeviceAllocator::name() const {return "cuda_device";}
|
const char* CudaDeviceAllocator::name() const {return "cuda_device";}
|
||||||
|
|
||||||
|
|
|
@ -17,7 +17,7 @@ struct CudaDeviceAllocator : Allocator {
|
||||||
void free(void* mem_ptr, size_t size, const size_t& allocation) override;
|
void free(void* mem_ptr, size_t size, const size_t& allocation) override;
|
||||||
};
|
};
|
||||||
|
|
||||||
extern CudaDeviceAllocator cuda_device_allocator;
|
EXTERN_LIB CudaDeviceAllocator cuda_device_allocator;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -24,9 +24,9 @@ struct DualAllocation {
|
||||||
size_t host_allocation, device_allocation;
|
size_t host_allocation, device_allocation;
|
||||||
};
|
};
|
||||||
|
|
||||||
extern SFRLAllocator cuda_dual_host_allocator;
|
EXTERN_LIB SFRLAllocator cuda_dual_host_allocator;
|
||||||
extern SFRLAllocator cuda_dual_device_allocator;
|
EXTERN_LIB SFRLAllocator cuda_dual_device_allocator;
|
||||||
extern bool no_cuda_error_when_free;
|
EXTERN_LIB bool no_cuda_error_when_free;
|
||||||
|
|
||||||
struct CudaDualAllocator : Allocator {
|
struct CudaDualAllocator : Allocator {
|
||||||
//for recycle block_id
|
//for recycle block_id
|
||||||
|
@ -74,11 +74,11 @@ struct CudaDualAllocator : Allocator {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
extern CudaDualAllocator cuda_dual_allocator;
|
EXTERN_LIB CudaDualAllocator cuda_dual_allocator;
|
||||||
|
|
||||||
namespace cuda_dual_local {
|
namespace cuda_dual_local {
|
||||||
|
|
||||||
extern list<Allocation> allocations;
|
EXTERN_LIB list<Allocation> allocations;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -115,7 +115,7 @@ struct DelayFree final : Allocator {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
extern DelayFree delay_free;
|
EXTERN_LIB DelayFree delay_free;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -12,7 +12,7 @@
|
||||||
namespace jittor {
|
namespace jittor {
|
||||||
|
|
||||||
CudaHostAllocator cuda_host_allocator;
|
CudaHostAllocator cuda_host_allocator;
|
||||||
extern bool no_cuda_error_when_free;
|
EXTERN_LIB bool no_cuda_error_when_free;
|
||||||
|
|
||||||
const char* CudaHostAllocator::name() const {return "cuda_host";}
|
const char* CudaHostAllocator::name() const {return "cuda_host";}
|
||||||
|
|
||||||
|
|
|
@ -17,7 +17,7 @@ struct CudaHostAllocator : Allocator {
|
||||||
void free(void* mem_ptr, size_t size, const size_t& allocation) override;
|
void free(void* mem_ptr, size_t size, const size_t& allocation) override;
|
||||||
};
|
};
|
||||||
|
|
||||||
extern CudaHostAllocator cuda_host_allocator;
|
EXTERN_LIB CudaHostAllocator cuda_host_allocator;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -13,7 +13,7 @@ namespace jittor {
|
||||||
|
|
||||||
CudaManagedAllocator cuda_managed_allocator;
|
CudaManagedAllocator cuda_managed_allocator;
|
||||||
DEFINE_FLAG(int, use_cuda_managed_allocator, 1, "Enable cuda_managed_allocator");
|
DEFINE_FLAG(int, use_cuda_managed_allocator, 1, "Enable cuda_managed_allocator");
|
||||||
extern bool no_cuda_error_when_free;
|
EXTERN_LIB bool no_cuda_error_when_free;
|
||||||
|
|
||||||
const char* CudaManagedAllocator::name() const {return "cuda_managed";}
|
const char* CudaManagedAllocator::name() const {return "cuda_managed";}
|
||||||
|
|
||||||
|
|
|
@ -17,7 +17,7 @@ struct CudaManagedAllocator : Allocator {
|
||||||
void free(void* mem_ptr, size_t size, const size_t& allocation) override;
|
void free(void* mem_ptr, size_t size, const size_t& allocation) override;
|
||||||
};
|
};
|
||||||
|
|
||||||
extern CudaManagedAllocator cuda_managed_allocator;
|
EXTERN_LIB CudaManagedAllocator cuda_managed_allocator;
|
||||||
DECLARE_FLAG(int, use_cuda_managed_allocator);
|
DECLARE_FLAG(int, use_cuda_managed_allocator);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,7 +16,9 @@
|
||||||
#elif defined(_WIN32)
|
#elif defined(_WIN32)
|
||||||
#include <windows.h>
|
#include <windows.h>
|
||||||
#endif
|
#endif
|
||||||
|
#ifndef _WIN32
|
||||||
#include <unistd.h>
|
#include <unistd.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
#include "var.h"
|
#include "var.h"
|
||||||
#include "op.h"
|
#include "op.h"
|
||||||
|
@ -62,7 +64,7 @@ void display_memory_info(const char* fileline, bool dump_var, bool red_color) {
|
||||||
FloatOutput{(double)mem_info.total_cpu_ram, " KMG", 1024, "B"};
|
FloatOutput{(double)mem_info.total_cpu_ram, " KMG", 1024, "B"};
|
||||||
log << "total_cuda_ram:" <<
|
log << "total_cuda_ram:" <<
|
||||||
FloatOutput{(double)mem_info.total_cuda_ram, " KMG", 1024, "B"} >> "\n";
|
FloatOutput{(double)mem_info.total_cuda_ram, " KMG", 1024, "B"} >> "\n";
|
||||||
log << "hold_vars:" << VarHolder::hold_vars.size()
|
log << "hold_vars:" << hold_vars.size()
|
||||||
<< "lived_vars:" << Var::number_of_lived_vars
|
<< "lived_vars:" << Var::number_of_lived_vars
|
||||||
<< "lived_ops:" << Op::number_of_lived_ops >> '\n';
|
<< "lived_ops:" << Op::number_of_lived_ops >> '\n';
|
||||||
log << "update queue:" << update_queue.queue.size()
|
log << "update queue:" << update_queue.queue.size()
|
||||||
|
@ -72,7 +74,7 @@ void display_memory_info(const char* fileline, bool dump_var, bool red_color) {
|
||||||
// get the oldest var
|
// get the oldest var
|
||||||
// vector<Node*> queue;
|
// vector<Node*> queue;
|
||||||
// auto t = ++Node::tflag_count;
|
// auto t = ++Node::tflag_count;
|
||||||
// for (auto& vh : VarHolder::hold_vars)
|
// for (auto& vh : hold_vars)
|
||||||
// if (vh->var->tflag != t) {
|
// if (vh->var->tflag != t) {
|
||||||
// vh->var->tflag = t;
|
// vh->var->tflag = t;
|
||||||
// queue.push_back(vh->var);
|
// queue.push_back(vh->var);
|
||||||
|
@ -148,7 +150,7 @@ void display_memory_info(const char* fileline, bool dump_var, bool red_color) {
|
||||||
if (dump_var) {
|
if (dump_var) {
|
||||||
vector<Node*> queue;
|
vector<Node*> queue;
|
||||||
unordered_set<Node*> visited;
|
unordered_set<Node*> visited;
|
||||||
for (auto& vh : VarHolder::hold_vars)
|
for (auto& vh : hold_vars)
|
||||||
if (!visited.count(vh->var)) {
|
if (!visited.count(vh->var)) {
|
||||||
queue.push_back(vh->var);
|
queue.push_back(vh->var);
|
||||||
visited.insert(vh->var);
|
visited.insert(vh->var);
|
||||||
|
@ -186,7 +188,7 @@ void display_memory_info(const char* fileline, bool dump_var, bool red_color) {
|
||||||
log.end();
|
log.end();
|
||||||
}
|
}
|
||||||
|
|
||||||
extern vector<void(*)()> sigquit_callback;
|
EXTERN_LIB vector<void(*)()> sigquit_callback;
|
||||||
|
|
||||||
void meminfo_callback() {
|
void meminfo_callback() {
|
||||||
display_memory_info();
|
display_memory_info();
|
||||||
|
|
|
@ -24,7 +24,7 @@ struct MemInfo {
|
||||||
MemInfo();
|
MemInfo();
|
||||||
};
|
};
|
||||||
|
|
||||||
extern MemInfo mem_info;
|
EXTERN_LIB MemInfo mem_info;
|
||||||
|
|
||||||
// @pyjt(get_mem_info)
|
// @pyjt(get_mem_info)
|
||||||
inline MemInfo get_mem_info() { return mem_info; }
|
inline MemInfo get_mem_info() { return mem_info; }
|
||||||
|
|
|
@ -79,7 +79,7 @@ void MemoryProfiler::check() {
|
||||||
vector<Node*> queue;
|
vector<Node*> queue;
|
||||||
|
|
||||||
auto t = ++Node::tflag_count;
|
auto t = ++Node::tflag_count;
|
||||||
for (auto& vh : VarHolder::hold_vars)
|
for (auto& vh : hold_vars)
|
||||||
if (vh->var->tflag != t) {
|
if (vh->var->tflag != t) {
|
||||||
vh->var->tflag = t;
|
vh->var->tflag = t;
|
||||||
queue.push_back(vh->var);
|
queue.push_back(vh->var);
|
||||||
|
|
|
@ -39,7 +39,7 @@ struct MemoryProfiler {
|
||||||
string get_max_memory_info();
|
string get_max_memory_info();
|
||||||
};
|
};
|
||||||
|
|
||||||
extern MemoryProfiler memory_profiler;
|
EXTERN_LIB MemoryProfiler memory_profiler;
|
||||||
|
|
||||||
DECLARE_FLAG(int, profile_memory_enable);
|
DECLARE_FLAG(int, profile_memory_enable);
|
||||||
|
|
||||||
|
|
|
@ -10,7 +10,7 @@
|
||||||
|
|
||||||
namespace jittor {
|
namespace jittor {
|
||||||
|
|
||||||
extern std::atomic_flag lock;
|
EXTERN_LIB std::atomic_flag lock;
|
||||||
|
|
||||||
struct spin_lock_guard {
|
struct spin_lock_guard {
|
||||||
inline spin_lock_guard() {
|
inline spin_lock_guard() {
|
||||||
|
|
|
@ -15,7 +15,7 @@ namespace jittor {
|
||||||
DEFINE_FLAG_WITH_SETTER(int, use_cuda, 0,
|
DEFINE_FLAG_WITH_SETTER(int, use_cuda, 0,
|
||||||
"Use cuda or not. 1 for trying to use cuda, 2 for forcing to use cuda.");
|
"Use cuda or not. 1 for trying to use cuda, 2 for forcing to use cuda.");
|
||||||
|
|
||||||
extern void sync_all(bool device_sync);
|
EXTERN_LIB void sync_all(bool device_sync);
|
||||||
|
|
||||||
void setter_use_cuda(int value) {
|
void setter_use_cuda(int value) {
|
||||||
#ifdef HAS_CUDA
|
#ifdef HAS_CUDA
|
||||||
|
|
|
@ -18,8 +18,8 @@ namespace jittor {
|
||||||
|
|
||||||
|
|
||||||
#ifdef HAS_CUDA
|
#ifdef HAS_CUDA
|
||||||
extern void check_nan_float32(float32* ptr, int64 num);
|
EXTERN_LIB void check_nan_float32(float32* ptr, int64 num);
|
||||||
extern void check_nan_float64(float64* ptr, int64 num);
|
EXTERN_LIB void check_nan_float64(float64* ptr, int64 num);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
bool check_nan(Var* v) {
|
bool check_nan(Var* v) {
|
||||||
|
|
|
@ -22,7 +22,16 @@ namespace jittor {
|
||||||
m(float32) \
|
m(float32) \
|
||||||
m(float64)
|
m(float64)
|
||||||
|
|
||||||
|
#ifdef _MSC_VER
|
||||||
|
inline int ffs(int i) {
|
||||||
|
int j=0;
|
||||||
|
while (i) j++,i/=2;
|
||||||
|
return j;
|
||||||
|
}
|
||||||
|
#define map_size(T) {#T, ffs(sizeof(T))-1},
|
||||||
|
#else
|
||||||
#define map_size(T) {#T, __builtin_ffs(sizeof(T))-1},
|
#define map_size(T) {#T, __builtin_ffs(sizeof(T))-1},
|
||||||
|
#endif
|
||||||
|
|
||||||
unordered_map<string, size_t> dsize_map = {FOR_ALL_TYPES(map_size)};
|
unordered_map<string, size_t> dsize_map = {FOR_ALL_TYPES(map_size)};
|
||||||
|
|
||||||
|
@ -120,9 +129,9 @@ static unordered_set<string> binary_ops = {
|
||||||
#define DEFINE_NS(T) NanoString ns_##T;
|
#define DEFINE_NS(T) NanoString ns_##T;
|
||||||
FOR_ALL_NS(DEFINE_NS);
|
FOR_ALL_NS(DEFINE_NS);
|
||||||
|
|
||||||
unordered_map<string, NanoString> NanoString::__string_to_ns;
|
unordered_map<string, NanoString> __string_to_ns;
|
||||||
char NanoString::__ns_to_string[ns_max_size*ns_max_len];
|
char __ns_to_string[ns_max_size*ns_max_len];
|
||||||
int NanoString::__ns_len[ns_max_size];
|
int __ns_len[ns_max_size];
|
||||||
|
|
||||||
static void init_ns() {
|
static void init_ns() {
|
||||||
NanoString::ns_t i=0;
|
NanoString::ns_t i=0;
|
||||||
|
@ -146,27 +155,27 @@ static void init_ns() {
|
||||||
ns.set(NanoString::_type, NanoString::_binary, NanoString::_type_nbits);
|
ns.set(NanoString::_type, NanoString::_binary, NanoString::_type_nbits);
|
||||||
ns.set(NanoString::_bool, is_bool.count(name));
|
ns.set(NanoString::_bool, is_bool.count(name));
|
||||||
}
|
}
|
||||||
NanoString::__string_to_ns[name] = ns;
|
__string_to_ns[name] = ns;
|
||||||
auto name2 = ns.to_cstring();
|
auto name2 = ns.to_cstring();
|
||||||
int len=0;
|
int len=0;
|
||||||
for (;;len++) {
|
for (;;len++) {
|
||||||
name2[len] = name[len];
|
name2[len] = name[len];
|
||||||
if (!name[len]) break;
|
if (!name[len]) break;
|
||||||
}
|
}
|
||||||
NanoString::__ns_len[i-1] = len;
|
__ns_len[i-1] = len;
|
||||||
};
|
};
|
||||||
#define INIT_NS(T) func(#T, ns_##T);
|
#define INIT_NS(T) func(#T, ns_##T);
|
||||||
FOR_ALL_NS(INIT_NS);
|
FOR_ALL_NS(INIT_NS);
|
||||||
ASSERT(i<=(1<<NanoString::_index_nbits));
|
ASSERT(i<=(1<<NanoString::_index_nbits));
|
||||||
NanoString::__string_to_ns["sum"] = ns_add;
|
__string_to_ns["sum"] = ns_add;
|
||||||
NanoString::__string_to_ns["min"] = ns_minimum;
|
__string_to_ns["min"] = ns_minimum;
|
||||||
NanoString::__string_to_ns["max"] = ns_maximum;
|
__string_to_ns["max"] = ns_maximum;
|
||||||
NanoString::__string_to_ns["float"] = ns_float32;
|
__string_to_ns["float"] = ns_float32;
|
||||||
NanoString::__string_to_ns["double"] = ns_float64;
|
__string_to_ns["double"] = ns_float64;
|
||||||
NanoString::__string_to_ns["int"] = ns_int32;
|
__string_to_ns["int"] = ns_int32;
|
||||||
NanoString::__string_to_ns["uint"] = ns_uint32;
|
__string_to_ns["uint"] = ns_uint32;
|
||||||
LOGvv << "init __string_to_ns" << NanoString::__string_to_ns;
|
LOGvv << "init __string_to_ns" << __string_to_ns;
|
||||||
LOGvv << "init __ns_to_string" << NanoString::__ns_to_string;
|
LOGvv << "init __ns_to_string" << __ns_to_string;
|
||||||
}
|
}
|
||||||
|
|
||||||
int __init_ns = (init_ns(), 0);
|
int __init_ns = (init_ns(), 0);
|
||||||
|
|
|
@ -86,9 +86,14 @@ constexpr int ns_max_len = 16;
|
||||||
m(normal) \
|
m(normal) \
|
||||||
|
|
||||||
struct NanoString;
|
struct NanoString;
|
||||||
#define DECLEAR_NS(T) extern NanoString ns_##T;
|
#define DECLEAR_NS(T) EXTERN_LIB NanoString ns_##T;
|
||||||
FOR_ALL_NS(DECLEAR_NS);
|
FOR_ALL_NS(DECLEAR_NS);
|
||||||
|
|
||||||
|
|
||||||
|
EXTERN_LIB unordered_map<string, NanoString> __string_to_ns;
|
||||||
|
EXTERN_LIB char __ns_to_string[];
|
||||||
|
EXTERN_LIB int __ns_len[];
|
||||||
|
|
||||||
// @pyjt(NanoString)
|
// @pyjt(NanoString)
|
||||||
struct NanoString {
|
struct NanoString {
|
||||||
typedef uint16 ns_t;
|
typedef uint16 ns_t;
|
||||||
|
@ -113,10 +118,6 @@ struct NanoString {
|
||||||
};
|
};
|
||||||
ns_t data=0;
|
ns_t data=0;
|
||||||
|
|
||||||
static unordered_map<string, NanoString> __string_to_ns;
|
|
||||||
static char __ns_to_string[];
|
|
||||||
static int __ns_len[];
|
|
||||||
|
|
||||||
inline void set(Flags f, ns_t a=1, ns_t nbits=1) {
|
inline void set(Flags f, ns_t a=1, ns_t nbits=1) {
|
||||||
ns_t mask = (((1u<<nbits)-1)<<f);
|
ns_t mask = (((1u<<nbits)-1)<<f);
|
||||||
data = (data & ~mask) | ((a<<f)&mask);
|
data = (data & ~mask) | ((a<<f)&mask);
|
||||||
|
|
|
@ -16,9 +16,13 @@ static inline int lzcnt(int64 v) {
|
||||||
#else
|
#else
|
||||||
return v ? __builtin_clzll(v) : 64;
|
return v ? __builtin_clzll(v) : 64;
|
||||||
#endif
|
#endif
|
||||||
|
#else
|
||||||
|
#ifdef _MSC_VER
|
||||||
|
return __lzcnt64(v);
|
||||||
#else
|
#else
|
||||||
return __builtin_clzll(v);
|
return __builtin_clzll(v);
|
||||||
#endif
|
#endif
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
struct Slice {
|
struct Slice {
|
||||||
|
|
|
@ -35,7 +35,7 @@ RingBuffer::~RingBuffer() {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
RingBuffer* RingBuffer::make_ring_buffer(uint64 size, bool multiprocess) {
|
RingBuffer* RingBuffer::make_ring_buffer(uint64 size, bool multiprocess, uint64 buffer, bool init) {
|
||||||
int i=0;
|
int i=0;
|
||||||
for (;(1ll<<i)<size;i++);
|
for (;(1ll<<i)<size;i++);
|
||||||
uint64 size_mask = (1ll<<i)-1;
|
uint64 size_mask = (1ll<<i)-1;
|
||||||
|
@ -47,26 +47,30 @@ RingBuffer* RingBuffer::make_ring_buffer(uint64 size, bool multiprocess) {
|
||||||
mmap(NULL, total_size, PROT_READ | PROT_WRITE, MAP_SHARED | MAP_ANONYMOUS, -1, 0)
|
mmap(NULL, total_size, PROT_READ | PROT_WRITE, MAP_SHARED | MAP_ANONYMOUS, -1, 0)
|
||||||
#else
|
#else
|
||||||
// TODO: multiprocess ring buffer in windows
|
// TODO: multiprocess ring buffer in windows
|
||||||
(void*)malloc(total_size)
|
(void*)buffer
|
||||||
#endif
|
#endif
|
||||||
:
|
:
|
||||||
// mmap(NULL, total_size, PROT_READ | PROT_WRITE, MAP_SHARED, -1, 0) :
|
// mmap(NULL, total_size, PROT_READ | PROT_WRITE, MAP_SHARED, -1, 0) :
|
||||||
(void*)malloc(total_size);
|
(void*)malloc(total_size);
|
||||||
std::memset(ptr, 0, total_size);
|
|
||||||
auto rb = (RingBuffer*)ptr;
|
auto rb = (RingBuffer*)ptr;
|
||||||
|
if (!init) return rb;
|
||||||
|
std::memset(ptr, 0, total_size);
|
||||||
new (rb) RingBuffer(size, multiprocess);
|
new (rb) RingBuffer(size, multiprocess);
|
||||||
return rb;
|
return rb;
|
||||||
}
|
}
|
||||||
|
|
||||||
void RingBuffer::free_ring_buffer(RingBuffer* rb) {
|
void RingBuffer::free_ring_buffer(RingBuffer* rb, uint64 buffer, bool init) {
|
||||||
uint64 total_size = sizeof(RingBuffer) + rb->size;
|
uint64 total_size = sizeof(RingBuffer) + rb->size;
|
||||||
auto is_multiprocess = rb->is_multiprocess;
|
auto is_multiprocess = rb->is_multiprocess;
|
||||||
rb->~RingBuffer();
|
if (init)
|
||||||
|
rb->~RingBuffer();
|
||||||
if (is_multiprocess) {
|
if (is_multiprocess) {
|
||||||
#ifndef _WIN32
|
#ifndef _WIN32
|
||||||
munmap(rb, total_size);
|
munmap(rb, total_size);
|
||||||
#else
|
#else
|
||||||
free((void*)rb);
|
if (!buffer)
|
||||||
|
free((void*)rb);
|
||||||
|
// this buffer is not owned by this obj
|
||||||
#endif
|
#endif
|
||||||
(void)total_size;
|
(void)total_size;
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -5,7 +5,11 @@
|
||||||
// file 'LICENSE.txt', which is part of this source code package.
|
// file 'LICENSE.txt', which is part of this source code package.
|
||||||
// ***************************************************************
|
// ***************************************************************
|
||||||
#pragma once
|
#pragma once
|
||||||
|
#ifdef _MSC_VER
|
||||||
|
#include <windows.h>
|
||||||
|
#else
|
||||||
#include <pthread.h>
|
#include <pthread.h>
|
||||||
|
#endif
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
#include "common.h"
|
#include "common.h"
|
||||||
|
|
||||||
|
@ -13,6 +17,37 @@ namespace jittor {
|
||||||
|
|
||||||
struct RingBuffer {
|
struct RingBuffer {
|
||||||
|
|
||||||
|
#ifdef _MSC_VER
|
||||||
|
struct Mutex {
|
||||||
|
HANDLE handle;
|
||||||
|
inline Mutex(bool multiprocess=0) {
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void lock() {
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void unlock() {
|
||||||
|
}
|
||||||
|
inline ~Mutex() {
|
||||||
|
}
|
||||||
|
};
|
||||||
|
struct MutexScope {
|
||||||
|
Mutex* m;
|
||||||
|
inline MutexScope(Mutex& m) : m(&m) { m.lock(); }
|
||||||
|
inline ~MutexScope() { m->unlock(); }
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Cond {
|
||||||
|
inline Cond(bool multiprocess=0) {
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void wait(MutexScope& m) {
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void notify() {
|
||||||
|
}
|
||||||
|
};
|
||||||
|
#else
|
||||||
struct Mutex {
|
struct Mutex {
|
||||||
pthread_mutex_t m;
|
pthread_mutex_t m;
|
||||||
inline Mutex(bool multiprocess=0) {
|
inline Mutex(bool multiprocess=0) {
|
||||||
|
@ -35,6 +70,11 @@ struct RingBuffer {
|
||||||
pthread_mutex_unlock(&m);
|
pthread_mutex_unlock(&m);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
struct MutexScope {
|
||||||
|
Mutex* m;
|
||||||
|
inline MutexScope(Mutex& m) : m(&m) { m.lock(); }
|
||||||
|
inline ~MutexScope() { m->unlock(); }
|
||||||
|
};
|
||||||
|
|
||||||
struct Cond {
|
struct Cond {
|
||||||
pthread_cond_t cv;
|
pthread_cond_t cv;
|
||||||
|
@ -56,20 +96,15 @@ struct RingBuffer {
|
||||||
pthread_cond_destroy(&cv);
|
pthread_cond_destroy(&cv);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void wait(Mutex& m) {
|
inline void wait(MutexScope& m) {
|
||||||
pthread_cond_wait(&cv, &m.m);
|
pthread_cond_wait(&cv, &m.m->m);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void notify() {
|
inline void notify() {
|
||||||
pthread_cond_signal(&cv);
|
pthread_cond_signal(&cv);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
#endif
|
||||||
struct MutexScope {
|
|
||||||
Mutex* m;
|
|
||||||
inline MutexScope(Mutex& m) : m(&m) { m.lock(); }
|
|
||||||
inline ~MutexScope() { m->unlock(); }
|
|
||||||
};
|
|
||||||
|
|
||||||
uint64 size;
|
uint64 size;
|
||||||
uint64 size_mask;
|
uint64 size_mask;
|
||||||
|
@ -86,8 +121,8 @@ struct RingBuffer {
|
||||||
RingBuffer(uint64 size, bool multiprocess=false);
|
RingBuffer(uint64 size, bool multiprocess=false);
|
||||||
~RingBuffer();
|
~RingBuffer();
|
||||||
void stop();
|
void stop();
|
||||||
static RingBuffer* make_ring_buffer(uint64 size, bool multiprocess);
|
static RingBuffer* make_ring_buffer(uint64 size, bool multiprocess, uint64 buffer=0, bool init=true);
|
||||||
static void free_ring_buffer(RingBuffer* rb);
|
static void free_ring_buffer(RingBuffer* rb, uint64 buffer=0, bool init=true);
|
||||||
|
|
||||||
inline void clear() { l = r = is_stop = 0; }
|
inline void clear() { l = r = is_stop = 0; }
|
||||||
|
|
||||||
|
@ -102,7 +137,7 @@ struct RingBuffer {
|
||||||
is_wait = 0;
|
is_wait = 0;
|
||||||
}
|
}
|
||||||
is_wait = 1;
|
is_wait = 1;
|
||||||
cv.wait(m);
|
cv.wait(_);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -20,6 +20,8 @@ namespace jittor {
|
||||||
using std::string_view;
|
using std::string_view;
|
||||||
#elif defined(__GNUC__)
|
#elif defined(__GNUC__)
|
||||||
using std::experimental::string_view;
|
using std::experimental::string_view;
|
||||||
|
#else
|
||||||
|
using std::string_view;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
template<class T>
|
template<class T>
|
||||||
|
|
|
@ -12,10 +12,10 @@
|
||||||
|
|
||||||
namespace jittor {
|
namespace jittor {
|
||||||
|
|
||||||
extern unordered_map<void*, int64> lived_nodes;
|
EXTERN_LIB unordered_map<void*, int64> lived_nodes;
|
||||||
extern int64 total_node;
|
EXTERN_LIB int64 total_node;
|
||||||
extern int64 nt;
|
EXTERN_LIB int64 nt;
|
||||||
extern vector<Node*> free_buffer;
|
EXTERN_LIB vector<Node*> free_buffer;
|
||||||
|
|
||||||
struct NodeFlags {
|
struct NodeFlags {
|
||||||
typedef uint16 nf_t;
|
typedef uint16 nf_t;
|
||||||
|
|
|
@ -97,12 +97,13 @@ string Op::get_jit_key(JK& jk) {
|
||||||
}
|
}
|
||||||
|
|
||||||
vector<pair<string,string>> Op::get_jit_define() {
|
vector<pair<string,string>> Op::get_jit_define() {
|
||||||
return parse_jit_keys(get_jit_key(jk));
|
return parse_jit_keys(get_jit_key(get_jk()));
|
||||||
}
|
}
|
||||||
|
|
||||||
string Op::get_hash_name() {
|
string Op::get_hash_name() {
|
||||||
string hash_name;
|
string hash_name;
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
|
JK& jk = get_jk();
|
||||||
do_prepare(jk);
|
do_prepare(jk);
|
||||||
ss << std::hex << std::hash<string>()(jk.to_string());
|
ss << std::hex << std::hash<string>()(jk.to_string());
|
||||||
hash_name = ss.str();
|
hash_name = ss.str();
|
||||||
|
@ -186,12 +187,13 @@ void Op::do_prepare(JK& jk){
|
||||||
|
|
||||||
void Op::do_run_after_prepare(JK& jk) {
|
void Op::do_run_after_prepare(JK& jk) {
|
||||||
if (!jk.empty())
|
if (!jk.empty())
|
||||||
jit_run();
|
jit_run(jk);
|
||||||
else
|
else
|
||||||
run();
|
run();
|
||||||
}
|
}
|
||||||
|
|
||||||
void Op::do_run() {
|
void Op::do_run() {
|
||||||
|
JK& jk = get_jk();
|
||||||
do_prepare(jk);
|
do_prepare(jk);
|
||||||
do_run_after_prepare(jk);
|
do_run_after_prepare(jk);
|
||||||
}
|
}
|
||||||
|
@ -209,10 +211,7 @@ string Op::get_filename_from_jit_key(const string& jit_key, const string& suffix
|
||||||
}
|
}
|
||||||
s = ss.str();
|
s = ss.str();
|
||||||
for (char& c : s) {
|
for (char& c : s) {
|
||||||
if (c=='[' || c==']' || c=='<' || c=='>'
|
if (!((c>='a' && c<='z') || (c>='A' && c<='Z') || (c>='0' && c<='9')))
|
||||||
|| c=='{' || c=='}' || c=='(' || c==')' || c==','
|
|
||||||
|| c=='\n' || c=='\t' || c==' ' || c=='&' || c=='|'
|
|
||||||
|| c=='/' || c==':')
|
|
||||||
c = '_';
|
c = '_';
|
||||||
}
|
}
|
||||||
#ifndef _WIN32
|
#ifndef _WIN32
|
||||||
|
@ -248,7 +247,7 @@ string Op::file_name_to_class_name(const string& s) {
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Op::jit_run() {
|
void Op::jit_run(JK& jk) {
|
||||||
const char* jit_key = jk.to_cstring();
|
const char* jit_key = jk.to_cstring();
|
||||||
auto iter = jit_ops.find(jit_key);
|
auto iter = jit_ops.find(jit_key);
|
||||||
if (iter != jit_ops.end()) {
|
if (iter != jit_ops.end()) {
|
||||||
|
|
|
@ -50,7 +50,7 @@ struct Op : Node {
|
||||||
virtual VarPtr duplicate();
|
virtual VarPtr duplicate();
|
||||||
virtual void compile_optimize(string& src);
|
virtual void compile_optimize(string& src);
|
||||||
virtual void graph_optimize();
|
virtual void graph_optimize();
|
||||||
void jit_run();
|
void jit_run(JK& jk);
|
||||||
|
|
||||||
string name_ex() const;
|
string name_ex() const;
|
||||||
string get_jit_key(JK& jk);
|
string get_jit_key(JK& jk);
|
||||||
|
@ -60,9 +60,9 @@ struct Op : Node {
|
||||||
|
|
||||||
std::ostream& operator<<(std::ostream& os, const Op* var);
|
std::ostream& operator<<(std::ostream& os, const Op* var);
|
||||||
|
|
||||||
extern string_view_map<jit_op_entry_t> jit_ops;
|
EXTERN_LIB string_view_map<jit_op_entry_t> jit_ops;
|
||||||
// jit_key_mapper: map origin jit_key -> tuned jit_key
|
// jit_key_mapper: map origin jit_key -> tuned jit_key
|
||||||
extern string_view_map<string> jit_key_mapper;
|
EXTERN_LIB string_view_map<string> jit_key_mapper;
|
||||||
|
|
||||||
#ifdef JIT
|
#ifdef JIT
|
||||||
#define DECLARE_jit_run void jit_run();
|
#define DECLARE_jit_run void jit_run();
|
||||||
|
|
|
@ -1042,7 +1042,7 @@ jit_op_entry_t OpCompiler::do_compile(Op* op) {
|
||||||
src = &src_after_passes;
|
src = &src_after_passes;
|
||||||
}
|
}
|
||||||
op->compile_optimize(*src);
|
op->compile_optimize(*src);
|
||||||
auto ret = oc.compile(op->get_jit_key(jk), *src);
|
auto ret = oc.compile(op->get_jit_key(get_jk()), *src);
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -129,9 +129,13 @@ void BroadcastToOp::infer_shape() {
|
||||||
auto xdim = x->shape.size();
|
auto xdim = x->shape.size();
|
||||||
auto ydim = yshapes.size();
|
auto ydim = yshapes.size();
|
||||||
auto count = __builtin_popcount(bcast_mask&~keepdims_mask);
|
auto count = __builtin_popcount(bcast_mask&~keepdims_mask);
|
||||||
auto zdim = std::max(xdim, ydim-count) + count;
|
auto zdim = std::max(uint64(xdim), uint64(ydim-count)) + count;
|
||||||
|
|
||||||
|
#ifdef _WIN32
|
||||||
|
int64 zz[10];
|
||||||
|
#else
|
||||||
int64 zz[zdim];
|
int64 zz[zdim];
|
||||||
|
#endif
|
||||||
for (int i=zdim-1, xi = xdim-1, yi = ydim-1; i>=0; i--) {
|
for (int i=zdim-1, xi = xdim-1, yi = ydim-1; i>=0; i--) {
|
||||||
bool bx = xi>=0;
|
bool bx = xi>=0;
|
||||||
bool by = yi>=0;
|
bool by = yi>=0;
|
||||||
|
|
|
@ -280,7 +280,7 @@ void GetitemOp::_compile_optimize(string& src) {
|
||||||
new_func->push_back(func->children.back()->move_out());
|
new_func->push_back(func->children.back()->move_out());
|
||||||
auto& loop = new_func->children.back();
|
auto& loop = new_func->children.back();
|
||||||
int no = o_shape.size();
|
int no = o_shape.size();
|
||||||
KernelIR* loops[no];
|
STACK_ALLOC(KernelIR*, loops, no);
|
||||||
if (!no) {
|
if (!no) {
|
||||||
func->push_back("func<<<1,1>>>("+arg_call+");");
|
func->push_back("func<<<1,1>>>("+arg_call+");");
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -38,6 +38,6 @@ VarPtr make_number(float number, Var* x) {
|
||||||
static void init() {
|
static void init() {
|
||||||
op_registe({"number", "", "", {{&typeid(&make_number), (void*)&make_number}}});
|
op_registe({"number", "", "", {{&typeid(&make_number), (void*)&make_number}}});
|
||||||
}
|
}
|
||||||
__attribute__((unused)) static int caller = (init(), 0);
|
static int caller = (init(), 0);
|
||||||
|
|
||||||
} // jittor
|
} // jittor
|
||||||
|
|
|
@ -213,17 +213,17 @@ static void getitem_inplace(GetitemOp* op) {
|
||||||
void SetitemOp::graph_optimize() {
|
void SetitemOp::graph_optimize() {
|
||||||
// LOGir << "hello graph_optimize";
|
// LOGir << "hello graph_optimize";
|
||||||
setitem_inplace(this);
|
setitem_inplace(this);
|
||||||
(void)setitem_inplace;
|
(void*)setitem_inplace;
|
||||||
}
|
}
|
||||||
|
|
||||||
void GetitemOp::graph_optimize() {
|
void GetitemOp::graph_optimize() {
|
||||||
// This optimize is still WIP
|
// This optimize is still WIP
|
||||||
// LOGir << "hello getitem graph_optimize";
|
// LOGir << "hello getitem graph_optimize";
|
||||||
// setitem_grad_opt(this);
|
// setitem_grad_opt(this);
|
||||||
(void)setitem_grad_opt;
|
(void*)setitem_grad_opt;
|
||||||
// (void)getitem_inplace;
|
// (void)getitem_inplace;
|
||||||
getitem_inplace(this);
|
getitem_inplace(this);
|
||||||
(void)getitem_inplace;
|
(void*)getitem_inplace;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -23,6 +23,7 @@ Searcher::Searcher(OpCompiler* oc) : oc(oc) {
|
||||||
}
|
}
|
||||||
|
|
||||||
int64_t Searcher::get_time_of_current_choices() {
|
int64_t Searcher::get_time_of_current_choices() {
|
||||||
|
JK& jk = get_jk();
|
||||||
auto* op = oc->op;
|
auto* op = oc->op;
|
||||||
// generate jit_key
|
// generate jit_key
|
||||||
op->update_jit_key();
|
op->update_jit_key();
|
||||||
|
|
|
@ -90,7 +90,7 @@ void CheckCachePass::run() {
|
||||||
ir->push_back("#include \"profiler/memory_checker.h\"", &ir->before);
|
ir->push_back("#include \"profiler/memory_checker.h\"", &ir->before);
|
||||||
ir->push_back("using namespace jittor;", &ir->before);
|
ir->push_back("using namespace jittor;", &ir->before);
|
||||||
// declaration
|
// declaration
|
||||||
ir->push_back("extern \"C\" std::unique_ptr<MemoryChecker> memory_checker;", &ir->before);
|
ir->push_back("EXTERN_LIB \"C\" std::unique_ptr<MemoryChecker> memory_checker;", &ir->before);
|
||||||
// definition
|
// definition
|
||||||
ir->push_back("std::unique_ptr<MemoryChecker> memory_checker;", &ir->before);
|
ir->push_back("std::unique_ptr<MemoryChecker> memory_checker;", &ir->before);
|
||||||
vector<string> commands;
|
vector<string> commands;
|
||||||
|
|
|
@ -17,6 +17,7 @@ namespace jittor {
|
||||||
using namespace expr;
|
using namespace expr;
|
||||||
|
|
||||||
void ConstVarPass::run() {
|
void ConstVarPass::run() {
|
||||||
|
JK& jk = get_jk();
|
||||||
int changed = 0;
|
int changed = 0;
|
||||||
for (int i=0; i<op->ops.size(); i++) {
|
for (int i=0; i<op->ops.size(); i++) {
|
||||||
auto opi = op->ops[i];
|
auto opi = op->ops[i];
|
||||||
|
|
|
@ -234,7 +234,7 @@ void ConvTuner::forwardTune(FusedOp* fop) {
|
||||||
continue;
|
continue;
|
||||||
Op* ops[3] = {op, bop->x->input(), bop->y->input()};
|
Op* ops[3] = {op, bop->x->input(), bop->y->input()};
|
||||||
int ok = 0;
|
int ok = 0;
|
||||||
LOGvvvv << "conv like op" << fop << fop->get_jit_key(jk);
|
LOGvvvv << "conv like op" << fop << fop->get_jit_key(get_jk());
|
||||||
for (int y_id=0; y_id<3; y_id++)
|
for (int y_id=0; y_id<3; y_id++)
|
||||||
for (int x_id=0; x_id<3; x_id++)
|
for (int x_id=0; x_id<3; x_id++)
|
||||||
for (int w_id=0; w_id<3; w_id++) {
|
for (int w_id=0; w_id<3; w_id++) {
|
||||||
|
|
|
@ -69,7 +69,7 @@ int VarRelayManager::add_relay_group(const vector<pair<Var*, Var*>>& group) {
|
||||||
if (node->is_var())
|
if (node->is_var())
|
||||||
continue;
|
continue;
|
||||||
Op* op = node->op();
|
Op* op = node->op();
|
||||||
op->do_jit_prepare(jk);
|
op->do_jit_prepare(get_jk());
|
||||||
list<Node*> new_inputs;
|
list<Node*> new_inputs;
|
||||||
int removed = 0;
|
int removed = 0;
|
||||||
for (Var* v : op->inputs())
|
for (Var* v : op->inputs())
|
||||||
|
|
|
@ -25,7 +25,7 @@ namespace jittor {
|
||||||
DEFINE_FLAG(int, use_parallel_op_compiler, 16, "Number of threads that parallel op comiler used, default 16, set this value to 0 will disable parallel op compiler.");
|
DEFINE_FLAG(int, use_parallel_op_compiler, 16, "Number of threads that parallel op comiler used, default 16, set this value to 0 will disable parallel op compiler.");
|
||||||
|
|
||||||
// from log.cc
|
// from log.cc
|
||||||
extern int segfault_happen;
|
EXTERN_LIB int segfault_happen;
|
||||||
|
|
||||||
// simple thread used for parallel compilation
|
// simple thread used for parallel compilation
|
||||||
struct SimpleThread {
|
struct SimpleThread {
|
||||||
|
@ -36,7 +36,7 @@ struct SimpleThread {
|
||||||
std::condition_variable cv;
|
std::condition_variable cv;
|
||||||
std::thread thread;
|
std::thread thread;
|
||||||
void run() {
|
void run() {
|
||||||
thread_name = "C"+S(id);
|
get_thread_name() = "C"+S(id);
|
||||||
try {
|
try {
|
||||||
std::unique_lock<std::mutex> lck(mtx);
|
std::unique_lock<std::mutex> lck(mtx);
|
||||||
if (func)
|
if (func)
|
||||||
|
@ -70,8 +70,8 @@ struct SimpleThread {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct SimpleThreads;
|
struct SimpleThreads;
|
||||||
extern SimpleThreads threads;
|
EXTERN_LIB SimpleThreads threads;
|
||||||
extern vector<void(*)()> cleanup_callback;
|
EXTERN_LIB vector<void(*)()> cleanup_callback;
|
||||||
|
|
||||||
struct SimpleThreads {
|
struct SimpleThreads {
|
||||||
list<SimpleThread> threads;
|
list<SimpleThread> threads;
|
||||||
|
@ -136,7 +136,7 @@ void parallel_compile_all_ops(vector<int>& queue, vector<int>& range, FusedOp& f
|
||||||
vector<int> op_needs_compile;
|
vector<int> op_needs_compile;
|
||||||
string_view_map<int> map;
|
string_view_map<int> map;
|
||||||
vector<unique_ptr<FusedOp>> fop_needs_compile;
|
vector<unique_ptr<FusedOp>> fop_needs_compile;
|
||||||
auto& jkl = jk;
|
auto& jkl = get_jk();
|
||||||
|
|
||||||
for (uint rid=0; rid<queue.size(); rid++) {
|
for (uint rid=0; rid<queue.size(); rid++) {
|
||||||
int root = queue[rid];
|
int root = queue[rid];
|
||||||
|
@ -213,7 +213,7 @@ void parallel_compile_all_ops(vector<int>& queue, vector<int>& range, FusedOp& f
|
||||||
auto func = [&](int tid) {
|
auto func = [&](int tid) {
|
||||||
auto& entrys = op_entrys.at(tid);
|
auto& entrys = op_entrys.at(tid);
|
||||||
entrys.clear();
|
entrys.clear();
|
||||||
auto& jkl = jk;
|
auto& jkl = get_jk();
|
||||||
while (!has_error && !segfault_happen) {
|
while (!has_error && !segfault_happen) {
|
||||||
int i = ai++;
|
int i = ai++;
|
||||||
if (i >= n) break;
|
if (i >= n) break;
|
||||||
|
@ -247,14 +247,14 @@ void parallel_compile_all_ops(vector<int>& queue, vector<int>& range, FusedOp& f
|
||||||
bool needs_compile;
|
bool needs_compile;
|
||||||
{
|
{
|
||||||
std::lock_guard<std::mutex> lock(entry_lock);
|
std::lock_guard<std::mutex> lock(entry_lock);
|
||||||
auto iter = jit_ops.find(jk.to_cstring());
|
auto iter = jit_ops.find(jkl.to_cstring());
|
||||||
needs_compile = (iter == jit_ops.end());
|
needs_compile = (iter == jit_ops.end());
|
||||||
if (needs_compile) {
|
if (needs_compile) {
|
||||||
jit_ops[jk.to_cstring()] = nullptr;
|
jit_ops[jkl.to_cstring()] = nullptr;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (!needs_compile) continue;
|
if (!needs_compile) continue;
|
||||||
string s = jk.to_string();
|
string s = jkl.to_string();
|
||||||
auto op_entry = OpCompiler::do_compile(orc.op);
|
auto op_entry = OpCompiler::do_compile(orc.op);
|
||||||
{
|
{
|
||||||
std::lock_guard<std::mutex> lock(entry_lock);
|
std::lock_guard<std::mutex> lock(entry_lock);
|
||||||
|
@ -266,7 +266,7 @@ void parallel_compile_all_ops(vector<int>& queue, vector<int>& range, FusedOp& f
|
||||||
} catch (const std::exception& e) {
|
} catch (const std::exception& e) {
|
||||||
// log jit_key and file location
|
// log jit_key and file location
|
||||||
op->do_prepare(jkl);
|
op->do_prepare(jkl);
|
||||||
string jit_src_path = Op::get_filename_from_jit_key(jk.to_cstring(), ".cc");
|
string jit_src_path = Op::get_filename_from_jit_key(jkl.to_cstring(), ".cc");
|
||||||
LOGe << "[Error] source file location:" << jit_src_path;
|
LOGe << "[Error] source file location:" << jit_src_path;
|
||||||
|
|
||||||
if (is_fused_op) {
|
if (is_fused_op) {
|
||||||
|
|
|
@ -87,7 +87,7 @@ unique_ptr<MemoryChecker>* load_memory_checker(string name) {
|
||||||
return mm;
|
return mm;
|
||||||
}
|
}
|
||||||
|
|
||||||
extern string _get_stack_info(Node* node);
|
EXTERN_LIB string _get_stack_info(Node* node);
|
||||||
|
|
||||||
static string get_stack_info(Op* op) {
|
static string get_stack_info(Op* op) {
|
||||||
string stack_info = "stack info:\n";
|
string stack_info = "stack info:\n";
|
||||||
|
|
|
@ -59,7 +59,7 @@ struct Profiler {
|
||||||
~Profiler();
|
~Profiler();
|
||||||
};
|
};
|
||||||
|
|
||||||
extern Profiler profiler;
|
EXTERN_LIB Profiler profiler;
|
||||||
|
|
||||||
DECLARE_FLAG(int, profiler_enable);
|
DECLARE_FLAG(int, profiler_enable);
|
||||||
|
|
||||||
|
|
|
@ -18,9 +18,13 @@ static inline int _lzcnt(int64 v) {
|
||||||
#else
|
#else
|
||||||
return v ? __builtin_clzll(v) : 64;
|
return v ? __builtin_clzll(v) : 64;
|
||||||
#endif
|
#endif
|
||||||
|
#else
|
||||||
|
#ifdef _MSC_VER
|
||||||
|
return __lzcnt64(v);
|
||||||
#else
|
#else
|
||||||
return __builtin_clzll(v);
|
return __builtin_clzll(v);
|
||||||
#endif
|
#endif
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
struct SimpleProfiler {
|
struct SimpleProfiler {
|
||||||
|
|
|
@ -12,7 +12,7 @@
|
||||||
namespace jittor {
|
namespace jittor {
|
||||||
|
|
||||||
// Those function is generated by python
|
// Those function is generated by python
|
||||||
extern void pyjt_def_all(PyObject* m);
|
EXTERN_LIB void pyjt_def_all(PyObject* m);
|
||||||
|
|
||||||
vector<VarHolder*> _grad(VarHolder* loss, const vector<VarHolder*>& targets) {
|
vector<VarHolder*> _grad(VarHolder* loss, const vector<VarHolder*>& targets) {
|
||||||
vector<Var*> vs;
|
vector<Var*> vs;
|
||||||
|
|
|
@ -94,7 +94,7 @@ static vector<Stack> get_stack_info() {
|
||||||
auto frame = (PyFrameObject*)ret.obj;
|
auto frame = (PyFrameObject*)ret.obj;
|
||||||
int n=0;
|
int n=0;
|
||||||
while (frame) n++, frame = frame->f_back;
|
while (frame) n++, frame = frame->f_back;
|
||||||
PyFrameObject* frames[n];
|
STACK_ALLOC(PyFrameObject*, frames, n);
|
||||||
frame = (PyFrameObject*)ret.obj;
|
frame = (PyFrameObject*)ret.obj;
|
||||||
int i=n;
|
int i=n;
|
||||||
while (i) frames[--i] = frame, frame = frame->f_back;
|
while (i) frames[--i] = frame, frame = frame->f_back;
|
||||||
|
@ -225,7 +225,7 @@ static inline string get_var_data_str(Var* v) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void TraceData::record_node(Node* node, bool record_stack) {
|
void TraceData::record_node(Node* node, bool record_stack) {
|
||||||
if (thread_name.size()) return;
|
if (get_thread_name().size()) return;
|
||||||
NodeData data;
|
NodeData data;
|
||||||
data.id = node_data_cnt++;
|
data.id = node_data_cnt++;
|
||||||
id_map[node] = data.id;
|
id_map[node] = data.id;
|
||||||
|
@ -261,7 +261,7 @@ static int64 get_node_id(Node* node) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void TraceData::release_node(Node* node) {
|
void TraceData::release_node(Node* node) {
|
||||||
if (thread_name.size()) return;
|
if (get_thread_name().size()) return;
|
||||||
auto iter = trace_data.id_map.find(node);
|
auto iter = trace_data.id_map.find(node);
|
||||||
if (iter == trace_data.id_map.end())
|
if (iter == trace_data.id_map.end())
|
||||||
return;
|
return;
|
||||||
|
|
|
@ -10,7 +10,7 @@
|
||||||
namespace jittor {
|
namespace jittor {
|
||||||
|
|
||||||
DECLARE_FLAG(int, trace_py_var);
|
DECLARE_FLAG(int, trace_py_var);
|
||||||
extern Op* trace_grad_op;
|
EXTERN_LIB Op* trace_grad_op;
|
||||||
struct JitKey;
|
struct JitKey;
|
||||||
|
|
||||||
struct Stack {
|
struct Stack {
|
||||||
|
@ -64,7 +64,7 @@ struct TraceData {
|
||||||
void record_execution(Op* op, bool is_fused_op, JitKey& jk);
|
void record_execution(Op* op, bool is_fused_op, JitKey& jk);
|
||||||
};
|
};
|
||||||
|
|
||||||
extern TraceData trace_data;
|
EXTERN_LIB TraceData trace_data;
|
||||||
|
|
||||||
void print_node_trace(const Node* node, std::ostream& os);
|
void print_node_trace(const Node* node, std::ostream& os);
|
||||||
vector<Stack> get_node_trace(Node* node);
|
vector<Stack> get_node_trace(Node* node);
|
||||||
|
|
|
@ -50,8 +50,8 @@ enum NPY_TYPES {
|
||||||
NPY_OBJECT=17,
|
NPY_OBJECT=17,
|
||||||
};
|
};
|
||||||
|
|
||||||
extern NanoString npy2ns[];
|
EXTERN_LIB NanoString npy2ns[];
|
||||||
extern NPY_TYPES ns2npy[];
|
EXTERN_LIB NPY_TYPES ns2npy[];
|
||||||
|
|
||||||
#define NPY_ARRAY_C_CONTIGUOUS 0x0001
|
#define NPY_ARRAY_C_CONTIGUOUS 0x0001
|
||||||
#define NPY_ARRAY_ALIGNED 0x0100
|
#define NPY_ARRAY_ALIGNED 0x0100
|
||||||
|
@ -74,19 +74,19 @@ inline int get_typenum(NanoString ns) {
|
||||||
|
|
||||||
typedef Py_intptr_t npy_intp;
|
typedef Py_intptr_t npy_intp;
|
||||||
|
|
||||||
extern unordered_map<string, int> np_typenum_map;
|
EXTERN_LIB unordered_map<string, int> np_typenum_map;
|
||||||
|
|
||||||
extern void** PyArray_API;
|
EXTERN_LIB void** PyArray_API;
|
||||||
extern PyTypeObject *PyArray_Type;
|
EXTERN_LIB PyTypeObject *PyArray_Type;
|
||||||
extern PyTypeObject *PyNumberArrType_Type;
|
EXTERN_LIB PyTypeObject *PyNumberArrType_Type;
|
||||||
extern PyTypeObject *PyArrayDescr_Type;
|
EXTERN_LIB PyTypeObject *PyArrayDescr_Type;
|
||||||
extern PyObject* (*PyArray_New)(PyTypeObject *, int, npy_intp const *, int, npy_intp const *, void *, int, int, PyObject *);
|
EXTERN_LIB PyObject* (*PyArray_New)(PyTypeObject *, int, npy_intp const *, int, npy_intp const *, void *, int, int, PyObject *);
|
||||||
extern PyObject* (*PyArray_FromAny)(PyObject *, PyArrayDescr_Proxy *, int, int, int, PyObject *);
|
EXTERN_LIB PyObject* (*PyArray_FromAny)(PyObject *, PyArrayDescr_Proxy *, int, int, int, PyObject *);
|
||||||
extern unsigned int (*PyArray_GetNDArrayCFeatureVersion)();
|
EXTERN_LIB unsigned int (*PyArray_GetNDArrayCFeatureVersion)();
|
||||||
extern int (*PyArray_SetBaseObject)(PyObject *arr, PyObject *obj);
|
EXTERN_LIB int (*PyArray_SetBaseObject)(PyObject *arr, PyObject *obj);
|
||||||
extern PyObject* (*PyArray_NewCopy)(PyObject *, int);
|
EXTERN_LIB PyObject* (*PyArray_NewCopy)(PyObject *, int);
|
||||||
extern int (*PyArray_CopyInto)(PyObject *, PyObject *);
|
EXTERN_LIB int (*PyArray_CopyInto)(PyObject *, PyObject *);
|
||||||
extern void (*PyArray_CastScalarToCtype)(PyObject* scalar, void* ctypeptr, PyArrayDescr_Proxy* outcode);
|
EXTERN_LIB void (*PyArray_CastScalarToCtype)(PyObject* scalar, void* ctypeptr, PyArrayDescr_Proxy* outcode);
|
||||||
|
|
||||||
#define PyArray_Copy(obj) PyArray_NewCopy(obj, 0)
|
#define PyArray_Copy(obj) PyArray_NewCopy(obj, 0)
|
||||||
|
|
||||||
|
@ -121,7 +121,7 @@ union tmp_data_t {
|
||||||
int8 i8;
|
int8 i8;
|
||||||
};
|
};
|
||||||
|
|
||||||
extern tmp_data_t tmp_data;
|
EXTERN_LIB tmp_data_t tmp_data;
|
||||||
|
|
||||||
void numpy_init();
|
void numpy_init();
|
||||||
|
|
||||||
|
|
|
@ -141,7 +141,7 @@ ArrayOp::ArrayOp(PyObject* obj) {
|
||||||
} else {
|
} else {
|
||||||
// this is non-continue numpy array
|
// this is non-continue numpy array
|
||||||
#if defined(__linux__) || defined(_WIN32)
|
#if defined(__linux__) || defined(_WIN32)
|
||||||
int64 dims[args.shape.size()];
|
STACK_ALLOC(int64, dims, args.shape.size());
|
||||||
#elif defined(__APPLE__)
|
#elif defined(__APPLE__)
|
||||||
long dims[args.shape.size()];
|
long dims[args.shape.size()];
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -135,7 +135,7 @@ DEF_IS(Slice, T) from_py_object(PyObject* obj) {
|
||||||
|
|
||||||
// DumpGraphs
|
// DumpGraphs
|
||||||
struct DumpGraphs;
|
struct DumpGraphs;
|
||||||
extern PyTypeObject PyjtDumpGraphs;
|
EXTERN_LIB PyTypeObject PyjtDumpGraphs;
|
||||||
DEF_IS(DumpGraphs, bool) is_type(PyObject* obj) {
|
DEF_IS(DumpGraphs, bool) is_type(PyObject* obj) {
|
||||||
return Py_TYPE(obj) == &PyjtDumpGraphs;
|
return Py_TYPE(obj) == &PyjtDumpGraphs;
|
||||||
}
|
}
|
||||||
|
@ -157,7 +157,7 @@ DEF_IS(DumpGraphs, const T&) from_py_object(PyObject* obj) {
|
||||||
|
|
||||||
// MemInfo
|
// MemInfo
|
||||||
struct MemInfo;
|
struct MemInfo;
|
||||||
extern PyTypeObject PyjtMemInfo;
|
EXTERN_LIB PyTypeObject PyjtMemInfo;
|
||||||
DEF_IS(MemInfo, bool) is_type(PyObject* obj) {
|
DEF_IS(MemInfo, bool) is_type(PyObject* obj) {
|
||||||
return Py_TYPE(obj) == &PyjtMemInfo;
|
return Py_TYPE(obj) == &PyjtMemInfo;
|
||||||
}
|
}
|
||||||
|
@ -177,7 +177,7 @@ DEF_IS(MemInfo, const T&) from_py_object(PyObject* obj) {
|
||||||
|
|
||||||
// NanoString
|
// NanoString
|
||||||
struct NanoString;
|
struct NanoString;
|
||||||
extern PyTypeObject PyjtNanoString;
|
EXTERN_LIB PyTypeObject PyjtNanoString;
|
||||||
DEF_IS(NanoString, bool) is_type(PyObject* obj) {
|
DEF_IS(NanoString, bool) is_type(PyObject* obj) {
|
||||||
return Py_TYPE(obj) == &PyjtNanoString ||
|
return Py_TYPE(obj) == &PyjtNanoString ||
|
||||||
PyUnicode_CheckExact(obj) ||
|
PyUnicode_CheckExact(obj) ||
|
||||||
|
@ -215,7 +215,7 @@ DEF_IS(NanoString, T) from_py_object(PyObject* obj) {
|
||||||
|
|
||||||
// NanoVector
|
// NanoVector
|
||||||
struct NanoVector;
|
struct NanoVector;
|
||||||
extern PyTypeObject PyjtNanoVector;
|
EXTERN_LIB PyTypeObject PyjtNanoVector;
|
||||||
DEF_IS(NanoVector, bool) is_type(PyObject* obj) {
|
DEF_IS(NanoVector, bool) is_type(PyObject* obj) {
|
||||||
return Py_TYPE(obj) == &PyjtNanoVector ||
|
return Py_TYPE(obj) == &PyjtNanoVector ||
|
||||||
PyList_CheckExact(obj) || PyTuple_CheckExact(obj);
|
PyList_CheckExact(obj) || PyTuple_CheckExact(obj);
|
||||||
|
@ -253,7 +253,7 @@ DEF_IS(NanoVector, T) from_py_object(PyObject* obj) {
|
||||||
struct ArrayArgs;
|
struct ArrayArgs;
|
||||||
struct VarHolder;
|
struct VarHolder;
|
||||||
vector<ArrayArgs> fetch_sync(const vector<VarHolder*>& vh);
|
vector<ArrayArgs> fetch_sync(const vector<VarHolder*>& vh);
|
||||||
extern PyHeapTypeObject PyjtVarHolder;
|
EXTERN_LIB PyHeapTypeObject PyjtVarHolder;
|
||||||
DEF_IS(ArrayArgs, bool) is_type(PyObject* obj) {
|
DEF_IS(ArrayArgs, bool) is_type(PyObject* obj) {
|
||||||
return
|
return
|
||||||
Py_TYPE(obj) == &PyjtVarHolder.ht_type ||
|
Py_TYPE(obj) == &PyjtVarHolder.ht_type ||
|
||||||
|
@ -267,7 +267,7 @@ DEF_IS(ArrayArgs, bool) is_type(PyObject* obj) {
|
||||||
|
|
||||||
DEF_IS(ArrayArgs, PyObject*) to_py_object(const T& a) {
|
DEF_IS(ArrayArgs, PyObject*) to_py_object(const T& a) {
|
||||||
#if defined(__linux__) || defined(_WIN32)
|
#if defined(__linux__) || defined(_WIN32)
|
||||||
int64 dims[a.shape.size()];
|
STACK_ALLOC(int64, dims, a.shape.size());
|
||||||
#elif defined(__APPLE__)
|
#elif defined(__APPLE__)
|
||||||
long dims[a.shape.size()];
|
long dims[a.shape.size()];
|
||||||
#endif
|
#endif
|
||||||
|
@ -351,8 +351,8 @@ DEF_IS(ArrayArgs, T) from_py_object(PyObject* obj) {
|
||||||
|
|
||||||
// VarHolder
|
// VarHolder
|
||||||
struct VarHolder;
|
struct VarHolder;
|
||||||
extern PyHeapTypeObject PyjtVarHolder;
|
EXTERN_LIB PyHeapTypeObject PyjtVarHolder;
|
||||||
namespace jit_op_maker { extern VarHolder* array_(ArrayArgs&& args); }
|
namespace jit_op_maker { EXTERN_LIB VarHolder* array_(ArrayArgs&& args); }
|
||||||
DEF_IS(VarHolder*, bool) is_type(PyObject* obj) {
|
DEF_IS(VarHolder*, bool) is_type(PyObject* obj) {
|
||||||
return Py_TYPE(obj) == &PyjtVarHolder.ht_type ||
|
return Py_TYPE(obj) == &PyjtVarHolder.ht_type ||
|
||||||
is_type<ArrayArgs>(obj);
|
is_type<ArrayArgs>(obj);
|
||||||
|
@ -383,7 +383,7 @@ DEF_IS(VarHolder*, T) from_py_object(PyObject* obj, unique_ptr<VarHolder>& holde
|
||||||
struct DataView;
|
struct DataView;
|
||||||
DEF_IS(DataView, PyObject*) to_py_object(T a) {
|
DEF_IS(DataView, PyObject*) to_py_object(T a) {
|
||||||
#if defined(__linux__) || defined(_WIN32)
|
#if defined(__linux__) || defined(_WIN32)
|
||||||
int64 dims[a.shape.size()];
|
STACK_ALLOC(int64, dims, a.shape.size());
|
||||||
#elif defined(__APPLE__)
|
#elif defined(__APPLE__)
|
||||||
long dims[a.shape.size()];
|
long dims[a.shape.size()];
|
||||||
#endif
|
#endif
|
||||||
|
@ -410,8 +410,9 @@ DEF_IS(DataView, PyObject*) to_py_object(T a) {
|
||||||
return oh.release();
|
return oh.release();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef __GNUC__
|
||||||
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
|
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
|
||||||
|
#endif
|
||||||
struct ItemData;
|
struct ItemData;
|
||||||
DEF_IS(ItemData, PyObject*) to_py_object(T a) {
|
DEF_IS(ItemData, PyObject*) to_py_object(T a) {
|
||||||
if (a.dtype == ns_bool) {
|
if (a.dtype == ns_bool) {
|
||||||
|
|
|
@ -110,7 +110,7 @@ static void push_py_object(RingBuffer* rb, PyObject* obj, uint64& __restrict__ o
|
||||||
rb->push(size, offset);
|
rb->push(size, offset);
|
||||||
args.ptr = rb->get_ptr(size, offset);
|
args.ptr = rb->get_ptr(size, offset);
|
||||||
#if defined(__linux__) || defined(_WIN32)
|
#if defined(__linux__) || defined(_WIN32)
|
||||||
int64 dims[args.shape.size()];
|
STACK_ALLOC(int64, dims, args.shape.size());
|
||||||
#elif defined(__APPLE__)
|
#elif defined(__APPLE__)
|
||||||
long dims[args.shape.size()];
|
long dims[args.shape.size()];
|
||||||
#endif
|
#endif
|
||||||
|
@ -225,12 +225,19 @@ PyObject* PyMultiprocessRingBuffer::pop() {
|
||||||
return obj;
|
return obj;
|
||||||
}
|
}
|
||||||
|
|
||||||
PyMultiprocessRingBuffer::PyMultiprocessRingBuffer(uint64 size) {
|
PyMultiprocessRingBuffer::PyMultiprocessRingBuffer(uint64 size, uint64 buffer, bool init) {
|
||||||
rb = RingBuffer::make_ring_buffer(size, 1);
|
this->buffer = buffer;
|
||||||
|
this->init = init;
|
||||||
|
if (buffer) {
|
||||||
|
auto mobj = (PyObject*)buffer;
|
||||||
|
auto buf = PyMemoryView_GET_BUFFER(mobj);
|
||||||
|
buffer = (uint64)buf->buf;
|
||||||
|
}
|
||||||
|
rb = RingBuffer::make_ring_buffer(size, 1, buffer, init);
|
||||||
}
|
}
|
||||||
|
|
||||||
PyMultiprocessRingBuffer::~PyMultiprocessRingBuffer() {
|
PyMultiprocessRingBuffer::~PyMultiprocessRingBuffer() {
|
||||||
RingBuffer::free_ring_buffer(rb);
|
RingBuffer::free_ring_buffer(rb, buffer, init);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -13,9 +13,11 @@ namespace jittor {
|
||||||
// @pyjt(RingBuffer)
|
// @pyjt(RingBuffer)
|
||||||
struct PyMultiprocessRingBuffer {
|
struct PyMultiprocessRingBuffer {
|
||||||
RingBuffer* rb;
|
RingBuffer* rb;
|
||||||
|
uint64 buffer;
|
||||||
bool _keep_numpy_array = false;
|
bool _keep_numpy_array = false;
|
||||||
|
bool init;
|
||||||
// @pyjt(__init__)
|
// @pyjt(__init__)
|
||||||
PyMultiprocessRingBuffer(uint64 size);
|
PyMultiprocessRingBuffer(uint64 size, uint64 buffer=0, bool init=true);
|
||||||
// @pyjt(__dealloc__)
|
// @pyjt(__dealloc__)
|
||||||
~PyMultiprocessRingBuffer();
|
~PyMultiprocessRingBuffer();
|
||||||
// @pyjt(push,send)
|
// @pyjt(push,send)
|
||||||
|
@ -46,6 +48,9 @@ struct PyMultiprocessRingBuffer {
|
||||||
s += ")";
|
s += ")";
|
||||||
return s;
|
return s;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// @pyjt(__get__size)
|
||||||
|
inline uint64 size() { return rb->size; }
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -9,10 +9,11 @@
|
||||||
namespace jittor {
|
namespace jittor {
|
||||||
|
|
||||||
JIT_TEST(jit_key) {
|
JIT_TEST(jit_key) {
|
||||||
|
JK& jk = get_jk();
|
||||||
jk.clear();
|
jk.clear();
|
||||||
for (int i=0; i<JK::buffer_size/2; i++)
|
for (int i=0; i<JK::buffer_size/2; i++)
|
||||||
jk.buffer[i] = i%256;
|
jk.buffer[i] = i%256;
|
||||||
expect_error([]() {
|
expect_error([&]() {
|
||||||
for (int i=0; i<JK::buffer_size; i++)
|
for (int i=0; i<JK::buffer_size; i++)
|
||||||
jk.buffer[i] = i%256;
|
jk.buffer[i] = i%256;
|
||||||
});
|
});
|
||||||
|
@ -45,9 +46,11 @@ JIT_TEST(jit_key) {
|
||||||
jk.clear();
|
jk.clear();
|
||||||
add_jit_define(jk, "f", 0.01);
|
add_jit_define(jk, "f", 0.01);
|
||||||
add_jit_define(jk, "f", 0.5);
|
add_jit_define(jk, "f", 0.5);
|
||||||
|
#ifndef _MSC_VER
|
||||||
add_jit_define(jk, "f", 1.0/0);
|
add_jit_define(jk, "f", 1.0/0);
|
||||||
add_jit_define(jk, "f", -1.0/0);
|
add_jit_define(jk, "f", -1.0/0);
|
||||||
add_jit_define(jk, "f", 0.0/0);
|
add_jit_define(jk, "f", 0.0/0);
|
||||||
|
#endif
|
||||||
keys = parse_jit_keys(jk.to_string());
|
keys = parse_jit_keys(jk.to_string());
|
||||||
k2 = {{"f","0x1.47ae147ae147bp-7"},
|
k2 = {{"f","0x1.47ae147ae147bp-7"},
|
||||||
{"f","0x1p-1"},
|
{"f","0x1p-1"},
|
||||||
|
|
|
@ -31,6 +31,7 @@ JIT_TEST(op_register) {
|
||||||
}
|
}
|
||||||
|
|
||||||
JIT_TEST(fused_op_relay_matmul) {
|
JIT_TEST(fused_op_relay_matmul) {
|
||||||
|
JK& jk = get_jk();
|
||||||
VarPtr a({10,10}, "float32");
|
VarPtr a({10,10}, "float32");
|
||||||
VarPtr b({10,10}, "float32");
|
VarPtr b({10,10}, "float32");
|
||||||
auto aa = make_broadcast_to_op(a, {10,10,10}, {2});
|
auto aa = make_broadcast_to_op(a, {10,10,10}, {2});
|
||||||
|
|
|
@ -13,7 +13,7 @@ void cuda_loop_schedule(NanoVector o_shape, int* masks, int* tdims);
|
||||||
|
|
||||||
JIT_TEST(cuda_loop_schedule) {
|
JIT_TEST(cuda_loop_schedule) {
|
||||||
auto check = [&](const vector<int64>& shape, const vector<int>& masks, vector<int> tdims={}) {
|
auto check = [&](const vector<int64>& shape, const vector<int>& masks, vector<int> tdims={}) {
|
||||||
int masks2[shape.size()];
|
STACK_ALLOC(int, masks2, shape.size());
|
||||||
int tdims2[6];
|
int tdims2[6];
|
||||||
cuda_loop_schedule(shape, masks2, tdims2);
|
cuda_loop_schedule(shape, masks2, tdims2);
|
||||||
while (tdims.size() < 6) tdims.push_back(1);
|
while (tdims.size() < 6) tdims.push_back(1);
|
||||||
|
|
|
@ -21,7 +21,7 @@ struct TestTask {
|
||||||
|
|
||||||
JIT_TEST(sfrl_allocator_time) {
|
JIT_TEST(sfrl_allocator_time) {
|
||||||
Allocator* allocator = get_allocator();
|
Allocator* allocator = get_allocator();
|
||||||
int max_allc_num = 10000;
|
constexpr int max_allc_num = 10000;
|
||||||
size_t id[max_allc_num];
|
size_t id[max_allc_num];
|
||||||
size_t temp[max_allc_num];
|
size_t temp[max_allc_num];
|
||||||
std::vector<TestTask> tasks;
|
std::vector<TestTask> tasks;
|
||||||
|
@ -52,7 +52,7 @@ JIT_TEST(sfrl_allocator_time) {
|
||||||
|
|
||||||
JIT_TEST(sfrl_allocator_share) {
|
JIT_TEST(sfrl_allocator_share) {
|
||||||
Allocator* allocator = get_allocator();
|
Allocator* allocator = get_allocator();
|
||||||
int max_allc_num = 10000;
|
constexpr int max_allc_num = 10000;
|
||||||
size_t id[max_allc_num];
|
size_t id[max_allc_num];
|
||||||
size_t temp[max_allc_num];
|
size_t temp[max_allc_num];
|
||||||
std::vector<TestTask> tasks;
|
std::vector<TestTask> tasks;
|
||||||
|
@ -88,7 +88,7 @@ JIT_TEST(sfrl_allocator_share) {
|
||||||
|
|
||||||
JIT_TEST(sfrl_allocator_share_without_size_and_ptr) {
|
JIT_TEST(sfrl_allocator_share_without_size_and_ptr) {
|
||||||
Allocator* allocator = get_allocator();
|
Allocator* allocator = get_allocator();
|
||||||
int max_allc_num = 1000;
|
constexpr int max_allc_num = 1000;
|
||||||
size_t id[max_allc_num];
|
size_t id[max_allc_num];
|
||||||
size_t temp[max_allc_num];
|
size_t temp[max_allc_num];
|
||||||
std::vector<TestTask> tasks;
|
std::vector<TestTask> tasks;
|
||||||
|
|
|
@ -22,7 +22,7 @@ struct UpdateQueue {
|
||||||
void auto_flush();
|
void auto_flush();
|
||||||
};
|
};
|
||||||
|
|
||||||
extern UpdateQueue update_queue;
|
EXTERN_LIB UpdateQueue update_queue;
|
||||||
|
|
||||||
} // jittor
|
} // jittor
|
||||||
|
|
||||||
|
|
|
@ -31,7 +31,7 @@ void write(const string& fname, const string& src) {
|
||||||
|
|
||||||
bool file_exist(const string& fname) {
|
bool file_exist(const string& fname) {
|
||||||
std::ifstream f(fname);
|
std::ifstream f(fname);
|
||||||
return f.good();
|
return f && f.good();
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
@ -45,23 +45,21 @@ string join(string a, string b) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void find_names(string cmd, vector<string>& input_names, string& output_name, map<string,vector<string>>& extra) {
|
void find_names(string cmd, vector<string>& input_names, string& output_name, map<string,vector<string>>& extra) {
|
||||||
size_t i=0;
|
|
||||||
while (i<cmd.size() && cmd[i] != ' ') i++;
|
|
||||||
CHECK(i<cmd.size());
|
|
||||||
// find space not in str
|
// find space not in str
|
||||||
|
#define is_quate(x) ((x)=='\'' || (x)=='\"')
|
||||||
auto pass = [&](size_t& j) {
|
auto pass = [&](size_t& j) {
|
||||||
while (j<cmd.size()) {
|
while (j<cmd.size()) {
|
||||||
if (cmd[j]=='\'') {
|
if (is_quate(cmd[j])) {
|
||||||
j++;
|
j++;
|
||||||
while (j<cmd.size() && cmd[j]!='\'') j++;
|
while (j<cmd.size() && !is_quate(cmd[j])) j++;
|
||||||
ASSERT(j<cmd.size());
|
ASSERT(j<cmd.size());
|
||||||
j++;
|
j++;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
while (j<cmd.size() && cmd[j]!=' ' && cmd[j]!='\'') j++;
|
while (j<cmd.size() && cmd[j]!=' ' && !is_quate(cmd[j])) j++;
|
||||||
if (j<cmd.size()) {
|
if (j<cmd.size()) {
|
||||||
if (cmd[j]==' ') break;
|
if (cmd[j]==' ') break;
|
||||||
if (cmd[j]=='\'') continue;
|
if (is_quate(cmd[j])) continue;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -69,15 +67,33 @@ void find_names(string cmd, vector<string>& input_names, string& output_name, ma
|
||||||
auto substr = [&](size_t i, size_t j) -> string {
|
auto substr = [&](size_t i, size_t j) -> string {
|
||||||
string s;
|
string s;
|
||||||
for (size_t k=i; k<j; k++)
|
for (size_t k=i; k<j; k++)
|
||||||
if (cmd[k]!='\'' && cmd[k]!='"') s += cmd[k];
|
if (!is_quate(cmd[k])) s += cmd[k];
|
||||||
return s;
|
return s;
|
||||||
};
|
};
|
||||||
|
size_t i=0;
|
||||||
|
pass(i);
|
||||||
while (i<cmd.size()) {
|
while (i<cmd.size()) {
|
||||||
if (cmd[i] == ' ') {
|
if (cmd[i] == ' ') {
|
||||||
i++;
|
i++;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if (cmd[i] == '-') {
|
if (cmd[i] == '-') {
|
||||||
|
#ifdef _MSC_VER
|
||||||
|
if (i+4<cmd.size() && cmd[i+1]=='F' && cmd[i+4]==' ') {
|
||||||
|
// -Fo: -Fe:
|
||||||
|
auto j=i+5;
|
||||||
|
while (j<cmd.size() && cmd[j] == ' ') j++;
|
||||||
|
CHECK(j<cmd.size());
|
||||||
|
auto k=j;
|
||||||
|
pass(k);
|
||||||
|
CHECK(j<k && output_name.size()==0);
|
||||||
|
// -Fo: xxx
|
||||||
|
// i j k
|
||||||
|
output_name = substr(j, k);
|
||||||
|
i = k;
|
||||||
|
continue;
|
||||||
|
} else
|
||||||
|
#endif
|
||||||
if (i+2<cmd.size() && cmd[i+1]=='o' && cmd[i+2]==' ') {
|
if (i+2<cmd.size() && cmd[i+1]=='o' && cmd[i+2]==' ') {
|
||||||
auto j=i+3;
|
auto j=i+3;
|
||||||
while (j<cmd.size() && cmd[j] == ' ') j++;
|
while (j<cmd.size() && cmd[j] == ' ') j++;
|
||||||
|
@ -141,6 +157,8 @@ size_t skip_comments(const string& src, size_t i) {
|
||||||
return i;
|
return i;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
map<string,string> jt_env;
|
||||||
|
|
||||||
void process(string src, vector<string>& input_names, string& cmd) {
|
void process(string src, vector<string>& input_names, string& cmd) {
|
||||||
for (size_t i=0; i<src.size(); i++) {
|
for (size_t i=0; i<src.size(); i++) {
|
||||||
i = skip_comments(src, i);
|
i = skip_comments(src, i);
|
||||||
|
@ -149,8 +167,9 @@ void process(string src, vector<string>& input_names, string& cmd) {
|
||||||
// #include "a.h"
|
// #include "a.h"
|
||||||
// i jk l
|
// i jk l
|
||||||
auto j=i+1;
|
auto j=i+1;
|
||||||
while (j<src.size() && src[j] != ' ') j++;
|
while (j<src.size() && (src[j] != ' ' && src[j] != '\n')) j++;
|
||||||
if (j>=src.size()) return;
|
if (j>=src.size()) return;
|
||||||
|
if (j-i != 8 && j-i != 6) continue;
|
||||||
auto k=j+1;
|
auto k=j+1;
|
||||||
while (k<src.size() && src[k] == ' ') k++;
|
while (k<src.size() && src[k] == ' ') k++;
|
||||||
if (k>=src.size()) return;
|
if (k>=src.size()) return;
|
||||||
|
@ -167,12 +186,22 @@ void process(string src, vector<string>& input_names, string& cmd) {
|
||||||
auto inc = src.substr(k, l-k);
|
auto inc = src.substr(k, l-k);
|
||||||
auto env = getenv(inc.c_str());
|
auto env = getenv(inc.c_str());
|
||||||
if (env && string(env)!="0") {
|
if (env && string(env)!="0") {
|
||||||
string dflag = " -D"+inc+"="+string(env)+" -o ";
|
auto senv = string(env);
|
||||||
|
if (!jt_env.count(inc)) {
|
||||||
|
LOGe << "Load JT env ok:" << inc << senv;
|
||||||
|
jt_env[inc] = senv;
|
||||||
|
}
|
||||||
|
string dflag = " -D"+inc+"="+senv;
|
||||||
if (cmd.find(dflag) == string::npos) {
|
if (cmd.find(dflag) == string::npos) {
|
||||||
// -D flags should insert before -o flag
|
// -D flags should insert before -o flag
|
||||||
auto cmds = split(cmd, " -o ", 2);
|
#ifdef _MSC_VER
|
||||||
|
string patt = " -Fo: ";
|
||||||
|
#else
|
||||||
|
string patt = " -o ";
|
||||||
|
#endif
|
||||||
|
auto cmds = split(cmd, patt, 2);
|
||||||
if (cmds.size() == 2) {
|
if (cmds.size() == 2) {
|
||||||
cmd = cmds[0] + dflag + cmds[1];
|
cmd = cmds[0] + dflag + patt + cmds[1];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -199,7 +228,7 @@ static inline void check_win_file(const string& name) {
|
||||||
|
|
||||||
static inline bool is_full_path(const string& name) {
|
static inline bool is_full_path(const string& name) {
|
||||||
#ifdef _WIN32
|
#ifdef _WIN32
|
||||||
return name.size()>=2 && name[1]==':';
|
return name.size()>=2 && (name[1]==':' || (name[0]=='\\' && name[1]=='\\'));
|
||||||
#else
|
#else
|
||||||
return name.size() && name[0]=='/';
|
return name.size() && name[0]=='/';
|
||||||
#endif
|
#endif
|
||||||
|
@ -217,6 +246,7 @@ bool cache_compile(string cmd, const string& cache_path, const string& jittor_pa
|
||||||
unordered_set<string> processed;
|
unordered_set<string> processed;
|
||||||
auto src_path = join(jittor_path, "src");
|
auto src_path = join(jittor_path, "src");
|
||||||
const auto& extra_include = extra["I"];
|
const auto& extra_include = extra["I"];
|
||||||
|
string tmp_dir =join(cache_path, "obj_files");
|
||||||
for (size_t i=0; i<input_names.size(); i++) {
|
for (size_t i=0; i<input_names.size(); i++) {
|
||||||
if (processed.count(input_names[i]) != 0)
|
if (processed.count(input_names[i]) != 0)
|
||||||
continue;
|
continue;
|
||||||
|
@ -224,10 +254,13 @@ bool cache_compile(string cmd, const string& cache_path, const string& jittor_pa
|
||||||
continue;
|
continue;
|
||||||
processed.insert(input_names[i]);
|
processed.insert(input_names[i]);
|
||||||
auto src = read_all(input_names[i]);
|
auto src = read_all(input_names[i]);
|
||||||
ASSERT(src.size()) << "Source read failed:" << input_names[i];
|
ASSERT(src.size()) << "Source read failed:" << input_names[i] << "cmd:" << cmd;
|
||||||
auto hash = S(hash64(src));
|
auto hash = S(hash64(src));
|
||||||
vector<string> new_names;
|
vector<string> new_names;
|
||||||
process(src, new_names, cmd);
|
auto back = input_names[i].back();
|
||||||
|
// *.obj, *.o, *.pyd
|
||||||
|
if (back != 'j' && back != 'o' && back != 'd')
|
||||||
|
process(src, new_names, cmd);
|
||||||
for (auto& name : new_names) {
|
for (auto& name : new_names) {
|
||||||
string full_name;
|
string full_name;
|
||||||
if (name.substr(0, 4) == "jit/" || name.substr(0, 4) == "gen/")
|
if (name.substr(0, 4) == "jit/" || name.substr(0, 4) == "gen/")
|
||||||
|
@ -261,14 +294,15 @@ bool cache_compile(string cmd, const string& cache_path, const string& jittor_pa
|
||||||
if (output_cache_key.size() == 0) {
|
if (output_cache_key.size() == 0) {
|
||||||
LOGvv << "Cache key of" << output_name << "not found.";
|
LOGvv << "Cache key of" << output_name << "not found.";
|
||||||
LOGvvv << "Run cmd:" << cmd;
|
LOGvvv << "Run cmd:" << cmd;
|
||||||
system_with_check(cmd.c_str());
|
check_win_file(output_name);
|
||||||
|
system_with_check(cmd.c_str(), tmp_dir.c_str());
|
||||||
ran = true;
|
ran = true;
|
||||||
}
|
}
|
||||||
if (output_cache_key.size() != 0 && output_cache_key != cache_key) {
|
if (output_cache_key.size() != 0 && output_cache_key != cache_key) {
|
||||||
LOGvv << "Cache key of" << output_name << "changed.";
|
LOGvv << "Cache key of" << output_name << "changed.";
|
||||||
LOGvvv << "Run cmd:" << cmd;
|
LOGvvv << "Run cmd:" << cmd;
|
||||||
check_win_file(output_name);
|
check_win_file(output_name);
|
||||||
system_with_check(cmd.c_str());
|
system_with_check(cmd.c_str(), tmp_dir.c_str());
|
||||||
ran = true;
|
ran = true;
|
||||||
}
|
}
|
||||||
if (output_cache_key != cache_key) {
|
if (output_cache_key != cache_key) {
|
||||||
|
@ -277,7 +311,7 @@ bool cache_compile(string cmd, const string& cache_path, const string& jittor_pa
|
||||||
write(output_name+".key", cache_key);
|
write(output_name+".key", cache_key);
|
||||||
}
|
}
|
||||||
if (!ran)
|
if (!ran)
|
||||||
LOGvv << "Command cached:" << cmd;
|
LOGvvvv << "Command cached:" << cmd;
|
||||||
return ran;
|
return ran;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,58 @@
|
||||||
|
// ***************************************************************
|
||||||
|
// Copyright (c) 2021 Jittor. All Rights Reserved.
|
||||||
|
// Maintainers: Dun Liang <randonlang@gmail.com>.
|
||||||
|
// This file is subject to the terms and conditions defined in
|
||||||
|
// file 'LICENSE.txt', which is part of this source code package.
|
||||||
|
// ***************************************************************
|
||||||
|
|
||||||
|
#ifndef _WIN32
|
||||||
|
#include <sys/wait.h>
|
||||||
|
#ifdef __linux__
|
||||||
|
#include <sys/prctl.h>
|
||||||
|
#endif
|
||||||
|
#include <unistd.h>
|
||||||
|
#include <execinfo.h>
|
||||||
|
#include <sys/wait.h>
|
||||||
|
#include <sys/time.h>
|
||||||
|
#else
|
||||||
|
#include <wchar.h>
|
||||||
|
#include <windows.h>
|
||||||
|
#endif
|
||||||
|
#ifdef _MSC_VER
|
||||||
|
#include <process.h>
|
||||||
|
#include <synchapi.h>
|
||||||
|
#define getpid _getpid
|
||||||
|
inline void sleep(int s) { Sleep(s*1000); }
|
||||||
|
#else
|
||||||
|
#include <unistd.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
|
||||||
|
#ifdef _MSC_VER
|
||||||
|
|
||||||
|
// typedef struct timeval {
|
||||||
|
// long tv_sec;
|
||||||
|
// long tv_usec;
|
||||||
|
// } timeval;
|
||||||
|
|
||||||
|
inline int gettimeofday(struct timeval * tp, struct timezone * tzp)
|
||||||
|
{
|
||||||
|
// Note: some broken versions only have 8 trailing zero's, the correct epoch has 9 trailing zero's
|
||||||
|
// This magic number is the number of 100 nanosecond intervals since January 1, 1601 (UTC)
|
||||||
|
// until 00:00:00 January 1, 1970
|
||||||
|
static const uint64_t EPOCH = ((uint64_t) 116444736000000000ULL);
|
||||||
|
|
||||||
|
SYSTEMTIME system_time;
|
||||||
|
FILETIME file_time;
|
||||||
|
uint64_t time;
|
||||||
|
|
||||||
|
GetSystemTime( &system_time );
|
||||||
|
SystemTimeToFileTime( &system_time, &file_time );
|
||||||
|
time = ((uint64_t)file_time.dwLowDateTime ) ;
|
||||||
|
time += ((uint64_t)file_time.dwHighDateTime) << 32;
|
||||||
|
|
||||||
|
tp->tv_sec = (long) ((time - EPOCH) / 10000000L);
|
||||||
|
tp->tv_usec = (long) (system_time.wMilliseconds * 1000);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
#endif
|
|
@ -19,9 +19,209 @@
|
||||||
#include <iterator>
|
#include <iterator>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
|
#ifdef _WIN32
|
||||||
|
#include <exception>
|
||||||
|
#include <windows.h>
|
||||||
|
#include <eh.h>
|
||||||
|
#include <sstream>
|
||||||
|
#endif
|
||||||
|
#include "utils/seh.h"
|
||||||
|
|
||||||
namespace jittor {
|
namespace jittor {
|
||||||
|
|
||||||
|
#ifdef _WIN32
|
||||||
|
|
||||||
|
using std::stringstream;
|
||||||
|
|
||||||
|
void raise_win_error(int ierr) {
|
||||||
|
DWORD err = (DWORD)ierr;
|
||||||
|
WCHAR *s_buf = NULL; /* Free via LocalFree */
|
||||||
|
stringstream message;
|
||||||
|
|
||||||
|
if (err==0) {
|
||||||
|
err = GetLastError();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto len = FormatMessageW(
|
||||||
|
/* Error API error */
|
||||||
|
FORMAT_MESSAGE_ALLOCATE_BUFFER |
|
||||||
|
FORMAT_MESSAGE_FROM_SYSTEM |
|
||||||
|
FORMAT_MESSAGE_IGNORE_INSERTS,
|
||||||
|
NULL, /* no message source */
|
||||||
|
err,
|
||||||
|
MAKELANGID(LANG_NEUTRAL,
|
||||||
|
SUBLANG_DEFAULT), /* Default language */
|
||||||
|
(LPWSTR) &s_buf,
|
||||||
|
0, /* size not used */
|
||||||
|
NULL); /* no args */
|
||||||
|
|
||||||
|
if (len==0) {
|
||||||
|
/* Only seen this in out of mem situations */
|
||||||
|
message << "Windows Error " << err;
|
||||||
|
s_buf = NULL;
|
||||||
|
} else {
|
||||||
|
/* remove trailing cr/lf and dots */
|
||||||
|
while (len > 0 && (s_buf[len-1] <= L' ' || s_buf[len-1] == L'.'))
|
||||||
|
s_buf[--len] = L'\0';
|
||||||
|
message << s_buf;
|
||||||
|
}
|
||||||
|
if (s_buf)
|
||||||
|
LocalFree(s_buf);
|
||||||
|
throw std::runtime_error(message.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
void raise_cxx_exception(DWORD code, const EXCEPTION_RECORD* pr) {
|
||||||
|
|
||||||
|
/* The 'code' is a normal win32 error code so it could be handled by
|
||||||
|
raise_win_error(). However, for some errors, we have additional
|
||||||
|
information not included in the error code. We handle those here and
|
||||||
|
delegate all others to the generic function. */
|
||||||
|
stringstream message;
|
||||||
|
switch (code) {
|
||||||
|
case EXCEPTION_ACCESS_VIOLATION:
|
||||||
|
/* The thread attempted to read from or write
|
||||||
|
to a virtual address for which it does not
|
||||||
|
have the appropriate access. */
|
||||||
|
if (pr->ExceptionInformation[0] == 0)
|
||||||
|
message << "exception: access violation reading " << (void*)pr->ExceptionInformation[1];
|
||||||
|
else
|
||||||
|
message << "exception: access violation writing " << (void*)pr->ExceptionInformation[1];
|
||||||
|
break;
|
||||||
|
|
||||||
|
case EXCEPTION_BREAKPOINT:
|
||||||
|
/* A breakpoint was encountered. */
|
||||||
|
message << "exception: breakpoint encountered";
|
||||||
|
break;
|
||||||
|
|
||||||
|
case EXCEPTION_DATATYPE_MISALIGNMENT:
|
||||||
|
/* The thread attempted to read or write data that is
|
||||||
|
misaligned on hardware that does not provide
|
||||||
|
alignment. For example, 16-bit values must be
|
||||||
|
aligned on 2-byte boundaries, 32-bit values on
|
||||||
|
4-byte boundaries, and so on. */
|
||||||
|
message << "exception: datatype misalignment";
|
||||||
|
break;
|
||||||
|
|
||||||
|
case EXCEPTION_SINGLE_STEP:
|
||||||
|
/* A trace trap or other single-instruction mechanism
|
||||||
|
signaled that one instruction has been executed. */
|
||||||
|
message << "exception: single step";
|
||||||
|
break;
|
||||||
|
|
||||||
|
case EXCEPTION_ARRAY_BOUNDS_EXCEEDED:
|
||||||
|
/* The thread attempted to access an array element
|
||||||
|
that is out of bounds, and the underlying hardware
|
||||||
|
supports bounds checking. */
|
||||||
|
message << "exception: array bounds exceeded";
|
||||||
|
break;
|
||||||
|
|
||||||
|
case EXCEPTION_FLT_DENORMAL_OPERAND:
|
||||||
|
/* One of the operands in a floating-point operation
|
||||||
|
is denormal. A denormal value is one that is too
|
||||||
|
small to represent as a standard floating-point
|
||||||
|
value. */
|
||||||
|
message << "exception: floating-point operand denormal";
|
||||||
|
break;
|
||||||
|
|
||||||
|
case EXCEPTION_FLT_DIVIDE_BY_ZERO:
|
||||||
|
/* The thread attempted to divide a floating-point
|
||||||
|
value by a floating-point divisor of zero. */
|
||||||
|
message << "exception: float divide by zero";
|
||||||
|
break;
|
||||||
|
|
||||||
|
case EXCEPTION_FLT_INEXACT_RESULT:
|
||||||
|
/* The result of a floating-point operation cannot be
|
||||||
|
represented exactly as a decimal fraction. */
|
||||||
|
message << "exception: float inexact";
|
||||||
|
break;
|
||||||
|
|
||||||
|
case EXCEPTION_FLT_INVALID_OPERATION:
|
||||||
|
/* This exception represents any floating-point
|
||||||
|
exception not included in this list. */
|
||||||
|
message << "exception: float invalid operation";
|
||||||
|
break;
|
||||||
|
|
||||||
|
case EXCEPTION_FLT_OVERFLOW:
|
||||||
|
/* The exponent of a floating-point operation is
|
||||||
|
greater than the magnitude allowed by the
|
||||||
|
corresponding type. */
|
||||||
|
message << "exception: float overflow";
|
||||||
|
break;
|
||||||
|
|
||||||
|
case EXCEPTION_FLT_STACK_CHECK:
|
||||||
|
/* The stack overflowed or underflowed as the result
|
||||||
|
of a floating-point operation. */
|
||||||
|
message << "exception: stack over/underflow";
|
||||||
|
break;
|
||||||
|
|
||||||
|
case EXCEPTION_STACK_OVERFLOW:
|
||||||
|
/* The stack overflowed or underflowed as the result
|
||||||
|
of a floating-point operation. */
|
||||||
|
message << "exception: stack overflow";
|
||||||
|
break;
|
||||||
|
|
||||||
|
case EXCEPTION_FLT_UNDERFLOW:
|
||||||
|
/* The exponent of a floating-point operation is less
|
||||||
|
than the magnitude allowed by the corresponding
|
||||||
|
type. */
|
||||||
|
message << "exception: float underflow";
|
||||||
|
break;
|
||||||
|
|
||||||
|
case EXCEPTION_INT_DIVIDE_BY_ZERO:
|
||||||
|
/* The thread attempted to divide an integer value by
|
||||||
|
an integer divisor of zero. */
|
||||||
|
message << "exception: integer divide by zero";
|
||||||
|
break;
|
||||||
|
|
||||||
|
case EXCEPTION_INT_OVERFLOW:
|
||||||
|
/* The result of an integer operation caused a carry
|
||||||
|
out of the most significant bit of the result. */
|
||||||
|
message << "exception: integer overflow";
|
||||||
|
break;
|
||||||
|
|
||||||
|
case EXCEPTION_PRIV_INSTRUCTION:
|
||||||
|
/* The thread attempted to execute an instruction
|
||||||
|
whose operation is not allowed in the current
|
||||||
|
machine mode. */
|
||||||
|
message << "exception: privileged instruction";
|
||||||
|
break;
|
||||||
|
|
||||||
|
case EXCEPTION_NONCONTINUABLE_EXCEPTION:
|
||||||
|
/* The thread attempted to continue execution after a
|
||||||
|
noncontinuable exception occurred. */
|
||||||
|
message << "exception: nocontinuable";
|
||||||
|
break;
|
||||||
|
|
||||||
|
case 0xE06D7363:
|
||||||
|
/* magic number(0xE06D7363) of c++ exception:
|
||||||
|
https://devblogs.microsoft.com/oldnewthing/20100730-00/?p=13273
|
||||||
|
*/
|
||||||
|
message << "Error c++ exception";
|
||||||
|
break;
|
||||||
|
|
||||||
|
default:
|
||||||
|
raise_win_error(code);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
// std::cout << message.str() << std::endl;
|
||||||
|
throw std::runtime_error(message.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
DWORD HandleException(EXCEPTION_POINTERS *ptrs,
|
||||||
|
DWORD *pdw, EXCEPTION_RECORD *record)
|
||||||
|
{
|
||||||
|
*pdw = ptrs->ExceptionRecord->ExceptionCode;
|
||||||
|
*record = *ptrs->ExceptionRecord;
|
||||||
|
/* We don't want to catch breakpoint exceptions, they are used to attach
|
||||||
|
* a debugger to the process.
|
||||||
|
*/
|
||||||
|
if (*pdw == EXCEPTION_BREAKPOINT)
|
||||||
|
return EXCEPTION_CONTINUE_SEARCH;
|
||||||
|
return EXCEPTION_EXECUTE_HANDLER;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
void init_subprocess() {
|
void init_subprocess() {
|
||||||
#ifdef __linux__
|
#ifdef __linux__
|
||||||
prctl(PR_SET_PDEATHSIG, SIGKILL);
|
prctl(PR_SET_PDEATHSIG, SIGKILL);
|
||||||
|
@ -193,7 +393,7 @@ static void pyjt_def_core(PyObject* m) {
|
||||||
{ R""(cache_compile)"",
|
{ R""(cache_compile)"",
|
||||||
|
|
||||||
(PyCFunction)(PyObject* (*)(PyObject*,PyObject**,int64,PyObject*))[](PyObject* self, PyObject** args, int64 n, PyObject* kw) -> PyObject* {
|
(PyCFunction)(PyObject* (*)(PyObject*,PyObject**,int64,PyObject*))[](PyObject* self, PyObject** args, int64 n, PyObject* kw) -> PyObject* {
|
||||||
try {
|
try {_JT_SEH_START3;
|
||||||
;
|
;
|
||||||
uint64 arg_filled=0;
|
uint64 arg_filled=0;
|
||||||
(void)arg_filled;
|
(void)arg_filled;
|
||||||
|
@ -270,7 +470,7 @@ static void pyjt_def_core(PyObject* m) {
|
||||||
}
|
}
|
||||||
|
|
||||||
LOGf << "Not a valid call.";
|
LOGf << "Not a valid call.";
|
||||||
} catch (const std::exception& e) {
|
_JT_SEH_END3; } catch (const std::exception& e) {
|
||||||
if (!PyErr_Occurred()) {
|
if (!PyErr_Occurred()) {
|
||||||
PyErr_Format(PyExc_RuntimeError, e.what());
|
PyErr_Format(PyExc_RuntimeError, e.what());
|
||||||
}
|
}
|
||||||
|
@ -287,7 +487,7 @@ bool cache_compile(const string& cmd, const string& cache_path="", const string&
|
||||||
{ R""(log)"",
|
{ R""(log)"",
|
||||||
|
|
||||||
(PyCFunction)(PyObject* (*)(PyObject*,PyObject**,int64,PyObject*))[](PyObject* self, PyObject** args, int64 n, PyObject* kw) -> PyObject* {
|
(PyCFunction)(PyObject* (*)(PyObject*,PyObject**,int64,PyObject*))[](PyObject* self, PyObject** args, int64 n, PyObject* kw) -> PyObject* {
|
||||||
try {
|
try {_JT_SEH_START3;
|
||||||
;
|
;
|
||||||
uint64 arg_filled=0;
|
uint64 arg_filled=0;
|
||||||
(void)arg_filled;
|
(void)arg_filled;
|
||||||
|
@ -357,7 +557,7 @@ bool cache_compile(const string& cmd, const string& cache_path="", const string&
|
||||||
}
|
}
|
||||||
|
|
||||||
LOGf << "Not a valid call.";
|
LOGf << "Not a valid call.";
|
||||||
} catch (const std::exception& e) {
|
_JT_SEH_END3; } catch (const std::exception& e) {
|
||||||
if (!PyErr_Occurred()) {
|
if (!PyErr_Occurred()) {
|
||||||
PyErr_Format(PyExc_RuntimeError, e.what());
|
PyErr_Format(PyExc_RuntimeError, e.what());
|
||||||
}
|
}
|
||||||
|
@ -374,7 +574,7 @@ void log(const std::string& fileline, const char* level, int verbose, const std:
|
||||||
{ R""(init_subprocess)"",
|
{ R""(init_subprocess)"",
|
||||||
|
|
||||||
(PyCFunction)(PyObject* (*)(PyObject*,PyObject**,int64,PyObject*))[](PyObject* self, PyObject** args, int64 n, PyObject* kw) -> PyObject* {
|
(PyCFunction)(PyObject* (*)(PyObject*,PyObject**,int64,PyObject*))[](PyObject* self, PyObject** args, int64 n, PyObject* kw) -> PyObject* {
|
||||||
try {
|
try {_JT_SEH_START3;
|
||||||
;
|
;
|
||||||
uint64 arg_filled=0;
|
uint64 arg_filled=0;
|
||||||
(void)arg_filled;
|
(void)arg_filled;
|
||||||
|
@ -386,7 +586,7 @@ void log(const std::string& fileline, const char* level, int verbose, const std:
|
||||||
}
|
}
|
||||||
|
|
||||||
LOGf << "Not a valid call.";
|
LOGf << "Not a valid call.";
|
||||||
} catch (const std::exception& e) {
|
_JT_SEH_END3; } catch (const std::exception& e) {
|
||||||
if (!PyErr_Occurred()) {
|
if (!PyErr_Occurred()) {
|
||||||
PyErr_Format(PyExc_RuntimeError, e.what());
|
PyErr_Format(PyExc_RuntimeError, e.what());
|
||||||
}
|
}
|
||||||
|
@ -403,7 +603,7 @@ void init_subprocess()
|
||||||
{ R""(log_capture_start)"",
|
{ R""(log_capture_start)"",
|
||||||
|
|
||||||
(PyCFunction)(PyObject* (*)(PyObject*,PyObject**,int64,PyObject*))[](PyObject* self, PyObject** args, int64 n, PyObject* kw) -> PyObject* {
|
(PyCFunction)(PyObject* (*)(PyObject*,PyObject**,int64,PyObject*))[](PyObject* self, PyObject** args, int64 n, PyObject* kw) -> PyObject* {
|
||||||
try {
|
try {_JT_SEH_START3;
|
||||||
;
|
;
|
||||||
uint64 arg_filled=0;
|
uint64 arg_filled=0;
|
||||||
(void)arg_filled;
|
(void)arg_filled;
|
||||||
|
@ -415,7 +615,7 @@ void init_subprocess()
|
||||||
}
|
}
|
||||||
|
|
||||||
LOGf << "Not a valid call.";
|
LOGf << "Not a valid call.";
|
||||||
} catch (const std::exception& e) {
|
_JT_SEH_END3; } catch (const std::exception& e) {
|
||||||
if (!PyErr_Occurred()) {
|
if (!PyErr_Occurred()) {
|
||||||
PyErr_Format(PyExc_RuntimeError, e.what());
|
PyErr_Format(PyExc_RuntimeError, e.what());
|
||||||
}
|
}
|
||||||
|
@ -432,7 +632,7 @@ void log_capture_start()
|
||||||
{ R""(log_capture_stop)"",
|
{ R""(log_capture_stop)"",
|
||||||
|
|
||||||
(PyCFunction)(PyObject* (*)(PyObject*,PyObject**,int64,PyObject*))[](PyObject* self, PyObject** args, int64 n, PyObject* kw) -> PyObject* {
|
(PyCFunction)(PyObject* (*)(PyObject*,PyObject**,int64,PyObject*))[](PyObject* self, PyObject** args, int64 n, PyObject* kw) -> PyObject* {
|
||||||
try {
|
try {_JT_SEH_START3;
|
||||||
;
|
;
|
||||||
uint64 arg_filled=0;
|
uint64 arg_filled=0;
|
||||||
(void)arg_filled;
|
(void)arg_filled;
|
||||||
|
@ -444,7 +644,7 @@ void log_capture_start()
|
||||||
}
|
}
|
||||||
|
|
||||||
LOGf << "Not a valid call.";
|
LOGf << "Not a valid call.";
|
||||||
} catch (const std::exception& e) {
|
_JT_SEH_END3; } catch (const std::exception& e) {
|
||||||
if (!PyErr_Occurred()) {
|
if (!PyErr_Occurred()) {
|
||||||
PyErr_Format(PyExc_RuntimeError, e.what());
|
PyErr_Format(PyExc_RuntimeError, e.what());
|
||||||
}
|
}
|
||||||
|
@ -461,7 +661,7 @@ void log_capture_stop()
|
||||||
{ R""(log_capture_read)"",
|
{ R""(log_capture_read)"",
|
||||||
|
|
||||||
(PyCFunction)(PyObject* (*)(PyObject*,PyObject**,int64,PyObject*))[](PyObject* self, PyObject** args, int64 n, PyObject* kw) -> PyObject* {
|
(PyCFunction)(PyObject* (*)(PyObject*,PyObject**,int64,PyObject*))[](PyObject* self, PyObject** args, int64 n, PyObject* kw) -> PyObject* {
|
||||||
try {
|
try {_JT_SEH_START3;
|
||||||
;
|
;
|
||||||
uint64 arg_filled=0;
|
uint64 arg_filled=0;
|
||||||
(void)arg_filled;
|
(void)arg_filled;
|
||||||
|
@ -475,7 +675,7 @@ void log_capture_stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
LOGf << "Not a valid call.";
|
LOGf << "Not a valid call.";
|
||||||
} catch (const std::exception& e) {
|
_JT_SEH_END3; } catch (const std::exception& e) {
|
||||||
if (!PyErr_Occurred()) {
|
if (!PyErr_Occurred()) {
|
||||||
PyErr_Format(PyExc_RuntimeError, e.what());
|
PyErr_Format(PyExc_RuntimeError, e.what());
|
||||||
}
|
}
|
||||||
|
@ -492,7 +692,7 @@ void log_capture_read()
|
||||||
{ R""(ostream_redirect)"",
|
{ R""(ostream_redirect)"",
|
||||||
|
|
||||||
(PyCFunction)(PyObject* (*)(PyObject*,PyObject**,int64,PyObject*))[](PyObject* self, PyObject** args, int64 n, PyObject* kw) -> PyObject* {
|
(PyCFunction)(PyObject* (*)(PyObject*,PyObject**,int64,PyObject*))[](PyObject* self, PyObject** args, int64 n, PyObject* kw) -> PyObject* {
|
||||||
try {
|
try {_JT_SEH_START3;
|
||||||
;
|
;
|
||||||
uint64 arg_filled=0;
|
uint64 arg_filled=0;
|
||||||
(void)arg_filled;
|
(void)arg_filled;
|
||||||
|
@ -540,7 +740,7 @@ void log_capture_read()
|
||||||
}
|
}
|
||||||
|
|
||||||
LOGf << "Not a valid call.";
|
LOGf << "Not a valid call.";
|
||||||
} catch (const std::exception& e) {
|
_JT_SEH_END3; } catch (const std::exception& e) {
|
||||||
if (!PyErr_Occurred()) {
|
if (!PyErr_Occurred()) {
|
||||||
PyErr_Format(PyExc_RuntimeError, e.what());
|
PyErr_Format(PyExc_RuntimeError, e.what());
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,15 +6,10 @@
|
||||||
// ***************************************************************
|
// ***************************************************************
|
||||||
#include <string.h>
|
#include <string.h>
|
||||||
#include <signal.h>
|
#include <signal.h>
|
||||||
#include <sys/time.h>
|
|
||||||
#include <iomanip>
|
#include <iomanip>
|
||||||
#include <thread>
|
#include <thread>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <unistd.h>
|
#include "utils/cross_platform.h"
|
||||||
#ifdef _WIN32
|
|
||||||
#include <wchar.h>
|
|
||||||
#include <windows.h>
|
|
||||||
#endif
|
|
||||||
#include "utils/log.h"
|
#include "utils/log.h"
|
||||||
#include "utils/mwsr_list.h"
|
#include "utils/mwsr_list.h"
|
||||||
#include "utils/str_utils.h"
|
#include "utils/str_utils.h"
|
||||||
|
@ -72,6 +67,7 @@ static bool supports_color() {
|
||||||
return term_supports_color;
|
return term_supports_color;
|
||||||
}
|
}
|
||||||
bool g_supports_color = supports_color();
|
bool g_supports_color = supports_color();
|
||||||
|
string thread_local thread_name;
|
||||||
|
|
||||||
struct timeval start_tv;
|
struct timeval start_tv;
|
||||||
|
|
||||||
|
@ -166,10 +162,10 @@ void log_capture(const string& s) {
|
||||||
|
|
||||||
DECLARE_FLAG(int, log_silent);
|
DECLARE_FLAG(int, log_silent);
|
||||||
|
|
||||||
void send_log(std::ostringstream&& out) {
|
void send_log(std::ostringstream&& out, char level, int verbose) {
|
||||||
if (log_capture_enabled)
|
if (log_capture_enabled)
|
||||||
log_capture(out.str());
|
log_capture(out.str());
|
||||||
if (log_silent) return;
|
if ((level=='i' || level=='w') && log_silent) return;
|
||||||
if (!log_sync) {
|
if (!log_sync) {
|
||||||
#if LOG_ASYNC
|
#if LOG_ASYNC
|
||||||
mwsr_list_log::push(move(out));
|
mwsr_list_log::push(move(out));
|
||||||
|
@ -203,12 +199,15 @@ void log_exiting();
|
||||||
bool exited = false;
|
bool exited = false;
|
||||||
size_t thread_local protected_page = 0;
|
size_t thread_local protected_page = 0;
|
||||||
int segfault_happen = 0;
|
int segfault_happen = 0;
|
||||||
string thread_local thread_name;
|
|
||||||
static int _pid = getpid();
|
static int _pid = getpid();
|
||||||
vector<void(*)()> cleanup_callback;
|
vector<void(*)()> cleanup_callback;
|
||||||
vector<void(*)()> sigquit_callback;
|
vector<void(*)()> sigquit_callback;
|
||||||
int64 last_q_time;
|
int64 last_q_time;
|
||||||
|
|
||||||
|
string& get_thread_name() {
|
||||||
|
return thread_name;
|
||||||
|
}
|
||||||
|
|
||||||
#ifdef _WIN32
|
#ifdef _WIN32
|
||||||
void handle_signal(int signal) {
|
void handle_signal(int signal) {
|
||||||
std::cerr << "Caught SIGNAL " << signal << ", quick exit";
|
std::cerr << "Caught SIGNAL " << signal << ", quick exit";
|
||||||
|
@ -432,7 +431,7 @@ If you still have problems, please contact us:
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef _WIN32
|
#ifdef _WIN32
|
||||||
int system_popen(const char *cmd) {
|
int system_popen(const char *cmd, const char* cwd) {
|
||||||
HANDLE g_hChildStd_OUT_Rd = NULL;
|
HANDLE g_hChildStd_OUT_Rd = NULL;
|
||||||
HANDLE g_hChildStd_OUT_Wr = NULL;
|
HANDLE g_hChildStd_OUT_Wr = NULL;
|
||||||
SECURITY_ATTRIBUTES saAttr;
|
SECURITY_ATTRIBUTES saAttr;
|
||||||
|
@ -472,7 +471,7 @@ int system_popen(const char *cmd) {
|
||||||
TRUE, // handles are inherited
|
TRUE, // handles are inherited
|
||||||
0, // creation flags
|
0, // creation flags
|
||||||
NULL, // use parent's environment
|
NULL, // use parent's environment
|
||||||
NULL, // use parent's current directory
|
cwd, // use cwd directory
|
||||||
&siStartInfo, // STARTUPINFO pointer
|
&siStartInfo, // STARTUPINFO pointer
|
||||||
&piProcInfo); // receives PROCESS_INFORMATION
|
&piProcInfo); // receives PROCESS_INFORMATION
|
||||||
|
|
||||||
|
@ -495,7 +494,8 @@ int system_popen(const char *cmd) {
|
||||||
if (!bSuccess || dwRead == 0)
|
if (!bSuccess || dwRead == 0)
|
||||||
break;
|
break;
|
||||||
output += chBuf;
|
output += chBuf;
|
||||||
bSuccess = WriteFile(hParentStdOut, chBuf,
|
if (log_v)
|
||||||
|
bSuccess = WriteFile(hParentStdOut, chBuf,
|
||||||
dwRead, &dwWritten, NULL);
|
dwRead, &dwWritten, NULL);
|
||||||
if (!bSuccess)
|
if (!bSuccess)
|
||||||
break;
|
break;
|
||||||
|
@ -508,6 +508,8 @@ int system_popen(const char *cmd) {
|
||||||
// of the child process, for example.
|
// of the child process, for example.
|
||||||
CloseHandle(piProcInfo.hProcess);
|
CloseHandle(piProcInfo.hProcess);
|
||||||
CloseHandle(piProcInfo.hThread);
|
CloseHandle(piProcInfo.hThread);
|
||||||
|
if (ec && !log_v)
|
||||||
|
LOGe << output;
|
||||||
|
|
||||||
if (ec) {
|
if (ec) {
|
||||||
check_cuda_unsupport_version(output);
|
check_cuda_unsupport_version(output);
|
||||||
|
@ -516,7 +518,7 @@ int system_popen(const char *cmd) {
|
||||||
return ec;
|
return ec;
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
int system_popen(const char* cmd) {
|
int system_popen(const char* cmd, const char* cwd) {
|
||||||
char buf[BUFSIZ];
|
char buf[BUFSIZ];
|
||||||
string cmd2;
|
string cmd2;
|
||||||
cmd2 = cmd;
|
cmd2 = cmd;
|
||||||
|
@ -542,8 +544,8 @@ int system_popen(const char* cmd) {
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
void system_with_check(const char* cmd) {
|
void system_with_check(const char* cmd, const char* cwd) {
|
||||||
auto ret = system_popen(cmd);
|
auto ret = system_popen(cmd, cwd);
|
||||||
CHECK(ret>=0 && ret<=256) << "Run cmd failed:" << cmd <<
|
CHECK(ret>=0 && ret<=256) << "Run cmd failed:" << cmd <<
|
||||||
"\nreturn ">> ret >> ". This might be an overcommit issue or out of memory."
|
"\nreturn ">> ret >> ". This might be an overcommit issue or out of memory."
|
||||||
<< "Try : sudo sysctl vm.overcommit_memory=1";
|
<< "Try : sudo sysctl vm.overcommit_memory=1";
|
||||||
|
|
|
@ -32,11 +32,26 @@ constexpr int32_t basename_index(const char * const path, const int32_t index =
|
||||||
#define __FILELINE__ \
|
#define __FILELINE__ \
|
||||||
(&((__FILE__ ":" STRINGIZE(__LINE__))[jittor::basename_index(__FILE__)]))
|
(&((__FILE__ ":" STRINGIZE(__LINE__))[jittor::basename_index(__FILE__)]))
|
||||||
|
|
||||||
|
#ifndef _WIN32
|
||||||
#define PREDICT_BRANCH_NOT_TAKEN(x) (__builtin_expect(x, 0))
|
#define PREDICT_BRANCH_NOT_TAKEN(x) (__builtin_expect(x, 0))
|
||||||
|
#else
|
||||||
|
#define PREDICT_BRANCH_NOT_TAKEN(x) (x)
|
||||||
|
#endif
|
||||||
|
|
||||||
extern uint32_t get_tid();
|
|
||||||
extern bool g_supports_color;
|
#ifdef _MSC_VER
|
||||||
extern void print_prefix(std::ostream* out);
|
#define STACK_ALLOC(T, a, n) T* a = (T*)_alloca(sizeof(T)*(n))
|
||||||
|
#define EXTERN_LIB extern __declspec(dllimport)
|
||||||
|
#define EXPORT_LIB __declspec(dllimport)
|
||||||
|
#else
|
||||||
|
#define STACK_ALLOC(T, a, n) T a[n]
|
||||||
|
#define EXTERN_LIB extern
|
||||||
|
#define EXPORT_LIB
|
||||||
|
#endif
|
||||||
|
|
||||||
|
EXTERN_LIB uint32_t get_tid();
|
||||||
|
EXTERN_LIB bool g_supports_color;
|
||||||
|
EXTERN_LIB void print_prefix(std::ostream* out);
|
||||||
|
|
||||||
#ifdef _WIN32
|
#ifdef _WIN32
|
||||||
constexpr char green[] = "\x1b[1;32m";
|
constexpr char green[] = "\x1b[1;32m";
|
||||||
|
@ -44,7 +59,7 @@ constexpr char red[] = "\x1b[1;31m";
|
||||||
constexpr char yellow[] = "\x1b[1;33m";
|
constexpr char yellow[] = "\x1b[1;33m";
|
||||||
|
|
||||||
|
|
||||||
static void get_color(char level, int verbose, const char*& color_begin, const char*& color_end) {
|
inline static void get_color(char level, int verbose, const char*& color_begin, const char*& color_end) {
|
||||||
if (level == 'i') {
|
if (level == 'i') {
|
||||||
if (verbose == 0) color_begin = "\x1b[1;32m"; else
|
if (verbose == 0) color_begin = "\x1b[1;32m"; else
|
||||||
if (verbose < 10) color_begin = "\x1b[1;32m"; else
|
if (verbose < 10) color_begin = "\x1b[1;32m"; else
|
||||||
|
@ -65,7 +80,7 @@ constexpr char green[] = "\033[38;5;2m";
|
||||||
constexpr char red[] = "\033[38;5;1m";
|
constexpr char red[] = "\033[38;5;1m";
|
||||||
constexpr char yellow[] = "\033[38;5;3m";
|
constexpr char yellow[] = "\033[38;5;3m";
|
||||||
|
|
||||||
static void get_color(char level, int verbose, const char*& color_begin, const char*& color_end) {
|
inline static void get_color(char level, int verbose, const char*& color_begin, const char*& color_end) {
|
||||||
if (level == 'i') {
|
if (level == 'i') {
|
||||||
if (verbose == 0) color_begin = "\033[38;5;2m"; else
|
if (verbose == 0) color_begin = "\033[38;5;2m"; else
|
||||||
if (verbose < 10) color_begin = "\033[38;5;250m"; else
|
if (verbose < 10) color_begin = "\033[38;5;250m"; else
|
||||||
|
@ -83,18 +98,22 @@ static void get_color(char level, int verbose, const char*& color_begin, const c
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
extern void send_log(std::ostringstream&& out);
|
EXTERN_LIB void send_log(std::ostringstream&& out, char level, int verbose);
|
||||||
extern void flush_log();
|
EXTERN_LIB void flush_log();
|
||||||
extern void log_capture_start();
|
EXTERN_LIB void log_capture_start();
|
||||||
extern void log_capture_stop();
|
EXTERN_LIB void log_capture_stop();
|
||||||
extern std::vector<std::map<string,string>> log_capture_read();
|
EXTERN_LIB std::vector<std::map<string,string>> log_capture_read();
|
||||||
extern string thread_local thread_name;
|
EXTERN_LIB string& get_thread_name();
|
||||||
|
|
||||||
struct Log {
|
struct Log {
|
||||||
std::ostringstream out;
|
std::ostringstream out;
|
||||||
const char* color_end;
|
const char* color_end;
|
||||||
|
int verbose;
|
||||||
|
char level;
|
||||||
|
|
||||||
Log(const char* const fileline, char level, int verbose) {
|
inline Log(const char* const fileline, char level, int verbose) {
|
||||||
|
this->verbose = verbose;
|
||||||
|
this->level = level;
|
||||||
const char* color_begin;
|
const char* color_begin;
|
||||||
get_color(level, verbose, color_begin, color_end);
|
get_color(level, verbose, color_begin, color_end);
|
||||||
if (g_supports_color) out << color_begin;
|
if (g_supports_color) out << color_begin;
|
||||||
|
@ -104,12 +123,12 @@ struct Log {
|
||||||
out << fileline << ']';
|
out << fileline << ']';
|
||||||
}
|
}
|
||||||
|
|
||||||
void end() {
|
inline void end() {
|
||||||
if (g_supports_color) out << color_end;
|
if (g_supports_color) out << color_end;
|
||||||
out << '\n';
|
out << '\n';
|
||||||
send_log(move(out));
|
send_log(move(out), level, verbose);
|
||||||
}
|
}
|
||||||
void flush() { flush_log(); }
|
inline void flush() { flush_log(); }
|
||||||
|
|
||||||
template<class T>
|
template<class T>
|
||||||
Log& operator<<(const T& a) { out << ' ' << a; return *this; }
|
Log& operator<<(const T& a) { out << ' ' << a; return *this; }
|
||||||
|
@ -118,11 +137,11 @@ struct Log {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct LogVoidify {
|
struct LogVoidify {
|
||||||
void operator&&(Log& log) { log.end(); }
|
inline void operator&&(Log& log) { log.end(); }
|
||||||
};
|
};
|
||||||
|
|
||||||
struct LogFatalVoidify {
|
struct LogFatalVoidify {
|
||||||
void operator&&(Log& log) {
|
inline void operator&&(Log& log) {
|
||||||
log.flush();
|
log.flush();
|
||||||
if (g_supports_color) log.out << log.color_end;
|
if (g_supports_color) log.out << log.color_end;
|
||||||
throw std::runtime_error(log.out.str());
|
throw std::runtime_error(log.out.str());
|
||||||
|
@ -170,9 +189,9 @@ template<class T> T get_from_env(const char* name,const T& _default) {
|
||||||
template<> std::string get_from_env(const char* name, const std::string& _default);
|
template<> std::string get_from_env(const char* name, const std::string& _default);
|
||||||
|
|
||||||
#define DECLARE_FLAG(type, name) \
|
#define DECLARE_FLAG(type, name) \
|
||||||
extern type name; \
|
EXTERN_LIB type name; \
|
||||||
extern std::string doc_ ## name; \
|
EXTERN_LIB std::string doc_ ## name; \
|
||||||
extern void set_ ## name (const type&);
|
EXTERN_LIB void set_ ## name (const type&);
|
||||||
|
|
||||||
|
|
||||||
#ifdef JIT
|
#ifdef JIT
|
||||||
|
@ -256,6 +275,6 @@ bool check_vlog(const char* fileline, int verbose);
|
||||||
#define LOGig LOGi >> jittor::green
|
#define LOGig LOGi >> jittor::green
|
||||||
#define LOGiy LOGi >> jittor::yellow
|
#define LOGiy LOGi >> jittor::yellow
|
||||||
|
|
||||||
void system_with_check(const char* cmd);
|
void system_with_check(const char* cmd, const char* cwd=nullptr);
|
||||||
|
|
||||||
} // jittor
|
} // jittor
|
|
@ -0,0 +1,77 @@
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
#ifdef _WIN32
|
||||||
|
#include <windows.h>
|
||||||
|
#include "common.h"
|
||||||
|
|
||||||
|
namespace jittor {
|
||||||
|
|
||||||
|
EXTERN_LIB void raise_win_error(int ierr);
|
||||||
|
EXTERN_LIB void raise_cxx_exception(DWORD code, const EXCEPTION_RECORD* pr);
|
||||||
|
EXTERN_LIB DWORD HandleException(EXCEPTION_POINTERS *ptrs,
|
||||||
|
DWORD *pdw, EXCEPTION_RECORD *record);
|
||||||
|
|
||||||
|
#define _JT_SEH_TRY \
|
||||||
|
DWORD dwExceptionCode = 0; \
|
||||||
|
EXCEPTION_RECORD record; \
|
||||||
|
__try {
|
||||||
|
|
||||||
|
#define _JT_SEH_CATCH \
|
||||||
|
} \
|
||||||
|
__except (HandleException(GetExceptionInformation(), \
|
||||||
|
&dwExceptionCode, &record)) { \
|
||||||
|
raise_cxx_exception(dwExceptionCode, &record); \
|
||||||
|
}
|
||||||
|
|
||||||
|
#define _JT_SEH_START \
|
||||||
|
return [&]() { \
|
||||||
|
_JT_SEH_TRY; \
|
||||||
|
return [&]() {
|
||||||
|
|
||||||
|
#define _JT_SEH_END \
|
||||||
|
}(); \
|
||||||
|
_JT_SEH_CATCH; \
|
||||||
|
}(); \
|
||||||
|
|
||||||
|
|
||||||
|
#define _JT_SEH_START2 \
|
||||||
|
[&]() { \
|
||||||
|
_JT_SEH_TRY;
|
||||||
|
|
||||||
|
#define _JT_SEH_END2 \
|
||||||
|
_JT_SEH_CATCH; \
|
||||||
|
}();
|
||||||
|
|
||||||
|
#ifdef JT_SEH_FULL
|
||||||
|
|
||||||
|
|
||||||
|
#define _JT_SEH_START3 \
|
||||||
|
return [&]() { \
|
||||||
|
_JT_SEH_TRY; \
|
||||||
|
return [&]() {
|
||||||
|
|
||||||
|
#define _JT_SEH_END3 \
|
||||||
|
}(); \
|
||||||
|
_JT_SEH_CATCH; \
|
||||||
|
}(); \
|
||||||
|
|
||||||
|
#else
|
||||||
|
|
||||||
|
#define _JT_SEH_START3
|
||||||
|
#define _JT_SEH_END3
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
|
||||||
|
#define _JT_SEH_TRY
|
||||||
|
#define _JT_SEH_CATCH
|
||||||
|
#define _JT_SEH_START
|
||||||
|
#define _JT_SEH_END
|
||||||
|
#define _JT_SEH_START2
|
||||||
|
#define _JT_SEH_END2
|
||||||
|
#define _JT_SEH_START3
|
||||||
|
#define _JT_SEH_END3
|
||||||
|
|
||||||
|
#endif
|
|
@ -6,19 +6,8 @@
|
||||||
// ***************************************************************
|
// ***************************************************************
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
#include <stdlib.h>
|
#include <stdlib.h>
|
||||||
#ifndef _WIN32
|
|
||||||
#include <sys/wait.h>
|
|
||||||
#ifdef __linux__
|
|
||||||
#include <sys/prctl.h>
|
|
||||||
#endif
|
|
||||||
#include <unistd.h>
|
|
||||||
#include <execinfo.h>
|
|
||||||
#include <sys/wait.h>
|
|
||||||
#else
|
|
||||||
#include <windows.h>
|
|
||||||
#endif
|
|
||||||
#include <unistd.h>
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
#include "utils/cross_platform.h"
|
||||||
#include "utils/tracer.h"
|
#include "utils/tracer.h"
|
||||||
|
|
||||||
namespace jittor {
|
namespace jittor {
|
||||||
|
@ -32,7 +21,7 @@ DEFINE_FLAG_WITH_SETTER(int, gdb_attach, 0, "gdb attach self process.");
|
||||||
|
|
||||||
string _extra_gdb_cmd;
|
string _extra_gdb_cmd;
|
||||||
|
|
||||||
int system_popen(const char* cmd);
|
int system_popen(const char* cmd, const char* cwd=nullptr);
|
||||||
|
|
||||||
#ifdef _WIN32
|
#ifdef _WIN32
|
||||||
string get_cmds(const vector<const char*>& argv) {
|
string get_cmds(const vector<const char*>& argv) {
|
||||||
|
@ -76,9 +65,9 @@ void setter_gdb_attach(int v) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
LOGi << "gdb attach for" << "pid=" >> pid_buf << argv;
|
||||||
// argv.insert(argv.end(), {name_buf, pid_buf, NULL});
|
// argv.insert(argv.end(), {name_buf, pid_buf, NULL});
|
||||||
argv.insert(argv.end(), {"-p", pid_buf, NULL});
|
argv.insert(argv.end(), {"-p", pid_buf, NULL});
|
||||||
LOGi << "gdb attach for" << "pid=" >> pid_buf << argv;
|
|
||||||
|
|
||||||
#ifdef _WIN32
|
#ifdef _WIN32
|
||||||
// _spawnvp(_P_OVERLAY, gdb_path.c_str(), (char* const*)&argv[0]);
|
// _spawnvp(_P_OVERLAY, gdb_path.c_str(), (char* const*)&argv[0]);
|
||||||
|
@ -150,6 +139,7 @@ void breakpoint() {
|
||||||
}
|
}
|
||||||
|
|
||||||
void print_trace() {
|
void print_trace() {
|
||||||
|
LOGir << "???" << gdb_path;
|
||||||
if (gdb_path.size()) {
|
if (gdb_path.size()) {
|
||||||
// using gdb to print the stack trace
|
// using gdb to print the stack trace
|
||||||
char pid_buf[30];
|
char pid_buf[30];
|
||||||
|
|
|
@ -0,0 +1 @@
|
||||||
|
#define _P(...)
|
|
@ -21,11 +21,11 @@ namespace jittor {
|
||||||
|
|
||||||
DEFINE_FLAG(int, lazy_execution, 1, "Default enabled, if disable, use immediately eager execution rather than lazy execution, This flag makes error message and traceback infomation better. But this flag will raise memory consumption and lower the performance.");
|
DEFINE_FLAG(int, lazy_execution, 1, "Default enabled, if disable, use immediately eager execution rather than lazy execution, This flag makes error message and traceback infomation better. But this flag will raise memory consumption and lower the performance.");
|
||||||
|
|
||||||
list<VarHolder*> VarHolder::hold_vars;
|
list<VarHolder*> hold_vars;
|
||||||
|
|
||||||
void add_hold_vars(VarHolder* self) {
|
void add_hold_vars(VarHolder* self) {
|
||||||
VarHolder::hold_vars.push_front(self);
|
hold_vars.push_front(self);
|
||||||
self->iter = VarHolder::hold_vars.begin();
|
self->iter = hold_vars.begin();
|
||||||
if (lazy_execution) return;
|
if (lazy_execution) return;
|
||||||
auto v = self->var;
|
auto v = self->var;
|
||||||
for (int i=0; i<5; i++) {
|
for (int i=0; i<5; i++) {
|
||||||
|
@ -129,7 +129,7 @@ VarHolder* VarHolder::_update(VarHolder* v) {
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
extern Executor exe;
|
EXTERN_LIB Executor exe;
|
||||||
|
|
||||||
void VarHolder::sync(bool device_sync) {
|
void VarHolder::sync(bool device_sync) {
|
||||||
jittor::sync({this}, device_sync);
|
jittor::sync({this}, device_sync);
|
||||||
|
@ -162,12 +162,12 @@ ItemData VarHolder::item() {
|
||||||
}
|
}
|
||||||
|
|
||||||
// from fetch_op.cc
|
// from fetch_op.cc
|
||||||
extern list<VarPtr> fetcher;
|
EXTERN_LIB list<VarPtr> fetcher;
|
||||||
|
|
||||||
void sync_all(bool device_sync) {
|
void sync_all(bool device_sync) {
|
||||||
vector<Var*> vars;
|
vector<Var*> vars;
|
||||||
vars.reserve(VarHolder::hold_vars.size());
|
vars.reserve(hold_vars.size());
|
||||||
for (auto v : VarHolder::hold_vars) {
|
for (auto v : hold_vars) {
|
||||||
if (!v->var->_outputs.size())
|
if (!v->var->_outputs.size())
|
||||||
vars.push_back(v->var);
|
vars.push_back(v->var);
|
||||||
}
|
}
|
||||||
|
|
|
@ -30,6 +30,8 @@ struct ItemData {
|
||||||
|
|
||||||
typedef struct _object PyObject;
|
typedef struct _object PyObject;
|
||||||
|
|
||||||
|
EXTERN_LIB list<VarHolder*> hold_vars;
|
||||||
|
|
||||||
// @pyjt(Var)
|
// @pyjt(Var)
|
||||||
// @attrs(heaptype)
|
// @attrs(heaptype)
|
||||||
struct VarHolder {
|
struct VarHolder {
|
||||||
|
@ -82,7 +84,6 @@ struct VarHolder {
|
||||||
|
|
||||||
void operator=(VarPtr&& v);
|
void operator=(VarPtr&& v);
|
||||||
|
|
||||||
static list<VarHolder*> hold_vars;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* set the name of the Var.
|
* set the name of the Var.
|
||||||
|
|
|
@ -17,6 +17,8 @@ def all_eq(x, y):
|
||||||
convert = lambda x: x.astype("uint8") if x.dtype=="bool" else x
|
convert = lambda x: x.astype("uint8") if x.dtype=="bool" else x
|
||||||
x = convert(x)
|
x = convert(x)
|
||||||
y = convert(y)
|
y = convert(y)
|
||||||
|
if str(x.dtype).startswith("float"):
|
||||||
|
return str(y.dtype).startswith("float") and x.shape == y.shape and (x==y).all()
|
||||||
return x.dtype == y.dtype and x.shape == y.shape and (x==y).all()
|
return x.dtype == y.dtype and x.shape == y.shape and (x==y).all()
|
||||||
|
|
||||||
def check(op, *args):
|
def check(op, *args):
|
||||||
|
|
|
@ -76,23 +76,59 @@ class TestDataset(unittest.TestCase):
|
||||||
assert isinstance(batch[1], np.ndarray)
|
assert isinstance(batch[1], np.ndarray)
|
||||||
|
|
||||||
|
|
||||||
|
class YourDataset(Dataset):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.set_attrs(total_len=10240)
|
||||||
|
|
||||||
|
def __getitem__(self, k):
|
||||||
|
self.tmp = None
|
||||||
|
x = jt.array(k)
|
||||||
|
y = x
|
||||||
|
for i in range(10):
|
||||||
|
for j in range(i+2):
|
||||||
|
y = y + j - j
|
||||||
|
y.stop_fuse()
|
||||||
|
return x, y
|
||||||
|
|
||||||
|
|
||||||
|
class YourDataset2(Dataset):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.set_attrs(total_len=16)
|
||||||
|
|
||||||
|
def __getitem__(self, k):
|
||||||
|
return np.random.rand(2)
|
||||||
|
|
||||||
|
|
||||||
|
class YourDataset3(Dataset):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.set_attrs(total_len=16)
|
||||||
|
|
||||||
|
def __getitem__(self, k):
|
||||||
|
return random.randint(0,1000)
|
||||||
|
|
||||||
|
|
||||||
|
class YourDataset4(Dataset):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.set_attrs(total_len=160)
|
||||||
|
|
||||||
|
def __getitem__(self, k):
|
||||||
|
return jt.rand(2)
|
||||||
|
|
||||||
|
|
||||||
|
class YourDataset5(Dataset):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.set_attrs(total_len=160)
|
||||||
|
|
||||||
|
def __getitem__(self, k):
|
||||||
|
return { "a":np.array([1,2,3]) }
|
||||||
|
|
||||||
class TestDataset2(unittest.TestCase):
|
class TestDataset2(unittest.TestCase):
|
||||||
def test_dataset_use_jittor(self):
|
def test_dataset_use_jittor(self):
|
||||||
class YourDataset(Dataset):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.set_attrs(total_len=10240)
|
|
||||||
|
|
||||||
def __getitem__(self, k):
|
|
||||||
self.tmp = None
|
|
||||||
x = jt.array(k)
|
|
||||||
y = x
|
|
||||||
for i in range(10):
|
|
||||||
for j in range(i+2):
|
|
||||||
y = y + j - j
|
|
||||||
y.stop_fuse()
|
|
||||||
return x, y
|
|
||||||
|
|
||||||
dataset = YourDataset().set_attrs(batch_size=256, shuffle=True, num_workers=4)
|
dataset = YourDataset().set_attrs(batch_size=256, shuffle=True, num_workers=4)
|
||||||
dataset.tmp = jt.array([1,2,3,4,5])
|
dataset.tmp = jt.array([1,2,3,4,5])
|
||||||
dataset.tmp.sync()
|
dataset.tmp.sync()
|
||||||
|
@ -108,15 +144,8 @@ class TestDataset2(unittest.TestCase):
|
||||||
|
|
||||||
class TestDatasetSeed(unittest.TestCase):
|
class TestDatasetSeed(unittest.TestCase):
|
||||||
def test_np(self):
|
def test_np(self):
|
||||||
class YourDataset(Dataset):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.set_attrs(total_len=16)
|
|
||||||
|
|
||||||
def __getitem__(self, k):
|
dataset = YourDataset2().set_attrs(batch_size=1, shuffle=True, num_workers=4)
|
||||||
return np.random.rand(2)
|
|
||||||
|
|
||||||
dataset = YourDataset().set_attrs(batch_size=1, shuffle=True, num_workers=4)
|
|
||||||
for _ in range(10):
|
for _ in range(10):
|
||||||
dd = []
|
dd = []
|
||||||
for d in dataset:
|
for d in dataset:
|
||||||
|
@ -127,16 +156,9 @@ class TestDatasetSeed(unittest.TestCase):
|
||||||
|
|
||||||
def test_py_native(self):
|
def test_py_native(self):
|
||||||
import random
|
import random
|
||||||
class YourDataset(Dataset):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.set_attrs(total_len=16)
|
|
||||||
|
|
||||||
def __getitem__(self, k):
|
|
||||||
return random.randint(0,1000)
|
|
||||||
|
|
||||||
jt.set_global_seed(0)
|
jt.set_global_seed(0)
|
||||||
dataset = YourDataset().set_attrs(batch_size=1, shuffle=True, num_workers=4)
|
dataset = YourDataset3().set_attrs(batch_size=1, shuffle=True, num_workers=4)
|
||||||
for _ in range(10):
|
for _ in range(10):
|
||||||
dd = []
|
dd = []
|
||||||
for d in dataset:
|
for d in dataset:
|
||||||
|
@ -147,16 +169,9 @@ class TestDatasetSeed(unittest.TestCase):
|
||||||
|
|
||||||
def test_jtrand(self):
|
def test_jtrand(self):
|
||||||
import random
|
import random
|
||||||
class YourDataset(Dataset):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.set_attrs(total_len=160)
|
|
||||||
|
|
||||||
def __getitem__(self, k):
|
|
||||||
return jt.rand(2)
|
|
||||||
|
|
||||||
jt.set_global_seed(0)
|
jt.set_global_seed(0)
|
||||||
dataset = YourDataset().set_attrs(batch_size=1, shuffle=True, num_workers=4)
|
dataset = YourDataset4().set_attrs(batch_size=1, shuffle=True, num_workers=4)
|
||||||
for _ in range(10):
|
for _ in range(10):
|
||||||
dd = []
|
dd = []
|
||||||
for d in dataset:
|
for d in dataset:
|
||||||
|
@ -167,16 +182,9 @@ class TestDatasetSeed(unittest.TestCase):
|
||||||
|
|
||||||
def test_dict(self):
|
def test_dict(self):
|
||||||
import random
|
import random
|
||||||
class YourDataset(Dataset):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.set_attrs(total_len=160)
|
|
||||||
|
|
||||||
def __getitem__(self, k):
|
|
||||||
return { "a":np.array([1,2,3]) }
|
|
||||||
|
|
||||||
jt.set_global_seed(0)
|
jt.set_global_seed(0)
|
||||||
dataset = YourDataset().set_attrs(batch_size=1, shuffle=True, num_workers=4)
|
dataset = YourDataset5().set_attrs(batch_size=1, shuffle=True, num_workers=4)
|
||||||
for _ in range(10):
|
for _ in range(10):
|
||||||
dd = []
|
dd = []
|
||||||
for d in dataset:
|
for d in dataset:
|
||||||
|
@ -216,6 +224,11 @@ class TestDatasetSeed(unittest.TestCase):
|
||||||
assert z[i] == c
|
assert z[i] == c
|
||||||
|
|
||||||
def test_children_died(self):
|
def test_children_died(self):
|
||||||
|
if os.name == 'nt':
|
||||||
|
# TODO: windows cannot pass this test now
|
||||||
|
# don't know how to detect child died in windows
|
||||||
|
# some clue: https://ikriv.com/blog/?p=1431
|
||||||
|
return
|
||||||
src = """
|
src = """
|
||||||
import jittor as jt
|
import jittor as jt
|
||||||
from jittor.dataset import Dataset
|
from jittor.dataset import Dataset
|
||||||
|
@ -231,13 +244,13 @@ class YourDataset(Dataset):
|
||||||
while 1:
|
while 1:
|
||||||
pass
|
pass
|
||||||
return { "a":np.array([1,2,3]) }
|
return { "a":np.array([1,2,3]) }
|
||||||
|
if __name__ == "__main__":
|
||||||
|
dataset = YourDataset()
|
||||||
|
dataset.set_attrs(num_workers=2)
|
||||||
|
|
||||||
dataset = YourDataset()
|
for d in dataset:
|
||||||
dataset.set_attrs(num_workers=2)
|
dataset.workers[0].p.kill()
|
||||||
|
pass
|
||||||
for d in dataset:
|
|
||||||
dataset.workers[0].p.kill()
|
|
||||||
pass
|
|
||||||
"""
|
"""
|
||||||
fname = os.path.join(jt.flags.cache_path, "children_dead_test.py")
|
fname = os.path.join(jt.flags.cache_path, "children_dead_test.py")
|
||||||
with open(fname, 'w') as f:
|
with open(fname, 'w') as f:
|
||||||
|
@ -271,12 +284,13 @@ class YourDataset(Dataset):
|
||||||
pass
|
pass
|
||||||
return { "a":np.array([1,2,3]) }
|
return { "a":np.array([1,2,3]) }
|
||||||
|
|
||||||
dataset = YourDataset()
|
if __name__ == "__main__":
|
||||||
dataset.set_attrs(num_workers=2)
|
dataset = YourDataset()
|
||||||
|
dataset.set_attrs(num_workers=2)
|
||||||
|
|
||||||
for d in dataset:
|
for d in dataset:
|
||||||
break
|
break
|
||||||
dataset.terminate()
|
dataset.terminate()
|
||||||
"""
|
"""
|
||||||
fname = os.path.join(jt.flags.cache_path, "children_dead_test.py")
|
fname = os.path.join(jt.flags.cache_path, "children_dead_test.py")
|
||||||
with open(fname, 'w') as f:
|
with open(fname, 'w') as f:
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue