!22878 Remove "GKException" from c++ code.

Merge pull request !22878 from DeshiChen/0902_gkexception
This commit is contained in:
i-robot 2021-09-06 06:59:25 +00:00 committed by Gitee
commit 6d0bdd83da
11 changed files with 163 additions and 179 deletions

View File

@ -158,12 +158,10 @@ PatternNodePtr PatternTree::BuildTree(const std::string &pattern_str) {
}
cur_node->AddInput(BuildTree(op_inputs));
return cur_node;
} else {
return std::make_shared<PatternNode>(pattern_str);
}
}
return nullptr;
}
@ -276,7 +274,6 @@ bool DfsMatchGraph(const graphkernel::NodePtr &tmp_node, const PatternNodePtr &t
return false;
}
}
} else {
for (size_t i = 0; i < tmp_pattern_inputs.size(); i++) {
if (!DfsMatchGraph(tmp_node_inputs[i], tmp_pattern_inputs[i], para_to_ref, const_to_ref, res)) {
@ -387,7 +384,6 @@ class ExtraReduce1PatternTree : public PatternTree {
for (auto &i : GetValue<std::vector<int64_t>>(first_reduce->attrs().find("axis")->second)) {
axis_set.insert(i);
}
} else {
auto first_axis = GetValue<std::vector<int64_t>>(first_reduce->attrs().find("axis")->second);
auto second_axis = GetValue<std::vector<int64_t>>(origin_root->attrs().find("axis")->second);
@ -538,7 +534,6 @@ std::unordered_map<std::string, std::vector<PatternTreePtr>> GetExpressions() {
std::unordered_set<std::string> enable_ids{flags.enable_simplify_exprs_only.begin(),
flags.enable_simplify_exprs_only.end()};
std::unordered_set<std::string> disable_ids{flags.disable_simplify_exprs.begin(), flags.disable_simplify_exprs.end()};
for (auto &e : expressions) {
if (!enable_ids.empty()) {
if (enable_ids.count(std::to_string(e.id)) == 0) continue;
@ -640,33 +635,29 @@ bool ArithmeticSimplify::Run(const FuncGraphPtr &func_graph) {
expressions_map_ = GetExpressions();
for (auto node : func_graph->GetOrderedCnodes()) {
if (AnfAlgo::IsGraphKernel(node)) {
try {
auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
graphkernel::LiteGraphPtr lg = AnfGraph2LiteGraph(sub_graph);
bool find_pattern = true;
bool change_anf_graph = false;
while (find_pattern) {
find_pattern = false;
find_pattern = DoArithmeticTrans(lg) || find_pattern;
find_pattern = DoConstantFold(lg) || find_pattern;
change_anf_graph = change_anf_graph || find_pattern;
}
if (!change_anf_graph) continue;
ReorganizeEmptyGraph(lg);
AnfNodePtrList outputs;
auto new_funcgraph = LiteGraph2AnfGraph(lg, &outputs);
new_funcgraph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, sub_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
auto cnode = node->cast<CNodePtr>();
AnfNodePtrList inputs(cnode->inputs().begin() + 1, cnode->inputs().end());
EliminateRedundantParameters(new_funcgraph, &inputs);
auto new_node = CreateNewFuseCNode(func_graph, new_funcgraph, inputs, outputs);
SetNewKernelInfo(new_node, new_funcgraph, inputs, outputs);
mng->Replace(node, new_node);
mng->AddFuncGraph(new_funcgraph);
do_simplify = true;
} catch (const graphkernel::GKException &e) {
MS_LOG(WARNING) << e.what() << ", so we undo airthmetic simplify for this graph";
auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
graphkernel::LiteGraphPtr lg = AnfGraph2LiteGraph(sub_graph);
bool find_pattern = true;
bool change_anf_graph = false;
while (find_pattern) {
find_pattern = false;
find_pattern = DoArithmeticTrans(lg) || find_pattern;
find_pattern = DoConstantFold(lg) || find_pattern;
change_anf_graph = change_anf_graph || find_pattern;
}
if (!change_anf_graph) continue;
ReorganizeEmptyGraph(lg);
AnfNodePtrList outputs;
auto new_funcgraph = LiteGraph2AnfGraph(lg, &outputs);
new_funcgraph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, sub_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
auto cnode = node->cast<CNodePtr>();
AnfNodePtrList inputs(cnode->inputs().begin() + 1, cnode->inputs().end());
EliminateRedundantParameters(new_funcgraph, &inputs);
auto new_node = CreateNewFuseCNode(func_graph, new_funcgraph, inputs, outputs);
SetNewKernelInfo(new_node, new_funcgraph, inputs, outputs);
mng->Replace(node, new_node);
mng->AddFuncGraph(new_funcgraph);
do_simplify = true;
}
}
return do_simplify;

View File

@ -13,14 +13,12 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_EXPANDERS_BIAS_ADD_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_EXPANDERS_BIAS_ADD_H_
#include <string>
#include <utility>
#include <vector>
#include <memory>
#include "backend/optimizer/graph_kernel/expanders/expander_factory.h"
#include "backend/optimizer/graph_kernel/expanders/utils.h"
namespace mindspore {
@ -34,7 +32,8 @@ class BiasAdd : public OpExpander {
support_format->AddFormat({kOpFormat_NCHW, kOpFormat_DEFAULT});
support_format->AddFormat({kOpFormat_NHWC, kOpFormat_DEFAULT});
validators_.emplace_back(std::move(support_format));
validators_.emplace_back(new CheckAttr({"format"}));
auto attrs = std::initializer_list<std::string>{"format"};
validators_.emplace_back(std::make_unique<CheckAttr>(attrs));
}
~BiasAdd() = default;
NodePtrList Expand() override {
@ -42,19 +41,19 @@ class BiasAdd : public OpExpander {
auto input_x = inputs[0];
auto input_y = inputs[1];
if (input_x->format == kOpFormat_NCHW) {
input_y = gb.Emit("Reshape", {input_y}, {{"shape", MakeValue(ExpandDims::InferShape(input_y->shape, {1, 2}))}});
input_y = gb.Emit("Reshape", {input_y}, {{"shape", MakeValue(ExpandDimsInferShape(input_y->shape, {1, 2}))}});
} else if (input_x->format == kOpFormat_DEFAULT) {
auto data_format = GetValue<std::string>(attrs_["format"]);
size_t channel_idx = (data_format == kOpFormat_NHWC) ? input_x->shape.size() - 1 : 1;
std::vector<int64_t> axis(input_x->shape.size() - channel_idx - 1, -1);
if (!axis.empty()) {
input_y = gb.Emit("Reshape", {input_y}, {{"shape", MakeValue(ExpandDims::InferShape(input_y->shape, axis))}});
input_y = gb.Emit("Reshape", {input_y}, {{"shape", MakeValue(ExpandDimsInferShape(input_y->shape, axis))}});
}
}
return {gb.Emit("Add", {input_x, input_y})};
}
};
OP_EXPANDER_REGISTER("BiasAdd", BiasAdd);
} // namespace expanders
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_EXPANDERS_BIAS_ADD_H_

View File

@ -22,22 +22,15 @@
#include <memory>
#include "backend/optimizer/graph_kernel/expanders/utils.h"
#include "backend/optimizer/graph_kernel/expanders/reshape.h"
#include "backend/optimizer/graph_kernel/expanders/bias_add.h"
namespace mindspore {
namespace opt {
namespace expanders {
#define OP_EXPANDER_CREATOR(cls) []() -> std::shared_ptr<OpExpander> { return std::make_shared<cls>(); }
class OpExpanderFactory {
public:
static OpExpanderFactory &Instance() {
static std::unique_ptr<OpExpanderFactory> instance = nullptr;
if (instance == nullptr) {
instance.reset(new OpExpanderFactory());
}
return *instance;
static OpExpanderFactory instance;
return instance;
}
std::shared_ptr<OpExpander> GetExpander(const std::string &op) {
if (auto iter = creators.find(op); iter != creators.end()) {
@ -49,16 +42,24 @@ class OpExpanderFactory {
}
~OpExpanderFactory() = default;
private:
using RegFunc = std::function<std::shared_ptr<OpExpander>()>;
void Register(std::string &&op, RegFunc &&func) { creators.insert({op, func}); }
OpExpanderFactory() {
Register("BiasAdd", OP_EXPANDER_CREATOR(expanders::BiasAdd));
Register("ExpandDims", OP_EXPANDER_CREATOR(expanders::ExpandDims));
}
void Register(const std::string &op, const RegFunc &func) { creators[op] = func; }
private:
std::unordered_map<std::string, RegFunc> creators;
};
class OpExpanderRegister {
public:
OpExpanderRegister(const std::string &name, const OpExpanderFactory::RegFunc &func) {
OpExpanderFactory::Instance().Register(name, func);
}
~OpExpanderRegister() = default;
};
#define OP_EXPANDER_REGISTER(name, cls) \
static const OpExpanderRegister g_##cls##_expander_reg( \
name, []() -> std::shared_ptr<OpExpander> { return std::make_shared<cls>(); })
} // namespace expanders
} // namespace opt
} // namespace mindspore

View File

@ -13,25 +13,24 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_EXPANDERS_RESHAPE_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_EXPANDERS_RESHAPE_H_
#include <memory>
#include <vector>
#include "backend/optimizer/graph_kernel/model/node.h"
#include "backend/optimizer/graph_kernel/expanders/utils.h"
#include "backend/optimizer/graph_kernel/expanders/expander_factory.h"
namespace mindspore {
namespace opt {
namespace expanders {
class ExpandDims : public OpExpander {
public:
ExpandDims() { validators_.emplace_back(new CheckAttr({"axis"})); }
~ExpandDims() {}
ExpandDims() {
std::initializer_list<std::string> attrs{"axis"};
validators_.emplace_back(std::make_unique<CheckAttr>(attrs));
}
~ExpandDims() = default;
NodePtrList Expand() override {
const auto &inputs = gb.Get()->inputs();
auto &input_x = inputs[0];
const auto &input_x = inputs[0];
auto shape = MakeValue(ExpandDims::InferShape(input_x->shape, GetAxisList(this->attrs_["axis"])));
auto result = gb.Emit("Reshape", {input_x}, {{"shape", shape}});
return {result};
@ -42,9 +41,7 @@ class ExpandDims : public OpExpander {
for (auto x : axis) {
int64_t rank = static_cast<int64_t>(new_shape.size());
if (x > rank || x < -rank - 1) {
std::ostringstream oss;
oss << "ExpandDims axis " << x << " is out of range of size " << new_shape.size();
throw graphkernel::GKException(oss.str());
MS_LOG(EXCEPTION) << "ExpandDims axis " << x << " is out of range of size " << new_shape.size();
}
if (x >= 0) {
new_shape.insert(new_shape.begin() + x, 1LL);
@ -55,7 +52,11 @@ class ExpandDims : public OpExpander {
return new_shape;
}
};
OP_EXPANDER_REGISTER("ExpandDims", ExpandDims);
ShapeVector ExpandDimsInferShape(const ShapeVector &shape, const std::vector<int64_t> &axis) {
return ExpandDims::InferShape(shape, axis);
}
} // namespace expanders
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_EXPANDERS_RESHAPE_H_

View File

@ -31,27 +31,31 @@ graphkernel::LiteGraphPtr OpExpander::Run(const BaseInfoList &inputs, const Base
this->outputs_info_ = outputs;
this->attrs_ = attrs;
this->processor_ = processor;
for (const auto &v : validators_) {
v->Check(*this);
if (std::any_of(validators_.begin(), validators_.end(),
[this](const std::unique_ptr<Validator> &v) { return !(v->Check(*this)); })) {
return nullptr;
}
if (!this->CheckInputs()) {
return nullptr;
}
this->CheckInputs();
for (auto &inp : inputs) {
(void)gb.Parameter(inp);
}
auto result = this->Expand();
gb.SetOutputs(result);
this->CheckOutputs();
if (!this->CheckOutputs()) {
return nullptr;
}
return gb.Get();
}
void OpExpander::CheckOutputs() {
bool OpExpander::CheckOutputs() {
// check the output shape/type/format are same as the original basic node's output.
const NodePtrList &outputs = gb.Get()->GetOutputs();
if (outputs.size() != this->outputs_info_.size()) {
std::ostringstream oss;
oss << "the output num was not equal to the original output num : " << outputs.size() << " vs "
<< outputs_info_.size();
throw graphkernel::GKException(oss.str());
MS_LOG(INFO) << "the output num was not equal to the original output num : " << outputs.size() << " vs "
<< outputs_info_.size();
return false;
}
for (size_t i = 0; i < outputs.size(); i++) {
if (outputs[i]->shape != outputs_info_[i].shape) {
@ -65,21 +69,21 @@ void OpExpander::CheckOutputs() {
oss << s << ",";
}
oss << "]";
throw graphkernel::GKException(oss.str());
MS_LOG(INFO) << oss.str();
return false;
}
if (outputs[i]->type != outputs_info_[i].type) {
std::ostringstream oss;
oss << "Op " << this->op_ << "'s output type [" << outputs[i]->type << "] is wrong, expect: ["
<< outputs_info_[i].type << "]";
throw graphkernel::GKException(oss.str());
MS_LOG(INFO) << "Op " << this->op_ << "'s output type [" << outputs[i]->type << "] is wrong, expect: ["
<< outputs_info_[i].type << "]";
return false;
}
if (outputs[i]->format != outputs_info_[i].format) {
std::ostringstream oss;
oss << "Op " << this->op_ << "'s output format [" << outputs[i]->format << "] is wrong, expect: ["
<< outputs_info_[i].format << "]";
throw graphkernel::GKException(oss.str());
MS_LOG(INFO) << "Op " << this->op_ << "'s output format [" << outputs[i]->format << "] is wrong, expect: ["
<< outputs_info_[i].format << "]";
return false;
}
}
return true;
}
std::vector<int64_t> GetAxisList(const ValuePtr &value) {

View File

@ -37,9 +37,9 @@ class OpExpander {
virtual ~OpExpander() = default;
protected:
virtual void CheckInputs() {}
virtual bool CheckInputs() { return true; }
virtual NodePtrList Expand() = 0;
void CheckOutputs();
bool CheckOutputs();
graphkernel::LiteGraph::GraphBuilder gb;
std::string op_;
@ -57,37 +57,36 @@ class OpExpander {
class Validator {
public:
virtual void Check(const OpExpander &e) = 0;
virtual bool Check(const OpExpander &e) = 0;
};
class CheckAllFormatsSame : public Validator {
public:
void Check(const OpExpander &e) override {
if (e.inputs_info_.empty()) return;
bool Check(const OpExpander &e) override {
if (e.inputs_info_.empty()) return true;
const auto &fmt_0 = e.inputs_info_[0].format;
for (size_t i = 1; i < e.inputs_info_.size(); i++) {
if (e.inputs_info_[i].format != fmt_0) {
std::ostringstream oss;
oss << "Unmatched format for op " << e.op_;
throw graphkernel::GKException(oss.str());
MS_LOG(INFO) << "Unmatched format for op " << e.op_;
return false;
}
}
return true;
}
};
class CheckAttr : public Validator {
public:
CheckAttr() = default;
CheckAttr(std::initializer_list<std::string> l) : attrs_(l) {}
~CheckAttr() = default;
void Check(const OpExpander &e) override {
bool Check(const OpExpander &e) override {
for (auto &a : attrs_) {
if (e.attrs_.count(a) == 0) {
std::ostringstream oss;
oss << "attr " << a << " does not exist. op " << e.op_;
throw graphkernel::GKException(oss.str());
MS_LOG(INFO) << "attr " << a << " does not exist. op " << e.op_;
return false;
}
}
return true;
}
private:
@ -97,7 +96,7 @@ class CheckAttr : public Validator {
class SupportFormat : public Validator {
public:
void AddFormat(std::initializer_list<std::string> l) { formats_.emplace_back(l); }
void Check(const OpExpander &e) override {
bool Check(const OpExpander &e) override {
for (auto &formats : formats_) {
if (formats.size() != e.inputs_info_.size()) {
continue;
@ -110,12 +109,11 @@ class SupportFormat : public Validator {
}
}
if (match) {
return;
return true;
}
}
std::ostringstream oss;
oss << "unsupported format for op " << e.op_;
throw graphkernel::GKException(oss.str());
MS_LOG(INFO) << "unsupported format for op " << e.op_;
return false;
}
private:
@ -123,6 +121,7 @@ class SupportFormat : public Validator {
};
std::vector<int64_t> GetAxisList(const ValuePtr &value);
ShapeVector ExpandDimsInferShape(const ShapeVector &shape, const std::vector<int64_t> &axis);
} // namespace expanders
} // namespace opt
} // namespace mindspore

View File

@ -153,13 +153,12 @@ FuncGraphPtr DefaultExpander::CreateExpandFuncGraph(const CNodePtr &node) {
outputs[i].format = AnfAlgo::GetOutputFormat(node, i);
}
auto &attrs = AnfAlgo::GetCNodePrimitive(node)->attrs();
try {
auto litegraph = expander_ptr->Run(inputs, outputs, attrs, kernel::GetStrProcessorFromContext());
return LiteGraph2AnfGraph(litegraph);
} catch (const graphkernel::GKException &e) {
MS_LOG(INFO) << e.what() << ", undo expanding this op";
auto litegraph = expander_ptr->Run(inputs, outputs, attrs, kernel::GetStrProcessorFromContext());
if (litegraph == nullptr) {
MS_LOG(INFO) << "undo expanding " << node->fullname_with_scope();
return nullptr;
}
return LiteGraph2AnfGraph(litegraph);
}
AnfNodePtr PyExpander::CreateExpandGraphKernel(const FuncGraphPtr &new_func_graph, const CNodePtr &old_node) {

View File

@ -152,15 +152,6 @@ class OutputNode : public Node {
void Dump(std::ostringstream &os) const override { ; }
NType NodeType() override { return NType::Output; }
};
class GKException : public std::exception {
public:
explicit GKException(const std::string &message) : msg_(message) {}
const char *what() const noexcept override { return msg_.c_str(); }
protected:
std::string msg_;
};
} // namespace graphkernel
} // namespace opt
} // namespace mindspore

View File

@ -200,29 +200,35 @@ NodePtr PrimOp::InferValue(const NodePtrList &inputs, const DAttrs &attrs, const
// default format shape to fractal_Nz format shape
DShape ToNz(const DShape &default_shape) {
if (default_shape.size() != 1 && default_shape.size() != 2) {
throw GKException("shape is too long");
constexpr size_t nz_size = 2;
auto len = default_shape.size();
DShape leading_shape;
DShape tail_shape;
if (default_shape.size() > nz_size) {
leading_shape.insert(leading_shape.end(), default_shape.begin(), default_shape.end() - nz_size);
}
DShape output_shape;
if (default_shape.size() == 1 || (default_shape.size() == 2 && default_shape[0] == 1)) {
output_shape = {default_shape[default_shape.size() - 1] / 16, 1, 1, 16};
if (default_shape[default_shape.size() - 1] % 16 != 0) {
throw GKException("should be multiplies of 16");
if (default_shape.size() == 1 || (default_shape.size() >= nz_size && default_shape[len - nz_size] == 1)) {
// (32) or (N, 1, 32) -> (N, 2, 1, 1, 16)
if (default_shape.back() % 16 != 0) {
MS_LOG(EXCEPTION) << "default_shape[-1] should be multiplies of 16, but got " << default_shape.back();
}
} else if (default_shape.size() == 2 || default_shape[1] == 1) {
output_shape = {1, default_shape[0] / 16, 16, 1};
if (default_shape[0] % 16 != 0) {
throw GKException("should be multiplies of 16");
tail_shape = {default_shape.back() / 16, 1, 1, 16};
} else if (default_shape.size() >= nz_size || default_shape[1] == 1) {
// (N, 32, 1) -> (N, 1, 2, 16, 1)
if (default_shape[len - nz_size] % 16 != 0) {
MS_LOG(EXCEPTION) << "default_shape[-2] should be multiplies of 16, but got " << default_shape[len - nz_size];
}
tail_shape = {1, default_shape[0] / 16, 16, 1};
} else {
output_shape = {default_shape[1] / 16, default_shape[0] / 16, 16, 16};
if (default_shape[0] % 16 != 0 || default_shape[1] % 16 != 0) {
throw GKException("should be multiplies of 16");
// (N, 32, 48) -> (N, 3, 2, 16, 16)
if (default_shape.back() % 16 != 0 || default_shape[len - nz_size] % 16 != 0) {
MS_LOG(EXCEPTION) << "default_shape[-1] and default_shape[-2]should be multiplies of 16, but got "
<< default_shape.back() << " " << default_shape[len - nz_size];
}
tail_shape = {default_shape[1] / 16, default_shape[0] / 16, 16, 16};
}
return output_shape;
leading_shape.insert(leading_shape.end(), tail_shape.begin(), tail_shape.end());
return leading_shape;
}
DShape BroadcastShape(const NodePtrList &inputs, bool to_nz = false) {
@ -252,7 +258,7 @@ DShape BroadcastShape(const NodePtrList &inputs, bool to_nz = false) {
output_shape[i] = align_shape[i];
}
if (output_shape[i] != align_shape[i]) {
throw GKException("shape broadcast failed");
MS_LOG(EXCEPTION) << "Shape broadcast failed. " << output_shape[i] << " vs " << align_shape[i];
}
}
}
@ -272,7 +278,7 @@ DShape ElemwiseOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
})) {
return BroadcastShape(inputs, true);
}
throw GKException("Only support default and fractal_nz");
MS_LOG(EXCEPTION) << "Unsupported format.";
}
DFormat ElemwiseOp::InferFormat(const NodePtrList &inputs, const DAttrs &attrs) {
@ -374,22 +380,20 @@ DShape ReduceOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
return new_shape;
}
void CheckNd(const std::vector<int64_t> &shape, size_t n) {
if (shape.size() != n) {
std::ostringstream info;
info << "input dimension should be " << n << ", but got " << shape.size();
throw GKException(info.str());
}
}
DShape Conv2dOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
auto check_nd = [](const std::vector<int64_t> &shape, size_t n) {
if (shape.size() != n) {
MS_LOG(EXCEPTION) << "input dimension should be " << n << ", but got " << shape.size();
}
};
auto shape0 = inputs[0]->shape;
auto shape1 = inputs[1]->shape;
CheckNd(shape0, 4);
CheckNd(shape1, 4);
check_nd(shape0, 4);
check_nd(shape1, 4);
CHECK_ATTR(attrs, "format");
if (inputs[0]->format != kOpFormat_NHWC && inputs[1]->format != kOpFormat_NHWC &&
GetValue<std::string>(attrs.find("format")->second) != kOpFormat_NHWC) {
throw GKException("check NHWC format failed");
MS_LOG(EXCEPTION) << "check NHWC format failed";
}
auto n = shape0[0];
auto h = shape0[1];
@ -405,10 +409,10 @@ DShape Conv2dOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
auto kernel_size = GetListInt(attrs.find("kernel_size")->second);
auto stride = GetListInt(attrs.find("stride")->second);
auto dilation = GetListInt(attrs.find("dilation")->second);
CheckNd(pad_list, 4);
CheckNd(kernel_size, 2);
CheckNd(stride, 4);
CheckNd(dilation, 4);
check_nd(pad_list, 4);
check_nd(kernel_size, 2);
check_nd(stride, 4);
check_nd(dilation, 4);
bool has_pad = false;
if (pad_list[0] != pad_list[1] || pad_list[2] != pad_list[3]) {
has_pad = true;
@ -464,19 +468,17 @@ DFormat TransposeOp::InferFormat(const NodePtrList &inputs, const DAttrs &attrs)
std::vector<int64_t> nhwc2nchw = {0, 3, 1, 2};
if (perm == nhwc2nchw) return kOpFormat_DEFAULT;
}
std::ostringstream info;
info << "Unsupported Transpose. ori_format = " << ori_format << ", perm = " << attrs.find("perm")->second->ToString();
throw GKException(info.str());
return kOpFormat_DEFAULT;
}
DShape MatMulOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
std::vector<int64_t> shape0 = inputs[0]->shape;
std::vector<int64_t> shape1 = inputs[1]->shape;
if (shape0.size() != 2 || shape1.size() != 2) {
std::ostringstream info;
info << "MatMul's input's dimension must be 2, but got " << shape0.size() << " and " << shape1.size();
throw GKException(info.str());
MS_LOG(EXCEPTION) << "MatMul's input's dimension must be 2, but got " << shape0.size() << " and " << shape1.size();
}
CHECK_ATTR(attrs, "transpose_a");
CHECK_ATTR(attrs, "transpose_b");
auto transpose_a = GetValue<bool>(attrs.find("transpose_a")->second);
auto transpose_b = GetValue<bool>(attrs.find("transpose_b")->second);
int64_t m = transpose_a ? shape0[1] : shape0[0];
@ -491,6 +493,7 @@ DShape MatMulOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
}
TypeId MatMulOp::InferType(const NodePtrList &inputs, const DAttrs &attrs) {
CHECK_ATTR(attrs, "dst_type");
if (attrs.find("dst_type") == attrs.end()) return inputs[0]->type;
auto dst_type = attrs.find("dst_type")->second;
if (dst_type->isa<Type>()) {
@ -502,6 +505,8 @@ TypeId MatMulOp::InferType(const NodePtrList &inputs, const DAttrs &attrs) {
DShape PadAkgOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
std::vector<int64_t> shape0 = inputs[0]->shape;
size_t n = shape0.size();
CHECK_ATTR(attrs, "head");
CHECK_ATTR(attrs, "tail");
std::vector<int64_t> pad_before = GetListInt(attrs.find("head")->second);
std::vector<int64_t> pad_after = GetListInt(attrs.find("tail")->second);
if (pad_before.size() != n || pad_after.size() != n) {
@ -518,6 +523,7 @@ DShape PadAkgOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
DShape UnPadAkgOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
std::vector<int64_t> shape0 = inputs[0]->shape;
size_t n = shape0.size();
CHECK_ATTR(attrs, "tail");
std::vector<int64_t> unpad_after = GetListInt(attrs.find("tail")->second);
if (unpad_after.size() != n) {
MS_LOG(EXCEPTION) << "Input dimension and pad mismatch: " << n << " vs " << unpad_after.size();
@ -531,13 +537,12 @@ DShape UnPadAkgOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
void ComplexOp::CheckType(const NodePtrList &inputs, const DAttrs &attrs) {
if (inputs[0]->type != TypeId::kNumberTypeFloat32) {
throw GKException("Complex's input[0] should be float32");
MS_LOG(EXCEPTION) << "Complex's input[0] should be float32";
}
if (inputs[0]->type != inputs[1]->type) {
MS_LOG(EXCEPTION) << "Complex's input[0] and inputs[1]'s type mismatch";
}
}
} // namespace graphkernel
} // namespace opt
} // namespace mindspore

View File

@ -251,7 +251,7 @@ class CImagOp : public ElemwiseOp {
protected:
void CheckType(const NodePtrList &inputs, const DAttrs &attrs) override {
if (inputs[0]->type != TypeId::kNumberTypeComplex64) {
throw GKException("CImag's input[0] should be complex64");
MS_LOG(EXCEPTION) << "CImag's input[0] should be complex64";
}
};
@ -266,7 +266,7 @@ class CRealOp : public ElemwiseOp {
protected:
void CheckType(const NodePtrList &inputs, const DAttrs &attrs) override {
if (inputs[0]->type != TypeId::kNumberTypeComplex64) {
throw GKException("CReal's input[0] should be complex64");
MS_LOG(EXCEPTION) << "CReal's input[0] should be complex64";
}
};

View File

@ -229,9 +229,7 @@ class TransformOp {
perm = perm_map[{format_b_, format_a_}];
}
if (perm.empty()) {
std::ostringstream oss;
oss << "unsupported format: " << format_a_ << " to " << format_b_;
throw graphkernel::GKException(oss.str());
MS_LOG(EXCEPTION) << "unsupported format: " << format_a_ << " to " << format_b_;
}
auto op = graphkernel::OpRegistry::Instance().NewOp("Transpose", "new_trans");
op->SetAttr("perm", MakeValue(perm));
@ -438,23 +436,19 @@ bool TransformOpOptimizer::Run(const FuncGraphPtr &kernel_graph) {
bool changed = false;
for (auto node : todos) {
if (!AnfAlgo::IsGraphKernel(node)) continue;
try {
auto sub_func_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
auto litegraph = AnfGraph2LiteGraph(sub_func_graph);
if (Process(litegraph)) {
changed = true;
AnfNodePtrList outputs;
auto new_funcgraph = LiteGraph2AnfGraph(litegraph, &outputs);
new_funcgraph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, sub_func_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
auto cnode = node->cast<CNodePtr>();
AnfNodePtrList inputs(cnode->inputs().begin() + 1, cnode->inputs().end());
auto new_node = CreateNewFuseCNode(kernel_graph, new_funcgraph, inputs, outputs);
SetNewKernelInfo(new_node, new_funcgraph, inputs, outputs);
mng->Replace(node, new_node);
mng->AddFuncGraph(new_funcgraph);
}
} catch (const graphkernel::GKException &e) {
MS_LOG(WARNING) << e.what() << ", so we undo airthmetic simplify for this graph";
auto sub_func_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
auto litegraph = AnfGraph2LiteGraph(sub_func_graph);
if (Process(litegraph)) {
changed = true;
AnfNodePtrList outputs;
auto new_funcgraph = LiteGraph2AnfGraph(litegraph, &outputs);
new_funcgraph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, sub_func_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
auto cnode = node->cast<CNodePtr>();
AnfNodePtrList inputs(cnode->inputs().begin() + 1, cnode->inputs().end());
auto new_node = CreateNewFuseCNode(kernel_graph, new_funcgraph, inputs, outputs);
SetNewKernelInfo(new_node, new_funcgraph, inputs, outputs);
mng->Replace(node, new_node);
mng->AddFuncGraph(new_funcgraph);
}
}
return changed;