add a process of make type and shape to be abstract

This commit is contained in:
lianliguang 2021-05-31 14:32:41 +08:00
parent 109770b743
commit 68f3b5784f
41 changed files with 205 additions and 148 deletions

View File

@ -505,12 +505,33 @@ py::tuple PreparePyInputs(const PrimitivePyPtr &prim_py, const AbstractBasePtrLi
return py_args;
}
void CheckCustomPrimOutputInferResult(const PrimitivePtr &prim, const AbstractBasePtr &res_spec) {
const string kOutputNum = "output_num";
if (prim->IsCustomPrim()) {
// Raise error if output_num is not match the infer result.
auto output_num_value = prim->GetAttr(kOutputNum);
if (output_num_value == nullptr) {
MS_LOG(DEBUG) << "The output num may no need to check";
return;
}
int64_t output_num = GetValue<int64_t>(output_num_value);
if (res_spec->isa<AbstractTensor>() && output_num != 1) {
MS_LOG(EXCEPTION) << "Custom primitive " << prim->ToString() << " output_num " << output_num
<< " not matches the infer result.";
} else if (res_spec->isa<AbstractTuple>() &&
(res_spec->cast<AbstractTuplePtr>()->size() != LongToSize(output_num))) {
MS_LOG(EXCEPTION) << "Custom primitive " << prim->ToString() << " output_num " << output_num
<< " not matches the infer result.";
}
}
}
AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dict &output) {
// Convert to AbstractValue based on type and shape
auto out_dtype = output[ATTR_DTYPE];
if (output[ATTR_VALUE].is_none()) {
auto out_shape = output[ATTR_SHAPE];
return PyListDtype2AbstractTensor(out_shape, out_dtype, output);
return MakePyInferRes2Abstract(out_shape, out_dtype, output);
}
// Convert pyobject to Value, then to AbstractValue
ValuePtr converted_ret = nullptr;
@ -527,18 +548,7 @@ AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dic
res_tensor->set_value(converted_ret);
SetValueRange(res_tensor, output);
}
if (prim_py->IsCustomPrim()) {
// Raise error if output_num is not match the infer result.
int64_t output_num = GetValue<int64_t>(prim_py->GetAttr("output_num"));
if (res_spec->isa<AbstractTensor>() && output_num != 1) {
MS_LOG(EXCEPTION) << "Custom primitive " << prim_py->ToString() << " output_num " << output_num
<< " not matches the infer result.";
} else if (res_spec->isa<AbstractTuple>() &&
(res_spec->cast<AbstractTuplePtr>()->size() != LongToSize(output_num))) {
MS_LOG(EXCEPTION) << "Custom primitive " << prim_py->ToString() << " output_num " << output_num
<< " not matches the infer result.";
}
}
CheckCustomPrimOutputInferResult(prim_py, res_spec);
return res_spec;
}
} // end anonymous namespace

View File

