forked from mindspore-Ecosystem/mindspore
add low precison
This commit is contained in:
parent
8666a336d5
commit
4105a247b7
|
@ -0,0 +1,273 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include "base/core_ops.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
#include "backend/optimizer/ascend/ascend_helper.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
|
||||
#include "runtime/device/kernel_info.h"
|
||||
#include "backend/kernel_compiler/kernel_build_info.h"
|
||||
#include "backend/kernel_compiler/common_utils.h"
|
||||
#include "backend/optimizer/graph_kernel/decrease_compute_precision.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
||||
// Add CastCNode
|
||||
CNodePtr AddCastCNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const std::string &format,
|
||||
const TypeId &input_type, const TypeId &output_type, const std::vector<size_t> &origin_shape,
|
||||
const TypeId &origin_type) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
std::string input_format = format;
|
||||
std::string output_format = format;
|
||||
CNodePtr cast = func_graph->NewCNode({NewValueNode(std::make_shared<Primitive>(prim::kPrimCast->name())), input});
|
||||
MS_EXCEPTION_IF_NULL(cast);
|
||||
// set kernel build info
|
||||
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
|
||||
builder.SetInputsFormat({input_format});
|
||||
builder.SetOutputsFormat({output_format});
|
||||
builder.SetInputsDeviceType({input_type});
|
||||
builder.SetOutputsDeviceType({output_type});
|
||||
builder.SetFusionType(kernel::FusionType::OPAQUE);
|
||||
builder.SetProcessor(kernel::Processor::AICORE);
|
||||
builder.SetKernelType(KernelType::AKG_KERNEL);
|
||||
if (cast->kernel_info() == nullptr) {
|
||||
auto kernel_info = std::make_shared<device::KernelInfo>();
|
||||
cast->set_kernel_info(kernel_info);
|
||||
}
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cast.get());
|
||||
AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {origin_shape}, cast.get());
|
||||
AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(true), cast);
|
||||
AnfAlgo::SetNodeAttr(kAttrDatadumpOriginalNames, MakeValue<std::vector<std::string>>({}), cast);
|
||||
return cast;
|
||||
}
|
||||
|
||||
// Update Output Abatract and BuildInfo as Input Changed
|
||||
void UpdateOutputInfo(const AnfNodePtr &cnode) {
|
||||
if (!AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimMakeTuple)) {
|
||||
ShapeVector out_shape = GetShape(cnode);
|
||||
auto abs_shape_ptr = std::make_shared<abstract::Shape>(abstract::Shape(out_shape));
|
||||
auto abstract = std::make_shared<abstract::AbstractTensor>(TypeIdToType(TypeId::kNumberTypeFloat16), abs_shape_ptr);
|
||||
cnode->set_abstract(abstract);
|
||||
std::vector<std::string> input_formats = AnfAlgo::GetAllInputFormats(cnode);
|
||||
std::vector<TypeId> input_types = AnfAlgo::GetAllInputDeviceTypes(cnode);
|
||||
for (size_t i = 0; i < input_types.size(); i++) {
|
||||
input_types[i] = TypeId::kNumberTypeFloat16;
|
||||
}
|
||||
std::vector<std::string> output_formats = AnfAlgo::GetAllOutputFormats(cnode);
|
||||
std::vector<TypeId> output_types = {TypeId::kNumberTypeFloat16};
|
||||
auto graph_sel_info = BuildSelectKernelBuildInfo(input_formats, input_types, output_formats, output_types, cnode);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(graph_sel_info, cnode.get());
|
||||
}
|
||||
}
|
||||
|
||||
CNodePtr InsertCastForGraphKernel(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto mng = func_graph->manager();
|
||||
size_t in_num = AnfAlgo::GetInputNum(cnode); // include monads.
|
||||
for (size_t input_index = 0; input_index < in_num; ++input_index) {
|
||||
auto cur_input = AnfAlgo::GetInputNode(cnode, input_index);
|
||||
if (HasAbstractMonad(cur_input)) {
|
||||
continue;
|
||||
}
|
||||
auto prev_node = AnfAlgo::GetPrevNodeOutput(cnode, input_index);
|
||||
auto in_node = prev_node.first;
|
||||
auto in_index = prev_node.second;
|
||||
auto ori_shape = AnfAlgo::GetOutputDeviceShape(in_node, in_index);
|
||||
auto ori_format = AnfAlgo::GetOutputFormat(in_node, in_index);
|
||||
auto ori_dtype = AnfAlgo::GetOutputDeviceDataType(in_node, in_index);
|
||||
const std::string dev_fmt = AnfAlgo::GetInputFormat(cnode, input_index);
|
||||
if (cur_input->isa<ValueNode>()) {
|
||||
ori_dtype = cur_input->cast<ValueNodePtr>()->value()->cast<tensor::TensorPtr>()->data_type();
|
||||
}
|
||||
auto new_dtype = TypeId::kNumberTypeFloat16;
|
||||
if (ori_dtype == TypeId::kNumberTypeFloat32) {
|
||||
if (cur_input->isa<ValueNode>()) {
|
||||
auto valuePtr = cur_input->cast<ValueNodePtr>();
|
||||
auto itensor = std::make_shared<tensor::Tensor>(
|
||||
TypeId::kNumberTypeFloat16, valuePtr->value()->cast<tensor::TensorPtr>()->shape(),
|
||||
valuePtr->value()->cast<tensor::TensorPtr>()->data_c(), TypeId::kNumberTypeFloat32);
|
||||
auto value_node = std::make_shared<ValueNode>(itensor);
|
||||
value_node->set_abstract(itensor->ToAbstract());
|
||||
mng->Replace(cur_input, value_node);
|
||||
}
|
||||
auto cast = AddCastCNode(func_graph, cur_input, dev_fmt, ori_dtype, new_dtype, ori_shape, new_dtype);
|
||||
MS_EXCEPTION_IF_NULL(cast);
|
||||
cast->set_scope(cnode->scope());
|
||||
ShapeVector out_shape = GetShape(cur_input);
|
||||
auto abs_shape_ptr = std::make_shared<abstract::Shape>(abstract::Shape(out_shape));
|
||||
auto abstract =
|
||||
std::make_shared<abstract::AbstractTensor>(TypeIdToType(TypeId::kNumberTypeFloat16), abs_shape_ptr);
|
||||
cast->set_abstract(abstract);
|
||||
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), cast);
|
||||
mng->Replace(cur_input, cast);
|
||||
}
|
||||
}
|
||||
CNodePtr new_node = nullptr;
|
||||
new_node = std::make_shared<CNode>(*cnode);
|
||||
MS_EXCEPTION_IF_NULL(new_node);
|
||||
UpdateOutputInfo(new_node);
|
||||
return new_node;
|
||||
}
|
||||
|
||||
bool DecreaseComputePrecision::Process(const FuncGraphPtr &func_graph) {
|
||||
auto mng = func_graph->manager();
|
||||
if (mng == nullptr) {
|
||||
mng = Manage(func_graph, true);
|
||||
func_graph->set_manager(mng);
|
||||
}
|
||||
auto todos = TopoSort(func_graph->get_return());
|
||||
bool changed = false;
|
||||
// Cast Down CNODES
|
||||
for (auto node : todos) {
|
||||
if (node->isa<CNode>() && !AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (IsPrimitiveCNode(cnode, prim::kPrimCast)) {
|
||||
if (AnfAlgo::GetOutputDeviceDataType(cnode->input(1), 0) == kNumberTypeFloat16) {
|
||||
auto in_node = cnode->input(1);
|
||||
mng->Replace(node, in_node);
|
||||
changed = true;
|
||||
continue;
|
||||
}
|
||||
if (AnfAlgo::GetOutputDeviceDataType(cnode->input(1), 0) == kNumberTypeFloat32 &&
|
||||
AnfAlgo::GetOutputDeviceDataType(cnode, 0) == kNumberTypeFloat16) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
auto new_node = InsertCastForGraphKernel(func_graph, cnode);
|
||||
mng->Replace(node, new_node);
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
if (changed) {
|
||||
mng->RemoveRoots();
|
||||
mng->KeepRoots({func_graph});
|
||||
}
|
||||
|
||||
// Cast Up Outputs
|
||||
auto old_output = func_graph->output()->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(old_output);
|
||||
auto add_cast = [&func_graph](const CNodePtr &old_cnode, bool is_output, std::vector<AnfNodePtr> &new_inputs) {
|
||||
AnfNodePtrList inputs1 = {NewValueNode(prim::kPrimCast), old_cnode};
|
||||
auto cnode1 = func_graph->NewCNode(inputs1);
|
||||
func_graph->AddNode(cnode1);
|
||||
ShapeVector cast_shape = GetShape(old_cnode);
|
||||
auto shape_ptr = std::make_shared<abstract::Shape>(abstract::Shape(cast_shape));
|
||||
auto new_abstract = std::make_shared<abstract::AbstractTensor>(TypeIdToType(TypeId::kNumberTypeFloat32), shape_ptr);
|
||||
cnode1->set_abstract(new_abstract);
|
||||
cnode1->set_scope(old_cnode->scope());
|
||||
SetNodeAttrSafely("dst_type", MakeValue(kernel::TypeId2String(kFloat32->type_id())), cnode1);
|
||||
MS_EXCEPTION_IF_NULL(cnode1);
|
||||
cnode1->set_kernel_info(std::make_shared<device::KernelInfo>());
|
||||
std::vector<std::string> cnode_input_format = {GetFormat(old_cnode)};
|
||||
std::vector<TypeId> cnode_input_type = {kNumberTypeFloat16};
|
||||
std::vector<std::string> cnode_output_format = {GetFormat(old_cnode)};
|
||||
std::vector<TypeId> cnode_output_type = {kNumberTypeFloat32};
|
||||
kernel::KernelBuildInfo::KernelBuildInfoBuilder graph_info_builder;
|
||||
graph_info_builder.SetInputsFormat(cnode_input_format);
|
||||
graph_info_builder.SetInputsDeviceType(cnode_input_type);
|
||||
graph_info_builder.SetOutputsFormat(cnode_output_format);
|
||||
graph_info_builder.SetOutputsDeviceType(cnode_output_type);
|
||||
graph_info_builder.SetProcessor(kernel::GetProcessorFromContext());
|
||||
graph_info_builder.SetKernelType(KernelType::AKG_KERNEL);
|
||||
graph_info_builder.SetFusionType(kernel::FusionType::OPAQUE);
|
||||
auto info_1 = graph_info_builder.Build();
|
||||
AnfAlgo::SetSelectKernelBuildInfo(info_1, cnode1.get());
|
||||
if (is_output) {
|
||||
func_graph->set_output(cnode1);
|
||||
} else {
|
||||
new_inputs.emplace_back(cnode1);
|
||||
}
|
||||
};
|
||||
|
||||
std::vector<AnfNodePtr> new_inputs;
|
||||
if (AnfAlgo::CheckPrimitiveType(old_output, prim::kPrimMakeTuple)) {
|
||||
new_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
|
||||
auto all_out = AnfAlgo::GetAllOutput(old_output);
|
||||
for (const auto &out : all_out) {
|
||||
auto c_out = out->cast<CNodePtr>();
|
||||
if (c_out) {
|
||||
add_cast(c_out, false, new_inputs);
|
||||
}
|
||||
}
|
||||
old_output->set_inputs(new_inputs);
|
||||
} else {
|
||||
add_cast(old_output, true, new_inputs);
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
bool IsCastUnAware(const FuncGraphPtr &func_graph) {
|
||||
std::vector<PrimitivePtr> cast_aware_list = {prim::kPrimReduceSum, prim::kPrimReduceMean, prim::kPrimReduceAll};
|
||||
auto mng = func_graph->manager();
|
||||
if (mng == nullptr) {
|
||||
mng = Manage(func_graph, true);
|
||||
func_graph->set_manager(mng);
|
||||
}
|
||||
|
||||
auto graph_name = GetValue<std::string>(func_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
|
||||
if (graph_name.find("atomic") != std::string::npos) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto todos = TopoSort(func_graph->get_return());
|
||||
for (auto node : todos) {
|
||||
if (node->isa<CNode>()) {
|
||||
if (std::find(cast_aware_list.begin(), cast_aware_list.end(), AnfAlgo::GetCNodePrimitive(node)) !=
|
||||
cast_aware_list.end()) {
|
||||
return false;
|
||||
}
|
||||
auto itype_id = AnfAlgo::GetOutputDeviceDataType(node, 0);
|
||||
if (itype_id != TypeId::kNumberTypeFloat16 && itype_id != TypeId::kNumberTypeFloat32) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool DecreaseComputePrecision::Run(const FuncGraphPtr &func_graph) {
|
||||
auto mng = func_graph->manager();
|
||||
if (mng == nullptr) {
|
||||
mng = Manage(func_graph, true);
|
||||
func_graph->set_manager(mng);
|
||||
}
|
||||
auto todos = TopoSort(func_graph->get_return());
|
||||
bool changed = false;
|
||||
for (const auto &node : todos) {
|
||||
if (AnfAlgo::IsGraphKernel(node)) {
|
||||
auto sub_func_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
|
||||
MS_ERROR_IF_NULL(sub_func_graph);
|
||||
if (IsCastUnAware(sub_func_graph)) {
|
||||
changed = Process(sub_func_graph) || changed;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (changed) {
|
||||
mng->RemoveRoots();
|
||||
mng->KeepRoots({func_graph});
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,43 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_DECREASE_COMPUTE_PRECISION_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_DECREASE_COMPUTE_PRECISION_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
#include "backend/session/kernel_graph.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class DecreaseComputePrecision : public Pass {
|
||||
public:
|
||||
explicit DecreaseComputePrecision(const std::vector<PrimitivePtr> &black_list = {})
|
||||
: Pass("decrease_compute_precision"), black_list_(black_list) {}
|
||||
~DecreaseComputePrecision() override = default;
|
||||
bool Run(const FuncGraphPtr &func_graph) override;
|
||||
|
||||
private:
|
||||
bool Process(const FuncGraphPtr &func_graph);
|
||||
AnfNodePtr CreateCast(const AnfNodePtr &input, const TypePtr &dst_type, const std::string &format) const;
|
||||
std::vector<PrimitivePtr> black_list_;
|
||||
};
|
||||
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_DECREASE_COMPUTE_PRECISION_H_
|
|
@ -0,0 +1,297 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include "base/core_ops.h"
|
||||
#include "utils/utils.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "ir/tensor.h"
|
||||
#include "ir/manager.h"
|
||||
#include "backend/kernel_compiler/kernel_build_info.h"
|
||||
#include "backend/kernel_compiler/common_utils.h"
|
||||
#include "runtime/device/kernel_info.h"
|
||||
#include "backend/optimizer/graph_kernel/decrease_transfer_precision.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
||||
static const size_t GK_MIN_SIZE = 2; // 2
|
||||
|
||||
int64_t ObtainGetItemIndex(const AnfNodePtr &getitem) {
|
||||
auto index_node = getitem->cast<CNodePtr>()->input(kInputNodeOutputIndexInTupleGetItem);
|
||||
auto value_ptr = GetValueNode(index_node);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
bool IsPreNodeReduce(const FuncGraphPtr &func_graph, const AnfNodePtr &node, bool is_tuple_out, size_t index) {
|
||||
auto gk_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
|
||||
MS_EXCEPTION_IF_NULL(gk_graph);
|
||||
if (is_tuple_out) {
|
||||
auto tuple_output = gk_graph->output()->cast<CNodePtr>();
|
||||
if (AnfAlgo::GetCNodeName(tuple_output) != prim::kPrimMakeTuple->name()) {
|
||||
MS_EXCEPTION(UnknownError) << "\nThe return op is not a MakeTuple node\n";
|
||||
}
|
||||
auto input_node = tuple_output->input(index + 1);
|
||||
if (AnfAlgo::GetCNodeName(input_node) == prim::kPrimReduceSum->name()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
size_t GetGraphKernelSize(const AnfNodePtr &node) {
|
||||
auto gk_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
|
||||
MS_EXCEPTION_IF_NULL(gk_graph);
|
||||
return gk_graph->GetOrderedCnodes().size();
|
||||
}
|
||||
|
||||
bool IsCandidateNode(const AnfNodePtr &node) {
|
||||
bool is_gk = AnfAlgo::IsGraphKernel(node);
|
||||
if (is_gk) {
|
||||
auto num = GetGraphKernelSize(node);
|
||||
if (num > GK_MIN_SIZE) {
|
||||
auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
|
||||
auto graph_name = GetValue<std::string>(sub_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
|
||||
if (graph_name.find("atomic") == std::string::npos) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool IsAllUserCandidateNode(const AnfNodeIndexSet &users) {
|
||||
// check whether all user are graph kernel when more than one users for the in_node
|
||||
bool result = std::all_of(users.begin(), users.end(), [](const std::pair<AnfNodePtr, int> &node_index) {
|
||||
return IsCandidateNode(node_index.first);
|
||||
});
|
||||
return result;
|
||||
}
|
||||
|
||||
bool DecreaseTransferPrecision::Run(const FuncGraphPtr &func_graph) {
|
||||
auto mng = func_graph->manager();
|
||||
if (mng == nullptr) {
|
||||
mng = Manage(func_graph, true);
|
||||
func_graph->set_manager(mng);
|
||||
}
|
||||
auto users_map = mng->node_users();
|
||||
auto todos = TopoSort(func_graph->get_return());
|
||||
bool changed = false;
|
||||
for (const auto &node : todos) {
|
||||
auto is_candidate = IsCandidateNode(node);
|
||||
if (is_candidate) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
for (size_t index = 1; index < cnode->size(); index++) {
|
||||
auto dtype = AnfAlgo::GetInputDeviceDataType(node, index - 1);
|
||||
if (dtype != kNumberTypeFloat32) {
|
||||
continue;
|
||||
}
|
||||
auto item = cnode->input(index);
|
||||
if (!item->cast<CNodePtr>()) {
|
||||
continue;
|
||||
}
|
||||
auto in_node = item->cast<CNodePtr>();
|
||||
if (IsPrimitive(in_node->input(0), prim::kPrimTupleGetItem)) {
|
||||
auto tuple_node = in_node->input(1);
|
||||
auto tuple_index = ObtainGetItemIndex(in_node);
|
||||
auto has_reduce_output = IsPreNodeReduce(func_graph, tuple_node, true, tuple_index);
|
||||
auto fail_flag = !IsCandidateNode(tuple_node) ||
|
||||
(users_map[in_node].size() > 1 && IsAllUserCandidateNode(users_map[in_node])) ||
|
||||
has_reduce_output;
|
||||
if (fail_flag) {
|
||||
continue;
|
||||
}
|
||||
// mutate father
|
||||
Process_Father(func_graph, tuple_node, true, tuple_index);
|
||||
in_node->set_abstract(std::make_shared<abstract::AbstractTensor>(kFloat16, GetShape(in_node)));
|
||||
// mutate sons
|
||||
for (auto each_out : users_map[in_node]) {
|
||||
Process_Son(func_graph, each_out.first, each_out.second);
|
||||
}
|
||||
}
|
||||
if (IsCandidateNode(in_node)) {
|
||||
auto fail_flag = !IsAllUserCandidateNode(users_map[in_node]);
|
||||
if (fail_flag) {
|
||||
continue;
|
||||
}
|
||||
// mutate father
|
||||
Process_Father(func_graph, in_node, false, 0);
|
||||
// mutate sons
|
||||
Process_Son(func_graph, cnode, index);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
bool DecreaseTransferPrecision::Process_Father(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
bool is_tuple_out, size_t index) {
|
||||
auto gk_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
|
||||
MS_EXCEPTION_IF_NULL(gk_graph);
|
||||
auto mng = gk_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(mng);
|
||||
|
||||
// lambda func for cast fp32 to fp16
|
||||
auto func_add_cast_fp16 = [&gk_graph](const AnfNodePtr &old_output) {
|
||||
AnfNodePtrList inputs = {NewValueNode(prim::kPrimCast), old_output};
|
||||
auto cnode = gk_graph->NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
gk_graph->AddNode(cnode);
|
||||
cnode->set_abstract(std::make_shared<abstract::AbstractTensor>(kFloat16, GetShape(old_output)));
|
||||
cnode->set_scope(old_output->scope());
|
||||
SetNodeAttrSafely("dst_type", MakeValue(kernel::TypeId2String(kFloat16->type_id())), cnode);
|
||||
cnode->set_kernel_info(std::make_shared<device::KernelInfo>());
|
||||
std::vector<std::string> cnode_input_format = {AnfAlgo::GetOutputFormat(old_output, 0)};
|
||||
std::vector<TypeId> cnode_input_type = {kNumberTypeFloat32};
|
||||
std::vector<std::string> cnode_output_format = {AnfAlgo::GetOutputFormat(old_output, 0)};
|
||||
std::vector<TypeId> cnode_output_type = {kNumberTypeFloat16};
|
||||
kernel::KernelBuildInfo::KernelBuildInfoBuilder graph_info_builder;
|
||||
graph_info_builder.SetInputsFormat(cnode_input_format);
|
||||
graph_info_builder.SetInputsDeviceType(cnode_input_type);
|
||||
graph_info_builder.SetOutputsFormat(cnode_output_format);
|
||||
graph_info_builder.SetOutputsDeviceType(cnode_output_type);
|
||||
graph_info_builder.SetProcessor(kernel::GetProcessorFromContext());
|
||||
graph_info_builder.SetKernelType(KernelType::AKG_KERNEL);
|
||||
graph_info_builder.SetFusionType(kernel::FusionType::OPAQUE);
|
||||
auto info_1 = graph_info_builder.Build();
|
||||
AnfAlgo::SetSelectKernelBuildInfo(info_1, cnode.get());
|
||||
return cnode;
|
||||
};
|
||||
|
||||
if (!is_tuple_out) {
|
||||
auto old_output = gk_graph->output()->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(old_output);
|
||||
if (AnfAlgo::GetCNodeName(old_output) == prim::kPrimCast->name() &&
|
||||
AnfAlgo::GetInputDeviceDataType(old_output, 0) == kNumberTypeFloat16 &&
|
||||
AnfAlgo::GetOutputDeviceDataType(old_output, 0) == kNumberTypeFloat32) {
|
||||
auto real_output = old_output->input(1);
|
||||
gk_graph->set_output(real_output);
|
||||
} else {
|
||||
auto cnode = func_add_cast_fp16(old_output);
|
||||
gk_graph->set_output(cnode);
|
||||
}
|
||||
|
||||
// get kernel build info
|
||||
node->set_abstract(std::make_shared<abstract::AbstractTensor>(kFloat16, GetShape(node)));
|
||||
auto gk_builder_info =
|
||||
std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(AnfAlgo::GetSelectKernelBuildInfo(node));
|
||||
std::vector<TypeId> gk_output_type = {kNumberTypeFloat16};
|
||||
gk_builder_info->SetOutputsDeviceType(gk_output_type);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(gk_builder_info->Build(), node.get());
|
||||
return true;
|
||||
} else {
|
||||
// cast for graph kernel with make tuple output
|
||||
auto tuple_output = gk_graph->output()->cast<CNodePtr>();
|
||||
if (AnfAlgo::GetCNodeName(tuple_output) != prim::kPrimMakeTuple->name()) {
|
||||
MS_EXCEPTION(UnknownError) << "\nThe return op is not a MakeTuple node\n";
|
||||
}
|
||||
auto input_node = tuple_output->input(index + 1);
|
||||
auto cnode = func_add_cast_fp16(input_node);
|
||||
tuple_output->set_input(index + 1, cnode);
|
||||
|
||||
// Update MakeTuple node abstract
|
||||
AbstractBasePtrList abstract_list;
|
||||
for (size_t i = 1; i < tuple_output->size(); ++i) {
|
||||
abstract_list.emplace_back(tuple_output->input(i)->abstract());
|
||||
}
|
||||
tuple_output->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list));
|
||||
|
||||
// Update Graph Kernel abstract
|
||||
node->set_abstract(tuple_output->abstract());
|
||||
|
||||
// Update Graph Kernel Build Kernel Info
|
||||
auto old_builder_info = AnfAlgo::GetSelectKernelBuildInfo(node);
|
||||
auto gk_builder_info = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(old_builder_info);
|
||||
auto origin_outputs_type = old_builder_info->GetAllOutputDeviceTypes();
|
||||
std::vector<TypeId> gk_output_type;
|
||||
for (size_t i = 0; i < origin_outputs_type.size(); ++i) {
|
||||
gk_output_type.push_back(origin_outputs_type[i]);
|
||||
}
|
||||
gk_output_type[index] = kNumberTypeFloat16;
|
||||
gk_builder_info->SetOutputsDeviceType(gk_output_type);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(gk_builder_info->Build(), node.get());
|
||||
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
bool DecreaseTransferPrecision::Process_Son(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t index) {
|
||||
auto gk_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
|
||||
MS_EXCEPTION_IF_NULL(gk_graph);
|
||||
auto mng = gk_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(mng);
|
||||
auto old_input = gk_graph->get_inputs()[index - 1];
|
||||
MS_EXCEPTION_IF_NULL(old_input);
|
||||
|
||||
auto user_nodes = mng->node_users()[old_input];
|
||||
// get kernel build info
|
||||
auto gk_builder_info =
|
||||
std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(AnfAlgo::GetSelectKernelBuildInfo(node));
|
||||
auto ori_input_format = AnfAlgo::GetAllInputDeviceTypes(node);
|
||||
std::vector<TypeId> &new_inputs_type = ori_input_format;
|
||||
new_inputs_type[index - 1] = kNumberTypeFloat16;
|
||||
gk_builder_info->SetInputsDeviceType(new_inputs_type);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(gk_builder_info->Build(), node.get());
|
||||
AbstractBasePtr old_abstract = node->abstract()->Clone();
|
||||
node->set_abstract(old_abstract);
|
||||
|
||||
for (const auto &user : user_nodes) {
|
||||
auto user_node = user.first;
|
||||
if (IsPrimitiveCNode(user_node, prim::kPrimCast) &&
|
||||
AnfAlgo::GetOutputDeviceDataType(user_node, 0) == kNumberTypeFloat16) {
|
||||
mng->Replace(user_node, old_input);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
auto tensor_input = node->cast<CNodePtr>()->input(index);
|
||||
AnfNodePtrList inputs = {NewValueNode(prim::kPrimCast), old_input};
|
||||
auto cnode = gk_graph->NewCNode(inputs);
|
||||
|
||||
gk_graph->AddNode(cnode);
|
||||
cnode->set_abstract(old_input->abstract());
|
||||
cnode->set_scope(old_input->scope());
|
||||
SetNodeAttrSafely("dst_type", MakeValue(kernel::TypeId2String(kFloat32->type_id())), cnode);
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
old_input->set_abstract(std::make_shared<abstract::AbstractTensor>(kFloat16, GetShape(old_input)));
|
||||
cnode->set_kernel_info(std::make_shared<device::KernelInfo>());
|
||||
std::vector<std::string> cnode_input_format = {AnfAlgo::GetOutputFormat(tensor_input, 0)};
|
||||
std::vector<TypeId> cnode_input_type = {kNumberTypeFloat16};
|
||||
std::vector<std::string> cnode_output_format = {AnfAlgo::GetOutputFormat(tensor_input, 0)};
|
||||
std::vector<TypeId> cnode_output_type = {kNumberTypeFloat32};
|
||||
kernel::KernelBuildInfo::KernelBuildInfoBuilder node_info_builder;
|
||||
node_info_builder.SetInputsFormat(cnode_input_format);
|
||||
node_info_builder.SetInputsDeviceType(cnode_input_type);
|
||||
node_info_builder.SetOutputsFormat(cnode_output_format);
|
||||
node_info_builder.SetOutputsDeviceType(cnode_output_type);
|
||||
node_info_builder.SetProcessor(kernel::GetProcessorFromContext());
|
||||
node_info_builder.SetKernelType(KernelType::AKG_KERNEL);
|
||||
node_info_builder.SetFusionType(kernel::FusionType::OPAQUE);
|
||||
auto info_1 = node_info_builder.Build();
|
||||
AnfAlgo::SetSelectKernelBuildInfo(info_1, cnode.get());
|
||||
mng->Replace(old_input, cnode);
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,38 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_DECREASE_TRANSFER_PRECISION_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_DECREASE_TRANSFER_PRECISION_H_
|
||||
|
||||
#include <string>
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class DecreaseTransferPrecision : public Pass {
|
||||
public:
|
||||
DecreaseTransferPrecision() : Pass("decrease_transfer_precision") {}
|
||||
~DecreaseTransferPrecision() override = default;
|
||||
bool Run(const FuncGraphPtr &func_graph);
|
||||
|
||||
private:
|
||||
bool Process_Father(const FuncGraphPtr &func_graph, const AnfNodePtr &node, bool is_tuple_out = false,
|
||||
size_t index = 0);
|
||||
bool Process_Son(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t index);
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_DECREASE_TRANSFER_PRECISION_H_
|
|
@ -42,6 +42,8 @@
|
|||
#include "backend/optimizer/graph_kernel/reorder_ops.h"
|
||||
#include "backend/optimizer/graph_kernel/update_state_formatter.h"
|
||||
#include "backend/optimizer/graph_kernel/axis_normalizer.h"
|
||||
#include "backend/optimizer/graph_kernel/decrease_compute_precision.h"
|
||||
#include "backend/optimizer/graph_kernel/decrease_transfer_precision.h"
|
||||
#include "backend/optimizer/pass/getitem_tuple.h"
|
||||
#include "backend/optimizer/graph_kernel/graph_kernel_pass_manager.h"
|
||||
|
||||
|
@ -152,6 +154,10 @@ PassManagerPtr GraphKernelOptimizer::HighLevelOpt2() const {
|
|||
auto level = GetPassLevelByFlag(context::GraphKernelFlags::GetInstance().enable_stitch_fusion);
|
||||
pm->AddPass(std::make_shared<StitchAtomicCleanInsertter>(), level, is_gpu);
|
||||
|
||||
// Enable low precision
|
||||
auto level_low_precision = GetPassLevelByFlag(context::GraphKernelFlags::GetInstance().enable_low_precision);
|
||||
pm->AddPass(std::make_shared<DecreaseTransferPrecision>(), level_low_precision);
|
||||
pm->AddPass(std::make_shared<DecreaseComputePrecision>(), level_low_precision, is_ascend);
|
||||
return pm;
|
||||
}
|
||||
|
||||
|
|
|
@ -184,6 +184,7 @@ void GraphKernelFlags::RegisterFlags(std::map<std::string, std::string> *flag_ma
|
|||
reg.AddFlag("enable_stitch_fusion", &enable_stitch_fusion, opt_level == OptLevel_3);
|
||||
reg.AddFlag("enable_recompute_fusion", &enable_recompute_fusion, opt_level >= OptLevel_2);
|
||||
reg.AddFlag("enable_parallel_fusion", &enable_parallel_fusion, opt_level == OptLevel_3);
|
||||
reg.AddFlag("enable_low_precision", &enable_low_precision);
|
||||
|
||||
// Integer flags
|
||||
reg.AddFlag("online_tuning", &online_tuning);
|
||||
|
@ -211,6 +212,7 @@ std::string GraphKernelFlags::DumpAllFlags() const {
|
|||
json["enable_stitch_fusion"] = enable_stitch_fusion;
|
||||
json["enable_recompute_fusion"] = enable_recompute_fusion;
|
||||
json["enable_parallel_fusion"] = enable_parallel_fusion;
|
||||
json["enable_low_precision"] = enable_low_precision;
|
||||
|
||||
json["opt_level"] = opt_level;
|
||||
json["online_tuning"] = online_tuning;
|
||||
|
|
|
@ -79,6 +79,13 @@ class GraphKernelFlags {
|
|||
*/
|
||||
bool enable_parallel_fusion;
|
||||
|
||||
/**
|
||||
* Enable low precision in data transferring between graph kernel and computing in graph kernel
|
||||
* in graph kernel.
|
||||
* Experimental feature, enabled by the enable_low_precision flag
|
||||
*/
|
||||
bool enable_low_precision;
|
||||
|
||||
/**
|
||||
* Optimization level, value from 0 to 3.
|
||||
* 0: Disable GraphKernel
|
||||
|
|
|
@ -0,0 +1,120 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.context as context
|
||||
from mindspore import Tensor
|
||||
from mindspore.nn import Cell
|
||||
import mindspore.ops as ops
|
||||
import mindspore.ops.operations as P
|
||||
|
||||
def test_case_1():
|
||||
class Net1(Cell):
|
||||
def __init__(self):
|
||||
super(Net1, self).__init__()
|
||||
self.sub = ops.Sub()
|
||||
self.mul = ops.Mul()
|
||||
self.sum = ops.ReduceSum(keep_dims=False)
|
||||
self.add = ops.Add()
|
||||
self.pow = ops.Pow()
|
||||
def construct(self, x, y, z):
|
||||
t1 = self.sub(x, y)
|
||||
t2 = self.mul(t1, x)
|
||||
t3 = self.add(y, t2)
|
||||
t4 = self.add(t3, t3)
|
||||
t5 = z + 1.0
|
||||
t6 = self.sum(t4)
|
||||
t7 = self.add(t5, t6)
|
||||
return t7
|
||||
def get_output(x, y, z, net, enable_graph_kernel=False):
|
||||
context.set_context(enable_graph_kernel=enable_graph_kernel)
|
||||
net_obj = net()
|
||||
output = net_obj(x, y, z)
|
||||
return output
|
||||
|
||||
N = 8
|
||||
x = Tensor(np.random.uniform(1, 2, [N, N, N]).astype(np.float32))
|
||||
y = Tensor(np.random.uniform(1, 2, [N, N, N]).astype(np.float32))
|
||||
z = Tensor(np.random.uniform(1, 2, [N, N, N]).astype(np.float32))
|
||||
expect = get_output(x, y, z, Net1, False)
|
||||
output = get_output(x, y, z, Net1, True)
|
||||
expect_np = expect.asnumpy().copy()
|
||||
output_np = output.asnumpy().copy()
|
||||
assert np.allclose(expect_np, output_np, 1.e-2, 1.e-2)
|
||||
|
||||
|
||||
def test_case_2():
|
||||
class Net2(Cell):
|
||||
def __init__(self):
|
||||
super(Net2, self).__init__()
|
||||
self.sqrt = P.Sqrt()
|
||||
self.sum = P.ReduceSum(keep_dims=True)
|
||||
self.add = P.Add()
|
||||
self.neg = P.Neg()
|
||||
def construct(self, x, y):
|
||||
sqrt_res = self.sqrt(x)
|
||||
add_res = self.add(y, sqrt_res)
|
||||
neg_res = self.neg(add_res)
|
||||
return neg_res
|
||||
|
||||
def get_output(x, y, net, enable_graph_kernel=False):
|
||||
context.set_context(enable_graph_kernel=enable_graph_kernel)
|
||||
net_obj = net()
|
||||
output = net_obj(x, y)
|
||||
return output
|
||||
|
||||
N = 16
|
||||
x = Tensor(np.random.uniform(1, 2, [N, N]).astype(np.float32))
|
||||
y = Tensor(np.random.uniform(1, 2, [N, N]).astype(np.float32))
|
||||
expect = get_output(x, y, Net2, False)
|
||||
output = get_output(x, y, Net2, True)
|
||||
expect_np = expect[0].asnumpy().copy()
|
||||
output_np = output[0].asnumpy().copy()
|
||||
assert np.allclose(expect_np, output_np, 1.e-2, 1.e-2)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_gpu_case_1():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
context.set_context(graph_kernel_flags="--enable_low_precision=true --disable_pass=highlevelopt2.atomic_clean")
|
||||
test_case_1()
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_gpu_case_2():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
context.set_context(graph_kernel_flags="--enable_low_precision=true")
|
||||
test_case_2()
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_ascend_case_1():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
context.set_context(graph_kernel_flags="--enable_low_precision=true --disable_pass=highlevelopt2.atomic_clean")
|
||||
test_case_1()
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_ascend_case_2():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
context.set_context(graph_kernel_flags="--enable_low_precision=true")
|
||||
test_case_2()
|
Loading…
Reference in New Issue