forked from mindspore-Ecosystem/mindspore
!49017 fix segmentation fault bug for softmax when input is None
Merge pull request !49017 from gengdongjie/fix_issues
This commit is contained in:
commit
11a8e7bc3d
|
@ -27,6 +27,9 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameSoftmaxMinInputSize = 1;
|
||||
constexpr auto kNameSoftmaxMaxInputSize = 2;
|
||||
|
||||
void Softmax::set_axis(const std::vector<int64_t> &axis) { (void)this->AddAttr(kAxis, api::MakeValue(axis)); }
|
||||
|
||||
std::vector<int64_t> Softmax::get_axis() const {
|
||||
|
@ -48,6 +51,11 @@ void Softmax::Init(const int64_t axis) {
|
|||
namespace {
|
||||
abstract::ShapePtr SoftMaxInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
if (input_args.size() < kNameSoftmaxMinInputSize || input_args.size() > kNameSoftmaxMaxInputSize) {
|
||||
MS_LOG(EXCEPTION) << "For '" << primitive->name() << "', the input args size should be " << kNameSoftmaxMinInputSize
|
||||
<< " or " << kNameSoftmaxMaxInputSize << " , but get " << input_args.size();
|
||||
}
|
||||
|
||||
auto op_name = primitive->name();
|
||||
auto axis = GetValue<std::vector<int64_t>>(primitive->GetAttr(kAxis));
|
||||
(void)CheckAndConvertUtils::CheckValue<size_t>("length of axis", axis.size(), kGreaterEqual, 1, op_name);
|
||||
|
@ -68,6 +76,10 @@ abstract::ShapePtr SoftMaxInferShape(const PrimitivePtr &primitive, const std::v
|
|||
}
|
||||
|
||||
TypePtr SoftMaxInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
if (input_args.size() < kNameSoftmaxMinInputSize || input_args.size() > kNameSoftmaxMaxInputSize) {
|
||||
MS_LOG(EXCEPTION) << "For '" << prim->name() << "', the input args size should be " << kNameSoftmaxMinInputSize
|
||||
<< " or " << kNameSoftmaxMaxInputSize << " , but get " << input_args.size();
|
||||
}
|
||||
if (std::any_of(input_args.begin(), input_args.end(), [](const AbstractBasePtr &a) { return a == nullptr; })) {
|
||||
MS_LOG(EXCEPTION) << "For '" << prim->name()
|
||||
<< ", the input args used for infer shape and type is necessary, but missing it.";
|
||||
|
|
Loading…
Reference in New Issue