@ -25,6 +25,7 @@
#include <cfloat>
#include "abstract/abstract_value.h"
#include "abstract/utils.h"
#include "pipeline/jit/parse/parse.h"
#include "pipeline/jit/parse/parse_base.h"
#include "ir/value.h"
@ -357,9 +358,8 @@ void SetValueRange(const AbstractBasePtr &tensor, const py::object &output) {
}
}
AbstractBasePtr PyList2DynamicShapeTensor(const py::object &shape_obj, const py::object &type_obj,
AbstractBasePtr MakePyInferRes2AbstractTensor(const py::object &shape_obj, const py::object &type_obj,
const py::object &output) {
AbstractBasePtr tensor = nullptr;
auto ret_vec = shape_obj.cast<ShapeVector>();
auto ret_dtype = type_obj.cast<TypePtr>();
ShapeVector min_shape_vec;
@ -379,15 +379,7 @@ AbstractBasePtr PyList2DynamicShapeTensor(const py::object &shape_obj, const py:
}
auto ret_shape = std::make_shared<abstract::Shape>(ret_vec, min_shape_vec, max_shape_vec);
if (ret_dtype->isa<TensorType>()) {
auto tensor_type = type_obj.cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(tensor_type);
auto element = std::make_shared<abstract::AbstractScalar>(kAnyValue, tensor_type->element());
tensor = std::make_shared<abstract::AbstractTensor>(element, ret_shape);
} else {
auto element = std::make_shared<abstract::AbstractScalar>(kAnyValue, ret_dtype);
tensor = std::make_shared<abstract::AbstractTensor>(element, ret_shape);
}
AbstractBasePtr tensor = MakeAbstractTensor(ret_shape, ret_dtype);
SetValueRange(tensor, output);
return tensor;
@ -404,17 +396,15 @@ static bool IsMonadType(const py::object &type_obj) {
static AbstractBasePtr ToMonadAbstract(const py::object &type_obj) {
if (py::isinstance<Type>(type_obj)) {
auto type = type_obj.cast<Type *>();
if (type->isa<UMonadType>()) {
return kUMonad->ToAbstract();
}
if (type->isa<IOMonadType>()) {
return kIOMonad->ToAbstract();
}
}
if (!type->isa<MonadType>()) {
MS_LOG(EXCEPTION) << "Not a monad type object: " << py::str(type_obj);
}
return abstract::MakeMonadAbstract(type->cast<MonadTypePtr>());
}
MS_LOG(EXCEPTION) << "Not a type object: " << py::str(type_obj);
}
AbstractBasePtr PyListDtype2AbstractTensor(const py::object &shape_obj, const py::object &type_obj,
AbstractBasePtr MakePyInferRes2Abstract(const py::object &shape_obj, const py::object &type_obj,
const py::object &output) {
if ((py::isinstance<py::list>(shape_obj) || py::isinstance<py::tuple>(shape_obj)) && py::isinstance<Type>(type_obj)) {
auto ret_vec = shape_obj.cast<ShapeVector>();
@ -425,13 +415,13 @@ AbstractBasePtr PyListDtype2AbstractTensor(const py::object &shape_obj, const py
abstract::AbstractScalarPtr abs_scalar = std::make_shared<abstract::AbstractScalar>(kAnyValue, ret_dtype);
return abs_scalar;
}
return PyList2DynamicShapeTensor(shape_obj, type_obj, output);
return MakePyInferRes2AbstractTensor(shape_obj, type_obj, output);
} else if (py::isinstance<py::tuple>(shape_obj) && py::isinstance<py::tuple>(type_obj)) {
auto shape_tuple = shape_obj.cast<py::tuple>();
auto typeid_tuple = type_obj.cast<py::tuple>();
AbstractBasePtrList ptr_list;
for (size_t it = 0; it < shape_tuple.size(); ++it) {
auto tensor_it = PyListDtype2AbstractTensor(shape_tuple[it], typeid_tuple[it]);
auto tensor_it = MakePyInferRes2Abstract(shape_tuple[it], typeid_tuple[it]);
ptr_list.push_back(tensor_it);
}
auto tuple = std::make_shared<abstract::AbstractTuple>(ptr_list);
@ -441,7 +431,7 @@ AbstractBasePtr PyListDtype2AbstractTensor(const py::object &shape_obj, const py
auto typeid_list = type_obj.cast<py::list>();
AbstractBasePtrList ptr_list;
for (size_t it = 0; it < shape_list.size(); ++it) {
auto tensor_it = PyListDtype2AbstractTensor(shape_list[it], typeid_list[it]);
auto tensor_it = MakePyInferRes2Abstract(shape_list[it], typeid_list[it]);
ptr_list.push_back(tensor_it);
}
auto list = std::make_shared<abstract::AbstractList>(ptr_list);

View File

@ -36,7 +36,7 @@ py::object ValuePtrToPyData(const ValuePtr &value);
bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &output, const py::tuple &args,
const std::shared_ptr<py::object> &ret_val);
AbstractBasePtr PyListDtype2AbstractTensor(const py::object &shape_obj, const py::object &type_obj,
AbstractBasePtr MakePyInferRes2Abstract(const py::object &shape_obj, const py::object &type_obj,
const py::object &output = py::none());
void SetValueRange(const AbstractBasePtr &tensor, const py::object &output);
} // namespace mindspore

View File

@ -21,6 +21,7 @@
#include <string>
#include <sstream>
#include <memory>
#include "utils/ms_context.h"
#include "utils/symbolic.h"
#include "abstract/param_validator.h"
@ -348,5 +349,96 @@ int64_t GetUnsortedSegmentOpScalarArg(const AbstractBasePtrList &args_spec_list,
}
return num_segments_value;
}
AbstractBasePtr MakeAbstractTensor(const ShapePtr &shape, const TypePtr &type) {
MS_EXCEPTION_IF_NULL(shape);
MS_EXCEPTION_IF_NULL(type);
AbstractBasePtr tensor = nullptr;
auto ret_vec = shape->shape();
ShapeVector min_shape_vec;
ShapeVector max_shape_vec;
if (shape->IsDynamic()) {
if (!shape->min_shape().empty()) {
min_shape_vec = shape->min_shape();
}
if (!shape->max_shape().empty()) {
max_shape_vec = shape->max_shape();
}
}
auto ret_shape = std::make_shared<abstract::Shape>(ret_vec, min_shape_vec, max_shape_vec);
if (type->isa<TensorType>()) {
auto tensor_type = type->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(tensor_type);
auto element = std::make_shared<abstract::AbstractScalar>(kAnyValue, tensor_type->element());
tensor = std::make_shared<abstract::AbstractTensor>(element, ret_shape);
} else {
auto element = std::make_shared<abstract::AbstractScalar>(kAnyValue, type);
tensor = std::make_shared<abstract::AbstractTensor>(element, ret_shape);
}
return tensor;
}
AbstractBasePtr MakeMonadAbstract(const MonadTypePtr &type) {
if (type->isa<UMonadType>()) {
return kUMonad->ToAbstract();
} else if (type->isa<IOMonadType>()) {
return kIOMonad->ToAbstract();
}
MS_EXCEPTION(UnknownError) << "Unsupported to convert type " << type->ToString() << " to monad abstract";
}
AbstractBasePtr MakeAbstract(const BaseShapePtr &base_shape, const TypePtr &type) {
MS_EXCEPTION_IF_NULL(base_shape);
MS_EXCEPTION_IF_NULL(type);
if ((base_shape->isa<Shape>())) {
auto shape = base_shape->cast<ShapePtr>();
MS_EXCEPTION_IF_NULL(shape);
auto shape_vec = shape->shape();
// if the size of shape list is empty, return an scalar abstract
if (shape_vec.empty() && (!type->isa<TensorType>())) {
abstract::AbstractScalarPtr abs_scalar = std::make_shared<abstract::AbstractScalar>(kAnyValue, type);
return abs_scalar;
}
return MakeAbstractTensor(shape, type);
} else if (base_shape->isa<TupleShape>() && type->isa<Tuple>()) {
auto shape_tuple = base_shape->cast<TupleShapePtr>();
auto type_tuple = type->cast<TuplePtr>();
AbstractBasePtrList ptr_list;
for (size_t it = 0; it < shape_tuple->size(); ++it) {
auto tensor_it = MakeAbstract((*shape_tuple)[it], (*type_tuple)[it]);
ptr_list.push_back(tensor_it);
}
auto tuple = std::make_shared<abstract::AbstractTuple>(ptr_list);
return tuple;
} else if (base_shape->isa<ListShape>() && type->isa<List>()) {
auto shape_list = base_shape->cast<ListShapePtr>();
auto type_list = type->cast<ListPtr>();
AbstractBasePtrList ptr_list;
for (size_t it = 0; it < shape_list->size(); ++it) {
auto tensor_it = MakeAbstract((*shape_list)[it], (*type_list)[it]);
ptr_list.push_back(tensor_it);
}
auto list = std::make_shared<abstract::AbstractList>(ptr_list);
return list;
} else if (base_shape->isa<NoShape>() && type->isa<TypeNone>()) {
// AbstractNone indicates there is no output for this CNode node.
auto abstract_none = std::make_shared<abstract::AbstractNone>();
return abstract_none;
} else if (type->isa<Monad>()) {
// Return monad abstract if it is monad type.
return MakeMonadAbstract(type->cast<MonadTypePtr>());
} else {
// When sparse enabled, the undetermined might be raised and eliminated in opt passes
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
bool enable_sparse = context->get_param<bool>(MS_CTX_ENABLE_SPARSE);
if (enable_sparse) {
return std::make_shared<abstract::AbstractUndetermined>();
}
MS_LOG(EXCEPTION) << "evaluator return invalid shape " << base_shape->ToString() << "or type. " << type->ToString();
}
}
} // namespace abstract
} // namespace mindspore

View File

@ -62,7 +62,9 @@ void CheckMinMaxShape(const ShapeVector &shape, ShapeVector *min_shape, ShapeVec
// Get 3rd argument for UnsortedSegmentOps' inferImpl function
int64_t GetUnsortedSegmentOpScalarArg(const AbstractBasePtrList &args_spec_list, const std::string &op_name);
AbstractBasePtr MakeAbstract(const BaseShapePtr &base_shape, const TypePtr &type);
AbstractBasePtr MakeMonadAbstract(const MonadTypePtr &type);
AbstractBasePtr MakeAbstractTensor(const ShapePtr &shape, const TypePtr &type);
} // namespace abstract
} // namespace mindspore
#endif // MINDSPORE_CORE_ABSTRACT_UTILS_H_

View File

@ -50,8 +50,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
AbstractBasePtr AddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
InferShape(primitive, input_args)->shape());
return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args));
}
REGISTER_PRIMITIVE_C(kNameAdd, Add);
} // namespace ops

