forked from mindspore-Ecosystem/mindspore
add activation attr for gelu
This commit is contained in:
parent
f457c355f5
commit
94d7c38c6f
|
@ -34,7 +34,7 @@ static bool CheckInputsDataType(const TensorC *const *inputs, size_t inputs_size
|
|||
|
||||
int SliceInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
|
||||
OpParameter *parameter) {
|
||||
int ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1);
|
||||
int ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 3, 1);
|
||||
if (ret != NNACL_OK) {
|
||||
return ret;
|
||||
}
|
||||
|
|
|
@ -54,6 +54,14 @@ ActivationType Activation::get_activation_type() const {
|
|||
auto value_ptr = GetAttr(kActivationType);
|
||||
return ActivationType(GetValue<int64_t>(value_ptr));
|
||||
}
|
||||
|
||||
void Activation::set_approximate(bool approximate) { this->AddAttr(kApproximate, MakeValue(approximate)); }
|
||||
|
||||
bool Activation::get_approximate() const {
|
||||
auto value_ptr = this->GetAttr(kApproximate);
|
||||
return GetValue<bool>(value_ptr);
|
||||
}
|
||||
|
||||
void Activation::Init(const float alpha, const float min_val, const float max_val,
|
||||
const ActivationType &activation_type) {
|
||||
this->set_alpha(alpha);
|
||||
|
|
|
@ -38,6 +38,8 @@ class Activation : public PrimitiveC {
|
|||
float get_min_val() const;
|
||||
float get_max_val() const;
|
||||
ActivationType get_activation_type() const;
|
||||
void set_approximate(bool approximate);
|
||||
bool get_approximate() const;
|
||||
};
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -251,6 +251,7 @@ constexpr auto kNumberSplit = "number_split";
|
|||
constexpr auto kSplitDim = "split_dim";
|
||||
constexpr auto kPadTop = "pad_top";
|
||||
constexpr auto kTransFormat = "trans_format";
|
||||
constexpr auto kApproximate = "approximate";
|
||||
const std::set<TypePtr> common_valid_types = {kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16,
|
||||
kUInt32, kUInt64, kFloat16, kFloat32, kFloat64};
|
||||
|
||||
|
|
|
@ -229,6 +229,7 @@ table Activation {
|
|||
alpha: float;
|
||||
min_val: float;
|
||||
max_val: float;
|
||||
approximate: bool = false;
|
||||
}
|
||||
|
||||
table ActivationGrad {
|
||||
|
|
|
@ -229,6 +229,7 @@ OP_ATTR_ENUM_WITH_VALUE(activation_type, ActivationType, 0)
|
|||
OP_ATTR(alpha, float)
|
||||
OP_ATTR(min_val, float)
|
||||
OP_ATTR(max_val, float)
|
||||
OP_ATTR_WITH_VALUE(approximate, bool, false)
|
||||
OP_SCHEMA_DEF_END(Activation)
|
||||
|
||||
OP_SCHEMA_DEF(ActivationGrad)
|
||||
|
|
|
@ -29,6 +29,7 @@ CNodePtr GeLUFusion::CreateGeLUNode(const FuncGraphPtr &func_graph, const AnfNod
|
|||
MS_ASSERT(node != nullptr);
|
||||
auto gelu_prim = std::make_shared<ops::Activation>();
|
||||
gelu_prim->set_activation_type(mindspore::GELU);
|
||||
gelu_prim->set_approximate(approximate_);
|
||||
auto input_node = utils::cast<AnfNodePtr>((*equiv)[input_]);
|
||||
MS_ASSERT(input_node != nullptr);
|
||||
auto gelu_cnode = func_graph->NewCNode(gelu_prim, {input_node});
|
||||
|
|
|
@ -41,6 +41,9 @@ class GeLUFusion : public PatternProcessPass {
|
|||
|
||||
private:
|
||||
CNodePtr CreateGeLUNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &equiv) const;
|
||||
|
||||
protected:
|
||||
mutable bool approximate_{false};
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -82,6 +82,7 @@ bool TfGeLUFusion::CheckPattern(const EquivPtr &equiv) const {
|
|||
if (mul3_x < 0 || fabs(mul3_x - MUL3_X) > DIFF_THRESHOLD) {
|
||||
return false;
|
||||
}
|
||||
approximate_ = true;
|
||||
return true;
|
||||
}
|
||||
} // namespace opt
|
||||
|
|
Loading…
Reference in New Issue