1.Optimize bias add grad kernel

2.Optimize slice grad kernel

3.Add Unet GPU Model
This commit is contained in:
fan1997 2020-12-11 17:29:29 +08:00
parent b12b780163
commit be3d4e6fd3
16 changed files with 764 additions and 231 deletions

View File

@ -19,6 +19,8 @@
#include <vector>
#include <algorithm>
#include <string>
#include <utility>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/cuda_impl/slice_impl.cuh"
@ -39,8 +41,9 @@ class SliceGradGpuKernel : public GpuKernel {
T *dy = GetDeviceAddress<T>(inputs, 0);
T *dx = GetDeviceAddress<T>(outputs, 0);
FillDeviceArray(outputs[0]->size / sizeof(T), dx, 0.f, reinterpret_cast<cudaStream_t>(stream_ptr));
CalSliceGrad(output_size_ / sizeof(T), dy, input_shape_, begin_, size_, dx,
reinterpret_cast<cudaStream_t>(stream_ptr));
CalSlice4DGrad(begin_[0], begin_[1], begin_[2], begin_[3], size_[0], size_[1], size_[2], size_[3], input_shape_[0],
input_shape_[1], input_shape_[2], input_shape_[3], dy, dx,
reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
@ -49,6 +52,7 @@ class SliceGradGpuKernel : public GpuKernel {
return false;
}
auto kernel_name = AnfAlgo::GetCNodeName(kernel_node);
auto data_format = AnfAlgo::GetInputFormat(kernel_node, 0);
if (kernel_name == "StridedSliceGrad") {
is_strided_slice_ = true;
std::vector<int64_t> shapex = GetAttr<std::vector<int64_t>>(kernel_node, "shapex");
@ -64,17 +68,15 @@ class SliceGradGpuKernel : public GpuKernel {
}
size_ = GetAttr<std::vector<int64_t>>(kernel_node, "end");
} else {
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
ShapeNdTo4d(input_shape, &input_shape_);
size_ = GetAttr<std::vector<int64_t>>(kernel_node, "size");
}
auto dy_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
auto dy_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
ShapeNdTo4d(dy_shape, &dy_shape_);
begin_ = GetAttr<std::vector<int64_t>>(kernel_node, "begin");
DealParam();
CalcBeginAndSize(data_format);
input_size_ = input_shape_[0] * input_shape_[1] * input_shape_[2] * input_shape_[3] * sizeof(T);
output_size_ = sizeof(T);
for (auto x : dy_shape_) {
output_size_ = output_size_ * x;
@ -89,6 +91,30 @@ class SliceGradGpuKernel : public GpuKernel {
input_size_list_.push_back(input_size_);
output_size_list_.push_back(input_size_);
}
void CalcBeginAndSize(const std::string data_format) {
for (auto i = begin_.size(); i < 4; i++) {
(void)begin_.insert(begin_.begin(), 0);
}
for (auto i = size_.size(); i < 4; i++) {
(void)size_.insert(size_.begin(), 1);
}
if (data_format == "NHWC") {
std::swap(begin_[1], begin_[3]);
std::swap(begin_[1], begin_[2]);
std::swap(size_[1], size_[3]);
std::swap(size_[1], size_[2]);
}
for (size_t i = 0; i < begin_.size(); i++) {
if (begin_[i] < 0) {
begin_[i] = begin_[i] + input_shape_[i];
}
}
for (size_t i = 0; i < size_.size(); i++) {
if (size_[i] < 0) {
size_[i] = (size_[i] + input_shape_[i]) > 0 ? (size_[i] + input_shape_[i]) : 0;
}
}
}
private:
bool CheckParam(const CNodePtr &kernel_node) {
@ -108,24 +134,7 @@ class SliceGradGpuKernel : public GpuKernel {
}
return true;
}
void DealParam() {
for (auto i = begin_.size(); i < 4; i++) {
(void)begin_.insert(begin_.begin(), 0);
}
for (auto i = size_.size(); i < 4; i++) {
(void)size_.insert(size_.begin(), 1);
}
for (size_t i = 0; i < begin_.size(); i++) {
if (begin_[i] < 0) {
begin_[i] = begin_[i] + input_shape_[i];
}
}
for (size_t i = 0; i < size_.size(); i++) {
if (size_[i] < 0) {
size_[i] = (size_[i] + input_shape_[i]) > 0 ? (size_[i] + input_shape_[i]) : 0;
}
}
}
std::vector<int64_t> begin_;
std::vector<int64_t> size_;
std::vector<int64_t> strides_;

View File

@ -0,0 +1,175 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <stdio.h>
#include <stdint.h>
#include <cuda_runtime.h>
#include "backend/kernel_compiler/gpu/cuda_impl/util.cuh"
#include "runtime/device/gpu/cuda_common.h"
#include "backend/kernel_compiler/gpu/cuda_impl/bias_add_grad_impl.cuh"
const int kWarpSize = 32;
// tuning param, for those nhw >= kLargeSize, launch more blocks to solve
const int kLargeSize = 500000; // tuning param for BiasAddGradNHWC
const int kNumBlocks = 8; // tuning param for BiasAddGradNHWC
// For NHWC bias add grad, combine dy's NHW together, matrix column reduce.
// This is a simple implementation, can be further optimized when C is small.
// Firstly, Each warp sums several rows, each thread's partial_sum is the sum of
// a part of one cloumn.
// Secondly, in order to sum up all values in one column, which is to sum up the partial_sum
// in different warps but with the same lane_id, each warp store their partial_sums
// to one row of shared mem, and read partial_sums from one col of shared mem.
// Then each warp do warp reduce to sum up 32 partial_sums, and write final result to db
// For larger NHW, one block is not enough to sum up all rows, needs to launch more blocks.
template<typename T>
__global__ void BiasAddGradNHWC(const T* dy, T* db, const size_t m, const size_t n,
const size_t rows_per_block, size_t rows_per_warp) {
__shared__ float shared_d[kWarpSize][kWarpSize + 1]; // avoid bank conflict
int shm_row_id = (threadIdx.x >> 5);
int shm_col_id = (threadIdx.x % 32);
int block_start_row = blockIdx.x * rows_per_block;
int block_end_row = block_start_row + rows_per_block;
block_end_row = block_end_row < m ? block_end_row : m;
int warp_start_row = blockIdx.x * rows_per_block + shm_row_id * rows_per_warp;
int warp_end_row = warp_start_row + rows_per_warp;
int real_rows_per_warp = warp_end_row < block_end_row ? rows_per_warp : block_end_row - warp_start_row;
// boundary process
// Only the last row or column may not have the full size
bool full_tile = true;
int tile_width_real = 32;
if (blockIdx.y == blockDim.y - 1) {
tile_width_real = n - (blockDim.y - 1) * 32;
full_tile = (tile_width_real == 32);
}
int read_offset = warp_start_row * n + (blockIdx.y << 5) + shm_col_id;
float partial_sum = 0.0;
if (full_tile) {
for (int i = 0; i < real_rows_per_warp; i++) {
partial_sum += static_cast<float>(dy[read_offset]);
read_offset += n;
}
} else {
if (shm_col_id < tile_width_real) {
for (int i = 0; i < real_rows_per_warp; i++) {
partial_sum += static_cast<float>(dy[read_offset]);
read_offset += n;
}
}
}
shared_d[shm_row_id][shm_col_id] = partial_sum;
__syncthreads();
partial_sum = shared_d[shm_col_id][shm_row_id];
__syncthreads();
for (int offset = kWarpSize / 2; offset > 0; offset /= 2) {
partial_sum += __shfl_down_sync(0xffffffff, partial_sum, offset);
}
if (shm_col_id == 0) {
if (full_tile) {
MsAtomicAdd(db + (blockIdx.y << 5) + shm_row_id, T(partial_sum));
} else {
if (shm_row_id < tile_width_real) {
MsAtomicAdd(db + (blockIdx.y << 5) + shm_row_id, T(partial_sum));
}
}
}
}
template <typename T>
__global__ void BiasAddGradNCHW(const size_t size, const int batch, const int bias_size, const int h,
const int w, const int bg_size, const T* dy, T* db) {
__shared__ float shared_d[32];
for (int i = threadIdx.x; i < 32; i += blockDim.x) {
shared_d[i] = static_cast<float>(0);
}
__syncthreads();
float sum = 0.;
int lane_id = threadIdx.x % 32;
int thread_id = threadIdx.x;
int img_size = h * w;
// N*H*W -> count / bg_size equals the amount of work one block should reduce
int count = batch * img_size;
int bg_offset = blockIdx.x % bias_size;
int bg_id = blockIdx.x / bias_size;
for (int i = bg_id * blockDim.x + threadIdx.x; // thread start
i < count; i += blockDim.x * bg_size) {
int img_offset = i % img_size;
int img_id = i / img_size;
T val = *(dy + (img_id * bias_size + bg_offset) * img_size + img_offset);
sum += static_cast<float>(val);
}
MsAtomicAdd(shared_d + lane_id, sum);
__syncthreads();
if (thread_id < 32) {
float data = shared_d[thread_id];
for (int offset = kWarpSize / 2; offset > 0; offset /= 2) {
data += __shfl_xor_sync(0xffffffff, data, offset);
}
if (thread_id == 0) {
MsAtomicAdd(db + bg_offset, T(data));
}
}
}
template <typename T>
__global__ void FillDb(T *db, const size_t bias_size) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < bias_size; pos += blockDim.x * gridDim.x) {
db[pos] = T(0.0);
}
}
template <typename T>
void CalBiasAddGradNCHW(const size_t size, const size_t bias_size, const int height, const int width,
const T* dy, T* db, cudaStream_t cuda_stream) {
int batch_size = size / bias_size / height / width;
int block_num = GET_BLOCKS(size);
int thread_num = GET_THREADS;
// how many blocks to solve one bias's reduce work(N * H * W)
int block_group_size = (block_num + bias_size - 1) / bias_size;
block_num = block_group_size * bias_size;
if (thread_num < kWarpSize) {
thread_num = kWarpSize;
}
FillDb<<<GET_BLOCKS(bias_size), GET_THREADS, 0, cuda_stream>>>(db, bias_size);
BiasAddGradNCHW<<<block_num, thread_num, 0, cuda_stream>>>(size, batch_size, bias_size, height,
width, block_group_size, dy, db);
return;
}
template <typename T>
void CalBiasAddGradNHWC(const size_t size, const size_t bias_size,
const T* dy, T* db, cudaStream_t cuda_stream) {
FillDb<<<GET_BLOCKS(bias_size), GET_THREADS, 0, cuda_stream>>>(db, bias_size);
size_t rows = size/bias_size;
int block_num_x = rows <= kLargeSize ? 1 : kNumBlocks;
int block_num_y = (bias_size + kWarpSize - 1) / kWarpSize;
dim3 grid_size(block_num_x, block_num_y, 1);
dim3 block_size(kWarpSize*kWarpSize);
size_t rows_per_block = (rows + block_num_x - 1) / block_num_x;
size_t rows_per_warp = (rows_per_block + kWarpSize - 1) / kWarpSize;
BiasAddGradNHWC<<<grid_size, block_size, 0, cuda_stream>>>(dy, db, rows, bias_size,
rows_per_block, rows_per_warp);
return;
}
template void CalBiasAddGradNCHW(const size_t size, const size_t bias_size, const int height, const int width,
const float* dy, float* db, cudaStream_t cuda_stream);
template void CalBiasAddGradNCHW(const size_t size, const size_t bias_size, const int height, const int width,
const half* dy, half* db, cudaStream_t cuda_stream);
template void CalBiasAddGradNHWC(const size_t size, const size_t bias_size,
const float* dy, float* db, cudaStream_t cuda_stream);
template void CalBiasAddGradNHWC(const size_t size, const size_t bias_size, const half* dy,
half* db, cudaStream_t cuda_stream);

View File

@ -0,0 +1,27 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BIASADDGRAD_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BIASADDGRAD_H_
#include "runtime/device/gpu/cuda_common.h"
template <typename T>
void CalBiasAddGradNHWC(const size_t size, const size_t bias_size,
const T* dy, T* db, cudaStream_t cuda_stream);
template <typename T>
void CalBiasAddGradNCHW(const size_t size, const size_t bias_size, const int height, const int width,
const T* dy, T* db, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BIASADDGRAD_H_

View File

@ -34,12 +34,20 @@ __global__ void Slice4D(const size_t s1, const size_t s2, const size_t s3, const
output[pos] = input[offset];
}
}
template <typename T>
__global__ void SliceGrad(const T *dy, int64_t p, int64_t start, int64_t length, T *output) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (length); pos += blockDim.x * gridDim.x) {
output[start + pos] = dy[p + pos];
__global__ void Slice4DGrad(const size_t s1, const size_t s2, const size_t s3, const size_t s4,
const size_t l1, const size_t l2, const size_t l3, const size_t l4,
const size_t d1, const size_t d2, const size_t d3, const size_t d4,
const T *dy, T *dx) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (l1 * l2 * l3 * l4); pos += blockDim.x * gridDim.x) {
size_t i = pos / (l2 * l3 * l4) % l1;
size_t j = pos / (l3 * l4) % l2;
size_t k = pos / l4 % l3;
size_t o = pos % l4;
size_t input_idx = (i + s1) * (d2 * d3 * d4) + (j + s2) * (d3 * d4) + (k + s3) * d4 + (o + s4);
dx[input_idx] = dy[pos];
}
return;
}
template <typename T>
@ -62,24 +70,13 @@ void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size
Slice4D<<<GET_BLOCKS(l1 * l2 * l3 * l4), GET_THREADS, 0, stream>>>(s1, s2, s3, s4, l1, l2, l3, l4, d1, d2, d3, d4,
input, output);
}
template <typename T>
void CalSliceGrad(const size_t input_size, const T *dy, const std::vector<size_t> in_shape,
const std::vector<int64_t> begin, const std::vector<int64_t> size, T *output,
cudaStream_t cuda_stream) {
size_t block = in_shape[1] * in_shape[2] * in_shape[3];
size_t map = in_shape[2] * in_shape[3];
size_t w = in_shape[3];
int64_t length = size[3];
int64_t p = 0;
for (int64_t i = begin[0]; i < size[0] + begin[0]; i++) {
for (int64_t j = begin[1]; j < size[1] + begin[1]; j++) {
for (int64_t k = begin[2]; k < size[2] + begin[2]; k++) {
SliceGrad<<<GET_BLOCKS(input_size), GET_THREADS, 0, cuda_stream>>>(
dy, p, i * block + j * map + k * w + begin[3], length, output);
p = p + size[3];
}
}
}
void CalSlice4DGrad(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1,
const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2,
const size_t d3, const size_t d4, const T *dy, T *dx, cudaStream_t stream) {
Slice4DGrad<<<GET_BLOCKS(l1 * l2 * l3 * l4), GET_THREADS, 0, stream>>>(s1, s2, s3, s4, l1, l2, l3, l4, d1, d2, d3, d4,
dy, dx);
}
template <typename T>
@ -168,10 +165,6 @@ template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, c
template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1,
const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2,
const size_t d3, const size_t d4, const half *input, half *output, cudaStream_t stream);
template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1,
const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2,
const size_t d3, const size_t d4, const int64_t *input, int64_t *output,
cudaStream_t stream);
template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1,
const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2,
const size_t d3, const size_t d4, const int *input, int *output, cudaStream_t stream);
@ -183,36 +176,43 @@ template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, c
const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2,
const size_t d3, const size_t d4, const unsigned char *input, unsigned char *output,
cudaStream_t stream);
template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1,
const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2,
const size_t d3, const size_t d4, const int64_t *input, int64_t *output,
cudaStream_t stream);
template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1,
const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2,
const size_t d3, const size_t d4, const bool *input, bool *output, cudaStream_t stream);
template void CalSliceGrad<double>(const size_t input_size, const double *dy, const std::vector<size_t> in_shape,
const std::vector<int64_t> begin, const std::vector<int64_t> size, double *output,
cudaStream_t cuda_stream);
template void CalSliceGrad<float>(const size_t input_size, const float *dy, const std::vector<size_t> in_shape,
const std::vector<int64_t> begin, const std::vector<int64_t> size, float *output,
cudaStream_t cuda_stream);
template void CalSliceGrad<half>(const size_t input_size, const half *dy, const std::vector<size_t> in_shape,
const std::vector<int64_t> begin, const std::vector<int64_t> size, half *output,
cudaStream_t cuda_stream);
template void CalSliceGrad<int64_t>(const size_t input_size, const int64_t *dy, const std::vector<size_t> in_shape,
const std::vector<int64_t> begin, const std::vector<int64_t> size, int64_t *output,
cudaStream_t cuda_stream);
template void CalSliceGrad<int>(const size_t input_size, const int *dy, const std::vector<size_t> in_shape,
const std::vector<int64_t> begin, const std::vector<int64_t> size, int *output,
cudaStream_t cuda_stream);
template void CalSliceGrad<short>(const size_t input_size, const short *dy, // NOLINT
const std::vector<size_t> in_shape, const std::vector<int64_t> begin,
const std::vector<int64_t> size, short *output, // NOLINT
cudaStream_t cuda_stream);
template void CalSliceGrad<unsigned char>(const size_t input_size, const unsigned char *dy,
const std::vector<size_t> in_shape, const std::vector<int64_t> begin,
const std::vector<int64_t> size, unsigned char *output,
cudaStream_t cuda_stream);
template void CalSliceGrad<bool>(const size_t input_size, const bool *dy, const std::vector<size_t> in_shape,
const std::vector<int64_t> begin, const std::vector<int64_t> size, bool *output,
cudaStream_t cuda_stream);
template void CalSlice4DGrad<double>(const size_t s1, const size_t s2, const size_t s3, const size_t s4,
const size_t l1, const size_t l2, const size_t l3, const size_t l4,
const size_t d1, const size_t d2, const size_t d3, const size_t d4,
const double *dy, double *dx, cudaStream_t stream);
template void CalSlice4DGrad<float>(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1,
const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2,
const size_t d3, const size_t d4, const float *dy, float *dx, cudaStream_t stream);
template void CalSlice4DGrad<half>(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1,
const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2,
const size_t d3, const size_t d4, const half *dy, half *dx, cudaStream_t stream);
template void CalSlice4DGrad<int>(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1,
const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2,
const size_t d3, const size_t d4, const int *dy, int *dx, cudaStream_t stream);
template void CalSlice4DGrad<short>(const size_t s1, const size_t s2, const size_t s3, const size_t s4, // NOLINT
const size_t l1,
const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2,
const size_t d3, const size_t d4, const short *dy, short *dx, // NOLINT
cudaStream_t stream);
template void CalSlice4DGrad<unsigned char>(const size_t s1, const size_t s2, const size_t s3, const size_t s4,
const size_t l1, const size_t l2, const size_t l3, const size_t l4,
const size_t d1, const size_t d2, const size_t d3, const size_t d4,
const unsigned char *dy, unsigned char *dx, cudaStream_t stream);
template void CalSlice4DGrad<int64_t>(const size_t s1, const size_t s2, const size_t s3, const size_t s4,
const size_t l1, const size_t l2, const size_t l3, const size_t l4,
const size_t d1, const size_t d2, const size_t d3, const size_t d4,
const int64_t *dy, int64_t *dx, cudaStream_t stream);
template void CalSlice4DGrad<bool>(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1,
const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2,
const size_t d3, const size_t d4, const bool *dy, bool *dx, cudaStream_t stream);
template void FillDeviceArray<bool>(const size_t input_size, bool *addr, const float value, cudaStream_t cuda_stream);
template void FillDeviceArray<int64_t>(const size_t input_size, int64_t *addr, const float value,

View File

@ -26,9 +26,9 @@ void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size
const size_t l3, const size_t l4, const size_t d1, const size_t d2, const size_t d3, const size_t d4,
const T *input, T *output, cudaStream_t stream);
template <typename T>
void CalSliceGrad(const size_t input_size, const T *input, const std::vector<size_t> in_shape,
const std::vector<int64_t> begin, const std::vector<int64_t> size, T *output,
cudaStream_t cuda_stream);
void CalSlice4DGrad(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1,
const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2,
const size_t d3, const size_t d4, const T *dy, T *dx, cudaStream_t stream);
template <typename T>
void StridedSlice(const std::vector<size_t> &input_shape, const std::vector<int64_t> &begin,
const std::vector<int64_t> &strides, const std::vector<size_t> &output_shape, const T *input,

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -21,6 +21,6 @@ namespace kernel {
MS_REG_GPU_KERNEL_ONE(BiasAddGrad, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
BiasAddGradGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(BiasAddGrad, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
BiasAddGradGpuKernel, float16)
BiasAddGradGpuKernel, half)
} // namespace kernel
} // namespace mindspore

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -17,8 +17,6 @@
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_BIAS_ADD_GRAD_GPU_KENEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_BIAS_ADD_GRAD_GPU_KENEL_H_
#include <cuda_runtime_api.h>
#include <cublas_v2.h>
#include <vector>
#include <string>
#include <algorithm>
@ -26,6 +24,7 @@
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/kernel_constants.h"
#include "backend/kernel_compiler/gpu/cuda_impl/bias_add_grad_impl.cuh"
namespace mindspore {
namespace kernel {
@ -34,6 +33,10 @@ class BiasAddGradGpuKernel : public GpuKernel {
public:
BiasAddGradGpuKernel()
: same_dims_(true),
use_cudnn_(false),
dy_num_(1),
db_num_(1),
bias_size_(0),
cudnn_handle_(nullptr),
cudnn_data_type_(CUDNN_DATA_FLOAT),
dy_desc_(nullptr),
@ -48,118 +51,171 @@ class BiasAddGradGpuKernel : public GpuKernel {
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
T *dy_addr = GetDeviceAddress<T>(inputs, 0);
T *db_addr = GetDeviceAddress<T>(outputs, 0);
T *indices_addr = GetDeviceAddress<T>(workspace, 0);
T *workspace_addr = GetDeviceAddress<T>(workspace, 1);
const float alpha = 1;
const float beta = 0;
if (same_dims_) {
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
cudaMemcpyAsync(db_addr, dy_addr, output_size_list_[0], cudaMemcpyDeviceToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync failed.");
} else {
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_,
cudnnReduceTensor(cudnn_handle_, op_desc_, indices_addr, workspace_size_list_[0], workspace_addr,
workspace_size_list_[1], &alpha, dy_desc_, dy_addr, &beta, db_desc_, db_addr),
"cudnnReduceTensor failed");
if (use_cudnn_) { // shared memory not satisfied or num_dim > 4
T *indices_addr = GetDeviceAddress<T>(workspace, 0);
T *workspace_addr = GetDeviceAddress<T>(workspace, 1);
const float alpha = 1;
const float beta = 0;
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_,
cudnnReduceTensor(cudnn_handle_, op_desc_, indices_addr, workspace_size_list_[0], workspace_addr,
workspace_size_list_[1], &alpha, dy_desc_, dy_addr, &beta, db_desc_, db_addr),
"cudnnReduceTensor failed");
} else { // use own implementation which is more efficient but cannot process num_dim > 4
if (data_format_ == kOpFormat_NHWC) {
CalBiasAddGradNHWC(dy_num_, bias_size_, dy_addr, db_addr, reinterpret_cast<cudaStream_t>(stream_ptr));
} else {
CalBiasAddGradNCHW(dy_num_, bias_size_, SizeToInt(dy_shape_[2]), SizeToInt(dy_shape_[3]), dy_addr, db_addr,
reinterpret_cast<cudaStream_t>(stream_ptr));
}
}
}
return true;
}
bool Init(const CNodePtr &kernel_node) override {
kernel_node_ = kernel_node;
InitResource();
cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0)));
auto dy_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
auto num_dims = dy_shape.size();
if (num_dims < 2) {
MS_LOG(EXCEPTION) << "input dims must be at least 2, but got " << num_dims;
cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0)));
auto input_device_format = AnfAlgo::GetInputFormat(kernel_node, 0);
cudnn_compute_format_ = (input_device_format == kOpFormat_NHWC) ? CUDNN_TENSOR_NHWC : CUDNN_TENSOR_NCHW;
num_dims_ = dy_shape.size();
if (num_dims_ < 2) {
MS_LOG(EXCEPTION) << "input dims must be at least 2, but got " << num_dims_;
}
std::string format = GetAttr<std::string>(kernel_node, "format");
string::size_type pos = format.find("C");
if (pos == std::string::npos || pos >= num_dims) {
if (pos == std::string::npos || pos >= num_dims_) {
MS_LOG(EXCEPTION) << "format '" << format << "' invalid";
}
// Expand to 4 dims for cudnnSetTensorNdDescriptorEx.
auto cudnn_dims = std::max(num_dims, 4UL);
std::unique_ptr<int[]> dy_dims = std::make_unique<int[]>(cudnn_dims);
std::unique_ptr<int[]> db_dims = std::make_unique<int[]>(cudnn_dims);
for (size_t i = 0; i < cudnn_dims; i++) {
dy_dims[i] = (i < num_dims) ? SizeToInt(dy_shape[i]) : 1;
db_dims[i] = (i == pos) ? SizeToInt(dy_shape[i]) : 1;
if (dy_dims[i] != db_dims[i]) {
bias_size_ = dy_shape[pos];
auto num_dims_fix = std::max(num_dims_, 4UL);
for (size_t i = 0; i < num_dims_fix; i++) {
dy_shape_.push_back((i < num_dims_) ? dy_shape[i] : 1);
db_shape_.push_back((i == pos) ? dy_shape[i] : 1);
if (dy_shape_[i] != db_shape_[i]) {
same_dims_ = false;
}
}
auto input_device_format = AnfAlgo::GetInputFormat(kernel_node, 0);
auto cudnn_cal_format = (input_device_format == kOpFormat_NHWC) ? CUDNN_TENSOR_NHWC : CUDNN_TENSOR_NCHW;
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_,
cudnnSetTensorNdDescriptorEx(dy_desc_, cudnn_cal_format, cudnn_data_type_, SizeToInt(cudnn_dims), dy_dims.get()),
"cudnnSetTensorNdDescriptor failed");
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_,
cudnnSetTensorNdDescriptorEx(db_desc_, cudnn_cal_format, cudnn_data_type_, SizeToInt(cudnn_dims), db_dims.get()),
"cudnnSetTensorNdDescriptor failed");
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_,
cudnnSetReduceTensorDescriptor(op_desc_, CUDNN_REDUCE_TENSOR_ADD, CUDNN_DATA_FLOAT, CUDNN_NOT_PROPAGATE_NAN,
CUDNN_REDUCE_TENSOR_NO_INDICES, CUDNN_32BIT_INDICES),
"cudnnSetReduceTensorDescriptor failed");
for (size_t i = 0; i < dy_shape_.size(); i++) {
dy_num_ *= dy_shape_[i];
}
for (size_t i = 0; i < db_shape_.size(); i++) {
db_num_ *= db_shape_[i];
}
data_format_ = input_device_format; // for opt implementation
if (format == kOpFormat_NHWC) {
data_format_ = kOpFormat_NHWC;
}
MethodSelection();
InitResource();
InitSizeLists();
return true;
}
void DestroyResource() noexcept override {
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnDestroyReduceTensorDescriptor(op_desc_),
"cudnnDestroyReduceTensorDescriptor failed");
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(db_desc_),
"cudnnDestroyTensorDescriptor failed");
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(dy_desc_),
"cudnnDestroyOpTensorDescriptor failed");
if (use_cudnn_) {
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnDestroyReduceTensorDescriptor(op_desc_),
"cudnnDestroyReduceTensorDescriptor failed");
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(db_desc_),
"cudnnDestroyTensorDescriptor failed");
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(dy_desc_),
"cudnnDestroyOpTensorDescriptor failed");
}
}
protected:
void MethodSelection() {
// opt implementation can only process num_dims_ <= 4
// for num_dims_ = 2, not time-consuming, use cudnn
if (num_dims_ > 4 || num_dims_ == 2) {
use_cudnn_ = true;
return;
}
if (data_format_ == kOpFormat_NHWC) {
size_t required_sharedmem_size = 32 * 33 * sizeof(float);
// nhwc opt implementation performs not so well when bias_size_ <= 6
if (required_sharedmem_size > SHARED_MEM_PER_BLOCK || bias_size_ <= 6) {
use_cudnn_ = true;
return;
}
}
}
void InitResource() override {
cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle();
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&dy_desc_),
"cudnnCreateTensorDescriptor failed");
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&db_desc_),
"cudnnCreateTensorDescriptor failed");
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateReduceTensorDescriptor(&op_desc_),
"cudnnCreateOpTensorDescriptor failed");
if (use_cudnn_) {
cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle();
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&dy_desc_),
"cudnnCreateTensorDescriptor failed");
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&db_desc_),
"cudnnCreateTensorDescriptor failed");
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateReduceTensorDescriptor(&op_desc_),
"cudnnCreateOpTensorDescriptor failed");
// Expand to 4 dims for cudnnSetTensorNdDescriptorEx.
auto cudnn_dims = std::max(num_dims_, 4UL);
std::unique_ptr<int[]> dy_dims = std::make_unique<int[]>(cudnn_dims);
std::unique_ptr<int[]> db_dims = std::make_unique<int[]>(cudnn_dims);
for (size_t i = 0; i < cudnn_dims; i++) {
dy_dims[i] = SizeToInt(dy_shape_[i]);
db_dims[i] = SizeToInt(db_shape_[i]);
}
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
cudnnSetTensorNdDescriptorEx(dy_desc_, cudnn_compute_format_, cudnn_data_type_,
SizeToInt(cudnn_dims), dy_dims.get()),
"cudnnSetTensorNdDescriptor failed");
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
cudnnSetTensorNdDescriptorEx(db_desc_, cudnn_compute_format_, cudnn_data_type_,
SizeToInt(cudnn_dims), db_dims.get()),
"cudnnSetTensorNdDescriptor failed");
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_,
cudnnSetReduceTensorDescriptor(op_desc_, CUDNN_REDUCE_TENSOR_ADD, CUDNN_DATA_FLOAT, CUDNN_NOT_PROPAGATE_NAN,
CUDNN_REDUCE_TENSOR_NO_INDICES, CUDNN_32BIT_INDICES),
"cudnnSetReduceTensorDescriptor failed");
}
}
void InitSizeLists() override {
size_t dy_size, db_size;
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnGetTensorSizeInBytes(dy_desc_, &dy_size),
"cudnnGetTensorSizeInBytes failed");
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnGetTensorSizeInBytes(db_desc_, &db_size),
"cudnnGetTensorSizeInBytes failed");
input_size_list_.push_back(dy_size);
output_size_list_.push_back(db_size);
size_t indices_size, workspace_size;
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_, cudnnGetReductionIndicesSize(cudnn_handle_, op_desc_, dy_desc_, db_desc_, &indices_size),
"cudnnGetReductionIndicesSize failed")
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_, cudnnGetReductionWorkspaceSize(cudnn_handle_, op_desc_, dy_desc_, db_desc_, &workspace_size),
"cudnnGetReductionWorkspaceSize failed")
workspace_size_list_.push_back(indices_size);
workspace_size_list_.push_back(workspace_size);
if (use_cudnn_) {
size_t dy_size, db_size;
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnGetTensorSizeInBytes(dy_desc_, &dy_size),
"cudnnGetTensorSizeInBytes failed");
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnGetTensorSizeInBytes(db_desc_, &db_size),
"cudnnGetTensorSizeInBytes failed");
input_size_list_.push_back(dy_size);
output_size_list_.push_back(db_size);
size_t indices_size, workspace_size;
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_, cudnnGetReductionIndicesSize(cudnn_handle_, op_desc_, dy_desc_, db_desc_, &indices_size),
"cudnnGetReductionIndicesSize failed")
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_, cudnnGetReductionWorkspaceSize(cudnn_handle_, op_desc_, dy_desc_, db_desc_, &workspace_size),
"cudnnGetReductionWorkspaceSize failed")
workspace_size_list_.push_back(indices_size);
workspace_size_list_.push_back(workspace_size);
} else {
input_size_list_.push_back(dy_num_ * sizeof(T));
output_size_list_.push_back(db_num_ * sizeof(T));
}
}
private:
bool same_dims_;
bool use_cudnn_;
size_t dy_num_; // for own implementation
size_t db_num_;
size_t num_dims_;
size_t bias_size_; // for own implementation
std::vector<size_t> dy_shape_; // for own implementation
std::vector<size_t> db_shape_; // for own implementation
std::string data_format_ = kOpFormat_NCHW;
// for cudnn implementation
cudnnHandle_t cudnn_handle_;
cudnnDataType_t cudnn_data_type_;
cudnnTensorFormat_t cudnn_compute_format_;
cudnnTensorDescriptor_t dy_desc_;
cudnnTensorDescriptor_t db_desc_;
cudnnReduceTensorDescriptor_t op_desc_;

