!49573 Create device address for tuple-in-tuple valuenode.

Merge pull request !49573 from gaoyong10/r1.10
This commit is contained in:
i-robot 2023-03-02 03:17:10 +00:00 committed by Gitee
commit 3cf15d2784
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 50 additions and 4 deletions

View File

@ -289,6 +289,52 @@ void CreateDeviceTensorForValueNode(const KernelWithIndex &front_node_with_index
UpdateRefCount(address.get(), true);
}
TypeId FetchTypeIdByNode(const AnfNodePtr &node, size_t index) {
MS_EXCEPTION_IF_NULL(node);
TypeId type_id = kTypeUnknown;
if (node->isa<ValueNode>() && node->abstract() != nullptr) {
// For valuenode, fetch type from abstract.
const auto &abs = FetchAbstractByIndex(node->abstract(), index);
MS_EXCEPTION_IF_NULL(abs);
const auto &type = abs->BuildType();
MS_EXCEPTION_IF_NULL(type);
if (type->isa<TensorType>()) {
const auto &tensor_type = type->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(tensor_type);
const auto &element = tensor_type->element();
type_id = element->type_id();
} else {
type_id = type->type_id();
}
} else {
type_id = common::AnfAlgo::GetOutputInferDataType(node, index);
}
return type_id;
}
size_t FetchOutputSizeByNode(const AnfNodePtr &node, size_t index, TypeId type_id) {
MS_EXCEPTION_IF_NULL(node);
size_t size = GetTypeByte(TypeIdToType(type_id));
if (node->isa<ValueNode>() && node->abstract() != nullptr) {
const auto &abs = FetchAbstractByIndex(node->abstract(), index);
MS_EXCEPTION_IF_NULL(abs);
const auto &shape_ptr = abs->BuildShape();
MS_EXCEPTION_IF_NULL(shape_ptr);
if (shape_ptr->isa<abstract::Shape>()) {
const auto &shapes = shape_ptr->cast<abstract::ShapePtr>()->shape();
size = std::accumulate(shapes.begin(), shapes.end(), size, std::multiplies<int64_t>());
} else if (abs->isa<abstract::AbstractMonad>() || abs->isa<abstract::AbstractScalar>()) {
MS_LOG(DEBUG) << "For scalar, the output shape is 1.";
} else {
MS_LOG(EXCEPTION) << "Invalid abstract;" << abs->ToString() << " for node:" << node->DebugString()
<< " index:" << index;
}
} else {
size = AnfAlgo::GetOutputTensorMemSize(node, index);
}
return size;
}
// Create a device tensor for front node.
// When the condition input of the switch and switchlayer or the output of a subgraph is a parameter or value node,
// there is no corresponding backend node for this parameter, so a device tensor needs to be created for it.
@ -326,8 +372,7 @@ void CreateDeviceTensorForFrontNode(const KernelWithIndex &front_node_with_index
}
// Set type.
TypeId type_id = kTypeUnknown;
type_id = common::AnfAlgo::GetOutputInferDataType(node, front_node_with_index.second);
TypeId type_id = FetchTypeIdByNode(node, front_node_with_index.second);
if (builder->GetAllOutputDeviceTypes().size() > front_node_with_index.second) {
builder->SetOutputDeviceType(type_id, front_node_with_index.second);
} else {
@ -338,8 +383,9 @@ void CreateDeviceTensorForFrontNode(const KernelWithIndex &front_node_with_index
builder->SetOutputsDeviceType(types);
}
size_t size = 0;
size = AnfAlgo::GetOutputTensorMemSize(node, front_node_with_index.second);
// Fetch mem size by shape, the shape is first obtained from the abstract to deal with the scenario where
// the value node is a multi-level tuple.
size_t size = FetchOutputSizeByNode(node, front_node_with_index.second, type_id);
device::DeviceAddressPtr address = nullptr;
if (node->isa<ValueNode>()) {
const auto &node_value = node->cast<ValueNodePtr>()->value();