!17713 [GraphKernel] fix bgcf and enbale graph kernel on GPU.

From: @chenlei_autodiff
Reviewed-by: @ckey_dou,@gaoxiong1
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-06-07 16:38:42 +08:00 committed by Gitee
commit df491ffe53
2 changed files with 8 additions and 29 deletions

View File

@ -126,25 +126,6 @@ bool GenJson(const AnfNodePtrList &op_nodes, const std::pair<AnfNodePtrList, Anf
return true;
}
bool TensorElementAllTheSame(const tensor::TensorPtr &tensor) {
MS_EXCEPTION_IF_NULL(tensor);
if (tensor->DataSize() == 1) {
return true;
}
auto data = static_cast<char *>(tensor->data_c());
auto itemsize = static_cast<size_t>(tensor->data().itemsize());
auto total_cnt = IntToSize(tensor->DataSize());
for (size_t i = 1; i < total_cnt; ++i) {
for (size_t ei = 0; ei < itemsize; ++ei) {
if (data[ei] != data[i * itemsize + ei]) {
return false;
}
}
}
return true;
}
AnfNodePtr ConvertToScalarTensor(const AnfNodePtr &value_node) {
auto tensor = GetValueNode<tensor::TensorPtr>(value_node);
MS_EXCEPTION_IF_NULL(tensor);
@ -222,17 +203,13 @@ bool ConvertNonscalarTensorToParameter(const FuncGraphPtr &fg, AnfNodePtrList *i
if (tensor == nullptr || tensor->DataSize() == 1) {
continue;
}
if (TensorElementAllTheSame(tensor)) {
scalar_tensors.emplace_back(tnode);
auto tensor_iter = std::find_if(
v_replace.begin(), v_replace.end(),
[&tensor](const std::pair<tensor::TensorPtr, AnfNodePtrList> &vl) { return vl.first->ValueEqual(*tensor); });
if (tensor_iter == v_replace.end()) {
v_replace.emplace_back(tensor, AnfNodePtrList{tnode});
} else {
auto tensor_iter = std::find_if(
v_replace.begin(), v_replace.end(),
[&tensor](const std::pair<tensor::TensorPtr, AnfNodePtrList> &vl) { return vl.first->ValueEqual(*tensor); });
if (tensor_iter == v_replace.end()) {
v_replace.emplace_back(tensor, AnfNodePtrList{tnode});
} else {
tensor_iter->second.push_back(tnode);
}
tensor_iter->second.push_back(tnode);
}
}
}

View File

@ -101,6 +101,8 @@ def run_train():
if config.device_target == "Ascend":
context.set_context(device_id=get_device_id())
if config.device_target == "GPU":
context.set_context(enable_graph_kernel=True)
train_graph, _, sampled_graph_list = load_graph(config.datapath)
train_ds = create_dataset(train_graph, sampled_graph_list, config.workers, batch_size=config.batch_pairs,