!42028 fix MirrorPadGrad bad_function_call bug

Merge pull request !42028 from panshaowu/master
This commit is contained in:
i-robot 2022-09-19 22:55:58 +00:00 committed by Gitee
commit d105842398
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
17 changed files with 444 additions and 132 deletions

View File

@ -73,6 +73,54 @@ bool GatherGradGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const st
return true;
}
void GatherGradGpuKernelMod::CalculateDim(const BaseOperatorPtr &base_operator,
const std::vector<KernelTensorPtr> &inputs) {
if (grad_shapes_.size() != index_shapes_.size() || grad_shapes_.size() != output_shapes_.size()) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_
<< "', the dimension of grad, index and output must be the same, but got the dimension of "
<< "grad: " << grad_shapes_.size() << ", the dimension of index: " << index_shapes_.size()
<< ", the dimension of output: " << output_shapes_.size();
}
int dims = SizeToInt(grad_shapes_.size());
size_t input_num = inputs.size();
constexpr size_t kStaticSize = 2;
constexpr size_t kDynamicSize = 3;
if (input_num == kStaticSize) {
auto kernel_ptr = std::dynamic_pointer_cast<ops::GatherDGrad>(base_operator);
axis_ = static_cast<int>(kernel_ptr->get_dim());
} else if (input_num == kDynamicSize) {
auto kernel_ptr = std::dynamic_pointer_cast<ops::GatherDGradV2>(base_operator);
axis_ = static_cast<int>(kernel_ptr->get_dim());
} else {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of inputs must be 2 or 3, but got " << input_num;
}
if (axis_ < -dims || axis_ >= dims) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the 'axis' must be in the range [-" << dims << "," << dims
<< "), but got " << axis_;
}
if (axis_ < 0) {
axis_ += dims;
}
int64_t dim_before_axis = 1;
for (size_t i = 0; i < IntToSize(axis_); i++) {
dim_before_axis *= output_shapes_[i];
}
size_t dim_at_axis_index = LongToSizeClipNeg(index_shapes_[IntToSize(axis_)]);
size_t dim_at_axis_output = LongToSizeClipNeg(output_shapes_[IntToSize(axis_)]);
int64_t dim_after_axis = 1;
for (size_t i = IntToSize(axis_) + 1; i < output_shapes_.size(); i++) {
dim_after_axis *= output_shapes_[i];
}
dims_[kIndex0] = LongToSize(dim_before_axis);
dims_[kIndex1] = dim_at_axis_index;
dims_[kIndex2] = dim_at_axis_output;
dims_[kIndex3] = LongToSize(dim_after_axis);
}
int GatherGradGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
@ -106,51 +154,8 @@ int GatherGradGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const s
}
if (ret == KRET_OK) {
if (grad_shapes_.size() != index_shapes_.size() || grad_shapes_.size() != output_shapes_.size()) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_
<< "', the dimension of grad, index and output must be the same, but got the dimension of "
<< "grad: " << grad_shapes_.size() << ", the dimension of index: " << index_shapes_.size()
<< ", the dimension of output: " << output_shapes_.size();
}
int dims = SizeToInt(grad_shapes_.size());
MS_EXCEPTION_IF_NULL(base_operator);
size_t input_num = inputs.size();
constexpr size_t kStaticSize = 2;
constexpr size_t kDynamicSize = 3;
if (input_num == kStaticSize) {
auto kernel_ptr = std::dynamic_pointer_cast<ops::GatherDGrad>(base_operator);
axis_ = static_cast<int>(kernel_ptr->get_dim());
} else if (input_num == kDynamicSize) {
auto kernel_ptr = std::dynamic_pointer_cast<ops::GatherDGradV2>(base_operator);
axis_ = static_cast<int>(kernel_ptr->get_dim());
} else {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of inputs must be 2 or 3, but got " << input_num;
}
if (axis_ < -dims || axis_ >= dims) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the 'axis' must be in the range [-" << dims << "," << dims
<< "), but got " << axis_;
}
if (axis_ < 0) {
axis_ += dims;
}
int64_t dim_before_axis = 1;
for (size_t i = 0; i < IntToSize(axis_); i++) {
dim_before_axis *= output_shapes_[i];
}
size_t dim_at_axis_index = LongToSizeClipNeg(index_shapes_[IntToSize(axis_)]);
size_t dim_at_axis_output = LongToSizeClipNeg(output_shapes_[IntToSize(axis_)]);
int64_t dim_after_axis = 1;
for (size_t i = IntToSize(axis_) + 1; i < output_shapes_.size(); i++) {
dim_after_axis *= output_shapes_[i];
}
dims_[kIndex0] = LongToSize(dim_before_axis);
dims_[kIndex1] = dim_at_axis_index;
dims_[kIndex2] = dim_at_axis_output;
dims_[kIndex3] = LongToSize(dim_after_axis);
CalculateDim(base_operator, inputs);
}
return static_cast<int>(ret);

