forked from mindspore-Ecosystem/mindspore
!39701 mediangrad算子GPU支持
Merge pull request !39701 from 黄晓/mediangrad_0804
This commit is contained in:
commit
60587fcb49
|
@ -0,0 +1,96 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "median_grad_impl.cuh"
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/util.cuh"
|
||||
|
||||
template <typename T>
|
||||
__global__ void Count_Repeat(const T *x, const T *y, int64_t size, int *repeat_val) {
|
||||
for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
|
||||
if (x[pos] == *y) {
|
||||
MsAtomicAdd(repeat_val, 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename V>
|
||||
__global__ void GlobalMedianGradComputer(const T *y_grad, const T *x, const T *y, V *output, int *repeat_val,
|
||||
const int64_t size) {
|
||||
for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
|
||||
if (x[pos] == *y) {
|
||||
output[pos] = *y_grad / *repeat_val;
|
||||
} else {
|
||||
output[pos] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename S, typename V>
|
||||
__global__ void MedianGradComputer(const T *y_grad, const S *indices, const T *y, V *output, int *elem_num_each_dim_x,
|
||||
int *elem_num_each_dim_y, int64_t axis, int64_t input_dim, int64_t size) {
|
||||
for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
|
||||
int elements_remain = pos;
|
||||
int temp = 0;
|
||||
int update_pos = 0;
|
||||
for (int i = 0; i < input_dim; i++) {
|
||||
temp = elements_remain / elem_num_each_dim_y[i];
|
||||
elements_remain %= elem_num_each_dim_y[i];
|
||||
if (i == axis) {
|
||||
update_pos += *(indices + pos) * elem_num_each_dim_x[i];
|
||||
} else {
|
||||
update_pos += temp * elem_num_each_dim_x[i];
|
||||
}
|
||||
}
|
||||
*(output + update_pos) = *(y_grad + pos);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename S, typename V>
|
||||
void MedianGrad(const T *y_grad, const T *x, const T *y, const S *indices, V *output, const int64_t axis,
|
||||
bool global_median, const int64_t input0_size, const int64_t input1_size, int64_t input_dim,
|
||||
int *elem_num_each_dim_x, int *elem_num_each_dim_y, int *repeat_val, cudaStream_t cuda_stream) {
|
||||
if (global_median) {
|
||||
Count_Repeat<T><<<GET_BLOCKS(input1_size), GET_THREADS, 0, cuda_stream>>>(x, y, input1_size, repeat_val);
|
||||
GlobalMedianGradComputer<T, V>
|
||||
<<<GET_BLOCKS(input1_size), GET_THREADS, 0, cuda_stream>>>(y_grad, x, y, output, repeat_val, input1_size);
|
||||
} else {
|
||||
MedianGradComputer<T, S, V><<<GET_BLOCKS(input0_size), GET_THREADS, 0, cuda_stream>>>(
|
||||
y_grad, indices, y, output, elem_num_each_dim_x, elem_num_each_dim_y, axis, input_dim, input0_size);
|
||||
}
|
||||
}
|
||||
|
||||
template CUDA_LIB_EXPORT void MedianGrad<int16_t, int64_t, float>(
|
||||
const int16_t *input0_value, const int16_t *input1_value, const int16_t *input2_value, const int64_t *input3_value,
|
||||
float *output, const int64_t axis, bool global_median, const int64_t input0_size, const int64_t input1_size,
|
||||
int64_t input_dim, int *elem_num_each_dim_x, int *elem_num_each_dim_y, int *repeat_val, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void MedianGrad<int32_t, int64_t, float>(
|
||||
const int32_t *input0_value, const int32_t *input1_value, const int32_t *input2_value, const int64_t *input3_value,
|
||||
float *output, const int64_t axis, bool global_median, const int64_t input0_size, const int64_t input1_size,
|
||||
int64_t input_dim, int *elem_num_each_dim_x, int *elem_num_each_dim_y, int *repeat_val, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void MedianGrad<int64_t, int64_t, float>(
|
||||
const int64_t *input0_value, const int64_t *input1_value, const int64_t *input2_value, const int64_t *input3_value,
|
||||
float *output, const int64_t axis, bool global_median, const int64_t input0_size, const int64_t input1_size,
|
||||
int64_t input_dim, int *elem_num_each_dim_x, int *elem_num_each_dim_y, int *repeat_val, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void MedianGrad<float, int64_t, float>(
|
||||
const float *input0_value, const float *input1_value, const float *input2_value, const int64_t *input3_value,
|
||||
float *output, const int64_t axis, bool global_median, const int64_t input0_size, const int64_t input1_size,
|
||||
int64_t input_dim, int *elem_num_each_dim_x, int *elem_num_each_dim_y, int *repeat_val, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void MedianGrad<double, int64_t, double>(
|
||||
const double *input0_value, const double *input1_value, const double *input2_value, const int64_t *input3_value,
|
||||
double *output, const int64_t axis, bool global_median, const int64_t input0_size, const int64_t input1_size,
|
||||
int64_t input_dim, int *elem_num_each_dim_x, int *elem_num_each_dim_y, int *repeat_val, cudaStream_t cuda_stream);
|
|
@ -0,0 +1,29 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_MEDIAN_GRAD_IMPL_CUH_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_MEDIAN_GRAD_IMPL_CUH_
|
||||
#include <vector>
|
||||
#include "include/cuda_fp16.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h"
|
||||
|
||||
template <typename T, typename S, typename V>
|
||||
void MedianGrad(const T *input0_value, const T *input1_value, const T *input2_value, const S *input3_value, V *output,
|
||||
const int64_t axis, bool global_median, const int64_t input0_size, const int64_t input1_size,
|
||||
const int64_t input1_dim_, int *elem_num_each_dim_x, int *elem_num_each_dim_y, int *repeat_val,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_MEDIAN_GRAD_IMPL_CUH_
|
|
@ -35,7 +35,7 @@ __device__ __forceinline__ unsigned int warp_ballot(int predicate) {
|
|||
|
||||
template <typename T>
|
||||
static __device__ __host__ T round_up(T a, T b) {
|
||||
return (a / b) * b;
|
||||
return ((a + b - 1) / b) * b;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
|
|
@ -0,0 +1,62 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/gpu/kernel/math/median_grad_gpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_THREE(MedianGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
MedianGradGpuKernelMod, int16_t, int64_t, float)
|
||||
MS_REG_GPU_KERNEL_THREE(MedianGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
MedianGradGpuKernelMod, int32_t, int64_t, float)
|
||||
MS_REG_GPU_KERNEL_THREE(MedianGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
MedianGradGpuKernelMod, int64_t, int64_t, float)
|
||||
MS_REG_GPU_KERNEL_THREE(MedianGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
MedianGradGpuKernelMod, float, int64_t, float)
|
||||
MS_REG_GPU_KERNEL_THREE(MedianGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
MedianGradGpuKernelMod, double, int64_t, double)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,244 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_MEDIAN_GRAD_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_MEDIAN_GRAD_GPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include "mindspore/core/ops/grad/median_grad.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/median_grad_impl.cuh"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
constexpr size_t kMedianOutputsNum = 1;
|
||||
constexpr size_t kInputsNum4 = 4;
|
||||
constexpr size_t kInputsNum3 = 3;
|
||||
template <typename T, typename S, typename V>
|
||||
class MedianGradGpuKernelMod : public NativeGpuKernelMod {
|
||||
public:
|
||||
MedianGradGpuKernelMod() : global_median_(false), keep_dims_(false), axis_(0) {}
|
||||
~MedianGradGpuKernelMod() = default;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
T *y_grad = GetDeviceAddress<T>(inputs, kIndex0);
|
||||
T *x = GetDeviceAddress<T>(inputs, kIndex1);
|
||||
T *y = GetDeviceAddress<T>(inputs, kIndex2);
|
||||
S *indices = nullptr;
|
||||
V *output0_addr = GetDeviceAddress<V>(outputs, kIndex0);
|
||||
if (!global_median_) {
|
||||
indices = GetDeviceAddress<S>(inputs, kIndex3);
|
||||
}
|
||||
|
||||
int *elem_num_each_dim_x = GetDeviceAddress<int>(workspace, kIndex0);
|
||||
int *elem_num_each_dim_y = GetDeviceAddress<int>(workspace, kIndex1);
|
||||
int *repeat_val = GetDeviceAddress<int>(workspace, kIndex2);
|
||||
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
|
||||
cudaMemsetAsync(output0_addr, 0, outputs[0]->size, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemSet Failed");
|
||||
|
||||
if (!global_median_) {
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
|
||||
cudaMemcpyAsync(elem_num_each_dim_x, &elem_num_each_dim_x_[0], sizeof(int) * input1_dim_,
|
||||
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemcpyAsync elem_num_each_dim_x failed");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
|
||||
cudaMemcpyAsync(elem_num_each_dim_y, &elem_num_each_dim_y_[0], sizeof(int) * input1_dim_,
|
||||
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemcpyAsync elem_num_each_dim_y failed");
|
||||
}
|
||||
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
|
||||
cudaMemsetAsync(repeat_val, 0, sizeof(int), reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemset failed in repeat_val.");
|
||||
|
||||
MedianGrad(y_grad, x, y, indices, output0_addr, axis_, global_median_, input0_size_, input1_size_, input1_dim_,
|
||||
elem_num_each_dim_x, elem_num_each_dim_y, repeat_val, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override {
|
||||
kernel_name_ = base_operator->name();
|
||||
auto kernel_ptr = std::dynamic_pointer_cast<ops::MedianGrad>(base_operator);
|
||||
if (kernel_ptr == nullptr) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "' cast Median ops failed!";
|
||||
return false;
|
||||
}
|
||||
if (((inputs.size() != kInputsNum4) && (inputs.size() != kInputsNum3)) || outputs.size() > kMedianOutputsNum) {
|
||||
MS_LOG(ERROR) << kernel_name_ << ": input size should be 4 or 3"
|
||||
<< "but get " << inputs.size() << " and output size should be 1, but get " << outputs.size();
|
||||
return false;
|
||||
}
|
||||
global_median_ = kernel_ptr->get_global_median();
|
||||
keep_dims_ = kernel_ptr->get_keep_dims();
|
||||
axis_ = kernel_ptr->get_axis();
|
||||
input_shape_ = inputs[1]->GetShapeVector();
|
||||
input1_dim_ = input_shape_.size();
|
||||
std::vector<int64_t> input0_shape = inputs[0]->GetShapeVector();
|
||||
|
||||
if (axis_ < -input1_dim_ || axis_ >= input1_dim_) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the 'axis' must be in the range [-" << input1_dim_ << ","
|
||||
<< input1_dim_ << "), but got " << axis_;
|
||||
}
|
||||
if (axis_ < 0) {
|
||||
axis_ += input1_dim_;
|
||||
}
|
||||
input1_size_ = 1;
|
||||
input0_size_ = 1;
|
||||
for (size_t i = 0; i < input_shape_.size(); i++) {
|
||||
input1_size_ *= input_shape_[i];
|
||||
}
|
||||
for (size_t i = 0; i < input0_shape.size(); i++) {
|
||||
input0_size_ *= input0_shape[i];
|
||||
}
|
||||
if (global_median_) {
|
||||
input_shape_.clear();
|
||||
input_shape_.push_back(input1_size_);
|
||||
} else {
|
||||
std::vector<int64_t> shape_keepdim;
|
||||
for (int64_t i = 0; i < input1_dim_; i++) {
|
||||
if (i == axis_) {
|
||||
shape_keepdim.push_back(1);
|
||||
} else {
|
||||
shape_keepdim.push_back(input_shape_[i]);
|
||||
}
|
||||
}
|
||||
int elem_num_x = 1;
|
||||
int elem_num_y = 1;
|
||||
for (size_t i = 0; i < shape_keepdim.size(); i++) {
|
||||
elem_num_each_dim_x_.insert(elem_num_each_dim_x_.begin(), elem_num_x);
|
||||
elem_num_x *= input_shape_[shape_keepdim.size() - 1 - i];
|
||||
elem_num_each_dim_y_.insert(elem_num_each_dim_y_.begin(), elem_num_y);
|
||||
elem_num_y *= shape_keepdim[shape_keepdim.size() - 1 - i];
|
||||
}
|
||||
}
|
||||
ResetResource();
|
||||
InitWorkSpaceSizeList();
|
||||
return true;
|
||||
}
|
||||
|
||||
int Resize(
|
||||
const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost = std::map<uint32_t, tensor::TensorPtr>()) override {
|
||||
int ret = KernelMod::Resize(base_operator, inputs, outputs);
|
||||
if (ret != 0) {
|
||||
return ret;
|
||||
}
|
||||
input_shape_ = inputs[1]->GetShapeVector();
|
||||
std::vector<int64_t> input0_shape = inputs[0]->GetShapeVector();
|
||||
input1_dim_ = input_shape_.size();
|
||||
input1_size_ = 1;
|
||||
input0_size_ = 1;
|
||||
for (size_t i = 0; i < input_shape_.size(); i++) {
|
||||
input1_size_ *= input_shape_[i];
|
||||
}
|
||||
for (size_t i = 0; i < input0_shape.size(); i++) {
|
||||
input0_size_ *= input0_shape[i];
|
||||
}
|
||||
if (global_median_) {
|
||||
input_shape_.clear();
|
||||
input_shape_.push_back(input1_size_);
|
||||
} else {
|
||||
std::vector<int64_t> shape_keepdim;
|
||||
for (int64_t i = 0; i < input1_dim_; i++) {
|
||||
if (i == axis_) {
|
||||
shape_keepdim.push_back(1);
|
||||
} else {
|
||||
shape_keepdim.push_back(input_shape_[i]);
|
||||
}
|
||||
}
|
||||
int elem_num_x = 1;
|
||||
int elem_num_y = 1;
|
||||
elem_num_each_dim_x_.clear();
|
||||
elem_num_each_dim_y_.clear();
|
||||
for (size_t i = 0; i < shape_keepdim.size(); i++) {
|
||||
elem_num_each_dim_x_.insert(elem_num_each_dim_x_.begin(), elem_num_x);
|
||||
elem_num_x *= input_shape_[shape_keepdim.size() - 1 - i];
|
||||
elem_num_each_dim_y_.insert(elem_num_each_dim_y_.begin(), elem_num_y);
|
||||
elem_num_y *= shape_keepdim[shape_keepdim.size() - 1 - i];
|
||||
}
|
||||
}
|
||||
InitWorkSpaceSizeList();
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> GetOpSupport() {
|
||||
static std::vector<KernelAttr> support_list = {KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddOutputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddOutputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat32)};
|
||||
return support_list;
|
||||
}
|
||||
|
||||
protected:
|
||||
void ResetResource() noexcept {
|
||||
input_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
workspace_size_list_.clear();
|
||||
}
|
||||
|
||||
private:
|
||||
void InitWorkSpaceSizeList() {
|
||||
workspace_size_list_.push_back(input1_dim_ * sizeof(int));
|
||||
workspace_size_list_.push_back(input1_dim_ * sizeof(int));
|
||||
workspace_size_list_.push_back(sizeof(int));
|
||||
}
|
||||
|
||||
bool global_median_;
|
||||
bool keep_dims_;
|
||||
int64_t axis_;
|
||||
int64_t input1_dim_;
|
||||
int64_t input0_size_;
|
||||
int64_t input1_size_;
|
||||
std::vector<int64_t> input_shape_;
|
||||
std::vector<int> elem_num_each_dim_x_;
|
||||
std::vector<int> elem_num_each_dim_y_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_MEDIAN_GRAD_GPU_KERNEL_H_
|
|
@ -27,6 +27,37 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
void MedianGrad::Init(const bool global_median, const int64_t axis, const bool keep_dims) {
|
||||
this->set_global_median(global_median);
|
||||
this->set_axis(axis);
|
||||
this->set_keep_dims(keep_dims);
|
||||
}
|
||||
|
||||
void MedianGrad::set_global_median(const bool global_median) {
|
||||
(void)this->AddAttr(kGlobalMedian, api::MakeValue(global_median));
|
||||
}
|
||||
|
||||
void MedianGrad::set_keep_dims(const bool keep_dims) { (void)this->AddAttr(kKeepDims, api::MakeValue(keep_dims)); }
|
||||
|
||||
void MedianGrad::set_axis(const int64_t &axis) {
|
||||
int64_t f = axis;
|
||||
(void)this->AddAttr(kAxis, api::MakeValue(f));
|
||||
}
|
||||
|
||||
bool MedianGrad::get_global_median() const {
|
||||
auto value_ptr = GetAttr(kGlobalMedian);
|
||||
return GetValue<bool>(value_ptr);
|
||||
}
|
||||
|
||||
bool MedianGrad::get_keep_dims() const {
|
||||
auto value_ptr = GetAttr(kKeepDims);
|
||||
return GetValue<bool>(value_ptr);
|
||||
}
|
||||
|
||||
int64_t MedianGrad::get_axis() const {
|
||||
auto value_ptr = GetAttr(kAxis);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
namespace {
|
||||
abstract::ShapePtr MedianGradInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
|
|
|
@ -28,11 +28,31 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameMedianGrad = "MedianGrad";
|
||||
class MedianGrad : public BaseOperator {
|
||||
class MIND_API MedianGrad : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(MedianGrad);
|
||||
/// \brief Constructor.
|
||||
MedianGrad() : BaseOperator(kNameMedianGrad) { InitIOName({"y_grad", "x", "y", "indices"}, {"x_grad"}); }
|
||||
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.Median for the inputs.
|
||||
void Init(const bool global_median = false, const int64_t axis = 0, const bool keep_dims = false);
|
||||
/// \brief Set global_median.
|
||||
void set_global_median(const bool global_median);
|
||||
/// \brief Set keep_dims.
|
||||
void set_keep_dims(const bool keep_dims);
|
||||
/// \brief Set axis.
|
||||
void set_axis(const int64_t &axis);
|
||||
/// \brief Get global_median.
|
||||
///
|
||||
/// \return global_median.
|
||||
bool get_global_median() const;
|
||||
/// \brief Get keep_dims.
|
||||
///
|
||||
/// \return keep_dims.
|
||||
bool get_keep_dims() const;
|
||||
/// \brief Get axis.
|
||||
///
|
||||
/// \return axis.
|
||||
int64_t get_axis() const;
|
||||
};
|
||||
|
||||
abstract::AbstractBasePtr MedianGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
|
|
|
@ -485,7 +485,7 @@ def get_median_vmap_rule(prim, axis_size):
|
|||
axis += rank - 1
|
||||
axis_new = axis + 1 if dim <= axis else axis
|
||||
if keep_dims:
|
||||
dim_new = axis_new
|
||||
dim_new = dim
|
||||
else:
|
||||
dim_new = dim - 1 if dim > axis_new else dim
|
||||
return axis_new, dim_new
|
||||
|
@ -556,6 +556,36 @@ def get_index_add_vmap_rule(prim, axis_size):
|
|||
return vmap_rule
|
||||
|
||||
|
||||
@vmap_rules_getters.register(G.MedianGrad)
|
||||
def get_median_grad_vmap_rule(prim, axis_size):
|
||||
"""VmapRule for MedianGrad."""
|
||||
global_median = prim.global_median
|
||||
axis = prim.axis
|
||||
keep_dims = prim.keep_dims
|
||||
|
||||
@constexpr
|
||||
def trans_grad_axis(axis, rank, dim, keep_dims):
|
||||
if axis < 0:
|
||||
axis += rank - 1
|
||||
axis_new = axis + 1 if dim <= axis else axis
|
||||
if keep_dims:
|
||||
dim_new = dim
|
||||
else:
|
||||
dim_new = dim - 1 if dim > axis_new else dim
|
||||
return axis_new, dim_new
|
||||
|
||||
def vmap_rule(dy_bdim, x_bdim, y_bdim, indices_bdim):
|
||||
dy, _ = dy_bdim
|
||||
x, x_dim = x_bdim
|
||||
y, _ = y_bdim
|
||||
indices, _ = indices_bdim
|
||||
rank = len(x.shape)
|
||||
axis_new, dim_new = trans_grad_axis(axis, rank, x_dim, keep_dims)
|
||||
x_grad = G.MedianGrad(global_median, axis_new, keep_dims)(dy, x, y, indices)
|
||||
return (x_grad, dim_new)
|
||||
return vmap_rule
|
||||
|
||||
|
||||
@vmap_rules_getters.register(linalg_ops.Svd)
|
||||
def get_svd_vmap_rule(prim, axis_size):
|
||||
"""VmapRule for 'Svd' operation."""
|
||||
|
|
|
@ -0,0 +1,136 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from mindspore import Tensor
|
||||
from mindspore.nn import Cell
|
||||
from mindspore.ops.operations.math_ops import Median
|
||||
import mindspore.ops.operations._grad_ops as G
|
||||
from mindspore.ops.composite import GradOperation
|
||||
|
||||
|
||||
class Grad(Cell):
|
||||
def __init__(self, network):
|
||||
super(Grad, self).__init__()
|
||||
self.grad = GradOperation(get_all=False, sens_param=False)
|
||||
self.network = network
|
||||
|
||||
def construct(self, input_x):
|
||||
gout = self.grad(self.network)(input_x)
|
||||
return gout
|
||||
|
||||
|
||||
class MedianC(Cell):
|
||||
def __init__(self, global_median, axis, keep_dims):
|
||||
super().__init__()
|
||||
self.global_median = global_median
|
||||
self.axis = axis
|
||||
self.keep_dims = keep_dims
|
||||
self.median = Median(self.global_median, self.axis, self.keep_dims)
|
||||
|
||||
def construct(self, x):
|
||||
return self.median(x)
|
||||
|
||||
|
||||
class MedianGrad(Cell):
|
||||
def __init__(self, global_median, axis, keep_dims):
|
||||
super().__init__()
|
||||
self.global_median = global_median
|
||||
self.axis = axis
|
||||
self.keep_dims = keep_dims
|
||||
self.median_grad = G.MedianGrad(self.global_median, self.axis, self.keep_dims)
|
||||
|
||||
def construct(self, dy, x, y, indices):
|
||||
return self.median_grad(dy, x, y, indices)
|
||||
|
||||
|
||||
class MedianFactory():
|
||||
def __init__(self, input_shape, global_median, axis=0, keep_dims=False, dtype=np.float32):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
self.input = np.random.randn(*input_shape).astype(self.dtype)
|
||||
self.global_median = global_median
|
||||
self.axis = axis
|
||||
self.keep_dims = keep_dims
|
||||
self.output_grad_np = np.random.randn(*input_shape).astype(dtype=dtype)
|
||||
|
||||
def forward_mindspore_impl(self):
|
||||
net = MedianC(self.global_median, self.axis, self.keep_dims)
|
||||
y, indices = net(Tensor(self.input))
|
||||
return y.asnumpy(), indices.asnumpy()
|
||||
|
||||
def grad_mindspore_impl(self):
|
||||
input_x = Tensor(self.input)
|
||||
net = MedianC(self.global_median, self.axis, self.keep_dims)
|
||||
grad_net = Grad(net)
|
||||
res = grad_net(input_x)
|
||||
return res.asnumpy()
|
||||
|
||||
def forward_pytorch_impl(self):
|
||||
input_pt = torch.from_numpy(self.input)
|
||||
indices = None
|
||||
if self.global_median is False:
|
||||
y, indices = torch.median(input_pt, axis=self.axis, keepdim=self.keep_dims)
|
||||
else:
|
||||
y = torch.median(input_pt)
|
||||
indices_np = None if indices is None else indices.numpy().astype(np.int64)
|
||||
return y.numpy().astype(self.dtype), indices_np
|
||||
|
||||
def global_grad_pytorch_impl(self):
|
||||
input_pt = torch.from_numpy(self.input)
|
||||
input_pt.requires_grad = True
|
||||
y = torch.median(input_pt)
|
||||
y.backward()
|
||||
return input_pt.grad.numpy()
|
||||
|
||||
def grad_pytorch_impl(self):
|
||||
input_pt = torch.from_numpy(self.input)
|
||||
input_pt.requires_grad = True
|
||||
y, _ = torch.median(input_pt, axis=self.axis, keepdim=self.keep_dims)
|
||||
y.sum().backward()
|
||||
return input_pt.grad.numpy()
|
||||
|
||||
def forward_cmp(self):
|
||||
y_pytorch, _ = self.forward_pytorch_impl()
|
||||
y_mindspore, _ = self.forward_mindspore_impl()
|
||||
assert np.allclose(y_pytorch, y_mindspore)
|
||||
|
||||
def grad_cmp(self):
|
||||
grad_ms = self.grad_mindspore_impl()
|
||||
if self.global_median is False:
|
||||
grad_torch = self.grad_pytorch_impl()
|
||||
else:
|
||||
grad_torch = self.global_grad_pytorch_impl()
|
||||
assert np.allclose(grad_ms, grad_torch)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_median_gpu():
|
||||
"""
|
||||
Feature: Test median.
|
||||
Description: Test median and mediangrad in Gpu with different global_median parameter.
|
||||
Expectation: the result match given one.
|
||||
"""
|
||||
|
||||
fact = MedianFactory(input_shape=(5, 5), global_median=True, axis=0, keep_dims=True)
|
||||
fact.forward_cmp()
|
||||
fact.grad_cmp()
|
||||
fact2 = MedianFactory(input_shape=(5, 5, 5), global_median=False, axis=1, keep_dims=False)
|
||||
fact2.forward_cmp()
|
||||
fact2.grad_cmp()
|
Loading…
Reference in New Issue