forked from mindspore-Ecosystem/mindspore
!188 Support pow's second input could be tensor and fix bug in bprop of pow
Merge pull request !188 from zhangbuxue/fix_pow_bprop
This commit is contained in:
commit
5d0801edf2
|
@ -98,6 +98,13 @@ class FloorDivInfo : public ArithmeticBase {
|
|||
~FloorDivInfo() override = default;
|
||||
};
|
||||
|
||||
class PowInfo : public ArithmeticBase {
|
||||
public:
|
||||
PowInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs)
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs) {}
|
||||
~PowInfo() override = default;
|
||||
};
|
||||
|
||||
class GreaterInfo : public ArithmeticBase {
|
||||
public:
|
||||
GreaterInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
|
||||
|
|
|
@ -1,47 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "parallel/ops_info/elementary_function_info.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
Status PowInfo::InferMirrorOps() {
|
||||
mirror_ops_.clear();
|
||||
|
||||
Shape tensor_map = inputs_tensor_map_[0];
|
||||
std::vector<Group> group;
|
||||
if (CreateGroupByTensorMap(tensor_map, &group) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << " : Create group failed.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
OperatorVector mirror_op;
|
||||
OperatorVector op_for_value;
|
||||
if (group.empty()) {
|
||||
MS_LOG(INFO) << name_ << " : The mirror ops is empty.";
|
||||
return SUCCESS;
|
||||
} else {
|
||||
mirror_op = CreateMirrorOps(group[0].name(), group[0].GetDevNum());
|
||||
mirror_ops_.push_back(mirror_op);
|
||||
mirror_ops_.push_back(op_for_value);
|
||||
std::string group_name = group[0].name();
|
||||
MS_LOG(INFO) << name_ << " : Create the mirror ops success, the group name is " << group_name;
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
|
@ -27,16 +27,6 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
class PowInfo : public ActivationOther {
|
||||
public:
|
||||
PowInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs)
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
|
||||
~PowInfo() override = default;
|
||||
|
||||
protected:
|
||||
Status InferMirrorOps() override;
|
||||
};
|
||||
|
||||
class ExpInfo : public ActivationOther {
|
||||
public:
|
||||
ExpInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs)
|
||||
|
|
|
@ -58,7 +58,7 @@ class _PoolNd(Cell):
|
|||
pass
|
||||
|
||||
def extend_repr(self):
|
||||
return 'kernel_size={kernel_size}, strides={strides}, pad_mode={pad_mode}'.format(**self.__dict__)
|
||||
return 'kernel_size={kernel_size}, stride={stride}, pad_mode={pad_mode}'.format(**self.__dict__)
|
||||
|
||||
|
||||
class MaxPool2d(_PoolNd):
|
||||
|
|
|
@ -336,14 +336,13 @@ def get_bprop_log(self):
|
|||
@bprop_getters.register(P.Pow)
|
||||
def get_bprop_pow(self):
|
||||
"""Grad definition for `Pow` operation."""
|
||||
pow_ = P.Pow()
|
||||
cast = P.Cast()
|
||||
dtype = P.DType()
|
||||
pow_op = P.Pow()
|
||||
ln = P.Log()
|
||||
|
||||
def bprop(x, power, out, dout):
|
||||
g = cast(F.tuple_to_array((power,)), dtype(x)) * pow_(x, power-1.0)
|
||||
dx = g * dout
|
||||
return dx, 0
|
||||
dx = power * pow_op(x, power - 1.0) * dout
|
||||
dpower = pow_op(x, power) * ln(x) * dout
|
||||
return dx, dpower
|
||||
return bprop
|
||||
|
||||
|
||||
|
|
|
@ -1097,7 +1097,7 @@ class ArgMaxWithValue(PrimitiveWithInfer):
|
|||
axis = self.axis
|
||||
x_rank = len(x_shape)
|
||||
validator.check_int_range("axis", axis, -x_rank, x_rank, Rel.INC_LEFT)
|
||||
ouput_shape = _infer_shape_reduce(x_shape, self.axis, self.keep_dims, self.prim_name())
|
||||
ouput_shape = _infer_shape_reduce(x_shape, self.axis, self.keep_dims, self.name)
|
||||
return ouput_shape, ouput_shape
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
|
@ -1143,7 +1143,7 @@ class ArgMinWithValue(PrimitiveWithInfer):
|
|||
axis = self.axis
|
||||
x_rank = len(x_shape)
|
||||
validator.check_int_range("axis", axis, -x_rank, x_rank, Rel.INC_LEFT)
|
||||
ouput_shape = _infer_shape_reduce(x_shape, self.axis, self.keep_dims, self.prim_name())
|
||||
ouput_shape = _infer_shape_reduce(x_shape, self.axis, self.keep_dims, self.name)
|
||||
return ouput_shape, ouput_shape
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
|
|
|
@ -74,7 +74,7 @@ class _BinaryOp(PrimitiveWithInfer):
|
|||
self.init_prim_io_names(inputs=['x', 'y'], outputs=['output'])
|
||||
|
||||
def infer_shape(self, x_shape, y_shape):
|
||||
return _get_broadcast_shape(x_shape, y_shape, self.prim_name())
|
||||
return _get_broadcast_shape(x_shape, y_shape, self.name)
|
||||
|
||||
|
||||
class _MathBinaryOp(_BinaryOp):
|
||||
|
@ -89,7 +89,7 @@ class _MathBinaryOp(_BinaryOp):
|
|||
return x_dtype
|
||||
|
||||
def infer_dtype(self, x_dtype, y_dtype):
|
||||
return _MathBinaryOp.do_infer_dtype(x_dtype, y_dtype, mstype.number_type, self.prim_name())
|
||||
return _MathBinaryOp.do_infer_dtype(x_dtype, y_dtype, mstype.number_type, self.name)
|
||||
|
||||
|
||||
class TensorAdd(_MathBinaryOp):
|
||||
|
@ -158,7 +158,7 @@ class AssignAdd(PrimitiveWithInfer):
|
|||
|
||||
def infer_dtype(self, variable, value):
|
||||
args = {"value": value}
|
||||
validator.check_scalar_or_tensor_type_same(args, mstype.number_type, self.prim_name())
|
||||
validator.check_scalar_or_tensor_type_same(args, mstype.number_type, self.name)
|
||||
return value
|
||||
|
||||
|
||||
|
@ -201,7 +201,7 @@ class AssignSub(PrimitiveWithInfer):
|
|||
|
||||
def infer_dtype(self, variable, value):
|
||||
args = {"value": value}
|
||||
validator.check_scalar_or_tensor_type_same(args, mstype.number_type, self.prim_name())
|
||||
validator.check_scalar_or_tensor_type_same(args, mstype.number_type, self.name)
|
||||
return value
|
||||
|
||||
|
||||
|
@ -222,16 +222,16 @@ class _Reduce(PrimitiveWithInfer):
|
|||
@prim_attr_register
|
||||
def __init__(self, keep_dims=False):
|
||||
"""init Reduce"""
|
||||
validator.check_value_type('keep_dims', keep_dims, [bool], self.prim_name())
|
||||
validator.check_value_type('keep_dims', keep_dims, [bool], self.name)
|
||||
self.init_prim_io_names(inputs=['input_x', 'axis'], outputs=['y'])
|
||||
|
||||
def do_infer(self, input_x, axis, valid_dtype=mstype.number_type):
|
||||
axis_v = axis['value']
|
||||
input_shp = input_x['shape']
|
||||
args = {'input_x': input_x['dtype']}
|
||||
validator.check_tensor_type_same(args, valid_dtype, self.prim_name())
|
||||
validator.check_tensor_type_same(args, valid_dtype, self.name)
|
||||
|
||||
input_shp = _infer_shape_reduce(input_shp, axis_v, self.keep_dims, self.prim_name())
|
||||
input_shp = _infer_shape_reduce(input_shp, axis_v, self.keep_dims, self.name)
|
||||
return {'shape': input_shp,
|
||||
'dtype': input_x['dtype'],
|
||||
'value': None}
|
||||
|
@ -466,7 +466,7 @@ class CumProd(PrimitiveWithInfer):
|
|||
"""
|
||||
@prim_attr_register
|
||||
def __init__(self, exclusive=False, reverse=False):
|
||||
cls_name = self.prim_name()
|
||||
cls_name = self.name
|
||||
self.exclusive = validator.check_value_type("exclusive", exclusive, [bool], cls_name)
|
||||
self.reverse = validator.check_value_type("reverse", reverse, [bool], cls_name)
|
||||
|
||||
|
@ -474,7 +474,7 @@ class CumProd(PrimitiveWithInfer):
|
|||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_type, axis_type):
|
||||
cls_name = self.prim_name()
|
||||
cls_name = self.name
|
||||
validator.check_tensor_type_same({'x': x_type}, mstype.number_type, cls_name)
|
||||
validator.check_subclass("axis", axis_type, mstype.int_, cls_name)
|
||||
return x_type
|
||||
|
@ -510,7 +510,7 @@ class MatMul(PrimitiveWithInfer):
|
|||
def __init__(self, transpose_a=False, transpose_b=False):
|
||||
self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['output'])
|
||||
self.__setattr_flag__ = True
|
||||
cls_name = self.prim_name()
|
||||
cls_name = self.name
|
||||
validator.check_value_type("transpose_a", transpose_a, [bool], cls_name)
|
||||
validator.check_value_type("transpose_b", transpose_b, [bool], cls_name)
|
||||
|
||||
|
@ -521,7 +521,7 @@ class MatMul(PrimitiveWithInfer):
|
|||
|
||||
def infer_shape(self, x, y):
|
||||
self.check_shape_size(x, y)
|
||||
cls_name = self.prim_name()
|
||||
cls_name = self.name
|
||||
# expected dimension of x, y, x:[...,a,b] y:[..., c,d], the dim size should be the same except the last two
|
||||
for i in range(len(x) - 2):
|
||||
if x[i] != y[i]:
|
||||
|
@ -546,7 +546,7 @@ class MatMul(PrimitiveWithInfer):
|
|||
|
||||
def infer_dtype(self, x, y):
|
||||
args = {"x": x, "y": y}
|
||||
validator.check_tensor_type_same(args, mstype.float_type + mstype.int_type, self.prim_name())
|
||||
validator.check_tensor_type_same(args, mstype.float_type + mstype.int_type, self.name)
|
||||
return x
|
||||
|
||||
|
||||
|
@ -590,7 +590,7 @@ class BatchMatMul(MatMul):
|
|||
def __init__(self, transpose_a=False, transpose_b=False):
|
||||
self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['output'])
|
||||
self.__setattr_flag__ = True
|
||||
cls_name = self.prim_name()
|
||||
cls_name = self.name
|
||||
validator.check_value_type("transpose_a", transpose_a, [bool], cls_name)
|
||||
validator.check_value_type("transpose_b", transpose_b, [bool], cls_name)
|
||||
|
||||
|
@ -628,13 +628,13 @@ class CumSum(PrimitiveWithInfer):
|
|||
@prim_attr_register
|
||||
def __init__(self, exclusive=False, reverse=False):
|
||||
"""init cumsum"""
|
||||
cls_name = self.prim_name()
|
||||
cls_name = self.name
|
||||
validator.check_value_type('exclusive', exclusive, [bool], cls_name)
|
||||
validator.check_value_type('reverse', reverse, [bool], cls_name)
|
||||
self.init_prim_io_names(inputs=['x', 'axis'], outputs=['y'])
|
||||
|
||||
def __infer__(self, x, axis):
|
||||
cls_name = self.prim_name()
|
||||
cls_name = self.name
|
||||
x_shp = x['shape']
|
||||
validator.check_value_type('axis', axis['value'], [int], cls_name)
|
||||
valid_types = [mstype.uint8, mstype.int8, mstype.int32, mstype.float16, mstype.float32]
|
||||
|
@ -679,7 +679,7 @@ class AddN(PrimitiveWithInfer):
|
|||
self.init_prim_io_names(inputs=["inputs"], outputs=["sum"])
|
||||
|
||||
def infer_shape(self, inputs):
|
||||
cls_name = self.prim_name()
|
||||
cls_name = self.name
|
||||
validator.check_integer("inputs", len(inputs), 1, Rel.GE, cls_name)
|
||||
self.add_prim_attr('n', len(inputs))
|
||||
shp0 = inputs[0]
|
||||
|
@ -688,7 +688,7 @@ class AddN(PrimitiveWithInfer):
|
|||
return shp0
|
||||
|
||||
def infer_dtype(self, inputs):
|
||||
cls_name = self.prim_name()
|
||||
cls_name = self.name
|
||||
validator.check_value_type("inputs", inputs, [tuple, list], cls_name)
|
||||
validator.check_integer("inputs", len(inputs), 1, Rel.GE, cls_name)
|
||||
args = {}
|
||||
|
@ -718,7 +718,7 @@ class Neg(PrimitiveWithInfer):
|
|||
return input_x
|
||||
|
||||
def infer_dtype(self, input_x):
|
||||
validator.check_tensor_type_same({"input_x": input_x}, mstype.number_type, self.prim_name())
|
||||
validator.check_tensor_type_same({"input_x": input_x}, mstype.number_type, self.name)
|
||||
return input_x
|
||||
|
||||
|
||||
|
@ -809,7 +809,7 @@ class Square(PrimitiveWithInfer):
|
|||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_type):
|
||||
validator.check_tensor_type_same({"x": x_type}, mstype.number_type, self.prim_name())
|
||||
validator.check_tensor_type_same({"x": x_type}, mstype.number_type, self.name)
|
||||
return x_type
|
||||
|
||||
|
||||
|
@ -838,7 +838,7 @@ class Rsqrt(PrimitiveWithInfer):
|
|||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_type):
|
||||
validator.check_tensor_type_same({"x": x_type}, mstype.number_type, self.prim_name())
|
||||
validator.check_tensor_type_same({"x": x_type}, mstype.number_type, self.name)
|
||||
return x_type
|
||||
|
||||
|
||||
|
@ -867,7 +867,7 @@ class Sqrt(PrimitiveWithInfer):
|
|||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_type):
|
||||
validator.check_tensor_type_same({"x": x_type}, mstype.number_type, self.prim_name())
|
||||
validator.check_tensor_type_same({"x": x_type}, mstype.number_type, self.name)
|
||||
return x_type
|
||||
|
||||
|
||||
|
@ -897,14 +897,29 @@ class Reciprocal(PrimitiveWithInfer):
|
|||
return x
|
||||
|
||||
def infer_dtype(self, x):
|
||||
validator.check_subclass("x", x, mstype.tensor, self.prim_name())
|
||||
validator.check_subclass("x", x, mstype.tensor, self.name)
|
||||
return x
|
||||
|
||||
|
||||
class Pow(PrimitiveWithInfer):
|
||||
class Pow(_MathBinaryOp):
|
||||
"""
|
||||
Computes a tensor to the power of the second input.
|
||||
|
||||
The first input must be a tensor, and the second input should be a tensor or a number.
|
||||
When the inputs are two tensors, the shapes of them could be broadcast,
|
||||
and the data types of them should be the same.
|
||||
When the inputs are one tensor and one scalar, the scalar could not be a parameter,
|
||||
only could be a constant, and the type of the scalar is the same as the data type of the tensor.
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Union[Tensor]) - The first input is a tensor whose data type is number.
|
||||
- **input_y** (Union[Tensor, Number]) - The second input is a tensor whose data type is same as 'input_x' or
|
||||
a number.
|
||||
|
||||
Outputs:
|
||||
Tensor, the shape is same as the shape after broadcasting, and the data type is same as 'input_x'.
|
||||
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Tensor) - The input tensor.
|
||||
- **input_y** (Union[Tensor, Number]) - The exponent part. If exponent is a tensor, its shape must be able to
|
||||
|
@ -927,17 +942,6 @@ class Pow(PrimitiveWithInfer):
|
|||
[1.0, 16.0, 64.0]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""init Multiply"""
|
||||
|
||||
def infer_shape(self, x, power):
|
||||
return x
|
||||
|
||||
def infer_dtype(self, x, power):
|
||||
validator.check_tensor_type_same({"x": x}, mstype.number_type, self.prim_name())
|
||||
return x
|
||||
|
||||
|
||||
class Exp(PrimitiveWithInfer):
|
||||
"""
|
||||
|
@ -965,7 +969,7 @@ class Exp(PrimitiveWithInfer):
|
|||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_type):
|
||||
validator.check_subclass("x", x_type, mstype.tensor, self.prim_name())
|
||||
validator.check_subclass("x", x_type, mstype.tensor, self.name)
|
||||
return x_type
|
||||
|
||||
|
||||
|
@ -994,7 +998,7 @@ class Log(PrimitiveWithInfer):
|
|||
return x
|
||||
|
||||
def infer_dtype(self, x):
|
||||
validator.check_subclass("x", x, mstype.tensor, self.prim_name())
|
||||
validator.check_subclass("x", x, mstype.tensor, self.name)
|
||||
return x
|
||||
|
||||
|
||||
|
@ -1176,7 +1180,7 @@ class Floor(PrimitiveWithInfer):
|
|||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
validator.check_tensor_type_same({"x": x_dtype}, mstype.float_type, self.prim_name())
|
||||
validator.check_tensor_type_same({"x": x_dtype}, mstype.float_type, self.name)
|
||||
return x_dtype
|
||||
|
||||
|
||||
|
@ -1231,7 +1235,7 @@ class Acosh(PrimitiveWithInfer):
|
|||
return x
|
||||
|
||||
def infer_dtype(self, x):
|
||||
validator.check_tensor_type_same({'x': x}, mstype.number_type, self.prim_name())
|
||||
validator.check_tensor_type_same({'x': x}, mstype.number_type, self.name)
|
||||
return x
|
||||
|
||||
|
||||
|
@ -1247,7 +1251,7 @@ class _LogicBinaryOp(_BinaryOp):
|
|||
return mstype.tensor_type(mstype.bool_)
|
||||
|
||||
def infer_dtype(self, x_dtype, y_dtype):
|
||||
return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, prim_name=self.prim_name())
|
||||
return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, prim_name=self.name)
|
||||
|
||||
|
||||
class Equal(_LogicBinaryOp):
|
||||
|
@ -1283,7 +1287,7 @@ class Equal(_LogicBinaryOp):
|
|||
"""
|
||||
|
||||
def infer_dtype(self, x_dtype, y_dtype):
|
||||
return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, mstype.number_type + (mstype.bool_,), self.prim_name())
|
||||
return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, mstype.number_type + (mstype.bool_,), self.name)
|
||||
|
||||
|
||||
class EqualCount(PrimitiveWithInfer):
|
||||
|
@ -1318,7 +1322,7 @@ class EqualCount(PrimitiveWithInfer):
|
|||
|
||||
def infer_dtype(self, x_dtype, y_dtype):
|
||||
args = {'x': x_dtype, 'y': y_dtype}
|
||||
validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), self.prim_name())
|
||||
validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), self.name)
|
||||
return x_dtype
|
||||
|
||||
|
||||
|
@ -1355,7 +1359,7 @@ class NotEqual(_LogicBinaryOp):
|
|||
"""
|
||||
|
||||
def infer_dtype(self, x_dtype, y_dtype):
|
||||
return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, mstype.number_type + (mstype.bool_,), self.prim_name())
|
||||
return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, mstype.number_type + (mstype.bool_,), self.name)
|
||||
|
||||
|
||||
class Greater(_LogicBinaryOp):
|
||||
|
@ -1491,7 +1495,7 @@ class LogicalNot(PrimitiveWithInfer):
|
|||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
validator.check_tensor_type_same({"x": x_dtype}, [mstype.bool_], self.prim_name())
|
||||
validator.check_tensor_type_same({"x": x_dtype}, [mstype.bool_], self.name)
|
||||
return mstype.tensor_type(mstype.bool_)
|
||||
|
||||
|
||||
|
@ -1521,7 +1525,7 @@ class LogicalAnd(_LogicBinaryOp):
|
|||
"""
|
||||
|
||||
def infer_dtype(self, x_dtype, y_dtype):
|
||||
return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, (mstype.bool_,), self.prim_name())
|
||||
return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, (mstype.bool_,), self.name)
|
||||
|
||||
|
||||
class LogicalOr(_LogicBinaryOp):
|
||||
|
@ -1550,7 +1554,7 @@ class LogicalOr(_LogicBinaryOp):
|
|||
"""
|
||||
|
||||
def infer_dtype(self, x_dtype, y_dtype):
|
||||
return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, (mstype.bool_,), self.prim_name())
|
||||
return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, (mstype.bool_,), self.name)
|
||||
|
||||
class IsNan(PrimitiveWithInfer):
|
||||
"""
|
||||
|
@ -1699,13 +1703,13 @@ class NPUGetFloatStatus(PrimitiveWithInfer):
|
|||
self.add_prim_attr("_side_effect_flag", True)
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
cls_name = self.prim_name()
|
||||
cls_name = self.name
|
||||
validator.check_integer("len(x_shape)", len(x_shape), 1, Rel.EQ, cls_name)
|
||||
validator.check_integer("x_shape[0]", x_shape[0], 8, Rel.EQ, cls_name)
|
||||
return [8]
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
validator.check_tensor_type_same({'x': x_dtype}, [mstype.float32], self.prim_name())
|
||||
validator.check_tensor_type_same({'x': x_dtype}, [mstype.float32], self.name)
|
||||
return mstype.float32
|
||||
|
||||
|
||||
|
@ -1741,13 +1745,13 @@ class NPUClearFloatStatus(PrimitiveWithInfer):
|
|||
self.add_prim_attr("_side_effect_flag", True)
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
cls_name = self.prim_name()
|
||||
cls_name = self.name
|
||||
validator.check_integer("len(x_shape)", len(x_shape), 1, Rel.EQ, cls_name)
|
||||
validator.check_integer("x_shape[0]", x_shape[0], 8, Rel.EQ, cls_name)
|
||||
return [8]
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
validator.check_tensor_type_same({'x': x_dtype}, [mstype.float32], self.prim_name())
|
||||
validator.check_tensor_type_same({'x': x_dtype}, [mstype.float32], self.name)
|
||||
return mstype.float32
|
||||
|
||||
|
||||
|
@ -1775,7 +1779,7 @@ class Cos(PrimitiveWithInfer):
|
|||
return x
|
||||
|
||||
def infer_dtype(self, x):
|
||||
validator.check_tensor_type_same({'x': x}, mstype.number_type, self.prim_name())
|
||||
validator.check_tensor_type_same({'x': x}, mstype.number_type, self.name)
|
||||
return x
|
||||
|
||||
|
||||
|
@ -1803,7 +1807,7 @@ class ACos(PrimitiveWithInfer):
|
|||
return x
|
||||
|
||||
def infer_dtype(self, x):
|
||||
validator.check_tensor_type_same({'x': x}, mstype.number_type, self.prim_name())
|
||||
validator.check_tensor_type_same({'x': x}, mstype.number_type, self.name)
|
||||
return x
|
||||
|
||||
|
||||
|
@ -1831,7 +1835,7 @@ class Sin(PrimitiveWithInfer):
|
|||
return x
|
||||
|
||||
def infer_dtype(self, x):
|
||||
validator.check_tensor_type_same({'x': x}, mstype.number_type, self.prim_name())
|
||||
validator.check_tensor_type_same({'x': x}, mstype.number_type, self.name)
|
||||
return x
|
||||
|
||||
|
||||
|
@ -1876,11 +1880,11 @@ class NMSWithMask(PrimitiveWithInfer):
|
|||
@prim_attr_register
|
||||
def __init__(self, iou_threshold=0.5):
|
||||
"""Init NMSWithMask"""
|
||||
validator.check_value_type("iou_threshold", iou_threshold, [float], self.prim_name())
|
||||
validator.check_value_type("iou_threshold", iou_threshold, [float], self.name)
|
||||
self.init_prim_io_names(inputs=['bboxes'], outputs=['selected_boxes', 'selected_idx', 'selected_mask'])
|
||||
|
||||
def infer_shape(self, bboxes_shape):
|
||||
cls_name = self.prim_name()
|
||||
cls_name = self.name
|
||||
validator.check_integer("bboxes rank", len(bboxes_shape), 2, Rel.EQ, cls_name)
|
||||
validator.check_integer("bboxes.shape()[0]", bboxes_shape[0], 0, Rel.GT, cls_name)
|
||||
validator.check_integer("bboxes.shape()[1]", bboxes_shape[1], 5, Rel.EQ, cls_name)
|
||||
|
@ -1888,7 +1892,7 @@ class NMSWithMask(PrimitiveWithInfer):
|
|||
return (bboxes_shape, (num,), (num,))
|
||||
|
||||
def infer_dtype(self, bboxes_dtype):
|
||||
validator.check_tensor_type_same({"bboxes": bboxes_dtype}, [mstype.float16, mstype.float32], self.prim_name())
|
||||
validator.check_tensor_type_same({"bboxes": bboxes_dtype}, [mstype.float16, mstype.float32], self.name)
|
||||
return (bboxes_dtype, mstype.int32, mstype.bool_)
|
||||
|
||||
|
||||
|
@ -1917,7 +1921,7 @@ class Abs(PrimitiveWithInfer):
|
|||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_type):
|
||||
validator.check_tensor_type_same({'x': x_type}, mstype.number_type, self.prim_name())
|
||||
validator.check_tensor_type_same({'x': x_type}, mstype.number_type, self.name)
|
||||
return x_type
|
||||
|
||||
def infer_value(self, x):
|
||||
|
@ -1959,7 +1963,7 @@ class Sign(PrimitiveWithInfer):
|
|||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.prim_name())
|
||||
validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.name)
|
||||
return x_dtype
|
||||
|
||||
|
||||
|
@ -1988,7 +1992,7 @@ class Round(PrimitiveWithInfer):
|
|||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_type):
|
||||
validator.check_tensor_type_same({'x': x_type}, mstype.number_type, self.prim_name())
|
||||
validator.check_tensor_type_same({'x': x_type}, mstype.number_type, self.name)
|
||||
return x_type
|
||||
|
||||
|
||||
|
|
|
@ -194,9 +194,6 @@ class PrimitiveWithInfer(Primitive):
|
|||
Primitive.__init__(self, name)
|
||||
self.set_prim_type(prim_type.py_infer_shape)
|
||||
|
||||
def prim_name(self):
|
||||
return self.__class__.__name__
|
||||
|
||||
def _clone(self):
|
||||
"""
|
||||
Deeply clones the primitive object.
|
||||
|
|
|
@ -19,7 +19,7 @@
|
|||
#include <vector>
|
||||
#include "common/common_test.h"
|
||||
#include "parallel/strategy.h"
|
||||
#include "parallel/ops_info/elementary_function_info.h"
|
||||
#include "parallel/ops_info/arithmetic_info.h"
|
||||
#include "parallel/device_manager.h"
|
||||
#include "parallel/step_parallel.h"
|
||||
|
||||
|
@ -56,14 +56,14 @@ void TestPowInfo::SetUp() {
|
|||
|
||||
std::unordered_map<std::string, ValuePtr> attr;
|
||||
|
||||
Shapes inputs_shape = {{32, 64, 128}};
|
||||
Shapes inputs_shape = {{32, 64, 128}, {32, 64, 128}};
|
||||
Shapes outputs_shape = {{32, 64, 128}};
|
||||
|
||||
pow = std::make_shared<PowInfo>("pow_info", inputs_shape, outputs_shape, attr);
|
||||
}
|
||||
|
||||
TEST_F(TestPowInfo, InferDevMatrixShape1) {
|
||||
std::vector<Dimensions> inputs = {{2, 4, 8}};
|
||||
std::vector<Dimensions> inputs = {{2, 4, 8}, {2, 4, 8}};
|
||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||
|
||||
pow->Init(strategy);
|
||||
|
@ -74,7 +74,7 @@ TEST_F(TestPowInfo, InferDevMatrixShape1) {
|
|||
}
|
||||
|
||||
TEST_F(TestPowInfo, InferSliceShape1) {
|
||||
std::vector<Dimensions> str = {{2, 4, 8}};
|
||||
std::vector<Dimensions> str = {{2, 4, 8}, {2, 4, 8}};
|
||||
StrategyPtr strategy = NewStrategy(0, str);
|
||||
|
||||
pow->Init(strategy);
|
||||
|
@ -95,7 +95,7 @@ TEST_F(TestPowInfo, InferSliceShape1) {
|
|||
}
|
||||
|
||||
TEST_F(TestPowInfo, GetTensorLayout1) {
|
||||
std::vector<Dimensions> str = {{2, 4, 8}};
|
||||
std::vector<Dimensions> str = {{2, 4, 8}, {2, 4, 8}};
|
||||
StrategyPtr strategy = NewStrategy(0, str);
|
||||
|
||||
pow->Init(strategy);
|
||||
|
@ -116,7 +116,7 @@ TEST_F(TestPowInfo, GetTensorLayout1) {
|
|||
}
|
||||
|
||||
TEST_F(TestPowInfo, GetForwardOp1) {
|
||||
std::vector<Dimensions> inputs = {{2, 4, 8}};
|
||||
std::vector<Dimensions> inputs = {{2, 4, 8}, {2, 4, 8}};
|
||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||
|
||||
pow->Init(strategy);
|
||||
|
@ -127,7 +127,7 @@ TEST_F(TestPowInfo, GetForwardOp1) {
|
|||
}
|
||||
|
||||
TEST_F(TestPowInfo, GetMirrorOPs1) {
|
||||
std::vector<Dimensions> inputs = {{2, 4, 8}};
|
||||
std::vector<Dimensions> inputs = {{2, 4, 8}, {2, 4, 8}};
|
||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||
|
||||
pow->Init(strategy);
|
||||
|
@ -147,7 +147,7 @@ TEST_F(TestPowInfo, CheckStrategy1) {
|
|||
}
|
||||
|
||||
TEST_F(TestPowInfo, CheckStrategy2) {
|
||||
std::vector<Dimensions> inputs = {{2, 4, 8, 16}};
|
||||
std::vector<Dimensions> inputs = {{2, 4, 8, 16}, {2, 4, 8, 16}};
|
||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||
|
||||
Status ret = pow->Init(strategy);
|
||||
|
@ -155,7 +155,7 @@ TEST_F(TestPowInfo, CheckStrategy2) {
|
|||
}
|
||||
|
||||
TEST_F(TestPowInfo, CheckStrategy3) {
|
||||
std::vector<Dimensions> inputs = {{2, 4, 8}};
|
||||
std::vector<Dimensions> inputs = {{2, 4, 8}, {2, 4, 8}};
|
||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||
|
||||
Status ret = pow->Init(strategy);
|
||||
|
|
|
@ -82,9 +82,10 @@ def test_sqrt():
|
|||
def test_pow():
|
||||
""" test_pow """
|
||||
input_tensor = Tensor(np.array([[2, 2], [3, 3]]))
|
||||
power = Tensor(np.array(3.0, np.int64))
|
||||
testpow = P.Pow()
|
||||
expect = np.array([[8, 8], [27, 27]])
|
||||
result = testpow(input_tensor, 3.0)
|
||||
result = testpow(input_tensor, power)
|
||||
assert np.all(result.asnumpy() == expect)
|
||||
|
||||
|
||||
|
|
|
@ -224,11 +224,15 @@ test_case_math_ops = [
|
|||
'block': P.Minimum(),
|
||||
'desc_inputs': [[2, 3, 3, 5], [2, 3, 3, 5]],
|
||||
'desc_bprop': [[2, 3, 3, 5]]}),
|
||||
('Pow', {
|
||||
('Pow_0', {
|
||||
'block': P.Pow(),
|
||||
'desc_const': [2.0],
|
||||
'desc_inputs': [[2, 3, 3, 5]],
|
||||
'desc_bprop': [[2, 3, 3, 5]]}),
|
||||
('Pow_1', {
|
||||
'block': P.Pow(),
|
||||
'desc_inputs': [[3, 5], [2, 3, 3, 5]],
|
||||
'desc_bprop': [[2, 3, 3, 5]]}),
|
||||
('Exp', {
|
||||
'block': P.Exp(),
|
||||
'desc_inputs': [[2, 3]],
|
||||
|
|
|
@ -59,7 +59,7 @@ def test_matmul_pow():
|
|||
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
strategy1 = ((2, 2), (2, 2))
|
||||
strategy2 = ((4, 2), )
|
||||
strategy2 = ((4, 2), ())
|
||||
net = GradWrap(NetWithLoss(Net(strategy1, strategy2)))
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
|
||||
|
|
|
@ -117,6 +117,7 @@ def vm_impl_pow(self):
|
|||
"""Generate vm_impl function for Pow."""
|
||||
def vm_impl(x, y):
|
||||
x = x.asnumpy()
|
||||
y = y.asnumpy()
|
||||
res = vm.power(x, y)
|
||||
return Tensor(res)
|
||||
return vm_impl
|
||||
|
|
Loading…
Reference in New Issue