!36991 refactor dynamic shape of add and dropout on npu.

Merge pull request !36991 from yangshuo/master
This commit is contained in:
i-robot 2022-07-11 13:04:17 +00:00 committed by Gitee
commit a025d9eed9
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
10 changed files with 218 additions and 51 deletions

View File

@ -55,9 +55,6 @@ AbstractBasePtr InferImplHSigmoidGrad(const AnalysisEnginePtr &, const Primitive
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplBpropCut(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplDropout(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplMinOrMaxGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplSqrt(const AnalysisEnginePtr &, const PrimitivePtr &primitive,

View File

@ -245,23 +245,6 @@ AbstractBasePtr InferImplBpropCut(const AnalysisEnginePtr &, const PrimitivePtr
return std::make_shared<AbstractTuple>(args_list);
}
AbstractBasePtr InferImplDropout(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 1);
auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
MS_EXCEPTION_IF_NULL(x);
MS_EXCEPTION_IF_NULL(x->shape());
ShapeVector shape = x->shape()->shape();
ShapeVector min_shape = x->shape()->min_shape();
ShapeVector max_shape = x->shape()->max_shape();
CheckMinMaxShape(shape, &min_shape, &max_shape);
auto output_shape =
std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(shape, min_shape, max_shape));
AbstractBasePtrList ret = {output_shape, output_shape};
return std::make_shared<AbstractTuple>(ret);
}
AbstractBasePtr InferImplSparseApplyFtrl(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
CheckRequiredArgsSize(primitive->name(), args_spec_list, 5);

View File

@ -52,9 +52,12 @@
#include "ops/tensor_scatter_arithmetic.h"
#include "ops/max_pool.h"
#include "ops/grad/max_pool_grad.h"
#include "ops/dropout.h"
namespace mindspore {
namespace abstract {
using ops::InferImplDropout;
PrimShapeDependMap &GetHostDependsMap() {
// Registration directly by the host_depends map will be deprecated and
// should be registered by the REGISTER_HOST_DEPENDS

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -13,15 +13,17 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/dropout.h"
#include <set>
#include <vector>
#include <memory>
#include "ops/dropout.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "mindapi/src/helper.h"
#include "abstract/ops/primitive_infer_map.h"
#include "abstract/param_validator.h"
namespace mindspore {
namespace ops {
@ -37,7 +39,22 @@ float Dropout::get_keep_prob() const {
auto value_ptr = this->GetAttr(kKeepProb);
return GetValue<float>(value_ptr);
}
REGISTER_PRIMITIVE_C(kNameDropout, Dropout);
AbstractBasePtr InferImplDropout(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &args_spec_list) {
auto op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 1);
auto x = abstract::CheckArg<abstract::AbstractTensor>(op_name, args_spec_list, 0);
MS_EXCEPTION_IF_NULL(x);
MS_EXCEPTION_IF_NULL(x->shape());
ShapeVector shape = x->shape()->shape();
ShapeVector min_shape = x->shape()->min_shape();
ShapeVector max_shape = x->shape()->max_shape();
CheckAndConvertUtils::CheckMinMaxShape(shape, &min_shape, &max_shape);
auto output_shape = std::make_shared<abstract::AbstractTensor>(
x->element(), std::make_shared<abstract::Shape>(shape, min_shape, max_shape));
AbstractBasePtrList ret = {output_shape, output_shape};
return std::make_shared<abstract::AbstractTuple>(ret);
}
} // namespace ops
} // namespace mindspore

View File

@ -41,8 +41,8 @@ class MIND_API Dropout : public BaseOperator {
/// \return keep_prob.
float get_keep_prob() const;
};
abstract::AbstractBasePtr DropoutInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
abstract::AbstractBasePtr InferImplDropout(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_DROPOUT_H_

View File

@ -33,6 +33,8 @@ from .acosh_grad_ds import _acosh_grad_ds_tbe
from .adam_apply_one_with_decay import _adam_apply_one_with_decay_tbe
from .apply_centered_rms_prop import _apply_centered_rms_prop_tbe
from .apply_centered_rms_prop_ds import _apply_centered_rms_prop_ds_tbe
from .add import _add_tbe
from .add_ds import _add_ds_tbe
from .add_n import _add_n_tbe
from .add_n_ds import _add_n_ds_tbe
from .addcdiv import _addcdiv_tbe
@ -137,8 +139,6 @@ from .sigmoid_cross_entropy_with_logits_ds import _sigmoid_cross_entropy_with_lo
from .sigmoid_cross_entropy_with_logits_grad import _sigmoid_cross_entropy_with_logits_grad_tbe
from .sparse_apply_adadelta import _sparse_apply_adadelta_tbe
from .sigmoid_cross_entropy_with_logits_grad_ds import _sigmoid_cross_entropy_with_logits_grad_ds_tbe
from .tensor_add import _tensor_add_tbe
from .tensor_add_ds import _tensor_add_ds_tbe
from .trans_data import _trans_data_tbe
from .trans_data_ds import _trans_data_ds_tbe
from .trans_data_rnn import _trans_data_rnn_tbe

View File

@ -16,7 +16,7 @@
"""Add op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
tensor_add_op_info = TBERegOp("Add") \
add_op_info = TBERegOp("Add") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("add.so") \
@ -36,7 +36,7 @@ tensor_add_op_info = TBERegOp("Add") \
.get_op_info()
@op_info_register(tensor_add_op_info)
def _tensor_add_tbe():
@op_info_register(add_op_info)
def _add_tbe():
"""Add TBE register"""
return

View File

@ -16,7 +16,7 @@
"""TensorAdd op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
tensor_add_op_info = TBERegOp("Add") \
add_op_info = TBERegOp("Add") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("add.so") \
@ -37,7 +37,7 @@ tensor_add_op_info = TBERegOp("Add") \
.get_op_info()
@op_info_register(tensor_add_op_info)
def _tensor_add_ds_tbe():
@op_info_register(add_op_info)
def _add_ds_tbe():
"""Add TBE register"""
return

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -13,37 +13,101 @@
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P
class Net(nn.Cell):
class AddNet(nn.Cell):
def __init__(self):
super(Net, self).__init__()
super().__init__()
self.add = P.Add()
def construct(self, x_, y_):
return self.add(x_, y_)
x = np.random.randn(1, 3, 3, 4).astype(np.float32)
y = np.random.randn(1, 3, 3, 4).astype(np.float32)
class AddDynamicShapeNet(nn.Cell):
def __init__(self, axis=0):
super().__init__()
self.unique = P.Unique()
self.gather = P.Gather()
self.add = P.Add()
self.axis = axis
def construct(self, x_, y_, indices):
u_indices, _ = self.unique(indices)
x_ = self.gather(x_, u_indices, self.axis)
y_ = self.gather(y_, u_indices, self.axis)
return self.add(x_, y_)
def test_net():
def comput_expect(x, y):
return np.add(x, y)
def add_net(*args, is_dynamic=False):
op = args[0]
x = args[1]
y = args[2]
if is_dynamic:
out = op(Tensor(x), Tensor(y), Tensor(args[3]))
else:
out = op(Tensor(x), Tensor(y))
if is_dynamic:
print("input shape: ", x.shape)
print("output shape: ", out.shape)
else:
assert np.allclose(out.asnumpy(), comput_expect(x, y), 1e-3, 1e-3)
@pytest.mark.skip
def test_add(dtype=np.float16):
"""
Feature: test add operator in graph and pynative mode.
Description: test add.
Expectation: the result is correct
"""
x = np.random.randn(3, 3, 4).astype(dtype)
y = np.random.randn(3, 3, 4).astype(dtype)
indices = np.random.randint(0, 3, size=3)
net = AddNet()
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
add = Net()
output = add(Tensor(x), Tensor(y))
print(x)
print(y)
print(output.asnumpy())
add_net(net, x, y)
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
add = Net()
output = add(Tensor(x), Tensor(y))
print(x)
print(y)
print(output.asnumpy())
add_net(net, x, y)
net = AddDynamicShapeNet()
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
add_net(net, x, y, indices, is_dynamic=True)
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
add_net(net, x, y, indices, is_dynamic=True)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_add_float16():
"""
Feature: test add operator.
Description: test float16 input.
Expectation: the result is correct
"""
test_add(np.float16)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_add_float32():
"""
Feature: test add operator.
Description: test float32 input.
Expectation: the result is correct
"""
test_add(np.float32)

View File

@ -0,0 +1,103 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P
class Net(nn.Cell):
def __init__(self):
super().__init__()
self.op = P.Dropout()
def construct(self, x_):
return self.op(x_)
class DynamicShapeNet(nn.Cell):
def __init__(self, axis=0):
super().__init__()
self.unique = P.Unique()
self.gather = P.Gather()
self.op = P.Dropout()
self.axis = axis
def construct(self, x_, indices):
u_indices, _ = self.unique(indices)
x_ = self.gather(x_, u_indices, self.axis)
return self.op(x_)
def dropout_net(*args, is_dynamic=False):
op = args[0]
x = args[1]
if is_dynamic:
out = op(Tensor(x), Tensor(args[2]))
else:
out = op(Tensor(x))
print("input shape: ", x.shape)
print("output shape: ", out[0].shape)
@pytest.mark.skip
def test_dropout(dtype=np.float16):
"""
Feature: test dropout operator in graph and pynative mode.
Description: test dropout.
Expectation: the result is correct
"""
x = np.random.randn(3, 3, 4).astype(dtype)
indices = np.random.randint(0, 3, size=3)
net = Net()
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
dropout_net(net, x)
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
dropout_net(net, x)
net = DynamicShapeNet()
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
dropout_net(net, x, indices, is_dynamic=True)
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
dropout_net(net, x, indices, is_dynamic=True)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_float16():
"""
Feature: test dropout operator.
Description: test float16 input.
Expectation: the result is correct
"""
test_dropout(np.float16)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_float32():
"""
Feature: test dropout operator.
Description: test float32 input.
Expectation: the result is correct
"""
test_dropout(np.float32)