supplement softmax UnPackAttr

This commit is contained in:
lyvette 2020-08-31 11:00:52 +08:00
parent 8d41931456
commit a4617f667f
3 changed files with 33 additions and 0 deletions

View File

@ -257,6 +257,8 @@ std::shared_ptr<PrimitiveC> PrimitiveC::UnPackFromPrimitive(const Primitive &pri
return NewPrimitiveC<Transpose>(prim, inputs);
} else if (op_type == "tuple_getitem") {
return NewPrimitiveC<TupleGetItem>(prim, inputs);
} else if (op_type == "Softmax") {
return NewPrimitiveC<SoftMax>(prim, inputs);
} else {
MS_LOG(ERROR) << "Unsupported primitive type in UnPackFromPrimitive : " << op_type;
return nullptr;

View File

@ -21,6 +21,36 @@ namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int SoftMax::GetAxis() const { return this->primitive_->value.AsSoftMax()->axis; }
int SoftMax::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_SoftMax;
}
if (this->primitive_->value.type != schema::PrimitiveType_SoftMax) {
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
auto attr = new (std::nothrow) schema::SoftMaxT();
if (attr == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
auto prim_axis = GetValue<int>(prim.GetAttr("axis"));
attr->axis = prim_axis;
this->primitive_->value.value = attr;
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "primitive value is nullptr";
return RET_ERROR;
}
}
return RET_OK;
}
void SoftMax::SetAxis(int axis) { this->primitive_->value.AsSoftMax()->axis = axis; }
#else

View File

@ -31,6 +31,7 @@ class SoftMax : public PrimitiveC {
MS_DECLARE_PARENT(SoftMax, PrimitiveC);
SoftMax() = default;
explicit SoftMax(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
void SetAxis(int axis);
#else