forked from mindspore-Ecosystem/mindspore
!21490 Ascend control use vm
Merge pull request !21490 from chenfei_mindspore/ascend-control-use-vm
This commit is contained in:
commit
fca1cb34c8
|
@ -523,30 +523,14 @@ void AscendSession::BuildGraphImpl(GraphId graph_id) {
|
||||||
InitRuntimeResource();
|
InitRuntimeResource();
|
||||||
// multiple graph handle
|
// multiple graph handle
|
||||||
if (graph_id == final_graph_id_) {
|
if (graph_id == final_graph_id_) {
|
||||||
if (!graph->executable()) {
|
MS_LOG(EXCEPTION) << "Unexpected graph id:" << graph_id << ", final_graph_id_:" << final_graph_id_;
|
||||||
return;
|
|
||||||
}
|
|
||||||
SetFinalGraphSummaryFlag(graph);
|
|
||||||
// OptChildGraphs
|
|
||||||
auto graph_order = GetGraphOrder(final_graph_id_);
|
|
||||||
auto &graph_type = GetGraphOrderType(final_graph_id_);
|
|
||||||
for (size_t i = 0; i < graph_order.size(); i++) {
|
|
||||||
if (!(graph_type[i] == BRANCH_END || graph_type[i] == BRANCH_START)) {
|
|
||||||
auto child_graph = GetGraph(graph_order[i]);
|
|
||||||
CompileChildGraph(child_graph);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
SetSummaryNodes(graph.get());
|
|
||||||
// merge child graph
|
|
||||||
MergeGraphExecOrder();
|
|
||||||
} else {
|
|
||||||
auto single_graph = GetGraph(graph_id);
|
|
||||||
MS_EXCEPTION_IF_NULL(single_graph);
|
|
||||||
CompileChildGraph(single_graph);
|
|
||||||
// set the distinction label of single graph
|
|
||||||
single_graph->set_stream_distinction_label(graph_id);
|
|
||||||
single_graph->UpdateExecuteKernelStreamLabel();
|
|
||||||
}
|
}
|
||||||
|
auto single_graph = GetGraph(graph_id);
|
||||||
|
MS_EXCEPTION_IF_NULL(single_graph);
|
||||||
|
CompileChildGraph(single_graph);
|
||||||
|
// set the distinction label of single graph
|
||||||
|
single_graph->set_stream_distinction_label(graph_id);
|
||||||
|
single_graph->UpdateExecuteKernelStreamLabel();
|
||||||
// adjust execution order because merge child graph and other special operations
|
// adjust execution order because merge child graph and other special operations
|
||||||
AdjustKernel(graph);
|
AdjustKernel(graph);
|
||||||
#if ENABLE_CPU && ENABLE_D
|
#if ENABLE_CPU && ENABLE_D
|
||||||
|
|
|
@ -614,9 +614,18 @@ bool TaskEmitAction(const ResourcePtr &res) {
|
||||||
context_ptr->set_param<bool>(MS_CTX_ENABLE_LOOP_SINK, false);
|
context_ptr->set_param<bool>(MS_CTX_ENABLE_LOOP_SINK, false);
|
||||||
} else if (context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) {
|
} else if (context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) {
|
||||||
std::string device_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
|
std::string device_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
|
||||||
if (device_target == kAscendDevice && backend != kMsVm) {
|
auto manager = func_graph->manager();
|
||||||
|
auto graphs = manager->func_graphs();
|
||||||
|
bool exist_while =
|
||||||
|
std::any_of(graphs.cbegin(), graphs.cend(), [](const FuncGraphPtr &fg) { return fg->recursive(); });
|
||||||
|
if (device_target == kAscendDevice && backend != kMsVm && !exist_while) {
|
||||||
|
MS_LOG(INFO) << "Run graph mode with multigraph sink.";
|
||||||
bc_ptr->set_is_multi_graph_sink(true);
|
bc_ptr->set_is_multi_graph_sink(true);
|
||||||
context_ptr->set_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK, true);
|
context_ptr->set_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK, true);
|
||||||
|
} else {
|
||||||
|
MS_LOG(INFO) << "Run graph mode with vm.";
|
||||||
|
bc_ptr->set_is_multi_graph_sink(false);
|
||||||
|
context_ptr->set_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK, false);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -142,20 +142,21 @@ std::string GetCompileExceptionInfo() {
|
||||||
return oss.str();
|
return oss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetGpuLoopSink(const ResourcePtr &resource) {
|
void SetLoopCount(const ResourcePtr &resource) {
|
||||||
MS_EXCEPTION_IF_NULL(resource);
|
MS_EXCEPTION_IF_NULL(resource);
|
||||||
auto func_graph = resource->func_graph();
|
auto func_graph = resource->func_graph();
|
||||||
if (func_graph != nullptr && func_graph->manager() != nullptr) {
|
if (func_graph != nullptr && func_graph->manager() != nullptr) {
|
||||||
auto manager = func_graph->manager();
|
auto manager = func_graph->manager();
|
||||||
size_t graph_nums = manager->func_graphs().size();
|
size_t graph_nums = manager->func_graphs().size();
|
||||||
int64_t sinksize = ConfigManager::GetInstance().iter_num();
|
int64_t loop_size = ConfigManager::GetInstance().iter_num();
|
||||||
if (graph_nums == 1 || MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_MINDRT)) {
|
const auto context_ptr = MsContext::GetInstance();
|
||||||
resource->set_gpu_loopsink(true, sinksize);
|
if (context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kAscendDevice) {
|
||||||
} else {
|
resource->set_vm_loop(!context_ptr->get_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK), loop_size);
|
||||||
resource->set_gpu_loopsink(false, sinksize);
|
} else if (context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kGPUDevice) {
|
||||||
|
bool run_with_mind_rt = graph_nums == 1 || context_ptr->get_param<bool>(MS_CTX_ENABLE_MINDRT);
|
||||||
|
resource->set_vm_loop(!run_with_mind_rt, loop_size);
|
||||||
}
|
}
|
||||||
MS_LOG(INFO) << "Change gpu_loopsink_flag_ to " << resource->gpu_loopsink_flag() << ", set loopsink size to "
|
MS_LOG(INFO) << "Change vm_loop_flag to " << resource->vm_loop_flag() << ", set loop_size to " << loop_size;
|
||||||
<< sinksize;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -827,7 +828,7 @@ void Pipeline::Run(const std::string &phase_s) {
|
||||||
MS_LOG(DEBUG) << "Action " << action.first << " end.";
|
MS_LOG(DEBUG) << "Action " << action.first << " end.";
|
||||||
};
|
};
|
||||||
if (action.first == "task_emit") {
|
if (action.first == "task_emit") {
|
||||||
SetGpuLoopSink(resource_);
|
SetLoopCount(resource_);
|
||||||
} else if (action.first == "validate") {
|
} else if (action.first == "validate") {
|
||||||
CacheValidateFuncGraph(phase_s, resource_);
|
CacheValidateFuncGraph(phase_s, resource_);
|
||||||
}
|
}
|
||||||
|
@ -1003,13 +1004,17 @@ py::object ExecutorPy::Run(const py::tuple &args, const py::object &phase) {
|
||||||
MS_LOG(EXCEPTION) << "Can't find run graph func for " << phase_s;
|
MS_LOG(EXCEPTION) << "Can't find run graph func for " << phase_s;
|
||||||
}
|
}
|
||||||
// Set loopsink size for each phase.
|
// Set loopsink size for each phase.
|
||||||
bool is_loopsink = info_[phase_s]->resource->gpu_loopsink_flag();
|
bool vm_loop_flag = info_[phase_s]->resource->vm_loop_flag();
|
||||||
int64_t sinksize = info_[phase_s]->resource->gpu_loopsink_size();
|
int64_t loop_size = info_[phase_s]->resource->loop_size();
|
||||||
ConfigManager::GetInstance().set_gpu_loopsink_size(is_loopsink ? sinksize : 1);
|
int64_t vm_loop = 1;
|
||||||
// If target is not gpu or is loopsink, keep vmloop 1.
|
if (vm_loop_flag) {
|
||||||
bool g = (MsContext::GetInstance()->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kGPUDevice);
|
vm_loop = loop_size;
|
||||||
int64_t vm_loop = (!g || is_loopsink) ? 1 : sinksize;
|
} else {
|
||||||
MS_LOG(INFO) << "VM loop size " << vm_loop << ", loopsink size " << (is_loopsink ? sinksize : 1);
|
// Set the loop size in config if graphs nums is 1(is_loop_sin=True), then there will be a loop embrace
|
||||||
|
// 'Execute(graph)' in GPUSession.
|
||||||
|
ConfigManager::GetInstance().set_gpu_loopsink_size(loop_size);
|
||||||
|
}
|
||||||
|
MS_LOG(INFO) << "VM loop size " << vm_loop << ", loopsink size " << vm_loop;
|
||||||
py::object ret;
|
py::object ret;
|
||||||
MS_LOG(DEBUG) << "Eval run" << backend;
|
MS_LOG(DEBUG) << "Eval run" << backend;
|
||||||
for (int64_t i = 0; i < vm_loop; i++) {
|
for (int64_t i = 0; i < vm_loop; i++) {
|
||||||
|
@ -1159,9 +1164,6 @@ bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batc
|
||||||
// Convert CNodeList to LinConvertResult.
|
// Convert CNodeList to LinConvertResult.
|
||||||
auto segment = std::make_shared<GraphSegment>(std::vector<AnfNodePtr>{app_init}, false);
|
auto segment = std::make_shared<GraphSegment>(std::vector<AnfNodePtr>{app_init}, false);
|
||||||
auto runner = convert_fn(segment, "");
|
auto runner = convert_fn(segment, "");
|
||||||
if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) {
|
|
||||||
backend->Link(runner.graph_id);
|
|
||||||
}
|
|
||||||
ConfigManager::GetInstance().set_iter_num(size);
|
ConfigManager::GetInstance().set_iter_num(size);
|
||||||
// PS cache does not support loop sink.
|
// PS cache does not support loop sink.
|
||||||
#if ((defined ENABLE_CPU) && (!defined _WIN32))
|
#if ((defined ENABLE_CPU) && (!defined _WIN32))
|
||||||
|
|
|
@ -75,14 +75,14 @@ class Resource : public ResourceBase {
|
||||||
const abstract::AbstractBasePtrList &args_spec() const { return args_spec_; }
|
const abstract::AbstractBasePtrList &args_spec() const { return args_spec_; }
|
||||||
void set_args_spec(const abstract::AbstractBasePtrList &args_spec) { args_spec_ = args_spec; }
|
void set_args_spec(const abstract::AbstractBasePtrList &args_spec) { args_spec_ = args_spec; }
|
||||||
|
|
||||||
void set_gpu_loopsink(const bool &flag, const int64_t size) {
|
void set_vm_loop(const bool &flag, const int64_t size) {
|
||||||
gpu_loopsink_flag_ = flag;
|
vm_loop_flag_ = flag;
|
||||||
gpu_loopsink_size_ = size;
|
loop_size_ = size;
|
||||||
}
|
}
|
||||||
void set_is_load(bool flag) { is_load_ = flag; }
|
void set_is_load(bool flag) { is_load_ = flag; }
|
||||||
bool is_load() { return is_load_; }
|
bool is_load() { return is_load_; }
|
||||||
bool gpu_loopsink_flag() { return gpu_loopsink_flag_; }
|
bool vm_loop_flag() { return vm_loop_flag_; }
|
||||||
int64_t gpu_loopsink_size() { return gpu_loopsink_size_; }
|
int64_t loop_size() { return loop_size_; }
|
||||||
// Reclaim resource and clear the cache.
|
// Reclaim resource and clear the cache.
|
||||||
// ExecutorPy::Compile() can be called multiple times, so cache
|
// ExecutorPy::Compile() can be called multiple times, so cache
|
||||||
// should be cleared.
|
// should be cleared.
|
||||||
|
@ -94,10 +94,10 @@ class Resource : public ResourceBase {
|
||||||
abstract::AbstractBasePtrList args_spec_;
|
abstract::AbstractBasePtrList args_spec_;
|
||||||
py::object input_;
|
py::object input_;
|
||||||
bool is_cleaned_;
|
bool is_cleaned_;
|
||||||
bool gpu_loopsink_flag_{false};
|
|
||||||
// The func_graph_ is loaded from mindir
|
// The func_graph_ is loaded from mindir
|
||||||
bool is_load_{false};
|
bool is_load_{false};
|
||||||
int64_t gpu_loopsink_size_{1};
|
bool vm_loop_flag_{false};
|
||||||
|
int64_t loop_size_{1};
|
||||||
};
|
};
|
||||||
|
|
||||||
using ResourcePtr = std::shared_ptr<pipeline::Resource>;
|
using ResourcePtr = std::shared_ptr<pipeline::Resource>;
|
||||||
|
|
|
@ -289,14 +289,6 @@ VectorRef MsBackend::MsRunGraph(const GraphId &g, const VectorRef &args, const s
|
||||||
return outputs;
|
return outputs;
|
||||||
}
|
}
|
||||||
|
|
||||||
void MsBackend::Link(GraphId graph_id) {
|
|
||||||
MS_EXCEPTION_IF_NULL(target_sess_);
|
|
||||||
if (graph_id == kInvalidGraphId) {
|
|
||||||
graph_id = target_sess_->GetFinalRunGraph();
|
|
||||||
}
|
|
||||||
target_sess_->BuildGraph(graph_id);
|
|
||||||
}
|
|
||||||
|
|
||||||
MsBackend::MsBackend(const std::string &name, const std::string &target, uint32_t device_id) : Backend(name) {
|
MsBackend::MsBackend(const std::string &name, const std::string &target, uint32_t device_id) : Backend(name) {
|
||||||
convert_fn_ = std::bind(&MsBackend::MsConvert, this, std::placeholders::_1, std::placeholders::_2);
|
convert_fn_ = std::bind(&MsBackend::MsConvert, this, std::placeholders::_1, std::placeholders::_2);
|
||||||
target_sess_ = session::SessionFactory::Get().Create(target);
|
target_sess_ = session::SessionFactory::Get().Create(target);
|
||||||
|
|
|
@ -61,7 +61,6 @@ class Backend {
|
||||||
virtual bool GetCond(const BaseRef &c, bool *value);
|
virtual bool GetCond(const BaseRef &c, bool *value);
|
||||||
virtual bool GetIndex(const BaseRef &c, int64_t *value);
|
virtual bool GetIndex(const BaseRef &c, int64_t *value);
|
||||||
virtual GraphId CompileGraph(NotNull<FuncGraphPtr> fg) { return kInvalidGraphId; }
|
virtual GraphId CompileGraph(NotNull<FuncGraphPtr> fg) { return kInvalidGraphId; }
|
||||||
virtual void Link(GraphId) {}
|
|
||||||
virtual void SetDebugger() {}
|
virtual void SetDebugger() {}
|
||||||
|
|
||||||
bool is_multi_graph_sink() const { return is_multi_graph_sink_; }
|
bool is_multi_graph_sink() const { return is_multi_graph_sink_; }
|
||||||
|
@ -82,7 +81,6 @@ class MsBackend : public Backend {
|
||||||
VectorRef MsRunGraph(const GraphId &g, const VectorRef &args, const std::string &target = "");
|
VectorRef MsRunGraph(const GraphId &g, const VectorRef &args, const std::string &target = "");
|
||||||
|
|
||||||
VectorRef MsSimuRunGraph(const GraphId &g);
|
VectorRef MsSimuRunGraph(const GraphId &g);
|
||||||
void Link(GraphId) override;
|
|
||||||
GraphId CompileGraph(NotNull<FuncGraphPtr> fg) override;
|
GraphId CompileGraph(NotNull<FuncGraphPtr> fg) override;
|
||||||
VectorRef RunGraph(GraphId graph_id, const VectorRef &args);
|
VectorRef RunGraph(GraphId graph_id, const VectorRef &args);
|
||||||
void ClearSessionGraphs();
|
void ClearSessionGraphs();
|
||||||
|
|
|
@ -580,9 +580,6 @@ BackendPtr CreateBackend() {
|
||||||
if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
|
if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
|
||||||
backend->set_is_multi_graph_sink(false);
|
backend->set_is_multi_graph_sink(false);
|
||||||
context_ptr->set_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK, false);
|
context_ptr->set_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK, false);
|
||||||
} else {
|
|
||||||
backend->set_is_multi_graph_sink(true);
|
|
||||||
context_ptr->set_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK, true);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return backend;
|
return backend;
|
||||||
|
|
|
@ -758,13 +758,9 @@ FuncGraphPtr TransformableClone(const FuncGraphPtr &func_graph, const TraceInfoP
|
||||||
for (auto &item : func_graph->parameter_default_value()) {
|
for (auto &item : func_graph->parameter_default_value()) {
|
||||||
new_func_graph->set_param_default_value(item.first, cloner[item.second]);
|
new_func_graph->set_param_default_value(item.first, cloner[item.second]);
|
||||||
}
|
}
|
||||||
|
if (func_graph->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES)) {
|
||||||
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK)) {
|
new_func_graph->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true);
|
||||||
if (func_graph->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES)) {
|
|
||||||
new_func_graph->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
|
if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
|
||||||
new_func_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, func_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
|
new_func_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, func_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
|
||||||
}
|
}
|
||||||
|
|
|
@ -52,16 +52,20 @@ def test_single_for_01():
|
||||||
|
|
||||||
# graph mode
|
# graph mode
|
||||||
context.set_context(mode=context.GRAPH_MODE)
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
for_net_foward = SingleForNet()
|
||||||
|
graph_forward_res = for_net_foward(x, y, z)
|
||||||
|
|
||||||
for_net = SingleForNet()
|
for_net = SingleForNet()
|
||||||
net = GradNet(for_net)
|
net = GradNet(for_net)
|
||||||
graph_forward_res = for_net(x, y, z)
|
|
||||||
graph_backward_res = net(x, y, z)
|
graph_backward_res = net(x, y, z)
|
||||||
|
|
||||||
# pynative mode
|
# pynative mode
|
||||||
context.set_context(mode=context.PYNATIVE_MODE)
|
context.set_context(mode=context.PYNATIVE_MODE)
|
||||||
|
for_net_foward = SingleForNet()
|
||||||
|
pynative_forward_res = for_net_foward(x, y, z)
|
||||||
|
|
||||||
for_net = SingleForNet()
|
for_net = SingleForNet()
|
||||||
net = GradNet(for_net)
|
net = GradNet(for_net)
|
||||||
pynative_forward_res = for_net(x, y, z)
|
|
||||||
pynative_backward_res = net(x, y, z)
|
pynative_backward_res = net(x, y, z)
|
||||||
|
|
||||||
assert graph_forward_res == pynative_forward_res
|
assert graph_forward_res == pynative_forward_res
|
||||||
|
|
|
@ -23,6 +23,7 @@ from mindspore.common import dtype as mstype
|
||||||
grad_all = C.GradOperation(get_all=True)
|
grad_all = C.GradOperation(get_all=True)
|
||||||
context.set_context(device_target="Ascend")
|
context.set_context(device_target="Ascend")
|
||||||
|
|
||||||
|
|
||||||
def test_for_in_if_01():
|
def test_for_in_if_01():
|
||||||
class ForInIfNet(nn.Cell):
|
class ForInIfNet(nn.Cell):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -69,6 +70,7 @@ def test_for_in_if_01():
|
||||||
assert graph_forward_res == pynative_forward_res
|
assert graph_forward_res == pynative_forward_res
|
||||||
assert graph_backward_res == pynative_backward_res
|
assert graph_backward_res == pynative_backward_res
|
||||||
|
|
||||||
|
|
||||||
def test_for_in_if_02():
|
def test_for_in_if_02():
|
||||||
class ForInIfNet(nn.Cell):
|
class ForInIfNet(nn.Cell):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -100,7 +102,7 @@ def test_for_in_if_02():
|
||||||
def construct(self, *inputs):
|
def construct(self, *inputs):
|
||||||
return grad_all(self.net)(*inputs)
|
return grad_all(self.net)(*inputs)
|
||||||
|
|
||||||
x = Tensor([10], mstype.int32)
|
x = Tensor([10], mstype.float32)
|
||||||
|
|
||||||
# graph mode
|
# graph mode
|
||||||
context.set_context(mode=context.GRAPH_MODE)
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
@ -152,7 +154,7 @@ def test_for_in_if_03():
|
||||||
def construct(self, *inputs):
|
def construct(self, *inputs):
|
||||||
return grad_all(self.net)(*inputs)
|
return grad_all(self.net)(*inputs)
|
||||||
|
|
||||||
x = Tensor([10], mstype.int32)
|
x = Tensor([10], mstype.float32)
|
||||||
|
|
||||||
# graph mode
|
# graph mode
|
||||||
context.set_context(mode=context.GRAPH_MODE)
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
|
|
@ -13,6 +13,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pytest
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
from mindspore import Tensor, nn
|
from mindspore import Tensor, nn
|
||||||
from mindspore.common.parameter import Parameter
|
from mindspore.common.parameter import Parameter
|
||||||
|
@ -23,6 +24,7 @@ from mindspore.common import dtype as mstype
|
||||||
grad_all = C.GradOperation(get_all=True)
|
grad_all = C.GradOperation(get_all=True)
|
||||||
context.set_context(device_target="Ascend")
|
context.set_context(device_target="Ascend")
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="not supported for in while")
|
||||||
def test_for_in_while_01():
|
def test_for_in_while_01():
|
||||||
class ForInWhileNet(nn.Cell):
|
class ForInWhileNet(nn.Cell):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -74,7 +76,7 @@ def test_for_in_while_01():
|
||||||
assert graph_forward_res == pynative_forward_res
|
assert graph_forward_res == pynative_forward_res
|
||||||
assert graph_backward_res == pynative_backward_res
|
assert graph_backward_res == pynative_backward_res
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="not supported for in while")
|
||||||
def test_for_in_while_02():
|
def test_for_in_while_02():
|
||||||
class ForInWhileNet(nn.Cell):
|
class ForInWhileNet(nn.Cell):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
|
@ -105,16 +105,20 @@ class GradNet(nn.Cell):
|
||||||
def control_flow_if_after_if(input_net, x, y):
|
def control_flow_if_after_if(input_net, x, y):
|
||||||
# graph mode
|
# graph mode
|
||||||
context.set_context(mode=context.GRAPH_MODE)
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
forward_net = input_net()
|
||||||
|
|
||||||
net = input_net()
|
net = input_net()
|
||||||
grad_net = GradNet(net)
|
grad_net = GradNet(net)
|
||||||
graph_forward_res = net(x, y)
|
graph_forward_res = forward_net(x, y)
|
||||||
graph_backward_res = grad_net(x, y)
|
graph_backward_res = grad_net(x, y)
|
||||||
|
|
||||||
# pynative mode
|
# pynative mode
|
||||||
context.set_context(mode=context.PYNATIVE_MODE)
|
context.set_context(mode=context.PYNATIVE_MODE)
|
||||||
|
forward_net = input_net()
|
||||||
|
|
||||||
net = input_net()
|
net = input_net()
|
||||||
grad_net = GradNet(net)
|
grad_net = GradNet(net)
|
||||||
pynative_forward_res = net(x, y)
|
pynative_forward_res = forward_net(x, y)
|
||||||
pynative_backward_res = grad_net(x, y)
|
pynative_backward_res = grad_net(x, y)
|
||||||
|
|
||||||
assert graph_forward_res == pynative_forward_res
|
assert graph_forward_res == pynative_forward_res
|
||||||
|
|
|
@ -12,6 +12,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
import pytest
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
from mindspore import Tensor, nn
|
from mindspore import Tensor, nn
|
||||||
from mindspore.ops import composite as C
|
from mindspore.ops import composite as C
|
||||||
|
@ -21,6 +22,7 @@ from mindspore.common.parameter import Parameter
|
||||||
grad_all = C.GradOperation(get_all=True)
|
grad_all = C.GradOperation(get_all=True)
|
||||||
context.set_context(device_target="Ascend")
|
context.set_context(device_target="Ascend")
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="not supported for in while")
|
||||||
def test_if_after_for_in_while():
|
def test_if_after_for_in_while():
|
||||||
class IfAfterForInWhileNet(nn.Cell):
|
class IfAfterForInWhileNet(nn.Cell):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pytest
|
||||||
from mindspore.common import dtype as mstype
|
from mindspore.common import dtype as mstype
|
||||||
from mindspore import nn
|
from mindspore import nn
|
||||||
from mindspore import Tensor
|
from mindspore import Tensor
|
||||||
|
@ -54,7 +55,7 @@ class BackwardNet(nn.Cell):
|
||||||
grads = self.grad(self.forward_net)(*inputs)
|
grads = self.grad(self.forward_net)(*inputs)
|
||||||
return grads
|
return grads
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="not supported for in while")
|
||||||
def test_forward():
|
def test_forward():
|
||||||
x = Tensor(np.array(1), mstype.int32)
|
x = Tensor(np.array(1), mstype.int32)
|
||||||
y = Tensor(np.array(3), mstype.int32)
|
y = Tensor(np.array(3), mstype.int32)
|
||||||
|
@ -62,7 +63,7 @@ def test_forward():
|
||||||
out = forward_net(x, y)
|
out = forward_net(x, y)
|
||||||
print("forward out:", out)
|
print("forward out:", out)
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="not supported for in while")
|
||||||
def test_backward():
|
def test_backward():
|
||||||
x = Tensor(np.array(1), mstype.int32)
|
x = Tensor(np.array(1), mstype.int32)
|
||||||
y = Tensor(np.array(3), mstype.int32)
|
y = Tensor(np.array(3), mstype.int32)
|
||||||
|
|
|
@ -13,6 +13,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pytest
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
from mindspore import Tensor, nn
|
from mindspore import Tensor, nn
|
||||||
from mindspore.common.parameter import Parameter
|
from mindspore.common.parameter import Parameter
|
||||||
|
@ -22,7 +23,7 @@ from mindspore.common import dtype as mstype
|
||||||
|
|
||||||
grad_all = C.GradOperation(get_all=True)
|
grad_all = C.GradOperation(get_all=True)
|
||||||
context.set_context(device_target="Ascend")
|
context.set_context(device_target="Ascend")
|
||||||
|
@pytest.mark.skip(reason="not supported for in while")
|
||||||
def test_for_after_for_in_while_01():
|
def test_for_after_for_in_while_01():
|
||||||
class ForAfterForInWhileNet(nn.Cell):
|
class ForAfterForInWhileNet(nn.Cell):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -87,7 +88,7 @@ def test_for_after_for_in_while_01():
|
||||||
assert graph_forward_res == pynative_forward_res
|
assert graph_forward_res == pynative_forward_res
|
||||||
assert graph_backward_res == pynative_backward_res
|
assert graph_backward_res == pynative_backward_res
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="not supported for in while")
|
||||||
def test_for_after_for_in_while_02():
|
def test_for_after_for_in_while_02():
|
||||||
class ForAfterForInWhileNet(nn.Cell):
|
class ForAfterForInWhileNet(nn.Cell):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
Loading…
Reference in New Issue