!15218 remove ConvertShapePtrToShape function, use ConvertShapePtrToShapeMap instead

From: @simson_wu
Reviewed-by: @ginfung,@chujinjin
Signed-off-by: @chujinjin
This commit is contained in:
mindspore-ci-bot 2021-04-17 17:17:25 +08:00 committed by Gitee
commit cd746e8d52
112 changed files with 167 additions and 259 deletions

View File

@ -30,11 +30,10 @@ namespace ops {
namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), prim_name);
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
return std::make_shared<abstract::Shape>(in_shape);
}

View File

@ -26,11 +26,10 @@ abstract::AbstractBasePtr AdamInfer(const PrimitivePtr &primitive, const std::ve
auto prim_name = primitive->name();
// infer shape
auto var_shape = CheckAndConvertUtils::ConvertShapePtrToShape("var_shape", input_args[0]->GetShapeTrack(), prim_name);
auto m_shape = CheckAndConvertUtils::ConvertShapePtrToShape("m_shape", input_args[1]->GetShapeTrack(), prim_name);
auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShape("v_shape", input_args[2]->GetShapeTrack(), prim_name);
auto grad_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("grad_shape", input_args[9]->GetShapeTrack(), prim_name);
auto var_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
auto m_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->GetShapeTrack())[kShape];
auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->GetShapeTrack())[kShape];
auto grad_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[9]->GetShapeTrack())[kShape];
CheckAndConvertUtils::Check("var_shape", var_shape, kEqual, "m_shape", m_shape, prim_name);
CheckAndConvertUtils::Check("var_shape", var_shape, kEqual, "v_shape", v_shape, prim_name);
CheckAndConvertUtils::Check("var_shape", var_shape, kEqual, "grad_shape", grad_shape, prim_name);

View File

@ -38,15 +38,13 @@ AbstractBasePtr AddNInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt
CheckAndConvertUtils::CheckInteger("concat element num", elements.size(), kGreaterEqual, 1, prim_name);
auto element0 = elements[0]->cast<abstract::AbstractTensorPtr>();
MS_EXCEPTION_IF_NULL(element0);
auto element0_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("element0 shape", element0->BuildShape(), prim_name);
auto element0_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(element0->BuildShape())[kShape];
std::map<std::string, TypePtr> types;
types.emplace("element0", element0->BuildType());
for (size_t i = 1; i < elements.size(); ++i) {
std::string elementi = "element" + std::to_string(i);
auto elementi_shape =
CheckAndConvertUtils::ConvertShapePtrToShape(elementi + " shape", elements[i]->BuildShape(), prim_name);
auto elementi_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(elements[i]->BuildShape())[kShape];
CheckAndConvertUtils::CheckInteger(elementi + " shape rank", elementi_shape.size(), kEqual, element0_shape.size(),
prim_name);
for (size_t j = 0; j < element0_shape.size(); ++j) {

View File

@ -60,7 +60,7 @@ AbstractBasePtr ApplyMomentumInfer(const abstract::AnalysisEnginePtr &, const Pr
CheckAndConvertUtils::CheckInteger("apply_momentum_infer", input_args.size(), kEqual, 5, prim_name);
// Infer shape
auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShape("v_shape", input_args[0]->BuildShape(), prim_name);
auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
// Infer type
auto v_tensor_type = input_args[0]->BuildType();

View File

@ -23,7 +23,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
auto axis = GetValue<int64_t>(primitive->GetAttr(kAxis));
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto x_rank = SizeToLong(x_shape.size());
CheckAndConvertUtils::CheckInRange<int64_t>("argmax axis", axis, kIncludeLeft, {-x_rank, x_rank}, prim_name);
axis = axis < 0 ? axis + x_rank : axis;

View File

@ -42,7 +42,7 @@ AbstractBasePtr ArgMinInfer(const abstract::AnalysisEnginePtr &, const Primitive
// Infer shape
auto axis = GetValue<int64_t>(primitive->GetAttr(kAxis));
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto x_rank = SizeToLong(x_shape.size());
CheckAndConvertUtils::CheckInRange<int64_t>("axis", axis, kIncludeLeft, {-x_rank, x_rank}, prim_name);
if (axis < 0) {

View File

@ -29,7 +29,7 @@ AbstractBasePtr AsinInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt
CheckAndConvertUtils::CheckInteger("Asin_infer", input_args.size(), kEqual, 1, prim_name);
// Infer Shape
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto infer_shape = std::make_shared<abstract::Shape>(x_shape);
// Infer Type

View File

@ -47,8 +47,7 @@ AbstractBasePtr AssertInfer(const abstract::AnalysisEnginePtr &, const Primitive
}
condition = TypeIdToType(kNumberTypeBool);
} else {
auto condition_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), op_name);
auto condition_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
CheckAndConvertUtils::CheckInteger("condition's rank", condition_shape[0], kLessEqual, 1, op_name);
if (condition_shape[0] == 1) {
auto condition_value = reinterpret_cast<bool *>(input_args[0]->BuildValue()->cast<tensor::TensorPtr>()->data_c());

View File

@ -25,9 +25,7 @@ namespace ops {
namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
auto value_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("value_shape", input_args[1]->BuildShape(), prim_name);
auto value_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
return std::make_shared<abstract::Shape>(value_shape);
}

View File

@ -27,7 +27,7 @@ AbstractBasePtr AtanInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt
CheckAndConvertUtils::CheckInteger("Atan_infer", input_args.size(), kEqual, 1, prim_name);
// Infer Shape
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto infer_shape = std::make_shared<abstract::Shape>(x_shape);
// Infer Type

View File

@ -30,9 +30,7 @@ namespace {
abstract::ShapePtr AudioSpectrogramInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
auto input_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), prim_name);
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
if (input_shape.size() != 2) {
MS_LOG(ERROR) << "input shape is error, which need to be 2 dimensions";
}

View File

@ -82,7 +82,7 @@ namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto op_name = primitive->name();
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), op_name);
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat)));
if (format == NHWC) {
in_shape = {in_shape[0], in_shape[3], in_shape[1], in_shape[2]};

View File

@ -75,20 +75,19 @@ AbstractBasePtr BatchNormInfer(const abstract::AnalysisEnginePtr &, const Primit
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInteger("batch_norm_infer", input_args.size(), kEqual, 5, prim_name);
auto input_x = CheckAndConvertUtils::ConvertShapePtrToShape("input_x", input_args[0]->BuildShape(), prim_name);
auto input_x = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat)));
if (format == NHWC) {
input_x = {input_x[0], input_x[3], input_x[1], input_x[2]};
}
auto scale = CheckAndConvertUtils::ConvertShapePtrToShape("scale", input_args[1]->BuildShape(), prim_name);
auto bias = CheckAndConvertUtils::ConvertShapePtrToShape("bias", input_args[2]->BuildShape(), prim_name);
auto mean = CheckAndConvertUtils::ConvertShapePtrToShape("mean", input_args[3]->BuildShape(), prim_name);
auto variance = CheckAndConvertUtils::ConvertShapePtrToShape("variance", input_args[4]->BuildShape(), prim_name);
auto scale = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
auto bias = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
auto mean = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[3]->BuildShape())[kShape];
auto variance = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[4]->BuildShape())[kShape];
std::vector<int64_t> input_shape_norm;
if (format == NCHW) {
input_shape_norm =
CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), prim_name);
input_shape_norm = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
} else {
input_shape_norm.push_back(input_x[0]);
input_shape_norm.push_back(input_x[3]);

View File

@ -68,12 +68,10 @@ AbstractBasePtr BatchNormFoldInfer(const abstract::AnalysisEnginePtr &, const Pr
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto op_name = primitive->name();
auto mean_shape = CheckAndConvertUtils::ConvertShapePtrToShape("mean_shape", input_args[1]->BuildShape(), op_name);
auto variance_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("variance_shape", input_args[2]->BuildShape(), op_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), op_name);
auto global_step_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("global_step_shape", input_args[3]->BuildShape(), op_name);
auto mean_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
auto variance_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto global_step_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[3]->BuildShape())[kShape];
CheckAndConvertUtils::Check("mean_shape", mean_shape, kEqual, "gamma_shape", variance_shape, op_name);
CheckAndConvertUtils::Check("mean_shape[0]", mean_shape[0], kEqual, "input channel", x_shape[1], op_name);
CheckAndConvertUtils::CheckInteger("global step shape len", global_step_shape.size(), kEqual, 1, op_name);

View File

@ -55,7 +55,7 @@ AbstractBasePtr BatchToSpaceInfer(const abstract::AnalysisEnginePtr &, const Pri
(void)CheckAndConvertUtils::CheckTensorTypeValid("input_x", input_args[0]->BuildType(), common_valid_types,
prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kEqual, 4, prim_name);
auto block_size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kBlockSize));
auto crops = GetValue<std::vector<std::vector<int64_t>>>(primitive->GetAttr(kCrops));

View File

@ -29,7 +29,7 @@ namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
CheckAndConvertUtils::CheckInteger("input_x rank", x_shape.size(), kEqual, 4, prim_name);
auto out_shape = x_shape;
int64_t block_shape_prod = 1;

View File

@ -30,8 +30,8 @@ abstract::ShapePtr BiasAddInferShape(const PrimitivePtr &primitive, const std::v
auto prim_name = primitive->name();
// check
CheckAndConvertUtils::CheckInteger("arg size", input_args.size(), kEqual, 2, prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto b_shape = CheckAndConvertUtils::ConvertShapePtrToShape("b_shape", input_args[1]->BuildShape(), prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto b_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kGreaterEqual, 2, prim_name);
CheckAndConvertUtils::CheckInteger("bias rank", b_shape.size(), kEqual, 1, prim_name);
auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat)));

View File

