forked from mindspore-Ecosystem/mindspore
fix the bug of input format inconsistency
This commit is contained in:
parent
87c2bd0c45
commit
3cde5d57c7
|
@ -48,7 +48,8 @@ void KernelActor::Init() {
|
|||
const auto &input_device_tensor = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel_, i, false);
|
||||
MS_EXCEPTION_IF_NULL(input_device_tensor);
|
||||
(void)real_input_data_infos_.emplace_back(
|
||||
std::make_pair(input_device_tensor->format(), input_device_tensor->host_shape()));
|
||||
std::make_shared<InputDataInfo>(input_device_tensor->format(), input_device_tensor->host_shape(),
|
||||
input_device_tensor->GetSize(), input_device_tensor->type_id()));
|
||||
}
|
||||
|
||||
// Init the device tensors and kernel launch info.
|
||||
|
@ -311,8 +312,9 @@ void KernelActor::CopyInputDeviceTensor(const OpData<DeviceTensor> *input_data,
|
|||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy_, *context, "The input index is of range.");
|
||||
}
|
||||
auto &real_input_info = real_input_data_infos_[input_data->index_];
|
||||
MS_EXCEPTION_IF_NULL(real_input_info);
|
||||
if ((input_data->data_->DeviceType() == device_contexts_[0]->GetDeviceAddressType()) &&
|
||||
(input_data->data_->format() == real_input_info.first)) {
|
||||
(input_data->data_->format() == real_input_info->format_)) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -320,14 +322,15 @@ void KernelActor::CopyInputDeviceTensor(const OpData<DeviceTensor> *input_data,
|
|||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy_, *context, "The input index is of range.");
|
||||
}
|
||||
if (copy_input_device_tensors_[input_data->index_] == nullptr) {
|
||||
copy_input_device_tensors_[input_data->index_] =
|
||||
device_contexts_[0]->CreateDeviceAddress(nullptr, input_data->data_->GetSize(), real_input_info.first,
|
||||
input_data->data_->type_id(), real_input_info.second);
|
||||
copy_input_device_tensors_[input_data->index_] = device_contexts_[0]->CreateDeviceAddress(
|
||||
nullptr, real_input_info->size_, real_input_info->format_, real_input_info->type_id_, real_input_info->shape_);
|
||||
}
|
||||
auto &new_device_tensor = copy_input_device_tensors_[input_data->index_];
|
||||
MS_EXCEPTION_IF_NULL(new_device_tensor);
|
||||
// Dynamic shape need update size.
|
||||
new_device_tensor->SetSize(input_data->data_->GetSize());
|
||||
if (AnfUtils::IsShapeDynamic(real_input_info->shape_)) {
|
||||
new_device_tensor->SetSize(input_data->data_->GetSize());
|
||||
}
|
||||
// Update the input device tensor.
|
||||
input_device_tensors_[input_data->index_] = new_device_tensor.get();
|
||||
|
||||
|
|
|
@ -39,6 +39,15 @@ using mindspore::kernel::Address;
|
|||
using mindspore::kernel::KernelLaunchInfo;
|
||||
using mindspore::tensor::TensorPtr;
|
||||
|
||||
struct InputDataInfo {
|
||||
InputDataInfo(const std::string &format, const ShapeVector &shape, size_t size, TypeId type_id)
|
||||
: format_(format), shape_(shape), size_(size), type_id_(type_id) {}
|
||||
std::string format_;
|
||||
ShapeVector shape_;
|
||||
size_t size_;
|
||||
TypeId type_id_;
|
||||
};
|
||||
|
||||
// The kernel actor is used to receive the device tensors and control info to luanch kernel.
|
||||
// The processing flow is RunOpData/RunOpControl -> CheckRunningCondition -> SendMemoryAllocReq
|
||||
// -> OnMemoryAllocFinish -> LaunchKernel -> SendMemoryFreeReq -> SendOutput.
|
||||
|
@ -129,8 +138,8 @@ class KernelActor : public DebugAwareActor {
|
|||
// The received input device type and format may be different from the formal parameter in the control flow scenarios,
|
||||
// so it needs to be copied from the input data to real data that kernel launch needs.
|
||||
std::vector<DeviceTensorPtr> copy_input_device_tensors_;
|
||||
// Real data info <format, host_shape> that kernel launch needs, used to check the consistency of received input data.
|
||||
std::vector<std::pair<std::string, ShapeVector>> real_input_data_infos_;
|
||||
// Real data info that kernel launch needs, used to check the consistency of received input data.
|
||||
std::vector<std::shared_ptr<InputDataInfo>> real_input_data_infos_;
|
||||
|
||||
// The device tensors for memory alloc and free.
|
||||
// output + workspace
|
||||
|
|
|
@ -142,6 +142,10 @@ bool AnfUtils::IsShapeDynamic(const std::vector<size_t> &shape) {
|
|||
return std::any_of(shape.begin(), shape.end(), [](int64_t s) { return s < 0; });
|
||||
}
|
||||
|
||||
bool AnfUtils::IsShapeDynamic(const std::vector<int64_t> &shape) {
|
||||
return std::any_of(shape.begin(), shape.end(), [](int64_t s) { return s < 0; });
|
||||
}
|
||||
|
||||
bool AnfUtils::IsNodeOutputDynamicShape(const CNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto base_shape = node->Shape();
|
||||
|
|
|
@ -65,6 +65,7 @@ class MS_CORE_API AnfUtils {
|
|||
static bool IsDimUnknown(const abstract::ShapePtr &shape);
|
||||
static bool IsShapeDynamic(const abstract::ShapePtr &shape);
|
||||
static bool IsShapeDynamic(const std::vector<size_t> &shape);
|
||||
static bool IsShapeDynamic(const std::vector<int64_t> &shape);
|
||||
static bool IsNodeOutputDynamicShape(const CNodePtr &node);
|
||||
static bool IsDimUnknown(const AnfNodePtr &node);
|
||||
// check whether the anf node is a real kernel that can run on device,parameter and constant is real kernel too
|
||||
|
|
Loading…
Reference in New Issue