forked from mindspore-Ecosystem/mindspore
!49017 fix segmentation fault bug for softmax when input is None
Merge pull request !49017 from gengdongjie/fix_issues
This commit is contained in:
commit
11a8e7bc3d
|
@ -27,6 +27,9 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
constexpr auto kNameSoftmaxMinInputSize = 1;
|
||||||
|
constexpr auto kNameSoftmaxMaxInputSize = 2;
|
||||||
|
|
||||||
void Softmax::set_axis(const std::vector<int64_t> &axis) { (void)this->AddAttr(kAxis, api::MakeValue(axis)); }
|
void Softmax::set_axis(const std::vector<int64_t> &axis) { (void)this->AddAttr(kAxis, api::MakeValue(axis)); }
|
||||||
|
|
||||||
std::vector<int64_t> Softmax::get_axis() const {
|
std::vector<int64_t> Softmax::get_axis() const {
|
||||||
|
@ -48,6 +51,11 @@ void Softmax::Init(const int64_t axis) {
|
||||||
namespace {
|
namespace {
|
||||||
abstract::ShapePtr SoftMaxInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
abstract::ShapePtr SoftMaxInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||||
MS_EXCEPTION_IF_NULL(primitive);
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
if (input_args.size() < kNameSoftmaxMinInputSize || input_args.size() > kNameSoftmaxMaxInputSize) {
|
||||||
|
MS_LOG(EXCEPTION) << "For '" << primitive->name() << "', the input args size should be " << kNameSoftmaxMinInputSize
|
||||||
|
<< " or " << kNameSoftmaxMaxInputSize << " , but get " << input_args.size();
|
||||||
|
}
|
||||||
|
|
||||||
auto op_name = primitive->name();
|
auto op_name = primitive->name();
|
||||||
auto axis = GetValue<std::vector<int64_t>>(primitive->GetAttr(kAxis));
|
auto axis = GetValue<std::vector<int64_t>>(primitive->GetAttr(kAxis));
|
||||||
(void)CheckAndConvertUtils::CheckValue<size_t>("length of axis", axis.size(), kGreaterEqual, 1, op_name);
|
(void)CheckAndConvertUtils::CheckValue<size_t>("length of axis", axis.size(), kGreaterEqual, 1, op_name);
|
||||||
|
@ -68,6 +76,10 @@ abstract::ShapePtr SoftMaxInferShape(const PrimitivePtr &primitive, const std::v
|
||||||
}
|
}
|
||||||
|
|
||||||
TypePtr SoftMaxInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
TypePtr SoftMaxInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
if (input_args.size() < kNameSoftmaxMinInputSize || input_args.size() > kNameSoftmaxMaxInputSize) {
|
||||||
|
MS_LOG(EXCEPTION) << "For '" << prim->name() << "', the input args size should be " << kNameSoftmaxMinInputSize
|
||||||
|
<< " or " << kNameSoftmaxMaxInputSize << " , but get " << input_args.size();
|
||||||
|
}
|
||||||
if (std::any_of(input_args.begin(), input_args.end(), [](const AbstractBasePtr &a) { return a == nullptr; })) {
|
if (std::any_of(input_args.begin(), input_args.end(), [](const AbstractBasePtr &a) { return a == nullptr; })) {
|
||||||
MS_LOG(EXCEPTION) << "For '" << prim->name()
|
MS_LOG(EXCEPTION) << "For '" << prim->name()
|
||||||
<< ", the input args used for infer shape and type is necessary, but missing it.";
|
<< ", the input args used for infer shape and type is necessary, but missing it.";
|
||||||
|
|
Loading…
Reference in New Issue