docker install readme

This commit is contained in:
Dun Liang 2020-05-15 15:41:16 +08:00
parent 203208d74f
commit 510a8e8b8b
17 changed files with 115 additions and 10 deletions

View File

@ -1 +1,2 @@
Dockerfile
**/publish.py

View File

@ -1,6 +1,4 @@
# docker build commands
# docker build --tag jittor/jittor:latest . --network host
# docker build --tag jittor/jittor-cuda:latest --build-arg FROM_IMAGE="nvidia/cuda:10.2-cudnn7-devel-ubuntu18.04" . --network host
ARG FROM_IMAGE=ubuntu:18.04
FROM ${FROM_IMAGE}

View File

@ -76,6 +76,18 @@ for i,(x,y) in enumerate(get_data(n)):
## 安装
我们提供了Docker安装方式免去您配置环境Docker安装方法如下
```
# CPU only
docker run -it --network host jittor/jittor
# CPU and CUDA
docker run -it --network host jittor/jittor-cuda
```
关于Docker安装的详细教程可以参考[Windows/Mac/Linux通过Docker安装计图](https://cg.cs.tsinghua.edu.cn/jittor/tutorial/2020-5-15-00-00-docker/)
Jittor使用Python和C++编写。 它需要用于即时编译的编译器。当前,我们支持三种编译器:

View File

@ -76,6 +76,17 @@ We provide some jupyter notebooks to help you quick start with Jittor.
## Install
We provide a Docker installation method to save you from configuring the environment. The Docker installation method is as follows:
```
# CPU only
docker run -it --network host jittor/jittor
# CPU and CUDA
docker run -it --network host jittor/jittor-cuda
```
Jittor is written in Python and C++. It requires a compiler for JIT compilation, Currently, we support four compilers:

View File

@ -94,6 +94,19 @@ We provide some jupyter notebooks to help you quick start with Jittor.
## 安装
我们提供了Docker安装方式免去您配置环境Docker安装方法如下
We provide a Docker installation method to save you from configuring the environment. The Docker installation method is as follows:
```
# CPU only
docker run -it --network host jittor/jittor
# CPU and CUDA
docker run -it --network host jittor/jittor-cuda
```
关于Docker安装的详细教程可以参考[Windows/Mac/Linux通过Docker安装计图](https://cg.cs.tsinghua.edu.cn/jittor/tutorial/2020-5-15-00-00-docker/)
Jittor is written in Python and C++. It requires a compiler for JIT compilation, Currently, we support four compilers:
Jittor使用Python和C++编写。 它需要用于即时编译的编译器。当前,我们支持三种编译器:

View File

@ -7,6 +7,7 @@
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include "cublas_warper.h"
#include "misc/cuda_flags.h"
namespace jittor {
@ -15,11 +16,13 @@ cublasHandle_t cublas_handle;
struct cublas_initer {
inline cublas_initer() {
if (!get_device_count()) return;
checkCudaErrors(cublasCreate(&cublas_handle));
LOGv << "cublasCreate finished";
}
inline ~cublas_initer() {
if (!get_device_count()) return;
checkCudaErrors(cublasDestroy(cublas_handle));
LOGv << "cublasDestroy finished";
}

View File

@ -4,6 +4,7 @@
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include "cudnn_warper.h"
#include "misc/cuda_flags.h"
namespace jittor {
@ -17,11 +18,13 @@ void set_algorithm_cache_size(int size) {
struct cudnn_initer {
inline cudnn_initer() {
if (!get_device_count()) return;
checkCudaErrors(cudnnCreate(&cudnn_handle));
LOGv << "cudnnCreate finished";
}
inline ~cudnn_initer() {
if (!get_device_count()) return;
checkCudaErrors(cudnnDestroy(cudnn_handle));
LOGv << "cudnnDestroy finished";
}

View File

@ -8,6 +8,7 @@
// ***************************************************************
#include "curand_warper.h"
#include "init.h"
#include "misc/cuda_flags.h"
namespace jittor {
@ -16,6 +17,7 @@ curandGenerator_t gen;
struct curand_initer {
inline curand_initer() {
if (!get_device_count()) return;
checkCudaErrors( curandCreateGenerator(&gen, CURAND_RNG_PSEUDO_DEFAULT) );
add_set_seed_callback([](int seed) {
checkCudaErrors( curandSetPseudoRandomGeneratorSeed(gen, seed) );
@ -24,6 +26,7 @@ inline curand_initer() {
}
inline ~curand_initer() {
if (!get_device_count()) return;
checkCudaErrors( curandDestroyGenerator(gen) );
LOGv << "curandDestroy finished";
}

View File

@ -22,6 +22,7 @@ ncclUniqueId id;
struct nccl_initer {
nccl_initer() {
if (!get_device_count()) return;
if (mpi_world_rank == 0)
checkCudaErrors(ncclGetUniqueId(&id));
MPI_CHECK(MPI_Bcast((void *)&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD));
@ -34,6 +35,7 @@ nccl_initer() {
}
~nccl_initer() {
if (!get_device_count()) return;
checkCudaErrors(ncclCommDestroy(comm));
}

View File

@ -90,9 +90,10 @@ def install_cub(root_folder):
with tarfile.open(fullname, "r") as tar:
tar.extractall(root_folder)
assert 0 == os.system(f"cd {dirname}/examples && "
f"{nvcc_path} device/example_device_radix_sort.cu -O2 -I.. -o test && ./test")
f"{nvcc_path} device/example_device_radix_sort.cu -O2 -I.. -o test")
if core.get_device_count():
assert 0 == os.system(f"cd {dirname}/examples && ./test")
return dirname
def setup_cub():

View File

@ -95,7 +95,7 @@ class TestResnet(unittest.TestCase):
-jt.flags.stat_allocator_total_free_byte
# assert mem_used < 4e9, mem_used
# TODO: why bigger?
assert mem_used < 5e9, mem_used
assert mem_used < 5.5e9, mem_used
# example log:
# Train Epoch: 0 [0/100 (0%)] Loss: 2.352903 Acc: 0.110000
# Train Epoch: 0 [1/100 (1%)] Loss: 2.840830 Acc: 0.080000

View File

@ -0,0 +1,39 @@
#!/usr/bin/python3
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors:
# Dun Liang <randonlang@gmail.com>.
# All Rights Reserved.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
# Publish steps:
# 1. build,push,upload docker image[jittor/jittor]
# 2. build,push,upload docker image[jittor/jittor-cuda]
import os
def run_cmd(cmd):
print("[run cmd]", cmd)
assert os.system(cmd) == 0
def upload_file(path):
run_cmd(f"rsync -avPu {path} jittor-web:Documents/jittor-blog/assets/build/")
def docker_task(name, build_cmd):
run_cmd(build_cmd)
run_cmd(f"sudo docker push {name}")
bname = os.path.basename(name)
run_cmd(f"docker save {name}:latest -o /tmp/{bname}.tgz && chmod 666 /tmp/{bname}.tgz")
upload_file(f" /tmp/{bname}.tgz")
docker_task(
"jittor/jittor",
"sudo docker build --tag jittor/jittor:latest . --network host"
)
docker_task(
"jittor/jittor-cuda",
"sudo docker build --tag jittor/jittor-cuda:latest --build-arg FROM_IMAGE='nvidia/cuda:10.2-cudnn7-devel-ubuntu18.04' . --network host"
)
run_cmd("ssh jittor-web Documents/jittor-blog.git/hooks/post-update")

View File

@ -397,7 +397,7 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
LOGvv << "All" << op_num << "ops finished, return vars:" << vars;
for (Var* v : vars) ASSERT(v->mem_ptr);
#ifdef HAS_CUDA
if (device_sync) {
if (device_sync && use_cuda) {
last_is_cuda = false;
sync_times++;
event_queue.run_sync([]() {

View File

@ -40,10 +40,12 @@ static void to_fetch(CUDA_HOST_FUNC_ARGS) {
struct Init {
Init() {
if (!get_device_count()) return;
checkCudaErrors(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
checkCudaErrors(cudaEventCreate(&event, cudaEventDisableTiming));
}
~Init() {
if (!get_device_count()) return;
// do not call deleter on exit
for (auto& f : fetch_tasks)
f.func.deleter = nullptr;

View File

@ -3,8 +3,11 @@
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
// #include "misc/cuda_flags.h"
#include "common.h"
#ifdef HAS_CUDA
#include <cuda_runtime.h>
#endif
namespace jittor {
@ -13,9 +16,12 @@ DEFINE_FLAG_WITH_SETTER(int, use_cuda, 0,
void setter_use_cuda(int value) {
#ifdef HAS_CUDA
if (value)
if (value) {
int count=0;
cudaGetDeviceCount(&count);
CHECK(count>0) << "No device found.";
LOGi << "CUDA enabled.";
else
} else
LOGi << "CUDA disabled.";
#else
CHECK(value==0) << "No CUDA found.";

View File

@ -14,6 +14,13 @@ namespace jittor {
DECLARE_FLAG(int, use_cuda);
// @pyjt(get_device_count)
inline int get_device_count() {
int count=0;
cudaGetDeviceCount(&count);
return count;
}
} // jittor
#if CUDART_VERSION < 10000
@ -32,5 +39,7 @@ namespace jittor {
constexpr int use_cuda = 0;
inline int get_device_count() { return 0; }
} // jittor
#endif

View File

@ -26,10 +26,12 @@ cudaEvent_t event;
struct Init {
Init() {
if (!get_device_count()) return;
checkCudaErrors(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
checkCudaErrors(cudaEventCreate(&event, cudaEventDisableTiming));
}
~Init() {
if (!get_device_count()) return;
checkCudaErrors(cudaDeviceSynchronize());
checkCudaErrors(cudaStreamDestroy(stream));
checkCudaErrors(cudaEventDestroy(event));