Add expanders in c++ code

transplant the op expander code from python to c++, base on LiteGraph.
the c++ expander will be called in priority if it was registered in OpExpanderFactory.

add two examples, BiasAdd and ExpandDims.
remove BiasAdd from python expanders.
since the ExpandDims is also imported by other ops (e.g. BatchNorm), we don't remove it now.
This commit is contained in:
dayschan 2021-08-12 14:40:02 +08:00
parent dd7c7dcb83
commit 9add26ad99
10 changed files with 516 additions and 55 deletions

View File

@ -18,7 +18,6 @@ from .addn import AddN
from .assign_add import AssignAdd
from .batchnorm import BatchNorm
from .batchnorm_grad import BatchNormGrad
from .bias_add import BiasAdd
from .bias_add_grad import BiasAddGrad
from .clip_by_norm_no_div_sum import ClipByNormNoDivSum
from .conv2d import Conv2D
@ -26,7 +25,6 @@ from .complex import CAbs, CAdd, CDiv, CMul, CSub
from .dropout_grad import DropoutGrad
from .equal_count import EqualCount
from .erfc import Erfc
from .expand_dims import ExpandDims
from .fused_adam import FusedAdam
from .fused_adam_weight_decay import FusedAdamWeightDecay
from .fused_mul_add import FusedMulAdd

View File

