forked from mindspore-Ecosystem/mindspore
add format transfer for fracz_c04 and nc1hwc04
This commit is contained in:
parent
a2a3b1c6c5
commit
2fe76bcfbd
|
@ -378,8 +378,8 @@ std::vector<size_t> FracZc04DeviceShape(const std::vector<size_t> &shape) {
|
||||||
}
|
}
|
||||||
std::vector<size_t> device_shape;
|
std::vector<size_t> device_shape;
|
||||||
size_t c0 = 4;
|
size_t c0 = 4;
|
||||||
size_t first_dim = DivCeil(c0 * shape[2] * shape[3], kCubeSize);
|
auto first_dim = DivCeil(c0 * shape.at(2) * shape.at(3), kCubeSize);
|
||||||
size_t no = DivCeil(DivCeil(shape[0], kCubeSize) * kCubeSize, kCubeSize);
|
auto no = DivCeil(shape.at(0), kCubeSize);
|
||||||
device_shape.push_back(first_dim);
|
device_shape.push_back(first_dim);
|
||||||
device_shape.push_back(no);
|
device_shape.push_back(no);
|
||||||
device_shape.push_back(kCubeSize);
|
device_shape.push_back(kCubeSize);
|
||||||
|
@ -495,6 +495,10 @@ bool TransFormat(const FormatArgs &args, void *result) {
|
||||||
return NchwToNc1hwc0(args, result);
|
return NchwToNc1hwc0(args, result);
|
||||||
} else if (args.device_format == kOpFormat_C1HWNCoC0) {
|
} else if (args.device_format == kOpFormat_C1HWNCoC0) {
|
||||||
return NchwToC1hwncoc0(args, result);
|
return NchwToC1hwncoc0(args, result);
|
||||||
|
} else if (args.device_format == kOpFormat_FRACTAL_Z_C04) {
|
||||||
|
return NchwToFracZc04(args, result);
|
||||||
|
} else if (args.device_format == kOpFormat_NC1HWC0_C04) {
|
||||||
|
return NchwToNc1hwc04(args, result);
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -513,6 +517,8 @@ bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result) {
|
||||||
return Nc1hwc0ToNchw(args, result);
|
return Nc1hwc0ToNchw(args, result);
|
||||||
} else if (args.device_format == kOpFormat_C1HWNCoC0) {
|
} else if (args.device_format == kOpFormat_C1HWNCoC0) {
|
||||||
return C1hwncoc0ToNchw(args, result);
|
return C1hwncoc0ToNchw(args, result);
|
||||||
|
} else if (args.device_format == kOpFormat_NC1HWC0_C04) {
|
||||||
|
return Nc1hwc04ToNchw(args, result);
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -632,6 +638,65 @@ bool FracZToNchw(const FormatArgs &args, void *result) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool NchwToFracZc04(const FormatArgs &args, void *result) {
|
||||||
|
// trans nchw to FracZc04
|
||||||
|
MS_LOG(DEBUG) << "Trans format from nchw to FracZc04.";
|
||||||
|
MS_EXCEPTION_IF_NULL(result);
|
||||||
|
size_t size = 0;
|
||||||
|
size_t total_size = 0;
|
||||||
|
if (!CheckArgs(args, &size, &total_size)) {
|
||||||
|
MS_LOG(ERROR) << "Check args failed.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
size_t cube = kCubeSize;
|
||||||
|
size_t n = args.host_shape[0];
|
||||||
|
size_t c = args.host_shape[1];
|
||||||
|
size_t h = args.host_shape[2];
|
||||||
|
size_t w = args.host_shape[3];
|
||||||
|
|
||||||
|
size_t c0 = 4;
|
||||||
|
size_t c1 = DivCeil(c, c0);
|
||||||
|
size_t hwc0 = h * w * c0;
|
||||||
|
size_t hwc = h * w * c;
|
||||||
|
size_t nhwc = n * h * w * c;
|
||||||
|
|
||||||
|
size_t n_cnt = DivCeil(n, cube);
|
||||||
|
size_t v_cnt = DivCeil(h * w * c0 * c1, cube);
|
||||||
|
size_t dst_idx = 0;
|
||||||
|
|
||||||
|
for (size_t vi = 0; vi < v_cnt; vi++) {
|
||||||
|
for (size_t ni = 0; ni < n_cnt; ni++) {
|
||||||
|
for (size_t col = 0; col < cube; col++) {
|
||||||
|
for (size_t row = 0; row < cube; row++) {
|
||||||
|
size_t cur_cube_n = cube * ni + col;
|
||||||
|
size_t cur_cube_c1hwc0 = cube * vi + row;
|
||||||
|
auto desc_g = cur_cube_n / n;
|
||||||
|
auto desc_n = cur_cube_n % n;
|
||||||
|
auto desc_c1 = cur_cube_c1hwc0 / hwc0;
|
||||||
|
auto desc_c0 = cur_cube_c1hwc0 % c0;
|
||||||
|
auto desc_h = (cur_cube_c1hwc0 - hwc0 * desc_c1) / (w * c0);
|
||||||
|
auto desc_w = (cur_cube_c1hwc0 - hwc0 * desc_c1 - w * c0 * desc_h) / c0;
|
||||||
|
auto c_idx = desc_c1 * c0 + desc_c0;
|
||||||
|
auto src_idx = desc_g * nhwc + desc_n * hwc + c_idx * h * w + desc_h * w + desc_w;
|
||||||
|
auto pad_zero = desc_g >= 1 || desc_n >= n || c_idx >= c;
|
||||||
|
SetDataBysize(size, pad_zero);
|
||||||
|
dst_idx++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool NchwToNc1hwc04(const FormatArgs &args, void *result) {
|
||||||
|
MS_LOG(DEBUG) << "Trans format from nchw to Nc1hwc04.";
|
||||||
|
return NchwToNc1hwc0(args, result);
|
||||||
|
}
|
||||||
|
bool Nc1hwc04ToNchw(const FormatArgs &args, void *result) {
|
||||||
|
MS_LOG(DEBUG) << "Trans format from Nc1hwc04 to nchw.";
|
||||||
|
return Nc1hwc0ToNchw(args, result);
|
||||||
|
}
|
||||||
|
|
||||||
bool TransShapeToNz(const std::vector<size_t> &host_shape, std::vector<size_t> *hw_shape) {
|
bool TransShapeToNz(const std::vector<size_t> &host_shape, std::vector<size_t> *hw_shape) {
|
||||||
MS_EXCEPTION_IF_NULL(hw_shape);
|
MS_EXCEPTION_IF_NULL(hw_shape);
|
||||||
if (host_shape.empty()) {
|
if (host_shape.empty()) {
|
||||||
|
|
|
@ -64,11 +64,14 @@ bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result);
|
||||||
bool NchwToFracZ(const FormatArgs &args, void *result);
|
bool NchwToFracZ(const FormatArgs &args, void *result);
|
||||||
bool NchwToFracNz(const FormatArgs &args, void *result);
|
bool NchwToFracNz(const FormatArgs &args, void *result);
|
||||||
bool NchwToNc1hwc0(const FormatArgs &args, void *result);
|
bool NchwToNc1hwc0(const FormatArgs &args, void *result);
|
||||||
|
bool NchwToFracZc04(const FormatArgs &args, void *result);
|
||||||
|
bool NchwToNc1hwc04(const FormatArgs &args, void *result);
|
||||||
bool NchwToC1hwncoc0(const FormatArgs &args, void *result);
|
bool NchwToC1hwncoc0(const FormatArgs &args, void *result);
|
||||||
// device to host
|
// device to host
|
||||||
bool FracZToNchw(const FormatArgs &args, void *result);
|
bool FracZToNchw(const FormatArgs &args, void *result);
|
||||||
bool FracNzToNchw(const FormatArgs &args, void *result);
|
bool FracNzToNchw(const FormatArgs &args, void *result);
|
||||||
bool Nc1hwc0ToNchw(const FormatArgs &args, void *result);
|
bool Nc1hwc0ToNchw(const FormatArgs &args, void *result);
|
||||||
|
bool Nc1hwc04ToNchw(const FormatArgs &args, void *result);
|
||||||
bool C1hwncoc0ToNchw(const FormatArgs &args, void *result);
|
bool C1hwncoc0ToNchw(const FormatArgs &args, void *result);
|
||||||
} // namespace trans
|
} // namespace trans
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
Loading…
Reference in New Issue