forked from mindspore-Ecosystem/mindspore
Change GatherV2 to Gather r1.1 to master
This commit is contained in:
parent
8a61767f32
commit
9fa0499fa0
|
@ -187,8 +187,8 @@
|
|||
{"op_name": "ReduceMean", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "axis", "param_type": "optional", "type": "listInt", "value": "all"}, {"name": "keep_dims", "param_type": "optional", "type": "bool", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["int8", ""], ["int8", ""]], [["uint8", ""], ["uint8", ""]], [["float16", ""], ["float16", ""]], [["float32", ""], ["float32", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "reduce_mean.so", "compute_cost": 10, "kernel_name": "reduce_mean_d", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "dynamic_shape": false, "need_check_supported": false, "op_pattern": "reduce"}
|
||||
{"op_name": "Tile", "inputs": [{"index": 0, "name": "x1", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "multiples", "param_type": "optional", "type": "listInt", "value": "all"}], "fusion_type": "ELEMWISE", "dtype_format": [[["", ""], ["", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "tile_d.so", "compute_cost": 10, "kernel_name": "tile_d", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "dynamic_shape": false, "need_check_supported": false, "op_pattern": "dynamicFormat"}
|
||||
{"op_name": "AtomicAddrClean", "inputs": [], "outputs": [], "attr": [{"name": "automic_add_mem_size", "param_type": "required", "type": "listUInt64", "value": "all"}], "fusion_type": "ELEMWISE", "dtype_format": [], "imply_type": "TBE", "async_flag": false, "binfile_name": "atomic_addr_clean.so", "compute_cost": 10, "kernel_name": "atomic_addr_clean", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "dynamic_shape": false, "need_check_supported": false, "op_pattern": ""}
|
||||
{"op_name": "GatherV2", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "indices", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "axis", "param_type": "required", "type": "int", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "DefaultFormat"], ["int32", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "DefaultFormat"], ["int32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["int8", "DefaultFormat"], ["int32", "DefaultFormat"], ["int8", "DefaultFormat"]], [["uint8", "DefaultFormat"], ["int32", "DefaultFormat"], ["uint8", "DefaultFormat"]], [["int32", "DefaultFormat"], ["int32", "DefaultFormat"], ["int32", "DefaultFormat"]], [["uint32", "DefaultFormat"], ["int32", "DefaultFormat"], ["uint32", "DefaultFormat"]], [["int16", "DefaultFormat"], ["int32", "DefaultFormat"], ["int16", "DefaultFormat"]], [["uint16", "DefaultFormat"], ["int32", "DefaultFormat"], ["uint16", "DefaultFormat"]], [["int64", "DefaultFormat"], ["int32", "DefaultFormat"], ["int64", "DefaultFormat"]], [["uint64", "DefaultFormat"], ["int32", "DefaultFormat"], ["uint64", "DefaultFormat"]], [["float16", "DefaultFormat"], ["int64", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "DefaultFormat"], ["int64", "DefaultFormat"], ["float32", "DefaultFormat"]], [["int8", "DefaultFormat"], ["int64", "DefaultFormat"], ["int8", "DefaultFormat"]], [["uint8", "DefaultFormat"], ["int64", "DefaultFormat"], ["uint8", "DefaultFormat"]], [["int32", "DefaultFormat"], ["int64", "DefaultFormat"], ["int32", "DefaultFormat"]], [["uint32", "DefaultFormat"], ["int64", "DefaultFormat"], ["uint32", "DefaultFormat"]], [["int16", "DefaultFormat"], ["int64", "DefaultFormat"], ["int16", "DefaultFormat"]], [["uint16", "DefaultFormat"], ["int64", "DefaultFormat"], ["uint16", "DefaultFormat"]], [["int64", "DefaultFormat"], ["int64", "DefaultFormat"], ["int64", "DefaultFormat"]], [["uint64", "DefaultFormat"], ["int64", "DefaultFormat"], ["uint64", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "gather_v2_d.so", "compute_cost": 10, "kernel_name": "gather_v2_d", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "dynamic_shape": false, "need_check_supported": false, "op_pattern": ""}
|
||||
{"op_name": "GatherV2", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "indices", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "axis", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["int8", "DefaultFormat"], ["int32", "DefaultFormat"], ["int32", "DefaultFormat"], ["int8", "DefaultFormat"]], [["int8", "DefaultFormat"], ["int64", "DefaultFormat"], ["int32", "DefaultFormat"], ["int8", "DefaultFormat"]], [["int8", "NC1HWC0"], ["int32", "NC1HWC0"], ["int32", "NC1HWC0"], ["int8", "NC1HWC0"]], [["int8", "NC1HWC0"], ["int64", "NC1HWC0"], ["int32", "NC1HWC0"], ["int8", "NC1HWC0"]], [["int8", "FracZ"], ["int32", "FracZ"], ["int32", "FracZ"], ["int8", "FracZ"]], [["int8", "FracZ"], ["int64", "FracZ"], ["int32", "FracZ"], ["int8", "FracZ"]], [["uint8", "DefaultFormat"], ["int32", "DefaultFormat"], ["int32", "DefaultFormat"], ["uint8", "DefaultFormat"]], [["uint8", "DefaultFormat"], ["int64", "DefaultFormat"], ["int32", "DefaultFormat"], ["uint8", "DefaultFormat"]], [["uint8", "NC1HWC0"], ["int32", "NC1HWC0"], ["int32", "NC1HWC0"], ["uint8", "NC1HWC0"]], [["uint8", "NC1HWC0"], ["int64", "NC1HWC0"], ["int32", "NC1HWC0"], ["uint8", "NC1HWC0"]], [["uint8", "FracZ"], ["int32", "FracZ"], ["int32", "FracZ"], ["uint8", "FracZ"]], [["uint8", "FracZ"], ["int64", "FracZ"], ["int32", "FracZ"], ["uint8", "FracZ"]], [["int32", "DefaultFormat"], ["int32", "DefaultFormat"], ["int32", "DefaultFormat"], ["int32", "DefaultFormat"]], [["int32", "DefaultFormat"], ["int64", "DefaultFormat"], ["int32", "DefaultFormat"], ["int32", "DefaultFormat"]], [["int32", "NC1HWC0"], ["int32", "NC1HWC0"], ["int32", "NC1HWC0"], ["int32", "NC1HWC0"]], [["int32", "NC1HWC0"], ["int64", "NC1HWC0"], ["int32", "NC1HWC0"], ["int32", "NC1HWC0"]], [["int32", "FracZ"], ["int32", "FracZ"], ["int32", "FracZ"], ["int32", "FracZ"]], [["int32", "FracZ"], ["int64", "FracZ"], ["int32", "FracZ"], ["int32", "FracZ"]], [["float16", "DefaultFormat"], ["int32", "DefaultFormat"], ["int32", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float16", "DefaultFormat"], ["int64", "DefaultFormat"], ["int32", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float16", "NC1HWC0"], ["int32", "NC1HWC0"], ["int32", "NC1HWC0"], ["float16", "NC1HWC0"]], [["float16", "NC1HWC0"], ["int64", "NC1HWC0"], ["int32", "NC1HWC0"], ["float16", "NC1HWC0"]], [["float16", "FracZ"], ["int32", "FracZ"], ["int32", "FracZ"], ["float16", "FracZ"]], [["float16", "FracZ"], ["int64", "FracZ"], ["int32", "FracZ"], ["float16", "FracZ"]], [["float32", "DefaultFormat"], ["int32", "DefaultFormat"], ["int32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["float32", "DefaultFormat"], ["int64", "DefaultFormat"], ["int32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["float32", "NC1HWC0"], ["int32", "NC1HWC0"], ["int32", "NC1HWC0"], ["float32", "NC1HWC0"]], [["float32", "NC1HWC0"], ["int64", "NC1HWC0"], ["int32", "NC1HWC0"], ["float32", "NC1HWC0"]], [["float32", "FracZ"], ["int32", "FracZ"], ["int32", "FracZ"], ["float32", "FracZ"]], [["float32", "FracZ"], ["int64", "FracZ"], ["int32", "FracZ"], ["float32", "FracZ"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "gather_v2.so", "compute_cost": 10, "kernel_name": "gather_v2", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "dynamic_shape": true, "need_check_supported": false, "op_pattern": ""}
|
||||
{"op_name": "Gather", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "indices", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "axis", "param_type": "required", "type": "int", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "DefaultFormat"], ["int32", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "DefaultFormat"], ["int32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["int8", "DefaultFormat"], ["int32", "DefaultFormat"], ["int8", "DefaultFormat"]], [["uint8", "DefaultFormat"], ["int32", "DefaultFormat"], ["uint8", "DefaultFormat"]], [["int32", "DefaultFormat"], ["int32", "DefaultFormat"], ["int32", "DefaultFormat"]], [["uint32", "DefaultFormat"], ["int32", "DefaultFormat"], ["uint32", "DefaultFormat"]], [["int16", "DefaultFormat"], ["int32", "DefaultFormat"], ["int16", "DefaultFormat"]], [["uint16", "DefaultFormat"], ["int32", "DefaultFormat"], ["uint16", "DefaultFormat"]], [["int64", "DefaultFormat"], ["int32", "DefaultFormat"], ["int64", "DefaultFormat"]], [["uint64", "DefaultFormat"], ["int32", "DefaultFormat"], ["uint64", "DefaultFormat"]], [["float16", "DefaultFormat"], ["int64", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "DefaultFormat"], ["int64", "DefaultFormat"], ["float32", "DefaultFormat"]], [["int8", "DefaultFormat"], ["int64", "DefaultFormat"], ["int8", "DefaultFormat"]], [["uint8", "DefaultFormat"], ["int64", "DefaultFormat"], ["uint8", "DefaultFormat"]], [["int32", "DefaultFormat"], ["int64", "DefaultFormat"], ["int32", "DefaultFormat"]], [["uint32", "DefaultFormat"], ["int64", "DefaultFormat"], ["uint32", "DefaultFormat"]], [["int16", "DefaultFormat"], ["int64", "DefaultFormat"], ["int16", "DefaultFormat"]], [["uint16", "DefaultFormat"], ["int64", "DefaultFormat"], ["uint16", "DefaultFormat"]], [["int64", "DefaultFormat"], ["int64", "DefaultFormat"], ["int64", "DefaultFormat"]], [["uint64", "DefaultFormat"], ["int64", "DefaultFormat"], ["uint64", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "gather_v2_d.so", "compute_cost": 10, "kernel_name": "gather_v2_d", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "dynamic_shape": false, "need_check_supported": false, "op_pattern": ""}
|
||||
{"op_name": "Gather", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "indices", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "axis", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["int8", "DefaultFormat"], ["int32", "DefaultFormat"], ["int32", "DefaultFormat"], ["int8", "DefaultFormat"]], [["int8", "DefaultFormat"], ["int64", "DefaultFormat"], ["int32", "DefaultFormat"], ["int8", "DefaultFormat"]], [["int8", "NC1HWC0"], ["int32", "NC1HWC0"], ["int32", "NC1HWC0"], ["int8", "NC1HWC0"]], [["int8", "NC1HWC0"], ["int64", "NC1HWC0"], ["int32", "NC1HWC0"], ["int8", "NC1HWC0"]], [["int8", "FracZ"], ["int32", "FracZ"], ["int32", "FracZ"], ["int8", "FracZ"]], [["int8", "FracZ"], ["int64", "FracZ"], ["int32", "FracZ"], ["int8", "FracZ"]], [["uint8", "DefaultFormat"], ["int32", "DefaultFormat"], ["int32", "DefaultFormat"], ["uint8", "DefaultFormat"]], [["uint8", "DefaultFormat"], ["int64", "DefaultFormat"], ["int32", "DefaultFormat"], ["uint8", "DefaultFormat"]], [["uint8", "NC1HWC0"], ["int32", "NC1HWC0"], ["int32", "NC1HWC0"], ["uint8", "NC1HWC0"]], [["uint8", "NC1HWC0"], ["int64", "NC1HWC0"], ["int32", "NC1HWC0"], ["uint8", "NC1HWC0"]], [["uint8", "FracZ"], ["int32", "FracZ"], ["int32", "FracZ"], ["uint8", "FracZ"]], [["uint8", "FracZ"], ["int64", "FracZ"], ["int32", "FracZ"], ["uint8", "FracZ"]], [["int32", "DefaultFormat"], ["int32", "DefaultFormat"], ["int32", "DefaultFormat"], ["int32", "DefaultFormat"]], [["int32", "DefaultFormat"], ["int64", "DefaultFormat"], ["int32", "DefaultFormat"], ["int32", "DefaultFormat"]], [["int32", "NC1HWC0"], ["int32", "NC1HWC0"], ["int32", "NC1HWC0"], ["int32", "NC1HWC0"]], [["int32", "NC1HWC0"], ["int64", "NC1HWC0"], ["int32", "NC1HWC0"], ["int32", "NC1HWC0"]], [["int32", "FracZ"], ["int32", "FracZ"], ["int32", "FracZ"], ["int32", "FracZ"]], [["int32", "FracZ"], ["int64", "FracZ"], ["int32", "FracZ"], ["int32", "FracZ"]], [["float16", "DefaultFormat"], ["int32", "DefaultFormat"], ["int32", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float16", "DefaultFormat"], ["int64", "DefaultFormat"], ["int32", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float16", "NC1HWC0"], ["int32", "NC1HWC0"], ["int32", "NC1HWC0"], ["float16", "NC1HWC0"]], [["float16", "NC1HWC0"], ["int64", "NC1HWC0"], ["int32", "NC1HWC0"], ["float16", "NC1HWC0"]], [["float16", "FracZ"], ["int32", "FracZ"], ["int32", "FracZ"], ["float16", "FracZ"]], [["float16", "FracZ"], ["int64", "FracZ"], ["int32", "FracZ"], ["float16", "FracZ"]], [["float32", "DefaultFormat"], ["int32", "DefaultFormat"], ["int32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["float32", "DefaultFormat"], ["int64", "DefaultFormat"], ["int32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["float32", "NC1HWC0"], ["int32", "NC1HWC0"], ["int32", "NC1HWC0"], ["float32", "NC1HWC0"]], [["float32", "NC1HWC0"], ["int64", "NC1HWC0"], ["int32", "NC1HWC0"], ["float32", "NC1HWC0"]], [["float32", "FracZ"], ["int32", "FracZ"], ["int32", "FracZ"], ["float32", "FracZ"]], [["float32", "FracZ"], ["int64", "FracZ"], ["int32", "FracZ"], ["float32", "FracZ"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "gather_v2.so", "compute_cost": 10, "kernel_name": "gather_v2", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "dynamic_shape": true, "need_check_supported": false, "op_pattern": ""}
|
||||
{"op_name": "GatherNd", "inputs": [{"index": 0, "name": "x1", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "x2", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["int32", "DefaultFormat"], ["int32", "DefaultFormat"], ["int32", "DefaultFormat"]], [["int32", "DefaultFormat"], ["int64", "DefaultFormat"], ["int32", "DefaultFormat"]], [["float32", "DefaultFormat"], ["int32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["float32", "DefaultFormat"], ["int64", "DefaultFormat"], ["float32", "DefaultFormat"]], [["float16", "DefaultFormat"], ["int32", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float16", "DefaultFormat"], ["int64", "DefaultFormat"], ["float16", "DefaultFormat"]], [["int8", "DefaultFormat"], ["int32", "DefaultFormat"], ["int8", "DefaultFormat"]], [["int8", "DefaultFormat"], ["int64", "DefaultFormat"], ["int8", "DefaultFormat"]], [["uint8", "DefaultFormat"], ["int32", "DefaultFormat"], ["uint8", "DefaultFormat"]], [["uint8", "DefaultFormat"], ["int64", "DefaultFormat"], ["uint8", "DefaultFormat"]], [["bool", "DefaultFormat"], ["int32", "DefaultFormat"], ["bool", "DefaultFormat"]], [["bool", "DefaultFormat"], ["int64", "DefaultFormat"], ["bool", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "gather_nd.so", "compute_cost": 10, "kernel_name": "gather_nd", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "dynamic_shape": false, "need_check_supported": false, "op_pattern": ""}
|
||||
{"op_name": "BNTrainingReduce", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all", "reshape_type": "NC"}], "outputs": [{"index": 0, "name": "sum", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "square_sum", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "ELEMWISE", "dtype_format": [[["float16", ""], ["float32", ""], ["float32", ""]], [["float32", ""], ["float32", ""], ["float32", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "bn_training_reduce.so", "compute_cost": 10, "kernel_name": "bn_training_reduce", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "dynamic_shape": false, "need_check_supported": false, "op_pattern": "dynamicFormat"}
|
||||
{"op_name": "BNTrainingReduceGrad", "inputs": [{"index": 0, "name": "grads", "need_compile": false, "param_type": "required", "shape": "all", "reshape_type": "NC"}, {"index": 1, "name": "x_norm", "need_compile": false, "param_type": "required", "shape": "all", "reshape_type": "NC"}, {"index": 2, "name": "diff_scale", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 3, "name": "diff_offset", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 4, "name": "scale", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 5, "name": "batch_mean", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 6, "name": "batch_variance", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all", "reshape_type": "NC"}], "attr": [{"name": "epsilon", "param_type": "optional", "type": "float", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["float16", ""], ["float32", ""], ["float32", ""], ["float32", ""], ["float32", ""], ["float32", ""], ["float32", ""], ["float16", ""]], [["float32", ""], ["float32", ""], ["float32", ""], ["float32", ""], ["float32", ""], ["float32", ""], ["float32", ""], ["float32", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "bn_training_reduce_grad.so", "compute_cost": 10, "kernel_name": "bn_training_reduce_grad", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "dynamic_shape": false, "need_check_supported": false, "op_pattern": "dynamicFormat"}
|
||||
|
|
|
@ -43,7 +43,7 @@ class GatherV2CPUKernel : public CPUKernel {
|
|||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(
|
||||
GatherV2,
|
||||
Gather,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
|
||||
GatherV2CPUKernel);
|
||||
} // namespace kernel
|
||||
|
|
|
@ -19,26 +19,26 @@
|
|||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherV2,
|
||||
Gather,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
|
||||
GatherV2GpuFwdKernel, float, int)
|
||||
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherV2,
|
||||
Gather,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
|
||||
GatherV2GpuFwdKernel, float, int64_t)
|
||||
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherV2,
|
||||
Gather,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
|
||||
GatherV2GpuFwdKernel, half, int)
|
||||
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherV2,
|
||||
Gather,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
|
||||
GatherV2GpuFwdKernel, half, int64_t)
|
||||
|
||||
MS_REG_GPU_KERNEL_TWO(GatherV2,
|
||||
MS_REG_GPU_KERNEL_TWO(Gather,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
|
@ -46,7 +46,7 @@ MS_REG_GPU_KERNEL_TWO(GatherV2,
|
|||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
GatherV2GpuFwdKernel, float, int)
|
||||
|
||||
MS_REG_GPU_KERNEL_TWO(GatherV2,
|
||||
MS_REG_GPU_KERNEL_TWO(Gather,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
|
@ -54,7 +54,7 @@ MS_REG_GPU_KERNEL_TWO(GatherV2,
|
|||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
GatherV2GpuFwdKernel, float, int64_t)
|
||||
|
||||
MS_REG_GPU_KERNEL_TWO(GatherV2,
|
||||
MS_REG_GPU_KERNEL_TWO(Gather,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
|
@ -62,7 +62,7 @@ MS_REG_GPU_KERNEL_TWO(GatherV2,
|
|||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
GatherV2GpuFwdKernel, half, int)
|
||||
|
||||
MS_REG_GPU_KERNEL_TWO(GatherV2,
|
||||
MS_REG_GPU_KERNEL_TWO(Gather,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
|
|
|
@ -85,8 +85,8 @@ CNodePtr CreateGatherV2Ds(const FuncGraphPtr &graph, const CNodePtr &origin_node
|
|||
if (origin_node->size() != 4) {
|
||||
MS_LOG(EXCEPTION) << "In dynamic shape scene, gatherv2 should have 3 inputs";
|
||||
}
|
||||
std::vector<AnfNodePtr> gatherv2_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimGatherV2->name())),
|
||||
pad, origin_node->input(2), origin_node->input(3)};
|
||||
std::vector<AnfNodePtr> gatherv2_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimGather->name())), pad,
|
||||
origin_node->input(2), origin_node->input(3)};
|
||||
auto gather_v2 = graph->NewCNode(gatherv2_inputs);
|
||||
MS_EXCEPTION_IF_NULL(gather_v2);
|
||||
gather_v2->set_scope(origin_node->scope());
|
||||
|
@ -146,7 +146,7 @@ bool CheckInputs(const CNodePtr &origin_node) {
|
|||
|
||||
const BaseRef GatherV2DsFission::DefinePattern() const {
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
VectorRef pattern({prim::kPrimGatherV2, Xs});
|
||||
VectorRef pattern({prim::kPrimGather, Xs});
|
||||
return pattern;
|
||||
}
|
||||
|
||||
|
|
|
@ -38,7 +38,7 @@ ConstInputToAttrInfoRegistry::ConstInputToAttrInfoRegistry() {
|
|||
Register(prim::kPrimReduceMin->name(), {1});
|
||||
Register(prim::kPrimReduceSum->name(), {1});
|
||||
Register(prim::kPrimReduceMean->name(), {1});
|
||||
Register(prim::kPrimGatherV2->name(), {2});
|
||||
Register(prim::kPrimGather->name(), {2});
|
||||
Register(prim::kPrimGatherD->name(), {1});
|
||||
Register(prim::kPrimEmbeddingLookup->name(), {2, 3, 4, 5});
|
||||
Register(prim::kPrimEmbeddingLookupCommGrad->name(), {1});
|
||||
|
|
|
@ -62,7 +62,7 @@ bool InConvertWhiteList(const AnfNodePtr &node, size_t index) {
|
|||
{prim::kPrimCast, {2}},
|
||||
{prim::kPrimTranspose, {2}},
|
||||
{prim::kPrimOneHot, {2}},
|
||||
{prim::kPrimGatherV2, {3}},
|
||||
{prim::kPrimGather, {3}},
|
||||
{prim::kPrimReshape, {2}},
|
||||
{prim::kPrimAssign, {1}},
|
||||
{prim::kPrimAssignAdd, {1}},
|
||||
|
@ -508,7 +508,7 @@ bool GraphOutputCompatible(const AbstractBasePtr &true_branch_abs, const Abstrac
|
|||
abstract::AbstractTuplePtr false_branch_tuple = false_branch_abs->cast<abstract::AbstractTuplePtr>();
|
||||
if (true_branch_tuple->elements().size() != false_branch_tuple->elements().size()) {
|
||||
MS_LOG(ERROR) << "true branch size:" << true_branch_tuple->elements().size()
|
||||
<< ", not equal to false banch size:" << false_branch_tuple->elements().size() << " ";
|
||||
<< ", not equal to false branch size:" << false_branch_tuple->elements().size() << " ";
|
||||
return false;
|
||||
}
|
||||
bool all_compatible = true;
|
||||
|
|
|
@ -616,7 +616,7 @@ Dimensions PrepareIncomingOperatorInputStrategy(const std::vector<std::shared_pt
|
|||
return s;
|
||||
}
|
||||
auto name = ops[incoming_op_index]->name().substr(0, pos);
|
||||
if (name == "GatherV2") {
|
||||
if (name == "Gather") {
|
||||
return s;
|
||||
} else if (name == "GatherV2P") {
|
||||
return PrepareGatherV2POutputStrategy(ops, incoming_op_index);
|
||||
|
@ -849,7 +849,7 @@ Strategys GenerateStrategiesFromStrategy(const std::vector<std::shared_ptr<Opera
|
|||
if (ops[iter_ops]->type() == GATHERV2) {
|
||||
auto pos = ops[iter_ops]->name().find("Info");
|
||||
auto name = ops[iter_ops]->name().substr(0, pos);
|
||||
if (name == "GatherV2") {
|
||||
if (name == "Gather") {
|
||||
return PrepareGatherV2(ops, iter_ops, basic_stra);
|
||||
} else if (name == "GatherV2P") {
|
||||
return PrepareGatherV2P(ops, iter_ops, basic_stra);
|
||||
|
|
|
@ -426,7 +426,7 @@ AnfNodePtr FindGatherV2FromSparseGatherV2(const FuncGraphPtr &graph, const AnfNo
|
|||
AnfNodePtrList gatherv2_nodes;
|
||||
auto user_set = graph->manager()->node_users()[node];
|
||||
for (auto &ele : user_set) {
|
||||
if (IsPrimitiveCNode(ele.first, prim::kPrimGatherV2)) {
|
||||
if (IsPrimitiveCNode(ele.first, prim::kPrimGather)) {
|
||||
gatherv2_nodes.emplace_back(ele.first);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -140,7 +140,7 @@ REGISTER(ReLU6Info);
|
|||
REGISTER(ReLUV2Info);
|
||||
REGISTER(SoftplusInfo);
|
||||
REGISTER(SoftsignInfo);
|
||||
REGISTER(GatherV2Info);
|
||||
REGISTER(GatherInfo);
|
||||
REGISTER(SparseGatherV2Info);
|
||||
REGISTER(SqrtInfo);
|
||||
REGISTER(SigmoidInfo);
|
||||
|
@ -180,7 +180,7 @@ REGISTER(UniformCandidateSamplerInfo);
|
|||
REGISTER(UnsortedSegmentSumInfo);
|
||||
REGISTER(UnsortedSegmentMinInfo);
|
||||
REGISTER(UnsortedSegmentMaxInfo);
|
||||
REGISTER(GatherV2PInfo);
|
||||
REGISTER(GatherPInfo);
|
||||
REGISTER(EmbeddingLookupInfo);
|
||||
REGISTER(TileInfo);
|
||||
REGISTER(BroadcastToInfo);
|
||||
|
|
|
@ -30,7 +30,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
Status GatherV2Info::GetAttrs() {
|
||||
Status GatherInfo::GetAttrs() {
|
||||
if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) {
|
||||
MS_LOG(ERROR) << name_ << ": inputs shape size must be 2, but is " << inputs_shape_.size();
|
||||
return FAILED;
|
||||
|
@ -70,7 +70,7 @@ Status GatherV2Info::GetAttrs() {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status GatherV2Info::CheckStrategy(const StrategyPtr &strategy) {
|
||||
Status GatherInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||
if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) {
|
||||
MS_LOG(ERROR) << name_ << ": inputs shape size must be " << GATHER_V2_INPUTS_SIZE << ", but is "
|
||||
<< inputs_shape_.size();
|
||||
|
@ -104,7 +104,7 @@ Status GatherV2Info::CheckStrategy(const StrategyPtr &strategy) {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status GatherV2Info::InferDevMatrixShape() {
|
||||
Status GatherInfo::InferDevMatrixShape() {
|
||||
Strategys stra = strategy_->GetInputDim();
|
||||
dev_matrix_shape_ = stra.at(0);
|
||||
return SUCCESS;
|
||||
|
@ -114,7 +114,7 @@ Status GatherV2Info::InferDevMatrixShape() {
|
|||
// If index is a n dimension tensor, output dimension is input dimension plus (n - 1).
|
||||
// Tensor map dimension is equal to the corresponding input and output dimension.
|
||||
// If index's dimension is more than 1, we insert -1 for the output tensor map.
|
||||
Status GatherV2Info::InferTensorMap() {
|
||||
Status GatherInfo::InferTensorMap() {
|
||||
if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) {
|
||||
MS_LOG(ERROR) << name_ << ": inputs shape size must be " << GATHER_V2_INPUTS_SIZE << ", but is "
|
||||
<< inputs_shape_.size();
|
||||
|
@ -158,7 +158,7 @@ Status GatherV2Info::InferTensorMap() {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status GatherV2Info::InferTensorInfo() {
|
||||
Status GatherInfo::InferTensorInfo() {
|
||||
if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) {
|
||||
MS_LOG(ERROR) << name_ << ": inputs shape size must be " << GATHER_V2_INPUTS_SIZE << ", but is "
|
||||
<< inputs_shape_.size();
|
||||
|
@ -219,7 +219,7 @@ OperatorVector CreateSubOp(int64_t sub_value) {
|
|||
return ops;
|
||||
}
|
||||
|
||||
Status GatherV2Info::InferTensorSubOps() {
|
||||
Status GatherInfo::InferTensorSubOps() {
|
||||
sub_ops_.clear();
|
||||
if ((index_size_ == 0) || (axis_strategy_ == 1)) {
|
||||
return SUCCESS;
|
||||
|
@ -252,7 +252,7 @@ Status GatherV2Info::InferTensorSubOps() {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status GatherV2Info::Init(const StrategyPtr &strategy) {
|
||||
Status GatherInfo::Init(const StrategyPtr &strategy) {
|
||||
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Init failed.";
|
||||
return FAILED;
|
||||
|
@ -266,7 +266,7 @@ Status GatherV2Info::Init(const StrategyPtr &strategy) {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status GatherV2Info::InitForCostModel(const StrategyPtr &strategy) {
|
||||
Status GatherInfo::InitForCostModel(const StrategyPtr &strategy) {
|
||||
if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Init for cost model failed.";
|
||||
return FAILED;
|
||||
|
@ -275,7 +275,7 @@ Status GatherV2Info::InitForCostModel(const StrategyPtr &strategy) {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status GatherV2Info::GenerateStrategies(int64_t stage_id) {
|
||||
Status GatherInfo::GenerateStrategies(int64_t stage_id) {
|
||||
if ((inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) || (outputs_shape_.size() != GATHER_V2_OUTPUTS_SIZE)) {
|
||||
MS_LOG(ERROR) << name_ << " : Inputs shape size(" << inputs_shape_.size() << ") or outputs shape size("
|
||||
<< outputs_shape_.size() << "is wrong.";
|
||||
|
@ -301,9 +301,9 @@ Status GatherV2Info::GenerateStrategies(int64_t stage_id) {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status GatherV2Info::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
|
||||
Status GatherInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
|
||||
|
||||
std::shared_ptr<Strategys> GatherV2Info::GenerateBatchStrategies() {
|
||||
std::shared_ptr<Strategys> GatherInfo::GenerateBatchStrategies() {
|
||||
if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": inputs shape size must be " << GATHER_V2_INPUTS_SIZE << ", but is "
|
||||
<< inputs_shape_.size();
|
||||
|
|
|
@ -36,15 +36,15 @@ constexpr size_t GATHER_V2_INPUTS_VALUE_SIZE = 3;
|
|||
// If the strategy corresponding to axis is more than 1, index must be evenly distributed across the axis-dimension of
|
||||
// the input.
|
||||
// If Index is a scalar or n-dimension vector(n > 1), the strategy corresponding to axis must be 1.
|
||||
class GatherV2Info : public OperatorInfo {
|
||||
class GatherInfo : public OperatorInfo {
|
||||
public:
|
||||
GatherV2Info(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
GatherInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<GatherV2Cost>()),
|
||||
axis_(-1),
|
||||
index_size_(0),
|
||||
axis_strategy_(1) {}
|
||||
~GatherV2Info() override = default;
|
||||
~GatherInfo() override = default;
|
||||
Status Init(const StrategyPtr &strategy) override;
|
||||
Status InitForCostModel(const StrategyPtr &strategy) override;
|
||||
|
||||
|
|
|
@ -32,7 +32,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
Status GatherV2PInfo::GetManualSplitWithoutOffsetAttr() {
|
||||
Status GatherPInfo::GetManualSplitWithoutOffsetAttr() {
|
||||
auto manual_split_without_offset_iter = attrs_.find("manual_split");
|
||||
if (manual_split_without_offset_iter != attrs_.end()) {
|
||||
manual_split_ = true;
|
||||
|
@ -68,7 +68,7 @@ Status GatherV2PInfo::GetManualSplitWithoutOffsetAttr() {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status GatherV2PInfo::GetManualSplitAttr() {
|
||||
Status GatherPInfo::GetManualSplitAttr() {
|
||||
auto manual_split_with_offset_iter = attrs_.find("manual_split_with_offset");
|
||||
if (manual_split_with_offset_iter != attrs_.end()) {
|
||||
manual_split_ = true;
|
||||
|
@ -118,7 +118,7 @@ Status GatherV2PInfo::GetManualSplitAttr() {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status GatherV2PInfo::GetAttrs() {
|
||||
Status GatherPInfo::GetAttrs() {
|
||||
// get axis, the third input is the axis, is a ValueNode, embeddinglookup doesn't have axis.
|
||||
if (target_ != CPU) {
|
||||
if (input_value_.at(2) == nullptr) {
|
||||
|
@ -172,7 +172,7 @@ Status GatherV2PInfo::GetAttrs() {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status GatherV2PInfo::CheckManualSplit(const Strategys &strategy) {
|
||||
Status GatherPInfo::CheckManualSplit(const Strategys &strategy) {
|
||||
if (strategy.size() != 2) {
|
||||
MS_LOG(ERROR) << name_ << ": The size of strategy must be 2, but got " << strategy.size();
|
||||
return FAILED;
|
||||
|
@ -228,7 +228,7 @@ Status GatherV2PInfo::CheckManualSplit(const Strategys &strategy) {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||
Status GatherPInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
|
||||
return FAILED;
|
||||
}
|
||||
|
@ -306,7 +306,7 @@ Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status GatherV2PInfo::InferMirrorOps() {
|
||||
Status GatherPInfo::InferMirrorOps() {
|
||||
// There is no mirror operators for manual split
|
||||
if (manual_split_) {
|
||||
return SUCCESS;
|
||||
|
@ -336,7 +336,7 @@ Status GatherV2PInfo::InferMirrorOps() {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status GatherV2PInfo::InferDevMatrixShape() {
|
||||
Status GatherPInfo::InferDevMatrixShape() {
|
||||
dev_matrix_shape_.clear();
|
||||
out_dev_matrix_shape_.clear();
|
||||
// infer input dev_matrix_shape
|
||||
|
@ -386,7 +386,7 @@ Status GatherV2PInfo::InferDevMatrixShape() {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
void GatherV2PInfo::InferInputsTensorMap() {
|
||||
void GatherPInfo::InferInputsTensorMap() {
|
||||
// infer input tensor map
|
||||
// param_strategy(axis) != 1
|
||||
size_t param_size = inputs_shape_.at(0).size();
|
||||
|
@ -413,7 +413,7 @@ void GatherV2PInfo::InferInputsTensorMap() {
|
|||
inputs_tensor_map_.emplace_back(std::move(tensor_map_index));
|
||||
}
|
||||
|
||||
void GatherV2PInfo::InferOutputsTensorMap() {
|
||||
void GatherPInfo::InferOutputsTensorMap() {
|
||||
// infer output tensor map
|
||||
size_t param_size = inputs_shape_.at(0).size();
|
||||
size_t index_size = inputs_shape_.at(1).size();
|
||||
|
@ -460,7 +460,7 @@ void GatherV2PInfo::InferOutputsTensorMap() {
|
|||
outputs_tensor_map_.emplace_back(std::move(tensor_map_out));
|
||||
}
|
||||
|
||||
Status GatherV2PInfo::InferTensorMap() {
|
||||
Status GatherPInfo::InferTensorMap() {
|
||||
if (manual_split_) {
|
||||
inputs_tensor_map_.push_back({1, 0});
|
||||
inputs_tensor_map_.push_back({-1, 1});
|
||||
|
@ -472,7 +472,7 @@ Status GatherV2PInfo::InferTensorMap() {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status GatherV2PInfo::InferTensorInfo() {
|
||||
Status GatherPInfo::InferTensorInfo() {
|
||||
// infer tensor shape
|
||||
Shape input_shape = inputs_shape_.at(0);
|
||||
Shape input_index_shape = inputs_shape_.at(1);
|
||||
|
@ -505,7 +505,7 @@ Status GatherV2PInfo::InferTensorInfo() {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status GatherV2PInfo::InferBias() {
|
||||
Status GatherPInfo::InferBias() {
|
||||
CheckGlobalDeviceManager();
|
||||
int64_t rank = g_device_manager->rank_index_in_stage();
|
||||
auto input_shape = inputs_shape_.at(0);
|
||||
|
@ -559,7 +559,7 @@ Status GatherV2PInfo::InferBias() {
|
|||
return FAILED;
|
||||
}
|
||||
|
||||
Status GatherV2PInfo::InferOffset() {
|
||||
Status GatherPInfo::InferOffset() {
|
||||
CheckGlobalDeviceManager();
|
||||
size_t rank = g_device_manager->rank_index_in_stage();
|
||||
|
||||
|
@ -580,7 +580,7 @@ Status GatherV2PInfo::InferOffset() {
|
|||
return FAILED;
|
||||
}
|
||||
|
||||
Status GatherV2PInfo::InferGroup() {
|
||||
Status GatherPInfo::InferGroup() {
|
||||
auto param_strategy = strategy_->GetInputDim().at(0);
|
||||
size_t dim = LongToSize(axis_);
|
||||
if (param_strategy.at(LongToSize(axis_)) != 1 && inputs_shape_.at(0).size() == 2) {
|
||||
|
@ -610,7 +610,7 @@ Status GatherV2PInfo::InferGroup() {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status GatherV2PInfo::InferForwardCommunication() {
|
||||
Status GatherPInfo::InferForwardCommunication() {
|
||||
if (manual_split_) {
|
||||
return SUCCESS;
|
||||
}
|
||||
|
@ -647,7 +647,7 @@ Status GatherV2PInfo::InferForwardCommunication() {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status GatherV2PInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
|
||||
Status GatherPInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
|
||||
GenerateGraph gen_g = GenerateGraph();
|
||||
if (gen_g.Init(cnode) != SUCCESS) {
|
||||
MS_LOG(ERROR) << "GenerateGraph Init failed";
|
||||
|
@ -705,7 +705,7 @@ Status GatherV2PInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
ReplaceGraphPtr GatherV2PInfo::replace_graph(const CNodePtr &cnode) {
|
||||
ReplaceGraphPtr GatherPInfo::replace_graph(const CNodePtr &cnode) {
|
||||
if (manual_split_ && target_ != CPU) {
|
||||
if (ComputeReplaceGraph(cnode) != SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": ComputeReplaceGraph failed.";
|
||||
|
@ -724,7 +724,7 @@ ReplaceGraphPtr GatherV2PInfo::replace_graph(const CNodePtr &cnode) {
|
|||
return replace_graph_;
|
||||
}
|
||||
|
||||
Status GatherV2PInfo::ComputeReplaceOp() {
|
||||
Status GatherPInfo::ComputeReplaceOp() {
|
||||
int64_t bias = 0;
|
||||
if (manual_split_) {
|
||||
if (InferOffset() != SUCCESS) {
|
||||
|
@ -752,7 +752,7 @@ Status GatherV2PInfo::ComputeReplaceOp() {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status GatherV2PInfo::Init(const StrategyPtr &strategy) {
|
||||
Status GatherPInfo::Init(const StrategyPtr &strategy) {
|
||||
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Init failed.";
|
||||
return FAILED;
|
||||
|
@ -765,7 +765,7 @@ Status GatherV2PInfo::Init(const StrategyPtr &strategy) {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status GatherV2PInfo::InitForCostModel(const StrategyPtr &strategy) {
|
||||
Status GatherPInfo::InitForCostModel(const StrategyPtr &strategy) {
|
||||
if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
|
||||
if (is_auto_parallel_) {
|
||||
MS_LOG(DEBUG) << name_ << ": Init for cost model failed.";
|
||||
|
@ -783,9 +783,9 @@ Status GatherV2PInfo::InitForCostModel(const StrategyPtr &strategy) {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status GatherV2PInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
|
||||
Status GatherPInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
|
||||
|
||||
Status GatherV2PInfo::GenerateStrategies(int64_t stage_id) {
|
||||
Status GatherPInfo::GenerateStrategies(int64_t stage_id) {
|
||||
if (GetAttrs() != SUCCESS) {
|
||||
return FAILED;
|
||||
}
|
||||
|
@ -814,7 +814,7 @@ Status GatherV2PInfo::GenerateStrategies(int64_t stage_id) {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
std::shared_ptr<Strategys> GatherV2PInfo::GenerateBatchStrategies() {
|
||||
std::shared_ptr<Strategys> GatherPInfo::GenerateBatchStrategies() {
|
||||
if (GetAttrs() != SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": Get attr failed";
|
||||
}
|
||||
|
|
|
@ -29,17 +29,17 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
class GatherV2PInfo : public OperatorInfo {
|
||||
class GatherPInfo : public OperatorInfo {
|
||||
public:
|
||||
GatherV2PInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs, const std::string &replace_op_name = GATHERV2)
|
||||
GatherPInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs, const std::string &replace_op_name = GATHERV2)
|
||||
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<GatherV2PCost>()),
|
||||
axis_(0),
|
||||
bias_(0),
|
||||
index_offset_(0),
|
||||
slice_size_(0),
|
||||
replace_op_name_(replace_op_name) {}
|
||||
~GatherV2PInfo() override = default;
|
||||
~GatherPInfo() override = default;
|
||||
Status Init(const StrategyPtr &strategy) override;
|
||||
Status InitForCostModel(const StrategyPtr &strategy) override;
|
||||
|
||||
|
@ -85,19 +85,19 @@ class GatherV2PInfo : public OperatorInfo {
|
|||
std::vector<int64_t> index_offsets_;
|
||||
};
|
||||
|
||||
class SparseGatherV2Info : public GatherV2PInfo {
|
||||
class SparseGatherV2Info : public GatherPInfo {
|
||||
public:
|
||||
SparseGatherV2Info(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs, const std::string &replace_op_name = SPARSE_GATHERV2)
|
||||
: GatherV2PInfo(name, inputs_shape, outputs_shape, attrs, replace_op_name) {}
|
||||
: GatherPInfo(name, inputs_shape, outputs_shape, attrs, replace_op_name) {}
|
||||
~SparseGatherV2Info() override = default;
|
||||
};
|
||||
|
||||
class EmbeddingLookupInfo : public GatherV2PInfo {
|
||||
class EmbeddingLookupInfo : public GatherPInfo {
|
||||
public:
|
||||
EmbeddingLookupInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: GatherV2PInfo(name, inputs_shape, outputs_shape, attrs) {}
|
||||
: GatherPInfo(name, inputs_shape, outputs_shape, attrs) {}
|
||||
~EmbeddingLookupInfo() override = default;
|
||||
};
|
||||
} // namespace parallel
|
||||
|
|
|
@ -249,7 +249,7 @@ constexpr char MINIMUM[] = "Minimum";
|
|||
constexpr char EQUAL[] = "Equal";
|
||||
constexpr char NOT_EQUAL[] = "NotEqual";
|
||||
constexpr char LOGICALNOT[] = "LogicalNot";
|
||||
constexpr char GATHERV2[] = "GatherV2";
|
||||
constexpr char GATHERV2[] = "Gather";
|
||||
constexpr char SPARSE_GATHERV2[] = "SparseGatherV2";
|
||||
constexpr char STRIDEDSLICE[] = "StridedSlice";
|
||||
constexpr char SLICE[] = "Slice";
|
||||
|
|
|
@ -2699,7 +2699,7 @@ void CheckpointStrategy(const std::vector<AnfNodePtr> &all_nodes) {
|
|||
}
|
||||
if (operator_info->name().find(EMBEDDING_LOOKUP) != std::string::npos ||
|
||||
operator_info->name().find(GATHERV2) != std::string::npos) {
|
||||
auto gatherv2_info = std::dynamic_pointer_cast<GatherV2PInfo>(operator_info);
|
||||
auto gatherv2_info = std::dynamic_pointer_cast<GatherPInfo>(operator_info);
|
||||
auto param_split_shapes = gatherv2_info->param_split_shapes();
|
||||
auto index_offsets = gatherv2_info->index_offsets();
|
||||
if (param_split_shapes.size() != index_offsets.size()) {
|
||||
|
|
|
@ -148,7 +148,7 @@ std::string GetRealOpType(const std::string &op_type) {
|
|||
static const std::map<std::string, std::string> kOpTypeMap = {
|
||||
{"SparseApplyFtrl", "SparseApplyFtrlD"},
|
||||
{"SparseApplyProximalAdagrad", "SparseApplyProximalAdagradD"},
|
||||
{"SparseGatherV2", "GatherV2"},
|
||||
{"SparseGatherV2", "Gather"},
|
||||
{"Pad", "PadD"},
|
||||
{"Concat", "ConcatD"},
|
||||
};
|
||||
|
|
|
@ -247,7 +247,7 @@ OPERATOR_ONNX_CONVERT_DEFINE(
|
|||
.Attr("pad_mode", "auto_pad", onnx::AttributeProto_AttributeType_STRING, SetPoolingPadMode)
|
||||
.Attr("strides", "strides", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<2>))
|
||||
|
||||
OPERATOR_ONNX_CONVERT_DEFINE(GatherV2, Gather, OpNameInfo())
|
||||
OPERATOR_ONNX_CONVERT_DEFINE(Gather, Gather, OpNameInfo())
|
||||
OPERATOR_ONNX_CONVERT_DEFINE(make_tuple, SequenceConstruct, OpNameInfo())
|
||||
OPERATOR_ONNX_CONVERT_DEFINE(Concat, Concat, OpNameInfo())
|
||||
OPERATOR_ONNX_CONVERT_DEFINE(RealDiv, Div, OpNameInfo())
|
||||
|
@ -970,7 +970,7 @@ void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &n
|
|||
}
|
||||
|
||||
// MindSpore GatherV2(x, indices, axis) --> ONNX Pow(x, indices)
|
||||
if (node->IsApply(prim::kPrimGatherV2)) {
|
||||
if (node->IsApply(prim::kPrimGather)) {
|
||||
return ExportPrimGatherV2(func_graph, node, node_map_ptr, graph_proto);
|
||||
}
|
||||
|
||||
|
|
|
@ -70,7 +70,7 @@ INPUT_MAP(GatherV2D) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(indices)}};
|
|||
INPUT_ATTR_MAP(GatherV2D) = {{3, ATTR_DESC(axis, AnyTraits<int64_t>())}};
|
||||
ATTR_MAP(GatherV2D) = EMPTY_ATTR_MAP;
|
||||
OUTPUT_MAP(GatherV2D) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(GatherV2D, prim::kPrimGatherV2->name(), ADPT_DESC(GatherV2D))
|
||||
REG_ADPT_DESC(GatherV2D, prim::kPrimGather->name(), ADPT_DESC(GatherV2D))
|
||||
|
||||
// ScatterNdD
|
||||
INPUT_MAP(ScatterNdD) = {{1, INPUT_DESC(indices)}, {2, INPUT_DESC(x)}};
|
||||
|
|
|
@ -208,7 +208,7 @@ constexpr auto kPushOpName = "Push";
|
|||
constexpr auto kPullOpName = "Pull";
|
||||
constexpr auto kEmbeddingLookupOpName = "EmbeddingLookup";
|
||||
constexpr auto kEmbeddingLookupProxyOpName = "EmbeddingLookupProxy";
|
||||
constexpr auto kGatherV2OpName = "GatherV2";
|
||||
constexpr auto kGatherV2OpName = "Gather";
|
||||
constexpr auto kPaddingOpName = "Padding";
|
||||
constexpr auto kAvgPoolOpName = "AvgPool";
|
||||
constexpr auto kAvgPoolGradGpuOpName = "AvgPoolGradGpu";
|
||||
|
|
|
@ -64,7 +64,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
|
|||
{prim::kPrimPad, {InferImplPad, true}},
|
||||
{prim::kPrimUnique, {InferImplUnique, true}},
|
||||
{prim::kPrimUniqueGrad, {InferImplUniqueGrad, true}},
|
||||
{prim::kPrimGatherV2, {InferImplGatherV2, true}},
|
||||
{prim::kPrimGather, {InferImplGatherV2, true}},
|
||||
{prim::kPrimSparseGatherV2, {InferImplGatherV2, true}},
|
||||
{prim::kPrimEmbeddingLookup, {InferImplEmbeddingLookup, true}},
|
||||
{prim::kPrimUnsortedSegmentSum, {InferImplUnsortedSegmentSum, true}},
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace prim {
|
||||
constexpr auto kGather = "Gather";
|
||||
// Here list all primitives used in backend or some special primitives used by core.
|
||||
// Arithmetic
|
||||
inline const PrimitivePtr kPrimScalarAdd = std::make_shared<Primitive>("scalar_add");
|
||||
|
@ -86,8 +87,8 @@ inline const PrimitivePtr kPrimCast = std::make_shared<Primitive>("Cast");
|
|||
inline const PrimitivePtr kPrimConcat = std::make_shared<Primitive>("Concat");
|
||||
inline const PrimitivePtr kPrimSqueeze = std::make_shared<Primitive>("Squeeze");
|
||||
inline const PrimitivePtr kPrimTranspose = std::make_shared<Primitive>("Transpose");
|
||||
inline const PrimitivePtr kPrimGatherV2 = std::make_shared<Primitive>("GatherV2");
|
||||
inline const PrimitivePtr kPrimGatherD = std::make_shared<Primitive>("GatherD");
|
||||
inline const PrimitivePtr kPrimGather = std::make_shared<Primitive>(kGather);
|
||||
inline const PrimitivePtr kPrimSparseGatherV2 = std::make_shared<Primitive>("SparseGatherV2");
|
||||
inline const PrimitivePtr kPrimShape = std::make_shared<Primitive>("Shape");
|
||||
inline const PrimitivePtr kPrimDynamicShape = std::make_shared<Primitive>("DynamicShape");
|
||||
|
@ -351,7 +352,7 @@ inline const PrimitivePtr kPrimGetRefKey = std::make_shared<Primitive>("get_ref_
|
|||
inline const PrimitivePtr kPrimMakeRef = std::make_shared<Primitive>("make_ref");
|
||||
inline const PrimitivePtr kPrimGetRefValue = std::make_shared<Primitive>("get_ref_value");
|
||||
|
||||
// Other primitve not used by backend but used in core;
|
||||
// Other primitive not used by backend but used in core;
|
||||
inline const PrimitivePtr kPrimStateSetItem = std::make_shared<Primitive>("state_setitem");
|
||||
inline const PrimitivePtr kPrimJ = std::make_shared<Primitive>("J");
|
||||
|
||||
|
|
|
@ -607,7 +607,7 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std:
|
|||
return NewPrimitiveC<While>(prim, inputs, quantType);
|
||||
} else if (op_type == "MirrorPad") {
|
||||
return NewPrimitiveC<Pad>(prim, inputs, quantType);
|
||||
} else if (op_type == "GatherV2") {
|
||||
} else if (op_type == "Gather") {
|
||||
return NewPrimitiveC<Gather>(prim, inputs, quantType);
|
||||
} else if (op_type == "OnesLike") {
|
||||
return NewPrimitiveC<OnesLike>(prim, inputs, quantType);
|
||||
|
|
|
@ -97,6 +97,7 @@ STATUS TFGatherParser::Parse(const tensorflow::NodeDef &tf_op,
|
|||
status = AddOpInput(tf_op, 1, inputs);
|
||||
return status;
|
||||
}
|
||||
|
||||
TFNodeRegistrar g_tfGatherV2Parser("GatherV2", new TFGatherParser());
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -69,7 +69,7 @@ std::map<tflite::BuiltinOperator, std::string> tfMsOpTypeMap{
|
|||
{tflite::BuiltinOperator_RANGE, "Range"},
|
||||
{tflite::BuiltinOperator_RANK, "Rank"},
|
||||
{tflite::BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION, "LocalResponseNorm"},
|
||||
{tflite::BuiltinOperator_GATHER, "GatherV2"},
|
||||
{tflite::BuiltinOperator_GATHER, "Gather"},
|
||||
{tflite::BuiltinOperator_EXP, "Exp"},
|
||||
{tflite::BuiltinOperator_SPLIT_V, "SplitV"},
|
||||
{tflite::BuiltinOperator_SPLIT, "Split"},
|
||||
|
|
|
@ -112,7 +112,7 @@ class Embedding(Cell):
|
|||
self.expand = P.ExpandDims()
|
||||
self.reshape_flat = P.Reshape()
|
||||
self.shp_flat = (-1,)
|
||||
self.gather = P.GatherV2()
|
||||
self.gather = P.Gather()
|
||||
self.one_hot = P.OneHot()
|
||||
self.on_value = Tensor(1.0, self.dtype)
|
||||
self.off_value = Tensor(0.0, self.dtype)
|
||||
|
@ -154,7 +154,7 @@ class EmbeddingLookup(Cell):
|
|||
When 'target' is set to 'CPU', this module will use
|
||||
P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU') which
|
||||
specified 'offset = 0' to lookup table.
|
||||
When 'target' is set to 'DEVICE', this module will use P.GatherV2() which
|
||||
When 'target' is set to 'DEVICE', this module will use P.Gather() which
|
||||
specified 'axis = 0' to lookup table.
|
||||
In field slice mode, the manual_shapes must be given. It is a tuple ,where
|
||||
the element is vocab[i], vocab[i] is the row numbers for i-th part.
|
||||
|
@ -221,7 +221,7 @@ class EmbeddingLookup(Cell):
|
|||
if sparse:
|
||||
self.gatherv2 = P.SparseGatherV2()
|
||||
else:
|
||||
self.gatherv2 = P.GatherV2()
|
||||
self.gatherv2 = P.Gather()
|
||||
self.embeddinglookup = P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU')
|
||||
enable_ps = _get_ps_context("enable_ps")
|
||||
if enable_ps:
|
||||
|
@ -231,7 +231,7 @@ class EmbeddingLookup(Cell):
|
|||
name='embedding_table')
|
||||
parallel_mode = _get_parallel_mode()
|
||||
is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
|
||||
self.gather_revert = P.GatherV2()
|
||||
self.gather_revert = P.Gather()
|
||||
self.reshape_first = P.Reshape()
|
||||
self.reshape = P.Reshape()
|
||||
self.unique = P.Unique()
|
||||
|
@ -379,7 +379,7 @@ class MultiFieldEmbeddingLookup(EmbeddingLookup):
|
|||
When 'target' is set to 'CPU', this module will use
|
||||
P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU') which
|
||||
specified 'offset = 0' to lookup table.
|
||||
When 'target' is set to 'DEVICE', this module will use P.GatherV2() which
|
||||
When 'target' is set to 'DEVICE', this module will use P.Gather() which
|
||||
specified 'axis = 0' to lookup table.
|
||||
The vectors with the same field_ids will be combined by the 'operator', such as 'SUM', 'MAX' and
|
||||
'MEAN'. Ensure the input_values of the padded id is zero, so that they can be ignored. The final
|
||||
|
|
|
@ -440,7 +440,7 @@ class SampledSoftmaxLoss(_Loss):
|
|||
self.log = P.Log()
|
||||
self.slice_op = P.Slice()
|
||||
self.matmul = P.MatMul(False, True)
|
||||
self.gather_v2 = P.GatherV2()
|
||||
self.gather_v2 = P.Gather()
|
||||
self.reduce_max_true = P.ReduceMax(True)
|
||||
self.reduce_sum = P.ReduceSum()
|
||||
self.reduce_sum_true = P.ReduceSum(True)
|
||||
|
|
|
@ -49,7 +49,7 @@ def _run_opt_with_sparse(opt, sparse_opt, push, pull, use_locking, use_nesterov,
|
|||
success = F.depend(success, sparse_opt(params, m, v, beta1_power, beta2_power, lr, beta1, beta2,
|
||||
eps, values, indices))
|
||||
else:
|
||||
op_gather = P.GatherV2()
|
||||
op_gather = P.Gather()
|
||||
op_sqrt = P.Sqrt()
|
||||
scatter_add = P.ScatterAdd(use_locking)
|
||||
scatter_update = P.ScatterUpdate(use_locking)
|
||||
|
|
|
@ -537,7 +537,7 @@ class Optimizer(Cell):
|
|||
|
||||
|
||||
op_add = P.AddN()
|
||||
op_gather = P.GatherV2()
|
||||
op_gather = P.Gather()
|
||||
op_mul = P.Mul()
|
||||
|
||||
_apply_decay = C.MultitypeFuncGraph("apply_decay")
|
||||
|
@ -625,7 +625,7 @@ class _IteratorLearningRate(LearningRateSchedule):
|
|||
raise TypeError("Learning rate should be Tensor.")
|
||||
|
||||
self.learning_rate = Parameter(learning_rate, name)
|
||||
self.gather = P.GatherV2()
|
||||
self.gather = P.Gather()
|
||||
|
||||
def construct(self, global_step):
|
||||
return self.gather(self.learning_rate, global_step, 0)
|
||||
|
|
|
@ -0,0 +1,36 @@
|
|||
# This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
||||
#
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
""" Define constants"""
|
||||
|
||||
# Arithmetic
|
||||
kScalarAdd = "ScalarAdd"
|
||||
kScalarSub = "ScalarSub"
|
||||
kScalarMul = "ScalarMul"
|
||||
kScalarDiv = "ScalarDiv"
|
||||
kScalarFloordiv = "ScalarFloordiv"
|
||||
kScalarMod = "ScalarMod"
|
||||
kScalarPow = "ScalarPow"
|
||||
kScalarTrunc = "ScalarTrunc"
|
||||
kScalarFloor = "ScalarFloor"
|
||||
kScalarUadd = "ScalarUadd"
|
||||
kScalarUsub = "ScalarUsub"
|
||||
|
||||
kTupleGetItem = "TupleGetItem"
|
||||
kMakeTuple = "MakeTuple"
|
||||
|
||||
kGather = "Gather"
|
|
@ -382,7 +382,7 @@ def _regenerate_output_shape(x_shp, ind_shp, axis):
|
|||
return out_shape
|
||||
|
||||
|
||||
@bprop_getters.register(P.GatherV2)
|
||||
@bprop_getters.register(P.Gather)
|
||||
def get_bprop_gather_v2(self):
|
||||
"""Generate bprop for GatherV2"""
|
||||
|
||||
|
@ -738,7 +738,7 @@ def get_bprop_tensor_scatter_update(self):
|
|||
@bprop_getters.register(P.ScatterMax)
|
||||
def get_bprop_scatter_max(self):
|
||||
"""Generate bprop for ScatterMax"""
|
||||
gather = P.GatherV2()
|
||||
gather = P.Gather()
|
||||
|
||||
def bprop(x, indices, update, out, dout):
|
||||
return dout, zeros_like(indices), gather(dout, indices, 0)
|
||||
|
@ -816,7 +816,7 @@ def _gather_drop_negatives(params,
|
|||
is_positive=None):
|
||||
"""Helper function for unsorted segment ops."""
|
||||
maximum = P.Maximum()
|
||||
gather = P.GatherV2()
|
||||
gather = P.Gather()
|
||||
greater_equal = P.GreaterEqual()
|
||||
rank = P.Rank()
|
||||
fill = P.Fill()
|
||||
|
@ -895,7 +895,7 @@ def get_bprop_unsorted_segment_prod(self):
|
|||
equal = P.Equal()
|
||||
cast = P.Cast()
|
||||
select = P.Select()
|
||||
gather = P.GatherV2()
|
||||
gather = P.Gather()
|
||||
greater = P.Greater()
|
||||
ones_like = P.OnesLike()
|
||||
maximum = P.Maximum()
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
"""GatherV2 op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
gather_v2_op_info = TBERegOp("GatherV2") \
|
||||
gather_v2_op_info = TBERegOp("Gather") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("gather_v2_d.so") \
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
"""AddN op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
gather_v2_op_info = TBERegOp("GatherV2") \
|
||||
gather_v2_op_info = TBERegOp("Gather") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("gather_v2.so") \
|
||||
|
|
|
@ -81,7 +81,7 @@ expand_dims = P.ExpandDims()
|
|||
transpose = P.Transpose()
|
||||
squeeze = P.Squeeze()
|
||||
scatter_nd = P.ScatterNd()
|
||||
gather = P.GatherV2()
|
||||
gather = P.Gather()
|
||||
gather_nd = P.GatherNd()
|
||||
scatter_update = P.ScatterUpdate()
|
||||
scatter_nd_update = P.ScatterNdUpdate()
|
||||
|
|
|
@ -22,7 +22,7 @@ A collection of operators to build neural networks or to compute functions.
|
|||
from .image_ops import (CropAndResize)
|
||||
from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
|
||||
Diag, DiagPart, DType, ExpandDims, Eye,
|
||||
Fill, Ones, Zeros, GatherNd, GatherV2, SparseGatherV2, InvertPermutation,
|
||||
Fill, Ones, Zeros, GatherNd, GatherV2, Gather, SparseGatherV2, InvertPermutation,
|
||||
IsInstance, IsSubClass, ArgMaxWithValue, OnesLike, ZerosLike,
|
||||
Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue, Meshgrid,
|
||||
SameTypeShape, ScatterAdd, ScatterSub, ScatterMul, ScatterDiv, ScatterMax, ScatterMin,
|
||||
|
@ -159,6 +159,7 @@ __all__ = [
|
|||
'Transpose',
|
||||
'OneHot',
|
||||
'GatherV2',
|
||||
'Gather',
|
||||
'SparseGatherV2',
|
||||
'EmbeddingLookup',
|
||||
'Padding',
|
||||
|
|
|
@ -771,7 +771,7 @@ class Unique(Primitive):
|
|||
self.init_prim_io_names(inputs=['x'], outputs=['output'])
|
||||
|
||||
|
||||
class GatherV2(PrimitiveWithCheck):
|
||||
class Gather(PrimitiveWithCheck):
|
||||
"""
|
||||
Returns a slice of the input tensor based on the specified indices and axis.
|
||||
|
||||
|
@ -793,7 +793,7 @@ class GatherV2(PrimitiveWithCheck):
|
|||
>>> input_params = Tensor(np.array([[1, 2, 7, 42], [3, 4, 54, 22], [2, 2, 55, 3]]), mindspore.float32)
|
||||
>>> input_indices = Tensor(np.array([1, 2]), mindspore.int32)
|
||||
>>> axis = 1
|
||||
>>> output = ops.GatherV2()(input_params, input_indices, axis)
|
||||
>>> output = ops.Gather()(input_params, input_indices, axis)
|
||||
>>> print(output)
|
||||
[[ 2. 7.]
|
||||
[ 4. 54.]
|
||||
|
@ -815,7 +815,12 @@ class GatherV2(PrimitiveWithCheck):
|
|||
validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, "axis", self.name)
|
||||
|
||||
|
||||
class SparseGatherV2(GatherV2):
|
||||
def GatherV2():
|
||||
"""Warning: This will be changed later"""
|
||||
logger.warning("WARN_DEPRECATED: The usage of GatherV2 is deprecated. Please use Gather.")
|
||||
return Gather()
|
||||
|
||||
class SparseGatherV2(Gather):
|
||||
"""
|
||||
Returns a slice of input tensor based on the specified indices and axis.
|
||||
|
||||
|
|
|
@ -23,7 +23,7 @@ from mindspore.common.tensor import Tensor
|
|||
|
||||
class BboxAssignSampleForRcnn(nn.Cell):
|
||||
"""
|
||||
Bbox assigner and sampler defination.
|
||||
Bbox assigner and sampler definition.
|
||||
|
||||
Args:
|
||||
config (dict): Config.
|
||||
|
@ -71,7 +71,7 @@ class BboxAssignSampleForRcnn(nn.Cell):
|
|||
self.greater = P.Greater()
|
||||
self.select = P.Select()
|
||||
self.gatherND = P.GatherNd()
|
||||
self.gatherV2 = P.GatherV2()
|
||||
self.gatherV2 = P.Gather()
|
||||
self.squeeze = P.Squeeze()
|
||||
self.cast = P.Cast()
|
||||
self.logicaland = P.LogicalAnd()
|
||||
|
|
|
@ -22,7 +22,7 @@ from mindspore.common.tensor import Tensor
|
|||
|
||||
class BboxAssignSampleForRcnn(nn.Cell):
|
||||
"""
|
||||
Bbox assigner and sampler defination.
|
||||
Bbox assigner and sampler definition.
|
||||
|
||||
Args:
|
||||
config (dict): Config.
|
||||
|
|
|
@ -50,7 +50,7 @@ class DiceLoss(_Loss):
|
|||
self.equal = P.Equal()
|
||||
self.zeros_like = P.ZerosLike()
|
||||
self.add = P.TensorAdd()
|
||||
self.gather = P.GatherV2()
|
||||
self.gather = P.Gather()
|
||||
|
||||
def ohem_batch(self, scores, gt_texts, training_masks):
|
||||
'''
|
||||
|
|
|
@ -187,7 +187,7 @@ class Conv2d_Thor_GPU(_Conv):
|
|||
self.batch_size = Tensor(batch_size, mstype.float16)
|
||||
self.transpose = P.Transpose()
|
||||
self.cast = P.Cast()
|
||||
self.gather = P.GatherV2()
|
||||
self.gather = P.Gather()
|
||||
self.freq = Tensor(frequency, mstype.int32)
|
||||
self.axis = 0
|
||||
self.sqrt = P.Sqrt()
|
||||
|
@ -330,7 +330,7 @@ class Dense_Thor_GPU(Cell):
|
|||
self.dampingA = Tensor(np.identity(in_channels), mstype.float32)
|
||||
self.dampingG = Tensor(np.identity(out_channels), mstype.float32)
|
||||
self.cast = P.Cast()
|
||||
self.gather = P.GatherV2()
|
||||
self.gather = P.Gather()
|
||||
self.freq = Tensor(frequency, mstype.int32)
|
||||
self.axis = 0
|
||||
self.add = P.TensorAdd()
|
||||
|
@ -496,7 +496,7 @@ class Conv2d_Thor(_Conv):
|
|||
self.device_shape_pad_flag = True
|
||||
self.device_shape_pad = P.Pad(((0, 0), (0, C0 - self.in_channels), (0, 0), (0, C0 - self.in_channels)))
|
||||
self.slice = P.Slice()
|
||||
self.gather = P.GatherV2()
|
||||
self.gather = P.Gather()
|
||||
self.freq = Tensor(frequency, mstype.int32)
|
||||
self.loss_scale = Tensor(1 / loss_scale, mstype.float16)
|
||||
self.axis = 0
|
||||
|
@ -678,7 +678,7 @@ class Dense_Thor(Cell):
|
|||
self.pad = P.Pad(((0, 23), (0, 23)))
|
||||
self.pad1 = P.Pad(((0, 7), (0, 7)))
|
||||
self.slice = P.Slice()
|
||||
self.gather = P.GatherV2()
|
||||
self.gather = P.Gather()
|
||||
self.assignadd = P.AssignAdd()
|
||||
self.freq = Tensor(frequency, mstype.int32)
|
||||
self.axis = 0
|
||||
|
|
|
@ -149,7 +149,7 @@ class BGCF(nn.Cell):
|
|||
self.tanh = P.Tanh()
|
||||
self.shape = P.Shape()
|
||||
self.split = P.Split(0, 2)
|
||||
self.gather = P.GatherV2()
|
||||
self.gather = P.Gather()
|
||||
self.reshape = P.Reshape()
|
||||
self.concat_0 = P.Concat(0)
|
||||
self.concat_1 = P.Concat(1)
|
||||
|
|
|
@ -73,7 +73,7 @@ class GetMaskedLMOutput(nn.Cell):
|
|||
super(GetMaskedLMOutput, self).__init__()
|
||||
self.width = config.hidden_size
|
||||
self.reshape = P.Reshape()
|
||||
self.gather = P.GatherV2()
|
||||
self.gather = P.Gather()
|
||||
|
||||
weight_init = TruncatedNormal(config.initializer_range)
|
||||
self.dense = nn.Dense(self.width,
|
||||
|
|
|
@ -113,7 +113,7 @@ class EmbeddingLookup(nn.Cell):
|
|||
[vocab_size, embedding_size]))
|
||||
self.expand = P.ExpandDims()
|
||||
self.shape_flat = (-1,)
|
||||
self.gather = P.GatherV2()
|
||||
self.gather = P.Gather()
|
||||
self.one_hot = P.OneHot()
|
||||
self.on_value = Tensor(1.0, mstype.float32)
|
||||
self.off_value = Tensor(0.0, mstype.float32)
|
||||
|
@ -178,7 +178,7 @@ class EmbeddingPostprocessor(nn.Cell):
|
|||
self.reshape = P.Reshape()
|
||||
self.shape = tuple(embedding_shape)
|
||||
self.dropout = nn.Dropout(1 - dropout_prob)
|
||||
self.gather = P.GatherV2()
|
||||
self.gather = P.Gather()
|
||||
self.use_relative_positions = use_relative_positions
|
||||
self.slice = P.StridedSlice()
|
||||
_, seq, _ = self.shape
|
||||
|
@ -310,7 +310,7 @@ class RelaPosEmbeddingsGenerator(nn.Cell):
|
|||
self.reshape = P.Reshape()
|
||||
self.one_hot = nn.OneHot(depth=self.vocab_size)
|
||||
self.shape = P.Shape()
|
||||
self.gather = P.GatherV2() # index_select
|
||||
self.gather = P.Gather() # index_select
|
||||
self.matmul = P.BatchMatMul()
|
||||
|
||||
def construct(self):
|
||||
|
|
|
@ -81,7 +81,7 @@ class GetMaskedLMOutput(nn.Cell):
|
|||
super(GetMaskedLMOutput, self).__init__()
|
||||
self.width = config.hidden_size
|
||||
self.reshape = P.Reshape()
|
||||
self.gather = P.GatherV2()
|
||||
self.gather = P.Gather()
|
||||
|
||||
weight_init = TruncatedNormal(config.initializer_range)
|
||||
self.dense = Dense_Thor(in_channels=self.width,
|
||||
|
|
|
@ -138,7 +138,7 @@ class EmbeddingLookup(nn.Cell):
|
|||
[vocab_size, embedding_size]))
|
||||
self.expand = P.ExpandDims()
|
||||
self.shape_flat = (-1,)
|
||||
self.gather = P.GatherV2()
|
||||
self.gather = P.Gather()
|
||||
self.one_hot = P.OneHot()
|
||||
self.on_value = Tensor(1.0, mstype.float32)
|
||||
self.off_value = Tensor(0.0, mstype.float32)
|
||||
|
@ -210,7 +210,7 @@ class EmbeddingPostprocessor(nn.Cell):
|
|||
self.reshape = P.Reshape()
|
||||
self.shape = tuple(embedding_shape)
|
||||
self.dropout = nn.Dropout(1 - dropout_prob)
|
||||
self.gather = P.GatherV2()
|
||||
self.gather = P.Gather()
|
||||
self.use_relative_positions = use_relative_positions
|
||||
self.slice = P.StridedSlice()
|
||||
_, seq, width = self.shape
|
||||
|
@ -362,7 +362,7 @@ class RelaPosEmbeddingsGenerator(nn.Cell):
|
|||
self.reshape = P.Reshape()
|
||||
self.one_hot = nn.OneHot(depth=self.vocab_size)
|
||||
self.shape = P.Shape()
|
||||
self.gather = P.GatherV2() # index_select
|
||||
self.gather = P.Gather() # index_select
|
||||
self.matmul = P.BatchMatMul()
|
||||
|
||||
def construct(self):
|
||||
|
|
|
@ -64,7 +64,7 @@ class THOR(Optimizer):
|
|||
self.shape = P.Shape()
|
||||
self.reshape = P.Reshape()
|
||||
self.mul = P.Mul()
|
||||
self.gather = P.GatherV2()
|
||||
self.gather = P.Gather()
|
||||
self.matrix_A_inv = ()
|
||||
self.matrix_G_inv = ()
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
|
@ -225,8 +225,8 @@ class THOR(Optimizer):
|
|||
end_idx = mlm_fc_idx + 4
|
||||
new_grads = new_grads + gradients[begin_idx: end_idx]
|
||||
|
||||
lenth = len(gradients)
|
||||
new_grads = new_grads + gradients[lenth - 2: lenth]
|
||||
length = len(gradients)
|
||||
new_grads = new_grads + gradients[length - 2: length]
|
||||
gradients = new_grads
|
||||
else:
|
||||
new_grads = ()
|
||||
|
@ -350,8 +350,8 @@ class THOR(Optimizer):
|
|||
end_idx = mlm_fc_idx + 4
|
||||
new_grads = new_grads + gradients[begin_idx: end_idx]
|
||||
|
||||
lenth = len(gradients)
|
||||
new_grads = new_grads + gradients[lenth - 2: lenth]
|
||||
length = len(gradients)
|
||||
new_grads = new_grads + gradients[length - 2: length]
|
||||
gradients = new_grads
|
||||
|
||||
if self.weight_decay > 0:
|
||||
|
|
|
@ -66,7 +66,7 @@ class THOR(Optimizer):
|
|||
self.shape = P.Shape()
|
||||
self.reshape = P.Reshape()
|
||||
self.mul = P.Mul()
|
||||
self.gather = P.GatherV2()
|
||||
self.gather = P.Gather()
|
||||
self.matrix_A_inv = ()
|
||||
self.matrix_G_inv = ()
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
|
@ -230,8 +230,8 @@ class THOR(Optimizer):
|
|||
end_idx = mlm_fc_idx + 4
|
||||
new_grads = new_grads + gradients[begin_idx: end_idx]
|
||||
|
||||
lenth = len(gradients)
|
||||
new_grads = new_grads + gradients[lenth - 2: lenth]
|
||||
length = len(gradients)
|
||||
new_grads = new_grads + gradients[length - 2: length]
|
||||
gradients = new_grads
|
||||
gradients = self.grad_reducer_g(gradients)
|
||||
else:
|
||||
|
@ -356,8 +356,8 @@ class THOR(Optimizer):
|
|||
end_idx = mlm_fc_idx + 4
|
||||
new_grads = new_grads + gradients[begin_idx: end_idx]
|
||||
|
||||
lenth = len(gradients)
|
||||
new_grads = new_grads + gradients[lenth - 2: lenth]
|
||||
length = len(gradients)
|
||||
new_grads = new_grads + gradients[length - 2: length]
|
||||
gradients = new_grads
|
||||
gradients = self.grad_reducer_g(gradients)
|
||||
|
||||
|
|
|
@ -55,7 +55,7 @@ class Embedding_Thor(Cell):
|
|||
self.thor = True
|
||||
self.expand = P.ExpandDims()
|
||||
self.shape_flat = (-1,)
|
||||
self.gather = P.GatherV2()
|
||||
self.gather = P.Gather()
|
||||
self.one_hot = P.OneHot()
|
||||
self.on_value = Tensor(1.0, mstype.float32)
|
||||
self.off_value = Tensor(0.0, mstype.float32)
|
||||
|
@ -75,7 +75,7 @@ class Embedding_Thor(Cell):
|
|||
self.freq = Tensor(frequency, mstype.int32)
|
||||
self.axis = 0
|
||||
self.damping = damping
|
||||
self.gather = P.GatherV2()
|
||||
self.gather = P.Gather()
|
||||
self.sqrt = P.Sqrt()
|
||||
self.mul = P.Mul()
|
||||
self.cast = P.Cast()
|
||||
|
@ -199,7 +199,7 @@ class Dense_Thor(Cell):
|
|||
self.damping = damping
|
||||
self.loss_scale = Tensor(1 / loss_scale, mstype.float16)
|
||||
self.vector_matmul = P.CusBatchMatMul()
|
||||
self.gather = P.GatherV2()
|
||||
self.gather = P.Gather()
|
||||
self.assignadd = P.AssignAdd()
|
||||
self.freq = Tensor(frequency, mstype.int32)
|
||||
self.axis = 0
|
||||
|
|
|
@ -50,7 +50,7 @@ class EmbeddingLookup(nn.Cell):
|
|||
init_weight = np.random.normal(-initializer_range, initializer_range, size=[vocab_size, embed_dim])
|
||||
self.embedding_table = Parameter(Tensor(init_weight, mstype.float32))
|
||||
self.expand = P.ExpandDims()
|
||||
self.gather = P.GatherV2()
|
||||
self.gather = P.Gather()
|
||||
self.one_hot = P.OneHot()
|
||||
self.on_value = Tensor(1.0, mstype.float32)
|
||||
self.off_value = Tensor(0.0, mstype.float32)
|
||||
|
|
|
@ -195,7 +195,7 @@ class EmbeddingLookup(nn.Cell):
|
|||
self.vocab_size = config.vocab_size
|
||||
self.embedding_size = config.embedding_size
|
||||
self.embedding_table = Parameter(initializer(TruncatedNormal(0.02), [self.vocab_size, self.embedding_size]))
|
||||
self.gather = P.GatherV2()
|
||||
self.gather = P.Gather()
|
||||
self.shape = (-1, config.seq_length, config.embedding_size)
|
||||
def construct(self, input_ids):
|
||||
output = self.gather(self.embedding_table, input_ids, 0)
|
||||
|
|
|
@ -46,7 +46,7 @@ class EmbeddingLookup(nn.Cell):
|
|||
init_weight[0, :] = 0
|
||||
self.embedding_table = Parameter(Tensor(init_weight))
|
||||
self.expand = P.ExpandDims()
|
||||
self.gather = P.GatherV2()
|
||||
self.gather = P.Gather()
|
||||
self.one_hot = P.OneHot()
|
||||
self.on_value = Tensor(1.0, mstype.float32)
|
||||
self.off_value = Tensor(0.0, mstype.float32)
|
||||
|
|
|
@ -70,7 +70,7 @@ class PositionalEmbedding(nn.Cell):
|
|||
position_encoding(max_position_embeddings, embedding_size),
|
||||
mstype.float32
|
||||
)
|
||||
self.gather = P.GatherV2()
|
||||
self.gather = P.Gather()
|
||||
self.get_shape = P.Shape()
|
||||
|
||||
def construct(self, word_embeddings):
|
||||
|
|
|
@ -46,7 +46,7 @@ class EmbeddingLookup(nn.Cell):
|
|||
init_weight[0, :] = 0
|
||||
self.embedding_table = Parameter(Tensor(init_weight))
|
||||
self.expand = P.ExpandDims()
|
||||
self.gather = P.GatherV2()
|
||||
self.gather = P.Gather()
|
||||
self.one_hot = P.OneHot()
|
||||
self.on_value = Tensor(1.0, mstype.float32)
|
||||
self.off_value = Tensor(0.0, mstype.float32)
|
||||
|
|
|
@ -70,7 +70,7 @@ class PositionalEmbedding(nn.Cell):
|
|||
position_encoding(max_position_embeddings, embedding_size),
|
||||
mstype.float32
|
||||
)
|
||||
self.gather = P.GatherV2()
|
||||
self.gather = P.Gather()
|
||||
self.get_shape = P.Shape()
|
||||
|
||||
def construct(self, word_embeddings):
|
||||
|
|
|
@ -113,7 +113,7 @@ class EmbeddingLookup(nn.Cell):
|
|||
[vocab_size, embedding_size]))
|
||||
self.expand = P.ExpandDims()
|
||||
self.shape_flat = (-1,)
|
||||
self.gather = P.GatherV2()
|
||||
self.gather = P.Gather()
|
||||
self.one_hot = P.OneHot()
|
||||
self.on_value = Tensor(1.0, mstype.float32)
|
||||
self.off_value = Tensor(0.0, mstype.float32)
|
||||
|
@ -179,7 +179,7 @@ class EmbeddingPostprocessor(nn.Cell):
|
|||
self.shape = tuple(embedding_shape)
|
||||
self.layernorm = nn.LayerNorm((embedding_size,))
|
||||
self.dropout = nn.Dropout(1 - dropout_prob)
|
||||
self.gather = P.GatherV2()
|
||||
self.gather = P.Gather()
|
||||
self.use_relative_positions = use_relative_positions
|
||||
self.slice = P.StridedSlice()
|
||||
self.full_position_embeddings = Parameter(initializer
|
||||
|
@ -322,7 +322,7 @@ class RelaPosEmbeddingsGenerator(nn.Cell):
|
|||
self.on_value = Tensor(1.0, mstype.float32)
|
||||
self.off_value = Tensor(0.0, mstype.float32)
|
||||
self.shape = P.Shape()
|
||||
self.gather = P.GatherV2() # index_select
|
||||
self.gather = P.Gather() # index_select
|
||||
self.matmul = P.BatchMatMul()
|
||||
|
||||
def construct(self):
|
||||
|
@ -957,7 +957,7 @@ class BertModelCLS(nn.Cell):
|
|||
"""
|
||||
This class is responsible for classification task evaluation,
|
||||
i.e. XNLI(num_labels=3), LCQMC(num_labels=2), Chnsenti(num_labels=2).
|
||||
The returned output represents the final logits as the results of log_softmax is propotional to that of softmax.
|
||||
The returned output represents the final logits as the results of log_softmax is proportional to that of softmax.
|
||||
"""
|
||||
def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0,
|
||||
use_one_hot_embeddings=False, phase_type="student"):
|
||||
|
|
|
@ -118,7 +118,7 @@ class EmbeddingLookup(nn.Cell):
|
|||
self.embedding_table = Parameter(normal_weight([vocab_size, embedding_size], embedding_size))
|
||||
self.expand = P.ExpandDims()
|
||||
self.shape_flat = (-1,)
|
||||
self.gather = P.GatherV2()
|
||||
self.gather = P.Gather()
|
||||
self.one_hot = P.OneHot()
|
||||
self.on_value = Tensor(1.0, mstype.float32)
|
||||
self.off_value = Tensor(0.0, mstype.float32)
|
||||
|
@ -246,7 +246,7 @@ class LayerPreprocess(nn.Cell):
|
|||
|
||||
class LayerPostprocess(nn.Cell):
|
||||
"""
|
||||
postprocess ouput of each layer.
|
||||
postprocess output of each layer.
|
||||
"""
|
||||
def __init__(self,
|
||||
dropout_prob=0.1):
|
||||
|
|
|
@ -195,7 +195,7 @@ class DeepFMModel(nn.Cell):
|
|||
self.dense_layer_5 = DenseLayer(self.all_dim_list[4], self.all_dim_list[5], self.weight_bias_init,
|
||||
self.deep_layer_act, self.keep_prob, convert_dtype=convert_dtype, use_act=False)
|
||||
" FM, linear Layers "
|
||||
self.Gatherv2 = P.GatherV2()
|
||||
self.Gatherv2 = P.Gather()
|
||||
self.Mul = P.Mul()
|
||||
self.ReduceSum = P.ReduceSum(keep_dims=False)
|
||||
self.Reshape = P.Reshape()
|
||||
|
|
|
@ -277,7 +277,7 @@ class PredictWithSigmoid(nn.Cell):
|
|||
self.squeeze = P.Squeeze()
|
||||
self.k = k
|
||||
self.num_eval_neg = num_eval_neg
|
||||
self.gather = P.GatherV2()
|
||||
self.gather = P.Gather()
|
||||
self.reshape = P.Reshape()
|
||||
self.reducesum = P.ReduceSum(keep_dims=False)
|
||||
self.notequal = P.NotEqual()
|
||||
|
|
|
@ -200,8 +200,8 @@ class WideDeepModel(nn.Cell):
|
|||
self.concat = P.Concat(axis=1)
|
||||
self.cast = P.Cast()
|
||||
self.unique = P.Unique().shard(((1,),))
|
||||
self.wide_gatherv2 = P.GatherV2()
|
||||
self.deep_gatherv2 = P.GatherV2()
|
||||
self.wide_gatherv2 = P.Gather()
|
||||
self.deep_gatherv2 = P.Gather()
|
||||
if is_auto_parallel and sparse and not is_field_slice and not parameter_server:
|
||||
target = 'DEVICE'
|
||||
if host_device_mix:
|
||||
|
|
|
@ -252,7 +252,7 @@ class WideDeepModel(nn.Cell):
|
|||
convert_dtype=True,
|
||||
use_activation=False)
|
||||
|
||||
self.gather_v2 = P.GatherV2()
|
||||
self.gather_v2 = P.Gather()
|
||||
self.mul = P.Mul()
|
||||
self.reduce_sum_false = P.ReduceSum(keep_dims=False)
|
||||
self.reduce_sum_true = P.ReduceSum(keep_dims=True)
|
||||
|
|
|
@ -30,7 +30,7 @@ class CriterionsFaceAttri(nn.Cell):
|
|||
super(CriterionsFaceAttri, self).__init__()
|
||||
|
||||
# label
|
||||
self.gatherv2 = P.GatherV2()
|
||||
self.gatherv2 = P.Gather()
|
||||
self.squeeze = P.Squeeze(axis=1)
|
||||
self.cast = P.Cast()
|
||||
self.reshape = P.Reshape()
|
||||
|
|
|
@ -71,7 +71,7 @@ class CriterionsFaceQA(nn.Cell):
|
|||
'''CriterionsFaceQA'''
|
||||
def __init__(self):
|
||||
super(CriterionsFaceQA, self).__init__()
|
||||
self.gatherv2 = P.GatherV2()
|
||||
self.gatherv2 = P.Gather()
|
||||
self.squeeze = P.Squeeze(axis=1)
|
||||
self.shape = P.Shape()
|
||||
self.reshape = P.Reshape()
|
||||
|
|
|
@ -30,7 +30,7 @@ class ComputeRij(nn.Cell):
|
|||
self.broadcastto1 = P.BroadcastTo((1, 192, 138, 3))
|
||||
self.expdims = P.ExpandDims()
|
||||
self.concat = P.Concat(axis=1)
|
||||
self.gather = P.GatherV2()
|
||||
self.gather = P.Gather()
|
||||
self.mul = P.Mul()
|
||||
self.slice = P.Slice()
|
||||
|
||||
|
@ -89,7 +89,7 @@ class ComputeDescriptor(nn.Cell):
|
|||
|
||||
self.expdims = P.ExpandDims()
|
||||
self.concat = P.Concat(axis=3)
|
||||
self.gather = P.GatherV2()
|
||||
self.gather = P.Gather()
|
||||
self.mul = P.Mul()
|
||||
self.slice = P.Slice()
|
||||
self.square = P.Square()
|
||||
|
|
|
@ -89,7 +89,7 @@ class GatherV2Quant(nn.Cell):
|
|||
|
||||
def __init__(self, activation_init=6):
|
||||
super(GatherV2Quant, self).__init__()
|
||||
self.gather = P.GatherV2()
|
||||
self.gather = P.Gather()
|
||||
|
||||
self.fake_quant_input = FakeQuantWithMinMax(min_init=-activation_init, max_init=activation_init, ema=True,
|
||||
symmetric=False)
|
||||
|
@ -309,7 +309,7 @@ class EmbeddingLookup(nn.Cell):
|
|||
if do_quant:
|
||||
self.gather = GatherV2Quant(activation_init=activation_init)
|
||||
else:
|
||||
self.gather = P.GatherV2()
|
||||
self.gather = P.Gather()
|
||||
self.one_hot = P.OneHot()
|
||||
self.on_value = Tensor(1.0, mstype.float32)
|
||||
self.off_value = Tensor(0.0, mstype.float32)
|
||||
|
@ -376,7 +376,7 @@ class EmbeddingPostprocessor(nn.Cell):
|
|||
self.shape = tuple(embedding_shape)
|
||||
self.layernorm = nn.LayerNorm((embedding_size,))
|
||||
self.dropout = nn.Dropout(1 - dropout_prob)
|
||||
self.gather = P.GatherV2()
|
||||
self.gather = P.Gather()
|
||||
self.use_relative_positions = use_relative_positions
|
||||
self.slice = P.StridedSlice()
|
||||
self.full_position_embeddings = Parameter(initializer
|
||||
|
@ -532,7 +532,7 @@ class RelaPosEmbeddingsGenerator(nn.Cell):
|
|||
self.on_value = Tensor(1.0, mstype.float32)
|
||||
self.off_value = Tensor(0.0, mstype.float32)
|
||||
self.shape = P.Shape()
|
||||
self.gather = P.GatherV2() # index_select
|
||||
self.gather = P.Gather() # index_select
|
||||
self.matmul = P.BatchMatMul()
|
||||
|
||||
def construct(self):
|
||||
|
|
|
@ -215,7 +215,7 @@ class AutoDisModel(nn.Cell):
|
|||
self.dense_layer_4 = DenseLayer(self.all_dim_list[3], self.all_dim_list[4],
|
||||
self.weight_bias_init, self.deep_layer_act, self.keep_prob)
|
||||
# FM, linear Layers
|
||||
self.Gatherv2 = P.GatherV2()
|
||||
self.Gatherv2 = P.Gather()
|
||||
self.Mul = P.Mul()
|
||||
self.ReduceSum = P.ReduceSum(keep_dims=False)
|
||||
self.Reshape = P.Reshape()
|
||||
|
|
|
@ -135,7 +135,7 @@ class NetWithSparseGatherV2(nn.Cell):
|
|||
self.gather = P.SparseGatherV2()
|
||||
else:
|
||||
self.weight = Parameter(Tensor(np.ones([8, 8]).astype(np.float32)), name="weight")
|
||||
self.gather = P.GatherV2()
|
||||
self.gather = P.Gather()
|
||||
if strategy is not None:
|
||||
self.gather.shard(strategy)
|
||||
|
||||
|
|
|
@ -24,4 +24,5 @@ import pytest
|
|||
def test_allreduce_sparsegatherv2_adam_auto_parallel():
|
||||
sh_path = os.path.split(os.path.realpath(__file__))[0]
|
||||
ret = os.system(f"sh {sh_path}/run_hcom_sparsetensor.sh")
|
||||
os.system(f"grep -E 'ERROR|error' {sh_path}/hcom_sparsetensor*/test_hcom_sparsetensor_8p_log* -C 3")
|
||||
assert ret == 0
|
||||
|
|
|
@ -223,7 +223,7 @@ class DeepFMModel(nn.Cell):
|
|||
self.dense_layer_5 = DenseLayer(self.all_dim_list[4], self.all_dim_list[5], self.weight_bias_init,
|
||||
self.deep_layer_act, self.keep_prob, convert_dtype=True, use_act=False)
|
||||
" FM, linear Layers "
|
||||
self.Gatherv2 = P.GatherV2()
|
||||
self.Gatherv2 = P.Gather()
|
||||
self.Mul = P.Mul()
|
||||
self.ReduceSum = P.ReduceSum(keep_dims=False)
|
||||
self.Reshape = P.Reshape()
|
||||
|
|
|
@ -53,8 +53,8 @@ def init_var_dict(init_args, in_vars):
|
|||
'''
|
||||
var_map = {}
|
||||
_, _max_val = init_args
|
||||
for _, iterm in enumerate(in_vars):
|
||||
key, shape, method = iterm
|
||||
for _, item in enumerate(in_vars):
|
||||
key, shape, method = item
|
||||
if key not in var_map.keys():
|
||||
if method in ['random', 'uniform']:
|
||||
var_map[key] = Parameter(initializer(
|
||||
|
@ -176,8 +176,8 @@ class WideDeepModel(nn.Cell):
|
|||
self.weight_bias_init,
|
||||
self.deep_layer_act, convert_dtype=True)
|
||||
|
||||
self.gather_v2 = P.GatherV2().shard(((1, 8), (1, 1)))
|
||||
self.gather_v2_1 = P.GatherV2()
|
||||
self.gather_v2 = P.Gather().shard(((1, 8), (1, 1)))
|
||||
self.gather_v2_1 = P.Gather()
|
||||
self.mul = P.Mul()
|
||||
self.reduce_sum = P.ReduceSum(keep_dims=False)
|
||||
self.reshape = P.Reshape()
|
||||
|
|
|
@ -74,7 +74,7 @@ class GetMaskedLMOutput(nn.Cell):
|
|||
super(GetMaskedLMOutput, self).__init__()
|
||||
self.width = config.hidden_size
|
||||
self.reshape = P.Reshape()
|
||||
self.gather = P.GatherV2()
|
||||
self.gather = P.Gather()
|
||||
|
||||
weight_init = TruncatedNormal(config.initializer_range)
|
||||
self.dense = nn.Dense(self.width,
|
||||
|
|
|
@ -127,7 +127,7 @@ class EmbeddingLookup(nn.Cell):
|
|||
name='embedding_table')
|
||||
self.expand = P.ExpandDims()
|
||||
self.shape_flat = (-1,)
|
||||
self.gather = P.GatherV2()
|
||||
self.gather = P.Gather()
|
||||
self.one_hot = P.OneHot()
|
||||
self.on_value = Tensor(1.0, mstype.float32)
|
||||
self.off_value = Tensor(0.0, mstype.float32)
|
||||
|
@ -194,7 +194,7 @@ class EmbeddingPostprocessor(nn.Cell):
|
|||
self.shape = tuple(embedding_shape)
|
||||
self.layernorm = nn.LayerNorm((embedding_size,))
|
||||
self.dropout = nn.Dropout(1 - dropout_prob)
|
||||
self.gather = P.GatherV2()
|
||||
self.gather = P.Gather()
|
||||
self.use_relative_positions = use_relative_positions
|
||||
self.slice = P.StridedSlice()
|
||||
self.full_position_embeddings = Parameter(initializer
|
||||
|
@ -333,7 +333,7 @@ class RelaPosEmbeddingsGenerator(nn.Cell):
|
|||
self.reshape = P.Reshape()
|
||||
self.one_hot = nn.OneHot(depth=self.vocab_size)
|
||||
self.shape = P.Shape()
|
||||
self.gather = P.GatherV2() # index_select
|
||||
self.gather = P.Gather() # index_select
|
||||
self.matmul = P.BatchMatMul()
|
||||
|
||||
def construct(self):
|
||||
|
|
|
@ -200,7 +200,7 @@ class Conv2d_Thor(_Conv):
|
|||
self.device_shape_pad_flag = True
|
||||
self.device_shape_pad = P.Pad(((0, 0), (0, C0 - self.in_channels), (0, 0), (0, C0 - self.in_channels)))
|
||||
self.slice = P.Slice()
|
||||
self.gather = P.GatherV2()
|
||||
self.gather = P.Gather()
|
||||
self.freq = Tensor(frequency, mstype.int32)
|
||||
self.loss_scale = Tensor(1 / loss_scale, mstype.float16)
|
||||
self.axis = 0
|
||||
|
@ -383,7 +383,7 @@ class Dense_Thor(Cell):
|
|||
self.pad = P.Pad(((0, 24), (0, 24)))
|
||||
self.pad1 = P.Pad(((0, 8), (0, 8)))
|
||||
self.slice = P.Slice()
|
||||
self.gather = P.GatherV2()
|
||||
self.gather = P.Gather()
|
||||
self.assignadd = P.AssignAdd()
|
||||
self.freq = Tensor(frequency, mstype.int32)
|
||||
self.axis = 0
|
||||
|
|
|
@ -26,7 +26,7 @@ context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
|
|||
class NetGatherV2_axis0(nn.Cell):
|
||||
def __init__(self):
|
||||
super(NetGatherV2_axis0, self).__init__()
|
||||
self.gatherv2 = P.GatherV2()
|
||||
self.gatherv2 = P.Gather()
|
||||
|
||||
def construct(self, params, indices):
|
||||
return self.gatherv2(params, indices, 0)
|
||||
|
@ -52,7 +52,7 @@ def test_gatherv2_axis0():
|
|||
class NetGatherV2_axis1(nn.Cell):
|
||||
def __init__(self):
|
||||
super(NetGatherV2_axis1, self).__init__()
|
||||
self.gatherv2 = P.GatherV2()
|
||||
self.gatherv2 = P.Gather()
|
||||
|
||||
def construct(self, params, indices):
|
||||
return self.gatherv2(params, indices, 1)
|
||||
|
@ -78,7 +78,7 @@ def test_gatherv2_axis1():
|
|||
class NetGatherV2_axisN1(nn.Cell):
|
||||
def __init__(self):
|
||||
super(NetGatherV2_axisN1, self).__init__()
|
||||
self.gatherv2 = P.GatherV2()
|
||||
self.gatherv2 = P.Gather()
|
||||
|
||||
def construct(self, params, indices):
|
||||
return self.gatherv2(params, indices, -1)
|
||||
|
|
|
@ -26,7 +26,7 @@ from mindspore.ops import operations as P
|
|||
class GatherNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(GatherNet, self).__init__()
|
||||
self.gather = P.GatherV2()
|
||||
self.gather = P.Gather()
|
||||
|
||||
def construct(self, x, indices):
|
||||
return self.gather(x, indices, 1)
|
||||
|
@ -850,7 +850,7 @@ def test_gather0():
|
|||
class GatherNet1(nn.Cell):
|
||||
def __init__(self):
|
||||
super(GatherNet1, self).__init__()
|
||||
self.gather = P.GatherV2()
|
||||
self.gather = P.Gather()
|
||||
|
||||
def construct(self, x, indices):
|
||||
return self.gather(x, indices, -1)
|
||||
|
@ -904,7 +904,7 @@ def test_gather1():
|
|||
class GatherNet2(nn.Cell):
|
||||
def __init__(self):
|
||||
super(GatherNet2, self).__init__()
|
||||
self.gather = P.GatherV2()
|
||||
self.gather = P.Gather()
|
||||
|
||||
def construct(self, x, indices):
|
||||
return self.gather(x, indices, 0)
|
||||
|
@ -944,7 +944,7 @@ def test_gather2():
|
|||
class GatherNetDynamic(nn.Cell):
|
||||
def __init__(self, axis=0, dyn_a=True, dyn_b=True):
|
||||
super(GatherNetDynamic, self).__init__()
|
||||
self.gather = P.GatherV2()
|
||||
self.gather = P.Gather()
|
||||
self.gpu_convert_to_dynamic_shape = inner.GpuConvertToDynamicShape()
|
||||
self.to_dyn_1 = dyn_a
|
||||
self.to_dyn_2 = dyn_b
|
||||
|
|
|
@ -367,7 +367,7 @@ TEST_F(TestConvert, TestConcat) {
|
|||
}
|
||||
|
||||
TEST_F(TestConvert, TestGatherV2) {
|
||||
auto prim = prim::kPrimGatherV2;
|
||||
auto prim = prim::kPrimGather;
|
||||
|
||||
std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, 3);
|
||||
std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
|
||||
|
|
|
@ -27,7 +27,7 @@ from mindspore.nn import ReLU
|
|||
from mindspore.nn import TrainOneStepCell, WithLossCell
|
||||
from mindspore.ops.operations.comm_ops import AllReduce, AllGather, _AlltoAll, ReduceOp, ReduceScatter
|
||||
from mindspore.ops.operations.comm_ops import Broadcast, AllSwap
|
||||
from mindspore.ops.operations.array_ops import GatherV2
|
||||
from mindspore.ops.operations.array_ops import Gather
|
||||
import mindspore
|
||||
|
||||
# pylint: disable=W0212
|
||||
|
@ -130,7 +130,7 @@ class AllSwapNet(nn.Cell):
|
|||
part_slice = batch_size / 2
|
||||
self.send_size = Tensor([0, part_slice*out_channel, part_slice*out_channel], mindspore.int64)
|
||||
self.recv_size = Tensor([part_slice*out_channel, part_slice*out_channel, 0], mindspore.int64)
|
||||
self.gatherv2 = GatherV2()
|
||||
self.gatherv2 = Gather()
|
||||
self.input = Tensor(np.ones([1]), mindspore.int32)
|
||||
def construct(self, x):
|
||||
x = self.allswap(x, self.send_size, self.recv_size)
|
||||
|
|
|
@ -143,7 +143,7 @@ class DeepFMOpNet(nn.Cell):
|
|||
"""Net definition with Gatherv2 and Tile and Square."""
|
||||
def __init__(self):
|
||||
super(DeepFMOpNet, self).__init__()
|
||||
self.gather = P.GatherV2()
|
||||
self.gather = P.Gather()
|
||||
self.square = P.Square()
|
||||
self.tile = P.Tile()
|
||||
|
||||
|
|
|
@ -97,7 +97,7 @@ def test_gatherv2():
|
|||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.unq = P.Unique()
|
||||
self.gather = P.GatherV2()
|
||||
self.gather = P.Gather()
|
||||
self.yy = Tensor(np.ones([8], dtype=np.int32))
|
||||
|
||||
def construct(self, x, y):
|
||||
|
|
|
@ -1766,37 +1766,37 @@ test_case_nn_ops = [
|
|||
'desc_inputs': [[2, 3, 4]],
|
||||
'desc_bprop': [[2, 3, 4], ([2, 3, 4], {'dtype': np.int32})]}),
|
||||
('GatherV2_0', {
|
||||
'block': P.GatherV2(),
|
||||
'block': P.Gather(),
|
||||
'desc_const': [0],
|
||||
'desc_inputs': [[3, 1, 2], Tensor(np.array([0, 1]).astype(np.int32))],
|
||||
'desc_bprop': [[2, 1, 2]]}),
|
||||
('GatherV2_1', {
|
||||
'block': P.GatherV2(),
|
||||
'block': P.Gather(),
|
||||
'desc_const': [2],
|
||||
'desc_inputs': [[3, 1, 3], Tensor(np.array([0, 1]).astype(np.int32))],
|
||||
'desc_bprop': [[3, 1, 2]]}),
|
||||
('GatherV2_2', {
|
||||
'block': P.GatherV2(),
|
||||
'block': P.Gather(),
|
||||
'desc_const': [0],
|
||||
'desc_inputs': [[3, 1, 3], Tensor(np.array([[0, 1], [0, 1], [0, 1]]).astype(np.int32))],
|
||||
'desc_bprop': [[3, 2, 1, 3]]}),
|
||||
('GatherV2_3', {
|
||||
'block': P.GatherV2(),
|
||||
'block': P.Gather(),
|
||||
'desc_const': [2],
|
||||
'desc_inputs': [[3, 1, 3], Tensor(np.array([[0, 1], [0, 1], [0, 1]]).astype(np.int32))],
|
||||
'desc_bprop': [[3, 1, 3, 2]]}),
|
||||
('GatherV2_4', {
|
||||
'block': P.GatherV2(),
|
||||
'block': P.Gather(),
|
||||
'desc_const': [1],
|
||||
'desc_inputs': [[32, 5, 1024], Tensor(np.array([3]).astype(np.int32))],
|
||||
'desc_bprop': [[32, 1, 1024]]}),
|
||||
('GatherV2_5', {
|
||||
'block': P.GatherV2(),
|
||||
'block': P.Gather(),
|
||||
'desc_const': [-1],
|
||||
'desc_inputs': [[3, 1, 3], Tensor(np.array([0, 1]).astype(np.int32))],
|
||||
'desc_bprop': [[3, 1, 2]]}),
|
||||
('GatherV2_6', {
|
||||
'block': P.GatherV2(),
|
||||
'block': P.Gather(),
|
||||
'desc_const': [0],
|
||||
'desc_inputs': [[1152], Tensor(np.array(10).astype(np.int32))],
|
||||
'desc_bprop': [Tensor(np.array(10).astype(np.float32))]}),
|
||||
|
|
|
@ -56,10 +56,10 @@ def test_unique_column_split():
|
|||
self.unique = P.Unique().shard(((1,),))
|
||||
self.relu = P.ReLU()
|
||||
self.mul = P.Mul()
|
||||
self.embedding_lookp = P.GatherV2().shard(((1, 8), (1,)))
|
||||
self.embedding_lookp = P.Gather().shard(((1, 8), (1,)))
|
||||
self.embedding_table = Parameter(initializer('normal', [2000, 128]),
|
||||
name='embedding_table')
|
||||
self.gatherv2 = P.GatherV2().shard(((1, 8), (1,)))
|
||||
self.gatherv2 = P.Gather().shard(((1, 8), (1,)))
|
||||
self.reshape = P.Reshape()
|
||||
self.matmul = P.MatMul()
|
||||
self.mul_weight = Parameter(Tensor(np.full([32, 64, 1], 0.5, dtype=np.float32)), name="mul_weight")
|
||||
|
@ -90,10 +90,10 @@ def test_unique_row_split():
|
|||
self.unique = P.Unique().shard(((1,),))
|
||||
self.relu = P.ReLU()
|
||||
self.mul = P.Mul()
|
||||
self.embedding_lookp = P.GatherV2().shard(((8, 1), (1,)))
|
||||
self.embedding_lookp = P.Gather().shard(((8, 1), (1,)))
|
||||
self.embedding_table = Parameter(initializer('normal', [2000, 128]),
|
||||
name='embedding_table')
|
||||
self.gatherv2 = P.GatherV2().shard(((1, 1), (1,)))
|
||||
self.gatherv2 = P.Gather().shard(((1, 1), (1,)))
|
||||
self.reshape = P.Reshape()
|
||||
self.matmul = P.MatMul()
|
||||
self.mul_weight = Parameter(Tensor(np.full([32, 64, 1], 0.5, dtype=np.float32)), name="mul_weight")
|
||||
|
|
|
@ -51,7 +51,7 @@ class Net(nn.Cell):
|
|||
super().__init__()
|
||||
if shape is None:
|
||||
shape = [64, 64]
|
||||
self.gatherv2 = P.GatherV2().shard(strategy1).add_prim_attr("primitive_target", target)
|
||||
self.gatherv2 = P.Gather().shard(strategy1).add_prim_attr("primitive_target", target)
|
||||
self.mul = P.Mul().shard(strategy2)
|
||||
self.index = Tensor(np.ones(shape), dtype=ms.int32)
|
||||
self.axis = axis
|
||||
|
|
|
@ -79,7 +79,7 @@ class GatherV2(_Loss):
|
|||
emb2_list = np.reshape(emb_list[1::2], (int(index_size / 2), 16))
|
||||
self.emb1_param = Tensor(emb1_list, dtype=mstype.int32)
|
||||
self.emb2_param = Tensor(emb2_list, dtype=mstype.int32)
|
||||
self.gatherv2 = P.GatherV2().shard(strategy).add_prim_attr("data_parallel", True)
|
||||
self.gatherv2 = P.Gather().shard(strategy).add_prim_attr("data_parallel", True)
|
||||
|
||||
def construct(self, nembeddings):
|
||||
emb1 = self.gatherv2(nembeddings, self.emb1_param, 0)
|
||||
|
@ -208,7 +208,7 @@ class GatherV2Axis1(_Loss):
|
|||
emb2_list = np.reshape(emb_list[1::2], (int(index_size / 2), index_size))
|
||||
self.emb1_param = Tensor(emb1_list, dtype=mstype.int32)
|
||||
self.emb2_param = Tensor(emb2_list, dtype=mstype.int32)
|
||||
self.gatherv2 = P.GatherV2().shard(strategy)
|
||||
self.gatherv2 = P.Gather().shard(strategy)
|
||||
|
||||
def construct(self, nembeddings):
|
||||
emb1 = self.gatherv2(nembeddings, self.emb1_param, 1)
|
||||
|
|
|
@ -33,7 +33,7 @@ class Net(Cell):
|
|||
split_string="manual_split",
|
||||
param_shape=(8, 8)):
|
||||
super().__init__()
|
||||
self.gatherv2 = P.GatherV2().shard(strategy1)
|
||||
self.gatherv2 = P.Gather().shard(strategy1)
|
||||
self.gatherv2.add_prim_attr(split_string, split_tuple)
|
||||
self.mul = P.Mul().shard(strategy2)
|
||||
self.reshape = P.Reshape()
|
||||
|
|
|
@ -24,7 +24,7 @@ from mindspore.ops import operations as P
|
|||
class Net(Cell):
|
||||
def __init__(self, matmul_weight, strategy1=None):
|
||||
super().__init__()
|
||||
self.gatherv2 = P.GatherV2().shard(strategy1)
|
||||
self.gatherv2 = P.Gather().shard(strategy1)
|
||||
self.reshape = P.Reshape().add_prim_attr("skip_redistribution", True)
|
||||
self.matmul = P.MatMul(transpose_b=False)
|
||||
self.index = Tensor(np.ones([64, 64]), dtype=ms.int32)
|
||||
|
|
|
@ -32,7 +32,7 @@ class Net(nn.Cell):
|
|||
if strategy1:
|
||||
self.sampler.shard(strategy1)
|
||||
self.embedding_table = Parameter(embedding_weight, "embedding_weight")
|
||||
self.gatherv2 = P.GatherV2()
|
||||
self.gatherv2 = P.Gather()
|
||||
self.reduce_sum = P.ReduceSum()
|
||||
self.reduce_sum2 = P.ReduceSum()
|
||||
self.reduce_sum3 = P.ReduceSum()
|
||||
|
|
|
@ -261,7 +261,7 @@ class AssignWhenInsertGrad(nn.Cell):
|
|||
|
||||
def __init__(self):
|
||||
super(AssignWhenInsertGrad, self).__init__()
|
||||
self.gather = P.GatherV2()
|
||||
self.gather = P.Gather()
|
||||
self.damping = Tensor(np.array([0.03, 0.03]).astype(np.float32))
|
||||
self.cov_step = ms.Parameter(0, name="cov_step", requires_grad=False)
|
||||
self.freq = Tensor(278, ms.int32)
|
||||
|
|
Loading…
Reference in New Issue