forked from mindspore-Ecosystem/mindspore
[ME][Auto-Monad] Insert Tensor for the Load whose refkey appears more than once,
or the load is input of call or partial, or the first input of load is call or partial.
This commit is contained in:
parent
de34e90e4d
commit
6dcab5a498
|
@ -22,6 +22,7 @@
|
|||
#include <utility>
|
||||
#include "utils/hash_map.h"
|
||||
#include "utils/hash_set.h"
|
||||
#include "utils/utils.h"
|
||||
#include "base/core_ops.h"
|
||||
|
||||
namespace mindspore::pipeline {
|
||||
|
@ -35,6 +36,15 @@ class OrderEnforcer {
|
|||
~OrderEnforcer() = default;
|
||||
|
||||
void Run() {
|
||||
// In order to store current value of parameter, insert TensorMove for Load:
|
||||
// whose refkey appears more than once,
|
||||
// or the load is input of call or partial,
|
||||
// or the first input of load is call or partial.
|
||||
std::vector<CNodePtr> need_insert_loads = GetNeedInsertLoads();
|
||||
for (auto &node : need_insert_loads) {
|
||||
InsertTensorMoveForLoad(node->cast<CNodePtr>());
|
||||
}
|
||||
|
||||
auto nodes = MakeTopoSortMap();
|
||||
for (auto &node : nodes) {
|
||||
if (IsPrimitiveCNode(node, prim::kPrimUpdateState)) {
|
||||
|
@ -341,6 +351,138 @@ class OrderEnforcer {
|
|||
});
|
||||
}
|
||||
|
||||
std::string GetRefKey(const AnfNodePtr &node) {
|
||||
auto abs = node->abstract();
|
||||
if (abs == nullptr) {
|
||||
if (IsPrimitiveCNode(node, prim::kPrimDepend)) {
|
||||
return GetRefKey(node->cast<CNodePtr>()->input(1));
|
||||
}
|
||||
return "";
|
||||
}
|
||||
auto abs_ref = abs->cast<abstract::AbstractRefPtr>();
|
||||
if (abs_ref == nullptr) {
|
||||
return "";
|
||||
}
|
||||
auto ref_key = abs_ref->ref_key_value();
|
||||
if (ref_key == nullptr) {
|
||||
return "";
|
||||
}
|
||||
return ref_key->name();
|
||||
}
|
||||
|
||||
std::vector<CNodePtr> GetAllLoads(const AnfNodePtrList &check_nodes) {
|
||||
std::vector<CNodePtr> need_insert_loads;
|
||||
for (auto &node : check_nodes) {
|
||||
if (IsPrimitiveCNode(node, prim::kPrimLoad)) {
|
||||
auto load = node->cast<CNodePtr>();
|
||||
(void)need_insert_loads.emplace_back(load);
|
||||
}
|
||||
}
|
||||
return need_insert_loads;
|
||||
}
|
||||
|
||||
std::vector<CNodePtr> GetSpecialLoads(const std::map<std::string, std::vector<CNodePtr>> &loads_map1,
|
||||
const std::map<std::string, std::vector<CNodePtr>> &loads_map2,
|
||||
const std::map<std::string, std::vector<CNodePtr>> &loads_map3) {
|
||||
std::vector<CNodePtr> need_insert_loads;
|
||||
for (auto &refkey_load : loads_map1) {
|
||||
auto &loads = refkey_load.second;
|
||||
if (loads.size() > 1) {
|
||||
(void)std::transform(loads.begin(), loads.end(), std::back_inserter(need_insert_loads),
|
||||
[](const CNodePtr &load) { return load; });
|
||||
}
|
||||
}
|
||||
for (auto &refkey_load_special : loads_map2) {
|
||||
auto &loads = refkey_load_special.second;
|
||||
// If loads size > 1, mean has exist in refkey_loads.
|
||||
if (loads.size() == 1) {
|
||||
(void)need_insert_loads.emplace_back(loads[0]);
|
||||
}
|
||||
}
|
||||
for (auto &refkey_load_special : loads_map3) {
|
||||
auto &loads = refkey_load_special.second;
|
||||
// If loads size > 1, mean has exist in refkey_loads.
|
||||
if (loads.size() == 1) {
|
||||
(void)need_insert_loads.emplace_back(loads[0]);
|
||||
}
|
||||
}
|
||||
return need_insert_loads;
|
||||
}
|
||||
|
||||
bool CheckLoadInput(const AnfNodePtr &input) {
|
||||
return IsPrimitiveCNode(input, prim::kPrimCall) || IsPrimitiveCNode(input, prim::kPrimPartial) ||
|
||||
(input->isa<CNode>() && (IsValueNode<FuncGraph>(input->cast<CNodePtr>()->input(0)) ||
|
||||
IsPrimitiveCNode(input->cast<CNodePtr>()->input(0), prim::kPrimSwitch) ||
|
||||
IsPrimitiveCNode(input->cast<CNodePtr>()->input(0), prim::kPrimSwitchLayer)));
|
||||
}
|
||||
|
||||
std::vector<CNodePtr> GetNeedInsertLoads() {
|
||||
auto check_nodes = TopoSort(func_graph_->get_return());
|
||||
static bool enable_all_load = common::GetEnv("MS_DEV_ENABLE_LOAD_INSERT_TENSORMOVE") == "1";
|
||||
// Insert TensorMove for all Load nodes
|
||||
if (enable_all_load) {
|
||||
return GetAllLoads(check_nodes);
|
||||
}
|
||||
std::map<std::string, std::vector<CNodePtr>> refkey_loads;
|
||||
std::map<std::string, std::vector<CNodePtr>> refkey_loads_in_call_or_partial;
|
||||
std::map<std::string, std::vector<CNodePtr>> refkey_loads_input_is_call_or_partial;
|
||||
for (auto &node : check_nodes) {
|
||||
// Record load refkey
|
||||
if (IsPrimitiveCNode(node, prim::kPrimLoad)) {
|
||||
auto load = node->cast<CNodePtr>();
|
||||
auto input = load->input(1);
|
||||
auto refkey = GetRefKey(input);
|
||||
if (refkey == "") {
|
||||
MS_LOG(WARNING) << "Load without ref key:" << load->DebugString();
|
||||
continue;
|
||||
}
|
||||
(void)refkey_loads[refkey].emplace_back(load);
|
||||
while (IsPrimitiveCNode(input, prim::kPrimDepend)) {
|
||||
input = input->cast<CNodePtr>()->input(1);
|
||||
}
|
||||
// If Load(call/partial, monad), we should insert TensorMove for the load node.
|
||||
if (CheckLoadInput(input)) {
|
||||
(void)refkey_loads_input_is_call_or_partial[refkey].emplace_back(load);
|
||||
}
|
||||
}
|
||||
// Find special load which is in call or partial
|
||||
if (!IsPrimitiveCNode(node, prim::kPrimCall) && !IsPrimitiveCNode(node, prim::kPrimPartial) &&
|
||||
!(node->isa<CNode>() && IsValueNode<FuncGraph>(node->cast<CNodePtr>()->input(0)))) {
|
||||
continue;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
for (size_t index = 1; index < cnode->inputs().size(); ++index) {
|
||||
auto input = cnode->input(index);
|
||||
if (IsPrimitiveCNode(input, prim::kPrimLoad)) {
|
||||
auto load = input->cast<CNodePtr>();
|
||||
auto refkey = GetRefKey(load->input(1));
|
||||
if (refkey == "") {
|
||||
MS_LOG(WARNING) << "Load without ref key:" << load->DebugString();
|
||||
continue;
|
||||
}
|
||||
if (refkey_loads[refkey].size() > 1) {
|
||||
continue;
|
||||
}
|
||||
(void)refkey_loads_in_call_or_partial[refkey].emplace_back(load);
|
||||
}
|
||||
}
|
||||
}
|
||||
return GetSpecialLoads(refkey_loads, refkey_loads_in_call_or_partial, refkey_loads_input_is_call_or_partial);
|
||||
}
|
||||
|
||||
void InsertTensorMoveForLoad(const CNodePtr &node) {
|
||||
if (!IsPrimitiveCNode(node, prim::kPrimLoad)) {
|
||||
return;
|
||||
}
|
||||
auto prim = std::make_shared<Primitive>(kTensorMoveOpName);
|
||||
std::vector<AnfNodePtr> new_inputs{NewValueNode(prim)};
|
||||
(void)new_inputs.emplace_back(node);
|
||||
auto real_load = func_graph_->NewCNode(new_inputs);
|
||||
real_load->set_abstract(node->abstract());
|
||||
MS_LOG(DEBUG) << "Insert TensorMove " << real_load->DebugString() << " for load " << node->DebugString();
|
||||
manager_->Replace(node, real_load);
|
||||
}
|
||||
|
||||
const FuncGraphPtr &func_graph_;
|
||||
FuncGraphManagerPtr manager_;
|
||||
mindspore::HashMap<AnfNodePtr, size_t> topo_sort_map_;
|
||||
|
|
|
@ -58,6 +58,9 @@ void ValidateOperation(const AnfNodePtr &node) {
|
|||
if (prim->HasAttr("is_load")) {
|
||||
return;
|
||||
}
|
||||
if (prim->name() == "TensorMove") {
|
||||
return;
|
||||
}
|
||||
if (prim->HasPyEvaluator()) {
|
||||
MS_LOG(DEBUG) << "Primitive " << prim->name() << " has python evaluator.";
|
||||
return;
|
||||
|
|
|
@ -16,10 +16,13 @@ import pytest
|
|||
from mindspore.nn import Cell
|
||||
from mindspore import context, Tensor, Parameter
|
||||
import mindspore.ops.operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import composite as C
|
||||
import mindspore as ms
|
||||
import numpy as np
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
class AutoMonadAddnAdamNet(Cell):
|
||||
def __init__(self, var, m, v):
|
||||
|
@ -121,3 +124,77 @@ def test_auto_monad_read_dependency_two_assign_two_addn():
|
|||
out1 = net(Tensor([9.0], ms.float32))
|
||||
out2 = benchmarknet(Tensor([9.0], ms.float32))
|
||||
allclose_nparray(out1.asnumpy(), out2.asnumpy(), 0.001, 0.001)
|
||||
|
||||
|
||||
class ForwardNet(Cell):
|
||||
def __init__(self):
|
||||
super(ForwardNet, self).__init__()
|
||||
self.weight = Parameter(Tensor(np.array(0), ms.int32), name="param")
|
||||
|
||||
def construct(self, x):
|
||||
out = 0
|
||||
i = 0
|
||||
while i < 3:
|
||||
F.assign(self.weight, i)
|
||||
out = x * self.weight + out
|
||||
i = i + 1
|
||||
return out
|
||||
|
||||
|
||||
class BackwardNet(Cell):
|
||||
def __init__(self, net):
|
||||
super(BackwardNet, self).__init__(auto_prefix=False)
|
||||
self.forward_net = net
|
||||
self.grad = C.GradOperation(get_all=True)
|
||||
|
||||
def construct(self, *inputs):
|
||||
grads = self.grad(self.forward_net)(*inputs)
|
||||
return grads
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_load_convert_tensormove():
|
||||
"""
|
||||
Feature: Auto monad feature: record the value of load.
|
||||
Description: record the value of load.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array(1), ms.int32)
|
||||
graph_forword_net = ForwardNet()
|
||||
graph_backword_net = BackwardNet(graph_forword_net)
|
||||
graph_mode_grads = graph_backword_net(x)
|
||||
output_except = (Tensor(np.array(3), ms.int32),)
|
||||
assert np.all(graph_mode_grads == output_except)
|
||||
|
||||
|
||||
class ForwardNet2(Cell):
|
||||
def __init__(self):
|
||||
super(ForwardNet2, self).__init__()
|
||||
self.weight = Parameter(Tensor(np.array(0), ms.int32), name="param")
|
||||
|
||||
def construct(self):
|
||||
out = 0
|
||||
i = 0
|
||||
while i < 3:
|
||||
F.assign(self.weight, i)
|
||||
out = self.weight + out
|
||||
i = i + 1
|
||||
return out
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_load_convert_tensormove_2():
|
||||
"""
|
||||
Feature: Auto monad feature: record the value of load.
|
||||
Description: record the value of load.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
graph_forword_net = ForwardNet2()
|
||||
forward_res = graph_forword_net()
|
||||
assert forward_res == 3
|
||||
|
|
Loading…
Reference in New Issue