forked from mindspore-Ecosystem/mindspore
commit
4f24ab161c
|
@ -115,10 +115,7 @@ static bool KernelBuildParallelCompile(const std::vector<CNodePtr> &kernels) {
|
|||
return akg_ret;
|
||||
}
|
||||
|
||||
bool KernelBuild(const std::vector<CNodePtr> &kernels) {
|
||||
TbeUtils::LoadCache();
|
||||
return device::ascend::KernelBuildParallelCompile(kernels);
|
||||
}
|
||||
bool KernelBuild(const std::vector<CNodePtr> &kernels) { return device::ascend::KernelBuildParallelCompile(kernels); }
|
||||
|
||||
namespace {
|
||||
bool IsAtomicNode(const CNodePtr &kernel_node) {
|
||||
|
|
|
@ -870,6 +870,8 @@ void TbeKernelCompileManager::TbeInitialize() {
|
|||
auto init_ret = DispatchCompileTask(init_json);
|
||||
auto json_ret = TurnStrToJson(init_ret);
|
||||
PrintInitResult(json_ret);
|
||||
// load cache before kernel build
|
||||
TbeUtils::LoadCache();
|
||||
}
|
||||
|
||||
void TbeKernelCompileManager::TbeFinalize() {
|
||||
|
|
|
@ -36,6 +36,15 @@ constexpr int64_t k4 = 4;
|
|||
static const std::set<TypeId> C0_64 = {kNumberTypeInt4};
|
||||
static const std::set<TypeId> C0_32 = {kNumberTypeUInt8, kNumberTypeInt8};
|
||||
namespace {
|
||||
const size_t hw_h = 1;
|
||||
const size_t hw_w = 2;
|
||||
const size_t fnz_w1 = 4;
|
||||
const size_t fnz_h1 = 3;
|
||||
const size_t fnz_h0 = 2;
|
||||
const size_t fnz_w0 = 1;
|
||||
const size_t fz_n0 = 1;
|
||||
const size_t fz_ni = 2;
|
||||
const size_t fz_c0 = 3;
|
||||
bool HasShapeDynamic(const ShapeVector &shape_list) {
|
||||
return std::any_of(shape_list.begin(), shape_list.end(), [](int64_t v) { return v == abstract::Shape::SHP_ANY; });
|
||||
}
|
||||
|
@ -239,7 +248,7 @@ ShapeVector GetRuntimePaddingShape(const AnfNodePtr &node, size_t index) {
|
|||
if (IsNeedPadding(format, host_shape.size())) {
|
||||
host_shape = PaddingShape(host_shape, format, AnfAlgo::GetOutputReshapeType(node, index));
|
||||
}
|
||||
std::transform(host_shape.begin(), host_shape.end(), std::back_inserter(shape), SizeToLong);
|
||||
(void)std::transform(host_shape.begin(), host_shape.end(), std::back_inserter(shape), SizeToLong);
|
||||
return shape;
|
||||
}
|
||||
|
||||
|
@ -295,7 +304,7 @@ void TransDataSrc2Fp16(const TypeIdArgs &args, void *dst, const int64_t data_siz
|
|||
}
|
||||
}
|
||||
|
||||
bool DataTypeTransfer::CastKernel(const TypeIdArgs &args, void *dst, int64_t data_size, DataTypeTransMode mode) {
|
||||
bool DataTypeTransfer::CastKernel(const TypeIdArgs &args, void *dst, int64_t data_size, DataTypeTransMode mode) const {
|
||||
using DtypeKernel = std::function<void(const TypeIdArgs &, void *, const int64_t)>;
|
||||
const std::map<DataTypeTransMode, DtypeKernel> cast_kernel_map{
|
||||
{DataTypeTransMode::FROM_BOOL_TO_UINT8, TransDataSrc2Dst<int8_t, uint8_t>},
|
||||
|
@ -323,10 +332,10 @@ bool DataTypeTransfer::CastKernel(const TypeIdArgs &args, void *dst, int64_t dat
|
|||
{DataTypeTransMode::FROM_FLOAT64_TO_FLOAT32, TransDataSrc2Dst<double, float>}};
|
||||
|
||||
if (mode == DataTypeTransMode::FROM_FLOAT_TO_FLOAT16) {
|
||||
device::FloatToHalf(dst, args.data, data_size);
|
||||
device::FloatToHalf(dst, args.data, LongToSize(data_size));
|
||||
return true;
|
||||
} else if (mode == DataTypeTransMode::FROM_FLOAT16_TO_FLOAT) {
|
||||
device::HalfToFloat(dst, args.data, data_size);
|
||||
device::HalfToFloat(dst, args.data, LongToSize(data_size));
|
||||
return true;
|
||||
}
|
||||
auto iter = cast_kernel_map.find(mode);
|
||||
|
@ -577,7 +586,7 @@ ShapeVector DeviceShapeTransfer::C1HWNCOC0DeviceShape(const ShapeVector &shape,
|
|||
return device_shape;
|
||||
}
|
||||
|
||||
ShapeVector DeviceShapeTransfer::FRAC_ZC04DeviceShape(const ShapeVector &shape, const TypeId &type) {
|
||||
ShapeVector DeviceShapeTransfer::FRAC_ZC04DeviceShape(const ShapeVector &shape, const TypeId &) {
|
||||
if (!CheckDims(shape)) {
|
||||
MS_LOG(EXCEPTION) << "Check dims failed.";
|
||||
}
|
||||
|
@ -628,7 +637,7 @@ ShapeVector DeviceShapeTransfer::ChannelLastDeviceShape(const ShapeVector &shape
|
|||
axis[dim - 1] = 1;
|
||||
ShapeVector device_shape;
|
||||
(void)std::transform(axis.begin(), axis.end(), std::back_inserter(device_shape),
|
||||
[&shape](int64_t n) { return shape[n]; });
|
||||
[&shape](size_t n) { return shape[n]; });
|
||||
return device_shape;
|
||||
}
|
||||
|
||||
|
@ -656,7 +665,7 @@ ShapeVector DeviceShapeTransfer::FRAC_NZDeviceShape(const ShapeVector &shape, co
|
|||
return device_shape;
|
||||
}
|
||||
|
||||
ShapeVector DeviceShapeTransfer::FRAC_ZN_LSTMDeviceShape(const ShapeVector &shape, const TypeId &type) {
|
||||
ShapeVector DeviceShapeTransfer::FRAC_ZN_LSTMDeviceShape(const ShapeVector &shape, const TypeId &) {
|
||||
ShapeVector device_shape;
|
||||
const int64_t lstm_ni = 4;
|
||||
const int64_t ni = 16;
|
||||
|
@ -971,15 +980,7 @@ bool FormatTransfer::TO_NCHW(const FormatArgs &args, void *result) {
|
|||
bool FormatTransfer::NCHW_TO_FRAC_Z(const FormatArgs &args, void *result) {
|
||||
MS_LOG(DEBUG) << "Trans format from nchw to frac_z";
|
||||
MS_EXCEPTION_IF_NULL(result);
|
||||
if (args.host_shape.size() != kNchwDims) {
|
||||
MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
|
||||
return false;
|
||||
}
|
||||
auto size = SizeToLong(abstract::TypeIdSize(args.src_data_type));
|
||||
if (size < 1) {
|
||||
MS_LOG(ERROR) << "Illegal dtype: " << args.src_data_type;
|
||||
return false;
|
||||
}
|
||||
auto size = Common4DCheck(args);
|
||||
auto n = args.host_shape[kN];
|
||||
auto c = args.host_shape[kC];
|
||||
auto h = args.host_shape[kH];
|
||||
|
@ -1048,15 +1049,15 @@ bool FormatTransfer::NCHW_TO_FRAC_NZ(const FormatArgs &args, void *result) {
|
|||
return false;
|
||||
}
|
||||
auto times = hw_shape.at(0);
|
||||
auto h = hw_shape.at(1);
|
||||
auto w = hw_shape.at(2);
|
||||
auto h = hw_shape.at(hw_h);
|
||||
auto w = hw_shape.at(hw_w);
|
||||
auto hw = h * w;
|
||||
|
||||
auto shape_size = args.device_shape.size();
|
||||
auto w1 = args.device_shape[shape_size - 4];
|
||||
auto h1 = args.device_shape[shape_size - 3];
|
||||
auto h0 = args.device_shape[shape_size - 2];
|
||||
auto w0 = args.device_shape[shape_size - 1];
|
||||
auto w1 = args.device_shape[shape_size - fnz_w1];
|
||||
auto h1 = args.device_shape[shape_size - fnz_h1];
|
||||
auto h0 = args.device_shape[shape_size - fnz_h0];
|
||||
auto w0 = args.device_shape[shape_size - fnz_w0];
|
||||
auto h1h0w0 = h1 * h0 * w0;
|
||||
auto w1h1h0w0 = w1 * h1h0w0;
|
||||
auto num_w1 = w / w0;
|
||||
|
@ -1136,15 +1137,7 @@ bool FormatTransfer::NCHW_TO_FRAC_ZC04(const FormatArgs &args, void *result) {
|
|||
bool FormatTransfer::NCHW_TO_NC1HWC0(const FormatArgs &args, void *result) {
|
||||
MS_LOG(DEBUG) << "Trans format from nchw to Nc1h1wc0";
|
||||
MS_EXCEPTION_IF_NULL(result);
|
||||
if (args.host_shape.size() != kNchwDims) {
|
||||
MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
|
||||
return false;
|
||||
}
|
||||
auto size = SizeToLong(abstract::TypeIdSize(args.src_data_type));
|
||||
if (size < 1) {
|
||||
MS_LOG(ERROR) << "Illegal dtype: " << args.src_data_type;
|
||||
return false;
|
||||
}
|
||||
auto size = Common4DCheck(args);
|
||||
auto total_size = abstract::ShapeSize(args.device_shape) * size;
|
||||
if (total_size != SizeToLong(args.device_size)) {
|
||||
MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
|
||||
|
@ -1349,15 +1342,7 @@ bool FormatTransfer::NCDHW_TO_FRAC_Z3D(const FormatArgs &args, void *result) {
|
|||
|
||||
bool FormatTransfer::NCHW_TO_FRAC_Z_WITH_GROPUS(const FormatArgs &args, void *result, bool to_device, int64_t groups) {
|
||||
MS_EXCEPTION_IF_NULL(result);
|
||||
if (args.host_shape.size() != kNchwDims) {
|
||||
MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
|
||||
return false;
|
||||
}
|
||||
auto size = SizeToLong(abstract::TypeIdSize(args.src_data_type));
|
||||
if (size < 1) {
|
||||
MS_LOG(ERROR) << "Illegal dtype: " << args.src_data_type;
|
||||
return false;
|
||||
}
|
||||
auto size = Common4DCheck(args);
|
||||
auto n_dim = args.host_shape[kN];
|
||||
auto c_dim = args.host_shape[kC];
|
||||
auto h_dim = args.host_shape[kH];
|
||||
|
@ -1387,7 +1372,7 @@ bool FormatTransfer::NCHW_TO_FRAC_Z_WITH_GROPUS(const FormatArgs &args, void *re
|
|||
if (dst_size == 0) {
|
||||
return true;
|
||||
}
|
||||
auto ret = memset_s(result, dst_size, 0, dst_size);
|
||||
auto ret = memset_s(result, LongToSize(dst_size), 0, LongToSize(dst_size));
|
||||
if (ret != EOK) {
|
||||
MS_LOG(ERROR) << "memset failed";
|
||||
return false;
|
||||
|
@ -1426,15 +1411,7 @@ bool FormatTransfer::NCHW_TO_FRAC_Z_WITH_GROPUS(const FormatArgs &args, void *re
|
|||
bool FormatTransfer::NC1HWC0_TO_NCHW(const FormatArgs &args, void *result) {
|
||||
MS_LOG(DEBUG) << "Trans format from nc1h1wc0 to nchw";
|
||||
MS_EXCEPTION_IF_NULL(result);
|
||||
if (args.host_shape.size() != kNchwDims) {
|
||||
MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
|
||||
return false;
|
||||
}
|
||||
auto size = SizeToLong(abstract::TypeIdSize(args.src_data_type));
|
||||
if (size < 1) {
|
||||
MS_LOG(ERROR) << "Illegal dtype: " << args.src_data_type;
|
||||
return false;
|
||||
}
|
||||
auto size = Common4DCheck(args);
|
||||
auto total_size = abstract::ShapeSize(args.device_shape) * size;
|
||||
if (total_size != SizeToLong(args.device_size)) {
|
||||
MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
|
||||
|
@ -1517,24 +1494,16 @@ bool FormatTransfer::C1HWNCOC0_TO_NCHW(const FormatArgs &args, void *result) {
|
|||
bool FormatTransfer::FRAC_Z_TO_NCHW(const FormatArgs &args, void *result) {
|
||||
MS_LOG(DEBUG) << "Trans format from frac_z to nchw";
|
||||
MS_EXCEPTION_IF_NULL(result);
|
||||
if (args.host_shape.size() != kNchwDims) {
|
||||
MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
|
||||
return false;
|
||||
}
|
||||
auto size = SizeToLong(abstract::TypeIdSize(args.src_data_type));
|
||||
if (size < 1) {
|
||||
MS_LOG(ERROR) << "Illegal dtype: " << args.src_data_type;
|
||||
return false;
|
||||
}
|
||||
auto size = Common4DCheck(args);
|
||||
auto total_size = abstract::ShapeSize(args.device_shape) * size;
|
||||
if (total_size != SizeToLong(args.device_size)) {
|
||||
MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
|
||||
return false;
|
||||
}
|
||||
|
||||
auto n0 = args.device_shape.at(1);
|
||||
auto ni = args.device_shape.at(2);
|
||||
auto c0 = args.device_shape.at(3);
|
||||
auto n0 = args.device_shape.at(fz_n0);
|
||||
auto ni = args.device_shape.at(fz_ni);
|
||||
auto c0 = args.device_shape.at(fz_c0);
|
||||
auto n = args.host_shape[kN];
|
||||
auto c = args.host_shape[kC];
|
||||
auto h = args.host_shape[kH];
|
||||
|
@ -1590,15 +1559,15 @@ bool FormatTransfer::FRAC_NZ_TO_NCHW(const FormatArgs &args, void *result) {
|
|||
return false;
|
||||
}
|
||||
auto times = hw_shape.at(0);
|
||||
auto h = hw_shape.at(1);
|
||||
auto w = hw_shape.at(2);
|
||||
auto h = hw_shape.at(hw_h);
|
||||
auto w = hw_shape.at(hw_w);
|
||||
auto hw = h * w;
|
||||
|
||||
auto shape_size = args.device_shape.size();
|
||||
auto w1 = args.device_shape[shape_size - 4];
|
||||
auto h1 = args.device_shape[shape_size - 3];
|
||||
auto h0 = args.device_shape[shape_size - 2];
|
||||
auto w0 = args.device_shape[shape_size - 1];
|
||||
auto w1 = args.device_shape[shape_size - fnz_w1];
|
||||
auto h1 = args.device_shape[shape_size - fnz_h1];
|
||||
auto h0 = args.device_shape[shape_size - fnz_h0];
|
||||
auto w0 = args.device_shape[shape_size - fnz_w0];
|
||||
auto h1h0w0 = h1 * h0 * w0;
|
||||
auto w1h1h0w0 = w1 * h1h0w0;
|
||||
auto num_w1 = w / w0;
|
||||
|
@ -1747,6 +1716,18 @@ bool FormatTransfer::FRAC_Z_TO_NCHW_WITH_GROUPS(const FormatArgs &args, void *re
|
|||
return NCHW_TO_FRAC_Z_WITH_GROPUS(args, result, false, groups);
|
||||
}
|
||||
|
||||
int64_t FormatTransfer::Common4DCheck(const FormatArgs &args) {
|
||||
if (args.host_shape.size() != kNchwDims) {
|
||||
MS_LOG(EXCEPTION) << "Invalid host shape, host shape dims:" << args.host_shape.size()
|
||||
<< ", expect dims:" << kNchwDims;
|
||||
}
|
||||
auto size = SizeToLong(abstract::TypeIdSize(args.src_data_type));
|
||||
if (size < 1) {
|
||||
MS_LOG(EXCEPTION) << "Illegal dtype: " << args.src_data_type;
|
||||
}
|
||||
return size;
|
||||
}
|
||||
|
||||
// ######################## RANGE TRANS ########################
|
||||
RangePair ShapeRangeTransfer::GetRealRange(const RangePair &ori_range, const std::string &format, const TypeId &type) {
|
||||
const std::set<std::string> no_need_change = {kOpFormat_ND, kOpFormat_DEFAULT, kOpFormat_NCHW, kOpFormat_NCDHW};
|
||||
|
@ -1844,9 +1825,10 @@ RangePair ShapeRangeTransfer::FRAC_ZRange(const RangePair &ori_range, const Type
|
|||
const std::pair<int64_t, int64_t> r0 = {ori_range[kH].first * ori_range[kW].first * cin16.first / cube,
|
||||
ori_range[kH].second * ori_range[kW].second * cin16.second / cube};
|
||||
const std::pair<int64_t, int64_t> r1 = {cout16.first / kNiSize, cout16.second / kNiSize};
|
||||
const std::pair<int64_t, int64_t> co = {kNiSize, kNiSize};
|
||||
dst_range.push_back(r0);
|
||||
dst_range.push_back(r1);
|
||||
dst_range.push_back({kNiSize, kNiSize});
|
||||
dst_range.push_back(co);
|
||||
dst_range.push_back(c0);
|
||||
return dst_range;
|
||||
}
|
||||
|
@ -1865,9 +1847,10 @@ RangePair ShapeRangeTransfer::FRAC_NZRange(const RangePair &ori_range, const Typ
|
|||
(ori_range[ori_size - 1].second - 1) / cube + 1};
|
||||
const std::pair<int64_t, int64_t> h1 = {(ori_range[ori_size - kDims2].first - 1) / kNiSize + 1,
|
||||
(ori_range[ori_size - kDims2].second - 1) / kNiSize + 1};
|
||||
const std::pair<int64_t, int64_t> co = {kNiSize, kNiSize};
|
||||
dst_range.push_back(w1);
|
||||
dst_range.push_back(h1);
|
||||
dst_range.push_back({kNiSize, kNiSize});
|
||||
dst_range.push_back(co);
|
||||
dst_range.push_back(c0);
|
||||
return dst_range;
|
||||
}
|
||||
|
|
|
@ -147,7 +147,7 @@ class DataTypeTransfer {
|
|||
{std::pair<TypeId, TypeId>(kNumberTypeBool, kNumberTypeUInt8), DataTypeTransMode::FROM_BOOL_TO_UINT8},
|
||||
{std::pair<TypeId, TypeId>(kNumberTypeBool, kNumberTypeFloat16), DataTypeTransMode::FROM_BOOL_TO_FLOAT16}};
|
||||
|
||||
bool CastKernel(const TypeIdArgs &args, void *dst, int64_t data_size, DataTypeTransMode mode);
|
||||
bool CastKernel(const TypeIdArgs &args, void *dst, int64_t data_size, DataTypeTransMode mode) const;
|
||||
};
|
||||
|
||||
/**
|
||||
|
@ -251,6 +251,9 @@ class FormatTransfer {
|
|||
static bool FRAC_Z3D_TO_NCDHW(const FormatArgs &args, void *result);
|
||||
static bool NDC1HWC0_TO_NCDHW(const FormatArgs &args, void *result);
|
||||
static bool FRAC_Z_TO_NCHW_WITH_GROUPS(const FormatArgs &args, void *result, int64_t groups);
|
||||
|
||||
// common check_func
|
||||
static int64_t Common4DCheck(const FormatArgs &args);
|
||||
};
|
||||
|
||||
/**
|
||||
|
@ -382,7 +385,7 @@ std::vector<T> PaddingShapeTo4dDefault(const std::vector<T> &shape) {
|
|||
shape_4d[kW] = shape[kH];
|
||||
break;
|
||||
case kNchwDims:
|
||||
std::copy(shape.begin(), shape.end(), shape_4d.begin());
|
||||
(void)std::copy(shape.begin(), shape.end(), shape_4d.begin());
|
||||
break;
|
||||
default:
|
||||
MS_LOG(EXCEPTION) << "Unexpected shape : " << shape;
|
||||
|
@ -449,13 +452,13 @@ template <typename T>
|
|||
std::vector<T> TransShapeToDevice(const std::vector<T> &shape, const std::string &format, const AnfNodePtr &node,
|
||||
size_t index, TypeId type, bool is_output = true) {
|
||||
ShapeVector shape_before;
|
||||
std::transform(shape.begin(), shape.end(), std::back_inserter(shape_before),
|
||||
[](T num) { return static_cast<int64_t>(num); });
|
||||
(void)std::transform(shape.begin(), shape.end(), std::back_inserter(shape_before),
|
||||
[](T num) { return static_cast<int64_t>(num); });
|
||||
DeviceShapeTransfer deviceShapeTransfer;
|
||||
auto res = deviceShapeTransfer.GetDeviceShapeByFormat(shape_before, format, node, index, type, is_output);
|
||||
std::vector<T> out_shape;
|
||||
std::transform(res.begin(), res.end(), std::back_inserter(out_shape),
|
||||
[](int64_t num) { return static_cast<T>(num); });
|
||||
(void)std::transform(res.begin(), res.end(), std::back_inserter(out_shape),
|
||||
[](int64_t num) { return static_cast<T>(num); });
|
||||
return out_shape;
|
||||
}
|
||||
|
||||
|
@ -463,13 +466,13 @@ template <typename T>
|
|||
std::vector<T> TransShapeToDevice(const std::vector<T> &shape, const std::string &format, TypeId type,
|
||||
int64_t groups = 1, const ShapeVector &input_hidden_size = {kAlign16, kAlign16}) {
|
||||
ShapeVector shape_before;
|
||||
std::transform(shape.begin(), shape.end(), std::back_inserter(shape_before),
|
||||
[](T num) { return static_cast<int64_t>(num); });
|
||||
(void)std::transform(shape.begin(), shape.end(), std::back_inserter(shape_before),
|
||||
[](T num) { return static_cast<int64_t>(num); });
|
||||
DeviceShapeTransfer deviceShapeTransfer;
|
||||
auto res = deviceShapeTransfer.GetDeviceShapeByFormat(shape_before, format, type, groups, input_hidden_size);
|
||||
std::vector<T> out_shape;
|
||||
std::transform(res.begin(), res.end(), std::back_inserter(out_shape),
|
||||
[](int64_t num) { return static_cast<T>(num); });
|
||||
(void)std::transform(res.begin(), res.end(), std::back_inserter(out_shape),
|
||||
[](int64_t num) { return static_cast<T>(num); });
|
||||
return out_shape;
|
||||
}
|
||||
} // namespace trans
|
||||
|
|
Loading…
Reference in New Issue