AdamApplyOne support dynamic shape

Merge pull request  from huoxinyou/0525adam
This commit is contained in:
i-robot 2022-06-06 02:49:21 +00:00 committed by Gitee
commit 27cb08f4f3
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
14 changed files with 624 additions and 17 deletions
docs/api/api_python/ops
mindspore
tests
st/ops/ascend
ut/cpp
pre_activate/ascend/ir_fusion
python_input/gtest_input/pre_activate

View File

@ -22,7 +22,7 @@ mindspore.ops.ScatterNdUpdate
**输入:**
- **input_x** (Parameter) - ScatterNdUpdate的输入任意维度的Parameter。
- **indices** (Tensor) - 指定更新操作的索引数据类型为int32。
- **indices** (Tensor) - 指定更新操作的索引数据类型为int32或者int64
- **updates** (Tensor) - 指定与 `input_x` 更新操作的Tensor类型与输入相同。shape为 `indices.shape[:-1] + x.shape[indices.shape[-1]:]`
**输出:**
@ -32,5 +32,5 @@ mindspore.ops.ScatterNdUpdate
**异常:**
- **TypeError** - `use_locking` 不是bool。
- **TypeError** - `indices` 不是int32。
- **TypeError** - `indices` 不是int32或者int64
- **RuntimeError** - 当 `input_x``updates` 类型不一致,需要进行类型转换时,如果 `updates` 不支持转成参数 `input_x` 需要的数据类型,就会报错。

View File

@ -306,6 +306,8 @@ AbstractBasePtr InferImplTensorArrayStack(const AnalysisEnginePtr &, const Primi
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplKMeansCentroids(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplAdamApplyOne(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
template <typename T>
AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) {

View File

@ -25,6 +25,12 @@
#include "utils/ms_context.h"
#include "utils/symbolic.h"
#include "utils/shape_utils.h"
#include "ops/real_div.h"
#include "ops/add.h"
#include "ops/mul.h"
#include "ops/sub.h"
#include "ops/square.h"
#include "ops/assign.h"
namespace {
constexpr auto kRankSize = "rank_size";
@ -957,5 +963,47 @@ AbstractBasePtr InferImplTensorMove(const AnalysisEnginePtr &, const PrimitivePt
MS_EXCEPTION_IF_NULL(output);
return output;
}
AbstractBasePtr InferImplAdamApplyOne(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// An object of a subclass of AbstractBase
constexpr auto kAdamApplyOneNum = 10;
constexpr auto kAdamInputNum1 = 1;
constexpr auto kAdamInputNum2 = 2;
constexpr auto kAdamInputNum3 = 3;
constexpr auto kAdamInputNum4 = 4;
constexpr auto kAdamInputNum5 = 5;
constexpr auto kAdamInputNum6 = 6;
constexpr auto kAdamInputNum7 = 7;
constexpr auto kAdamInputNum8 = 8;
constexpr auto kAdamInputNum9 = 9;
CheckArgsSize(primitive->name(), args_spec_list, kAdamApplyOneNum);
auto input0 = args_spec_list[0];
auto input1 = args_spec_list[kAdamInputNum1];
auto input2 = args_spec_list[kAdamInputNum2];
auto input3 = args_spec_list[kAdamInputNum3];
auto input4 = args_spec_list[kAdamInputNum4];
auto mul0_x = args_spec_list[kAdamInputNum5];
auto mul1_x = args_spec_list[kAdamInputNum6];
auto mul2_x = args_spec_list[kAdamInputNum7];
auto mul3_x = args_spec_list[kAdamInputNum8];
auto add2_y = args_spec_list[kAdamInputNum9];
auto square0 = ops::SquareInfer(nullptr, primitive, {input0});
auto mul1 = ops::MulInfer(nullptr, primitive, {mul1_x, input0});
auto mul0 = ops::MulInfer(nullptr, primitive, {mul0_x, input2});
auto mul2 = ops::MulInfer(nullptr, primitive, {mul2_x, input1});
auto mul3 = ops::MulInfer(nullptr, primitive, {mul3_x, square0});
auto add0 = ops::AddInfer(nullptr, primitive, {mul0, mul1});
auto add1 = ops::AddInfer(nullptr, primitive, {mul2, mul3});
auto sqrt0 = InferImplSqrt(nullptr, primitive, {add1});
auto add2 = ops::AddInfer(nullptr, primitive, {add2_y, sqrt0});
auto true_div0 = ops::RealDivInfer(nullptr, primitive, {add0, add2});
auto mul4 = ops::MulInfer(nullptr, primitive, {input4, true_div0});
auto sub0 = ops::SubInfer(nullptr, primitive, {input3, mul4});
AbstractBasePtrList rets = {add1, add0, sub0};
return std::make_shared<AbstractTuple>(rets);
}
} // namespace abstract
} // namespace mindspore

View File

@ -346,6 +346,7 @@ PrimitiveEvalImplMap &GetPrimitiveToBackendEvalImplMap() {
{prim::kPrimStack, R{ops::StackInfer, nullptr, true}},
{prim::kPrimRpcRecv, R{ops::RpcRecvInfer, nullptr, true}},
{prim::kPrimRpcSend, R{ops::RpcSendInfer, nullptr, true}},
{prim::kPrimAdamApplyOne, R{InferImplAdamApplyOne, nullptr, true}},
};
return prim_backend_eval_implement_map;
}

View File

@ -34,6 +34,8 @@ class MIND_API Assign : public BaseOperator {
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.Assign for the inputs.
void Init() const {}
};
abstract::AbstractBasePtr AssignInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
} // namespace ops
} // namespace mindspore

