forked from mindspore-Ecosystem/mindspore
!49573 Create device address for tuple-in-tuple valuenode.
Merge pull request !49573 from gaoyong10/r1.10
This commit is contained in:
commit
3cf15d2784
|
@ -289,6 +289,52 @@ void CreateDeviceTensorForValueNode(const KernelWithIndex &front_node_with_index
|
|||
UpdateRefCount(address.get(), true);
|
||||
}
|
||||
|
||||
TypeId FetchTypeIdByNode(const AnfNodePtr &node, size_t index) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
TypeId type_id = kTypeUnknown;
|
||||
if (node->isa<ValueNode>() && node->abstract() != nullptr) {
|
||||
// For valuenode, fetch type from abstract.
|
||||
const auto &abs = FetchAbstractByIndex(node->abstract(), index);
|
||||
MS_EXCEPTION_IF_NULL(abs);
|
||||
const auto &type = abs->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(type);
|
||||
if (type->isa<TensorType>()) {
|
||||
const auto &tensor_type = type->cast<TensorTypePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tensor_type);
|
||||
const auto &element = tensor_type->element();
|
||||
type_id = element->type_id();
|
||||
} else {
|
||||
type_id = type->type_id();
|
||||
}
|
||||
} else {
|
||||
type_id = common::AnfAlgo::GetOutputInferDataType(node, index);
|
||||
}
|
||||
return type_id;
|
||||
}
|
||||
|
||||
size_t FetchOutputSizeByNode(const AnfNodePtr &node, size_t index, TypeId type_id) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
size_t size = GetTypeByte(TypeIdToType(type_id));
|
||||
if (node->isa<ValueNode>() && node->abstract() != nullptr) {
|
||||
const auto &abs = FetchAbstractByIndex(node->abstract(), index);
|
||||
MS_EXCEPTION_IF_NULL(abs);
|
||||
const auto &shape_ptr = abs->BuildShape();
|
||||
MS_EXCEPTION_IF_NULL(shape_ptr);
|
||||
if (shape_ptr->isa<abstract::Shape>()) {
|
||||
const auto &shapes = shape_ptr->cast<abstract::ShapePtr>()->shape();
|
||||
size = std::accumulate(shapes.begin(), shapes.end(), size, std::multiplies<int64_t>());
|
||||
} else if (abs->isa<abstract::AbstractMonad>() || abs->isa<abstract::AbstractScalar>()) {
|
||||
MS_LOG(DEBUG) << "For scalar, the output shape is 1.";
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Invalid abstract;" << abs->ToString() << " for node:" << node->DebugString()
|
||||
<< " index:" << index;
|
||||
}
|
||||
} else {
|
||||
size = AnfAlgo::GetOutputTensorMemSize(node, index);
|
||||
}
|
||||
return size;
|
||||
}
|
||||
|
||||
// Create a device tensor for front node.
|
||||
// When the condition input of the switch and switchlayer or the output of a subgraph is a parameter or value node,
|
||||
// there is no corresponding backend node for this parameter, so a device tensor needs to be created for it.
|
||||
|
@ -326,8 +372,7 @@ void CreateDeviceTensorForFrontNode(const KernelWithIndex &front_node_with_index
|
|||
}
|
||||
|
||||
// Set type.
|
||||
TypeId type_id = kTypeUnknown;
|
||||
type_id = common::AnfAlgo::GetOutputInferDataType(node, front_node_with_index.second);
|
||||
TypeId type_id = FetchTypeIdByNode(node, front_node_with_index.second);
|
||||
if (builder->GetAllOutputDeviceTypes().size() > front_node_with_index.second) {
|
||||
builder->SetOutputDeviceType(type_id, front_node_with_index.second);
|
||||
} else {
|
||||
|
@ -338,8 +383,9 @@ void CreateDeviceTensorForFrontNode(const KernelWithIndex &front_node_with_index
|
|||
builder->SetOutputsDeviceType(types);
|
||||
}
|
||||
|
||||
size_t size = 0;
|
||||
size = AnfAlgo::GetOutputTensorMemSize(node, front_node_with_index.second);
|
||||
// Fetch mem size by shape, the shape is first obtained from the abstract to deal with the scenario where
|
||||
// the value node is a multi-level tuple.
|
||||
size_t size = FetchOutputSizeByNode(node, front_node_with_index.second, type_id);
|
||||
device::DeviceAddressPtr address = nullptr;
|
||||
if (node->isa<ValueNode>()) {
|
||||
const auto &node_value = node->cast<ValueNodePtr>()->value();
|
||||
|
|
Loading…
Reference in New Issue