View File

@ -72,7 +72,8 @@ TypePtr AddNInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePt
}
std::set<TypePtr> valid_types = common_valid_types;
valid_types.insert(kBool);
return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim_name);
CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim_name);
return elements[0]->BuildType();
}
} // namespace
AbstractBasePtr AddNInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
@ -83,11 +84,7 @@ AbstractBasePtr AddNInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto abs_type = AddNInferType(primitive, input_args);
if (abs_type->type_id() == kObjectTypeUndeterminedType) {
return std::make_shared<abstract::AbstractUndetermined>();
}
return std::make_shared<abstract::AbstractTensor>(abs_type, AddNInferShape(primitive, input_args));
return abstract::MakeAbstract(AddNInferShape(primitive, input_args), AddNInferType(primitive, input_args));
}
REGISTER_PRIMITIVE_EVAL_IMPL(AddN, prim::kPrimAddN, AddNInfer, nullptr, true);
} // namespace ops

View File

@ -172,8 +172,7 @@ TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBaseP
AbstractBasePtr AvgPool3DInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
InferShape(primitive, input_args)->shape());
return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args));
}
REGISTER_PRIMITIVE_EVAL_IMPL(AvgPool3D, prim::kPrimAvgPool3D, AvgPool3DInfer, nullptr, true);

View File

