From 689158f79b137aca33ab7864d7add51cc093de36 Mon Sep 17 00:00:00 2001 From: yuchaojie Date: Mon, 7 Jun 2021 09:28:18 +0800 Subject: [PATCH] FracZ format conversion when conv2d group > 1 --- .../ascend/ascend_backend_optimization.cc | 3 + .../ascend/buffer_fusion/ub_pattern_fusion.cc | 8 + .../ir_fusion/refresh_parameter_format.cc | 4 +- .../ascend/ir_fusion/set_fracz_group_attr.cc | 167 +++++++++++++++ .../ascend/ir_fusion/set_fracz_group_attr.h | 37 ++++ .../backend/session/anf_runtime_algorithm.cc | 14 +- mindspore/ccsrc/common/trans.cc | 193 +++++++++++++++++- mindspore/ccsrc/common/trans.h | 16 +- .../device/ascend/ascend_device_address.cc | 21 +- .../device/ascend/ascend_device_address.h | 5 +- .../device/ascend/ascend_kernel_runtime.cc | 7 +- .../device/ascend/ascend_kernel_runtime.h | 4 +- .../runtime/device/cpu/cpu_device_address.cc | 2 +- .../runtime/device/cpu/cpu_device_address.h | 6 +- .../runtime/device/cpu/cpu_kernel_runtime.cc | 5 + .../runtime/device/cpu/cpu_kernel_runtime.h | 4 +- .../ccsrc/runtime/device/device_address.h | 14 +- .../runtime/device/gpu/gpu_device_address.cc | 2 +- .../runtime/device/gpu/gpu_device_address.h | 6 +- .../runtime/device/gpu/gpu_kernel_runtime.cc | 5 + .../runtime/device/gpu/gpu_kernel_runtime.h | 2 + .../ccsrc/runtime/device/kernel_runtime.cc | 22 +- .../ccsrc/runtime/device/kernel_runtime.h | 4 +- mindspore/ccsrc/utils/utils.h | 7 + mindspore/core/ir/anf.h | 5 + 25 files changed, 517 insertions(+), 46 deletions(-) create mode 100644 mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/set_fracz_group_attr.cc create mode 100644 mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/set_fracz_group_attr.h diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc index 4c45b0d2485..c2db7c398aa 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc @@ -69,6 +69,7 @@ #include "backend/optimizer/ascend/ir_fusion/batchnormgrad_to_bninfergrad.h" #include "backend/optimizer/ascend/ir_fusion/confusion_mul_grad_fusion.h" #include "backend/optimizer/ascend/ir_fusion/softmax_grad_ext_fusion.h" +#include "backend/optimizer/ascend/ir_fusion/set_fracz_group_attr.h" #include "backend/optimizer/ascend/format_type/insert_trans_op.h" #include "backend/optimizer/ascend/format_type/trans_op_format_refine.h" #include "backend/optimizer/ascend/format_type/dynamic_rnn_grad_reformat.h" @@ -369,6 +370,8 @@ void AscendBackendOptimization(const std::shared_ptr &kern other_pm->AddPass(std::make_shared()); other_pm->AddPass(std::make_shared()); other_pm->AddPass(std::make_shared()); + other_pm->AddPass(std::make_shared()); + other_pm->AddPass(std::make_shared()); optimizer->AddPassManager(other_pm); (void)optimizer->Optimize(kernel_graph); kernel_graph->SetExecOrderByDefault(); diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/ub_pattern_fusion.cc b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/ub_pattern_fusion.cc index c6d7d55234a..c63a1ebc61b 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/ub_pattern_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/ub_pattern_fusion.cc @@ -99,6 +99,14 @@ CNodePtr CreateFusionOp(const std::vector &inputs_list, const std::v ValuePtr output_names_v = MakeValue(output_names); fusion_op->set_attr("input_names", input_names_v); fusion_op->set_attr("output_names", output_names_v); + for (auto node : anf_nodes) { + auto cnode = node->cast(); + if (AnfAlgo::HasNodeAttr(kAttrFracZGroup, cnode)) { + auto fracz_group = AnfAlgo::GetNodeAttr(node, kAttrFracZGroup); + fusion_op->set_attr(kAttrFracZGroup, MakeValue(fracz_group)); + break; + } + } std::vector fusion_inputs_list = inputs_list; auto value_node = std::make_shared(fusion_op); (void)fusion_inputs_list.insert(fusion_inputs_list.begin(), value_node); diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/refresh_parameter_format.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/refresh_parameter_format.cc index 1170f77c19a..20ee55e1049 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/refresh_parameter_format.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/refresh_parameter_format.cc @@ -23,9 +23,7 @@ namespace mindspore { namespace opt { void DoRefresh(const CNodePtr &cnode) { - if (cnode == nullptr) { - MS_LOG(EXCEPTION) << "node is nullptr"; - } + MS_EXCEPTION_IF_NULL(cnode); size_t input_num = AnfAlgo::GetInputTensorNum(cnode); for (size_t input_index = 0; input_index < input_num; input_index++) { auto input_kernel_node = AnfAlgo::VisitKernel(AnfAlgo::GetInputNode(cnode, input_index), 0).first; diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/set_fracz_group_attr.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/set_fracz_group_attr.cc new file mode 100644 index 00000000000..e591b3f3dc5 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/set_fracz_group_attr.cc @@ -0,0 +1,167 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/ascend/ir_fusion/set_fracz_group_attr.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "utils/utils.h" +#include "runtime/device/kernel_info.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace opt { +namespace { +constexpr auto kTupleGetItemName = "TupleGetItem"; +constexpr auto kDependName = "Depend"; +constexpr auto kLoadName = "Load"; +constexpr size_t kConvFilterInputIndex = 2; +constexpr size_t kElemFilterInputIndex = 1; + +void SetAttrForInputNode(const AnfNodePtr &node, int64_t groups) { + if (node == nullptr) { + return; + } + if (node->isa()) { + auto param = node->cast(); + MS_EXCEPTION_IF_NULL(param); + param->set_fracz_group(groups); + MS_LOG(INFO) << "set parameter " << param->fullname_with_scope() << " with fracz_group: " << groups; + } else if (node->isa()) { + AnfAlgo::SetNodeAttr(kAttrFracZGroup, MakeValue(groups), node); + if (AnfAlgo::GetCNodeName(node) == kTransDataOpName) { + AnfAlgo::SetNodeAttr(kAttrGroups, MakeValue(groups), node); + } + auto cnode = node->cast(); + SetAttrForInputNode(cnode->input(kElemFilterInputIndex), groups); + } +} + +void SetAttrForConvInput(const CNodePtr &cnode) { + MS_EXCEPTION_IF_NULL(cnode); + auto groups = AnfAlgo::GetNodeAttr(cnode, kAttrGroups); + if (groups > 1) { + SetAttrForInputNode(cnode->input(kConvFilterInputIndex), groups); + } +} + +void SetAttrForOptParamInput(const AnfNodePtr &node, int64_t groups) { + // For optimizer, there may be other parameters used by opt need to be set. + // For example, moments param used by FusedMulApplyMomentum. + MS_EXCEPTION_IF_NULL(node); + auto opt_cnode = node->cast(); + MS_EXCEPTION_IF_NULL(opt_cnode); + auto opt_inputs = opt_cnode->inputs(); + for (size_t i = 1; i < opt_inputs.size(); ++i) { + auto input_node = opt_inputs[i]; + if (input_node->isa() && AnfAlgo::GetCNodeName(input_node) == kLoadName) { + auto input_cnode = input_node->cast(); + MS_EXCEPTION_IF_NULL(input_cnode); + input_node = input_cnode->input(kElemFilterInputIndex); + } + if (input_node->isa()) { + auto param = input_node->cast(); + MS_EXCEPTION_IF_NULL(param); + param->set_fracz_group(groups); + } + } +} + +void SetFracZGroupIdxForAllReduce(const AnfNodePtr &node, const int64_t index) { + // When Allreduce do fusion, there may be several FracZ outputs with groups=1 or groups>1, + // so we need to record the output index with groups>1 + MS_EXCEPTION_IF_NULL(node); + auto allreduce = node->cast(); + MS_EXCEPTION_IF_NULL(allreduce); + if (AnfAlgo::HasNodeAttr(kAttrFracZGroupIdx, allreduce)) { + auto fz_group_idx = AnfAlgo::GetNodeAttr>(allreduce, kAttrFracZGroupIdx); + fz_group_idx.push_back(index); + AnfAlgo::SetNodeAttr(kAttrFracZGroupIdx, MakeValue(fz_group_idx), allreduce); + } else { + AnfAlgo::SetNodeAttr(kAttrFracZGroupIdx, MakeValue(std::vector{index}), allreduce); + } +} + +void SetAttrForOutputNode(const FuncGraphManagerPtr &manager, const AnfNodePtr &node, int64_t groups, + int64_t getitem_idx = 0) { + MS_EXCEPTION_IF_NULL(manager); + MS_EXCEPTION_IF_NULL(node); + if (AnfAlgo::GetCNodeName(node) != kTupleGetItemName) { + AnfAlgo::SetNodeAttr(kAttrFracZGroup, MakeValue(groups), node); + } + for (auto node_index : manager->node_users()[node]) { + auto output_node = node_index.first; + auto output_name = AnfAlgo::GetCNodeName(output_node); + if (kOptOperatorSet.find(output_name) != kOptOperatorSet.end()) { + SetAttrForOptParamInput(output_node, groups); + AnfAlgo::SetNodeAttr(kAttrFracZGroup, MakeValue(groups), output_node); + } else if (output_name == kTransDataOpName) { + // Trans to other format, no need to recurse, but need to set Groups attr for TransData + AnfAlgo::SetNodeAttr(kAttrGroups, MakeValue(groups), output_node); + } else if (output_name == kAllReduceOpName) { + int64_t index = static_cast(node_index.second) - 1; + SetFracZGroupIdxForAllReduce(output_node, index); + SetAttrForOutputNode(manager, output_node, groups, index); + } else if (output_name == kTupleGetItemName) { + auto getitem = output_node->cast(); + MS_EXCEPTION_IF_NULL(getitem); + auto getitem_input2 = getitem->input(kInputNodeOutputIndexInTupleGetItem); + auto output_idx = GetValue(GetValueNode(getitem_input2)); + if (output_idx == getitem_idx) { + SetAttrForOutputNode(manager, output_node, groups); + return; + } + } else { + SetAttrForOutputNode(manager, output_node, groups); + } + } +} + +void SetAttrForConvOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { + MS_EXCEPTION_IF_NULL(func_graph); + auto groups = AnfAlgo::GetNodeAttr(cnode, kAttrGroups); + if (groups > 1) { + auto kernel_graph = func_graph->cast(); + MS_EXCEPTION_IF_NULL(kernel_graph); + auto manager = kernel_graph->manager(); + SetAttrForOutputNode(manager, cnode, groups); + } +} +} // namespace + +bool SetFraczGroupAttr::Run(const FuncGraphPtr &func_graph) { + if (func_graph == nullptr) { + MS_LOG(ERROR) << "func_graph is nullptr."; + return false; + } + std::vector node_list = TopoSort(func_graph->get_return()); + for (auto node : node_list) { + if (node == nullptr || !node->isa()) { + continue; + } + auto cnode = node->cast(); + if (cnode == nullptr) { + continue; + } + auto node_name = AnfAlgo::GetCNodeName(cnode); + if (node_name == kConv2DOpName || node_name == kConv2DBackpropInputOpName) { + SetAttrForConvInput(cnode); + } else if (node_name == kConv2DBackpropFilterOpName) { + SetAttrForConvOutput(func_graph, cnode); + } + } + return true; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/set_fracz_group_attr.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/set_fracz_group_attr.h new file mode 100644 index 00000000000..ce55ec52d88 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/set_fracz_group_attr.h @@ -0,0 +1,37 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_SET_FRACZ_GROUP_ATTR_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_SET_FRACZ_GROUP_ATTR_H_ + +#include +#include +#include +#include "ir/anf.h" +#include "backend/optimizer/common/pass.h" + +namespace mindspore { +namespace opt { +class SetFraczGroupAttr : public Pass { + public: + SetFraczGroupAttr() : Pass("set_fracz_group_attr") {} + ~SetFraczGroupAttr() override = default; + bool Run(const FuncGraphPtr &graph) override; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_SET_FRACZ_GROUP_ATTR_H_ diff --git a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc index ae537695285..67a31f0230b 100644 --- a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc +++ b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc @@ -590,7 +590,7 @@ size_t AnfRuntimeAlgorithm::GetOutputTensorMemSize(const AnfNodePtr &node, size_ auto format = AnfAlgo::GetOutputFormat(node, output_index); if (shape.empty() && format != kOpFormat_DEFAULT) { shape = trans::PaddingShape(shape, format, AnfAlgo::GetOutputReshapeType(node, output_index)); - shape = trans::TransShapeToDevice(shape, format); + shape = trans::TransShapeToDevice(shape, format, node, output_index); } // scalar's output shape is a empty vector size_t tensor_size = std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies()); @@ -804,7 +804,7 @@ std::vector AnfRuntimeAlgorithm::GetOutputDeviceShape(const AnfNodePtr & if (trans::IsNeedPadding(format, infer_shape.size())) { infer_shape = trans::PaddingShape(infer_shape, format, GetOutputReshapeType(node, output_idx)); } - return trans::TransShapeToDevice(infer_shape, format); + return trans::TransShapeToDevice(infer_shape, format, node, output_idx); } std::vector AnfRuntimeAlgorithm::GetInputDeviceShape(const AnfNodePtr &node, size_t input_idx) { @@ -817,7 +817,8 @@ std::vector AnfRuntimeAlgorithm::GetInputDeviceShape(const AnfNodePtr &n if (trans::IsNeedPadding(format, infer_shape.size())) { infer_shape = trans::PaddingShape(infer_shape, format, GetInputReshapeType(node, input_idx)); } - return trans::TransShapeToDevice(infer_shape, format); + auto input_node_index = GetPrevNodeOutput(node, input_idx); + return trans::TransShapeToDevice(infer_shape, format, input_node_index.first, input_node_index.second); } std::string AnfRuntimeAlgorithm::GetInputReshapeType(const AnfNodePtr &node, size_t input_idx) { @@ -1414,7 +1415,7 @@ bool AnfRuntimeAlgorithm::IsTupleOutput(const AnfNodePtr &anf) { AnfNodePtr AnfRuntimeAlgorithm::GetInputNode(const CNodePtr &node, size_t index) { MS_EXCEPTION_IF_NULL(node); auto get_input_index = index + 1; - if (index + 1 >= node->inputs().size()) { + if (get_input_index >= node->inputs().size()) { MS_LOG(EXCEPTION) << "Input index size " << get_input_index << "but the node input size just" << node->inputs().size() << " trace: " << trace::DumpSourceLines(node); } @@ -1973,7 +1974,8 @@ std::vector AnfRuntimeAlgorithm::GetInputRealDeviceShapeIfExist(const An auto max_shape = GetInputMaxShape(anf_node, index); std::transform(max_shape.begin(), max_shape.end(), device_shape.begin(), IntToSize); auto format = GetInputFormat(anf_node, index); - trans::TransShapeToDevice(device_shape, format); + auto input_node_index = GetPrevNodeOutput(anf_node, index); + trans::TransShapeToDevice(device_shape, format, input_node_index.first, input_node_index.second); } return device_shape; } @@ -1985,7 +1987,7 @@ std::vector AnfRuntimeAlgorithm::GetOutputRealDeviceShapeIfExist(const A auto max_shape = GetOutputMaxShape(anf_node, index); std::transform(max_shape.begin(), max_shape.end(), device_shape.begin(), IntToSize); auto format = GetOutputFormat(anf_node, index); - trans::TransShapeToDevice(device_shape, format); + trans::TransShapeToDevice(device_shape, format, anf_node, index); } return device_shape; } diff --git a/mindspore/ccsrc/common/trans.cc b/mindspore/ccsrc/common/trans.cc index 9b06236948d..e6732e7210b 100644 --- a/mindspore/ccsrc/common/trans.cc +++ b/mindspore/ccsrc/common/trans.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -49,6 +49,29 @@ inline void SetData(size_t size, bool pad_zero, size_t src_idx, size_t dst_idx, } } +// greatest common divsor +size_t Gcd(size_t a, size_t b) { + if (b == 0) { + return 0; + } + size_t c = b; + while (a % b != 0) { + c = a % b; + a = b; + b = c; + } + return c; +} + +// least common multiple +size_t Lcm(size_t a, size_t b) { + if (b == 0) { + return 0; + } + size_t ret = (a * b) / (Gcd(a, b)); + return ret; +} + template T DivCeil(T n1, T n2) { if (n2 != 0) { @@ -57,6 +80,10 @@ T DivCeil(T n1, T n2) { return 0; } +size_t GetShapeSize(const std::vector &shape) { + return std::accumulate(shape.begin(), shape.end(), size_t(1), std::multiplies()); +} + enum DataTypeTransMode { FROM_FLOAT_TO_FLOAT16, FROM_FLOAT_TO_INT32, @@ -378,8 +405,54 @@ std::vector PaddingShapeTo4dByDefault(const std::vector &shape) } return shape_4d; } + +std::vector FracZDeviceShapeWithGroups(const std::vector &shape, const int64_t groups = 1) { + if (!CheckDims(shape)) { + MS_LOG(EXCEPTION) << "Check dims failed."; + } + size_t group_size = LongToSize(groups); + size_t cin_ori = shape[kC]; + size_t cout_ori = shape[kN] / group_size; + size_t e_mult = std::min(Lcm(Lcm(cin_ori, kCubeSize) / cin_ori, Lcm(cout_ori, kCubeSize) / cout_ori), group_size); + size_t cin_opt = DivCeil(e_mult * cin_ori, kCubeSize) * kCubeSize; + size_t c1_dim = cin_opt / kCubeSize; + size_t g_dim = DivCeil(group_size, e_mult); + size_t n1 = DivCeil(cout_ori * e_mult, kCubeSize); + std::vector device_shape; + device_shape.push_back(g_dim * c1_dim * shape[kH] * shape[kW]); + device_shape.push_back(n1); + device_shape.push_back(kNiSize); + device_shape.push_back(kCubeSize); + return device_shape; +} } // namespace +int64_t GetAttrGroups(const AnfNodePtr &node, const size_t index) { + if (node == nullptr) { + return 1; + } + if (node->isa()) { + auto cnode = node->cast(); + if (AnfAlgo::HasNodeAttr(kAttrFracZGroup, cnode)) { + if (AnfAlgo::GetCNodeName(cnode) == kAllReduceOpName) { + // if index not exists in fracz_group_idx, return default value 1 + auto fz_group_idx = AnfAlgo::GetNodeAttr>(cnode, kAttrFracZGroupIdx); + int64_t out_index = SizeToLong(index); + auto fz_iter = std::find(std::begin(fz_group_idx), std::end(fz_group_idx), out_index); + if (fz_iter == std::end(fz_group_idx)) { + return 1; + } + } + return AnfAlgo::GetNodeAttr(cnode, kAttrFracZGroup); + } + } else if (node->isa()) { + auto param = node->cast(); + MS_EXCEPTION_IF_NULL(param); + return param->fracz_group(); + } + return 1; +} + bool IsNeedPadding(const std::string &format, const size_t shape_size) { if (shape_size == 0) { return false; @@ -547,7 +620,8 @@ std::vector PaddingShapeTo5dDefault(const std::vector &shape) { return shape_5d; } -std::vector TransShapeToDevice(const std::vector &shape, const std::string &format) { +std::vector TransShapeToDevice(const std::vector &shape, const std::string &format, + const int64_t groups) { using DeviceShapeTransfer = std::function(const std::vector &)>; const std::map device_shape_map{{kOpFormat_NCHW, NchwDeviceShape}, {kOpFormat_NHWC, NhwcDeviceShape}, @@ -565,6 +639,9 @@ std::vector TransShapeToDevice(const std::vector &shape, const s if (format == kOpFormat_ND || format == kOpFormat_DEFAULT) { return shape; } + if (groups > 1 && format == kOpFormat_FRAC_Z) { + return FracZDeviceShapeWithGroups(shape, groups); + } auto temp_shape = shape; std::vector device_shape; if (format == kOpFormat_FRAC_NZ) { @@ -610,6 +687,15 @@ std::vector TransShapeToDevice(const std::vector &shape, const s return iter->second(temp_shape); } +std::vector TransShapeToDevice(const std::vector &shape, const std::string &format, + const AnfNodePtr &node, const size_t index) { + int64_t groups = 1; + if (format == kOpFormat_FRAC_Z) { + groups = GetAttrGroups(node, index); + } + return TransShapeToDevice(shape, format, groups); +} + bool CheckArgs(const FormatArgs &args, size_t *size, size_t *total_size) { if (args.host_shape.size() != kNchwDims) { MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims; @@ -649,7 +735,7 @@ bool TransDataType(const TypeIdArgs &args, void *result) { return true; } -bool TransFormat(const FormatArgs &args, void *result) { +bool TransFormat(const FormatArgs &args, void *result, int64_t groups) { MS_LOG(DEBUG) << "Start trans format."; if (abstract::TypeIdSize(args.src_data_type) < 1) { MS_LOG(ERROR) << "Invalid datatype.."; @@ -658,6 +744,9 @@ bool TransFormat(const FormatArgs &args, void *result) { if (args.device_format == kOpFormat_HWCN || args.device_format == kOpFormat_NHWC) { return NchwTo4D(args, result); } + if (groups > 1 && args.device_format == kOpFormat_FRAC_Z) { + return NchwToFracZWithGroups(args, result, groups); + } auto iter = kTransFormatMapOfHostToDevice.find(args.device_format); if (iter == kTransFormatMapOfHostToDevice.end()) { MS_LOG(EXCEPTION) << "Unexpected format[" << args.device_format << "]"; @@ -665,7 +754,15 @@ bool TransFormat(const FormatArgs &args, void *result) { return iter->second(args, result); } -bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result) { +bool TransFormat(const FormatArgs &args, void *result, const AnfNodePtr &node, const size_t index) { + int64_t groups = 1; + if (args.device_format == kOpFormat_FRAC_Z) { + groups = GetAttrGroups(node, index); + } + return TransFormat(args, result, groups); +} + +bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result, int64_t groups) { const std::map format_trans_map{ {kOpFormat_FRAC_Z, FracZToNchw}, {kOpFormat_FRAC_NZ, FracNzToNchw}, {kOpFormat_NC1HWC0, Nc1hwc0ToNchw}, {kOpFormat_C1HWNCoC0, C1hwncoc0ToNchw}, @@ -680,6 +777,9 @@ bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result) { if (args.device_format == kOpFormat_HWCN || args.device_format == kOpFormat_NHWC) { return ToNchw(args, result); } + if (groups > 1 && args.device_format == kOpFormat_FRAC_Z) { + return FracZToNchwWithGroups(args, result, groups); + } auto iter = format_trans_map.find(args.device_format); if (iter == format_trans_map.end()) { MS_LOG(EXCEPTION) << "Unexpected format[" << args.device_format << "]"; @@ -687,6 +787,14 @@ bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result) { return iter->second(args, result); } +bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result, const AnfNodePtr &node, const size_t index) { + int64_t groups = 1; + if (args.device_format == kOpFormat_FRAC_Z) { + groups = GetAttrGroups(node, index); + } + return TransFormatFromDeviceToHost(args, result, groups); +} + bool NchwTo4D(const FormatArgs &args, void *result) { // trans nchw to 4d MS_LOG(DEBUG) << "Trans format from nchw to 4d."; @@ -1477,5 +1585,82 @@ bool FracZ3DToNcdhw(const FormatArgs &args, void *result) { } return true; } + +bool NchwFracZTransWithGroups(const FormatArgs &args, void *result, bool to_device, int64_t groups) { + MS_LOG(DEBUG) << "Trans format from nchw to frac_z"; + MS_EXCEPTION_IF_NULL(result); + if (args.host_shape.size() != kNchwDims) { + MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims; + return false; + } + auto size = abstract::TypeIdSize(args.src_data_type); + if (size < 1) { + MS_LOG(ERROR) << "Illegal dtype"; + return false; + } + auto n_dim = args.host_shape[kN]; + auto c_dim = args.host_shape[kC]; + auto h_dim = args.host_shape[kH]; + auto w_dim = args.host_shape[kW]; + size_t d_dim = 1; + size_t group_size = LongToSize(groups); + auto cin_ori = c_dim; + auto cout_ori = n_dim / group_size; + if (cin_ori == 0 || cout_ori == 0) { + MS_LOG(ERROR) << "cin_ori, cout_ori must not equal to 0"; + return false; + } + size_t e_mult = std::min(Lcm(Lcm(cin_ori, kCubeSize) / cin_ori, Lcm(cout_ori, kCubeSize) / cout_ori), group_size); + size_t cin_opt = DivCeil(e_mult * cin_ori, kCubeSize) * kCubeSize; + size_t cout_opt = DivCeil(e_mult * cout_ori, kCubeSize) * kCubeSize; + size_t c1_dim = cin_opt / kCubeSize; + size_t dst_size = to_device ? GetShapeSize(args.device_shape) * size : GetShapeSize(args.host_shape) * size; + if (dst_size == 0) { + return true; + } + auto ret = memset_s(result, dst_size, 0, dst_size); + if (ret != EOK) { + MS_LOG(ERROR) << "memset failed"; + return false; + } + for (size_t g = 0; g < group_size; ++g) { + for (size_t d = 0; d < d_dim; ++d) { + for (size_t c = 0; c < c_dim; ++c) { + for (size_t h = 0; h < h_dim; ++h) { + for (size_t w = 0; w < w_dim; ++w) { + for (size_t n = 0; n < cout_ori; ++n) { + size_t e_val = g % e_mult; + size_t dst_ci = e_val * cin_ori + c; + size_t dst_co = e_val * cout_ori + n; + size_t src_co = g * cout_ori + n; + size_t temporary = dst_ci % kCubeSize; + size_t dev_idx = (g / e_mult) * d_dim * c1_dim * h_dim * w_dim * cout_opt * kCubeSize + + d * c1_dim * h_dim * w_dim * cout_opt * kCubeSize + + (dst_ci / kCubeSize) * h_dim * w_dim * cout_opt * kCubeSize + + h * w_dim * cout_opt * kCubeSize + w * cout_opt * kCubeSize + dst_co * kCubeSize + + temporary; + size_t hst_idx = + src_co * c_dim * d_dim * h_dim * w_dim + c * d_dim * h_dim * w_dim + d * h_dim * w_dim + h * w_dim + w; + if (to_device) { + SetData(size, false, hst_idx, dev_idx, args, result); + } else { + SetData(size, false, dev_idx, hst_idx, args, result); + } + } + } + } + } + } + } + return true; +} + +bool NchwToFracZWithGroups(const FormatArgs &args, void *result, int64_t groups) { + return NchwFracZTransWithGroups(args, result, true, groups); +} + +bool FracZToNchwWithGroups(const FormatArgs &args, void *result, int64_t groups) { + return NchwFracZTransWithGroups(args, result, false, groups); +} } // namespace trans } // namespace mindspore diff --git a/mindspore/ccsrc/common/trans.h b/mindspore/ccsrc/common/trans.h index f8fb88f3a9a..b8992c07383 100644 --- a/mindspore/ccsrc/common/trans.h +++ b/mindspore/ccsrc/common/trans.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -64,10 +64,16 @@ void StringToAxisVector4D(const std::string &reshape_type_str, std::vector void StringToAxisVector5D(const std::string &reshape_type_str, std::vector *reshape_type_vec); ShapeVector GetRuntimePaddingShape(const AnfNodePtr &node, size_t index); bool IsNeedPadding(const std::string &format, const size_t shape_size); -std::vector TransShapeToDevice(const std::vector &shape, const std::string &format); +int64_t GetNodeGroups(const AnfNodePtr &node); +std::vector TransShapeToDevice(const std::vector &shape, const std::string &format, + const int64_t groups = 1); +std::vector TransShapeToDevice(const std::vector &shape, const std::string &format, + const AnfNodePtr &node, const size_t index); bool TransDataType(const TypeIdArgs &args, void *result); -bool TransFormat(const FormatArgs &args, void *result); -bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result); +bool TransFormat(const FormatArgs &args, void *result, int64_t groups = 1); +bool TransFormat(const FormatArgs &args, void *result, const AnfNodePtr &node, const size_t index); +bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result, int64_t groups = 1); +bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result, const AnfNodePtr &node, const size_t index); // host to device bool NchwTo4D(const FormatArgs &args, void *result); @@ -79,6 +85,7 @@ bool NchwToFracZc04(const FormatArgs &args, void *result); bool NchwToNc1hwc04(const FormatArgs &args, void *result); bool NchwToC1hwncoc0(const FormatArgs &args, void *result); bool NcdhwToNdc1hwc0(const FormatArgs &args, void *result); +bool NchwToFracZWithGroups(const FormatArgs &args, void *result, int64_t groups); // device to host bool ToNchw(const FormatArgs &args, void *result); @@ -89,6 +96,7 @@ bool Nc1hwc04ToNchw(const FormatArgs &args, void *result); bool FracZ3DToNcdhw(const FormatArgs &args, void *result); bool C1hwncoc0ToNchw(const FormatArgs &args, void *result); bool Ndc1hwc0ToNcdhw(const FormatArgs &args, void *result); +bool FracZToNchwWithGroups(const FormatArgs &args, void *result, int64_t groups); using FormatTransfer = std::function; const std::map kTransFormatMapOfHostToDevice{ {kOpFormat_FRAC_Z, NchwToFracZ}, {kOpFormat_FRAC_NZ, NchwToFracNz}, diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc b/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc index 04215057e6f..2900a0c8d63 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc @@ -1,5 +1,5 @@ /** - * Copyright 2019-2020 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -480,8 +480,9 @@ std::vector AscendDeviceAddress::GetWorkspaceSizeList(const nlohmann::js std::vector AscendDeviceAddress::GetDeviceShape(std::vector *host_shape) const { std::vector device_shape; + auto node_index = GetNodeIndex(); if (format_ == kOpFormat_FRAC_NZ || format_ == kOpFormat_NCDHW) { - device_shape = trans::TransShapeToDevice(*host_shape, format_); + device_shape = trans::TransShapeToDevice(*host_shape, format_, node_index.first, node_index.second); } else { if (host_shape_.empty()) { *host_shape = trans::PaddingShape(*host_shape, format_); @@ -489,7 +490,7 @@ std::vector AscendDeviceAddress::GetDeviceShape(std::vector *hos host_shape->clear(); (void)std::transform(host_shape_.begin(), host_shape_.end(), std::back_inserter(*host_shape), LongToSize); } - device_shape = trans::TransShapeToDevice(*host_shape, format_); + device_shape = trans::TransShapeToDevice(*host_shape, format_, node_index.first, node_index.second); } return device_shape; } @@ -518,11 +519,12 @@ bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormat(const ShapeVector &sh } auto host_tmp = std::vector(size_); SyncMemory(host_tmp.data(), ptr_, size_, RT_MEMCPY_DEVICE_TO_HOST); + auto node_index = GetNodeIndex(); if (type_id_ != type) { const trans::FormatArgs format_args{host_tmp.data(), size_, kOpFormat_NCHW, format_, host_shape, device_shape, type_id_}; auto host = std::vector(size_); - sync_ok = trans::TransFormatFromDeviceToHost(format_args, host.data()); + sync_ok = trans::TransFormatFromDeviceToHost(format_args, host.data(), node_index.first, node_index.second); if (!sync_ok) { MS_LOG(ERROR) << "Trans format failed."; return false; @@ -537,7 +539,7 @@ bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormat(const ShapeVector &sh } else { const trans::FormatArgs format_args{host_tmp.data(), size_, kOpFormat_NCHW, format_, host_shape, device_shape, type_id_}; - sync_ok = trans::TransFormatFromDeviceToHost(format_args, host_ptr); + sync_ok = trans::TransFormatFromDeviceToHost(format_args, host_ptr, node_index.first, node_index.second); if (!sync_ok) { MS_LOG(ERROR) << "Trans format failed."; return false; @@ -604,12 +606,13 @@ bool AscendDeviceAddress::ConvertFormatAndSyncHostToDevice(const ShapeVector &sh if (host_shape.empty()) { host_shape.emplace_back(1); } + auto node_index = GetNodeIndex(); std::vector device_shape; if (format_ == kOpFormat_FRAC_NZ) { - device_shape = trans::TransShapeToDevice(host_shape, format_); + device_shape = trans::TransShapeToDevice(host_shape, format_, node_index.first, node_index.second); } else { host_shape = trans::PaddingShape(host_shape, format_); - device_shape = trans::TransShapeToDevice(host_shape, format_); + device_shape = trans::TransShapeToDevice(host_shape, format_, node_index.first, node_index.second); } if (type_id_ != type) { auto shape_size = abstract::ShapeSize(host_shape); @@ -623,7 +626,7 @@ bool AscendDeviceAddress::ConvertFormatAndSyncHostToDevice(const ShapeVector &sh const trans::FormatArgs format_args{host_tmp.data(), size_, kOpFormat_NCHW, format_, host_shape, device_shape, type_id_}; auto dst_tmp = std::vector(size_); - sync_ok = trans::TransFormat(format_args, dst_tmp.data()); + sync_ok = trans::TransFormat(format_args, dst_tmp.data(), node_index.first, node_index.second); if (!sync_ok) { MS_LOG(ERROR) << "Trans format failed."; return false; @@ -632,7 +635,7 @@ bool AscendDeviceAddress::ConvertFormatAndSyncHostToDevice(const ShapeVector &sh } else { const trans::FormatArgs format_args{host_ptr, size_, kOpFormat_NCHW, format_, host_shape, device_shape, type_id_}; auto host_tmp = std::vector(size_); - sync_ok = trans::TransFormat(format_args, host_tmp.data()); + sync_ok = trans::TransFormat(format_args, host_tmp.data(), node_index.first, node_index.second); if (!sync_ok) { MS_LOG(ERROR) << "Trans format failed."; return false; diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.h b/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.h index 4c5b666b69f..1b88c5b7bee 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.h +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.h @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -38,6 +38,9 @@ class AscendDeviceAddress : public DeviceAddress { explicit AscendDeviceAddress(void *ptr, size_t size) : DeviceAddress(ptr, size) {} explicit AscendDeviceAddress(void *ptr, size_t size, const std::string &format, TypeId type_id) : DeviceAddress(ptr, size, format, type_id) {} + explicit AscendDeviceAddress(void *ptr, size_t size, const std::string &format, TypeId type_id, + const AnfNodePtr &node, size_t out_index) + : DeviceAddress(ptr, size, format, type_id, node, out_index) {} ~AscendDeviceAddress() override; bool SyncDeviceToHost(size_t size, void *host_ptr) const override; bool SyncHostToDevice(size_t size, const void *host_ptr) const override; diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc index 6fb1111e124..8e6c3a26a29 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -333,6 +333,11 @@ DeviceAddressPtr AscendKernelRuntime::CreateDeviceAddress(void *device_ptr, size return std::make_shared(device_ptr, device_size, format, type_id); } +DeviceAddressPtr AscendKernelRuntime::CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, + TypeId type_id, const AnfNodePtr &node, size_t out_index) { + return std::make_shared(device_ptr, device_size, format, type_id, node, out_index); +} + bool AscendKernelRuntime::Load(session::KernelGraph *graph, bool is_task_sink) { if (!is_task_sink) { GenKernelEvents(graph); diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.h b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.h index 3f3b09c142f..1ea870e5e1e 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.h +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.h @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -66,6 +66,8 @@ class AscendKernelRuntime : public KernelRuntime { protected: DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, TypeId type_id) override; + DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, TypeId type_id, + const AnfNodePtr &node, size_t out_index) override; bool KernelMemNotReuse(const AnfNodePtr &node) override; void KernelLaunchProfiling(const std::string &kernel_name) override; diff --git a/mindspore/ccsrc/runtime/device/cpu/cpu_device_address.cc b/mindspore/ccsrc/runtime/device/cpu/cpu_device_address.cc index 410de8ac867..c4e39bd1964 100644 --- a/mindspore/ccsrc/runtime/device/cpu/cpu_device_address.cc +++ b/mindspore/ccsrc/runtime/device/cpu/cpu_device_address.cc @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/mindspore/ccsrc/runtime/device/cpu/cpu_device_address.h b/mindspore/ccsrc/runtime/device/cpu/cpu_device_address.h index dac5d5a5e9f..85b57d2ee8a 100644 --- a/mindspore/ccsrc/runtime/device/cpu/cpu_device_address.h +++ b/mindspore/ccsrc/runtime/device/cpu/cpu_device_address.h @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -31,6 +31,10 @@ class CPUDeviceAddress : public DeviceAddress { CPUDeviceAddress(void *ptr, size_t size, const string &format, TypeId type_id) : DeviceAddress(ptr, size, format, type_id) {} + CPUDeviceAddress(void *ptr, size_t size, const std::string &format, TypeId type_id, const AnfNodePtr &node, + size_t out_index) + : DeviceAddress(ptr, size, format, type_id, node, out_index) {} + ~CPUDeviceAddress() override = default; bool SyncDeviceToHost(const ShapeVector &shape, size_t size, TypeId type, void *host_ptr) const override; diff --git a/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc b/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc index c13a6747c29..a6347fd601f 100644 --- a/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc @@ -175,6 +175,11 @@ DeviceAddressPtr CPUKernelRuntime::CreateDeviceAddress(void *device_ptr, size_t return std::make_shared(device_ptr, device_size, format, type_id); } +DeviceAddressPtr CPUKernelRuntime::CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, + TypeId type_id, const AnfNodePtr &node, size_t out_index) { + return std::make_shared(device_ptr, device_size, format, type_id, node, out_index); +} + tensor::TensorPtr CPUKernelRuntime::CreatTensorForOutput( session::KernelGraph *kernel_graph, const CNodePtr &node, size_t index, std::map *tensor_to_node) { diff --git a/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.h b/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.h index 0c402865369..2c229f79614 100644 --- a/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.h +++ b/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.h @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -53,6 +53,8 @@ class CPUKernelRuntime : public KernelRuntime { bool MemcpyAsync(void *dst, const void *src, uint64_t size, int32_t kind) override { return true; }; DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, TypeId type_id) override; + DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, TypeId type_id, + const AnfNodePtr &node, size_t out_index) override; private: tensor::TensorPtr CreatTensorForOutput(session::KernelGraph *kernel_graph, const CNodePtr &node, size_t index, diff --git a/mindspore/ccsrc/runtime/device/device_address.h b/mindspore/ccsrc/runtime/device/device_address.h index 70ee2014042..06c81216219 100644 --- a/mindspore/ccsrc/runtime/device/device_address.h +++ b/mindspore/ccsrc/runtime/device/device_address.h @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,6 +21,7 @@ #include #include #include +#include #include "ir/dtype.h" #include "ir/device_sync.h" #include "utils/shape_utils.h" @@ -52,6 +53,7 @@ class GPUDeviceContext; namespace mindspore { namespace device { +using KernelWithIndex = std::pair; enum class DeviceAddressStatus { kInDevice, kInHost, kInDeviceToHost, kInHostToDevice }; enum class DeviceAddressType { kUnknown, kAscend, kCPU, kGPU }; static const std::map kDeviceTypeToName = {{DeviceAddressType::kUnknown, "Unknown"}, @@ -64,6 +66,9 @@ class DeviceAddress : public mindspore::DeviceSync { explicit DeviceAddress(void *ptr, size_t size) : ptr_(ptr), size_(size) {} explicit DeviceAddress(void *ptr, size_t size, const string &format, TypeId type_id) : ptr_(ptr), size_(size), format_(format), type_id_(type_id) {} + explicit DeviceAddress(void *ptr, size_t size, const std::string &format, TypeId type_id, const AnfNodePtr &node, + size_t out_index) + : ptr_(ptr), size_(size), format_(format), type_id_(type_id), node_index_({node, out_index}) {} virtual ~DeviceAddress() { ptr_ = nullptr; } const void *GetPtr() const { return ptr_; } size_t GetSize() const { return size_; } @@ -75,6 +80,7 @@ class DeviceAddress : public mindspore::DeviceSync { virtual DeviceAddressStatus status() const { return DeviceAddressStatus::kInDevice; } virtual DeviceAddressType DeviceType() const { return DeviceAddressType::kUnknown; } void *GetMutablePtr() const override { return ptr_; } + virtual void SetNodeIndex(const AnfNodePtr &node, size_t out_index) { node_index_ = {node, out_index}; } // The related interface of reference count operation. void set_original_ref_count(size_t original_ref_count) { original_ref_count_ = original_ref_count; } @@ -100,6 +106,10 @@ class DeviceAddress : public mindspore::DeviceSync { const void *ptr() const { return ptr_; } size_t size() const { return size_; } void set_ptr(void *ptr) { ptr_ = ptr; } + KernelWithIndex GetNodeIndex() const { + return node_index_.first.expired() ? KernelWithIndex{nullptr, node_index_.second} + : KernelWithIndex{node_index_.first.lock(), node_index_.second}; + } void *ptr_{nullptr}; size_t size_{0}; size_t original_ref_count_{1}; @@ -110,6 +120,8 @@ class DeviceAddress : public mindspore::DeviceSync { bool from_mem_pool_{false}; uint8_t *communication_ptr_{nullptr}; ShapeVector host_shape_{}; + // {node, out_index} + std::pair node_index_{AnfNodePtr(nullptr), 0}; friend class KernelRuntime; friend class MemoryManager; friend class mindspore::device::ascend::tasksink::TaskGenerator; diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_device_address.cc b/mindspore/ccsrc/runtime/device/gpu/gpu_device_address.cc index 71948593543..a0332788be8 100644 --- a/mindspore/ccsrc/runtime/device/gpu/gpu_device_address.cc +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_device_address.cc @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_device_address.h b/mindspore/ccsrc/runtime/device/gpu/gpu_device_address.h index ef5f4a9c3e8..4b14e66c210 100644 --- a/mindspore/ccsrc/runtime/device/gpu/gpu_device_address.h +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_device_address.h @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -34,6 +34,9 @@ class GPUDeviceAddress : public DeviceAddress { GPUDeviceAddress(void *ptr, size_t size) : DeviceAddress(ptr, size) {} GPUDeviceAddress(void *ptr, size_t size, const string &format, TypeId type_id) : DeviceAddress(ptr, size, format, type_id) {} + GPUDeviceAddress(void *ptr, size_t size, const std::string &format, TypeId type_id, const AnfNodePtr &node, + size_t out_index) + : DeviceAddress(ptr, size, format, type_id, node, out_index) {} ~GPUDeviceAddress() override; bool SyncDeviceToHost(size_t size, void *host_ptr) const override; @@ -51,6 +54,7 @@ class GPUDeviceAddress : public DeviceAddress { bool LoadMemToHost(const std::string &tensor_name, int execution_order, const std::string &host_fmt, const ShapeVector &host_shape, TypeId host_type, size_t slot, bool keep_prev) const override; #endif + private: DeviceAddressStatus status_{DeviceAddressStatus::kInDevice}; }; diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc index 54d42cf4227..4eabb2bd0a9 100644 --- a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc @@ -211,6 +211,11 @@ DeviceAddressPtr GPUKernelRuntime::CreateDeviceAddress(void *device_ptr, size_t return std::make_shared(device_ptr, device_size, format, type_id); } +DeviceAddressPtr GPUKernelRuntime::CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, + TypeId type_id, const AnfNodePtr &node, size_t out_index) { + return std::make_shared(device_ptr, device_size, format, type_id, node, out_index); +} + bool GPUKernelRuntime::InitDevice() { if (GPUDeviceManager::GetInstance().device_count() <= 0) { MS_LOG(ERROR) << "No GPU device found."; diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.h b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.h index d7df42f2ce6..360d0639376 100644 --- a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.h +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.h @@ -58,6 +58,8 @@ class GPUKernelRuntime : public KernelRuntime { protected: DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, TypeId type_id) override; + DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, TypeId type_id, + const AnfNodePtr &node, size_t out_index) override; bool SyncStream() override; bool MemcpyAsync(void *dst, const void *src, uint64_t size, int32_t kind) override; diff --git a/mindspore/ccsrc/runtime/device/kernel_runtime.cc b/mindspore/ccsrc/runtime/device/kernel_runtime.cc index d84914a0471..6e2bdaae74a 100644 --- a/mindspore/ccsrc/runtime/device/kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/kernel_runtime.cc @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -165,7 +165,7 @@ void KernelRuntime::RunOpAssignInputMemory(const std::vector } auto tensor_size = AnfAlgo::GetOutputTensorMemSize(item, index); auto device_address = - CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id); + CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id, item, index); MS_EXCEPTION_IF_NULL(device_address); MS_EXCEPTION_IF_NULL(mem_manager_); auto ret = mem_manager_->MallocMemFromMemPool(device_address, tensor_size); @@ -198,7 +198,7 @@ void KernelRuntime::RunOpAssignOutputMemory(const AnfNodePtr &kernel) { } std::string output_format = AnfAlgo::GetOutputFormat(kernel, i); auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel, i); - auto device_address = CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type); + auto device_address = CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type, kernel, i); device_address->set_host_shape(trans::GetRuntimePaddingShape(kernel, i)); MS_EXCEPTION_IF_NULL(device_address); auto ret = mem_manager_->MallocMemFromMemPool(device_address, output_sizes[i]); @@ -334,14 +334,15 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) { } const auto &address = ps::ps_cache_instance.QueryHashTableAddr(param_name); MS_EXCEPTION_IF_NULL(address.addr); - device_address = - CreateDeviceAddress(address.addr, address.size, AnfAlgo::GetOutputFormat(item, index), output_type_id); + device_address = CreateDeviceAddress(address.addr, address.size, AnfAlgo::GetOutputFormat(item, index), + output_type_id, item, index); AnfAlgo::SetOutputAddr(device_address, index, item.get()); continue; } #endif auto tensor_size = AnfAlgo::GetOutputTensorMemSize(item, index); - device_address = CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id); + device_address = + CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id, item, index); MS_LOG(INFO) << "Assign Static Memory for Input node, size:" << tensor_size << " node:" << item->fullname_with_scope() << " index: " << index; if (mem_manager_->MallocMem(kStaticMem, tensor_size, device_address, graph->graph_id()) == nullptr) { @@ -512,7 +513,7 @@ void KernelRuntime::AssignCommunicationNodeOutputMem(MemType type, const AnfNode for (size_t j = 0; j < align_size_list.size(); ++j) { std::string output_format = AnfAlgo::GetOutputFormat(node, j); auto output_type = AnfAlgo::GetOutputDeviceDataType(node, j); - auto address = CreateDeviceAddress(nullptr, output_sizes[j], output_format, output_type); + auto address = CreateDeviceAddress(nullptr, output_sizes[j], output_format, output_type, node, j); MS_EXCEPTION_IF_NULL(address); if (output_ptr == nullptr) { output_ptr = mem_manager_->MallocOutputMem(node, 0, type, total_size, address, true); @@ -548,7 +549,7 @@ DeviceAddressPtr KernelRuntime::PreAssignCNodeMemory(const AnfNodePtr &anf_node, } std::string output_format = AnfAlgo::GetOutputFormat(anf_node, index); auto output_type = AnfAlgo::GetOutputDeviceDataType(anf_node, index); - auto address = CreateDeviceAddress(nullptr, output_sizes[index], output_format, output_type); + auto address = CreateDeviceAddress(nullptr, output_sizes[index], output_format, output_type, anf_node, index); AnfAlgo::SetOutputAddr(address, index, anf_node.get()); return address; } @@ -641,7 +642,7 @@ void KernelRuntime::AssignNodeOutputMem(MemType type, const AnfNodePtr &node, in } std::string output_format = AnfAlgo::GetOutputFormat(node, i); auto output_type = AnfAlgo::GetOutputDeviceDataType(node, i); - auto device_address = CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type); + auto device_address = CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type, node, i); MS_EXCEPTION_IF_NULL(device_address); uint8_t *ptr = mem_manager_->MallocOutputMem(node, i, type, output_sizes[i], device_address, false); MS_EXCEPTION_IF_NULL(ptr); @@ -681,7 +682,8 @@ void KernelRuntime::AssignValueNodeTensor(const ValueNodePtr &value_node, const output_type_id = AnfAlgo::GetOutputInferDataType(value_node, output_idx); } auto output_format = AnfAlgo::GetOutputFormat(value_node, output_idx); - DeviceAddressPtr address = CreateDeviceAddress(nullptr, node_size, output_format, output_type_id); + DeviceAddressPtr address = + CreateDeviceAddress(nullptr, node_size, output_format, output_type_id, value_node, output_idx); MS_EXCEPTION_IF_NULL(address); if (ms_context->get_param(MS_CTX_ENABLE_PYNATIVE_INFER) && !mem_manager_->MallocMemFromMemPool(address, node_size)) { diff --git a/mindspore/ccsrc/runtime/device/kernel_runtime.h b/mindspore/ccsrc/runtime/device/kernel_runtime.h index de7cfb4be9c..c1172f48709 100644 --- a/mindspore/ccsrc/runtime/device/kernel_runtime.h +++ b/mindspore/ccsrc/runtime/device/kernel_runtime.h @@ -1,5 +1,5 @@ /** - * Copyright 2019-2020 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -114,6 +114,8 @@ class KernelRuntime { protected: virtual DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, TypeId type_id) = 0; + virtual DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, + TypeId type_id, const AnfNodePtr &node, size_t out_index) = 0; virtual bool NodeOutputDeviceAddressExist(const AnfNodePtr &node, size_t index); virtual bool KernelMemNotReuse(const AnfNodePtr &node); diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index 03d670b9689..564691dddad 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -358,6 +358,9 @@ constexpr auto kAttrLabelForInsertStreamActive = "label_for_insert_stream_active constexpr auto kAttrFpBpEnd = "fpbp_end"; constexpr auto kAttrFusion = "fusion"; constexpr auto kAttrGroup = "group"; +constexpr auto kAttrGroups = "groups"; +constexpr auto kAttrFracZGroup = "fracz_group"; +constexpr auto kAttrFracZGroupIdx = "fracz_group_idx"; constexpr auto kAttrOp = "op"; constexpr auto kAttrDestRank = "dest_rank"; constexpr auto kAttrSrcRank = "src_rank"; @@ -450,6 +453,7 @@ const size_t kShape2dDims = 2; const size_t kShape5dDims = 5; const size_t kShape1dDims = 1; const size_t kCubeSize = 16; +const size_t kNiSize = 16; const size_t kMemAlignSize = 512; const size_t kBNChannelMultipleFactor = 4; const int kParameterDataTensorMask = 0; @@ -565,9 +569,12 @@ const std::set kOptOperatorSet = {kMomentumOpName, kApplyProximalAdagradOpName, kApplyProximalGradientDescentOpName, kApplyRMSPropOpName, + kAdamApplyOneWithDecayOpName, + kAdamApplyOneWithDecayAssignOpName, kFusedAdamWeightDecayName, kFusedAdamName, kFusedSparseAdamName, + kFusedMulApplyMomentumOpName, kFusedWeightScaleApplyMomentum, kFusedScaleApplyMomentum, kApplyCenteredRMSPropOpName, diff --git a/mindspore/core/ir/anf.h b/mindspore/core/ir/anf.h index 105021c205a..e3c94277a5d 100644 --- a/mindspore/core/ir/anf.h +++ b/mindspore/core/ir/anf.h @@ -418,6 +418,9 @@ class Parameter : public ANode { void set_used_by_dynamic_kernel(bool used) { is_used_by_dynamic_kernel_ = used; } bool is_used_by_dynamic_kernel() { return is_used_by_dynamic_kernel_; } + void set_fracz_group(int64_t fracz_group) { fracz_group_ = fracz_group; } + int64_t fracz_group() { return fracz_group_; } + private: std::string name_; bool has_default_; @@ -426,6 +429,8 @@ class Parameter : public ANode { ValuePtr default_param_; // The count of graphs using the parameter. int used_graph_count_; + // groups attr in FracZ format + int64_t fracz_group_ = 1; }; using ParameterPtr = std::shared_ptr;