forked from mindspore-Ecosystem/mindspore
!18472 Implement UNet3d on GPU
Merge pull request !18472 from likesen/master
This commit is contained in:
commit
8e043090be
|
@ -42,9 +42,9 @@ class SliceGpuFwdKernel : public GpuKernel {
|
|||
}
|
||||
T *input = GetDeviceAddress<T>(inputs, 0);
|
||||
T *output = GetDeviceAddress<T>(outputs, 0);
|
||||
Slice4DKernel(begin_[0], begin_[1], begin_[2], begin_[3], size_[0], size_[1], size_[2], size_[3], input_shape_[0],
|
||||
input_shape_[1], input_shape_[2], input_shape_[3], input, output,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
Slice5DKernel(begin_[0], begin_[1], begin_[2], begin_[3], begin_[4], size_[0], size_[1], size_[2], size_[3],
|
||||
size_[4], input_shape_[0], input_shape_[1], input_shape_[2], input_shape_[3], input_shape_[4], input,
|
||||
output, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
|
@ -53,28 +53,35 @@ class SliceGpuFwdKernel : public GpuKernel {
|
|||
}
|
||||
auto data_format = AnfAlgo::GetInputFormat(kernel_node, 0);
|
||||
auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
ShapeNdTo4d(input_shape, &input_shape_);
|
||||
ShapeNdTo5d(input_shape, &input_shape_);
|
||||
|
||||
for (auto i = begin_.size(); i < 4; i++) {
|
||||
for (auto i = begin_.size(); i < 5; i++) {
|
||||
(void)begin_.insert(begin_.begin(), 0);
|
||||
}
|
||||
for (size_t i = size_.size(); i < 4; i++) {
|
||||
for (size_t i = size_.size(); i < 5; i++) {
|
||||
(void)size_.insert(size_.begin(), 1);
|
||||
}
|
||||
|
||||
input_size_ = input_shape_[0] * input_shape_[1] * input_shape_[2] * input_shape_[3] * sizeof(T);
|
||||
input_size_ = input_shape_[0] * input_shape_[1] * input_shape_[2] * input_shape_[3] * input_shape_[4] * sizeof(T);
|
||||
auto out_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
|
||||
|
||||
output_size_ = sizeof(T);
|
||||
for (size_t x : out_shape) {
|
||||
output_size_ = output_size_ * x;
|
||||
}
|
||||
// transpose begin and size for NHWC data
|
||||
// transpose begin and size for NHWC and NDHWC data
|
||||
if (data_format == "NHWC") {
|
||||
std::swap(begin_[1], begin_[3]);
|
||||
std::swap(begin_[1], begin_[2]);
|
||||
std::swap(size_[1], size_[3]);
|
||||
std::swap(size_[1], size_[2]);
|
||||
} else if (data_format == "NDHWC") {
|
||||
std::swap(begin_[1], begin_[4]);
|
||||
std::swap(begin_[1], begin_[3]);
|
||||
std::swap(begin_[1], begin_[2]);
|
||||
std::swap(size_[1], size_[4]);
|
||||
std::swap(size_[1], size_[3]);
|
||||
std::swap(size_[1], size_[2]);
|
||||
}
|
||||
InitSizeLists();
|
||||
return true;
|
||||
|
@ -87,6 +94,18 @@ class SliceGpuFwdKernel : public GpuKernel {
|
|||
}
|
||||
|
||||
private:
|
||||
// expand Nd Shape to 5d (N in [0,5])
|
||||
void ShapeNdTo5d(const std::vector<size_t> &src, std::vector<size_t> *dst) {
|
||||
if (src.size() > 5) {
|
||||
MS_EXCEPTION(ValueError) << src.size() << "-D data is not supported!";
|
||||
}
|
||||
dst->push_back(src.size() < 5 ? 1 : src[src.size() - 5]);
|
||||
dst->push_back(src.size() < 4 ? 1 : src[src.size() - 4]);
|
||||
dst->push_back(src.size() < 3 ? 1 : src[src.size() - 3]);
|
||||
dst->push_back(src.size() < 2 ? 1 : src[src.size() - 2]);
|
||||
dst->push_back(src.size() == 0 ? 1 : src[src.size() - 1]);
|
||||
}
|
||||
|
||||
bool CheckParam(const CNodePtr &kernel_node) {
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != 1) {
|
||||
|
@ -99,8 +118,8 @@ class SliceGpuFwdKernel : public GpuKernel {
|
|||
return false;
|
||||
}
|
||||
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
if (input_shape.size() > 4) {
|
||||
MS_LOG(ERROR) << "Input dims is " << input_shape.size() << ", but SliceGpuFwdKernel olny support 4d or lower.";
|
||||
if (input_shape.size() > 5) {
|
||||
MS_LOG(ERROR) << "Input dims is " << input_shape.size() << ", but SliceGpuFwdKernel olny support 5d or lower.";
|
||||
return false;
|
||||
}
|
||||
if (input_shape.size() == 0) {
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -30,6 +30,26 @@ void CalReLUGrad(int size, T *dy, T *y, T *dx, cudaStream_t cuda_stream) {
|
|||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void PReluChannelSharedGradKernel(size_t size, T *dy_addr, T *x_addr, T *w_addr, T *dx_addr, T *dwc_addr) {
|
||||
T zero = static_cast<T>(0);
|
||||
T w = w_addr[0];
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
|
||||
T dy = dy_addr[pos];
|
||||
T x = x_addr[pos];
|
||||
dx_addr[pos] = x > zero ? dy : w * dy;
|
||||
dwc_addr[pos] = x > zero ? zero : x * dy;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void PReluChannelSharedGrad(size_t input_size, T *dy_addr, T *x_addr, T *w_addr, T *dx_addr, T *dwc_addr,
|
||||
cudaStream_t cuda_stream) {
|
||||
PReluChannelSharedGradKernel<<<GET_BLOCKS(input_size), GET_THREADS, 0, cuda_stream>>>(input_size, dy_addr, x_addr,
|
||||
w_addr, dx_addr, dwc_addr);
|
||||
return;
|
||||
}
|
||||
|
||||
template void CalReLUGrad(int size, double *dy, double *y, double *dx, cudaStream_t cuda_stream);
|
||||
template void CalReLUGrad(int size, float *dy, float *y, float *dx, cudaStream_t cuda_stream);
|
||||
template void CalReLUGrad(int size, half *dy, half *y, half *dx, cudaStream_t cuda_stream);
|
||||
|
@ -38,3 +58,7 @@ template void CalReLUGrad(int size, int16_t *dy, int16_t *y, int16_t *dx, cudaSt
|
|||
template void CalReLUGrad(int size, int32_t *dy, int32_t *y, int32_t *dx, cudaStream_t cuda_stream);
|
||||
template void CalReLUGrad(int size, int64_t *dy, int64_t *y, int64_t *dx, cudaStream_t cuda_stream);
|
||||
template void CalReLUGrad(int size, uint8_t *dy, uint8_t *y, uint8_t *dx, cudaStream_t cuda_stream);
|
||||
template void PReluChannelSharedGrad(size_t input_size, float *dy_addr, float *x_addr, float *w_addr, float *dx_addr,
|
||||
float *dwc_addr, cudaStream_t cuda_stream);
|
||||
template void PReluChannelSharedGrad(size_t input_size, half *dy_addr, half *x_addr, half *w_addr, half *dx_addr,
|
||||
half *dwc_addr, cudaStream_t cuda_stream);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -20,4 +20,8 @@
|
|||
#include "runtime/device/gpu/cuda_common.h"
|
||||
template <typename T>
|
||||
void CalReLUGrad(int input_size, T *dy, T *y, T *dx, cudaStream_t cuda_stream);
|
||||
|
||||
template <typename T>
|
||||
void PReluChannelSharedGrad(size_t input_size, T *dy_addr, T *x_addr, T *w_addr, T *dx_addr, T *dwc_addr,
|
||||
cudaStream_t cuda_stream);
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_RELU_GRAD_H_
|
||||
|
|
|
@ -35,6 +35,25 @@ __global__ void Slice4D(const size_t s1, const size_t s2, const size_t s3, const
|
|||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void Slice5D(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t s5,
|
||||
const size_t l1, const size_t l2, const size_t l3, const size_t l4, const size_t l5,
|
||||
const size_t d1, const size_t d2, const size_t d3, const size_t d4, const size_t d5,
|
||||
const T *input, T *output) {
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (l1 * l2 * l3 * l4 * l5);
|
||||
pos += blockDim.x * gridDim.x) {
|
||||
size_t i = pos / (l2 * l3 * l4 * l5) % l1;
|
||||
size_t j = pos / (l3 * l4 * l5) % l2;
|
||||
size_t k = pos / (l4 * l5) % l3;
|
||||
size_t o = pos / l5 % l4;
|
||||
size_t q = pos % l5;
|
||||
|
||||
size_t offset =
|
||||
(i + s1) * (d2 * d3 * d4 * d5) + (j + s2) * (d3 * d4 * d5) + (k + s3) * (d4 * d5) + (o + s4) * d5 + (q + s5);
|
||||
output[pos] = input[offset];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void Slice4DGrad(const size_t s1, const size_t s2, const size_t s3, const size_t s4,
|
||||
const size_t l1, const size_t l2, const size_t l3, const size_t l4,
|
||||
|
@ -70,7 +89,13 @@ void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size
|
|||
Slice4D<<<GET_BLOCKS(l1 * l2 * l3 * l4), GET_THREADS, 0, stream>>>(s1, s2, s3, s4, l1, l2, l3, l4, d1, d2, d3, d4,
|
||||
input, output);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void Slice5DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t s5, const size_t l1,
|
||||
const size_t l2, const size_t l3, const size_t l4, const size_t l5, const size_t d1, const size_t d2,
|
||||
const size_t d3, const size_t d4, const size_t d5, const T *input, T *output, cudaStream_t stream) {
|
||||
Slice5D<<<GET_BLOCKS(l1 * l2 * l3 * l4 * l5), GET_THREADS, 0, stream>>>(s1, s2, s3, s4, s5, l1, l2, l3, l4, l5, d1,
|
||||
d2, d3, d4, d5, input, output);
|
||||
}
|
||||
template <typename T>
|
||||
void CalSlice4DGrad(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1,
|
||||
const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2,
|
||||
|
@ -184,6 +209,39 @@ template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, c
|
|||
const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2,
|
||||
const size_t d3, const size_t d4, const bool *input, bool *output, cudaStream_t stream);
|
||||
|
||||
template void Slice5DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t s5,
|
||||
const size_t l1, const size_t l2, const size_t l3, const size_t l4, const size_t l5,
|
||||
const size_t d1, const size_t d2, const size_t d3, const size_t d4, const size_t d5,
|
||||
const double *input, double *output, cudaStream_t stream);
|
||||
template void Slice5DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t s5,
|
||||
const size_t l1, const size_t l2, const size_t l3, const size_t l4, const size_t l5,
|
||||
const size_t d1, const size_t d2, const size_t d3, const size_t d4, const size_t d5,
|
||||
const float *input, float *output, cudaStream_t stream);
|
||||
template void Slice5DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t s5,
|
||||
const size_t l1, const size_t l2, const size_t l3, const size_t l4, const size_t l5,
|
||||
const size_t d1, const size_t d2, const size_t d3, const size_t d4, const size_t d5,
|
||||
const half *input, half *output, cudaStream_t stream);
|
||||
template void Slice5DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t s5,
|
||||
const size_t l1, const size_t l2, const size_t l3, const size_t l4, const size_t l5,
|
||||
const size_t d1, const size_t d2, const size_t d3, const size_t d4, const size_t d5,
|
||||
const int64_t *input, int64_t *output, cudaStream_t stream);
|
||||
template void Slice5DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t s5,
|
||||
const size_t l1, const size_t l2, const size_t l3, const size_t l4, const size_t l5,
|
||||
const size_t d1, const size_t d2, const size_t d3, const size_t d4, const size_t d5,
|
||||
const int *input, int *output, cudaStream_t stream);
|
||||
template void Slice5DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t s5,
|
||||
const size_t l1, const size_t l2, const size_t l3, const size_t l4, const size_t l5,
|
||||
const size_t d1, const size_t d2, const size_t d3, const size_t d4, const size_t d5,
|
||||
const short *input, short *output, cudaStream_t stream); // NOLINT
|
||||
template void Slice5DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t s5,
|
||||
const size_t l1, const size_t l2, const size_t l3, const size_t l4, const size_t l5,
|
||||
const size_t d1, const size_t d2, const size_t d3, const size_t d4, const size_t d5,
|
||||
const unsigned char *input, unsigned char *output, cudaStream_t stream);
|
||||
template void Slice5DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t s5,
|
||||
const size_t l1, const size_t l2, const size_t l3, const size_t l4, const size_t l5,
|
||||
const size_t d1, const size_t d2, const size_t d3, const size_t d4, const size_t d5,
|
||||
const bool *input, bool *output, cudaStream_t stream);
|
||||
|
||||
template void CalSlice4DGrad<double>(const size_t s1, const size_t s2, const size_t s3, const size_t s4,
|
||||
const size_t l1, const size_t l2, const size_t l3, const size_t l4,
|
||||
const size_t d1, const size_t d2, const size_t d3, const size_t d4,
|
||||
|
|
|
@ -26,6 +26,10 @@ void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size
|
|||
const size_t l3, const size_t l4, const size_t d1, const size_t d2, const size_t d3, const size_t d4,
|
||||
const T *input, T *output, cudaStream_t stream);
|
||||
template <typename T>
|
||||
void Slice5DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t s5, const size_t l1,
|
||||
const size_t l2, const size_t l3, const size_t l4, const size_t l5, const size_t d1, const size_t d2,
|
||||
const size_t d3, const size_t d4, const size_t d5, const T *input, T *output, cudaStream_t stream);
|
||||
template <typename T>
|
||||
void CalSlice4DGrad(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1,
|
||||
const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2,
|
||||
const size_t d3, const size_t d4, const T *dy, T *dx, cudaStream_t stream);
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/nn/prelu_grad_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(PReLUGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
PReLUGpuGradKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(PReLUGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
PReLUGpuGradKernel, half)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,196 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_PRELU_GRAD_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_PRELU_GRAD_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
#include "backend/kernel_compiler/gpu/kernel_constants.h"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/relu_grad_impl.cuh"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class PReLUGpuGradKernel : public GpuKernel {
|
||||
public:
|
||||
PReLUGpuGradKernel()
|
||||
: data_format_(kOpFormat_NCDHW),
|
||||
input_size_(0),
|
||||
weight_size_(0),
|
||||
reduce_workspace_size_(0),
|
||||
spatial_count_(1),
|
||||
is_null_input_(false),
|
||||
channel_shared_(false),
|
||||
channel_last_(false) {}
|
||||
~PReLUGpuGradKernel() override { DestroyResource(); }
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
T *dy_addr = GetDeviceAddress<T>(inputs, 0);
|
||||
T *x_addr = GetDeviceAddress<T>(inputs, 1);
|
||||
T *w_addr = GetDeviceAddress<T>(inputs, 2);
|
||||
T *dx_addr = GetDeviceAddress<T>(outputs, 0);
|
||||
T *dw_addr = GetDeviceAddress<T>(outputs, 1);
|
||||
T *dw_collector_addr = GetDeviceAddress<T>(workspace, 0);
|
||||
T *reduce_workspace_addr = GetDeviceAddress<T>(workspace, 1);
|
||||
|
||||
PReluChannelSharedGrad(input_size_ / sizeof(T), dy_addr, x_addr, w_addr, dx_addr, dw_collector_addr,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
|
||||
if (data_type_ == CUDNN_DATA_DOUBLE) {
|
||||
T alpha = static_cast<T>(1.0f);
|
||||
T beta = static_cast<T>(0.0f);
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
kernel_node_,
|
||||
cudnnReduceTensor(cudnn_handle_, reduce_tensor_descriptor_, nullptr, 0, reduce_workspace_addr,
|
||||
reduce_workspace_size_, &alpha, grad_weight_collector_descriptor_, dw_collector_addr, &beta,
|
||||
grad_weight_descriptor_, dw_addr),
|
||||
"cudnnReduceTensor failed.");
|
||||
} else {
|
||||
const float alphaf = static_cast<float>(1.0f);
|
||||
const float betaf = static_cast<float>(0.0f);
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
kernel_node_,
|
||||
cudnnReduceTensor(cudnn_handle_, reduce_tensor_descriptor_, nullptr, 0, reduce_workspace_addr,
|
||||
reduce_workspace_size_, &alphaf, grad_weight_collector_descriptor_, dw_collector_addr, &betaf,
|
||||
grad_weight_descriptor_, dw_addr),
|
||||
"cudnnReduceTensor failed.");
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void InitResource() override {
|
||||
cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle();
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateReduceTensorDescriptor(&reduce_tensor_descriptor_),
|
||||
"cudnnCreateReduceTensorDescriptor failed.");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&grad_weight_collector_descriptor_),
|
||||
"cudnnCreateTensorDescriptor failed.");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&grad_weight_descriptor_),
|
||||
"cudnnCreateTensorDescriptor failed.");
|
||||
}
|
||||
|
||||
void DestroyResource() noexcept override {
|
||||
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyReduceTensorDescriptor(reduce_tensor_descriptor_),
|
||||
"cudnnDestroyReduceTensorDescriptor failed.");
|
||||
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(grad_weight_collector_descriptor_),
|
||||
"cudnnDestroyTensorDescriptor failed.");
|
||||
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(grad_weight_descriptor_),
|
||||
"cudnnDestroyTensorDescriptor failed.");
|
||||
}
|
||||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
kernel_node_ = kernel_node;
|
||||
input_size_ = sizeof(T);
|
||||
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
is_null_input_ = CHECK_NULL_INPUT(input_shape);
|
||||
if (is_null_input_) {
|
||||
MS_LOG(WARNING) << "PReLUGpuBwdKernel input is null.";
|
||||
}
|
||||
for (size_t i = 0; i < input_shape.size(); ++i) {
|
||||
input_size_ *= input_shape[i];
|
||||
}
|
||||
weight_size_ = sizeof(T);
|
||||
auto weight_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 2);
|
||||
is_null_input_ = CHECK_NULL_INPUT(weight_shape);
|
||||
if (is_null_input_) {
|
||||
MS_LOG(WARNING) << "PReLUGpuBwdKernel input is null.";
|
||||
}
|
||||
for (auto dim : weight_shape) {
|
||||
weight_size_ *= dim;
|
||||
}
|
||||
channel_shared_ = (weight_shape[0] == 1);
|
||||
if (!channel_shared_) {
|
||||
MS_LOG(WARNING)
|
||||
<< "PReLUGpuBwdKernel shares weight for all channels, but the given weight tensor has more than one element.";
|
||||
}
|
||||
|
||||
spatial_count_ = 1;
|
||||
if (channel_last_) {
|
||||
for (size_t i = 1; i < input_shape.size() - 1; ++i) {
|
||||
spatial_count_ *= input_shape[i];
|
||||
}
|
||||
} else {
|
||||
for (size_t i = 2; i < input_shape.size(); ++i) {
|
||||
spatial_count_ *= input_shape[i];
|
||||
}
|
||||
}
|
||||
|
||||
data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0)));
|
||||
int input_dim_length = input_shape.size();
|
||||
std::vector<size_t> reduce_out_shape(input_dim_length, 1);
|
||||
if (channel_last_) {
|
||||
reduce_out_shape[input_dim_length - 1] = weight_shape[0];
|
||||
} else {
|
||||
reduce_out_shape[1] = weight_shape[0];
|
||||
}
|
||||
InitResource();
|
||||
CudnnSetTensorNdDescriptor(reduce_out_shape, grad_weight_descriptor_, data_type_, kernel_node_);
|
||||
CudnnSetTensorNdDescriptor(input_shape, grad_weight_collector_descriptor_, data_type_, kernel_node_);
|
||||
cudnnDataType_t comp_type = (data_type_ == CUDNN_DATA_DOUBLE) ? CUDNN_DATA_DOUBLE : CUDNN_DATA_FLOAT;
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
kernel_node_,
|
||||
cudnnSetReduceTensorDescriptor(reduce_tensor_descriptor_, CUDNN_REDUCE_TENSOR_ADD, comp_type,
|
||||
CUDNN_NOT_PROPAGATE_NAN, CUDNN_REDUCE_TENSOR_NO_INDICES, CUDNN_32BIT_INDICES),
|
||||
"cudnnSetReduceTensorDescriptor failed");
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override {
|
||||
input_size_list_.push_back(input_size_);
|
||||
input_size_list_.push_back(input_size_);
|
||||
input_size_list_.push_back(weight_size_);
|
||||
output_size_list_.push_back(input_size_);
|
||||
output_size_list_.push_back(weight_size_);
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
kernel_node_,
|
||||
cudnnGetReductionWorkspaceSize(cudnn_handle_, reduce_tensor_descriptor_, grad_weight_collector_descriptor_,
|
||||
grad_weight_descriptor_, &reduce_workspace_size_),
|
||||
"cudnnGetReductionWorkspaceSize failed.");
|
||||
workspace_size_list_.push_back(input_size_);
|
||||
workspace_size_list_.push_back(reduce_workspace_size_);
|
||||
}
|
||||
|
||||
private:
|
||||
cudnnHandle_t cudnn_handle_;
|
||||
cudnnDataType_t data_type_;
|
||||
cudnnReduceTensorDescriptor_t reduce_tensor_descriptor_;
|
||||
cudnnTensorDescriptor_t grad_weight_collector_descriptor_;
|
||||
cudnnTensorDescriptor_t grad_weight_descriptor_;
|
||||
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
std::string data_format_ = kOpFormat_NCDHW;
|
||||
size_t input_size_;
|
||||
size_t weight_size_;
|
||||
size_t reduce_workspace_size_;
|
||||
size_t spatial_count_;
|
||||
bool is_null_input_ = false;
|
||||
bool channel_shared_ = false;
|
||||
bool channel_last_ = false;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_PRELU_GRAD_KERNEL_H_
|
|
@ -11,10 +11,15 @@
|
|||
- [Script Parameters](#script-parameters)
|
||||
- [Training Process](#training-process)
|
||||
- [Training](#training)
|
||||
- [running on Ascend](#running-on-ascend)
|
||||
- [Distributed Training](#distributed-training)
|
||||
- [Training on Ascend](#training-on-ascend)
|
||||
- [Training on GPU](#training-on-gpu)
|
||||
- [Distributed Training](#distributed-training)
|
||||
- [Distributed training on Ascend](#distributed-training-on-ascend)
|
||||
- [Distributed training on GPU](#distributed-training-on-gpu)
|
||||
- [Evaluation Process](#evaluation-process)
|
||||
- [Evaluation](#evaluation)
|
||||
- [Evaluating on Ascend](#training-on-ascend)
|
||||
- [Evaluating on GPU](#training-on-gpu)
|
||||
- [Model Description](#model-description)
|
||||
- [Performance](#performance)
|
||||
- [Evaluation Performance](#evaluation-performance)
|
||||
|
@ -36,16 +41,29 @@ Dataset used: [LUNA16](https://luna16.grand-challenge.org/)
|
|||
|
||||
- Description: The data is to automatically detect the location of nodules from volumetric CT images. 888 CT scans from LIDC-IDRI database are provided. The complete dataset is divided into 10 subsets that should be used for the 10-fold cross-validation. All subsets are available as compressed zip files.
|
||||
|
||||
- Dataset size:888
|
||||
- Train:878 images
|
||||
- Test:10 images
|
||||
- Dataset size:887
|
||||
- Train:877 images
|
||||
- Test:10 images(last 10 images in subset9 with lexicographical order)
|
||||
- Data format:zip
|
||||
- Note:Data will be processed in convert_nifti.py
|
||||
- Note:Data will be processed in convert_nifti.py, and one of them will be ignored during data processing.
|
||||
- Data Content Structure
|
||||
|
||||
```text
|
||||
|
||||
.
|
||||
└─LUNA16
|
||||
├── train
|
||||
│ ├── image // contains 877 image files
|
||||
| ├── seg // contains 877 seg files
|
||||
├── val
|
||||
│ ├── image // contains 10 image files
|
||||
| ├── seg // contains 10 seg files
|
||||
```
|
||||
|
||||
## [Environment Requirements](#contents)
|
||||
|
||||
- Hardware(Ascend)
|
||||
- Prepare hardware environment with Ascend processor.
|
||||
- Hardware(Ascend or GPU)
|
||||
- Prepare hardware environment with Ascend or GPU.
|
||||
- Framework
|
||||
- [MindSpore](https://www.mindspore.cn/install/en)
|
||||
- For more information, please check the resources below:
|
||||
|
@ -79,6 +97,25 @@ bash scripts/run_distribute_train.sh [RANK_TABLE_FILE] [DATA_PATH]
|
|||
|
||||
# run evaluation example
|
||||
python eval.py --data_path=/path/to/data/ --checkpoint_file_path=/path/to/checkpoint/ > eval.log 2>&1 &
|
||||
```
|
||||
|
||||
- Run on GPU
|
||||
|
||||
```shell
|
||||
# enter scripts directory
|
||||
cd scripts
|
||||
# run training example(fp32)
|
||||
bash ./run_standalone_train_gpu_fp32.sh [TRAINING_DATA_PATH]
|
||||
# run training example(fp16)
|
||||
bash ./run_standalone_train_gpu_fp16.sh [TRAINING_DATA_PATH]
|
||||
# run distributed training example(fp32)
|
||||
bash ./run_distribute_train_gpu_fp32.sh [TRAINING_DATA_PATH]
|
||||
# run distributed training example(fp16)
|
||||
bash ./run_distribute_train_gpu_fp16.sh [TRAINING_DATA_PATH]
|
||||
# run evaluation example(fp32)
|
||||
bash ./run_standalone_eval_gpu_fp32.sh [VALIDATING_DATA_PATH] [CHECKPOINT_FILE_PATH]
|
||||
# run evaluation example(fp16)
|
||||
bash ./run_standalone_eval_gpu_fp16.sh [VALIDATING_DATA_PATH] [CHECKPOINT_FILE_PATH]
|
||||
|
||||
```
|
||||
|
||||
|
@ -123,9 +160,15 @@ If you want to run in modelarts, please check the official documentation of [mod
|
|||
└─unet3d
|
||||
├── README.md // descriptions about Unet3D
|
||||
├── scripts
|
||||
│ ├──run_disribute_train.sh // shell script for distributed on Ascend
|
||||
│ ├──run_distribute_train.sh // shell script for distributed on Ascend
|
||||
│ ├──run_standalone_train.sh // shell script for standalone on Ascend
|
||||
│ ├──run_standalone_eval.sh // shell script for evaluation on Ascend
|
||||
│ ├──run_distribute_train_gpu_fp32.sh // shell script for distributed on GPU fp32
|
||||
│ ├──run_distribute_train_gpu_fp16.sh // shell script for distributed on GPU fp16
|
||||
│ ├──run_standalone_train_gpu_fp32.sh // shell script for standalone on GPU fp32
|
||||
│ ├──run_standalone_train_gpu_fp16.sh // shell script for standalone on GPU fp16
|
||||
│ ├──run_standalone_eval_gpu_fp32.sh // shell script for evaluation on GPU fp32
|
||||
│ ├──run_standalone_eval_gpu_fp16.sh // shell script for evaluation on GPU fp16
|
||||
├── src
|
||||
│ ├──dataset.py // creating dataset
|
||||
│ ├──lr_schedule.py // learning rate scheduler
|
||||
|
@ -177,7 +220,23 @@ Parameters for both training and evaluation can be set in config.py
|
|||
|
||||
### Training
|
||||
|
||||
#### running on Ascend
|
||||
#### Training on GPU
|
||||
|
||||
```shell
|
||||
# enter scripts directory
|
||||
cd scripts
|
||||
# fp32
|
||||
bash ./run_standalone_train_gpu_fp32.sh /path_prefix/LUNA16/train
|
||||
# fp16
|
||||
bash ./run_standalone_train_gpu_fp16.sh /path_prefix/LUNA16/train
|
||||
|
||||
```
|
||||
|
||||
The python command above will run in the background, you can view the results through the file `train.log`.
|
||||
|
||||
After training, you'll get some checkpoint files under the train_fp[32|16]/output/ckpt_0/ folder by default.
|
||||
|
||||
#### Training on Ascend
|
||||
|
||||
```shell
|
||||
python train.py --data_path=/path/to/data/ > train.log 2>&1 &
|
||||
|
@ -201,7 +260,25 @@ epoch time: 1180467.795 ms, per step time: 1380.664 ms
|
|||
|
||||
```
|
||||
|
||||
#### Distributed Training
|
||||
### Distributed Training
|
||||
|
||||
#### Distributed training on GPU(8P)
|
||||
|
||||
```shell
|
||||
# enter scripts directory
|
||||
cd scripts
|
||||
# fp32
|
||||
bash ./run_distribute_train_gpu_fp32.sh /path_prefix/LUNA16/train
|
||||
# fp16
|
||||
bash ./run_distribute_train_gpu_fp16.sh /path_prefix/LUNA16/train
|
||||
|
||||
```
|
||||
|
||||
The above shell script will run distribute training in the background. You can view the results through the file `/train_parallel_fp[32|16]/train.log`.
|
||||
|
||||
After training, you'll get some checkpoint files under the `train_parallel_fp[32|16]/output/ckpt_[X]/` folder by default.
|
||||
|
||||
#### Distributed training on Ascend
|
||||
|
||||
> Notes:
|
||||
> RANK_TABLE_FILE can refer to [Link](https://www.mindspore.cn/tutorial/training/en/master/advanced_use/distributed_training_ascend.html) , and the device_ip can be got as [Link](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools). For large models like InceptionV4, it's better to export an external environment variable `export HCCL_CONNECT_TIMEOUT=600` to extend hccl connection checking time from the default 120 seconds to 600 seconds. Otherwise, the connection could be timeout since compiling time increases with the growth of model size.
|
||||
|
@ -235,6 +312,24 @@ epoch time: 140476.520 ms, per step time: 1312.865 ms
|
|||
|
||||
### Evaluation
|
||||
|
||||
#### Evaluating on GPU
|
||||
|
||||
```shell
|
||||
# enter scripts directory
|
||||
cd ./script
|
||||
# fp32, 1gpu
|
||||
bash ./run_standalone_eval_gpu_fp32.sh /path_prefix/LUNA16/val /path_prefix/train_fp32/output/ckpt_0/Unet3d-10_877.ckpt
|
||||
# fp16, 1gpu
|
||||
bash ./run_standalone_eval_gpu_fp16.sh /path_prefix/LUNA16/val /path_prefix/train_fp16/output/ckpt_0/Unet3d-10_877.ckpt
|
||||
# fp32, 8gpu
|
||||
bash ./run_standalone_eval_gpu_fp32.sh /path_prefix/LUNA16/val /path_prefix/train_parallel_fp32/output/ckpt_0/Unet3d-10_110.ckpt
|
||||
# fp16, 8gpu
|
||||
bash ./run_standalone_eval_gpu_fp16.sh /path_prefix/LUNA16/val /path_prefix/train_parallel_fp16/output/ckpt_0/Unet3d-10_110.ckpt
|
||||
|
||||
```
|
||||
|
||||
#### Evaluating on Ascend
|
||||
|
||||
- evaluation on dataset when running on Ascend
|
||||
|
||||
Before running the command below, please check the checkpoint path used for evaluation. Please set the checkpoint path to be the absolute full path, e.g., "username/unet3d/Unet3d-10_110.ckpt".
|
||||
|
@ -259,33 +354,33 @@ eval average dice is 0.9502010010453671
|
|||
|
||||
#### Evaluation Performance
|
||||
|
||||
| Parameters | Ascend |
|
||||
| ------------------- | --------------------------------------------------------- |
|
||||
| Model Version | Unet3D |
|
||||
| Resource | Ascend 910; CPU 2.60GHz, 192cores; Memory 755G; OS Euler2.8 |
|
||||
| uploaded Date | 03/18/2021 (month/day/year) |
|
||||
| MindSpore Version | 1.2.0 |
|
||||
| Dataset | LUNA16 |
|
||||
| Training Parameters | epoch = 10, batch_size = 1 |
|
||||
| Optimizer | Adam |
|
||||
| Loss Function | SoftmaxCrossEntropyWithLogits |
|
||||
| Speed | 8pcs: 1795ms/step |
|
||||
| Total time | 8pcs: 0.62hours |
|
||||
| Parameters (M) | 34 |
|
||||
| Parameters | Ascend | GPU |
|
||||
| ------------------- | --------------------------------------------------------- | ---------------------------------------------------- |
|
||||
| Model Version | Unet3D | Unet3D |
|
||||
| Resource | Ascend 910; CPU 2.60GHz, 192cores; Memory 755G; OS Euler2.8 | Nvidia V100 SXM2; CPU 1.526GHz; 72cores; Memory 42G; OS Ubuntu16|
|
||||
| uploaded Date | 03/18/2021 (month/day/year) | 05/21/2021(month/day/year) |
|
||||
| MindSpore Version | 1.2.0 | 1.2.0 |
|
||||
| Dataset | LUNA16 | LUNA16 |
|
||||
| Training Parameters | epoch = 10, batch_size = 1 | epoch = 10, batch_size = 1 |
|
||||
| Optimizer | Adam | Adam |
|
||||
| Loss Function | SoftmaxCrossEntropyWithLogits | SoftmaxCrossEntropyWithLogits |
|
||||
| Speed | 8pcs: 1795ms/step | 8pcs: 1883ms/step |
|
||||
| Total time | 8pcs: 0.62hours | 8pcs: 0.66hours |
|
||||
| Parameters (M) | 34 | 34 |
|
||||
| Scripts | [unet3d script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/unet3d) |
|
||||
|
||||
#### Inference Performance
|
||||
|
||||
| Parameters | Ascend |
|
||||
| ------------------- | --------------------------- |
|
||||
| Model Version | Unet3D |
|
||||
| Resource | Ascend 910; OS Euler2.8 |
|
||||
| Uploaded Date | 03/18/2021 (month/day/year) |
|
||||
| MindSpore Version | 1.2.0 |
|
||||
| Dataset | LUNA16 |
|
||||
| batch_size | 1 |
|
||||
| Dice | dice = 0.9502 |
|
||||
| Model for inference | 56M(.ckpt file) |
|
||||
| Parameters | Ascend | GPU |
|
||||
| ------------------- | --------------------------- | --------------------------- |
|
||||
| Model Version | Unet3D | Unet3D |
|
||||
| Resource | Ascend 910; OS Euler2.8 | Nvidia V100 SXM2; OS Ubuntu16|
|
||||
| Uploaded Date | 03/18/2021 (month/day/year) | 05/21/2021 (month/day/year) |
|
||||
| MindSpore Version | 1.2.0 | 1.2.0 |
|
||||
| Dataset | LUNA16 | LUNA16 |
|
||||
| batch_size | 1 | 1 |
|
||||
| Dice | dice = 0.9502 | dice = 0.9601 |
|
||||
| Model for inference | 56M(.ckpt file) | 56M(.ckpt file) |
|
||||
|
||||
# [Description of Random Situation](#contents)
|
||||
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
|
||||
enable_fp16_gpu: False
|
||||
enable_modelarts: False
|
||||
# Url for modelarts
|
||||
data_url: ""
|
||||
|
@ -41,6 +42,7 @@ file_format: ""
|
|||
---
|
||||
# Help description for each configuration
|
||||
enable_modelarts: 'Whether training on modelarts, default: False'
|
||||
enable_fp16_gpu: 'Whether training on gpu with fp16, default: False'
|
||||
data_url: 'Dataset url for obs'
|
||||
train_url: 'Training output url for obs'
|
||||
checkpoint_url: 'The location of checkpoint for obs'
|
||||
|
|
|
@ -19,13 +19,13 @@ from mindspore import dtype as mstype
|
|||
from mindspore import Model, context, Tensor
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from src.dataset import create_dataset
|
||||
from src.unet3d_model import UNet3d
|
||||
from src.unet3d_model import UNet3d, UNet3d_
|
||||
from src.utils import create_sliding_window, CalculateDice
|
||||
from src.model_utils.config import config
|
||||
from src.model_utils.moxing_adapter import moxing_wrapper
|
||||
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target, save_graphs=False, device_id=device_id)
|
||||
|
||||
@moxing_wrapper()
|
||||
def test_net(data_path, ckpt_path):
|
||||
|
@ -35,7 +35,10 @@ def test_net(data_path, ckpt_path):
|
|||
eval_data_size = eval_dataset.get_dataset_size()
|
||||
print("train dataset length is:", eval_data_size)
|
||||
|
||||
network = UNet3d()
|
||||
if config.device_target == 'Ascend':
|
||||
network = UNet3d()
|
||||
else:
|
||||
network = UNet3d_()
|
||||
network.set_train(False)
|
||||
param_dict = load_checkpoint(ckpt_path)
|
||||
load_param_into_net(network, param_dict)
|
||||
|
|
|
@ -0,0 +1,59 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# -ne 1 ]
|
||||
then
|
||||
echo "Usage: sh run_distribute_train_gpu.sh [DATA_PATH]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
PATH1=$(get_real_path $1)
|
||||
echo $PATH1
|
||||
if [ ! -d $PATH1 ]
|
||||
then
|
||||
echo "error: IMAGE_PATH=$PATH1 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=8
|
||||
export RANK_SIZE=8
|
||||
|
||||
|
||||
if [ -d "train_parallel_fp16" ];
|
||||
then
|
||||
rm -rf ./train_parallel_fp16
|
||||
fi
|
||||
|
||||
rm -rf ./train_parallel_fp16
|
||||
mkdir ./train_parallel_fp16
|
||||
cp ../*.py ./train_parallel_fp16
|
||||
cp *.sh ./train_parallel_fp16
|
||||
cp ../*.yaml ./train_parallel_fp16
|
||||
cp -r ../src ./train_parallel_fp16
|
||||
cd ./train_parallel_fp16 || exit
|
||||
echo "start distributed training with $DEVICE_NUM GPUs."
|
||||
env > env.log
|
||||
mpirun --allow-run-as-root -n $DEVICE_NUM python train.py --run_distribute=True --data_path=$PATH1 --output_path './output' --device_target='GPU' --enable_fp16_gpu=True --checkpoint_path='./' > train.log 2>&1 &
|
||||
cd ..
|
|
@ -0,0 +1,59 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# -ne 1 ]
|
||||
then
|
||||
echo "Usage: sh run_distribute_train_gpu.sh [DATA_PATH]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
PATH1=$(get_real_path $1)
|
||||
echo $PATH1
|
||||
if [ ! -d $PATH1 ]
|
||||
then
|
||||
echo "error: IMAGE_PATH=$PATH1 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=8
|
||||
export RANK_SIZE=8
|
||||
|
||||
|
||||
if [ -d "train_parallel_fp32" ];
|
||||
then
|
||||
rm -rf ./train_parallel_fp32
|
||||
fi
|
||||
|
||||
rm -rf ./train_parallel_fp32
|
||||
mkdir ./train_parallel_fp32
|
||||
cp ../*.py ./train_parallel_fp32
|
||||
cp *.sh ./train_parallel_fp32
|
||||
cp ../*.yaml ./train_parallel_fp32
|
||||
cp -r ../src ./train_parallel_fp32
|
||||
cd ./train_parallel_fp32 || exit
|
||||
echo "start distributed training with $DEVICE_NUM GPUs."
|
||||
env > env.log
|
||||
mpirun --allow-run-as-root -n $DEVICE_NUM python train.py --run_distribute=True --data_path=$PATH1 --output_path './output' --device_target='GPU' --checkpoint_path='./' > train.log 2>&1 &
|
||||
cd ..
|
|
@ -0,0 +1,76 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 2 ]
|
||||
then
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "bash scripts/run_standalone_eval_gpu_fp16.sh [DATA_PATH] [CHECKPOINT]"
|
||||
echo "for example: bash run_standalone_eval_gpu_fp16.sh /path/to/data/ /path/to/checkpoint/"
|
||||
echo "=============================================================================================================="
|
||||
fi
|
||||
|
||||
if [ $# != 2 ]
|
||||
then
|
||||
echo "Usage: sh run_standalone_eval_gpu_fp16.sh [DATA_PATH] [CHECKPOINT_PATH]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
PATH1=$(get_real_path $1)
|
||||
CHECKPOINT_FILE_PATH=$(get_real_path $2)
|
||||
echo $PATH1
|
||||
echo $CHECKPOINT_FILE_PATH
|
||||
|
||||
if [ ! -d $PATH1 ]
|
||||
then
|
||||
echo "error: PATH1=$PATH1 is not a path"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f $CHECKPOINT_FILE_PATH ]
|
||||
then
|
||||
echo "error: CHECKPOINT_FILE_PATH=$CHECKPOINT_FILE_PATH is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=1
|
||||
export RANK_SIZE=$DEVICE_NUM
|
||||
export DEVICE_ID=0
|
||||
export RANK_ID=0
|
||||
|
||||
if [ -d "eval_fp16" ];
|
||||
then
|
||||
rm -rf ./eval_fp16
|
||||
fi
|
||||
|
||||
mkdir ./eval_fp16
|
||||
cp ../*.py ./eval_fp16
|
||||
cp *.sh ./eval_fp16
|
||||
cp ../*.yaml ./eval_fp16
|
||||
cp -r ../src ./eval_fp16
|
||||
cd ./eval_fp16 || exit
|
||||
echo "start eval for checkpoint file: ${CHECKPOINT_FILE_PATH}"
|
||||
python eval.py --data_path=$PATH1 --checkpoint_file_path=$CHECKPOINT_FILE_PATH --device_target='GPU' > eval.log 2>&1 &
|
||||
echo "end eval for checkpoint file: ${CHECKPOINT_FILE_PATH}"
|
||||
cd ..
|
|
@ -0,0 +1,76 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 2 ]
|
||||
then
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "bash scripts/run_standalone_eval_gpu.sh [DATA_PATH] [CHECKPOINT]"
|
||||
echo "for example: bash run_standalone_eval_gpu.sh /path/to/data/ /path/to/checkpoint/"
|
||||
echo "=============================================================================================================="
|
||||
fi
|
||||
|
||||
if [ $# != 2 ]
|
||||
then
|
||||
echo "Usage: sh run_standalone_eval_gpu.sh [DATA_PATH] [CHECKPOINT_PATH]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
PATH1=$(get_real_path $1)
|
||||
CHECKPOINT_FILE_PATH=$(get_real_path $2)
|
||||
echo $PATH1
|
||||
echo $CHECKPOINT_FILE_PATH
|
||||
|
||||
if [ ! -d $PATH1 ]
|
||||
then
|
||||
echo "error: PATH1=$PATH1 is not a path"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f $CHECKPOINT_FILE_PATH ]
|
||||
then
|
||||
echo "error: CHECKPOINT_FILE_PATH=$CHECKPOINT_FILE_PATH is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=1
|
||||
export RANK_SIZE=$DEVICE_NUM
|
||||
export DEVICE_ID=0
|
||||
export RANK_ID=0
|
||||
|
||||
if [ -d "eval_fp32" ];
|
||||
then
|
||||
rm -rf ./eval_fp32
|
||||
fi
|
||||
|
||||
mkdir ./eval_fp32
|
||||
cp ../*.py ./eval_fp32
|
||||
cp *.sh ./eval_fp32
|
||||
cp ../*.yaml ./eval_fp32
|
||||
cp -r ../src ./eval_fp32
|
||||
cd ./eval_fp32 || exit
|
||||
echo "start eval for checkpoint file: ${CHECKPOINT_FILE_PATH}"
|
||||
python eval.py --data_path=$PATH1 --checkpoint_file_path=$CHECKPOINT_FILE_PATH --device_target='GPU' > eval.log 2>&1 &
|
||||
echo "end eval for checkpoint file: ${CHECKPOINT_FILE_PATH}"
|
||||
cd ..
|
|
@ -0,0 +1,60 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# -ne 1 ]
|
||||
then
|
||||
echo "Usage: sh run_distribute_train_gpu_fp16.sh [DATA_PATH]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
PATH1=$(get_real_path $1)
|
||||
echo $PATH1
|
||||
if [ ! -d $PATH1 ]
|
||||
then
|
||||
echo "error: IMAGE_PATH=$PATH1 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=1
|
||||
export DEVICE_ID=0
|
||||
export RANK_ID=0
|
||||
export RANK_SIZE=1
|
||||
|
||||
if [ -d "train_fp16" ];
|
||||
then
|
||||
rm -rf ./train_fp16
|
||||
fi
|
||||
|
||||
rm -rf ./train_fp16
|
||||
mkdir ./train_fp16
|
||||
cp ../*.py ./train_fp16
|
||||
cp *.sh ./train_fp16
|
||||
cp ../*.yaml ./train_fp16
|
||||
cp -r ../src ./train_fp16
|
||||
cd ./train_fp16 || exit
|
||||
echo "start training for device $DEVICE_ID"
|
||||
env > env.log
|
||||
python train.py --data_path=$PATH1 --output_path './output' --device_target='GPU' --checkpoint_path='./' --enable_fp16_gpu=True > train.log 2>&1 &
|
||||
cd ..
|
|
@ -0,0 +1,60 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# -ne 1 ]
|
||||
then
|
||||
echo "Usage: sh run_distribute_train_gpu_fp32.sh [DATA_PATH]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
PATH1=$(get_real_path $1)
|
||||
echo $PATH1
|
||||
if [ ! -d $PATH1 ]
|
||||
then
|
||||
echo "error: IMAGE_PATH=$PATH1 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=1
|
||||
export DEVICE_ID=0
|
||||
export RANK_ID=0
|
||||
export RANK_SIZE=1
|
||||
|
||||
if [ -d "train_fp32" ];
|
||||
then
|
||||
rm -rf ./train_fp32
|
||||
fi
|
||||
|
||||
rm -rf ./train_fp32
|
||||
mkdir ./train_fp32
|
||||
cp ../*.py ./train_fp32
|
||||
cp *.sh ./train_fp32
|
||||
cp ../*.yaml ./train_fp32
|
||||
cp -r ../src ./train_fp32
|
||||
cd ./train_fp32 || exit
|
||||
echo "start training for device $DEVICE_ID"
|
||||
env > env.log
|
||||
python train.py --data_path=$PATH1 --output_path './output' --device_target='GPU' --checkpoint_path='./' > train.log 2>&1 &
|
||||
cd ..
|
|
@ -117,9 +117,9 @@ def get_config():
|
|||
help="Config file path")
|
||||
path_args, _ = parser.parse_known_args()
|
||||
default, helper = parse_yaml(path_args.config_path)
|
||||
pprint(default)
|
||||
args = parse_cli_to_yaml(parser, default, helper, path_args.config_path)
|
||||
final_config = merge(args, default)
|
||||
pprint(final_config)
|
||||
return Config(final_config)
|
||||
|
||||
config = get_config()
|
||||
|
|
|
@ -19,6 +19,48 @@ from mindspore.ops import operations as P
|
|||
from src.unet3d_parts import Down, Up
|
||||
from src.model_utils.config import config
|
||||
|
||||
class UNet3d_(nn.Cell):
|
||||
"""
|
||||
UNet3d_ support fp32 and fp16(amp) training on GPU.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(UNet3d_, self).__init__()
|
||||
self.n_channels = config.in_channels
|
||||
self.n_classes = config.num_classes
|
||||
|
||||
# down
|
||||
self.down1 = Down(in_channel=self.n_channels, out_channel=16, dtype=mstype.float32)
|
||||
self.down2 = Down(in_channel=16, out_channel=32, dtype=mstype.float32)
|
||||
self.down3 = Down(in_channel=32, out_channel=64, dtype=mstype.float32)
|
||||
self.down4 = Down(in_channel=64, out_channel=128, dtype=mstype.float32)
|
||||
self.down5 = Down(in_channel=128, out_channel=256, stride=1, kernel_size=(1, 1, 1), \
|
||||
dtype=mstype.float32)
|
||||
|
||||
# up
|
||||
self.up1 = Up(in_channel=256, down_in_channel=128, out_channel=64, \
|
||||
dtype=mstype.float32)
|
||||
self.up2 = Up(in_channel=64, down_in_channel=64, out_channel=32, \
|
||||
dtype=mstype.float32)
|
||||
self.up3 = Up(in_channel=32, down_in_channel=32, out_channel=16, \
|
||||
dtype=mstype.float32)
|
||||
self.up4 = Up(in_channel=16, down_in_channel=16, out_channel=self.n_classes, \
|
||||
dtype=mstype.float32, is_output=True)
|
||||
|
||||
|
||||
def construct(self, input_data):
|
||||
x1 = self.down1(input_data)
|
||||
x2 = self.down2(x1)
|
||||
x3 = self.down3(x2)
|
||||
x4 = self.down4(x3)
|
||||
x5 = self.down5(x4)
|
||||
|
||||
x = self.up1(x5, x4)
|
||||
x = self.up2(x, x3)
|
||||
x = self.up3(x, x2)
|
||||
x = self.up4(x, x1)
|
||||
return x
|
||||
|
||||
|
||||
class UNet3d(nn.Cell):
|
||||
def __init__(self):
|
||||
super(UNet3d, self).__init__()
|
||||
|
|
|
@ -19,20 +19,23 @@ import mindspore.nn as nn
|
|||
import mindspore.common.dtype as mstype
|
||||
from mindspore import Tensor, Model, context
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.communication.management import init
|
||||
from mindspore.communication.management import init, get_rank, get_group_size
|
||||
from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
||||
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, LossMonitor, TimeMonitor
|
||||
from src.dataset import create_dataset
|
||||
from src.unet3d_model import UNet3d
|
||||
from src.unet3d_model import UNet3d, UNet3d_
|
||||
from src.lr_schedule import dynamic_lr
|
||||
from src.loss import SoftmaxCrossEntropyWithLogits
|
||||
from src.model_utils.config import config
|
||||
from src.model_utils.moxing_adapter import moxing_wrapper
|
||||
from src.model_utils.device_adapter import get_device_id, get_device_num
|
||||
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, \
|
||||
device_id=device_id)
|
||||
if config.device_target == 'Ascend':
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target, save_graphs=False, \
|
||||
device_id=device_id)
|
||||
else:
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target, save_graphs=False)
|
||||
mindspore.set_seed(1)
|
||||
|
||||
@moxing_wrapper()
|
||||
|
@ -42,8 +45,12 @@ def train_net(data_path,
|
|||
seg_dir = data_path + "/seg/"
|
||||
if run_distribute:
|
||||
init()
|
||||
rank_id = get_device_id()
|
||||
rank_size = get_device_num()
|
||||
if config.device_target == 'Ascend':
|
||||
rank_id = get_device_id()
|
||||
rank_size = get_device_num()
|
||||
else:
|
||||
rank_id = get_rank()
|
||||
rank_size = get_group_size()
|
||||
parallel_mode = ParallelMode.DATA_PARALLEL
|
||||
context.set_auto_parallel_context(parallel_mode=parallel_mode,
|
||||
device_num=rank_size,
|
||||
|
@ -56,7 +63,10 @@ def train_net(data_path,
|
|||
train_data_size = train_dataset.get_dataset_size()
|
||||
print("train dataset length is:", train_data_size)
|
||||
|
||||
network = UNet3d()
|
||||
if config.device_target == 'Ascend':
|
||||
network = UNet3d()
|
||||
else:
|
||||
network = UNet3d_()
|
||||
|
||||
loss = SoftmaxCrossEntropyWithLogits()
|
||||
lr = Tensor(dynamic_lr(config, train_data_size), mstype.float32)
|
||||
|
@ -64,7 +74,10 @@ def train_net(data_path,
|
|||
scale_manager = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
|
||||
network.set_train()
|
||||
|
||||
model = Model(network, loss_fn=loss, optimizer=optimizer, loss_scale_manager=scale_manager)
|
||||
if config.device_target == 'GPU' and config.enable_fp16_gpu:
|
||||
model = Model(network, loss_fn=loss, optimizer=optimizer, loss_scale_manager=scale_manager, amp_level='O2')
|
||||
else:
|
||||
model = Model(network, loss_fn=loss, optimizer=optimizer, loss_scale_manager=scale_manager)
|
||||
|
||||
time_cb = TimeMonitor(data_size=train_data_size)
|
||||
loss_cb = LossMonitor()
|
||||
|
@ -72,7 +85,7 @@ def train_net(data_path,
|
|||
keep_checkpoint_max=config.keep_checkpoint_max)
|
||||
ckpt_save_dir = os.path.join(config.output_path, config.checkpoint_path)
|
||||
ckpoint_cb = ModelCheckpoint(prefix='Unet3d',
|
||||
directory=ckpt_save_dir+'./ckpt_{}/'.format(device_id),
|
||||
directory=ckpt_save_dir+'./ckpt_{}/'.format(rank_id),
|
||||
config=ckpt_config)
|
||||
callbacks_list = [loss_cb, time_cb, ckpoint_cb]
|
||||
print("============== Starting Training ==============")
|
||||
|
|
|
@ -0,0 +1,61 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops.operations import _grad_ops as G
|
||||
|
||||
|
||||
class NetPReLUGrad(nn.Cell):
|
||||
def __init__(self):
|
||||
super(NetPReLUGrad, self).__init__()
|
||||
self.prelu_grad = G.PReLUGrad()
|
||||
|
||||
def construct(self, dout, x, w):
|
||||
return self.prelu_grad(dout, x, w)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_prelu_grad_fp32_channel_shared():
|
||||
dout = Tensor(np.ones(shape=[2, 2, 2, 3]).astype(np.float32))
|
||||
x = Tensor(np.arange(-5, 19).reshape(2, 2, 2, 3).astype(np.float32))
|
||||
w = Tensor(np.array([-0.5]).astype(np.float32))
|
||||
expect_dx = np.array([[[[-0.5000, -0.5000, -0.5000],
|
||||
[-0.5000, -0.5000, -0.5000]],
|
||||
[[1.0000, 1.0000, 1.0000],
|
||||
[1.0000, 1.0000, 1.0000]]],
|
||||
[[[1.0000, 1.0000, 1.0000],
|
||||
[1.0000, 1.0000, 1.0000]],
|
||||
[[1.0000, 1.0000, 1.0000],
|
||||
[1.0000, 1.0000, 1.0000]]]]).astype(np.float32)
|
||||
expect_dw = np.array([-15.]).astype(np.float32)
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
prelu_grad = NetPReLUGrad()
|
||||
dx, dw = prelu_grad(dout, x, w)
|
||||
assert (dx.asnumpy() == expect_dx).all()
|
||||
assert (dw.asnumpy() == expect_dw).all()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
prelu_grad = NetPReLUGrad()
|
||||
dx, dw = prelu_grad(dout, x, w)
|
||||
assert (dx.asnumpy() == expect_dx).all()
|
||||
assert (dw.asnumpy() == expect_dw).all()
|
|
@ -70,6 +70,30 @@ def test_slice_4d():
|
|||
assert (output_ms.asnumpy() == output_np).all()
|
||||
|
||||
|
||||
class Slice5DNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Slice5DNet, self).__init__()
|
||||
self.slice = P.Slice()
|
||||
|
||||
def construct(self, x):
|
||||
return self.slice(x, (0, 11, 1, 2, 3), (32, 7, 14, 10, 221))
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_slice_5d():
|
||||
x_np = np.random.randn(32, 32, 24, 224, 224).astype(np.float32)
|
||||
output_np = x_np[:, 11:18, 1:15, 2:12, 3:224]
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
x_ms = Tensor(x_np)
|
||||
net = Slice5DNet()
|
||||
output_ms = net(x_ms)
|
||||
|
||||
assert (output_ms.asnumpy() == output_np).all()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
|
|
Loading…
Reference in New Issue