forked from mindspore-Ecosystem/mindspore
unified runtime fix the kernel not support dynamic shape
This commit is contained in:
parent
8961541339
commit
99d6cc1ac0
|
@ -97,7 +97,7 @@ void KernelActor::Run(OpContext<DeviceTensor> *const context) {
|
|||
}
|
||||
|
||||
FetchInputDeviceTensor(context);
|
||||
FetchOutputDeviceTensor();
|
||||
FetchOutputDeviceTensor(context);
|
||||
if (memory_alloc_list_.size() > 0) {
|
||||
SendMemoryAllocReq(context);
|
||||
} else {
|
||||
|
@ -119,7 +119,7 @@ void KernelActor::RunOpControlWithInputTensor(AID *const input_control, OpContex
|
|||
device_contexts_[0]->UpdateDynamicShape(kernel_);
|
||||
}
|
||||
|
||||
FetchOutputDeviceTensor();
|
||||
FetchOutputDeviceTensor(context);
|
||||
if (memory_alloc_list_.size() > 0) {
|
||||
SendMemoryAllocReq(context);
|
||||
}
|
||||
|
@ -323,15 +323,18 @@ void KernelActor::FetchInputDeviceTensor(OpContext<DeviceTensor> *const context)
|
|||
}
|
||||
}
|
||||
|
||||
void KernelActor::FetchOutputDeviceTensor() {
|
||||
void KernelActor::FetchOutputDeviceTensor(OpContext<DeviceTensor> *const context) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_info_);
|
||||
auto &output_addresses = kernel_info_->output_address_list();
|
||||
const auto &kernel_mod = kernel_info_->kernel_mod();
|
||||
MS_EXCEPTION_IF_NULL(kernel_mod);
|
||||
const auto &output_size_list = kernel_mod->GetOutputSizeList();
|
||||
|
||||
// May exist in the kernel which does not support the dynamic shape.
|
||||
if (output_addresses.size() != output_size_list.size()) {
|
||||
MS_LOG(EXCEPTION) << "The outputs number is not equal.";
|
||||
std::string error_info = "The outputs number(" + std::to_string(output_size_list.size()) + ") is wrong, " +
|
||||
GetAID().Name() + " may not support the dynamic shape, please check.";
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy_, (*context), error_info);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < output_addresses.size(); ++i) {
|
||||
|
|
|
@ -85,7 +85,7 @@ class KernelActor : public DebugAwareActor {
|
|||
|
||||
// Fetch the device tensor for launch.
|
||||
void FetchInputDeviceTensor(OpContext<DeviceTensor> *const context);
|
||||
void FetchOutputDeviceTensor();
|
||||
void FetchOutputDeviceTensor(OpContext<DeviceTensor> *const context);
|
||||
void CopyInputDeviceTensor(const OpData<DeviceTensor> *input_data, OpContext<DeviceTensor> *const context);
|
||||
// In step mode, push the input tensors which contain valid device address into input_device_tensors_ directly.
|
||||
void PushInputDeviceTensor(const std::vector<TensorPtr> *input_tensors);
|
||||
|
|
Loading…
Reference in New Issue