diff --git a/mindspore/ccsrc/parallel/ops_info/arithmetic_info.h b/mindspore/ccsrc/parallel/ops_info/arithmetic_info.h index daa2ad595c5..76ba500a038 100644 --- a/mindspore/ccsrc/parallel/ops_info/arithmetic_info.h +++ b/mindspore/ccsrc/parallel/ops_info/arithmetic_info.h @@ -98,6 +98,13 @@ class FloorDivInfo : public ArithmeticBase { ~FloorDivInfo() override = default; }; +class PowInfo : public ArithmeticBase { + public: + PowInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs) + : ArithmeticBase(name, inputs_shape, outputs_shape, attrs) {} + ~PowInfo() override = default; +}; + class GreaterInfo : public ArithmeticBase { public: GreaterInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, diff --git a/mindspore/ccsrc/parallel/ops_info/elementary_function_info.cc b/mindspore/ccsrc/parallel/ops_info/elementary_function_info.cc deleted file mode 100644 index d4f79aca652..00000000000 --- a/mindspore/ccsrc/parallel/ops_info/elementary_function_info.cc +++ /dev/null @@ -1,47 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/ops_info/elementary_function_info.h" - -namespace mindspore { -namespace parallel { -Status PowInfo::InferMirrorOps() { - mirror_ops_.clear(); - - Shape tensor_map = inputs_tensor_map_[0]; - std::vector group; - if (CreateGroupByTensorMap(tensor_map, &group) != SUCCESS) { - MS_LOG(ERROR) << name_ << " : Create group failed."; - return FAILED; - } - - OperatorVector mirror_op; - OperatorVector op_for_value; - if (group.empty()) { - MS_LOG(INFO) << name_ << " : The mirror ops is empty."; - return SUCCESS; - } else { - mirror_op = CreateMirrorOps(group[0].name(), group[0].GetDevNum()); - mirror_ops_.push_back(mirror_op); - mirror_ops_.push_back(op_for_value); - std::string group_name = group[0].name(); - MS_LOG(INFO) << name_ << " : Create the mirror ops success, the group name is " << group_name; - } - - return SUCCESS; -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/ops_info/elementary_function_info.h b/mindspore/ccsrc/parallel/ops_info/elementary_function_info.h index 57b4650f26d..84b8030f37a 100644 --- a/mindspore/ccsrc/parallel/ops_info/elementary_function_info.h +++ b/mindspore/ccsrc/parallel/ops_info/elementary_function_info.h @@ -27,16 +27,6 @@ namespace mindspore { namespace parallel { -class PowInfo : public ActivationOther { - public: - PowInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs) - : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} - ~PowInfo() override = default; - - protected: - Status InferMirrorOps() override; -}; - class ExpInfo : public ActivationOther { public: ExpInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs) diff --git a/mindspore/nn/layer/pooling.py b/mindspore/nn/layer/pooling.py index bf90fcc9de9..5d9b0ffa6cc 100644 --- a/mindspore/nn/layer/pooling.py +++ b/mindspore/nn/layer/pooling.py @@ -58,7 +58,7 @@ class _PoolNd(Cell): pass def extend_repr(self): - return 'kernel_size={kernel_size}, strides={strides}, pad_mode={pad_mode}'.format(**self.__dict__) + return 'kernel_size={kernel_size}, stride={stride}, pad_mode={pad_mode}'.format(**self.__dict__) class MaxPool2d(_PoolNd): diff --git a/mindspore/ops/_grad/grad_math_ops.py b/mindspore/ops/_grad/grad_math_ops.py index 1863ac8fdd3..81e078dc982 100755 --- a/mindspore/ops/_grad/grad_math_ops.py +++ b/mindspore/ops/_grad/grad_math_ops.py @@ -336,14 +336,13 @@ def get_bprop_log(self): @bprop_getters.register(P.Pow) def get_bprop_pow(self): """Grad definition for `Pow` operation.""" - pow_ = P.Pow() - cast = P.Cast() - dtype = P.DType() + pow_op = P.Pow() + ln = P.Log() def bprop(x, power, out, dout): - g = cast(F.tuple_to_array((power,)), dtype(x)) * pow_(x, power-1.0) - dx = g * dout - return dx, 0 + dx = power * pow_op(x, power - 1.0) * dout + dpower = pow_op(x, power) * ln(x) * dout + return dx, dpower return bprop diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index ac7f8ed699c..850e895ad01 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -1097,7 +1097,7 @@ class ArgMaxWithValue(PrimitiveWithInfer): axis = self.axis x_rank = len(x_shape) validator.check_int_range("axis", axis, -x_rank, x_rank, Rel.INC_LEFT) - ouput_shape = _infer_shape_reduce(x_shape, self.axis, self.keep_dims, self.prim_name()) + ouput_shape = _infer_shape_reduce(x_shape, self.axis, self.keep_dims, self.name) return ouput_shape, ouput_shape def infer_dtype(self, x_dtype): @@ -1143,7 +1143,7 @@ class ArgMinWithValue(PrimitiveWithInfer): axis = self.axis x_rank = len(x_shape) validator.check_int_range("axis", axis, -x_rank, x_rank, Rel.INC_LEFT) - ouput_shape = _infer_shape_reduce(x_shape, self.axis, self.keep_dims, self.prim_name()) + ouput_shape = _infer_shape_reduce(x_shape, self.axis, self.keep_dims, self.name) return ouput_shape, ouput_shape def infer_dtype(self, x_dtype): diff --git a/mindspore/ops/operations/math_ops.py b/mindspore/ops/operations/math_ops.py index 106886c45c3..a1fe6e72b55 100644 --- a/mindspore/ops/operations/math_ops.py +++ b/mindspore/ops/operations/math_ops.py @@ -74,7 +74,7 @@ class _BinaryOp(PrimitiveWithInfer): self.init_prim_io_names(inputs=['x', 'y'], outputs=['output']) def infer_shape(self, x_shape, y_shape): - return _get_broadcast_shape(x_shape, y_shape, self.prim_name()) + return _get_broadcast_shape(x_shape, y_shape, self.name) class _MathBinaryOp(_BinaryOp): @@ -89,7 +89,7 @@ class _MathBinaryOp(_BinaryOp): return x_dtype def infer_dtype(self, x_dtype, y_dtype): - return _MathBinaryOp.do_infer_dtype(x_dtype, y_dtype, mstype.number_type, self.prim_name()) + return _MathBinaryOp.do_infer_dtype(x_dtype, y_dtype, mstype.number_type, self.name) class TensorAdd(_MathBinaryOp): @@ -158,7 +158,7 @@ class AssignAdd(PrimitiveWithInfer): def infer_dtype(self, variable, value): args = {"value": value} - validator.check_scalar_or_tensor_type_same(args, mstype.number_type, self.prim_name()) + validator.check_scalar_or_tensor_type_same(args, mstype.number_type, self.name) return value @@ -201,7 +201,7 @@ class AssignSub(PrimitiveWithInfer): def infer_dtype(self, variable, value): args = {"value": value} - validator.check_scalar_or_tensor_type_same(args, mstype.number_type, self.prim_name()) + validator.check_scalar_or_tensor_type_same(args, mstype.number_type, self.name) return value @@ -222,16 +222,16 @@ class _Reduce(PrimitiveWithInfer): @prim_attr_register def __init__(self, keep_dims=False): """init Reduce""" - validator.check_value_type('keep_dims', keep_dims, [bool], self.prim_name()) + validator.check_value_type('keep_dims', keep_dims, [bool], self.name) self.init_prim_io_names(inputs=['input_x', 'axis'], outputs=['y']) def do_infer(self, input_x, axis, valid_dtype=mstype.number_type): axis_v = axis['value'] input_shp = input_x['shape'] args = {'input_x': input_x['dtype']} - validator.check_tensor_type_same(args, valid_dtype, self.prim_name()) + validator.check_tensor_type_same(args, valid_dtype, self.name) - input_shp = _infer_shape_reduce(input_shp, axis_v, self.keep_dims, self.prim_name()) + input_shp = _infer_shape_reduce(input_shp, axis_v, self.keep_dims, self.name) return {'shape': input_shp, 'dtype': input_x['dtype'], 'value': None} @@ -466,7 +466,7 @@ class CumProd(PrimitiveWithInfer): """ @prim_attr_register def __init__(self, exclusive=False, reverse=False): - cls_name = self.prim_name() + cls_name = self.name self.exclusive = validator.check_value_type("exclusive", exclusive, [bool], cls_name) self.reverse = validator.check_value_type("reverse", reverse, [bool], cls_name) @@ -474,7 +474,7 @@ class CumProd(PrimitiveWithInfer): return x_shape def infer_dtype(self, x_type, axis_type): - cls_name = self.prim_name() + cls_name = self.name validator.check_tensor_type_same({'x': x_type}, mstype.number_type, cls_name) validator.check_subclass("axis", axis_type, mstype.int_, cls_name) return x_type @@ -510,7 +510,7 @@ class MatMul(PrimitiveWithInfer): def __init__(self, transpose_a=False, transpose_b=False): self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['output']) self.__setattr_flag__ = True - cls_name = self.prim_name() + cls_name = self.name validator.check_value_type("transpose_a", transpose_a, [bool], cls_name) validator.check_value_type("transpose_b", transpose_b, [bool], cls_name) @@ -521,7 +521,7 @@ class MatMul(PrimitiveWithInfer): def infer_shape(self, x, y): self.check_shape_size(x, y) - cls_name = self.prim_name() + cls_name = self.name # expected dimension of x, y, x:[...,a,b] y:[..., c,d], the dim size should be the same except the last two for i in range(len(x) - 2): if x[i] != y[i]: @@ -546,7 +546,7 @@ class MatMul(PrimitiveWithInfer): def infer_dtype(self, x, y): args = {"x": x, "y": y} - validator.check_tensor_type_same(args, mstype.float_type + mstype.int_type, self.prim_name()) + validator.check_tensor_type_same(args, mstype.float_type + mstype.int_type, self.name) return x @@ -590,7 +590,7 @@ class BatchMatMul(MatMul): def __init__(self, transpose_a=False, transpose_b=False): self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['output']) self.__setattr_flag__ = True - cls_name = self.prim_name() + cls_name = self.name validator.check_value_type("transpose_a", transpose_a, [bool], cls_name) validator.check_value_type("transpose_b", transpose_b, [bool], cls_name) @@ -628,13 +628,13 @@ class CumSum(PrimitiveWithInfer): @prim_attr_register def __init__(self, exclusive=False, reverse=False): """init cumsum""" - cls_name = self.prim_name() + cls_name = self.name validator.check_value_type('exclusive', exclusive, [bool], cls_name) validator.check_value_type('reverse', reverse, [bool], cls_name) self.init_prim_io_names(inputs=['x', 'axis'], outputs=['y']) def __infer__(self, x, axis): - cls_name = self.prim_name() + cls_name = self.name x_shp = x['shape'] validator.check_value_type('axis', axis['value'], [int], cls_name) valid_types = [mstype.uint8, mstype.int8, mstype.int32, mstype.float16, mstype.float32] @@ -679,7 +679,7 @@ class AddN(PrimitiveWithInfer): self.init_prim_io_names(inputs=["inputs"], outputs=["sum"]) def infer_shape(self, inputs): - cls_name = self.prim_name() + cls_name = self.name validator.check_integer("inputs", len(inputs), 1, Rel.GE, cls_name) self.add_prim_attr('n', len(inputs)) shp0 = inputs[0] @@ -688,7 +688,7 @@ class AddN(PrimitiveWithInfer): return shp0 def infer_dtype(self, inputs): - cls_name = self.prim_name() + cls_name = self.name validator.check_value_type("inputs", inputs, [tuple, list], cls_name) validator.check_integer("inputs", len(inputs), 1, Rel.GE, cls_name) args = {} @@ -718,7 +718,7 @@ class Neg(PrimitiveWithInfer): return input_x def infer_dtype(self, input_x): - validator.check_tensor_type_same({"input_x": input_x}, mstype.number_type, self.prim_name()) + validator.check_tensor_type_same({"input_x": input_x}, mstype.number_type, self.name) return input_x @@ -809,7 +809,7 @@ class Square(PrimitiveWithInfer): return x_shape def infer_dtype(self, x_type): - validator.check_tensor_type_same({"x": x_type}, mstype.number_type, self.prim_name()) + validator.check_tensor_type_same({"x": x_type}, mstype.number_type, self.name) return x_type @@ -838,7 +838,7 @@ class Rsqrt(PrimitiveWithInfer): return x_shape def infer_dtype(self, x_type): - validator.check_tensor_type_same({"x": x_type}, mstype.number_type, self.prim_name()) + validator.check_tensor_type_same({"x": x_type}, mstype.number_type, self.name) return x_type @@ -867,7 +867,7 @@ class Sqrt(PrimitiveWithInfer): return x_shape def infer_dtype(self, x_type): - validator.check_tensor_type_same({"x": x_type}, mstype.number_type, self.prim_name()) + validator.check_tensor_type_same({"x": x_type}, mstype.number_type, self.name) return x_type @@ -897,14 +897,29 @@ class Reciprocal(PrimitiveWithInfer): return x def infer_dtype(self, x): - validator.check_subclass("x", x, mstype.tensor, self.prim_name()) + validator.check_subclass("x", x, mstype.tensor, self.name) return x -class Pow(PrimitiveWithInfer): +class Pow(_MathBinaryOp): """ Computes a tensor to the power of the second input. + The first input must be a tensor, and the second input should be a tensor or a number. + When the inputs are two tensors, the shapes of them could be broadcast, + and the data types of them should be the same. + When the inputs are one tensor and one scalar, the scalar could not be a parameter, + only could be a constant, and the type of the scalar is the same as the data type of the tensor. + + Inputs: + - **input_x** (Union[Tensor]) - The first input is a tensor whose data type is number. + - **input_y** (Union[Tensor, Number]) - The second input is a tensor whose data type is same as 'input_x' or + a number. + + Outputs: + Tensor, the shape is same as the shape after broadcasting, and the data type is same as 'input_x'. + + Inputs: - **input_x** (Tensor) - The input tensor. - **input_y** (Union[Tensor, Number]) - The exponent part. If exponent is a tensor, its shape must be able to @@ -927,17 +942,6 @@ class Pow(PrimitiveWithInfer): [1.0, 16.0, 64.0] """ - @prim_attr_register - def __init__(self): - """init Multiply""" - - def infer_shape(self, x, power): - return x - - def infer_dtype(self, x, power): - validator.check_tensor_type_same({"x": x}, mstype.number_type, self.prim_name()) - return x - class Exp(PrimitiveWithInfer): """ @@ -965,7 +969,7 @@ class Exp(PrimitiveWithInfer): return x_shape def infer_dtype(self, x_type): - validator.check_subclass("x", x_type, mstype.tensor, self.prim_name()) + validator.check_subclass("x", x_type, mstype.tensor, self.name) return x_type @@ -994,7 +998,7 @@ class Log(PrimitiveWithInfer): return x def infer_dtype(self, x): - validator.check_subclass("x", x, mstype.tensor, self.prim_name()) + validator.check_subclass("x", x, mstype.tensor, self.name) return x @@ -1176,7 +1180,7 @@ class Floor(PrimitiveWithInfer): return x_shape def infer_dtype(self, x_dtype): - validator.check_tensor_type_same({"x": x_dtype}, mstype.float_type, self.prim_name()) + validator.check_tensor_type_same({"x": x_dtype}, mstype.float_type, self.name) return x_dtype @@ -1231,7 +1235,7 @@ class Acosh(PrimitiveWithInfer): return x def infer_dtype(self, x): - validator.check_tensor_type_same({'x': x}, mstype.number_type, self.prim_name()) + validator.check_tensor_type_same({'x': x}, mstype.number_type, self.name) return x @@ -1247,7 +1251,7 @@ class _LogicBinaryOp(_BinaryOp): return mstype.tensor_type(mstype.bool_) def infer_dtype(self, x_dtype, y_dtype): - return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, prim_name=self.prim_name()) + return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, prim_name=self.name) class Equal(_LogicBinaryOp): @@ -1283,7 +1287,7 @@ class Equal(_LogicBinaryOp): """ def infer_dtype(self, x_dtype, y_dtype): - return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, mstype.number_type + (mstype.bool_,), self.prim_name()) + return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, mstype.number_type + (mstype.bool_,), self.name) class EqualCount(PrimitiveWithInfer): @@ -1318,7 +1322,7 @@ class EqualCount(PrimitiveWithInfer): def infer_dtype(self, x_dtype, y_dtype): args = {'x': x_dtype, 'y': y_dtype} - validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), self.prim_name()) + validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), self.name) return x_dtype @@ -1355,7 +1359,7 @@ class NotEqual(_LogicBinaryOp): """ def infer_dtype(self, x_dtype, y_dtype): - return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, mstype.number_type + (mstype.bool_,), self.prim_name()) + return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, mstype.number_type + (mstype.bool_,), self.name) class Greater(_LogicBinaryOp): @@ -1491,7 +1495,7 @@ class LogicalNot(PrimitiveWithInfer): return x_shape def infer_dtype(self, x_dtype): - validator.check_tensor_type_same({"x": x_dtype}, [mstype.bool_], self.prim_name()) + validator.check_tensor_type_same({"x": x_dtype}, [mstype.bool_], self.name) return mstype.tensor_type(mstype.bool_) @@ -1521,7 +1525,7 @@ class LogicalAnd(_LogicBinaryOp): """ def infer_dtype(self, x_dtype, y_dtype): - return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, (mstype.bool_,), self.prim_name()) + return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, (mstype.bool_,), self.name) class LogicalOr(_LogicBinaryOp): @@ -1550,7 +1554,7 @@ class LogicalOr(_LogicBinaryOp): """ def infer_dtype(self, x_dtype, y_dtype): - return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, (mstype.bool_,), self.prim_name()) + return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, (mstype.bool_,), self.name) class IsNan(PrimitiveWithInfer): """ @@ -1699,13 +1703,13 @@ class NPUGetFloatStatus(PrimitiveWithInfer): self.add_prim_attr("_side_effect_flag", True) def infer_shape(self, x_shape): - cls_name = self.prim_name() + cls_name = self.name validator.check_integer("len(x_shape)", len(x_shape), 1, Rel.EQ, cls_name) validator.check_integer("x_shape[0]", x_shape[0], 8, Rel.EQ, cls_name) return [8] def infer_dtype(self, x_dtype): - validator.check_tensor_type_same({'x': x_dtype}, [mstype.float32], self.prim_name()) + validator.check_tensor_type_same({'x': x_dtype}, [mstype.float32], self.name) return mstype.float32 @@ -1741,13 +1745,13 @@ class NPUClearFloatStatus(PrimitiveWithInfer): self.add_prim_attr("_side_effect_flag", True) def infer_shape(self, x_shape): - cls_name = self.prim_name() + cls_name = self.name validator.check_integer("len(x_shape)", len(x_shape), 1, Rel.EQ, cls_name) validator.check_integer("x_shape[0]", x_shape[0], 8, Rel.EQ, cls_name) return [8] def infer_dtype(self, x_dtype): - validator.check_tensor_type_same({'x': x_dtype}, [mstype.float32], self.prim_name()) + validator.check_tensor_type_same({'x': x_dtype}, [mstype.float32], self.name) return mstype.float32 @@ -1775,7 +1779,7 @@ class Cos(PrimitiveWithInfer): return x def infer_dtype(self, x): - validator.check_tensor_type_same({'x': x}, mstype.number_type, self.prim_name()) + validator.check_tensor_type_same({'x': x}, mstype.number_type, self.name) return x @@ -1803,7 +1807,7 @@ class ACos(PrimitiveWithInfer): return x def infer_dtype(self, x): - validator.check_tensor_type_same({'x': x}, mstype.number_type, self.prim_name()) + validator.check_tensor_type_same({'x': x}, mstype.number_type, self.name) return x @@ -1831,7 +1835,7 @@ class Sin(PrimitiveWithInfer): return x def infer_dtype(self, x): - validator.check_tensor_type_same({'x': x}, mstype.number_type, self.prim_name()) + validator.check_tensor_type_same({'x': x}, mstype.number_type, self.name) return x @@ -1876,11 +1880,11 @@ class NMSWithMask(PrimitiveWithInfer): @prim_attr_register def __init__(self, iou_threshold=0.5): """Init NMSWithMask""" - validator.check_value_type("iou_threshold", iou_threshold, [float], self.prim_name()) + validator.check_value_type("iou_threshold", iou_threshold, [float], self.name) self.init_prim_io_names(inputs=['bboxes'], outputs=['selected_boxes', 'selected_idx', 'selected_mask']) def infer_shape(self, bboxes_shape): - cls_name = self.prim_name() + cls_name = self.name validator.check_integer("bboxes rank", len(bboxes_shape), 2, Rel.EQ, cls_name) validator.check_integer("bboxes.shape()[0]", bboxes_shape[0], 0, Rel.GT, cls_name) validator.check_integer("bboxes.shape()[1]", bboxes_shape[1], 5, Rel.EQ, cls_name) @@ -1888,7 +1892,7 @@ class NMSWithMask(PrimitiveWithInfer): return (bboxes_shape, (num,), (num,)) def infer_dtype(self, bboxes_dtype): - validator.check_tensor_type_same({"bboxes": bboxes_dtype}, [mstype.float16, mstype.float32], self.prim_name()) + validator.check_tensor_type_same({"bboxes": bboxes_dtype}, [mstype.float16, mstype.float32], self.name) return (bboxes_dtype, mstype.int32, mstype.bool_) @@ -1917,7 +1921,7 @@ class Abs(PrimitiveWithInfer): return x_shape def infer_dtype(self, x_type): - validator.check_tensor_type_same({'x': x_type}, mstype.number_type, self.prim_name()) + validator.check_tensor_type_same({'x': x_type}, mstype.number_type, self.name) return x_type def infer_value(self, x): @@ -1959,7 +1963,7 @@ class Sign(PrimitiveWithInfer): return x_shape def infer_dtype(self, x_dtype): - validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.prim_name()) + validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.name) return x_dtype @@ -1988,7 +1992,7 @@ class Round(PrimitiveWithInfer): return x_shape def infer_dtype(self, x_type): - validator.check_tensor_type_same({'x': x_type}, mstype.number_type, self.prim_name()) + validator.check_tensor_type_same({'x': x_type}, mstype.number_type, self.name) return x_type diff --git a/mindspore/ops/primitive.py b/mindspore/ops/primitive.py index 96e754f5f7d..d281b4f76ca 100644 --- a/mindspore/ops/primitive.py +++ b/mindspore/ops/primitive.py @@ -194,9 +194,6 @@ class PrimitiveWithInfer(Primitive): Primitive.__init__(self, name) self.set_prim_type(prim_type.py_infer_shape) - def prim_name(self): - return self.__class__.__name__ - def _clone(self): """ Deeply clones the primitive object. diff --git a/tests/ut/cpp/parallel/ops_info/pow_info_test.cc b/tests/ut/cpp/parallel/ops_info/pow_info_test.cc index f6ea2c3d3c3..7b37a90fd83 100644 --- a/tests/ut/cpp/parallel/ops_info/pow_info_test.cc +++ b/tests/ut/cpp/parallel/ops_info/pow_info_test.cc @@ -19,7 +19,7 @@ #include #include "common/common_test.h" #include "parallel/strategy.h" -#include "parallel/ops_info/elementary_function_info.h" +#include "parallel/ops_info/arithmetic_info.h" #include "parallel/device_manager.h" #include "parallel/step_parallel.h" @@ -56,14 +56,14 @@ void TestPowInfo::SetUp() { std::unordered_map attr; - Shapes inputs_shape = {{32, 64, 128}}; + Shapes inputs_shape = {{32, 64, 128}, {32, 64, 128}}; Shapes outputs_shape = {{32, 64, 128}}; pow = std::make_shared("pow_info", inputs_shape, outputs_shape, attr); } TEST_F(TestPowInfo, InferDevMatrixShape1) { - std::vector inputs = {{2, 4, 8}}; + std::vector inputs = {{2, 4, 8}, {2, 4, 8}}; StrategyPtr strategy = NewStrategy(0, inputs); pow->Init(strategy); @@ -74,7 +74,7 @@ TEST_F(TestPowInfo, InferDevMatrixShape1) { } TEST_F(TestPowInfo, InferSliceShape1) { - std::vector str = {{2, 4, 8}}; + std::vector str = {{2, 4, 8}, {2, 4, 8}}; StrategyPtr strategy = NewStrategy(0, str); pow->Init(strategy); @@ -95,7 +95,7 @@ TEST_F(TestPowInfo, InferSliceShape1) { } TEST_F(TestPowInfo, GetTensorLayout1) { - std::vector str = {{2, 4, 8}}; + std::vector str = {{2, 4, 8}, {2, 4, 8}}; StrategyPtr strategy = NewStrategy(0, str); pow->Init(strategy); @@ -116,7 +116,7 @@ TEST_F(TestPowInfo, GetTensorLayout1) { } TEST_F(TestPowInfo, GetForwardOp1) { - std::vector inputs = {{2, 4, 8}}; + std::vector inputs = {{2, 4, 8}, {2, 4, 8}}; StrategyPtr strategy = NewStrategy(0, inputs); pow->Init(strategy); @@ -127,7 +127,7 @@ TEST_F(TestPowInfo, GetForwardOp1) { } TEST_F(TestPowInfo, GetMirrorOPs1) { - std::vector inputs = {{2, 4, 8}}; + std::vector inputs = {{2, 4, 8}, {2, 4, 8}}; StrategyPtr strategy = NewStrategy(0, inputs); pow->Init(strategy); @@ -147,7 +147,7 @@ TEST_F(TestPowInfo, CheckStrategy1) { } TEST_F(TestPowInfo, CheckStrategy2) { - std::vector inputs = {{2, 4, 8, 16}}; + std::vector inputs = {{2, 4, 8, 16}, {2, 4, 8, 16}}; StrategyPtr strategy = NewStrategy(0, inputs); Status ret = pow->Init(strategy); @@ -155,7 +155,7 @@ TEST_F(TestPowInfo, CheckStrategy2) { } TEST_F(TestPowInfo, CheckStrategy3) { - std::vector inputs = {{2, 4, 8}}; + std::vector inputs = {{2, 4, 8}, {2, 4, 8}}; StrategyPtr strategy = NewStrategy(0, inputs); Status ret = pow->Init(strategy); diff --git a/tests/ut/python/ops/test_math_ops.py b/tests/ut/python/ops/test_math_ops.py index 7c0cca9b400..ad1642228d6 100755 --- a/tests/ut/python/ops/test_math_ops.py +++ b/tests/ut/python/ops/test_math_ops.py @@ -82,9 +82,10 @@ def test_sqrt(): def test_pow(): """ test_pow """ input_tensor = Tensor(np.array([[2, 2], [3, 3]])) + power = Tensor(np.array(3.0, np.int64)) testpow = P.Pow() expect = np.array([[8, 8], [27, 27]]) - result = testpow(input_tensor, 3.0) + result = testpow(input_tensor, power) assert np.all(result.asnumpy() == expect) diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index a6b064bdb0c..078ada84065 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -224,11 +224,15 @@ test_case_math_ops = [ 'block': P.Minimum(), 'desc_inputs': [[2, 3, 3, 5], [2, 3, 3, 5]], 'desc_bprop': [[2, 3, 3, 5]]}), - ('Pow', { + ('Pow_0', { 'block': P.Pow(), 'desc_const': [2.0], 'desc_inputs': [[2, 3, 3, 5]], 'desc_bprop': [[2, 3, 3, 5]]}), + ('Pow_1', { + 'block': P.Pow(), + 'desc_inputs': [[3, 5], [2, 3, 3, 5]], + 'desc_bprop': [[2, 3, 3, 5]]}), ('Exp', { 'block': P.Exp(), 'desc_inputs': [[2, 3]], diff --git a/tests/ut/python/parallel/test_element_wise_function.py b/tests/ut/python/parallel/test_element_wise_function.py index 2eb3a22ed2a..641eb19f20f 100644 --- a/tests/ut/python/parallel/test_element_wise_function.py +++ b/tests/ut/python/parallel/test_element_wise_function.py @@ -59,7 +59,7 @@ def test_matmul_pow(): context.set_auto_parallel_context(device_num=8, global_rank=0) strategy1 = ((2, 2), (2, 2)) - strategy2 = ((4, 2), ) + strategy2 = ((4, 2), ()) net = GradWrap(NetWithLoss(Net(strategy1, strategy2))) context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") diff --git a/tests/vm_impl/math_ops_vm_impl.py b/tests/vm_impl/math_ops_vm_impl.py index fd132280d10..01df0b824e2 100644 --- a/tests/vm_impl/math_ops_vm_impl.py +++ b/tests/vm_impl/math_ops_vm_impl.py @@ -117,6 +117,7 @@ def vm_impl_pow(self): """Generate vm_impl function for Pow.""" def vm_impl(x, y): x = x.asnumpy() + y = y.asnumpy() res = vm.power(x, y) return Tensor(res) return vm_impl