forked from mindspore-Ecosystem/mindspore
!22721 clean codecheck for ops r1.3
Merge pull request !22721 from wangnan39/codecheck_clean_r1.3
This commit is contained in:
commit
493aed5f16
|
@ -28,19 +28,19 @@ abstract::AbstractBasePtr AdamInfer(const PrimitivePtr &primitive, const std::ve
|
|||
CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, input_num, prim_name);
|
||||
|
||||
// infer shape
|
||||
auto var_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
|
||||
auto m_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->GetShapeTrack())[kShape];
|
||||
auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->GetShapeTrack())[kShape];
|
||||
auto grad_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[9]->GetShapeTrack())[kShape];
|
||||
auto var_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->GetShapeTrack())[kShape];
|
||||
auto m_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->GetShapeTrack())[kShape];
|
||||
auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->GetShapeTrack())[kShape];
|
||||
auto grad_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex9]->GetShapeTrack())[kShape];
|
||||
CheckAndConvertUtils::Check("var_shape", var_shape, kEqual, "m_shape", m_shape, prim_name);
|
||||
CheckAndConvertUtils::Check("var_shape", var_shape, kEqual, "v_shape", v_shape, prim_name);
|
||||
CheckAndConvertUtils::Check("var_shape", var_shape, kEqual, "grad_shape", grad_shape, prim_name);
|
||||
|
||||
// infer type
|
||||
auto var_type = input_args[0]->BuildType();
|
||||
auto m_type = input_args[1]->BuildType();
|
||||
auto v_type = input_args[2]->BuildType();
|
||||
auto grad_type = input_args[9]->BuildType();
|
||||
auto var_type = input_args[kInputIndex0]->BuildType();
|
||||
auto m_type = input_args[kInputIndex1]->BuildType();
|
||||
auto v_type = input_args[kInputIndex2]->BuildType();
|
||||
auto grad_type = input_args[kInputIndex9]->BuildType();
|
||||
auto infer_var_type = CheckAndConvertUtils::CheckTensorTypeValid("var_type", var_type, common_valid_types, prim_name);
|
||||
auto infer_m_type = CheckAndConvertUtils::CheckTensorTypeValid("m_type", m_type, common_valid_types, prim_name);
|
||||
auto infer_v_type = CheckAndConvertUtils::CheckTensorTypeValid("v_type", v_type, common_valid_types, prim_name);
|
||||
|
|
|
@ -24,33 +24,22 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
AbstractBasePtr AddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, 2, prim_name);
|
||||
const int64_t input_num = 2;
|
||||
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
return BroadCastInferShape(prim_name, input_args);
|
||||
}
|
||||
auto output_shape = BroadCastInferShape(prim_name, input_args);
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
auto op_name = prim->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("Add infer", SizeToLong(input_args.size()), kGreaterEqual, 2, op_name);
|
||||
std::map<std::string, TypePtr> types;
|
||||
(void)types.emplace("x", input_args[0]->BuildType());
|
||||
(void)types.emplace("y", input_args[1]->BuildType());
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
|
||||
}
|
||||
} // namespace
|
||||
|
||||
AbstractBasePtr AddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args));
|
||||
auto output_type = CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim_name);
|
||||
return abstract::MakeAbstract(output_shape, output_type);
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameAdd, Add);
|
||||
} // namespace ops
|
||||
|
|
|
@ -61,7 +61,8 @@ AbstractBasePtr ApplyMomentumInfer(const abstract::AnalysisEnginePtr &, const Pr
|
|||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("apply_momentum_infer", SizeToLong(input_args.size()), kEqual, 5, prim_name);
|
||||
const int64_t input_num = 5;
|
||||
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
|
||||
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
|
@ -70,11 +71,11 @@ AbstractBasePtr ApplyMomentumInfer(const abstract::AnalysisEnginePtr &, const Pr
|
|||
auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
|
||||
// Infer type
|
||||
auto v_tensor_type = input_args[0]->BuildType();
|
||||
auto a_tensor_type = input_args[1]->BuildType();
|
||||
auto l_type = input_args[2]->BuildType();
|
||||
auto g_type = input_args[3]->BuildType();
|
||||
auto m_type = input_args[4]->BuildType();
|
||||
auto v_tensor_type = input_args[kInputIndex0]->BuildType();
|
||||
auto a_tensor_type = input_args[kInputIndex1]->BuildType();
|
||||
auto l_type = input_args[kInputIndex2]->BuildType();
|
||||
auto g_type = input_args[kInputIndex3]->BuildType();
|
||||
auto m_type = input_args[kInputIndex4]->BuildType();
|
||||
const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kFloat64};
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("v_type", v_tensor_type, valid_types, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("a_type", a_tensor_type, valid_types, prim_name);
|
||||
|
|
|
@ -30,8 +30,9 @@ AbstractBasePtr InferImplAssign(const abstract::AnalysisEnginePtr &, const Primi
|
|||
const AbstractBasePtrList &args_spec_list) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
const int64_t input_num = 2;
|
||||
(void)CheckAndConvertUtils::CheckInteger(
|
||||
"Assign infer", SizeToLong(CheckAndConvertUtils::GetRemoveMonadAbsNum(args_spec_list)), kEqual, 2, prim_name);
|
||||
"infer", SizeToLong(CheckAndConvertUtils::GetRemoveMonadAbsNum(args_spec_list)), kEqual, input_num, prim_name);
|
||||
auto check_types = common_valid_types;
|
||||
(void)check_types.emplace(kBool);
|
||||
auto variable_type = args_spec_list[0]->BuildType();
|
||||
|
|
|
@ -83,7 +83,9 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
auto op_name = primitive->name();
|
||||
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
|
||||
auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat)));
|
||||
(void)CheckAndConvertUtils::CheckInteger("x_rank", SizeToLong(in_shape.size()), kEqual, 4, op_name);
|
||||
const int64_t x_size = 4;
|
||||
const int64_t attr_size = 4;
|
||||
(void)CheckAndConvertUtils::CheckInteger("x_rank", SizeToLong(in_shape.size()), kEqual, x_size, op_name);
|
||||
if (format == NHWC) {
|
||||
in_shape = {in_shape[0], in_shape[3], in_shape[1], in_shape[2]};
|
||||
}
|
||||
|
@ -94,8 +96,8 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
auto in_h = in_shape[2];
|
||||
auto in_w = in_shape[3];
|
||||
auto strides = GetValue<std::vector<int64_t>>(primitive->GetAttr(kStrides));
|
||||
(void)CheckAndConvertUtils::CheckInteger("kernel size", SizeToLong(kernel_size.size()), kEqual, 4, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("strides size", SizeToLong(strides.size()), kEqual, 4, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("kernel size", SizeToLong(kernel_size.size()), kEqual, attr_size, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("strides size", SizeToLong(strides.size()), kEqual, attr_size, op_name);
|
||||
if (std::any_of(strides.begin(), strides.end(), [](int64_t stride) { return stride <= 0; })) {
|
||||
MS_LOG(EXCEPTION) << "Strides is not valid, strides must be positive.";
|
||||
}
|
||||
|
|
|
@ -38,12 +38,15 @@ void GetAttrs(const PrimitivePtr &primitive, std::vector<int64_t> *kernel_size,
|
|||
// attr kernel size
|
||||
*kernel_size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kKernelSize));
|
||||
if (kernel_size->size() != kKernelDims) {
|
||||
MS_LOG(EXCEPTION) << "kernel_size of AvgPool3D must be 3.";
|
||||
MS_LOG(EXCEPTION) << "`kernel_size` of AvgPool3D must be 3.";
|
||||
}
|
||||
// attr strides
|
||||
*strides = GetValue<std::vector<int64_t>>(primitive->GetAttr(kStrides));
|
||||
if (strides->size() != kStridesDims) {
|
||||
MS_LOG(EXCEPTION) << "strides of AvgPool3D must be 3.";
|
||||
MS_LOG(EXCEPTION) << "`strides` of AvgPool3D must be 3.";
|
||||
}
|
||||
if (std::any_of(strides->begin(), strides->end(), [](int64_t stride) { return stride <= 0; })) {
|
||||
MS_EXCEPTION(ValueError) << "Invalid strides, strides must be all positive.";
|
||||
}
|
||||
// sttr pad_list
|
||||
*pad_list = GetValue<std::vector<int64_t>>(primitive->GetAttr(kPadList));
|
||||
|
@ -135,6 +138,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
auto kernel_d = kernel_size[0];
|
||||
auto kernel_h = kernel_size[1];
|
||||
auto kernel_w = kernel_size[2];
|
||||
|
||||
auto stride_d = strides[0];
|
||||
auto stride_h = strides[1];
|
||||
auto stride_w = strides[2];
|
||||
|
@ -148,7 +152,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
|
||||
std::vector<int64_t> out_shape =
|
||||
GetOutputShape(in_shape, kernel_d, kernel_h, kernel_w, stride_d, stride_h, stride_w, new_pad_list, ceil_mode);
|
||||
if (std::any_of(out_shape.begin(), out_shape.end(), [](int64_t a) { return a <= 0; })) {
|
||||
if (std::any_of(out_shape.begin(), out_shape.end(), [](int64_t shp_v) { return shp_v <= 0; })) {
|
||||
MS_LOG(EXCEPTION) << "output size is not valid.";
|
||||
}
|
||||
return std::make_shared<abstract::Shape>(out_shape);
|
||||
|
|
|
@ -144,7 +144,9 @@ bool BatchMatmul::get_transpose_b() const {
|
|||
AbstractBasePtr BatchMatmulInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
(void)CheckAndConvertUtils::CheckInteger("BatchMatmul infer", input_args.size(), kGreaterEqual, 2, primitive->name());
|
||||
const int64_t input_num = 2;
|
||||
(void)CheckAndConvertUtils::CheckInteger("BatchMatmul infer", SizeToLong(input_args.size()), kGreaterEqual, input_num,
|
||||
primitive->name());
|
||||
return abstract::MakeAbstract(BatchMatmulInferShape(primitive, input_args),
|
||||
BatchMatmulInferType(primitive, input_args));
|
||||
}
|
||||
|
|
|
@ -73,17 +73,18 @@ AbstractBasePtr BatchNormInfer(const abstract::AnalysisEnginePtr &, const Primit
|
|||
// Infer shape
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("batch_norm_infer", SizeToLong(input_args.size()), kEqual, 5, prim_name);
|
||||
const int64_t input_num = 5;
|
||||
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
|
||||
|
||||
auto input_x = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat)));
|
||||
if (format == NHWC) {
|
||||
input_x = {input_x[0], input_x[3], input_x[1], input_x[2]};
|
||||
}
|
||||
auto scale = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
|
||||
auto bias = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
|
||||
auto mean = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[3]->BuildShape())[kShape];
|
||||
auto variance = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[4]->BuildShape())[kShape];
|
||||
auto scale = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
|
||||
auto bias = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
|
||||
auto mean = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->BuildShape())[kShape];
|
||||
auto variance = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex4]->BuildShape())[kShape];
|
||||
|
||||
std::vector<int64_t> input_shape_norm;
|
||||
if (format == NCHW) {
|
||||
|
@ -106,19 +107,19 @@ AbstractBasePtr BatchNormInfer(const abstract::AnalysisEnginePtr &, const Primit
|
|||
}
|
||||
|
||||
// Infer type
|
||||
auto scale_type = input_args[1]->BuildType()->cast<TensorTypePtr>()->element();
|
||||
auto bias_type = input_args[2]->BuildType()->cast<TensorTypePtr>()->element();
|
||||
auto scale_type = input_args[kInputIndex1]->BuildType()->cast<TensorTypePtr>()->element();
|
||||
auto bias_type = input_args[kInputIndex2]->BuildType()->cast<TensorTypePtr>()->element();
|
||||
|
||||
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
|
||||
auto input_x_type =
|
||||
CheckAndConvertUtils::CheckTensorTypeValid("input_x", input_args[0]->BuildType(), valid_types, prim_name);
|
||||
CheckAndConvertUtils::CheckTensorTypeValid("x", input_args[kInputIndex0]->BuildType(), valid_types, prim_name);
|
||||
std::map<std::string, TypePtr> args;
|
||||
args.emplace("scale", input_args[1]->BuildType());
|
||||
args.emplace("bias", input_args[2]->BuildType());
|
||||
args.emplace("scale", input_args[kInputIndex1]->BuildType());
|
||||
args.emplace("bias", input_args[kInputIndex2]->BuildType());
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name);
|
||||
std::map<std::string, TypePtr> args_moving;
|
||||
args_moving.emplace("scale", input_args[2]->BuildType());
|
||||
args_moving.emplace("bias", input_args[3]->BuildType());
|
||||
args_moving.emplace("scale", input_args[kInputIndex2]->BuildType());
|
||||
args_moving.emplace("bias", input_args[kInputIndex3]->BuildType());
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeSame(args_moving, valid_types, prim_name);
|
||||
|
||||
auto output0 = std::make_shared<abstract::AbstractTensor>(input_x_type, input_x);
|
||||
|
|
|
@ -59,15 +59,18 @@ AbstractBasePtr BatchToSpaceInfer(const abstract::AnalysisEnginePtr &, const Pri
|
|||
auto block_size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kBlockSize));
|
||||
auto crops = GetValue<std::vector<std::vector<int64_t>>>(primitive->GetAttr(kCrops));
|
||||
auto out_shape = x_shape;
|
||||
(void)CheckAndConvertUtils::CheckInteger("x rank", SizeToLong(x_shape.size()), kEqual, 4, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("block_size size", SizeToLong(block_size.size()), kEqual, 4, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("crops size", SizeToLong(crops.size()), kEqual, 4, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("crops[0] size", SizeToLong(crops[0].size()), kEqual, 4, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("crops[1] size", SizeToLong(crops[1].size()), kEqual, 4, prim_name);
|
||||
const int64_t attr_size = 4;
|
||||
const int64_t x_rank = 4;
|
||||
(void)CheckAndConvertUtils::CheckInteger("x rank", SizeToLong(x_shape.size()), kEqual, x_rank, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("block_size size", SizeToLong(block_size.size()), kEqual, attr_size,
|
||||
prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("crops size", SizeToLong(crops.size()), kEqual, attr_size, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("crops[0] size", SizeToLong(crops[0].size()), kEqual, attr_size, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("crops[1] size", SizeToLong(crops[1].size()), kEqual, attr_size, prim_name);
|
||||
for (size_t i = 0; i < 2; ++i) {
|
||||
auto x_block_prod = out_shape[i + 2] * block_size[i];
|
||||
auto crops_sum = crops[i][0] + crops[i][1];
|
||||
CheckAndConvertUtils::Check("x block shape prod", x_block_prod, kGreaterThan, "crops sum", 4, prim_name);
|
||||
CheckAndConvertUtils::Check("x block shape prod", x_block_prod, kGreaterThan, "crops sum", attr_size, prim_name);
|
||||
out_shape[i + 2] = x_block_prod - crops_sum;
|
||||
}
|
||||
(void)CheckAndConvertUtils::CheckInteger("x_shape[0] % (block_size[0]*block_size[1])",
|
||||
|
|
|
@ -30,7 +30,8 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
(void)CheckAndConvertUtils::CheckInteger("input_x rank", SizeToLong(x_shape.size()), kEqual, 4, prim_name);
|
||||
const int64_t x_rank = 4;
|
||||
(void)CheckAndConvertUtils::CheckInteger("input_x rank", SizeToLong(x_shape.size()), kEqual, x_rank, prim_name);
|
||||
auto out_shape = x_shape;
|
||||
int64_t block_shape_prod = 1;
|
||||
size_t offset = 2;
|
||||
|
@ -62,7 +63,8 @@ TypePtr InferType(const std::vector<AbstractBasePtr> &input_args) {
|
|||
} // namespace
|
||||
|
||||
void BatchToSpaceND::set_crops(std::vector<std::vector<int64_t>> crops) {
|
||||
(void)CheckAndConvertUtils::CheckInteger(kCrops, SizeToLong(crops.size()), kEqual, 2, this->name());
|
||||
const int64_t crop_size = 2;
|
||||
(void)CheckAndConvertUtils::CheckInteger(kCrops, SizeToLong(crops.size()), kEqual, crop_size, this->name());
|
||||
size_t h = crops.size();
|
||||
size_t w = crops[0].size();
|
||||
std::vector<size_t> temp_w = {2, 2};
|
||||
|
@ -80,7 +82,9 @@ std::vector<std::vector<int64_t>> BatchToSpaceND::get_crops() const {
|
|||
return GetValue<std::vector<std::vector<int64_t>>>(value_ptr);
|
||||
}
|
||||
void BatchToSpaceND::set_block_shape(std::vector<int64_t> block_shape) {
|
||||
(void)CheckAndConvertUtils::CheckInteger(kBlockShape, SizeToLong(block_shape.size()), kEqual, 2, this->name());
|
||||
const int64_t block_size = 2;
|
||||
(void)CheckAndConvertUtils::CheckInteger(kBlockShape, SizeToLong(block_shape.size()), kEqual, block_size,
|
||||
this->name());
|
||||
for (size_t i = 0; i < block_shape.size(); i++) {
|
||||
(void)CheckAndConvertUtils::CheckInteger(kBlockShape, block_shape[i], kGreaterEqual, 1, this->name());
|
||||
}
|
||||
|
|
|
@ -35,15 +35,20 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
auto bias = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, 1);
|
||||
MS_EXCEPTION_IF_NULL(x);
|
||||
MS_EXCEPTION_IF_NULL(bias);
|
||||
(void)CheckAndConvertUtils::CheckInteger("arg size", SizeToLong(input_args.size()), kEqual, 2, prim_name);
|
||||
const int64_t input_num = 2;
|
||||
(void)CheckAndConvertUtils::CheckInteger("arg size", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
|
||||
auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape());
|
||||
auto input_shape = shape_map[kShape];
|
||||
auto min_shape = shape_map[kMinShape];
|
||||
auto max_shape = shape_map[kMaxShape];
|
||||
CheckAndConvertUtils::CheckInRange("bias_add_infer", input_shape.size(), kIncludeBoth, {2, 5}, prim_name);
|
||||
const int64_t x_min_rank = 2;
|
||||
const int64_t x_max_rank = 5;
|
||||
CheckAndConvertUtils::CheckInRange("bias_add_infer", input_shape.size(), kIncludeBoth, {x_min_rank, x_max_rank},
|
||||
prim_name);
|
||||
auto bias_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
|
||||
(void)CheckAndConvertUtils::CheckInteger("bias rank", SizeToLong(bias_shape.size()), kEqual, 1, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("x rank", SizeToLong(input_shape.size()), kGreaterEqual, 2, prim_name);
|
||||
const int64_t x_size = 2;
|
||||
(void)CheckAndConvertUtils::CheckInteger("x rank", SizeToLong(input_shape.size()), kGreaterEqual, x_size, prim_name);
|
||||
auto data_format_ptr = primitive->GetAttr("format");
|
||||
int64_t data_format = Format::NCHW;
|
||||
if (data_format_ptr != nullptr) {
|
||||
|
@ -71,7 +76,9 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto prim_name = prim->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("biasadd_infer", SizeToLong(input_args.size()), kEqual, 2, prim_name);
|
||||
const int64_t input_num = 2;
|
||||
(void)CheckAndConvertUtils::CheckInteger("biasadd_infer", SizeToLong(input_args.size()), kEqual, input_num,
|
||||
prim_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
|
|
@ -34,9 +34,9 @@ abstract::ShapePtr BinaryCrossEntroyInferShape(const PrimitivePtr &primitive,
|
|||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInRange("binary_cross_entropy_infer", input_args.size(), kIncludeBoth, {2, 3}, prim_name);
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
|
||||
auto weight_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
|
||||
auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
|
||||
auto weight_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
|
||||
CheckAndConvertUtils::Check("x shape", x_shape, kEqual, "y shape", y_shape, prim_name);
|
||||
std::vector<int64_t> infer_shape;
|
||||
if (weight_shape.size() < 1) {
|
||||
|
@ -50,19 +50,20 @@ abstract::ShapePtr BinaryCrossEntroyInferShape(const PrimitivePtr &primitive,
|
|||
}
|
||||
|
||||
TypePtr BinaryCrossEntroyInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
(void)CheckAndConvertUtils::CheckInteger("binary_cross_entropy_infer", SizeToLong(input_args.size()), kEqual, 3,
|
||||
prim->name());
|
||||
const int64_t input_num = 3;
|
||||
(void)CheckAndConvertUtils::CheckInteger("binary_cross_entropy_infer", SizeToLong(input_args.size()), kEqual,
|
||||
input_num, prim->name());
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
|
||||
std::map<std::string, TypePtr> types;
|
||||
(void)types.emplace("x_shape", input_args[0]->BuildType());
|
||||
(void)types.emplace("y_shape", input_args[1]->BuildType());
|
||||
(void)types.emplace("x_shape", input_args[kInputIndex0]->BuildType());
|
||||
(void)types.emplace("y_shape", input_args[kInputIndex1]->BuildType());
|
||||
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
|
||||
if (input_args[3]->BuildType() != nullptr) {
|
||||
(void)types.emplace("x_shape", input_args[0]->BuildType());
|
||||
(void)types.emplace("weight_shape", input_args[2]->BuildType());
|
||||
if (input_args[kInputIndex3]->BuildType() != nullptr) {
|
||||
(void)types.emplace("x_shape", input_args[kInputIndex0]->BuildType());
|
||||
(void)types.emplace("weight_shape", input_args[kInputIndex2]->BuildType());
|
||||
infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
|
||||
}
|
||||
return infer_type;
|
||||
|
|
|
@ -109,8 +109,9 @@ abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::ve
|
|||
auto w_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape());
|
||||
auto x_shape = x_shape_map[kShape];
|
||||
auto w_shape = w_shape_map[kShape];
|
||||
(void)CheckAndConvertUtils::CheckInteger("x shape size", SizeToLong(x_shape.size()), kEqual, 4, primitive->name());
|
||||
(void)CheckAndConvertUtils::CheckInteger("w shape size", SizeToLong(w_shape.size()), kEqual, 4, primitive->name());
|
||||
const int64_t shape_size = 4;
|
||||
(void)CheckAndConvertUtils::CheckInteger("x shape size", SizeToLong(x_shape.size()), kEqual, shape_size, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("w shape size", SizeToLong(w_shape.size()), kEqual, shape_size, prim_name);
|
||||
auto x_min_shape = x_shape_map[kMinShape];
|
||||
auto x_max_shape = x_shape_map[kMaxShape];
|
||||
auto w_min_shape = w_shape_map[kMinShape];
|
||||
|
@ -252,7 +253,8 @@ void Conv2D::set_pad_mode(const PadMode &pad_mode) {
|
|||
}
|
||||
|
||||
void Conv2D::set_pad(const std::vector<int64_t> &pad) {
|
||||
(void)CheckAndConvertUtils::CheckInteger("pad_size", pad.size(), kEqual, 4, name());
|
||||
const int64_t pad_size = 4;
|
||||
(void)CheckAndConvertUtils::CheckInteger("pad_size", SizeToLong(pad.size()), kEqual, pad_size, name());
|
||||
(void)AddAttr(kPad, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, name())));
|
||||
}
|
||||
|
||||
|
@ -316,7 +318,8 @@ Format Conv2D::get_format() const {
|
|||
|
||||
AbstractBasePtr Conv2dInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
(void)CheckAndConvertUtils::CheckInteger("Conv2d infer", SizeToLong(input_args.size()), kGreaterEqual, 2,
|
||||
const int64_t input_num = 2;
|
||||
(void)CheckAndConvertUtils::CheckInteger("Conv2d infer", SizeToLong(input_args.size()), kGreaterEqual, input_num,
|
||||
primitive->name());
|
||||
const std::set<TypePtr> valid_types = {kInt8, kInt32, kInt64, kFloat16, kFloat32};
|
||||
std::map<std::string, TypePtr> types;
|
||||
|
|
|
@ -53,7 +53,8 @@ void Conv2DTranspose::set_out_channel(int64_t out_channel) {
|
|||
}
|
||||
|
||||
void Conv2DTranspose::set_kernel_size(const std::vector<int64_t> &kernel_size) {
|
||||
(void)CheckAndConvertUtils::CheckInteger(kKernelSize, SizeToLong(kernel_size.size()), kEqual, 2, name());
|
||||
const int64_t kernel_len = 2;
|
||||
(void)CheckAndConvertUtils::CheckInteger(kKernelSize, SizeToLong(kernel_size.size()), kEqual, kernel_len, name());
|
||||
for (int64_t item : kernel_size) {
|
||||
(void)CheckAndConvertUtils::CheckInteger(kKernelSize, item, kGreaterEqual, 1, name());
|
||||
}
|
||||
|
@ -61,7 +62,8 @@ void Conv2DTranspose::set_kernel_size(const std::vector<int64_t> &kernel_size) {
|
|||
}
|
||||
|
||||
void Conv2DTranspose::set_stride(const std::vector<int64_t> &stride) {
|
||||
(void)CheckAndConvertUtils::CheckInteger(kStride, SizeToLong(stride.size()), kEqual, 2, name());
|
||||
const int64_t stride_size = 2;
|
||||
(void)CheckAndConvertUtils::CheckInteger(kStride, SizeToLong(stride.size()), kEqual, stride_size, name());
|
||||
for (int64_t item : stride) {
|
||||
(void)CheckAndConvertUtils::CheckInteger(kStride, item, kGreaterEqual, 1, name());
|
||||
}
|
||||
|
@ -69,7 +71,9 @@ void Conv2DTranspose::set_stride(const std::vector<int64_t> &stride) {
|
|||
}
|
||||
|
||||
void Conv2DTranspose::set_dilation(const std::vector<int64_t> &dilation) {
|
||||
(void)CheckAndConvertUtils::CheckInteger(kDilation, SizeToLong(dilation.size()), kGreaterEqual, 2, name());
|
||||
const int64_t dilation_size = 2;
|
||||
(void)CheckAndConvertUtils::CheckInteger(kDilation, SizeToLong(dilation.size()), kGreaterEqual, dilation_size,
|
||||
name());
|
||||
(void)AddAttr(kDilation, MakeValue(dilation));
|
||||
}
|
||||
|
||||
|
@ -87,7 +91,8 @@ void Conv2DTranspose::set_pad_mode(const PadMode &pad_mode) {
|
|||
}
|
||||
|
||||
void Conv2DTranspose::set_pad(const std::vector<int64_t> &pad) {
|
||||
(void)CheckAndConvertUtils::CheckInteger("pad_size", SizeToLong(pad.size()), kEqual, 4, name());
|
||||
const int64_t pad_size = 4;
|
||||
(void)CheckAndConvertUtils::CheckInteger("pad_size", SizeToLong(pad.size()), kEqual, pad_size, name());
|
||||
(void)AddAttr(kPad, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, name())));
|
||||
}
|
||||
|
||||
|
@ -105,7 +110,8 @@ void Conv2DTranspose::set_format(const Format &format) {
|
|||
}
|
||||
|
||||
void Conv2DTranspose::set_pad_list(const std::vector<int64_t> &pad_list) {
|
||||
(void)CheckAndConvertUtils::CheckInteger(kPadList, SizeToLong(pad_list.size()), kEqual, 4, name());
|
||||
const int64_t pad_size = 4;
|
||||
(void)CheckAndConvertUtils::CheckInteger(kPadList, SizeToLong(pad_list.size()), kEqual, pad_size, name());
|
||||
(void)this->AddAttr(kPadList, MakeValue(pad_list));
|
||||
}
|
||||
|
||||
|
|
|
@ -44,7 +44,8 @@ AbstractBasePtr CropInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt
|
|||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, 2, prim_name);
|
||||
const int64_t input_num = 2;
|
||||
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
|
|
@ -29,7 +29,9 @@ namespace mindspore {
|
|||
namespace ops {
|
||||
namespace {
|
||||
void CheckCTCLossInputs(const std::vector<AbstractBasePtr> &input_args, const std::string &op_name) {
|
||||
(void)CheckAndConvertUtils::CheckInteger("input numbers", input_args.size(), kGreaterEqual, 4, op_name);
|
||||
const int64_t input_num = 4;
|
||||
(void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kGreaterEqual, input_num,
|
||||
op_name);
|
||||
auto inputs = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(op_name, input_args, 0);
|
||||
auto labels_indices = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(op_name, input_args, 1);
|
||||
auto labels_values = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(op_name, input_args, 2);
|
||||
|
@ -40,11 +42,18 @@ void CheckCTCLossInputs(const std::vector<AbstractBasePtr> &input_args, const st
|
|||
auto labels_values_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(labels_values->BuildShape())[kShape];
|
||||
auto sequence_length_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(sequence_length->BuildShape())[kShape];
|
||||
|
||||
(void)CheckAndConvertUtils::CheckInteger("inputs rank", inputs_shape.size(), kEqual, 3, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("label_indices rank", labels_indices_shape.size(), kEqual, 2, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("label_indices second dim", labels_indices_shape[1], kEqual, 2, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("label_values rank", labels_values_shape.size(), kEqual, 1, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("sequence_length rank", sequence_length_shape.size(), kEqual, 1, op_name);
|
||||
const int64_t input_size = 3;
|
||||
const int64_t label_indice_size = 2;
|
||||
const int64_t label_indice_last_dim = 2;
|
||||
(void)CheckAndConvertUtils::CheckInteger("inputs rank", SizeToLong(inputs_shape.size()), kEqual, input_size, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("label_indices rank", SizeToLong(labels_indices_shape.size()), kEqual,
|
||||
label_indice_size, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("label_indices second dim", labels_indices_shape[1], kEqual,
|
||||
label_indice_last_dim, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("label_values rank", int64_t(labels_values_shape.size()), kEqual, 1,
|
||||
op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("sequence_length rank", int64_t(sequence_length_shape.size()), kEqual, 1,
|
||||
op_name);
|
||||
|
||||
if (labels_indices_shape[0] != labels_values_shape[0]) {
|
||||
MS_EXCEPTION(ValueError) << "For CTCLoss first dim of label_indices and label_value must be same, but got "
|
||||
|
@ -82,11 +91,15 @@ abstract::TupleShapePtr InferShape(const PrimitivePtr &primitive, const std::vec
|
|||
|
||||
TuplePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto op_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("labels_indices", input_args[1]->BuildType(), {kInt64}, op_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("labels_values", input_args[2]->BuildType(), {kInt32}, op_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("sequence_length", input_args[3]->BuildType(), {kInt32}, op_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("labels_indices", input_args[kInputIndex1]->BuildType(), {kInt64},
|
||||
op_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("labels_values", input_args[kInputIndex2]->BuildType(), {kInt32},
|
||||
op_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("sequence_length", input_args[kInputIndex3]->BuildType(), {kInt32},
|
||||
op_name);
|
||||
const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kFloat64};
|
||||
auto type = CheckAndConvertUtils::CheckTensorTypeValid("inputs", input_args[0]->BuildType(), valid_types, op_name);
|
||||
auto type =
|
||||
CheckAndConvertUtils::CheckTensorTypeValid("inputs", input_args[kInputIndex0]->BuildType(), valid_types, op_name);
|
||||
return std::make_shared<Tuple>(std::vector<TypePtr>{type, type});
|
||||
}
|
||||
} // namespace
|
||||
|
|
|
@ -44,7 +44,8 @@ AbstractBasePtr CumSumInfer(const abstract::AnalysisEnginePtr &, const Primitive
|
|||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 2, prim_name);
|
||||
const int64_t input_num = 2;
|
||||
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
|
|
@ -59,7 +59,8 @@ AbstractBasePtr DepthToSpaceInfer(const abstract::AnalysisEnginePtr &, const Pri
|
|||
if (format == NHWC) {
|
||||
x_shape = {x_shape[0], x_shape[3], x_shape[1], x_shape[2]};
|
||||
}
|
||||
(void)CheckAndConvertUtils::CheckInteger("x rank", SizeToLong(x_shape.size()), kEqual, 4, prim_name);
|
||||
const int64_t x_rank = 4;
|
||||
(void)CheckAndConvertUtils::CheckInteger("x rank", SizeToLong(x_shape.size()), kEqual, x_rank, prim_name);
|
||||
int64_t block_size = GetValue<int64_t>(primitive->GetAttr(kBlockSize));
|
||||
(void)CheckAndConvertUtils::CheckInteger("x_shape[1] % (block_size*block_size)",
|
||||
x_shape[1] % (block_size * block_size), kEqual, 0, prim_name);
|
||||
|
|
|
@ -117,14 +117,14 @@ AbstractBasePtr DetectionPostProcessInfer(const abstract::AnalysisEnginePtr &, c
|
|||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("detection_post_process_infer", SizeToLong(input_args.size()), kEqual, 3,
|
||||
prim_name);
|
||||
MS_EXCEPTION_IF_NULL(input_args[0]);
|
||||
MS_EXCEPTION_IF_NULL(input_args[1]);
|
||||
MS_EXCEPTION_IF_NULL(input_args[2]);
|
||||
auto boxes = input_args[0];
|
||||
auto scores = input_args[1];
|
||||
auto anchors = input_args[2];
|
||||
const int64_t input_num = 3;
|
||||
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
|
||||
MS_EXCEPTION_IF_NULL(input_args[kInputIndex0]);
|
||||
MS_EXCEPTION_IF_NULL(input_args[kInputIndex1]);
|
||||
MS_EXCEPTION_IF_NULL(input_args[kInputIndex2]);
|
||||
auto boxes = input_args[kInputIndex0];
|
||||
auto scores = input_args[kInputIndex1];
|
||||
auto anchors = input_args[kInputIndex2];
|
||||
auto boxes_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(boxes->BuildShape())[kShape];
|
||||
auto scores_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(scores->BuildShape())[kShape];
|
||||
auto anchors_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(anchors->BuildShape())[kShape];
|
||||
|
|
|
@ -60,7 +60,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
MS_EXCEPTION(ValueError) << "DropoutDoMask input mask do not match input, input_x shape: " << x_shape->ToString()
|
||||
<< ", mask shape: " << mask_shape->ToString();
|
||||
}
|
||||
auto keep_prop = input_args[2];
|
||||
auto keep_prop = input_args[kInputIndex2];
|
||||
if (keep_prop->isa<abstract::AbstractTensor>()) {
|
||||
auto keep_prop_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(keep_prop->BuildShape())[kShape];
|
||||
if (!keep_prop_shape.empty()) {
|
||||
|
@ -72,7 +72,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
|
||||
TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto op_name = primitive->name();
|
||||
auto keep_prop = input_args[2];
|
||||
auto keep_prop = input_args[kInputIndex2];
|
||||
MS_EXCEPTION_IF_NULL(keep_prop);
|
||||
auto keep_prop_value = keep_prop->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(keep_prop_value);
|
||||
|
@ -116,7 +116,9 @@ TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBaseP
|
|||
AbstractBasePtr DropoutDoMaskInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
(void)CheckAndConvertUtils::CheckInteger("infer shape", input_args.size(), kGreaterEqual, 3, primitive->name());
|
||||
const int64_t input_num = 3;
|
||||
(void)CheckAndConvertUtils::CheckInteger("infer shape", SizeToLong(input_args.size()), kGreaterEqual, input_num,
|
||||
primitive->name());
|
||||
return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(DropoutDoMask, prim::kPrimDropoutDoMask, DropoutDoMaskInfer, nullptr, true);
|
||||
|
|
|
@ -94,7 +94,9 @@ ShapeVector CalOutputShape(const AbstractBasePtrList shape_list) {
|
|||
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto op_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("infer shape", input_args.size(), kGreaterEqual, 2, op_name);
|
||||
const int64_t input_num = 2;
|
||||
(void)CheckAndConvertUtils::CheckInteger("infer shape", SizeToLong(input_args.size()), kGreaterEqual, input_num,
|
||||
op_name);
|
||||
AbstractBasePtr shape_args = input_args[0];
|
||||
MS_EXCEPTION_IF_NULL(shape_args);
|
||||
|
||||
|
|
|
@ -55,7 +55,8 @@ size_t CheckInputsAndGetShape(const AbstractBasePtr &input_arg, const string &pr
|
|||
abstract::TupleShapePtr Infer(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, 2, prim_name);
|
||||
const int64_t input_num = 2;
|
||||
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
|
||||
auto x_shape = CheckInputsAndGetShape(input_args[0], prim_name);
|
||||
auto y_shape = CheckInputsAndGetShape(input_args[1], prim_name);
|
||||
|
||||
|
|
|
@ -37,7 +37,8 @@ AbstractBasePtr EmbeddingLookupInfer(const abstract::AnalysisEnginePtr &, const
|
|||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, 3, prim_name);
|
||||
const int64_t input_num = 3;
|
||||
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
@ -47,7 +48,8 @@ AbstractBasePtr EmbeddingLookupInfer(const abstract::AnalysisEnginePtr &, const
|
|||
MS_EXCEPTION_IF_NULL(indices);
|
||||
const std::set<TypePtr> int_valid_types = {kInt8, kInt16, kInt32, kInt64};
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("indices type", indices->BuildType(), int_valid_types, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("offset", input_args[2]->BuildType(), int_valid_types, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("offset", input_args[kInputIndex2]->BuildType(), int_valid_types,
|
||||
prim_name);
|
||||
MS_EXCEPTION_IF_NULL(params->shape());
|
||||
auto params_shp = params->shape()->shape();
|
||||
MS_EXCEPTION_IF_NULL(indices->shape());
|
||||
|
|
|
@ -26,21 +26,13 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
AbstractBasePtr EqualInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto op_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, 2, op_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
return BroadCastInferShape(op_name, input_args);
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto op_name = prim->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kGreaterEqual, 2, op_name);
|
||||
const int64_t input_num = 2;
|
||||
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kGreaterEqual, input_num,
|
||||
op_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
@ -50,14 +42,9 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
std::map<std::string, TypePtr> types;
|
||||
(void)types.emplace("x", input_args[0]->BuildType());
|
||||
(void)types.emplace("y", input_args[1]->BuildType());
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, op_name);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
AbstractBasePtr EqualInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
(void)InferType(primitive, input_args);
|
||||
return abstract::MakeAbstract(InferShape(primitive, input_args), kBool);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, op_name);
|
||||
auto out_shape = BroadCastInferShape(op_name, input_args);
|
||||
return abstract::MakeAbstract(out_shape, kBool);
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameEqual, Equal);
|
||||
} // namespace ops
|
||||
|
|
|
@ -31,7 +31,8 @@ AbstractBasePtr ExpandDimsInfer(const abstract::AnalysisEnginePtr &, const Primi
|
|||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, 2, prim_name);
|
||||
const int64_t input_num = 2;
|
||||
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
|
|
@ -29,9 +29,9 @@ namespace {
|
|||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
auto min_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
|
||||
auto max_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
|
||||
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
|
||||
auto min_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
|
||||
auto max_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
|
||||
(void)CheckAndConvertUtils::CheckInteger("x_rank", SizeToLong(in_shape.size()), kGreaterEqual, 1, prim_name);
|
||||
CheckAndConvertUtils::Check("min_shape", min_shape, kEqual, "max_shape", max_shape, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("min_shape", SizeToLong(min_shape.size()), kEqual, 1, prim_name);
|
||||
|
@ -51,9 +51,9 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
MS_LOG(EXCEPTION) << "nullptr";
|
||||
}
|
||||
std::map<std::string, TypePtr> types;
|
||||
(void)types.emplace("x", input_args[0]->BuildType());
|
||||
(void)types.emplace("min", input_args[1]->BuildType());
|
||||
(void)types.emplace("max", input_args[2]->BuildType());
|
||||
(void)types.emplace("x", input_args[kInputIndex0]->BuildType());
|
||||
(void)types.emplace("min", input_args[kInputIndex1]->BuildType());
|
||||
(void)types.emplace("max", input_args[kInputIndex2]->BuildType());
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
|
||||
}
|
||||
} // namespace
|
||||
|
|
|
@ -44,17 +44,17 @@ AbstractBasePtr FakeQuantWithMinMaxVarsPerChannelInfer(const abstract::AnalysisE
|
|||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto op_name = primitive->name();
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
auto min_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
|
||||
auto max_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
|
||||
auto min_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
|
||||
auto max_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
|
||||
(void)CheckAndConvertUtils::CheckInteger("x rank", (int64_t)x_shape.size(), kGreaterThan, 1, op_name);
|
||||
CheckAndConvertUtils::Check("min shape", min_shape, kEqual, "max shape", max_shape, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("min shape", (int64_t)min_shape.size(), kEqual, 1, op_name);
|
||||
CheckAndConvertUtils::Check("min shape", min_shape[0], kEqual, "x shape", x_shape[x_shape.size() - 1], op_name);
|
||||
|
||||
auto x_type = input_args[0]->BuildType();
|
||||
auto min_type = input_args[1]->BuildType();
|
||||
auto max_type = input_args[2]->BuildType();
|
||||
auto x_type = input_args[kInputIndex0]->BuildType();
|
||||
auto min_type = input_args[kInputIndex1]->BuildType();
|
||||
auto max_type = input_args[kInputIndex2]->BuildType();
|
||||
std::vector<std::string> type_name = {"x", "min", "max"};
|
||||
std::vector<TypePtr> type = {x_type, min_type, max_type};
|
||||
for (size_t i = 0; i < 3; i++) {
|
||||
|
|
|
@ -26,11 +26,12 @@ AbstractBasePtr FillInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt
|
|||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, 3, prim_name);
|
||||
const int64_t input_num = 3;
|
||||
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
auto input_dtype = input_args[0]->cast<abstract::AbstractTypePtr>();
|
||||
auto input_dtype = input_args[kInputIndex0]->cast<abstract::AbstractTypePtr>();
|
||||
MS_EXCEPTION_IF_NULL(input_dtype);
|
||||
auto dtype_value = input_dtype->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(dtype_value);
|
||||
|
@ -39,10 +40,10 @@ AbstractBasePtr FillInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt
|
|||
auto valid_types = common_valid_types;
|
||||
valid_types.insert(kBool);
|
||||
(void)CheckAndConvertUtils::CheckTypeValid("output datatype", dtype, valid_types, prim_name);
|
||||
auto out_shape = GetValue<std::vector<int64_t>>(input_args[1]->BuildValue());
|
||||
auto x_type = input_args[2]->BuildType();
|
||||
auto out_shape = GetValue<std::vector<int64_t>>(input_args[kInputIndex1]->BuildValue());
|
||||
auto x_type = input_args[kInputIndex2]->BuildType();
|
||||
auto x_type_id = x_type->type_id();
|
||||
auto x_value = input_args[2]->BuildValue();
|
||||
auto x_value = input_args[kInputIndex2]->BuildValue();
|
||||
auto abs = std::make_shared<abstract::AbstractTensor>(dtype, std::make_shared<abstract::Shape>(out_shape));
|
||||
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(x_type_id, out_shape);
|
||||
MS_EXCEPTION_IF_NULL(tensor);
|
||||
|
@ -54,7 +55,7 @@ AbstractBasePtr FillInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt
|
|||
auto float_value = GetValue<float>(x_value);
|
||||
SetTensorData(tensor->data_c(), float_value, mem_size);
|
||||
} else {
|
||||
MS_LOG(ERROR) << " Fill not supported to flod the constant type " << input_args[2]->ToString();
|
||||
MS_LOG(ERROR) << " Fill not supported to flod the constant type " << input_args[kInputIndex2]->ToString();
|
||||
}
|
||||
abs->set_value(tensor);
|
||||
return abs;
|
||||
|
|
|
@ -47,8 +47,8 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
MS_LOG(EXCEPTION) << "nullptr";
|
||||
}
|
||||
std::map<std::string, TypePtr> types;
|
||||
(void)types.emplace("x", input_args[0]->BuildType());
|
||||
(void)types.emplace("y", input_args[1]->BuildType());
|
||||
(void)types.emplace("x", input_args[kInputIndex0]->BuildType());
|
||||
(void)types.emplace("y", input_args[kInputIndex1]->BuildType());
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
|
||||
}
|
||||
} // namespace
|
||||
|
|
|
@ -62,7 +62,8 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
if (format == NHWC) {
|
||||
in_shape = {in_shape[0], in_shape[3], in_shape[1], in_shape[2]};
|
||||
}
|
||||
(void)CheckAndConvertUtils::CheckInteger("x_rank", SizeToLong(in_shape.size()), kEqual, 4, op_name);
|
||||
const int64_t x_rank = 4;
|
||||
(void)CheckAndConvertUtils::CheckInteger("x_rank", SizeToLong(in_shape.size()), kEqual, x_rank, op_name);
|
||||
auto kernel_size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kKernelSize));
|
||||
auto pad_mode = PadMode(GetValue<int64_t>(primitive->GetAttr(kPadMode)));
|
||||
auto batch = in_shape[0];
|
||||
|
|
|
@ -40,7 +40,8 @@ void Conv2dTransposeFusion::Init(int64_t in_channel, int64_t out_channel, const
|
|||
}
|
||||
|
||||
void Conv2dTransposeFusion::set_kernel_size(const std::vector<int64_t> &kernel_size) {
|
||||
(void)CheckAndConvertUtils::CheckInteger(kKernelSize, SizeToLong(kernel_size.size()), kEqual, 2, name());
|
||||
const size_t kernel_len = 2;
|
||||
(void)CheckAndConvertUtils::CheckInteger(kKernelSize, SizeToLong(kernel_size.size()), kEqual, kernel_len, name());
|
||||
for (int64_t item : kernel_size) {
|
||||
(void)CheckAndConvertUtils::CheckInteger(kKernelSize, item, kGreaterEqual, 1, name());
|
||||
}
|
||||
|
@ -48,7 +49,8 @@ void Conv2dTransposeFusion::set_kernel_size(const std::vector<int64_t> &kernel_s
|
|||
}
|
||||
|
||||
void Conv2dTransposeFusion::set_dilation(const std::vector<int64_t> &dilation) {
|
||||
(void)CheckAndConvertUtils::CheckInteger(kDilation, SizeToLong(dilation.size()), kEqual, 2, name());
|
||||
const size_t dilation_size = 2;
|
||||
(void)CheckAndConvertUtils::CheckInteger(kDilation, SizeToLong(dilation.size()), kEqual, dilation_size, name());
|
||||
for (int64_t item : dilation) {
|
||||
(void)CheckAndConvertUtils::CheckInteger(kDilation, item, kGreaterEqual, 1, name());
|
||||
}
|
||||
|
|
|
@ -62,8 +62,8 @@ AbstractBasePtr FullConnectionInfer(const abstract::AnalysisEnginePtr &, const P
|
|||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
MS_EXCEPTION_IF_NULL(input_args[0]);
|
||||
MS_EXCEPTION_IF_NULL(input_args[1]);
|
||||
MS_EXCEPTION_IF_NULL(input_args[kInputIndex0]);
|
||||
MS_EXCEPTION_IF_NULL(input_args[kInputIndex1]);
|
||||
auto input0 = input_args[0];
|
||||
auto input1 = input_args[1];
|
||||
auto input0_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input0->BuildShape())[kShape];
|
||||
|
@ -93,7 +93,7 @@ AbstractBasePtr FullConnectionInfer(const abstract::AnalysisEnginePtr &, const P
|
|||
new_k = input1_shape[1];
|
||||
}
|
||||
if (has_bias) {
|
||||
auto input2_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
|
||||
auto input2_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
|
||||
if (input2_shape[0] != input1_shape[0]) {
|
||||
MS_EXCEPTION(ValueError) << "Bias size is invalid";
|
||||
}
|
||||
|
|
|
@ -36,9 +36,9 @@ AbstractBasePtr SliceFusionInfer(const abstract::AnalysisEnginePtr &, const Prim
|
|||
auto op_name = primitive->name();
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
auto x_shape_len = x_shape.size();
|
||||
auto begin_v = input_args[1]->BuildValue();
|
||||
auto size_v = input_args[2]->BuildValue();
|
||||
auto x_type = input_args[0]->BuildType();
|
||||
auto begin_v = input_args[kInputIndex1]->BuildValue();
|
||||
auto size_v = input_args[kInputIndex2]->BuildValue();
|
||||
auto x_type = input_args[kInputIndex0]->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(x_type);
|
||||
MS_EXCEPTION_IF_NULL(begin_v);
|
||||
MS_EXCEPTION_IF_NULL(size_v);
|
||||
|
|
|
@ -14,10 +14,12 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/gather.h"
|
||||
|
||||
#include <set>
|
||||
#include <memory>
|
||||
#include <algorithm>
|
||||
#include "ops/gather.h"
|
||||
#include "ops/op_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
|
@ -33,28 +35,30 @@ AbstractBasePtr GatherInfer(const abstract::AnalysisEnginePtr &, const Primitive
|
|||
CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(op_name, input_args, 1);
|
||||
// check
|
||||
std::set<TypePtr> valid_params_types = {kTensorType};
|
||||
(void)CheckAndConvertUtils::CheckSubClass("params_type", input_args[0]->BuildType(), valid_params_types, op_name);
|
||||
(void)CheckAndConvertUtils::CheckSubClass("params_type", input_args[kInputIndex0]->BuildType(), valid_params_types,
|
||||
op_name);
|
||||
std::set<TypePtr> int_types = {kInt8, kInt16, kInt32, kInt64};
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("index_type", input_args[1]->BuildType(), int_types, op_name);
|
||||
(void)CheckAndConvertUtils::CheckTypeValid("axis_type", input_args[2]->BuildType(), int_types, op_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("index_type", input_args[kInputIndex1]->BuildType(), int_types,
|
||||
op_name);
|
||||
(void)CheckAndConvertUtils::CheckTypeValid("axis_type", input_args[kInputIndex2]->BuildType(), int_types, op_name);
|
||||
|
||||
bool ind_dyn = (!indices->shape()->min_shape().empty() && !indices->shape()->max_shape().empty());
|
||||
bool param_dyn = (!params->shape()->min_shape().empty() && !params->shape()->max_shape().empty());
|
||||
int64_t axis_val = 0;
|
||||
// 3rd input is a Tensor when Gather is a dynamic shape operator
|
||||
if (input_args[2]->isa<abstract::AbstractTensor>()) {
|
||||
auto axis = input_args[2]->cast<abstract::AbstractTensorPtr>();
|
||||
if (input_args[kInputIndex2]->isa<abstract::AbstractTensor>()) {
|
||||
auto axis = input_args[kInputIndex2]->cast<abstract::AbstractTensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(axis);
|
||||
auto axis_value_ptr = axis->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(axis_value_ptr);
|
||||
auto axis_tensor = axis_value_ptr->cast<tensor::TensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(axis_tensor);
|
||||
axis_val = *static_cast<int64_t *>(axis_tensor->data_c());
|
||||
} else if (input_args[2]->isa<abstract::AbstractScalar>()) {
|
||||
auto axis = input_args[2]->cast<abstract::AbstractScalarPtr>();
|
||||
} else if (input_args[kInputIndex2]->isa<abstract::AbstractScalar>()) {
|
||||
auto axis = input_args[kInputIndex2]->cast<abstract::AbstractScalarPtr>();
|
||||
axis_val = GetValue<int64_t>(axis->BuildValue());
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Invalid abstract type:" << input_args[2]->type_name();
|
||||
MS_LOG(EXCEPTION) << "Invalid abstract type:" << input_args[kInputIndex2]->type_name();
|
||||
}
|
||||
auto params_shp = params->shape()->shape();
|
||||
auto indices_shp = indices->shape()->shape();
|
||||
|
|
|
@ -29,11 +29,11 @@ abstract::ShapePtr GatherDInferShape(const PrimitivePtr &primitive, const std::v
|
|||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
// check
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
auto index_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
|
||||
int64_t x_rank = x_shape.size();
|
||||
CheckAndConvertUtils::Check("x_rank", x_rank, kEqual, "index_rank", index_shape.size(), prim_name);
|
||||
auto value_ptr = input_args[1]->BuildValue();
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
|
||||
auto index_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
|
||||
int64_t x_rank = SizeToLong(x_shape.size());
|
||||
CheckAndConvertUtils::Check("x_rank", x_rank, kEqual, "index_rank", SizeToLong(index_shape.size()), prim_name);
|
||||
auto value_ptr = input_args[kInputIndex1]->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
auto dim_v = GetValue<int64_t>(value_ptr);
|
||||
CheckAndConvertUtils::Check("dim value", dim_v, kGreaterEqual, "negative index_rank", -x_rank, prim_name);
|
||||
|
@ -66,8 +66,9 @@ AbstractBasePtr GatherDInfer(const abstract::AnalysisEnginePtr &, const Primitiv
|
|||
auto prim_name = primitive->name();
|
||||
// check
|
||||
std::set<TypePtr> valid_types = {kInt32, kInt64};
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("index_type", input_args[2]->BuildType(), valid_types, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckSubClass("dim_type", input_args[1]->BuildType(), valid_types, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("index_type", input_args[kInputIndex2]->BuildType(), valid_types,
|
||||
prim_name);
|
||||
(void)CheckAndConvertUtils::CheckSubClass("dim_type", input_args[kInputIndex1]->BuildType(), valid_types, prim_name);
|
||||
return abstract::MakeAbstract(GatherDInferShape(primitive, input_args), GatherDInferType(primitive, input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(GatherD, prim::kPrimGatherD, GatherDInfer, nullptr, true);
|
||||
|
|
|
@ -28,7 +28,8 @@ namespace {
|
|||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, 2, prim_name);
|
||||
const int64_t input_num = 2;
|
||||
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
|
|
@ -30,7 +30,8 @@ constexpr size_t k5DInputDims = 5;
|
|||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto op_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("input size", input_args.size(), kEqual, 2, op_name);
|
||||
const int64_t input_num = 2;
|
||||
(void)CheckAndConvertUtils::CheckInteger("input size", SizeToLong(input_args.size()), kEqual, input_num, op_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
@ -48,7 +49,8 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto op_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("input size", input_args.size(), kEqual, 2, op_name);
|
||||
const int64_t input_num = 2;
|
||||
(void)CheckAndConvertUtils::CheckInteger("input size", SizeToLong(input_args.size()), kEqual, input_num, op_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
|
|
@ -58,10 +58,10 @@ AbstractBasePtr BatchNormGradInfer(const abstract::AnalysisEnginePtr &, const Pr
|
|||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
|
||||
CheckAndConvertUtils::Check("BatchNorm y_backprop_shape", y_backprop_shape, kEqual, "BatchNorm x_shape", x_shape);
|
||||
|
||||
auto dx = input_args[1]->Broaden();
|
||||
auto dscale = input_args[2]->Broaden();
|
||||
auto reserve_1 = input_args[3]->Broaden();
|
||||
auto reserve_2 = input_args[4]->Broaden();
|
||||
auto dx = input_args[kInputIndex1]->Broaden();
|
||||
auto dscale = input_args[kInputIndex2]->Broaden();
|
||||
auto reserve_1 = input_args[kInputIndex3]->Broaden();
|
||||
auto reserve_2 = input_args[kInputIndex4]->Broaden();
|
||||
|
||||
AbstractBasePtrList rets = {dx, dscale, dscale, reserve_1, reserve_2};
|
||||
return std::make_shared<abstract::AbstractTuple>(rets);
|
||||
|
|
|
@ -27,10 +27,11 @@ abstract::ShapePtr BinaryCrossEntroyGradInferShape(const PrimitivePtr &primitive
|
|||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
|
||||
auto weight_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
|
||||
CheckAndConvertUtils::Check("x shape", x_shape, kEqual, "y shape", y_shape, prim_name);
|
||||
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
|
||||
auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
|
||||
auto weight_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
|
||||
(void)CheckAndConvertUtils::Check("x shape", x_shape, kEqual, "y shape", y_shape, prim_name);
|
||||
if (weight_shape.size() < 1) {
|
||||
CheckAndConvertUtils::Check("y shape", y_shape, kEqual, "weight shape", weight_shape, prim_name);
|
||||
}
|
||||
|
@ -40,12 +41,12 @@ abstract::ShapePtr BinaryCrossEntroyGradInferShape(const PrimitivePtr &primitive
|
|||
TypePtr BinaryCrossEntroyGradInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
|
||||
std::map<std::string, TypePtr> types;
|
||||
(void)types.emplace("x_shape", input_args[0]->BuildType());
|
||||
(void)types.emplace("y_shape", input_args[1]->BuildType());
|
||||
(void)types.emplace("x_shape", input_args[kInputIndex0]->BuildType());
|
||||
(void)types.emplace("y_shape", input_args[kInputIndex1]->BuildType());
|
||||
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
|
||||
if (input_args[3]->BuildType() != nullptr) {
|
||||
(void)types.emplace("x_shape", input_args[0]->BuildType());
|
||||
(void)types.emplace("weight_shape", input_args[2]->BuildType());
|
||||
if (input_args[kInputIndex3]->BuildType() != nullptr) {
|
||||
(void)types.emplace("x_shape", input_args[kInputIndex0]->BuildType());
|
||||
(void)types.emplace("weight_shape", input_args[kInputIndex2]->BuildType());
|
||||
infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
|
||||
}
|
||||
return infer_type;
|
||||
|
|
|
@ -29,7 +29,7 @@ abstract::ShapePtr Conv2DBackpropFilterInferShape(const PrimitivePtr &primitive,
|
|||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
// check
|
||||
auto w_size_v = input_args[2]->BuildValue();
|
||||
auto w_size_v = input_args[kInputIndex2]->BuildValue();
|
||||
auto ret_shape = CheckAndConvertUtils::CheckAttrIntOrTupleInt("w_size", w_size_v, prim_name);
|
||||
return std::make_shared<abstract::Shape>(ret_shape);
|
||||
}
|
||||
|
@ -157,8 +157,9 @@ AbstractBasePtr Conv2DBackpropFilterInfer(const abstract::AnalysisEnginePtr &, c
|
|||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
// check
|
||||
(void)CheckAndConvertUtils::CheckInteger("input size", SizeToLong(input_args.size()), kGreaterEqual, 3, prim_name);
|
||||
const int64_t input_num = 3;
|
||||
(void)CheckAndConvertUtils::CheckInteger("input size", SizeToLong(input_args.size()), kGreaterEqual, input_num,
|
||||
prim_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
|
|
@ -67,7 +67,7 @@ abstract::ShapePtr Conv2DBackpropInputInferShape(const PrimitivePtr &primitive,
|
|||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
auto x_size_v = input_args[2]->BuildValue();
|
||||
auto x_size_v = input_args[kInputIndex2]->BuildValue();
|
||||
auto ret_shape = CheckAndConvertUtils::CheckAttrIntOrTupleInt("x_size", x_size_v, prim_name);
|
||||
auto dout_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
auto format = CheckAndConvertUtils::GetAndCheckFormat(primitive->GetAttr(kFormat));
|
||||
|
@ -93,7 +93,9 @@ AbstractBasePtr Conv2DBackpropInputInfer(const abstract::AnalysisEnginePtr &, co
|
|||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
// check
|
||||
(void)CheckAndConvertUtils::CheckInteger("input size", input_args.size(), kGreaterEqual, 3, prim_name);
|
||||
const int64_t input_num = 3;
|
||||
(void)CheckAndConvertUtils::CheckInteger("input size", SizeToLong(input_args.size()), kGreaterEqual, input_num,
|
||||
prim_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
@ -149,7 +151,8 @@ void Conv2DBackpropInput::set_pad_mode(const PadMode &pad_mode) {
|
|||
}
|
||||
|
||||
void Conv2DBackpropInput::set_pad(const std::vector<int64_t> &pad) {
|
||||
(void)CheckAndConvertUtils::CheckInteger("pad_size", SizeToLong(pad.size()), kEqual, 4, name());
|
||||
const int64_t pad_size = 4;
|
||||
(void)CheckAndConvertUtils::CheckInteger("pad_size", SizeToLong(pad.size()), kEqual, pad_size, name());
|
||||
(void)AddAttr(kPad, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, name())));
|
||||
}
|
||||
|
||||
|
|
|
@ -162,8 +162,9 @@ AbstractBasePtr GroupConv2DGradInputInfer(const abstract::AnalysisEnginePtr &, c
|
|||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("group_conv_2D_infer", SizeToLong(input_args.size()), kGreaterEqual, 2,
|
||||
prim_name);
|
||||
const int64_t input_num = 2;
|
||||
(void)CheckAndConvertUtils::CheckInteger("group_conv_2D_infer", SizeToLong(input_args.size()), kGreaterEqual,
|
||||
input_num, prim_name);
|
||||
MS_EXCEPTION_IF_NULL(input_args[0]);
|
||||
|
||||
// Infer shape
|
||||
|
|
|
@ -26,10 +26,11 @@ AbstractBasePtr LayerNormGradInfer(const abstract::AnalysisEnginePtr &, const Pr
|
|||
// Outputs: x_backprob, gamma_backprob, beta_backprob
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto op_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, 5, op_name);
|
||||
auto x_backprob = input_args[0]->Broaden();
|
||||
auto gamma_backprob = input_args[4]->Broaden();
|
||||
auto beta_backprob = input_args[4]->Broaden();
|
||||
const int64_t input_num = 5;
|
||||
(void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, input_num, op_name);
|
||||
auto x_backprob = input_args[kInputIndex0]->Broaden();
|
||||
auto gamma_backprob = input_args[kInputIndex4]->Broaden();
|
||||
auto beta_backprob = input_args[kInputIndex4]->Broaden();
|
||||
auto shapes = std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{
|
||||
x_backprob->BuildShape(), gamma_backprob->BuildShape(), beta_backprob->BuildShape()});
|
||||
auto types = std::make_shared<Tuple>(
|
||||
|
|
|
@ -30,7 +30,8 @@ namespace {
|
|||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("input numbers", input_args.size(), kEqual, 2, prim_name);
|
||||
const int64_t input_num = 2;
|
||||
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
@ -47,7 +48,8 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto prim_name = prim->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("ReLUGrad infer", SizeToLong(input_args.size()), kEqual, 2, prim_name);
|
||||
const int64_t input_num = 2;
|
||||
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
|
||||
MS_EXCEPTION_IF_NULL(input_args[0]);
|
||||
auto dout = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, 0);
|
||||
auto out = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, 1);
|
||||
|
|
|
@ -27,13 +27,7 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("input numbers", input_args.size(), kEqual, 2, prim_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
abstract::ShapePtr InferShape(const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto x = input_args[0]->BuildShape();
|
||||
MS_EXCEPTION_IF_NULL(x);
|
||||
auto shape_element = x->cast<abstract::ShapePtr>();
|
||||
|
@ -41,9 +35,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
return shape_element;
|
||||
}
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto prim_name = prim->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("ReLUGradV2 infer", input_args.size(), kEqual, 2, prim_name);
|
||||
MS_EXCEPTION_IF_NULL(input_args[0]);
|
||||
auto x_type_map = input_args[0]->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(x_type_map);
|
||||
|
@ -55,7 +47,15 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
} // namespace
|
||||
AbstractBasePtr ReLUGradV2Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args));
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
const int64_t input_num = 2;
|
||||
(void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, input_num,
|
||||
prim_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
return abstract::MakeAbstract(InferShape(input_args), InferType(primitive, input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(ReLUGradV2, prim::kPrimReluGradV2, ReLUGradV2Infer, nullptr, true);
|
||||
} // namespace ops
|
||||
|
|
|
@ -31,23 +31,24 @@ AbstractBasePtr SigmoidCrossEntropyWithLogitsGradInfer(const abstract::AnalysisE
|
|||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
const int64_t input_num = 3;
|
||||
(void)CheckAndConvertUtils::CheckInteger("sigmoid_cross_entropy_with_logits_grad_infer",
|
||||
SizeToLong(input_args.size()), kEqual, 3, prim_name);
|
||||
SizeToLong(input_args.size()), kEqual, input_num, prim_name);
|
||||
|
||||
// Infer Shape
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
|
||||
auto dout_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
|
||||
CheckAndConvertUtils::Check("x_shape", x_shape, kEqual, "y_shape", y_shape, prim_name, TypeError);
|
||||
CheckAndConvertUtils::Check("x_shape", x_shape, kEqual, "dout_shape", dout_shape, prim_name, TypeError);
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
|
||||
auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
|
||||
auto dout_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
|
||||
(void)CheckAndConvertUtils::Check("x_shape", x_shape, kEqual, "y_shape", y_shape, prim_name, TypeError);
|
||||
(void)CheckAndConvertUtils::Check("x_shape", x_shape, kEqual, "dout_shape", dout_shape, prim_name, TypeError);
|
||||
|
||||
// Infer type
|
||||
const std::set<TypePtr> valid_types = {kBool, kInt, kInt8, kInt16, kInt32, kInt64, kUInt, kUInt8,
|
||||
kUInt16, kUInt32, kUInt64, kFloat, kFloat16, kFloat32, kFloat64, kComplex64};
|
||||
std::map<std::string, TypePtr> args;
|
||||
args.emplace("x_type", input_args[0]->BuildType());
|
||||
args.emplace("y_type", input_args[1]->BuildType());
|
||||
args.emplace("dout_type", input_args[2]->BuildType());
|
||||
args.emplace("x_type", input_args[kInputIndex0]->BuildType());
|
||||
args.emplace("y_type", input_args[kInputIndex1]->BuildType());
|
||||
args.emplace("dout_type", input_args[kInputIndex2]->BuildType());
|
||||
auto dout_type = CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name);
|
||||
return std::make_shared<abstract::AbstractTensor>(dout_type, x_shape);
|
||||
}
|
||||
|
|
|
@ -38,13 +38,14 @@ AbstractBasePtr SmoothL1LossGradInfer(const abstract::AnalysisEnginePtr &, const
|
|||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("smooth_l1_loss_grad_infer", SizeToLong(input_args.size()), kEqual, 3,
|
||||
prim_name);
|
||||
const int64_t input_num = 3;
|
||||
(void)CheckAndConvertUtils::CheckInteger("smooth_l1_loss_grad_infer", SizeToLong(input_args.size()), kEqual,
|
||||
input_num, prim_name);
|
||||
|
||||
// Infer shape
|
||||
auto prediction = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
auto target = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
|
||||
auto dloss = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
|
||||
auto prediction = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
|
||||
auto target = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
|
||||
auto dloss = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
|
||||
CheckAndConvertUtils::Check("prediction shape", prediction, kEqual, "target shape", target, prim_name, TypeError);
|
||||
CheckAndConvertUtils::Check("prediction shape", prediction, kEqual, "dloss", dloss, prim_name, TypeError);
|
||||
|
||||
|
@ -52,9 +53,9 @@ AbstractBasePtr SmoothL1LossGradInfer(const abstract::AnalysisEnginePtr &, const
|
|||
const std::set<TypePtr> valid_types = {kBool, kInt, kInt8, kInt16, kInt32, kInt64, kUInt, kUInt8,
|
||||
kUInt16, kUInt32, kUInt64, kFloat, kFloat16, kFloat32, kFloat64, kComplex64};
|
||||
std::map<std::string, TypePtr> args;
|
||||
args.emplace("prediction", input_args[0]->BuildType());
|
||||
args.emplace("target", input_args[1]->BuildType());
|
||||
args.emplace("dloss", input_args[2]->BuildType());
|
||||
args.emplace("prediction", input_args[kInputIndex0]->BuildType());
|
||||
args.emplace("target", input_args[kInputIndex1]->BuildType());
|
||||
args.emplace("dloss", input_args[kInputIndex2]->BuildType());
|
||||
auto dloss_type = CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name);
|
||||
|
||||
return std::make_shared<abstract::AbstractTensor>(dloss_type, prediction);
|
||||
|
|
|
@ -14,10 +14,11 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "ops/hashtable_lookup.h"
|
||||
|
||||
#include <vector>
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "ops/op_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
|
@ -32,7 +33,7 @@ AbstractBasePtr HashtableLookupInfer(const abstract::AnalysisEnginePtr &, const
|
|||
(void)CheckAndConvertUtils::CheckInteger("logits size", SizeToLong(input.size()), kGreaterEqual, 1, op_name);
|
||||
hits_shape.push_back(input[0]);
|
||||
|
||||
auto value_type = input_args[2]->BuildType();
|
||||
auto value_type = input_args[kInputIndex2]->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(value_type);
|
||||
auto tensor_type = value_type->cast<TensorTypePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tensor_type);
|
||||
|
|
|
@ -33,15 +33,17 @@ AbstractBasePtr LshProjectionInfer(const abstract::AnalysisEnginePtr &, const Pr
|
|||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto op_name = primitive->name();
|
||||
const int64_t input_num = 2;
|
||||
const int64_t input0_size = 2;
|
||||
const int64_t input0_last_dim = 32;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, input_num, op_name);
|
||||
auto input0 = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
auto input1 = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
|
||||
(void)CheckAndConvertUtils::CheckInteger("input0 rank", SizeToLong(input0.size()), kEqual, 2, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("input0_shape_dimen_1", input0[1], kLessEqual, 32, op_name);
|
||||
auto input0 = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
|
||||
auto input1 = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
|
||||
(void)CheckAndConvertUtils::CheckInteger("input0 rank", SizeToLong(input0.size()), kEqual, input0_size, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("input0_shape_dimen_1", input0[1], kLessEqual, input0_last_dim, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("input1 rank", SizeToLong(input1.size()), kGreaterEqual, 1, op_name);
|
||||
|
||||
if (input_args.size() == 3) {
|
||||
auto input2 = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
|
||||
auto input2 = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
|
||||
(void)CheckAndConvertUtils::CheckInteger("input2 rank", SizeToLong(input2.size()), kEqual, 1, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("input2_shape_dimen_0", input2[0], kEqual, input1[0], op_name);
|
||||
}
|
||||
|
|
|
@ -32,16 +32,23 @@ AbstractBasePtr LstmInfer(const PrimitivePtr &primitive, const std::vector<Abstr
|
|||
// infer shape
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("lstm_prim_infer", SizeToLong(input_args.size()), kEqual, 4, prim_name);
|
||||
auto x_input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
auto h_input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
|
||||
auto c_input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
|
||||
const int64_t input_num = 4;
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
|
||||
auto x_input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
|
||||
auto h_input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
|
||||
auto c_input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
|
||||
|
||||
int64_t input_x_size = GetValue<int64_t>(primitive->GetAttr(kInput_size));
|
||||
(void)CheckAndConvertUtils::CheckInteger("x_shape.size()", SizeToLong(x_input_shape.size()), kEqual, 3, prim_name);
|
||||
const int64_t shape_size = 3;
|
||||
(void)CheckAndConvertUtils::CheckInteger("x_shape.size()", SizeToLong(x_input_shape.size()), kEqual, shape_size,
|
||||
prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("x_shape[2]", x_input_shape[2], kEqual, input_x_size, prim_name);
|
||||
|
||||
(void)CheckAndConvertUtils::CheckInteger("h_shape.size()", SizeToLong(h_input_shape.size()), kEqual, 3, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("h_shape.size()", SizeToLong(h_input_shape.size()), kEqual, shape_size,
|
||||
prim_name);
|
||||
CheckAndConvertUtils::Check("h_shape", h_input_shape, kEqual, "c_shape", c_input_shape, prim_name);
|
||||
|
||||
int64_t num_layers = GetValue<int64_t>(primitive->GetAttr(kNumLayers));
|
||||
|
@ -81,15 +88,11 @@ AbstractBasePtr LstmInfer(const PrimitivePtr &primitive, const std::vector<Abstr
|
|||
std::vector<int64_t> state_shape = {1, 1};
|
||||
|
||||
// infer type
|
||||
(void)CheckAndConvertUtils::CheckInteger("lstm_prim_infer", SizeToLong(input_args.size()), kEqual, 4, prim_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
auto infer_type0 = input_args[0]->BuildType()->cast<TensorTypePtr>()->element();
|
||||
auto infer_type1 = input_args[1]->BuildType()->cast<TensorTypePtr>()->element();
|
||||
auto infer_type2 = input_args[2]->BuildType()->cast<TensorTypePtr>()->element();
|
||||
auto infer_type3 = input_args[3]->BuildType()->cast<TensorTypePtr>()->element();
|
||||
auto infer_type4 = input_args[4]->BuildType()->cast<TensorTypePtr>()->element();
|
||||
auto infer_type0 = input_args[kInputIndex0]->BuildType()->cast<TensorTypePtr>()->element();
|
||||
auto infer_type1 = input_args[kInputIndex1]->BuildType()->cast<TensorTypePtr>()->element();
|
||||
auto infer_type2 = input_args[kInputIndex2]->BuildType()->cast<TensorTypePtr>()->element();
|
||||
auto infer_type3 = input_args[kInputIndex3]->BuildType()->cast<TensorTypePtr>()->element();
|
||||
auto infer_type4 = input_args[kInputIndex4]->BuildType()->cast<TensorTypePtr>()->element();
|
||||
auto output0 = std::make_shared<abstract::AbstractTensor>(infer_type0, x_shape);
|
||||
auto output1 = std::make_shared<abstract::AbstractTensor>(infer_type1, y_shape);
|
||||
auto output2 = std::make_shared<abstract::AbstractTensor>(infer_type2, c_shape);
|
||||
|
|
|
@ -125,7 +125,8 @@ bool MatMul::get_transpose_b() const {
|
|||
|
||||
AbstractBasePtr MatMulInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
(void)CheckAndConvertUtils::CheckInteger("MatMul infer", SizeToLong(input_args.size()), kGreaterEqual, 2,
|
||||
const int64_t input_num = 2;
|
||||
(void)CheckAndConvertUtils::CheckInteger("MatMul infer", SizeToLong(input_args.size()), kGreaterEqual, input_num,
|
||||
primitive->name());
|
||||
return abstract::MakeAbstract(MatMulInferShape(primitive, input_args), MatMulInferType(primitive, input_args));
|
||||
}
|
||||
|
|
|
@ -88,7 +88,8 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
if (format == NHWC) {
|
||||
in_shape = {in_shape[0], in_shape[3], in_shape[1], in_shape[2]};
|
||||
}
|
||||
(void)CheckAndConvertUtils::CheckInteger("x_rank", SizeToLong(in_shape.size()), kEqual, 4, op_name);
|
||||
const int64_t x_rank = 4;
|
||||
(void)CheckAndConvertUtils::CheckInteger("x_rank", SizeToLong(in_shape.size()), kEqual, x_rank, op_name);
|
||||
|
||||
auto kernel_size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kKernelSize));
|
||||
auto pad_mode_value = (primitive->GetAttr(kPadMode));
|
||||
|
|
|
@ -25,10 +25,14 @@ namespace {
|
|||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
const int64_t input0_size = 3;
|
||||
const int64_t input1_size = 1;
|
||||
auto first_input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
auto second_input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
|
||||
(void)CheckAndConvertUtils::CheckInteger("input 0 rank", SizeToLong(first_input_shape.size()), kEqual, 3, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("input 1 rank", SizeToLong(second_input_shape.size()), kEqual, 1, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("input 0 rank", SizeToLong(first_input_shape.size()), kEqual, input0_size,
|
||||
prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("input 1 rank", SizeToLong(second_input_shape.size()), kEqual, input1_size,
|
||||
prim_name);
|
||||
std::vector<int64_t> out_shape = {first_input_shape[0], first_input_shape[1],
|
||||
GetValue<int64_t>(primitive->GetAttr(kDctCoeffNum))};
|
||||
return std::make_shared<abstract::Shape>(out_shape);
|
||||
|
|
|
@ -35,14 +35,16 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, 2, prim->name());
|
||||
auto op_name = prim->name();
|
||||
const int64_t input_num = 2;
|
||||
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, op_name);
|
||||
if (std::any_of(input_args.begin(), input_args.end(), [](const AbstractBasePtr &a) { return a == nullptr; })) {
|
||||
MS_LOG(EXCEPTION) << "nullptr";
|
||||
}
|
||||
std::map<std::string, TypePtr> types;
|
||||
(void)types.emplace("x", input_args[0]->BuildType());
|
||||
(void)types.emplace("y", input_args[1]->BuildType());
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, op_name);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
|
|
|
@ -30,7 +30,8 @@ namespace {
|
|||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("input numbers", input_args.size(), kEqual, 2, prim_name);
|
||||
const int64_t input_num = 2;
|
||||
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
@ -42,7 +43,9 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
auto op_name = prim->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("Mul infer", input_args.size(), kGreaterEqual, 2, op_name);
|
||||
const int64_t input_num = 2;
|
||||
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kGreaterEqual, input_num,
|
||||
op_name);
|
||||
std::map<std::string, TypePtr> types;
|
||||
(void)types.emplace("x", input_args[0]->BuildType());
|
||||
(void)types.emplace("y", input_args[1]->BuildType());
|
||||
|
|
|
@ -27,23 +27,8 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto op_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, 2, op_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
return BroadCastInferShape(op_name, input_args);
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto op_name = prim->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kGreaterEqual, 2, op_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
std::map<std::string, TypePtr> types;
|
||||
(void)types.emplace("x", input_args[0]->BuildType());
|
||||
(void)types.emplace("y", input_args[1]->BuildType());
|
||||
|
@ -53,8 +38,15 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
|
||||
AbstractBasePtr NotEqualInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto op_name = primitive->name();
|
||||
const int64_t input_num = 2;
|
||||
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, op_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
(void)InferType(primitive, input_args);
|
||||
return abstract::MakeAbstract(InferShape(primitive, input_args), kBool);
|
||||
return abstract::MakeAbstract(BroadCastInferShape(op_name, input_args), kBool);
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameNotEqual, NotEqual);
|
||||
} // namespace ops
|
||||
|
|
|
@ -59,11 +59,12 @@ abstract::ShapePtr OneHotInferShape(const PrimitivePtr &primitive, const std::ve
|
|||
|
||||
TypePtr OneHotInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto op_name = prim->name();
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("indices", input_args[0]->BuildType(), {kInt32, kInt64}, op_name);
|
||||
(void)CheckAndConvertUtils::CheckTypeValid("depth", input_args[1]->BuildType(), {kInt8, kInt16, kInt32, kInt64},
|
||||
op_name);
|
||||
std::map<std::string, TypePtr> args = {{"on_value", input_args[2]->BuildType()},
|
||||
{"off_dtype", input_args[3]->BuildType()}};
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("indices", input_args[kInputIndex0]->BuildType(), {kInt32, kInt64},
|
||||
op_name);
|
||||
(void)CheckAndConvertUtils::CheckTypeValid("depth", input_args[kInputIndex1]->BuildType(),
|
||||
{kInt8, kInt16, kInt32, kInt64}, op_name);
|
||||
std::map<std::string, TypePtr> args = {{"on_value", input_args[kInputIndex2]->BuildType()},
|
||||
{"off_dtype", input_args[kInputIndex3]->BuildType()}};
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(args, {kFloat16, kFloat32}, op_name);
|
||||
}
|
||||
} // namespace
|
||||
|
|
|
@ -252,6 +252,27 @@ constexpr auto kSplitDim = "split_dim";
|
|||
constexpr auto kPadTop = "pad_top";
|
||||
constexpr auto kTransFormat = "trans_format";
|
||||
constexpr auto kApproximate = "approximate";
|
||||
|
||||
enum Index : size_t {
|
||||
kInputIndex0 = 0,
|
||||
kInputIndex1,
|
||||
kInputIndex2,
|
||||
kInputIndex3,
|
||||
kInputIndex4,
|
||||
kInputIndex5,
|
||||
kInputIndex6,
|
||||
kInputIndex7,
|
||||
kInputIndex8,
|
||||
kInputIndex9,
|
||||
kInputIndex10,
|
||||
kInputIndex11,
|
||||
kInputIndex12,
|
||||
kInputIndex13,
|
||||
kInputIndex14,
|
||||
kInputIndex15,
|
||||
kInputIndex16,
|
||||
};
|
||||
|
||||
const std::set<TypePtr> common_valid_types = {kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16,
|
||||
kUInt32, kUInt64, kFloat16, kFloat32, kFloat64};
|
||||
|
||||
|
|
|
@ -26,9 +26,10 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
auto w = input_args[1]->BuildShape();
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(x)[kShape];
|
||||
auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(w)[kShape];
|
||||
|
||||
(void)CheckAndConvertUtils::CheckInteger("x rank", SizeToLong(x_shape.size()), kGreaterEqual, 2, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("weight rank", SizeToLong(w_shape.size()), kEqual, 1, prim_name);
|
||||
const int64_t x_rank = 2;
|
||||
const int64_t w_rank = 1;
|
||||
(void)CheckAndConvertUtils::CheckInteger("x rank", SizeToLong(x_shape.size()), kGreaterEqual, x_rank, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("weight rank", SizeToLong(w_shape.size()), kEqual, w_rank, prim_name);
|
||||
if (w_shape[0] != x_shape[1] && w_shape[0] != 1) {
|
||||
MS_LOG(EXCEPTION) << "For " << prim_name << ", channel of input_x and weight must be matched, "
|
||||
<< "while channel of input_x is " << x_shape[1] << ", weight_shape[0] is " << w_shape[0];
|
||||
|
|
|
@ -61,13 +61,14 @@ AbstractBasePtr RangeInfer(const abstract::AnalysisEnginePtr &, const PrimitiveP
|
|||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
int64_t shape_size = 0;
|
||||
if (input_args.size() == 3) {
|
||||
MS_EXCEPTION_IF_NULL(input_args[0]->BuildValue());
|
||||
MS_EXCEPTION_IF_NULL(input_args[1]->BuildValue());
|
||||
MS_EXCEPTION_IF_NULL(input_args[2]->BuildValue());
|
||||
auto start_tensor = input_args[0]->BuildValue()->cast<tensor::TensorPtr>();
|
||||
auto limit_tensor = input_args[1]->BuildValue()->cast<tensor::TensorPtr>();
|
||||
auto delta_tensor = input_args[2]->BuildValue()->cast<tensor::TensorPtr>();
|
||||
const size_t max_input_num = 3;
|
||||
if (input_args.size() == max_input_num) {
|
||||
MS_EXCEPTION_IF_NULL(input_args[kInputIndex0]->BuildValue());
|
||||
MS_EXCEPTION_IF_NULL(input_args[kInputIndex1]->BuildValue());
|
||||
MS_EXCEPTION_IF_NULL(input_args[kInputIndex2]->BuildValue());
|
||||
auto start_tensor = input_args[kInputIndex0]->BuildValue()->cast<tensor::TensorPtr>();
|
||||
auto limit_tensor = input_args[kInputIndex1]->BuildValue()->cast<tensor::TensorPtr>();
|
||||
auto delta_tensor = input_args[kInputIndex2]->BuildValue()->cast<tensor::TensorPtr>();
|
||||
auto dtype = start_tensor->data_type();
|
||||
switch (dtype) {
|
||||
case kNumberTypeInt:
|
||||
|
|
|
@ -29,7 +29,8 @@ namespace {
|
|||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, 2, prim_name);
|
||||
const int64_t input_num = 2;
|
||||
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
@ -42,7 +43,8 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
auto op_name = prim->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("RealDiv infer", SizeToLong(input_args.size()), kGreaterEqual, 2, op_name);
|
||||
const int64_t input_num = 2;
|
||||
(void)CheckAndConvertUtils::CheckInteger("infer", SizeToLong(input_args.size()), kGreaterEqual, input_num, op_name);
|
||||
std::map<std::string, TypePtr> types;
|
||||
(void)types.emplace("x", input_args[0]->BuildType());
|
||||
(void)types.emplace("y", input_args[1]->BuildType());
|
||||
|
|
|
@ -31,7 +31,8 @@ AbstractBasePtr ReshapeInfer(const abstract::AnalysisEnginePtr &, const Primitiv
|
|||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, 2, prim_name);
|
||||
const int64_t input_num = 2;
|
||||
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
|
|
@ -50,7 +50,8 @@ AbstractBasePtr ResizeBilinearInfer(const abstract::AnalysisEnginePtr &, const P
|
|||
|
||||
// Infer shape
|
||||
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
(void)CheckAndConvertUtils::CheckInteger("input_shape_rank", SizeToLong(input_shape.size()), kEqual, 4, prim_name);
|
||||
const int64_t shape_size = 4;
|
||||
(void)CheckAndConvertUtils::CheckInteger("input rank", SizeToLong(input_shape.size()), kEqual, shape_size, prim_name);
|
||||
std::vector<int64_t> out_shape = {input_shape[0], input_shape[1]};
|
||||
auto size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kSize));
|
||||
(void)out_shape.insert(out_shape.end(), size.begin(), size.end());
|
||||
|
|
|
@ -39,7 +39,9 @@ AbstractBasePtr ReverseSequenceInfer(const abstract::AnalysisEnginePtr &, const
|
|||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, 2, prim_name);
|
||||
const int64_t input_num = 2;
|
||||
(void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, input_num,
|
||||
prim_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
|
|
@ -52,7 +52,8 @@ AbstractBasePtr ROIPoolingInfer(const abstract::AnalysisEnginePtr &, const Primi
|
|||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("roi_pooling_infer", SizeToLong(input_args.size()), kEqual, 2, prim_name);
|
||||
const int64_t input_num = 2;
|
||||
(void)CheckAndConvertUtils::CheckInteger("infer", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
|
||||
MS_EXCEPTION_IF_NULL(input_args[0]);
|
||||
MS_EXCEPTION_IF_NULL(input_args[1]);
|
||||
|
||||
|
|
|
@ -24,13 +24,13 @@ namespace mindspore {
|
|||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto shape_value = input_args[2]->BuildValue();
|
||||
auto shape_value = input_args[kInputIndex2]->BuildValue();
|
||||
auto shape_value_element = GetValue<std::vector<int64_t>>(shape_value);
|
||||
for (const auto &shape : shape_value_element) {
|
||||
(void)CheckAndConvertUtils::CheckInteger("shape value", shape, kGreaterThan, 0, "ScatterNd");
|
||||
}
|
||||
auto indices_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
auto update_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
|
||||
auto indices_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
|
||||
auto update_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
|
||||
(void)CheckAndConvertUtils::CheckInteger("indices_shape[0] and update_shape[0]", indices_shape[0], kEqual,
|
||||
update_shape[0], "ScatterNd");
|
||||
return std::make_shared<abstract::Shape>(shape_value_element);
|
||||
|
|
|
@ -30,8 +30,9 @@ AbstractBasePtr SigmoidCrossEntropyWithLogitsInfer(const abstract::AnalysisEngin
|
|||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
const int64_t input_num = 2;
|
||||
(void)CheckAndConvertUtils::CheckInteger("sigmoid_cross_extropy_with_logits_infer", SizeToLong(input_args.size()),
|
||||
kEqual, 2, prim_name);
|
||||
kEqual, input_num, prim_name);
|
||||
|
||||
// Infer shape
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
|
|
|
@ -30,7 +30,8 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
(void)CheckAndConvertUtils::CheckInteger("input shape", SizeToLong(input_shape.size()), kEqual, 4, prim_name);
|
||||
const int64_t x_rank = 4;
|
||||
(void)CheckAndConvertUtils::CheckInteger("x rank", SizeToLong(input_shape.size()), kEqual, x_rank, prim_name);
|
||||
std::vector<int64_t> output_shape(input_shape.size());
|
||||
auto block_shape_vector = GetValue<std::vector<int64_t>>(primitive->GetAttr(kBlockSize));
|
||||
auto paddings = GetValue<std::vector<std::vector<int64_t>>>(primitive->GetAttr(kPaddings));
|
||||
|
|
|
@ -30,7 +30,8 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
(void)CheckAndConvertUtils::CheckInteger("input_x rank", SizeToLong(x_shape.size()), kEqual, 4, prim_name);
|
||||
const int64_t shape_size = 4;
|
||||
(void)CheckAndConvertUtils::CheckInteger("input_x rank", SizeToLong(x_shape.size()), kEqual, shape_size, prim_name);
|
||||
auto out_shape = x_shape;
|
||||
int64_t block_shape_prod = 1;
|
||||
const size_t offset = 2;
|
||||
|
@ -60,7 +61,8 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
} // namespace
|
||||
|
||||
void SpaceToBatchND::set_paddings(std::vector<std::vector<int64_t>> paddings) {
|
||||
(void)CheckAndConvertUtils::CheckInteger(kPaddings, SizeToLong(paddings.size()), kEqual, 2, this->name());
|
||||
const int64_t pad_size = 2;
|
||||
(void)CheckAndConvertUtils::CheckInteger(kPaddings, SizeToLong(paddings.size()), kEqual, pad_size, this->name());
|
||||
size_t h = paddings.size();
|
||||
size_t w = paddings[0].size();
|
||||
std::vector<size_t> temp_w = {2, 2};
|
||||
|
@ -78,7 +80,9 @@ std::vector<std::vector<int64_t>> SpaceToBatchND::get_paddings() const {
|
|||
return GetValue<std::vector<std::vector<int64_t>>>(value_ptr);
|
||||
}
|
||||
void SpaceToBatchND::set_block_shape(std::vector<int64_t> block_shape) {
|
||||
(void)CheckAndConvertUtils::CheckInteger(kBlockShape, SizeToLong(block_shape.size()), kEqual, 2, this->name());
|
||||
const int64_t block_size = 2;
|
||||
(void)CheckAndConvertUtils::CheckInteger(kBlockShape, SizeToLong(block_shape.size()), kEqual, block_size,
|
||||
this->name());
|
||||
for (size_t i = 0; i < block_shape.size(); i++) {
|
||||
(void)CheckAndConvertUtils::CheckInteger(kBlockShape, SizeToLong(block_shape[i]), kGreaterEqual, 1, this->name());
|
||||
}
|
||||
|
|
|
@ -38,7 +38,9 @@ AbstractBasePtr SparseSoftmaxCrossEntropyWithLogitsInfer(const abstract::Analysi
|
|||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("input numbers", input_args.size(), kEqual, 2, prim_name);
|
||||
const int64_t input_num = 2;
|
||||
(void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, input_num,
|
||||
prim_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
|
|
@ -28,12 +28,13 @@ AbstractBasePtr SparseToDenseInfer(const abstract::AnalysisEnginePtr &, const Pr
|
|||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, 3, prim_name);
|
||||
const int64_t input_num = 3;
|
||||
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
// infer shape
|
||||
auto dense_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[3]->BuildShape())[kShape];
|
||||
auto dense_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->BuildShape())[kShape];
|
||||
// infer type
|
||||
auto values_type = input_args[1]->BuildType()->cast<TensorTypePtr>()->element();
|
||||
return std::make_shared<abstract::AbstractTensor>(values_type, dense_shape);
|
||||
|
|
|
@ -173,18 +173,18 @@ abstract::ShapePtr StridedSliceInferShape(const PrimitivePtr &primitive,
|
|||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto strided_slice_prim = primitive->cast<PrimStridedSlicePtr>();
|
||||
MS_EXCEPTION_IF_NULL(strided_slice_prim);
|
||||
auto tuple_begin_v = input_args[1]->cast<abstract::AbstractTuplePtr>();
|
||||
auto tuple_begin_v = input_args[kInputIndex1]->cast<abstract::AbstractTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tuple_begin_v);
|
||||
auto temp_begin_v = tuple_begin_v->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(temp_begin_v);
|
||||
auto begin_v = GetValue<std::vector<int64_t>>(temp_begin_v);
|
||||
|
||||
auto tuple_end_v = input_args[2]->cast<abstract::AbstractTuplePtr>();
|
||||
auto tuple_end_v = input_args[kInputIndex2]->cast<abstract::AbstractTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tuple_end_v);
|
||||
auto temp_end_v = tuple_end_v->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(temp_end_v);
|
||||
auto end_v = GetValue<std::vector<int64_t>>(temp_end_v);
|
||||
auto strides_v = CheckAndGetValidStrides(input_args[3]);
|
||||
auto strides_v = CheckAndGetValidStrides(input_args[kInputIndex3]);
|
||||
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
auto min_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kMinShape];
|
||||
|
|
|
@ -30,7 +30,8 @@ namespace {
|
|||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, 2, prim_name);
|
||||
const int64_t input_num = 2;
|
||||
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
@ -42,7 +43,8 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
auto op_name = prim->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("Sub infer", SizeToLong(input_args.size()), kGreaterEqual, 2, op_name);
|
||||
const int64_t input_num = 2;
|
||||
(void)CheckAndConvertUtils::CheckInteger("infer", SizeToLong(input_args.size()), kGreaterEqual, input_num, op_name);
|
||||
std::map<std::string, TypePtr> types;
|
||||
(void)types.emplace("x", input_args[0]->BuildType());
|
||||
(void)types.emplace("y", input_args[1]->BuildType());
|
||||
|
|
|
@ -77,7 +77,8 @@ abstract::ShapePtr TileInferShape(const PrimitivePtr &primitive, const std::vect
|
|||
TypePtr TileInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto prim_name = prim->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("tile_prim_infer", input_args.size(), kEqual, 2, prim_name);
|
||||
const int64_t input_num = 2;
|
||||
(void)CheckAndConvertUtils::CheckInteger("infer", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
|
|
@ -32,7 +32,8 @@ AbstractBasePtr TopKInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt
|
|||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("top_k_infer", SizeToLong(input_args.size()), kEqual, 2, prim_name);
|
||||
const int64_t input_num = 2;
|
||||
(void)CheckAndConvertUtils::CheckInteger("top_k_infer", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
|
||||
|
||||
// Infer dtype
|
||||
for (const auto &item : input_args) {
|
||||
|
|
|
@ -55,11 +55,11 @@ AbstractBasePtr UnsortedSegmentSumInfer(const abstract::AnalysisEnginePtr &, con
|
|||
}
|
||||
|
||||
const std::set<TypePtr> valid_num_segments_types = {kInt32, kInt64};
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("num_segments", input_args[2]->BuildType(), valid_num_segments_types,
|
||||
prim_name);
|
||||
int64_t size_segment_ids_shp = segment_ids_shape.size();
|
||||
int64_t size_x_shpe = x_shape.size();
|
||||
for (int64_t i = size_segment_ids_shp; i < size_x_shpe; ++i) {
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("num_segments", input_args[kInputIndex2]->BuildType(),
|
||||
valid_num_segments_types, prim_name);
|
||||
auto size_segment_ids_shp = segment_ids_shape.size();
|
||||
auto size_x_shpe = x_shape.size();
|
||||
for (size_t i = size_segment_ids_shp; i < size_x_shpe; ++i) {
|
||||
(void)shp.emplace_back(x_shape[i]);
|
||||
}
|
||||
|
||||
|
|
|
@ -29,16 +29,18 @@ AbstractBasePtr WhereInfer(const abstract::AnalysisEnginePtr &, const PrimitiveP
|
|||
MS_EXCEPTION_IF_NULL(input);
|
||||
}
|
||||
auto op_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("input numbers", input_args.size(), kGreaterEqual, 3, op_name);
|
||||
auto input0_type_ = input_args[0]->BuildType()->cast<TensorTypePtr>();
|
||||
const int64_t input_num = 3;
|
||||
(void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kGreaterEqual, input_num,
|
||||
op_name);
|
||||
auto input0_type_ = input_args[kInputIndex0]->BuildType()->cast<TensorTypePtr>();
|
||||
MS_EXCEPTION_IF_NULL(input0_type_);
|
||||
auto input0_type = input0_type_->element();
|
||||
auto input0_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
auto num = input_args[0]->BuildValue()->cast<tensor::TensorPtr>()->ElementsNum();
|
||||
auto input1_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
|
||||
auto num1 = input_args[1]->BuildValue()->cast<tensor::TensorPtr>()->ElementsNum();
|
||||
auto input2_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
|
||||
auto num2 = input_args[2]->BuildValue()->cast<tensor::TensorPtr>()->ElementsNum();
|
||||
auto input0_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
|
||||
auto num = input_args[kInputIndex0]->BuildValue()->cast<tensor::TensorPtr>()->ElementsNum();
|
||||
auto input1_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
|
||||
auto num1 = input_args[kInputIndex1]->BuildValue()->cast<tensor::TensorPtr>()->ElementsNum();
|
||||
auto input2_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
|
||||
auto num2 = input_args[kInputIndex2]->BuildValue()->cast<tensor::TensorPtr>()->ElementsNum();
|
||||
auto nummax = num > num1 ? num : (num1 > num2 ? num1 : num2);
|
||||
size_t axisout = 0;
|
||||
size_t temp = 0;
|
||||
|
|
Loading…
Reference in New Issue