forked from mindspore-Ecosystem/mindspore
fix bert precison bug
This commit is contained in:
parent
2d44dd1cb3
commit
da123c5b3e
|
@ -35,6 +35,7 @@ enum MatchCountPriority : int {
|
||||||
MATCH_COUNT_PRIORITY_BEGIN = 0,
|
MATCH_COUNT_PRIORITY_BEGIN = 0,
|
||||||
MATCH_DTYPE_COUNT = MATCH_COUNT_PRIORITY_BEGIN,
|
MATCH_DTYPE_COUNT = MATCH_COUNT_PRIORITY_BEGIN,
|
||||||
MATCH_FORMAT_COUNT,
|
MATCH_FORMAT_COUNT,
|
||||||
|
MATCH_SPECIAL_FORMAT_COUNT,
|
||||||
MATCH_5D_FORMAT_COUNT,
|
MATCH_5D_FORMAT_COUNT,
|
||||||
MATCH_OUTPUT_DTYPE_COUNT,
|
MATCH_OUTPUT_DTYPE_COUNT,
|
||||||
MATCH_COUNT_PRIORITY_END
|
MATCH_COUNT_PRIORITY_END
|
||||||
|
@ -81,6 +82,12 @@ bool IsValidKernelInfo(const std::shared_ptr<CNode> &kernel_node, const kernel::
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
};
|
};
|
||||||
|
if (AnfAlgo::GetCNodeName(kernel_node) == "LayerNormBetaGammaBackprop" ||
|
||||||
|
AnfAlgo::GetCNodeName(kernel_node) == "LayerNormXBackprop") {
|
||||||
|
if (AnfAlgo::GetPrevNodeOutputFormat(kernel_node, 0) != kernel_build_info.GetInputFormat(0)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
if (AnfAlgo::GetCNodeName(kernel_node) == prim::kPrimCast->name()) {
|
if (AnfAlgo::GetCNodeName(kernel_node) == prim::kPrimCast->name()) {
|
||||||
return AnfAlgo::GetOutputInferDataType(kernel_node, 0) == kernel_build_info.GetOutputDeviceType(0) &&
|
return AnfAlgo::GetOutputInferDataType(kernel_node, 0) == kernel_build_info.GetOutputDeviceType(0) &&
|
||||||
AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0) == kernel_build_info.GetInputDeviceType(0);
|
AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0) == kernel_build_info.GetInputDeviceType(0);
|
||||||
|
@ -154,7 +161,7 @@ bool PriorityChooseItem(const std::vector<int> &cur_item, std::vector<int> *best
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return false;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, const std::shared_ptr<CNode> &kernel_node,
|
void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, const std::shared_ptr<CNode> &kernel_node,
|
||||||
|
@ -174,12 +181,11 @@ void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, cons
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (input_anf_node->isa<ValueNode>()) {
|
|
||||||
if (AnfAlgo::GetOutputDeviceDataType(input_anf_node, 0) == kTypeUnknown) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (kernel_build_info.GetInputFormat(input_index) == AnfAlgo::GetPrevNodeOutputFormat(kernel_node, input_index)) {
|
if (kernel_build_info.GetInputFormat(input_index) == AnfAlgo::GetPrevNodeOutputFormat(kernel_node, input_index)) {
|
||||||
|
if (AnfAlgo::IsFeatureMapInput(kernel_node, input_index) &&
|
||||||
|
kSpecialFormatSet.find(kernel_build_info.GetInputFormat(input_index)) != kSpecialFormatSet.end()) {
|
||||||
|
(*cur_kernelinfo_match_counts)[MATCH_SPECIAL_FORMAT_COUNT]++;
|
||||||
|
}
|
||||||
(*cur_kernelinfo_match_counts)[MATCH_FORMAT_COUNT]++;
|
(*cur_kernelinfo_match_counts)[MATCH_FORMAT_COUNT]++;
|
||||||
}
|
}
|
||||||
if (kernel_build_info.GetInputDeviceType(input_index) ==
|
if (kernel_build_info.GetInputDeviceType(input_index) ==
|
||||||
|
@ -203,7 +209,7 @@ void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, cons
|
||||||
(*cur_kernelinfo_match_counts)[MATCH_OUTPUT_DTYPE_COUNT]++;
|
(*cur_kernelinfo_match_counts)[MATCH_OUTPUT_DTYPE_COUNT]++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
} // namespace
|
||||||
|
|
||||||
void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, const CNodePtr &kernel_node) {
|
void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, const CNodePtr &kernel_node) {
|
||||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||||
|
|
|
@ -195,6 +195,9 @@ const std::set<std::string> kOptOperatorSet = {
|
||||||
kApplyRMSPropOpName,
|
kApplyRMSPropOpName,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const std::set<std::string> kSpecialFormatSet = {kOpFormat_FRAC_Z, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0,
|
||||||
|
kOpFormat_FRAC_NZ, kOpFormat_C1HWNCoC0};
|
||||||
|
|
||||||
static inline void ChangeFileMode(const std::string& file_name, mode_t mode) {
|
static inline void ChangeFileMode(const std::string& file_name, mode_t mode) {
|
||||||
if (access(file_name.c_str(), F_OK) != 0) {
|
if (access(file_name.c_str(), F_OK) != 0) {
|
||||||
MS_LOG(DEBUG) << "File `" << file_name << "` does not exist.";
|
MS_LOG(DEBUG) << "File `" << file_name << "` does not exist.";
|
||||||
|
|
|
@ -32,10 +32,10 @@ from mindspore.ops.op_info_register import op_info_register
|
||||||
{
|
{
|
||||||
"index": 0,
|
"index": 0,
|
||||||
"dtype": [
|
"dtype": [
|
||||||
"float16","float","float16","float16","float16","float16","float","float","float","float"
|
"float16","float","float16","float","float16","float16","float16","float16","float","float","float","float"
|
||||||
],
|
],
|
||||||
"format": [
|
"format": [
|
||||||
"FracZ","FracZ","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat"
|
"FRACTAL_NZ","FRACTAL_NZ","FracZ","FracZ","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat"
|
||||||
],
|
],
|
||||||
"name": "x",
|
"name": "x",
|
||||||
"need_compile": false,
|
"need_compile": false,
|
||||||
|
@ -47,10 +47,10 @@ from mindspore.ops.op_info_register import op_info_register
|
||||||
{
|
{
|
||||||
"index": 0,
|
"index": 0,
|
||||||
"dtype": [
|
"dtype": [
|
||||||
"float16","float","float16","float16","float16","float16","float","float","float","float"
|
"float16","float","float16","float","float16","float16","float16","float16","float","float","float","float"
|
||||||
],
|
],
|
||||||
"format": [
|
"format": [
|
||||||
"FracZ","FracZ","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat"
|
"FRACTAL_NZ","FRACTAL_NZ","FracZ","FracZ","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat"
|
||||||
],
|
],
|
||||||
"name": "y",
|
"name": "y",
|
||||||
"need_compile": true,
|
"need_compile": true,
|
||||||
|
|
|
@ -153,8 +153,7 @@ def test_bert_tdt():
|
||||||
batch_size = int(os.getenv('BATCH_SIZE', '16'))
|
batch_size = int(os.getenv('BATCH_SIZE', '16'))
|
||||||
config = get_config(version=version, batch_size=batch_size)
|
config = get_config(version=version, batch_size=batch_size)
|
||||||
netwithloss = BertNetworkWithLoss(config, True)
|
netwithloss = BertNetworkWithLoss(config, True)
|
||||||
optimizer = Lamb(netwithloss.trainable_params(), decay_steps=10000, start_learning_rate=1e-4,
|
optimizer = Momentum(netwithloss.trainable_params(), learning_rate=2e-5, momentum=0.9)
|
||||||
end_learning_rate=0.0, power=10.0, warmup_steps=0, decay_filter=lambda x: False)
|
|
||||||
netwithgrads = BertTrainOneStepCell(netwithloss, optimizer=optimizer)
|
netwithgrads = BertTrainOneStepCell(netwithloss, optimizer=optimizer)
|
||||||
netwithgrads.set_train(True)
|
netwithgrads.set_train(True)
|
||||||
model = Model(netwithgrads)
|
model = Model(netwithgrads)
|
||||||
|
@ -178,10 +177,10 @@ def test_bert_tdt():
|
||||||
param.default_input = weight_variable(value.asnumpy().shape)
|
param.default_input = weight_variable(value.asnumpy().shape)
|
||||||
model.train(ds.get_repeat_count(), ds, callbacks=parallel_callback, dataset_sink_mode=False)
|
model.train(ds.get_repeat_count(), ds, callbacks=parallel_callback, dataset_sink_mode=False)
|
||||||
loss_value = np.array(parallel_callback.loss_list)
|
loss_value = np.array(parallel_callback.loss_list)
|
||||||
expect_out = [12.191790, 11.739655, 11.523477, 11.320723, 11.113152, 11.203759, 10.841681, 10.826849,
|
expect_out = [12.19179, 11.965041, 11.969687, 11.97815, 11.969171, 12.603289, 12.165594,
|
||||||
10.616718, 10.486609]
|
12.824818, 12.38842, 12.604046]
|
||||||
logger.info("expected loss value output: {}".format(expect_out))
|
logger.info("expected loss value output: {}".format(expect_out))
|
||||||
assert allclose(loss_value, expect_out, 0.001, 0.001)
|
assert allclose(loss_value, expect_out, 0.00001, 0.00001)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_bert_tdt()
|
test_bert_tdt()
|
||||||
|
|
Loading…
Reference in New Issue