forked from mindspore-Ecosystem/mindspore
Add expanders in c++ code
transplant the op expander code from python to c++, base on LiteGraph. the c++ expander will be called in priority if it was registered in OpExpanderFactory. add two examples, BiasAdd and ExpandDims. remove BiasAdd from python expanders. since the ExpandDims is also imported by other ops (e.g. BatchNorm), we don't remove it now.
This commit is contained in:
parent
dd7c7dcb83
commit
9add26ad99
|
@ -18,7 +18,6 @@ from .addn import AddN
|
|||
from .assign_add import AssignAdd
|
||||
from .batchnorm import BatchNorm
|
||||
from .batchnorm_grad import BatchNormGrad
|
||||
from .bias_add import BiasAdd
|
||||
from .bias_add_grad import BiasAddGrad
|
||||
from .clip_by_norm_no_div_sum import ClipByNormNoDivSum
|
||||
from .conv2d import Conv2D
|
||||
|
@ -26,7 +25,6 @@ from .complex import CAbs, CAdd, CDiv, CMul, CSub
|
|||
from .dropout_grad import DropoutGrad
|
||||
from .equal_count import EqualCount
|
||||
from .erfc import Erfc
|
||||
from .expand_dims import ExpandDims
|
||||
from .fused_adam import FusedAdam
|
||||
from .fused_adam_weight_decay import FusedAdamWeightDecay
|
||||
from .fused_mul_add import FusedMulAdd
|
||||
|
|
|
@ -1,48 +0,0 @@
|
|||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
"""generate json desc for bias_add"""
|
||||
from mindspore._extends.graph_kernel.model.model import DataFormat as DF
|
||||
from ._utils import Expander, ExpanderInfoValidator as VLD
|
||||
from .expand_dims import ExpandDims
|
||||
|
||||
|
||||
@VLD.add_format(DF.DEFAULT, DF.DEFAULT)
|
||||
@VLD.add_format(DF.NCHW, DF.DEFAULT)
|
||||
@VLD.add_format(DF.NHWC, DF.DEFAULT)
|
||||
class BiasAdd(Expander):
|
||||
"""BiasAdd expander"""
|
||||
|
||||
def _expand(self, graph_builder):
|
||||
input_x, input_y = self.inputs
|
||||
|
||||
if input_x.data_format == DF.NCHW:
|
||||
input_y_expand = graph_builder.emit(
|
||||
'Reshape', [input_y], attrs={'shape': ExpandDims.infer_shape(input_y.shape, [1, 2])})
|
||||
result = graph_builder.emit('Add', [input_x, input_y_expand])
|
||||
elif input_x.data_format == DF.DEFAULT:
|
||||
if len(input_x.shape) == 2:
|
||||
result = graph_builder.emit('Add', [input_x, input_y])
|
||||
elif len(input_x.shape) == 3:
|
||||
input_y_expand = graph_builder.emit(
|
||||
'Reshape', [input_y], attrs={'shape': ExpandDims.infer_shape(input_y.shape, 1)})
|
||||
result = graph_builder.emit('Add', [input_x, input_y_expand])
|
||||
else: # len == 4
|
||||
input_y_expand = graph_builder.emit(
|
||||
'Reshape', [input_y], attrs={'shape': ExpandDims.infer_shape(input_y.shape, [1, 2])})
|
||||
result = graph_builder.emit('Add', [input_x, input_y_expand])
|
||||
else: # NHWC
|
||||
result = graph_builder.emit('Add', [input_x, input_y])
|
||||
|
||||
return result
|
|
@ -0,0 +1,60 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_EXPANDERS_BIAS_ADD_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_EXPANDERS_BIAS_ADD_H_
|
||||
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "backend/optimizer/graph_kernel/expanders/utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace expanders {
|
||||
class BiasAdd : public OpExpander {
|
||||
public:
|
||||
BiasAdd() {
|
||||
auto support_format = std::make_unique<SupportFormat>();
|
||||
support_format->AddFormat({kOpFormat_DEFAULT, kOpFormat_DEFAULT});
|
||||
support_format->AddFormat({kOpFormat_NCHW, kOpFormat_DEFAULT});
|
||||
support_format->AddFormat({kOpFormat_NHWC, kOpFormat_DEFAULT});
|
||||
validators_.emplace_back(std::move(support_format));
|
||||
validators_.emplace_back(new CheckAttr({"format"}));
|
||||
}
|
||||
~BiasAdd() = default;
|
||||
NodePtrList Expand() override {
|
||||
const auto &inputs = gb.Get()->inputs();
|
||||
auto input_x = inputs[0];
|
||||
auto input_y = inputs[1];
|
||||
if (input_x->format == kOpFormat_NCHW) {
|
||||
input_y = gb.Emit("Reshape", {input_y}, {{"shape", MakeValue(ExpandDims::InferShape(input_y->shape, {1, 2}))}});
|
||||
} else if (input_x->format == kOpFormat_DEFAULT) {
|
||||
auto data_format = GetValue<std::string>(attrs_["format"]);
|
||||
size_t channel_idx = (data_format == kOpFormat_NHWC) ? input_x->shape.size() - 1 : 1;
|
||||
std::vector<int64_t> axis(input_x->shape.size() - channel_idx - 1, -1);
|
||||
if (!axis.empty()) {
|
||||
input_y = gb.Emit("Reshape", {input_y}, {{"shape", MakeValue(ExpandDims::InferShape(input_y->shape, axis))}});
|
||||
}
|
||||
}
|
||||
return {gb.Emit("Add", {input_x, input_y})};
|
||||
}
|
||||
};
|
||||
} // namespace expanders
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_EXPANDERS_BIAS_ADD_H_
|
|
@ -0,0 +1,65 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_EXPANDERS_EXPANDER_FACTORY_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_EXPANDERS_EXPANDER_FACTORY_H_
|
||||
|
||||
#include <unordered_map>
|
||||
#include <functional>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
|
||||
#include "backend/optimizer/graph_kernel/expanders/utils.h"
|
||||
#include "backend/optimizer/graph_kernel/expanders/reshape.h"
|
||||
#include "backend/optimizer/graph_kernel/expanders/bias_add.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace expanders {
|
||||
#define OP_EXPANDER_CREATOR(cls) []() -> std::shared_ptr<OpExpander> { return std::make_shared<cls>(); }
|
||||
|
||||
class OpExpanderFactory {
|
||||
public:
|
||||
static OpExpanderFactory &Instance() {
|
||||
static std::unique_ptr<OpExpanderFactory> instance = nullptr;
|
||||
if (instance == nullptr) {
|
||||
instance.reset(new OpExpanderFactory());
|
||||
}
|
||||
return *instance;
|
||||
}
|
||||
std::shared_ptr<OpExpander> GetExpander(const std::string &op) {
|
||||
if (auto iter = creators.find(op); iter != creators.end()) {
|
||||
auto expander_ptr = iter->second();
|
||||
expander_ptr->op_ = op;
|
||||
return expander_ptr;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
~OpExpanderFactory() = default;
|
||||
|
||||
private:
|
||||
using RegFunc = std::function<std::shared_ptr<OpExpander>()>;
|
||||
void Register(std::string &&op, RegFunc &&func) { creators.insert({op, func}); }
|
||||
OpExpanderFactory() {
|
||||
Register("BiasAdd", OP_EXPANDER_CREATOR(expanders::BiasAdd));
|
||||
Register("ExpandDims", OP_EXPANDER_CREATOR(expanders::ExpandDims));
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, RegFunc> creators;
|
||||
};
|
||||
} // namespace expanders
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_EXPANDERS_EXPANDER_FACTORY_H_
|
|
@ -0,0 +1,61 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_EXPANDERS_RESHAPE_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_EXPANDERS_RESHAPE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "backend/optimizer/graph_kernel/model/node.h"
|
||||
#include "backend/optimizer/graph_kernel/expanders/utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace expanders {
|
||||
class ExpandDims : public OpExpander {
|
||||
public:
|
||||
ExpandDims() { validators_.emplace_back(new CheckAttr({"axis"})); }
|
||||
~ExpandDims() {}
|
||||
NodePtrList Expand() override {
|
||||
const auto &inputs = gb.Get()->inputs();
|
||||
auto &input_x = inputs[0];
|
||||
auto shape = MakeValue(ExpandDims::InferShape(input_x->shape, GetAxisList(this->attrs_["axis"])));
|
||||
auto result = gb.Emit("Reshape", {input_x}, {{"shape", shape}});
|
||||
return {result};
|
||||
}
|
||||
|
||||
static ShapeVector InferShape(const ShapeVector &shape, const std::vector<int64_t> &axis) {
|
||||
ShapeVector new_shape = shape;
|
||||
for (auto x : axis) {
|
||||
int64_t rank = static_cast<int64_t>(new_shape.size());
|
||||
if (x > rank || x < -rank - 1) {
|
||||
std::ostringstream oss;
|
||||
oss << "ExpandDims axis " << x << " is out of range of size " << new_shape.size();
|
||||
throw graphkernel::GKException(oss.str());
|
||||
}
|
||||
if (x >= 0) {
|
||||
new_shape.insert(new_shape.begin() + x, 1LL);
|
||||
} else {
|
||||
new_shape.insert(new_shape.begin() + (x + rank + 1), 1LL);
|
||||
}
|
||||
}
|
||||
return new_shape;
|
||||
}
|
||||
};
|
||||
} // namespace expanders
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_EXPANDERS_RESHAPE_H_
|
|
@ -0,0 +1,100 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/optimizer/graph_kernel/expanders/utils.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "backend/optimizer/graph_kernel/model/lite_graph.h"
|
||||
#include "backend/optimizer/graph_kernel/model/node.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace expanders {
|
||||
graphkernel::LiteGraphPtr OpExpander::Run(const BaseInfoList &inputs, const BaseInfoList &outputs,
|
||||
const graphkernel::DAttrs &attrs, const std::string &processor) {
|
||||
this->inputs_info_ = inputs;
|
||||
this->outputs_info_ = outputs;
|
||||
this->attrs_ = attrs;
|
||||
this->processor_ = processor;
|
||||
for (const auto &v : validators_) {
|
||||
v->Check(*this);
|
||||
}
|
||||
this->CheckInputs();
|
||||
for (auto &inp : inputs) {
|
||||
gb.Parameter(inp);
|
||||
}
|
||||
auto result = this->Expand();
|
||||
gb.SetOutputs(result);
|
||||
this->CheckOutputs();
|
||||
return gb.Get();
|
||||
}
|
||||
|
||||
void OpExpander::CheckOutputs() {
|
||||
// check the output shape/type/format are same as the original basic node's output.
|
||||
const NodePtrList &outputs = gb.Get()->GetOutputs();
|
||||
if (outputs.size() != this->outputs_info_.size()) {
|
||||
std::ostringstream oss;
|
||||
oss << "the output num was not equal to the original output num : " << outputs.size() << " vs "
|
||||
<< outputs_info_.size();
|
||||
throw graphkernel::GKException(oss.str());
|
||||
}
|
||||
for (size_t i = 0; i < outputs.size(); i++) {
|
||||
if (outputs[i]->shape != outputs_info_[i].shape) {
|
||||
std::ostringstream oss;
|
||||
oss << "Op " << this->op_ << "'s output shape [";
|
||||
for (auto s : outputs[i]->shape) {
|
||||
oss << s << ",";
|
||||
}
|
||||
oss << "] is wrong. expect: [";
|
||||
for (auto s : outputs_info_[i].shape) {
|
||||
oss << s << ",";
|
||||
}
|
||||
oss << "]";
|
||||
throw graphkernel::GKException(oss.str());
|
||||
}
|
||||
if (outputs[i]->type != outputs_info_[i].type) {
|
||||
std::ostringstream oss;
|
||||
oss << "Op " << this->op_ << "'s output type [" << outputs[i]->type << "] is wrong, expect: ["
|
||||
<< outputs_info_[i].type << "]";
|
||||
throw graphkernel::GKException(oss.str());
|
||||
}
|
||||
if (outputs[i]->format != outputs_info_[i].format) {
|
||||
std::ostringstream oss;
|
||||
oss << "Op " << this->op_ << "'s output format [" << outputs[i]->format << "] is wrong, expect: ["
|
||||
<< outputs_info_[i].format << "]";
|
||||
throw graphkernel::GKException(oss.str());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<int64_t> GetAxisList(const ValuePtr &value) {
|
||||
std::vector<int64_t> result;
|
||||
auto get_int_value = [](const ValuePtr &value) -> int64_t {
|
||||
return value->isa<Int64Imm>() ? GetValue<int64_t>(value) : static_cast<int64_t>(GetValue<int>(value));
|
||||
};
|
||||
if (value->isa<ValueSequeue>()) {
|
||||
const auto &vals = value->cast<ValueSequeuePtr>()->value();
|
||||
(void)std::transform(vals.begin(), vals.end(), std::back_inserter(result), get_int_value);
|
||||
} else {
|
||||
result.push_back(get_int_value(value));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
} // namespace expanders
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,129 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_EXPANDERS_UTILS_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_EXPANDERS_UTILS_H_
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "backend/optimizer/graph_kernel/model/lite_graph.h"
|
||||
#include "backend/optimizer/graph_kernel/model/node.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace expanders {
|
||||
using graphkernel::NodePtrList;
|
||||
using BaseInfoList = std::vector<graphkernel::NodeBase>;
|
||||
class Validator;
|
||||
|
||||
class OpExpander {
|
||||
public:
|
||||
graphkernel::LiteGraphPtr Run(const BaseInfoList &inputs, const BaseInfoList &outputs,
|
||||
const graphkernel::DAttrs &attrs, const std::string &processor);
|
||||
virtual ~OpExpander() = default;
|
||||
|
||||
protected:
|
||||
virtual void CheckInputs() {}
|
||||
virtual NodePtrList Expand() = 0;
|
||||
void CheckOutputs();
|
||||
|
||||
graphkernel::LiteGraph::GraphBuilder gb;
|
||||
std::string op_;
|
||||
BaseInfoList inputs_info_;
|
||||
BaseInfoList outputs_info_;
|
||||
graphkernel::DAttrs attrs_;
|
||||
std::string processor_;
|
||||
std::vector<std::unique_ptr<Validator>> validators_;
|
||||
|
||||
friend class OpExpanderFactory;
|
||||
friend class CheckAllFormatsSame;
|
||||
friend class CheckAttr;
|
||||
friend class SupportFormat;
|
||||
};
|
||||
|
||||
class Validator {
|
||||
public:
|
||||
virtual void Check(const OpExpander &e) = 0;
|
||||
};
|
||||
|
||||
class CheckAllFormatsSame : public Validator {
|
||||
public:
|
||||
void Check(const OpExpander &e) override {
|
||||
if (e.inputs_info_.empty()) return;
|
||||
const auto &fmt_0 = e.inputs_info_[0].format;
|
||||
for (size_t i = 1; i < e.inputs_info_.size(); i++) {
|
||||
if (e.inputs_info_[i].format != fmt_0) {
|
||||
std::ostringstream oss;
|
||||
oss << "Unmatched format for op " << e.op_;
|
||||
throw graphkernel::GKException(oss.str());
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class CheckAttr : public Validator {
|
||||
public:
|
||||
CheckAttr() = default;
|
||||
CheckAttr(std::initializer_list<std::string> l) : attrs_(l) {}
|
||||
~CheckAttr() = default;
|
||||
void Check(const OpExpander &e) override {
|
||||
for (auto &a : attrs_) {
|
||||
if (e.attrs_.count(a) == 0) {
|
||||
std::ostringstream oss;
|
||||
oss << "attr " << a << " does not exist. op " << e.op_;
|
||||
throw graphkernel::GKException(oss.str());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<std::string> attrs_;
|
||||
};
|
||||
|
||||
class SupportFormat : public Validator {
|
||||
public:
|
||||
void AddFormat(std::initializer_list<std::string> l) { formats_.emplace_back(l); }
|
||||
void Check(const OpExpander &e) override {
|
||||
for (auto &formats : formats_) {
|
||||
if (formats.size() != e.inputs_info_.size()) {
|
||||
continue;
|
||||
}
|
||||
bool match = true;
|
||||
for (size_t i = 0; i < formats.size(); i++) {
|
||||
if (e.inputs_info_[i].format != formats[i]) {
|
||||
match = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (match) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
std::ostringstream oss;
|
||||
oss << "unsupported format for op " << e.op_;
|
||||
throw graphkernel::GKException(oss.str());
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<std::vector<std::string>> formats_;
|
||||
};
|
||||
|
||||
std::vector<int64_t> GetAxisList(const ValuePtr &value);
|
||||
} // namespace expanders
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_EXPANDERS_UTILS_H_
|
|
@ -35,6 +35,7 @@
|
|||
#include "pybind_api/ir/primitive_py.h"
|
||||
#include "runtime/device/kernel_info.h"
|
||||
#include "vm/segment_runner.h"
|
||||
#include "backend/optimizer/graph_kernel/expanders/expander_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -99,14 +100,14 @@ std::vector<PrimitivePtr> GetExpandOps() {
|
|||
}
|
||||
} // namespace
|
||||
|
||||
bool DefaultExpander::ExpandJsonInfo(const AnfNodePtr &node, nlohmann::json *kernel_json) {
|
||||
bool PyExpander::ExpandJsonInfo(const AnfNodePtr &node, nlohmann::json *kernel_json) {
|
||||
DumpOption dump_option;
|
||||
dump_option.extract_opinfo_from_anfnode = true;
|
||||
kernel::AkgKernelJsonGenerator json_generator(dump_option);
|
||||
return json_generator.CollectJson(node, kernel_json);
|
||||
}
|
||||
|
||||
FuncGraphPtr DefaultExpander::CreateExpandFuncGraph(const CNodePtr &node) {
|
||||
FuncGraphPtr PyExpander::CreateExpandFuncGraph(const CNodePtr &node) {
|
||||
nlohmann::json kernel_json;
|
||||
if (!ExpandJsonInfo(node, &kernel_json)) {
|
||||
MS_LOG(ERROR) << "Expand json info to: " << node->DebugString(2) << " failed, ori_json:\n" << kernel_json.dump();
|
||||
|
@ -131,7 +132,36 @@ FuncGraphPtr DefaultExpander::CreateExpandFuncGraph(const CNodePtr &node) {
|
|||
return JsonDescToAnf(kernel_desc_str);
|
||||
}
|
||||
|
||||
AnfNodePtr DefaultExpander::CreateExpandGraphKernel(const FuncGraphPtr &new_func_graph, const CNodePtr &old_node) {
|
||||
FuncGraphPtr DefaultExpander::CreateExpandFuncGraph(const CNodePtr &node) {
|
||||
auto expander_ptr = expanders::OpExpanderFactory::Instance().GetExpander(AnfAlgo::GetCNodeName(node));
|
||||
if (expander_ptr == nullptr) {
|
||||
return PyExpander::CreateExpandFuncGraph(node);
|
||||
}
|
||||
expanders::BaseInfoList inputs(node->size() - 1);
|
||||
expanders::BaseInfoList outputs(AnfAlgo::GetOutputTensorNum(node));
|
||||
for (size_t i = 0; i < inputs.size(); i++) {
|
||||
auto shape = AnfAlgo::GetInputDeviceShape(node, i);
|
||||
std::transform(shape.begin(), shape.end(), std::back_inserter(inputs[i].shape), SizeToLong);
|
||||
inputs[i].type = AnfAlgo::GetInputDeviceDataType(node, i);
|
||||
inputs[i].format = AnfAlgo::GetInputFormat(node, i);
|
||||
}
|
||||
for (size_t i = 0; i < outputs.size(); i++) {
|
||||
auto shape = AnfAlgo::GetOutputDeviceShape(node, i);
|
||||
std::transform(shape.begin(), shape.end(), std::back_inserter(outputs[i].shape), SizeToLong);
|
||||
outputs[i].type = AnfAlgo::GetOutputDeviceDataType(node, i);
|
||||
outputs[i].format = AnfAlgo::GetOutputFormat(node, i);
|
||||
}
|
||||
auto &attrs = AnfAlgo::GetCNodePrimitive(node)->attrs();
|
||||
try {
|
||||
auto litegraph = expander_ptr->Run(inputs, outputs, attrs, kernel::GetStrProcessorFromContext());
|
||||
return LiteGraph2AnfGraph(litegraph);
|
||||
} catch (const graphkernel::GKException &e) {
|
||||
MS_LOG(INFO) << e.what() << ", undo expanding this op";
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
AnfNodePtr PyExpander::CreateExpandGraphKernel(const FuncGraphPtr &new_func_graph, const CNodePtr &old_node) {
|
||||
auto func_graph = old_node->func_graph();
|
||||
std::vector<AnfNodePtr> inputs(old_node->inputs().begin() + 1, old_node->inputs().end());
|
||||
AnfNodePtrList kernel_nodes;
|
||||
|
@ -146,7 +176,7 @@ AnfNodePtr DefaultExpander::CreateExpandGraphKernel(const FuncGraphPtr &new_func
|
|||
return graph_kernel_node;
|
||||
}
|
||||
|
||||
AnfNodePtr DefaultExpander::Run(const AnfNodePtr &node) {
|
||||
AnfNodePtr PyExpander::Run(const AnfNodePtr &node) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto new_func_graph = CreateExpandFuncGraph(cnode);
|
||||
|
@ -205,6 +235,7 @@ bool GraphKernelExpander::DoExpand(const FuncGraphPtr &func_graph) {
|
|||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
bool GraphKernelComplexExpander::CanExpand(const CNodePtr &node) const {
|
||||
bool has_complex = false;
|
||||
auto all_inputs_type = AnfAlgo::GetAllInputDeviceTypes(node);
|
||||
|
|
|
@ -30,7 +30,7 @@ class Expander {
|
|||
};
|
||||
using ExpanderPtr = std::shared_ptr<Expander>;
|
||||
|
||||
class DefaultExpander : public Expander {
|
||||
class PyExpander : public Expander {
|
||||
public:
|
||||
AnfNodePtr Run(const AnfNodePtr &node) override;
|
||||
|
||||
|
@ -39,6 +39,12 @@ class DefaultExpander : public Expander {
|
|||
virtual AnfNodePtr CreateExpandGraphKernel(const FuncGraphPtr &new_func_graph, const CNodePtr &old_node);
|
||||
virtual FuncGraphPtr CreateExpandFuncGraph(const CNodePtr &node);
|
||||
};
|
||||
|
||||
class DefaultExpander : public PyExpander {
|
||||
protected:
|
||||
FuncGraphPtr CreateExpandFuncGraph(const CNodePtr &node) override;
|
||||
};
|
||||
|
||||
class ComplexOpExpander : public DefaultExpander {
|
||||
protected:
|
||||
bool ExpandJsonInfo(const AnfNodePtr &node, nlohmann::json *kernel_json);
|
||||
|
|
|
@ -0,0 +1,59 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, data_format="NCHW"):
|
||||
super(Net, self).__init__()
|
||||
self.bias_add = P.BiasAdd(data_format)
|
||||
|
||||
def construct(self, x, b):
|
||||
return self.bias_add(x, b)
|
||||
|
||||
|
||||
def get_output(x, b, data_format, enable_graph_kernel):
|
||||
context.set_context(enable_graph_kernel=enable_graph_kernel)
|
||||
net = Net(data_format)
|
||||
output = net(x, b)
|
||||
return output
|
||||
|
||||
|
||||
def test_bias_add(shape1, shape2, data_format, dtype):
|
||||
np.random.seed(0)
|
||||
x = Tensor(np.random.normal(0, 10, shape1).astype(dtype))
|
||||
b = Tensor(np.ones(shape2).astype(dtype))
|
||||
expect = get_output(x, b, data_format, False)
|
||||
output = get_output(x, b, data_format, True)
|
||||
|
||||
expect_np = expect.asnumpy().copy()
|
||||
output_np = output.asnumpy().copy()
|
||||
|
||||
assert np.allclose(expect_np, output_np, 0.0001, 0.0001)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_bias_add_gpu():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
test_bias_add((2, 3), (3,), "NCHW", np.float32)
|
||||
test_bias_add((2, 3, 4, 5), (3,), "NCHW", np.float32)
|
||||
test_bias_add((2, 3, 4, 5), (5,), "NHWC", np.float32)
|
Loading…
Reference in New Issue