View File

@ -1105,6 +1105,10 @@ GVAR_DEF(PrimitivePtr, kPrimTensorArrayWrite, std::make_shared<Primitive>("Tenso
GVAR_DEF(PrimitivePtr, kPrimTensorArrayGather, std::make_shared<Primitive>("TensorArrayGather"));
GVAR_DEF(PrimitivePtr, kPrimKMeansCentroids, std::make_shared<Primitive>("KMeansCentroids"));
// AdamApplyOne
GVAR_DEF(PrimitivePtr, kPrimAdamApplyOne, std::make_shared<Primitive>("AdamApplyOne"));
GVAR_DEF(PrimitivePtr, kPrimAdamApplyOneAssign, std::make_shared<Primitive>("AdamApplyOneAssign"));
class DoSignaturePrimitive : public Primitive {
public:
explicit DoSignaturePrimitive(const std::string &name, const ValuePtr &function)

View File

@ -36,6 +36,8 @@ class MIND_API Square : public BaseOperator {
/// \brief Init.
void Init() const {}
};
abstract::AbstractBasePtr SquareInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
} // namespace ops
} // namespace mindspore

View File

@ -579,3 +579,4 @@ from .parallel_resize_bilinear_grad import _parallel_resize_bilinear_grad_op_inf
from .p_s_r_o_i_pooling import _p_s_r_o_i_pooling_tbe
from .p_s_r_o_i_pooling_grad import _p_s_r_o_i_pooling_grad_tbe
from .renorm import _renorm_tbe
from .adam_apply_one_ds import _adam_apply_one_ds_tbe

View File

@ -0,0 +1,51 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""AdamApplyOne op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
adam_apply_one_ds_op_info = TBERegOp("AdamApplyOne") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("adam_apply_one.so") \
.compute_cost(10) \
.kernel_name("adam_apply_one") \
.partial_flag(True) \
.dynamic_shape(True) \
.input(0, "input0", False, "required", "all") \
.input(1, "input1", False, "required", "all") \
.input(2, "input2", False, "required", "all") \
.input(3, "input3", False, "required", "all") \
.input(4, "input4", False, "required", "all") \
.input(5, "mul0_x", False, "required", "all") \
.input(6, "mul1_x", False, "required", "all") \
.input(7, "mul2_x", False, "required", "all") \
.input(8, "mul3_x", False, "required", "all") \
.input(9, "add2_y", False, "required", "all") \
.output(0, "output0", False, "required", "all") \
.output(1, "output1", False, "required", "all") \
.output(2, "output2", False, "required", "all") \
.is_dynamic_format(True) \
.dtype_format(DataType.None_None, DataType.None_None, DataType.None_None, DataType.None_None,
DataType.None_None, DataType.None_None, DataType.None_None, DataType.None_None,
DataType.None_None, DataType.None_None, DataType.None_None, DataType.None_None,
DataType.None_None) \
.get_op_info()
@op_info_register(adam_apply_one_ds_op_info)
def _adam_apply_one_ds_tbe():
"""AdamApplyOne TBE register"""
return

View File

@ -1375,7 +1375,7 @@ def scatter_nd_add(input_x, indices, updates, use_locking=False):
Args:
input_x (Parameter): The target tensor, with data type of Parameter.
The shape is :math:`(N,*)` where :math:`*` means,any number of additional dimensions.
indices (Tensor): The index to do min operation whose data type must be mindspore.int32.
indices (Tensor): The index to do min operation whose data type must be mindspore.int32 or mindspore.int64.
The rank of indices must be at least 2 and `indices.shape[-1] <= len(shape)`.
updates (Tensor): The tensor doing the addition operation with `input_x`,
the data type is same as `input_x`, the shape is `indices.shape[:-1] + x.shape[indices.shape[-1]:]`.
@ -1386,7 +1386,7 @@ def scatter_nd_add(input_x, indices, updates, use_locking=False):
Raises:
TypeError: If `use_locking` is not a bool.
TypeError: If `indices` is not an int32.
TypeError: If `indices` is not an int32 or an int64.
ValueError: If the shape of `updates` is not equal to `indices.shape[:-1] + x.shape[indices.shape[-1]:]`.
RuntimeError: If the data type of `input_x` and `updates` conversion of Parameter
is required when data type conversion of Parameter is not supported.
@ -1451,7 +1451,7 @@ def scatter_nd_sub(input_x, indices, updates, use_locking=False):
Args:
input_x (Parameter): The target tensor, with data type of Parameter.
The shape is :math:`(N,*)` where :math:`*` means,any number of additional dimensions.
indices (Tensor): The index of input tensor, with int32 data type.
indices (Tensor): The index of input tensor, with int32 or int64 data type.
The rank of indices must be at least 2 and `indices.shape[-1] <= len(shape)`.
updates (Tensor): The tensor doing the subtraction operation with `input_x`, has the same type as input.
The shape is `indices.shape[:-1] + x.shape[indices.shape[-1]:]`.
@ -1462,7 +1462,7 @@ def scatter_nd_sub(input_x, indices, updates, use_locking=False):
Raises:
TypeError: If `use_locking` is not a bool.
TypeError: If `indices` is not an int32.
TypeError: If `indices` is not an int32 or int64.
ValueError: If the shape of `updates` is not equal to `indices.shape[:-1] + x.shape[indices.shape[-1]:]`.
RuntimeError: If the data type of `input_x` and `updates` conversion of Parameter
is required when data type conversion of Parameter is not supported.
@ -1528,7 +1528,7 @@ def scatter_nd_div(input_x, indices, updates, use_locking=False):
Args:
input_x (Parameter): The target tensor, with data type of Parameter.
The shape is :math:`(N,*)`, where :math:`*` means any number of additional dimensions.
indices (Tensor): The index to do div operation whose data type must be mindspore.int32.
indices (Tensor): The index to do div operation whose data type must be mindspore.int32 or mindspore.int64.
The rank of indices must be at least 2 and `indices.shape[-1] <= len(shape)`.
updates (Tensor): The tensor to do the div operation with `input_x`.
The data type is same as `input_x`, and the shape is `indices.shape[:-1] + x.shape[indices.shape[-1]:]`.
@ -1539,7 +1539,7 @@ def scatter_nd_div(input_x, indices, updates, use_locking=False):
Raises:
TypeError: If `use_locking` is not a bool.
TypeError: If `indices` is not an int32.
TypeError: If `indices` is not an int32 or an int64.
ValueError: If the shape of `updates` is not equal to `indices.shape[:-1] + x.shape[indices.shape[-1]:]`.
RuntimeError: If the data type of `input_x` and `updates` conversion of Parameter
is required when data type conversion of Parameter is not supported.
@ -1606,7 +1606,7 @@ def scatter_nd_min(input_x, indices, updates, use_locking=False):
Args:
input_x (Parameter): The target tensor, with data type of Parameter.
The shape is :math:`(N,*)`, where :math:`*` means any number of additional dimensions.
indices (Tensor): The index to do min operation whose data type must be mindspore.int32.
indices (Tensor): The index to do min operation whose data type must be mindspore.int32 or mindspore.int64.
The rank of indices must be at least 2 and `indices.shape[-1] <= len(shape)`.
updates (Tensor): The tensor to do the min operation with `input_x`.
The data type is same as `input_x`, and the shape is `indices.shape[:-1] + x.shape[indices.shape[-1]:]`.
@ -1617,7 +1617,7 @@ def scatter_nd_min(input_x, indices, updates, use_locking=False):
Raises:
TypeError: If `use_locking` is not a bool.
TypeError: If `indices` is not an int32.
TypeError: If `indices` is not an int32 or an int64.
ValueError: If the shape of `updates` is not equal to `indices.shape[:-1] + x.shape[indices.shape[-1]:]`.
RuntimeError: If the data type of `input_x` and `updates` conversion of Parameter
is required when data type conversion of Parameter is not supported.

View File

@ -3890,7 +3890,7 @@ class StridedSlice(PrimitiveWithInfer):
new_axis_pos = bin(self.new_axis_mask)[-1:1:-1]
shrink_axis_pos = bin(self.shrink_axis_mask)[-1:1:-1]
if self.ellipsis_mask:
raise ValueError("Ellipsis Mask is currently not supported.")
raise ValueError("Ellipsis Mask is currently not supported in dynamic shape.")
ret_shape = []
ret_min_shape = []
ret_max_shape = []
@ -4319,7 +4319,7 @@ class ScatterNdUpdate(Primitive):
Inputs:
- **input_x** (Parameter) - The target tensor, with data type of Parameter.
The shape is :math:`(N,*)` where :math:`*` means,any number of additional dimensions.
- **indices** (Tensor) - The index of input tensor, with int32 data type.
- **indices** (Tensor) - The index of input tensor, with int32 or int64 data type.
- **updates** (Tensor) - The tensor to be updated to the input tensor, has the same type as input.
The shape is `indices.shape[:-1] + x.shape[indices.shape[-1]:]`.
@ -4328,7 +4328,7 @@ class ScatterNdUpdate(Primitive):
Raises:
TypeError: If `use_locking` is not a bool.
TypeError: If `indices` is not an int32.
TypeError: If `indices` is not an int32 or an int64.
RuntimeError: If the data type of `input_x` and `updates` conversion of Parameter
is required when data type conversion of Parameter is not supported.
@ -5099,7 +5099,7 @@ class ScatterNdMul(_ScatterNdOp):
Inputs:
- **input_x** (Parameter) - The target tensor, with data type of Parameter.
The shape is :math:`(N,*)` where :math:`*` means,any number of additional dimensions.
- **indices** (Tensor) - The index of input tensor, with int32 data type.
- **indices** (Tensor) - The index of input tensor, with int32 or int64 data type.
The rank of indices must be at least 2 and `indices_shape[-1] <= len(shape)`.
- **updates** (Tensor) - The tensor to be updated to the input tensor, has the same type as input.
The shape is `indices_shape[:-1] + x_shape[indices_shape[-1]:]`.
@ -5109,7 +5109,7 @@ class ScatterNdMul(_ScatterNdOp):
Raises:
TypeError: If `use_locking` is not a bool.
TypeError: If `indices` is not an int32.
TypeError: If `indices` is not an int32 or an int64.
ValueError: If the shape of `updates` is not equal to `indices_shape[:-1] + x_shape[indices_shape[-1]:]`.
RuntimeError: If the data type of `input_x` and `updates` conversion of Parameter
is required when data type conversion of Parameter is not supported.
@ -5227,7 +5227,7 @@ class ScatterNdMax(_ScatterNdOp):
Inputs:
- **input_x** (Parameter) - The target tensor, with data type of Parameter.
The shape is :math:`(N,*)` where :math:`*` means,any number of additional dimensions.
- **indices** (Tensor) - The index of input tensor, with int32 data type.
- **indices** (Tensor) - The index of input tensor, with int32 or int64 data type.
The rank of indices must be at least 2 and `indices_shape[-1] <= len(shape)`.
- **updates** (Tensor) - The tensor to be updated to the input tensor, has the same type as input.
The shape is `indices_shape[:-1] + x_shape[indices_shape[-1]:]`.
@ -5237,7 +5237,7 @@ class ScatterNdMax(_ScatterNdOp):
Raises:
TypeError: If `use_locking` is not a bool.
TypeError: If `indices` is not an int32.
TypeError: If `indices` is not an int32 or an int64.
ValueError: If the shape of `updates` is not equal to `indices_shape[:-1] + x_shape[indices_shape[-1]:]`.
RuntimeError: If the data type of `input_x` and `updates` conversion of Parameter
is required when data type conversion of Parameter is not supported.

View File

@ -0,0 +1,111 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
from mindspore import context, nn, set_seed
from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype
from mindspore.ops import operations as P
context.set_context(mode=context.GRAPH_MODE)
context.set_context(device_target="Ascend")
set_seed(2)
class AdamApplyOneNet(nn.Cell):
def __init__(self):
super(AdamApplyOneNet, self).__init__()
self.add = P.Add()
self.sub = P.Sub()
self.mul = P.Mul()
self.real_div = P.RealDiv()
self.sqrt = P.Sqrt()
self.square = P.Square()
def construct(self, input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, add2_y):
square0 = self.square(input0)
mul1 = self.mul(mul1_x, input0)
mul0 = self.mul(mul0_x, input2)
mul2 = self.mul(mul2_x, input1)
mul3 = self.mul(mul3_x, square0)
add0 = self.add(mul0, mul1)
add1 = self.add(mul2, mul3)
sqrt0 = self.sqrt(add1)
add2 = self.add(add2_y, sqrt0)
true_div0 = self.real_div(add0, add2)
mul4 = self.mul(input4, true_div0)
sub0 = self.sub(input3, mul4)
return add1, add0, sub0
def adam_apply_one_np(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, add2_y):
square0 = input0 * input0
mul1 = mul1_x * input0
mul0 = mul0_x * input2
mul2 = mul2_x * input1
mul3 = mul3_x * square0
add0 = mul0 + mul1
add1 = mul2 + mul3
sqrt0 = np.sqrt(add1)
add2 = add2_y + sqrt0
true_div0 = np.true_divide(add0, add2)
mul4 = input4 * true_div0
sub0 = input3 - mul4
return add1, add0, sub0
def compute_func(ms_net, np_net, is_dyn=False):
if is_dyn:
inputs = Tensor(shape=[2, None], dtype=mstype.float32)
ms_net.set_inputs(inputs, inputs, inputs, inputs, inputs, inputs, inputs, inputs, inputs, inputs)
input0 = np.array([[0.1, 0.3, 3.6], [0.4, 0.5, 3.2]]).astype(np.float32)
out0, out1, out2 = ms_net(Tensor(input0), Tensor(input0), Tensor(input0), Tensor(input0), \
Tensor(input0), Tensor(input0), Tensor(input0), Tensor(input0), Tensor(input0), Tensor(input0))
np0, np1, np2 = np_net(input0, input0, input0, input0, input0, input0, input0, input0, input0, input0)
assert np.all(out0.asnumpy() == np0)
assert np.all(out1.asnumpy() == np1)
assert np.all(out2.asnumpy() == np2)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_adam_apply_one_dyn():
"""
Feature: Test Dynamic AdamApplyOne.
Description: The input shape is dynamic.
Expectation: Assert that results are consistent with numpy.
"""
ms_net = AdamApplyOneNet()
np_net = adam_apply_one_np
compute_func(ms_net, np_net, True)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_adam_apply_one():
"""
Feature: Test AdamApplyOne.
Description: The input shape is static.
Expectation: Assert that results are consistent with numpy.
"""
ms_net = AdamApplyOneNet()
np_net = adam_apply_one_np
compute_func(ms_net, np_net)

View File

@ -0,0 +1,241 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/backend_common_test.h"
#include "common/py_func_graph_fetcher.h"
#include "plugin/device/ascend/optimizer/ir_fusion/adam_apply_one_fusion.h"
#include "include/common/debug/anf_ir_dump.h"
namespace mindspore {
namespace opt {
class TestHWAdamApplyOneDynFusion : public BackendCommon {
public:
TestHWAdamApplyOneDynFusion() : get_py_fun_("gtest_input.pre_activate.adam_apply_one_dyn_fusion_test", true) {}
~TestHWAdamApplyOneDynFusion() override = default;
UT::PyFuncGraphFetcher get_py_fun_;
};
AbstractBasePtr GetInputAbstract() {
std::vector<int64_t> shp{-1, 32, -1, 224};
std::vector<int64_t> min_shp{1, 32, 1, 224};
std::vector<int64_t> max_shp{2, 32, 224, 224};
auto input_shp = std::make_shared<abstract::Shape>(shp, min_shp, max_shp);
auto element = std::make_shared<abstract::AbstractScalar>(kAnyValue, std::make_shared<Float>(32));
auto x_abstract = std::make_shared<abstract::AbstractTensor>(element, input_shp);
return x_abstract;
}
/// Feature: test AdamApplyOne dynamic shape
/// Description: The input shape is dynamic
/// Expectation: Assert that result is error
TEST_F(TestHWAdamApplyOneDynFusion, test_adam_apply_one_dyn_fusion) {
/*
* def before(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, add2_y):
* square0 = Square(input0)
* mul1 = Mul(mul1_x, input0)
* mul0 = Mul(mul0_x, input2)
* mul2 = Mul(mul2_x, input1)
* mul3 = Mul(mul3_x, square0)
* add0 = Add(mul0, mul1)
* add1 = Add(mul2, mul3)
* sqrt0 = Sqrt(add1)
* add2 = Add(sqrt0, add2_y)
* true_div0 = RealDiv(add0, add2)
* mul4 = Mul(input4, true_div0)
* sub0 = Sub(input3, mul4)
* outputs = make_tuple(add1, add0, sub0)
* output = tuple_getitem(outputs, 0)
* return output
*/
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_dyn_fusion", "before");
auto x_abstract = GetInputAbstract();
AbstractBasePtrList args_spec_list;
for (size_t i = 0; i < 10; ++i) {
args_spec_list.push_back(x_abstract);
}
auto fg = GetKernelGraph(g, args_spec_list);
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::AdamApplyOneFusion>());
optimizer->AddPassManager(pm);
FuncGraphPtr new_graph = optimizer->Optimize(fg);
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_dyn_fusion", "after");
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
}
/// Feature: test AdamApplyOne dynamic shape
/// Description: The input shape is dynamic
/// Expectation: Assert that result is error
TEST_F(TestHWAdamApplyOneDynFusion, test_adam_apply_one_dyn_cond1_fusion) {
/*
* def before_cond1(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, add2_y):
* square0 = Square(input0)
* mul1 = Mul(mul1_x, input0)
* mul0 = Mul(mul0_x, input2)
* mul2 = Mul(mul2_x, input1)
* mul3 = Mul(mul3_x, square0)
* add0 = Add(mul0, mul1)
* add1 = Add(mul2, mul3)
* sqrt0 = Sqrt(add1)
* add2 = Add(add2_y, sqrt0)
* true_div0 = RealDiv(add0, add2)
* mul4 = Mul(input4, true_div0)
* sub0 = Sub(input3, mul4)
* outputs = make_tuple(add1, add0, sub0)
* output = tuple_getitem(outputs, 0)
* return output
*/
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_dyn_fusion", "before_cond1");
auto x_abstract = GetInputAbstract();
AbstractBasePtrList args_spec_list;
for (size_t i = 0; i < 10; ++i) {
args_spec_list.push_back(x_abstract);
}
auto fg = GetKernelGraph(g, args_spec_list);
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::AdamApplyOneCond1Fusion>());
optimizer->AddPassManager(pm);
FuncGraphPtr new_graph = optimizer->Optimize(fg);
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_dyn_fusion", "after");
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
}
/// Feature: test AdamApplyOne dynamic shape
/// Description: The input shape is dynamic
/// Expectation: Assert that result is error
TEST_F(TestHWAdamApplyOneDynFusion, test_adam_apply_one_dyn_cond2_fusion) {
/*
* def before_cond2(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, add2_y):
* square0 = Square(input0)
* mul1 = Mul(mul1_x, input0)
* mul0 = Mul(mul0_x, input2)
* mul2 = Mul(mul2_x, input1)
* mul3 = Mul(square0, mul3_x)
* add0 = Add(mul0, mul1)
* add1 = Add(mul2, mul3)
* sqrt0 = Sqrt(add1)
* add2 = Add(sqrt0, add2_y)
* true_div0 = RealDiv(add0, add2)
* mul4 = Mul(true_div0, input4)
* sub0 = Sub(input3, mul4)
* outputs = make_tuple(add1, add0, sub0)
* output = tuple_getitem(outputs, 0)
* return output
*/
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_dyn_fusion", "before_cond2");
auto x_abstract = GetInputAbstract();
AbstractBasePtrList args_spec_list;
for (size_t i = 0; i < 10; ++i) {
args_spec_list.push_back(x_abstract);
}
auto fg = GetKernelGraph(g, args_spec_list);
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::AdamApplyOneCond2Fusion>());
optimizer->AddPassManager(pm);
FuncGraphPtr new_graph = optimizer->Optimize(fg);
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_dyn_fusion", "after");
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
}
/// Feature: test AdamApplyOne dynamic shape
/// Description: The input shape is dynamic
/// Expectation: Assert that result is error
TEST_F(TestHWAdamApplyOneDynFusion, test_adam_apply_one_dyn_cond3_fusion) {
/*
* def before_cond3(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, add2_y):
* square0 = Square(input0)
* mul1 = Mul(mul1_x, input0)
* mul0 = Mul(mul0_x, input2)
* mul2 = Mul(mul2_x, input1)
* mul3 = Mul(mul3_x, square0)
* add0 = Add(mul0, mul1)
* add1 = Add(mul2, mul3)
* sqrt0 = Sqrt(add1)
* add2 = Add(sqrt0, add2_y)
* true_div0 = RealDiv(add0, add2)
* mul4 = Mul(true_div0, input4)
* sub0 = Sub(input3, mul4)
* outputs = make_tuple(add1, add0, sub0)
* output = tuple_getitem(outputs, 0)
* return output
*/
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_dyn_fusion", "before_cond3");
auto x_abstract = GetInputAbstract();
AbstractBasePtrList args_spec_list;
for (size_t i = 0; i < 10; ++i) {
args_spec_list.push_back(x_abstract);
}
auto fg = GetKernelGraph(g, args_spec_list);
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::AdamApplyOneCond3Fusion>());
optimizer->AddPassManager(pm);
FuncGraphPtr new_graph = optimizer->Optimize(fg);
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_dyn_fusion", "after");
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
}
/// Feature: test AdamApplyOne dynamic shape
/// Description: The input shape is dynamic
/// Expectation: Assert that result is error
TEST_F(TestHWAdamApplyOneDynFusion, test_adam_apply_one_dyn_cond4_fusion) {
/*
* def before_cond4(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, add2_y):
* square0 = Square(input0)
* mul1 = Mul(mul1_x, input0)
* mul0 = Mul(mul0_x, input2)
* mul2 = Mul(mul2_x, input1)
* mul3 = Mul(mul3_x, square0)
* add0 = Add(mul0, mul1)
* add1 = Add(mul2, mul3)
* sqrt0 = Sqrt(add1)
* add2 = Add(add2_y, sqrt0)
* true_div0 = RealDiv(add0, add2)
* mul4 = Mul(true_div0, input4)
* sub0 = Sub(input3, mul4)
* outputs = make_tuple(add1, add0, sub0)
* output = tuple_getitem(outputs, 0)
* return output
*/
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_dyn_fusion", "before_cond4");
auto x_abstract = GetInputAbstract();
AbstractBasePtrList args_spec_list;
for (size_t i = 0; i < 10; ++i) {
args_spec_list.push_back(x_abstract);
}
auto fg = GetKernelGraph(g, args_spec_list);
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::AdamApplyOneCond4Fusion>());
optimizer->AddPassManager(pm);
FuncGraphPtr new_graph = optimizer->Optimize(fg);
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_dyn_fusion", "after");
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,144 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from mindspore.ops import Primitive
from mindspore.ops import operations as P
from mindspore.ops import _constants as Constants
Add = P.Add()
Sub = P.Sub()
Mul = P.Mul()
RealDiv = P.RealDiv()
Sqrt = P.Sqrt()
Square = P.Square()
Assign = P.Assign()
make_tuple = Primitive('MakeTuple')
tuple_getitem = Primitive(Constants.kTupleGetItem)
AdamApplyOne = Primitive('AdamApplyOne')
AdamApplyOneAssign = Primitive('AdamApplyOneAssign')
class FnDict:
def __init__(self):
self.fn_dict = {}
def __call__(self, fn):
self.fn_dict[fn.__name__] = fn
def __getitem__(self, name):
return self.fn_dict.get(name, None)
def test_adam_apply_one_dyn_fusion(tag):
"""
Feature: test AdamApplyOne dynamic shape
Description: The input shape is dynamic
Expectation: Assert that result is error
"""
fns = FnDict()
@fns
def before(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, add2_y):
square0 = Square(input0)
mul1 = Mul(mul1_x, input0)
mul0 = Mul(mul0_x, input2)
mul2 = Mul(mul2_x, input1)
mul3 = Mul(mul3_x, square0)
add0 = Add(mul0, mul1)
add1 = Add(mul2, mul3)
sqrt0 = Sqrt(add1)
add2 = Add(sqrt0, add2_y)
true_div0 = RealDiv(add0, add2)
mul4 = Mul(input4, true_div0)
sub0 = Sub(input3, mul4)
outputs = make_tuple(add1, add0, sub0)
return outputs
@fns
def before_cond1(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, add2_y):
square0 = Square(input0)
mul1 = Mul(mul1_x, input0)
mul0 = Mul(mul0_x, input2)
mul2 = Mul(mul2_x, input1)
mul3 = Mul(mul3_x, square0)
add0 = Add(mul0, mul1)
add1 = Add(mul2, mul3)
sqrt0 = Sqrt(add1)
add2 = Add(add2_y, sqrt0)
true_div0 = RealDiv(add0, add2)
mul4 = Mul(input4, true_div0)
sub0 = Sub(input3, mul4)
outputs = make_tuple(add1, add0, sub0)
return outputs
@fns
def before_cond2(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, add2_y):
square0 = Square(input0)
mul1 = Mul(mul1_x, input0)
mul0 = Mul(mul0_x, input2)
mul2 = Mul(mul2_x, input1)
mul3 = Mul(square0, mul3_x)
add0 = Add(mul0, mul1)
add1 = Add(mul2, mul3)
sqrt0 = Sqrt(add1)
add2 = Add(sqrt0, add2_y)
true_div0 = RealDiv(add0, add2)
mul4 = Mul(true_div0, input4)
sub0 = Sub(input3, mul4)
outputs = make_tuple(add1, add0, sub0)
return outputs
@fns
def before_cond3(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, add2_y):
square0 = Square(input0)
mul1 = Mul(mul1_x, input0)
mul0 = Mul(mul0_x, input2)
mul2 = Mul(mul2_x, input1)
mul3 = Mul(mul3_x, square0)
add0 = Add(mul0, mul1)
add1 = Add(mul2, mul3)
sqrt0 = Sqrt(add1)
add2 = Add(sqrt0, add2_y)
true_div0 = RealDiv(add0, add2)
mul4 = Mul(true_div0, input4)
sub0 = Sub(input3, mul4)
outputs = make_tuple(add1, add0, sub0)
return outputs
@fns
def before_cond4(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, add2_y):
square0 = Square(input0)
mul1 = Mul(mul1_x, input0)
mul0 = Mul(mul0_x, input2)
mul2 = Mul(mul2_x, input1)
mul3 = Mul(mul3_x, square0)
add0 = Add(mul0, mul1)
add1 = Add(mul2, mul3)
sqrt0 = Sqrt(add1)
add2 = Add(add2_y, sqrt0)
true_div0 = RealDiv(add0, add2)
mul4 = Mul(true_div0, input4)
sub0 = Sub(input3, mul4)
outputs = make_tuple(add1, add0, sub0)
return outputs
@fns
def after(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, add2_y):
adam_apply_one = AdamApplyOne(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, add2_y)
outputs = make_tuple(tuple_getitem(adam_apply_one, 0), tuple_getitem(adam_apply_one, 1),
tuple_getitem(adam_apply_one, 2))
return make_tuple(outputs)
return fns[tag]