forked from mindspore-Ecosystem/mindspore
1.Optimize bias add grad kernel
2.Optimize slice grad kernel 3.Add Unet GPU Model
This commit is contained in:
parent
b12b780163
commit
be3d4e6fd3
|
@ -19,6 +19,8 @@
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||||
#include "backend/kernel_compiler/gpu/cuda_impl/slice_impl.cuh"
|
#include "backend/kernel_compiler/gpu/cuda_impl/slice_impl.cuh"
|
||||||
|
@ -39,8 +41,9 @@ class SliceGradGpuKernel : public GpuKernel {
|
||||||
T *dy = GetDeviceAddress<T>(inputs, 0);
|
T *dy = GetDeviceAddress<T>(inputs, 0);
|
||||||
T *dx = GetDeviceAddress<T>(outputs, 0);
|
T *dx = GetDeviceAddress<T>(outputs, 0);
|
||||||
FillDeviceArray(outputs[0]->size / sizeof(T), dx, 0.f, reinterpret_cast<cudaStream_t>(stream_ptr));
|
FillDeviceArray(outputs[0]->size / sizeof(T), dx, 0.f, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||||
CalSliceGrad(output_size_ / sizeof(T), dy, input_shape_, begin_, size_, dx,
|
CalSlice4DGrad(begin_[0], begin_[1], begin_[2], begin_[3], size_[0], size_[1], size_[2], size_[3], input_shape_[0],
|
||||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
input_shape_[1], input_shape_[2], input_shape_[3], dy, dx,
|
||||||
|
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -49,6 +52,7 @@ class SliceGradGpuKernel : public GpuKernel {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
auto kernel_name = AnfAlgo::GetCNodeName(kernel_node);
|
auto kernel_name = AnfAlgo::GetCNodeName(kernel_node);
|
||||||
|
auto data_format = AnfAlgo::GetInputFormat(kernel_node, 0);
|
||||||
if (kernel_name == "StridedSliceGrad") {
|
if (kernel_name == "StridedSliceGrad") {
|
||||||
is_strided_slice_ = true;
|
is_strided_slice_ = true;
|
||||||
std::vector<int64_t> shapex = GetAttr<std::vector<int64_t>>(kernel_node, "shapex");
|
std::vector<int64_t> shapex = GetAttr<std::vector<int64_t>>(kernel_node, "shapex");
|
||||||
|
@ -64,17 +68,15 @@ class SliceGradGpuKernel : public GpuKernel {
|
||||||
}
|
}
|
||||||
size_ = GetAttr<std::vector<int64_t>>(kernel_node, "end");
|
size_ = GetAttr<std::vector<int64_t>>(kernel_node, "end");
|
||||||
} else {
|
} else {
|
||||||
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
|
auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
|
||||||
ShapeNdTo4d(input_shape, &input_shape_);
|
ShapeNdTo4d(input_shape, &input_shape_);
|
||||||
size_ = GetAttr<std::vector<int64_t>>(kernel_node, "size");
|
size_ = GetAttr<std::vector<int64_t>>(kernel_node, "size");
|
||||||
}
|
}
|
||||||
|
auto dy_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||||
auto dy_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
|
||||||
ShapeNdTo4d(dy_shape, &dy_shape_);
|
ShapeNdTo4d(dy_shape, &dy_shape_);
|
||||||
begin_ = GetAttr<std::vector<int64_t>>(kernel_node, "begin");
|
begin_ = GetAttr<std::vector<int64_t>>(kernel_node, "begin");
|
||||||
DealParam();
|
CalcBeginAndSize(data_format);
|
||||||
input_size_ = input_shape_[0] * input_shape_[1] * input_shape_[2] * input_shape_[3] * sizeof(T);
|
input_size_ = input_shape_[0] * input_shape_[1] * input_shape_[2] * input_shape_[3] * sizeof(T);
|
||||||
|
|
||||||
output_size_ = sizeof(T);
|
output_size_ = sizeof(T);
|
||||||
for (auto x : dy_shape_) {
|
for (auto x : dy_shape_) {
|
||||||
output_size_ = output_size_ * x;
|
output_size_ = output_size_ * x;
|
||||||
|
@ -89,6 +91,30 @@ class SliceGradGpuKernel : public GpuKernel {
|
||||||
input_size_list_.push_back(input_size_);
|
input_size_list_.push_back(input_size_);
|
||||||
output_size_list_.push_back(input_size_);
|
output_size_list_.push_back(input_size_);
|
||||||
}
|
}
|
||||||
|
void CalcBeginAndSize(const std::string data_format) {
|
||||||
|
for (auto i = begin_.size(); i < 4; i++) {
|
||||||
|
(void)begin_.insert(begin_.begin(), 0);
|
||||||
|
}
|
||||||
|
for (auto i = size_.size(); i < 4; i++) {
|
||||||
|
(void)size_.insert(size_.begin(), 1);
|
||||||
|
}
|
||||||
|
if (data_format == "NHWC") {
|
||||||
|
std::swap(begin_[1], begin_[3]);
|
||||||
|
std::swap(begin_[1], begin_[2]);
|
||||||
|
std::swap(size_[1], size_[3]);
|
||||||
|
std::swap(size_[1], size_[2]);
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < begin_.size(); i++) {
|
||||||
|
if (begin_[i] < 0) {
|
||||||
|
begin_[i] = begin_[i] + input_shape_[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < size_.size(); i++) {
|
||||||
|
if (size_[i] < 0) {
|
||||||
|
size_[i] = (size_[i] + input_shape_[i]) > 0 ? (size_[i] + input_shape_[i]) : 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
bool CheckParam(const CNodePtr &kernel_node) {
|
bool CheckParam(const CNodePtr &kernel_node) {
|
||||||
|
@ -108,24 +134,7 @@ class SliceGradGpuKernel : public GpuKernel {
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
void DealParam() {
|
|
||||||
for (auto i = begin_.size(); i < 4; i++) {
|
|
||||||
(void)begin_.insert(begin_.begin(), 0);
|
|
||||||
}
|
|
||||||
for (auto i = size_.size(); i < 4; i++) {
|
|
||||||
(void)size_.insert(size_.begin(), 1);
|
|
||||||
}
|
|
||||||
for (size_t i = 0; i < begin_.size(); i++) {
|
|
||||||
if (begin_[i] < 0) {
|
|
||||||
begin_[i] = begin_[i] + input_shape_[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for (size_t i = 0; i < size_.size(); i++) {
|
|
||||||
if (size_[i] < 0) {
|
|
||||||
size_[i] = (size_[i] + input_shape_[i]) > 0 ? (size_[i] + input_shape_[i]) : 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
std::vector<int64_t> begin_;
|
std::vector<int64_t> begin_;
|
||||||
std::vector<int64_t> size_;
|
std::vector<int64_t> size_;
|
||||||
std::vector<int64_t> strides_;
|
std::vector<int64_t> strides_;
|
||||||
|
|
|
@ -0,0 +1,175 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
#include "backend/kernel_compiler/gpu/cuda_impl/util.cuh"
|
||||||
|
#include "runtime/device/gpu/cuda_common.h"
|
||||||
|
#include "backend/kernel_compiler/gpu/cuda_impl/bias_add_grad_impl.cuh"
|
||||||
|
|
||||||
|
const int kWarpSize = 32;
|
||||||
|
// tuning param, for those nhw >= kLargeSize, launch more blocks to solve
|
||||||
|
const int kLargeSize = 500000; // tuning param for BiasAddGradNHWC
|
||||||
|
const int kNumBlocks = 8; // tuning param for BiasAddGradNHWC
|
||||||
|
|
||||||
|
// For NHWC bias add grad, combine dy's NHW together, matrix column reduce.
|
||||||
|
// This is a simple implementation, can be further optimized when C is small.
|
||||||
|
// Firstly, Each warp sums several rows, each thread's partial_sum is the sum of
|
||||||
|
// a part of one cloumn.
|
||||||
|
// Secondly, in order to sum up all values in one column, which is to sum up the partial_sum
|
||||||
|
// in different warps but with the same lane_id, each warp store their partial_sums
|
||||||
|
// to one row of shared mem, and read partial_sums from one col of shared mem.
|
||||||
|
// Then each warp do warp reduce to sum up 32 partial_sums, and write final result to db
|
||||||
|
// For larger NHW, one block is not enough to sum up all rows, needs to launch more blocks.
|
||||||
|
template<typename T>
|
||||||
|
__global__ void BiasAddGradNHWC(const T* dy, T* db, const size_t m, const size_t n,
|
||||||
|
const size_t rows_per_block, size_t rows_per_warp) {
|
||||||
|
__shared__ float shared_d[kWarpSize][kWarpSize + 1]; // avoid bank conflict
|
||||||
|
int shm_row_id = (threadIdx.x >> 5);
|
||||||
|
int shm_col_id = (threadIdx.x % 32);
|
||||||
|
int block_start_row = blockIdx.x * rows_per_block;
|
||||||
|
int block_end_row = block_start_row + rows_per_block;
|
||||||
|
block_end_row = block_end_row < m ? block_end_row : m;
|
||||||
|
int warp_start_row = blockIdx.x * rows_per_block + shm_row_id * rows_per_warp;
|
||||||
|
int warp_end_row = warp_start_row + rows_per_warp;
|
||||||
|
int real_rows_per_warp = warp_end_row < block_end_row ? rows_per_warp : block_end_row - warp_start_row;
|
||||||
|
// boundary process
|
||||||
|
// Only the last row or column may not have the full size
|
||||||
|
bool full_tile = true;
|
||||||
|
int tile_width_real = 32;
|
||||||
|
if (blockIdx.y == blockDim.y - 1) {
|
||||||
|
tile_width_real = n - (blockDim.y - 1) * 32;
|
||||||
|
full_tile = (tile_width_real == 32);
|
||||||
|
}
|
||||||
|
int read_offset = warp_start_row * n + (blockIdx.y << 5) + shm_col_id;
|
||||||
|
float partial_sum = 0.0;
|
||||||
|
if (full_tile) {
|
||||||
|
for (int i = 0; i < real_rows_per_warp; i++) {
|
||||||
|
partial_sum += static_cast<float>(dy[read_offset]);
|
||||||
|
read_offset += n;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (shm_col_id < tile_width_real) {
|
||||||
|
for (int i = 0; i < real_rows_per_warp; i++) {
|
||||||
|
partial_sum += static_cast<float>(dy[read_offset]);
|
||||||
|
read_offset += n;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
shared_d[shm_row_id][shm_col_id] = partial_sum;
|
||||||
|
__syncthreads();
|
||||||
|
partial_sum = shared_d[shm_col_id][shm_row_id];
|
||||||
|
__syncthreads();
|
||||||
|
for (int offset = kWarpSize / 2; offset > 0; offset /= 2) {
|
||||||
|
partial_sum += __shfl_down_sync(0xffffffff, partial_sum, offset);
|
||||||
|
}
|
||||||
|
if (shm_col_id == 0) {
|
||||||
|
if (full_tile) {
|
||||||
|
MsAtomicAdd(db + (blockIdx.y << 5) + shm_row_id, T(partial_sum));
|
||||||
|
} else {
|
||||||
|
if (shm_row_id < tile_width_real) {
|
||||||
|
MsAtomicAdd(db + (blockIdx.y << 5) + shm_row_id, T(partial_sum));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__global__ void BiasAddGradNCHW(const size_t size, const int batch, const int bias_size, const int h,
|
||||||
|
const int w, const int bg_size, const T* dy, T* db) {
|
||||||
|
__shared__ float shared_d[32];
|
||||||
|
for (int i = threadIdx.x; i < 32; i += blockDim.x) {
|
||||||
|
shared_d[i] = static_cast<float>(0);
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
float sum = 0.;
|
||||||
|
int lane_id = threadIdx.x % 32;
|
||||||
|
int thread_id = threadIdx.x;
|
||||||
|
int img_size = h * w;
|
||||||
|
// N*H*W -> count / bg_size equals the amount of work one block should reduce
|
||||||
|
int count = batch * img_size;
|
||||||
|
int bg_offset = blockIdx.x % bias_size;
|
||||||
|
int bg_id = blockIdx.x / bias_size;
|
||||||
|
for (int i = bg_id * blockDim.x + threadIdx.x; // thread start
|
||||||
|
i < count; i += blockDim.x * bg_size) {
|
||||||
|
int img_offset = i % img_size;
|
||||||
|
int img_id = i / img_size;
|
||||||
|
T val = *(dy + (img_id * bias_size + bg_offset) * img_size + img_offset);
|
||||||
|
sum += static_cast<float>(val);
|
||||||
|
}
|
||||||
|
MsAtomicAdd(shared_d + lane_id, sum);
|
||||||
|
__syncthreads();
|
||||||
|
if (thread_id < 32) {
|
||||||
|
float data = shared_d[thread_id];
|
||||||
|
for (int offset = kWarpSize / 2; offset > 0; offset /= 2) {
|
||||||
|
data += __shfl_xor_sync(0xffffffff, data, offset);
|
||||||
|
}
|
||||||
|
if (thread_id == 0) {
|
||||||
|
MsAtomicAdd(db + bg_offset, T(data));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__global__ void FillDb(T *db, const size_t bias_size) {
|
||||||
|
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < bias_size; pos += blockDim.x * gridDim.x) {
|
||||||
|
db[pos] = T(0.0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void CalBiasAddGradNCHW(const size_t size, const size_t bias_size, const int height, const int width,
|
||||||
|
const T* dy, T* db, cudaStream_t cuda_stream) {
|
||||||
|
int batch_size = size / bias_size / height / width;
|
||||||
|
int block_num = GET_BLOCKS(size);
|
||||||
|
int thread_num = GET_THREADS;
|
||||||
|
// how many blocks to solve one bias's reduce work(N * H * W)
|
||||||
|
int block_group_size = (block_num + bias_size - 1) / bias_size;
|
||||||
|
block_num = block_group_size * bias_size;
|
||||||
|
if (thread_num < kWarpSize) {
|
||||||
|
thread_num = kWarpSize;
|
||||||
|
}
|
||||||
|
FillDb<<<GET_BLOCKS(bias_size), GET_THREADS, 0, cuda_stream>>>(db, bias_size);
|
||||||
|
BiasAddGradNCHW<<<block_num, thread_num, 0, cuda_stream>>>(size, batch_size, bias_size, height,
|
||||||
|
width, block_group_size, dy, db);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void CalBiasAddGradNHWC(const size_t size, const size_t bias_size,
|
||||||
|
const T* dy, T* db, cudaStream_t cuda_stream) {
|
||||||
|
FillDb<<<GET_BLOCKS(bias_size), GET_THREADS, 0, cuda_stream>>>(db, bias_size);
|
||||||
|
size_t rows = size/bias_size;
|
||||||
|
int block_num_x = rows <= kLargeSize ? 1 : kNumBlocks;
|
||||||
|
int block_num_y = (bias_size + kWarpSize - 1) / kWarpSize;
|
||||||
|
dim3 grid_size(block_num_x, block_num_y, 1);
|
||||||
|
dim3 block_size(kWarpSize*kWarpSize);
|
||||||
|
size_t rows_per_block = (rows + block_num_x - 1) / block_num_x;
|
||||||
|
size_t rows_per_warp = (rows_per_block + kWarpSize - 1) / kWarpSize;
|
||||||
|
BiasAddGradNHWC<<<grid_size, block_size, 0, cuda_stream>>>(dy, db, rows, bias_size,
|
||||||
|
rows_per_block, rows_per_warp);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
template void CalBiasAddGradNCHW(const size_t size, const size_t bias_size, const int height, const int width,
|
||||||
|
const float* dy, float* db, cudaStream_t cuda_stream);
|
||||||
|
template void CalBiasAddGradNCHW(const size_t size, const size_t bias_size, const int height, const int width,
|
||||||
|
const half* dy, half* db, cudaStream_t cuda_stream);
|
||||||
|
template void CalBiasAddGradNHWC(const size_t size, const size_t bias_size,
|
||||||
|
const float* dy, float* db, cudaStream_t cuda_stream);
|
||||||
|
template void CalBiasAddGradNHWC(const size_t size, const size_t bias_size, const half* dy,
|
||||||
|
half* db, cudaStream_t cuda_stream);
|
|
@ -0,0 +1,27 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BIASADDGRAD_H_
|
||||||
|
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BIASADDGRAD_H_
|
||||||
|
|
||||||
|
#include "runtime/device/gpu/cuda_common.h"
|
||||||
|
template <typename T>
|
||||||
|
void CalBiasAddGradNHWC(const size_t size, const size_t bias_size,
|
||||||
|
const T* dy, T* db, cudaStream_t cuda_stream);
|
||||||
|
template <typename T>
|
||||||
|
void CalBiasAddGradNCHW(const size_t size, const size_t bias_size, const int height, const int width,
|
||||||
|
const T* dy, T* db, cudaStream_t cuda_stream);
|
||||||
|
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BIASADDGRAD_H_
|
|
@ -34,12 +34,20 @@ __global__ void Slice4D(const size_t s1, const size_t s2, const size_t s3, const
|
||||||
output[pos] = input[offset];
|
output[pos] = input[offset];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__global__ void SliceGrad(const T *dy, int64_t p, int64_t start, int64_t length, T *output) {
|
__global__ void Slice4DGrad(const size_t s1, const size_t s2, const size_t s3, const size_t s4,
|
||||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (length); pos += blockDim.x * gridDim.x) {
|
const size_t l1, const size_t l2, const size_t l3, const size_t l4,
|
||||||
output[start + pos] = dy[p + pos];
|
const size_t d1, const size_t d2, const size_t d3, const size_t d4,
|
||||||
|
const T *dy, T *dx) {
|
||||||
|
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (l1 * l2 * l3 * l4); pos += blockDim.x * gridDim.x) {
|
||||||
|
size_t i = pos / (l2 * l3 * l4) % l1;
|
||||||
|
size_t j = pos / (l3 * l4) % l2;
|
||||||
|
size_t k = pos / l4 % l3;
|
||||||
|
size_t o = pos % l4;
|
||||||
|
size_t input_idx = (i + s1) * (d2 * d3 * d4) + (j + s2) * (d3 * d4) + (k + s3) * d4 + (o + s4);
|
||||||
|
dx[input_idx] = dy[pos];
|
||||||
}
|
}
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
@ -62,24 +70,13 @@ void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size
|
||||||
Slice4D<<<GET_BLOCKS(l1 * l2 * l3 * l4), GET_THREADS, 0, stream>>>(s1, s2, s3, s4, l1, l2, l3, l4, d1, d2, d3, d4,
|
Slice4D<<<GET_BLOCKS(l1 * l2 * l3 * l4), GET_THREADS, 0, stream>>>(s1, s2, s3, s4, l1, l2, l3, l4, d1, d2, d3, d4,
|
||||||
input, output);
|
input, output);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void CalSliceGrad(const size_t input_size, const T *dy, const std::vector<size_t> in_shape,
|
void CalSlice4DGrad(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1,
|
||||||
const std::vector<int64_t> begin, const std::vector<int64_t> size, T *output,
|
const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2,
|
||||||
cudaStream_t cuda_stream) {
|
const size_t d3, const size_t d4, const T *dy, T *dx, cudaStream_t stream) {
|
||||||
size_t block = in_shape[1] * in_shape[2] * in_shape[3];
|
Slice4DGrad<<<GET_BLOCKS(l1 * l2 * l3 * l4), GET_THREADS, 0, stream>>>(s1, s2, s3, s4, l1, l2, l3, l4, d1, d2, d3, d4,
|
||||||
size_t map = in_shape[2] * in_shape[3];
|
dy, dx);
|
||||||
size_t w = in_shape[3];
|
|
||||||
int64_t length = size[3];
|
|
||||||
int64_t p = 0;
|
|
||||||
for (int64_t i = begin[0]; i < size[0] + begin[0]; i++) {
|
|
||||||
for (int64_t j = begin[1]; j < size[1] + begin[1]; j++) {
|
|
||||||
for (int64_t k = begin[2]; k < size[2] + begin[2]; k++) {
|
|
||||||
SliceGrad<<<GET_BLOCKS(input_size), GET_THREADS, 0, cuda_stream>>>(
|
|
||||||
dy, p, i * block + j * map + k * w + begin[3], length, output);
|
|
||||||
p = p + size[3];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
@ -168,10 +165,6 @@ template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, c
|
||||||
template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1,
|
template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1,
|
||||||
const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2,
|
const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2,
|
||||||
const size_t d3, const size_t d4, const half *input, half *output, cudaStream_t stream);
|
const size_t d3, const size_t d4, const half *input, half *output, cudaStream_t stream);
|
||||||
template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1,
|
|
||||||
const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2,
|
|
||||||
const size_t d3, const size_t d4, const int64_t *input, int64_t *output,
|
|
||||||
cudaStream_t stream);
|
|
||||||
template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1,
|
template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1,
|
||||||
const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2,
|
const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2,
|
||||||
const size_t d3, const size_t d4, const int *input, int *output, cudaStream_t stream);
|
const size_t d3, const size_t d4, const int *input, int *output, cudaStream_t stream);
|
||||||
|
@ -183,36 +176,43 @@ template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, c
|
||||||
const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2,
|
const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2,
|
||||||
const size_t d3, const size_t d4, const unsigned char *input, unsigned char *output,
|
const size_t d3, const size_t d4, const unsigned char *input, unsigned char *output,
|
||||||
cudaStream_t stream);
|
cudaStream_t stream);
|
||||||
|
template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1,
|
||||||
|
const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2,
|
||||||
|
const size_t d3, const size_t d4, const int64_t *input, int64_t *output,
|
||||||
|
cudaStream_t stream);
|
||||||
template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1,
|
template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1,
|
||||||
const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2,
|
const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2,
|
||||||
const size_t d3, const size_t d4, const bool *input, bool *output, cudaStream_t stream);
|
const size_t d3, const size_t d4, const bool *input, bool *output, cudaStream_t stream);
|
||||||
|
|
||||||
template void CalSliceGrad<double>(const size_t input_size, const double *dy, const std::vector<size_t> in_shape,
|
template void CalSlice4DGrad<double>(const size_t s1, const size_t s2, const size_t s3, const size_t s4,
|
||||||
const std::vector<int64_t> begin, const std::vector<int64_t> size, double *output,
|
const size_t l1, const size_t l2, const size_t l3, const size_t l4,
|
||||||
cudaStream_t cuda_stream);
|
const size_t d1, const size_t d2, const size_t d3, const size_t d4,
|
||||||
template void CalSliceGrad<float>(const size_t input_size, const float *dy, const std::vector<size_t> in_shape,
|
const double *dy, double *dx, cudaStream_t stream);
|
||||||
const std::vector<int64_t> begin, const std::vector<int64_t> size, float *output,
|
template void CalSlice4DGrad<float>(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1,
|
||||||
cudaStream_t cuda_stream);
|
const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2,
|
||||||
template void CalSliceGrad<half>(const size_t input_size, const half *dy, const std::vector<size_t> in_shape,
|
const size_t d3, const size_t d4, const float *dy, float *dx, cudaStream_t stream);
|
||||||
const std::vector<int64_t> begin, const std::vector<int64_t> size, half *output,
|
template void CalSlice4DGrad<half>(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1,
|
||||||
cudaStream_t cuda_stream);
|
const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2,
|
||||||
template void CalSliceGrad<int64_t>(const size_t input_size, const int64_t *dy, const std::vector<size_t> in_shape,
|
const size_t d3, const size_t d4, const half *dy, half *dx, cudaStream_t stream);
|
||||||
const std::vector<int64_t> begin, const std::vector<int64_t> size, int64_t *output,
|
template void CalSlice4DGrad<int>(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1,
|
||||||
cudaStream_t cuda_stream);
|
const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2,
|
||||||
template void CalSliceGrad<int>(const size_t input_size, const int *dy, const std::vector<size_t> in_shape,
|
const size_t d3, const size_t d4, const int *dy, int *dx, cudaStream_t stream);
|
||||||
const std::vector<int64_t> begin, const std::vector<int64_t> size, int *output,
|
template void CalSlice4DGrad<short>(const size_t s1, const size_t s2, const size_t s3, const size_t s4, // NOLINT
|
||||||
cudaStream_t cuda_stream);
|
const size_t l1,
|
||||||
template void CalSliceGrad<short>(const size_t input_size, const short *dy, // NOLINT
|
const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2,
|
||||||
const std::vector<size_t> in_shape, const std::vector<int64_t> begin,
|
const size_t d3, const size_t d4, const short *dy, short *dx, // NOLINT
|
||||||
const std::vector<int64_t> size, short *output, // NOLINT
|
cudaStream_t stream);
|
||||||
cudaStream_t cuda_stream);
|
template void CalSlice4DGrad<unsigned char>(const size_t s1, const size_t s2, const size_t s3, const size_t s4,
|
||||||
template void CalSliceGrad<unsigned char>(const size_t input_size, const unsigned char *dy,
|
const size_t l1, const size_t l2, const size_t l3, const size_t l4,
|
||||||
const std::vector<size_t> in_shape, const std::vector<int64_t> begin,
|
const size_t d1, const size_t d2, const size_t d3, const size_t d4,
|
||||||
const std::vector<int64_t> size, unsigned char *output,
|
const unsigned char *dy, unsigned char *dx, cudaStream_t stream);
|
||||||
cudaStream_t cuda_stream);
|
template void CalSlice4DGrad<int64_t>(const size_t s1, const size_t s2, const size_t s3, const size_t s4,
|
||||||
template void CalSliceGrad<bool>(const size_t input_size, const bool *dy, const std::vector<size_t> in_shape,
|
const size_t l1, const size_t l2, const size_t l3, const size_t l4,
|
||||||
const std::vector<int64_t> begin, const std::vector<int64_t> size, bool *output,
|
const size_t d1, const size_t d2, const size_t d3, const size_t d4,
|
||||||
cudaStream_t cuda_stream);
|
const int64_t *dy, int64_t *dx, cudaStream_t stream);
|
||||||
|
template void CalSlice4DGrad<bool>(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1,
|
||||||
|
const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2,
|
||||||
|
const size_t d3, const size_t d4, const bool *dy, bool *dx, cudaStream_t stream);
|
||||||
|
|
||||||
template void FillDeviceArray<bool>(const size_t input_size, bool *addr, const float value, cudaStream_t cuda_stream);
|
template void FillDeviceArray<bool>(const size_t input_size, bool *addr, const float value, cudaStream_t cuda_stream);
|
||||||
template void FillDeviceArray<int64_t>(const size_t input_size, int64_t *addr, const float value,
|
template void FillDeviceArray<int64_t>(const size_t input_size, int64_t *addr, const float value,
|
||||||
|
|
|
@ -26,9 +26,9 @@ void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size
|
||||||
const size_t l3, const size_t l4, const size_t d1, const size_t d2, const size_t d3, const size_t d4,
|
const size_t l3, const size_t l4, const size_t d1, const size_t d2, const size_t d3, const size_t d4,
|
||||||
const T *input, T *output, cudaStream_t stream);
|
const T *input, T *output, cudaStream_t stream);
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void CalSliceGrad(const size_t input_size, const T *input, const std::vector<size_t> in_shape,
|
void CalSlice4DGrad(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1,
|
||||||
const std::vector<int64_t> begin, const std::vector<int64_t> size, T *output,
|
const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2,
|
||||||
cudaStream_t cuda_stream);
|
const size_t d3, const size_t d4, const T *dy, T *dx, cudaStream_t stream);
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void StridedSlice(const std::vector<size_t> &input_shape, const std::vector<int64_t> &begin,
|
void StridedSlice(const std::vector<size_t> &input_shape, const std::vector<int64_t> &begin,
|
||||||
const std::vector<int64_t> &strides, const std::vector<size_t> &output_shape, const T *input,
|
const std::vector<int64_t> &strides, const std::vector<size_t> &output_shape, const T *input,
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -21,6 +21,6 @@ namespace kernel {
|
||||||
MS_REG_GPU_KERNEL_ONE(BiasAddGrad, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
MS_REG_GPU_KERNEL_ONE(BiasAddGrad, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||||
BiasAddGradGpuKernel, float)
|
BiasAddGradGpuKernel, float)
|
||||||
MS_REG_GPU_KERNEL_ONE(BiasAddGrad, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
MS_REG_GPU_KERNEL_ONE(BiasAddGrad, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||||
BiasAddGradGpuKernel, float16)
|
BiasAddGradGpuKernel, half)
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -17,8 +17,6 @@
|
||||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_BIAS_ADD_GRAD_GPU_KENEL_H_
|
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_BIAS_ADD_GRAD_GPU_KENEL_H_
|
||||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_BIAS_ADD_GRAD_GPU_KENEL_H_
|
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_BIAS_ADD_GRAD_GPU_KENEL_H_
|
||||||
|
|
||||||
#include <cuda_runtime_api.h>
|
|
||||||
#include <cublas_v2.h>
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
@ -26,6 +24,7 @@
|
||||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||||
#include "backend/kernel_compiler/gpu/kernel_constants.h"
|
#include "backend/kernel_compiler/gpu/kernel_constants.h"
|
||||||
|
#include "backend/kernel_compiler/gpu/cuda_impl/bias_add_grad_impl.cuh"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
|
@ -34,6 +33,10 @@ class BiasAddGradGpuKernel : public GpuKernel {
|
||||||
public:
|
public:
|
||||||
BiasAddGradGpuKernel()
|
BiasAddGradGpuKernel()
|
||||||
: same_dims_(true),
|
: same_dims_(true),
|
||||||
|
use_cudnn_(false),
|
||||||
|
dy_num_(1),
|
||||||
|
db_num_(1),
|
||||||
|
bias_size_(0),
|
||||||
cudnn_handle_(nullptr),
|
cudnn_handle_(nullptr),
|
||||||
cudnn_data_type_(CUDNN_DATA_FLOAT),
|
cudnn_data_type_(CUDNN_DATA_FLOAT),
|
||||||
dy_desc_(nullptr),
|
dy_desc_(nullptr),
|
||||||
|
@ -48,118 +51,171 @@ class BiasAddGradGpuKernel : public GpuKernel {
|
||||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||||
T *dy_addr = GetDeviceAddress<T>(inputs, 0);
|
T *dy_addr = GetDeviceAddress<T>(inputs, 0);
|
||||||
T *db_addr = GetDeviceAddress<T>(outputs, 0);
|
T *db_addr = GetDeviceAddress<T>(outputs, 0);
|
||||||
T *indices_addr = GetDeviceAddress<T>(workspace, 0);
|
|
||||||
T *workspace_addr = GetDeviceAddress<T>(workspace, 1);
|
|
||||||
|
|
||||||
const float alpha = 1;
|
|
||||||
const float beta = 0;
|
|
||||||
if (same_dims_) {
|
if (same_dims_) {
|
||||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
||||||
cudaMemcpyAsync(db_addr, dy_addr, output_size_list_[0], cudaMemcpyDeviceToDevice,
|
cudaMemcpyAsync(db_addr, dy_addr, output_size_list_[0], cudaMemcpyDeviceToDevice,
|
||||||
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||||
"cudaMemcpyAsync failed.");
|
"cudaMemcpyAsync failed.");
|
||||||
} else {
|
} else {
|
||||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
if (use_cudnn_) { // shared memory not satisfied or num_dim > 4
|
||||||
kernel_node_,
|
T *indices_addr = GetDeviceAddress<T>(workspace, 0);
|
||||||
cudnnReduceTensor(cudnn_handle_, op_desc_, indices_addr, workspace_size_list_[0], workspace_addr,
|
T *workspace_addr = GetDeviceAddress<T>(workspace, 1);
|
||||||
workspace_size_list_[1], &alpha, dy_desc_, dy_addr, &beta, db_desc_, db_addr),
|
const float alpha = 1;
|
||||||
"cudnnReduceTensor failed");
|
const float beta = 0;
|
||||||
|
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||||
|
kernel_node_,
|
||||||
|
cudnnReduceTensor(cudnn_handle_, op_desc_, indices_addr, workspace_size_list_[0], workspace_addr,
|
||||||
|
workspace_size_list_[1], &alpha, dy_desc_, dy_addr, &beta, db_desc_, db_addr),
|
||||||
|
"cudnnReduceTensor failed");
|
||||||
|
} else { // use own implementation which is more efficient but cannot process num_dim > 4
|
||||||
|
if (data_format_ == kOpFormat_NHWC) {
|
||||||
|
CalBiasAddGradNHWC(dy_num_, bias_size_, dy_addr, db_addr, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||||
|
} else {
|
||||||
|
CalBiasAddGradNCHW(dy_num_, bias_size_, SizeToInt(dy_shape_[2]), SizeToInt(dy_shape_[3]), dy_addr, db_addr,
|
||||||
|
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
bool Init(const CNodePtr &kernel_node) override {
|
bool Init(const CNodePtr &kernel_node) override {
|
||||||
kernel_node_ = kernel_node;
|
kernel_node_ = kernel_node;
|
||||||
InitResource();
|
|
||||||
cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0)));
|
|
||||||
auto dy_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
auto dy_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||||
auto num_dims = dy_shape.size();
|
cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0)));
|
||||||
if (num_dims < 2) {
|
auto input_device_format = AnfAlgo::GetInputFormat(kernel_node, 0);
|
||||||
MS_LOG(EXCEPTION) << "input dims must be at least 2, but got " << num_dims;
|
cudnn_compute_format_ = (input_device_format == kOpFormat_NHWC) ? CUDNN_TENSOR_NHWC : CUDNN_TENSOR_NCHW;
|
||||||
|
num_dims_ = dy_shape.size();
|
||||||
|
if (num_dims_ < 2) {
|
||||||
|
MS_LOG(EXCEPTION) << "input dims must be at least 2, but got " << num_dims_;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string format = GetAttr<std::string>(kernel_node, "format");
|
std::string format = GetAttr<std::string>(kernel_node, "format");
|
||||||
string::size_type pos = format.find("C");
|
string::size_type pos = format.find("C");
|
||||||
if (pos == std::string::npos || pos >= num_dims) {
|
if (pos == std::string::npos || pos >= num_dims_) {
|
||||||
MS_LOG(EXCEPTION) << "format '" << format << "' invalid";
|
MS_LOG(EXCEPTION) << "format '" << format << "' invalid";
|
||||||
}
|
}
|
||||||
|
bias_size_ = dy_shape[pos];
|
||||||
// Expand to 4 dims for cudnnSetTensorNdDescriptorEx.
|
auto num_dims_fix = std::max(num_dims_, 4UL);
|
||||||
auto cudnn_dims = std::max(num_dims, 4UL);
|
for (size_t i = 0; i < num_dims_fix; i++) {
|
||||||
std::unique_ptr<int[]> dy_dims = std::make_unique<int[]>(cudnn_dims);
|
dy_shape_.push_back((i < num_dims_) ? dy_shape[i] : 1);
|
||||||
std::unique_ptr<int[]> db_dims = std::make_unique<int[]>(cudnn_dims);
|
db_shape_.push_back((i == pos) ? dy_shape[i] : 1);
|
||||||
for (size_t i = 0; i < cudnn_dims; i++) {
|
if (dy_shape_[i] != db_shape_[i]) {
|
||||||
dy_dims[i] = (i < num_dims) ? SizeToInt(dy_shape[i]) : 1;
|
|
||||||
db_dims[i] = (i == pos) ? SizeToInt(dy_shape[i]) : 1;
|
|
||||||
|
|
||||||
if (dy_dims[i] != db_dims[i]) {
|
|
||||||
same_dims_ = false;
|
same_dims_ = false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
for (size_t i = 0; i < dy_shape_.size(); i++) {
|
||||||
auto input_device_format = AnfAlgo::GetInputFormat(kernel_node, 0);
|
dy_num_ *= dy_shape_[i];
|
||||||
auto cudnn_cal_format = (input_device_format == kOpFormat_NHWC) ? CUDNN_TENSOR_NHWC : CUDNN_TENSOR_NCHW;
|
}
|
||||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
for (size_t i = 0; i < db_shape_.size(); i++) {
|
||||||
kernel_node_,
|
db_num_ *= db_shape_[i];
|
||||||
cudnnSetTensorNdDescriptorEx(dy_desc_, cudnn_cal_format, cudnn_data_type_, SizeToInt(cudnn_dims), dy_dims.get()),
|
}
|
||||||
"cudnnSetTensorNdDescriptor failed");
|
data_format_ = input_device_format; // for opt implementation
|
||||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
if (format == kOpFormat_NHWC) {
|
||||||
kernel_node_,
|
data_format_ = kOpFormat_NHWC;
|
||||||
cudnnSetTensorNdDescriptorEx(db_desc_, cudnn_cal_format, cudnn_data_type_, SizeToInt(cudnn_dims), db_dims.get()),
|
}
|
||||||
"cudnnSetTensorNdDescriptor failed");
|
MethodSelection();
|
||||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
InitResource();
|
||||||
kernel_node_,
|
|
||||||
cudnnSetReduceTensorDescriptor(op_desc_, CUDNN_REDUCE_TENSOR_ADD, CUDNN_DATA_FLOAT, CUDNN_NOT_PROPAGATE_NAN,
|
|
||||||
CUDNN_REDUCE_TENSOR_NO_INDICES, CUDNN_32BIT_INDICES),
|
|
||||||
"cudnnSetReduceTensorDescriptor failed");
|
|
||||||
|
|
||||||
InitSizeLists();
|
InitSizeLists();
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void DestroyResource() noexcept override {
|
void DestroyResource() noexcept override {
|
||||||
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnDestroyReduceTensorDescriptor(op_desc_),
|
if (use_cudnn_) {
|
||||||
"cudnnDestroyReduceTensorDescriptor failed");
|
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnDestroyReduceTensorDescriptor(op_desc_),
|
||||||
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(db_desc_),
|
"cudnnDestroyReduceTensorDescriptor failed");
|
||||||
"cudnnDestroyTensorDescriptor failed");
|
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(db_desc_),
|
||||||
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(dy_desc_),
|
"cudnnDestroyTensorDescriptor failed");
|
||||||
"cudnnDestroyOpTensorDescriptor failed");
|
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(dy_desc_),
|
||||||
|
"cudnnDestroyOpTensorDescriptor failed");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
void MethodSelection() {
|
||||||
|
// opt implementation can only process num_dims_ <= 4
|
||||||
|
// for num_dims_ = 2, not time-consuming, use cudnn
|
||||||
|
if (num_dims_ > 4 || num_dims_ == 2) {
|
||||||
|
use_cudnn_ = true;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (data_format_ == kOpFormat_NHWC) {
|
||||||
|
size_t required_sharedmem_size = 32 * 33 * sizeof(float);
|
||||||
|
// nhwc opt implementation performs not so well when bias_size_ <= 6
|
||||||
|
if (required_sharedmem_size > SHARED_MEM_PER_BLOCK || bias_size_ <= 6) {
|
||||||
|
use_cudnn_ = true;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
void InitResource() override {
|
void InitResource() override {
|
||||||
cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle();
|
if (use_cudnn_) {
|
||||||
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&dy_desc_),
|
cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle();
|
||||||
"cudnnCreateTensorDescriptor failed");
|
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&dy_desc_),
|
||||||
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&db_desc_),
|
"cudnnCreateTensorDescriptor failed");
|
||||||
"cudnnCreateTensorDescriptor failed");
|
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&db_desc_),
|
||||||
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateReduceTensorDescriptor(&op_desc_),
|
"cudnnCreateTensorDescriptor failed");
|
||||||
"cudnnCreateOpTensorDescriptor failed");
|
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateReduceTensorDescriptor(&op_desc_),
|
||||||
|
"cudnnCreateOpTensorDescriptor failed");
|
||||||
|
// Expand to 4 dims for cudnnSetTensorNdDescriptorEx.
|
||||||
|
auto cudnn_dims = std::max(num_dims_, 4UL);
|
||||||
|
std::unique_ptr<int[]> dy_dims = std::make_unique<int[]>(cudnn_dims);
|
||||||
|
std::unique_ptr<int[]> db_dims = std::make_unique<int[]>(cudnn_dims);
|
||||||
|
for (size_t i = 0; i < cudnn_dims; i++) {
|
||||||
|
dy_dims[i] = SizeToInt(dy_shape_[i]);
|
||||||
|
db_dims[i] = SizeToInt(db_shape_[i]);
|
||||||
|
}
|
||||||
|
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
|
||||||
|
cudnnSetTensorNdDescriptorEx(dy_desc_, cudnn_compute_format_, cudnn_data_type_,
|
||||||
|
SizeToInt(cudnn_dims), dy_dims.get()),
|
||||||
|
"cudnnSetTensorNdDescriptor failed");
|
||||||
|
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
|
||||||
|
cudnnSetTensorNdDescriptorEx(db_desc_, cudnn_compute_format_, cudnn_data_type_,
|
||||||
|
SizeToInt(cudnn_dims), db_dims.get()),
|
||||||
|
"cudnnSetTensorNdDescriptor failed");
|
||||||
|
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||||
|
kernel_node_,
|
||||||
|
cudnnSetReduceTensorDescriptor(op_desc_, CUDNN_REDUCE_TENSOR_ADD, CUDNN_DATA_FLOAT, CUDNN_NOT_PROPAGATE_NAN,
|
||||||
|
CUDNN_REDUCE_TENSOR_NO_INDICES, CUDNN_32BIT_INDICES),
|
||||||
|
"cudnnSetReduceTensorDescriptor failed");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
void InitSizeLists() override {
|
void InitSizeLists() override {
|
||||||
size_t dy_size, db_size;
|
if (use_cudnn_) {
|
||||||
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnGetTensorSizeInBytes(dy_desc_, &dy_size),
|
size_t dy_size, db_size;
|
||||||
"cudnnGetTensorSizeInBytes failed");
|
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnGetTensorSizeInBytes(dy_desc_, &dy_size),
|
||||||
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnGetTensorSizeInBytes(db_desc_, &db_size),
|
"cudnnGetTensorSizeInBytes failed");
|
||||||
"cudnnGetTensorSizeInBytes failed");
|
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnGetTensorSizeInBytes(db_desc_, &db_size),
|
||||||
input_size_list_.push_back(dy_size);
|
"cudnnGetTensorSizeInBytes failed");
|
||||||
output_size_list_.push_back(db_size);
|
input_size_list_.push_back(dy_size);
|
||||||
|
output_size_list_.push_back(db_size);
|
||||||
size_t indices_size, workspace_size;
|
size_t indices_size, workspace_size;
|
||||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||||
kernel_node_, cudnnGetReductionIndicesSize(cudnn_handle_, op_desc_, dy_desc_, db_desc_, &indices_size),
|
kernel_node_, cudnnGetReductionIndicesSize(cudnn_handle_, op_desc_, dy_desc_, db_desc_, &indices_size),
|
||||||
"cudnnGetReductionIndicesSize failed")
|
"cudnnGetReductionIndicesSize failed")
|
||||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||||
kernel_node_, cudnnGetReductionWorkspaceSize(cudnn_handle_, op_desc_, dy_desc_, db_desc_, &workspace_size),
|
kernel_node_, cudnnGetReductionWorkspaceSize(cudnn_handle_, op_desc_, dy_desc_, db_desc_, &workspace_size),
|
||||||
"cudnnGetReductionWorkspaceSize failed")
|
"cudnnGetReductionWorkspaceSize failed")
|
||||||
workspace_size_list_.push_back(indices_size);
|
workspace_size_list_.push_back(indices_size);
|
||||||
workspace_size_list_.push_back(workspace_size);
|
workspace_size_list_.push_back(workspace_size);
|
||||||
|
} else {
|
||||||
|
input_size_list_.push_back(dy_num_ * sizeof(T));
|
||||||
|
output_size_list_.push_back(db_num_ * sizeof(T));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
bool same_dims_;
|
bool same_dims_;
|
||||||
|
bool use_cudnn_;
|
||||||
|
size_t dy_num_; // for own implementation
|
||||||
|
size_t db_num_;
|
||||||
|
size_t num_dims_;
|
||||||
|
size_t bias_size_; // for own implementation
|
||||||
|
std::vector<size_t> dy_shape_; // for own implementation
|
||||||
|
std::vector<size_t> db_shape_; // for own implementation
|
||||||
|
std::string data_format_ = kOpFormat_NCHW;
|
||||||
|
// for cudnn implementation
|
||||||
cudnnHandle_t cudnn_handle_;
|
cudnnHandle_t cudnn_handle_;
|
||||||
cudnnDataType_t cudnn_data_type_;
|
cudnnDataType_t cudnn_data_type_;
|
||||||
|
cudnnTensorFormat_t cudnn_compute_format_;
|
||||||
cudnnTensorDescriptor_t dy_desc_;
|
cudnnTensorDescriptor_t dy_desc_;
|
||||||
cudnnTensorDescriptor_t db_desc_;
|
cudnnTensorDescriptor_t db_desc_;
|
||||||
cudnnReduceTensorDescriptor_t op_desc_;
|
cudnnReduceTensorDescriptor_t op_desc_;
|
||||||
|
|
|
@ -60,6 +60,7 @@ static std::map<std::string, std::pair<std::vector<size_t>, std::vector<size_t>>
|
||||||
{prim::kPrimRelu6->name(), {{0}, {0}}},
|
{prim::kPrimRelu6->name(), {{0}, {0}}},
|
||||||
{prim::kPrimRelu6Grad->name(), {{0, 1}, {0}}},
|
{prim::kPrimRelu6Grad->name(), {{0, 1}, {0}}},
|
||||||
{kSliceOpName, {{0}, {0}}},
|
{kSliceOpName, {{0}, {0}}},
|
||||||
|
{kSliceGradOpName, {{0, 1}, {0}}},
|
||||||
{kTensorAddOpName, {{0, 1}, {0}}},
|
{kTensorAddOpName, {{0, 1}, {0}}},
|
||||||
{prim::kPrimConcat->name(), {{kAllPositions}, {0}}},
|
{prim::kPrimConcat->name(), {{kAllPositions}, {0}}},
|
||||||
{prim::kPrimAddN->name(), {{kAllPositions}, {0}}},
|
{prim::kPrimAddN->name(), {{kAllPositions}, {0}}},
|
||||||
|
|
|
@ -108,8 +108,8 @@ python preprocess_dataset.py -d /data/save_data_path
|
||||||
|
|
||||||
## [Environment Requirements](#contents)
|
## [Environment Requirements](#contents)
|
||||||
|
|
||||||
- Hardware(Ascend)
|
- Hardware(Ascend/GPU)
|
||||||
- Prepare hardware environment with Ascend processor.
|
- Prepare hardware environment with Ascend or GPU processor.
|
||||||
- Framework
|
- Framework
|
||||||
- [MindSpore](https://www.mindspore.cn/install/en)
|
- [MindSpore](https://www.mindspore.cn/install/en)
|
||||||
- For more information, please check the resources below:
|
- For more information, please check the resources below:
|
||||||
|
@ -191,6 +191,23 @@ If you want to run in modelarts, please check the official documentation of [mod
|
||||||
# (7) Create your job.
|
# (7) Create your job.
|
||||||
```
|
```
|
||||||
|
|
||||||
|
- Run on GPU
|
||||||
|
|
||||||
|
```python
|
||||||
|
# run training example
|
||||||
|
python train.py --data_path=/path/to/data/ --config_path=/path/to/config/ --output ./output > train.log 2>&1 &
|
||||||
|
OR
|
||||||
|
bash scripts/run_standalone_train_gpu.sh [DATASET] [CONFIG_PATH]
|
||||||
|
|
||||||
|
# run distributed training example
|
||||||
|
bash scripts/run_distribute_train_gpu.sh [RANKSIZE] [DATASET] [CONFIG_PATH]
|
||||||
|
|
||||||
|
# run evaluation example
|
||||||
|
python eval.py --data_path=/path/to/data/ --checkpoint_file_path=/path/to/checkpoint/ --config_path=/path/to/config/ > eval.log 2>&1 &
|
||||||
|
OR
|
||||||
|
bash scripts/run_standalone_eval_gpu.sh [DATASET] [CHECKPOINT] [CONFIG_PATH]
|
||||||
|
```
|
||||||
|
|
||||||
## [Script Description](#contents)
|
## [Script Description](#contents)
|
||||||
|
|
||||||
### [Script and Sample Code](#contents)
|
### [Script and Sample Code](#contents)
|
||||||
|
@ -207,6 +224,9 @@ If you want to run in modelarts, please check the official documentation of [mod
|
||||||
│ ├──run_infer_310.sh // shell script for infer on ascend 310
|
│ ├──run_infer_310.sh // shell script for infer on ascend 310
|
||||||
│ ├──run_standalone_train.sh // shell script for standalone on Ascend
|
│ ├──run_standalone_train.sh // shell script for standalone on Ascend
|
||||||
│ ├──run_standalone_eval.sh // shell script for evaluation on Ascend
|
│ ├──run_standalone_eval.sh // shell script for evaluation on Ascend
|
||||||
|
│ ├──run_standalone_train_gpu.sh // shell script for training on GPU
|
||||||
|
│ ├──run_standalone_eval_gpu.sh // shell script forevaluation on GPU
|
||||||
|
│ ├──run_distribute_train_gpu.sh // shell script for distributed on GPU
|
||||||
├── src
|
├── src
|
||||||
│ ├──config.py // parameter configuration
|
│ ├──config.py // parameter configuration
|
||||||
│ ├──data_loader.py // creating dataset
|
│ ├──data_loader.py // creating dataset
|
||||||
|
@ -261,6 +281,8 @@ Parameters for both training and evaluation can be set in config.py
|
||||||
'weight_decay': 0.0005, # weight decay value
|
'weight_decay': 0.0005, # weight decay value
|
||||||
'loss_scale': 1024.0, # loss scale
|
'loss_scale': 1024.0, # loss scale
|
||||||
'FixedLossScaleManager': 1024.0, # fix loss scale
|
'FixedLossScaleManager': 1024.0, # fix loss scale
|
||||||
|
'is_save_on_master': 1, # save checkpoint on master or all rank
|
||||||
|
'rank': 0, # local rank of distributed(default: 0)
|
||||||
'resume': False, # whether training with pretrain model
|
'resume': False, # whether training with pretrain model
|
||||||
'resume_ckpt': './', # pretrain model path
|
'resume_ckpt': './', # pretrain model path
|
||||||
'transfer_training': False # whether do transfer training
|
'transfer_training': False # whether do transfer training
|
||||||
|
@ -337,23 +359,43 @@ step: 600, loss is 0.22070312, fps is 56.99692546024671
|
||||||
|
|
||||||
The model checkpoint will be saved in the current directory.
|
The model checkpoint will be saved in the current directory.
|
||||||
|
|
||||||
#### Distributed Training
|
#### running on GPU
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
bash scripts/run_distribute_train.sh [RANK_TABLE_FILE] [DATASET] [CONFIG_PATH]
|
python train.py --data_path=/path/to/data/ --config_path=/path/to/config/ --output ./output > train.log 2>&1 &
|
||||||
|
OR
|
||||||
|
bash scripts/run_standalone_train_gpu.sh [DATASET] [CONFIG_PATH]
|
||||||
```
|
```
|
||||||
|
|
||||||
The above shell script will run distribute training in the background. You can view the results through the file `logs/device[X]/log.log`. The loss value will be achieved as follows:
|
The python command above will run in the background, you can view the results through the file train.log. The model checkpoint will be saved in the current directory.
|
||||||
|
|
||||||
|
### Distributed Training
|
||||||
|
|
||||||
|
#### running on Ascend
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
# grep "loss is" logs/device0/log.log
|
bash scripts/run_distribute_train.sh [RANK_TABLE_FILE] [DATASET]
|
||||||
step: 1, loss is 0.70524895, fps is 0.15914689861221412
|
|
||||||
step: 2, loss is 0.6925452, fps is 56.43668656967454
|
|
||||||
...
|
|
||||||
step: 299, loss is 0.20551169, fps is 58.4039329983891
|
|
||||||
step: 300, loss is 0.18949677, fps is 57.63118508760329
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
The above shell script will run distribute training in the background. You can view the results through the file `logs/device[X]/log.log`. The loss value will be achieved as follows:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
# grep "loss is" logs/device0/log.log
|
||||||
|
step: 1, loss is 0.70524895, fps is 0.15914689861221412
|
||||||
|
step: 2, loss is 0.6925452, fps is 56.43668656967454
|
||||||
|
...
|
||||||
|
step: 299, loss is 0.20551169, fps is 58.4039329983891
|
||||||
|
step: 300, loss is 0.18949677, fps is 57.63118508760329
|
||||||
|
```
|
||||||
|
|
||||||
|
#### running on GPU
|
||||||
|
|
||||||
|
```shell
|
||||||
|
bash scripts/run_distribute_train_gpu.sh [RANKSIZE] [DATASET] [CONFIG_PATH]
|
||||||
|
```
|
||||||
|
|
||||||
|
The above shell script will run distribute training in the background. You can view the results through the file `train.log`.
|
||||||
|
|
||||||
#### Evaluation while training
|
#### Evaluation while training
|
||||||
|
|
||||||
You can add `run_eval` to start shell and set it True, if you want evaluation while training. And you can set argument option: `save_best_ckpt`, `eval_start_epoch`, `eval_interval`, `eval_metrics` when `run_eval` is True.
|
You can add `run_eval` to start shell and set it True, if you want evaluation while training. And you can set argument option: `save_best_ckpt`, `eval_start_epoch`, `eval_interval`, `eval_metrics` when `run_eval` is True.
|
||||||
|
@ -379,37 +421,56 @@ The above python command will run in the background. You can view the results th
|
||||||
============== Cross valid dice coeff is: {'dice_coeff': 0.9111}
|
============== Cross valid dice coeff is: {'dice_coeff': 0.9111}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
- evaluation on ISBI dataset when running on GPU
|
||||||
|
|
||||||
|
Before running the command below, please check the checkpoint path used for evaluation. Please set the checkpoint path to be the absolute full path, e.g., "username/unet/ckpt_unet_medical_adam-2_400.ckpt".
|
||||||
|
|
||||||
|
```shell
|
||||||
|
python eval.py --data_path=/path/to/data/ --checkpoint_file_path=/path/to/checkpoint/ --config_path=/path/to/config/ > eval.log 2>&1 &
|
||||||
|
OR
|
||||||
|
bash scripts/run_standalone_eval_gpu.sh [DATASET] [CHECKPOINT] [CONFIG_PATH]
|
||||||
|
```
|
||||||
|
|
||||||
|
The above python command will run in the background. You can view the results through the file "eval.log". The accuracy of the test dataset will be as follows:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
# grep "Cross valid dice coeff is:" eval.log
|
||||||
|
============== Cross valid dice coeff is: {'dice_coeff': 0.9089390969777261}
|
||||||
|
```
|
||||||
|
|
||||||
## [Model Description](#contents)
|
## [Model Description](#contents)
|
||||||
|
|
||||||
### [Performance](#contents)
|
### [Performance](#contents)
|
||||||
|
|
||||||
#### Evaluation Performance
|
#### Evaluation Performance
|
||||||
|
|
||||||
| Parameters | Ascend |
|
| Parameters | Ascend | GPU |
|
||||||
| -------------------------- | ------------------------------------------------------------ |
|
| -------------------------- | ------------------------------------------------------------ | :----------------------------------------------------------- |
|
||||||
| Model Version | Unet |
|
| Model Version | Unet | Unet |
|
||||||
| Resource | Ascend 910; CPU 2.60GHz, 192cores; Memory 755G; OS Euler2.8 |
|
| Resource | Ascend 910 ;CPU 2.60GHz,192cores; Memory,755G; OS Euler2.8 | NV SMX2 V100-32G |
|
||||||
| uploaded Date | 09/15/2020 (month/day/year) |
|
| uploaded Date | 09/15/2020 (month/day/year) | 01/20/2021 (month/day/year) |
|
||||||
| MindSpore Version | 1.2.0 |
|
| MindSpore Version | 1.2.0 | 1.1.0 |
|
||||||
| Dataset | ISBI |
|
| Dataset | ISBI | ISBI |
|
||||||
| Training Parameters | 1pc: epoch=400, total steps=600, batch_size = 16, lr=0.0001 |
|
| Training Parameters | 1pc: epoch=400, total steps=600, batch_size = 16, lr=0.0001 | 1pc: epoch=400, total steps=800, batch_size = 12, lr=0.0001 |
|
||||||
| Optimizer | Adam |
|
| Optimizer | ADAM | ADAM |
|
||||||
| Loss Function | Softmax Cross Entropy |
|
| Loss Function | Softmax Cross Entropy | Softmax Cross Entropy |
|
||||||
| outputs | probability |
|
| outputs | probability | probability |
|
||||||
| Loss | 0.22070312 |
|
| Loss | 0.22070312 | 0.21425568 |
|
||||||
| Speed | 1pc: 267 ms/step |
|
| Speed | 1pc: 267 ms/step; | 1pc: 423 ms/step; |
|
||||||
| Total time | 1pc: 2.67 mins |
|
| Total time | 1pc: 2.67 mins; | 1pc: 5.64 mins; |
|
||||||
| Parameters (M) | 93M |
|
| Parameters (M) | 93M | 93M |
|
||||||
| Checkpoint for Fine tuning | 355.11M (.ckpt file) |
|
| Checkpoint for Fine tuning | 355.11M (.ckpt file) | 355.11M (.ckpt file) |
|
||||||
| Scripts | [unet script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/unet) |
|
| Scripts | [unet script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/unet) | [unet script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/unet) |
|
||||||
|
|
||||||
### [How to use](#contents)
|
## [How to use](#contents)
|
||||||
|
|
||||||
#### Inference
|
### Inference
|
||||||
|
|
||||||
If you need to use the trained model to perform inference on multiple hardware platforms, such as Ascend 910 or Ascend 310, you can refer to this [Link](https://www.mindspore.cn/tutorial/training/en/master/advanced_use/migrate_3rd_scripts.html). Following the steps below, this is a simple example:
|
If you need to use the trained model to perform inference on multiple hardware platforms, such as Ascend 910 or Ascend 310, you
|
||||||
|
can refer to this [Link](https://www.mindspore.cn/tutorial/training/en/master/advanced_use/migrate_3rd_scripts.html). Following
|
||||||
|
the steps below, this is a simple example:
|
||||||
|
|
||||||
##### Running on Ascend 310
|
#### Running on Ascend 310
|
||||||
|
|
||||||
Export MindIR
|
Export MindIR
|
||||||
|
|
||||||
|
@ -464,4 +525,4 @@ In data_loader.py, we set the seed inside “_get_val_train_indices" function. W
|
||||||
|
|
||||||
## [ModelZoo Homepage](#contents)
|
## [ModelZoo Homepage](#contents)
|
||||||
|
|
||||||
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
|
@ -112,8 +112,8 @@ python preprocess_dataset.py -d /data/save_data_path
|
||||||
|
|
||||||
## 环境要求
|
## 环境要求
|
||||||
|
|
||||||
- 硬件(Ascend)
|
- 硬件(Ascend/GPU)
|
||||||
- 准备Ascend处理器搭建硬件环境。
|
- 准备Ascend处理器或GPU处理器搭建硬件环境。
|
||||||
- 框架
|
- 框架
|
||||||
- [MindSpore](https://www.mindspore.cn/install)
|
- [MindSpore](https://www.mindspore.cn/install)
|
||||||
- 如需查看详情,请参见如下资源:
|
- 如需查看详情,请参见如下资源:
|
||||||
|
@ -198,6 +198,25 @@ bash scripts/docker_start.sh unet:20.1.0 [DATA_DIR] [MODEL_DIR]
|
||||||
# (7) 开始模型的推理。
|
# (7) 开始模型的推理。
|
||||||
```
|
```
|
||||||
|
|
||||||
|
- GPU处理器环境运行
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 训练示例
|
||||||
|
python train.py --data_path=/path/to/data/ --config_path=/path/to/config/ --output ./output > train.log 2>&1 &
|
||||||
|
OR
|
||||||
|
bash scripts/run_standalone_train_gpu.sh [DATASET] [CONFIG_PATH]
|
||||||
|
|
||||||
|
# 分布式训练示例
|
||||||
|
bash scripts/run_distribute_train_gpu.sh [RANKSIZE] [DATASET] [CONFIG_PATH]
|
||||||
|
|
||||||
|
# 评估示例
|
||||||
|
python eval.py --data_path=/path/to/data/ --checkpoint_file_path=/path/to/checkpoint/ --config_path=/path/to/config/ > eval.log 2>&1 &
|
||||||
|
OR
|
||||||
|
bash scripts/run_standalone_eval_gpu.sh [DATASET] [CHECKPOINT] [CONFIG_PATH]
|
||||||
|
```
|
||||||
|
|
||||||
|
# 脚本说明
|
||||||
|
|
||||||
## 脚本说明
|
## 脚本说明
|
||||||
|
|
||||||
### 脚本及样例代码
|
### 脚本及样例代码
|
||||||
|
@ -214,6 +233,9 @@ bash scripts/docker_start.sh unet:20.1.0 [DATA_DIR] [MODEL_DIR]
|
||||||
│ ├──run_infer_310.sh // Ascend 310 推理脚本
|
│ ├──run_infer_310.sh // Ascend 310 推理脚本
|
||||||
│ ├──run_standalone_train.sh // Ascend 上单卡训练脚本
|
│ ├──run_standalone_train.sh // Ascend 上单卡训练脚本
|
||||||
│ ├──run_standalone_eval.sh // Ascend 上推理脚本
|
│ ├──run_standalone_eval.sh // Ascend 上推理脚本
|
||||||
|
│ ├──run_standalone_train_gpu.sh // GPU 上训练脚本
|
||||||
|
│ ├──run_standalone_eval_gpu.sh // GPU 上评估脚本
|
||||||
|
│ ├──run_distribute_train_gpu.sh // GPU 上分布式训练脚本
|
||||||
├── src
|
├── src
|
||||||
│ ├──config.py // 参数配置
|
│ ├──config.py // 参数配置
|
||||||
│ ├──data_loader.py // 数据处理
|
│ ├──data_loader.py // 数据处理
|
||||||
|
@ -268,6 +290,8 @@ bash scripts/docker_start.sh unet:20.1.0 [DATA_DIR] [MODEL_DIR]
|
||||||
'weight_decay': 0.0005, # 权重衰减值
|
'weight_decay': 0.0005, # 权重衰减值
|
||||||
'loss_scale': 1024.0, # 损失放大
|
'loss_scale': 1024.0, # 损失放大
|
||||||
'FixedLossScaleManager': 1024.0, # 固定损失放大
|
'FixedLossScaleManager': 1024.0, # 固定损失放大
|
||||||
|
'is_save_on_master': 1, # 在master或all rank上保存检查点
|
||||||
|
'rank': 0, # 分布式local rank(默认为0)
|
||||||
'resume': False, # 是否使用预训练模型训练
|
'resume': False, # 是否使用预训练模型训练
|
||||||
'resume_ckpt': './', # 预训练模型路径
|
'resume_ckpt': './', # 预训练模型路径
|
||||||
```
|
```
|
||||||
|
@ -332,9 +356,22 @@ python train.py --data_path=/path/to/data/ --config_path=/path/to/yaml > train.l
|
||||||
```
|
```
|
||||||
|
|
||||||
模型检查点储存在当前路径中。
|
模型检查点储存在当前路径中。
|
||||||
|
- GPU处理器环境运行
|
||||||
|
|
||||||
|
```shell
|
||||||
|
python train.py --data_path=/path/to/data/ --config_path=/path/to/config/ --output ./output > train.log 2>&1 &
|
||||||
|
OR
|
||||||
|
bash scripts/run_standalone_train_gpu.sh [DATASET] [CONFIG_PATH]
|
||||||
|
```
|
||||||
|
|
||||||
|
上述python命令在后台运行,可通过`train.log`文件查看结果。
|
||||||
|
|
||||||
|
训练结束后,您可以在默认脚本文件夹中找到检查点文件。
|
||||||
|
|
||||||
### 分布式训练
|
### 分布式训练
|
||||||
|
|
||||||
|
- Ascend处理器环境运行
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
bash scripts/run_distribute_train.sh [RANK_TABLE_FILE] [DATASET]
|
bash scripts/run_distribute_train.sh [RANK_TABLE_FILE] [DATASET]
|
||||||
```
|
```
|
||||||
|
@ -350,6 +387,14 @@ step: 299, loss is 0.20551169, fps is 58.4039329983891
|
||||||
step: 300, loss is 0.18949677, fps is 57.63118508760329
|
step: 300, loss is 0.18949677, fps is 57.63118508760329
|
||||||
```
|
```
|
||||||
|
|
||||||
|
- GPU处理器环境运行
|
||||||
|
|
||||||
|
```shell
|
||||||
|
bash scripts/run_distribute_train_gpu.sh [RANKSIZE] [DATASET] [CONFIG_PATH]
|
||||||
|
```
|
||||||
|
|
||||||
|
上述shell脚本在后台运行分布式训练。可通过`train.log`文件查看结果。
|
||||||
|
|
||||||
#### 训练时推理
|
#### 训练时推理
|
||||||
|
|
||||||
训练时推理需要在启动文件中添加`run_eval` 并设置为True。与此同时需要设置: `save_best_ckpt`, `eval_start_epoch`, `eval_interval`, `eval_metrics` 。
|
训练时推理需要在启动文件中添加`run_eval` 并设置为True。与此同时需要设置: `save_best_ckpt`, `eval_start_epoch`, `eval_interval`, `eval_metrics` 。
|
||||||
|
@ -375,29 +420,46 @@ python eval.py --data_path=/path/to/data/ --checkpoint_file_path=/path/to/checkp
|
||||||
============== Cross valid dice coeff is: {'dice_coeff': 0.9111}
|
============== Cross valid dice coeff is: {'dice_coeff': 0.9111}
|
||||||
```
|
```
|
||||||
|
|
||||||
## 模型描述
|
- GPU处理器环境运行评估ISBI数据集
|
||||||
|
|
||||||
### 性能
|
在运行以下命令之前,请检查用于评估的检查点路径。将检查点路径设置为绝对全路径,如"username/unet/ckpt_unet_medical_adam-2_400.ckpt"。
|
||||||
|
|
||||||
#### 评估性能
|
```shell
|
||||||
|
python eval.py --data_path=/path/to/data/ --checkpoint_file_path=/path/to/checkpoint/ --config_path=/path/to/config/ > eval.log 2>&1 &
|
||||||
|
OR
|
||||||
|
bash scripts/run_standalone_eval_gpu.sh [DATASET] [CHECKPOINT] [CONFIG_PATH]
|
||||||
|
```
|
||||||
|
|
||||||
| 参数 | Ascend |
|
上述python命令在后台运行。可通过"eval.log"文件查看结果。测试数据集的准确率如下:
|
||||||
| -------------------------- | ------------------------------------------------------------ |
|
|
||||||
| 模型版本 | U-Net |
|
```shell
|
||||||
| 资源 | Ascend 910;CPU 2.60GHz,192核;内存 755GB;系统 Euler2.8 |
|
# grep "Cross valid dice coeff is:" eval.log
|
||||||
| 上传日期 | 2020-9-15 |
|
============== Cross valid dice coeff is: {'dice_coeff': 0.9089390969777261}
|
||||||
| MindSpore版本 | 1.2.0 |
|
```
|
||||||
| 数据集 | ISBI |
|
|
||||||
| 训练参数 | 1pc: epoch=400, total steps=600, batch_size = 16, lr=0.0001 |
|
# 模型描述
|
||||||
| 优化器 | Adam |
|
|
||||||
| 损失函数 | Softmax交叉熵 |
|
## 性能
|
||||||
| 输出 | 概率 |
|
|
||||||
| 损失 | 0.22070312 |
|
### 评估性能
|
||||||
| 速度 | 1卡:267毫秒/步;8卡:280毫秒/步 |
|
|
||||||
| 总时长 | 1卡:2.67分钟;8卡:1.40分钟 |
|
| 参数 | Ascend | GPU |
|
||||||
| 参数(M) | 93M |
|
| -------------------------- | ------------------------------------------------------------ | ------------------------------------------------------------ |
|
||||||
| 微调检查点 | 355.11M (.ckpt文件) |
|
| 模型版本 | U-Net | U-Net |
|
||||||
| 脚本 | [U-Net脚本](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/unet) |
|
| 资源 | Ascend 910;CPU:2.60GHz,192核;内存:755 GB;系统 Euler2.8 | NV SMX2 V100,内存:32G |
|
||||||
|
| 上传日期 | 2020-9-15 | 2020-12-29 |
|
||||||
|
| MindSpore版本 | 1.2.0 | 1.1.0 |
|
||||||
|
| 数据集 | ISBI | ISBI |
|
||||||
|
| 训练参数 | 1pc: epoch=400, total steps=600, batch_size = 16, lr=0.0001 | 1pc: epoch=400, total steps=800,batch_size = 12, lr=0.0001 |
|
||||||
|
| 优化器 | ADAM | ADAM |
|
||||||
|
| 损失函数 | Softmax交叉熵 | Softmax交叉熵 |
|
||||||
|
| 输出 | 概率 | 概率 |
|
||||||
|
| 损失 | 0.22070312 | 0.21425568 |
|
||||||
|
| 速度 | 1卡:267毫秒/步;8卡:280毫秒/步 | 1卡:423毫秒/步;8卡:128毫秒/步 |
|
||||||
|
| 总时长 | 1卡:2.67分钟;8卡:1.40分钟 | 1卡:5.64分钟;8卡:3.41分钟 |
|
||||||
|
| 参数(M) | 93M | 93M |
|
||||||
|
| 微调检查点 | 355.11M (.ckpt文件) | 355.11M (.ckpt文件) |
|
||||||
|
| 脚本 | [U-Net脚本](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/unet) | [U-Net脚本](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/unet) |
|
||||||
|
|
||||||
### 用法
|
### 用法
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -25,9 +25,6 @@ from src.utils import UnetEval, TempLoss, dice_coeff
|
||||||
from src.model_utils.config import config
|
from src.model_utils.config import config
|
||||||
from src.model_utils.moxing_adapter import moxing_wrapper
|
from src.model_utils.moxing_adapter import moxing_wrapper
|
||||||
|
|
||||||
device_id = int(os.getenv("DEVICE_ID"))
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id)
|
|
||||||
|
|
||||||
@moxing_wrapper()
|
@moxing_wrapper()
|
||||||
def test_net(data_dir,
|
def test_net(data_dir,
|
||||||
ckpt_path,
|
ckpt_path,
|
||||||
|
@ -63,6 +60,10 @@ def test_net(data_dir,
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
|
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target, save_graphs=False)
|
||||||
|
if config.device_target == "Ascend":
|
||||||
|
device_id = int(os.getenv('DEVICE_ID'))
|
||||||
|
context.set_context(device_id=device_id)
|
||||||
test_net(data_dir=config.data_path,
|
test_net(data_dir=config.data_path,
|
||||||
ckpt_path=config.checkpoint_file_path,
|
ckpt_path=config.checkpoint_file_path,
|
||||||
cross_valid_ind=config.cross_valid_ind)
|
cross_valid_ind=config.cross_valid_ind)
|
||||||
|
|
|
@ -0,0 +1,22 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
echo "=============================================================================================================="
|
||||||
|
echo "Please run the script as: "
|
||||||
|
echo "bash scripts/run_distribute_train_gpu.sh [RANKSIZE] [DATASET] [CONFIG_PATH]"
|
||||||
|
echo "for example: bash run_distribute_train_gpu.sh 8 /path/to/data/ /path/to/config/"
|
||||||
|
echo "=============================================================================================================="
|
||||||
|
mpirun -n $1 --allow-run-as-root --output-filename log_output --merge-stderr-to-stdout \
|
||||||
|
python train.py --run_distribute=True --data_path=$2 --config_path=$3 --output=./output > train.log 2>&1 &
|
|
@ -0,0 +1,22 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# less required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
echo "=============================================================================================================="
|
||||||
|
echo "Please run the script as: "
|
||||||
|
echo "bash scripts/run_standalone_eval_gpu.sh [DATASET] [CHECKPOINT] [CONFIG_PATH]"
|
||||||
|
echo "for example: bash run_standalone_eval_gpu.sh /path/to/data/ /path/to/checkpoint/ /path/to/config/"
|
||||||
|
echo "=============================================================================================================="
|
||||||
|
python eval.py --data_path=$1 --checkpoint_file_path=$2 --config_path=$3 > eval.log 2>&1 &
|
|
@ -0,0 +1,22 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# less required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
echo "=============================================================================================================="
|
||||||
|
echo "Please run the script as: "
|
||||||
|
echo "bash scripts/run_standalone_train_gpu.sh [DATASET] [CONFIG_PATH] "
|
||||||
|
echo "for example: bash scripts/run_standalone_train_gpu.sh /path/to/data/ /path/to/config/"
|
||||||
|
echo "=============================================================================================================="
|
||||||
|
python train.py --data_path=$1 --config_path=$2 --output ./output > train.log 2>&1 &
|
|
@ -12,14 +12,13 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import mindspore
|
import mindspore
|
||||||
import mindspore.nn as nn
|
import mindspore.nn as nn
|
||||||
from mindspore import Model, context
|
from mindspore import Model, context
|
||||||
from mindspore.communication.management import init
|
from mindspore.communication.management import init, get_rank, get_group_size
|
||||||
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint
|
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint
|
||||||
from mindspore.context import ParallelMode
|
from mindspore.context import ParallelMode
|
||||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
|
@ -33,10 +32,6 @@ from src.eval_callback import EvalCallBack
|
||||||
|
|
||||||
from src.model_utils.config import config
|
from src.model_utils.config import config
|
||||||
from src.model_utils.moxing_adapter import moxing_wrapper
|
from src.model_utils.moxing_adapter import moxing_wrapper
|
||||||
from src.model_utils.device_adapter import get_rank_id, get_device_num
|
|
||||||
|
|
||||||
device_id = int(os.getenv("DEVICE_ID"))
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id)
|
|
||||||
|
|
||||||
mindspore.set_seed(1)
|
mindspore.set_seed(1)
|
||||||
|
|
||||||
|
@ -51,8 +46,8 @@ def train_net(cross_valid_ind=1,
|
||||||
run_distribute = config.run_distribute
|
run_distribute = config.run_distribute
|
||||||
if run_distribute:
|
if run_distribute:
|
||||||
init()
|
init()
|
||||||
group_size = get_device_num()
|
group_size = get_group_size()
|
||||||
rank = get_rank_id()
|
rank = get_rank()
|
||||||
parallel_mode = ParallelMode.DATA_PARALLEL
|
parallel_mode = ParallelMode.DATA_PARALLEL
|
||||||
context.set_auto_parallel_context(parallel_mode=parallel_mode,
|
context.set_auto_parallel_context(parallel_mode=parallel_mode,
|
||||||
device_num=group_size,
|
device_num=group_size,
|
||||||
|
@ -94,16 +89,21 @@ def train_net(cross_valid_ind=1,
|
||||||
else:
|
else:
|
||||||
repeat = config.repeat
|
repeat = config.repeat
|
||||||
dataset_sink_mode = False
|
dataset_sink_mode = False
|
||||||
|
if config.device_target == "GPU":
|
||||||
|
dataset_sink_mode = True
|
||||||
per_print_times = 1
|
per_print_times = 1
|
||||||
train_dataset, valid_dataset = create_dataset(data_dir, repeat, batch_size, True, cross_valid_ind,
|
train_dataset, valid_dataset = create_dataset(data_dir, repeat, batch_size, True, cross_valid_ind,
|
||||||
run_distribute, config.crop, config.image_size)
|
run_distribute, config.crop, config.image_size)
|
||||||
train_data_size = train_dataset.get_dataset_size()
|
train_data_size = train_dataset.get_dataset_size()
|
||||||
print("dataset length is:", train_data_size)
|
print("dataset length is:", train_data_size)
|
||||||
ckpt_save_dir = os.path.join(config.output_path, config.checkpoint_path)
|
ckpt_save_dir = os.path.join(config.output_path, config.checkpoint_path)
|
||||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=train_data_size,
|
save_ck_steps = train_data_size
|
||||||
|
if config.device_target == "GPU":
|
||||||
|
save_ck_steps = train_data_size * epochs
|
||||||
|
ckpt_config = CheckpointConfig(save_checkpoint_steps=save_ck_steps,
|
||||||
keep_checkpoint_max=config.keep_checkpoint_max)
|
keep_checkpoint_max=config.keep_checkpoint_max)
|
||||||
ckpoint_cb = ModelCheckpoint(prefix='ckpt_{}_adam'.format(config.model_name),
|
ckpoint_cb = ModelCheckpoint(prefix='ckpt_{}_adam'.format(config.model_name),
|
||||||
directory=ckpt_save_dir+'./ckpt_{}/'.format(device_id),
|
directory=ckpt_save_dir+'./ckpt_{}/'.format(rank),
|
||||||
config=ckpt_config)
|
config=ckpt_config)
|
||||||
|
|
||||||
optimizer = nn.Adam(params=net.trainable_params(), learning_rate=lr, weight_decay=config.weight_decay,
|
optimizer = nn.Adam(params=net.trainable_params(), learning_rate=lr, weight_decay=config.weight_decay,
|
||||||
|
@ -121,7 +121,7 @@ def train_net(cross_valid_ind=1,
|
||||||
eval_param_dict = {"model": eval_model, "dataset": valid_dataset, "metrics_name": config.eval_metrics}
|
eval_param_dict = {"model": eval_model, "dataset": valid_dataset, "metrics_name": config.eval_metrics}
|
||||||
eval_cb = EvalCallBack(apply_eval, eval_param_dict, interval=config.eval_interval,
|
eval_cb = EvalCallBack(apply_eval, eval_param_dict, interval=config.eval_interval,
|
||||||
eval_start_epoch=config.eval_start_epoch, save_best_ckpt=True,
|
eval_start_epoch=config.eval_start_epoch, save_best_ckpt=True,
|
||||||
ckpt_directory=ckpt_save_dir+'./ckpt_{}/'.format(device_id), besk_ckpt_name="best.ckpt",
|
ckpt_directory=ckpt_save_dir+'./ckpt_{}/'.format(rank), besk_ckpt_name="best.ckpt",
|
||||||
metrics_name=config.eval_metrics)
|
metrics_name=config.eval_metrics)
|
||||||
callbacks.append(eval_cb)
|
callbacks.append(eval_cb)
|
||||||
model.train(int(epochs / repeat), train_dataset, callbacks=callbacks, dataset_sink_mode=dataset_sink_mode)
|
model.train(int(epochs / repeat), train_dataset, callbacks=callbacks, dataset_sink_mode=dataset_sink_mode)
|
||||||
|
@ -130,9 +130,15 @@ def train_net(cross_valid_ind=1,
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
|
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target, save_graphs=False)
|
||||||
|
if config.device_target == "Ascend":
|
||||||
|
device_id = int(os.getenv('DEVICE_ID'))
|
||||||
|
context.set_context(device_id=device_id)
|
||||||
epoch_size = config.epochs if not config.run_distribute else config.distribute_epochs
|
epoch_size = config.epochs if not config.run_distribute else config.distribute_epochs
|
||||||
|
batchsize = config.batch_size
|
||||||
|
if config.device_target == 'GPU' and config.run_distribute:
|
||||||
|
batchsize = config.distribute_batchsize
|
||||||
train_net(cross_valid_ind=config.cross_valid_ind,
|
train_net(cross_valid_ind=config.cross_valid_ind,
|
||||||
epochs=epoch_size,
|
epochs=epoch_size,
|
||||||
batch_size=config.batch_size,
|
batch_size=batchsize,
|
||||||
lr=config.lr)
|
lr=config.lr)
|
||||||
|
|
|
@ -0,0 +1,69 @@
|
||||||
|
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
|
||||||
|
enable_modelarts: False
|
||||||
|
# Url for modelarts
|
||||||
|
data_url: ""
|
||||||
|
train_url: ""
|
||||||
|
checkpoint_url: ""
|
||||||
|
# Path for local
|
||||||
|
data_path: "/cache/data"
|
||||||
|
output_path: "/cache/train"
|
||||||
|
load_path: "/cache/checkpoint_path/"
|
||||||
|
device_target: 'GPU'
|
||||||
|
enable_profiling: False
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
# Training options
|
||||||
|
model_name: 'unet_medical'
|
||||||
|
run_eval: False
|
||||||
|
run_distribute: False
|
||||||
|
crop: [388, 388]
|
||||||
|
image_size : [572, 572]
|
||||||
|
lr: 0.0001
|
||||||
|
epochs: 400
|
||||||
|
repeat: 1
|
||||||
|
distribute_epochs: 1600
|
||||||
|
batch_size: 12
|
||||||
|
distribute_batchsize: 3
|
||||||
|
cross_valid_ind: 1
|
||||||
|
num_classes: 2
|
||||||
|
num_channels: 1
|
||||||
|
weight_decay: 0.0005
|
||||||
|
loss_scale: 1024.0
|
||||||
|
FixedLossScaleManager: 1024.0
|
||||||
|
resume: False
|
||||||
|
resume_ckpt: './'
|
||||||
|
transfer_training: False
|
||||||
|
filter_weight: ['outc.weight', 'outc.bias']
|
||||||
|
|
||||||
|
#Eval options
|
||||||
|
keep_checkpoint_max: 10
|
||||||
|
eval_activate: 'Softmax'
|
||||||
|
eval_resize: False
|
||||||
|
checkpoint_path: './checkpoint/'
|
||||||
|
checkpoint_file_path: 'ckpt_unet_medical_adam-400.ckpt'
|
||||||
|
rst_path: './result_Files/'
|
||||||
|
|
||||||
|
# Export options
|
||||||
|
width: 572
|
||||||
|
height: 572
|
||||||
|
file_name: unet
|
||||||
|
file_format: AIR
|
||||||
|
|
||||||
|
---
|
||||||
|
# Help description for each configuration
|
||||||
|
enable_modelarts: 'Whether training on modelarts, default: False'
|
||||||
|
data_url: 'Dataset url for obs'
|
||||||
|
train_url: 'Training output url for obs'
|
||||||
|
checkpoint_url: 'The location of checkpoint for obs'
|
||||||
|
data_path: 'Dataset path for local'
|
||||||
|
output_path: 'Training output path for local'
|
||||||
|
load_path: 'The location of checkpoint for obs'
|
||||||
|
device_target: 'Target device type, available: [Ascend, GPU, CPU]'
|
||||||
|
enable_profiling: 'Whether enable profiling while training, default: False'
|
||||||
|
num_classes: 'Class for dataset'
|
||||||
|
batch_size: "Batch size for training and evaluation"
|
||||||
|
distribute_batchsize: "Batch size for distribute training"
|
||||||
|
weight_decay: "Weight decay."
|
||||||
|
keep_checkpoint_max: "keep the last keep_checkpoint_max checkpoint"
|
||||||
|
checkpoint_path: "The location of the checkpoint file."
|
||||||
|
checkpoint_file_path: "The location of the checkpoint file."
|
Loading…
Reference in New Issue