!22320 support heterogeneous for pynative mode

Merge pull request !22320 from chujinjin/support_heterogeneous_for_pynative
This commit is contained in:
i-robot 2021-12-01 08:34:44 +00:00 committed by Gitee
commit 8e496e44f0
12 changed files with 292 additions and 59 deletions

View File

@ -754,7 +754,7 @@ void AscendSession::PrepareForOutputTensor(const KernelGraphPtr &graph,
std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node,
VectorRef *outputs) const {
// Create DeviceAddress For Output Tensor(contain: Shape, Format, DType)
auto runtime_instance = device::KernelRuntimeManager::Instance().GetCurrentKernelRuntime();
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
runtime_instance->RunOpMallocPre(*graph, input_tensors);
runtime_instance->UpdateRefNodeOutputMem(*graph);
// CREATE OUTPUT TENSOR ADDRESS
@ -810,6 +810,7 @@ void AscendSession::RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_inf
}
MS_EXCEPTION_IF_NULL(input_tensors);
ProcessInputTensorsForHeterogeneous("Ascend", *input_tensors);
bool cache_miss = run_op_graphs_.find(graph_info) == run_op_graphs_.end();
auto graph = CreateKernelGraph(graph_info, op_run_info, input_tensors, tensors_mask, cache_miss);
EraseValueNodeTensor(tensors_mask, input_tensors);
@ -847,8 +848,8 @@ void AscendSession::RunOpImplOrigin(const GraphInfo &graph_info, OpRunInfo *op_r
const std::vector<int64_t> &tensors_mask) {
MS_EXCEPTION_IF_NULL(input_tensors);
MS_EXCEPTION_IF_NULL(op_run_info);
ProcessInputTensorsForHeterogeneous("Ascend", *input_tensors);
const auto &graph = BuildOpImpl(*op_run_info, graph_info, *input_tensors, tensors_mask);
EraseValueNodeTensor(tensors_mask, input_tensors);
// wait for allreduce
@ -857,6 +858,7 @@ void AscendSession::RunOpImplOrigin(const GraphInfo &graph_info, OpRunInfo *op_r
tensor->WaitDevice();
}
}
// malloc mem
RunOpRemoveNopNode(graph);
RunOpMemoryAlloc(*input_tensors, graph.get());
@ -1758,7 +1760,7 @@ void AscendSession::SyncStream() const {
std::shared_ptr<device::Bucket> AscendSession::CreateBucket(uint32_t bucket_id, uint32_t bucket_size) {
auto bucket = std::make_shared<device::ascend::AscendBucket>(bucket_id, bucket_size);
auto kernel_runtime = device::KernelRuntimeManager::Instance().GetCurrentKernelRuntime();
auto kernel_runtime = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
MS_EXCEPTION_IF_NULL(kernel_runtime);
auto compute_stream = kernel_runtime->compute_stream();
auto communication_stream = kernel_runtime->communication_stream();

View File

@ -284,8 +284,10 @@ void CPUSession::RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info,
const std::vector<int64_t> &tensors_mask) {
MS_EXCEPTION_IF_NULL(input_tensors);
MS_EXCEPTION_IF_NULL(op_run_info);
ProcessInputTensorsForHeterogeneous("CPU", *input_tensors);
const auto &kernel_graph = BuildOpImpl(*op_run_info, graph_info, *input_tensors, tensors_mask);
EraseValueNodeTensor(tensors_mask, input_tensors);
// Remove reorder after PS feature finish adapting push/pull in auto_monad.
auto execution_order = kernel_graph->execution_order();
Reorder(&execution_order);

View File

@ -321,6 +321,30 @@ size_t UpdateGraphInputAbstract(const AnfNodePtr input_node, const tensor::Tenso
}
return size;
}
bool CheckIfNeedSync(const tensor::TensorPtr &tensor, const DeviceAddressPtr &device_address,
const ParameterPtr &pk_node) {
MS_EXCEPTION_IF_NULL(tensor);
MS_EXCEPTION_IF_NULL(pk_node);
auto tensor_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
bool need_sync = false;
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER)) {
if (tensor_address == nullptr || tensor_address != device_address) {
need_sync = true;
}
} else if (tensor->NeedSyncHostToDevice() || tensor_address == nullptr) {
need_sync = true;
} else if (tensor_address != device_address) {
if (tensor_address->DeviceType() == device_address->DeviceType()) {
AnfAlgo::SetOutputAddr(tensor_address, 0, pk_node.get());
} else {
need_sync = true;
}
}
return need_sync;
}
} // namespace
void GPUSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
@ -348,21 +372,7 @@ void GPUSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
auto pk_node = input_node->cast<ParameterPtr>();
auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0);
MS_EXCEPTION_IF_NULL(device_address);
auto tensor_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
bool need_sync = false;
if (ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER)) {
if (tensor_address == nullptr || tensor_address != device_address) {
need_sync = true;
}
} else if (tensor->NeedSyncHostToDevice() || tensor_address == nullptr) {
need_sync = true;
} else if (tensor_address != device_address) {
if (tensor_address->DeviceType() == device_address->DeviceType()) {
AnfAlgo::SetOutputAddr(tensor_address, 0, pk_node.get());
} else {
need_sync = true;
}
}
bool need_sync = CheckIfNeedSync(tensor, device_address, pk_node);
if (need_sync) {
if (AnfAlgo::IsParameterWeight(pk_node) || UpdatedByAssign(kernel_graph, input_node) ||
ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
@ -681,6 +691,7 @@ void GPUSession::RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info,
const std::vector<int64_t> &tensors_mask) {
MS_EXCEPTION_IF_NULL(input_tensors);
MS_EXCEPTION_IF_NULL(op_run_info);
ProcessInputTensorsForHeterogeneous("GPU", *input_tensors);
const auto &kernel_graph = BuildOpImpl(*op_run_info, graph_info, *input_tensors, tensors_mask);
EraseValueNodeTensor(tensors_mask, input_tensors);
// wait for allreduce
@ -690,6 +701,7 @@ void GPUSession::RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info,
tensor->WaitDevice();
}
}
// run op
MS_EXCEPTION_IF_NULL(kernel_graph);
RunOpRemoveNopNode(kernel_graph);

