polish win_cuda on linux

This commit is contained in:
Dun Liang 2021-09-26 19:48:22 +08:00
parent e77f1ea7cb
commit c1ee6d9ed3
5 changed files with 69 additions and 21 deletions

View File

@ -126,7 +126,7 @@ def setup_mkl():
mkl_lib_path = os.path.join(mkl_home, "lib")
mkl_lib_name = os.path.join(mkl_lib_path, "libmkldnn.so")
extra_flags = f" -I\"{mkl_include_path}\" -L\"{mkl_lib_path}\" -lmkldnn -Wl,-rpath='{mkl_lib_path}' "
extra_flags = f" -I\"{mkl_include_path}\" -L\"{mkl_lib_path}\" -lmkldnn "
if os.name == 'nt':
mkl_lib_name = os.path.join(mkl_home, 'bin', 'dnnl.dll')
mkl_bin_path = os.path.join(mkl_home, 'bin')
@ -199,9 +199,9 @@ def setup_cuda_extern():
cuda_extern_src = os.path.join(jittor_path, "extern", "cuda", "src")
cuda_extern_files = [os.path.join(cuda_extern_src, name)
for name in os.listdir(cuda_extern_src)]
so_name = os.path.join(cache_path_cuda, "cuda_extern"+so)
so_name = os.path.join(cache_path_cuda, "libcuda_extern"+so)
compile(cc_path, cc_flags+f" -I\"{cuda_include}\" ", cuda_extern_files, so_name)
link_cuda_extern = f" -L\"{cache_path_cuda}\" -lcuda_extern "
link_cuda_extern = f" -L\"{cache_path_cuda}\" -llibcuda_extern "
ctypes.CDLL(so_name, dlopen_flags)
try:

View File

@ -118,7 +118,7 @@ def compile(compiler, flags, inputs, output, combind_build=False, cuda_flags="")
inputs = new_inputs
if len(inputs) == 1 or combind_build:
cmd = f"\"{compiler}\" {' '.join(inputs)} {flags} {link} -o {output}"
cmd = f"\"{compiler}\" {' '.join(inputs)} {flags} -o {output}"
return do_compile(fix_cl_flags(cmd))
# split compile object file and link
# remove -l -L flags when compile object files
@ -1019,7 +1019,18 @@ if platform.system() == 'Darwin':
kernel_opt_flags += " -Xpreprocessor -fopenmp "
elif cc_type != 'cl':
kernel_opt_flags += " -fopenmp "
fix_cl_flags = lambda x:x
def fix_cl_flags(cmd):
output = shsplit(cmd)
output2 = []
for s in output:
if s.startswith("-l") and ("cpython" in s or "lib" in s):
output2.append(f"-l:{s[2:]}.so")
elif s.startswith("-L"):
output2.append(f"{s} -Wl,-rpath={s[2:]}")
else:
output2.append(s)
return " ".join(output2)
if os.name == 'nt':
if cc_type == 'g++':
pass
@ -1251,9 +1262,9 @@ with jit_utils.import_scope(import_flags):
import jittor_core as core
flags = core.flags()
nvcc_flags = convert_nvcc_flags(cc_flags)
if has_cuda:
nvcc_flags = convert_nvcc_flags(cc_flags)
if len(flags.cuda_archs):
nvcc_flags += f" -arch=compute_{min(flags.cuda_archs)} "
nvcc_flags += ''.join(map(lambda x:f' -code=sm_{x} ', flags.cuda_archs))

View File

