forked from mindspore-Ecosystem/mindspore
!48302 fix segmentation fault raised by multimarginloss
Merge pull request !48302 from OwenSec/master
This commit is contained in:
commit
ab68927414
|
@ -134,13 +134,13 @@ string MultiMarginLoss::get_reduction() const {
|
|||
AbstractBasePtr MultiMarginLossInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
CheckAndConvertUtils::CheckInRange("multi_margin_loss_input_nums", input_args.size(), kIncludeBoth, {kDim2, kDim3},
|
||||
primitive->name());
|
||||
MS_EXCEPTION_IF_NULL(input_args[kInputIndex0]);
|
||||
MS_EXCEPTION_IF_NULL(input_args[kInputIndex1]);
|
||||
if (input_args.size() == kDim3) {
|
||||
MS_EXCEPTION_IF_NULL(input_args[kInputIndex2]);
|
||||
}
|
||||
CheckAndConvertUtils::CheckInRange("multi_margin_loss_input_nums", input_args.size(), kIncludeBoth, {kDim2, kDim3},
|
||||
primitive->name());
|
||||
auto types = MultiMarginLossInferType(primitive, input_args);
|
||||
auto shapes = MultiMarginLossInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(shapes, types);
|
||||
|
|
Loading…
Reference in New Issue