complex test

This commit is contained in:
muxiyin 2022-08-25 10:47:33 +08:00
parent 82c0c2477e
commit d2dfc66113
2 changed files with 152 additions and 12 deletions

View File

@ -93,18 +93,6 @@ class MedianGradGpuKernelMod : public NativeGpuKernelMod {
input_shape_ = inputs[1]->GetShapeVector();
input1_dim_ = input_shape_.size();
std::vector<int64_t> input0_shape = inputs[0]->GetShapeVector();
if (axis_ < -input1_dim_ || axis_ >= input1_dim_) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the 'axis' must be in the range [-" << input1_dim_ << ","
<< input1_dim_ << "), but got " << axis_;
}
if (axis_ < 0) {
if (input1_dim_ == 0) {
axis_ = 0;
} else {
axis_ += input1_dim_;
}
}
input1_size_ = 1;
input0_size_ = 1;
for (size_t i = 0; i < input_shape_.size(); i++) {
@ -152,6 +140,23 @@ class MedianGradGpuKernelMod : public NativeGpuKernelMod {
input1_dim_ = input_shape_.size();
input1_size_ = 1;
input0_size_ = 1;
if (input1_dim_ == 0) {
if (axis_ < -1 || axis_ > 0) {
MS_LOG(EXCEPTION) << "For 'MedianGrad'"
<< "', the 'axis' must be in the range [-1,1), but got " << axis_;
}
} else if (axis_ < -input1_dim_ || axis_ >= input1_dim_) {
MS_LOG(EXCEPTION) << "For 'MedianGrad'"
<< "', the 'axis' must be in the range [-" << input1_dim_ << "," << input1_dim_ << "), but got "
<< axis_;
}
if (axis_ < 0) {
if (input1_dim_ == 0) {
axis_ = 0;
} else {
axis_ += input1_dim_;
}
}
for (size_t i = 0; i < input_shape_.size(); i++) {
input1_size_ *= input_shape_[i];
}

View File

@ -0,0 +1,135 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
from mindspore import Tensor
from mindspore.nn import Cell
from mindspore import ops
from mindspore.ops import Complex as ComplexOp
shape_2d = (7, 6)
shape_1d = (6,)
rand_shape_1d_1 = np.random.rand(*shape_1d).astype(np.float32)
rand_shape_1d_2 = np.random.rand(*shape_1d).astype(np.float32)
rand_shape_2d_1 = np.random.rand(*shape_2d).astype(np.float32)
rand_shape_2d_2 = np.random.rand(*shape_2d).astype(np.float32)
real_op = ops.Real()
imag_op = ops.Imag()
class Complex(Cell):
def __init__(self):
super().__init__()
self.complex = ComplexOp()
def construct(self, x, y):
return self.complex(x, y)
def complex_compare(complex1, complex2):
real1 = real_op(Tensor(complex1)).asnumpy()
real2 = np.real(complex2)
imag1 = imag_op(Tensor(complex1)).asnumpy()
imag2 = np.imag(complex2)
return np.allclose(real1, real2, rtol=5e-03, atol=5e-03) and np.allclose(imag1, imag2, rtol=5e-03, atol=5e-03)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_complex_elemwise():
"""
Feature: complex basic Operation.
Description: Test complex basic Operation.
Expectation: the result match given one.
"""
real_ms = Tensor(rand_shape_2d_1)
imag_ms = Tensor(rand_shape_2d_2)
real_to = rand_shape_2d_1
imag_to = rand_shape_2d_2
complex1 = Complex()(real_ms, imag_ms)
complex2 = Complex()(imag_ms, real_ms)
complex_1 = np.vectorize(complex)(real_to, imag_to)
complex_2 = np.vectorize(complex)(imag_to, real_to)
assert complex_compare(complex1, complex_1)
res_ms = ops.Add()(complex1, complex2)
res_to = np.add(complex_1, complex_2)
assert complex_compare(res_ms, res_to)
res_ms = ops.Mul()(complex1, complex2)
res_to = np.multiply(complex_1, complex_2)
assert complex_compare(res_ms, res_to)
res_ms = ops.Sub()(complex1, complex2)
res_to = np.subtract(complex_1, complex_2)
assert complex_compare(res_ms, res_to)
res_ms = ops.Div()(complex1, complex2)
res_to = np.divide(complex_1, complex_2)
assert complex_compare(res_ms, res_to)
res_ms = complex1 / complex2
res_to = np.divide(complex_1, complex_2)
assert complex_compare(res_ms, res_to)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_complex_broadcast():
"""
Feature: complex broadcast Operation.
Description: Test complex broadcast Operation.
Expectation: the result match given one.
"""
real_ms_1 = Tensor(rand_shape_2d_1)
imag_ms_1 = Tensor(rand_shape_2d_2)
real_ms_2 = Tensor(rand_shape_1d_1)
imag_ms_2 = Tensor(rand_shape_1d_2)
real_to_1 = rand_shape_2d_1
imag_to_1 = rand_shape_2d_2
real_to_2 = rand_shape_1d_1
imag_to_2 = rand_shape_1d_2
complex1 = Complex()(real_ms_1, imag_ms_1)
complex2 = Complex()(real_ms_2, imag_ms_2)
complex_1 = np.vectorize(complex)(real_to_1, imag_to_1)
complex_2 = np.vectorize(complex)(real_to_2, imag_to_2)
assert complex_compare(complex1, complex_1)
res_ms = ops.Add()(complex1, complex2)
res_to = np.add(complex_1, complex_2)
assert complex_compare(res_ms, res_to)
res_ms = ops.Mul()(complex1, complex2)
res_to = np.multiply(complex_1, complex_2)
assert complex_compare(res_ms, res_to)
res_ms = ops.Sub()(complex1, complex2)
res_to = np.subtract(complex_1, complex_2)
assert complex_compare(res_ms, res_to)
res_ms = ops.Div()(complex1, complex2)
res_to = np.divide(complex_1, complex_2)
assert complex_compare(res_ms, res_to)
res_ms = complex1 / complex2
res_to = np.divide(complex_1, complex_2)
assert complex_compare(res_ms, res_to)