support macOS

This commit is contained in:
lzhengning 2021-06-04 13:26:24 +08:00
parent c4a937cd32
commit cf171bf577
23 changed files with 161 additions and 85 deletions

View File

@ -1169,9 +1169,11 @@ def dirty_fix_pytorch_runtime_error():
jt.dirty_fix_pytorch_runtime_error()
import torch
'''
import os
os.RTLD_GLOBAL = os.RTLD_GLOBAL | os.RTLD_DEEPBIND
import os, platform
if platform.system() == 'Linux':
os.RTLD_GLOBAL = os.RTLD_GLOBAL | os.RTLD_DEEPBIND
import atexit

View File

@ -5,6 +5,7 @@
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import os, sys, shutil
import platform
from .compiler import *
from jittor_utils import run_cmd, get_version, get_int_version
from jittor.utils.misc import download_url_to_local
@ -54,39 +55,46 @@ def setup_mkl():
mkl_include_path = os.environ.get("mkl_include_path")
mkl_lib_path = os.environ.get("mkl_lib_path")
if mkl_lib_path is None or mkl_include_path is None:
mkl_install_sh = os.path.join(jittor_path, "script", "install_mkl.sh")
LOG.v("setup mkl...")
# mkl_path = os.path.join(cache_path, "mkl")
# mkl_path decouple with cc_path
from pathlib import Path
mkl_path = os.path.join(str(Path.home()), ".cache", "jittor", "mkl")
make_cache_dir(mkl_path)
install_mkl(mkl_path)
mkl_home = ""
for name in os.listdir(mkl_path):
if name.startswith("mkldnn_lnx") and os.path.isdir(os.path.join(mkl_path, name)):
mkl_home = os.path.join(mkl_path, name)
break
assert mkl_home!=""
if platform.system() == 'Linux':
if mkl_lib_path is None or mkl_include_path is None:
mkl_install_sh = os.path.join(jittor_path, "script", "install_mkl.sh")
LOG.v("setup mkl...")
# mkl_path = os.path.join(cache_path, "mkl")
# mkl_path decouple with cc_path
from pathlib import Path
mkl_path = os.path.join(str(Path.home()), ".cache", "jittor", "mkl")
make_cache_dir(mkl_path)
install_mkl(mkl_path)
mkl_home = ""
for name in os.listdir(mkl_path):
if name.startswith("mkldnn_lnx") and os.path.isdir(os.path.join(mkl_path, name)):
mkl_home = os.path.join(mkl_path, name)
break
assert mkl_home!=""
mkl_include_path = os.path.join(mkl_home, "include")
mkl_lib_path = os.path.join(mkl_home, "lib")
mkl_lib_name = os.path.join(mkl_lib_path, "libmkldnn.so")
assert os.path.isdir(mkl_include_path)
assert os.path.isdir(mkl_lib_path)
assert os.path.isfile(mkl_lib_name)
LOG.v(f"mkl_include_path: {mkl_include_path}")
LOG.v(f"mkl_lib_path: {mkl_lib_path}")
LOG.v(f"mkl_lib_name: {mkl_lib_name}")
# We do not link manualy, link in custom ops
# ctypes.CDLL(mkl_lib_name, dlopen_flags)
mkl_lib_name = os.path.join(mkl_lib_path, "libmkldnn.so")
assert os.path.isdir(mkl_include_path)
assert os.path.isdir(mkl_lib_path)
assert os.path.isfile(mkl_lib_name)
LOG.v(f"mkl_include_path: {mkl_include_path}")
LOG.v(f"mkl_lib_path: {mkl_lib_path}")
LOG.v(f"mkl_lib_name: {mkl_lib_name}")
# We do not link manualy, link in custom ops
# ctypes.CDLL(mkl_lib_name, dlopen_flags)
extra_flags = f" -I'{mkl_include_path}' -L'{mkl_lib_path}' -lmkldnn -Wl,-rpath='{mkl_lib_path}' "
elif platform.system() == 'Darwin':
mkl_lib_name = "/usr/local/lib/libmkldnn.dylib"
assert os.path.exists(mkl_lib_name), "Not found onednn, please install it by the command 'brew install onednn@2.2.3'"
extra_flags = f" -lmkldnn "
mkl_op_dir = os.path.join(jittor_path, "extern", "mkl", "ops")
mkl_op_files = [os.path.join(mkl_op_dir, name) for name in os.listdir(mkl_op_dir)]
mkl_ops = compile_custom_ops(mkl_op_files,
extra_flags=f" -I'{mkl_include_path}' -L'{mkl_lib_path}' -lmkldnn -Wl,-rpath='{mkl_lib_path}' ")
mkl_ops = compile_custom_ops(mkl_op_files, extra_flags=extra_flags)
LOG.vv("Get mkl_ops: "+str(dir(mkl_ops)))

View File

@ -11,6 +11,7 @@ import sys
import inspect
import datetime
import threading
import platform
import ctypes
from ctypes import cdll
from ctypes.util import find_library
@ -92,7 +93,7 @@ def compile(compiler, flags, inputs, output, combind_build=False):
return do_compile(cmd)
def gen_jit_tests():
all_src = run_cmd('find -L src/ | grep "cc$"', jittor_path).splitlines()
all_src = run_cmd('find -L src | grep "cc$"', jittor_path).splitlines()
jit_declares = []
re_def = re.compile("JIT_TEST\\((.*?)\\)")
names = set()
@ -142,7 +143,7 @@ def gen_jit_tests():
f.write(jit_src)
def gen_jit_flags():
all_src = run_cmd('find -L src/ | grep "cc$"', jittor_path).splitlines()
all_src = run_cmd('find -L src | grep "cc$"', jittor_path).splitlines()
jit_declares = []
re_def = re.compile("DEFINE_FLAG(_WITH_SETTER)?\\((.*?)\\);", re.DOTALL)
@ -591,7 +592,7 @@ def compile_custom_ops(
filenames,
extra_flags="",
return_module=False,
dlopen_flags=os.RTLD_GLOBAL | os.RTLD_NOW | os.RTLD_DEEPBIND,
dlopen_flags=None,
gen_name_ = ""):
"""Compile custom ops
filenames: path of op source files, filenames must be
@ -601,6 +602,11 @@ def compile_custom_ops(
return_module: return module rather than ops(default: False)
return: compiled ops
"""
if dlopen_flags is None:
dlopen_flags = os.RTLD_GLOBAL | os.RTLD_NOW
if platform.system() == 'Linux':
dlopen_flags |= os.RTLD_DEEPBIND
srcs = {}
headers = {}
builds = []
@ -836,11 +842,15 @@ def check_debug_flags():
cc_flags = " "
# os.RTLD_NOW | os.RTLD_GLOBAL cause segfault when import torch first
import_flags = os.RTLD_NOW | os.RTLD_GLOBAL | os.RTLD_DEEPBIND
import_flags = os.RTLD_NOW | os.RTLD_GLOBAL
if platform.system() == 'Linux':
import_flags |= os.RTLD_DEEPBIND
# if cc_type=="icc":
# # weird link problem, icc omp library may conflict and cause segfault
# import_flags = os.RTLD_NOW | os.RTLD_GLOBAL
dlopen_flags = os.RTLD_NOW | os.RTLD_GLOBAL | os.RTLD_DEEPBIND
dlopen_flags = os.RTLD_NOW | os.RTLD_GLOBAL
if platform.system() == 'Linux':
import_flags |= os.RTLD_DEEPBIND
with jit_utils.import_scope(import_flags):
jit_utils.try_import_jit_utils_core()
@ -874,9 +884,18 @@ cc_flags += " -fdiagnostics-color=always "
if "cc_flags" in os.environ:
cc_flags += os.environ["cc_flags"] + ' '
link_flags = " -lstdc++ -ldl -shared "
if platform.system() == 'Darwin':
# TODO: if not using apple clang, no need to add -lomp
link_flags += "-undefined dynamic_lookup -lomp "
core_link_flags = ""
opt_flags = ""
kernel_opt_flags = os.environ.get("kernel_flags", "") + opt_flags + " -fopenmp "
kernel_opt_flags = os.environ.get("kernel_flags", "") + opt_flags
if platform.system() == 'Darwin':
# TODO: if not using apple clang, cannot add -Xpreprocessor
kernel_opt_flags = kernel_opt_flags + " -Xpreprocessor -fopenmp "
else:
kernel_opt_flags = kernel_opt_flags + " -fopenmp "
if ' -O' not in cc_flags:
opt_flags += " -O2 "
@ -935,7 +954,7 @@ if has_cuda:
# build core
gen_jit_flags()
gen_jit_tests()
op_headers = run_cmd('find -L src/ops/ | grep "op.h$"', jittor_path).splitlines()
op_headers = run_cmd('find -L src/ops | grep "op.h$"', jittor_path).splitlines()
jit_src = gen_jit_op_maker(op_headers)
LOG.vvvv(jit_src)
with open(os.path.join(cache_path, "gen", "jit_op_maker.h"), 'w') as f:
@ -983,19 +1002,22 @@ LOG.vv("compile order:", files)
# manual Link omp using flags(os.RTLD_NOW | os.RTLD_GLOBAL)
# if cc_type=="icc":
# os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
libname = {"clang":"omp", "icc":"iomp5", "g++":"gomp"}[cc_type]
libname = ctypes.util.find_library(libname)
assert libname is not None, "openmp library not found"
ctypes.CDLL(libname, os.RTLD_NOW | os.RTLD_GLOBAL)
# libname = {"clang":"omp", "icc":"iomp5", "g++":"gomp"}[cc_type]
# libname = ctypes.util.find_library(libname)
# assert libname is not None, "openmp library not found"
# ctypes.CDLL(libname, os.RTLD_NOW | os.RTLD_GLOBAL)
# get os release
with open("/etc/os-release", "r", encoding='utf8') as f:
s = f.read().splitlines()
os_release = {}
for line in s:
a = line.split('=')
if len(a) != 2: continue
os_release[a[0]] = a[1].replace("\"", "")
if platform.system() == 'Linux':
with open("/etc/os-release", "r", encoding='utf8') as f:
s = f.read().splitlines()
os_release = {}
for line in s:
a = line.split('=')
if len(a) != 2: continue
os_release[a[0]] = a[1].replace("\"", "")
elif platform.system() == 'Darwin':
os_release = {'ID' : 'macOS'}
os_type = {
"ubuntu": "ubuntu",
@ -1003,7 +1025,9 @@ os_type = {
"centos": "centos",
"rhel": "ubuntu",
"fedora": "ubuntu",
"macOS": "macOS",
}
version_file = os.path.join(jittor_path, "version")
if os.path.isfile(version_file) and not os.path.isdir(os.path.join(jittor_path, "src", "__data__")):
with open(version_file, 'r') as f:

View File

@ -860,8 +860,8 @@ def compile_single(head_file_name, src_file_name, src=None):
return True
def compile(cache_path, jittor_path):
headers1 = run_cmd('find -L src/ | grep ".h$"', jittor_path).splitlines()
headers2 = run_cmd('find gen/ | grep ".h$"', cache_path).splitlines()
headers1 = run_cmd('find -L src | grep ".h$"', jittor_path).splitlines()
headers2 = run_cmd('find gen | grep ".h$"', cache_path).splitlines()
headers = [ os.path.join(jittor_path, h) for h in headers1 ] + \
[ os.path.join(cache_path, h) for h in headers2 ]
basenames = []

View File

@ -33,7 +33,11 @@ jit_op_entry_t load_jit_lib(string name, string symbol_name="jit_entry") {
LOGvv << "Opening jit lib:" << name;
// void* handle = dlopen(name.c_str(), RTLD_NOW | RTLD_DEEPBIND | RTLD_LOCAL);
// RTLD_DEEPBIND and openmp cause segfault
#ifdef __linux__
void* handle = dlopen(name.c_str(), RTLD_NOW | RTLD_LOCAL | RTLD_DEEPBIND);
#else
void *handle = dlopen(name.c_str(), RTLD_NOW | RTLD_LOCAL);
#endif
CHECK(handle) << "Cannot open library" << name << ":" << dlerror();
//dlerror();
@ -84,8 +88,8 @@ jit_op_entry_t compile(const string& jit_key, const string& src, const bool is_c
+ " '" + jit_src_path + "'" + other_src
+ cc_flags + extra_flags
+ " -o '" + jit_lib_path + "'";
cmd = python_path+" "+jittor_path+"/utils/asm_tuner.py "
"--cc_path=" + cmd;
// cmd = python_path+" "+jittor_path+"/utils/asm_tuner.py "
// "--cc_path=" + cmd;
}
cache_compile(cmd, cache_path, jittor_path);
auto symbol_name = get_symbol_name(jit_key);

View File

@ -166,9 +166,11 @@ inline JK& operator<<(JK& jk, int64 c) {
return jk << JK::hex(c);
}
#ifdef __linux__
inline JK& operator<<(JK& jk, long long int c) {
return jk << (int64)c;
}
#endif
inline JK& operator<<(JK& jk, uint64 c) {
return jk << JK::hex(c);

View File

@ -6,7 +6,12 @@
// ***************************************************************
#include <iomanip>
#include <algorithm>
#if defined(__linux__)
#include <sys/sysinfo.h>
#elif defined(__APPLE__)
#include <sys/sysctl.h>
#endif
#include "var.h"
#include "op.h"
@ -152,9 +157,17 @@ void display_memory_info(const char* fileline, bool dump_var, bool red_color) {
}
MemInfo::MemInfo() {
#if defined(__linux__)
struct sysinfo info = {0};
sysinfo(&info);
total_cpu_ram = info.totalram;
#elif defined(__APPLE__)
int mib[] = {CTL_HW, HW_MEMSIZE};
int64 mem;
size_t len;
total_cpu_ram = sysctl(mib, 2, &mem, &len, NULL, 0);
#endif
total_cuda_ram = 0;
#ifdef HAS_CUDA
cudaDeviceProp prop;

View File

@ -14,7 +14,6 @@
#include "mem/allocator/sfrl_allocator.h"
#include <iomanip>
#include <algorithm>
#include <sys/sysinfo.h>
#include <sstream>
#include "pybind/py_var_tracer.h"

View File

@ -4,7 +4,7 @@
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include <bits/stdc++.h>
#include <cstring>
#include "misc/nano_string.h"
namespace jittor {

View File

@ -159,13 +159,8 @@ struct NanoVector {
for (auto a : v) push_back_check_overflow(a);
}
inline static NanoVector make(const int64* v, int n) {
NanoVector nv;
for (int i=0; i<n; i++) nv.push_back_check_overflow(v[i]);
return nv;
}
inline static NanoVector make(const int32* v, int n) {
template<typename TMakeV>
inline static NanoVector make(const TMakeV* v, int n) {
NanoVector nv;
for (int i=0; i<n; i++) nv.push_back_check_overflow(v[i]);
return nv;

View File

@ -51,7 +51,9 @@ struct RingBuffer {
// a dirty hack
// ref: https://stackoverflow.com/questions/20439404/pthread-conditions-and-process-termination
// cv.__data.__wrefs = 0;
#ifdef __linux__
cv.__data = {0};
#endif
pthread_cond_destroy(&cv);
}

View File

@ -5,12 +5,22 @@
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#pragma once
#if defined(__clang__)
#include <string_view>
#elif defined(__GNUC__)
#include <experimental/string_view>
#endif
#include "common.h"
namespace jittor {
#if defined(__clang__)
using std::string_view;
#elif defined(__GNUC__)
using std::experimental::string_view;
#endif
template<class T>
struct string_view_map {

View File

@ -144,7 +144,8 @@ void GetitemOp::infer_slices(
out_shape_j = (slice.stop - slice.start - 1) / slice.step + 1;
else
out_shape_j = (slice.start - slice.stop - 1) / -slice.step + 1;
out_shape_j = std::max(0l, out_shape_j);
out_shape_j = out_shape_j > 0 ? out_shape_j : 0;
}
out_shape.push_back(out_shape_j);
}

View File

@ -58,7 +58,11 @@ void Profiler::stop() {
unique_ptr<MemoryChecker>* load_memory_checker(string name) {
LOGvv << "Opening jit lib:" << name;
void* handle = dlopen(name.c_str(), RTLD_LAZY | RTLD_DEEPBIND | RTLD_LOCAL);
#ifdef __linux__
void *handle = dlopen(name.c_str(), RTLD_LAZY | RTLD_DEEPBIND | RTLD_LOCAL);
#else
void* handle = dlopen(name.c_str(), RTLD_LAZY | RTLD_LOCAL);
#endif
CHECK(handle) << "Cannot open library" << name << ":" << dlerror();
//dlerror();

View File

@ -136,7 +136,7 @@ ArrayOp::ArrayOp(PyObject* obj) {
std::memcpy(host_ptr, args.ptr, size);
} else {
// this is non-continue numpy array
int64 dims[args.shape.size()];
long dims[args.shape.size()];
for (int i=0; i<args.shape.size(); i++)
dims[i] = args.shape[i];
holder.assign(PyArray_New(

View File

@ -266,7 +266,7 @@ DEF_IS(ArrayArgs, bool) is_type(PyObject* obj) {
}
DEF_IS(ArrayArgs, PyObject*) to_py_object(const T& a) {
int64 dims[a.shape.size()];
long dims[a.shape.size()];
for (int i=0; i<a.shape.size(); i++)
dims[i] = a.shape[i];
PyObjHolder obj(PyArray_SimpleNew(
@ -378,7 +378,7 @@ DEF_IS(VarHolder*, T) from_py_object(PyObject* obj, unique_ptr<VarHolder>& holde
struct DataView;
DEF_IS(DataView, PyObject*) to_py_object(T a) {
int64 dims[a.shape.size()];
long dims[a.shape.size()];
for (int i=0; i<a.shape.size(); i++)
dims[i] = a.shape[i];
PyObjHolder oh(PyArray_New(

View File

@ -109,7 +109,7 @@ static void push_py_object(RingBuffer* rb, PyObject* obj, uint64& __restrict__ o
rb->push_t<NanoString>(args.dtype, offset);
rb->push(size, offset);
args.ptr = rb->get_ptr(size, offset);
int64 dims[args.shape.size()];
long dims[args.shape.size()];
for (int i=0; i<args.shape.size(); i++)
dims[i] = args.shape[i];
PyObjHolder oh(PyArray_New(

View File

@ -187,6 +187,8 @@ bool cache_compile(const string& cmd, const string& cache_path, const string& ji
for (size_t i=0; i<input_names.size(); i++) {
if (processed.count(input_names[i]) != 0)
continue;
if (input_names[i] == "dynamic_lookup")
continue;
processed.insert(input_names[i]);
auto src = read_all(input_names[i]);
ASSERT(src.size()) << "Source read failed:" << input_names[i];

View File

@ -12,7 +12,11 @@
#endif
#ifdef __GNUC__
#endif
#ifdef __linux__
#include <sys/prctl.h>
#endif
#include <signal.h>
#include <iterator>
#include <algorithm>
@ -21,7 +25,10 @@
namespace jittor {
void init_subprocess() {
#ifdef __linux__
prctl(PR_SET_PDEATHSIG, SIGKILL);
#endif
}
static void __log(

View File

@ -193,7 +193,11 @@ void segfault_sigaction(int signal, siginfo_t *si, void *arg) {
LOGe << "Caught SIGINT, quick exit";
}
exited = true;
#ifdef __APPLE__
_Exit(1);
#else
std::quick_exit(1);
#endif
}
std::cerr << "Caught segfault at address " << si->si_addr << ", "
<< "thread_name: '" << thread_name << "', flush log..." << std::endl;

View File

@ -7,8 +7,10 @@
#include <stdio.h>
#include <stdlib.h>
#include <sys/wait.h>
#include <unistd.h>
#ifdef __linux__
#include <sys/prctl.h>
#endif
#include <unistd.h>
#include <execinfo.h>
#include <iostream>
#include "utils/tracer.h"
@ -61,7 +63,9 @@ void setter_gdb_attach(int v) {
exit(1);
} else {
// allow children ptrace parent
#ifdef __linux__
prctl(PR_SET_PTRACER, child_pid, 0, 0, 0);
#endif
// sleep 5s, wait gdb attach
sleep(5);
}
@ -118,7 +122,9 @@ void print_trace() {
exit(0);
} else {
// allow children ptrace parent
#ifdef __linux__
prctl(PR_SET_PTRACER, child_pid, 0, 0, 0);
#endif
waitpid(child_pid,NULL,0);
}
} else {

View File

@ -154,6 +154,8 @@ def pool_cleanup():
del p
def pool_initializer():
if cc is None:
try_import_jit_utils_core()
cc.init_subprocess()
def run_cmds(cmds, cache_path, jittor_path, msg="run_cmds"):

View File

@ -1,13 +1,13 @@
error_msg = """Jittor only supports Ubuntu>=16.04 currently.
error_msg = """Jittor only supports Linux and macOS currently.
For other OS, use Jittor may be risky.
If you insist on installing, please set the environment variable : export FORCE_INSTALL=1
We strongly recommended docker installation:
We strongly recommend docker installation:
# CPU only(Linux)
# CPU only (Linux)
>>> docker run -it --network host jittor/jittor
# CPU and CUDA(Linux)
# CPU and CUDA (Linux)
>>> docker run -it --network host jittor/jittor-cuda
# CPU only(Mac and Windows)
# CPU only (Mac and Windows)
>>> docker run -it -p 8888:8888 jittor/jittor
Reference:
@ -15,19 +15,10 @@ Reference:
"""
from warnings import warn
import os
try:
with open("/etc/os-release", "r", encoding='utf8') as f:
s = f.read().splitlines()
m = {}
for line in s:
a = line.split('=')
if len(a) != 2: continue
m[a[0]] = a[1].replace("\"", "")
# assert m["NAME"] == "Ubuntu" and float(m["VERSION_ID"].split('.')[0])>=16, error_msg
except Exception as e:
print(e)
warn(error_msg)
if os.environ.get("FORCE_INSTALL", '0') != '1': raise
import platform
if not platform.system() in ['Linux', 'Darwin']:
assert os.environ.get("FORCE_INSTALL", '0') != '1', error_msg
import setuptools
from setuptools import setup, find_packages