unify the output num of optimizer ops
This commit is contained in:
parent
4754d1f3ed
commit
cd9173fdfd
|
@ -292,7 +292,6 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap
|
||||||
ir_fusion_pm->AddPass(std::make_shared<EraseVisitAttr>());
|
ir_fusion_pm->AddPass(std::make_shared<EraseVisitAttr>());
|
||||||
}
|
}
|
||||||
ir_fusion_pm->AddPass(std::make_shared<InsertMemcpyAsyncForHcclOp>());
|
ir_fusion_pm->AddPass(std::make_shared<InsertMemcpyAsyncForHcclOp>());
|
||||||
ir_fusion_pm->AddPass(std::make_shared<AddInputToOutput>());
|
|
||||||
ir_fusion_pm->AddPass(std::make_shared<InsertTranspose>());
|
ir_fusion_pm->AddPass(std::make_shared<InsertTranspose>());
|
||||||
ir_fusion_pm->AddPass(std::make_shared<GetitemTuple>());
|
ir_fusion_pm->AddPass(std::make_shared<GetitemTuple>());
|
||||||
ir_fusion_pm->AddPass(std::make_shared<EraseVisitAttr>());
|
ir_fusion_pm->AddPass(std::make_shared<EraseVisitAttr>());
|
||||||
|
|
|
@ -0,0 +1,126 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "backend/optimizer/ascend/mindir/optimizer_unify_output.h"
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "abstract/abstract_value.h"
|
||||||
|
#include "backend/session/anf_runtime_algorithm.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace opt {
|
||||||
|
namespace {
|
||||||
|
constexpr size_t kFtrlOutputNum = 3;
|
||||||
|
constexpr size_t kMomentumOutputNum = 2;
|
||||||
|
constexpr size_t kRMSPropOutputNum = 3;
|
||||||
|
constexpr size_t kCenteredRMSPropOutputNum = 4;
|
||||||
|
|
||||||
|
CNodePtr ProcessOutput(const FuncGraphPtr &graph, const AnfNodePtr &node, const size_t output_size) {
|
||||||
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
|
||||||
|
auto cnode_ptr = node->cast<CNodePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(cnode_ptr);
|
||||||
|
|
||||||
|
auto abstract = cnode_ptr->abstract();
|
||||||
|
MS_EXCEPTION_IF_NULL(abstract);
|
||||||
|
|
||||||
|
if (AnfAlgo::HasNodeAttr("optim_output_passed", cnode_ptr) && abstract->isa<abstract::AbstractTuple>()) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
AnfAlgo::SetNodeAttr("optim_output_passed", MakeValue(true), cnode_ptr);
|
||||||
|
|
||||||
|
std::vector<AbstractBasePtr> abstract_list;
|
||||||
|
for (size_t i = 0; i < output_size; i++) {
|
||||||
|
abstract_list.push_back(abstract->Clone());
|
||||||
|
}
|
||||||
|
auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list);
|
||||||
|
cnode_ptr->set_abstract(abstract_tuple);
|
||||||
|
|
||||||
|
auto index = NewValueNode(static_cast<int64_t>(0));
|
||||||
|
auto get_item = graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cnode_ptr, index});
|
||||||
|
MS_EXCEPTION_IF_NULL(get_item);
|
||||||
|
|
||||||
|
get_item->set_abstract(abstract->Clone());
|
||||||
|
return get_item;
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
const BaseRef FtrlUnifyOutput::DefinePattern() const {
|
||||||
|
VarPtr var = std::make_shared<Var>();
|
||||||
|
VarPtr accum = std::make_shared<Var>();
|
||||||
|
VarPtr linear = std::make_shared<Var>();
|
||||||
|
VarPtr grad = std::make_shared<Var>();
|
||||||
|
VarPtr lr = std::make_shared<Var>();
|
||||||
|
VarPtr l1 = std::make_shared<Var>();
|
||||||
|
VarPtr l2 = std::make_shared<Var>();
|
||||||
|
VarPtr lr_power = std::make_shared<Var>();
|
||||||
|
VectorRef pattern({prim::kPrimApplyFtrl, var, accum, linear, grad, lr, l1, l2, lr_power});
|
||||||
|
return pattern;
|
||||||
|
}
|
||||||
|
|
||||||
|
const AnfNodePtr FtrlUnifyOutput::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const {
|
||||||
|
return ProcessOutput(graph, node, kFtrlOutputNum);
|
||||||
|
}
|
||||||
|
|
||||||
|
const BaseRef MomentumUnifyOutput::DefinePattern() const {
|
||||||
|
VarPtr var = std::make_shared<Var>();
|
||||||
|
VarPtr accum = std::make_shared<Var>();
|
||||||
|
VarPtr lr = std::make_shared<Var>();
|
||||||
|
VarPtr grad = std::make_shared<Var>();
|
||||||
|
VarPtr momentum = std::make_shared<Var>();
|
||||||
|
VectorRef pattern({prim::kPrimApplyMomentum, var, accum, lr, grad, momentum});
|
||||||
|
return pattern;
|
||||||
|
}
|
||||||
|
|
||||||
|
const AnfNodePtr MomentumUnifyOutput::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
||||||
|
const EquivPtr &) const {
|
||||||
|
return ProcessOutput(graph, node, kMomentumOutputNum);
|
||||||
|
}
|
||||||
|
|
||||||
|
const BaseRef RMSPropUnifyOutput::DefinePattern() const {
|
||||||
|
VarPtr inputs = std::make_shared<SeqVar>();
|
||||||
|
VectorRef pattern({prim::kPrimApplyRMSProp, inputs});
|
||||||
|
return pattern;
|
||||||
|
}
|
||||||
|
|
||||||
|
const AnfNodePtr RMSPropUnifyOutput::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
||||||
|
const EquivPtr &) const {
|
||||||
|
return ProcessOutput(graph, node, kRMSPropOutputNum);
|
||||||
|
}
|
||||||
|
|
||||||
|
const BaseRef CenteredRMSPropUnifyOutput::DefinePattern() const {
|
||||||
|
VarPtr var = std::make_shared<Var>();
|
||||||
|
VarPtr mg = std::make_shared<Var>();
|
||||||
|
VarPtr ms = std::make_shared<Var>();
|
||||||
|
VarPtr mom = std::make_shared<Var>();
|
||||||
|
VarPtr grad = std::make_shared<Var>();
|
||||||
|
VarPtr lr = std::make_shared<Var>();
|
||||||
|
VarPtr rho = std::make_shared<Var>();
|
||||||
|
VarPtr momentum = std::make_shared<Var>();
|
||||||
|
VarPtr epsilon = std::make_shared<Var>();
|
||||||
|
VectorRef pattern({prim::kPrimApplyCenteredRMSProp, var, mg, ms, mom, grad, lr, rho, momentum, epsilon});
|
||||||
|
return pattern;
|
||||||
|
}
|
||||||
|
|
||||||
|
const AnfNodePtr CenteredRMSPropUnifyOutput::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
||||||
|
const EquivPtr &) const {
|
||||||
|
return ProcessOutput(graph, node, kCenteredRMSPropOutputNum);
|
||||||
|
}
|
||||||
|
} // namespace opt
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,58 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_OPTIMIZER_UNIFY_OUTPUT_H_
|
||||||
|
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_OPTIMIZER_UNIFY_OUTPUT_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include "backend/optimizer/common/optimizer.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace opt {
|
||||||
|
class FtrlUnifyOutput : public PatternProcessPass {
|
||||||
|
public:
|
||||||
|
explicit FtrlUnifyOutput(bool multigraph = true) : PatternProcessPass("ftrl_unify_output", multigraph) {}
|
||||||
|
~FtrlUnifyOutput() override = default;
|
||||||
|
const BaseRef DefinePattern() const override;
|
||||||
|
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||||
|
};
|
||||||
|
|
||||||
|
class MomentumUnifyOutput : public PatternProcessPass {
|
||||||
|
public:
|
||||||
|
explicit MomentumUnifyOutput(bool multigraph = true) : PatternProcessPass("momentum_unify_output", multigraph) {}
|
||||||
|
~MomentumUnifyOutput() override = default;
|
||||||
|
const BaseRef DefinePattern() const override;
|
||||||
|
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||||
|
};
|
||||||
|
|
||||||
|
class CenteredRMSPropUnifyOutput : public PatternProcessPass {
|
||||||
|
public:
|
||||||
|
explicit CenteredRMSPropUnifyOutput(bool multigraph = true)
|
||||||
|
: PatternProcessPass("centered_rmsprop_unify_output", multigraph) {}
|
||||||
|
~CenteredRMSPropUnifyOutput() override = default;
|
||||||
|
const BaseRef DefinePattern() const override;
|
||||||
|
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||||
|
};
|
||||||
|
|
||||||
|
class RMSPropUnifyOutput : public PatternProcessPass {
|
||||||
|
public:
|
||||||
|
explicit RMSPropUnifyOutput(bool multigraph = true) : PatternProcessPass("rmsprop_unify_output", multigraph) {}
|
||||||
|
~RMSPropUnifyOutput() override = default;
|
||||||
|
const BaseRef DefinePattern() const override;
|
||||||
|
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||||
|
};
|
||||||
|
} // namespace opt
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_OPTIMIZER_UNIFY_OUTPUT_H_
|
|
@ -38,6 +38,7 @@
|
||||||
#include "backend/optimizer/ascend/mindir/maxpool_to_maxpool_with_argmax.h"
|
#include "backend/optimizer/ascend/mindir/maxpool_to_maxpool_with_argmax.h"
|
||||||
#include "backend/optimizer/ascend/mindir/maxpool_with_argmax_unify_mindir.h"
|
#include "backend/optimizer/ascend/mindir/maxpool_with_argmax_unify_mindir.h"
|
||||||
#include "backend/optimizer/ascend/mindir/conv2d_unify_mindir.h"
|
#include "backend/optimizer/ascend/mindir/conv2d_unify_mindir.h"
|
||||||
|
#include "backend/optimizer/ascend/mindir/optimizer_unify_output.h"
|
||||||
#include "backend/optimizer/ascend/mindir/sparse_softmax_cross_entropy_with_logits_unify_mindir.h"
|
#include "backend/optimizer/ascend/mindir/sparse_softmax_cross_entropy_with_logits_unify_mindir.h"
|
||||||
#include "backend/optimizer/ascend/mindir/slice_grad_unify_mindir.h"
|
#include "backend/optimizer/ascend/mindir/slice_grad_unify_mindir.h"
|
||||||
#include "runtime/device/kernel_adjust.h"
|
#include "runtime/device/kernel_adjust.h"
|
||||||
|
@ -217,6 +218,10 @@ void AscendSession::UnifyMindIR(const KernelGraphPtr &graph) {
|
||||||
unify_mindir_pm->AddPass(std::make_shared<opt::Conv2DBackpropInputUnifyMindIR>());
|
unify_mindir_pm->AddPass(std::make_shared<opt::Conv2DBackpropInputUnifyMindIR>());
|
||||||
unify_mindir_pm->AddPass(std::make_shared<opt::Conv2DBackpropFilterUnifyMindIR>());
|
unify_mindir_pm->AddPass(std::make_shared<opt::Conv2DBackpropFilterUnifyMindIR>());
|
||||||
unify_mindir_pm->AddPass(std::make_shared<opt::SliceGradUnifyMindIR>());
|
unify_mindir_pm->AddPass(std::make_shared<opt::SliceGradUnifyMindIR>());
|
||||||
|
unify_mindir_pm->AddPass(std::make_shared<opt::FtrlUnifyOutput>());
|
||||||
|
unify_mindir_pm->AddPass(std::make_shared<opt::MomentumUnifyOutput>());
|
||||||
|
unify_mindir_pm->AddPass(std::make_shared<opt::RMSPropUnifyOutput>());
|
||||||
|
unify_mindir_pm->AddPass(std::make_shared<opt::CenteredRMSPropUnifyOutput>());
|
||||||
auto ms_context = MsContext::GetInstance();
|
auto ms_context = MsContext::GetInstance();
|
||||||
MS_EXCEPTION_IF_NULL(ms_context);
|
MS_EXCEPTION_IF_NULL(ms_context);
|
||||||
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
|
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -143,13 +143,13 @@ ATTR_MAP(SparseApplyFtrlD) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<b
|
||||||
OUTPUT_MAP(SparseApplyFtrlD) = {{0, OUTPUT_DESC(var)}};
|
OUTPUT_MAP(SparseApplyFtrlD) = {{0, OUTPUT_DESC(var)}};
|
||||||
REG_ADPT_DESC(SparseApplyFtrlD, kNameSparseApplyFtrlD, ADPT_DESC(SparseApplyFtrlD))
|
REG_ADPT_DESC(SparseApplyFtrlD, kNameSparseApplyFtrlD, ADPT_DESC(SparseApplyFtrlD))
|
||||||
|
|
||||||
// ApplyFtrlD
|
// ApplyFtrl
|
||||||
INPUT_MAP(ApplyFtrlD) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(accum)}, {3, INPUT_DESC(linear)},
|
INPUT_MAP(ApplyFtrl) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(accum)}, {3, INPUT_DESC(linear)},
|
||||||
{4, INPUT_DESC(grad)}, {5, INPUT_DESC(lr)}, {6, INPUT_DESC(l1)},
|
{4, INPUT_DESC(grad)}, {5, INPUT_DESC(lr)}, {6, INPUT_DESC(l1)},
|
||||||
{7, INPUT_DESC(l2)}, {8, INPUT_DESC(lr_power)}};
|
{7, INPUT_DESC(l2)}, {8, INPUT_DESC(lr_power)}};
|
||||||
ATTR_MAP(ApplyFtrlD) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
|
ATTR_MAP(ApplyFtrl) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
|
||||||
OUTPUT_MAP(ApplyFtrlD) = {{0, OUTPUT_DESC(var)}, {1, OUTPUT_DESC(accum)}, {2, OUTPUT_DESC(linear)}};
|
OUTPUT_MAP(ApplyFtrl) = {{0, OUTPUT_DESC(var)}};
|
||||||
REG_ADPT_DESC(ApplyFtrlD, kNameApplyFtrl, ADPT_DESC(ApplyFtrlD))
|
REG_ADPT_DESC(ApplyFtrl, kNameApplyFtrl, ADPT_DESC(ApplyFtrl))
|
||||||
|
|
||||||
// ApplyRMSPropD
|
// ApplyRMSPropD
|
||||||
INPUT_MAP(ApplyRMSPropD) = {
|
INPUT_MAP(ApplyRMSPropD) = {
|
||||||
|
@ -161,12 +161,11 @@ ATTR_MAP(ApplyRMSPropD) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool
|
||||||
OUTPUT_MAP(ApplyRMSPropD) = {{0, OUTPUT_DESC(var)}};
|
OUTPUT_MAP(ApplyRMSPropD) = {{0, OUTPUT_DESC(var)}};
|
||||||
REG_ADPT_DESC(ApplyRMSPropD, kNameApplyRMSProp, ADPT_DESC(ApplyRMSPropD))
|
REG_ADPT_DESC(ApplyRMSPropD, kNameApplyRMSProp, ADPT_DESC(ApplyRMSPropD))
|
||||||
|
|
||||||
// ApplyCenteredRMSPropD
|
// ApplyCenteredRMSProp
|
||||||
INPUT_MAP(ApplyCenteredRMSPropD) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(mg)}, {3, INPUT_DESC(ms)},
|
INPUT_MAP(ApplyCenteredRMSProp) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(mg)}, {3, INPUT_DESC(ms)},
|
||||||
{4, INPUT_DESC(mom)}, {5, INPUT_DESC(grad)}, {6, INPUT_DESC(lr)},
|
{4, INPUT_DESC(mom)}, {5, INPUT_DESC(grad)}, {6, INPUT_DESC(lr)},
|
||||||
{7, INPUT_DESC(rho)}, {8, INPUT_DESC(momentum)}, {9, INPUT_DESC(epsilon)}};
|
{7, INPUT_DESC(rho)}, {8, INPUT_DESC(momentum)}, {9, INPUT_DESC(epsilon)}};
|
||||||
ATTR_MAP(ApplyCenteredRMSPropD) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
|
ATTR_MAP(ApplyCenteredRMSProp) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
|
||||||
OUTPUT_MAP(ApplyCenteredRMSPropD) = {
|
OUTPUT_MAP(ApplyCenteredRMSProp) = {{0, OUTPUT_DESC(var)}};
|
||||||
{0, OUTPUT_DESC(var)}, {1, OUTPUT_DESC(mg)}, {2, OUTPUT_DESC(ms)}, {3, OUTPUT_DESC(mom)}};
|
REG_ADPT_DESC(ApplyCenteredRMSProp, kNameApplyCenteredRMSProp, ADPT_DESC(ApplyCenteredRMSProp))
|
||||||
REG_ADPT_DESC(ApplyCenteredRMSPropD, kNameApplyCenteredRMSProp, ADPT_DESC(ApplyCenteredRMSPropD))
|
|
||||||
} // namespace mindspore::transform
|
} // namespace mindspore::transform
|
||||||
|
|
|
@ -62,8 +62,8 @@ DECLARE_OP_USE_OUTPUT(ApplyProximalAdagradD)
|
||||||
DECLARE_OP_ADAPTER(LarsV2Update)
|
DECLARE_OP_ADAPTER(LarsV2Update)
|
||||||
DECLARE_OP_USE_OUTPUT(LarsV2Update)
|
DECLARE_OP_USE_OUTPUT(LarsV2Update)
|
||||||
|
|
||||||
DECLARE_OP_ADAPTER(ApplyFtrlD)
|
DECLARE_OP_ADAPTER(ApplyFtrl)
|
||||||
DECLARE_OP_USE_OUTPUT(ApplyFtrlD)
|
DECLARE_OP_USE_OUTPUT(ApplyFtrl)
|
||||||
|
|
||||||
DECLARE_OP_ADAPTER(SparseApplyFtrlD)
|
DECLARE_OP_ADAPTER(SparseApplyFtrlD)
|
||||||
DECLARE_OP_USE_OUTPUT(SparseApplyFtrlD)
|
DECLARE_OP_USE_OUTPUT(SparseApplyFtrlD)
|
||||||
|
@ -72,7 +72,7 @@ DECLARE_OP_ADAPTER(ApplyRMSPropD)
|
||||||
DECLARE_OP_USE_INPUT_ATTR(ApplyRMSPropD)
|
DECLARE_OP_USE_INPUT_ATTR(ApplyRMSPropD)
|
||||||
DECLARE_OP_USE_OUTPUT(ApplyRMSPropD)
|
DECLARE_OP_USE_OUTPUT(ApplyRMSPropD)
|
||||||
|
|
||||||
DECLARE_OP_ADAPTER(ApplyCenteredRMSPropD)
|
DECLARE_OP_ADAPTER(ApplyCenteredRMSProp)
|
||||||
DECLARE_OP_USE_OUTPUT(ApplyCenteredRMSPropD)
|
DECLARE_OP_USE_OUTPUT(ApplyCenteredRMSProp)
|
||||||
} // namespace mindspore::transform
|
} // namespace mindspore::transform
|
||||||
#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_NN_TRAINING_OPS_DECLARE_H_
|
#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_NN_TRAINING_OPS_DECLARE_H_
|
||||||
|
|
|
@ -239,6 +239,7 @@ inline const PrimitivePtr kPrimSparseSoftmaxCrossEntropyWithLogits =
|
||||||
std::make_shared<Primitive>("SparseSoftmaxCrossEntropyWithLogits");
|
std::make_shared<Primitive>("SparseSoftmaxCrossEntropyWithLogits");
|
||||||
inline const PrimitivePtr kPrimMomentum = std::make_shared<Primitive>("Momentum");
|
inline const PrimitivePtr kPrimMomentum = std::make_shared<Primitive>("Momentum");
|
||||||
inline const PrimitivePtr kPrimApplyMomentum = std::make_shared<Primitive>("ApplyMomentum");
|
inline const PrimitivePtr kPrimApplyMomentum = std::make_shared<Primitive>("ApplyMomentum");
|
||||||
|
inline const PrimitivePtr kPrimApplyFtrl = std::make_shared<Primitive>("ApplyFtrl");
|
||||||
inline const PrimitivePtr kPrimLayerNorm = std::make_shared<Primitive>("LayerNorm");
|
inline const PrimitivePtr kPrimLayerNorm = std::make_shared<Primitive>("LayerNorm");
|
||||||
inline const PrimitivePtr kPrimLrn = std::make_shared<Primitive>("Lrn");
|
inline const PrimitivePtr kPrimLrn = std::make_shared<Primitive>("Lrn");
|
||||||
inline const PrimitivePtr kPrimLayerNormGrad = std::make_shared<Primitive>("LayerNormGrad");
|
inline const PrimitivePtr kPrimLayerNormGrad = std::make_shared<Primitive>("LayerNormGrad");
|
||||||
|
@ -452,7 +453,7 @@ inline const PrimitivePtr kPrimGetRefKey = std::make_shared<Primitive>("get_ref_
|
||||||
inline const PrimitivePtr kPrimMakeRef = std::make_shared<Primitive>("make_ref");
|
inline const PrimitivePtr kPrimMakeRef = std::make_shared<Primitive>("make_ref");
|
||||||
inline const PrimitivePtr kPrimGetRefValue = std::make_shared<Primitive>("get_ref_value");
|
inline const PrimitivePtr kPrimGetRefValue = std::make_shared<Primitive>("get_ref_value");
|
||||||
|
|
||||||
// Other primitve not used by backend but used in core;
|
// Other primitive not used by backend but used in core;
|
||||||
inline const PrimitivePtr kPrimStateSetItem = std::make_shared<Primitive>("state_setitem");
|
inline const PrimitivePtr kPrimStateSetItem = std::make_shared<Primitive>("state_setitem");
|
||||||
inline const PrimitivePtr kPrimJ = std::make_shared<Primitive>("J");
|
inline const PrimitivePtr kPrimJ = std::make_shared<Primitive>("J");
|
||||||
|
|
||||||
|
|
|
@ -308,7 +308,6 @@ def _concat_grad_uniform(input_shapes, input_nums):
|
||||||
def get_bprop_concat(self):
|
def get_bprop_concat(self):
|
||||||
"""Generate bprop for Concat"""
|
"""Generate bprop for Concat"""
|
||||||
axis = self.axis
|
axis = self.axis
|
||||||
is_ascend = context.get_context('device_target') == "Ascend"
|
|
||||||
|
|
||||||
def bprop(x, out, dout):
|
def bprop(x, out, dout):
|
||||||
dx = ()
|
dx = ()
|
||||||
|
@ -318,7 +317,7 @@ def get_bprop_concat(self):
|
||||||
for i in range(input_nums):
|
for i in range(input_nums):
|
||||||
input_shapes = input_shapes + (shape_op(x[i]),)
|
input_shapes = input_shapes + (shape_op(x[i]),)
|
||||||
is_uniform = _concat_grad_uniform(input_shapes, input_nums)
|
is_uniform = _concat_grad_uniform(input_shapes, input_nums)
|
||||||
if is_uniform and is_ascend:
|
if is_uniform:
|
||||||
dx = P.Split(axis, input_nums)(dout)
|
dx = P.Split(axis, input_nums)(dout)
|
||||||
else:
|
else:
|
||||||
for i in range(input_nums):
|
for i in range(input_nums):
|
||||||
|
|
|
@ -2413,12 +2413,8 @@ class ApplyMomentum(PrimitiveWithInfer):
|
||||||
validator.check_value_type('gradient_scale', gradient_scale, [float], self.name)
|
validator.check_value_type('gradient_scale', gradient_scale, [float], self.name)
|
||||||
self.init_prim_io_names(inputs=['variable', 'accumulation', 'learning_rate', 'gradient', 'momentum'],
|
self.init_prim_io_names(inputs=['variable', 'accumulation', 'learning_rate', 'gradient', 'momentum'],
|
||||||
outputs=['output'])
|
outputs=['output'])
|
||||||
self.is_tbe = context.get_context("device_target") == "Ascend"
|
|
||||||
self.is_ge = context.get_context("enable_ge")
|
|
||||||
|
|
||||||
def infer_shape(self, v_shape, a_shape, l_shape, g_shape, m_shape):
|
def infer_shape(self, v_shape, a_shape, l_shape, g_shape, m_shape):
|
||||||
if not self.is_ge and self.is_tbe:
|
|
||||||
return v_shape, v_shape
|
|
||||||
return v_shape
|
return v_shape
|
||||||
|
|
||||||
def infer_dtype(self, v_dtype, a_dtype, l_dtype, g_dtype, m_dtype):
|
def infer_dtype(self, v_dtype, a_dtype, l_dtype, g_dtype, m_dtype):
|
||||||
|
@ -2429,9 +2425,7 @@ class ApplyMomentum(PrimitiveWithInfer):
|
||||||
validator.check_scalar_or_tensor_types_same({"l_dtype": l_dtype}, valid_dtypes, self.name)
|
validator.check_scalar_or_tensor_types_same({"l_dtype": l_dtype}, valid_dtypes, self.name)
|
||||||
validator.check_scalar_or_tensor_types_same({"g_dtype": g_dtype}, valid_dtypes, self.name)
|
validator.check_scalar_or_tensor_types_same({"g_dtype": g_dtype}, valid_dtypes, self.name)
|
||||||
validator.check_scalar_or_tensor_types_same({"m_dtype": m_dtype}, valid_dtypes, self.name)
|
validator.check_scalar_or_tensor_types_same({"m_dtype": m_dtype}, valid_dtypes, self.name)
|
||||||
if not self.is_ge and self.is_tbe:
|
return v_dtype
|
||||||
return g_dtype, g_dtype
|
|
||||||
return g_dtype
|
|
||||||
|
|
||||||
|
|
||||||
class SmoothL1Loss(PrimitiveWithInfer):
|
class SmoothL1Loss(PrimitiveWithInfer):
|
||||||
|
@ -2763,9 +2757,8 @@ class ApplyRMSProp(PrimitiveWithInfer):
|
||||||
>>> momentum = 1e-10
|
>>> momentum = 1e-10
|
||||||
>>> epsilon = 0.001
|
>>> epsilon = 0.001
|
||||||
>>> output = apply_rms(input_x, mean_square, moment, learning_rate, grad, decay, momentum, epsilon)
|
>>> output = apply_rms(input_x, mean_square, moment, learning_rate, grad, decay, momentum, epsilon)
|
||||||
>>> print(output)
|
>>> output
|
||||||
(Tensor(shape=[], dtype=Float32, value= 0.100112), Tensor(shape=[], dtype=Float32, value= 4),
|
Tensor(shape=[], dtype=Float32, value= 0.100112)
|
||||||
Tensor(shape=[], dtype=Float32, value= 0.899888))
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
|
@ -2773,16 +2766,12 @@ class ApplyRMSProp(PrimitiveWithInfer):
|
||||||
self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name)
|
self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name)
|
||||||
self.init_prim_io_names(inputs=['var', 'mean_square', 'moment', 'learning_rate', 'grad',
|
self.init_prim_io_names(inputs=['var', 'mean_square', 'moment', 'learning_rate', 'grad',
|
||||||
'rho', 'momentum', 'epsilon'], outputs=['output'])
|
'rho', 'momentum', 'epsilon'], outputs=['output'])
|
||||||
self.is_ge = context.get_context("enable_ge")
|
|
||||||
self.is_d = context.get_context("device_target") == "Ascend"
|
|
||||||
|
|
||||||
def infer_shape(self, var_shape, mean_square_shape, moment_shape, learning_rate_shape, grad_shape, decay_shape,
|
def infer_shape(self, var_shape, mean_square_shape, moment_shape, learning_rate_shape, grad_shape, decay_shape,
|
||||||
momentum_shape, epsilon_shape):
|
momentum_shape, epsilon_shape):
|
||||||
validator.check("var_shape", var_shape, "mean_square_shape", mean_square_shape, Rel.EQ, self.name)
|
validator.check("var_shape", var_shape, "mean_square_shape", mean_square_shape, Rel.EQ, self.name)
|
||||||
validator.check("var_shape", var_shape, "moment_shape", moment_shape, Rel.EQ, self.name)
|
validator.check("var_shape", var_shape, "moment_shape", moment_shape, Rel.EQ, self.name)
|
||||||
validator.check("var_shape", var_shape, "grad_shape", grad_shape, Rel.EQ, self.name)
|
validator.check("var_shape", var_shape, "grad_shape", grad_shape, Rel.EQ, self.name)
|
||||||
if not self.is_ge and self.is_d:
|
|
||||||
return var_shape, var_shape, var_shape
|
|
||||||
return var_shape
|
return var_shape
|
||||||
|
|
||||||
def infer_dtype(self, var_dtype, mean_square_dtype, moment_dtype, learning_rate_dtype, grad_dtype, decay_dtype,
|
def infer_dtype(self, var_dtype, mean_square_dtype, moment_dtype, learning_rate_dtype, grad_dtype, decay_dtype,
|
||||||
|
@ -2795,8 +2784,6 @@ class ApplyRMSProp(PrimitiveWithInfer):
|
||||||
validator.check_types_same_and_valid(args_decay, valid_dtypes, self.name)
|
validator.check_types_same_and_valid(args_decay, valid_dtypes, self.name)
|
||||||
args_lr = {"learning_rate": learning_rate_dtype, "decay": decay_dtype}
|
args_lr = {"learning_rate": learning_rate_dtype, "decay": decay_dtype}
|
||||||
validator.check_scalar_or_tensor_types_same(args_lr, valid_dtypes, self.name, allow_mix=True)
|
validator.check_scalar_or_tensor_types_same(args_lr, valid_dtypes, self.name, allow_mix=True)
|
||||||
if not self.is_ge and self.is_d:
|
|
||||||
return var_dtype, var_dtype, var_dtype
|
|
||||||
return var_dtype
|
return var_dtype
|
||||||
|
|
||||||
def infer_value(self, var, mean_square, moment, learning_rate, grad, decay, momentum, epsilon):
|
def infer_value(self, var, mean_square, moment, learning_rate, grad, decay, momentum, epsilon):
|
||||||
|
@ -2867,22 +2854,15 @@ class ApplyCenteredRMSProp(PrimitiveWithInfer):
|
||||||
>>> epsilon = 0.05
|
>>> epsilon = 0.05
|
||||||
>>> output = centered_rms_prop(input_x, mean_grad, mean_square, moment, grad,
|
>>> output = centered_rms_prop(input_x, mean_grad, mean_square, moment, grad,
|
||||||
... learning_rate, decay, momentum, epsilon)
|
... learning_rate, decay, momentum, epsilon)
|
||||||
>>> print(output)
|
>>> output
|
||||||
(Tensor(shape=[2, 2], dtype=Float32, value=
|
Tensor(shape=[2, 2], dtype=Float32, value=
|
||||||
[[-2.00000000e+00, -5.02492237e+00],
|
[[-2.00000000e+00, -5.02492237e+00],
|
||||||
[-8.04984474e+00, -1.10747662e+01]]), Tensor(shape=[2, 2], dtype=Float32, value=
|
[-8.04984474e+00, -1.10747662e+01]])
|
||||||
[[ 0.00000000e+00, 1.00000000e+00],
|
|
||||||
[ 2.00000000e+00, 3.00000000e+00]]), Tensor(shape=[2, 2], dtype=Float32, value=
|
|
||||||
[[ 0.00000000e+00, 1.00000000e+00],
|
|
||||||
[ 4.00000000e+00, 9.00000000e+00]]), Tensor(shape=[2, 2], dtype=Float32, value=
|
|
||||||
[[ 0.00000000e+00, 4.02492237e+00],
|
|
||||||
[ 8.04984474e+00, 1.20747662e+01]]))
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
def __init__(self, use_locking=False):
|
def __init__(self, use_locking=False):
|
||||||
self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name)
|
self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name)
|
||||||
self.is_ascend = context.get_context("device_target") == "Ascend"
|
|
||||||
|
|
||||||
def infer_shape(self, var_shape, mean_gradient_shape, mean_square_shape, moment_shape, grad_shape,
|
def infer_shape(self, var_shape, mean_gradient_shape, mean_square_shape, moment_shape, grad_shape,
|
||||||
learning_rate_shape, decay_shape, momentum_shape, epsilon_shape):
|
learning_rate_shape, decay_shape, momentum_shape, epsilon_shape):
|
||||||
|
@ -2890,8 +2870,6 @@ class ApplyCenteredRMSProp(PrimitiveWithInfer):
|
||||||
validator.check("var_shape", var_shape, "mean_square_shape", mean_square_shape, Rel.EQ, self.name)
|
validator.check("var_shape", var_shape, "mean_square_shape", mean_square_shape, Rel.EQ, self.name)
|
||||||
validator.check("var_shape", var_shape, "moment_shape", moment_shape, Rel.EQ, self.name)
|
validator.check("var_shape", var_shape, "moment_shape", moment_shape, Rel.EQ, self.name)
|
||||||
validator.check("var_shape", var_shape, "grad_shape", grad_shape, Rel.EQ, self.name)
|
validator.check("var_shape", var_shape, "grad_shape", grad_shape, Rel.EQ, self.name)
|
||||||
if self.is_ascend:
|
|
||||||
return var_shape, mean_gradient_shape, mean_square_shape, moment_shape
|
|
||||||
return var_shape
|
return var_shape
|
||||||
|
|
||||||
def infer_dtype(self, var_dtype, mean_gradient_dtype, mean_square_dtype, moment_dtype, grad_dtype,
|
def infer_dtype(self, var_dtype, mean_gradient_dtype, mean_square_dtype, moment_dtype, grad_dtype,
|
||||||
|
@ -2905,8 +2883,6 @@ class ApplyCenteredRMSProp(PrimitiveWithInfer):
|
||||||
validator.check_types_same_and_valid(args_rho, valid_dtypes, self.name)
|
validator.check_types_same_and_valid(args_rho, valid_dtypes, self.name)
|
||||||
args_lr = {"learning_rate": learning_rate_dtype, "rho": rho_dtype}
|
args_lr = {"learning_rate": learning_rate_dtype, "rho": rho_dtype}
|
||||||
validator.check_scalar_or_tensor_types_same(args_lr, valid_dtypes, self.name, allow_mix=True)
|
validator.check_scalar_or_tensor_types_same(args_lr, valid_dtypes, self.name, allow_mix=True)
|
||||||
if self.is_ascend:
|
|
||||||
return var_dtype, mean_gradient_dtype, mean_square_dtype, moment_dtype
|
|
||||||
return var_dtype
|
return var_dtype
|
||||||
|
|
||||||
|
|
||||||
|
@ -6176,15 +6152,8 @@ class ApplyFtrl(PrimitiveWithInfer):
|
||||||
Default: -0.5. It must be a float number or a scalar tensor with float16 or float32 data type.
|
Default: -0.5. It must be a float number or a scalar tensor with float16 or float32 data type.
|
||||||
|
|
||||||
Outputs:
|
Outputs:
|
||||||
There are three outputs for Ascend environment.
|
- **var** (Tensor) - represents the updated `var`. As the input parameters has been updated in-place, this
|
||||||
|
value is always zero when the platforms is GPU.
|
||||||
- **var** (Tensor) - represents the updated `var`.
|
|
||||||
- **accum** (Tensor) - represents the updated `accum`.
|
|
||||||
- **linear** (Tensor) - represents the updated `linear`.
|
|
||||||
|
|
||||||
There is only one output for GPU environment.
|
|
||||||
|
|
||||||
- **var** (Tensor) - This value is always zero and the input parameters has been updated in-place.
|
|
||||||
|
|
||||||
Supported Platforms:
|
Supported Platforms:
|
||||||
``Ascend`` ``GPU``
|
``Ascend`` ``GPU``
|
||||||
|
@ -6217,26 +6186,10 @@ class ApplyFtrl(PrimitiveWithInfer):
|
||||||
>>> net = ApplyFtrlNet()
|
>>> net = ApplyFtrlNet()
|
||||||
>>> input_x = Tensor(np.random.randint(-4, 4, (2, 2)), mindspore.float32)
|
>>> input_x = Tensor(np.random.randint(-4, 4, (2, 2)), mindspore.float32)
|
||||||
>>> output = net(input_x)
|
>>> output = net(input_x)
|
||||||
>>> is_tbe = context.get_context("device_target") == "Ascend"
|
>>> output
|
||||||
>>> if is_tbe:
|
Tensor(shape=[2, 2], dtype=Float32, value=
|
||||||
... print(output)
|
|
||||||
(Tensor(shape=[2, 2], dtype=Float32, value=
|
|
||||||
[[ 4.61418092e-01, 5.30964255e-01],
|
[[ 4.61418092e-01, 5.30964255e-01],
|
||||||
[ 2.68715084e-01, 3.82065028e-01]]), Tensor(shape=[2, 2], dtype=Float32, value=
|
[ 2.68715084e-01, 3.82065028e-01]])
|
||||||
[[ 1.64236546e+01, 9.64589405e+00],
|
|
||||||
[ 1.43758726e+00, 9.89177322e+00]]), Tensor(shape=[2, 2], dtype=Float32, value=
|
|
||||||
[[-1.86994812e+03, -1.64906018e+03],
|
|
||||||
[-3.22187836e+02, -1.20163989e+03]]))
|
|
||||||
... else:
|
|
||||||
... print(net.var.asnumpy())
|
|
||||||
[[0.4614181 0.5309642 ]
|
|
||||||
[0.2687151 0.38206503]]
|
|
||||||
... print(net.accum.asnumpy())
|
|
||||||
[[16.423655 9.645894 ]
|
|
||||||
[ 1.4375873 9.891773 ]]
|
|
||||||
... print(net.linear.asnumpy())
|
|
||||||
[[-1869.9479 -1649.0599]
|
|
||||||
[ -322.1879 -1201.6399]]
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
|
@ -6244,14 +6197,11 @@ class ApplyFtrl(PrimitiveWithInfer):
|
||||||
self.init_prim_io_names(inputs=['var', 'accum', 'linear', 'grad', 'lr', 'l1', 'l2', 'lr_power'],
|
self.init_prim_io_names(inputs=['var', 'accum', 'linear', 'grad', 'lr', 'l1', 'l2', 'lr_power'],
|
||||||
outputs=['output'])
|
outputs=['output'])
|
||||||
self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name)
|
self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name)
|
||||||
self.is_tbe = context.get_context("device_target") == "Ascend"
|
|
||||||
|
|
||||||
def infer_shape(self, var_shape, accum_shape, linear_shape, grad_shape, lr_shape, l1_shape, l2_shape,
|
def infer_shape(self, var_shape, accum_shape, linear_shape, grad_shape, lr_shape, l1_shape, l2_shape,
|
||||||
lr_power_shape):
|
lr_power_shape):
|
||||||
validator.check('var shape', var_shape, 'accum shape', accum_shape, Rel.EQ, self.name)
|
validator.check('var shape', var_shape, 'accum shape', accum_shape, Rel.EQ, self.name)
|
||||||
validator.check('var shape', var_shape, 'linear shape', linear_shape, Rel.EQ, self.name)
|
validator.check('var shape', var_shape, 'linear shape', linear_shape, Rel.EQ, self.name)
|
||||||
if self.is_tbe:
|
|
||||||
return var_shape, var_shape, var_shape
|
|
||||||
return var_shape
|
return var_shape
|
||||||
|
|
||||||
def infer_dtype(self, var_type, accum_type, linear_type, grad_type, lr_type, l1_type, l2_type, lr_power_type):
|
def infer_dtype(self, var_type, accum_type, linear_type, grad_type, lr_type, l1_type, l2_type, lr_power_type):
|
||||||
|
@ -6263,8 +6213,6 @@ class ApplyFtrl(PrimitiveWithInfer):
|
||||||
validator.check_scalar_or_tensor_types_same({"l1": l1_type}, valid_dtypes, self.name)
|
validator.check_scalar_or_tensor_types_same({"l1": l1_type}, valid_dtypes, self.name)
|
||||||
validator.check_scalar_or_tensor_types_same({"l2": l2_type}, valid_dtypes, self.name)
|
validator.check_scalar_or_tensor_types_same({"l2": l2_type}, valid_dtypes, self.name)
|
||||||
validator.check_scalar_or_tensor_types_same({"lr_power": lr_power_type}, valid_dtypes, self.name)
|
validator.check_scalar_or_tensor_types_same({"lr_power": lr_power_type}, valid_dtypes, self.name)
|
||||||
if self.is_tbe:
|
|
||||||
return var_type, var_type, var_type
|
|
||||||
return var_type
|
return var_type
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -48,7 +48,8 @@ class MockInsertMemcpyForHcclKernelQuery : public KernelQuery {
|
||||||
if (!node->isa<CNode>()) {
|
if (!node->isa<CNode>()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
return AnfAlgo::GetCNodeName(node->cast<CNodePtr>()) == "ApplyMomentum";
|
auto node_name = AnfAlgo::GetCNodeName(node->cast<CNodePtr>());
|
||||||
|
return node_name == "ApplyMomentum" || node_name == "AssignAdd";
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -103,9 +104,9 @@ TEST_F(TestHWInsertMemcpyForHccl, test_cond3) {
|
||||||
get_py_fun_.SetDoResolve(true);
|
get_py_fun_.SetDoResolve(true);
|
||||||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_insert_memcpy_async_for_hccl_op_cond3", "before");
|
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_insert_memcpy_async_for_hccl_op_cond3", "before");
|
||||||
ASSERT_TRUE(g != nullptr);
|
ASSERT_TRUE(g != nullptr);
|
||||||
std::vector<int64_t> shp_x{1, 64, 112, 112};
|
std::vector<int64_t> shp_x{3, 2};
|
||||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x);
|
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x);
|
||||||
AbstractBasePtrList args_spec_list{x_abstract, x_abstract, x_abstract, x_abstract, x_abstract};
|
AbstractBasePtrList args_spec_list{x_abstract, x_abstract};
|
||||||
auto kg = GetKernelGraph(g, args_spec_list);
|
auto kg = GetKernelGraph(g, args_spec_list);
|
||||||
EXPECT_NE(kg, nullptr);
|
EXPECT_NE(kg, nullptr);
|
||||||
|
|
||||||
|
|
|
@ -1,74 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
#include "common/backend_common_test.h"
|
|
||||||
#include "common/py_func_graph_fetcher.h"
|
|
||||||
#include "debug/anf_ir_dump.h"
|
|
||||||
|
|
||||||
#define private public
|
|
||||||
#define protected public
|
|
||||||
#include "backend/optimizer/ascend/ir_fusion/add_input_to_output.h"
|
|
||||||
#undef private
|
|
||||||
#undef protected
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace opt {
|
|
||||||
class TestHWAddInputToOutput : public BackendCommon {
|
|
||||||
public:
|
|
||||||
TestHWAddInputToOutput() : getPyFun_("gtest_input.pre_activate.add_input_to_output_test", true) {}
|
|
||||||
~TestHWAddInputToOutput() override = default;
|
|
||||||
|
|
||||||
public:
|
|
||||||
UT::PyFuncGraphFetcher getPyFun_;
|
|
||||||
};
|
|
||||||
|
|
||||||
class MockOpFinder : public OpFinder {
|
|
||||||
public:
|
|
||||||
MockOpFinder() = default;
|
|
||||||
~MockOpFinder() override = default;
|
|
||||||
int GetOpRegisteredOutputNum(const std::string &op_name, const CNodePtr &cnode) override { return 2; }
|
|
||||||
};
|
|
||||||
|
|
||||||
TEST_F(TestHWAddInputToOutput, test_add_input_to_output) {
|
|
||||||
FuncGraphPtr g = getPyFun_.CallAndParseRet("test_add_input_to_output", "before");
|
|
||||||
EXPECT_NE(g, nullptr);
|
|
||||||
std::vector<int64_t> shp{2, 32, 224, 224};
|
|
||||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
|
||||||
AbstractBasePtrList args_spec_list;
|
|
||||||
for (size_t i = 0; i < 5; ++i) {
|
|
||||||
args_spec_list.push_back(x_abstract);
|
|
||||||
}
|
|
||||||
auto kg = GetKernelGraph(g, args_spec_list);
|
|
||||||
EXPECT_NE(kg, nullptr);
|
|
||||||
auto ret = kg->get_return();
|
|
||||||
EXPECT_NE(ret, nullptr);
|
|
||||||
auto make_tuple = ret->input(1);
|
|
||||||
EXPECT_NE(make_tuple, nullptr);
|
|
||||||
auto momentum = make_tuple->cast<CNodePtr>()->input(1);
|
|
||||||
EXPECT_NE(momentum, nullptr);
|
|
||||||
EXPECT_NE(momentum->abstract(), nullptr);
|
|
||||||
EXPECT_FALSE(momentum->abstract()->isa<abstract::AbstractTuple>());
|
|
||||||
|
|
||||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
|
||||||
auto pm = std::make_shared<opt::PassManager>();
|
|
||||||
auto pass = std::make_shared<opt::AddInputToOutput>();
|
|
||||||
pass->op_finder_ = std::make_shared<MockOpFinder>();
|
|
||||||
pm->AddPass(pass);
|
|
||||||
optimizer->AddPassManager(pm);
|
|
||||||
(void)optimizer->Optimize(kg);
|
|
||||||
EXPECT_TRUE(momentum->abstract()->isa<abstract::AbstractTuple>());
|
|
||||||
}
|
|
||||||
} // namespace opt
|
|
||||||
} // namespace mindspore
|
|
|
@ -1,39 +0,0 @@
|
||||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
from mindspore.ops import operations as P
|
|
||||||
|
|
||||||
ApplyMomentum = P.ApplyMomentum()
|
|
||||||
|
|
||||||
|
|
||||||
class FnDict:
|
|
||||||
def __init__(self):
|
|
||||||
self.fnDict = {}
|
|
||||||
|
|
||||||
def __call__(self, fn):
|
|
||||||
self.fnDict[fn.__name__] = fn
|
|
||||||
|
|
||||||
def __getitem__(self, name):
|
|
||||||
return self.fnDict[name]
|
|
||||||
|
|
||||||
|
|
||||||
def test_add_input_to_output(tag):
|
|
||||||
fns = FnDict()
|
|
||||||
|
|
||||||
@fns
|
|
||||||
def before(input0, input1, input2, input3, input4):
|
|
||||||
return ApplyMomentum(input0, input1, input2, input3, input4)
|
|
||||||
|
|
||||||
return fns[tag]
|
|
|
@ -22,7 +22,7 @@ broadcast = P.Broadcast(1)
|
||||||
memcpy_async = Primitive('memcpy_async')
|
memcpy_async = Primitive('memcpy_async')
|
||||||
make_tuple = Primitive('make_tuple')
|
make_tuple = Primitive('make_tuple')
|
||||||
tuple_getitem = Primitive(Constants.kTupleGetItem)
|
tuple_getitem = Primitive(Constants.kTupleGetItem)
|
||||||
apply_momentun = P.ApplyMomentum()
|
assign_add = P.AssignAdd()
|
||||||
control_depend = P.ControlDepend()
|
control_depend = P.ControlDepend()
|
||||||
relu = P.ReLU()
|
relu = P.ReLU()
|
||||||
|
|
||||||
|
@ -84,14 +84,14 @@ def test_insert_memcpy_async_for_hccl_op_cond3(tag):
|
||||||
fns = FnDict()
|
fns = FnDict()
|
||||||
|
|
||||||
@fns
|
@fns
|
||||||
def before(a, b, c, d, e):
|
def before(a, b):
|
||||||
res = apply_momentun(a, b, c, d, e)
|
res = assign_add(a, b)
|
||||||
res = all_reduce(res)
|
res = all_reduce(res)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
@fns
|
@fns
|
||||||
def after(a, b, c, d, e):
|
def after(a, b):
|
||||||
res = apply_momentun(a, b, c, d, e)
|
res = assign_add(a, b)
|
||||||
res = memcpy_async(res)
|
res = memcpy_async(res)
|
||||||
res = all_reduce(res)
|
res = all_reduce(res)
|
||||||
return make_tuple(res)
|
return make_tuple(res)
|
||||||
|
|
|
@ -48,6 +48,6 @@ def test_momentum_lossscale_fusion(tag):
|
||||||
|
|
||||||
@fns
|
@fns
|
||||||
def after(input0, input1, input2, input3, input4):
|
def after(input0, input1, input2, input3, input4):
|
||||||
return make_tuple(FusedMulApplyMomentum(input0, input1, input2, input3, input4, constant))
|
return make_tuple(tuple_getitem(FusedMulApplyMomentum(input0, input1, input2, input3, input4, constant), 0))
|
||||||
|
|
||||||
return fns[tag]
|
return fns[tag]
|
||||||
|
|
Loading…
Reference in New Issue