!49481 valuenode support the list

Merge pull request !49481 from limingqi107/bug_fix3
This commit is contained in:
yanghaoran 2023-03-01 01:55:06 +00:00 committed by Gitee
commit 004a7e16d5
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
7 changed files with 46 additions and 49 deletions

View File

@ -173,8 +173,8 @@ void PushInputTensor(const BaseRef &arg, std::vector<tensor::TensorPtr> *inputs,
} else if (utils::isa<ValuePtr>(arg)) {
auto value = utils::cast<ValuePtr>(arg);
MS_EXCEPTION_IF_NULL(value);
if (value->isa<ValueTuple>()) {
auto value_tuple = value->cast<ValueTuplePtr>();
if (value->isa<ValueSequence>()) {
auto value_tuple = value->cast<ValueSequencePtr>();
MS_EXCEPTION_IF_NULL(value_tuple);
auto tuple_value = value_tuple->value();
(void)std::transform(tuple_value.begin(), tuple_value.end(), std::back_inserter(*inputs),
@ -202,23 +202,6 @@ void PushInputTensor(const BaseRef &arg, std::vector<tensor::TensorPtr> *inputs,
}
namespace {
// Move these function to anonymous namespace
void FlatValueTupleValue(const ValuePtrList &value, ValuePtrList *flatted_value) {
MS_EXCEPTION_IF_NULL(flatted_value);
for (auto value_element : value) {
MS_EXCEPTION_IF_NULL(value_element);
if (utils::isa<tensor::TensorPtr>(value_element)) {
(void)flatted_value->emplace_back(value_element);
} else if (utils::isa<ValueTuplePtr>(value_element)) {
auto value_tuple_element = value_element->cast<ValueTuplePtr>();
MS_EXCEPTION_IF_NULL(value_tuple_element);
FlatValueTupleValue(value_tuple_element->value(), flatted_value);
} else {
MS_LOG(EXCEPTION) << "The value input to FlatValueTupleValue should only contains Tensor and ValueTuple.";
}
}
}
void FlattenValue(const BaseRef &arg, ValuePtrList *flatted_value) {
MS_EXCEPTION_IF_NULL(flatted_value);
if (utils::isa<ValueSequencePtr>(arg)) {
@ -647,8 +630,8 @@ namespace {
void TensorValueToVector(const ValuePtr &value, VectorRef *outputs) {
MS_EXCEPTION_IF_NULL(value);
MS_EXCEPTION_IF_NULL(outputs);
if (value->isa<ValueTuple>()) {
auto value_tuple = value->cast<ValueTuplePtr>();
if (value->isa<ValueSequence>()) {
auto value_tuple = value->cast<ValueSequencePtr>();
MS_EXCEPTION_IF_NULL(value_tuple);
for (size_t i = 0; i < value_tuple->size(); ++i) {
ValuePtr element = value_tuple->value()[i];
@ -661,7 +644,7 @@ void TensorValueToVector(const ValuePtr &value, VectorRef *outputs) {
auto scalar = element->cast<ScalarPtr>();
MS_EXCEPTION_IF_NULL(scalar);
outputs->emplace_back(ScalarToTensor(scalar));
} else if (element->isa<ValueTuple>()) {
} else if (element->isa<ValueSequence>()) {
VectorRef tuple;
TensorValueToVector(element, &tuple);
outputs->emplace_back(tuple);
@ -687,7 +670,7 @@ bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &graph_output, const Vec
ValuePtr value = GetValueNode(graph_output);
TensorValueToVector(value, &output_tmp);
MS_EXCEPTION_IF_NULL(value);
if (value->isa<ValueTuple>()) {
if (value->isa<ValueSequence>()) {
outputs->emplace_back(output_tmp);
} else if (value->isa<tensor::Tensor>() || value->isa<Scalar>()) {
*outputs = output_tmp;
@ -952,9 +935,9 @@ void MindRTBackendBase::ConstructOutputs(const AnfNodePtr &output_node,
if (output_node->isa<ValueNode>()) {
auto value = output_node->cast<ValueNodePtr>()->value();
MS_EXCEPTION_IF_NULL(value);
if (value->isa<ValueTuple>()) {
if (value->isa<ValueSequence>()) {
outputs->emplace_back(value);
(*output_position) += CountValueNum(value->cast<ValueTuplePtr>());
(*output_position) += CountValueNum(value->cast<ValueSequencePtr>());
} else if (outputs_num != 0) {
outputs->emplace_back(value);
(*output_position) += outputs_num;

View File

@ -81,7 +81,7 @@ COMMON_EXPORT void TensorValueToTensor(const ValuePtr &value, std::vector<tensor
COMMON_EXPORT ValuePtr ShallowCopyTensorValue(const ValuePtr &value);
COMMON_EXPORT size_t CountValueNum(const ValueTuplePtr &value_tuple);
COMMON_EXPORT size_t CountValueNum(const ValueSequencePtr &value_sequence);
COMMON_EXPORT bool IsAKGSparseOP(const AnfNodePtr &cnode);

View File

@ -186,7 +186,7 @@ void DeviceAddressUtils::CreateDeviceAddressForTensorValue(const DeviceContext *
std::string output_format = AnfAlgo::GetOutputFormat(value_node, output_idx);
device::DeviceAddressPtr address = device_context->device_res_manager_->CreateDeviceAddress(
nullptr, tensor_size, output_format, output_type_id, trans::GetRuntimePaddingShape(value_node, output_idx));
nullptr, tensor_size, output_format, output_type_id, tensor->shape());
MS_LOG(DEBUG) << "Create addr for node:" << common::AnfAlgo::GetNodeDebugString(value_node) << " addr:" << address;
MS_EXCEPTION_IF_NULL(address);
address->set_from_persistent_mem(true);

View File

@ -271,14 +271,10 @@ ShapeVector GetRuntimePaddingShape(const AnfNodePtr &node, size_t index) {
if (node_value->isa<Scalar>()) {
return {};
}
if (node_value->isa<ValueTuple>()) {
const auto &value_tuple = node_value->cast<ValueTuplePtr>();
MS_EXCEPTION_IF_NULL(value_tuple);
const auto &values = value_tuple->value();
if (values.size() == 0) {
MS_LOG(INFO) << "Empty tuple for node:" << node->DebugString();
return {0};
}
if (node_value->isa<ValueSequence>()) {
MS_LOG(INFO) << "GetRuntimePaddingShape does not support the value sequence for value node:"
<< node->fullname_with_scope() << ", debug name:" << node->DebugString();
return {0};
}
auto tensor = node_value->cast<tensor::TensorPtr>();
if (tensor == nullptr) {

View File

@ -111,6 +111,11 @@ void SyncTensorData(const TensorPtr &host_tensor, const DeviceTensorPtr &device_
}
};
ShapeVector host_shape = {};
// GetRuntimePaddingShape doesn't support the value tuple node.
if (!node->isa<ValueNode>()) {
host_shape = trans::GetRuntimePaddingShape(node, 0);
}
auto get_tensor_num = (host_tensor->isa<tensor::MapTensor>() ? kMapTensorNum : kNormalTensorNum);
for (size_t i = 0; i < get_tensor_num; ++i) {
const auto &real_host_tensor = get_tensor_by_index(i);
@ -118,8 +123,11 @@ void SyncTensorData(const TensorPtr &host_tensor, const DeviceTensorPtr &device_
// Copy data from host tensor to device.
auto host_tensor_size = LongToSize(real_host_tensor->data().nbytes());
auto host_tensor_type = real_host_tensor->data_type();
if (!device_tensor->SyncHostToDevice(trans::GetRuntimePaddingShape(node, 0), host_tensor_size, host_tensor_type,
real_host_tensor->data_c(), real_host_tensor->device_info().host_format_)) {
if (node->isa<ValueNode>()) {
host_shape = real_host_tensor->shape();
}
if (!device_tensor->SyncHostToDevice(host_shape, host_tensor_size, host_tensor_type, real_host_tensor->data_c(),
real_host_tensor->device_info().host_format_)) {
std::string error_info = "SyncHostToDevice failed, node name: " + node->fullname_with_scope() +
", host tensor size: " + std::to_string(host_tensor_size) +
", host tensor type: " + std::to_string(static_cast<int>(host_tensor_type)) +

View File

@ -120,8 +120,8 @@ std::vector<KernelWithIndex> GetAllOutputWithIndexInner(const AnfNodePtr &node)
MS_EXCEPTION_IF_NULL(value);
if (value->isa<None>()) {
return ret;
} else if (value->isa<ValueTuple>()) {
auto value_tuple = value->cast<ValueTuplePtr>();
} else if (value->isa<ValueSequence>()) {
auto value_tuple = value->cast<ValueSequencePtr>();
auto value_tuple_size = CountValueNum(value_tuple);
for (size_t i = 0; i < value_tuple_size; ++i) {
(void)ret.emplace_back(node, i);
@ -142,8 +142,14 @@ std::vector<KernelWithIndex> GetAllOutputWithIndexInner(const AnfNodePtr &node)
outputs_num = AnfUtils::GetOutputTensorNum(node);
}
}
// Call node maybe a real cnode and the last interface cannot get output num exactly, so we should get
// output num from abstract again.
// Call node maybe a real cnode and the unreal node cannot get output num exactly, so we should get
// output num from abstract again. For example the TupleGetItem/Makeple multi-level nesting:
// '''G = op() ---> Assume that the output of G is a multi-member tuple
// A = MakeTuple(E, F, G)
// B = MakeTuple(H, A)
// C = TupleGetItem(B, 1) ---> Euqal the A
// D = TupleGetItem(C, 2) ---> VisitKernel will return the {G, 0}, but expect the whole G with all the members
// return D'''
if (common::AnfAlgo::IsCallNode(node) || (!AnfUtils::IsRealCNodeKernel(node))) {
MS_EXCEPTION_IF_NULL(node->abstract());
outputs_num = AnfAlgo::GetOutputNumByAbstract(node->abstract());

View File

@ -381,8 +381,8 @@ ValuePtr CreateValueFromTensor(const tensor::TensorPtr &tensor) {
void TensorValueToTensor(const ValuePtr &value, std::vector<tensor::TensorPtr> *tensors) {
MS_EXCEPTION_IF_NULL(value);
MS_EXCEPTION_IF_NULL(tensors);
if (value->isa<ValueTuple>()) {
auto value_tuple = value->cast<ValueTuplePtr>();
if (value->isa<ValueSequence>()) {
auto value_tuple = value->cast<ValueSequencePtr>();
MS_EXCEPTION_IF_NULL(value_tuple);
for (size_t i = 0; i < value_tuple->size(); ++i) {
ValuePtr element = value_tuple->value()[i];
@ -390,8 +390,12 @@ void TensorValueToTensor(const ValuePtr &value, std::vector<tensor::TensorPtr> *
auto tensor = element->cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(tensor);
tensors->emplace_back(tensor);
} else if (element->isa<ValueTuple>()) {
} else if (element->isa<ValueSequence>()) {
TensorValueToTensor(element, tensors);
} else if (element->isa<Scalar>()) {
auto scalar = element->cast<ScalarPtr>();
MS_EXCEPTION_IF_NULL(scalar);
tensors->emplace_back(ScalarToTensor(scalar));
}
}
} else if (value->isa<tensor::Tensor>()) {
@ -425,15 +429,15 @@ ValuePtr ShallowCopyTensorValue(const ValuePtr &value) {
}
}
size_t CountValueNum(const ValueTuplePtr &value_tuple) {
MS_EXCEPTION_IF_NULL(value_tuple);
size_t CountValueNum(const ValueSequencePtr &value_sequence) {
MS_EXCEPTION_IF_NULL(value_sequence);
size_t cnt = 0;
const auto &value_list = value_tuple->value();
const auto &value_list = value_sequence->value();
for (const auto &value : value_list) {
if (value->isa<None>()) {
continue;
} else if (value->isa<ValueTuple>()) {
cnt += CountValueNum(value->cast<ValueTuplePtr>());
} else if (value->isa<ValueSequence>()) {
cnt += CountValueNum(value->cast<ValueSequencePtr>());
} else {
cnt++;
}