support complex in gpu about graph_kernel

This commit is contained in:
zengzitao 2021-07-05 15:22:50 +08:00
parent ed6cb44e7f
commit aa019a639f
17 changed files with 440 additions and 31 deletions

View File

@ -22,6 +22,7 @@ from .bias_add import BiasAdd
from .bias_add_grad import BiasAddGrad
from .clip_by_norm_no_div_sum import ClipByNormNoDivSum
from .conv2d import Conv2D
from .complex import CAbs, CAdd, CDiv, CMul, CSub
from .dropout_grad import DropoutGrad
from .expand_dims import ExpandDims
from .fused_adam import FusedAdam

View File

@ -0,0 +1,21 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""complex expanders init"""
from .abs import CAbs
from .add import CAdd
from .div import CDiv
from .mul import CMul
from .sub import CSub

View File

@ -0,0 +1,30 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""generate json desc for cabs"""
from mindspore._extends.graph_kernel.expanders._utils import Expander
class CAbs(Expander):
"""CAbs expander"""
def _expand(self, graph_builder):
input_x = self.inputs[0]
x_real = graph_builder.emit('CReal', [input_x])
x_imag = graph_builder.emit('CImag', [input_x])
squre_x_real = graph_builder.emit('Mul', [x_real, x_real])
squre_x_imag = graph_builder.emit('Mul', [x_imag, x_imag])
squre_sum = graph_builder.emit('Add', [squre_x_real, squre_x_imag])
result = graph_builder.emit('Sqrt', [squre_sum])
return result

View File

@ -0,0 +1,33 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""generate json desc for cadd"""
from mindspore._extends.graph_kernel.model.model import DataFormat as DF
from mindspore._extends.graph_kernel.expanders._utils import Expander, ExpanderInfoValidator as VLD
@VLD.add_format(DF.DEFAULT, DF.DEFAULT)
class CAdd(Expander):
"""CAdd expander"""
def _expand(self, graph_builder):
input_x, input_y = self.inputs
x_real = graph_builder.emit('CReal', [input_x])
y_real = graph_builder.emit('CReal', [input_y])
x_imag = graph_builder.emit('CImag', [input_x])
y_imag = graph_builder.emit('CImag', [input_y])
result_real = graph_builder.emit('Add', [x_real, y_real])
result_imag = graph_builder.emit('Add', [x_imag, y_imag])
result = graph_builder.emit('Complex', [result_real, result_imag])
return result

View File

@ -0,0 +1,43 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""generate json desc for cdiv"""
from mindspore._extends.graph_kernel.model.model import DataFormat as DF
from mindspore._extends.graph_kernel.expanders._utils import Expander, ExpanderInfoValidator as VLD
@VLD.add_format(DF.DEFAULT, DF.DEFAULT)
class CDiv(Expander):
"""CDiv expander"""
def _expand(self, graph_builder):
"""CDiv Implementation"""
input_x, input_y = self.inputs
x_real = graph_builder.emit('CReal', [input_x])
y_real = graph_builder.emit('CReal', [input_y])
x_imag = graph_builder.emit('CImag', [input_x])
y_imag = graph_builder.emit('CImag', [input_y])
squre_y_real = graph_builder.emit('Mul', [y_real, y_real])
squre_y_imag = graph_builder.emit('Mul', [y_imag, y_imag])
final_denominator = graph_builder.emit('Add', [squre_y_real, squre_y_imag])
x_real_mul_y_real = graph_builder.emit('Mul', [x_real, y_real])
x_imag_mul_y_imag = graph_builder.emit('Mul', [x_imag, y_imag])
x_real_mul_y_imag = graph_builder.emit('Mul', [x_real, y_imag])
x_imag_mul_y_real = graph_builder.emit('Mul', [x_imag, y_real])
final_numerator_real = graph_builder.emit('Add', [x_real_mul_y_real, x_imag_mul_y_imag])
final_numerator_imag = graph_builder.emit('Sub', [x_imag_mul_y_real, x_real_mul_y_imag])
result_real = graph_builder.emit('RealDiv', [final_numerator_real, final_denominator])
result_imag = graph_builder.emit('RealDiv', [final_numerator_imag, final_denominator])
result = graph_builder.emit('Complex', [result_real, result_imag])
return result

