forked from OSSInnovation/mindspore
support extract_image_patches op
This commit is contained in:
parent
15b4115e55
commit
96622fc804
|
@ -64,6 +64,7 @@
|
|||
#include "backend/optimizer/ascend/format_type/rectify_do_mask_kernel_info.h"
|
||||
#include "backend/optimizer/ascend/format_type/chang_axis_of_reduce_kernel.h"
|
||||
#include "backend/optimizer/ascend/format_type/split_unsupported_transdata.h"
|
||||
#include "backend/optimizer/ascend/format_type/insert_reshape_for_extract_image_patches_op.h"
|
||||
#include "backend/optimizer/pass/getitem_tuple.h"
|
||||
#include "backend/optimizer/pass/optimize_dependence.h"
|
||||
#include "backend/optimizer/pass/erase_visit_attr.h"
|
||||
|
@ -231,6 +232,7 @@ void AscendMixPrecision(const std::shared_ptr<session::KernelGraph> &kernel_grap
|
|||
auto optimizer = std::make_shared<GraphOptimizer>();
|
||||
auto mixed_precision_pm = std::make_shared<PassManager>("cast_pm");
|
||||
mixed_precision_pm->AddPass(std::make_shared<InsertCast>());
|
||||
mixed_precision_pm->AddPass(std::make_shared<InsertReshapeForExtractImagePatchesOp>());
|
||||
mixed_precision_pm->AddPass(std::make_shared<GetitemTuple>());
|
||||
mixed_precision_pm->AddPass(std::make_shared<CommonSubexpressionElimination>());
|
||||
mixed_precision_pm->AddPass(std::make_shared<EliminateRedundantOp>());
|
||||
|
|
|
@ -0,0 +1,65 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/optimizer/ascend/format_type/insert_reshape_for_extract_image_patches_op.h"
|
||||
#include <memory>
|
||||
#include "backend/optimizer/ascend/ascend_helper.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "utils/utils.h"
|
||||
#include "base/core_ops.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
const BaseRef InsertReshapeForExtractImagePatchesOp::DefinePattern() const {
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
return VectorRef({prim::kPrimExtractImagePatches, Xs});
|
||||
}
|
||||
|
||||
const AnfNodePtr InsertReshapeForExtractImagePatchesOp::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &equiv) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(equiv);
|
||||
auto extract = CheckAnfNodeIfCNodeAndInputSize(node, 2);
|
||||
MS_EXCEPTION_IF_NULL(extract);
|
||||
auto in_node = extract->input(1);
|
||||
MS_EXCEPTION_IF_NULL(in_node);
|
||||
auto extract_kernel_build_info = AnfAlgo::GetSelectKernelBuildInfo(extract);
|
||||
auto in_node_kernel_build_info = AnfAlgo::GetSelectKernelBuildInfo(in_node);
|
||||
MS_EXCEPTION_IF_NULL(extract_kernel_build_info);
|
||||
MS_EXCEPTION_IF_NULL(in_node_kernel_build_info);
|
||||
std::vector<AnfNodePtr> reshape_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReshape->name())),
|
||||
in_node};
|
||||
auto reshape_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
|
||||
reshape_builder->SetInputsFormat({kOpFormat_NC1HWC0});
|
||||
reshape_builder->SetOutputsFormat({kOpFormat_NC1HWC0});
|
||||
reshape_builder->SetInputsDeviceType({AnfAlgo::GetOutputDeviceDataType(in_node, 0)});
|
||||
reshape_builder->SetOutputsDeviceType({AnfAlgo::GetOutputDeviceDataType(in_node, 0)});
|
||||
reshape_builder->SetKernelType(in_node_kernel_build_info->kernel_type());
|
||||
reshape_builder->SetFusionType(in_node_kernel_build_info->fusion_type());
|
||||
reshape_builder->SetProcessor(in_node_kernel_build_info->processor());
|
||||
|
||||
auto reshape = func_graph->NewCNode(reshape_inputs);
|
||||
reshape->set_scope(in_node->scope());
|
||||
auto shape_tmp = AnfAlgo::GetOutputInferShape(in_node, 0);
|
||||
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputDeviceDataType(in_node, 0)},
|
||||
{{shape_tmp[0], shape_tmp[2], shape_tmp[3], shape_tmp[1]}}, reshape.get());
|
||||
AnfAlgo::SetSelectKernelBuildInfo(reshape_builder->Build(), reshape.get());
|
||||
AnfAlgo::SetNodeAttr("nop_op", MakeValue(true), reshape);
|
||||
AnfAlgo::SetNodeInput(extract, reshape, 0);
|
||||
return extract;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,41 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_INSERT_RESHAPE_FOR_EXTRACT_IMAGE_PATCHES_OP_H
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_INSERT_RESHAPE_FOR_EXTRACT_IMAGE_PATCHES_OP_H
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <memory>
|
||||
#include "ir/anf.h"
|
||||
#include "backend/optimizer/common/pattern_engine.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class InsertReshapeForExtractImagePatchesOp : public PatternProcessPass {
|
||||
public:
|
||||
explicit InsertReshapeForExtractImagePatchesOp(bool multigraph = true)
|
||||
: PatternProcessPass("insert_reshape_for_extract_image_patches_op", multigraph) {}
|
||||
~InsertReshapeForExtractImagePatchesOp() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_INSERT_RESHAPE_FOR_EXTRACT_IMAGE_PATCHES_OP_H
|
|
@ -516,6 +516,10 @@ std::vector<size_t> AnfRuntimeAlgorithm::GetInputDeviceShape(const AnfNodePtr &n
|
|||
if (trans::IsNeedPadding(format, infer_shape.size())) {
|
||||
infer_shape = trans::PaddingShapeTo4d(infer_shape, GetInputReshapeType(node, input_idx));
|
||||
}
|
||||
if (node->isa<CNode>() && GetCNodeName(node) == kExtractImagePatchesOpName) {
|
||||
auto shape_tmp = {infer_shape[0], infer_shape[3], infer_shape[1], infer_shape[2]};
|
||||
return trans::TransShapeToDevice(shape_tmp, format);
|
||||
}
|
||||
return trans::TransShapeToDevice(infer_shape, format);
|
||||
}
|
||||
|
||||
|
|
|
@ -104,6 +104,7 @@ inline const PrimitivePtr kPrimPad = std::make_shared<Primitive>("Pad");
|
|||
inline const PrimitivePtr kPrimArgMaxWithValue = std::make_shared<Primitive>("ArgMaxWithValue");
|
||||
inline const PrimitivePtr kPrimUnique = std::make_shared<Primitive>("Unique");
|
||||
inline const PrimitivePtr kPrimUniqueGrad = std::make_shared<Primitive>("UniqueGrad");
|
||||
inline const PrimitivePtr kPrimExtractImagePatches = std::make_shared<Primitive>("ExtractImagePatches");
|
||||
|
||||
// NN
|
||||
inline const PrimitivePtr kPrimFlatten = std::make_shared<Primitive>("Flatten");
|
||||
|
|
|
@ -542,12 +542,16 @@ class Unfold(Cell):
|
|||
self.transpose = P.Transpose()
|
||||
self.format_NHWC = (0, 2, 3, 1)
|
||||
self.format_NCHW = (0, 3, 1, 2)
|
||||
self.is_ge = context.get_context("enable_ge")
|
||||
|
||||
def construct(self, input_x):
|
||||
x_transpose = self.transpose(input_x, self.format_NHWC)
|
||||
ret = self.extract_image_patches(x_transpose)
|
||||
ret_transpose = self.transpose(ret, self.format_NCHW)
|
||||
return ret_transpose
|
||||
if self.is_ge:
|
||||
x_transpose = self.transpose(input_x, self.format_NHWC)
|
||||
ret = self.extract_image_patches(x_transpose)
|
||||
result = self.transpose(ret, self.format_NCHW)
|
||||
else:
|
||||
result = self.extract_image_patches(input_x)
|
||||
return result
|
||||
|
||||
|
||||
@constexpr
|
||||
|
|
|
@ -21,6 +21,7 @@ from mindspore.common.tensor import Tensor
|
|||
from .grad_base import bprop_getters
|
||||
from .. import functional as F
|
||||
from .. import operations as P
|
||||
from ...common import dtype as mstype
|
||||
from ..composite.multitype_ops.zeros_like_impl import zeros_like
|
||||
from ..operations import _grad_ops as G
|
||||
from ..operations import _inner_ops as inner
|
||||
|
@ -75,11 +76,43 @@ def get_bprop_extract_image_patches(self):
|
|||
fill = P.Fill()
|
||||
slice_op = P.Slice()
|
||||
transpose = P.Transpose()
|
||||
cast = P.Cast()
|
||||
matmul = P.MatMul()
|
||||
|
||||
_, ksizes_row, ksizes_col, _ = self.ksizes
|
||||
|
||||
def bprop(x, out, dout):
|
||||
x_shape = get_shape(x)
|
||||
x_batch, x_depth, x_row, x_col = x_shape
|
||||
x_indices_num = x_row * x_col + 1
|
||||
x_idx = cast(F.tuple_to_array(range(1, x_indices_num)), mstype.float32)
|
||||
x_idx = reshape(x_idx, (1, 1, x_row, x_col))
|
||||
x_idx_patch = cast(extract_image_patches(x_idx), mstype.int32)
|
||||
x_idx_patch = transpose(x_idx_patch, (0, 2, 3, 1))
|
||||
|
||||
out_shape = get_shape(out)
|
||||
_, _, out_row, out_col = out_shape
|
||||
out_indices_num = out_row * out_col * ksizes_row * ksizes_col
|
||||
out_idx = F.tuple_to_array(range(out_indices_num))
|
||||
out_idx = reshape(out_idx, (1, out_row, out_col, ksizes_row * ksizes_col))
|
||||
|
||||
idx_tensor = concat((expand_dims(x_idx_patch, -1), expand_dims(out_idx, -1)))
|
||||
idx_tensor = reshape(idx_tensor, (-1, 2))
|
||||
sp_shape = (x_indices_num, out_indices_num)
|
||||
sp_tensor = scatter_nd(idx_tensor, fill(dtype(dout), (out_indices_num,), 1), sp_shape)
|
||||
sp_tensor = slice_op(sp_tensor, (1, 0), (x_indices_num - 1, out_indices_num))
|
||||
|
||||
grad = transpose(dout, (0, 2, 3, 1))
|
||||
grad = reshape(grad, (x_batch, out_row, out_col, ksizes_row, ksizes_col, x_depth))
|
||||
grad = transpose(grad, (1, 2, 3, 4, 0, 5))
|
||||
grad = reshape(grad, (-1, x_batch * x_depth))
|
||||
|
||||
jac = matmul(sp_tensor, grad)
|
||||
dx = reshape(jac, (x_row, x_col, x_batch, x_depth))
|
||||
dx = transpose(dx, (2, 3, 0, 1))
|
||||
return (dx,)
|
||||
|
||||
def bprop_ge(x, out, dout):
|
||||
x_shape = get_shape(x)
|
||||
x_batch, x_row, x_col, x_depth = x_shape
|
||||
x_indices_num = x_row * x_col + 1
|
||||
|
@ -109,6 +142,9 @@ def get_bprop_extract_image_patches(self):
|
|||
|
||||
return (dx,)
|
||||
|
||||
if context.get_context("enable_ge"):
|
||||
return bprop_ge
|
||||
|
||||
return bprop
|
||||
|
||||
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
|
||||
from ..._checkparam import Rel
|
||||
from ..._checkparam import Validator as validator
|
||||
from ... import context
|
||||
from ...common import dtype as mstype
|
||||
from ..primitive import PrimitiveWithInfer, prim_attr_register
|
||||
from ..operations.math_ops import _infer_shape_reduce
|
||||
|
@ -200,10 +201,13 @@ class ExtractImagePatches(PrimitiveWithInfer):
|
|||
self.padding = validator.check_string('padding', padding.upper(), ['VALID', 'SAME'], self.name)
|
||||
self.add_prim_attr("padding", self.padding)
|
||||
self.add_prim_attr("io_format", "NHWC")
|
||||
self.is_ge = context.get_context("enable_ge")
|
||||
|
||||
def infer_shape(self, input_x):
|
||||
"""infer shape"""
|
||||
in_batch, in_row, in_col, in_depth = input_x
|
||||
in_batch, in_depth, in_row, in_col = input_x
|
||||
if self.is_ge:
|
||||
in_batch, in_row, in_col, in_depth = input_x
|
||||
_, ksize_row, ksize_col, _ = self.ksizes
|
||||
_, stride_row, stride_col, _ = self.strides
|
||||
_, rate_row, rate_col, _ = self.rates
|
||||
|
@ -223,7 +227,9 @@ class ExtractImagePatches(PrimitiveWithInfer):
|
|||
out_row = (in_row - 1) // stride_row + 1
|
||||
out_col = (in_col - 1) // stride_col + 1
|
||||
|
||||
out_shape = [out_batch, out_row, out_col, out_depth]
|
||||
out_shape = [out_batch, out_depth, out_row, out_col]
|
||||
if self.is_ge:
|
||||
out_shape = [out_batch, out_row, out_col, out_depth]
|
||||
return out_shape
|
||||
|
||||
def infer_dtype(self, input_x):
|
||||
|
|
Loading…
Reference in New Issue