forked from mindspore-Ecosystem/mindspore
!17713 [GraphKernel] fix bgcf and enbale graph kernel on GPU.
From: @chenlei_autodiff Reviewed-by: @ckey_dou,@gaoxiong1 Signed-off-by:
This commit is contained in:
commit
df491ffe53
|
@ -126,25 +126,6 @@ bool GenJson(const AnfNodePtrList &op_nodes, const std::pair<AnfNodePtrList, Anf
|
|||
return true;
|
||||
}
|
||||
|
||||
bool TensorElementAllTheSame(const tensor::TensorPtr &tensor) {
|
||||
MS_EXCEPTION_IF_NULL(tensor);
|
||||
if (tensor->DataSize() == 1) {
|
||||
return true;
|
||||
}
|
||||
|
||||
auto data = static_cast<char *>(tensor->data_c());
|
||||
auto itemsize = static_cast<size_t>(tensor->data().itemsize());
|
||||
auto total_cnt = IntToSize(tensor->DataSize());
|
||||
for (size_t i = 1; i < total_cnt; ++i) {
|
||||
for (size_t ei = 0; ei < itemsize; ++ei) {
|
||||
if (data[ei] != data[i * itemsize + ei]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
AnfNodePtr ConvertToScalarTensor(const AnfNodePtr &value_node) {
|
||||
auto tensor = GetValueNode<tensor::TensorPtr>(value_node);
|
||||
MS_EXCEPTION_IF_NULL(tensor);
|
||||
|
@ -222,17 +203,13 @@ bool ConvertNonscalarTensorToParameter(const FuncGraphPtr &fg, AnfNodePtrList *i
|
|||
if (tensor == nullptr || tensor->DataSize() == 1) {
|
||||
continue;
|
||||
}
|
||||
if (TensorElementAllTheSame(tensor)) {
|
||||
scalar_tensors.emplace_back(tnode);
|
||||
auto tensor_iter = std::find_if(
|
||||
v_replace.begin(), v_replace.end(),
|
||||
[&tensor](const std::pair<tensor::TensorPtr, AnfNodePtrList> &vl) { return vl.first->ValueEqual(*tensor); });
|
||||
if (tensor_iter == v_replace.end()) {
|
||||
v_replace.emplace_back(tensor, AnfNodePtrList{tnode});
|
||||
} else {
|
||||
auto tensor_iter = std::find_if(
|
||||
v_replace.begin(), v_replace.end(),
|
||||
[&tensor](const std::pair<tensor::TensorPtr, AnfNodePtrList> &vl) { return vl.first->ValueEqual(*tensor); });
|
||||
if (tensor_iter == v_replace.end()) {
|
||||
v_replace.emplace_back(tensor, AnfNodePtrList{tnode});
|
||||
} else {
|
||||
tensor_iter->second.push_back(tnode);
|
||||
}
|
||||
tensor_iter->second.push_back(tnode);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -101,6 +101,8 @@ def run_train():
|
|||
|
||||
if config.device_target == "Ascend":
|
||||
context.set_context(device_id=get_device_id())
|
||||
if config.device_target == "GPU":
|
||||
context.set_context(enable_graph_kernel=True)
|
||||
|
||||
train_graph, _, sampled_graph_list = load_graph(config.datapath)
|
||||
train_ds = create_dataset(train_graph, sampled_graph_list, config.workers, batch_size=config.batch_pairs,
|
||||
|
|
Loading…
Reference in New Issue