diff --git a/mindspore/ccsrc/common/trans.cc b/mindspore/ccsrc/common/trans.cc index 954261f912e..731c71c7811 100644 --- a/mindspore/ccsrc/common/trans.cc +++ b/mindspore/ccsrc/common/trans.cc @@ -378,8 +378,8 @@ std::vector FracZc04DeviceShape(const std::vector &shape) { } std::vector device_shape; size_t c0 = 4; - size_t first_dim = DivCeil(c0 * shape[2] * shape[3], kCubeSize); - size_t no = DivCeil(DivCeil(shape[0], kCubeSize) * kCubeSize, kCubeSize); + auto first_dim = DivCeil(c0 * shape.at(2) * shape.at(3), kCubeSize); + auto no = DivCeil(shape.at(0), kCubeSize); device_shape.push_back(first_dim); device_shape.push_back(no); device_shape.push_back(kCubeSize); @@ -495,6 +495,10 @@ bool TransFormat(const FormatArgs &args, void *result) { return NchwToNc1hwc0(args, result); } else if (args.device_format == kOpFormat_C1HWNCoC0) { return NchwToC1hwncoc0(args, result); + } else if (args.device_format == kOpFormat_FRACTAL_Z_C04) { + return NchwToFracZc04(args, result); + } else if (args.device_format == kOpFormat_NC1HWC0_C04) { + return NchwToNc1hwc04(args, result); } return true; } @@ -513,6 +517,8 @@ bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result) { return Nc1hwc0ToNchw(args, result); } else if (args.device_format == kOpFormat_C1HWNCoC0) { return C1hwncoc0ToNchw(args, result); + } else if (args.device_format == kOpFormat_NC1HWC0_C04) { + return Nc1hwc04ToNchw(args, result); } return true; } @@ -632,6 +638,65 @@ bool FracZToNchw(const FormatArgs &args, void *result) { return true; } +bool NchwToFracZc04(const FormatArgs &args, void *result) { + // trans nchw to FracZc04 + MS_LOG(DEBUG) << "Trans format from nchw to FracZc04."; + MS_EXCEPTION_IF_NULL(result); + size_t size = 0; + size_t total_size = 0; + if (!CheckArgs(args, &size, &total_size)) { + MS_LOG(ERROR) << "Check args failed."; + return false; + } + size_t cube = kCubeSize; + size_t n = args.host_shape[0]; + size_t c = args.host_shape[1]; + size_t h = args.host_shape[2]; + size_t w = args.host_shape[3]; + + size_t c0 = 4; + size_t c1 = DivCeil(c, c0); + size_t hwc0 = h * w * c0; + size_t hwc = h * w * c; + size_t nhwc = n * h * w * c; + + size_t n_cnt = DivCeil(n, cube); + size_t v_cnt = DivCeil(h * w * c0 * c1, cube); + size_t dst_idx = 0; + + for (size_t vi = 0; vi < v_cnt; vi++) { + for (size_t ni = 0; ni < n_cnt; ni++) { + for (size_t col = 0; col < cube; col++) { + for (size_t row = 0; row < cube; row++) { + size_t cur_cube_n = cube * ni + col; + size_t cur_cube_c1hwc0 = cube * vi + row; + auto desc_g = cur_cube_n / n; + auto desc_n = cur_cube_n % n; + auto desc_c1 = cur_cube_c1hwc0 / hwc0; + auto desc_c0 = cur_cube_c1hwc0 % c0; + auto desc_h = (cur_cube_c1hwc0 - hwc0 * desc_c1) / (w * c0); + auto desc_w = (cur_cube_c1hwc0 - hwc0 * desc_c1 - w * c0 * desc_h) / c0; + auto c_idx = desc_c1 * c0 + desc_c0; + auto src_idx = desc_g * nhwc + desc_n * hwc + c_idx * h * w + desc_h * w + desc_w; + auto pad_zero = desc_g >= 1 || desc_n >= n || c_idx >= c; + SetDataBysize(size, pad_zero); + dst_idx++; + } + } + } + } + return true; +} + +bool NchwToNc1hwc04(const FormatArgs &args, void *result) { + MS_LOG(DEBUG) << "Trans format from nchw to Nc1hwc04."; + return NchwToNc1hwc0(args, result); +} +bool Nc1hwc04ToNchw(const FormatArgs &args, void *result) { + MS_LOG(DEBUG) << "Trans format from Nc1hwc04 to nchw."; + return Nc1hwc0ToNchw(args, result); +} + bool TransShapeToNz(const std::vector &host_shape, std::vector *hw_shape) { MS_EXCEPTION_IF_NULL(hw_shape); if (host_shape.empty()) { diff --git a/mindspore/ccsrc/common/trans.h b/mindspore/ccsrc/common/trans.h index 0593466c381..e15b95e6d3e 100644 --- a/mindspore/ccsrc/common/trans.h +++ b/mindspore/ccsrc/common/trans.h @@ -64,11 +64,14 @@ bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result); bool NchwToFracZ(const FormatArgs &args, void *result); bool NchwToFracNz(const FormatArgs &args, void *result); bool NchwToNc1hwc0(const FormatArgs &args, void *result); +bool NchwToFracZc04(const FormatArgs &args, void *result); +bool NchwToNc1hwc04(const FormatArgs &args, void *result); bool NchwToC1hwncoc0(const FormatArgs &args, void *result); // device to host bool FracZToNchw(const FormatArgs &args, void *result); bool FracNzToNchw(const FormatArgs &args, void *result); bool Nc1hwc0ToNchw(const FormatArgs &args, void *result); +bool Nc1hwc04ToNchw(const FormatArgs &args, void *result); bool C1hwncoc0ToNchw(const FormatArgs &args, void *result); } // namespace trans } // namespace mindspore