forked from mindspore-Ecosystem/mindspore
Enable TensorCore for training wide&deep and deepfm on GPU
This commit is contained in:
parent
c5ed68dd05
commit
689945ba43
|
@ -26,22 +26,14 @@ namespace opt {
|
|||
namespace {
|
||||
using vec = std::vector<size_t>;
|
||||
|
||||
// A[K,M] && B[K,N] M,N pad 32, K pad 4
|
||||
auto GetPadShape1 = [](size_t K, size_t M, size_t N) {
|
||||
size_t pad_K = ((K - 1) / 4 + 1) * 4;
|
||||
// M,N pad 32, K pad 16
|
||||
auto GetPadShape = [](size_t K, size_t M, size_t N) {
|
||||
size_t pad_K = ((K - 1) / 16 + 1) * 16;
|
||||
size_t pad_M = ((M - 1) / 32 + 1) * 32;
|
||||
size_t pad_N = ((N - 1) / 32 + 1) * 32;
|
||||
return std::tuple(pad_K, pad_M, pad_N);
|
||||
};
|
||||
|
||||
// M,N pad 16, K pad 8
|
||||
auto GetPadShape2 = [](size_t K, size_t M, size_t N) {
|
||||
size_t pad_K = ((K - 1) / 8 + 1) * 8;
|
||||
size_t pad_M = ((M - 1) / 16 + 1) * 16;
|
||||
size_t pad_N = ((N - 1) / 16 + 1) * 16;
|
||||
return std::tuple(pad_K, pad_M, pad_N);
|
||||
};
|
||||
|
||||
// Get (K M .. pad_N) when tran_a is true and tran_b is false
|
||||
auto TransANotTransB = [](const vec &shape_a, const vec &shape_b, vec *pad_shape_a, vec *pad_shape_b) {
|
||||
size_t K, M, N, pad_K, pad_M, pad_N;
|
||||
|
@ -49,7 +41,7 @@ auto TransANotTransB = [](const vec &shape_a, const vec &shape_b, vec *pad_shape
|
|||
K = shape_a[size - 2];
|
||||
M = shape_a[size - 1];
|
||||
N = shape_b[size - 1];
|
||||
std::tie(pad_K, pad_M, pad_N) = GetPadShape1(K, M, N);
|
||||
std::tie(pad_K, pad_M, pad_N) = GetPadShape(K, M, N);
|
||||
pad_shape_a->push_back(pad_K);
|
||||
pad_shape_a->push_back(pad_M);
|
||||
pad_shape_b->push_back(pad_K);
|
||||
|
@ -64,7 +56,7 @@ auto TransATransB = [](const vec &shape_a, const vec &shape_b, vec *pad_shape_a,
|
|||
K = shape_a[size - 2];
|
||||
M = shape_a[size - 1];
|
||||
N = shape_b[size - 2];
|
||||
std::tie(pad_K, pad_M, pad_N) = GetPadShape2(K, M, N);
|
||||
std::tie(pad_K, pad_M, pad_N) = GetPadShape(K, M, N);
|
||||
pad_shape_a->push_back(pad_K);
|
||||
pad_shape_a->push_back(pad_M);
|
||||
pad_shape_b->push_back(pad_N);
|
||||
|
@ -79,7 +71,7 @@ auto NotTransATransB = [](const vec &shape_a, const vec &shape_b, vec *pad_shape
|
|||
K = shape_a[size - 1];
|
||||
M = shape_a[size - 2];
|
||||
N = shape_b[size - 2];
|
||||
std::tie(pad_K, pad_M, pad_N) = GetPadShape2(K, M, N);
|
||||
std::tie(pad_K, pad_M, pad_N) = GetPadShape(K, M, N);
|
||||
pad_shape_a->push_back(pad_M);
|
||||
pad_shape_a->push_back(pad_K);
|
||||
pad_shape_b->push_back(pad_N);
|
||||
|
@ -94,7 +86,7 @@ auto NotTransANotTransB = [](const vec &shape_a, const vec &shape_b, vec *pad_sh
|
|||
K = shape_a[size - 1];
|
||||
M = shape_a[size - 2];
|
||||
N = shape_b[size - 1];
|
||||
std::tie(pad_K, pad_M, pad_N) = GetPadShape2(K, M, N);
|
||||
std::tie(pad_K, pad_M, pad_N) = GetPadShape(K, M, N);
|
||||
pad_shape_a->push_back(pad_M);
|
||||
pad_shape_a->push_back(pad_K);
|
||||
pad_shape_b->push_back(pad_K);
|
||||
|
|
|
@ -65,6 +65,7 @@ if __name__ == '__main__':
|
|||
elif args_opt.device_target == "GPU":
|
||||
init()
|
||||
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target=args_opt.device_target)
|
||||
context.set_context(graph_kernel_flags="--enable_cluster_ops=MatMul")
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(device_num=get_group_size(),
|
||||
parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
|
@ -79,6 +80,7 @@ if __name__ == '__main__':
|
|||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=device_id)
|
||||
elif args_opt.device_target == "GPU":
|
||||
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target=args_opt.device_target)
|
||||
context.set_context(graph_kernel_flags="--enable_cluster_ops=MatMul")
|
||||
else:
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
|
||||
rank_size = None
|
||||
|
|
|
@ -91,4 +91,6 @@ if __name__ == "__main__":
|
|||
_enable_graph_kernel = config.device_target == "GPU"
|
||||
context.set_context(mode=context.GRAPH_MODE,
|
||||
enable_graph_kernel=_enable_graph_kernel, device_target=config.device_target)
|
||||
if _enable_graph_kernel:
|
||||
context.set_context(graph_kernel_flags="--enable_cluster_ops=MatMul")
|
||||
test_train(config)
|
||||
|
|
|
@ -108,5 +108,7 @@ if __name__ == "__main__":
|
|||
_enable_graph_kernel = wide_deep_config.device_target == "GPU"
|
||||
context.set_context(mode=context.GRAPH_MODE,
|
||||
enable_graph_kernel=_enable_graph_kernel, device_target=wide_deep_config.device_target)
|
||||
if _enable_graph_kernel:
|
||||
context.set_context(graph_kernel_flags="--enable_cluster_ops=MatMul")
|
||||
context.set_context(enable_sparse=wide_deep_config.sparse)
|
||||
test_train_eval(wide_deep_config)
|
||||
|
|
|
@ -121,6 +121,12 @@ if __name__ == "__main__":
|
|||
wide_deep_config.argparse_init()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=wide_deep_config.device_target, save_graphs=True)
|
||||
|
||||
_enable_graph_kernel = wide_deep_config.device_target == "GPU"
|
||||
if _enable_graph_kernel:
|
||||
context.set_context(enable_graph_kernel=True)
|
||||
context.set_context(graph_kernel_flags="--enable_cluster_ops=MatMul")
|
||||
|
||||
context.set_context(enable_sparse=wide_deep_config.sparse)
|
||||
init()
|
||||
context.set_context(save_graphs_path='./graphs_of_device_id_'+str(get_rank()))
|
||||
|
|
|
@ -166,4 +166,5 @@ if __name__ == "__main__":
|
|||
context.set_context(enable_sparse=True)
|
||||
if wide_deep_config.device_target == "GPU":
|
||||
context.set_context(enable_graph_kernel=True)
|
||||
context.set_context(graph_kernel_flags="--enable_cluster_ops=MatMul")
|
||||
train_and_eval(wide_deep_config)
|
||||
|
|
|
@ -127,6 +127,7 @@ if __name__ == "__main__":
|
|||
context.set_context(enable_sparse=True)
|
||||
if wide_deep_config.device_target == "GPU":
|
||||
context.set_context(enable_graph_kernel=True)
|
||||
context.set_context(graph_kernel_flags="--enable_cluster_ops=MatMul")
|
||||
context.set_ps_context(enable_ps=True)
|
||||
|
||||
train_and_eval(wide_deep_config)
|
||||
|
|
Loading…
Reference in New Issue