!39701 mediangrad算子GPU支持

Merge pull request !39701 from 黄晓/mediangrad_0804
This commit is contained in:
i-robot 2022-08-05 02:08:55 +00:00 committed by Gitee
commit 60587fcb49
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
9 changed files with 651 additions and 3 deletions

View File

@ -0,0 +1,96 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "median_grad_impl.cuh"
#include <iostream>
#include <vector>
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/util.cuh"
template <typename T>
__global__ void Count_Repeat(const T *x, const T *y, int64_t size, int *repeat_val) {
for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
if (x[pos] == *y) {
MsAtomicAdd(repeat_val, 1);
}
}
}
template <typename T, typename V>
__global__ void GlobalMedianGradComputer(const T *y_grad, const T *x, const T *y, V *output, int *repeat_val,
const int64_t size) {
for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
if (x[pos] == *y) {
output[pos] = *y_grad / *repeat_val;
} else {
output[pos] = 0;
}
}
}
template <typename T, typename S, typename V>
__global__ void MedianGradComputer(const T *y_grad, const S *indices, const T *y, V *output, int *elem_num_each_dim_x,
int *elem_num_each_dim_y, int64_t axis, int64_t input_dim, int64_t size) {
for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
int elements_remain = pos;
int temp = 0;
int update_pos = 0;
for (int i = 0; i < input_dim; i++) {
temp = elements_remain / elem_num_each_dim_y[i];
elements_remain %= elem_num_each_dim_y[i];
if (i == axis) {
update_pos += *(indices + pos) * elem_num_each_dim_x[i];
} else {
update_pos += temp * elem_num_each_dim_x[i];
}
}
*(output + update_pos) = *(y_grad + pos);
}
}
template <typename T, typename S, typename V>
void MedianGrad(const T *y_grad, const T *x, const T *y, const S *indices, V *output, const int64_t axis,
bool global_median, const int64_t input0_size, const int64_t input1_size, int64_t input_dim,
int *elem_num_each_dim_x, int *elem_num_each_dim_y, int *repeat_val, cudaStream_t cuda_stream) {
if (global_median) {
Count_Repeat<T><<<GET_BLOCKS(input1_size), GET_THREADS, 0, cuda_stream>>>(x, y, input1_size, repeat_val);
GlobalMedianGradComputer<T, V>
<<<GET_BLOCKS(input1_size), GET_THREADS, 0, cuda_stream>>>(y_grad, x, y, output, repeat_val, input1_size);
} else {
MedianGradComputer<T, S, V><<<GET_BLOCKS(input0_size), GET_THREADS, 0, cuda_stream>>>(
y_grad, indices, y, output, elem_num_each_dim_x, elem_num_each_dim_y, axis, input_dim, input0_size);
}
}
template CUDA_LIB_EXPORT void MedianGrad<int16_t, int64_t, float>(
const int16_t *input0_value, const int16_t *input1_value, const int16_t *input2_value, const int64_t *input3_value,
float *output, const int64_t axis, bool global_median, const int64_t input0_size, const int64_t input1_size,
int64_t input_dim, int *elem_num_each_dim_x, int *elem_num_each_dim_y, int *repeat_val, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void MedianGrad<int32_t, int64_t, float>(
const int32_t *input0_value, const int32_t *input1_value, const int32_t *input2_value, const int64_t *input3_value,
float *output, const int64_t axis, bool global_median, const int64_t input0_size, const int64_t input1_size,
int64_t input_dim, int *elem_num_each_dim_x, int *elem_num_each_dim_y, int *repeat_val, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void MedianGrad<int64_t, int64_t, float>(
const int64_t *input0_value, const int64_t *input1_value, const int64_t *input2_value, const int64_t *input3_value,
float *output, const int64_t axis, bool global_median, const int64_t input0_size, const int64_t input1_size,
int64_t input_dim, int *elem_num_each_dim_x, int *elem_num_each_dim_y, int *repeat_val, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void MedianGrad<float, int64_t, float>(
const float *input0_value, const float *input1_value, const float *input2_value, const int64_t *input3_value,
float *output, const int64_t axis, bool global_median, const int64_t input0_size, const int64_t input1_size,
int64_t input_dim, int *elem_num_each_dim_x, int *elem_num_each_dim_y, int *repeat_val, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void MedianGrad<double, int64_t, double>(
const double *input0_value, const double *input1_value, const double *input2_value, const int64_t *input3_value,
double *output, const int64_t axis, bool global_median, const int64_t input0_size, const int64_t input1_size,
int64_t input_dim, int *elem_num_each_dim_x, int *elem_num_each_dim_y, int *repeat_val, cudaStream_t cuda_stream);

View File

@ -0,0 +1,29 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_MEDIAN_GRAD_IMPL_CUH_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_MEDIAN_GRAD_IMPL_CUH_
#include <vector>
#include "include/cuda_fp16.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h"
template <typename T, typename S, typename V>
void MedianGrad(const T *input0_value, const T *input1_value, const T *input2_value, const S *input3_value, V *output,
const int64_t axis, bool global_median, const int64_t input0_size, const int64_t input1_size,
const int64_t input1_dim_, int *elem_num_each_dim_x, int *elem_num_each_dim_y, int *repeat_val,
cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_MEDIAN_GRAD_IMPL_CUH_

View File

@ -35,7 +35,7 @@ __device__ __forceinline__ unsigned int warp_ballot(int predicate) {
template <typename T>
static __device__ __host__ T round_up(T a, T b) {
return (a / b) * b;
return ((a + b - 1) / b) * b;
}
template <typename T>

View File

@ -0,0 +1,62 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/gpu/kernel/math/median_grad_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_THREE(MedianGrad,
KernelAttr()
.AddInputAttr(kNumberTypeInt16)
.AddInputAttr(kNumberTypeInt16)
.AddInputAttr(kNumberTypeInt16)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat32),
MedianGradGpuKernelMod, int16_t, int64_t, float)
MS_REG_GPU_KERNEL_THREE(MedianGrad,
KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat32),
MedianGradGpuKernelMod, int32_t, int64_t, float)
MS_REG_GPU_KERNEL_THREE(MedianGrad,
KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat32),
MedianGradGpuKernelMod, int64_t, int64_t, float)
MS_REG_GPU_KERNEL_THREE(MedianGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat32),
MedianGradGpuKernelMod, float, int64_t, float)
MS_REG_GPU_KERNEL_THREE(MedianGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat64),
MedianGradGpuKernelMod, double, int64_t, double)
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,244 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_MEDIAN_GRAD_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_MEDIAN_GRAD_GPU_KERNEL_H_
#include <vector>
#include <map>
#include "mindspore/core/ops/grad/median_grad.h"
#include "plugin/device/gpu/kernel/gpu_kernel.h"
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/median_grad_impl.cuh"
namespace mindspore {
namespace kernel {
constexpr size_t kMedianOutputsNum = 1;
constexpr size_t kInputsNum4 = 4;
constexpr size_t kInputsNum3 = 3;
template <typename T, typename S, typename V>
class MedianGradGpuKernelMod : public NativeGpuKernelMod {
public:
MedianGradGpuKernelMod() : global_median_(false), keep_dims_(false), axis_(0) {}
~MedianGradGpuKernelMod() = default;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
T *y_grad = GetDeviceAddress<T>(inputs, kIndex0);
T *x = GetDeviceAddress<T>(inputs, kIndex1);
T *y = GetDeviceAddress<T>(inputs, kIndex2);
S *indices = nullptr;
V *output0_addr = GetDeviceAddress<V>(outputs, kIndex0);
if (!global_median_) {
indices = GetDeviceAddress<S>(inputs, kIndex3);
}
int *elem_num_each_dim_x = GetDeviceAddress<int>(workspace, kIndex0);
int *elem_num_each_dim_y = GetDeviceAddress<int>(workspace, kIndex1);
int *repeat_val = GetDeviceAddress<int>(workspace, kIndex2);
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemsetAsync(output0_addr, 0, outputs[0]->size, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemSet Failed");
if (!global_median_) {
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemcpyAsync(elem_num_each_dim_x, &elem_num_each_dim_x_[0], sizeof(int) * input1_dim_,
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync elem_num_each_dim_x failed");
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemcpyAsync(elem_num_each_dim_y, &elem_num_each_dim_y_[0], sizeof(int) * input1_dim_,
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync elem_num_each_dim_y failed");
}
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemsetAsync(repeat_val, 0, sizeof(int), reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemset failed in repeat_val.");
MedianGrad(y_grad, x, y, indices, output0_addr, axis_, global_median_, input0_size_, input1_size_, input1_dim_,
elem_num_each_dim_x, elem_num_each_dim_y, repeat_val, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override {
kernel_name_ = base_operator->name();
auto kernel_ptr = std::dynamic_pointer_cast<ops::MedianGrad>(base_operator);
if (kernel_ptr == nullptr) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "' cast Median ops failed!";
return false;
}
if (((inputs.size() != kInputsNum4) && (inputs.size() != kInputsNum3)) || outputs.size() > kMedianOutputsNum) {
MS_LOG(ERROR) << kernel_name_ << ": input size should be 4 or 3"
<< "but get " << inputs.size() << " and output size should be 1, but get " << outputs.size();
return false;
}
global_median_ = kernel_ptr->get_global_median();
keep_dims_ = kernel_ptr->get_keep_dims();
axis_ = kernel_ptr->get_axis();
input_shape_ = inputs[1]->GetShapeVector();
input1_dim_ = input_shape_.size();
std::vector<int64_t> input0_shape = inputs[0]->GetShapeVector();
if (axis_ < -input1_dim_ || axis_ >= input1_dim_) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the 'axis' must be in the range [-" << input1_dim_ << ","
<< input1_dim_ << "), but got " << axis_;
}
if (axis_ < 0) {
axis_ += input1_dim_;
}
input1_size_ = 1;
input0_size_ = 1;
for (size_t i = 0; i < input_shape_.size(); i++) {
input1_size_ *= input_shape_[i];
}
for (size_t i = 0; i < input0_shape.size(); i++) {
input0_size_ *= input0_shape[i];
}
if (global_median_) {
input_shape_.clear();
input_shape_.push_back(input1_size_);
} else {
std::vector<int64_t> shape_keepdim;
for (int64_t i = 0; i < input1_dim_; i++) {
if (i == axis_) {
shape_keepdim.push_back(1);
} else {
shape_keepdim.push_back(input_shape_[i]);
}
}
int elem_num_x = 1;
int elem_num_y = 1;
for (size_t i = 0; i < shape_keepdim.size(); i++) {
elem_num_each_dim_x_.insert(elem_num_each_dim_x_.begin(), elem_num_x);
elem_num_x *= input_shape_[shape_keepdim.size() - 1 - i];
elem_num_each_dim_y_.insert(elem_num_each_dim_y_.begin(), elem_num_y);
elem_num_y *= shape_keepdim[shape_keepdim.size() - 1 - i];
}
}
ResetResource();
InitWorkSpaceSizeList();
return true;
}
int Resize(
const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost = std::map<uint32_t, tensor::TensorPtr>()) override {
int ret = KernelMod::Resize(base_operator, inputs, outputs);
if (ret != 0) {
return ret;
}
input_shape_ = inputs[1]->GetShapeVector();
std::vector<int64_t> input0_shape = inputs[0]->GetShapeVector();
input1_dim_ = input_shape_.size();
input1_size_ = 1;
input0_size_ = 1;
for (size_t i = 0; i < input_shape_.size(); i++) {
input1_size_ *= input_shape_[i];
}
for (size_t i = 0; i < input0_shape.size(); i++) {
input0_size_ *= input0_shape[i];
}
if (global_median_) {
input_shape_.clear();
input_shape_.push_back(input1_size_);
} else {
std::vector<int64_t> shape_keepdim;
for (int64_t i = 0; i < input1_dim_; i++) {
if (i == axis_) {
shape_keepdim.push_back(1);
} else {
shape_keepdim.push_back(input_shape_[i]);
}
}
int elem_num_x = 1;
int elem_num_y = 1;
elem_num_each_dim_x_.clear();
elem_num_each_dim_y_.clear();
for (size_t i = 0; i < shape_keepdim.size(); i++) {
elem_num_each_dim_x_.insert(elem_num_each_dim_x_.begin(), elem_num_x);
elem_num_x *= input_shape_[shape_keepdim.size() - 1 - i];
elem_num_each_dim_y_.insert(elem_num_each_dim_y_.begin(), elem_num_y);
elem_num_y *= shape_keepdim[shape_keepdim.size() - 1 - i];
}
}
InitWorkSpaceSizeList();
return KRET_OK;
}
std::vector<KernelAttr> GetOpSupport() {
static std::vector<KernelAttr> support_list = {KernelAttr()
.AddInputAttr(kNumberTypeInt16)
.AddInputAttr(kNumberTypeInt16)
.AddInputAttr(kNumberTypeInt16)
.AddOutputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat32),
KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat32),
KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat32),
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat32),
KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat32)};
return support_list;
}
protected:
void ResetResource() noexcept {
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
}
private:
void InitWorkSpaceSizeList() {
workspace_size_list_.push_back(input1_dim_ * sizeof(int));
workspace_size_list_.push_back(input1_dim_ * sizeof(int));
workspace_size_list_.push_back(sizeof(int));
}
bool global_median_;
bool keep_dims_;
int64_t axis_;
int64_t input1_dim_;
int64_t input0_size_;
int64_t input1_size_;
std::vector<int64_t> input_shape_;
std::vector<int> elem_num_each_dim_x_;
std::vector<int> elem_num_each_dim_y_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_MEDIAN_GRAD_GPU_KERNEL_H_

View File

@ -27,6 +27,37 @@
namespace mindspore {
namespace ops {
void MedianGrad::Init(const bool global_median, const int64_t axis, const bool keep_dims) {
this->set_global_median(global_median);
this->set_axis(axis);
this->set_keep_dims(keep_dims);
}
void MedianGrad::set_global_median(const bool global_median) {
(void)this->AddAttr(kGlobalMedian, api::MakeValue(global_median));
}
void MedianGrad::set_keep_dims(const bool keep_dims) { (void)this->AddAttr(kKeepDims, api::MakeValue(keep_dims)); }
void MedianGrad::set_axis(const int64_t &axis) {
int64_t f = axis;
(void)this->AddAttr(kAxis, api::MakeValue(f));
}
bool MedianGrad::get_global_median() const {
auto value_ptr = GetAttr(kGlobalMedian);
return GetValue<bool>(value_ptr);
}
bool MedianGrad::get_keep_dims() const {
auto value_ptr = GetAttr(kKeepDims);
return GetValue<bool>(value_ptr);
}
int64_t MedianGrad::get_axis() const {
auto value_ptr = GetAttr(kAxis);
return GetValue<int64_t>(value_ptr);
}
namespace {
abstract::ShapePtr MedianGradInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);

View File

@ -28,11 +28,31 @@
namespace mindspore {
namespace ops {
constexpr auto kNameMedianGrad = "MedianGrad";
class MedianGrad : public BaseOperator {
class MIND_API MedianGrad : public BaseOperator {
public:
MIND_API_BASE_MEMBER(MedianGrad);
/// \brief Constructor.
MedianGrad() : BaseOperator(kNameMedianGrad) { InitIOName({"y_grad", "x", "y", "indices"}, {"x_grad"}); }
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.Median for the inputs.
void Init(const bool global_median = false, const int64_t axis = 0, const bool keep_dims = false);
/// \brief Set global_median.
void set_global_median(const bool global_median);
/// \brief Set keep_dims.
void set_keep_dims(const bool keep_dims);
/// \brief Set axis.
void set_axis(const int64_t &axis);
/// \brief Get global_median.
///
/// \return global_median.
bool get_global_median() const;
/// \brief Get keep_dims.
///
/// \return keep_dims.
bool get_keep_dims() const;
/// \brief Get axis.
///
/// \return axis.
int64_t get_axis() const;
};
abstract::AbstractBasePtr MedianGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,

View File

@ -485,7 +485,7 @@ def get_median_vmap_rule(prim, axis_size):
axis += rank - 1
axis_new = axis + 1 if dim <= axis else axis
if keep_dims:
dim_new = axis_new
dim_new = dim
else:
dim_new = dim - 1 if dim > axis_new else dim
return axis_new, dim_new
@ -556,6 +556,36 @@ def get_index_add_vmap_rule(prim, axis_size):
return vmap_rule
@vmap_rules_getters.register(G.MedianGrad)
def get_median_grad_vmap_rule(prim, axis_size):
"""VmapRule for MedianGrad."""
global_median = prim.global_median
axis = prim.axis
keep_dims = prim.keep_dims
@constexpr
def trans_grad_axis(axis, rank, dim, keep_dims):
if axis < 0:
axis += rank - 1
axis_new = axis + 1 if dim <= axis else axis
if keep_dims:
dim_new = dim
else:
dim_new = dim - 1 if dim > axis_new else dim
return axis_new, dim_new
def vmap_rule(dy_bdim, x_bdim, y_bdim, indices_bdim):
dy, _ = dy_bdim
x, x_dim = x_bdim
y, _ = y_bdim
indices, _ = indices_bdim
rank = len(x.shape)
axis_new, dim_new = trans_grad_axis(axis, rank, x_dim, keep_dims)
x_grad = G.MedianGrad(global_median, axis_new, keep_dims)(dy, x, y, indices)
return (x_grad, dim_new)
return vmap_rule
@vmap_rules_getters.register(linalg_ops.Svd)
def get_svd_vmap_rule(prim, axis_size):
"""VmapRule for 'Svd' operation."""

View File

@ -0,0 +1,136 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import torch
from mindspore import Tensor
from mindspore.nn import Cell
from mindspore.ops.operations.math_ops import Median
import mindspore.ops.operations._grad_ops as G
from mindspore.ops.composite import GradOperation
class Grad(Cell):
def __init__(self, network):
super(Grad, self).__init__()
self.grad = GradOperation(get_all=False, sens_param=False)
self.network = network
def construct(self, input_x):
gout = self.grad(self.network)(input_x)
return gout
class MedianC(Cell):
def __init__(self, global_median, axis, keep_dims):
super().__init__()
self.global_median = global_median
self.axis = axis
self.keep_dims = keep_dims
self.median = Median(self.global_median, self.axis, self.keep_dims)
def construct(self, x):
return self.median(x)
class MedianGrad(Cell):
def __init__(self, global_median, axis, keep_dims):
super().__init__()
self.global_median = global_median
self.axis = axis
self.keep_dims = keep_dims
self.median_grad = G.MedianGrad(self.global_median, self.axis, self.keep_dims)
def construct(self, dy, x, y, indices):
return self.median_grad(dy, x, y, indices)
class MedianFactory():
def __init__(self, input_shape, global_median, axis=0, keep_dims=False, dtype=np.float32):
super().__init__()
self.dtype = dtype
self.input = np.random.randn(*input_shape).astype(self.dtype)
self.global_median = global_median
self.axis = axis
self.keep_dims = keep_dims
self.output_grad_np = np.random.randn(*input_shape).astype(dtype=dtype)
def forward_mindspore_impl(self):
net = MedianC(self.global_median, self.axis, self.keep_dims)
y, indices = net(Tensor(self.input))
return y.asnumpy(), indices.asnumpy()
def grad_mindspore_impl(self):
input_x = Tensor(self.input)
net = MedianC(self.global_median, self.axis, self.keep_dims)
grad_net = Grad(net)
res = grad_net(input_x)
return res.asnumpy()
def forward_pytorch_impl(self):
input_pt = torch.from_numpy(self.input)
indices = None
if self.global_median is False:
y, indices = torch.median(input_pt, axis=self.axis, keepdim=self.keep_dims)
else:
y = torch.median(input_pt)
indices_np = None if indices is None else indices.numpy().astype(np.int64)
return y.numpy().astype(self.dtype), indices_np
def global_grad_pytorch_impl(self):
input_pt = torch.from_numpy(self.input)
input_pt.requires_grad = True
y = torch.median(input_pt)
y.backward()
return input_pt.grad.numpy()
def grad_pytorch_impl(self):
input_pt = torch.from_numpy(self.input)
input_pt.requires_grad = True
y, _ = torch.median(input_pt, axis=self.axis, keepdim=self.keep_dims)
y.sum().backward()
return input_pt.grad.numpy()
def forward_cmp(self):
y_pytorch, _ = self.forward_pytorch_impl()
y_mindspore, _ = self.forward_mindspore_impl()
assert np.allclose(y_pytorch, y_mindspore)
def grad_cmp(self):
grad_ms = self.grad_mindspore_impl()
if self.global_median is False:
grad_torch = self.grad_pytorch_impl()
else:
grad_torch = self.global_grad_pytorch_impl()
assert np.allclose(grad_ms, grad_torch)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_median_gpu():
"""
Feature: Test median.
Description: Test median and mediangrad in Gpu with different global_median parameter.
Expectation: the result match given one.
"""
fact = MedianFactory(input_shape=(5, 5), global_median=True, axis=0, keep_dims=True)
fact.forward_cmp()
fact.grad_cmp()
fact2 = MedianFactory(input_shape=(5, 5, 5), global_median=False, axis=1, keep_dims=False)
fact2.forward_cmp()
fact2.grad_cmp()