!27037 【cpu】Move kernel register into cpp

Merge pull request !27037 from VectorSL/update-tensor-array
This commit is contained in:
i-robot 2021-12-02 03:44:21 +00:00 committed by Gitee
commit becf381908
17 changed files with 251 additions and 265 deletions

View File

@ -41,7 +41,8 @@ class TensorArrayCPUClearKernel : public CPUKernel {
std::vector<size_t> workspace_size_list_;
};
MS_REG_CPU_KERNEL(TensorArrayClear, KernelAttr(), TensorArrayCPUClearKernel)
MS_REG_CPU_KERNEL(TensorArrayClear, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
TensorArrayCPUClearKernel);
} // namespace kernel
} // namespace mindspore

View File

@ -41,7 +41,8 @@ class TensorArrayCPUCloseKernel : public CPUKernel {
std::vector<size_t> workspace_size_list_;
};
MS_REG_CPU_KERNEL(TensorArrayClose, KernelAttr(), TensorArrayCPUCloseKernel)
MS_REG_CPU_KERNEL(TensorArrayClose, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
TensorArrayCPUCloseKernel);
} // namespace kernel
} // namespace mindspore

View File

@ -46,7 +46,7 @@ class TensorArrayCPUCreateKernel : public CPUKernel {
std::vector<size_t> workspace_size_list_;
};
MS_REG_CPU_KERNEL(TensorArray, KernelAttr(), TensorArrayCPUCreateKernel)
MS_REG_CPU_KERNEL(TensorArray, KernelAttr().AddOutputAttr(kNumberTypeInt64), TensorArrayCPUCreateKernel);
} // namespace kernel
} // namespace mindspore

View File

@ -45,8 +45,88 @@ class TensorArrayCPUReadKernel : public CPUKernel {
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
};
MS_REG_CPU_KERNEL(TensorArrayRead, KernelAttr(), TensorArrayCPUReadKernel)
// index int64
MS_REG_CPU_KERNEL(
TensorArrayRead,
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
TensorArrayCPUReadKernel);
MS_REG_CPU_KERNEL(
TensorArrayRead,
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
TensorArrayCPUReadKernel);
MS_REG_CPU_KERNEL(
TensorArrayRead,
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16),
TensorArrayCPUReadKernel);
MS_REG_CPU_KERNEL(
TensorArrayRead,
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32),
TensorArrayCPUReadKernel);
MS_REG_CPU_KERNEL(
TensorArrayRead,
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt16),
TensorArrayCPUReadKernel);
MS_REG_CPU_KERNEL(
TensorArrayRead,
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8),
TensorArrayCPUReadKernel);
MS_REG_CPU_KERNEL(
TensorArrayRead,
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt64),
TensorArrayCPUReadKernel);
MS_REG_CPU_KERNEL(
TensorArrayRead,
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
TensorArrayCPUReadKernel);
MS_REG_CPU_KERNEL(
TensorArrayRead,
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
TensorArrayCPUReadKernel);
MS_REG_CPU_KERNEL(
TensorArrayRead,
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
TensorArrayCPUReadKernel);
// index int32
MS_REG_CPU_KERNEL(
TensorArrayRead,
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
TensorArrayCPUReadKernel);
MS_REG_CPU_KERNEL(
TensorArrayRead,
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
TensorArrayCPUReadKernel);
MS_REG_CPU_KERNEL(
TensorArrayRead,
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16),
TensorArrayCPUReadKernel);
MS_REG_CPU_KERNEL(
TensorArrayRead,
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32),
TensorArrayCPUReadKernel);
MS_REG_CPU_KERNEL(
TensorArrayRead,
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt16),
TensorArrayCPUReadKernel);
MS_REG_CPU_KERNEL(
TensorArrayRead,
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8),
TensorArrayCPUReadKernel);
MS_REG_CPU_KERNEL(
TensorArrayRead,
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt64),
TensorArrayCPUReadKernel);
MS_REG_CPU_KERNEL(
TensorArrayRead,
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
TensorArrayCPUReadKernel);
MS_REG_CPU_KERNEL(
TensorArrayRead,
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
TensorArrayCPUReadKernel);
MS_REG_CPU_KERNEL(
TensorArrayRead,
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
TensorArrayCPUReadKernel);
} // namespace kernel
} // namespace mindspore

