fix resnet50 of thor

This commit is contained in:
wangzidong 2021-05-28 17:06:01 +08:00
parent 40ca285ab3
commit fea69fe0e6
15 changed files with 2059 additions and 3887 deletions

View File

@ -23,7 +23,7 @@ cus_fused_abs_max1_op_info = TBERegOp("CusFusedAbsMax1") \
.async_flag(False) \
.binfile_name("fusedabsmax1.so") \
.compute_cost(10) \
.kernel_name("CusFusedAbsMax1") \
.kernel_name("cus_fused_abs_max1") \
.partial_flag(True) \
.attr("origin_shape", "required", "listInt", "all") \
.input(0, "x1", False, "required", "all") \
@ -32,47 +32,39 @@ cus_fused_abs_max1_op_info = TBERegOp("CusFusedAbsMax1") \
.get_op_info()
@op_info_register(cus_fused_abs_max1_op_info)
def CusFusedAbsMax1(input_x, output, origin_shape=None, kernel_name="fused_abs_max1"):
"""CusFusedAbsMax1"""
input_x_shape = input_x.get("shape")
output_shape = output.get("shape")
dtype = input_x.get("dtype")
def _update_tik(tik_instance, input_x_ub, broadcast_0_local_ub, block_index, res):
"""_update_tik"""
with tik_instance.for_range(0, 64) as cc0:
data_temp = tik_instance.Scalar("float32")
data_temp.set_as(input_x_ub[cc0])
tik_instance.vector_dup(64, broadcast_0_local_ub[cc0 * 64], data_temp, 1, 1, 8)
tik_instance.vmax(64, broadcast_0_local_ub, broadcast_0_local_ub, broadcast_0_local_ub[2048], 32, 1, 1,
1, 8, 8, 8)
tik_instance.vmax(64, broadcast_0_local_ub, broadcast_0_local_ub, broadcast_0_local_ub[1024], 16, 1, 1,
1, 8, 8, 8)
tik_instance.vmax(64, broadcast_0_local_ub, broadcast_0_local_ub, broadcast_0_local_ub[512], 8, 1, 1, 1,
8, 8, 8)
tik_instance.vmax(64, broadcast_0_local_ub, broadcast_0_local_ub, broadcast_0_local_ub[256], 4, 1, 1, 1,
8, 8, 8)
tik_instance.vmax(64, broadcast_0_local_ub, broadcast_0_local_ub, broadcast_0_local_ub[128], 2, 1, 1, 1,
8, 8, 8)
tik_instance.vmax(64, broadcast_0_local_ub, broadcast_0_local_ub, broadcast_0_local_ub[64], 1, 1, 1, 1,
8, 8, 8)
tik_instance.data_move(res[block_index, 0], broadcast_0_local_ub, 0, 1, 8, 0, 0)
return tik_instance, res
if util.get_product_version() == util.VERSION_MINI:
tik_instance = tik.Tik(tik.Dprofile("v100", "mini"))
else:
tik_instance = tik.Tik(tik.Dprofile("v100", "cloud"))
support_shape = [((1, 128, 128), "float32"),
((2, 128, 128), "float32"),
((4, 128, 128), "float32"),
((8, 128, 128), "float32"),
((16, 128, 128), "float32"),
((5, 128, 128), "float32"),
((9, 128, 128), "float32"),
((18, 128, 128), "float32"),
((36, 128, 128), "float32"),
((32, 128, 128), "float32"),
((1, 64, 64), "float32"),
((32, 64), "float32")
]
ori_shape = tuple(origin_shape)
input_info = (tuple(input_x_shape), dtype)
if input_info not in support_shape:
raise RuntimeError("input_shape %s is not supported" % str(input_info))
if input_info == ((1, 128, 128), "float32"):
input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm)
res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm)
total_elements = 1
def shape0(tik_instance, input_x_shape, input_x, res):
"""shape0"""
total_elements0 = 1
for val in input_x_shape:
total_elements *= val
total_elements0 *= val
blocks = 32
each_block_element = total_elements // blocks
each_block_element = total_elements0 // blocks
with tik_instance.for_range(0, blocks, block_num=blocks) as block_index:
input_x_ub = tik_instance.Tensor("float32", (each_block_element,), name="input_x_ub",
scope=tik.scope_ubuf)
broadcast_0_local_UB = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_UB",
broadcast_0_local_ub = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_ub",
scope=tik.scope_ubuf)
tik_instance.data_move(input_x_ub, input_x[each_block_element * block_index], 0, 1,
each_block_element // 8, 0, 0)
@ -81,35 +73,21 @@ def CusFusedAbsMax1(input_x, output, origin_shape=None, kernel_name="fused_abs_m
tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[256], 4, 1, 1, 1, 8, 8, 8)
tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[128], 2, 1, 1, 1, 8, 8, 8)
tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8)
with tik_instance.for_range(0, 64) as cc0:
data_temp = tik_instance.Scalar("float32")
data_temp.set_as(input_x_ub[cc0])
tik_instance.vector_dup(64, broadcast_0_local_UB[cc0 * 64], data_temp, 1, 1, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1, 1,
1, 8, 8, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1, 1,
1, 8, 8, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1, 1, 1,
8, 8, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1, 1, 1,
8, 8, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1, 1, 1,
8, 8, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1, 1,
8, 8, 8)
tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0)
elif input_info == ((2, 128, 128), "float32"):
tik_instance, res = _update_tik(tik_instance, input_x_ub, broadcast_0_local_ub, block_index, res)
return tik_instance, res
def shape1(tik_instance, input_x_shape, ori_shape, input_x, res):
"""shape1"""
if ori_shape == (147, 147):
phase_1 = 16384
phase_2 = 1216
blocks = 32
each_block_element = phase_1 // blocks + 64
input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm)
res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm)
with tik_instance.for_range(0, blocks, block_num=blocks) as block_index:
input_x_ub = tik_instance.Tensor("float32", (each_block_element,), name="input_x_ub",
scope=tik.scope_ubuf)
broadcast_0_local_UB = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_UB",
broadcast_0_local_ub = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_ub",
scope=tik.scope_ubuf)
tik_instance.data_move(input_x_ub, input_x[512 * block_index], 0, 1, 512 // 8, 0, 0)
line_id = block_index % 19
@ -120,35 +98,17 @@ def CusFusedAbsMax1(input_x, output, origin_shape=None, kernel_name="fused_abs_m
tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[256], 4, 1, 1, 1, 8, 8, 8)
tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[128], 2, 1, 1, 1, 8, 8, 8)
tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8)
with tik_instance.for_range(0, 64) as cc0:
data_temp = tik_instance.Scalar("float32")
data_temp.set_as(input_x_ub[cc0])
tik_instance.vector_dup(64, broadcast_0_local_UB[cc0 * 64], data_temp, 1, 1, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1,
1, 1, 8, 8, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1,
1, 1, 8, 8, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1,
1, 1, 8, 8, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1,
1, 1, 8, 8, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1,
1, 1, 8, 8, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1,
1, 8, 8, 8)
tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0)
tik_instance, res = _update_tik(tik_instance, input_x_ub, broadcast_0_local_ub, block_index, res)
elif ori_shape in ((256, 256), None, (-1, -1)):
input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm)
res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm)
total_elements = 1
total_elements1 = 1
for val in input_x_shape:
total_elements *= val
total_elements1 *= val
blocks = 32
each_block_element = total_elements // blocks
each_block_element = total_elements1 // blocks
with tik_instance.for_range(0, blocks, block_num=blocks) as block_index:
input_x_ub = tik_instance.Tensor("float32", (each_block_element,), name="input_x_ub",
scope=tik.scope_ubuf)
broadcast_0_local_UB = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_UB",
broadcast_0_local_ub = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_ub",
scope=tik.scope_ubuf)
tik_instance.data_move(input_x_ub, input_x[each_block_element * block_index], 0, 1,
each_block_element // 8, 0, 0)
@ -158,37 +118,23 @@ def CusFusedAbsMax1(input_x, output, origin_shape=None, kernel_name="fused_abs_m
tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[256], 4, 1, 1, 1, 8, 8, 8)
tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[128], 2, 1, 1, 1, 8, 8, 8)
tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8)
with tik_instance.for_range(0, 64) as cc0:
data_temp = tik_instance.Scalar("float32")
data_temp.set_as(input_x_ub[cc0])
tik_instance.vector_dup(64, broadcast_0_local_UB[cc0 * 64], data_temp, 1, 1, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1,
1, 1, 8, 8, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1,
1, 1, 8, 8, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1,
1, 1, 8, 8, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1,
1, 1, 8, 8, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1,
1, 1, 8, 8, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1,
1, 8, 8, 8)
tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0)
tik_instance, res = _update_tik(tik_instance, input_x_ub, broadcast_0_local_ub, block_index, res)
else:
raise RuntimeError("origin shape %s is not supported" % str(ori_shape))
elif input_info == ((4, 128, 128), "float32"):
input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm)
res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm)
total_elements = 1
return tik_instance, res
def shape2(tik_instance, input_x_shape, input_x, res):
"""shape2"""
total_elements2 = 1
for val in input_x_shape:
total_elements *= val
total_elements2 *= val
blocks = 32
each_block_element = total_elements // blocks
each_block_element = total_elements2 // blocks
with tik_instance.for_range(0, blocks, block_num=blocks) as block_index:
input_x_ub = tik_instance.Tensor("float32", (each_block_element,), name="input_x_ub",
scope=tik.scope_ubuf)
broadcast_0_local_UB = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_UB",
broadcast_0_local_ub = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_ub",
scope=tik.scope_ubuf)
tik_instance.data_move(input_x_ub, input_x[each_block_element * block_index], 0, 1,
each_block_element // 8, 0, 0)
@ -199,47 +145,32 @@ def CusFusedAbsMax1(input_x, output, origin_shape=None, kernel_name="fused_abs_m
tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[256], 4, 1, 1, 1, 8, 8, 8)
tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[128], 2, 1, 1, 1, 8, 8, 8)
tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8)
with tik_instance.for_range(0, 64) as cc0:
data_temp = tik_instance.Scalar("float32")
data_temp.set_as(input_x_ub[cc0])
tik_instance.vector_dup(64, broadcast_0_local_UB[cc0 * 64], data_temp, 1, 1, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1, 1,
1, 8, 8, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1, 1,
1, 8, 8, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1, 1, 1,
8, 8, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1, 1, 1,
8, 8, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1, 1, 1,
8, 8, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1, 1,
8, 8, 8)
tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0)
elif input_info == ((8, 128, 128), "float32"):
if ori_shape == (1000, 1000):
input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm)
res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm)
tik_instance, res = _update_tik(tik_instance, input_x_ub, broadcast_0_local_ub, block_index, res)
return tik_instance, res
def shape3_1000(tik_instance, input_x, res):
"""shape3_1000"""
blocks = 32
each_block_element = 7 * 128 * 128 // 32 + 4 * 128
phase_1 = 7 * 128 * 128 // 32
phase_0 = 7 * 128 * 128 // 32
with tik_instance.for_range(0, blocks, block_num=blocks) as block_index:
input_x_ub = tik_instance.Tensor("float32", (each_block_element,), name="input_x_ub",
scope=tik.scope_ubuf)
broadcast_0_local_UB = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_UB",
broadcast_0_local_ub = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_ub",
scope=tik.scope_ubuf)
tik_instance.data_move(input_x_ub, input_x[phase_1 * block_index], 0, 1, phase_1 // 8, 0, 0)
tik_instance.data_move(input_x_ub[phase_1], input_x[114688 + block_index * 384], 0, 1, 384 // 8, 0,
tik_instance.data_move(input_x_ub, input_x[phase_0 * block_index], 0, 1, phase_0 // 8, 0, 0)
tik_instance.data_move(input_x_ub[phase_0], input_x[114688 + block_index * 384], 0, 1, 384 // 8, 0,
0)
move_idx = block_index % 8
tik_instance.data_move(input_x_ub[phase_1 + 384], input_x[114688 + 96 * 128 + move_idx * 128], 0, 1,
tik_instance.data_move(input_x_ub[phase_0 + 384], input_x[114688 + 96 * 128 + move_idx * 128], 0, 1,
128 // 8, 0, 0)
repeat_time = each_block_element // 64
tik_instance.vabs(64, input_x_ub, input_x_ub, repeat_time, 1, 1, 8, 8)
vmask = 1000 - 7 * 128 - 64
with tik_instance.for_range(0, 4) as loop_idx:
tik_instance.vmax(vmask, input_x_ub[3584 + 128 * loop_idx], input_x_ub[3584 + 128 * loop_idx],
input_x_ub[3584 + 128 * loop_idx + 64], 1, 1, 1, 1, 8, 8, 8)
with tik_instance.for_range(0, 4) as loop_idx0:
tik_instance.vmax(vmask, input_x_ub[3584 + 128 * loop_idx0], input_x_ub[3584 + 128 * loop_idx0],
input_x_ub[3584 + 128 * loop_idx0 + 64], 1, 1, 1, 1, 8, 8, 8)
tik_instance.vmax(64, input_x_ub, input_x_ub[512], input_x_ub[2048], 24, 1, 1, 1, 8, 8, 8)
tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[1024], 16, 1, 1, 1, 8, 8, 8)
tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[512], 8, 1, 1, 1, 8, 8, 8)
@ -247,42 +178,26 @@ def CusFusedAbsMax1(input_x, output, origin_shape=None, kernel_name="fused_abs_m
tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[128], 2, 1, 1, 1, 8, 8, 8)
tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8)
with tik_instance.for_range(0, 4) as loop_idx:
tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[3584 + 128 * loop_idx], 1, 1, 1, 1, 8,
with tik_instance.for_range(0, 4) as loop_idx0:
tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[3584 + 128 * loop_idx0], 1, 1, 1, 1, 8,
8, 8)
with tik_instance.for_range(0, 64) as cc0:
data_temp = tik_instance.Scalar("float32")
data_temp.set_as(input_x_ub[cc0])
tik_instance.vector_dup(64, broadcast_0_local_UB[cc0 * 64], data_temp, 1, 1, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1,
1, 1, 8, 8, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1,
1, 1, 8, 8, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1,
1, 1, 8, 8, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1,
1, 1, 8, 8, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1,
1, 1, 8, 8, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1,
1, 8, 8, 8)
tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0)
elif ori_shape == (1001, 1001):
input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm)
res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm)
tik_instance, res = _update_tik(tik_instance, input_x_ub, broadcast_0_local_ub, block_index, res)
return tik_instance, res
def shape3_1001(tik_instance, input_x, res):
"""shape3_1001"""
blocks = 32
each_block_element = 7 * 128 * 128 // 32 + 4 * 128
phase_1 = 7 * 128 * 128 // 32
with tik_instance.for_range(0, blocks, block_num=blocks) as block_index:
input_x_ub = tik_instance.Tensor("float32", (each_block_element,), name="input_x_ub",
scope=tik.scope_ubuf)
broadcast_0_local_UB = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_UB",
broadcast_0_local_ub = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_ub",
scope=tik.scope_ubuf)
tik_instance.data_move(input_x_ub, input_x[phase_1 * block_index], 0, 1, phase_1 // 8, 0, 0)
tik_instance.data_move(input_x_ub[phase_1], input_x[114688 + block_index * 384], 0, 1, 384 // 8, 0,
0)
tik_instance.data_move(input_x_ub[phase_1], input_x[114688 + block_index * 384], 0, 1, 384 // 8, 0,
0)
tik_instance.data_move(input_x_ub[phase_1], input_x[114688 + block_index * 384], 0, 1, 384 // 8, 0, 0)
tik_instance.data_move(input_x_ub[phase_1], input_x[114688 + block_index * 384], 0, 1, 384 // 8, 0, 0)
move_idx = block_index % 9
tik_instance.data_move(input_x_ub[phase_1 + 384], input_x[114688 + 96 * 128 + move_idx * 128], 0, 1,
128 // 8, 0, 0)
@ -301,37 +216,28 @@ def CusFusedAbsMax1(input_x, output, origin_shape=None, kernel_name="fused_abs_m
with tik_instance.for_range(0, 4) as loop_idx:
tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[3584 + 128 * loop_idx], 1, 1, 1, 1, 8,
8, 8)
with tik_instance.for_range(0, 64) as cc0:
data_temp = tik_instance.Scalar("float32")
data_temp.set_as(input_x_ub[cc0])
tik_instance.vector_dup(64, broadcast_0_local_UB[cc0 * 64], data_temp, 1, 1, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1,
1, 1, 8, 8, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1,
1, 1, 8, 8, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1,
1, 1, 8, 8, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1,
1, 1, 8, 8, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1,
1, 1, 8, 8, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1,
1, 8, 8, 8)
tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0)
tik_instance, res = _update_tik(tik_instance, input_x_ub, broadcast_0_local_ub, block_index, res)
return tik_instance, res
def shape3(tik_instance, input_x_shape, ori_shape, input_x, res):
"""shape3"""
if ori_shape == (1000, 1000):
tik_instance, res = shape3_1000(tik_instance, input_x, res)
elif ori_shape == (1001, 1001):
tik_instance, res = shape3_1001(tik_instance, input_x, res)
elif ori_shape in ((1024, 1024), None, (-1, -1)):
input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm)
res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm)
total_elements = 1
total_elements3 = 1
for val in input_x_shape:
total_elements *= val
total_elements3 *= val
blocks = 32
each_block_element = total_elements // blocks
with tik_instance.for_range(0, blocks, block_num=blocks) as block_index:
each_block_element = total_elements3 // blocks
with tik_instance.for_range(0, blocks, block_num=blocks) as block_index0:
input_x_ub = tik_instance.Tensor("float32", (each_block_element,), name="input_x_ub",
scope=tik.scope_ubuf)
broadcast_0_local_UB = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_UB",
broadcast_0_local_ub = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_ub",
scope=tik.scope_ubuf)
tik_instance.data_move(input_x_ub, input_x[each_block_element * block_index], 0, 1,
tik_instance.data_move(input_x_ub, input_x[each_block_element * block_index0], 0, 1,
each_block_element // 8, 0, 0)
repeat_time = each_block_element // 64
tik_instance.vabs(64, input_x_ub, input_x_ub, repeat_time, 1, 1, 8, 8)
@ -341,42 +247,28 @@ def CusFusedAbsMax1(input_x, output, origin_shape=None, kernel_name="fused_abs_m
tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[256], 4, 1, 1, 1, 8, 8, 8)
tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[128], 2, 1, 1, 1, 8, 8, 8)
tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8)
with tik_instance.for_range(0, 64) as cc0:
data_temp = tik_instance.Scalar("float32")
data_temp.set_as(input_x_ub[cc0])
tik_instance.vector_dup(64, broadcast_0_local_UB[cc0 * 64], data_temp, 1, 1, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1,
1, 1, 8, 8, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1,
1, 1, 8, 8, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1,
1, 1, 8, 8, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1,
1, 1, 8, 8, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1,
1, 1, 8, 8, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1,
1, 8, 8, 8)
tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0)
tik_instance, res = _update_tik(tik_instance, input_x_ub, broadcast_0_local_ub, block_index0, res)
else:
raise RuntimeError("origin shape %s is not supported" % str(ori_shape))
elif input_info == ((16, 128, 128), "float32"):
input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm)
res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm)
total_elements = 1
return tik_instance, res
def shape4(tik_instance, input_x_shape, input_x, res):
"""shape4"""
total_elements4 = 1
for val in input_x_shape:
total_elements *= val
total_elements4 *= val
blocks = 32
each_block_element = total_elements // blocks
with tik_instance.for_range(0, blocks, block_num=blocks) as block_index:
each_block_element = total_elements4 // blocks
with tik_instance.for_range(0, blocks, block_num=blocks) as block_index1:
input_x_ub = tik_instance.Tensor("float32", (each_block_element,), name="input_x_ub",
scope=tik.scope_ubuf)
broadcast_0_local_UB = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_UB",
broadcast_0_local_ub = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_ub",
scope=tik.scope_ubuf)
tik_instance.data_move(input_x_ub, input_x[each_block_element * block_index], 0, 1,
tik_instance.data_move(input_x_ub, input_x[each_block_element * block_index1], 0, 1,
each_block_element // 8, 0, 0)
repeat_time = each_block_element // 64
tik_instance.vabs(64, input_x_ub, input_x_ub, repeat_time, 1, 1, 8, 8)
repeat_time1 = each_block_element // 64
tik_instance.vabs(64, input_x_ub, input_x_ub, repeat_time1, 1, 1, 8, 8)
tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[4096], 64, 1, 1, 1, 8, 8, 8)
tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[2048], 32, 1, 1, 1, 8, 8, 8)
tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[1024], 16, 1, 1, 1, 8, 8, 8)
@ -384,35 +276,21 @@ def CusFusedAbsMax1(input_x, output, origin_shape=None, kernel_name="fused_abs_m
tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[256], 4, 1, 1, 1, 8, 8, 8)
tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[128], 2, 1, 1, 1, 8, 8, 8)
tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8)
with tik_instance.for_range(0, 64) as cc0:
data_temp = tik_instance.Scalar("float32")
data_temp.set_as(input_x_ub[cc0])
tik_instance.vector_dup(64, broadcast_0_local_UB[cc0 * 64], data_temp, 1, 1, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1, 1,
1, 8, 8, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1, 1,
1, 8, 8, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1, 1, 1,
8, 8, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1, 1, 1,
8, 8, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1, 1, 1,
8, 8, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1, 1,
8, 8, 8)
tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0)
elif input_info == ((32, 128, 128), "float32"):
input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm)
res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm)
total_elements = 1
tik_instance, res = _update_tik(tik_instance, input_x_ub, broadcast_0_local_ub, block_index1, res)
return tik_instance, res
def shape5(tik_instance, input_x_shape, input_x, res):
"""shape5"""
total_elements5 = 1
for val in input_x_shape:
total_elements *= val
total_elements5 *= val
blocks = 32
each_block_element = total_elements // blocks
each_block_element = total_elements5 // blocks
with tik_instance.for_range(0, blocks, block_num=blocks) as block_index:
input_x_ub = tik_instance.Tensor("float32", (each_block_element,), name="input_x_ub",
scope=tik.scope_ubuf)
broadcast_0_local_UB = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_UB",
broadcast_0_local_ub = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_ub",
scope=tik.scope_ubuf)
tik_instance.data_move(input_x_ub, input_x[each_block_element * block_index], 0, 1,
each_block_element // 8, 0, 0)
@ -426,36 +304,22 @@ def CusFusedAbsMax1(input_x, output, origin_shape=None, kernel_name="fused_abs_m
tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[256], 4, 1, 1, 1, 8, 8, 8)
tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[128], 2, 1, 1, 1, 8, 8, 8)
tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8)
with tik_instance.for_range(0, 64) as cc0:
data_temp = tik_instance.Scalar("float32")
data_temp.set_as(input_x_ub[cc0])
tik_instance.vector_dup(64, broadcast_0_local_UB[cc0 * 64], data_temp, 1, 1, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1, 1,
1, 8, 8, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1, 1,
1, 8, 8, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1, 1, 1,
8, 8, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1, 1, 1,
8, 8, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1, 1, 1,
8, 8, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1, 1,
8, 8, 8)
tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0)
elif input_info == ((5, 128, 128), "float32"):
tik_instance, res = _update_tik(tik_instance, input_x_ub, broadcast_0_local_ub, block_index, res)
return tik_instance, res
def shape6(tik_instance, ori_shape, input_x, res):
"""shape6"""
if ori_shape == (576, 576):
input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm)
res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm)
total_elements = 69632
total_elements6 = 69632
blocks = 32
each_block_element = total_elements // blocks
each_block_element = total_elements6 // blocks
phase_1 = 2048
phase_2 = 128
with tik_instance.for_range(0, blocks, block_num=blocks) as block_index:
input_x_ub = tik_instance.Tensor("float32", (each_block_element,), name="input_x_ub",
scope=tik.scope_ubuf)
broadcast_0_local_UB = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_UB",
broadcast_0_local_ub = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_ub",
scope=tik.scope_ubuf)
tik_instance.data_move(input_x_ub, input_x[phase_1 * block_index], 0, 1, phase_1 // 8, 0, 0)
tik_instance.data_move(input_x_ub[phase_1], input_x[65536 + phase_2 * block_index * 2], 0, 1, 8, 0, 0)
@ -470,37 +334,23 @@ def CusFusedAbsMax1(input_x, output, origin_shape=None, kernel_name="fused_abs_m
tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[128], 2, 1, 1, 1, 8, 8, 8)
tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8)
tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[2048], 1, 1, 1, 1, 8, 8, 8)
with tik_instance.for_range(0, 64) as cc0:
data_temp = tik_instance.Scalar("float32")
data_temp.set_as(input_x_ub[cc0])
tik_instance.vector_dup(64, broadcast_0_local_UB[cc0 * 64], data_temp, 1, 1, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1, 1,
1, 8, 8, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1, 1,
1, 8, 8, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1, 1, 1,
8, 8, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1, 1, 1,
8, 8, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1, 1, 1,
8, 8, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1, 1,
8, 8, 8)
tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0)
tik_instance, res = _update_tik(tik_instance, input_x_ub, broadcast_0_local_ub, block_index, res)
else:
raise RuntimeError("origin shape %s is not supported" % str(ori_shape))
elif input_info == ((9, 128, 128), "float32"):
input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm)
res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm)
total_elements = 1
return tik_instance, res
def shape7(tik_instance, input_x_shape, input_x, res):
"""shape7"""
total_elements7 = 1
for val in input_x_shape:
total_elements *= val
total_elements7 *= val
blocks = 32
each_block_element = total_elements // blocks
each_block_element = total_elements7 // blocks
with tik_instance.for_range(0, blocks, block_num=blocks) as block_index:
input_x_ub = tik_instance.Tensor("float32", (each_block_element,), name="input_x_ub",
scope=tik.scope_ubuf)
broadcast_0_local_UB = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_UB",
broadcast_0_local_ub = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_ub",
scope=tik.scope_ubuf)
tik_instance.data_move(input_x_ub, input_x[each_block_element * block_index], 0, 1,
each_block_element // 8, 0, 0)
@ -516,35 +366,21 @@ def CusFusedAbsMax1(input_x, output, origin_shape=None, kernel_name="fused_abs_m
tik_instance.vmax(64, input_x_ub[4096], input_x_ub[4096], input_x_ub[4096 + 128], 2, 1, 1, 1, 8, 8, 8)
tik_instance.vmax(64, input_x_ub[4096], input_x_ub[4096], input_x_ub[4096 + 64], 1, 1, 1, 1, 8, 8, 8)
tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[4096], 1, 1, 1, 1, 8, 8, 8)
with tik_instance.for_range(0, 64) as cc0:
data_temp = tik_instance.Scalar("float32")
data_temp.set_as(input_x_ub[cc0])
tik_instance.vector_dup(64, broadcast_0_local_UB[cc0 * 64], data_temp, 1, 1, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1, 1,
1, 8, 8, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1, 1,
1, 8, 8, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1, 1, 1,
8, 8, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1, 1, 1,
8, 8, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1, 1, 1,
8, 8, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1, 1,
8, 8, 8)
tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0)
elif input_info == ((18, 128, 128), "float32"):
input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm)
res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm)
total_elements = 1
tik_instance, res = _update_tik(tik_instance, input_x_ub, broadcast_0_local_ub, block_index, res)
return tik_instance, res
def shape8(tik_instance, input_x_shape, input_x, res):
"""shape8"""
total_elements8 = 1
for val in input_x_shape:
total_elements *= val
total_elements8 *= val
blocks = 32
each_block_element = total_elements // blocks
each_block_element = total_elements8 // blocks
with tik_instance.for_range(0, blocks, block_num=blocks) as block_index:
input_x_ub = tik_instance.Tensor("float32", (each_block_element,), name="input_x_ub",
scope=tik.scope_ubuf)
broadcast_0_local_UB = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_UB",
broadcast_0_local_ub = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_ub",
scope=tik.scope_ubuf)
tik_instance.data_move(input_x_ub, input_x[each_block_element * block_index], 0, 1,
each_block_element // 8, 0, 0)
@ -562,35 +398,21 @@ def CusFusedAbsMax1(input_x, output, origin_shape=None, kernel_name="fused_abs_m
tik_instance.vmax(64, input_x_ub[8192], input_x_ub[8192], input_x_ub[8192 + 128], 2, 1, 1, 1, 8, 8, 8)
tik_instance.vmax(64, input_x_ub[8192], input_x_ub[8192], input_x_ub[8192 + 64], 1, 1, 1, 1, 8, 8, 8)
tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[8192], 1, 1, 1, 1, 8, 8, 8)
with tik_instance.for_range(0, 64) as cc0:
data_temp = tik_instance.Scalar("float32")
data_temp.set_as(input_x_ub[cc0])
tik_instance.vector_dup(64, broadcast_0_local_UB[cc0 * 64], data_temp, 1, 1, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1, 1,
1, 8, 8, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1, 1,
1, 8, 8, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1, 1, 1,
8, 8, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1, 1, 1,
8, 8, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1, 1, 1,
8, 8, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1, 1,
8, 8, 8)
tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0)
elif input_info == ((36, 128, 128), "float32"):
input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm)
res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm)
total_elements = 1
tik_instance, res = _update_tik(tik_instance, input_x_ub, broadcast_0_local_ub, block_index, res)
return tik_instance, res
def shape9(tik_instance, input_x_shape, input_x, res):
"""shape9"""
total_elements9 = 1
for val in input_x_shape:
total_elements *= val
total_elements9 *= val
blocks = 32
each_block_element = total_elements // blocks
each_block_element = total_elements9 // blocks
with tik_instance.for_range(0, blocks, block_num=blocks) as block_index:
input_x_ub = tik_instance.Tensor("float32", (each_block_element,), name="input_x_ub",
scope=tik.scope_ubuf)
broadcast_0_local_UB = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_UB",
broadcast_0_local_ub = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_ub",
scope=tik.scope_ubuf)
tik_instance.data_move(input_x_ub, input_x[each_block_element * block_index], 0, 1,
each_block_element // 8, 0, 0)
@ -617,58 +439,90 @@ def CusFusedAbsMax1(input_x, output, origin_shape=None, kernel_name="fused_abs_m
8)
tik_instance.vmax(64, input_x_ub[16384], input_x_ub[16384], input_x_ub[16384 + 64], 1, 1, 1, 1, 8, 8, 8)
tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[16384], 1, 1, 1, 1, 8, 8, 8)
with tik_instance.for_range(0, 64) as cc0:
data_temp = tik_instance.Scalar("float32")
data_temp.set_as(input_x_ub[cc0])
tik_instance.vector_dup(64, broadcast_0_local_UB[cc0 * 64], data_temp, 1, 1, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1, 1,
1, 8, 8, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1, 1,
1, 8, 8, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1, 1, 1,
8, 8, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1, 1, 1,
8, 8, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1, 1, 1,
8, 8, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1, 1,
8, 8, 8)
tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0)
elif input_info == ((1, 64, 64), "float32"):
input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm)
res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm)
total_elements = 1
tik_instance, res = _update_tik(tik_instance, input_x_ub, broadcast_0_local_ub, block_index, res)
return tik_instance, res
def shape10(tik_instance, input_x_shape, input_x, res):
"""shape10"""
total_elements10 = 1
for val in input_x_shape:
total_elements *= val
total_elements10 *= val
blocks = 32
each_block_element = total_elements // blocks
each_block_element = total_elements10 // blocks
with tik_instance.for_range(0, blocks, block_num=blocks) as block_index:
input_x_ub = tik_instance.Tensor("float32", (each_block_element,), name="input_x_ub",
scope=tik.scope_ubuf)
broadcast_0_local_UB = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_UB",
broadcast_0_local_ub = tik_instance.Tensor("float32", (4096,), name="broadcast_0_local_ub",
scope=tik.scope_ubuf)
tik_instance.data_move(input_x_ub, input_x[each_block_element * block_index], 0, 1,
each_block_element // 8, 0, 0)
repeat_time = each_block_element // 64
tik_instance.vabs(64, input_x_ub, input_x_ub, repeat_time, 1, 1, 8, 8)
tik_instance.vmax(64, input_x_ub, input_x_ub, input_x_ub[64], 1, 1, 1, 1, 8, 8, 8)
with tik_instance.for_range(0, 64) as cc0:
data_temp = tik_instance.Scalar("float32")
data_temp.set_as(input_x_ub[cc0])
tik_instance.vector_dup(64, broadcast_0_local_UB[cc0 * 64], data_temp, 1, 1, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[2048], 32, 1, 1,
1, 8, 8, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[1024], 16, 1, 1,
1, 8, 8, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[512], 8, 1, 1, 1,
8, 8, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[256], 4, 1, 1, 1,
8, 8, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[128], 2, 1, 1, 1,
8, 8, 8)
tik_instance.vmax(64, broadcast_0_local_UB, broadcast_0_local_UB, broadcast_0_local_UB[64], 1, 1, 1, 1,
8, 8, 8)
tik_instance.data_move(res[block_index, 0], broadcast_0_local_UB, 0, 1, 8, 0, 0)
tik_instance, res = _update_tik(tik_instance, input_x_ub, broadcast_0_local_ub, block_index, res)
return tik_instance, res
def _get_tik_instance():
"""_get_tik_instance"""
if util.get_product_version() == util.VERSION_MINI:
tik_instance = tik.Tik(tik.Dprofile("v100", "mini"))
else:
tik_instance = tik.Tik(tik.Dprofile("v100", "cloud"))
return tik_instance
@op_info_register(cus_fused_abs_max1_op_info)
def cus_fused_abs_max1(input_x, output, origin_shape=None, kernel_name="cus_fused_abs_max1"):
"""CusFusedAbsMax1"""
input_x_shape = input_x.get("shape")
output_shape = output.get("shape")
dtype = input_x.get("dtype")
tik_instance = _get_tik_instance()
support_shape = [((1, 128, 128), "float32"),
((2, 128, 128), "float32"),
((4, 128, 128), "float32"),
((8, 128, 128), "float32"),
((16, 128, 128), "float32"),
((5, 128, 128), "float32"),
((9, 128, 128), "float32"),
((18, 128, 128), "float32"),
((36, 128, 128), "float32"),
((32, 128, 128), "float32"),
((1, 64, 64), "float32"),
((32, 64), "float32")
]
ori_shape = tuple(origin_shape)
input_info = (tuple(input_x_shape), dtype)
input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm)
res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm)
if input_info not in support_shape:
raise RuntimeError("input_shape %s is not supported" % str(input_info))
if input_info == ((1, 128, 128), "float32"):
tik_instance, res = shape0(tik_instance, input_x_shape, input_x, res)
elif input_info == ((2, 128, 128), "float32"):
tik_instance, res = shape1(tik_instance, input_x_shape, ori_shape, input_x, res)
elif input_info == ((4, 128, 128), "float32"):
tik_instance, res = shape2(tik_instance, input_x_shape, input_x, res)
elif input_info == ((8, 128, 128), "float32"):
tik_instance, res = shape3(tik_instance, input_x_shape, ori_shape, input_x, res)
elif input_info == ((16, 128, 128), "float32"):
tik_instance, res = shape4(tik_instance, input_x_shape, input_x, res)
elif input_info == ((32, 128, 128), "float32"):
tik_instance, res = shape5(tik_instance, input_x_shape, input_x, res)
elif input_info == ((5, 128, 128), "float32"):
tik_instance, res = shape6(tik_instance, ori_shape, input_x, res)
elif input_info == ((9, 128, 128), "float32"):
tik_instance, res = shape7(tik_instance, input_x_shape, input_x, res)
elif input_info == ((18, 128, 128), "float32"):
tik_instance, res = shape8(tik_instance, input_x_shape, input_x, res)
elif input_info == ((36, 128, 128), "float32"):
tik_instance, res = shape9(tik_instance, input_x_shape, input_x, res)
elif input_info == ((1, 64, 64), "float32"):
tik_instance, res = shape10(tik_instance, input_x_shape, input_x, res)
elif input_info == ((32, 64), "float32"):
input_x = tik_instance.Tensor("float32", input_x_shape, name="input_x", scope=tik.scope_gm)
res = tik_instance.Tensor("float32", output_shape, name="res", scope=tik.scope_gm)

