!17179 some supplement of Tile operator.

From: @shen_jingxing
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-05-31 17:27:06 +08:00 committed by Gitee
commit c470060584
4 changed files with 5 additions and 36 deletions

View File

@ -252,8 +252,6 @@ AbstractBasePtr InferImplLinSpace(const AnalysisEnginePtr &, const PrimitivePtr
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplExpandDims(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplTile(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplGpuConvertToDynamicShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplPad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,

View File

@ -82,39 +82,6 @@ AbstractBasePtr InferImplBroadCastShape(const AnalysisEnginePtr &, const Primiti
return std::make_shared<AbstractTuple>(elems);
}
AbstractBasePtr InferImplTile(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: a tensor and a tuple.
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 2);
auto arg = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
auto multiples = CheckArg<AbstractTuple>(op_name, args_spec_list, 1);
ShapePtr input_shape = arg->shape();
(void)CheckTensorDType(arg, {kInt16, kFloat16, kInt32, kFloat32}, "Input 0 of Tile should be %s");
auto mul_shp_value = multiples->BuildValue();
if (mul_shp_value->isa<AnyValue>()) {
MS_LOG(EXCEPTION) << "shape's data field can't be anything: " << args_spec_list[1]->ToString();
}
ShapeVector mul_shp;
auto value_tuple_mul = mul_shp_value->cast<ValueTuplePtr>();
auto mul_shp_data = value_tuple_mul->value();
(void)std::transform(std::begin(mul_shp_data), std::end(mul_shp_data), std::back_inserter(mul_shp),
[](const ValuePtr &e) -> int64_t { return GetValue<int64_t>(e); });
if (input_shape->shape().size() != mul_shp_data.size()) {
MS_LOG(EXCEPTION) << "Tile requires input and multiples size equal, while the input size is "
<< input_shape->shape().size() << ", value size is: " << mul_shp_data.size() << ".";
}
ShapeVector result_shp;
for (size_t i = 0; i < mul_shp_data.size(); ++i) {
result_shp.push_back(input_shape->shape()[i] * mul_shp[i]);
}
return std::make_shared<AbstractTensor>(arg->element(), std::make_shared<Shape>(result_shp));
}
AbstractBasePtr InferImplStack(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: a tuple of tensor.

View File

@ -25,6 +25,7 @@
#include "ops/add.h"
#include "abstract/abstract_function.h"
#include "abstract/infer_functions.h"
#include "ops/tile.h"
namespace mindspore {
namespace abstract {
@ -174,7 +175,7 @@ PrimitiveEvalImplMap &GetPrimitiveToBackendEvalImplMap() {
{prim::kPrimSqrtGrad, {InferImplSqrtGrad, nullptr, true}},
{prim::kPrimSub, {InferImplSub, nullptr, false}},
{prim::kPrimEqual, {InferImplEqual, nullptr, true}},
{prim::kPrimTile, {InferImplTile, nullptr, false}},
{prim::kPrimTile, {ops::TileInfer, nullptr, true}},
{prim::kPrimReduceSum, {InferImplReduceFunc, nullptr, true}},
{prim::kPrimReduceMean, {InferImplReduceFunc, nullptr, true}},
{prim::kPrimReduceAll, {InferImplReduceFunc, nullptr, true}},

View File

@ -41,6 +41,9 @@ std::vector<int64_t> GetInferShape(const std::vector<int64_t> &input_shape, cons
"length of dimension in input_x";
}
for (size_t i = 0; i < multiples_w.size(); i++) {
if (infer_shape[i] == abstract::Shape::SHP_ANY) {
continue;
}
infer_shape[i] *= multiples_w[i];
}
return infer_shape;