View File

@ -52,6 +52,8 @@ class GatherGradGpuKernelMod : public NativeGpuKernelMod {
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs, void *stream_ptr);
private:
void CalculateDim(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs);
using GatherGradOpFunc = std::function<bool(GatherGradGpuKernelMod *, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &, void *)>;
static std::map<std::string, std::vector<std::pair<KernelAttr, GatherGradGpuKernelMod::GatherGradOpFunc>>>

View File

@ -73,9 +73,40 @@ bool MirrorPadGradGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const
}
input_type_size_ = abstract::TypeIdSize(inputs.at(kIndex0)->GetDtype());
padding_type_size_ = abstract::TypeIdSize(inputs.at(kIndex1)->GetDtype());
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr;
return false;
}
kernel_func_ = func_list_[index].second;
return true;
}
void MirrorPadGradGpuKernelMod::CalculateWorkspace(const ShapeVector &input_shape,
const std::vector<size_t> &output_shape) {
workspace_size_ = input_type_size_;
for (int i = 0; i < SizeToInt(kOutputDimLowerLimit); i++) {
workspace_size_ *= output_shape[i]; // BATCH, CHANNEL -> Output size
workspace_size_ *= input_shape[i + kOutputDimLowerLimit]; // WIDTH, HEIGHT -> Input Size
}
workspace_size_list_.push_back(workspace_size_);
int max_width = input_shape_[kIndexForMaxWidth];
// basic error check for padding value
if (mode_ == 1) { // symmetric
max_width = max_width + (kSymmetricCoef * max_width);
} else { // reflect
max_width = max_width + (kSymmetricCoef * (max_width - 1));
}
if (output_shape_[(output_shape_.size() - kMaxIndexOffset) + 0] > max_width ||
output_shape_[(output_shape_.size() - kMaxIndexOffset) + 1] > max_width) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the output.shape[-1] and output.shape[-2] cannot be greater "
<< "than input_x.shape[-1], but got output.shape: " << CONVERT_VECTOR_TO_STRING(output_shape_)
<< ", input_x.shape: " << CONVERT_VECTOR_TO_STRING(input_shape_);
}
}
int MirrorPadGradGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
@ -149,26 +180,7 @@ int MirrorPadGradGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, cons
// calc workspace size
// store dy values with accumulation across batch and channel only
if (ret == KRET_OK) {
workspace_size_ = input_type_size_;
for (int i = 0; i < SizeToInt(kOutputDimLowerLimit); i++) {
workspace_size_ *= output_shape[i]; // BATCH, CHANNEL -> Output size
workspace_size_ *= input_shape[i + kOutputDimLowerLimit]; // WIDTH, HEIGHT -> Input Size
}
workspace_size_list_.push_back(workspace_size_);
int max_width = input_shape_[kIndexForMaxWidth];
// basic error check for padding value
if (mode_ == 1) { // symmetric
max_width = max_width + (kSymmetricCoef * max_width);
} else { // reflect
max_width = max_width + (kSymmetricCoef * (max_width - 1));
}
if (output_shape_[(output_shape_.size() - kMaxIndexOffset) + 0] > max_width ||
output_shape_[(output_shape_.size() - kMaxIndexOffset) + 1] > max_width) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the output.shape[-1] and output.shape[-2] cannot be greater "
<< "than input_x.shape[-1], but got output.shape: " << CONVERT_VECTOR_TO_STRING(output_shape_)
<< ", input_x.shape: " << CONVERT_VECTOR_TO_STRING(input_shape_);
}
CalculateWorkspace(input_shape, output_shape);
}
return static_cast<int>(ret);
}

View File