@ -145,9 +145,8 @@ AbstractBasePtr BatchMatmulInfer(const abstract::AnalysisEnginePtr &, const Prim
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
CheckAndConvertUtils::CheckInteger("BatchMatmul infer", input_args.size(), kGreaterEqual, 2, primitive->name());
auto abs = std::make_shared<abstract::AbstractTensor>(BatchMatmulInferType(primitive, input_args),
BatchMatmulInferShape(primitive, input_args));
return abs;
return abstract::MakeAbstract(BatchMatmulInferShape(primitive, input_args),
BatchMatmulInferType(primitive, input_args));
}
REGISTER_PRIMITIVE_EVAL_IMPL(BatchMatmul, prim::kPrimBatchMatMul, BatchMatmulInfer, nullptr, true);
} // namespace ops

View File

@ -84,8 +84,8 @@ std::vector<int64_t> BroadcastTo::get_shape() const {
}
AbstractBasePtr BroadcastToInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
return std::make_shared<abstract::AbstractTensor>(BroadcastToInferType(primitive, input_args),
BroadcastToInferShape(primitive, input_args));
return abstract::MakeAbstract(BroadcastToInferShape(primitive, input_args),
BroadcastToInferType(primitive, input_args));
}
REGISTER_PRIMITIVE_EVAL_IMPL(BroadcastTo, prim::kPrimBroadcastTo, BroadcastToInfer, nullptr, true);
} // namespace ops

View File

@ -313,8 +313,7 @@ Format Conv2D::get_format() const {
AbstractBasePtr Conv2dInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
CheckAndConvertUtils::CheckInteger("Conv2d infer", input_args.size(), kGreaterEqual, 2, primitive->name());
return std::make_shared<abstract::AbstractTensor>(Conv2dInferType(primitive, input_args),
Conv2dInferShape(primitive, input_args));
return abstract::MakeAbstract(Conv2dInferShape(primitive, input_args), Conv2dInferType(primitive, input_args));
}
REGISTER_PRIMITIVE_EVAL_IMPL(Conv2D, prim::kPrimConv2D, Conv2dInfer, nullptr, true);
} // namespace ops

View File

@ -56,8 +56,7 @@ void CheckCTCLossInputs(const std::vector<AbstractBasePtr> &input_args, const st
}
}
std::vector<abstract::ShapePtr> InferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
abstract::TupleShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
auto op_name = primitive->name();
CheckCTCLossInputs(input_args, op_name);
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape());
@ -71,36 +70,33 @@ std::vector<abstract::ShapePtr> InferShape(const PrimitivePtr &primitive,
if (min_shape.empty() || max_shape.empty()) {
loss_shape = std::make_shared<abstract::Shape>(batch);
gradient_shape = std::make_shared<abstract::Shape>(shape);
return std::vector<abstract::ShapePtr>{loss_shape, gradient_shape};
return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{loss_shape, gradient_shape});
}
ShapeVector batch_min = {min_shape[1]};
ShapeVector batch_max = {max_shape[1]};
loss_shape = std::make_shared<abstract::Shape>(batch, batch_min, batch_max);
gradient_shape = std::make_shared<abstract::Shape>(shape, min_shape, max_shape);
return std::vector<abstract::ShapePtr>{loss_shape, gradient_shape};
return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{loss_shape, gradient_shape});
}
TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
TuplePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
auto op_name = primitive->name();
CheckAndConvertUtils::CheckTensorTypeValid("labels_indices", input_args[1]->BuildType(), {kInt64}, op_name);
CheckAndConvertUtils::CheckTensorTypeValid("labels_values", input_args[2]->BuildType(), {kInt32}, op_name);
CheckAndConvertUtils::CheckTensorTypeValid("sequence_length", input_args[3]->BuildType(), {kInt32}, op_name);
const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kFloat64};
return CheckAndConvertUtils::CheckTensorTypeValid("inputs", input_args[0]->BuildType(), valid_types, op_name);
auto type = CheckAndConvertUtils::CheckTensorTypeValid("inputs", input_args[0]->BuildType(), valid_types, op_name);
return std::make_shared<Tuple>(std::vector<TypePtr>{type, type});
}
} // namespace
AbstractBasePtr CTCLossInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto type = InferType(primitive, input_args);
auto types = InferType(primitive, input_args);
auto shapes = InferShape(primitive, input_args);
auto loss = std::make_shared<abstract::AbstractTensor>(type, shapes[0]);
auto gradient = std::make_shared<abstract::AbstractTensor>(type, shapes[1]);
AbstractBasePtrList outputs = {loss, gradient};
return std::make_shared<abstract::AbstractTuple>(outputs);
return abstract::MakeAbstract(shapes, types);
}
REGISTER_PRIMITIVE_EVAL_IMPL(CTCLoss, prim::kPrimCTCLoss, CTCLossInfer, nullptr, true);
} // namespace ops

View File

@ -116,8 +116,7 @@ AbstractBasePtr DropoutDoMaskInfer(const abstract::AnalysisEnginePtr &, const Pr
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
CheckAndConvertUtils::CheckInteger("infer shape", input_args.size(), kGreaterEqual, 3, primitive->name());
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
InferShape(primitive, input_args));
return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args));
}
REGISTER_PRIMITIVE_EVAL_IMPL(DropoutDoMask, prim::kPrimDropoutDoMask, DropoutDoMaskInfer, nullptr, true);
} // namespace ops

