!8673 Add min/max_shape to ScatterAdd/Update and Transpose and add new dynamic shape testcases

From: @TFbunny
Reviewed-by: @robingrosman
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2020-11-27 22:13:32 +08:00 committed by Gitee
commit b0496aaa10
10 changed files with 555 additions and 171 deletions

View File

@ -1,114 +1,126 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_TRANSPOSE_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_TRANSPOSE_H_
#include <vector>
#include <algorithm>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/cuda_impl/transpose_impl.cuh"
namespace mindspore {
namespace kernel {
template <typename T>
class TransposeGpuFwdKernel : public GpuKernel {
public:
TransposeGpuFwdKernel() : shape_size_(0), input_size_(0), output_size_(0), workspace_size_(0) {}
~TransposeGpuFwdKernel() = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
T *input = GetDeviceAddress<T>(inputs, 0);
T *output = GetDeviceAddress<T>(outputs, 0);
size_t *input_shape = GetDeviceAddress<size_t>(workspace, 0);
size_t *input_axis = GetDeviceAddress<size_t>(workspace, 1);
CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(input_shape, &input_shape_[0], workspace_size_, cudaMemcpyHostToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync input_shape failed");
CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(input_axis, &input_axis_[0], workspace_size_, cudaMemcpyHostToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync input_axis failed");
size_t size = input_size_ / sizeof(T);
CalTranspose(size, input, input_shape, input_axis, shape_size_, output, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
bool Init(const CNodePtr &kernel_node) override {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 1) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but transpose needs 1 input.";
return false;
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 1) {
MS_LOG(ERROR) << "Output number is " << output_num << ", but transpose needs 1 output.";
return false;
}
auto input_shape = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 0);
shape_size_ = input_shape.size();
if (shape_size_ > TRANSPOSE_MAX_DIMENSION) {
MS_LOG(EXCEPTION) << "Input is " << shape_size_ << "-D, but transpose supports max " << TRANSPOSE_MAX_DIMENSION
<< "-D inputs.";
}
input_size_ = 1;
for (size_t i = 0; i < shape_size_; i++) {
input_size_ *= input_shape[i];
input_shape_.push_back(input_shape[i]);
}
input_size_ *= sizeof(T);
output_size_ = input_size_;
std::vector<int> perm;
std::vector<int64_t> perm_me = GetAttr<std::vector<int64_t>>(kernel_node, "perm");
(void)std::transform(perm_me.begin(), perm_me.end(), std::back_inserter(perm),
[](const int64_t &value) { return static_cast<int>(value); });
for (size_t j = 0; j < perm.size(); j++) {
input_axis_.push_back(perm[j]);
}
InitSizeLists();
return true;
}
protected:
void InitSizeLists() override {
input_size_list_.push_back(input_size_);
output_size_list_.push_back(output_size_);
workspace_size_ = shape_size_ * sizeof(size_t);
workspace_size_list_.push_back(workspace_size_);
workspace_size_list_.push_back(workspace_size_);
return;
}
private:
std::vector<size_t> input_shape_;
std::vector<size_t> input_axis_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
size_t shape_size_;
size_t input_size_;
size_t output_size_;
size_t workspace_size_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_TRANSPOSE_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_TRANSPOSE_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_TRANSPOSE_H_
#include <vector>
#include <algorithm>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/cuda_impl/transpose_impl.cuh"
namespace mindspore {
namespace kernel {
template <typename T>
class TransposeGpuFwdKernel : public GpuKernel {
public:
TransposeGpuFwdKernel() { ResetResource(); }
~TransposeGpuFwdKernel() = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
T *input = GetDeviceAddress<T>(inputs, 0);
T *output = GetDeviceAddress<T>(outputs, 0);
size_t *input_shape = GetDeviceAddress<size_t>(workspace, 0);
size_t *input_axis = GetDeviceAddress<size_t>(workspace, 1);
CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(input_shape, &input_shape_[0], workspace_size_, cudaMemcpyHostToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync input_shape failed");
CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(input_axis, &input_axis_[0], workspace_size_, cudaMemcpyHostToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync input_axis failed");
size_t size = input_size_ / sizeof(T);
CalTranspose(size, input, input_shape, input_axis, shape_size_, output, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
bool Init(const CNodePtr &kernel_node) override {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 1) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but transpose needs 1 input.";
return false;
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 1) {
MS_LOG(ERROR) << "Output number is " << output_num << ", but transpose needs 1 output.";
return false;
}
auto input_shape = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 0);
shape_size_ = input_shape.size();
if (shape_size_ > TRANSPOSE_MAX_DIMENSION) {
MS_LOG(EXCEPTION) << "Input is " << shape_size_ << "-D, but transpose supports max " << TRANSPOSE_MAX_DIMENSION
<< "-D inputs.";
}
input_size_ = 1;
for (size_t i = 0; i < shape_size_; i++) {
input_size_ *= input_shape[i];
input_shape_.push_back(input_shape[i]);
}
input_size_ *= sizeof(T);
output_size_ = input_size_;
std::vector<int> perm;
std::vector<int64_t> perm_me = GetAttr<std::vector<int64_t>>(kernel_node, "perm");
(void)std::transform(perm_me.begin(), perm_me.end(), std::back_inserter(perm),
[](const int64_t &value) { return static_cast<int>(value); });
for (size_t j = 0; j < perm.size(); j++) {
input_axis_.push_back(perm[j]);
}
InitSizeLists();
return true;
}
void ResetResource() noexcept override {
shape_size_ = 0;
input_size_ = 0;
output_size_ = 0;
workspace_size_ = 0;
input_shape_.clear();
input_axis_.clear();
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
}
protected:
void InitSizeLists() override {
input_size_list_.push_back(input_size_);
output_size_list_.push_back(output_size_);
workspace_size_ = shape_size_ * sizeof(size_t);
workspace_size_list_.push_back(workspace_size_);
workspace_size_list_.push_back(workspace_size_);
return;
}
private:
std::vector<size_t> input_shape_;
std::vector<size_t> input_axis_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
size_t shape_size_;
size_t input_size_;
size_t output_size_;
size_t workspace_size_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_TRANSPOSE_H_

View File

@ -44,7 +44,7 @@ const AnfNodePtr ConvertConstInputToAttr::Process(const FuncGraphPtr &, const An
}
std::set<string> DynamicShapeConstInputToAttr = {kCastOpName, kExpandDimsOpName, kReshapeOpName,
kEmbeddingLookupOpName};
kEmbeddingLookupOpName, kTransposeOpName};
for (auto &t : todos) {
CNodePtr cnode = t->cast<CNodePtr>();
ConstInputToAttrInfoRegister reg;

View File

@ -82,6 +82,7 @@ constexpr auto kUnsortedSegmentMinOpName = "UnsortedSegmentMin";
constexpr auto kFlattenGradOpName = "FlattenGrad";
constexpr auto kExpandDimsOpName = "ExpandDims";
constexpr auto kReshapeOpName = "Reshape";
constexpr auto kTransposeOpName = "Transpose";
constexpr auto kSplitOpName = "Split";
constexpr auto kSplitVOpName = "SplitV";
constexpr auto kSparseApplyAdagradOpName = "SparseApplyAdagrad";

View File

@ -14,7 +14,6 @@
* limitations under the License.
*/
#include <set>
#include <algorithm>
#include <iterator>
#include "abstract/infer_functions.h"
@ -260,7 +259,11 @@ AbstractBasePtr InferImplScatterAdd(const AnalysisEnginePtr &, const PrimitivePt
auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
MS_EXCEPTION_IF_NULL(x);
MS_EXCEPTION_IF_NULL(x->shape());
return std::make_shared<AbstractTensor>(x->element(), x->shape());
ShapeVector shape = x->shape()->shape();
ShapeVector min_shape = x->shape()->min_shape();
ShapeVector max_shape = x->shape()->max_shape();
(void)CheckMinMaxShape(shape, &min_shape, &max_shape);
return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(shape, min_shape, max_shape));
}
AbstractBasePtr InferImplScatterUpdate(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
@ -270,7 +273,11 @@ AbstractBasePtr InferImplScatterUpdate(const AnalysisEnginePtr &, const Primitiv
auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
MS_EXCEPTION_IF_NULL(x);
MS_EXCEPTION_IF_NULL(x->shape());
return std::make_shared<AbstractTensor>(x->element(), x->shape());
ShapeVector shape = x->shape()->shape();
ShapeVector min_shape = x->shape()->min_shape();
ShapeVector max_shape = x->shape()->max_shape();
(void)CheckMinMaxShape(shape, &min_shape, &max_shape);
return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(shape, min_shape, max_shape));
}
AbstractBasePtr InferImplMapCacheIdx(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
@ -542,43 +549,28 @@ AbstractBasePtr InferImplZerosLike(const AnalysisEnginePtr &, const PrimitivePtr
AbstractBasePtr InferImplTranspose(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
const std::string &op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 2);
AbstractTensorPtr input = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
auto perm = CheckArg<AbstractTuple>(op_name, args_spec_list, 1);
auto input_shp = input->shape()->shape();
auto perm_val = perm->BuildValue();
if (perm_val->isa<AnyValue>()) {
MS_LOG(EXCEPTION) << "Perm can't be anything: " << args_spec_list[1]->ToString();
}
auto perm_val_data = perm_val->cast<ValueTuplePtr>()->value();
ValuePtr perm = primitive->GetAttr("perm");
auto perm_val = perm->cast<ValueTuplePtr>();
MS_EXCEPTION_IF_NULL(perm_val);
auto perm_val_data = perm_val->value();
ShapeVector perm_vec;
(void)std::transform(std::begin(perm_val_data), std::end(perm_val_data), std::back_inserter(perm_vec),
[](const ValuePtr &e) -> int64_t { return GetValue<int64_t>(e); });
ShapeVector result_shp;
std::set<size_t> indices;
for (size_t i = 0; i < perm_vec.size(); i++) {
size_t idx = static_cast<size_t>(perm_vec[i]);
if (indices.find(idx) != indices.end()) {
MS_LOG(EXCEPTION) << "Perm values must be unique";
}
if (idx >= perm_vec.size()) {
MS_LOG(EXCEPTION) << "One value in perm is " << idx << ", not in range [0, " << perm_vec.size() << ")";
}
result_shp.push_back(input_shp[idx]);
indices.insert(idx);
}
ShapeVector max_shp;
ShapeVector min_shp;
if (input->shape()->max_shape().size() == input_shp.size() &&
input->shape()->min_shape().size() == input_shp.size()) {
for (size_t i = 0; i < perm_vec.size(); i++) {
size_t idx = static_cast<size_t>(perm_vec[i]);
max_shp.push_back(input->shape()->max_shape()[idx]);
min_shp.push_back(input->shape()->min_shape()[idx]);
}
return std::make_shared<AbstractTensor>(input->element(), std::make_shared<Shape>(result_shp, min_shp, max_shp));
ShapeVector x_max_shp = input->shape()->max_shape();
ShapeVector x_min_shp = input->shape()->min_shape();
(void)CheckMinMaxShape(input_shp, &x_min_shp, &x_max_shp);
for (size_t i = 0; i < perm_vec.size(); i++) {
size_t idx = static_cast<size_t>(perm_vec[i]);
result_shp.push_back(input_shp[idx]);
max_shp.push_back(x_max_shp[idx]);
min_shp.push_back(x_min_shp[idx]);
}
return std::make_shared<AbstractTensor>(input->element(), std::make_shared<Shape>(result_shp));
return std::make_shared<AbstractTensor>(input->element(), std::make_shared<Shape>(result_shp, min_shp, max_shp));
}
AbstractBasePtr InferImplReshape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,

View File

@ -310,5 +310,10 @@ size_t TypeIdSize(const TypeId data_type) {
size_t ShapeSize(const std::vector<size_t> &shape) {
return std::accumulate(shape.begin(), shape.end(), IntToSize(1), std::multiplies<size_t>());
}
void CheckMinMaxShape(const ShapeVector &shape, ShapeVector *min_shape, ShapeVector *max_shape) {
*min_shape = (*min_shape).empty() ? shape : *min_shape;
*max_shape = (*max_shape).empty() ? shape : *max_shape;
}
} // namespace abstract
} // namespace mindspore

View File

@ -56,6 +56,10 @@ size_t ShapeSize(const std::vector<size_t> &shape);
// Get broadcasted shape for binary element-wise operation
ShapePtr GetBroadcastShape(const std::string &op, const AbstractTensorPtr &tensor_x, const AbstractTensorPtr &tensor_y);
// Check dynamic shape routine
void CheckMinMaxShape(const ShapeVector &shape, ShapeVector *min_shape, ShapeVector *max_shape);
} // namespace abstract
} // namespace mindspore
#endif // MINDSPORE_CORE_ABSTRACT_UTILS_H_

View File

@ -73,17 +73,13 @@ class _ScatterOp_Dynamic(PrimitiveWithCheck):
"""
Defines Scatter operators with dynamic shape
"""
__mindspore_signature__ = (
sig.make_sig('x', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
sig.make_sig('indices', dtype=sig.sig_dtype.T1),
sig.make_sig('updates', dtype=sig.sig_dtype.T)
)
def _check_scatter_shape(self, x_shape, indices_shape, updates_shape, prim_name):
if indices_shape != [-1] and updates_shape and updates_shape != indices_shape + x_shape[1:]:
raise ValueError(f"For '{prim_name}', "
f"updates_shape = indices_shape + x_shape[1:], but got x_shape: {x_shape}, "
f"indices_shape: {indices_shape}, updates_shape: {updates_shape}.")
if np.all(np.array(x_shape) != -1):
if indices_shape != [-1] and updates_shape and updates_shape != indices_shape + x_shape[1:]:
raise ValueError(f"For '{prim_name}', "
f"updates_shape = indices_shape + x_shape[1:], but got x_shape: {x_shape}, "
f"indices_shape: {indices_shape}, updates_shape: {updates_shape}.")
@prim_attr_register
def __init__(self, use_locking=False):
@ -649,7 +645,7 @@ class Squeeze(PrimitiveWithInfer):
return x_dtype
class Transpose(PrimitiveWithCheck):
class Transpose(PrimitiveWithInfer):
"""
Permutes the dimensions of the input tensor according to input permutation.
@ -685,14 +681,36 @@ class Transpose(PrimitiveWithCheck):
"""Initialize Transpose"""
self.init_prim_io_names(inputs=['x', 'perm'], outputs=['output'])
def check_shape(self, x, perm):
validator.check_value_type("perm", perm, [tuple], self.name)
if len(x) != len(perm):
def __infer__(self, x, perm):
x_shape = x['shape']
p_value = perm['value']
x_type = x['dtype']
validator.check_value_type("p_value", p_value, [tuple], self.name)
validator.check_subclass("x_type", x_type, mstype.tensor, self.name)
if len(x_shape) != len(p_value):
raise ValueError('The dimension of x and perm must be equal.')
def check_dtype(self, x, perm):
validator.check_subclass("x", x, mstype.tensor, self.name)
tmp = list(p_value)
for i, dim in enumerate(p_value):
validator.check_int(dim, 0, Rel.GE, f'perm[{i}]', self.name)
validator.check_int(dim, len(p_value), Rel.LT, f'perm[{i}]', self.name)
tmp.remove(dim)
if dim in tmp:
raise ValueError('The value of perm is wrong.')
out_shapes = []
for i in p_value:
out_shapes.append(x_shape[i])
out = {'shape': tuple(out_shapes),
'dtype': x['dtype'],
'value': None}
if 'min_shape' in x and 'max_shape' in x:
min_vec = []
max_vec = []
for i in p_value:
min_vec.append(x['min_shape'][i])
max_vec.append(x['max_shape'][i])
out['min_shape'] = tuple(min_vec)
out['max_shape'] = tuple(max_vec)
return out
class Unique(Primitive):
"""

View File

@ -19,6 +19,7 @@ import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor, Parameter
from mindspore.ops import operations as P
from mindspore.ops.operations import _inner_ops as inner
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
# all cases tested against dchip
@ -45,6 +46,44 @@ def scatter_add_use_locking_false_net(inputx, indices, updates):
net = TestScatterAddNet(lock, inputx, indices, updates)
return net()
class TestScatterAddDynamicNet(nn.Cell):
def __init__(self, inputx, indices, updates):
super(TestScatterAddDynamicNet, self).__init__()
self.scatter_add = P.ScatterAdd()
self.test_dynamic = inner.GpuConvertToDynamicShape()
self.inputx = Parameter(inputx, name="inputx")
self.indices = Parameter(indices, name="indices")
self.updates = Parameter(updates, name="updates")
def construct(self):
out = self.test_dynamic(self.inputx)
out = self.scatter_add(out, self.indices, self.updates)
return out
def scatter_add_d_net(inputx, indices, updates):
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
net = TestScatterAddDynamicNet(inputx, indices, updates)
return net()
class TestScatterAddDynamicNet2(nn.Cell):
def __init__(self):
super(TestScatterAddDynamicNet2, self).__init__()
self.scatter_add = P.ScatterAdd()
self.test_dynamic = inner.GpuConvertToDynamicShape()
def construct(self, inputx, indices, updates):
out = self.test_dynamic(inputx)
out = self.scatter_add(out, indices, updates)
return out
def scatter_add_d2_net(inputx_1, indices_1, updates_1, inputx_2,
indices_2, updates_2):
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
net = TestScatterAddDynamicNet2()
out1 = net(inputx_1, indices_1, updates_1)
out2 = net(inputx_2, indices_2, updates_2)
return (out1, out2)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@ -196,3 +235,78 @@ def test_scatter_add_disordered_int32():
[187., 188., 189., 190.],
[492., 496., 500., 504.]]).astype(np.int32)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_scatter_add_disordered_dynamic_int32():
inputx = Tensor(np.flip(np.arange(34, 46).reshape(3, 4).astype(np.int32)))
indices = Tensor(np.array([[[0, 1, 2],
[2, 1, 0]],
[[0, 0, 0],
[2, 2, 2]]]).astype(np.int32))
updates = Tensor(np.arange(63, 111).reshape((2, 2, 3, 4)).astype(np.int32))
output = scatter_add_d_net(inputx, indices, updates)
expected = np.array([[464., 468., 472., 476.],
[187., 188., 189., 190.],
[492., 496., 500., 504.]]).astype(np.int32)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_scatter_add_input_less_than_1_dynamic_float32():
inputx = Tensor(np.array([[0.214141, 0.415151, 0.51516],
[0.876542, 0.451611, 0.55112],
[0.111244, 0.633333, 0.34444]]).astype(np.float32))
indices = Tensor(np.array([[[1, 0, 2],
[2, 2, 0]],
[[1, 0, 1],
[2, 1, 2]]]).astype(np.int32))
updates = Tensor(np.arange(34, 70).reshape((2, 2, 3, 3)).astype(np.float32))
output = scatter_add_d_net(inputx, indices, updates)
expected = np.array([[141.21414, 144.41515, 147.51517],
[208.87654, 212.45161, 216.55112],
[257.11124, 262.63333, 267.34442]], dtype=np.float32)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_scatter_add_dynamic_two_inputs():
inputx_1 = Tensor(np.zeros((2, 3)).astype(np.float32))
indices_1 = Tensor(np.array([[0, 1], [0, 1]]).astype(np.int32))
updates_1 = Tensor(np.arange(12).reshape((2, 2, 3)).astype(np.float32))
inputx_2 = Tensor(np.ones((4, 2, 3, 4)).astype(np.float32))
indices_2 = Tensor(np.array([[0, 2], [3, 1]]).astype(np.int32))
updates_2 = Tensor(np.arange(96).reshape((2, 2, 2, 3, 4)).astype(np.float32))
output_1, output_2 = scatter_add_d2_net(inputx_1, indices_1, updates_1,
inputx_2, indices_2, updates_2)
expected_1 = np.array([[6., 8., 10.],
[12., 14., 16.]])
expected_2 = np.array([[[[1., 2., 3., 4.],
[5., 6., 7., 8.],
[9., 10., 11., 12.]],
[[13., 14., 15., 16.],
[17., 18., 19., 20.],
[21., 22., 23., 24.]]],
[[[73., 74., 75., 76.],
[77., 78., 79., 80.],
[81., 82., 83., 84.]],
[[85., 86., 87., 88.],
[89., 90., 91., 92.],
[93., 94., 95., 96.]]],
[[[25., 26., 27., 28.],
[29., 30., 31., 32.],
[33., 34., 35., 36.]],
[[37., 38., 39., 40.],
[41., 42., 43., 44.],
[45., 46., 47., 48.]]],
[[[49., 50., 51., 52.],
[53., 54., 55., 56.],
[57., 58., 59., 60.]],
[[61., 62., 63., 64.],
[65., 66., 67., 68.],
[69., 70., 71., 72.]]]])
np.testing.assert_array_almost_equal(output_1.asnumpy(), expected_1)
np.testing.assert_array_almost_equal(output_2.asnumpy(), expected_2)

View File

@ -19,6 +19,7 @@ import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor, Parameter
from mindspore.ops import operations as P
from mindspore.ops.operations import _inner_ops as inner
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
# all cases tested against dchip
@ -39,6 +40,44 @@ def scatter_update_net(inputx, indices, updates):
net = TestScatterUpdateNet(inputx, indices, updates)
return net()
class TestScatterUpdateDynamicNet(nn.Cell):
def __init__(self, inputx, indices, updates):
super(TestScatterUpdateDynamicNet, self).__init__()
self.scatter_update = P.ScatterUpdate()
self.test_dynamic = inner.GpuConvertToDynamicShape()
self.inputx = Parameter(inputx, name="inputx")
self.indices = Parameter(indices, name="indices")
self.updates = Parameter(updates, name="updates")
def construct(self):
out = self.test_dynamic(self.inputx)
out = self.scatter_update(out, self.indices, self.updates)
return out
def scatter_update_d_net(inputx, indices, updates):
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
net = TestScatterUpdateDynamicNet(inputx, indices, updates)
return net()
class TestScatterUpdateDynamicNet2(nn.Cell):
def __init__(self):
super(TestScatterUpdateDynamicNet2, self).__init__()
self.scatter_update = P.ScatterUpdate()
self.test_dynamic = inner.GpuConvertToDynamicShape()
def construct(self, inputx, indices, updates):
out = self.test_dynamic(inputx)
out = self.scatter_update(out, indices, updates)
return out
def scatter_update_d2_net(inputx_1, indices_1, updates_1, inputx_2,
indices_2, updates_2):
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
net = TestScatterUpdateDynamicNet2()
out1 = net(inputx_1, indices_1, updates_1)
out2 = net(inputx_2, indices_2, updates_2)
return (out1, out2)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@ -237,3 +276,72 @@ def test_scatter_update_disordered_uint8():
[63., 64., 65., 66.],
[67., 68., 69., 70.]]).astype(np.uint8)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_scatter_update_large_shape_dynamic_int8():
inputx = Tensor(np.arange(96).reshape((4, 2, 3, 4)).astype(np.int8))
indices = Tensor(np.array([1, 0]).astype(np.int32))
updates = Tensor(np.flip(np.arange(48).reshape((2, 2, 3, 4)).astype(np.int8)))
output = scatter_update_d_net(inputx, indices, updates)
expected = np.array([[[[23., 22., 21., 20.],
[19., 18., 17., 16.],
[15., 14., 13., 12.]],
[[11., 10., 9., 8.],
[7., 6., 5., 4.],
[3., 2., 1., 0.]]],
[[[47., 46., 45., 44.],
[43., 42., 41., 40.],
[39., 38., 37., 36.]],
[[35., 34., 33., 32.],
[31., 30., 29., 28.],
[27., 26., 25., 24.]]],
[[[48., 49., 50., 51.],
[52., 53., 54., 55.],
[56., 57., 58., 59.]],
[[60., 61., 62., 63.],
[64., 65., 66., 67.],
[68., 69., 70., 71.]]],
[[[72., 73., 74., 75.],
[76., 77., 78., 79.],
[80., 81., 82., 83.]],
[[84., 85., 86., 87.],
[88., 89., 90., 91.],
[92., 93., 94., 95.]]]]).astype(np.int8)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_scatter_update_disordered_dynamic_int32():
inputx = Tensor(np.flip(np.arange(34, 46).reshape(3, 4).astype(np.int32)))
indices = Tensor(np.array([1, 2]).astype(np.int32))
updates = Tensor(np.arange(63, 71).reshape((2, 4)).astype(np.int32))
output = scatter_update_d_net(inputx, indices, updates)
expected = np.array([[45., 44., 43., 42.],
[63., 64., 65., 66.],
[67., 68., 69., 70.]]).astype(np.int32)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_scatter_update_two_inputs():
inputx_1 = Tensor(np.zeros((2, 3)).astype(np.float32))
indices_1 = Tensor(np.array([0, 1]).astype(np.int32))
updates_1 = Tensor(np.arange(6).reshape((2, 3)).astype(np.float32))
inputx_2 = Tensor(np.array([[0.214141, 0.415151, 0.51516],
[0.876542, 0.451611, 0.55112],
[0.111244, 0.633333, 0.34444]]).astype(np.float32))
indices_2 = Tensor(np.array([1, 0, 2]).astype(np.int32))
updates_2 = Tensor(np.arange(34, 43).reshape((3, 3)).astype(np.float32))
output_1, output_2 = scatter_update_d2_net(inputx_1, indices_1, updates_1,
inputx_2, indices_2, updates_2)
expected_1 = np.array([[0., 1., 2.],
[3., 4., 5.]])
expected_2 = np.array([[37., 38., 39.],
[34., 35., 36.],
[40., 41., 42.]], dtype=np.float32)
np.testing.assert_array_almost_equal(output_1.asnumpy(), expected_1)
np.testing.assert_array_almost_equal(output_2.asnumpy(), expected_2)

View File

@ -23,28 +23,24 @@ from mindspore.common.api import ms_function
from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter
from mindspore.ops import operations as P
from mindspore.ops.operations import _inner_ops as inner
context.set_context(device_target='GPU')
class Transpose(nn.Cell):
def __init__(self, nptype):
super(Transpose, self).__init__()
self.transpose = P.Transpose()
self.x_2D = Parameter(initializer(Tensor(np.arange(5 * 6).reshape(5, 6).astype(nptype)), [5, 6]),
name='x_2D')
self.perm_2D = (1, 0)
self.x_3D = Parameter(initializer(Tensor(np.arange(2 * 2 * 4).reshape(2, 2, 4).astype(nptype)), [2, 2, 4]),
name='x_3D')
self.perm_3D = (1, 0, 2)
self.x_4D = Parameter(
initializer(Tensor(np.arange(2 * 3 * 4 * 5).reshape(2, 3, 4, 5).astype(nptype)), [2, 3, 4, 5]),
name='x_4D')
self.perm_4D = (0, 1, 2, 3)
self.x_5D = Parameter(
initializer(Tensor(np.arange(1 * 2 * 3 * 4 * 5).reshape(1, 2, 3, 4, 5).astype(nptype)),
[1, 2, 3, 4, 5]), name='x_5D')
@ -55,11 +51,42 @@ class Transpose(nn.Cell):
return (self.transpose(self.x_2D, self.perm_2D), self.transpose(self.x_3D, self.perm_3D),
self.transpose(self.x_4D, self.perm_4D), self.transpose(self.x_5D, self.perm_5D))
class Transpose_dynamic(nn.Cell):
def __init__(self, nptype):
super(Transpose_dynamic, self).__init__()
self.transpose = P.Transpose()
self.test_dynamic = inner.GpuConvertToDynamicShape()
self.x = Parameter(
initializer(Tensor(np.arange(1 * 2 * 3 * 4 * 5).reshape(1, 2, 3, 4, 5).astype(nptype)),
[1, 2, 3, 4, 5]), name='5D')
self.perm = (1, 0, 3, 4, 2)
@ms_function
def construct(self):
out = self.test_dynamic(self.x)
return self.transpose(out, self.perm)
class Transpose_dynamic2(nn.Cell):
def __init__(self, input_1, input_2, perm_1, perm_2):
super(Transpose_dynamic2, self).__init__()
self.transpose = P.Transpose()
self.test_dynamic = inner.GpuConvertToDynamicShape()
self.x_1 = input_1
self.x_2 = input_2
self.perm_1 = perm_1
self.perm_2 = perm_2
@ms_function
def construct(self):
out_1 = self.test_dynamic(self.x_1)
out_1 = self.transpose(out_1, self.perm_1)
out_2 = self.test_dynamic(self.x_2)
out_2 = self.transpose(out_2, self.perm_2)
return (out_1, out_2)
def transpose1(nptype):
transpose = Transpose(nptype)
output = transpose()
expect0 = np.array([[[0, 6, 12, 18, 24],
[1, 7, 13, 19, 25],
[2, 8, 14, 20, 26],
@ -82,7 +109,6 @@ def transpose1(nptype):
[45, 46, 47, 48, 49],
[50, 51, 52, 53, 54],
[55, 56, 57, 58, 59]]],
[[[60, 61, 62, 63, 64],
[65, 66, 67, 68, 69],
[70, 71, 72, 73, 74],
@ -115,7 +141,6 @@ def transpose1(nptype):
[17, 37, 57],
[18, 38, 58],
[19, 39, 59]]]],
[[[[60, 80, 100],
[61, 81, 101],
[62, 82, 102],
@ -141,6 +166,75 @@ def transpose1(nptype):
assert (output[2].asnumpy() == expect2).all()
assert (output[3].asnumpy() == expect3).all()
def transpose_d(nptype):
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
transpose = Transpose_dynamic(nptype)
output = transpose()
expect = np.array([[[[[[0, 20, 40],
[1, 21, 41],
[2, 22, 42],
[3, 23, 43],
[4, 24, 44]],
[[5, 25, 45],
[6, 26, 46],
[7, 27, 47],
[8, 28, 48],
[9, 29, 49]],
[[10, 30, 50],
[11, 31, 51],
[12, 32, 52],
[13, 33, 53],
[14, 34, 54]],
[[15, 35, 55],
[16, 36, 56],
[17, 37, 57],
[18, 38, 58],
[19, 39, 59]]]],
[[[[60, 80, 100],
[61, 81, 101],
[62, 82, 102],
[63, 83, 103],
[64, 84, 104]],
[[65, 85, 105],
[66, 86, 106],
[67, 87, 107],
[68, 88, 108],
[69, 89, 109]],
[[70, 90, 110],
[71, 91, 111],
[72, 92, 112],
[73, 93, 113],
[74, 94, 114]],
[[75, 95, 115],
[76, 96, 116],
[77, 97, 117],
[78, 98, 118],
[79, 99, 119]]]]]]).astype(nptype)
assert (output.asnumpy() == expect).all()
def transpose_d2(nptype):
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
input_1 = Parameter(Tensor(np.arange(5 * 6).reshape(5, 6).astype(nptype)),
name="input_1")
input_2 = Parameter(Tensor(np.arange(2 * 2 * 4).reshape(2, 2, 4).astype(nptype)),
name="input_2")
perm_1 = (1, 0)
perm_2 = (1, 0, 2)
expect_1 = np.array([[[0, 6, 12, 18, 24],
[1, 7, 13, 19, 25],
[2, 8, 14, 20, 26],
[3, 9, 15, 21, 27],
[4, 10, 16, 22, 28],
[5, 11, 17, 23, 29]]]).astype(nptype)
expect_2 = np.array([[[[0, 1, 2, 3],
[8, 9, 10, 11]],
[[4, 5, 6, 7],
[12, 13, 14, 15]]]]).astype(nptype)
net = Transpose_dynamic2(input_1, input_2, perm_1, perm_2)
output_1, output_2 = net()
assert (output_1.asnumpy() == expect_1).all()
assert (output_2.asnumpy() == expect_2).all()
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@ -158,3 +252,39 @@ def test_transpose_float16():
@pytest.mark.env_onecard
def test_transpose_int32():
transpose1(np.int32)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_transpose_dynamic_float32():
transpose_d(np.float32)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_transpose_dynamic_float16():
transpose_d(np.float16)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_transpose_dynamic_int32():
transpose_d(np.int32)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_transpose_dynamic_two_inputs_float32():
transpose_d2(np.float32)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_transpose_dynamic_two_inputs_float16():
transpose_d2(np.float16)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_transpose_dynamic_two_inputs_int32():
transpose_d2(np.int32)