@ -57,6 +57,8 @@ class MirrorPadGradGpuKernelMod : public NativeGpuKernelMod {
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &, void *)>;
private:
void CalculateWorkspace(const ShapeVector &input_shape, const std::vector<size_t> &output_shape);
MirrorPadGradLaunchFunc kernel_func_;
static std::vector<std::pair<KernelAttr, MirrorPadGradLaunchFunc>> func_list_;

View File

@ -32,21 +32,15 @@ class FloatStatusInfer : public abstract::OpInferBase {
BaseShapePtr InferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) const override {
MS_EXCEPTION_IF_NULL(primitive);
const int64_t input_num = 1;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, 1L, primitive->name());
MS_EXCEPTION_IF_NULL(input_args[kInputIndex0]);
ShapeVector shape = {1};
return std::make_shared<abstract::Shape>(shape);
}
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) const override {
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
const int64_t input_num = 1;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, prim->name());
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, 1L, prim->name());
MS_EXCEPTION_IF_NULL(input_args[kInputIndex0]);
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", input_args[0]->BuildType(), {kFloat16, kFloat32, kFloat64},
prim->name());
return std::make_shared<TensorType>(kFloat32);

View File

@ -26,23 +26,25 @@ namespace mindspore {
namespace ops {
namespace {
constexpr int64_t kPaddingsSecondDim = 2;
constexpr int64_t kMaxPaddings = 5;
constexpr size_t kMaxPaddings = 5;
void verify_padding_range(const std::string &mode, int64_t out_size, std::pair<int64_t, int64_t> padding_attr,
void verify_padding_range(const std::string &mode, int64_t out_size, const std::pair<int64_t, int64_t> padding_attr,
const std::string &prim_name) {
if (padding_attr.first < 0 || padding_attr.second < 0) {
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', all elements of paddings must be >= 0.";
}
if (mode == "SYMMETRIC") {
if (padding_attr.first > out_size || padding_attr.second > out_size)
if (padding_attr.first > out_size || padding_attr.second > out_size) {
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', paddings must be no greater than the dimension size: ["
<< padding_attr.first << "], [" << padding_attr.second << "] greater than [" << out_size
<< "]";
}
} else if (mode == "REFLECT") {
if (padding_attr.first >= out_size || padding_attr.second >= out_size)
if (padding_attr.first >= out_size || padding_attr.second >= out_size) {
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', paddings must be no greater than the dimension size: ["
<< padding_attr.first << "], [" << padding_attr.second << "] not less than [" << out_size
<< "]";
}
}
}
@ -85,15 +87,15 @@ abstract::ShapePtr MirrorPadGradInferShape(const PrimitivePtr &primitive,
paddings_attr.push_back(std::make_pair(paddings_arg[i], paddings_arg[i + 1]));
}
(void)CheckAndConvertUtils::CheckInteger("paddings_size", paddings_attr.size(), kEqual, x_shape.size(), prim_name);
int64_t size = x_shape.size();
if (size < 0 || size > kMaxPaddings) {
size_t size = x_shape.size();
if (size > kMaxPaddings) {
MS_EXCEPTION(ValueError) << "For '" << prim_name
<< "', the dimension of input only supports less than or equal to 5 dims, but got " << size
<< " dims";
}
std::string mode = GetValue<std::string>(primitive->GetAttr(kMode));
for (int64_t i = 0; i < size; i++) {
for (size_t i = 0; i < size; i++) {
int64_t out_size = x_shape[i] - (paddings_attr[i].first + paddings_attr[i].second);
verify_padding_range(mode, out_size, paddings_attr[i], prim_name);
}

View File

@ -27,21 +27,14 @@ class IsFiniteInfer : public abstract::OpInferBase {
BaseShapePtr InferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) const override {
MS_EXCEPTION_IF_NULL(primitive);
const int64_t input_num = 1;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, 1L, primitive->name());
MS_EXCEPTION_IF_NULL(input_args[kInputIndex0]);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
return std::make_shared<abstract::Shape>(x_shape);
}
TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override {
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
const int64_t input_num = 1;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, 1L, primitive->name());
MS_EXCEPTION_IF_NULL(input_args[kInputIndex0]);
CheckAndConvertUtils::CheckTensorTypeValid(
"x", input_args[0]->BuildType(),

View File

@ -32,21 +32,15 @@ class IsInfInfer : public abstract::OpInferBase {
BaseShapePtr InferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) const override {
MS_EXCEPTION_IF_NULL(primitive);
const int64_t input_num = 1;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, 1L, primitive->name());
MS_EXCEPTION_IF_NULL(input_args[kInputIndex0]);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
return std::make_shared<abstract::Shape>(x_shape);
}
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) const override {
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
const int64_t input_num = 1;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, prim->name());
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, 1L, prim->name());
MS_EXCEPTION_IF_NULL(input_args[kInputIndex0]);
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", input_args[0]->BuildType(), {kFloat16, kFloat32, kFloat64},
prim->name());
return std::make_shared<TensorType>(kBool);

View File

@ -31,19 +31,15 @@ class IsNanInfer : public abstract::OpInferBase {
BaseShapePtr InferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) const override {
MS_EXCEPTION_IF_NULL(primitive);
const int64_t input_num = 1;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, 1L, primitive->name());
MS_EXCEPTION_IF_NULL(input_args[kInputIndex0]);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
return std::make_shared<abstract::Shape>(x_shape);
}
TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override {
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, 1L, primitive->name());
MS_EXCEPTION_IF_NULL(input_args[kInputIndex0]);
(void)CheckAndConvertUtils::CheckTensorTypeValid(
"x", input_args[0]->BuildType(),
{kBool, kInt8, kInt16, kInt32, kInt64, kFloat16, kFloat32, kFloat64, kUInt8, kUInt16, kUInt32, kUInt64},

View File

@ -16,6 +16,7 @@
import numpy as np
import pytest
import mindspore as ms
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
@ -35,6 +36,7 @@ class NetGatherD(nn.Cell):
def construct(self, x, index):
return self.gatherd(x, self.dim, index)
class NetGatherDGrad(nn.Cell):
def __init__(self, network):
super(NetGatherDGrad, self).__init__()
@ -119,3 +121,28 @@ def test_gatherd_grad_checkresult():
expect = np.array([[[89.99606, -145.67], [0., 119.84]], [[138.56479, -8.696029], [0., -23.369316]]], np.float32)
error = np.ones(shape=expect.shape) * 1.0e-6
assert np.all(np.abs(output.asnumpy() - expect) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_gatherd_grad_dynamic_shape():
"""
Feature: dynamic shape support of GatherDGrad.
Description: input Tensor with dynamic shape.
Expectation: output shape coincide with expect_shape.
"""
context.set_context(mode=context.PYNATIVE_MODE, device_target='CPU')
x_dyn = Tensor(shape=[2, None], dtype=ms.float16)
x = Tensor(np.array([[1, 1, 1, 1, 1], [1, 1, 1, 1, 1]]), dtype=ms.float16)
dim = 0
index_dyn = Tensor(shape=[None, 5], dtype=ms.int64)
index = Tensor(np.array([[0, 1, 1, 0, 0], [1, 0, 0, 1, 1]]), dtype=ms.int64)
grad_dyn = Tensor(shape=[2, None], dtype=ms.float16)
grad = Tensor(np.array([[0.9031, 0.0890, 0.2779, 0.3198, 0.5710],
[0.6949, 0.8439, 0.2003, 0.6868, 0.4437]]), dtype=ms.float16)
except_shape = (2, 5)
grad_net = NetGatherDGrad(NetGatherD(dim))
grad_net.set_inputs(x_dyn, index_dyn, grad_dyn)
output = grad_net(x, index, grad)
assert output[0].asnumpy().shape == except_shape

View File

@ -16,6 +16,7 @@
import numpy as np
import pytest
import mindspore as ms
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
@ -98,3 +99,22 @@ def test_net():
out = net(x11).asnumpy()
expect = [True, True, True]
assert np.all(out == expect)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_is_finite_cpu_dynamic_shape():
"""
Feature: test FloatStatus op on CPU.
Description: test the ops in dynamic shape.
Expectation: expect correct shape result.
"""
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
net = Net()
x_dyn = Tensor(shape=[1, 32, 9, None], dtype=ms.float32)
net.set_inputs(x_dyn)
x = np.random.randn(1, 32, 9, 9)
output = net(Tensor(x, ms.float32))
except_shape = (1, 32, 9, 9)
assert output.asnumpy().shape == except_shape

View File

@ -0,0 +1,76 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
class NetIsInf(nn.Cell):
def __init__(self):
super(NetIsInf, self).__init__()
self.isinf = P.IsInf()
def construct(self, x):
return self.isinf(x)
x1 = Tensor(np.array([3, np.log(0), 1, np.log(0)]), ms.float32)
x2 = Tensor(np.array([np.log(0), 1, np.log(0), 3]), ms.float32)
x3 = Tensor(np.array([[np.log(0), 2], [np.log(0), np.log(0)]]), ms.float32)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_nan():
ms_isinf = NetIsInf()
output1 = ms_isinf(Tensor(x1))
expect1 = [[False, True, False, True]]
assert (output1.asnumpy() == expect1).all()
output2 = ms_isinf(Tensor(x2))
expect2 = [[True, False, True, False]]
assert (output2.asnumpy() == expect2).all()
output3 = ms_isinf(Tensor(x3))
expect3 = [[True, False], [True, True]]
assert (output3.asnumpy() == expect3).all()
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_is_nan_cpu_dynamic_shape():
"""
Feature: test FloatStatus op on CPU.
Description: test the ops in dynamic shape.
Expectation: expect correct shape result.
"""
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
net = NetIsInf()
x_dyn = Tensor(shape=[1, 32, 9, None], dtype=ms.float32)
net.set_inputs(x_dyn)
x = np.random.randn(1, 32, 9, 9)
output = net(Tensor(x, ms.float32))
except_shape = (1, 32, 9, 9)
assert output.asnumpy().shape == except_shape

View File

@ -16,6 +16,7 @@
import numpy as np
import pytest
import mindspore as ms
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
@ -24,9 +25,9 @@ from mindspore.ops import operations as P
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
class Netnan(nn.Cell):
class NetIsNan(nn.Cell):
def __init__(self):
super(Netnan, self).__init__()
super(NetIsNan, self).__init__()
self.isnan = P.IsNan()
def construct(self, x):
@ -42,7 +43,7 @@ x3 = np.array([[1, 2], [3, 4], [5.0, 88.0]]).astype(np.float32)
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_nan():
ms_isnan = Netnan()
ms_isnan = NetIsNan()
output1 = ms_isnan(Tensor(x1))
expect1 = [[False, False, True, False]]
assert (output1.asnumpy() == expect1).all()
@ -54,3 +55,22 @@ def test_nan():
output3 = ms_isnan(Tensor(x3))
expect3 = [[False, False], [False, False], [False, False]]
assert (output3.asnumpy() == expect3).all()
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_is_nan_cpu_dynamic_shape():
"""
Feature: test FloatStatus op on CPU.
Description: test the ops in dynamic shape.
Expectation: expect correct shape result.
"""
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
net = NetIsNan()
x_dyn = Tensor(shape=[1, 32, 9, None], dtype=ms.float32)
net.set_inputs(x_dyn)
x = np.random.randn(1, 32, 9, 9)
output = net(Tensor(x, ms.float32))
except_shape = (1, 32, 9, 9)
assert output.asnumpy().shape == except_shape

View File

@ -23,6 +23,7 @@ import mindspore.context as context
from mindspore import Tensor
from mindspore.ops.composite import GradOperation
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@ -39,14 +40,14 @@ def test_mirror_pad():
test2_arr_exp = [[[[2, 1, 1, 2, 3, 3, 2], [2, 1, 1, 2, 3, 3, 2], [5, 4, 4, 5, 6, 6, 5],
[8, 7, 7, 8, 9, 9, 8], [8, 7, 7, 8, 9, 9, 8]]]]
reflectOp = nn.Pad(mode='REFLECT', paddings=test_1_paddings)
symmOp = nn.Pad(mode='SYMMETRIC', paddings=test_2_paddings)
reflect_op = nn.Pad(mode='REFLECT', paddings=test_1_paddings)
symm_op = nn.Pad(mode='SYMMETRIC', paddings=test_2_paddings)
x_test_1 = Tensor(np.array(test1_arr_in), dtype=mindspore.float32)
x_test_2 = Tensor(np.array(test2_arr_in), dtype=mindspore.float32)
y_test_1 = reflectOp(x_test_1).asnumpy()
y_test_2 = symmOp(x_test_2).asnumpy()
y_test_1 = reflect_op(x_test_1).asnumpy()
y_test_2 = symm_op(x_test_2).asnumpy()
print(np.array(test1_arr_in))
print(y_test_1)
@ -60,13 +61,16 @@ class Grad(nn.Cell):
super(Grad, self).__init__()
self.grad = GradOperation(get_all=True, sens_param=True)
self.network = network
def construct(self, input_, output_grad):
return self.grad(self.network)(input_, output_grad)
class Net(nn.Cell):
def __init__(self, pads, mode_):
super(Net, self).__init__()
self.pad = nn.Pad(mode=mode_, paddings=pads)
def construct(self, x):
return self.pad(x)
@ -87,6 +91,7 @@ def test_mirror_pad_backprop():
dx = dx[0].asnumpy()
np.testing.assert_array_almost_equal(dx, expected_dx)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@ -104,9 +109,9 @@ def test_mirror_pad_fwd_back_4d_int32_reflect():
obtained_ms_res = op(test_arr_ms).asnumpy()
np.testing.assert_array_equal(expected_np_result, obtained_ms_res)
# backwards pass check
GradNet = Grad(Net(pads, "REFLECT"))
grad_net = Grad(Net(pads, "REFLECT"))
dy_value = Tensor(np.ones(obtained_ms_res.shape), dtype=mindspore.int32)
dx_value_obtained = GradNet(test_arr_ms, dy_value)[0].asnumpy()
dx_value_obtained = grad_net(test_arr_ms, dy_value)[0].asnumpy()
dx_value_expected = np.array([[[[4, 6, 6, 6, 2],
[6, 9, 9, 9, 3],
[2, 3, 3, 3, 1]],
@ -145,9 +150,9 @@ def test_mirror_pad_fwd_back_4d_int32_symm():
obtained_ms_res = op(test_arr_ms).asnumpy()
np.testing.assert_array_equal(expected_np_result, obtained_ms_res)
# backwards pass check
GradNet = Grad(Net(pads, "SYMMETRIC"))
grad_net = Grad(Net(pads, "SYMMETRIC"))
dy_value = Tensor(np.ones(obtained_ms_res.shape), dtype=mindspore.int32)
dx_value_obtained = GradNet(test_arr_ms, dy_value)[0].asnumpy()
dx_value_obtained = grad_net(test_arr_ms, dy_value)[0].asnumpy()
dx_value_expected = np.array([[[[16, 24, 24, 16, 16],
[16, 24, 24, 16, 16],
[16, 24, 24, 16, 16]],
@ -167,3 +172,25 @@ def test_mirror_pad_fwd_back_4d_int32_symm():
[4, 6, 6, 4, 4],
[4, 6, 6, 4, 4]]]], dtype=np.int32)
np.testing.assert_array_equal(dx_value_expected, dx_value_obtained)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_mirror_pad_grad_dynamic_shape():
"""
Feature: dynamic shape support of MirrorPadGrad.
Description: input Tensor with dynamic shape.
Expectation: output shape coincide with expect_shape.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
test_arr_in_dyn = Tensor(shape=[1, 1, None, 3], dtype=mindspore.float32)
test_arr_in = [[[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]]
test_arr_in = Tensor(test_arr_in, dtype=mindspore.float32)
dy_dyn = Tensor(shape=[1, None, None, 5], dtype=mindspore.float32)
dy = (np.ones((1, 1, 4, 5)) * 0.1).astype(np.float32)
expected_shape = (1, 1, 3, 3)
net = Grad(Net(((0, 0), (0, 0), (1, 0), (0, 2)), "REFLECT"))
net.set_inputs(test_arr_in_dyn, dy_dyn)
dx = net(test_arr_in, Tensor(dy))
assert dx[0].asnumpy().shape == expected_shape

View File

@ -16,42 +16,43 @@
import numpy as np
import pytest
import mindspore as ms
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P
class Net(nn.Cell):
class NetFloatStatus(nn.Cell):
def __init__(self):
super(Net, self).__init__()
super(NetFloatStatus, self).__init__()
self.status = P.FloatStatus()
def construct(self, x):
return self.status(x)
class Netnan(nn.Cell):
class NetIsNan(nn.Cell):
def __init__(self):
super(Netnan, self).__init__()
super(NetIsNan, self).__init__()
self.isnan = P.IsNan()
def construct(self, x):
return self.isnan(x)
class Netinf(nn.Cell):
class NetIsInf(nn.Cell):
def __init__(self):
super(Netinf, self).__init__()
super(NetIsInf, self).__init__()
self.isinf = P.IsInf()
def construct(self, x):
return self.isinf(x)
class Netfinite(nn.Cell):
class NetIsFinite(nn.Cell):
def __init__(self):
super(Netfinite, self).__init__()
super(NetIsFinite, self).__init__()
self.isfinite = P.IsFinite()
def construct(self, x):
@ -74,7 +75,7 @@ def test_status(dtype):
Description: test cases for FloatStatus
Expectation: the result match to expectation
"""
ms_status = Net()
ms_status = NetFloatStatus()
output1 = ms_status(Tensor(x1.astype(dtype)))
expect1 = 1
assert output1.asnumpy()[0] == expect1
@ -98,7 +99,7 @@ def test_nan(dtype):
Description: test cases for IsNan
Expectation: the result match to expectation
"""
ms_isnan = Netnan()
ms_isnan = NetIsNan()
output1 = ms_isnan(Tensor(x1.astype(dtype)))
expect1 = [[False, False, True, False]]
assert (output1.asnumpy() == expect1).all()
@ -122,7 +123,7 @@ def test_inf(dtype):
Description: test cases for IsInf
Expectation: the result match to expectation
"""
ms_isinf = Netinf()
ms_isinf = NetIsInf()
output1 = ms_isinf(Tensor(x1.astype(dtype)))
expect1 = [[False, False, False, False]]
assert (output1.asnumpy() == expect1).all()
@ -146,7 +147,7 @@ def test_finite(dtype):
Description: test cases for Netfinite
Expectation: the result match to expectation
"""
ms_isfinite = Netfinite()
ms_isfinite = NetIsFinite()
output1 = ms_isfinite(Tensor(x1.astype(dtype)))
expect1 = [[True, True, False, True]]
assert (output1.asnumpy() == expect1).all()
@ -158,3 +159,79 @@ def test_finite(dtype):
output3 = ms_isfinite(Tensor(x3.astype(dtype)))
expect3 = [[True, True], [True, True], [True, True]]
assert (output3.asnumpy() == expect3).all()
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_float_status_gpu_dynamic_shape():
"""
Feature: test FloatStatus op on GPU.
Description: test the ops in dynamic shape.
Expectation: expect correct shape result.
"""
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
net = NetFloatStatus()
x_dyn = Tensor(shape=[1, 32, 9, None], dtype=ms.float32)
net.set_inputs(x_dyn)
x = np.random.randn(1, 32, 9, 9)
output = net(Tensor(x, ms.float32))
except_shape = (1,)
assert output.asnumpy().shape == except_shape
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_is_nan_gpu_dynamic_shape():
"""
Feature: test FloatStatus op on GPU.
Description: test the ops in dynamic shape.
Expectation: expect correct shape result.
"""
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
net = NetIsNan()
x_dyn = Tensor(shape=[1, 32, 9, None], dtype=ms.float32)
net.set_inputs(x_dyn)
x = np.random.randn(1, 32, 9, 9)
output = net(Tensor(x, ms.float32))
except_shape = (1, 32, 9, 9)
assert output.asnumpy().shape == except_shape
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_is_inf_gpu_dynamic_shape():
"""
Feature: test FloatStatus op on GPU.
Description: test the ops in dynamic shape.
Expectation: expect correct shape result.
"""
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
net = NetIsInf()
x_dyn = Tensor(shape=[1, 32, 9, None], dtype=ms.float32)
net.set_inputs(x_dyn)
x = np.random.randn(1, 32, 9, 9)
output = net(Tensor(x, ms.float32))
except_shape = (1, 32, 9, 9)
assert output.asnumpy().shape == except_shape
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_is_finite_gpu_dynamic_shape():
"""
Feature: test FloatStatus op on GPU.
Description: test the ops in dynamic shape.
Expectation: expect correct shape result.
"""
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
net = NetIsFinite()
x_dyn = Tensor(shape=[1, 32, 9, None], dtype=ms.float32)
net.set_inputs(x_dyn)
x = np.random.randn(1, 32, 9, 9)
output = net(Tensor(x, ms.float32))
except_shape = (1, 32, 9, 9)
assert output.asnumpy().shape == except_shape

View File

@ -33,6 +33,17 @@ class GatherDNet(nn.Cell):
def construct(self, x, index):
return self.gather_d(x, self.dim, index)
class GatherDGradNet(nn.Cell):
def __init__(self, net):
super(GatherDGradNet, self).__init__()
self.net = net
self.grad = GradOperation(get_all=True, sens_param=True)(self.net)
def construct(self, x, index, grad):
return self.grad(x, index, grad)
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@ -52,6 +63,7 @@ def test_gather_grad_graph_int32_fp32():
diff = output[0].asnumpy() - expect
assert np.all(diff < error)
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@ -71,6 +83,7 @@ def test_gather_grad_graph_int64_fp32():
diff = output[0].asnumpy() - expect
assert np.all(diff < error)
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@ -90,6 +103,7 @@ def test_gather_grad_graph_int32_fp16():
diff = output[0].asnumpy() - expect
assert np.all(diff < error)
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@ -109,6 +123,7 @@ def test_gather_grad_graph_int64_fp16():
diff = output[0].asnumpy() - expect
assert np.all(diff < error)
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@ -126,6 +141,7 @@ def test_gather_grad_pynative_int32_fp32():
diff = output.asnumpy() - expect
assert np.all(diff < error)
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@ -143,6 +159,7 @@ def test_gather_grad_pynative_int64_fp32():
diff = output.asnumpy() - expect
assert np.all(diff < error)
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@ -160,6 +177,7 @@ def test_gather_grad_pynative_int32_fp16():
diff = output.asnumpy() - expect
assert np.all(diff < error)
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@ -176,3 +194,28 @@ def test_gather_grad_pynative_int64_fp16():
error = 1e-4
diff = output.asnumpy() - expect
assert np.all(diff < error)
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_gatherd_grad_dynamic_shape():
"""
Feature: dynamic shape support of GatherDGrad.
Description: input Tensor with dynamic shape.
Expectation: output shape coincide with expect_shape.
"""
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
x_dyn = Tensor(shape=[2, None], dtype=ms.float16)
x = Tensor(np.array([[1, 1, 1, 1, 1], [1, 1, 1, 1, 1]]), dtype=ms.float16)
dim = 0
index_dyn = Tensor(shape=[None, 5], dtype=ms.int64)
index = Tensor(np.array([[0, 1, 1, 0, 0], [1, 0, 0, 1, 1]]), dtype=ms.int64)
grad_dyn = Tensor(shape=[2, None], dtype=ms.float16)
grad = Tensor(np.array([[0.9031, 0.0890, 0.2779, 0.3198, 0.5710],
[0.6949, 0.8439, 0.2003, 0.6868, 0.4437]]), dtype=ms.float16)
except_shape = (2, 5)
grad_net = GatherDGradNet(GatherDNet(dim))
grad_net.set_inputs(x_dyn, index_dyn, grad_dyn)
output = grad_net(x, index, grad)
assert output[0].asnumpy().shape == except_shape

View File

@ -211,3 +211,25 @@ def test_mirror_pad_dynamic():
np.testing.assert_equal(np.array(test1_arr_exp), y_test_1)
np.testing.assert_equal(np.array(test2_arr_exp), y_test_2)
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_mirror_pad_grad_dynamic_shape():
"""
Feature: dynamic shape support of MirrorPadGrad.
Description: input Tensor with dynamic shape.
Expectation: output shape coincide with expect_shape.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
test_arr_in_dyn = Tensor(shape=[1, 1, None, 3], dtype=mindspore.float32)
test_arr_in = [[[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]]
test_arr_in = Tensor(test_arr_in, dtype=mindspore.float32)
dy_dyn = Tensor(shape=[1, None, None, 5], dtype=mindspore.float32)
dy = (np.ones((1, 1, 4, 5)) * 0.1).astype(np.float32)
expected_shape = (1, 1, 3, 3)
net = Grad(Net(((0, 0), (0, 0), (1, 0), (0, 2)), "REFLECT"))
net.set_inputs(test_arr_in_dyn, dy_dyn)
dx = net(test_arr_in, Tensor(dy))
assert dx[0].asnumpy().shape == expected_shape