File diff suppressed because it is too large Load Diff

View File

@ -46,46 +46,12 @@ matmul_cube_dense_left_op_info = TBERegOp("CusMatMulCubeDenseLeft") \
.dtype_format(DataType.F16_Default, DataType.F16_FracNZ, DataType.F16_Default, DataType.F16_FracNZ) \
.get_op_info()
# @util.check_input_type(dict, dict, (dict, NoneType), dict, bool, bool, str)
@op_info_register(matmul_cube_dense_left_op_info)
def cus_matmul_cube_dense_left(input_x1, input_x2, bias=None, output_y=None, trans_a=False, trans_b=False,
kernel_name="cus_matmul_cube_dense_left"):
"""
calculating matrix multiplication with bias, C = A*B + bias, support input
data with fractal format.
Parameters:
shape_a: list or tuple
Shape of the first tensor a with rank > 1
shape_b: list or tuple
Shape of the second tensor b with the same type with a,
and shape_a, shape_b must be 2 dims
src_dtype: str
The data type of input, support "float32", "float16"
dst_dtype: str
The data type of output, support "float32", "float16"
trans_a: bool
If True, shape_a == transposed before multiplication
trans_b: bool
If True, shape_b == transposed before multiplication
is_fractal: bool
If True, the input data format of a and b must be fractal format
shape_bias: list or tuple
Shape of bias, only support the input data format with ND
Returns
-------
None
"""
print("!!!!come into zzt~~~~~~~!!!!")
def shape_gen1(input_x1, input_x2, output_y, kernel_name, trans_a, trans_b):
"""shape gen1"""
shape_a = input_x1.get("ori_shape")
shape_b = input_x2.get("ori_shape")
shape_output = output_y.get("ori_shape")
print("============")
print(input_x1.get("format"), input_x2.get("format"))
print(shape_a, shape_b)
print("============")
if input_x2.get("format") == "FRACTAL_Z":
n, c, h, w = shape_b
c0 = 16
@ -115,7 +81,6 @@ def cus_matmul_cube_dense_left(input_x1, input_x2, bias=None, output_y=None, tra
shape_a = _get_input_shape(shape_a)
shape_b = _get_input_shape(shape_b)
util.check_kernel_name(kernel_name)
util.check_shape_rule(shape_a)
util.check_shape_rule(shape_b)
@ -127,7 +92,10 @@ def cus_matmul_cube_dense_left(input_x1, input_x2, bias=None, output_y=None, tra
shape_b = [shape_b[1], shape_b[0]]
trans_b = bool(1 - trans_b)
return shape_a, shape_b, trans_a, trans_b, shape_output
def shape_gen2(bias, input_x1, output_y, shape_a, shape_b, trans_a, trans_b):
"""shape gen2"""
shape_bias = ()
if bias is not None and bool(bias):
shape_bias = bias.get("shape")
@ -174,22 +142,10 @@ def cus_matmul_cube_dense_left(input_x1, input_x2, bias=None, output_y=None, tra
format_a = "FRACTAL_NZ"
shape_b_temp = (shape_b_temp[0], shape_b_temp[1], shape_b_temp[2], shape_b_temp[3])
format_b = "FRACTAL_NZ"
return shape_a_temp, format_a, shape_b_temp, format_b, shape_bias, src_dtype, dst_dtype
print("=======================================")
print(shape_a_temp, shape_b_temp)
print(format_a, format_b)
print("=======================================")
tensor_bias = None
tensor_a = tvm.placeholder(shape_a_temp, name='tensor_a',
dtype=src_dtype)
tensor_b = tvm.placeholder(shape_b_temp, name='tensor_b',
dtype=src_dtype)
if shape_bias:
tensor_bias = tvm.placeholder(shape_bias, name='tensor_bias',
dtype=dst_dtype)
if shape_a_temp[0] == 63 and shape_a_temp[1] == 63 and shape_b_temp[0] == 128 and shape_b_temp[1] == 63:
def core(shape_a_temp, shape_b_temp, shape_output, kernel_name):
"""core func"""
if util.get_product_version() == util.VERSION_MINI:
tik_instance = tik.Tik(tik.Dprofile("v100", "mini"))
else:
@ -250,7 +206,56 @@ def cus_matmul_cube_dense_left(input_x1, input_x2, bias=None, output_y=None, tra
tik_instance.BuildCCE(kernel_name=kernel_name, inputs=[input_x1, input_x2], outputs=[resmatmul])
return tik_instance
print("come into tbe, shape is error!")
@op_info_register(matmul_cube_dense_left_op_info)
def cus_matmul_cube_dense_left(input_x1, input_x2, bias=None, output_y=None, trans_a=False, trans_b=False,
kernel_name="cus_matmul_cube_dense_left"):
"""
calculating matrix multiplication with bias, C = A*B + bias, support input
data with fractal format.
Parameters:
shape_a: list or tuple
Shape of the first tensor a with rank > 1
shape_b: list or tuple
Shape of the second tensor b with the same type with a,
and shape_a, shape_b must be 2 dims
src_dtype: str
The data type of input, support "float32", "float16"
dst_dtype: str
The data type of output, support "float32", "float16"
trans_a: bool
If True, shape_a == transposed before multiplication
trans_b: bool
If True, shape_b == transposed before multiplication
is_fractal: bool
If True, the input data format of a and b must be fractal format
shape_bias: list or tuple
Shape of bias, only support the input data format with ND
Returns
-------
None
"""
shape_a, shape_b, trans_a, trans_b, shape_output = shape_gen1(input_x1, input_x2, output_y, kernel_name,
trans_a, trans_b)
shape_a_temp, format_a, shape_b_temp, format_b, shape_bias, src_dtype, dst_dtype = shape_gen2(bias, input_x1,
output_y, shape_a,
shape_b, trans_a,
trans_b)
tensor_bias = None
tensor_a = tvm.placeholder(shape_a_temp, name='tensor_a',
dtype=src_dtype)
tensor_b = tvm.placeholder(shape_b_temp, name='tensor_b',
dtype=src_dtype)
if shape_bias:
tensor_bias = tvm.placeholder(shape_bias, name='tensor_bias',
dtype=dst_dtype)
if shape_a_temp[0] == 63 and shape_a_temp[1] == 63 and shape_b_temp[0] == 128 and shape_b_temp[1] == 63:
tik_instance = core(shape_a_temp, shape_b_temp, shape_output, kernel_name)
return tik_instance
result = te.lang.cce.matmul(tensor_a, tensor_b, trans_a, trans_b, format_a=format_a,
format_b=format_b, dst_dtype=dst_dtype, tensor_bias=tensor_bias)

View File

@ -130,7 +130,7 @@ class CusFusedAbsMax1(PrimitiveWithInfer):
"""Initialize CusFusedAbsMax1"""
self.init_prim_io_names(inputs=['x1'], outputs=['y'])
self.origin_shape = origin_shape
from mindspore.ops._op_impl._custom_op.fused_abs_max1_impl import CusFusedAbsMax1
from mindspore.ops._op_impl._custom_op.fused_abs_max1_impl import cus_fused_abs_max1
def infer_shape(self, data1_shape):
ll = []
@ -169,7 +169,7 @@ class CusImg2Col(PrimitiveWithInfer):
self.strides = strides
self.dilates = dilates
self.mode = mode
from mindspore.ops._op_impl._custom_op.img2col_impl import CusImg2Col
from mindspore.ops._op_impl._custom_op.img2col_impl import cus_img2col
def infer_shape(self, data1_shape):
bs, c, h, w = data1_shape

View File

@ -22,7 +22,7 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.crossentropy import CrossEntropy
from src.config import config
from src.dataset import create_dataset
from src.resnet_thor import resnet50 as resnet
from src.resnet import resnet50 as resnet
parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')

View File

@ -25,17 +25,18 @@ config = ed({
"momentum": 0.9,
"weight_decay": 5e-4,
"epoch_size": 45,
"pretrain_epoch_size": 0,
"save_checkpoint": True,
"save_checkpoint_epochs": 1,
"save_checkpoint_epochs": 2,
"keep_checkpoint_max": 15,
"save_checkpoint_path": "./",
"use_label_smooth": True,
"label_smooth_factor": 0.1,
"lr_init": 0.045,
"lr_decay": 6,
"lr_end_epoch": 70,
"damping_init": 0.03,
"damping_decay": 0.87,
"lr_init": 0.05803,
"lr_decay": 4.04839,
"lr_end_epoch": 53,
"damping_init": 0.02714,
"damping_decay": 0.50036,
"frequency": 834,
"use_dynamic_frequency": False,
"first_stage_steps": 835,

View File

@ -16,12 +16,63 @@
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common import dtype as mstype
from mindspore.nn.loss.loss import _Loss
from mindspore.ops import functional as F
from mindspore.ops import operations as P
class CrossEntropy(_Loss):
class Loss(nn.Cell):
"""
Base class for other losses.
"""
def __init__(self, reduction='mean'):
super(Loss, self).__init__()
if reduction is None:
reduction = 'none'
if reduction not in ('mean', 'sum', 'none'):
raise ValueError(f"reduction method for {reduction.lower()} is not supported")
self.average = True
self.reduce = True
if reduction == 'sum':
self.average = False
if reduction == 'none':
self.reduce = False
self.reduce_mean = P.ReduceMean()
self.reduce_sum = P.ReduceSum()
self.mul = P.Mul()
self.cast = P.Cast()
def get_axis(self, x):
shape = F.shape(x)
length = F.tuple_len(shape)
perm = F.make_range(0, length)
return perm
def get_loss(self, x, weights=1.0):
"""
Computes the weighted loss
Args:
weights: Optional `Tensor` whose rank is either 0, or the same rank as inputs, and must be broadcastable to
inputs (i.e., all dimensions must be either `1`, or the same as the corresponding inputs dimension).
"""
input_dtype = x.dtype
x = self.cast(x, mstype.float32)
weights = self.cast(weights, mstype.float32)
x = self.mul(weights, x)
if self.reduce and self.average:
x = self.reduce_mean(x, self.get_axis(x))
if self.reduce and not self.average:
x = self.reduce_sum(x, self.get_axis(x))
x = self.cast(x, input_dtype)
return x
def construct(self, base, target):
raise NotImplementedError
class CrossEntropy(Loss):
"""CrossEntropy"""
def __init__(self, smooth_factor=0., num_classes=1000):
super(CrossEntropy, self).__init__()

View File

@ -1,191 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Dataset help for minddata dataset"""
import math
import os
from mindspore._checkparam import Validator
from mindspore import context
from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes
from mindspore.nn.wrap import GetNextSingleOp
from mindspore.parallel._utils import _get_device_num, _need_to_full, _to_full_shapes
def _send_data(dataset, epoch_num):
"""Engine dataset to write data to tdt queue."""
if not hasattr(dataset, '__has_sent__'):
exec_dataset = dataset.__transfer_dataset__
exec_dataset.send(epoch_num)
dataset.__has_sent__ = True
def _send_data_no_flag(dataset, epoch_num):
"""Engine dataset to write data to tdt queue directly."""
exec_dataset = dataset.__transfer_dataset__
exec_dataset.send(epoch_num)
class DatasetHelper:
"""
Help function to use the MindData dataset.
According to different contexts, change the iterations of dataset and use the same iteration for loop in different
contexts.
Note:
The iteration of DatasetHelper will provide one epoch data.
Args:
dataset (DataSet): The training dataset iterator.
dataset_sink_mode (bool): If true use GetNext to fetch the data, or else feed the data from host. Default: True.
sink_size (int): Control the amount of data in each sink.
If sink_size=-1, sink the complete dataset for each epoch.
If sink_size>0, sink sink_size data for each epoch. Default: -1.
epoch_num (int): Control the number of epoch data to send. Default: 1.
Examples:
>>> dataset_helper = DatasetHelper(dataset)
>>> for inputs in dataset_helper:
>>> outputs = network(*inputs)
"""
def __init__(self, dataset, dataset_sink_mode=True, sink_size=-1, epoch_num=1, iter_first_order=1):
dataset_sink_mode = Validator.check_bool(dataset_sink_mode)
Validator.check_is_int(sink_size)
if sink_size < -1 or sink_size == 0:
raise ValueError("The sink_size must be -1 or positive, but got sink_size {}.".format(sink_size))
if dataset_sink_mode:
if context.get_context("device_target") == "Ascend":
iterclass = _DatasetIterMSLoopSink
self.iter = iterclass(dataset, sink_size, epoch_num, iter_first_order)
elif context.get_context("device_target") == "GPU":
iterclass = _DatasetIterMS
self.iter = iterclass(dataset, sink_size, epoch_num)
elif context.get_context("device_target") == "CPU":
raise RuntimeError("Currently dataset sink mode is not supported when the device target is CPU.")
def __iter__(self):
return self.iter.__iter__()
# A temp solution for loop sink. Delete later
def types_shapes(self):
"""Get the types and shapes from dataset on the current configuration."""
return self.iter.types_shapes()
def sink_size(self):
"""Get sink_size for each iteration."""
return self.iter.get_sink_size()
def stop_send(self):
"""Free up resources about data sink."""
self.iter.stop_send()
class _DatasetIter:
"""Base iter for dataset helper"""
def __init__(self, dataset, sink_size, epoch_num):
self.dataset = dataset
self.sink_size = sink_size
self.sink_count = 1
if not hasattr(dataset, '__transfer_dataset__'):
if hasattr(dataset, '__loop_size__'):
self.sink_size = dataset.__loop_size__
dataset.__transfer_dataset__ = _exec_datagraph(dataset, self.sink_size)
if not hasattr(dataset, '__no_send__'):
_send_data(dataset, epoch_num)
else:
_send_data_no_flag(dataset, epoch_num)
self.stop_send = dataset.__transfer_dataset__.stop_send
self.dataset_types, self.dataset_shapes = _get_types_and_shapes(dataset)
def __iter__(self):
self.index = 0
return self
def __next__(self):
if self.index >= self.sink_count:
raise StopIteration()
self.index += 1
return self.op()
def types_shapes(self):
return self.dataset_types, self.dataset_shapes
def get_sink_count(self, dataset):
sink_count = 1
if hasattr(dataset, '__loop_size__'):
loop_size = dataset.__loop_size__
if loop_size <= dataset.get_dataset_size() and dataset.get_dataset_size() % loop_size != 0:
raise ValueError(f'Dataset size {dataset.get_dataset_size()} and '
f'sink_size {loop_size} are not matched.')
sink_count = math.ceil(dataset.get_dataset_size() / loop_size)
return sink_count
def get_sink_size(self):
"""get sink_size to device"""
sink_size = 1
if hasattr(self.dataset, '__loop_size__'):
sink_size = self.dataset.__loop_size__
else:
if context.get_context("enable_ge") or context.get_context("device_target") == "Ascend":
if self.sink_size > 0:
sink_size = self.sink_size
else:
sink_size = self.dataset.get_dataset_size()
return sink_size
class _DatasetIterMSLoopSink(_DatasetIter):
"""Iter for context when device_target is Ascend"""
def __init__(self, dataset, sink_size, epoch_num, iter_first_order):
super().__init__(dataset, sink_size, epoch_num)
sink_count = 1
if hasattr(dataset, '__loop_size__'):
loop_size = dataset.__loop_size__ + iter_first_order
if loop_size <= dataset.get_dataset_size() and dataset.get_dataset_size() % loop_size != 0:
raise ValueError(f'Dataset size {dataset.get_dataset_size()} and '
f'sink_size {loop_size} are not matched.')
sink_count = math.ceil(dataset.get_dataset_size() / loop_size) * 2
self.sink_count = sink_count
ms_role = os.getenv("MS_ROLE")
if ms_role in ("MS_PSERVER", "MS_SCHED"):
self.sink_count = 1
# for self._parallel_mode equal to semi_auto_parallel or auto_parallel, and not using full_batch,
# use a complete tensor to compile, and slice tensor to run. The batch dimension of tensors for
# compile is device_number times the batch dimension of tensors for run. Now only support LoopSink.
if _need_to_full():
device_num = _get_device_num()
self.dataset_shapes = _to_full_shapes(self.dataset_shapes, device_num)
def op():
return tuple()
self.op = op
class _DatasetIterMS(_DatasetIter):
"""Iter for MS when enable_loop_sink is False."""
def __init__(self, dataset, sink_size, epoch_num):
super().__init__(dataset, sink_size, epoch_num)
if sink_size > 0:
self.sink_count = sink_size
else:
self.sink_count = dataset.get_dataset_size()
queue_name = dataset.__transfer_dataset__.queue_name
self.op = GetNextSingleOp(self.dataset_types, self.dataset_shapes, queue_name)

View File

@ -1,135 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""grad reducer cell for distributed training"""
from mindspore.nn.cell import Cell
from mindspore.communication.management import GlobalComm, get_group_size
from mindspore.ops import functional as F, composite as C, operations as P
from mindspore.ops.operations.comm_ops import AllReduce
import mindspore.common.dtype as mstype
reduce_opt = C.MultitypeFuncGraph("reduce_opt")
def _init_allreduce_operators(length, split_indices):
""" initialize allreduce communication operators"""
indices = split_indices[0]
fusion = split_indices[1]
op_list = ()
j = 0
for i in range(length):
if j <= len(indices)-1:
temp = indices[j]
else:
temp = length
if i >= temp:
j = j + 1
fusion = fusion + 1
op = AllReduce('sum', GlobalComm.WORLD_COMM_GROUP)
op.add_prim_attr('fusion', fusion)
op_list = op_list + (op,)
return op_list
@reduce_opt.register("Function", "Number", "Function", "Tensor")
def _tensors_allreduce_mean(mul, degree, allreduce, parameters):
"""
Apply allreduce on parameters.
Args:
mul(Primitive): The mul operator for parameters.
degree (int): The mean coefficient.
allreduce (Primitive): The communication operator for parameters.
parameters (Tensor): The parameters before operation.
Returns:
Tensor, the parameters after operation.
"""
degree = F.scalar_cast(degree, F.dtype(parameters))
parameters = allreduce(parameters)
cast_op = P.Cast()
return mul(parameters, cast_op(F.scalar_to_array(1.0 / degree), F.dtype(parameters)))
_get_datatype = C.MultitypeFuncGraph("_get_datatype")
@_get_datatype.register("Tensor")
def _tensors_get_datatype(parameters):
"""
Acquire parameters datatype.
Args:
parameters (Tensor): The parameters before operation.
Returns:
mstype, the datatype of parameters.
"""
return F.dtype(parameters)
_cast_datatype = C.MultitypeFuncGraph("_cast_datatype")
@_cast_datatype.register("TypeType", "Tensor")
def _tensors_cast_datatype(datatype, parameters):
"""
Cast parameters to datatype.
Args:
datatype (mstype): the destination datatype of parameters.
parameters (Tensor): The parameters before operation.
Returns:
Tensor, the parameters after operation.
"""
return F.cast(parameters, datatype)
class DistributedGradReducerThor(Cell):
"""
A distributed optimizer.
Constructs a parameters reducer Cell, which applies communication and average operations on
single-process parameters values.
Args:
parameter_length (int): length of the parameters to be updated.
split_indices(tuple): parameter split indices.
mean (bool): When mean is true, the mean coefficient (degree) would apply on parameters. Default: False.
degree (int): The mean coefficient. Usually it equals to device number. Default: None.
Raises:
ValueError: If degree is not a int or less than 0.
"""
def __init__(self, parameter_length, split_indices, mean=True, degree=None):
super(DistributedGradReducerThor, self).__init__(auto_prefix=False)
self.hyper_map = C.HyperMap()
self.mul = P.Mul()
if degree is None:
self.degree = get_group_size()
else:
if not isinstance(degree, int) or degree <= 0:
raise ValueError("Parameter 'degree' in DistributedGradReducer should large than 0 and be int")
self.degree = degree
self.mean = mean
self.op_list = _init_allreduce_operators(parameter_length, split_indices)
def construct(self, parameters):
datatypes = self.hyper_map(F.partial(_get_datatype), parameters)
parameters = self.hyper_map(F.partial(_cast_datatype, mstype.float32), parameters)
new_parameters = self.hyper_map(F.partial(reduce_opt, self.mul, self.degree), self.op_list, parameters)
new_parameters = self.hyper_map(F.partial(_cast_datatype), datatypes, new_parameters)
return new_parameters

View File

@ -1,267 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Model."""
import math
from mindspore.train.callback import RunContext
from mindspore import context
from mindspore.context import ParallelMode
from mindspore.train.model import Model
from mindspore.train.dataset_helper import connect_network_with_dataset
from mindspore.parallel._utils import _need_to_full, _to_full_tensor
from mindspore.common.dtype import pytype_to_dtype
from mindspore._c_expression import init_exec_dataset
from src.dataset_helper import DatasetHelper
def _convert_type(types):
"""
Convert from numpy type to tensor type.
Args:
types (list): Numpy type list of element in dataset.
Returns:
list, list of element in dataset.
"""
ms_types = []
for np_type in types:
ms_type = pytype_to_dtype(np_type)
ms_types.append(ms_type)
return ms_types
def _get_types_and_shapes(dataset):
"""Get dataset types and shapes."""
dataset_types = _convert_type(dataset.output_types())
dataset_shapes = dataset.output_shapes()
return dataset_types, dataset_shapes
def _exec_datagraph(exec_dataset, dataset_size, phase='dataset'):
"""Initialize and execute the dataset graph."""
batch_size = exec_dataset.get_batch_size()
input_indexs = exec_dataset.input_indexs
# transform data format
dataset_types, dataset_shapes = _get_types_and_shapes(exec_dataset)
init_exec_dataset(exec_dataset.__transfer_dataset__.queue_name,
dataset_size,
batch_size,
dataset_types,
dataset_shapes,
input_indexs,
phase=phase,
need_run=False)
class Model_Thor(Model):
"""
High-Level API for Training or Testing.
`Model` groups layers into an object with training and inference features.
Args:
network (Cell): A training or testing network.
loss_fn (Cell): Objective function, if loss_fn is None, the
network should contain the logic of loss and grads calculation, and the logic
of parallel if needed. Default: None.
optimizer (Cell): Optimizer for updating the weights. Default: None.
metrics (Union[dict, set]): A Dictionary or a set of metrics to be evaluated by the model during
training and testing. eg: {'accuracy', 'recall'}. Default: None.
eval_network (Cell): Network for evaluation. If not defined, `network` and `loss_fn` would be wrapped as
`eval_network`. Default: None.
eval_indexes (list): When defining the `eval_network`, if `eval_indexes` is None, all outputs of the
`eval_network` would be passed to metrics, otherwise `eval_indexes` must contain three
elements, including the positions of loss value, predicted value and label. The loss
value would be passed to the `Loss` metric, the predicted value and label would be passed
to other metric. Default: None.
amp_level (str): Option for argument `level` in `mindspore.amp.build_train_network`, level for mixed
precision training. Supports [O0, O2, O3]. Default: "O0".
- O0: Do not change.
- O2: Cast network to float16, keep batchnorm run in float32, using dynamic loss scale.
- O3: Cast network to float16, with additional property 'keep_batchnorm_fp32=False'.
O2 is recommended on GPU, O3 is recommended on Ascend.
loss_scale_manager (Union[None, LossScaleManager]): If it is None, the loss would not be scaled. Otherwise,
scale the loss by LossScaleManager. It is a key argument.
e.g. Use `loss_scale_manager=None` to set the value.
keep_batchnorm_fp32 (bool): Keep Batchnorm running in `float32`. If it is set to true, the level setting before
will be overwritten. Default: True.
"""
def __init__(self, network, loss_fn=None, optimizer=None, metrics=None, eval_network=None,
eval_indexes=None, amp_level="O0", frequency=834, use_dynamic_frequency=False,
first_stage_steps=5, **kwargs):
super(Model_Thor, self).__init__(network, loss_fn, optimizer, metrics, eval_network,
eval_indexes, amp_level, **kwargs)
self._frequency = frequency
self._use_dynamic_frequency = use_dynamic_frequency
self._first_stage_steps = first_stage_steps
self._train_network = self._build_train_network()
def _exec_preprocess(self, network, is_train, phase, dataset, dataset_sink_mode, sink_size=-1,
epoch_num=1, iter_first_order=1):
"""Initializes dataset."""
if dataset_sink_mode and not is_train:
dataset.__loop_size__ = 1
dataset_helper = DatasetHelper(dataset, dataset_sink_mode, sink_size, epoch_num, iter_first_order)
if dataset_sink_mode and context.get_context("device_target") != "GPU":
network = connect_network_with_dataset(network, dataset_helper)
network.set_train(is_train)
network.phase = phase
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
network.set_auto_parallel()
return dataset_helper, network
def _get_iter_second_steps(self, cb_params, sink_size):
"""get first stage steps for second order."""
iter_second_steps = 1
if self._use_dynamic_frequency:
global_steps = (cb_params.cur_epoch_num - 1) * sink_size + cb_params.cur_step_num
if global_steps <= self._first_stage_steps:
iter_second_steps = self._first_stage_steps
return iter_second_steps
def _get_ascend_sink_count(self, cb_params, dataset_helper, sink_size, iter_first_order, ori_sink_count):
"""get ascend sink count for each epoch."""
if context.get_context("device_target") == "Ascend":
if self._use_dynamic_frequency and cb_params.cur_epoch_num == 1:
fix_fre_sink_size = sink_size - self._first_stage_steps - iter_first_order
first_epoch_sink_count = math.ceil(fix_fre_sink_size / self._frequency) * 2 + 2
dataset_helper.iter.sink_count = first_epoch_sink_count
else:
dataset_helper.iter.sink_count = ori_sink_count
def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None, sink_size=-1):
"""
Training process. The data would be passed to network through dataset channel.
Args:
epoch (int): Total number of iterations on the data.
train_dataset (Dataset): A training dataset iterator. If there is no
loss_fn, a tuple with multiple data (data1, data2, data3, ...) should be
returned and passed to the network. Otherwise, a tuple (data, label) should
be returned. The data and label would be passed to the network and loss
function respectively.
list_callback (Callback): Executor of callback list. Default: None.
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
sink_size (int): Control the amount of data in each sink. Default: -1.
"""
if sink_size == -1:
epoch_num = epoch
else:
epoch_num = math.ceil(epoch * sink_size / train_dataset.get_dataset_size())
iter_first_order = self._frequency - 1
iter_second_order = 1
train_dataset.__loop_size__ = iter_second_order
dataset_helper, train_network = self._exec_preprocess(self._train_network,
is_train=True,
phase='train',
dataset=train_dataset,
dataset_sink_mode=True,
sink_size=sink_size,
epoch_num=epoch_num,
iter_first_order=iter_first_order)
self._train_network = train_network
cb_params.train_network = self._train_network
cb_params.cur_step_num = 0
run_context = RunContext(cb_params)
list_callback.begin(run_context)
# used to stop training for early stop, such as stopAtTIme or stopATStep
should_stop = False
switch_branch_one = True
index_first_order = 0
train_network_init_flag = True
has_do_dataset_init = False
ori_sink_count = dataset_helper.iter.sink_count
for i in range(epoch):
cb_params.cur_epoch_num = i + 1
list_callback.epoch_begin(run_context)
self._get_ascend_sink_count(cb_params, dataset_helper, sink_size, iter_first_order, ori_sink_count)
# for data sink dataset_helper only iter once, other wise iter epoch_size times.
for inputs in dataset_helper:
if _need_to_full() and context.get_context("device_target") == "GPU":
inputs = _to_full_tensor(inputs, self._device_number, self._global_rank)
list_callback.step_begin(run_context)
if context.get_context("device_target") == "GPU":
if switch_branch_one:
cb_params.cur_step_num += 1
if train_network_init_flag:
self._train_network.add_flags_recursive(thor=True)
self._train_network.phase = 'train0'
outputs = self._train_network(*inputs)
cb_params.net_outputs = outputs
is_first_stage = self._use_dynamic_frequency and cb_params.cur_epoch_num == 1 \
and cb_params.cur_step_num < self._first_stage_steps
if is_first_stage:
continue
else:
switch_branch_one = not switch_branch_one
list_callback.step_end(run_context)
else:
cb_params.cur_step_num += 1
if train_network_init_flag:
self._train_network.add_flags_recursive(thor=False)
train_network_init_flag = False
self._train_network.phase = 'train1'
outputs = self._train_network(*inputs)
cb_params.net_outputs = outputs
index_first_order += 1
if index_first_order == iter_first_order:
index_first_order = 0
switch_branch_one = not switch_branch_one
list_callback.step_end(run_context)
else:
if switch_branch_one:
cb_params.cur_step_num += self._get_iter_second_steps(cb_params, sink_size)
if train_network_init_flag:
self._train_network.add_flags_recursive(thor=True)
self._train_network.phase = 'train0'
else:
cb_params.cur_step_num += iter_first_order
if train_network_init_flag:
self._train_network.add_flags_recursive(thor=False)
train_network_init_flag = False
self._train_network.phase = 'train1'
if not has_do_dataset_init:
_exec_datagraph(train_dataset, iter_first_order, phase='train1_dataset')
has_do_dataset_init = True
switch_branch_one = not switch_branch_one
outputs = self._train_network(*inputs)
cb_params.net_outputs = outputs
list_callback.step_end(run_context)
list_callback.epoch_end(run_context)
should_stop = should_stop or run_context.get_stop_requested()
if should_stop:
break
dataset_helper.stop_send()
list_callback.end(run_context)
__all__ = ["Model_Thor"]

View File

@ -0,0 +1,573 @@
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ResNet."""
import math
import numpy as np
from scipy.stats import truncnorm
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.common.tensor import Tensor
def _conv_variance_scaling_initializer(in_channel, out_channel, kernel_size):
fan_in = in_channel * kernel_size * kernel_size
scale = 1.0
scale /= max(1., fan_in)
stddev = (scale ** 0.5) / .87962566103423978
mu, sigma = 0, stddev
weight = truncnorm(-2, 2, loc=mu, scale=sigma).rvs(out_channel * in_channel * kernel_size * kernel_size)
weight = np.reshape(weight, (out_channel, in_channel, kernel_size, kernel_size))
return Tensor(weight, dtype=mstype.float32)
def _weight_variable(shape, factor=0.01):
init_value = np.random.randn(*shape).astype(np.float32) * factor
return Tensor(init_value)
def calculate_gain(nonlinearity, param=None):
"""calculate_gain"""
linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']
res = 0
if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
res = 1
elif nonlinearity == 'tanh':
res = 5.0 / 3
elif nonlinearity == 'relu':
res = math.sqrt(2.0)
elif nonlinearity == 'leaky_relu':
if param is None:
negative_slope = 0.01
elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
# True/False are instances of int, hence check above
negative_slope = param
else:
raise ValueError("negative_slope {} not a valid number".format(param))
res = math.sqrt(2.0 / (1 + negative_slope ** 2))
else:
raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
return res
def _calculate_fan_in_and_fan_out(tensor):
"""_calculate_fan_in_and_fan_out"""
dimensions = len(tensor)
if dimensions < 2:
raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions")
if dimensions == 2: # Linear
fan_in = tensor[1]
fan_out = tensor[0]
else:
num_input_fmaps = tensor[1]
num_output_fmaps = tensor[0]
receptive_field_size = 1
if dimensions > 2:
receptive_field_size = tensor[2] * tensor[3]
fan_in = num_input_fmaps * receptive_field_size
fan_out = num_output_fmaps * receptive_field_size
return fan_in, fan_out
def _calculate_correct_fan(tensor, mode):
mode = mode.lower()
valid_modes = ['fan_in', 'fan_out']
if mode not in valid_modes:
raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes))
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
return fan_in if mode == 'fan_in' else fan_out
def kaiming_normal(inputs_shape, a=0, mode='fan_in', nonlinearity='leaky_relu'):
fan = _calculate_correct_fan(inputs_shape, mode)
gain = calculate_gain(nonlinearity, a)
std = gain / math.sqrt(fan)
return np.random.normal(0, std, size=inputs_shape).astype(np.float32)
def kaiming_uniform(inputs_shape, a=0., mode='fan_in', nonlinearity='leaky_relu'):
fan = _calculate_correct_fan(inputs_shape, mode)
gain = calculate_gain(nonlinearity, a)
std = gain / math.sqrt(fan)
bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
return np.random.uniform(-bound, bound, size=inputs_shape).astype(np.float32)
def _conv3x3(in_channel, out_channel, stride=1, use_se=False, res_base=False):
if use_se:
weight = _conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=3)
else:
weight_shape = (out_channel, in_channel, 3, 3)
weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu'))
if res_base:
return nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride,
padding=1, pad_mode='pad', weight_init=weight)
return nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride,
padding=0, pad_mode='same', weight_init=weight)
def _conv1x1(in_channel, out_channel, stride=1, use_se=False, res_base=False):
if use_se:
weight = _conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=1)
else:
weight_shape = (out_channel, in_channel, 1, 1)
weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu'))
if res_base:
return nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=stride,
padding=0, pad_mode='pad', weight_init=weight)
return nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=stride,
padding=0, pad_mode='same', weight_init=weight)
def _conv7x7(in_channel, out_channel, stride=1, use_se=False, res_base=False):
if use_se:
weight = _conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=7)
else:
weight_shape = (out_channel, in_channel, 7, 7)
weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu'))
if res_base:
return nn.Conv2d(in_channel, out_channel,
kernel_size=7, stride=stride, padding=3, pad_mode='pad', weight_init=weight)
return nn.Conv2d(in_channel, out_channel,
kernel_size=7, stride=stride, padding=0, pad_mode='same', weight_init=weight)
def _bn(channel, res_base=False):
if res_base:
return nn.BatchNorm2d(channel, eps=1e-5, momentum=0.1,
gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1)
return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9,
gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1)
def _bn_last(channel):
return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9,
gamma_init=0, beta_init=0, moving_mean_init=0, moving_var_init=1)
def _fc(in_channel, out_channel, use_se=False):
if use_se:
weight = np.random.normal(loc=0, scale=0.01, size=out_channel * in_channel)
weight = Tensor(np.reshape(weight, (out_channel, in_channel)), dtype=mstype.float32)
else:
weight_shape = (out_channel, in_channel)
weight = Tensor(kaiming_uniform(weight_shape, a=math.sqrt(5)))
return nn.Dense(in_channel, out_channel, has_bias=True, weight_init=weight, bias_init=0)
class ResidualBlock(nn.Cell):
"""
ResNet V1 residual block definition.
Args:
in_channel (int): Input channel.
out_channel (int): Output channel.
stride (int): Stride size for the first convolutional layer. Default: 1.
use_se (bool): Enable SE-ResNet50 net. Default: False.
se_block(bool): Use se block in SE-ResNet50 net. Default: False.
Returns:
Tensor, output tensor.
Examples:
>>> ResidualBlock(3, 256, stride=2)
"""
expansion = 4
def __init__(self,
in_channel,
out_channel,
stride=1,
use_se=False, se_block=False):
super(ResidualBlock, self).__init__()
self.stride = stride
self.use_se = use_se
self.se_block = se_block
channel = out_channel // self.expansion
self.conv1 = _conv1x1(in_channel, channel, stride=1, use_se=self.use_se)
self.bn1 = _bn(channel)
if self.use_se and self.stride != 1:
self.e2 = nn.SequentialCell([_conv3x3(channel, channel, stride=1, use_se=True), _bn(channel),
nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='same')])
else:
self.conv2 = _conv3x3(channel, channel, stride=stride, use_se=self.use_se)
self.bn2 = _bn(channel)
self.conv3 = _conv1x1(channel, out_channel, stride=1, use_se=self.use_se)
self.bn3 = _bn_last(out_channel)
if self.se_block:
self.se_global_pool = P.ReduceMean(keep_dims=False)
self.se_dense_0 = _fc(out_channel, int(out_channel / 4), use_se=self.use_se)
self.se_dense_1 = _fc(int(out_channel / 4), out_channel, use_se=self.use_se)
self.se_sigmoid = nn.Sigmoid()
self.se_mul = P.Mul()
self.relu = nn.ReLU()
self.down_sample = False
if stride != 1 or in_channel != out_channel:
self.down_sample = True
self.down_sample_layer = None
if self.down_sample:
if self.use_se:
if stride == 1:
self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel,
stride, use_se=self.use_se), _bn(out_channel)])
else:
self.down_sample_layer = nn.SequentialCell([nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='same'),
_conv1x1(in_channel, out_channel, 1,
use_se=self.use_se), _bn(out_channel)])
else:
self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride,
use_se=self.use_se), _bn(out_channel)])
def construct(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
if self.use_se and self.stride != 1:
out = self.e2(out)
else:
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.se_block:
out_se = out
out = self.se_global_pool(out, (2, 3))
out = self.se_dense_0(out)
out = self.relu(out)
out = self.se_dense_1(out)
out = self.se_sigmoid(out)
out = F.reshape(out, F.shape(out) + (1, 1))
out = self.se_mul(out, out_se)
if self.down_sample:
identity = self.down_sample_layer(identity)
out = out + identity
out = self.relu(out)
return out
class ResidualBlockBase(nn.Cell):
"""
ResNet V1 residual block definition.
Args:
in_channel (int): Input channel.
out_channel (int): Output channel.
stride (int): Stride size for the first convolutional layer. Default: 1.
use_se (bool): Enable SE-ResNet50 net. Default: False.
se_block(bool): Use se block in SE-ResNet50 net. Default: False.
res_base (bool): Enable parameter setting of resnet18. Default: True.
Returns:
Tensor, output tensor.
Examples:
>>> ResidualBlockBase(3, 256, stride=2)
"""
def __init__(self,
in_channel,
out_channel,
stride=1,
use_se=False,
se_block=False,
res_base=True):
super(ResidualBlockBase, self).__init__()
self.res_base = res_base
self.conv1 = _conv3x3(in_channel, out_channel, stride=stride, res_base=self.res_base)
self.bn1d = _bn(out_channel)
self.conv2 = _conv3x3(out_channel, out_channel, stride=1, res_base=self.res_base)
self.bn2d = _bn(out_channel)
self.relu = nn.ReLU()
self.down_sample = False
if stride != 1 or in_channel != out_channel:
self.down_sample = True
self.down_sample_layer = None
if self.down_sample:
self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride,
use_se=use_se, res_base=self.res_base),
_bn(out_channel, res_base)])
def construct(self, x):
identity = x
out = self.conv1(x)
out = self.bn1d(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2d(out)
if self.down_sample:
identity = self.down_sample_layer(identity)
out = out + identity
out = self.relu(out)
return out
class ResNet(nn.Cell):
"""
ResNet architecture.
Args:
block (Cell): Block for network.
layer_nums (list): Numbers of block in different layers.
in_channels (list): Input channel in each layer.
out_channels (list): Output channel in each layer.
strides (list): Stride size in each layer.
num_classes (int): The number of classes that the training images are belonging to.
use_se (bool): Enable SE-ResNet50 net. Default: False.
se_block(bool): Use se block in SE-ResNet50 net in layer 3 and layer 4. Default: False.
res_base (bool): Enable parameter setting of resnet18. Default: False.
Returns:
Tensor, output tensor.
Examples:
>>> ResNet(ResidualBlock,
>>> [3, 4, 6, 3],
>>> [64, 256, 512, 1024],
>>> [256, 512, 1024, 2048],
>>> [1, 2, 2, 2],
>>> 10)
"""
def __init__(self,
block,
layer_nums,
in_channels,
out_channels,
strides,
num_classes,
use_se=False,
res_base=False):
super(ResNet, self).__init__()
if not len(layer_nums) == len(in_channels) == len(out_channels) == 4:
raise ValueError("the length of layer_num, in_channels, out_channels list must be 4!")
self.use_se = use_se
self.res_base = res_base
self.se_block = False
if self.use_se:
self.se_block = True
if self.use_se:
self.conv1_0 = _conv3x3(3, 32, stride=2, use_se=self.use_se)
self.bn1_0 = _bn(32)
self.conv1_1 = _conv3x3(32, 32, stride=1, use_se=self.use_se)
self.bn1_1 = _bn(32)
self.conv1_2 = _conv3x3(32, 64, stride=1, use_se=self.use_se)
else:
self.conv1 = _conv7x7(3, 64, stride=2, res_base=self.res_base)
self.bn1 = _bn(64, self.res_base)
self.relu = P.ReLU()
if self.res_base:
self.pad = nn.Pad(paddings=((0, 0), (0, 0), (1, 1), (1, 1)))
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="valid")
else:
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
self.layer1 = self._make_layer(block,
layer_nums[0],
in_channel=in_channels[0],
out_channel=out_channels[0],
stride=strides[0],
use_se=self.use_se)
self.layer2 = self._make_layer(block,
layer_nums[1],
in_channel=in_channels[1],
out_channel=out_channels[1],
stride=strides[1],
use_se=self.use_se)
self.layer3 = self._make_layer(block,
layer_nums[2],
in_channel=in_channels[2],
out_channel=out_channels[2],
stride=strides[2],
use_se=self.use_se,
se_block=self.se_block)
self.layer4 = self._make_layer(block,
layer_nums[3],
in_channel=in_channels[3],
out_channel=out_channels[3],
stride=strides[3],
use_se=self.use_se,
se_block=self.se_block)
self.mean = P.ReduceMean(keep_dims=True)
self.flatten = nn.Flatten()
self.end_point = _fc(out_channels[3], num_classes, use_se=self.use_se)
def _make_layer(self, block, layer_num, in_channel, out_channel, stride, use_se=False, se_block=False):
"""
Make stage network of ResNet.
Args:
block (Cell): Resnet block.
layer_num (int): Layer number.
in_channel (int): Input channel.
out_channel (int): Output channel.
stride (int): Stride size for the first convolutional layer.
se_block(bool): Use se block in SE-ResNet50 net. Default: False.
Returns:
SequentialCell, the output layer.
Examples:
>>> _make_layer(ResidualBlock, 3, 128, 256, 2)
"""
layers = []
resnet_block = block(in_channel, out_channel, stride=stride, use_se=use_se)
layers.append(resnet_block)
if se_block:
for _ in range(1, layer_num - 1):
resnet_block = block(out_channel, out_channel, stride=1, use_se=use_se)
layers.append(resnet_block)
resnet_block = block(out_channel, out_channel, stride=1, use_se=use_se, se_block=se_block)
layers.append(resnet_block)
else:
for _ in range(1, layer_num):
resnet_block = block(out_channel, out_channel, stride=1, use_se=use_se)
layers.append(resnet_block)
return nn.SequentialCell(layers)
def construct(self, x):
if self.use_se:
x = self.conv1_0(x)
x = self.bn1_0(x)
x = self.relu(x)
x = self.conv1_1(x)
x = self.bn1_1(x)
x = self.relu(x)
x = self.conv1_2(x)
else:
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
if self.res_base:
x = self.pad(x)
c1 = self.maxpool(x)
c2 = self.layer1(c1)
c3 = self.layer2(c2)
c4 = self.layer3(c3)
c5 = self.layer4(c4)
out = self.mean(c5, (2, 3))
out = self.flatten(out)
out = self.end_point(out)
return out
def resnet18(class_num=10):
"""
Get ResNet18 neural network.
Args:
class_num (int): Class number.
Returns:
Cell, cell instance of ResNet18 neural network.
Examples:
>>> net = resnet18(10)
"""
return ResNet(ResidualBlockBase,
[2, 2, 2, 2],
[64, 64, 128, 256],
[64, 128, 256, 512],
[1, 2, 2, 2],
class_num,
res_base=True)
def resnet50(class_num=10):
"""
Get ResNet50 neural network.
Args:
class_num (int): Class number.
Returns:
Cell, cell instance of ResNet50 neural network.
Examples:
>>> net = resnet50(10)
"""
return ResNet(ResidualBlock,
[3, 4, 6, 3],
[64, 256, 512, 1024],
[256, 512, 1024, 2048],
[1, 2, 2, 2],
class_num)
def se_resnet50(class_num=1001):
"""
Get SE-ResNet50 neural network.
Args:
class_num (int): Class number.
Returns:
Cell, cell instance of SE-ResNet50 neural network.
Examples:
>>> net = se-resnet50(1001)
"""
return ResNet(ResidualBlock,
[3, 4, 6, 3],
[64, 256, 512, 1024],
[256, 512, 1024, 2048],
[1, 2, 2, 2],
class_num,
use_se=True)
def resnet101(class_num=1001):
"""
Get ResNet101 neural network.
Args:
class_num (int): Class number.
Returns:
Cell, cell instance of ResNet101 neural network.
Examples:
>>> net = resnet101(1001)
"""
return ResNet(ResidualBlock,
[3, 4, 23, 3],
[64, 256, 512, 1024],
[256, 512, 1024, 2048],
[1, 2, 2, 2],
class_num)

