forked from mindspore-Ecosystem/mindspore
!15218 remove ConvertShapePtrToShape function, use ConvertShapePtrToShapeMap instead
From: @simson_wu Reviewed-by: @ginfung,@chujinjin Signed-off-by: @chujinjin
This commit is contained in:
commit
cd746e8d52
|
@ -30,11 +30,10 @@ namespace ops {
|
|||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), prim_name);
|
||||
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
|
||||
return std::make_shared<abstract::Shape>(in_shape);
|
||||
}
|
||||
|
||||
|
|
|
@ -26,11 +26,10 @@ abstract::AbstractBasePtr AdamInfer(const PrimitivePtr &primitive, const std::ve
|
|||
auto prim_name = primitive->name();
|
||||
|
||||
// infer shape
|
||||
auto var_shape = CheckAndConvertUtils::ConvertShapePtrToShape("var_shape", input_args[0]->GetShapeTrack(), prim_name);
|
||||
auto m_shape = CheckAndConvertUtils::ConvertShapePtrToShape("m_shape", input_args[1]->GetShapeTrack(), prim_name);
|
||||
auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShape("v_shape", input_args[2]->GetShapeTrack(), prim_name);
|
||||
auto grad_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShape("grad_shape", input_args[9]->GetShapeTrack(), prim_name);
|
||||
auto var_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
|
||||
auto m_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->GetShapeTrack())[kShape];
|
||||
auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->GetShapeTrack())[kShape];
|
||||
auto grad_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[9]->GetShapeTrack())[kShape];
|
||||
CheckAndConvertUtils::Check("var_shape", var_shape, kEqual, "m_shape", m_shape, prim_name);
|
||||
CheckAndConvertUtils::Check("var_shape", var_shape, kEqual, "v_shape", v_shape, prim_name);
|
||||
CheckAndConvertUtils::Check("var_shape", var_shape, kEqual, "grad_shape", grad_shape, prim_name);
|
||||
|
|
|
@ -38,15 +38,13 @@ AbstractBasePtr AddNInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt
|
|||
CheckAndConvertUtils::CheckInteger("concat element num", elements.size(), kGreaterEqual, 1, prim_name);
|
||||
auto element0 = elements[0]->cast<abstract::AbstractTensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(element0);
|
||||
auto element0_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShape("element0 shape", element0->BuildShape(), prim_name);
|
||||
auto element0_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(element0->BuildShape())[kShape];
|
||||
|
||||
std::map<std::string, TypePtr> types;
|
||||
types.emplace("element0", element0->BuildType());
|
||||
for (size_t i = 1; i < elements.size(); ++i) {
|
||||
std::string elementi = "element" + std::to_string(i);
|
||||
auto elementi_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShape(elementi + " shape", elements[i]->BuildShape(), prim_name);
|
||||
auto elementi_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(elements[i]->BuildShape())[kShape];
|
||||
CheckAndConvertUtils::CheckInteger(elementi + " shape rank", elementi_shape.size(), kEqual, element0_shape.size(),
|
||||
prim_name);
|
||||
for (size_t j = 0; j < element0_shape.size(); ++j) {
|
||||
|
|
|
@ -60,7 +60,7 @@ AbstractBasePtr ApplyMomentumInfer(const abstract::AnalysisEnginePtr &, const Pr
|
|||
CheckAndConvertUtils::CheckInteger("apply_momentum_infer", input_args.size(), kEqual, 5, prim_name);
|
||||
|
||||
// Infer shape
|
||||
auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShape("v_shape", input_args[0]->BuildShape(), prim_name);
|
||||
auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
|
||||
// Infer type
|
||||
auto v_tensor_type = input_args[0]->BuildType();
|
||||
|
|
|
@ -23,7 +23,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
auto axis = GetValue<int64_t>(primitive->GetAttr(kAxis));
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
auto x_rank = SizeToLong(x_shape.size());
|
||||
CheckAndConvertUtils::CheckInRange<int64_t>("argmax axis", axis, kIncludeLeft, {-x_rank, x_rank}, prim_name);
|
||||
axis = axis < 0 ? axis + x_rank : axis;
|
||||
|
|
|
@ -42,7 +42,7 @@ AbstractBasePtr ArgMinInfer(const abstract::AnalysisEnginePtr &, const Primitive
|
|||
|
||||
// Infer shape
|
||||
auto axis = GetValue<int64_t>(primitive->GetAttr(kAxis));
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
auto x_rank = SizeToLong(x_shape.size());
|
||||
CheckAndConvertUtils::CheckInRange<int64_t>("axis", axis, kIncludeLeft, {-x_rank, x_rank}, prim_name);
|
||||
if (axis < 0) {
|
||||
|
|
|
@ -29,7 +29,7 @@ AbstractBasePtr AsinInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt
|
|||
CheckAndConvertUtils::CheckInteger("Asin_infer", input_args.size(), kEqual, 1, prim_name);
|
||||
|
||||
// Infer Shape
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
auto infer_shape = std::make_shared<abstract::Shape>(x_shape);
|
||||
|
||||
// Infer Type
|
||||
|
|
|
@ -47,8 +47,7 @@ AbstractBasePtr AssertInfer(const abstract::AnalysisEnginePtr &, const Primitive
|
|||
}
|
||||
condition = TypeIdToType(kNumberTypeBool);
|
||||
} else {
|
||||
auto condition_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), op_name);
|
||||
auto condition_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
CheckAndConvertUtils::CheckInteger("condition's rank", condition_shape[0], kLessEqual, 1, op_name);
|
||||
if (condition_shape[0] == 1) {
|
||||
auto condition_value = reinterpret_cast<bool *>(input_args[0]->BuildValue()->cast<tensor::TensorPtr>()->data_c());
|
||||
|
|
|
@ -25,9 +25,7 @@ namespace ops {
|
|||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
auto value_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShape("value_shape", input_args[1]->BuildShape(), prim_name);
|
||||
auto value_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
|
||||
return std::make_shared<abstract::Shape>(value_shape);
|
||||
}
|
||||
|
||||
|
|
|
@ -27,7 +27,7 @@ AbstractBasePtr AtanInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt
|
|||
CheckAndConvertUtils::CheckInteger("Atan_infer", input_args.size(), kEqual, 1, prim_name);
|
||||
|
||||
// Infer Shape
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
auto infer_shape = std::make_shared<abstract::Shape>(x_shape);
|
||||
|
||||
// Infer Type
|
||||
|
|
|
@ -30,9 +30,7 @@ namespace {
|
|||
abstract::ShapePtr AudioSpectrogramInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
auto input_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), prim_name);
|
||||
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
if (input_shape.size() != 2) {
|
||||
MS_LOG(ERROR) << "input shape is error, which need to be 2 dimensions";
|
||||
}
|
||||
|
|
|
@ -82,7 +82,7 @@ namespace {
|
|||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto op_name = primitive->name();
|
||||
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), op_name);
|
||||
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
|
||||
auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat)));
|
||||
if (format == NHWC) {
|
||||
in_shape = {in_shape[0], in_shape[3], in_shape[1], in_shape[2]};
|
||||
|
|
|
@ -75,20 +75,19 @@ AbstractBasePtr BatchNormInfer(const abstract::AnalysisEnginePtr &, const Primit
|
|||
auto prim_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInteger("batch_norm_infer", input_args.size(), kEqual, 5, prim_name);
|
||||
|
||||
auto input_x = CheckAndConvertUtils::ConvertShapePtrToShape("input_x", input_args[0]->BuildShape(), prim_name);
|
||||
auto input_x = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat)));
|
||||
if (format == NHWC) {
|
||||
input_x = {input_x[0], input_x[3], input_x[1], input_x[2]};
|
||||
}
|
||||
auto scale = CheckAndConvertUtils::ConvertShapePtrToShape("scale", input_args[1]->BuildShape(), prim_name);
|
||||
auto bias = CheckAndConvertUtils::ConvertShapePtrToShape("bias", input_args[2]->BuildShape(), prim_name);
|
||||
auto mean = CheckAndConvertUtils::ConvertShapePtrToShape("mean", input_args[3]->BuildShape(), prim_name);
|
||||
auto variance = CheckAndConvertUtils::ConvertShapePtrToShape("variance", input_args[4]->BuildShape(), prim_name);
|
||||
auto scale = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
|
||||
auto bias = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
|
||||
auto mean = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[3]->BuildShape())[kShape];
|
||||
auto variance = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[4]->BuildShape())[kShape];
|
||||
|
||||
std::vector<int64_t> input_shape_norm;
|
||||
if (format == NCHW) {
|
||||
input_shape_norm =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), prim_name);
|
||||
input_shape_norm = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
|
||||
} else {
|
||||
input_shape_norm.push_back(input_x[0]);
|
||||
input_shape_norm.push_back(input_x[3]);
|
||||
|
|
|
@ -68,12 +68,10 @@ AbstractBasePtr BatchNormFoldInfer(const abstract::AnalysisEnginePtr &, const Pr
|
|||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto op_name = primitive->name();
|
||||
auto mean_shape = CheckAndConvertUtils::ConvertShapePtrToShape("mean_shape", input_args[1]->BuildShape(), op_name);
|
||||
auto variance_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShape("variance_shape", input_args[2]->BuildShape(), op_name);
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), op_name);
|
||||
auto global_step_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShape("global_step_shape", input_args[3]->BuildShape(), op_name);
|
||||
auto mean_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
|
||||
auto variance_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
auto global_step_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[3]->BuildShape())[kShape];
|
||||
CheckAndConvertUtils::Check("mean_shape", mean_shape, kEqual, "gamma_shape", variance_shape, op_name);
|
||||
CheckAndConvertUtils::Check("mean_shape[0]", mean_shape[0], kEqual, "input channel", x_shape[1], op_name);
|
||||
CheckAndConvertUtils::CheckInteger("global step shape len", global_step_shape.size(), kEqual, 1, op_name);
|
||||
|
|
|
@ -55,7 +55,7 @@ AbstractBasePtr BatchToSpaceInfer(const abstract::AnalysisEnginePtr &, const Pri
|
|||
(void)CheckAndConvertUtils::CheckTensorTypeValid("input_x", input_args[0]->BuildType(), common_valid_types,
|
||||
prim_name);
|
||||
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kEqual, 4, prim_name);
|
||||
auto block_size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kBlockSize));
|
||||
auto crops = GetValue<std::vector<std::vector<int64_t>>>(primitive->GetAttr(kCrops));
|
||||
|
|
|
@ -29,7 +29,7 @@ namespace {
|
|||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
CheckAndConvertUtils::CheckInteger("input_x rank", x_shape.size(), kEqual, 4, prim_name);
|
||||
auto out_shape = x_shape;
|
||||
int64_t block_shape_prod = 1;
|
||||
|
|
|
@ -30,8 +30,8 @@ abstract::ShapePtr BiasAddInferShape(const PrimitivePtr &primitive, const std::v
|
|||
auto prim_name = primitive->name();
|
||||
// check
|
||||
CheckAndConvertUtils::CheckInteger("arg size", input_args.size(), kEqual, 2, prim_name);
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
|
||||
auto b_shape = CheckAndConvertUtils::ConvertShapePtrToShape("b_shape", input_args[1]->BuildShape(), prim_name);
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
auto b_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
|
||||
CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kGreaterEqual, 2, prim_name);
|
||||
CheckAndConvertUtils::CheckInteger("bias rank", b_shape.size(), kEqual, 1, prim_name);
|
||||
auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat)));
|
||||
|
|
|
@ -34,10 +34,9 @@ abstract::ShapePtr BinaryCrossEntroyInferShape(const PrimitivePtr &primitive,
|
|||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInRange("binary_cross_entropy_infer", input_args.size(), kIncludeBoth, {2, 3}, prim_name);
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
|
||||
auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShape("y_shape", input_args[1]->BuildShape(), prim_name);
|
||||
auto weight_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShape("weight_shape", input_args[2]->BuildShape(), prim_name);
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
|
||||
auto weight_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
|
||||
CheckAndConvertUtils::Check("x shape", x_shape, kEqual, "y shape", y_shape, prim_name);
|
||||
std::vector<int64_t> infer_shape;
|
||||
if (weight_shape.size() < 1) {
|
||||
|
|
|
@ -50,7 +50,7 @@ AbstractBasePtr BroadcastInfer(const abstract::AnalysisEnginePtr &, const Primit
|
|||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
// infer shape
|
||||
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
|
||||
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
// infer type
|
||||
auto x_type = input_args[0]->BuildType()->cast<TensorTypePtr>()->element();
|
||||
std::vector<TypePtr> output_types;
|
||||
|
|
|
@ -24,7 +24,7 @@ abstract::ShapePtr BroadcastToInferShape(const PrimitivePtr &primitive,
|
|||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
auto value_ptr = primitive->GetAttr(kShape);
|
||||
auto input_x = GetValue<std::vector<int64_t>>(value_ptr);
|
||||
int64_t outer_dim_offset = input_x.size() - x_shape.size();
|
||||
|
|
|
@ -31,7 +31,7 @@ AbstractBasePtr CeilInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt
|
|||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), "Ceil");
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
|
||||
auto infer_type = input_args[0]->BuildType();
|
||||
auto data_type = CheckAndConvertUtils::CheckTensorTypeValid("x type", infer_type, valid_types, primitive->name());
|
||||
|
|
|
@ -43,8 +43,7 @@ AbstractBasePtr ConcatInfer(const abstract::AnalysisEnginePtr &, const Primitive
|
|||
CheckAndConvertUtils::CheckInteger("concat element num", elements.size(), kGreaterEqual, 1, prim_name);
|
||||
auto element0 = elements[0]->cast<abstract::AbstractTensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(element0);
|
||||
auto element0_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShape("element0 shape", element0->BuildShape(), prim_name);
|
||||
auto element0_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(element0->BuildShape())[kShape];
|
||||
auto element0_rank = SizeToLong(element0_shape.size());
|
||||
auto axis = GetValue<int64_t>(primitive->GetAttr(kAxis));
|
||||
CheckAndConvertUtils::CheckInRange<int64_t>("Concat axis", axis, kIncludeBoth, {-element0_rank - 1, element0_rank},
|
||||
|
@ -56,8 +55,7 @@ AbstractBasePtr ConcatInfer(const abstract::AnalysisEnginePtr &, const Primitive
|
|||
int64_t all_shp = element0_shape[axis];
|
||||
for (size_t i = 1; i < elements.size(); ++i) {
|
||||
std::string elementi = "element" + std::to_string(i);
|
||||
auto elementi_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShape(elementi + " shape", elements[i]->BuildShape(), prim_name);
|
||||
auto elementi_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(elements[i]->BuildShape())[kShape];
|
||||
CheckAndConvertUtils::CheckInteger(elementi + " shape rank", elementi_shape.size(), kEqual, element0_shape.size(),
|
||||
prim_name);
|
||||
for (int64_t j = 0; j < element0_rank; ++j) {
|
||||
|
|
|
@ -24,8 +24,7 @@ namespace ops {
|
|||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
CheckAndConvertUtils::CheckInteger("input args size", input_args.size(), kEqual, 1, "ConstantOfShape");
|
||||
auto input_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), "ConstantOfShape");
|
||||
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
return std::make_shared<abstract::Shape>(input_shape);
|
||||
}
|
||||
|
||||
|
|
|
@ -79,8 +79,8 @@ abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::ve
|
|||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInRange<size_t>("conv2d_infer", input_args.size(), kIncludeBoth, {2, 3}, prim_name);
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
|
||||
auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShape("w_shape", input_args[1]->BuildShape(), prim_name);
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
|
||||
auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat)));
|
||||
if (format == NHWC) {
|
||||
x_shape = {x_shape[0], x_shape[3], x_shape[1], x_shape[2]};
|
||||
|
|
|
@ -28,8 +28,7 @@ namespace {
|
|||
abstract::ShapePtr Conv2dTransposeInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[3]->BuildShape(), prim_name);
|
||||
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[3]->BuildShape())[kShape];
|
||||
return std::make_shared<abstract::Shape>(input_shape);
|
||||
}
|
||||
|
||||
|
|
|
@ -24,11 +24,10 @@ namespace ops {
|
|||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
|
||||
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
return std::make_shared<abstract::Shape>(in_shape);
|
||||
}
|
||||
|
||||
|
|
|
@ -49,7 +49,7 @@ AbstractBasePtr CropInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt
|
|||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
// infer shape
|
||||
auto out_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[1]->BuildShape(), prim_name);
|
||||
auto out_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
|
||||
// infer type
|
||||
auto x_type = input_args[0]->BuildType()->cast<TensorTypePtr>()->element();
|
||||
return std::make_shared<abstract::AbstractTensor>(x_type, out_shape);
|
||||
|
|
|
@ -24,18 +24,14 @@ namespace ops {
|
|||
AbstractBasePtr CustomExtractFeaturesInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
MS_EXCEPTION_IF_NULL(input_args[0]);
|
||||
// auto input = input_args[0];
|
||||
|
||||
// Infer type
|
||||
auto output0_type = kInt32;
|
||||
auto output1_type = kFloat32;
|
||||
|
||||
// Infer shape
|
||||
std::vector<int64_t> out_shape;
|
||||
auto input_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), prim_name);
|
||||
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
auto string_num = input_shape[0];
|
||||
if (string_num == 0) {
|
||||
out_shape.push_back(1);
|
||||
|
|
|
@ -54,7 +54,7 @@ AbstractBasePtr DepthToSpaceInfer(const abstract::AnalysisEnginePtr &, const Pri
|
|||
auto input_x = input_args[0]->cast<abstract::AbstractTensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(input_x);
|
||||
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat)));
|
||||
if (format == NHWC) {
|
||||
x_shape = {x_shape[0], x_shape[3], x_shape[1], x_shape[2]};
|
||||
|
|
|
@ -119,8 +119,8 @@ abstract::ShapePtr DepthWiseConv2DInferShape(const PrimitivePtr &primitive,
|
|||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInRange<size_t>("conv2d_Infer", input_args.size(), kIncludeBoth, {2, 3}, prim_name);
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), prim_name);
|
||||
auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShape("w_shape", input_args[1]->GetShapeTrack(), prim_name);
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
|
||||
auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->GetShapeTrack())[kShape];
|
||||
auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat)));
|
||||
if (format == NHWC) {
|
||||
x_shape = {x_shape[0], x_shape[3], x_shape[1], x_shape[2]};
|
||||
|
|
|
@ -120,9 +120,9 @@ AbstractBasePtr DetectionPostProcessInfer(const abstract::AnalysisEnginePtr &, c
|
|||
auto boxes = input_args[0];
|
||||
auto scores = input_args[1];
|
||||
auto anchors = input_args[2];
|
||||
auto boxes_shape = CheckAndConvertUtils::ConvertShapePtrToShape("boxes_shape", boxes->BuildShape(), prim_name);
|
||||
auto scores_shape = CheckAndConvertUtils::ConvertShapePtrToShape("scores_shape", scores->BuildShape(), prim_name);
|
||||
auto anchors_shape = CheckAndConvertUtils::ConvertShapePtrToShape("anchors_shape", anchors->BuildShape(), prim_name);
|
||||
auto boxes_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(boxes->BuildShape())[kShape];
|
||||
auto scores_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(scores->BuildShape())[kShape];
|
||||
auto anchors_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(anchors->BuildShape())[kShape];
|
||||
auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat)));
|
||||
if (format == NHWC) {
|
||||
boxes_shape = {boxes_shape[0], boxes_shape[3], boxes_shape[1], boxes_shape[2]};
|
||||
|
|
|
@ -43,7 +43,7 @@ AbstractBasePtr DropoutInfer(const abstract::AnalysisEnginePtr &, const Primitiv
|
|||
CheckAndConvertUtils::CheckInteger("dropout_infer", input_args.size(), kEqual, 1, prim_name);
|
||||
|
||||
// Infer shape
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
CheckAndConvertUtils::CheckInteger("x_shape", x_shape.size(), kGreaterEqual, 1, prim_name);
|
||||
std::vector<int64_t> out_shape;
|
||||
out_shape.insert(out_shape.end(), x_shape.begin(), x_shape.end());
|
||||
|
|
|
@ -31,11 +31,10 @@ namespace ops {
|
|||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto op_name = primitive->name();
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->GetShapeTrack(), op_name);
|
||||
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
|
||||
return std::make_shared<abstract::Shape>(in_shape);
|
||||
}
|
||||
|
||||
|
|
|
@ -36,7 +36,7 @@ AbstractBasePtr ExpandDimsInfer(const abstract::AnalysisEnginePtr &, const Primi
|
|||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
// Infer shape
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
auto dim_val = GetValue<int64_t>(input_args[1]->BuildValue());
|
||||
auto rank = x_shape.size();
|
||||
CheckAndConvertUtils::CheckInRange<int64_t>("axis", dim_val, kIncludeBoth, {-rank - 1, rank}, prim_name);
|
||||
|
|
|
@ -29,9 +29,9 @@ namespace {
|
|||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
|
||||
auto min_shape = CheckAndConvertUtils::ConvertShapePtrToShape("min_shape", input_args[1]->BuildShape(), prim_name);
|
||||
auto max_shape = CheckAndConvertUtils::ConvertShapePtrToShape("max_shape", input_args[2]->BuildShape(), prim_name);
|
||||
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
auto min_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
|
||||
auto max_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
|
||||
CheckAndConvertUtils::CheckInteger("x_rank", in_shape.size(), kGreaterEqual, 1, prim_name);
|
||||
CheckAndConvertUtils::Check("min_shape", min_shape, kEqual, "max_shape", max_shape, prim_name);
|
||||
CheckAndConvertUtils::CheckInteger("min_shape", min_shape.size(), kEqual, 1, prim_name);
|
||||
|
|
|
@ -44,9 +44,9 @@ AbstractBasePtr FakeQuantWithMinMaxVarsPerChannelInfer(const abstract::AnalysisE
|
|||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto op_name = primitive->name();
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), op_name);
|
||||
auto min_shape = CheckAndConvertUtils::ConvertShapePtrToShape("min_shape", input_args[1]->BuildShape(), op_name);
|
||||
auto max_shape = CheckAndConvertUtils::ConvertShapePtrToShape("max_shape", input_args[2]->BuildShape(), op_name);
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
auto min_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
|
||||
auto max_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
|
||||
CheckAndConvertUtils::CheckInteger("x rank", (int64_t)x_shape.size(), kGreaterThan, 1, op_name);
|
||||
CheckAndConvertUtils::Check("min shape", min_shape, kEqual, "max shape", max_shape, op_name);
|
||||
CheckAndConvertUtils::CheckInteger("min shape", (int64_t)min_shape.size(), kEqual, 1, op_name);
|
||||
|
|
|
@ -24,8 +24,7 @@ namespace ops {
|
|||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("in_shape", input_args[0]->BuildShape(), prim_name);
|
||||
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
in_shape.pop_back();
|
||||
return std::make_shared<abstract::Shape>(in_shape);
|
||||
}
|
||||
|
|
|
@ -33,7 +33,7 @@ AbstractBasePtr FftRealInfer(const abstract::AnalysisEnginePtr &, const Primitiv
|
|||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
auto out_dtype = kFloat32;
|
||||
auto out_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
|
||||
auto out_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
out_shape.pop_back();
|
||||
return std::make_shared<abstract::AbstractTensor>(out_dtype, std::make_shared<abstract::Shape>(out_shape));
|
||||
}
|
||||
|
|
|
@ -25,7 +25,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInteger("input args size", input_args.size(), kGreaterEqual, 1, prim_name);
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
auto prod = 1;
|
||||
int64_t size = x_shape.size();
|
||||
for (int64_t i = 1; i < size; i++) {
|
||||
|
|
|
@ -28,11 +28,10 @@ namespace ops {
|
|||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), prim_name);
|
||||
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
|
||||
return std::make_shared<abstract::Shape>(in_shape);
|
||||
}
|
||||
|
||||
|
|
|
@ -53,7 +53,7 @@ namespace {
|
|||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto op_name = primitive->name();
|
||||
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), op_name);
|
||||
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
|
||||
auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat)));
|
||||
if (format == NHWC) {
|
||||
in_shape = {in_shape[0], in_shape[3], in_shape[1], in_shape[2]};
|
||||
|
|
|
@ -53,8 +53,8 @@ AbstractBasePtr FullConnectionInfer(const abstract::AnalysisEnginePtr &, const P
|
|||
MS_EXCEPTION_IF_NULL(input_args[1]);
|
||||
auto input0 = input_args[0];
|
||||
auto input1 = input_args[1];
|
||||
auto input0_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input0_shape", input0->BuildShape(), prim_name);
|
||||
auto input1_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input1_shape", input1->BuildShape(), prim_name);
|
||||
auto input0_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input0->BuildShape())[kShape];
|
||||
auto input1_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input1->BuildShape())[kShape];
|
||||
auto prim_axis = GetValue<int64_t>(primitive->GetAttr(kAxis));
|
||||
auto has_bias = GetValue<bool>(primitive->GetAttr(kHasBias));
|
||||
if (has_bias) {
|
||||
|
@ -78,8 +78,7 @@ AbstractBasePtr FullConnectionInfer(const abstract::AnalysisEnginePtr &, const P
|
|||
new_k = input1_shape[1];
|
||||
}
|
||||
if (has_bias) {
|
||||
auto input2_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShape("input2_shape", input_args[2]->BuildShape(), prim_name);
|
||||
auto input2_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
|
||||
if (input2_shape[0] != input1_shape[0]) {
|
||||
MS_EXCEPTION(ValueError) << "Bias size invalid";
|
||||
}
|
||||
|
|
|
@ -53,7 +53,7 @@ namespace {
|
|||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto op_name = primitive->name();
|
||||
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), op_name);
|
||||
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
|
||||
auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat)));
|
||||
if (format == NHWC) {
|
||||
in_shape = {in_shape[0], in_shape[3], in_shape[1], in_shape[2]};
|
||||
|
|
|
@ -33,7 +33,7 @@ AbstractBasePtr SliceFusionInfer(const abstract::AnalysisEnginePtr &, const Prim
|
|||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto op_name = primitive->name();
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), op_name);
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
auto x_shape_len = (int64_t)x_shape.size();
|
||||
auto begin_v = input_args[1]->BuildValue();
|
||||
auto size_v = input_args[2]->BuildValue();
|
||||
|
|
|
@ -29,8 +29,8 @@ abstract::ShapePtr GatherDInferShape(const PrimitivePtr &primitive, const std::v
|
|||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
// check
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
|
||||
auto index_shape = CheckAndConvertUtils::ConvertShapePtrToShape("dim_shape", input_args[2]->BuildShape(), prim_name);
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
auto index_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
|
||||
int64_t x_rank = x_shape.size();
|
||||
CheckAndConvertUtils::Check("x_rank", x_rank, kEqual, "index_rank", index_shape.size(), prim_name);
|
||||
auto dim_v = GetValue<int64_t>(input_args[1]->BuildValue());
|
||||
|
|
|
@ -32,9 +32,8 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
|
||||
auto indices_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShape("indices_shape", input_args[1]->BuildShape(), prim_name);
|
||||
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
auto indices_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
|
||||
auto input_rank = input_shape.size();
|
||||
auto indices_rank = indices_shape.size();
|
||||
CheckAndConvertUtils::CheckInteger("Input of indices data", input_rank, kGreaterEqual,
|
||||
|
|
|
@ -28,8 +28,7 @@ namespace ops {
|
|||
namespace {
|
||||
abstract::ShapePtr GeLUInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input_x", input_args[0]->BuildShape(), prim_name);
|
||||
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
return std::make_shared<abstract::Shape>(input_shape);
|
||||
}
|
||||
|
||||
|
|
|
@ -47,13 +47,11 @@ bool BatchNormGrad::get_is_training() const {
|
|||
AbstractBasePtr BatchNormGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto op_name = primitive->name();
|
||||
MS_EXCEPTION_IF_NULL(input_args[1]);
|
||||
MS_EXCEPTION_IF_NULL(input_args[2]);
|
||||
MS_EXCEPTION_IF_NULL(input_args[3]);
|
||||
auto y_backprop_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShape("y_backprop_shape", input_args[0]->BuildShape(), op_name);
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[1]->BuildShape(), op_name);
|
||||
auto y_backprop_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
|
||||
CheckAndConvertUtils::Check("BatchNorm y_backprop_shape", y_backprop_shape, kEqual, "BatchNorm x_shape", x_shape);
|
||||
|
||||
auto dx = input_args[1]->Broaden();
|
||||
|
|
|
@ -46,7 +46,7 @@ AbstractBasePtr BiasAddGradInfer(const abstract::AnalysisEnginePtr &, const Prim
|
|||
MS_EXCEPTION_IF_NULL(input_args[0]);
|
||||
|
||||
// Infer shape
|
||||
auto inshape = CheckAndConvertUtils::ConvertShapePtrToShape("inshape", input_args[0]->BuildShape(), prim_name);
|
||||
auto inshape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
for (size_t i = 0; i < inshape.size() - 1; i++) {
|
||||
inshape[i] = 1;
|
||||
}
|
||||
|
|
|
@ -27,10 +27,9 @@ abstract::ShapePtr BinaryCrossEntroyGradInferShape(const PrimitivePtr &primitive
|
|||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
|
||||
auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShape("y_shape", input_args[1]->BuildShape(), prim_name);
|
||||
auto weight_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShape("weight_shape", input_args[2]->BuildShape(), prim_name);
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
|
||||
auto weight_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
|
||||
CheckAndConvertUtils::Check("x shape", x_shape, kEqual, "y shape", y_shape, prim_name);
|
||||
if (weight_shape.size() < 1) {
|
||||
CheckAndConvertUtils::Check("y shape", y_shape, kEqual, "weight shape", weight_shape, prim_name);
|
||||
|
|
|
@ -35,8 +35,7 @@ namespace {
|
|||
abstract::ShapePtr DropoutGradInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto op_name = primitive->name();
|
||||
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), op_name);
|
||||
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
return std::make_shared<abstract::Shape>(in_shape);
|
||||
}
|
||||
|
||||
|
|
|
@ -21,10 +21,8 @@ namespace mindspore {
|
|||
namespace ops {
|
||||
AbstractBasePtr MaxPoolGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto op_name = primitive->name();
|
||||
MS_EXCEPTION_IF_NULL(input_args[0]->BuildValue());
|
||||
auto x1_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x1_shape", input_args[0]->BuildShape(), op_name);
|
||||
auto x1_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
auto tensor_type = input_args[0]->BuildType()->cast<TensorTypePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tensor_type);
|
||||
auto element = tensor_type->element();
|
||||
|
|
|
@ -35,9 +35,9 @@ AbstractBasePtr SigmoidCrossEntropyWithLogitsGradInfer(const abstract::AnalysisE
|
|||
prim_name);
|
||||
|
||||
// Infer Shape
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
|
||||
auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShape("y_shape", input_args[1]->BuildShape(), prim_name);
|
||||
auto dout_shape = CheckAndConvertUtils::ConvertShapePtrToShape("dout_shape", input_args[2]->BuildShape(), prim_name);
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
|
||||
auto dout_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
|
||||
CheckAndConvertUtils::Check("x_shape", x_shape, kEqual, "y_shape", y_shape, prim_name, TypeError);
|
||||
CheckAndConvertUtils::Check("x_shape", x_shape, kEqual, "dout_shape", dout_shape, prim_name, TypeError);
|
||||
|
||||
|
|
|
@ -40,9 +40,9 @@ AbstractBasePtr SmoothL1LossGradInfer(const abstract::AnalysisEnginePtr &, const
|
|||
CheckAndConvertUtils::CheckInteger("smooth_l1_loss_grad_infer", input_args.size(), kEqual, 3, prim_name);
|
||||
|
||||
// Infer shape
|
||||
auto prediction = CheckAndConvertUtils::ConvertShapePtrToShape("prediction", input_args[0]->BuildShape(), prim_name);
|
||||
auto target = CheckAndConvertUtils::ConvertShapePtrToShape("target", input_args[1]->BuildShape(), prim_name);
|
||||
auto dloss = CheckAndConvertUtils::ConvertShapePtrToShape("dloss", input_args[2]->BuildShape(), prim_name);
|
||||
auto prediction = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
auto target = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
|
||||
auto dloss = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
|
||||
CheckAndConvertUtils::Check("prediction shape", prediction, kEqual, "target shape", target, prim_name, TypeError);
|
||||
CheckAndConvertUtils::Check("prediction shape", prediction, kEqual, "dloss", dloss, prim_name, TypeError);
|
||||
|
||||
|
|
|
@ -27,9 +27,8 @@ AbstractBasePtr HashtableLookupInfer(const abstract::AnalysisEnginePtr &, const
|
|||
for (auto input : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(input);
|
||||
}
|
||||
auto op_name = primitive->name();
|
||||
std::vector<int64_t> hits_shape;
|
||||
auto input = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), op_name);
|
||||
auto input = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
hits_shape.push_back(input[0]);
|
||||
|
||||
auto value_type = input_args[2]->BuildType();
|
||||
|
|
|
@ -46,7 +46,7 @@ AbstractBasePtr L2NormalizeInfer(const abstract::AnalysisEnginePtr &, const Prim
|
|||
}
|
||||
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("input_x", input_args[0]->BuildType(), valid_types, prim_name);
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
auto x_rank = SizeToLong(x_shape.size());
|
||||
auto axiss = GetValue<std::vector<int64_t>>(primitive->GetAttr(kAxis));
|
||||
for (auto &axis : axiss) {
|
||||
|
|
|
@ -24,7 +24,7 @@ namespace mindspore {
|
|||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), "Log");
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
return std::make_shared<abstract::Shape>(x_shape);
|
||||
}
|
||||
|
||||
|
|
|
@ -24,8 +24,7 @@ namespace ops {
|
|||
namespace {
|
||||
abstract::ShapePtr LogicalNotInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto op_name = primitive->name();
|
||||
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), op_name);
|
||||
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
return std::make_shared<abstract::Shape>(in_shape);
|
||||
}
|
||||
|
||||
|
|
|
@ -78,7 +78,7 @@ namespace {
|
|||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
|
||||
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
CheckAndConvertUtils::CheckInteger("input shape", in_shape.size(), kEqual, 4, prim_name);
|
||||
return std::make_shared<abstract::Shape>(in_shape);
|
||||
}
|
||||
|
|
|
@ -32,14 +32,14 @@ AbstractBasePtr LshProjectionInfer(const abstract::AnalysisEnginePtr &, const Pr
|
|||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto op_name = primitive->name();
|
||||
auto input0 = CheckAndConvertUtils::ConvertShapePtrToShape("input0_shape", input_args[0]->BuildShape(), op_name);
|
||||
auto input1 = CheckAndConvertUtils::ConvertShapePtrToShape("input1_shape", input_args[1]->BuildShape(), op_name);
|
||||
auto input0 = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
auto input1 = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
|
||||
CheckAndConvertUtils::CheckInteger("input0_shape", input0.size(), kEqual, 2, op_name);
|
||||
CheckAndConvertUtils::CheckInteger("input0_shape_dimen_1", input0[1], kLessEqual, 32, op_name);
|
||||
CheckAndConvertUtils::CheckInteger("input1_shape", input1.size(), kGreaterEqual, 1, op_name);
|
||||
|
||||
if (input_args.size() == 3) {
|
||||
auto input2 = CheckAndConvertUtils::ConvertShapePtrToShape("input2_shape", input_args[2]->BuildShape(), op_name);
|
||||
auto input2 = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
|
||||
CheckAndConvertUtils::CheckInteger("input2_shape", input2.size(), kEqual, 1, op_name);
|
||||
CheckAndConvertUtils::CheckInteger("input2_shape_dimen_0", input2[0], kEqual, input1[0], op_name);
|
||||
}
|
||||
|
|
|
@ -32,9 +32,9 @@ AbstractBasePtr LstmInfer(const PrimitivePtr &primitive, const std::vector<Abstr
|
|||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInteger("lstm_prim_infer", input_args.size(), kEqual, 4, prim_name);
|
||||
auto x_input_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
|
||||
auto h_input_shape = CheckAndConvertUtils::ConvertShapePtrToShape("h_shape", input_args[1]->BuildShape(), prim_name);
|
||||
auto c_input_shape = CheckAndConvertUtils::ConvertShapePtrToShape("c_shape", input_args[2]->BuildShape(), prim_name);
|
||||
auto x_input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
auto h_input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
|
||||
auto c_input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
|
||||
|
||||
int64_t input_x_size = GetValue<int64_t>(primitive->GetAttr(kInput_size));
|
||||
CheckAndConvertUtils::CheckInteger("x_shape.size()", x_input_shape.size(), kEqual, 3, prim_name);
|
||||
|
|
|
@ -26,8 +26,8 @@ abstract::ShapePtr MatMulInferShape(const PrimitivePtr &primitive, const std::ve
|
|||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInteger("matmul_infer_input", input_args.size(), kEqual, 2, prim_name);
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
|
||||
auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShape("w_shape", input_args[1]->BuildShape(), prim_name);
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
|
||||
auto trans_a = GetValue<bool>(primitive->GetAttr(kTransposeA));
|
||||
auto trans_b = GetValue<bool>(primitive->GetAttr(kTransposeB));
|
||||
|
||||
|
|
|
@ -30,9 +30,8 @@ namespace {
|
|||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
|
||||
auto assist_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShape("assist_shape", input_args[1]->BuildShape(), prim_name);
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
auto assist_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
|
||||
|
||||
CheckAndConvertUtils::CheckInteger("assist rank", (int64_t)assist_shape.size(), kGreaterEqual, 2, prim_name);
|
||||
CheckAndConvertUtils::Check("x_shape rank", (int64_t)x_shape.size() + 1, kLessEqual, "assist rank",
|
||||
|
|
|
@ -82,7 +82,7 @@ namespace {
|
|||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto op_name = primitive->name();
|
||||
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), op_name);
|
||||
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
|
||||
auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat)));
|
||||
if (format == NHWC) {
|
||||
in_shape = {in_shape[0], in_shape[3], in_shape[1], in_shape[2]};
|
||||
|
|
|
@ -25,10 +25,8 @@ namespace {
|
|||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
auto first_input_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShape("first_input_shape", input_args[0]->BuildShape(), prim_name);
|
||||
auto second_input_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShape("second_input_shape", input_args[1]->BuildShape(), prim_name);
|
||||
auto first_input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
auto second_input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
|
||||
CheckAndConvertUtils::CheckInteger("first input rank", first_input_shape.size(), kEqual, 3, prim_name);
|
||||
CheckAndConvertUtils::CheckInteger("second input rank", second_input_shape.size(), kEqual, 1, prim_name);
|
||||
std::vector<int64_t> out_shape = {first_input_shape[0], first_input_shape[1],
|
||||
|
|
|
@ -31,7 +31,7 @@ abstract::ShapePtr OneHotInferShape(const PrimitivePtr &primitive, const std::ve
|
|||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto op_name = primitive->name();
|
||||
int64_t axis = GetValue<int64_t>(primitive->GetAttr(kAxis));
|
||||
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), op_name);
|
||||
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
CheckAndConvertUtils::CheckInRange<int64_t>("axis", axis, kIncludeBoth, {-1, SizeToLong(in_shape.size())}, op_name);
|
||||
auto depth_val = GetValue<int64_t>(input_args[1]->BuildValue());
|
||||
CheckAndConvertUtils::CheckInteger("depth", depth_val, kGreaterEqual, 0, op_name);
|
||||
|
|
|
@ -28,9 +28,7 @@ namespace ops {
|
|||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
auto input_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), prim_name);
|
||||
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
return std::make_shared<abstract::Shape>(input_shape);
|
||||
}
|
||||
|
||||
|
|
|
@ -27,8 +27,8 @@ namespace mindspore {
|
|||
namespace ops {
|
||||
abstract::ShapePtr BroadCastInferShape(const std::string &op_name, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_LOG(INFO) << "Do infer shape for op " << op_name;
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), op_name);
|
||||
auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShape("y_shape", input_args[1]->GetShapeTrack(), op_name);
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
|
||||
auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->GetShapeTrack())[kShape];
|
||||
if (x_shape == y_shape) {
|
||||
return std::make_shared<abstract::Shape>(x_shape);
|
||||
}
|
||||
|
|
|
@ -23,7 +23,7 @@ std::vector<int64_t> _get_pack_shape(std::vector<BaseShapePtr> x_shapes, std::ve
|
|||
std::string name) {
|
||||
CheckAndConvertUtils::CheckInteger("len of input_x", (int64_t)x_shapes.size(), kGreaterEqual, 1, name);
|
||||
CheckAndConvertUtils::CheckSubClass("input_x[0]", x_types[0], {TypeIdToType(kObjectTypeTensorType)}, name);
|
||||
auto output_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape[0]", x_shapes[0], name);
|
||||
auto output_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(x_shapes[0])[kShape];
|
||||
int64_t rank_base = output_shape.size();
|
||||
int64_t N = x_shapes.size();
|
||||
// CheckAndConvertUtils::CheckInRange("axis", axis, kIncludeBoth, {-rank_base-1, rank_base}, name);
|
||||
|
@ -37,7 +37,7 @@ std::vector<int64_t> _get_pack_shape(std::vector<BaseShapePtr> x_shapes, std::ve
|
|||
MS_EXCEPTION_IF_NULL(type0);
|
||||
CheckAndConvertUtils::Check("x_type[" + std::to_string(i) + "]", type->type_id(), kEqual, "base", type0->type_id(),
|
||||
name);
|
||||
auto shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape" + std::to_string(i), x_shapes[i], name);
|
||||
auto shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(x_shapes[i])[kShape];
|
||||
if (shape != output_shape) {
|
||||
MS_EXCEPTION(ValueError) << "For '" + name + "' element " + std::to_string(i) +
|
||||
"shape in input can't pack with first element.";
|
||||
|
|
|
@ -25,7 +25,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
auto paddings_attr = GetValue<std::vector<std::vector<int64_t>>>(primitive->GetAttr(kPaddings));
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), "Pad");
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
CheckAndConvertUtils::CheckInteger("paddings_size", paddings_attr.size(), kEqual, int64_t(2 * x_shape.size()),
|
||||
prim_name);
|
||||
int64_t size = paddings_attr.size();
|
||||
|
|
|
@ -25,8 +25,8 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
auto prim_name = primitive->name();
|
||||
auto x = input_args[0]->BuildShape();
|
||||
auto w = input_args[1]->BuildShape();
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", x, prim_name);
|
||||
auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShape("w_shape", w, prim_name);
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(x)[kShape];
|
||||
auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(w)[kShape];
|
||||
|
||||
CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kNotEqual, 1, prim_name);
|
||||
CheckAndConvertUtils::CheckInteger("weight rank", w_shape.size(), kEqual, 1, prim_name);
|
||||
|
|
|
@ -112,7 +112,6 @@ void PriorBox::Init(const std::vector<int64_t> &min_sizes, const std::vector<int
|
|||
AbstractBasePtr PriorBoxInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto op_name = primitive->name();
|
||||
MS_EXCEPTION_IF_NULL(input_args[0]);
|
||||
std::vector<float> different_aspect_ratios{1.0f};
|
||||
auto aspect_ratios = GetValue<std::vector<float>>(primitive->GetAttr(kAspectRatios));
|
||||
|
@ -129,7 +128,7 @@ AbstractBasePtr PriorBoxInfer(const abstract::AnalysisEnginePtr &, const Primiti
|
|||
}
|
||||
auto min_sizes = GetValue<std::vector<int64_t>>(primitive->GetAttr(kMinSizes));
|
||||
int64_t num_priors_box = min_sizes.size() * different_aspect_ratios.size() + min_sizes.size();
|
||||
auto input = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), op_name);
|
||||
auto input = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
int64_t h = input[0] * input[1] * num_priors_box * 4;
|
||||
std::vector<int64_t> output_shape{1, h, 1, 2};
|
||||
return std::make_shared<abstract::AbstractTensor>(kFloat32, output_shape);
|
||||
|
|
|
@ -32,13 +32,12 @@ void QuantDTypeCast::Init(const int64_t src_t, const int64_t dst_t) {
|
|||
AbstractBasePtr QuantDTypeCastInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto op_name = primitive->name();
|
||||
MS_EXCEPTION_IF_NULL(input_args[0]);
|
||||
auto input_type = input_args[0]->BuildType()->cast<TensorTypePtr>();
|
||||
MS_EXCEPTION_IF_NULL(input_type);
|
||||
auto dst_type = GetValue<int64_t>(primitive->GetAttr(kDstT));
|
||||
MS_ASSERT(input_type->element() == TypeIdToType(TypeId(dst_type)));
|
||||
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), op_name);
|
||||
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
return std::make_shared<abstract::AbstractTensor>(TypeIdToType(TypeId(dst_type)), input_shape);
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameQuantDTypeCast, QuantDTypeCast);
|
||||
|
|
|
@ -34,8 +34,7 @@ AbstractBasePtr ReciprocalInfer(const abstract::AnalysisEnginePtr &, const Primi
|
|||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
// infer shape
|
||||
auto in_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->GetShapeTrack(), prim_name);
|
||||
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
|
||||
// infer type
|
||||
std::set<TypePtr> valid_x_type = {kTensorType};
|
||||
auto x_type = CheckAndConvertUtils::CheckTypeValid("x_type", input_args[0]->BuildType(), valid_x_type, prim_name);
|
||||
|
|
|
@ -71,8 +71,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
auto input_x_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShape("input_x_shape", input_args[0]->BuildShape(), prim_name);
|
||||
auto input_x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
|
||||
auto keep_dims = GetValue<bool>(primitive->GetAttr(kKeepDims));
|
||||
auto out_shape = infer_shape_reduce(input_x_shape, axis_value, keep_dims, prim_name);
|
||||
|
|
|
@ -49,8 +49,7 @@ AbstractBasePtr ResizeBilinearInfer(const abstract::AnalysisEnginePtr &, const P
|
|||
CheckAndConvertUtils::CheckInteger("resize_bilinear_infer", input_args.size(), kEqual, 1, prim_name);
|
||||
|
||||
// Infer shape
|
||||
auto input_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), prim_name);
|
||||
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
CheckAndConvertUtils::CheckInteger("input_shape_rank", input_shape.size(), kEqual, 4, prim_name);
|
||||
std::vector<int64_t> out_shape = {input_shape[0], input_shape[1]};
|
||||
auto size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kSize));
|
||||
|
|
|
@ -44,10 +44,8 @@ AbstractBasePtr ReverseSequenceInfer(const abstract::AnalysisEnginePtr &, const
|
|||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
// infer shape
|
||||
auto input_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), prim_name);
|
||||
auto seq_lengths =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShape("seq_lengths", input_args[1]->BuildShape(), prim_name);
|
||||
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
auto seq_lengths = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
|
||||
auto seq_dim = GetValue<int64_t>(primitive->GetAttr(kSeqDim));
|
||||
auto batch_dim = GetValue<int64_t>(primitive->GetAttr(kBatchDim));
|
||||
CheckAndConvertUtils::CheckInteger("seq_dim", seq_dim, kLessEqual, input_shape.size(), prim_name);
|
||||
|
|
|
@ -24,8 +24,7 @@ namespace ops {
|
|||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
return std::make_shared<abstract::Shape>(x_shape);
|
||||
}
|
||||
|
||||
|
|
|
@ -24,9 +24,7 @@ namespace ops {
|
|||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
auto first_input_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShape("first_input_shape", input_args[0]->BuildShape(), prim_name);
|
||||
auto first_input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
auto out_shape = first_input_shape;
|
||||
out_shape[out_shape.size() - 1] = GetValue<int64_t>(primitive->GetAttr(kFftLength)) / 2 + 1;
|
||||
out_shape.push_back(2);
|
||||
|
|
|
@ -62,9 +62,8 @@ AbstractBasePtr ROIPoolingInfer(const abstract::AnalysisEnginePtr &, const Primi
|
|||
// Infer shape
|
||||
auto new_h = GetValue<int64_t>(primitive->GetAttr(kPooledH));
|
||||
auto new_w = GetValue<int64_t>(primitive->GetAttr(kPooledW));
|
||||
auto input_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), prim_name);
|
||||
auto roi_shape = CheckAndConvertUtils::ConvertShapePtrToShape("roi_shape", input_args[1]->BuildShape(), prim_name);
|
||||
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
auto roi_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
|
||||
std::vector<int64_t> output_shape;
|
||||
output_shape.push_back(roi_shape[0]);
|
||||
output_shape.push_back(new_h);
|
||||
|
|
|
@ -23,7 +23,7 @@ namespace mindspore {
|
|||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), "round");
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
return std::make_shared<abstract::Shape>(x_shape);
|
||||
}
|
||||
|
||||
|
|
|
@ -30,7 +30,7 @@ namespace {
|
|||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("in_shape", input_args[0]->GetShapeTrack(), prim_name);
|
||||
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
|
||||
CheckAndConvertUtils::CheckInteger("input shape", in_shape.size(), kEqual, 1, prim_name);
|
||||
return std::make_shared<abstract::Shape>(in_shape);
|
||||
}
|
||||
|
|
|
@ -29,7 +29,7 @@ abstract::ShapePtr ScalarSummaryInferShape(const PrimitivePtr &primitive,
|
|||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
// check
|
||||
auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShape("v_shape", input_args[1]->BuildShape(), prim_name);
|
||||
auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
|
||||
CheckAndConvertUtils::CheckInteger("v rank", v_shape.size(), kLessEqual, 1, prim_name);
|
||||
return std::make_shared<abstract::Shape>(ShapeVector(1));
|
||||
}
|
||||
|
|
|
@ -29,10 +29,8 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
for (const auto &shape : shape_value_element) {
|
||||
CheckAndConvertUtils::CheckInteger("shape value", shape, kGreaterThan, 0, "ScatterNd");
|
||||
}
|
||||
auto indices_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShape("indices_shape", input_args[0]->BuildShape(), "ScatterNd");
|
||||
auto update_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShape("update_shape", input_args[1]->BuildShape(), "ScatterNd");
|
||||
auto indices_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
auto update_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
|
||||
CheckAndConvertUtils::CheckInteger("indices_shape[0] and update_shape[0]", indices_shape[0], kEqual, update_shape[0],
|
||||
"ScatterNd");
|
||||
return std::make_shared<abstract::Shape>(shape_value_element);
|
||||
|
|
|
@ -34,8 +34,8 @@ AbstractBasePtr SigmoidCrossEntropyWithLogitsInfer(const abstract::AnalysisEngin
|
|||
prim_name);
|
||||
|
||||
// Infer shape
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
|
||||
auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShape("y_shape", input_args[1]->BuildShape(), prim_name);
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
|
||||
CheckAndConvertUtils::Check("x_shape", x_shape, kEqual, "y_shape", y_shape, prim_name, TypeError);
|
||||
|
||||
// Infer type
|
||||
|
|
|
@ -31,7 +31,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), "Sin");
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
return std::make_shared<abstract::Shape>(x_shape);
|
||||
}
|
||||
|
||||
|
|
|
@ -23,7 +23,6 @@ namespace ops {
|
|||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
if (input_args.size() != 1) {
|
||||
MS_LOG(ERROR) << "Skip Gram should have one input";
|
||||
}
|
||||
|
@ -31,7 +30,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
if (infer_value == nullptr) {
|
||||
MS_LOG(INFO) << "Do infer shape in runtime.";
|
||||
}
|
||||
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("in_shape", input_args[0]->BuildShape(), prim_name);
|
||||
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
return std::make_shared<abstract::Shape>(in_shape);
|
||||
}
|
||||
|
||||
|
|
|
@ -40,8 +40,8 @@ AbstractBasePtr SmoothL1LossInfer(const abstract::AnalysisEnginePtr &, const Pri
|
|||
CheckAndConvertUtils::CheckInteger("smooth_l1_loss_infer", input_args.size(), kEqual, 2, prim_name);
|
||||
|
||||
// Infer shape
|
||||
auto prediction = CheckAndConvertUtils::ConvertShapePtrToShape("prediction", input_args[0]->BuildShape(), prim_name);
|
||||
auto target = CheckAndConvertUtils::ConvertShapePtrToShape("target", input_args[0]->BuildShape(), prim_name);
|
||||
auto prediction = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
auto target = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
CheckAndConvertUtils::Check("prediction shape", prediction, kEqual, "target shape", target, prim_name, TypeError);
|
||||
|
||||
// Infer type
|
||||
|
|
|
@ -34,10 +34,8 @@ AbstractBasePtr SoftmaxCrossEntropyWithLogitsInfer(const abstract::AnalysisEngin
|
|||
prim_name);
|
||||
|
||||
// Infer shape
|
||||
auto logits_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShape("logits_shape", input_args[0]->BuildShape(), prim_name);
|
||||
auto labels_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShape("labels_shape", input_args[1]->BuildShape(), prim_name);
|
||||
auto logits_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
auto labels_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
|
||||
CheckAndConvertUtils::Check("logits shape", logits_shape, kEqual, "labels shape", labels_shape, prim_name, TypeError);
|
||||
std::vector<int64_t> loss_shape = {logits_shape[0]};
|
||||
auto dlogits_shape = logits_shape;
|
||||
|
|
|
@ -29,8 +29,7 @@ namespace {
|
|||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
auto input_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), prim_name);
|
||||
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
CheckAndConvertUtils::CheckInteger("input shape", input_shape.size(), kEqual, 4, prim_name);
|
||||
std::vector<int64_t> output_shape(input_shape.size());
|
||||
auto block_shape_vector = GetValue<std::vector<int64_t>>(primitive->GetAttr(kBlockSize));
|
||||
|
|
|
@ -29,7 +29,7 @@ namespace {
|
|||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
CheckAndConvertUtils::CheckInteger("input_x rank", x_shape.size(), kEqual, 4, prim_name);
|
||||
auto out_shape = x_shape;
|
||||
int64_t block_shape_prod = 1;
|
||||
|
|
|
@ -43,8 +43,7 @@ AbstractBasePtr SparseSoftmaxCrossEntropyWithLogitsInfer(const abstract::Analysi
|
|||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
// infer shape
|
||||
auto input_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), prim_name);
|
||||
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
std::vector<int64_t> output_shape;
|
||||
if (GetValue<bool>(primitive->GetAttr(kIsGrad)) != 0) {
|
||||
output_shape = input_shape;
|
||||
|
|
|
@ -33,8 +33,7 @@ AbstractBasePtr SparseToDenseInfer(const abstract::AnalysisEnginePtr &, const Pr
|
|||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
// infer shape
|
||||
auto dense_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShape("dense_shape", input_args[3]->BuildShape(), prim_name);
|
||||
auto dense_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[3]->BuildShape())[kShape];
|
||||
// infer type
|
||||
auto values_type = input_args[1]->BuildType()->cast<TensorTypePtr>()->element();
|
||||
return std::make_shared<abstract::AbstractTensor>(values_type, dense_shape);
|
||||
|
|
|
@ -29,7 +29,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
auto axis = GetValue<std::vector<int64_t>>(primitive->GetAttr(kAxis));
|
||||
std::vector<int64_t> infer_shape;
|
||||
|
||||
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->GetShapeTrack(), op_name);
|
||||
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
|
||||
auto len = SizeToLong(in_shape.size());
|
||||
if (axis.empty()) {
|
||||
std::copy_if(in_shape.begin(), in_shape.end(), std::back_inserter(infer_shape),
|
||||
|
|
|
@ -21,7 +21,6 @@ namespace ops {
|
|||
namespace {
|
||||
abstract::AbstractBasePtr StackInfer(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
|
||||
if (input_args.size() != 1) {
|
||||
MS_LOG(ERROR) << "Invalid output size:" << input_args.size();
|
||||
|
@ -29,11 +28,9 @@ abstract::AbstractBasePtr StackInfer(const PrimitivePtr &primitive, const std::v
|
|||
if (input_args.size() < 1) {
|
||||
MS_LOG(ERROR) << "Invalid input size " << input_args.size();
|
||||
}
|
||||
auto input_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), prim_name);
|
||||
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
for (int64_t i = 1; i < (int64_t)input_args.size(); ++i) {
|
||||
auto input_shape_tmp =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[i]->BuildShape(), prim_name);
|
||||
auto input_shape_tmp = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[i]->BuildShape())[kShape];
|
||||
if (input_shape_tmp.size() != input_shape.size()) {
|
||||
MS_LOG(ERROR) << "All input shape size should be the same!";
|
||||
}
|
||||
|
|
|
@ -108,7 +108,7 @@ abstract::ShapePtr StridedSliceInferShape(const PrimitivePtr &primitive,
|
|||
auto temp_strides_v = input_args[3]->cast<abstract::AbstractTuplePtr>()->BuildValue();
|
||||
auto strides_v = GetValue<std::vector<int64_t>>(temp_strides_v);
|
||||
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
int64_t x_rank = x_shape.size();
|
||||
int64_t slice_len = begin_v.size();
|
||||
std::vector<int64_t> begin_pos = TenToTwo(GetValue<int64_t>(primitive->GetAttr(kBeginMask)));
|
||||
|
|
|
@ -33,7 +33,7 @@ AbstractBasePtr TanInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr
|
|||
CheckAndConvertUtils::CheckInteger("tan_infer", input_args.size(), kEqual, 1, prim_name);
|
||||
|
||||
// Infer Shape
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
auto infer_shape = std::make_shared<abstract::Shape>(x_shape);
|
||||
|
||||
// Infer Type
|
||||
|
|
|
@ -24,11 +24,8 @@ namespace {
|
|||
abstract::ShapePtr TensorListFromTensorInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
auto input0_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShape("input0 shape", input_args[0]->BuildShape(), prim_name);
|
||||
auto input1_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShape("input1 shape", input_args[1]->BuildShape(), prim_name);
|
||||
auto input0_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
auto input1_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
|
||||
if (input0_shape.size() < 1) {
|
||||
MS_LOG(ERROR) << "input0_shape.size():" << input0_shape.size() << " must be greater than 0!";
|
||||
}
|
||||
|
|
|
@ -52,9 +52,7 @@ AbstractBasePtr TensorListStackInfer(const abstract::AnalysisEnginePtr &, const
|
|||
for (const auto &input : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(input);
|
||||
}
|
||||
auto op_name = primitive->name();
|
||||
auto input0_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShape("input0_shape", input_args[0]->BuildShape(), op_name);
|
||||
auto input0_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
int64_t num = std::accumulate(input0_shape.begin(), input0_shape.end(), 1LL, std::multiplies<int64_t>());
|
||||
if (num == 0) {
|
||||
MS_LOG(ERROR) << "Try to stack a empty tensorlist!";
|
||||
|
@ -62,8 +60,7 @@ AbstractBasePtr TensorListStackInfer(const abstract::AnalysisEnginePtr &, const
|
|||
if (input_args[1]->BuildShape() == nullptr) {
|
||||
MS_LOG(ERROR) << "ele_shape->data_c() is nullptr";
|
||||
}
|
||||
auto input1_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShape("input1_shape", input_args[1]->BuildShape(), op_name);
|
||||
auto input1_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
|
||||
input1_shape.insert(input1_shape.begin(), 1);
|
||||
return std::make_shared<abstract::AbstractTensor>(input_args[0]->BuildType(), input1_shape);
|
||||
}
|
||||
|
|
|
@ -29,7 +29,7 @@ abstract::ShapePtr TensorSummaryInferShape(const PrimitivePtr &primitive,
|
|||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
// check
|
||||
auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShape("v_shape", input_args[1]->BuildShape(), prim_name);
|
||||
auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
|
||||
CheckAndConvertUtils::CheckInteger("v rank", v_shape.size(), kGreaterEqual, 1, prim_name);
|
||||
return std::make_shared<abstract::Shape>(ShapeVector(1));
|
||||
}
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue