support extract_image_patches op

This commit is contained in:
liubuyu 2020-09-23 11:31:20 +08:00
parent 15b4115e55
commit 96622fc804
8 changed files with 165 additions and 6 deletions

View File

@ -64,6 +64,7 @@
#include "backend/optimizer/ascend/format_type/rectify_do_mask_kernel_info.h"
#include "backend/optimizer/ascend/format_type/chang_axis_of_reduce_kernel.h"
#include "backend/optimizer/ascend/format_type/split_unsupported_transdata.h"
#include "backend/optimizer/ascend/format_type/insert_reshape_for_extract_image_patches_op.h"
#include "backend/optimizer/pass/getitem_tuple.h"
#include "backend/optimizer/pass/optimize_dependence.h"
#include "backend/optimizer/pass/erase_visit_attr.h"
@ -231,6 +232,7 @@ void AscendMixPrecision(const std::shared_ptr<session::KernelGraph> &kernel_grap
auto optimizer = std::make_shared<GraphOptimizer>();
auto mixed_precision_pm = std::make_shared<PassManager>("cast_pm");
mixed_precision_pm->AddPass(std::make_shared<InsertCast>());
mixed_precision_pm->AddPass(std::make_shared<InsertReshapeForExtractImagePatchesOp>());
mixed_precision_pm->AddPass(std::make_shared<GetitemTuple>());
mixed_precision_pm->AddPass(std::make_shared<CommonSubexpressionElimination>());
mixed_precision_pm->AddPass(std::make_shared<EliminateRedundantOp>());

View File

@ -0,0 +1,65 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/ascend/format_type/insert_reshape_for_extract_image_patches_op.h"
#include <memory>
#include "backend/optimizer/ascend/ascend_helper.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "utils/utils.h"
#include "base/core_ops.h"
namespace mindspore {
namespace opt {
const BaseRef InsertReshapeForExtractImagePatchesOp::DefinePattern() const {
VarPtr Xs = std::make_shared<SeqVar>();
return VectorRef({prim::kPrimExtractImagePatches, Xs});
}
const AnfNodePtr InsertReshapeForExtractImagePatchesOp::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &equiv) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(equiv);
auto extract = CheckAnfNodeIfCNodeAndInputSize(node, 2);
MS_EXCEPTION_IF_NULL(extract);
auto in_node = extract->input(1);
MS_EXCEPTION_IF_NULL(in_node);
auto extract_kernel_build_info = AnfAlgo::GetSelectKernelBuildInfo(extract);
auto in_node_kernel_build_info = AnfAlgo::GetSelectKernelBuildInfo(in_node);
MS_EXCEPTION_IF_NULL(extract_kernel_build_info);
MS_EXCEPTION_IF_NULL(in_node_kernel_build_info);
std::vector<AnfNodePtr> reshape_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReshape->name())),
in_node};
auto reshape_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
reshape_builder->SetInputsFormat({kOpFormat_NC1HWC0});
reshape_builder->SetOutputsFormat({kOpFormat_NC1HWC0});
reshape_builder->SetInputsDeviceType({AnfAlgo::GetOutputDeviceDataType(in_node, 0)});
reshape_builder->SetOutputsDeviceType({AnfAlgo::GetOutputDeviceDataType(in_node, 0)});
reshape_builder->SetKernelType(in_node_kernel_build_info->kernel_type());
reshape_builder->SetFusionType(in_node_kernel_build_info->fusion_type());
reshape_builder->SetProcessor(in_node_kernel_build_info->processor());
auto reshape = func_graph->NewCNode(reshape_inputs);
reshape->set_scope(in_node->scope());
auto shape_tmp = AnfAlgo::GetOutputInferShape(in_node, 0);
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputDeviceDataType(in_node, 0)},
{{shape_tmp[0], shape_tmp[2], shape_tmp[3], shape_tmp[1]}}, reshape.get());
AnfAlgo::SetSelectKernelBuildInfo(reshape_builder->Build(), reshape.get());
AnfAlgo::SetNodeAttr("nop_op", MakeValue(true), reshape);
AnfAlgo::SetNodeInput(extract, reshape, 0);
return extract;
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,41 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_INSERT_RESHAPE_FOR_EXTRACT_IMAGE_PATCHES_OP_H
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_INSERT_RESHAPE_FOR_EXTRACT_IMAGE_PATCHES_OP_H
#include <vector>
#include <string>
#include <utility>
#include <memory>
#include "ir/anf.h"
#include "backend/optimizer/common/pattern_engine.h"
#include "backend/optimizer/common/helper.h"
#include "backend/optimizer/common/optimizer.h"
namespace mindspore {
namespace opt {
class InsertReshapeForExtractImagePatchesOp : public PatternProcessPass {
public:
explicit InsertReshapeForExtractImagePatchesOp(bool multigraph = true)
: PatternProcessPass("insert_reshape_for_extract_image_patches_op", multigraph) {}
~InsertReshapeForExtractImagePatchesOp() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_INSERT_RESHAPE_FOR_EXTRACT_IMAGE_PATCHES_OP_H

View File

@ -516,6 +516,10 @@ std::vector<size_t> AnfRuntimeAlgorithm::GetInputDeviceShape(const AnfNodePtr &n
if (trans::IsNeedPadding(format, infer_shape.size())) {
infer_shape = trans::PaddingShapeTo4d(infer_shape, GetInputReshapeType(node, input_idx));
}
if (node->isa<CNode>() && GetCNodeName(node) == kExtractImagePatchesOpName) {
auto shape_tmp = {infer_shape[0], infer_shape[3], infer_shape[1], infer_shape[2]};
return trans::TransShapeToDevice(shape_tmp, format);
}
return trans::TransShapeToDevice(infer_shape, format);
}

View File

@ -104,6 +104,7 @@ inline const PrimitivePtr kPrimPad = std::make_shared<Primitive>("Pad");
inline const PrimitivePtr kPrimArgMaxWithValue = std::make_shared<Primitive>("ArgMaxWithValue");
inline const PrimitivePtr kPrimUnique = std::make_shared<Primitive>("Unique");
inline const PrimitivePtr kPrimUniqueGrad = std::make_shared<Primitive>("UniqueGrad");
inline const PrimitivePtr kPrimExtractImagePatches = std::make_shared<Primitive>("ExtractImagePatches");
// NN
inline const PrimitivePtr kPrimFlatten = std::make_shared<Primitive>("Flatten");

View File

@ -542,12 +542,16 @@ class Unfold(Cell):
self.transpose = P.Transpose()
self.format_NHWC = (0, 2, 3, 1)
self.format_NCHW = (0, 3, 1, 2)
self.is_ge = context.get_context("enable_ge")
def construct(self, input_x):
x_transpose = self.transpose(input_x, self.format_NHWC)
ret = self.extract_image_patches(x_transpose)
ret_transpose = self.transpose(ret, self.format_NCHW)
return ret_transpose
if self.is_ge:
x_transpose = self.transpose(input_x, self.format_NHWC)
ret = self.extract_image_patches(x_transpose)
result = self.transpose(ret, self.format_NCHW)
else:
result = self.extract_image_patches(input_x)
return result
@constexpr

View File

@ -21,6 +21,7 @@ from mindspore.common.tensor import Tensor
from .grad_base import bprop_getters
from .. import functional as F
from .. import operations as P
from ...common import dtype as mstype
from ..composite.multitype_ops.zeros_like_impl import zeros_like
from ..operations import _grad_ops as G
from ..operations import _inner_ops as inner
@ -75,11 +76,43 @@ def get_bprop_extract_image_patches(self):
fill = P.Fill()
slice_op = P.Slice()
transpose = P.Transpose()
cast = P.Cast()
matmul = P.MatMul()
_, ksizes_row, ksizes_col, _ = self.ksizes
def bprop(x, out, dout):
x_shape = get_shape(x)
x_batch, x_depth, x_row, x_col = x_shape
x_indices_num = x_row * x_col + 1
x_idx = cast(F.tuple_to_array(range(1, x_indices_num)), mstype.float32)
x_idx = reshape(x_idx, (1, 1, x_row, x_col))
x_idx_patch = cast(extract_image_patches(x_idx), mstype.int32)
x_idx_patch = transpose(x_idx_patch, (0, 2, 3, 1))
out_shape = get_shape(out)
_, _, out_row, out_col = out_shape
out_indices_num = out_row * out_col * ksizes_row * ksizes_col
out_idx = F.tuple_to_array(range(out_indices_num))
out_idx = reshape(out_idx, (1, out_row, out_col, ksizes_row * ksizes_col))
idx_tensor = concat((expand_dims(x_idx_patch, -1), expand_dims(out_idx, -1)))
idx_tensor = reshape(idx_tensor, (-1, 2))
sp_shape = (x_indices_num, out_indices_num)
sp_tensor = scatter_nd(idx_tensor, fill(dtype(dout), (out_indices_num,), 1), sp_shape)
sp_tensor = slice_op(sp_tensor, (1, 0), (x_indices_num - 1, out_indices_num))
grad = transpose(dout, (0, 2, 3, 1))
grad = reshape(grad, (x_batch, out_row, out_col, ksizes_row, ksizes_col, x_depth))
grad = transpose(grad, (1, 2, 3, 4, 0, 5))
grad = reshape(grad, (-1, x_batch * x_depth))
jac = matmul(sp_tensor, grad)
dx = reshape(jac, (x_row, x_col, x_batch, x_depth))
dx = transpose(dx, (2, 3, 0, 1))
return (dx,)
def bprop_ge(x, out, dout):
x_shape = get_shape(x)
x_batch, x_row, x_col, x_depth = x_shape
x_indices_num = x_row * x_col + 1
@ -109,6 +142,9 @@ def get_bprop_extract_image_patches(self):
return (dx,)
if context.get_context("enable_ge"):
return bprop_ge
return bprop

View File

@ -17,6 +17,7 @@
from ..._checkparam import Rel
from ..._checkparam import Validator as validator
from ... import context
from ...common import dtype as mstype
from ..primitive import PrimitiveWithInfer, prim_attr_register
from ..operations.math_ops import _infer_shape_reduce
@ -200,10 +201,13 @@ class ExtractImagePatches(PrimitiveWithInfer):
self.padding = validator.check_string('padding', padding.upper(), ['VALID', 'SAME'], self.name)
self.add_prim_attr("padding", self.padding)
self.add_prim_attr("io_format", "NHWC")
self.is_ge = context.get_context("enable_ge")
def infer_shape(self, input_x):
"""infer shape"""
in_batch, in_row, in_col, in_depth = input_x
in_batch, in_depth, in_row, in_col = input_x
if self.is_ge:
in_batch, in_row, in_col, in_depth = input_x
_, ksize_row, ksize_col, _ = self.ksizes
_, stride_row, stride_col, _ = self.strides
_, rate_row, rate_col, _ = self.rates
@ -223,7 +227,9 @@ class ExtractImagePatches(PrimitiveWithInfer):
out_row = (in_row - 1) // stride_row + 1
out_col = (in_col - 1) // stride_col + 1
out_shape = [out_batch, out_row, out_col, out_depth]
out_shape = [out_batch, out_depth, out_row, out_col]
if self.is_ge:
out_shape = [out_batch, out_row, out_col, out_depth]
return out_shape
def infer_dtype(self, input_x):