diff --git a/mindspore/ccsrc/backend/kernel_compiler/common_utils.cc b/mindspore/ccsrc/backend/kernel_compiler/common_utils.cc index 69c3bbec874..138dedeb668 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/common_utils.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/common_utils.cc @@ -856,5 +856,78 @@ bool GetShapeSize(const std::vector &shape, const TypePtr &type_ptr, int size_i[0] = LongMulWithOverflowCheck(size_i[0], SizeToInt(type_byte)); return true; } + +void CastShapeSizeToLong(const std::vector &shape, std::vector *long_shape) { + MS_EXCEPTION_IF_NULL(long_shape); + std::transform(shape.begin(), shape.end(), std::back_inserter(*long_shape), SizeToLong); +} + +void CheckSliceValid(const std::vector &start, const std::vector &stop, + const std::vector &step, const std::vector &input_shape) { + if (start.size() != stop.size() || start.size() != step.size() || start.size() > input_shape.size()) { + MS_LOG(EXCEPTION) + << "TensorCopySlices requires the length of begin, stride and end must be equal and less than input dimension."; + } + + size_t size = start.size(); + for (size_t i = 0; i < size; ++i) { + if (stop[i] <= start[i]) { + MS_LOG(EXCEPTION) << "Invalid slice: (" << start[i] << ", " << stop[i] << " ," << step[i] << ")"; + } + // Operator need to be generalized in the future. Only support to copy continuous memory now. + if (step[i] != 1) { + MS_LOG(EXCEPTION) << "The element in step only support 1, but got:" << step; + } + } + + size_t slice_pos = size; + for (size_t i = 0; i < size; ++i) { + if (stop[i] - start[i] > 1) { + slice_pos = i; + break; + } + } + + for (size_t i = slice_pos + 1; i < size; ++i) { + if (stop[i] - start[i] != input_shape[i]) { + MS_LOG(EXCEPTION) << "Only support copy continuous memory now. For example tensor[0, 0:100] is fine, " + "but tensor[0:100, 0] is not supported."; + } + } +} + +size_t GetCopySize(const std::vector &dim_offset, const std::vector &start, + const std::vector &stop) { + for (size_t i = 0; i < start.size(); ++i) { + if (stop[i] - start[i] != 1) { + return (stop[i] - start[i]) * dim_offset[i]; + } + } + return dim_offset[start.size() - 1]; +} + +std::vector CalDimOffset(const std::vector &input_shape) { + std::vector dim_offset; + int64_t offset = 1; + for (auto iter = input_shape.rbegin(); iter != input_shape.rend(); ++iter) { + dim_offset.push_back(offset); + offset = offset * (*iter); + } + std::reverse(dim_offset.begin(), dim_offset.end()); + return dim_offset; +} + +size_t CalOffset(const std::vector &start, const std::vector &stop, const std::vector &step, + const std::vector &dim_offset) { + size_t size = start.size(); + size_t offset = 0; + for (size_t i = 0; i < size; ++i) { + offset += dim_offset[i] * start[i]; + if (stop[i] - start[i] != 1) { + break; + } + } + return offset; +} } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/common_utils.h b/mindspore/ccsrc/backend/kernel_compiler/common_utils.h index 07a7fa071aa..49fdc72df0e 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/common_utils.h +++ b/mindspore/ccsrc/backend/kernel_compiler/common_utils.h @@ -133,6 +133,15 @@ inline T ComputeLerp(T top_left, T top_right, T bottom_left, T bottom_right, T x T bottom = bottom_left + (bottom_right - bottom_left) * x_lerp; return top + (bottom - top) * y_lerp; } + +void CastShapeSizeToLong(const std::vector &shape, std::vector *long_shape); +void CheckSliceValid(const std::vector &start, const std::vector &stop, + const std::vector &step, const std::vector &input_shape); +size_t CalOffset(const std::vector &start, const std::vector &stop, const std::vector &step, + const std::vector &dim_offset); +std::vector CalDimOffset(const std::vector &input_shape); +size_t GetCopySize(const std::vector &dim_offset, const std::vector &start, + const std::vector &stop); } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/tensor_copy_slices_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/tensor_copy_slices_cpu_kernel.cc new file mode 100644 index 00000000000..2d6a02ce504 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/tensor_copy_slices_cpu_kernel.cc @@ -0,0 +1,70 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/cpu/tensor_copy_slices_cpu_kernel.h" + +#include +#include +#include "abstract/utils.h" +#include "backend/kernel_compiler/common_utils.h" +#include "runtime/device/cpu/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +void TensorCopySlicesCPUKernel::InitKernel(const CNodePtr &kernel_node) { + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + auto update_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); + + CastShapeSizeToLong(input_shape, &input_shape_); + CastShapeSizeToLong(update_shape, &update_shape_); + CastShapeSizeToLong(output_shape, &output_shape_); + + auto begin = AnfAlgo::GetNodeAttr>(kernel_node, BEGIN); + auto end = AnfAlgo::GetNodeAttr>(kernel_node, END); + auto stride = AnfAlgo::GetNodeAttr>(kernel_node, STRIDES); + CheckSliceValid(begin, end, stride, input_shape_); + + data_type_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0); + auto dim_offset = CalDimOffset(input_shape_); + auto type_size = abstract::TypeIdSize(data_type_); + offset_ = CalOffset(begin, end, stride, dim_offset) * type_size; + copy_size_ = GetCopySize(dim_offset, begin, end) * type_size; +} + +bool TensorCopySlicesCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { + if (inputs.size() != 2 || outputs.size() != 1) { + MS_LOG(ERROR) << "TensorCopySlices requires 1 input and 1 output, but got " << inputs.size() << " input and " + << outputs.size() << " output."; + return false; + } + + auto input_addr = reinterpret_cast(inputs[0]->addr); + auto update_addr = reinterpret_cast(inputs[1]->addr); + auto output_addr = reinterpret_cast(outputs[0]->addr); + + if (memcpy_s(output_addr, outputs[0]->size, input_addr, inputs[0]->size) != EOK) { + MS_LOG(EXCEPTION) << "TensorCopySlices memcpy input failed"; + } + if (memcpy_s(output_addr + offset_, copy_size_, update_addr, copy_size_) != EOK) { + MS_LOG(EXCEPTION) << "TensorCopySlices memcpy update failed"; + } + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/tensor_copy_slices_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/tensor_copy_slices_cpu_kernel.h new file mode 100644 index 00000000000..3a1e2c41b75 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/tensor_copy_slices_cpu_kernel.h @@ -0,0 +1,51 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TENSOR_COPY_SLICES_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TENSOR_COPY_SLICES_CPU_KERNEL_H_ + +#include +#include + +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" +#include "nnacl/fp32/strided_slice_fp32.h" + +namespace mindspore { +namespace kernel { +class TensorCopySlicesCPUKernel : public CPUKernel { + public: + TensorCopySlicesCPUKernel() = default; + ~TensorCopySlicesCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + private: + TypeId data_type_; + size_t offset_; + size_t copy_size_; + std::vector input_shape_; + std::vector update_shape_; + std::vector output_shape_; +}; + +MS_REG_CPU_KERNEL(TensorCopySlices, KernelAttr(), TensorCopySlicesCPUKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TENSOR_COPY_SLICES_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/tensor_copy_slices_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/tensor_copy_slices_gpu_kernel.cc new file mode 100644 index 00000000000..87e036936b4 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/tensor_copy_slices_gpu_kernel.cc @@ -0,0 +1,50 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/arrays/tensor_copy_slices_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE( + TensorCopySlices, + KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), + TensorCopySlicesGpuKernel, double) +MS_REG_GPU_KERNEL_ONE( + TensorCopySlices, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + TensorCopySlicesGpuKernel, float) +MS_REG_GPU_KERNEL_ONE( + TensorCopySlices, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + TensorCopySlicesGpuKernel, half) +MS_REG_GPU_KERNEL_ONE( + TensorCopySlices, + KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), + TensorCopySlicesGpuKernel, int64_t) +MS_REG_GPU_KERNEL_ONE( + TensorCopySlices, + KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + TensorCopySlicesGpuKernel, int) +MS_REG_GPU_KERNEL_ONE( + TensorCopySlices, + KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), + TensorCopySlicesGpuKernel, char) +MS_REG_GPU_KERNEL_ONE( + TensorCopySlices, + KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), + TensorCopySlicesGpuKernel, uchar) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/tensor_copy_slices_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/tensor_copy_slices_gpu_kernel.h new file mode 100644 index 00000000000..feed60dbbf7 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/tensor_copy_slices_gpu_kernel.h @@ -0,0 +1,140 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_TENSOR_STRIDE_UPDATE_GPU_KERNEL_H_ +#define MINDSPORE_MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_TENSOR_STRIDE_UPDATE_GPU_KERNEL_H_ + +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/common_utils.h" + +namespace mindspore { +namespace kernel { +template +class TensorCopySlicesGpuKernel : public GpuKernel { + public: + TensorCopySlicesGpuKernel() : input_size_(0), update_size_(0), output_size_(0), offset_(0), copy_size_(0) {} + ~TensorCopySlicesGpuKernel() {} + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + T *input_addr = GetDeviceAddress(inputs, 0); + T *update_addr = GetDeviceAddress(inputs, 1); + T *output_addr = GetDeviceAddress(outputs, 0); + + if (inputs[1]->size != copy_size_) { + MS_LOG(EXCEPTION) << "Invalid update size:" << inputs[1]->size << " copy_size_:" << copy_size_; + } + + CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, + cudaMemcpyAsync(output_addr, input_addr, inputs[0]->size, cudaMemcpyDeviceToDevice, + reinterpret_cast(stream_ptr)), + "TensorCopySlices cudaMemcpyAsync outputs failed"); + + CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, + cudaMemcpyAsync(output_addr + offset_, update_addr, inputs[1]->size, + cudaMemcpyDeviceToDevice, reinterpret_cast(stream_ptr)), + "TensorCopySlices cudaMemcpyAsync outputs failed"); + + return true; + } + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Init(const CNodePtr &kernel_node) override { + kernel_node_ = kernel_node; + + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 2) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but TensorCopySlices needs 2 inputs."; + return false; + } + + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but TensorCopySlices has 1 output."; + return false; + } + + auto input_shapes = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + auto update_shapes = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + auto output_shapes = AnfAlgo::GetOutputInferShape(kernel_node, 0); + + CastShapeSizeToLong(input_shapes, &input_shapes_); + CastShapeSizeToLong(update_shapes, &update_shapes_); + CastShapeSizeToLong(output_shapes, &output_shapes_); + + GetSize(); + InitSizeLists(); + + auto begin = GetAttr>(kernel_node, kAttrBegin); + auto end = GetAttr>(kernel_node, kAttrEnd); + auto strides = GetAttr>(kernel_node, kAttrStrides); + + CheckSliceValid(begin, end, strides, input_shapes_); + auto dim_offset = CalDimOffset(input_shapes_); + offset_ = CalOffset(begin, end, strides, dim_offset); + copy_size_ = GetCopySize(dim_offset, begin, end) * sizeof(T); + return true; + } + + protected: + void GetSize() { + input_size_ = sizeof(T); + for (size_t i = 0; i < input_shapes_.size(); i++) { + input_size_ *= LongToSize(input_shapes_[i]); + } + + update_size_ = sizeof(T); + for (size_t i = 0; i < update_shapes_.size(); i++) { + update_size_ *= LongToSize(update_shapes_[i]); + } + output_size_ = sizeof(T); + for (size_t i = 0; i < output_shapes_.size(); i++) { + output_size_ *= LongToSize(output_shapes_[i]); + } + } + + void InitSizeLists() override { + input_size_list_.push_back(input_size_); + input_size_list_.push_back(update_size_); + output_size_list_.push_back(output_size_); + return; + } + + private: + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + + std::vector input_shapes_; + std::vector update_shapes_; + std::vector output_shapes_; + + size_t input_size_; + size_t update_size_; + size_t output_size_; + + size_t offset_; + size_t copy_size_; +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_TENSOR_STRIDE_UPDATE_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/memcpy_async.cc b/mindspore/ccsrc/backend/kernel_compiler/rts/memcpy_async.cc index 63df97b76d5..a7c0c927bc6 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/rts/memcpy_async.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/memcpy_async.cc @@ -154,10 +154,10 @@ device::DynamicKernelPtr MemCpyAsyncKernel::GenDynamicKernel(const CNodePtr &cno kernel_inputs[0]->size); } -const std::vector data_type_list{kNumberTypeInt, kNumberTypeInt8, kNumberTypeInt16, kNumberTypeInt32, - kNumberTypeInt64, kNumberTypeUInt, kNumberTypeUInt8, kNumberTypeUInt16, - kNumberTypeUInt32, kNumberTypeUInt64, kNumberTypeFloat, kNumberTypeFloat16, - kNumberTypeFloat32, kNumberTypeFloat64, kNumberTypeBool}; +const std::vector data_type_list = { + kNumberTypeInt, kNumberTypeInt8, kNumberTypeInt16, kNumberTypeInt32, kNumberTypeInt64, + kNumberTypeUInt, kNumberTypeUInt8, kNumberTypeUInt16, kNumberTypeUInt32, kNumberTypeUInt64, + kNumberTypeFloat, kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeFloat64, kNumberTypeBool}; const std::vector format_list = {kOpFormat_DEFAULT, kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_NC1HWC0, kOpFormat_FRAC_Z, kOpFormat_NC1KHKWHWC0, kOpFormat_C1HWNCoC0}; diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/tensor_copy_slices.cc b/mindspore/ccsrc/backend/kernel_compiler/rts/tensor_copy_slices.cc new file mode 100644 index 00000000000..c4f606f9221 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/tensor_copy_slices.cc @@ -0,0 +1,187 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/rts/tensor_copy_slices.h" +#include +#include +#include +#include +#include "abstract/utils.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/kernel_compiler/common_utils.h" +#include "common/trans.h" +#include "runtime/mem.h" +#include "runtime/device/kernel_runtime.h" +#include "utils/ms_context.h" + +using mindspore::ge::model_runner::MemcpyAsyncTaskInfo; +namespace mindspore { +namespace kernel { +TensorCopySlices::TensorCopySlices() {} + +TensorCopySlices::~TensorCopySlices() {} + +bool TensorCopySlices::Launch(const std::vector &inputs, const std::vector & /*workspace*/, + const std::vector &outputs, void *stream_ptr) { + if (inputs.size() != 2) { + MS_LOG(ERROR) << "inputs size is not 2"; + return false; + } + if (outputs.size() != 1) { + MS_LOG(ERROR) << "outputs size is not 1"; + return false; + } + if (outputs[0]->size != inputs[0]->size) { + MS_LOG(ERROR) << "TensorCopySlices destMax > src size"; + return false; + } + + auto status = rtMemcpyAsync(outputs[0]->addr, outputs[0]->size, inputs[0]->addr, inputs[0]->size, + RT_MEMCPY_DEVICE_TO_DEVICE, stream_ptr); + if (status != RT_ERROR_NONE) { + MS_LOG(ERROR) << "MemCpyAsync op rtMemcpyAsync failed!"; + return false; + } + status = rtMemcpyAsync(VoidPointerOffset(outputs[0]->addr, offset_), copy_size_, inputs[1]->addr, copy_size_, + RT_MEMCPY_DEVICE_TO_DEVICE, stream_ptr); + if (status != RT_ERROR_NONE) { + MS_LOG(ERROR) << "MemCpyAsync op rtMemcpyAsync failed!"; + return false; + } + return true; +} + +bool TensorCopySlices::Init(const mindspore::AnfNodePtr &anf_node) { + MS_EXCEPTION_IF_NULL(anf_node); + GetInputOutputInfo(anf_node); + GetInputOutputTotalCount(anf_node); + + auto begin = AnfAlgo::GetNodeAttr>(anf_node, kAttrBegin); + auto end = AnfAlgo::GetNodeAttr>(anf_node, kAttrEnd); + auto strides = AnfAlgo::GetNodeAttr>(anf_node, kAttrStrides); + + CheckSliceValid(begin, end, strides, input_shape_); + auto dim_offset = CalDimOffset(input_shape_); + offset_ = CalOffset(begin, end, strides, dim_offset) * abstract::TypeIdSize(input_type_id_); + copy_size_ = GetCopySize(dim_offset, begin, end) * abstract::TypeIdSize(input_type_id_); + return true; +} + +void TensorCopySlices::GetInputOutputInfo(const AnfNodePtr &anf_node) { + MS_EXCEPTION_IF_NULL(anf_node); + size_t input_size = AnfAlgo::GetInputTensorNum(anf_node); + if (input_size != 2) { + MS_LOG(EXCEPTION) << "TensorCopySlices input size is not 2"; + } + input_type_id_ = AnfAlgo::GetPrevNodeOutputDeviceDataType(anf_node, 0); + update_type_id_ = AnfAlgo::GetPrevNodeOutputDeviceDataType(anf_node, 0); + output_type_id_ = AnfAlgo::GetOutputDeviceDataType(anf_node, 0); + if (input_type_id_ != output_type_id_ || input_type_id_ != update_type_id_) { + MS_LOG(EXCEPTION) << "Input and output of TensorCopySlices is not same, input type:" << input_type_id_ + << " output_type_id_:" << output_type_id_; + } + + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(anf_node, 0); + auto update_shape = AnfAlgo::GetPrevNodeOutputInferShape(anf_node, 1); + auto output_shape = AnfAlgo::GetOutputDeviceShape(anf_node, 0); + CastShapeSizeToLong(input_shape, &input_shape_); + CastShapeSizeToLong(update_shape, &update_shape_); + CastShapeSizeToLong(output_shape, &output_shape_); +} + +void *TensorCopySlices::VoidPointerOffset(void *ptr, size_t offset) { + return reinterpret_cast(ptr) + offset; +} + +void TensorCopySlices::GetInputOutputTotalCount(const AnfNodePtr &anf_node) { + MS_EXCEPTION_IF_NULL(anf_node); + size_t input_size = AnfAlgo::GetInputTensorNum(anf_node); + if (input_size != 2) { + MS_LOG(EXCEPTION) << "TensorCopySlices input size is not 2"; + } + + auto input_shape = AnfAlgo::GetInputDeviceShape(anf_node, 0); + size_t total_size = std::accumulate(input_shape.begin(), input_shape.end(), 1, std::multiplies<>()); + total_size *= abstract::TypeIdSize(input_type_id_); + MS_LOG(INFO) << "TensorCopySlices size[" << total_size << "]"; + // Shape and DType of input0 and output0 are same. + input_size_list_.emplace_back(total_size); + output_size_list_.emplace_back(total_size); + + auto update_shape = AnfAlgo::GetInputDeviceShape(anf_node, 1); + size_t update_size = std::accumulate(update_shape.begin(), update_shape.end(), 1, std::multiplies<>()); + update_size *= abstract::TypeIdSize(update_type_id_); + input_size_list_.emplace_back(update_size); +} + +std::vector TensorCopySlices::GenTask(const std::vector &inputs, + const std::vector &, + const std::vector &outputs, uint32_t stream_id) { + if (inputs.size() != 2) { + MS_LOG(EXCEPTION) << "inputs size is not 2."; + } + if (outputs.size() != 1) { + MS_LOG(EXCEPTION) << "outputs size is not 1."; + } + if (outputs[0]->size != inputs[0]->size) { + MS_LOG(EXCEPTION) << "TensorCopySlices input size and output size not equal."; + } + + stream_id_ = stream_id; + std::shared_ptr task_info_ptr1 = + std::make_shared(kernel_name_, stream_id, outputs[0]->addr, outputs[0]->size, inputs[0]->addr, + inputs[0]->size, RT_MEMCPY_DEVICE_TO_DEVICE, NeedDump()); + std::shared_ptr task_info_ptr2 = std::make_shared( + kernel_name_, stream_id, VoidPointerOffset(outputs[0]->addr, offset_), copy_size_, inputs[1]->addr, copy_size_, + RT_MEMCPY_DEVICE_TO_DEVICE, NeedDump()); + return {task_info_ptr1, task_info_ptr2}; +} + +const std::vector data_type_list = { + kNumberTypeInt, kNumberTypeInt8, kNumberTypeInt16, kNumberTypeInt32, kNumberTypeInt64, + kNumberTypeUInt, kNumberTypeUInt8, kNumberTypeUInt16, kNumberTypeUInt32, kNumberTypeUInt64, + kNumberTypeFloat, kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeFloat64, kNumberTypeBool}; +// If input's format is 5D, we will insert TransData before TensorCopySlices. +const std::vector format_list = {kOpFormat_DEFAULT, kOpFormat_NCHW, kOpFormat_NHWC}; + +TensorCopySlicesDesc::TensorCopySlicesDesc() {} + +TensorCopySlicesDesc::~TensorCopySlicesDesc() {} + +// TensorCopySlices Register +std::vector> TensorCopySlicesDesc::GetKernelInfo() { + std::vector> tensor_copy_slices_build_info{}; + for (const auto &format : format_list) { + for (const auto &type : data_type_list) { + auto builder = KernelBuildInfo::KernelBuildInfoBuilder(); + vector input_format{format, format}; + vector input_type{type, type}; + vector output_format{format}; + vector output_type{type}; + builder.SetInputsFormat(input_format); + builder.SetInputsDeviceType(input_type); + builder.SetOutputsFormat(output_format); + builder.SetOutputsDeviceType(output_type); + builder.SetProcessor(AICORE); + builder.SetKernelType(RT_KERNEL); + builder.SetFusionType(OPAQUE); + tensor_copy_slices_build_info.emplace_back(builder.Build()); + } + } + return tensor_copy_slices_build_info; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/tensor_copy_slices.h b/mindspore/ccsrc/backend/kernel_compiler/rts/tensor_copy_slices.h new file mode 100644 index 00000000000..0ebd9a0f254 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/tensor_copy_slices.h @@ -0,0 +1,66 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_RTS_TENSOR_COPY_SLICES_H +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_RTS_TENSOR_COPY_SLICES_H + +#include +#include +#include "backend/kernel_compiler/rts/rt_kernel.h" +#include "backend/kernel_compiler/rts/rt_kernel_info.h" + +namespace mindspore { +namespace kernel { +class TensorCopySlices : public RtKernel { + public: + TensorCopySlices(); + ~TensorCopySlices() override; + + bool Init(const AnfNodePtr &anf_node) override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override; + std::vector GenTask(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, uint32_t stream_id) override; + + private: + void GetInputOutputInfo(const AnfNodePtr &anf_node); + void GetInputOutputTotalCount(const AnfNodePtr &anf_node); + void *VoidPointerOffset(void *ptr, size_t offset); + + std::vector input_shape_; + std::vector update_shape_; + std::vector output_shape_; + TypeId input_type_id_{}; + TypeId output_type_id_{}; + TypeId update_type_id_{}; + + size_t offset_{0}; + size_t copy_size_{0}; +}; + +class TensorCopySlicesDesc : public RtKerDesc { + public: + TensorCopySlicesDesc(); + ~TensorCopySlicesDesc() override; + std::vector> GetKernelInfo() override; +}; + +MS_REG_RTKERNEL_DESC(tensorcopyslices, TensorCopySlicesDesc); +MS_REG_RTKERNEL(tensorcopyslices, TensorCopySlices); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_RTS_TENSOR_COPY_SLICES_H diff --git a/mindspore/ccsrc/backend/optimizer/common/const_input_to_attr_registry.cc b/mindspore/ccsrc/backend/optimizer/common/const_input_to_attr_registry.cc index fe661ee46a0..01353b42aa1 100644 --- a/mindspore/ccsrc/backend/optimizer/common/const_input_to_attr_registry.cc +++ b/mindspore/ccsrc/backend/optimizer/common/const_input_to_attr_registry.cc @@ -66,6 +66,7 @@ ConstInputToAttrInfoRegistry::ConstInputToAttrInfoRegistry() { Register(kStridedSliceAssignOpName, {1, 2, 3}); Register(kStridedSliceOpName, {1, 2, 3}); Register(kStridedSliceGradOpName, {1, 2, 3, 4}); + Register(kTensorCopySlicesOpName, {2, 3, 4}); Register(kFlattenGradOpName, {1}); Register(kExpandDimsOpName, {1}); Register(kSplitOpName, {0}); diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index 4d6838a0327..a78a875325e 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -215,6 +215,7 @@ constexpr auto kSparseApplyRMSPropOpName = "SparseApplyRMSProp"; constexpr auto kSparseApplyAdadeltaOpName = "SparseApplyAdadelta"; constexpr auto kApplyAdamWithAmsgradOpName = "ApplyAdamWithAmsgrad"; constexpr auto kTensorMoveOpName = "TensorMove"; +constexpr auto kTensorCopySlicesOpName = "TensorCopySlices"; constexpr auto kTensorScatterUpdateOpName = "TensorScatterUpdate"; constexpr auto kScatterNdUpdateOpName = "ScatterNdUpdate"; constexpr auto kPushOpName = "Push"; @@ -406,6 +407,7 @@ constexpr auto kAttrPaddings = "paddings"; constexpr auto kAttrNumSegments = "num_segments"; constexpr auto kAttrStackOpName = "stack_op_name"; constexpr auto kAttrBegin = "begin"; +constexpr auto kAttrEnd = "end"; constexpr auto kAttrSize = "size"; constexpr auto kAttrIsDynamicShape = "is_dynamic_shape"; constexpr auto kAttrInputIsDynamicShape = "input_is_dynamic_shape"; diff --git a/mindspore/core/abstract/infer_functions.h b/mindspore/core/abstract/infer_functions.h index 41f3482cd46..5f06d31b8bf 100644 --- a/mindspore/core/abstract/infer_functions.h +++ b/mindspore/core/abstract/infer_functions.h @@ -279,6 +279,8 @@ AbstractBasePtr InferImplMaskedSelect(const AnalysisEnginePtr &, const Primitive const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplTransData(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplTensorCopySlices(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); template AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) { // Inputs: a tuple or list or dict. diff --git a/mindspore/core/abstract/prim_arrays.cc b/mindspore/core/abstract/prim_arrays.cc index 3d36528a584..2ebd4fb39cb 100644 --- a/mindspore/core/abstract/prim_arrays.cc +++ b/mindspore/core/abstract/prim_arrays.cc @@ -1214,5 +1214,14 @@ AbstractBasePtr InferImplDynamicStitch(const AnalysisEnginePtr &, const Primitiv return std::make_shared(infer_type, std::make_shared(out_shape, min_shape, max_shape)); } + +AbstractBasePtr InferImplTensorCopySlices(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + auto &op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 5); + AbstractTensorPtr input = CheckArg(op_name, args_spec_list, 0); + return std::make_shared(input->element(), input->shape()); +} + } // namespace abstract } // namespace mindspore diff --git a/mindspore/core/abstract/primitive_infer_map.cc b/mindspore/core/abstract/primitive_infer_map.cc index 0986c497c7d..dc3525276ee 100644 --- a/mindspore/core/abstract/primitive_infer_map.cc +++ b/mindspore/core/abstract/primitive_infer_map.cc @@ -107,6 +107,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { {prim::kPrimSequenceMask, {InferImplSequenceMask, nullptr, true}}, {prim::kPrimSort, {InferImplSort, nullptr, true}}, {prim::kPrimMaskedSelect, {InferImplMaskedSelect, nullptr, true}}, + {prim::kPrimTensorCopySlices, {InferImplTensorCopySlices, nullptr, true}}, // Structure {prim::kPrimMakeTuple, {InferImplMakeTuple, nullptr, true}}, {prim::kPrimMakeList, {InferImplMakeList, nullptr, true}}, diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index 9dfd38cc9ee..9b5a3dcbd10 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -212,6 +212,7 @@ inline const PrimitivePtr kPrimDynamicGRUV2 = std::make_shared("Dynam inline const PrimitivePtr kPrimDynamicGRUV2Grad = std::make_shared("DynamicGRUV2Grad"); inline const PrimitivePtr kPrimScatterAdd = std::make_shared("ScatterAdd"); inline const PrimitivePtr kPrimScatterUpdate = std::make_shared("ScatterUpdate"); +inline const PrimitivePtr kPrimTensorCopySlices = std::make_shared("TensorCopySlices"); inline const PrimitivePtr kPrimMapUniform = std::make_shared("MapUniform"); inline const PrimitivePtr kPrimSplit = std::make_shared("Split"); inline const PrimitivePtr kPrimSequenceMask = std::make_shared("SequenceMask"); diff --git a/mindspore/ops/_grad_experimental/__init__.py b/mindspore/ops/_grad_experimental/__init__.py index 5fabf43c0bf..b5cee4fb0bf 100644 --- a/mindspore/ops/_grad_experimental/__init__.py +++ b/mindspore/ops/_grad_experimental/__init__.py @@ -15,5 +15,6 @@ """grad experimental impl.""" from .._grad.grad_base import get_bprop_fn +from . import grad_inner_ops __all__ = ['get_bprop_fn'] diff --git a/mindspore/ops/_grad_experimental/grad_inner_ops.py b/mindspore/ops/_grad_experimental/grad_inner_ops.py new file mode 100644 index 00000000000..9ad74598b0f --- /dev/null +++ b/mindspore/ops/_grad_experimental/grad_inner_ops.py @@ -0,0 +1,32 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""inner_ops""" + +from .._grad.grad_base import bprop_getters +from ..operations import _inner_ops as inner +from .. import functional as F + +@bprop_getters.register(inner.TensorCopySlices) +def get_bprop_tensor_copy_slices(self): + """Generate bprop for TensorCopySlices""" + tensor_copy_slices = inner.TensorCopySlices() + + def bprop(x, update, begin, end, stride, out, dout): + x_grad = tensor_copy_slices(dout, F.zeros_like(update)) + update_grad = F.strided_slice(dout, begin, end, stride) + return x_grad, update_grad, F.zeros_like(begin), F.zeros_like(end), F.zeros_like(stride) + + return bprop diff --git a/mindspore/ops/_op_impl/cpu/__init__.py b/mindspore/ops/_op_impl/cpu/__init__.py index 4031ccd1380..e7bf7936773 100644 --- a/mindspore/ops/_op_impl/cpu/__init__.py +++ b/mindspore/ops/_op_impl/cpu/__init__.py @@ -61,3 +61,4 @@ from .add import _add_cpu from .one_hot import _one_hot_cpu from .pad import _pad_cpu from .range import _range_cpu +from .tensor_copy_slices import _tensor_copy_slices_cpu diff --git a/mindspore/ops/_op_impl/cpu/tensor_copy_slices.py b/mindspore/ops/_op_impl/cpu/tensor_copy_slices.py new file mode 100644 index 00000000000..454c73d63fd --- /dev/null +++ b/mindspore/ops/_op_impl/cpu/tensor_copy_slices.py @@ -0,0 +1,41 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""TensorCopySlices op""" +from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType + +tensor_copy_slices_op_info = CpuRegOp("TensorCopySlices") \ + .input(0, "x", "required") \ + .input(1, "value", "required") \ + .output(0, "y", "required") \ + .dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \ + .dtype_format(DataType.U16_Default, DataType.U16_Default, DataType.U16_Default) \ + .dtype_format(DataType.U32_Default, DataType.U32_Default, DataType.U32_Default) \ + .dtype_format(DataType.U64_Default, DataType.U64_Default, DataType.U64_Default) \ + .dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \ + .dtype_format(DataType.I16_Default, DataType.I16_Default, DataType.I16_Default) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ + .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.F64_Default) \ + .dtype_format(DataType.BOOL_Default, DataType.BOOL_Default, DataType.BOOL_Default) \ + .get_op_info() + + +@op_info_register(tensor_copy_slices_op_info) +def _tensor_copy_slices_cpu(): + """TensorCopySlices cpu register""" + return diff --git a/mindspore/ops/operations/_inner_ops.py b/mindspore/ops/operations/_inner_ops.py index 2b8e76003d2..ce7e3fbe6fc 100755 --- a/mindspore/ops/operations/_inner_ops.py +++ b/mindspore/ops/operations/_inner_ops.py @@ -1154,3 +1154,42 @@ class DynamicBroadcastGradientArgs(Primitive): @prim_attr_register def __init__(self): """Init BroadcastGradientArgs""" + + +class TensorCopySlices(Primitive): + """ + Copy continues memory. + + Inputs: + - **x** (Tensor) - The target Tensor. + - **value** (Tensor) - The tensor to update x. + - **begin** (tuple[int]) - A tuple which represents the location where to start. Only + constant value is allowed. + - **end** (tuple[int]) - A tuple or which represents the maximum location where to end. + Only constant value is allowed. + - **strides** (tuple[int]) - A tuple which represents the stride is continuously added + before reaching the maximum location. Only constant value is allowed. + + Outputs: + - **y** (Tensor), has the same shape and data type of x. + + Examples: + >>> import numpy as np + >>> from mindspore.ops.operations import _inner_ops + >>> copy_slices = _inner_ops.TensorCopySlices() + >>> out = copy_slices(Tensor(np.zeros((5, 5))), Tensor(np.ones((2, 5))), (3, 0), (5, 5), (1, 1)) + >>> print(out) + [[1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1.], + [0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.]] + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + """ + + @prim_attr_register + def __init__(self): + """Initialize TensorScatterUpdate""" + self.init_prim_io_names(inputs=['x', 'value', 'begin', 'end', 'strides'], outputs=['y']) diff --git a/tests/st/ops/ascend/test_aicpu_ops/test_tensor_copy_slices.py b/tests/st/ops/ascend/test_aicpu_ops/test_tensor_copy_slices.py new file mode 100644 index 00000000000..305f7ec50e3 --- /dev/null +++ b/tests/st/ops/ascend/test_aicpu_ops/test_tensor_copy_slices.py @@ -0,0 +1,105 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops.operations import _inner_ops + +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.copy_slices = _inner_ops.TensorCopySlices() + + def construct(self, input_x, update, begin, end, strides): + return self.copy_slices(input_x, update, begin, end, strides) + +def convert_begin_end_strides_to_slice(begin, end, strides): + result = [] + for x, y, z in zip(begin, end, strides): + result.append(slice(x, y, z)) + return tuple(result) + +def test_tensor_copy_slices_net(input_shape, update_shape, begin, end, strides, dtype): + input_np = np.zeros(input_shape, dtype) + update_np = np.ones(update_shape, dtype) + input_tensor = Tensor(input_np) + update = Tensor(update_np) + net = Net() + output = net(input_tensor, update, begin, end, strides) + slices = convert_begin_end_strides_to_slice(begin, end, strides) + input_np[slices] = update_np + assert (output.asnumpy() == input_np).all() + +def test_tensor_copy_slices_net_many_dtype(input_shape, update_shape, begin, end, strides, dtypes): + for dtype in dtypes: + test_tensor_copy_slices_net(input_shape, update_shape, begin, end, strides, dtype) + +support_dtype = (np.int64, np.int32, np.float64, np.float32) + +def test_tensor_copy_slices(): + test_tensor_copy_slices_net_many_dtype((10,), (5,), (0,), (5,), (1,), support_dtype) + test_tensor_copy_slices_net_many_dtype((10,), (5,), (5,), (10,), (1,), support_dtype) + test_tensor_copy_slices_net_many_dtype((10, 10), (5, 10), (0,), (5,), (1,), support_dtype) + test_tensor_copy_slices_net_many_dtype((10, 10), (5, 10), (5,), (10,), (1,), support_dtype) + test_tensor_copy_slices_net_many_dtype((10, 10), (5,), (9, 5), (10, 10), (1, 1), support_dtype) + test_tensor_copy_slices_net_many_dtype((10, 10, 10), (5, 10), (0, 5, 0), (1, 10, 10), (1, 1, 1,), support_dtype) + test_tensor_copy_slices_net_many_dtype((10, 10, 10), (5, 10), (9, 5,), (10, 10,), (1, 1,), support_dtype) + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_tensor_copy_slices_ascend_graph(): + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + test_tensor_copy_slices() + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_tensor_copy_slices_ascend_pynative(): + context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") + test_tensor_copy_slices() + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_tensor_copy_slices_gpu_graph(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + test_tensor_copy_slices() + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_tensor_copy_slices_gpu_pynative(): + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + test_tensor_copy_slices() + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu_training +@pytest.mark.env_onecard +def test_tensor_copy_slices_cpu_graph(): + context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + test_tensor_copy_slices() + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu_training +@pytest.mark.env_onecard +def test_tensor_copy_slices_cpu_pynative(): + context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU") + test_tensor_copy_slices() diff --git a/tests/st/ops/ascend/test_aicpu_ops/test_tensor_setitem.py b/tests/st/ops/ascend/test_aicpu_ops/test_tensor_setitem.py new file mode 100644 index 00000000000..c00f11d7f5a --- /dev/null +++ b/tests/st/ops/ascend/test_aicpu_ops/test_tensor_setitem.py @@ -0,0 +1,84 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor + + +class Net(nn.Cell): + def construct(self, input_x, update, slice_tuple): + input_x[slice_tuple] = update + return input_x + + +def test_tensor_setitem_net(input_shape, update_shape, slice_tuple, dtype): + input_np = np.zeros(input_shape, dtype) + update_np = np.ones(update_shape, dtype) + input_tensor = Tensor(input_np) + update = Tensor(update_np) + net = Net() + output = net(input_tensor, update, slice_tuple) + input_np[slice_tuple] = update_np + assert (output.asnumpy() == input_np).all() + + +def test_tensor_setitem_net_many_dtype(input_shape, update_shape, slice_tuple, dtypes): + for dtype in dtypes: + test_tensor_setitem_net(input_shape, update_shape, slice_tuple, dtype) + + +support_dtype = (np.int64, np.int32, np.float64, np.float32) + + +def test_tensor_setitem_all(): + test_tensor_setitem_net_many_dtype((10,), (5,), (slice(0, 5),), support_dtype) + test_tensor_setitem_net_many_dtype((10,), (5,), (slice(5, 10),), support_dtype) + test_tensor_setitem_net_many_dtype((10, 10), (5, 10), (slice(0, 5),), support_dtype) + test_tensor_setitem_net_many_dtype((10, 10), (5, 10), (slice(5, 10),), support_dtype) + test_tensor_setitem_net_many_dtype((10, 10), (5,), (9, slice(5, 10)), support_dtype) + test_tensor_setitem_net_many_dtype((10, 10, 10), (5, 10), (0, slice(5, 10)), support_dtype) + test_tensor_setitem_net_many_dtype((10, 10, 10), (5, 10), (9, slice(5, 10)), support_dtype) + + +def test_tensor_copy_slices_ascend_graph(): + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + test_tensor_setitem_all() + + +def test_tensor_copy_slices_ascend_pynative(): + context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") + test_tensor_setitem_all() + + +def test_tensor_copy_slices_gpu_graph(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + test_tensor_setitem_all() + + +def test_tensor_copy_slices_gpu_pynative(): + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + test_tensor_setitem_all() + + +def test_tensor_copy_slices_cpu_graph(): + context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + test_tensor_setitem_all() + + +def test_tensor_copy_slices_cpu_pynative(): + context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU") + test_tensor_setitem_all()