forked from mindspore-Ecosystem/mindspore
!1414 fix issue use reshape as flatten grad impl
Merge pull request !1414 from zhaozhenlong/fix-issues-reshape-replace-flattern-grad
This commit is contained in:
commit
44bf7c9330
|
@ -385,7 +385,8 @@ bool IsNopNode(const AnfNodePtr &node) {
|
|||
return false;
|
||||
}
|
||||
static std::unordered_set<std::string> nop_nodes = {prim::kPrimReshape->name(), kExpandDimsOpName,
|
||||
prim::kPrimSqueeze->name(), prim::kPrimFlatten->name()};
|
||||
prim::kPrimSqueeze->name(), prim::kPrimFlatten->name(),
|
||||
kFlattenGradOpName};
|
||||
if (node == nullptr || !node->isa<CNode>()) {
|
||||
return false;
|
||||
}
|
||||
|
|
|
@ -197,3 +197,4 @@ from .cum_sum import _cum_sum_tbe
|
|||
from .apply_rms_prop import _apply_rms_prop_tbe
|
||||
from .cumprod import _cumprop_tbe
|
||||
from .reduce_prod import _reduce_prod_tbe
|
||||
from .flatten_grad import _flatten_grad_tbe
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Reshape op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
flatten_grad_op_info = TBERegOp("FlattenGrad") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("reshape.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("reshape") \
|
||||
.partial_flag(True) \
|
||||
.attr("shape", "required", "listInt", "all") \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
@op_info_register(flatten_grad_op_info)
|
||||
def _flatten_grad_tbe():
|
||||
"""Reshape TBE register"""
|
||||
return
|
|
@ -121,6 +121,16 @@ class NetForFlatten0D(nn.Cell):
|
|||
return self.flatten(x)
|
||||
|
||||
|
||||
class NetForFlattenComposed(nn.Cell):
|
||||
# make flatten op together with other ops for testing flatten grad
|
||||
def __init__(self):
|
||||
super(NetForFlattenComposed, self).__init__()
|
||||
self.flatten = P.Flatten()
|
||||
|
||||
def construct(self, x, y):
|
||||
return self.flatten(x+x) + y
|
||||
|
||||
|
||||
class ArgmaxNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(ArgmaxNet, self).__init__()
|
||||
|
@ -695,7 +705,7 @@ test_case_nn_ops = [
|
|||
('Flatten', {
|
||||
'block': P.Flatten(),
|
||||
'desc_inputs': [[128, 32, 32, 64]],
|
||||
'desc_bprop': [[128 * 32 * 8 * 16]]}),
|
||||
'desc_bprop': [[128, 65536]]}),
|
||||
('LogSoftmax', {
|
||||
'block': P.LogSoftmax(),
|
||||
'desc_inputs': [[64, 2]],
|
||||
|
@ -897,6 +907,11 @@ test_case_nn_ops = [
|
|||
'desc_inputs': [Tensor(np.ones([8]).astype(np.int32)), Tensor(np.ones([8, 3]).astype(np.int32))],
|
||||
'desc_bprop': [Tensor(np.ones([8, 3]).astype(np.int32))],
|
||||
'skip': ['backward']}),
|
||||
('Flatten_3', {
|
||||
'block': NetForFlattenComposed(),
|
||||
'desc_inputs': [Tensor(np.ones([2, 3, 4]).astype(np.int32)), Tensor(np.ones([2, 12]).astype(np.int32))],
|
||||
'desc_bprop': [Tensor(np.ones([2, 12]).astype(np.int32))],
|
||||
'skip': []}),
|
||||
('ArgmaxNet', {
|
||||
'block': ArgmaxNet(),
|
||||
'desc_inputs': [Tensor(np.array([[128, 32, 32, 64], [128, 32, 32, 64]]).astype(np.float16))],
|
||||
|
|
Loading…
Reference in New Issue