View File

@ -20,6 +20,7 @@
#include <queue>
#include <utility>
#include <functional>
#include <unordered_map>
#include "utils/hash_map.h"
#include "ops/primitive_c.h"
@ -2289,6 +2290,33 @@ void SessionBasic::RunGraphImpl(const GraphId &graph_id, const std::vector<tenso
MS_LOG(INFO) << "Status record: end run graph. graph id: " << graph_id;
}
device::DeviceAddressType DeviceTargetToDeviceType(const std::string &device_target) {
static const std::unordered_map<std::string, device::DeviceAddressType> target_type = {
{"Unknown", device::DeviceAddressType::kUnknown},
{"Ascend", device::DeviceAddressType::kAscend},
{"CPU", device::DeviceAddressType::kCPU},
{"GPU", device::DeviceAddressType::kGPU},
{"Davinci", device::DeviceAddressType::kAscend}};
auto iter = target_type.find(device_target);
if (iter == target_type.end()) {
MS_LOG(EXCEPTION) << "Not support device target: " << device_target;
}
return iter->second;
}
void SessionBasic::ProcessInputTensorsForHeterogeneous(const std::string &cur_target,
const std::vector<tensor::TensorPtr> &input_tensors) {
for (auto &tensor : input_tensors) {
auto device_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
if (device_address != nullptr) {
if (device_address->DeviceType() != DeviceTargetToDeviceType(cur_target)) {
tensor->data_sync();
tensor->set_device_address(nullptr);
}
}
}
}
void SessionBasic::RunOpsInGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
VectorRef *outputs) {
MS_LOG(INFO) << "Clean task in Queue";

View File

@ -244,6 +244,8 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs,
const std::vector<int64_t> &tensors_mask) {}
void RunOpsInGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs);
void ProcessInputTensorsForHeterogeneous(const std::string &cur_target,
const std::vector<tensor::TensorPtr> &input_tensors);
virtual void BuildOpsInGraph(const GraphId &graph_id, const std::map<AnfNodePtr, size_t> &parameter_index,
const std::vector<tensor::TensorPtr> &graph_inputs,
const std::map<KernelWithIndex, size_t> &cnode_refcount) {}

View File

