forked from mindspore-Ecosystem/mindspore
!17040 codedex_clean 0526
From: @ding_fei_fei Reviewed-by: Signed-off-by:
This commit is contained in:
commit
e9d2684cdb
|
@ -57,7 +57,7 @@ AbstractBasePtr ApplyMomentumInfer(const abstract::AnalysisEnginePtr &, const Pr
|
|||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInteger("apply_momentum_infer", input_args.size(), kEqual, 5, prim_name);
|
||||
CheckAndConvertUtils::CheckInteger("apply_momentum_infer", SizeToLong(input_args.size()), kEqual, 5, prim_name);
|
||||
|
||||
// Infer shape
|
||||
auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
|
|
|
@ -24,7 +24,7 @@ AbstractBasePtr AtanInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt
|
|||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInteger("Atan_infer", input_args.size(), kEqual, 1, prim_name);
|
||||
CheckAndConvertUtils::CheckInteger("Atan_infer", int64_t(input_args.size()), kEqual, 1, prim_name);
|
||||
|
||||
// Infer Shape
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
|
|
|
@ -87,7 +87,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
if (format == NHWC) {
|
||||
in_shape = {in_shape[0], in_shape[3], in_shape[1], in_shape[2]};
|
||||
}
|
||||
CheckAndConvertUtils::CheckInteger("x_rank", in_shape.size(), kEqual, 4, op_name);
|
||||
CheckAndConvertUtils::CheckInteger("x_rank", SizeToLong(in_shape.size()), kEqual, 4, op_name);
|
||||
auto kernel_size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kKernelSize));
|
||||
auto pad_mode = PadMode(GetValue<int64_t>(primitive->GetAttr(kPadMode)));
|
||||
auto batch = in_shape[0];
|
||||
|
@ -112,14 +112,14 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
if (format == NHWC) {
|
||||
out_shape = {batch, out_h, out_w, channel};
|
||||
}
|
||||
if (std::any_of(out_shape.begin(), out_shape.end(), [](int64_t a) { return a <= 0; })) {
|
||||
if (std::any_of(out_shape.begin(), out_shape.end(), [](int64_t arg) { return arg <= 0; })) {
|
||||
MS_LOG(EXCEPTION) << "Kernel size is not valid.";
|
||||
}
|
||||
return std::make_shared<abstract::Shape>(out_shape);
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
if (std::any_of(input_args.begin(), input_args.end(), [](AbstractBasePtr a) { return a == nullptr; })) {
|
||||
TypePtr InferType(const std::vector<AbstractBasePtr> &input_args) {
|
||||
if (std::any_of(input_args.begin(), input_args.end(), [](AbstractBasePtr arg) { return arg == nullptr; })) {
|
||||
MS_LOG(EXCEPTION) << "nullptr";
|
||||
}
|
||||
return input_args[0]->BuildType();
|
||||
|
@ -128,8 +128,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
|
||||
AbstractBasePtr AvgPoolInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
||||
InferShape(primitive, input_args)->shape());
|
||||
return std::make_shared<abstract::AbstractTensor>(InferType(input_args), InferShape(primitive, input_args)->shape());
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameAvgPool, AvgPool);
|
||||
} // namespace ops
|
||||
|
|
|
@ -30,7 +30,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
CheckAndConvertUtils::CheckInteger("input_x rank", x_shape.size(), kEqual, 4, prim_name);
|
||||
CheckAndConvertUtils::CheckInteger("input_x rank", SizeToLong(x_shape.size()), kEqual, 4, prim_name);
|
||||
auto out_shape = x_shape;
|
||||
int64_t block_shape_prod = 1;
|
||||
int64_t offset = 2;
|
||||
|
@ -52,7 +52,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
return std::make_shared<abstract::Shape>(out_shape);
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
TypePtr InferType(const std::vector<AbstractBasePtr> &input_args) {
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
@ -62,7 +62,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
} // namespace
|
||||
|
||||
void BatchToSpaceND::set_crops(std::vector<std::vector<int64_t>> crops) {
|
||||
CheckAndConvertUtils::CheckInteger(kCrops, crops.size(), kEqual, 2, this->name());
|
||||
CheckAndConvertUtils::CheckInteger(kCrops, SizeToLong(crops.size()), kEqual, 2, this->name());
|
||||
int64_t h = crops.size();
|
||||
int64_t w = crops[0].size();
|
||||
std::vector<int64_t> temp_w = {2, 2};
|
||||
|
@ -80,7 +80,7 @@ std::vector<std::vector<int64_t>> BatchToSpaceND::get_crops() const {
|
|||
return GetValue<std::vector<std::vector<int64_t>>>(value_ptr);
|
||||
}
|
||||
void BatchToSpaceND::set_block_shape(std::vector<int64_t> block_shape) {
|
||||
CheckAndConvertUtils::CheckInteger(kBlockShape, block_shape.size(), kEqual, 2, this->name());
|
||||
CheckAndConvertUtils::CheckInteger(kBlockShape, SizeToLong(block_shape.size()), kEqual, 2, this->name());
|
||||
for (int64_t i = 0; i < (int64_t)block_shape.size(); i++) {
|
||||
CheckAndConvertUtils::CheckInteger(kBlockShape, block_shape[i], kGreaterEqual, 1, this->name());
|
||||
}
|
||||
|
@ -98,8 +98,7 @@ void BatchToSpaceND::Init(std::vector<int64_t> block_shape, std::vector<std::vec
|
|||
}
|
||||
AbstractBasePtr BatchToSpaceNDInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
||||
InferShape(primitive, input_args)->shape());
|
||||
return std::make_shared<abstract::AbstractTensor>(InferType(input_args), InferShape(primitive, input_args)->shape());
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameBatchToSpaceND, BatchToSpaceND);
|
||||
} // namespace ops
|
||||
|
|
|
@ -33,7 +33,7 @@ abstract::ShapePtr Conv2dTransposeInferShape(const PrimitivePtr &primitive,
|
|||
}
|
||||
|
||||
TypePtr Conv2dTransposeInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
CheckAndConvertUtils::CheckInteger("conv2d_transpose_infer", input_args.size(), kEqual, 3, prim->name());
|
||||
CheckAndConvertUtils::CheckInteger("conv2d_transpose_infer", SizeToLong(input_args.size()), kEqual, 3, prim->name());
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
@ -72,7 +72,7 @@ void Conv2dTranspose::set_out_channel(int64_t out_channel) {
|
|||
}
|
||||
|
||||
void Conv2dTranspose::set_kernel_size(const std::vector<int64_t> &kernel_size) {
|
||||
CheckAndConvertUtils::CheckInteger(kKernelSize, kernel_size.size(), kEqual, 2, name());
|
||||
CheckAndConvertUtils::CheckInteger(kKernelSize, SizeToLong(kernel_size.size()), kEqual, 2, name());
|
||||
for (int64_t item : kernel_size) {
|
||||
CheckAndConvertUtils::CheckInteger(kKernelSize, item, kGreaterEqual, 1, name());
|
||||
}
|
||||
|
@ -80,7 +80,7 @@ void Conv2dTranspose::set_kernel_size(const std::vector<int64_t> &kernel_size) {
|
|||
}
|
||||
|
||||
void Conv2dTranspose::set_stride(const std::vector<int64_t> &stride) {
|
||||
CheckAndConvertUtils::CheckInteger(kStride, stride.size(), kEqual, 2, name());
|
||||
CheckAndConvertUtils::CheckInteger(kStride, SizeToLong(stride.size()), kEqual, 2, name());
|
||||
for (int64_t item : stride) {
|
||||
CheckAndConvertUtils::CheckInteger(kStride, item, kGreaterEqual, 1, name());
|
||||
}
|
||||
|
@ -88,7 +88,7 @@ void Conv2dTranspose::set_stride(const std::vector<int64_t> &stride) {
|
|||
}
|
||||
|
||||
void Conv2dTranspose::set_dilation(const std::vector<int64_t> &dilation) {
|
||||
CheckAndConvertUtils::CheckInteger(kDilation, dilation.size(), kGreaterEqual, 2, name());
|
||||
CheckAndConvertUtils::CheckInteger(kDilation, SizeToLong(dilation.size()), kGreaterEqual, 2, name());
|
||||
AddAttr(kDilation, MakeValue(dilation));
|
||||
}
|
||||
|
||||
|
@ -106,7 +106,7 @@ void Conv2dTranspose::set_pad_mode(const PadMode &pad_mode) {
|
|||
}
|
||||
|
||||
void Conv2dTranspose::set_pad(const std::vector<int64_t> &pad) {
|
||||
CheckAndConvertUtils::CheckInteger("pad_size", pad.size(), kEqual, 4, name());
|
||||
CheckAndConvertUtils::CheckInteger("pad_size", SizeToLong(pad.size()), kEqual, 4, name());
|
||||
AddAttr(kPad, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, name())));
|
||||
}
|
||||
|
||||
|
@ -124,7 +124,7 @@ void Conv2dTranspose::set_format(const Format &format) {
|
|||
}
|
||||
|
||||
void Conv2dTranspose::set_pad_list(const std::vector<int64_t> &pad_list) {
|
||||
CheckAndConvertUtils::CheckInteger(kPadList, pad_list.size(), kEqual, 4, name());
|
||||
CheckAndConvertUtils::CheckInteger(kPadList, SizeToLong(pad_list.size()), kEqual, 4, name());
|
||||
this->AddAttr(kPadList, MakeValue(pad_list));
|
||||
}
|
||||
|
||||
|
|
|
@ -47,7 +47,7 @@ AbstractBasePtr DepthToSpaceInfer(const abstract::AnalysisEnginePtr &, const Pri
|
|||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 1, prim_name);
|
||||
CheckAndConvertUtils::CheckInteger("input number", int64_t(input_args.size()), kEqual, 1, prim_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
@ -59,7 +59,7 @@ AbstractBasePtr DepthToSpaceInfer(const abstract::AnalysisEnginePtr &, const Pri
|
|||
if (format == NHWC) {
|
||||
x_shape = {x_shape[0], x_shape[3], x_shape[1], x_shape[2]};
|
||||
}
|
||||
CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kEqual, 4, prim_name);
|
||||
CheckAndConvertUtils::CheckInteger("x rank", SizeToLong(x_shape.size()), kEqual, 4, prim_name);
|
||||
int64_t block_size = GetValue<int64_t>(primitive->GetAttr(kBlockSize));
|
||||
CheckAndConvertUtils::CheckInteger("x_shape[1] % (block_size*block_size)", x_shape[1] % (block_size * block_size),
|
||||
kEqual, 0, prim_name);
|
||||
|
|
|
@ -26,7 +26,7 @@ AbstractBasePtr FillInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt
|
|||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 3, prim_name);
|
||||
CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, 3, prim_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
|
|
@ -28,7 +28,7 @@ namespace {
|
|||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 2, prim_name);
|
||||
CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, 2, prim_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
@ -50,7 +50,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
const std::set<TypePtr> valid_types = {kInt8, kInt16, kInt32, kInt64};
|
||||
if (std::any_of(input_args.begin(), input_args.end(), [](const AbstractBasePtr &a) { return a == nullptr; })) {
|
||||
if (std::any_of(input_args.begin(), input_args.end(), [](const AbstractBasePtr &arg) { return arg == nullptr; })) {
|
||||
MS_LOG(EXCEPTION) << "nullptr";
|
||||
}
|
||||
std::map<std::string, TypePtr> types;
|
||||
|
|
|
@ -130,7 +130,7 @@ void Conv2DBackpropInput::set_pad_mode(const PadMode &pad_mode) {
|
|||
}
|
||||
|
||||
void Conv2DBackpropInput::set_pad(const std::vector<int64_t> &pad) {
|
||||
CheckAndConvertUtils::CheckInteger("pad_size", pad.size(), kEqual, 4, name());
|
||||
CheckAndConvertUtils::CheckInteger("pad_size", SizeToLong(pad.size()), kEqual, 4, name());
|
||||
AddAttr(kPad, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, name())));
|
||||
}
|
||||
|
||||
|
|
|
@ -31,8 +31,8 @@ AbstractBasePtr SigmoidCrossEntropyWithLogitsGradInfer(const abstract::AnalysisE
|
|||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInteger("sigmoid_cross_entropy_with_logits_grad_infer", input_args.size(), kEqual, 3,
|
||||
prim_name);
|
||||
CheckAndConvertUtils::CheckInteger("sigmoid_cross_entropy_with_logits_grad_infer", SizeToLong(input_args.size()),
|
||||
kEqual, 3, prim_name);
|
||||
|
||||
// Infer Shape
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
|
|
|
@ -27,14 +27,14 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
auto prim_name = primitive->name();
|
||||
auto first_input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
auto second_input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
|
||||
CheckAndConvertUtils::CheckInteger("first input rank", first_input_shape.size(), kEqual, 3, prim_name);
|
||||
CheckAndConvertUtils::CheckInteger("second input rank", second_input_shape.size(), kEqual, 1, prim_name);
|
||||
CheckAndConvertUtils::CheckInteger("first input rank", SizeToLong(first_input_shape.size()), kEqual, 3, prim_name);
|
||||
CheckAndConvertUtils::CheckInteger("second input rank", SizeToLong(second_input_shape.size()), kEqual, 1, prim_name);
|
||||
std::vector<int64_t> out_shape = {first_input_shape[0], first_input_shape[1],
|
||||
GetValue<int64_t>(primitive->GetAttr(kDctCoeffNum))};
|
||||
return std::make_shared<abstract::Shape>(out_shape);
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
TypePtr InferType(const std::vector<AbstractBasePtr> &input_args) {
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
@ -84,8 +84,7 @@ int64_t Mfcc::get_dct_coeff_num() const { return GetValue<int64_t>(GetAttr(kDctC
|
|||
|
||||
AbstractBasePtr MfccInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
||||
InferShape(primitive, input_args)->shape());
|
||||
return std::make_shared<abstract::AbstractTensor>(InferType(input_args), InferShape(primitive, input_args)->shape());
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameMfcc, Mfcc);
|
||||
} // namespace ops
|
||||
|
|
|
@ -29,7 +29,7 @@ int64_t NonMaxSuppression::get_center_point_box() const {
|
|||
}
|
||||
void NonMaxSuppression::Init(const int64_t center_point_box) { this->set_center_point_box(center_point_box); }
|
||||
|
||||
AbstractBasePtr NonMaxSuppressionInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
AbstractBasePtr NonMaxSuppressionInfer(const abstract::AnalysisEnginePtr &,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_LOG(INFO) << "NonMaxSuppression infer shape in runtime.";
|
||||
return std::make_shared<abstract::AbstractTensor>(kInt32, std::vector<int64_t>{});
|
||||
|
|
|
@ -52,7 +52,7 @@ AbstractBasePtr ROIPoolingInfer(const abstract::AnalysisEnginePtr &, const Primi
|
|||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInteger("roi_pooling_infer", input_args.size(), kEqual, 2, prim_name);
|
||||
CheckAndConvertUtils::CheckInteger("roi_pooling_infer", SizeToLong(input_args.size()), kEqual, 2, prim_name);
|
||||
MS_EXCEPTION_IF_NULL(input_args[0]);
|
||||
MS_EXCEPTION_IF_NULL(input_args[1]);
|
||||
|
||||
|
|
|
@ -23,7 +23,7 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
abstract::ShapePtr InferShape(const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto shape_value = input_args[2]->BuildValue();
|
||||
auto shape_value_element = GetValue<std::vector<int64_t>>(shape_value);
|
||||
for (const auto &shape : shape_value_element) {
|
||||
|
@ -52,8 +52,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
|
||||
AbstractBasePtr ScatterNdInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
||||
InferShape(primitive, input_args)->shape());
|
||||
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args), InferShape(input_args)->shape());
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameScatterNd, ScatterNd);
|
||||
} // namespace ops
|
||||
|
|
|
@ -34,7 +34,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
return std::make_shared<abstract::Shape>(in_shape);
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
TypePtr InferType(const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto infer_type = input_args[0]->BuildType();
|
||||
return infer_type;
|
||||
}
|
||||
|
@ -65,8 +65,7 @@ void SkipGram::Init(const bool include_all_grams, const int64_t max_skip_size, c
|
|||
|
||||
AbstractBasePtr SkipGramInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
||||
InferShape(primitive, input_args)->shape());
|
||||
return std::make_shared<abstract::AbstractTensor>(InferType(input_args), InferShape(primitive, input_args)->shape());
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameSkipGram, SkipGram);
|
||||
} // namespace ops
|
||||
|
|
|
@ -30,8 +30,8 @@ AbstractBasePtr SoftmaxCrossEntropyWithLogitsInfer(const abstract::AnalysisEngin
|
|||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInteger("softmax_cross_entropy_with_logics_infer", input_args.size(), kEqual, 2,
|
||||
prim_name);
|
||||
CheckAndConvertUtils::CheckInteger("softmax_cross_entropy_with_logics_infer", SizeToLong(input_args.size()), kEqual,
|
||||
2, prim_name);
|
||||
|
||||
// Infer shape
|
||||
auto logits_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
|
|
|
@ -29,12 +29,12 @@ abstract::AbstractBasePtr StackInfer(const PrimitivePtr &primitive, const std::v
|
|||
MS_LOG(ERROR) << "Invalid input size " << input_args.size();
|
||||
}
|
||||
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
for (int64_t i = 1; i < (int64_t)input_args.size(); ++i) {
|
||||
for (int64_t i = 1; i < SizeToLong(input_args.size()); ++i) {
|
||||
auto input_shape_tmp = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[i]->BuildShape())[kShape];
|
||||
if (input_shape_tmp.size() != input_shape.size()) {
|
||||
MS_LOG(ERROR) << "All input shape size should be the same!";
|
||||
}
|
||||
for (int64_t j = 0; j < (int64_t)input_shape.size(); ++j) {
|
||||
for (int64_t j = 0; j < SizeToLong(input_shape.size()); ++j) {
|
||||
if (input_shape_tmp.at(j) != input_shape.at(j)) {
|
||||
MS_LOG(ERROR) << "All input shape should be the same!";
|
||||
}
|
||||
|
@ -44,7 +44,7 @@ abstract::AbstractBasePtr StackInfer(const PrimitivePtr &primitive, const std::v
|
|||
infer_shape.insert(infer_shape.begin() + GetValue<int64_t>(primitive->GetAttr(kAxis)), input_args.size());
|
||||
|
||||
auto infer_type0 = input_args[0]->BuildType()->cast<TensorTypePtr>()->element();
|
||||
for (int64_t i = 1; i < (int64_t)input_args.size(); i++) {
|
||||
for (int64_t i = 1; i < SizeToLong(input_args.size()); i++) {
|
||||
if (input_args[i]->BuildType()->cast<TensorTypePtr>()->element() == infer_type0) {
|
||||
MS_LOG(ERROR) << "All input should have the same data type!input[" << i
|
||||
<< "] data type = " << input_args[i]->BuildType()->cast<TensorTypePtr>()->element();
|
||||
|
|
|
@ -34,12 +34,13 @@ AbstractBasePtr UnsortedSegmentSumInfer(const abstract::AnalysisEnginePtr &, con
|
|||
auto x_type = input_args[0]->BuildType()->cast<TensorTypePtr>()->element();
|
||||
// Infer shape
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
CheckAndConvertUtils::CheckInteger("x_shape", x_shape.size(), kGreaterThan, 0, prim_name);
|
||||
CheckAndConvertUtils::CheckInteger("x_shape", SizeToLong(x_shape.size()), kGreaterThan, 0, prim_name);
|
||||
auto shp = x_shape;
|
||||
auto segment_ids_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
|
||||
CheckAndConvertUtils::CheckInteger("segment_ids_shape", segment_ids_shape.size(), kGreaterThan, 0, prim_name);
|
||||
CheckAndConvertUtils::Check("input_x", x_shape.size(), kGreaterEqual, "segment_ids_shape", segment_ids_shape.size(),
|
||||
prim_name);
|
||||
CheckAndConvertUtils::CheckInteger("segment_ids_shape", SizeToLong(segment_ids_shape.size()), kGreaterThan, 0,
|
||||
prim_name);
|
||||
CheckAndConvertUtils::Check("input_x", int64_t(x_shape.size()), kGreaterEqual, "segment_ids_shape",
|
||||
int64_t(segment_ids_shape.size()), prim_name);
|
||||
|
||||
if ((x_shape.end() != find(x_shape.begin(), x_shape.end(), -1)) &&
|
||||
(segment_ids_shape.end() != find(segment_ids_shape.begin(), segment_ids_shape.end(), -1))) {
|
||||
|
|
Loading…
Reference in New Issue