fix gatherv2 and dataset bug

This commit is contained in:
lichenever 2020-05-13 09:44:37 +08:00
parent 635acb6c27
commit debfd38b75
3 changed files with 22 additions and 18 deletions

View File

@ -618,19 +618,11 @@ void StepReplaceGraph(const ReplaceGraphPtr &replace_graph, const CNodePtr &node
for (auto &replace_input : replace_graph->first) { for (auto &replace_input : replace_graph->first) {
auto pre_node = node->input(IntToSize(replace_input.second)); auto pre_node = node->input(IntToSize(replace_input.second));
manager->SetEdge(replace_input.first, 1, pre_node); manager->SetEdge(replace_input.first, 1, pre_node);
auto replace_input_cnode = replace_input.first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(replace_input_cnode);
(void)replace_input_cnode->set_operator_info(node->operator_info());
replace_input_cnode->set_in_forward_flag(true); // mark this new cnode is forward node
} }
// "(void)manager->Replace(replace_graph->first, pre_node);" can not be called // "(void)manager->Replace(replace_graph->first, pre_node);" can not be called
auto replace_output = replace_graph->second; auto replace_output = replace_graph->second;
MS_EXCEPTION_IF_NULL(replace_output); MS_EXCEPTION_IF_NULL(replace_output);
(void)manager->Replace(node, replace_output); (void)manager->Replace(node, replace_output);
CNodePtr replace_output_cnode = replace_graph->second->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(replace_output_cnode);
(void)replace_output_cnode->set_operator_info(node->operator_info());
replace_output_cnode->set_in_forward_flag(true); // mark this new cnode is forward node
} }
int32_t GetTupleGetItemIndex(const CNodePtr &cnode) { int32_t GetTupleGetItemIndex(const CNodePtr &cnode) {
@ -1994,14 +1986,27 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePt
BackwardCommunication(distribute_operator, cnode, sens_loss_pairs); BackwardCommunication(distribute_operator, cnode, sens_loss_pairs);
} }
// StepReplace
StepReplace(distribute_operator, cnode);
HandleSpecialNode(distribute_operator, cnode); HandleSpecialNode(distribute_operator, cnode);
} else if (IsValueNode<Tensor>(node)) { } else if (IsValueNode<Tensor>(node)) {
StepSplitTensor(node, manager); StepSplitTensor(node, manager);
} }
} }
for (auto &node : all_nodes) {
MS_EXCEPTION_IF_NULL(node);
if (node->isa<CNode>()) {
auto cnode = node->cast<CNodePtr>();
if (!IsValueNode<Primitive>(cnode->input(0))) {
continue;
}
OperatorInfoPtr distribute_operator = GetDistributeOperator(cnode);
if (distribute_operator == nullptr) {
continue;
}
// StepReplace
StepReplace(distribute_operator, cnode);
}
}
} }
namespace { namespace {

View File

@ -83,12 +83,6 @@ class _DatasetIter:
self.dataset = dataset self.dataset = dataset
dataset_types, dataset_shapes = _get_types_and_shapes(dataset) dataset_types, dataset_shapes = _get_types_and_shapes(dataset)
self.dataset_types, self.dataset_shapes = dataset_types, dataset_shapes self.dataset_types, self.dataset_shapes = dataset_types, dataset_shapes
# for self._parallel_mode equal to semi_auto_parallel or auto_parallel, use a complete tensor to
# compile, and slice tensor to run. The batch dimension of tensors for compile is device_number
# times the batch dimension of tensors for run
if _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
device_num = _get_device_num()
self.dataset_shapes = _to_full_shapes(dataset_shapes, device_num)
def __iter__(self): def __iter__(self):
self.ind = 0 self.ind = 0
@ -119,6 +113,12 @@ class _DatasetIterMSLoopSink(_DatasetIter):
def __init__(self, dataset): def __init__(self, dataset):
super(_DatasetIterMSLoopSink, self).__init__(dataset) super(_DatasetIterMSLoopSink, self).__init__(dataset)
self.loop_count = self.get_loop_count(dataset) self.loop_count = self.get_loop_count(dataset)
# for self._parallel_mode equal to semi_auto_parallel or auto_parallel, use a complete tensor to
# compile, and slice tensor to run. The batch dimension of tensors for compile is device_number
# times the batch dimension of tensors for run. Now only support LoopSink.
if _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
device_num = _get_device_num()
self.dataset_shapes = _to_full_shapes(self.dataset_shapes, device_num)
def op(): def op():
return tuple() return tuple()

View File

@ -170,4 +170,3 @@ def test_gatherv2_auto1():
x = Tensor(np.ones([64, 32]), dtype=ms.float32) x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32) y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
_executor.compile(net, x, y) _executor.compile(net, x, y)