!28211 Bugfix for PyNative Heterogeneous in MindRT

Merge pull request !28211 from caifubi/master-pynative-mindrt-heter
This commit is contained in:
i-robot 2021-12-27 03:55:44 +00:00 committed by Gitee
commit 708f5559b3
3 changed files with 41 additions and 6 deletions

View File

@ -285,19 +285,50 @@ ValueNodePtr CreateNewValueNode(const AnfNodePtr &anf, KernelGraph *graph) {
return new_value_node;
}
std::string GetOpRunDeviceTarget(const PrimitivePtr &op_prim) {
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
const std::string &device_target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
MS_EXCEPTION_IF_NULL(op_prim);
const auto &attr_map = op_prim->attrs();
auto iter = attr_map.find(kAttrPrimitiveTarget);
if (iter != attr_map.end()) {
return GetValue<std::string>(iter->second);
}
return device_target;
}
// Need to discard input tensor properties in heterogeneous scenarios.
// For example, the format of device_address in input_tensor is 5D format,
// and it's invalid for CPU graph parameter.
bool NeedDiscardTensorProperties(const std::string &op_device_target,
const device::DeviceAddressPtr &tensor_device_address) {
if (tensor_device_address == nullptr) {
return true;
}
auto tensor_device_address_type = tensor_device_address->DeviceType();
auto tensor_device_address_type_str = device::kDeviceTypeToName.at(tensor_device_address_type);
if (op_device_target == tensor_device_address_type_str) {
return false;
}
return true;
}
ParameterPtr ConstructRunOpParameter(const std::shared_ptr<KernelGraph> &graph, const tensor::TensorPtr &input_tensor,
int64_t tensor_mask) {
const OpRunInfo &op_run_info, int64_t tensor_mask) {
MS_EXCEPTION_IF_NULL(graph);
auto param = graph->NewParameter();
MS_EXCEPTION_IF_NULL(param);
if (tensor_mask == kParameterWeightTensorMask) {
param->set_default_param(input_tensor);
}
// set the kernel info of parameter
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
MS_EXCEPTION_IF_NULL(input_tensor);
auto device_address = std::dynamic_pointer_cast<device::DeviceAddress>(input_tensor->device_address());
if (device_address == nullptr) {
if (NeedDiscardTensorProperties(op_run_info.device_target, device_address)) {
kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT});
TypeId param_init_data_type = AnfAlgo::IsParameterWeight(param) ? kTypeUnknown : input_tensor->data_type();
kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>{param_init_data_type});
@ -1264,7 +1295,8 @@ OpRunInfo SessionBasic::GetSingleOpRunInfo(const CNodePtr &cnode, const GraphInf
.next_input_index = 0,
.graph_info = graph_info,
.tensor_mask = tensor_info.input_tensors_mask,
.input_tensors = tensor_info.input_tensors};
.input_tensors = tensor_info.input_tensors,
.device_target = GetOpRunDeviceTarget(primitive)};
return op_run_info;
}
@ -2231,7 +2263,7 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructSingleOpGraph(const OpRunInf
inputs.push_back(value_node);
continue;
}
auto parameter = ConstructRunOpParameter(graph, input_tensors[i], tensors_mask[i]);
auto parameter = ConstructRunOpParameter(graph, input_tensors[i], op_run_info, tensors_mask[i]);
inputs.push_back(parameter);
auto mutable_inputs = graph->MutableInputs();
MS_EXCEPTION_IF_NULL(mutable_inputs);

View File

@ -73,6 +73,7 @@ struct OpRunInfo {
std::string graph_info;
std::vector<int64_t> tensor_mask;
std::vector<tensor::TensorPtr> input_tensors;
std::string device_target = "Unknown";
};
struct InputTensorInfo {

View File

@ -2177,7 +2177,8 @@ py::object ForwardExecutor::RunOpInMs(const OpExecInfoPtr &op_exec_info, Pynativ
static_cast<int>(op_exec_info->next_input_index),
graph_info,
tensors_mask,
input_tensors};
input_tensors,
cur_target};
#else
session::OpRunInfo op_run_info = {false,
op_exec_info->op_name,
@ -2190,7 +2191,8 @@ py::object ForwardExecutor::RunOpInMs(const OpExecInfoPtr &op_exec_info, Pynativ
op_exec_info->next_input_index,
graph_info,
tensors_mask,
input_tensors};
input_tensors,
cur_target};
#endif
VectorRef outputs;