add low precison

This commit is contained in:
cy 2021-07-28 11:39:38 +08:00
parent 8666a336d5
commit 4105a247b7
8 changed files with 786 additions and 0 deletions

View File

@ -0,0 +1,273 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <algorithm>
#include <vector>
#include <string>
#include <map>
#include <memory>
#include "base/core_ops.h"
#include "backend/optimizer/common/helper.h"
#include "backend/optimizer/ascend/ascend_helper.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
#include "runtime/device/kernel_info.h"
#include "backend/kernel_compiler/kernel_build_info.h"
#include "backend/kernel_compiler/common_utils.h"
#include "backend/optimizer/graph_kernel/decrease_compute_precision.h"
namespace mindspore {
namespace opt {
// Add CastCNode
CNodePtr AddCastCNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const std::string &format,
const TypeId &input_type, const TypeId &output_type, const std::vector<size_t> &origin_shape,
const TypeId &origin_type) {
MS_EXCEPTION_IF_NULL(func_graph);
std::string input_format = format;
std::string output_format = format;
CNodePtr cast = func_graph->NewCNode({NewValueNode(std::make_shared<Primitive>(prim::kPrimCast->name())), input});
MS_EXCEPTION_IF_NULL(cast);
// set kernel build info
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
builder.SetInputsFormat({input_format});
builder.SetOutputsFormat({output_format});
builder.SetInputsDeviceType({input_type});
builder.SetOutputsDeviceType({output_type});
builder.SetFusionType(kernel::FusionType::OPAQUE);
builder.SetProcessor(kernel::Processor::AICORE);
builder.SetKernelType(KernelType::AKG_KERNEL);
if (cast->kernel_info() == nullptr) {
auto kernel_info = std::make_shared<device::KernelInfo>();
cast->set_kernel_info(kernel_info);
}
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cast.get());
AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {origin_shape}, cast.get());
AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(true), cast);
AnfAlgo::SetNodeAttr(kAttrDatadumpOriginalNames, MakeValue<std::vector<std::string>>({}), cast);
return cast;
}
// Update Output Abatract and BuildInfo as Input Changed
void UpdateOutputInfo(const AnfNodePtr &cnode) {
if (!AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimMakeTuple)) {
ShapeVector out_shape = GetShape(cnode);
auto abs_shape_ptr = std::make_shared<abstract::Shape>(abstract::Shape(out_shape));
auto abstract = std::make_shared<abstract::AbstractTensor>(TypeIdToType(TypeId::kNumberTypeFloat16), abs_shape_ptr);
cnode->set_abstract(abstract);
std::vector<std::string> input_formats = AnfAlgo::GetAllInputFormats(cnode);
std::vector<TypeId> input_types = AnfAlgo::GetAllInputDeviceTypes(cnode);
for (size_t i = 0; i < input_types.size(); i++) {
input_types[i] = TypeId::kNumberTypeFloat16;
}
std::vector<std::string> output_formats = AnfAlgo::GetAllOutputFormats(cnode);
std::vector<TypeId> output_types = {TypeId::kNumberTypeFloat16};
auto graph_sel_info = BuildSelectKernelBuildInfo(input_formats, input_types, output_formats, output_types, cnode);
AnfAlgo::SetSelectKernelBuildInfo(graph_sel_info, cnode.get());
}
}
CNodePtr InsertCastForGraphKernel(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode);
auto mng = func_graph->manager();
size_t in_num = AnfAlgo::GetInputNum(cnode); // include monads.
for (size_t input_index = 0; input_index < in_num; ++input_index) {
auto cur_input = AnfAlgo::GetInputNode(cnode, input_index);
if (HasAbstractMonad(cur_input)) {
continue;
}
auto prev_node = AnfAlgo::GetPrevNodeOutput(cnode, input_index);
auto in_node = prev_node.first;
auto in_index = prev_node.second;
auto ori_shape = AnfAlgo::GetOutputDeviceShape(in_node, in_index);
auto ori_format = AnfAlgo::GetOutputFormat(in_node, in_index);
auto ori_dtype = AnfAlgo::GetOutputDeviceDataType(in_node, in_index);
const std::string dev_fmt = AnfAlgo::GetInputFormat(cnode, input_index);
if (cur_input->isa<ValueNode>()) {
ori_dtype = cur_input->cast<ValueNodePtr>()->value()->cast<tensor::TensorPtr>()->data_type();
}
auto new_dtype = TypeId::kNumberTypeFloat16;
if (ori_dtype == TypeId::kNumberTypeFloat32) {
if (cur_input->isa<ValueNode>()) {
auto valuePtr = cur_input->cast<ValueNodePtr>();
auto itensor = std::make_shared<tensor::Tensor>(
TypeId::kNumberTypeFloat16, valuePtr->value()->cast<tensor::TensorPtr>()->shape(),
valuePtr->value()->cast<tensor::TensorPtr>()->data_c(), TypeId::kNumberTypeFloat32);
auto value_node = std::make_shared<ValueNode>(itensor);
value_node->set_abstract(itensor->ToAbstract());
mng->Replace(cur_input, value_node);
}
auto cast = AddCastCNode(func_graph, cur_input, dev_fmt, ori_dtype, new_dtype, ori_shape, new_dtype);
MS_EXCEPTION_IF_NULL(cast);
cast->set_scope(cnode->scope());
ShapeVector out_shape = GetShape(cur_input);
auto abs_shape_ptr = std::make_shared<abstract::Shape>(abstract::Shape(out_shape));
auto abstract =
std::make_shared<abstract::AbstractTensor>(TypeIdToType(TypeId::kNumberTypeFloat16), abs_shape_ptr);
cast->set_abstract(abstract);
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), cast);
mng->Replace(cur_input, cast);
}
}
CNodePtr new_node = nullptr;
new_node = std::make_shared<CNode>(*cnode);
MS_EXCEPTION_IF_NULL(new_node);
UpdateOutputInfo(new_node);
return new_node;
}
bool DecreaseComputePrecision::Process(const FuncGraphPtr &func_graph) {
auto mng = func_graph->manager();
if (mng == nullptr) {
mng = Manage(func_graph, true);
func_graph->set_manager(mng);
}
auto todos = TopoSort(func_graph->get_return());
bool changed = false;
// Cast Down CNODES
for (auto node : todos) {
if (node->isa<CNode>() && !AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) {
auto cnode = node->cast<CNodePtr>();
if (IsPrimitiveCNode(cnode, prim::kPrimCast)) {
if (AnfAlgo::GetOutputDeviceDataType(cnode->input(1), 0) == kNumberTypeFloat16) {
auto in_node = cnode->input(1);
mng->Replace(node, in_node);
changed = true;
continue;
}
if (AnfAlgo::GetOutputDeviceDataType(cnode->input(1), 0) == kNumberTypeFloat32 &&
AnfAlgo::GetOutputDeviceDataType(cnode, 0) == kNumberTypeFloat16) {
continue;
}
}
auto new_node = InsertCastForGraphKernel(func_graph, cnode);
mng->Replace(node, new_node);
changed = true;
}
}
if (changed) {
mng->RemoveRoots();
mng->KeepRoots({func_graph});
}
// Cast Up Outputs
auto old_output = func_graph->output()->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(old_output);
auto add_cast = [&func_graph](const CNodePtr &old_cnode, bool is_output, std::vector<AnfNodePtr> &new_inputs) {
AnfNodePtrList inputs1 = {NewValueNode(prim::kPrimCast), old_cnode};
auto cnode1 = func_graph->NewCNode(inputs1);
func_graph->AddNode(cnode1);
ShapeVector cast_shape = GetShape(old_cnode);
auto shape_ptr = std::make_shared<abstract::Shape>(abstract::Shape(cast_shape));
auto new_abstract = std::make_shared<abstract::AbstractTensor>(TypeIdToType(TypeId::kNumberTypeFloat32), shape_ptr);
cnode1->set_abstract(new_abstract);
cnode1->set_scope(old_cnode->scope());
SetNodeAttrSafely("dst_type", MakeValue(kernel::TypeId2String(kFloat32->type_id())), cnode1);
MS_EXCEPTION_IF_NULL(cnode1);
cnode1->set_kernel_info(std::make_shared<device::KernelInfo>());
std::vector<std::string> cnode_input_format = {GetFormat(old_cnode)};
std::vector<TypeId> cnode_input_type = {kNumberTypeFloat16};
std::vector<std::string> cnode_output_format = {GetFormat(old_cnode)};
std::vector<TypeId> cnode_output_type = {kNumberTypeFloat32};
kernel::KernelBuildInfo::KernelBuildInfoBuilder graph_info_builder;
graph_info_builder.SetInputsFormat(cnode_input_format);
graph_info_builder.SetInputsDeviceType(cnode_input_type);
graph_info_builder.SetOutputsFormat(cnode_output_format);
graph_info_builder.SetOutputsDeviceType(cnode_output_type);
graph_info_builder.SetProcessor(kernel::GetProcessorFromContext());
graph_info_builder.SetKernelType(KernelType::AKG_KERNEL);
graph_info_builder.SetFusionType(kernel::FusionType::OPAQUE);
auto info_1 = graph_info_builder.Build();
AnfAlgo::SetSelectKernelBuildInfo(info_1, cnode1.get());
if (is_output) {
func_graph->set_output(cnode1);
} else {
new_inputs.emplace_back(cnode1);
}
};
std::vector<AnfNodePtr> new_inputs;
if (AnfAlgo::CheckPrimitiveType(old_output, prim::kPrimMakeTuple)) {
new_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
auto all_out = AnfAlgo::GetAllOutput(old_output);
for (const auto &out : all_out) {
auto c_out = out->cast<CNodePtr>();
if (c_out) {
add_cast(c_out, false, new_inputs);
}
}
old_output->set_inputs(new_inputs);
} else {
add_cast(old_output, true, new_inputs);
}
return changed;
}
bool IsCastUnAware(const FuncGraphPtr &func_graph) {
std::vector<PrimitivePtr> cast_aware_list = {prim::kPrimReduceSum, prim::kPrimReduceMean, prim::kPrimReduceAll};
auto mng = func_graph->manager();
if (mng == nullptr) {
mng = Manage(func_graph, true);
func_graph->set_manager(mng);
}
auto graph_name = GetValue<std::string>(func_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
if (graph_name.find("atomic") != std::string::npos) {
return false;
}
auto todos = TopoSort(func_graph->get_return());
for (auto node : todos) {
if (node->isa<CNode>()) {
if (std::find(cast_aware_list.begin(), cast_aware_list.end(), AnfAlgo::GetCNodePrimitive(node)) !=
cast_aware_list.end()) {
return false;
}
auto itype_id = AnfAlgo::GetOutputDeviceDataType(node, 0);
if (itype_id != TypeId::kNumberTypeFloat16 && itype_id != TypeId::kNumberTypeFloat32) {
return false;
}
}
}
return true;
}
bool DecreaseComputePrecision::Run(const FuncGraphPtr &func_graph) {
auto mng = func_graph->manager();
if (mng == nullptr) {
mng = Manage(func_graph, true);
func_graph->set_manager(mng);
}
auto todos = TopoSort(func_graph->get_return());
bool changed = false;
for (const auto &node : todos) {
if (AnfAlgo::IsGraphKernel(node)) {
auto sub_func_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
MS_ERROR_IF_NULL(sub_func_graph);
if (IsCastUnAware(sub_func_graph)) {
changed = Process(sub_func_graph) || changed;
}
}
}
if (changed) {
mng->RemoveRoots();
mng->KeepRoots({func_graph});
}
return changed;
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,43 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_DECREASE_COMPUTE_PRECISION_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_DECREASE_COMPUTE_PRECISION_H_
#include <vector>
#include <string>
#include "backend/optimizer/common/optimizer.h"
#include "backend/session/kernel_graph.h"
namespace mindspore {
namespace opt {
class DecreaseComputePrecision : public Pass {
public:
explicit DecreaseComputePrecision(const std::vector<PrimitivePtr> &black_list = {})
: Pass("decrease_compute_precision"), black_list_(black_list) {}
~DecreaseComputePrecision() override = default;
bool Run(const FuncGraphPtr &func_graph) override;
private:
bool Process(const FuncGraphPtr &func_graph);
AnfNodePtr CreateCast(const AnfNodePtr &input, const TypePtr &dst_type, const std::string &format) const;
std::vector<PrimitivePtr> black_list_;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_DECREASE_COMPUTE_PRECISION_H_

View File

@ -0,0 +1,297 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include <string>
#include <algorithm>
#include <memory>
#include <utility>
#include "base/core_ops.h"
#include "utils/utils.h"
#include "backend/optimizer/common/helper.h"
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "ir/tensor.h"
#include "ir/manager.h"
#include "backend/kernel_compiler/kernel_build_info.h"
#include "backend/kernel_compiler/common_utils.h"
#include "runtime/device/kernel_info.h"
#include "backend/optimizer/graph_kernel/decrease_transfer_precision.h"
namespace mindspore {
namespace opt {
static const size_t GK_MIN_SIZE = 2; // 2
int64_t ObtainGetItemIndex(const AnfNodePtr &getitem) {
auto index_node = getitem->cast<CNodePtr>()->input(kInputNodeOutputIndexInTupleGetItem);
auto value_ptr = GetValueNode(index_node);
return GetValue<int64_t>(value_ptr);
}
bool IsPreNodeReduce(const FuncGraphPtr &func_graph, const AnfNodePtr &node, bool is_tuple_out, size_t index) {
auto gk_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
MS_EXCEPTION_IF_NULL(gk_graph);
if (is_tuple_out) {
auto tuple_output = gk_graph->output()->cast<CNodePtr>();
if (AnfAlgo::GetCNodeName(tuple_output) != prim::kPrimMakeTuple->name()) {
MS_EXCEPTION(UnknownError) << "\nThe return op is not a MakeTuple node\n";
}
auto input_node = tuple_output->input(index + 1);
if (AnfAlgo::GetCNodeName(input_node) == prim::kPrimReduceSum->name()) {
return true;
}
}
return false;
}
size_t GetGraphKernelSize(const AnfNodePtr &node) {
auto gk_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
MS_EXCEPTION_IF_NULL(gk_graph);
return gk_graph->GetOrderedCnodes().size();
}
bool IsCandidateNode(const AnfNodePtr &node) {
bool is_gk = AnfAlgo::IsGraphKernel(node);
if (is_gk) {
auto num = GetGraphKernelSize(node);
if (num > GK_MIN_SIZE) {
auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
auto graph_name = GetValue<std::string>(sub_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
if (graph_name.find("atomic") == std::string::npos) {
return true;
}
}
}
return false;
}
bool IsAllUserCandidateNode(const AnfNodeIndexSet &users) {
// check whether all user are graph kernel when more than one users for the in_node
bool result = std::all_of(users.begin(), users.end(), [](const std::pair<AnfNodePtr, int> &node_index) {
return IsCandidateNode(node_index.first);
});
return result;
}
bool DecreaseTransferPrecision::Run(const FuncGraphPtr &func_graph) {
auto mng = func_graph->manager();
if (mng == nullptr) {
mng = Manage(func_graph, true);
func_graph->set_manager(mng);
}
auto users_map = mng->node_users();
auto todos = TopoSort(func_graph->get_return());
bool changed = false;
for (const auto &node : todos) {
auto is_candidate = IsCandidateNode(node);
if (is_candidate) {
auto cnode = node->cast<CNodePtr>();
for (size_t index = 1; index < cnode->size(); index++) {
auto dtype = AnfAlgo::GetInputDeviceDataType(node, index - 1);
if (dtype != kNumberTypeFloat32) {
continue;
}
auto item = cnode->input(index);
if (!item->cast<CNodePtr>()) {
continue;
}
auto in_node = item->cast<CNodePtr>();
if (IsPrimitive(in_node->input(0), prim::kPrimTupleGetItem)) {
auto tuple_node = in_node->input(1);
auto tuple_index = ObtainGetItemIndex(in_node);
auto has_reduce_output = IsPreNodeReduce(func_graph, tuple_node, true, tuple_index);
auto fail_flag = !IsCandidateNode(tuple_node) ||
(users_map[in_node].size() > 1 && IsAllUserCandidateNode(users_map[in_node])) ||
has_reduce_output;
if (fail_flag) {
continue;
}
// mutate father
Process_Father(func_graph, tuple_node, true, tuple_index);
in_node->set_abstract(std::make_shared<abstract::AbstractTensor>(kFloat16, GetShape(in_node)));
// mutate sons
for (auto each_out : users_map[in_node]) {
Process_Son(func_graph, each_out.first, each_out.second);
}
}
if (IsCandidateNode(in_node)) {
auto fail_flag = !IsAllUserCandidateNode(users_map[in_node]);
if (fail_flag) {
continue;
}
// mutate father
Process_Father(func_graph, in_node, false, 0);
// mutate sons
Process_Son(func_graph, cnode, index);
}
}
}
}
return changed;
}
bool DecreaseTransferPrecision::Process_Father(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
bool is_tuple_out, size_t index) {
auto gk_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
MS_EXCEPTION_IF_NULL(gk_graph);
auto mng = gk_graph->manager();
MS_EXCEPTION_IF_NULL(mng);
// lambda func for cast fp32 to fp16
auto func_add_cast_fp16 = [&gk_graph](const AnfNodePtr &old_output) {
AnfNodePtrList inputs = {NewValueNode(prim::kPrimCast), old_output};
auto cnode = gk_graph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(cnode);
gk_graph->AddNode(cnode);
cnode->set_abstract(std::make_shared<abstract::AbstractTensor>(kFloat16, GetShape(old_output)));
cnode->set_scope(old_output->scope());
SetNodeAttrSafely("dst_type", MakeValue(kernel::TypeId2String(kFloat16->type_id())), cnode);
cnode->set_kernel_info(std::make_shared<device::KernelInfo>());
std::vector<std::string> cnode_input_format = {AnfAlgo::GetOutputFormat(old_output, 0)};
std::vector<TypeId> cnode_input_type = {kNumberTypeFloat32};
std::vector<std::string> cnode_output_format = {AnfAlgo::GetOutputFormat(old_output, 0)};
std::vector<TypeId> cnode_output_type = {kNumberTypeFloat16};
kernel::KernelBuildInfo::KernelBuildInfoBuilder graph_info_builder;
graph_info_builder.SetInputsFormat(cnode_input_format);
graph_info_builder.SetInputsDeviceType(cnode_input_type);
graph_info_builder.SetOutputsFormat(cnode_output_format);
graph_info_builder.SetOutputsDeviceType(cnode_output_type);
graph_info_builder.SetProcessor(kernel::GetProcessorFromContext());
graph_info_builder.SetKernelType(KernelType::AKG_KERNEL);
graph_info_builder.SetFusionType(kernel::FusionType::OPAQUE);
auto info_1 = graph_info_builder.Build();
AnfAlgo::SetSelectKernelBuildInfo(info_1, cnode.get());
return cnode;
};
if (!is_tuple_out) {
auto old_output = gk_graph->output()->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(old_output);
if (AnfAlgo::GetCNodeName(old_output) == prim::kPrimCast->name() &&
AnfAlgo::GetInputDeviceDataType(old_output, 0) == kNumberTypeFloat16 &&
AnfAlgo::GetOutputDeviceDataType(old_output, 0) == kNumberTypeFloat32) {
auto real_output = old_output->input(1);
gk_graph->set_output(real_output);
} else {
auto cnode = func_add_cast_fp16(old_output);
gk_graph->set_output(cnode);
}
// get kernel build info
node->set_abstract(std::make_shared<abstract::AbstractTensor>(kFloat16, GetShape(node)));
auto gk_builder_info =
std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(AnfAlgo::GetSelectKernelBuildInfo(node));
std::vector<TypeId> gk_output_type = {kNumberTypeFloat16};
gk_builder_info->SetOutputsDeviceType(gk_output_type);
AnfAlgo::SetSelectKernelBuildInfo(gk_builder_info->Build(), node.get());
return true;
} else {
// cast for graph kernel with make tuple output
auto tuple_output = gk_graph->output()->cast<CNodePtr>();
if (AnfAlgo::GetCNodeName(tuple_output) != prim::kPrimMakeTuple->name()) {
MS_EXCEPTION(UnknownError) << "\nThe return op is not a MakeTuple node\n";
}
auto input_node = tuple_output->input(index + 1);
auto cnode = func_add_cast_fp16(input_node);
tuple_output->set_input(index + 1, cnode);
// Update MakeTuple node abstract
AbstractBasePtrList abstract_list;
for (size_t i = 1; i < tuple_output->size(); ++i) {
abstract_list.emplace_back(tuple_output->input(i)->abstract());
}
tuple_output->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list));
// Update Graph Kernel abstract
node->set_abstract(tuple_output->abstract());
// Update Graph Kernel Build Kernel Info
auto old_builder_info = AnfAlgo::GetSelectKernelBuildInfo(node);
auto gk_builder_info = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(old_builder_info);
auto origin_outputs_type = old_builder_info->GetAllOutputDeviceTypes();
std::vector<TypeId> gk_output_type;
for (size_t i = 0; i < origin_outputs_type.size(); ++i) {
gk_output_type.push_back(origin_outputs_type[i]);
}
gk_output_type[index] = kNumberTypeFloat16;
gk_builder_info->SetOutputsDeviceType(gk_output_type);
AnfAlgo::SetSelectKernelBuildInfo(gk_builder_info->Build(), node.get());
return true;
}
}
bool DecreaseTransferPrecision::Process_Son(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t index) {
auto gk_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
MS_EXCEPTION_IF_NULL(gk_graph);
auto mng = gk_graph->manager();
MS_EXCEPTION_IF_NULL(mng);
auto old_input = gk_graph->get_inputs()[index - 1];
MS_EXCEPTION_IF_NULL(old_input);
auto user_nodes = mng->node_users()[old_input];
// get kernel build info
auto gk_builder_info =
std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(AnfAlgo::GetSelectKernelBuildInfo(node));
auto ori_input_format = AnfAlgo::GetAllInputDeviceTypes(node);
std::vector<TypeId> &new_inputs_type = ori_input_format;
new_inputs_type[index - 1] = kNumberTypeFloat16;
gk_builder_info->SetInputsDeviceType(new_inputs_type);
AnfAlgo::SetSelectKernelBuildInfo(gk_builder_info->Build(), node.get());
AbstractBasePtr old_abstract = node->abstract()->Clone();
node->set_abstract(old_abstract);
for (const auto &user : user_nodes) {
auto user_node = user.first;
if (IsPrimitiveCNode(user_node, prim::kPrimCast) &&
AnfAlgo::GetOutputDeviceDataType(user_node, 0) == kNumberTypeFloat16) {
mng->Replace(user_node, old_input);
return true;
}
}
auto tensor_input = node->cast<CNodePtr>()->input(index);
AnfNodePtrList inputs = {NewValueNode(prim::kPrimCast), old_input};
auto cnode = gk_graph->NewCNode(inputs);
gk_graph->AddNode(cnode);
cnode->set_abstract(old_input->abstract());
cnode->set_scope(old_input->scope());
SetNodeAttrSafely("dst_type", MakeValue(kernel::TypeId2String(kFloat32->type_id())), cnode);
MS_EXCEPTION_IF_NULL(cnode);
old_input->set_abstract(std::make_shared<abstract::AbstractTensor>(kFloat16, GetShape(old_input)));
cnode->set_kernel_info(std::make_shared<device::KernelInfo>());
std::vector<std::string> cnode_input_format = {AnfAlgo::GetOutputFormat(tensor_input, 0)};
std::vector<TypeId> cnode_input_type = {kNumberTypeFloat16};
std::vector<std::string> cnode_output_format = {AnfAlgo::GetOutputFormat(tensor_input, 0)};
std::vector<TypeId> cnode_output_type = {kNumberTypeFloat32};
kernel::KernelBuildInfo::KernelBuildInfoBuilder node_info_builder;
node_info_builder.SetInputsFormat(cnode_input_format);
node_info_builder.SetInputsDeviceType(cnode_input_type);
node_info_builder.SetOutputsFormat(cnode_output_format);
node_info_builder.SetOutputsDeviceType(cnode_output_type);
node_info_builder.SetProcessor(kernel::GetProcessorFromContext());
node_info_builder.SetKernelType(KernelType::AKG_KERNEL);
node_info_builder.SetFusionType(kernel::FusionType::OPAQUE);
auto info_1 = node_info_builder.Build();
AnfAlgo::SetSelectKernelBuildInfo(info_1, cnode.get());
mng->Replace(old_input, cnode);
return true;
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,38 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_DECREASE_TRANSFER_PRECISION_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_DECREASE_TRANSFER_PRECISION_H_
#include <string>
#include "backend/optimizer/common/optimizer.h"
namespace mindspore {
namespace opt {
class DecreaseTransferPrecision : public Pass {
public:
DecreaseTransferPrecision() : Pass("decrease_transfer_precision") {}
~DecreaseTransferPrecision() override = default;
bool Run(const FuncGraphPtr &func_graph);
private:
bool Process_Father(const FuncGraphPtr &func_graph, const AnfNodePtr &node, bool is_tuple_out = false,
size_t index = 0);
bool Process_Son(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t index);
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_DECREASE_TRANSFER_PRECISION_H_

View File

@ -42,6 +42,8 @@
#include "backend/optimizer/graph_kernel/reorder_ops.h"
#include "backend/optimizer/graph_kernel/update_state_formatter.h"
#include "backend/optimizer/graph_kernel/axis_normalizer.h"
#include "backend/optimizer/graph_kernel/decrease_compute_precision.h"
#include "backend/optimizer/graph_kernel/decrease_transfer_precision.h"
#include "backend/optimizer/pass/getitem_tuple.h"
#include "backend/optimizer/graph_kernel/graph_kernel_pass_manager.h"
@ -152,6 +154,10 @@ PassManagerPtr GraphKernelOptimizer::HighLevelOpt2() const {
auto level = GetPassLevelByFlag(context::GraphKernelFlags::GetInstance().enable_stitch_fusion);
pm->AddPass(std::make_shared<StitchAtomicCleanInsertter>(), level, is_gpu);
// Enable low precision
auto level_low_precision = GetPassLevelByFlag(context::GraphKernelFlags::GetInstance().enable_low_precision);
pm->AddPass(std::make_shared<DecreaseTransferPrecision>(), level_low_precision);
pm->AddPass(std::make_shared<DecreaseComputePrecision>(), level_low_precision, is_ascend);
return pm;
}

View File

@ -184,6 +184,7 @@ void GraphKernelFlags::RegisterFlags(std::map<std::string, std::string> *flag_ma
reg.AddFlag("enable_stitch_fusion", &enable_stitch_fusion, opt_level == OptLevel_3);
reg.AddFlag("enable_recompute_fusion", &enable_recompute_fusion, opt_level >= OptLevel_2);
reg.AddFlag("enable_parallel_fusion", &enable_parallel_fusion, opt_level == OptLevel_3);
reg.AddFlag("enable_low_precision", &enable_low_precision);
// Integer flags
reg.AddFlag("online_tuning", &online_tuning);
@ -211,6 +212,7 @@ std::string GraphKernelFlags::DumpAllFlags() const {
json["enable_stitch_fusion"] = enable_stitch_fusion;
json["enable_recompute_fusion"] = enable_recompute_fusion;
json["enable_parallel_fusion"] = enable_parallel_fusion;
json["enable_low_precision"] = enable_low_precision;
json["opt_level"] = opt_level;
json["online_tuning"] = online_tuning;

View File

@ -79,6 +79,13 @@ class GraphKernelFlags {
*/
bool enable_parallel_fusion;
/**
* Enable low precision in data transferring between graph kernel and computing in graph kernel
* in graph kernel.
* Experimental feature, enabled by the enable_low_precision flag
*/
bool enable_low_precision;
/**
* Optimization level, value from 0 to 3.
* 0: Disable GraphKernel

View File

@ -0,0 +1,120 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
from mindspore import Tensor
from mindspore.nn import Cell
import mindspore.ops as ops
import mindspore.ops.operations as P
def test_case_1():
class Net1(Cell):
def __init__(self):
super(Net1, self).__init__()
self.sub = ops.Sub()
self.mul = ops.Mul()
self.sum = ops.ReduceSum(keep_dims=False)
self.add = ops.Add()
self.pow = ops.Pow()
def construct(self, x, y, z):
t1 = self.sub(x, y)
t2 = self.mul(t1, x)
t3 = self.add(y, t2)
t4 = self.add(t3, t3)
t5 = z + 1.0
t6 = self.sum(t4)
t7 = self.add(t5, t6)
return t7
def get_output(x, y, z, net, enable_graph_kernel=False):
context.set_context(enable_graph_kernel=enable_graph_kernel)
net_obj = net()
output = net_obj(x, y, z)
return output
N = 8
x = Tensor(np.random.uniform(1, 2, [N, N, N]).astype(np.float32))
y = Tensor(np.random.uniform(1, 2, [N, N, N]).astype(np.float32))
z = Tensor(np.random.uniform(1, 2, [N, N, N]).astype(np.float32))
expect = get_output(x, y, z, Net1, False)
output = get_output(x, y, z, Net1, True)
expect_np = expect.asnumpy().copy()
output_np = output.asnumpy().copy()
assert np.allclose(expect_np, output_np, 1.e-2, 1.e-2)
def test_case_2():
class Net2(Cell):
def __init__(self):
super(Net2, self).__init__()
self.sqrt = P.Sqrt()
self.sum = P.ReduceSum(keep_dims=True)
self.add = P.Add()
self.neg = P.Neg()
def construct(self, x, y):
sqrt_res = self.sqrt(x)
add_res = self.add(y, sqrt_res)
neg_res = self.neg(add_res)
return neg_res
def get_output(x, y, net, enable_graph_kernel=False):
context.set_context(enable_graph_kernel=enable_graph_kernel)
net_obj = net()
output = net_obj(x, y)
return output
N = 16
x = Tensor(np.random.uniform(1, 2, [N, N]).astype(np.float32))
y = Tensor(np.random.uniform(1, 2, [N, N]).astype(np.float32))
expect = get_output(x, y, Net2, False)
output = get_output(x, y, Net2, True)
expect_np = expect[0].asnumpy().copy()
output_np = output[0].asnumpy().copy()
assert np.allclose(expect_np, output_np, 1.e-2, 1.e-2)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_gpu_case_1():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
context.set_context(graph_kernel_flags="--enable_low_precision=true --disable_pass=highlevelopt2.atomic_clean")
test_case_1()
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_gpu_case_2():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
context.set_context(graph_kernel_flags="--enable_low_precision=true")
test_case_2()
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_ascend_case_1():
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(graph_kernel_flags="--enable_low_precision=true --disable_pass=highlevelopt2.atomic_clean")
test_case_1()
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_ascend_case_2():
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(graph_kernel_flags="--enable_low_precision=true")
test_case_2()