forked from mindspore-Ecosystem/mindspore
[GraphKernel] add sponge ops.
This commit is contained in:
parent
a993f5a46a
commit
7d55cef106
2
akg
2
akg
|
@ -1 +1 @@
|
|||
Subproject commit 1e6b226a0417d23d2d0a2333d5e80f13fe9e8d0f
|
||||
Subproject commit a26aa4330ed1e15dede3b983e79c78290083b3ac
|
|
@ -53,6 +53,7 @@ from .softmax_grad_ext import SoftmaxGradExt
|
|||
from .sqrt_grad import SqrtGrad
|
||||
from .square import Square
|
||||
from .square_sum_v1 import SquareSumV1
|
||||
from .squared_difference import SquaredDifference
|
||||
from .squeeze import Squeeze
|
||||
from .tanh_grad import TanhGrad
|
||||
from .tile import Tile
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
"""generate json desc for squared_difference"""
|
||||
from ._utils import Expander, ExpanderInfoValidator as VLD
|
||||
|
||||
|
||||
@VLD.check_all_formats_same
|
||||
class SquaredDifference(Expander):
|
||||
"""SquaredDifference expander"""
|
||||
|
||||
def _expand(self, graph_builder):
|
||||
input_x = self.inputs[0]
|
||||
input_y = self.inputs[1]
|
||||
|
||||
sub_val = graph_builder.emit('Sub', [input_x, input_y])
|
||||
result = graph_builder.emit('Mul', [sub_val, sub_val])
|
||||
|
||||
return result
|
|
@ -176,6 +176,11 @@ class PrimLib:
|
|||
'ReduceMax': Prim(REDUCE),
|
||||
'ReduceMin': Prim(REDUCE),
|
||||
'Assign': Prim(ELEMWISE),
|
||||
'Sign': Prim(ELEMWISE),
|
||||
'Sin': Prim(ELEMWISE),
|
||||
'Cos': Prim(ELEMWISE),
|
||||
'Asin': Prim(ELEMWISE),
|
||||
'ACos': Prim(ELEMWISE),
|
||||
'Tanh': Prim(ELEMWISE),
|
||||
'InplaceAssign': Prim(ELEMWISE),
|
||||
'@ReduceInit': Prim(ELEMWISE),
|
||||
|
|
|
@ -67,6 +67,11 @@ std::vector<PrimitivePtr> GetClusterableOpList() {
|
|||
prim::KPrimTransData,
|
||||
prim::kPrimBatchMatMul,
|
||||
#elif ENABLE_GPU
|
||||
prim::kPrimSin,
|
||||
prim::kPrimCos,
|
||||
prim::kPrimAsin,
|
||||
prim::kPrimACos,
|
||||
prim::kPrimSign,
|
||||
prim::kPrimReduceMax,
|
||||
prim::kPrimReduceMin,
|
||||
prim::kPrimGreater,
|
||||
|
|
|
@ -84,6 +84,7 @@ std::vector<PrimitivePtr> GetExpandOps() {
|
|||
prim::kPrimSigmoidCrossEntropyWithLogits,
|
||||
prim::kPrimSigmoidCrossEntropyWithLogitsGrad,
|
||||
prim::kPrimSoftmaxCrossEntropyWithLogits,
|
||||
prim::kPrimSquaredDifference,
|
||||
prim::kPrimSqueeze,
|
||||
#endif
|
||||
};
|
||||
|
|
|
@ -0,0 +1,58 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.squared_difference = P.SquaredDifference()
|
||||
|
||||
def construct(self, x, y):
|
||||
return self.squared_difference(x, y)
|
||||
|
||||
|
||||
def get_output(x, y, enable_graph_kernel=False):
|
||||
context.set_context(enable_graph_kernel=enable_graph_kernel)
|
||||
net = Net()
|
||||
output = net(x, y)
|
||||
return output
|
||||
|
||||
|
||||
def test_squared_difference(shape1, shape2, dtype):
|
||||
x = Tensor(np.random.normal(0, 10, shape1).astype(dtype))
|
||||
y = Tensor(np.random.normal(0, 10, shape2).astype(dtype))
|
||||
expect = get_output(x, y, False)
|
||||
output = get_output(x, y, True)
|
||||
|
||||
expect_np = expect.asnumpy().copy()
|
||||
output_np = output.asnumpy().copy()
|
||||
|
||||
assert np.allclose(expect_np, output_np, 0.0001, 0.0001)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_squared_difference_gpu():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
test_squared_difference((4, 3), (4, 3), np.float16)
|
||||
test_squared_difference((6, 2), (1), np.int32)
|
||||
test_squared_difference((1), (4, 3), np.float32)
|
Loading…
Reference in New Issue