View File

@ -147,8 +147,7 @@ TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBaseP
AbstractBasePtr DropoutGenMaskInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
InferShape(primitive, input_args));
return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args));
}
REGISTER_PRIMITIVE_EVAL_IMPL(DropoutGenMask, prim::kPrimDropoutGenMask, DropoutGenMaskInfer, nullptr, true);
} // namespace ops

View File

@ -44,8 +44,7 @@ AbstractBasePtr DTypeInfer(const abstract::AnalysisEnginePtr &, const PrimitiveP
MS_EXCEPTION_IF_NULL(value);
auto type = value->cast<TypePtr>();
MS_EXCEPTION_IF_NULL(type);
auto abstract = std::make_shared<abstract::AbstractType>(type);
return abstract;
return abstract::MakeAbstract(std::make_shared<abstract::NoShape>(), type);
}
REGISTER_PRIMITIVE_EVAL_IMPL(DType, prim::kPrimDType, DTypeInfer, DTypeInferValue, false);

View File

@ -84,10 +84,10 @@ AbstractBasePtr GatherInfer(const abstract::AnalysisEnginePtr &, const Primitive
if (ind_dyn || param_dyn) {
ShapeVector min_shape = calc_shape(indices_shp_min, param_shp_min);
ShapeVector max_shape = calc_shape(indices_shp_max, param_shp_max);
return std::make_shared<abstract::AbstractTensor>(
params->element(), std::make_shared<abstract::Shape>(out_shape, min_shape, max_shape));
return abstract::MakeAbstract(std::make_shared<abstract::Shape>(out_shape, min_shape, max_shape),
params->BuildType());
}
return std::make_shared<abstract::AbstractTensor>(params->element(), std::make_shared<abstract::Shape>(out_shape));
return abstract::MakeAbstract(std::make_shared<abstract::Shape>(out_shape), params->BuildType());
}
REGISTER_PRIMITIVE_EVAL_IMPL(Gather, prim::kPrimGather, GatherInfer, nullptr, true);
} // namespace ops

View File

@ -68,9 +68,7 @@ AbstractBasePtr GatherDInfer(const abstract::AnalysisEnginePtr &, const Primitiv
std::set<TypePtr> valid_types = {kInt32, kInt64};
CheckAndConvertUtils::CheckTensorTypeValid("index_type", input_args[2]->BuildType(), valid_types, prim_name);
CheckAndConvertUtils::CheckSubClass("dim_type", input_args[1]->BuildType(), valid_types, prim_name);
auto abs = std::make_shared<abstract::AbstractTensor>(GatherDInferType(primitive, input_args),
GatherDInferShape(primitive, input_args));
return abs;
return abstract::MakeAbstract(GatherDInferShape(primitive, input_args), GatherDInferType(primitive, input_args));
}
REGISTER_PRIMITIVE_EVAL_IMPL(GatherD, prim::kPrimGatherD, GatherDInfer, nullptr, true);
} // namespace ops

View File

@ -54,8 +54,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
} // namespace
AbstractBasePtr GeLUInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
InferShape(primitive, input_args)->shape());
return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args));
}
REGISTER_PRIMITIVE_EVAL_IMPL(GeLU, prim::kPrimGeLU, GeLUInfer, nullptr, true);

View File

@ -72,8 +72,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
} // namespace
AbstractBasePtr BiasAddGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
InferShape(primitive, input_args)->shape());
return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args));
}
REGISTER_PRIMITIVE_EVAL_IMPL(BiasAddGrad, prim::kPrimBiasAddGrad, BiasAddGradInfer, nullptr, true);

View File

@ -30,11 +30,12 @@ AbstractBasePtr LayerNormGradInfer(const abstract::AnalysisEnginePtr &, const Pr
auto x_backprob = input_args[0]->Broaden();
auto gamma_backprob = input_args[4]->Broaden();
auto beta_backprob = input_args[4]->Broaden();
AbstractBasePtrList args_list({x_backprob, gamma_backprob, beta_backprob});
return std::make_shared<abstract::AbstractTuple>(args_list);
auto shapes = std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{
x_backprob->BuildShape(), gamma_backprob->BuildShape(), beta_backprob->BuildShape()});
auto types = std::make_shared<Tuple>(
std::vector<TypePtr>{x_backprob->BuildType(), gamma_backprob->BuildType(), beta_backprob->BuildType()});
return abstract::MakeAbstract(shapes, types);
}
void LayerNormGrad::Init(const int64_t begin_norm_axis, const int64_t begin_params_axis) {
this->set_begin_norm_axis(begin_norm_axis);
this->set_begin_params_axis(begin_params_axis);

View File

@ -62,8 +62,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
} // namespace
AbstractBasePtr ReLUGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
InferShape(primitive, input_args));
return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args));
}
REGISTER_PRIMITIVE_EVAL_IMPL(ReLUGrad, prim::kPrimReluGrad, ReLUGradInfer, nullptr, true);