View File

@ -41,7 +41,8 @@ class TensorArrayCPUSizeKernel : public CPUKernel {
std::vector<size_t> workspace_size_list_;
};
MS_REG_CPU_KERNEL(TensorArraySize, KernelAttr(), TensorArrayCPUSizeKernel)
MS_REG_CPU_KERNEL(TensorArraySize, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
TensorArrayCPUSizeKernel);
} // namespace kernel
} // namespace mindspore

View File

@ -53,7 +53,26 @@ class TensorArrayCPUStackKernel : public CPUKernel {
std::vector<size_t> workspace_size_list_;
};
MS_REG_CPU_KERNEL(TensorArrayStack, KernelAttr(), TensorArrayCPUStackKernel)
MS_REG_CPU_KERNEL(TensorArrayStack, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
TensorArrayCPUStackKernel);
MS_REG_CPU_KERNEL(TensorArrayStack, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
TensorArrayCPUStackKernel);
MS_REG_CPU_KERNEL(TensorArrayStack, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16),
TensorArrayCPUStackKernel);
MS_REG_CPU_KERNEL(TensorArrayStack, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt64),
TensorArrayCPUStackKernel);
MS_REG_CPU_KERNEL(TensorArrayStack, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32),
TensorArrayCPUStackKernel);
MS_REG_CPU_KERNEL(TensorArrayStack, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt16),
TensorArrayCPUStackKernel);
MS_REG_CPU_KERNEL(TensorArrayStack, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8),
TensorArrayCPUStackKernel);
MS_REG_CPU_KERNEL(TensorArrayStack, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
TensorArrayCPUStackKernel);
MS_REG_CPU_KERNEL(TensorArrayStack, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
TensorArrayCPUStackKernel);
MS_REG_CPU_KERNEL(TensorArrayStack, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
TensorArrayCPUStackKernel);
} // namespace kernel
} // namespace mindspore

View File

@ -45,7 +45,148 @@ class TensorArrayCPUWriteKernel : public CPUKernel {
std::vector<size_t> workspace_size_list_;
};
MS_REG_CPU_KERNEL(TensorArrayWrite, KernelAttr(), TensorArrayCPUWriteKernel)
// index int64
MS_REG_CPU_KERNEL(TensorArrayWrite,
KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt64),
TensorArrayCPUWriteKernel);
MS_REG_CPU_KERNEL(TensorArrayWrite,
KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt64),
TensorArrayCPUWriteKernel);
MS_REG_CPU_KERNEL(TensorArrayWrite,
KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt16)
.AddOutputAttr(kNumberTypeInt64),
TensorArrayCPUWriteKernel);
MS_REG_CPU_KERNEL(TensorArrayWrite,
KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeUInt64)
.AddOutputAttr(kNumberTypeInt64),
TensorArrayCPUWriteKernel);
MS_REG_CPU_KERNEL(TensorArrayWrite,
KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeUInt32)
.AddOutputAttr(kNumberTypeInt64),
TensorArrayCPUWriteKernel);
MS_REG_CPU_KERNEL(TensorArrayWrite,
KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeUInt16)
.AddOutputAttr(kNumberTypeInt64),
TensorArrayCPUWriteKernel);
MS_REG_CPU_KERNEL(TensorArrayWrite,
KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeUInt8)
.AddOutputAttr(kNumberTypeInt64),
TensorArrayCPUWriteKernel);
MS_REG_CPU_KERNEL(TensorArrayWrite,
KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeInt64),
TensorArrayCPUWriteKernel);
MS_REG_CPU_KERNEL(TensorArrayWrite,
KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeInt64),
TensorArrayCPUWriteKernel);
MS_REG_CPU_KERNEL(TensorArrayWrite,
KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeBool)
.AddOutputAttr(kNumberTypeInt64),
TensorArrayCPUWriteKernel);
// index int32
MS_REG_CPU_KERNEL(TensorArrayWrite,
KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt64),
TensorArrayCPUWriteKernel);
MS_REG_CPU_KERNEL(TensorArrayWrite,
KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt64),
TensorArrayCPUWriteKernel);
MS_REG_CPU_KERNEL(TensorArrayWrite,
KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt16)
.AddOutputAttr(kNumberTypeInt64),
TensorArrayCPUWriteKernel);
MS_REG_CPU_KERNEL(TensorArrayWrite,
KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeUInt64)
.AddOutputAttr(kNumberTypeInt64),
TensorArrayCPUWriteKernel);
MS_REG_CPU_KERNEL(TensorArrayWrite,
KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeUInt32)
.AddOutputAttr(kNumberTypeInt64),
TensorArrayCPUWriteKernel);
MS_REG_CPU_KERNEL(TensorArrayWrite,
KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeUInt16)
.AddOutputAttr(kNumberTypeInt64),
TensorArrayCPUWriteKernel);
MS_REG_CPU_KERNEL(TensorArrayWrite,
KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeUInt8)
.AddOutputAttr(kNumberTypeInt64),
TensorArrayCPUWriteKernel);
MS_REG_CPU_KERNEL(TensorArrayWrite,
KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeInt64),
TensorArrayCPUWriteKernel);
MS_REG_CPU_KERNEL(TensorArrayWrite,
KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeInt64),
TensorArrayCPUWriteKernel);
MS_REG_CPU_KERNEL(TensorArrayWrite,
KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeBool)
.AddOutputAttr(kNumberTypeInt64),
TensorArrayCPUWriteKernel);
} // namespace kernel
} // namespace mindspore

