forked from mindspore-Ecosystem/mindspore
!47859 fix wrong device target at pynative mode
Merge pull request !47859 from zhoufeng/fix-wrong-device-target-at-pynative-r2a
This commit is contained in:
commit
1ca418936b
|
@ -136,6 +136,12 @@ void GetSingleOpGraphInfo(const FrontendOpRunInfoPtr &op_run_info, const std::st
|
|||
}
|
||||
} // namespace
|
||||
|
||||
std::string ForwardExecutor::device_target() const {
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
return ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
|
||||
}
|
||||
|
||||
GradExecutorPtr ForwardExecutor::grad() const {
|
||||
auto grad_executor = grad_executor_.lock();
|
||||
MS_EXCEPTION_IF_NULL(grad_executor);
|
||||
|
@ -149,9 +155,6 @@ void ForwardExecutor::Init() {
|
|||
MS_LOG(DEBUG) << "Init ForwardExecutor";
|
||||
compile::SetMindRTEnable();
|
||||
python_adapter::set_python_env_flag(true);
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
device_target_ = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
|
||||
init_ = true;
|
||||
}
|
||||
|
||||
|
@ -263,7 +266,7 @@ compile::MindRTBackendPtr ForwardExecutor::GetMindRtBackend(const std::string &d
|
|||
ValuePtr ForwardExecutor::RunOpWithBackendPolicy(const FrontendOpRunInfoPtr &op_run_info) {
|
||||
MS_EXCEPTION_IF_NULL(op_run_info);
|
||||
ValuePtr result;
|
||||
auto backend_policy = GetBackendPolicy(device_target_);
|
||||
auto backend_policy = GetBackendPolicy(device_target());
|
||||
if (backend_policy == kMsBackendVmOnly) {
|
||||
#ifndef ENABLE_TEST
|
||||
if (kVmOperators.find(op_run_info->base_op_run_info.op_name) != kVmOperators.end()) {
|
||||
|
@ -399,7 +402,7 @@ std::string ForwardExecutor::GetCurrentDeviceTarget(const PrimitivePtr &op_prim)
|
|||
if (iter != attr_map.end()) {
|
||||
return GetValue<std::string>(iter->second);
|
||||
}
|
||||
return device_target_;
|
||||
return device_target();
|
||||
}
|
||||
|
||||
void ForwardExecutor::Sync() {
|
||||
|
|
|
@ -75,7 +75,7 @@ class ForwardExecutor {
|
|||
inline void set_is_ms_function_compiling(bool is_ms_function_compiling) {
|
||||
is_ms_function_compiling_ = is_ms_function_compiling;
|
||||
}
|
||||
inline std::string device_target() { return device_target_; }
|
||||
std::string device_target() const;
|
||||
|
||||
private:
|
||||
GradExecutorPtr grad() const;
|
||||
|
@ -102,7 +102,6 @@ class ForwardExecutor {
|
|||
bool is_ms_function_compiling_{false};
|
||||
uint32_t device_id_{0};
|
||||
std::string last_target_{"Unknown"};
|
||||
std::string device_target_;
|
||||
std::stack<CellPtr> forward_cell_stack_;
|
||||
GradExecutorWeakPtr grad_executor_;
|
||||
CastOperationPtr cast_operation_;
|
||||
|
|
Loading…
Reference in New Issue