forked from mindspore-Ecosystem/mindspore
!121 Add a checking mechanism for the need of Renormalize pass in Parse pipeline
Merge pull request !121 from thlinh/dev_Apr02_add_watch_for_renormalize
This commit is contained in:
commit
d04a62b700
|
@ -52,7 +52,8 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
|
|||
special_op_eliminate_ = MakeSubstitution(SpecialOpEliminater(), "special_op_eliminate",
|
||||
{prim::kPrimInsertGradientOf, prim::kPrimPrintShapeType,
|
||||
prim::kPrimGetRefKey, prim::kPrimMirror, prim::kPrimVirtualDiv});
|
||||
zero_like_fill_zero_ = MakeSubstitution(ZeroLikeFillZero(), "zero_like_fill_zero", prim::kPrimZerosLikeTensor);
|
||||
zero_like_fill_zero_ =
|
||||
MakeSubstitution(ZeroLikeFillZero(), "zero_like_fill_zero", prim::kPrimZerosLikeTensor, opt::FORCE_RENORM);
|
||||
|
||||
// ops eliminate
|
||||
item_tuple_eliminate_ =
|
||||
|
@ -81,7 +82,9 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
|
|||
make_ref_eliminate_ = MakeSubstitution(MakeRefEliminater(), "make_ref_eliminate", prim::kPrimMakeRef);
|
||||
get_make_ref_eliminate_ =
|
||||
MakeSubstitution(GetMakeRefEliminater(), "get_make_ref_eliminate", {prim::kPrimGetRefKey, prim::kPrimGetRefValue});
|
||||
replace_refkey_by_param_ = MakeSubstitution(ReplaceRefkeyByParam(), "replace_refkey_by_param", IsValueNode<RefKey>);
|
||||
|
||||
replace_refkey_by_param_ =
|
||||
MakeSubstitution(ReplaceRefkeyByParam(), "replace_refkey_by_param", IsValueNode<RefKey>, opt::FORCE_RENORM);
|
||||
replace_old_param_ = MakeSubstitution(ReplaceOldParam(), "replace_old_param", IsParam);
|
||||
|
||||
// Gradient transforms
|
||||
|
|
|
@ -31,14 +31,14 @@
|
|||
namespace mindspore {
|
||||
/* namespace to support opt */
|
||||
namespace opt {
|
||||
SubstitutionPtr MakeSubstitution(const TransformFuncType& transform, const std::string& name,
|
||||
const PrimitivePtr& prim) {
|
||||
SubstitutionPtr MakeSubstitution(const TransformFuncType& transform, const std::string& name, const PrimitivePtr& prim,
|
||||
const RenormAction& renorm_action) {
|
||||
auto fn = [prim](const AnfNodePtr& node) -> bool { return IsPrimitiveCNode(node, prim); };
|
||||
return std::make_shared<Substitution>(transform, name, fn);
|
||||
return std::make_shared<Substitution>(transform, name, fn, renorm_action);
|
||||
}
|
||||
|
||||
SubstitutionPtr MakeSubstitution(const TransformFuncType& transform, const std::string& name,
|
||||
const std::vector<PrimitivePtr>& prims) {
|
||||
const std::vector<PrimitivePtr>& prims, const RenormAction& renorm_action) {
|
||||
auto fn = [prims](const AnfNodePtr& node) -> bool {
|
||||
if (!node->isa<CNode>()) {
|
||||
return false;
|
||||
|
@ -52,12 +52,12 @@ SubstitutionPtr MakeSubstitution(const TransformFuncType& transform, const std::
|
|||
return false;
|
||||
};
|
||||
|
||||
return std::make_shared<Substitution>(transform, name, fn);
|
||||
return std::make_shared<Substitution>(transform, name, fn, renorm_action);
|
||||
}
|
||||
|
||||
SubstitutionPtr MakeSubstitution(const TransformFuncType& transform, const std::string& name,
|
||||
const PredicateFuncType& predicate) {
|
||||
return std::make_shared<Substitution>(transform, name, predicate);
|
||||
const PredicateFuncType& predicate, const RenormAction& renorm_action) {
|
||||
return std::make_shared<Substitution>(transform, name, predicate, renorm_action);
|
||||
}
|
||||
|
||||
AnfNodePtr Substitution::operator()(const OptimizerPtr& optimizer, const AnfNodePtr& node) const {
|
||||
|
@ -74,6 +74,16 @@ AnfNodePtr Substitution::operator()(const OptimizerPtr& optimizer, const AnfNode
|
|||
}
|
||||
}
|
||||
#endif
|
||||
if (optimizer != nullptr && optimizer->is_watch_renormalize() && result != nullptr) {
|
||||
if (renorm_action_ == FORCE_RENORM) {
|
||||
optimizer->add_node_to_renormalize(result);
|
||||
} else {
|
||||
// renorm_action_ is CHECK_RENORM
|
||||
if (result->abstract() == nullptr) {
|
||||
optimizer->add_node_to_renormalize(result);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
|
|
@ -36,24 +36,34 @@ using OptimizerWeakPtr = std::weak_ptr<Optimizer>;
|
|||
using PredicateFuncType = std::function<bool(const AnfNodePtr &)>;
|
||||
using TransformFuncType = std::function<AnfNodePtr(const OptimizerPtr &, const AnfNodePtr &)>;
|
||||
|
||||
// Define the interaction mode between an Optimize pass and Renormalize pass
|
||||
// FORCE_RENORM: if the pass modified the graph then the next Renormalize will be executed
|
||||
// CHECK_RENORM: check if the new node is un-typed to decide if the next Renormalize will be executted
|
||||
enum RenormAction : int { FORCE_RENORM = 0, CHECK_RENORM };
|
||||
|
||||
class Substitution {
|
||||
public:
|
||||
TransformFuncType transform_{nullptr};
|
||||
std::string name_;
|
||||
PredicateFuncType predicate_{nullptr};
|
||||
explicit Substitution(const TransformFuncType &transform, const std::string &name, const PredicateFuncType &predicate)
|
||||
: transform_(transform), name_(name), predicate_(predicate) {}
|
||||
// an enum to mark this Substitution relation to renormalize pass
|
||||
RenormAction renorm_action_;
|
||||
explicit Substitution(const TransformFuncType &transform, const std::string &name, const PredicateFuncType &predicate,
|
||||
const RenormAction &renorm_action)
|
||||
: transform_(transform), name_(name), predicate_(predicate), renorm_action_(renorm_action) {}
|
||||
~Substitution() = default;
|
||||
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) const;
|
||||
};
|
||||
|
||||
using SubstitutionPtr = std::shared_ptr<Substitution>;
|
||||
|
||||
SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, const PrimitivePtr &prim);
|
||||
SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, const PrimitivePtr &prim,
|
||||
const RenormAction &action_renorm = CHECK_RENORM);
|
||||
SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name,
|
||||
const std::vector<PrimitivePtr> &prims);
|
||||
const std::vector<PrimitivePtr> &prims,
|
||||
const RenormAction &action_renorm = CHECK_RENORM);
|
||||
SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name,
|
||||
const PredicateFuncType &predicate);
|
||||
const PredicateFuncType &predicate, const RenormAction &action_renorm = CHECK_RENORM);
|
||||
|
||||
class SubstitutionList {
|
||||
public:
|
||||
|
|
|
@ -87,11 +87,12 @@ using OptPassGroupMap = std::vector<std::pair<std::string, OptPassConfig>>;
|
|||
class Optimizer : public std::enable_shared_from_this<Optimizer> {
|
||||
public:
|
||||
Optimizer(const std::string &name, const pipeline::ResourceBasePtr &resource_ptr)
|
||||
: name_(name), resource_(resource_ptr), run_only_once_(false) {}
|
||||
: name_(name), resource_(resource_ptr), run_only_once_(false), is_watch_renormalize_(false) {}
|
||||
virtual ~Optimizer() = default;
|
||||
|
||||
void Init(const OptPassGroupMap &passes, bool run_only_once) {
|
||||
run_only_once_ = run_only_once;
|
||||
is_watch_renormalize_ = false;
|
||||
|
||||
for (auto &iter : passes) {
|
||||
const std::string &name = iter.first;
|
||||
|
@ -118,9 +119,13 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> {
|
|||
}
|
||||
|
||||
static std::shared_ptr<Optimizer> MakeOptimizer(const std::string &name, const pipeline::ResourceBasePtr resource_ptr,
|
||||
const OptPassGroupMap &passes, bool run_only_once = false) {
|
||||
const OptPassGroupMap &passes, bool run_only_once = false,
|
||||
bool watch_renormalize = false) {
|
||||
OptimizerPtr optimizer = std::make_shared<Optimizer>(name, resource_ptr);
|
||||
optimizer->Init(passes, run_only_once);
|
||||
if (watch_renormalize) {
|
||||
optimizer->enable_watch_renormalize();
|
||||
}
|
||||
return optimizer;
|
||||
}
|
||||
|
||||
|
@ -138,7 +143,16 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> {
|
|||
if (opt.is_renormalize()) {
|
||||
auto resource_ptr = std::dynamic_pointer_cast<pipeline::Resource>(resource_);
|
||||
if (resource_ptr != nullptr) {
|
||||
func_graph = pipeline::Renormalize(resource_ptr, func_graph, args_spec);
|
||||
if (is_watch_renormalize_) {
|
||||
if (untyped_nodes_.size() > 0) {
|
||||
func_graph = pipeline::Renormalize(resource_ptr, func_graph, args_spec);
|
||||
clear_untyped_nodes();
|
||||
} else {
|
||||
MS_LOG(INFO) << "Optimizer::step: Skipping Renormalize because untyped_nodes_ is empty.";
|
||||
}
|
||||
} else {
|
||||
func_graph = pipeline::Renormalize(resource_ptr, func_graph, args_spec);
|
||||
}
|
||||
}
|
||||
} else if (opt(func_graph, shared_from_this())) {
|
||||
changes = true;
|
||||
|
@ -180,12 +194,26 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> {
|
|||
|
||||
const std::string name() const { return name_; }
|
||||
|
||||
void add_node_to_renormalize(AnfNodePtr anode) {
|
||||
if (std::find(untyped_nodes_.begin(), untyped_nodes_.end(), anode) == untyped_nodes_.end()) {
|
||||
untyped_nodes_.push_back(anode);
|
||||
}
|
||||
}
|
||||
|
||||
void clear_untyped_nodes() { untyped_nodes_.clear(); }
|
||||
|
||||
void enable_watch_renormalize() { is_watch_renormalize_ = true; }
|
||||
void disable_watch_renormalize() { is_watch_renormalize_ = false; }
|
||||
bool is_watch_renormalize() { return is_watch_renormalize_; }
|
||||
|
||||
private:
|
||||
const std::string name_;
|
||||
pipeline::ResourceBasePtr resource_;
|
||||
std::vector<OptPass> passes_;
|
||||
std::vector<std::string> pass_names_;
|
||||
bool run_only_once_;
|
||||
std::vector<AnfNodePtr> untyped_nodes_;
|
||||
bool is_watch_renormalize_;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -185,8 +185,8 @@ void InitOpt(const ResourcePtr& res) {
|
|||
if (g_pass_opts.size() == 0) {
|
||||
opt::irpass::OptimizeIRPassLib irpass;
|
||||
g_pass_opts["opt_a"] = Optimizer::MakeOptimizer("opt_a", res, GetOptPassesA(irpass));
|
||||
g_pass_opts["opt_b"] = Optimizer::MakeOptimizer("opt_b", res, GetOptPassesB(irpass));
|
||||
g_pass_opts["opt_control"] = Optimizer::MakeOptimizer("opt_control", res, GetControlPhases(irpass));
|
||||
g_pass_opts["opt_b"] = Optimizer::MakeOptimizer("opt_b", res, GetOptPassesB(irpass), false, true);
|
||||
g_pass_opts["opt_control"] = Optimizer::MakeOptimizer("opt_control", res, GetControlPhases(irpass), false, true);
|
||||
g_pass_opts["opt_prepare"] = Optimizer::MakeOptimizer("opt_prepare", res, GetPreparePhases(irpass));
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue