Support pow's second input could be tensor and fix bug in bprop of pow

This commit is contained in:
buxue 2020-04-08 16:24:51 +08:00
parent 7cec28526a
commit 5841fe010e
13 changed files with 95 additions and 139 deletions

View File

@ -98,6 +98,13 @@ class FloorDivInfo : public ArithmeticBase {
~FloorDivInfo() override = default; ~FloorDivInfo() override = default;
}; };
class PowInfo : public ArithmeticBase {
public:
PowInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs) {}
~PowInfo() override = default;
};
class GreaterInfo : public ArithmeticBase { class GreaterInfo : public ArithmeticBase {
public: public:
GreaterInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, GreaterInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,

View File

@ -1,47 +0,0 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "parallel/ops_info/elementary_function_info.h"
namespace mindspore {
namespace parallel {
Status PowInfo::InferMirrorOps() {
mirror_ops_.clear();
Shape tensor_map = inputs_tensor_map_[0];
std::vector<Group> group;
if (CreateGroupByTensorMap(tensor_map, &group) != SUCCESS) {
MS_LOG(ERROR) << name_ << " : Create group failed.";
return FAILED;
}
OperatorVector mirror_op;
OperatorVector op_for_value;
if (group.empty()) {
MS_LOG(INFO) << name_ << " : The mirror ops is empty.";
return SUCCESS;
} else {
mirror_op = CreateMirrorOps(group[0].name(), group[0].GetDevNum());
mirror_ops_.push_back(mirror_op);
mirror_ops_.push_back(op_for_value);
std::string group_name = group[0].name();
MS_LOG(INFO) << name_ << " : Create the mirror ops success, the group name is " << group_name;
}
return SUCCESS;
}
} // namespace parallel
} // namespace mindspore

View File

@ -27,16 +27,6 @@
namespace mindspore { namespace mindspore {
namespace parallel { namespace parallel {
class PowInfo : public ActivationOther {
public:
PowInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
~PowInfo() override = default;
protected:
Status InferMirrorOps() override;
};
class ExpInfo : public ActivationOther { class ExpInfo : public ActivationOther {
public: public:
ExpInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs) ExpInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs)

View File

@ -58,7 +58,7 @@ class _PoolNd(Cell):
pass pass
def extend_repr(self): def extend_repr(self):
return 'kernel_size={kernel_size}, strides={strides}, pad_mode={pad_mode}'.format(**self.__dict__) return 'kernel_size={kernel_size}, stride={stride}, pad_mode={pad_mode}'.format(**self.__dict__)
class MaxPool2d(_PoolNd): class MaxPool2d(_PoolNd):

View File

@ -336,14 +336,13 @@ def get_bprop_log(self):
@bprop_getters.register(P.Pow) @bprop_getters.register(P.Pow)
def get_bprop_pow(self): def get_bprop_pow(self):
"""Grad definition for `Pow` operation.""" """Grad definition for `Pow` operation."""
pow_ = P.Pow() pow_op = P.Pow()
cast = P.Cast() ln = P.Log()
dtype = P.DType()
def bprop(x, power, out, dout): def bprop(x, power, out, dout):
g = cast(F.tuple_to_array((power,)), dtype(x)) * pow_(x, power-1.0) dx = power * pow_op(x, power - 1.0) * dout
dx = g * dout dpower = pow_op(x, power) * ln(x) * dout
return dx, 0 return dx, dpower
return bprop return bprop

View File

@ -1097,7 +1097,7 @@ class ArgMaxWithValue(PrimitiveWithInfer):
axis = self.axis axis = self.axis
x_rank = len(x_shape) x_rank = len(x_shape)
validator.check_int_range("axis", axis, -x_rank, x_rank, Rel.INC_LEFT) validator.check_int_range("axis", axis, -x_rank, x_rank, Rel.INC_LEFT)
ouput_shape = _infer_shape_reduce(x_shape, self.axis, self.keep_dims, self.prim_name()) ouput_shape = _infer_shape_reduce(x_shape, self.axis, self.keep_dims, self.name)
return ouput_shape, ouput_shape return ouput_shape, ouput_shape
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
@ -1143,7 +1143,7 @@ class ArgMinWithValue(PrimitiveWithInfer):
axis = self.axis axis = self.axis
x_rank = len(x_shape) x_rank = len(x_shape)
validator.check_int_range("axis", axis, -x_rank, x_rank, Rel.INC_LEFT) validator.check_int_range("axis", axis, -x_rank, x_rank, Rel.INC_LEFT)
ouput_shape = _infer_shape_reduce(x_shape, self.axis, self.keep_dims, self.prim_name()) ouput_shape = _infer_shape_reduce(x_shape, self.axis, self.keep_dims, self.name)
return ouput_shape, ouput_shape return ouput_shape, ouput_shape
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):

View File

@ -74,7 +74,7 @@ class _BinaryOp(PrimitiveWithInfer):
self.init_prim_io_names(inputs=['x', 'y'], outputs=['output']) self.init_prim_io_names(inputs=['x', 'y'], outputs=['output'])
def infer_shape(self, x_shape, y_shape): def infer_shape(self, x_shape, y_shape):
return _get_broadcast_shape(x_shape, y_shape, self.prim_name()) return _get_broadcast_shape(x_shape, y_shape, self.name)
class _MathBinaryOp(_BinaryOp): class _MathBinaryOp(_BinaryOp):
@ -89,7 +89,7 @@ class _MathBinaryOp(_BinaryOp):
return x_dtype return x_dtype
def infer_dtype(self, x_dtype, y_dtype): def infer_dtype(self, x_dtype, y_dtype):
return _MathBinaryOp.do_infer_dtype(x_dtype, y_dtype, mstype.number_type, self.prim_name()) return _MathBinaryOp.do_infer_dtype(x_dtype, y_dtype, mstype.number_type, self.name)
class TensorAdd(_MathBinaryOp): class TensorAdd(_MathBinaryOp):
@ -158,7 +158,7 @@ class AssignAdd(PrimitiveWithInfer):
def infer_dtype(self, variable, value): def infer_dtype(self, variable, value):
args = {"value": value} args = {"value": value}
validator.check_scalar_or_tensor_type_same(args, mstype.number_type, self.prim_name()) validator.check_scalar_or_tensor_type_same(args, mstype.number_type, self.name)
return value return value
@ -201,7 +201,7 @@ class AssignSub(PrimitiveWithInfer):
def infer_dtype(self, variable, value): def infer_dtype(self, variable, value):
args = {"value": value} args = {"value": value}
validator.check_scalar_or_tensor_type_same(args, mstype.number_type, self.prim_name()) validator.check_scalar_or_tensor_type_same(args, mstype.number_type, self.name)
return value return value
@ -222,16 +222,16 @@ class _Reduce(PrimitiveWithInfer):
@prim_attr_register @prim_attr_register
def __init__(self, keep_dims=False): def __init__(self, keep_dims=False):
"""init Reduce""" """init Reduce"""
validator.check_value_type('keep_dims', keep_dims, [bool], self.prim_name()) validator.check_value_type('keep_dims', keep_dims, [bool], self.name)
self.init_prim_io_names(inputs=['input_x', 'axis'], outputs=['y']) self.init_prim_io_names(inputs=['input_x', 'axis'], outputs=['y'])
def do_infer(self, input_x, axis, valid_dtype=mstype.number_type): def do_infer(self, input_x, axis, valid_dtype=mstype.number_type):
axis_v = axis['value'] axis_v = axis['value']
input_shp = input_x['shape'] input_shp = input_x['shape']
args = {'input_x': input_x['dtype']} args = {'input_x': input_x['dtype']}
validator.check_tensor_type_same(args, valid_dtype, self.prim_name()) validator.check_tensor_type_same(args, valid_dtype, self.name)
input_shp = _infer_shape_reduce(input_shp, axis_v, self.keep_dims, self.prim_name()) input_shp = _infer_shape_reduce(input_shp, axis_v, self.keep_dims, self.name)
return {'shape': input_shp, return {'shape': input_shp,
'dtype': input_x['dtype'], 'dtype': input_x['dtype'],
'value': None} 'value': None}
@ -466,7 +466,7 @@ class CumProd(PrimitiveWithInfer):
""" """
@prim_attr_register @prim_attr_register
def __init__(self, exclusive=False, reverse=False): def __init__(self, exclusive=False, reverse=False):
cls_name = self.prim_name() cls_name = self.name
self.exclusive = validator.check_value_type("exclusive", exclusive, [bool], cls_name) self.exclusive = validator.check_value_type("exclusive", exclusive, [bool], cls_name)
self.reverse = validator.check_value_type("reverse", reverse, [bool], cls_name) self.reverse = validator.check_value_type("reverse", reverse, [bool], cls_name)
@ -474,7 +474,7 @@ class CumProd(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_type, axis_type): def infer_dtype(self, x_type, axis_type):
cls_name = self.prim_name() cls_name = self.name
validator.check_tensor_type_same({'x': x_type}, mstype.number_type, cls_name) validator.check_tensor_type_same({'x': x_type}, mstype.number_type, cls_name)
validator.check_subclass("axis", axis_type, mstype.int_, cls_name) validator.check_subclass("axis", axis_type, mstype.int_, cls_name)
return x_type return x_type
@ -510,7 +510,7 @@ class MatMul(PrimitiveWithInfer):
def __init__(self, transpose_a=False, transpose_b=False): def __init__(self, transpose_a=False, transpose_b=False):
self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['output']) self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['output'])
self.__setattr_flag__ = True self.__setattr_flag__ = True
cls_name = self.prim_name() cls_name = self.name
validator.check_value_type("transpose_a", transpose_a, [bool], cls_name) validator.check_value_type("transpose_a", transpose_a, [bool], cls_name)
validator.check_value_type("transpose_b", transpose_b, [bool], cls_name) validator.check_value_type("transpose_b", transpose_b, [bool], cls_name)
@ -521,7 +521,7 @@ class MatMul(PrimitiveWithInfer):
def infer_shape(self, x, y): def infer_shape(self, x, y):
self.check_shape_size(x, y) self.check_shape_size(x, y)
cls_name = self.prim_name() cls_name = self.name
# expected dimension of x, y, x:[...,a,b] y:[..., c,d], the dim size should be the same except the last two # expected dimension of x, y, x:[...,a,b] y:[..., c,d], the dim size should be the same except the last two
for i in range(len(x) - 2): for i in range(len(x) - 2):
if x[i] != y[i]: if x[i] != y[i]:
@ -546,7 +546,7 @@ class MatMul(PrimitiveWithInfer):
def infer_dtype(self, x, y): def infer_dtype(self, x, y):
args = {"x": x, "y": y} args = {"x": x, "y": y}
validator.check_tensor_type_same(args, mstype.float_type + mstype.int_type, self.prim_name()) validator.check_tensor_type_same(args, mstype.float_type + mstype.int_type, self.name)
return x return x
@ -590,7 +590,7 @@ class BatchMatMul(MatMul):
def __init__(self, transpose_a=False, transpose_b=False): def __init__(self, transpose_a=False, transpose_b=False):
self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['output']) self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['output'])
self.__setattr_flag__ = True self.__setattr_flag__ = True
cls_name = self.prim_name() cls_name = self.name
validator.check_value_type("transpose_a", transpose_a, [bool], cls_name) validator.check_value_type("transpose_a", transpose_a, [bool], cls_name)
validator.check_value_type("transpose_b", transpose_b, [bool], cls_name) validator.check_value_type("transpose_b", transpose_b, [bool], cls_name)
@ -628,13 +628,13 @@ class CumSum(PrimitiveWithInfer):
@prim_attr_register @prim_attr_register
def __init__(self, exclusive=False, reverse=False): def __init__(self, exclusive=False, reverse=False):
"""init cumsum""" """init cumsum"""
cls_name = self.prim_name() cls_name = self.name
validator.check_value_type('exclusive', exclusive, [bool], cls_name) validator.check_value_type('exclusive', exclusive, [bool], cls_name)
validator.check_value_type('reverse', reverse, [bool], cls_name) validator.check_value_type('reverse', reverse, [bool], cls_name)
self.init_prim_io_names(inputs=['x', 'axis'], outputs=['y']) self.init_prim_io_names(inputs=['x', 'axis'], outputs=['y'])
def __infer__(self, x, axis): def __infer__(self, x, axis):
cls_name = self.prim_name() cls_name = self.name
x_shp = x['shape'] x_shp = x['shape']
validator.check_value_type('axis', axis['value'], [int], cls_name) validator.check_value_type('axis', axis['value'], [int], cls_name)
valid_types = [mstype.uint8, mstype.int8, mstype.int32, mstype.float16, mstype.float32] valid_types = [mstype.uint8, mstype.int8, mstype.int32, mstype.float16, mstype.float32]
@ -679,7 +679,7 @@ class AddN(PrimitiveWithInfer):
self.init_prim_io_names(inputs=["inputs"], outputs=["sum"]) self.init_prim_io_names(inputs=["inputs"], outputs=["sum"])
def infer_shape(self, inputs): def infer_shape(self, inputs):
cls_name = self.prim_name() cls_name = self.name
validator.check_integer("inputs", len(inputs), 1, Rel.GE, cls_name) validator.check_integer("inputs", len(inputs), 1, Rel.GE, cls_name)
self.add_prim_attr('n', len(inputs)) self.add_prim_attr('n', len(inputs))
shp0 = inputs[0] shp0 = inputs[0]
@ -688,7 +688,7 @@ class AddN(PrimitiveWithInfer):
return shp0 return shp0
def infer_dtype(self, inputs): def infer_dtype(self, inputs):
cls_name = self.prim_name() cls_name = self.name
validator.check_value_type("inputs", inputs, [tuple, list], cls_name) validator.check_value_type("inputs", inputs, [tuple, list], cls_name)
validator.check_integer("inputs", len(inputs), 1, Rel.GE, cls_name) validator.check_integer("inputs", len(inputs), 1, Rel.GE, cls_name)
args = {} args = {}
@ -718,7 +718,7 @@ class Neg(PrimitiveWithInfer):
return input_x return input_x
def infer_dtype(self, input_x): def infer_dtype(self, input_x):
validator.check_tensor_type_same({"input_x": input_x}, mstype.number_type, self.prim_name()) validator.check_tensor_type_same({"input_x": input_x}, mstype.number_type, self.name)
return input_x return input_x
@ -809,7 +809,7 @@ class Square(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_type): def infer_dtype(self, x_type):
validator.check_tensor_type_same({"x": x_type}, mstype.number_type, self.prim_name()) validator.check_tensor_type_same({"x": x_type}, mstype.number_type, self.name)
return x_type return x_type
@ -838,7 +838,7 @@ class Rsqrt(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_type): def infer_dtype(self, x_type):
validator.check_tensor_type_same({"x": x_type}, mstype.number_type, self.prim_name()) validator.check_tensor_type_same({"x": x_type}, mstype.number_type, self.name)
return x_type return x_type
@ -867,7 +867,7 @@ class Sqrt(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_type): def infer_dtype(self, x_type):
validator.check_tensor_type_same({"x": x_type}, mstype.number_type, self.prim_name()) validator.check_tensor_type_same({"x": x_type}, mstype.number_type, self.name)
return x_type return x_type
@ -897,14 +897,29 @@ class Reciprocal(PrimitiveWithInfer):
return x return x
def infer_dtype(self, x): def infer_dtype(self, x):
validator.check_subclass("x", x, mstype.tensor, self.prim_name()) validator.check_subclass("x", x, mstype.tensor, self.name)
return x return x
class Pow(PrimitiveWithInfer): class Pow(_MathBinaryOp):
""" """
Computes a tensor to the power of the second input. Computes a tensor to the power of the second input.
The first input must be a tensor, and the second input should be a tensor or a number.
When the inputs are two tensors, the shapes of them could be broadcast,
and the data types of them should be the same.
When the inputs are one tensor and one scalar, the scalar could not be a parameter,
only could be a constant, and the type of the scalar is the same as the data type of the tensor.
Inputs:
- **input_x** (Union[Tensor]) - The first input is a tensor whose data type is number.
- **input_y** (Union[Tensor, Number]) - The second input is a tensor whose data type is same as 'input_x' or
a number.
Outputs:
Tensor, the shape is same as the shape after broadcasting, and the data type is same as 'input_x'.
Inputs: Inputs:
- **input_x** (Tensor) - The input tensor. - **input_x** (Tensor) - The input tensor.
- **input_y** (Union[Tensor, Number]) - The exponent part. If exponent is a tensor, its shape must be able to - **input_y** (Union[Tensor, Number]) - The exponent part. If exponent is a tensor, its shape must be able to
@ -927,17 +942,6 @@ class Pow(PrimitiveWithInfer):
[1.0, 16.0, 64.0] [1.0, 16.0, 64.0]
""" """
@prim_attr_register
def __init__(self):
"""init Multiply"""
def infer_shape(self, x, power):
return x
def infer_dtype(self, x, power):
validator.check_tensor_type_same({"x": x}, mstype.number_type, self.prim_name())
return x
class Exp(PrimitiveWithInfer): class Exp(PrimitiveWithInfer):
""" """
@ -965,7 +969,7 @@ class Exp(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_type): def infer_dtype(self, x_type):
validator.check_subclass("x", x_type, mstype.tensor, self.prim_name()) validator.check_subclass("x", x_type, mstype.tensor, self.name)
return x_type return x_type
@ -994,7 +998,7 @@ class Log(PrimitiveWithInfer):
return x return x
def infer_dtype(self, x): def infer_dtype(self, x):
validator.check_subclass("x", x, mstype.tensor, self.prim_name()) validator.check_subclass("x", x, mstype.tensor, self.name)
return x return x
@ -1176,7 +1180,7 @@ class Floor(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({"x": x_dtype}, mstype.float_type, self.prim_name()) validator.check_tensor_type_same({"x": x_dtype}, mstype.float_type, self.name)
return x_dtype return x_dtype
@ -1231,7 +1235,7 @@ class Acosh(PrimitiveWithInfer):
return x return x
def infer_dtype(self, x): def infer_dtype(self, x):
validator.check_tensor_type_same({'x': x}, mstype.number_type, self.prim_name()) validator.check_tensor_type_same({'x': x}, mstype.number_type, self.name)
return x return x
@ -1247,7 +1251,7 @@ class _LogicBinaryOp(_BinaryOp):
return mstype.tensor_type(mstype.bool_) return mstype.tensor_type(mstype.bool_)
def infer_dtype(self, x_dtype, y_dtype): def infer_dtype(self, x_dtype, y_dtype):
return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, prim_name=self.prim_name()) return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, prim_name=self.name)
class Equal(_LogicBinaryOp): class Equal(_LogicBinaryOp):
@ -1283,7 +1287,7 @@ class Equal(_LogicBinaryOp):
""" """
def infer_dtype(self, x_dtype, y_dtype): def infer_dtype(self, x_dtype, y_dtype):
return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, mstype.number_type + (mstype.bool_,), self.prim_name()) return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, mstype.number_type + (mstype.bool_,), self.name)
class EqualCount(PrimitiveWithInfer): class EqualCount(PrimitiveWithInfer):
@ -1318,7 +1322,7 @@ class EqualCount(PrimitiveWithInfer):
def infer_dtype(self, x_dtype, y_dtype): def infer_dtype(self, x_dtype, y_dtype):
args = {'x': x_dtype, 'y': y_dtype} args = {'x': x_dtype, 'y': y_dtype}
validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), self.prim_name()) validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), self.name)
return x_dtype return x_dtype
@ -1355,7 +1359,7 @@ class NotEqual(_LogicBinaryOp):
""" """
def infer_dtype(self, x_dtype, y_dtype): def infer_dtype(self, x_dtype, y_dtype):
return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, mstype.number_type + (mstype.bool_,), self.prim_name()) return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, mstype.number_type + (mstype.bool_,), self.name)
class Greater(_LogicBinaryOp): class Greater(_LogicBinaryOp):
@ -1491,7 +1495,7 @@ class LogicalNot(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({"x": x_dtype}, [mstype.bool_], self.prim_name()) validator.check_tensor_type_same({"x": x_dtype}, [mstype.bool_], self.name)
return mstype.tensor_type(mstype.bool_) return mstype.tensor_type(mstype.bool_)
@ -1521,7 +1525,7 @@ class LogicalAnd(_LogicBinaryOp):
""" """
def infer_dtype(self, x_dtype, y_dtype): def infer_dtype(self, x_dtype, y_dtype):
return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, (mstype.bool_,), self.prim_name()) return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, (mstype.bool_,), self.name)
class LogicalOr(_LogicBinaryOp): class LogicalOr(_LogicBinaryOp):
@ -1550,7 +1554,7 @@ class LogicalOr(_LogicBinaryOp):
""" """
def infer_dtype(self, x_dtype, y_dtype): def infer_dtype(self, x_dtype, y_dtype):
return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, (mstype.bool_,), self.prim_name()) return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, (mstype.bool_,), self.name)
class IsNan(PrimitiveWithInfer): class IsNan(PrimitiveWithInfer):
""" """
@ -1699,13 +1703,13 @@ class NPUGetFloatStatus(PrimitiveWithInfer):
self.add_prim_attr("_side_effect_flag", True) self.add_prim_attr("_side_effect_flag", True)
def infer_shape(self, x_shape): def infer_shape(self, x_shape):
cls_name = self.prim_name() cls_name = self.name
validator.check_integer("len(x_shape)", len(x_shape), 1, Rel.EQ, cls_name) validator.check_integer("len(x_shape)", len(x_shape), 1, Rel.EQ, cls_name)
validator.check_integer("x_shape[0]", x_shape[0], 8, Rel.EQ, cls_name) validator.check_integer("x_shape[0]", x_shape[0], 8, Rel.EQ, cls_name)
return [8] return [8]
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, [mstype.float32], self.prim_name()) validator.check_tensor_type_same({'x': x_dtype}, [mstype.float32], self.name)
return mstype.float32 return mstype.float32
@ -1741,13 +1745,13 @@ class NPUClearFloatStatus(PrimitiveWithInfer):
self.add_prim_attr("_side_effect_flag", True) self.add_prim_attr("_side_effect_flag", True)
def infer_shape(self, x_shape): def infer_shape(self, x_shape):
cls_name = self.prim_name() cls_name = self.name
validator.check_integer("len(x_shape)", len(x_shape), 1, Rel.EQ, cls_name) validator.check_integer("len(x_shape)", len(x_shape), 1, Rel.EQ, cls_name)
validator.check_integer("x_shape[0]", x_shape[0], 8, Rel.EQ, cls_name) validator.check_integer("x_shape[0]", x_shape[0], 8, Rel.EQ, cls_name)
return [8] return [8]
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, [mstype.float32], self.prim_name()) validator.check_tensor_type_same({'x': x_dtype}, [mstype.float32], self.name)
return mstype.float32 return mstype.float32
@ -1775,7 +1779,7 @@ class Cos(PrimitiveWithInfer):
return x return x
def infer_dtype(self, x): def infer_dtype(self, x):
validator.check_tensor_type_same({'x': x}, mstype.number_type, self.prim_name()) validator.check_tensor_type_same({'x': x}, mstype.number_type, self.name)
return x return x
@ -1803,7 +1807,7 @@ class ACos(PrimitiveWithInfer):
return x return x
def infer_dtype(self, x): def infer_dtype(self, x):
validator.check_tensor_type_same({'x': x}, mstype.number_type, self.prim_name()) validator.check_tensor_type_same({'x': x}, mstype.number_type, self.name)
return x return x
@ -1831,7 +1835,7 @@ class Sin(PrimitiveWithInfer):
return x return x
def infer_dtype(self, x): def infer_dtype(self, x):
validator.check_tensor_type_same({'x': x}, mstype.number_type, self.prim_name()) validator.check_tensor_type_same({'x': x}, mstype.number_type, self.name)
return x return x
@ -1876,11 +1880,11 @@ class NMSWithMask(PrimitiveWithInfer):
@prim_attr_register @prim_attr_register
def __init__(self, iou_threshold=0.5): def __init__(self, iou_threshold=0.5):
"""Init NMSWithMask""" """Init NMSWithMask"""
validator.check_value_type("iou_threshold", iou_threshold, [float], self.prim_name()) validator.check_value_type("iou_threshold", iou_threshold, [float], self.name)
self.init_prim_io_names(inputs=['bboxes'], outputs=['selected_boxes', 'selected_idx', 'selected_mask']) self.init_prim_io_names(inputs=['bboxes'], outputs=['selected_boxes', 'selected_idx', 'selected_mask'])
def infer_shape(self, bboxes_shape): def infer_shape(self, bboxes_shape):
cls_name = self.prim_name() cls_name = self.name
validator.check_integer("bboxes rank", len(bboxes_shape), 2, Rel.EQ, cls_name) validator.check_integer("bboxes rank", len(bboxes_shape), 2, Rel.EQ, cls_name)
validator.check_integer("bboxes.shape()[0]", bboxes_shape[0], 0, Rel.GT, cls_name) validator.check_integer("bboxes.shape()[0]", bboxes_shape[0], 0, Rel.GT, cls_name)
validator.check_integer("bboxes.shape()[1]", bboxes_shape[1], 5, Rel.EQ, cls_name) validator.check_integer("bboxes.shape()[1]", bboxes_shape[1], 5, Rel.EQ, cls_name)
@ -1888,7 +1892,7 @@ class NMSWithMask(PrimitiveWithInfer):
return (bboxes_shape, (num,), (num,)) return (bboxes_shape, (num,), (num,))
def infer_dtype(self, bboxes_dtype): def infer_dtype(self, bboxes_dtype):
validator.check_tensor_type_same({"bboxes": bboxes_dtype}, [mstype.float16, mstype.float32], self.prim_name()) validator.check_tensor_type_same({"bboxes": bboxes_dtype}, [mstype.float16, mstype.float32], self.name)
return (bboxes_dtype, mstype.int32, mstype.bool_) return (bboxes_dtype, mstype.int32, mstype.bool_)
@ -1917,7 +1921,7 @@ class Abs(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_type): def infer_dtype(self, x_type):
validator.check_tensor_type_same({'x': x_type}, mstype.number_type, self.prim_name()) validator.check_tensor_type_same({'x': x_type}, mstype.number_type, self.name)
return x_type return x_type
def infer_value(self, x): def infer_value(self, x):
@ -1959,7 +1963,7 @@ class Sign(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.prim_name()) validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.name)
return x_dtype return x_dtype
@ -1988,7 +1992,7 @@ class Round(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_type): def infer_dtype(self, x_type):
validator.check_tensor_type_same({'x': x_type}, mstype.number_type, self.prim_name()) validator.check_tensor_type_same({'x': x_type}, mstype.number_type, self.name)
return x_type return x_type

