diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/arithmetic_simplify.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/arithmetic_simplify.cc index 48cf2b05e4a..fdfffd87918 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/arithmetic_simplify.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/arithmetic_simplify.cc @@ -158,12 +158,10 @@ PatternNodePtr PatternTree::BuildTree(const std::string &pattern_str) { } cur_node->AddInput(BuildTree(op_inputs)); return cur_node; - } else { return std::make_shared(pattern_str); } } - return nullptr; } @@ -276,7 +274,6 @@ bool DfsMatchGraph(const graphkernel::NodePtr &tmp_node, const PatternNodePtr &t return false; } } - } else { for (size_t i = 0; i < tmp_pattern_inputs.size(); i++) { if (!DfsMatchGraph(tmp_node_inputs[i], tmp_pattern_inputs[i], para_to_ref, const_to_ref, res)) { @@ -387,7 +384,6 @@ class ExtraReduce1PatternTree : public PatternTree { for (auto &i : GetValue>(first_reduce->attrs().find("axis")->second)) { axis_set.insert(i); } - } else { auto first_axis = GetValue>(first_reduce->attrs().find("axis")->second); auto second_axis = GetValue>(origin_root->attrs().find("axis")->second); @@ -538,7 +534,6 @@ std::unordered_map> GetExpressions() { std::unordered_set enable_ids{flags.enable_simplify_exprs_only.begin(), flags.enable_simplify_exprs_only.end()}; std::unordered_set disable_ids{flags.disable_simplify_exprs.begin(), flags.disable_simplify_exprs.end()}; - for (auto &e : expressions) { if (!enable_ids.empty()) { if (enable_ids.count(std::to_string(e.id)) == 0) continue; @@ -640,33 +635,29 @@ bool ArithmeticSimplify::Run(const FuncGraphPtr &func_graph) { expressions_map_ = GetExpressions(); for (auto node : func_graph->GetOrderedCnodes()) { if (AnfAlgo::IsGraphKernel(node)) { - try { - auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node); - graphkernel::LiteGraphPtr lg = AnfGraph2LiteGraph(sub_graph); - bool find_pattern = true; - bool change_anf_graph = false; - while (find_pattern) { - find_pattern = false; - find_pattern = DoArithmeticTrans(lg) || find_pattern; - find_pattern = DoConstantFold(lg) || find_pattern; - change_anf_graph = change_anf_graph || find_pattern; - } - if (!change_anf_graph) continue; - ReorganizeEmptyGraph(lg); - AnfNodePtrList outputs; - auto new_funcgraph = LiteGraph2AnfGraph(lg, &outputs); - new_funcgraph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, sub_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)); - auto cnode = node->cast(); - AnfNodePtrList inputs(cnode->inputs().begin() + 1, cnode->inputs().end()); - EliminateRedundantParameters(new_funcgraph, &inputs); - auto new_node = CreateNewFuseCNode(func_graph, new_funcgraph, inputs, outputs); - SetNewKernelInfo(new_node, new_funcgraph, inputs, outputs); - mng->Replace(node, new_node); - mng->AddFuncGraph(new_funcgraph); - do_simplify = true; - } catch (const graphkernel::GKException &e) { - MS_LOG(WARNING) << e.what() << ", so we undo airthmetic simplify for this graph"; + auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node); + graphkernel::LiteGraphPtr lg = AnfGraph2LiteGraph(sub_graph); + bool find_pattern = true; + bool change_anf_graph = false; + while (find_pattern) { + find_pattern = false; + find_pattern = DoArithmeticTrans(lg) || find_pattern; + find_pattern = DoConstantFold(lg) || find_pattern; + change_anf_graph = change_anf_graph || find_pattern; } + if (!change_anf_graph) continue; + ReorganizeEmptyGraph(lg); + AnfNodePtrList outputs; + auto new_funcgraph = LiteGraph2AnfGraph(lg, &outputs); + new_funcgraph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, sub_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)); + auto cnode = node->cast(); + AnfNodePtrList inputs(cnode->inputs().begin() + 1, cnode->inputs().end()); + EliminateRedundantParameters(new_funcgraph, &inputs); + auto new_node = CreateNewFuseCNode(func_graph, new_funcgraph, inputs, outputs); + SetNewKernelInfo(new_node, new_funcgraph, inputs, outputs); + mng->Replace(node, new_node); + mng->AddFuncGraph(new_funcgraph); + do_simplify = true; } } return do_simplify; diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/expanders/bias_add.h b/mindspore/ccsrc/backend/optimizer/graph_kernel/expanders/bias_add.cc similarity index 83% rename from mindspore/ccsrc/backend/optimizer/graph_kernel/expanders/bias_add.h rename to mindspore/ccsrc/backend/optimizer/graph_kernel/expanders/bias_add.cc index b1d1c85a66f..7c8d98d045a 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/expanders/bias_add.h +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/expanders/bias_add.cc @@ -13,14 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_EXPANDERS_BIAS_ADD_H_ -#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_EXPANDERS_BIAS_ADD_H_ - #include #include #include #include +#include "backend/optimizer/graph_kernel/expanders/expander_factory.h" #include "backend/optimizer/graph_kernel/expanders/utils.h" namespace mindspore { @@ -34,7 +32,8 @@ class BiasAdd : public OpExpander { support_format->AddFormat({kOpFormat_NCHW, kOpFormat_DEFAULT}); support_format->AddFormat({kOpFormat_NHWC, kOpFormat_DEFAULT}); validators_.emplace_back(std::move(support_format)); - validators_.emplace_back(new CheckAttr({"format"})); + auto attrs = std::initializer_list{"format"}; + validators_.emplace_back(std::make_unique(attrs)); } ~BiasAdd() = default; NodePtrList Expand() override { @@ -42,19 +41,19 @@ class BiasAdd : public OpExpander { auto input_x = inputs[0]; auto input_y = inputs[1]; if (input_x->format == kOpFormat_NCHW) { - input_y = gb.Emit("Reshape", {input_y}, {{"shape", MakeValue(ExpandDims::InferShape(input_y->shape, {1, 2}))}}); + input_y = gb.Emit("Reshape", {input_y}, {{"shape", MakeValue(ExpandDimsInferShape(input_y->shape, {1, 2}))}}); } else if (input_x->format == kOpFormat_DEFAULT) { auto data_format = GetValue(attrs_["format"]); size_t channel_idx = (data_format == kOpFormat_NHWC) ? input_x->shape.size() - 1 : 1; std::vector axis(input_x->shape.size() - channel_idx - 1, -1); if (!axis.empty()) { - input_y = gb.Emit("Reshape", {input_y}, {{"shape", MakeValue(ExpandDims::InferShape(input_y->shape, axis))}}); + input_y = gb.Emit("Reshape", {input_y}, {{"shape", MakeValue(ExpandDimsInferShape(input_y->shape, axis))}}); } } return {gb.Emit("Add", {input_x, input_y})}; } }; +OP_EXPANDER_REGISTER("BiasAdd", BiasAdd); } // namespace expanders } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_EXPANDERS_BIAS_ADD_H_ diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/expanders/expander_factory.h b/mindspore/ccsrc/backend/optimizer/graph_kernel/expanders/expander_factory.h index cfd4b64d11d..76db539b098 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/expanders/expander_factory.h +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/expanders/expander_factory.h @@ -22,22 +22,15 @@ #include #include "backend/optimizer/graph_kernel/expanders/utils.h" -#include "backend/optimizer/graph_kernel/expanders/reshape.h" -#include "backend/optimizer/graph_kernel/expanders/bias_add.h" namespace mindspore { namespace opt { namespace expanders { -#define OP_EXPANDER_CREATOR(cls) []() -> std::shared_ptr { return std::make_shared(); } - class OpExpanderFactory { public: static OpExpanderFactory &Instance() { - static std::unique_ptr instance = nullptr; - if (instance == nullptr) { - instance.reset(new OpExpanderFactory()); - } - return *instance; + static OpExpanderFactory instance; + return instance; } std::shared_ptr GetExpander(const std::string &op) { if (auto iter = creators.find(op); iter != creators.end()) { @@ -49,16 +42,24 @@ class OpExpanderFactory { } ~OpExpanderFactory() = default; - private: using RegFunc = std::function()>; - void Register(std::string &&op, RegFunc &&func) { creators.insert({op, func}); } - OpExpanderFactory() { - Register("BiasAdd", OP_EXPANDER_CREATOR(expanders::BiasAdd)); - Register("ExpandDims", OP_EXPANDER_CREATOR(expanders::ExpandDims)); - } + void Register(const std::string &op, const RegFunc &func) { creators[op] = func; } + private: std::unordered_map creators; }; + +class OpExpanderRegister { + public: + OpExpanderRegister(const std::string &name, const OpExpanderFactory::RegFunc &func) { + OpExpanderFactory::Instance().Register(name, func); + } + ~OpExpanderRegister() = default; +}; + +#define OP_EXPANDER_REGISTER(name, cls) \ + static const OpExpanderRegister g_##cls##_expander_reg( \ + name, []() -> std::shared_ptr { return std::make_shared(); }) } // namespace expanders } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/expanders/reshape.h b/mindspore/ccsrc/backend/optimizer/graph_kernel/expanders/reshape.cc similarity index 71% rename from mindspore/ccsrc/backend/optimizer/graph_kernel/expanders/reshape.h rename to mindspore/ccsrc/backend/optimizer/graph_kernel/expanders/reshape.cc index 1f12bf49295..0df513d9243 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/expanders/reshape.h +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/expanders/reshape.cc @@ -13,25 +13,24 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_EXPANDERS_RESHAPE_H_ -#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_EXPANDERS_RESHAPE_H_ - #include #include -#include "backend/optimizer/graph_kernel/model/node.h" -#include "backend/optimizer/graph_kernel/expanders/utils.h" +#include "backend/optimizer/graph_kernel/expanders/expander_factory.h" namespace mindspore { namespace opt { namespace expanders { class ExpandDims : public OpExpander { public: - ExpandDims() { validators_.emplace_back(new CheckAttr({"axis"})); } - ~ExpandDims() {} + ExpandDims() { + std::initializer_list attrs{"axis"}; + validators_.emplace_back(std::make_unique(attrs)); + } + ~ExpandDims() = default; NodePtrList Expand() override { const auto &inputs = gb.Get()->inputs(); - auto &input_x = inputs[0]; + const auto &input_x = inputs[0]; auto shape = MakeValue(ExpandDims::InferShape(input_x->shape, GetAxisList(this->attrs_["axis"]))); auto result = gb.Emit("Reshape", {input_x}, {{"shape", shape}}); return {result}; @@ -42,9 +41,7 @@ class ExpandDims : public OpExpander { for (auto x : axis) { int64_t rank = static_cast(new_shape.size()); if (x > rank || x < -rank - 1) { - std::ostringstream oss; - oss << "ExpandDims axis " << x << " is out of range of size " << new_shape.size(); - throw graphkernel::GKException(oss.str()); + MS_LOG(EXCEPTION) << "ExpandDims axis " << x << " is out of range of size " << new_shape.size(); } if (x >= 0) { new_shape.insert(new_shape.begin() + x, 1LL); @@ -55,7 +52,11 @@ class ExpandDims : public OpExpander { return new_shape; } }; +OP_EXPANDER_REGISTER("ExpandDims", ExpandDims); + +ShapeVector ExpandDimsInferShape(const ShapeVector &shape, const std::vector &axis) { + return ExpandDims::InferShape(shape, axis); +} } // namespace expanders } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_EXPANDERS_RESHAPE_H_ diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/expanders/utils.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/expanders/utils.cc index e1349981703..e90f970ecee 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/expanders/utils.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/expanders/utils.cc @@ -31,27 +31,31 @@ graphkernel::LiteGraphPtr OpExpander::Run(const BaseInfoList &inputs, const Base this->outputs_info_ = outputs; this->attrs_ = attrs; this->processor_ = processor; - for (const auto &v : validators_) { - v->Check(*this); + if (std::any_of(validators_.begin(), validators_.end(), + [this](const std::unique_ptr &v) { return !(v->Check(*this)); })) { + return nullptr; + } + if (!this->CheckInputs()) { + return nullptr; } - this->CheckInputs(); for (auto &inp : inputs) { (void)gb.Parameter(inp); } auto result = this->Expand(); gb.SetOutputs(result); - this->CheckOutputs(); + if (!this->CheckOutputs()) { + return nullptr; + } return gb.Get(); } -void OpExpander::CheckOutputs() { +bool OpExpander::CheckOutputs() { // check the output shape/type/format are same as the original basic node's output. const NodePtrList &outputs = gb.Get()->GetOutputs(); if (outputs.size() != this->outputs_info_.size()) { - std::ostringstream oss; - oss << "the output num was not equal to the original output num : " << outputs.size() << " vs " - << outputs_info_.size(); - throw graphkernel::GKException(oss.str()); + MS_LOG(INFO) << "the output num was not equal to the original output num : " << outputs.size() << " vs " + << outputs_info_.size(); + return false; } for (size_t i = 0; i < outputs.size(); i++) { if (outputs[i]->shape != outputs_info_[i].shape) { @@ -65,21 +69,21 @@ void OpExpander::CheckOutputs() { oss << s << ","; } oss << "]"; - throw graphkernel::GKException(oss.str()); + MS_LOG(INFO) << oss.str(); + return false; } if (outputs[i]->type != outputs_info_[i].type) { - std::ostringstream oss; - oss << "Op " << this->op_ << "'s output type [" << outputs[i]->type << "] is wrong, expect: [" - << outputs_info_[i].type << "]"; - throw graphkernel::GKException(oss.str()); + MS_LOG(INFO) << "Op " << this->op_ << "'s output type [" << outputs[i]->type << "] is wrong, expect: [" + << outputs_info_[i].type << "]"; + return false; } if (outputs[i]->format != outputs_info_[i].format) { - std::ostringstream oss; - oss << "Op " << this->op_ << "'s output format [" << outputs[i]->format << "] is wrong, expect: [" - << outputs_info_[i].format << "]"; - throw graphkernel::GKException(oss.str()); + MS_LOG(INFO) << "Op " << this->op_ << "'s output format [" << outputs[i]->format << "] is wrong, expect: [" + << outputs_info_[i].format << "]"; + return false; } } + return true; } std::vector GetAxisList(const ValuePtr &value) { diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/expanders/utils.h b/mindspore/ccsrc/backend/optimizer/graph_kernel/expanders/utils.h index abe2c3ca502..5a26ef42907 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/expanders/utils.h +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/expanders/utils.h @@ -37,9 +37,9 @@ class OpExpander { virtual ~OpExpander() = default; protected: - virtual void CheckInputs() {} + virtual bool CheckInputs() { return true; } virtual NodePtrList Expand() = 0; - void CheckOutputs(); + bool CheckOutputs(); graphkernel::LiteGraph::GraphBuilder gb; std::string op_; @@ -57,37 +57,36 @@ class OpExpander { class Validator { public: - virtual void Check(const OpExpander &e) = 0; + virtual bool Check(const OpExpander &e) = 0; }; class CheckAllFormatsSame : public Validator { public: - void Check(const OpExpander &e) override { - if (e.inputs_info_.empty()) return; + bool Check(const OpExpander &e) override { + if (e.inputs_info_.empty()) return true; const auto &fmt_0 = e.inputs_info_[0].format; for (size_t i = 1; i < e.inputs_info_.size(); i++) { if (e.inputs_info_[i].format != fmt_0) { - std::ostringstream oss; - oss << "Unmatched format for op " << e.op_; - throw graphkernel::GKException(oss.str()); + MS_LOG(INFO) << "Unmatched format for op " << e.op_; + return false; } } + return true; } }; class CheckAttr : public Validator { public: - CheckAttr() = default; CheckAttr(std::initializer_list l) : attrs_(l) {} ~CheckAttr() = default; - void Check(const OpExpander &e) override { + bool Check(const OpExpander &e) override { for (auto &a : attrs_) { if (e.attrs_.count(a) == 0) { - std::ostringstream oss; - oss << "attr " << a << " does not exist. op " << e.op_; - throw graphkernel::GKException(oss.str()); + MS_LOG(INFO) << "attr " << a << " does not exist. op " << e.op_; + return false; } } + return true; } private: @@ -97,7 +96,7 @@ class CheckAttr : public Validator { class SupportFormat : public Validator { public: void AddFormat(std::initializer_list l) { formats_.emplace_back(l); } - void Check(const OpExpander &e) override { + bool Check(const OpExpander &e) override { for (auto &formats : formats_) { if (formats.size() != e.inputs_info_.size()) { continue; @@ -110,12 +109,11 @@ class SupportFormat : public Validator { } } if (match) { - return; + return true; } } - std::ostringstream oss; - oss << "unsupported format for op " << e.op_; - throw graphkernel::GKException(oss.str()); + MS_LOG(INFO) << "unsupported format for op " << e.op_; + return false; } private: @@ -123,6 +121,7 @@ class SupportFormat : public Validator { }; std::vector GetAxisList(const ValuePtr &value); +ShapeVector ExpandDimsInferShape(const ShapeVector &shape, const std::vector &axis); } // namespace expanders } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc index c5c99500db5..74274853675 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc @@ -153,13 +153,12 @@ FuncGraphPtr DefaultExpander::CreateExpandFuncGraph(const CNodePtr &node) { outputs[i].format = AnfAlgo::GetOutputFormat(node, i); } auto &attrs = AnfAlgo::GetCNodePrimitive(node)->attrs(); - try { - auto litegraph = expander_ptr->Run(inputs, outputs, attrs, kernel::GetStrProcessorFromContext()); - return LiteGraph2AnfGraph(litegraph); - } catch (const graphkernel::GKException &e) { - MS_LOG(INFO) << e.what() << ", undo expanding this op"; + auto litegraph = expander_ptr->Run(inputs, outputs, attrs, kernel::GetStrProcessorFromContext()); + if (litegraph == nullptr) { + MS_LOG(INFO) << "undo expanding " << node->fullname_with_scope(); return nullptr; } + return LiteGraph2AnfGraph(litegraph); } AnfNodePtr PyExpander::CreateExpandGraphKernel(const FuncGraphPtr &new_func_graph, const CNodePtr &old_node) { diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/model/node.h b/mindspore/ccsrc/backend/optimizer/graph_kernel/model/node.h index edab0cd7e19..a651c71a3cd 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/model/node.h +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/model/node.h @@ -152,15 +152,6 @@ class OutputNode : public Node { void Dump(std::ostringstream &os) const override { ; } NType NodeType() override { return NType::Output; } }; - -class GKException : public std::exception { - public: - explicit GKException(const std::string &message) : msg_(message) {} - const char *what() const noexcept override { return msg_.c_str(); } - - protected: - std::string msg_; -}; } // namespace graphkernel } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/model/op_node.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/model/op_node.cc index 8ea98b98e92..26c06295959 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/model/op_node.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/model/op_node.cc @@ -200,29 +200,35 @@ NodePtr PrimOp::InferValue(const NodePtrList &inputs, const DAttrs &attrs, const // default format shape to fractal_Nz format shape DShape ToNz(const DShape &default_shape) { - if (default_shape.size() != 1 && default_shape.size() != 2) { - throw GKException("shape is too long"); + constexpr size_t nz_size = 2; + auto len = default_shape.size(); + DShape leading_shape; + DShape tail_shape; + if (default_shape.size() > nz_size) { + leading_shape.insert(leading_shape.end(), default_shape.begin(), default_shape.end() - nz_size); } - DShape output_shape; - if (default_shape.size() == 1 || (default_shape.size() == 2 && default_shape[0] == 1)) { - output_shape = {default_shape[default_shape.size() - 1] / 16, 1, 1, 16}; - if (default_shape[default_shape.size() - 1] % 16 != 0) { - throw GKException("should be multiplies of 16"); + if (default_shape.size() == 1 || (default_shape.size() >= nz_size && default_shape[len - nz_size] == 1)) { + // (32) or (N, 1, 32) -> (N, 2, 1, 1, 16) + if (default_shape.back() % 16 != 0) { + MS_LOG(EXCEPTION) << "default_shape[-1] should be multiplies of 16, but got " << default_shape.back(); } - - } else if (default_shape.size() == 2 || default_shape[1] == 1) { - output_shape = {1, default_shape[0] / 16, 16, 1}; - if (default_shape[0] % 16 != 0) { - throw GKException("should be multiplies of 16"); + tail_shape = {default_shape.back() / 16, 1, 1, 16}; + } else if (default_shape.size() >= nz_size || default_shape[1] == 1) { + // (N, 32, 1) -> (N, 1, 2, 16, 1) + if (default_shape[len - nz_size] % 16 != 0) { + MS_LOG(EXCEPTION) << "default_shape[-2] should be multiplies of 16, but got " << default_shape[len - nz_size]; } - + tail_shape = {1, default_shape[0] / 16, 16, 1}; } else { - output_shape = {default_shape[1] / 16, default_shape[0] / 16, 16, 16}; - if (default_shape[0] % 16 != 0 || default_shape[1] % 16 != 0) { - throw GKException("should be multiplies of 16"); + // (N, 32, 48) -> (N, 3, 2, 16, 16) + if (default_shape.back() % 16 != 0 || default_shape[len - nz_size] % 16 != 0) { + MS_LOG(EXCEPTION) << "default_shape[-1] and default_shape[-2]should be multiplies of 16, but got " + << default_shape.back() << " " << default_shape[len - nz_size]; } + tail_shape = {default_shape[1] / 16, default_shape[0] / 16, 16, 16}; } - return output_shape; + leading_shape.insert(leading_shape.end(), tail_shape.begin(), tail_shape.end()); + return leading_shape; } DShape BroadcastShape(const NodePtrList &inputs, bool to_nz = false) { @@ -252,7 +258,7 @@ DShape BroadcastShape(const NodePtrList &inputs, bool to_nz = false) { output_shape[i] = align_shape[i]; } if (output_shape[i] != align_shape[i]) { - throw GKException("shape broadcast failed"); + MS_LOG(EXCEPTION) << "Shape broadcast failed. " << output_shape[i] << " vs " << align_shape[i]; } } } @@ -272,7 +278,7 @@ DShape ElemwiseOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) { })) { return BroadcastShape(inputs, true); } - throw GKException("Only support default and fractal_nz"); + MS_LOG(EXCEPTION) << "Unsupported format."; } DFormat ElemwiseOp::InferFormat(const NodePtrList &inputs, const DAttrs &attrs) { @@ -374,22 +380,20 @@ DShape ReduceOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) { return new_shape; } -void CheckNd(const std::vector &shape, size_t n) { - if (shape.size() != n) { - std::ostringstream info; - info << "input dimension should be " << n << ", but got " << shape.size(); - throw GKException(info.str()); - } -} - DShape Conv2dOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) { + auto check_nd = [](const std::vector &shape, size_t n) { + if (shape.size() != n) { + MS_LOG(EXCEPTION) << "input dimension should be " << n << ", but got " << shape.size(); + } + }; auto shape0 = inputs[0]->shape; auto shape1 = inputs[1]->shape; - CheckNd(shape0, 4); - CheckNd(shape1, 4); + check_nd(shape0, 4); + check_nd(shape1, 4); + CHECK_ATTR(attrs, "format"); if (inputs[0]->format != kOpFormat_NHWC && inputs[1]->format != kOpFormat_NHWC && GetValue(attrs.find("format")->second) != kOpFormat_NHWC) { - throw GKException("check NHWC format failed"); + MS_LOG(EXCEPTION) << "check NHWC format failed"; } auto n = shape0[0]; auto h = shape0[1]; @@ -405,10 +409,10 @@ DShape Conv2dOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) { auto kernel_size = GetListInt(attrs.find("kernel_size")->second); auto stride = GetListInt(attrs.find("stride")->second); auto dilation = GetListInt(attrs.find("dilation")->second); - CheckNd(pad_list, 4); - CheckNd(kernel_size, 2); - CheckNd(stride, 4); - CheckNd(dilation, 4); + check_nd(pad_list, 4); + check_nd(kernel_size, 2); + check_nd(stride, 4); + check_nd(dilation, 4); bool has_pad = false; if (pad_list[0] != pad_list[1] || pad_list[2] != pad_list[3]) { has_pad = true; @@ -464,19 +468,17 @@ DFormat TransposeOp::InferFormat(const NodePtrList &inputs, const DAttrs &attrs) std::vector nhwc2nchw = {0, 3, 1, 2}; if (perm == nhwc2nchw) return kOpFormat_DEFAULT; } - std::ostringstream info; - info << "Unsupported Transpose. ori_format = " << ori_format << ", perm = " << attrs.find("perm")->second->ToString(); - throw GKException(info.str()); + return kOpFormat_DEFAULT; } DShape MatMulOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) { std::vector shape0 = inputs[0]->shape; std::vector shape1 = inputs[1]->shape; if (shape0.size() != 2 || shape1.size() != 2) { - std::ostringstream info; - info << "MatMul's input's dimension must be 2, but got " << shape0.size() << " and " << shape1.size(); - throw GKException(info.str()); + MS_LOG(EXCEPTION) << "MatMul's input's dimension must be 2, but got " << shape0.size() << " and " << shape1.size(); } + CHECK_ATTR(attrs, "transpose_a"); + CHECK_ATTR(attrs, "transpose_b"); auto transpose_a = GetValue(attrs.find("transpose_a")->second); auto transpose_b = GetValue(attrs.find("transpose_b")->second); int64_t m = transpose_a ? shape0[1] : shape0[0]; @@ -491,6 +493,7 @@ DShape MatMulOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) { } TypeId MatMulOp::InferType(const NodePtrList &inputs, const DAttrs &attrs) { + CHECK_ATTR(attrs, "dst_type"); if (attrs.find("dst_type") == attrs.end()) return inputs[0]->type; auto dst_type = attrs.find("dst_type")->second; if (dst_type->isa()) { @@ -502,6 +505,8 @@ TypeId MatMulOp::InferType(const NodePtrList &inputs, const DAttrs &attrs) { DShape PadAkgOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) { std::vector shape0 = inputs[0]->shape; size_t n = shape0.size(); + CHECK_ATTR(attrs, "head"); + CHECK_ATTR(attrs, "tail"); std::vector pad_before = GetListInt(attrs.find("head")->second); std::vector pad_after = GetListInt(attrs.find("tail")->second); if (pad_before.size() != n || pad_after.size() != n) { @@ -518,6 +523,7 @@ DShape PadAkgOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) { DShape UnPadAkgOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) { std::vector shape0 = inputs[0]->shape; size_t n = shape0.size(); + CHECK_ATTR(attrs, "tail"); std::vector unpad_after = GetListInt(attrs.find("tail")->second); if (unpad_after.size() != n) { MS_LOG(EXCEPTION) << "Input dimension and pad mismatch: " << n << " vs " << unpad_after.size(); @@ -531,13 +537,12 @@ DShape UnPadAkgOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) { void ComplexOp::CheckType(const NodePtrList &inputs, const DAttrs &attrs) { if (inputs[0]->type != TypeId::kNumberTypeFloat32) { - throw GKException("Complex's input[0] should be float32"); + MS_LOG(EXCEPTION) << "Complex's input[0] should be float32"; } if (inputs[0]->type != inputs[1]->type) { MS_LOG(EXCEPTION) << "Complex's input[0] and inputs[1]'s type mismatch"; } } - } // namespace graphkernel } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/model/op_node.h b/mindspore/ccsrc/backend/optimizer/graph_kernel/model/op_node.h index 3ccede1c471..eca01f79c93 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/model/op_node.h +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/model/op_node.h @@ -251,7 +251,7 @@ class CImagOp : public ElemwiseOp { protected: void CheckType(const NodePtrList &inputs, const DAttrs &attrs) override { if (inputs[0]->type != TypeId::kNumberTypeComplex64) { - throw GKException("CImag's input[0] should be complex64"); + MS_LOG(EXCEPTION) << "CImag's input[0] should be complex64"; } }; @@ -266,7 +266,7 @@ class CRealOp : public ElemwiseOp { protected: void CheckType(const NodePtrList &inputs, const DAttrs &attrs) override { if (inputs[0]->type != TypeId::kNumberTypeComplex64) { - throw GKException("CReal's input[0] should be complex64"); + MS_LOG(EXCEPTION) << "CReal's input[0] should be complex64"; } }; diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/transform_op_optimizer.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/transform_op_optimizer.cc index ce8f2308aea..a013298d1dc 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/transform_op_optimizer.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/transform_op_optimizer.cc @@ -229,9 +229,7 @@ class TransformOp { perm = perm_map[{format_b_, format_a_}]; } if (perm.empty()) { - std::ostringstream oss; - oss << "unsupported format: " << format_a_ << " to " << format_b_; - throw graphkernel::GKException(oss.str()); + MS_LOG(EXCEPTION) << "unsupported format: " << format_a_ << " to " << format_b_; } auto op = graphkernel::OpRegistry::Instance().NewOp("Transpose", "new_trans"); op->SetAttr("perm", MakeValue(perm)); @@ -438,23 +436,19 @@ bool TransformOpOptimizer::Run(const FuncGraphPtr &kernel_graph) { bool changed = false; for (auto node : todos) { if (!AnfAlgo::IsGraphKernel(node)) continue; - try { - auto sub_func_graph = AnfAlgo::GetCNodeFuncGraphPtr(node); - auto litegraph = AnfGraph2LiteGraph(sub_func_graph); - if (Process(litegraph)) { - changed = true; - AnfNodePtrList outputs; - auto new_funcgraph = LiteGraph2AnfGraph(litegraph, &outputs); - new_funcgraph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, sub_func_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)); - auto cnode = node->cast(); - AnfNodePtrList inputs(cnode->inputs().begin() + 1, cnode->inputs().end()); - auto new_node = CreateNewFuseCNode(kernel_graph, new_funcgraph, inputs, outputs); - SetNewKernelInfo(new_node, new_funcgraph, inputs, outputs); - mng->Replace(node, new_node); - mng->AddFuncGraph(new_funcgraph); - } - } catch (const graphkernel::GKException &e) { - MS_LOG(WARNING) << e.what() << ", so we undo airthmetic simplify for this graph"; + auto sub_func_graph = AnfAlgo::GetCNodeFuncGraphPtr(node); + auto litegraph = AnfGraph2LiteGraph(sub_func_graph); + if (Process(litegraph)) { + changed = true; + AnfNodePtrList outputs; + auto new_funcgraph = LiteGraph2AnfGraph(litegraph, &outputs); + new_funcgraph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, sub_func_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)); + auto cnode = node->cast(); + AnfNodePtrList inputs(cnode->inputs().begin() + 1, cnode->inputs().end()); + auto new_node = CreateNewFuseCNode(kernel_graph, new_funcgraph, inputs, outputs); + SetNewKernelInfo(new_node, new_funcgraph, inputs, outputs); + mng->Replace(node, new_node); + mng->AddFuncGraph(new_funcgraph); } } return changed;