forked from mindspore-Ecosystem/mindspore
supplement softmax UnPackAttr
This commit is contained in:
parent
8d41931456
commit
a4617f667f
|
@ -257,6 +257,8 @@ std::shared_ptr<PrimitiveC> PrimitiveC::UnPackFromPrimitive(const Primitive &pri
|
|||
return NewPrimitiveC<Transpose>(prim, inputs);
|
||||
} else if (op_type == "tuple_getitem") {
|
||||
return NewPrimitiveC<TupleGetItem>(prim, inputs);
|
||||
} else if (op_type == "Softmax") {
|
||||
return NewPrimitiveC<SoftMax>(prim, inputs);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupported primitive type in UnPackFromPrimitive : " << op_type;
|
||||
return nullptr;
|
||||
|
|
|
@ -21,6 +21,36 @@ namespace lite {
|
|||
#ifdef PRIMITIVE_WRITEABLE
|
||||
int SoftMax::GetAxis() const { return this->primitive_->value.AsSoftMax()->axis; }
|
||||
|
||||
int SoftMax::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
|
||||
if (this->primitive_ == nullptr) {
|
||||
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
|
||||
if (this->primitive_ == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitiveT failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
this->primitive_->value.type = schema::PrimitiveType_SoftMax;
|
||||
}
|
||||
if (this->primitive_->value.type != schema::PrimitiveType_SoftMax) {
|
||||
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (this->primitive_->value.value == nullptr) {
|
||||
auto attr = new (std::nothrow) schema::SoftMaxT();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitiveT value failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto prim_axis = GetValue<int>(prim.GetAttr("axis"));
|
||||
attr->axis = prim_axis;
|
||||
this->primitive_->value.value = attr;
|
||||
if (this->primitive_->value.value == nullptr) {
|
||||
MS_LOG(ERROR) << "primitive value is nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
void SoftMax::SetAxis(int axis) { this->primitive_->value.AsSoftMax()->axis = axis; }
|
||||
|
||||
#else
|
||||
|
|
|
@ -31,6 +31,7 @@ class SoftMax : public PrimitiveC {
|
|||
MS_DECLARE_PARENT(SoftMax, PrimitiveC);
|
||||
SoftMax() = default;
|
||||
explicit SoftMax(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
|
||||
void SetAxis(int axis);
|
||||
|
||||
#else
|
||||
|
|
Loading…
Reference in New Issue