From d51e693f5e7fd90b29dee885dcc918049cc77506 Mon Sep 17 00:00:00 2001 From: OwenSec Date: Wed, 1 Feb 2023 15:45:15 +0800 Subject: [PATCH] fix segmentation fault of multimarginloss --- mindspore/core/ops/multi_margin_loss.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mindspore/core/ops/multi_margin_loss.cc b/mindspore/core/ops/multi_margin_loss.cc index 87c1ef15a8d..e8dc42d7518 100644 --- a/mindspore/core/ops/multi_margin_loss.cc +++ b/mindspore/core/ops/multi_margin_loss.cc @@ -134,13 +134,13 @@ string MultiMarginLoss::get_reduction() const { AbstractBasePtr MultiMarginLossInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); + CheckAndConvertUtils::CheckInRange("multi_margin_loss_input_nums", input_args.size(), kIncludeBoth, {kDim2, kDim3}, + primitive->name()); MS_EXCEPTION_IF_NULL(input_args[kInputIndex0]); MS_EXCEPTION_IF_NULL(input_args[kInputIndex1]); if (input_args.size() == kDim3) { MS_EXCEPTION_IF_NULL(input_args[kInputIndex2]); } - CheckAndConvertUtils::CheckInRange("multi_margin_loss_input_nums", input_args.size(), kIncludeBoth, {kDim2, kDim3}, - primitive->name()); auto types = MultiMarginLossInferType(primitive, input_args); auto shapes = MultiMarginLossInferShape(primitive, input_args); return abstract::MakeAbstract(shapes, types);