forked from mindspore-Ecosystem/mindspore
!28211 Bugfix for PyNative Heterogeneous in MindRT
Merge pull request !28211 from caifubi/master-pynative-mindrt-heter
This commit is contained in:
commit
708f5559b3
|
@ -285,19 +285,50 @@ ValueNodePtr CreateNewValueNode(const AnfNodePtr &anf, KernelGraph *graph) {
|
|||
return new_value_node;
|
||||
}
|
||||
|
||||
std::string GetOpRunDeviceTarget(const PrimitivePtr &op_prim) {
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
const std::string &device_target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
|
||||
|
||||
MS_EXCEPTION_IF_NULL(op_prim);
|
||||
const auto &attr_map = op_prim->attrs();
|
||||
auto iter = attr_map.find(kAttrPrimitiveTarget);
|
||||
if (iter != attr_map.end()) {
|
||||
return GetValue<std::string>(iter->second);
|
||||
}
|
||||
return device_target;
|
||||
}
|
||||
|
||||
// Need to discard input tensor properties in heterogeneous scenarios.
|
||||
// For example, the format of device_address in input_tensor is 5D format,
|
||||
// and it's invalid for CPU graph parameter.
|
||||
bool NeedDiscardTensorProperties(const std::string &op_device_target,
|
||||
const device::DeviceAddressPtr &tensor_device_address) {
|
||||
if (tensor_device_address == nullptr) {
|
||||
return true;
|
||||
}
|
||||
auto tensor_device_address_type = tensor_device_address->DeviceType();
|
||||
auto tensor_device_address_type_str = device::kDeviceTypeToName.at(tensor_device_address_type);
|
||||
if (op_device_target == tensor_device_address_type_str) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
ParameterPtr ConstructRunOpParameter(const std::shared_ptr<KernelGraph> &graph, const tensor::TensorPtr &input_tensor,
|
||||
int64_t tensor_mask) {
|
||||
const OpRunInfo &op_run_info, int64_t tensor_mask) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto param = graph->NewParameter();
|
||||
MS_EXCEPTION_IF_NULL(param);
|
||||
if (tensor_mask == kParameterWeightTensorMask) {
|
||||
param->set_default_param(input_tensor);
|
||||
}
|
||||
|
||||
// set the kernel info of parameter
|
||||
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
|
||||
MS_EXCEPTION_IF_NULL(input_tensor);
|
||||
auto device_address = std::dynamic_pointer_cast<device::DeviceAddress>(input_tensor->device_address());
|
||||
if (device_address == nullptr) {
|
||||
if (NeedDiscardTensorProperties(op_run_info.device_target, device_address)) {
|
||||
kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT});
|
||||
TypeId param_init_data_type = AnfAlgo::IsParameterWeight(param) ? kTypeUnknown : input_tensor->data_type();
|
||||
kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>{param_init_data_type});
|
||||
|
@ -1264,7 +1295,8 @@ OpRunInfo SessionBasic::GetSingleOpRunInfo(const CNodePtr &cnode, const GraphInf
|
|||
.next_input_index = 0,
|
||||
.graph_info = graph_info,
|
||||
.tensor_mask = tensor_info.input_tensors_mask,
|
||||
.input_tensors = tensor_info.input_tensors};
|
||||
.input_tensors = tensor_info.input_tensors,
|
||||
.device_target = GetOpRunDeviceTarget(primitive)};
|
||||
return op_run_info;
|
||||
}
|
||||
|
||||
|
@ -2231,7 +2263,7 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructSingleOpGraph(const OpRunInf
|
|||
inputs.push_back(value_node);
|
||||
continue;
|
||||
}
|
||||
auto parameter = ConstructRunOpParameter(graph, input_tensors[i], tensors_mask[i]);
|
||||
auto parameter = ConstructRunOpParameter(graph, input_tensors[i], op_run_info, tensors_mask[i]);
|
||||
inputs.push_back(parameter);
|
||||
auto mutable_inputs = graph->MutableInputs();
|
||||
MS_EXCEPTION_IF_NULL(mutable_inputs);
|
||||
|
|
|
@ -73,6 +73,7 @@ struct OpRunInfo {
|
|||
std::string graph_info;
|
||||
std::vector<int64_t> tensor_mask;
|
||||
std::vector<tensor::TensorPtr> input_tensors;
|
||||
std::string device_target = "Unknown";
|
||||
};
|
||||
|
||||
struct InputTensorInfo {
|
||||
|
|
|
@ -2177,7 +2177,8 @@ py::object ForwardExecutor::RunOpInMs(const OpExecInfoPtr &op_exec_info, Pynativ
|
|||
static_cast<int>(op_exec_info->next_input_index),
|
||||
graph_info,
|
||||
tensors_mask,
|
||||
input_tensors};
|
||||
input_tensors,
|
||||
cur_target};
|
||||
#else
|
||||
session::OpRunInfo op_run_info = {false,
|
||||
op_exec_info->op_name,
|
||||
|
@ -2190,7 +2191,8 @@ py::object ForwardExecutor::RunOpInMs(const OpExecInfoPtr &op_exec_info, Pynativ
|
|||
op_exec_info->next_input_index,
|
||||
graph_info,
|
||||
tensors_mask,
|
||||
input_tensors};
|
||||
input_tensors,
|
||||
cur_target};
|
||||
#endif
|
||||
|
||||
VectorRef outputs;
|
||||
|
|
Loading…
Reference in New Issue