View File

@ -55,8 +55,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
} // namespace
AbstractBasePtr ReLUGradV2Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
InferShape(primitive, input_args));
return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args));
}
REGISTER_PRIMITIVE_EVAL_IMPL(ReLUGradV2, prim::kPrimReluGradV2, ReLUGradV2Infer, nullptr, true);

View File

@ -93,23 +93,22 @@ AbstractBasePtr LayerNormInfer(const abstract::AnalysisEnginePtr &, const Primit
}
}
auto mean = input_x->Broaden();
auto var = input_x->Broaden();
std::vector<BaseShapePtr> shapes_list = {input_x->BuildShape()};
std::vector<TypePtr> types_list = {input_x->BuildType(), input_x->BuildType(), input_x->BuildType()};
auto mean_var_shape = CalLayerNormMeanAndVarShape(begin_norm_axis, input_shape->shape());
auto input_min_shape = input_shape->min_shape();
auto input_max_shape = input_shape->max_shape();
if (input_min_shape.empty() || input_max_shape.empty()) {
mean->set_shape(std::make_shared<abstract::Shape>(mean_var_shape));
var->set_shape(std::make_shared<abstract::Shape>(mean_var_shape));
shapes_list.emplace_back(std::make_shared<abstract::Shape>(mean_var_shape));
shapes_list.emplace_back(std::make_shared<abstract::Shape>(mean_var_shape));
} else {
auto mean_var_shape_min = CalLayerNormMeanAndVarShape(begin_norm_axis, input_min_shape);
auto mean_var_shape_max = CalLayerNormMeanAndVarShape(begin_norm_axis, input_min_shape);
mean->set_shape(std::make_shared<abstract::Shape>(mean_var_shape, mean_var_shape_min, mean_var_shape_max));
var->set_shape(std::make_shared<abstract::Shape>(mean_var_shape, mean_var_shape_min, mean_var_shape_max));
shapes_list.emplace_back(std::make_shared<abstract::Shape>(mean_var_shape, mean_var_shape_min, mean_var_shape_max));
shapes_list.emplace_back(std::make_shared<abstract::Shape>(mean_var_shape, mean_var_shape_min, mean_var_shape_max));
}
AbstractBasePtrList args_list({input_x->Broaden(), mean, var});
return std::make_shared<abstract::AbstractTuple>(args_list);
return abstract::MakeAbstract(std::make_shared<abstract::TupleShape>(shapes_list),
std::make_shared<Tuple>(types_list));
}
void LayerNorm::Init(const int64_t begin_norm_axis, const int64_t begin_params_axis, const float epsilon) {

View File

@ -40,8 +40,8 @@ TypePtr Log1pInferType(const PrimitivePtr &prim, const std::vector<AbstractBaseP
auto prim_name = prim->name();
// check
std::set<TypePtr> valid_index_types = {kFloat16, kFloat32};
auto x_type =
CheckAndConvertUtils::CheckTensorTypeValid("x", input_args[0]->BuildType(), valid_index_types, prim_name);
auto x_type = input_args[0]->BuildType();
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", input_args[0]->BuildType(), valid_index_types, prim_name);
return x_type;
}
} // namespace
@ -49,9 +49,7 @@ TypePtr Log1pInferType(const PrimitivePtr &prim, const std::vector<AbstractBaseP
AbstractBasePtr Log1pInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto abs = std::make_shared<abstract::AbstractTensor>(Log1pInferType(primitive, input_args),
Log1pInferShape(primitive, input_args));
return abs;
return abstract::MakeAbstract(Log1pInferShape(primitive, input_args), Log1pInferType(primitive, input_args));
}
REGISTER_PRIMITIVE_EVAL_IMPL(Log1p, prim::kPrimLog1p, Log1pInfer, nullptr, true);
} // namespace ops

View File

@ -65,8 +65,8 @@ TypePtr LogSoftmaxInferType(const PrimitivePtr &prim, const std::vector<Abstract
AbstractBasePtr LogSoftmaxInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
return std::make_shared<abstract::AbstractTensor>(LogSoftmaxInferType(primitive, input_args),
LogSoftmaxInferShape(primitive, input_args)->shape());
return abstract::MakeAbstract(LogSoftmaxInferShape(primitive, input_args),
LogSoftmaxInferType(primitive, input_args));
}
REGISTER_PRIMITIVE_EVAL_IMPL(LogSoftmax, prim::kPrimLogSoftmax, LogSoftmaxInfer, nullptr, true);
} // namespace ops

View File

@ -126,8 +126,7 @@ bool MatMul::get_transpose_b() const {
AbstractBasePtr MatMulInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
CheckAndConvertUtils::CheckInteger("MatMul infer", input_args.size(), kGreaterEqual, 2, primitive->name());
return std::make_shared<abstract::AbstractTensor>(MatMulInferType(primitive, input_args),
MatMulInferShape(primitive, input_args)->shape());
return abstract::MakeAbstract(MatMulInferShape(primitive, input_args), MatMulInferType(primitive, input_args));
}
// Add
REGISTER_PRIMITIVE_C(kNameMatMul, MatMul);

