diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/median_grad_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/median_grad_impl.cu new file mode 100644 index 00000000000..93db7e1fd3f --- /dev/null +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/median_grad_impl.cu @@ -0,0 +1,96 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "median_grad_impl.cuh" +#include +#include +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/util.cuh" + +template +__global__ void Count_Repeat(const T *x, const T *y, int64_t size, int *repeat_val) { + for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { + if (x[pos] == *y) { + MsAtomicAdd(repeat_val, 1); + } + } +} + +template +__global__ void GlobalMedianGradComputer(const T *y_grad, const T *x, const T *y, V *output, int *repeat_val, + const int64_t size) { + for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { + if (x[pos] == *y) { + output[pos] = *y_grad / *repeat_val; + } else { + output[pos] = 0; + } + } +} + +template +__global__ void MedianGradComputer(const T *y_grad, const S *indices, const T *y, V *output, int *elem_num_each_dim_x, + int *elem_num_each_dim_y, int64_t axis, int64_t input_dim, int64_t size) { + for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { + int elements_remain = pos; + int temp = 0; + int update_pos = 0; + for (int i = 0; i < input_dim; i++) { + temp = elements_remain / elem_num_each_dim_y[i]; + elements_remain %= elem_num_each_dim_y[i]; + if (i == axis) { + update_pos += *(indices + pos) * elem_num_each_dim_x[i]; + } else { + update_pos += temp * elem_num_each_dim_x[i]; + } + } + *(output + update_pos) = *(y_grad + pos); + } +} + +template +void MedianGrad(const T *y_grad, const T *x, const T *y, const S *indices, V *output, const int64_t axis, + bool global_median, const int64_t input0_size, const int64_t input1_size, int64_t input_dim, + int *elem_num_each_dim_x, int *elem_num_each_dim_y, int *repeat_val, cudaStream_t cuda_stream) { + if (global_median) { + Count_Repeat<<>>(x, y, input1_size, repeat_val); + GlobalMedianGradComputer + <<>>(y_grad, x, y, output, repeat_val, input1_size); + } else { + MedianGradComputer<<>>( + y_grad, indices, y, output, elem_num_each_dim_x, elem_num_each_dim_y, axis, input_dim, input0_size); + } +} + +template CUDA_LIB_EXPORT void MedianGrad( + const int16_t *input0_value, const int16_t *input1_value, const int16_t *input2_value, const int64_t *input3_value, + float *output, const int64_t axis, bool global_median, const int64_t input0_size, const int64_t input1_size, + int64_t input_dim, int *elem_num_each_dim_x, int *elem_num_each_dim_y, int *repeat_val, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void MedianGrad( + const int32_t *input0_value, const int32_t *input1_value, const int32_t *input2_value, const int64_t *input3_value, + float *output, const int64_t axis, bool global_median, const int64_t input0_size, const int64_t input1_size, + int64_t input_dim, int *elem_num_each_dim_x, int *elem_num_each_dim_y, int *repeat_val, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void MedianGrad( + const int64_t *input0_value, const int64_t *input1_value, const int64_t *input2_value, const int64_t *input3_value, + float *output, const int64_t axis, bool global_median, const int64_t input0_size, const int64_t input1_size, + int64_t input_dim, int *elem_num_each_dim_x, int *elem_num_each_dim_y, int *repeat_val, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void MedianGrad( + const float *input0_value, const float *input1_value, const float *input2_value, const int64_t *input3_value, + float *output, const int64_t axis, bool global_median, const int64_t input0_size, const int64_t input1_size, + int64_t input_dim, int *elem_num_each_dim_x, int *elem_num_each_dim_y, int *repeat_val, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void MedianGrad( + const double *input0_value, const double *input1_value, const double *input2_value, const int64_t *input3_value, + double *output, const int64_t axis, bool global_median, const int64_t input0_size, const int64_t input1_size, + int64_t input_dim, int *elem_num_each_dim_x, int *elem_num_each_dim_y, int *repeat_val, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/median_grad_impl.cuh b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/median_grad_impl.cuh new file mode 100644 index 00000000000..c3f4cd66c03 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/median_grad_impl.cuh @@ -0,0 +1,29 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_MEDIAN_GRAD_IMPL_CUH_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_MEDIAN_GRAD_IMPL_CUH_ +#include +#include "include/cuda_fp16.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h" + +template +void MedianGrad(const T *input0_value, const T *input1_value, const T *input2_value, const S *input3_value, V *output, + const int64_t axis, bool global_median, const int64_t input0_size, const int64_t input1_size, + const int64_t input1_dim_, int *elem_num_each_dim_x, int *elem_num_each_dim_y, int *repeat_val, + cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_MEDIAN_GRAD_IMPL_CUH_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/median_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/median_impl.cu index 4d86a480fcd..55ef052fee7 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/median_impl.cu +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/median_impl.cu @@ -35,7 +35,7 @@ __device__ __forceinline__ unsigned int warp_ballot(int predicate) { template static __device__ __host__ T round_up(T a, T b) { - return (a / b) * b; + return ((a + b - 1) / b) * b; } template diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/math/median_grad_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/math/median_grad_gpu_kernel.cc new file mode 100644 index 00000000000..69cca31713e --- /dev/null +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/math/median_grad_gpu_kernel.cc @@ -0,0 +1,62 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/gpu/kernel/math/median_grad_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_THREE(MedianGrad, + KernelAttr() + .AddInputAttr(kNumberTypeInt16) + .AddInputAttr(kNumberTypeInt16) + .AddInputAttr(kNumberTypeInt16) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeFloat32), + MedianGradGpuKernelMod, int16_t, int64_t, float) +MS_REG_GPU_KERNEL_THREE(MedianGrad, + KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeFloat32), + MedianGradGpuKernelMod, int32_t, int64_t, float) +MS_REG_GPU_KERNEL_THREE(MedianGrad, + KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeFloat32), + MedianGradGpuKernelMod, int64_t, int64_t, float) +MS_REG_GPU_KERNEL_THREE(MedianGrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeFloat32), + MedianGradGpuKernelMod, float, int64_t, float) +MS_REG_GPU_KERNEL_THREE(MedianGrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeFloat64), + MedianGradGpuKernelMod, double, int64_t, double) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/math/median_grad_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/math/median_grad_gpu_kernel.h new file mode 100644 index 00000000000..c7f7ed18d51 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/math/median_grad_gpu_kernel.h @@ -0,0 +1,244 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_MEDIAN_GRAD_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_MEDIAN_GRAD_GPU_KERNEL_H_ + +#include +#include +#include "mindspore/core/ops/grad/median_grad.h" +#include "plugin/device/gpu/kernel/gpu_kernel.h" +#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/median_grad_impl.cuh" + +namespace mindspore { +namespace kernel { +constexpr size_t kMedianOutputsNum = 1; +constexpr size_t kInputsNum4 = 4; +constexpr size_t kInputsNum3 = 3; +template +class MedianGradGpuKernelMod : public NativeGpuKernelMod { + public: + MedianGradGpuKernelMod() : global_median_(false), keep_dims_(false), axis_(0) {} + ~MedianGradGpuKernelMod() = default; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + T *y_grad = GetDeviceAddress(inputs, kIndex0); + T *x = GetDeviceAddress(inputs, kIndex1); + T *y = GetDeviceAddress(inputs, kIndex2); + S *indices = nullptr; + V *output0_addr = GetDeviceAddress(outputs, kIndex0); + if (!global_median_) { + indices = GetDeviceAddress(inputs, kIndex3); + } + + int *elem_num_each_dim_x = GetDeviceAddress(workspace, kIndex0); + int *elem_num_each_dim_y = GetDeviceAddress(workspace, kIndex1); + int *repeat_val = GetDeviceAddress(workspace, kIndex2); + + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( + cudaMemsetAsync(output0_addr, 0, outputs[0]->size, reinterpret_cast(stream_ptr)), + "cudaMemSet Failed"); + + if (!global_median_) { + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( + cudaMemcpyAsync(elem_num_each_dim_x, &elem_num_each_dim_x_[0], sizeof(int) * input1_dim_, + cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), + "cudaMemcpyAsync elem_num_each_dim_x failed"); + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( + cudaMemcpyAsync(elem_num_each_dim_y, &elem_num_each_dim_y_[0], sizeof(int) * input1_dim_, + cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), + "cudaMemcpyAsync elem_num_each_dim_y failed"); + } + + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( + cudaMemsetAsync(repeat_val, 0, sizeof(int), reinterpret_cast(stream_ptr)), + "cudaMemset failed in repeat_val."); + + MedianGrad(y_grad, x, y, indices, output0_addr, axis_, global_median_, input0_size_, input1_size_, input1_dim_, + elem_num_each_dim_x, elem_num_each_dim_y, repeat_val, reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs) override { + kernel_name_ = base_operator->name(); + auto kernel_ptr = std::dynamic_pointer_cast(base_operator); + if (kernel_ptr == nullptr) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "' cast Median ops failed!"; + return false; + } + if (((inputs.size() != kInputsNum4) && (inputs.size() != kInputsNum3)) || outputs.size() > kMedianOutputsNum) { + MS_LOG(ERROR) << kernel_name_ << ": input size should be 4 or 3" + << "but get " << inputs.size() << " and output size should be 1, but get " << outputs.size(); + return false; + } + global_median_ = kernel_ptr->get_global_median(); + keep_dims_ = kernel_ptr->get_keep_dims(); + axis_ = kernel_ptr->get_axis(); + input_shape_ = inputs[1]->GetShapeVector(); + input1_dim_ = input_shape_.size(); + std::vector input0_shape = inputs[0]->GetShapeVector(); + + if (axis_ < -input1_dim_ || axis_ >= input1_dim_) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the 'axis' must be in the range [-" << input1_dim_ << "," + << input1_dim_ << "), but got " << axis_; + } + if (axis_ < 0) { + axis_ += input1_dim_; + } + input1_size_ = 1; + input0_size_ = 1; + for (size_t i = 0; i < input_shape_.size(); i++) { + input1_size_ *= input_shape_[i]; + } + for (size_t i = 0; i < input0_shape.size(); i++) { + input0_size_ *= input0_shape[i]; + } + if (global_median_) { + input_shape_.clear(); + input_shape_.push_back(input1_size_); + } else { + std::vector shape_keepdim; + for (int64_t i = 0; i < input1_dim_; i++) { + if (i == axis_) { + shape_keepdim.push_back(1); + } else { + shape_keepdim.push_back(input_shape_[i]); + } + } + int elem_num_x = 1; + int elem_num_y = 1; + for (size_t i = 0; i < shape_keepdim.size(); i++) { + elem_num_each_dim_x_.insert(elem_num_each_dim_x_.begin(), elem_num_x); + elem_num_x *= input_shape_[shape_keepdim.size() - 1 - i]; + elem_num_each_dim_y_.insert(elem_num_each_dim_y_.begin(), elem_num_y); + elem_num_y *= shape_keepdim[shape_keepdim.size() - 1 - i]; + } + } + ResetResource(); + InitWorkSpaceSizeList(); + return true; + } + + int Resize( + const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs, + const std::map &inputsOnHost = std::map()) override { + int ret = KernelMod::Resize(base_operator, inputs, outputs); + if (ret != 0) { + return ret; + } + input_shape_ = inputs[1]->GetShapeVector(); + std::vector input0_shape = inputs[0]->GetShapeVector(); + input1_dim_ = input_shape_.size(); + input1_size_ = 1; + input0_size_ = 1; + for (size_t i = 0; i < input_shape_.size(); i++) { + input1_size_ *= input_shape_[i]; + } + for (size_t i = 0; i < input0_shape.size(); i++) { + input0_size_ *= input0_shape[i]; + } + if (global_median_) { + input_shape_.clear(); + input_shape_.push_back(input1_size_); + } else { + std::vector shape_keepdim; + for (int64_t i = 0; i < input1_dim_; i++) { + if (i == axis_) { + shape_keepdim.push_back(1); + } else { + shape_keepdim.push_back(input_shape_[i]); + } + } + int elem_num_x = 1; + int elem_num_y = 1; + elem_num_each_dim_x_.clear(); + elem_num_each_dim_y_.clear(); + for (size_t i = 0; i < shape_keepdim.size(); i++) { + elem_num_each_dim_x_.insert(elem_num_each_dim_x_.begin(), elem_num_x); + elem_num_x *= input_shape_[shape_keepdim.size() - 1 - i]; + elem_num_each_dim_y_.insert(elem_num_each_dim_y_.begin(), elem_num_y); + elem_num_y *= shape_keepdim[shape_keepdim.size() - 1 - i]; + } + } + InitWorkSpaceSizeList(); + return KRET_OK; + } + + std::vector GetOpSupport() { + static std::vector support_list = {KernelAttr() + .AddInputAttr(kNumberTypeInt16) + .AddInputAttr(kNumberTypeInt16) + .AddInputAttr(kNumberTypeInt16) + .AddOutputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeFloat32), + KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeFloat32), + KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeFloat32), + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeFloat32), + KernelAttr() + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeFloat64) + .AddOutputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeFloat32)}; + return support_list; + } + + protected: + void ResetResource() noexcept { + input_size_list_.clear(); + output_size_list_.clear(); + workspace_size_list_.clear(); + } + + private: + void InitWorkSpaceSizeList() { + workspace_size_list_.push_back(input1_dim_ * sizeof(int)); + workspace_size_list_.push_back(input1_dim_ * sizeof(int)); + workspace_size_list_.push_back(sizeof(int)); + } + + bool global_median_; + bool keep_dims_; + int64_t axis_; + int64_t input1_dim_; + int64_t input0_size_; + int64_t input1_size_; + std::vector input_shape_; + std::vector elem_num_each_dim_x_; + std::vector elem_num_each_dim_y_; +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_MEDIAN_GRAD_GPU_KERNEL_H_ diff --git a/mindspore/core/ops/grad/median_grad.cc b/mindspore/core/ops/grad/median_grad.cc index 8c834570c6f..bca36900102 100644 --- a/mindspore/core/ops/grad/median_grad.cc +++ b/mindspore/core/ops/grad/median_grad.cc @@ -27,6 +27,37 @@ namespace mindspore { namespace ops { +void MedianGrad::Init(const bool global_median, const int64_t axis, const bool keep_dims) { + this->set_global_median(global_median); + this->set_axis(axis); + this->set_keep_dims(keep_dims); +} + +void MedianGrad::set_global_median(const bool global_median) { + (void)this->AddAttr(kGlobalMedian, api::MakeValue(global_median)); +} + +void MedianGrad::set_keep_dims(const bool keep_dims) { (void)this->AddAttr(kKeepDims, api::MakeValue(keep_dims)); } + +void MedianGrad::set_axis(const int64_t &axis) { + int64_t f = axis; + (void)this->AddAttr(kAxis, api::MakeValue(f)); +} + +bool MedianGrad::get_global_median() const { + auto value_ptr = GetAttr(kGlobalMedian); + return GetValue(value_ptr); +} + +bool MedianGrad::get_keep_dims() const { + auto value_ptr = GetAttr(kKeepDims); + return GetValue(value_ptr); +} + +int64_t MedianGrad::get_axis() const { + auto value_ptr = GetAttr(kAxis); + return GetValue(value_ptr); +} namespace { abstract::ShapePtr MedianGradInferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); diff --git a/mindspore/core/ops/grad/median_grad.h b/mindspore/core/ops/grad/median_grad.h index 1e6a116406e..4b002a7200a 100644 --- a/mindspore/core/ops/grad/median_grad.h +++ b/mindspore/core/ops/grad/median_grad.h @@ -28,11 +28,31 @@ namespace mindspore { namespace ops { constexpr auto kNameMedianGrad = "MedianGrad"; -class MedianGrad : public BaseOperator { +class MIND_API MedianGrad : public BaseOperator { public: MIND_API_BASE_MEMBER(MedianGrad); /// \brief Constructor. MedianGrad() : BaseOperator(kNameMedianGrad) { InitIOName({"y_grad", "x", "y", "indices"}, {"x_grad"}); } + /// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.Median for the inputs. + void Init(const bool global_median = false, const int64_t axis = 0, const bool keep_dims = false); + /// \brief Set global_median. + void set_global_median(const bool global_median); + /// \brief Set keep_dims. + void set_keep_dims(const bool keep_dims); + /// \brief Set axis. + void set_axis(const int64_t &axis); + /// \brief Get global_median. + /// + /// \return global_median. + bool get_global_median() const; + /// \brief Get keep_dims. + /// + /// \return keep_dims. + bool get_keep_dims() const; + /// \brief Get axis. + /// + /// \return axis. + int64_t get_axis() const; }; abstract::AbstractBasePtr MedianGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, diff --git a/mindspore/python/mindspore/ops/_vmap/vmap_math_ops.py b/mindspore/python/mindspore/ops/_vmap/vmap_math_ops.py index 9a8da8ca26f..124449aa850 100644 --- a/mindspore/python/mindspore/ops/_vmap/vmap_math_ops.py +++ b/mindspore/python/mindspore/ops/_vmap/vmap_math_ops.py @@ -485,7 +485,7 @@ def get_median_vmap_rule(prim, axis_size): axis += rank - 1 axis_new = axis + 1 if dim <= axis else axis if keep_dims: - dim_new = axis_new + dim_new = dim else: dim_new = dim - 1 if dim > axis_new else dim return axis_new, dim_new @@ -556,6 +556,36 @@ def get_index_add_vmap_rule(prim, axis_size): return vmap_rule +@vmap_rules_getters.register(G.MedianGrad) +def get_median_grad_vmap_rule(prim, axis_size): + """VmapRule for MedianGrad.""" + global_median = prim.global_median + axis = prim.axis + keep_dims = prim.keep_dims + + @constexpr + def trans_grad_axis(axis, rank, dim, keep_dims): + if axis < 0: + axis += rank - 1 + axis_new = axis + 1 if dim <= axis else axis + if keep_dims: + dim_new = dim + else: + dim_new = dim - 1 if dim > axis_new else dim + return axis_new, dim_new + + def vmap_rule(dy_bdim, x_bdim, y_bdim, indices_bdim): + dy, _ = dy_bdim + x, x_dim = x_bdim + y, _ = y_bdim + indices, _ = indices_bdim + rank = len(x.shape) + axis_new, dim_new = trans_grad_axis(axis, rank, x_dim, keep_dims) + x_grad = G.MedianGrad(global_median, axis_new, keep_dims)(dy, x, y, indices) + return (x_grad, dim_new) + return vmap_rule + + @vmap_rules_getters.register(linalg_ops.Svd) def get_svd_vmap_rule(prim, axis_size): """VmapRule for 'Svd' operation.""" diff --git a/tests/st/ops/graph_kernel/test_median.py b/tests/st/ops/graph_kernel/test_median.py new file mode 100644 index 00000000000..494605d33b9 --- /dev/null +++ b/tests/st/ops/graph_kernel/test_median.py @@ -0,0 +1,136 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest +import torch +from mindspore import Tensor +from mindspore.nn import Cell +from mindspore.ops.operations.math_ops import Median +import mindspore.ops.operations._grad_ops as G +from mindspore.ops.composite import GradOperation + + +class Grad(Cell): + def __init__(self, network): + super(Grad, self).__init__() + self.grad = GradOperation(get_all=False, sens_param=False) + self.network = network + + def construct(self, input_x): + gout = self.grad(self.network)(input_x) + return gout + + +class MedianC(Cell): + def __init__(self, global_median, axis, keep_dims): + super().__init__() + self.global_median = global_median + self.axis = axis + self.keep_dims = keep_dims + self.median = Median(self.global_median, self.axis, self.keep_dims) + + def construct(self, x): + return self.median(x) + + +class MedianGrad(Cell): + def __init__(self, global_median, axis, keep_dims): + super().__init__() + self.global_median = global_median + self.axis = axis + self.keep_dims = keep_dims + self.median_grad = G.MedianGrad(self.global_median, self.axis, self.keep_dims) + + def construct(self, dy, x, y, indices): + return self.median_grad(dy, x, y, indices) + + +class MedianFactory(): + def __init__(self, input_shape, global_median, axis=0, keep_dims=False, dtype=np.float32): + super().__init__() + self.dtype = dtype + self.input = np.random.randn(*input_shape).astype(self.dtype) + self.global_median = global_median + self.axis = axis + self.keep_dims = keep_dims + self.output_grad_np = np.random.randn(*input_shape).astype(dtype=dtype) + + def forward_mindspore_impl(self): + net = MedianC(self.global_median, self.axis, self.keep_dims) + y, indices = net(Tensor(self.input)) + return y.asnumpy(), indices.asnumpy() + + def grad_mindspore_impl(self): + input_x = Tensor(self.input) + net = MedianC(self.global_median, self.axis, self.keep_dims) + grad_net = Grad(net) + res = grad_net(input_x) + return res.asnumpy() + + def forward_pytorch_impl(self): + input_pt = torch.from_numpy(self.input) + indices = None + if self.global_median is False: + y, indices = torch.median(input_pt, axis=self.axis, keepdim=self.keep_dims) + else: + y = torch.median(input_pt) + indices_np = None if indices is None else indices.numpy().astype(np.int64) + return y.numpy().astype(self.dtype), indices_np + + def global_grad_pytorch_impl(self): + input_pt = torch.from_numpy(self.input) + input_pt.requires_grad = True + y = torch.median(input_pt) + y.backward() + return input_pt.grad.numpy() + + def grad_pytorch_impl(self): + input_pt = torch.from_numpy(self.input) + input_pt.requires_grad = True + y, _ = torch.median(input_pt, axis=self.axis, keepdim=self.keep_dims) + y.sum().backward() + return input_pt.grad.numpy() + + def forward_cmp(self): + y_pytorch, _ = self.forward_pytorch_impl() + y_mindspore, _ = self.forward_mindspore_impl() + assert np.allclose(y_pytorch, y_mindspore) + + def grad_cmp(self): + grad_ms = self.grad_mindspore_impl() + if self.global_median is False: + grad_torch = self.grad_pytorch_impl() + else: + grad_torch = self.global_grad_pytorch_impl() + assert np.allclose(grad_ms, grad_torch) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_median_gpu(): + """ + Feature: Test median. + Description: Test median and mediangrad in Gpu with different global_median parameter. + Expectation: the result match given one. + """ + + fact = MedianFactory(input_shape=(5, 5), global_median=True, axis=0, keep_dims=True) + fact.forward_cmp() + fact.grad_cmp() + fact2 = MedianFactory(input_shape=(5, 5, 5), global_median=False, axis=1, keep_dims=False) + fact2.forward_cmp() + fact2.grad_cmp()