View File

@ -60,6 +60,7 @@ static std::map<std::string, std::pair<std::vector<size_t>, std::vector<size_t>>
{prim::kPrimRelu6->name(), {{0}, {0}}},
{prim::kPrimRelu6Grad->name(), {{0, 1}, {0}}},
{kSliceOpName, {{0}, {0}}},
{kSliceGradOpName, {{0, 1}, {0}}},
{kTensorAddOpName, {{0, 1}, {0}}},
{prim::kPrimConcat->name(), {{kAllPositions}, {0}}},
{prim::kPrimAddN->name(), {{kAllPositions}, {0}}},

View File

@ -108,8 +108,8 @@ python preprocess_dataset.py -d /data/save_data_path
## [Environment Requirements](#contents)
- HardwareAscend
- Prepare hardware environment with Ascend processor.
- HardwareAscend/GPU
- Prepare hardware environment with Ascend or GPU processor.
- Framework
- [MindSpore](https://www.mindspore.cn/install/en)
- For more information, please check the resources below
@ -191,6 +191,23 @@ If you want to run in modelarts, please check the official documentation of [mod
# (7) Create your job.
```
- Run on GPU
```python
# run training example
python train.py --data_path=/path/to/data/ --config_path=/path/to/config/ --output ./output > train.log 2>&1 &
OR
bash scripts/run_standalone_train_gpu.sh [DATASET] [CONFIG_PATH]
# run distributed training example
bash scripts/run_distribute_train_gpu.sh [RANKSIZE] [DATASET] [CONFIG_PATH]
# run evaluation example
python eval.py --data_path=/path/to/data/ --checkpoint_file_path=/path/to/checkpoint/ --config_path=/path/to/config/ > eval.log 2>&1 &
OR
bash scripts/run_standalone_eval_gpu.sh [DATASET] [CHECKPOINT] [CONFIG_PATH]
```
## [Script Description](#contents)
### [Script and Sample Code](#contents)
@ -207,6 +224,9 @@ If you want to run in modelarts, please check the official documentation of [mod
│ ├──run_infer_310.sh // shell script for infer on ascend 310
│ ├──run_standalone_train.sh // shell script for standalone on Ascend
│ ├──run_standalone_eval.sh // shell script for evaluation on Ascend
│ ├──run_standalone_train_gpu.sh // shell script for training on GPU
│ ├──run_standalone_eval_gpu.sh // shell script forevaluation on GPU
│ ├──run_distribute_train_gpu.sh // shell script for distributed on GPU
├── src
│ ├──config.py // parameter configuration
│ ├──data_loader.py // creating dataset
@ -261,6 +281,8 @@ Parameters for both training and evaluation can be set in config.py
'weight_decay': 0.0005, # weight decay value
'loss_scale': 1024.0, # loss scale
'FixedLossScaleManager': 1024.0, # fix loss scale
'is_save_on_master': 1, # save checkpoint on master or all rank
'rank': 0, # local rank of distributed(default: 0)
'resume': False, # whether training with pretrain model
'resume_ckpt': './', # pretrain model path
'transfer_training': False # whether do transfer training
@ -337,23 +359,43 @@ step: 600, loss is 0.22070312, fps is 56.99692546024671
The model checkpoint will be saved in the current directory.
#### Distributed Training
#### running on GPU
```shell
bash scripts/run_distribute_train.sh [RANK_TABLE_FILE] [DATASET] [CONFIG_PATH]
python train.py --data_path=/path/to/data/ --config_path=/path/to/config/ --output ./output > train.log 2>&1 &
OR
bash scripts/run_standalone_train_gpu.sh [DATASET] [CONFIG_PATH]
```
The above shell script will run distribute training in the background. You can view the results through the file `logs/device[X]/log.log`. The loss value will be achieved as follows:
The python command above will run in the background, you can view the results through the file train.log. The model checkpoint will be saved in the current directory.
### Distributed Training
#### running on Ascend
```shell
# grep "loss is" logs/device0/log.log
step: 1, loss is 0.70524895, fps is 0.15914689861221412
step: 2, loss is 0.6925452, fps is 56.43668656967454
...
step: 299, loss is 0.20551169, fps is 58.4039329983891
step: 300, loss is 0.18949677, fps is 57.63118508760329
bash scripts/run_distribute_train.sh [RANK_TABLE_FILE] [DATASET]
```
The above shell script will run distribute training in the background. You can view the results through the file `logs/device[X]/log.log`. The loss value will be achieved as follows:
```shell
# grep "loss is" logs/device0/log.log
step: 1, loss is 0.70524895, fps is 0.15914689861221412
step: 2, loss is 0.6925452, fps is 56.43668656967454
...
step: 299, loss is 0.20551169, fps is 58.4039329983891
step: 300, loss is 0.18949677, fps is 57.63118508760329
```
#### running on GPU
```shell
bash scripts/run_distribute_train_gpu.sh [RANKSIZE] [DATASET] [CONFIG_PATH]
```
The above shell script will run distribute training in the background. You can view the results through the file `train.log`.
#### Evaluation while training
You can add `run_eval` to start shell and set it True, if you want evaluation while training. And you can set argument option: `save_best_ckpt`, `eval_start_epoch`, `eval_interval`, `eval_metrics` when `run_eval` is True.
@ -379,37 +421,56 @@ The above python command will run in the background. You can view the results th
============== Cross valid dice coeff is: {'dice_coeff': 0.9111}
```
- evaluation on ISBI dataset when running on GPU
Before running the command below, please check the checkpoint path used for evaluation. Please set the checkpoint path to be the absolute full path, e.g., "username/unet/ckpt_unet_medical_adam-2_400.ckpt".
```shell
python eval.py --data_path=/path/to/data/ --checkpoint_file_path=/path/to/checkpoint/ --config_path=/path/to/config/ > eval.log 2>&1 &
OR
bash scripts/run_standalone_eval_gpu.sh [DATASET] [CHECKPOINT] [CONFIG_PATH]
```
The above python command will run in the background. You can view the results through the file "eval.log". The accuracy of the test dataset will be as follows:
```shell
# grep "Cross valid dice coeff is:" eval.log
============== Cross valid dice coeff is: {'dice_coeff': 0.9089390969777261}
```
## [Model Description](#contents)
### [Performance](#contents)
#### Evaluation Performance
| Parameters | Ascend |
| -------------------------- | ------------------------------------------------------------ |
| Model Version | Unet |
| Resource | Ascend 910; CPU 2.60GHz, 192cores; Memory 755G; OS Euler2.8 |
| uploaded Date | 09/15/2020 (month/day/year) |
| MindSpore Version | 1.2.0 |
| Dataset | ISBI |
| Training Parameters | 1pc: epoch=400, total steps=600, batch_size = 16, lr=0.0001 |
| Optimizer | Adam |
| Loss Function | Softmax Cross Entropy |
| outputs | probability |
| Loss | 0.22070312 |
| Speed | 1pc: 267 ms/step |
| Total time | 1pc: 2.67 mins |
| Parameters (M) | 93M |
| Checkpoint for Fine tuning | 355.11M (.ckpt file) |
| Scripts | [unet script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/unet) |
| Parameters | Ascend | GPU |
| -------------------------- | ------------------------------------------------------------ | :----------------------------------------------------------- |
| Model Version | Unet | Unet |
| Resource | Ascend 910 ;CPU 2.60GHz,192cores; Memory,755G; OS Euler2.8 | NV SMX2 V100-32G |
| uploaded Date | 09/15/2020 (month/day/year) | 01/20/2021 (month/day/year) |
| MindSpore Version | 1.2.0 | 1.1.0 |
| Dataset | ISBI | ISBI |
| Training Parameters | 1pc: epoch=400, total steps=600, batch_size = 16, lr=0.0001 | 1pc: epoch=400, total steps=800, batch_size = 12, lr=0.0001 |
| Optimizer | ADAM | ADAM |
| Loss Function | Softmax Cross Entropy | Softmax Cross Entropy |
| outputs | probability | probability |
| Loss | 0.22070312 | 0.21425568 |
| Speed | 1pc: 267 ms/step; | 1pc: 423 ms/step; |
| Total time | 1pc: 2.67 mins; | 1pc: 5.64 mins; |
| Parameters (M) | 93M | 93M |
| Checkpoint for Fine tuning | 355.11M (.ckpt file) | 355.11M (.ckpt file) |
| Scripts | [unet script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/unet) | [unet script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/unet) |
### [How to use](#contents)
## [How to use](#contents)
#### Inference
### Inference
If you need to use the trained model to perform inference on multiple hardware platforms, such as Ascend 910 or Ascend 310, you can refer to this [Link](https://www.mindspore.cn/tutorial/training/en/master/advanced_use/migrate_3rd_scripts.html). Following the steps below, this is a simple example:
If you need to use the trained model to perform inference on multiple hardware platforms, such as Ascend 910 or Ascend 310, you
can refer to this [Link](https://www.mindspore.cn/tutorial/training/en/master/advanced_use/migrate_3rd_scripts.html). Following
the steps below, this is a simple example:
##### Running on Ascend 310
#### Running on Ascend 310
Export MindIR
@ -464,4 +525,4 @@ In data_loader.py, we set the seed inside “_get_val_train_indices" function. W
## [ModelZoo Homepage](#contents)
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).

View File

@ -112,8 +112,8 @@ python preprocess_dataset.py -d /data/save_data_path
## 环境要求
- 硬件Ascend
- 准备Ascend处理器搭建硬件环境。
- 硬件Ascend/GPU
- 准备Ascend处理器或GPU处理器搭建硬件环境。
- 框架
- [MindSpore](https://www.mindspore.cn/install)
- 如需查看详情,请参见如下资源:
@ -198,6 +198,25 @@ bash scripts/docker_start.sh unet:20.1.0 [DATA_DIR] [MODEL_DIR]
# (7) 开始模型的推理。
```
- GPU处理器环境运行
```python
# 训练示例
python train.py --data_path=/path/to/data/ --config_path=/path/to/config/ --output ./output > train.log 2>&1 &
OR
bash scripts/run_standalone_train_gpu.sh [DATASET] [CONFIG_PATH]
# 分布式训练示例
bash scripts/run_distribute_train_gpu.sh [RANKSIZE] [DATASET] [CONFIG_PATH]
# 评估示例
python eval.py --data_path=/path/to/data/ --checkpoint_file_path=/path/to/checkpoint/ --config_path=/path/to/config/ > eval.log 2>&1 &
OR
bash scripts/run_standalone_eval_gpu.sh [DATASET] [CHECKPOINT] [CONFIG_PATH]
```
# 脚本说明
## 脚本说明
### 脚本及样例代码
@ -214,6 +233,9 @@ bash scripts/docker_start.sh unet:20.1.0 [DATA_DIR] [MODEL_DIR]
│ ├──run_infer_310.sh // Ascend 310 推理脚本
│ ├──run_standalone_train.sh // Ascend 上单卡训练脚本
│ ├──run_standalone_eval.sh // Ascend 上推理脚本
│ ├──run_standalone_train_gpu.sh // GPU 上训练脚本
│ ├──run_standalone_eval_gpu.sh // GPU 上评估脚本
│ ├──run_distribute_train_gpu.sh // GPU 上分布式训练脚本
├── src
│ ├──config.py // 参数配置
│ ├──data_loader.py // 数据处理
@ -268,6 +290,8 @@ bash scripts/docker_start.sh unet:20.1.0 [DATA_DIR] [MODEL_DIR]
'weight_decay': 0.0005, # 权重衰减值
'loss_scale': 1024.0, # 损失放大
'FixedLossScaleManager': 1024.0, # 固定损失放大
'is_save_on_master': 1, # 在master或all rank上保存检查点
'rank': 0, # 分布式local rank默认为0
'resume': False, # 是否使用预训练模型训练
'resume_ckpt': './', # 预训练模型路径
```
@ -332,9 +356,22 @@ python train.py --data_path=/path/to/data/ --config_path=/path/to/yaml > train.l
```
模型检查点储存在当前路径中。
- GPU处理器环境运行
```shell
python train.py --data_path=/path/to/data/ --config_path=/path/to/config/ --output ./output > train.log 2>&1 &
OR
bash scripts/run_standalone_train_gpu.sh [DATASET] [CONFIG_PATH]
```
上述python命令在后台运行可通过`train.log`文件查看结果。
训练结束后,您可以在默认脚本文件夹中找到检查点文件。
### 分布式训练
- Ascend处理器环境运行
```shell
bash scripts/run_distribute_train.sh [RANK_TABLE_FILE] [DATASET]
```
@ -350,6 +387,14 @@ step: 299, loss is 0.20551169, fps is 58.4039329983891
step: 300, loss is 0.18949677, fps is 57.63118508760329
```
- GPU处理器环境运行
```shell
bash scripts/run_distribute_train_gpu.sh [RANKSIZE] [DATASET] [CONFIG_PATH]
```
上述shell脚本在后台运行分布式训练。可通过`train.log`文件查看结果。
#### 训练时推理
训练时推理需要在启动文件中添加`run_eval` 并设置为True。与此同时需要设置: `save_best_ckpt`, `eval_start_epoch`, `eval_interval`, `eval_metrics`
@ -375,29 +420,46 @@ python eval.py --data_path=/path/to/data/ --checkpoint_file_path=/path/to/checkp
============== Cross valid dice coeff is: {'dice_coeff': 0.9111}
```
## 模型描述
- GPU处理器环境运行评估ISBI数据集
### 性能
在运行以下命令之前,请检查用于评估的检查点路径。将检查点路径设置为绝对全路径,如"username/unet/ckpt_unet_medical_adam-2_400.ckpt"。
#### 评估性能
```shell
python eval.py --data_path=/path/to/data/ --checkpoint_file_path=/path/to/checkpoint/ --config_path=/path/to/config/ > eval.log 2>&1 &
OR
bash scripts/run_standalone_eval_gpu.sh [DATASET] [CHECKPOINT] [CONFIG_PATH]
```
| 参数 | Ascend |
| -------------------------- | ------------------------------------------------------------ |
| 模型版本 | U-Net |
| 资源 | Ascend 910CPU 2.60GHz192核内存 755GB系统 Euler2.8 |
| 上传日期 | 2020-9-15 |
| MindSpore版本 | 1.2.0 |
| 数据集 | ISBI |
| 训练参数 | 1pc: epoch=400, total steps=600, batch_size = 16, lr=0.0001 |
| 优化器 | Adam |
| 损失函数 | Softmax交叉熵 |
| 输出 | 概率 |
| 损失 | 0.22070312 |
| 速度 | 1卡267毫秒/步8卡280毫秒/步 |
| 总时长 | 1卡2.67分钟8卡1.40分钟 |
| 参数(M) | 93M |
| 微调检查点 | 355.11M (.ckpt文件) |
| 脚本 | [U-Net脚本](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/unet) |
上述python命令在后台运行。可通过"eval.log"文件查看结果。测试数据集的准确率如下:
```shell
# grep "Cross valid dice coeff is:" eval.log
============== Cross valid dice coeff is: {'dice_coeff': 0.9089390969777261}
```
# 模型描述
## 性能
### 评估性能
| 参数 | Ascend | GPU |
| -------------------------- | ------------------------------------------------------------ | ------------------------------------------------------------ |
| 模型版本 | U-Net | U-Net |
| 资源 | Ascend 910CPU2.60GHz192核内存755 GB系统 Euler2.8 | NV SMX2 V100内存32G |
| 上传日期 | 2020-9-15 | 2020-12-29 |
| MindSpore版本 | 1.2.0 | 1.1.0 |
| 数据集 | ISBI | ISBI |
| 训练参数 | 1pc: epoch=400, total steps=600, batch_size = 16, lr=0.0001 | 1pc: epoch=400, total steps=800,batch_size = 12, lr=0.0001 |
| 优化器 | ADAM | ADAM |
| 损失函数 | Softmax交叉熵 | Softmax交叉熵 |
| 输出 | 概率 | 概率 |
| 损失 | 0.22070312 | 0.21425568 |
| 速度 | 1卡267毫秒/步8卡280毫秒/步 | 1卡423毫秒/步8卡128毫秒/步 |
| 总时长 | 1卡2.67分钟8卡1.40分钟 | 1卡5.64分钟8卡3.41分钟 |
| 参数(M) | 93M | 93M |
| 微调检查点 | 355.11M (.ckpt文件) | 355.11M (.ckpt文件) |
| 脚本 | [U-Net脚本](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/unet) | [U-Net脚本](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/unet) |
### 用法

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -25,9 +25,6 @@ from src.utils import UnetEval, TempLoss, dice_coeff
from src.model_utils.config import config
from src.model_utils.moxing_adapter import moxing_wrapper
device_id = int(os.getenv("DEVICE_ID"))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id)
@moxing_wrapper()
def test_net(data_dir,
ckpt_path,
@ -63,6 +60,10 @@ def test_net(data_dir,
if __name__ == '__main__':
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target, save_graphs=False)
if config.device_target == "Ascend":
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(device_id=device_id)
test_net(data_dir=config.data_path,
ckpt_path=config.checkpoint_file_path,
cross_valid_ind=config.cross_valid_ind)

View File

@ -0,0 +1,22 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash scripts/run_distribute_train_gpu.sh [RANKSIZE] [DATASET] [CONFIG_PATH]"
echo "for example: bash run_distribute_train_gpu.sh 8 /path/to/data/ /path/to/config/"
echo "=============================================================================================================="
mpirun -n $1 --allow-run-as-root --output-filename log_output --merge-stderr-to-stdout \
python train.py --run_distribute=True --data_path=$2 --config_path=$3 --output=./output > train.log 2>&1 &

View File

@ -0,0 +1,22 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# less required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash scripts/run_standalone_eval_gpu.sh [DATASET] [CHECKPOINT] [CONFIG_PATH]"
echo "for example: bash run_standalone_eval_gpu.sh /path/to/data/ /path/to/checkpoint/ /path/to/config/"
echo "=============================================================================================================="
python eval.py --data_path=$1 --checkpoint_file_path=$2 --config_path=$3 > eval.log 2>&1 &

View File

@ -0,0 +1,22 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# less required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash scripts/run_standalone_train_gpu.sh [DATASET] [CONFIG_PATH] "
echo "for example: bash scripts/run_standalone_train_gpu.sh /path/to/data/ /path/to/config/"
echo "=============================================================================================================="
python train.py --data_path=$1 --config_path=$2 --output ./output > train.log 2>&1 &

View File

@ -12,14 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
import logging
import mindspore
import mindspore.nn as nn
from mindspore import Model, context
from mindspore.communication.management import init
from mindspore.communication.management import init, get_rank, get_group_size
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint
from mindspore.context import ParallelMode
from mindspore.train.serialization import load_checkpoint, load_param_into_net
@ -33,10 +32,6 @@ from src.eval_callback import EvalCallBack
from src.model_utils.config import config
from src.model_utils.moxing_adapter import moxing_wrapper
from src.model_utils.device_adapter import get_rank_id, get_device_num
device_id = int(os.getenv("DEVICE_ID"))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id)
mindspore.set_seed(1)
@ -51,8 +46,8 @@ def train_net(cross_valid_ind=1,
run_distribute = config.run_distribute
if run_distribute:
init()
group_size = get_device_num()
rank = get_rank_id()
group_size = get_group_size()
rank = get_rank()
parallel_mode = ParallelMode.DATA_PARALLEL
context.set_auto_parallel_context(parallel_mode=parallel_mode,
device_num=group_size,
@ -94,16 +89,21 @@ def train_net(cross_valid_ind=1,
else:
repeat = config.repeat
dataset_sink_mode = False
if config.device_target == "GPU":
dataset_sink_mode = True
per_print_times = 1
train_dataset, valid_dataset = create_dataset(data_dir, repeat, batch_size, True, cross_valid_ind,
run_distribute, config.crop, config.image_size)
train_data_size = train_dataset.get_dataset_size()
print("dataset length is:", train_data_size)
ckpt_save_dir = os.path.join(config.output_path, config.checkpoint_path)
ckpt_config = CheckpointConfig(save_checkpoint_steps=train_data_size,
save_ck_steps = train_data_size
if config.device_target == "GPU":
save_ck_steps = train_data_size * epochs
ckpt_config = CheckpointConfig(save_checkpoint_steps=save_ck_steps,
keep_checkpoint_max=config.keep_checkpoint_max)
ckpoint_cb = ModelCheckpoint(prefix='ckpt_{}_adam'.format(config.model_name),
directory=ckpt_save_dir+'./ckpt_{}/'.format(device_id),
directory=ckpt_save_dir+'./ckpt_{}/'.format(rank),
config=ckpt_config)
optimizer = nn.Adam(params=net.trainable_params(), learning_rate=lr, weight_decay=config.weight_decay,
@ -121,7 +121,7 @@ def train_net(cross_valid_ind=1,
eval_param_dict = {"model": eval_model, "dataset": valid_dataset, "metrics_name": config.eval_metrics}
eval_cb = EvalCallBack(apply_eval, eval_param_dict, interval=config.eval_interval,
eval_start_epoch=config.eval_start_epoch, save_best_ckpt=True,
ckpt_directory=ckpt_save_dir+'./ckpt_{}/'.format(device_id), besk_ckpt_name="best.ckpt",
ckpt_directory=ckpt_save_dir+'./ckpt_{}/'.format(rank), besk_ckpt_name="best.ckpt",
metrics_name=config.eval_metrics)
callbacks.append(eval_cb)
model.train(int(epochs / repeat), train_dataset, callbacks=callbacks, dataset_sink_mode=dataset_sink_mode)
@ -130,9 +130,15 @@ def train_net(cross_valid_ind=1,
if __name__ == '__main__':
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target, save_graphs=False)
if config.device_target == "Ascend":
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(device_id=device_id)
epoch_size = config.epochs if not config.run_distribute else config.distribute_epochs
batchsize = config.batch_size
if config.device_target == 'GPU' and config.run_distribute:
batchsize = config.distribute_batchsize
train_net(cross_valid_ind=config.cross_valid_ind,
epochs=epoch_size,
batch_size=config.batch_size,
batch_size=batchsize,
lr=config.lr)

View File

@ -0,0 +1,69 @@
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
enable_modelarts: False
# Url for modelarts
data_url: ""
train_url: ""
checkpoint_url: ""
# Path for local
data_path: "/cache/data"
output_path: "/cache/train"
load_path: "/cache/checkpoint_path/"
device_target: 'GPU'
enable_profiling: False
# ==============================================================================
# Training options
model_name: 'unet_medical'
run_eval: False
run_distribute: False
crop: [388, 388]
image_size : [572, 572]
lr: 0.0001
epochs: 400
repeat: 1
distribute_epochs: 1600
batch_size: 12
distribute_batchsize: 3
cross_valid_ind: 1
num_classes: 2
num_channels: 1
weight_decay: 0.0005
loss_scale: 1024.0
FixedLossScaleManager: 1024.0
resume: False
resume_ckpt: './'
transfer_training: False
filter_weight: ['outc.weight', 'outc.bias']
#Eval options
keep_checkpoint_max: 10
eval_activate: 'Softmax'
eval_resize: False
checkpoint_path: './checkpoint/'
checkpoint_file_path: 'ckpt_unet_medical_adam-400.ckpt'
rst_path: './result_Files/'
# Export options
width: 572
height: 572
file_name: unet
file_format: AIR
---
# Help description for each configuration
enable_modelarts: 'Whether training on modelarts, default: False'
data_url: 'Dataset url for obs'
train_url: 'Training output url for obs'
checkpoint_url: 'The location of checkpoint for obs'
data_path: 'Dataset path for local'
output_path: 'Training output path for local'
load_path: 'The location of checkpoint for obs'
device_target: 'Target device type, available: [Ascend, GPU, CPU]'
enable_profiling: 'Whether enable profiling while training, default: False'
num_classes: 'Class for dataset'
batch_size: "Batch size for training and evaluation"
distribute_batchsize: "Batch size for distribute training"
weight_decay: "Weight decay."
keep_checkpoint_max: "keep the last keep_checkpoint_max checkpoint"
checkpoint_path: "The location of the checkpoint file."
checkpoint_file_path: "The location of the checkpoint file."