forked from mindspore-Ecosystem/mindspore
!11915 Change TensorAdd to Add, merge from r1.1 to master
From: @liangzhibo Reviewed-by: @ginfung,@zh_qh Signed-off-by: @zh_qh
This commit is contained in:
commit
e897eb4c41
File diff suppressed because one or more lines are too long
|
@ -36,24 +36,24 @@ def expand_biasadd(expand_info):
|
||||||
'ExpandDims', [input_y], attrs={'axis': 1})
|
'ExpandDims', [input_y], attrs={'axis': 1})
|
||||||
input_y_expand = graph_builder.emit(
|
input_y_expand = graph_builder.emit(
|
||||||
'ExpandDims', [input_y_expand], attrs={'axis': 2})
|
'ExpandDims', [input_y_expand], attrs={'axis': 2})
|
||||||
result = graph_builder.emit('TensorAdd', [input_x, input_y_expand])
|
result = graph_builder.emit('Add', [input_x, input_y_expand])
|
||||||
elif input_x.data_format == "DefaultFormat":
|
elif input_x.data_format == "DefaultFormat":
|
||||||
if len(input_x.shape) == 2:
|
if len(input_x.shape) == 2:
|
||||||
result = graph_builder.emit('TensorAdd', [input_x, input_y])
|
result = graph_builder.emit('Add', [input_x, input_y])
|
||||||
elif len(input_x.shape) == 3:
|
elif len(input_x.shape) == 3:
|
||||||
input_y_expand = graph_builder.emit(
|
input_y_expand = graph_builder.emit(
|
||||||
'ExpandDims', [input_y], attrs={'axis': 1})
|
'ExpandDims', [input_y], attrs={'axis': 1})
|
||||||
result = graph_builder.emit(
|
result = graph_builder.emit(
|
||||||
'TensorAdd', [input_x, input_y_expand])
|
'Add', [input_x, input_y_expand])
|
||||||
else:
|
else:
|
||||||
input_y_expand = graph_builder.emit(
|
input_y_expand = graph_builder.emit(
|
||||||
'ExpandDims', [input_y], attrs={'axis': 1})
|
'ExpandDims', [input_y], attrs={'axis': 1})
|
||||||
input_y_expand = graph_builder.emit(
|
input_y_expand = graph_builder.emit(
|
||||||
'ExpandDims', [input_y_expand], attrs={'axis': 2})
|
'ExpandDims', [input_y_expand], attrs={'axis': 2})
|
||||||
result = graph_builder.emit(
|
result = graph_builder.emit(
|
||||||
'TensorAdd', [input_x, input_y_expand])
|
'Add', [input_x, input_y_expand])
|
||||||
else:
|
else:
|
||||||
result = graph_builder.emit('TensorAdd', [input_x, input_y])
|
result = graph_builder.emit('Add', [input_x, input_y])
|
||||||
|
|
||||||
# set graph output.
|
# set graph output.
|
||||||
graph_scope.set_output(result)
|
graph_scope.set_output(result)
|
||||||
|
|
|
@ -49,13 +49,13 @@ def expand_fusedadam(expand_info):
|
||||||
# compute result
|
# compute result
|
||||||
beta_1_mul_m = graph_builder.emit('Mul', [beta_1, m])
|
beta_1_mul_m = graph_builder.emit('Mul', [beta_1, m])
|
||||||
one_sub_beta_1_mul_grad = graph_builder.emit('Mul', [one_sub_beta_1, gradient])
|
one_sub_beta_1_mul_grad = graph_builder.emit('Mul', [one_sub_beta_1, gradient])
|
||||||
next_m = graph_builder.emit('TensorAdd', [beta_1_mul_m, one_sub_beta_1_mul_grad])
|
next_m = graph_builder.emit('Add', [beta_1_mul_m, one_sub_beta_1_mul_grad])
|
||||||
beta_2_mul_v = graph_builder.emit('Mul', [beta_2, v])
|
beta_2_mul_v = graph_builder.emit('Mul', [beta_2, v])
|
||||||
grad_square = graph_builder.emit('Mul', [gradient, gradient])
|
grad_square = graph_builder.emit('Mul', [gradient, gradient])
|
||||||
one_sub_beta_2_mul_grad_square = graph_builder.emit('Mul', [one_sub_beta_2, grad_square])
|
one_sub_beta_2_mul_grad_square = graph_builder.emit('Mul', [one_sub_beta_2, grad_square])
|
||||||
next_v = graph_builder.emit('TensorAdd', [beta_2_mul_v, one_sub_beta_2_mul_grad_square])
|
next_v = graph_builder.emit('Add', [beta_2_mul_v, one_sub_beta_2_mul_grad_square])
|
||||||
sqrt_next_v = graph_builder.emit('Sqrt', [next_v])
|
sqrt_next_v = graph_builder.emit('Sqrt', [next_v])
|
||||||
sqrt_next_v_add_eps = graph_builder.emit('TensorAdd', [sqrt_next_v, eps])
|
sqrt_next_v_add_eps = graph_builder.emit('Add', [sqrt_next_v, eps])
|
||||||
update = graph_builder.emit('RealDiv', [next_m, sqrt_next_v_add_eps])
|
update = graph_builder.emit('RealDiv', [next_m, sqrt_next_v_add_eps])
|
||||||
update_with_lr = graph_builder.emit('Mul', [lr, update])
|
update_with_lr = graph_builder.emit('Mul', [lr, update])
|
||||||
next_para = graph_builder.emit('Sub', [param, update_with_lr])
|
next_para = graph_builder.emit('Sub', [param, update_with_lr])
|
||||||
|
|
|
@ -52,16 +52,16 @@ def expand_fusedadamweightdecay(expand_info):
|
||||||
# compute result
|
# compute result
|
||||||
beta_1_mul_m = graph_builder.emit('Mul', [beta_1, m])
|
beta_1_mul_m = graph_builder.emit('Mul', [beta_1, m])
|
||||||
one_sub_beta_1_mul_grad = graph_builder.emit('Mul', [one_sub_beta_1, gradient])
|
one_sub_beta_1_mul_grad = graph_builder.emit('Mul', [one_sub_beta_1, gradient])
|
||||||
next_m = graph_builder.emit('TensorAdd', [beta_1_mul_m, one_sub_beta_1_mul_grad])
|
next_m = graph_builder.emit('Add', [beta_1_mul_m, one_sub_beta_1_mul_grad])
|
||||||
beta_2_mul_v = graph_builder.emit('Mul', [beta_2, v])
|
beta_2_mul_v = graph_builder.emit('Mul', [beta_2, v])
|
||||||
grad_square = graph_builder.emit('Mul', [gradient, gradient])
|
grad_square = graph_builder.emit('Mul', [gradient, gradient])
|
||||||
one_sub_beta_2_mul_grad_square = graph_builder.emit('Mul', [one_sub_beta_2, grad_square])
|
one_sub_beta_2_mul_grad_square = graph_builder.emit('Mul', [one_sub_beta_2, grad_square])
|
||||||
next_v = graph_builder.emit('TensorAdd', [beta_2_mul_v, one_sub_beta_2_mul_grad_square])
|
next_v = graph_builder.emit('Add', [beta_2_mul_v, one_sub_beta_2_mul_grad_square])
|
||||||
sqrt_next_v = graph_builder.emit('Sqrt', [next_v])
|
sqrt_next_v = graph_builder.emit('Sqrt', [next_v])
|
||||||
sqrt_next_v_add_eps = graph_builder.emit('TensorAdd', [sqrt_next_v, eps])
|
sqrt_next_v_add_eps = graph_builder.emit('Add', [sqrt_next_v, eps])
|
||||||
update = graph_builder.emit('RealDiv', [next_m, sqrt_next_v_add_eps])
|
update = graph_builder.emit('RealDiv', [next_m, sqrt_next_v_add_eps])
|
||||||
param_with_weight_decay = graph_builder.emit('Mul', [weight_decay, param])
|
param_with_weight_decay = graph_builder.emit('Mul', [weight_decay, param])
|
||||||
update = graph_builder.emit('TensorAdd', [update, param_with_weight_decay])
|
update = graph_builder.emit('Add', [update, param_with_weight_decay])
|
||||||
update_with_lr = graph_builder.emit('Mul', [lr, update])
|
update_with_lr = graph_builder.emit('Mul', [lr, update])
|
||||||
next_para = graph_builder.emit('Sub', [param, update_with_lr])
|
next_para = graph_builder.emit('Sub', [param, update_with_lr])
|
||||||
|
|
||||||
|
|
|
@ -42,7 +42,7 @@ def expand_gelu(expand_info):
|
||||||
pow_0 = graph_builder.emit('Mul', [mul_0, input_x])
|
pow_0 = graph_builder.emit('Mul', [mul_0, input_x])
|
||||||
const_csvalue = graph_builder.value(pow_0.dtype, CSVALUE, input_desc['format'])
|
const_csvalue = graph_builder.value(pow_0.dtype, CSVALUE, input_desc['format'])
|
||||||
mul_1 = graph_builder.emit('Mul', [pow_0, const_csvalue])
|
mul_1 = graph_builder.emit('Mul', [pow_0, const_csvalue])
|
||||||
tanh_res = graph_builder.emit('TensorAdd', [input_x, mul_1])
|
tanh_res = graph_builder.emit('Add', [input_x, mul_1])
|
||||||
const_csvalue_sqrt_two_div_pi = graph_builder.value(
|
const_csvalue_sqrt_two_div_pi = graph_builder.value(
|
||||||
tanh_res.dtype, CSVALUE_SQRT_TWO_DIV_PI, input_desc['format'])
|
tanh_res.dtype, CSVALUE_SQRT_TWO_DIV_PI, input_desc['format'])
|
||||||
y = graph_builder.emit('Mul', [tanh_res, const_csvalue_sqrt_two_div_pi])
|
y = graph_builder.emit('Mul', [tanh_res, const_csvalue_sqrt_two_div_pi])
|
||||||
|
@ -51,7 +51,7 @@ def expand_gelu(expand_info):
|
||||||
tanh_y = graph_builder.emit('Tanh', [y])
|
tanh_y = graph_builder.emit('Tanh', [y])
|
||||||
const_one = graph_builder.value(tanh_y.dtype, ONE, input_desc['format'])
|
const_one = graph_builder.value(tanh_y.dtype, ONE, input_desc['format'])
|
||||||
const_half = graph_builder.value(tanh_y.dtype, HALF, input_desc['format'])
|
const_half = graph_builder.value(tanh_y.dtype, HALF, input_desc['format'])
|
||||||
tanh_y_add_one = graph_builder.emit('TensorAdd', [tanh_y, const_one])
|
tanh_y_add_one = graph_builder.emit('Add', [tanh_y, const_one])
|
||||||
mul_x = graph_builder.emit('Mul', [input_x, tanh_y_add_one])
|
mul_x = graph_builder.emit('Mul', [input_x, tanh_y_add_one])
|
||||||
result = graph_builder.emit('Mul', [const_half, mul_x])
|
result = graph_builder.emit('Mul', [const_half, mul_x])
|
||||||
|
|
||||||
|
|
|
@ -55,18 +55,18 @@ def expand_gelugrad(expand_info):
|
||||||
# cal mul_right
|
# cal mul_right
|
||||||
mul_double = graph_builder.emit('Mul', [input_x, input_x])
|
mul_double = graph_builder.emit('Mul', [input_x, input_x])
|
||||||
mul_double_mul_tri = graph_builder.emit('Mul', [const_csvalue_tri, mul_double])
|
mul_double_mul_tri = graph_builder.emit('Mul', [const_csvalue_tri, mul_double])
|
||||||
mul_add_one = graph_builder.emit('TensorAdd', [const_one, mul_double_mul_tri])
|
mul_add_one = graph_builder.emit('Add', [const_one, mul_double_mul_tri])
|
||||||
mul_right = graph_builder.emit('Mul', [const_csvalue_sqrt_two_div_pi, mul_add_one])
|
mul_right = graph_builder.emit('Mul', [const_csvalue_sqrt_two_div_pi, mul_add_one])
|
||||||
|
|
||||||
# cal tanh_para
|
# cal tanh_para
|
||||||
mul_triple = graph_builder.emit('Mul', [input_x, mul_double])
|
mul_triple = graph_builder.emit('Mul', [input_x, mul_double])
|
||||||
mul_triple_mul_csvalue = graph_builder.emit('Mul', [const_csvalue, mul_triple])
|
mul_triple_mul_csvalue = graph_builder.emit('Mul', [const_csvalue, mul_triple])
|
||||||
mul_add_x = graph_builder.emit('TensorAdd', [input_x, mul_triple_mul_csvalue])
|
mul_add_x = graph_builder.emit('Add', [input_x, mul_triple_mul_csvalue])
|
||||||
tanh_para = graph_builder.emit('Mul', [const_csvalue_sqrt_two_div_pi, mul_add_x])
|
tanh_para = graph_builder.emit('Mul', [const_csvalue_sqrt_two_div_pi, mul_add_x])
|
||||||
|
|
||||||
# cal 0.5 * (1.0 + tanh(tahn_para))
|
# cal 0.5 * (1.0 + tanh(tahn_para))
|
||||||
tanh_res = graph_builder.emit('Tanh', [tanh_para])
|
tanh_res = graph_builder.emit('Tanh', [tanh_para])
|
||||||
tanh_res_add_one = graph_builder.emit('TensorAdd', [const_one, tanh_res])
|
tanh_res_add_one = graph_builder.emit('Add', [const_one, tanh_res])
|
||||||
half_mul_tanh_res_add_one = graph_builder.emit('Mul', [const_half, tanh_res_add_one])
|
half_mul_tanh_res_add_one = graph_builder.emit('Mul', [const_half, tanh_res_add_one])
|
||||||
|
|
||||||
# cal 0.5 * x * (1.0 - tanh(tanh_para) * tanh(tanh_para)) * mul_right
|
# cal 0.5 * x * (1.0 - tanh(tanh_para) * tanh(tanh_para)) * mul_right
|
||||||
|
@ -77,7 +77,7 @@ def expand_gelugrad(expand_info):
|
||||||
mul_final = graph_builder.emit('Mul', [mul_tmp, mul_right])
|
mul_final = graph_builder.emit('Mul', [mul_tmp, mul_right])
|
||||||
|
|
||||||
# cal result
|
# cal result
|
||||||
result_tmp = graph_builder.emit('TensorAdd', [half_mul_tanh_res_add_one, mul_final])
|
result_tmp = graph_builder.emit('Add', [half_mul_tanh_res_add_one, mul_final])
|
||||||
result = graph_builder.emit('Mul', [input_dy, result_tmp])
|
result = graph_builder.emit('Mul', [input_dy, result_tmp])
|
||||||
|
|
||||||
# set graph output.
|
# set graph output.
|
||||||
|
|
|
@ -68,13 +68,13 @@ def expand_layernorm(expand_info):
|
||||||
# Calculate normalize
|
# Calculate normalize
|
||||||
normalize_sub = graph_builder.emit('Sub', [input_x, mean])
|
normalize_sub = graph_builder.emit('Sub', [input_x, mean])
|
||||||
epsilon_v = graph_builder.value(input_x.dtype, epsilon, input_x.data_format)
|
epsilon_v = graph_builder.value(input_x.dtype, epsilon, input_x.data_format)
|
||||||
normalize_add = graph_builder.emit('TensorAdd', [variance, epsilon_v])
|
normalize_add = graph_builder.emit('Add', [variance, epsilon_v])
|
||||||
normlize_rsqrt = graph_builder.emit('Rsqrt', [normalize_add])
|
normlize_rsqrt = graph_builder.emit('Rsqrt', [normalize_add])
|
||||||
normalize_mul = graph_builder.emit('Mul', [normalize_sub, normlize_rsqrt])
|
normalize_mul = graph_builder.emit('Mul', [normalize_sub, normlize_rsqrt])
|
||||||
|
|
||||||
# Calculate scale and translate
|
# Calculate scale and translate
|
||||||
scale_mul = graph_builder.emit('Mul', [input_gamma, normalize_mul])
|
scale_mul = graph_builder.emit('Mul', [input_gamma, normalize_mul])
|
||||||
res = graph_builder.emit('TensorAdd', [scale_mul, input_beta])
|
res = graph_builder.emit('Add', [scale_mul, input_beta])
|
||||||
|
|
||||||
# set graph output.
|
# set graph output.
|
||||||
graph_scope.set_output(res, mean, variance)
|
graph_scope.set_output(res, mean, variance)
|
||||||
|
|
|
@ -66,7 +66,7 @@ def expand_layernormgrad(expand_info):
|
||||||
mean_cof = graph_builder.value(x.dtype, (1.0 / reduce_size), x.data_format)
|
mean_cof = graph_builder.value(x.dtype, (1.0 / reduce_size), x.data_format)
|
||||||
|
|
||||||
# cal dg db
|
# cal dg db
|
||||||
var_eps = graph_builder.emit('TensorAdd', [variance, eps])
|
var_eps = graph_builder.emit('Add', [variance, eps])
|
||||||
sqrt_var_eps = graph_builder.emit('Sqrt', [var_eps])
|
sqrt_var_eps = graph_builder.emit('Sqrt', [var_eps])
|
||||||
rsqrt_var_eps = graph_builder.emit('RealDiv', [const_one, sqrt_var_eps])
|
rsqrt_var_eps = graph_builder.emit('RealDiv', [const_one, sqrt_var_eps])
|
||||||
x_sub_mean = graph_builder.emit('Sub', [x, mean])
|
x_sub_mean = graph_builder.emit('Sub', [x, mean])
|
||||||
|
@ -100,10 +100,10 @@ def expand_layernormgrad(expand_info):
|
||||||
neg_rsqrt_var_eps_mul_sum_2 = graph_builder.emit('Mul', [neg_rsqrt_var_eps, sum_2])
|
neg_rsqrt_var_eps_mul_sum_2 = graph_builder.emit('Mul', [neg_rsqrt_var_eps, sum_2])
|
||||||
sum_1_mul_sum_3 = graph_builder.emit('Mul', [sum_1, sum_3])
|
sum_1_mul_sum_3 = graph_builder.emit('Mul', [sum_1, sum_3])
|
||||||
mean_cof_mul_sum_1_mul_sum_3 = graph_builder.emit('Mul', [mean_cof, sum_1_mul_sum_3])
|
mean_cof_mul_sum_1_mul_sum_3 = graph_builder.emit('Mul', [mean_cof, sum_1_mul_sum_3])
|
||||||
add_tmp = graph_builder.emit('TensorAdd', [neg_rsqrt_var_eps_mul_sum_2, mean_cof_mul_sum_1_mul_sum_3])
|
add_tmp = graph_builder.emit('Add', [neg_rsqrt_var_eps_mul_sum_2, mean_cof_mul_sum_1_mul_sum_3])
|
||||||
dx_3 = graph_builder.emit('Mul', [add_tmp, mean_cof])
|
dx_3 = graph_builder.emit('Mul', [add_tmp, mean_cof])
|
||||||
dx_tmp = graph_builder.emit('TensorAdd', [dx_1, dx_2])
|
dx_tmp = graph_builder.emit('Add', [dx_1, dx_2])
|
||||||
dx = graph_builder.emit('TensorAdd', [dx_tmp, dx_3])
|
dx = graph_builder.emit('Add', [dx_tmp, dx_3])
|
||||||
|
|
||||||
# set graph output.
|
# set graph output.
|
||||||
graph_scope.set_output(dx, dg, db)
|
graph_scope.set_output(dx, dg, db)
|
||||||
|
|
|
@ -131,7 +131,7 @@ class PrimLib:
|
||||||
]
|
]
|
||||||
|
|
||||||
primtives = {
|
primtives = {
|
||||||
'TensorAdd': Prim(ELEMWISE),
|
'Add': Prim(ELEMWISE),
|
||||||
'Abs': Prim(ELEMWISE),
|
'Abs': Prim(ELEMWISE),
|
||||||
'Neg': Prim(ELEMWISE),
|
'Neg': Prim(ELEMWISE),
|
||||||
'Mul': Prim(ELEMWISE),
|
'Mul': Prim(ELEMWISE),
|
||||||
|
|
|
@ -238,7 +238,7 @@ void ArithmeticCPUKernel::LessEqual(const T *input1, const T *input2, bool *out,
|
||||||
void ArithmeticCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
void ArithmeticCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||||
std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node);
|
std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node);
|
||||||
if (kernel_name == prim::kPrimTensorAdd->name()) {
|
if (kernel_name == prim::kPrimAdd->name()) {
|
||||||
operate_type_ = ADD;
|
operate_type_ = ADD;
|
||||||
} else if (kernel_name == prim::kPrimSub->name()) {
|
} else if (kernel_name == prim::kPrimSub->name()) {
|
||||||
operate_type_ = SUB;
|
operate_type_ = SUB;
|
||||||
|
|
|
@ -37,8 +37,7 @@ class TensorAddCPUKernel : public MKLCPUKernel {
|
||||||
};
|
};
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL(
|
MS_REG_CPU_KERNEL(
|
||||||
TensorAdd,
|
Add, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
|
||||||
TensorAddCPUKernel);
|
TensorAddCPUKernel);
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -51,8 +51,7 @@ MS_REG_GPU_KERNEL_ONE(
|
||||||
Sub, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
Sub, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||||
BroadcastOpGpuKernel, float)
|
BroadcastOpGpuKernel, float)
|
||||||
MS_REG_GPU_KERNEL_ONE(
|
MS_REG_GPU_KERNEL_ONE(
|
||||||
TensorAdd,
|
Add, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
|
||||||
BroadcastOpGpuKernel, float)
|
BroadcastOpGpuKernel, float)
|
||||||
MS_REG_GPU_KERNEL_ONE(
|
MS_REG_GPU_KERNEL_ONE(
|
||||||
FloorDiv,
|
FloorDiv,
|
||||||
|
@ -103,8 +102,7 @@ MS_REG_GPU_KERNEL_ONE(
|
||||||
Sub, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
Sub, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||||
BroadcastOpGpuKernel, half)
|
BroadcastOpGpuKernel, half)
|
||||||
MS_REG_GPU_KERNEL_ONE(
|
MS_REG_GPU_KERNEL_ONE(
|
||||||
TensorAdd,
|
Add, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
|
||||||
BroadcastOpGpuKernel, half)
|
BroadcastOpGpuKernel, half)
|
||||||
MS_REG_GPU_KERNEL_ONE(
|
MS_REG_GPU_KERNEL_ONE(
|
||||||
FloorDiv,
|
FloorDiv,
|
||||||
|
@ -133,7 +131,7 @@ MS_REG_GPU_KERNEL_ONE(
|
||||||
Equal, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
|
Equal, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
|
||||||
BroadcastOpGpuKernel, int)
|
BroadcastOpGpuKernel, int)
|
||||||
MS_REG_GPU_KERNEL_ONE(
|
MS_REG_GPU_KERNEL_ONE(
|
||||||
TensorAdd, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
Add, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||||
BroadcastOpGpuKernel, int)
|
BroadcastOpGpuKernel, int)
|
||||||
MS_REG_GPU_KERNEL_ONE(
|
MS_REG_GPU_KERNEL_ONE(
|
||||||
Minimum, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
Minimum, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||||
|
@ -171,7 +169,7 @@ MS_REG_GPU_KERNEL_ONE(
|
||||||
Equal, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
|
Equal, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
|
||||||
BroadcastOpGpuKernel, int64_t)
|
BroadcastOpGpuKernel, int64_t)
|
||||||
MS_REG_GPU_KERNEL_ONE(
|
MS_REG_GPU_KERNEL_ONE(
|
||||||
TensorAdd, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
Add, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||||
BroadcastOpGpuKernel, int64_t)
|
BroadcastOpGpuKernel, int64_t)
|
||||||
MS_REG_GPU_KERNEL_ONE(
|
MS_REG_GPU_KERNEL_ONE(
|
||||||
Minimum, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
Minimum, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||||
|
|
|
@ -145,7 +145,7 @@ class BroadcastOpGpuKernel : public GpuKernel {
|
||||||
static std::map<std::string, BroadcastOpType> kBroadcastArithmetricTypeMap = {
|
static std::map<std::string, BroadcastOpType> kBroadcastArithmetricTypeMap = {
|
||||||
{"Maximum", BROADCAST_TYPE_MAXIMUM}, {"Minimum", BROADCAST_TYPE_MINIMUM}, {"Pow", BROADCAST_TYPE_POWER},
|
{"Maximum", BROADCAST_TYPE_MAXIMUM}, {"Minimum", BROADCAST_TYPE_MINIMUM}, {"Pow", BROADCAST_TYPE_POWER},
|
||||||
{"RealDiv", BROADCAST_TYPE_REALDIV}, {"Mul", BROADCAST_TYPE_MUL}, {"Sub", BROADCAST_TYPE_SUB},
|
{"RealDiv", BROADCAST_TYPE_REALDIV}, {"Mul", BROADCAST_TYPE_MUL}, {"Sub", BROADCAST_TYPE_SUB},
|
||||||
{"TensorAdd", BROADCAST_TYPE_ADD}, {"FloorDiv", BROADCAST_TYPE_FLOORDIV}, {"AbsGrad", BROADCAST_TYPE_ABSGRAD},
|
{"Add", BROADCAST_TYPE_ADD}, {"FloorDiv", BROADCAST_TYPE_FLOORDIV}, {"AbsGrad", BROADCAST_TYPE_ABSGRAD},
|
||||||
{"Div", BROADCAST_TYPE_DIV}, {"DivNoNan", BROADCAST_TYPE_DIVNONAN},
|
{"Div", BROADCAST_TYPE_DIV}, {"DivNoNan", BROADCAST_TYPE_DIVNONAN},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -1063,7 +1063,7 @@ size_t TbeKernelBuild::GetOptionalInput(const mindspore::CNodePtr &cnode, bool i
|
||||||
|
|
||||||
std::string TbeKernelBuild::GetRealOpType(const std::string &origin_type) {
|
std::string TbeKernelBuild::GetRealOpType(const std::string &origin_type) {
|
||||||
static std::map<std::string, std::string> buffer_fussion_op_map = {
|
static std::map<std::string, std::string> buffer_fussion_op_map = {
|
||||||
{parallel::DEPTHWISE_CONV2D_NATIVE, parallel::DEPTHWISE_CONV2D}, {parallel::TENSOR_ADD, parallel::ADD}};
|
{parallel::DEPTHWISE_CONV2D_NATIVE, parallel::DEPTHWISE_CONV2D}};
|
||||||
string result = origin_type;
|
string result = origin_type;
|
||||||
auto iter = buffer_fussion_op_map.find(origin_type);
|
auto iter = buffer_fussion_op_map.find(origin_type);
|
||||||
if (iter != buffer_fussion_op_map.end()) {
|
if (iter != buffer_fussion_op_map.end()) {
|
||||||
|
|
|
@ -99,7 +99,7 @@ void BnupdateEltwiseEltwiseFusionPass::MatchSingleFusionPattern(const session::K
|
||||||
AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE && CheckEltwiseInputAndOutputSize(cnode)) {
|
AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE && CheckEltwiseInputAndOutputSize(cnode)) {
|
||||||
auto eltwise_input = cnode->input(1);
|
auto eltwise_input = cnode->input(1);
|
||||||
MS_EXCEPTION_IF_NULL(eltwise_input);
|
MS_EXCEPTION_IF_NULL(eltwise_input);
|
||||||
if (eltwise_input->isa<CNode>() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimTensorAdd)) {
|
if (eltwise_input->isa<CNode>() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimAdd)) {
|
||||||
MatchBnupdateAddRelu(cnode, eltwise_input, kernel_graph, candidate_fusion);
|
MatchBnupdateAddRelu(cnode, eltwise_input, kernel_graph, candidate_fusion);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -28,7 +28,7 @@ const BaseRef AdamApplyOneFusion::DefinePattern() const {
|
||||||
VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]});
|
VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]});
|
||||||
VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]});
|
VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]});
|
||||||
VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
|
VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
|
||||||
VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, sqrt0, add2_y_})});
|
VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimAdd, sqrt0, add2_y_})});
|
||||||
return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, input_vars_[4], true_div0})});
|
return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, input_vars_[4], true_div0})});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -41,7 +41,7 @@ const BaseRef AdamApplyOneCond1Fusion::DefinePattern() const {
|
||||||
VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]});
|
VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]});
|
||||||
VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]});
|
VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]});
|
||||||
VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
|
VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
|
||||||
VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, add2_y_, sqrt0})});
|
VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimAdd, add2_y_, sqrt0})});
|
||||||
return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, input_vars_[4], true_div0})});
|
return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, input_vars_[4], true_div0})});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -54,7 +54,7 @@ const BaseRef AdamApplyOneCond2Fusion::DefinePattern() const {
|
||||||
VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]});
|
VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]});
|
||||||
VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]});
|
VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]});
|
||||||
VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
|
VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
|
||||||
VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, sqrt0, add2_y_})});
|
VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimAdd, sqrt0, add2_y_})});
|
||||||
return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, true_div0, input_vars_[4]})});
|
return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, true_div0, input_vars_[4]})});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -67,7 +67,7 @@ const BaseRef AdamApplyOneCond3Fusion::DefinePattern() const {
|
||||||
VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]});
|
VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]});
|
||||||
VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]});
|
VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]});
|
||||||
VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
|
VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
|
||||||
VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, sqrt0, add2_y_})});
|
VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimAdd, sqrt0, add2_y_})});
|
||||||
return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, true_div0, input_vars_[4]})});
|
return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, true_div0, input_vars_[4]})});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -80,7 +80,7 @@ const BaseRef AdamApplyOneCond4Fusion::DefinePattern() const {
|
||||||
VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]});
|
VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]});
|
||||||
VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]});
|
VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]});
|
||||||
VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
|
VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
|
||||||
VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, add2_y_, sqrt0})});
|
VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimAdd, add2_y_, sqrt0})});
|
||||||
return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, true_div0, input_vars_[4]})});
|
return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, true_div0, input_vars_[4]})});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -94,7 +94,7 @@ const BaseRef AdamApplyOneAssignFusion::DefinePattern() const {
|
||||||
VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]});
|
VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]});
|
||||||
VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]});
|
VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]});
|
||||||
VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
|
VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
|
||||||
VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, sqrt0, add2_y_})});
|
VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimAdd, sqrt0, add2_y_})});
|
||||||
VectorRef sub0 = VectorRef({sub0_var_, input_vars_[3], VectorRef({prim::kPrimMul, input_vars_[4], true_div0})});
|
VectorRef sub0 = VectorRef({sub0_var_, input_vars_[3], VectorRef({prim::kPrimMul, input_vars_[4], true_div0})});
|
||||||
VectorRef assign0 = VectorRef({prim::kPrimAssign, input_vars_[3], sub0});
|
VectorRef assign0 = VectorRef({prim::kPrimAssign, input_vars_[3], sub0});
|
||||||
VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0});
|
VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0});
|
||||||
|
@ -114,7 +114,7 @@ const BaseRef AdamApplyOneAssignCond1Fusion::DefinePattern() const {
|
||||||
VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]});
|
VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]});
|
||||||
VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]});
|
VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]});
|
||||||
VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
|
VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
|
||||||
VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, add2_y_, sqrt0})});
|
VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimAdd, add2_y_, sqrt0})});
|
||||||
VectorRef sub0 = VectorRef({sub0_var_, input_vars_[3], VectorRef({prim::kPrimMul, input_vars_[4], true_div0})});
|
VectorRef sub0 = VectorRef({sub0_var_, input_vars_[3], VectorRef({prim::kPrimMul, input_vars_[4], true_div0})});
|
||||||
VectorRef assign0 = VectorRef({prim::kPrimAssign, input_vars_[3], sub0});
|
VectorRef assign0 = VectorRef({prim::kPrimAssign, input_vars_[3], sub0});
|
||||||
VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0});
|
VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0});
|
||||||
|
@ -134,7 +134,7 @@ const BaseRef AdamApplyOneAssignCond2Fusion::DefinePattern() const {
|
||||||
VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]});
|
VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]});
|
||||||
VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]});
|
VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]});
|
||||||
VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
|
VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
|
||||||
VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, sqrt0, add2_y_})});
|
VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimAdd, sqrt0, add2_y_})});
|
||||||
VectorRef sub0 = VectorRef({sub0_var_, input_vars_[3], VectorRef({prim::kPrimMul, true_div0, input_vars_[4]})});
|
VectorRef sub0 = VectorRef({sub0_var_, input_vars_[3], VectorRef({prim::kPrimMul, true_div0, input_vars_[4]})});
|
||||||
VectorRef assign0 = VectorRef({prim::kPrimAssign, input_vars_[3], sub0});
|
VectorRef assign0 = VectorRef({prim::kPrimAssign, input_vars_[3], sub0});
|
||||||
VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0});
|
VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0});
|
||||||
|
@ -154,7 +154,7 @@ const BaseRef AdamApplyOneAssignCond3Fusion::DefinePattern() const {
|
||||||
VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]});
|
VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]});
|
||||||
VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]});
|
VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]});
|
||||||
VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
|
VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
|
||||||
VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, sqrt0, add2_y_})});
|
VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimAdd, sqrt0, add2_y_})});
|
||||||
VectorRef sub0 = VectorRef({sub0_var_, input_vars_[3], VectorRef({prim::kPrimMul, true_div0, input_vars_[4]})});
|
VectorRef sub0 = VectorRef({sub0_var_, input_vars_[3], VectorRef({prim::kPrimMul, true_div0, input_vars_[4]})});
|
||||||
VectorRef assign0 = VectorRef({prim::kPrimAssign, input_vars_[3], sub0});
|
VectorRef assign0 = VectorRef({prim::kPrimAssign, input_vars_[3], sub0});
|
||||||
VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0});
|
VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0});
|
||||||
|
@ -174,7 +174,7 @@ const BaseRef AdamApplyOneAssignCond4Fusion::DefinePattern() const {
|
||||||
VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]});
|
VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]});
|
||||||
VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]});
|
VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]});
|
||||||
VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
|
VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
|
||||||
VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, add2_y_, sqrt0})});
|
VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimAdd, add2_y_, sqrt0})});
|
||||||
VectorRef sub0 = VectorRef({sub0_var_, input_vars_[3], VectorRef({prim::kPrimMul, true_div0, input_vars_[4]})});
|
VectorRef sub0 = VectorRef({sub0_var_, input_vars_[3], VectorRef({prim::kPrimMul, true_div0, input_vars_[4]})});
|
||||||
VectorRef assign0 = VectorRef({prim::kPrimAssign, input_vars_[3], sub0});
|
VectorRef assign0 = VectorRef({prim::kPrimAssign, input_vars_[3], sub0});
|
||||||
VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0});
|
VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0});
|
||||||
|
|
|
@ -38,8 +38,8 @@ class AdamApplyOneFusion : public PatternProcessPass {
|
||||||
mul_x_input_vars_.push_back(std::make_shared<Var>());
|
mul_x_input_vars_.push_back(std::make_shared<Var>());
|
||||||
}
|
}
|
||||||
add2_y_ = std::make_shared<Var>();
|
add2_y_ = std::make_shared<Var>();
|
||||||
add0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimTensorAdd->name()));
|
add0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimAdd->name()));
|
||||||
add1_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimTensorAdd->name()));
|
add1_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimAdd->name()));
|
||||||
sub0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimSub->name()));
|
sub0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimSub->name()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -59,10 +59,10 @@ const BaseRef AdamApplyOneWithDecayRuleCond1::DefinePattern() const {
|
||||||
VectorRef mul3({prim::kPrimMul, mul3_x_, square0});
|
VectorRef mul3({prim::kPrimMul, mul3_x_, square0});
|
||||||
VectorRef add1({add1_var_, mul2, mul3});
|
VectorRef add1({add1_var_, mul2, mul3});
|
||||||
VectorRef sqrt0({sqrt, add1});
|
VectorRef sqrt0({sqrt, add1});
|
||||||
VectorRef add2({prim::kPrimTensorAdd, add2_y_, sqrt0});
|
VectorRef add2({prim::kPrimAdd, add2_y_, sqrt0});
|
||||||
VectorRef mul4({prim::kPrimMul, mul4_x_, input3_});
|
VectorRef mul4({prim::kPrimMul, mul4_x_, input3_});
|
||||||
VectorRef real_div0({real_div, add0, add2});
|
VectorRef real_div0({real_div, add0, add2});
|
||||||
VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0});
|
VectorRef add3({prim::kPrimAdd, mul4, real_div0});
|
||||||
VectorRef mul5({prim::kPrimMul, input4_, add3});
|
VectorRef mul5({prim::kPrimMul, input4_, add3});
|
||||||
VectorRef sub0({prim::kPrimSub, input3_, mul5});
|
VectorRef sub0({prim::kPrimSub, input3_, mul5});
|
||||||
return sub0;
|
return sub0;
|
||||||
|
@ -79,10 +79,10 @@ const BaseRef AdamApplyOneWithDecayRuleCond2::DefinePattern() const {
|
||||||
VectorRef mul3({prim::kPrimMul, mul3_x_, square0});
|
VectorRef mul3({prim::kPrimMul, mul3_x_, square0});
|
||||||
VectorRef add1({add1_var_, mul2, mul3});
|
VectorRef add1({add1_var_, mul2, mul3});
|
||||||
VectorRef sqrt0({sqrt, add1});
|
VectorRef sqrt0({sqrt, add1});
|
||||||
VectorRef add2({prim::kPrimTensorAdd, sqrt0, add2_y_});
|
VectorRef add2({prim::kPrimAdd, sqrt0, add2_y_});
|
||||||
VectorRef mul4({prim::kPrimMul, input3_, mul4_x_});
|
VectorRef mul4({prim::kPrimMul, input3_, mul4_x_});
|
||||||
VectorRef real_div0({real_div, add0, add2});
|
VectorRef real_div0({real_div, add0, add2});
|
||||||
VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0});
|
VectorRef add3({prim::kPrimAdd, mul4, real_div0});
|
||||||
VectorRef mul5({prim::kPrimMul, add3, input4_});
|
VectorRef mul5({prim::kPrimMul, add3, input4_});
|
||||||
VectorRef sub0({prim::kPrimSub, input3_, mul5});
|
VectorRef sub0({prim::kPrimSub, input3_, mul5});
|
||||||
return sub0;
|
return sub0;
|
||||||
|
@ -99,10 +99,10 @@ const BaseRef AdamApplyOneWithDecayRuleCond3::DefinePattern() const {
|
||||||
VectorRef mul3({prim::kPrimMul, square0, mul3_x_});
|
VectorRef mul3({prim::kPrimMul, square0, mul3_x_});
|
||||||
VectorRef add1({add1_var_, mul2, mul3});
|
VectorRef add1({add1_var_, mul2, mul3});
|
||||||
VectorRef sqrt0({sqrt, add1});
|
VectorRef sqrt0({sqrt, add1});
|
||||||
VectorRef add2({prim::kPrimTensorAdd, sqrt0, add2_y_});
|
VectorRef add2({prim::kPrimAdd, sqrt0, add2_y_});
|
||||||
VectorRef mul4({prim::kPrimMul, mul4_x_, input3_});
|
VectorRef mul4({prim::kPrimMul, mul4_x_, input3_});
|
||||||
VectorRef real_div0({real_div, add0, add2});
|
VectorRef real_div0({real_div, add0, add2});
|
||||||
VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0});
|
VectorRef add3({prim::kPrimAdd, mul4, real_div0});
|
||||||
VectorRef mul5({prim::kPrimMul, add3, input4_});
|
VectorRef mul5({prim::kPrimMul, add3, input4_});
|
||||||
VectorRef sub0({prim::kPrimSub, input3_, mul5});
|
VectorRef sub0({prim::kPrimSub, input3_, mul5});
|
||||||
return sub0;
|
return sub0;
|
||||||
|
@ -119,10 +119,10 @@ const BaseRef AdamApplyOneWithDecayRuleCond4::DefinePattern() const {
|
||||||
VectorRef mul3({prim::kPrimMul, mul3_x_, square0});
|
VectorRef mul3({prim::kPrimMul, mul3_x_, square0});
|
||||||
VectorRef add1({add1_var_, mul2, mul3});
|
VectorRef add1({add1_var_, mul2, mul3});
|
||||||
VectorRef sqrt0({sqrt, add1});
|
VectorRef sqrt0({sqrt, add1});
|
||||||
VectorRef add2({prim::kPrimTensorAdd, add2_y_, sqrt0});
|
VectorRef add2({prim::kPrimAdd, add2_y_, sqrt0});
|
||||||
VectorRef mul4({prim::kPrimMul, mul4_x_, input3_});
|
VectorRef mul4({prim::kPrimMul, mul4_x_, input3_});
|
||||||
VectorRef real_div0({real_div, add0, add2});
|
VectorRef real_div0({real_div, add0, add2});
|
||||||
VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0});
|
VectorRef add3({prim::kPrimAdd, mul4, real_div0});
|
||||||
VectorRef mul5({prim::kPrimMul, add3, input4_});
|
VectorRef mul5({prim::kPrimMul, add3, input4_});
|
||||||
VectorRef sub0({prim::kPrimSub, input3_, mul5});
|
VectorRef sub0({prim::kPrimSub, input3_, mul5});
|
||||||
return sub0;
|
return sub0;
|
||||||
|
@ -139,10 +139,10 @@ const BaseRef AdamApplyOneWithDecayRuleCond5::DefinePattern() const {
|
||||||
VectorRef mul3({prim::kPrimMul, mul3_x_, square0});
|
VectorRef mul3({prim::kPrimMul, mul3_x_, square0});
|
||||||
VectorRef add1({add1_var_, mul2, mul3});
|
VectorRef add1({add1_var_, mul2, mul3});
|
||||||
VectorRef sqrt0({sqrt, add1});
|
VectorRef sqrt0({sqrt, add1});
|
||||||
VectorRef add2({prim::kPrimTensorAdd, sqrt0, add2_y_});
|
VectorRef add2({prim::kPrimAdd, sqrt0, add2_y_});
|
||||||
VectorRef mul4({prim::kPrimMul, mul4_x_, input3_});
|
VectorRef mul4({prim::kPrimMul, mul4_x_, input3_});
|
||||||
VectorRef real_div0({real_div, add0, add2});
|
VectorRef real_div0({real_div, add0, add2});
|
||||||
VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0});
|
VectorRef add3({prim::kPrimAdd, mul4, real_div0});
|
||||||
VectorRef mul5({prim::kPrimMul, add3, input4_});
|
VectorRef mul5({prim::kPrimMul, add3, input4_});
|
||||||
VectorRef sub0({prim::kPrimSub, input3_, mul5});
|
VectorRef sub0({prim::kPrimSub, input3_, mul5});
|
||||||
return sub0;
|
return sub0;
|
||||||
|
@ -159,10 +159,10 @@ const BaseRef AdamApplyOneWithDecayAssignRuleCond1::DefinePattern() const {
|
||||||
VectorRef mul3({prim::kPrimMul, mul3_x_, square0});
|
VectorRef mul3({prim::kPrimMul, mul3_x_, square0});
|
||||||
VectorRef add1({add1_var_, mul2, mul3});
|
VectorRef add1({add1_var_, mul2, mul3});
|
||||||
VectorRef sqrt0({sqrt, add1});
|
VectorRef sqrt0({sqrt, add1});
|
||||||
VectorRef add2({prim::kPrimTensorAdd, add2_y_, sqrt0});
|
VectorRef add2({prim::kPrimAdd, add2_y_, sqrt0});
|
||||||
VectorRef mul4({prim::kPrimMul, mul4_x_, input3_});
|
VectorRef mul4({prim::kPrimMul, mul4_x_, input3_});
|
||||||
VectorRef real_div0({real_div, add0, add2});
|
VectorRef real_div0({real_div, add0, add2});
|
||||||
VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0});
|
VectorRef add3({prim::kPrimAdd, mul4, real_div0});
|
||||||
VectorRef mul5({prim::kPrimMul, input4_, add3});
|
VectorRef mul5({prim::kPrimMul, input4_, add3});
|
||||||
VectorRef sub0({sub0_var_, input3_, mul5});
|
VectorRef sub0({sub0_var_, input3_, mul5});
|
||||||
VectorRef assign0 = VectorRef({prim::kPrimAssign, input3_, sub0});
|
VectorRef assign0 = VectorRef({prim::kPrimAssign, input3_, sub0});
|
||||||
|
@ -184,10 +184,10 @@ const BaseRef AdamApplyOneWithDecayAssignRuleCond2::DefinePattern() const {
|
||||||
VectorRef mul3({prim::kPrimMul, mul3_x_, square0});
|
VectorRef mul3({prim::kPrimMul, mul3_x_, square0});
|
||||||
VectorRef add1({add1_var_, mul2, mul3});
|
VectorRef add1({add1_var_, mul2, mul3});
|
||||||
VectorRef sqrt0({sqrt, add1});
|
VectorRef sqrt0({sqrt, add1});
|
||||||
VectorRef add2({prim::kPrimTensorAdd, sqrt0, add2_y_});
|
VectorRef add2({prim::kPrimAdd, sqrt0, add2_y_});
|
||||||
VectorRef mul4({prim::kPrimMul, input3_, mul4_x_});
|
VectorRef mul4({prim::kPrimMul, input3_, mul4_x_});
|
||||||
VectorRef real_div0({real_div, add0, add2});
|
VectorRef real_div0({real_div, add0, add2});
|
||||||
VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0});
|
VectorRef add3({prim::kPrimAdd, mul4, real_div0});
|
||||||
VectorRef mul5({prim::kPrimMul, add3, input4_});
|
VectorRef mul5({prim::kPrimMul, add3, input4_});
|
||||||
VectorRef sub0({sub0_var_, input3_, mul5});
|
VectorRef sub0({sub0_var_, input3_, mul5});
|
||||||
VectorRef assign0 = VectorRef({prim::kPrimAssign, input3_, sub0});
|
VectorRef assign0 = VectorRef({prim::kPrimAssign, input3_, sub0});
|
||||||
|
@ -209,10 +209,10 @@ const BaseRef AdamApplyOneWithDecayAssignRuleCond3::DefinePattern() const {
|
||||||
VectorRef mul3({prim::kPrimMul, square0, mul3_x_});
|
VectorRef mul3({prim::kPrimMul, square0, mul3_x_});
|
||||||
VectorRef add1({add1_var_, mul2, mul3});
|
VectorRef add1({add1_var_, mul2, mul3});
|
||||||
VectorRef sqrt0({sqrt, add1});
|
VectorRef sqrt0({sqrt, add1});
|
||||||
VectorRef add2({prim::kPrimTensorAdd, sqrt0, add2_y_});
|
VectorRef add2({prim::kPrimAdd, sqrt0, add2_y_});
|
||||||
VectorRef mul4({prim::kPrimMul, mul4_x_, input3_});
|
VectorRef mul4({prim::kPrimMul, mul4_x_, input3_});
|
||||||
VectorRef real_div0({real_div, add0, add2});
|
VectorRef real_div0({real_div, add0, add2});
|
||||||
VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0});
|
VectorRef add3({prim::kPrimAdd, mul4, real_div0});
|
||||||
VectorRef mul5({prim::kPrimMul, add3, input4_});
|
VectorRef mul5({prim::kPrimMul, add3, input4_});
|
||||||
VectorRef sub0({sub0_var_, input3_, mul5});
|
VectorRef sub0({sub0_var_, input3_, mul5});
|
||||||
VectorRef assign0 = VectorRef({prim::kPrimAssign, input3_, sub0});
|
VectorRef assign0 = VectorRef({prim::kPrimAssign, input3_, sub0});
|
||||||
|
@ -234,10 +234,10 @@ const BaseRef AdamApplyOneWithDecayAssignRuleCond4::DefinePattern() const {
|
||||||
VectorRef mul3({prim::kPrimMul, mul3_x_, square0});
|
VectorRef mul3({prim::kPrimMul, mul3_x_, square0});
|
||||||
VectorRef add1({add1_var_, mul2, mul3});
|
VectorRef add1({add1_var_, mul2, mul3});
|
||||||
VectorRef sqrt0({sqrt, add1});
|
VectorRef sqrt0({sqrt, add1});
|
||||||
VectorRef add2({prim::kPrimTensorAdd, add2_y_, sqrt0});
|
VectorRef add2({prim::kPrimAdd, add2_y_, sqrt0});
|
||||||
VectorRef mul4({prim::kPrimMul, mul4_x_, input3_});
|
VectorRef mul4({prim::kPrimMul, mul4_x_, input3_});
|
||||||
VectorRef real_div0({real_div, add0, add2});
|
VectorRef real_div0({real_div, add0, add2});
|
||||||
VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0});
|
VectorRef add3({prim::kPrimAdd, mul4, real_div0});
|
||||||
VectorRef mul5({prim::kPrimMul, add3, input4_});
|
VectorRef mul5({prim::kPrimMul, add3, input4_});
|
||||||
VectorRef sub0({sub0_var_, input3_, mul5});
|
VectorRef sub0({sub0_var_, input3_, mul5});
|
||||||
VectorRef assign0 = VectorRef({prim::kPrimAssign, input3_, sub0});
|
VectorRef assign0 = VectorRef({prim::kPrimAssign, input3_, sub0});
|
||||||
|
@ -259,10 +259,10 @@ const BaseRef AdamApplyOneWithDecayAssignRuleCond5::DefinePattern() const {
|
||||||
VectorRef mul3({prim::kPrimMul, mul3_x_, square0});
|
VectorRef mul3({prim::kPrimMul, mul3_x_, square0});
|
||||||
VectorRef add1({add1_var_, mul2, mul3});
|
VectorRef add1({add1_var_, mul2, mul3});
|
||||||
VectorRef sqrt0({sqrt, add1});
|
VectorRef sqrt0({sqrt, add1});
|
||||||
VectorRef add2({prim::kPrimTensorAdd, sqrt0, add2_y_});
|
VectorRef add2({prim::kPrimAdd, sqrt0, add2_y_});
|
||||||
VectorRef mul4({prim::kPrimMul, mul4_x_, input3_});
|
VectorRef mul4({prim::kPrimMul, mul4_x_, input3_});
|
||||||
VectorRef real_div0({real_div, add0, add2});
|
VectorRef real_div0({real_div, add0, add2});
|
||||||
VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0});
|
VectorRef add3({prim::kPrimAdd, mul4, real_div0});
|
||||||
VectorRef mul5({prim::kPrimMul, add3, input4_});
|
VectorRef mul5({prim::kPrimMul, add3, input4_});
|
||||||
VectorRef sub0({sub0_var_, input3_, mul5});
|
VectorRef sub0({sub0_var_, input3_, mul5});
|
||||||
VectorRef assign0 = VectorRef({prim::kPrimAssign, input3_, sub0});
|
VectorRef assign0 = VectorRef({prim::kPrimAssign, input3_, sub0});
|
||||||
|
|
|
@ -38,8 +38,8 @@ class AdamApplyOneWithDecayRule : public PatternProcessPass {
|
||||||
mul3_x_ = std::make_shared<Var>();
|
mul3_x_ = std::make_shared<Var>();
|
||||||
mul4_x_ = std::make_shared<Var>();
|
mul4_x_ = std::make_shared<Var>();
|
||||||
add2_y_ = std::make_shared<Var>();
|
add2_y_ = std::make_shared<Var>();
|
||||||
add0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimTensorAdd->name()));
|
add0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimAdd->name()));
|
||||||
add1_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimTensorAdd->name()));
|
add1_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimAdd->name()));
|
||||||
sub0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimSub->name()));
|
sub0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimSub->name()));
|
||||||
}
|
}
|
||||||
~AdamApplyOneWithDecayRule() override = default;
|
~AdamApplyOneWithDecayRule() override = default;
|
||||||
|
|
|
@ -130,11 +130,11 @@ const BaseRef LambNextMVRuleCond1::DefinePattern() const {
|
||||||
auto real_div0 = VectorRef({real_div0_var_, add0, input5_});
|
auto real_div0 = VectorRef({real_div0_var_, add0, input5_});
|
||||||
auto real_div1 = VectorRef({real_div1_var_, add1, input2_});
|
auto real_div1 = VectorRef({real_div1_var_, add1, input2_});
|
||||||
|
|
||||||
auto add2 = VectorRef({prim::kPrimTensorAdd, add2_y_, real_div1});
|
auto add2 = VectorRef({prim::kPrimAdd, add2_y_, real_div1});
|
||||||
auto sqrt0 = VectorRef({prim_rsqrt, add2});
|
auto sqrt0 = VectorRef({prim_rsqrt, add2});
|
||||||
auto real_div2 = VectorRef({real_div2_var_, sqrt0, real_div0});
|
auto real_div2 = VectorRef({real_div2_var_, sqrt0, real_div0});
|
||||||
|
|
||||||
return VectorRef({prim::kPrimTensorAdd, mul4, real_div2});
|
return VectorRef({prim::kPrimAdd, mul4, real_div2});
|
||||||
}
|
}
|
||||||
|
|
||||||
BaseRef LambNextMVRuleCond1::DefineAnotherPattern() const {
|
BaseRef LambNextMVRuleCond1::DefineAnotherPattern() const {
|
||||||
|
@ -147,7 +147,7 @@ BaseRef LambNextMVRuleCond1::DefineAnotherPattern() const {
|
||||||
VectorRef real_div1 = VectorRef({real_div1_var_, Ys});
|
VectorRef real_div1 = VectorRef({real_div1_var_, Ys});
|
||||||
|
|
||||||
VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1});
|
VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1});
|
||||||
VectorRef add4 = VectorRef({prim::kPrimTensorAdd, add2_y_, sqrt1});
|
VectorRef add4 = VectorRef({prim::kPrimAdd, add2_y_, sqrt1});
|
||||||
VectorRef real_div4 = VectorRef({prim_real_div, real_div0, add4});
|
VectorRef real_div4 = VectorRef({prim_real_div, real_div0, add4});
|
||||||
return real_div4;
|
return real_div4;
|
||||||
}
|
}
|
||||||
|
@ -166,11 +166,11 @@ const BaseRef LambNextMVRuleCond2::DefinePattern() const {
|
||||||
auto real_div0 = VectorRef({real_div0_var_, add0, input5_});
|
auto real_div0 = VectorRef({real_div0_var_, add0, input5_});
|
||||||
auto real_div1 = VectorRef({real_div1_var_, add1, input2_});
|
auto real_div1 = VectorRef({real_div1_var_, add1, input2_});
|
||||||
|
|
||||||
auto add2 = VectorRef({prim::kPrimTensorAdd, add2_y_, real_div1});
|
auto add2 = VectorRef({prim::kPrimAdd, add2_y_, real_div1});
|
||||||
auto sqrt0 = VectorRef({prim_rsqrt, add2});
|
auto sqrt0 = VectorRef({prim_rsqrt, add2});
|
||||||
auto real_div2 = VectorRef({real_div2_var_, sqrt0, real_div0});
|
auto real_div2 = VectorRef({real_div2_var_, sqrt0, real_div0});
|
||||||
|
|
||||||
return VectorRef({prim::kPrimTensorAdd, mul4, real_div2});
|
return VectorRef({prim::kPrimAdd, mul4, real_div2});
|
||||||
}
|
}
|
||||||
|
|
||||||
BaseRef LambNextMVRuleCond2::DefineAnotherPattern() const {
|
BaseRef LambNextMVRuleCond2::DefineAnotherPattern() const {
|
||||||
|
@ -183,7 +183,7 @@ BaseRef LambNextMVRuleCond2::DefineAnotherPattern() const {
|
||||||
VectorRef real_div1 = VectorRef({real_div1_var_, Ys});
|
VectorRef real_div1 = VectorRef({real_div1_var_, Ys});
|
||||||
|
|
||||||
VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1});
|
VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1});
|
||||||
VectorRef add4 = VectorRef({prim::kPrimTensorAdd, sqrt1, add2_y_});
|
VectorRef add4 = VectorRef({prim::kPrimAdd, sqrt1, add2_y_});
|
||||||
VectorRef real_div4 = VectorRef({prim_real_div, real_div0, add4});
|
VectorRef real_div4 = VectorRef({prim_real_div, real_div0, add4});
|
||||||
return real_div4;
|
return real_div4;
|
||||||
}
|
}
|
||||||
|
@ -202,11 +202,11 @@ const BaseRef LambNextMVRuleCond3::DefinePattern() const {
|
||||||
auto real_div0 = VectorRef({real_div0_var_, add0, input5_});
|
auto real_div0 = VectorRef({real_div0_var_, add0, input5_});
|
||||||
auto real_div1 = VectorRef({real_div1_var_, add1, input2_});
|
auto real_div1 = VectorRef({real_div1_var_, add1, input2_});
|
||||||
|
|
||||||
auto add2 = VectorRef({prim::kPrimTensorAdd, real_div1, add2_y_});
|
auto add2 = VectorRef({prim::kPrimAdd, real_div1, add2_y_});
|
||||||
auto sqrt0 = VectorRef({prim_rsqrt, add2});
|
auto sqrt0 = VectorRef({prim_rsqrt, add2});
|
||||||
auto real_div2 = VectorRef({real_div2_var_, sqrt0, real_div0});
|
auto real_div2 = VectorRef({real_div2_var_, sqrt0, real_div0});
|
||||||
|
|
||||||
return VectorRef({prim::kPrimTensorAdd, mul4, real_div2});
|
return VectorRef({prim::kPrimAdd, mul4, real_div2});
|
||||||
}
|
}
|
||||||
|
|
||||||
BaseRef LambNextMVRuleCond3::DefineAnotherPattern() const {
|
BaseRef LambNextMVRuleCond3::DefineAnotherPattern() const {
|
||||||
|
@ -219,7 +219,7 @@ BaseRef LambNextMVRuleCond3::DefineAnotherPattern() const {
|
||||||
VectorRef real_div1 = VectorRef({real_div1_var_, Ys});
|
VectorRef real_div1 = VectorRef({real_div1_var_, Ys});
|
||||||
|
|
||||||
VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1});
|
VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1});
|
||||||
VectorRef add4 = VectorRef({prim::kPrimTensorAdd, sqrt1, add2_y_});
|
VectorRef add4 = VectorRef({prim::kPrimAdd, sqrt1, add2_y_});
|
||||||
VectorRef real_div4 = VectorRef({prim_real_div, real_div0, add4});
|
VectorRef real_div4 = VectorRef({prim_real_div, real_div0, add4});
|
||||||
return real_div4;
|
return real_div4;
|
||||||
}
|
}
|
||||||
|
@ -238,11 +238,11 @@ const BaseRef LambNextMVRuleCond4::DefinePattern() const {
|
||||||
auto real_div0 = VectorRef({real_div0_var_, add0, input5_});
|
auto real_div0 = VectorRef({real_div0_var_, add0, input5_});
|
||||||
auto real_div1 = VectorRef({real_div1_var_, add1, input2_});
|
auto real_div1 = VectorRef({real_div1_var_, add1, input2_});
|
||||||
|
|
||||||
auto add2 = VectorRef({prim::kPrimTensorAdd, real_div1, add2_y_});
|
auto add2 = VectorRef({prim::kPrimAdd, real_div1, add2_y_});
|
||||||
auto sqrt0 = VectorRef({prim_rsqrt, add2});
|
auto sqrt0 = VectorRef({prim_rsqrt, add2});
|
||||||
auto real_div2 = VectorRef({real_div2_var_, real_div0, sqrt0});
|
auto real_div2 = VectorRef({real_div2_var_, real_div0, sqrt0});
|
||||||
|
|
||||||
return VectorRef({prim::kPrimTensorAdd, real_div2, mul4});
|
return VectorRef({prim::kPrimAdd, real_div2, mul4});
|
||||||
}
|
}
|
||||||
|
|
||||||
BaseRef LambNextMVRuleCond4::DefineAnotherPattern() const {
|
BaseRef LambNextMVRuleCond4::DefineAnotherPattern() const {
|
||||||
|
@ -255,7 +255,7 @@ BaseRef LambNextMVRuleCond4::DefineAnotherPattern() const {
|
||||||
VectorRef real_div1 = VectorRef({real_div1_var_, Ys});
|
VectorRef real_div1 = VectorRef({real_div1_var_, Ys});
|
||||||
|
|
||||||
VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1});
|
VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1});
|
||||||
VectorRef add4 = VectorRef({prim::kPrimTensorAdd, sqrt1, add2_y_});
|
VectorRef add4 = VectorRef({prim::kPrimAdd, sqrt1, add2_y_});
|
||||||
VectorRef real_div4 = VectorRef({prim_real_div, real_div0, add4});
|
VectorRef real_div4 = VectorRef({prim_real_div, real_div0, add4});
|
||||||
return real_div4;
|
return real_div4;
|
||||||
}
|
}
|
||||||
|
|
|
@ -49,8 +49,8 @@ class LambNextMVRule : public MultipleOutputPatternProcessPass {
|
||||||
real_div0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(kRealDivOpName));
|
real_div0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(kRealDivOpName));
|
||||||
real_div1_var_ = std::make_shared<Var>(std::make_shared<Primitive>(kRealDivOpName));
|
real_div1_var_ = std::make_shared<Var>(std::make_shared<Primitive>(kRealDivOpName));
|
||||||
real_div2_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimMul->name()));
|
real_div2_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimMul->name()));
|
||||||
add0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimTensorAdd->name()));
|
add0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimAdd->name()));
|
||||||
add1_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimTensorAdd->name()));
|
add1_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimAdd->name()));
|
||||||
}
|
}
|
||||||
~LambNextMVRule() override = default;
|
~LambNextMVRule() override = default;
|
||||||
const BaseRef DefinePattern() const override = 0;
|
const BaseRef DefinePattern() const override = 0;
|
||||||
|
|
|
@ -124,10 +124,10 @@ BaseRef LambNextMVWithDecayRuleCond1::DefineAnotherPattern() const {
|
||||||
VectorRef real_div1 = VectorRef({real_div1_var_, Ys});
|
VectorRef real_div1 = VectorRef({real_div1_var_, Ys});
|
||||||
VectorRef mul4 = VectorRef({mul4_var_, Zs});
|
VectorRef mul4 = VectorRef({mul4_var_, Zs});
|
||||||
|
|
||||||
VectorRef add2 = VectorRef({prim::kPrimTensorAdd, constant_add2_y_, real_div1});
|
VectorRef add2 = VectorRef({prim::kPrimAdd, constant_add2_y_, real_div1});
|
||||||
VectorRef sqrt0 = VectorRef({prim_rsqrt, add2});
|
VectorRef sqrt0 = VectorRef({prim_rsqrt, add2});
|
||||||
VectorRef real_div2 = VectorRef({prim::kPrimMul, sqrt0, real_div0});
|
VectorRef real_div2 = VectorRef({prim::kPrimMul, sqrt0, real_div0});
|
||||||
VectorRef add3 = VectorRef({prim::kPrimTensorAdd, mul4, real_div2});
|
VectorRef add3 = VectorRef({prim::kPrimAdd, mul4, real_div2});
|
||||||
return add3;
|
return add3;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -141,14 +141,14 @@ const BaseRef LambNextMVWithDecayRuleCond1::DefinePattern() const {
|
||||||
VectorRef add1 = VectorRef({add1_var_, mul2, mul3});
|
VectorRef add1 = VectorRef({add1_var_, mul2, mul3});
|
||||||
VectorRef real_div1 = VectorRef({real_div1_var_, add1, input_vars_[2]});
|
VectorRef real_div1 = VectorRef({real_div1_var_, add1, input_vars_[2]});
|
||||||
VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1});
|
VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1});
|
||||||
VectorRef add4 = VectorRef({prim::kPrimTensorAdd, sqrt1, constant_add2_y_});
|
VectorRef add4 = VectorRef({prim::kPrimAdd, sqrt1, constant_add2_y_});
|
||||||
VectorRef mul0 = VectorRef({prim::kPrimMul, input_vars_[4], constant_mul_input_vars_[0]});
|
VectorRef mul0 = VectorRef({prim::kPrimMul, input_vars_[4], constant_mul_input_vars_[0]});
|
||||||
VectorRef mul1 = VectorRef({prim::kPrimMul, input_vars_[3], constant_mul_input_vars_[1]});
|
VectorRef mul1 = VectorRef({prim::kPrimMul, input_vars_[3], constant_mul_input_vars_[1]});
|
||||||
VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
|
VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
|
||||||
VectorRef real_div0 = VectorRef({real_div0_var_, add0, input_vars_[5]});
|
VectorRef real_div0 = VectorRef({real_div0_var_, add0, input_vars_[5]});
|
||||||
VectorRef real_div4 = VectorRef({prim_deal_div, real_div0, add4});
|
VectorRef real_div4 = VectorRef({prim_deal_div, real_div0, add4});
|
||||||
VectorRef mul4 = VectorRef({mul4_var_, constant_mul_input_vars_[4], input_vars_[6]});
|
VectorRef mul4 = VectorRef({mul4_var_, constant_mul_input_vars_[4], input_vars_[6]});
|
||||||
VectorRef add5 = VectorRef({prim::kPrimTensorAdd, mul4, real_div4});
|
VectorRef add5 = VectorRef({prim::kPrimAdd, mul4, real_div4});
|
||||||
return add5;
|
return add5;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -165,10 +165,10 @@ BaseRef LambNextMVWithDecayRuleCond2::DefineAnotherPattern() const {
|
||||||
VectorRef real_div1 = VectorRef({real_div1_var_, Ys});
|
VectorRef real_div1 = VectorRef({real_div1_var_, Ys});
|
||||||
VectorRef mul4 = VectorRef({mul4_var_, Zs});
|
VectorRef mul4 = VectorRef({mul4_var_, Zs});
|
||||||
|
|
||||||
VectorRef add2 = VectorRef({prim::kPrimTensorAdd, constant_add2_y_, real_div1});
|
VectorRef add2 = VectorRef({prim::kPrimAdd, constant_add2_y_, real_div1});
|
||||||
VectorRef sqrt0 = VectorRef({prim_rsqrt, add2});
|
VectorRef sqrt0 = VectorRef({prim_rsqrt, add2});
|
||||||
VectorRef real_div2 = VectorRef({prim::kPrimMul, sqrt0, real_div0});
|
VectorRef real_div2 = VectorRef({prim::kPrimMul, sqrt0, real_div0});
|
||||||
VectorRef add3 = VectorRef({prim::kPrimTensorAdd, mul4, real_div2});
|
VectorRef add3 = VectorRef({prim::kPrimAdd, mul4, real_div2});
|
||||||
return add3;
|
return add3;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -182,14 +182,14 @@ const BaseRef LambNextMVWithDecayRuleCond2::DefinePattern() const {
|
||||||
VectorRef add1 = VectorRef({add1_var_, mul2, mul3});
|
VectorRef add1 = VectorRef({add1_var_, mul2, mul3});
|
||||||
VectorRef real_div1 = VectorRef({real_div1_var_, add1, input_vars_[2]});
|
VectorRef real_div1 = VectorRef({real_div1_var_, add1, input_vars_[2]});
|
||||||
VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1});
|
VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1});
|
||||||
VectorRef add4 = VectorRef({prim::kPrimTensorAdd, constant_add2_y_, sqrt1});
|
VectorRef add4 = VectorRef({prim::kPrimAdd, constant_add2_y_, sqrt1});
|
||||||
VectorRef mul0 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[0], input_vars_[4]});
|
VectorRef mul0 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[0], input_vars_[4]});
|
||||||
VectorRef mul1 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[1], input_vars_[3]});
|
VectorRef mul1 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[1], input_vars_[3]});
|
||||||
VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
|
VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
|
||||||
VectorRef real_div0 = VectorRef({real_div0_var_, add0, input_vars_[5]});
|
VectorRef real_div0 = VectorRef({real_div0_var_, add0, input_vars_[5]});
|
||||||
VectorRef real_div4 = VectorRef({prim_deal_div, real_div0, add4});
|
VectorRef real_div4 = VectorRef({prim_deal_div, real_div0, add4});
|
||||||
VectorRef mul4 = VectorRef({mul4_var_, constant_mul_input_vars_[4], input_vars_[6]});
|
VectorRef mul4 = VectorRef({mul4_var_, constant_mul_input_vars_[4], input_vars_[6]});
|
||||||
VectorRef add5 = VectorRef({prim::kPrimTensorAdd, mul4, real_div4});
|
VectorRef add5 = VectorRef({prim::kPrimAdd, mul4, real_div4});
|
||||||
return add5;
|
return add5;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -206,10 +206,10 @@ BaseRef LambNextMVWithDecayRuleCond3::DefineAnotherPattern() const {
|
||||||
VectorRef real_div1 = VectorRef({real_div1_var_, Ys});
|
VectorRef real_div1 = VectorRef({real_div1_var_, Ys});
|
||||||
VectorRef mul4 = VectorRef({mul4_var_, Zs});
|
VectorRef mul4 = VectorRef({mul4_var_, Zs});
|
||||||
|
|
||||||
VectorRef add2 = VectorRef({prim::kPrimTensorAdd, real_div1, constant_add2_y_});
|
VectorRef add2 = VectorRef({prim::kPrimAdd, real_div1, constant_add2_y_});
|
||||||
VectorRef sqrt0 = VectorRef({prim_rsqrt, add2});
|
VectorRef sqrt0 = VectorRef({prim_rsqrt, add2});
|
||||||
VectorRef real_div2 = VectorRef({prim::kPrimMul, sqrt0, real_div0});
|
VectorRef real_div2 = VectorRef({prim::kPrimMul, sqrt0, real_div0});
|
||||||
VectorRef add3 = VectorRef({prim::kPrimTensorAdd, mul4, real_div2});
|
VectorRef add3 = VectorRef({prim::kPrimAdd, mul4, real_div2});
|
||||||
return add3;
|
return add3;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -223,14 +223,14 @@ const BaseRef LambNextMVWithDecayRuleCond3::DefinePattern() const {
|
||||||
VectorRef add1 = VectorRef({add1_var_, mul2, mul3});
|
VectorRef add1 = VectorRef({add1_var_, mul2, mul3});
|
||||||
VectorRef real_div1 = VectorRef({real_div1_var_, add1, input_vars_[2]});
|
VectorRef real_div1 = VectorRef({real_div1_var_, add1, input_vars_[2]});
|
||||||
VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1});
|
VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1});
|
||||||
VectorRef add4 = VectorRef({prim::kPrimTensorAdd, sqrt1, constant_add2_y_});
|
VectorRef add4 = VectorRef({prim::kPrimAdd, sqrt1, constant_add2_y_});
|
||||||
VectorRef mul0 = VectorRef({prim::kPrimMul, input_vars_[4], constant_mul_input_vars_[0]});
|
VectorRef mul0 = VectorRef({prim::kPrimMul, input_vars_[4], constant_mul_input_vars_[0]});
|
||||||
VectorRef mul1 = VectorRef({prim::kPrimMul, input_vars_[3], constant_mul_input_vars_[1]});
|
VectorRef mul1 = VectorRef({prim::kPrimMul, input_vars_[3], constant_mul_input_vars_[1]});
|
||||||
VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
|
VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
|
||||||
VectorRef real_div0 = VectorRef({real_div0_var_, add0, input_vars_[5]});
|
VectorRef real_div0 = VectorRef({real_div0_var_, add0, input_vars_[5]});
|
||||||
VectorRef real_div4 = VectorRef({prim_deal_div, real_div0, add4});
|
VectorRef real_div4 = VectorRef({prim_deal_div, real_div0, add4});
|
||||||
VectorRef mul4 = VectorRef({mul4_var_, input_vars_[6], constant_mul_input_vars_[4]});
|
VectorRef mul4 = VectorRef({mul4_var_, input_vars_[6], constant_mul_input_vars_[4]});
|
||||||
VectorRef add5 = VectorRef({prim::kPrimTensorAdd, mul4, real_div4});
|
VectorRef add5 = VectorRef({prim::kPrimAdd, mul4, real_div4});
|
||||||
return add5;
|
return add5;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -248,10 +248,10 @@ BaseRef LambNextMVWithDecayRuleCond4::DefineAnotherPattern() const {
|
||||||
VectorRef real_div1 = VectorRef({real_div1_var_, Ys});
|
VectorRef real_div1 = VectorRef({real_div1_var_, Ys});
|
||||||
VectorRef mul4 = VectorRef({mul4_var_, Zs});
|
VectorRef mul4 = VectorRef({mul4_var_, Zs});
|
||||||
|
|
||||||
VectorRef add2 = VectorRef({prim::kPrimTensorAdd, real_div1, constant_add2_y_});
|
VectorRef add2 = VectorRef({prim::kPrimAdd, real_div1, constant_add2_y_});
|
||||||
VectorRef sqrt0 = VectorRef({prim_rsqrt, add2});
|
VectorRef sqrt0 = VectorRef({prim_rsqrt, add2});
|
||||||
VectorRef real_div2 = VectorRef({prim::kPrimMul, real_div0, sqrt0});
|
VectorRef real_div2 = VectorRef({prim::kPrimMul, real_div0, sqrt0});
|
||||||
VectorRef add3 = VectorRef({prim::kPrimTensorAdd, real_div2, mul4});
|
VectorRef add3 = VectorRef({prim::kPrimAdd, real_div2, mul4});
|
||||||
return add3;
|
return add3;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -265,14 +265,14 @@ const BaseRef LambNextMVWithDecayRuleCond4::DefinePattern() const {
|
||||||
VectorRef add1 = VectorRef({add1_var_, mul2, mul3});
|
VectorRef add1 = VectorRef({add1_var_, mul2, mul3});
|
||||||
VectorRef real_div1 = VectorRef({real_div1_var_, add1, input_vars_[2]});
|
VectorRef real_div1 = VectorRef({real_div1_var_, add1, input_vars_[2]});
|
||||||
VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1});
|
VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1});
|
||||||
VectorRef add4 = VectorRef({prim::kPrimTensorAdd, sqrt1, constant_add2_y_});
|
VectorRef add4 = VectorRef({prim::kPrimAdd, sqrt1, constant_add2_y_});
|
||||||
VectorRef mul0 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[0], input_vars_[4]});
|
VectorRef mul0 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[0], input_vars_[4]});
|
||||||
VectorRef mul1 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[1], input_vars_[3]});
|
VectorRef mul1 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[1], input_vars_[3]});
|
||||||
VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
|
VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
|
||||||
VectorRef real_div0 = VectorRef({real_div0_var_, add0, input_vars_[5]});
|
VectorRef real_div0 = VectorRef({real_div0_var_, add0, input_vars_[5]});
|
||||||
VectorRef real_div4 = VectorRef({prim_deal_div, real_div0, add4});
|
VectorRef real_div4 = VectorRef({prim_deal_div, real_div0, add4});
|
||||||
VectorRef mul4 = VectorRef({mul4_var_, constant_mul_input_vars_[4], input_vars_[6]});
|
VectorRef mul4 = VectorRef({mul4_var_, constant_mul_input_vars_[4], input_vars_[6]});
|
||||||
VectorRef add5 = VectorRef({prim::kPrimTensorAdd, real_div4, mul4});
|
VectorRef add5 = VectorRef({prim::kPrimAdd, real_div4, mul4});
|
||||||
return add5;
|
return add5;
|
||||||
}
|
}
|
||||||
} // namespace opt
|
} // namespace opt
|
||||||
|
|
|
@ -38,8 +38,8 @@ class LambNextMVWithDecayRule : public MultipleOutputPatternProcessPass {
|
||||||
mul4_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimMul->name()));
|
mul4_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimMul->name()));
|
||||||
real_div0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(kRealDivOpName));
|
real_div0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(kRealDivOpName));
|
||||||
real_div1_var_ = std::make_shared<Var>(std::make_shared<Primitive>(kRealDivOpName));
|
real_div1_var_ = std::make_shared<Var>(std::make_shared<Primitive>(kRealDivOpName));
|
||||||
add0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimTensorAdd->name()));
|
add0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimAdd->name()));
|
||||||
add1_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimTensorAdd->name()));
|
add1_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimAdd->name()));
|
||||||
}
|
}
|
||||||
|
|
||||||
~LambNextMVWithDecayRule() override = default;
|
~LambNextMVWithDecayRule() override = default;
|
||||||
|
|
|
@ -66,7 +66,7 @@ bool MatchAdd5Pattern(const AnfNodePtr &node, const AnfNodePtr &mul4, const AnfN
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
auto add5 = node->cast<CNodePtr>();
|
auto add5 = node->cast<CNodePtr>();
|
||||||
if (AnfAlgo::GetCNodeName(add5) != prim::kPrimTensorAdd->name() || add5->inputs().size() != kAddInputNum) {
|
if (AnfAlgo::GetCNodeName(add5) != prim::kPrimAdd->name() || add5->inputs().size() != kAddInputNum) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
auto real_div4_anf = add5->input(1);
|
auto real_div4_anf = add5->input(1);
|
||||||
|
@ -82,7 +82,7 @@ bool MatchAdd5Pattern(const AnfNodePtr &node, const AnfNodePtr &mul4, const AnfN
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
auto add4 = add4_anf->cast<CNodePtr>();
|
auto add4 = add4_anf->cast<CNodePtr>();
|
||||||
if (AnfAlgo::GetCNodeName(add4) != prim::kPrimTensorAdd->name() || add4->inputs().size() != kAddInputNum) {
|
if (AnfAlgo::GetCNodeName(add4) != prim::kPrimAdd->name() || add4->inputs().size() != kAddInputNum) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
auto sqrt1_anf = add4->input(1);
|
auto sqrt1_anf = add4->input(1);
|
||||||
|
@ -140,17 +140,17 @@ const BaseRef LambNextMVWithDecayV1Rule::DefinePattern() const {
|
||||||
const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName);
|
const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName);
|
||||||
VectorRef mul3({prim::kPrimMul, mul3_sub1_, input0_});
|
VectorRef mul3({prim::kPrimMul, mul3_sub1_, input0_});
|
||||||
VectorRef mul2({prim::kPrimMul, mul2_x_, input1_});
|
VectorRef mul2({prim::kPrimMul, mul2_x_, input1_});
|
||||||
VectorRef add1({prim::kPrimTensorAdd, mul2, mul3});
|
VectorRef add1({prim::kPrimAdd, mul2, mul3});
|
||||||
VectorRef real_div1({prim_real_div, add1, input2_});
|
VectorRef real_div1({prim_real_div, add1, input2_});
|
||||||
VectorRef add2({prim::kPrimTensorAdd, real_div1, add2_y_});
|
VectorRef add2({prim::kPrimAdd, real_div1, add2_y_});
|
||||||
VectorRef mul0({prim::kPrimMul, mul0_x_, input4_});
|
VectorRef mul0({prim::kPrimMul, mul0_x_, input4_});
|
||||||
VectorRef mul1({prim::kPrimMul, mul1_sub_, input3_});
|
VectorRef mul1({prim::kPrimMul, mul1_sub_, input3_});
|
||||||
VectorRef sqrt0({prim_rsqrt, add2});
|
VectorRef sqrt0({prim_rsqrt, add2});
|
||||||
VectorRef add0({prim::kPrimTensorAdd, mul0, mul1});
|
VectorRef add0({prim::kPrimAdd, mul0, mul1});
|
||||||
VectorRef real_div0({prim_real_div, add0, input5_});
|
VectorRef real_div0({prim_real_div, add0, input5_});
|
||||||
VectorRef real_div2({prim::kPrimMul, real_div0, sqrt0});
|
VectorRef real_div2({prim::kPrimMul, real_div0, sqrt0});
|
||||||
VectorRef mul4({prim::kPrimMul, mul4_x_, input6_});
|
VectorRef mul4({prim::kPrimMul, mul4_x_, input6_});
|
||||||
VectorRef add3({prim::kPrimTensorAdd, real_div2, mul4});
|
VectorRef add3({prim::kPrimAdd, real_div2, mul4});
|
||||||
return add3;
|
return add3;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -54,7 +54,7 @@ const BaseRef LambNextRightRule::DefinePattern() const {
|
||||||
VectorRef mul3 = VectorRef({prim::kPrimMul, mul3_x_, VectorRef({prim::kPrimSquare, input0_})});
|
VectorRef mul3 = VectorRef({prim::kPrimMul, mul3_x_, VectorRef({prim::kPrimSquare, input0_})});
|
||||||
VectorRef add1 = VectorRef({add1_var_, VectorRef({prim::kPrimMul, mul2_x_, input1_}), mul3});
|
VectorRef add1 = VectorRef({add1_var_, VectorRef({prim::kPrimMul, mul2_x_, input1_}), mul3});
|
||||||
return VectorRef(
|
return VectorRef(
|
||||||
{prim::kPrimTensorAdd, VectorRef({prim_sqrt, VectorRef({prim::kPrimMul, add1, true_div1_recip_})}), add2_y_});
|
{prim::kPrimAdd, VectorRef({prim_sqrt, VectorRef({prim::kPrimMul, add1, true_div1_recip_})}), add2_y_});
|
||||||
}
|
}
|
||||||
|
|
||||||
const AnfNodePtr LambNextRightRule::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
const AnfNodePtr LambNextRightRule::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||||
|
|
|
@ -32,7 +32,7 @@ class LambNextRightRule : public PatternProcessPass {
|
||||||
mul3_x_(std::make_shared<Var>()),
|
mul3_x_(std::make_shared<Var>()),
|
||||||
true_div1_recip_(std::make_shared<Var>()),
|
true_div1_recip_(std::make_shared<Var>()),
|
||||||
add2_y_(std::make_shared<Var>()),
|
add2_y_(std::make_shared<Var>()),
|
||||||
add1_var_(std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimTensorAdd->name()))) {}
|
add1_var_(std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimAdd->name()))) {}
|
||||||
|
|
||||||
~LambNextRightRule() override = default;
|
~LambNextRightRule() override = default;
|
||||||
const BaseRef DefinePattern() const override;
|
const BaseRef DefinePattern() const override;
|
||||||
|
|
|
@ -58,7 +58,7 @@ bool GetMul(const FuncGraphPtr &graph, const CNodePtr &add, CNodePtr *mul, size_
|
||||||
const BaseRef MulAddFusion::DefinePattern() const {
|
const BaseRef MulAddFusion::DefinePattern() const {
|
||||||
VarPtr x = std::make_shared<Var>();
|
VarPtr x = std::make_shared<Var>();
|
||||||
VarPtr y = std::make_shared<Var>();
|
VarPtr y = std::make_shared<Var>();
|
||||||
VectorRef pattern({prim::kPrimTensorAdd, x, y});
|
VectorRef pattern({prim::kPrimAdd, x, y});
|
||||||
return pattern;
|
return pattern;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -51,13 +51,13 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) {
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
const BaseRef AdamFusion::DefinePattern() const {
|
const BaseRef AdamFusion::DefinePattern() const {
|
||||||
VectorRef next_m = VectorRef({prim::kPrimTensorAdd, VectorRef({prim::kPrimMul, beta1_, m_}),
|
VectorRef next_m = VectorRef(
|
||||||
VectorRef({prim::kPrimMul, one_sub_beta1_, gradient_})});
|
{prim::kPrimAdd, VectorRef({prim::kPrimMul, beta1_, m_}), VectorRef({prim::kPrimMul, one_sub_beta1_, gradient_})});
|
||||||
VectorRef next_v =
|
VectorRef next_v =
|
||||||
VectorRef({prim::kPrimTensorAdd, VectorRef({prim::kPrimMul, beta2_, v_}),
|
VectorRef({prim::kPrimAdd, VectorRef({prim::kPrimMul, beta2_, v_}),
|
||||||
VectorRef({prim::kPrimMul, one_sub_beta2_, VectorRef({prim::kPrimSquare, gradient_})})});
|
VectorRef({prim::kPrimMul, one_sub_beta2_, VectorRef({prim::kPrimSquare, gradient_})})});
|
||||||
VectorRef update = VectorRef(
|
VectorRef update =
|
||||||
{prim::kPrimRealDiv, next_m, VectorRef({prim::kPrimTensorAdd, eps_, VectorRef({prim::kPrimSqrt, next_v})})});
|
VectorRef({prim::kPrimRealDiv, next_m, VectorRef({prim::kPrimAdd, eps_, VectorRef({prim::kPrimSqrt, next_v})})});
|
||||||
VectorRef update_with_lr = VectorRef({prim::kPrimMul, lr_, update});
|
VectorRef update_with_lr = VectorRef({prim::kPrimMul, lr_, update});
|
||||||
VectorRef next_param = VectorRef({prim::kPrimSub, param_, update_with_lr});
|
VectorRef next_param = VectorRef({prim::kPrimSub, param_, update_with_lr});
|
||||||
|
|
||||||
|
|
|
@ -51,14 +51,14 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) {
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
const BaseRef AdamWeightDecayFusion::DefinePattern() const {
|
const BaseRef AdamWeightDecayFusion::DefinePattern() const {
|
||||||
VectorRef next_m = VectorRef({prim::kPrimTensorAdd, VectorRef({prim::kPrimMul, beta1_, m_}),
|
VectorRef next_m = VectorRef(
|
||||||
VectorRef({prim::kPrimMul, one_sub_beta1_, gradient_})});
|
{prim::kPrimAdd, VectorRef({prim::kPrimMul, beta1_, m_}), VectorRef({prim::kPrimMul, one_sub_beta1_, gradient_})});
|
||||||
VectorRef next_v =
|
VectorRef next_v =
|
||||||
VectorRef({prim::kPrimTensorAdd, VectorRef({prim::kPrimMul, beta2_, v_}),
|
VectorRef({prim::kPrimAdd, VectorRef({prim::kPrimMul, beta2_, v_}),
|
||||||
VectorRef({prim::kPrimMul, one_sub_beta2_, VectorRef({prim::kPrimSquare, gradient_})})});
|
VectorRef({prim::kPrimMul, one_sub_beta2_, VectorRef({prim::kPrimSquare, gradient_})})});
|
||||||
VectorRef update = VectorRef(
|
VectorRef update =
|
||||||
{prim::kPrimRealDiv, next_m, VectorRef({prim::kPrimTensorAdd, eps_, VectorRef({prim::kPrimSqrt, next_v})})});
|
VectorRef({prim::kPrimRealDiv, next_m, VectorRef({prim::kPrimAdd, eps_, VectorRef({prim::kPrimSqrt, next_v})})});
|
||||||
VectorRef new_update = VectorRef({prim::kPrimTensorAdd, VectorRef({prim::kPrimMul, weight_decay_, param_}), update});
|
VectorRef new_update = VectorRef({prim::kPrimAdd, VectorRef({prim::kPrimMul, weight_decay_, param_}), update});
|
||||||
|
|
||||||
VectorRef update_with_lr = VectorRef({prim::kPrimMul, lr_, new_update});
|
VectorRef update_with_lr = VectorRef({prim::kPrimMul, lr_, new_update});
|
||||||
VectorRef next_param = VectorRef({prim::kPrimSub, param_, update_with_lr});
|
VectorRef next_param = VectorRef({prim::kPrimSub, param_, update_with_lr});
|
||||||
|
|
|
@ -51,7 +51,7 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) {
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
const BaseRef AddReluGradV2Fusion::DefinePattern() const {
|
const BaseRef AddReluGradV2Fusion::DefinePattern() const {
|
||||||
VectorRef relu_grad = VectorRef({prim::kPrimReluGradV2, VectorRef({prim::kPrimTensorAdd, x1_, x2_}), mask_});
|
VectorRef relu_grad = VectorRef({prim::kPrimReluGradV2, VectorRef({prim::kPrimAdd, x1_, x2_}), mask_});
|
||||||
return relu_grad;
|
return relu_grad;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -51,7 +51,7 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) {
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
const BaseRef AddReluV2Fusion::DefinePattern() const {
|
const BaseRef AddReluV2Fusion::DefinePattern() const {
|
||||||
VectorRef relu = VectorRef({prim::kPrimReluV2, VectorRef({prim::kPrimTensorAdd, x1_, x2_})});
|
VectorRef relu = VectorRef({prim::kPrimReluV2, VectorRef({prim::kPrimAdd, x1_, x2_})});
|
||||||
return relu;
|
return relu;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -30,7 +30,7 @@ namespace opt {
|
||||||
const BaseRef BatchNormAddReluFusion::DefinePattern() const {
|
const BaseRef BatchNormAddReluFusion::DefinePattern() const {
|
||||||
VectorRef batch_norm_ex = VectorRef({prim::kPrimFusedBatchNormEx, x_, scale_, bias_, mean_, var_});
|
VectorRef batch_norm_ex = VectorRef({prim::kPrimFusedBatchNormEx, x_, scale_, bias_, mean_, var_});
|
||||||
VectorRef tuple_get_item = VectorRef({prim::kPrimTupleGetItem, batch_norm_ex, index_});
|
VectorRef tuple_get_item = VectorRef({prim::kPrimTupleGetItem, batch_norm_ex, index_});
|
||||||
VectorRef tensor_add = VectorRef({prim::kPrimTensorAdd, tuple_get_item, z_});
|
VectorRef tensor_add = VectorRef({prim::kPrimAdd, tuple_get_item, z_});
|
||||||
VectorRef relu = VectorRef({prim::kPrimRelu, tensor_add});
|
VectorRef relu = VectorRef({prim::kPrimRelu, tensor_add});
|
||||||
return relu;
|
return relu;
|
||||||
}
|
}
|
||||||
|
|
|
@ -42,7 +42,7 @@ const AnfNodePtr ReplaceAddNFusion::Process(const FuncGraphPtr &graph, const Anf
|
||||||
MS_EXCEPTION_IF_NULL(B);
|
MS_EXCEPTION_IF_NULL(B);
|
||||||
int64_t num_input = AnfAlgo::GetNodeAttr<int64_t>(node, "n");
|
int64_t num_input = AnfAlgo::GetNodeAttr<int64_t>(node, "n");
|
||||||
if (num_input == 2) {
|
if (num_input == 2) {
|
||||||
auto prim = std::make_shared<Primitive>(prim::kPrimTensorAdd->name());
|
auto prim = std::make_shared<Primitive>(prim::kPrimAdd->name());
|
||||||
MS_EXCEPTION_IF_NULL(prim);
|
MS_EXCEPTION_IF_NULL(prim);
|
||||||
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), A, B};
|
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), A, B};
|
||||||
auto add_new = graph->NewCNode(inputs);
|
auto add_new = graph->NewCNode(inputs);
|
||||||
|
|
|
@ -47,7 +47,7 @@ AnfNodePtr NewCNodeWithInfo(const AnfNodePtrList &inputs, const AnfNodePtr &ori_
|
||||||
}
|
}
|
||||||
|
|
||||||
AnfNodePtr SimplifyAdd(const AnfNodePtr &node) {
|
AnfNodePtr SimplifyAdd(const AnfNodePtr &node) {
|
||||||
if (!IsPrimitiveCNode(node, prim::kPrimTensorAdd)) {
|
if (!IsPrimitiveCNode(node, prim::kPrimAdd)) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
PatternNode<AnfNodePtr> x, y, z;
|
PatternNode<AnfNodePtr> x, y, z;
|
||||||
|
@ -57,13 +57,13 @@ AnfNodePtr SimplifyAdd(const AnfNodePtr &node) {
|
||||||
PConstant<AnfNodePtr> any_const_2(node);
|
PConstant<AnfNodePtr> any_const_2(node);
|
||||||
|
|
||||||
auto add_distri_lambda = [&node, &x, &y, &any_const]() -> AnfNodePtr {
|
auto add_distri_lambda = [&node, &x, &y, &any_const]() -> AnfNodePtr {
|
||||||
auto node_tmp = NewCNodeWithInfo({NewValueNode(prim::kPrimTensorAdd), x.GetNode(node), y.GetNode(node)}, node);
|
auto node_tmp = NewCNodeWithInfo({NewValueNode(prim::kPrimAdd), x.GetNode(node), y.GetNode(node)}, node);
|
||||||
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimMul), node_tmp, any_const.GetNode(node)}, node);
|
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimMul), node_tmp, any_const.GetNode(node)}, node);
|
||||||
return new_cnode;
|
return new_cnode;
|
||||||
};
|
};
|
||||||
auto add_union_lambda = [&node, &x, &any_const, &any_const_2]() -> AnfNodePtr {
|
auto add_union_lambda = [&node, &x, &any_const, &any_const_2]() -> AnfNodePtr {
|
||||||
auto new_rhs = any_const.AddByPatternConst(any_const_2, x.GetNode(node));
|
auto new_rhs = any_const.AddByPatternConst(any_const_2, x.GetNode(node));
|
||||||
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimTensorAdd), x.GetNode(node), new_rhs}, node);
|
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimAdd), x.GetNode(node), new_rhs}, node);
|
||||||
return new_cnode;
|
return new_cnode;
|
||||||
};
|
};
|
||||||
// A + 0 = A
|
// A + 0 = A
|
||||||
|
@ -88,7 +88,7 @@ AnfNodePtr SimplifySub(const AnfNodePtr &node) {
|
||||||
PConstant<AnfNodePtr> any_const(node);
|
PConstant<AnfNodePtr> any_const(node);
|
||||||
auto sub_toadd_lambda = [&node, &x, &any_const]() -> AnfNodePtr {
|
auto sub_toadd_lambda = [&node, &x, &any_const]() -> AnfNodePtr {
|
||||||
auto new_rhs = any_const.ValueNodeWithOprations(prim::kPrimNeg);
|
auto new_rhs = any_const.ValueNodeWithOprations(prim::kPrimNeg);
|
||||||
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimTensorAdd), x.GetNode(node), new_rhs}, node);
|
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimAdd), x.GetNode(node), new_rhs}, node);
|
||||||
return new_cnode;
|
return new_cnode;
|
||||||
};
|
};
|
||||||
// A - 0 = A
|
// A - 0 = A
|
||||||
|
@ -269,7 +269,7 @@ AnfNodePtr SimplifyMul(const AnfNodePtr &node) {
|
||||||
return new_cnode;
|
return new_cnode;
|
||||||
};
|
};
|
||||||
auto exp_merge_lambda = [&node, &x, &y]() -> AnfNodePtr {
|
auto exp_merge_lambda = [&node, &x, &y]() -> AnfNodePtr {
|
||||||
auto node_tmp = NewCNodeWithInfo({NewValueNode(prim::kPrimTensorAdd), x.GetNode(node), y.GetNode(node)}, node);
|
auto node_tmp = NewCNodeWithInfo({NewValueNode(prim::kPrimAdd), x.GetNode(node), y.GetNode(node)}, node);
|
||||||
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimExp), node_tmp}, node);
|
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimExp), node_tmp}, node);
|
||||||
return new_cnode;
|
return new_cnode;
|
||||||
};
|
};
|
||||||
|
|
|
@ -741,14 +741,14 @@ std::string ExtractGraphKernelName(const AnfNodePtrList &cnodes, const string &p
|
||||||
std::vector<PrimitivePtr> GetFusibleOpList() {
|
std::vector<PrimitivePtr> GetFusibleOpList() {
|
||||||
#if ENABLE_D
|
#if ENABLE_D
|
||||||
std::vector<PrimitivePtr> fusible_basic_ops = {
|
std::vector<PrimitivePtr> fusible_basic_ops = {
|
||||||
prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimTensorAdd,
|
prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimAdd,
|
||||||
prim::kPrimExpandDims, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog,
|
prim::kPrimExpandDims, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog,
|
||||||
prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimAddN,
|
prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimAddN,
|
||||||
prim::kPrimEqual, prim::kPrimReciprocal, prim::kPrimTanh, prim::kPrimReshape, prim::kPrimTranspose,
|
prim::kPrimEqual, prim::kPrimReciprocal, prim::kPrimTanh, prim::kPrimReshape, prim::kPrimTranspose,
|
||||||
prim::kPrimCast, prim::kPrimRealDiv};
|
prim::kPrimCast, prim::kPrimRealDiv};
|
||||||
#elif ENABLE_GPU
|
#elif ENABLE_GPU
|
||||||
std::vector<PrimitivePtr> fusible_basic_ops = {
|
std::vector<PrimitivePtr> fusible_basic_ops = {
|
||||||
prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimTensorAdd,
|
prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimAdd,
|
||||||
prim::kPrimRealDiv, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog,
|
prim::kPrimRealDiv, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog,
|
||||||
prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimAddN,
|
prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimAddN,
|
||||||
prim::kPrimEqual, prim::kPrimReciprocal, prim::KPrimTransData, prim::kPrimSelect, prim::kPrimGreater,
|
prim::kPrimEqual, prim::kPrimReciprocal, prim::KPrimTransData, prim::kPrimSelect, prim::kPrimGreater,
|
||||||
|
|
|
@ -52,7 +52,7 @@ namespace opt {
|
||||||
namespace irpass {
|
namespace irpass {
|
||||||
OptimizeIRPassLib::OptimizeIRPassLib() {
|
OptimizeIRPassLib::OptimizeIRPassLib() {
|
||||||
arithmetic_simplify_ = MakeSubstitution(std::make_shared<ArithmeticSimplify>(), "arithmetic_simplify",
|
arithmetic_simplify_ = MakeSubstitution(std::make_shared<ArithmeticSimplify>(), "arithmetic_simplify",
|
||||||
{prim::kPrimScalarAdd, prim::kPrimScalarMul, prim::kPrimTensorAdd,
|
{prim::kPrimScalarAdd, prim::kPrimScalarMul, prim::kPrimAdd,
|
||||||
prim::kPrimIdentity, prim::kPrimMomentum, prim::kPrimMul, prim::kPrimPow});
|
prim::kPrimIdentity, prim::kPrimMomentum, prim::kPrimMul, prim::kPrimPow});
|
||||||
arithmetic_simplify2_ =
|
arithmetic_simplify2_ =
|
||||||
MakeSubstitution(std::make_shared<ArithmeticSimplify2>(), "arithmetic_simplify2", {prim::kPrimMul});
|
MakeSubstitution(std::make_shared<ArithmeticSimplify2>(), "arithmetic_simplify2", {prim::kPrimMul});
|
||||||
|
|
|
@ -272,7 +272,7 @@ class AddNEliminater : public AnfVisitor {
|
||||||
if (tuple_inputs.size() == 3) {
|
if (tuple_inputs.size() == 3) {
|
||||||
// case2: inputs size = 2, -> TensorAdd(Tensor, Tensor)
|
// case2: inputs size = 2, -> TensorAdd(Tensor, Tensor)
|
||||||
MS_LOG(DEBUG) << "Replace AddN with two inputs with TensorAdd. " << cnode->DebugString(2);
|
MS_LOG(DEBUG) << "Replace AddN with two inputs with TensorAdd. " << cnode->DebugString(2);
|
||||||
ValuePtr prim_tensoradd = prim::GetPythonOps("TensorAdd", "mindspore.ops.operations");
|
ValuePtr prim_tensoradd = prim::GetPythonOps("Add", "mindspore.ops.operations");
|
||||||
std::vector<AnfNodePtr> new_xs{func_graph->NewCNode({NewValueNode(prim_tensoradd)}), tuple_inputs[1],
|
std::vector<AnfNodePtr> new_xs{func_graph->NewCNode({NewValueNode(prim_tensoradd)}), tuple_inputs[1],
|
||||||
tuple_inputs[2]};
|
tuple_inputs[2]};
|
||||||
mng->Replace(node, func_graph->NewCNode(new_xs));
|
mng->Replace(node, func_graph->NewCNode(new_xs));
|
||||||
|
@ -299,7 +299,7 @@ class AddNEliminater : public AnfVisitor {
|
||||||
ValuePtr prim_addn = prim::GetPythonOps("AddN", "mindspore.ops.operations");
|
ValuePtr prim_addn = prim::GetPythonOps("AddN", "mindspore.ops.operations");
|
||||||
auto new_addn = func_graph->NewCNode(
|
auto new_addn = func_graph->NewCNode(
|
||||||
{func_graph->NewCNode({NewValueNode(prim_addn)}), func_graph->NewCNode(make_tuple_new_xs)});
|
{func_graph->NewCNode({NewValueNode(prim_addn)}), func_graph->NewCNode(make_tuple_new_xs)});
|
||||||
ValuePtr prim_tensoradd = prim::GetPythonOps("TensorAdd", "mindspore.ops.operations");
|
ValuePtr prim_tensoradd = prim::GetPythonOps("Add", "mindspore.ops.operations");
|
||||||
auto new_add =
|
auto new_add =
|
||||||
func_graph->NewCNode({func_graph->NewCNode({NewValueNode(prim_tensoradd)}), *first_valuenode, new_addn});
|
func_graph->NewCNode({func_graph->NewCNode({NewValueNode(prim_tensoradd)}), *first_valuenode, new_addn});
|
||||||
(void)mng->Replace(node, new_add);
|
(void)mng->Replace(node, new_add);
|
||||||
|
|
|
@ -860,7 +860,7 @@ Strategys GenerateStrategiesFromStrategy(const std::vector<std::shared_ptr<Opera
|
||||||
if (ops[iter_ops]->type() == L2_NORMALIZE) {
|
if (ops[iter_ops]->type() == L2_NORMALIZE) {
|
||||||
return PrepareL2Normalize(ops, iter_ops, basic_stra);
|
return PrepareL2Normalize(ops, iter_ops, basic_stra);
|
||||||
}
|
}
|
||||||
if (ops[iter_ops]->type() == TENSOR_ADD || ops[iter_ops]->type() == SUB || ops[iter_ops]->type() == MUL ||
|
if (ops[iter_ops]->type() == ADD || ops[iter_ops]->type() == SUB || ops[iter_ops]->type() == MUL ||
|
||||||
ops[iter_ops]->type() == DIV) {
|
ops[iter_ops]->type() == DIV) {
|
||||||
return CheckBroadcast(ops, iter_ops, basic_stra);
|
return CheckBroadcast(ops, iter_ops, basic_stra);
|
||||||
}
|
}
|
||||||
|
|
|
@ -78,7 +78,7 @@ const std::map<std::string, OperatorType> DictOpType{
|
||||||
// Elm-wise OP
|
// Elm-wise OP
|
||||||
{TRANSPOSE, OperatorType::kRecElmWiseOp},
|
{TRANSPOSE, OperatorType::kRecElmWiseOp},
|
||||||
{L2_NORMALIZE, OperatorType::kRecElmWiseOp},
|
{L2_NORMALIZE, OperatorType::kRecElmWiseOp},
|
||||||
{TENSOR_ADD, OperatorType::kRecElmWiseOp},
|
{ADD, OperatorType::kRecElmWiseOp},
|
||||||
{TENSOR_DOT, OperatorType::kRecElmWiseOp},
|
{TENSOR_DOT, OperatorType::kRecElmWiseOp},
|
||||||
{SUB, OperatorType::kRecElmWiseOp},
|
{SUB, OperatorType::kRecElmWiseOp},
|
||||||
{MUL, OperatorType::kRecElmWiseOp},
|
{MUL, OperatorType::kRecElmWiseOp},
|
||||||
|
|
|
@ -86,7 +86,7 @@ REGISTER(LogSoftmaxInfo);
|
||||||
REGISTER(ActivationInfo);
|
REGISTER(ActivationInfo);
|
||||||
REGISTER(SoftmaxCrossEntropyWithLogitsInfo);
|
REGISTER(SoftmaxCrossEntropyWithLogitsInfo);
|
||||||
REGISTER(SubInfo);
|
REGISTER(SubInfo);
|
||||||
REGISTER(TensorAddInfo);
|
REGISTER(AddInfo);
|
||||||
REGISTER(BiasAddInfo);
|
REGISTER(BiasAddInfo);
|
||||||
REGISTER(MulInfo);
|
REGISTER(MulInfo);
|
||||||
REGISTER(DivInfo);
|
REGISTER(DivInfo);
|
||||||
|
|
|
@ -60,12 +60,11 @@ class SubInfo : public ArithmeticBase {
|
||||||
~SubInfo() override = default;
|
~SubInfo() override = default;
|
||||||
};
|
};
|
||||||
|
|
||||||
class TensorAddInfo : public ArithmeticBase {
|
class AddInfo : public ArithmeticBase {
|
||||||
public:
|
public:
|
||||||
TensorAddInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
AddInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs)
|
||||||
const PrimitiveAttrs &attrs)
|
|
||||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<TensorAddCost>()) {}
|
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<TensorAddCost>()) {}
|
||||||
~TensorAddInfo() override = default;
|
~AddInfo() override = default;
|
||||||
};
|
};
|
||||||
|
|
||||||
class MulInfo : public ArithmeticBase {
|
class MulInfo : public ArithmeticBase {
|
||||||
|
|
|
@ -191,7 +191,7 @@ Status OneHotInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
|
||||||
auto equal = gen_g.PushBack({gen_g.NewOpInst(EQUAL), floor_div, CreateInt32Tensor(mod_rank_)});
|
auto equal = gen_g.PushBack({gen_g.NewOpInst(EQUAL), floor_div, CreateInt32Tensor(mod_rank_)});
|
||||||
auto cast = gen_g.PushBack({gen_g.NewOpInst(CAST), equal, CreatTypeInt(32)});
|
auto cast = gen_g.PushBack({gen_g.NewOpInst(CAST), equal, CreatTypeInt(32)});
|
||||||
auto mul2 = gen_g.PushBack({gen_g.NewOpInst(MUL), sub1, cast});
|
auto mul2 = gen_g.PushBack({gen_g.NewOpInst(MUL), sub1, cast});
|
||||||
auto tensor_add = gen_g.PushBack({gen_g.NewOpInst(TENSOR_ADD), mul2, CreateInt32Tensor(1)});
|
auto tensor_add = gen_g.PushBack({gen_g.NewOpInst(ADD), mul2, CreateInt32Tensor(1)});
|
||||||
auto mul3 = gen_g.PushBack({gen_g.NewOpInst(MUL), cast, tensor_add});
|
auto mul3 = gen_g.PushBack({gen_g.NewOpInst(MUL), cast, tensor_add});
|
||||||
auto sub2 = gen_g.PushBack({gen_g.NewOpInst(SUB), mul3, CreateInt32Tensor(1)});
|
auto sub2 = gen_g.PushBack({gen_g.NewOpInst(SUB), mul3, CreateInt32Tensor(1)});
|
||||||
Attr attr_onehot_axis = std::make_pair(AXIS, axis_value_ptr_);
|
Attr attr_onehot_axis = std::make_pair(AXIS, axis_value_ptr_);
|
||||||
|
|
|
@ -200,7 +200,7 @@ constexpr char MAXPOOLV2[] = "MaxPoolV2";
|
||||||
constexpr char L2_NORMALIZE[] = "L2Normalize";
|
constexpr char L2_NORMALIZE[] = "L2Normalize";
|
||||||
constexpr char TRANSPOSE[] = "Transpose";
|
constexpr char TRANSPOSE[] = "Transpose";
|
||||||
constexpr char RESHAPE[] = "Reshape";
|
constexpr char RESHAPE[] = "Reshape";
|
||||||
constexpr char TENSOR_ADD[] = "TensorAdd";
|
constexpr char ADD[] = "Add";
|
||||||
constexpr char BIAS_ADD[] = "BiasAdd";
|
constexpr char BIAS_ADD[] = "BiasAdd";
|
||||||
constexpr char SUB[] = "Sub";
|
constexpr char SUB[] = "Sub";
|
||||||
constexpr char MUL[] = "Mul";
|
constexpr char MUL[] = "Mul";
|
||||||
|
@ -315,7 +315,6 @@ constexpr char UNSORTED_SEGMENT_MIN[] = "UnsortedSegmentMin";
|
||||||
constexpr char UNSORTED_SEGMENT_MAX[] = "UnsortedSegmentMax";
|
constexpr char UNSORTED_SEGMENT_MAX[] = "UnsortedSegmentMax";
|
||||||
constexpr char DEPTHWISE_CONV2D_NATIVE[] = "DepthwiseConv2dNative";
|
constexpr char DEPTHWISE_CONV2D_NATIVE[] = "DepthwiseConv2dNative";
|
||||||
constexpr char DEPTHWISE_CONV2D[] = "DepthwiseConv2D";
|
constexpr char DEPTHWISE_CONV2D[] = "DepthwiseConv2D";
|
||||||
constexpr char ADD[] = "Add";
|
|
||||||
constexpr char DROPOUT[] = "Dropout";
|
constexpr char DROPOUT[] = "Dropout";
|
||||||
constexpr char KStridedSlice[] = "StridedSlice";
|
constexpr char KStridedSlice[] = "StridedSlice";
|
||||||
constexpr char UNIQUE[] = "Unique";
|
constexpr char UNIQUE[] = "Unique";
|
||||||
|
|
|
@ -151,7 +151,7 @@ bool IsSplittableOperator(const std::string &op_name) {
|
||||||
// clang-format off
|
// clang-format off
|
||||||
static const std::set<std::string> splittable_op =
|
static const std::set<std::string> splittable_op =
|
||||||
{MATMUL, TRANSPOSE, GELU, TANH, SOFTMAX, SUB, MUL, DIV, RESHAPE, GREATER, LOG_SOFTMAX, ACTIVATION, PRELU,
|
{MATMUL, TRANSPOSE, GELU, TANH, SOFTMAX, SUB, MUL, DIV, RESHAPE, GREATER, LOG_SOFTMAX, ACTIVATION, PRELU,
|
||||||
FLOORDIV, L2_NORMALIZE, TENSOR_ADD, MAXPOOL, MAXPOOLV2, VIRTUAL_DATA_SET, RELU, ONEHOT, DROPOUT_DO_MASK,
|
FLOORDIV, L2_NORMALIZE, ADD, MAXPOOL, MAXPOOLV2, VIRTUAL_DATA_SET, RELU, ONEHOT, DROPOUT_DO_MASK,
|
||||||
REDUCE_MAX, REDUCE_MIN, ARGMAXWITHVALUE, ARGMINWITHVALUE, REDUCE_SUM, CONV2D, FUSE_BATCH_NORM, POOLING,
|
REDUCE_MAX, REDUCE_MIN, ARGMAXWITHVALUE, ARGMINWITHVALUE, REDUCE_SUM, CONV2D, FUSE_BATCH_NORM, POOLING,
|
||||||
MAX_POOL_WITH_ARGMAX, SIMPLE_MEAN, FLATTEN, BATCH_NORM, LAYER_NORM, BIAS_ADD, ASSIGN_SUB, COS, ACOS, EXP, PACK,
|
MAX_POOL_WITH_ARGMAX, SIMPLE_MEAN, FLATTEN, BATCH_NORM, LAYER_NORM, BIAS_ADD, ASSIGN_SUB, COS, ACOS, EXP, PACK,
|
||||||
LOG, REDUCE_MEAN, REAL_DIV, SIGMOID, POW, MAXIMUM, MINIMUM, EQUAL, NOT_EQUAL, LOGICALNOT, GATHERV2, SQRT, CONCAT,
|
LOG, REDUCE_MEAN, REAL_DIV, SIGMOID, POW, MAXIMUM, MINIMUM, EQUAL, NOT_EQUAL, LOGICALNOT, GATHERV2, SQRT, CONCAT,
|
||||||
|
|
|
@ -165,7 +165,7 @@ class OpNameInfo {
|
||||||
#define OPERATOR_ONNX_CONVERT_DEFINE(name, onnx_name, impl) \
|
#define OPERATOR_ONNX_CONVERT_DEFINE(name, onnx_name, impl) \
|
||||||
OpNameInfo GetOpOnnxConvertInfo_##name() { return impl.set_op_type(#name).set_onnx_type(#onnx_name); }
|
OpNameInfo GetOpOnnxConvertInfo_##name() { return impl.set_op_type(#name).set_onnx_type(#onnx_name); }
|
||||||
|
|
||||||
OPERATOR_ONNX_CONVERT_DEFINE(TensorAdd, Add, OpNameInfo())
|
OPERATOR_ONNX_CONVERT_DEFINE(Add, Add, OpNameInfo())
|
||||||
OPERATOR_ONNX_CONVERT_DEFINE(Mul, Mul, OpNameInfo())
|
OPERATOR_ONNX_CONVERT_DEFINE(Mul, Mul, OpNameInfo())
|
||||||
|
|
||||||
OPERATOR_ONNX_CONVERT_DEFINE(ReLU, Relu, OpNameInfo())
|
OPERATOR_ONNX_CONVERT_DEFINE(ReLU, Relu, OpNameInfo())
|
||||||
|
@ -257,7 +257,7 @@ OPERATOR_ONNX_CONVERT_DEFINE(Sub, Sub, OpNameInfo())
|
||||||
#define OP_CONVERT_FUNCTION_NAME(name) GetOpOnnxConvertInfo_##name
|
#define OP_CONVERT_FUNCTION_NAME(name) GetOpOnnxConvertInfo_##name
|
||||||
|
|
||||||
void RegisterOpConverters(const std::function<void(OpNameInfo &&)> &fn) {
|
void RegisterOpConverters(const std::function<void(OpNameInfo &&)> &fn) {
|
||||||
fn(OP_CONVERT_FUNCTION_NAME(TensorAdd)());
|
fn(OP_CONVERT_FUNCTION_NAME(Add)());
|
||||||
fn(OP_CONVERT_FUNCTION_NAME(Mul)());
|
fn(OP_CONVERT_FUNCTION_NAME(Mul)());
|
||||||
|
|
||||||
fn(OP_CONVERT_FUNCTION_NAME(ReLU)());
|
fn(OP_CONVERT_FUNCTION_NAME(ReLU)());
|
||||||
|
|
|
@ -29,7 +29,7 @@ REG_ADPT_DESC(StateSetItem, prim::kPrimStateSetItem->name(), ADPT_DESC(Assign))
|
||||||
INPUT_MAP(Add) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}};
|
INPUT_MAP(Add) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}};
|
||||||
ATTR_MAP(Add) = EMPTY_ATTR_MAP;
|
ATTR_MAP(Add) = EMPTY_ATTR_MAP;
|
||||||
OUTPUT_MAP(Add) = {{0, OUTPUT_DESC(y)}};
|
OUTPUT_MAP(Add) = {{0, OUTPUT_DESC(y)}};
|
||||||
REG_ADPT_DESC(Add, prim::kPrimTensorAdd->name(),
|
REG_ADPT_DESC(Add, prim::kPrimAdd->name(),
|
||||||
std::make_shared<OpAdapterDesc>(
|
std::make_shared<OpAdapterDesc>(
|
||||||
std::make_shared<OpAdapter<Add>>(ExtraAttr({{"mode", MakeValue(static_cast<int64_t>(1))}})),
|
std::make_shared<OpAdapter<Add>>(ExtraAttr({{"mode", MakeValue(static_cast<int64_t>(1))}})),
|
||||||
std::make_shared<OpAdapter<Add>>(ExtraAttr({{"mode", MakeValue(static_cast<int64_t>(1))}}))))
|
std::make_shared<OpAdapter<Add>>(ExtraAttr({{"mode", MakeValue(static_cast<int64_t>(1))}}))))
|
||||||
|
|
|
@ -215,7 +215,7 @@ constexpr auto kAvgPoolGradGpuOpName = "AvgPoolGradGpu";
|
||||||
constexpr auto kmaxPoolGradOpName = "MaxPoolGrad";
|
constexpr auto kmaxPoolGradOpName = "MaxPoolGrad";
|
||||||
constexpr auto kMaxPoolWithArgmaxOpName = "MaxPoolWithArgmax";
|
constexpr auto kMaxPoolWithArgmaxOpName = "MaxPoolWithArgmax";
|
||||||
constexpr auto kMaxPoolGradWithArgmaxOpName = "MaxPoolGradWithArgmax";
|
constexpr auto kMaxPoolGradWithArgmaxOpName = "MaxPoolGradWithArgmax";
|
||||||
constexpr auto kTensorAddOpName = "TensorAdd";
|
constexpr auto kTensorAddOpName = "Add";
|
||||||
constexpr auto kCastOpName = "Cast";
|
constexpr auto kCastOpName = "Cast";
|
||||||
constexpr auto kGreaterEqualOpName = "GreaterEqual";
|
constexpr auto kGreaterEqualOpName = "GreaterEqual";
|
||||||
constexpr auto kAbsOpName = "Abs";
|
constexpr auto kAbsOpName = "Abs";
|
||||||
|
|
|
@ -46,7 +46,7 @@ class ExportToQuantInferNetwork:
|
||||||
Returns:
|
Returns:
|
||||||
Cell, Infer network.
|
Cell, Infer network.
|
||||||
"""
|
"""
|
||||||
__quant_op_name__ = ["TensorAdd", "Sub", "Mul", "RealDiv"]
|
__quant_op_name__ = ["Add", "Sub", "Mul", "RealDiv"]
|
||||||
|
|
||||||
def __init__(self, network, mean, std_dev, *inputs, is_mindir=False):
|
def __init__(self, network, mean, std_dev, *inputs, is_mindir=False):
|
||||||
network = Validator.check_isinstance('network', network, (nn.Cell,))
|
network = Validator.check_isinstance('network', network, (nn.Cell,))
|
||||||
|
@ -225,7 +225,7 @@ class ExportManualQuantNetwork(ExportToQuantInferNetwork):
|
||||||
Returns:
|
Returns:
|
||||||
Cell, Infer network.
|
Cell, Infer network.
|
||||||
"""
|
"""
|
||||||
__quant_op_name__ = ["TensorAdd", "Sub", "Mul", "RealDiv"]
|
__quant_op_name__ = ["Add", "Sub", "Mul", "RealDiv"]
|
||||||
|
|
||||||
def __init__(self, network, mean, std_dev, *inputs, is_mindir=False):
|
def __init__(self, network, mean, std_dev, *inputs, is_mindir=False):
|
||||||
super(ExportManualQuantNetwork, self).__init__(network, mean, std_dev, *inputs, is_mindir)
|
super(ExportManualQuantNetwork, self).__init__(network, mean, std_dev, *inputs, is_mindir)
|
||||||
|
|
|
@ -173,7 +173,7 @@ class QuantizationAwareTraining(Quantizer):
|
||||||
>>> quantizer = QuantizationAwareTraining(bn_fold=False, per_channel=[True, False], symmetric=[True, False])
|
>>> quantizer = QuantizationAwareTraining(bn_fold=False, per_channel=[True, False], symmetric=[True, False])
|
||||||
>>> net_qat = quantizer.quantize(net)
|
>>> net_qat = quantizer.quantize(net)
|
||||||
"""
|
"""
|
||||||
__quant_op_name__ = ["TensorAdd", "Sub", "Mul", "RealDiv"]
|
__quant_op_name__ = ["Add", "Sub", "Mul", "RealDiv"]
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
bn_fold=True,
|
bn_fold=True,
|
||||||
|
|
|
@ -91,8 +91,8 @@ AbstractBasePtr InferImplMinOrMaxGrad(const AnalysisEnginePtr &, const Primitive
|
||||||
AbstractBasePtr InferImplMul(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr InferImplMul(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const AbstractBasePtrList &args_spec_list);
|
const AbstractBasePtrList &args_spec_list);
|
||||||
|
|
||||||
AbstractBasePtr InferImplTensorAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr InferImplAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const AbstractBasePtrList &args_spec_list);
|
const AbstractBasePtrList &args_spec_list);
|
||||||
|
|
||||||
AbstractBasePtr InferImplSquare(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr InferImplSquare(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const AbstractBasePtrList &args_spec_list);
|
const AbstractBasePtrList &args_spec_list);
|
||||||
|
|
|
@ -60,8 +60,8 @@ AbstractBasePtr InferImplSqrtGrad(const AnalysisEnginePtr &, const PrimitivePtr
|
||||||
return out->Broaden();
|
return out->Broaden();
|
||||||
}
|
}
|
||||||
|
|
||||||
AbstractBasePtr InferImplTensorAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr InferImplAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const AbstractBasePtrList &args_spec_list) {
|
const AbstractBasePtrList &args_spec_list) {
|
||||||
// Inputs: two tensors.
|
// Inputs: two tensors.
|
||||||
const std::string op_name = primitive->name();
|
const std::string op_name = primitive->name();
|
||||||
CheckArgsSize(op_name, args_spec_list, 2);
|
CheckArgsSize(op_name, args_spec_list, 2);
|
||||||
|
|
|
@ -37,7 +37,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
|
||||||
{prim::kPrimMaximumGrad, {InferImplMinOrMaxGrad, true}},
|
{prim::kPrimMaximumGrad, {InferImplMinOrMaxGrad, true}},
|
||||||
{prim::kPrimMinimumGrad, {InferImplMinOrMaxGrad, true}},
|
{prim::kPrimMinimumGrad, {InferImplMinOrMaxGrad, true}},
|
||||||
{prim::kPrimMul, {InferImplMul, true}},
|
{prim::kPrimMul, {InferImplMul, true}},
|
||||||
{prim::kPrimTensorAdd, {InferImplTensorAdd, true}},
|
{prim::kPrimAdd, {InferImplAdd, true}},
|
||||||
{prim::kPrimSquare, {InferImplSquare, true}},
|
{prim::kPrimSquare, {InferImplSquare, true}},
|
||||||
{prim::kPrimSqrt, {InferImplSqrt, true}},
|
{prim::kPrimSqrt, {InferImplSqrt, true}},
|
||||||
{prim::kPrimSqrtGrad, {InferImplSqrtGrad, true}},
|
{prim::kPrimSqrtGrad, {InferImplSqrtGrad, true}},
|
||||||
|
|
|
@ -236,7 +236,7 @@ inline const PrimitivePtr kPrimSparseTensorGetIndices = std::make_shared<Primiti
|
||||||
inline const PrimitivePtr kPrimSparseTensorGetDenseShape = std::make_shared<Primitive>("SparseTensorGetDenseShape");
|
inline const PrimitivePtr kPrimSparseTensorGetDenseShape = std::make_shared<Primitive>("SparseTensorGetDenseShape");
|
||||||
|
|
||||||
// Maths
|
// Maths
|
||||||
inline const PrimitivePtr kPrimTensorAdd = std::make_shared<Primitive>("TensorAdd");
|
inline const PrimitivePtr kPrimAdd = std::make_shared<Primitive>("Add");
|
||||||
inline const PrimitivePtr kPrimMatMul = std::make_shared<Primitive>("MatMul");
|
inline const PrimitivePtr kPrimMatMul = std::make_shared<Primitive>("MatMul");
|
||||||
inline const PrimitivePtr kPrimBatchMatMul = std::make_shared<Primitive>("BatchMatMul");
|
inline const PrimitivePtr kPrimBatchMatMul = std::make_shared<Primitive>("BatchMatMul");
|
||||||
inline const PrimitivePtr kPrimMaximumGrad = std::make_shared<Primitive>("MaximumGrad");
|
inline const PrimitivePtr kPrimMaximumGrad = std::make_shared<Primitive>("MaximumGrad");
|
||||||
|
|
|
@ -49,6 +49,6 @@ AbstractBasePtr AddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr
|
||||||
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
||||||
InferShape(primitive, input_args)->shape());
|
InferShape(primitive, input_args)->shape());
|
||||||
}
|
}
|
||||||
REGISTER_PRIMITIVE_EVAL_IMPL(Add, prim::kPrimTensorAdd, AddInfer);
|
REGISTER_PRIMITIVE_EVAL_IMPL(Add, prim::kPrimAdd, AddInfer);
|
||||||
REGISTER_PRIMITIVE_C(kNameAdd, Add);
|
REGISTER_PRIMITIVE_C(kNameAdd, Add);
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -989,7 +989,7 @@ class PConstant : public PBase<PConstant<T> > {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Arithmetic operations
|
// Arithmetic operations
|
||||||
BIN_OPERATION_PATTERN(operator+, prim::kPrimTensorAdd, true);
|
BIN_OPERATION_PATTERN(operator+, prim::kPrimAdd, true);
|
||||||
BIN_OPERATION_PATTERN(operator*, prim::kPrimMul, true);
|
BIN_OPERATION_PATTERN(operator*, prim::kPrimMul, true);
|
||||||
BIN_OPERATION_PATTERN(operator/, prim::kPrimRealDiv, false);
|
BIN_OPERATION_PATTERN(operator/, prim::kPrimRealDiv, false);
|
||||||
BIN_OPERATION_PATTERN(operator-, prim::kPrimSub, false);
|
BIN_OPERATION_PATTERN(operator-, prim::kPrimSub, false);
|
||||||
|
|
|
@ -225,7 +225,7 @@ class LambNextMV(GraphKernel):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(LambNextMV, self).__init__()
|
super(LambNextMV, self).__init__()
|
||||||
self.mul = P.Mul()
|
self.mul = P.Mul()
|
||||||
self.add = P.TensorAdd()
|
self.add = P.Add()
|
||||||
self.div = P.RealDiv()
|
self.div = P.RealDiv()
|
||||||
self.sqrt = P.Sqrt()
|
self.sqrt = P.Sqrt()
|
||||||
self.rsqrt = P.Rsqrt()
|
self.rsqrt = P.Rsqrt()
|
||||||
|
|
|
@ -651,7 +651,7 @@ class LogSigmoid(Cell):
|
||||||
super(LogSigmoid, self).__init__()
|
super(LogSigmoid, self).__init__()
|
||||||
self.mul = P.Mul()
|
self.mul = P.Mul()
|
||||||
self.exp = P.Exp()
|
self.exp = P.Exp()
|
||||||
self.add = P.TensorAdd()
|
self.add = P.Add()
|
||||||
self.rec = P.Reciprocal()
|
self.rec = P.Reciprocal()
|
||||||
self.log = P.Log()
|
self.log = P.Log()
|
||||||
|
|
||||||
|
|
|
@ -441,13 +441,13 @@ class MultiFieldEmbeddingLookup(EmbeddingLookup):
|
||||||
|
|
||||||
self.mul = P.Mul()
|
self.mul = P.Mul()
|
||||||
self.inf_mask_mul = P.Mul()
|
self.inf_mask_mul = P.Mul()
|
||||||
self.bias_add = P.TensorAdd()
|
self.bias_add = P.Add()
|
||||||
self.inf_add = P.TensorAdd()
|
self.inf_add = P.Add()
|
||||||
self.merge_op = None
|
self.merge_op = None
|
||||||
self.count_op = P.UnsortedSegmentSum()
|
self.count_op = P.UnsortedSegmentSum()
|
||||||
self.abs = P.Abs()
|
self.abs = P.Abs()
|
||||||
self.equal = P.Equal()
|
self.equal = P.Equal()
|
||||||
self.add = P.TensorAdd()
|
self.add = P.Add()
|
||||||
self.cast = P.Cast()
|
self.cast = P.Cast()
|
||||||
self.div_no_nan = P.DivNoNan()
|
self.div_no_nan = P.DivNoNan()
|
||||||
self.expand = P.ExpandDims()
|
self.expand = P.ExpandDims()
|
||||||
|
|
|
@ -99,8 +99,8 @@ class BatchNormFoldCell(Cell):
|
||||||
else:
|
else:
|
||||||
batch_mean = P.ZerosLike()(variance)
|
batch_mean = P.ZerosLike()(variance)
|
||||||
batch_std = P.OnesLike()(variance)
|
batch_std = P.OnesLike()(variance)
|
||||||
running_mean = P.TensorAdd()(mean, 0.)
|
running_mean = P.Add()(mean, 0.)
|
||||||
running_std = P.Sqrt()(P.TensorAdd()(variance, self.epsilon))
|
running_std = P.Sqrt()(P.Add()(variance, self.epsilon))
|
||||||
return batch_mean, batch_std, running_mean, running_std
|
return batch_mean, batch_std, running_mean, running_std
|
||||||
|
|
||||||
|
|
||||||
|
@ -559,7 +559,7 @@ class Conv2dBnFoldQuantOneConv(Cell):
|
||||||
return s
|
return s
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
running_std = P.Sqrt()(P.TensorAdd()(self.moving_variance, self.eps))
|
running_std = P.Sqrt()(P.Add()(self.moving_variance, self.eps))
|
||||||
scale_factor = self.gamma / running_std
|
scale_factor = self.gamma / running_std
|
||||||
if self.channel_axis:
|
if self.channel_axis:
|
||||||
scale_factor = self.reshape(scale_factor, (1, -1, 1, 1))
|
scale_factor = self.reshape(scale_factor, (1, -1, 1, 1))
|
||||||
|
@ -1236,7 +1236,7 @@ class TensorAddQuant(Cell):
|
||||||
ema=True,
|
ema=True,
|
||||||
ema_decay=ema_decay,
|
ema_decay=ema_decay,
|
||||||
quant_dtype=quant_dtype)
|
quant_dtype=quant_dtype)
|
||||||
self.add = P.TensorAdd()
|
self.add = P.Add()
|
||||||
|
|
||||||
def construct(self, x1, x2):
|
def construct(self, x1, x2):
|
||||||
x = self.add(x1, x2)
|
x = self.add(x1, x2)
|
||||||
|
|
|
@ -155,9 +155,9 @@ def bprop_batchmatmul(self):
|
||||||
return bprop
|
return bprop
|
||||||
|
|
||||||
|
|
||||||
@bprop_getters.register(P.TensorAdd)
|
@bprop_getters.register(P.Add)
|
||||||
def get_bprop_tensor_add(self):
|
def get_bprop_tensor_add(self):
|
||||||
"""Grad definition for `TensorAdd` operation."""
|
"""Grad definition for `Add` operation."""
|
||||||
|
|
||||||
def bprop(x, y, out, dout):
|
def bprop(x, y, out, dout):
|
||||||
return binop_grad_common(x, y, dout, dout)
|
return binop_grad_common(x, y, dout, dout)
|
||||||
|
|
|
@ -13,10 +13,10 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
"""TensorAdd op"""
|
"""Add op"""
|
||||||
from mindspore.ops.op_info_register import op_info_register, AkgAscendRegOp, DataType as DT
|
from mindspore.ops.op_info_register import op_info_register, AkgAscendRegOp, DataType as DT
|
||||||
|
|
||||||
op_info = AkgAscendRegOp("TensorAdd") \
|
op_info = AkgAscendRegOp("Add") \
|
||||||
.fusion_type("ELEMWISE") \
|
.fusion_type("ELEMWISE") \
|
||||||
.input(0, "x") \
|
.input(0, "x") \
|
||||||
.input(1, "y") \
|
.input(1, "y") \
|
||||||
|
@ -38,5 +38,5 @@ op_info = AkgAscendRegOp("TensorAdd") \
|
||||||
|
|
||||||
@op_info_register(op_info)
|
@op_info_register(op_info)
|
||||||
def _add_akg():
|
def _add_akg():
|
||||||
"""TensorAdd Akg register"""
|
"""Add Akg register"""
|
||||||
return
|
return
|
||||||
|
|
|
@ -13,10 +13,10 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
"""TensorAdd op"""
|
"""Add op"""
|
||||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||||
|
|
||||||
tensor_add_op_info = TBERegOp("TensorAdd") \
|
tensor_add_op_info = TBERegOp("Add") \
|
||||||
.fusion_type("ELEMWISE") \
|
.fusion_type("ELEMWISE") \
|
||||||
.async_flag(False) \
|
.async_flag(False) \
|
||||||
.binfile_name("add.so") \
|
.binfile_name("add.so") \
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
"""TensorAdd op"""
|
"""TensorAdd op"""
|
||||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||||
|
|
||||||
tensor_add_op_info = TBERegOp("TensorAdd") \
|
tensor_add_op_info = TBERegOp("Add") \
|
||||||
.fusion_type("ELEMWISE") \
|
.fusion_type("ELEMWISE") \
|
||||||
.async_flag(False) \
|
.async_flag(False) \
|
||||||
.binfile_name("add.so") \
|
.binfile_name("add.so") \
|
||||||
|
|
|
@ -395,7 +395,7 @@ class MultitypeFuncGraph(MultitypeFuncGraph_):
|
||||||
>>> from mindspore.ops import Primitive, operations as P
|
>>> from mindspore.ops import Primitive, operations as P
|
||||||
>>> from mindspore import dtype as mstype
|
>>> from mindspore import dtype as mstype
|
||||||
>>>
|
>>>
|
||||||
>>> tensor_add = P.TensorAdd()
|
>>> tensor_add = P.Add()
|
||||||
>>> add = MultitypeFuncGraph('add')
|
>>> add = MultitypeFuncGraph('add')
|
||||||
>>> @add.register("Number", "Number")
|
>>> @add.register("Number", "Number")
|
||||||
... def add_scala(x, y):
|
... def add_scala(x, y):
|
||||||
|
|
|
@ -51,7 +51,7 @@ merge = P.Merge()
|
||||||
geswitch = P.GeSwitch()
|
geswitch = P.GeSwitch()
|
||||||
addn = P.AddN()
|
addn = P.AddN()
|
||||||
absolute = P.Abs()
|
absolute = P.Abs()
|
||||||
tensor_add = P.TensorAdd()
|
tensor_add = P.Add()
|
||||||
neg_tensor = P.Neg()
|
neg_tensor = P.Neg()
|
||||||
tensor_lt = P.Less()
|
tensor_lt = P.Less()
|
||||||
tensor_le = P.LessEqual()
|
tensor_le = P.LessEqual()
|
||||||
|
|
|
@ -54,7 +54,7 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A
|
||||||
NPUGetFloatStatus, Pow, RealDiv, IsNan, IsInf, IsFinite, FloatStatus,
|
NPUGetFloatStatus, Pow, RealDiv, IsNan, IsInf, IsFinite, FloatStatus,
|
||||||
Reciprocal, CumSum, HistogramFixedWidth, SquaredDifference, Xdivy, Xlogy,
|
Reciprocal, CumSum, HistogramFixedWidth, SquaredDifference, Xdivy, Xlogy,
|
||||||
Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e, TruncateDiv, TruncateMod,
|
Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e, TruncateDiv, TruncateMod,
|
||||||
Square, Sub, TensorAdd, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps, Tan,
|
Square, Sub, TensorAdd, Add, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps, Tan,
|
||||||
MatrixInverse)
|
MatrixInverse)
|
||||||
|
|
||||||
from .random_ops import (RandomChoiceWithMask, StandardNormal, Gamma, Poisson, UniformInt, UniformReal,
|
from .random_ops import (RandomChoiceWithMask, StandardNormal, Gamma, Poisson, UniformInt, UniformReal,
|
||||||
|
@ -102,6 +102,7 @@ __all__ = [
|
||||||
'Sort',
|
'Sort',
|
||||||
'EditDistance',
|
'EditDistance',
|
||||||
'CropAndResize',
|
'CropAndResize',
|
||||||
|
'Add',
|
||||||
'TensorAdd',
|
'TensorAdd',
|
||||||
'Argmax',
|
'Argmax',
|
||||||
'Argmin',
|
'Argmin',
|
||||||
|
|
|
@ -106,7 +106,7 @@ class GeSwitch(PrimitiveWithInfer):
|
||||||
... def __init__(self):
|
... def __init__(self):
|
||||||
... super(Net, self).__init__()
|
... super(Net, self).__init__()
|
||||||
... self.square = ops.Square()
|
... self.square = ops.Square()
|
||||||
... self.add = ops.TensorAdd()
|
... self.add = ops.Add()
|
||||||
... self.value = Tensor(np.full((1), 3), mindspore.float32)
|
... self.value = Tensor(np.full((1), 3), mindspore.float32)
|
||||||
... self.switch = ops.GeSwitch()
|
... self.switch = ops.GeSwitch()
|
||||||
... self.merge = ops.Merge()
|
... self.merge = ops.Merge()
|
||||||
|
|
|
@ -66,7 +66,7 @@ class ScalarSummary(PrimitiveWithInfer):
|
||||||
... def __init__(self,):
|
... def __init__(self,):
|
||||||
... super(SummaryDemo, self).__init__()
|
... super(SummaryDemo, self).__init__()
|
||||||
... self.summary = ops.ScalarSummary()
|
... self.summary = ops.ScalarSummary()
|
||||||
... self.add = ops.TensorAdd()
|
... self.add = ops.Add()
|
||||||
...
|
...
|
||||||
... def construct(self, x, y):
|
... def construct(self, x, y):
|
||||||
... name = "x"
|
... name = "x"
|
||||||
|
@ -149,7 +149,7 @@ class TensorSummary(PrimitiveWithInfer):
|
||||||
... def __init__(self,):
|
... def __init__(self,):
|
||||||
... super(SummaryDemo, self).__init__()
|
... super(SummaryDemo, self).__init__()
|
||||||
... self.summary = ops.TensorSummary()
|
... self.summary = ops.TensorSummary()
|
||||||
... self.add = ops.TensorAdd()
|
... self.add = ops.Add()
|
||||||
...
|
...
|
||||||
... def construct(self, x, y):
|
... def construct(self, x, y):
|
||||||
... x = self.add(x, y)
|
... x = self.add(x, y)
|
||||||
|
@ -191,7 +191,7 @@ class HistogramSummary(PrimitiveWithInfer):
|
||||||
... def __init__(self,):
|
... def __init__(self,):
|
||||||
... super(SummaryDemo, self).__init__()
|
... super(SummaryDemo, self).__init__()
|
||||||
... self.summary = ops.HistogramSummary()
|
... self.summary = ops.HistogramSummary()
|
||||||
... self.add = ops.TensorAdd()
|
... self.add = ops.Add()
|
||||||
...
|
...
|
||||||
... def construct(self, x, y):
|
... def construct(self, x, y):
|
||||||
... x = self.add(x, y)
|
... x = self.add(x, y)
|
||||||
|
@ -409,7 +409,7 @@ class Assert(PrimitiveWithInfer):
|
||||||
... def __init__(self):
|
... def __init__(self):
|
||||||
... super(AssertDemo, self).__init__()
|
... super(AssertDemo, self).__init__()
|
||||||
... self.assert1 = ops.Assert(summarize=10)
|
... self.assert1 = ops.Assert(summarize=10)
|
||||||
... self.add = ops.TensorAdd()
|
... self.add = ops.Add()
|
||||||
...
|
...
|
||||||
... def construct(self, x, y):
|
... def construct(self, x, y):
|
||||||
... data = self.add(x, y)
|
... data = self.add(x, y)
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
import copy
|
import copy
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from mindspore import log as logger
|
||||||
from ... import context
|
from ... import context
|
||||||
from .. import signature as sig
|
from .. import signature as sig
|
||||||
from ..._checkparam import Validator as validator
|
from ..._checkparam import Validator as validator
|
||||||
|
@ -114,7 +115,7 @@ class _BitwiseBinaryOp(_MathBinaryOp):
|
||||||
return _BitwiseBinaryOp._check_bitwise_op_input_type(x1_type, x2_type, self.name)
|
return _BitwiseBinaryOp._check_bitwise_op_input_type(x1_type, x2_type, self.name)
|
||||||
|
|
||||||
|
|
||||||
class TensorAdd(_MathBinaryOp):
|
class Add(_MathBinaryOp):
|
||||||
r"""
|
r"""
|
||||||
Adds two input tensors element-wise.
|
Adds two input tensors element-wise.
|
||||||
|
|
||||||
|
@ -143,7 +144,7 @@ class TensorAdd(_MathBinaryOp):
|
||||||
``Ascend`` ``GPU`` ``CPU``
|
``Ascend`` ``GPU`` ``CPU``
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> add = ops.TensorAdd()
|
>>> add = ops.Add()
|
||||||
>>> input_x = Tensor(np.array([1, 2, 3]).astype(np.float32))
|
>>> input_x = Tensor(np.array([1, 2, 3]).astype(np.float32))
|
||||||
>>> input_y = Tensor(np.array([4, 5, 6]).astype(np.float32))
|
>>> input_y = Tensor(np.array([4, 5, 6]).astype(np.float32))
|
||||||
>>> output = add(input_x, input_y)
|
>>> output = add(input_x, input_y)
|
||||||
|
@ -160,6 +161,10 @@ class TensorAdd(_MathBinaryOp):
|
||||||
return Tensor(out)
|
return Tensor(out)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def TensorAdd():
|
||||||
|
"""Warning: This will be changed later"""
|
||||||
|
logger.warning("WARN_DEPRECATED: The usage of TensorAdd is deprecated. Please use Add.")
|
||||||
|
return Add()
|
||||||
|
|
||||||
class AssignAdd(PrimitiveWithInfer):
|
class AssignAdd(PrimitiveWithInfer):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
|
|
||||||
import mindspore.nn as nn
|
import mindspore.nn as nn
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
from mindspore.ops.operations import TensorAdd
|
from mindspore.ops.operations import Add
|
||||||
|
|
||||||
from src.var_init import KaimingNormal
|
from src.var_init import KaimingNormal
|
||||||
|
|
||||||
|
@ -91,7 +91,7 @@ class InvertedResidual(nn.Cell):
|
||||||
])
|
])
|
||||||
|
|
||||||
self.conv = nn.SequentialCell(layers)
|
self.conv = nn.SequentialCell(layers)
|
||||||
self.add = TensorAdd()
|
self.add = Add()
|
||||||
self.cast = P.Cast()
|
self.cast = P.Cast()
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
|
|
|
@ -198,7 +198,7 @@ class BasicBlock(nn.Cell):
|
||||||
self.bn2 = ms_fused_bn(planes)
|
self.bn2 = ms_fused_bn(planes)
|
||||||
self.relu = P.ReLU()
|
self.relu = P.ReLU()
|
||||||
self.downsample = downsample
|
self.downsample = downsample
|
||||||
self.add = P.TensorAdd()
|
self.add = P.Add()
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
residual = x
|
residual = x
|
||||||
|
|
|
@ -102,7 +102,7 @@ class Bottleneck(nn.Cell):
|
||||||
self.relu = nn.ReLU()
|
self.relu = nn.ReLU()
|
||||||
self.downsample = downsample
|
self.downsample = downsample
|
||||||
|
|
||||||
self.add = P.TensorAdd()
|
self.add = P.Add()
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
identity = x
|
identity = x
|
||||||
|
|
|
@ -222,7 +222,7 @@ class ResidualBlockUsing(nn.Cell):
|
||||||
self.bn_down_sample = self.bn_down_sample.set_train()
|
self.bn_down_sample = self.bn_down_sample.set_train()
|
||||||
if not weights_update:
|
if not weights_update:
|
||||||
self.conv_down_sample.weight.requires_grad = False
|
self.conv_down_sample.weight.requires_grad = False
|
||||||
self.add = P.TensorAdd()
|
self.add = P.Add()
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
identity = x
|
identity = x
|
||||||
|
|
|
@ -218,7 +218,7 @@ class ResidualBlockUsing(nn.Cell):
|
||||||
self.bn_down_sample = self.bn_down_sample.set_train()
|
self.bn_down_sample = self.bn_down_sample.set_train()
|
||||||
if not weights_update:
|
if not weights_update:
|
||||||
self.conv_down_sample.weight.requires_grad = False
|
self.conv_down_sample.weight.requires_grad = False
|
||||||
self.add = P.TensorAdd()
|
self.add = P.Add()
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
identity = x
|
identity = x
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import mindspore.nn as nn
|
import mindspore.nn as nn
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
from mindspore.ops.operations import TensorAdd
|
from mindspore.ops.operations import Add
|
||||||
from mindspore import Tensor
|
from mindspore import Tensor
|
||||||
|
|
||||||
__all__ = ['MobileNetV2', 'MobileNetV2Backbone', 'MobileNetV2Head', 'mobilenet_v2']
|
__all__ = ['MobileNetV2', 'MobileNetV2Backbone', 'MobileNetV2Head', 'mobilenet_v2']
|
||||||
|
@ -129,7 +129,7 @@ class InvertedResidual(nn.Cell):
|
||||||
nn.BatchNorm2d(oup),
|
nn.BatchNorm2d(oup),
|
||||||
])
|
])
|
||||||
self.conv = nn.SequentialCell(layers)
|
self.conv = nn.SequentialCell(layers)
|
||||||
self.add = TensorAdd()
|
self.add = Add()
|
||||||
self.cast = P.Cast()
|
self.cast = P.Cast()
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
|
|
|
@ -120,7 +120,7 @@ class InvertedResidual(nn.Cell):
|
||||||
nn.Conv2dBnAct(hidden_dim, oup, kernel_size=1, stride=1, pad_mode='pad', padding=0, group=1, has_bn=True)
|
nn.Conv2dBnAct(hidden_dim, oup, kernel_size=1, stride=1, pad_mode='pad', padding=0, group=1, has_bn=True)
|
||||||
])
|
])
|
||||||
self.conv = nn.SequentialCell(layers)
|
self.conv = nn.SequentialCell(layers)
|
||||||
self.add = P.TensorAdd()
|
self.add = P.Add()
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
out = self.conv(x)
|
out = self.conv(x)
|
||||||
|
|
|
@ -123,7 +123,7 @@ class InvertedResidual(nn.Cell):
|
||||||
nn.Conv2dBnAct(hidden_dim, oup, kernel_size=1, stride=1, pad_mode='pad', padding=0, group=1, has_bn=True)
|
nn.Conv2dBnAct(hidden_dim, oup, kernel_size=1, stride=1, pad_mode='pad', padding=0, group=1, has_bn=True)
|
||||||
])
|
])
|
||||||
self.conv = nn.SequentialCell(layers)
|
self.conv = nn.SequentialCell(layers)
|
||||||
self.add = P.TensorAdd()
|
self.add = P.Add()
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
out = self.conv(x)
|
out = self.conv(x)
|
||||||
|
|
|
@ -197,7 +197,7 @@ class ResUnit(nn.Cell):
|
||||||
padding=0, act_type=act_type, use_act=False)
|
padding=0, act_type=act_type, use_act=False)
|
||||||
if num_in != num_out or stride != 1:
|
if num_in != num_out or stride != 1:
|
||||||
self.use_short_cut_conv = False
|
self.use_short_cut_conv = False
|
||||||
self.add = P.TensorAdd() if self.use_short_cut_conv else None
|
self.add = P.Add() if self.use_short_cut_conv else None
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
"""construct"""
|
"""construct"""
|
||||||
|
|
|
@ -49,7 +49,7 @@ class DiceLoss(_Loss):
|
||||||
self.logical_or = P.LogicalOr()
|
self.logical_or = P.LogicalOr()
|
||||||
self.equal = P.Equal()
|
self.equal = P.Equal()
|
||||||
self.zeros_like = P.ZerosLike()
|
self.zeros_like = P.ZerosLike()
|
||||||
self.add = P.TensorAdd()
|
self.add = P.Add()
|
||||||
self.gather = P.Gather()
|
self.gather = P.Gather()
|
||||||
|
|
||||||
def ohem_batch(self, scores, gt_texts, training_masks):
|
def ohem_batch(self, scores, gt_texts, training_masks):
|
||||||
|
|
|
@ -61,7 +61,7 @@ class ResidualBlock(nn.Cell):
|
||||||
kernel_size=1, stride=stride)
|
kernel_size=1, stride=stride)
|
||||||
self.bn_down_sample = _bn(out_channels, momentum=momentum)
|
self.bn_down_sample = _bn(out_channels, momentum=momentum)
|
||||||
|
|
||||||
self.add = P.TensorAdd()
|
self.add = P.Add()
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
identity = x
|
identity = x
|
||||||
|
|
|
@ -152,7 +152,7 @@ class ResidualBlock(nn.Cell):
|
||||||
else:
|
else:
|
||||||
self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride,
|
self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride,
|
||||||
use_se=self.use_se), _bn(out_channel)])
|
use_se=self.use_se), _bn(out_channel)])
|
||||||
self.add = P.TensorAdd()
|
self.add = P.Add()
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
identity = x
|
identity = x
|
||||||
|
|
|
@ -119,7 +119,7 @@ class ResidualBlock(nn.Cell):
|
||||||
|
|
||||||
if self.down_sample:
|
if self.down_sample:
|
||||||
self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride), _bn(out_channel)])
|
self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride), _bn(out_channel)])
|
||||||
self.add = P.TensorAdd()
|
self.add = P.Add()
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
identity = x
|
identity = x
|
||||||
|
|
|
@ -85,7 +85,7 @@ class ResidualBlock(nn.Cell):
|
||||||
self.down_sample_layer = nn.Conv2dBnAct(in_channel, out_channel,
|
self.down_sample_layer = nn.Conv2dBnAct(in_channel, out_channel,
|
||||||
kernel_size=1, stride=stride,
|
kernel_size=1, stride=stride,
|
||||||
pad_mode='same', padding=0, has_bn=True, activation='relu')
|
pad_mode='same', padding=0, has_bn=True, activation='relu')
|
||||||
self.add = P.TensorAdd()
|
self.add = P.Add()
|
||||||
self.relu = P.ReLU()
|
self.relu = P.ReLU()
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
|
|
|
@ -215,7 +215,7 @@ class ResidualBlock(nn.Cell):
|
||||||
frequency=frequency,
|
frequency=frequency,
|
||||||
batch_size=batch_size),
|
batch_size=batch_size),
|
||||||
_bn(out_channel)])
|
_bn(out_channel)])
|
||||||
self.add = P.TensorAdd()
|
self.add = P.Add()
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
identity = x
|
identity = x
|
||||||
|
|
|
@ -333,7 +333,7 @@ class Dense_Thor_GPU(Cell):
|
||||||
self.gather = P.Gather()
|
self.gather = P.Gather()
|
||||||
self.freq = Tensor(frequency, mstype.int32)
|
self.freq = Tensor(frequency, mstype.int32)
|
||||||
self.axis = 0
|
self.axis = 0
|
||||||
self.add = P.TensorAdd()
|
self.add = P.Add()
|
||||||
self.sqrt = P.Sqrt()
|
self.sqrt = P.Sqrt()
|
||||||
self.cholesky = P.CholeskyTrsm(split_dim=split_dim)
|
self.cholesky = P.CholeskyTrsm(split_dim=split_dim)
|
||||||
self.vector_matmul = P.BatchMatMul(transpose_a=True)
|
self.vector_matmul = P.BatchMatMul(transpose_a=True)
|
||||||
|
@ -690,7 +690,7 @@ class Dense_Thor(Cell):
|
||||||
self.exp = P.Exp()
|
self.exp = P.Exp()
|
||||||
self.dampingA = Tensor(np.identity(2048), mstype.float32)
|
self.dampingA = Tensor(np.identity(2048), mstype.float32)
|
||||||
self.dampingG = Tensor(np.identity(1024), mstype.float32)
|
self.dampingG = Tensor(np.identity(1024), mstype.float32)
|
||||||
self.add = P.TensorAdd()
|
self.add = P.Add()
|
||||||
self.sqrt = P.Sqrt()
|
self.sqrt = P.Sqrt()
|
||||||
self.getG = P.InsertGradientOf(self.save_gradient)
|
self.getG = P.InsertGradientOf(self.save_gradient)
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
ResNet based ResNext
|
ResNet based ResNext
|
||||||
"""
|
"""
|
||||||
import mindspore.nn as nn
|
import mindspore.nn as nn
|
||||||
from mindspore.ops.operations import TensorAdd, Split, Concat
|
from mindspore.ops.operations import Add, Split, Concat
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
from mindspore.common.initializer import TruncatedNormal
|
from mindspore.common.initializer import TruncatedNormal
|
||||||
|
|
||||||
|
@ -105,7 +105,7 @@ class BasicBlock(nn.Cell):
|
||||||
self.down_sample = down_sample
|
self.down_sample = down_sample
|
||||||
self.down_sample_flag = True
|
self.down_sample_flag = True
|
||||||
|
|
||||||
self.add = TensorAdd()
|
self.add = Add()
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
identity = x
|
identity = x
|
||||||
|
@ -176,7 +176,7 @@ class Bottleneck(nn.Cell):
|
||||||
self.down_sample_flag = True
|
self.down_sample_flag = True
|
||||||
|
|
||||||
self.cast = P.Cast()
|
self.cast = P.Cast()
|
||||||
self.add = TensorAdd()
|
self.add = Add()
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
identity = x
|
identity = x
|
||||||
|
|
|
@ -95,7 +95,7 @@ class ResidualBlock(nn.Cell):
|
||||||
if self.down_sample:
|
if self.down_sample:
|
||||||
self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride),
|
self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride),
|
||||||
_bn(out_channel)])
|
_bn(out_channel)])
|
||||||
self.add = P.TensorAdd()
|
self.add = P.Add()
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
identity = x
|
identity = x
|
||||||
|
|
|
@ -68,7 +68,7 @@ class ShuffleV1Block(nn.Cell):
|
||||||
outputs = oup
|
outputs = oup
|
||||||
|
|
||||||
self.relu = nn.ReLU()
|
self.relu = nn.ReLU()
|
||||||
self.add = P.TensorAdd()
|
self.add = P.Add()
|
||||||
self.concat = P.Concat(1)
|
self.concat = P.Concat(1)
|
||||||
self.shape = P.Shape()
|
self.shape = P.Shape()
|
||||||
self.transpose = P.Transpose()
|
self.transpose = P.Transpose()
|
||||||
|
|
|
@ -170,7 +170,7 @@ class SqueezeNet_Residual(nn.Cell):
|
||||||
|
|
||||||
self.relu = nn.ReLU()
|
self.relu = nn.ReLU()
|
||||||
self.max_pool2d = nn.MaxPool2d(kernel_size=3, stride=2)
|
self.max_pool2d = nn.MaxPool2d(kernel_size=3, stride=2)
|
||||||
self.add = P.TensorAdd()
|
self.add = P.Add()
|
||||||
self.dropout = nn.Dropout(keep_prob=0.5)
|
self.dropout = nn.Dropout(keep_prob=0.5)
|
||||||
self.mean = P.ReduceMean(keep_dims=True)
|
self.mean = P.ReduceMean(keep_dims=True)
|
||||||
self.flatten = nn.Flatten()
|
self.flatten = nn.Flatten()
|
||||||
|
|
|
@ -133,7 +133,7 @@ class InvertedResidual(nn.Cell):
|
||||||
_bn(oup),
|
_bn(oup),
|
||||||
])
|
])
|
||||||
self.conv = nn.SequentialCell(layers)
|
self.conv = nn.SequentialCell(layers)
|
||||||
self.add = P.TensorAdd()
|
self.add = P.Add()
|
||||||
self.cast = P.Cast()
|
self.cast = P.Cast()
|
||||||
self.last_relu = last_relu
|
self.last_relu = last_relu
|
||||||
self.relu = nn.ReLU6()
|
self.relu = nn.ReLU6()
|
||||||
|
|
|
@ -68,7 +68,7 @@ class Block(nn.Cell):
|
||||||
if strides != 1:
|
if strides != 1:
|
||||||
rep.append(nn.MaxPool2d(3, strides, pad_mode="same"))
|
rep.append(nn.MaxPool2d(3, strides, pad_mode="same"))
|
||||||
self.rep = nn.SequentialCell(*rep)
|
self.rep = nn.SequentialCell(*rep)
|
||||||
self.add = P.TensorAdd()
|
self.add = P.Add()
|
||||||
|
|
||||||
def construct(self, inp):
|
def construct(self, inp):
|
||||||
x = self.rep(inp)
|
x = self.rep(inp)
|
||||||
|
|
|
@ -62,7 +62,7 @@ class ResidualBlock(nn.Cell):
|
||||||
out_chls = out_channels//2
|
out_chls = out_channels//2
|
||||||
self.conv1 = conv_block(in_channels, out_chls, kernel_size=1, stride=1)
|
self.conv1 = conv_block(in_channels, out_chls, kernel_size=1, stride=1)
|
||||||
self.conv2 = conv_block(out_chls, out_channels, kernel_size=3, stride=1)
|
self.conv2 = conv_block(out_chls, out_channels, kernel_size=3, stride=1)
|
||||||
self.add = P.TensorAdd()
|
self.add = P.Add()
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
identity = x
|
identity = x
|
||||||
|
|
|
@ -59,7 +59,7 @@ class ResidualBlock(nn.Cell):
|
||||||
out_chls = out_channels//2
|
out_chls = out_channels//2
|
||||||
self.conv1 = conv_block(in_channels, out_chls, kernel_size=1, stride=1)
|
self.conv1 = conv_block(in_channels, out_chls, kernel_size=1, stride=1)
|
||||||
self.conv2 = conv_block(out_chls, out_channels, kernel_size=3, stride=1)
|
self.conv2 = conv_block(out_chls, out_channels, kernel_size=3, stride=1)
|
||||||
self.add = P.TensorAdd()
|
self.add = P.Add()
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
identity = x
|
identity = x
|
||||||
|
|
|
@ -107,7 +107,7 @@ class BasicBlock(nn.Cell):
|
||||||
self.downsample = (in_channels != out_channels)
|
self.downsample = (in_channels != out_channels)
|
||||||
if self.downsample:
|
if self.downsample:
|
||||||
self.down_sample_layer = _conv2d(in_channels, out_channels, 1, stride=stride)
|
self.down_sample_layer = _conv2d(in_channels, out_channels, 1, stride=stride)
|
||||||
self.add = P.TensorAdd()
|
self.add = P.Add()
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
identity = x
|
identity = x
|
||||||
|
|
|
@ -76,7 +76,7 @@ class ResidualBlock(nn.Cell):
|
||||||
out_chls = out_channels
|
out_chls = out_channels
|
||||||
self.conv1 = conv_block(in_channels, out_chls, kernel_size=1, stride=1)
|
self.conv1 = conv_block(in_channels, out_chls, kernel_size=1, stride=1)
|
||||||
self.conv2 = conv_block(out_chls, out_channels, kernel_size=3, stride=1)
|
self.conv2 = conv_block(out_chls, out_channels, kernel_size=3, stride=1)
|
||||||
self.add = P.TensorAdd()
|
self.add = P.Add()
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
identity = x
|
identity = x
|
||||||
|
@ -111,7 +111,7 @@ class CspDarkNet53(nn.Cell):
|
||||||
self.outchannel = 1024
|
self.outchannel = 1024
|
||||||
self.detect = detect
|
self.detect = detect
|
||||||
self.concat = P.Concat(axis=1)
|
self.concat = P.Concat(axis=1)
|
||||||
self.add = P.TensorAdd()
|
self.add = P.Add()
|
||||||
|
|
||||||
self.conv0 = conv_block(3, 32, kernel_size=3, stride=1)
|
self.conv0 = conv_block(3, 32, kernel_size=3, stride=1)
|
||||||
self.conv1 = conv_block(32, 64, kernel_size=3, stride=2)
|
self.conv1 = conv_block(32, 64, kernel_size=3, stride=2)
|
||||||
|
|
|
@ -188,7 +188,7 @@ class EmbeddingPostprocessor(nn.Cell):
|
||||||
use_one_hot=False)
|
use_one_hot=False)
|
||||||
self.layernorm = nn.LayerNorm((embedding_size,))
|
self.layernorm = nn.LayerNorm((embedding_size,))
|
||||||
self.position_ids = Tensor(np.arange(seq).reshape(-1, seq).astype(np.int32))
|
self.position_ids = Tensor(np.arange(seq).reshape(-1, seq).astype(np.int32))
|
||||||
self.add = P.TensorAdd()
|
self.add = P.Add()
|
||||||
|
|
||||||
def construct(self, token_type_ids, word_embeddings):
|
def construct(self, token_type_ids, word_embeddings):
|
||||||
"""Postprocessors apply positional and token type embeddings to word embeddings."""
|
"""Postprocessors apply positional and token type embeddings to word embeddings."""
|
||||||
|
@ -226,7 +226,7 @@ class BertOutput(nn.Cell):
|
||||||
weight_init=TruncatedNormal(initializer_range)).to_float(compute_type)
|
weight_init=TruncatedNormal(initializer_range)).to_float(compute_type)
|
||||||
self.dropout = nn.Dropout(1 - dropout_prob)
|
self.dropout = nn.Dropout(1 - dropout_prob)
|
||||||
self.dropout_prob = dropout_prob
|
self.dropout_prob = dropout_prob
|
||||||
self.add = P.TensorAdd()
|
self.add = P.Add()
|
||||||
self.layernorm = nn.LayerNorm((out_channels,)).to_float(compute_type)
|
self.layernorm = nn.LayerNorm((out_channels,)).to_float(compute_type)
|
||||||
self.cast = P.Cast()
|
self.cast = P.Cast()
|
||||||
|
|
||||||
|
@ -444,7 +444,7 @@ class BertAttention(nn.Cell):
|
||||||
if self.has_attention_mask:
|
if self.has_attention_mask:
|
||||||
self.expand_dims = P.ExpandDims()
|
self.expand_dims = P.ExpandDims()
|
||||||
self.sub = P.Sub()
|
self.sub = P.Sub()
|
||||||
self.add = P.TensorAdd()
|
self.add = P.Add()
|
||||||
self.cast = P.Cast()
|
self.cast = P.Cast()
|
||||||
self.get_dtype = P.DType()
|
self.get_dtype = P.DType()
|
||||||
if do_return_2d_tensor:
|
if do_return_2d_tensor:
|
||||||
|
|
|
@ -227,7 +227,7 @@ class EmbeddingPostprocessor(nn.Cell):
|
||||||
frequency=frequency)
|
frequency=frequency)
|
||||||
self.position_ids = Tensor(np.arange(seq).reshape(-1, seq).astype(np.int32))
|
self.position_ids = Tensor(np.arange(seq).reshape(-1, seq).astype(np.int32))
|
||||||
self.layernorm = nn.LayerNorm((embedding_size,))
|
self.layernorm = nn.LayerNorm((embedding_size,))
|
||||||
self.add = P.TensorAdd()
|
self.add = P.Add()
|
||||||
|
|
||||||
def construct(self, token_type_ids, word_embeddings):
|
def construct(self, token_type_ids, word_embeddings):
|
||||||
"""construct of EmbeddingPostprocessor"""
|
"""construct of EmbeddingPostprocessor"""
|
||||||
|
@ -275,7 +275,7 @@ class BertOutput(nn.Cell):
|
||||||
batch_size=batch_size).to_float(compute_type)
|
batch_size=batch_size).to_float(compute_type)
|
||||||
self.dropout = nn.Dropout(1 - dropout_prob)
|
self.dropout = nn.Dropout(1 - dropout_prob)
|
||||||
self.dropout_prob = dropout_prob
|
self.dropout_prob = dropout_prob
|
||||||
self.add = P.TensorAdd()
|
self.add = P.Add()
|
||||||
self.layernorm = nn.LayerNorm((out_channels,)).to_float(compute_type)
|
self.layernorm = nn.LayerNorm((out_channels,)).to_float(compute_type)
|
||||||
self.cast = P.Cast()
|
self.cast = P.Cast()
|
||||||
|
|
||||||
|
@ -522,7 +522,7 @@ class BertAttention(nn.Cell):
|
||||||
if self.has_attention_mask:
|
if self.has_attention_mask:
|
||||||
self.expand_dims = P.ExpandDims()
|
self.expand_dims = P.ExpandDims()
|
||||||
self.sub = P.Sub()
|
self.sub = P.Sub()
|
||||||
self.add = P.TensorAdd()
|
self.add = P.Add()
|
||||||
self.cast = P.Cast()
|
self.cast = P.Cast()
|
||||||
self.get_dtype = P.DType()
|
self.get_dtype = P.DType()
|
||||||
if do_return_2d_tensor:
|
if do_return_2d_tensor:
|
||||||
|
|
|
@ -35,7 +35,7 @@ class LengthPenalty(nn.Cell):
|
||||||
def __init__(self, weight=1.0, compute_type=mstype.float32):
|
def __init__(self, weight=1.0, compute_type=mstype.float32):
|
||||||
super(LengthPenalty, self).__init__()
|
super(LengthPenalty, self).__init__()
|
||||||
self.weight = weight
|
self.weight = weight
|
||||||
self.add = P.TensorAdd()
|
self.add = P.Add()
|
||||||
self.pow = P.Pow()
|
self.pow = P.Pow()
|
||||||
self.div = P.RealDiv()
|
self.div = P.RealDiv()
|
||||||
self.five = Tensor(5.0, mstype.float32)
|
self.five = Tensor(5.0, mstype.float32)
|
||||||
|
@ -188,7 +188,7 @@ class BeamSearchDecoder(nn.Cell):
|
||||||
self.decoder = decoder
|
self.decoder = decoder
|
||||||
self.is_using_while = is_using_while
|
self.is_using_while = is_using_while
|
||||||
|
|
||||||
self.add = P.TensorAdd()
|
self.add = P.Add()
|
||||||
self.expand = P.ExpandDims()
|
self.expand = P.ExpandDims()
|
||||||
self.reshape = P.Reshape()
|
self.reshape = P.Reshape()
|
||||||
self.shape_flat = (-1,)
|
self.shape_flat = (-1,)
|
||||||
|
|
|
@ -36,7 +36,7 @@ class LengthPenalty(nn.Cell):
|
||||||
super(LengthPenalty, self).__init__()
|
super(LengthPenalty, self).__init__()
|
||||||
self.weight = weight
|
self.weight = weight
|
||||||
|
|
||||||
self.add = P.TensorAdd()
|
self.add = P.Add()
|
||||||
self.pow = P.Pow()
|
self.pow = P.Pow()
|
||||||
self.div = P.RealDiv()
|
self.div = P.RealDiv()
|
||||||
|
|
||||||
|
@ -178,7 +178,7 @@ class BeamSearchDecoder(nn.Cell):
|
||||||
|
|
||||||
self.decoder = decoder
|
self.decoder = decoder
|
||||||
|
|
||||||
self.add = P.TensorAdd()
|
self.add = P.Add()
|
||||||
self.expand = P.ExpandDims()
|
self.expand = P.ExpandDims()
|
||||||
self.reshape = P.Reshape()
|
self.reshape = P.Reshape()
|
||||||
self.shape_flat = (-1,)
|
self.shape_flat = (-1,)
|
||||||
|
|
|
@ -138,7 +138,7 @@ class MultiHeadAttention(nn.Cell):
|
||||||
if self.has_attention_mask:
|
if self.has_attention_mask:
|
||||||
self.expand_dims = P.ExpandDims()
|
self.expand_dims = P.ExpandDims()
|
||||||
self.sub = P.Sub()
|
self.sub = P.Sub()
|
||||||
self.add = P.TensorAdd()
|
self.add = P.Add()
|
||||||
self.cast = P.Cast()
|
self.cast = P.Cast()
|
||||||
self.get_dtype = P.DType()
|
self.get_dtype = P.DType()
|
||||||
|
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue