[OP] add maxpool grad grad with argmax cpu and gpu backend

This commit is contained in:
yangruoqi713 2022-07-04 16:46:34 +08:00
parent 247a0d3d86
commit d4f97e8aaf
17 changed files with 790 additions and 24 deletions

View File

@ -0,0 +1,116 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/cpu/kernel/maxpool_grad_grad_with_argmax_cpu_kernel.h"
#include <algorithm>
#include <functional>
#include <unordered_map>
#include "utils/ms_utils.h"
#include "utils/profile.h"
#include "mindspore/ccsrc/kernel/common_utils.h"
namespace mindspore {
namespace kernel {
namespace {
constexpr size_t kMaxPoolGradGradWithArgmaxInputsNum = 3;
constexpr size_t kMaxPoolGradGradWithArgmaxOutputsNum = 1;
constexpr size_t kArgmaxIndex = 2;
} // namespace
bool MaxPoolGradGradWithArgmaxCpuKernelMod::Init(const BaseOperatorPtr &base_operator,
const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
kernel_name_ = base_operator->name();
if (inputs.size() != kMaxPoolGradGradWithArgmaxInputsNum || outputs.size() != kMaxPoolGradGradWithArgmaxOutputsNum) {
MS_LOG(ERROR) << kernel_name_ << ": input and output size should be " << kMaxPoolGradGradWithArgmaxInputsNum
<< " and " << kMaxPoolGradGradWithArgmaxOutputsNum << ", but get " << inputs.size() << " and "
<< outputs.size();
return false;
}
return MatchKernelFunc(base_operator, inputs, outputs);
}
int MaxPoolGradGradWithArgmaxCpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
int ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost);
if (ret != 0) {
return ret;
}
auto in_shapes = inputs[0]->GetShapeVector();
auto out_shapes = outputs[0]->GetShapeVector();
output_elements_ = std::accumulate(out_shapes.begin(), out_shapes.end(), 1, std::multiplies<size_t>());
input_batch_stride_ = std::accumulate(in_shapes.begin() + 1, in_shapes.end(), 1, std::multiplies<size_t>());
output_batch_stride_ = std::accumulate(out_shapes.begin() + 1, out_shapes.end(), 1, std::multiplies<size_t>());
return KRET_OK;
}
template <typename T, typename I>
bool MaxPoolGradGradWithArgmaxCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kMaxPoolGradGradWithArgmaxInputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kMaxPoolGradGradWithArgmaxOutputsNum, kernel_name_);
T *grad = reinterpret_cast<T *>(inputs[1]->addr);
I *argmax = reinterpret_cast<I *>(inputs[kArgmaxIndex]->addr);
T *out = reinterpret_cast<T *>(outputs[0]->addr);
auto task = [this, grad, argmax, out](size_t start, size_t end) {
for (size_t pos = start; pos < end; pos++) {
const int pos_n = pos / this->output_batch_stride_;
out[pos] = grad[pos_n * this->input_batch_stride_ + argmax[pos]];
}
};
ParallelLaunchAutoSearch(task, output_elements_, this, &parallel_search_info_, pool_);
return true;
}
const std::vector<std::pair<KernelAttr, MaxPoolGradGradWithArgmaxCpuKernelMod::KernelRunFunc>>
&MaxPoolGradGradWithArgmaxCpuKernelMod::GetFuncList() const {
static const std::vector<std::pair<KernelAttr, MaxPoolGradGradWithArgmaxCpuKernelMod::KernelRunFunc>> func_list = {
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat32),
&MaxPoolGradGradWithArgmaxCpuKernelMod::LaunchKernel<float, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat32),
&MaxPoolGradGradWithArgmaxCpuKernelMod::LaunchKernel<float, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat16),
&MaxPoolGradGradWithArgmaxCpuKernelMod::LaunchKernel<float16, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat16),
&MaxPoolGradGradWithArgmaxCpuKernelMod::LaunchKernel<float16, int64_t>},
};
return func_list;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, MaxPoolGradGradWithArgmax, MaxPoolGradGradWithArgmaxCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,66 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MAX_POOL_GRAD_GRAD_WITH_ARGMAX_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MAX_POOL_GRAD_GRAD_WITH_ARGMAX_CPU_KERNEL_H_
#include <string>
#include <vector>
#include <memory>
#include <map>
#include <utility>
#include <unordered_map>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
namespace mindspore {
namespace kernel {
class MaxPoolGradGradWithArgmaxCpuKernelMod : public NativeCpuKernelMod,
public MatchKernelHelper<MaxPoolGradGradWithArgmaxCpuKernelMod> {
public:
MaxPoolGradGradWithArgmaxCpuKernelMod() = default;
~MaxPoolGradGradWithArgmaxCpuKernelMod() = default;
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(
const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost = std::map<uint32_t, tensor::TensorPtr>()) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override {
return kernel_func_(this, inputs, workspace, outputs);
}
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
protected:
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
private:
template <typename T, typename I>
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs);
size_t output_elements_ = 0;
size_t input_batch_stride_ = 0;
size_t output_batch_stride_ = 0;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MAX_POOL_GRAD_GRAD_WITH_ARGMAX_CPU_KERNEL_H_

View File

@ -0,0 +1,61 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <algorithm>
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/maxpool_grad_grad_with_argmax_impl.cuh"
#include "include/cuda_fp16.h"
template <typename T, typename I>
__global__ void MaxPoolGradGradWithArgmax(const T *grad, const I *argmax, const int input_stride,
const int output_stride, const int output_elements, T *output) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (output_elements); pos += blockDim.x * gridDim.x) {
const int posn = pos / output_stride;
output[pos] = grad[posn * input_stride + argmax[pos]];
}
}
template <typename T, typename I>
void CalMaxPoolGradGradWithArgmax(const T *grad, const I *argmax, const int batch, const int input_stride,
const int output_stride, T *output, const uint32_t &device_id,
cudaStream_t cuda_stream) {
const int output_elements = batch * output_stride;
MaxPoolGradGradWithArgmax<<<CUDA_BLOCKS(device_id, output_elements), CUDA_THREADS(device_id), 0, cuda_stream>>>(
grad, argmax, input_stride, output_stride, output_elements, output);
}
template CUDA_LIB_EXPORT void CalMaxPoolGradGradWithArgmax<float, int32_t>(const float *grad, const int32_t *argmax,
const int batch, const int input_stride,
const int output_stride, float *output,
const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalMaxPoolGradGradWithArgmax<float, int64_t>(const float *grad, const int64_t *argmax,
const int batch, const int input_stride,
const int output_stride, float *output,
const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalMaxPoolGradGradWithArgmax<half, int32_t>(const half *grad, const int32_t *argmax,
const int batch, const int input_stride,
const int output_stride, half *output,
const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalMaxPoolGradGradWithArgmax<half, int64_t>(const half *grad, const int64_t *argmax,
const int batch, const int input_stride,
const int output_stride, half *output,
const uint32_t &device_id,
cudaStream_t cuda_stream);

View File

@ -0,0 +1,25 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_MAXPOOL_GRAD_GRAD_IMPL_CUH_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_MAXPOOL_GRAD_GRAD_IMPL_CUH_
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h"
template <typename T, typename I>
CUDA_LIB_EXPORT void CalMaxPoolGradGradWithArgmax(const T *grad, const I *argmax, const int batch,
const int input_stride, const int output_stride, T *output,
const uint32_t &device_id, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_MAXPOOL_GRAD_GRAD_IMPL_CUH_

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_HSHRINK_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_HSHRINK_GPU_KERNEL_H_
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_MAXPOOL_GRAD_GRAD_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_MAXPOOL_GRAD_GRAD_GPU_KERNEL_H_
#include <vector>
#include <map>
#include <utility>
@ -104,4 +104,4 @@ class MaxPool3DGradGradGpuKernelMod : public MaxPoolGradGradGpuKernelMod {
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_HSHRINK_GPU_KERNEL_H_
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_MAXPOOL_GRAD_GRAD_GPU_KERNEL_H_

View File

@ -0,0 +1,111 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/gpu/kernel/nn/maxpool_grad_grad_with_argmax_gpu_kernel.h"
#include <algorithm>
#include <functional>
#include "abstract/utils.h"
#include "plugin/factory/ms_factory.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/maxpool_grad_grad_with_argmax_impl.cuh"
namespace mindspore {
namespace kernel {
namespace {
constexpr size_t kMaxPoolGradGradWithArgmaxInputsNum = 3;
constexpr size_t kMaxPoolGradGradWithArgmaxOutputsNum = 1;
constexpr size_t kArgmaxIndex = 2;
} // namespace
bool MaxPoolGradGradWithArgmaxGpuKernelMod::Init(const BaseOperatorPtr &base_operator,
const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
kernel_name_ = base_operator->name();
if (inputs.size() != kMaxPoolGradGradWithArgmaxInputsNum || outputs.size() != kMaxPoolGradGradWithArgmaxOutputsNum) {
MS_LOG(ERROR) << kernel_name_ << ": input and output size should be " << kMaxPoolGradGradWithArgmaxInputsNum
<< " and " << kMaxPoolGradGradWithArgmaxOutputsNum << ", but get " << inputs.size() << " and "
<< outputs.size();
return false;
}
return MatchKernelFunc(base_operator, inputs, outputs);
}
int MaxPoolGradGradWithArgmaxGpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &) {
auto ret = KernelMod::Resize(base_operator, inputs, outputs);
if (ret != 0) {
return ret;
}
auto in_shapes = inputs[0]->GetShapeVector();
auto out_shapes = outputs[0]->GetShapeVector();
batch_ = out_shapes[0];
output_elements_ = std::accumulate(out_shapes.begin(), out_shapes.end(), 1, std::multiplies<size_t>());
input_batch_stride_ = std::accumulate(in_shapes.begin() + 1, in_shapes.end(), 1, std::multiplies<size_t>());
output_batch_stride_ = std::accumulate(out_shapes.begin() + 1, out_shapes.end(), 1, std::multiplies<size_t>());
return KRET_OK;
}
template <typename T, typename I>
bool MaxPoolGradGradWithArgmaxGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kMaxPoolGradGradWithArgmaxInputsNum, kernel_name_);
T *grad_addr = GetDeviceAddress<T>(inputs, kIndex1);
I *argmax_addr = GetDeviceAddress<I>(inputs, kArgmaxIndex);
T *output_addr = GetDeviceAddress<T>(outputs, kIndex0);
CalMaxPoolGradGradWithArgmax<T, I>(grad_addr, argmax_addr, batch_, input_batch_stride_, output_batch_stride_,
output_addr, device_id_, reinterpret_cast<cudaStream_t>(cuda_stream_));
return true;
}
const std::vector<std::pair<KernelAttr, MaxPoolGradGradWithArgmaxGpuKernelMod::KernelRunFunc>>
&MaxPoolGradGradWithArgmaxGpuKernelMod::GetFuncList() const {
static const std::vector<std::pair<KernelAttr, MaxPoolGradGradWithArgmaxGpuKernelMod::KernelRunFunc>> func_list = {
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat32),
&MaxPoolGradGradWithArgmaxGpuKernelMod::LaunchKernel<float, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat32),
&MaxPoolGradGradWithArgmaxGpuKernelMod::LaunchKernel<float, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat16),
&MaxPoolGradGradWithArgmaxGpuKernelMod::LaunchKernel<half, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat16),
&MaxPoolGradGradWithArgmaxGpuKernelMod::LaunchKernel<half, int64_t>},
};
return func_list;
}
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, MaxPoolGradGradWithArgmax, MaxPoolGradGradWithArgmaxGpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,64 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_MAXPOOL_GRAD_GRAD_WITH_ARGMAX_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_MAXPOOL_GRAD_GRAD_WITH_ARGMAX_GPU_KERNEL_H_
#include <vector>
#include <map>
#include <utility>
#include "plugin/device/gpu/kernel/gpu_kernel.h"
#include "mindspore/core/mindapi/base/types.h"
namespace mindspore {
namespace kernel {
class MaxPoolGradGradWithArgmaxGpuKernelMod : public NativeGpuKernelMod,
public MatchKernelHelper<MaxPoolGradGradWithArgmaxGpuKernelMod> {
public:
MaxPoolGradGradWithArgmaxGpuKernelMod() = default;
~MaxPoolGradGradWithArgmaxGpuKernelMod() override = default;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *cuda_stream) override {
cuda_stream_ = cuda_stream;
return kernel_func_(this, inputs, workspace, outputs);
}
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
protected:
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
private:
template <typename T, typename I>
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs);
void *cuda_stream_{nullptr};
int batch_ = 0;
size_t output_elements_ = 0;
size_t input_batch_stride_ = 0;
size_t output_batch_stride_ = 0;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_MAXPOOL_GRAD_GRAD_WITH_ARGMAX_GPU_KERNEL_H_

View File

@ -566,6 +566,7 @@ GVAR_DEF(PrimitivePtr, kPrimMaxPoolGradGrad, std::make_shared<Primitive>("MaxPoo
GVAR_DEF(PrimitivePtr, kPrimMaxPool3DGradGrad, std::make_shared<Primitive>("MaxPool3DGradGrad"));
GVAR_DEF(PrimitivePtr, kPrimMaxPoolWithArgmax, std::make_shared<Primitive>("MaxPoolWithArgmax"));
GVAR_DEF(PrimitivePtr, kPrimMaxPoolGradWithArgmax, std::make_shared<Primitive>("MaxPoolGradWithArgmax"));
GVAR_DEF(PrimitivePtr, kPrimMaxPoolGradGradWithArgmax, std::make_shared<Primitive>("MaxPoolGradGradWithArgmax"));
GVAR_DEF(PrimitivePtr, kPrimMaxPool3DWithArgmax, std::make_shared<Primitive>("MaxPool3DWithArgmax"));
GVAR_DEF(PrimitivePtr, kPrimApplyCenteredRMSProp, std::make_shared<Primitive>("ApplyCenteredRMSProp"));
GVAR_DEF(PrimitivePtr, kPrimAvgPool, std::make_shared<Primitive>("AvgPool"));

View File

@ -0,0 +1,76 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/grad/max_pool_grad_grad_with_argmax.h"
#include <algorithm>
#include <set>
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/ops/primitive_infer_map.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr MaxPoolGradGradWithArgmaxInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
const int64_t input_dim = 4;
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
(void)CheckAndConvertUtils::CheckInteger("origin input shape size", SizeToLong(x_shape.size()), kEqual, input_dim,
primitive->name());
auto grad_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
(void)CheckAndConvertUtils::CheckInteger("origin output shape size", SizeToLong(grad_shape.size()), kEqual, input_dim,
primitive->name());
auto argmax_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
(void)CheckAndConvertUtils::CheckInteger("grad shape size", SizeToLong(argmax_shape.size()), kEqual, input_dim,
primitive->name());
CheckAndConvertUtils::Check("argmax_shape", x_shape, kEqual, grad_shape, primitive->name(), ValueError);
return std::make_shared<abstract::Shape>(argmax_shape);
}
TypePtr MaxPoolGradGradWithArgmaxInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
std::map<std::string, TypePtr> types;
const std::set<TypePtr> valid_index_types = {kInt32, kInt64};
(void)types.emplace("argmax", input_args[kInputIndex2]->BuildType());
CheckAndConvertUtils::CheckTensorTypeSame(types, valid_index_types, prim->name());
types.clear();
const std::set<TypePtr> valid_data_types = {kFloat16, kFloat32};
(void)types.emplace("x", input_args[0]->BuildType());
(void)types.emplace("grad", input_args[kInputIndex1]->BuildType());
return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_data_types, prim->name());
}
} // namespace
AbstractBasePtr MaxPoolGradGradWithArgmaxInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const int64_t input_num = 3;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
auto infer_type = MaxPoolGradGradWithArgmaxInferType(primitive, input_args);
auto infer_shape = MaxPoolGradGradWithArgmaxInferShape(primitive, input_args);
MS_EXCEPTION_IF_NULL(infer_shape);
return std::make_shared<abstract::AbstractTensor>(infer_type, infer_shape->shape());
}
MIND_API_OPERATOR_IMPL(MaxPoolGradGradWithArgmax, MaxPoolGradGrad);
REGISTER_PRIMITIVE_EVAL_IMPL(MaxPoolGradGradWithArgmax, prim::kPrimMaxPoolGradGradWithArgmax,
MaxPoolGradGradWithArgmaxInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,45 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_MAX_POOL_GRAD_GRAD_WITH_ARGMAX_H_
#define MINDSPORE_CORE_OPS_MAX_POOL_GRAD_GRAD_WITH_ARGMAX_H_
#include <map>
#include <vector>
#include <string>
#include <memory>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
#include "ops/grad/max_pool_grad_grad.h"
namespace mindspore {
namespace ops {
constexpr auto kNameMaxPoolGradGradWithArgmax = "MaxPoolGradGradWithArgmax";
class MIND_API MaxPoolGradGradWithArgmax : public MaxPoolGradGrad {
public:
MIND_API_BASE_MEMBER(MaxPoolGradGradWithArgmax);
/// \brief Constructor.
MaxPoolGradGradWithArgmax() : MaxPoolGradGrad(kNameMaxPoolGradGradWithArgmax) {
InitIOName({"x", "grad", "argmax"}, {"output"});
}
};
abstract::AbstractBasePtr MaxPoolGradGradWithArgmaxInfer(const abstract::AnalysisEnginePtr &,
const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_MAX_POOL_GRAD_GRAD_WITH_ARGMAX_H_

View File

@ -25,9 +25,9 @@ from mindspore.ops.functional import vmap
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
class NetPoolGradGrad(nn.Cell):
class NetMaxPool3DGradGrad(nn.Cell):
def __init__(self, mode, kernel, stride):
super(NetPoolGradGrad, self).__init__()
super(NetMaxPool3DGradGrad, self).__init__()
self.maxpool_grad_grad_fun = G.MaxPool3DGradGrad(pad_mode=mode,
kernel_size=kernel,
strides=stride)
@ -62,7 +62,7 @@ def test_maxpool3d_grad_grad_fp16():
[[[8.5, 8.7],
[9.3, 9.5]]]]])).astype(np.float16)
maxpool3d_grad_grad = NetPoolGradGrad("VALID", 2, 2)
maxpool3d_grad_grad = NetMaxPool3DGradGrad("VALID", 2, 2)
output = maxpool3d_grad_grad(x, out, d)
assert np.allclose(output.asnumpy(), expect_result)
@ -120,7 +120,7 @@ def test_maxpool3d_grad_grad_fp32():
[5.2, 5.3, 5.3],
[5.2, 5.3, 5.3]]]]])).astype(np.float32)
maxpool3d_grad_grad = NetPoolGradGrad("SAME", 3, 1)
maxpool3d_grad_grad = NetMaxPool3DGradGrad("SAME", 3, 1)
output = maxpool3d_grad_grad(x, out, d)
assert np.allclose(output.asnumpy(), expect_result)
@ -135,7 +135,7 @@ def test_maxpool3d_grad_grad_vmap(axis):
Description: test the rightness of MaxPool3DGradGrad cpu kernel vmap feature.
Expectation: Success.
"""
maxpool3d_grad_grad = NetPoolGradGrad("SAME", 3, 1)
maxpool3d_grad_grad = NetMaxPool3DGradGrad("SAME", 3, 1)
x = np.random.random((2, 3, 5, 5, 5, axis)).astype(np.float32)
y = np.random.random((2, 3, 5, 5, 5, axis)).astype(np.float32)

View File

@ -25,9 +25,9 @@ from mindspore.ops.functional import vmap
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
class NetPoolGradGrad(nn.Cell):
class NetMaxPoolGradGrad(nn.Cell):
def __init__(self, mode, kernel, stride):
super(NetPoolGradGrad, self).__init__()
super(NetMaxPoolGradGrad, self).__init__()
self.maxpool_grad_grad_fun = G.MaxPoolGradGrad(pad_mode=mode,
kernel_size=kernel,
strides=stride)
@ -61,7 +61,7 @@ def test_maxpool2d_grad_grad_fp16():
[[3.7, 3.9],
[4.5, 4.7]]]])).astype(np.float16)
maxpool2d_grad_grad = NetPoolGradGrad("VALID", 2, 2)
maxpool2d_grad_grad = NetMaxPoolGradGrad("VALID", 2, 2)
output = maxpool2d_grad_grad(x, out, d)
assert np.allclose(output.asnumpy(), expect_result)
@ -100,7 +100,7 @@ def test_maxpool2d_grad_grad_fp32():
[4.6, 4.7, 4.8, 4.9, 4.9],
[4.6, 4.7, 4.8, 4.9, 4.9]]]])).astype(np.float32)
maxpool2d_grad_grad = NetPoolGradGrad("SAME", 3, 1)
maxpool2d_grad_grad = NetMaxPoolGradGrad("SAME", 3, 1)
output = maxpool2d_grad_grad(x, out, d)
assert np.allclose(output.asnumpy(), expect_result)
@ -115,7 +115,7 @@ def test_maxpool2d_grad_grad_vmap(axis):
Description: test the rightness of MaxPool2dGradGrad cpu kernel vmap feature.
Expectation: Success.
"""
maxpool2d_grad_grad = NetPoolGradGrad("SAME", 3, 1)
maxpool2d_grad_grad = NetMaxPoolGradGrad("SAME", 3, 1)
x = np.random.random((2, 3, 5, 5, axis)).astype(np.float32)
y = np.random.random((2, 3, 5, 5, axis)).astype(np.float32)

View File

@ -0,0 +1,107 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops.operations import _grad_ops as G
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
class NetMaxPoolGradGradWithArgmax(nn.Cell):
def __init__(self, mode, kernel, stride):
super(NetMaxPoolGradGradWithArgmax, self).__init__()
self.maxpool_grad_grad_argmax_fun = G.MaxPoolGradGradWithArgmax(pad_mode=mode,
kernel_size=kernel,
strides=stride)
def construct(self, x, out, grad):
return self.maxpool_grad_grad_argmax_fun(x, out, grad)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize("argmax_type", [np.int32, np.int64])
def test_maxpool_grad_grad_argmax_fp16(argmax_type):
"""
Feature: MaxPoolGradGradWithArgmax cpu kernel
Description: test the rightness of MaxPoolGradGradWithArgmax cpu kernel, pad_mode: VALID, dtype: float16
Expectation: the output is same as expect output
"""
data = (np.arange(1 * 2 * 6 * 6).astype(np.float16)).reshape(1, 2, 6, 6)
x = Tensor(data)
grad = Tensor(data / 10)
argmax = Tensor(np.array([[[[7, 9, 11],
[19, 21, 23],
[31, 33, 35]],
[[43, 45, 47],
[55, 57, 59],
[67, 69, 71]]]]).astype(argmax_type))
expect_result = (np.array([[[[0.7, 0.9, 1.1],
[1.9, 2.1, 2.3],
[3.1, 3.3, 3.5]],
[[4.3, 4.5, 4.7],
[5.5, 5.7, 5.9],
[6.7, 6.9, 7.1]]]])).astype(np.float16)
maxpool_grad_grad_argmax = NetMaxPoolGradGradWithArgmax("VALID", 2, 2)
output = maxpool_grad_grad_argmax(x, grad, argmax)
assert np.allclose(output.asnumpy(), expect_result)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize("argmax_type", [np.int32, np.int64])
def test_maxpool_grad_grad_argmax_fp32(argmax_type):
"""
Feature: MaxPoolGradGradWithArgmax cpu kernel
Description: test the rightness of MaxPoolGradGradWithArgmax cpu kernel, pad_mode: SAME, dtype: float
Expectation: the output is same as expect output
"""
data = (np.arange(2 * 1 * 5 * 5).astype(np.float32)).reshape(2, 1, 5, 5)
x = Tensor(-1 * data)
grad = Tensor(data / 10)
argmax = Tensor(np.array([[[[0, 0, 1, 2, 3],
[0, 0, 1, 2, 3],
[5, 5, 6, 7, 8],
[10, 10, 11, 12, 13],
[15, 15, 16, 17, 18]]],
[[[0, 0, 1, 2, 3],
[0, 0, 1, 2, 3],
[5, 5, 6, 7, 8],
[10, 10, 11, 12, 13],
[15, 15, 16, 17, 18]]]]
).astype(argmax_type))
expect_result = (np.array(
[[[[0, 0, 0.1, 0.2, 0.3],
[0, 0, 0.1, 0.2, 0.3],
[0.5, 0.5, 0.6, 0.7, 0.8],
[1.0, 1.0, 1.1, 1.2, 1.3],
[1.5, 1.5, 1.6, 1.7, 1.8]]],
[[[2.5, 2.5, 2.6, 2.7, 2.8],
[2.5, 2.5, 2.6, 2.7, 2.8],
[3.0, 3.0, 3.1, 3.2, 3.3],
[3.5, 3.5, 3.6, 3.7, 3.8],
[4.0, 4.0, 4.1, 4.2, 4.3]]]])).astype(np.float32)
maxpool_grad_grad_argmax = NetMaxPoolGradGradWithArgmax("SAME", 3, 1)
output = maxpool_grad_grad_argmax(x, grad, argmax)
assert np.allclose(output.asnumpy(), expect_result)

View File

@ -24,9 +24,9 @@ from mindspore.ops.operations import _grad_ops as G
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
class NetPoolGradGrad(nn.Cell):
class NetMaxPool3DGradGrad(nn.Cell):
def __init__(self, mode, kernel, stride):
super(NetPoolGradGrad, self).__init__()
super(NetMaxPool3DGradGrad, self).__init__()
self.maxpool_grad_grad_fun = G.MaxPool3DGradGrad(pad_mode=mode,
kernel_size=kernel,
strides=stride)
@ -61,7 +61,7 @@ def test_maxpool3d_grad_grad_fp16():
[12.7, 12.9, 13.1],
[13.9, 14.1, 14.3]]]]]).astype(np.float16)
maxpool3d_grad_grad = NetPoolGradGrad("VALID", 2, 2)
maxpool3d_grad_grad = NetMaxPool3DGradGrad("VALID", 2, 2)
output = maxpool3d_grad_grad(x, out, d)
assert np.allclose(output.asnumpy(), expect_result)
@ -119,6 +119,6 @@ def test_maxpool3d_grad_grad_fp32():
[5.2, 5.3, 5.3],
[5.2, 5.3, 5.3]]]]])).astype(np.float32)
maxpool3d_grad_grad = NetPoolGradGrad("SAME", 3, 1)
maxpool3d_grad_grad = NetMaxPool3DGradGrad("SAME", 3, 1)
output = maxpool3d_grad_grad(x, out, d)
assert np.allclose(output.asnumpy(), expect_result)

View File

@ -24,9 +24,9 @@ from mindspore.ops.operations import _grad_ops as G
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
class NetPoolGradGrad(nn.Cell):
class NetMaxPoolGradGrad(nn.Cell):
def __init__(self, mode, kernel, stride):
super(NetPoolGradGrad, self).__init__()
super(NetMaxPoolGradGrad, self).__init__()
self.maxpool_grad_grad_fun = G.MaxPoolGradGrad(pad_mode=mode,
kernel_size=kernel,
strides=stride)
@ -35,7 +35,7 @@ class NetPoolGradGrad(nn.Cell):
return self.maxpool_grad_grad_fun(x, out, grad)
@pytest.mark.level1
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_maxpool2d_grad_grad_fp16():
@ -58,12 +58,12 @@ def test_maxpool2d_grad_grad_fp16():
[3.1, 3.3, 3.5]
]]])).astype(np.float16)
maxpool2d_grad_grad = NetPoolGradGrad("VALID", 2, 2)
maxpool2d_grad_grad = NetMaxPoolGradGrad("VALID", 2, 2)
output = maxpool2d_grad_grad(x, out, d)
assert np.allclose(output.asnumpy(), expect_result)
@pytest.mark.level1
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_maxpool2d_grad_grad_fp32():
@ -88,6 +88,6 @@ def test_maxpool2d_grad_grad_fp32():
[-1.6, -1.7, -1.7],
[-1.6, -1.7, -1.7]]]]).astype(np.float32))
maxpool2d_grad_grad = NetPoolGradGrad("SAME", 3, 1)
maxpool2d_grad_grad = NetMaxPoolGradGrad("SAME", 3, 1)
output = maxpool2d_grad_grad(x, out, d)
assert np.allclose(output.asnumpy(), expect_result)

View File

@ -0,0 +1,94 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops.operations import _grad_ops as G
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
class NetMaxPoolGradGradWithArgmax(nn.Cell):
def __init__(self, mode, kernel, stride):
super(NetMaxPoolGradGradWithArgmax, self).__init__()
self.maxpool_grad_grad_argmax_fun = G.MaxPoolGradGradWithArgmax(pad_mode=mode,
kernel_size=kernel,
strides=stride)
def construct(self, x, out, grad):
return self.maxpool_grad_grad_argmax_fun(x, out, grad)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize("argmax_type", [np.int32, np.int64])
def test_maxpool_grad_grad_argmax_fp16(argmax_type):
"""
Feature: MaxPoolGradGradWithArgmax gpu kernel
Description: test the rightness of MaxPoolGradGradWithArgmax gpu kernel, pad_mode: VALID, dtype: float16
Expectation: the output is same as expect output
"""
data = (np.arange(1 * 2 * 4 * 4).astype(np.float16)).reshape(1, 2, 4, 4)
x = Tensor(data)
grad = Tensor(data / 10)
argmax = Tensor(np.array([[[[5, 7],
[13, 15]],
[[21, 23],
[29, 31]]]]).astype(argmax_type))
expect_result = (np.array([[[[0.5, 0.7],
[1.3, 1.5]],
[[2.1, 2.3],
[2.9, 3.1]]]])).astype(np.float16)
maxpool_grad_grad_argmax = NetMaxPoolGradGradWithArgmax("VALID", 2, 2)
output = maxpool_grad_grad_argmax(x, grad, argmax)
assert np.allclose(output.asnumpy(), expect_result)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize("argmax_type", [np.int32, np.int64])
def test_maxpool_grad_grad_argmax_fp32(argmax_type):
"""
Feature: MaxPoolGradGradWithArgmax gpu kernel
Description: test the rightness of MaxPoolGradGradWithArgmax gpu kernel, pad_mode: SAME, dtype: float
Expectation: the output is same as expect output
"""
data = (np.arange(2 * 1 * 3 * 3).astype(np.float32)
).reshape(2, 1, 3, 3) * (-1)
x = Tensor(data)
grad = Tensor(data / 10)
argmax = Tensor(np.array([[[[0, 0, 1],
[0, 0, 1],
[3, 3, 4]]],
[[[0, 0, 1],
[0, 0, 1],
[3, 3, 4]]]]).astype(argmax_type))
expect_result = (np.array([[[[0, 0, -0.1],
[0, 0, -0.1],
[-0.3, -0.3, -0.4]]],
[[[-0.9, -0.9, -1.0],
[-0.9, -0.9, -1.0],
[-1.2, -1.2, -1.3]]]]).astype(np.float32))
maxpool_grad_grad_argmax = NetMaxPoolGradGradWithArgmax("SAME", 3, 1)
output = maxpool_grad_grad_argmax(x, grad, argmax)
assert np.allclose(output.asnumpy(), expect_result)

View File

@ -2812,7 +2812,7 @@ test_case_nn_ops = [
'block': G.MaxPoolGradGradWithArgmax(),
'desc_inputs': [Tensor(np.random.rand(1, 1, 2, 2), mstype.float16),
Tensor(np.random.rand(1, 1, 2, 2), mstype.float16),
Tensor(np.zeros((1, 1, 2, 2)), mstype.uint16)],
Tensor(np.zeros((1, 1, 2, 2)), mstype.int32)],
'desc_bprop': [],
'skip': ['backward']}),
('Roll', {