From b3a39a0b20b9b8a8585a7c906df2c28958f1faa4 Mon Sep 17 00:00:00 2001 From: shen_jingxing Date: Sat, 29 May 2021 17:52:06 +0800 Subject: [PATCH] Tile --- mindspore/core/abstract/infer_functions.h | 2 -- mindspore/core/abstract/prim_arrays.cc | 33 ------------------- .../core/abstract/primitive_infer_map.cc | 3 +- mindspore/core/ops/tile.cc | 3 ++ 4 files changed, 5 insertions(+), 36 deletions(-) diff --git a/mindspore/core/abstract/infer_functions.h b/mindspore/core/abstract/infer_functions.h index 07d1659c80c..abef3df91cf 100644 --- a/mindspore/core/abstract/infer_functions.h +++ b/mindspore/core/abstract/infer_functions.h @@ -252,8 +252,6 @@ AbstractBasePtr InferImplLinSpace(const AnalysisEnginePtr &, const PrimitivePtr const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplExpandDims(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplTile(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplGpuConvertToDynamicShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplPad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, diff --git a/mindspore/core/abstract/prim_arrays.cc b/mindspore/core/abstract/prim_arrays.cc index 17d7b68e0ee..fa9ff35c6c2 100644 --- a/mindspore/core/abstract/prim_arrays.cc +++ b/mindspore/core/abstract/prim_arrays.cc @@ -81,39 +81,6 @@ AbstractBasePtr InferImplBroadCastShape(const AnalysisEnginePtr &, const Primiti return std::make_shared(elems); } -AbstractBasePtr InferImplTile(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: a tensor and a tuple. - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 2); - auto arg = CheckArg(op_name, args_spec_list, 0); - auto multiples = CheckArg(op_name, args_spec_list, 1); - - ShapePtr input_shape = arg->shape(); - (void)CheckTensorDType(arg, {kInt16, kFloat16, kInt32, kFloat32}, "Input 0 of Tile should be %s"); - - auto mul_shp_value = multiples->BuildValue(); - if (mul_shp_value->isa()) { - MS_LOG(EXCEPTION) << "shape's data field can't be anything: " << args_spec_list[1]->ToString(); - } - - ShapeVector mul_shp; - auto value_tuple_mul = mul_shp_value->cast(); - auto mul_shp_data = value_tuple_mul->value(); - (void)std::transform(std::begin(mul_shp_data), std::end(mul_shp_data), std::back_inserter(mul_shp), - [](const ValuePtr &e) -> int64_t { return GetValue(e); }); - if (input_shape->shape().size() != mul_shp_data.size()) { - MS_LOG(EXCEPTION) << "Tile requires input and multiples size equal, while the input size is " - << input_shape->shape().size() << ", value size is: " << mul_shp_data.size() << "."; - } - - ShapeVector result_shp; - for (size_t i = 0; i < mul_shp_data.size(); ++i) { - result_shp.push_back(input_shape->shape()[i] * mul_shp[i]); - } - return std::make_shared(arg->element(), std::make_shared(result_shp)); -} - AbstractBasePtr InferImplStack(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { // Inputs: a tuple of tensor. diff --git a/mindspore/core/abstract/primitive_infer_map.cc b/mindspore/core/abstract/primitive_infer_map.cc index 879e8b11b3e..5ba23c03732 100644 --- a/mindspore/core/abstract/primitive_infer_map.cc +++ b/mindspore/core/abstract/primitive_infer_map.cc @@ -25,6 +25,7 @@ #include "ops/add.h" #include "abstract/abstract_function.h" #include "abstract/infer_functions.h" +#include "ops/tile.h" namespace mindspore { namespace abstract { @@ -174,7 +175,7 @@ PrimitiveEvalImplMap &GetPrimitiveToBackendEvalImplMap() { {prim::kPrimSqrtGrad, {InferImplSqrtGrad, nullptr, true}}, {prim::kPrimSub, {InferImplSub, nullptr, false}}, {prim::kPrimEqual, {InferImplEqual, nullptr, true}}, - {prim::kPrimTile, {InferImplTile, nullptr, false}}, + {prim::kPrimTile, {ops::TileInfer, nullptr, true}}, {prim::kPrimReduceSum, {InferImplReduceFunc, nullptr, true}}, {prim::kPrimReduceMean, {InferImplReduceFunc, nullptr, true}}, {prim::kPrimReduceAll, {InferImplReduceFunc, nullptr, true}}, diff --git a/mindspore/core/ops/tile.cc b/mindspore/core/ops/tile.cc index 596eca8c63a..261375df088 100644 --- a/mindspore/core/ops/tile.cc +++ b/mindspore/core/ops/tile.cc @@ -41,6 +41,9 @@ std::vector GetInferShape(const std::vector &input_shape, cons "length of dimension in input_x"; } for (size_t i = 0; i < multiples_w.size(); i++) { + if (infer_shape[i] == abstract::Shape::SHP_ANY) { + continue; + } infer_shape[i] *= multiples_w[i]; } return infer_shape;