forked from mindspore-Ecosystem/mindspore
!49481 valuenode support the list
Merge pull request !49481 from limingqi107/bug_fix3
This commit is contained in:
commit
004a7e16d5
|
@ -173,8 +173,8 @@ void PushInputTensor(const BaseRef &arg, std::vector<tensor::TensorPtr> *inputs,
|
|||
} else if (utils::isa<ValuePtr>(arg)) {
|
||||
auto value = utils::cast<ValuePtr>(arg);
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
if (value->isa<ValueTuple>()) {
|
||||
auto value_tuple = value->cast<ValueTuplePtr>();
|
||||
if (value->isa<ValueSequence>()) {
|
||||
auto value_tuple = value->cast<ValueSequencePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_tuple);
|
||||
auto tuple_value = value_tuple->value();
|
||||
(void)std::transform(tuple_value.begin(), tuple_value.end(), std::back_inserter(*inputs),
|
||||
|
@ -202,23 +202,6 @@ void PushInputTensor(const BaseRef &arg, std::vector<tensor::TensorPtr> *inputs,
|
|||
}
|
||||
|
||||
namespace {
|
||||
// Move these function to anonymous namespace
|
||||
void FlatValueTupleValue(const ValuePtrList &value, ValuePtrList *flatted_value) {
|
||||
MS_EXCEPTION_IF_NULL(flatted_value);
|
||||
for (auto value_element : value) {
|
||||
MS_EXCEPTION_IF_NULL(value_element);
|
||||
if (utils::isa<tensor::TensorPtr>(value_element)) {
|
||||
(void)flatted_value->emplace_back(value_element);
|
||||
} else if (utils::isa<ValueTuplePtr>(value_element)) {
|
||||
auto value_tuple_element = value_element->cast<ValueTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_tuple_element);
|
||||
FlatValueTupleValue(value_tuple_element->value(), flatted_value);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "The value input to FlatValueTupleValue should only contains Tensor and ValueTuple.";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void FlattenValue(const BaseRef &arg, ValuePtrList *flatted_value) {
|
||||
MS_EXCEPTION_IF_NULL(flatted_value);
|
||||
if (utils::isa<ValueSequencePtr>(arg)) {
|
||||
|
@ -647,8 +630,8 @@ namespace {
|
|||
void TensorValueToVector(const ValuePtr &value, VectorRef *outputs) {
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
MS_EXCEPTION_IF_NULL(outputs);
|
||||
if (value->isa<ValueTuple>()) {
|
||||
auto value_tuple = value->cast<ValueTuplePtr>();
|
||||
if (value->isa<ValueSequence>()) {
|
||||
auto value_tuple = value->cast<ValueSequencePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_tuple);
|
||||
for (size_t i = 0; i < value_tuple->size(); ++i) {
|
||||
ValuePtr element = value_tuple->value()[i];
|
||||
|
@ -661,7 +644,7 @@ void TensorValueToVector(const ValuePtr &value, VectorRef *outputs) {
|
|||
auto scalar = element->cast<ScalarPtr>();
|
||||
MS_EXCEPTION_IF_NULL(scalar);
|
||||
outputs->emplace_back(ScalarToTensor(scalar));
|
||||
} else if (element->isa<ValueTuple>()) {
|
||||
} else if (element->isa<ValueSequence>()) {
|
||||
VectorRef tuple;
|
||||
TensorValueToVector(element, &tuple);
|
||||
outputs->emplace_back(tuple);
|
||||
|
@ -687,7 +670,7 @@ bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &graph_output, const Vec
|
|||
ValuePtr value = GetValueNode(graph_output);
|
||||
TensorValueToVector(value, &output_tmp);
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
if (value->isa<ValueTuple>()) {
|
||||
if (value->isa<ValueSequence>()) {
|
||||
outputs->emplace_back(output_tmp);
|
||||
} else if (value->isa<tensor::Tensor>() || value->isa<Scalar>()) {
|
||||
*outputs = output_tmp;
|
||||
|
@ -952,9 +935,9 @@ void MindRTBackendBase::ConstructOutputs(const AnfNodePtr &output_node,
|
|||
if (output_node->isa<ValueNode>()) {
|
||||
auto value = output_node->cast<ValueNodePtr>()->value();
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
if (value->isa<ValueTuple>()) {
|
||||
if (value->isa<ValueSequence>()) {
|
||||
outputs->emplace_back(value);
|
||||
(*output_position) += CountValueNum(value->cast<ValueTuplePtr>());
|
||||
(*output_position) += CountValueNum(value->cast<ValueSequencePtr>());
|
||||
} else if (outputs_num != 0) {
|
||||
outputs->emplace_back(value);
|
||||
(*output_position) += outputs_num;
|
||||
|
|
|
@ -81,7 +81,7 @@ COMMON_EXPORT void TensorValueToTensor(const ValuePtr &value, std::vector<tensor
|
|||
|
||||
COMMON_EXPORT ValuePtr ShallowCopyTensorValue(const ValuePtr &value);
|
||||
|
||||
COMMON_EXPORT size_t CountValueNum(const ValueTuplePtr &value_tuple);
|
||||
COMMON_EXPORT size_t CountValueNum(const ValueSequencePtr &value_sequence);
|
||||
|
||||
COMMON_EXPORT bool IsAKGSparseOP(const AnfNodePtr &cnode);
|
||||
|
||||
|
|
|
@ -186,7 +186,7 @@ void DeviceAddressUtils::CreateDeviceAddressForTensorValue(const DeviceContext *
|
|||
std::string output_format = AnfAlgo::GetOutputFormat(value_node, output_idx);
|
||||
|
||||
device::DeviceAddressPtr address = device_context->device_res_manager_->CreateDeviceAddress(
|
||||
nullptr, tensor_size, output_format, output_type_id, trans::GetRuntimePaddingShape(value_node, output_idx));
|
||||
nullptr, tensor_size, output_format, output_type_id, tensor->shape());
|
||||
MS_LOG(DEBUG) << "Create addr for node:" << common::AnfAlgo::GetNodeDebugString(value_node) << " addr:" << address;
|
||||
MS_EXCEPTION_IF_NULL(address);
|
||||
address->set_from_persistent_mem(true);
|
||||
|
|
|
@ -271,14 +271,10 @@ ShapeVector GetRuntimePaddingShape(const AnfNodePtr &node, size_t index) {
|
|||
if (node_value->isa<Scalar>()) {
|
||||
return {};
|
||||
}
|
||||
if (node_value->isa<ValueTuple>()) {
|
||||
const auto &value_tuple = node_value->cast<ValueTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_tuple);
|
||||
const auto &values = value_tuple->value();
|
||||
if (values.size() == 0) {
|
||||
MS_LOG(INFO) << "Empty tuple for node:" << node->DebugString();
|
||||
return {0};
|
||||
}
|
||||
if (node_value->isa<ValueSequence>()) {
|
||||
MS_LOG(INFO) << "GetRuntimePaddingShape does not support the value sequence for value node:"
|
||||
<< node->fullname_with_scope() << ", debug name:" << node->DebugString();
|
||||
return {0};
|
||||
}
|
||||
auto tensor = node_value->cast<tensor::TensorPtr>();
|
||||
if (tensor == nullptr) {
|
||||
|
|
|
@ -111,6 +111,11 @@ void SyncTensorData(const TensorPtr &host_tensor, const DeviceTensorPtr &device_
|
|||
}
|
||||
};
|
||||
|
||||
ShapeVector host_shape = {};
|
||||
// GetRuntimePaddingShape doesn't support the value tuple node.
|
||||
if (!node->isa<ValueNode>()) {
|
||||
host_shape = trans::GetRuntimePaddingShape(node, 0);
|
||||
}
|
||||
auto get_tensor_num = (host_tensor->isa<tensor::MapTensor>() ? kMapTensorNum : kNormalTensorNum);
|
||||
for (size_t i = 0; i < get_tensor_num; ++i) {
|
||||
const auto &real_host_tensor = get_tensor_by_index(i);
|
||||
|
@ -118,8 +123,11 @@ void SyncTensorData(const TensorPtr &host_tensor, const DeviceTensorPtr &device_
|
|||
// Copy data from host tensor to device.
|
||||
auto host_tensor_size = LongToSize(real_host_tensor->data().nbytes());
|
||||
auto host_tensor_type = real_host_tensor->data_type();
|
||||
if (!device_tensor->SyncHostToDevice(trans::GetRuntimePaddingShape(node, 0), host_tensor_size, host_tensor_type,
|
||||
real_host_tensor->data_c(), real_host_tensor->device_info().host_format_)) {
|
||||
if (node->isa<ValueNode>()) {
|
||||
host_shape = real_host_tensor->shape();
|
||||
}
|
||||
if (!device_tensor->SyncHostToDevice(host_shape, host_tensor_size, host_tensor_type, real_host_tensor->data_c(),
|
||||
real_host_tensor->device_info().host_format_)) {
|
||||
std::string error_info = "SyncHostToDevice failed, node name: " + node->fullname_with_scope() +
|
||||
", host tensor size: " + std::to_string(host_tensor_size) +
|
||||
", host tensor type: " + std::to_string(static_cast<int>(host_tensor_type)) +
|
||||
|
|
|
@ -120,8 +120,8 @@ std::vector<KernelWithIndex> GetAllOutputWithIndexInner(const AnfNodePtr &node)
|
|||
MS_EXCEPTION_IF_NULL(value);
|
||||
if (value->isa<None>()) {
|
||||
return ret;
|
||||
} else if (value->isa<ValueTuple>()) {
|
||||
auto value_tuple = value->cast<ValueTuplePtr>();
|
||||
} else if (value->isa<ValueSequence>()) {
|
||||
auto value_tuple = value->cast<ValueSequencePtr>();
|
||||
auto value_tuple_size = CountValueNum(value_tuple);
|
||||
for (size_t i = 0; i < value_tuple_size; ++i) {
|
||||
(void)ret.emplace_back(node, i);
|
||||
|
@ -142,8 +142,14 @@ std::vector<KernelWithIndex> GetAllOutputWithIndexInner(const AnfNodePtr &node)
|
|||
outputs_num = AnfUtils::GetOutputTensorNum(node);
|
||||
}
|
||||
}
|
||||
// Call node maybe a real cnode and the last interface cannot get output num exactly, so we should get
|
||||
// output num from abstract again.
|
||||
// Call node maybe a real cnode and the unreal node cannot get output num exactly, so we should get
|
||||
// output num from abstract again. For example the TupleGetItem/Makeple multi-level nesting:
|
||||
// '''G = op() ---> Assume that the output of G is a multi-member tuple
|
||||
// A = MakeTuple(E, F, G)
|
||||
// B = MakeTuple(H, A)
|
||||
// C = TupleGetItem(B, 1) ---> Euqal the A
|
||||
// D = TupleGetItem(C, 2) ---> VisitKernel will return the {G, 0}, but expect the whole G with all the members
|
||||
// return D'''
|
||||
if (common::AnfAlgo::IsCallNode(node) || (!AnfUtils::IsRealCNodeKernel(node))) {
|
||||
MS_EXCEPTION_IF_NULL(node->abstract());
|
||||
outputs_num = AnfAlgo::GetOutputNumByAbstract(node->abstract());
|
||||
|
|
|
@ -381,8 +381,8 @@ ValuePtr CreateValueFromTensor(const tensor::TensorPtr &tensor) {
|
|||
void TensorValueToTensor(const ValuePtr &value, std::vector<tensor::TensorPtr> *tensors) {
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
MS_EXCEPTION_IF_NULL(tensors);
|
||||
if (value->isa<ValueTuple>()) {
|
||||
auto value_tuple = value->cast<ValueTuplePtr>();
|
||||
if (value->isa<ValueSequence>()) {
|
||||
auto value_tuple = value->cast<ValueSequencePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_tuple);
|
||||
for (size_t i = 0; i < value_tuple->size(); ++i) {
|
||||
ValuePtr element = value_tuple->value()[i];
|
||||
|
@ -390,8 +390,12 @@ void TensorValueToTensor(const ValuePtr &value, std::vector<tensor::TensorPtr> *
|
|||
auto tensor = element->cast<tensor::TensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(tensor);
|
||||
tensors->emplace_back(tensor);
|
||||
} else if (element->isa<ValueTuple>()) {
|
||||
} else if (element->isa<ValueSequence>()) {
|
||||
TensorValueToTensor(element, tensors);
|
||||
} else if (element->isa<Scalar>()) {
|
||||
auto scalar = element->cast<ScalarPtr>();
|
||||
MS_EXCEPTION_IF_NULL(scalar);
|
||||
tensors->emplace_back(ScalarToTensor(scalar));
|
||||
}
|
||||
}
|
||||
} else if (value->isa<tensor::Tensor>()) {
|
||||
|
@ -425,15 +429,15 @@ ValuePtr ShallowCopyTensorValue(const ValuePtr &value) {
|
|||
}
|
||||
}
|
||||
|
||||
size_t CountValueNum(const ValueTuplePtr &value_tuple) {
|
||||
MS_EXCEPTION_IF_NULL(value_tuple);
|
||||
size_t CountValueNum(const ValueSequencePtr &value_sequence) {
|
||||
MS_EXCEPTION_IF_NULL(value_sequence);
|
||||
size_t cnt = 0;
|
||||
const auto &value_list = value_tuple->value();
|
||||
const auto &value_list = value_sequence->value();
|
||||
for (const auto &value : value_list) {
|
||||
if (value->isa<None>()) {
|
||||
continue;
|
||||
} else if (value->isa<ValueTuple>()) {
|
||||
cnt += CountValueNum(value->cast<ValueTuplePtr>());
|
||||
} else if (value->isa<ValueSequence>()) {
|
||||
cnt += CountValueNum(value->cast<ValueSequencePtr>());
|
||||
} else {
|
||||
cnt++;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue