fix the bug of input format inconsistency

This commit is contained in:
limingqi107 2022-05-11 11:19:10 +08:00
parent 87c2bd0c45
commit 3cde5d57c7
4 changed files with 25 additions and 8 deletions

View File

@ -48,7 +48,8 @@ void KernelActor::Init() {
const auto &input_device_tensor = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel_, i, false);
MS_EXCEPTION_IF_NULL(input_device_tensor);
(void)real_input_data_infos_.emplace_back(
std::make_pair(input_device_tensor->format(), input_device_tensor->host_shape()));
std::make_shared<InputDataInfo>(input_device_tensor->format(), input_device_tensor->host_shape(),
input_device_tensor->GetSize(), input_device_tensor->type_id()));
}
// Init the device tensors and kernel launch info.
@ -311,8 +312,9 @@ void KernelActor::CopyInputDeviceTensor(const OpData<DeviceTensor> *input_data,
SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy_, *context, "The input index is of range.");
}
auto &real_input_info = real_input_data_infos_[input_data->index_];
MS_EXCEPTION_IF_NULL(real_input_info);
if ((input_data->data_->DeviceType() == device_contexts_[0]->GetDeviceAddressType()) &&
(input_data->data_->format() == real_input_info.first)) {
(input_data->data_->format() == real_input_info->format_)) {
return;
}
@ -320,14 +322,15 @@ void KernelActor::CopyInputDeviceTensor(const OpData<DeviceTensor> *input_data,
SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy_, *context, "The input index is of range.");
}
if (copy_input_device_tensors_[input_data->index_] == nullptr) {
copy_input_device_tensors_[input_data->index_] =
device_contexts_[0]->CreateDeviceAddress(nullptr, input_data->data_->GetSize(), real_input_info.first,
input_data->data_->type_id(), real_input_info.second);
copy_input_device_tensors_[input_data->index_] = device_contexts_[0]->CreateDeviceAddress(
nullptr, real_input_info->size_, real_input_info->format_, real_input_info->type_id_, real_input_info->shape_);
}
auto &new_device_tensor = copy_input_device_tensors_[input_data->index_];
MS_EXCEPTION_IF_NULL(new_device_tensor);
// Dynamic shape need update size.
new_device_tensor->SetSize(input_data->data_->GetSize());
if (AnfUtils::IsShapeDynamic(real_input_info->shape_)) {
new_device_tensor->SetSize(input_data->data_->GetSize());
}
// Update the input device tensor.
input_device_tensors_[input_data->index_] = new_device_tensor.get();

View File

@ -39,6 +39,15 @@ using mindspore::kernel::Address;
using mindspore::kernel::KernelLaunchInfo;
using mindspore::tensor::TensorPtr;
struct InputDataInfo {
InputDataInfo(const std::string &format, const ShapeVector &shape, size_t size, TypeId type_id)
: format_(format), shape_(shape), size_(size), type_id_(type_id) {}
std::string format_;
ShapeVector shape_;
size_t size_;
TypeId type_id_;
};
// The kernel actor is used to receive the device tensors and control info to luanch kernel.
// The processing flow is RunOpData/RunOpControl -> CheckRunningCondition -> SendMemoryAllocReq
// -> OnMemoryAllocFinish -> LaunchKernel -> SendMemoryFreeReq -> SendOutput.
@ -129,8 +138,8 @@ class KernelActor : public DebugAwareActor {
// The received input device type and format may be different from the formal parameter in the control flow scenarios,
// so it needs to be copied from the input data to real data that kernel launch needs.
std::vector<DeviceTensorPtr> copy_input_device_tensors_;
// Real data info <format, host_shape> that kernel launch needs, used to check the consistency of received input data.
std::vector<std::pair<std::string, ShapeVector>> real_input_data_infos_;
// Real data info that kernel launch needs, used to check the consistency of received input data.
std::vector<std::shared_ptr<InputDataInfo>> real_input_data_infos_;
// The device tensors for memory alloc and free.
// output + workspace

View File

@ -142,6 +142,10 @@ bool AnfUtils::IsShapeDynamic(const std::vector<size_t> &shape) {
return std::any_of(shape.begin(), shape.end(), [](int64_t s) { return s < 0; });
}
bool AnfUtils::IsShapeDynamic(const std::vector<int64_t> &shape) {
return std::any_of(shape.begin(), shape.end(), [](int64_t s) { return s < 0; });
}
bool AnfUtils::IsNodeOutputDynamicShape(const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto base_shape = node->Shape();

View File

@ -65,6 +65,7 @@ class MS_CORE_API AnfUtils {
static bool IsDimUnknown(const abstract::ShapePtr &shape);
static bool IsShapeDynamic(const abstract::ShapePtr &shape);
static bool IsShapeDynamic(const std::vector<size_t> &shape);
static bool IsShapeDynamic(const std::vector<int64_t> &shape);
static bool IsNodeOutputDynamicShape(const CNodePtr &node);
static bool IsDimUnknown(const AnfNodePtr &node);
// check whether the anf node is a real kernel that can run on device,parameter and constant is real kernel too