add format transfer for fracz_c04 and nc1hwc04

This commit is contained in:
liubuyu 2020-05-14 10:54:22 +08:00
parent a2a3b1c6c5
commit 2fe76bcfbd
2 changed files with 70 additions and 2 deletions

View File

@ -378,8 +378,8 @@ std::vector<size_t> FracZc04DeviceShape(const std::vector<size_t> &shape) {
} }
std::vector<size_t> device_shape; std::vector<size_t> device_shape;
size_t c0 = 4; size_t c0 = 4;
size_t first_dim = DivCeil(c0 * shape[2] * shape[3], kCubeSize); auto first_dim = DivCeil(c0 * shape.at(2) * shape.at(3), kCubeSize);
size_t no = DivCeil(DivCeil(shape[0], kCubeSize) * kCubeSize, kCubeSize); auto no = DivCeil(shape.at(0), kCubeSize);
device_shape.push_back(first_dim); device_shape.push_back(first_dim);
device_shape.push_back(no); device_shape.push_back(no);
device_shape.push_back(kCubeSize); device_shape.push_back(kCubeSize);
@ -495,6 +495,10 @@ bool TransFormat(const FormatArgs &args, void *result) {
return NchwToNc1hwc0(args, result); return NchwToNc1hwc0(args, result);
} else if (args.device_format == kOpFormat_C1HWNCoC0) { } else if (args.device_format == kOpFormat_C1HWNCoC0) {
return NchwToC1hwncoc0(args, result); return NchwToC1hwncoc0(args, result);
} else if (args.device_format == kOpFormat_FRACTAL_Z_C04) {
return NchwToFracZc04(args, result);
} else if (args.device_format == kOpFormat_NC1HWC0_C04) {
return NchwToNc1hwc04(args, result);
} }
return true; return true;
} }
@ -513,6 +517,8 @@ bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result) {
return Nc1hwc0ToNchw(args, result); return Nc1hwc0ToNchw(args, result);
} else if (args.device_format == kOpFormat_C1HWNCoC0) { } else if (args.device_format == kOpFormat_C1HWNCoC0) {
return C1hwncoc0ToNchw(args, result); return C1hwncoc0ToNchw(args, result);
} else if (args.device_format == kOpFormat_NC1HWC0_C04) {
return Nc1hwc04ToNchw(args, result);
} }
return true; return true;
} }
@ -632,6 +638,65 @@ bool FracZToNchw(const FormatArgs &args, void *result) {
return true; return true;
} }
bool NchwToFracZc04(const FormatArgs &args, void *result) {
// trans nchw to FracZc04
MS_LOG(DEBUG) << "Trans format from nchw to FracZc04.";
MS_EXCEPTION_IF_NULL(result);
size_t size = 0;
size_t total_size = 0;
if (!CheckArgs(args, &size, &total_size)) {
MS_LOG(ERROR) << "Check args failed.";
return false;
}
size_t cube = kCubeSize;
size_t n = args.host_shape[0];
size_t c = args.host_shape[1];
size_t h = args.host_shape[2];
size_t w = args.host_shape[3];
size_t c0 = 4;
size_t c1 = DivCeil(c, c0);
size_t hwc0 = h * w * c0;
size_t hwc = h * w * c;
size_t nhwc = n * h * w * c;
size_t n_cnt = DivCeil(n, cube);
size_t v_cnt = DivCeil(h * w * c0 * c1, cube);
size_t dst_idx = 0;
for (size_t vi = 0; vi < v_cnt; vi++) {
for (size_t ni = 0; ni < n_cnt; ni++) {
for (size_t col = 0; col < cube; col++) {
for (size_t row = 0; row < cube; row++) {
size_t cur_cube_n = cube * ni + col;
size_t cur_cube_c1hwc0 = cube * vi + row;
auto desc_g = cur_cube_n / n;
auto desc_n = cur_cube_n % n;
auto desc_c1 = cur_cube_c1hwc0 / hwc0;
auto desc_c0 = cur_cube_c1hwc0 % c0;
auto desc_h = (cur_cube_c1hwc0 - hwc0 * desc_c1) / (w * c0);
auto desc_w = (cur_cube_c1hwc0 - hwc0 * desc_c1 - w * c0 * desc_h) / c0;
auto c_idx = desc_c1 * c0 + desc_c0;
auto src_idx = desc_g * nhwc + desc_n * hwc + c_idx * h * w + desc_h * w + desc_w;
auto pad_zero = desc_g >= 1 || desc_n >= n || c_idx >= c;
SetDataBysize(size, pad_zero);
dst_idx++;
}
}
}
}
return true;
}
bool NchwToNc1hwc04(const FormatArgs &args, void *result) {
MS_LOG(DEBUG) << "Trans format from nchw to Nc1hwc04.";
return NchwToNc1hwc0(args, result);
}
bool Nc1hwc04ToNchw(const FormatArgs &args, void *result) {
MS_LOG(DEBUG) << "Trans format from Nc1hwc04 to nchw.";
return Nc1hwc0ToNchw(args, result);
}
bool TransShapeToNz(const std::vector<size_t> &host_shape, std::vector<size_t> *hw_shape) { bool TransShapeToNz(const std::vector<size_t> &host_shape, std::vector<size_t> *hw_shape) {
MS_EXCEPTION_IF_NULL(hw_shape); MS_EXCEPTION_IF_NULL(hw_shape);
if (host_shape.empty()) { if (host_shape.empty()) {

View File

@ -64,11 +64,14 @@ bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result);
bool NchwToFracZ(const FormatArgs &args, void *result); bool NchwToFracZ(const FormatArgs &args, void *result);
bool NchwToFracNz(const FormatArgs &args, void *result); bool NchwToFracNz(const FormatArgs &args, void *result);
bool NchwToNc1hwc0(const FormatArgs &args, void *result); bool NchwToNc1hwc0(const FormatArgs &args, void *result);
bool NchwToFracZc04(const FormatArgs &args, void *result);
bool NchwToNc1hwc04(const FormatArgs &args, void *result);
bool NchwToC1hwncoc0(const FormatArgs &args, void *result); bool NchwToC1hwncoc0(const FormatArgs &args, void *result);
// device to host // device to host
bool FracZToNchw(const FormatArgs &args, void *result); bool FracZToNchw(const FormatArgs &args, void *result);
bool FracNzToNchw(const FormatArgs &args, void *result); bool FracNzToNchw(const FormatArgs &args, void *result);
bool Nc1hwc0ToNchw(const FormatArgs &args, void *result); bool Nc1hwc0ToNchw(const FormatArgs &args, void *result);
bool Nc1hwc04ToNchw(const FormatArgs &args, void *result);
bool C1hwncoc0ToNchw(const FormatArgs &args, void *result); bool C1hwncoc0ToNchw(const FormatArgs &args, void *result);
} // namespace trans } // namespace trans
} // namespace mindspore } // namespace mindspore