split UMonad in inputs of op

This commit is contained in:
wenfangpei 2021-04-01 14:47:59 +08:00
parent d346a861bc
commit 0085a273e7
7 changed files with 127 additions and 58 deletions

View File

@ -26,6 +26,7 @@
#include "backend/kernel_compiler/common_utils.h" #include "backend/kernel_compiler/common_utils.h"
#include "backend/kernel_compiler/kernel_build_info.h" #include "backend/kernel_compiler/kernel_build_info.h"
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h" #include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
#include "backend/optimizer/graph_kernel/split_umonad.h"
#include "backend/optimizer/graph_kernel/substitute_dropout.h" #include "backend/optimizer/graph_kernel/substitute_dropout.h"
#include "backend/session/anf_runtime_algorithm.h" #include "backend/session/anf_runtime_algorithm.h"
#include "mindspore/core/ir/graph_utils.h" #include "mindspore/core/ir/graph_utils.h"
@ -37,10 +38,14 @@
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
namespace { namespace {
constexpr size_t kAssignInputIdx = 1;
constexpr size_t kLambInputIdx = 12;
std::vector<PrimitivePtr> GetExpandOps() { std::vector<PrimitivePtr> GetExpandOps() {
std::vector<PrimitivePtr> expand_ops = { std::vector<PrimitivePtr> expand_ops = {
prim::kPrimSquare, prim::kPrimSquare,
prim::kPrimGeLUGrad, prim::kPrimGeLUGrad,
prim::kPrimAssignAdd,
#if ENABLE_D #if ENABLE_D
prim::kPrimTile, prim::kPrimTile,
prim::kPrimSqrtGrad, prim::kPrimSqrtGrad,
@ -69,7 +74,6 @@ std::vector<PrimitivePtr> GetExpandOps() {
prim::kPrimSigmoidCrossEntropyWithLogits, prim::kPrimSigmoidCrossEntropyWithLogits,
prim::kPrimSigmoidCrossEntropyWithLogitsGrad, prim::kPrimSigmoidCrossEntropyWithLogitsGrad,
prim::kPrimSoftmaxCrossEntropyWithLogits, prim::kPrimSoftmaxCrossEntropyWithLogits,
prim::kPrimAssignAdd,
#endif #endif
}; };
const auto &flags = context::GraphKernelFlags::GetInstance(); const auto &flags = context::GraphKernelFlags::GetInstance();
@ -167,6 +171,22 @@ AnfNodePtr DefaultExpander::Run(const AnfNodePtr &node) {
return graph_kernel_node; return graph_kernel_node;
} }
ExpanderPtr GraphKernelExpander::GetExpander(const AnfNodePtr &node) {
std::vector<std::pair<PrimitivePtr, ExpanderPtr>> expanders = {
{prim::kPrimDropout, std::make_shared<DropoutExpander>()},
{prim::kPrimAssignAdd, std::make_shared<OpUMonadExpander>(kAssignInputIdx)},
{prim::kPrimAssignSub, std::make_shared<OpUMonadExpander>(kAssignInputIdx)},
{prim::kLambApplyOptimizerAssign, std::make_shared<OpUMonadExpander>(kLambInputIdx)},
};
for (auto &e : expanders) {
if (IsPrimitiveCNode(node, e.first)) {
return e.second;
}
}
return std::make_shared<DefaultExpander>();
}
bool GraphKernelExpander::DoExpand(const FuncGraphPtr &func_graph) { bool GraphKernelExpander::DoExpand(const FuncGraphPtr &func_graph) {
bool changed = false; bool changed = false;
auto todos = TopoSort(func_graph->get_return()); auto todos = TopoSort(func_graph->get_return());
@ -192,18 +212,6 @@ bool GraphKernelExpander::DoExpand(const FuncGraphPtr &func_graph) {
return changed; return changed;
} }
ExpanderPtr GraphKernelExpander::GetExpander(const AnfNodePtr &node) {
std::vector<std::pair<PrimitivePtr, ExpanderPtr>> expanders = {
{prim::kPrimDropout, std::make_shared<DropoutExpander>()},
};
for (auto &e : expanders) {
if (IsPrimitiveCNode(node, e.first)) {
return e.second;
}
}
return std::make_shared<DefaultExpander>();
}
bool GraphKernelExpander::Run(const FuncGraphPtr &func_graph) { bool GraphKernelExpander::Run(const FuncGraphPtr &func_graph) {
expand_ops_ = GetExpandOps(); expand_ops_ = GetExpandOps();
return DoExpand(func_graph); return DoExpand(func_graph);

View File

@ -37,7 +37,7 @@
#include "backend/optimizer/graph_kernel/value_graph_binder.h" #include "backend/optimizer/graph_kernel/value_graph_binder.h"
#include "backend/optimizer/graph_kernel/parallel_fusion.h" #include "backend/optimizer/graph_kernel/parallel_fusion.h"
#include "backend/optimizer/graph_kernel/optimize_assign.h" #include "backend/optimizer/graph_kernel/optimize_assign.h"
#include "backend/optimizer/graph_kernel/split_assign.h" #include "backend/optimizer/graph_kernel/split_umonad.h"
#include "backend/optimizer/graph_kernel/reorder_ops.h" #include "backend/optimizer/graph_kernel/reorder_ops.h"
#include "backend/optimizer/graph_kernel/update_state_formatter.h" #include "backend/optimizer/graph_kernel/update_state_formatter.h"
#include "backend/optimizer/graph_kernel/axis_normalizer.h" #include "backend/optimizer/graph_kernel/axis_normalizer.h"

View File

@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#include "backend/optimizer/graph_kernel/split_assign.h" #include "backend/optimizer/graph_kernel/split_umonad.h"
#include <vector> #include <vector>
#include <string> #include <string>
@ -35,31 +35,63 @@ const BaseRef SplitAssign::DefinePattern() const {
return VectorRef({v, Xs, Us, UMonad}); return VectorRef({v, Xs, Us, UMonad});
} }
bool CanSplit(const AnfNodePtr &node) { bool CanSplit(const AnfNodePtr &node) { return IsPrimitiveCNode(node, prim::kPrimAssign); }
return IsPrimitiveCNode(node, prim::kPrimAssignAdd) || IsPrimitiveCNode(node, prim::kPrimAssign) ||
IsPrimitiveCNode(node, prim::kPrimAssignSub); AnfNodePtr ProcessNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, int input_idx) {
MS_EXCEPTION_IF_NULL(node);
CNodePtr cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
// Get original op's abstract and inputs
AbstractBasePtr original_abstract = cnode->abstract()->Clone();
auto original_inputs = cnode->inputs();
int input_node_size = cnode->size() - 1;
// Create depend node
AnfNodePtrList depend_inputs = {NewValueNode(prim::kPrimDepend), original_inputs[input_idx],
original_inputs[input_node_size]};
auto depend_cnode = func_graph->NewCNode(depend_inputs);
depend_cnode->set_abstract(original_inputs[input_idx]->abstract());
depend_cnode->set_kernel_info(std::make_shared<device::KernelInfo>());
// Create new node, delete U from inputs.
AnfNodePtrList new_inputs = {cnode->input(0)};
for (int i = 1; i < input_node_size; i++) {
if (i == input_idx) {
new_inputs.push_back(depend_cnode);
} else {
new_inputs.push_back(cnode->input(i));
}
}
auto new_cnode = func_graph->NewCNode(new_inputs);
new_cnode->set_abstract(original_abstract);
new_cnode->set_kernel_info(cnode->kernel_info_ptr());
return new_cnode;
} }
const AnfNodePtr SplitAssign::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const { const AnfNodePtr SplitAssign::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
if (!CanSplit(node)) return node; if (!CanSplit(node)) return node;
CNodePtr cnode = node->cast<CNodePtr>(); return ProcessNode(node->func_graph(), node, 1);
MS_EXCEPTION_IF_NULL(cnode);
CheckCNodeInputSize(cnode, kAssignInputTensorNum);
// Get original assign op's abstract and inputs
AbstractBasePtr original_abstract = cnode->abstract()->Clone();
auto original_inputs = cnode->inputs();
// Create depend node
AnfNodePtrList depend_inputs = {NewValueNode(prim::kPrimDepend), original_inputs[1], original_inputs[3]};
auto depend_cnode = func_graph->NewCNode(depend_inputs);
depend_cnode->set_abstract(original_inputs[1]->abstract());
depend_cnode->set_kernel_info(std::make_shared<device::KernelInfo>());
// Create new assign node, delete U from inputs.
AnfNodePtrList new_assign_inputs = {cnode->input(0), depend_cnode, original_inputs[2]};
auto new_assign_cnode = func_graph->NewCNode(new_assign_inputs);
new_assign_cnode->set_abstract(original_abstract);
new_assign_cnode->set_kernel_info(cnode->kernel_info_ptr());
return new_assign_cnode;
} }
AnfNodePtr OpUMonadExpander::Run(const AnfNodePtr &node) {
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
bool has_umonad = false;
for (unsigned int i = 1; i < cnode->size(); i++) {
if (HasAbstractUMonad(cnode->input(i))) {
has_umonad = true;
break;
}
}
if (has_umonad) {
auto new_node = ProcessNode(node->func_graph(), node, input_idx_);
return DefaultExpander::Run(new_node);
}
return DefaultExpander::Run(node);
}
} // namespace opt } // namespace opt
} // namespace mindspore } // namespace mindspore

View File

@ -13,11 +13,11 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SPLIT_ASSIGN_H_ #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SPLIT_UMONAD_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SPLIT_ASSIGN_H_ #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SPLIT_UMONAD_H_
#include "backend/optimizer/common/optimizer.h" #include "backend/optimizer/common/optimizer.h"
#include "backend/optimizer/graph_kernel/graph_kernel_expander.h"
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
class SplitAssign : public PatternProcessPass { class SplitAssign : public PatternProcessPass {
@ -27,6 +27,16 @@ class SplitAssign : public PatternProcessPass {
const BaseRef DefinePattern() const override; const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
}; };
class OpUMonadExpander : public DefaultExpander {
public:
explicit OpUMonadExpander(int input_idx) : input_idx_(input_idx) {}
~OpUMonadExpander() = default;
AnfNodePtr Run(const AnfNodePtr &node) override;
private:
int input_idx_;
};
} // namespace opt } // namespace opt
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SPLIT_ASSIGN_H_ #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SPLIT_UMONAD_H_

View File

@ -219,7 +219,9 @@ bool ExtendOutputForUpdateState::ProcessIndex(const FuncGraphPtr &func_graph, co
auto mng = func_graph->manager(); auto mng = func_graph->manager();
MS_EXCEPTION_IF_NULL(mng); MS_EXCEPTION_IF_NULL(mng);
for (auto user : mng->node_users()[getitems_[index]]) { for (auto user : mng->node_users()[getitems_[index]]) {
user.first->cast<CNodePtr>()->set_input(user.second, new_node); if (IsPrimitiveCNode(user.first, prim::kPrimUpdateState)) {
user.first->cast<CNodePtr>()->set_input(user.second, new_node);
}
} }
return true; return true;
} }

View File

@ -32,26 +32,38 @@ class AssignAdd(nn.Cell):
self.add(self.var, y) self.add(self.var, y)
return self.var return self.var
def get_output(x2, y2, enable_graph_kernel=False):
@pytest.mark.level0 context.set_context(enable_graph_kernel=enable_graph_kernel)
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_assign_add():
x2 = Tensor(np.arange(1 * 3 * 3 * 3).reshape(1, 3, 3, 3).astype(np.float32))
y2 = Tensor(np.arange(1 * 3 * 3 * 3).reshape(1, 3, 3, 3).astype(np.float32))
context.set_context(mode=context.GRAPH_MODE,
enable_graph_kernel=True, device_target="GPU")
add = AssignAdd(x2) add = AssignAdd(x2)
result_gk_on_1 = add(y2) result_gk_on_1 = add(y2)
add_2 = AssignAdd(result_gk_on_1) add_2 = AssignAdd(result_gk_on_1)
result_gk_on_2 = add_2(y2) result_gk_on_2 = add_2(y2)
output = [result_gk_on_1, result_gk_on_2]
return output
context.set_context(mode=context.GRAPH_MODE, def assign_add():
enable_graph_kernel=False, device_target="GPU") x2 = Tensor(np.arange(1 * 3 * 3 * 3).reshape(1, 3, 3, 3).astype(np.float32))
add_beta = AssignAdd(x2) y2 = Tensor(np.arange(1 * 3 * 3 * 3).reshape(1, 3, 3, 3).astype(np.float32))
result_gk_off_1 = add_beta(y2)
add_beta_2 = AssignAdd(result_gk_off_1) expect = get_output(x2, y2, False)
result_gk_off_2 = add_beta_2(y2) output = get_output(x2, y2, True)
assert (result_gk_on_1.asnumpy() == result_gk_off_1.asnumpy()).all() e1, e2 = list(expect)
assert (result_gk_on_2.asnumpy() == result_gk_off_2.asnumpy()).all() o1, o2 = list(output)
assert np.allclose(o1.asnumpy(), e1.asnumpy())
assert np.allclose(o2.asnumpy(), e2.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_assign_add_gpu():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
assign_add()
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_assign_add_ascend():
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
assign_add()

View File

@ -14,6 +14,7 @@
# ============================================================================ # ============================================================================
import numpy as np import numpy as np
import pytest
import mindspore.context as context import mindspore.context as context
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Tensor from mindspore import Tensor
@ -67,6 +68,10 @@ def lamb_apply_optimizer_assign():
assert np.allclose(o2.asnumpy(), e2.asnumpy()) assert np.allclose(o2.asnumpy(), e2.asnumpy())
assert np.allclose(o3.asnumpy(), e3.asnumpy()) assert np.allclose(o3.asnumpy(), e3.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_lamb_apply_optimizer_assign_ascend(): def test_lamb_apply_optimizer_assign_ascend():
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
lamb_apply_optimizer_assign() lamb_apply_optimizer_assign()