forked from jittor/jittor
support macOS
This commit is contained in:
parent
c4a937cd32
commit
cf171bf577
|
@ -1169,9 +1169,11 @@ def dirty_fix_pytorch_runtime_error():
|
|||
jt.dirty_fix_pytorch_runtime_error()
|
||||
import torch
|
||||
'''
|
||||
import os
|
||||
os.RTLD_GLOBAL = os.RTLD_GLOBAL | os.RTLD_DEEPBIND
|
||||
import os, platform
|
||||
|
||||
if platform.system() == 'Linux':
|
||||
os.RTLD_GLOBAL = os.RTLD_GLOBAL | os.RTLD_DEEPBIND
|
||||
|
||||
|
||||
import atexit
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import os, sys, shutil
|
||||
import platform
|
||||
from .compiler import *
|
||||
from jittor_utils import run_cmd, get_version, get_int_version
|
||||
from jittor.utils.misc import download_url_to_local
|
||||
|
@ -54,39 +55,46 @@ def setup_mkl():
|
|||
mkl_include_path = os.environ.get("mkl_include_path")
|
||||
mkl_lib_path = os.environ.get("mkl_lib_path")
|
||||
|
||||
if mkl_lib_path is None or mkl_include_path is None:
|
||||
mkl_install_sh = os.path.join(jittor_path, "script", "install_mkl.sh")
|
||||
LOG.v("setup mkl...")
|
||||
# mkl_path = os.path.join(cache_path, "mkl")
|
||||
# mkl_path decouple with cc_path
|
||||
from pathlib import Path
|
||||
mkl_path = os.path.join(str(Path.home()), ".cache", "jittor", "mkl")
|
||||
|
||||
make_cache_dir(mkl_path)
|
||||
install_mkl(mkl_path)
|
||||
mkl_home = ""
|
||||
for name in os.listdir(mkl_path):
|
||||
if name.startswith("mkldnn_lnx") and os.path.isdir(os.path.join(mkl_path, name)):
|
||||
mkl_home = os.path.join(mkl_path, name)
|
||||
break
|
||||
assert mkl_home!=""
|
||||
if platform.system() == 'Linux':
|
||||
if mkl_lib_path is None or mkl_include_path is None:
|
||||
mkl_install_sh = os.path.join(jittor_path, "script", "install_mkl.sh")
|
||||
LOG.v("setup mkl...")
|
||||
# mkl_path = os.path.join(cache_path, "mkl")
|
||||
# mkl_path decouple with cc_path
|
||||
from pathlib import Path
|
||||
mkl_path = os.path.join(str(Path.home()), ".cache", "jittor", "mkl")
|
||||
|
||||
make_cache_dir(mkl_path)
|
||||
install_mkl(mkl_path)
|
||||
mkl_home = ""
|
||||
for name in os.listdir(mkl_path):
|
||||
if name.startswith("mkldnn_lnx") and os.path.isdir(os.path.join(mkl_path, name)):
|
||||
mkl_home = os.path.join(mkl_path, name)
|
||||
break
|
||||
assert mkl_home!=""
|
||||
|
||||
mkl_include_path = os.path.join(mkl_home, "include")
|
||||
mkl_lib_path = os.path.join(mkl_home, "lib")
|
||||
|
||||
mkl_lib_name = os.path.join(mkl_lib_path, "libmkldnn.so")
|
||||
assert os.path.isdir(mkl_include_path)
|
||||
assert os.path.isdir(mkl_lib_path)
|
||||
assert os.path.isfile(mkl_lib_name)
|
||||
LOG.v(f"mkl_include_path: {mkl_include_path}")
|
||||
LOG.v(f"mkl_lib_path: {mkl_lib_path}")
|
||||
LOG.v(f"mkl_lib_name: {mkl_lib_name}")
|
||||
# We do not link manualy, link in custom ops
|
||||
# ctypes.CDLL(mkl_lib_name, dlopen_flags)
|
||||
mkl_lib_name = os.path.join(mkl_lib_path, "libmkldnn.so")
|
||||
assert os.path.isdir(mkl_include_path)
|
||||
assert os.path.isdir(mkl_lib_path)
|
||||
assert os.path.isfile(mkl_lib_name)
|
||||
LOG.v(f"mkl_include_path: {mkl_include_path}")
|
||||
LOG.v(f"mkl_lib_path: {mkl_lib_path}")
|
||||
LOG.v(f"mkl_lib_name: {mkl_lib_name}")
|
||||
# We do not link manualy, link in custom ops
|
||||
# ctypes.CDLL(mkl_lib_name, dlopen_flags)
|
||||
extra_flags = f" -I'{mkl_include_path}' -L'{mkl_lib_path}' -lmkldnn -Wl,-rpath='{mkl_lib_path}' "
|
||||
|
||||
elif platform.system() == 'Darwin':
|
||||
mkl_lib_name = "/usr/local/lib/libmkldnn.dylib"
|
||||
assert os.path.exists(mkl_lib_name), "Not found onednn, please install it by the command 'brew install onednn@2.2.3'"
|
||||
extra_flags = f" -lmkldnn "
|
||||
|
||||
mkl_op_dir = os.path.join(jittor_path, "extern", "mkl", "ops")
|
||||
mkl_op_files = [os.path.join(mkl_op_dir, name) for name in os.listdir(mkl_op_dir)]
|
||||
mkl_ops = compile_custom_ops(mkl_op_files,
|
||||
extra_flags=f" -I'{mkl_include_path}' -L'{mkl_lib_path}' -lmkldnn -Wl,-rpath='{mkl_lib_path}' ")
|
||||
mkl_ops = compile_custom_ops(mkl_op_files, extra_flags=extra_flags)
|
||||
LOG.vv("Get mkl_ops: "+str(dir(mkl_ops)))
|
||||
|
||||
|
||||
|
|
|
@ -11,6 +11,7 @@ import sys
|
|||
import inspect
|
||||
import datetime
|
||||
import threading
|
||||
import platform
|
||||
import ctypes
|
||||
from ctypes import cdll
|
||||
from ctypes.util import find_library
|
||||
|
@ -92,7 +93,7 @@ def compile(compiler, flags, inputs, output, combind_build=False):
|
|||
return do_compile(cmd)
|
||||
|
||||
def gen_jit_tests():
|
||||
all_src = run_cmd('find -L src/ | grep "cc$"', jittor_path).splitlines()
|
||||
all_src = run_cmd('find -L src | grep "cc$"', jittor_path).splitlines()
|
||||
jit_declares = []
|
||||
re_def = re.compile("JIT_TEST\\((.*?)\\)")
|
||||
names = set()
|
||||
|
@ -142,7 +143,7 @@ def gen_jit_tests():
|
|||
f.write(jit_src)
|
||||
|
||||
def gen_jit_flags():
|
||||
all_src = run_cmd('find -L src/ | grep "cc$"', jittor_path).splitlines()
|
||||
all_src = run_cmd('find -L src | grep "cc$"', jittor_path).splitlines()
|
||||
jit_declares = []
|
||||
re_def = re.compile("DEFINE_FLAG(_WITH_SETTER)?\\((.*?)\\);", re.DOTALL)
|
||||
|
||||
|
@ -591,7 +592,7 @@ def compile_custom_ops(
|
|||
filenames,
|
||||
extra_flags="",
|
||||
return_module=False,
|
||||
dlopen_flags=os.RTLD_GLOBAL | os.RTLD_NOW | os.RTLD_DEEPBIND,
|
||||
dlopen_flags=None,
|
||||
gen_name_ = ""):
|
||||
"""Compile custom ops
|
||||
filenames: path of op source files, filenames must be
|
||||
|
@ -601,6 +602,11 @@ def compile_custom_ops(
|
|||
return_module: return module rather than ops(default: False)
|
||||
return: compiled ops
|
||||
"""
|
||||
if dlopen_flags is None:
|
||||
dlopen_flags = os.RTLD_GLOBAL | os.RTLD_NOW
|
||||
if platform.system() == 'Linux':
|
||||
dlopen_flags |= os.RTLD_DEEPBIND
|
||||
|
||||
srcs = {}
|
||||
headers = {}
|
||||
builds = []
|
||||
|
@ -836,11 +842,15 @@ def check_debug_flags():
|
|||
|
||||
cc_flags = " "
|
||||
# os.RTLD_NOW | os.RTLD_GLOBAL cause segfault when import torch first
|
||||
import_flags = os.RTLD_NOW | os.RTLD_GLOBAL | os.RTLD_DEEPBIND
|
||||
import_flags = os.RTLD_NOW | os.RTLD_GLOBAL
|
||||
if platform.system() == 'Linux':
|
||||
import_flags |= os.RTLD_DEEPBIND
|
||||
# if cc_type=="icc":
|
||||
# # weird link problem, icc omp library may conflict and cause segfault
|
||||
# import_flags = os.RTLD_NOW | os.RTLD_GLOBAL
|
||||
dlopen_flags = os.RTLD_NOW | os.RTLD_GLOBAL | os.RTLD_DEEPBIND
|
||||
dlopen_flags = os.RTLD_NOW | os.RTLD_GLOBAL
|
||||
if platform.system() == 'Linux':
|
||||
import_flags |= os.RTLD_DEEPBIND
|
||||
|
||||
with jit_utils.import_scope(import_flags):
|
||||
jit_utils.try_import_jit_utils_core()
|
||||
|
@ -874,9 +884,18 @@ cc_flags += " -fdiagnostics-color=always "
|
|||
if "cc_flags" in os.environ:
|
||||
cc_flags += os.environ["cc_flags"] + ' '
|
||||
link_flags = " -lstdc++ -ldl -shared "
|
||||
if platform.system() == 'Darwin':
|
||||
# TODO: if not using apple clang, no need to add -lomp
|
||||
link_flags += "-undefined dynamic_lookup -lomp "
|
||||
|
||||
core_link_flags = ""
|
||||
opt_flags = ""
|
||||
kernel_opt_flags = os.environ.get("kernel_flags", "") + opt_flags + " -fopenmp "
|
||||
kernel_opt_flags = os.environ.get("kernel_flags", "") + opt_flags
|
||||
if platform.system() == 'Darwin':
|
||||
# TODO: if not using apple clang, cannot add -Xpreprocessor
|
||||
kernel_opt_flags = kernel_opt_flags + " -Xpreprocessor -fopenmp "
|
||||
else:
|
||||
kernel_opt_flags = kernel_opt_flags + " -fopenmp "
|
||||
|
||||
if ' -O' not in cc_flags:
|
||||
opt_flags += " -O2 "
|
||||
|
@ -935,7 +954,7 @@ if has_cuda:
|
|||
# build core
|
||||
gen_jit_flags()
|
||||
gen_jit_tests()
|
||||
op_headers = run_cmd('find -L src/ops/ | grep "op.h$"', jittor_path).splitlines()
|
||||
op_headers = run_cmd('find -L src/ops | grep "op.h$"', jittor_path).splitlines()
|
||||
jit_src = gen_jit_op_maker(op_headers)
|
||||
LOG.vvvv(jit_src)
|
||||
with open(os.path.join(cache_path, "gen", "jit_op_maker.h"), 'w') as f:
|
||||
|
@ -983,19 +1002,22 @@ LOG.vv("compile order:", files)
|
|||
# manual Link omp using flags(os.RTLD_NOW | os.RTLD_GLOBAL)
|
||||
# if cc_type=="icc":
|
||||
# os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
|
||||
libname = {"clang":"omp", "icc":"iomp5", "g++":"gomp"}[cc_type]
|
||||
libname = ctypes.util.find_library(libname)
|
||||
assert libname is not None, "openmp library not found"
|
||||
ctypes.CDLL(libname, os.RTLD_NOW | os.RTLD_GLOBAL)
|
||||
# libname = {"clang":"omp", "icc":"iomp5", "g++":"gomp"}[cc_type]
|
||||
# libname = ctypes.util.find_library(libname)
|
||||
# assert libname is not None, "openmp library not found"
|
||||
# ctypes.CDLL(libname, os.RTLD_NOW | os.RTLD_GLOBAL)
|
||||
|
||||
# get os release
|
||||
with open("/etc/os-release", "r", encoding='utf8') as f:
|
||||
s = f.read().splitlines()
|
||||
os_release = {}
|
||||
for line in s:
|
||||
a = line.split('=')
|
||||
if len(a) != 2: continue
|
||||
os_release[a[0]] = a[1].replace("\"", "")
|
||||
if platform.system() == 'Linux':
|
||||
with open("/etc/os-release", "r", encoding='utf8') as f:
|
||||
s = f.read().splitlines()
|
||||
os_release = {}
|
||||
for line in s:
|
||||
a = line.split('=')
|
||||
if len(a) != 2: continue
|
||||
os_release[a[0]] = a[1].replace("\"", "")
|
||||
elif platform.system() == 'Darwin':
|
||||
os_release = {'ID' : 'macOS'}
|
||||
|
||||
os_type = {
|
||||
"ubuntu": "ubuntu",
|
||||
|
@ -1003,7 +1025,9 @@ os_type = {
|
|||
"centos": "centos",
|
||||
"rhel": "ubuntu",
|
||||
"fedora": "ubuntu",
|
||||
"macOS": "macOS",
|
||||
}
|
||||
|
||||
version_file = os.path.join(jittor_path, "version")
|
||||
if os.path.isfile(version_file) and not os.path.isdir(os.path.join(jittor_path, "src", "__data__")):
|
||||
with open(version_file, 'r') as f:
|
||||
|
|
|
@ -860,8 +860,8 @@ def compile_single(head_file_name, src_file_name, src=None):
|
|||
return True
|
||||
|
||||
def compile(cache_path, jittor_path):
|
||||
headers1 = run_cmd('find -L src/ | grep ".h$"', jittor_path).splitlines()
|
||||
headers2 = run_cmd('find gen/ | grep ".h$"', cache_path).splitlines()
|
||||
headers1 = run_cmd('find -L src | grep ".h$"', jittor_path).splitlines()
|
||||
headers2 = run_cmd('find gen | grep ".h$"', cache_path).splitlines()
|
||||
headers = [ os.path.join(jittor_path, h) for h in headers1 ] + \
|
||||
[ os.path.join(cache_path, h) for h in headers2 ]
|
||||
basenames = []
|
||||
|
|
|
@ -33,7 +33,11 @@ jit_op_entry_t load_jit_lib(string name, string symbol_name="jit_entry") {
|
|||
LOGvv << "Opening jit lib:" << name;
|
||||
// void* handle = dlopen(name.c_str(), RTLD_NOW | RTLD_DEEPBIND | RTLD_LOCAL);
|
||||
// RTLD_DEEPBIND and openmp cause segfault
|
||||
#ifdef __linux__
|
||||
void* handle = dlopen(name.c_str(), RTLD_NOW | RTLD_LOCAL | RTLD_DEEPBIND);
|
||||
#else
|
||||
void *handle = dlopen(name.c_str(), RTLD_NOW | RTLD_LOCAL);
|
||||
#endif
|
||||
CHECK(handle) << "Cannot open library" << name << ":" << dlerror();
|
||||
|
||||
//dlerror();
|
||||
|
@ -84,8 +88,8 @@ jit_op_entry_t compile(const string& jit_key, const string& src, const bool is_c
|
|||
+ " '" + jit_src_path + "'" + other_src
|
||||
+ cc_flags + extra_flags
|
||||
+ " -o '" + jit_lib_path + "'";
|
||||
cmd = python_path+" "+jittor_path+"/utils/asm_tuner.py "
|
||||
"--cc_path=" + cmd;
|
||||
// cmd = python_path+" "+jittor_path+"/utils/asm_tuner.py "
|
||||
// "--cc_path=" + cmd;
|
||||
}
|
||||
cache_compile(cmd, cache_path, jittor_path);
|
||||
auto symbol_name = get_symbol_name(jit_key);
|
||||
|
|
|
@ -166,9 +166,11 @@ inline JK& operator<<(JK& jk, int64 c) {
|
|||
return jk << JK::hex(c);
|
||||
}
|
||||
|
||||
#ifdef __linux__
|
||||
inline JK& operator<<(JK& jk, long long int c) {
|
||||
return jk << (int64)c;
|
||||
}
|
||||
#endif
|
||||
|
||||
inline JK& operator<<(JK& jk, uint64 c) {
|
||||
return jk << JK::hex(c);
|
||||
|
|
|
@ -6,7 +6,12 @@
|
|||
// ***************************************************************
|
||||
#include <iomanip>
|
||||
#include <algorithm>
|
||||
|
||||
#if defined(__linux__)
|
||||
#include <sys/sysinfo.h>
|
||||
#elif defined(__APPLE__)
|
||||
#include <sys/sysctl.h>
|
||||
#endif
|
||||
|
||||
#include "var.h"
|
||||
#include "op.h"
|
||||
|
@ -152,9 +157,17 @@ void display_memory_info(const char* fileline, bool dump_var, bool red_color) {
|
|||
}
|
||||
|
||||
MemInfo::MemInfo() {
|
||||
#if defined(__linux__)
|
||||
struct sysinfo info = {0};
|
||||
sysinfo(&info);
|
||||
total_cpu_ram = info.totalram;
|
||||
#elif defined(__APPLE__)
|
||||
int mib[] = {CTL_HW, HW_MEMSIZE};
|
||||
int64 mem;
|
||||
size_t len;
|
||||
total_cpu_ram = sysctl(mib, 2, &mem, &len, NULL, 0);
|
||||
#endif
|
||||
|
||||
total_cuda_ram = 0;
|
||||
#ifdef HAS_CUDA
|
||||
cudaDeviceProp prop;
|
||||
|
|
|
@ -14,7 +14,6 @@
|
|||
#include "mem/allocator/sfrl_allocator.h"
|
||||
#include <iomanip>
|
||||
#include <algorithm>
|
||||
#include <sys/sysinfo.h>
|
||||
#include <sstream>
|
||||
#include "pybind/py_var_tracer.h"
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include <bits/stdc++.h>
|
||||
#include <cstring>
|
||||
#include "misc/nano_string.h"
|
||||
|
||||
namespace jittor {
|
||||
|
|
|
@ -159,13 +159,8 @@ struct NanoVector {
|
|||
for (auto a : v) push_back_check_overflow(a);
|
||||
}
|
||||
|
||||
inline static NanoVector make(const int64* v, int n) {
|
||||
NanoVector nv;
|
||||
for (int i=0; i<n; i++) nv.push_back_check_overflow(v[i]);
|
||||
return nv;
|
||||
}
|
||||
|
||||
inline static NanoVector make(const int32* v, int n) {
|
||||
template<typename TMakeV>
|
||||
inline static NanoVector make(const TMakeV* v, int n) {
|
||||
NanoVector nv;
|
||||
for (int i=0; i<n; i++) nv.push_back_check_overflow(v[i]);
|
||||
return nv;
|
||||
|
|
|
@ -51,7 +51,9 @@ struct RingBuffer {
|
|||
// a dirty hack
|
||||
// ref: https://stackoverflow.com/questions/20439404/pthread-conditions-and-process-termination
|
||||
// cv.__data.__wrefs = 0;
|
||||
#ifdef __linux__
|
||||
cv.__data = {0};
|
||||
#endif
|
||||
pthread_cond_destroy(&cv);
|
||||
}
|
||||
|
||||
|
|
|
@ -5,12 +5,22 @@
|
|||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
|
||||
#if defined(__clang__)
|
||||
#include <string_view>
|
||||
#elif defined(__GNUC__)
|
||||
#include <experimental/string_view>
|
||||
#endif
|
||||
|
||||
#include "common.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
#if defined(__clang__)
|
||||
using std::string_view;
|
||||
#elif defined(__GNUC__)
|
||||
using std::experimental::string_view;
|
||||
#endif
|
||||
|
||||
template<class T>
|
||||
struct string_view_map {
|
||||
|
|
|
@ -144,7 +144,8 @@ void GetitemOp::infer_slices(
|
|||
out_shape_j = (slice.stop - slice.start - 1) / slice.step + 1;
|
||||
else
|
||||
out_shape_j = (slice.start - slice.stop - 1) / -slice.step + 1;
|
||||
out_shape_j = std::max(0l, out_shape_j);
|
||||
|
||||
out_shape_j = out_shape_j > 0 ? out_shape_j : 0;
|
||||
}
|
||||
out_shape.push_back(out_shape_j);
|
||||
}
|
||||
|
|
|
@ -58,7 +58,11 @@ void Profiler::stop() {
|
|||
|
||||
unique_ptr<MemoryChecker>* load_memory_checker(string name) {
|
||||
LOGvv << "Opening jit lib:" << name;
|
||||
void* handle = dlopen(name.c_str(), RTLD_LAZY | RTLD_DEEPBIND | RTLD_LOCAL);
|
||||
#ifdef __linux__
|
||||
void *handle = dlopen(name.c_str(), RTLD_LAZY | RTLD_DEEPBIND | RTLD_LOCAL);
|
||||
#else
|
||||
void* handle = dlopen(name.c_str(), RTLD_LAZY | RTLD_LOCAL);
|
||||
#endif
|
||||
CHECK(handle) << "Cannot open library" << name << ":" << dlerror();
|
||||
|
||||
//dlerror();
|
||||
|
|
|
@ -136,7 +136,7 @@ ArrayOp::ArrayOp(PyObject* obj) {
|
|||
std::memcpy(host_ptr, args.ptr, size);
|
||||
} else {
|
||||
// this is non-continue numpy array
|
||||
int64 dims[args.shape.size()];
|
||||
long dims[args.shape.size()];
|
||||
for (int i=0; i<args.shape.size(); i++)
|
||||
dims[i] = args.shape[i];
|
||||
holder.assign(PyArray_New(
|
||||
|
|
|
@ -266,7 +266,7 @@ DEF_IS(ArrayArgs, bool) is_type(PyObject* obj) {
|
|||
}
|
||||
|
||||
DEF_IS(ArrayArgs, PyObject*) to_py_object(const T& a) {
|
||||
int64 dims[a.shape.size()];
|
||||
long dims[a.shape.size()];
|
||||
for (int i=0; i<a.shape.size(); i++)
|
||||
dims[i] = a.shape[i];
|
||||
PyObjHolder obj(PyArray_SimpleNew(
|
||||
|
@ -378,7 +378,7 @@ DEF_IS(VarHolder*, T) from_py_object(PyObject* obj, unique_ptr<VarHolder>& holde
|
|||
|
||||
struct DataView;
|
||||
DEF_IS(DataView, PyObject*) to_py_object(T a) {
|
||||
int64 dims[a.shape.size()];
|
||||
long dims[a.shape.size()];
|
||||
for (int i=0; i<a.shape.size(); i++)
|
||||
dims[i] = a.shape[i];
|
||||
PyObjHolder oh(PyArray_New(
|
||||
|
|
|
@ -109,7 +109,7 @@ static void push_py_object(RingBuffer* rb, PyObject* obj, uint64& __restrict__ o
|
|||
rb->push_t<NanoString>(args.dtype, offset);
|
||||
rb->push(size, offset);
|
||||
args.ptr = rb->get_ptr(size, offset);
|
||||
int64 dims[args.shape.size()];
|
||||
long dims[args.shape.size()];
|
||||
for (int i=0; i<args.shape.size(); i++)
|
||||
dims[i] = args.shape[i];
|
||||
PyObjHolder oh(PyArray_New(
|
||||
|
|
|
@ -187,6 +187,8 @@ bool cache_compile(const string& cmd, const string& cache_path, const string& ji
|
|||
for (size_t i=0; i<input_names.size(); i++) {
|
||||
if (processed.count(input_names[i]) != 0)
|
||||
continue;
|
||||
if (input_names[i] == "dynamic_lookup")
|
||||
continue;
|
||||
processed.insert(input_names[i]);
|
||||
auto src = read_all(input_names[i]);
|
||||
ASSERT(src.size()) << "Source read failed:" << input_names[i];
|
||||
|
|
|
@ -12,7 +12,11 @@
|
|||
#endif
|
||||
#ifdef __GNUC__
|
||||
#endif
|
||||
|
||||
#ifdef __linux__
|
||||
#include <sys/prctl.h>
|
||||
#endif
|
||||
|
||||
#include <signal.h>
|
||||
#include <iterator>
|
||||
#include <algorithm>
|
||||
|
@ -21,7 +25,10 @@
|
|||
namespace jittor {
|
||||
|
||||
void init_subprocess() {
|
||||
#ifdef __linux__
|
||||
prctl(PR_SET_PDEATHSIG, SIGKILL);
|
||||
#endif
|
||||
|
||||
}
|
||||
|
||||
static void __log(
|
||||
|
|
|
@ -193,7 +193,11 @@ void segfault_sigaction(int signal, siginfo_t *si, void *arg) {
|
|||
LOGe << "Caught SIGINT, quick exit";
|
||||
}
|
||||
exited = true;
|
||||
#ifdef __APPLE__
|
||||
_Exit(1);
|
||||
#else
|
||||
std::quick_exit(1);
|
||||
#endif
|
||||
}
|
||||
std::cerr << "Caught segfault at address " << si->si_addr << ", "
|
||||
<< "thread_name: '" << thread_name << "', flush log..." << std::endl;
|
||||
|
|
|
@ -7,8 +7,10 @@
|
|||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <sys/wait.h>
|
||||
#include <unistd.h>
|
||||
#ifdef __linux__
|
||||
#include <sys/prctl.h>
|
||||
#endif
|
||||
#include <unistd.h>
|
||||
#include <execinfo.h>
|
||||
#include <iostream>
|
||||
#include "utils/tracer.h"
|
||||
|
@ -61,7 +63,9 @@ void setter_gdb_attach(int v) {
|
|||
exit(1);
|
||||
} else {
|
||||
// allow children ptrace parent
|
||||
#ifdef __linux__
|
||||
prctl(PR_SET_PTRACER, child_pid, 0, 0, 0);
|
||||
#endif
|
||||
// sleep 5s, wait gdb attach
|
||||
sleep(5);
|
||||
}
|
||||
|
@ -118,7 +122,9 @@ void print_trace() {
|
|||
exit(0);
|
||||
} else {
|
||||
// allow children ptrace parent
|
||||
#ifdef __linux__
|
||||
prctl(PR_SET_PTRACER, child_pid, 0, 0, 0);
|
||||
#endif
|
||||
waitpid(child_pid,NULL,0);
|
||||
}
|
||||
} else {
|
||||
|
|
|
@ -154,6 +154,8 @@ def pool_cleanup():
|
|||
del p
|
||||
|
||||
def pool_initializer():
|
||||
if cc is None:
|
||||
try_import_jit_utils_core()
|
||||
cc.init_subprocess()
|
||||
|
||||
def run_cmds(cmds, cache_path, jittor_path, msg="run_cmds"):
|
||||
|
|
27
setup.py
27
setup.py
|
@ -1,13 +1,13 @@
|
|||
error_msg = """Jittor only supports Ubuntu>=16.04 currently.
|
||||
error_msg = """Jittor only supports Linux and macOS currently.
|
||||
For other OS, use Jittor may be risky.
|
||||
If you insist on installing, please set the environment variable : export FORCE_INSTALL=1
|
||||
We strongly recommended docker installation:
|
||||
We strongly recommend docker installation:
|
||||
|
||||
# CPU only(Linux)
|
||||
# CPU only (Linux)
|
||||
>>> docker run -it --network host jittor/jittor
|
||||
# CPU and CUDA(Linux)
|
||||
# CPU and CUDA (Linux)
|
||||
>>> docker run -it --network host jittor/jittor-cuda
|
||||
# CPU only(Mac and Windows)
|
||||
# CPU only (Mac and Windows)
|
||||
>>> docker run -it -p 8888:8888 jittor/jittor
|
||||
|
||||
Reference:
|
||||
|
@ -15,19 +15,10 @@ Reference:
|
|||
"""
|
||||
from warnings import warn
|
||||
import os
|
||||
try:
|
||||
with open("/etc/os-release", "r", encoding='utf8') as f:
|
||||
s = f.read().splitlines()
|
||||
m = {}
|
||||
for line in s:
|
||||
a = line.split('=')
|
||||
if len(a) != 2: continue
|
||||
m[a[0]] = a[1].replace("\"", "")
|
||||
# assert m["NAME"] == "Ubuntu" and float(m["VERSION_ID"].split('.')[0])>=16, error_msg
|
||||
except Exception as e:
|
||||
print(e)
|
||||
warn(error_msg)
|
||||
if os.environ.get("FORCE_INSTALL", '0') != '1': raise
|
||||
import platform
|
||||
|
||||
if not platform.system() in ['Linux', 'Darwin']:
|
||||
assert os.environ.get("FORCE_INSTALL", '0') != '1', error_msg
|
||||
|
||||
import setuptools
|
||||
from setuptools import setup, find_packages
|
||||
|
|
Loading…
Reference in New Issue