forked from mindspore-Ecosystem/mindspore
!40423 Add dynamic shape support for Square and SoftPlusGrad.
Merge pull request !40423 from hezhenhao1/add_softplusgrad
This commit is contained in:
commit
ae207a653d
|
@ -272,13 +272,13 @@ tensor::TensorPtr ConstData() {
|
|||
|
||||
CNodePtr SquareOp(const FuncGraphPtr &graph, const AnfNodePtr &cond, int64_t switch_idx,
|
||||
const tensor::TensorPtr &const_data) {
|
||||
auto PrimSquare = prim::GetPythonOps("square", "mindspore.ops.functional")->cast<PrimitivePtr>();
|
||||
auto prim_square = prim::kPrimSquare;
|
||||
// for the depended node , add two const data to merge the flow ,one for depended node with same switch,
|
||||
// the other use the opposite
|
||||
auto ctrl_data = NewValueNode(const_data);
|
||||
auto ctrl_node = GenerateSwitchNode(graph, cond, ctrl_data, switch_idx);
|
||||
|
||||
std::vector<AnfNodePtr> square_nodes{NewValueNode(PrimSquare), ctrl_node};
|
||||
std::vector<AnfNodePtr> square_nodes{NewValueNode(prim_square), ctrl_node};
|
||||
auto square_op = graph->NewCNode(square_nodes);
|
||||
|
||||
return square_op;
|
||||
|
|
|
@ -264,6 +264,7 @@ BuiltInTypeMap &GetMethodMap() {
|
|||
{"var", std::string("var")}, // P.ReduceSum
|
||||
{"std", std::string("std")}, // P.ReduceSum
|
||||
{"sum", std::string("sum")}, // P.ReduceSum
|
||||
{"square", std::string("square")}, // P.Square()
|
||||
{"repeat", std::string("repeat")}, // C.repeat_elements
|
||||
{"bernoulli", prim::kPrimBernoulli}, // P.Bernoulli()
|
||||
{"ceil", std::string("ceil")}, // P.Ceil
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
*/
|
||||
|
||||
#include "ops/square.h"
|
||||
#include <complex>
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
@ -122,11 +123,20 @@ ValuePtr SquareInferValue(const PrimitivePtr &prim, const std::vector<AbstractBa
|
|||
ImpleSquare<double>(x_datac, result_datac, data_size);
|
||||
break;
|
||||
}
|
||||
case kNumberTypeComplex64: {
|
||||
ImpleSquare<std::complex<float>>(x_datac, result_datac, data_size);
|
||||
break;
|
||||
}
|
||||
case kNumberTypeComplex128: {
|
||||
ImpleSquare<std::complex<double>>(x_datac, result_datac, data_size);
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
MS_EXCEPTION(TypeError) << "For '" << prim->name()
|
||||
<< "', the supported data type is ['int8', 'int16', 'int32', 'int64', 'uint8', "
|
||||
"'uint16','uint32', 'uint64','float16', 'float32', 'float64'], but got "
|
||||
<< x_tensor->ToString();
|
||||
MS_EXCEPTION(TypeError)
|
||||
<< "For '" << prim->name()
|
||||
<< "', the supported data type is ['int8', 'int16', 'int32', 'int64', 'uint8', 'uint16','uint32', "
|
||||
"'uint64','float16', 'float32', 'float64', 'complex64', 'complex128'], but got "
|
||||
<< x_tensor->ToString();
|
||||
}
|
||||
}
|
||||
return result_tensor;
|
||||
|
|
|
@ -2206,6 +2206,11 @@ def bitwise_xor(x, y):
|
|||
return F.bitwise_xor(x, y)
|
||||
|
||||
|
||||
def square(x):
|
||||
"""Returns square of a tensor element-wise."""
|
||||
return F.square(x)
|
||||
|
||||
|
||||
def tan(x):
|
||||
"""Returns tangent of `x`."""
|
||||
return F.tan(x)
|
||||
|
|
|
@ -1045,6 +1045,29 @@ class Tensor(Tensor_):
|
|||
self._init_check()
|
||||
return tensor_operator_registry.get('broadcast_to')(x.shape)(self)
|
||||
|
||||
def square(self):
|
||||
"""
|
||||
Returns square of a tensor element-wise.
|
||||
|
||||
.. math::
|
||||
|
||||
out_{i} = (x_{i})^2
|
||||
|
||||
Returns:
|
||||
Tensor, has the same shape and dtype as the `x`.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor(np.array([1.0, 2.0, 3.0]), mindspore.float32)
|
||||
>>> output = x.square()
|
||||
>>> print(output)
|
||||
[1. 4. 9.]
|
||||
"""
|
||||
self._init_check()
|
||||
return tensor_operator_registry.get('square')(self)
|
||||
|
||||
def tan(self):
|
||||
"""
|
||||
Computes tangent of `x` element-wise.
|
||||
|
|
|
@ -144,6 +144,7 @@ from .math_func import (
|
|||
ge,
|
||||
tensor_sub,
|
||||
sub,
|
||||
square,
|
||||
tensor_mul,
|
||||
mul,
|
||||
tensor_div,
|
||||
|
@ -289,6 +290,7 @@ from .nn_func import (
|
|||
softmax,
|
||||
pdist,
|
||||
pad,
|
||||
prelu,
|
||||
mirror_pad,
|
||||
nll_loss,
|
||||
smooth_l1_loss,
|
||||
|
|
|
@ -160,6 +160,7 @@ truncate_mod_ = P.TruncateMod()
|
|||
trunc_ = P.Trunc()
|
||||
sparse_segment_mean_ = SparseSegmentMean()
|
||||
xlogy_ = P.Xlogy()
|
||||
square_ = P.Square()
|
||||
|
||||
|
||||
#####################################
|
||||
|
@ -3414,6 +3415,35 @@ def std(input_x, axis=(), unbiased=True, keep_dims=False):
|
|||
return output
|
||||
|
||||
|
||||
def square(x):
|
||||
"""
|
||||
Returns square of a tensor element-wise.
|
||||
|
||||
.. math::
|
||||
|
||||
out_{i} = (x_{i})^2
|
||||
|
||||
Args:
|
||||
x (Tensor): The input tensor with a dtype of Number, its rank must be in [0, 7] inclusive.
|
||||
|
||||
Returns:
|
||||
Tensor, has the same shape and dtype as the `x`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `x` is not a Tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor(np.array([1.0, 2.0, 3.0]), mindspore.float32)
|
||||
>>> output = ops.square(x)
|
||||
>>> print(output)
|
||||
[1. 4. 9.]
|
||||
"""
|
||||
return square_(x)
|
||||
|
||||
|
||||
def outer(x1, x2):
|
||||
"""
|
||||
Return outer product of `x1` and `x2`. If `x1` is a vector of size n and `x2` is a vector of size m,
|
||||
|
@ -5787,6 +5817,7 @@ __all__ = [
|
|||
'logit',
|
||||
'logsumexp',
|
||||
'ldexp',
|
||||
'square',
|
||||
'sin',
|
||||
'cos',
|
||||
'tan',
|
||||
|
|
|
@ -34,6 +34,7 @@ softsign_ = P.Softsign()
|
|||
hardswish_ = P.HSwish()
|
||||
mish_ = NN_OPS.Mish()
|
||||
selu_ = NN_OPS.SeLU()
|
||||
prelu_ = NN_OPS.PReLU()
|
||||
|
||||
|
||||
def adaptive_avg_pool2d(input_x, output_size):
|
||||
|
@ -1234,6 +1235,57 @@ def pad(input_x, paddings):
|
|||
return slice_(out, slice_begin, slice_size)
|
||||
|
||||
|
||||
def prelu(x, weight):
|
||||
r"""
|
||||
Parametric Rectified Linear Unit activation function.
|
||||
|
||||
PReLU is described in the paper `Delving Deep into Rectifiers: Surpassing Human-Level Performance on
|
||||
ImageNet Classification <https://arxiv.org/abs/1502.01852>`_. Defined as follows:
|
||||
|
||||
.. math::
|
||||
prelu(x_i)= \max(0, x_i) + \min(0, w * x_i),
|
||||
|
||||
where :math:`x_i` is an element of a channel of the input, `w` is the weight of the channel.
|
||||
|
||||
Note:
|
||||
Scalar or 1-D input x is not supported on Ascend.
|
||||
|
||||
Args:
|
||||
x (Tensor): The input Tensor of the activation function. The data type is float16 or float32.
|
||||
The shape is :math:`(N, C, *)` where :math:`*` means, any number of additional dimensions.
|
||||
weight (Tensor): Weight Tensor. The data type is float16 or float32.
|
||||
The weight can only be a vector, and the length is the same as the number of channels C of the `input_x`.
|
||||
On GPU devices, when the input is a scalar, the shape is 1.
|
||||
|
||||
Returns:
|
||||
Tensor, with the same type as `x`.
|
||||
|
||||
For detailed information, please refer to :class:`mindspore.nn.PReLU`.
|
||||
|
||||
Raises:
|
||||
TypeError: If dtype of `x` or `weight` is neither float16 nor float32.
|
||||
TypeError: If the `x` or the `weight` is not a Tensor.
|
||||
ValueError: If the `x` is a 0-D or 1-D Tensor on Ascend.
|
||||
ValueError: If the `weight` is not a 1-D Tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU``
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor(np.arange(-6, 6).reshape((2, 3, 2)), mindspore.float32)
|
||||
>>> weight = Tensor(np.array([0.1, 0.6, -0.3]), mindspore.float32)
|
||||
>>> output = ops.prelu(x, weight)
|
||||
>>> print(output)
|
||||
[[[-0.60 -0.50]
|
||||
[-2.40 -1.80]
|
||||
[ 0.60 0.30]]
|
||||
[[ 0.00 1.00]
|
||||
[ 2.00 3.00]
|
||||
[ 4.0 5.00]]]
|
||||
"""
|
||||
return prelu_(x, weight)
|
||||
|
||||
|
||||
def mirror_pad(input_x, paddings, mode):
|
||||
"""
|
||||
Pads the input tensor according to the paddings and mode.
|
||||
|
@ -2135,6 +2187,7 @@ __all__ = [
|
|||
'softmax',
|
||||
'pdist',
|
||||
'pad',
|
||||
'prelu',
|
||||
'mirror_pad',
|
||||
'cross_entropy',
|
||||
'grid_sample',
|
||||
|
|
|
@ -54,7 +54,6 @@ merge = P.Merge()
|
|||
geswitch = P.GeSwitch()
|
||||
strided_slice = P.StridedSlice()
|
||||
check_bprop = P.CheckBprop()
|
||||
square = P.Square()
|
||||
sqrt = P.Sqrt()
|
||||
reduce_sum = P.ReduceSum()
|
||||
reduce_max = P.ReduceMax()
|
||||
|
@ -346,6 +345,7 @@ tensor_operator_registry.register('all', P.ReduceAll)
|
|||
tensor_operator_registry.register('any', P.ReduceAny)
|
||||
tensor_operator_registry.register('atan2', atan2)
|
||||
tensor_operator_registry.register('abs', P.Abs)
|
||||
tensor_operator_registry.register('square', square)
|
||||
tensor_operator_registry.register('tan', P.Tan)
|
||||
tensor_operator_registry.register('acos', acos)
|
||||
tensor_operator_registry.register('cos', cos)
|
||||
|
|
|
@ -68,29 +68,3 @@ def test_net(dtype):
|
|||
output = backword_net(Tensor(x), Tensor(sens))
|
||||
print(len(output))
|
||||
print(output[0].asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('dtype', [np.float32, np.float64])
|
||||
def test_square_dy(dtype):
|
||||
"""
|
||||
Feature: ALL To ALL
|
||||
Description: test cases for Square dynamic shape.
|
||||
Expectation: the result match to numpy
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
input_x_np = np.random.randn(2, 3, 3, 4).astype(dtype)
|
||||
benchmark_output = np.square(input_x_np)
|
||||
loss = 1e-6
|
||||
square_net = Net()
|
||||
real_input = Tensor(input_x_np)
|
||||
dy_shape = [None for _ in input_x_np.shape]
|
||||
input_dyn = Tensor(shape=dy_shape, dtype=real_input.dtype)
|
||||
square_net.set_inputs(input_dyn)
|
||||
ms_result = square_net(real_input)
|
||||
np.testing.assert_allclose(benchmark_output, ms_result.asnumpy(), rtol=loss, atol=loss)
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
ms_result = square_net(real_input)
|
||||
np.testing.assert_allclose(benchmark_output, ms_result.asnumpy(), rtol=loss, atol=loss)
|
||||
|
|
|
@ -0,0 +1,75 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor, context
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class SoftplusNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(SoftplusNet, self).__init__()
|
||||
self.softplus = P.Softplus()
|
||||
|
||||
def construct(self, x):
|
||||
return self.softplus(x)
|
||||
|
||||
|
||||
class Grad(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(Grad, self).__init__()
|
||||
self.grad = C.GradOperation(get_all=True, sens_param=True)
|
||||
self.network = network
|
||||
|
||||
def construct(self, input_data, sens):
|
||||
gout = self.grad(self.network)(input_data, sens)
|
||||
return gout
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('dtype', [np.float16, np.float32])
|
||||
def test_dynamic_shape_softplus_grad(dtype):
|
||||
"""
|
||||
Feature: ALL To ALL
|
||||
Description: test cases for SoftplusGrad dynamic shape.
|
||||
Expectation: the result match to numpy
|
||||
"""
|
||||
np.random.seed(0)
|
||||
x_np = np.random.randn(2, 3, 4).astype(dtype)
|
||||
dout_np = np.random.randn(2, 3, 4).astype(dtype)
|
||||
expect = dout_np * np.exp(x_np) / (1 + np.exp(x_np))
|
||||
loss = 1e-3
|
||||
net = SoftplusNet()
|
||||
grad_net = Grad(net)
|
||||
x_tensor = Tensor(x_np)
|
||||
dout_tensor = Tensor(dout_np)
|
||||
dy_shape = [None for _ in x_tensor.shape]
|
||||
x_dyn = Tensor(shape=dy_shape, dtype=x_tensor.dtype)
|
||||
dout_dyn = Tensor(shape=dy_shape, dtype=x_tensor.dtype)
|
||||
grad_net.set_inputs(x_dyn, dout_dyn)
|
||||
|
||||
# Graph mode
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
ms_result = grad_net(x_tensor, dout_tensor)[0]
|
||||
np.testing.assert_allclose(expect, ms_result.asnumpy(), rtol=loss, atol=loss)
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
ms_result = grad_net(x_tensor, dout_tensor)[0]
|
||||
np.testing.assert_allclose(expect, ms_result.asnumpy(), rtol=loss, atol=loss)
|
|
@ -0,0 +1,58 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor, context
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.ops = P.Square()
|
||||
|
||||
def construct(self, x):
|
||||
return self.ops(x)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('dtype', [np.float16, np.float32, np.float64])
|
||||
def test_dynamic_shape_square(dtype):
|
||||
"""
|
||||
Feature: ALL To ALL
|
||||
Description: test cases for Square dynamic shape.
|
||||
Expectation: the result match to numpy
|
||||
"""
|
||||
input_x_np = np.random.randn(2, 3, 3, 4).astype(dtype)
|
||||
benchmark_output = np.square(input_x_np)
|
||||
loss = 1e-6
|
||||
square_net = Net()
|
||||
real_input = Tensor(input_x_np)
|
||||
dy_shape = [None for _ in input_x_np.shape]
|
||||
input_dyn = Tensor(shape=dy_shape, dtype=real_input.dtype)
|
||||
square_net.set_inputs(input_dyn)
|
||||
# Graph mode
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
ms_result = square_net(real_input)
|
||||
np.testing.assert_allclose(benchmark_output, ms_result.asnumpy(), rtol=loss, atol=loss)
|
||||
# PyNative mode
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
ms_result = square_net(real_input)
|
||||
np.testing.assert_allclose(benchmark_output, ms_result.asnumpy(), rtol=loss, atol=loss)
|
|
@ -54,29 +54,3 @@ def test_square_normal(dtype):
|
|||
output_ms = P.Square()(Tensor(x_np))
|
||||
output_np = np.square(x_np)
|
||||
assert np.allclose(output_ms.asnumpy(), output_np)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('dtype', [np.float32, np.float64])
|
||||
def test_square_dynamic(dtype):
|
||||
"""
|
||||
Feature: ALL To ALL
|
||||
Description: test cases for Square dynamic shape.
|
||||
Expectation: the result match to numpy
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
input_x_np = np.random.randn(2, 3, 3, 4).astype(dtype)
|
||||
benchmark_output = np.square(input_x_np)
|
||||
loss = 1e-6
|
||||
square_net = SquareNet()
|
||||
real_input = Tensor(input_x_np)
|
||||
dy_shape = [None for _ in input_x_np.shape]
|
||||
input_dyn = Tensor(shape=dy_shape, dtype=real_input.dtype)
|
||||
square_net.set_inputs(input_dyn)
|
||||
ms_result = square_net(real_input)
|
||||
np.testing.assert_allclose(benchmark_output, ms_result.asnumpy(), rtol=loss, atol=loss)
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
ms_result = square_net(real_input)
|
||||
np.testing.assert_allclose(benchmark_output, ms_result.asnumpy(), rtol=loss, atol=loss)
|
||||
|
|
|
@ -62,10 +62,11 @@ class LayerNorm(nn.Cell):
|
|||
self.add = P.Add()
|
||||
self.mul = P.Mul()
|
||||
self.div = P.RealDiv()
|
||||
self.square = P.Square()
|
||||
|
||||
def construct(self, x):
|
||||
mean = self.mean(x, -1)
|
||||
variance = self.mean(F.square(self.sub(x, mean)))
|
||||
variance = self.mean(self.square(self.sub(x, mean)))
|
||||
output = self.div(self.sub(x, mean), F.sqrt(self.add(variance, self.eps)))
|
||||
rescaled_output = self.add(self.mul(output, self.gamma), self.beta)
|
||||
return rescaled_output
|
||||
|
|
|
@ -62,7 +62,7 @@ class LayerNorm(nn.Cell):
|
|||
self.add = P.Add()
|
||||
self.mul = P.Mul()
|
||||
self.div = P.RealDiv()
|
||||
|
||||
self.square = P.Square()
|
||||
self.reshape = P.Reshape()
|
||||
self.shape = P.Shape()
|
||||
|
||||
|
@ -73,7 +73,7 @@ class LayerNorm(nn.Cell):
|
|||
x = self.reshape(x, x_shape)
|
||||
x = self.reshape(x, x_target_shape)
|
||||
mean = self.mean(x, -1)
|
||||
variance = self.mean(F.square(self.sub(x, mean)))
|
||||
variance = self.mean(self.square(self.sub(x, mean)))
|
||||
output = self.div(self.sub(x, mean), F.sqrt(self.add(variance, self.eps)))
|
||||
rescaled_output = self.add(self.mul(output, self.gamma), self.beta)
|
||||
output_shape = self.shape(rescaled_output) + (1,)
|
||||
|
|
Loading…
Reference in New Issue