add activation attr for gelu

This commit is contained in:
xuanyue 2021-07-08 17:10:53 +08:00
parent f457c355f5
commit 94d7c38c6f
9 changed files with 19 additions and 1 deletions

View File

@ -34,7 +34,7 @@ static bool CheckInputsDataType(const TensorC *const *inputs, size_t inputs_size
int SliceInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
OpParameter *parameter) {
int ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1);
int ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 3, 1);
if (ret != NNACL_OK) {
return ret;
}

View File

@ -54,6 +54,14 @@ ActivationType Activation::get_activation_type() const {
auto value_ptr = GetAttr(kActivationType);
return ActivationType(GetValue<int64_t>(value_ptr));
}
void Activation::set_approximate(bool approximate) { this->AddAttr(kApproximate, MakeValue(approximate)); }
bool Activation::get_approximate() const {
auto value_ptr = this->GetAttr(kApproximate);
return GetValue<bool>(value_ptr);
}
void Activation::Init(const float alpha, const float min_val, const float max_val,
const ActivationType &activation_type) {
this->set_alpha(alpha);

View File

@ -38,6 +38,8 @@ class Activation : public PrimitiveC {
float get_min_val() const;
float get_max_val() const;
ActivationType get_activation_type() const;
void set_approximate(bool approximate);
bool get_approximate() const;
};
} // namespace ops
} // namespace mindspore

View File

@ -251,6 +251,7 @@ constexpr auto kNumberSplit = "number_split";
constexpr auto kSplitDim = "split_dim";
constexpr auto kPadTop = "pad_top";
constexpr auto kTransFormat = "trans_format";
constexpr auto kApproximate = "approximate";
const std::set<TypePtr> common_valid_types = {kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16,
kUInt32, kUInt64, kFloat16, kFloat32, kFloat64};

View File

@ -229,6 +229,7 @@ table Activation {
alpha: float;
min_val: float;
max_val: float;
approximate: bool = false;
}
table ActivationGrad {

View File

@ -229,6 +229,7 @@ OP_ATTR_ENUM_WITH_VALUE(activation_type, ActivationType, 0)
OP_ATTR(alpha, float)
OP_ATTR(min_val, float)
OP_ATTR(max_val, float)
OP_ATTR_WITH_VALUE(approximate, bool, false)
OP_SCHEMA_DEF_END(Activation)
OP_SCHEMA_DEF(ActivationGrad)

View File

@ -29,6 +29,7 @@ CNodePtr GeLUFusion::CreateGeLUNode(const FuncGraphPtr &func_graph, const AnfNod
MS_ASSERT(node != nullptr);
auto gelu_prim = std::make_shared<ops::Activation>();
gelu_prim->set_activation_type(mindspore::GELU);
gelu_prim->set_approximate(approximate_);
auto input_node = utils::cast<AnfNodePtr>((*equiv)[input_]);
MS_ASSERT(input_node != nullptr);
auto gelu_cnode = func_graph->NewCNode(gelu_prim, {input_node});

View File

@ -41,6 +41,9 @@ class GeLUFusion : public PatternProcessPass {
private:
CNodePtr CreateGeLUNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &equiv) const;
protected:
mutable bool approximate_{false};
};
} // namespace opt
} // namespace mindspore

View File

@ -82,6 +82,7 @@ bool TfGeLUFusion::CheckPattern(const EquivPtr &equiv) const {
if (mul3_x < 0 || fabs(mul3_x - MUL3_X) > DIFF_THRESHOLD) {
return false;
}
approximate_ = true;
return true;
}
} // namespace opt