View File

@ -49,7 +49,6 @@ bool CPUTensorArray::Write(const int64_t index, const mindspore::kernel::Address
tensors_.push_back(create_dev);
}
tensors_.push_back(dev_value);
// FillZeros(valid_size_, index);
for (size_t i = valid_size_; i < LongToSize(index); i++) {
auto tensor_size = tensors_[i]->size;
(void)memset_s(tensors_[i]->addr, tensor_size, 0, tensors_[i]->size);

View File

@ -51,7 +51,6 @@ bool GPUTensorArray::Write(const int64_t index, const mindspore::kernel::Address
tensors_.push_back(create_dev);
}
tensors_.push_back(dev_value);
// FillZeros(valid_size_, index);
for (size_t i = valid_size_; i < LongToSize(index); i++) {
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaMemsetAsync(tensors_[i]->addr, 0, tensors_[i]->size),
"failed to set cuda memory with zeros.")

View File

@ -74,10 +74,3 @@ from .pyfunc import _pyfunc_cpu
from .buffer_append import _buffer_append_cpu
from .buffer_get import _buffer_get_cpu
from .buffer_sample import _buffer_sample_cpu
from .tensor_array_clear import _tensor_array_clear_cpu
from .tensor_array_close import _tensor_array_close_cpu
from .tensor_array_create import _tensor_array_create_cpu
from .tensor_array_read import _tensor_array_read_cpu
from .tensor_array_size import _tensor_array_size_cpu
from .tensor_array_stack import _tensor_array_stack_cpu
from .tensor_array_write import _tensor_array_write_cpu

View File

@ -1,29 +0,0 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""TensorArrayClear op"""
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
tensor_array_clear_op_info = CpuRegOp("TensorArrayClear") \
.input(0, "handle", "required") \
.output(0, "y", "required") \
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
.get_op_info()
@op_info_register(tensor_array_clear_op_info)
def _tensor_array_clear_cpu():
"""TensorArrayClear cpu register"""
return

View File

@ -1,29 +0,0 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""TensorArrayClose op"""
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
tensor_array_close_op_info = CpuRegOp("TensorArrayClose") \
.input(0, "handle", "required") \
.output(0, "y", "required") \
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
.get_op_info()
@op_info_register(tensor_array_close_op_info)
def _tensor_array_close_cpu():
"""TensorArrayClose cpu register"""
return

View File

@ -1,28 +0,0 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""TensorArrayCreate op"""
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
tensor_array_create_op_info = CpuRegOp("TensorArray") \
.output(0, "handle", "required") \
.dtype_format(DataType.I64_Default) \
.get_op_info()
@op_info_register(tensor_array_create_op_info)
def _tensor_array_create_cpu():
"""TensorArrayCreate cpu register"""
return

View File

@ -1,48 +0,0 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""TensorArrayRead op"""
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
tensor_array_read_op_info = CpuRegOp("TensorArrayRead") \
.input(0, "handle", "required") \
.input(1, "index", "required") \
.output(0, "y", "required") \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I16_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.U32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.U16_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.U8_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.U64_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.F16_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.F32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I32_Default) \
.dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I16_Default) \
.dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.U32_Default) \
.dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.U16_Default) \
.dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.U8_Default) \
.dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.U64_Default) \
.dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.F16_Default) \
.dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.F32_Default) \
.dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.BOOL_Default) \
.get_op_info()
@op_info_register(tensor_array_read_op_info)
def _tensor_array_read_cpu():
"""TensorArrayRead cpu register"""
return