@ -83,8 +83,8 @@ const std::set<std::string> kVmOperators = {"make_ref", "HookBackward", "InsertG
"mixed_precision_cast"};
const char kOpsFunctionModelName[] = "mindspore.ops.functional";
const char kGrad[] = "grad";
std::shared_ptr<session::SessionBasic> kSession = nullptr;
std::shared_ptr<compile::MindRTBackend> mind_rt_backend = nullptr;
std::map<std::string, std::shared_ptr<session::SessionBasic>> kSessionBackends;
std::map<std::string, std::shared_ptr<compile::MindRTBackend>> kMindRtBackends;
PyObjectIdCache g_pyobj_id_cache;
template <typename T, typename... Args>
@ -224,6 +224,41 @@ TypeId JudgeMaxType(TypeId max_type, bool has_scalar_float32, bool has_scalar_in
return max_type;
}
std::string GetCurrentDeviceTarget(const std::string &device_target, const PrimitivePyPtr &op_prim) {
MS_EXCEPTION_IF_NULL(op_prim);
const auto &attr_map = op_prim->attrs();
auto iter = attr_map.find("primitive_target");
if (iter != attr_map.end()) {
return GetValue<std::string>(iter->second);
}
return device_target;
}
session::SessionPtr GetCurrentSession(const std::string &device_target, uint32_t device_id) {
auto iter = kSessionBackends.find(device_target);
if (iter == kSessionBackends.end()) {
auto session = session::SessionFactory::Get().Create(device_target);
MS_EXCEPTION_IF_NULL(session);
session->Init(device_id);
kSessionBackends[device_target] = session;
return session;
} else {
return iter->second;
}
}
compile::MindRTBackendPtr GetMindRtBackend(const std::string &device_target, uint32_t device_id) {
auto iter = kMindRtBackends.find(device_target);
if (iter == kMindRtBackends.end()) {
auto backend = std::make_shared<compile::MindRTBackend>("ms", device_target, device_id);
MS_EXCEPTION_IF_NULL(backend);
kMindRtBackends[device_target] = backend;
return backend;
} else {
return iter->second;
}
}
void GetDstType(const py::tuple &py_args,
const mindspore::HashMap<SignatureEnumDType, std::vector<size_t>> &type_indexes,
mindspore::HashMap<SignatureEnumDType, TypeId> *dst_type) {
@ -748,12 +783,15 @@ void UpdateTensorInfo(const tensor::TensorPtr &new_tensor, const std::vector<ten
<< new_tensor->GetShapeAndDataTypeInfo();
pre_tensor->set_shape(new_tensor->shape());
pre_tensor->set_data_type(new_tensor->data_type());
if (device_target != kCPUDevice) {
auto device_address = std::dynamic_pointer_cast<device::DeviceAddress>(new_tensor->device_address());
MS_EXCEPTION_IF_NULL(device_address);
if (device_target != kCPUDevice && device_address->DeviceType() != device::DeviceAddressType::kCPU) {
pre_tensor->set_device_address(new_tensor->device_address());
continue;
}
if (mind_rt_backend != nullptr) {
mind_rt_backend->SyncLazyTasks();
for (auto &item : kMindRtBackends) {
MS_EXCEPTION_IF_NULL(item.second);
item.second->SyncLazyTasks();
}
// Replace data in device address when run in CPU device.
if (pre_tensor->device_address() != nullptr) {
@ -773,6 +811,11 @@ void UpdateTensorInfo(const tensor::TensorPtr &new_tensor, const std::vector<ten
auto ret_code = std::memcpy(old_ptr, new_ptr, old_device_address->GetSize());
MS_EXCEPTION_IF_CHECK_FAIL(ret_code == old_ptr, "Memory copy failed");
}
} else {
pre_tensor->set_device_address(device_address);
pre_tensor->data_sync();
pre_tensor->set_device_address(nullptr);
pre_tensor->set_sync_status(kNeedSyncHostToDevice);
}
}
}
@ -1998,21 +2041,28 @@ py::object ForwardExecutor::RunOpInVM(const OpExecInfoPtr &op_exec_info, Pynativ
return std::move(tuple_result);
}
void ForwardExecutor::CheckIfNeedSyncForHeterogeneous(const std::string &cur_target) {
if (last_target_ != "Unknown" && last_target_ != cur_target) {
auto executor = PynativeExecutor::GetInstance();
executor->Sync();
}
last_target_ = cur_target;
}
py::object ForwardExecutor::RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status) {
MS_EXCEPTION_IF_NULL(op_exec_info);
MS_EXCEPTION_IF_NULL(status);
compile::SetMindRTEnable();
MS_LOG(DEBUG) << "Start run op [" << op_exec_info->op_name << "] with backend policy ms";
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, true);
compile::SetMindRTEnable();
const std::string &device_target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
uint32_t device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
auto enable_mind_rt = ms_context->get_param<bool>(MS_CTX_ENABLE_MINDRT);
if (kSession == nullptr && !ms_context->get_param<bool>(MS_CTX_ENABLE_MINDRT)) {
const auto &device_target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
kSession = session::SessionFactory::Get().Create(device_target);
MS_EXCEPTION_IF_NULL(kSession);
kSession->Init(ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID));
}
std::string cur_target = GetCurrentDeviceTarget(device_target, op_exec_info->py_primitive);
CheckIfNeedSyncForHeterogeneous(cur_target);
std::vector<tensor::TensorPtr> input_tensors;
std::vector<int64_t> tensors_mask;
@ -2021,7 +2071,6 @@ py::object ForwardExecutor::RunOpInMs(const OpExecInfoPtr &op_exec_info, Pynativ
ConvertAttrToUnifyMindIR(op_exec_info);
// get graph info for checking it whether existing in the cache
GetSingleOpGraphInfo(op_exec_info, input_tensors, tensors_mask, &graph_info);
VectorRef outputs;
#if defined(__APPLE__)
session::OpRunInfo op_run_info = {op_exec_info->op_name,
op_exec_info->py_primitive.get(),
@ -2048,17 +2097,16 @@ py::object ForwardExecutor::RunOpInMs(const OpExecInfoPtr &op_exec_info, Pynativ
input_tensors};
#endif
if (!ms_context->get_param<bool>(MS_CTX_ENABLE_MINDRT)) {
kSession->RunOp(&op_run_info, &outputs);
VectorRef outputs;
if (!enable_mind_rt || cur_target == "Ascend") {
auto cur_session = GetCurrentSession(cur_target, device_id);
MS_EXCEPTION_IF_NULL(cur_session);
cur_session->RunOp(&op_run_info, &outputs);
} else {
if (mind_rt_backend == nullptr) {
const auto &device_target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
uint32_t device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
mind_rt_backend = std::make_shared<compile::MindRTBackend>("ms", device_target, device_id);
}
auto cur_mind_rt_backend = GetMindRtBackend(cur_target, device_id);
MS_EXCEPTION_IF_NULL(cur_mind_rt_backend);
mindspore::ScopedLongRunning long_running;
mind_rt_backend->RunOp(&op_run_info, &outputs);
cur_mind_rt_backend->RunOp(&op_run_info, &outputs);
}
if (op_exec_info->is_dynamic_shape) {
@ -3240,8 +3288,9 @@ void PynativeExecutor::ClearGrad(const py::object &cell, const py::args &args) {
void PynativeExecutor::ClearRes() {
MS_LOG(DEBUG) << "Clear all res";
session::PynativeTaskManager::GetInstance().Reset();
if (mind_rt_backend != nullptr) {
mind_rt_backend->ClearOpBuilderResource();
for (auto &item : kMindRtBackends) {
MS_EXCEPTION_IF_NULL(item.second);
item.second->ClearOpBuilderResource();
}
SetLazyBuild(false);
cell_depth_ = 0;
@ -3260,8 +3309,8 @@ void PynativeExecutor::ClearRes() {
}
ad::CleanRes();
pipeline::ReclaimOptimizer();
kSession = nullptr;
mind_rt_backend = nullptr;
kSessionBackends.clear();
kMindRtBackends.clear();
g_pyobj_id_cache.clear();
}
@ -3303,17 +3352,19 @@ void PynativeExecutor::Sync() {
ExecuteAllTask();
if (!ms_context->get_param<bool>(MS_CTX_ENABLE_MINDRT)) {
if (kSession == nullptr) {
MS_EXCEPTION(NotExistsError) << "No session has been created!";
for (auto &item : kSessionBackends) {
MS_EXCEPTION_IF_NULL(item.second);
item.second->SyncStream();
}
kSession->SyncStream();
} else {
std::string device_target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
uint32_t device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
const auto &device_context =
device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({device_target, device_id});
MS_EXCEPTION_IF_NULL(device_context);
(void)device_context->SyncStream();
for (auto &item : kMindRtBackends) {
MS_EXCEPTION_IF_NULL(item.second);
item.second->SyncStream();
}
for (auto &item : kSessionBackends) {
MS_EXCEPTION_IF_NULL(item.second);
item.second->SyncStream();
}
}
}
@ -3337,8 +3388,9 @@ bool PynativeExecutor::IsTopCell() const { return cell_depth_ == 0; }
void PynativeExecutor::ExecuteAllTask() {
session::PynativeTaskManager::GetInstance().ExecuteRemainingTasks();
if (mind_rt_backend != nullptr) {
mind_rt_backend->SyncLazyTasks();
for (auto &item : kMindRtBackends) {
MS_EXCEPTION_IF_NULL(item.second);
item.second->SyncLazyTasks();
}
}

View File

@ -358,6 +358,7 @@ class ForwardExecutor {
py::object DoAutoCast(const py::object &arg, const TypeId &type_id, const std::string &op_name, size_t index);
void DoSignatureCast(const PrimitivePyPtr &prim, const mindspore::HashMap<SignatureEnumDType, TypeId> &dst_type,
const std::vector<SignatureEnumDType> &dtypes, const OpExecInfoPtr &op_exec_info);
void CheckIfNeedSyncForHeterogeneous(const std::string &cur_target);
private:
GradExecutorWeakPtr grad_executor_;
@ -365,6 +366,7 @@ class ForwardExecutor {
ImplicitCastCache implicit_cast_map_;
mindspore::HashMap<std::string, abstract::AbstractBasePtr> node_abs_map_;
bool lazy_build_{false};
std::string last_target_{"Unknown"};
};
class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {

View File

@ -288,9 +288,14 @@ void DataPrepareActor::PrepareDataForStepMode(const std::vector<std::vector<Tens
auto host_tensor_address = std::dynamic_pointer_cast<DeviceTensor>(input_tensor->device_address());
if (host_tensor_address != nullptr) {
AnfAlgo::SetOutputAddr(host_tensor_address, 0, input_node.get());
host_tensor_address->SetNodeIndex(input_node, 0);
continue;
if (host_tensor_address->DeviceType() != device_context->GetDeviceAddressType()) {
input_tensor->data_sync();
input_tensor->set_device_address(nullptr);
} else {
AnfAlgo::SetOutputAddr(host_tensor_address, 0, input_node.get());
host_tensor_address->SetNodeIndex(input_node, 0);
continue;
}
}
if (!AnfAlgo::OutputAddrExist(input_node, 0, false)) {
@ -496,6 +501,12 @@ void DataPrepareActor::PrepareDataForWeightNode(const AnfNodePtr &backend_node,
} else {
MS_LOG(INFO) << "The device type is not equal, host tensor type:" << host_tensor_address->DeviceType()
<< ", device tensor type:" << device_tensor->DeviceType();
if (strategy_ == GraphExecutionStrategy::kStep) {
tensor->data_sync();
host_tensor_address = device_tensor;
tensor->set_device_address(host_tensor_address);
is_need_sync = true;
}
}
}
// Maybe the same host_tensor_address corresponds to the different front_node in shared weight scene,

View File

@ -616,7 +616,9 @@ bool AscendDeviceContext::LaunchKernel(const CNodePtr &kernel, const vector<Addr
}
bool AscendDeviceContext::BindDeviceToCurrentThread() const {
runtime_instance_->SetContext();
if (initialized_) {
runtime_instance_->SetContext();
}
return true;
}

View File

@ -971,6 +971,13 @@ void MindRTBackend::SyncLazyTasks() const { runtime::OpLazyBuilder::GetInstance(
void MindRTBackend::ClearOpBuilderResource() const { runtime::OpLazyBuilder::GetInstance().Reset(); }
void MindRTBackend::SyncStream() {
const auto &device_context =
device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({device_name_, device_id_});
MS_EXCEPTION_IF_NULL(device_context);
(void)device_context->SyncStream();
}
std::unique_ptr<GraphCompilerInfo> MindRTBackend::ConstructGraphCompilerInfo(const FuncGraphPtr &root_graph) {
MS_EXCEPTION_IF_NULL(root_graph);
MS_EXCEPTION_IF_NULL(graph_compiler_);

View File

@ -121,6 +121,10 @@ class MindRTBackend : public Backend {
void SyncLazyTasks() const;
// Clear resource when python exit.
void ClearOpBuilderResource() const;
// Get the device target.
std::string GetDeviceTarget() { return device_name_; }
// Sync default stream in PyNative mode.
void SyncStream();
private:
// The parameter func_graph is a graph, it can be either a root graph or a sub graph,
@ -194,6 +198,7 @@ class MindRTBackend : public Backend {
int ms_execution_mode_{kGraphMode};
int real_execution_mode_{kGraphMode};
};
using MindRTBackendPtr = std::shared_ptr<compile::MindRTBackend>;
} // namespace compile
} // namespace mindspore
#endif

View File

@ -0,0 +1,108 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test_pynative_heterogeneous """
import numpy as np
import pytest
from mindspore import context, Tensor
from mindspore.nn import Cell
import mindspore.ops as ops
class MulRelu(Cell):
def __init__(self):
super(MulRelu, self).__init__()
self.relu1 = ops.ReLU()
self.relu2 = ops.ReLU()
self.mul = ops.Mul()
def construct(self, inp1, inp2):
x1 = self.relu1(inp1)
x2 = self.relu2(inp2)
y = self.mul(x1, x2)
return y
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_heterogeneous_default_ascend_prim_cpu():
"""
Feature: PyNative heterogeneous.
Description: Default device target is Ascend, the relu1 set to CPU.
Expectation: The output of device is equal to the output of heterogeneous.
"""
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
net = MulRelu()
inp1 = Tensor(np.random.randn(2, 2).astype(np.float32))
inp2 = Tensor(np.random.randn(2, 2).astype(np.float32))
output_device = net(inp1, inp2)
net.relu1.add_prim_attr("primitive_target", "CPU")
output_heter = net(inp1, inp2)
assert np.allclose(output_device.asnumpy(), output_heter.asnumpy(), 1e-6, 1e-6)
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_heterogeneous_default_cpu_prim_ascend():
"""
Feature: PyNative heterogeneous.
Description: Default device target is CPU, the relu1 set to Ascend.
Expectation: The output of device is equal to the output of heterogeneous.
"""
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
net = MulRelu()
inp1 = Tensor(np.random.randn(2, 2).astype(np.float32))
inp2 = Tensor(np.random.randn(2, 2).astype(np.float32))
output_device = net(inp1, inp2)
net.relu1.add_prim_attr("primitive_target", "Ascend")
output_heter = net(inp1, inp2)
assert np.allclose(output_device.asnumpy(), output_heter.asnumpy(), 1e-6, 1e-6)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_heterogeneous_default_gpu_prim_cpu():
"""
Feature: PyNative heterogeneous.
Description: Default device target is GPU, the relu1 set to CPU.
Expectation: The output of device is equal to the output of heterogeneous.
"""
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
net = MulRelu()
inp1 = Tensor(np.random.randn(2, 2).astype(np.float32))
inp2 = Tensor(np.random.randn(2, 2).astype(np.float32))
output_device = net(inp1, inp2)
net.relu1.add_prim_attr("primitive_target", "CPU")
output_heter = net(inp1, inp2)
assert np.allclose(output_device.asnumpy(), output_heter.asnumpy(), 1e-6, 1e-6)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_heterogeneous_default_cpu_prim_gpu():
"""
Feature: PyNative heterogeneous.
Description: Default device target is CPU, the relu1 set to GPU.
Expectation: The output of device is equal to the output of heterogeneous.
"""
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
net = MulRelu()
inp1 = Tensor(np.random.randn(2, 2).astype(np.float32))
inp2 = Tensor(np.random.randn(2, 2).astype(np.float32))
output_device = net(inp1, inp2)
net.relu1.add_prim_attr("primitive_target", "GPU")
output_heter = net(inp1, inp2)
assert np.allclose(output_device.asnumpy(), output_heter.asnumpy(), 1e-6, 1e-6)