forked from mindspore-Ecosystem/mindspore
!18538 support FracZ format conversion when conv2d group > 1
Merge pull request !18538 from yuchaojie/op_select
This commit is contained in:
commit
0966595cfd
|
@ -69,6 +69,7 @@
|
|||
#include "backend/optimizer/ascend/ir_fusion/batchnormgrad_to_bninfergrad.h"
|
||||
#include "backend/optimizer/ascend/ir_fusion/confusion_mul_grad_fusion.h"
|
||||
#include "backend/optimizer/ascend/ir_fusion/softmax_grad_ext_fusion.h"
|
||||
#include "backend/optimizer/ascend/ir_fusion/set_fracz_group_attr.h"
|
||||
#include "backend/optimizer/ascend/format_type/insert_trans_op.h"
|
||||
#include "backend/optimizer/ascend/format_type/trans_op_format_refine.h"
|
||||
#include "backend/optimizer/ascend/format_type/dynamic_rnn_grad_reformat.h"
|
||||
|
@ -369,6 +370,8 @@ void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kern
|
|||
other_pm->AddPass(std::make_shared<ParameterTransOpFusion>());
|
||||
other_pm->AddPass(std::make_shared<RefreshParameterFormat>());
|
||||
other_pm->AddPass(std::make_shared<SplitOpOptimizer>());
|
||||
other_pm->AddPass(std::make_shared<SetFraczGroupAttr>());
|
||||
other_pm->AddPass(std::make_shared<EliminateRedundantOp>());
|
||||
optimizer->AddPassManager(other_pm);
|
||||
(void)optimizer->Optimize(kernel_graph);
|
||||
kernel_graph->SetExecOrderByDefault();
|
||||
|
|
|
@ -99,6 +99,14 @@ CNodePtr CreateFusionOp(const std::vector<AnfNodePtr> &inputs_list, const std::v
|
|||
ValuePtr output_names_v = MakeValue(output_names);
|
||||
fusion_op->set_attr("input_names", input_names_v);
|
||||
fusion_op->set_attr("output_names", output_names_v);
|
||||
for (auto node : anf_nodes) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (AnfAlgo::HasNodeAttr(kAttrFracZGroup, cnode)) {
|
||||
auto fracz_group = AnfAlgo::GetNodeAttr<int64_t>(node, kAttrFracZGroup);
|
||||
fusion_op->set_attr(kAttrFracZGroup, MakeValue(fracz_group));
|
||||
break;
|
||||
}
|
||||
}
|
||||
std::vector<AnfNodePtr> fusion_inputs_list = inputs_list;
|
||||
auto value_node = std::make_shared<ValueNode>(fusion_op);
|
||||
(void)fusion_inputs_list.insert(fusion_inputs_list.begin(), value_node);
|
||||
|
|
|
@ -23,9 +23,7 @@
|
|||
namespace mindspore {
|
||||
namespace opt {
|
||||
void DoRefresh(const CNodePtr &cnode) {
|
||||
if (cnode == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "node is nullptr";
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(cnode);
|
||||
for (size_t input_index = 0; input_index < input_num; input_index++) {
|
||||
auto input_kernel_node = AnfAlgo::VisitKernel(AnfAlgo::GetInputNode(cnode, input_index), 0).first;
|
||||
|
|
|
@ -0,0 +1,167 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/optimizer/ascend/ir_fusion/set_fracz_group_attr.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "utils/utils.h"
|
||||
#include "runtime/device/kernel_info.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
constexpr auto kTupleGetItemName = "TupleGetItem";
|
||||
constexpr auto kDependName = "Depend";
|
||||
constexpr auto kLoadName = "Load";
|
||||
constexpr size_t kConvFilterInputIndex = 2;
|
||||
constexpr size_t kElemFilterInputIndex = 1;
|
||||
|
||||
void SetAttrForInputNode(const AnfNodePtr &node, int64_t groups) {
|
||||
if (node == nullptr) {
|
||||
return;
|
||||
}
|
||||
if (node->isa<Parameter>()) {
|
||||
auto param = node->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(param);
|
||||
param->set_fracz_group(groups);
|
||||
MS_LOG(INFO) << "set parameter " << param->fullname_with_scope() << " with fracz_group: " << groups;
|
||||
} else if (node->isa<CNode>()) {
|
||||
AnfAlgo::SetNodeAttr(kAttrFracZGroup, MakeValue(groups), node);
|
||||
if (AnfAlgo::GetCNodeName(node) == kTransDataOpName) {
|
||||
AnfAlgo::SetNodeAttr(kAttrGroups, MakeValue(groups), node);
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
SetAttrForInputNode(cnode->input(kElemFilterInputIndex), groups);
|
||||
}
|
||||
}
|
||||
|
||||
void SetAttrForConvInput(const CNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto groups = AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrGroups);
|
||||
if (groups > 1) {
|
||||
SetAttrForInputNode(cnode->input(kConvFilterInputIndex), groups);
|
||||
}
|
||||
}
|
||||
|
||||
void SetAttrForOptParamInput(const AnfNodePtr &node, int64_t groups) {
|
||||
// For optimizer, there may be other parameters used by opt need to be set.
|
||||
// For example, moments param used by FusedMulApplyMomentum.
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto opt_cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(opt_cnode);
|
||||
auto opt_inputs = opt_cnode->inputs();
|
||||
for (size_t i = 1; i < opt_inputs.size(); ++i) {
|
||||
auto input_node = opt_inputs[i];
|
||||
if (input_node->isa<CNode>() && AnfAlgo::GetCNodeName(input_node) == kLoadName) {
|
||||
auto input_cnode = input_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(input_cnode);
|
||||
input_node = input_cnode->input(kElemFilterInputIndex);
|
||||
}
|
||||
if (input_node->isa<Parameter>()) {
|
||||
auto param = input_node->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(param);
|
||||
param->set_fracz_group(groups);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void SetFracZGroupIdxForAllReduce(const AnfNodePtr &node, const int64_t index) {
|
||||
// When Allreduce do fusion, there may be several FracZ outputs with groups=1 or groups>1,
|
||||
// so we need to record the output index with groups>1
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto allreduce = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(allreduce);
|
||||
if (AnfAlgo::HasNodeAttr(kAttrFracZGroupIdx, allreduce)) {
|
||||
auto fz_group_idx = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(allreduce, kAttrFracZGroupIdx);
|
||||
fz_group_idx.push_back(index);
|
||||
AnfAlgo::SetNodeAttr(kAttrFracZGroupIdx, MakeValue(fz_group_idx), allreduce);
|
||||
} else {
|
||||
AnfAlgo::SetNodeAttr(kAttrFracZGroupIdx, MakeValue(std::vector<int64_t>{index}), allreduce);
|
||||
}
|
||||
}
|
||||
|
||||
void SetAttrForOutputNode(const FuncGraphManagerPtr &manager, const AnfNodePtr &node, int64_t groups,
|
||||
int64_t getitem_idx = 0) {
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (AnfAlgo::GetCNodeName(node) != kTupleGetItemName) {
|
||||
AnfAlgo::SetNodeAttr(kAttrFracZGroup, MakeValue(groups), node);
|
||||
}
|
||||
for (auto node_index : manager->node_users()[node]) {
|
||||
auto output_node = node_index.first;
|
||||
auto output_name = AnfAlgo::GetCNodeName(output_node);
|
||||
if (kOptOperatorSet.find(output_name) != kOptOperatorSet.end()) {
|
||||
SetAttrForOptParamInput(output_node, groups);
|
||||
AnfAlgo::SetNodeAttr(kAttrFracZGroup, MakeValue(groups), output_node);
|
||||
} else if (output_name == kTransDataOpName) {
|
||||
// Trans to other format, no need to recurse, but need to set Groups attr for TransData
|
||||
AnfAlgo::SetNodeAttr(kAttrGroups, MakeValue(groups), output_node);
|
||||
} else if (output_name == kAllReduceOpName) {
|
||||
int64_t index = static_cast<int64_t>(node_index.second) - 1;
|
||||
SetFracZGroupIdxForAllReduce(output_node, index);
|
||||
SetAttrForOutputNode(manager, output_node, groups, index);
|
||||
} else if (output_name == kTupleGetItemName) {
|
||||
auto getitem = output_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(getitem);
|
||||
auto getitem_input2 = getitem->input(kInputNodeOutputIndexInTupleGetItem);
|
||||
auto output_idx = GetValue<int64_t>(GetValueNode(getitem_input2));
|
||||
if (output_idx == getitem_idx) {
|
||||
SetAttrForOutputNode(manager, output_node, groups);
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
SetAttrForOutputNode(manager, output_node, groups);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void SetAttrForConvOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
auto groups = AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrGroups);
|
||||
if (groups > 1) {
|
||||
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
auto manager = kernel_graph->manager();
|
||||
SetAttrForOutputNode(manager, cnode, groups);
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool SetFraczGroupAttr::Run(const FuncGraphPtr &func_graph) {
|
||||
if (func_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "func_graph is nullptr.";
|
||||
return false;
|
||||
}
|
||||
std::vector<AnfNodePtr> node_list = TopoSort(func_graph->get_return());
|
||||
for (auto node : node_list) {
|
||||
if (node == nullptr || !node->isa<CNode>()) {
|
||||
continue;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (cnode == nullptr) {
|
||||
continue;
|
||||
}
|
||||
auto node_name = AnfAlgo::GetCNodeName(cnode);
|
||||
if (node_name == kConv2DOpName || node_name == kConv2DBackpropInputOpName) {
|
||||
SetAttrForConvInput(cnode);
|
||||
} else if (node_name == kConv2DBackpropFilterOpName) {
|
||||
SetAttrForConvOutput(func_graph, cnode);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,37 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_SET_FRACZ_GROUP_ATTR_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_SET_FRACZ_GROUP_ATTR_H_
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include "ir/anf.h"
|
||||
#include "backend/optimizer/common/pass.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class SetFraczGroupAttr : public Pass {
|
||||
public:
|
||||
SetFraczGroupAttr() : Pass("set_fracz_group_attr") {}
|
||||
~SetFraczGroupAttr() override = default;
|
||||
bool Run(const FuncGraphPtr &graph) override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_SET_FRACZ_GROUP_ATTR_H_
|
|
@ -599,7 +599,7 @@ size_t AnfRuntimeAlgorithm::GetOutputTensorMemSize(const AnfNodePtr &node, size_
|
|||
auto format = AnfAlgo::GetOutputFormat(node, output_index);
|
||||
if (shape.empty() && format != kOpFormat_DEFAULT) {
|
||||
shape = trans::PaddingShape(shape, format, AnfAlgo::GetOutputReshapeType(node, output_index));
|
||||
shape = trans::TransShapeToDevice(shape, format);
|
||||
shape = trans::TransShapeToDevice(shape, format, node, output_index);
|
||||
}
|
||||
// scalar's output shape is a empty vector
|
||||
size_t tensor_size = std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies<size_t>());
|
||||
|
@ -813,7 +813,7 @@ std::vector<size_t> AnfRuntimeAlgorithm::GetOutputDeviceShape(const AnfNodePtr &
|
|||
if (trans::IsNeedPadding(format, infer_shape.size())) {
|
||||
infer_shape = trans::PaddingShape(infer_shape, format, GetOutputReshapeType(node, output_idx));
|
||||
}
|
||||
return trans::TransShapeToDevice(infer_shape, format);
|
||||
return trans::TransShapeToDevice(infer_shape, format, node, output_idx);
|
||||
}
|
||||
|
||||
std::vector<size_t> AnfRuntimeAlgorithm::GetInputDeviceShape(const AnfNodePtr &node, size_t input_idx) {
|
||||
|
@ -826,7 +826,8 @@ std::vector<size_t> AnfRuntimeAlgorithm::GetInputDeviceShape(const AnfNodePtr &n
|
|||
if (trans::IsNeedPadding(format, infer_shape.size())) {
|
||||
infer_shape = trans::PaddingShape(infer_shape, format, GetInputReshapeType(node, input_idx));
|
||||
}
|
||||
return trans::TransShapeToDevice(infer_shape, format);
|
||||
auto input_node_index = GetPrevNodeOutput(node, input_idx);
|
||||
return trans::TransShapeToDevice(infer_shape, format, input_node_index.first, input_node_index.second);
|
||||
}
|
||||
|
||||
std::string AnfRuntimeAlgorithm::GetInputReshapeType(const AnfNodePtr &node, size_t input_idx) {
|
||||
|
@ -1423,7 +1424,7 @@ bool AnfRuntimeAlgorithm::IsTupleOutput(const AnfNodePtr &anf) {
|
|||
AnfNodePtr AnfRuntimeAlgorithm::GetInputNode(const CNodePtr &node, size_t index) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto get_input_index = index + 1;
|
||||
if (index + 1 >= node->inputs().size()) {
|
||||
if (get_input_index >= node->inputs().size()) {
|
||||
MS_LOG(EXCEPTION) << "Input index size " << get_input_index << "but the node input size just"
|
||||
<< node->inputs().size() << " trace: " << trace::DumpSourceLines(node);
|
||||
}
|
||||
|
@ -1982,7 +1983,8 @@ std::vector<size_t> AnfRuntimeAlgorithm::GetInputRealDeviceShapeIfExist(const An
|
|||
auto max_shape = GetInputMaxShape(anf_node, index);
|
||||
std::transform(max_shape.begin(), max_shape.end(), device_shape.begin(), IntToSize);
|
||||
auto format = GetInputFormat(anf_node, index);
|
||||
trans::TransShapeToDevice(device_shape, format);
|
||||
auto input_node_index = GetPrevNodeOutput(anf_node, index);
|
||||
trans::TransShapeToDevice(device_shape, format, input_node_index.first, input_node_index.second);
|
||||
}
|
||||
return device_shape;
|
||||
}
|
||||
|
@ -1994,7 +1996,7 @@ std::vector<size_t> AnfRuntimeAlgorithm::GetOutputRealDeviceShapeIfExist(const A
|
|||
auto max_shape = GetOutputMaxShape(anf_node, index);
|
||||
std::transform(max_shape.begin(), max_shape.end(), device_shape.begin(), IntToSize);
|
||||
auto format = GetOutputFormat(anf_node, index);
|
||||
trans::TransShapeToDevice(device_shape, format);
|
||||
trans::TransShapeToDevice(device_shape, format, anf_node, index);
|
||||
}
|
||||
return device_shape;
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -49,6 +49,29 @@ inline void SetData(size_t size, bool pad_zero, size_t src_idx, size_t dst_idx,
|
|||
}
|
||||
}
|
||||
|
||||
// greatest common divsor
|
||||
size_t Gcd(size_t a, size_t b) {
|
||||
if (b == 0) {
|
||||
return 0;
|
||||
}
|
||||
size_t c = b;
|
||||
while (a % b != 0) {
|
||||
c = a % b;
|
||||
a = b;
|
||||
b = c;
|
||||
}
|
||||
return c;
|
||||
}
|
||||
|
||||
// least common multiple
|
||||
size_t Lcm(size_t a, size_t b) {
|
||||
if (b == 0) {
|
||||
return 0;
|
||||
}
|
||||
size_t ret = (a * b) / (Gcd(a, b));
|
||||
return ret;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T DivCeil(T n1, T n2) {
|
||||
if (n2 != 0) {
|
||||
|
@ -57,6 +80,10 @@ T DivCeil(T n1, T n2) {
|
|||
return 0;
|
||||
}
|
||||
|
||||
size_t GetShapeSize(const std::vector<size_t> &shape) {
|
||||
return std::accumulate(shape.begin(), shape.end(), size_t(1), std::multiplies<size_t>());
|
||||
}
|
||||
|
||||
enum DataTypeTransMode {
|
||||
FROM_FLOAT_TO_FLOAT16,
|
||||
FROM_FLOAT_TO_INT32,
|
||||
|
@ -378,8 +405,54 @@ std::vector<size_t> PaddingShapeTo4dByDefault(const std::vector<size_t> &shape)
|
|||
}
|
||||
return shape_4d;
|
||||
}
|
||||
|
||||
std::vector<size_t> FracZDeviceShapeWithGroups(const std::vector<size_t> &shape, const int64_t groups = 1) {
|
||||
if (!CheckDims(shape)) {
|
||||
MS_LOG(EXCEPTION) << "Check dims failed.";
|
||||
}
|
||||
size_t group_size = LongToSize(groups);
|
||||
size_t cin_ori = shape[kC];
|
||||
size_t cout_ori = shape[kN] / group_size;
|
||||
size_t e_mult = std::min(Lcm(Lcm(cin_ori, kCubeSize) / cin_ori, Lcm(cout_ori, kCubeSize) / cout_ori), group_size);
|
||||
size_t cin_opt = DivCeil(e_mult * cin_ori, kCubeSize) * kCubeSize;
|
||||
size_t c1_dim = cin_opt / kCubeSize;
|
||||
size_t g_dim = DivCeil(group_size, e_mult);
|
||||
size_t n1 = DivCeil(cout_ori * e_mult, kCubeSize);
|
||||
std::vector<size_t> device_shape;
|
||||
device_shape.push_back(g_dim * c1_dim * shape[kH] * shape[kW]);
|
||||
device_shape.push_back(n1);
|
||||
device_shape.push_back(kNiSize);
|
||||
device_shape.push_back(kCubeSize);
|
||||
return device_shape;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
int64_t GetAttrGroups(const AnfNodePtr &node, const size_t index) {
|
||||
if (node == nullptr) {
|
||||
return 1;
|
||||
}
|
||||
if (node->isa<CNode>()) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (AnfAlgo::HasNodeAttr(kAttrFracZGroup, cnode)) {
|
||||
if (AnfAlgo::GetCNodeName(cnode) == kAllReduceOpName) {
|
||||
// if index not exists in fracz_group_idx, return default value 1
|
||||
auto fz_group_idx = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(cnode, kAttrFracZGroupIdx);
|
||||
int64_t out_index = SizeToLong(index);
|
||||
auto fz_iter = std::find(std::begin(fz_group_idx), std::end(fz_group_idx), out_index);
|
||||
if (fz_iter == std::end(fz_group_idx)) {
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
return AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrFracZGroup);
|
||||
}
|
||||
} else if (node->isa<Parameter>()) {
|
||||
auto param = node->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(param);
|
||||
return param->fracz_group();
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
|
||||
bool IsNeedPadding(const std::string &format, const size_t shape_size) {
|
||||
if (shape_size == 0) {
|
||||
return false;
|
||||
|
@ -547,7 +620,8 @@ std::vector<size_t> PaddingShapeTo5dDefault(const std::vector<size_t> &shape) {
|
|||
return shape_5d;
|
||||
}
|
||||
|
||||
std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format) {
|
||||
std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format,
|
||||
const int64_t groups) {
|
||||
using DeviceShapeTransfer = std::function<std::vector<size_t>(const std::vector<size_t> &)>;
|
||||
const std::map<std::string, DeviceShapeTransfer> device_shape_map{{kOpFormat_NCHW, NchwDeviceShape},
|
||||
{kOpFormat_NHWC, NhwcDeviceShape},
|
||||
|
@ -565,6 +639,9 @@ std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const s
|
|||
if (format == kOpFormat_ND || format == kOpFormat_DEFAULT) {
|
||||
return shape;
|
||||
}
|
||||
if (groups > 1 && format == kOpFormat_FRAC_Z) {
|
||||
return FracZDeviceShapeWithGroups(shape, groups);
|
||||
}
|
||||
auto temp_shape = shape;
|
||||
std::vector<size_t> device_shape;
|
||||
if (format == kOpFormat_FRAC_NZ) {
|
||||
|
@ -610,6 +687,15 @@ std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const s
|
|||
return iter->second(temp_shape);
|
||||
}
|
||||
|
||||
std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format,
|
||||
const AnfNodePtr &node, const size_t index) {
|
||||
int64_t groups = 1;
|
||||
if (format == kOpFormat_FRAC_Z) {
|
||||
groups = GetAttrGroups(node, index);
|
||||
}
|
||||
return TransShapeToDevice(shape, format, groups);
|
||||
}
|
||||
|
||||
bool CheckArgs(const FormatArgs &args, size_t *size, size_t *total_size) {
|
||||
if (args.host_shape.size() != kNchwDims) {
|
||||
MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
|
||||
|
@ -649,7 +735,7 @@ bool TransDataType(const TypeIdArgs &args, void *result) {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool TransFormat(const FormatArgs &args, void *result) {
|
||||
bool TransFormat(const FormatArgs &args, void *result, int64_t groups) {
|
||||
MS_LOG(DEBUG) << "Start trans format.";
|
||||
if (abstract::TypeIdSize(args.src_data_type) < 1) {
|
||||
MS_LOG(ERROR) << "Invalid datatype..";
|
||||
|
@ -658,6 +744,9 @@ bool TransFormat(const FormatArgs &args, void *result) {
|
|||
if (args.device_format == kOpFormat_HWCN || args.device_format == kOpFormat_NHWC) {
|
||||
return NchwTo4D(args, result);
|
||||
}
|
||||
if (groups > 1 && args.device_format == kOpFormat_FRAC_Z) {
|
||||
return NchwToFracZWithGroups(args, result, groups);
|
||||
}
|
||||
auto iter = kTransFormatMapOfHostToDevice.find(args.device_format);
|
||||
if (iter == kTransFormatMapOfHostToDevice.end()) {
|
||||
MS_LOG(EXCEPTION) << "Unexpected format[" << args.device_format << "]";
|
||||
|
@ -665,7 +754,15 @@ bool TransFormat(const FormatArgs &args, void *result) {
|
|||
return iter->second(args, result);
|
||||
}
|
||||
|
||||
bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result) {
|
||||
bool TransFormat(const FormatArgs &args, void *result, const AnfNodePtr &node, const size_t index) {
|
||||
int64_t groups = 1;
|
||||
if (args.device_format == kOpFormat_FRAC_Z) {
|
||||
groups = GetAttrGroups(node, index);
|
||||
}
|
||||
return TransFormat(args, result, groups);
|
||||
}
|
||||
|
||||
bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result, int64_t groups) {
|
||||
const std::map<std::string, FormatTransfer> format_trans_map{
|
||||
{kOpFormat_FRAC_Z, FracZToNchw}, {kOpFormat_FRAC_NZ, FracNzToNchw},
|
||||
{kOpFormat_NC1HWC0, Nc1hwc0ToNchw}, {kOpFormat_C1HWNCoC0, C1hwncoc0ToNchw},
|
||||
|
@ -680,6 +777,9 @@ bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result) {
|
|||
if (args.device_format == kOpFormat_HWCN || args.device_format == kOpFormat_NHWC) {
|
||||
return ToNchw(args, result);
|
||||
}
|
||||
if (groups > 1 && args.device_format == kOpFormat_FRAC_Z) {
|
||||
return FracZToNchwWithGroups(args, result, groups);
|
||||
}
|
||||
auto iter = format_trans_map.find(args.device_format);
|
||||
if (iter == format_trans_map.end()) {
|
||||
MS_LOG(EXCEPTION) << "Unexpected format[" << args.device_format << "]";
|
||||
|
@ -687,6 +787,14 @@ bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result) {
|
|||
return iter->second(args, result);
|
||||
}
|
||||
|
||||
bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result, const AnfNodePtr &node, const size_t index) {
|
||||
int64_t groups = 1;
|
||||
if (args.device_format == kOpFormat_FRAC_Z) {
|
||||
groups = GetAttrGroups(node, index);
|
||||
}
|
||||
return TransFormatFromDeviceToHost(args, result, groups);
|
||||
}
|
||||
|
||||
bool NchwTo4D(const FormatArgs &args, void *result) {
|
||||
// trans nchw to 4d
|
||||
MS_LOG(DEBUG) << "Trans format from nchw to 4d.";
|
||||
|
@ -1477,5 +1585,82 @@ bool FracZ3DToNcdhw(const FormatArgs &args, void *result) {
|
|||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool NchwFracZTransWithGroups(const FormatArgs &args, void *result, bool to_device, int64_t groups) {
|
||||
MS_LOG(DEBUG) << "Trans format from nchw to frac_z";
|
||||
MS_EXCEPTION_IF_NULL(result);
|
||||
if (args.host_shape.size() != kNchwDims) {
|
||||
MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
|
||||
return false;
|
||||
}
|
||||
auto size = abstract::TypeIdSize(args.src_data_type);
|
||||
if (size < 1) {
|
||||
MS_LOG(ERROR) << "Illegal dtype";
|
||||
return false;
|
||||
}
|
||||
auto n_dim = args.host_shape[kN];
|
||||
auto c_dim = args.host_shape[kC];
|
||||
auto h_dim = args.host_shape[kH];
|
||||
auto w_dim = args.host_shape[kW];
|
||||
size_t d_dim = 1;
|
||||
size_t group_size = LongToSize(groups);
|
||||
auto cin_ori = c_dim;
|
||||
auto cout_ori = n_dim / group_size;
|
||||
if (cin_ori == 0 || cout_ori == 0) {
|
||||
MS_LOG(ERROR) << "cin_ori, cout_ori must not equal to 0";
|
||||
return false;
|
||||
}
|
||||
size_t e_mult = std::min(Lcm(Lcm(cin_ori, kCubeSize) / cin_ori, Lcm(cout_ori, kCubeSize) / cout_ori), group_size);
|
||||
size_t cin_opt = DivCeil(e_mult * cin_ori, kCubeSize) * kCubeSize;
|
||||
size_t cout_opt = DivCeil(e_mult * cout_ori, kCubeSize) * kCubeSize;
|
||||
size_t c1_dim = cin_opt / kCubeSize;
|
||||
size_t dst_size = to_device ? GetShapeSize(args.device_shape) * size : GetShapeSize(args.host_shape) * size;
|
||||
if (dst_size == 0) {
|
||||
return true;
|
||||
}
|
||||
auto ret = memset_s(result, dst_size, 0, dst_size);
|
||||
if (ret != EOK) {
|
||||
MS_LOG(ERROR) << "memset failed";
|
||||
return false;
|
||||
}
|
||||
for (size_t g = 0; g < group_size; ++g) {
|
||||
for (size_t d = 0; d < d_dim; ++d) {
|
||||
for (size_t c = 0; c < c_dim; ++c) {
|
||||
for (size_t h = 0; h < h_dim; ++h) {
|
||||
for (size_t w = 0; w < w_dim; ++w) {
|
||||
for (size_t n = 0; n < cout_ori; ++n) {
|
||||
size_t e_val = g % e_mult;
|
||||
size_t dst_ci = e_val * cin_ori + c;
|
||||
size_t dst_co = e_val * cout_ori + n;
|
||||
size_t src_co = g * cout_ori + n;
|
||||
size_t temporary = dst_ci % kCubeSize;
|
||||
size_t dev_idx = (g / e_mult) * d_dim * c1_dim * h_dim * w_dim * cout_opt * kCubeSize +
|
||||
d * c1_dim * h_dim * w_dim * cout_opt * kCubeSize +
|
||||
(dst_ci / kCubeSize) * h_dim * w_dim * cout_opt * kCubeSize +
|
||||
h * w_dim * cout_opt * kCubeSize + w * cout_opt * kCubeSize + dst_co * kCubeSize +
|
||||
temporary;
|
||||
size_t hst_idx =
|
||||
src_co * c_dim * d_dim * h_dim * w_dim + c * d_dim * h_dim * w_dim + d * h_dim * w_dim + h * w_dim + w;
|
||||
if (to_device) {
|
||||
SetData(size, false, hst_idx, dev_idx, args, result);
|
||||
} else {
|
||||
SetData(size, false, dev_idx, hst_idx, args, result);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool NchwToFracZWithGroups(const FormatArgs &args, void *result, int64_t groups) {
|
||||
return NchwFracZTransWithGroups(args, result, true, groups);
|
||||
}
|
||||
|
||||
bool FracZToNchwWithGroups(const FormatArgs &args, void *result, int64_t groups) {
|
||||
return NchwFracZTransWithGroups(args, result, false, groups);
|
||||
}
|
||||
} // namespace trans
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -64,10 +64,16 @@ void StringToAxisVector4D(const std::string &reshape_type_str, std::vector<Axis>
|
|||
void StringToAxisVector5D(const std::string &reshape_type_str, std::vector<Axis5D> *reshape_type_vec);
|
||||
ShapeVector GetRuntimePaddingShape(const AnfNodePtr &node, size_t index);
|
||||
bool IsNeedPadding(const std::string &format, const size_t shape_size);
|
||||
std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format);
|
||||
int64_t GetNodeGroups(const AnfNodePtr &node);
|
||||
std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format,
|
||||
const int64_t groups = 1);
|
||||
std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format,
|
||||
const AnfNodePtr &node, const size_t index);
|
||||
bool TransDataType(const TypeIdArgs &args, void *result);
|
||||
bool TransFormat(const FormatArgs &args, void *result);
|
||||
bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result);
|
||||
bool TransFormat(const FormatArgs &args, void *result, int64_t groups = 1);
|
||||
bool TransFormat(const FormatArgs &args, void *result, const AnfNodePtr &node, const size_t index);
|
||||
bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result, int64_t groups = 1);
|
||||
bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result, const AnfNodePtr &node, const size_t index);
|
||||
|
||||
// host to device
|
||||
bool NchwTo4D(const FormatArgs &args, void *result);
|
||||
|
@ -79,6 +85,7 @@ bool NchwToFracZc04(const FormatArgs &args, void *result);
|
|||
bool NchwToNc1hwc04(const FormatArgs &args, void *result);
|
||||
bool NchwToC1hwncoc0(const FormatArgs &args, void *result);
|
||||
bool NcdhwToNdc1hwc0(const FormatArgs &args, void *result);
|
||||
bool NchwToFracZWithGroups(const FormatArgs &args, void *result, int64_t groups);
|
||||
|
||||
// device to host
|
||||
bool ToNchw(const FormatArgs &args, void *result);
|
||||
|
@ -89,6 +96,7 @@ bool Nc1hwc04ToNchw(const FormatArgs &args, void *result);
|
|||
bool FracZ3DToNcdhw(const FormatArgs &args, void *result);
|
||||
bool C1hwncoc0ToNchw(const FormatArgs &args, void *result);
|
||||
bool Ndc1hwc0ToNcdhw(const FormatArgs &args, void *result);
|
||||
bool FracZToNchwWithGroups(const FormatArgs &args, void *result, int64_t groups);
|
||||
using FormatTransfer = std::function<bool(const FormatArgs &, void *)>;
|
||||
const std::map<std::string, FormatTransfer> kTransFormatMapOfHostToDevice{
|
||||
{kOpFormat_FRAC_Z, NchwToFracZ}, {kOpFormat_FRAC_NZ, NchwToFracNz},
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -480,8 +480,9 @@ std::vector<size_t> AscendDeviceAddress::GetWorkspaceSizeList(const nlohmann::js
|
|||
|
||||
std::vector<size_t> AscendDeviceAddress::GetDeviceShape(std::vector<size_t> *host_shape) const {
|
||||
std::vector<size_t> device_shape;
|
||||
auto node_index = GetNodeIndex();
|
||||
if (format_ == kOpFormat_FRAC_NZ || format_ == kOpFormat_NCDHW) {
|
||||
device_shape = trans::TransShapeToDevice(*host_shape, format_);
|
||||
device_shape = trans::TransShapeToDevice(*host_shape, format_, node_index.first, node_index.second);
|
||||
} else {
|
||||
if (host_shape_.empty()) {
|
||||
*host_shape = trans::PaddingShape(*host_shape, format_);
|
||||
|
@ -489,7 +490,7 @@ std::vector<size_t> AscendDeviceAddress::GetDeviceShape(std::vector<size_t> *hos
|
|||
host_shape->clear();
|
||||
(void)std::transform(host_shape_.begin(), host_shape_.end(), std::back_inserter(*host_shape), LongToSize);
|
||||
}
|
||||
device_shape = trans::TransShapeToDevice(*host_shape, format_);
|
||||
device_shape = trans::TransShapeToDevice(*host_shape, format_, node_index.first, node_index.second);
|
||||
}
|
||||
return device_shape;
|
||||
}
|
||||
|
@ -518,11 +519,12 @@ bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormat(const ShapeVector &sh
|
|||
}
|
||||
auto host_tmp = std::vector<uint8_t>(size_);
|
||||
SyncMemory(host_tmp.data(), ptr_, size_, RT_MEMCPY_DEVICE_TO_HOST);
|
||||
auto node_index = GetNodeIndex();
|
||||
if (type_id_ != type) {
|
||||
const trans::FormatArgs format_args{host_tmp.data(), size_, kOpFormat_NCHW, format_,
|
||||
host_shape, device_shape, type_id_};
|
||||
auto host = std::vector<uint8_t>(size_);
|
||||
sync_ok = trans::TransFormatFromDeviceToHost(format_args, host.data());
|
||||
sync_ok = trans::TransFormatFromDeviceToHost(format_args, host.data(), node_index.first, node_index.second);
|
||||
if (!sync_ok) {
|
||||
MS_LOG(ERROR) << "Trans format failed.";
|
||||
return false;
|
||||
|
@ -537,7 +539,7 @@ bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormat(const ShapeVector &sh
|
|||
} else {
|
||||
const trans::FormatArgs format_args{host_tmp.data(), size_, kOpFormat_NCHW, format_,
|
||||
host_shape, device_shape, type_id_};
|
||||
sync_ok = trans::TransFormatFromDeviceToHost(format_args, host_ptr);
|
||||
sync_ok = trans::TransFormatFromDeviceToHost(format_args, host_ptr, node_index.first, node_index.second);
|
||||
if (!sync_ok) {
|
||||
MS_LOG(ERROR) << "Trans format failed.";
|
||||
return false;
|
||||
|
@ -604,12 +606,13 @@ bool AscendDeviceAddress::ConvertFormatAndSyncHostToDevice(const ShapeVector &sh
|
|||
if (host_shape.empty()) {
|
||||
host_shape.emplace_back(1);
|
||||
}
|
||||
auto node_index = GetNodeIndex();
|
||||
std::vector<size_t> device_shape;
|
||||
if (format_ == kOpFormat_FRAC_NZ) {
|
||||
device_shape = trans::TransShapeToDevice(host_shape, format_);
|
||||
device_shape = trans::TransShapeToDevice(host_shape, format_, node_index.first, node_index.second);
|
||||
} else {
|
||||
host_shape = trans::PaddingShape(host_shape, format_);
|
||||
device_shape = trans::TransShapeToDevice(host_shape, format_);
|
||||
device_shape = trans::TransShapeToDevice(host_shape, format_, node_index.first, node_index.second);
|
||||
}
|
||||
if (type_id_ != type) {
|
||||
auto shape_size = abstract::ShapeSize(host_shape);
|
||||
|
@ -623,7 +626,7 @@ bool AscendDeviceAddress::ConvertFormatAndSyncHostToDevice(const ShapeVector &sh
|
|||
const trans::FormatArgs format_args{host_tmp.data(), size_, kOpFormat_NCHW, format_,
|
||||
host_shape, device_shape, type_id_};
|
||||
auto dst_tmp = std::vector<uint8_t>(size_);
|
||||
sync_ok = trans::TransFormat(format_args, dst_tmp.data());
|
||||
sync_ok = trans::TransFormat(format_args, dst_tmp.data(), node_index.first, node_index.second);
|
||||
if (!sync_ok) {
|
||||
MS_LOG(ERROR) << "Trans format failed.";
|
||||
return false;
|
||||
|
@ -632,7 +635,7 @@ bool AscendDeviceAddress::ConvertFormatAndSyncHostToDevice(const ShapeVector &sh
|
|||
} else {
|
||||
const trans::FormatArgs format_args{host_ptr, size_, kOpFormat_NCHW, format_, host_shape, device_shape, type_id_};
|
||||
auto host_tmp = std::vector<uint8_t>(size_);
|
||||
sync_ok = trans::TransFormat(format_args, host_tmp.data());
|
||||
sync_ok = trans::TransFormat(format_args, host_tmp.data(), node_index.first, node_index.second);
|
||||
if (!sync_ok) {
|
||||
MS_LOG(ERROR) << "Trans format failed.";
|
||||
return false;
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -38,6 +38,9 @@ class AscendDeviceAddress : public DeviceAddress {
|
|||
explicit AscendDeviceAddress(void *ptr, size_t size) : DeviceAddress(ptr, size) {}
|
||||
explicit AscendDeviceAddress(void *ptr, size_t size, const std::string &format, TypeId type_id)
|
||||
: DeviceAddress(ptr, size, format, type_id) {}
|
||||
explicit AscendDeviceAddress(void *ptr, size_t size, const std::string &format, TypeId type_id,
|
||||
const AnfNodePtr &node, size_t out_index)
|
||||
: DeviceAddress(ptr, size, format, type_id, node, out_index) {}
|
||||
~AscendDeviceAddress() override;
|
||||
bool SyncDeviceToHost(size_t size, void *host_ptr) const override;
|
||||
bool SyncHostToDevice(size_t size, const void *host_ptr) const override;
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -333,6 +333,11 @@ DeviceAddressPtr AscendKernelRuntime::CreateDeviceAddress(void *device_ptr, size
|
|||
return std::make_shared<AscendDeviceAddress>(device_ptr, device_size, format, type_id);
|
||||
}
|
||||
|
||||
DeviceAddressPtr AscendKernelRuntime::CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
|
||||
TypeId type_id, const AnfNodePtr &node, size_t out_index) {
|
||||
return std::make_shared<AscendDeviceAddress>(device_ptr, device_size, format, type_id, node, out_index);
|
||||
}
|
||||
|
||||
bool AscendKernelRuntime::Load(session::KernelGraph *graph, bool is_task_sink) {
|
||||
if (!is_task_sink) {
|
||||
GenKernelEvents(graph);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -66,6 +66,8 @@ class AscendKernelRuntime : public KernelRuntime {
|
|||
protected:
|
||||
DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
|
||||
TypeId type_id) override;
|
||||
DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, TypeId type_id,
|
||||
const AnfNodePtr &node, size_t out_index) override;
|
||||
bool KernelMemNotReuse(const AnfNodePtr &node) override;
|
||||
|
||||
void KernelLaunchProfiling(const std::string &kernel_name) override;
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -31,6 +31,10 @@ class CPUDeviceAddress : public DeviceAddress {
|
|||
CPUDeviceAddress(void *ptr, size_t size, const string &format, TypeId type_id)
|
||||
: DeviceAddress(ptr, size, format, type_id) {}
|
||||
|
||||
CPUDeviceAddress(void *ptr, size_t size, const std::string &format, TypeId type_id, const AnfNodePtr &node,
|
||||
size_t out_index)
|
||||
: DeviceAddress(ptr, size, format, type_id, node, out_index) {}
|
||||
|
||||
~CPUDeviceAddress() override = default;
|
||||
|
||||
bool SyncDeviceToHost(const ShapeVector &shape, size_t size, TypeId type, void *host_ptr) const override;
|
||||
|
|
|
@ -175,6 +175,11 @@ DeviceAddressPtr CPUKernelRuntime::CreateDeviceAddress(void *device_ptr, size_t
|
|||
return std::make_shared<CPUDeviceAddress>(device_ptr, device_size, format, type_id);
|
||||
}
|
||||
|
||||
DeviceAddressPtr CPUKernelRuntime::CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
|
||||
TypeId type_id, const AnfNodePtr &node, size_t out_index) {
|
||||
return std::make_shared<CPUDeviceAddress>(device_ptr, device_size, format, type_id, node, out_index);
|
||||
}
|
||||
|
||||
tensor::TensorPtr CPUKernelRuntime::CreatTensorForOutput(
|
||||
session::KernelGraph *kernel_graph, const CNodePtr &node, size_t index,
|
||||
std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node) {
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -53,6 +53,8 @@ class CPUKernelRuntime : public KernelRuntime {
|
|||
bool MemcpyAsync(void *dst, const void *src, uint64_t size, int32_t kind) override { return true; };
|
||||
DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
|
||||
TypeId type_id) override;
|
||||
DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, TypeId type_id,
|
||||
const AnfNodePtr &node, size_t out_index) override;
|
||||
|
||||
private:
|
||||
tensor::TensorPtr CreatTensorForOutput(session::KernelGraph *kernel_graph, const CNodePtr &node, size_t index,
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -21,6 +21,7 @@
|
|||
#include <vector>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <utility>
|
||||
#include "ir/dtype.h"
|
||||
#include "ir/device_sync.h"
|
||||
#include "utils/shape_utils.h"
|
||||
|
@ -52,6 +53,7 @@ class GPUDeviceContext;
|
|||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
using KernelWithIndex = std::pair<AnfNodePtr, size_t>;
|
||||
enum class DeviceAddressStatus { kInDevice, kInHost, kInDeviceToHost, kInHostToDevice };
|
||||
enum class DeviceAddressType { kUnknown, kAscend, kCPU, kGPU };
|
||||
static const std::map<DeviceAddressType, std::string> kDeviceTypeToName = {{DeviceAddressType::kUnknown, "Unknown"},
|
||||
|
@ -64,6 +66,9 @@ class DeviceAddress : public mindspore::DeviceSync {
|
|||
explicit DeviceAddress(void *ptr, size_t size) : ptr_(ptr), size_(size) {}
|
||||
explicit DeviceAddress(void *ptr, size_t size, const string &format, TypeId type_id)
|
||||
: ptr_(ptr), size_(size), format_(format), type_id_(type_id) {}
|
||||
explicit DeviceAddress(void *ptr, size_t size, const std::string &format, TypeId type_id, const AnfNodePtr &node,
|
||||
size_t out_index)
|
||||
: ptr_(ptr), size_(size), format_(format), type_id_(type_id), node_index_({node, out_index}) {}
|
||||
virtual ~DeviceAddress() { ptr_ = nullptr; }
|
||||
const void *GetPtr() const { return ptr_; }
|
||||
size_t GetSize() const { return size_; }
|
||||
|
@ -75,6 +80,7 @@ class DeviceAddress : public mindspore::DeviceSync {
|
|||
virtual DeviceAddressStatus status() const { return DeviceAddressStatus::kInDevice; }
|
||||
virtual DeviceAddressType DeviceType() const { return DeviceAddressType::kUnknown; }
|
||||
void *GetMutablePtr() const override { return ptr_; }
|
||||
virtual void SetNodeIndex(const AnfNodePtr &node, size_t out_index) { node_index_ = {node, out_index}; }
|
||||
|
||||
// The related interface of reference count operation.
|
||||
void set_original_ref_count(size_t original_ref_count) { original_ref_count_ = original_ref_count; }
|
||||
|
@ -100,6 +106,10 @@ class DeviceAddress : public mindspore::DeviceSync {
|
|||
const void *ptr() const { return ptr_; }
|
||||
size_t size() const { return size_; }
|
||||
void set_ptr(void *ptr) { ptr_ = ptr; }
|
||||
KernelWithIndex GetNodeIndex() const {
|
||||
return node_index_.first.expired() ? KernelWithIndex{nullptr, node_index_.second}
|
||||
: KernelWithIndex{node_index_.first.lock(), node_index_.second};
|
||||
}
|
||||
void *ptr_{nullptr};
|
||||
size_t size_{0};
|
||||
size_t original_ref_count_{1};
|
||||
|
@ -110,6 +120,8 @@ class DeviceAddress : public mindspore::DeviceSync {
|
|||
bool from_mem_pool_{false};
|
||||
uint8_t *communication_ptr_{nullptr};
|
||||
ShapeVector host_shape_{};
|
||||
// {node, out_index}
|
||||
std::pair<AnfNodeWeakPtr, size_t> node_index_{AnfNodePtr(nullptr), 0};
|
||||
friend class KernelRuntime;
|
||||
friend class MemoryManager;
|
||||
friend class mindspore::device::ascend::tasksink::TaskGenerator;
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -34,6 +34,9 @@ class GPUDeviceAddress : public DeviceAddress {
|
|||
GPUDeviceAddress(void *ptr, size_t size) : DeviceAddress(ptr, size) {}
|
||||
GPUDeviceAddress(void *ptr, size_t size, const string &format, TypeId type_id)
|
||||
: DeviceAddress(ptr, size, format, type_id) {}
|
||||
GPUDeviceAddress(void *ptr, size_t size, const std::string &format, TypeId type_id, const AnfNodePtr &node,
|
||||
size_t out_index)
|
||||
: DeviceAddress(ptr, size, format, type_id, node, out_index) {}
|
||||
~GPUDeviceAddress() override;
|
||||
|
||||
bool SyncDeviceToHost(size_t size, void *host_ptr) const override;
|
||||
|
@ -51,6 +54,7 @@ class GPUDeviceAddress : public DeviceAddress {
|
|||
bool LoadMemToHost(const std::string &tensor_name, int execution_order, const std::string &host_fmt,
|
||||
const ShapeVector &host_shape, TypeId host_type, size_t slot, bool keep_prev) const override;
|
||||
#endif
|
||||
|
||||
private:
|
||||
DeviceAddressStatus status_{DeviceAddressStatus::kInDevice};
|
||||
};
|
||||
|
|
|
@ -211,6 +211,11 @@ DeviceAddressPtr GPUKernelRuntime::CreateDeviceAddress(void *device_ptr, size_t
|
|||
return std::make_shared<GPUDeviceAddress>(device_ptr, device_size, format, type_id);
|
||||
}
|
||||
|
||||
DeviceAddressPtr GPUKernelRuntime::CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
|
||||
TypeId type_id, const AnfNodePtr &node, size_t out_index) {
|
||||
return std::make_shared<GPUDeviceAddress>(device_ptr, device_size, format, type_id, node, out_index);
|
||||
}
|
||||
|
||||
bool GPUKernelRuntime::InitDevice() {
|
||||
if (GPUDeviceManager::GetInstance().device_count() <= 0) {
|
||||
MS_LOG(ERROR) << "No GPU device found.";
|
||||
|
|
|
@ -58,6 +58,8 @@ class GPUKernelRuntime : public KernelRuntime {
|
|||
protected:
|
||||
DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
|
||||
TypeId type_id) override;
|
||||
DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, TypeId type_id,
|
||||
const AnfNodePtr &node, size_t out_index) override;
|
||||
bool SyncStream() override;
|
||||
bool MemcpyAsync(void *dst, const void *src, uint64_t size, int32_t kind) override;
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -165,7 +165,7 @@ void KernelRuntime::RunOpAssignInputMemory(const std::vector<tensor::TensorPtr>
|
|||
}
|
||||
auto tensor_size = AnfAlgo::GetOutputTensorMemSize(item, index);
|
||||
auto device_address =
|
||||
CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id);
|
||||
CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id, item, index);
|
||||
MS_EXCEPTION_IF_NULL(device_address);
|
||||
MS_EXCEPTION_IF_NULL(mem_manager_);
|
||||
auto ret = mem_manager_->MallocMemFromMemPool(device_address, tensor_size);
|
||||
|
@ -198,7 +198,7 @@ void KernelRuntime::RunOpAssignOutputMemory(const AnfNodePtr &kernel) {
|
|||
}
|
||||
std::string output_format = AnfAlgo::GetOutputFormat(kernel, i);
|
||||
auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel, i);
|
||||
auto device_address = CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type);
|
||||
auto device_address = CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type, kernel, i);
|
||||
device_address->set_host_shape(trans::GetRuntimePaddingShape(kernel, i));
|
||||
MS_EXCEPTION_IF_NULL(device_address);
|
||||
auto ret = mem_manager_->MallocMemFromMemPool(device_address, output_sizes[i]);
|
||||
|
@ -331,14 +331,15 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) {
|
|||
}
|
||||
const auto &address = ps::ps_cache_instance.QueryHashTableAddr(param_name);
|
||||
MS_EXCEPTION_IF_NULL(address.addr);
|
||||
device_address =
|
||||
CreateDeviceAddress(address.addr, address.size, AnfAlgo::GetOutputFormat(item, index), output_type_id);
|
||||
device_address = CreateDeviceAddress(address.addr, address.size, AnfAlgo::GetOutputFormat(item, index),
|
||||
output_type_id, item, index);
|
||||
AnfAlgo::SetOutputAddr(device_address, index, item.get());
|
||||
continue;
|
||||
}
|
||||
#endif
|
||||
auto tensor_size = AnfAlgo::GetOutputTensorMemSize(item, index);
|
||||
device_address = CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id);
|
||||
device_address =
|
||||
CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id, item, index);
|
||||
MS_LOG(INFO) << "Assign Static Memory for Input node, size:" << tensor_size
|
||||
<< " node:" << item->fullname_with_scope() << " index: " << index;
|
||||
if (mem_manager_->MallocMem(kStaticMem, tensor_size, device_address, graph->graph_id()) == nullptr) {
|
||||
|
@ -509,7 +510,7 @@ void KernelRuntime::AssignCommunicationNodeOutputMem(MemType type, const AnfNode
|
|||
for (size_t j = 0; j < align_size_list.size(); ++j) {
|
||||
std::string output_format = AnfAlgo::GetOutputFormat(node, j);
|
||||
auto output_type = AnfAlgo::GetOutputDeviceDataType(node, j);
|
||||
auto address = CreateDeviceAddress(nullptr, output_sizes[j], output_format, output_type);
|
||||
auto address = CreateDeviceAddress(nullptr, output_sizes[j], output_format, output_type, node, j);
|
||||
MS_EXCEPTION_IF_NULL(address);
|
||||
if (output_ptr == nullptr) {
|
||||
output_ptr = mem_manager_->MallocOutputMem(node, 0, type, total_size, address, true);
|
||||
|
@ -545,7 +546,7 @@ DeviceAddressPtr KernelRuntime::PreAssignCNodeMemory(const AnfNodePtr &anf_node,
|
|||
}
|
||||
std::string output_format = AnfAlgo::GetOutputFormat(anf_node, index);
|
||||
auto output_type = AnfAlgo::GetOutputDeviceDataType(anf_node, index);
|
||||
auto address = CreateDeviceAddress(nullptr, output_sizes[index], output_format, output_type);
|
||||
auto address = CreateDeviceAddress(nullptr, output_sizes[index], output_format, output_type, anf_node, index);
|
||||
AnfAlgo::SetOutputAddr(address, index, anf_node.get());
|
||||
return address;
|
||||
}
|
||||
|
@ -638,7 +639,7 @@ void KernelRuntime::AssignNodeOutputMem(MemType type, const AnfNodePtr &node, in
|
|||
}
|
||||
std::string output_format = AnfAlgo::GetOutputFormat(node, i);
|
||||
auto output_type = AnfAlgo::GetOutputDeviceDataType(node, i);
|
||||
auto device_address = CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type);
|
||||
auto device_address = CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type, node, i);
|
||||
MS_EXCEPTION_IF_NULL(device_address);
|
||||
uint8_t *ptr = mem_manager_->MallocOutputMem(node, i, type, output_sizes[i], device_address, false);
|
||||
MS_EXCEPTION_IF_NULL(ptr);
|
||||
|
@ -678,7 +679,8 @@ void KernelRuntime::AssignValueNodeTensor(const ValueNodePtr &value_node, const
|
|||
output_type_id = AnfAlgo::GetOutputInferDataType(value_node, output_idx);
|
||||
}
|
||||
auto output_format = AnfAlgo::GetOutputFormat(value_node, output_idx);
|
||||
DeviceAddressPtr address = CreateDeviceAddress(nullptr, node_size, output_format, output_type_id);
|
||||
DeviceAddressPtr address =
|
||||
CreateDeviceAddress(nullptr, node_size, output_format, output_type_id, value_node, output_idx);
|
||||
MS_EXCEPTION_IF_NULL(address);
|
||||
if (ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER) &&
|
||||
!mem_manager_->MallocMemFromMemPool(address, node_size)) {
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -114,6 +114,8 @@ class KernelRuntime {
|
|||
protected:
|
||||
virtual DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
|
||||
TypeId type_id) = 0;
|
||||
virtual DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
|
||||
TypeId type_id, const AnfNodePtr &node, size_t out_index) = 0;
|
||||
virtual bool NodeOutputDeviceAddressExist(const AnfNodePtr &node, size_t index);
|
||||
virtual bool KernelMemNotReuse(const AnfNodePtr &node);
|
||||
|
||||
|
|
|
@ -358,6 +358,9 @@ constexpr auto kAttrLabelForInsertStreamActive = "label_for_insert_stream_active
|
|||
constexpr auto kAttrFpBpEnd = "fpbp_end";
|
||||
constexpr auto kAttrFusion = "fusion";
|
||||
constexpr auto kAttrGroup = "group";
|
||||
constexpr auto kAttrGroups = "groups";
|
||||
constexpr auto kAttrFracZGroup = "fracz_group";
|
||||
constexpr auto kAttrFracZGroupIdx = "fracz_group_idx";
|
||||
constexpr auto kAttrOp = "op";
|
||||
constexpr auto kAttrDestRank = "dest_rank";
|
||||
constexpr auto kAttrSrcRank = "src_rank";
|
||||
|
@ -450,6 +453,7 @@ const size_t kShape2dDims = 2;
|
|||
const size_t kShape5dDims = 5;
|
||||
const size_t kShape1dDims = 1;
|
||||
const size_t kCubeSize = 16;
|
||||
const size_t kNiSize = 16;
|
||||
const size_t kMemAlignSize = 512;
|
||||
const size_t kBNChannelMultipleFactor = 4;
|
||||
const int kParameterDataTensorMask = 0;
|
||||
|
@ -565,9 +569,12 @@ const std::set<std::string> kOptOperatorSet = {kMomentumOpName,
|
|||
kApplyProximalAdagradOpName,
|
||||
kApplyProximalGradientDescentOpName,
|
||||
kApplyRMSPropOpName,
|
||||
kAdamApplyOneWithDecayOpName,
|
||||
kAdamApplyOneWithDecayAssignOpName,
|
||||
kFusedAdamWeightDecayName,
|
||||
kFusedAdamName,
|
||||
kFusedSparseAdamName,
|
||||
kFusedMulApplyMomentumOpName,
|
||||
kFusedWeightScaleApplyMomentum,
|
||||
kFusedScaleApplyMomentum,
|
||||
kApplyCenteredRMSPropOpName,
|
||||
|
|
|
@ -424,6 +424,9 @@ class Parameter : public ANode {
|
|||
void set_has_dynamic_shape(bool flag) { has_dynamic_shape_ = flag; }
|
||||
bool has_dynamic_shape() const { return has_dynamic_shape_; }
|
||||
|
||||
void set_fracz_group(int64_t fracz_group) { fracz_group_ = fracz_group; }
|
||||
int64_t fracz_group() { return fracz_group_; }
|
||||
|
||||
private:
|
||||
std::string name_;
|
||||
bool has_default_;
|
||||
|
@ -432,6 +435,8 @@ class Parameter : public ANode {
|
|||
ValuePtr default_param_;
|
||||
// The count of graphs using the parameter.
|
||||
int used_graph_count_;
|
||||
// groups attr in FracZ format
|
||||
int64_t fracz_group_ = 1;
|
||||
};
|
||||
using ParameterPtr = std::shared_ptr<Parameter>;
|
||||
|
||||
|
|
Loading…
Reference in New Issue