forked from mindspore-Ecosystem/mindspore
isnan isfinite isinf squaresumall identity oneslik
This commit is contained in:
parent
305f08811d
commit
50a66ae476
|
@ -33,6 +33,7 @@ from .fused_mul_add import FusedMulAdd
|
||||||
from .gelu import GeLU
|
from .gelu import GeLU
|
||||||
from .gelu_grad import GeLUGrad
|
from .gelu_grad import GeLUGrad
|
||||||
from .gkdropout import GkDropout
|
from .gkdropout import GkDropout
|
||||||
|
from .identity import Identity
|
||||||
from .lamb_apply_optimizer_assign import LambApplyOptimizerAssign
|
from .lamb_apply_optimizer_assign import LambApplyOptimizerAssign
|
||||||
from .lamb_apply_weight_assign import LambApplyWeightAssign
|
from .lamb_apply_weight_assign import LambApplyWeightAssign
|
||||||
from .layernorm import LayerNorm
|
from .layernorm import LayerNorm
|
||||||
|
@ -42,6 +43,7 @@ from .logsoftmax_grad import LogSoftmaxGrad
|
||||||
from .matmul import BatchMatMul, MatMul
|
from .matmul import BatchMatMul, MatMul
|
||||||
from .maximum_grad import MaximumGrad
|
from .maximum_grad import MaximumGrad
|
||||||
from .minimum_grad import MinimumGrad
|
from .minimum_grad import MinimumGrad
|
||||||
|
from .oneslike import OnesLike
|
||||||
from .reduce_mean import ReduceMean
|
from .reduce_mean import ReduceMean
|
||||||
from .relu import ReLU
|
from .relu import ReLU
|
||||||
from .relu_grad import ReluGrad
|
from .relu_grad import ReluGrad
|
||||||
|
@ -56,6 +58,7 @@ from .sqrt_grad import SqrtGrad
|
||||||
from .square import Square
|
from .square import Square
|
||||||
from .square_sum_v1 import SquareSumV1
|
from .square_sum_v1 import SquareSumV1
|
||||||
from .squared_difference import SquaredDifference
|
from .squared_difference import SquaredDifference
|
||||||
|
from .square_sum_all import SquareSumAll
|
||||||
from .squeeze import Squeeze
|
from .squeeze import Squeeze
|
||||||
from .tanh_grad import TanhGrad
|
from .tanh_grad import TanhGrad
|
||||||
from .tile import Tile
|
from .tile import Tile
|
||||||
|
|
|
@ -0,0 +1,23 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ===========================================================================
|
||||||
|
"""generate json desc for Identity"""
|
||||||
|
from ._utils import Expander
|
||||||
|
|
||||||
|
class Identity(Expander):
|
||||||
|
"""Identity expander"""
|
||||||
|
def _expand(self, graph_builder):
|
||||||
|
input_x = self.inputs[0]
|
||||||
|
result = graph_builder.emit('Reshape', [input_x], attrs={'shape': input_x.shape})
|
||||||
|
return result
|
|
@ -0,0 +1,24 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ===========================================================================
|
||||||
|
"""generate json desc for OnesLike"""
|
||||||
|
from ._utils import Expander
|
||||||
|
|
||||||
|
class OnesLike(Expander):
|
||||||
|
"""OnesLike expander"""
|
||||||
|
def _expand(self, graph_builder):
|
||||||
|
input_x = self.inputs[0]
|
||||||
|
const_one = graph_builder.value(input_x.dtype, 1)
|
||||||
|
result = graph_builder.emit('BroadcastTo', [const_one], attrs={'shape': input_x.shape})
|
||||||
|
return result
|
|
@ -0,0 +1,41 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ===========================================================================
|
||||||
|
"""generate json desc for SquareSumAll"""
|
||||||
|
from ._utils import Expander
|
||||||
|
|
||||||
|
class SquareSumAll(Expander):
|
||||||
|
"""SquareSumAll expander"""
|
||||||
|
def _check(self):
|
||||||
|
"""check inputs"""
|
||||||
|
input_num = len(self.inputs)
|
||||||
|
if input_num != 2:
|
||||||
|
raise GKException("SquareSumAll inputs number should be 2, but got {}.".format(input_num))
|
||||||
|
|
||||||
|
def _expand(self, graph_builder):
|
||||||
|
"""do expand"""
|
||||||
|
x0 = self.inputs[0]
|
||||||
|
x1 = self.inputs[1]
|
||||||
|
|
||||||
|
ori_shape = x0.shape
|
||||||
|
axis = []
|
||||||
|
for i, _ in enumerate(ori_shape):
|
||||||
|
axis.append(i)
|
||||||
|
|
||||||
|
square_res0 = graph_builder.emit('Mul', [x0, x0])
|
||||||
|
square_res1 = graph_builder.emit('Mul', [x1, x1])
|
||||||
|
result0 = graph_builder.emit('ReduceSum', [square_res0], attrs={'reduce_axis': axis, 'keep_dims': False})
|
||||||
|
result1 = graph_builder.emit('ReduceSum', [square_res1], attrs={'reduce_axis': axis, 'keep_dims': False})
|
||||||
|
|
||||||
|
return result0, result1
|
|
@ -155,6 +155,9 @@ class PrimLib:
|
||||||
'Mul': Prim(ELEMWISE),
|
'Mul': Prim(ELEMWISE),
|
||||||
'Sub': Prim(ELEMWISE),
|
'Sub': Prim(ELEMWISE),
|
||||||
'Log': Prim(ELEMWISE),
|
'Log': Prim(ELEMWISE),
|
||||||
|
'IsNan': Prim(ELEMWISE),
|
||||||
|
'IsInf': Prim(ELEMWISE),
|
||||||
|
'IsFinite': Prim(ELEMWISE),
|
||||||
'Exp': Prim(ELEMWISE),
|
'Exp': Prim(ELEMWISE),
|
||||||
'Rsqrt': Prim(ELEMWISE),
|
'Rsqrt': Prim(ELEMWISE),
|
||||||
'Sqrt': Prim(ELEMWISE),
|
'Sqrt': Prim(ELEMWISE),
|
||||||
|
|
|
@ -94,6 +94,9 @@ std::vector<PrimitivePtr> GetClusterableOpList() {
|
||||||
prim::kPrimLogicalAnd,
|
prim::kPrimLogicalAnd,
|
||||||
prim::kPrimLogicalOr,
|
prim::kPrimLogicalOr,
|
||||||
prim::kPrimLogicalNot,
|
prim::kPrimLogicalNot,
|
||||||
|
prim::kPrimIsNan,
|
||||||
|
prim::kPrimIsInf,
|
||||||
|
prim::kPrimIsFinite,
|
||||||
#endif
|
#endif
|
||||||
};
|
};
|
||||||
const auto &flags = context::GraphKernelFlags::GetInstance();
|
const auto &flags = context::GraphKernelFlags::GetInstance();
|
||||||
|
|
|
@ -88,6 +88,9 @@ std::vector<PrimitivePtr> GetExpandOps() {
|
||||||
prim::kPrimSquaredDifference,
|
prim::kPrimSquaredDifference,
|
||||||
prim::kPrimSqueeze,
|
prim::kPrimSqueeze,
|
||||||
prim::kPrimEqualCount,
|
prim::kPrimEqualCount,
|
||||||
|
prim::kPrimSquareSumAll,
|
||||||
|
prim::kPrimIdentityMath,
|
||||||
|
prim::kPrimOnesLike,
|
||||||
#endif
|
#endif
|
||||||
};
|
};
|
||||||
const auto &flags = context::GraphKernelFlags::GetInstance();
|
const auto &flags = context::GraphKernelFlags::GetInstance();
|
||||||
|
|
|
@ -2581,7 +2581,9 @@ bool GraphScheduler::CheckActorValid(const ActorSet *actor_set, GraphExecutionSt
|
||||||
// Check the data source actors.
|
// Check the data source actors.
|
||||||
for (const auto &data_source_actor : actor_set->data_source_actors_) {
|
for (const auto &data_source_actor : actor_set->data_source_actors_) {
|
||||||
MS_EXCEPTION_IF_NULL(data_source_actor);
|
MS_EXCEPTION_IF_NULL(data_source_actor);
|
||||||
if (data_source_actor->output_data_arrows_.size() + data_source_actor->output_result_arrows_.size() == 0) {
|
if (data_source_actor->output_data_arrows_.size() + data_source_actor->output_result_arrows_.size() +
|
||||||
|
data_source_actor->output_control_arrows_.size() ==
|
||||||
|
0) {
|
||||||
MS_LOG(ERROR) << data_source_actor->GetAID().Name() << " has no user.";
|
MS_LOG(ERROR) << data_source_actor->GetAID().Name() << " has no user.";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
|
@ -487,6 +487,10 @@ inline const PrimitivePtr kPrimAtanGrad = std::make_shared<Primitive>("AtanGrad"
|
||||||
inline const PrimitivePtr kPrimFloorMod = std::make_shared<Primitive>("FloorMod");
|
inline const PrimitivePtr kPrimFloorMod = std::make_shared<Primitive>("FloorMod");
|
||||||
inline const PrimitivePtr kPrimWhere = std::make_shared<Primitive>("Where");
|
inline const PrimitivePtr kPrimWhere = std::make_shared<Primitive>("Where");
|
||||||
inline const PrimitivePtr kPrimIdentityMath = std::make_shared<Primitive>("Identity", kSideEffectPropagate);
|
inline const PrimitivePtr kPrimIdentityMath = std::make_shared<Primitive>("Identity", kSideEffectPropagate);
|
||||||
|
inline const PrimitivePtr kPrimIsNan = std::make_shared<Primitive>("IsNan");
|
||||||
|
inline const PrimitivePtr kPrimIsInf = std::make_shared<Primitive>("IsInf");
|
||||||
|
inline const PrimitivePtr kPrimIsFinite = std::make_shared<Primitive>("IsFinite");
|
||||||
|
inline const PrimitivePtr kPrimSquareSumAll = std::make_shared<Primitive>("SquareSumAll");
|
||||||
|
|
||||||
// Statements
|
// Statements
|
||||||
inline const PrimitivePtr kPrimReturn = std::make_shared<Primitive>("Return");
|
inline const PrimitivePtr kPrimReturn = std::make_shared<Primitive>("Return");
|
||||||
|
|
|
@ -0,0 +1,56 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import mindspore.context as context
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore.nn import Cell
|
||||||
|
import mindspore.ops.operations as P
|
||||||
|
|
||||||
|
class Net(Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.identity = P.Identity()
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
return self.identity(x)
|
||||||
|
|
||||||
|
def get_output(x, enable_graph_kernel=False):
|
||||||
|
context.set_context(enable_graph_kernel=enable_graph_kernel)
|
||||||
|
net = Net()
|
||||||
|
output = net(x)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def test_basic(dtype):
|
||||||
|
expect_np = np.random.normal(0, 10, (16, 32)).astype(dtype)
|
||||||
|
x = Tensor(expect_np)
|
||||||
|
output = get_output(x, True)
|
||||||
|
output_np = output.asnumpy().copy()
|
||||||
|
assert np.allclose(expect_np, output_np, 1.e-4, 1.e-7)
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_gpu_1():
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||||
|
test_basic(np.float16)
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_gpu_2():
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||||
|
test_basic(np.float32)
|
|
@ -0,0 +1,56 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import mindspore.context as context
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore.nn import Cell
|
||||||
|
import mindspore.ops.operations as P
|
||||||
|
import mindspore.common.dtype as mstype
|
||||||
|
|
||||||
|
class Net(Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.oneslike = P.OnesLike()
|
||||||
|
|
||||||
|
def construct(self, shape, dtype, x):
|
||||||
|
return self.oneslike(x)
|
||||||
|
|
||||||
|
def get_output(shape, dtype, nptype, enable_graph_kernel=False):
|
||||||
|
context.set_context(enable_graph_kernel=enable_graph_kernel)
|
||||||
|
net = Net()
|
||||||
|
x = Tensor(np.random.normal(0, 10, shape).astype(nptype))
|
||||||
|
output = net(shape, dtype, x)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def test_basic(shape, dtype, nptype):
|
||||||
|
expect = get_output(shape, dtype, nptype, False)
|
||||||
|
output = get_output(shape, dtype, nptype, True)
|
||||||
|
assert np.allclose(expect.asnumpy(), output.asnumpy(), 1.e-4, 1.e-7)
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_gpu_1():
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||||
|
test_basic((2, 16), mstype.float16, np.float16)
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_gpu_2():
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||||
|
test_basic((4, 32), mstype.float32, np.float32)
|
|
@ -0,0 +1,61 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import mindspore.context as context
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore.nn import Cell
|
||||||
|
import mindspore.ops.operations as P
|
||||||
|
|
||||||
|
class Net(Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.squaresumall = P.SquareSumAll()
|
||||||
|
|
||||||
|
def construct(self, x0, x1):
|
||||||
|
return self.squaresumall(x0, x1)
|
||||||
|
|
||||||
|
def get_output(inp0, inp1, enable_graph_kernel=False):
|
||||||
|
context.set_context(enable_graph_kernel=enable_graph_kernel)
|
||||||
|
net = Net()
|
||||||
|
output = net(inp0, inp1)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def test_basic(datatype):
|
||||||
|
inp0 = Tensor(np.random.normal(1, 0.1, [800, 96]).astype(datatype))
|
||||||
|
inp1 = Tensor(np.random.normal(1, 0.1, [800, 96]).astype(datatype))
|
||||||
|
expect = get_output(inp0, inp1, False)
|
||||||
|
output = get_output(inp0, inp1, True)
|
||||||
|
expect_np0 = expect[0].asnumpy().copy()
|
||||||
|
output_np0 = output[0].asnumpy().copy()
|
||||||
|
expect_np1 = expect[1].asnumpy().copy()
|
||||||
|
output_np1 = output[1].asnumpy().copy()
|
||||||
|
assert np.allclose(expect_np0, output_np0, 1.e-4, 1.e-7)
|
||||||
|
assert np.allclose(expect_np1, output_np1, 1.e-4, 1.e-7)
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_gpu_1():
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||||
|
test_basic(np.float16)
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_gpu_2():
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||||
|
test_basic(np.float32)
|
Loading…
Reference in New Issue