unified runtime fix the kernel not support dynamic shape

This commit is contained in:
limingqi107 2021-12-31 14:10:51 +08:00
parent 8961541339
commit 99d6cc1ac0
2 changed files with 8 additions and 5 deletions

View File

@ -97,7 +97,7 @@ void KernelActor::Run(OpContext<DeviceTensor> *const context) {
}
FetchInputDeviceTensor(context);
FetchOutputDeviceTensor();
FetchOutputDeviceTensor(context);
if (memory_alloc_list_.size() > 0) {
SendMemoryAllocReq(context);
} else {
@ -119,7 +119,7 @@ void KernelActor::RunOpControlWithInputTensor(AID *const input_control, OpContex
device_contexts_[0]->UpdateDynamicShape(kernel_);
}
FetchOutputDeviceTensor();
FetchOutputDeviceTensor(context);
if (memory_alloc_list_.size() > 0) {
SendMemoryAllocReq(context);
}
@ -323,15 +323,18 @@ void KernelActor::FetchInputDeviceTensor(OpContext<DeviceTensor> *const context)
}
}
void KernelActor::FetchOutputDeviceTensor() {
void KernelActor::FetchOutputDeviceTensor(OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(kernel_info_);
auto &output_addresses = kernel_info_->output_address_list();
const auto &kernel_mod = kernel_info_->kernel_mod();
MS_EXCEPTION_IF_NULL(kernel_mod);
const auto &output_size_list = kernel_mod->GetOutputSizeList();
// May exist in the kernel which does not support the dynamic shape.
if (output_addresses.size() != output_size_list.size()) {
MS_LOG(EXCEPTION) << "The outputs number is not equal.";
std::string error_info = "The outputs number(" + std::to_string(output_size_list.size()) + ") is wrong, " +
GetAID().Name() + " may not support the dynamic shape, please check.";
SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy_, (*context), error_info);
}
for (size_t i = 0; i < output_addresses.size(); ++i) {

View File

@ -85,7 +85,7 @@ class KernelActor : public DebugAwareActor {
// Fetch the device tensor for launch.
void FetchInputDeviceTensor(OpContext<DeviceTensor> *const context);
void FetchOutputDeviceTensor();
void FetchOutputDeviceTensor(OpContext<DeviceTensor> *const context);
void CopyInputDeviceTensor(const OpData<DeviceTensor> *input_data, OpContext<DeviceTensor> *const context);
// In step mode, push the input tensors which contain valid device address into input_device_tensors_ directly.
void PushInputDeviceTensor(const std::vector<TensorPtr> *input_tensors);