forked from mindspore-Ecosystem/mindspore
modify bert test file
This commit is contained in:
parent
97276fa522
commit
5b176f258b
|
@ -82,12 +82,6 @@ bool IsValidKernelInfo(const std::shared_ptr<CNode> &kernel_node, const kernel::
|
|||
}
|
||||
return true;
|
||||
};
|
||||
if (AnfAlgo::GetCNodeName(kernel_node) == "LayerNormBetaGammaBackprop" ||
|
||||
AnfAlgo::GetCNodeName(kernel_node) == "LayerNormXBackprop") {
|
||||
if (AnfAlgo::GetPrevNodeOutputFormat(kernel_node, 0) != kernel_build_info.GetInputFormat(0)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
if (AnfAlgo::GetCNodeName(kernel_node) == prim::kPrimCast->name()) {
|
||||
return AnfAlgo::GetOutputInferDataType(kernel_node, 0) == kernel_build_info.GetOutputDeviceType(0) &&
|
||||
AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0) == kernel_build_info.GetInputDeviceType(0);
|
||||
|
@ -161,7 +155,7 @@ bool PriorityChooseItem(const std::vector<int> &cur_item, std::vector<int> *best
|
|||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, const std::shared_ptr<CNode> &kernel_node,
|
||||
|
|
|
@ -27,7 +27,7 @@ from mindspore.common.tensor import Tensor
|
|||
from mindspore.train.model import Model
|
||||
from mindspore.train.callback import Callback
|
||||
from mindspore.model_zoo.Bert_NEZHA import BertConfig, BertNetworkWithLoss, BertTrainOneStepCell
|
||||
from mindspore.nn.optim import Lamb
|
||||
from mindspore.nn.optim import Momentum
|
||||
from mindspore import log as logger
|
||||
_current_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
DATA_DIR = ["/home/workspace/mindspore_dataset/bert/example/examples.tfrecord"]
|
||||
|
|
Loading…
Reference in New Issue