From 96573e6bd4a90bf40921a2bc75aaff6c0dec50f1 Mon Sep 17 00:00:00 2001 From: jin-xiulang Date: Fri, 4 Nov 2022 09:38:02 +0800 Subject: [PATCH] tranfsfer Reshape's infer_shape from Python to core/ops --- mindspore/core/abstract/ops/infer_functions.h | 2 - mindspore/core/abstract/ops/prim_arrays.cc | 119 --------------- .../core/abstract/ops/primitive_infer_map.cc | 1 - mindspore/core/ops/reshape.cc | 140 +++++++++++++++++- mindspore/core/ops/reshape.h | 3 - .../mindspore/ops/operations/array_ops.py | 97 ++++-------- 6 files changed, 166 insertions(+), 196 deletions(-) diff --git a/mindspore/core/abstract/ops/infer_functions.h b/mindspore/core/abstract/ops/infer_functions.h index 75eccb1860d..40f6c35c052 100644 --- a/mindspore/core/abstract/ops/infer_functions.h +++ b/mindspore/core/abstract/ops/infer_functions.h @@ -180,8 +180,6 @@ AbstractBasePtr InferImplSGD(const AnalysisEnginePtr &, const PrimitivePtr &prim const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplTranspose(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplReshape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplMemCpyAsync(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplEmbeddingLookup(const AnalysisEnginePtr &, const PrimitivePtr &primitive, diff --git a/mindspore/core/abstract/ops/prim_arrays.cc b/mindspore/core/abstract/ops/prim_arrays.cc index 3fca24c75e1..46c62008d56 100644 --- a/mindspore/core/abstract/ops/prim_arrays.cc +++ b/mindspore/core/abstract/ops/prim_arrays.cc @@ -677,125 +677,6 @@ AbstractBasePtr InferImplTranspose(const AnalysisEnginePtr &, const PrimitivePtr return std::make_shared(input->element(), std::make_shared(result_shp, min_shp, max_shp)); } -static ShapeVector GetShape(const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list, - const std::string &op_name) { - ShapeVector shape; - if (args_spec_list.size() == kSizeTwo) { - auto input_value = args_spec_list[1]->BuildValue(); - if (input_value->isa()) { - shape = CheckAndConvertUtils::CheckTensorIntValue("shape", input_value, op_name); - } else { - shape = CheckAndConvertUtils::CheckTupleInt("input[shape]", input_value, op_name); - } - } else { - ValuePtr sh = primitive->GetAttr("shape"); - MS_EXCEPTION_IF_NULL(sh); - if (sh->isa()) { - auto reshape_value_tuple = sh->cast(); - MS_EXCEPTION_IF_NULL(reshape_value_tuple); - auto reshape_tuple = reshape_value_tuple->value(); - (void)std::transform(std::begin(reshape_tuple), std::end(reshape_tuple), std::back_inserter(shape), - [](const ValuePtr &e) -> int64_t { return GetValue(e); }); - } else if (sh->isa()) { - shape = CheckAndConvertUtils::CheckTensorIntValue("shape", sh, "Reshape"); - } else { - MS_EXCEPTION(ValueError) << "In stage of execution, the primitive[Reshape]'s input['shape'] must be a tuple or " - << "constant Tensor."; - } - } - return shape; -} - -AbstractBasePtr InferImplReshape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - const std::string op_name = primitive->name(); - auto x = CheckArg(op_name, args_spec_list, 0); - MS_EXCEPTION_IF_NULL(x); - MS_EXCEPTION_IF_NULL(x->shape()); - // padding_axis_value is for the condition that the number of -1 in shape > 1, , but just paddingshape - std::vector padding_axis_value; - ValuePtr padding_axis = primitive->GetAttr("reshape_padding_axis"); - if (padding_axis != nullptr) { - padding_axis_value = GetValue>(padding_axis); - } - - ShapeVector shape = GetShape(primitive, args_spec_list, op_name); - ShapeVector x_shape = x->shape()->shape(); - ShapeVector x_max_shape = x->shape()->max_shape(); - ShapeVector x_min_shape = x->shape()->min_shape(); - if (x_max_shape.empty()) { - x_max_shape = x_shape; - } - if (x_min_shape.empty()) { - x_min_shape = x_shape; - } - - auto max_shape = shape; - auto min_shape = shape; - int64_t x_num = 1; - int64_t x_min_num = 1; - int64_t x_max_num = 1; - for (int64_t value : x_shape) { - x_num = LongMulWithOverflowCheck(value, x_num); - } - for (int64_t value : x_min_shape) { - x_min_num = LongMulWithOverflowCheck(value, x_min_num); - } - for (int64_t value : x_max_shape) { - x_max_num = LongMulWithOverflowCheck(value, x_max_num); - } - - auto it_first = find(shape.begin(), shape.end(), -1); - if (it_first != shape.end()) { - if (!padding_axis_value.empty()) { - // the condition that the number of -1 in shape is > 1, but just paddingshape - for (size_t index = 0; index < padding_axis_value.size(); ++index) { - shape[IntToSize(padding_axis_value[index])] = x_shape[index]; - min_shape[IntToSize(padding_axis_value[index])] = x_min_shape[index]; - max_shape[IntToSize(padding_axis_value[index])] = x_max_shape[index]; - } - } else { - auto it_second = find(it_first + 1, shape.end(), -1); - if (it_second != shape.end()) { - MS_LOG(EXCEPTION) << "At most one component of input shape can be -1, but got " << shape; - } - auto index = LongToSize(std::distance(shape.begin(), it_first)); - int64_t infer_value = x_num; - int64_t infer_min_value = x_min_num; - int64_t infer_max_value = x_max_num; - for (size_t i = 0; i < shape.size(); ++i) { - int64_t value = shape[i]; - if (value != -1 && value != 0) { - infer_value = infer_value / value; - infer_min_value = infer_min_value / value; - infer_max_value = infer_max_value / value; - } - } - shape[index] = infer_value; - min_shape[index] = infer_min_value; - max_shape[index] = infer_max_value; - } - } - - int64_t shape_num = 1; - for (int64_t value : shape) { - shape_num = LongMulWithOverflowCheck(value, shape_num); - } - if (shape_num != x_num) { - MS_LOG(EXCEPTION) << "The accumulate of x_shape must be equal to out_shape, but got x_shape: " << x_shape - << ", and out_shape: " << shape; - } - - if (IsDynamic(min_shape) || IsDynamic(max_shape)) { - min_shape.clear(); - max_shape.clear(); - } - - AbstractTensorPtr ret = - std::make_shared(x->element(), std::make_shared(shape, min_shape, max_shape)); - return ret; -} - AbstractBasePtr InferImplMapUniform(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { // Inputs: one tensor. diff --git a/mindspore/core/abstract/ops/primitive_infer_map.cc b/mindspore/core/abstract/ops/primitive_infer_map.cc index b09a3b04ca3..4c8b092ad04 100644 --- a/mindspore/core/abstract/ops/primitive_infer_map.cc +++ b/mindspore/core/abstract/ops/primitive_infer_map.cc @@ -401,7 +401,6 @@ PrimitiveEvalImplMap &GetPrimitiveToBackendEvalImplMap() { {prim::kPrimTransposeNOD, R{InferImplTranspose, nullptr, true}}, {prim::kPrimStridedSlice, R{ops::StridedSliceInfer, nullptr, true}}, {prim::kPrimSlice, R{ops::SliceInfer, nullptr, true}}, - {prim::kPrimReshape, R{InferImplReshape, nullptr, true}}, {prim::kPrimSliceGrad, R{ops::SliceGradInfer, nullptr, true}}, {prim::kPrimConcat, R{InferImplConcat, nullptr, true}}, {prim::kPrimConcatOffset, R{InferImplConcatOffset, nullptr, true}}, diff --git a/mindspore/core/ops/reshape.cc b/mindspore/core/ops/reshape.cc index 4e7b552bb70..102eb519965 100644 --- a/mindspore/core/ops/reshape.cc +++ b/mindspore/core/ops/reshape.cc @@ -14,17 +14,151 @@ * limitations under the License. */ #include "ops/reshape.h" - #include #include - +#include +#include +#include +#include "ops/op_utils.h" #include "utils/check_convert_utils.h" #include "abstract/ops/primitive_infer_map.h" #include "mindapi/src/helper.h" namespace mindspore { namespace ops { +ShapeVector update_shape(const std::vector &padding_axis_value, const ShapeVector &x_shape, ShapeVector shape) { + // padding_axis_value is for the condition that the number of -1 in shape > 1, , but just paddingshape + if (std::any_of(x_shape.begin(), x_shape.end(), [](const int &shape_i) { return shape_i < 0; })) { + return shape; + } + int64_t x_num = 1; + for (int64_t value : x_shape) { + x_num = LongMulWithOverflowCheck(value, x_num); + } + + auto it_first = find(shape.begin(), shape.end(), -1); + if (it_first != shape.end()) { + if (!padding_axis_value.empty()) { + // the condition that the number of -1 in shape is > 1, but just paddingshape + for (size_t index = 0; index < padding_axis_value.size(); ++index) { + shape[IntToSize(padding_axis_value[index])] = x_shape[index]; + } + } else { + auto it_second = find(it_first + 1, shape.end(), -1); + if (it_second != shape.end()) { + MS_EXCEPTION(ValueError) << "At most one component of input shape can be -1, but got " << shape; + } + auto index = LongToSize(std::distance(shape.begin(), it_first)); + int64_t infer_value = x_num; + for (size_t i = 0; i < shape.size(); ++i) { + int64_t value = shape[i]; + if (value != -1 && value != 0) { + infer_value = infer_value / value; + } + } + shape[index] = infer_value; + } + } + + int64_t shape_num = 1; + for (int64_t value : shape) { + shape_num = LongMulWithOverflowCheck(value, shape_num); + } + if (shape_num != x_num) { + MS_EXCEPTION(ValueError) << "The accumulate of x_shape must be equal to out_shape, but got x_shape: " << x_shape + << ", and out_shape: " << shape; + } + return shape; +} + MIND_API_OPERATOR_IMPL(Reshape, BaseOperator); -REGISTER_PRIMITIVE_C(kNameReshape, Reshape); +class ReshapeInfer : public abstract::OpInferBase { + public: + BaseShapePtr InferShape(const PrimitivePtr &primitive, + const std::vector &input_args) const override { + MS_EXCEPTION_IF_NULL(primitive); + auto prim_name = primitive->name(); + constexpr size_t max_size = 2; + (void)CheckAndConvertUtils::CheckValue("input size", input_args.size(), kLessEqual, max_size, prim_name); + + auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; + + std::vector output_shape; + + if (input_args.size() == max_size) { + auto input_y = input_args[1]; + MS_EXCEPTION_IF_NULL(input_y); + if (input_y->isa()) { + auto y_value = input_y->BuildValue(); + MS_EXCEPTION_IF_NULL(y_value); + if (y_value != kAnyValue) { + output_shape = CheckAndConvertUtils::CheckTensorIntValue("shape", y_value, prim_name); + } else { + abstract::ShapePtr y_shape = CheckAndConvertUtils::GetTensorInputShape(prim_name, input_args, 1); + auto shape_value = y_shape->shape(); + if (shape_value.size() != 1) { + MS_EXCEPTION(TypeError) << "For '" << prim_name + << "', the shape size must be 1, but got: " << shape_value.size() << "."; + } + if (y_shape->IsDynamic()) { + output_shape.push_back(abstract::Shape::kShapeRankAny); + } else { + output_shape = GetShapeValue(primitive, input_y); + } + return std::make_shared(output_shape); + } + } else if (input_y->isa()) { + auto y_value = input_y->BuildValue(); + MS_EXCEPTION_IF_NULL(y_value); + output_shape = CheckAndConvertUtils::CheckTupleInt("input[shape]", y_value, primitive->name()); + } else { + MS_EXCEPTION(TypeError) << "input_y must be AbstractTensor or AbstractTuple, but got: " << input_y; + } + } else { + // When the shape is passed as an attribute, shape should be constant. + ValuePtr sh = primitive->GetAttr("shape"); + MS_EXCEPTION_IF_NULL(sh); + if (sh->isa()) { + auto reshape_value_tuple = sh->cast(); + MS_EXCEPTION_IF_NULL(reshape_value_tuple); + auto reshape_tuple = reshape_value_tuple->value(); + (void)std::transform(reshape_tuple.begin(), reshape_tuple.end(), std::back_inserter(output_shape), + [=](const ValuePtr &e) -> int64_t { + if (!e->isa()) { + MS_EXCEPTION(TypeError) + << "For primitive[" << prim_name << "], the 'shape'" + << " must be a tuple with all Int elements, but got " << sh->ToString(); + } + return GetValue(e); + }); + } else if (sh->isa()) { + output_shape = CheckAndConvertUtils::CheckTensorIntValue("shape", sh, "Reshape"); + } else { + MS_EXCEPTION(ValueError) + << "In stage of execution, the primitive[Reshape]'s input['shape'] must be a tuple or " + << "constant Tensor."; + } + } + + std::vector padding_axis_value; + ValuePtr padding_axis = primitive->GetAttr("reshape_padding_axis"); + if (padding_axis != nullptr) { + padding_axis_value = GetValue>(padding_axis); + } + ShapeVector updated_shape = update_shape(padding_axis_value, x_shape, output_shape); + return std::make_shared(updated_shape); + } + + TypePtr InferType(const PrimitivePtr &prim, const std::vector &input_args) const override { + for (const auto &item : input_args) { + MS_EXCEPTION_IF_NULL(item); + } + auto x_dtype = input_args[0]->BuildType(); + std::set template_types = {kTensorType}; + (void)CheckAndConvertUtils::CheckSubClass("x_dtype", x_dtype, template_types, prim->name()); + return x_dtype; + } +}; +REGISTER_PRIMITIVE_OP_INFER_IMPL(Reshape, prim::kPrimReshape, ReshapeInfer, false); } // namespace ops } // namespace mindspore diff --git a/mindspore/core/ops/reshape.h b/mindspore/core/ops/reshape.h index f1b0d7cc43e..74572ebf20f 100644 --- a/mindspore/core/ops/reshape.h +++ b/mindspore/core/ops/reshape.h @@ -37,9 +37,6 @@ class MIND_API Reshape : public BaseOperator { /// \brief Init. void Init() const {} }; - -abstract::AbstractBasePtr ReshapeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, - const std::vector &input_args); } // namespace ops } // namespace mindspore diff --git a/mindspore/python/mindspore/ops/operations/array_ops.py b/mindspore/python/mindspore/ops/operations/array_ops.py index 8a6b275c938..b72e85bbae7 100755 --- a/mindspore/python/mindspore/ops/operations/array_ops.py +++ b/mindspore/python/mindspore/ops/operations/array_ops.py @@ -752,7 +752,7 @@ class Col2Im(Primitive): self.add_prim_attr('stride', self.stride) -class Reshape(PrimitiveWithInfer): +class Reshape(PrimitiveWithCheck): """ Rearranges the input Tensor based on the given shape. @@ -776,85 +776,46 @@ class Reshape(PrimitiveWithInfer): """Initialize Reshape""" self.init_prim_io_names(inputs=['tensor', 'shape'], outputs=['output']) - @staticmethod - def _get_shape_and_range(x, shape): - """ get min and max shape when output shape is dynamic""" - x_shp = x['shape'] - if is_shape_unknown(shape['shape']): - out_shape = [-2] - return out_shape - - shape_rank = shape['shape'][0] - if not x_shp: - # x is a scalar, output shape fixed - out_shape = [1] * shape_rank - return out_shape - - out_shape = [-1] * shape_rank - - if "shape_value" in shape: - out_shape = shape["shape_value"] - return out_shape - - def _update_shape_and_value(self, out, x, shape_v, dim_prod, neg_index): - """ update shape, value and min / max value of output when input shape is known""" - x_shp = x['shape'] - if dim_prod <= 0: - raise ValueError(f"For '{self.name}', the shape of 'input_x' is {x_shp}, " - f"the value of 'input_shape' is {shape_v}. " - f"The product of 'input_shape' should > 0, but got {dim_prod}.") - arr_prod = np.prod(x_shp) - if neg_index != -1: - shape_v[neg_index] = int(arr_prod // dim_prod) - dim_prod *= shape_v[neg_index] - if dim_prod != arr_prod: - raise ValueError(f"For '{self.name}', the product of the 'input_x' shape " - f"should be equal to product of 'input_shape', but got product of the" - f" shape of 'input_x': {arr_prod}, product of 'input_shape': {dim_prod}.") - out['shape'] = tuple(shape_v) - - if x['value'] is not None: - out['value'] = Tensor(x['value'].asnumpy().reshape(shape_v)) - - return out - - def __infer__(self, x, shape): - shape_v = shape['value'] - validator.check_subclass("x", x['dtype'], mstype.tensor, self.name) + def infer_value(self, x, shape): + """infer value""" # for shape is not constant - if shape_v is None: - out_shape = self._get_shape_and_range(x, shape) - return { - 'shape': out_shape, - 'dtype': x['dtype'], - 'value': None, - } - - if isinstance(shape_v, Tensor_): - validator.check_tensor_dtype_valid("shape", shape['dtype'], [mstype.int32, mstype.int64], self.name) - shape_v = shape_v.asnumpy().tolist() + if shape is None or x is None: + return None + if isinstance(shape, Tensor_): + validator.check_tensor_dtype_valid("shape", mstype.tensor_type(shape.dtype), + [mstype.int32, mstype.int64], self.name) + shape = shape.asnumpy().tolist() else: - validator.check_value_type("shape", shape_v, [tuple], self.name) - shape_v = list(shape_v) + validator.check_value_type("shape", shape, [tuple], self.name) + shape = list(shape) neg_index = -1 dim_prod = 1 - for i, shp_i in enumerate(shape_v): + for i, shp_i in enumerate(shape): validator.check_value_type("shape[%d]" % i, shp_i, [int], self.name) if shp_i == -1: if neg_index != -1: raise ValueError(f"For '{self.name}', there can be at most one '-1' in 'input_shape', " - f"but got {shape_v}.") + f"but got {shape}.") neg_index = i else: dim_prod *= shp_i - - out = {'shape': shape_v, - 'dtype': x['dtype'], - 'value': None} - - if not is_shape_unknown(x['shape']): - out = self._update_shape_and_value(out, x, shape_v, dim_prod, neg_index) + out = None + if not is_shape_unknown(x.shape): + x_shp = x.shape + if dim_prod <= 0: + raise ValueError(f"For '{self.name}', the shape of 'input_x' is {x_shp}, " + f"the value of 'input_shape' is {shape}. " + f"The product of 'input_shape' should > 0, but got {dim_prod}.") + arr_prod = np.prod(x_shp) + if neg_index != -1: + shape[neg_index] = int(arr_prod // dim_prod) + dim_prod *= shape[neg_index] + if dim_prod != arr_prod: + raise ValueError(f"For '{self.name}', the product of the 'input_x' shape " + f"should be equal to product of 'input_shape', but got product of the" + f" shape of 'input_x': {arr_prod}, product of 'input_shape': {dim_prod}.") + out = Tensor(x.asnumpy().reshape(shape)) return out