View File

@ -52,8 +52,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
AbstractBasePtr MulInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
InferShape(primitive, input_args)->shape());
return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args));
}
REGISTER_PRIMITIVE_C(kNameMul, Mul);
} // namespace ops

View File

@ -73,8 +73,7 @@ TypePtr OneHotInferType(const PrimitivePtr &prim, const std::vector<AbstractBase
} // namespace
AbstractBasePtr OneHotInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
return std::make_shared<abstract::AbstractTensor>(OneHotInferType(primitive, input_args),
OneHotInferShape(primitive, input_args));
return abstract::MakeAbstract(OneHotInferShape(primitive, input_args), OneHotInferType(primitive, input_args));
}
REGISTER_PRIMITIVE_EVAL_IMPL(OneHot, prim::kPrimOneHot, OneHotInfer, nullptr, true);
} // namespace ops

View File

@ -45,8 +45,7 @@ TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBaseP
AbstractBasePtr OnesLikeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
InferShape(primitive, input_args));
return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args));
}
REGISTER_PRIMITIVE_EVAL_IMPL(OnesLike, prim::kPrimOnesLike, OnesLikeInfer, nullptr, true);
} // namespace ops

View File

@ -52,8 +52,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
AbstractBasePtr RealDivInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
InferShape(primitive, input_args)->shape());
return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args));
}
REGISTER_PRIMITIVE_C(kNameRealDiv, RealDiv);
} // namespace ops

View File

@ -44,14 +44,15 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
CheckAndConvertUtils::CheckInteger("ReLU infer", input_args.size(), kEqual, 1, prim_name);
MS_EXCEPTION_IF_NULL(input_args[0]);
auto x_type = input_args[0]->BuildType();
return CheckAndConvertUtils::CheckTensorTypeValid("input_x", x_type, common_valid_types, prim_name);
(void)CheckAndConvertUtils::CheckTensorTypeValid("input_x", x_type, common_valid_types, prim_name);
return x_type;
}
} // namespace
AbstractBasePtr ReLUInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
auto type = InferType(primitive, input_args);
auto shape = InferShape(primitive, input_args);
return std::make_shared<abstract::AbstractTensor>(type, shape);
return abstract::MakeAbstract(shape, type);
}
REGISTER_PRIMITIVE_EVAL_IMPL(ReLU, prim::kPrimRelu, ReLUInfer, nullptr, true);

View File

@ -52,8 +52,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
} // namespace
AbstractBasePtr ReLU6Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
InferShape(primitive, input_args));
return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args));
}
REGISTER_PRIMITIVE_EVAL_IMPL(ReLU6, prim::kPrimRelu6, ReLU6Infer, nullptr, true);
} // namespace ops

View File

@ -51,8 +51,7 @@ std::vector<int64_t> GetOutputMaskShape(const std::vector<int64_t> &input_shape,
}
return mask_shape;
}
std::vector<abstract::ShapePtr> InferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
abstract::TupleShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInteger("input numbers", input_args.size(), kEqual, 1, prim_name);
@ -74,13 +73,13 @@ std::vector<abstract::ShapePtr> InferShape(const PrimitivePtr &primitive,
if (min_shape.empty() || max_shape.empty()) {
inputs_shape = std::make_shared<abstract::Shape>(input_shape);
masks_shape = std::make_shared<abstract::Shape>(mask_shape);
return std::vector<abstract::ShapePtr>{inputs_shape, masks_shape};
return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{inputs_shape, masks_shape});
}
auto min_mask_shape = GetOutputMaskShape(min_shape, x_dtype);
auto max_mask_shape = GetOutputMaskShape(max_shape, x_dtype);
inputs_shape = std::make_shared<abstract::Shape>(input_shape, min_shape, max_shape);
masks_shape = std::make_shared<abstract::Shape>(mask_shape, min_mask_shape, max_mask_shape);
return std::vector<abstract::ShapePtr>{inputs_shape, masks_shape};
return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{inputs_shape, masks_shape});
}
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
@ -92,7 +91,8 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
MS_EXCEPTION_IF_NULL(x_type);
std::set<TypePtr> valid_x_type = {kTensorType};
CheckAndConvertUtils::CheckSubClass("input_x", x_type, valid_x_type, prim_name);
return CheckAndConvertUtils::CheckTensorTypeValid("input_x", x_type, valid_x_type, prim_name);
auto type = CheckAndConvertUtils::CheckTensorTypeValid("input_x", x_type, valid_x_type, prim_name);
return std::make_shared<Tuple>(std::vector<TypePtr>{type, type});
}
} // namespace
AbstractBasePtr ReLUV2Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
@ -100,10 +100,7 @@ AbstractBasePtr ReLUV2Infer(const abstract::AnalysisEnginePtr &, const Primitive
MS_EXCEPTION_IF_NULL(primitive);
auto types = InferType(primitive, input_args);
auto shapes = InferShape(primitive, input_args);
auto input_dtype = std::make_shared<abstract::AbstractTensor>(types, shapes[0]);
auto mask_dtype = std::make_shared<abstract::AbstractTensor>(kUInt8, shapes[1]);
AbstractBasePtrList outputs = {input_dtype, mask_dtype};
return std::make_shared<abstract::AbstractTuple>(outputs);
return abstract::MakeAbstract(shapes, types);
}
REGISTER_PRIMITIVE_EVAL_IMPL(ReLUV2, prim::kPrimReluV2, ReLUV2Infer, nullptr, true);