@ -34,10 +34,9 @@ abstract::ShapePtr BinaryCrossEntroyInferShape(const PrimitivePtr &primitive,
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInRange("binary_cross_entropy_infer", input_args.size(), kIncludeBoth, {2, 3}, prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShape("y_shape", input_args[1]->BuildShape(), prim_name);
auto weight_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("weight_shape", input_args[2]->BuildShape(), prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
auto weight_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
CheckAndConvertUtils::Check("x shape", x_shape, kEqual, "y shape", y_shape, prim_name);
std::vector<int64_t> infer_shape;
if (weight_shape.size() < 1) {

View File

@ -50,7 +50,7 @@ AbstractBasePtr BroadcastInfer(const abstract::AnalysisEnginePtr &, const Primit
MS_EXCEPTION_IF_NULL(item);
}
// infer shape
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
// infer type
auto x_type = input_args[0]->BuildType()->cast<TensorTypePtr>()->element();
std::vector<TypePtr> output_types;

View File

@ -24,7 +24,7 @@ abstract::ShapePtr BroadcastToInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto value_ptr = primitive->GetAttr(kShape);
auto input_x = GetValue<std::vector<int64_t>>(value_ptr);
int64_t outer_dim_offset = input_x.size() - x_shape.size();

View File

@ -31,7 +31,7 @@ AbstractBasePtr CeilInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), "Ceil");
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
auto infer_type = input_args[0]->BuildType();
auto data_type = CheckAndConvertUtils::CheckTensorTypeValid("x type", infer_type, valid_types, primitive->name());

View File

@ -43,8 +43,7 @@ AbstractBasePtr ConcatInfer(const abstract::AnalysisEnginePtr &, const Primitive
CheckAndConvertUtils::CheckInteger("concat element num", elements.size(), kGreaterEqual, 1, prim_name);
auto element0 = elements[0]->cast<abstract::AbstractTensorPtr>();
MS_EXCEPTION_IF_NULL(element0);
auto element0_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("element0 shape", element0->BuildShape(), prim_name);
auto element0_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(element0->BuildShape())[kShape];
auto element0_rank = SizeToLong(element0_shape.size());
auto axis = GetValue<int64_t>(primitive->GetAttr(kAxis));
CheckAndConvertUtils::CheckInRange<int64_t>("Concat axis", axis, kIncludeBoth, {-element0_rank - 1, element0_rank},
@ -56,8 +55,7 @@ AbstractBasePtr ConcatInfer(const abstract::AnalysisEnginePtr &, const Primitive
int64_t all_shp = element0_shape[axis];
for (size_t i = 1; i < elements.size(); ++i) {
std::string elementi = "element" + std::to_string(i);
auto elementi_shape =
CheckAndConvertUtils::ConvertShapePtrToShape(elementi + " shape", elements[i]->BuildShape(), prim_name);
auto elementi_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(elements[i]->BuildShape())[kShape];
CheckAndConvertUtils::CheckInteger(elementi + " shape rank", elementi_shape.size(), kEqual, element0_shape.size(),
prim_name);
for (int64_t j = 0; j < element0_rank; ++j) {

View File

@ -24,8 +24,7 @@ namespace ops {
namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
CheckAndConvertUtils::CheckInteger("input args size", input_args.size(), kEqual, 1, "ConstantOfShape");
auto input_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), "ConstantOfShape");
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
return std::make_shared<abstract::Shape>(input_shape);
}

View File

@ -79,8 +79,8 @@ abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::ve
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInRange<size_t>("conv2d_infer", input_args.size(), kIncludeBoth, {2, 3}, prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShape("w_shape", input_args[1]->BuildShape(), prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat)));
if (format == NHWC) {
x_shape = {x_shape[0], x_shape[3], x_shape[1], x_shape[2]};

View File

@ -28,8 +28,7 @@ namespace {
abstract::ShapePtr Conv2dTransposeInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[3]->BuildShape(), prim_name);
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[3]->BuildShape())[kShape];
return std::make_shared<abstract::Shape>(input_shape);
}

View File

@ -24,11 +24,10 @@ namespace ops {
namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
return std::make_shared<abstract::Shape>(in_shape);
}

View File

@ -49,7 +49,7 @@ AbstractBasePtr CropInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt
MS_EXCEPTION_IF_NULL(item);
}
// infer shape
auto out_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[1]->BuildShape(), prim_name);
auto out_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
// infer type
auto x_type = input_args[0]->BuildType()->cast<TensorTypePtr>()->element();
return std::make_shared<abstract::AbstractTensor>(x_type, out_shape);

View File

@ -24,18 +24,14 @@ namespace ops {
AbstractBasePtr CustomExtractFeaturesInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
MS_EXCEPTION_IF_NULL(input_args[0]);
// auto input = input_args[0];
// Infer type
auto output0_type = kInt32;
auto output1_type = kFloat32;
// Infer shape
std::vector<int64_t> out_shape;
auto input_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), prim_name);
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto string_num = input_shape[0];
if (string_num == 0) {
out_shape.push_back(1);

View File

@ -54,7 +54,7 @@ AbstractBasePtr DepthToSpaceInfer(const abstract::AnalysisEnginePtr &, const Pri
auto input_x = input_args[0]->cast<abstract::AbstractTensorPtr>();
MS_EXCEPTION_IF_NULL(input_x);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat)));
if (format == NHWC) {
x_shape = {x_shape[0], x_shape[3], x_shape[1], x_shape[2]};

View File

@ -119,8 +119,8 @@ abstract::ShapePtr DepthWiseConv2DInferShape(const PrimitivePtr &primitive,
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInRange<size_t>("conv2d_Infer", input_args.size(), kIncludeBoth, {2, 3}, prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), prim_name);
auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShape("w_shape", input_args[1]->GetShapeTrack(), prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->GetShapeTrack())[kShape];
auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat)));
if (format == NHWC) {
x_shape = {x_shape[0], x_shape[3], x_shape[1], x_shape[2]};

View File

@ -120,9 +120,9 @@ AbstractBasePtr DetectionPostProcessInfer(const abstract::AnalysisEnginePtr &, c
auto boxes = input_args[0];
auto scores = input_args[1];
auto anchors = input_args[2];
auto boxes_shape = CheckAndConvertUtils::ConvertShapePtrToShape("boxes_shape", boxes->BuildShape(), prim_name);
auto scores_shape = CheckAndConvertUtils::ConvertShapePtrToShape("scores_shape", scores->BuildShape(), prim_name);
auto anchors_shape = CheckAndConvertUtils::ConvertShapePtrToShape("anchors_shape", anchors->BuildShape(), prim_name);
auto boxes_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(boxes->BuildShape())[kShape];
auto scores_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(scores->BuildShape())[kShape];
auto anchors_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(anchors->BuildShape())[kShape];
auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat)));
if (format == NHWC) {
boxes_shape = {boxes_shape[0], boxes_shape[3], boxes_shape[1], boxes_shape[2]};

View File

@ -43,7 +43,7 @@ AbstractBasePtr DropoutInfer(const abstract::AnalysisEnginePtr &, const Primitiv
CheckAndConvertUtils::CheckInteger("dropout_infer", input_args.size(), kEqual, 1, prim_name);
// Infer shape
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
CheckAndConvertUtils::CheckInteger("x_shape", x_shape.size(), kGreaterEqual, 1, prim_name);
std::vector<int64_t> out_shape;
out_shape.insert(out_shape.end(), x_shape.begin(), x_shape.end());

View File

@ -31,11 +31,10 @@ namespace ops {
namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto op_name = primitive->name();
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->GetShapeTrack(), op_name);
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
return std::make_shared<abstract::Shape>(in_shape);
}

View File

@ -36,7 +36,7 @@ AbstractBasePtr ExpandDimsInfer(const abstract::AnalysisEnginePtr &, const Primi
MS_EXCEPTION_IF_NULL(item);
}
// Infer shape
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto dim_val = GetValue<int64_t>(input_args[1]->BuildValue());
auto rank = x_shape.size();
CheckAndConvertUtils::CheckInRange<int64_t>("axis", dim_val, kIncludeBoth, {-rank - 1, rank}, prim_name);

View File

@ -29,9 +29,9 @@ namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto min_shape = CheckAndConvertUtils::ConvertShapePtrToShape("min_shape", input_args[1]->BuildShape(), prim_name);
auto max_shape = CheckAndConvertUtils::ConvertShapePtrToShape("max_shape", input_args[2]->BuildShape(), prim_name);
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto min_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
auto max_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
CheckAndConvertUtils::CheckInteger("x_rank", in_shape.size(), kGreaterEqual, 1, prim_name);
CheckAndConvertUtils::Check("min_shape", min_shape, kEqual, "max_shape", max_shape, prim_name);
CheckAndConvertUtils::CheckInteger("min_shape", min_shape.size(), kEqual, 1, prim_name);

View File

@ -44,9 +44,9 @@ AbstractBasePtr FakeQuantWithMinMaxVarsPerChannelInfer(const abstract::AnalysisE
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto op_name = primitive->name();
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), op_name);
auto min_shape = CheckAndConvertUtils::ConvertShapePtrToShape("min_shape", input_args[1]->BuildShape(), op_name);
auto max_shape = CheckAndConvertUtils::ConvertShapePtrToShape("max_shape", input_args[2]->BuildShape(), op_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto min_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
auto max_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
CheckAndConvertUtils::CheckInteger("x rank", (int64_t)x_shape.size(), kGreaterThan, 1, op_name);
CheckAndConvertUtils::Check("min shape", min_shape, kEqual, "max shape", max_shape, op_name);
CheckAndConvertUtils::CheckInteger("min shape", (int64_t)min_shape.size(), kEqual, 1, op_name);

View File

@ -24,8 +24,7 @@ namespace ops {
namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("in_shape", input_args[0]->BuildShape(), prim_name);
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
in_shape.pop_back();
return std::make_shared<abstract::Shape>(in_shape);
}

View File

@ -33,7 +33,7 @@ AbstractBasePtr FftRealInfer(const abstract::AnalysisEnginePtr &, const Primitiv
MS_EXCEPTION_IF_NULL(item);
}
auto out_dtype = kFloat32;
auto out_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto out_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
out_shape.pop_back();
return std::make_shared<abstract::AbstractTensor>(out_dtype, std::make_shared<abstract::Shape>(out_shape));
}

View File

@ -25,7 +25,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInteger("input args size", input_args.size(), kGreaterEqual, 1, prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto prod = 1;
int64_t size = x_shape.size();
for (int64_t i = 1; i < size; i++) {

View File

@ -28,11 +28,10 @@ namespace ops {
namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), prim_name);
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
return std::make_shared<abstract::Shape>(in_shape);
}

View File

@ -53,7 +53,7 @@ namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto op_name = primitive->name();
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), op_name);
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat)));
if (format == NHWC) {
in_shape = {in_shape[0], in_shape[3], in_shape[1], in_shape[2]};

View File

@ -53,8 +53,8 @@ AbstractBasePtr FullConnectionInfer(const abstract::AnalysisEnginePtr &, const P
MS_EXCEPTION_IF_NULL(input_args[1]);
auto input0 = input_args[0];
auto input1 = input_args[1];
auto input0_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input0_shape", input0->BuildShape(), prim_name);
auto input1_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input1_shape", input1->BuildShape(), prim_name);
auto input0_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input0->BuildShape())[kShape];
auto input1_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input1->BuildShape())[kShape];
auto prim_axis = GetValue<int64_t>(primitive->GetAttr(kAxis));
auto has_bias = GetValue<bool>(primitive->GetAttr(kHasBias));
if (has_bias) {
@ -78,8 +78,7 @@ AbstractBasePtr FullConnectionInfer(const abstract::AnalysisEnginePtr &, const P
new_k = input1_shape[1];
}
if (has_bias) {
auto input2_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("input2_shape", input_args[2]->BuildShape(), prim_name);
auto input2_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
if (input2_shape[0] != input1_shape[0]) {
MS_EXCEPTION(ValueError) << "Bias size invalid";
}

View File

@ -53,7 +53,7 @@ namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto op_name = primitive->name();
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), op_name);
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat)));
if (format == NHWC) {
in_shape = {in_shape[0], in_shape[3], in_shape[1], in_shape[2]};

View File

@ -33,7 +33,7 @@ AbstractBasePtr SliceFusionInfer(const abstract::AnalysisEnginePtr &, const Prim
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto op_name = primitive->name();
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), op_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto x_shape_len = (int64_t)x_shape.size();
auto begin_v = input_args[1]->BuildValue();
auto size_v = input_args[2]->BuildValue();

View File

@ -29,8 +29,8 @@ abstract::ShapePtr GatherDInferShape(const PrimitivePtr &primitive, const std::v
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
// check
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto index_shape = CheckAndConvertUtils::ConvertShapePtrToShape("dim_shape", input_args[2]->BuildShape(), prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto index_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
int64_t x_rank = x_shape.size();
CheckAndConvertUtils::Check("x_rank", x_rank, kEqual, "index_rank", index_shape.size(), prim_name);
auto dim_v = GetValue<int64_t>(input_args[1]->BuildValue());

View File

@ -32,9 +32,8 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto indices_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("indices_shape", input_args[1]->BuildShape(), prim_name);
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto indices_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
auto input_rank = input_shape.size();
auto indices_rank = indices_shape.size();
CheckAndConvertUtils::CheckInteger("Input of indices data", input_rank, kGreaterEqual,

View File

@ -28,8 +28,7 @@ namespace ops {
namespace {
abstract::ShapePtr GeLUInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input_x", input_args[0]->BuildShape(), prim_name);
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
return std::make_shared<abstract::Shape>(input_shape);
}

View File

@ -47,13 +47,11 @@ bool BatchNormGrad::get_is_training() const {
AbstractBasePtr BatchNormGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto op_name = primitive->name();
MS_EXCEPTION_IF_NULL(input_args[1]);
MS_EXCEPTION_IF_NULL(input_args[2]);
MS_EXCEPTION_IF_NULL(input_args[3]);
auto y_backprop_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("y_backprop_shape", input_args[0]->BuildShape(), op_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[1]->BuildShape(), op_name);
auto y_backprop_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
CheckAndConvertUtils::Check("BatchNorm y_backprop_shape", y_backprop_shape, kEqual, "BatchNorm x_shape", x_shape);
auto dx = input_args[1]->Broaden();

View File

@ -46,7 +46,7 @@ AbstractBasePtr BiasAddGradInfer(const abstract::AnalysisEnginePtr &, const Prim
MS_EXCEPTION_IF_NULL(input_args[0]);
// Infer shape
auto inshape = CheckAndConvertUtils::ConvertShapePtrToShape("inshape", input_args[0]->BuildShape(), prim_name);
auto inshape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
for (size_t i = 0; i < inshape.size() - 1; i++) {
inshape[i] = 1;
}

View File

@ -27,10 +27,9 @@ abstract::ShapePtr BinaryCrossEntroyGradInferShape(const PrimitivePtr &primitive
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShape("y_shape", input_args[1]->BuildShape(), prim_name);
auto weight_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("weight_shape", input_args[2]->BuildShape(), prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
auto weight_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
CheckAndConvertUtils::Check("x shape", x_shape, kEqual, "y shape", y_shape, prim_name);
if (weight_shape.size() < 1) {
CheckAndConvertUtils::Check("y shape", y_shape, kEqual, "weight shape", weight_shape, prim_name);

View File

@ -35,8 +35,7 @@ namespace {
abstract::ShapePtr DropoutGradInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto op_name = primitive->name();
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), op_name);
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
return std::make_shared<abstract::Shape>(in_shape);
}

View File

@ -21,10 +21,8 @@ namespace mindspore {
namespace ops {
AbstractBasePtr MaxPoolGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto op_name = primitive->name();
MS_EXCEPTION_IF_NULL(input_args[0]->BuildValue());
auto x1_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x1_shape", input_args[0]->BuildShape(), op_name);
auto x1_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto tensor_type = input_args[0]->BuildType()->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(tensor_type);
auto element = tensor_type->element();

View File

@ -35,9 +35,9 @@ AbstractBasePtr SigmoidCrossEntropyWithLogitsGradInfer(const abstract::AnalysisE
prim_name);
// Infer Shape
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShape("y_shape", input_args[1]->BuildShape(), prim_name);
auto dout_shape = CheckAndConvertUtils::ConvertShapePtrToShape("dout_shape", input_args[2]->BuildShape(), prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
auto dout_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
CheckAndConvertUtils::Check("x_shape", x_shape, kEqual, "y_shape", y_shape, prim_name, TypeError);
CheckAndConvertUtils::Check("x_shape", x_shape, kEqual, "dout_shape", dout_shape, prim_name, TypeError);

View File

@ -40,9 +40,9 @@ AbstractBasePtr SmoothL1LossGradInfer(const abstract::AnalysisEnginePtr &, const
CheckAndConvertUtils::CheckInteger("smooth_l1_loss_grad_infer", input_args.size(), kEqual, 3, prim_name);
// Infer shape
auto prediction = CheckAndConvertUtils::ConvertShapePtrToShape("prediction", input_args[0]->BuildShape(), prim_name);
auto target = CheckAndConvertUtils::ConvertShapePtrToShape("target", input_args[1]->BuildShape(), prim_name);
auto dloss = CheckAndConvertUtils::ConvertShapePtrToShape("dloss", input_args[2]->BuildShape(), prim_name);
auto prediction = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto target = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
auto dloss = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
CheckAndConvertUtils::Check("prediction shape", prediction, kEqual, "target shape", target, prim_name, TypeError);
CheckAndConvertUtils::Check("prediction shape", prediction, kEqual, "dloss", dloss, prim_name, TypeError);

View File

@ -27,9 +27,8 @@ AbstractBasePtr HashtableLookupInfer(const abstract::AnalysisEnginePtr &, const
for (auto input : input_args) {
MS_EXCEPTION_IF_NULL(input);
}
auto op_name = primitive->name();
std::vector<int64_t> hits_shape;
auto input = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), op_name);
auto input = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
hits_shape.push_back(input[0]);
auto value_type = input_args[2]->BuildType();

View File

@ -46,7 +46,7 @@ AbstractBasePtr L2NormalizeInfer(const abstract::AnalysisEnginePtr &, const Prim
}
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
(void)CheckAndConvertUtils::CheckTensorTypeValid("input_x", input_args[0]->BuildType(), valid_types, prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto x_rank = SizeToLong(x_shape.size());
auto axiss = GetValue<std::vector<int64_t>>(primitive->GetAttr(kAxis));
for (auto &axis : axiss) {

View File

@ -24,7 +24,7 @@ namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), "Log");
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
return std::make_shared<abstract::Shape>(x_shape);
}

View File

@ -24,8 +24,7 @@ namespace ops {
namespace {
abstract::ShapePtr LogicalNotInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto op_name = primitive->name();
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), op_name);
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
return std::make_shared<abstract::Shape>(in_shape);
}

View File

@ -78,7 +78,7 @@ namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
CheckAndConvertUtils::CheckInteger("input shape", in_shape.size(), kEqual, 4, prim_name);
return std::make_shared<abstract::Shape>(in_shape);
}

View File

@ -32,14 +32,14 @@ AbstractBasePtr LshProjectionInfer(const abstract::AnalysisEnginePtr &, const Pr
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto op_name = primitive->name();
auto input0 = CheckAndConvertUtils::ConvertShapePtrToShape("input0_shape", input_args[0]->BuildShape(), op_name);
auto input1 = CheckAndConvertUtils::ConvertShapePtrToShape("input1_shape", input_args[1]->BuildShape(), op_name);
auto input0 = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto input1 = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
CheckAndConvertUtils::CheckInteger("input0_shape", input0.size(), kEqual, 2, op_name);
CheckAndConvertUtils::CheckInteger("input0_shape_dimen_1", input0[1], kLessEqual, 32, op_name);
CheckAndConvertUtils::CheckInteger("input1_shape", input1.size(), kGreaterEqual, 1, op_name);
if (input_args.size() == 3) {
auto input2 = CheckAndConvertUtils::ConvertShapePtrToShape("input2_shape", input_args[2]->BuildShape(), op_name);
auto input2 = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
CheckAndConvertUtils::CheckInteger("input2_shape", input2.size(), kEqual, 1, op_name);
CheckAndConvertUtils::CheckInteger("input2_shape_dimen_0", input2[0], kEqual, input1[0], op_name);
}

View File

@ -32,9 +32,9 @@ AbstractBasePtr LstmInfer(const PrimitivePtr &primitive, const std::vector<Abstr
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInteger("lstm_prim_infer", input_args.size(), kEqual, 4, prim_name);
auto x_input_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto h_input_shape = CheckAndConvertUtils::ConvertShapePtrToShape("h_shape", input_args[1]->BuildShape(), prim_name);
auto c_input_shape = CheckAndConvertUtils::ConvertShapePtrToShape("c_shape", input_args[2]->BuildShape(), prim_name);
auto x_input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto h_input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
auto c_input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
int64_t input_x_size = GetValue<int64_t>(primitive->GetAttr(kInput_size));
CheckAndConvertUtils::CheckInteger("x_shape.size()", x_input_shape.size(), kEqual, 3, prim_name);

View File

@ -26,8 +26,8 @@ abstract::ShapePtr MatMulInferShape(const PrimitivePtr &primitive, const std::ve
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInteger("matmul_infer_input", input_args.size(), kEqual, 2, prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShape("w_shape", input_args[1]->BuildShape(), prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
auto trans_a = GetValue<bool>(primitive->GetAttr(kTransposeA));
auto trans_b = GetValue<bool>(primitive->GetAttr(kTransposeB));

View File

@ -30,9 +30,8 @@ namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto assist_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("assist_shape", input_args[1]->BuildShape(), prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto assist_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
CheckAndConvertUtils::CheckInteger("assist rank", (int64_t)assist_shape.size(), kGreaterEqual, 2, prim_name);
CheckAndConvertUtils::Check("x_shape rank", (int64_t)x_shape.size() + 1, kLessEqual, "assist rank",

View File

@ -82,7 +82,7 @@ namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto op_name = primitive->name();
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), op_name);
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat)));
if (format == NHWC) {
in_shape = {in_shape[0], in_shape[3], in_shape[1], in_shape[2]};

View File

@ -25,10 +25,8 @@ namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
auto first_input_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("first_input_shape", input_args[0]->BuildShape(), prim_name);
auto second_input_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("second_input_shape", input_args[1]->BuildShape(), prim_name);
auto first_input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto second_input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
CheckAndConvertUtils::CheckInteger("first input rank", first_input_shape.size(), kEqual, 3, prim_name);
CheckAndConvertUtils::CheckInteger("second input rank", second_input_shape.size(), kEqual, 1, prim_name);
std::vector<int64_t> out_shape = {first_input_shape[0], first_input_shape[1],

View File

@ -31,7 +31,7 @@ abstract::ShapePtr OneHotInferShape(const PrimitivePtr &primitive, const std::ve
MS_EXCEPTION_IF_NULL(primitive);
auto op_name = primitive->name();
int64_t axis = GetValue<int64_t>(primitive->GetAttr(kAxis));
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), op_name);
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
CheckAndConvertUtils::CheckInRange<int64_t>("axis", axis, kIncludeBoth, {-1, SizeToLong(in_shape.size())}, op_name);
auto depth_val = GetValue<int64_t>(input_args[1]->BuildValue());
CheckAndConvertUtils::CheckInteger("depth", depth_val, kGreaterEqual, 0, op_name);

View File

@ -28,9 +28,7 @@ namespace ops {
namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
auto input_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), prim_name);
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
return std::make_shared<abstract::Shape>(input_shape);
}

View File

@ -27,8 +27,8 @@ namespace mindspore {
namespace ops {
abstract::ShapePtr BroadCastInferShape(const std::string &op_name, const std::vector<AbstractBasePtr> &input_args) {
MS_LOG(INFO) << "Do infer shape for op " << op_name;
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), op_name);
auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShape("y_shape", input_args[1]->GetShapeTrack(), op_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->GetShapeTrack())[kShape];
if (x_shape == y_shape) {
return std::make_shared<abstract::Shape>(x_shape);
}

View File

@ -23,7 +23,7 @@ std::vector<int64_t> _get_pack_shape(std::vector<BaseShapePtr> x_shapes, std::ve
std::string name) {
CheckAndConvertUtils::CheckInteger("len of input_x", (int64_t)x_shapes.size(), kGreaterEqual, 1, name);
CheckAndConvertUtils::CheckSubClass("input_x[0]", x_types[0], {TypeIdToType(kObjectTypeTensorType)}, name);
auto output_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape[0]", x_shapes[0], name);
auto output_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(x_shapes[0])[kShape];
int64_t rank_base = output_shape.size();
int64_t N = x_shapes.size();
// CheckAndConvertUtils::CheckInRange("axis", axis, kIncludeBoth, {-rank_base-1, rank_base}, name);
@ -37,7 +37,7 @@ std::vector<int64_t> _get_pack_shape(std::vector<BaseShapePtr> x_shapes, std::ve
MS_EXCEPTION_IF_NULL(type0);
CheckAndConvertUtils::Check("x_type[" + std::to_string(i) + "]", type->type_id(), kEqual, "base", type0->type_id(),
name);
auto shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape" + std::to_string(i), x_shapes[i], name);
auto shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(x_shapes[i])[kShape];
if (shape != output_shape) {
MS_EXCEPTION(ValueError) << "For '" + name + "' element " + std::to_string(i) +
"shape in input can't pack with first element.";

View File

@ -25,7 +25,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
auto paddings_attr = GetValue<std::vector<std::vector<int64_t>>>(primitive->GetAttr(kPaddings));
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), "Pad");
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
CheckAndConvertUtils::CheckInteger("paddings_size", paddings_attr.size(), kEqual, int64_t(2 * x_shape.size()),
prim_name);
int64_t size = paddings_attr.size();

View File

@ -25,8 +25,8 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
auto prim_name = primitive->name();
auto x = input_args[0]->BuildShape();
auto w = input_args[1]->BuildShape();
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", x, prim_name);
auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShape("w_shape", w, prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(x)[kShape];
auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(w)[kShape];
CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kNotEqual, 1, prim_name);
CheckAndConvertUtils::CheckInteger("weight rank", w_shape.size(), kEqual, 1, prim_name);

View File

@ -112,7 +112,6 @@ void PriorBox::Init(const std::vector<int64_t> &min_sizes, const std::vector<int
AbstractBasePtr PriorBoxInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto op_name = primitive->name();
MS_EXCEPTION_IF_NULL(input_args[0]);
std::vector<float> different_aspect_ratios{1.0f};
auto aspect_ratios = GetValue<std::vector<float>>(primitive->GetAttr(kAspectRatios));
@ -129,7 +128,7 @@ AbstractBasePtr PriorBoxInfer(const abstract::AnalysisEnginePtr &, const Primiti
}
auto min_sizes = GetValue<std::vector<int64_t>>(primitive->GetAttr(kMinSizes));
int64_t num_priors_box = min_sizes.size() * different_aspect_ratios.size() + min_sizes.size();
auto input = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), op_name);
auto input = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
int64_t h = input[0] * input[1] * num_priors_box * 4;
std::vector<int64_t> output_shape{1, h, 1, 2};
return std::make_shared<abstract::AbstractTensor>(kFloat32, output_shape);

View File

@ -32,13 +32,12 @@ void QuantDTypeCast::Init(const int64_t src_t, const int64_t dst_t) {
AbstractBasePtr QuantDTypeCastInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto op_name = primitive->name();
MS_EXCEPTION_IF_NULL(input_args[0]);
auto input_type = input_args[0]->BuildType()->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(input_type);
auto dst_type = GetValue<int64_t>(primitive->GetAttr(kDstT));
MS_ASSERT(input_type->element() == TypeIdToType(TypeId(dst_type)));
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), op_name);
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
return std::make_shared<abstract::AbstractTensor>(TypeIdToType(TypeId(dst_type)), input_shape);
}
REGISTER_PRIMITIVE_C(kNameQuantDTypeCast, QuantDTypeCast);

View File

@ -34,8 +34,7 @@ AbstractBasePtr ReciprocalInfer(const abstract::AnalysisEnginePtr &, const Primi
MS_EXCEPTION_IF_NULL(item);
}
// infer shape
auto in_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->GetShapeTrack(), prim_name);
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
// infer type
std::set<TypePtr> valid_x_type = {kTensorType};
auto x_type = CheckAndConvertUtils::CheckTypeValid("x_type", input_args[0]->BuildType(), valid_x_type, prim_name);

View File

@ -71,8 +71,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
auto input_x_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("input_x_shape", input_args[0]->BuildShape(), prim_name);
auto input_x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto keep_dims = GetValue<bool>(primitive->GetAttr(kKeepDims));
auto out_shape = infer_shape_reduce(input_x_shape, axis_value, keep_dims, prim_name);

View File

@ -49,8 +49,7 @@ AbstractBasePtr ResizeBilinearInfer(const abstract::AnalysisEnginePtr &, const P
CheckAndConvertUtils::CheckInteger("resize_bilinear_infer", input_args.size(), kEqual, 1, prim_name);
// Infer shape
auto input_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), prim_name);
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
CheckAndConvertUtils::CheckInteger("input_shape_rank", input_shape.size(), kEqual, 4, prim_name);
std::vector<int64_t> out_shape = {input_shape[0], input_shape[1]};
auto size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kSize));

View File

@ -44,10 +44,8 @@ AbstractBasePtr ReverseSequenceInfer(const abstract::AnalysisEnginePtr &, const
MS_EXCEPTION_IF_NULL(item);
}
// infer shape
auto input_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), prim_name);
auto seq_lengths =
CheckAndConvertUtils::ConvertShapePtrToShape("seq_lengths", input_args[1]->BuildShape(), prim_name);
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto seq_lengths = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
auto seq_dim = GetValue<int64_t>(primitive->GetAttr(kSeqDim));
auto batch_dim = GetValue<int64_t>(primitive->GetAttr(kBatchDim));
CheckAndConvertUtils::CheckInteger("seq_dim", seq_dim, kLessEqual, input_shape.size(), prim_name);

View File

@ -24,8 +24,7 @@ namespace ops {
namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
return std::make_shared<abstract::Shape>(x_shape);
}

View File

@ -24,9 +24,7 @@ namespace ops {
namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
auto first_input_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("first_input_shape", input_args[0]->BuildShape(), prim_name);
auto first_input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto out_shape = first_input_shape;
out_shape[out_shape.size() - 1] = GetValue<int64_t>(primitive->GetAttr(kFftLength)) / 2 + 1;
out_shape.push_back(2);

View File

@ -62,9 +62,8 @@ AbstractBasePtr ROIPoolingInfer(const abstract::AnalysisEnginePtr &, const Primi
// Infer shape
auto new_h = GetValue<int64_t>(primitive->GetAttr(kPooledH));
auto new_w = GetValue<int64_t>(primitive->GetAttr(kPooledW));
auto input_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), prim_name);
auto roi_shape = CheckAndConvertUtils::ConvertShapePtrToShape("roi_shape", input_args[1]->BuildShape(), prim_name);
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto roi_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
std::vector<int64_t> output_shape;
output_shape.push_back(roi_shape[0]);
output_shape.push_back(new_h);

View File

@ -23,7 +23,7 @@ namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), "round");
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
return std::make_shared<abstract::Shape>(x_shape);
}

