fastgelu update.

This commit is contained in:
linqingke 2021-09-01 11:26:43 +08:00
parent 608d5ddd71
commit 6af059dac3
4 changed files with 11 additions and 1 deletions

View File

@ -251,6 +251,7 @@ class GeLUCost : public SqrtCost {
// Taking account of input and output
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
};
using FastGeLUCost = GeLUCost;
using BesselI0eCost = GeLUCost;
using BesselI1eCost = GeLUCost;
using L2NormalizeCost = GeLUCost;

View File

@ -78,6 +78,7 @@ class RegisterAction {
// operator register
REGISTER(MatMulInfo);
REGISTER(GeLUInfo);
REGISTER(FastGeLUInfo);
REGISTER(VirtualDatasetInfo);
REGISTER(BatchParallelInfo);
REGISTER(TanhInfo);

View File

@ -89,6 +89,14 @@ class GeLUInfo : public ActivationOther {
~GeLUInfo() override = default;
};
class FastGeLUInfo : public ActivationOther {
public:
FastGeLUInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<FastGeLUCost>()) {}
~FastGeLUInfo() override = default;
};
class TanhInfo : public ActivationOther {
public:
TanhInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,

View File

@ -36,7 +36,7 @@ class PanguAlphaConfig:
dropout_rate=0.1,
eod_token=6,
use_past=False,
hidden_act='gelu',
hidden_act='fast_gelu',
eod_reset=True,
enable_offload=False,
parallel_config=None):