@ -1,48 +0,0 @@
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""generate json desc for bias_add"""
from mindspore._extends.graph_kernel.model.model import DataFormat as DF
from ._utils import Expander, ExpanderInfoValidator as VLD
from .expand_dims import ExpandDims
@VLD.add_format(DF.DEFAULT, DF.DEFAULT)
@VLD.add_format(DF.NCHW, DF.DEFAULT)
@VLD.add_format(DF.NHWC, DF.DEFAULT)
class BiasAdd(Expander):
"""BiasAdd expander"""
def _expand(self, graph_builder):
input_x, input_y = self.inputs
if input_x.data_format == DF.NCHW:
input_y_expand = graph_builder.emit(
'Reshape', [input_y], attrs={'shape': ExpandDims.infer_shape(input_y.shape, [1, 2])})
result = graph_builder.emit('Add', [input_x, input_y_expand])
elif input_x.data_format == DF.DEFAULT:
if len(input_x.shape) == 2:
result = graph_builder.emit('Add', [input_x, input_y])
elif len(input_x.shape) == 3:
input_y_expand = graph_builder.emit(
'Reshape', [input_y], attrs={'shape': ExpandDims.infer_shape(input_y.shape, 1)})
result = graph_builder.emit('Add', [input_x, input_y_expand])
else: # len == 4
input_y_expand = graph_builder.emit(
'Reshape', [input_y], attrs={'shape': ExpandDims.infer_shape(input_y.shape, [1, 2])})
result = graph_builder.emit('Add', [input_x, input_y_expand])
else: # NHWC
result = graph_builder.emit('Add', [input_x, input_y])
return result

View File

@ -0,0 +1,60 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_EXPANDERS_BIAS_ADD_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_EXPANDERS_BIAS_ADD_H_
#include <string>
#include <utility>
#include <vector>
#include <memory>
#include "backend/optimizer/graph_kernel/expanders/utils.h"
namespace mindspore {
namespace opt {
namespace expanders {
class BiasAdd : public OpExpander {
public:
BiasAdd() {
auto support_format = std::make_unique<SupportFormat>();
support_format->AddFormat({kOpFormat_DEFAULT, kOpFormat_DEFAULT});
support_format->AddFormat({kOpFormat_NCHW, kOpFormat_DEFAULT});
support_format->AddFormat({kOpFormat_NHWC, kOpFormat_DEFAULT});
validators_.emplace_back(std::move(support_format));
validators_.emplace_back(new CheckAttr({"format"}));
}
~BiasAdd() = default;
NodePtrList Expand() override {
const auto &inputs = gb.Get()->inputs();
auto input_x = inputs[0];
auto input_y = inputs[1];
if (input_x->format == kOpFormat_NCHW) {
input_y = gb.Emit("Reshape", {input_y}, {{"shape", MakeValue(ExpandDims::InferShape(input_y->shape, {1, 2}))}});
} else if (input_x->format == kOpFormat_DEFAULT) {
auto data_format = GetValue<std::string>(attrs_["format"]);
size_t channel_idx = (data_format == kOpFormat_NHWC) ? input_x->shape.size() - 1 : 1;
std::vector<int64_t> axis(input_x->shape.size() - channel_idx - 1, -1);
if (!axis.empty()) {
input_y = gb.Emit("Reshape", {input_y}, {{"shape", MakeValue(ExpandDims::InferShape(input_y->shape, axis))}});
}
}
return {gb.Emit("Add", {input_x, input_y})};
}
};
} // namespace expanders
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_EXPANDERS_BIAS_ADD_H_

View File

@ -0,0 +1,65 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_EXPANDERS_EXPANDER_FACTORY_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_EXPANDERS_EXPANDER_FACTORY_H_
#include <unordered_map>
#include <functional>
#include <string>
#include <memory>
#include "backend/optimizer/graph_kernel/expanders/utils.h"
#include "backend/optimizer/graph_kernel/expanders/reshape.h"
#include "backend/optimizer/graph_kernel/expanders/bias_add.h"
namespace mindspore {
namespace opt {
namespace expanders {
#define OP_EXPANDER_CREATOR(cls) []() -> std::shared_ptr<OpExpander> { return std::make_shared<cls>(); }
class OpExpanderFactory {
public:
static OpExpanderFactory &Instance() {
static std::unique_ptr<OpExpanderFactory> instance = nullptr;
if (instance == nullptr) {
instance.reset(new OpExpanderFactory());
}
return *instance;
}
std::shared_ptr<OpExpander> GetExpander(const std::string &op) {
if (auto iter = creators.find(op); iter != creators.end()) {
auto expander_ptr = iter->second();
expander_ptr->op_ = op;
return expander_ptr;
}
return nullptr;
}
~OpExpanderFactory() = default;
private:
using RegFunc = std::function<std::shared_ptr<OpExpander>()>;
void Register(std::string &&op, RegFunc &&func) { creators.insert({op, func}); }
OpExpanderFactory() {
Register("BiasAdd", OP_EXPANDER_CREATOR(expanders::BiasAdd));
Register("ExpandDims", OP_EXPANDER_CREATOR(expanders::ExpandDims));
}
std::unordered_map<std::string, RegFunc> creators;
};
} // namespace expanders
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_EXPANDERS_EXPANDER_FACTORY_H_

View File

@ -0,0 +1,61 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_EXPANDERS_RESHAPE_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_EXPANDERS_RESHAPE_H_
#include <memory>
#include <vector>
#include "backend/optimizer/graph_kernel/model/node.h"
#include "backend/optimizer/graph_kernel/expanders/utils.h"
namespace mindspore {
namespace opt {
namespace expanders {
class ExpandDims : public OpExpander {
public:
ExpandDims() { validators_.emplace_back(new CheckAttr({"axis"})); }
~ExpandDims() {}
NodePtrList Expand() override {
const auto &inputs = gb.Get()->inputs();
auto &input_x = inputs[0];
auto shape = MakeValue(ExpandDims::InferShape(input_x->shape, GetAxisList(this->attrs_["axis"])));
auto result = gb.Emit("Reshape", {input_x}, {{"shape", shape}});
return {result};
}
static ShapeVector InferShape(const ShapeVector &shape, const std::vector<int64_t> &axis) {
ShapeVector new_shape = shape;
for (auto x : axis) {
int64_t rank = static_cast<int64_t>(new_shape.size());
if (x > rank || x < -rank - 1) {
std::ostringstream oss;
oss << "ExpandDims axis " << x << " is out of range of size " << new_shape.size();
throw graphkernel::GKException(oss.str());
}
if (x >= 0) {
new_shape.insert(new_shape.begin() + x, 1LL);
} else {
new_shape.insert(new_shape.begin() + (x + rank + 1), 1LL);
}
}
return new_shape;
}
};
} // namespace expanders
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_EXPANDERS_RESHAPE_H_

View File

@ -0,0 +1,100 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/graph_kernel/expanders/utils.h"
#include <algorithm>
#include <string>
#include <vector>
#include "backend/optimizer/graph_kernel/model/lite_graph.h"
#include "backend/optimizer/graph_kernel/model/node.h"
namespace mindspore {
namespace opt {
namespace expanders {
graphkernel::LiteGraphPtr OpExpander::Run(const BaseInfoList &inputs, const BaseInfoList &outputs,
const graphkernel::DAttrs &attrs, const std::string &processor) {
this->inputs_info_ = inputs;
this->outputs_info_ = outputs;
this->attrs_ = attrs;
this->processor_ = processor;
for (const auto &v : validators_) {
v->Check(*this);
}
this->CheckInputs();
for (auto &inp : inputs) {
gb.Parameter(inp);
}
auto result = this->Expand();
gb.SetOutputs(result);
this->CheckOutputs();
return gb.Get();
}
void OpExpander::CheckOutputs() {
// check the output shape/type/format are same as the original basic node's output.
const NodePtrList &outputs = gb.Get()->GetOutputs();
if (outputs.size() != this->outputs_info_.size()) {
std::ostringstream oss;
oss << "the output num was not equal to the original output num : " << outputs.size() << " vs "
<< outputs_info_.size();
throw graphkernel::GKException(oss.str());
}
for (size_t i = 0; i < outputs.size(); i++) {
if (outputs[i]->shape != outputs_info_[i].shape) {
std::ostringstream oss;
oss << "Op " << this->op_ << "'s output shape [";
for (auto s : outputs[i]->shape) {
oss << s << ",";
}
oss << "] is wrong. expect: [";
for (auto s : outputs_info_[i].shape) {
oss << s << ",";
}
oss << "]";
throw graphkernel::GKException(oss.str());
}
if (outputs[i]->type != outputs_info_[i].type) {
std::ostringstream oss;
oss << "Op " << this->op_ << "'s output type [" << outputs[i]->type << "] is wrong, expect: ["
<< outputs_info_[i].type << "]";
throw graphkernel::GKException(oss.str());
}
if (outputs[i]->format != outputs_info_[i].format) {
std::ostringstream oss;
oss << "Op " << this->op_ << "'s output format [" << outputs[i]->format << "] is wrong, expect: ["
<< outputs_info_[i].format << "]";
throw graphkernel::GKException(oss.str());
}
}
}
std::vector<int64_t> GetAxisList(const ValuePtr &value) {
std::vector<int64_t> result;
auto get_int_value = [](const ValuePtr &value) -> int64_t {
return value->isa<Int64Imm>() ? GetValue<int64_t>(value) : static_cast<int64_t>(GetValue<int>(value));
};
if (value->isa<ValueSequeue>()) {
const auto &vals = value->cast<ValueSequeuePtr>()->value();
(void)std::transform(vals.begin(), vals.end(), std::back_inserter(result), get_int_value);
} else {
result.push_back(get_int_value(value));
}
return result;
}
} // namespace expanders
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,129 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_EXPANDERS_UTILS_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_EXPANDERS_UTILS_H_
#include <string>
#include <memory>
#include <vector>
#include "backend/optimizer/graph_kernel/model/lite_graph.h"
#include "backend/optimizer/graph_kernel/model/node.h"
namespace mindspore {
namespace opt {
namespace expanders {
using graphkernel::NodePtrList;
using BaseInfoList = std::vector<graphkernel::NodeBase>;
class Validator;
class OpExpander {
public:
graphkernel::LiteGraphPtr Run(const BaseInfoList &inputs, const BaseInfoList &outputs,
const graphkernel::DAttrs &attrs, const std::string &processor);
virtual ~OpExpander() = default;
protected:
virtual void CheckInputs() {}
virtual NodePtrList Expand() = 0;
void CheckOutputs();
graphkernel::LiteGraph::GraphBuilder gb;
std::string op_;
BaseInfoList inputs_info_;
BaseInfoList outputs_info_;
graphkernel::DAttrs attrs_;
std::string processor_;
std::vector<std::unique_ptr<Validator>> validators_;
friend class OpExpanderFactory;
friend class CheckAllFormatsSame;
friend class CheckAttr;
friend class SupportFormat;
};
class Validator {
public:
virtual void Check(const OpExpander &e) = 0;
};
class CheckAllFormatsSame : public Validator {
public:
void Check(const OpExpander &e) override {
if (e.inputs_info_.empty()) return;
const auto &fmt_0 = e.inputs_info_[0].format;
for (size_t i = 1; i < e.inputs_info_.size(); i++) {
if (e.inputs_info_[i].format != fmt_0) {
std::ostringstream oss;
oss << "Unmatched format for op " << e.op_;
throw graphkernel::GKException(oss.str());
}
}
}
};
class CheckAttr : public Validator {
public:
CheckAttr() = default;
CheckAttr(std::initializer_list<std::string> l) : attrs_(l) {}
~CheckAttr() = default;
void Check(const OpExpander &e) override {
for (auto &a : attrs_) {
if (e.attrs_.count(a) == 0) {
std::ostringstream oss;
oss << "attr " << a << " does not exist. op " << e.op_;
throw graphkernel::GKException(oss.str());
}
}
}
private:
std::vector<std::string> attrs_;
};
class SupportFormat : public Validator {
public:
void AddFormat(std::initializer_list<std::string> l) { formats_.emplace_back(l); }
void Check(const OpExpander &e) override {
for (auto &formats : formats_) {
if (formats.size() != e.inputs_info_.size()) {
continue;
}
bool match = true;
for (size_t i = 0; i < formats.size(); i++) {
if (e.inputs_info_[i].format != formats[i]) {
match = false;
break;
}
}
if (match) {
return;
}
}
std::ostringstream oss;
oss << "unsupported format for op " << e.op_;
throw graphkernel::GKException(oss.str());
}
private:
std::vector<std::vector<std::string>> formats_;
};
std::vector<int64_t> GetAxisList(const ValuePtr &value);
} // namespace expanders
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_EXPANDERS_UTILS_H_

View File

@ -35,6 +35,7 @@
#include "pybind_api/ir/primitive_py.h"
#include "runtime/device/kernel_info.h"
#include "vm/segment_runner.h"
#include "backend/optimizer/graph_kernel/expanders/expander_factory.h"
namespace mindspore {
namespace opt {
@ -99,14 +100,14 @@ std::vector<PrimitivePtr> GetExpandOps() {
}
} // namespace
bool DefaultExpander::ExpandJsonInfo(const AnfNodePtr &node, nlohmann::json *kernel_json) {
bool PyExpander::ExpandJsonInfo(const AnfNodePtr &node, nlohmann::json *kernel_json) {
DumpOption dump_option;
dump_option.extract_opinfo_from_anfnode = true;
kernel::AkgKernelJsonGenerator json_generator(dump_option);
return json_generator.CollectJson(node, kernel_json);
}
FuncGraphPtr DefaultExpander::CreateExpandFuncGraph(const CNodePtr &node) {
FuncGraphPtr PyExpander::CreateExpandFuncGraph(const CNodePtr &node) {
nlohmann::json kernel_json;
if (!ExpandJsonInfo(node, &kernel_json)) {
MS_LOG(ERROR) << "Expand json info to: " << node->DebugString(2) << " failed, ori_json:\n" << kernel_json.dump();
@ -131,7 +132,36 @@ FuncGraphPtr DefaultExpander::CreateExpandFuncGraph(const CNodePtr &node) {
return JsonDescToAnf(kernel_desc_str);
}
AnfNodePtr DefaultExpander::CreateExpandGraphKernel(const FuncGraphPtr &new_func_graph, const CNodePtr &old_node) {
FuncGraphPtr DefaultExpander::CreateExpandFuncGraph(const CNodePtr &node) {
auto expander_ptr = expanders::OpExpanderFactory::Instance().GetExpander(AnfAlgo::GetCNodeName(node));
if (expander_ptr == nullptr) {
return PyExpander::CreateExpandFuncGraph(node);
}
expanders::BaseInfoList inputs(node->size() - 1);
expanders::BaseInfoList outputs(AnfAlgo::GetOutputTensorNum(node));
for (size_t i = 0; i < inputs.size(); i++) {
auto shape = AnfAlgo::GetInputDeviceShape(node, i);
std::transform(shape.begin(), shape.end(), std::back_inserter(inputs[i].shape), SizeToLong);
inputs[i].type = AnfAlgo::GetInputDeviceDataType(node, i);
inputs[i].format = AnfAlgo::GetInputFormat(node, i);
}
for (size_t i = 0; i < outputs.size(); i++) {
auto shape = AnfAlgo::GetOutputDeviceShape(node, i);
std::transform(shape.begin(), shape.end(), std::back_inserter(outputs[i].shape), SizeToLong);
outputs[i].type = AnfAlgo::GetOutputDeviceDataType(node, i);
outputs[i].format = AnfAlgo::GetOutputFormat(node, i);
}
auto &attrs = AnfAlgo::GetCNodePrimitive(node)->attrs();
try {
auto litegraph = expander_ptr->Run(inputs, outputs, attrs, kernel::GetStrProcessorFromContext());
return LiteGraph2AnfGraph(litegraph);
} catch (const graphkernel::GKException &e) {
MS_LOG(INFO) << e.what() << ", undo expanding this op";
return nullptr;
}
}
AnfNodePtr PyExpander::CreateExpandGraphKernel(const FuncGraphPtr &new_func_graph, const CNodePtr &old_node) {
auto func_graph = old_node->func_graph();
std::vector<AnfNodePtr> inputs(old_node->inputs().begin() + 1, old_node->inputs().end());
AnfNodePtrList kernel_nodes;
@ -146,7 +176,7 @@ AnfNodePtr DefaultExpander::CreateExpandGraphKernel(const FuncGraphPtr &new_func
return graph_kernel_node;
}
AnfNodePtr DefaultExpander::Run(const AnfNodePtr &node) {
AnfNodePtr PyExpander::Run(const AnfNodePtr &node) {
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto new_func_graph = CreateExpandFuncGraph(cnode);
@ -205,6 +235,7 @@ bool GraphKernelExpander::DoExpand(const FuncGraphPtr &func_graph) {
}
return changed;
}
bool GraphKernelComplexExpander::CanExpand(const CNodePtr &node) const {
bool has_complex = false;
auto all_inputs_type = AnfAlgo::GetAllInputDeviceTypes(node);

View File

@ -30,7 +30,7 @@ class Expander {
};
using ExpanderPtr = std::shared_ptr<Expander>;
class DefaultExpander : public Expander {
class PyExpander : public Expander {
public:
AnfNodePtr Run(const AnfNodePtr &node) override;
@ -39,6 +39,12 @@ class DefaultExpander : public Expander {
virtual AnfNodePtr CreateExpandGraphKernel(const FuncGraphPtr &new_func_graph, const CNodePtr &old_node);
virtual FuncGraphPtr CreateExpandFuncGraph(const CNodePtr &node);
};
class DefaultExpander : public PyExpander {
protected:
FuncGraphPtr CreateExpandFuncGraph(const CNodePtr &node) override;
};
class ComplexOpExpander : public DefaultExpander {
protected:
bool ExpandJsonInfo(const AnfNodePtr &node, nlohmann::json *kernel_json);

View File

@ -0,0 +1,59 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P
class Net(nn.Cell):
def __init__(self, data_format="NCHW"):
super(Net, self).__init__()
self.bias_add = P.BiasAdd(data_format)
def construct(self, x, b):
return self.bias_add(x, b)
def get_output(x, b, data_format, enable_graph_kernel):
context.set_context(enable_graph_kernel=enable_graph_kernel)
net = Net(data_format)
output = net(x, b)
return output
def test_bias_add(shape1, shape2, data_format, dtype):
np.random.seed(0)
x = Tensor(np.random.normal(0, 10, shape1).astype(dtype))
b = Tensor(np.ones(shape2).astype(dtype))
expect = get_output(x, b, data_format, False)
output = get_output(x, b, data_format, True)
expect_np = expect.asnumpy().copy()
output_np = output.asnumpy().copy()
assert np.allclose(expect_np, output_np, 0.0001, 0.0001)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_bias_add_gpu():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
test_bias_add((2, 3), (3,), "NCHW", np.float32)
test_bias_add((2, 3, 4, 5), (3,), "NCHW", np.float32)
test_bias_add((2, 3, 4, 5), (5,), "NHWC", np.float32)