forked from mindspore-Ecosystem/mindspore
pynative-support-dynamic-shape
This commit is contained in:
parent
000fb7e332
commit
fd5be43598
|
@ -390,6 +390,8 @@ void AscendSession::RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &gra
|
|||
MS_LOG(INFO) << "Run op " << op_run_info.op_name << " start!";
|
||||
// malloc mem
|
||||
RunOpMemoryAlloc(op_run_info.value, input_tensors, graph.get());
|
||||
// Build dynamic kernel
|
||||
BuildDynamicKernel(graph);
|
||||
// load input data to device
|
||||
LoadInputData(graph, input_tensors);
|
||||
// run op
|
||||
|
@ -510,6 +512,17 @@ void AscendSession::BuildKernel(const std::shared_ptr<KernelGraph> &kernel_graph
|
|||
MS_LOG(INFO) << "Finish!";
|
||||
}
|
||||
|
||||
void AscendSession::BuildDynamicKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const {
|
||||
MS_LOG(INFO) << "Start!";
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
|
||||
MS_EXCEPTION_IF_NULL(runtime_instance);
|
||||
if (!runtime_instance->GenDynamicKernel(kernel_graph.get())) {
|
||||
MS_LOG(DEBUG) << "Graph:" << kernel_graph->graph_id() << " failed to generate dynamic kernel!";
|
||||
}
|
||||
MS_LOG(INFO) << "Finish!";
|
||||
}
|
||||
|
||||
void AscendSession::MemoryAlloc(KernelGraph *kernel_graph) const {
|
||||
MS_LOG(INFO) << "Start!";
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
|
|
|
@ -90,6 +90,7 @@ class AscendSession : public SessionBasic {
|
|||
void RunOpAdjustKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const;
|
||||
void AssignStream(NotNull<KernelGraphPtr> kernel_graph) const;
|
||||
void BuildKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const;
|
||||
void BuildDynamicKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const;
|
||||
void MemoryAlloc(KernelGraph *kernel_graph) const;
|
||||
void RunOpMemoryAlloc(const ValuePtr &pre_output_value, const std::vector<tensor::TensorPtr> &input_tensors,
|
||||
KernelGraph *kernel_graph) const;
|
||||
|
|
|
@ -1315,6 +1315,8 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructSingleOpGraph(const OpRunInf
|
|||
const std::vector<tensor::TensorPtr> &input_tensors,
|
||||
const std::vector<int> &tensors_mask) {
|
||||
auto graph = std::make_shared<KernelGraph>();
|
||||
graph->set_graph_id(run_op_graph_id_);
|
||||
run_op_graph_id_++;
|
||||
std::vector<AnfNodePtr> inputs;
|
||||
// set input[0]
|
||||
PrimitivePtr op_prim = op_run_info.primitive;
|
||||
|
@ -1343,9 +1345,12 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructSingleOpGraph(const OpRunInf
|
|||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
// set abstract,which include inferred shapes and types
|
||||
cnode->set_abstract(op_run_info.abstract);
|
||||
// get output dynamic shape info
|
||||
AnfAlgo::SetNodeAttr(kAttrOutputIsDynamicShape, MakeValue(op_run_info.is_dynamic_shape), cnode);
|
||||
// set execution order
|
||||
std::vector<CNodePtr> exe_order = {cnode};
|
||||
graph->set_execution_order(exe_order);
|
||||
graph->UpdateGraphDynamicAttr();
|
||||
// set output
|
||||
CreateOutputNode(cnode, graph);
|
||||
graph->SetInputNodes();
|
||||
|
|
|
@ -50,12 +50,13 @@ struct OpRunInfo {
|
|||
PrimitivePtr primitive;
|
||||
AbstractBasePtr abstract;
|
||||
ValuePtr value = nullptr;
|
||||
bool is_dynamic_shape = false;
|
||||
};
|
||||
using OpRunInfoPtr = std::shared_ptr<OpRunInfo>;
|
||||
class Executor;
|
||||
class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
|
||||
public:
|
||||
SessionBasic() : context_(nullptr), summary_callback_(nullptr), device_id_(0) {
|
||||
SessionBasic() : context_(nullptr), summary_callback_(nullptr), device_id_(0), run_op_graph_id_(0) {
|
||||
#if !defined(_WIN32) && !defined(_WIN64)
|
||||
debugger_ = nullptr;
|
||||
#endif
|
||||
|
@ -182,6 +183,7 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
|
|||
CallBackFunc summary_callback_;
|
||||
static GraphId graph_sum_;
|
||||
uint32_t device_id_;
|
||||
uint32_t run_op_graph_id_;
|
||||
std::shared_ptr<Executor> executor_;
|
||||
#if !defined(_WIN32) && !defined(_WIN64)
|
||||
std::shared_ptr<Debugger> debugger_;
|
||||
|
|
|
@ -53,8 +53,9 @@ struct OpExecInfo {
|
|||
std::string prim_id;
|
||||
PrimitivePyPtr py_primitive;
|
||||
AbstractBasePtr abstract;
|
||||
ValuePtr value = nullptr;
|
||||
bool is_dynamic_shape = false;
|
||||
|
||||
ValuePtr value = nullptr;
|
||||
py::list op_inputs;
|
||||
py::dict op_attrs;
|
||||
std::vector<bool> inputs_mask;
|
||||
|
|
|
@ -758,6 +758,13 @@ py::tuple PynativeExecutor::RunOpInner(const py::args &args) {
|
|||
cnode->set_abstract(op_exec_info->abstract);
|
||||
}
|
||||
|
||||
// get output dynamic shape info
|
||||
MS_EXCEPTION_IF_NULL(op_exec_info->abstract);
|
||||
auto abstract_info = op_exec_info->abstract->ToString();
|
||||
if (abstract_info.find("-1") != string::npos) {
|
||||
op_exec_info->is_dynamic_shape = true;
|
||||
}
|
||||
|
||||
op_exec_info->inputs_mask = op_masks;
|
||||
MS_EXCEPTION_IF_NULL(op_exec_info);
|
||||
if (op_exec_info->abstract != nullptr) {
|
||||
|
@ -1301,7 +1308,7 @@ py::object PynativeExecutor::RunOpInMs(const OpExecInfoPtr &op_exec_info, Pynati
|
|||
// get graph info for checking it whether existing in the cache
|
||||
std::string graph_info = GetSingleOpGraphInfo(op_exec_info, input_tensors);
|
||||
session::OpRunInfo op_run_info = {op_exec_info->op_name, op_exec_info->py_primitive, op_exec_info->abstract,
|
||||
op_exec_info->value};
|
||||
op_exec_info->value, op_exec_info->is_dynamic_shape};
|
||||
session->BuildOp(&op_run_info, graph_info, input_tensors, tensors_mask);
|
||||
EraseValueNodeTensor(tensors_mask, &input_tensors);
|
||||
VectorRef outputs;
|
||||
|
|
|
@ -358,16 +358,13 @@ bool AscendKernelRuntime::GenDynamicKernel(const session::KernelGraph *graph) {
|
|||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_LOG(INFO) << "Generate node:" << cnode->fullname_with_scope() << " dynamic kernel";
|
||||
auto kernel_mod = AnfAlgo::GetKernelMod(cnode);
|
||||
MS_EXCEPTION_IF_NULL(kernel_mod);
|
||||
auto dynamic_kernel = kernel_mod->GenDynamicKernel(cnode, stream_);
|
||||
MS_EXCEPTION_IF_NULL(dynamic_kernel);
|
||||
dynamic_kernel->Initialize();
|
||||
dynamic_kernels.emplace_back(dynamic_kernel);
|
||||
}
|
||||
auto ret = graph_dynamic_kernel_map_.try_emplace(graph->graph_id(), dynamic_kernels);
|
||||
if (!ret.second) {
|
||||
MS_LOG(ERROR) << "Graph:" << graph->graph_id() << " already generator executor";
|
||||
return false;
|
||||
}
|
||||
graph_dynamic_kernel_map_[graph->graph_id()] = dynamic_kernels;
|
||||
MS_LOG(INFO) << "GenDynamicKernel end";
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -28,7 +28,7 @@ void DynamicKernel::Initialize() {
|
|||
MS_LOG(INFO) << "Init Start";
|
||||
is_dynamic_shape_ = AnfAlgo::IsDynamicShape(cnode_ptr_);
|
||||
if (!is_dynamic_shape_) {
|
||||
MS_LOG(INFO) << "cnode is not dynamic shape:" << cnode_ptr_->fullname_with_scope();
|
||||
MS_LOG(DEBUG) << "cnode is not dynamic shape:" << cnode_ptr_->fullname_with_scope();
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -37,7 +37,7 @@ void DynamicKernel::Initialize() {
|
|||
|
||||
auto have_depends = AnfAlgo::HasNodeAttr(kDynamicShapeDepends, cnode_ptr_);
|
||||
if (!have_depends) {
|
||||
MS_LOG(WARNING) << "No dynamic_shape_depends found";
|
||||
MS_LOG(DEBUG) << "No dynamic_shape_depends found";
|
||||
return;
|
||||
}
|
||||
MS_LOG(INFO) << "Have depends";
|
||||
|
|
|
@ -799,18 +799,38 @@ void KernelRuntime::GenAddrCleanLaunchArgs(const CNodePtr &cnode, AddressPtrList
|
|||
|
||||
bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph) {
|
||||
auto &kernels = graph.execution_order();
|
||||
for (const auto &kernel : kernels) {
|
||||
auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
|
||||
MS_EXCEPTION_IF_NULL(kernel_mod);
|
||||
|
||||
AddressPtrList kernel_inputs;
|
||||
AddressPtrList kernel_workspaces;
|
||||
AddressPtrList kernel_outputs;
|
||||
GenLaunchArgs(*kernel_mod, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs);
|
||||
auto ret = kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_);
|
||||
if (!ret) {
|
||||
MS_LOG(ERROR) << "Launch kernel failed.";
|
||||
return false;
|
||||
std::vector<DynamicKernelPtr> dynamic_kernel_list;
|
||||
auto iter = graph_dynamic_kernel_map_.find(graph.graph_id());
|
||||
if (iter != graph_dynamic_kernel_map_.end()) {
|
||||
dynamic_kernel_list = iter->second;
|
||||
}
|
||||
if (!dynamic_kernel_list.empty() && dynamic_kernel_list.size() != kernels.size()) {
|
||||
MS_LOG(EXCEPTION) << "The size of dynamic kernels " << dynamic_kernel_list.size()
|
||||
<< " should be equal to the size of kernels " << kernels.size();
|
||||
}
|
||||
for (size_t i = 0; i < kernels.size(); ++i) {
|
||||
if (!dynamic_kernel_list.empty() && dynamic_kernel_list[i] != nullptr &&
|
||||
dynamic_kernel_list[i]->is_dynamic_shape()) {
|
||||
dynamic_kernel_list[i]->InferShape();
|
||||
dynamic_kernel_list[i]->UpdateArgs();
|
||||
dynamic_kernel_list[i]->Execute();
|
||||
if (!SyncStream()) {
|
||||
MS_LOG(ERROR) << "SyncStream failed";
|
||||
return false;
|
||||
}
|
||||
dynamic_kernel_list[i]->PostExecute();
|
||||
} else {
|
||||
auto kernel_mod = AnfAlgo::GetKernelMod(kernels[i]);
|
||||
MS_EXCEPTION_IF_NULL(kernel_mod);
|
||||
AddressPtrList kernel_inputs;
|
||||
AddressPtrList kernel_workspaces;
|
||||
AddressPtrList kernel_outputs;
|
||||
GenLaunchArgs(*kernel_mod, kernels[i], &kernel_inputs, &kernel_workspaces, &kernel_outputs);
|
||||
auto ret = kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_);
|
||||
if (!ret) {
|
||||
MS_LOG(ERROR) << "Launch kernel failed.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
|
|
|
@ -0,0 +1,74 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor, Parameter, context
|
||||
from mindspore.nn import TrainOneStepCell
|
||||
from mindspore.nn.optim import FTRL, LazyAdam
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
context.set_context(enable_sparse=True,
|
||||
mode=context.PYNATIVE_MODE,
|
||||
device_target="Ascend")
|
||||
|
||||
class NetWithSparseGatherV2(nn.Cell):
|
||||
def __init__(self):
|
||||
super(NetWithSparseGatherV2, self).__init__()
|
||||
self.weight1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="weight1")
|
||||
self.weight2 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="weight2")
|
||||
self.axis = 1
|
||||
self.gather = P.SparseGatherV2()
|
||||
|
||||
def construct(self, indices, label):
|
||||
return self.gather(self.weight1, indices, self.axis) + self.weight2
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_pynative_ftrl_net():
|
||||
indices = Tensor(np.array([0, 0, 1]).astype(np.int32))
|
||||
label = Tensor(np.zeros([2, 1, 2]).astype(np.float32))
|
||||
net = NetWithSparseGatherV2()
|
||||
|
||||
optimizer = FTRL(net.trainable_params(), learning_rate=0.1, weight_decay=0.9, loss_scale=2.0)
|
||||
optimizer.target = 'Ascend'
|
||||
train_network = TrainOneStepCell(net, optimizer)
|
||||
output = train_network(indices, label)
|
||||
np.allclose(output.asnumpy(), np.array([[[2, 2]], [[2, 2]], [[2, 2]]]))
|
||||
np.allclose(net.weight1.asnumpy(), np.array([[[0.7884067, 0.7884067]],
|
||||
[[0.68213105, 0.68213105]],
|
||||
[[1.0, 1.0]]]))
|
||||
np.allclose(net.weight2.asnumpy(), np.array([[[0.6821311, 0.6821311]],
|
||||
[[0.6821311, 0.6821311]],
|
||||
[[0.6821311, 0.6821311]]]))
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_pynative_lazy_adam_net():
|
||||
indices = Tensor(np.array([0, 0, 1]).astype(np.int32))
|
||||
label = Tensor(np.zeros([2, 1, 2]).astype(np.float32))
|
||||
net = NetWithSparseGatherV2()
|
||||
|
||||
optimizer = LazyAdam(net.trainable_params(), learning_rate=0.1, weight_decay=0.9, loss_scale=2.0)
|
||||
optimizer.target = 'Ascend'
|
||||
train_network = TrainOneStepCell(net, optimizer)
|
||||
output = train_network(indices, label)
|
||||
np.allclose(output.asnumpy(), np.array([[[2, 2]], [[2, 2]], [[2, 2]]]))
|
||||
np.allclose(net.weight1.asnumpy(), np.array([[[0.9, 0.9]], [[0.9, 0.9]], [[1.0, 1.0]]]))
|
||||
np.allclose(net.weight2.asnumpy(), np.array([[[0.9, 0.9]], [[0.9, 0.9]], [[0.9, 0.9]]]))
|
|
@ -0,0 +1,44 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.unique = P.Unique()
|
||||
|
||||
def construct(self, x):
|
||||
return self.unique(x)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_pynative_unqiue():
|
||||
x = Tensor(np.array([1, 1, 2, 2, 3, 3]), mstype.int32)
|
||||
unique = Net()
|
||||
output = unique(x)
|
||||
expect1 = np.array([1, 2, 3])
|
||||
expect2 = np.array([0, 0, 1, 1, 2, 2])
|
||||
assert (output[0].asnumpy() == expect1).all()
|
||||
assert (output[1].asnumpy() == expect2).all()
|
Loading…
Reference in New Issue