View File

@ -0,0 +1,38 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""generate json desc for cmul"""
from mindspore._extends.graph_kernel.model.model import DataFormat as DF
from mindspore._extends.graph_kernel.expanders._utils import Expander, ExpanderInfoValidator as VLD
@VLD.add_format(DF.DEFAULT, DF.DEFAULT)
class CMul(Expander):
"""CMul expander"""
def _expand(self, graph_builder):
"""CMul Implementation"""
input_x, input_y = self.inputs
x_real = graph_builder.emit('CReal', [input_x])
y_real = graph_builder.emit('CReal', [input_y])
x_imag = graph_builder.emit('CImag', [input_x])
y_imag = graph_builder.emit('CImag', [input_y])
x_real_mul_y_real = graph_builder.emit('Mul', [x_real, y_real])
x_imag_mul_y_imag = graph_builder.emit('Mul', [x_imag, y_imag])
x_real_mul_y_imag = graph_builder.emit('Mul', [x_real, y_imag])
x_imag_mul_y_real = graph_builder.emit('Mul', [x_imag, y_real])
result_real = graph_builder.emit('Sub', [x_real_mul_y_real, x_imag_mul_y_imag])
result_imag = graph_builder.emit('Add', [x_real_mul_y_imag, x_imag_mul_y_real])
result = graph_builder.emit('Complex', [result_real, result_imag])
return result

View File

@ -0,0 +1,33 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""generate json desc for csub"""
from mindspore._extends.graph_kernel.model.model import DataFormat as DF
from mindspore._extends.graph_kernel.expanders._utils import Expander, ExpanderInfoValidator as VLD
@VLD.add_format(DF.DEFAULT, DF.DEFAULT)
class CSub(Expander):
"""CSub expander"""
def _expand(self, graph_builder):
input_x, input_y = self.inputs
x_real = graph_builder.emit('CReal', [input_x])
y_real = graph_builder.emit('CReal', [input_y])
x_imag = graph_builder.emit('CImag', [input_x])
y_imag = graph_builder.emit('CImag', [input_y])
result_real = graph_builder.emit('Sub', [x_real, y_real])
result_imag = graph_builder.emit('Sub', [x_imag, y_imag])
result = graph_builder.emit('Complex', [result_real, result_imag])
return result

View File

@ -192,6 +192,9 @@ class PrimLib:
'UnPadAkg': Prim(OPAQUE),
'PadAkg': Prim(OPAQUE),
'Conv2D': Prim(OPAQUE),
'CReal': Prim(ELEMWISE),
'CImag': Prim(ELEMWISE),
'Complex': Prim(ELEMWISE),
}
default_primtive = Prim(UNKNOWN)
@ -466,6 +469,7 @@ class GraphVisitor:
for i in range(len(graph.ops)-1, -1, -1):
self.visit(graph.ops[i])
class AlignShape(GraphVisitor):
"""Align shape"""
@ -484,6 +488,7 @@ class AlignShape(GraphVisitor):
if align_dim > out_dim:
op.output.shape = [1] * (align_dim - out_dim) + op.output.shape
class AddControlBuddy(GraphVisitor):
"""Add control buddy"""

View File

@ -255,6 +255,38 @@ class _CompareOp(_Elemwise):
return "bool"
class CImag(OpInfer):
def _check_type(self):
if self.inputs[0].dtype != "complex64":
raise GKException(
"CImag's input[0] should be a complex64 condition but got {}".format(self.inputs[0].dtype))
def _infer_type(self):
return "float32"
class CReal(OpInfer):
def _check_type(self):
if self.inputs[0].dtype != "complex64":
raise GKException(
"CReal's input[0] should be a complex64 condition but got {}".format(self.inputs[0].dtype))
def _infer_type(self):
return "float32"
class Complex(OpInfer):
def _check_type(self):
if self.inputs[0].dtype != "float32":
raise GKException(
"Complex's input[0] should be a float32 condition but got {}".format(self.inputs[0].dtype))
if self.inputs[0].dtype != self.inputs[1].dtype:
raise GKException("Complex's input mismatch ({} vs {})".format(self.inputs[0].dtype, self.inputs[1].dtype))
def _infer_type(self):
return "complex64"
class Less(_CompareOp):
pass

