[ME][Auto-Monad] Insert Tensor for the Load whose refkey appears more than once,

or the load is input of call or partial, or the first input of load is call or partial.
This commit is contained in:
Margaret_wangrui 2022-01-04 16:40:10 +08:00
parent de34e90e4d
commit 6dcab5a498
3 changed files with 223 additions and 1 deletions

View File

@ -22,6 +22,7 @@
#include <utility>
#include "utils/hash_map.h"
#include "utils/hash_set.h"
#include "utils/utils.h"
#include "base/core_ops.h"
namespace mindspore::pipeline {
@ -35,6 +36,15 @@ class OrderEnforcer {
~OrderEnforcer() = default;
void Run() {
// In order to store current value of parameter, insert TensorMove for Load:
// whose refkey appears more than once,
// or the load is input of call or partial,
// or the first input of load is call or partial.
std::vector<CNodePtr> need_insert_loads = GetNeedInsertLoads();
for (auto &node : need_insert_loads) {
InsertTensorMoveForLoad(node->cast<CNodePtr>());
}
auto nodes = MakeTopoSortMap();
for (auto &node : nodes) {
if (IsPrimitiveCNode(node, prim::kPrimUpdateState)) {
@ -341,6 +351,138 @@ class OrderEnforcer {
});
}
std::string GetRefKey(const AnfNodePtr &node) {
auto abs = node->abstract();
if (abs == nullptr) {
if (IsPrimitiveCNode(node, prim::kPrimDepend)) {
return GetRefKey(node->cast<CNodePtr>()->input(1));
}
return "";
}
auto abs_ref = abs->cast<abstract::AbstractRefPtr>();
if (abs_ref == nullptr) {
return "";
}
auto ref_key = abs_ref->ref_key_value();
if (ref_key == nullptr) {
return "";
}
return ref_key->name();
}
std::vector<CNodePtr> GetAllLoads(const AnfNodePtrList &check_nodes) {
std::vector<CNodePtr> need_insert_loads;
for (auto &node : check_nodes) {
if (IsPrimitiveCNode(node, prim::kPrimLoad)) {
auto load = node->cast<CNodePtr>();
(void)need_insert_loads.emplace_back(load);
}
}
return need_insert_loads;
}
std::vector<CNodePtr> GetSpecialLoads(const std::map<std::string, std::vector<CNodePtr>> &loads_map1,
const std::map<std::string, std::vector<CNodePtr>> &loads_map2,
const std::map<std::string, std::vector<CNodePtr>> &loads_map3) {
std::vector<CNodePtr> need_insert_loads;
for (auto &refkey_load : loads_map1) {
auto &loads = refkey_load.second;
if (loads.size() > 1) {
(void)std::transform(loads.begin(), loads.end(), std::back_inserter(need_insert_loads),
[](const CNodePtr &load) { return load; });
}
}
for (auto &refkey_load_special : loads_map2) {
auto &loads = refkey_load_special.second;
// If loads size > 1, mean has exist in refkey_loads.
if (loads.size() == 1) {
(void)need_insert_loads.emplace_back(loads[0]);
}
}
for (auto &refkey_load_special : loads_map3) {
auto &loads = refkey_load_special.second;
// If loads size > 1, mean has exist in refkey_loads.
if (loads.size() == 1) {
(void)need_insert_loads.emplace_back(loads[0]);
}
}
return need_insert_loads;
}
bool CheckLoadInput(const AnfNodePtr &input) {
return IsPrimitiveCNode(input, prim::kPrimCall) || IsPrimitiveCNode(input, prim::kPrimPartial) ||
(input->isa<CNode>() && (IsValueNode<FuncGraph>(input->cast<CNodePtr>()->input(0)) ||
IsPrimitiveCNode(input->cast<CNodePtr>()->input(0), prim::kPrimSwitch) ||
IsPrimitiveCNode(input->cast<CNodePtr>()->input(0), prim::kPrimSwitchLayer)));
}
std::vector<CNodePtr> GetNeedInsertLoads() {
auto check_nodes = TopoSort(func_graph_->get_return());
static bool enable_all_load = common::GetEnv("MS_DEV_ENABLE_LOAD_INSERT_TENSORMOVE") == "1";
// Insert TensorMove for all Load nodes
if (enable_all_load) {
return GetAllLoads(check_nodes);
}
std::map<std::string, std::vector<CNodePtr>> refkey_loads;
std::map<std::string, std::vector<CNodePtr>> refkey_loads_in_call_or_partial;
std::map<std::string, std::vector<CNodePtr>> refkey_loads_input_is_call_or_partial;
for (auto &node : check_nodes) {
// Record load refkey
if (IsPrimitiveCNode(node, prim::kPrimLoad)) {
auto load = node->cast<CNodePtr>();
auto input = load->input(1);
auto refkey = GetRefKey(input);
if (refkey == "") {
MS_LOG(WARNING) << "Load without ref key:" << load->DebugString();
continue;
}
(void)refkey_loads[refkey].emplace_back(load);
while (IsPrimitiveCNode(input, prim::kPrimDepend)) {
input = input->cast<CNodePtr>()->input(1);
}
// If Load(call/partial, monad), we should insert TensorMove for the load node.
if (CheckLoadInput(input)) {
(void)refkey_loads_input_is_call_or_partial[refkey].emplace_back(load);
}
}
// Find special load which is in call or partial
if (!IsPrimitiveCNode(node, prim::kPrimCall) && !IsPrimitiveCNode(node, prim::kPrimPartial) &&
!(node->isa<CNode>() && IsValueNode<FuncGraph>(node->cast<CNodePtr>()->input(0)))) {
continue;
}
auto cnode = node->cast<CNodePtr>();
for (size_t index = 1; index < cnode->inputs().size(); ++index) {
auto input = cnode->input(index);
if (IsPrimitiveCNode(input, prim::kPrimLoad)) {
auto load = input->cast<CNodePtr>();
auto refkey = GetRefKey(load->input(1));
if (refkey == "") {
MS_LOG(WARNING) << "Load without ref key:" << load->DebugString();
continue;
}
if (refkey_loads[refkey].size() > 1) {
continue;
}
(void)refkey_loads_in_call_or_partial[refkey].emplace_back(load);
}
}
}
return GetSpecialLoads(refkey_loads, refkey_loads_in_call_or_partial, refkey_loads_input_is_call_or_partial);
}
void InsertTensorMoveForLoad(const CNodePtr &node) {
if (!IsPrimitiveCNode(node, prim::kPrimLoad)) {
return;
}
auto prim = std::make_shared<Primitive>(kTensorMoveOpName);
std::vector<AnfNodePtr> new_inputs{NewValueNode(prim)};
(void)new_inputs.emplace_back(node);
auto real_load = func_graph_->NewCNode(new_inputs);
real_load->set_abstract(node->abstract());
MS_LOG(DEBUG) << "Insert TensorMove " << real_load->DebugString() << " for load " << node->DebugString();
manager_->Replace(node, real_load);
}
const FuncGraphPtr &func_graph_;
FuncGraphManagerPtr manager_;
mindspore::HashMap<AnfNodePtr, size_t> topo_sort_map_;

View File

@ -58,6 +58,9 @@ void ValidateOperation(const AnfNodePtr &node) {
if (prim->HasAttr("is_load")) {
return;
}
if (prim->name() == "TensorMove") {
return;
}
if (prim->HasPyEvaluator()) {
MS_LOG(DEBUG) << "Primitive " << prim->name() << " has python evaluator.";
return;

View File

@ -16,10 +16,13 @@ import pytest
from mindspore.nn import Cell
from mindspore import context, Tensor, Parameter
import mindspore.ops.operations as P
from mindspore.ops import functional as F
from mindspore.ops import composite as C
import mindspore as ms
import numpy as np
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(mode=context.GRAPH_MODE)
class AutoMonadAddnAdamNet(Cell):
def __init__(self, var, m, v):
@ -121,3 +124,77 @@ def test_auto_monad_read_dependency_two_assign_two_addn():
out1 = net(Tensor([9.0], ms.float32))
out2 = benchmarknet(Tensor([9.0], ms.float32))
allclose_nparray(out1.asnumpy(), out2.asnumpy(), 0.001, 0.001)
class ForwardNet(Cell):
def __init__(self):
super(ForwardNet, self).__init__()
self.weight = Parameter(Tensor(np.array(0), ms.int32), name="param")
def construct(self, x):
out = 0
i = 0
while i < 3:
F.assign(self.weight, i)
out = x * self.weight + out
i = i + 1
return out
class BackwardNet(Cell):
def __init__(self, net):
super(BackwardNet, self).__init__(auto_prefix=False)
self.forward_net = net
self.grad = C.GradOperation(get_all=True)
def construct(self, *inputs):
grads = self.grad(self.forward_net)(*inputs)
return grads
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_load_convert_tensormove():
"""
Feature: Auto monad feature: record the value of load.
Description: record the value of load.
Expectation: No exception.
"""
x = Tensor(np.array(1), ms.int32)
graph_forword_net = ForwardNet()
graph_backword_net = BackwardNet(graph_forword_net)
graph_mode_grads = graph_backword_net(x)
output_except = (Tensor(np.array(3), ms.int32),)
assert np.all(graph_mode_grads == output_except)
class ForwardNet2(Cell):
def __init__(self):
super(ForwardNet2, self).__init__()
self.weight = Parameter(Tensor(np.array(0), ms.int32), name="param")
def construct(self):
out = 0
i = 0
while i < 3:
F.assign(self.weight, i)
out = self.weight + out
i = i + 1
return out
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_load_convert_tensormove_2():
"""
Feature: Auto monad feature: record the value of load.
Description: record the value of load.
Expectation: No exception.
"""
graph_forword_net = ForwardNet2()
forward_res = graph_forword_net()
assert forward_res == 3