!47859 fix wrong device target at pynative mode

Merge pull request !47859 from zhoufeng/fix-wrong-device-target-at-pynative-r2a
This commit is contained in:
i-robot 2023-01-13 11:43:25 +00:00 committed by Gitee
commit 1ca418936b
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 9 additions and 7 deletions

View File

@ -136,6 +136,12 @@ void GetSingleOpGraphInfo(const FrontendOpRunInfoPtr &op_run_info, const std::st
}
} // namespace
std::string ForwardExecutor::device_target() const {
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
return ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
}
GradExecutorPtr ForwardExecutor::grad() const {
auto grad_executor = grad_executor_.lock();
MS_EXCEPTION_IF_NULL(grad_executor);
@ -149,9 +155,6 @@ void ForwardExecutor::Init() {
MS_LOG(DEBUG) << "Init ForwardExecutor";
compile::SetMindRTEnable();
python_adapter::set_python_env_flag(true);
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
device_target_ = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
init_ = true;
}
@ -263,7 +266,7 @@ compile::MindRTBackendPtr ForwardExecutor::GetMindRtBackend(const std::string &d
ValuePtr ForwardExecutor::RunOpWithBackendPolicy(const FrontendOpRunInfoPtr &op_run_info) {
MS_EXCEPTION_IF_NULL(op_run_info);
ValuePtr result;
auto backend_policy = GetBackendPolicy(device_target_);
auto backend_policy = GetBackendPolicy(device_target());
if (backend_policy == kMsBackendVmOnly) {
#ifndef ENABLE_TEST
if (kVmOperators.find(op_run_info->base_op_run_info.op_name) != kVmOperators.end()) {
@ -399,7 +402,7 @@ std::string ForwardExecutor::GetCurrentDeviceTarget(const PrimitivePtr &op_prim)
if (iter != attr_map.end()) {
return GetValue<std::string>(iter->second);
}
return device_target_;
return device_target();
}
void ForwardExecutor::Sync() {

View File

@ -75,7 +75,7 @@ class ForwardExecutor {
inline void set_is_ms_function_compiling(bool is_ms_function_compiling) {
is_ms_function_compiling_ = is_ms_function_compiling;
}
inline std::string device_target() { return device_target_; }
std::string device_target() const;
private:
GradExecutorPtr grad() const;
@ -102,7 +102,6 @@ class ForwardExecutor {
bool is_ms_function_compiling_{false};
uint32_t device_id_{0};
std::string last_target_{"Unknown"};
std::string device_target_;
std::stack<CellPtr> forward_cell_stack_;
GradExecutorWeakPtr grad_executor_;
CastOperationPtr cast_operation_;