View File

@ -1,28 +0,0 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""TensorArraySize op"""
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
tensor_array_size_op_info = CpuRegOp("TensorArraySize") \
.input(0, "handle", "required") \
.output(0, "y", "required") \
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
.get_op_info()
@op_info_register(tensor_array_size_op_info)
def _tensor_array_size_cpu():
"""TensorArraySize cpu register"""
return

View File

@ -1,37 +0,0 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""TensorArrayStack op"""
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
tensor_array_stack_op_info = CpuRegOp("TensorArrayStack") \
.input(0, "handle", "required") \
.output(0, "y", "required") \
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.I64_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I16_Default) \
.dtype_format(DataType.I64_Default, DataType.U32_Default) \
.dtype_format(DataType.I64_Default, DataType.U16_Default) \
.dtype_format(DataType.I64_Default, DataType.U8_Default) \
.dtype_format(DataType.I64_Default, DataType.U64_Default) \
.dtype_format(DataType.I64_Default, DataType.F16_Default) \
.dtype_format(DataType.I64_Default, DataType.F32_Default) \
.dtype_format(DataType.I64_Default, DataType.BOOL_Default) \
.get_op_info()
@op_info_register(tensor_array_stack_op_info)
def _tensor_array_stack_cpu():
"""TensorArrayStack cpu register"""
return

View File

@ -1,49 +0,0 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""TensorArrayWrite op"""
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
tensor_array_write_op_info = CpuRegOp("TensorArrayWrite") \
.input(0, "handle", "required") \
.input(1, "index", "required") \
.input(2, "value", "required") \
.output(0, "y", "required") \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I32_Default, DataType.I64_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I16_Default, DataType.I64_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.U32_Default, DataType.I64_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.U16_Default, DataType.I64_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.U8_Default, DataType.I64_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.U64_Default, DataType.I64_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.F16_Default, DataType.I64_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.F32_Default, DataType.I64_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.BOOL_Default, DataType.I64_Default) \
.dtype_format(DataType.I64_Default, DataType.I32_Default, DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.I64_Default, DataType.I32_Default, DataType.I32_Default, DataType.I64_Default) \
.dtype_format(DataType.I64_Default, DataType.I32_Default, DataType.I16_Default, DataType.I64_Default) \
.dtype_format(DataType.I64_Default, DataType.I32_Default, DataType.U32_Default, DataType.I64_Default) \
.dtype_format(DataType.I64_Default, DataType.I32_Default, DataType.U16_Default, DataType.I64_Default) \
.dtype_format(DataType.I64_Default, DataType.I32_Default, DataType.U8_Default, DataType.I64_Default) \
.dtype_format(DataType.I64_Default, DataType.I32_Default, DataType.U64_Default, DataType.I64_Default) \
.dtype_format(DataType.I64_Default, DataType.I32_Default, DataType.F16_Default, DataType.I64_Default) \
.dtype_format(DataType.I64_Default, DataType.I32_Default, DataType.F32_Default, DataType.I64_Default) \
.dtype_format(DataType.I64_Default, DataType.I32_Default, DataType.BOOL_Default, DataType.I64_Default) \
.get_op_info()
@op_info_register(tensor_array_write_op_info)
def _tensor_array_write_cpu():
"""TensorArrayWrite cpu register"""
return