forked from mindspore-Ecosystem/mindspore
complex test
This commit is contained in:
parent
82c0c2477e
commit
d2dfc66113
|
@ -93,18 +93,6 @@ class MedianGradGpuKernelMod : public NativeGpuKernelMod {
|
|||
input_shape_ = inputs[1]->GetShapeVector();
|
||||
input1_dim_ = input_shape_.size();
|
||||
std::vector<int64_t> input0_shape = inputs[0]->GetShapeVector();
|
||||
|
||||
if (axis_ < -input1_dim_ || axis_ >= input1_dim_) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the 'axis' must be in the range [-" << input1_dim_ << ","
|
||||
<< input1_dim_ << "), but got " << axis_;
|
||||
}
|
||||
if (axis_ < 0) {
|
||||
if (input1_dim_ == 0) {
|
||||
axis_ = 0;
|
||||
} else {
|
||||
axis_ += input1_dim_;
|
||||
}
|
||||
}
|
||||
input1_size_ = 1;
|
||||
input0_size_ = 1;
|
||||
for (size_t i = 0; i < input_shape_.size(); i++) {
|
||||
|
@ -152,6 +140,23 @@ class MedianGradGpuKernelMod : public NativeGpuKernelMod {
|
|||
input1_dim_ = input_shape_.size();
|
||||
input1_size_ = 1;
|
||||
input0_size_ = 1;
|
||||
if (input1_dim_ == 0) {
|
||||
if (axis_ < -1 || axis_ > 0) {
|
||||
MS_LOG(EXCEPTION) << "For 'MedianGrad'"
|
||||
<< "', the 'axis' must be in the range [-1,1), but got " << axis_;
|
||||
}
|
||||
} else if (axis_ < -input1_dim_ || axis_ >= input1_dim_) {
|
||||
MS_LOG(EXCEPTION) << "For 'MedianGrad'"
|
||||
<< "', the 'axis' must be in the range [-" << input1_dim_ << "," << input1_dim_ << "), but got "
|
||||
<< axis_;
|
||||
}
|
||||
if (axis_ < 0) {
|
||||
if (input1_dim_ == 0) {
|
||||
axis_ = 0;
|
||||
} else {
|
||||
axis_ += input1_dim_;
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < input_shape_.size(); i++) {
|
||||
input1_size_ *= input_shape_[i];
|
||||
}
|
||||
|
|
|
@ -0,0 +1,135 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from mindspore import Tensor
|
||||
from mindspore.nn import Cell
|
||||
from mindspore import ops
|
||||
from mindspore.ops import Complex as ComplexOp
|
||||
|
||||
shape_2d = (7, 6)
|
||||
shape_1d = (6,)
|
||||
rand_shape_1d_1 = np.random.rand(*shape_1d).astype(np.float32)
|
||||
rand_shape_1d_2 = np.random.rand(*shape_1d).astype(np.float32)
|
||||
rand_shape_2d_1 = np.random.rand(*shape_2d).astype(np.float32)
|
||||
rand_shape_2d_2 = np.random.rand(*shape_2d).astype(np.float32)
|
||||
real_op = ops.Real()
|
||||
imag_op = ops.Imag()
|
||||
|
||||
|
||||
class Complex(Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.complex = ComplexOp()
|
||||
|
||||
def construct(self, x, y):
|
||||
return self.complex(x, y)
|
||||
|
||||
|
||||
def complex_compare(complex1, complex2):
|
||||
real1 = real_op(Tensor(complex1)).asnumpy()
|
||||
real2 = np.real(complex2)
|
||||
imag1 = imag_op(Tensor(complex1)).asnumpy()
|
||||
imag2 = np.imag(complex2)
|
||||
return np.allclose(real1, real2, rtol=5e-03, atol=5e-03) and np.allclose(imag1, imag2, rtol=5e-03, atol=5e-03)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_complex_elemwise():
|
||||
"""
|
||||
Feature: complex basic Operation.
|
||||
Description: Test complex basic Operation.
|
||||
Expectation: the result match given one.
|
||||
"""
|
||||
|
||||
real_ms = Tensor(rand_shape_2d_1)
|
||||
imag_ms = Tensor(rand_shape_2d_2)
|
||||
real_to = rand_shape_2d_1
|
||||
imag_to = rand_shape_2d_2
|
||||
|
||||
complex1 = Complex()(real_ms, imag_ms)
|
||||
complex2 = Complex()(imag_ms, real_ms)
|
||||
complex_1 = np.vectorize(complex)(real_to, imag_to)
|
||||
complex_2 = np.vectorize(complex)(imag_to, real_to)
|
||||
assert complex_compare(complex1, complex_1)
|
||||
|
||||
res_ms = ops.Add()(complex1, complex2)
|
||||
res_to = np.add(complex_1, complex_2)
|
||||
assert complex_compare(res_ms, res_to)
|
||||
|
||||
res_ms = ops.Mul()(complex1, complex2)
|
||||
res_to = np.multiply(complex_1, complex_2)
|
||||
assert complex_compare(res_ms, res_to)
|
||||
|
||||
res_ms = ops.Sub()(complex1, complex2)
|
||||
res_to = np.subtract(complex_1, complex_2)
|
||||
assert complex_compare(res_ms, res_to)
|
||||
|
||||
res_ms = ops.Div()(complex1, complex2)
|
||||
res_to = np.divide(complex_1, complex_2)
|
||||
assert complex_compare(res_ms, res_to)
|
||||
|
||||
res_ms = complex1 / complex2
|
||||
res_to = np.divide(complex_1, complex_2)
|
||||
assert complex_compare(res_ms, res_to)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_complex_broadcast():
|
||||
"""
|
||||
Feature: complex broadcast Operation.
|
||||
Description: Test complex broadcast Operation.
|
||||
Expectation: the result match given one.
|
||||
"""
|
||||
|
||||
real_ms_1 = Tensor(rand_shape_2d_1)
|
||||
imag_ms_1 = Tensor(rand_shape_2d_2)
|
||||
real_ms_2 = Tensor(rand_shape_1d_1)
|
||||
imag_ms_2 = Tensor(rand_shape_1d_2)
|
||||
real_to_1 = rand_shape_2d_1
|
||||
imag_to_1 = rand_shape_2d_2
|
||||
real_to_2 = rand_shape_1d_1
|
||||
imag_to_2 = rand_shape_1d_2
|
||||
|
||||
complex1 = Complex()(real_ms_1, imag_ms_1)
|
||||
complex2 = Complex()(real_ms_2, imag_ms_2)
|
||||
complex_1 = np.vectorize(complex)(real_to_1, imag_to_1)
|
||||
complex_2 = np.vectorize(complex)(real_to_2, imag_to_2)
|
||||
assert complex_compare(complex1, complex_1)
|
||||
|
||||
res_ms = ops.Add()(complex1, complex2)
|
||||
res_to = np.add(complex_1, complex_2)
|
||||
assert complex_compare(res_ms, res_to)
|
||||
|
||||
res_ms = ops.Mul()(complex1, complex2)
|
||||
res_to = np.multiply(complex_1, complex_2)
|
||||
assert complex_compare(res_ms, res_to)
|
||||
|
||||
res_ms = ops.Sub()(complex1, complex2)
|
||||
res_to = np.subtract(complex_1, complex_2)
|
||||
assert complex_compare(res_ms, res_to)
|
||||
|
||||
res_ms = ops.Div()(complex1, complex2)
|
||||
res_to = np.divide(complex_1, complex_2)
|
||||
assert complex_compare(res_ms, res_to)
|
||||
|
||||
res_ms = complex1 / complex2
|
||||
res_to = np.divide(complex_1, complex_2)
|
||||
assert complex_compare(res_ms, res_to)
|
Loading…
Reference in New Issue