View File

@ -44,8 +44,7 @@ const std::unordered_map<std::string, TypeId> type_id_maps = {
{"int64", TypeId::kNumberTypeInt64}, {"uint", TypeId::kNumberTypeUInt},
{"uint8", TypeId::kNumberTypeUInt8}, {"uint16", TypeId::kNumberTypeUInt16},
{"uint32", TypeId::kNumberTypeUInt32}, {"uint64", TypeId::kNumberTypeUInt64},
{"bool", TypeId::kNumberTypeBool},
};
{"bool", TypeId::kNumberTypeBool}, {"complex64", TypeId::kNumberTypeComplex64}};
const std::map<TypeId, std::string> type_id_str_map = {
{TypeId::kNumberTypeFloat32, "float32"}, {TypeId::kNumberTypeFloat16, "float16"},
@ -55,8 +54,7 @@ const std::map<TypeId, std::string> type_id_str_map = {
{TypeId::kNumberTypeInt64, "int64"}, {TypeId::kNumberTypeUInt, "uint"},
{TypeId::kNumberTypeUInt8, "uint8"}, {TypeId::kNumberTypeUInt16, "uint16"},
{TypeId::kNumberTypeUInt32, "uint32"}, {TypeId::kNumberTypeUInt64, "uint64"},
{TypeId::kNumberTypeBool, "bool"},
};
{TypeId::kNumberTypeBool, "bool"}, {TypeId::kNumberTypeComplex64, "complex64"}};
const std::unordered_map<std::string, std::string> dtype_shortdtype_map_ = {
{"float16", "f16"}, {"float32", "f32"}, {"float64", "f64"}, {"int8", "i8"}, {"int16", "i16"}, {"int32", "i32"},
@ -64,11 +62,11 @@ const std::unordered_map<std::string, std::string> dtype_shortdtype_map_ = {
};
const std::unordered_map<std::string, size_t> dtype_nbyte_map = {
{"float16", sizeof(float) / 2}, {"float32", sizeof(float)}, {"float64", sizeof(float) * 2},
{"int8", sizeof(int) / 4}, {"int16", sizeof(int) / 2}, {"int32", sizeof(int)},
{"int64", sizeof(int) * 2}, {"uint8", sizeof(int) / 4}, {"uint16", sizeof(int) / 2},
{"uint32", sizeof(int)}, {"uint64", sizeof(int) * 2}, {"bool", sizeof(char)},
};
{"float16", sizeof(float) / 2}, {"float32", sizeof(float)}, {"float64", sizeof(float) * 2},
{"int8", sizeof(int) / 4}, {"int16", sizeof(int) / 2}, {"int32", sizeof(int)},
{"int64", sizeof(int) * 2}, {"uint8", sizeof(int) / 4}, {"uint16", sizeof(int) / 2},
{"uint32", sizeof(int)}, {"uint64", sizeof(int) * 2}, {"bool", sizeof(char)},
{"complex64", sizeof(float) * 2}};
const std::unordered_map<std::string, FusionType> fusion_type_maps = {
{"CONVLUTION", FusionType::CONVLUTION}, {"ELEMWISE", FusionType::ELEMWISE}, {"COMMREDUCE", FusionType::COMMREDUCE},

View File

@ -0,0 +1,77 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/graph_kernel/eliminate_redundant_complex.h"
#include <algorithm>
#include <vector>
#include <string>
#include <utility>
#include "frontend/optimizer/irpass.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/kernel_compiler/common_utils.h"
#include "debug/anf_ir_dump.h"
namespace mindspore {
namespace opt {
namespace {
bool EliminateRedudantComplexInGraphkernel(const FuncGraphPtr &func_graph) {
auto mng = func_graph->manager();
MS_EXCEPTION_IF_NULL(mng);
auto todos = TopoSort(func_graph->get_return());
bool changed = false;
for (const auto &node : todos) {
auto cnode = node->cast<CNodePtr>();
// Find all Complex node in graphkernel sub_graph
if (cnode != nullptr && IsPrimitiveCNode(cnode, std::make_shared<Primitive>("Complex"))) {
auto original_users = mng->node_users()[cnode];
for (const auto &getitem_iter : original_users) {
auto getitem = getitem_iter.first;
auto getitem_cnode = getitem->cast<CNodePtr>();
// Find all complex users which are CReal or CImag, then use Complex inputs replace them.
if (IsPrimitiveCNode(getitem_cnode, std::make_shared<Primitive>("CReal"))) {
mng->Replace(getitem, cnode->inputs()[1]);
changed = true;
} else if (IsPrimitiveCNode(getitem_cnode, std::make_shared<Primitive>("CImag"))) {
mng->Replace(getitem, cnode->inputs()[2]);
changed = true;
}
}
}
}
return changed;
}
} // namespace
bool EliminateRedundantComplex::Run(const FuncGraphPtr &func_graph) {
auto mng = func_graph->manager();
MS_EXCEPTION_IF_NULL(mng);
bool changed = false;
auto todos = TopoSort(func_graph->get_return());
std::reverse(todos.begin(), todos.end());
for (const auto &node : todos) {
auto cnode = node->cast<CNodePtr>();
// Check whether graph_kernel node
if (cnode != nullptr && AnfAlgo::IsGraphKernel(cnode)) {
auto graph_kernel_fg = AnfAlgo::GetCNodeFuncGraphPtr(cnode);
MS_EXCEPTION_IF_NULL(graph_kernel_fg);
changed = EliminateRedudantComplexInGraphkernel(graph_kernel_fg) || changed;
}
}
return changed;
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,34 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_ELIMINATE_REDUNDANT_COMPLEX_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_ELIMINATE_REDUNDANT_COMPLEX_H_
#include <map>
#include <memory>
#include "backend/optimizer/common/pass.h"
#include "ir/func_graph.h"
namespace mindspore {
namespace opt {
class EliminateRedundantComplex : public Pass {
public:
EliminateRedundantComplex() : Pass("eliminate_redundant_complex") {}
~EliminateRedundantComplex() override = default;
bool Run(const FuncGraphPtr &graph) override;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_ELIMINATE_REDUNDANT_COMPLEX_H_

View File

@ -20,6 +20,7 @@
#include <set>
#include <utility>
#include <vector>
#include <algorithm>
#include "utils/context/graph_kernel_flags.h"
#include "backend/kernel_compiler/akg/akg_kernel_json_generator.h"
@ -205,7 +206,7 @@ bool GraphKernelExpander::DoExpand(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(mng);
for (const auto &n : todos) {
auto node = n->cast<CNodePtr>();
if (node == nullptr || IsKeepBasicNode(node) || !AnfAlgo::IsRealKernel(node) || AnfAlgo::IsGraphKernel(node) ||
if (node == nullptr || AnfAlgo::IsGraphKernel(node) || IsKeepBasicNode(node) || !AnfAlgo::IsRealKernel(node) ||
!CanExpand(node)) {
continue;
}
@ -221,10 +222,57 @@ bool GraphKernelExpander::DoExpand(const FuncGraphPtr &func_graph) {
}
return changed;
}
bool GraphKernelComplexExpander::CanExpand(const CNodePtr &node) const {
bool has_complex = false;
auto all_inputs_type = AnfAlgo::GetAllInputDeviceTypes(node);
for (size_t i = 0; i < all_inputs_type.size(); ++i) {
if (all_inputs_type[i] == kNumberTypeFloat64 || all_inputs_type[i] == kNumberTypeComplex64) {
has_complex = true;
break;
}
}
return has_complex;
}
// Just test for complex op, then will be deleted
ExpanderPtr GraphKernelComplexExpander::GetExpander(const AnfNodePtr &node) {
return std::make_shared<ComplexOpExpander>();
}
bool ComplexOpExpander::ExpandJsonInfo(const AnfNodePtr &node, nlohmann::json *kernel_json) {
auto cnode = node->cast<CNodePtr>();
auto all_inputs_type = AnfAlgo::GetAllInputDeviceTypes(cnode);
for (size_t i = 0; i < all_inputs_type.size(); ++i) {
if (all_inputs_type[i] == kNumberTypeFloat64 || all_inputs_type[i] == kNumberTypeComplex64) {
all_inputs_type[i] = kNumberTypeComplex64;
}
}
auto all_outputs_type = AnfAlgo::GetAllOutputDeviceTypes(cnode);
for (size_t i = 0; i < all_outputs_type.size(); ++i) {
if (all_outputs_type[i] == kNumberTypeFloat64) {
all_outputs_type[i] = kNumberTypeComplex64;
}
}
auto all_inputs_format = AnfAlgo::GetAllInputFormats(cnode);
auto all_outputs_format = AnfAlgo::GetAllOutputFormats(cnode);
auto graph_sel_info =
BuildSelectKernelBuildInfo(all_inputs_format, all_inputs_type, all_outputs_format, all_outputs_type);
AnfAlgo::SetSelectKernelBuildInfo(graph_sel_info, cnode.get());
std::vector<size_t> original_shape = AnfAlgo::GetOutputInferShape(cnode, 0);
ShapeVector real_shape;
std::copy(original_shape.begin(), original_shape.end(), std::back_inserter(real_shape));
auto complex_shape_ptr = std::make_shared<abstract::Shape>(abstract::Shape(real_shape));
TypeId complex_type = kNumberTypeComplex64;
auto abstract = std::make_shared<abstract::AbstractTensor>(TypeIdToType(complex_type), complex_shape_ptr);
cnode->set_abstract(abstract);
if (!DefaultExpander::ExpandJsonInfo(cnode, kernel_json)) return false;
(*kernel_json)["name"] = std::string("C") + AnfAlgo::GetCNodeName(cnode);
return true;
}
bool GraphKernelExpander::Run(const FuncGraphPtr &func_graph) {
expand_ops_ = GetExpandOps();
return DoExpand(func_graph);
}
bool GraphKernelComplexExpander::Run(const FuncGraphPtr &func_graph) { return DoExpand(func_graph); }
} // namespace opt
} // namespace mindspore

View File

@ -17,6 +17,7 @@
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_EXPANDER_H_
#include <memory>
#include <vector>
#include <string>
#include <nlohmann/json.hpp>
#include "backend/optimizer/common/pass.h"
#include "ir/func_graph.h"
@ -34,22 +35,26 @@ class DefaultExpander : public Expander {
AnfNodePtr Run(const AnfNodePtr &node) override;
protected:
bool ExpandJsonInfo(const AnfNodePtr &node, nlohmann::json *kernel_json);
void EliminateRedundantParameters(const FuncGraphPtr &func_graph, AnfNodePtrList *inputs);
AnfNodePtr CreateExpandGraphKernel(const FuncGraphPtr &new_func_graph, const CNodePtr &old_node);
FuncGraphPtr CreateExpandFuncGraph(const CNodePtr &node);
virtual bool ExpandJsonInfo(const AnfNodePtr &node, nlohmann::json *kernel_json);
virtual void EliminateRedundantParameters(const FuncGraphPtr &func_graph, AnfNodePtrList *inputs);
virtual AnfNodePtr CreateExpandGraphKernel(const FuncGraphPtr &new_func_graph, const CNodePtr &old_node);
virtual FuncGraphPtr CreateExpandFuncGraph(const CNodePtr &node);
};
class ComplexOpExpander : public DefaultExpander {
protected:
bool ExpandJsonInfo(const AnfNodePtr &node, nlohmann::json *kernel_json);
};
class GraphKernelExpander : public Pass {
public:
GraphKernelExpander() : Pass("graph_kernel_expander") {}
explicit GraphKernelExpander(const std::string &name) : Pass(name) {}
~GraphKernelExpander() override = default;
bool Run(const FuncGraphPtr &func_graph) override;
private:
ExpanderPtr GetExpander(const AnfNodePtr &node);
bool DoExpand(const FuncGraphPtr &func_graph);
bool CanExpand(const CNodePtr &node) const {
protected:
virtual ExpanderPtr GetExpander(const AnfNodePtr &node);
virtual bool DoExpand(const FuncGraphPtr &func_graph);
virtual bool CanExpand(const CNodePtr &node) const {
return std::any_of(expand_ops_.begin(), expand_ops_.end(),
[&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); });
}
@ -57,6 +62,16 @@ class GraphKernelExpander : public Pass {
private:
std::vector<PrimitivePtr> expand_ops_;
};
class GraphKernelComplexExpander : public GraphKernelExpander {
public:
GraphKernelComplexExpander() : GraphKernelExpander("graph_kernel_complex_expander") {}
~GraphKernelComplexExpander() override = default;
bool Run(const FuncGraphPtr &func_graph) override;
protected:
ExpanderPtr GetExpander(const AnfNodePtr &node) override;
bool CanExpand(const CNodePtr &node) const override;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_EXPANDER_H_

View File

@ -27,6 +27,7 @@
#include "backend/optimizer/graph_kernel/arithmetic_simplify.h"
#include "backend/optimizer/graph_kernel/graph_kernel_cluster.h"
#include "backend/optimizer/graph_kernel/eliminate_redundant_output.h"
#include "backend/optimizer/graph_kernel/eliminate_redundant_complex.h"
#include "backend/optimizer/graph_kernel/insert_pad.h"
#include "backend/optimizer/graph_kernel/graph_kernel_splitter.h"
#include "backend/optimizer/graph_kernel/graph_kernel_expander.h"
@ -74,6 +75,10 @@ PassManagerPtr GraphKernelOptimizer::PreProcess() const {
PassManagerPtr GraphKernelOptimizer::Cluster() const {
auto pm = std::make_shared<GraphKernelPassManager>(1, "cluster");
// Expand complex op to composite kernels
pm->AddPass(std::make_shared<GraphKernelComplexExpander>(), OptLevel_1, false);
// Expand complex basic kernels to composite kernels
pm->AddPass(std::make_shared<GraphKernelExpander>(), OptLevel_1);
@ -108,6 +113,9 @@ PassManagerPtr GraphKernelOptimizer::HighLevelOpt1() const {
// Common subexpression elimination
pm->AddPass(std::make_shared<GraphKernelCSE>(), OptLevel_2);
// Elimate Redundant Complex op
pm->AddPass(std::make_shared<EliminateRedundantComplex>(), OptLevel_2, false);
return pm;
}

View File

@ -159,7 +159,7 @@ class Complex64 : public Number {
TypeId generic_type_id() const override { return kNumberTypeComplex64; }
TypePtr DeepCopy() const override { return std::make_shared<Complex64>(); }
std::string ToString() const override { return GetTypeName("Complex64"); }
std::string ToString() const override { return GetTypeName("Complex"); }
std::string ToReprString() const override { return nbits() == 0 ? "complex64_" : GetTypeName("complex64"); }
std::string DumpText() const override {
return nbits() == 0 ? std::string("Complex64") : std::string("C") + std::to_string(nbits());

View File

@ -31,10 +31,10 @@ from mindspore.nn.optim import Lamb, Momentum, AdamWeightDecay, thor
from mindspore import log as logger
from mindspore.common import set_seed
from src import BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell, \
BertTrainAccumulationAllReduceEachWithLossScaleCell, \
BertTrainAccumulationAllReducePostWithLossScaleCell, \
BertTrainOneStepWithLossScaleCellForAdam, \
AdamWeightDecayForBert, AdamWeightDecayOp
BertTrainAccumulationAllReduceEachWithLossScaleCell, \
BertTrainAccumulationAllReducePostWithLossScaleCell, \
BertTrainOneStepWithLossScaleCellForAdam, \
AdamWeightDecayForBert, AdamWeightDecayOp
from src.dataset import create_bert_dataset
from src.utils import LossCallBack, BertLearningRate
from src.model_utils.config import config as cfg, bert_net_cfg
@ -120,12 +120,6 @@ def _get_optimizer(args_opt, network):
return optimizer
def _auto_enable_graph_kernel(device_target, graph_kernel_mode):
"""Judge whether is suitable to enable graph kernel."""
return graph_kernel_mode in ("auto", "true") and device_target == 'GPU' and \
cfg.bert_network in ('base', 'large') and cfg.optimizer == 'AdamWeightDecay'
def _set_graph_kernel_context(device_target):
"""Add suitable graph kernel context for different configs."""
if device_target == 'GPU':
@ -247,7 +241,6 @@ def run_pretrain():
net_with_grads = BertTrainOneStepCell(net_with_loss, optimizer=optimizer, sens=cfg.Thor.loss_scale,
enable_clip_grad=False)
model = Model(net_with_grads)
model = ConvertModelUtils().convert_to_thor_model(model, network=net_with_grads, optimizer=optimizer)
model.train(new_repeat_count, ds, callbacks=callback,