@ -33,7 +33,6 @@ DEFINE_FLAG(string, python_path, "", "Path of python interpreter");
DEFINE_FLAG(string, cache_path, "", "Cache path of jittor");
DEFINE_FLAG(int, rewrite_op, 1, "Rewrite source file of jit operator or not");
#ifdef _MSC_VER
vector<string> shsplit(const string& s) {
auto s1 = split(s, " ");
vector<string> s2;
@ -54,7 +53,8 @@ vector<string> shsplit(const string& s) {
return s2;
}
string fix_cl_flags(const string& cmd) {
string fix_cl_flags(const string& cmd, bool is_cuda) {
#ifdef _MSC_VER
auto flags = shsplit(cmd);
vector<string> output, output2;
@ -95,8 +95,31 @@ string fix_cl_flags(const string& cmd) {
cmdx += " ";
}
return cmdx;
}
#else
auto flags = shsplit(cmd);
vector<string> output;
for (auto& f : flags) {
if (startswith(f, "-l") &&
(f.find("cpython") != string::npos ||
f.find("lib") != string::npos))
output.push_back("-l:"+f.substr(2)+".so");
else if (startswith(f, "-L")) {
if (is_cuda)
output.push_back(f+" -Xlinker -rpath="+f.substr(2));
else
output.push_back(f+" -Wl,-rpath="+f.substr(2));
} else
output.push_back(f);
}
string cmdx = "";
for (auto& s : output) {
cmdx += s;
cmdx += " ";
}
return cmdx;
#endif
}
namespace jit_compiler {
@ -174,12 +197,12 @@ jit_op_entry_t compile(const string& jit_key, const string& src, const bool is_c
if (is_cuda_op) {
cmd = "\"" + nvcc_path + "\""
+ " \"" + jit_src_path + "\"" + other_src
+ nvcc_flags + extra_flags
+ fix_cl_flags(nvcc_flags + extra_flags, is_cuda_op)
+ " -o \"" + jit_lib_path + "\"";
} else {
cmd = "\"" + cc_path + "\""
+ " \"" + jit_src_path + "\"" + other_src
+ cc_flags + extra_flags
+ fix_cl_flags(cc_flags + extra_flags, is_cuda_op)
+ " -o \"" + jit_lib_path + "\"";
#ifdef __linux__
cmd = python_path+" "+jittor_path+"/utils/asm_tuner.py "
@ -193,12 +216,12 @@ jit_op_entry_t compile(const string& jit_key, const string& src, const bool is_c
+ nvcc_flags + extra_flags
+ " -o \"" + jit_lib_path + "\""
+ " -Xlinker -EXPORT:\""
+ symbol_name + "\"";;
+ symbol_name + "\"";
} else {
cmd = "\"" + cc_path + "\""
+ " \"" + jit_src_path + "\"" + other_src
+ " -Fe: \"" + jit_lib_path + "\" "
+ fix_cl_flags(cc_flags + extra_flags) + " -EXPORT:\""
+ fix_cl_flags(cc_flags + extra_flags, is_cuda_op) + " -EXPORT:\""
+ symbol_name + "\"";
}
#endif

View File

@ -241,7 +241,8 @@ bool cache_compile(string cmd, const string& cache_path, const string& jittor_pa
find_names(cmd, input_names, output_name, extra);
string output_cache_key;
bool ran = false;
output_cache_key = read_all(output_name+".key");
if (file_exist(output_name))
output_cache_key = read_all(output_name+".key");
string cache_key;
unordered_set<string> processed;
auto src_path = join(jittor_path, "src");

View File

@ -131,15 +131,28 @@ so_pos=cmd.find("_op.so")
# remove -Xclang ...
remove_clang_flag = lambda s: re.sub("-Xclang (('[^']*')|([^ ]*))", "", s)
def shsplit(s):
s1 = s.split(' ')
s2 = []
count = 0
for s in s1:
nc = s.count('"') + s.count('\'')
if count&1:
count += nc
s2[-1] += " "
s2[-1] += s
else:
count = nc
s2.append(s)
return s2
def remove_flags(flags, rm_flags):
flags = flags.split(" ")
flags = shsplit(flags)
output = []
for s in flags:
if s.startswith("-load"):
output.append(s)
continue
ss = s.replace("\"", "")
for rm in rm_flags:
if s.startswith(rm):
if ss.startswith(rm) or ss.endswith(rm):
break
else:
output.append(s)
@ -161,7 +174,7 @@ else: #cc_to_so
.replace("-ldl", "") \
.replace("-shared", "-S") \
.replace(" -o ", " -g -o ")
asm_cmd = remove_flags(asm_cmd, ['-l', '-L', '-Wl,'])
asm_cmd = remove_flags(asm_cmd, ['-l', '-L', '-Wl,', '.lib', '-shared'])
run_cmd(asm_cmd)
s_path = cc_path.replace("_op.cc","_op.post.s")
@ -169,5 +182,5 @@ else: #cc_to_so
pass_asm(cc_path,s_path)
asm_cmd = cmd.replace("_op.cc", "_op.s") \
.replace("-g", "")
.replace(" -g", "")
run_cmd(remove_clang_flag(asm_cmd))