!49478 Dynamic shape support valuenode of empty tuple.

Merge pull request !49478 from gaoyong10/dynamic_shape_02
This commit is contained in:
i-robot 2023-02-28 06:25:40 +00:00 committed by Gitee
commit 8d1506efce
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
10 changed files with 66 additions and 18 deletions

View File

@ -64,21 +64,28 @@ size_t KernelTensor::GetSizeInBytes() const {
}
TypeId GetSeqElementsDtype(const abstract::AbstractBasePtr &abs) {
MS_EXCEPTION_IF_NULL(abs);
if (!abs->isa<abstract::AbstractSequence>()) {
return TypeId::kTypeUnknown;
}
TypePtr type_ptr;
auto seq_abs = abs->cast<abstract::AbstractSequencePtr>();
MS_EXCEPTION_IF_NULL(seq_abs);
if (seq_abs->dynamic_len()) {
if (seq_abs->dynamic_len_element_abs() == nullptr) {
return TypeId::kTypeUnknown;
}
type_ptr = seq_abs->dynamic_len_element_abs()->BuildType();
} else {
if (seq_abs->elements().empty()) {
if (seq_abs->elements().empty() || seq_abs->elements()[0] == nullptr) {
return TypeId::kTypeUnknown;
}
type_ptr = seq_abs->elements()[0]->BuildType();
}
MS_EXCEPTION_IF_NULL(type_ptr);
if (type_ptr->isa<TensorType>()) {
auto tensor_ptr = type_ptr->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(tensor_ptr);
auto elem = tensor_ptr->element();
if (elem == nullptr) {
return TypeId::kTypeUnknown;
@ -122,14 +129,21 @@ TypeId KernelTensor::GetDtype() const {
}
ShapeVector GetSequenceFlattenShape(const abstract::AbstractBasePtr &abs) {
MS_EXCEPTION_IF_NULL(abs);
if (!abs->isa<abstract::AbstractSequence>()) {
return {};
}
auto seq_abs = abs->cast<abstract::AbstractSequencePtr>();
MS_EXCEPTION_IF_NULL(seq_abs);
if (seq_abs->dynamic_len()) {
return {-1};
}
if (seq_abs->elements().empty() || seq_abs->elements()[0] == nullptr) {
MS_LOG(INFO) << "Empty sequence abstract:" << seq_abs->ToString();
return {0};
}
auto type_ptr = seq_abs->elements()[0]->BuildType();
MS_EXCEPTION_IF_NULL(type_ptr);
if (!type_ptr->isa<TensorType>()) {
return {(int64_t)seq_abs->elements().size()};
}

View File

@ -430,9 +430,10 @@ inline T *GetDeviceAddress(const std::vector<AddressPtr> &addr_list, size_t inde
return nullptr;
}
// When the input is an empty tuple, the input size will be 0.
if (addr_list[index]->size == 0) {
MS_LOG(WARNING) << "The size of device address is zero, address index: " << index
<< ", and the length of 'addr_list' is " << addr_list.size();
MS_LOG(INFO) << "The size of device address is zero, address index: " << index
<< ", and the length of 'addr_list' is " << addr_list.size();
return nullptr;
}
return reinterpret_cast<T *>(addr_list[index]->addr);

View File

@ -114,7 +114,7 @@ void GetInputFormat(const CNodePtr &kernel_node, std::vector<std::string> *input
}
bool InputDtypeMatch(TypeId InputAttr, TypeId input_type, bool strict) {
if (InputAttr == input_type) {
if (InputAttr == input_type || kTypeUnknown == input_type) {
return true;
}
if (!strict && InputAttr == kNumberTypeInt32 && (input_type == kNumberTypeInt16 || input_type == kNumberTypeInt64)) {
@ -135,6 +135,9 @@ bool OutputDtypeMatched(const kernel::KernelAttr &kernel_attr, const std::vector
}
auto output_num = output_types.size();
for (size_t i = 0; i < output_num; ++i) {
if (output_types[i] == kTypeUnknown) {
continue;
}
if (kernel_attr.GetOutputAttr(i).dtype != output_types[i]) {
MS_LOG(DEBUG) << "required dtype:" << kernel_attr.GetOutputAttr(i).dtype
<< ", actual output dtype:" << output_types[i];

View File

@ -67,15 +67,17 @@ bool SequenceAddCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the size of 'input_0 + input_1': {"
<< input_0_size + input_1_size << "} is not equal to the size of output: {" << output_size << "}";
}
auto cp_ret = memcpy_s(output_addr, input_0_size, input_0_addr, input_0_size);
if (cp_ret != EOK) {
MS_LOG(EXCEPTION) << "For " << kernel_name_ << ", memcpy error, errorno: " << cp_ret;
if (input_0_size != 0) {
auto cp_ret = memcpy_s(output_addr, input_0_size, input_0_addr, input_0_size);
if (cp_ret != EOK) {
MS_LOG(EXCEPTION) << "For " << kernel_name_ << ", memcpy error, errorno: " << cp_ret;
}
}
cp_ret = memcpy_s(output_addr + input_0_size / sizeof(T), input_1_size, input_1_addr, input_1_size);
if (cp_ret != EOK) {
MS_LOG(EXCEPTION) << "For " << kernel_name_ << ", memcpy error, errorno: " << cp_ret;
if (input_1_size != 0) {
auto cp_ret = memcpy_s(output_addr + input_0_size / sizeof(T), input_1_size, input_1_addr, input_1_size);
if (cp_ret != EOK) {
MS_LOG(EXCEPTION) << "For " << kernel_name_ << ", memcpy error, errorno: " << cp_ret;
}
}
return true;
}

View File

@ -271,6 +271,15 @@ ShapeVector GetRuntimePaddingShape(const AnfNodePtr &node, size_t index) {
if (node_value->isa<Scalar>()) {
return {};
}
if (node_value->isa<ValueTuple>()) {
const auto &value_tuple = node_value->cast<ValueTuplePtr>();
MS_EXCEPTION_IF_NULL(value_tuple);
const auto &values = value_tuple->value();
if (values.size() == 0) {
MS_LOG(INFO) << "Empty tuple for node:" << node->DebugString();
return {0};
}
}
auto tensor = node_value->cast<tensor::TensorPtr>();
if (tensor == nullptr) {
MS_LOG(EXCEPTION) << " The node[ " << node->DebugString() << "]'s cannot convert ";

View File

@ -169,6 +169,10 @@ void ValueTupleToValue(const ValuePtr &value, std::vector<ValuePtr> *const value
if (value->isa<ValueTuple>()) {
auto value_tuple = value->cast<ValueTuplePtr>();
MS_EXCEPTION_IF_NULL(value_tuple);
if (value_tuple->size() == 0) {
(void)values->emplace_back(std::make_shared<tensor::Tensor>());
return;
}
for (size_t i = 0; i < value_tuple->size(); ++i) {
ValuePtr element = value_tuple->value()[i];
MS_EXCEPTION_IF_NULL(element);
@ -715,6 +719,11 @@ void DataPrepareActor::PrepareDataForControlValueNode(const KernelWithIndex &nod
device_tensor->GetSize());
}
if (tensor->data_ptr() == nullptr && device_tensor->GetSize() == 0) {
MS_LOG(INFO) << "Empty tuple sync";
return;
}
auto host_tensor_size = LongToSize(tensor->data().nbytes());
auto host_tensor_type = tensor->data_type();
auto shape = tensor->shape();

View File

@ -296,7 +296,10 @@ void HostQueueDataSourceActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *cons
}
continue;
}
if (host_tensor->data_ptr() == nullptr && device_tensor->GetSize() == 0) {
MS_LOG(INFO) << "Empty tuple sync";
continue;
}
// Sync data from host_tensor to device_tensor.
if (!device_tensor->SyncHostToDevice(
trans::GetRuntimePaddingShape(data_node_with_indexs_[i].first, data_node_with_indexs_[i].second),

View File

@ -705,6 +705,9 @@ TypeId AnfAlgo::GetOutputInferDataType(const TypePtr &type, size_t output_idx) {
if (type_ptr->isa<Tuple>()) {
auto tuple_ptr = type_ptr->cast<TuplePtr>();
MS_EXCEPTION_IF_NULL(tuple_ptr);
if (tuple_ptr->size() == 0) {
return kTypeUnknown;
}
if (tuple_ptr->dynamic_len()) {
MS_EXCEPTION_IF_NULL(tuple_ptr->dynamic_element_type());
if (tuple_ptr->dynamic_element_type()->isa<TensorType>()) {
@ -715,9 +718,6 @@ TypeId AnfAlgo::GetOutputInferDataType(const TypePtr &type, size_t output_idx) {
}
return tuple_ptr->dynamic_element_type()->type_id();
}
if (tuple_ptr->size() == 0) {
return kTypeUnknown;
}
MS_EXCEPTION_IF_NULL(tuple_ptr);
if (output_idx >= tuple_ptr->size()) {
MS_LOG(EXCEPTION) << "Output index " << output_idx << " must be less than output number " << tuple_ptr->size();
@ -1961,7 +1961,11 @@ abstract::BaseShapePtr AnfAlgo::GetDynamicSequenceShape(const AnfNodePtr &node,
MS_LOG(EXCEPTION) << "Not dynamic abstract in node:" << node->DebugString() << " for dynamic sequence shape.";
}
const auto &element_abs = sequence_abs->dynamic_len_element_abs();
MS_EXCEPTION_IF_NULL(element_abs);
if (element_abs == nullptr) {
MS_LOG(INFO) << "No element abs for node:" << node->DebugString() << " index:" << output_idx;
ShapeVector empty_shape{0};
return std::make_shared<abstract::Shape>(empty_shape);
}
return element_abs->BuildShape();
}
} // namespace common

View File

@ -517,7 +517,7 @@ std::vector<ShapeVector> BaseShapeToShapeVector(const abstract::BaseShapePtr &ba
} else if (element_base_shape->isa<abstract::NoShape>()) {
return std::vector<ShapeVector>(tuple_shape->size(), {1});
}
} else if (base_shape->isa<abstract::NoShape>()) {
} else if (base_shape->isa<abstract::NoShape>() || base_shape->isa<abstract::DynamicSequenceShape>()) {
return {};
}
MS_LOG(WARNING) << "Invalid shape:" << base_shape->ToString();

View File

@ -527,6 +527,9 @@ py::object VectorRefToPyData(const VectorRef &value_list, const AbstractBasePtr
if (dynamic_len || dynamic_len_element_abs != nullptr) {
if (dynamic_len_element_abs == nullptr) {
MS_LOG(INFO) << "Dynamic length sequence with no specified element abstract convert to empty tuple.";
for (size_t i = 0; i < value_size; i++) {
ref_tuple[i] = BaseRefToPyData(value_list[i]);
}
return ref_tuple;
}
if (dynamic_len_element_abs->isa<abstract::AbstractNone>()) {