forked from mindspore-Ecosystem/mindspore
!23460 Update pangu performance.
Merge pull request !23460 from linqingke/gpt
This commit is contained in:
commit
6aac4e42d0
|
@ -251,6 +251,7 @@ class GeLUCost : public SqrtCost {
|
||||||
// Taking account of input and output
|
// Taking account of input and output
|
||||||
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
|
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
|
||||||
};
|
};
|
||||||
|
using FastGeLUCost = GeLUCost;
|
||||||
using BesselI0eCost = GeLUCost;
|
using BesselI0eCost = GeLUCost;
|
||||||
using BesselI1eCost = GeLUCost;
|
using BesselI1eCost = GeLUCost;
|
||||||
using L2NormalizeCost = GeLUCost;
|
using L2NormalizeCost = GeLUCost;
|
||||||
|
|
|
@ -78,6 +78,7 @@ class RegisterAction {
|
||||||
// operator register
|
// operator register
|
||||||
REGISTER(MatMulInfo);
|
REGISTER(MatMulInfo);
|
||||||
REGISTER(GeLUInfo);
|
REGISTER(GeLUInfo);
|
||||||
|
REGISTER(FastGeLUInfo);
|
||||||
REGISTER(VirtualDatasetInfo);
|
REGISTER(VirtualDatasetInfo);
|
||||||
REGISTER(BatchParallelInfo);
|
REGISTER(BatchParallelInfo);
|
||||||
REGISTER(TanhInfo);
|
REGISTER(TanhInfo);
|
||||||
|
|
|
@ -89,6 +89,14 @@ class GeLUInfo : public ActivationOther {
|
||||||
~GeLUInfo() override = default;
|
~GeLUInfo() override = default;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class FastGeLUInfo : public ActivationOther {
|
||||||
|
public:
|
||||||
|
FastGeLUInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||||
|
const PrimitiveAttrs &attrs)
|
||||||
|
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<FastGeLUCost>()) {}
|
||||||
|
~FastGeLUInfo() override = default;
|
||||||
|
};
|
||||||
|
|
||||||
class TanhInfo : public ActivationOther {
|
class TanhInfo : public ActivationOther {
|
||||||
public:
|
public:
|
||||||
TanhInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
TanhInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||||
|
|
|
@ -36,7 +36,7 @@ class PanguAlphaConfig:
|
||||||
dropout_rate=0.1,
|
dropout_rate=0.1,
|
||||||
eod_token=6,
|
eod_token=6,
|
||||||
use_past=False,
|
use_past=False,
|
||||||
hidden_act='gelu',
|
hidden_act='fast_gelu',
|
||||||
eod_reset=True,
|
eod_reset=True,
|
||||||
enable_offload=False,
|
enable_offload=False,
|
||||||
parallel_config=None):
|
parallel_config=None):
|
||||||
|
|
Loading…
Reference in New Issue