FracZ format conversion when conv2d group > 1

This commit is contained in:
yuchaojie 2021-06-07 09:28:18 +08:00
parent 7c1ea7e3fa
commit 689158f79b
25 changed files with 517 additions and 46 deletions

View File

@ -69,6 +69,7 @@
#include "backend/optimizer/ascend/ir_fusion/batchnormgrad_to_bninfergrad.h"
#include "backend/optimizer/ascend/ir_fusion/confusion_mul_grad_fusion.h"
#include "backend/optimizer/ascend/ir_fusion/softmax_grad_ext_fusion.h"
#include "backend/optimizer/ascend/ir_fusion/set_fracz_group_attr.h"
#include "backend/optimizer/ascend/format_type/insert_trans_op.h"
#include "backend/optimizer/ascend/format_type/trans_op_format_refine.h"
#include "backend/optimizer/ascend/format_type/dynamic_rnn_grad_reformat.h"
@ -369,6 +370,8 @@ void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kern
other_pm->AddPass(std::make_shared<ParameterTransOpFusion>());
other_pm->AddPass(std::make_shared<RefreshParameterFormat>());
other_pm->AddPass(std::make_shared<SplitOpOptimizer>());
other_pm->AddPass(std::make_shared<SetFraczGroupAttr>());
other_pm->AddPass(std::make_shared<EliminateRedundantOp>());
optimizer->AddPassManager(other_pm);
(void)optimizer->Optimize(kernel_graph);
kernel_graph->SetExecOrderByDefault();

View File

