forked from mindspore-Ecosystem/mindspore
commit
8c3e0d6b54
|
@ -66,7 +66,7 @@ class UnravelIndexHelperGpuKernel : public GpuKernelHelperBase {
|
||||||
if (is_null_input_) {
|
if (is_null_input_) {
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
size_t indices_size = input_indices_shape_[0];
|
size_t indices_size = input_indices_shape_.size() == 0 ? 1 : input_indices_shape_[0];
|
||||||
size_t dims_size = input_dims_shape_[0];
|
size_t dims_size = input_dims_shape_[0];
|
||||||
|
|
||||||
T *input_indices_ptr = nullptr;
|
T *input_indices_ptr = nullptr;
|
||||||
|
@ -74,25 +74,10 @@ class UnravelIndexHelperGpuKernel : public GpuKernelHelperBase {
|
||||||
T *output_ptr = nullptr;
|
T *output_ptr = nullptr;
|
||||||
T *check_dims_ptr = nullptr;
|
T *check_dims_ptr = nullptr;
|
||||||
|
|
||||||
int flag = GetDeviceAddress<T>(input_ptrs, 0, kernel_name_, &input_indices_ptr);
|
(void)GetDeviceAddress<T>(input_ptrs, 0, kernel_name_, &input_indices_ptr);
|
||||||
if (flag != 0) {
|
(void)GetDeviceAddress<T>(input_ptrs, 1, kernel_name_, &input_dims_ptr);
|
||||||
return flag;
|
(void)GetDeviceAddress<T>(output_ptrs, 0, kernel_name_, &output_ptr);
|
||||||
}
|
(void)GetDeviceAddress<T>(work_ptrs, 0, kernel_name_, &check_dims_ptr);
|
||||||
|
|
||||||
flag = GetDeviceAddress<T>(input_ptrs, 1, kernel_name_, &input_dims_ptr);
|
|
||||||
if (flag != 0) {
|
|
||||||
return flag;
|
|
||||||
}
|
|
||||||
|
|
||||||
flag = GetDeviceAddress<T>(output_ptrs, 0, kernel_name_, &output_ptr);
|
|
||||||
if (flag != 0) {
|
|
||||||
return flag;
|
|
||||||
}
|
|
||||||
|
|
||||||
flag = GetDeviceAddress<T>(work_ptrs, 0, kernel_name_, &check_dims_ptr);
|
|
||||||
if (flag != 0) {
|
|
||||||
return flag;
|
|
||||||
}
|
|
||||||
|
|
||||||
// call cuda kernel
|
// call cuda kernel
|
||||||
CalUnravelIndex(input_indices_ptr, input_dims_ptr, indices_size, dims_size, output_ptr, device_id_,
|
CalUnravelIndex(input_indices_ptr, input_dims_ptr, indices_size, dims_size, output_ptr, device_id_,
|
||||||
|
|
|
@ -310,6 +310,13 @@ __global__ void SinhKernel(const double *input, double *output, const size_t cou
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
template <typename T>
|
||||||
|
__global__ void SinhKernel(const Complex<T> *input, Complex<T> *output, const size_t count) {
|
||||||
|
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||||
|
output[i] = sinh(input[i]);
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
template <>
|
template <>
|
||||||
__global__ void SinhKernel(const half *input, half *output, const size_t count) {
|
__global__ void SinhKernel(const half *input, half *output, const size_t count) {
|
||||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < count; i += blockDim.x * gridDim.x) {
|
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < count; i += blockDim.x * gridDim.x) {
|
||||||
|
@ -1847,6 +1854,8 @@ template CUDA_LIB_EXPORT void Sign<Complex<float>>(const Complex<float> *input,
|
||||||
const size_t count, cudaStream_t cuda_stream);
|
const size_t count, cudaStream_t cuda_stream);
|
||||||
template CUDA_LIB_EXPORT void Cosh<Complex<float>>(const Complex<float> *input, Complex<float> *output,
|
template CUDA_LIB_EXPORT void Cosh<Complex<float>>(const Complex<float> *input, Complex<float> *output,
|
||||||
const size_t count, cudaStream_t cuda_stream);
|
const size_t count, cudaStream_t cuda_stream);
|
||||||
|
template CUDA_LIB_EXPORT void Sinh<Complex<float>>(const Complex<float> *input, Complex<float> *output,
|
||||||
|
const size_t count, cudaStream_t cuda_stream);
|
||||||
template CUDA_LIB_EXPORT void Atan<Complex<float>>(const Complex<float> *input, Complex<float> *output,
|
template CUDA_LIB_EXPORT void Atan<Complex<float>>(const Complex<float> *input, Complex<float> *output,
|
||||||
const size_t count, cudaStream_t cuda_stream);
|
const size_t count, cudaStream_t cuda_stream);
|
||||||
template CUDA_LIB_EXPORT void Atanh<Complex<float>>(const Complex<float> *input, Complex<float> *output,
|
template CUDA_LIB_EXPORT void Atanh<Complex<float>>(const Complex<float> *input, Complex<float> *output,
|
||||||
|
@ -1881,6 +1890,8 @@ template CUDA_LIB_EXPORT void Sin<Complex<double>>(const Complex<double> *input,
|
||||||
const size_t count, cudaStream_t cuda_stream);
|
const size_t count, cudaStream_t cuda_stream);
|
||||||
template CUDA_LIB_EXPORT void Cos<Complex<double>>(const Complex<double> *input, Complex<double> *output,
|
template CUDA_LIB_EXPORT void Cos<Complex<double>>(const Complex<double> *input, Complex<double> *output,
|
||||||
const size_t count, cudaStream_t cuda_stream);
|
const size_t count, cudaStream_t cuda_stream);
|
||||||
|
template CUDA_LIB_EXPORT void Sinh<Complex<double>>(const Complex<double> *input, Complex<double> *output,
|
||||||
|
const size_t count, cudaStream_t cuda_stream);
|
||||||
template CUDA_LIB_EXPORT void ACos<Complex<double>>(const Complex<double> *input, Complex<double> *output,
|
template CUDA_LIB_EXPORT void ACos<Complex<double>>(const Complex<double> *input, Complex<double> *output,
|
||||||
const size_t count, cudaStream_t cuda_stream);
|
const size_t count, cudaStream_t cuda_stream);
|
||||||
template CUDA_LIB_EXPORT void Acosh<Complex<double>>(const Complex<double> *input, Complex<double> *output,
|
template CUDA_LIB_EXPORT void Acosh<Complex<double>>(const Complex<double> *input, Complex<double> *output,
|
||||||
|
|
|
@ -241,7 +241,11 @@ std::map<std::string, std::vector<std::pair<KernelAttr, UnaryOpGpuKernelMod::Una
|
||||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||||
&UnaryOpGpuKernelMod::LaunchKernel<float>},
|
&UnaryOpGpuKernelMod::LaunchKernel<float>},
|
||||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||||
&UnaryOpGpuKernelMod::LaunchKernel<half>}}},
|
&UnaryOpGpuKernelMod::LaunchKernel<half>},
|
||||||
|
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
|
||||||
|
&UnaryOpGpuKernelMod::LaunchKernel<utils::Complex<float>>},
|
||||||
|
{KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
|
||||||
|
&UnaryOpGpuKernelMod::LaunchKernel<utils::Complex<double>>}}},
|
||||||
{kTan,
|
{kTan,
|
||||||
{{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
{{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||||
&UnaryOpGpuKernelMod::LaunchKernel<double>},
|
&UnaryOpGpuKernelMod::LaunchKernel<double>},
|
||||||
|
@ -479,7 +483,7 @@ bool UnaryOpGpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &in
|
||||||
{kInv, Inv<T>}, {kLog, Logarithm<T>}, {kExp, Exponential<T>}, {kNeg, Negative<T>},
|
{kInv, Inv<T>}, {kLog, Logarithm<T>}, {kExp, Exponential<T>}, {kNeg, Negative<T>},
|
||||||
{kSin, Sin<T>}, {kCos, Cos<T>}, {kACos, ACos<T>}, {kAcosh, Acosh<T>},
|
{kSin, Sin<T>}, {kCos, Cos<T>}, {kACos, ACos<T>}, {kAcosh, Acosh<T>},
|
||||||
{kAsin, Asin<T>}, {kAsinh, Asinh<T>}, {kSquare, Square<T>}, {kReciprocal, Reciprocal<T>},
|
{kAsin, Asin<T>}, {kAsinh, Asinh<T>}, {kSquare, Square<T>}, {kReciprocal, Reciprocal<T>},
|
||||||
{kRsqrt, Rsqrt<T>}, {kSign, Sign<T>}, {kAtan, Atan<T>}};
|
{kRsqrt, Rsqrt<T>}, {kSign, Sign<T>}, {kAtan, Atan<T>}, {kSinh, Sinh<T>}};
|
||||||
copy(func_map_complex.begin(), func_map_complex.end(), inserter(func_map, func_map.begin()));
|
copy(func_map_complex.begin(), func_map_complex.end(), inserter(func_map, func_map.begin()));
|
||||||
} else {
|
} else {
|
||||||
std::map<std::string, std::function<void(const T *, T *, const size_t, cudaStream_t)>> func_map_normal = {
|
std::map<std::string, std::function<void(const T *, T *, const size_t, cudaStream_t)>> func_map_normal = {
|
||||||
|
|
|
@ -7661,7 +7661,7 @@ class Qr(Primitive):
|
||||||
ValueError: If the dimension of `x` is less than 2.
|
ValueError: If the dimension of `x` is less than 2.
|
||||||
|
|
||||||
Supported Platforms:
|
Supported Platforms:
|
||||||
``Ascend`` ``GPU`` ``CPU``
|
``Ascend`` ``CPU``
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> qr_op = ops.Qr(full_matrices=False)
|
>>> qr_op = ops.Qr(full_matrices=False)
|
||||||
|
|
|
@ -1,70 +0,0 @@
|
||||||
# Copyright 2022 Huawei Technologies Co., Ltd
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pytest
|
|
||||||
from mindspore import Tensor
|
|
||||||
import mindspore.context as context
|
|
||||||
from mindspore.common import dtype as mstype
|
|
||||||
import mindspore.ops.operations.math_ops as P
|
|
||||||
|
|
||||||
|
|
||||||
def my_method(input_x, full_matrices):
|
|
||||||
qr_op = P.Qr(full_matrices=full_matrices)
|
|
||||||
out = qr_op(Tensor(input_x))
|
|
||||||
res = [out[0].asnumpy(), out[1].asnumpy()]
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
def qr_cmp(input_x, full_matrices):
|
|
||||||
out_me = my_method(input_x, full_matrices)
|
|
||||||
_out_q = Tensor([[-0.857143, 0.394286, 0.331429],
|
|
||||||
[-0.428571, -0.902857, -0.034286],
|
|
||||||
[0.285714, -0.171429, 0.942857]],
|
|
||||||
dtype=mstype.float32).asnumpy()
|
|
||||||
_out_r = Tensor([[-14.000001, -21.00001, 14],
|
|
||||||
[0, -175, 70.000015],
|
|
||||||
[0, 0, -34.999996]],
|
|
||||||
dtype=mstype.float32).asnumpy()
|
|
||||||
np.testing.assert_allclose(out_me[0], _out_q, rtol=1e-3)
|
|
||||||
np.testing.assert_allclose(out_me[1], _out_r, rtol=1e-3)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.level0
|
|
||||||
@pytest.mark.platform_x86_gpu_training
|
|
||||||
@pytest.mark.env_onecard
|
|
||||||
def test_qr_pynative():
|
|
||||||
"""
|
|
||||||
Feature: Qr_pynative
|
|
||||||
Description: test cases for qr: m >= n and full_matrices=True
|
|
||||||
Expectation: the result match to tf
|
|
||||||
"""
|
|
||||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
|
||||||
x = np.array([[12., -51, 4], [6, 167, -68], [-4, 24, -41]])
|
|
||||||
qr_cmp(input_x=x, full_matrices=True)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.level0
|
|
||||||
@pytest.mark.platform_x86_gpu_training
|
|
||||||
@pytest.mark.env_onecard
|
|
||||||
def test_qr_graph():
|
|
||||||
"""
|
|
||||||
Feature: Qr_graph
|
|
||||||
Description: test cases for qr: m < n and full_matrices=False
|
|
||||||
Expectation: the result match to tf
|
|
||||||
"""
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
|
||||||
x = np.array([[12., -51, 4], [6, 167, -68], [-4, 24, -41]])
|
|
||||||
qr_cmp(input_x=x, full_matrices=False)
|
|
Loading…
Reference in New Issue