diff --git a/mindspore/lite/src/ops/primitive_c.cc b/mindspore/lite/src/ops/primitive_c.cc index 94a8b265274..9ae43c7a86e 100644 --- a/mindspore/lite/src/ops/primitive_c.cc +++ b/mindspore/lite/src/ops/primitive_c.cc @@ -257,6 +257,8 @@ std::shared_ptr PrimitiveC::UnPackFromPrimitive(const Primitive &pri return NewPrimitiveC(prim, inputs); } else if (op_type == "tuple_getitem") { return NewPrimitiveC(prim, inputs); + } else if (op_type == "Softmax") { + return NewPrimitiveC(prim, inputs); } else { MS_LOG(ERROR) << "Unsupported primitive type in UnPackFromPrimitive : " << op_type; return nullptr; diff --git a/mindspore/lite/src/ops/softmax.cc b/mindspore/lite/src/ops/softmax.cc index 2640d3edbc5..af518125d02 100644 --- a/mindspore/lite/src/ops/softmax.cc +++ b/mindspore/lite/src/ops/softmax.cc @@ -21,6 +21,36 @@ namespace lite { #ifdef PRIMITIVE_WRITEABLE int SoftMax::GetAxis() const { return this->primitive_->value.AsSoftMax()->axis; } +int SoftMax::UnPackAttr(const Primitive &prim, const std::vector &inputs) { + if (this->primitive_ == nullptr) { + this->primitive_ = new (std::nothrow) schema::PrimitiveT; + if (this->primitive_ == nullptr) { + MS_LOG(ERROR) << "new primitiveT failed"; + return RET_ERROR; + } + this->primitive_->value.type = schema::PrimitiveType_SoftMax; + } + if (this->primitive_->value.type != schema::PrimitiveType_SoftMax) { + MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; + return RET_ERROR; + } + if (this->primitive_->value.value == nullptr) { + auto attr = new (std::nothrow) schema::SoftMaxT(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new primitiveT value failed"; + return RET_ERROR; + } + auto prim_axis = GetValue(prim.GetAttr("axis")); + attr->axis = prim_axis; + this->primitive_->value.value = attr; + if (this->primitive_->value.value == nullptr) { + MS_LOG(ERROR) << "primitive value is nullptr"; + return RET_ERROR; + } + } + return RET_OK; +} + void SoftMax::SetAxis(int axis) { this->primitive_->value.AsSoftMax()->axis = axis; } #else diff --git a/mindspore/lite/src/ops/softmax.h b/mindspore/lite/src/ops/softmax.h index aa7dc5db88a..3659d88820b 100644 --- a/mindspore/lite/src/ops/softmax.h +++ b/mindspore/lite/src/ops/softmax.h @@ -31,6 +31,7 @@ class SoftMax : public PrimitiveC { MS_DECLARE_PARENT(SoftMax, PrimitiveC); SoftMax() = default; explicit SoftMax(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} + int UnPackAttr(const Primitive &prim, const std::vector &inputs) override; void SetAxis(int axis); #else