Fix Sponge net.

This commit is contained in:
gaoyong10 2021-07-03 21:04:50 +08:00
parent 0553fb0a6a
commit b3f0a29fb1
3 changed files with 10 additions and 8 deletions

View File

@ -407,7 +407,7 @@ void SwitchActor::SendOutput(OpContext<DeviceTensor> *context) {
bool is_send = false;
for (const auto &backend_node : backend_parameters_[from_index]) {
for (size_t j = 0; j < AnfAlgo::GetOutputTensorNum(backend_node.first); ++j) {
if (AnfAlgo::OutputAddrExist(backend_node.first, j, false) &&
if (backend_node.first->kernel_info() != nullptr && AnfAlgo::OutputAddrExist(backend_node.first, j, false) &&
AnfAlgo::GetMutableOutputAddr(backend_node.first, j, false).get() == input_device_tensors_[from_index]) {
Async(result_arrow->to_op_id_, &OutputActor::CollectOutput, backend_node.first, j,
result_arrow->to_input_index_, context);

View File

@ -1367,11 +1367,13 @@ void GraphScheduler::LinkDataArrowForInternalParameter(const AnfNodePtr &interna
<< ", internal parameter:" << AnfAlgo::GetNodeDebugString(internal_parameter);
}
auto actor_pair = graph_output_to_actor_[front_output_with_index];
MS_LOG(INFO) << "Graph " << graph->graph_id() << " internal parameter:" << internal_parameter->DebugString()
<< ", corresponding front node:" << front_output_node->fullname_with_scope()
<< " with index:" << front_output_with_index.second
<< ", from actor:" << actor_pair.first->GetAID().Name() << " with index:" << actor_pair.second
<< ", to actor:" << to_actor->GetAID().Name() << " with index:" << to_kernel_with_input_idx.second;
if (actor_pair.first != nullptr) {
MS_LOG(INFO) << "Graph " << graph->graph_id() << " internal parameter:" << internal_parameter->DebugString()
<< ", corresponding front node:" << front_output_node->fullname_with_scope()
<< " with index:" << front_output_with_index.second
<< ", from actor:" << actor_pair.first->GetAID().Name() << " with index:" << actor_pair.second
<< ", to actor:" << to_actor->GetAID().Name() << " with index:" << to_kernel_with_input_idx.second;
}
if (IsDeviceQueueDSActor(front_output_node)) {
auto from_actor = dynamic_cast<DeviceQueueDataSourceActor *>(actor_pair.first);

View File

@ -226,7 +226,7 @@ class SideEffectTwoAssignTwoAddnDependencyNet(Cell):
return grad_out
@pytest.mark.level1
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_ctrl_while_by_while_and_if_in_first_while():
@ -262,7 +262,7 @@ def test_ctrl_while_by_while_and_if_in_first_while():
net(input_me_a)
@pytest.mark.level1
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_ctrl_while_by_while_and_while_in_first_while():