forked from jittor/jittor
polish win_cuda on linux
This commit is contained in:
parent
e77f1ea7cb
commit
c1ee6d9ed3
|
@ -126,7 +126,7 @@ def setup_mkl():
|
|||
mkl_lib_path = os.path.join(mkl_home, "lib")
|
||||
|
||||
mkl_lib_name = os.path.join(mkl_lib_path, "libmkldnn.so")
|
||||
extra_flags = f" -I\"{mkl_include_path}\" -L\"{mkl_lib_path}\" -lmkldnn -Wl,-rpath='{mkl_lib_path}' "
|
||||
extra_flags = f" -I\"{mkl_include_path}\" -L\"{mkl_lib_path}\" -lmkldnn "
|
||||
if os.name == 'nt':
|
||||
mkl_lib_name = os.path.join(mkl_home, 'bin', 'dnnl.dll')
|
||||
mkl_bin_path = os.path.join(mkl_home, 'bin')
|
||||
|
@ -199,9 +199,9 @@ def setup_cuda_extern():
|
|||
cuda_extern_src = os.path.join(jittor_path, "extern", "cuda", "src")
|
||||
cuda_extern_files = [os.path.join(cuda_extern_src, name)
|
||||
for name in os.listdir(cuda_extern_src)]
|
||||
so_name = os.path.join(cache_path_cuda, "cuda_extern"+so)
|
||||
so_name = os.path.join(cache_path_cuda, "libcuda_extern"+so)
|
||||
compile(cc_path, cc_flags+f" -I\"{cuda_include}\" ", cuda_extern_files, so_name)
|
||||
link_cuda_extern = f" -L\"{cache_path_cuda}\" -lcuda_extern "
|
||||
link_cuda_extern = f" -L\"{cache_path_cuda}\" -llibcuda_extern "
|
||||
ctypes.CDLL(so_name, dlopen_flags)
|
||||
|
||||
try:
|
||||
|
|
|
@ -118,7 +118,7 @@ def compile(compiler, flags, inputs, output, combind_build=False, cuda_flags="")
|
|||
inputs = new_inputs
|
||||
|
||||
if len(inputs) == 1 or combind_build:
|
||||
cmd = f"\"{compiler}\" {' '.join(inputs)} {flags} {link} -o {output}"
|
||||
cmd = f"\"{compiler}\" {' '.join(inputs)} {flags} -o {output}"
|
||||
return do_compile(fix_cl_flags(cmd))
|
||||
# split compile object file and link
|
||||
# remove -l -L flags when compile object files
|
||||
|
@ -1019,7 +1019,18 @@ if platform.system() == 'Darwin':
|
|||
kernel_opt_flags += " -Xpreprocessor -fopenmp "
|
||||
elif cc_type != 'cl':
|
||||
kernel_opt_flags += " -fopenmp "
|
||||
fix_cl_flags = lambda x:x
|
||||
def fix_cl_flags(cmd):
|
||||
output = shsplit(cmd)
|
||||
output2 = []
|
||||
for s in output:
|
||||
if s.startswith("-l") and ("cpython" in s or "lib" in s):
|
||||
output2.append(f"-l:{s[2:]}.so")
|
||||
elif s.startswith("-L"):
|
||||
output2.append(f"{s} -Wl,-rpath={s[2:]}")
|
||||
else:
|
||||
output2.append(s)
|
||||
return " ".join(output2)
|
||||
|
||||
if os.name == 'nt':
|
||||
if cc_type == 'g++':
|
||||
pass
|
||||
|
@ -1251,9 +1262,9 @@ with jit_utils.import_scope(import_flags):
|
|||
import jittor_core as core
|
||||
|
||||
flags = core.flags()
|
||||
nvcc_flags = convert_nvcc_flags(cc_flags)
|
||||
|
||||
if has_cuda:
|
||||
nvcc_flags = convert_nvcc_flags(cc_flags)
|
||||
if len(flags.cuda_archs):
|
||||
nvcc_flags += f" -arch=compute_{min(flags.cuda_archs)} "
|
||||
nvcc_flags += ''.join(map(lambda x:f' -code=sm_{x} ', flags.cuda_archs))
|
||||
|
|
|
@ -33,7 +33,6 @@ DEFINE_FLAG(string, python_path, "", "Path of python interpreter");
|
|||
DEFINE_FLAG(string, cache_path, "", "Cache path of jittor");
|
||||
DEFINE_FLAG(int, rewrite_op, 1, "Rewrite source file of jit operator or not");
|
||||
|
||||
#ifdef _MSC_VER
|
||||
vector<string> shsplit(const string& s) {
|
||||
auto s1 = split(s, " ");
|
||||
vector<string> s2;
|
||||
|
@ -54,7 +53,8 @@ vector<string> shsplit(const string& s) {
|
|||
return s2;
|
||||
}
|
||||
|
||||
string fix_cl_flags(const string& cmd) {
|
||||
string fix_cl_flags(const string& cmd, bool is_cuda) {
|
||||
#ifdef _MSC_VER
|
||||
auto flags = shsplit(cmd);
|
||||
vector<string> output, output2;
|
||||
|
||||
|
@ -95,8 +95,31 @@ string fix_cl_flags(const string& cmd) {
|
|||
cmdx += " ";
|
||||
}
|
||||
return cmdx;
|
||||
}
|
||||
#else
|
||||
auto flags = shsplit(cmd);
|
||||
vector<string> output;
|
||||
|
||||
for (auto& f : flags) {
|
||||
if (startswith(f, "-l") &&
|
||||
(f.find("cpython") != string::npos ||
|
||||
f.find("lib") != string::npos))
|
||||
output.push_back("-l:"+f.substr(2)+".so");
|
||||
else if (startswith(f, "-L")) {
|
||||
if (is_cuda)
|
||||
output.push_back(f+" -Xlinker -rpath="+f.substr(2));
|
||||
else
|
||||
output.push_back(f+" -Wl,-rpath="+f.substr(2));
|
||||
} else
|
||||
output.push_back(f);
|
||||
}
|
||||
string cmdx = "";
|
||||
for (auto& s : output) {
|
||||
cmdx += s;
|
||||
cmdx += " ";
|
||||
}
|
||||
return cmdx;
|
||||
#endif
|
||||
}
|
||||
|
||||
namespace jit_compiler {
|
||||
|
||||
|
@ -174,12 +197,12 @@ jit_op_entry_t compile(const string& jit_key, const string& src, const bool is_c
|
|||
if (is_cuda_op) {
|
||||
cmd = "\"" + nvcc_path + "\""
|
||||
+ " \"" + jit_src_path + "\"" + other_src
|
||||
+ nvcc_flags + extra_flags
|
||||
+ fix_cl_flags(nvcc_flags + extra_flags, is_cuda_op)
|
||||
+ " -o \"" + jit_lib_path + "\"";
|
||||
} else {
|
||||
cmd = "\"" + cc_path + "\""
|
||||
+ " \"" + jit_src_path + "\"" + other_src
|
||||
+ cc_flags + extra_flags
|
||||
+ fix_cl_flags(cc_flags + extra_flags, is_cuda_op)
|
||||
+ " -o \"" + jit_lib_path + "\"";
|
||||
#ifdef __linux__
|
||||
cmd = python_path+" "+jittor_path+"/utils/asm_tuner.py "
|
||||
|
@ -193,12 +216,12 @@ jit_op_entry_t compile(const string& jit_key, const string& src, const bool is_c
|
|||
+ nvcc_flags + extra_flags
|
||||
+ " -o \"" + jit_lib_path + "\""
|
||||
+ " -Xlinker -EXPORT:\""
|
||||
+ symbol_name + "\"";;
|
||||
+ symbol_name + "\"";
|
||||
} else {
|
||||
cmd = "\"" + cc_path + "\""
|
||||
+ " \"" + jit_src_path + "\"" + other_src
|
||||
+ " -Fe: \"" + jit_lib_path + "\" "
|
||||
+ fix_cl_flags(cc_flags + extra_flags) + " -EXPORT:\""
|
||||
+ fix_cl_flags(cc_flags + extra_flags, is_cuda_op) + " -EXPORT:\""
|
||||
+ symbol_name + "\"";
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -241,7 +241,8 @@ bool cache_compile(string cmd, const string& cache_path, const string& jittor_pa
|
|||
find_names(cmd, input_names, output_name, extra);
|
||||
string output_cache_key;
|
||||
bool ran = false;
|
||||
output_cache_key = read_all(output_name+".key");
|
||||
if (file_exist(output_name))
|
||||
output_cache_key = read_all(output_name+".key");
|
||||
string cache_key;
|
||||
unordered_set<string> processed;
|
||||
auto src_path = join(jittor_path, "src");
|
||||
|
|
|
@ -131,15 +131,28 @@ so_pos=cmd.find("_op.so")
|
|||
# remove -Xclang ...
|
||||
remove_clang_flag = lambda s: re.sub("-Xclang (('[^']*')|([^ ]*))", "", s)
|
||||
|
||||
def shsplit(s):
|
||||
s1 = s.split(' ')
|
||||
s2 = []
|
||||
count = 0
|
||||
for s in s1:
|
||||
nc = s.count('"') + s.count('\'')
|
||||
if count&1:
|
||||
count += nc
|
||||
s2[-1] += " "
|
||||
s2[-1] += s
|
||||
else:
|
||||
count = nc
|
||||
s2.append(s)
|
||||
return s2
|
||||
|
||||
def remove_flags(flags, rm_flags):
|
||||
flags = flags.split(" ")
|
||||
flags = shsplit(flags)
|
||||
output = []
|
||||
for s in flags:
|
||||
if s.startswith("-load"):
|
||||
output.append(s)
|
||||
continue
|
||||
ss = s.replace("\"", "")
|
||||
for rm in rm_flags:
|
||||
if s.startswith(rm):
|
||||
if ss.startswith(rm) or ss.endswith(rm):
|
||||
break
|
||||
else:
|
||||
output.append(s)
|
||||
|
@ -161,7 +174,7 @@ else: #cc_to_so
|
|||
.replace("-ldl", "") \
|
||||
.replace("-shared", "-S") \
|
||||
.replace(" -o ", " -g -o ")
|
||||
asm_cmd = remove_flags(asm_cmd, ['-l', '-L', '-Wl,'])
|
||||
asm_cmd = remove_flags(asm_cmd, ['-l', '-L', '-Wl,', '.lib', '-shared'])
|
||||
run_cmd(asm_cmd)
|
||||
|
||||
s_path = cc_path.replace("_op.cc","_op.post.s")
|
||||
|
@ -169,5 +182,5 @@ else: #cc_to_so
|
|||
pass_asm(cc_path,s_path)
|
||||
|
||||
asm_cmd = cmd.replace("_op.cc", "_op.s") \
|
||||
.replace("-g", "")
|
||||
.replace(" -g", "")
|
||||
run_cmd(remove_clang_flag(asm_cmd))
|
Loading…
Reference in New Issue