!49017 fix segmentation fault bug for softmax when input is None

Merge pull request !49017 from gengdongjie/fix_issues
This commit is contained in:
i-robot 2023-02-21 01:25:56 +00:00 committed by Gitee
commit 11a8e7bc3d
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 12 additions and 0 deletions

View File

@ -27,6 +27,9 @@
namespace mindspore {
namespace ops {
constexpr auto kNameSoftmaxMinInputSize = 1;
constexpr auto kNameSoftmaxMaxInputSize = 2;
void Softmax::set_axis(const std::vector<int64_t> &axis) { (void)this->AddAttr(kAxis, api::MakeValue(axis)); }
std::vector<int64_t> Softmax::get_axis() const {
@ -48,6 +51,11 @@ void Softmax::Init(const int64_t axis) {
namespace {
abstract::ShapePtr SoftMaxInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
if (input_args.size() < kNameSoftmaxMinInputSize || input_args.size() > kNameSoftmaxMaxInputSize) {
MS_LOG(EXCEPTION) << "For '" << primitive->name() << "', the input args size should be " << kNameSoftmaxMinInputSize
<< " or " << kNameSoftmaxMaxInputSize << " , but get " << input_args.size();
}
auto op_name = primitive->name();
auto axis = GetValue<std::vector<int64_t>>(primitive->GetAttr(kAxis));
(void)CheckAndConvertUtils::CheckValue<size_t>("length of axis", axis.size(), kGreaterEqual, 1, op_name);
@ -68,6 +76,10 @@ abstract::ShapePtr SoftMaxInferShape(const PrimitivePtr &primitive, const std::v
}
TypePtr SoftMaxInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
if (input_args.size() < kNameSoftmaxMinInputSize || input_args.size() > kNameSoftmaxMaxInputSize) {
MS_LOG(EXCEPTION) << "For '" << prim->name() << "', the input args size should be " << kNameSoftmaxMinInputSize
<< " or " << kNameSoftmaxMaxInputSize << " , but get " << input_args.size();
}
if (std::any_of(input_args.begin(), input_args.end(), [](const AbstractBasePtr &a) { return a == nullptr; })) {
MS_LOG(EXCEPTION) << "For '" << prim->name()
<< ", the input args used for infer shape and type is necessary, but missing it.";