forked from mindspore-Ecosystem/mindspore
!27845 using data type to get cube size stage 2
Merge pull request !27845 from liubuyu/SBB
This commit is contained in:
commit
4fd4014bc0
|
@ -611,5 +611,901 @@ ShapeVector DeviceShapeTransfer::GetAttrInputAndHiddenSize(const AnfNodePtr &nod
|
|||
input_hidden_size[1] = AnfAlgo::GetNodeAttr<int64_t>(node, kAttrHiddenSize);
|
||||
return input_hidden_size;
|
||||
}
|
||||
|
||||
/**###################### DATA FORMAT TRANS ################################*/
|
||||
inline void SetData(int64_t size, bool pad_zero, int64_t src_idx, int64_t dst_idx, const FormatArgs &args,
|
||||
void *result) {
|
||||
switch (size) {
|
||||
case b1:
|
||||
static_cast<uint8_t *>(result)[dst_idx] = pad_zero ? 0 : static_cast<const uint8_t *>(args.data)[src_idx];
|
||||
break;
|
||||
case b2:
|
||||
static_cast<uint16_t *>(result)[dst_idx] = pad_zero ? 0 : static_cast<const uint16_t *>(args.data)[src_idx];
|
||||
break;
|
||||
case b4:
|
||||
static_cast<uint32_t *>(result)[dst_idx] = pad_zero ? 0 : static_cast<const uint32_t *>(args.data)[src_idx];
|
||||
break;
|
||||
case b8:
|
||||
static_cast<uint64_t *>(result)[dst_idx] = pad_zero ? 0 : static_cast<const uint64_t *>(args.data)[src_idx];
|
||||
break;
|
||||
default:
|
||||
MS_LOG(EXCEPTION) << "Trans data not support size " << size;
|
||||
}
|
||||
}
|
||||
|
||||
bool FormatTransfer::TransDataByFormat(const FormatArgs &args, void *result, const AnfNodePtr &node, size_t index,
|
||||
bool is_forward) {
|
||||
int64_t groups = 1;
|
||||
if (args.device_format == kOpFormat_FRAC_Z && node != nullptr) {
|
||||
groups = AnfAlgo::GetAttrGroups(node, index);
|
||||
}
|
||||
if (is_forward) {
|
||||
return TransDataForwardCore(args, result, groups);
|
||||
}
|
||||
return TransDataBackwordCore(args, result, groups);
|
||||
}
|
||||
|
||||
bool FormatTransfer::TransDataForwardCore(const FormatArgs &args, void *result, int64_t groups) {
|
||||
MS_LOG(DEBUG) << "Start trans format.";
|
||||
if (abstract::TypeIdSize(args.src_data_type) < 1) {
|
||||
MS_LOG(ERROR) << "Invalid datatype:" << args.src_data_type;
|
||||
return false;
|
||||
}
|
||||
if (groups > 1 && args.device_format == kOpFormat_FRAC_Z) {
|
||||
return NCHW_TO_FRAC_Z_WITH_GROPUS(args, result, true, groups);
|
||||
}
|
||||
auto iter = format_trans_fp_map.find(args.device_format);
|
||||
if (iter == format_trans_fp_map.end()) {
|
||||
MS_LOG(EXCEPTION) << "Unexpected format[" << args.device_format << "]";
|
||||
}
|
||||
return iter->second(args, result);
|
||||
}
|
||||
|
||||
bool FormatTransfer::TransDataBackwordCore(const FormatArgs &args, void *result, int64_t groups) {
|
||||
MS_LOG(DEBUG) << "Start trans format.";
|
||||
if (abstract::TypeIdSize(args.src_data_type) < 1) {
|
||||
MS_LOG(ERROR) << "Invalid datatype..";
|
||||
return false;
|
||||
}
|
||||
if (groups > 1 && args.device_format == kOpFormat_FRAC_Z) {
|
||||
return FRAC_Z_TO_NCHW_WITH_GROUPS(args, result, groups);
|
||||
}
|
||||
auto iter = format_trans_bp_map.find(args.device_format);
|
||||
if (iter == format_trans_bp_map.end()) {
|
||||
MS_LOG(EXCEPTION) << "Unexpected format[" << args.device_format << "]";
|
||||
}
|
||||
return iter->second(args, result);
|
||||
}
|
||||
|
||||
bool FormatTransfer::CheckArgs(const FormatArgs &args, int64_t *size) {
|
||||
if (args.host_shape.size() != kNchwDims) {
|
||||
MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
|
||||
return false;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(size);
|
||||
*size = SizeToLong(abstract::TypeIdSize(args.src_data_type));
|
||||
if (*size < 1) {
|
||||
MS_LOG(ERROR) << "Illegal dtype.";
|
||||
return false;
|
||||
}
|
||||
auto total_size = abstract::ShapeSize(args.device_shape) * (*size);
|
||||
if (total_size != SizeToLong(args.device_size)) {
|
||||
MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool FormatTransfer::TransShapeToHW_NZ(const ShapeVector &host_shape, ShapeVector *hw_shape) {
|
||||
MS_EXCEPTION_IF_NULL(hw_shape);
|
||||
if (host_shape.empty()) {
|
||||
MS_LOG(ERROR) << "Size of vector is 0.";
|
||||
return false;
|
||||
}
|
||||
switch (host_shape.size()) {
|
||||
case 1:
|
||||
hw_shape->push_back(1);
|
||||
hw_shape->push_back(1);
|
||||
hw_shape->push_back(host_shape[0]);
|
||||
return true;
|
||||
default:
|
||||
auto size = host_shape.size();
|
||||
if (size < kDim2) {
|
||||
MS_LOG(ERROR) << "Illegal size.";
|
||||
return false;
|
||||
}
|
||||
int64_t times = 1;
|
||||
for (size_t i = 0; i != size - kDim2; i++) {
|
||||
times *= host_shape[i];
|
||||
}
|
||||
hw_shape->push_back(times);
|
||||
hw_shape->push_back(host_shape[size - kDim2]);
|
||||
hw_shape->push_back(host_shape[size - kDim1]);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
bool FormatTransfer::NCHW_TO_4D(const FormatArgs &args, void *result) {
|
||||
// trans nchw to NHWC or HWCN
|
||||
MS_LOG(DEBUG) << "Trans format from nchw to " << args.device_format;
|
||||
MS_EXCEPTION_IF_NULL(result);
|
||||
int64_t size = 0;
|
||||
if (!CheckArgs(args, &size)) {
|
||||
MS_LOG(ERROR) << "Check args failed.";
|
||||
return false;
|
||||
}
|
||||
auto n = args.host_shape[kN];
|
||||
auto c = args.host_shape[kC];
|
||||
auto h = args.host_shape[kH];
|
||||
auto w = args.host_shape[kW];
|
||||
for (int64_t ni = 0; ni < n; ni++) {
|
||||
for (int64_t ci = 0; ci < c; ci++) {
|
||||
for (int64_t hi = 0; hi < h; hi++) {
|
||||
for (int64_t wi = 0; wi < w; wi++) {
|
||||
auto src_idx = ni * c * h * w + ci * h * w + hi * w + wi;
|
||||
int64_t dst_idx = 0;
|
||||
if (args.device_format == kOpFormat_NHWC) {
|
||||
dst_idx = ni * h * w * c + hi * w * c + wi * c + ci;
|
||||
} else if (args.device_format == kOpFormat_HWCN) {
|
||||
dst_idx = hi * w * c * n + wi * c * n + ci * n + ni;
|
||||
}
|
||||
SetData(size, false, src_idx, dst_idx, args, result);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool FormatTransfer::TO_NCHW(const FormatArgs &args, void *result) {
|
||||
MS_LOG(DEBUG) << "Trans format to nchw from " << args.device_format;
|
||||
MS_EXCEPTION_IF_NULL(result);
|
||||
int64_t size = 0;
|
||||
if (!CheckArgs(args, &size)) {
|
||||
MS_LOG(ERROR) << "Check args failed.";
|
||||
return false;
|
||||
}
|
||||
auto n = args.host_shape[kN];
|
||||
auto c = args.host_shape[kC];
|
||||
auto h = args.host_shape[kH];
|
||||
auto w = args.host_shape[kW];
|
||||
for (int64_t ni = 0; ni < n; ni++) {
|
||||
for (int64_t ci = 0; ci < c; ci++) {
|
||||
for (int64_t hi = 0; hi < h; hi++) {
|
||||
for (int64_t wi = 0; wi < w; wi++) {
|
||||
auto dst_idx = ni * c * h * w + ci * h * w + hi * w + wi;
|
||||
int64_t src_idx = 0;
|
||||
if (args.device_format == kOpFormat_NHWC) {
|
||||
src_idx = ni * h * w * c + hi * w * c + wi * c + ci;
|
||||
} else if (args.device_format == kOpFormat_HWCN) {
|
||||
src_idx = hi * w * c * n + wi * c * n + ci * n + ni;
|
||||
}
|
||||
SetData(size, false, src_idx, dst_idx, args, result);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool FormatTransfer::NCHW_TO_FRAC_Z(const FormatArgs &args, void *result) {
|
||||
MS_LOG(DEBUG) << "Trans format from nchw to frac_z";
|
||||
MS_EXCEPTION_IF_NULL(result);
|
||||
if (args.host_shape.size() != kNchwDims) {
|
||||
MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
|
||||
return false;
|
||||
}
|
||||
auto size = SizeToLong(abstract::TypeIdSize(args.src_data_type));
|
||||
if (size < 1) {
|
||||
MS_LOG(ERROR) << "Illegal dtype.";
|
||||
return false;
|
||||
}
|
||||
auto n = args.host_shape[kN];
|
||||
auto c = args.host_shape[kC];
|
||||
auto h = args.host_shape[kH];
|
||||
auto w = args.host_shape[kW];
|
||||
auto c0 = GetCubeSizeByType(args.src_data_type);
|
||||
auto c1 = DivCeil(c, c0);
|
||||
auto hw = h * w;
|
||||
auto chw = c * hw;
|
||||
auto hwc0 = hw * c0;
|
||||
auto nchw = n * chw;
|
||||
|
||||
auto hf_cnt = DivCeil(n, kNiSize);
|
||||
auto vf_cnt = c1 * hw;
|
||||
auto fractal_ele_cnt = c0 * kNiSize;
|
||||
auto total_ele_cnt = hf_cnt * vf_cnt * fractal_ele_cnt;
|
||||
auto dst_size = total_ele_cnt * size;
|
||||
if (dst_size != SizeToLong(args.device_size)) {
|
||||
MS_LOG(ERROR) << "Illegal total data size."
|
||||
<< "dst size is :" << dst_size << "device size is :" << args.device_size;
|
||||
return false;
|
||||
}
|
||||
|
||||
for (int64_t vfi = 0; vfi < vf_cnt; vfi++) {
|
||||
auto vf_base_i = vfi * hf_cnt; // vertical fractal matrix base index
|
||||
for (int64_t hfi = 0; hfi < hf_cnt; hfi++) {
|
||||
auto gfi = vf_base_i + hfi; // global fractal matrix index
|
||||
auto src_n_offset = hfi * chw * kNiSize;
|
||||
auto src_f_offset = src_n_offset + vfi % hw + vfi / hw * hwc0;
|
||||
for (int64_t row = 0; row < c0; row++) {
|
||||
auto src_ci = vfi / hw * c0 + row;
|
||||
auto src_row_offset = src_f_offset + row * hw;
|
||||
for (int64_t col = 0; col < kNiSize; col++) {
|
||||
auto src_ni = hfi * kNiSize + col;
|
||||
auto src_idx = src_row_offset + chw * col;
|
||||
auto dst_idx = gfi * fractal_ele_cnt + col * c0 + row;
|
||||
auto pad_zero = src_ni >= n || src_idx >= nchw || src_ci >= c;
|
||||
SetData(size, pad_zero, src_idx, dst_idx, args, result);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool FormatTransfer::NCHW_TO_FRAC_NZ(const FormatArgs &args, void *result) {
|
||||
MS_LOG(DEBUG) << "Trans format from nchw to frac_nz.";
|
||||
MS_EXCEPTION_IF_NULL(result);
|
||||
ShapeVector hw_shape;
|
||||
if (!TransShapeToHW_NZ(args.host_shape, &hw_shape)) {
|
||||
MS_LOG(ERROR) << "Trans shape failed..";
|
||||
return false;
|
||||
}
|
||||
if (hw_shape.size() < kDim3 || args.device_shape.size() < kDim4) {
|
||||
MS_LOG(ERROR) << "Invalid shape size.";
|
||||
return false;
|
||||
}
|
||||
auto size = SizeToLong(abstract::TypeIdSize(args.src_data_type));
|
||||
if (size < 1) {
|
||||
MS_LOG(ERROR) << "Illegal dtype";
|
||||
return false;
|
||||
}
|
||||
|
||||
auto dst_size = abstract::ShapeSize(args.device_shape) * size;
|
||||
if (dst_size != SizeToLong(args.device_size)) {
|
||||
MS_LOG(ERROR) << "Illegal total data size, total_size:" << dst_size << ", device_size:" << args.device_size;
|
||||
return false;
|
||||
}
|
||||
auto times = hw_shape.at(0);
|
||||
auto h = hw_shape.at(1);
|
||||
auto w = hw_shape.at(2);
|
||||
auto hw = h * w;
|
||||
|
||||
auto shape_size = args.device_shape.size();
|
||||
auto w1 = args.device_shape[shape_size - 4];
|
||||
auto h1 = args.device_shape[shape_size - 3];
|
||||
auto h0 = args.device_shape[shape_size - 2];
|
||||
auto w0 = args.device_shape[shape_size - 1];
|
||||
auto h1h0w0 = h1 * h0 * w0;
|
||||
auto w1h1h0w0 = w1 * h1h0w0;
|
||||
auto num_w1 = w / w0;
|
||||
|
||||
for (int64_t times_idx = 0; times_idx < times; times_idx++) {
|
||||
auto times_head = times_idx * w1h1h0w0;
|
||||
auto src_times_head = times_idx * hw;
|
||||
for (int64_t h1h0_idx = 0; h1h0_idx < h; h1h0_idx++) {
|
||||
auto h1h0_head = times_head + h1h0_idx * w0;
|
||||
auto src_h_head = src_times_head + h1h0_idx * w;
|
||||
for (int64_t w1_idx = 0; w1_idx < num_w1; w1_idx++) {
|
||||
for (int64_t i = 0; i < w0; ++i) {
|
||||
int64_t src_idx = src_h_head + w1_idx * w0 + i;
|
||||
int64_t dst_idx = h1h0_head + w1_idx * h1h0w0 + i;
|
||||
SetData(size, false, src_idx, dst_idx, args, result);
|
||||
}
|
||||
}
|
||||
auto w1_head = num_w1 * w0;
|
||||
for (int64_t w0_idx = 0; w1_head + w0_idx < w; w0_idx++) {
|
||||
auto src_w_idx = w1_head + w0_idx;
|
||||
int64_t dst_idx = h1h0_head + num_w1 * h1h0w0 + w0_idx;
|
||||
int64_t src_idx = src_h_head + src_w_idx;
|
||||
SetData(size, false, src_idx, dst_idx, args, result);
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool FormatTransfer::NCHW_TO_FRAC_ZC04(const FormatArgs &args, void *result) {
|
||||
// trans nchw to FracZc04
|
||||
MS_LOG(DEBUG) << "Trans format from nchw to FracZc04.";
|
||||
MS_EXCEPTION_IF_NULL(result);
|
||||
int64_t size = 0;
|
||||
if (!CheckArgs(args, &size)) {
|
||||
MS_LOG(ERROR) << "Check args failed.";
|
||||
return false;
|
||||
}
|
||||
auto cube = GetCubeSizeByType(args.src_data_type);
|
||||
auto n = args.host_shape[kN];
|
||||
auto c = args.host_shape[kC];
|
||||
auto h = args.host_shape[kH];
|
||||
auto w = args.host_shape[kW];
|
||||
const int64_t c0 = 4;
|
||||
auto c1 = DivCeil(c, c0);
|
||||
auto hwc0 = h * w * c0;
|
||||
auto hwc = h * w * c;
|
||||
auto nhwc = n * h * w * c;
|
||||
auto n_cnt = DivCeil(n, kNiSize);
|
||||
auto v_cnt = DivCeil(h * w * c0 * c1, cube);
|
||||
int64_t dst_idx = 0;
|
||||
|
||||
for (int64_t vi = 0; vi < v_cnt; vi++) {
|
||||
for (int64_t ni = 0; ni < n_cnt; ni++) {
|
||||
for (int64_t col = 0; col < kNiSize; col++) {
|
||||
for (int64_t row = 0; row < kNiSize; row++) {
|
||||
int64_t cur_cube_n = kNiSize * ni + col;
|
||||
int64_t cur_cube_c1hwc0 = kNiSize * vi + row;
|
||||
auto desc_g = cur_cube_n / n;
|
||||
auto desc_n = cur_cube_n % n;
|
||||
auto desc_c1 = cur_cube_c1hwc0 / hwc0;
|
||||
auto desc_c0 = cur_cube_c1hwc0 % c0;
|
||||
auto desc_h = (cur_cube_c1hwc0 - hwc0 * desc_c1) / (w * c0);
|
||||
auto desc_w = (cur_cube_c1hwc0 - hwc0 * desc_c1 - w * c0 * desc_h) / c0;
|
||||
auto c_idx = desc_c1 * c0 + desc_c0;
|
||||
auto src_idx = desc_g * nhwc + desc_n * hwc + c_idx * h * w + desc_h * w + desc_w;
|
||||
auto pad_zero = desc_g >= 1 || desc_n >= n || c_idx >= c;
|
||||
SetData(size, pad_zero, src_idx, dst_idx, args, result);
|
||||
dst_idx++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool FormatTransfer::NCHW_TO_NC1HWC0(const FormatArgs &args, void *result) {
|
||||
MS_LOG(DEBUG) << "Trans format from nchw to Nc1h1wc0";
|
||||
MS_EXCEPTION_IF_NULL(result);
|
||||
if (args.host_shape.size() != kNchwDims) {
|
||||
MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
|
||||
return false;
|
||||
}
|
||||
auto size = SizeToLong(abstract::TypeIdSize(args.src_data_type));
|
||||
if (size < 1) {
|
||||
MS_LOG(ERROR) << "Illegal dtype.";
|
||||
return false;
|
||||
}
|
||||
auto total_size = abstract::ShapeSize(args.device_shape) * size;
|
||||
if (total_size != SizeToLong(args.device_size)) {
|
||||
MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
|
||||
return false;
|
||||
}
|
||||
|
||||
auto n = args.host_shape[kN];
|
||||
auto c = args.host_shape[kC];
|
||||
auto h = args.host_shape[kH];
|
||||
auto w = args.host_shape[kW];
|
||||
auto c0 = GetCubeSizeByType(args.src_data_type);
|
||||
if (args.device_format == kOpFormat_NC1HWC0_C04) {
|
||||
c0 = kCubeSize_C04;
|
||||
}
|
||||
auto c1 = DivCeil(c, c0);
|
||||
auto hw = h * w;
|
||||
auto chw = c * hw;
|
||||
auto c1hwc0 = c1 * hw * c0;
|
||||
auto wc0 = w * c0;
|
||||
|
||||
for (int64_t n_idx = 0; n_idx < n; n_idx++) {
|
||||
int64_t n_head_addr = n_idx * c1hwc0;
|
||||
for (int64_t c1_idx = 0; c1_idx < c1; c1_idx++) {
|
||||
int64_t c1_head_addr = n_head_addr + c1_idx * hw * c0;
|
||||
for (int64_t h_idx = 0; h_idx < h; h_idx++) {
|
||||
int64_t h_head_addr = c1_head_addr + h_idx * wc0;
|
||||
for (int64_t w_idx = 0; w_idx < w; w_idx++) {
|
||||
int64_t w_head_addr = h_head_addr + w_idx * c0;
|
||||
for (int64_t c0_idx = 0; c0_idx < c0; c0_idx++) {
|
||||
int64_t dst_idx = c0_idx + w_head_addr;
|
||||
int64_t c_idx = c0_idx + c1_idx * c0;
|
||||
int64_t src_idx = n_idx * chw + c_idx * hw + h_idx * w + w_idx;
|
||||
auto pad_zero = c_idx >= c;
|
||||
SetData(size, pad_zero, src_idx, dst_idx, args, result);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool FormatTransfer::NCHW_TO_NC1HWC04(const FormatArgs &args, void *result) {
|
||||
MS_LOG(DEBUG) << "Trans format from nchw to Nc1hwc04.";
|
||||
return NCHW_TO_NC1HWC0(args, result);
|
||||
}
|
||||
|
||||
bool FormatTransfer::NCHW_TO_C1HWNCOC0(const FormatArgs &args, void *result) {
|
||||
// trans nchw to c1hwncoc0
|
||||
MS_LOG(DEBUG) << "Trans format from nchw to c1hwncoc0.";
|
||||
MS_EXCEPTION_IF_NULL(result);
|
||||
int64_t size = 0;
|
||||
if (!CheckArgs(args, &size)) {
|
||||
MS_LOG(ERROR) << "Check args failed.";
|
||||
return false;
|
||||
}
|
||||
auto n = args.host_shape[kN];
|
||||
auto c = args.host_shape[kC];
|
||||
auto h = args.host_shape[kH];
|
||||
auto w = args.host_shape[kW];
|
||||
const int co_idx = 4;
|
||||
const int c0_idx = 5;
|
||||
auto c1 = args.device_shape[0];
|
||||
auto co = args.device_shape[co_idx];
|
||||
auto c0 = args.device_shape[c0_idx];
|
||||
|
||||
for (int64_t c1_i = 0; c1_i < c1; c1_i++) {
|
||||
for (int64_t h_i = 0; h_i < h; h_i++) {
|
||||
for (int64_t w_i = 0; w_i < w; w_i++) {
|
||||
for (int64_t n_i = 0; n_i < n; n_i++) {
|
||||
for (int64_t co_i = 0; co_i < co; co_i++) {
|
||||
for (int64_t c0_i = 0; c0_i < c0; c0_i++) {
|
||||
int64_t dst_idx = c1_i * h * w * n * co * c0 + h_i * w * n * co * c0 + w_i * n * co * c0 + n_i * co * c0 +
|
||||
co_i * c0 + c0_i;
|
||||
int64_t c_i = c0_i + c1_i * c0;
|
||||
int64_t src_idx = n_i * c * h * w + c_i * h * w + h_i * w + w_i;
|
||||
auto pad_zero = !(c_i < c && c0_i == co_i);
|
||||
SetData(size, pad_zero, src_idx, dst_idx, args, result);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool FormatTransfer::NCDHW_TO_NDC1HWC0(const FormatArgs &args, void *result) {
|
||||
MS_LOG(DEBUG) << "Trans from ncdhw to ndc1hwc0";
|
||||
MS_EXCEPTION_IF_NULL(result);
|
||||
|
||||
if (args.host_shape.size() != kNcdhw) {
|
||||
MS_LOG(ERROR) << "Illegal host shape dim, expect dim: 5, but got " << args.host_shape.size();
|
||||
return false;
|
||||
}
|
||||
auto size = SizeToLong(abstract::TypeIdSize(args.src_data_type));
|
||||
if (size < 1) {
|
||||
MS_LOG(ERROR) << "Illegal dtype.";
|
||||
return false;
|
||||
}
|
||||
auto total_size = abstract::ShapeSize(args.device_shape) * size;
|
||||
if (total_size != SizeToLong(args.device_size)) {
|
||||
MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
|
||||
return false;
|
||||
}
|
||||
|
||||
auto n = args.host_shape[N_ncdhw];
|
||||
auto c = args.host_shape[C_ncdhw];
|
||||
auto d = args.host_shape[D_ncdhw];
|
||||
auto h = args.host_shape[H_ncdhw];
|
||||
auto w = args.host_shape[W_ncdhw];
|
||||
auto c0 = GetCubeSizeByType(args.src_data_type);
|
||||
auto c1 = DivCeil(c, c0);
|
||||
const int64_t cdhw = c * d * h * w;
|
||||
const int64_t dhw = d * h * w;
|
||||
const int64_t hw = h * w;
|
||||
const int64_t dc1hwc0 = d * c1 * h * w * c0;
|
||||
const int64_t c1hwc0 = c1 * h * w * c0;
|
||||
const int64_t hwc0 = h * w * c0;
|
||||
const int64_t wc0 = w * c0;
|
||||
|
||||
for (int64_t n_i = 0; n_i < n; n_i++) {
|
||||
int64_t n_head = n_i * dc1hwc0;
|
||||
for (int64_t d_i = 0; d_i < d; d_i++) {
|
||||
int64_t d_head = n_head + d_i * c1hwc0;
|
||||
for (int64_t c1_i = 0; c1_i < c1; c1_i++) {
|
||||
int64_t c1_head = d_head + c1_i * hwc0;
|
||||
for (int64_t h_i = 0; h_i < h; h_i++) {
|
||||
int64_t h_head = c1_head + h_i * wc0;
|
||||
for (int64_t w_i = 0; w_i < w; w_i++) {
|
||||
int64_t w_head = h_head + w_i * c0;
|
||||
for (int64_t c0_i = 0; c0_i < c0; c0_i++) {
|
||||
int64_t dst_i = c0_i + w_head;
|
||||
int64_t c_i = c0_i + c1_i * c0;
|
||||
int64_t src_i = n_i * cdhw + c_i * dhw + d_i * hw + h_i * w + w_i;
|
||||
auto pad_zero = c_i >= c;
|
||||
SetData(size, pad_zero, src_i, dst_i, args, result);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool FormatTransfer::NCDHW_TO_FRAC_Z3D(const FormatArgs &args, void *result) {
|
||||
MS_LOG(DEBUG) << "Trans from ncdhw to frac_z_3d";
|
||||
MS_EXCEPTION_IF_NULL(result);
|
||||
|
||||
if (args.host_shape.size() != kNcdhw) {
|
||||
MS_LOG(ERROR) << "Illegal host shape dim, expect dim: 5, but got " << args.host_shape.size();
|
||||
return false;
|
||||
}
|
||||
auto size = SizeToLong(abstract::TypeIdSize(args.src_data_type));
|
||||
if (size < 1) {
|
||||
MS_LOG(ERROR) << "Illegal dtype.";
|
||||
return false;
|
||||
}
|
||||
auto total_size = abstract::ShapeSize(args.device_shape) * size;
|
||||
if (total_size != SizeToLong(args.device_size)) {
|
||||
MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
|
||||
return false;
|
||||
}
|
||||
|
||||
auto n = args.host_shape[N_ncdhw];
|
||||
auto c = args.host_shape[C_ncdhw];
|
||||
auto d = args.host_shape[D_ncdhw];
|
||||
auto h = args.host_shape[H_ncdhw];
|
||||
auto w = args.host_shape[W_ncdhw];
|
||||
|
||||
auto n1n0 = DivCeil(n, kNiSize) * kNiSize;
|
||||
auto c0 = GetCubeSizeByType(args.src_data_type);
|
||||
auto c1 = DivCeil(c, c0);
|
||||
auto hw = h * w;
|
||||
auto dhw = d * hw;
|
||||
auto cdhw = c * dhw;
|
||||
auto n1n0c0 = n1n0 * c0;
|
||||
auto wn1n0c0 = w * n1n0c0;
|
||||
auto hwn1n0c0 = h * wn1n0c0;
|
||||
auto c1hwn1n0c0 = c1 * hwn1n0c0;
|
||||
|
||||
for (int64_t d_i = 0; d_i < d; d_i++) {
|
||||
for (int64_t c1_i = 0; c1_i < c1; c1_i++) {
|
||||
for (int64_t h_i = 0; h_i < h; h_i++) {
|
||||
for (int64_t w_i = 0; w_i < w; w_i++) {
|
||||
for (int64_t n1n0_i = 0; n1n0_i < n1n0; n1n0_i++) {
|
||||
for (int64_t c0_i = 0; c0_i < c0; c0_i++) {
|
||||
auto dst_i = d_i * c1hwn1n0c0 + c1_i * hwn1n0c0 + h_i * wn1n0c0 + w_i * n1n0c0 + n1n0_i * c0 + c0_i;
|
||||
// ncdhw
|
||||
int64_t src_i = n1n0_i * cdhw + (c1_i * c0 + c0_i) * dhw + d_i * hw + h_i * w + w_i;
|
||||
auto pad_zero = ((c1_i * c0 + c0_i) >= c) || (n1n0_i >= n);
|
||||
SetData(size, pad_zero, src_i, dst_i, args, result);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool FormatTransfer::NCHW_TO_FRAC_Z_WITH_GROPUS(const FormatArgs &args, void *result, bool to_device, int64_t groups) {
|
||||
MS_EXCEPTION_IF_NULL(result);
|
||||
if (args.host_shape.size() != kNchwDims) {
|
||||
MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
|
||||
return false;
|
||||
}
|
||||
auto size = SizeToLong(abstract::TypeIdSize(args.src_data_type));
|
||||
if (size < 1) {
|
||||
MS_LOG(ERROR) << "Illegal dtype";
|
||||
return false;
|
||||
}
|
||||
auto n_dim = args.host_shape[kN];
|
||||
auto c_dim = args.host_shape[kC];
|
||||
auto h_dim = args.host_shape[kH];
|
||||
auto w_dim = args.host_shape[kW];
|
||||
auto d_dim = 1;
|
||||
auto cin_ori = c_dim;
|
||||
if (groups <= 0) {
|
||||
MS_LOG(EXCEPTION) << "The value of groups should be greater than 0, but got " << groups;
|
||||
}
|
||||
// cppcheck-suppress *
|
||||
auto cout_ori = n_dim / groups;
|
||||
if (cin_ori == 0 || cout_ori == 0) {
|
||||
MS_LOG(ERROR) << "cin_ori, cout_ori must not equal to 0";
|
||||
return false;
|
||||
}
|
||||
auto cube_k = GetCubeSizeByType(args.src_data_type);
|
||||
auto e_mult = std::min(Lcm(Lcm(cin_ori, cube_k) / cin_ori, Lcm(cout_ori, kCubeSize) / cout_ori), groups);
|
||||
if (e_mult == 0) {
|
||||
MS_LOG(EXCEPTION) << "The value of e_mult should be greater than 0, but got " << e_mult;
|
||||
}
|
||||
auto cin_opt = DivCeil(e_mult * cin_ori, cube_k) * cube_k;
|
||||
auto cout_opt = DivCeil(e_mult * cout_ori, kCubeSize) * kCubeSize;
|
||||
// cppcheck-suppress *
|
||||
auto c1_dim = cin_opt / cube_k;
|
||||
auto dst_size =
|
||||
to_device ? abstract::ShapeSize(args.device_shape) * size : abstract::ShapeSize(args.host_shape) * size;
|
||||
if (dst_size == 0) {
|
||||
return true;
|
||||
}
|
||||
auto ret = memset_s(result, dst_size, 0, dst_size);
|
||||
if (ret != EOK) {
|
||||
MS_LOG(ERROR) << "memset failed";
|
||||
return false;
|
||||
}
|
||||
for (int64_t g = 0; g < groups; ++g) {
|
||||
for (int64_t d = 0; d < d_dim; ++d) {
|
||||
for (int64_t c = 0; c < c_dim; ++c) {
|
||||
for (int64_t h = 0; h < h_dim; ++h) {
|
||||
for (int64_t w = 0; w < w_dim; ++w) {
|
||||
for (int64_t n = 0; n < cout_ori; ++n) {
|
||||
int64_t e_val = g % e_mult;
|
||||
int64_t dst_ci = e_val * cin_ori + c;
|
||||
int64_t dst_co = e_val * cout_ori + n;
|
||||
int64_t src_co = g * cout_ori + n;
|
||||
int64_t temporary = dst_ci % cube_k;
|
||||
int64_t dev_idx = (g / e_mult) * d_dim * c1_dim * h_dim * w_dim * cout_opt * cube_k +
|
||||
d * c1_dim * h_dim * w_dim * cout_opt * cube_k +
|
||||
(dst_ci / cube_k) * h_dim * w_dim * cout_opt * cube_k + h * w_dim * cout_opt * cube_k +
|
||||
w * cout_opt * cube_k + dst_co * cube_k + temporary;
|
||||
int64_t hst_idx =
|
||||
src_co * c_dim * d_dim * h_dim * w_dim + c * d_dim * h_dim * w_dim + d * h_dim * w_dim + h * w_dim + w;
|
||||
if (to_device) {
|
||||
SetData(size, false, hst_idx, dev_idx, args, result);
|
||||
} else {
|
||||
SetData(size, false, dev_idx, hst_idx, args, result);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool FormatTransfer::NC1HWC0_TO_NCHW(const FormatArgs &args, void *result) {
|
||||
MS_LOG(DEBUG) << "Trans format from nc1h1wc0 to nchw";
|
||||
MS_EXCEPTION_IF_NULL(result);
|
||||
if (args.host_shape.size() != kNchwDims) {
|
||||
MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
|
||||
return false;
|
||||
}
|
||||
auto size = SizeToLong(abstract::TypeIdSize(args.src_data_type));
|
||||
if (size < 1) {
|
||||
MS_LOG(ERROR) << "Illegal dtype.";
|
||||
return false;
|
||||
}
|
||||
auto total_size = abstract::ShapeSize(args.device_shape) * size;
|
||||
if (total_size != SizeToLong(args.device_size)) {
|
||||
MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
|
||||
return false;
|
||||
}
|
||||
|
||||
auto n = args.host_shape[kN];
|
||||
auto c = args.host_shape[kC];
|
||||
auto h = args.host_shape[kH];
|
||||
auto w = args.host_shape[kW];
|
||||
auto c1 = args.device_shape[1];
|
||||
auto c0 = args.device_shape[4];
|
||||
|
||||
auto hw = h * w;
|
||||
auto chw = c * hw;
|
||||
auto wc0 = w * c0;
|
||||
auto hwc0 = h * wc0;
|
||||
auto c1hwc0 = c1 * hwc0;
|
||||
|
||||
for (int64_t n_idx = 0; n_idx < n; n_idx++) {
|
||||
int64_t n_head_addr = n_idx * chw;
|
||||
for (int64_t c_idx = 0; c_idx < c; c_idx++) {
|
||||
int64_t c_head_addr = n_head_addr + c_idx * hw;
|
||||
for (int64_t h_idx = 0; h_idx < h; h_idx++) {
|
||||
int64_t h_head_addr = c_head_addr + h_idx * w;
|
||||
for (int64_t w_idx = 0; w_idx < w; w_idx++) {
|
||||
int64_t dst_idx = h_head_addr + w_idx;
|
||||
int64_t c1_idx = c_idx / c0;
|
||||
int64_t c0_idx = c_idx % c0;
|
||||
int64_t src_idx = n_idx * c1hwc0 + c1_idx * hwc0 + h_idx * wc0 + w_idx * c0 + c0_idx;
|
||||
SetData(size, false, src_idx, dst_idx, args, result);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool FormatTransfer::NC1HWC04_TO_NCHW(const FormatArgs &args, void *result) {
|
||||
MS_LOG(DEBUG) << "Trans format from Nc1hwc04 to nchw.";
|
||||
return NC1HWC0_TO_NCHW(args, result);
|
||||
}
|
||||
|
||||
bool FormatTransfer::C1HWNCOC0_TO_NCHW(const FormatArgs &args, void *result) {
|
||||
// trans c1hwncoc0 to nchw
|
||||
MS_LOG(DEBUG) << "Trans format from c1hwncoc0 to nchw";
|
||||
MS_EXCEPTION_IF_NULL(result);
|
||||
int64_t size = 0;
|
||||
if (!CheckArgs(args, &size)) {
|
||||
MS_LOG(ERROR) << "Check args failed.";
|
||||
return false;
|
||||
}
|
||||
auto n = args.host_shape[kN];
|
||||
auto c = args.host_shape[kC];
|
||||
auto h = args.host_shape[kH];
|
||||
auto w = args.host_shape[kW];
|
||||
const int co_idx = 4;
|
||||
const int c0_idx = 5;
|
||||
auto co = args.device_shape[co_idx];
|
||||
auto c0 = args.device_shape[c0_idx];
|
||||
auto cube_k = GetCubeSizeByType(args.src_data_type);
|
||||
for (int64_t n_i = 0; n_i < n; n_i++) {
|
||||
for (int64_t c_i = 0; c_i < c; c_i++) {
|
||||
for (int64_t h_i = 0; h_i < h; h_i++) {
|
||||
for (int64_t w_i = 0; w_i < w; w_i++) {
|
||||
int64_t dst_idx = n_i * c * h * w + c_i * h * w + h_i * w + w_i;
|
||||
int64_t c1_i = c_i / cube_k;
|
||||
int64_t c0_i = c_i % cube_k;
|
||||
int64_t co_i = c0_i;
|
||||
int64_t src_idx =
|
||||
c1_i * h * w * n * co * c0 + h_i * w * n * co * c0 + w_i * n * co * c0 + n_i * co * c0 + co_i * c0 + c0_i;
|
||||
SetData(size, false, src_idx, dst_idx, args, result);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool FormatTransfer::FRAC_Z_TO_NCHW(const FormatArgs &args, void *result) {
|
||||
MS_LOG(DEBUG) << "Trans format from frac_z to nchw";
|
||||
MS_EXCEPTION_IF_NULL(result);
|
||||
if (args.host_shape.size() != kNchwDims) {
|
||||
MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
|
||||
return false;
|
||||
}
|
||||
auto size = SizeToLong(abstract::TypeIdSize(args.src_data_type));
|
||||
if (size < 1) {
|
||||
MS_LOG(ERROR) << "Illegal dtype.";
|
||||
return false;
|
||||
}
|
||||
auto total_size = abstract::ShapeSize(args.device_shape) * size;
|
||||
if (total_size != SizeToLong(args.device_size)) {
|
||||
MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
|
||||
return false;
|
||||
}
|
||||
|
||||
auto n0 = args.device_shape.at(1);
|
||||
auto ni = args.device_shape.at(2);
|
||||
auto c0 = args.device_shape.at(3);
|
||||
auto n = args.host_shape[kN];
|
||||
auto c = args.host_shape[kC];
|
||||
auto h = args.host_shape[kH];
|
||||
auto w = args.host_shape[kW];
|
||||
auto nc = ni * n0;
|
||||
auto ncc0 = nc * c0;
|
||||
auto wncc0 = w * ncc0;
|
||||
auto hwncc0 = h * wncc0;
|
||||
auto hw = h * w;
|
||||
auto chw = c * hw;
|
||||
|
||||
for (int64_t n_idx = 0; n_idx < n; n_idx++) {
|
||||
int64_t n_head_addr = n_idx * chw;
|
||||
for (int64_t c_idx = 0; c_idx < c; c_idx++) {
|
||||
int64_t c_head_addr = n_head_addr + c_idx * hw;
|
||||
for (int64_t h_idx = 0; h_idx < h; h_idx++) {
|
||||
int64_t h_head_addr = c_head_addr + h_idx * w;
|
||||
for (int64_t w_idx = 0; w_idx < w; w_idx++) {
|
||||
auto dst_idx = h_head_addr + w_idx;
|
||||
auto c1_idx = c_idx / c0;
|
||||
auto c0_idx = c_idx % c0;
|
||||
auto nc_idx = n_idx;
|
||||
auto src_idx = c1_idx * hwncc0 + h_idx * wncc0 + w_idx * ncc0 + nc_idx * c0 + c0_idx;
|
||||
SetData(size, false, src_idx, dst_idx, args, result);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool FormatTransfer::FRAC_NZ_TO_NCHW(const FormatArgs &args, void *result) {
|
||||
MS_LOG(DEBUG) << "Trans format from frac_nz to nchw";
|
||||
MS_EXCEPTION_IF_NULL(result);
|
||||
ShapeVector hw_shape;
|
||||
if (!TransShapeToHW_NZ(args.host_shape, &hw_shape)) {
|
||||
MS_LOG(ERROR) << "Trans shape failed..";
|
||||
return false;
|
||||
}
|
||||
if (hw_shape.size() < kDim3 || args.device_shape.size() < kDim4) {
|
||||
MS_LOG(ERROR) << "Invalid shape size.";
|
||||
return false;
|
||||
}
|
||||
auto size = SizeToLong(abstract::TypeIdSize(args.src_data_type));
|
||||
if (size < 1) {
|
||||
MS_LOG(ERROR) << "Illegal dtype";
|
||||
return false;
|
||||
}
|
||||
|
||||
auto dst_size = abstract::ShapeSize(args.device_shape) * size;
|
||||
if (dst_size != SizeToLong(args.device_size)) {
|
||||
MS_LOG(ERROR) << "Illegal total data size, total_size:" << dst_size << ", device_size:" << args.device_size;
|
||||
return false;
|
||||
}
|
||||
auto times = hw_shape.at(0);
|
||||
auto h = hw_shape.at(1);
|
||||
auto w = hw_shape.at(2);
|
||||
auto hw = h * w;
|
||||
|
||||
auto shape_size = args.device_shape.size();
|
||||
auto w1 = args.device_shape[shape_size - 4];
|
||||
auto h1 = args.device_shape[shape_size - 3];
|
||||
auto h0 = args.device_shape[shape_size - 2];
|
||||
auto w0 = args.device_shape[shape_size - 1];
|
||||
auto h1h0w0 = h1 * h0 * w0;
|
||||
auto w1h1h0w0 = w1 * h1h0w0;
|
||||
auto num_w1 = w / w0;
|
||||
|
||||
for (int64_t times_idx = 0; times_idx < times; times_idx++) {
|
||||
auto times_head = times_idx * w1h1h0w0;
|
||||
auto src_times_head = times_idx * hw;
|
||||
for (int64_t h1h0_idx = 0; h1h0_idx < h; h1h0_idx++) {
|
||||
auto h1h0_head = times_head + h1h0_idx * w0;
|
||||
auto src_h_head = src_times_head + h1h0_idx * w;
|
||||
for (int64_t w1_idx = 0; w1_idx < num_w1; w1_idx++) {
|
||||
for (int64_t i = 0; i < w0; ++i) {
|
||||
int64_t src_idx = h1h0_head + w1_idx * h1h0w0 + i;
|
||||
int64_t dst_idx = src_h_head + w1_idx * w0 + i;
|
||||
SetData(size, false, src_idx, dst_idx, args, result);
|
||||
}
|
||||
}
|
||||
auto w1_head = num_w1 * w0;
|
||||
for (int64_t w0_idx = 0; w1_head + w0_idx < w; w0_idx++) {
|
||||
auto src_w_idx = w1_head + w0_idx;
|
||||
int64_t src_idx = h1h0_head + num_w1 * h1h0w0 + w0_idx;
|
||||
int64_t dst_idx = src_h_head + src_w_idx;
|
||||
SetData(size, false, src_idx, dst_idx, args, result);
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool FormatTransfer::NDC1HWC0_TO_NCDHW(const FormatArgs &args, void *result) {
|
||||
MS_LOG(DEBUG) << "Trans from ndc1hwc0 to ncdhw";
|
||||
MS_EXCEPTION_IF_NULL(result);
|
||||
|
||||
if (args.host_shape.size() != kNcdhw) {
|
||||
MS_LOG(ERROR) << "Illegal host shape dim, expect dim: 5, but got " << args.host_shape.size();
|
||||
return false;
|
||||
}
|
||||
auto size = SizeToLong(abstract::TypeIdSize(args.src_data_type));
|
||||
if (size < 1) {
|
||||
MS_LOG(ERROR) << "Illegal dtype.";
|
||||
return false;
|
||||
}
|
||||
auto total_size = abstract::ShapeSize(args.device_shape) * size;
|
||||
if (total_size != SizeToLong(args.device_size)) {
|
||||
MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
|
||||
return false;
|
||||
}
|
||||
auto n = args.host_shape[N_ncdhw];
|
||||
auto c = args.host_shape[C_ncdhw];
|
||||
auto d = args.host_shape[D_ncdhw];
|
||||
auto h = args.host_shape[H_ncdhw];
|
||||
auto w = args.host_shape[W_ncdhw];
|
||||
auto c1 = args.device_shape[C1_ndc1hwc0];
|
||||
auto c0 = args.device_shape[C0_ndc1hwc0];
|
||||
const int64_t cdhw = c * d * h * w;
|
||||
const int64_t dhw = d * h * w;
|
||||
const int64_t hw = h * w;
|
||||
const int64_t dc1hwc0 = d * c1 * h * w * c0;
|
||||
const int64_t c1hwc0 = c1 * h * w * c0;
|
||||
const int64_t hwc0 = h * w * c0;
|
||||
const int64_t wc0 = w * c0;
|
||||
|
||||
for (int64_t n_i = 0; n_i < n; n_i++) {
|
||||
int64_t n_head = n_i * cdhw;
|
||||
for (int64_t c_i = 0; c_i < c; c_i++) {
|
||||
int64_t c_head = n_head + c_i * dhw;
|
||||
for (int64_t d_i = 0; d_i < d; d_i++) {
|
||||
int64_t d_head = c_head + d_i * hw;
|
||||
for (int64_t h_i = 0; h_i < h; h_i++) {
|
||||
int64_t h_head = d_head + h_i * w;
|
||||
for (int64_t w_i = 0; w_i < w; w_i++) {
|
||||
int64_t dst_i = h_head + w_i;
|
||||
int64_t c1_i = c_i / c0;
|
||||
int64_t c0_i = c_i % c0;
|
||||
auto src_idx = n_i * dc1hwc0 + d_i * c1hwc0 + c1_i * hwc0 + h_i * wc0 + w_i * c0 + c0_i;
|
||||
SetData(size, false, src_idx, dst_i, args, result);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool FormatTransfer::FRAC_Z_TO_NCHW_WITH_GROUPS(const FormatArgs &args, void *result, int64_t groups) {
|
||||
MS_LOG(DEBUG) << "Trans format from frac_z to nchw with groups=" << groups;
|
||||
return NCHW_TO_FRAC_Z_WITH_GROPUS(args, result, false, groups);
|
||||
}
|
||||
} // namespace TEMP
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -190,6 +190,68 @@ class DeviceShapeTransfer {
|
|||
const ShapeVector &input_hidden_size = {kAlign16, kAlign16});
|
||||
};
|
||||
|
||||
/**
|
||||
* Trans data at host according to the node's format
|
||||
* */
|
||||
class FormatTransfer {
|
||||
public:
|
||||
FormatTransfer() = default;
|
||||
~FormatTransfer() = default;
|
||||
|
||||
bool TransDataByFormat(const FormatArgs &args, void *result, const AnfNodePtr &node, size_t index, bool is_forward);
|
||||
bool TransDataForwardCore(const FormatArgs &args, void *result, int64_t groups = 1);
|
||||
bool TransDataBackwordCore(const FormatArgs &args, void *result, int64_t groups = 1);
|
||||
|
||||
private:
|
||||
using TransferCore = std::function<bool(const FormatArgs &, void *)>;
|
||||
// fp map
|
||||
const std::map<std::string, TransferCore> format_trans_fp_map = {{kOpFormat_HWCN, NCHW_TO_4D},
|
||||
{kOpFormat_NHWC, NCHW_TO_4D},
|
||||
{kOpFormat_FRAC_Z, NCHW_TO_FRAC_Z},
|
||||
{kOpFormat_FRAC_NZ, NCHW_TO_FRAC_NZ},
|
||||
{kOpFormat_NC1HWC0, NCHW_TO_NC1HWC0},
|
||||
{kOpFormat_NDC1HWC0, NCDHW_TO_NDC1HWC0},
|
||||
{kOpFormat_C1HWNCoC0, NCHW_TO_C1HWNCOC0},
|
||||
{kOpFormat_NC1HWC0_C04, NCHW_TO_NC1HWC04},
|
||||
{kOpFormat_FRACTAL_Z_3D, NCDHW_TO_FRAC_Z3D},
|
||||
{kOpFormat_FRACTAL_Z_C04, NCHW_TO_FRAC_ZC04}};
|
||||
// bp map
|
||||
const std::map<std::string, TransferCore> format_trans_bp_map = {{kOpFormat_HWCN, TO_NCHW},
|
||||
{kOpFormat_NHWC, TO_NCHW},
|
||||
{kOpFormat_FRAC_Z, FRAC_Z_TO_NCHW},
|
||||
{kOpFormat_FRAC_NZ, FRAC_NZ_TO_NCHW},
|
||||
{kOpFormat_NC1HWC0, NC1HWC0_TO_NCHW},
|
||||
{kOpFormat_NDC1HWC0, NDC1HWC0_TO_NCDHW},
|
||||
{kOpFormat_C1HWNCoC0, C1HWNCOC0_TO_NCHW},
|
||||
{kOpFormat_NC1HWC0_C04, NC1HWC04_TO_NCHW},
|
||||
{kOpFormat_FRACTAL_Z_3D, FRAC_Z3D_TO_NCDHW}};
|
||||
|
||||
static bool CheckArgs(const FormatArgs &args, int64_t *size);
|
||||
static bool TransShapeToHW_NZ(const ShapeVector &host_shape, ShapeVector *hw_shape);
|
||||
// HOST TO DEVICE
|
||||
static bool NCHW_TO_4D(const FormatArgs &args, void *result);
|
||||
static bool NCHW_TO_FRAC_Z(const FormatArgs &args, void *result);
|
||||
static bool NCHW_TO_NC1HWC0(const FormatArgs &args, void *result);
|
||||
static bool NCHW_TO_FRAC_NZ(const FormatArgs &args, void *result);
|
||||
static bool NCHW_TO_NC1HWC04(const FormatArgs &args, void *result);
|
||||
static bool NCHW_TO_FRAC_ZC04(const FormatArgs &args, void *result);
|
||||
static bool NCHW_TO_C1HWNCOC0(const FormatArgs &args, void *result);
|
||||
static bool NCDHW_TO_NDC1HWC0(const FormatArgs &args, void *result);
|
||||
static bool NCDHW_TO_FRAC_Z3D(const FormatArgs &args, void *result);
|
||||
static bool NCHW_TO_FRAC_Z_WITH_GROPUS(const FormatArgs &args, void *result, bool to_device, int64_t groups);
|
||||
|
||||
// DEVICE TO HOST
|
||||
static bool TO_NCHW(const FormatArgs &args, void *result);
|
||||
static bool FRAC_Z_TO_NCHW(const FormatArgs &args, void *result);
|
||||
static bool FRAC_NZ_TO_NCHW(const FormatArgs &args, void *result);
|
||||
static bool NC1HWC0_TO_NCHW(const FormatArgs &args, void *result);
|
||||
static bool NC1HWC04_TO_NCHW(const FormatArgs &args, void *result);
|
||||
static bool C1HWNCOC0_TO_NCHW(const FormatArgs &args, void *result);
|
||||
static bool FRAC_Z3D_TO_NCDHW(const FormatArgs &args, void *result);
|
||||
static bool NDC1HWC0_TO_NCDHW(const FormatArgs &args, void *result);
|
||||
static bool FRAC_Z_TO_NCHW_WITH_GROUPS(const FormatArgs &args, void *result, int64_t groups);
|
||||
};
|
||||
|
||||
/**
|
||||
* Padding shape to 5D by default mode
|
||||
* */
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include <utility>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <functional>
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/any.h"
|
||||
#include "utils/misc.h"
|
||||
|
@ -46,7 +47,10 @@ AbstractBasePtr SensitivityTransform(const AbstractBasePtr &spec);
|
|||
ShapeVector BroadcastShape(ShapeVector shpx, ShapeVector shpy);
|
||||
|
||||
MS_CORE_API size_t TypeIdSize(const TypeId data_type);
|
||||
size_t ShapeSize(const std::vector<size_t> &shape);
|
||||
template <typename T>
|
||||
T ShapeSize(const std::vector<T> &shape) {
|
||||
return std::accumulate(shape.begin(), shape.end(), static_cast<T>(1), std::multiplies<T>());
|
||||
}
|
||||
|
||||
// Check dynamic shape routine
|
||||
void CheckMinMaxShape(const ShapeVector &shape, ShapeVector *min_shape, ShapeVector *max_shape);
|
||||
|
|
Loading…
Reference in New Issue