forked from mindspore-Ecosystem/mindspore
!10690 add CPU ops: Greater/GreaterEqual/Range/GatherNd for center net
From: @caojian05 Reviewed-by: @wuxuejian,@oacjiewen Signed-off-by: @wuxuejian
This commit is contained in:
commit
ae0ea279f5
|
@ -167,6 +167,24 @@ void ArithmeticCPUKernel::SquaredDifference(const T *input1, const T *input2, T
|
|||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ArithmeticCPUKernel::Greater(const T *input1, const T *input2, bool *out, size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
std::vector<size_t> idx;
|
||||
GenIndex(i, &idx);
|
||||
out[i] = input1[idx[0]] > input2[idx[1]];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ArithmeticCPUKernel::GreaterEqual(const T *input1, const T *input2, bool *out, size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
std::vector<size_t> idx;
|
||||
GenIndex(i, &idx);
|
||||
out[i] = input1[idx[0]] >= input2[idx[1]];
|
||||
}
|
||||
}
|
||||
|
||||
void ArithmeticCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node);
|
||||
|
@ -190,6 +208,10 @@ void ArithmeticCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|||
operate_type_ = EQUAL;
|
||||
} else if (kernel_name == prim::kPrimNotEqual->name()) {
|
||||
operate_type_ = NOTEQUAL;
|
||||
} else if (kernel_name == prim::kPrimGreater->name()) {
|
||||
operate_type_ = GREATER;
|
||||
} else if (kernel_name == prim::kPrimGreaterEqual->name()) {
|
||||
operate_type_ = GREATEREQUAL;
|
||||
} else if (kernel_name == prim::kPrimAssignAdd->name()) {
|
||||
operate_type_ = ASSIGNADD;
|
||||
} else if (kernel_name == prim::kPrimSquaredDifference->name()) {
|
||||
|
@ -301,6 +323,11 @@ void ArithmeticCPUKernel::LaunchKernelLogic(const std::vector<AddressPtr> &input
|
|||
threads.emplace_back(std::thread(&ArithmeticCPUKernel::Equal<T>, this, input1, input2, output, start, end));
|
||||
} else if (operate_type_ == NOTEQUAL) {
|
||||
threads.emplace_back(std::thread(&ArithmeticCPUKernel::NotEqual<T>, this, input1, input2, output, start, end));
|
||||
} else if (operate_type_ == GREATER) {
|
||||
threads.emplace_back(std::thread(&ArithmeticCPUKernel::Greater<T>, this, input1, input2, output, start, end));
|
||||
} else if (operate_type_ == GREATEREQUAL) {
|
||||
threads.emplace_back(
|
||||
std::thread(&ArithmeticCPUKernel::GreaterEqual<T>, this, input1, input2, output, start, end));
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Not support " << operate_type_;
|
||||
}
|
||||
|
|
|
@ -63,6 +63,10 @@ class ArithmeticCPUKernel : public CPUKernel {
|
|||
void NotEqual(const T *input1, const T *input2, bool *out, size_t start, size_t end);
|
||||
template <typename T>
|
||||
void SquaredDifference(const T *input1, const T *input2, T *out, size_t start, size_t end);
|
||||
template <typename T>
|
||||
void Greater(const T *input1, const T *input2, bool *out, size_t start, size_t end);
|
||||
template <typename T>
|
||||
void GreaterEqual(const T *input1, const T *input2, bool *out, size_t start, size_t end);
|
||||
std::vector<size_t> input_shape0_;
|
||||
std::vector<size_t> input_shape1_;
|
||||
std::vector<size_t> input_element_num0_;
|
||||
|
@ -213,6 +217,28 @@ MS_REG_CPU_KERNEL(
|
|||
SquaredDifference,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
ArithmeticCPUKernel);
|
||||
MS_REG_CPU_KERNEL(
|
||||
Greater, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
|
||||
ArithmeticCPUKernel);
|
||||
MS_REG_CPU_KERNEL(
|
||||
Greater,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
|
||||
ArithmeticCPUKernel);
|
||||
MS_REG_CPU_KERNEL(
|
||||
Greater, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
|
||||
ArithmeticCPUKernel);
|
||||
MS_REG_CPU_KERNEL(
|
||||
GreaterEqual,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
|
||||
ArithmeticCPUKernel);
|
||||
MS_REG_CPU_KERNEL(
|
||||
GreaterEqual,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
|
||||
ArithmeticCPUKernel);
|
||||
MS_REG_CPU_KERNEL(
|
||||
GreaterEqual,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
|
||||
ArithmeticCPUKernel);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -53,6 +53,9 @@ const char END[] = "end";
|
|||
const char SIZE[] = "size";
|
||||
const char USE_NESTEROV[] = "use_nesterov";
|
||||
const char GROUP[] = "group";
|
||||
const char START[] = "start";
|
||||
const char LIMIT[] = "limit";
|
||||
const char DELTA[] = "delta";
|
||||
|
||||
enum OperateType {
|
||||
ADD = 0,
|
||||
|
@ -79,7 +82,9 @@ enum OperateType {
|
|||
EQUAL,
|
||||
NOTEQUAL,
|
||||
FLOOR,
|
||||
SQUAREDDIFFERENCE
|
||||
SQUAREDDIFFERENCE,
|
||||
GREATER,
|
||||
GREATEREQUAL,
|
||||
};
|
||||
|
||||
class CPUKernel : public kernel::KernelMod {
|
||||
|
|
|
@ -0,0 +1,104 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/kernel_compiler/cpu/gathernd_cpu_kernel.h"
|
||||
#include "runtime/device/cpu/cpu_device_address.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
||||
void GatherNdCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||
input_shapes_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
indices_shapes_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
|
||||
output_shapes_ = AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
||||
|
||||
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
|
||||
|
||||
// ReShape()
|
||||
size_t dim_of_indices = 1;
|
||||
for (size_t i = 0; i < indices_shapes_.size() - IntToSize(1); ++i) {
|
||||
dim_of_indices *= indices_shapes_[i];
|
||||
}
|
||||
|
||||
size_t dim_after_indices = 1;
|
||||
size_t dim_indices_last = indices_shapes_[indices_shapes_.size() - IntToSize(1)];
|
||||
for (size_t i = dim_indices_last; i < input_shapes_.size(); i++) {
|
||||
dim_after_indices *= input_shapes_[i];
|
||||
}
|
||||
|
||||
dims_.emplace_back(dim_of_indices);
|
||||
dims_.emplace_back(dim_after_indices);
|
||||
dims_.emplace_back(dim_indices_last);
|
||||
|
||||
batch_strides_.resize(dim_indices_last, 0);
|
||||
batch_indices_.resize(dim_indices_last, 0);
|
||||
|
||||
if (dim_indices_last > 0) {
|
||||
batch_strides_[dim_indices_last - 1] = input_shapes_[dim_indices_last - 1];
|
||||
batch_indices_[dim_indices_last - 1] = dims_[1];
|
||||
}
|
||||
|
||||
for (size_t i = dim_indices_last - 1; i > 0; --i) {
|
||||
batch_strides_[i - 1] = input_shapes_[i - 1];
|
||||
batch_indices_[i - 1] = batch_indices_[i] * input_shapes_[i];
|
||||
}
|
||||
}
|
||||
|
||||
bool GatherNdCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> & /*workspace*/,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
if (dtype_ == kNumberTypeInt32) {
|
||||
return LaunchKernel<int32_t>(inputs, outputs);
|
||||
} else if (dtype_ == kNumberTypeInt64) {
|
||||
return LaunchKernel<int64_t>(inputs, outputs);
|
||||
} else if (dtype_ == kNumberTypeFloat32) {
|
||||
return LaunchKernel<float>(inputs, outputs);
|
||||
} else if (dtype_ == kNumberTypeFloat64) {
|
||||
return LaunchKernel<double>(inputs, outputs);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Only support int, float, but actual data type is " << TypeIdLabel(dtype_);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool GatherNdCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs) {
|
||||
auto input_addr = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
auto indices_addr = reinterpret_cast<int *>(inputs[1]->addr);
|
||||
auto output_addr = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
|
||||
//
|
||||
size_t output_dim0 = dims_[0];
|
||||
size_t output_dim1 = dims_[1];
|
||||
size_t indices_dim1 = dims_[2];
|
||||
|
||||
int num = output_dim0 * output_dim1;
|
||||
|
||||
for (int write_index = 0; write_index < num; write_index++) {
|
||||
int i = write_index / output_dim1 % output_dim0;
|
||||
int j = write_index % output_dim1;
|
||||
|
||||
int read_index = 0;
|
||||
for (size_t k = 0; k < indices_dim1; k++) {
|
||||
size_t ind = indices_dim1 * i + k;
|
||||
int indices_i = indices_addr[ind];
|
||||
read_index += indices_i * batch_indices_[k];
|
||||
}
|
||||
read_index += j;
|
||||
output_addr[write_index] = input_addr[read_index];
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,67 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_GATHERND_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_GATHERND_CPU_KERNEL_H_
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class GatherNdCPUKernel : public CPUKernel {
|
||||
public:
|
||||
GatherNdCPUKernel() = default;
|
||||
~GatherNdCPUKernel() override = default;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
template <typename T>
|
||||
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
|
||||
private:
|
||||
std::vector<size_t> input_shapes_;
|
||||
std::vector<size_t> indices_shapes_;
|
||||
std::vector<size_t> output_shapes_;
|
||||
|
||||
std::vector<size_t> dims_;
|
||||
std::vector<int> batch_indices_;
|
||||
std::vector<int> batch_strides_;
|
||||
|
||||
TypeId dtype_{kTypeUnknown};
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(
|
||||
GatherNd, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
GatherNdCPUKernel);
|
||||
MS_REG_CPU_KERNEL(
|
||||
GatherNd, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
|
||||
GatherNdCPUKernel);
|
||||
MS_REG_CPU_KERNEL(
|
||||
GatherNd,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
|
||||
GatherNdCPUKernel);
|
||||
MS_REG_CPU_KERNEL(
|
||||
GatherNd,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
|
||||
GatherNdCPUKernel);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_GATHERND_CPU_KERNEL_H_
|
|
@ -0,0 +1,56 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/kernel_compiler/cpu/range_cpu_kernel.h"
|
||||
#include "runtime/device/cpu/cpu_device_address.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
void RangeCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
|
||||
|
||||
start_ = AnfAlgo::GetNodeAttr<float>(kernel_node, START);
|
||||
limit_ = AnfAlgo::GetNodeAttr<float>(kernel_node, LIMIT);
|
||||
delta_ = AnfAlgo::GetNodeAttr<float>(kernel_node, DELTA);
|
||||
}
|
||||
|
||||
bool RangeCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> & /*workspace*/,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
if (dtype_ == kNumberTypeInt32) {
|
||||
return LaunchKernel<int32_t>(inputs, outputs);
|
||||
} else if (dtype_ == kNumberTypeInt64) {
|
||||
return LaunchKernel<int64_t>(inputs, outputs);
|
||||
} else if (dtype_ == kNumberTypeFloat32) {
|
||||
return LaunchKernel<float>(inputs, outputs);
|
||||
} else if (dtype_ == kNumberTypeFloat64) {
|
||||
return LaunchKernel<double>(inputs, outputs);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Only support int, float, but actual data type is " << TypeIdLabel(dtype_);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool RangeCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs) {
|
||||
auto output_addr = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
size_t elem_num = outputs[0]->size / sizeof(T);
|
||||
for (size_t i = 0; i < elem_num; i++) {
|
||||
output_addr[i] = start_ + i * delta_;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,54 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RANGE_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RANGE_CPU_KERNEL_H_
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class RangeCPUKernel : public CPUKernel {
|
||||
public:
|
||||
RangeCPUKernel() = default;
|
||||
~RangeCPUKernel() override = default;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
template <typename T>
|
||||
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
|
||||
private:
|
||||
TypeId dtype_{kTypeUnknown};
|
||||
int64_t start_;
|
||||
int64_t limit_;
|
||||
int64_t delta_;
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(Range, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), RangeCPUKernel);
|
||||
MS_REG_CPU_KERNEL(Range, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), RangeCPUKernel);
|
||||
MS_REG_CPU_KERNEL(Range, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
RangeCPUKernel);
|
||||
MS_REG_CPU_KERNEL(Range, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
RangeCPUKernel);
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RANGE_CPU_KERNEL_H_
|
|
@ -116,7 +116,7 @@ class Range(Cell):
|
|||
Tensor, the dtype is int if the dtype of `start`, `limit` and `delta` all are int. Otherwise, dtype is float.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend``
|
||||
``Ascend`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> net = nn.Range(1, 8, 2)
|
||||
|
|
|
@ -3078,7 +3078,7 @@ class GatherNd(PrimitiveWithInfer):
|
|||
Tensor, has the same type as `input_x` and the shape is indices_shape[:-1] + x_shape[indices_shape[-1]:].
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU``
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> input_x = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32)
|
||||
|
|
|
@ -2698,7 +2698,7 @@ class Greater(_LogicBinaryOp):
|
|||
Tensor, the shape is the same as the one after broadcasting,and the data type is bool.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU``
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> input_x = Tensor(np.array([1, 2, 3]), mindspore.int32)
|
||||
|
@ -2739,7 +2739,7 @@ class GreaterEqual(_LogicBinaryOp):
|
|||
Tensor, the shape is the same as the one after broadcasting,and the data type is bool.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU``
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> input_x = Tensor(np.array([1, 2, 3]), mindspore.int32)
|
||||
|
|
|
@ -0,0 +1,188 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
|
||||
|
||||
|
||||
class OpNetWrapper(nn.Cell):
|
||||
def __init__(self, op):
|
||||
super(OpNetWrapper, self).__init__()
|
||||
self.op = op
|
||||
|
||||
def construct(self, *inputs):
|
||||
return self.op(*inputs)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_case1_basic_func():
|
||||
op = P.GatherNd()
|
||||
op_wrapper = OpNetWrapper(op)
|
||||
|
||||
indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32)
|
||||
params = Tensor(np.array([[0, 1], [2, 3]]), mindspore.float32)
|
||||
outputs = op_wrapper(params, indices)
|
||||
print(outputs)
|
||||
expected = [0, 3]
|
||||
assert np.allclose(outputs.asnumpy(), np.array(expected))
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_case2_indices_to_matrix():
|
||||
op = P.GatherNd()
|
||||
op_wrapper = OpNetWrapper(op)
|
||||
|
||||
indices = Tensor(np.array([[1], [0]]), mindspore.int32)
|
||||
params = Tensor(np.array([[0, 1], [2, 3]]), mindspore.float32)
|
||||
outputs = op_wrapper(params, indices)
|
||||
print(outputs)
|
||||
expected = [[2, 3], [0, 1]]
|
||||
assert np.allclose(outputs.asnumpy(), np.array(expected))
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_case3_indices_to_3d_tensor():
|
||||
op = P.GatherNd()
|
||||
op_wrapper = OpNetWrapper(op)
|
||||
|
||||
indices = Tensor(np.array([[1]]), mindspore.int32) # (1, 1)
|
||||
params = Tensor(np.array([[[0, 1], [2, 3]],
|
||||
[[4, 5], [6, 7]]]), mindspore.float32) # (2, 2, 2)
|
||||
outputs = op_wrapper(params, indices)
|
||||
print(outputs)
|
||||
expected = [[[4, 5], [6, 7]]] # (1, 2, 2)
|
||||
assert np.allclose(outputs.asnumpy(), np.array(expected))
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_case4():
|
||||
op = P.GatherNd()
|
||||
op_wrapper = OpNetWrapper(op)
|
||||
|
||||
indices = Tensor(np.array([[0, 1], [1, 0]]), mindspore.int32) # (2, 2)
|
||||
params = Tensor(np.array([[[0, 1], [2, 3]],
|
||||
[[4, 5], [6, 7]]]), mindspore.float32) # (2, 2, 2)
|
||||
outputs = op_wrapper(params, indices)
|
||||
print(outputs)
|
||||
expected = [[2, 3], [4, 5]] # (2, 2)
|
||||
assert np.allclose(outputs.asnumpy(), np.array(expected))
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_case5():
|
||||
op = P.GatherNd()
|
||||
op_wrapper = OpNetWrapper(op)
|
||||
|
||||
indices = Tensor(np.array([[0, 0, 1], [1, 0, 1]]), mindspore.int32) # (2, 3)
|
||||
params = Tensor(np.array([[[0, 1], [2, 3]],
|
||||
[[4, 5], [6, 7]]]), mindspore.float32) # (2, 2, 2)
|
||||
outputs = op_wrapper(params, indices)
|
||||
print(outputs)
|
||||
expected = [1, 5] # (2,)
|
||||
assert np.allclose(outputs.asnumpy(), np.array(expected))
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_case6():
|
||||
op = P.GatherNd()
|
||||
op_wrapper = OpNetWrapper(op)
|
||||
|
||||
indices = Tensor(np.array([[[0, 0]], [[0, 1]]]), mindspore.int32) # (2, 1, 2)
|
||||
params = Tensor(np.array([[[0, 1], [2, 3]],
|
||||
[[4, 5], [6, 7]]]), mindspore.float32) # (2, 2, 2)
|
||||
outputs = op_wrapper(params, indices)
|
||||
print(outputs)
|
||||
expected = [[[0, 1]], [[2, 3]]] # (2, 1, 2)
|
||||
assert np.allclose(outputs.asnumpy(), np.array(expected))
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_case7():
|
||||
op = P.GatherNd()
|
||||
op_wrapper = OpNetWrapper(op)
|
||||
|
||||
indices = Tensor(np.array([[[1]], [[0]]]), mindspore.int32) # (2, 1, 1)
|
||||
params = Tensor(np.array([[[0, 1], [2, 3]],
|
||||
[[4, 5], [6, 7]]]), mindspore.float32) # (2, 2, 2)
|
||||
outputs = op_wrapper(params, indices)
|
||||
print(outputs)
|
||||
expected = [[[[4, 5], [6, 7]]], [[[0, 1], [2, 3]]]] # (2, 1, 2, 2)
|
||||
assert np.allclose(outputs.asnumpy(), np.array(expected))
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_case8():
|
||||
op = P.GatherNd()
|
||||
op_wrapper = OpNetWrapper(op)
|
||||
|
||||
indices = Tensor(np.array([[[0, 1], [1, 0]], [[0, 0], [1, 1]]]), mindspore.int32) # (2, 2, 2)
|
||||
params = Tensor(np.array([[[0, 1], [2, 3]],
|
||||
[[4, 5], [6, 7]]]), mindspore.float32) # (2, 2, 2)
|
||||
outputs = op_wrapper(params, indices)
|
||||
print(outputs)
|
||||
expected = [[[2, 3], [4, 5]], [[0, 1], [6, 7]]] # (2, 2, 2)
|
||||
assert np.allclose(outputs.asnumpy(), np.array(expected))
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_case9():
|
||||
op = P.GatherNd()
|
||||
op_wrapper = OpNetWrapper(op)
|
||||
|
||||
indices = Tensor(np.array([[[0, 0, 1], [1, 0, 1]], [[0, 1, 1], [1, 1, 0]]]), mindspore.int32) # (2, 2, 3)
|
||||
params = Tensor(np.array([[[0, 1], [2, 3]],
|
||||
[[4, 5], [6, 7]]]), mindspore.int64) # (2, 2, 2)
|
||||
outputs = op_wrapper(params, indices)
|
||||
print(outputs)
|
||||
expected = [[1, 5], [3, 6]] # (2, 2, 2)
|
||||
assert np.allclose(outputs.asnumpy(), np.array(expected))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_case1_basic_func()
|
||||
test_case2_indices_to_matrix()
|
||||
test_case3_indices_to_3d_tensor()
|
||||
test_case4()
|
||||
test_case5()
|
||||
test_case6()
|
||||
test_case7()
|
||||
test_case8()
|
||||
test_case9()
|
|
@ -0,0 +1,70 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
|
||||
|
||||
|
||||
class OpNetWrapper(nn.Cell):
|
||||
def __init__(self, op):
|
||||
super(OpNetWrapper, self).__init__()
|
||||
self.op = op
|
||||
|
||||
def construct(self, *inputs):
|
||||
return self.op(*inputs)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_int32():
|
||||
op = P.GreaterEqual()
|
||||
op_wrapper = OpNetWrapper(op)
|
||||
|
||||
input_x = Tensor(np.array([1, 2, 3]).astype(np.int32))
|
||||
input_y = Tensor(np.array([3, 2, 1]).astype(np.int32))
|
||||
outputs = op_wrapper(input_x, input_y)
|
||||
|
||||
print(outputs)
|
||||
assert outputs.shape == (3,)
|
||||
assert np.allclose(outputs.asnumpy(), [False, True, True])
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_float32():
|
||||
op = P.GreaterEqual()
|
||||
op_wrapper = OpNetWrapper(op)
|
||||
|
||||
input_x = Tensor(np.array([1, 2, -1]).astype(np.float32))
|
||||
input_y = Tensor(np.array([-3, 2, -1]).astype(np.float32))
|
||||
outputs = op_wrapper(input_x, input_y)
|
||||
|
||||
print(outputs)
|
||||
assert outputs.shape == (3,)
|
||||
assert np.allclose(outputs.asnumpy(), [True, True, True])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_int32()
|
||||
test_float32()
|
|
@ -0,0 +1,70 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
|
||||
|
||||
|
||||
class OpNetWrapper(nn.Cell):
|
||||
def __init__(self, op):
|
||||
super(OpNetWrapper, self).__init__()
|
||||
self.op = op
|
||||
|
||||
def construct(self, *inputs):
|
||||
return self.op(*inputs)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_int32():
|
||||
op = P.Greater()
|
||||
op_wrapper = OpNetWrapper(op)
|
||||
|
||||
input_x = Tensor(np.array([1, 2, 3]).astype(np.int32))
|
||||
input_y = Tensor(np.array([3, 2, 1]).astype(np.int32))
|
||||
outputs = op_wrapper(input_x, input_y)
|
||||
|
||||
print(outputs)
|
||||
assert outputs.shape == (3,)
|
||||
assert np.allclose(outputs.asnumpy(), [False, False, True])
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_float32():
|
||||
op = P.Greater()
|
||||
op_wrapper = OpNetWrapper(op)
|
||||
|
||||
input_x = Tensor(np.array([1, 2, -1]).astype(np.float32))
|
||||
input_y = Tensor(np.array([-3, 2, -1]).astype(np.float32))
|
||||
outputs = op_wrapper(input_x, input_y)
|
||||
|
||||
print(outputs)
|
||||
assert outputs.shape == (3,)
|
||||
assert np.allclose(outputs.asnumpy(), [True, False, False])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_int32()
|
||||
test_float32()
|
|
@ -0,0 +1,62 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
|
||||
|
||||
|
||||
class OpNetWrapper(nn.Cell):
|
||||
def __init__(self, op):
|
||||
super(OpNetWrapper, self).__init__()
|
||||
self.op = op
|
||||
|
||||
def construct(self, *inputs):
|
||||
return self.op(*inputs)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_int():
|
||||
op = nn.Range(0, 100, 10)
|
||||
op_wrapper = OpNetWrapper(op)
|
||||
|
||||
outputs = op_wrapper()
|
||||
print(outputs)
|
||||
assert outputs.shape == (10,)
|
||||
assert np.allclose(outputs.asnumpy(), range(0, 100, 10))
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_float():
|
||||
op = nn.Range(10., 100., 20.)
|
||||
op_wrapper = OpNetWrapper(op)
|
||||
|
||||
outputs = op_wrapper()
|
||||
print(outputs)
|
||||
assert outputs.shape == (5,)
|
||||
assert np.allclose(outputs.asnumpy(), [10., 30., 50., 70., 90.])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_int()
|
||||
test_float()
|
Loading…
Reference in New Issue