@ -99,6 +99,14 @@ CNodePtr CreateFusionOp(const std::vector<AnfNodePtr> &inputs_list, const std::v
ValuePtr output_names_v = MakeValue(output_names);
fusion_op->set_attr("input_names", input_names_v);
fusion_op->set_attr("output_names", output_names_v);
for (auto node : anf_nodes) {
auto cnode = node->cast<CNodePtr>();
if (AnfAlgo::HasNodeAttr(kAttrFracZGroup, cnode)) {
auto fracz_group = AnfAlgo::GetNodeAttr<int64_t>(node, kAttrFracZGroup);
fusion_op->set_attr(kAttrFracZGroup, MakeValue(fracz_group));
break;
}
}
std::vector<AnfNodePtr> fusion_inputs_list = inputs_list;
auto value_node = std::make_shared<ValueNode>(fusion_op);
(void)fusion_inputs_list.insert(fusion_inputs_list.begin(), value_node);

View File

@ -23,9 +23,7 @@
namespace mindspore {
namespace opt {
void DoRefresh(const CNodePtr &cnode) {
if (cnode == nullptr) {
MS_LOG(EXCEPTION) << "node is nullptr";
}
MS_EXCEPTION_IF_NULL(cnode);
size_t input_num = AnfAlgo::GetInputTensorNum(cnode);
for (size_t input_index = 0; input_index < input_num; input_index++) {
auto input_kernel_node = AnfAlgo::VisitKernel(AnfAlgo::GetInputNode(cnode, input_index), 0).first;

View File

@ -0,0 +1,167 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/ascend/ir_fusion/set_fracz_group_attr.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "utils/utils.h"
#include "runtime/device/kernel_info.h"
#include "backend/optimizer/common/helper.h"
namespace mindspore {
namespace opt {
namespace {
constexpr auto kTupleGetItemName = "TupleGetItem";
constexpr auto kDependName = "Depend";
constexpr auto kLoadName = "Load";
constexpr size_t kConvFilterInputIndex = 2;
constexpr size_t kElemFilterInputIndex = 1;
void SetAttrForInputNode(const AnfNodePtr &node, int64_t groups) {
if (node == nullptr) {
return;
}
if (node->isa<Parameter>()) {
auto param = node->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(param);
param->set_fracz_group(groups);
MS_LOG(INFO) << "set parameter " << param->fullname_with_scope() << " with fracz_group: " << groups;
} else if (node->isa<CNode>()) {
AnfAlgo::SetNodeAttr(kAttrFracZGroup, MakeValue(groups), node);
if (AnfAlgo::GetCNodeName(node) == kTransDataOpName) {
AnfAlgo::SetNodeAttr(kAttrGroups, MakeValue(groups), node);
}
auto cnode = node->cast<CNodePtr>();
SetAttrForInputNode(cnode->input(kElemFilterInputIndex), groups);
}
}
void SetAttrForConvInput(const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode);
auto groups = AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrGroups);
if (groups > 1) {
SetAttrForInputNode(cnode->input(kConvFilterInputIndex), groups);
}
}
void SetAttrForOptParamInput(const AnfNodePtr &node, int64_t groups) {
// For optimizer, there may be other parameters used by opt need to be set.
// For example, moments param used by FusedMulApplyMomentum.
MS_EXCEPTION_IF_NULL(node);
auto opt_cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(opt_cnode);
auto opt_inputs = opt_cnode->inputs();
for (size_t i = 1; i < opt_inputs.size(); ++i) {
auto input_node = opt_inputs[i];
if (input_node->isa<CNode>() && AnfAlgo::GetCNodeName(input_node) == kLoadName) {
auto input_cnode = input_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(input_cnode);
input_node = input_cnode->input(kElemFilterInputIndex);
}
if (input_node->isa<Parameter>()) {
auto param = input_node->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(param);
param->set_fracz_group(groups);
}
}
}
void SetFracZGroupIdxForAllReduce(const AnfNodePtr &node, const int64_t index) {
// When Allreduce do fusion, there may be several FracZ outputs with groups=1 or groups>1,
// so we need to record the output index with groups>1
MS_EXCEPTION_IF_NULL(node);
auto allreduce = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(allreduce);
if (AnfAlgo::HasNodeAttr(kAttrFracZGroupIdx, allreduce)) {
auto fz_group_idx = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(allreduce, kAttrFracZGroupIdx);
fz_group_idx.push_back(index);
AnfAlgo::SetNodeAttr(kAttrFracZGroupIdx, MakeValue(fz_group_idx), allreduce);
} else {
AnfAlgo::SetNodeAttr(kAttrFracZGroupIdx, MakeValue(std::vector<int64_t>{index}), allreduce);
}
}
void SetAttrForOutputNode(const FuncGraphManagerPtr &manager, const AnfNodePtr &node, int64_t groups,
int64_t getitem_idx = 0) {
MS_EXCEPTION_IF_NULL(manager);
MS_EXCEPTION_IF_NULL(node);
if (AnfAlgo::GetCNodeName(node) != kTupleGetItemName) {
AnfAlgo::SetNodeAttr(kAttrFracZGroup, MakeValue(groups), node);
}
for (auto node_index : manager->node_users()[node]) {
auto output_node = node_index.first;
auto output_name = AnfAlgo::GetCNodeName(output_node);
if (kOptOperatorSet.find(output_name) != kOptOperatorSet.end()) {
SetAttrForOptParamInput(output_node, groups);
AnfAlgo::SetNodeAttr(kAttrFracZGroup, MakeValue(groups), output_node);
} else if (output_name == kTransDataOpName) {
// Trans to other format, no need to recurse, but need to set Groups attr for TransData
AnfAlgo::SetNodeAttr(kAttrGroups, MakeValue(groups), output_node);
} else if (output_name == kAllReduceOpName) {
int64_t index = static_cast<int64_t>(node_index.second) - 1;
SetFracZGroupIdxForAllReduce(output_node, index);
SetAttrForOutputNode(manager, output_node, groups, index);
} else if (output_name == kTupleGetItemName) {
auto getitem = output_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(getitem);
auto getitem_input2 = getitem->input(kInputNodeOutputIndexInTupleGetItem);
auto output_idx = GetValue<int64_t>(GetValueNode(getitem_input2));
if (output_idx == getitem_idx) {
SetAttrForOutputNode(manager, output_node, groups);
return;
}
} else {
SetAttrForOutputNode(manager, output_node, groups);
}
}
}
void SetAttrForConvOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(func_graph);
auto groups = AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrGroups);
if (groups > 1) {
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
MS_EXCEPTION_IF_NULL(kernel_graph);
auto manager = kernel_graph->manager();
SetAttrForOutputNode(manager, cnode, groups);
}
}
} // namespace
bool SetFraczGroupAttr::Run(const FuncGraphPtr &func_graph) {
if (func_graph == nullptr) {
MS_LOG(ERROR) << "func_graph is nullptr.";
return false;
}
std::vector<AnfNodePtr> node_list = TopoSort(func_graph->get_return());
for (auto node : node_list) {
if (node == nullptr || !node->isa<CNode>()) {
continue;
}
auto cnode = node->cast<CNodePtr>();
if (cnode == nullptr) {
continue;
}
auto node_name = AnfAlgo::GetCNodeName(cnode);
if (node_name == kConv2DOpName || node_name == kConv2DBackpropInputOpName) {
SetAttrForConvInput(cnode);
} else if (node_name == kConv2DBackpropFilterOpName) {
SetAttrForConvOutput(func_graph, cnode);
}
}
return true;
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,37 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_SET_FRACZ_GROUP_ATTR_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_SET_FRACZ_GROUP_ATTR_H_
#include <vector>
#include <memory>
#include <utility>
#include "ir/anf.h"
#include "backend/optimizer/common/pass.h"
namespace mindspore {
namespace opt {
class SetFraczGroupAttr : public Pass {
public:
SetFraczGroupAttr() : Pass("set_fracz_group_attr") {}
~SetFraczGroupAttr() override = default;
bool Run(const FuncGraphPtr &graph) override;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_SET_FRACZ_GROUP_ATTR_H_

View File

@ -590,7 +590,7 @@ size_t AnfRuntimeAlgorithm::GetOutputTensorMemSize(const AnfNodePtr &node, size_
auto format = AnfAlgo::GetOutputFormat(node, output_index);
if (shape.empty() && format != kOpFormat_DEFAULT) {
shape = trans::PaddingShape(shape, format, AnfAlgo::GetOutputReshapeType(node, output_index));
shape = trans::TransShapeToDevice(shape, format);
shape = trans::TransShapeToDevice(shape, format, node, output_index);
}
// scalar's output shape is a empty vector
size_t tensor_size = std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies<size_t>());
@ -804,7 +804,7 @@ std::vector<size_t> AnfRuntimeAlgorithm::GetOutputDeviceShape(const AnfNodePtr &
if (trans::IsNeedPadding(format, infer_shape.size())) {
infer_shape = trans::PaddingShape(infer_shape, format, GetOutputReshapeType(node, output_idx));
}
return trans::TransShapeToDevice(infer_shape, format);
return trans::TransShapeToDevice(infer_shape, format, node, output_idx);
}
std::vector<size_t> AnfRuntimeAlgorithm::GetInputDeviceShape(const AnfNodePtr &node, size_t input_idx) {
@ -817,7 +817,8 @@ std::vector<size_t> AnfRuntimeAlgorithm::GetInputDeviceShape(const AnfNodePtr &n
if (trans::IsNeedPadding(format, infer_shape.size())) {
infer_shape = trans::PaddingShape(infer_shape, format, GetInputReshapeType(node, input_idx));
}
return trans::TransShapeToDevice(infer_shape, format);
auto input_node_index = GetPrevNodeOutput(node, input_idx);
return trans::TransShapeToDevice(infer_shape, format, input_node_index.first, input_node_index.second);
}
std::string AnfRuntimeAlgorithm::GetInputReshapeType(const AnfNodePtr &node, size_t input_idx) {
@ -1414,7 +1415,7 @@ bool AnfRuntimeAlgorithm::IsTupleOutput(const AnfNodePtr &anf) {
AnfNodePtr AnfRuntimeAlgorithm::GetInputNode(const CNodePtr &node, size_t index) {
MS_EXCEPTION_IF_NULL(node);
auto get_input_index = index + 1;
if (index + 1 >= node->inputs().size()) {
if (get_input_index >= node->inputs().size()) {
MS_LOG(EXCEPTION) << "Input index size " << get_input_index << "but the node input size just"
<< node->inputs().size() << " trace: " << trace::DumpSourceLines(node);
}
@ -1973,7 +1974,8 @@ std::vector<size_t> AnfRuntimeAlgorithm::GetInputRealDeviceShapeIfExist(const An
auto max_shape = GetInputMaxShape(anf_node, index);
std::transform(max_shape.begin(), max_shape.end(), device_shape.begin(), IntToSize);
auto format = GetInputFormat(anf_node, index);
trans::TransShapeToDevice(device_shape, format);
auto input_node_index = GetPrevNodeOutput(anf_node, index);
trans::TransShapeToDevice(device_shape, format, input_node_index.first, input_node_index.second);
}
return device_shape;
}
@ -1985,7 +1987,7 @@ std::vector<size_t> AnfRuntimeAlgorithm::GetOutputRealDeviceShapeIfExist(const A
auto max_shape = GetOutputMaxShape(anf_node, index);
std::transform(max_shape.begin(), max_shape.end(), device_shape.begin(), IntToSize);
auto format = GetOutputFormat(anf_node, index);
trans::TransShapeToDevice(device_shape, format);
trans::TransShapeToDevice(device_shape, format, anf_node, index);
}
return device_shape;
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -49,6 +49,29 @@ inline void SetData(size_t size, bool pad_zero, size_t src_idx, size_t dst_idx,
}
}
// greatest common divsor
size_t Gcd(size_t a, size_t b) {
if (b == 0) {
return 0;
}
size_t c = b;
while (a % b != 0) {
c = a % b;
a = b;
b = c;
}
return c;
}
// least common multiple
size_t Lcm(size_t a, size_t b) {
if (b == 0) {
return 0;
}
size_t ret = (a * b) / (Gcd(a, b));
return ret;
}
template <typename T>
T DivCeil(T n1, T n2) {
if (n2 != 0) {
@ -57,6 +80,10 @@ T DivCeil(T n1, T n2) {
return 0;
}
size_t GetShapeSize(const std::vector<size_t> &shape) {
return std::accumulate(shape.begin(), shape.end(), size_t(1), std::multiplies<size_t>());
}
enum DataTypeTransMode {
FROM_FLOAT_TO_FLOAT16,
FROM_FLOAT_TO_INT32,
@ -378,8 +405,54 @@ std::vector<size_t> PaddingShapeTo4dByDefault(const std::vector<size_t> &shape)
}
return shape_4d;
}
std::vector<size_t> FracZDeviceShapeWithGroups(const std::vector<size_t> &shape, const int64_t groups = 1) {
if (!CheckDims(shape)) {
MS_LOG(EXCEPTION) << "Check dims failed.";
}
size_t group_size = LongToSize(groups);
size_t cin_ori = shape[kC];
size_t cout_ori = shape[kN] / group_size;
size_t e_mult = std::min(Lcm(Lcm(cin_ori, kCubeSize) / cin_ori, Lcm(cout_ori, kCubeSize) / cout_ori), group_size);
size_t cin_opt = DivCeil(e_mult * cin_ori, kCubeSize) * kCubeSize;
size_t c1_dim = cin_opt / kCubeSize;
size_t g_dim = DivCeil(group_size, e_mult);
size_t n1 = DivCeil(cout_ori * e_mult, kCubeSize);
std::vector<size_t> device_shape;
device_shape.push_back(g_dim * c1_dim * shape[kH] * shape[kW]);
device_shape.push_back(n1);
device_shape.push_back(kNiSize);
device_shape.push_back(kCubeSize);
return device_shape;
}
} // namespace
int64_t GetAttrGroups(const AnfNodePtr &node, const size_t index) {
if (node == nullptr) {
return 1;
}
if (node->isa<CNode>()) {
auto cnode = node->cast<CNodePtr>();
if (AnfAlgo::HasNodeAttr(kAttrFracZGroup, cnode)) {
if (AnfAlgo::GetCNodeName(cnode) == kAllReduceOpName) {
// if index not exists in fracz_group_idx, return default value 1
auto fz_group_idx = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(cnode, kAttrFracZGroupIdx);
int64_t out_index = SizeToLong(index);
auto fz_iter = std::find(std::begin(fz_group_idx), std::end(fz_group_idx), out_index);
if (fz_iter == std::end(fz_group_idx)) {
return 1;
}
}
return AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrFracZGroup);
}
} else if (node->isa<Parameter>()) {
auto param = node->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(param);
return param->fracz_group();
}
return 1;
}
bool IsNeedPadding(const std::string &format, const size_t shape_size) {
if (shape_size == 0) {
return false;
@ -547,7 +620,8 @@ std::vector<size_t> PaddingShapeTo5dDefault(const std::vector<size_t> &shape) {
return shape_5d;
}
std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format) {
std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format,
const int64_t groups) {
using DeviceShapeTransfer = std::function<std::vector<size_t>(const std::vector<size_t> &)>;
const std::map<std::string, DeviceShapeTransfer> device_shape_map{{kOpFormat_NCHW, NchwDeviceShape},
{kOpFormat_NHWC, NhwcDeviceShape},
@ -565,6 +639,9 @@ std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const s
if (format == kOpFormat_ND || format == kOpFormat_DEFAULT) {
return shape;
}
if (groups > 1 && format == kOpFormat_FRAC_Z) {
return FracZDeviceShapeWithGroups(shape, groups);
}
auto temp_shape = shape;
std::vector<size_t> device_shape;
if (format == kOpFormat_FRAC_NZ) {
@ -610,6 +687,15 @@ std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const s
return iter->second(temp_shape);
}
std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format,
const AnfNodePtr &node, const size_t index) {
int64_t groups = 1;
if (format == kOpFormat_FRAC_Z) {
groups = GetAttrGroups(node, index);
}
return TransShapeToDevice(shape, format, groups);
}
bool CheckArgs(const FormatArgs &args, size_t *size, size_t *total_size) {
if (args.host_shape.size() != kNchwDims) {
MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
@ -649,7 +735,7 @@ bool TransDataType(const TypeIdArgs &args, void *result) {
return true;
}
bool TransFormat(const FormatArgs &args, void *result) {
bool TransFormat(const FormatArgs &args, void *result, int64_t groups) {
MS_LOG(DEBUG) << "Start trans format.";
if (abstract::TypeIdSize(args.src_data_type) < 1) {
MS_LOG(ERROR) << "Invalid datatype..";
@ -658,6 +744,9 @@ bool TransFormat(const FormatArgs &args, void *result) {
if (args.device_format == kOpFormat_HWCN || args.device_format == kOpFormat_NHWC) {
return NchwTo4D(args, result);
}
if (groups > 1 && args.device_format == kOpFormat_FRAC_Z) {
return NchwToFracZWithGroups(args, result, groups);
}
auto iter = kTransFormatMapOfHostToDevice.find(args.device_format);
if (iter == kTransFormatMapOfHostToDevice.end()) {
MS_LOG(EXCEPTION) << "Unexpected format[" << args.device_format << "]";
@ -665,7 +754,15 @@ bool TransFormat(const FormatArgs &args, void *result) {
return iter->second(args, result);
}
bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result) {
bool TransFormat(const FormatArgs &args, void *result, const AnfNodePtr &node, const size_t index) {
int64_t groups = 1;
if (args.device_format == kOpFormat_FRAC_Z) {
groups = GetAttrGroups(node, index);
}
return TransFormat(args, result, groups);
}
bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result, int64_t groups) {
const std::map<std::string, FormatTransfer> format_trans_map{
{kOpFormat_FRAC_Z, FracZToNchw}, {kOpFormat_FRAC_NZ, FracNzToNchw},
{kOpFormat_NC1HWC0, Nc1hwc0ToNchw}, {kOpFormat_C1HWNCoC0, C1hwncoc0ToNchw},
@ -680,6 +777,9 @@ bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result) {
if (args.device_format == kOpFormat_HWCN || args.device_format == kOpFormat_NHWC) {
return ToNchw(args, result);
}
if (groups > 1 && args.device_format == kOpFormat_FRAC_Z) {
return FracZToNchwWithGroups(args, result, groups);
}
auto iter = format_trans_map.find(args.device_format);
if (iter == format_trans_map.end()) {
MS_LOG(EXCEPTION) << "Unexpected format[" << args.device_format << "]";
@ -687,6 +787,14 @@ bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result) {
return iter->second(args, result);
}
bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result, const AnfNodePtr &node, const size_t index) {
int64_t groups = 1;
if (args.device_format == kOpFormat_FRAC_Z) {
groups = GetAttrGroups(node, index);
}
return TransFormatFromDeviceToHost(args, result, groups);
}
bool NchwTo4D(const FormatArgs &args, void *result) {
// trans nchw to 4d
MS_LOG(DEBUG) << "Trans format from nchw to 4d.";
@ -1477,5 +1585,82 @@ bool FracZ3DToNcdhw(const FormatArgs &args, void *result) {
}
return true;
}
bool NchwFracZTransWithGroups(const FormatArgs &args, void *result, bool to_device, int64_t groups) {
MS_LOG(DEBUG) << "Trans format from nchw to frac_z";
MS_EXCEPTION_IF_NULL(result);
if (args.host_shape.size() != kNchwDims) {
MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
return false;
}
auto size = abstract::TypeIdSize(args.src_data_type);
if (size < 1) {
MS_LOG(ERROR) << "Illegal dtype";
return false;
}
auto n_dim = args.host_shape[kN];
auto c_dim = args.host_shape[kC];
auto h_dim = args.host_shape[kH];
auto w_dim = args.host_shape[kW];
size_t d_dim = 1;
size_t group_size = LongToSize(groups);
auto cin_ori = c_dim;
auto cout_ori = n_dim / group_size;
if (cin_ori == 0 || cout_ori == 0) {
MS_LOG(ERROR) << "cin_ori, cout_ori must not equal to 0";
return false;
}
size_t e_mult = std::min(Lcm(Lcm(cin_ori, kCubeSize) / cin_ori, Lcm(cout_ori, kCubeSize) / cout_ori), group_size);
size_t cin_opt = DivCeil(e_mult * cin_ori, kCubeSize) * kCubeSize;
size_t cout_opt = DivCeil(e_mult * cout_ori, kCubeSize) * kCubeSize;
size_t c1_dim = cin_opt / kCubeSize;
size_t dst_size = to_device ? GetShapeSize(args.device_shape) * size : GetShapeSize(args.host_shape) * size;
if (dst_size == 0) {
return true;
}
auto ret = memset_s(result, dst_size, 0, dst_size);
if (ret != EOK) {
MS_LOG(ERROR) << "memset failed";
return false;
}
for (size_t g = 0; g < group_size; ++g) {
for (size_t d = 0; d < d_dim; ++d) {
for (size_t c = 0; c < c_dim; ++c) {
for (size_t h = 0; h < h_dim; ++h) {
for (size_t w = 0; w < w_dim; ++w) {
for (size_t n = 0; n < cout_ori; ++n) {
size_t e_val = g % e_mult;
size_t dst_ci = e_val * cin_ori + c;
size_t dst_co = e_val * cout_ori + n;
size_t src_co = g * cout_ori + n;
size_t temporary = dst_ci % kCubeSize;
size_t dev_idx = (g / e_mult) * d_dim * c1_dim * h_dim * w_dim * cout_opt * kCubeSize +
d * c1_dim * h_dim * w_dim * cout_opt * kCubeSize +
(dst_ci / kCubeSize) * h_dim * w_dim * cout_opt * kCubeSize +
h * w_dim * cout_opt * kCubeSize + w * cout_opt * kCubeSize + dst_co * kCubeSize +
temporary;
size_t hst_idx =
src_co * c_dim * d_dim * h_dim * w_dim + c * d_dim * h_dim * w_dim + d * h_dim * w_dim + h * w_dim + w;
if (to_device) {
SetData(size, false, hst_idx, dev_idx, args, result);
} else {
SetData(size, false, dev_idx, hst_idx, args, result);
}
}
}
}
}
}
}
return true;
}
bool NchwToFracZWithGroups(const FormatArgs &args, void *result, int64_t groups) {
return NchwFracZTransWithGroups(args, result, true, groups);
}
bool FracZToNchwWithGroups(const FormatArgs &args, void *result, int64_t groups) {
return NchwFracZTransWithGroups(args, result, false, groups);
}
} // namespace trans
} // namespace mindspore

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -64,10 +64,16 @@ void StringToAxisVector4D(const std::string &reshape_type_str, std::vector<Axis>
void StringToAxisVector5D(const std::string &reshape_type_str, std::vector<Axis5D> *reshape_type_vec);
ShapeVector GetRuntimePaddingShape(const AnfNodePtr &node, size_t index);
bool IsNeedPadding(const std::string &format, const size_t shape_size);
std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format);
int64_t GetNodeGroups(const AnfNodePtr &node);
std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format,
const int64_t groups = 1);
std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format,
const AnfNodePtr &node, const size_t index);
bool TransDataType(const TypeIdArgs &args, void *result);
bool TransFormat(const FormatArgs &args, void *result);
bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result);
bool TransFormat(const FormatArgs &args, void *result, int64_t groups = 1);
bool TransFormat(const FormatArgs &args, void *result, const AnfNodePtr &node, const size_t index);
bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result, int64_t groups = 1);
bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result, const AnfNodePtr &node, const size_t index);
// host to device
bool NchwTo4D(const FormatArgs &args, void *result);
@ -79,6 +85,7 @@ bool NchwToFracZc04(const FormatArgs &args, void *result);
bool NchwToNc1hwc04(const FormatArgs &args, void *result);
bool NchwToC1hwncoc0(const FormatArgs &args, void *result);
bool NcdhwToNdc1hwc0(const FormatArgs &args, void *result);
bool NchwToFracZWithGroups(const FormatArgs &args, void *result, int64_t groups);
// device to host
bool ToNchw(const FormatArgs &args, void *result);
@ -89,6 +96,7 @@ bool Nc1hwc04ToNchw(const FormatArgs &args, void *result);
bool FracZ3DToNcdhw(const FormatArgs &args, void *result);
bool C1hwncoc0ToNchw(const FormatArgs &args, void *result);
bool Ndc1hwc0ToNcdhw(const FormatArgs &args, void *result);
bool FracZToNchwWithGroups(const FormatArgs &args, void *result, int64_t groups);
using FormatTransfer = std::function<bool(const FormatArgs &, void *)>;
const std::map<std::string, FormatTransfer> kTransFormatMapOfHostToDevice{
{kOpFormat_FRAC_Z, NchwToFracZ}, {kOpFormat_FRAC_NZ, NchwToFracNz},

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -480,8 +480,9 @@ std::vector<size_t> AscendDeviceAddress::GetWorkspaceSizeList(const nlohmann::js
std::vector<size_t> AscendDeviceAddress::GetDeviceShape(std::vector<size_t> *host_shape) const {
std::vector<size_t> device_shape;
auto node_index = GetNodeIndex();
if (format_ == kOpFormat_FRAC_NZ || format_ == kOpFormat_NCDHW) {
device_shape = trans::TransShapeToDevice(*host_shape, format_);
device_shape = trans::TransShapeToDevice(*host_shape, format_, node_index.first, node_index.second);
} else {
if (host_shape_.empty()) {
*host_shape = trans::PaddingShape(*host_shape, format_);
@ -489,7 +490,7 @@ std::vector<size_t> AscendDeviceAddress::GetDeviceShape(std::vector<size_t> *hos
host_shape->clear();
(void)std::transform(host_shape_.begin(), host_shape_.end(), std::back_inserter(*host_shape), LongToSize);
}
device_shape = trans::TransShapeToDevice(*host_shape, format_);
device_shape = trans::TransShapeToDevice(*host_shape, format_, node_index.first, node_index.second);
}
return device_shape;
}
@ -518,11 +519,12 @@ bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormat(const ShapeVector &sh
}
auto host_tmp = std::vector<uint8_t>(size_);
SyncMemory(host_tmp.data(), ptr_, size_, RT_MEMCPY_DEVICE_TO_HOST);
auto node_index = GetNodeIndex();
if (type_id_ != type) {
const trans::FormatArgs format_args{host_tmp.data(), size_, kOpFormat_NCHW, format_,
host_shape, device_shape, type_id_};
auto host = std::vector<uint8_t>(size_);
sync_ok = trans::TransFormatFromDeviceToHost(format_args, host.data());
sync_ok = trans::TransFormatFromDeviceToHost(format_args, host.data(), node_index.first, node_index.second);
if (!sync_ok) {
MS_LOG(ERROR) << "Trans format failed.";
return false;
@ -537,7 +539,7 @@ bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormat(const ShapeVector &sh
} else {
const trans::FormatArgs format_args{host_tmp.data(), size_, kOpFormat_NCHW, format_,
host_shape, device_shape, type_id_};
sync_ok = trans::TransFormatFromDeviceToHost(format_args, host_ptr);
sync_ok = trans::TransFormatFromDeviceToHost(format_args, host_ptr, node_index.first, node_index.second);
if (!sync_ok) {
MS_LOG(ERROR) << "Trans format failed.";
return false;
@ -604,12 +606,13 @@ bool AscendDeviceAddress::ConvertFormatAndSyncHostToDevice(const ShapeVector &sh
if (host_shape.empty()) {
host_shape.emplace_back(1);
}
auto node_index = GetNodeIndex();
std::vector<size_t> device_shape;
if (format_ == kOpFormat_FRAC_NZ) {
device_shape = trans::TransShapeToDevice(host_shape, format_);
device_shape = trans::TransShapeToDevice(host_shape, format_, node_index.first, node_index.second);
} else {
host_shape = trans::PaddingShape(host_shape, format_);
device_shape = trans::TransShapeToDevice(host_shape, format_);
device_shape = trans::TransShapeToDevice(host_shape, format_, node_index.first, node_index.second);
}
if (type_id_ != type) {
auto shape_size = abstract::ShapeSize(host_shape);
@ -623,7 +626,7 @@ bool AscendDeviceAddress::ConvertFormatAndSyncHostToDevice(const ShapeVector &sh
const trans::FormatArgs format_args{host_tmp.data(), size_, kOpFormat_NCHW, format_,
host_shape, device_shape, type_id_};
auto dst_tmp = std::vector<uint8_t>(size_);
sync_ok = trans::TransFormat(format_args, dst_tmp.data());
sync_ok = trans::TransFormat(format_args, dst_tmp.data(), node_index.first, node_index.second);
if (!sync_ok) {
MS_LOG(ERROR) << "Trans format failed.";
return false;
@ -632,7 +635,7 @@ bool AscendDeviceAddress::ConvertFormatAndSyncHostToDevice(const ShapeVector &sh
} else {
const trans::FormatArgs format_args{host_ptr, size_, kOpFormat_NCHW, format_, host_shape, device_shape, type_id_};
auto host_tmp = std::vector<uint8_t>(size_);
sync_ok = trans::TransFormat(format_args, host_tmp.data());
sync_ok = trans::TransFormat(format_args, host_tmp.data(), node_index.first, node_index.second);
if (!sync_ok) {
MS_LOG(ERROR) << "Trans format failed.";
return false;

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -38,6 +38,9 @@ class AscendDeviceAddress : public DeviceAddress {
explicit AscendDeviceAddress(void *ptr, size_t size) : DeviceAddress(ptr, size) {}
explicit AscendDeviceAddress(void *ptr, size_t size, const std::string &format, TypeId type_id)
: DeviceAddress(ptr, size, format, type_id) {}
explicit AscendDeviceAddress(void *ptr, size_t size, const std::string &format, TypeId type_id,
const AnfNodePtr &node, size_t out_index)
: DeviceAddress(ptr, size, format, type_id, node, out_index) {}
~AscendDeviceAddress() override;
bool SyncDeviceToHost(size_t size, void *host_ptr) const override;
bool SyncHostToDevice(size_t size, const void *host_ptr) const override;

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -333,6 +333,11 @@ DeviceAddressPtr AscendKernelRuntime::CreateDeviceAddress(void *device_ptr, size
return std::make_shared<AscendDeviceAddress>(device_ptr, device_size, format, type_id);
}
DeviceAddressPtr AscendKernelRuntime::CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
TypeId type_id, const AnfNodePtr &node, size_t out_index) {
return std::make_shared<AscendDeviceAddress>(device_ptr, device_size, format, type_id, node, out_index);
}
bool AscendKernelRuntime::Load(session::KernelGraph *graph, bool is_task_sink) {
if (!is_task_sink) {
GenKernelEvents(graph);

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -66,6 +66,8 @@ class AscendKernelRuntime : public KernelRuntime {
protected:
DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
TypeId type_id) override;
DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, TypeId type_id,
const AnfNodePtr &node, size_t out_index) override;
bool KernelMemNotReuse(const AnfNodePtr &node) override;
void KernelLaunchProfiling(const std::string &kernel_name) override;

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -31,6 +31,10 @@ class CPUDeviceAddress : public DeviceAddress {
CPUDeviceAddress(void *ptr, size_t size, const string &format, TypeId type_id)
: DeviceAddress(ptr, size, format, type_id) {}
CPUDeviceAddress(void *ptr, size_t size, const std::string &format, TypeId type_id, const AnfNodePtr &node,
size_t out_index)
: DeviceAddress(ptr, size, format, type_id, node, out_index) {}
~CPUDeviceAddress() override = default;
bool SyncDeviceToHost(const ShapeVector &shape, size_t size, TypeId type, void *host_ptr) const override;

View File

@ -175,6 +175,11 @@ DeviceAddressPtr CPUKernelRuntime::CreateDeviceAddress(void *device_ptr, size_t
return std::make_shared<CPUDeviceAddress>(device_ptr, device_size, format, type_id);
}
DeviceAddressPtr CPUKernelRuntime::CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
TypeId type_id, const AnfNodePtr &node, size_t out_index) {
return std::make_shared<CPUDeviceAddress>(device_ptr, device_size, format, type_id, node, out_index);
}
tensor::TensorPtr CPUKernelRuntime::CreatTensorForOutput(
session::KernelGraph *kernel_graph, const CNodePtr &node, size_t index,
std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node) {

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -53,6 +53,8 @@ class CPUKernelRuntime : public KernelRuntime {
bool MemcpyAsync(void *dst, const void *src, uint64_t size, int32_t kind) override { return true; };
DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
TypeId type_id) override;
DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, TypeId type_id,
const AnfNodePtr &node, size_t out_index) override;
private:
tensor::TensorPtr CreatTensorForOutput(session::KernelGraph *kernel_graph, const CNodePtr &node, size_t index,

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -21,6 +21,7 @@
#include <vector>
#include <memory>
#include <map>
#include <utility>
#include "ir/dtype.h"
#include "ir/device_sync.h"
#include "utils/shape_utils.h"
@ -52,6 +53,7 @@ class GPUDeviceContext;
namespace mindspore {
namespace device {
using KernelWithIndex = std::pair<AnfNodePtr, size_t>;
enum class DeviceAddressStatus { kInDevice, kInHost, kInDeviceToHost, kInHostToDevice };
enum class DeviceAddressType { kUnknown, kAscend, kCPU, kGPU };
static const std::map<DeviceAddressType, std::string> kDeviceTypeToName = {{DeviceAddressType::kUnknown, "Unknown"},
@ -64,6 +66,9 @@ class DeviceAddress : public mindspore::DeviceSync {
explicit DeviceAddress(void *ptr, size_t size) : ptr_(ptr), size_(size) {}
explicit DeviceAddress(void *ptr, size_t size, const string &format, TypeId type_id)
: ptr_(ptr), size_(size), format_(format), type_id_(type_id) {}
explicit DeviceAddress(void *ptr, size_t size, const std::string &format, TypeId type_id, const AnfNodePtr &node,
size_t out_index)
: ptr_(ptr), size_(size), format_(format), type_id_(type_id), node_index_({node, out_index}) {}
virtual ~DeviceAddress() { ptr_ = nullptr; }
const void *GetPtr() const { return ptr_; }
size_t GetSize() const { return size_; }
@ -75,6 +80,7 @@ class DeviceAddress : public mindspore::DeviceSync {
virtual DeviceAddressStatus status() const { return DeviceAddressStatus::kInDevice; }
virtual DeviceAddressType DeviceType() const { return DeviceAddressType::kUnknown; }
void *GetMutablePtr() const override { return ptr_; }
virtual void SetNodeIndex(const AnfNodePtr &node, size_t out_index) { node_index_ = {node, out_index}; }
// The related interface of reference count operation.
void set_original_ref_count(size_t original_ref_count) { original_ref_count_ = original_ref_count; }
@ -100,6 +106,10 @@ class DeviceAddress : public mindspore::DeviceSync {
const void *ptr() const { return ptr_; }
size_t size() const { return size_; }
void set_ptr(void *ptr) { ptr_ = ptr; }
KernelWithIndex GetNodeIndex() const {
return node_index_.first.expired() ? KernelWithIndex{nullptr, node_index_.second}
: KernelWithIndex{node_index_.first.lock(), node_index_.second};
}
void *ptr_{nullptr};
size_t size_{0};
size_t original_ref_count_{1};
@ -110,6 +120,8 @@ class DeviceAddress : public mindspore::DeviceSync {
bool from_mem_pool_{false};
uint8_t *communication_ptr_{nullptr};
ShapeVector host_shape_{};
// {node, out_index}
std::pair<AnfNodeWeakPtr, size_t> node_index_{AnfNodePtr(nullptr), 0};
friend class KernelRuntime;
friend class MemoryManager;
friend class mindspore::device::ascend::tasksink::TaskGenerator;

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -34,6 +34,9 @@ class GPUDeviceAddress : public DeviceAddress {
GPUDeviceAddress(void *ptr, size_t size) : DeviceAddress(ptr, size) {}
GPUDeviceAddress(void *ptr, size_t size, const string &format, TypeId type_id)
: DeviceAddress(ptr, size, format, type_id) {}
GPUDeviceAddress(void *ptr, size_t size, const std::string &format, TypeId type_id, const AnfNodePtr &node,
size_t out_index)
: DeviceAddress(ptr, size, format, type_id, node, out_index) {}
~GPUDeviceAddress() override;
bool SyncDeviceToHost(size_t size, void *host_ptr) const override;
@ -51,6 +54,7 @@ class GPUDeviceAddress : public DeviceAddress {
bool LoadMemToHost(const std::string &tensor_name, int execution_order, const std::string &host_fmt,
const ShapeVector &host_shape, TypeId host_type, size_t slot, bool keep_prev) const override;
#endif
private:
DeviceAddressStatus status_{DeviceAddressStatus::kInDevice};
};

View File

@ -211,6 +211,11 @@ DeviceAddressPtr GPUKernelRuntime::CreateDeviceAddress(void *device_ptr, size_t
return std::make_shared<GPUDeviceAddress>(device_ptr, device_size, format, type_id);
}
DeviceAddressPtr GPUKernelRuntime::CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
TypeId type_id, const AnfNodePtr &node, size_t out_index) {
return std::make_shared<GPUDeviceAddress>(device_ptr, device_size, format, type_id, node, out_index);
}
bool GPUKernelRuntime::InitDevice() {
if (GPUDeviceManager::GetInstance().device_count() <= 0) {
MS_LOG(ERROR) << "No GPU device found.";

View File

@ -58,6 +58,8 @@ class GPUKernelRuntime : public KernelRuntime {
protected:
DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
TypeId type_id) override;
DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, TypeId type_id,
const AnfNodePtr &node, size_t out_index) override;
bool SyncStream() override;
bool MemcpyAsync(void *dst, const void *src, uint64_t size, int32_t kind) override;

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -165,7 +165,7 @@ void KernelRuntime::RunOpAssignInputMemory(const std::vector<tensor::TensorPtr>
}
auto tensor_size = AnfAlgo::GetOutputTensorMemSize(item, index);
auto device_address =
CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id);
CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id, item, index);
MS_EXCEPTION_IF_NULL(device_address);
MS_EXCEPTION_IF_NULL(mem_manager_);
auto ret = mem_manager_->MallocMemFromMemPool(device_address, tensor_size);
@ -198,7 +198,7 @@ void KernelRuntime::RunOpAssignOutputMemory(const AnfNodePtr &kernel) {
}
std::string output_format = AnfAlgo::GetOutputFormat(kernel, i);
auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel, i);
auto device_address = CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type);
auto device_address = CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type, kernel, i);
device_address->set_host_shape(trans::GetRuntimePaddingShape(kernel, i));
MS_EXCEPTION_IF_NULL(device_address);
auto ret = mem_manager_->MallocMemFromMemPool(device_address, output_sizes[i]);
@ -334,14 +334,15 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) {
}
const auto &address = ps::ps_cache_instance.QueryHashTableAddr(param_name);
MS_EXCEPTION_IF_NULL(address.addr);
device_address =
CreateDeviceAddress(address.addr, address.size, AnfAlgo::GetOutputFormat(item, index), output_type_id);
device_address = CreateDeviceAddress(address.addr, address.size, AnfAlgo::GetOutputFormat(item, index),
output_type_id, item, index);
AnfAlgo::SetOutputAddr(device_address, index, item.get());
continue;
}
#endif
auto tensor_size = AnfAlgo::GetOutputTensorMemSize(item, index);
device_address = CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id);
device_address =
CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id, item, index);
MS_LOG(INFO) << "Assign Static Memory for Input node, size:" << tensor_size
<< " node:" << item->fullname_with_scope() << " index: " << index;
if (mem_manager_->MallocMem(kStaticMem, tensor_size, device_address, graph->graph_id()) == nullptr) {
@ -512,7 +513,7 @@ void KernelRuntime::AssignCommunicationNodeOutputMem(MemType type, const AnfNode
for (size_t j = 0; j < align_size_list.size(); ++j) {
std::string output_format = AnfAlgo::GetOutputFormat(node, j);
auto output_type = AnfAlgo::GetOutputDeviceDataType(node, j);
auto address = CreateDeviceAddress(nullptr, output_sizes[j], output_format, output_type);
auto address = CreateDeviceAddress(nullptr, output_sizes[j], output_format, output_type, node, j);
MS_EXCEPTION_IF_NULL(address);
if (output_ptr == nullptr) {
output_ptr = mem_manager_->MallocOutputMem(node, 0, type, total_size, address, true);
@ -548,7 +549,7 @@ DeviceAddressPtr KernelRuntime::PreAssignCNodeMemory(const AnfNodePtr &anf_node,
}
std::string output_format = AnfAlgo::GetOutputFormat(anf_node, index);
auto output_type = AnfAlgo::GetOutputDeviceDataType(anf_node, index);
auto address = CreateDeviceAddress(nullptr, output_sizes[index], output_format, output_type);
auto address = CreateDeviceAddress(nullptr, output_sizes[index], output_format, output_type, anf_node, index);
AnfAlgo::SetOutputAddr(address, index, anf_node.get());
return address;
}
@ -641,7 +642,7 @@ void KernelRuntime::AssignNodeOutputMem(MemType type, const AnfNodePtr &node, in
}
std::string output_format = AnfAlgo::GetOutputFormat(node, i);
auto output_type = AnfAlgo::GetOutputDeviceDataType(node, i);
auto device_address = CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type);
auto device_address = CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type, node, i);
MS_EXCEPTION_IF_NULL(device_address);
uint8_t *ptr = mem_manager_->MallocOutputMem(node, i, type, output_sizes[i], device_address, false);
MS_EXCEPTION_IF_NULL(ptr);
@ -681,7 +682,8 @@ void KernelRuntime::AssignValueNodeTensor(const ValueNodePtr &value_node, const
output_type_id = AnfAlgo::GetOutputInferDataType(value_node, output_idx);
}
auto output_format = AnfAlgo::GetOutputFormat(value_node, output_idx);
DeviceAddressPtr address = CreateDeviceAddress(nullptr, node_size, output_format, output_type_id);
DeviceAddressPtr address =
CreateDeviceAddress(nullptr, node_size, output_format, output_type_id, value_node, output_idx);
MS_EXCEPTION_IF_NULL(address);
if (ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER) &&
!mem_manager_->MallocMemFromMemPool(address, node_size)) {

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -114,6 +114,8 @@ class KernelRuntime {
protected:
virtual DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
TypeId type_id) = 0;
virtual DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
TypeId type_id, const AnfNodePtr &node, size_t out_index) = 0;
virtual bool NodeOutputDeviceAddressExist(const AnfNodePtr &node, size_t index);
virtual bool KernelMemNotReuse(const AnfNodePtr &node);

View File

@ -358,6 +358,9 @@ constexpr auto kAttrLabelForInsertStreamActive = "label_for_insert_stream_active
constexpr auto kAttrFpBpEnd = "fpbp_end";
constexpr auto kAttrFusion = "fusion";
constexpr auto kAttrGroup = "group";
constexpr auto kAttrGroups = "groups";
constexpr auto kAttrFracZGroup = "fracz_group";
constexpr auto kAttrFracZGroupIdx = "fracz_group_idx";
constexpr auto kAttrOp = "op";
constexpr auto kAttrDestRank = "dest_rank";
constexpr auto kAttrSrcRank = "src_rank";
@ -450,6 +453,7 @@ const size_t kShape2dDims = 2;
const size_t kShape5dDims = 5;
const size_t kShape1dDims = 1;
const size_t kCubeSize = 16;
const size_t kNiSize = 16;
const size_t kMemAlignSize = 512;
const size_t kBNChannelMultipleFactor = 4;
const int kParameterDataTensorMask = 0;
@ -565,9 +569,12 @@ const std::set<std::string> kOptOperatorSet = {kMomentumOpName,
kApplyProximalAdagradOpName,
kApplyProximalGradientDescentOpName,
kApplyRMSPropOpName,
kAdamApplyOneWithDecayOpName,
kAdamApplyOneWithDecayAssignOpName,
kFusedAdamWeightDecayName,
kFusedAdamName,
kFusedSparseAdamName,
kFusedMulApplyMomentumOpName,
kFusedWeightScaleApplyMomentum,
kFusedScaleApplyMomentum,
kApplyCenteredRMSPropOpName,

View File

@ -418,6 +418,9 @@ class Parameter : public ANode {
void set_used_by_dynamic_kernel(bool used) { is_used_by_dynamic_kernel_ = used; }
bool is_used_by_dynamic_kernel() { return is_used_by_dynamic_kernel_; }
void set_fracz_group(int64_t fracz_group) { fracz_group_ = fracz_group; }
int64_t fracz_group() { return fracz_group_; }
private:
std::string name_;
bool has_default_;
@ -426,6 +429,8 @@ class Parameter : public ANode {
ValuePtr default_param_;
// The count of graphs using the parameter.
int used_graph_count_;
// groups attr in FracZ format
int64_t fracz_group_ = 1;
};
using ParameterPtr = std::shared_ptr<Parameter>;