forked from OSSInnovation/mindspore
!4126 Add new parameter
Merge pull request !4126 from BowenK/new_parameter
This commit is contained in:
commit
8d693306f4
|
@ -14,6 +14,7 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#include "frontend/optimizer/pass_group.h"
|
||||
#include "frontend/optimizer/py_pass_manager.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -35,14 +36,15 @@ bool PassGroup::DeletePass(const std::string &pass_name) {
|
|||
return false;
|
||||
}
|
||||
|
||||
bool PassGroup::Run(const FuncGraphPtr &func_graph, const std::vector<PythonPassPtr> &passes) const {
|
||||
bool PassGroup::Run(const FuncGraphPtr &func_graph, const std::vector<PythonPassPtr> &passes,
|
||||
const MatchResultPtr &res) const {
|
||||
if (func_graph == nullptr) {
|
||||
return false;
|
||||
}
|
||||
bool changed = false;
|
||||
for (const auto &pass : passes) {
|
||||
if (pass != nullptr) {
|
||||
if (pass->Run(func_graph)) {
|
||||
if (pass->Run(func_graph, res)) {
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
|
@ -54,8 +56,9 @@ bool PassGroup::Run(const FuncGraphPtr &func_graph) const {
|
|||
bool changed = false;
|
||||
// run all passes
|
||||
bool change = true;
|
||||
auto res = PyPassManager::GetInstance()->GetMatchResult();
|
||||
while (change) {
|
||||
change = Run(func_graph, passes_);
|
||||
change = Run(func_graph, passes_, res);
|
||||
changed = change || changed;
|
||||
if (run_only_once_) {
|
||||
break;
|
||||
|
|
|
@ -41,12 +41,14 @@ class PassGroup {
|
|||
// @return false, graph not changed
|
||||
bool Run(const FuncGraphPtr &func_graph) const;
|
||||
// Run the given graph passes on the input graph
|
||||
// @param [inout] graph The graph to be optimized
|
||||
// @param [inout] func_graph The graph to be optimized
|
||||
// @param [in] passes The given graph passes
|
||||
// @param [inout] res MatchResult used to collect all matched patterns and nodes
|
||||
// @return true, graph changed
|
||||
// @return false, graph not changed
|
||||
bool Run(const FuncGraphPtr &func_graph, const std::vector<PythonPassPtr> &passes) const;
|
||||
bool Run(const FuncGraphPtr &func_graph, const std::vector<PythonPassPtr> &passes, const MatchResultPtr &res) const;
|
||||
std::string name() const { return name_; }
|
||||
void SetRunOnlyOnce(bool run_only_once) { run_only_once_ = run_only_once; }
|
||||
|
||||
private:
|
||||
const std::string name_;
|
||||
|
|
|
@ -96,6 +96,7 @@ MatchResultPtr IsIn::match(const AnfNodePtr &node) {
|
|||
for (auto &iter : patterns_) {
|
||||
auto res = iter->match(node);
|
||||
if (res != nullptr) {
|
||||
res->add_entry(shared_from_base<IsIn>(), node);
|
||||
return res;
|
||||
}
|
||||
}
|
||||
|
@ -151,6 +152,9 @@ REGISTER_PYBIND_DEFINE(
|
|||
(void)py::class_<AnyPattern, std::shared_ptr<AnyPattern>, Pattern>(*m, "AnyPattern").def(py::init<>());
|
||||
(void)py::class_<NewTensor, std::shared_ptr<NewTensor>, Pattern>(*m, "NewTensor_")
|
||||
.def(py::init<tensor::TensorPtr>());
|
||||
(void)py::class_<NewParameter, std::shared_ptr<NewParameter>, Pattern>(*m, "NewParameter_")
|
||||
.def(py::init<string, tensor::TensorPtr, bool, bool, bool>());
|
||||
(void)py::class_<Imm, std::shared_ptr<Imm>, Pattern>(*m, "Imm").def(py::init<int>());
|
||||
}));
|
||||
} // namespace python_pass
|
||||
} // namespace opt
|
||||
|
|
|
@ -42,6 +42,10 @@ class CallWith;
|
|||
using CallWithPtr = std::shared_ptr<CallWith>;
|
||||
class NewTensor;
|
||||
using NewTensorPtr = std::shared_ptr<NewTensor>;
|
||||
class NewParameter;
|
||||
using NewParameterPtr = std::shared_ptr<NewParameter>;
|
||||
class Imm;
|
||||
using ImmPtr = std::shared_ptr<Imm>;
|
||||
struct PatternHasher;
|
||||
struct PatternEqual;
|
||||
using PatternNodeMap = std::unordered_map<PatternPtr, AnfNodePtr, PatternHasher, PatternEqual>;
|
||||
|
@ -55,6 +59,7 @@ class Pattern : public Base {
|
|||
string unique_name() const { return unique_name_; }
|
||||
vector<PatternPtr> inputs() { return inputs_; }
|
||||
bool should_replace() { return should_replace_; }
|
||||
void set_should_replace(bool should_replace) { should_replace_ = should_replace; }
|
||||
virtual void reset() {}
|
||||
|
||||
protected:
|
||||
|
@ -86,14 +91,14 @@ class IsPrimTypeOf : public Pattern {
|
|||
~IsPrimTypeOf() = default;
|
||||
IsPrimTypeOf(vector<PrimitivePyPtr> prims, string name, bool should_replace)
|
||||
: primitives_(prims), name_(name), matched_prim_(nullptr) {
|
||||
unique_name_ = std::to_string(g_id_++) + "_" + name;
|
||||
unique_name_ = std::to_string(g_id_++) + "IsPrimTypeOf_" + name;
|
||||
should_replace_ = should_replace;
|
||||
if (!should_replace) {
|
||||
matched_prim_ = prims[0];
|
||||
}
|
||||
}
|
||||
IsPrimTypeOf(vector<string> types, string name, bool should_replace) : types_(types), name_(name) {
|
||||
unique_name_ = std::to_string(g_id_++) + "_" + name;
|
||||
unique_name_ = std::to_string(g_id_++) + "IsPrimTypeOf_" + name;
|
||||
// Make primitives_
|
||||
for (auto &iter : types) {
|
||||
primitives_.push_back(std::make_shared<PrimitivePy>(iter, py::cast(nullptr)));
|
||||
|
@ -126,19 +131,20 @@ class CallWith : public Pattern {
|
|||
CallWith(PatternPtr prim_pattern, vector<PatternPtr> inputs, bool should_replace) {
|
||||
// NOTE: should_replace is ignored in this case, since each sub-pattern has its own setting
|
||||
prim_pattern_ = prim_pattern;
|
||||
unique_name_ = std::to_string(g_id_++) + prim_pattern->unique_name();
|
||||
unique_name_ = std::to_string(g_id_++) + "CallWithPattern_" + prim_pattern->unique_name();
|
||||
inputs_ = inputs;
|
||||
should_replace_ = should_replace;
|
||||
// NOTE: should_replace_ is overrided by it prim_pattern(if exists) silently.
|
||||
should_replace_ = prim_pattern->should_replace();
|
||||
}
|
||||
CallWith(PrimitivePyPtr prim, vector<PatternPtr> inputs, bool should_replace) {
|
||||
prim_ = prim;
|
||||
unique_name_ = std::to_string(g_id_++) + prim_->ToString();
|
||||
unique_name_ = std::to_string(g_id_++) + "CallWithPrim_" + prim_->ToString();
|
||||
inputs_ = inputs;
|
||||
should_replace_ = should_replace;
|
||||
}
|
||||
CallWith(string prim_str, vector<PatternPtr> inputs, bool should_replace) {
|
||||
prim_ = std::make_shared<PrimitivePy>(prim_str, py::cast(nullptr));
|
||||
unique_name_ = std::to_string(g_id_++) + prim_->ToString();
|
||||
unique_name_ = std::to_string(g_id_++) + "CallWithStr_" + prim_->ToString();
|
||||
inputs_ = inputs;
|
||||
should_replace_ = should_replace;
|
||||
}
|
||||
|
@ -159,7 +165,7 @@ class IsIn : public Pattern {
|
|||
IsIn() { unique_name_ = std::to_string(g_id_++); }
|
||||
~IsIn() = default;
|
||||
explicit IsIn(vector<PatternPtr> patterns) : patterns_(patterns) {
|
||||
unique_name_ = std::to_string(g_id_++);
|
||||
unique_name_ = std::to_string(g_id_++) + "IsIn";
|
||||
for (auto &iter : patterns) {
|
||||
unique_name_ = unique_name_ + "_" + iter->unique_name();
|
||||
}
|
||||
|
@ -176,9 +182,9 @@ class IsNot : public Pattern {
|
|||
IsNot() { unique_name_ = std::to_string(g_id_++); }
|
||||
~IsNot() = default;
|
||||
explicit IsNot(vector<PatternPtr> patterns) : patterns_(patterns) {
|
||||
unique_name_ = std::to_string(g_id_++);
|
||||
unique_name_ = std::to_string(g_id_++) + "IsNot";
|
||||
for (auto &iter : patterns) {
|
||||
unique_name_ = "IsNot_" + unique_name_ + "_" + iter->unique_name();
|
||||
unique_name_ = unique_name_ + "_" + iter->unique_name();
|
||||
}
|
||||
}
|
||||
MS_DECLARE_PARENT(IsNot, Pattern);
|
||||
|
@ -200,7 +206,10 @@ class NewTensor : public Pattern {
|
|||
public:
|
||||
NewTensor() { unique_name_ = std::to_string(g_id_++); }
|
||||
~NewTensor() = default;
|
||||
explicit NewTensor(tensor::TensorPtr input_tensor) : input_tensor_(input_tensor) { should_replace_ = false; }
|
||||
explicit NewTensor(tensor::TensorPtr input_tensor) : input_tensor_(input_tensor) {
|
||||
should_replace_ = false;
|
||||
unique_name_ = std::to_string(g_id_++) + "NewTensor";
|
||||
}
|
||||
MS_DECLARE_PARENT(NewTensor, Pattern);
|
||||
MatchResultPtr match(const AnfNodePtr &node) override {
|
||||
MS_LOG(EXCEPTION) << "Find NewTensor in pattern, NewTensor should only appear in the target.\n";
|
||||
|
@ -211,6 +220,54 @@ class NewTensor : public Pattern {
|
|||
tensor::TensorPtr input_tensor_;
|
||||
};
|
||||
|
||||
class NewParameter : public Pattern {
|
||||
public:
|
||||
NewParameter() { unique_name_ = std::to_string(g_id_++); }
|
||||
explicit NewParameter(string para_name, tensor::TensorPtr default_tensor, bool requires_grad, bool layerwise_parallel,
|
||||
bool should_replace)
|
||||
: para_name_(para_name), requires_grad_(requires_grad), layerwise_parallel_(layerwise_parallel) {
|
||||
should_replace_ = should_replace;
|
||||
unique_name_ = std::to_string(g_id_++) + "NewParameter_" + para_name;
|
||||
// clone input tensor
|
||||
default_tensor_ = std::make_shared<tensor::Tensor>(*default_tensor.get());
|
||||
built_ = false;
|
||||
}
|
||||
MS_DECLARE_PARENT(NewParameter, Pattern);
|
||||
MatchResultPtr match(const AnfNodePtr &node) override {
|
||||
MS_LOG(EXCEPTION) << "Find NewParameter in pattern, NewParameter should only appear in the target.\n";
|
||||
}
|
||||
string para_name() { return para_name_; }
|
||||
tensor::TensorPtr default_tensor() { return default_tensor_; }
|
||||
bool requires_grad() { return requires_grad_; }
|
||||
bool layerwise_parallel() { return layerwise_parallel_; }
|
||||
bool built() { return built_; }
|
||||
void set_built(bool built) { built_ = built; }
|
||||
void reset() override { built_ = false; }
|
||||
|
||||
private:
|
||||
string para_name_;
|
||||
bool requires_grad_;
|
||||
bool layerwise_parallel_;
|
||||
bool built_;
|
||||
tensor::TensorPtr default_tensor_;
|
||||
};
|
||||
|
||||
class Imm : public Pattern {
|
||||
public:
|
||||
Imm() { unique_name_ = std::to_string(g_id_++); }
|
||||
explicit Imm(int value) : value_(value) {
|
||||
should_replace_ = false;
|
||||
unique_name_ = std::to_string(g_id_++) + "Imm_" + std::to_string(value);
|
||||
}
|
||||
MS_DECLARE_PARENT(Imm, Pattern);
|
||||
// NOTE: Doesn't support Imm in src pattern currently.
|
||||
MatchResultPtr match(const AnfNodePtr &node) override { return nullptr; }
|
||||
int value() { return value_; }
|
||||
|
||||
private:
|
||||
int value_;
|
||||
};
|
||||
|
||||
class MatchResult {
|
||||
public:
|
||||
MatchResult() {}
|
||||
|
|
|
@ -21,13 +21,26 @@
|
|||
#include "ir/func_graph.h"
|
||||
#include "ir/manager.h"
|
||||
#include "pybind_api/ir/primitive_py.h"
|
||||
#include "ir/scalar.h"
|
||||
#include "ir/graph_utils.h"
|
||||
#include "pipeline/jit/parse/parse_base.h"
|
||||
#include "pipeline/jit/resource.h"
|
||||
#include "frontend/optimizer/py_pass_manager.h"
|
||||
#include "utils/info.h"
|
||||
#include "debug/anf_ir_dump.h"
|
||||
#include "debug/draw.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace python_pass {
|
||||
namespace internal {
|
||||
AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr &res);
|
||||
const char PARAMETER_MODULE[] = "mindspore.common.parameter";
|
||||
const char PARAMETER_CLASS[] = "Parameter";
|
||||
const char SET_PARAM[] = "__setattr__";
|
||||
AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &func_graph);
|
||||
AnfNodePtr BuildTarget(const PatternPtr &pattern, const FuncGraphPtr &func_graph, const MatchResultPtr &res);
|
||||
void ReflectParamBackToPython(const AnfNodePtr ¶m, string param_name, tensor::TensorPtr default_input,
|
||||
bool requires_grad, bool layerwise_parallel);
|
||||
|
||||
std::string GetNodeRepr(AnfNodePtr node) {
|
||||
if (node != nullptr) {
|
||||
|
@ -42,8 +55,10 @@ std::string GetNodeRepr(AnfNodePtr node) {
|
|||
repr += ")";
|
||||
return repr;
|
||||
}
|
||||
if (node->isa<ValueNode>()) {
|
||||
return GetValueNode(node)->ToString();
|
||||
if (node->isa<Parameter>()) {
|
||||
return "[Parameter]" + node->ToString();
|
||||
} else if (node->isa<ValueNode>()) {
|
||||
return "[Value]" + GetValueNode(node)->ToString();
|
||||
}
|
||||
return node->ToString();
|
||||
}
|
||||
|
@ -82,7 +97,7 @@ AnfNodePtr BuildNewTensor(const PatternPtr &pattern, const MatchResultPtr &res)
|
|||
return std::make_shared<ValueNode>(input_tensor);
|
||||
}
|
||||
|
||||
AnfNodePtr BuildPrimitiveValueNode(const PatternPtr &pattern, const MatchResultPtr &res) {
|
||||
AnfNodePtr BuildPrimitiveValueNode(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &fg) {
|
||||
auto call_with_pattern = pattern->cast<CallWithPtr>();
|
||||
MS_EXCEPTION_IF_NULL(call_with_pattern);
|
||||
auto prim = call_with_pattern->prim_value();
|
||||
|
@ -91,15 +106,70 @@ AnfNodePtr BuildPrimitiveValueNode(const PatternPtr &pattern, const MatchResultP
|
|||
}
|
||||
auto prim_pattern = call_with_pattern->prim_pattern();
|
||||
MS_EXCEPTION_IF_NULL(prim_pattern);
|
||||
return ProcessSinglePattern(prim_pattern, res);
|
||||
return ProcessSinglePattern(prim_pattern, res, fg);
|
||||
}
|
||||
|
||||
AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr &res) {
|
||||
AnfNodePtr BuildNewParameter(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &func_graph) {
|
||||
auto new_para_pattern = pattern->cast<NewParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(new_para_pattern);
|
||||
if (!new_para_pattern->built()) {
|
||||
static int parameter_id = 0;
|
||||
auto para_name = new_para_pattern->para_name() + new_para_pattern->unique_name() + std::to_string(parameter_id++);
|
||||
auto para_node = std::make_shared<Parameter>(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(para_node);
|
||||
para_node->set_name(para_name);
|
||||
// Set function graph
|
||||
para_node->set_func_graph(func_graph);
|
||||
// Set Debug Info
|
||||
auto debug_info = std::make_shared<NodeDebugInfo>(para_name);
|
||||
para_node->set_debug_info(debug_info);
|
||||
// Set abstract
|
||||
auto default_value = new_para_pattern->default_tensor();
|
||||
MS_EXCEPTION_IF_NULL(default_value);
|
||||
para_node->set_abstract(default_value->ToAbstract()->Broaden());
|
||||
res->add_entry(pattern, para_node);
|
||||
func_graph->add_parameter(para_node);
|
||||
// Reflect back to Cell._params
|
||||
internal::ReflectParamBackToPython(para_node, para_name, default_value, new_para_pattern->requires_grad(),
|
||||
new_para_pattern->layerwise_parallel());
|
||||
MS_LOG(WARNING) << "Adding parameter: " + para_node->ToString() + " parameter name:" + para_node->name();
|
||||
new_para_pattern->set_built(true);
|
||||
return para_node;
|
||||
} else {
|
||||
// Built, fetch the node
|
||||
auto para_node = res->get_node(pattern);
|
||||
MS_EXCEPTION_IF_NULL(para_node);
|
||||
return para_node;
|
||||
}
|
||||
}
|
||||
|
||||
AnfNodePtr BuildImmNode(const PatternPtr &pattern, const MatchResultPtr &res) {
|
||||
auto imm_pattern = pattern->cast<ImmPtr>();
|
||||
MS_EXCEPTION_IF_NULL(imm_pattern);
|
||||
auto value = imm_pattern->value();
|
||||
auto scalar_value_ptr = std::make_shared<Int32Imm>(value);
|
||||
return std::make_shared<ValueNode>(scalar_value_ptr);
|
||||
}
|
||||
|
||||
AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &func_graph) {
|
||||
if (pattern->should_replace()) {
|
||||
// Find replacement in the MatchResult
|
||||
auto target_node = res->get_node(pattern);
|
||||
if (target_node == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Cannot find target node in pattern match result, pattern: " + pattern->unique_name() + "\n";
|
||||
// If it's base pattern(in contrast to complex pattern like CallWith/IsIn/IsNot), raise runtime exception.
|
||||
if (pattern->isa<IsPrimTypeOf>() || pattern->isa<NewTensor>() || pattern->isa<NewParameter>()) {
|
||||
MS_LOG(EXCEPTION) << "Cannot find target node, pattern: " + pattern->unique_name() + "\n";
|
||||
return nullptr;
|
||||
}
|
||||
// Try to build this pattern and add to MatchResult, since this pattern is defined inside target
|
||||
auto new_node = BuildTarget(pattern, func_graph, res);
|
||||
if (new_node == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Try to build pattern node but FAILED. pattern: " + pattern->unique_name() + "\n";
|
||||
}
|
||||
return new_node;
|
||||
}
|
||||
if (pattern->isa<NewParameter>()) {
|
||||
return target_node;
|
||||
}
|
||||
return target_node;
|
||||
}
|
||||
|
@ -109,7 +179,19 @@ AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr
|
|||
} else if (pattern->isa<NewTensor>()) {
|
||||
return BuildNewTensor(pattern, res);
|
||||
} else if (pattern->isa<CallWith>()) {
|
||||
return BuildPrimitiveValueNode(pattern, res);
|
||||
return BuildPrimitiveValueNode(pattern, res, func_graph);
|
||||
} else if (pattern->isa<NewParameter>()) {
|
||||
return BuildNewParameter(pattern, res, func_graph);
|
||||
} else if (pattern->isa<Imm>()) {
|
||||
return BuildImmNode(pattern, res);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
AnfNodePtr ProcessComplexPatternFirstInput(const PatternPtr &pattern, const MatchResultPtr &res,
|
||||
const FuncGraphPtr &func_graph) {
|
||||
if (pattern->isa<CallWith>()) {
|
||||
return BuildPrimitiveValueNode(pattern, res, func_graph);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -117,91 +199,154 @@ AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr
|
|||
AnfNodePtr BuildTarget(const PatternPtr &pattern, const FuncGraphPtr &func_graph, const MatchResultPtr &res) {
|
||||
auto target_inputs = pattern->inputs();
|
||||
if (target_inputs.size() == 0) {
|
||||
return ProcessSinglePattern(pattern, res);
|
||||
auto new_node = ProcessSinglePattern(pattern, res, func_graph);
|
||||
if (new_node != nullptr) {
|
||||
res->add_entry(pattern, new_node);
|
||||
}
|
||||
return new_node;
|
||||
}
|
||||
// Build up the AnfNode in a recursive manner
|
||||
std::vector<AnfNodePtr> new_inputs;
|
||||
auto prim_value_node = ProcessSinglePattern(pattern, res);
|
||||
auto prim_value_node = ProcessComplexPatternFirstInput(pattern, res, func_graph);
|
||||
MS_EXCEPTION_IF_NULL(prim_value_node);
|
||||
new_inputs.push_back(prim_value_node);
|
||||
for (auto &iter : target_inputs) {
|
||||
if (iter == pattern) {
|
||||
MS_LOG(EXCEPTION) << "Circle references: Pattern takes itself as input. Got pattern: " + pattern->unique_name() +
|
||||
"\n";
|
||||
MS_LOG(EXCEPTION) << "Circle references. Got pattern: " + pattern->unique_name() + "\n";
|
||||
}
|
||||
new_inputs.push_back(BuildTarget(iter, func_graph, res));
|
||||
auto input_node = BuildTarget(iter, func_graph, res);
|
||||
if (input_node == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Failed to build input node for pattern : " + iter->unique_name() + "\n";
|
||||
}
|
||||
new_inputs.push_back(input_node);
|
||||
}
|
||||
return func_graph->NewCNode(new_inputs);
|
||||
auto new_node = func_graph->NewCNode(new_inputs);
|
||||
res->add_entry(pattern, new_node);
|
||||
return new_node;
|
||||
}
|
||||
|
||||
void DrawNode(string name, AnfNodePtr node) {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
bool save_graphs = context_ptr->save_graphs_flag();
|
||||
auto save_graphs_path = context_ptr->save_graphs_path();
|
||||
if (save_graphs_path.empty()) {
|
||||
save_graphs_path = ".";
|
||||
}
|
||||
auto new_func_graph = std::make_shared<FuncGraph>();
|
||||
new_func_graph->set_output(node, true);
|
||||
if (save_graphs) {
|
||||
auto ir_dump_path = save_graphs_path + "/" + name + ".ir";
|
||||
auto dot_dump_path = save_graphs_path + "/" + name + ".dot";
|
||||
DumpIR(ir_dump_path, new_func_graph);
|
||||
draw::Draw(dot_dump_path, new_func_graph);
|
||||
}
|
||||
}
|
||||
|
||||
void ReflectParamBackToPython(const AnfNodePtr ¶m, string param_name, tensor::TensorPtr default_input,
|
||||
bool requires_grad, bool layerwise_parallel) {
|
||||
// 1. Get current cell object
|
||||
auto ppm = opt::python_pass::PyPassManager::GetInstance();
|
||||
auto resource = ppm->GetResource();
|
||||
py::object top_cell = resource->input();
|
||||
if (py::isinstance<py::none>(top_cell)) {
|
||||
MS_LOG(EXCEPTION) << "Failed to get top cell from resource.";
|
||||
}
|
||||
// 2. New a Parameter object with the above-specified args
|
||||
py::object parameter_class = py::module::import(PARAMETER_MODULE).attr(PARAMETER_CLASS);
|
||||
py::object new_parameter = parameter_class(default_input, param_name, requires_grad, layerwise_parallel);
|
||||
// 3. Add the new python Parameter object to Cell's _params atttributes
|
||||
top_cell.attr(SET_PARAM)(param_name, new_parameter);
|
||||
// 4. Set default_param for param_node
|
||||
ValuePtr param_value = nullptr;
|
||||
bool converted = parse::ConvertData(new_parameter, ¶m_value, false);
|
||||
if (!converted) {
|
||||
MS_LOG(EXCEPTION) << "Failed to convert new parameter to ValuePtr.";
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(param);
|
||||
auto param_node = param->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(param_node);
|
||||
param_node->set_default_param(param_value);
|
||||
}
|
||||
|
||||
void Reset(PatternPtr pattern) {
|
||||
if (pattern->isa<IsPrimTypeOf>()) {
|
||||
auto prim_pattern = pattern->cast<IsPrimTypeOfPtr>();
|
||||
prim_pattern->reset();
|
||||
return;
|
||||
} else if (pattern->isa<NewParameter>()) {
|
||||
auto new_param_pattern = pattern->cast<NewParameterPtr>();
|
||||
new_param_pattern->reset();
|
||||
return;
|
||||
} else if (pattern->isa<CallWith>()) {
|
||||
auto call_with_pattern = pattern->cast<CallWithPtr>();
|
||||
for (auto sub_pattern : call_with_pattern->inputs()) {
|
||||
Reset(sub_pattern);
|
||||
}
|
||||
return;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
} // namespace internal
|
||||
|
||||
AnfNodePtr PythonPass::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(src_pattern_);
|
||||
MS_EXCEPTION_IF_NULL(dst_pattern_);
|
||||
auto res = src_pattern_->match(node);
|
||||
if (res != nullptr) {
|
||||
res->dump();
|
||||
MS_LOG(WARNING) << "Matched pattern: " + src_pattern_->unique_name();
|
||||
AnfNodePtr PythonPass::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const MatchResultPtr &res) {
|
||||
auto match_res = src_pattern_->match(node);
|
||||
if (match_res != nullptr) {
|
||||
MS_LOG(DEBUG) << "Matched pattern: " + src_pattern_->unique_name() + " node : " + internal::GetNodeRepr(node);
|
||||
res->merge(match_res);
|
||||
auto new_node = internal::BuildTarget(dst_pattern_, func_graph, res);
|
||||
dst_pattern_->reset();
|
||||
MS_LOG(DEBUG) << "To be replaced node: " + internal::GetNodeRepr(new_node) + "\n";
|
||||
internal::Reset(dst_pattern());
|
||||
MS_LOG(WARNING) << "To be replaced node: " + internal::GetNodeRepr(new_node) + "\n";
|
||||
return new_node;
|
||||
}
|
||||
src_pattern_->reset();
|
||||
internal::Reset(src_pattern());
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
bool PythonPass::Run(const FuncGraphPtr &func_graph) {
|
||||
bool PythonPass::Run(const FuncGraphPtr &func_graph, const MatchResultPtr &res) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(dst_pattern_);
|
||||
if (src_pattern_ == nullptr) {
|
||||
// Add NewParameter
|
||||
auto new_para_pattern = dst_pattern_->cast<NewParameterPtr>();
|
||||
if (new_para_pattern == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Expect NewParameter pattern for target if src pattern is null.";
|
||||
}
|
||||
auto para_name = new_para_pattern->para_name() + new_para_pattern->unique_name();
|
||||
MS_LOG(DEBUG) << "Adding New parameter : " + para_name;
|
||||
auto para_node = std::make_shared<Parameter>(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(para_node);
|
||||
para_node->set_name(para_name);
|
||||
// Set function graph
|
||||
para_node->set_func_graph(func_graph);
|
||||
// Set Debug Info
|
||||
auto debug_info = std::make_shared<NodeDebugInfo>(para_name);
|
||||
para_node->set_debug_info(debug_info);
|
||||
// Set abstract
|
||||
auto default_value = new_para_pattern->default_tensor();
|
||||
MS_EXCEPTION_IF_NULL(default_value);
|
||||
para_node->set_abstract(default_value->ToAbstract()->Broaden());
|
||||
res->add_entry(dst_pattern_, para_node);
|
||||
func_graph->add_parameter(para_node);
|
||||
// Reflect back to Cell._params
|
||||
internal::ReflectParamBackToPython(para_node, para_name, default_value, new_para_pattern->requires_grad(),
|
||||
new_para_pattern->layerwise_parallel());
|
||||
MS_LOG(WARNING) << "Adding parameter: " + para_node->ToString() + " parameter name:" + para_node->name();
|
||||
return true;
|
||||
}
|
||||
FuncGraphManagerPtr manager = func_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
manager->AddFuncGraph(func_graph);
|
||||
auto seen = NewSeenGeneration();
|
||||
// 1024 is for the initial capacity of deque
|
||||
std::deque<AnfNodePtr> todo(1024);
|
||||
todo.push_back(func_graph->output());
|
||||
auto graph_nodes_sorted = TopoSort(func_graph->output());
|
||||
bool changes = false;
|
||||
|
||||
auto &all_nodes = manager->all_nodes();
|
||||
while (!todo.empty()) {
|
||||
AnfNodePtr node = todo.front();
|
||||
todo.pop_front();
|
||||
// Check whether this node has been matched.
|
||||
if (node == nullptr || node->seen_ == seen || !internal::IsTraversable(node) || !all_nodes.contains(node)) {
|
||||
continue;
|
||||
}
|
||||
node->seen_ = seen;
|
||||
// Select nodes that this transform can be applied.
|
||||
AnfNodePtr new_node = Run(func_graph, node);
|
||||
bool change = (new_node != nullptr);
|
||||
// Traverse once
|
||||
for (auto &node : graph_nodes_sorted) {
|
||||
AnfNodePtr new_node = Run(func_graph, node, res);
|
||||
if (new_node != nullptr && new_node != node) {
|
||||
internal::DrawNode(dst_pattern_->unique_name(), new_node);
|
||||
(void)manager->Replace(node, new_node);
|
||||
} else if (new_node == nullptr) {
|
||||
new_node = node;
|
||||
}
|
||||
if (run_only_once_) {
|
||||
return change;
|
||||
}
|
||||
// Find success, and add them to todo list
|
||||
if (IsValueNode<FuncGraph>(node)) {
|
||||
todo.push_back(GetValueNode<FuncGraphPtr>(node)->output());
|
||||
}
|
||||
if (node->isa<CNode>()) {
|
||||
auto &inputs = node->cast<CNodePtr>()->inputs();
|
||||
(void)std::copy(inputs.begin(), inputs.end(), std::back_inserter(todo));
|
||||
}
|
||||
auto &node_users = manager->node_users();
|
||||
if (change && node_users.find(node) != node_users.end()) {
|
||||
for (auto &use : node_users[node]) {
|
||||
auto use_node = use.first;
|
||||
if (use_node == nullptr) {
|
||||
continue;
|
||||
}
|
||||
todo.push_back(use_node);
|
||||
if (use_node->seen_ == seen) {
|
||||
use_node->seen_--;
|
||||
}
|
||||
}
|
||||
changes = true;
|
||||
}
|
||||
}
|
||||
return changes;
|
||||
|
|
|
@ -34,20 +34,20 @@ using NodeEquivPtr = std::shared_ptr<NodeEquiv>;
|
|||
|
||||
class PythonPass {
|
||||
public:
|
||||
explicit PythonPass(const std::string &name, const PatternPtr &src, const PatternPtr &dst, bool run_only_once = false,
|
||||
bool multigraph = true)
|
||||
: src_pattern_(src), dst_pattern_(dst), name_(name), run_only_once_(run_only_once), multigraph_(multigraph) {}
|
||||
explicit PythonPass(const std::string &name, const PatternPtr &src, const PatternPtr &dst, bool run_only_once = false)
|
||||
: src_pattern_(src), dst_pattern_(dst), name_(name), run_only_once_(run_only_once) {}
|
||||
~PythonPass() = default;
|
||||
bool Run(const FuncGraphPtr &func_graph);
|
||||
bool Run(const FuncGraphPtr &func_graph, const MatchResultPtr &res);
|
||||
std::string name() const { return name_; }
|
||||
AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node);
|
||||
AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const MatchResultPtr &res);
|
||||
PatternPtr src_pattern() { return src_pattern_; }
|
||||
PatternPtr dst_pattern() { return dst_pattern_; }
|
||||
|
||||
private:
|
||||
PatternPtr src_pattern_;
|
||||
PatternPtr dst_pattern_;
|
||||
const std::string name_;
|
||||
bool run_only_once_;
|
||||
bool multigraph_ = true;
|
||||
};
|
||||
|
||||
using PythonPassPtr = std::shared_ptr<PythonPass>;
|
||||
|
|
|
@ -45,14 +45,19 @@ PyPassManagerPtr PyPassManager::GetInstance() {
|
|||
PyPassManager::PyPassManager() {
|
||||
phase_to_group_[Phase::RESOLVE] = std::make_shared<PassGroup>();
|
||||
phase_to_group_[Phase::OPT] = std::make_shared<PassGroup>();
|
||||
res_ = std::make_shared<MatchResult>();
|
||||
}
|
||||
|
||||
void PyPassManager::Registe(const std::string &pass_name, const PatternPtr &pattern, const PatternPtr &target,
|
||||
Phase phase, bool run_only_once, bool multigraph) {
|
||||
auto cur_pm = GetPassGroup(phase);
|
||||
MS_EXCEPTION_IF_NULL(cur_pm);
|
||||
PythonPassPtr new_pass = std::make_shared<PythonPass>(pass_name, pattern, target, run_only_once, multigraph);
|
||||
cur_pm->AddPass(new_pass);
|
||||
Phase phase, bool run_only_once) {
|
||||
auto cur_pg = GetPassGroup(phase);
|
||||
MS_EXCEPTION_IF_NULL(cur_pg);
|
||||
cur_pg->SetRunOnlyOnce(run_only_once);
|
||||
MS_EXCEPTION_IF_NULL(pattern);
|
||||
MS_EXCEPTION_IF_NULL(target);
|
||||
MS_EXCEPTION_IF_NULL(cur_pg);
|
||||
PythonPassPtr new_pass = std::make_shared<PythonPass>(pass_name, pattern, target, run_only_once);
|
||||
cur_pg->AddPass(new_pass);
|
||||
}
|
||||
|
||||
void PyPassManager::Unregiste(const std::string &pass_name, Phase phase) {
|
||||
|
@ -63,6 +68,21 @@ void PyPassManager::Unregiste(const std::string &pass_name, Phase phase) {
|
|||
}
|
||||
}
|
||||
|
||||
void PyPassManager::GenNewParameter(const PatternPtr ¶meter) {
|
||||
MS_EXCEPTION_IF_NULL(parameter);
|
||||
// Add new parameter after resolve
|
||||
// NOTE: Add NewParameter at early stage will cause CSE problems
|
||||
auto cur_pg = GetPassGroup(Phase::OPT);
|
||||
MS_EXCEPTION_IF_NULL(cur_pg);
|
||||
cur_pg->SetRunOnlyOnce(true);
|
||||
auto new_para_pattern = parameter->cast<NewParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(new_para_pattern);
|
||||
auto pass_name = new_para_pattern->para_name();
|
||||
parameter->set_should_replace(false);
|
||||
auto new_pass = std::make_shared<PythonPass>(pass_name, nullptr, parameter, true);
|
||||
cur_pg->AddPass(new_pass);
|
||||
}
|
||||
|
||||
void PyPassManager::ClearRes() {
|
||||
MS_LOG(INFO) << "Clear PyPassManager resources!";
|
||||
global_instance = nullptr;
|
||||
|
@ -75,7 +95,9 @@ REGISTER_PYBIND_DEFINE(
|
|||
(void)py::class_<PyPassManager, std::shared_ptr<PyPassManager>>(*m, "PyPassManager_")
|
||||
.def(py::init([]() { return PyPassManager::GetInstance(); }))
|
||||
.def("registe", &PyPassManager::Registe, "Registe python pass")
|
||||
.def("unregiste", &PyPassManager::Unregiste, "Delete Python Pass");
|
||||
.def("unregiste", &PyPassManager::Unregiste, "Delete Python Pass")
|
||||
.def("gen_new_parameter", &PyPassManager::GenNewParameter, "Generate new parameter")
|
||||
.def("set_renorm", &PyPassManager::SetRenorm, "Set whether or not to do renorm after modified graph");
|
||||
}));
|
||||
} // namespace python_pass
|
||||
} // namespace opt
|
||||
|
|
|
@ -27,7 +27,7 @@
|
|||
#include "ir/graph_utils.h"
|
||||
#include "utils/ms_utils.h"
|
||||
|
||||
#include "pipeline/jit/parse/resolve.h"
|
||||
#include "pipeline/jit/resource.h"
|
||||
#include "frontend/optimizer/pattern.h"
|
||||
#include "frontend/optimizer/py_pass.h"
|
||||
#include "frontend/optimizer/pass_group.h"
|
||||
|
@ -53,12 +53,21 @@ class PyPassManager {
|
|||
static PyPassManagerPtr GetInstance();
|
||||
virtual ~PyPassManager() = default;
|
||||
void Registe(const std::string &pass_name, const PatternPtr &pattern, const PatternPtr &target,
|
||||
Phase phase = Phase::RESOLVE, bool run_only_once = false, bool multigraph = true);
|
||||
Phase phase = Phase::RESOLVE, bool run_only_once = false);
|
||||
void Unregiste(const std::string &pass_name, Phase phase);
|
||||
void GenNewParameter(const PatternPtr ¶meter);
|
||||
PassGroupPtr GetPassGroup(Phase phase);
|
||||
void ClearRes();
|
||||
MatchResultPtr GetMatchResult() { return res_; }
|
||||
void SetRenorm(bool should_renorm) { should_renorm_ = should_renorm; }
|
||||
bool ShouldRenorm() { return should_renorm_; }
|
||||
void SetResource(pipeline::ResourcePtr resource) { resource_ = resource; }
|
||||
pipeline::ResourcePtr GetResource() { return resource_; }
|
||||
|
||||
private:
|
||||
bool should_renorm_ = true;
|
||||
MatchResultPtr res_;
|
||||
pipeline::ResourcePtr resource_;
|
||||
static std::unordered_map<Phase, PassGroupPtr> phase_to_group_;
|
||||
};
|
||||
} // namespace python_pass
|
||||
|
|
|
@ -448,8 +448,21 @@ void ActionPyStub(const ResourcePtr &res, opt::python_pass::Phase phase) {
|
|||
MS_EXCEPTION_IF_NULL(res->manager());
|
||||
MS_EXCEPTION_IF_NULL(res->func_graph());
|
||||
auto ppm = opt::python_pass::PyPassManager::GetInstance();
|
||||
ppm->SetResource(res);
|
||||
if (!ppm->GetPassGroup(phase)->Run(res->func_graph())) {
|
||||
MS_LOG(DEBUG) << "No match.\n";
|
||||
} else if (phase == opt::python_pass::Phase::OPT && opt::python_pass::PyPassManager::GetInstance()->ShouldRenorm()) {
|
||||
MS_LOG(DEBUG) << "Entered PyStub Renorm";
|
||||
// Renomalize
|
||||
MS_EXCEPTION_IF_NULL(res->func_graph());
|
||||
FuncGraphPtr func_graph = res->func_graph();
|
||||
abstract::AbstractBasePtrList args_spec;
|
||||
auto parameters = func_graph->parameters();
|
||||
(void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args_spec),
|
||||
[](const AnfNodePtr &p) -> AbstractBasePtr { return p->abstract(); });
|
||||
FuncGraphPtr new_fg = Renormalize(res, func_graph, args_spec);
|
||||
res->set_func_graph(new_fg);
|
||||
res->set_args_spec(args_spec);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -477,6 +490,7 @@ static std::vector<ActionItem> CommonPipeline() {
|
|||
}
|
||||
// Add resolve-stage python pass stub
|
||||
actions.emplace_back(std::make_pair("py_resolve", ResolveActionPyStub));
|
||||
|
||||
actions.emplace_back(std::make_pair("inference_opt_prepare", InferenceOptPrepareAction));
|
||||
// Evaluate type and shape, and specialize
|
||||
actions.emplace_back(std::make_pair("abstract_specialize", AbstractSpecializeAction));
|
||||
|
|
|
@ -1,81 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Python pass register"""
|
||||
from inspect import isfunction
|
||||
from mindspore.common.graph_pattern import Pattern
|
||||
from mindspore._c_expression import PyPassManager_
|
||||
from mindspore._c_expression import phase
|
||||
|
||||
class PyPassManager(PyPassManager_):
|
||||
r"""
|
||||
Used to registe and unregiste python passes which can be used to alter graphs.
|
||||
|
||||
Args:
|
||||
pipeline_phase (phase): Specify the stage in which the pass will run in the pipeline. Default: phase.opt.
|
||||
run_only_once (bool): Specify whether or not to run pass only once. Default: False.
|
||||
multigraph (bool): Whether or not the pattern exists across graphs. Default: True.
|
||||
|
||||
Raises:
|
||||
TypeError: If argument has invalid type.
|
||||
"""
|
||||
def __init__(self, pipeline_phase=phase.opt, run_only_once=False, multi_graph=True):
|
||||
if not isinstance(pipeline_phase, phase):
|
||||
raise TypeError(f"Expecting phase, got : ({type(pipeline_phase)}){pipeline_phase}")
|
||||
if not isinstance(run_only_once, bool):
|
||||
raise TypeError(f"Expecting bool, got : ({type(run_only_once)}){run_only_once}")
|
||||
if not isinstance(multi_graph, bool):
|
||||
raise TypeError(f"Expecting bool, got : ({type(multi_graph)}){multi_graph}")
|
||||
PyPassManager_.__init__(self)
|
||||
self.phase_ = pipeline_phase
|
||||
self.run_only_once_ = run_only_once
|
||||
self.multi_graph_ = multi_graph
|
||||
|
||||
def registe(self, py_pass):
|
||||
if not isfunction(py_pass):
|
||||
raise TypeError(f"Expecting function pass, got : ({type(py_pass)}){py_pass}")
|
||||
pattern, target = py_pass()
|
||||
pass_name = py_pass.__name__
|
||||
if not isinstance(pattern, Pattern):
|
||||
raise TypeError(f"Expecting pattern of Pattern type, got : ({type(pattern)}){pattern}")
|
||||
if not isinstance(target, Pattern):
|
||||
raise TypeError(f"Expecting target of Pattern type, got : ({type(target)}){target}")
|
||||
super().registe(pass_name, pattern, target, self.phase_, self.run_only_once_, self.multi_graph_)
|
||||
|
||||
def unregiste(self, py_pass, pipeline_phase=phase.opt):
|
||||
if not isinstance(pipeline_phase, phase):
|
||||
raise TypeError(f"Expecting phase, got : ({type(pipeline_phase)}){pipeline_phase}")
|
||||
if isinstance(py_pass, str):
|
||||
super().unregiste(py_pass, pipeline_phase)
|
||||
return
|
||||
if isfunction(py_pass):
|
||||
super().unregiste(py_pass.__name__, pipeline_phase)
|
||||
return
|
||||
raise TypeError(f"Expecting py_pass to be string or function, got ({type(py_pass)}){py_pass}")
|
||||
|
||||
def __call__(self, py_pass):
|
||||
self.registe(py_pass)
|
||||
return py_pass
|
||||
|
||||
def registe_pass(pipeline_phase=phase.opt, run_only_once=False, multi_graph=True):
|
||||
"""
|
||||
Examples:
|
||||
>>> @registe_pass()
|
||||
>>> def toy_pass():
|
||||
>>> def pattern():
|
||||
>>> pass
|
||||
>>> def target():
|
||||
>>> pass
|
||||
"""
|
||||
return PyPassManager(pipeline_phase, run_only_once, multi_graph)
|
|
@ -0,0 +1,15 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Top-level reference to python pass."""
|
|
@ -15,7 +15,8 @@
|
|||
"""Patterns for describing graphs"""
|
||||
from mindspore.ops import Primitive
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore._c_expression import Pattern, IsIn_, IsPrimTypeOf_, CallWith_, IsNot_, AnyPattern, NewTensor_
|
||||
from mindspore._c_expression import Pattern, IsIn_, IsPrimTypeOf_, CallWith_, IsNot_, AnyPattern, NewTensor_,\
|
||||
NewParameter_, Imm
|
||||
|
||||
__all__ = [
|
||||
"IsIn",
|
||||
|
@ -24,17 +25,25 @@ __all__ = [
|
|||
"IsNot",
|
||||
"AnyPattern",
|
||||
"NewTensor",
|
||||
"NewParameter",
|
||||
"Imm"
|
||||
]
|
||||
|
||||
class IsIn(IsIn_):
|
||||
"""
|
||||
r"""
|
||||
Express a pattern which allows a list of patterns.
|
||||
"""
|
||||
def __init__(self, patterns=None, should_replace=True):
|
||||
r"""
|
||||
Args:
|
||||
patterns(list/tuple): list of allowed patterns
|
||||
patterns(Union[tuple[:class:`mindspore.graph_utils.graph_pattern`],
|
||||
list[:class:`mindspore.graph_utils.graph_pattern`]]): list of allowed patterns,
|
||||
each element should be one of the exposed Pattern instance.
|
||||
should_replace(bool): added this for interface consistency. Should only set this in sub-patterns.
|
||||
|
||||
Raises:
|
||||
ValueError: raise if should_replace is False
|
||||
TypeError: raise type error for invalid inputs.
|
||||
"""
|
||||
if not should_replace:
|
||||
raise ValueError("IsIn pattern does not have its own should_replace attribute. Set should_replace in \
|
||||
|
@ -52,19 +61,28 @@ class IsIn(IsIn_):
|
|||
class IsPrimTypeOf(IsPrimTypeOf_):
|
||||
r"""
|
||||
Express a pattern of certain primitive type(s).
|
||||
NOTE: This pattern will match and only match the primitive value node. If matching primitive CNode is needed,
|
||||
please refer to CallWith pattern.
|
||||
|
||||
NOTE:
|
||||
This pattern will match and only match the primitive value node. If matching primitive CNode is needed,
|
||||
please refer to CallWith pattern.
|
||||
"""
|
||||
def __init__(self, types, name=None, should_replace=True):
|
||||
r"""
|
||||
Args:
|
||||
types (str/(list/tuple of Primitives)): Specify allowed types.
|
||||
types (Union[str, :class:`mindspore.ops.Primitive`, list[:class:`mindspore.ops.Primitive`],
|
||||
tuple[:class:`mindspore.ops.Primitive`]):
|
||||
Specify allowed types.
|
||||
If it is a string, the form could be
|
||||
1) a single primitive type, e.g. 'Conv2D'
|
||||
2) a set of primitive types separated by '|', e.g. 'MatMul|Conv2D'
|
||||
It can also be a list of Primitives, e.g. [ops.Conv2D(1, 6)]
|
||||
name (str): name of the pattern, optional
|
||||
should_replace
|
||||
It can also be a Primitive or a list/tuple of Primitives, e.g. [ops.Conv2D(1, 6)]
|
||||
name (str): name of the pattern, optional. Default: None.
|
||||
should_replace(bool): If pattern is part of the pass replacement target, this would set how this pattern is
|
||||
used when building the replacement target node. Use captured node if True, build from scratch otherwise.
|
||||
Default: True.
|
||||
|
||||
Raises:
|
||||
TypeError: raise type error for invalid argument.
|
||||
"""
|
||||
if name is not None and not isinstance(name, str):
|
||||
raise TypeError(f"Expect string, got : {name}")
|
||||
|
@ -91,12 +109,21 @@ class CallWith(CallWith_):
|
|||
r"""
|
||||
Express a primitive CNode.
|
||||
"""
|
||||
def __init__(self, prim_pattern, inputs=None, should_replace=False):
|
||||
def __init__(self, prim_pattern, inputs=None, should_replace=True):
|
||||
r"""
|
||||
Args:
|
||||
prim_pattern (Pattern/Primitive/str): Primitive ValueNode in the Primitive CNode.
|
||||
inputs (list/tuple): Specify inputs pattern for the primitive(s), optional. If None, accepts any inputs;
|
||||
if specified, input patterns should be of right order.
|
||||
prim_pattern (Union[str, :class:`mindspore.graph_utils.graph_pattern.IsPrimTypeOf`,
|
||||
:class:`mindspore.ops.Primitive`]): Primitive ValueNode in the Primitive CNode.
|
||||
inputs (Union[list[:class:`mindspore.graph_utils.graph_pattern`],
|
||||
tuple[:class:`mindspore.graph_utils.graph_pattern`]]):
|
||||
Specify inputs pattern for the primitive(s), optional. If None, accepts any inputs; if specified, input
|
||||
patterns should be of right order and each element should be one of the exposed Pattern instance.
|
||||
should_replace(bool): If pattern is part of the pass replacement target, this would set how this pattern is
|
||||
used when building the replacement target node. Use captured node if True, build from scratch otherwise.
|
||||
Default: True.
|
||||
|
||||
Raises:
|
||||
TypeError: raise type error for invalid argument.
|
||||
"""
|
||||
if not isinstance(prim_pattern, (Pattern, str, Primitive)):
|
||||
raise TypeError(f"Expect prim_pattern to be Pattern, Primitive or string, got : {prim_pattern}")
|
||||
|
@ -110,17 +137,23 @@ class CallWith(CallWith_):
|
|||
raise TypeError(f"Expect inputs to be a list of Patterns, got : {inputs}")
|
||||
CallWith_.__init__(self, self.prim_pattern, self.inputs, should_replace)
|
||||
|
||||
|
||||
class IsNot(IsNot_):
|
||||
r"""
|
||||
Express a pattern which forbids a list of patterns.
|
||||
NOTE: IsNot pattern should not be the root pattern.
|
||||
|
||||
NOTE:
|
||||
IsNot pattern should not be the root pattern.
|
||||
"""
|
||||
def __init__(self, patterns=None, should_replace=True):
|
||||
r"""
|
||||
Args:
|
||||
patterns(list/tuple): list of forbiden patterns
|
||||
patterns(Union[list[:class:`mindspore.graph_utils.graph_pattern`]]: list of forbiden patterns, each element
|
||||
should be one of the exposed Pattern instance.
|
||||
should_replace(bool): added this for interface consistency. Should only set this in sub-patterns.
|
||||
|
||||
Raises:
|
||||
ValueError: raise if should_replace is False.
|
||||
TypeError: raise type error for invalid argument.
|
||||
"""
|
||||
if not should_replace:
|
||||
raise ValueError("IsNot pattern does not have its own should_replace attribute. Set should_replace in \
|
||||
|
@ -142,13 +175,48 @@ class NewTensor(NewTensor_):
|
|||
def __init__(self, input_tensor, should_replace=False):
|
||||
r"""
|
||||
Args:
|
||||
input_tensor(Tensor): new tensor to be used in the target
|
||||
input_tensor(:class:`mindspore.common.tensor.Tensor`): new tensor to be used in the target
|
||||
should_replace(bool): added this for interface consistency. NewTensor should only appear in the target.
|
||||
|
||||
Raises:
|
||||
ValueError: raise if should_replace is True
|
||||
TypeError: raise type error for invalid argument.
|
||||
"""
|
||||
if should_replace:
|
||||
raise ValueError("NewTensor should only appear in the target, thus should_replace can onlyu be False.")
|
||||
raise ValueError("NewTensor should only appear in the target, thus should_replace can only be False.")
|
||||
self.input_tensor = input_tensor
|
||||
if isinstance(input_tensor, Tensor):
|
||||
NewTensor_.__init__(self, input_tensor)
|
||||
else:
|
||||
raise TypeError(f"Expect input_tensor to be a Tensor, got : {input_tensor}")
|
||||
|
||||
class NewParameter(NewParameter_):
|
||||
r"""
|
||||
New Parameter to be used in the target.
|
||||
"""
|
||||
def __init__(self, para_name, default_tensor, requires_grad=False, layerwise_parallel=False, should_replace=False):
|
||||
r"""
|
||||
Args:
|
||||
para_name(str): name for the new Parameter
|
||||
default_tensor(:class:`mindspore.common.tensor.Tensor`): default value for the new Parameter
|
||||
requires_grad(bool): True if the parameter requires gradient. Default: True
|
||||
layerwise_parallel(bool): switch for layerwise parallel mode. Default: False
|
||||
should_replace(bool): gen new parameter once and replace after if set to be true; otherwise build a new
|
||||
parameter everytime a pass target got built. Default: False
|
||||
|
||||
Raises:
|
||||
TypeError: raise type error for invalid argument.
|
||||
"""
|
||||
self.para_name = para_name
|
||||
self.default_tensor = default_tensor
|
||||
self.requires_grad = requires_grad
|
||||
self.layerwise_parallel = layerwise_parallel
|
||||
self.should_replace = should_replace
|
||||
if isinstance(para_name, str) and isinstance(default_tensor, Tensor) and isinstance(requires_grad, bool) and\
|
||||
isinstance(layerwise_parallel, bool) and isinstance(should_replace, bool):
|
||||
NewParameter_.__init__(self, self.para_name, self.default_tensor, self.requires_grad,
|
||||
self.layerwise_parallel, self.should_replace)
|
||||
else:
|
||||
raise TypeError(f"Expect para_name(str), default_tensor(Tensor), requires_grad(bool), \
|
||||
layerwise_parallel(bool) should_replace(bool), got : {para_name}, {default_tensor}, \
|
||||
{requires_grad}, {layerwise_parallel}, {should_replace}")
|
|
@ -0,0 +1,24 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Top-level reference to python pass."""
|
||||
from .python_pass_register import registe_pass, unregiste_pass, gen_new_parameter, cancel_new_parameter, set_renorm
|
||||
|
||||
__all__ = [
|
||||
"registe_pass",
|
||||
"unregiste_pass",
|
||||
"gen_new_parameter",
|
||||
"cancel_new_parameter",
|
||||
"set_renorm"
|
||||
]
|
|
@ -0,0 +1,170 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Python pass register"""
|
||||
from inspect import isfunction
|
||||
from mindspore.graph_utils.graph_pattern import Pattern, NewParameter
|
||||
from mindspore._c_expression import PyPassManager_, phase
|
||||
|
||||
|
||||
__all__ = [
|
||||
"registe_pass",
|
||||
"unregiste_pass",
|
||||
"gen_new_parameter",
|
||||
"cancel_new_parameter",
|
||||
"set_renorm"
|
||||
]
|
||||
class PyPassManager(PyPassManager_):
|
||||
r"""
|
||||
Used to registe and unregiste python passes which can be used to alter graphs.
|
||||
|
||||
Args:
|
||||
pipeline_phase (phase): Specify the stage in which the pass will run in the pipeline. Default: phase.opt.
|
||||
run_only_once (bool): Specify whether or not to run pass only once. Default: False.
|
||||
multigraph (bool): Whether or not the pattern exists across graphs. Default: True.
|
||||
|
||||
Raises:
|
||||
TypeError: If argument has invalid type.
|
||||
"""
|
||||
def __init__(self, pipeline_phase=phase.opt, run_only_once=False):
|
||||
if not isinstance(pipeline_phase, phase):
|
||||
raise TypeError(f"Expect phase, got : ({type(pipeline_phase)}){pipeline_phase}")
|
||||
if not isinstance(run_only_once, bool):
|
||||
raise TypeError(f"Expect bool, got : ({type(run_only_once)}){run_only_once}")
|
||||
PyPassManager_.__init__(self)
|
||||
self.phase_ = pipeline_phase
|
||||
self.run_only_once_ = run_only_once
|
||||
|
||||
def registe(self, py_pass):
|
||||
if not isfunction(py_pass):
|
||||
raise TypeError(f"Expect function pass, got : ({type(py_pass)}){py_pass}")
|
||||
pattern, target = py_pass()
|
||||
pass_name = py_pass.__name__
|
||||
if not isinstance(pattern, Pattern):
|
||||
raise TypeError(f"Expect pattern of Pattern type, got : ({type(pattern)}){pattern}")
|
||||
if not isinstance(target, Pattern):
|
||||
raise TypeError(f"Expect target of Pattern type, got : ({type(target)}){target}")
|
||||
super().registe(pass_name, pattern, target, self.phase_, self.run_only_once_)
|
||||
|
||||
def unregiste(self, py_pass, pipeline_phase=phase.opt):
|
||||
if not isinstance(pipeline_phase, phase):
|
||||
raise TypeError(f"Expect phase, got : ({type(pipeline_phase)}){pipeline_phase}")
|
||||
if isinstance(py_pass, str):
|
||||
super().unregiste(py_pass, pipeline_phase)
|
||||
return
|
||||
if isfunction(py_pass):
|
||||
super().unregiste(py_pass.__name__, pipeline_phase)
|
||||
return
|
||||
raise TypeError(f"Expect py_pass to be string or function, got ({type(py_pass)}){py_pass}")
|
||||
|
||||
def __call__(self, py_pass):
|
||||
self.registe(py_pass)
|
||||
return py_pass
|
||||
|
||||
def gen_new_parameter(self, pattern):
|
||||
if not isinstance(pattern, NewParameter):
|
||||
raise TypeError(f"Expect pattern to be a NewParameter Pattern, got {pattern}")
|
||||
super().gen_new_parameter(pattern)
|
||||
|
||||
def set_renorm(self, should_renorm):
|
||||
if not isinstance(should_renorm, bool):
|
||||
raise TypeError(f"Expect should_renorm to be a bool, got {should_renorm}")
|
||||
super().set_renorm(should_renorm)
|
||||
|
||||
def registe_pass(pipeline_phase=phase.opt, run_only_once=False):
|
||||
"""
|
||||
Registe python pass to specified pipeline phase which would be used in compilation.
|
||||
|
||||
Args:
|
||||
pipeline_phase(:class:`mindspore._c_expression.phase`): To which compilation pipeline stage the pass is
|
||||
registed. Support phase.resolve and phase.opt. Default: phase.opt.
|
||||
run_only_once(bool): Run this pass only once if set true. Otherwise run the pass until converge. Default: False.
|
||||
|
||||
Returns:
|
||||
This function should be used as a decorator, return the decoratorated pass function.
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.graph_utils.graph_pattern import IsPrimTypeOf
|
||||
>>> @registe_pass()
|
||||
>>> def toy_pass():
|
||||
>>> pattern = IsPrimTypeOf("ReLU")
|
||||
>>> target = IsPrimTypeOf("ReLU6")
|
||||
>>> return pattern, target
|
||||
"""
|
||||
return PyPassManager(pipeline_phase, run_only_once)
|
||||
|
||||
def unregiste_pass(py_pass, pipeline_phase=phase.opt):
|
||||
"""
|
||||
Unregiste python pass.
|
||||
|
||||
Args:
|
||||
py_pass(Union(str, function)): target python pass to unregiste.
|
||||
pipeline_phase(:class:`mindspore._c_expression.phase`): To which compilation pipeline stage the pass is
|
||||
unregisted. Support phase.resolve and phase.opt. Default: phase.opt.
|
||||
"""
|
||||
ppm = PyPassManager()
|
||||
ppm.unregiste(py_pass, pipeline_phase)
|
||||
|
||||
def gen_new_parameter(pattern):
|
||||
"""
|
||||
Generate specified parameter every time a network gets compiled.
|
||||
|
||||
NOTE:
|
||||
In this way, every pass uses this pattern would be using the same Parameter. If use NewParameter without
|
||||
gen_new_parameter, every pass match would build a new Parameter.
|
||||
This would registe a pass to add new parameter in the compilation pipeline, so later compilation would
|
||||
ALSO add this parameter unless the pass is unregisted. To unregiste this pass, call
|
||||
cancel_new_parameter(pattern)
|
||||
|
||||
Args:
|
||||
pattern (NewParameter): NewParameter type, could be used to build nested patterns across multiple passes
|
||||
after gen_new_parameter.
|
||||
|
||||
Raises:
|
||||
TypeError: If argument has invalid type.
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.graph_utils.graph_pattern import NewParameter
|
||||
>>> abc = NewParameter("abc")
|
||||
>>> gen_new_parameter(abc)
|
||||
"""
|
||||
ppm = PyPassManager()
|
||||
ppm.gen_new_parameter(pattern)
|
||||
|
||||
def cancel_new_parameter(pattern):
|
||||
"""
|
||||
Use with gen_new_parameter to unregiste gen_new_parameter pass.
|
||||
|
||||
Args:
|
||||
pattern (NewParameter): NewParameter type, cancel the pass which would add new parameter as this pattern
|
||||
describes.
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.graph_utils.graph_pattern import NewParameter
|
||||
>>> abc = NewParameter("abc")
|
||||
>>> gen_new_parameter(abs)
|
||||
>>> # some compilations
|
||||
>>> cancel_new_parameter(abc)
|
||||
"""
|
||||
if not isinstance(pattern, NewParameter):
|
||||
raise TypeError(f"Expect pattern to be a NewParameter Pattern, got {pattern}")
|
||||
ppm = PyPassManager()
|
||||
ppm.unregiste(pattern.para_name)
|
||||
|
||||
def set_renorm(should_renorm):
|
||||
"""
|
||||
Set whether or not to do renorm after modified graph in python pass(es).
|
||||
"""
|
||||
ppm = PyPassManager()
|
||||
ppm.set_renorm(should_renorm)
|
|
@ -152,7 +152,7 @@ class Primitive(Primitive_):
|
|||
Check if certain inputs should go to the backend. Subclass in need should override this method.
|
||||
|
||||
Args:
|
||||
*args(Primitive args): Same as arguments of current Primitive.
|
||||
args(Primitive args): Same as arguments of current Primitive.
|
||||
|
||||
Returns:
|
||||
A tuple consisting of two elements. The first element indicates whether we should filter out current
|
||||
|
|
|
@ -19,10 +19,12 @@ import mindspore.nn as nn
|
|||
from mindspore import context
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.python_pass_register import registe_pass, PyPassManager
|
||||
from mindspore.graph_utils.python_pass import registe_pass, unregiste_pass, set_renorm, gen_new_parameter,\
|
||||
cancel_new_parameter
|
||||
from mindspore.common.api import _generate_pip_args
|
||||
from mindspore._c_expression import generate_key, Executor_
|
||||
from mindspore.common.graph_pattern import IsIn, IsPrimTypeOf, CallWith, IsNot, AnyPattern, NewTensor
|
||||
from mindspore.graph_utils.graph_pattern import IsIn, IsPrimTypeOf, CallWith, IsNot, AnyPattern, NewTensor,\
|
||||
NewParameter, Imm
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
@ -56,12 +58,39 @@ def test_softmax_relu():
|
|||
return pattern, target
|
||||
|
||||
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2)
|
||||
ppm = PyPassManager()
|
||||
ppm.unregiste(softmax_relu_pass)
|
||||
unregiste_pass(softmax_relu_pass)
|
||||
assert "ReLU" in transformed_repr
|
||||
assert "Softmax" not in transformed_repr
|
||||
|
||||
def test_isin_pattern():
|
||||
def test_softmax_relu_sigmoid():
|
||||
"""
|
||||
Use python pass to transform from Softmax(x) to ReLU(Sigmoid(x)).
|
||||
|
||||
NOTE:
|
||||
Sigmoid pattern only exists in the target.
|
||||
"""
|
||||
inputs = Tensor(np.ones([42]), mindspore.float16)
|
||||
softmax_model = nn.Softmax()
|
||||
|
||||
@registe_pass(run_only_once=True)
|
||||
def softmax_relu_pass():
|
||||
x = AnyPattern()
|
||||
softmax_pattern = IsPrimTypeOf(P.Softmax())
|
||||
pattern = CallWith(softmax_pattern, inputs=[x])
|
||||
sigmoid_pattern = IsPrimTypeOf(P.Sigmoid(), should_replace=False)
|
||||
call_sigmoid = CallWith(sigmoid_pattern, [x])
|
||||
relu_pattern = IsPrimTypeOf(P.ReLU(), should_replace=False)
|
||||
target = CallWith(relu_pattern, inputs=[call_sigmoid])
|
||||
return pattern, target
|
||||
|
||||
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(3)
|
||||
unregiste_pass(softmax_relu_pass)
|
||||
assert "ReLU" in transformed_repr
|
||||
assert "Sigmoid" in transformed_repr
|
||||
assert "Softmax" not in transformed_repr
|
||||
|
||||
|
||||
def test_isin_pattern_0():
|
||||
"""
|
||||
Test IsIn pattern which expresses the IsIn/OneOf semantics.
|
||||
"""
|
||||
|
@ -81,16 +110,41 @@ def test_isin_pattern():
|
|||
target = CallWith(relu6_pattern, inputs=[x])
|
||||
return pattern, target
|
||||
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2)
|
||||
ppm = PyPassManager()
|
||||
ppm.unregiste(softmax_relu_pass)
|
||||
unregiste_pass(softmax_relu_pass)
|
||||
assert "ReLU6" in transformed_repr
|
||||
assert "Softmax" not in transformed_repr
|
||||
|
||||
def test_isin_pattern_1():
|
||||
"""
|
||||
Test IsIn. IsIn is used as nested inputs for the target in this case.
|
||||
"""
|
||||
inputs = Tensor(np.ones([42]), mindspore.float16)
|
||||
softmax_model = nn.Softmax()
|
||||
|
||||
@registe_pass(run_only_once=True)
|
||||
def softmax_neg_pass():
|
||||
x = AnyPattern()
|
||||
softmax_pattern = IsPrimTypeOf(P.Softmax())
|
||||
call_softmax = CallWith(softmax_pattern, inputs=[x])
|
||||
relu_pattern = IsPrimTypeOf(P.ReLU())
|
||||
call_relu = CallWith(relu_pattern, inputs=[x])
|
||||
|
||||
pattern = IsIn([call_softmax, call_relu])
|
||||
neg_ops = IsPrimTypeOf(P.Neg(), should_replace=False)
|
||||
target = CallWith(neg_ops, inputs=[pattern])
|
||||
return pattern, target
|
||||
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(4)
|
||||
print(transformed_repr)
|
||||
unregiste_pass(softmax_neg_pass)
|
||||
assert "Neg" in transformed_repr
|
||||
assert "Softmax" in transformed_repr
|
||||
|
||||
def test_isnot_pattern_0():
|
||||
"""
|
||||
Test IsNot pattern which expresses the IsNot semantics.
|
||||
Case: IsNot pass failed to match
|
||||
"""
|
||||
set_renorm(False)
|
||||
class ConvBN(nn.Cell):
|
||||
def __init__(self):
|
||||
super(ConvBN, self).__init__()
|
||||
|
@ -132,11 +186,11 @@ def test_isnot_pattern_0():
|
|||
return pattern, target
|
||||
|
||||
transformed_repr = get_func_graph(conv_bn_model, inputs).get_return().expanded_str(5)
|
||||
ppm = PyPassManager()
|
||||
ppm.unregiste(single_bn_pass)
|
||||
ppm.unregiste(bn_pass)
|
||||
unregiste_pass(single_bn_pass)
|
||||
unregiste_pass(bn_pass)
|
||||
assert "ReLU6" not in transformed_repr
|
||||
assert "Softmax" in transformed_repr
|
||||
set_renorm(True)
|
||||
|
||||
def test_isnot_pattern_1():
|
||||
"""
|
||||
|
@ -160,12 +214,15 @@ def test_isnot_pattern_1():
|
|||
return pattern, target
|
||||
|
||||
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
|
||||
ppm = PyPassManager()
|
||||
ppm.unregiste(single_bn_pass)
|
||||
unregiste_pass(single_bn_pass)
|
||||
assert "ReLU6" in transformed_repr
|
||||
assert "Softmax" not in transformed_repr
|
||||
|
||||
def test_newtensor_pattern():
|
||||
"""
|
||||
Test NewTensor pattern in the target
|
||||
"""
|
||||
set_renorm(False)
|
||||
inputs = Tensor(np.ones([42]), mindspore.float16)
|
||||
softmax_model = nn.Softmax()
|
||||
|
||||
|
@ -181,7 +238,84 @@ def test_newtensor_pattern():
|
|||
target = CallWith(addn_ops, inputs=[x, new_weight], should_replace=False)
|
||||
return pattern, target
|
||||
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2)
|
||||
ppm = PyPassManager()
|
||||
ppm.unregiste(softmax_addn_pass)
|
||||
unregiste_pass(softmax_addn_pass)
|
||||
assert "AddN" in transformed_repr
|
||||
assert "Softmax" not in transformed_repr
|
||||
set_renorm(True)
|
||||
|
||||
def test_newparameter_pattern():
|
||||
"""
|
||||
Test NewParameter pattern in the target
|
||||
"""
|
||||
inputs = Tensor(np.ones([42]), mindspore.float16)
|
||||
softmax_model = nn.Softmax()
|
||||
|
||||
@registe_pass(run_only_once=True)
|
||||
def softmax_addn_pass():
|
||||
x = AnyPattern()
|
||||
softmax = P.Softmax()
|
||||
pattern = CallWith(softmax, inputs=[x])
|
||||
|
||||
default_tensor0 = Tensor(np.ones((4, 4)), mindspore.float32)
|
||||
default_tensor1 = Tensor(np.ones((4, 4)), mindspore.float32)
|
||||
new_para_0 = NewParameter("Merlin", default_tensor0)
|
||||
new_para_1 = NewParameter("Arthur", default_tensor1)
|
||||
target_0 = CallWith(P.MatMul(), inputs=[new_para_0, new_para_1], should_replace=False)
|
||||
target = CallWith("make_tuple", inputs=[target_0], should_replace=False)
|
||||
return pattern, target
|
||||
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
|
||||
print(transformed_repr)
|
||||
unregiste_pass(softmax_addn_pass)
|
||||
assert "MatMul" in transformed_repr
|
||||
assert "make_tuple" in transformed_repr
|
||||
assert "Softmax" not in transformed_repr
|
||||
|
||||
def test_imm_pattern():
|
||||
"""
|
||||
Test NewParameter pattern in the target
|
||||
"""
|
||||
inputs = Tensor(np.ones([42]), mindspore.float16)
|
||||
softmax_model = nn.Softmax()
|
||||
|
||||
@registe_pass(run_only_once=True)
|
||||
def softmax_addn_pass():
|
||||
x = AnyPattern()
|
||||
softmax = P.Softmax()
|
||||
pattern = CallWith(softmax, inputs=[x])
|
||||
imm = Imm(0)
|
||||
target_0 = CallWith("make_tuple", inputs=[pattern], should_replace=False)
|
||||
target = CallWith("tuple_getitem", inputs=[target_0, imm], should_replace=False)
|
||||
return pattern, target
|
||||
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
|
||||
print(transformed_repr)
|
||||
unregiste_pass(softmax_addn_pass)
|
||||
assert "make_tuple" in transformed_repr
|
||||
assert "tuple_getitem" in transformed_repr
|
||||
assert "Softmax" in transformed_repr
|
||||
|
||||
def test_gen_new_parameter():
|
||||
"""
|
||||
Test gen_new_parameter
|
||||
"""
|
||||
inputs = Tensor(np.ones([42]), mindspore.float16)
|
||||
softmax_model = nn.Softmax()
|
||||
|
||||
default_tensor = Tensor(np.ones((4, 4)), mindspore.float32)
|
||||
new_para = NewParameter("Merlin", default_tensor, should_replace=True)
|
||||
gen_new_parameter(new_para)
|
||||
@registe_pass(run_only_once=True)
|
||||
def softmax_make_tuple_pass():
|
||||
x = AnyPattern()
|
||||
softmax = P.Softmax()
|
||||
pattern = CallWith(softmax, inputs=[x])
|
||||
|
||||
target = CallWith("make_tuple", inputs=[pattern, new_para], should_replace=False)
|
||||
return pattern, target
|
||||
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
|
||||
print(transformed_repr)
|
||||
assert "Merlin" in transformed_repr
|
||||
unregiste_pass(softmax_make_tuple_pass)
|
||||
cancel_new_parameter(new_para)
|
||||
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
|
||||
print(transformed_repr)
|
||||
assert "Merlin" not in transformed_repr
|
||||
|
|
Loading…
Reference in New Issue