!17179 some supplement of Tile operator.
From: @shen_jingxing Reviewed-by: Signed-off-by:
This commit is contained in:
commit
c470060584
|
@ -252,8 +252,6 @@ AbstractBasePtr InferImplLinSpace(const AnalysisEnginePtr &, const PrimitivePtr
|
|||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplExpandDims(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplTile(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplGpuConvertToDynamicShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplPad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
|
|
|
@ -82,39 +82,6 @@ AbstractBasePtr InferImplBroadCastShape(const AnalysisEnginePtr &, const Primiti
|
|||
return std::make_shared<AbstractTuple>(elems);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplTile(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: a tensor and a tuple.
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 2);
|
||||
auto arg = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
|
||||
auto multiples = CheckArg<AbstractTuple>(op_name, args_spec_list, 1);
|
||||
|
||||
ShapePtr input_shape = arg->shape();
|
||||
(void)CheckTensorDType(arg, {kInt16, kFloat16, kInt32, kFloat32}, "Input 0 of Tile should be %s");
|
||||
|
||||
auto mul_shp_value = multiples->BuildValue();
|
||||
if (mul_shp_value->isa<AnyValue>()) {
|
||||
MS_LOG(EXCEPTION) << "shape's data field can't be anything: " << args_spec_list[1]->ToString();
|
||||
}
|
||||
|
||||
ShapeVector mul_shp;
|
||||
auto value_tuple_mul = mul_shp_value->cast<ValueTuplePtr>();
|
||||
auto mul_shp_data = value_tuple_mul->value();
|
||||
(void)std::transform(std::begin(mul_shp_data), std::end(mul_shp_data), std::back_inserter(mul_shp),
|
||||
[](const ValuePtr &e) -> int64_t { return GetValue<int64_t>(e); });
|
||||
if (input_shape->shape().size() != mul_shp_data.size()) {
|
||||
MS_LOG(EXCEPTION) << "Tile requires input and multiples size equal, while the input size is "
|
||||
<< input_shape->shape().size() << ", value size is: " << mul_shp_data.size() << ".";
|
||||
}
|
||||
|
||||
ShapeVector result_shp;
|
||||
for (size_t i = 0; i < mul_shp_data.size(); ++i) {
|
||||
result_shp.push_back(input_shape->shape()[i] * mul_shp[i]);
|
||||
}
|
||||
return std::make_shared<AbstractTensor>(arg->element(), std::make_shared<Shape>(result_shp));
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplStack(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: a tuple of tensor.
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
#include "ops/add.h"
|
||||
#include "abstract/abstract_function.h"
|
||||
#include "abstract/infer_functions.h"
|
||||
#include "ops/tile.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace abstract {
|
||||
|
@ -174,7 +175,7 @@ PrimitiveEvalImplMap &GetPrimitiveToBackendEvalImplMap() {
|
|||
{prim::kPrimSqrtGrad, {InferImplSqrtGrad, nullptr, true}},
|
||||
{prim::kPrimSub, {InferImplSub, nullptr, false}},
|
||||
{prim::kPrimEqual, {InferImplEqual, nullptr, true}},
|
||||
{prim::kPrimTile, {InferImplTile, nullptr, false}},
|
||||
{prim::kPrimTile, {ops::TileInfer, nullptr, true}},
|
||||
{prim::kPrimReduceSum, {InferImplReduceFunc, nullptr, true}},
|
||||
{prim::kPrimReduceMean, {InferImplReduceFunc, nullptr, true}},
|
||||
{prim::kPrimReduceAll, {InferImplReduceFunc, nullptr, true}},
|
||||
|
|
|
@ -41,6 +41,9 @@ std::vector<int64_t> GetInferShape(const std::vector<int64_t> &input_shape, cons
|
|||
"length of dimension in input_x";
|
||||
}
|
||||
for (size_t i = 0; i < multiples_w.size(); i++) {
|
||||
if (infer_shape[i] == abstract::Shape::SHP_ANY) {
|
||||
continue;
|
||||
}
|
||||
infer_shape[i] *= multiples_w[i];
|
||||
}
|
||||
return infer_shape;
|
||||
|
|
Loading…
Reference in New Issue