forked from mindspore-Ecosystem/mindspore
split UMonad in inputs of op
This commit is contained in:
parent
d346a861bc
commit
0085a273e7
|
@ -26,6 +26,7 @@
|
||||||
#include "backend/kernel_compiler/common_utils.h"
|
#include "backend/kernel_compiler/common_utils.h"
|
||||||
#include "backend/kernel_compiler/kernel_build_info.h"
|
#include "backend/kernel_compiler/kernel_build_info.h"
|
||||||
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
|
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
|
||||||
|
#include "backend/optimizer/graph_kernel/split_umonad.h"
|
||||||
#include "backend/optimizer/graph_kernel/substitute_dropout.h"
|
#include "backend/optimizer/graph_kernel/substitute_dropout.h"
|
||||||
#include "backend/session/anf_runtime_algorithm.h"
|
#include "backend/session/anf_runtime_algorithm.h"
|
||||||
#include "mindspore/core/ir/graph_utils.h"
|
#include "mindspore/core/ir/graph_utils.h"
|
||||||
|
@ -37,10 +38,14 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
namespace {
|
namespace {
|
||||||
|
constexpr size_t kAssignInputIdx = 1;
|
||||||
|
constexpr size_t kLambInputIdx = 12;
|
||||||
|
|
||||||
std::vector<PrimitivePtr> GetExpandOps() {
|
std::vector<PrimitivePtr> GetExpandOps() {
|
||||||
std::vector<PrimitivePtr> expand_ops = {
|
std::vector<PrimitivePtr> expand_ops = {
|
||||||
prim::kPrimSquare,
|
prim::kPrimSquare,
|
||||||
prim::kPrimGeLUGrad,
|
prim::kPrimGeLUGrad,
|
||||||
|
prim::kPrimAssignAdd,
|
||||||
#if ENABLE_D
|
#if ENABLE_D
|
||||||
prim::kPrimTile,
|
prim::kPrimTile,
|
||||||
prim::kPrimSqrtGrad,
|
prim::kPrimSqrtGrad,
|
||||||
|
@ -69,7 +74,6 @@ std::vector<PrimitivePtr> GetExpandOps() {
|
||||||
prim::kPrimSigmoidCrossEntropyWithLogits,
|
prim::kPrimSigmoidCrossEntropyWithLogits,
|
||||||
prim::kPrimSigmoidCrossEntropyWithLogitsGrad,
|
prim::kPrimSigmoidCrossEntropyWithLogitsGrad,
|
||||||
prim::kPrimSoftmaxCrossEntropyWithLogits,
|
prim::kPrimSoftmaxCrossEntropyWithLogits,
|
||||||
prim::kPrimAssignAdd,
|
|
||||||
#endif
|
#endif
|
||||||
};
|
};
|
||||||
const auto &flags = context::GraphKernelFlags::GetInstance();
|
const auto &flags = context::GraphKernelFlags::GetInstance();
|
||||||
|
@ -167,6 +171,22 @@ AnfNodePtr DefaultExpander::Run(const AnfNodePtr &node) {
|
||||||
return graph_kernel_node;
|
return graph_kernel_node;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ExpanderPtr GraphKernelExpander::GetExpander(const AnfNodePtr &node) {
|
||||||
|
std::vector<std::pair<PrimitivePtr, ExpanderPtr>> expanders = {
|
||||||
|
{prim::kPrimDropout, std::make_shared<DropoutExpander>()},
|
||||||
|
{prim::kPrimAssignAdd, std::make_shared<OpUMonadExpander>(kAssignInputIdx)},
|
||||||
|
{prim::kPrimAssignSub, std::make_shared<OpUMonadExpander>(kAssignInputIdx)},
|
||||||
|
{prim::kLambApplyOptimizerAssign, std::make_shared<OpUMonadExpander>(kLambInputIdx)},
|
||||||
|
};
|
||||||
|
|
||||||
|
for (auto &e : expanders) {
|
||||||
|
if (IsPrimitiveCNode(node, e.first)) {
|
||||||
|
return e.second;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return std::make_shared<DefaultExpander>();
|
||||||
|
}
|
||||||
|
|
||||||
bool GraphKernelExpander::DoExpand(const FuncGraphPtr &func_graph) {
|
bool GraphKernelExpander::DoExpand(const FuncGraphPtr &func_graph) {
|
||||||
bool changed = false;
|
bool changed = false;
|
||||||
auto todos = TopoSort(func_graph->get_return());
|
auto todos = TopoSort(func_graph->get_return());
|
||||||
|
@ -192,18 +212,6 @@ bool GraphKernelExpander::DoExpand(const FuncGraphPtr &func_graph) {
|
||||||
return changed;
|
return changed;
|
||||||
}
|
}
|
||||||
|
|
||||||
ExpanderPtr GraphKernelExpander::GetExpander(const AnfNodePtr &node) {
|
|
||||||
std::vector<std::pair<PrimitivePtr, ExpanderPtr>> expanders = {
|
|
||||||
{prim::kPrimDropout, std::make_shared<DropoutExpander>()},
|
|
||||||
};
|
|
||||||
for (auto &e : expanders) {
|
|
||||||
if (IsPrimitiveCNode(node, e.first)) {
|
|
||||||
return e.second;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return std::make_shared<DefaultExpander>();
|
|
||||||
}
|
|
||||||
|
|
||||||
bool GraphKernelExpander::Run(const FuncGraphPtr &func_graph) {
|
bool GraphKernelExpander::Run(const FuncGraphPtr &func_graph) {
|
||||||
expand_ops_ = GetExpandOps();
|
expand_ops_ = GetExpandOps();
|
||||||
return DoExpand(func_graph);
|
return DoExpand(func_graph);
|
||||||
|
|
|
@ -37,7 +37,7 @@
|
||||||
#include "backend/optimizer/graph_kernel/value_graph_binder.h"
|
#include "backend/optimizer/graph_kernel/value_graph_binder.h"
|
||||||
#include "backend/optimizer/graph_kernel/parallel_fusion.h"
|
#include "backend/optimizer/graph_kernel/parallel_fusion.h"
|
||||||
#include "backend/optimizer/graph_kernel/optimize_assign.h"
|
#include "backend/optimizer/graph_kernel/optimize_assign.h"
|
||||||
#include "backend/optimizer/graph_kernel/split_assign.h"
|
#include "backend/optimizer/graph_kernel/split_umonad.h"
|
||||||
#include "backend/optimizer/graph_kernel/reorder_ops.h"
|
#include "backend/optimizer/graph_kernel/reorder_ops.h"
|
||||||
#include "backend/optimizer/graph_kernel/update_state_formatter.h"
|
#include "backend/optimizer/graph_kernel/update_state_formatter.h"
|
||||||
#include "backend/optimizer/graph_kernel/axis_normalizer.h"
|
#include "backend/optimizer/graph_kernel/axis_normalizer.h"
|
||||||
|
|
|
@ -13,7 +13,7 @@
|
||||||
* See the License for the specific language governing permissions and
|
* See the License for the specific language governing permissions and
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
#include "backend/optimizer/graph_kernel/split_assign.h"
|
#include "backend/optimizer/graph_kernel/split_umonad.h"
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
@ -35,31 +35,63 @@ const BaseRef SplitAssign::DefinePattern() const {
|
||||||
return VectorRef({v, Xs, Us, UMonad});
|
return VectorRef({v, Xs, Us, UMonad});
|
||||||
}
|
}
|
||||||
|
|
||||||
bool CanSplit(const AnfNodePtr &node) {
|
bool CanSplit(const AnfNodePtr &node) { return IsPrimitiveCNode(node, prim::kPrimAssign); }
|
||||||
return IsPrimitiveCNode(node, prim::kPrimAssignAdd) || IsPrimitiveCNode(node, prim::kPrimAssign) ||
|
|
||||||
IsPrimitiveCNode(node, prim::kPrimAssignSub);
|
AnfNodePtr ProcessNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, int input_idx) {
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
CNodePtr cnode = node->cast<CNodePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
|
|
||||||
|
// Get original op's abstract and inputs
|
||||||
|
AbstractBasePtr original_abstract = cnode->abstract()->Clone();
|
||||||
|
auto original_inputs = cnode->inputs();
|
||||||
|
|
||||||
|
int input_node_size = cnode->size() - 1;
|
||||||
|
// Create depend node
|
||||||
|
AnfNodePtrList depend_inputs = {NewValueNode(prim::kPrimDepend), original_inputs[input_idx],
|
||||||
|
original_inputs[input_node_size]};
|
||||||
|
auto depend_cnode = func_graph->NewCNode(depend_inputs);
|
||||||
|
depend_cnode->set_abstract(original_inputs[input_idx]->abstract());
|
||||||
|
depend_cnode->set_kernel_info(std::make_shared<device::KernelInfo>());
|
||||||
|
// Create new node, delete U from inputs.
|
||||||
|
AnfNodePtrList new_inputs = {cnode->input(0)};
|
||||||
|
for (int i = 1; i < input_node_size; i++) {
|
||||||
|
if (i == input_idx) {
|
||||||
|
new_inputs.push_back(depend_cnode);
|
||||||
|
} else {
|
||||||
|
new_inputs.push_back(cnode->input(i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
auto new_cnode = func_graph->NewCNode(new_inputs);
|
||||||
|
new_cnode->set_abstract(original_abstract);
|
||||||
|
new_cnode->set_kernel_info(cnode->kernel_info_ptr());
|
||||||
|
return new_cnode;
|
||||||
}
|
}
|
||||||
|
|
||||||
const AnfNodePtr SplitAssign::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const {
|
const AnfNodePtr SplitAssign::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const {
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
if (!CanSplit(node)) return node;
|
if (!CanSplit(node)) return node;
|
||||||
CNodePtr cnode = node->cast<CNodePtr>();
|
return ProcessNode(node->func_graph(), node, 1);
|
||||||
MS_EXCEPTION_IF_NULL(cnode);
|
|
||||||
CheckCNodeInputSize(cnode, kAssignInputTensorNum);
|
|
||||||
// Get original assign op's abstract and inputs
|
|
||||||
AbstractBasePtr original_abstract = cnode->abstract()->Clone();
|
|
||||||
auto original_inputs = cnode->inputs();
|
|
||||||
// Create depend node
|
|
||||||
AnfNodePtrList depend_inputs = {NewValueNode(prim::kPrimDepend), original_inputs[1], original_inputs[3]};
|
|
||||||
auto depend_cnode = func_graph->NewCNode(depend_inputs);
|
|
||||||
depend_cnode->set_abstract(original_inputs[1]->abstract());
|
|
||||||
depend_cnode->set_kernel_info(std::make_shared<device::KernelInfo>());
|
|
||||||
// Create new assign node, delete U from inputs.
|
|
||||||
AnfNodePtrList new_assign_inputs = {cnode->input(0), depend_cnode, original_inputs[2]};
|
|
||||||
auto new_assign_cnode = func_graph->NewCNode(new_assign_inputs);
|
|
||||||
new_assign_cnode->set_abstract(original_abstract);
|
|
||||||
new_assign_cnode->set_kernel_info(cnode->kernel_info_ptr());
|
|
||||||
return new_assign_cnode;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
AnfNodePtr OpUMonadExpander::Run(const AnfNodePtr &node) {
|
||||||
|
auto cnode = node->cast<CNodePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
|
|
||||||
|
bool has_umonad = false;
|
||||||
|
for (unsigned int i = 1; i < cnode->size(); i++) {
|
||||||
|
if (HasAbstractUMonad(cnode->input(i))) {
|
||||||
|
has_umonad = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (has_umonad) {
|
||||||
|
auto new_node = ProcessNode(node->func_graph(), node, input_idx_);
|
||||||
|
return DefaultExpander::Run(new_node);
|
||||||
|
}
|
||||||
|
|
||||||
|
return DefaultExpander::Run(node);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace opt
|
} // namespace opt
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
|
@ -13,11 +13,11 @@
|
||||||
* See the License for the specific language governing permissions and
|
* See the License for the specific language governing permissions and
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SPLIT_ASSIGN_H_
|
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SPLIT_UMONAD_H_
|
||||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SPLIT_ASSIGN_H_
|
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SPLIT_UMONAD_H_
|
||||||
|
|
||||||
#include "backend/optimizer/common/optimizer.h"
|
#include "backend/optimizer/common/optimizer.h"
|
||||||
|
#include "backend/optimizer/graph_kernel/graph_kernel_expander.h"
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
class SplitAssign : public PatternProcessPass {
|
class SplitAssign : public PatternProcessPass {
|
||||||
|
@ -27,6 +27,16 @@ class SplitAssign : public PatternProcessPass {
|
||||||
const BaseRef DefinePattern() const override;
|
const BaseRef DefinePattern() const override;
|
||||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class OpUMonadExpander : public DefaultExpander {
|
||||||
|
public:
|
||||||
|
explicit OpUMonadExpander(int input_idx) : input_idx_(input_idx) {}
|
||||||
|
~OpUMonadExpander() = default;
|
||||||
|
AnfNodePtr Run(const AnfNodePtr &node) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
int input_idx_;
|
||||||
|
};
|
||||||
} // namespace opt
|
} // namespace opt
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SPLIT_ASSIGN_H_
|
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SPLIT_UMONAD_H_
|
|
@ -219,7 +219,9 @@ bool ExtendOutputForUpdateState::ProcessIndex(const FuncGraphPtr &func_graph, co
|
||||||
auto mng = func_graph->manager();
|
auto mng = func_graph->manager();
|
||||||
MS_EXCEPTION_IF_NULL(mng);
|
MS_EXCEPTION_IF_NULL(mng);
|
||||||
for (auto user : mng->node_users()[getitems_[index]]) {
|
for (auto user : mng->node_users()[getitems_[index]]) {
|
||||||
user.first->cast<CNodePtr>()->set_input(user.second, new_node);
|
if (IsPrimitiveCNode(user.first, prim::kPrimUpdateState)) {
|
||||||
|
user.first->cast<CNodePtr>()->set_input(user.second, new_node);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
|
@ -32,26 +32,38 @@ class AssignAdd(nn.Cell):
|
||||||
self.add(self.var, y)
|
self.add(self.var, y)
|
||||||
return self.var
|
return self.var
|
||||||
|
|
||||||
|
def get_output(x2, y2, enable_graph_kernel=False):
|
||||||
@pytest.mark.level0
|
context.set_context(enable_graph_kernel=enable_graph_kernel)
|
||||||
@pytest.mark.platform_x86_gpu_training
|
|
||||||
@pytest.mark.env_onecard
|
|
||||||
def test_assign_add():
|
|
||||||
x2 = Tensor(np.arange(1 * 3 * 3 * 3).reshape(1, 3, 3, 3).astype(np.float32))
|
|
||||||
y2 = Tensor(np.arange(1 * 3 * 3 * 3).reshape(1, 3, 3, 3).astype(np.float32))
|
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE,
|
|
||||||
enable_graph_kernel=True, device_target="GPU")
|
|
||||||
add = AssignAdd(x2)
|
add = AssignAdd(x2)
|
||||||
result_gk_on_1 = add(y2)
|
result_gk_on_1 = add(y2)
|
||||||
add_2 = AssignAdd(result_gk_on_1)
|
add_2 = AssignAdd(result_gk_on_1)
|
||||||
result_gk_on_2 = add_2(y2)
|
result_gk_on_2 = add_2(y2)
|
||||||
|
output = [result_gk_on_1, result_gk_on_2]
|
||||||
|
return output
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE,
|
def assign_add():
|
||||||
enable_graph_kernel=False, device_target="GPU")
|
x2 = Tensor(np.arange(1 * 3 * 3 * 3).reshape(1, 3, 3, 3).astype(np.float32))
|
||||||
add_beta = AssignAdd(x2)
|
y2 = Tensor(np.arange(1 * 3 * 3 * 3).reshape(1, 3, 3, 3).astype(np.float32))
|
||||||
result_gk_off_1 = add_beta(y2)
|
|
||||||
add_beta_2 = AssignAdd(result_gk_off_1)
|
expect = get_output(x2, y2, False)
|
||||||
result_gk_off_2 = add_beta_2(y2)
|
output = get_output(x2, y2, True)
|
||||||
assert (result_gk_on_1.asnumpy() == result_gk_off_1.asnumpy()).all()
|
e1, e2 = list(expect)
|
||||||
assert (result_gk_on_2.asnumpy() == result_gk_off_2.asnumpy()).all()
|
o1, o2 = list(output)
|
||||||
|
|
||||||
|
assert np.allclose(o1.asnumpy(), e1.asnumpy())
|
||||||
|
assert np.allclose(o2.asnumpy(), e2.asnumpy())
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_assign_add_gpu():
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||||
|
assign_add()
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_assign_add_ascend():
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||||
|
assign_add()
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pytest
|
||||||
import mindspore.context as context
|
import mindspore.context as context
|
||||||
import mindspore.nn as nn
|
import mindspore.nn as nn
|
||||||
from mindspore import Tensor
|
from mindspore import Tensor
|
||||||
|
@ -67,6 +68,10 @@ def lamb_apply_optimizer_assign():
|
||||||
assert np.allclose(o2.asnumpy(), e2.asnumpy())
|
assert np.allclose(o2.asnumpy(), e2.asnumpy())
|
||||||
assert np.allclose(o3.asnumpy(), e3.asnumpy())
|
assert np.allclose(o3.asnumpy(), e3.asnumpy())
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
def test_lamb_apply_optimizer_assign_ascend():
|
def test_lamb_apply_optimizer_assign_ascend():
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||||
lamb_apply_optimizer_assign()
|
lamb_apply_optimizer_assign()
|
||||||
|
|
Loading…
Reference in New Issue