View File

@ -30,7 +30,7 @@ namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("in_shape", input_args[0]->GetShapeTrack(), prim_name);
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
CheckAndConvertUtils::CheckInteger("input shape", in_shape.size(), kEqual, 1, prim_name);
return std::make_shared<abstract::Shape>(in_shape);
}

View File

@ -29,7 +29,7 @@ abstract::ShapePtr ScalarSummaryInferShape(const PrimitivePtr &primitive,
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
// check
auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShape("v_shape", input_args[1]->BuildShape(), prim_name);
auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
CheckAndConvertUtils::CheckInteger("v rank", v_shape.size(), kLessEqual, 1, prim_name);
return std::make_shared<abstract::Shape>(ShapeVector(1));
}

View File

@ -29,10 +29,8 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
for (const auto &shape : shape_value_element) {
CheckAndConvertUtils::CheckInteger("shape value", shape, kGreaterThan, 0, "ScatterNd");
}
auto indices_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("indices_shape", input_args[0]->BuildShape(), "ScatterNd");
auto update_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("update_shape", input_args[1]->BuildShape(), "ScatterNd");
auto indices_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto update_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
CheckAndConvertUtils::CheckInteger("indices_shape[0] and update_shape[0]", indices_shape[0], kEqual, update_shape[0],
"ScatterNd");
return std::make_shared<abstract::Shape>(shape_value_element);

