forked from mindspore-Ecosystem/mindspore
process cast when activate graph kernel in amp
This commit is contained in:
parent
7dce9f5f4e
commit
13126653ec
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
#include "backend/optimizer/graph_kernel/graph_kernel_cse.h"
|
#include "backend/optimizer/graph_kernel/graph_kernel_cse.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
@ -26,13 +27,15 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
namespace {
|
namespace {
|
||||||
bool IsCNodePrimitveEqual(const CNodePtr &main, const CNodePtr &node) {
|
bool IsCNodePrimitveEqual(const CNodePtr &main, const CNodePtr &node, const std::vector<PrimitivePtr> &black_list) {
|
||||||
auto main_primitive = AnfAlgo::GetCNodePrimitive(main);
|
auto main_primitive = AnfAlgo::GetCNodePrimitive(main);
|
||||||
auto node_primitive = AnfAlgo::GetCNodePrimitive(node);
|
auto node_primitive = AnfAlgo::GetCNodePrimitive(node);
|
||||||
if (main_primitive != nullptr && node_primitive != nullptr) {
|
if (main_primitive != nullptr && node_primitive != nullptr) {
|
||||||
// Some ops such as Reshape is not real op, cse these type will not get gain. And for ops fusion, keep these op
|
// Some ops such as Reshape is not real op, cse these type will not get gain. And for ops fusion, keep these op
|
||||||
// alone can prevent some redundant output case (input -> reshape -> output).
|
// alone can prevent some redundant output case (input -> reshape -> output).
|
||||||
if (main_primitive->name() != node_primitive->name() || IsPrimitiveCNode(node, prim::kPrimReshape)) {
|
if (main_primitive->name() != node_primitive->name() ||
|
||||||
|
std::any_of(black_list.begin(), black_list.end(),
|
||||||
|
[&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); })) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -125,12 +128,12 @@ bool GraphKernelBackendCSE::CheckEqualCnodeInputs(const AnfNodePtr &main, const
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return IsCNodePrimitveEqual(c_main, c_node);
|
return IsCNodePrimitveEqual(c_main, c_node, black_list_);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool GraphKernelCSE::Run(const FuncGraphPtr &func_graph) {
|
bool GraphKernelCSE::Run(const FuncGraphPtr &func_graph) {
|
||||||
MS_EXCEPTION_IF_NULL(func_graph);
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
auto graphkernel_backend_cse = std::make_shared<GraphKernelBackendCSE>();
|
auto graphkernel_backend_cse = std::make_shared<GraphKernelBackendCSE>(black_list_);
|
||||||
return graphkernel_backend_cse->Cse(func_graph, func_graph->manager());
|
return graphkernel_backend_cse->Cse(func_graph, func_graph->manager());
|
||||||
}
|
}
|
||||||
} // namespace opt
|
} // namespace opt
|
||||||
|
|
|
@ -13,27 +13,35 @@
|
||||||
* See the License for the specific language governing permissions and
|
* See the License for the specific language governing permissions and
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CSE_H_
|
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_CSE_H_
|
||||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CSE_H_
|
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_CSE_H_
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
#include "backend/optimizer/pass/common_subexpression_elimination.h"
|
#include "backend/optimizer/pass/common_subexpression_elimination.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
class GraphKernelCSE : public Pass {
|
class GraphKernelCSE : public Pass {
|
||||||
public:
|
public:
|
||||||
GraphKernelCSE() : Pass("graph_kernel_cse") {}
|
explicit GraphKernelCSE(const std::vector<PrimitivePtr> &black_list = {})
|
||||||
|
: Pass("graph_kernel_cse"), black_list_(black_list) {}
|
||||||
~GraphKernelCSE() override = default;
|
~GraphKernelCSE() override = default;
|
||||||
bool Run(const FuncGraphPtr &func_graph) override;
|
bool Run(const FuncGraphPtr &func_graph) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::vector<PrimitivePtr> black_list_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class GraphKernelBackendCSE : public BackendCSE {
|
class GraphKernelBackendCSE : public BackendCSE {
|
||||||
public:
|
public:
|
||||||
GraphKernelBackendCSE() = default;
|
explicit GraphKernelBackendCSE(const std::vector<PrimitivePtr> &black_list = {}) : black_list_(black_list) {}
|
||||||
~GraphKernelBackendCSE() override = default;
|
~GraphKernelBackendCSE() override = default;
|
||||||
bool CheckEqualKernelBuildInfo(const AnfNodePtr &main, const AnfNodePtr &node) const override;
|
bool CheckEqualKernelBuildInfo(const AnfNodePtr &main, const AnfNodePtr &node) const override;
|
||||||
bool CheckEqualCnodeInputs(const AnfNodePtr &main, const AnfNodePtr &node) const override;
|
bool CheckEqualCnodeInputs(const AnfNodePtr &main, const AnfNodePtr &node) const override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::vector<PrimitivePtr> black_list_;
|
||||||
};
|
};
|
||||||
} // namespace opt
|
} // namespace opt
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CSE_H_
|
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_CSE_H_
|
||||||
|
|
|
@ -34,7 +34,7 @@ namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
namespace {
|
namespace {
|
||||||
bool IsMultiUserShapeOps(AnfNodePtr node, const FuncGraphManagerPtr &mng) {
|
bool IsMultiUserShapeOps(AnfNodePtr node, const FuncGraphManagerPtr &mng) {
|
||||||
std::vector<PrimitivePtr> shape_ops = {prim::kPrimReshape};
|
std::vector<PrimitivePtr> shape_ops = {prim::kPrimReshape, prim::kPrimCast};
|
||||||
auto &users = mng->node_users();
|
auto &users = mng->node_users();
|
||||||
return std::any_of(shape_ops.begin(), shape_ops.end(),
|
return std::any_of(shape_ops.begin(), shape_ops.end(),
|
||||||
[&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); }) &&
|
[&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); }) &&
|
||||||
|
|
|
@ -120,7 +120,9 @@ void GPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
|
||||||
pm->AddPass(std::make_shared<opt::AdamFusion>());
|
pm->AddPass(std::make_shared<opt::AdamFusion>());
|
||||||
pm->AddPass(std::make_shared<opt::ApplyMomentumWeightDecayScaleFusion>());
|
pm->AddPass(std::make_shared<opt::ApplyMomentumWeightDecayScaleFusion>());
|
||||||
pm->AddPass(std::make_shared<opt::ApplyMomentumScaleFusion>());
|
pm->AddPass(std::make_shared<opt::ApplyMomentumScaleFusion>());
|
||||||
pm->AddPass(std::make_shared<opt::CastAllFusion>("cast_all"));
|
if (!(context_ptr->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL))) {
|
||||||
|
pm->AddPass(std::make_shared<opt::CastAllFusion>("cast_all"));
|
||||||
|
}
|
||||||
pm->AddPass(std::make_shared<opt::CombineMomentumFusion>("combine_momentum"));
|
pm->AddPass(std::make_shared<opt::CombineMomentumFusion>("combine_momentum"));
|
||||||
pm->AddPass(std::make_shared<opt::ReplaceMomentumCastFusion>());
|
pm->AddPass(std::make_shared<opt::ReplaceMomentumCastFusion>());
|
||||||
pm->AddPass(std::make_shared<opt::ReplaceAddNFusion>());
|
pm->AddPass(std::make_shared<opt::ReplaceAddNFusion>());
|
||||||
|
@ -165,15 +167,17 @@ void GPUSession::GraphKernelOptimize(const std::shared_ptr<KernelGraph> &kernel_
|
||||||
}
|
}
|
||||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||||
auto pm = std::make_shared<opt::PassManager>("graph_kernel_pm");
|
auto pm = std::make_shared<opt::PassManager>("graph_kernel_pm");
|
||||||
|
std::vector<PrimitivePtr> black_list = {prim::kPrimReshape, prim::kPrimCast};
|
||||||
pm->AddPass(std::make_shared<opt::GraphKernelExpander>());
|
pm->AddPass(std::make_shared<opt::GraphKernelExpander>());
|
||||||
pm->AddPass(std::make_shared<opt::ShapeOpsSplitter>());
|
pm->AddPass(std::make_shared<opt::ShapeOpsSplitter>());
|
||||||
pm->AddPass(std::make_shared<opt::BasicOpsFusion>());
|
pm->AddPass(std::make_shared<opt::BasicOpsFusion>());
|
||||||
pm->AddPass(std::make_shared<opt::CompositeOpsFusion>());
|
pm->AddPass(std::make_shared<opt::CompositeOpsFusion>());
|
||||||
pm->AddPass(std::make_shared<opt::GraphKernelCSE>());
|
pm->AddPass(std::make_shared<opt::GraphKernelCSE>(black_list));
|
||||||
pm->AddPass(std::make_shared<opt::ArithmeticSimplify>());
|
pm->AddPass(std::make_shared<opt::ArithmeticSimplify>());
|
||||||
pm->AddPass(std::make_shared<opt::GraphKernelCSE>());
|
pm->AddPass(std::make_shared<opt::GraphKernelCSE>(black_list));
|
||||||
pm->AddPass(std::make_shared<opt::TensorPromotion>());
|
pm->AddPass(std::make_shared<opt::TensorPromotion>());
|
||||||
pm->AddPass(std::make_shared<opt::GraphKernelSplitter>());
|
pm->AddPass(std::make_shared<opt::GraphKernelSplitter>());
|
||||||
|
pm->AddPass(std::make_shared<opt::GraphKernelCSE>());
|
||||||
// After Simplify and Splitter, a lot of redundant getitem/maketuple
|
// After Simplify and Splitter, a lot of redundant getitem/maketuple
|
||||||
// will be exposed, use GetitemTuple Pass to delete them.
|
// will be exposed, use GetitemTuple Pass to delete them.
|
||||||
pm->AddPass(std::make_shared<opt::GetitemTuple>());
|
pm->AddPass(std::make_shared<opt::GetitemTuple>());
|
||||||
|
|
Loading…
Reference in New Issue