diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/pooling_avg_grad_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/pooling_avg_grad_cpu_kernel.h index a448019f889..bd24220b55a 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/pooling_avg_grad_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/pooling_avg_grad_cpu_kernel.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -37,7 +37,7 @@ class AvgPoolingGradCPUKernel : public MKLCPUKernel { std::vector kernel_size_; }; -MS_REG_CPU_KERNEL(AvgPoolGradCpu, +MS_REG_CPU_KERNEL(AvgPoolGrad, KernelAttr() .AddInputAttr(kNumberTypeFloat32) .AddInputAttr(kNumberTypeFloat32) diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_grad_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_grad_gpu_kernel.cc index 2948c900d24..183197414ec 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_grad_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_grad_gpu_kernel.cc @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -32,14 +32,14 @@ MS_REG_GPU_KERNEL_ONE(MaxPoolGrad, .AddInputAttr(kNumberTypeFloat16) .AddOutputAttr(kNumberTypeFloat16), PoolingGradGpuKernel, half) -MS_REG_GPU_KERNEL_ONE(AvgPoolGradGpu, +MS_REG_GPU_KERNEL_ONE(AvgPoolGrad, KernelAttr() .AddInputAttr(kNumberTypeFloat32) .AddInputAttr(kNumberTypeFloat32) .AddInputAttr(kNumberTypeFloat32) .AddOutputAttr(kNumberTypeFloat32), PoolingGradGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(AvgPoolGradGpu, +MS_REG_GPU_KERNEL_ONE(AvgPoolGrad, KernelAttr() .AddInputAttr(kNumberTypeFloat16) .AddInputAttr(kNumberTypeFloat16) diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_grad_gpu_kernel.h index d92cac8f44e..e385d74c2e8 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_grad_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_grad_gpu_kernel.h @@ -254,7 +254,7 @@ class PoolingGradGpuKernel : public GpuKernel { } void SetPoolingMode(const CNodePtr &kernel_node) { mode_ = AnfAlgo::GetCNodeName(kernel_node); - if (mode_ == "AvgPoolGradGpu") { + if (mode_ == "AvgPoolGrad") { pooling_mode_ = CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING; pad_value_ = 0.0; } else { diff --git a/mindspore/ccsrc/backend/optimizer/ascend/mindir/avg_pool_grad_unify_mindir.cc b/mindspore/ccsrc/backend/optimizer/ascend/mindir/avg_pool_grad_unify_mindir.cc new file mode 100644 index 00000000000..bb08950612a --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/mindir/avg_pool_grad_unify_mindir.cc @@ -0,0 +1,200 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/ascend/mindir/avg_pool_grad_unify_mindir.h" + +#include +#include +#include +#include +#include + +#include "utils/utils.h" +#include "utils/ms_context.h" +#include "utils/check_convert_utils.h" +#include "backend/optimizer/common/helper.h" +#include "runtime/device/kernel_info.h" +#include "backend/session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace opt { +namespace { +constexpr size_t kAvgPoolGradInputNum = 3; +constexpr size_t kShapeDimNum = 4; +constexpr float kKernelMatrixInitNum = 1.0; +constexpr size_t kFloat32Len = 4; // size of float32 + +std::vector GetInputXShape(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + std::vector shapes; + auto shape_size_t = AnfAlgo::GetPrevNodeOutputInferShape(node, 0); + std::transform(shape_size_t.begin(), shape_size_t.end(), std::back_inserter(shapes), SizeToLong); + return shapes; +} + +int64_t windowed_output_size(int64_t input_size, int64_t ksize, int64_t stride, PadMode pad_mode, int64_t *pad_before, + int64_t *pad_after) { + int64_t output = 0; + *pad_before = 0; + *pad_after = 0; + if (pad_mode == PadMode::VALID) { + output = (input_size - ksize + stride) / stride; + } else if (pad_mode == PadMode::SAME) { + output = (input_size + stride - 1) / stride; + int64_t pad_need = std::max(int64_t(0), (output - 1) * stride + ksize - input_size); + *pad_before = pad_need / 2; + *pad_after = pad_need - *pad_before; + } else { + MS_LOG(EXCEPTION) << "The pad mode of AvgPoolGrad should be SAME or VALID."; + } + return output; +} + +ValueNodePtr CreateMeanMatrixValueNode(const FuncGraphPtr &func_graph, const std::vector &x_shape, + const std::vector &k_size, const std::vector &stride, + const PadMode pad_mode, const TypeId x_dtype) { + MS_EXCEPTION_IF_NULL(func_graph); + auto kernel_graph = func_graph->cast(); + MS_EXCEPTION_IF_NULL(kernel_graph); + if (x_shape.size() != kShapeDimNum || k_size.size() != kShapeDimNum || stride.size() != kShapeDimNum) { + MS_LOG(EXCEPTION) << "The dim of x_shape or kernel_size or strides of AvgPoolGrad should be 4."; + } + int64_t pad_top, pad_bottom, pad_left, pad_right; + int64_t h_output = windowed_output_size(x_shape[2], k_size[2], stride[2], pad_mode, &pad_top, &pad_bottom); + int64_t w_output = windowed_output_size(x_shape[3], k_size[3], stride[3], pad_mode, &pad_left, &pad_right); + + // `assist_input_matrix` is a 2d matrix with input_shape after padding, + // the value of element which is padded is 0, else are 1. + // For each element of output, it is mapped for slide window: `[h*h_stride : h*h_stride + h_ksize, + // w*w_stride : w*w_stride + w_ksize]` of `assist_input_matrix`, so the sum of slide window is the + // number of input that associate with output element. + std::vector> assist_input_matrix; + std::vector in_shape_after_padding_2d = {x_shape[2] + pad_top + pad_bottom, + x_shape[3] + pad_left + pad_right}; + std::vector tmp_zero_vector(in_shape_after_padding_2d[1], 0.0); + std::vector tmp_one_vector(in_shape_after_padding_2d[1], 1.0); + for (int64_t i = 0; i < in_shape_after_padding_2d[1]; ++i) { + if (i < pad_left || i >= (in_shape_after_padding_2d[1] - pad_right)) { + tmp_one_vector[i] = 0.0; + } + } + for (int64_t i = 0; i < in_shape_after_padding_2d[0]; ++i) { + if (i < pad_top || i >= (in_shape_after_padding_2d[0] - pad_bottom)) { + assist_input_matrix.emplace_back(tmp_zero_vector); + } else { + assist_input_matrix.emplace_back(tmp_one_vector); + } + } + + // calculate output + std::vector hw_output(h_output * w_output, 0.0); + for (int64_t h = 0; h < h_output; ++h) { + for (int64_t w = 0; w < w_output; ++w) { + float curr_sum = 0; + for (int64_t i = h * stride[2]; i < h * stride[2] + k_size[2]; ++i) { + for (int64_t j = w * stride[3]; j < w * stride[3] + k_size[3]; ++j) { + curr_sum += assist_input_matrix[i][j]; + } + } + if (curr_sum > 0) { + hw_output[h * w_output + w] = 1.0 / curr_sum; + } + } + } + + // make output tensor + std::vector output_shape = {x_shape[0], x_shape[1], h_output, w_output}; + auto output_size = std::accumulate(output_shape.begin(), output_shape.end(), int64_t(1), std::multiplies()); + std::vector output(output_size, 0.0); + for (int64_t i = 0; i < output_shape[0] * output_shape[1]; ++i) { + size_t copy_size = hw_output.size() * kFloat32Len; + (void)memcpy_s(&output[i * hw_output.size()], copy_size, &hw_output[0], copy_size); + } + auto output_tensor = std::make_shared(x_dtype, output_shape, &output[0], kNumberTypeFloat32); + MS_EXCEPTION_IF_NULL(output_tensor); + auto abstract = std::make_shared(TypeIdToType(x_dtype), output_shape); + MS_EXCEPTION_IF_NULL(abstract); + auto mean_matrix_vnode = kernel_graph->NewValueNode(abstract, output_tensor); + MS_EXCEPTION_IF_NULL(mean_matrix_vnode); + kernel_graph->AddValueNodeToGraph(mean_matrix_vnode); + return mean_matrix_vnode; +} + +ValueNodePtr CreateKernelMatrixValueNode(const FuncGraphPtr &func_graph, const std::vector &x_shape, + const std::vector &k_size, const TypeId x_dtype) { + MS_EXCEPTION_IF_NULL(func_graph); + auto kernel_graph = func_graph->cast(); + MS_EXCEPTION_IF_NULL(kernel_graph); + if (x_shape.size() != kShapeDimNum || k_size.size() != kShapeDimNum) { + MS_LOG(EXCEPTION) << "The dim of x_shape or kernel_size of AvgPoolGrad should be 4."; + } + std::vector kernel_shape = {1, x_shape[1], k_size[2], k_size[3]}; + auto data_size = std::accumulate(kernel_shape.begin(), kernel_shape.end(), int64_t(1), std::multiplies()); + std::vector data(data_size, kKernelMatrixInitNum); + auto kernel_matrix_tensor = std::make_shared(x_dtype, kernel_shape, &data[0], kNumberTypeFloat32); + MS_EXCEPTION_IF_NULL(kernel_matrix_tensor); + auto abstract = std::make_shared(TypeIdToType(x_dtype), kernel_shape); + MS_EXCEPTION_IF_NULL(abstract); + auto kernel_matrix_vnode = kernel_graph->NewValueNode(abstract, kernel_matrix_tensor); + MS_EXCEPTION_IF_NULL(kernel_matrix_vnode); + kernel_graph->AddValueNodeToGraph(kernel_matrix_vnode); + return kernel_matrix_vnode; +} +} // namespace + +const BaseRef AvgPoolGradUnifyMindIR::DefinePattern() const { + VarPtr X1 = std::make_shared(); + VarPtr X2 = std::make_shared(); + VarPtr G = std::make_shared(); + VectorRef pattern({prim::kPrimAvgPoolGrad, X1, X2, G}); + return pattern; +} + +const AnfNodePtr AvgPoolGradUnifyMindIR::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, + const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + auto avgpool_grad = CheckAnfNodeIfCNodeAndInputSize(node, kAvgPoolGradInputNum); + + auto x_shape = GetInputXShape(avgpool_grad); + auto x_dtype = AnfAlgo::GetPrevNodeOutputInferDataType(avgpool_grad, 0); + auto k_size = AnfAlgo::GetNodeAttr>(avgpool_grad, kAttrKernelSize); + auto stride = AnfAlgo::GetNodeAttr>(avgpool_grad, kAttrStrides); + auto pad_mode = PadMode(AnfAlgo::GetNodeAttr(avgpool_grad, kAttrPadMode)); + + auto x_shape_vnode = CreateShapeValueNode(graph, x_shape); + auto mean_matrix_vnode = CreateMeanMatrixValueNode(graph, x_shape, k_size, stride, pad_mode, x_dtype); + auto kernel_matrix_vnode = CreateKernelMatrixValueNode(graph, x_shape, k_size, x_dtype); + + std::vector avgpool_grad_vm_inputs = {NewValueNode(std::make_shared(kAvgPoolGradVmOpName)), + x_shape_vnode, avgpool_grad->input(3), mean_matrix_vnode, + kernel_matrix_vnode}; + auto avgpool_grad_vm = graph->NewCNode(avgpool_grad_vm_inputs); + MS_EXCEPTION_IF_NULL(avgpool_grad_vm); + avgpool_grad_vm->set_scope(avgpool_grad->scope()); + avgpool_grad_vm->set_abstract(avgpool_grad->abstract()); + AnfAlgo::CopyNodeAttr(kAttrKernelSize, avgpool_grad, avgpool_grad_vm); + AnfAlgo::CopyNodeAttr(kAttrStrides, avgpool_grad, avgpool_grad_vm); + AnfAlgo::CopyNodeAttr(kAttrPadMode, avgpool_grad, avgpool_grad_vm); + AnfAlgo::CopyNodeAttr(kAttrFormat, avgpool_grad, avgpool_grad_vm); + auto input_names = std::vector{"x_origin", "grad", "mean_matrix", "kernel_matrix"}; + AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(input_names), avgpool_grad_vm); + auto output_names = std::vector{"output"}; + AnfAlgo::SetNodeAttr(kAttrOutputNames, MakeValue(output_names), avgpool_grad_vm); + return avgpool_grad_vm; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/mindir/avg_pool_grad_unify_mindir.h b/mindspore/ccsrc/backend/optimizer/ascend/mindir/avg_pool_grad_unify_mindir.h new file mode 100644 index 00000000000..22da36af68a --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/mindir/avg_pool_grad_unify_mindir.h @@ -0,0 +1,34 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_AVG_POOL_GRAD_UNIFY_MINDIR_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_AVG_POOL_GRAD_UNIFY_MINDIR_H_ + +#include +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class AvgPoolGradUnifyMindIR : public PatternProcessPass { + public: + explicit AvgPoolGradUnifyMindIR(bool multigraph = true) + : PatternProcessPass("avg_pool_grad_unify_mindir", multigraph) {} + ~AvgPoolGradUnifyMindIR() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_AVG_POOL_GRAD_UNIFY_MINDIR_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/mindir/dropout_unify_mindir.cc b/mindspore/ccsrc/backend/optimizer/ascend/mindir/dropout_unify_mindir.cc index 01af6f5ee45..1f307b1ab4e 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/mindir/dropout_unify_mindir.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/mindir/dropout_unify_mindir.cc @@ -104,47 +104,6 @@ ValueNodePtr CreateKeepPorbValueNode(const FuncGraphPtr &func_graph, const AnfNo return keep_prob_value; } -ValueNodePtr CreateShapeValueNode(const FuncGraphPtr &func_graph, const std::vector &shape, - bool is_pynative = false) { - MS_LOG(INFO) << "CreateShapeValueNode start."; - MS_EXCEPTION_IF_NULL(func_graph); - auto kernel_graph = func_graph->cast(); - MS_EXCEPTION_IF_NULL(kernel_graph); - ValuePtr shape_value = nullptr; - AbstractBasePtr abstract = nullptr; - if (is_pynative) { - // pynative mode need to create tensor - int64_t shape_dim = SizeToLong(shape.size()); - std::vector shape_vec_shape = {shape_dim}; - auto shape_tensor = std::make_shared(kNumberTypeInt64, shape_vec_shape); - MS_EXCEPTION_IF_NULL(shape_tensor); - auto data_ptr = shape_tensor->data_c(); - MS_EXCEPTION_IF_NULL(data_ptr); - auto elem_num = shape.size() * kInt64Len; - auto ret_code = memcpy_s(data_ptr, static_cast(shape_tensor->data().nbytes()), &shape[0], elem_num); - if (ret_code != 0) { - MS_LOG(EXCEPTION) << "Failed to copy data into Tensor."; - } - shape_value = shape_tensor; - abstract = std::make_shared(kInt64, shape_vec_shape); - } else { - std::vector dim_values{}; - abstract::AbstractBasePtrList abs{}; - for (const auto &dim : shape) { - dim_values.push_back(MakeValue(dim)); - abs.push_back(std::make_shared(dim)); - } - shape_value = std::make_shared(dim_values); - abstract = std::make_shared(abs); - } - MS_EXCEPTION_IF_NULL(shape_value); - MS_EXCEPTION_IF_NULL(abstract); - auto shape_value_node = kernel_graph->NewValueNode(abstract, shape_value); - MS_EXCEPTION_IF_NULL(shape_value_node); - kernel_graph->AddValueNodeToGraph(shape_value_node); - return shape_value_node; -} - std::vector CalDropoutGenMaskOutput(const std::vector &shape) { auto output_size = std::accumulate(shape.begin(), shape.end(), static_cast(1), std::multiplies()); auto output_count = output_size / kMaskAlignNum; diff --git a/mindspore/ccsrc/backend/optimizer/common/helper.cc b/mindspore/ccsrc/backend/optimizer/common/helper.cc index 5eb80b6c3b9..c534c6fb418 100644 --- a/mindspore/ccsrc/backend/optimizer/common/helper.cc +++ b/mindspore/ccsrc/backend/optimizer/common/helper.cc @@ -35,6 +35,8 @@ namespace mindspore { namespace opt { constexpr size_t kType32Len = 4; +constexpr size_t kType64Len = 8; + std::vector Convert2Int(const std::vector &v) { std::vector result; (void)std::transform(v.begin(), v.end(), std::back_inserter(result), SizeToInt); @@ -495,6 +497,46 @@ CNodePtr CreatTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePtr return tuple_getitem; } +ValueNodePtr CreateShapeValueNode(const FuncGraphPtr &func_graph, const std::vector &shape, bool to_tensor) { + MS_EXCEPTION_IF_NULL(func_graph); + auto kernel_graph = func_graph->cast(); + MS_EXCEPTION_IF_NULL(kernel_graph); + ValuePtr shape_value = nullptr; + AbstractBasePtr abstract = nullptr; + if (to_tensor) { + // create Tensor + int64_t shape_dim = SizeToLong(shape.size()); + std::vector shape_vec_shape = {shape_dim}; + auto shape_tensor = std::make_shared(kNumberTypeInt64, shape_vec_shape); + MS_EXCEPTION_IF_NULL(shape_tensor); + auto data_ptr = shape_tensor->data_c(); + MS_EXCEPTION_IF_NULL(data_ptr); + auto elem_num = shape.size() * kType64Len; + auto ret_code = memcpy_s(data_ptr, static_cast(shape_tensor->data().nbytes()), &shape[0], elem_num); + if (ret_code != 0) { + MS_LOG(EXCEPTION) << "Failed to copy data into Tensor."; + } + shape_value = shape_tensor; + abstract = std::make_shared(kInt64, shape_vec_shape); + } else { + // create ValueTuple + std::vector dim_values{}; + abstract::AbstractBasePtrList abs{}; + for (const auto &dim : shape) { + dim_values.push_back(MakeValue(dim)); + abs.push_back(std::make_shared(dim)); + } + shape_value = std::make_shared(dim_values); + abstract = std::make_shared(abs); + } + MS_EXCEPTION_IF_NULL(shape_value); + MS_EXCEPTION_IF_NULL(abstract); + auto shape_value_node = kernel_graph->NewValueNode(abstract, shape_value); + MS_EXCEPTION_IF_NULL(shape_value_node); + kernel_graph->AddValueNodeToGraph(shape_value_node); + return shape_value_node; +} + void ConstInputToAttr(const CNodePtr &cnode, const std::unordered_set &input_attrs) { MS_EXCEPTION_IF_NULL(cnode); std::vector new_inputs; diff --git a/mindspore/ccsrc/backend/optimizer/common/helper.h b/mindspore/ccsrc/backend/optimizer/common/helper.h index ee3af39da30..a20a4234b51 100644 --- a/mindspore/ccsrc/backend/optimizer/common/helper.h +++ b/mindspore/ccsrc/backend/optimizer/common/helper.h @@ -168,6 +168,9 @@ void RemoveNopNode(session::KernelGraph *const graph); CNodePtr CreatTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_idx); +ValueNodePtr CreateShapeValueNode(const FuncGraphPtr &func_graph, const std::vector &shape, + bool to_tensor = false); + bool IsUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node); std::shared_ptr>> GetRealNodeUsedList(const FuncGraphPtr &graph, diff --git a/mindspore/ccsrc/backend/session/ascend_session.cc b/mindspore/ccsrc/backend/session/ascend_session.cc index 3f31f7f324e..5ffa1881a02 100644 --- a/mindspore/ccsrc/backend/session/ascend_session.cc +++ b/mindspore/ccsrc/backend/session/ascend_session.cc @@ -41,6 +41,7 @@ #include "backend/optimizer/ascend/mindir/optimizer_unify_output.h" #include "backend/optimizer/ascend/mindir/sparse_softmax_cross_entropy_with_logits_unify_mindir.h" #include "backend/optimizer/ascend/mindir/slice_grad_unify_mindir.h" +#include "backend/optimizer/ascend/mindir/avg_pool_grad_unify_mindir.h" #include "runtime/device/kernel_adjust.h" #include "runtime/device/ascend/ascend_stream_assign.h" #include "backend/session/anf_runtime_algorithm.h" @@ -225,6 +226,7 @@ void AscendSession::UnifyMindIR(const KernelGraphPtr &graph) { unify_mindir_pm->AddPass(std::make_shared()); unify_mindir_pm->AddPass(std::make_shared()); unify_mindir_pm->AddPass(std::make_shared()); + unify_mindir_pm->AddPass(std::make_shared()); unify_mindir_pm->AddPass(std::make_shared()); unify_mindir_pm->AddPass(std::make_shared()); unify_mindir_pm->AddPass(std::make_shared()); diff --git a/mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.h b/mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.h index 20342200977..d15de86529f 100644 --- a/mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.h +++ b/mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.h @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -42,7 +42,7 @@ static std::map, std::vector> {prim::kPrimMaxPool->name(), {{0}, {0}}}, {prim::kPrimMaxPoolGrad->name(), {{0, 1, 2}, {0}}}, {kAvgPoolOpName, {{0}, {0}}}, - {kAvgPoolGradGpuOpName, {{0, 1, 2}, {0}}}, + {kAvgPoolGradOpName, {{0, 1, 2}, {0}}}, {kFusedBatchNormEx, {{0}, {0}}}, {kFusedBatchNormExWithActivation, {{0}, {0}}}, {kFusedBatchNormExWithAddAndActivation, {{0, 5}, {0}}}, diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index 137412eddbd..34842bef3f6 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -214,7 +214,8 @@ constexpr auto kEmbeddingLookupProxyOpName = "EmbeddingLookupProxy"; constexpr auto kGatherV2OpName = "Gather"; constexpr auto kPaddingOpName = "Padding"; constexpr auto kAvgPoolOpName = "AvgPool"; -constexpr auto kAvgPoolGradGpuOpName = "AvgPoolGradGpu"; +constexpr auto kAvgPoolGradOpName = "AvgPoolGrad"; +constexpr auto kAvgPoolGradVmOpName = "AvgPoolGradVm"; constexpr auto kmaxPoolGradOpName = "MaxPoolGrad"; constexpr auto kMaxPoolWithArgmaxOpName = "MaxPoolWithArgmax"; constexpr auto kMaxPoolGradWithArgmaxOpName = "MaxPoolGradWithArgmax"; diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index 5624c2a87f8..0617250edd9 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -216,7 +216,6 @@ inline const PrimitivePtr kPrimApplyCenteredRMSProp = std::make_shared("AvgPool"); inline const PrimitivePtr kPrimAvgPoolGrad = std::make_shared("AvgPoolGrad"); inline const PrimitivePtr kPrimAvgPoolGradVm = std::make_shared("AvgPoolGradVm"); -inline const PrimitivePtr kPrimAvgPoolGradCpu = std::make_shared("AvgPoolGradCpu"); inline const PrimitivePtr kPrimFusedSparseAdam = std::make_shared("FusedSparseAdam"); inline const PrimitivePtr kPrimFusedBatchNorm = std::make_shared("FusedBatchNorm"); inline const PrimitivePtr kPrimFusedBatchNormEx = std::make_shared("FusedBatchNormEx"); diff --git a/mindspore/ops/_grad/grad_nn_ops.py b/mindspore/ops/_grad/grad_nn_ops.py index d8104271b5e..db72b775fa6 100755 --- a/mindspore/ops/_grad/grad_nn_ops.py +++ b/mindspore/ops/_grad/grad_nn_ops.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,7 +15,6 @@ """Define the grad rules of neural network related operations.""" import os -import numpy as np from mindspore.ops import _selected_grad_ops as SG from mindspore.ops.primitive import constexpr from mindspore.common.tensor import Tensor @@ -250,149 +249,20 @@ def get_bprop_max_pool_grad(self): return bprop -def _windowed_output_size(input_size, ksize, stride, pad_mode): - """ - helper func for AvgPoolGrad - """ - - tmp_output = 0 - tmp_pad_need = 0 - tmp_pad_before = 0 - tmp_pad_after = 0 - if pad_mode == 'VALID': - tmp_output = (input_size - ksize + stride) // stride - tmp_pad_before = 0 - tmp_pad_after = 0 - elif pad_mode == 'SAME': - tmp_output = (input_size + stride - 1) // stride - tmp_pad_need = max(0, (tmp_output - 1) * stride + ksize - input_size) - tmp_pad_before = tmp_pad_need // 2 - tmp_pad_after = tmp_pad_need - tmp_pad_before - return tmp_output, tmp_pad_before, tmp_pad_after - - -@constexpr -def _get_mean_matrix(x_shape, ksize, stride, pad_mode, x_dtype): - """ - helper func for AvgPoolGrad. - - `assist_input_matrix` is a 2d matrix with input_shape after padding, - the value of element which is padded is 0, else are 1. - For each element of output, it is mapped for slide window: `[h*h_stride : h*h_stride + h_ksize, - w*w_stride : w*w_stride + w_ksize]` of `assist_input_matrix`, so the sum of slide window is the - number of input that associate with output element. - """ - - n_input, c_input, h_input, w_input = x_shape - h_ksize, w_ksize = ksize[2], ksize[3] - h_stride, w_stride = stride[2], stride[3] - n_output = n_input - c_output = c_input - h_output, w_output = 0, 0 - pad_top, pad_bottom, pad_left, pad_right = 0, 0, 0, 0 - h_output, pad_top, pad_bottom = _windowed_output_size(h_input, h_ksize, - h_stride, pad_mode) - w_output, pad_left, pad_right = _windowed_output_size(w_input, w_ksize, - w_stride, pad_mode) - - output_size = n_output * c_output * h_output * w_output - output_shape = (n_output, c_output, h_output, w_output) - output = np.array([0.0] * output_size) - output = np.reshape(output, output_shape) - - in_shape_after_padding_2d = (h_input + pad_top + pad_bottom, w_input + pad_left + pad_right) - assist_input_matrix = np.ones(in_shape_after_padding_2d).astype(np.float32) - if pad_top > 0: - assist_input_matrix[:pad_top, :] = 0 - if pad_bottom > 0: - assist_input_matrix[-pad_bottom:, :] = 0 - if pad_left > 0: - assist_input_matrix[:, :pad_left] = 0 - if pad_right > 0: - assist_input_matrix[:, -pad_right:] = 0 - - for h in range(h_output): - for w in range(w_output): - curr_input = assist_input_matrix[h * h_stride: h * h_stride + h_ksize, w * w_stride: w * w_stride + w_ksize] - curr_sum = np.sum(curr_input) - if curr_sum > 0: - output[:, :, h, w] = 1. / curr_sum - return Tensor(output, x_dtype) - - -@constexpr -def _get_kernel_matrix(x_shape_nchw, kernel_matrix_shape, pad_mode, x_dtype): - kernel_matrix = np.ones(kernel_matrix_shape) - return Tensor(kernel_matrix, x_dtype) - - @bprop_getters.register(P.AvgPool) def get_bprop_avg_pool_grad(self): """Grad definition for `AvgPool` operation.""" + avgpool_grad = G.AvgPoolGrad( + kernel_size=self.kernel_size, + strides=self.strides, + pad_mode=self.pad_mode, + data_format=self.format) - # the parameter of AvgPoolGrad in GPU and TBE/CPU is not same - if self.target == "GPU": - avgpool_grad_gpu = G.AvgPoolGradGpu( - kernel_size=self.kernel_size, - strides=self.strides, - pad_mode=self.pad_mode, - data_format=self.format) + def bprop(x, out, dout): + dx = avgpool_grad(x, out, dout) + return (dx,) - def bprop_gpu(x, out, dout): - dx = avgpool_grad_gpu(x, out, dout) - return (dx,) - - bprop_fn = bprop_gpu - - elif self.target == "CPU": - avgpool_grad_cpu = G.AvgPoolGradCpu( - kernel_size=self.kernel_size, - strides=self.strides, - pad_mode=self.pad_mode, - data_format=self.format) - - def bprop_cpu(x, out, dout): - dx = avgpool_grad_cpu(x, out, dout) - return (dx,) - - bprop_fn = bprop_cpu - - elif self.target == "GE": - avgpool_grad_ge = G.AvgPoolGrad( - kernel_size=self.kernel_size, - strides=self.strides, - pad_mode=self.pad_mode) - shape_op = P.Shape() - - def bprop_ge(x, out, dout): - dx = avgpool_grad_ge(shape_op(x), dout) - return (dx,) - - bprop_fn = bprop_ge - - else: - avgpool_grad_vm = G.AvgPoolGradVm( - kernel_size=self.kernel_size, - strides=self.strides, - pad_mode=self.pad_mode) - k_size_nchw = avgpool_grad_vm.kernel_size - stride_nchw = avgpool_grad_vm.strides - pad_mode = self.pad_mode - - def bprop_vm(x, out, dout): - x_shape_nchw = F.shape(x) - x_dtype = F.dtype(x) - kernel_matrix_shape = (1, x_shape_nchw[1], - k_size_nchw[2], - k_size_nchw[3]) - mean_matrix = _get_mean_matrix(x_shape_nchw, k_size_nchw, stride_nchw, pad_mode, x_dtype) - kernel_matrix = _get_kernel_matrix(x_shape_nchw, kernel_matrix_shape, pad_mode, x_dtype) - dx = avgpool_grad_vm(x_shape_nchw, dout, mean_matrix, kernel_matrix) - return (dx,) - - bprop_fn = bprop_vm - - return bprop_fn + return bprop @bprop_getters.register(P.DropoutGenMask) diff --git a/mindspore/ops/operations/_grad_ops.py b/mindspore/ops/operations/_grad_ops.py index 02f17e6f41c..16f19c4278b 100644 --- a/mindspore/ops/operations/_grad_ops.py +++ b/mindspore/ops/operations/_grad_ops.py @@ -892,23 +892,6 @@ class _PoolGrad(PrimitiveWithInfer): self.add_prim_attr("strides", self.strides) -class AvgPoolGrad(_PoolGrad): - """Gradients of the avg pool operation for ge.""" - - @prim_attr_register - def __init__(self, kernel_size=1, strides=1, pad_mode="VALID"): - super(AvgPoolGrad, self).__init__(kernel_size, strides, pad_mode) - - def __infer__(self, origin_input, dout): - out = { - 'value': None, - 'shape': tuple(origin_input['value']), - 'dtype': dout['dtype'], - } - - return out - - class AvgPoolGradVm(_PoolGrad): """Gradients of the avg pool operation for vm.""" @@ -927,26 +910,12 @@ class AvgPoolGradVm(_PoolGrad): return out -class AvgPoolGradGpu(_PoolGrad): - """Gradients of the avg pool operation for gpu.""" +class AvgPoolGrad(_PoolGrad): + """Gradients of the avg pool operation.""" @prim_attr_register def __init__(self, kernel_size=1, strides=1, pad_mode="VALID", data_format="NCHW"): - super(AvgPoolGradGpu, self).__init__(kernel_size, strides, pad_mode, data_format) - - def infer_shape(self, x1_shape, x2_shape, grad_shape): - return x1_shape - - def infer_dtype(self, x1_dtype, x2_dtype, grad_dtype): - return x1_dtype - - -class AvgPoolGradCpu(_PoolGrad): - """Gradients of the avg pool operation for cpu.""" - - @prim_attr_register - def __init__(self, kernel_size=1, strides=1, pad_mode="VALID", data_format="NCHW"): - super(AvgPoolGradCpu, self).__init__(kernel_size, strides, pad_mode, data_format) + super(AvgPoolGrad, self).__init__(kernel_size, strides, pad_mode, data_format) def infer_shape(self, x1_shape, x2_shape, grad_shape): return x1_shape diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 3c2f7b83087..6190518ac5e 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -1811,10 +1811,9 @@ class MaxPoolWithArgmax(_Pool): [33. 34. 35.]]]] """ + @prim_attr_register def __init__(self, kernel_size=1, strides=1, pad_mode="valid", data_format="NCHW"): super(MaxPoolWithArgmax, self).__init__(kernel_size, strides, pad_mode, data_format) - self.is_tbe = context.get_context("device_target") == "Ascend" - self.is_gpu = context.get_context("device_target") == "GPU" def infer_shape(self, x_shape): out_shape = _Pool.infer_shape(self, x_shape) @@ -1897,14 +1896,6 @@ class AvgPool(_Pool): @prim_attr_register def __init__(self, kernel_size=1, strides=1, pad_mode="valid", data_format="NCHW"): - if context.get_context("device_target") == "GPU": - self.target = "GPU" - elif context.get_context("device_target") == "CPU": - self.target = "CPU" - elif context.get_context("enable_ge"): - self.target = "GE" - else: - self.target = "OTHER" super(AvgPool, self).__init__(kernel_size, strides, pad_mode, data_format)