forked from mindspore-Ecosystem/mindspore
change the padding strategy & refactor insert transdata
This commit is contained in:
parent
60958d6b25
commit
5d225f934f
|
@ -20,6 +20,8 @@
|
|||
#include <utility>
|
||||
#include "./securec.h"
|
||||
#include "common/utils.h"
|
||||
#include "session/anf_runtime_algorithm.h"
|
||||
#include "kernel/kernel.h"
|
||||
#include "device/convert_tensor_utils.h"
|
||||
#include "utils/convert_utils.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
@ -27,6 +29,33 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace trans {
|
||||
namespace {
|
||||
std::vector<size_t> PaddingShapeTo4dByDefault(const std::vector<size_t> &shape) {
|
||||
std::vector<size_t> shape_4d(4, 1);
|
||||
switch (shape.size()) {
|
||||
case 0:
|
||||
return shape_4d;
|
||||
case 1:
|
||||
shape_4d[1] = shape[0];
|
||||
break;
|
||||
case 2:
|
||||
shape_4d[1] = shape[0];
|
||||
shape_4d[2] = shape[1];
|
||||
break;
|
||||
case 3:
|
||||
shape_4d[1] = shape[0];
|
||||
shape_4d[2] = shape[1];
|
||||
shape_4d[3] = shape[2];
|
||||
break;
|
||||
case 4:
|
||||
std::copy(shape.begin(), shape.end(), shape_4d.begin());
|
||||
break;
|
||||
default:
|
||||
MS_LOG(EXCEPTION) << "Unexpect shape size = " << shape.size();
|
||||
}
|
||||
return shape_4d;
|
||||
}
|
||||
} // namespace
|
||||
const size_t kNchwDims = 4;
|
||||
const std::map<TypeId, size_t> type_map = {{kNumberTypeBool, 1}, {kNumberTypeInt, 4}, {kNumberTypeInt8, 1},
|
||||
{kNumberTypeInt16, 2}, {kNumberTypeInt32, 4}, {kNumberTypeInt64, 8},
|
||||
|
@ -154,38 +183,64 @@ size_t TypeIdSize(const TypeId data_type) {
|
|||
return unsupported_type_error;
|
||||
}
|
||||
|
||||
std::vector<size_t> TransShapeTo4d(const std::vector<size_t> &shape) {
|
||||
bool IsNeedPadding(const std::string &format, const size_t shape_size) {
|
||||
if (shape_size == 0) {
|
||||
return false;
|
||||
}
|
||||
if (format == kOpFormat_DEFAULT || format == kOpFormat_FRAC_NZ) {
|
||||
return false;
|
||||
} else if (shape_size < 4) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<int> GetRuntimePaddingShape(const AnfNodePtr &node, size_t index) {
|
||||
std::vector<int> shape;
|
||||
std::vector<size_t> host_shape;
|
||||
if (node->isa<ValueNode>()) {
|
||||
auto value_node = node->cast<ValueNodePtr>();
|
||||
auto node_value = value_node->value();
|
||||
auto tensor = node_value->cast<tensor::TensorPtr>();
|
||||
if (tensor == nullptr) {
|
||||
MS_LOG(EXCEPTION) << " the node[ " << node->DebugString() << "]'s cannot convert ";
|
||||
}
|
||||
shape = tensor->shape();
|
||||
(void)std::transform(shape.begin(), shape.end(), std::back_inserter(host_shape), IntToSize);
|
||||
if (host_shape.empty()) {
|
||||
host_shape.push_back(1);
|
||||
}
|
||||
} else {
|
||||
host_shape = AnfAlgo::GetOutputInferShape(node, index);
|
||||
}
|
||||
if (trans::IsNeedPadding(AnfAlgo::GetOutputFormat(node, 0), host_shape.size())) {
|
||||
host_shape = trans::PaddingShapeTo4d(host_shape, AnfAlgo::GetOutputReshapeType(node, 0));
|
||||
}
|
||||
std::transform(host_shape.begin(), host_shape.end(), std::back_inserter(shape), SizeToInt);
|
||||
return shape;
|
||||
}
|
||||
|
||||
std::vector<size_t> PaddingShapeTo4d(const std::vector<size_t> &shape, const std::vector<kernel::Axis> &padding_axis) {
|
||||
if (padding_axis.empty() || shape.size() != padding_axis.size()) {
|
||||
return PaddingShapeTo4dByDefault(shape);
|
||||
}
|
||||
std::vector<size_t> shape_4d(4, 1);
|
||||
switch (shape.size()) {
|
||||
case 0:
|
||||
break;
|
||||
case 1:
|
||||
shape_4d[1] = shape[0];
|
||||
break;
|
||||
case 2:
|
||||
shape_4d[0] = shape[0];
|
||||
shape_4d[1] = shape[1];
|
||||
break;
|
||||
case 3:
|
||||
MS_LOG(EXCEPTION) << "Unexpected shape size = 3,it should has a default format";
|
||||
case 4:
|
||||
for (size_t i = 0; i < 4; ++i) {
|
||||
shape_4d[i] = shape[i];
|
||||
}
|
||||
break;
|
||||
default:
|
||||
MS_LOG(EXCEPTION) << "Unexpected shape size = " << shape.size();
|
||||
for (size_t index = 0; index < padding_axis.size(); index++) {
|
||||
shape_4d[padding_axis[index]] = shape[index];
|
||||
}
|
||||
return shape_4d;
|
||||
}
|
||||
|
||||
std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format) {
|
||||
if (format == kOpFormat_ND || format == kOpFormat_DEFAULT) {
|
||||
return shape;
|
||||
}
|
||||
auto temp_shape = shape;
|
||||
std::vector<size_t> device_shape;
|
||||
if (format == kOpFormat_FRAC_NZ) {
|
||||
if (shape.size() < 2) {
|
||||
MS_EXCEPTION(NotSupportError) << "Format " << format << " is not support shape " << shape.size();
|
||||
}
|
||||
if (shape.size() > 2) {
|
||||
MS_LOG(EXCEPTION) << "Format" << format << " is not support shape " << shape.size();
|
||||
} else {
|
||||
(void)std::copy(shape.begin(), shape.end() - 2, std::back_inserter(device_shape));
|
||||
}
|
||||
auto h1 = (shape[shape.size() - 2] - 1) / kCubeSize + 1;
|
||||
|
@ -197,35 +252,36 @@ std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const s
|
|||
return device_shape;
|
||||
}
|
||||
if (shape.size() != 4) {
|
||||
MS_LOG(EXCEPTION) << "shape_4d size should be 4";
|
||||
MS_LOG(WARNING) << "Get Device Shape using a shape size is less than 4 ,should be Padding shape by Default firstly";
|
||||
temp_shape = PaddingShapeTo4dByDefault(shape);
|
||||
}
|
||||
if (format == kOpFormat_NC1HWC0) {
|
||||
size_t C1 = (shape[1] + kCubeSize - 1) / kCubeSize;
|
||||
size_t C1 = (temp_shape[1] + kCubeSize - 1) / kCubeSize;
|
||||
size_t C0 = kCubeSize;
|
||||
device_shape.push_back(shape[0]);
|
||||
device_shape.push_back(temp_shape[0]);
|
||||
device_shape.push_back(C1);
|
||||
device_shape.push_back(shape[2]);
|
||||
device_shape.push_back(shape[3]);
|
||||
device_shape.push_back(temp_shape[2]);
|
||||
device_shape.push_back(temp_shape[3]);
|
||||
device_shape.push_back(C0);
|
||||
return device_shape;
|
||||
} else if (format == kOpFormat_FRAC_Z) {
|
||||
size_t cout16 = ((shape[0] + kCubeSize - 1) / kCubeSize) * kCubeSize;
|
||||
size_t cin16 = ((shape[1] + kCubeSize - 1) / kCubeSize) * kCubeSize;
|
||||
device_shape.push_back(shape[2] * shape[3] * cin16 / kCubeSize);
|
||||
size_t cout16 = ((temp_shape[0] + kCubeSize - 1) / kCubeSize) * kCubeSize;
|
||||
size_t cin16 = ((temp_shape[1] + kCubeSize - 1) / kCubeSize) * kCubeSize;
|
||||
device_shape.push_back(temp_shape[2] * temp_shape[3] * cin16 / kCubeSize);
|
||||
device_shape.push_back(cout16 / kCubeSize);
|
||||
device_shape.push_back(kCubeSize);
|
||||
device_shape.push_back(kCubeSize);
|
||||
return device_shape;
|
||||
} else if (format == kOpFormat_NHWC) {
|
||||
device_shape.push_back(shape[0]);
|
||||
device_shape.push_back(shape[2]);
|
||||
device_shape.push_back(shape[3]);
|
||||
device_shape.push_back(shape[1]);
|
||||
device_shape.push_back(temp_shape[0]);
|
||||
device_shape.push_back(temp_shape[2]);
|
||||
device_shape.push_back(temp_shape[3]);
|
||||
device_shape.push_back(temp_shape[1]);
|
||||
return device_shape;
|
||||
} else if (format == kOpFormat_NCHW) {
|
||||
return shape;
|
||||
} else if (format == kOpFormat_HWCN) {
|
||||
return {shape[2], shape[3], shape[1], shape[0]};
|
||||
return {temp_shape[2], temp_shape[3], temp_shape[1], temp_shape[0]};
|
||||
} else if (format == kOpFormat_NCHW) {
|
||||
return temp_shape;
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "Unexpected format[" << format << "]";
|
||||
}
|
||||
|
|
|
@ -24,6 +24,7 @@
|
|||
#include <utility>
|
||||
#include <vector>
|
||||
#include "ir/dtype.h"
|
||||
#include "kernel/kernel.h"
|
||||
#include "ir/dtype/type.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -49,7 +50,10 @@ size_t TypeIdSize(const TypeId data_type);
|
|||
size_t ShapeSize(const std::vector<size_t> &shape);
|
||||
size_t CubeSizeByType(const TypeId data_type);
|
||||
|
||||
std::vector<size_t> TransShapeTo4d(const std::vector<size_t> &shape);
|
||||
std::vector<size_t> PaddingShapeTo4d(const std::vector<size_t> &shape,
|
||||
const std::vector<kernel::Axis> &padding_axis = {});
|
||||
std::vector<int> GetRuntimePaddingShape(const AnfNodePtr &node, size_t index);
|
||||
bool IsNeedPadding(const std::string &format, const size_t shape_size);
|
||||
std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format);
|
||||
bool TransDataType(const TypeIdArgs &args, void *result);
|
||||
bool TransFormat(const FormatArgs &args, void *result);
|
||||
|
|
|
@ -141,7 +141,7 @@ bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormat(const std::vector<int
|
|||
if (format_ == kOpFormat_FRAC_NZ) {
|
||||
device_shape = trans::TransShapeToDevice(host_shape, format_);
|
||||
} else {
|
||||
host_shape = trans::TransShapeTo4d(host_shape);
|
||||
host_shape = trans::PaddingShapeTo4d(host_shape);
|
||||
device_shape = trans::TransShapeToDevice(host_shape, format_);
|
||||
}
|
||||
if (type_id_ != type) {
|
||||
|
@ -224,7 +224,7 @@ bool AscendDeviceAddress::ConvertFormatAndSyncHostToDevice(const std::vector<int
|
|||
if (format_ == kOpFormat_FRAC_NZ) {
|
||||
device_shape = trans::TransShapeToDevice(host_shape, format_);
|
||||
} else {
|
||||
host_shape = trans::TransShapeTo4d(host_shape);
|
||||
host_shape = trans::PaddingShapeTo4d(host_shape);
|
||||
device_shape = trans::TransShapeToDevice(host_shape, format_);
|
||||
}
|
||||
if (type_id_ != type) {
|
||||
|
|
|
@ -27,6 +27,7 @@
|
|||
#include "utils/context/ms_context.h"
|
||||
#include "device/ascend/profiling/profiling_manager.h"
|
||||
#include "hccl/hcom.h"
|
||||
#include "common/trans.h"
|
||||
#include "runtime/context.h"
|
||||
#include "device/ascend/ascend_stream_assign.h"
|
||||
#include "device/ascend/ascend_memory_pool.h"
|
||||
|
@ -150,7 +151,7 @@ void DumpOutput(mindspore::session::KernelGraph *graph, const string &dump_path,
|
|||
auto output_size = AnfAlgo::GetOutputTensorNum(node);
|
||||
for (size_t j = 0; j < output_size; ++j) {
|
||||
auto addr = AnfAlgo::GetOutputAddr(node, j);
|
||||
auto shape = AnfAlgo::GetOutputInferShape(node, j);
|
||||
auto shape = trans::GetRuntimePaddingShape(node, j);
|
||||
auto type = AnfAlgo::GetOutputInferDataType(node, j);
|
||||
auto format = kOpFormat_DEFAULT;
|
||||
string filepath = dump_path + '/' + kernel_name + '_' + "output_" + std::to_string(j);
|
||||
|
@ -181,7 +182,7 @@ void DumpParameters(mindspore::session::KernelGraph *graph, const string &dump_p
|
|||
continue;
|
||||
}
|
||||
auto addr = AnfAlgo::GetOutputAddr(item, PRAMATER_OUTPUT_INDEX);
|
||||
auto shape = AnfAlgo::GetOutputInferShape(item, PRAMATER_OUTPUT_INDEX);
|
||||
auto shape = trans::GetRuntimePaddingShape(item, PRAMATER_OUTPUT_INDEX);
|
||||
auto type = AnfAlgo::GetOutputInferDataType(item, PRAMATER_OUTPUT_INDEX);
|
||||
auto format = kOpFormat_DEFAULT;
|
||||
string filepath = dump_path + '/' + parameter_name + '_' + "output_0";
|
||||
|
|
|
@ -184,7 +184,7 @@ void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, cons
|
|||
}
|
||||
if (kernel_build_info.GetInputFormat(input_index) == AnfAlgo::GetPrevNodeOutputFormat(kernel_node, input_index)) {
|
||||
if (AnfAlgo::IsFeatureMapInput(kernel_node, input_index) &&
|
||||
kSpecialFormatSet.find(kernel_build_info.GetInputFormat(input_index)) != kSpecialFormatSet.end()) {
|
||||
kNeedTransFormatSet.find(kernel_build_info.GetInputFormat(input_index)) != kNeedTransFormatSet.end()) {
|
||||
(*cur_kernelinfo_match_counts)[MATCH_SPECIAL_FORMAT_COUNT]++;
|
||||
}
|
||||
(*cur_kernelinfo_match_counts)[MATCH_FORMAT_COUNT]++;
|
||||
|
@ -210,19 +210,22 @@ void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, cons
|
|||
(*cur_kernelinfo_match_counts)[MATCH_OUTPUT_DTYPE_COUNT]++;
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
}
|
||||
|
||||
void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) {
|
||||
auto input_kernel_node = AnfAlgo::GetInputNode(kernel_node, input_index);
|
||||
MS_EXCEPTION_IF_NULL(input_kernel_node);
|
||||
if (AnfAlgo::IsFeatureMapInput(kernel_node, input_index)) {
|
||||
continue;
|
||||
}
|
||||
auto input_with_index = AnfAlgo::VisitKernel(input_kernel_node, 0);
|
||||
MS_EXCEPTION_IF_NULL(input_with_index.first);
|
||||
auto real_input_node = input_with_index.first;
|
||||
if (real_input_node->isa<CNode>()) {
|
||||
continue;
|
||||
}
|
||||
if (real_input_node->isa<Parameter>() && !AnfAlgo::IsParameterWeight(real_input_node->cast<ParameterPtr>())) {
|
||||
continue;
|
||||
}
|
||||
std::shared_ptr<kernel::KernelBuildInfo::KernelBuildInfoBuilder> builder =
|
||||
std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
|
||||
// we set special device info of a input tensor.
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
|
||||
#include "session/anf_runtime_algorithm.h"
|
||||
#include "utils/context/ms_context.h"
|
||||
#include "common/trans.h"
|
||||
#include "utils/config_manager.h"
|
||||
#include "common/utils.h"
|
||||
#include "kernel/kernel_build_info.h"
|
||||
|
@ -391,7 +392,8 @@ bool KernelAdjust::StepLoadCtrlInputs(const std::shared_ptr<session::Context> &c
|
|||
auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0);
|
||||
MS_EXCEPTION_IF_NULL(device_address);
|
||||
tensor->set_device_address(device_address);
|
||||
if (!device_address->SyncHostToDevice(tensor->shape(), LongToSize(tensor->data().nbytes()), tensor->data_type(),
|
||||
if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0),
|
||||
LongToSize(tensor->data().nbytes()), tensor->data_type(),
|
||||
tensor->data_c(false))) {
|
||||
MS_LOG(INFO) << "SyncHostToDevice failed.";
|
||||
return false;
|
||||
|
|
|
@ -31,6 +31,7 @@ class KernelInfo {
|
|||
public:
|
||||
KernelInfo() {
|
||||
kernel_mod_ = nullptr;
|
||||
is_feature_map_ = false;
|
||||
select_kernel_build_info_ = nullptr;
|
||||
output_address_list_ = {};
|
||||
workspace_address_list_ = {};
|
||||
|
@ -45,6 +46,7 @@ class KernelInfo {
|
|||
void set_select_kernel_build_info(const kernel::KernelBuildInfoPtr &select_kernel_build_info) {
|
||||
select_kernel_build_info_ = select_kernel_build_info;
|
||||
}
|
||||
void SetFeatureMapFlag(bool flag) { is_feature_map_ = flag; }
|
||||
const DeviceAddress *GetOutputAddr(size_t index) const;
|
||||
DeviceAddressPtr GetMutableOutputAddr(size_t index) const;
|
||||
bool OutputAddrExist(size_t index) const;
|
||||
|
@ -63,8 +65,10 @@ class KernelInfo {
|
|||
void set_graph_id(uint32_t graph_id) { graph_id_ = graph_id; }
|
||||
uint32_t graph_id() const { return graph_id_; }
|
||||
bool operator==(const KernelInfo &other) const;
|
||||
bool is_feature_map() const { return is_feature_map_; }
|
||||
|
||||
private:
|
||||
bool is_feature_map_;
|
||||
kernel::KernelBuildInfoPtr select_kernel_build_info_;
|
||||
std::vector<std::shared_ptr<DeviceAddress>> output_address_list_;
|
||||
std::vector<std::shared_ptr<DeviceAddress>> workspace_address_list_;
|
||||
|
|
|
@ -105,7 +105,7 @@ size_t KernelRuntime::CountNodeDeviceMemorySize(const mindspore::AnfNodePtr &nod
|
|||
std::vector<size_t> shape = AnfAlgo::GetOutputDeviceShape(node, output_index);
|
||||
auto format = AnfAlgo::GetOutputFormat(node, output_index);
|
||||
if (shape.empty() && format != kOpFormat_DEFAULT) {
|
||||
shape = trans::TransShapeTo4d(shape);
|
||||
shape = trans::PaddingShapeTo4d(shape, AnfAlgo::GetOutputReshapeType(node, output_index));
|
||||
shape = trans::TransShapeToDevice(shape, format);
|
||||
}
|
||||
// scalar's output shape is a empty vector
|
||||
|
@ -401,8 +401,9 @@ void KernelRuntime::AssignValueNodeTensor(const ValueNodePtr &value_node, const
|
|||
auto address = CreateDeviceAddress(ptr, node_size, AnfAlgo::GetOutputFormat(value_node, output_idx), output_type_id);
|
||||
MS_EXCEPTION_IF_NULL(address);
|
||||
AnfAlgo::SetOutputAddr(address, output_idx, value_node.get());
|
||||
if (!address->SyncHostToDevice(tensor->shape(), tensor_size, tensor->data_type(), tensor->data_c(false))) {
|
||||
MS_EXCEPTION(NotExistsError) << "kValueNode SyncHostToDevice fail!" << value_node->DebugString() << "node format is"
|
||||
if (!address->SyncHostToDevice(trans::GetRuntimePaddingShape(value_node, 0), tensor_size, tensor->data_type(),
|
||||
tensor->data_c(false))) {
|
||||
MS_EXCEPTION(NotExistsError) << "ValueNode SyncHostToDevice fail!" << value_node->DebugString() << "node format is"
|
||||
<< AnfAlgo::GetOutputFormat(value_node, output_idx) << "node dtype is "
|
||||
<< AnfAlgo::GetOutputInferDataType(value_node, output_idx);
|
||||
}
|
||||
|
@ -421,19 +422,6 @@ void KernelRuntime::AssignStaticMemoryValueNode(session::KernelGraph *graph) {
|
|||
MS_EXCEPTION_IF_NULL(node_value);
|
||||
if (node_value->isa<Tensor>()) {
|
||||
AssignValueNodeTensor(value_node, node_value, 0);
|
||||
} else if (node_value->isa<ValueTuple>()) {
|
||||
auto value_tuple = node_value->cast<ValueTuplePtr>();
|
||||
if (value_tuple == nullptr) {
|
||||
MS_LOG(WARNING) << "value_tuple is null";
|
||||
continue;
|
||||
}
|
||||
size_t i = 0;
|
||||
auto value_list = value_tuple->value();
|
||||
for (auto value_ptr : value_list) {
|
||||
if (value_ptr->isa<Tensor>()) {
|
||||
AssignValueNodeTensor(value_node, value_ptr, i++);
|
||||
}
|
||||
}
|
||||
} else if (node_value->isa<StringImm>()) {
|
||||
auto value = GetValue<std::string>(node_value);
|
||||
size_t tensor_size = value.size();
|
||||
|
|
|
@ -59,30 +59,20 @@ size_t KernelBuildInfo::GetInputNum() const { return inputs_format_.size(); }
|
|||
|
||||
size_t KernelBuildInfo::GetOutputNum() const { return outputs_format_.size(); }
|
||||
|
||||
bool KernelBuildInfo::GetInputReshapeType(size_t input_index, std::vector<Axis> *reshape_type) const {
|
||||
MS_EXCEPTION_IF_NULL(reshape_type);
|
||||
reshape_type->clear();
|
||||
std::vector<Axis> KernelBuildInfo::GetInputReshapeType(size_t input_index) const {
|
||||
if (input_index >= input_reshape_type_.size()) {
|
||||
MS_LOG(WARNING) << "The index [" << input_index << "] is exceed the number of input node size "
|
||||
<< input_reshape_type_.size();
|
||||
return false;
|
||||
MS_LOG(EXCEPTION) << "The index [" << input_index << "] is exceed the number of input node size "
|
||||
<< input_reshape_type_.size();
|
||||
}
|
||||
(void)std::copy(input_reshape_type_[input_index].begin(), input_reshape_type_[input_index].end(),
|
||||
std::inserter(*reshape_type, (*reshape_type).begin()));
|
||||
return true;
|
||||
return input_reshape_type_[input_index];
|
||||
}
|
||||
|
||||
bool KernelBuildInfo::GetOutputReshapeType(size_t output_index, std::vector<Axis> *reshape_type) const {
|
||||
MS_EXCEPTION_IF_NULL(reshape_type);
|
||||
reshape_type->clear();
|
||||
std::vector<Axis> KernelBuildInfo::GetOutputReshapeType(size_t output_index) const {
|
||||
if (output_index >= output_reshape_type_.size()) {
|
||||
MS_LOG(WARNING) << "The index [" << output_index << "] is exceed the number of output node dixr"
|
||||
<< output_reshape_type_.size();
|
||||
return false;
|
||||
MS_LOG(EXCEPTION) << "The index [" << output_index << "] is exceed the number of output node size "
|
||||
<< output_reshape_type_.size();
|
||||
}
|
||||
(void)std::copy(output_reshape_type_[output_index].begin(), output_reshape_type_[output_index].end(),
|
||||
std::inserter(*reshape_type, (*reshape_type).begin()));
|
||||
return true;
|
||||
return output_reshape_type_[output_index];
|
||||
}
|
||||
|
||||
std::string KernelBuildInfo::ToString() const {
|
||||
|
@ -115,6 +105,10 @@ bool KernelBuildInfo::operator==(const KernelBuildInfo &other) const {
|
|||
return !(inputs_device_type_ != other.inputs_device_type_ || outputs_device_type_ != other.outputs_device_type_);
|
||||
}
|
||||
|
||||
bool KernelBuildInfo::IsInputDefaultPadding() const { return output_reshape_type_.empty(); }
|
||||
|
||||
bool KernelBuildInfo::IsOutputDefaultPadding() const { return input_reshape_type_.empty(); }
|
||||
|
||||
void KernelBuildInfo::KernelBuildInfoBuilder::SetKernelType(const KernelType &kernel_type) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_build_info_);
|
||||
kernel_build_info_->kernel_type_ = kernel_type;
|
||||
|
|
|
@ -54,9 +54,13 @@ class KernelBuildInfo {
|
|||
|
||||
TypeId GetOutputDeviceType(size_t output_index) const;
|
||||
|
||||
bool GetInputReshapeType(size_t input_index, std::vector<Axis> *reshape_type) const;
|
||||
std::vector<Axis> GetInputReshapeType(size_t input_index) const;
|
||||
|
||||
bool GetOutputReshapeType(size_t input_index, std::vector<Axis> *reshape_type) const;
|
||||
bool IsInputDefaultPadding() const;
|
||||
|
||||
bool IsOutputDefaultPadding() const;
|
||||
|
||||
std::vector<Axis> GetOutputReshapeType(size_t input_index) const;
|
||||
|
||||
std::vector<std::string> GetAllInputFormats() const;
|
||||
|
||||
|
|
|
@ -18,20 +18,21 @@
|
|||
#include <set>
|
||||
#include "common/trans.h"
|
||||
#include "common/utils.h"
|
||||
#include "utils/utils.h"
|
||||
#include "device/kernel_info.h"
|
||||
#include "kernel/oplib/oplib.h"
|
||||
#include "operator/ops.h"
|
||||
#include "session/anf_runtime_algorithm.h"
|
||||
#include "session/kernel_graph.h"
|
||||
#include "utils/context/ms_context.h"
|
||||
#include "utils/utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder;
|
||||
namespace {
|
||||
kernel::KernelBuildInfoPtr CreateKernelBuildInfo(const std::string &input_format, const std::string &output_format,
|
||||
const AnfNodePtr &node, const kernel::KernelBuildInfo ori_build_info) {
|
||||
kernel::KernelBuildInfoPtr RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format,
|
||||
const AnfNodePtr &node,
|
||||
const kernel::KernelBuildInfo ori_build_info) {
|
||||
KernelBuildInfoBuilder builder;
|
||||
builder.SetInputsFormat({input_format});
|
||||
builder.SetOutputsFormat({output_format});
|
||||
|
@ -54,9 +55,11 @@ CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input,
|
|||
CNodePtr trans_node = func_graph->NewCNode(trans_inputs);
|
||||
MS_EXCEPTION_IF_NULL(trans_node);
|
||||
if (need_padding) {
|
||||
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input, 0)},
|
||||
{trans::TransShapeTo4d(AnfAlgo::GetOutputInferShape(input, 0))},
|
||||
trans_node.get());
|
||||
// if need padding we should set the transdata node's shape to the padding shape
|
||||
AnfAlgo::SetOutputInferTypeAndShape(
|
||||
{AnfAlgo::GetOutputInferDataType(input, 0)},
|
||||
{trans::PaddingShapeTo4d(AnfAlgo::GetOutputInferShape(input, 0), AnfAlgo::GetOutputReshapeType(input, 0))},
|
||||
trans_node.get());
|
||||
} else {
|
||||
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input, 0)},
|
||||
{AnfAlgo::GetOutputInferShape(input, 0)}, trans_node.get());
|
||||
|
@ -92,9 +95,11 @@ AnfNodePtr CreateReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &i
|
|||
AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr &node, size_t index,
|
||||
const KernelSelectPtr &kernel_select) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
bool padding_flag = false;
|
||||
auto input_node = AnfAlgo::GetInputNode(node, index);
|
||||
if (input_node->isa<ValueNode>() || input_node->isa<Parameter>()) {
|
||||
auto node_with_index = AnfAlgo::VisitKernel(input_node, 0);
|
||||
MS_EXCEPTION_IF_NULL(node_with_index.first);
|
||||
auto real_input = node_with_index.first;
|
||||
if (real_input->isa<ValueNode>() || real_input->isa<Parameter>()) {
|
||||
input_node = InsertTransOpForOutput(func_graph, input_node, kernel_select);
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
AnfAlgo::SetNodeInput(node, input_node, index);
|
||||
|
@ -106,33 +111,11 @@ AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr &
|
|||
std::vector<size_t> origin_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, index);
|
||||
std::string origin_format = kOpFormat_DEFAULT;
|
||||
std::string dest_format = AnfAlgo::GetInputFormat(node, index);
|
||||
if (dest_format == kOpFormat_C1HWNCoC0) {
|
||||
padding_flag = (origin_shape.size() != kShape4dDims);
|
||||
AnfNodePtr replace_input = AddTransOpNodeToGraph(func_graph, node, kernel_select, index, padding_flag,
|
||||
origin_format, dest_format, kTransDataOpName, true);
|
||||
MS_EXCEPTION_IF_NULL(replace_input);
|
||||
return replace_input;
|
||||
}
|
||||
if (dest_format == kOpFormat_NC1HWC0 && origin_shape.size() > 1) {
|
||||
padding_flag = (origin_shape.size() != kShape4dDims);
|
||||
AnfNodePtr replace_input = AddTransOpNodeToGraph(func_graph, node, kernel_select, index, padding_flag,
|
||||
origin_format, dest_format, kTransDataOpName, true);
|
||||
MS_EXCEPTION_IF_NULL(replace_input);
|
||||
MS_LOG(DEBUG) << "Inserted Translate45, index: " << index;
|
||||
return replace_input;
|
||||
} else if (dest_format == kOpFormat_FRAC_NZ) {
|
||||
AnfNodePtr replace_input = AddTransOpNodeToGraph(func_graph, node, kernel_select, index, padding_flag,
|
||||
origin_format, dest_format, kTransDataOpName, true);
|
||||
MS_EXCEPTION_IF_NULL(replace_input);
|
||||
MS_LOG(DEBUG) << "inserted translate " << AnfAlgo::GetInputFormat(node, index) << " To default, index: " << index;
|
||||
return replace_input;
|
||||
} else if (dest_format == kOpFormat_FRAC_Z && !origin_shape.empty()) {
|
||||
padding_flag = (origin_shape.size() != kShape4dDims);
|
||||
AnfNodePtr replace_input = AddTransOpNodeToGraph(func_graph, node, kernel_select, index, padding_flag,
|
||||
origin_format, dest_format, kTransDataOpName, true);
|
||||
MS_EXCEPTION_IF_NULL(replace_input);
|
||||
MS_LOG(DEBUG) << "Inserted Translate45, index: " << index;
|
||||
return replace_input;
|
||||
if (kNeedTransFormatSet.find(dest_format) != kNeedTransFormatSet.end() && origin_shape.size() > 1) {
|
||||
MS_LOG(DEBUG) << node->DebugString() << "Insert transdata " << AnfAlgo::GetInputFormat(node, index)
|
||||
<< " To DefaultFormat , index: " << index;
|
||||
return AddTransOpNodeToGraph(func_graph, node, kernel_select, index, origin_format, dest_format, kTransDataOpName,
|
||||
true);
|
||||
}
|
||||
return input_node;
|
||||
}
|
||||
|
@ -140,7 +123,6 @@ AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr &
|
|||
AnfNodePtr InsertTransOpForSingleOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const KernelSelectPtr &kernel_select) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
bool padding_flag = false;
|
||||
std::string output_format;
|
||||
std::vector<size_t> origin_shape;
|
||||
if (!AnfAlgo::IsRealKernel(node)) {
|
||||
|
@ -156,46 +138,14 @@ AnfNodePtr InsertTransOpForSingleOutput(const FuncGraphPtr &func_graph, const An
|
|||
}
|
||||
std::string origin_format = output_format;
|
||||
std::string dest_format = kOpFormat_DEFAULT;
|
||||
if (output_format == kOpFormat_C1HWNCoC0) {
|
||||
padding_flag = (origin_shape.size() != kShape4dDims);
|
||||
AnfNodePtr replace_input = AddTransOpNodeToGraph(func_graph, node, kernel_select, 0, padding_flag, origin_format,
|
||||
dest_format, kTransDataOpName, false);
|
||||
MS_EXCEPTION_IF_NULL(replace_input);
|
||||
return replace_input;
|
||||
}
|
||||
if (output_format == kOpFormat_NC1HWC0 && origin_shape.size() > 1) {
|
||||
padding_flag = (origin_shape.size() != kShape4dDims);
|
||||
AnfNodePtr replace_output = AddTransOpNodeToGraph(func_graph, node, kernel_select, 0, padding_flag, origin_format,
|
||||
dest_format, kTransDataOpName, false);
|
||||
MS_EXCEPTION_IF_NULL(replace_output);
|
||||
MS_LOG(DEBUG) << "Inserted Trans54";
|
||||
return replace_output;
|
||||
} else if (output_format == kOpFormat_FRAC_NZ) {
|
||||
AnfNodePtr replace_output = AddTransOpNodeToGraph(func_graph, node, kernel_select, 0, padding_flag, origin_format,
|
||||
dest_format, kTransDataOpName, false);
|
||||
MS_EXCEPTION_IF_NULL(replace_output);
|
||||
MS_LOG(DEBUG) << "Inserted Translate " << output_format << " To default, index: 0";
|
||||
return replace_output;
|
||||
} else if (output_format == kOpFormat_FRAC_Z && !origin_shape.empty()) {
|
||||
padding_flag = (origin_shape.size() != kShape4dDims);
|
||||
AnfNodePtr replace_output = AddTransOpNodeToGraph(func_graph, node, kernel_select, 0, padding_flag, origin_format,
|
||||
dest_format, kTransDataOpName, false);
|
||||
MS_EXCEPTION_IF_NULL(replace_output);
|
||||
MS_LOG(DEBUG) << "Inserted Trans54";
|
||||
return replace_output;
|
||||
if (kNeedTransFormatSet.find(output_format) != kNeedTransFormatSet.end() && origin_shape.size() > 1) {
|
||||
MS_LOG(DEBUG) << "Inserted Transdata " << output_format << " To default , index :0";
|
||||
return AddTransOpNodeToGraph(func_graph, node, kernel_select, 0, origin_format, dest_format, kTransDataOpName,
|
||||
false);
|
||||
}
|
||||
return node;
|
||||
}
|
||||
|
||||
void GetTransDataInputFormat(const AnfNodePtr &node, size_t idx, std::string *input_format) {
|
||||
MS_EXCEPTION_IF_NULL(input_format);
|
||||
if (AnfAlgo::IsRealKernel(node)) {
|
||||
*input_format = AnfAlgo::GetOutputFormat(node, idx);
|
||||
} else {
|
||||
*input_format = AnfAlgo::GetPrevNodeOutputFormat(node, 0);
|
||||
}
|
||||
}
|
||||
|
||||
AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const KernelSelectPtr &kernel_select) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
|
@ -203,46 +153,17 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const
|
|||
std::vector<AnfNodePtr> make_tuple_inputs;
|
||||
make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
|
||||
for (size_t output_idx = 0; output_idx < AnfAlgo::GetOutputTensorNum(node); ++output_idx) {
|
||||
bool padding_flag = false;
|
||||
|
||||
std::string output_format;
|
||||
GetTransDataInputFormat(node, output_idx, &output_format);
|
||||
std::string output_format = AnfAlgo::GetOutputFormat(node, output_idx);
|
||||
if (output_format == kOpFormat_NC1KHKWHWC0) {
|
||||
MS_LOG(EXCEPTION) << "got the hw format" << output_format << " when insert the transdata node "
|
||||
MS_LOG(EXCEPTION) << "Got the special format" << output_format << " when insert the transdata node "
|
||||
<< node->DebugString();
|
||||
}
|
||||
auto tuple_getitem = CreatTupleGetItemNode(func_graph, node, output_idx);
|
||||
std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(node, output_idx);
|
||||
std::string origin_format = output_format;
|
||||
std::string dest_format = kOpFormat_DEFAULT;
|
||||
if (output_format == kOpFormat_C1HWNCoC0) {
|
||||
padding_flag = (origin_shape.size() != kShape4dDims);
|
||||
AnfNodePtr replace_input = AddTransOpNodeToGraph(func_graph, tuple_getitem, kernel_select, 0, padding_flag,
|
||||
origin_format, dest_format, kTransDataOpName, false);
|
||||
MS_EXCEPTION_IF_NULL(replace_input);
|
||||
return replace_input;
|
||||
}
|
||||
if (output_format == kOpFormat_NC1HWC0 && origin_shape.size() > 1) {
|
||||
padding_flag = (origin_shape.size() != kShape4dDims);
|
||||
// Insert a 5to4 trans op.
|
||||
AnfNodePtr replace_output = AddTransOpNodeToGraph(func_graph, tuple_getitem, kernel_select, 0, padding_flag,
|
||||
origin_format, dest_format, kTransDataOpName, false);
|
||||
MS_EXCEPTION_IF_NULL(replace_output);
|
||||
MS_LOG(DEBUG) << "Inserted Translate54";
|
||||
make_tuple_inputs.push_back(replace_output);
|
||||
} else if (output_format == kOpFormat_FRAC_NZ) {
|
||||
AnfNodePtr replace_output = AddTransOpNodeToGraph(func_graph, tuple_getitem, kernel_select, 0, padding_flag,
|
||||
origin_format, dest_format, kTransDataOpName, false);
|
||||
MS_EXCEPTION_IF_NULL(replace_output);
|
||||
MS_LOG(DEBUG) << "Inserted Translate " << output_format << " To default, index: " << output_idx;
|
||||
make_tuple_inputs.push_back(replace_output);
|
||||
} else if (output_format == kOpFormat_FRAC_Z && !origin_shape.empty()) {
|
||||
padding_flag = (origin_shape.size() != kShape4dDims);
|
||||
AnfNodePtr replace_output = AddTransOpNodeToGraph(func_graph, tuple_getitem, kernel_select, 0, padding_flag,
|
||||
origin_format, dest_format, kTransDataOpName, false);
|
||||
MS_EXCEPTION_IF_NULL(replace_output);
|
||||
MS_LOG(DEBUG) << "Inserted Translate54";
|
||||
make_tuple_inputs.push_back(replace_output);
|
||||
if (kNeedTransFormatSet.find(output_format) != kNeedTransFormatSet.end() && origin_shape.size() > 1) {
|
||||
make_tuple_inputs.emplace_back(AddTransOpNodeToGraph(func_graph, tuple_getitem, kernel_select, 0, output_format,
|
||||
dest_format, kTransDataOpName, false));
|
||||
} else {
|
||||
// No need insert trans op.
|
||||
make_tuple_inputs.push_back(tuple_getitem);
|
||||
|
@ -253,16 +174,17 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const
|
|||
}
|
||||
} // namespace
|
||||
AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const KernelSelectPtr &kernel_select, size_t insert_index, const bool padding_flag,
|
||||
const KernelSelectPtr &kernel_select, size_t insert_index,
|
||||
const std::string &origin_format, const std::string &dest_format,
|
||||
const std::string &op_name, bool is_insert_input) {
|
||||
AnfNodePtr trans_node = nullptr;
|
||||
AnfNodePtr input_node = nullptr;
|
||||
AnfNodePtr input_node = node;
|
||||
AnfNodePtr trans_data = nullptr;
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (origin_format.empty() || dest_format.empty()) {
|
||||
MS_LOG(EXCEPTION) << "trans op format is error, origin = " << origin_format << ", dest " << origin_format;
|
||||
}
|
||||
// if insert transdata for input we need to change the input
|
||||
if (is_insert_input) {
|
||||
if (!node->isa<CNode>()) {
|
||||
MS_LOG(EXCEPTION) << "cannot insert a transdata node to a node's input which the node is not a cnode";
|
||||
|
@ -270,29 +192,34 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt
|
|||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
input_node = AnfAlgo::GetInputNode(cnode, insert_index);
|
||||
if (padding_flag) {
|
||||
auto padd_shape = trans::TransShapeTo4d(AnfAlgo::GetOutputInferShape(input_node, 0));
|
||||
auto reshape_node = CreateReshapeNode(func_graph, input_node, kernel_select, padd_shape);
|
||||
trans_data = NewTransOpNode(func_graph, reshape_node, kernel_select, padding_flag, op_name);
|
||||
} else {
|
||||
trans_data = NewTransOpNode(func_graph, input_node, kernel_select, padding_flag, op_name);
|
||||
}
|
||||
}
|
||||
bool need_padding = (trans::IsNeedPadding(dest_format, AnfAlgo::GetOutputInferShape(input_node, 0).size()) &&
|
||||
op_name == kTransDataOpName);
|
||||
if (!need_padding) {
|
||||
// don't need padding insert transdata only
|
||||
trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, op_name);
|
||||
trans_node = trans_data;
|
||||
} else if (is_insert_input) {
|
||||
// if need padding & is input need insert a transdata
|
||||
// reshape[padding shape] -> transdata[padding shape] -> node
|
||||
auto padding_shape =
|
||||
trans::PaddingShapeTo4d(AnfAlgo::GetOutputInferShape(input_node, 0), AnfAlgo::GetInputReshapeType(node, 0));
|
||||
auto reshape_node = CreateReshapeNode(func_graph, input_node, kernel_select, padding_shape);
|
||||
trans_data = NewTransOpNode(func_graph, reshape_node, kernel_select, need_padding, op_name);
|
||||
trans_node = trans_data;
|
||||
} else {
|
||||
input_node = node;
|
||||
trans_data = NewTransOpNode(func_graph, input_node, kernel_select, padding_flag, op_name);
|
||||
if (padding_flag) {
|
||||
auto reshape_node =
|
||||
CreateReshapeNode(func_graph, trans_data, kernel_select, AnfAlgo::GetOutputInferShape(input_node, 0));
|
||||
trans_node = reshape_node;
|
||||
} else {
|
||||
trans_node = trans_data;
|
||||
}
|
||||
// if need padding & is output need insert a transdata
|
||||
// node -> transdata[padding shape] -> reshape[ori_shape]
|
||||
trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, op_name);
|
||||
auto reshape_node =
|
||||
CreateReshapeNode(func_graph, trans_data, kernel_select, AnfAlgo::GetOutputInferShape(input_node, 0));
|
||||
trans_node = reshape_node;
|
||||
}
|
||||
// refresh the transdata's format to ori format & dst format
|
||||
MS_EXCEPTION_IF_NULL(trans_data);
|
||||
MS_EXCEPTION_IF_NULL(trans_data->kernel_info());
|
||||
auto trans_ori_build_info = trans_data->kernel_info()->select_kernel_build_info();
|
||||
auto kernel_build_info = CreateKernelBuildInfo(origin_format, dest_format, input_node, *trans_ori_build_info);
|
||||
auto kernel_build_info = RefreshKernelBuildInfo(origin_format, dest_format, input_node, *trans_ori_build_info);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info, trans_data.get());
|
||||
return trans_node;
|
||||
}
|
||||
|
@ -376,7 +303,17 @@ CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnod
|
|||
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); ++input_index) {
|
||||
TypeId origin_type;
|
||||
auto cur_input = AnfAlgo::GetInputNode(cnode, input_index);
|
||||
if (!AnfAlgo::IsFeatureMapInput(cnode, input_index)) {
|
||||
auto kernel_with_index = AnfAlgo::VisitKernel(cur_input, 0);
|
||||
auto is_weight_boundary = [](const AnfNodePtr &node) -> bool {
|
||||
if (node->isa<ValueNode>()) {
|
||||
return true;
|
||||
} else if (node->isa<Parameter>() && AnfAlgo::IsParameterWeight(node->cast<ParameterPtr>())) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
auto real_input_node = kernel_with_index.first;
|
||||
if (is_weight_boundary(real_input_node)) {
|
||||
// weight
|
||||
origin_type = AnfAlgo::GetPrevNodeOutputDeviceDataType(cnode, input_index);
|
||||
} else {
|
||||
|
|
|
@ -48,7 +48,7 @@ class KernelQuery {
|
|||
using KernelQueryPtr = std::shared_ptr<KernelQuery>;
|
||||
|
||||
AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const KernelSelectPtr &kernel_select, size_t insert_index, bool padding_flag,
|
||||
const KernelSelectPtr &kernel_select, size_t insert_index,
|
||||
const std::string &origin_format, const std::string &dest_format,
|
||||
const std::string &op_name, bool is_insert_input);
|
||||
|
||||
|
|
|
@ -105,10 +105,8 @@ AnfNodePtr AddAdditionalToRefOutput(const FuncGraphPtr &func_graph, const CNodeP
|
|||
// insert trans
|
||||
if (origin_format != cur_format) {
|
||||
auto kernel_select = std::make_shared<KernelSelect>();
|
||||
bool need_padding =
|
||||
(cur_format == kOpFormat_NC1HWC0 && AnfAlgo::GetOutputInferShape(final_node, 0).size() != kShape4dDims);
|
||||
final_node = AddTransOpNodeToGraph(func_graph, final_node, kernel_select, 0, need_padding, cur_format,
|
||||
origin_format, kTransDataOpName, false);
|
||||
final_node = AddTransOpNodeToGraph(func_graph, final_node, kernel_select, 0, cur_format, origin_format,
|
||||
kTransDataOpName, false);
|
||||
final_index = 0;
|
||||
MS_EXCEPTION_IF_NULL(final_node);
|
||||
MS_LOG(INFO) << "DealRefTransAndCast add trans op, op debug info is " << final_node->DebugString();
|
||||
|
|
|
@ -1,99 +1,99 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "pre_activate/ascend/ir_fusion/transdata_split.h"
|
||||
#include <set>
|
||||
#include "pre_activate/ascend/ascend_helper.h"
|
||||
#include "session/anf_runtime_algorithm.h"
|
||||
#include "debug/anf_ir_dump.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
const std::set<std::pair<string, string>> invalid_formats_pair = {{kOpFormat_C1HWNCoC0, kOpFormat_NCHW},
|
||||
{kOpFormat_NCHW, kOpFormat_C1HWNCoC0},
|
||||
{kOpFormat_C1HWNCoC0, kOpFormat_DEFAULT},
|
||||
{kOpFormat_DEFAULT, kOpFormat_C1HWNCoC0}};
|
||||
|
||||
bool TransDataSplit::Run(const FuncGraphPtr &func_graph) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
bool changed = false;
|
||||
std::vector<AnfNodePtr> node_list = TopoSort(func_graph->get_return());
|
||||
for (auto &node : node_list) {
|
||||
if (node != nullptr && node->isa<CNode>() && AnfAlgo::GetCNodeName(node) == kTransDataOpName) {
|
||||
CheckCNodeInputSize(node->cast<CNodePtr>(), kBackendTransDataInputNum);
|
||||
if (IsFormatInvaild(node)) {
|
||||
changed = DoSplit(func_graph, node);
|
||||
}
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
bool TransDataSplit::IsFormatInvaild(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto input_format = AnfAlgo::GetInputFormat(node, 0);
|
||||
auto output_format = AnfAlgo::GetOutputFormat(node, 0);
|
||||
auto format_pair = std::make_pair(input_format, output_format);
|
||||
|
||||
return invalid_formats_pair.find(format_pair) != invalid_formats_pair.end();
|
||||
}
|
||||
// transdata cannot support frac_z to nchw need split transdata(frac_z-HWCN) and transpose(HWCN-NCHW)
|
||||
bool TransDataSplit::DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto input_node = node->cast<CNodePtr>()->input(1);
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
|
||||
auto input_format = AnfAlgo::GetInputFormat(node, 0);
|
||||
auto output_format = AnfAlgo::GetOutputFormat(node, 0);
|
||||
AnfNodePtr new_transdata_node = nullptr;
|
||||
AnfNodePtr new_transpose_node = nullptr;
|
||||
AnfNodePtr new_replace_node = nullptr;
|
||||
// if output_format=default transdata need split transdata->transpose else transpose->transdata
|
||||
if (output_format == kOpFormat_DEFAULT || output_format == kOpFormat_NCHW) {
|
||||
// trans input_format to hwcn
|
||||
new_transdata_node = AddTransOpNodeToGraph(func_graph, node, kernel_select_, 0, false, input_format, kOpFormat_HWCN,
|
||||
kTransDataOpName, true);
|
||||
// trans hwcn to default_format
|
||||
new_transpose_node = AddTransOpNodeToGraph(func_graph, new_transdata_node, kernel_select_, 0, false, kOpFormat_HWCN,
|
||||
output_format, prim::kPrimTranspose->name(), false);
|
||||
AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(std::vector<int>{3, 2, 0, 1}), new_transpose_node);
|
||||
new_replace_node = new_transpose_node;
|
||||
} else {
|
||||
// trans default to hwcn
|
||||
new_transpose_node = AddTransOpNodeToGraph(func_graph, node, kernel_select_, 0, false, input_format, kOpFormat_HWCN,
|
||||
prim::kPrimTranspose->name(), true);
|
||||
AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(std::vector<int>{2, 3, 1, 0}), new_transpose_node);
|
||||
|
||||
// trans hwcn to output_format
|
||||
new_transdata_node = AddTransOpNodeToGraph(func_graph, new_transpose_node, kernel_select_, 0, false, kOpFormat_HWCN,
|
||||
output_format, kTransDataOpName, false);
|
||||
new_replace_node = new_transdata_node;
|
||||
}
|
||||
FuncGraphManagerPtr manager = func_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
manager->AddFuncGraph(func_graph);
|
||||
|
||||
if (!manager->Replace(node, new_replace_node)) {
|
||||
MS_LOG(EXCEPTION) << "manager replace node failed";
|
||||
}
|
||||
MS_LOG(INFO) << "transdata node:" << cnode->DebugString() << "split success.";
|
||||
return true;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "pre_activate/ascend/ir_fusion/transdata_split.h"
|
||||
#include <set>
|
||||
#include "pre_activate/ascend/ascend_helper.h"
|
||||
#include "session/anf_runtime_algorithm.h"
|
||||
#include "debug/anf_ir_dump.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
const std::set<std::pair<string, string>> invalid_formats_pair = {{kOpFormat_C1HWNCoC0, kOpFormat_NCHW},
|
||||
{kOpFormat_NCHW, kOpFormat_C1HWNCoC0},
|
||||
{kOpFormat_C1HWNCoC0, kOpFormat_DEFAULT},
|
||||
{kOpFormat_DEFAULT, kOpFormat_C1HWNCoC0}};
|
||||
|
||||
bool TransDataSplit::Run(const FuncGraphPtr &func_graph) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
bool changed = false;
|
||||
std::vector<AnfNodePtr> node_list = TopoSort(func_graph->get_return());
|
||||
for (auto &node : node_list) {
|
||||
if (node != nullptr && node->isa<CNode>() && AnfAlgo::GetCNodeName(node) == kTransDataOpName) {
|
||||
CheckCNodeInputSize(node->cast<CNodePtr>(), kBackendTransDataInputNum);
|
||||
if (IsFormatInvaild(node)) {
|
||||
changed = DoSplit(func_graph, node);
|
||||
}
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
bool TransDataSplit::IsFormatInvaild(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto input_format = AnfAlgo::GetInputFormat(node, 0);
|
||||
auto output_format = AnfAlgo::GetOutputFormat(node, 0);
|
||||
auto format_pair = std::make_pair(input_format, output_format);
|
||||
|
||||
return invalid_formats_pair.find(format_pair) != invalid_formats_pair.end();
|
||||
}
|
||||
// transdata cannot support frac_z to nchw need split transdata(frac_z-HWCN) and transpose(HWCN-NCHW)
|
||||
bool TransDataSplit::DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto input_node = node->cast<CNodePtr>()->input(1);
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
|
||||
auto input_format = AnfAlgo::GetInputFormat(node, 0);
|
||||
auto output_format = AnfAlgo::GetOutputFormat(node, 0);
|
||||
AnfNodePtr new_transdata_node = nullptr;
|
||||
AnfNodePtr new_transpose_node = nullptr;
|
||||
AnfNodePtr new_replace_node = nullptr;
|
||||
// if output_format=default transdata need split transdata->transpose else transpose->transdata
|
||||
if (output_format == kOpFormat_DEFAULT || output_format == kOpFormat_NCHW) {
|
||||
// trans input_format to hwcn
|
||||
new_transdata_node =
|
||||
AddTransOpNodeToGraph(func_graph, node, kernel_select_, 0, input_format, kOpFormat_HWCN, kTransDataOpName, true);
|
||||
// trans hwcn to default_format
|
||||
new_transpose_node = AddTransOpNodeToGraph(func_graph, new_transdata_node, kernel_select_, 0, kOpFormat_HWCN,
|
||||
output_format, prim::kPrimTranspose->name(), false);
|
||||
AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(std::vector<int>{3, 2, 0, 1}), new_transpose_node);
|
||||
new_replace_node = new_transpose_node;
|
||||
} else {
|
||||
// trans default to hwcn
|
||||
new_transpose_node = AddTransOpNodeToGraph(func_graph, node, kernel_select_, 0, input_format, kOpFormat_HWCN,
|
||||
prim::kPrimTranspose->name(), true);
|
||||
AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(std::vector<int>{2, 3, 1, 0}), new_transpose_node);
|
||||
|
||||
// trans hwcn to output_format
|
||||
new_transdata_node = AddTransOpNodeToGraph(func_graph, new_transpose_node, kernel_select_, 0, kOpFormat_HWCN,
|
||||
output_format, kTransDataOpName, false);
|
||||
new_replace_node = new_transdata_node;
|
||||
}
|
||||
FuncGraphManagerPtr manager = func_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
manager->AddFuncGraph(func_graph);
|
||||
|
||||
if (!manager->Replace(node, new_replace_node)) {
|
||||
MS_LOG(EXCEPTION) << "Manager replace node failed";
|
||||
}
|
||||
MS_LOG(INFO) << "Transdata node:" << cnode->DebugString() << "split success.";
|
||||
return true;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -289,6 +289,11 @@ size_t AnfRuntimeAlgorithm::GetOutputTensorNum(const AnfNodePtr &node) {
|
|||
|
||||
std::string AnfRuntimeAlgorithm::GetOutputFormat(const AnfNodePtr &node, size_t output_idx) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (output_idx > GetOutputTensorNum(node)) {
|
||||
MS_LOG(EXCEPTION) << "Output index:" << output_idx
|
||||
<< " is out of the node output range :" << GetOutputTensorNum(node) << " #node ["
|
||||
<< node->DebugString() << "]";
|
||||
}
|
||||
auto kernel_info = node->kernel_info();
|
||||
MS_EXCEPTION_IF_NULL(kernel_info);
|
||||
auto build_info = kernel_info->select_kernel_build_info();
|
||||
|
@ -298,6 +303,11 @@ std::string AnfRuntimeAlgorithm::GetOutputFormat(const AnfNodePtr &node, size_t
|
|||
|
||||
std::string AnfRuntimeAlgorithm::GetInputFormat(const AnfNodePtr &node, size_t input_idx) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (input_idx > GetInputTensorNum(node)) {
|
||||
MS_LOG(EXCEPTION) << "Input index :" << input_idx
|
||||
<< " is out of the number node Input range :" << GetInputTensorNum(node) << "#node ["
|
||||
<< node->DebugString() << "]";
|
||||
}
|
||||
auto kernel_info = node->kernel_info();
|
||||
MS_EXCEPTION_IF_NULL(kernel_info);
|
||||
auto build_info = kernel_info->select_kernel_build_info();
|
||||
|
@ -362,62 +372,60 @@ std::vector<size_t> AnfRuntimeAlgorithm::GetPrevNodeOutputInferShape(const AnfNo
|
|||
std::vector<size_t> AnfRuntimeAlgorithm::GetOutputDeviceShape(const AnfNodePtr &node, size_t output_idx) {
|
||||
auto format = GetOutputFormat(node, output_idx);
|
||||
auto infer_shape = GetOutputInferShape(node, output_idx);
|
||||
// if format is default_format or NC1KHKWHWC0,device shape = original shape
|
||||
if (format == kOpFormat_DEFAULT || format == kOpFormat_NC1KHKWHWC0) {
|
||||
return infer_shape;
|
||||
}
|
||||
// scalar shape
|
||||
if (infer_shape.empty()) {
|
||||
return infer_shape;
|
||||
}
|
||||
if (format == kOpFormat_FRAC_NZ) {
|
||||
return trans::TransShapeToDevice(infer_shape, format);
|
||||
// if format is default_format or NC1KHKWHWC0,device shape = original shape
|
||||
if (trans::IsNeedPadding(format, infer_shape.size())) {
|
||||
infer_shape = trans::PaddingShapeTo4d(infer_shape, GetOutputReshapeType(node, output_idx));
|
||||
}
|
||||
// else trans infer shape to 4d and then calculate device shape
|
||||
return trans::TransShapeToDevice(trans::TransShapeTo4d(infer_shape), format);
|
||||
return trans::TransShapeToDevice(infer_shape, format);
|
||||
}
|
||||
|
||||
std::vector<size_t> AnfRuntimeAlgorithm::GetInputDeviceShape(const AnfNodePtr &node, size_t input_idx) {
|
||||
auto format = GetInputFormat(node, input_idx);
|
||||
auto infer_shape = GetPrevNodeOutputInferShape(node, input_idx);
|
||||
// if format is default_format or NC1KHKWHWC0,device shape = original shape
|
||||
if (format == kOpFormat_DEFAULT || format == kOpFormat_NC1KHKWHWC0) {
|
||||
return infer_shape;
|
||||
}
|
||||
if (infer_shape.empty()) {
|
||||
return infer_shape;
|
||||
}
|
||||
if (format == kOpFormat_FRAC_NZ) {
|
||||
return trans::TransShapeToDevice(infer_shape, format);
|
||||
// if format is default_format or NC1KHKWHWC0,device shape = original shape
|
||||
if (trans::IsNeedPadding(format, infer_shape.size())) {
|
||||
infer_shape = trans::PaddingShapeTo4d(infer_shape, GetInputReshapeType(node, input_idx));
|
||||
}
|
||||
// else trans infer shape to 4d and then calculate device shape
|
||||
return trans::TransShapeToDevice(trans::TransShapeTo4d(infer_shape), format);
|
||||
return trans::TransShapeToDevice(infer_shape, format);
|
||||
}
|
||||
|
||||
std::vector<kernel::Axis> AnfRuntimeAlgorithm::GetInputReshapeType(const AnfNodePtr &node, size_t input_idx) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (input_idx > GetInputTensorNum(node)) {
|
||||
MS_LOG(EXCEPTION) << "The index:" << input_idx
|
||||
<< " is out of range of the node's input size : " << GetInputTensorNum(node) << "#node["
|
||||
<< node->DebugString() << "]";
|
||||
}
|
||||
auto kernel_info = node->kernel_info();
|
||||
MS_EXCEPTION_IF_NULL(kernel_info);
|
||||
auto build_info = kernel_info->select_kernel_build_info();
|
||||
MS_EXCEPTION_IF_NULL(build_info);
|
||||
std::vector<kernel::Axis> result;
|
||||
if (!build_info->GetInputReshapeType(input_idx, &result)) {
|
||||
MS_LOG(EXCEPTION) << "Failed to get the node's[ " << node->DebugString() << "] reshape type !";
|
||||
if (build_info->IsInputDefaultPadding()) {
|
||||
return {};
|
||||
}
|
||||
return result;
|
||||
return build_info->GetInputReshapeType(input_idx);
|
||||
}
|
||||
|
||||
std::vector<kernel::Axis> AnfRuntimeAlgorithm::GetOutputReshapeType(const AnfNodePtr &node, size_t output_idx) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (output_idx > GetOutputTensorNum(node)) {
|
||||
MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ "
|
||||
<< GetOutputTensorNum(node) << "#node[ " << node->DebugString() << "]";
|
||||
}
|
||||
auto kernel_info = node->kernel_info();
|
||||
MS_EXCEPTION_IF_NULL(kernel_info);
|
||||
auto build_info = kernel_info->select_kernel_build_info();
|
||||
MS_EXCEPTION_IF_NULL(build_info);
|
||||
std::vector<kernel::Axis> result;
|
||||
if (!build_info->GetOutputReshapeType(output_idx, &result)) {
|
||||
MS_LOG(EXCEPTION) << "Failed to get the node's[ " << node->DebugString() << "] reshape type !";
|
||||
if (build_info->IsOutputDefaultPadding()) {
|
||||
return {};
|
||||
}
|
||||
return result;
|
||||
return build_info->GetOutputReshapeType(output_idx);
|
||||
}
|
||||
|
||||
TypeId AnfRuntimeAlgorithm::GetOutputInferDataType(const AnfNodePtr &node, size_t output_idx) {
|
||||
|
@ -463,6 +471,10 @@ TypeId AnfRuntimeAlgorithm::GetPrevNodeOutputInferDataType(const AnfNodePtr &nod
|
|||
|
||||
TypeId AnfRuntimeAlgorithm::GetOutputDeviceDataType(const AnfNodePtr &node, size_t output_idx) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (output_idx > GetOutputTensorNum(node)) {
|
||||
MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ "
|
||||
<< GetOutputTensorNum(node) << "#node [ " << node->DebugString() << "]";
|
||||
}
|
||||
auto kernel_info = node->kernel_info();
|
||||
MS_EXCEPTION_IF_NULL(kernel_info);
|
||||
auto build_info = kernel_info->select_kernel_build_info();
|
||||
|
@ -472,6 +484,10 @@ TypeId AnfRuntimeAlgorithm::GetOutputDeviceDataType(const AnfNodePtr &node, size
|
|||
|
||||
TypeId AnfRuntimeAlgorithm::GetInputDeviceDataType(const AnfNodePtr &node, size_t input_idx) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (input_idx > GetInputTensorNum(node)) {
|
||||
MS_LOG(EXCEPTION) << "The index [" << input_idx << "] is out of range of the node's input size [ "
|
||||
<< GetInputTensorNum(node) << "#node [ " << node->DebugString() << "]";
|
||||
}
|
||||
auto kernel_info = node->kernel_info();
|
||||
MS_EXCEPTION_IF_NULL(kernel_info);
|
||||
auto build_info = kernel_info->select_kernel_build_info();
|
||||
|
@ -496,11 +512,15 @@ const DeviceAddress *AnfRuntimeAlgorithm::GetOutputAddr(const AnfNodePtr &node,
|
|||
MS_LOG(EXCEPTION) << node->DebugString() << "Invalid nop node";
|
||||
}
|
||||
}
|
||||
if (output_idx > GetOutputTensorNum(node)) {
|
||||
MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ "
|
||||
<< GetOutputTensorNum(node) << "#node:[ " << node->DebugString() << "]";
|
||||
}
|
||||
auto kernel_info = node->kernel_info();
|
||||
MS_EXCEPTION_IF_NULL(kernel_info);
|
||||
auto addr = kernel_info->GetOutputAddr(output_idx);
|
||||
if (addr == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "output_idx " << output_idx << " of node " << node->DebugString()
|
||||
MS_LOG(EXCEPTION) << "Output_idx " << output_idx << " of node " << node->DebugString()
|
||||
<< " output addr is not exist";
|
||||
}
|
||||
return addr;
|
||||
|
@ -517,11 +537,15 @@ DeviceAddressPtr AnfRuntimeAlgorithm::GetMutableOutputAddr(const AnfNodePtr &nod
|
|||
MS_LOG(EXCEPTION) << node->DebugString() << "Invalid nop node.";
|
||||
}
|
||||
}
|
||||
if (output_idx > GetOutputTensorNum(node)) {
|
||||
MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ "
|
||||
<< GetOutputTensorNum(node) << "#node:[ " << node->DebugString() << "]";
|
||||
}
|
||||
auto kernel_info = node->kernel_info();
|
||||
MS_EXCEPTION_IF_NULL(kernel_info);
|
||||
auto addr = kernel_info->GetMutableOutputAddr(output_idx);
|
||||
if (addr == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "output_idx" << output_idx << " of node " << node->DebugString()
|
||||
MS_LOG(EXCEPTION) << "Output_idx" << output_idx << " of node " << node->DebugString()
|
||||
<< " output addr is not exist";
|
||||
}
|
||||
return addr;
|
||||
|
@ -530,6 +554,10 @@ DeviceAddressPtr AnfRuntimeAlgorithm::GetMutableOutputAddr(const AnfNodePtr &nod
|
|||
// get output device addr of anf_node
|
||||
bool AnfRuntimeAlgorithm::OutputAddrExist(const AnfNodePtr &node, size_t output_idx) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (output_idx > GetOutputTensorNum(node)) {
|
||||
MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ "
|
||||
<< GetOutputTensorNum(node) << "#node:[ " << node->DebugString() << "]";
|
||||
}
|
||||
auto kernel_info = node->kernel_info();
|
||||
MS_EXCEPTION_IF_NULL(kernel_info);
|
||||
return kernel_info->OutputAddrExist(output_idx);
|
||||
|
@ -769,22 +797,24 @@ AnfNodePtr AnfRuntimeAlgorithm::GetInputNode(const CNodePtr &node, size_t index)
|
|||
return node->input(get_input_index);
|
||||
}
|
||||
|
||||
bool AnfRuntimeAlgorithm::IsFeatureMapOutput(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (node->isa<ValueNode>()) {
|
||||
return false;
|
||||
}
|
||||
auto kernel_info = node->kernel_info();
|
||||
MS_EXCEPTION_IF_NULL(kernel_info);
|
||||
return kernel_info->is_feature_map();
|
||||
}
|
||||
|
||||
bool AnfRuntimeAlgorithm::IsFeatureMapInput(const AnfNodePtr &node, size_t input_index) {
|
||||
if (!node->isa<CNode>()) {
|
||||
MS_LOG(EXCEPTION) << "Cannot input a parameter or a valuenode to charge it's input if is a feature";
|
||||
MS_LOG(EXCEPTION) << "Cannot input a parameter or a valuenode to charge it's input if is a feature map";
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto input_node = cnode->input(input_index + 1);
|
||||
auto node_with_index = VisitKernel(input_node, 0);
|
||||
MS_EXCEPTION_IF_NULL(node_with_index.first);
|
||||
if (node_with_index.first->isa<ValueNode>()) {
|
||||
return false;
|
||||
}
|
||||
if (node_with_index.first->isa<Parameter>()) {
|
||||
return !AnfAlgo::IsParameterWeight(node_with_index.first->cast<ParameterPtr>());
|
||||
}
|
||||
return true;
|
||||
return IsFeatureMapOutput(input_node);
|
||||
}
|
||||
|
||||
size_t AnfRuntimeAlgorithm::GetRealInputIndex(const mindspore::AnfNodePtr &anf_node, const size_t cur_index) {
|
||||
|
|
|
@ -101,7 +101,9 @@ class AnfRuntimeAlgorithm {
|
|||
static std::vector<size_t> GetOutputDeviceShape(const AnfNodePtr &node, size_t output_idx);
|
||||
// get input shapes which will built and run in device
|
||||
static std::vector<size_t> GetInputDeviceShape(const AnfNodePtr &node, size_t input_idx);
|
||||
// Get Input Padding Axis
|
||||
static std::vector<kernel::Axis> GetInputReshapeType(const AnfNodePtr &node, size_t output_idx);
|
||||
// Get Output Padding Axis
|
||||
static std::vector<kernel::Axis> GetOutputReshapeType(const AnfNodePtr &node, size_t output_idx);
|
||||
// get output data type inferred by ME of anf node
|
||||
static TypeId GetOutputInferDataType(const AnfNodePtr &node, size_t output_idx);
|
||||
|
@ -165,6 +167,9 @@ class AnfRuntimeAlgorithm {
|
|||
// get graph id
|
||||
static uint32_t GetGraphId(const AnfNode *node);
|
||||
static AnfNodePtr GetInputNode(const CNodePtr &node, size_t index);
|
||||
// charge if the node's output is a feature map output
|
||||
static bool IsFeatureMapOutput(const AnfNodePtr &node);
|
||||
// charge if the node's input is from a feature map output
|
||||
static bool IsFeatureMapInput(const AnfNodePtr &node, size_t input_index);
|
||||
// get real input index for some tbe ops which input order is different between me and tbe impl
|
||||
static size_t GetRealInputIndex(const AnfNodePtr &anf_node, const size_t cur_index);
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#include "operator/ops.h"
|
||||
#include "ir/meta_tensor.h"
|
||||
#include "ir/anf.h"
|
||||
#include "common/trans.h"
|
||||
#include "device/kernel_runtime.h"
|
||||
#include "device/ascend/kernel_select_ascend.h"
|
||||
#include "device/ascend/kernel_build_ascend.h"
|
||||
|
@ -730,8 +731,8 @@ void AscendSession::SetChildGraphParameter(const tensor::TensorPtr &front_tensor
|
|||
size_t tensor_size = front_tensor->data().nbytes();
|
||||
auto addr = AnfAlgo::GetOutputAddr(backend_parameter, 0);
|
||||
MS_EXCEPTION_IF_NULL(addr);
|
||||
if (!addr->SyncHostToDevice(front_tensor->shape(), tensor_size, front_tensor->data_type(),
|
||||
front_tensor->data_c(false))) {
|
||||
if (!addr->SyncHostToDevice(trans::GetRuntimePaddingShape(backend_parameter, 0), tensor_size,
|
||||
front_tensor->data_type(), front_tensor->data_c(false))) {
|
||||
MS_LOG(EXCEPTION) << "Tensor SyncHostToDevice fail!";
|
||||
}
|
||||
MS_LOG(INFO) << "Finish!";
|
||||
|
|
|
@ -143,6 +143,12 @@ CNodePtr KernelGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) {
|
|||
cnode->set_abstract(std::make_shared<abstract::AbstractNone>());
|
||||
// create kernel_info from new parameter
|
||||
auto kernel_info = std::make_shared<device::KernelInfo>();
|
||||
// if the node only has the primitive(such as getNext) or the node's input has a feature map input
|
||||
// then the node's output is a feature map output
|
||||
if (inputs.size() == 1 || std::any_of(inputs.begin() + 1, inputs.end(),
|
||||
[&](const AnfNodePtr &node) { return AnfAlgo::IsFeatureMapOutput(node); })) {
|
||||
kernel_info->SetFeatureMapFlag(true);
|
||||
}
|
||||
cnode->set_kernel_info(kernel_info);
|
||||
AnfAlgo::SetGraphId(graph_id_, cnode.get());
|
||||
return cnode;
|
||||
|
@ -162,22 +168,26 @@ CNodePtr KernelGraph::NewCNode(const CNodePtr &cnode) {
|
|||
ParameterPtr KernelGraph::NewParameter(const ParameterPtr ¶meter) {
|
||||
ParameterPtr new_parameter = add_parameter();
|
||||
MS_EXCEPTION_IF_NULL(new_parameter);
|
||||
// create kernel_info form new parameter
|
||||
auto kernel_info = std::make_shared<device::KernelInfo>();
|
||||
size_t output_tensor_num = 1;
|
||||
// if use default parameter = nullptr,it remarks create a new parameter from no parameter
|
||||
if (parameter == nullptr) {
|
||||
new_parameter->set_abstract(std::make_shared<abstract::AbstractNone>());
|
||||
kernel_info->SetFeatureMapFlag(true);
|
||||
} else {
|
||||
// if don't use default parameter = nullptr,it remarks create a new parameter from a old parameter
|
||||
new_parameter->set_abstract(parameter->abstract());
|
||||
new_parameter->set_name(parameter->name());
|
||||
if (parameter->has_default()) {
|
||||
if (AnfAlgo::IsParameterWeight(parameter)) {
|
||||
new_parameter->set_default_param(parameter->default_param());
|
||||
kernel_info->SetFeatureMapFlag(false);
|
||||
} else {
|
||||
kernel_info->SetFeatureMapFlag(true);
|
||||
}
|
||||
// if output is a tuple tensor,now can use for loop to handle tuple tensor
|
||||
output_tensor_num = AnfAlgo::GetOutputTensorNum(parameter);
|
||||
}
|
||||
// create kernel_info form new parameter
|
||||
auto kernel_info = std::make_shared<device::KernelInfo>();
|
||||
new_parameter->set_kernel_info(kernel_info);
|
||||
// create kernel_build_info for new parameter
|
||||
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
|
||||
|
@ -217,6 +227,7 @@ std::vector<AnfNodePtr> KernelGraph::SplitTupleValueNodeToNodeList(const ValueNo
|
|||
AddValueNodeToGraph(new_value_node);
|
||||
auto kernel_info = std::make_shared<device::KernelInfo>();
|
||||
new_value_node->set_kernel_info(kernel_info);
|
||||
kernel_info->SetFeatureMapFlag(false);
|
||||
// create kernel_build_info for new value node
|
||||
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
|
||||
// set the format of value_node to DEFAULT_FORMAT
|
||||
|
@ -240,6 +251,7 @@ ValueNodePtr KernelGraph::NewValueNode(const ValueNodePtr &value_node) {
|
|||
new_value_node->set_abstract(value_node->abstract());
|
||||
// create kernel_info fo new value node
|
||||
auto kernel_info = std::make_shared<device::KernelInfo>();
|
||||
kernel_info->SetFeatureMapFlag(false);
|
||||
new_value_node->set_kernel_info(kernel_info);
|
||||
// create kernel_build_info for new value node
|
||||
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include "pipeline/parse/data_converter.h"
|
||||
#include "ir/manager.h"
|
||||
#include "operator/ops.h"
|
||||
#include "common/trans.h"
|
||||
#include "utils/context/ms_context.h"
|
||||
#include "utils/config_manager.h"
|
||||
#include "session/anf_runtime_algorithm.h"
|
||||
|
@ -124,7 +125,8 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne
|
|||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
if (ms_context->enable_pynative_infer()) {
|
||||
tensor->set_device_address(AnfAlgo::GetMutableOutputAddr(node, output_index));
|
||||
} else if (!address->SyncDeviceToHost(tensor->shape(), LongToSize(tensor->data().nbytes()), tensor->data_type(),
|
||||
} else if (!address->SyncDeviceToHost(trans::GetRuntimePaddingShape(node, output_index),
|
||||
LongToSize(tensor->data().nbytes()), tensor->data_type(),
|
||||
tensor->data_c(true))) {
|
||||
MS_LOG(INFO) << "output sync device to host error!!!";
|
||||
tensor->set_dirty(false);
|
||||
|
@ -369,7 +371,7 @@ ParameterPtr ConstructRunOpParameter(const std::shared_ptr<KernelGraph> &graph,
|
|||
kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>{input_tensor->device_address()->type_id()});
|
||||
}
|
||||
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), param.get());
|
||||
// construct abstract of parameter
|
||||
// ftruct abstract of parameter
|
||||
auto abstract = std::make_shared<abstract::AbstractTensor>(input_tensor);
|
||||
param->set_abstract(abstract);
|
||||
return param;
|
||||
|
@ -548,7 +550,8 @@ void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_grap
|
|||
if (need_sync) {
|
||||
tensor->set_device_address(device_address);
|
||||
MS_EXCEPTION_IF_NULL(device_address);
|
||||
if (!device_address->SyncHostToDevice(tensor->shape(), LongToSize(tensor->data().nbytes()), tensor->data_type(),
|
||||
if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0),
|
||||
LongToSize(tensor->data().nbytes()), tensor->data_type(),
|
||||
tensor->data_c(false))) {
|
||||
MS_LOG(EXCEPTION) << "SyncHostToDevice failed.";
|
||||
}
|
||||
|
@ -620,8 +623,8 @@ void SessionBasic::Summary(KernelGraph *graph) {
|
|||
(void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape));
|
||||
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
|
||||
MS_EXCEPTION_IF_NULL(address);
|
||||
if (!address->SyncDeviceToHost(tensor->shape(), LongToSize(tensor->data().nbytes()), tensor->data_type(),
|
||||
tensor->data_c(true))) {
|
||||
if (!address->SyncDeviceToHost(trans::GetRuntimePaddingShape(node, index), LongToSize(tensor->data().nbytes()),
|
||||
tensor->data_type(), tensor->data_c(true))) {
|
||||
MS_LOG(ERROR) << "Failed to sync output from device to host.";
|
||||
}
|
||||
tensor->set_dirty(false);
|
||||
|
|
|
@ -197,8 +197,8 @@ const std::set<std::string> kOptOperatorSet = {
|
|||
kApplyRMSPropOpName,
|
||||
};
|
||||
|
||||
const std::set<std::string> kSpecialFormatSet = {kOpFormat_FRAC_Z, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0,
|
||||
kOpFormat_FRAC_NZ, kOpFormat_C1HWNCoC0};
|
||||
const std::set<std::string> kNeedTransFormatSet = {kOpFormat_FRAC_Z, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0,
|
||||
kOpFormat_FRAC_NZ, kOpFormat_C1HWNCoC0};
|
||||
|
||||
static inline void ChangeFileMode(const std::string& file_name, mode_t mode) {
|
||||
if (access(file_name.c_str(), F_OK) != 0) {
|
||||
|
|
|
@ -80,6 +80,8 @@ TEST_F(TestHWLayerNormBetaGammaBackpropFusion, layernorm_beta_gamma_backprop_fus
|
|||
builder1.SetOutputsDeviceType({kNumberTypeFloat32});
|
||||
cast0->set_kernel_info(std::make_shared<device::KernelInfo>());
|
||||
cast1->set_kernel_info(std::make_shared<device::KernelInfo>());
|
||||
cast0->set_abstract(x_abstract);
|
||||
cast1->set_abstract(x_abstract);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder1.Build(), cast0.get());
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder1.Build(), cast1.get());
|
||||
|
||||
|
|
|
@ -211,8 +211,8 @@ TEST_F(AnfRuntimeAlgorithmTest, EraseNodeAttr) {
|
|||
TEST_F(AnfRuntimeAlgorithmTest, GetInputTensorNum) {
|
||||
auto kernel_graph = std::make_shared<KernelGraph>();
|
||||
// test cnode node
|
||||
auto parameter_one = kernel_graph->add_parameter();
|
||||
auto parameter_two = kernel_graph->add_parameter();
|
||||
auto parameter_one = kernel_graph->NewParameter();
|
||||
auto parameter_two = kernel_graph->NewParameter();
|
||||
std::vector<AnfNodePtr> add_inputs{NewValueNode(prim::kPrimTensorAdd), parameter_one, parameter_two};
|
||||
auto add = kernel_graph->NewCNode(add_inputs);
|
||||
EXPECT_EQ(AnfAlgo::GetInputTensorNum(add), 2);
|
||||
|
@ -247,9 +247,11 @@ TEST_F(AnfRuntimeAlgorithmTest, GetOutputTensorNum) {
|
|||
|
||||
TEST_F(AnfRuntimeAlgorithmTest, GetOutputFormat) {
|
||||
auto kernel_graph = std::make_shared<KernelGraph>();
|
||||
std::vector<AnfNodePtr> inputs;
|
||||
inputs.push_back(NewValueNode(prim::kPrimTensorAdd));
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(prim::kPrimTensorAdd), kernel_graph->NewParameter(),
|
||||
kernel_graph->NewParameter()};
|
||||
auto add = kernel_graph->NewCNode(inputs);
|
||||
std::vector<size_t> shape = {1, 2, 3, 4};
|
||||
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32, kNumberTypeFloat32}, {shape, shape}, add.get());
|
||||
MS_EXCEPTION_IF_NULL(add);
|
||||
add->set_kernel_info(std::make_shared<KernelInfo>());
|
||||
auto d_kernel_info = add->kernel_info();
|
||||
|
@ -266,8 +268,8 @@ TEST_F(AnfRuntimeAlgorithmTest, GetOutputFormat) {
|
|||
|
||||
TEST_F(AnfRuntimeAlgorithmTest, GetInputFormat) {
|
||||
auto kernel_graph = std::make_shared<KernelGraph>();
|
||||
std::vector<AnfNodePtr> inputs;
|
||||
inputs.push_back(NewValueNode(prim::kPrimTensorAdd));
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(prim::kPrimTensorAdd), kernel_graph->NewParameter(),
|
||||
kernel_graph->NewParameter()};
|
||||
auto add = kernel_graph->NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(add);
|
||||
add->set_kernel_info(std::make_shared<KernelInfo>());
|
||||
|
@ -345,7 +347,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetPrevNodeOutputInferShape) {
|
|||
std::vector<int> shp{2, 32, 224, 224};
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||
// test parameter node as input
|
||||
auto parameter_node = kernel_graph->add_parameter();
|
||||
auto parameter_node = kernel_graph->NewParameter();
|
||||
MS_EXCEPTION_IF_NULL(parameter_node);
|
||||
parameter_node->set_abstract(x_abstract);
|
||||
EXPECT_THROW(AnfAlgo::GetPrevNodeOutputInferShape(parameter_node, 0), std::runtime_error);
|
||||
|
@ -387,13 +389,13 @@ TEST_F(AnfRuntimeAlgorithmTest, GetInputDeviceShape) {
|
|||
auto kernel_graph = std::make_shared<KernelGraph>();
|
||||
std::vector<int> shp{2, 32, 224, 224};
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||
auto parameter_one = kernel_graph->add_parameter();
|
||||
auto parameter_one = kernel_graph->NewParameter();
|
||||
MS_EXCEPTION_IF_NULL(parameter_one);
|
||||
parameter_one->set_abstract(x_abstract);
|
||||
auto parameter_two = kernel_graph->add_parameter();
|
||||
auto parameter_two = kernel_graph->NewParameter();
|
||||
MS_EXCEPTION_IF_NULL(parameter_two);
|
||||
parameter_two->set_abstract(x_abstract);
|
||||
auto parameter_third = kernel_graph->add_parameter();
|
||||
auto parameter_third = kernel_graph->NewParameter();
|
||||
MS_EXCEPTION_IF_NULL(parameter_third);
|
||||
parameter_third->set_abstract(x_abstract);
|
||||
// test cnode as input
|
||||
|
@ -466,8 +468,8 @@ TEST_F(AnfRuntimeAlgorithmTest, GetOutputDeviceDataTypeTest) {
|
|||
|
||||
TEST_F(AnfRuntimeAlgorithmTest, GetInputDeviceDataTypeTest) {
|
||||
auto kernel_graph = std::make_shared<KernelGraph>();
|
||||
std::vector<AnfNodePtr> inputs;
|
||||
inputs.push_back(NewValueNode(prim::kPrimTensorAdd));
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(prim::kPrimTensorAdd), kernel_graph->NewParameter(),
|
||||
kernel_graph->NewParameter()};
|
||||
auto add = kernel_graph->NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(add);
|
||||
add->set_kernel_info(std::make_shared<KernelInfo>());
|
||||
|
|
|
@ -140,11 +140,11 @@ TEST_F(KernelGraphTest, SetExecOrderByDefault) {
|
|||
std::vector<int> shape = {2, 32, 224, 224};
|
||||
auto abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shape);
|
||||
|
||||
auto x_parameter = kernel_graph->add_parameter();
|
||||
auto x_parameter = kernel_graph->NewParameter();
|
||||
MS_EXCEPTION_IF_NULL(x_parameter);
|
||||
x_parameter->set_name("x_parameter");
|
||||
x_parameter->set_abstract(abstract);
|
||||
auto y_parameter = kernel_graph->add_parameter();
|
||||
auto y_parameter = kernel_graph->NewParameter();
|
||||
MS_EXCEPTION_IF_NULL(y_parameter);
|
||||
y_parameter->set_name("y_parameter");
|
||||
y_parameter->set_abstract(abstract);
|
||||
|
@ -153,7 +153,7 @@ TEST_F(KernelGraphTest, SetExecOrderByDefault) {
|
|||
MS_EXCEPTION_IF_NULL(add);
|
||||
add->set_abstract(abstract);
|
||||
|
||||
auto z_parameter = kernel_graph->add_parameter();
|
||||
auto z_parameter = kernel_graph->NewParameter();
|
||||
MS_EXCEPTION_IF_NULL(z_parameter);
|
||||
z_parameter->set_name("z_parameter");
|
||||
z_parameter->set_abstract(abstract);
|
||||
|
|
Loading…
Reference in New Issue