diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.cc index b3bc662c822..55cd03fdf3f 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.cc @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include #include #include #include @@ -251,6 +252,11 @@ void Atanh(const T *in, T *out, size_t size) { }; CPUKernelUtils::ParallelFor(task, size); } + +template +void Identity(const T *in, T *out, size_t size) { + std::copy(in, in + size, out); +} } // namespace static const std::map kArithmeticOpTypeMap = {{prim::kPrimNeg->name(), NEG}, @@ -274,7 +280,8 @@ static const std::map kArithmeticOpTypeMap = {{prim::k {prim::kPrimCosh->name(), COSH}, {prim::kPrimAsinh->name(), ASINH}, {prim::kPrimAcosh->name(), ACOSH}, - {prim::kPrimAtanh->name(), ATANH}}; + {prim::kPrimAtanh->name(), ATANH}, + {prim::kPrimIdentityMath->name(), IDENTITY}}; void ArithmeticSelfCPUKernel::InitKernel(const CNodePtr &kernel_node) { MS_EXCEPTION_IF_NULL(kernel_node); @@ -335,5 +342,16 @@ void ArithmeticSelfCPUKernel::LaunchKernel(const std::vector &inputs MS_LOG(EXCEPTION) << "Not support " << operate_type_; } } + +template +bool IdentityCPUKernel::Launch(const std::vector &inputs, + const std::vector &, + const std::vector &outputs) { + T *input = reinterpret_cast(inputs[0]->addr); + T *output = reinterpret_cast(outputs[0]->addr); + size_t lens = outputs[0]->size > 0 ? static_cast(outputs[0]->size / sizeof(T)) : 1; + Identity(input, output, lens); + return true; +} } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.h index d46f9336e8d..1acdb9f27ae 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.h @@ -43,6 +43,16 @@ class ArithmeticSelfCPUKernel : public CPUKernel { TypeId target_dtype_{kTypeUnknown}; }; +template +class IdentityCPUKernel : public ArithmeticSelfCPUKernel { + public: + IdentityCPUKernel() = default; + ~IdentityCPUKernel() override = default; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; +}; + MS_REG_CPU_KERNEL(Square, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), ArithmeticSelfCPUKernel); MS_REG_CPU_KERNEL(Neg, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), @@ -97,6 +107,31 @@ MS_REG_CPU_KERNEL(Acosh, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutput ArithmeticSelfCPUKernel); MS_REG_CPU_KERNEL(Atanh, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), ArithmeticSelfCPUKernel); + +MS_REG_CPU_KERNEL_T(Identity, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64), + IdentityCPUKernel, uint64_t); +MS_REG_CPU_KERNEL_T(Identity, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), + IdentityCPUKernel, int64_t); +MS_REG_CPU_KERNEL_T(Identity, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32), + IdentityCPUKernel, uint32_t); +MS_REG_CPU_KERNEL_T(Identity, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + IdentityCPUKernel, int32_t); +MS_REG_CPU_KERNEL_T(Identity, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16), + IdentityCPUKernel, uint16_t); +MS_REG_CPU_KERNEL_T(Identity, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), + IdentityCPUKernel, int16_t); +MS_REG_CPU_KERNEL_T(Identity, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), + IdentityCPUKernel, uint8_t); +MS_REG_CPU_KERNEL_T(Identity, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), + IdentityCPUKernel, int8_t); +MS_REG_CPU_KERNEL_T(Identity, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), + IdentityCPUKernel, double); +MS_REG_CPU_KERNEL_T(Identity, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + IdentityCPUKernel, float); +MS_REG_CPU_KERNEL_T(Identity, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + IdentityCPUKernel, float16); +MS_REG_CPU_KERNEL_T(Identity, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), + IdentityCPUKernel, bool); } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h index 8ec8e5e4d91..ae13c58fb5b 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h @@ -118,6 +118,7 @@ enum OperateType { ATAN2, RINT, ROUND, + IDENTITY, }; class CPUKernel : public kernel::KernelMod { diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/identity_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/eye_impl.cu similarity index 69% rename from mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/identity_impl.cu rename to mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/eye_impl.cu index ecb44c45acf..fdeaaa58a62 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/identity_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/eye_impl.cu @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,10 +14,10 @@ * limitations under the License. */ -#include "identity_impl.cuh" +#include "eye_impl.cuh" #include template -__global__ void IdentityKernel(const size_t size, const size_t dim, T *output_addr) { +__global__ void EyeKernel(const size_t size, const size_t dim, T *output_addr) { for (size_t pointIdx = blockIdx.x * blockDim.x + threadIdx.x; pointIdx < (size); pointIdx += blockDim.x * gridDim.x) { size_t batchIdx = pointIdx / (dim * dim); size_t dst_x = (pointIdx - batchIdx * dim * dim) / dim; @@ -31,10 +31,9 @@ __global__ void IdentityKernel(const size_t size, const size_t dim, T *output_ad } template -void Identity(const size_t size, const size_t dim, T *output_addr, cudaStream_t cuda_stream) { - IdentityKernel<<>>(size, dim, output_addr); +void Eye(const size_t size, const size_t dim, T *output_addr, cudaStream_t cuda_stream) { + EyeKernel<<>>(size, dim, output_addr); return; } -template void Identity(const size_t size, const size_t dim, float *output_addr, cudaStream_t cuda_stream); - +template void Eye(const size_t size, const size_t dim, float *output_addr, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/identity_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/eye_impl.cuh similarity index 66% rename from mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/identity_impl.cuh rename to mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/eye_impl.cuh index b8fd4a0be3f..e76635ba7e9 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/identity_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/eye_impl.cuh @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,11 +14,11 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_IDENTITY_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_IDENTITY_H_ +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_EYE_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_EYE_H_ #include "runtime/device/gpu/cuda_common.h" template -void Identity(const size_t size, const size_t dim, T *output_addr, cudaStream_t cuda_stream); +void Eye(const size_t size, const size_t dim, T *output_addr, cudaStream_t cuda_stream); -#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_MATRIXSPLIT_H_ +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_EYE_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/cholesky_solve_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/cholesky_solve_gpu_kernel.h index a0356ffac45..f3169165b35 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/cholesky_solve_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/cholesky_solve_gpu_kernel.h @@ -20,7 +20,7 @@ #include #include #include -#include "backend/kernel_compiler/gpu/cuda_impl/identity_impl.cuh" +#include "backend/kernel_compiler/gpu/cuda_impl/eye_impl.cuh" #include "backend/kernel_compiler/gpu/cuda_impl/matrix_split_impl.cuh" #include "backend/kernel_compiler/gpu/cuda_impl/triangle_matrix_copy_impl.cuh" #include "backend/kernel_compiler/gpu/gpu_kernel.h" @@ -79,7 +79,7 @@ class CholeskyGpuKernel : public GpuKernel { h_array[i] = d_batch_input_addr + i * lda_ * m_; h_identity[i] = output_addr + i * ldb_ * m_; } - Identity(batch_ * split_dim * split_dim, split_dim, output_addr, reinterpret_cast(stream_ptr)); + Eye(batch_ * split_dim * split_dim, split_dim, output_addr, reinterpret_cast(stream_ptr)); MatrixSplit(batch_ * split_dim * split_dim, split_dim, width, input1_addr, d_batch_input_addr, reinterpret_cast(stream_ptr)); CHECK_CUDA_RET_WITH_ERROR(kernel_node_, diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/cholesky_trsm_solve_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/cholesky_trsm_solve_gpu_kernel.h index 323c4befeb1..ef57d782711 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/cholesky_trsm_solve_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/cholesky_trsm_solve_gpu_kernel.h @@ -20,7 +20,7 @@ #include #include #include -#include "backend/kernel_compiler/gpu/cuda_impl/identity_impl.cuh" +#include "backend/kernel_compiler/gpu/cuda_impl/eye_impl.cuh" #include "backend/kernel_compiler/gpu/cuda_impl/matrix_split_impl.cuh" #include "backend/kernel_compiler/gpu/gpu_kernel.h" #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" @@ -149,7 +149,7 @@ class CholeskyTrsmGpuKernel : public GpuKernel { h_array[i] = d_batch_input_addr + i * lda_ * m_; h_identity[i] = output_addr + i * ldb_ * m_; } - Identity(batch_ * split_dim * split_dim, split_dim, output_addr, reinterpret_cast(stream_ptr)); + Eye(batch_ * split_dim * split_dim, split_dim, output_addr, reinterpret_cast(stream_ptr)); MatrixSplit(batch_ * split_dim * split_dim, split_dim, width, input1_addr, d_batch_input_addr, reinterpret_cast(stream_ptr)); CHECK_CUDA_RET_WITH_ERROR(kernel_node_, diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/identity_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/identity_gpu_kernel.cc new file mode 100644 index 00000000000..6c398e7499d --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/identity_gpu_kernel.cc @@ -0,0 +1,46 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/math/identity_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(Identity, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), + IdentityGpuKernel, double) +MS_REG_GPU_KERNEL_ONE(Identity, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + IdentityGpuKernel, float); +MS_REG_GPU_KERNEL_ONE(Identity, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + IdentityGpuKernel, half); +MS_REG_GPU_KERNEL_ONE(Identity, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64), + IdentityGpuKernel, uint64_t); +MS_REG_GPU_KERNEL_ONE(Identity, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), + IdentityGpuKernel, int64_t); +MS_REG_GPU_KERNEL_ONE(Identity, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32), + IdentityGpuKernel, uint32_t); +MS_REG_GPU_KERNEL_ONE(Identity, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + IdentityGpuKernel, int32_t); +MS_REG_GPU_KERNEL_ONE(Identity, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16), + IdentityGpuKernel, uint16_t); +MS_REG_GPU_KERNEL_ONE(Identity, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), + IdentityGpuKernel, int16_t); +MS_REG_GPU_KERNEL_ONE(Identity, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), + IdentityGpuKernel, uint8_t); +MS_REG_GPU_KERNEL_ONE(Identity, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), + IdentityGpuKernel, int8_t); +MS_REG_GPU_KERNEL_ONE(Identity, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), + IdentityGpuKernel, bool); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/identity_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/identity_gpu_kernel.h new file mode 100644 index 00000000000..b0f762f6908 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/identity_gpu_kernel.h @@ -0,0 +1,104 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_IDENTITY_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_IDENTITY_GPU_KERNEL_H_ + +#include +#include +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { + +template +class IdentityGpuKernel : public GpuKernel { + public: + IdentityGpuKernel() { ResetResource(); } + ~IdentityGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + T *input_addr = GetDeviceAddress(inputs, 0); + T *output_addr = GetDeviceAddress(outputs, 0); + + CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, + cudaMemcpyAsync(output_addr, input_addr, inputs[0]->size, cudaMemcpyDeviceToDevice, + reinterpret_cast(stream_ptr)), + "cudaMemcpyAsync failed in IdentityGpuKernel::Lanuch"); + return true; + } + bool Init(const CNodePtr &kernel_node) override { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 1) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but identity needs 1 inputs."; + return false; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but identity needs 1 output."; + return false; + } + auto input_shape = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 0); + is_null_input_ = CHECK_NULL_INPUT(input_shape); + if (is_null_input_) { + MS_LOG(WARNING) << "IdentityGpuKernel input is null"; + InitSizeLists(); + return true; + } + for (size_t i = 0; i < input_shape.size(); i++) { + input_size_ *= input_shape[i]; + } + output_size_ = input_size_; + InitSizeLists(); + return true; + } + void ResetResource() noexcept override { + input_size_ = sizeof(T); + output_size_ = sizeof(T); + workspace_size_ = 0; + is_null_input_ = false; + input_size_list_.clear(); + output_size_list_.clear(); + workspace_size_list_.clear(); + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_); + output_size_list_.push_back(output_size_); + } + + private: + size_t input_size_; + size_t output_size_; + size_t workspace_size_; + bool is_null_input_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_IDENTITY_GPU_KERNEL_H_ diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index e4bb7b94710..e63c7ececef 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -435,6 +435,7 @@ inline const PrimitivePtr kPrimACosGrad = std::make_shared("ACosGrad" inline const PrimitivePtr kPrimAtanGrad = std::make_shared("AtanGrad"); inline const PrimitivePtr kPrimFloorMod = std::make_shared("FloorMod"); inline const PrimitivePtr kPrimWhere = std::make_shared("Where"); +inline const PrimitivePtr kPrimIdentityMath = std::make_shared("Identity", kSideEffectPropagate); // Statements inline const PrimitivePtr kPrimReturn = std::make_shared("Return"); diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index d53a9d98dc6..b23f8a777d0 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -5150,7 +5150,7 @@ class Identity(PrimitiveWithInfer): TypeError: If `x` is not a Tensor. Supported Platforms: - ``Ascend`` + ``Ascend`` ``CPU`` ``GPU`` Examples: >>> x = Tensor(np.array([1, 2, 3, 4]), mindspore.int64) diff --git a/tests/st/ops/cpu/test_arithmetic_self_op.py b/tests/st/ops/cpu/test_arithmetic_self_op.py index 3a5bef3594c..5b2a87f4c64 100644 --- a/tests/st/ops/cpu/test_arithmetic_self_op.py +++ b/tests/st/ops/cpu/test_arithmetic_self_op.py @@ -68,6 +68,15 @@ class RintNet(nn.Cell): return self.rint(x) +class IdentityNet(nn.Cell): + def __init__(self): + super(IdentityNet, self).__init__() + self.identity = P.Identity() + + def construct(self, x): + return self.identity(x) + + @pytest.mark.level0 @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard @@ -189,3 +198,163 @@ def test_reciprocal(): diff = output.asnumpy() - expect_output error = np.ones(shape=expect_output.shape) * 1.0e-5 assert np.all(np.abs(diff) < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_identity_pynative(): + context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU") + net = IdentityNet() + + x = np.random.randn(3, 4, 5, 6).astype(np.float64) + input_tensor = Tensor(x) + output = net(input_tensor) + np.testing.assert_almost_equal(output.asnumpy(), input_tensor.asnumpy()) + assert id(input_tensor) != id(output) + + x = np.random.randn(3, 4, 5, 6).astype(np.float32) + input_tensor = Tensor(x) + output = net(input_tensor) + np.testing.assert_almost_equal(output.asnumpy(), input_tensor.asnumpy()) + assert id(input_tensor) != id(output) + + x = np.random.randn(3, 4, 5, 6).astype(np.float16) + input_tensor = Tensor(x) + output = net(input_tensor) + np.testing.assert_almost_equal(output.asnumpy(), input_tensor.asnumpy()) + assert id(input_tensor) != id(output) + + x = np.random.randn(3, 4, 5, 6).astype(np.uint64) + input_tensor = Tensor(x) + output = net(input_tensor) + np.testing.assert_almost_equal(output.asnumpy(), input_tensor.asnumpy()) + assert id(input_tensor) != id(output) + + x = np.random.randn(3, 4, 5, 6).astype(np.int64) + input_tensor = Tensor(x) + output = net(input_tensor) + np.testing.assert_almost_equal(output.asnumpy(), input_tensor.asnumpy()) + assert id(input_tensor) != id(output) + + x = np.random.randn(3, 4, 5, 6).astype(np.uint32) + input_tensor = Tensor(x) + output = net(input_tensor) + np.testing.assert_almost_equal(output.asnumpy(), input_tensor.asnumpy()) + assert id(input_tensor) != id(output) + + x = np.random.randn(3, 4, 5, 6).astype(np.int32) + input_tensor = Tensor(x) + output = net(input_tensor) + np.testing.assert_almost_equal(output.asnumpy(), input_tensor.asnumpy()) + assert id(input_tensor) != id(output) + + x = np.random.randn(3, 4, 5, 6).astype(np.uint16) + input_tensor = Tensor(x) + output = net(input_tensor) + np.testing.assert_almost_equal(output.asnumpy(), input_tensor.asnumpy()) + assert id(input_tensor) != id(output) + + x = np.random.randn(3, 4, 5, 6).astype(np.int16) + input_tensor = Tensor(x) + output = net(input_tensor) + np.testing.assert_almost_equal(output.asnumpy(), input_tensor.asnumpy()) + assert id(input_tensor) != id(output) + + x = np.random.randn(3, 4, 5, 6).astype(np.uint8) + input_tensor = Tensor(x) + output = net(input_tensor) + np.testing.assert_almost_equal(output.asnumpy(), input_tensor.asnumpy()) + assert id(input_tensor) != id(output) + + x = np.random.randn(3, 4, 5, 6).astype(np.int8) + input_tensor = Tensor(x) + output = net(input_tensor) + np.testing.assert_almost_equal(output.asnumpy(), input_tensor.asnumpy()) + assert id(input_tensor) != id(output) + + x = np.random.randn(3, 4, 5, 6).astype(np.bool) + input_tensor = Tensor(x) + output = net(input_tensor) + np.testing.assert_almost_equal(output.asnumpy(), input_tensor.asnumpy()) + assert id(input_tensor) != id(output) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_identity_graph(): + context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + net = IdentityNet() + + x = np.random.randn(3, 4, 5, 6).astype(np.float64) + input_tensor = Tensor(x) + output = net(input_tensor) + np.testing.assert_almost_equal(output.asnumpy(), input_tensor.asnumpy()) + assert id(input_tensor) != id(output) + + x = np.random.randn(3, 4, 5, 6).astype(np.float32) + input_tensor = Tensor(x) + output = net(input_tensor) + np.testing.assert_almost_equal(output.asnumpy(), input_tensor.asnumpy()) + assert id(input_tensor) != id(output) + + x = np.random.randn(3, 4, 5, 6).astype(np.float16) + input_tensor = Tensor(x) + output = net(input_tensor) + np.testing.assert_almost_equal(output.asnumpy(), input_tensor.asnumpy()) + assert id(input_tensor) != id(output) + + x = np.random.randn(3, 4, 5, 6).astype(np.uint64) + input_tensor = Tensor(x) + output = net(input_tensor) + np.testing.assert_almost_equal(output.asnumpy(), input_tensor.asnumpy()) + assert id(input_tensor) != id(output) + + x = np.random.randn(3, 4, 5, 6).astype(np.int64) + input_tensor = Tensor(x) + output = net(input_tensor) + np.testing.assert_almost_equal(output.asnumpy(), input_tensor.asnumpy()) + assert id(input_tensor) != id(output) + + x = np.random.randn(3, 4, 5, 6).astype(np.uint32) + input_tensor = Tensor(x) + output = net(input_tensor) + np.testing.assert_almost_equal(output.asnumpy(), input_tensor.asnumpy()) + assert id(input_tensor) != id(output) + + x = np.random.randn(3, 4, 5, 6).astype(np.int32) + input_tensor = Tensor(x) + output = net(input_tensor) + np.testing.assert_almost_equal(output.asnumpy(), input_tensor.asnumpy()) + assert id(input_tensor) != id(output) + + x = np.random.randn(3, 4, 5, 6).astype(np.uint16) + input_tensor = Tensor(x) + output = net(input_tensor) + np.testing.assert_almost_equal(output.asnumpy(), input_tensor.asnumpy()) + assert id(input_tensor) != id(output) + + x = np.random.randn(3, 4, 5, 6).astype(np.int16) + input_tensor = Tensor(x) + output = net(input_tensor) + np.testing.assert_almost_equal(output.asnumpy(), input_tensor.asnumpy()) + assert id(input_tensor) != id(output) + + x = np.random.randn(3, 4, 5, 6).astype(np.uint8) + input_tensor = Tensor(x) + output = net(input_tensor) + np.testing.assert_almost_equal(output.asnumpy(), input_tensor.asnumpy()) + assert id(input_tensor) != id(output) + + x = np.random.randn(3, 4, 5, 6).astype(np.int8) + input_tensor = Tensor(x) + output = net(input_tensor) + np.testing.assert_almost_equal(output.asnumpy(), input_tensor.asnumpy()) + assert id(input_tensor) != id(output) + + x = np.random.randn(3, 4, 5, 6).astype(np.bool) + input_tensor = Tensor(x) + output = net(input_tensor) + np.testing.assert_almost_equal(output.asnumpy(), input_tensor.asnumpy()) + assert id(input_tensor) != id(output) diff --git a/tests/st/ops/gpu/test_identity_op.py b/tests/st/ops/gpu/test_identity_op.py new file mode 100644 index 00000000000..d4239ae44a5 --- /dev/null +++ b/tests/st/ops/gpu/test_identity_op.py @@ -0,0 +1,132 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor, ops + + +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.identity = ops.Identity() + + def construct(self, x): + return self.identity(x) + + +def generate_testcases(nptype): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + x = np.random.randn(3, 4, 5, 6).astype(nptype) + net = Net() + input_tensor = Tensor(x) + output = net(input_tensor) + np.testing.assert_almost_equal(output.asnumpy(), input_tensor.asnumpy()) + assert id(input_tensor) != id(output) + + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + x = np.random.randn(3, 4, 5, 6).astype(nptype) + net = Net() + input_tensor = Tensor(x) + output = net(input_tensor) + np.testing.assert_almost_equal(output.asnumpy(), input_tensor.asnumpy()) + assert id(input_tensor) != id(output) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_identity_float64(): + generate_testcases(np.float64) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_identity_float32(): + generate_testcases(np.float32) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_identity_float16(): + generate_testcases(np.float16) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_identity_uint64(): + generate_testcases(np.uint64) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_identity_int64(): + generate_testcases(np.int64) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_identity_uint32(): + generate_testcases(np.uint32) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_identity_int32(): + generate_testcases(np.int32) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_identity_uint16(): + generate_testcases(np.uint16) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_identity_int16(): + generate_testcases(np.int16) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_identity_uint8(): + generate_testcases(np.uint8) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_identity_int8(): + generate_testcases(np.int8) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_identity_bool(): + generate_testcases(np.bool)