forked from mindspore-Ecosystem/mindspore
!7030 add Meshgrid ops for aicpu
Merge pull request !7030 from yanzhenxiang2020/br_meshgrid
This commit is contained in:
commit
cfb131b844
|
@ -38,10 +38,10 @@ void AicpuMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<
|
|||
return;
|
||||
}
|
||||
// For compatibility with the current framework
|
||||
if (op_name == kPrint || op_name == kGetNext || op_name == kPack) {
|
||||
if (op_name == kPrint || op_name == kGetNext || op_name == kPack || op_name == kMeshgrid) {
|
||||
std::vector<std::string> inputs_format{};
|
||||
std::vector<TypeId> inputs_type{};
|
||||
if (op_name == kPrint || op_name == kPack) {
|
||||
if (op_name == kPrint || op_name == kPack || op_name == kMeshgrid) {
|
||||
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) {
|
||||
inputs_format.emplace_back(kOpFormat_DEFAULT);
|
||||
inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_index));
|
||||
|
|
|
@ -29,6 +29,7 @@ constexpr auto kInitData = "InitData";
|
|||
constexpr auto kGetNext = "GetNext";
|
||||
constexpr auto kPrint = "Print";
|
||||
constexpr auto kPack = "Pack";
|
||||
constexpr auto kMeshgrid = "Meshgrid";
|
||||
constexpr auto kOutputTypes = "output_types";
|
||||
constexpr auto kOutputShapes = "output_shapes";
|
||||
constexpr auto kChannelName = "channel_name";
|
||||
|
@ -46,7 +47,7 @@ constexpr auto kEditDistance = "EditDistance";
|
|||
constexpr auto kGatherD = "GatherD";
|
||||
constexpr auto kIdentity = "Identity";
|
||||
constexpr auto kCustRunApi = "RunCpuKernel";
|
||||
const std::set<std::string> kCustAiCpuKernelOps{kEditDistance, kGatherD, kIdentity};
|
||||
const std::set<std::string> kCustAiCpuKernelOps{kEditDistance, kGatherD, kIdentity, kMeshgrid};
|
||||
|
||||
struct AicpuParamHead {
|
||||
uint32_t length; // Total length: include cunstom message
|
||||
|
|
|
@ -55,3 +55,4 @@ from .fused_sparse_adam import _fused_sparse_adam_aicpu
|
|||
from .fused_sparse_lazy_adam import _fused_sparse_lazy_adam_aicpu
|
||||
from .fused_sparse_ftrl import _fused_sparse_ftrl_aicpu
|
||||
from .fused_sparse_proximal_adagrad import _fused_sparse_proximal_adagrad_aicpu
|
||||
from .meshgrid import _meshgrid_aicpu
|
||||
|
|
|
@ -0,0 +1,41 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Meshgrid op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||
|
||||
meshgrid_op_info = AiCPURegOp("Meshgrid") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.attr("indexing", "str") \
|
||||
.input(0, "x", "dynamic") \
|
||||
.output(0, "y", "dynamic") \
|
||||
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.I16_Default, DataType.I16_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.U16_Default, DataType.U16_Default) \
|
||||
.dtype_format(DataType.U32_Default, DataType.U32_Default) \
|
||||
.dtype_format(DataType.U64_Default, DataType.U64_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F64_Default, DataType.F64_Default) \
|
||||
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
|
||||
.get_op_info()
|
||||
|
||||
@op_info_register(meshgrid_op_info)
|
||||
def _meshgrid_aicpu():
|
||||
"""Meshgrid AiCPU register"""
|
||||
return
|
|
@ -24,7 +24,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
|
|||
Diag, DiagPart, DType, ExpandDims, Eye,
|
||||
Fill, GatherNd, GatherV2, SparseGatherV2, InvertPermutation,
|
||||
IsInstance, IsSubClass, ArgMaxWithValue, OnesLike, ZerosLike,
|
||||
Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue,
|
||||
Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue, Meshgrid,
|
||||
SameTypeShape, ScatterAdd, ScatterSub, ScatterMul, ScatterDiv, ScatterMax, ScatterMin,
|
||||
ScatterUpdate, ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select,
|
||||
Shape, DynamicShape, Size, Slice, Split, TransShape, ParallelConcat, Padding,
|
||||
|
@ -110,6 +110,7 @@ __all__ = [
|
|||
'MatMul',
|
||||
'BatchMatMul',
|
||||
'Mul',
|
||||
'Meshgrid',
|
||||
'Pow',
|
||||
'Exp',
|
||||
'Expm1',
|
||||
|
|
|
@ -3509,6 +3509,103 @@ class BroadcastTo(PrimitiveWithInfer):
|
|||
return x_dtype
|
||||
|
||||
|
||||
class Meshgrid(PrimitiveWithInfer):
|
||||
"""
|
||||
Generates coordinate matrices from given coordinate tensors.
|
||||
|
||||
Given N one-dimensional coordinate tensors, returns a list outputs of N N-D
|
||||
coordinate tensors for evaluating expressions on an N-D grid.
|
||||
|
||||
|
||||
Args:
|
||||
indexing (str): Either 'xy' or 'ij'. Default: 'xy'.
|
||||
When the indexing argument is set to 'xy' (the default),
|
||||
the broadcasting instructions for the first two dimensions are swapped.
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Union[tuple, list]) - A Tuple or list of N 1-D Tensor objects.
|
||||
The length of input_x should be greater than 1
|
||||
|
||||
Outputs:
|
||||
Tensors, A Tuple of N N-D Tensor objects.
|
||||
|
||||
Examples:
|
||||
>>> x = np.array([1, 2, 3, 4]).astype(np.int32)
|
||||
>>> y = np.array([5, 6, 7]).astype(np.int32)
|
||||
>>> z = np.array([8, 9, 0, 1, 2]).astype(np.int32)
|
||||
>>> inputs = (x, y, z)
|
||||
>>> meshgrid = P.Meshgrid(indexing="xy")
|
||||
>>> meshgrid(inputs)
|
||||
(Tensor(shape=[3, 4, 6], dtype=UInt32, value=
|
||||
[[[1, 1, 1, 1, 1],
|
||||
[2, 2, 2, 2, 2],
|
||||
[3, 3, 3, 3, 3],
|
||||
[4, 4, 4, 4, 4]],
|
||||
[[1, 1, 1, 1, 1],
|
||||
[2, 2, 2, 2, 2],
|
||||
[3, 3, 3, 3, 3],
|
||||
[4, 4, 4, 4, 4]],
|
||||
[[1, 1, 1, 1, 1],
|
||||
[2, 2, 2, 2, 2],
|
||||
[3, 3, 3, 3, 3],
|
||||
[4, 4, 4, 4, 4]]]),
|
||||
Tensor(shape=[3, 4, 6], dtype=UInt32, value=
|
||||
[[[5, 5, 5, 5, 5],
|
||||
[5, 5, 5, 5, 5],
|
||||
[5, 5, 5, 5, 5],
|
||||
[5, 5, 5, 5, 5]],
|
||||
[[6, 6, 6, 6, 6],
|
||||
[6, 6, 6, 6, 6],
|
||||
[6, 6, 6, 6, 6],
|
||||
[6, 6, 6, 6, 6]],
|
||||
[[7, 7, 7, 7, 7],
|
||||
[7, 7, 7, 7, 7],
|
||||
[7, 7, 7, 7, 7],
|
||||
[7, 7, 7, 7, 7]]]),
|
||||
Tensor(shape=[3, 4, 6], dtype=UInt32, value=
|
||||
[[[8, 9, 0, 1, 2],
|
||||
[8, 9, 0, 1, 2],
|
||||
[8, 9, 0, 1, 2],
|
||||
[8, 9, 0, 1, 2]],
|
||||
[[8, 9, 0, 1, 2],
|
||||
[8, 9, 0, 1, 2],
|
||||
[8, 9, 0, 1, 2],
|
||||
[8, 9, 0, 1, 2]],
|
||||
[[8, 9, 0, 1, 2],
|
||||
[8, 9, 0, 1, 2],
|
||||
[8, 9, 0, 1, 2],
|
||||
[8, 9, 0, 1, 2]]]))
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, indexing="xy"):
|
||||
"""Init Meshgrid"""
|
||||
validator.check_value_type("indexing", indexing, (str), self.name)
|
||||
if indexing not in ("xy", "ij"):
|
||||
raise ValueError("indexing parameter must be either 'xy' or 'ij'")
|
||||
self.indexing = indexing
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
validator.check_value_type("shape", x_shape, [tuple, list], self.name)
|
||||
validator.check_integer("len of input_x", len(x_shape), 2, Rel.GE, self.name)
|
||||
n = len(x_shape)
|
||||
shape_0 = []
|
||||
for s in x_shape:
|
||||
validator.check_integer('each_input_rank', len(s), 1, Rel.EQ, self.name)
|
||||
shape_0.append(s[0])
|
||||
if self.indexing == "xy":
|
||||
shape_0[0], shape_0[1] = shape_0[1], shape_0[0]
|
||||
out_shape = tuple(tuple(shape_0) for _ in range(n))
|
||||
return out_shape
|
||||
|
||||
|
||||
def infer_dtype(self, x_type):
|
||||
validator.check_subclass("input_x[0]", x_type[0], mstype.tensor, self.name)
|
||||
n = len(x_type)
|
||||
for i in range(1, n):
|
||||
validator.check('x_type[%d]' % i, x_type[i], 'base', x_type[0], Rel.EQ, self.name, TypeError)
|
||||
return x_type
|
||||
|
||||
class InplaceUpdate(PrimitiveWithInfer):
|
||||
r"""
|
||||
Updates specified rows with values in `v`.
|
||||
|
|
|
@ -0,0 +1,265 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, x, indexing):
|
||||
super(Net, self).__init__()
|
||||
self.meshgrid = P.Meshgrid(indexing)
|
||||
self.x = x
|
||||
|
||||
def construct(self):
|
||||
return self.meshgrid(self.x)
|
||||
|
||||
|
||||
def test_net_bool():
|
||||
x = np.random.randn(4,) > 0
|
||||
y = np.random.randn(3,) > 0
|
||||
z = np.random.randn(6,) > 0
|
||||
indexing = "xy"
|
||||
|
||||
net = Net((Tensor(x), Tensor(y), Tensor(z)), indexing)
|
||||
output = net()
|
||||
print(x)
|
||||
print(y)
|
||||
print(z)
|
||||
print(output)
|
||||
np_output = np.meshgrid(x, y, z, indexing=indexing)
|
||||
assert np.array_equal(output[0].asnumpy(), np_output[0])
|
||||
assert np.array_equal(output[1].asnumpy(), np_output[1])
|
||||
assert np.array_equal(output[2].asnumpy(), np_output[2])
|
||||
|
||||
|
||||
def test_net_int8():
|
||||
x = np.random.randn(4,).astype(np.int8)
|
||||
y = np.random.randn(3,).astype(np.int8)
|
||||
z = np.random.randn(6,).astype(np.int8)
|
||||
indexing = "xy"
|
||||
|
||||
net = Net((Tensor(x), Tensor(y), Tensor(z)), indexing)
|
||||
output = net()
|
||||
print(x)
|
||||
print(y)
|
||||
print(z)
|
||||
print(output)
|
||||
np_output = np.meshgrid(x, y, z, indexing=indexing)
|
||||
assert np.array_equal(output[0].asnumpy(), np_output[0])
|
||||
assert np.array_equal(output[1].asnumpy(), np_output[1])
|
||||
assert np.array_equal(output[2].asnumpy(), np_output[2])
|
||||
|
||||
|
||||
def test_net_uint8():
|
||||
x = np.random.randn(4,).astype(np.uint8)
|
||||
y = np.random.randn(3,).astype(np.uint8)
|
||||
z = np.random.randn(6,).astype(np.uint8)
|
||||
indexing = "xy"
|
||||
|
||||
net = Net((Tensor(x), Tensor(y), Tensor(z)), indexing)
|
||||
output = net()
|
||||
print(x)
|
||||
print(y)
|
||||
print(z)
|
||||
print(output)
|
||||
np_output = np.meshgrid(x, y, z, indexing=indexing)
|
||||
assert np.array_equal(output[0].asnumpy(), np_output[0])
|
||||
assert np.array_equal(output[1].asnumpy(), np_output[1])
|
||||
assert np.array_equal(output[2].asnumpy(), np_output[2])
|
||||
|
||||
|
||||
def test_net_int16():
|
||||
x = np.random.randn(4,).astype(np.int16)
|
||||
y = np.random.randn(3,).astype(np.int16)
|
||||
z = np.random.randn(6,).astype(np.int16)
|
||||
indexing = "xy"
|
||||
|
||||
net = Net((Tensor(x), Tensor(y), Tensor(z)), indexing)
|
||||
output = net()
|
||||
print(x)
|
||||
print(y)
|
||||
print(z)
|
||||
print(output)
|
||||
np_output = np.meshgrid(x, y, z, indexing=indexing)
|
||||
assert np.array_equal(output[0].asnumpy(), np_output[0])
|
||||
assert np.array_equal(output[1].asnumpy(), np_output[1])
|
||||
assert np.array_equal(output[2].asnumpy(), np_output[2])
|
||||
|
||||
|
||||
def test_net_uint16():
|
||||
x = np.random.randn(4,).astype(np.uint16)
|
||||
y = np.random.randn(3,).astype(np.uint16)
|
||||
z = np.random.randn(6,).astype(np.uint16)
|
||||
indexing = "xy"
|
||||
|
||||
net = Net((Tensor(x), Tensor(y), Tensor(z)), indexing)
|
||||
output = net()
|
||||
print(x)
|
||||
print(y)
|
||||
print(z)
|
||||
print(output)
|
||||
np_output = np.meshgrid(x, y, z, indexing=indexing)
|
||||
assert np.array_equal(output[0].asnumpy(), np_output[0])
|
||||
assert np.array_equal(output[1].asnumpy(), np_output[1])
|
||||
assert np.array_equal(output[2].asnumpy(), np_output[2])
|
||||
|
||||
|
||||
def test_net_int32():
|
||||
x = np.random.randn(4,).astype(np.int32)
|
||||
y = np.random.randn(3,).astype(np.int32)
|
||||
z = np.random.randn(6,).astype(np.int32)
|
||||
indexing = "xy"
|
||||
|
||||
net = Net((Tensor(x), Tensor(y), Tensor(z)), indexing)
|
||||
output = net()
|
||||
print(x)
|
||||
print(y)
|
||||
print(z)
|
||||
print(output)
|
||||
np_output = np.meshgrid(x, y, z, indexing=indexing)
|
||||
assert np.array_equal(output[0].asnumpy(), np_output[0])
|
||||
assert np.array_equal(output[1].asnumpy(), np_output[1])
|
||||
assert np.array_equal(output[2].asnumpy(), np_output[2])
|
||||
|
||||
|
||||
def test_net_uint32():
|
||||
x = np.random.randn(4,).astype(np.uint32)
|
||||
y = np.random.randn(3,).astype(np.uint32)
|
||||
z = np.random.randn(6,).astype(np.uint32)
|
||||
indexing = "xy"
|
||||
|
||||
net = Net((Tensor(x), Tensor(y), Tensor(z)), indexing)
|
||||
output = net()
|
||||
print(x)
|
||||
print(y)
|
||||
print(z)
|
||||
print(output)
|
||||
np_output = np.meshgrid(x, y, z, indexing=indexing)
|
||||
assert np.array_equal(output[0].asnumpy(), np_output[0])
|
||||
assert np.array_equal(output[1].asnumpy(), np_output[1])
|
||||
assert np.array_equal(output[2].asnumpy(), np_output[2])
|
||||
|
||||
|
||||
def test_net_int64():
|
||||
x = np.random.randn(4,).astype(np.int64)
|
||||
y = np.random.randn(3,).astype(np.int64)
|
||||
z = np.random.randn(6,).astype(np.int64)
|
||||
indexing = "xy"
|
||||
|
||||
net = Net((Tensor(x), Tensor(y), Tensor(z)), indexing)
|
||||
output = net()
|
||||
print(x)
|
||||
print(y)
|
||||
print(z)
|
||||
print(output)
|
||||
np_output = np.meshgrid(x, y, z, indexing=indexing)
|
||||
assert np.array_equal(output[0].asnumpy(), np_output[0])
|
||||
assert np.array_equal(output[1].asnumpy(), np_output[1])
|
||||
assert np.array_equal(output[2].asnumpy(), np_output[2])
|
||||
|
||||
def test_net_uint64():
|
||||
x = np.random.randn(4,).astype(np.uint64)
|
||||
y = np.random.randn(3,).astype(np.uint64)
|
||||
z = np.random.randn(6,).astype(np.uint64)
|
||||
indexing = "xy"
|
||||
|
||||
net = Net((Tensor(x), Tensor(y), Tensor(z)), indexing)
|
||||
output = net()
|
||||
print(x)
|
||||
print(y)
|
||||
print(z)
|
||||
print(output)
|
||||
np_output = np.meshgrid(x, y, z, indexing=indexing)
|
||||
assert np.array_equal(output[0].asnumpy(), np_output[0])
|
||||
assert np.array_equal(output[1].asnumpy(), np_output[1])
|
||||
assert np.array_equal(output[2].asnumpy(), np_output[2])
|
||||
|
||||
|
||||
def test_net_float16():
|
||||
x = np.random.randn(4,).astype(np.float16)
|
||||
y = np.random.randn(3,).astype(np.float16)
|
||||
z = np.random.randn(6,).astype(np.float16)
|
||||
indexing = "xy"
|
||||
|
||||
net = Net((Tensor(x), Tensor(y), Tensor(z)), indexing)
|
||||
output = net()
|
||||
print(x)
|
||||
print(y)
|
||||
print(z)
|
||||
print(output)
|
||||
np_output = np.meshgrid(x, y, z, indexing=indexing)
|
||||
assert np.array_equal(output[0].asnumpy(), np_output[0])
|
||||
assert np.array_equal(output[1].asnumpy(), np_output[1])
|
||||
assert np.array_equal(output[2].asnumpy(), np_output[2])
|
||||
|
||||
|
||||
def test_net_float32():
|
||||
x = np.random.randn(4,).astype(np.float32)
|
||||
y = np.random.randn(3,).astype(np.float32)
|
||||
z = np.random.randn(6,).astype(np.float32)
|
||||
indexing = "xy"
|
||||
|
||||
net = Net((Tensor(x), Tensor(y), Tensor(z)), indexing)
|
||||
output = net()
|
||||
print(x)
|
||||
print(y)
|
||||
print(z)
|
||||
print(output)
|
||||
np_output = np.meshgrid(x, y, z, indexing=indexing)
|
||||
assert np.array_equal(output[0].asnumpy(), np_output[0])
|
||||
assert np.array_equal(output[1].asnumpy(), np_output[1])
|
||||
assert np.array_equal(output[2].asnumpy(), np_output[2])
|
||||
|
||||
|
||||
def test_net_float64():
|
||||
x = np.random.randn(4,).astype(np.float64)
|
||||
y = np.random.randn(3,).astype(np.float64)
|
||||
z = np.random.randn(6,).astype(np.float64)
|
||||
indexing = "xy"
|
||||
|
||||
net = Net((Tensor(x), Tensor(y), Tensor(z)), indexing)
|
||||
output = net()
|
||||
print(x)
|
||||
print(y)
|
||||
print(z)
|
||||
print(output)
|
||||
np_output = np.meshgrid(x, y, z, indexing=indexing)
|
||||
assert np.array_equal(output[0].asnumpy(), np_output[0])
|
||||
assert np.array_equal(output[1].asnumpy(), np_output[1])
|
||||
assert np.array_equal(output[2].asnumpy(), np_output[2])
|
||||
|
||||
|
||||
def test_net_float64_ij():
|
||||
x = np.random.randn(4,).astype(np.float64)
|
||||
y = np.random.randn(3,).astype(np.float64)
|
||||
z = np.random.randn(6,).astype(np.float64)
|
||||
indexing = "ij"
|
||||
|
||||
net = Net((Tensor(x), Tensor(y), Tensor(z)), indexing)
|
||||
output = net()
|
||||
print(x)
|
||||
print(y)
|
||||
print(z)
|
||||
print(output)
|
||||
np_output = np.meshgrid(x, y, z, indexing=indexing)
|
||||
assert np.array_equal(output[0].asnumpy(), np_output[0])
|
||||
assert np.array_equal(output[1].asnumpy(), np_output[1])
|
||||
assert np.array_equal(output[2].asnumpy(), np_output[2])
|
Loading…
Reference in New Issue