!36991 refactor dynamic shape of add and dropout on npu.
Merge pull request !36991 from yangshuo/master
This commit is contained in:
commit
a025d9eed9
|
@ -55,9 +55,6 @@ AbstractBasePtr InferImplHSigmoidGrad(const AnalysisEnginePtr &, const Primitive
|
||||||
const AbstractBasePtrList &args_spec_list);
|
const AbstractBasePtrList &args_spec_list);
|
||||||
AbstractBasePtr InferImplBpropCut(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr InferImplBpropCut(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const AbstractBasePtrList &args_spec_list);
|
const AbstractBasePtrList &args_spec_list);
|
||||||
AbstractBasePtr InferImplDropout(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
|
||||||
const AbstractBasePtrList &args_spec_list);
|
|
||||||
|
|
||||||
AbstractBasePtr InferImplMinOrMaxGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr InferImplMinOrMaxGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const AbstractBasePtrList &args_spec_list);
|
const AbstractBasePtrList &args_spec_list);
|
||||||
AbstractBasePtr InferImplSqrt(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr InferImplSqrt(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
|
|
@ -245,23 +245,6 @@ AbstractBasePtr InferImplBpropCut(const AnalysisEnginePtr &, const PrimitivePtr
|
||||||
return std::make_shared<AbstractTuple>(args_list);
|
return std::make_shared<AbstractTuple>(args_list);
|
||||||
}
|
}
|
||||||
|
|
||||||
AbstractBasePtr InferImplDropout(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
|
||||||
const AbstractBasePtrList &args_spec_list) {
|
|
||||||
const std::string op_name = primitive->name();
|
|
||||||
CheckArgsSize(op_name, args_spec_list, 1);
|
|
||||||
auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
|
|
||||||
MS_EXCEPTION_IF_NULL(x);
|
|
||||||
MS_EXCEPTION_IF_NULL(x->shape());
|
|
||||||
ShapeVector shape = x->shape()->shape();
|
|
||||||
ShapeVector min_shape = x->shape()->min_shape();
|
|
||||||
ShapeVector max_shape = x->shape()->max_shape();
|
|
||||||
CheckMinMaxShape(shape, &min_shape, &max_shape);
|
|
||||||
auto output_shape =
|
|
||||||
std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(shape, min_shape, max_shape));
|
|
||||||
AbstractBasePtrList ret = {output_shape, output_shape};
|
|
||||||
return std::make_shared<AbstractTuple>(ret);
|
|
||||||
}
|
|
||||||
|
|
||||||
AbstractBasePtr InferImplSparseApplyFtrl(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr InferImplSparseApplyFtrl(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const AbstractBasePtrList &args_spec_list) {
|
const AbstractBasePtrList &args_spec_list) {
|
||||||
CheckRequiredArgsSize(primitive->name(), args_spec_list, 5);
|
CheckRequiredArgsSize(primitive->name(), args_spec_list, 5);
|
||||||
|
|
|
@ -52,9 +52,12 @@
|
||||||
#include "ops/tensor_scatter_arithmetic.h"
|
#include "ops/tensor_scatter_arithmetic.h"
|
||||||
#include "ops/max_pool.h"
|
#include "ops/max_pool.h"
|
||||||
#include "ops/grad/max_pool_grad.h"
|
#include "ops/grad/max_pool_grad.h"
|
||||||
|
#include "ops/dropout.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace abstract {
|
namespace abstract {
|
||||||
|
using ops::InferImplDropout;
|
||||||
|
|
||||||
PrimShapeDependMap &GetHostDependsMap() {
|
PrimShapeDependMap &GetHostDependsMap() {
|
||||||
// Registration directly by the host_depends map will be deprecated and
|
// Registration directly by the host_depends map will be deprecated and
|
||||||
// should be registered by the REGISTER_HOST_DEPENDS
|
// should be registered by the REGISTER_HOST_DEPENDS
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -13,15 +13,17 @@
|
||||||
* See the License for the specific language governing permissions and
|
* See the License for the specific language governing permissions and
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
#include "ops/dropout.h"
|
||||||
|
|
||||||
#include <set>
|
#include <set>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
#include "ops/dropout.h"
|
|
||||||
#include "ops/op_utils.h"
|
#include "ops/op_utils.h"
|
||||||
#include "utils/check_convert_utils.h"
|
#include "utils/check_convert_utils.h"
|
||||||
#include "mindapi/src/helper.h"
|
#include "mindapi/src/helper.h"
|
||||||
|
#include "abstract/ops/primitive_infer_map.h"
|
||||||
|
#include "abstract/param_validator.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
@ -37,7 +39,22 @@ float Dropout::get_keep_prob() const {
|
||||||
auto value_ptr = this->GetAttr(kKeepProb);
|
auto value_ptr = this->GetAttr(kKeepProb);
|
||||||
return GetValue<float>(value_ptr);
|
return GetValue<float>(value_ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
REGISTER_PRIMITIVE_C(kNameDropout, Dropout);
|
REGISTER_PRIMITIVE_C(kNameDropout, Dropout);
|
||||||
|
AbstractBasePtr InferImplDropout(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const std::vector<abstract::AbstractBasePtr> &args_spec_list) {
|
||||||
|
auto op_name = primitive->name();
|
||||||
|
CheckArgsSize(op_name, args_spec_list, 1);
|
||||||
|
auto x = abstract::CheckArg<abstract::AbstractTensor>(op_name, args_spec_list, 0);
|
||||||
|
MS_EXCEPTION_IF_NULL(x);
|
||||||
|
MS_EXCEPTION_IF_NULL(x->shape());
|
||||||
|
ShapeVector shape = x->shape()->shape();
|
||||||
|
ShapeVector min_shape = x->shape()->min_shape();
|
||||||
|
ShapeVector max_shape = x->shape()->max_shape();
|
||||||
|
CheckAndConvertUtils::CheckMinMaxShape(shape, &min_shape, &max_shape);
|
||||||
|
auto output_shape = std::make_shared<abstract::AbstractTensor>(
|
||||||
|
x->element(), std::make_shared<abstract::Shape>(shape, min_shape, max_shape));
|
||||||
|
AbstractBasePtrList ret = {output_shape, output_shape};
|
||||||
|
return std::make_shared<abstract::AbstractTuple>(ret);
|
||||||
|
}
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -41,8 +41,8 @@ class MIND_API Dropout : public BaseOperator {
|
||||||
/// \return keep_prob.
|
/// \return keep_prob.
|
||||||
float get_keep_prob() const;
|
float get_keep_prob() const;
|
||||||
};
|
};
|
||||||
abstract::AbstractBasePtr DropoutInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
abstract::AbstractBasePtr InferImplDropout(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CORE_OPS_DROPOUT_H_
|
#endif // MINDSPORE_CORE_OPS_DROPOUT_H_
|
||||||
|
|
|
@ -33,6 +33,8 @@ from .acosh_grad_ds import _acosh_grad_ds_tbe
|
||||||
from .adam_apply_one_with_decay import _adam_apply_one_with_decay_tbe
|
from .adam_apply_one_with_decay import _adam_apply_one_with_decay_tbe
|
||||||
from .apply_centered_rms_prop import _apply_centered_rms_prop_tbe
|
from .apply_centered_rms_prop import _apply_centered_rms_prop_tbe
|
||||||
from .apply_centered_rms_prop_ds import _apply_centered_rms_prop_ds_tbe
|
from .apply_centered_rms_prop_ds import _apply_centered_rms_prop_ds_tbe
|
||||||
|
from .add import _add_tbe
|
||||||
|
from .add_ds import _add_ds_tbe
|
||||||
from .add_n import _add_n_tbe
|
from .add_n import _add_n_tbe
|
||||||
from .add_n_ds import _add_n_ds_tbe
|
from .add_n_ds import _add_n_ds_tbe
|
||||||
from .addcdiv import _addcdiv_tbe
|
from .addcdiv import _addcdiv_tbe
|
||||||
|
@ -137,8 +139,6 @@ from .sigmoid_cross_entropy_with_logits_ds import _sigmoid_cross_entropy_with_lo
|
||||||
from .sigmoid_cross_entropy_with_logits_grad import _sigmoid_cross_entropy_with_logits_grad_tbe
|
from .sigmoid_cross_entropy_with_logits_grad import _sigmoid_cross_entropy_with_logits_grad_tbe
|
||||||
from .sparse_apply_adadelta import _sparse_apply_adadelta_tbe
|
from .sparse_apply_adadelta import _sparse_apply_adadelta_tbe
|
||||||
from .sigmoid_cross_entropy_with_logits_grad_ds import _sigmoid_cross_entropy_with_logits_grad_ds_tbe
|
from .sigmoid_cross_entropy_with_logits_grad_ds import _sigmoid_cross_entropy_with_logits_grad_ds_tbe
|
||||||
from .tensor_add import _tensor_add_tbe
|
|
||||||
from .tensor_add_ds import _tensor_add_ds_tbe
|
|
||||||
from .trans_data import _trans_data_tbe
|
from .trans_data import _trans_data_tbe
|
||||||
from .trans_data_ds import _trans_data_ds_tbe
|
from .trans_data_ds import _trans_data_ds_tbe
|
||||||
from .trans_data_rnn import _trans_data_rnn_tbe
|
from .trans_data_rnn import _trans_data_rnn_tbe
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
"""Add op"""
|
"""Add op"""
|
||||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||||
|
|
||||||
tensor_add_op_info = TBERegOp("Add") \
|
add_op_info = TBERegOp("Add") \
|
||||||
.fusion_type("ELEMWISE") \
|
.fusion_type("ELEMWISE") \
|
||||||
.async_flag(False) \
|
.async_flag(False) \
|
||||||
.binfile_name("add.so") \
|
.binfile_name("add.so") \
|
||||||
|
@ -36,7 +36,7 @@ tensor_add_op_info = TBERegOp("Add") \
|
||||||
.get_op_info()
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
@op_info_register(tensor_add_op_info)
|
@op_info_register(add_op_info)
|
||||||
def _tensor_add_tbe():
|
def _add_tbe():
|
||||||
"""Add TBE register"""
|
"""Add TBE register"""
|
||||||
return
|
return
|
|
@ -16,7 +16,7 @@
|
||||||
"""TensorAdd op"""
|
"""TensorAdd op"""
|
||||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||||
|
|
||||||
tensor_add_op_info = TBERegOp("Add") \
|
add_op_info = TBERegOp("Add") \
|
||||||
.fusion_type("ELEMWISE") \
|
.fusion_type("ELEMWISE") \
|
||||||
.async_flag(False) \
|
.async_flag(False) \
|
||||||
.binfile_name("add.so") \
|
.binfile_name("add.so") \
|
||||||
|
@ -37,7 +37,7 @@ tensor_add_op_info = TBERegOp("Add") \
|
||||||
.get_op_info()
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
@op_info_register(tensor_add_op_info)
|
@op_info_register(add_op_info)
|
||||||
def _tensor_add_ds_tbe():
|
def _add_ds_tbe():
|
||||||
"""Add TBE register"""
|
"""Add TBE register"""
|
||||||
return
|
return
|
|
@ -1,4 +1,4 @@
|
||||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
# Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -13,37 +13,101 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pytest
|
||||||
import mindspore.context as context
|
import mindspore.context as context
|
||||||
import mindspore.nn as nn
|
import mindspore.nn as nn
|
||||||
from mindspore import Tensor
|
from mindspore import Tensor
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
|
|
||||||
|
|
||||||
class Net(nn.Cell):
|
class AddNet(nn.Cell):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(Net, self).__init__()
|
super().__init__()
|
||||||
self.add = P.Add()
|
self.add = P.Add()
|
||||||
|
|
||||||
def construct(self, x_, y_):
|
def construct(self, x_, y_):
|
||||||
return self.add(x_, y_)
|
return self.add(x_, y_)
|
||||||
|
|
||||||
|
|
||||||
x = np.random.randn(1, 3, 3, 4).astype(np.float32)
|
class AddDynamicShapeNet(nn.Cell):
|
||||||
y = np.random.randn(1, 3, 3, 4).astype(np.float32)
|
def __init__(self, axis=0):
|
||||||
|
super().__init__()
|
||||||
|
self.unique = P.Unique()
|
||||||
|
self.gather = P.Gather()
|
||||||
|
self.add = P.Add()
|
||||||
|
self.axis = axis
|
||||||
|
|
||||||
|
def construct(self, x_, y_, indices):
|
||||||
|
u_indices, _ = self.unique(indices)
|
||||||
|
x_ = self.gather(x_, u_indices, self.axis)
|
||||||
|
y_ = self.gather(y_, u_indices, self.axis)
|
||||||
|
return self.add(x_, y_)
|
||||||
|
|
||||||
|
|
||||||
def test_net():
|
def comput_expect(x, y):
|
||||||
|
return np.add(x, y)
|
||||||
|
|
||||||
|
|
||||||
|
def add_net(*args, is_dynamic=False):
|
||||||
|
op = args[0]
|
||||||
|
x = args[1]
|
||||||
|
y = args[2]
|
||||||
|
if is_dynamic:
|
||||||
|
out = op(Tensor(x), Tensor(y), Tensor(args[3]))
|
||||||
|
else:
|
||||||
|
out = op(Tensor(x), Tensor(y))
|
||||||
|
if is_dynamic:
|
||||||
|
print("input shape: ", x.shape)
|
||||||
|
print("output shape: ", out.shape)
|
||||||
|
else:
|
||||||
|
assert np.allclose(out.asnumpy(), comput_expect(x, y), 1e-3, 1e-3)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip
|
||||||
|
def test_add(dtype=np.float16):
|
||||||
|
"""
|
||||||
|
Feature: test add operator in graph and pynative mode.
|
||||||
|
Description: test add.
|
||||||
|
Expectation: the result is correct
|
||||||
|
"""
|
||||||
|
x = np.random.randn(3, 3, 4).astype(dtype)
|
||||||
|
y = np.random.randn(3, 3, 4).astype(dtype)
|
||||||
|
indices = np.random.randint(0, 3, size=3)
|
||||||
|
|
||||||
|
net = AddNet()
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||||
add = Net()
|
add_net(net, x, y)
|
||||||
output = add(Tensor(x), Tensor(y))
|
|
||||||
print(x)
|
|
||||||
print(y)
|
|
||||||
print(output.asnumpy())
|
|
||||||
|
|
||||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||||
add = Net()
|
add_net(net, x, y)
|
||||||
output = add(Tensor(x), Tensor(y))
|
|
||||||
print(x)
|
net = AddDynamicShapeNet()
|
||||||
print(y)
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||||
print(output.asnumpy())
|
add_net(net, x, y, indices, is_dynamic=True)
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||||
|
add_net(net, x, y, indices, is_dynamic=True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_add_float16():
|
||||||
|
"""
|
||||||
|
Feature: test add operator.
|
||||||
|
Description: test float16 input.
|
||||||
|
Expectation: the result is correct
|
||||||
|
"""
|
||||||
|
test_add(np.float16)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_add_float32():
|
||||||
|
"""
|
||||||
|
Feature: test add operator.
|
||||||
|
Description: test float32 input.
|
||||||
|
Expectation: the result is correct
|
||||||
|
"""
|
||||||
|
test_add(np.float32)
|
||||||
|
|
|
@ -0,0 +1,103 @@
|
||||||
|
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import mindspore.context as context
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.op = P.Dropout()
|
||||||
|
|
||||||
|
def construct(self, x_):
|
||||||
|
return self.op(x_)
|
||||||
|
|
||||||
|
|
||||||
|
class DynamicShapeNet(nn.Cell):
|
||||||
|
def __init__(self, axis=0):
|
||||||
|
super().__init__()
|
||||||
|
self.unique = P.Unique()
|
||||||
|
self.gather = P.Gather()
|
||||||
|
self.op = P.Dropout()
|
||||||
|
self.axis = axis
|
||||||
|
|
||||||
|
def construct(self, x_, indices):
|
||||||
|
u_indices, _ = self.unique(indices)
|
||||||
|
x_ = self.gather(x_, u_indices, self.axis)
|
||||||
|
return self.op(x_)
|
||||||
|
|
||||||
|
|
||||||
|
def dropout_net(*args, is_dynamic=False):
|
||||||
|
op = args[0]
|
||||||
|
x = args[1]
|
||||||
|
if is_dynamic:
|
||||||
|
out = op(Tensor(x), Tensor(args[2]))
|
||||||
|
else:
|
||||||
|
out = op(Tensor(x))
|
||||||
|
print("input shape: ", x.shape)
|
||||||
|
print("output shape: ", out[0].shape)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip
|
||||||
|
def test_dropout(dtype=np.float16):
|
||||||
|
"""
|
||||||
|
Feature: test dropout operator in graph and pynative mode.
|
||||||
|
Description: test dropout.
|
||||||
|
Expectation: the result is correct
|
||||||
|
"""
|
||||||
|
x = np.random.randn(3, 3, 4).astype(dtype)
|
||||||
|
indices = np.random.randint(0, 3, size=3)
|
||||||
|
|
||||||
|
net = Net()
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||||
|
dropout_net(net, x)
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||||
|
dropout_net(net, x)
|
||||||
|
|
||||||
|
net = DynamicShapeNet()
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||||
|
dropout_net(net, x, indices, is_dynamic=True)
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||||
|
dropout_net(net, x, indices, is_dynamic=True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_float16():
|
||||||
|
"""
|
||||||
|
Feature: test dropout operator.
|
||||||
|
Description: test float16 input.
|
||||||
|
Expectation: the result is correct
|
||||||
|
"""
|
||||||
|
test_dropout(np.float16)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_float32():
|
||||||
|
"""
|
||||||
|
Feature: test dropout operator.
|
||||||
|
Description: test float32 input.
|
||||||
|
Expectation: the result is correct
|
||||||
|
"""
|
||||||
|
test_dropout(np.float32)
|
Loading…
Reference in New Issue