Merge pull request !31337 from liubuyu/code_dex
This commit is contained in:
i-robot 2022-03-17 01:15:44 +00:00 committed by Gitee
commit 4f24ab161c
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 69 additions and 84 deletions

View File

@ -115,10 +115,7 @@ static bool KernelBuildParallelCompile(const std::vector<CNodePtr> &kernels) {
return akg_ret;
}
bool KernelBuild(const std::vector<CNodePtr> &kernels) {
TbeUtils::LoadCache();
return device::ascend::KernelBuildParallelCompile(kernels);
}
bool KernelBuild(const std::vector<CNodePtr> &kernels) { return device::ascend::KernelBuildParallelCompile(kernels); }
namespace {
bool IsAtomicNode(const CNodePtr &kernel_node) {

View File

@ -870,6 +870,8 @@ void TbeKernelCompileManager::TbeInitialize() {
auto init_ret = DispatchCompileTask(init_json);
auto json_ret = TurnStrToJson(init_ret);
PrintInitResult(json_ret);
// load cache before kernel build
TbeUtils::LoadCache();
}
void TbeKernelCompileManager::TbeFinalize() {

View File

@ -36,6 +36,15 @@ constexpr int64_t k4 = 4;
static const std::set<TypeId> C0_64 = {kNumberTypeInt4};
static const std::set<TypeId> C0_32 = {kNumberTypeUInt8, kNumberTypeInt8};
namespace {
const size_t hw_h = 1;
const size_t hw_w = 2;
const size_t fnz_w1 = 4;
const size_t fnz_h1 = 3;
const size_t fnz_h0 = 2;
const size_t fnz_w0 = 1;
const size_t fz_n0 = 1;
const size_t fz_ni = 2;
const size_t fz_c0 = 3;
bool HasShapeDynamic(const ShapeVector &shape_list) {
return std::any_of(shape_list.begin(), shape_list.end(), [](int64_t v) { return v == abstract::Shape::SHP_ANY; });
}
@ -239,7 +248,7 @@ ShapeVector GetRuntimePaddingShape(const AnfNodePtr &node, size_t index) {
if (IsNeedPadding(format, host_shape.size())) {
host_shape = PaddingShape(host_shape, format, AnfAlgo::GetOutputReshapeType(node, index));
}
std::transform(host_shape.begin(), host_shape.end(), std::back_inserter(shape), SizeToLong);
(void)std::transform(host_shape.begin(), host_shape.end(), std::back_inserter(shape), SizeToLong);
return shape;
}
@ -295,7 +304,7 @@ void TransDataSrc2Fp16(const TypeIdArgs &args, void *dst, const int64_t data_siz
}
}
bool DataTypeTransfer::CastKernel(const TypeIdArgs &args, void *dst, int64_t data_size, DataTypeTransMode mode) {
bool DataTypeTransfer::CastKernel(const TypeIdArgs &args, void *dst, int64_t data_size, DataTypeTransMode mode) const {
using DtypeKernel = std::function<void(const TypeIdArgs &, void *, const int64_t)>;
const std::map<DataTypeTransMode, DtypeKernel> cast_kernel_map{
{DataTypeTransMode::FROM_BOOL_TO_UINT8, TransDataSrc2Dst<int8_t, uint8_t>},
@ -323,10 +332,10 @@ bool DataTypeTransfer::CastKernel(const TypeIdArgs &args, void *dst, int64_t dat
{DataTypeTransMode::FROM_FLOAT64_TO_FLOAT32, TransDataSrc2Dst<double, float>}};
if (mode == DataTypeTransMode::FROM_FLOAT_TO_FLOAT16) {
device::FloatToHalf(dst, args.data, data_size);
device::FloatToHalf(dst, args.data, LongToSize(data_size));
return true;
} else if (mode == DataTypeTransMode::FROM_FLOAT16_TO_FLOAT) {
device::HalfToFloat(dst, args.data, data_size);
device::HalfToFloat(dst, args.data, LongToSize(data_size));
return true;
}
auto iter = cast_kernel_map.find(mode);
@ -577,7 +586,7 @@ ShapeVector DeviceShapeTransfer::C1HWNCOC0DeviceShape(const ShapeVector &shape,
return device_shape;
}
ShapeVector DeviceShapeTransfer::FRAC_ZC04DeviceShape(const ShapeVector &shape, const TypeId &type) {
ShapeVector DeviceShapeTransfer::FRAC_ZC04DeviceShape(const ShapeVector &shape, const TypeId &) {
if (!CheckDims(shape)) {
MS_LOG(EXCEPTION) << "Check dims failed.";
}
@ -628,7 +637,7 @@ ShapeVector DeviceShapeTransfer::ChannelLastDeviceShape(const ShapeVector &shape
axis[dim - 1] = 1;
ShapeVector device_shape;
(void)std::transform(axis.begin(), axis.end(), std::back_inserter(device_shape),
[&shape](int64_t n) { return shape[n]; });
[&shape](size_t n) { return shape[n]; });
return device_shape;
}
@ -656,7 +665,7 @@ ShapeVector DeviceShapeTransfer::FRAC_NZDeviceShape(const ShapeVector &shape, co
return device_shape;
}
ShapeVector DeviceShapeTransfer::FRAC_ZN_LSTMDeviceShape(const ShapeVector &shape, const TypeId &type) {
ShapeVector DeviceShapeTransfer::FRAC_ZN_LSTMDeviceShape(const ShapeVector &shape, const TypeId &) {
ShapeVector device_shape;
const int64_t lstm_ni = 4;
const int64_t ni = 16;
@ -971,15 +980,7 @@ bool FormatTransfer::TO_NCHW(const FormatArgs &args, void *result) {
bool FormatTransfer::NCHW_TO_FRAC_Z(const FormatArgs &args, void *result) {
MS_LOG(DEBUG) << "Trans format from nchw to frac_z";
MS_EXCEPTION_IF_NULL(result);
if (args.host_shape.size() != kNchwDims) {
MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
return false;
}
auto size = SizeToLong(abstract::TypeIdSize(args.src_data_type));
if (size < 1) {
MS_LOG(ERROR) << "Illegal dtype: " << args.src_data_type;
return false;
}
auto size = Common4DCheck(args);
auto n = args.host_shape[kN];
auto c = args.host_shape[kC];
auto h = args.host_shape[kH];
@ -1048,15 +1049,15 @@ bool FormatTransfer::NCHW_TO_FRAC_NZ(const FormatArgs &args, void *result) {
return false;
}
auto times = hw_shape.at(0);
auto h = hw_shape.at(1);
auto w = hw_shape.at(2);
auto h = hw_shape.at(hw_h);
auto w = hw_shape.at(hw_w);
auto hw = h * w;
auto shape_size = args.device_shape.size();
auto w1 = args.device_shape[shape_size - 4];
auto h1 = args.device_shape[shape_size - 3];
auto h0 = args.device_shape[shape_size - 2];
auto w0 = args.device_shape[shape_size - 1];
auto w1 = args.device_shape[shape_size - fnz_w1];
auto h1 = args.device_shape[shape_size - fnz_h1];
auto h0 = args.device_shape[shape_size - fnz_h0];
auto w0 = args.device_shape[shape_size - fnz_w0];
auto h1h0w0 = h1 * h0 * w0;
auto w1h1h0w0 = w1 * h1h0w0;
auto num_w1 = w / w0;
@ -1136,15 +1137,7 @@ bool FormatTransfer::NCHW_TO_FRAC_ZC04(const FormatArgs &args, void *result) {
bool FormatTransfer::NCHW_TO_NC1HWC0(const FormatArgs &args, void *result) {
MS_LOG(DEBUG) << "Trans format from nchw to Nc1h1wc0";
MS_EXCEPTION_IF_NULL(result);
if (args.host_shape.size() != kNchwDims) {
MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
return false;
}
auto size = SizeToLong(abstract::TypeIdSize(args.src_data_type));
if (size < 1) {
MS_LOG(ERROR) << "Illegal dtype: " << args.src_data_type;
return false;
}
auto size = Common4DCheck(args);
auto total_size = abstract::ShapeSize(args.device_shape) * size;
if (total_size != SizeToLong(args.device_size)) {
MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
@ -1349,15 +1342,7 @@ bool FormatTransfer::NCDHW_TO_FRAC_Z3D(const FormatArgs &args, void *result) {
bool FormatTransfer::NCHW_TO_FRAC_Z_WITH_GROPUS(const FormatArgs &args, void *result, bool to_device, int64_t groups) {
MS_EXCEPTION_IF_NULL(result);
if (args.host_shape.size() != kNchwDims) {
MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
return false;
}
auto size = SizeToLong(abstract::TypeIdSize(args.src_data_type));
if (size < 1) {
MS_LOG(ERROR) << "Illegal dtype: " << args.src_data_type;
return false;
}
auto size = Common4DCheck(args);
auto n_dim = args.host_shape[kN];
auto c_dim = args.host_shape[kC];
auto h_dim = args.host_shape[kH];
@ -1387,7 +1372,7 @@ bool FormatTransfer::NCHW_TO_FRAC_Z_WITH_GROPUS(const FormatArgs &args, void *re
if (dst_size == 0) {
return true;
}
auto ret = memset_s(result, dst_size, 0, dst_size);
auto ret = memset_s(result, LongToSize(dst_size), 0, LongToSize(dst_size));
if (ret != EOK) {
MS_LOG(ERROR) << "memset failed";
return false;
@ -1426,15 +1411,7 @@ bool FormatTransfer::NCHW_TO_FRAC_Z_WITH_GROPUS(const FormatArgs &args, void *re
bool FormatTransfer::NC1HWC0_TO_NCHW(const FormatArgs &args, void *result) {
MS_LOG(DEBUG) << "Trans format from nc1h1wc0 to nchw";
MS_EXCEPTION_IF_NULL(result);
if (args.host_shape.size() != kNchwDims) {
MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
return false;
}
auto size = SizeToLong(abstract::TypeIdSize(args.src_data_type));
if (size < 1) {
MS_LOG(ERROR) << "Illegal dtype: " << args.src_data_type;
return false;
}
auto size = Common4DCheck(args);
auto total_size = abstract::ShapeSize(args.device_shape) * size;
if (total_size != SizeToLong(args.device_size)) {
MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
@ -1517,24 +1494,16 @@ bool FormatTransfer::C1HWNCOC0_TO_NCHW(const FormatArgs &args, void *result) {
bool FormatTransfer::FRAC_Z_TO_NCHW(const FormatArgs &args, void *result) {
MS_LOG(DEBUG) << "Trans format from frac_z to nchw";
MS_EXCEPTION_IF_NULL(result);
if (args.host_shape.size() != kNchwDims) {
MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
return false;
}
auto size = SizeToLong(abstract::TypeIdSize(args.src_data_type));
if (size < 1) {
MS_LOG(ERROR) << "Illegal dtype: " << args.src_data_type;
return false;
}
auto size = Common4DCheck(args);
auto total_size = abstract::ShapeSize(args.device_shape) * size;
if (total_size != SizeToLong(args.device_size)) {
MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
return false;
}
auto n0 = args.device_shape.at(1);
auto ni = args.device_shape.at(2);
auto c0 = args.device_shape.at(3);
auto n0 = args.device_shape.at(fz_n0);
auto ni = args.device_shape.at(fz_ni);
auto c0 = args.device_shape.at(fz_c0);
auto n = args.host_shape[kN];
auto c = args.host_shape[kC];
auto h = args.host_shape[kH];
@ -1590,15 +1559,15 @@ bool FormatTransfer::FRAC_NZ_TO_NCHW(const FormatArgs &args, void *result) {
return false;
}
auto times = hw_shape.at(0);
auto h = hw_shape.at(1);
auto w = hw_shape.at(2);
auto h = hw_shape.at(hw_h);
auto w = hw_shape.at(hw_w);
auto hw = h * w;
auto shape_size = args.device_shape.size();
auto w1 = args.device_shape[shape_size - 4];
auto h1 = args.device_shape[shape_size - 3];
auto h0 = args.device_shape[shape_size - 2];
auto w0 = args.device_shape[shape_size - 1];
auto w1 = args.device_shape[shape_size - fnz_w1];
auto h1 = args.device_shape[shape_size - fnz_h1];
auto h0 = args.device_shape[shape_size - fnz_h0];
auto w0 = args.device_shape[shape_size - fnz_w0];
auto h1h0w0 = h1 * h0 * w0;
auto w1h1h0w0 = w1 * h1h0w0;
auto num_w1 = w / w0;
@ -1747,6 +1716,18 @@ bool FormatTransfer::FRAC_Z_TO_NCHW_WITH_GROUPS(const FormatArgs &args, void *re
return NCHW_TO_FRAC_Z_WITH_GROPUS(args, result, false, groups);
}
int64_t FormatTransfer::Common4DCheck(const FormatArgs &args) {
if (args.host_shape.size() != kNchwDims) {
MS_LOG(EXCEPTION) << "Invalid host shape, host shape dims:" << args.host_shape.size()
<< ", expect dims:" << kNchwDims;
}
auto size = SizeToLong(abstract::TypeIdSize(args.src_data_type));
if (size < 1) {
MS_LOG(EXCEPTION) << "Illegal dtype: " << args.src_data_type;
}
return size;
}
// ######################## RANGE TRANS ########################
RangePair ShapeRangeTransfer::GetRealRange(const RangePair &ori_range, const std::string &format, const TypeId &type) {
const std::set<std::string> no_need_change = {kOpFormat_ND, kOpFormat_DEFAULT, kOpFormat_NCHW, kOpFormat_NCDHW};
@ -1844,9 +1825,10 @@ RangePair ShapeRangeTransfer::FRAC_ZRange(const RangePair &ori_range, const Type
const std::pair<int64_t, int64_t> r0 = {ori_range[kH].first * ori_range[kW].first * cin16.first / cube,
ori_range[kH].second * ori_range[kW].second * cin16.second / cube};
const std::pair<int64_t, int64_t> r1 = {cout16.first / kNiSize, cout16.second / kNiSize};
const std::pair<int64_t, int64_t> co = {kNiSize, kNiSize};
dst_range.push_back(r0);
dst_range.push_back(r1);
dst_range.push_back({kNiSize, kNiSize});
dst_range.push_back(co);
dst_range.push_back(c0);
return dst_range;
}
@ -1865,9 +1847,10 @@ RangePair ShapeRangeTransfer::FRAC_NZRange(const RangePair &ori_range, const Typ
(ori_range[ori_size - 1].second - 1) / cube + 1};
const std::pair<int64_t, int64_t> h1 = {(ori_range[ori_size - kDims2].first - 1) / kNiSize + 1,
(ori_range[ori_size - kDims2].second - 1) / kNiSize + 1};
const std::pair<int64_t, int64_t> co = {kNiSize, kNiSize};
dst_range.push_back(w1);
dst_range.push_back(h1);
dst_range.push_back({kNiSize, kNiSize});
dst_range.push_back(co);
dst_range.push_back(c0);
return dst_range;
}

View File

@ -147,7 +147,7 @@ class DataTypeTransfer {
{std::pair<TypeId, TypeId>(kNumberTypeBool, kNumberTypeUInt8), DataTypeTransMode::FROM_BOOL_TO_UINT8},
{std::pair<TypeId, TypeId>(kNumberTypeBool, kNumberTypeFloat16), DataTypeTransMode::FROM_BOOL_TO_FLOAT16}};
bool CastKernel(const TypeIdArgs &args, void *dst, int64_t data_size, DataTypeTransMode mode);
bool CastKernel(const TypeIdArgs &args, void *dst, int64_t data_size, DataTypeTransMode mode) const;
};
/**
@ -251,6 +251,9 @@ class FormatTransfer {
static bool FRAC_Z3D_TO_NCDHW(const FormatArgs &args, void *result);
static bool NDC1HWC0_TO_NCDHW(const FormatArgs &args, void *result);
static bool FRAC_Z_TO_NCHW_WITH_GROUPS(const FormatArgs &args, void *result, int64_t groups);
// common check_func
static int64_t Common4DCheck(const FormatArgs &args);
};
/**
@ -382,7 +385,7 @@ std::vector<T> PaddingShapeTo4dDefault(const std::vector<T> &shape) {
shape_4d[kW] = shape[kH];
break;
case kNchwDims:
std::copy(shape.begin(), shape.end(), shape_4d.begin());
(void)std::copy(shape.begin(), shape.end(), shape_4d.begin());
break;
default:
MS_LOG(EXCEPTION) << "Unexpected shape : " << shape;
@ -449,13 +452,13 @@ template <typename T>
std::vector<T> TransShapeToDevice(const std::vector<T> &shape, const std::string &format, const AnfNodePtr &node,
size_t index, TypeId type, bool is_output = true) {
ShapeVector shape_before;
std::transform(shape.begin(), shape.end(), std::back_inserter(shape_before),
[](T num) { return static_cast<int64_t>(num); });
(void)std::transform(shape.begin(), shape.end(), std::back_inserter(shape_before),
[](T num) { return static_cast<int64_t>(num); });
DeviceShapeTransfer deviceShapeTransfer;
auto res = deviceShapeTransfer.GetDeviceShapeByFormat(shape_before, format, node, index, type, is_output);
std::vector<T> out_shape;
std::transform(res.begin(), res.end(), std::back_inserter(out_shape),
[](int64_t num) { return static_cast<T>(num); });
(void)std::transform(res.begin(), res.end(), std::back_inserter(out_shape),
[](int64_t num) { return static_cast<T>(num); });
return out_shape;
}
@ -463,13 +466,13 @@ template <typename T>
std::vector<T> TransShapeToDevice(const std::vector<T> &shape, const std::string &format, TypeId type,
int64_t groups = 1, const ShapeVector &input_hidden_size = {kAlign16, kAlign16}) {
ShapeVector shape_before;
std::transform(shape.begin(), shape.end(), std::back_inserter(shape_before),
[](T num) { return static_cast<int64_t>(num); });
(void)std::transform(shape.begin(), shape.end(), std::back_inserter(shape_before),
[](T num) { return static_cast<int64_t>(num); });
DeviceShapeTransfer deviceShapeTransfer;
auto res = deviceShapeTransfer.GetDeviceShapeByFormat(shape_before, format, type, groups, input_hidden_size);
std::vector<T> out_shape;
std::transform(res.begin(), res.end(), std::back_inserter(out_shape),
[](int64_t num) { return static_cast<T>(num); });
(void)std::transform(res.begin(), res.end(), std::back_inserter(out_shape),
[](int64_t num) { return static_cast<T>(num); });
return out_shape;
}
} // namespace trans