support complex in gpu about graph_kernel
This commit is contained in:
parent
ed6cb44e7f
commit
aa019a639f
|
@ -22,6 +22,7 @@ from .bias_add import BiasAdd
|
|||
from .bias_add_grad import BiasAddGrad
|
||||
from .clip_by_norm_no_div_sum import ClipByNormNoDivSum
|
||||
from .conv2d import Conv2D
|
||||
from .complex import CAbs, CAdd, CDiv, CMul, CSub
|
||||
from .dropout_grad import DropoutGrad
|
||||
from .expand_dims import ExpandDims
|
||||
from .fused_adam import FusedAdam
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""complex expanders init"""
|
||||
|
||||
from .abs import CAbs
|
||||
from .add import CAdd
|
||||
from .div import CDiv
|
||||
from .mul import CMul
|
||||
from .sub import CSub
|
|
@ -0,0 +1,30 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
"""generate json desc for cabs"""
|
||||
from mindspore._extends.graph_kernel.expanders._utils import Expander
|
||||
|
||||
|
||||
class CAbs(Expander):
|
||||
"""CAbs expander"""
|
||||
|
||||
def _expand(self, graph_builder):
|
||||
input_x = self.inputs[0]
|
||||
x_real = graph_builder.emit('CReal', [input_x])
|
||||
x_imag = graph_builder.emit('CImag', [input_x])
|
||||
squre_x_real = graph_builder.emit('Mul', [x_real, x_real])
|
||||
squre_x_imag = graph_builder.emit('Mul', [x_imag, x_imag])
|
||||
squre_sum = graph_builder.emit('Add', [squre_x_real, squre_x_imag])
|
||||
result = graph_builder.emit('Sqrt', [squre_sum])
|
||||
return result
|
|
@ -0,0 +1,33 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
"""generate json desc for cadd"""
|
||||
from mindspore._extends.graph_kernel.model.model import DataFormat as DF
|
||||
from mindspore._extends.graph_kernel.expanders._utils import Expander, ExpanderInfoValidator as VLD
|
||||
|
||||
|
||||
@VLD.add_format(DF.DEFAULT, DF.DEFAULT)
|
||||
class CAdd(Expander):
|
||||
"""CAdd expander"""
|
||||
|
||||
def _expand(self, graph_builder):
|
||||
input_x, input_y = self.inputs
|
||||
x_real = graph_builder.emit('CReal', [input_x])
|
||||
y_real = graph_builder.emit('CReal', [input_y])
|
||||
x_imag = graph_builder.emit('CImag', [input_x])
|
||||
y_imag = graph_builder.emit('CImag', [input_y])
|
||||
result_real = graph_builder.emit('Add', [x_real, y_real])
|
||||
result_imag = graph_builder.emit('Add', [x_imag, y_imag])
|
||||
result = graph_builder.emit('Complex', [result_real, result_imag])
|
||||
return result
|
|
@ -0,0 +1,43 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
"""generate json desc for cdiv"""
|
||||
from mindspore._extends.graph_kernel.model.model import DataFormat as DF
|
||||
from mindspore._extends.graph_kernel.expanders._utils import Expander, ExpanderInfoValidator as VLD
|
||||
|
||||
|
||||
@VLD.add_format(DF.DEFAULT, DF.DEFAULT)
|
||||
class CDiv(Expander):
|
||||
"""CDiv expander"""
|
||||
|
||||
def _expand(self, graph_builder):
|
||||
"""CDiv Implementation"""
|
||||
input_x, input_y = self.inputs
|
||||
x_real = graph_builder.emit('CReal', [input_x])
|
||||
y_real = graph_builder.emit('CReal', [input_y])
|
||||
x_imag = graph_builder.emit('CImag', [input_x])
|
||||
y_imag = graph_builder.emit('CImag', [input_y])
|
||||
squre_y_real = graph_builder.emit('Mul', [y_real, y_real])
|
||||
squre_y_imag = graph_builder.emit('Mul', [y_imag, y_imag])
|
||||
final_denominator = graph_builder.emit('Add', [squre_y_real, squre_y_imag])
|
||||
x_real_mul_y_real = graph_builder.emit('Mul', [x_real, y_real])
|
||||
x_imag_mul_y_imag = graph_builder.emit('Mul', [x_imag, y_imag])
|
||||
x_real_mul_y_imag = graph_builder.emit('Mul', [x_real, y_imag])
|
||||
x_imag_mul_y_real = graph_builder.emit('Mul', [x_imag, y_real])
|
||||
final_numerator_real = graph_builder.emit('Add', [x_real_mul_y_real, x_imag_mul_y_imag])
|
||||
final_numerator_imag = graph_builder.emit('Sub', [x_imag_mul_y_real, x_real_mul_y_imag])
|
||||
result_real = graph_builder.emit('RealDiv', [final_numerator_real, final_denominator])
|
||||
result_imag = graph_builder.emit('RealDiv', [final_numerator_imag, final_denominator])
|
||||
result = graph_builder.emit('Complex', [result_real, result_imag])
|
||||
return result
|
|
@ -0,0 +1,38 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
"""generate json desc for cmul"""
|
||||
from mindspore._extends.graph_kernel.model.model import DataFormat as DF
|
||||
from mindspore._extends.graph_kernel.expanders._utils import Expander, ExpanderInfoValidator as VLD
|
||||
|
||||
|
||||
@VLD.add_format(DF.DEFAULT, DF.DEFAULT)
|
||||
class CMul(Expander):
|
||||
"""CMul expander"""
|
||||
|
||||
def _expand(self, graph_builder):
|
||||
"""CMul Implementation"""
|
||||
input_x, input_y = self.inputs
|
||||
x_real = graph_builder.emit('CReal', [input_x])
|
||||
y_real = graph_builder.emit('CReal', [input_y])
|
||||
x_imag = graph_builder.emit('CImag', [input_x])
|
||||
y_imag = graph_builder.emit('CImag', [input_y])
|
||||
x_real_mul_y_real = graph_builder.emit('Mul', [x_real, y_real])
|
||||
x_imag_mul_y_imag = graph_builder.emit('Mul', [x_imag, y_imag])
|
||||
x_real_mul_y_imag = graph_builder.emit('Mul', [x_real, y_imag])
|
||||
x_imag_mul_y_real = graph_builder.emit('Mul', [x_imag, y_real])
|
||||
result_real = graph_builder.emit('Sub', [x_real_mul_y_real, x_imag_mul_y_imag])
|
||||
result_imag = graph_builder.emit('Add', [x_real_mul_y_imag, x_imag_mul_y_real])
|
||||
result = graph_builder.emit('Complex', [result_real, result_imag])
|
||||
return result
|
|
@ -0,0 +1,33 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
"""generate json desc for csub"""
|
||||
from mindspore._extends.graph_kernel.model.model import DataFormat as DF
|
||||
from mindspore._extends.graph_kernel.expanders._utils import Expander, ExpanderInfoValidator as VLD
|
||||
|
||||
|
||||
@VLD.add_format(DF.DEFAULT, DF.DEFAULT)
|
||||
class CSub(Expander):
|
||||
"""CSub expander"""
|
||||
|
||||
def _expand(self, graph_builder):
|
||||
input_x, input_y = self.inputs
|
||||
x_real = graph_builder.emit('CReal', [input_x])
|
||||
y_real = graph_builder.emit('CReal', [input_y])
|
||||
x_imag = graph_builder.emit('CImag', [input_x])
|
||||
y_imag = graph_builder.emit('CImag', [input_y])
|
||||
result_real = graph_builder.emit('Sub', [x_real, y_real])
|
||||
result_imag = graph_builder.emit('Sub', [x_imag, y_imag])
|
||||
result = graph_builder.emit('Complex', [result_real, result_imag])
|
||||
return result
|
|
@ -192,6 +192,9 @@ class PrimLib:
|
|||
'UnPadAkg': Prim(OPAQUE),
|
||||
'PadAkg': Prim(OPAQUE),
|
||||
'Conv2D': Prim(OPAQUE),
|
||||
'CReal': Prim(ELEMWISE),
|
||||
'CImag': Prim(ELEMWISE),
|
||||
'Complex': Prim(ELEMWISE),
|
||||
}
|
||||
|
||||
default_primtive = Prim(UNKNOWN)
|
||||
|
@ -466,6 +469,7 @@ class GraphVisitor:
|
|||
for i in range(len(graph.ops)-1, -1, -1):
|
||||
self.visit(graph.ops[i])
|
||||
|
||||
|
||||
class AlignShape(GraphVisitor):
|
||||
"""Align shape"""
|
||||
|
||||
|
@ -484,6 +488,7 @@ class AlignShape(GraphVisitor):
|
|||
if align_dim > out_dim:
|
||||
op.output.shape = [1] * (align_dim - out_dim) + op.output.shape
|
||||
|
||||
|
||||
class AddControlBuddy(GraphVisitor):
|
||||
"""Add control buddy"""
|
||||
|
||||
|
|
|
@ -255,6 +255,38 @@ class _CompareOp(_Elemwise):
|
|||
return "bool"
|
||||
|
||||
|
||||
class CImag(OpInfer):
|
||||
def _check_type(self):
|
||||
if self.inputs[0].dtype != "complex64":
|
||||
raise GKException(
|
||||
"CImag's input[0] should be a complex64 condition but got {}".format(self.inputs[0].dtype))
|
||||
|
||||
def _infer_type(self):
|
||||
return "float32"
|
||||
|
||||
|
||||
class CReal(OpInfer):
|
||||
def _check_type(self):
|
||||
if self.inputs[0].dtype != "complex64":
|
||||
raise GKException(
|
||||
"CReal's input[0] should be a complex64 condition but got {}".format(self.inputs[0].dtype))
|
||||
|
||||
def _infer_type(self):
|
||||
return "float32"
|
||||
|
||||
|
||||
class Complex(OpInfer):
|
||||
def _check_type(self):
|
||||
if self.inputs[0].dtype != "float32":
|
||||
raise GKException(
|
||||
"Complex's input[0] should be a float32 condition but got {}".format(self.inputs[0].dtype))
|
||||
if self.inputs[0].dtype != self.inputs[1].dtype:
|
||||
raise GKException("Complex's input mismatch ({} vs {})".format(self.inputs[0].dtype, self.inputs[1].dtype))
|
||||
|
||||
def _infer_type(self):
|
||||
return "complex64"
|
||||
|
||||
|
||||
class Less(_CompareOp):
|
||||
pass
|
||||
|
||||
|
|
|
@ -44,8 +44,7 @@ const std::unordered_map<std::string, TypeId> type_id_maps = {
|
|||
{"int64", TypeId::kNumberTypeInt64}, {"uint", TypeId::kNumberTypeUInt},
|
||||
{"uint8", TypeId::kNumberTypeUInt8}, {"uint16", TypeId::kNumberTypeUInt16},
|
||||
{"uint32", TypeId::kNumberTypeUInt32}, {"uint64", TypeId::kNumberTypeUInt64},
|
||||
{"bool", TypeId::kNumberTypeBool},
|
||||
};
|
||||
{"bool", TypeId::kNumberTypeBool}, {"complex64", TypeId::kNumberTypeComplex64}};
|
||||
|
||||
const std::map<TypeId, std::string> type_id_str_map = {
|
||||
{TypeId::kNumberTypeFloat32, "float32"}, {TypeId::kNumberTypeFloat16, "float16"},
|
||||
|
@ -55,8 +54,7 @@ const std::map<TypeId, std::string> type_id_str_map = {
|
|||
{TypeId::kNumberTypeInt64, "int64"}, {TypeId::kNumberTypeUInt, "uint"},
|
||||
{TypeId::kNumberTypeUInt8, "uint8"}, {TypeId::kNumberTypeUInt16, "uint16"},
|
||||
{TypeId::kNumberTypeUInt32, "uint32"}, {TypeId::kNumberTypeUInt64, "uint64"},
|
||||
{TypeId::kNumberTypeBool, "bool"},
|
||||
};
|
||||
{TypeId::kNumberTypeBool, "bool"}, {TypeId::kNumberTypeComplex64, "complex64"}};
|
||||
|
||||
const std::unordered_map<std::string, std::string> dtype_shortdtype_map_ = {
|
||||
{"float16", "f16"}, {"float32", "f32"}, {"float64", "f64"}, {"int8", "i8"}, {"int16", "i16"}, {"int32", "i32"},
|
||||
|
@ -64,11 +62,11 @@ const std::unordered_map<std::string, std::string> dtype_shortdtype_map_ = {
|
|||
};
|
||||
|
||||
const std::unordered_map<std::string, size_t> dtype_nbyte_map = {
|
||||
{"float16", sizeof(float) / 2}, {"float32", sizeof(float)}, {"float64", sizeof(float) * 2},
|
||||
{"int8", sizeof(int) / 4}, {"int16", sizeof(int) / 2}, {"int32", sizeof(int)},
|
||||
{"int64", sizeof(int) * 2}, {"uint8", sizeof(int) / 4}, {"uint16", sizeof(int) / 2},
|
||||
{"uint32", sizeof(int)}, {"uint64", sizeof(int) * 2}, {"bool", sizeof(char)},
|
||||
};
|
||||
{"float16", sizeof(float) / 2}, {"float32", sizeof(float)}, {"float64", sizeof(float) * 2},
|
||||
{"int8", sizeof(int) / 4}, {"int16", sizeof(int) / 2}, {"int32", sizeof(int)},
|
||||
{"int64", sizeof(int) * 2}, {"uint8", sizeof(int) / 4}, {"uint16", sizeof(int) / 2},
|
||||
{"uint32", sizeof(int)}, {"uint64", sizeof(int) * 2}, {"bool", sizeof(char)},
|
||||
{"complex64", sizeof(float) * 2}};
|
||||
|
||||
const std::unordered_map<std::string, FusionType> fusion_type_maps = {
|
||||
{"CONVLUTION", FusionType::CONVLUTION}, {"ELEMWISE", FusionType::ELEMWISE}, {"COMMREDUCE", FusionType::COMMREDUCE},
|
||||
|
|
|
@ -0,0 +1,77 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/optimizer/graph_kernel/eliminate_redundant_complex.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "frontend/optimizer/irpass.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "backend/kernel_compiler/common_utils.h"
|
||||
#include "debug/anf_ir_dump.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
bool EliminateRedudantComplexInGraphkernel(const FuncGraphPtr &func_graph) {
|
||||
auto mng = func_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(mng);
|
||||
auto todos = TopoSort(func_graph->get_return());
|
||||
bool changed = false;
|
||||
for (const auto &node : todos) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
// Find all Complex node in graphkernel sub_graph
|
||||
if (cnode != nullptr && IsPrimitiveCNode(cnode, std::make_shared<Primitive>("Complex"))) {
|
||||
auto original_users = mng->node_users()[cnode];
|
||||
for (const auto &getitem_iter : original_users) {
|
||||
auto getitem = getitem_iter.first;
|
||||
auto getitem_cnode = getitem->cast<CNodePtr>();
|
||||
// Find all complex users which are CReal or CImag, then use Complex inputs replace them.
|
||||
if (IsPrimitiveCNode(getitem_cnode, std::make_shared<Primitive>("CReal"))) {
|
||||
mng->Replace(getitem, cnode->inputs()[1]);
|
||||
changed = true;
|
||||
} else if (IsPrimitiveCNode(getitem_cnode, std::make_shared<Primitive>("CImag"))) {
|
||||
mng->Replace(getitem, cnode->inputs()[2]);
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool EliminateRedundantComplex::Run(const FuncGraphPtr &func_graph) {
|
||||
auto mng = func_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(mng);
|
||||
bool changed = false;
|
||||
auto todos = TopoSort(func_graph->get_return());
|
||||
std::reverse(todos.begin(), todos.end());
|
||||
for (const auto &node : todos) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
// Check whether graph_kernel node
|
||||
if (cnode != nullptr && AnfAlgo::IsGraphKernel(cnode)) {
|
||||
auto graph_kernel_fg = AnfAlgo::GetCNodeFuncGraphPtr(cnode);
|
||||
MS_EXCEPTION_IF_NULL(graph_kernel_fg);
|
||||
changed = EliminateRedudantComplexInGraphkernel(graph_kernel_fg) || changed;
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,34 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_ELIMINATE_REDUNDANT_COMPLEX_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_ELIMINATE_REDUNDANT_COMPLEX_H_
|
||||
#include <map>
|
||||
#include <memory>
|
||||
|
||||
#include "backend/optimizer/common/pass.h"
|
||||
#include "ir/func_graph.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class EliminateRedundantComplex : public Pass {
|
||||
public:
|
||||
EliminateRedundantComplex() : Pass("eliminate_redundant_complex") {}
|
||||
~EliminateRedundantComplex() override = default;
|
||||
bool Run(const FuncGraphPtr &graph) override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_ELIMINATE_REDUNDANT_COMPLEX_H_
|
|
@ -20,6 +20,7 @@
|
|||
#include <set>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
|
||||
#include "utils/context/graph_kernel_flags.h"
|
||||
#include "backend/kernel_compiler/akg/akg_kernel_json_generator.h"
|
||||
|
@ -205,7 +206,7 @@ bool GraphKernelExpander::DoExpand(const FuncGraphPtr &func_graph) {
|
|||
MS_EXCEPTION_IF_NULL(mng);
|
||||
for (const auto &n : todos) {
|
||||
auto node = n->cast<CNodePtr>();
|
||||
if (node == nullptr || IsKeepBasicNode(node) || !AnfAlgo::IsRealKernel(node) || AnfAlgo::IsGraphKernel(node) ||
|
||||
if (node == nullptr || AnfAlgo::IsGraphKernel(node) || IsKeepBasicNode(node) || !AnfAlgo::IsRealKernel(node) ||
|
||||
!CanExpand(node)) {
|
||||
continue;
|
||||
}
|
||||
|
@ -221,10 +222,57 @@ bool GraphKernelExpander::DoExpand(const FuncGraphPtr &func_graph) {
|
|||
}
|
||||
return changed;
|
||||
}
|
||||
bool GraphKernelComplexExpander::CanExpand(const CNodePtr &node) const {
|
||||
bool has_complex = false;
|
||||
auto all_inputs_type = AnfAlgo::GetAllInputDeviceTypes(node);
|
||||
for (size_t i = 0; i < all_inputs_type.size(); ++i) {
|
||||
if (all_inputs_type[i] == kNumberTypeFloat64 || all_inputs_type[i] == kNumberTypeComplex64) {
|
||||
has_complex = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
return has_complex;
|
||||
}
|
||||
|
||||
// Just test for complex op, then will be deleted
|
||||
ExpanderPtr GraphKernelComplexExpander::GetExpander(const AnfNodePtr &node) {
|
||||
return std::make_shared<ComplexOpExpander>();
|
||||
}
|
||||
bool ComplexOpExpander::ExpandJsonInfo(const AnfNodePtr &node, nlohmann::json *kernel_json) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
auto all_inputs_type = AnfAlgo::GetAllInputDeviceTypes(cnode);
|
||||
for (size_t i = 0; i < all_inputs_type.size(); ++i) {
|
||||
if (all_inputs_type[i] == kNumberTypeFloat64 || all_inputs_type[i] == kNumberTypeComplex64) {
|
||||
all_inputs_type[i] = kNumberTypeComplex64;
|
||||
}
|
||||
}
|
||||
|
||||
auto all_outputs_type = AnfAlgo::GetAllOutputDeviceTypes(cnode);
|
||||
for (size_t i = 0; i < all_outputs_type.size(); ++i) {
|
||||
if (all_outputs_type[i] == kNumberTypeFloat64) {
|
||||
all_outputs_type[i] = kNumberTypeComplex64;
|
||||
}
|
||||
}
|
||||
auto all_inputs_format = AnfAlgo::GetAllInputFormats(cnode);
|
||||
auto all_outputs_format = AnfAlgo::GetAllOutputFormats(cnode);
|
||||
auto graph_sel_info =
|
||||
BuildSelectKernelBuildInfo(all_inputs_format, all_inputs_type, all_outputs_format, all_outputs_type);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(graph_sel_info, cnode.get());
|
||||
std::vector<size_t> original_shape = AnfAlgo::GetOutputInferShape(cnode, 0);
|
||||
ShapeVector real_shape;
|
||||
std::copy(original_shape.begin(), original_shape.end(), std::back_inserter(real_shape));
|
||||
auto complex_shape_ptr = std::make_shared<abstract::Shape>(abstract::Shape(real_shape));
|
||||
TypeId complex_type = kNumberTypeComplex64;
|
||||
auto abstract = std::make_shared<abstract::AbstractTensor>(TypeIdToType(complex_type), complex_shape_ptr);
|
||||
cnode->set_abstract(abstract);
|
||||
if (!DefaultExpander::ExpandJsonInfo(cnode, kernel_json)) return false;
|
||||
(*kernel_json)["name"] = std::string("C") + AnfAlgo::GetCNodeName(cnode);
|
||||
return true;
|
||||
}
|
||||
bool GraphKernelExpander::Run(const FuncGraphPtr &func_graph) {
|
||||
expand_ops_ = GetExpandOps();
|
||||
return DoExpand(func_graph);
|
||||
}
|
||||
bool GraphKernelComplexExpander::Run(const FuncGraphPtr &func_graph) { return DoExpand(func_graph); }
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_EXPANDER_H_
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <nlohmann/json.hpp>
|
||||
#include "backend/optimizer/common/pass.h"
|
||||
#include "ir/func_graph.h"
|
||||
|
@ -34,22 +35,26 @@ class DefaultExpander : public Expander {
|
|||
AnfNodePtr Run(const AnfNodePtr &node) override;
|
||||
|
||||
protected:
|
||||
bool ExpandJsonInfo(const AnfNodePtr &node, nlohmann::json *kernel_json);
|
||||
void EliminateRedundantParameters(const FuncGraphPtr &func_graph, AnfNodePtrList *inputs);
|
||||
AnfNodePtr CreateExpandGraphKernel(const FuncGraphPtr &new_func_graph, const CNodePtr &old_node);
|
||||
FuncGraphPtr CreateExpandFuncGraph(const CNodePtr &node);
|
||||
virtual bool ExpandJsonInfo(const AnfNodePtr &node, nlohmann::json *kernel_json);
|
||||
virtual void EliminateRedundantParameters(const FuncGraphPtr &func_graph, AnfNodePtrList *inputs);
|
||||
virtual AnfNodePtr CreateExpandGraphKernel(const FuncGraphPtr &new_func_graph, const CNodePtr &old_node);
|
||||
virtual FuncGraphPtr CreateExpandFuncGraph(const CNodePtr &node);
|
||||
};
|
||||
class ComplexOpExpander : public DefaultExpander {
|
||||
protected:
|
||||
bool ExpandJsonInfo(const AnfNodePtr &node, nlohmann::json *kernel_json);
|
||||
};
|
||||
|
||||
class GraphKernelExpander : public Pass {
|
||||
public:
|
||||
GraphKernelExpander() : Pass("graph_kernel_expander") {}
|
||||
explicit GraphKernelExpander(const std::string &name) : Pass(name) {}
|
||||
~GraphKernelExpander() override = default;
|
||||
bool Run(const FuncGraphPtr &func_graph) override;
|
||||
|
||||
private:
|
||||
ExpanderPtr GetExpander(const AnfNodePtr &node);
|
||||
bool DoExpand(const FuncGraphPtr &func_graph);
|
||||
bool CanExpand(const CNodePtr &node) const {
|
||||
protected:
|
||||
virtual ExpanderPtr GetExpander(const AnfNodePtr &node);
|
||||
virtual bool DoExpand(const FuncGraphPtr &func_graph);
|
||||
virtual bool CanExpand(const CNodePtr &node) const {
|
||||
return std::any_of(expand_ops_.begin(), expand_ops_.end(),
|
||||
[&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); });
|
||||
}
|
||||
|
@ -57,6 +62,16 @@ class GraphKernelExpander : public Pass {
|
|||
private:
|
||||
std::vector<PrimitivePtr> expand_ops_;
|
||||
};
|
||||
class GraphKernelComplexExpander : public GraphKernelExpander {
|
||||
public:
|
||||
GraphKernelComplexExpander() : GraphKernelExpander("graph_kernel_complex_expander") {}
|
||||
~GraphKernelComplexExpander() override = default;
|
||||
bool Run(const FuncGraphPtr &func_graph) override;
|
||||
|
||||
protected:
|
||||
ExpanderPtr GetExpander(const AnfNodePtr &node) override;
|
||||
bool CanExpand(const CNodePtr &node) const override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_EXPANDER_H_
|
||||
|
|
|
@ -27,6 +27,7 @@
|
|||
#include "backend/optimizer/graph_kernel/arithmetic_simplify.h"
|
||||
#include "backend/optimizer/graph_kernel/graph_kernel_cluster.h"
|
||||
#include "backend/optimizer/graph_kernel/eliminate_redundant_output.h"
|
||||
#include "backend/optimizer/graph_kernel/eliminate_redundant_complex.h"
|
||||
#include "backend/optimizer/graph_kernel/insert_pad.h"
|
||||
#include "backend/optimizer/graph_kernel/graph_kernel_splitter.h"
|
||||
#include "backend/optimizer/graph_kernel/graph_kernel_expander.h"
|
||||
|
@ -74,6 +75,10 @@ PassManagerPtr GraphKernelOptimizer::PreProcess() const {
|
|||
|
||||
PassManagerPtr GraphKernelOptimizer::Cluster() const {
|
||||
auto pm = std::make_shared<GraphKernelPassManager>(1, "cluster");
|
||||
|
||||
// Expand complex op to composite kernels
|
||||
pm->AddPass(std::make_shared<GraphKernelComplexExpander>(), OptLevel_1, false);
|
||||
|
||||
// Expand complex basic kernels to composite kernels
|
||||
pm->AddPass(std::make_shared<GraphKernelExpander>(), OptLevel_1);
|
||||
|
||||
|
@ -108,6 +113,9 @@ PassManagerPtr GraphKernelOptimizer::HighLevelOpt1() const {
|
|||
|
||||
// Common subexpression elimination
|
||||
pm->AddPass(std::make_shared<GraphKernelCSE>(), OptLevel_2);
|
||||
|
||||
// Elimate Redundant Complex op
|
||||
pm->AddPass(std::make_shared<EliminateRedundantComplex>(), OptLevel_2, false);
|
||||
return pm;
|
||||
}
|
||||
|
||||
|
|
|
@ -159,7 +159,7 @@ class Complex64 : public Number {
|
|||
|
||||
TypeId generic_type_id() const override { return kNumberTypeComplex64; }
|
||||
TypePtr DeepCopy() const override { return std::make_shared<Complex64>(); }
|
||||
std::string ToString() const override { return GetTypeName("Complex64"); }
|
||||
std::string ToString() const override { return GetTypeName("Complex"); }
|
||||
std::string ToReprString() const override { return nbits() == 0 ? "complex64_" : GetTypeName("complex64"); }
|
||||
std::string DumpText() const override {
|
||||
return nbits() == 0 ? std::string("Complex64") : std::string("C") + std::to_string(nbits());
|
||||
|
|
|
@ -31,10 +31,10 @@ from mindspore.nn.optim import Lamb, Momentum, AdamWeightDecay, thor
|
|||
from mindspore import log as logger
|
||||
from mindspore.common import set_seed
|
||||
from src import BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell, \
|
||||
BertTrainAccumulationAllReduceEachWithLossScaleCell, \
|
||||
BertTrainAccumulationAllReducePostWithLossScaleCell, \
|
||||
BertTrainOneStepWithLossScaleCellForAdam, \
|
||||
AdamWeightDecayForBert, AdamWeightDecayOp
|
||||
BertTrainAccumulationAllReduceEachWithLossScaleCell, \
|
||||
BertTrainAccumulationAllReducePostWithLossScaleCell, \
|
||||
BertTrainOneStepWithLossScaleCellForAdam, \
|
||||
AdamWeightDecayForBert, AdamWeightDecayOp
|
||||
from src.dataset import create_bert_dataset
|
||||
from src.utils import LossCallBack, BertLearningRate
|
||||
from src.model_utils.config import config as cfg, bert_net_cfg
|
||||
|
@ -120,12 +120,6 @@ def _get_optimizer(args_opt, network):
|
|||
return optimizer
|
||||
|
||||
|
||||
def _auto_enable_graph_kernel(device_target, graph_kernel_mode):
|
||||
"""Judge whether is suitable to enable graph kernel."""
|
||||
return graph_kernel_mode in ("auto", "true") and device_target == 'GPU' and \
|
||||
cfg.bert_network in ('base', 'large') and cfg.optimizer == 'AdamWeightDecay'
|
||||
|
||||
|
||||
def _set_graph_kernel_context(device_target):
|
||||
"""Add suitable graph kernel context for different configs."""
|
||||
if device_target == 'GPU':
|
||||
|
@ -247,7 +241,6 @@ def run_pretrain():
|
|||
net_with_grads = BertTrainOneStepCell(net_with_loss, optimizer=optimizer, sens=cfg.Thor.loss_scale,
|
||||
enable_clip_grad=False)
|
||||
|
||||
|
||||
model = Model(net_with_grads)
|
||||
model = ConvertModelUtils().convert_to_thor_model(model, network=net_with_grads, optimizer=optimizer)
|
||||
model.train(new_repeat_count, ds, callbacks=callback,
|
||||
|
|
Loading…
Reference in New Issue