!215 add ops: LogicalNot, LogicalAnd, LogicalOr, NotEqual, EqualCount, Asinh, Acosh
* add ops: LogicalNot, LogicalAnd, LogicalOr, NotEqual, EqualCount, Asinh, Acosh
This commit is contained in:
parent
c168ecce09
commit
7cddde47b0
|
@ -57,3 +57,4 @@ from .squared_difference import SquaredDifference
|
|||
from .squeeze import Squeeze
|
||||
from .tanh_grad import TanhGrad
|
||||
from .tile import Tile
|
||||
from .equal_count import EqualCount
|
||||
|
|
|
@ -0,0 +1,49 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
"""generate json desc for equal_count"""
|
||||
from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException
|
||||
from ._utils import Expander, ExpanderInfoValidator as VLD
|
||||
|
||||
@VLD.check_all_formats_same
|
||||
class EqualCount(Expander):
|
||||
"""EqualCount expander"""
|
||||
|
||||
def __init__(self, expand_info):
|
||||
super().__init__(expand_info)
|
||||
self.shape_x = self.inputs[0]['shape']
|
||||
self.shape_y = self.inputs[1]['shape']
|
||||
self.dtype_x = self.inputs[0]['data_type']
|
||||
self.dtype_y = self.inputs[1]['data_type']
|
||||
|
||||
def _check(self):
|
||||
if self.shape_x != self.shape_y:
|
||||
raise GKException("For 'EqualCount' the `x_shape` should be == `y_shape`: {}, \
|
||||
but got {}".format(self.shape_y, self.shape_x))
|
||||
if self.dtype_x != self.dtype_y:
|
||||
raise GKException("For 'EqualCount' the data type of `y` should the same as `x`, but `x` with {}, \
|
||||
and `y` with {}".format(self.dtype_x, self.dtype_y))
|
||||
|
||||
def _expand(self, graph_builder):
|
||||
input_x = self.inputs[0]
|
||||
input_y = self.inputs[1]
|
||||
|
||||
eql_val = graph_builder.emit('Equal', [input_x, input_y])
|
||||
cast_val = graph_builder.emit('Cast', [eql_val], attrs={'dst_type': 'float32'})
|
||||
axis = list(range(len(input_x.shape)))
|
||||
result = graph_builder.emit('ReduceSum', [cast_val], attrs={'reduce_axis': axis, 'keep_dims': True})
|
||||
|
||||
if result.dtype != input_x.dtype:
|
||||
result = graph_builder.emit('Cast', [result], attrs={'dst_type': input_x.dtype})
|
||||
return result
|
|
@ -165,10 +165,14 @@ class PrimLib:
|
|||
'Maximum': Prim(ELEMWISE),
|
||||
'Reciprocal': Prim(ELEMWISE),
|
||||
'Equal': Prim(ELEMWISE),
|
||||
'NotEqual': Prim(ELEMWISE),
|
||||
'Greater': Prim(ELEMWISE),
|
||||
'GreaterEqual': Prim(ELEMWISE),
|
||||
'Less': Prim(ELEMWISE),
|
||||
'LessEqual': Prim(ELEMWISE),
|
||||
'LogicalNot': Prim(ELEMWISE),
|
||||
'LogicalAnd': Prim(ELEMWISE),
|
||||
'LogicalOr': Prim(ELEMWISE),
|
||||
'Square': Prim(ELEMWISE),
|
||||
'AddN': Prim(ELEMWISE),
|
||||
'Select': Prim(ELEMWISE, 8),
|
||||
|
@ -182,6 +186,8 @@ class PrimLib:
|
|||
'Asin': Prim(ELEMWISE),
|
||||
'ACos': Prim(ELEMWISE),
|
||||
'Tanh': Prim(ELEMWISE),
|
||||
'Asinh': Prim(ELEMWISE),
|
||||
'Acosh': Prim(ELEMWISE),
|
||||
'InplaceAssign': Prim(ELEMWISE),
|
||||
'@ReduceInit': Prim(ELEMWISE),
|
||||
'Reshape': Prim(RESHAPE),
|
||||
|
|
|
@ -60,11 +60,17 @@ std::vector<PrimitivePtr> GetClusterableOpList() {
|
|||
prim::kPrimRealDiv,
|
||||
prim::kPrimReduceSum,
|
||||
prim::kPrimEqual,
|
||||
prim::kPrimNotEqual,
|
||||
prim::kPrimLogicalAnd,
|
||||
prim::kPrimLogicalOr,
|
||||
prim::kPrimLogicalNot,
|
||||
prim::kPrimAssign,
|
||||
prim::kPrimInplaceAssign,
|
||||
prim::kPrimAtan,
|
||||
prim::kPrimAtan2,
|
||||
prim::kPrimExpm1,
|
||||
prim::kPrimAsinh,
|
||||
prim::kPrimAcosh,
|
||||
#if ENABLE_D
|
||||
prim::kPrimMatMul,
|
||||
prim::KPrimTransData,
|
||||
|
|
|
@ -86,6 +86,7 @@ std::vector<PrimitivePtr> GetExpandOps() {
|
|||
prim::kPrimSoftmaxCrossEntropyWithLogits,
|
||||
prim::kPrimSquaredDifference,
|
||||
prim::kPrimSqueeze,
|
||||
prim::kPrimEqualCount,
|
||||
#endif
|
||||
};
|
||||
const auto &flags = context::GraphKernelFlags::GetInstance();
|
||||
|
|
|
@ -136,6 +136,7 @@ inline const PrimitivePtr kPrimNotEqual = std::make_shared<Primitive>(kNotEqual)
|
|||
inline const PrimitivePtr kPrimLogicalAnd = std::make_shared<Primitive>("LogicalAnd");
|
||||
inline const PrimitivePtr kPrimLogicalOr = std::make_shared<Primitive>("LogicalOr");
|
||||
inline const PrimitivePtr kPrimLogicalNot = std::make_shared<Primitive>("LogicalNot");
|
||||
inline const PrimitivePtr kPrimEqualCount = std::make_shared<Primitive>("EqualCount");
|
||||
|
||||
inline const PrimitivePtr kPrimDistribute = std::make_shared<Primitive>("distribute");
|
||||
inline const PrimitivePtr kPrimIm2Col = std::make_shared<Primitive>("im2col");
|
||||
|
|
|
@ -0,0 +1,58 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class EqualCount(nn.Cell):
|
||||
def __init__(self):
|
||||
super(EqualCount, self).__init__()
|
||||
self.op = P.EqualCount()
|
||||
|
||||
def construct(self, *inp):
|
||||
return self.op(*inp)
|
||||
|
||||
def get_output(*inp, enable_graph_kernel=False):
|
||||
context.set_context(enable_graph_kernel=enable_graph_kernel)
|
||||
output = EqualCount()(*inp)
|
||||
return output
|
||||
|
||||
def basic_test(datatype):
|
||||
x = Tensor(np.array([[1, 1, 1, 1], [3, 3, 3, 3]]).astype(datatype))
|
||||
y = Tensor(np.array([[1, 2, 1, 2], [1, 1, 3, 3]]).astype(datatype))
|
||||
expect = get_output(x, y, enable_graph_kernel=False)
|
||||
output = get_output(x, y, enable_graph_kernel=True)
|
||||
expect_np = expect.asnumpy().copy()
|
||||
output_np = output.asnumpy().copy()
|
||||
assert np.allclose(expect_np, output_np, 1.e-4, 1.e-7)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_gpu_fp16():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
basic_test(np.float16)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_gpu_fp32():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
basic_test(np.float32)
|
Loading…
Reference in New Issue