!17179 some supplement of Tile operator.
From: @shen_jingxing Reviewed-by: Signed-off-by:
This commit is contained in:
commit
c470060584
|
@ -252,8 +252,6 @@ AbstractBasePtr InferImplLinSpace(const AnalysisEnginePtr &, const PrimitivePtr
|
||||||
const AbstractBasePtrList &args_spec_list);
|
const AbstractBasePtrList &args_spec_list);
|
||||||
AbstractBasePtr InferImplExpandDims(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr InferImplExpandDims(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const AbstractBasePtrList &args_spec_list);
|
const AbstractBasePtrList &args_spec_list);
|
||||||
AbstractBasePtr InferImplTile(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
|
||||||
const AbstractBasePtrList &args_spec_list);
|
|
||||||
AbstractBasePtr InferImplGpuConvertToDynamicShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr InferImplGpuConvertToDynamicShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const AbstractBasePtrList &args_spec_list);
|
const AbstractBasePtrList &args_spec_list);
|
||||||
AbstractBasePtr InferImplPad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr InferImplPad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
|
|
@ -82,39 +82,6 @@ AbstractBasePtr InferImplBroadCastShape(const AnalysisEnginePtr &, const Primiti
|
||||||
return std::make_shared<AbstractTuple>(elems);
|
return std::make_shared<AbstractTuple>(elems);
|
||||||
}
|
}
|
||||||
|
|
||||||
AbstractBasePtr InferImplTile(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
|
||||||
const AbstractBasePtrList &args_spec_list) {
|
|
||||||
// Inputs: a tensor and a tuple.
|
|
||||||
const std::string op_name = primitive->name();
|
|
||||||
CheckArgsSize(op_name, args_spec_list, 2);
|
|
||||||
auto arg = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
|
|
||||||
auto multiples = CheckArg<AbstractTuple>(op_name, args_spec_list, 1);
|
|
||||||
|
|
||||||
ShapePtr input_shape = arg->shape();
|
|
||||||
(void)CheckTensorDType(arg, {kInt16, kFloat16, kInt32, kFloat32}, "Input 0 of Tile should be %s");
|
|
||||||
|
|
||||||
auto mul_shp_value = multiples->BuildValue();
|
|
||||||
if (mul_shp_value->isa<AnyValue>()) {
|
|
||||||
MS_LOG(EXCEPTION) << "shape's data field can't be anything: " << args_spec_list[1]->ToString();
|
|
||||||
}
|
|
||||||
|
|
||||||
ShapeVector mul_shp;
|
|
||||||
auto value_tuple_mul = mul_shp_value->cast<ValueTuplePtr>();
|
|
||||||
auto mul_shp_data = value_tuple_mul->value();
|
|
||||||
(void)std::transform(std::begin(mul_shp_data), std::end(mul_shp_data), std::back_inserter(mul_shp),
|
|
||||||
[](const ValuePtr &e) -> int64_t { return GetValue<int64_t>(e); });
|
|
||||||
if (input_shape->shape().size() != mul_shp_data.size()) {
|
|
||||||
MS_LOG(EXCEPTION) << "Tile requires input and multiples size equal, while the input size is "
|
|
||||||
<< input_shape->shape().size() << ", value size is: " << mul_shp_data.size() << ".";
|
|
||||||
}
|
|
||||||
|
|
||||||
ShapeVector result_shp;
|
|
||||||
for (size_t i = 0; i < mul_shp_data.size(); ++i) {
|
|
||||||
result_shp.push_back(input_shape->shape()[i] * mul_shp[i]);
|
|
||||||
}
|
|
||||||
return std::make_shared<AbstractTensor>(arg->element(), std::make_shared<Shape>(result_shp));
|
|
||||||
}
|
|
||||||
|
|
||||||
AbstractBasePtr InferImplStack(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr InferImplStack(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const AbstractBasePtrList &args_spec_list) {
|
const AbstractBasePtrList &args_spec_list) {
|
||||||
// Inputs: a tuple of tensor.
|
// Inputs: a tuple of tensor.
|
||||||
|
|
|
@ -25,6 +25,7 @@
|
||||||
#include "ops/add.h"
|
#include "ops/add.h"
|
||||||
#include "abstract/abstract_function.h"
|
#include "abstract/abstract_function.h"
|
||||||
#include "abstract/infer_functions.h"
|
#include "abstract/infer_functions.h"
|
||||||
|
#include "ops/tile.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace abstract {
|
namespace abstract {
|
||||||
|
@ -174,7 +175,7 @@ PrimitiveEvalImplMap &GetPrimitiveToBackendEvalImplMap() {
|
||||||
{prim::kPrimSqrtGrad, {InferImplSqrtGrad, nullptr, true}},
|
{prim::kPrimSqrtGrad, {InferImplSqrtGrad, nullptr, true}},
|
||||||
{prim::kPrimSub, {InferImplSub, nullptr, false}},
|
{prim::kPrimSub, {InferImplSub, nullptr, false}},
|
||||||
{prim::kPrimEqual, {InferImplEqual, nullptr, true}},
|
{prim::kPrimEqual, {InferImplEqual, nullptr, true}},
|
||||||
{prim::kPrimTile, {InferImplTile, nullptr, false}},
|
{prim::kPrimTile, {ops::TileInfer, nullptr, true}},
|
||||||
{prim::kPrimReduceSum, {InferImplReduceFunc, nullptr, true}},
|
{prim::kPrimReduceSum, {InferImplReduceFunc, nullptr, true}},
|
||||||
{prim::kPrimReduceMean, {InferImplReduceFunc, nullptr, true}},
|
{prim::kPrimReduceMean, {InferImplReduceFunc, nullptr, true}},
|
||||||
{prim::kPrimReduceAll, {InferImplReduceFunc, nullptr, true}},
|
{prim::kPrimReduceAll, {InferImplReduceFunc, nullptr, true}},
|
||||||
|
|
|
@ -41,6 +41,9 @@ std::vector<int64_t> GetInferShape(const std::vector<int64_t> &input_shape, cons
|
||||||
"length of dimension in input_x";
|
"length of dimension in input_x";
|
||||||
}
|
}
|
||||||
for (size_t i = 0; i < multiples_w.size(); i++) {
|
for (size_t i = 0; i < multiples_w.size(); i++) {
|
||||||
|
if (infer_shape[i] == abstract::Shape::SHP_ANY) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
infer_shape[i] *= multiples_w[i];
|
infer_shape[i] *= multiples_w[i];
|
||||||
}
|
}
|
||||||
return infer_shape;
|
return infer_shape;
|
||||||
|
|
Loading…
Reference in New Issue