forked from mindspore-Ecosystem/mindspore
!22878 Remove "GKException" from c++ code.
Merge pull request !22878 from DeshiChen/0902_gkexception
This commit is contained in:
commit
6d0bdd83da
|
@ -158,12 +158,10 @@ PatternNodePtr PatternTree::BuildTree(const std::string &pattern_str) {
|
|||
}
|
||||
cur_node->AddInput(BuildTree(op_inputs));
|
||||
return cur_node;
|
||||
|
||||
} else {
|
||||
return std::make_shared<PatternNode>(pattern_str);
|
||||
}
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
@ -276,7 +274,6 @@ bool DfsMatchGraph(const graphkernel::NodePtr &tmp_node, const PatternNodePtr &t
|
|||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
} else {
|
||||
for (size_t i = 0; i < tmp_pattern_inputs.size(); i++) {
|
||||
if (!DfsMatchGraph(tmp_node_inputs[i], tmp_pattern_inputs[i], para_to_ref, const_to_ref, res)) {
|
||||
|
@ -387,7 +384,6 @@ class ExtraReduce1PatternTree : public PatternTree {
|
|||
for (auto &i : GetValue<std::vector<int64_t>>(first_reduce->attrs().find("axis")->second)) {
|
||||
axis_set.insert(i);
|
||||
}
|
||||
|
||||
} else {
|
||||
auto first_axis = GetValue<std::vector<int64_t>>(first_reduce->attrs().find("axis")->second);
|
||||
auto second_axis = GetValue<std::vector<int64_t>>(origin_root->attrs().find("axis")->second);
|
||||
|
@ -538,7 +534,6 @@ std::unordered_map<std::string, std::vector<PatternTreePtr>> GetExpressions() {
|
|||
std::unordered_set<std::string> enable_ids{flags.enable_simplify_exprs_only.begin(),
|
||||
flags.enable_simplify_exprs_only.end()};
|
||||
std::unordered_set<std::string> disable_ids{flags.disable_simplify_exprs.begin(), flags.disable_simplify_exprs.end()};
|
||||
|
||||
for (auto &e : expressions) {
|
||||
if (!enable_ids.empty()) {
|
||||
if (enable_ids.count(std::to_string(e.id)) == 0) continue;
|
||||
|
@ -640,33 +635,29 @@ bool ArithmeticSimplify::Run(const FuncGraphPtr &func_graph) {
|
|||
expressions_map_ = GetExpressions();
|
||||
for (auto node : func_graph->GetOrderedCnodes()) {
|
||||
if (AnfAlgo::IsGraphKernel(node)) {
|
||||
try {
|
||||
auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
|
||||
graphkernel::LiteGraphPtr lg = AnfGraph2LiteGraph(sub_graph);
|
||||
bool find_pattern = true;
|
||||
bool change_anf_graph = false;
|
||||
while (find_pattern) {
|
||||
find_pattern = false;
|
||||
find_pattern = DoArithmeticTrans(lg) || find_pattern;
|
||||
find_pattern = DoConstantFold(lg) || find_pattern;
|
||||
change_anf_graph = change_anf_graph || find_pattern;
|
||||
}
|
||||
if (!change_anf_graph) continue;
|
||||
ReorganizeEmptyGraph(lg);
|
||||
AnfNodePtrList outputs;
|
||||
auto new_funcgraph = LiteGraph2AnfGraph(lg, &outputs);
|
||||
new_funcgraph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, sub_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
AnfNodePtrList inputs(cnode->inputs().begin() + 1, cnode->inputs().end());
|
||||
EliminateRedundantParameters(new_funcgraph, &inputs);
|
||||
auto new_node = CreateNewFuseCNode(func_graph, new_funcgraph, inputs, outputs);
|
||||
SetNewKernelInfo(new_node, new_funcgraph, inputs, outputs);
|
||||
mng->Replace(node, new_node);
|
||||
mng->AddFuncGraph(new_funcgraph);
|
||||
do_simplify = true;
|
||||
} catch (const graphkernel::GKException &e) {
|
||||
MS_LOG(WARNING) << e.what() << ", so we undo airthmetic simplify for this graph";
|
||||
auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
|
||||
graphkernel::LiteGraphPtr lg = AnfGraph2LiteGraph(sub_graph);
|
||||
bool find_pattern = true;
|
||||
bool change_anf_graph = false;
|
||||
while (find_pattern) {
|
||||
find_pattern = false;
|
||||
find_pattern = DoArithmeticTrans(lg) || find_pattern;
|
||||
find_pattern = DoConstantFold(lg) || find_pattern;
|
||||
change_anf_graph = change_anf_graph || find_pattern;
|
||||
}
|
||||
if (!change_anf_graph) continue;
|
||||
ReorganizeEmptyGraph(lg);
|
||||
AnfNodePtrList outputs;
|
||||
auto new_funcgraph = LiteGraph2AnfGraph(lg, &outputs);
|
||||
new_funcgraph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, sub_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
AnfNodePtrList inputs(cnode->inputs().begin() + 1, cnode->inputs().end());
|
||||
EliminateRedundantParameters(new_funcgraph, &inputs);
|
||||
auto new_node = CreateNewFuseCNode(func_graph, new_funcgraph, inputs, outputs);
|
||||
SetNewKernelInfo(new_node, new_funcgraph, inputs, outputs);
|
||||
mng->Replace(node, new_node);
|
||||
mng->AddFuncGraph(new_funcgraph);
|
||||
do_simplify = true;
|
||||
}
|
||||
}
|
||||
return do_simplify;
|
||||
|
|
|
@ -13,14 +13,12 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_EXPANDERS_BIAS_ADD_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_EXPANDERS_BIAS_ADD_H_
|
||||
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "backend/optimizer/graph_kernel/expanders/expander_factory.h"
|
||||
#include "backend/optimizer/graph_kernel/expanders/utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -34,7 +32,8 @@ class BiasAdd : public OpExpander {
|
|||
support_format->AddFormat({kOpFormat_NCHW, kOpFormat_DEFAULT});
|
||||
support_format->AddFormat({kOpFormat_NHWC, kOpFormat_DEFAULT});
|
||||
validators_.emplace_back(std::move(support_format));
|
||||
validators_.emplace_back(new CheckAttr({"format"}));
|
||||
auto attrs = std::initializer_list<std::string>{"format"};
|
||||
validators_.emplace_back(std::make_unique<CheckAttr>(attrs));
|
||||
}
|
||||
~BiasAdd() = default;
|
||||
NodePtrList Expand() override {
|
||||
|
@ -42,19 +41,19 @@ class BiasAdd : public OpExpander {
|
|||
auto input_x = inputs[0];
|
||||
auto input_y = inputs[1];
|
||||
if (input_x->format == kOpFormat_NCHW) {
|
||||
input_y = gb.Emit("Reshape", {input_y}, {{"shape", MakeValue(ExpandDims::InferShape(input_y->shape, {1, 2}))}});
|
||||
input_y = gb.Emit("Reshape", {input_y}, {{"shape", MakeValue(ExpandDimsInferShape(input_y->shape, {1, 2}))}});
|
||||
} else if (input_x->format == kOpFormat_DEFAULT) {
|
||||
auto data_format = GetValue<std::string>(attrs_["format"]);
|
||||
size_t channel_idx = (data_format == kOpFormat_NHWC) ? input_x->shape.size() - 1 : 1;
|
||||
std::vector<int64_t> axis(input_x->shape.size() - channel_idx - 1, -1);
|
||||
if (!axis.empty()) {
|
||||
input_y = gb.Emit("Reshape", {input_y}, {{"shape", MakeValue(ExpandDims::InferShape(input_y->shape, axis))}});
|
||||
input_y = gb.Emit("Reshape", {input_y}, {{"shape", MakeValue(ExpandDimsInferShape(input_y->shape, axis))}});
|
||||
}
|
||||
}
|
||||
return {gb.Emit("Add", {input_x, input_y})};
|
||||
}
|
||||
};
|
||||
OP_EXPANDER_REGISTER("BiasAdd", BiasAdd);
|
||||
} // namespace expanders
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_EXPANDERS_BIAS_ADD_H_
|
|
@ -22,22 +22,15 @@
|
|||
#include <memory>
|
||||
|
||||
#include "backend/optimizer/graph_kernel/expanders/utils.h"
|
||||
#include "backend/optimizer/graph_kernel/expanders/reshape.h"
|
||||
#include "backend/optimizer/graph_kernel/expanders/bias_add.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace expanders {
|
||||
#define OP_EXPANDER_CREATOR(cls) []() -> std::shared_ptr<OpExpander> { return std::make_shared<cls>(); }
|
||||
|
||||
class OpExpanderFactory {
|
||||
public:
|
||||
static OpExpanderFactory &Instance() {
|
||||
static std::unique_ptr<OpExpanderFactory> instance = nullptr;
|
||||
if (instance == nullptr) {
|
||||
instance.reset(new OpExpanderFactory());
|
||||
}
|
||||
return *instance;
|
||||
static OpExpanderFactory instance;
|
||||
return instance;
|
||||
}
|
||||
std::shared_ptr<OpExpander> GetExpander(const std::string &op) {
|
||||
if (auto iter = creators.find(op); iter != creators.end()) {
|
||||
|
@ -49,16 +42,24 @@ class OpExpanderFactory {
|
|||
}
|
||||
~OpExpanderFactory() = default;
|
||||
|
||||
private:
|
||||
using RegFunc = std::function<std::shared_ptr<OpExpander>()>;
|
||||
void Register(std::string &&op, RegFunc &&func) { creators.insert({op, func}); }
|
||||
OpExpanderFactory() {
|
||||
Register("BiasAdd", OP_EXPANDER_CREATOR(expanders::BiasAdd));
|
||||
Register("ExpandDims", OP_EXPANDER_CREATOR(expanders::ExpandDims));
|
||||
}
|
||||
void Register(const std::string &op, const RegFunc &func) { creators[op] = func; }
|
||||
|
||||
private:
|
||||
std::unordered_map<std::string, RegFunc> creators;
|
||||
};
|
||||
|
||||
class OpExpanderRegister {
|
||||
public:
|
||||
OpExpanderRegister(const std::string &name, const OpExpanderFactory::RegFunc &func) {
|
||||
OpExpanderFactory::Instance().Register(name, func);
|
||||
}
|
||||
~OpExpanderRegister() = default;
|
||||
};
|
||||
|
||||
#define OP_EXPANDER_REGISTER(name, cls) \
|
||||
static const OpExpanderRegister g_##cls##_expander_reg( \
|
||||
name, []() -> std::shared_ptr<OpExpander> { return std::make_shared<cls>(); })
|
||||
} // namespace expanders
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -13,25 +13,24 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_EXPANDERS_RESHAPE_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_EXPANDERS_RESHAPE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "backend/optimizer/graph_kernel/model/node.h"
|
||||
#include "backend/optimizer/graph_kernel/expanders/utils.h"
|
||||
#include "backend/optimizer/graph_kernel/expanders/expander_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace expanders {
|
||||
class ExpandDims : public OpExpander {
|
||||
public:
|
||||
ExpandDims() { validators_.emplace_back(new CheckAttr({"axis"})); }
|
||||
~ExpandDims() {}
|
||||
ExpandDims() {
|
||||
std::initializer_list<std::string> attrs{"axis"};
|
||||
validators_.emplace_back(std::make_unique<CheckAttr>(attrs));
|
||||
}
|
||||
~ExpandDims() = default;
|
||||
NodePtrList Expand() override {
|
||||
const auto &inputs = gb.Get()->inputs();
|
||||
auto &input_x = inputs[0];
|
||||
const auto &input_x = inputs[0];
|
||||
auto shape = MakeValue(ExpandDims::InferShape(input_x->shape, GetAxisList(this->attrs_["axis"])));
|
||||
auto result = gb.Emit("Reshape", {input_x}, {{"shape", shape}});
|
||||
return {result};
|
||||
|
@ -42,9 +41,7 @@ class ExpandDims : public OpExpander {
|
|||
for (auto x : axis) {
|
||||
int64_t rank = static_cast<int64_t>(new_shape.size());
|
||||
if (x > rank || x < -rank - 1) {
|
||||
std::ostringstream oss;
|
||||
oss << "ExpandDims axis " << x << " is out of range of size " << new_shape.size();
|
||||
throw graphkernel::GKException(oss.str());
|
||||
MS_LOG(EXCEPTION) << "ExpandDims axis " << x << " is out of range of size " << new_shape.size();
|
||||
}
|
||||
if (x >= 0) {
|
||||
new_shape.insert(new_shape.begin() + x, 1LL);
|
||||
|
@ -55,7 +52,11 @@ class ExpandDims : public OpExpander {
|
|||
return new_shape;
|
||||
}
|
||||
};
|
||||
OP_EXPANDER_REGISTER("ExpandDims", ExpandDims);
|
||||
|
||||
ShapeVector ExpandDimsInferShape(const ShapeVector &shape, const std::vector<int64_t> &axis) {
|
||||
return ExpandDims::InferShape(shape, axis);
|
||||
}
|
||||
} // namespace expanders
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_EXPANDERS_RESHAPE_H_
|
|
@ -31,27 +31,31 @@ graphkernel::LiteGraphPtr OpExpander::Run(const BaseInfoList &inputs, const Base
|
|||
this->outputs_info_ = outputs;
|
||||
this->attrs_ = attrs;
|
||||
this->processor_ = processor;
|
||||
for (const auto &v : validators_) {
|
||||
v->Check(*this);
|
||||
if (std::any_of(validators_.begin(), validators_.end(),
|
||||
[this](const std::unique_ptr<Validator> &v) { return !(v->Check(*this)); })) {
|
||||
return nullptr;
|
||||
}
|
||||
if (!this->CheckInputs()) {
|
||||
return nullptr;
|
||||
}
|
||||
this->CheckInputs();
|
||||
for (auto &inp : inputs) {
|
||||
(void)gb.Parameter(inp);
|
||||
}
|
||||
auto result = this->Expand();
|
||||
gb.SetOutputs(result);
|
||||
this->CheckOutputs();
|
||||
if (!this->CheckOutputs()) {
|
||||
return nullptr;
|
||||
}
|
||||
return gb.Get();
|
||||
}
|
||||
|
||||
void OpExpander::CheckOutputs() {
|
||||
bool OpExpander::CheckOutputs() {
|
||||
// check the output shape/type/format are same as the original basic node's output.
|
||||
const NodePtrList &outputs = gb.Get()->GetOutputs();
|
||||
if (outputs.size() != this->outputs_info_.size()) {
|
||||
std::ostringstream oss;
|
||||
oss << "the output num was not equal to the original output num : " << outputs.size() << " vs "
|
||||
<< outputs_info_.size();
|
||||
throw graphkernel::GKException(oss.str());
|
||||
MS_LOG(INFO) << "the output num was not equal to the original output num : " << outputs.size() << " vs "
|
||||
<< outputs_info_.size();
|
||||
return false;
|
||||
}
|
||||
for (size_t i = 0; i < outputs.size(); i++) {
|
||||
if (outputs[i]->shape != outputs_info_[i].shape) {
|
||||
|
@ -65,21 +69,21 @@ void OpExpander::CheckOutputs() {
|
|||
oss << s << ",";
|
||||
}
|
||||
oss << "]";
|
||||
throw graphkernel::GKException(oss.str());
|
||||
MS_LOG(INFO) << oss.str();
|
||||
return false;
|
||||
}
|
||||
if (outputs[i]->type != outputs_info_[i].type) {
|
||||
std::ostringstream oss;
|
||||
oss << "Op " << this->op_ << "'s output type [" << outputs[i]->type << "] is wrong, expect: ["
|
||||
<< outputs_info_[i].type << "]";
|
||||
throw graphkernel::GKException(oss.str());
|
||||
MS_LOG(INFO) << "Op " << this->op_ << "'s output type [" << outputs[i]->type << "] is wrong, expect: ["
|
||||
<< outputs_info_[i].type << "]";
|
||||
return false;
|
||||
}
|
||||
if (outputs[i]->format != outputs_info_[i].format) {
|
||||
std::ostringstream oss;
|
||||
oss << "Op " << this->op_ << "'s output format [" << outputs[i]->format << "] is wrong, expect: ["
|
||||
<< outputs_info_[i].format << "]";
|
||||
throw graphkernel::GKException(oss.str());
|
||||
MS_LOG(INFO) << "Op " << this->op_ << "'s output format [" << outputs[i]->format << "] is wrong, expect: ["
|
||||
<< outputs_info_[i].format << "]";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<int64_t> GetAxisList(const ValuePtr &value) {
|
||||
|
|
|
@ -37,9 +37,9 @@ class OpExpander {
|
|||
virtual ~OpExpander() = default;
|
||||
|
||||
protected:
|
||||
virtual void CheckInputs() {}
|
||||
virtual bool CheckInputs() { return true; }
|
||||
virtual NodePtrList Expand() = 0;
|
||||
void CheckOutputs();
|
||||
bool CheckOutputs();
|
||||
|
||||
graphkernel::LiteGraph::GraphBuilder gb;
|
||||
std::string op_;
|
||||
|
@ -57,37 +57,36 @@ class OpExpander {
|
|||
|
||||
class Validator {
|
||||
public:
|
||||
virtual void Check(const OpExpander &e) = 0;
|
||||
virtual bool Check(const OpExpander &e) = 0;
|
||||
};
|
||||
|
||||
class CheckAllFormatsSame : public Validator {
|
||||
public:
|
||||
void Check(const OpExpander &e) override {
|
||||
if (e.inputs_info_.empty()) return;
|
||||
bool Check(const OpExpander &e) override {
|
||||
if (e.inputs_info_.empty()) return true;
|
||||
const auto &fmt_0 = e.inputs_info_[0].format;
|
||||
for (size_t i = 1; i < e.inputs_info_.size(); i++) {
|
||||
if (e.inputs_info_[i].format != fmt_0) {
|
||||
std::ostringstream oss;
|
||||
oss << "Unmatched format for op " << e.op_;
|
||||
throw graphkernel::GKException(oss.str());
|
||||
MS_LOG(INFO) << "Unmatched format for op " << e.op_;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
class CheckAttr : public Validator {
|
||||
public:
|
||||
CheckAttr() = default;
|
||||
CheckAttr(std::initializer_list<std::string> l) : attrs_(l) {}
|
||||
~CheckAttr() = default;
|
||||
void Check(const OpExpander &e) override {
|
||||
bool Check(const OpExpander &e) override {
|
||||
for (auto &a : attrs_) {
|
||||
if (e.attrs_.count(a) == 0) {
|
||||
std::ostringstream oss;
|
||||
oss << "attr " << a << " does not exist. op " << e.op_;
|
||||
throw graphkernel::GKException(oss.str());
|
||||
MS_LOG(INFO) << "attr " << a << " does not exist. op " << e.op_;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -97,7 +96,7 @@ class CheckAttr : public Validator {
|
|||
class SupportFormat : public Validator {
|
||||
public:
|
||||
void AddFormat(std::initializer_list<std::string> l) { formats_.emplace_back(l); }
|
||||
void Check(const OpExpander &e) override {
|
||||
bool Check(const OpExpander &e) override {
|
||||
for (auto &formats : formats_) {
|
||||
if (formats.size() != e.inputs_info_.size()) {
|
||||
continue;
|
||||
|
@ -110,12 +109,11 @@ class SupportFormat : public Validator {
|
|||
}
|
||||
}
|
||||
if (match) {
|
||||
return;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
std::ostringstream oss;
|
||||
oss << "unsupported format for op " << e.op_;
|
||||
throw graphkernel::GKException(oss.str());
|
||||
MS_LOG(INFO) << "unsupported format for op " << e.op_;
|
||||
return false;
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -123,6 +121,7 @@ class SupportFormat : public Validator {
|
|||
};
|
||||
|
||||
std::vector<int64_t> GetAxisList(const ValuePtr &value);
|
||||
ShapeVector ExpandDimsInferShape(const ShapeVector &shape, const std::vector<int64_t> &axis);
|
||||
} // namespace expanders
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -153,13 +153,12 @@ FuncGraphPtr DefaultExpander::CreateExpandFuncGraph(const CNodePtr &node) {
|
|||
outputs[i].format = AnfAlgo::GetOutputFormat(node, i);
|
||||
}
|
||||
auto &attrs = AnfAlgo::GetCNodePrimitive(node)->attrs();
|
||||
try {
|
||||
auto litegraph = expander_ptr->Run(inputs, outputs, attrs, kernel::GetStrProcessorFromContext());
|
||||
return LiteGraph2AnfGraph(litegraph);
|
||||
} catch (const graphkernel::GKException &e) {
|
||||
MS_LOG(INFO) << e.what() << ", undo expanding this op";
|
||||
auto litegraph = expander_ptr->Run(inputs, outputs, attrs, kernel::GetStrProcessorFromContext());
|
||||
if (litegraph == nullptr) {
|
||||
MS_LOG(INFO) << "undo expanding " << node->fullname_with_scope();
|
||||
return nullptr;
|
||||
}
|
||||
return LiteGraph2AnfGraph(litegraph);
|
||||
}
|
||||
|
||||
AnfNodePtr PyExpander::CreateExpandGraphKernel(const FuncGraphPtr &new_func_graph, const CNodePtr &old_node) {
|
||||
|
|
|
@ -152,15 +152,6 @@ class OutputNode : public Node {
|
|||
void Dump(std::ostringstream &os) const override { ; }
|
||||
NType NodeType() override { return NType::Output; }
|
||||
};
|
||||
|
||||
class GKException : public std::exception {
|
||||
public:
|
||||
explicit GKException(const std::string &message) : msg_(message) {}
|
||||
const char *what() const noexcept override { return msg_.c_str(); }
|
||||
|
||||
protected:
|
||||
std::string msg_;
|
||||
};
|
||||
} // namespace graphkernel
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -200,29 +200,35 @@ NodePtr PrimOp::InferValue(const NodePtrList &inputs, const DAttrs &attrs, const
|
|||
|
||||
// default format shape to fractal_Nz format shape
|
||||
DShape ToNz(const DShape &default_shape) {
|
||||
if (default_shape.size() != 1 && default_shape.size() != 2) {
|
||||
throw GKException("shape is too long");
|
||||
constexpr size_t nz_size = 2;
|
||||
auto len = default_shape.size();
|
||||
DShape leading_shape;
|
||||
DShape tail_shape;
|
||||
if (default_shape.size() > nz_size) {
|
||||
leading_shape.insert(leading_shape.end(), default_shape.begin(), default_shape.end() - nz_size);
|
||||
}
|
||||
DShape output_shape;
|
||||
if (default_shape.size() == 1 || (default_shape.size() == 2 && default_shape[0] == 1)) {
|
||||
output_shape = {default_shape[default_shape.size() - 1] / 16, 1, 1, 16};
|
||||
if (default_shape[default_shape.size() - 1] % 16 != 0) {
|
||||
throw GKException("should be multiplies of 16");
|
||||
if (default_shape.size() == 1 || (default_shape.size() >= nz_size && default_shape[len - nz_size] == 1)) {
|
||||
// (32) or (N, 1, 32) -> (N, 2, 1, 1, 16)
|
||||
if (default_shape.back() % 16 != 0) {
|
||||
MS_LOG(EXCEPTION) << "default_shape[-1] should be multiplies of 16, but got " << default_shape.back();
|
||||
}
|
||||
|
||||
} else if (default_shape.size() == 2 || default_shape[1] == 1) {
|
||||
output_shape = {1, default_shape[0] / 16, 16, 1};
|
||||
if (default_shape[0] % 16 != 0) {
|
||||
throw GKException("should be multiplies of 16");
|
||||
tail_shape = {default_shape.back() / 16, 1, 1, 16};
|
||||
} else if (default_shape.size() >= nz_size || default_shape[1] == 1) {
|
||||
// (N, 32, 1) -> (N, 1, 2, 16, 1)
|
||||
if (default_shape[len - nz_size] % 16 != 0) {
|
||||
MS_LOG(EXCEPTION) << "default_shape[-2] should be multiplies of 16, but got " << default_shape[len - nz_size];
|
||||
}
|
||||
|
||||
tail_shape = {1, default_shape[0] / 16, 16, 1};
|
||||
} else {
|
||||
output_shape = {default_shape[1] / 16, default_shape[0] / 16, 16, 16};
|
||||
if (default_shape[0] % 16 != 0 || default_shape[1] % 16 != 0) {
|
||||
throw GKException("should be multiplies of 16");
|
||||
// (N, 32, 48) -> (N, 3, 2, 16, 16)
|
||||
if (default_shape.back() % 16 != 0 || default_shape[len - nz_size] % 16 != 0) {
|
||||
MS_LOG(EXCEPTION) << "default_shape[-1] and default_shape[-2]should be multiplies of 16, but got "
|
||||
<< default_shape.back() << " " << default_shape[len - nz_size];
|
||||
}
|
||||
tail_shape = {default_shape[1] / 16, default_shape[0] / 16, 16, 16};
|
||||
}
|
||||
return output_shape;
|
||||
leading_shape.insert(leading_shape.end(), tail_shape.begin(), tail_shape.end());
|
||||
return leading_shape;
|
||||
}
|
||||
|
||||
DShape BroadcastShape(const NodePtrList &inputs, bool to_nz = false) {
|
||||
|
@ -252,7 +258,7 @@ DShape BroadcastShape(const NodePtrList &inputs, bool to_nz = false) {
|
|||
output_shape[i] = align_shape[i];
|
||||
}
|
||||
if (output_shape[i] != align_shape[i]) {
|
||||
throw GKException("shape broadcast failed");
|
||||
MS_LOG(EXCEPTION) << "Shape broadcast failed. " << output_shape[i] << " vs " << align_shape[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -272,7 +278,7 @@ DShape ElemwiseOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
|
|||
})) {
|
||||
return BroadcastShape(inputs, true);
|
||||
}
|
||||
throw GKException("Only support default and fractal_nz");
|
||||
MS_LOG(EXCEPTION) << "Unsupported format.";
|
||||
}
|
||||
|
||||
DFormat ElemwiseOp::InferFormat(const NodePtrList &inputs, const DAttrs &attrs) {
|
||||
|
@ -374,22 +380,20 @@ DShape ReduceOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
|
|||
return new_shape;
|
||||
}
|
||||
|
||||
void CheckNd(const std::vector<int64_t> &shape, size_t n) {
|
||||
if (shape.size() != n) {
|
||||
std::ostringstream info;
|
||||
info << "input dimension should be " << n << ", but got " << shape.size();
|
||||
throw GKException(info.str());
|
||||
}
|
||||
}
|
||||
|
||||
DShape Conv2dOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
|
||||
auto check_nd = [](const std::vector<int64_t> &shape, size_t n) {
|
||||
if (shape.size() != n) {
|
||||
MS_LOG(EXCEPTION) << "input dimension should be " << n << ", but got " << shape.size();
|
||||
}
|
||||
};
|
||||
auto shape0 = inputs[0]->shape;
|
||||
auto shape1 = inputs[1]->shape;
|
||||
CheckNd(shape0, 4);
|
||||
CheckNd(shape1, 4);
|
||||
check_nd(shape0, 4);
|
||||
check_nd(shape1, 4);
|
||||
CHECK_ATTR(attrs, "format");
|
||||
if (inputs[0]->format != kOpFormat_NHWC && inputs[1]->format != kOpFormat_NHWC &&
|
||||
GetValue<std::string>(attrs.find("format")->second) != kOpFormat_NHWC) {
|
||||
throw GKException("check NHWC format failed");
|
||||
MS_LOG(EXCEPTION) << "check NHWC format failed";
|
||||
}
|
||||
auto n = shape0[0];
|
||||
auto h = shape0[1];
|
||||
|
@ -405,10 +409,10 @@ DShape Conv2dOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
|
|||
auto kernel_size = GetListInt(attrs.find("kernel_size")->second);
|
||||
auto stride = GetListInt(attrs.find("stride")->second);
|
||||
auto dilation = GetListInt(attrs.find("dilation")->second);
|
||||
CheckNd(pad_list, 4);
|
||||
CheckNd(kernel_size, 2);
|
||||
CheckNd(stride, 4);
|
||||
CheckNd(dilation, 4);
|
||||
check_nd(pad_list, 4);
|
||||
check_nd(kernel_size, 2);
|
||||
check_nd(stride, 4);
|
||||
check_nd(dilation, 4);
|
||||
bool has_pad = false;
|
||||
if (pad_list[0] != pad_list[1] || pad_list[2] != pad_list[3]) {
|
||||
has_pad = true;
|
||||
|
@ -464,19 +468,17 @@ DFormat TransposeOp::InferFormat(const NodePtrList &inputs, const DAttrs &attrs)
|
|||
std::vector<int64_t> nhwc2nchw = {0, 3, 1, 2};
|
||||
if (perm == nhwc2nchw) return kOpFormat_DEFAULT;
|
||||
}
|
||||
std::ostringstream info;
|
||||
info << "Unsupported Transpose. ori_format = " << ori_format << ", perm = " << attrs.find("perm")->second->ToString();
|
||||
throw GKException(info.str());
|
||||
return kOpFormat_DEFAULT;
|
||||
}
|
||||
|
||||
DShape MatMulOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
|
||||
std::vector<int64_t> shape0 = inputs[0]->shape;
|
||||
std::vector<int64_t> shape1 = inputs[1]->shape;
|
||||
if (shape0.size() != 2 || shape1.size() != 2) {
|
||||
std::ostringstream info;
|
||||
info << "MatMul's input's dimension must be 2, but got " << shape0.size() << " and " << shape1.size();
|
||||
throw GKException(info.str());
|
||||
MS_LOG(EXCEPTION) << "MatMul's input's dimension must be 2, but got " << shape0.size() << " and " << shape1.size();
|
||||
}
|
||||
CHECK_ATTR(attrs, "transpose_a");
|
||||
CHECK_ATTR(attrs, "transpose_b");
|
||||
auto transpose_a = GetValue<bool>(attrs.find("transpose_a")->second);
|
||||
auto transpose_b = GetValue<bool>(attrs.find("transpose_b")->second);
|
||||
int64_t m = transpose_a ? shape0[1] : shape0[0];
|
||||
|
@ -491,6 +493,7 @@ DShape MatMulOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
|
|||
}
|
||||
|
||||
TypeId MatMulOp::InferType(const NodePtrList &inputs, const DAttrs &attrs) {
|
||||
CHECK_ATTR(attrs, "dst_type");
|
||||
if (attrs.find("dst_type") == attrs.end()) return inputs[0]->type;
|
||||
auto dst_type = attrs.find("dst_type")->second;
|
||||
if (dst_type->isa<Type>()) {
|
||||
|
@ -502,6 +505,8 @@ TypeId MatMulOp::InferType(const NodePtrList &inputs, const DAttrs &attrs) {
|
|||
DShape PadAkgOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
|
||||
std::vector<int64_t> shape0 = inputs[0]->shape;
|
||||
size_t n = shape0.size();
|
||||
CHECK_ATTR(attrs, "head");
|
||||
CHECK_ATTR(attrs, "tail");
|
||||
std::vector<int64_t> pad_before = GetListInt(attrs.find("head")->second);
|
||||
std::vector<int64_t> pad_after = GetListInt(attrs.find("tail")->second);
|
||||
if (pad_before.size() != n || pad_after.size() != n) {
|
||||
|
@ -518,6 +523,7 @@ DShape PadAkgOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
|
|||
DShape UnPadAkgOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
|
||||
std::vector<int64_t> shape0 = inputs[0]->shape;
|
||||
size_t n = shape0.size();
|
||||
CHECK_ATTR(attrs, "tail");
|
||||
std::vector<int64_t> unpad_after = GetListInt(attrs.find("tail")->second);
|
||||
if (unpad_after.size() != n) {
|
||||
MS_LOG(EXCEPTION) << "Input dimension and pad mismatch: " << n << " vs " << unpad_after.size();
|
||||
|
@ -531,13 +537,12 @@ DShape UnPadAkgOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
|
|||
|
||||
void ComplexOp::CheckType(const NodePtrList &inputs, const DAttrs &attrs) {
|
||||
if (inputs[0]->type != TypeId::kNumberTypeFloat32) {
|
||||
throw GKException("Complex's input[0] should be float32");
|
||||
MS_LOG(EXCEPTION) << "Complex's input[0] should be float32";
|
||||
}
|
||||
if (inputs[0]->type != inputs[1]->type) {
|
||||
MS_LOG(EXCEPTION) << "Complex's input[0] and inputs[1]'s type mismatch";
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace graphkernel
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -251,7 +251,7 @@ class CImagOp : public ElemwiseOp {
|
|||
protected:
|
||||
void CheckType(const NodePtrList &inputs, const DAttrs &attrs) override {
|
||||
if (inputs[0]->type != TypeId::kNumberTypeComplex64) {
|
||||
throw GKException("CImag's input[0] should be complex64");
|
||||
MS_LOG(EXCEPTION) << "CImag's input[0] should be complex64";
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -266,7 +266,7 @@ class CRealOp : public ElemwiseOp {
|
|||
protected:
|
||||
void CheckType(const NodePtrList &inputs, const DAttrs &attrs) override {
|
||||
if (inputs[0]->type != TypeId::kNumberTypeComplex64) {
|
||||
throw GKException("CReal's input[0] should be complex64");
|
||||
MS_LOG(EXCEPTION) << "CReal's input[0] should be complex64";
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -229,9 +229,7 @@ class TransformOp {
|
|||
perm = perm_map[{format_b_, format_a_}];
|
||||
}
|
||||
if (perm.empty()) {
|
||||
std::ostringstream oss;
|
||||
oss << "unsupported format: " << format_a_ << " to " << format_b_;
|
||||
throw graphkernel::GKException(oss.str());
|
||||
MS_LOG(EXCEPTION) << "unsupported format: " << format_a_ << " to " << format_b_;
|
||||
}
|
||||
auto op = graphkernel::OpRegistry::Instance().NewOp("Transpose", "new_trans");
|
||||
op->SetAttr("perm", MakeValue(perm));
|
||||
|
@ -438,23 +436,19 @@ bool TransformOpOptimizer::Run(const FuncGraphPtr &kernel_graph) {
|
|||
bool changed = false;
|
||||
for (auto node : todos) {
|
||||
if (!AnfAlgo::IsGraphKernel(node)) continue;
|
||||
try {
|
||||
auto sub_func_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
|
||||
auto litegraph = AnfGraph2LiteGraph(sub_func_graph);
|
||||
if (Process(litegraph)) {
|
||||
changed = true;
|
||||
AnfNodePtrList outputs;
|
||||
auto new_funcgraph = LiteGraph2AnfGraph(litegraph, &outputs);
|
||||
new_funcgraph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, sub_func_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
AnfNodePtrList inputs(cnode->inputs().begin() + 1, cnode->inputs().end());
|
||||
auto new_node = CreateNewFuseCNode(kernel_graph, new_funcgraph, inputs, outputs);
|
||||
SetNewKernelInfo(new_node, new_funcgraph, inputs, outputs);
|
||||
mng->Replace(node, new_node);
|
||||
mng->AddFuncGraph(new_funcgraph);
|
||||
}
|
||||
} catch (const graphkernel::GKException &e) {
|
||||
MS_LOG(WARNING) << e.what() << ", so we undo airthmetic simplify for this graph";
|
||||
auto sub_func_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
|
||||
auto litegraph = AnfGraph2LiteGraph(sub_func_graph);
|
||||
if (Process(litegraph)) {
|
||||
changed = true;
|
||||
AnfNodePtrList outputs;
|
||||
auto new_funcgraph = LiteGraph2AnfGraph(litegraph, &outputs);
|
||||
new_funcgraph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, sub_func_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
AnfNodePtrList inputs(cnode->inputs().begin() + 1, cnode->inputs().end());
|
||||
auto new_node = CreateNewFuseCNode(kernel_graph, new_funcgraph, inputs, outputs);
|
||||
SetNewKernelInfo(new_node, new_funcgraph, inputs, outputs);
|
||||
mng->Replace(node, new_node);
|
||||
mng->AddFuncGraph(new_funcgraph);
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
|
|
Loading…
Reference in New Issue