forked from mindspore-Ecosystem/mindspore
!16028 Add identity op for cpu and gpu
From: @xcnick Reviewed-by: Signed-off-by:
This commit is contained in:
commit
abb6192daa
|
@ -13,6 +13,7 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
|
@ -251,6 +252,11 @@ void Atanh(const T *in, T *out, size_t size) {
|
|||
};
|
||||
CPUKernelUtils::ParallelFor(task, size);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void Identity(const T *in, T *out, size_t size) {
|
||||
std::copy(in, in + size, out);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
static const std::map<std::string, OperateType> kArithmeticOpTypeMap = {{prim::kPrimNeg->name(), NEG},
|
||||
|
@ -274,7 +280,8 @@ static const std::map<std::string, OperateType> kArithmeticOpTypeMap = {{prim::k
|
|||
{prim::kPrimCosh->name(), COSH},
|
||||
{prim::kPrimAsinh->name(), ASINH},
|
||||
{prim::kPrimAcosh->name(), ACOSH},
|
||||
{prim::kPrimAtanh->name(), ATANH}};
|
||||
{prim::kPrimAtanh->name(), ATANH},
|
||||
{prim::kPrimIdentityMath->name(), IDENTITY}};
|
||||
|
||||
void ArithmeticSelfCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
|
@ -335,5 +342,16 @@ void ArithmeticSelfCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs
|
|||
MS_LOG(EXCEPTION) << "Not support " << operate_type_;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool IdentityCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
T *input = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
T *output = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
size_t lens = outputs[0]->size > 0 ? static_cast<size_t>(outputs[0]->size / sizeof(T)) : 1;
|
||||
Identity<T>(input, output, lens);
|
||||
return true;
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -43,6 +43,16 @@ class ArithmeticSelfCPUKernel : public CPUKernel {
|
|||
TypeId target_dtype_{kTypeUnknown};
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class IdentityCPUKernel : public ArithmeticSelfCPUKernel {
|
||||
public:
|
||||
IdentityCPUKernel() = default;
|
||||
~IdentityCPUKernel() override = default;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(Square, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
ArithmeticSelfCPUKernel);
|
||||
MS_REG_CPU_KERNEL(Neg, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
|
@ -97,6 +107,31 @@ MS_REG_CPU_KERNEL(Acosh, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutput
|
|||
ArithmeticSelfCPUKernel);
|
||||
MS_REG_CPU_KERNEL(Atanh, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
ArithmeticSelfCPUKernel);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(Identity, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
|
||||
IdentityCPUKernel, uint64_t);
|
||||
MS_REG_CPU_KERNEL_T(Identity, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
IdentityCPUKernel, int64_t);
|
||||
MS_REG_CPU_KERNEL_T(Identity, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||
IdentityCPUKernel, uint32_t);
|
||||
MS_REG_CPU_KERNEL_T(Identity, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
IdentityCPUKernel, int32_t);
|
||||
MS_REG_CPU_KERNEL_T(Identity, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
|
||||
IdentityCPUKernel, uint16_t);
|
||||
MS_REG_CPU_KERNEL_T(Identity, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
|
||||
IdentityCPUKernel, int16_t);
|
||||
MS_REG_CPU_KERNEL_T(Identity, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
|
||||
IdentityCPUKernel, uint8_t);
|
||||
MS_REG_CPU_KERNEL_T(Identity, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
|
||||
IdentityCPUKernel, int8_t);
|
||||
MS_REG_CPU_KERNEL_T(Identity, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
IdentityCPUKernel, double);
|
||||
MS_REG_CPU_KERNEL_T(Identity, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
IdentityCPUKernel, float);
|
||||
MS_REG_CPU_KERNEL_T(Identity, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
IdentityCPUKernel, float16);
|
||||
MS_REG_CPU_KERNEL_T(Identity, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
||||
IdentityCPUKernel, bool);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -118,6 +118,7 @@ enum OperateType {
|
|||
ATAN2,
|
||||
RINT,
|
||||
ROUND,
|
||||
IDENTITY,
|
||||
};
|
||||
|
||||
class CPUKernel : public kernel::KernelMod {
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -14,10 +14,10 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "identity_impl.cuh"
|
||||
#include "eye_impl.cuh"
|
||||
#include <iostream>
|
||||
template <typename T>
|
||||
__global__ void IdentityKernel(const size_t size, const size_t dim, T *output_addr) {
|
||||
__global__ void EyeKernel(const size_t size, const size_t dim, T *output_addr) {
|
||||
for (size_t pointIdx = blockIdx.x * blockDim.x + threadIdx.x; pointIdx < (size); pointIdx += blockDim.x * gridDim.x) {
|
||||
size_t batchIdx = pointIdx / (dim * dim);
|
||||
size_t dst_x = (pointIdx - batchIdx * dim * dim) / dim;
|
||||
|
@ -31,10 +31,9 @@ __global__ void IdentityKernel(const size_t size, const size_t dim, T *output_ad
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void Identity(const size_t size, const size_t dim, T *output_addr, cudaStream_t cuda_stream) {
|
||||
IdentityKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, dim, output_addr);
|
||||
void Eye(const size_t size, const size_t dim, T *output_addr, cudaStream_t cuda_stream) {
|
||||
EyeKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, dim, output_addr);
|
||||
return;
|
||||
}
|
||||
|
||||
template void Identity<float>(const size_t size, const size_t dim, float *output_addr, cudaStream_t cuda_stream);
|
||||
|
||||
template void Eye<float>(const size_t size, const size_t dim, float *output_addr, cudaStream_t cuda_stream);
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -14,11 +14,11 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_IDENTITY_H_
|
||||
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_IDENTITY_H_
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_EYE_H_
|
||||
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_EYE_H_
|
||||
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
template <typename T>
|
||||
void Identity(const size_t size, const size_t dim, T *output_addr, cudaStream_t cuda_stream);
|
||||
void Eye(const size_t size, const size_t dim, T *output_addr, cudaStream_t cuda_stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_MATRIXSPLIT_H_
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_EYE_H_
|
|
@ -20,7 +20,7 @@
|
|||
#include <cuda_runtime_api.h>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/identity_impl.cuh"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/eye_impl.cuh"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/matrix_split_impl.cuh"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/triangle_matrix_copy_impl.cuh"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
|
@ -79,7 +79,7 @@ class CholeskyGpuKernel : public GpuKernel {
|
|||
h_array[i] = d_batch_input_addr + i * lda_ * m_;
|
||||
h_identity[i] = output_addr + i * ldb_ * m_;
|
||||
}
|
||||
Identity(batch_ * split_dim * split_dim, split_dim, output_addr, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
Eye(batch_ * split_dim * split_dim, split_dim, output_addr, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
MatrixSplit(batch_ * split_dim * split_dim, split_dim, width, input1_addr, d_batch_input_addr,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
CHECK_CUDA_RET_WITH_ERROR(kernel_node_,
|
||||
|
|
|
@ -20,7 +20,7 @@
|
|||
#include <cuda_runtime_api.h>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/identity_impl.cuh"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/eye_impl.cuh"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/matrix_split_impl.cuh"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
|
@ -149,7 +149,7 @@ class CholeskyTrsmGpuKernel : public GpuKernel {
|
|||
h_array[i] = d_batch_input_addr + i * lda_ * m_;
|
||||
h_identity[i] = output_addr + i * ldb_ * m_;
|
||||
}
|
||||
Identity(batch_ * split_dim * split_dim, split_dim, output_addr, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
Eye(batch_ * split_dim * split_dim, split_dim, output_addr, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
MatrixSplit(batch_ * split_dim * split_dim, split_dim, width, input1_addr, d_batch_input_addr,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
CHECK_CUDA_RET_WITH_ERROR(kernel_node_,
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/math/identity_gpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(Identity, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
IdentityGpuKernel, double)
|
||||
MS_REG_GPU_KERNEL_ONE(Identity, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
IdentityGpuKernel, float);
|
||||
MS_REG_GPU_KERNEL_ONE(Identity, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
IdentityGpuKernel, half);
|
||||
MS_REG_GPU_KERNEL_ONE(Identity, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
|
||||
IdentityGpuKernel, uint64_t);
|
||||
MS_REG_GPU_KERNEL_ONE(Identity, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
IdentityGpuKernel, int64_t);
|
||||
MS_REG_GPU_KERNEL_ONE(Identity, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||
IdentityGpuKernel, uint32_t);
|
||||
MS_REG_GPU_KERNEL_ONE(Identity, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
IdentityGpuKernel, int32_t);
|
||||
MS_REG_GPU_KERNEL_ONE(Identity, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
|
||||
IdentityGpuKernel, uint16_t);
|
||||
MS_REG_GPU_KERNEL_ONE(Identity, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
|
||||
IdentityGpuKernel, int16_t);
|
||||
MS_REG_GPU_KERNEL_ONE(Identity, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
|
||||
IdentityGpuKernel, uint8_t);
|
||||
MS_REG_GPU_KERNEL_ONE(Identity, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
|
||||
IdentityGpuKernel, int8_t);
|
||||
MS_REG_GPU_KERNEL_ONE(Identity, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
||||
IdentityGpuKernel, bool);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,104 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_IDENTITY_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_IDENTITY_GPU_KERNEL_H_
|
||||
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
||||
template <typename T>
|
||||
class IdentityGpuKernel : public GpuKernel {
|
||||
public:
|
||||
IdentityGpuKernel() { ResetResource(); }
|
||||
~IdentityGpuKernel() override = default;
|
||||
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
T *input_addr = GetDeviceAddress<T>(inputs, 0);
|
||||
T *output_addr = GetDeviceAddress<T>(outputs, 0);
|
||||
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudaMemcpyAsync(output_addr, input_addr, inputs[0]->size, cudaMemcpyDeviceToDevice,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemcpyAsync failed in IdentityGpuKernel::Lanuch");
|
||||
return true;
|
||||
}
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != 1) {
|
||||
MS_LOG(ERROR) << "Input number is " << input_num << ", but identity needs 1 inputs.";
|
||||
return false;
|
||||
}
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
if (output_num != 1) {
|
||||
MS_LOG(ERROR) << "Output number is " << output_num << ", but identity needs 1 output.";
|
||||
return false;
|
||||
}
|
||||
auto input_shape = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 0);
|
||||
is_null_input_ = CHECK_NULL_INPUT(input_shape);
|
||||
if (is_null_input_) {
|
||||
MS_LOG(WARNING) << "IdentityGpuKernel input is null";
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
for (size_t i = 0; i < input_shape.size(); i++) {
|
||||
input_size_ *= input_shape[i];
|
||||
}
|
||||
output_size_ = input_size_;
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
void ResetResource() noexcept override {
|
||||
input_size_ = sizeof(T);
|
||||
output_size_ = sizeof(T);
|
||||
workspace_size_ = 0;
|
||||
is_null_input_ = false;
|
||||
input_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
workspace_size_list_.clear();
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override {
|
||||
input_size_list_.push_back(input_size_);
|
||||
output_size_list_.push_back(output_size_);
|
||||
}
|
||||
|
||||
private:
|
||||
size_t input_size_;
|
||||
size_t output_size_;
|
||||
size_t workspace_size_;
|
||||
bool is_null_input_;
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_IDENTITY_GPU_KERNEL_H_
|
|
@ -435,6 +435,7 @@ inline const PrimitivePtr kPrimACosGrad = std::make_shared<Primitive>("ACosGrad"
|
|||
inline const PrimitivePtr kPrimAtanGrad = std::make_shared<Primitive>("AtanGrad");
|
||||
inline const PrimitivePtr kPrimFloorMod = std::make_shared<Primitive>("FloorMod");
|
||||
inline const PrimitivePtr kPrimWhere = std::make_shared<Primitive>("Where");
|
||||
inline const PrimitivePtr kPrimIdentityMath = std::make_shared<Primitive>("Identity", kSideEffectPropagate);
|
||||
|
||||
// Statements
|
||||
inline const PrimitivePtr kPrimReturn = std::make_shared<Primitive>("Return");
|
||||
|
|
|
@ -5150,7 +5150,7 @@ class Identity(PrimitiveWithInfer):
|
|||
TypeError: If `x` is not a Tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend``
|
||||
``Ascend`` ``CPU`` ``GPU``
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor(np.array([1, 2, 3, 4]), mindspore.int64)
|
||||
|
|
|
@ -68,6 +68,15 @@ class RintNet(nn.Cell):
|
|||
return self.rint(x)
|
||||
|
||||
|
||||
class IdentityNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(IdentityNet, self).__init__()
|
||||
self.identity = P.Identity()
|
||||
|
||||
def construct(self, x):
|
||||
return self.identity(x)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
|
@ -189,3 +198,163 @@ def test_reciprocal():
|
|||
diff = output.asnumpy() - expect_output
|
||||
error = np.ones(shape=expect_output.shape) * 1.0e-5
|
||||
assert np.all(np.abs(diff) < error)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_identity_pynative():
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
|
||||
net = IdentityNet()
|
||||
|
||||
x = np.random.randn(3, 4, 5, 6).astype(np.float64)
|
||||
input_tensor = Tensor(x)
|
||||
output = net(input_tensor)
|
||||
np.testing.assert_almost_equal(output.asnumpy(), input_tensor.asnumpy())
|
||||
assert id(input_tensor) != id(output)
|
||||
|
||||
x = np.random.randn(3, 4, 5, 6).astype(np.float32)
|
||||
input_tensor = Tensor(x)
|
||||
output = net(input_tensor)
|
||||
np.testing.assert_almost_equal(output.asnumpy(), input_tensor.asnumpy())
|
||||
assert id(input_tensor) != id(output)
|
||||
|
||||
x = np.random.randn(3, 4, 5, 6).astype(np.float16)
|
||||
input_tensor = Tensor(x)
|
||||
output = net(input_tensor)
|
||||
np.testing.assert_almost_equal(output.asnumpy(), input_tensor.asnumpy())
|
||||
assert id(input_tensor) != id(output)
|
||||
|
||||
x = np.random.randn(3, 4, 5, 6).astype(np.uint64)
|
||||
input_tensor = Tensor(x)
|
||||
output = net(input_tensor)
|
||||
np.testing.assert_almost_equal(output.asnumpy(), input_tensor.asnumpy())
|
||||
assert id(input_tensor) != id(output)
|
||||
|
||||
x = np.random.randn(3, 4, 5, 6).astype(np.int64)
|
||||
input_tensor = Tensor(x)
|
||||
output = net(input_tensor)
|
||||
np.testing.assert_almost_equal(output.asnumpy(), input_tensor.asnumpy())
|
||||
assert id(input_tensor) != id(output)
|
||||
|
||||
x = np.random.randn(3, 4, 5, 6).astype(np.uint32)
|
||||
input_tensor = Tensor(x)
|
||||
output = net(input_tensor)
|
||||
np.testing.assert_almost_equal(output.asnumpy(), input_tensor.asnumpy())
|
||||
assert id(input_tensor) != id(output)
|
||||
|
||||
x = np.random.randn(3, 4, 5, 6).astype(np.int32)
|
||||
input_tensor = Tensor(x)
|
||||
output = net(input_tensor)
|
||||
np.testing.assert_almost_equal(output.asnumpy(), input_tensor.asnumpy())
|
||||
assert id(input_tensor) != id(output)
|
||||
|
||||
x = np.random.randn(3, 4, 5, 6).astype(np.uint16)
|
||||
input_tensor = Tensor(x)
|
||||
output = net(input_tensor)
|
||||
np.testing.assert_almost_equal(output.asnumpy(), input_tensor.asnumpy())
|
||||
assert id(input_tensor) != id(output)
|
||||
|
||||
x = np.random.randn(3, 4, 5, 6).astype(np.int16)
|
||||
input_tensor = Tensor(x)
|
||||
output = net(input_tensor)
|
||||
np.testing.assert_almost_equal(output.asnumpy(), input_tensor.asnumpy())
|
||||
assert id(input_tensor) != id(output)
|
||||
|
||||
x = np.random.randn(3, 4, 5, 6).astype(np.uint8)
|
||||
input_tensor = Tensor(x)
|
||||
output = net(input_tensor)
|
||||
np.testing.assert_almost_equal(output.asnumpy(), input_tensor.asnumpy())
|
||||
assert id(input_tensor) != id(output)
|
||||
|
||||
x = np.random.randn(3, 4, 5, 6).astype(np.int8)
|
||||
input_tensor = Tensor(x)
|
||||
output = net(input_tensor)
|
||||
np.testing.assert_almost_equal(output.asnumpy(), input_tensor.asnumpy())
|
||||
assert id(input_tensor) != id(output)
|
||||
|
||||
x = np.random.randn(3, 4, 5, 6).astype(np.bool)
|
||||
input_tensor = Tensor(x)
|
||||
output = net(input_tensor)
|
||||
np.testing.assert_almost_equal(output.asnumpy(), input_tensor.asnumpy())
|
||||
assert id(input_tensor) != id(output)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_identity_graph():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
net = IdentityNet()
|
||||
|
||||
x = np.random.randn(3, 4, 5, 6).astype(np.float64)
|
||||
input_tensor = Tensor(x)
|
||||
output = net(input_tensor)
|
||||
np.testing.assert_almost_equal(output.asnumpy(), input_tensor.asnumpy())
|
||||
assert id(input_tensor) != id(output)
|
||||
|
||||
x = np.random.randn(3, 4, 5, 6).astype(np.float32)
|
||||
input_tensor = Tensor(x)
|
||||
output = net(input_tensor)
|
||||
np.testing.assert_almost_equal(output.asnumpy(), input_tensor.asnumpy())
|
||||
assert id(input_tensor) != id(output)
|
||||
|
||||
x = np.random.randn(3, 4, 5, 6).astype(np.float16)
|
||||
input_tensor = Tensor(x)
|
||||
output = net(input_tensor)
|
||||
np.testing.assert_almost_equal(output.asnumpy(), input_tensor.asnumpy())
|
||||
assert id(input_tensor) != id(output)
|
||||
|
||||
x = np.random.randn(3, 4, 5, 6).astype(np.uint64)
|
||||
input_tensor = Tensor(x)
|
||||
output = net(input_tensor)
|
||||
np.testing.assert_almost_equal(output.asnumpy(), input_tensor.asnumpy())
|
||||
assert id(input_tensor) != id(output)
|
||||
|
||||
x = np.random.randn(3, 4, 5, 6).astype(np.int64)
|
||||
input_tensor = Tensor(x)
|
||||
output = net(input_tensor)
|
||||
np.testing.assert_almost_equal(output.asnumpy(), input_tensor.asnumpy())
|
||||
assert id(input_tensor) != id(output)
|
||||
|
||||
x = np.random.randn(3, 4, 5, 6).astype(np.uint32)
|
||||
input_tensor = Tensor(x)
|
||||
output = net(input_tensor)
|
||||
np.testing.assert_almost_equal(output.asnumpy(), input_tensor.asnumpy())
|
||||
assert id(input_tensor) != id(output)
|
||||
|
||||
x = np.random.randn(3, 4, 5, 6).astype(np.int32)
|
||||
input_tensor = Tensor(x)
|
||||
output = net(input_tensor)
|
||||
np.testing.assert_almost_equal(output.asnumpy(), input_tensor.asnumpy())
|
||||
assert id(input_tensor) != id(output)
|
||||
|
||||
x = np.random.randn(3, 4, 5, 6).astype(np.uint16)
|
||||
input_tensor = Tensor(x)
|
||||
output = net(input_tensor)
|
||||
np.testing.assert_almost_equal(output.asnumpy(), input_tensor.asnumpy())
|
||||
assert id(input_tensor) != id(output)
|
||||
|
||||
x = np.random.randn(3, 4, 5, 6).astype(np.int16)
|
||||
input_tensor = Tensor(x)
|
||||
output = net(input_tensor)
|
||||
np.testing.assert_almost_equal(output.asnumpy(), input_tensor.asnumpy())
|
||||
assert id(input_tensor) != id(output)
|
||||
|
||||
x = np.random.randn(3, 4, 5, 6).astype(np.uint8)
|
||||
input_tensor = Tensor(x)
|
||||
output = net(input_tensor)
|
||||
np.testing.assert_almost_equal(output.asnumpy(), input_tensor.asnumpy())
|
||||
assert id(input_tensor) != id(output)
|
||||
|
||||
x = np.random.randn(3, 4, 5, 6).astype(np.int8)
|
||||
input_tensor = Tensor(x)
|
||||
output = net(input_tensor)
|
||||
np.testing.assert_almost_equal(output.asnumpy(), input_tensor.asnumpy())
|
||||
assert id(input_tensor) != id(output)
|
||||
|
||||
x = np.random.randn(3, 4, 5, 6).astype(np.bool)
|
||||
input_tensor = Tensor(x)
|
||||
output = net(input_tensor)
|
||||
np.testing.assert_almost_equal(output.asnumpy(), input_tensor.asnumpy())
|
||||
assert id(input_tensor) != id(output)
|
||||
|
|
|
@ -0,0 +1,132 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor, ops
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.identity = ops.Identity()
|
||||
|
||||
def construct(self, x):
|
||||
return self.identity(x)
|
||||
|
||||
|
||||
def generate_testcases(nptype):
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
x = np.random.randn(3, 4, 5, 6).astype(nptype)
|
||||
net = Net()
|
||||
input_tensor = Tensor(x)
|
||||
output = net(input_tensor)
|
||||
np.testing.assert_almost_equal(output.asnumpy(), input_tensor.asnumpy())
|
||||
assert id(input_tensor) != id(output)
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
x = np.random.randn(3, 4, 5, 6).astype(nptype)
|
||||
net = Net()
|
||||
input_tensor = Tensor(x)
|
||||
output = net(input_tensor)
|
||||
np.testing.assert_almost_equal(output.asnumpy(), input_tensor.asnumpy())
|
||||
assert id(input_tensor) != id(output)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_identity_float64():
|
||||
generate_testcases(np.float64)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_identity_float32():
|
||||
generate_testcases(np.float32)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_identity_float16():
|
||||
generate_testcases(np.float16)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_identity_uint64():
|
||||
generate_testcases(np.uint64)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_identity_int64():
|
||||
generate_testcases(np.int64)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_identity_uint32():
|
||||
generate_testcases(np.uint32)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_identity_int32():
|
||||
generate_testcases(np.int32)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_identity_uint16():
|
||||
generate_testcases(np.uint16)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_identity_int16():
|
||||
generate_testcases(np.int16)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_identity_uint8():
|
||||
generate_testcases(np.uint8)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_identity_int8():
|
||||
generate_testcases(np.int8)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_identity_bool():
|
||||
generate_testcases(np.bool)
|
Loading…
Reference in New Issue