forked from mindspore-Ecosystem/mindspore
fix gatherv2 and dataset bug
This commit is contained in:
parent
635acb6c27
commit
debfd38b75
|
@ -618,19 +618,11 @@ void StepReplaceGraph(const ReplaceGraphPtr &replace_graph, const CNodePtr &node
|
||||||
for (auto &replace_input : replace_graph->first) {
|
for (auto &replace_input : replace_graph->first) {
|
||||||
auto pre_node = node->input(IntToSize(replace_input.second));
|
auto pre_node = node->input(IntToSize(replace_input.second));
|
||||||
manager->SetEdge(replace_input.first, 1, pre_node);
|
manager->SetEdge(replace_input.first, 1, pre_node);
|
||||||
auto replace_input_cnode = replace_input.first->cast<CNodePtr>();
|
|
||||||
MS_EXCEPTION_IF_NULL(replace_input_cnode);
|
|
||||||
(void)replace_input_cnode->set_operator_info(node->operator_info());
|
|
||||||
replace_input_cnode->set_in_forward_flag(true); // mark this new cnode is forward node
|
|
||||||
}
|
}
|
||||||
// "(void)manager->Replace(replace_graph->first, pre_node);" can not be called
|
// "(void)manager->Replace(replace_graph->first, pre_node);" can not be called
|
||||||
auto replace_output = replace_graph->second;
|
auto replace_output = replace_graph->second;
|
||||||
MS_EXCEPTION_IF_NULL(replace_output);
|
MS_EXCEPTION_IF_NULL(replace_output);
|
||||||
(void)manager->Replace(node, replace_output);
|
(void)manager->Replace(node, replace_output);
|
||||||
CNodePtr replace_output_cnode = replace_graph->second->cast<CNodePtr>();
|
|
||||||
MS_EXCEPTION_IF_NULL(replace_output_cnode);
|
|
||||||
(void)replace_output_cnode->set_operator_info(node->operator_info());
|
|
||||||
replace_output_cnode->set_in_forward_flag(true); // mark this new cnode is forward node
|
|
||||||
}
|
}
|
||||||
|
|
||||||
int32_t GetTupleGetItemIndex(const CNodePtr &cnode) {
|
int32_t GetTupleGetItemIndex(const CNodePtr &cnode) {
|
||||||
|
@ -1994,14 +1986,27 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePt
|
||||||
BackwardCommunication(distribute_operator, cnode, sens_loss_pairs);
|
BackwardCommunication(distribute_operator, cnode, sens_loss_pairs);
|
||||||
}
|
}
|
||||||
|
|
||||||
// StepReplace
|
|
||||||
StepReplace(distribute_operator, cnode);
|
|
||||||
|
|
||||||
HandleSpecialNode(distribute_operator, cnode);
|
HandleSpecialNode(distribute_operator, cnode);
|
||||||
} else if (IsValueNode<Tensor>(node)) {
|
} else if (IsValueNode<Tensor>(node)) {
|
||||||
StepSplitTensor(node, manager);
|
StepSplitTensor(node, manager);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for (auto &node : all_nodes) {
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
if (node->isa<CNode>()) {
|
||||||
|
auto cnode = node->cast<CNodePtr>();
|
||||||
|
if (!IsValueNode<Primitive>(cnode->input(0))) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
OperatorInfoPtr distribute_operator = GetDistributeOperator(cnode);
|
||||||
|
if (distribute_operator == nullptr) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
// StepReplace
|
||||||
|
StepReplace(distribute_operator, cnode);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
|
@ -83,12 +83,6 @@ class _DatasetIter:
|
||||||
self.dataset = dataset
|
self.dataset = dataset
|
||||||
dataset_types, dataset_shapes = _get_types_and_shapes(dataset)
|
dataset_types, dataset_shapes = _get_types_and_shapes(dataset)
|
||||||
self.dataset_types, self.dataset_shapes = dataset_types, dataset_shapes
|
self.dataset_types, self.dataset_shapes = dataset_types, dataset_shapes
|
||||||
# for self._parallel_mode equal to semi_auto_parallel or auto_parallel, use a complete tensor to
|
|
||||||
# compile, and slice tensor to run. The batch dimension of tensors for compile is device_number
|
|
||||||
# times the batch dimension of tensors for run
|
|
||||||
if _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
|
|
||||||
device_num = _get_device_num()
|
|
||||||
self.dataset_shapes = _to_full_shapes(dataset_shapes, device_num)
|
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
self.ind = 0
|
self.ind = 0
|
||||||
|
@ -119,6 +113,12 @@ class _DatasetIterMSLoopSink(_DatasetIter):
|
||||||
def __init__(self, dataset):
|
def __init__(self, dataset):
|
||||||
super(_DatasetIterMSLoopSink, self).__init__(dataset)
|
super(_DatasetIterMSLoopSink, self).__init__(dataset)
|
||||||
self.loop_count = self.get_loop_count(dataset)
|
self.loop_count = self.get_loop_count(dataset)
|
||||||
|
# for self._parallel_mode equal to semi_auto_parallel or auto_parallel, use a complete tensor to
|
||||||
|
# compile, and slice tensor to run. The batch dimension of tensors for compile is device_number
|
||||||
|
# times the batch dimension of tensors for run. Now only support LoopSink.
|
||||||
|
if _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
|
||||||
|
device_num = _get_device_num()
|
||||||
|
self.dataset_shapes = _to_full_shapes(self.dataset_shapes, device_num)
|
||||||
|
|
||||||
def op():
|
def op():
|
||||||
return tuple()
|
return tuple()
|
||||||
|
|
|
@ -170,4 +170,3 @@ def test_gatherv2_auto1():
|
||||||
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||||
y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
|
y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
|
||||||
_executor.compile(net, x, y)
|
_executor.compile(net, x, y)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue