fastgelu update.
This commit is contained in:
parent
608d5ddd71
commit
6af059dac3
|
@ -251,6 +251,7 @@ class GeLUCost : public SqrtCost {
|
|||
// Taking account of input and output
|
||||
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
|
||||
};
|
||||
using FastGeLUCost = GeLUCost;
|
||||
using BesselI0eCost = GeLUCost;
|
||||
using BesselI1eCost = GeLUCost;
|
||||
using L2NormalizeCost = GeLUCost;
|
||||
|
|
|
@ -78,6 +78,7 @@ class RegisterAction {
|
|||
// operator register
|
||||
REGISTER(MatMulInfo);
|
||||
REGISTER(GeLUInfo);
|
||||
REGISTER(FastGeLUInfo);
|
||||
REGISTER(VirtualDatasetInfo);
|
||||
REGISTER(BatchParallelInfo);
|
||||
REGISTER(TanhInfo);
|
||||
|
|
|
@ -89,6 +89,14 @@ class GeLUInfo : public ActivationOther {
|
|||
~GeLUInfo() override = default;
|
||||
};
|
||||
|
||||
class FastGeLUInfo : public ActivationOther {
|
||||
public:
|
||||
FastGeLUInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<FastGeLUCost>()) {}
|
||||
~FastGeLUInfo() override = default;
|
||||
};
|
||||
|
||||
class TanhInfo : public ActivationOther {
|
||||
public:
|
||||
TanhInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
|
|
|
@ -36,7 +36,7 @@ class PanguAlphaConfig:
|
|||
dropout_rate=0.1,
|
||||
eod_token=6,
|
||||
use_past=False,
|
||||
hidden_act='gelu',
|
||||
hidden_act='fast_gelu',
|
||||
eod_reset=True,
|
||||
enable_offload=False,
|
||||
parallel_config=None):
|
||||
|
|
Loading…
Reference in New Issue