code check clean

This commit is contained in:
王南 2021-09-14 15:23:56 +08:00
parent f5c3901a3a
commit bd4707f7c9
33 changed files with 101 additions and 94 deletions

View File

@ -80,7 +80,7 @@ AbstractBasePtr InferImplPooling(const AnalysisEnginePtr &, const PrimitivePtr &
auto pad_mode_ptr = primitive->GetAttr("pad_mode");
if (pad_mode_ptr != nullptr) {
int64_t pad_mode;
(void)CheckAndConvertUtils::GetPadModEnumValue(pad_mode_ptr, &pad_mode, true);
CheckAndConvertUtils::GetPadModEnumValue(pad_mode_ptr, &pad_mode, true);
if (pad_mode == PadMode::VALID) {
padding = 0;
} else if (pad_mode == PadMode::SAME) {
@ -298,7 +298,7 @@ AbstractBasePtr InferImplConv2D(const AnalysisEnginePtr &, const PrimitivePtr &p
std::vector<int64_t> padding =
CheckAttrIntOrTuple(op_name, primitive->GetAttr("pad"), padding_start_idx, padding_num_element);
int64_t pad_mode;
(void)CheckAndConvertUtils::GetPadModEnumValue(primitive->GetAttr("pad_mode"), &pad_mode);
CheckAndConvertUtils::GetPadModEnumValue(primitive->GetAttr("pad_mode"), &pad_mode);
std::vector<int64_t> output_hw;
std::vector<int64_t> pad_list;
std::vector<int64_t> output_hw_min;

View File

@ -47,7 +47,9 @@ AbstractBasePtr AssertInfer(const abstract::AnalysisEnginePtr &, const Primitive
(void)CheckAndConvertUtils::CheckInteger("condition's rank", SizeToLong(condition_values.size()), kLessEqual, 1,
op_name);
if (condition_values.size() == 1) {
(void)CheckAndConvertUtils::CheckInteger("condition[0]", SizeToLong(condition_values[0]), kEqual, 1, op_name);
if (!condition_values[0]) {
MS_EXCEPTION(ValueError) << "condition value must be `true` when only one value contained.";
}
}
condition = TypeIdToType(kNumberTypeBool);
} else {
@ -56,7 +58,9 @@ AbstractBasePtr AssertInfer(const abstract::AnalysisEnginePtr &, const Primitive
if (condition_shape[0] == 1) {
auto condition_value = reinterpret_cast<bool *>(input_args[0]->BuildValue()->cast<tensor::TensorPtr>()->data_c());
MS_EXCEPTION_IF_NULL(condition_value);
(void)CheckAndConvertUtils::CheckInteger("condition[0]", *condition_value, kEqual, 1, op_name);
if (!*condition_value) {
MS_EXCEPTION(ValueError) << "condition value must be `true` when only one value contained.";
}
}
condition = input_args[0]->BuildType();
}

View File

@ -113,7 +113,7 @@ void Conv2DPadFunction(std::vector<int64_t> *output_hw, std::vector<int64_t> *pa
int64_t pad_needed_w = (output_hw->at(1) - 1) * stride[1] + dilation[1] * (kernel[1] - 1) + 1 - x_w;
pad_needed_w = std::max((int64_t)0, pad_needed_w);
pad_list->push_back(static_cast<int64_t>(std::floor(pad_needed_w / 2)));
pad_list->push_back(pad_needed_w - pad_list->at(2));
pad_list->push_back(pad_needed_w - pad_list->at(kInputIndex2));
}
} else if (pad_mode == PadMode::PAD) {
(void)pad_list->insert(pad_list->begin(), padding.begin(), padding.end());
@ -128,9 +128,10 @@ void Conv2DPadFunction(std::vector<int64_t> *output_hw, std::vector<int64_t> *pa
}
}
if (x_w != Shape::SHP_ANY) {
out_w = static_cast<int64_t>(std::floor(
1 + ((x_w * 1.0) + pad_list->at(2) + pad_list->at(3) - kernel[1] - (kernel[1] - 1) * (dilation[1] - 1)) /
stride[1]));
out_w =
static_cast<int64_t>(std::floor(1 + ((x_w * 1.0) + pad_list->at(kInputIndex2) + pad_list->at(kInputIndex3) -
kernel[1] - (kernel[1] - 1) * (dilation[1] - 1)) /
stride[1]));
if (is_min_shape && out_w < 1) {
out_w = 1L;
}

View File

@ -27,7 +27,8 @@ namespace {
abstract::ShapePtr DiagInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
(void)CheckAndConvertUtils::CheckInteger("input rank", input_shape.size(), kGreaterEqual, 1, primitive->name());
(void)CheckAndConvertUtils::CheckInteger("input rank", SizeToLong(input_shape.size()), kGreaterEqual, 1L,
primitive->name());
std::vector<int64_t> out_shape(input_shape);
(void)out_shape.insert(out_shape.end(), input_shape.begin(), input_shape.end());
return std::make_shared<abstract::Shape>(out_shape);

View File

@ -29,6 +29,7 @@
namespace mindspore {
namespace ops {
namespace {
const int64_t mask_convert_len = 128;
ShapeVector CalDynamicOutputShape(const ValuePtrList value_list) {
int64_t count = 1;
size_t x_rank = value_list.size();
@ -52,8 +53,8 @@ ShapeVector CalDynamicOutputShape(const ValuePtrList value_list) {
}
// convert to bytes(8 bits) mask, using round up
int64_t n128s = count / 128;
if ((count % 128) != 0) {
int64_t n128s = count / mask_convert_len;
if ((count % mask_convert_len) != 0) {
n128s++;
}
int64_t bytes_count = n128s * 16;
@ -87,8 +88,8 @@ ShapeVector CalOutputShape(const AbstractBasePtrList shape_list) {
}
// convert to bytes(8 bits) mask, using round up
int64_t n128s = count / 128;
if ((count % 128) != 0) {
int64_t n128s = count / mask_convert_len;
if ((count % mask_convert_len) != 0) {
n128s++;
}
int64_t bytes_count = n128s * 16;
@ -115,7 +116,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
if (shape->shape().size() != 1) {
MS_EXCEPTION(TypeError) << "Input `shape` must be a 1-D Tensor.";
}
size_t shape_rank = shape->shape()[0];
size_t shape_rank = LongToSize(shape->shape()[0]);
auto shape_max = shape_abstract->get_max_value();
MS_EXCEPTION_IF_NULL(shape_max);

View File

@ -27,7 +27,7 @@
namespace mindspore {
namespace ops {
namespace {
size_t CheckInputsAndGetShape(const AbstractBasePtr &input_arg, const string &prim_name) {
int64_t CheckInputsAndGetShape(const AbstractBasePtr &input_arg, const string &prim_name) {
MS_EXCEPTION_IF_NULL(input_arg);
if (input_arg->isa<abstract::AbstractTensor>()) {
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_arg->BuildShape())[kShape];
@ -57,13 +57,13 @@ abstract::TupleShapePtr Infer(const PrimitivePtr &primitive, const std::vector<A
auto prim_name = primitive->name();
const int64_t input_num = 2;
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
auto x_shape = CheckInputsAndGetShape(input_args[0], prim_name);
auto y_shape = CheckInputsAndGetShape(input_args[1], prim_name);
auto x_shape0 = CheckInputsAndGetShape(input_args[0], prim_name);
auto y_shape0 = CheckInputsAndGetShape(input_args[1], prim_name);
ShapeVector shape{abstract::Shape::SHP_ANY};
ShapeVector min_shape{1L};
size_t max_size = x_shape > y_shape ? x_shape : y_shape;
ShapeVector max_shape{SizeToLong(max_size)};
int64_t max_size = x_shape0 > y_shape0 ? x_shape0 : y_shape0;
ShapeVector max_shape{max_size};
auto out_shape = std::make_shared<abstract::Shape>(shape, min_shape, max_shape);
return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{out_shape, out_shape});
}

View File

@ -30,9 +30,9 @@ abstract::ShapePtr BinaryCrossEntroyGradInferShape(const PrimitivePtr &primitive
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
auto weight_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
(void)CheckAndConvertUtils::Check("x shape", x_shape, kEqual, "y shape", y_shape, prim_name);
CheckAndConvertUtils::Check("x shape", x_shape, kEqual, "y shape", y_shape, prim_name);
if (weight_shape.size() < 1) {
(void)CheckAndConvertUtils::Check("y shape", y_shape, kEqual, "weight shape", weight_shape, prim_name);
CheckAndConvertUtils::Check("y shape", y_shape, kEqual, "weight shape", weight_shape, prim_name);
}
return std::make_shared<abstract::Shape>(x_shape);
}

View File

@ -43,9 +43,9 @@ void TransStrideTo4D(const PrimitivePtr &primitive, const std::vector<AbstractBa
auto stride_value = GetValue<std::vector<int64_t>>(stride);
if (stride_value.size() == kStride2dSize) {
std::vector<int64_t> stride_value_4d(stride_value);
stride_value_4d.insert(stride_value_4d.begin(), 1);
stride_value_4d.insert(stride_value_4d.begin(), 1);
(void)primitive->set_attr(kStride, MakeValue(stride_value_4d));
(void)stride_value_4d.insert(stride_value_4d.begin(), 1);
(void)stride_value_4d.insert(stride_value_4d.begin(), 1);
primitive->set_attr(kStride, MakeValue(stride_value_4d));
}
return;
}

View File

@ -41,7 +41,7 @@ void SetPadList(const PrimitivePtr &primitive, const std::vector<int64_t> &dout_
// default pad mode is valid
auto attr_pad_list_prt = primitive->GetAttr(kPadList);
int64_t pad_mode;
(void)CheckAndConvertUtils::GetPadModEnumValue(primitive->GetAttr(kPadMode), &pad_mode, true);
CheckAndConvertUtils::GetPadModEnumValue(primitive->GetAttr(kPadMode), &pad_mode, true);
ShapeVector pad_list = {0, 0, 0, 0};
if (!attr_pad_list_prt->isa<None>()) {
pad_list = GetValue<ShapeVector>(attr_pad_list_prt);
@ -56,14 +56,18 @@ void SetPadList(const PrimitivePtr &primitive, const std::vector<int64_t> &dout_
int64_t pad_bottom = abstract::Shape::SHP_ANY;
int64_t pad_left = abstract::Shape::SHP_ANY;
int64_t pad_right = abstract::Shape::SHP_ANY;
if (dout_shape_norm[2] != abstract::Shape::SHP_ANY && x_size_v[2] != abstract::Shape::SHP_ANY) {
auto pad_needed_h = (dout_shape_norm[2] - 1) * stride_h + dilation_h * (kernel_h - 1) + 1 - x_size_v[2];
if (dout_shape_norm[kInputIndex2] != abstract::Shape::SHP_ANY &&
x_size_v[kInputIndex2] != abstract::Shape::SHP_ANY) {
auto pad_needed_h =
(dout_shape_norm[kInputIndex2] - 1) * stride_h + dilation_h * (kernel_h - 1) + 1 - x_size_v[kInputIndex2];
pad_needed_h = 0 > pad_needed_h ? 0 : pad_needed_h;
pad_top = pad_needed_h / 2;
pad_bottom = pad_needed_h - pad_top;
}
if (dout_shape_norm[3] != abstract::Shape::SHP_ANY && x_size_v[3] != abstract::Shape::SHP_ANY) {
auto pad_needed_w = (dout_shape_norm[3] - 1) * stride_w + dilation_w * (kernel_w - 1) + 1 - x_size_v[3];
if (dout_shape_norm[kInputIndex3] != abstract::Shape::SHP_ANY &&
x_size_v[kInputIndex3] != abstract::Shape::SHP_ANY) {
auto pad_needed_w =
(dout_shape_norm[kInputIndex3] - 1) * stride_w + dilation_w * (kernel_w - 1) + 1 - x_size_v[kInputIndex3];
pad_needed_w = pad_needed_w > 0L ? pad_needed_w : 0L;
pad_left = pad_needed_w / 2;
pad_right = pad_needed_w - pad_left;
@ -203,10 +207,10 @@ void Conv2DBackpropInput::set_pad_mode(const PadMode &pad_mode) {
std::vector<int64_t> pad = get_pad();
if (pad_mode == PAD) {
for (auto item : pad) {
(void)CheckAndConvertUtils::Check(kPadItem, item, kGreaterEqual, "zeros_list", 0, name());
CheckAndConvertUtils::Check(kPadItem, item, kGreaterEqual, "zeros_list", 0, name());
}
} else {
(void)CheckAndConvertUtils::Check(kPad, pad, kEqual, "zeros_list", {0, 0, 0, 0}, name());
CheckAndConvertUtils::Check(kPad, pad, kEqual, "zeros_list", {0, 0, 0, 0}, name());
}
int64_t swi = pad_mode;
(void)AddAttr(kPadMode, MakeValue(swi));

View File

@ -29,7 +29,8 @@ abstract::ShapePtr HShrinkGradInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const int64_t input_num = 2;
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, input_num, primitive->name());
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num,
primitive->name());
auto prim_name = primitive->name();
auto gradients_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto features_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
@ -42,7 +43,8 @@ abstract::ShapePtr HShrinkGradInferShape(const PrimitivePtr &primitive,
TypePtr HShrinkGradInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
const int64_t input_num = 2;
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, input_num, prim->name());
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num,
prim->name());
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}

View File

@ -33,7 +33,8 @@ namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const int64_t input_num = 2;
(void)CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, input_num, primitive->name());
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num,
primitive->name());
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
@ -47,10 +48,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
const int64_t input_num = 2;
(void)CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, input_num, prim->name());
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, prim->name());
const std::set<TypePtr> valid_types = {kInt8, kInt16, kInt32, kInt64, kFloat16, kFloat32};
std::map<std::string, TypePtr> types;
(void)types.emplace("grads", input_args[0]->BuildType());

View File

@ -39,16 +39,16 @@ AbstractBasePtr SigmoidCrossEntropyWithLogitsGradInfer(const abstract::AnalysisE
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
auto dout_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
(void)CheckAndConvertUtils::Check("x_shape", x_shape, kEqual, "y_shape", y_shape, prim_name, TypeError);
(void)CheckAndConvertUtils::Check("x_shape", x_shape, kEqual, "dout_shape", dout_shape, prim_name, TypeError);
CheckAndConvertUtils::Check("x_shape", x_shape, kEqual, "y_shape", y_shape, prim_name, TypeError);
CheckAndConvertUtils::Check("x_shape", x_shape, kEqual, "dout_shape", dout_shape, prim_name, TypeError);
// Infer type
const std::set<TypePtr> valid_types = {kBool, kInt, kInt8, kInt16, kInt32, kInt64, kUInt, kUInt8,
kUInt16, kUInt32, kUInt64, kFloat, kFloat16, kFloat32, kFloat64, kComplex64};
std::map<std::string, TypePtr> args;
args.emplace("x_type", input_args[kInputIndex0]->BuildType());
args.emplace("y_type", input_args[kInputIndex1]->BuildType());
args.emplace("dout_type", input_args[kInputIndex2]->BuildType());
(void)args.emplace("x_type", input_args[kInputIndex0]->BuildType());
(void)args.emplace("y_type", input_args[kInputIndex1]->BuildType());
(void)args.emplace("dout_type", input_args[kInputIndex2]->BuildType());
auto dout_type = CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name);
return std::make_shared<abstract::AbstractTensor>(dout_type, x_shape);
}

View File

@ -53,9 +53,9 @@ AbstractBasePtr SmoothL1LossGradInfer(const abstract::AnalysisEnginePtr &, const
const std::set<TypePtr> valid_types = {kBool, kInt, kInt8, kInt16, kInt32, kInt64, kUInt, kUInt8,
kUInt16, kUInt32, kUInt64, kFloat, kFloat16, kFloat32, kFloat64, kComplex64};
std::map<std::string, TypePtr> args;
args.emplace("prediction", input_args[kInputIndex0]->BuildType());
args.emplace("target", input_args[kInputIndex1]->BuildType());
args.emplace("dloss", input_args[kInputIndex2]->BuildType());
(void)args.emplace("prediction", input_args[kInputIndex0]->BuildType());
(void)args.emplace("target", input_args[kInputIndex1]->BuildType());
(void)args.emplace("dloss", input_args[kInputIndex2]->BuildType());
auto dloss_type = CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name);
return std::make_shared<abstract::AbstractTensor>(dloss_type, prediction);

View File

@ -28,7 +28,7 @@ abstract::ShapePtr SoftMarginLossGradInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto op_name = primitive->name();
CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, kInputSize, op_name);
(void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, kInputSize, op_name);
auto predict = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
auto label = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
auto dout = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
@ -41,7 +41,7 @@ abstract::ShapePtr SoftMarginLossGradInferShape(const PrimitivePtr &primitive,
TypePtr SoftMarginLossGradInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
auto op_name = primitive->name();
CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, kInputSize, op_name);
(void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, kInputSize, op_name);
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
std::map<std::string, TypePtr> types;
(void)types.emplace("logits", input_args[kInputIndex0]->BuildType());

View File

@ -33,7 +33,8 @@ abstract::ShapePtr SoftShrinkGradInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const int64_t input_num = 2;
(void)CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, input_num, primitive->name());
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num,
primitive->name());
auto input_grad_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto input_x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
auto prim_name = primitive->name();

View File

@ -28,19 +28,13 @@ namespace ops {
namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 1, primitive->name());
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, 1L, primitive->name());
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
return std::make_shared<abstract::Shape>(in_shape);
}
TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 1, primitive->name());
if (std::any_of(input_args.begin(), input_args.end(), [](const AbstractBasePtr &a) { return a == nullptr; })) {
MS_LOG(EXCEPTION) << "nullptr";
}
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, 1, primitive->name());
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
return CheckAndConvertUtils::CheckTensorTypeValid("input_x", input_args[0]->BuildType(), valid_types,
primitive->name());

View File

@ -27,7 +27,7 @@ ShapeVector CalLayerNormMeanAndVarShape(int64_t begin_norm_axis, const ShapeVect
if (begin_norm_axis == -1) {
mean_var_shape_value[input_rank - 1] = 1;
} else {
for (size_t i = begin_norm_axis; i < input_rank; i++) {
for (size_t i = LongToSize(begin_norm_axis); i < input_rank; i++) {
mean_var_shape_value[i] = 1;
}
}

View File

@ -59,9 +59,9 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
types.emplace("start", input_args[0]->BuildType());
types.emplace("end", input_args[1]->BuildType());
if (input_args[kInputIndex2]->isa<abstract::AbstractTensor>()) {
types.emplace("weight", input_args[kInputIndex2]->BuildType());
(void)types.emplace("weight", input_args[kInputIndex2]->BuildType());
} else {
CheckAndConvertUtils::CheckSubClass("weight", input_args[kInputIndex2]->BuildType(), {kFloat}, op_name);
(void)CheckAndConvertUtils::CheckSubClass("weight", input_args[kInputIndex2]->BuildType(), {kFloat}, op_name);
}
return CheckAndConvertUtils::CheckTensorTypeSame(types, {kFloat16, kFloat32}, op_name);
}

View File

@ -28,10 +28,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
MS_EXCEPTION_IF_NULL(primitive);
auto op_name = primitive->name();
const int64_t input_num = 3;
CheckAndConvertUtils::CheckInteger("input numbers", input_args.size(), kEqual, input_num, op_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, op_name);
auto input_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape());
auto input_shape = input_shape_map[kShape];
auto mask_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape());
@ -57,15 +54,15 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
}
auto op_name = prim->name();
const int64_t input_num = 3;
CheckAndConvertUtils::CheckInteger("input numbers", input_args.size(), kEqual, input_num, op_name);
(void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, input_num, op_name);
CheckAndConvertUtils::CheckTensorTypeValid("mask", input_args[1]->BuildType(), {kBool}, op_name);
if (input_args[kInputIndex2]->isa<abstract::AbstractTensor>()) {
std::map<std::string, TypePtr> types;
types.emplace("input", input_args[kInputIndex0]->BuildType());
types.emplace("value", input_args[kInputIndex2]->BuildType());
(void)types.emplace("input", input_args[kInputIndex0]->BuildType());
(void)types.emplace("value", input_args[kInputIndex2]->BuildType());
return CheckAndConvertUtils::CheckTensorTypeSame(types, {kFloat16, kFloat32, kInt8, kInt32}, op_name);
} else {
CheckAndConvertUtils::CheckSubClass("value", input_args[kInputIndex2]->BuildType(), {kFloat}, op_name);
(void)CheckAndConvertUtils::CheckSubClass("value", input_args[kInputIndex2]->BuildType(), {kFloat}, op_name);
return CheckAndConvertUtils::CheckTensorTypeValid("input", input_args[0]->BuildType(),
{kFloat16, kFloat32, kInt8, kInt32}, op_name);
}

View File

@ -36,7 +36,7 @@ AbstractBasePtr MergeInfer(const abstract::AnalysisEnginePtr &, const PrimitiveP
args.insert({"input[" + std::to_string(i) + "]", inputs_type[i]});
}
std::set<TypePtr> template_type = common_valid_types;
template_type.emplace(kBool);
(void)template_type.emplace(kBool);
auto infered_type = CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args, template_type, op_name);
std::vector<int64_t> in_shape0 = inputs_shape[0]->cast<abstract::ShapePtr>()->shape();

View File

@ -40,7 +40,7 @@ abstract::ShapePtr OneHotInferShape(const PrimitivePtr &primitive, const std::ve
MS_EXCEPTION_IF_NULL(depth);
int64_t depth_value;
if (depth->isa<tensor::Tensor>()) {
CheckAndConvertUtils::CheckTensorTypeValid("depth", input_args[1]->BuildType(), {kInt64}, op_name);
(void)CheckAndConvertUtils::CheckTensorTypeValid("depth", input_args[1]->BuildType(), {kInt64}, op_name);
auto depth_data = depth->cast<tensor::TensorPtr>()->data_c();
MS_EXCEPTION_IF_NULL(depth_data);
auto data_value = reinterpret_cast<int64_t *>(depth_data);

View File

@ -44,10 +44,12 @@ std::vector<int64_t> GetOutputMaskShape(const std::vector<int64_t> &input_shape,
mask_shape.push_back(input_shape[i]);
}
}
const int64_t shape_end_4d = 4;
const int64_t shape_end_2d = 2;
if (x_dtype == kUInt8 || x_dtype == kInt8) {
(void)mask_shape.insert(mask_shape.end(), 4);
(void)mask_shape.insert(mask_shape.end(), shape_end_4d);
} else {
(void)mask_shape.insert(mask_shape.end(), 2);
(void)mask_shape.insert(mask_shape.end(), shape_end_2d);
}
return mask_shape;
}

View File

@ -28,7 +28,7 @@ abstract::ShapePtr SoftMarginLossInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto op_name = primitive->name();
CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, kInputSize, op_name);
(void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, kInputSize, op_name);
auto predict = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto label = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
CheckAndConvertUtils::Check("logits shape", predict, kEqual, "labels shape", label, op_name, ValueError);
@ -43,7 +43,7 @@ abstract::ShapePtr SoftMarginLossInferShape(const PrimitivePtr &primitive,
TypePtr SoftMarginLossInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
auto op_name = primitive->name();
CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, kInputSize, op_name);
(void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, kInputSize, op_name);
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
std::map<std::string, TypePtr> types;
(void)types.emplace("logits", input_args[0]->BuildType());

View File

@ -69,7 +69,7 @@ void SpaceToBatchND::set_paddings(std::vector<std::vector<int64_t>> paddings) {
CheckAndConvertUtils::Check(kPaddings, {h, w}, kEqual, "paddings_shape(2,2)", temp_w, this->name());
for (size_t i = 0; i < h; i++) {
for (size_t j = 0; j < w; j++) {
(void)CheckAndConvertUtils::CheckInteger(kPaddings, SizeToLong(paddings[i][j]), kGreaterEqual, 0, this->name());
(void)CheckAndConvertUtils::CheckInteger(kPaddings, SizeToLong(paddings[i][j]), kGreaterEqual, 0L, this->name());
}
}
(void)this->AddAttr(kPaddings, MakeValue(paddings));
@ -84,7 +84,7 @@ void SpaceToBatchND::set_block_shape(std::vector<int64_t> block_shape) {
(void)CheckAndConvertUtils::CheckInteger(kBlockShape, SizeToLong(block_shape.size()), kEqual, block_size,
this->name());
for (size_t i = 0; i < block_shape.size(); i++) {
(void)CheckAndConvertUtils::CheckInteger(kBlockShape, SizeToLong(block_shape[i]), kGreaterEqual, 1, this->name());
(void)CheckAndConvertUtils::CheckInteger(kBlockShape, SizeToLong(block_shape[i]), kGreaterEqual, 1L, this->name());
}
(void)this->AddAttr(kBlockShape, MakeValue(block_shape));
}

View File

@ -24,7 +24,7 @@ void SpaceToDepth::Init(const int64_t block_size, const Format &format) {
}
void SpaceToDepth::set_block_size(const int64_t block_size) {
(void)CheckAndConvertUtils::Check(kBlockSize, block_size, kGreaterEqual, "", 2, this->name());
CheckAndConvertUtils::Check(kBlockSize, block_size, kGreaterEqual, "", 2, this->name());
(void)AddAttr(kBlockSize, MakeValue(block_size));
}

View File

@ -50,15 +50,15 @@ abstract::TupleShapePtr InferShape(const PrimitivePtr &primitive, const std::vec
same_shape_args_map.insert({"shape of mom ", mom_shape});
same_shape_args_map.insert({"shape of grad ", grad_shape});
for (auto &elem : same_shape_args_map) {
(void)CheckAndConvertUtils::Check(elem.first, elem.second, kEqual, "var shape", var_shape, prim_name);
CheckAndConvertUtils::Check(elem.first, elem.second, kEqual, "var shape", var_shape, prim_name);
}
// Indices must be rank 1
(void)CheckAndConvertUtils::CheckInteger("indices dim", indices_shape.size(), kEqual, 1, prim_name);
// Dimension of var must be equal or greater than 1
(void)CheckAndConvertUtils::CheckInteger("dimension of var", var_shape.size(), kGreaterEqual, 1, prim_name);
// Indices shape must be equal to the first dimension of var
(void)CheckAndConvertUtils::Check("indices shape", indices_shape[0], kEqual, "the first dimension of var",
var_shape[0], prim_name);
CheckAndConvertUtils::Check("indices shape", indices_shape[0], kEqual, "the first dimension of var", var_shape[0],
prim_name);
return std::make_shared<abstract::TupleShape>(
std::vector<abstract::BaseShapePtr>{var_shape_ptr, ms_shape_ptr, mom_shape_ptr});
}

View File

@ -27,10 +27,10 @@ namespace {
abstract::TupleShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInteger("input numbers", input_args.size(), kEqual, 1, prim_name);
(void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, 1L, prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto x_rank = SizeToLong(x_shape.size());
CheckAndConvertUtils::CheckInteger("x_rank", x_rank, kGreaterEqual, 1, prim_name);
(void)CheckAndConvertUtils::CheckInteger("x_rank", x_rank, kGreaterEqual, 1, prim_name);
auto split_dim = GetValue<int64_t>(primitive->GetAttr("split_dim"));
CheckAndConvertUtils::CheckInRange("split_dim", split_dim, kIncludeLeft, {-x_rank, x_rank}, prim_name);
if (split_dim < 0) {

View File

@ -24,8 +24,11 @@ namespace {
abstract::ShapePtr TensorListFromTensorInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto op_name = primitive->name();
const int64_t input_num = 2;
CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, input_num, op_name);
auto input0_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto input1_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
if (input0_shape.size() < 1) {
MS_LOG(ERROR) << "input0_shape.size():" << input0_shape.size() << " must be greater than 0!";
}
@ -33,11 +36,6 @@ abstract::ShapePtr TensorListFromTensorInferShape(const PrimitivePtr &primitive,
if (dim0 < 0) {
MS_LOG(ERROR) << "input[0] dim0:" << dim0 << " must be greater than or equal to 0!";
}
auto input1 = &input1_shape[0];
MS_EXCEPTION_IF_NULL(input1);
if (input1 == nullptr) {
MS_LOG(ERROR) << "input1 is nullptr";
}
std::vector<int64_t> infer_shape = {1, dim0};
return std::make_shared<abstract::Shape>(infer_shape);
}

View File

@ -41,6 +41,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
[](const ValuePtr &e) -> int64_t { return GetValue<int64_t>(e); });
} else {
auto perm_value = input_args[1]->BuildValue();
MS_EXCEPTION_IF_NULL(perm_value);
if (perm_value->isa<tensor::Tensor>()) {
p_value = CheckAndConvertUtils::CheckTensorIntValue("perm value", perm_value, op_name);
} else {
@ -66,7 +67,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
}
}
std::vector<int64_t> in_shape(p_value);
(void)std::transform(in_shape.begin(), in_shape.end(), in_shape.begin(), [x_shape](int i) { return x_shape[i]; });
(void)std::transform(in_shape.begin(), in_shape.end(), in_shape.begin(), [x_shape](int64_t i) { return x_shape[i]; });
if (!x_min_shape.empty() && !x_max_shape.empty()) {
std::vector<int64_t> min_shape;
std::vector<int64_t> max_shape;

View File

@ -57,9 +57,9 @@ AbstractBasePtr UnsortedSegmentSumInfer(const abstract::AnalysisEnginePtr &, con
const std::set<TypePtr> valid_num_segments_types = {kInt32, kInt64};
(void)CheckAndConvertUtils::CheckTensorTypeValid("num_segments", input_args[kInputIndex2]->BuildType(),
valid_num_segments_types, prim_name);
int64_t size_segment_ids_shp = SizeToLong(segment_ids_shape.size());
int64_t size_x_shpe = SizeToLong(x_shape.size());
for (int64_t i = size_segment_ids_shp; i < size_x_shpe; ++i) {
size_t size_segment_ids_shp = segment_ids_shape.size();
size_t size_x_shape = x_shape.size();
for (size_t i = size_segment_ids_shp; i < size_x_shape; ++i) {
(void)shp.emplace_back(x_shape[i]);
}

View File

@ -21,7 +21,7 @@
namespace mindspore {
namespace ops {
void Unsqueeze::Init(const std::vector<int64_t> axis) { (void)this->set_axis(axis); }
void Unsqueeze::Init(const std::vector<int64_t> axis) { this->set_axis(axis); }
void Unsqueeze::set_axis(const std::vector<int64_t> axis) { (void)this->AddAttr(kAxis, MakeValue(axis)); }

View File

@ -37,7 +37,7 @@ TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBaseP
auto op_name = primitive->name();
auto infer_type = input_args[0]->BuildType();
auto valid_type = common_valid_types;
valid_type.insert(kBool);
(void)valid_type.insert(kBool);
(void)CheckAndConvertUtils::CheckTensorTypeValid("infer_type", infer_type, valid_type, op_name);
return infer_type;
}

View File

@ -528,6 +528,9 @@ TypePtr CheckAndConvertUtils::CheckTensorTypeValid(const std::string &type_name,
ShapeVector CheckAndConvertUtils::CheckTensorIntValue(const std::string &type_name, const ValuePtr &value,
const std::string &prim_name) {
if (value == nullptr) {
MS_EXCEPTION(ValueError) << "The " << prim_name << "'s " << type_name << " value is nullptr.";
}
ShapeVector tensor_value;
if (!value->isa<tensor::Tensor>()) {
MS_EXCEPTION(ValueError) << "The " << prim_name << "'s " << type_name << " must be a tensor.";