View File

@ -48,7 +48,7 @@ AbstractBasePtr ScalarSummaryInfer(const abstract::AnalysisEnginePtr &, const Pr
MS_EXCEPTION_IF_NULL(primitive);
// check
CheckAndConvertUtils::CheckSummaryParam(input_args[0], input_args[1], primitive->name());
return std::make_shared<abstract::AbstractTensor>(kInt32, ScalarSummaryInferShape(primitive, input_args));
return abstract::MakeAbstract(ScalarSummaryInferShape(primitive, input_args), kInt32);
}
REGISTER_PRIMITIVE_EVAL_IMPL(ScalarSummary, prim::kPrimScalarSummary, ScalarSummaryInfer, nullptr, true);
} // namespace ops

View File

@ -79,8 +79,7 @@ TypePtr SoftMaxInferType(const PrimitivePtr &prim, const std::vector<AbstractBas
AbstractBasePtr SoftmaxInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
return std::make_shared<abstract::AbstractTensor>(SoftMaxInferType(primitive, input_args),
SoftMaxInferShape(primitive, input_args));
return abstract::MakeAbstract(SoftMaxInferShape(primitive, input_args), SoftMaxInferType(primitive, input_args));
}
REGISTER_PRIMITIVE_EVAL_IMPL(Softmax, prim::kPrimSoftmax, SoftmaxInfer, nullptr, true);
} // namespace ops

View File

@ -49,9 +49,7 @@ TypePtr SoftplusInferType(const PrimitivePtr &prim, const std::vector<AbstractBa
AbstractBasePtr SoftplusInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto abs = std::make_shared<abstract::AbstractTensor>(SoftplusInferType(primitive, input_args),
SoftplusInferShape(primitive, input_args));
return abs;
return abstract::MakeAbstract(SoftplusInferShape(primitive, input_args), SoftplusInferType(primitive, input_args));
}
REGISTER_PRIMITIVE_EVAL_IMPL(Softplus, prim::kPrimSoftplus, SoftplusInfer, nullptr, true);
} // namespace ops

View File

@ -52,8 +52,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
AbstractBasePtr SubInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
InferShape(primitive, input_args)->shape());
return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args));
}
REGISTER_PRIMITIVE_C(kNameSub, Sub);
} // namespace ops

View File

@ -48,7 +48,7 @@ AbstractBasePtr TensorSummaryInfer(const abstract::AnalysisEnginePtr &, const Pr
MS_EXCEPTION_IF_NULL(primitive);
// check
CheckAndConvertUtils::CheckSummaryParam(input_args[0], input_args[1], primitive->name());
return std::make_shared<abstract::AbstractTensor>(kInt32, TensorSummaryInferShape(primitive, input_args));
return abstract::MakeAbstract(TensorSummaryInferShape(primitive, input_args), kInt32);
}
REGISTER_PRIMITIVE_EVAL_IMPL(TensorSummary, prim::kPrimTensorSummary, TensorSummaryInfer, nullptr, true);
} // namespace ops

View File

@ -91,8 +91,7 @@ TypePtr TileInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePt
AbstractBasePtr TileInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
return std::make_shared<abstract::AbstractTensor>(TileInferType(primitive, input_args),
TileInferShape(primitive, input_args)->shape());
return abstract::MakeAbstract(TileInferShape(primitive, input_args), TileInferType(primitive, input_args));
}
REGISTER_PRIMITIVE_C(kNameTile, Tile);

View File

@ -52,9 +52,7 @@ TypePtr ZerosInferType(const PrimitivePtr &prim, const std::vector<AbstractBaseP
AbstractBasePtr ZerosInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto abs = std::make_shared<abstract::AbstractTensor>(ZerosInferType(primitive, input_args),
ZerosInferShape(primitive, input_args));
return abs;
return abstract::MakeAbstract(ZerosInferShape(primitive, input_args), ZerosInferType(primitive, input_args));
}
ValuePtr ZerosInferValue(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {

View File

@ -38,15 +38,15 @@ TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBaseP
auto infer_type = input_args[0]->BuildType();
auto valid_type = common_valid_types;
valid_type.insert(kBool);
return CheckAndConvertUtils::CheckTensorTypeValid("infer_type", infer_type, valid_type, op_name);
(void)CheckAndConvertUtils::CheckTensorTypeValid("infer_type", infer_type, valid_type, op_name);
return infer_type;
}
} // namespace
AbstractBasePtr ZerosLikeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
InferShape(primitive, input_args));
return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args));
}
REGISTER_PRIMITIVE_EVAL_IMPL(ZerosLike, prim::kPrimZerosLike, ZerosLikeInfer, nullptr, true);
} // namespace ops