forked from mindspore-Ecosystem/mindspore
fix gatherv2 and dataset bug
This commit is contained in:
parent
635acb6c27
commit
debfd38b75
|
@ -618,19 +618,11 @@ void StepReplaceGraph(const ReplaceGraphPtr &replace_graph, const CNodePtr &node
|
|||
for (auto &replace_input : replace_graph->first) {
|
||||
auto pre_node = node->input(IntToSize(replace_input.second));
|
||||
manager->SetEdge(replace_input.first, 1, pre_node);
|
||||
auto replace_input_cnode = replace_input.first->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(replace_input_cnode);
|
||||
(void)replace_input_cnode->set_operator_info(node->operator_info());
|
||||
replace_input_cnode->set_in_forward_flag(true); // mark this new cnode is forward node
|
||||
}
|
||||
// "(void)manager->Replace(replace_graph->first, pre_node);" can not be called
|
||||
auto replace_output = replace_graph->second;
|
||||
MS_EXCEPTION_IF_NULL(replace_output);
|
||||
(void)manager->Replace(node, replace_output);
|
||||
CNodePtr replace_output_cnode = replace_graph->second->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(replace_output_cnode);
|
||||
(void)replace_output_cnode->set_operator_info(node->operator_info());
|
||||
replace_output_cnode->set_in_forward_flag(true); // mark this new cnode is forward node
|
||||
}
|
||||
|
||||
int32_t GetTupleGetItemIndex(const CNodePtr &cnode) {
|
||||
|
@ -1994,14 +1986,27 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePt
|
|||
BackwardCommunication(distribute_operator, cnode, sens_loss_pairs);
|
||||
}
|
||||
|
||||
// StepReplace
|
||||
StepReplace(distribute_operator, cnode);
|
||||
|
||||
HandleSpecialNode(distribute_operator, cnode);
|
||||
} else if (IsValueNode<Tensor>(node)) {
|
||||
StepSplitTensor(node, manager);
|
||||
}
|
||||
}
|
||||
|
||||
for (auto &node : all_nodes) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (node->isa<CNode>()) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (!IsValueNode<Primitive>(cnode->input(0))) {
|
||||
continue;
|
||||
}
|
||||
OperatorInfoPtr distribute_operator = GetDistributeOperator(cnode);
|
||||
if (distribute_operator == nullptr) {
|
||||
continue;
|
||||
}
|
||||
// StepReplace
|
||||
StepReplace(distribute_operator, cnode);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
|
|
@ -83,12 +83,6 @@ class _DatasetIter:
|
|||
self.dataset = dataset
|
||||
dataset_types, dataset_shapes = _get_types_and_shapes(dataset)
|
||||
self.dataset_types, self.dataset_shapes = dataset_types, dataset_shapes
|
||||
# for self._parallel_mode equal to semi_auto_parallel or auto_parallel, use a complete tensor to
|
||||
# compile, and slice tensor to run. The batch dimension of tensors for compile is device_number
|
||||
# times the batch dimension of tensors for run
|
||||
if _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
|
||||
device_num = _get_device_num()
|
||||
self.dataset_shapes = _to_full_shapes(dataset_shapes, device_num)
|
||||
|
||||
def __iter__(self):
|
||||
self.ind = 0
|
||||
|
@ -119,6 +113,12 @@ class _DatasetIterMSLoopSink(_DatasetIter):
|
|||
def __init__(self, dataset):
|
||||
super(_DatasetIterMSLoopSink, self).__init__(dataset)
|
||||
self.loop_count = self.get_loop_count(dataset)
|
||||
# for self._parallel_mode equal to semi_auto_parallel or auto_parallel, use a complete tensor to
|
||||
# compile, and slice tensor to run. The batch dimension of tensors for compile is device_number
|
||||
# times the batch dimension of tensors for run. Now only support LoopSink.
|
||||
if _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
|
||||
device_num = _get_device_num()
|
||||
self.dataset_shapes = _to_full_shapes(self.dataset_shapes, device_num)
|
||||
|
||||
def op():
|
||||
return tuple()
|
||||
|
|
|
@ -170,4 +170,3 @@ def test_gatherv2_auto1():
|
|||
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y)
|
||||
|
||||
|
|
Loading…
Reference in New Issue