Trans format when the tensor is 5D format but the graph input is 4D format

This commit is contained in:
caifubi 2022-02-16 20:02:35 +08:00
parent a0453df907
commit 4dd49a4b72
1 changed files with 1 additions and 1 deletions

View File

@ -396,7 +396,7 @@ void DataPrepareActor::PrepareDataForHostTensorQueue(const std::vector<std::vect
auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0, false);
MS_EXCEPTION_IF_NULL(device_address);
if ((tensor_address != nullptr) && (tensor_address->DeviceType() == device_address->DeviceType()) &&
!device_address->is_ptr_persisted()) {
!device_address->is_ptr_persisted() && tensor_address->format() == device_address->format()) {
AnfAlgo::SetOutputAddr(tensor_address, 0, input_node.get());
tensor_address->SetNodeIndex(input_node, 0);
}