View File

@ -1,409 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ResNet."""
import math
import numpy as np
import mindspore.nn as nn
from mindspore.common.tensor import Tensor
from mindspore.ops import operations as P
from mindspore import context
from src.thor_layer import Conv2d_Thor, Dense_Thor, Conv2d_Thor_GPU, Dense_Thor_GPU
def calculate_gain(nonlinearity, param=None):
"""calculate_gain"""
linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']
res = 0
if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
res = 1
elif nonlinearity == 'tanh':
res = 5.0 / 3
elif nonlinearity == 'relu':
res = math.sqrt(2.0)
elif nonlinearity == 'leaky_relu':
if param is None:
negative_slope = 0.01
elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
# True/False are instances of int, hence check above
negative_slope = param
else:
raise ValueError("negative_slope {} not a valid number".format(param))
res = math.sqrt(2.0 / (1 + negative_slope ** 2))
else:
raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
return res
def _calculate_fan_in_and_fan_out(tensor):
"""_calculate_fan_in_and_fan_out"""
dimensions = len(tensor)
if dimensions < 2:
raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions")
if dimensions == 2: # Linear
fan_in = tensor[1]
fan_out = tensor[0]
else:
num_input_fmaps = tensor[1]
num_output_fmaps = tensor[0]
receptive_field_size = 1
if dimensions > 2:
receptive_field_size = tensor[2] * tensor[3]
fan_in = num_input_fmaps * receptive_field_size
fan_out = num_output_fmaps * receptive_field_size
return fan_in, fan_out
def _calculate_correct_fan(tensor, mode):
mode = mode.lower()
valid_modes = ['fan_in', 'fan_out']
if mode not in valid_modes:
raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes))
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
return fan_in if mode == 'fan_in' else fan_out
def kaiming_normal(inputs_shape, a=0, mode='fan_in', nonlinearity='leaky_relu'):
fan = _calculate_correct_fan(inputs_shape, mode)
gain = calculate_gain(nonlinearity, a)
std = gain / math.sqrt(fan)
return np.random.normal(0, std, size=inputs_shape).astype(np.float32)
def kaiming_uniform(inputs_shape, a=0., mode='fan_in', nonlinearity='leaky_relu'):
fan = _calculate_correct_fan(inputs_shape, mode)
gain = calculate_gain(nonlinearity, a)
std = gain / math.sqrt(fan)
bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
return np.random.uniform(-bound, bound, size=inputs_shape).astype(np.float32)
def _weight_variable(shape, factor=0.01):
init_value = np.random.randn(*shape).astype(np.float32) * factor
return Tensor(init_value)
def _conv3x3(in_channel, out_channel, stride=1, damping=0.03, loss_scale=1, frequency=278, batch_size=32):
weight_shape = (out_channel, in_channel, 3, 3)
weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu'))
if context.get_context('device_target') == "Ascend":
layer = Conv2d_Thor(in_channel, out_channel,
kernel_size=3, stride=stride, padding=0, pad_mode='same', weight_init=weight,
damping=damping, loss_scale=loss_scale, frequency=frequency, batch_size=batch_size)
else:
layer = Conv2d_Thor_GPU(in_channel, out_channel,
kernel_size=3, stride=stride, padding=0, pad_mode='same', weight_init=weight,
damping=damping, loss_scale=loss_scale, frequency=frequency, batch_size=batch_size)
return layer
def _conv1x1(in_channel, out_channel, stride=1, damping=0.03, loss_scale=1, frequency=278, batch_size=32):
weight_shape = (out_channel, in_channel, 1, 1)
weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu'))
if context.get_context('device_target') == "Ascend":
layer = Conv2d_Thor(in_channel, out_channel,
kernel_size=1, stride=stride, padding=0, pad_mode='same', weight_init=weight,
damping=damping, loss_scale=loss_scale, frequency=frequency, batch_size=batch_size)
else:
layer = Conv2d_Thor_GPU(in_channel, out_channel,
kernel_size=1, stride=stride, padding=0, pad_mode='same', weight_init=weight,
damping=damping, loss_scale=loss_scale, frequency=frequency, batch_size=batch_size)
return layer
def _conv7x7(in_channel, out_channel, stride=1, damping=0.03, loss_scale=1, frequency=278, batch_size=32):
weight_shape = (out_channel, in_channel, 7, 7)
weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu'))
if context.get_context('device_target') == "Ascend":
layer = Conv2d_Thor(in_channel, out_channel,
kernel_size=7, stride=stride, padding=0, pad_mode='same', weight_init=weight,
damping=damping, loss_scale=loss_scale, frequency=frequency, batch_size=batch_size)
else:
layer = Conv2d_Thor_GPU(in_channel, out_channel,
kernel_size=7, stride=stride, padding=0, pad_mode='same', weight_init=weight,
damping=damping, loss_scale=loss_scale, frequency=frequency, batch_size=batch_size)
return layer
def _bn(channel):
return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9,
gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1)
def _bn_last(channel):
return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9,
gamma_init=0, beta_init=0, moving_mean_init=0, moving_var_init=1)
def _fc(in_channel, out_channel, damping, loss_scale, frequency, batch_size=32):
weight_shape = (out_channel, in_channel)
weight = Tensor(kaiming_uniform(weight_shape, a=math.sqrt(5)))
if context.get_context('device_target') == "Ascend":
layer = Dense_Thor(in_channel, out_channel, has_bias=False, weight_init=weight,
bias_init=0, damping=damping, loss_scale=loss_scale, frequency=frequency,
batch_size=batch_size)
else:
layer = Dense_Thor_GPU(in_channel, out_channel, has_bias=False, weight_init=weight,
bias_init=0, damping=damping, loss_scale=loss_scale, frequency=frequency,
batch_size=batch_size)
return layer
class ResidualBlock(nn.Cell):
"""
ResNet V1 residual block definition.
Args:
in_channel (int): Input channel.
out_channel (int): Output channel.
stride (int): Stride size for the first convolutional layer. Default: 1.
Returns:
Tensor, output tensor.
Examples:
>>> ResidualBlock(3, 256, stride=2)
"""
expansion = 4
def __init__(self,
in_channel,
out_channel,
stride=1,
damping=0.03,
loss_scale=1,
frequency=278,
batch_size=32):
super(ResidualBlock, self).__init__()
channel = out_channel // self.expansion
self.conv1 = _conv1x1(in_channel, channel, stride=1, damping=damping, loss_scale=loss_scale,
frequency=frequency, batch_size=batch_size)
self.bn1 = _bn(channel)
self.conv2 = _conv3x3(channel, channel, stride=stride, damping=damping, loss_scale=loss_scale,
frequency=frequency, batch_size=batch_size)
self.bn2 = _bn(channel)
self.conv3 = _conv1x1(channel, out_channel, stride=1, damping=damping, loss_scale=loss_scale,
frequency=frequency, batch_size=batch_size)
self.bn3 = _bn_last(out_channel)
self.relu = nn.ReLU()
self.down_sample = False
if stride != 1 or in_channel != out_channel:
self.down_sample = True
self.down_sample_layer = None
if self.down_sample:
self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride,
damping=damping, loss_scale=loss_scale,
frequency=frequency,
batch_size=batch_size),
_bn(out_channel)])
self.add = P.Add()
def construct(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.down_sample:
identity = self.down_sample_layer(identity)
out = self.add(out, identity)
out = self.relu(out)
return out
class ResNet(nn.Cell):
"""
ResNet architecture.
Args:
block (Cell): Block for network.
layer_nums (list): Numbers of block in different layers.
in_channels (list): Input channel in each layer.
out_channels (list): Output channel in each layer.
strides (list): Stride size in each layer.
num_classes (int): The number of classes that the training images are belonging to.
Returns:
Tensor, output tensor.
Examples:
>>> ResNet(ResidualBlock,
>>> [3, 4, 6, 3],
>>> [64, 256, 512, 1024],
>>> [256, 512, 1024, 2048],
>>> [1, 2, 2, 2],
>>> 10)
"""
def __init__(self,
block,
layer_nums,
in_channels,
out_channels,
strides,
num_classes,
damping,
loss_scale,
frequency,
batch_size,
include_top=True):
super(ResNet, self).__init__()
if not len(layer_nums) == len(in_channels) == len(out_channels) == 4:
raise ValueError("the length of layer_num, in_channels, out_channels list must be 4!")
self.conv1 = _conv7x7(3, 64, stride=2, damping=damping, loss_scale=loss_scale,
frequency=frequency, batch_size=batch_size)
self.bn1 = _bn(64)
self.relu = P.ReLU()
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
self.layer1 = self._make_layer(block,
layer_nums[0],
in_channel=in_channels[0],
out_channel=out_channels[0],
stride=strides[0],
damping=damping,
loss_scale=loss_scale,
frequency=frequency,
batch_size=batch_size)
self.layer2 = self._make_layer(block,
layer_nums[1],
in_channel=in_channels[1],
out_channel=out_channels[1],
stride=strides[1],
damping=damping,
loss_scale=loss_scale,
frequency=frequency,
batch_size=batch_size)
self.layer3 = self._make_layer(block,
layer_nums[2],
in_channel=in_channels[2],
out_channel=out_channels[2],
stride=strides[2], damping=damping,
loss_scale=loss_scale,
frequency=frequency,
batch_size=batch_size)
self.layer4 = self._make_layer(block,
layer_nums[3],
in_channel=in_channels[3],
out_channel=out_channels[3],
stride=strides[3],
damping=damping,
loss_scale=loss_scale,
frequency=frequency,
batch_size=batch_size)
self.include_top = include_top
if self.include_top:
self.mean = P.ReduceMean(keep_dims=True)
self.flatten = nn.Flatten()
self.end_point = _fc(out_channels[3], num_classes, damping=damping, loss_scale=loss_scale,
frequency=frequency, batch_size=batch_size)
def _make_layer(self, block, layer_num, in_channel, out_channel, stride,
damping, loss_scale, frequency, batch_size):
"""
Make stage network of ResNet.
Args:
block (Cell): Resnet block.
layer_num (int): Layer number.
in_channel (int): Input channel.
out_channel (int): Output channel.
stride (int): Stride size for the first convolutional layer.
Returns:
SequentialCell, the output layer.
Examples:
>>> _make_layer(ResidualBlock, 3, 128, 256, 2)
"""
layers = []
resnet_block = block(in_channel, out_channel, stride=stride,
damping=damping, loss_scale=loss_scale, frequency=frequency,
batch_size=batch_size)
layers.append(resnet_block)
for _ in range(1, layer_num):
resnet_block = block(out_channel, out_channel, stride=1,
damping=damping, loss_scale=loss_scale, frequency=frequency,
batch_size=batch_size)
layers.append(resnet_block)
return nn.SequentialCell(layers)
def construct(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
c1 = self.maxpool(x)
c2 = self.layer1(c1)
c3 = self.layer2(c2)
c4 = self.layer3(c3)
c5 = self.layer4(c4)
if not self.include_top:
return x
out = self.mean(c5, (2, 3))
out = self.flatten(out)
out = self.end_point(out)
return out
def resnet50(class_num=10, damping=0.03, loss_scale=1, frequency=278, batch_size=32, include_top=True):
"""
Get ResNet50 neural network.
Args:
class_num (int): Class number.
Returns:
Cell, cell instance of ResNet50 neural network.
Examples:
>>> net = resnet50(10)
"""
return ResNet(ResidualBlock,
[3, 4, 6, 3],
[64, 256, 512, 1024],
[256, 512, 1024, 2048],
[1, 2, 2, 2],
class_num,
damping,
loss_scale,
frequency,
batch_size,
include_top)

View File

@ -1,301 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""THOR"""
from mindspore.ops import functional as F, composite as C, operations as P
from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter, ParameterTuple
from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype
from mindspore._checkparam import Validator
from mindspore.nn.optim.optimizer import Optimizer
from mindspore.parallel._utils import _get_device_num, _get_gradients_mean
from src.grad_reducer_thor import DistributedGradReducerThor
_momentum_opt = C.MultitypeFuncGraph("momentum_opt")
op_add = P.AddN()
apply_decay = C.MultitypeFuncGraph("apply_decay")
@apply_decay.register("Number", "Bool", "Tensor", "Tensor")
def _tensor_apply_decay(weight_decay, if_apply, weight, gradient):
"""Get grad with weight_decay."""
if if_apply:
return op_add((weight * weight_decay, gradient))
return gradient
@_momentum_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor")
def _tensor_run_opt_ext(opt, momentum, learning_rate, gradient, weight, moment):
"""Apply momentum optimizer to the weight parameter using Tensor."""
success = True
success = F.depend(success, opt(weight, moment, learning_rate, gradient, momentum))
return success
class THOR_GPU(Optimizer):
"""
THOR
"""
def __init__(self, params, learning_rate, momentum, matrix_A, matrix_G, A_inv_max, G_inv_max,
weight_decay=0.0, loss_scale=1.0, use_nesterov=False, decay_filter=lambda x: x.name not in []):
super(THOR_GPU, self).__init__(learning_rate, params, weight_decay, loss_scale)
Validator.check_value_type("momentum", momentum, [float], self.cls_name)
if isinstance(momentum, float) and momentum < 0.0:
raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum))
self.momentum = Parameter(Tensor(momentum, mstype.float32))
self.params = self.parameters
self.use_nesterov = Validator.check_bool(use_nesterov)
self.moments = self.params.clone(prefix="moments", init='zeros')
self.hyper_map = C.HyperMap()
self.opt = P.ApplyMomentum(use_nesterov=self.use_nesterov)
self.feature_map = [1.0 / 12544, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136,
1.0 / 3136, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136,
1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784,
1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784,
1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196,
1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196,
1.0 / 196, 1.0 / 196, 1.0 / 196,
1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49,
1.0]
self.feature_map_new = [x ** 0.5 for x in self.feature_map]
self.transpose = P.Transpose()
self.shape = P.Shape()
self.reshape = P.Reshape()
self.matmul = P.MatMul()
self.matrix_A = ParameterTuple(matrix_A)
self.matrix_G = ParameterTuple(matrix_G)
self.A_inv_max = ParameterTuple(A_inv_max)
self.G_inv_max = ParameterTuple(G_inv_max)
self.assign = P.Assign()
self.mul = P.Mul()
mean = _get_gradients_mean()
degree = _get_device_num()
parameter_length = len(self.feature_map)
self.grad_reducer_thorA = DistributedGradReducerThor(parameter_length, ((parameter_length,), 0), mean, degree)
self.grad_reducer_thorG = DistributedGradReducerThor(parameter_length, ((parameter_length,), 0), mean, degree)
self.weight_decay = weight_decay
self.decay_flags = tuple(decay_filter(x) for x in self.parameters)
self.update_gradient = P.UpdateThorGradient(split_dim=128)
def construct(self, gradients):
params = self.params
moments = self.moments
gradients = self.scale_grad(gradients)
new_grads = ()
if self.thor:
matrix_A_allreduce = ()
matrix_G_allreduce = ()
for i in range(54):
g = gradients[i * 3]
matrix_A = self.matrix_A[i]
matrix_G = self.matrix_G[i]
matrix_A = F.depend(matrix_A, g)
matrix_G = F.depend(matrix_G, g)
matrix_A = self.mul(matrix_A, self.feature_map_new[i])
matrix_G = self.mul(matrix_G, self.feature_map_new[i])
matrix_A_allreduce = matrix_A_allreduce + (matrix_A,)
matrix_G_allreduce = matrix_G_allreduce + (matrix_G,)
matrix_A_allreduce = self.grad_reducer_thorA(matrix_A_allreduce)
matrix_G_allreduce = self.grad_reducer_thorG(matrix_G_allreduce)
for i in range(54):
g = gradients[i * 3]
g_shape = self.shape(g)
g = self.reshape(g, (g_shape[0], -1))
matrix_A = matrix_A_allreduce[i]
matrix_G = matrix_G_allreduce[i]
g = self.update_gradient(matrix_G, g, matrix_A)
fake_A = self.assign(self.matrix_A[i], matrix_A)
fake_G = self.assign(self.matrix_G[i], matrix_G)
g = F.depend(g, fake_A)
g = F.depend(g, fake_G)
if i == 53:
new_grads = new_grads + (g,)
else:
g = self.reshape(g, g_shape)
new_grads = new_grads + (g, gradients[i * 3 + 1], gradients[i * 3 + 2])
else:
for i in range(54):
g = gradients[i * 3]
g_shape = self.shape(g)
g = self.reshape(g, (g_shape[0], -1))
matrix_A = self.matrix_A[i]
matrix_G = self.matrix_G[i]
g = self.update_gradient(matrix_G, g, matrix_A)
if i == 53:
new_grads = new_grads + (g,)
else:
g = self.reshape(g, g_shape)
new_grads = new_grads + (g, gradients[i * 3 + 1], gradients[i * 3 + 2])
gradients = new_grads
if self.weight_decay > 0:
gradients = self.hyper_map(F.partial(apply_decay, self.weight_decay), self.decay_flags,
params, gradients)
lr = self.get_lr()
success = self.hyper_map(F.partial(_momentum_opt, self.opt, self.momentum, lr), gradients, params, moments)
return success
class THOR(Optimizer):
"""THOR"""
def __init__(self, params, learning_rate, momentum, matrix_A, matrix_G, A_inv_max, G_inv_max, weight_decay=0.0,
loss_scale=1.0,
decay_filter=lambda x: x.name not in []):
super(THOR, self).__init__(learning_rate, params, weight_decay, loss_scale)
if isinstance(momentum, float) and momentum < 0.0:
raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum))
self.momentum = Parameter(Tensor(momentum, mstype.float32))
self.params = self.parameters
self.moments = self.params.clone(prefix="moments", init='zeros')
self.hyper_map = C.HyperMap()
self.opt = P.ApplyMomentum()
self.matrix_A = ParameterTuple(matrix_A)
self.matrix_G = ParameterTuple(matrix_G)
self.A_inv_max = ParameterTuple(A_inv_max)
self.G_inv_max = ParameterTuple(G_inv_max)
self.cube_matmul_left = P.CusMatMulCubeFraczLeftCast()
self.cube_matmul_left_fc = P.CusMatMulCubeDenseLeft()
self.cube_matmul_right_fc = P.CusMatMulCubeDenseRight()
self.cube_matmul_right_mul = P.CusMatMulCubeFraczRightMul()
self.transpose = P.Transpose()
self.shape = P.Shape()
self.reshape = P.Reshape()
self.mul = P.Mul()
self.weight_idx = []
for i in range(len(self.params)):
if "conv" in self.params[i].name or "end_point" in self.params[i].name:
self.weight_idx.append(i)
self.weight_idx.append(len(self.params))
self.feature_map = [1.0 / 12544, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136,
1.0 / 3136, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136,
1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784,
1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784,
1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196,
1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196,
1.0 / 196, 1.0 / 196, 1.0 / 196,
1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49,
1.0]
mean = _get_gradients_mean()
degree = _get_device_num()
parameter_length = len(self.feature_map)
self.grad_reducer_Amax = DistributedGradReducerThor(parameter_length, ((27,), 2), mean, degree)
self.grad_reducer_Gmax = DistributedGradReducerThor(parameter_length, ((27,), 4), mean, degree)
self.grad_reducer_A = DistributedGradReducerThor(parameter_length, ((27,), 6), mean, degree)
self.grad_reducer_G = DistributedGradReducerThor(parameter_length, ((27,), 8), mean, degree)
self.matrix_A_inv = ()
self.matrix_G_inv = ()
self.matrix_max_inv = ()
for i in range(54):
self.matrix_max_inv = self.matrix_max_inv + (
Parameter(initializer(1, [1], mstype.float32), name="matrix_max" + str(i), requires_grad=False),)
self.log = P.Log()
self.exp = P.Exp()
self.sqrt = P.Sqrt()
self.matrix_max_inv = ParameterTuple(self.matrix_max_inv)
self.assign = P.Assign()
self.cast = P.Cast()
self.thor = True
self.weight_decay = weight_decay * loss_scale
self.decay_flags = tuple(decay_filter(x) for x in self.parameters)
def construct(self, gradients):
params = self.params
moments = self.moments
if self.thor:
matrix_A_allreduce = ()
matrix_G_allreduce = ()
matrix_A_max_allreduce = ()
matrix_G_max_allreduce = ()
for i in range(54):
g = gradients[i * 3]
matrix_A = self.matrix_A[i]
matrix_G = self.matrix_G[i]
A_max = self.A_inv_max[i]
G_max = self.G_inv_max[i]
matrix_A = F.depend(matrix_A, g)
matrix_G = F.depend(matrix_G, g)
A_max = F.depend(A_max, g)
G_max = F.depend(G_max, g)
matrix_A_allreduce = matrix_A_allreduce + (matrix_A,)
matrix_G_allreduce = matrix_G_allreduce + (matrix_G,)
matrix_A_max_allreduce = matrix_A_max_allreduce + (A_max,)
matrix_G_max_allreduce = matrix_G_max_allreduce + (G_max,)
matrix_A_allreduce = self.grad_reducer_A(matrix_A_allreduce)
matrix_G_allreduce = self.grad_reducer_G(matrix_G_allreduce)
matrix_A_max_allreduce = self.grad_reducer_Amax(matrix_A_max_allreduce)
matrix_G_max_allreduce = self.grad_reducer_Gmax(matrix_G_max_allreduce)
new_grads = ()
for i in range(54):
g = gradients[i * 3]
temp_a = matrix_A_allreduce[i]
temp_g = matrix_G_allreduce[i]
temp_a = self.cast(temp_a, mstype.float32)
temp_g = self.cast(temp_g, mstype.float32)
matrix_A_inv_max = self.log(matrix_A_max_allreduce[i])
matrix_A_inv_max = self.mul(matrix_A_inv_max, -1)
matrix_A_inv_max = self.exp(matrix_A_inv_max)
temp_a = self.mul(temp_a, matrix_A_inv_max)
matrix_G_inv_max = self.log(matrix_G_max_allreduce[i])
matrix_G_inv_max = self.mul(matrix_G_inv_max, -1)
matrix_G_inv_max = self.exp(matrix_G_inv_max)
temp_g = self.mul(temp_g, matrix_G_inv_max)
temp_max = self.mul(matrix_A_max_allreduce[i], matrix_G_max_allreduce[i])
temp_max = self.mul(temp_max, self.feature_map[i])
temp_a = self.cast(temp_a, mstype.float16)
temp_g = self.cast(temp_g, mstype.float16)
if i == 53:
g = self.cube_matmul_left_fc(temp_g, g)
g = self.cube_matmul_right_fc(g, temp_a, temp_max)
else:
g = self.cube_matmul_left(temp_g, g)
g = self.cube_matmul_right_mul(g, temp_a, temp_max)
fake_A = self.assign(self.matrix_A[i], temp_a)
fake_G = self.assign(self.matrix_G[i], temp_g)
fake_max = self.assign(self.matrix_max_inv[i], temp_max)
g = F.depend(g, fake_A)
g = F.depend(g, fake_G)
g = F.depend(g, fake_max)
if i == 53:
new_grads = new_grads + (g,)
else:
new_grads = new_grads + (g, gradients[i * 3 + 1], gradients[i * 3 + 2])
gradients = new_grads
else:
new_grads = ()
for i in range(54):
g = gradients[i * 3]
matrix_A = self.matrix_A[i]
matrix_G = self.matrix_G[i]
matrix_max = self.matrix_max_inv[i]
if i == 53:
g = self.cube_matmul_left_fc(matrix_G, g)
g = self.cube_matmul_right_fc(g, matrix_A, matrix_max)
new_grads = new_grads + (g,)
else:
g = self.cube_matmul_left(matrix_G, g)
g = self.cube_matmul_right_mul(g, matrix_A, matrix_max)
new_grads = new_grads + (g, gradients[i * 3 + 1], gradients[i * 3 + 2])
gradients = new_grads
if self.weight_decay > 0:
gradients = self.hyper_map(F.partial(apply_decay, self.weight_decay), self.decay_flags,
params, gradients)
gradients = self.scale_grad(gradients)
lr = self.get_lr()
success = self.hyper_map(F.partial(_momentum_opt, self.opt, self.momentum, lr), gradients, params, moments)
return success

