!215 add ops: LogicalNot, LogicalAnd, LogicalOr, NotEqual, EqualCount, Asinh, Acosh

* add ops: LogicalNot, LogicalAnd, LogicalOr, NotEqual, EqualCount, Asinh, Acosh
This commit is contained in:
wangrao124 2021-07-13 06:35:24 +00:00
parent c168ecce09
commit 7cddde47b0
7 changed files with 122 additions and 0 deletions

View File

@ -57,3 +57,4 @@ from .squared_difference import SquaredDifference
from .squeeze import Squeeze
from .tanh_grad import TanhGrad
from .tile import Tile
from .equal_count import EqualCount

View File

@ -0,0 +1,49 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""generate json desc for equal_count"""
from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException
from ._utils import Expander, ExpanderInfoValidator as VLD
@VLD.check_all_formats_same
class EqualCount(Expander):
"""EqualCount expander"""
def __init__(self, expand_info):
super().__init__(expand_info)
self.shape_x = self.inputs[0]['shape']
self.shape_y = self.inputs[1]['shape']
self.dtype_x = self.inputs[0]['data_type']
self.dtype_y = self.inputs[1]['data_type']
def _check(self):
if self.shape_x != self.shape_y:
raise GKException("For 'EqualCount' the `x_shape` should be == `y_shape`: {}, \
but got {}".format(self.shape_y, self.shape_x))
if self.dtype_x != self.dtype_y:
raise GKException("For 'EqualCount' the data type of `y` should the same as `x`, but `x` with {}, \
and `y` with {}".format(self.dtype_x, self.dtype_y))
def _expand(self, graph_builder):
input_x = self.inputs[0]
input_y = self.inputs[1]
eql_val = graph_builder.emit('Equal', [input_x, input_y])
cast_val = graph_builder.emit('Cast', [eql_val], attrs={'dst_type': 'float32'})
axis = list(range(len(input_x.shape)))
result = graph_builder.emit('ReduceSum', [cast_val], attrs={'reduce_axis': axis, 'keep_dims': True})
if result.dtype != input_x.dtype:
result = graph_builder.emit('Cast', [result], attrs={'dst_type': input_x.dtype})
return result

View File

@ -165,10 +165,14 @@ class PrimLib:
'Maximum': Prim(ELEMWISE),
'Reciprocal': Prim(ELEMWISE),
'Equal': Prim(ELEMWISE),
'NotEqual': Prim(ELEMWISE),
'Greater': Prim(ELEMWISE),
'GreaterEqual': Prim(ELEMWISE),
'Less': Prim(ELEMWISE),
'LessEqual': Prim(ELEMWISE),
'LogicalNot': Prim(ELEMWISE),
'LogicalAnd': Prim(ELEMWISE),
'LogicalOr': Prim(ELEMWISE),
'Square': Prim(ELEMWISE),
'AddN': Prim(ELEMWISE),
'Select': Prim(ELEMWISE, 8),
@ -182,6 +186,8 @@ class PrimLib:
'Asin': Prim(ELEMWISE),
'ACos': Prim(ELEMWISE),
'Tanh': Prim(ELEMWISE),
'Asinh': Prim(ELEMWISE),
'Acosh': Prim(ELEMWISE),
'InplaceAssign': Prim(ELEMWISE),
'@ReduceInit': Prim(ELEMWISE),
'Reshape': Prim(RESHAPE),

View File

@ -60,11 +60,17 @@ std::vector<PrimitivePtr> GetClusterableOpList() {
prim::kPrimRealDiv,
prim::kPrimReduceSum,
prim::kPrimEqual,
prim::kPrimNotEqual,
prim::kPrimLogicalAnd,
prim::kPrimLogicalOr,
prim::kPrimLogicalNot,
prim::kPrimAssign,
prim::kPrimInplaceAssign,
prim::kPrimAtan,
prim::kPrimAtan2,
prim::kPrimExpm1,
prim::kPrimAsinh,
prim::kPrimAcosh,
#if ENABLE_D
prim::kPrimMatMul,
prim::KPrimTransData,

View File

@ -86,6 +86,7 @@ std::vector<PrimitivePtr> GetExpandOps() {
prim::kPrimSoftmaxCrossEntropyWithLogits,
prim::kPrimSquaredDifference,
prim::kPrimSqueeze,
prim::kPrimEqualCount,
#endif
};
const auto &flags = context::GraphKernelFlags::GetInstance();

View File

@ -136,6 +136,7 @@ inline const PrimitivePtr kPrimNotEqual = std::make_shared<Primitive>(kNotEqual)
inline const PrimitivePtr kPrimLogicalAnd = std::make_shared<Primitive>("LogicalAnd");
inline const PrimitivePtr kPrimLogicalOr = std::make_shared<Primitive>("LogicalOr");
inline const PrimitivePtr kPrimLogicalNot = std::make_shared<Primitive>("LogicalNot");
inline const PrimitivePtr kPrimEqualCount = std::make_shared<Primitive>("EqualCount");
inline const PrimitivePtr kPrimDistribute = std::make_shared<Primitive>("distribute");
inline const PrimitivePtr kPrimIm2Col = std::make_shared<Primitive>("im2col");

View File

@ -0,0 +1,58 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P
class EqualCount(nn.Cell):
def __init__(self):
super(EqualCount, self).__init__()
self.op = P.EqualCount()
def construct(self, *inp):
return self.op(*inp)
def get_output(*inp, enable_graph_kernel=False):
context.set_context(enable_graph_kernel=enable_graph_kernel)
output = EqualCount()(*inp)
return output
def basic_test(datatype):
x = Tensor(np.array([[1, 1, 1, 1], [3, 3, 3, 3]]).astype(datatype))
y = Tensor(np.array([[1, 2, 1, 2], [1, 1, 3, 3]]).astype(datatype))
expect = get_output(x, y, enable_graph_kernel=False)
output = get_output(x, y, enable_graph_kernel=True)
expect_np = expect.asnumpy().copy()
output_np = output.asnumpy().copy()
assert np.allclose(expect_np, output_np, 1.e-4, 1.e-7)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_gpu_fp16():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
basic_test(np.float16)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_gpu_fp32():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
basic_test(np.float32)