View File

@ -194,9 +194,6 @@ class PrimitiveWithInfer(Primitive):
Primitive.__init__(self, name) Primitive.__init__(self, name)
self.set_prim_type(prim_type.py_infer_shape) self.set_prim_type(prim_type.py_infer_shape)
def prim_name(self):
return self.__class__.__name__
def _clone(self): def _clone(self):
""" """
Deeply clones the primitive object. Deeply clones the primitive object.

View File

@ -19,7 +19,7 @@
#include <vector> #include <vector>
#include "common/common_test.h" #include "common/common_test.h"
#include "parallel/strategy.h" #include "parallel/strategy.h"
#include "parallel/ops_info/elementary_function_info.h" #include "parallel/ops_info/arithmetic_info.h"
#include "parallel/device_manager.h" #include "parallel/device_manager.h"
#include "parallel/step_parallel.h" #include "parallel/step_parallel.h"
@ -56,14 +56,14 @@ void TestPowInfo::SetUp() {
std::unordered_map<std::string, ValuePtr> attr; std::unordered_map<std::string, ValuePtr> attr;
Shapes inputs_shape = {{32, 64, 128}}; Shapes inputs_shape = {{32, 64, 128}, {32, 64, 128}};
Shapes outputs_shape = {{32, 64, 128}}; Shapes outputs_shape = {{32, 64, 128}};
pow = std::make_shared<PowInfo>("pow_info", inputs_shape, outputs_shape, attr); pow = std::make_shared<PowInfo>("pow_info", inputs_shape, outputs_shape, attr);
} }
TEST_F(TestPowInfo, InferDevMatrixShape1) { TEST_F(TestPowInfo, InferDevMatrixShape1) {
std::vector<Dimensions> inputs = {{2, 4, 8}}; std::vector<Dimensions> inputs = {{2, 4, 8}, {2, 4, 8}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
pow->Init(strategy); pow->Init(strategy);
@ -74,7 +74,7 @@ TEST_F(TestPowInfo, InferDevMatrixShape1) {
} }
TEST_F(TestPowInfo, InferSliceShape1) { TEST_F(TestPowInfo, InferSliceShape1) {
std::vector<Dimensions> str = {{2, 4, 8}}; std::vector<Dimensions> str = {{2, 4, 8}, {2, 4, 8}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
pow->Init(strategy); pow->Init(strategy);
@ -95,7 +95,7 @@ TEST_F(TestPowInfo, InferSliceShape1) {
} }
TEST_F(TestPowInfo, GetTensorLayout1) { TEST_F(TestPowInfo, GetTensorLayout1) {
std::vector<Dimensions> str = {{2, 4, 8}}; std::vector<Dimensions> str = {{2, 4, 8}, {2, 4, 8}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
pow->Init(strategy); pow->Init(strategy);
@ -116,7 +116,7 @@ TEST_F(TestPowInfo, GetTensorLayout1) {
} }
TEST_F(TestPowInfo, GetForwardOp1) { TEST_F(TestPowInfo, GetForwardOp1) {
std::vector<Dimensions> inputs = {{2, 4, 8}}; std::vector<Dimensions> inputs = {{2, 4, 8}, {2, 4, 8}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
pow->Init(strategy); pow->Init(strategy);
@ -127,7 +127,7 @@ TEST_F(TestPowInfo, GetForwardOp1) {
} }
TEST_F(TestPowInfo, GetMirrorOPs1) { TEST_F(TestPowInfo, GetMirrorOPs1) {
std::vector<Dimensions> inputs = {{2, 4, 8}}; std::vector<Dimensions> inputs = {{2, 4, 8}, {2, 4, 8}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
pow->Init(strategy); pow->Init(strategy);
@ -147,7 +147,7 @@ TEST_F(TestPowInfo, CheckStrategy1) {
} }
TEST_F(TestPowInfo, CheckStrategy2) { TEST_F(TestPowInfo, CheckStrategy2) {
std::vector<Dimensions> inputs = {{2, 4, 8, 16}}; std::vector<Dimensions> inputs = {{2, 4, 8, 16}, {2, 4, 8, 16}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = pow->Init(strategy); Status ret = pow->Init(strategy);
@ -155,7 +155,7 @@ TEST_F(TestPowInfo, CheckStrategy2) {
} }
TEST_F(TestPowInfo, CheckStrategy3) { TEST_F(TestPowInfo, CheckStrategy3) {
std::vector<Dimensions> inputs = {{2, 4, 8}}; std::vector<Dimensions> inputs = {{2, 4, 8}, {2, 4, 8}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = pow->Init(strategy); Status ret = pow->Init(strategy);

View File

@ -82,9 +82,10 @@ def test_sqrt():
def test_pow(): def test_pow():
""" test_pow """ """ test_pow """
input_tensor = Tensor(np.array([[2, 2], [3, 3]])) input_tensor = Tensor(np.array([[2, 2], [3, 3]]))
power = Tensor(np.array(3.0, np.int64))
testpow = P.Pow() testpow = P.Pow()
expect = np.array([[8, 8], [27, 27]]) expect = np.array([[8, 8], [27, 27]])
result = testpow(input_tensor, 3.0) result = testpow(input_tensor, power)
assert np.all(result.asnumpy() == expect) assert np.all(result.asnumpy() == expect)

View File

@ -224,11 +224,15 @@ test_case_math_ops = [
'block': P.Minimum(), 'block': P.Minimum(),
'desc_inputs': [[2, 3, 3, 5], [2, 3, 3, 5]], 'desc_inputs': [[2, 3, 3, 5], [2, 3, 3, 5]],
'desc_bprop': [[2, 3, 3, 5]]}), 'desc_bprop': [[2, 3, 3, 5]]}),
('Pow', { ('Pow_0', {
'block': P.Pow(), 'block': P.Pow(),
'desc_const': [2.0], 'desc_const': [2.0],
'desc_inputs': [[2, 3, 3, 5]], 'desc_inputs': [[2, 3, 3, 5]],
'desc_bprop': [[2, 3, 3, 5]]}), 'desc_bprop': [[2, 3, 3, 5]]}),
('Pow_1', {
'block': P.Pow(),
'desc_inputs': [[3, 5], [2, 3, 3, 5]],
'desc_bprop': [[2, 3, 3, 5]]}),
('Exp', { ('Exp', {
'block': P.Exp(), 'block': P.Exp(),
'desc_inputs': [[2, 3]], 'desc_inputs': [[2, 3]],

View File

@ -59,7 +59,7 @@ def test_matmul_pow():
context.set_auto_parallel_context(device_num=8, global_rank=0) context.set_auto_parallel_context(device_num=8, global_rank=0)
strategy1 = ((2, 2), (2, 2)) strategy1 = ((2, 2), (2, 2))
strategy2 = ((4, 2), ) strategy2 = ((4, 2), ())
net = GradWrap(NetWithLoss(Net(strategy1, strategy2))) net = GradWrap(NetWithLoss(Net(strategy1, strategy2)))
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")

View File

@ -117,6 +117,7 @@ def vm_impl_pow(self):
"""Generate vm_impl function for Pow.""" """Generate vm_impl function for Pow."""
def vm_impl(x, y): def vm_impl(x, y):
x = x.asnumpy() x = x.asnumpy()
y = y.asnumpy()
res = vm.power(x, y) res = vm.power(x, y)
return Tensor(res) return Tensor(res)
return vm_impl return vm_impl