forked from mindspore-Ecosystem/mindspore
!33268 Add GPU implementation of NonZero operator.
Merge pull request !33268 from hezhenhao1/add_nonzero
This commit is contained in:
commit
818e1a2b6e
|
@ -0,0 +1,152 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/gpu/kernel/arrays/non_zero_gpu_kernel.h"
|
||||
#include <functional>
|
||||
#include <utility>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include "mindspore/core/abstract/utils.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/non_zero_impl.cuh"
|
||||
#include "kernel/common_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
constexpr int kNonZeroInputsNum = 1;
|
||||
constexpr int kNonZeroOutputsNum = 1;
|
||||
} // namespace
|
||||
|
||||
bool NonZeroGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
kernel_name_ = base_operator->GetPrim()->name();
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kNonZeroInputsNum, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kNonZeroOutputsNum, kernel_name_);
|
||||
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr;
|
||||
return false;
|
||||
}
|
||||
is_need_retrieve_output_shape_ = true; // NonZero is a dynamic shape operator.
|
||||
kernel_func_ = func_list_[index].second;
|
||||
data_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex0).first);
|
||||
index_size_ = abstract::TypeIdSize(kernel_attr.GetOutputAttr(kIndex0).first);
|
||||
return true;
|
||||
}
|
||||
|
||||
void NonZeroGpuKernelMod::ResetResource() noexcept {
|
||||
real_output_size_ = 0;
|
||||
input_shape_.clear();
|
||||
input_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
workspace_size_list_.clear();
|
||||
}
|
||||
|
||||
int NonZeroGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &) {
|
||||
ResetResource();
|
||||
outputs_ = outputs;
|
||||
auto shape = inputs.at(kIndex0)->GetShapeVector();
|
||||
(void)std::transform(shape.begin(), shape.end(), std::back_inserter(input_shape_),
|
||||
[](int64_t x) { return x < 0 ? 0 : LongToSize(x); });
|
||||
rank_ = input_shape_.size();
|
||||
input_size_ = std::accumulate(input_shape_.begin(), input_shape_.end(), 1, std::multiplies{});
|
||||
if (input_size_ == 0) {
|
||||
return KRET_INVALID_SHAPE;
|
||||
}
|
||||
|
||||
input_size_list_.push_back(input_size_ * data_size_);
|
||||
workspace_size_list_.push_back(input_size_ * sizeof(size_t));
|
||||
workspace_size_list_.push_back(rank_ * sizeof(size_t));
|
||||
output_size_list_.push_back(input_size_ * rank_ * index_size_);
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
template <typename DataType, typename IndexType>
|
||||
bool NonZeroGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
|
||||
cuda_stream_ = reinterpret_cast<cudaStream_t>(stream_ptr);
|
||||
if (input_size_ == 0) {
|
||||
return true;
|
||||
}
|
||||
auto input_ptr = GetDeviceAddress<DataType>(inputs, kIndex0);
|
||||
MS_EXCEPTION_IF_NULL(input_ptr);
|
||||
auto index_ptr = GetDeviceAddress<size_t>(workspace, kIndex0);
|
||||
MS_EXCEPTION_IF_NULL(index_ptr);
|
||||
auto shape_ptr = GetDeviceAddress<size_t>(workspace, kIndex1);
|
||||
MS_EXCEPTION_IF_NULL(shape_ptr);
|
||||
auto output_ptr = GetDeviceAddress<IndexType>(outputs, kIndex0);
|
||||
MS_EXCEPTION_IF_NULL(output_ptr);
|
||||
|
||||
// Copy input shape to device.
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
|
||||
cudaMemcpyAsync(shape_ptr, input_shape_.data(), rank_ * sizeof(size_t), cudaMemcpyHostToDevice, cuda_stream_),
|
||||
"NonZero cudaMemcpyAsync failed.");
|
||||
|
||||
NonZero(input_ptr, index_ptr, shape_ptr, output_ptr, input_size_, rank_, cuda_stream_);
|
||||
|
||||
// The last element of index_ptr is the final output size of NonZero.
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaMemcpyAsync(&real_output_size_, index_ptr + input_size_ - 1, sizeof(int64_t),
|
||||
cudaMemcpyDeviceToHost, cuda_stream_),
|
||||
"NonZero cudaMemcpyAsync failed.");
|
||||
return true;
|
||||
}
|
||||
|
||||
void NonZeroGpuKernelMod::SyncData() {
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaStreamSynchronize(cuda_stream_), "NonZero cudaStreamSynchronized failed");
|
||||
std::vector<int64_t> new_output_shape = {SizeToLong(real_output_size_), SizeToLong(rank_)};
|
||||
outputs_[kIndex0]->SetShapeVector(new_output_shape);
|
||||
}
|
||||
|
||||
std::vector<std::pair<KernelAttr, NonZeroGpuKernelMod::NonZeroLaunchFunc>> NonZeroGpuKernelMod::func_list_ = {
|
||||
{KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt64),
|
||||
&NonZeroGpuKernelMod::LaunchKernel<bool, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt64),
|
||||
&NonZeroGpuKernelMod::LaunchKernel<int8_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt64),
|
||||
&NonZeroGpuKernelMod::LaunchKernel<int16_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
|
||||
&NonZeroGpuKernelMod::LaunchKernel<int32_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
&NonZeroGpuKernelMod::LaunchKernel<int64_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt64),
|
||||
&NonZeroGpuKernelMod::LaunchKernel<half, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt64),
|
||||
&NonZeroGpuKernelMod::LaunchKernel<float, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt64),
|
||||
&NonZeroGpuKernelMod::LaunchKernel<double, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt64),
|
||||
&NonZeroGpuKernelMod::LaunchKernel<uint8_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt64),
|
||||
&NonZeroGpuKernelMod::LaunchKernel<uint16_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt64),
|
||||
&NonZeroGpuKernelMod::LaunchKernel<uint32_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
&NonZeroGpuKernelMod::LaunchKernel<uint64_t, int64_t>}};
|
||||
|
||||
std::vector<KernelAttr> NonZeroGpuKernelMod::GetOpSupport() {
|
||||
std::vector<KernelAttr> support_list;
|
||||
(void)std::transform(
|
||||
func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, NonZeroGpuKernelMod::NonZeroLaunchFunc> &pair) { return pair.first; });
|
||||
return support_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, NonZero, NonZeroGpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,73 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_ARRAYS_NON_ZERO_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_ARRAYS_NON_ZERO_GPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <map>
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class NonZeroGpuKernelMod : public NativeGpuKernelMod {
|
||||
public:
|
||||
NonZeroGpuKernelMod() = default;
|
||||
~NonZeroGpuKernelMod() override = default;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
return kernel_func_(this, inputs, workspace, outputs, stream_ptr);
|
||||
}
|
||||
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
|
||||
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
|
||||
|
||||
protected:
|
||||
void SyncData() override;
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
std::vector<KernelTensorPtr> GetOutputs() override { return outputs_; }
|
||||
|
||||
private:
|
||||
void ResetResource() noexcept;
|
||||
template <typename DataType, typename IndexType>
|
||||
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr);
|
||||
|
||||
using NonZeroLaunchFunc =
|
||||
std::function<bool(NonZeroGpuKernelMod *, const std::vector<AddressPtr> &, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &, void *)>;
|
||||
static std::vector<std::pair<KernelAttr, NonZeroLaunchFunc>> func_list_;
|
||||
NonZeroLaunchFunc kernel_func_;
|
||||
cudaStream_t cuda_stream_;
|
||||
size_t rank_;
|
||||
size_t input_size_;
|
||||
size_t data_size_; // That is, sizeof(DataType).
|
||||
size_t index_size_; // That is, sizeof(IndexType)
|
||||
size_t real_output_size_; // Dynamic shape related.
|
||||
std::vector<size_t> input_shape_{};
|
||||
std::vector<KernelTensorPtr> outputs_{};
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_ARRAYS_NON_ZERO_GPU_KERNEL_H_
|
|
@ -0,0 +1,103 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/non_zero_impl.cuh"
|
||||
#include <thrust/transform.h>
|
||||
#include <thrust/device_ptr.h>
|
||||
#include <thrust/execution_policy.h>
|
||||
#include <algorithm>
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/util.cuh"
|
||||
|
||||
template <typename T>
|
||||
struct is_nonzero {
|
||||
typedef T data_type;
|
||||
typedef size_t index_type;
|
||||
|
||||
__device__ index_type operator()(const data_type &x) const { return x == T(0) ? 0 : 1; }
|
||||
};
|
||||
|
||||
template <typename IndexType>
|
||||
__global__ void NonZeroKernel(const size_t *index_ptr, const size_t *shape_ptr, IndexType *output_ptr,
|
||||
size_t input_size, size_t rank) {
|
||||
for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < input_size; tid += blockDim.x * gridDim.x) {
|
||||
bool is_write = (tid != 0 && index_ptr[tid] != index_ptr[tid - 1]) || (tid == 0 && index_ptr[tid]);
|
||||
if (is_write) {
|
||||
size_t fill_index = index_ptr[tid] * rank - 1;
|
||||
size_t fill_value = tid;
|
||||
for (size_t i = 0; i < rank; i++) {
|
||||
size_t base = shape_ptr[rank - 1 - i];
|
||||
output_ptr[fill_index] = fill_value % base;
|
||||
fill_index--;
|
||||
fill_value /= base;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DataType, typename IndexType>
|
||||
CUDA_LIB_EXPORT void NonZero(const DataType *input_ptr, size_t *index_ptr, size_t *shape_ptr, IndexType *output_ptr,
|
||||
size_t input_size, size_t rank, cudaStream_t cuda_stream) {
|
||||
auto device = thrust::cuda::par.on(cuda_stream);
|
||||
auto thrust_input_ptr = thrust::device_pointer_cast(input_ptr);
|
||||
auto thrust_index_ptr = thrust::device_pointer_cast(index_ptr);
|
||||
|
||||
// Transform each non-zero element to 0 if number is zero else is 1,
|
||||
// then using scan method to calculate prefix sum of 01 transformed sequence.
|
||||
thrust::transform(device, thrust_input_ptr, thrust_input_ptr + input_size, thrust_index_ptr, is_nonzero<DataType>());
|
||||
thrust::inclusive_scan(device, thrust_index_ptr, thrust_index_ptr + input_size, thrust_index_ptr);
|
||||
|
||||
// Extract the first index to appear and transform into output index,
|
||||
// e.g., [0, 0, 1, 2, 2, 2] -> [(1, 2), (2, 3)] -> [(0, 0, 2), (0, 1, 0)] when shape is (2, 1, 3)
|
||||
NonZeroKernel<<<GET_BLOCKS(input_size), GET_THREADS, 0, cuda_stream>>>(index_ptr, shape_ptr, output_ptr, input_size,
|
||||
rank);
|
||||
}
|
||||
|
||||
template CUDA_LIB_EXPORT void NonZero<bool, int64_t>(const bool *input_ptr, size_t *index_ptr, size_t *shape_ptr,
|
||||
int64_t *output_ptr, size_t input_size, size_t rank,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void NonZero<uint8_t, int64_t>(const uint8_t *input_ptr, size_t *index_ptr, size_t *shape_ptr,
|
||||
int64_t *output_ptr, size_t input_size, size_t rank,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void NonZero<uint16_t, int64_t>(const uint16_t *input_ptr, size_t *index_ptr,
|
||||
size_t *shape_ptr, int64_t *output_ptr, size_t input_size,
|
||||
size_t rank, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void NonZero<uint32_t, int64_t>(const uint32_t *input_ptr, size_t *index_ptr,
|
||||
size_t *shape_ptr, int64_t *output_ptr, size_t input_size,
|
||||
size_t rank, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void NonZero<uint64_t, int64_t>(const uint64_t *input_ptr, size_t *index_ptr,
|
||||
size_t *shape_ptr, int64_t *output_ptr, size_t input_size,
|
||||
size_t rank, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void NonZero<int8_t, int64_t>(const int8_t *input_ptr, size_t *index_ptr, size_t *shape_ptr,
|
||||
int64_t *output_ptr, size_t input_size, size_t rank,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void NonZero<int16_t, int64_t>(const int16_t *input_ptr, size_t *index_ptr, size_t *shape_ptr,
|
||||
int64_t *output_ptr, size_t input_size, size_t rank,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void NonZero<int32_t, int64_t>(const int32_t *input_ptr, size_t *index_ptr, size_t *shape_ptr,
|
||||
int64_t *output_ptr, size_t input_size, size_t rank,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void NonZero<int64_t, int64_t>(const int64_t *input_ptr, size_t *index_ptr, size_t *shape_ptr,
|
||||
int64_t *output_ptr, size_t input_size, size_t rank,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void NonZero<half, int64_t>(const half *input_ptr, size_t *index_ptr, size_t *shape_ptr,
|
||||
int64_t *output_ptr, size_t input_size, size_t rank,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void NonZero<float, int64_t>(const float *input_ptr, size_t *index_ptr, size_t *shape_ptr,
|
||||
int64_t *output_ptr, size_t input_size, size_t rank,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void NonZero<double, int64_t>(const double *input_ptr, size_t *index_ptr, size_t *shape_ptr,
|
||||
int64_t *output_ptr, size_t input_size, size_t rank,
|
||||
cudaStream_t cuda_stream);
|
|
@ -0,0 +1,26 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_NON_ZERO_IMPL_CUH_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_NON_ZERO_IMPL_CUH_
|
||||
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h"
|
||||
|
||||
template <typename DataType, typename IndexType>
|
||||
CUDA_LIB_EXPORT void NonZero(const DataType *input_ptr, size_t *index_ptr, size_t *shape_ptr, IndexType *output_ptr,
|
||||
size_t input_size, size_t rank, cudaStream_t cuda_stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_NON_ZERO_IMPL_CUH_
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -14,14 +14,58 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <memory>
|
||||
#include "ops/non_zero.h"
|
||||
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "ops/primitive_c.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
constexpr size_t kNonZeroInputMinDim = 1;
|
||||
constexpr int64_t kNonZeroInputNum = 1;
|
||||
|
||||
abstract::ShapePtr NonZeroInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
|
||||
if (x_shape.size() < kNonZeroInputMinDim) {
|
||||
MS_EXCEPTION(ValueError) << "For NonZero, the dimension of input argument[x] must greater or equal to "
|
||||
<< kNonZeroInputMinDim << ", but got " << x_shape.size() << ".";
|
||||
}
|
||||
|
||||
auto x_num = std::accumulate(x_shape.begin(), x_shape.end(), 1, std::multiplies<int64_t>());
|
||||
|
||||
int64_t x_rank = SizeToLong(x_shape.size());
|
||||
ShapeVector output_shape = {abstract::Shape::SHP_ANY, x_rank};
|
||||
ShapeVector min_shape = {0, x_rank};
|
||||
ShapeVector max_shape = {x_num, x_rank};
|
||||
return std::make_shared<abstract::Shape>(output_shape, min_shape, max_shape);
|
||||
}
|
||||
|
||||
TypePtr NonZeroInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
const std::set valid_types = {kBool, kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16,
|
||||
kUInt32, kUInt64, kFloat16, kFloat, kFloat64, kComplex64, kComplex128};
|
||||
auto x_type = input_args[0]->BuildType();
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_type, valid_types, prim->name());
|
||||
return std::make_shared<TensorType>(kInt64);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
AbstractBasePtr NonZeroInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, kNonZeroInputNum, primitive->name());
|
||||
auto infer_type = NonZeroInferType(primitive, input_args);
|
||||
auto infer_shape = NonZeroInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||
}
|
||||
|
||||
MIND_API_OPERATOR_IMPL(NonZero, BaseOperator);
|
||||
REGISTER_PRIMITIVE_C(kNameNonZero, NonZero);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(NonZero, prim::kPrimNonZero, NonZeroInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -16,9 +16,13 @@
|
|||
|
||||
#ifndef MINDSPORE_CORE_OPS_NON_ZERO_H_
|
||||
#define MINDSPORE_CORE_OPS_NON_ZERO_H_
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
|
@ -29,8 +33,11 @@ class MIND_API NonZero : public BaseOperator {
|
|||
public:
|
||||
MIND_API_BASE_MEMBER(NonZero);
|
||||
/// \brief Constructor.
|
||||
NonZero() : BaseOperator(kNameNonZero) {}
|
||||
NonZero() : BaseOperator(kNameNonZero) { InitIOName({"x"}, {"output"}); }
|
||||
};
|
||||
AbstractBasePtr NonZeroInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
using PrimNonZeroPtr = std::shared_ptr<NonZero>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -7636,6 +7636,41 @@ class RightShift(Primitive):
|
|||
self.init_prim_io_names(inputs=['input_x', 'input_y'], outputs=['output'])
|
||||
|
||||
|
||||
class NonZero(Primitive):
|
||||
"""
|
||||
Return a tensor of the positions of all non-zero values.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. The data type is Number or Bool.
|
||||
|
||||
Outputs:
|
||||
- **y** (Tensor), The shape of tensor is 2-D. The data type is int64.
|
||||
|
||||
Raises:
|
||||
TypeError: If `x` is not Tensor.
|
||||
ValueError: If 'x' dim equal to 0.
|
||||
|
||||
Supported Platforms:
|
||||
``GPU``
|
||||
|
||||
Examples:
|
||||
>>> import mindspore
|
||||
>>> import numpy as np
|
||||
>>> from mindspore import Tensor
|
||||
>>> import mindspore.ops as ops
|
||||
>>> x = Tensor(np.array([[[1, 0], [-5, 0]]]), mindspore.int32)
|
||||
>>> nonzero = ops.NonZero()
|
||||
>>> output = nonzero(x)
|
||||
>>> print(output)
|
||||
[[0 0 0]
|
||||
[0 1 0]]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
self.init_prim_io_names(inputs=['x'], outputs=['y'])
|
||||
|
||||
|
||||
class Tril(Primitive):
|
||||
"""
|
||||
Returns the lower triangular part of the matrix (2-D tensor) or batch of matrices input,
|
||||
|
|
|
@ -0,0 +1,62 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
import mindspore.ops.operations.array_ops as P
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.ops = P.NonZero()
|
||||
|
||||
def construct(self, x):
|
||||
return self.ops(x)
|
||||
|
||||
|
||||
def compare_with_numpy(x):
|
||||
net = Net()
|
||||
# Graph Mode
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
ms_result_graph = net(Tensor(x))
|
||||
# PyNative Mode
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
|
||||
ms_result_pynative = net(Tensor(x))
|
||||
|
||||
np_result = np.transpose(np.nonzero(x))
|
||||
return np.array_equal(ms_result_graph, np_result) and np.array_equal(ms_result_pynative, np_result)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('data_shape', [(10, 10), (3, 4, 5)])
|
||||
@pytest.mark.parametrize('data_type',
|
||||
[np.int8, np.int16, np.int32, np.int64, np.float16,
|
||||
np.float32, np.float64, np.uint8, np.uint16])
|
||||
def test_net(data_shape, data_type):
|
||||
"""
|
||||
Feature: NonZero
|
||||
Description: test cases for NonZero operator.
|
||||
Expectation: the result match numpy nonzero.
|
||||
"""
|
||||
np.random.seed(1)
|
||||
x = np.random.randint(low=-1, high=2, size=data_shape).astype(data_type)
|
||||
assert compare_with_numpy(x)
|
Loading…
Reference in New Issue