View File

@ -34,8 +34,8 @@ AbstractBasePtr SigmoidCrossEntropyWithLogitsInfer(const abstract::AnalysisEngin
prim_name);
// Infer shape
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShape("y_shape", input_args[1]->BuildShape(), prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
CheckAndConvertUtils::Check("x_shape", x_shape, kEqual, "y_shape", y_shape, prim_name, TypeError);
// Infer type

View File

@ -31,7 +31,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), "Sin");
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
return std::make_shared<abstract::Shape>(x_shape);
}

View File

@ -23,7 +23,6 @@ namespace ops {
namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
if (input_args.size() != 1) {
MS_LOG(ERROR) << "Skip Gram should have one input";
}
@ -31,7 +30,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
if (infer_value == nullptr) {
MS_LOG(INFO) << "Do infer shape in runtime.";
}
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("in_shape", input_args[0]->BuildShape(), prim_name);
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
return std::make_shared<abstract::Shape>(in_shape);
}

View File

@ -40,8 +40,8 @@ AbstractBasePtr SmoothL1LossInfer(const abstract::AnalysisEnginePtr &, const Pri
CheckAndConvertUtils::CheckInteger("smooth_l1_loss_infer", input_args.size(), kEqual, 2, prim_name);
// Infer shape
auto prediction = CheckAndConvertUtils::ConvertShapePtrToShape("prediction", input_args[0]->BuildShape(), prim_name);
auto target = CheckAndConvertUtils::ConvertShapePtrToShape("target", input_args[0]->BuildShape(), prim_name);
auto prediction = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto target = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
CheckAndConvertUtils::Check("prediction shape", prediction, kEqual, "target shape", target, prim_name, TypeError);
// Infer type

View File

@ -34,10 +34,8 @@ AbstractBasePtr SoftmaxCrossEntropyWithLogitsInfer(const abstract::AnalysisEngin
prim_name);
// Infer shape
auto logits_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("logits_shape", input_args[0]->BuildShape(), prim_name);
auto labels_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("labels_shape", input_args[1]->BuildShape(), prim_name);
auto logits_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto labels_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
CheckAndConvertUtils::Check("logits shape", logits_shape, kEqual, "labels shape", labels_shape, prim_name, TypeError);
std::vector<int64_t> loss_shape = {logits_shape[0]};
auto dlogits_shape = logits_shape;

View File

@ -29,8 +29,7 @@ namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
auto input_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), prim_name);
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
CheckAndConvertUtils::CheckInteger("input shape", input_shape.size(), kEqual, 4, prim_name);
std::vector<int64_t> output_shape(input_shape.size());
auto block_shape_vector = GetValue<std::vector<int64_t>>(primitive->GetAttr(kBlockSize));