View File

@ -1,771 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""thor_layer"""
import numpy as np
import mindspore.common.dtype as mstype
from mindspore._checkparam import Validator, twice
from mindspore._extends import cell_attr_register
from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter
from mindspore.common.tensor import Tensor
from mindspore.nn.cell import Cell
from mindspore.nn.layer.activation import get_activation
from mindspore.ops import operations as P
C0 = 16
def caculate_device_shape(matrix_dim, channel, is_A):
ll = (0)
if is_A:
if channel // C0 == 0:
matrix_dim = (matrix_dim / channel) * C0
ll = (int(matrix_dim // C0), int(matrix_dim // C0), C0, C0), int(matrix_dim)
else:
ll = (int(matrix_dim // C0), int(matrix_dim // C0), C0, C0), int(matrix_dim)
return ll
def caculate_matmul_shape(matrix_A_dim, matrix_G_dim, split_dim):
split_dimA = split_dim
split_dimG = split_dim
if matrix_A_dim % split_dim == 0:
batch_w = matrix_A_dim // split_dim
else:
if matrix_A_dim < split_dim:
batch_w = 1
split_dimA = matrix_A_dim
else:
batch_w = matrix_A_dim // split_dim + 1
if matrix_G_dim % split_dim == 0:
batch_h = matrix_G_dim // split_dim
else:
if matrix_G_dim < split_dim:
batch_h = 1
split_dimG = matrix_G_dim
else:
batch_h = matrix_G_dim // split_dim + 1
matrix_A_shape = (batch_h, batch_w, split_dimA, split_dimA)
matrix_G_shape = (batch_h, split_dimG, split_dimG)
return matrix_A_shape, matrix_G_shape
class _Conv(Cell):
r"""Applies a N-D convolution over an input signal composed of several input
planes.
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride,
pad_mode,
padding,
dilation,
group,
data_format,
has_bias,
weight_init,
bias_init,
):
super(_Conv, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.pad_mode = pad_mode
self.padding = padding
self.dilation = dilation
self.group = group
self.data_format = data_format
self.has_bias = has_bias
if not (isinstance(in_channels, int) and in_channels > 0):
raise ValueError('Attr \'in_channels\' of \'Conv2D\' Op passed '
+ str(in_channels) + ', should be a int and greater than 0.')
if (not isinstance(kernel_size, tuple)) or len(kernel_size) != 2 or \
(not isinstance(kernel_size[0], int)) or (not isinstance(kernel_size[1], int)) or \
kernel_size[0] < 1 or kernel_size[1] < 1:
raise ValueError('Attr \'kernel_size\' of \'Conv2D\' Op passed '
+ str(self.kernel_size) + ', should be a int or tuple and equal to or greater than 1.')
if in_channels % group != 0:
raise ValueError('Attr \'in_channels\' of \'Conv2D\' Op must be divisible by '
'attr \'group\' of \'Conv2D\' Op.')
if out_channels % group != 0:
raise ValueError('Attr \'out_channels\' of \'Conv2D\' Op must be divisible by '
'attr \'group\' of \'Conv2D\' Op.')
self.weight = Parameter(initializer(
weight_init, [out_channels, in_channels // group, *kernel_size]))
if Validator.check_bool(has_bias):
self.bias = Parameter(initializer(bias_init, [out_channels]))
else:
if bias_init != 'zeros':
logger.warning("Value of 'has_bias' is False, value of 'bias_init' will be ignored.")
self.bias = None
def construct(self, *inputs):
raise NotImplementedError
class Conv2d_Thor_GPU(_Conv):
"""Conv2d_Thor"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
pad_mode='same',
padding=0,
dilation=1,
group=1,
data_format='NCHW',
has_bias=False,
weight_init='normal',
damping=0.03,
loss_scale=1,
frequency=278,
batch_size=32,
bias_init='zeros'):
self.thor = True
self.hw = kernel_size * kernel_size
kernel_size = twice(kernel_size)
super(Conv2d_Thor_GPU, self).__init__(
in_channels,
out_channels,
kernel_size,
stride,
pad_mode,
padding,
dilation,
group,
data_format,
has_bias,
weight_init,
bias_init,
)
self.conv2d = P.Conv2D(out_channel=self.out_channels,
kernel_size=self.kernel_size,
mode=1,
pad_mode=self.pad_mode,
pad=self.padding,
stride=self.stride,
dilation=self.dilation,
group=self.group
)
self.matrix_A_dim = self.in_channels * self.kernel_size[0] * self.kernel_size[1]
self.matrix_G_dim = self.out_channels
split_dim = 128
matrix_A_shape, matrix_G_shape = caculate_matmul_shape(self.matrix_A_dim, self.matrix_G_dim, split_dim)
self.matrix_A_inv = Parameter(np.zeros(matrix_A_shape).astype(np.float32), requires_grad=False)
self.matrix_G_inv = Parameter(np.zeros(matrix_G_shape).astype(np.float32), requires_grad=False)
self.broadcast_to = P.BroadcastTo(matrix_A_shape)
self.cov_step = Parameter(initializer(0, [1], mstype.int32), requires_grad=False)
self.img2col = P.Im2Col(kernel_size=kernel_size, stride=stride, pad_mode="same")
self.matmul = P.MatMul(transpose_b=True)
self.shape = P.Shape()
self.reshape = P.Reshape()
self.mul = P.Mul()
self.getG = P.InsertGradientOf(self.save_gradient)
self.loss_scale = Tensor(1 / loss_scale, mstype.float16)
self.batch_size = Tensor(batch_size, mstype.float16)
self.transpose = P.Transpose()
self.cast = P.Cast()
self.gather = P.Gather()
self.freq = Tensor(frequency, mstype.int32)
self.axis = 0
self.sqrt = P.Sqrt()
self.reduce_mean = P.ReduceMean(keep_dims=False)
self.damping = Parameter(Tensor(damping), requires_grad=False)
self.dampingA = Tensor(np.identity(self.matrix_A_dim), mstype.float32)
self.dampingG = Tensor(np.identity(self.matrix_G_dim), mstype.float32)
self.cholesky = P.CholeskyTrsm(split_dim=split_dim)
self.vector_matmul = P.BatchMatMul(transpose_a=True)
def save_gradient(self, dout):
"""save_gradient"""
out = dout
dout = self.mul(dout, self.loss_scale)
dout = self.mul(dout, self.batch_size)
dout = self.reduce_mean(dout, 0)
dout_shape = self.shape(dout)
dout = self.reshape(dout, (dout_shape[0], -1))
dout_shape = self.shape(dout)
normalizer = dout_shape[1]
dout = self.cast(dout, mstype.float32)
matrix_G = self.matmul(dout, dout)
matrix_G = self.mul(matrix_G, 1.0 / normalizer)
damping_step = self.gather(self.damping, self.cov_step, 0)
damping_step = self.cast(damping_step, mstype.float32)
self.cov_step = self.cov_step + self.freq
damping = self.mul(damping_step, 1.0 / normalizer)
damping = self.sqrt(damping)
matrix_G = matrix_G + damping * self.dampingG
matrix_G = self.cholesky(matrix_G)
matrix_G = self.vector_matmul(matrix_G, matrix_G)
self.matrix_G_inv = matrix_G
return out
def construct(self, x):
if self.thor:
matrix_A = self.img2col(x)
matrix_A_shape = self.shape(matrix_A)
matrix_A = self.reshape(matrix_A, (matrix_A_shape[0]*matrix_A_shape[1]*matrix_A_shape[2],
matrix_A_shape[3], -1))
matrix_A = self.reduce_mean(matrix_A, 1)
matrix_A_shape = self.shape(matrix_A)
normalizer = matrix_A_shape[1]
matrix_A = self.cast(matrix_A, mstype.float32)
matrix_A = self.matmul(matrix_A, matrix_A)
matrix_A = self.mul(matrix_A, 1.0 / normalizer)
damping_step = self.gather(self.damping, self.cov_step, self.axis)
damping_step = self.cast(damping_step, mstype.float32)
damping = self.mul(damping_step, 1.0 / normalizer)
damping = self.sqrt(damping)
matrix_A = matrix_A + damping * self.dampingA
matrix_A = self.cholesky(matrix_A)
matrix_A = self.vector_matmul(matrix_A, matrix_A)
matrix_A = self.broadcast_to(matrix_A)
self.matrix_A_inv = matrix_A
out = self.conv2d(x, self.weight)
out = self.getG(out)
else:
out = self.conv2d(x, self.weight)
return out
def extra_repr(self):
"""extra_repr"""
s = 'input_channels={}, output_channels={}, kernel_size={},' \
'stride={}, pad_mode={}, padding={}, dilation={}, ' \
'group={}, data_format={}, has_bias={},' \
'weight_init={}, bias_init={}'.format(
self.in_channels,
self.out_channels,
self.kernel_size,
self.stride,
self.pad_mode,
self.padding,
self.dilation,
self.group,
self.data_format,
self.has_bias,
self.weight,
self.bias)
if self.has_bias:
s += ', bias={}'.format(self.bias)
return s
class Dense_Thor_GPU(Cell):
"""Dense_Thor"""
@cell_attr_register(attrs=['has_bias', 'activation'])
def __init__(self,
in_channels,
out_channels,
weight_init='normal',
bias_init='zeros',
damping=0.03,
loss_scale=1,
frequency=278,
batch_size=32,
has_bias=True,
activation=None):
super(Dense_Thor_GPU, self).__init__()
self.in_channels = Validator.check_positive_int(in_channels)
self.out_channels = Validator.check_positive_int(out_channels)
self.has_bias = Validator.check_bool(has_bias)
self.thor = True
if isinstance(weight_init, Tensor):
if weight_init.ndim != 2 or weight_init.shape[0] != out_channels or \
weight_init.shape[1] != in_channels:
raise ValueError("weight_init shape error")
self.weight = Parameter(initializer(weight_init, [out_channels, in_channels]))
if self.has_bias:
if isinstance(bias_init, Tensor):
if bias_init.ndim != 1 or bias_init.shape[0] != out_channels:
raise ValueError("bias_init shape error")
self.bias = Parameter(initializer(bias_init, [out_channels]))
self.matmul = P.MatMul(transpose_b=True)
self.bias_add = P.BiasAdd()
self.activation = get_activation(activation)
self.activation_flag = self.activation is not None
split_dim = 128
matrix_A_shape, matrix_G_shape = caculate_matmul_shape(self.in_channels, self.out_channels, split_dim)
self.matrix_A_inv = Parameter(Tensor(np.zeros(matrix_A_shape).astype(np.float32)), requires_grad=False)
self.matrix_G_inv = Parameter(Tensor(np.zeros(matrix_G_shape).astype(np.float32)), requires_grad=False)
self.broadcast_to = P.BroadcastTo(matrix_A_shape)
self.cov_step = Parameter(initializer(0, [1], mstype.int32), requires_grad=False)
self.shape = P.Shape()
self.reshape = P.Reshape()
self.transpose = P.Transpose()
self.mul = P.Mul()
self.cube_matmul = P.MatMul(transpose_a=True)
self.loss_scale = Tensor(1 / loss_scale, mstype.float16)
self.batch_size = Tensor(batch_size, mstype.float16)
self.getG = P.InsertGradientOf(self.save_gradient)
self.damping = Parameter(Tensor(damping), requires_grad=False)
self.dampingA = Tensor(np.identity(in_channels), mstype.float32)
self.dampingG = Tensor(np.identity(out_channels), mstype.float32)
self.cast = P.Cast()
self.gather = P.Gather()
self.freq = Tensor(frequency, mstype.int32)
self.axis = 0
self.add = P.Add()
self.sqrt = P.Sqrt()
self.cholesky = P.CholeskyTrsm(split_dim=split_dim)
self.vector_matmul = P.BatchMatMul(transpose_a=True)
def save_gradient(self, dout):
"""save_gradient"""
out = dout
dout = self.mul(dout, self.loss_scale)
dout = self.mul(dout, self.batch_size)
dout_shape = self.shape(dout)
normalizer = dout_shape[0]
dout = self.cast(dout, mstype.float32)
matrix_G = self.cube_matmul(dout, dout)
matrix_G = self.mul(matrix_G, 1.0 / normalizer)
damping_step = self.gather(self.damping, self.cov_step, 0)
damping_step = self.cast(damping_step, mstype.float32)
self.cov_step = self.cov_step + self.freq
damping = self.sqrt(damping_step)
matrix_G = matrix_G + damping * self.dampingG
matrix_G = self.cholesky(matrix_G)
matrix_G = self.vector_matmul(matrix_G, matrix_G)
self.matrix_G_inv = matrix_G
return out
def construct(self, x):
"""construct"""
if self.thor:
inputs = self.cast(x, mstype.float32)
inputs = self.cube_matmul(inputs, inputs)
inputs_shape = self.shape(inputs)
normalizer = inputs_shape[0]
matrix_A = self.mul(inputs, 1.0 / normalizer)
damping_step = self.gather(self.damping, self.cov_step, self.axis)
damping_step = self.cast(damping_step, mstype.float32)
damping = self.sqrt(damping_step)
matrix_A = matrix_A + damping * self.dampingA
matrix_A = self.cholesky(matrix_A)
matrix_A = self.vector_matmul(matrix_A, matrix_A)
matrix_A = self.broadcast_to(matrix_A)
self.matrix_A_inv = matrix_A
output = self.matmul(x, self.weight)
output = self.getG(output)
else:
output = self.matmul(x, self.weight)
if self.has_bias:
output = self.bias_add(output, self.bias)
if self.activation_flag:
return self.activation(output)
return output
def extend_repr(self):
"""extend_repr"""
s = 'in_channels={}, out_channels={}'.format(self.in_channels, self.out_channels)
if self.has_bias:
s += ', has_bias={}'.format(self.has_bias)
if self.activation_flag:
s += ', activation={}'.format(self.activation)
return s
class Conv2d_Thor(_Conv):
"""Conv2d_Thor"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
pad_mode='same',
padding=0,
dilation=1,
group=1,
data_format='NCHW',
has_bias=False,
weight_init='normal',
damping=0.03,
loss_scale=1,
frequency=278,
batch_size=32,
bias_init='zeros'):
self.thor = True
ksizes = (1, kernel_size, kernel_size, 1)
self.hw = kernel_size * kernel_size
strides = (1, stride, stride, 1)
kernel_size = twice(kernel_size)
super(Conv2d_Thor, self).__init__(
in_channels,
out_channels,
kernel_size,
stride,
pad_mode,
padding,
dilation,
group,
data_format,
has_bias,
weight_init,
bias_init,
)
self.conv2d = P.Conv2D(out_channel=self.out_channels,
kernel_size=self.kernel_size,
mode=1,
pad_mode=self.pad_mode,
pad=self.padding,
stride=self.stride,
dilation=self.dilation,
group=self.group
)
self.batch_size = batch_size
self.img2col = P.CusImg2Col(ksizes=ksizes, strides=strides)
self.cube_matmul = P.CusMatMulCube(transpose_a=True)
self.matrix_combine = P.CusMatrixCombine()
self.cholesky = P.CusCholeskyTrsm()
self.transpose02314 = P.CusTranspose02314()
self.matrix_A_dim = self.in_channels * self.kernel_size[0] * self.kernel_size[1]
self.matrix_G_dim = self.out_channels
self.matrix_A_device_shape, self.matrix_A_device_dim = caculate_device_shape(self.matrix_A_dim,
self.in_channels, True)
self.matrix_G_device_shape, self.matrix_G_device_dim = caculate_device_shape(self.matrix_G_dim,
self.in_channels, False)
self.matrix_A_device_temp_shape = (
self.matrix_A_device_shape[0], self.matrix_A_device_shape[2], self.matrix_A_device_shape[1],
self.matrix_A_device_shape[3])
self.matrix_G_device_temp_shape = (
self.matrix_G_device_shape[0], self.matrix_G_device_shape[2], self.matrix_G_device_shape[1],
self.matrix_G_device_shape[3])
self.matrix_A_inv = Parameter(
Tensor(np.reshape(np.identity(self.matrix_A_device_dim).astype(np.float16), self.matrix_A_device_shape)),
requires_grad=False)
self.A_inv_max = Parameter(initializer(0, [1], mstype.float32), requires_grad=False)
self.matrix_G_inv = Parameter(
Tensor(np.reshape(np.identity(self.matrix_G_device_dim).astype(np.float16), self.matrix_G_device_shape)),
requires_grad=False)
self.G_inv_max = Parameter(initializer(0, [1], mstype.float32), requires_grad=False)
self.fake_G = Tensor(
np.reshape(np.identity(self.matrix_G_device_dim).astype(np.float16), self.matrix_G_device_shape))
self.shape = P.Shape()
self.reshape = P.Reshape()
self.transpose = P.Transpose()
self.cov_step = Parameter(initializer(0, [1], mstype.int32), requires_grad=False)
self.mul = P.Mul()
self.cast = P.Cast()
self.damping = Tensor(damping)
self.vector_matmul = P.CusBatchMatMul()
self.diag_block_dim = 128
self.channels_slice_flag = False
if self.in_channels % C0 != 0:
self.channels_slice_flag = True
self.padA_flag = False
if (self.matrix_A_dim // self.diag_block_dim) * self.diag_block_dim != self.matrix_A_dim \
and self.matrix_A_dim > self.diag_block_dim:
self.padA_flag = True
pad_dim = self.diag_block_dim - self.matrix_A_dim % self.diag_block_dim
self.padA = P.Pad(((0, pad_dim), (0, pad_dim)))
self.device_shape_pad_flag = False
if self.matrix_A_dim != self.matrix_A_device_dim:
self.device_shape_pad_flag = True
self.device_shape_pad = P.Pad(((0, 0), (0, C0 - self.in_channels), (0, 0), (0, C0 - self.in_channels)))
self.slice = P.Slice()
self.gather = P.Gather()
self.freq = Tensor(frequency, mstype.int32)
self.loss_scale = Tensor(1 / loss_scale, mstype.float16)
self.axis = 0
dampingA_dim = self.matrix_A_dim
if (self.matrix_A_dim % self.diag_block_dim) != 0 and self.matrix_A_dim > self.diag_block_dim:
dampingA_dim = (self.matrix_A_dim // self.diag_block_dim + 1) * self.diag_block_dim
dampingG_dim = self.matrix_G_dim
if (self.matrix_G_dim % self.diag_block_dim) != 0 and self.matrix_G_dim > self.diag_block_dim:
dampingG_dim = (self.matrix_G_dim // self.diag_block_dim + 1) * self.diag_block_dim
self.dampingA = Tensor(np.identity(dampingA_dim), mstype.float32)
self.dampingG = Tensor(np.identity(dampingG_dim), mstype.float32)
self.fused_abs_max1 = P.CusFusedAbsMax1([self.matrix_A_dim, self.matrix_A_dim])
self.fused_abs_max2 = P.CusFusedAbsMax1()
self.log = P.Log()
self.exp = P.Exp()
self.sqrt = P.Sqrt()
self.getG = P.InsertGradientOf(self.save_gradient)
def save_gradient(self, dout):
"""save_gradient"""
out = dout
dout = self.mul(dout, self.loss_scale)
dout = self.mul(dout, 32.0)
dout = self.transpose02314(dout)
dout_shape = self.shape(dout)
normalizer = dout_shape[0]
matrix_G = self.cube_matmul(dout, dout)
normalizer = self.cast(normalizer, mstype.float32)
matrix_G = self.mul(matrix_G, 1.0 / normalizer)
damping_step = self.gather(self.damping, self.cov_step, 0)
self.cov_step = self.cov_step + self.freq
damping_step = self.cast(damping_step, mstype.float32)
damping = self.mul(damping_step, 32.0 / normalizer)
damping = self.sqrt(damping)
dampingG = self.cast(self.dampingG, mstype.float32)
matrix_G = matrix_G + damping * dampingG
matrix_G_inv = self.cholesky(matrix_G)
matrix_G_inv = self.vector_matmul(matrix_G_inv, matrix_G_inv)
matrix_G_inv_max = self.fused_abs_max2(matrix_G_inv)
matrix_G_inv_max = self.fused_abs_max2(matrix_G_inv_max)
self.G_inv_max = matrix_G_inv_max
matrix_G_inv = self.matrix_combine(matrix_G_inv)
matrix_G_inv = self.reshape(matrix_G_inv, self.matrix_G_device_temp_shape)
matrix_G_inv = self.transpose(matrix_G_inv, (2, 0, 1, 3))
matrix_G = self.cast(matrix_G_inv, mstype.float16)
self.matrix_G_inv = matrix_G
return out
def construct(self, x):
if self.thor:
matrix_A = self.img2col(x)
matrix_A_shape = self.shape(matrix_A)
normalizer = matrix_A_shape[0]
matrix_A = self.cube_matmul(matrix_A, matrix_A)
if self.channels_slice_flag:
matrix_A = self.reshape(matrix_A, (self.hw, C0, self.hw, C0))
matrix_A = self.slice(matrix_A, (0, 0, 0, 0), (self.hw, self.in_channels, self.hw, self.in_channels))
matrix_A = self.reshape(matrix_A, (self.matrix_A_dim, self.matrix_A_dim))
normalizer = self.cast(normalizer, mstype.float32)
matrix_A = self.mul(matrix_A, 1.0 / normalizer)
if self.padA_flag:
matrix_A = self.padA(matrix_A)
damping_step = self.gather(self.damping, self.cov_step, self.axis)
damping_step = self.cast(damping_step, mstype.float32)
damping = self.mul(damping_step, 32.0 / normalizer)
damping = self.sqrt(damping)
damping_A = self.cast(self.dampingA, mstype.float32)
matrix_A = matrix_A + damping * damping_A
matrix_A_inv = self.cholesky(matrix_A)
matrix_A_inv = self.vector_matmul(matrix_A_inv, matrix_A_inv)
matrix_A_inv_max = self.fused_abs_max1(matrix_A_inv)
matrix_A_inv_max = self.fused_abs_max2(matrix_A_inv_max)
self.A_inv_max = matrix_A_inv_max
matrix_A_inv = self.matrix_combine(matrix_A_inv)
matrix_A_inv = self.cast(matrix_A_inv, mstype.float16)
if self.padA_flag:
matrix_A_inv = self.slice(matrix_A_inv, (0, 0), (self.matrix_A_dim, self.matrix_A_dim))
if self.device_shape_pad_flag:
matrix_A_inv = self.reshape(matrix_A_inv, (self.hw, self.in_channels, self.hw, self.in_channels))
matrix_A_inv = self.device_shape_pad(matrix_A_inv)
matrix_A_inv = self.reshape(matrix_A_inv, self.matrix_A_device_temp_shape)
matrix_A_inv = self.transpose(matrix_A_inv, (2, 0, 1, 3))
self.matrix_A_inv = matrix_A_inv
out = self.conv2d(x, self.weight)
out = self.getG(out)
else:
out = self.conv2d(x, self.weight)
return out
def extra_repr(self):
"""extra_repr"""
s = 'input_channels={}, output_channels={}, kernel_size={},' \
'stride={}, pad_mode={}, padding={}, dilation={}, ' \
'group={}, data_format={}, has_bias={},' \
'weight_init={}, bias_init={}'.format(
self.in_channels,
self.out_channels,
self.kernel_size,
self.stride,
self.pad_mode,
self.padding,
self.dilation,
self.group,
self.data_format,
self.has_bias,
self.weight,
self.bias)
if self.has_bias:
s += ', bias={}'.format(self.bias)
return s
class Dense_Thor(Cell):
"""Dense_Thor"""
@cell_attr_register(attrs=['has_bias', 'activation'])
def __init__(self,
in_channels,
out_channels,
weight_init='normal',
bias_init='zeros',
damping=0.03,
loss_scale=1,
frequency=278,
batch_size=32,
has_bias=True,
activation=None):
super(Dense_Thor, self).__init__()
self.in_channels = Validator.check_positive_int(in_channels)
self.out_channels = Validator.check_positive_int(out_channels)
self.has_bias = Validator.check_bool(has_bias)
self.thor = True
self.batch_size = batch_size
if isinstance(weight_init, Tensor):
if weight_init.ndim != 2 or weight_init.shape[0] != out_channels or \
weight_init.shape[1] != in_channels:
raise ValueError("weight_init shape error")
self.weight = Parameter(initializer(weight_init, [out_channels, in_channels]))
if self.has_bias:
if isinstance(bias_init, Tensor):
if bias_init.ndim != 1 or bias_init.shape[0] != out_channels:
raise ValueError("bias_init shape error")
self.bias = Parameter(initializer(bias_init, [out_channels]))
self.matmul = P.MatMul(transpose_b=True)
self.bias_add = P.BiasAdd()
self.activation = get_activation(activation)
self.activation_flag = self.activation is not None
self.matrix_A_inv = Parameter(Tensor(np.zeros([128, 128, 16, 16]).astype(np.float16)), requires_grad=False)
self.matrix_G_inv = Parameter(Tensor(np.zeros([63, 63, 16, 16]).astype(np.float16)), requires_grad=False)
self.fake_G = Tensor(np.zeros([63, 63, 16, 16]).astype(np.float16))
self.matmul = P.MatMul(transpose_b=True)
self.cube_matmul = P.CusMatMulCube(transpose_a=True)
self.matrix_combine = P.CusMatrixCombine()
self.cholesky = P.CusCholeskyTrsm()
self.shape = P.Shape()
self.reshape = P.Reshape()
self.transpose = P.Transpose()
self.cov_step = Parameter(initializer(0, [1], mstype.int32), requires_grad=False)
self.mul = P.Mul()
self.cast = P.Cast()
self.damping = Tensor(damping)
self.loss_scale = Tensor(1 / loss_scale, mstype.float16)
self.vector_matmul = P.CusBatchMatMul()
self.pad = P.Pad(((0, 23), (0, 23)))
self.pad1 = P.Pad(((0, 7), (0, 7)))
self.slice = P.Slice()
self.gather = P.Gather()
self.assignadd = P.AssignAdd()
self.freq = Tensor(frequency, mstype.int32)
self.axis = 0
self.A_inv_max = Parameter(initializer(0, [1], mstype.float32), requires_grad=False)
self.G_inv_max = Parameter(initializer(0, [1], mstype.float32), requires_grad=False)
self.fused_abs_max1 = P.CusFusedAbsMax1([1001, 1001])
self.fused_abs_max2 = P.CusFusedAbsMax1()
self.log = P.Log()
self.exp = P.Exp()
self.dampingA = Tensor(np.identity(2048), mstype.float32)
self.dampingG = Tensor(np.identity(1024), mstype.float32)
self.add = P.Add()
self.sqrt = P.Sqrt()
self.getG = P.InsertGradientOf(self.save_gradient)
def save_gradient(self, dout):
"""save_gradient"""
out = dout
dout = self.mul(dout, self.loss_scale)
dout = self.mul(dout, 32.0)
normalizer = 32
matrix_G = self.cube_matmul(dout, dout)
normalizer = self.cast(normalizer, mstype.float32)
matrix_G = self.mul(matrix_G, 1.0 / normalizer)
matrix_G = self.pad(matrix_G)
damping_step = self.gather(self.damping, self.cov_step, 0)
damping_step = self.cast(damping_step, mstype.float32)
self.cov_step = self.cov_step + self.freq
damping = self.sqrt(damping_step)
dampingG = self.cast(self.dampingG, mstype.float32)
matrix_G = matrix_G + damping * dampingG
matrix_G_inv = self.cholesky(matrix_G)
matrix_G_inv = self.vector_matmul(matrix_G_inv, matrix_G_inv)
matrix_G_inv_max = self.fused_abs_max1(matrix_G_inv)
matrix_G_inv_max = self.fused_abs_max2(matrix_G_inv_max)
self.G_inv_max = matrix_G_inv_max
matrix_G_inv = self.matrix_combine(matrix_G_inv)
matrix_G_inv = self.slice(matrix_G_inv, (0, 0), (1001, 1001))
matrix_G_inv = self.pad1(matrix_G_inv)
matrix_G_inv_shape = self.shape(matrix_G_inv)
matrix_G_inv = self.reshape(matrix_G_inv, (matrix_G_inv_shape[0] / 16, 16, matrix_G_inv_shape[0] / 16, 16))
matrix_G_inv = self.transpose(matrix_G_inv, (2, 0, 1, 3))
matrix_G_inv = self.cast(matrix_G_inv, mstype.float16)
self.matrix_G_inv = matrix_G_inv
return out
def construct(self, x):
"""construct"""
if self.thor:
inputs = self.cube_matmul(x, x)
normalizer = 32
normalizer = self.cast(normalizer, mstype.float32)
matrix_A = self.mul(inputs, 1.0 / normalizer)
damping_step = self.gather(self.damping, self.cov_step, self.axis)
damping_step = self.cast(damping_step, mstype.float32)
damping = self.sqrt(damping_step)
dampingA = self.cast(self.dampingA, mstype.float32)
matrix_A = matrix_A + damping * dampingA
matrix_A_inv = self.cholesky(matrix_A)
matrix_A_inv = self.vector_matmul(matrix_A_inv, matrix_A_inv)
matrix_A_inv_max = self.fused_abs_max2(matrix_A_inv)
matrix_A_inv_max = self.fused_abs_max2(matrix_A_inv_max)
self.A_inv_max = matrix_A_inv_max
matrix_A_inv = self.matrix_combine(matrix_A_inv)
matrix_A_inv_shape = self.shape(matrix_A_inv)
matrix_A_inv = self.reshape(matrix_A_inv, (matrix_A_inv_shape[0] / 16, 16, matrix_A_inv_shape[0] / 16, 16))
matrix_A_inv = self.transpose(matrix_A_inv, (2, 0, 1, 3))
matrix_A_inv = self.cast(matrix_A_inv, mstype.float16)
self.matrix_A_inv = matrix_A_inv
output = self.matmul(x, self.weight)
output = self.getG(output)
else:
output = self.matmul(x, self.weight)
if self.has_bias:
output = self.bias_add(output, self.bias)
if self.activation_flag:
return self.activation(output)
return output
def extend_repr(self):
"""extend_repr"""
s = 'in_channels={}, out_channels={}'.format(self.in_channels, self.out_channels)
if self.has_bias:
s += ', has_bias={}'.format(self.has_bias)
if self.activation_flag:
s += ', activation={}'.format(self.activation)
return s

View File

@ -24,11 +24,14 @@ from mindspore.context import ParallelMode
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor, LossMonitor
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.communication.management import init, get_rank, get_group_size
from mindspore.parallel import set_algo_parameters
from mindspore.train.train_thor import ConvertModelUtils
from mindspore.nn.optim import thor
from mindspore.train.model import Model
from src.model_thor import Model_Thor as Model
from src.resnet_thor import resnet50
from src.resnet import resnet50 as resnet
from src.dataset import create_dataset
from src.crossentropy import CrossEntropy
from src.crossentropy import CrossEntropy as CrossEntropySmooth
parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute')
@ -38,16 +41,32 @@ parser.add_argument('--device_num', type=int, default=1, help='Device num')
args_opt = parser.parse_args()
if args_opt.device_target == "Ascend":
from src.thor import THOR
from src.config import config
else:
from src.thor import THOR_GPU as THOR
from src.config import config_gpu as config
set_seed(1)
def get_model_lr(global_step, lr_init, decay, total_epochs, steps_per_epoch, decay_epochs=100):
def filter_checkpoint_parameter_by_list(origin_dict, param_filter):
"""remove useless parameters according to filter_list"""
for key in list(origin_dict.keys()):
for name in param_filter:
if name in key:
print("Delete parameter from checkpoint: ", key)
del origin_dict[key]
break
def apply_eval(eval_param):
eval_model = eval_param["model"]
eval_ds = eval_param["dataset"]
metrics_name = eval_param["metrics_name"]
res = eval_model.eval(eval_ds)
return res[metrics_name]
def get_thor_lr(global_step, lr_init, decay, total_epochs, steps_per_epoch, decay_epochs=100):
"""get_model_lr"""
lr_each_step = []
total_steps = steps_per_epoch * total_epochs
@ -66,7 +85,7 @@ def get_model_lr(global_step, lr_init, decay, total_epochs, steps_per_epoch, dec
return learning_rate
def get_model_damping(global_step, damping_init, decay_rate, total_epochs, steps_per_epoch):
def get_thor_damping(global_step, damping_init, decay_rate, total_epochs, steps_per_epoch):
"""get_model_damping"""
damping_each_step = []
total_steps = steps_per_epoch * total_epochs
@ -88,46 +107,50 @@ if __name__ == '__main__':
context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False)
if args_opt.run_distribute:
# Ascend target
if target == "Ascend":
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(device_id=device_id, enable_auto_mixed_precision=True)
context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True, all_reduce_fusion_config=[107])
gradients_mean=True)
set_algo_parameters(elementwise_op_strategy_follow=True)
context.set_auto_parallel_context(all_reduce_fusion_config=[85, 160])
init()
# GPU target
else:
init()
context.set_auto_parallel_context(device_num=get_group_size(), parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True, all_reduce_fusion_config=[107])
ckpt_save_dir = ckpt_save_dir + "ckpt_" + str(get_rank()) + "/"
gradients_mean=True)
context.set_auto_parallel_context(all_reduce_fusion_config=[85, 160])
ckpt_save_dir = config.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/"
# create dataset
dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True, repeat_num=1,
batch_size=config.batch_size, target=target)
step_size = dataset.get_dataset_size()
# define net
step_size = dataset.get_dataset_size()
damping = get_model_damping(0, config.damping_init, config.damping_decay, 70, step_size)
lr = get_model_lr(0, config.lr_init, config.lr_decay, config.lr_end_epoch, step_size, decay_epochs=39)
net = resnet50(class_num=config.class_num, damping=damping, loss_scale=config.loss_scale,
frequency=config.frequency, batch_size=config.batch_size)
net = resnet(class_num=config.class_num)
# define loss, model
# init lr
lr = get_thor_lr(0, config.lr_init, config.lr_decay, config.lr_end_epoch, step_size, decay_epochs=39)
lr = Tensor(lr)
# define loss
if not config.use_label_smooth:
config.label_smooth_factor = 0.0
loss = CrossEntropy(smooth_factor=config.label_smooth_factor, num_classes=config.class_num)
opt = THOR(filter(lambda x: x.requires_grad, net.get_parameters()), Tensor(lr), config.momentum,
filter(lambda x: 'matrix_A' in x.name, net.get_parameters()),
filter(lambda x: 'matrix_G' in x.name, net.get_parameters()),
filter(lambda x: 'A_inv_max' in x.name, net.get_parameters()),
filter(lambda x: 'G_inv_max' in x.name, net.get_parameters()),
config.weight_decay, config.loss_scale)
loss = CrossEntropySmooth(smooth_factor=config.label_smooth_factor, num_classes=config.class_num)
loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
model = Model(net, loss_fn=loss, optimizer=opt, amp_level='O2', loss_scale_manager=loss_scale,
keep_batchnorm_fp32=False, metrics={'acc'}, frequency=config.frequency,
use_dynamic_frequency=config.use_dynamic_frequency,
first_stage_steps=config.first_stage_steps)
metrics = {"acc"}
damping = get_thor_damping(0, config.damping_init, config.damping_decay, 70, step_size)
split_indices = [26, 53]
opt = thor(net, lr, Tensor(damping), config.momentum, config.weight_decay, config.loss_scale,
config.batch_size, split_indices=split_indices, frequency=config.frequency)
model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics=metrics,
amp_level="O2", keep_batchnorm_fp32=False)
model = ConvertModelUtils().convert_to_thor_model(model=model, network=net, loss_fn=loss, optimizer=opt,
loss_scale_manager=loss_scale, metrics={'acc'},
amp_level="O2", keep_batchnorm_fp32=False)
# define callbacks
time_cb = TimeMonitor(data_size=step_size)
@ -140,4 +163,6 @@ if __name__ == '__main__':
cb += [ckpt_cb]
# train model
model.train(config.epoch_size, dataset, callbacks=cb)
dataset_sink_mode = True
model.train(config.epoch_size, dataset, callbacks=cb,
sink_size=dataset.get_dataset_size(), dataset_sink_mode=dataset_sink_mode)