View File

@ -29,7 +29,7 @@ namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
CheckAndConvertUtils::CheckInteger("input_x rank", x_shape.size(), kEqual, 4, prim_name);
auto out_shape = x_shape;
int64_t block_shape_prod = 1;

View File

@ -43,8 +43,7 @@ AbstractBasePtr SparseSoftmaxCrossEntropyWithLogitsInfer(const abstract::Analysi
MS_EXCEPTION_IF_NULL(item);
}
// infer shape
auto input_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), prim_name);
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
std::vector<int64_t> output_shape;
if (GetValue<bool>(primitive->GetAttr(kIsGrad)) != 0) {
output_shape = input_shape;

View File

@ -33,8 +33,7 @@ AbstractBasePtr SparseToDenseInfer(const abstract::AnalysisEnginePtr &, const Pr
MS_EXCEPTION_IF_NULL(item);
}
// infer shape
auto dense_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("dense_shape", input_args[3]->BuildShape(), prim_name);
auto dense_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[3]->BuildShape())[kShape];
// infer type
auto values_type = input_args[1]->BuildType()->cast<TensorTypePtr>()->element();
return std::make_shared<abstract::AbstractTensor>(values_type, dense_shape);

View File

@ -29,7 +29,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
auto axis = GetValue<std::vector<int64_t>>(primitive->GetAttr(kAxis));
std::vector<int64_t> infer_shape;
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->GetShapeTrack(), op_name);
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
auto len = SizeToLong(in_shape.size());
if (axis.empty()) {
std::copy_if(in_shape.begin(), in_shape.end(), std::back_inserter(infer_shape),

View File

@ -21,7 +21,6 @@ namespace ops {
namespace {
abstract::AbstractBasePtr StackInfer(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
if (input_args.size() != 1) {
MS_LOG(ERROR) << "Invalid output size:" << input_args.size();
@ -29,11 +28,9 @@ abstract::AbstractBasePtr StackInfer(const PrimitivePtr &primitive, const std::v
if (input_args.size() < 1) {
MS_LOG(ERROR) << "Invalid input size " << input_args.size();
}
auto input_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), prim_name);
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
for (int64_t i = 1; i < (int64_t)input_args.size(); ++i) {
auto input_shape_tmp =
CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[i]->BuildShape(), prim_name);
auto input_shape_tmp = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[i]->BuildShape())[kShape];
if (input_shape_tmp.size() != input_shape.size()) {
MS_LOG(ERROR) << "All input shape size should be the same!";
}

View File

@ -108,7 +108,7 @@ abstract::ShapePtr StridedSliceInferShape(const PrimitivePtr &primitive,
auto temp_strides_v = input_args[3]->cast<abstract::AbstractTuplePtr>()->BuildValue();
auto strides_v = GetValue<std::vector<int64_t>>(temp_strides_v);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
int64_t x_rank = x_shape.size();
int64_t slice_len = begin_v.size();
std::vector<int64_t> begin_pos = TenToTwo(GetValue<int64_t>(primitive->GetAttr(kBeginMask)));

View File

@ -33,7 +33,7 @@ AbstractBasePtr TanInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr
CheckAndConvertUtils::CheckInteger("tan_infer", input_args.size(), kEqual, 1, prim_name);
// Infer Shape
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto infer_shape = std::make_shared<abstract::Shape>(x_shape);
// Infer Type

View File

@ -24,11 +24,8 @@ namespace {
abstract::ShapePtr TensorListFromTensorInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
auto input0_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("input0 shape", input_args[0]->BuildShape(), prim_name);
auto input1_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("input1 shape", input_args[1]->BuildShape(), prim_name);
auto input0_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto input1_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
if (input0_shape.size() < 1) {
MS_LOG(ERROR) << "input0_shape.size():" << input0_shape.size() << " must be greater than 0!";
}

View File

@ -52,9 +52,7 @@ AbstractBasePtr TensorListStackInfer(const abstract::AnalysisEnginePtr &, const
for (const auto &input : input_args) {
MS_EXCEPTION_IF_NULL(input);
}
auto op_name = primitive->name();
auto input0_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("input0_shape", input_args[0]->BuildShape(), op_name);
auto input0_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
int64_t num = std::accumulate(input0_shape.begin(), input0_shape.end(), 1LL, std::multiplies<int64_t>());
if (num == 0) {
MS_LOG(ERROR) << "Try to stack a empty tensorlist!";
@ -62,8 +60,7 @@ AbstractBasePtr TensorListStackInfer(const abstract::AnalysisEnginePtr &, const
if (input_args[1]->BuildShape() == nullptr) {
MS_LOG(ERROR) << "ele_shape->data_c() is nullptr";
}
auto input1_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("input1_shape", input_args[1]->BuildShape(), op_name);
auto input1_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
input1_shape.insert(input1_shape.begin(), 1);
return std::make_shared<abstract::AbstractTensor>(input_args[0]->BuildType(), input1_shape);
}

View File

@ -29,7 +29,7 @@ abstract::ShapePtr TensorSummaryInferShape(const PrimitivePtr &primitive,
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
// check
auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShape("v_shape", input_args[1]->BuildShape(), prim_name);
auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
CheckAndConvertUtils::CheckInteger("v rank", v_shape.size(), kGreaterEqual, 1, prim_name);
return std::make_shared<abstract::Shape>(ShapeVector(1));
}

Some files were not shown because too many files have changed in this diff Show More