forked from mindspore-Ecosystem/mindspore
!21351 complex support
Merge pull request !21351 from zhouyaqiang0/complex_support
This commit is contained in:
commit
2650aae9ba
|
@ -109,6 +109,22 @@ REGISTER_PYBIND_DEFINE(
|
|||
Float data(t[0].cast<py::int_>());
|
||||
return data;
|
||||
}));
|
||||
(void)py::class_<Complex, Number, std::shared_ptr<Complex>>(m_sub, "Complex")
|
||||
.def(py::init())
|
||||
.def(py::init<int>(), py::arg("nbits"))
|
||||
.def(py::pickle(
|
||||
[](const Complex &t) { // __getstate__
|
||||
/* Return a tuple that fully encodes the state of the object */
|
||||
return py::make_tuple(py::int_(t.nbits()));
|
||||
},
|
||||
[](const py::tuple &t) { // __setstate__
|
||||
if (t.size() != 1) {
|
||||
throw std::runtime_error("Invalid state!");
|
||||
}
|
||||
/* Create a new C++ instance */
|
||||
Complex data(t[0].cast<py::int_>());
|
||||
return data;
|
||||
}));
|
||||
(void)py::class_<List, Type, std::shared_ptr<List>>(m_sub, "List")
|
||||
.def(py::init())
|
||||
.def(py::init<std::vector<TypePtr>>(), py::arg("elements"));
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include <sstream>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <complex>
|
||||
|
||||
#include "pybind_api/api_register.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
|
@ -78,9 +79,15 @@ static TypeId GetDataType(const py::buffer_info &buf) {
|
|||
case '?':
|
||||
return TypeId::kNumberTypeBool;
|
||||
}
|
||||
} else if (buf.format.size() >= 2 && buf.format.back() == 'w') {
|
||||
} else if (buf.format.size() >= 2) {
|
||||
// Support np.str_ dtype, format: {x}w. {x} is a number that means the maximum length of the string items.
|
||||
return TypeId::kObjectTypeString;
|
||||
if (buf.format.back() == 'w') {
|
||||
return TypeId::kObjectTypeString;
|
||||
} else if (buf.format == "Zf") {
|
||||
return TypeId::kNumberTypeComplex64;
|
||||
} else if (buf.format == "Zd") {
|
||||
return TypeId::kNumberTypeComplex128;
|
||||
}
|
||||
}
|
||||
MS_LOG(WARNING) << "Unsupported DataType format " << buf.format << ", item size " << buf.itemsize;
|
||||
return TypeId::kTypeUnknown;
|
||||
|
@ -114,6 +121,10 @@ static std::string GetPyTypeFormat(TypeId data_type) {
|
|||
return py::format_descriptor<bool>::format();
|
||||
case TypeId::kObjectTypeString:
|
||||
return py::format_descriptor<uint8_t>::format();
|
||||
case TypeId::kNumberTypeComplex64:
|
||||
return py::format_descriptor<std::complex<float>>::format();
|
||||
case TypeId::kNumberTypeComplex128:
|
||||
return py::format_descriptor<std::complex<double>>::format();
|
||||
default:
|
||||
MS_LOG(WARNING) << "Unsupported DataType " << data_type << ".";
|
||||
return "";
|
||||
|
|
|
@ -0,0 +1,323 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_UTILS_COPLEX_H_
|
||||
#define MINDSPORE_CCSRC_UTILS_COPLEX_H_
|
||||
|
||||
#include <complex>
|
||||
#include <limits>
|
||||
#ifdef ENABLE_GPU
|
||||
#include <thrust/complex.h>
|
||||
#endif
|
||||
#include "base/float16.h"
|
||||
#if defined(__CUDACC__)
|
||||
#define HOST_DEVICE __host__ __device__
|
||||
#else
|
||||
#define HOST_DEVICE
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
namespace utils {
|
||||
// Implement Complex for mindspore, inspired by std::complex.
|
||||
template <typename T>
|
||||
struct alignas(sizeof(T) * 2) Complex {
|
||||
Complex() = default;
|
||||
~Complex() = default;
|
||||
|
||||
Complex(const Complex<T> &other) noexcept = default;
|
||||
Complex(Complex<T> &&other) noexcept = default;
|
||||
|
||||
Complex &operator=(const Complex<T> &other) noexcept = default;
|
||||
Complex &operator=(Complex<T> &&other) noexcept = default;
|
||||
|
||||
HOST_DEVICE inline constexpr Complex(const T &real, const T &imag = T()) : real_(real), imag_(imag) {}
|
||||
|
||||
template <typename U>
|
||||
inline explicit constexpr Complex(const std::complex<U> &other) : Complex(other.real(), other.imag()) {}
|
||||
template <typename U>
|
||||
inline explicit constexpr operator std::complex<U>() const {
|
||||
return std::complex<U>(std::complex<T>(real(), imag()));
|
||||
}
|
||||
|
||||
HOST_DEVICE inline explicit constexpr Complex(const float16 &real) : real_(static_cast<T>(real)), imag_(T()) {}
|
||||
#if defined(__CUDACC__)
|
||||
template <typename U>
|
||||
HOST_DEVICE inline explicit Complex(const thrust::complex<U> &other) : real_(other.real()), imag_(other.imag()) {}
|
||||
|
||||
template <typename U>
|
||||
HOST_DEVICE inline HOST_DEVICE explicit operator thrust::complex<U>() const {
|
||||
return static_cast<thrust::complex<U>>(thrust::complex<T>(real(), imag()));
|
||||
}
|
||||
#endif
|
||||
template <typename U = T>
|
||||
HOST_DEVICE explicit Complex(const std::enable_if_t<std::is_same<U, float>::value, Complex<double>> &other)
|
||||
: real_(other.real()), imag_(other.imag()) {}
|
||||
|
||||
template <typename U = T>
|
||||
HOST_DEVICE explicit Complex(const std::enable_if_t<std::is_same<U, double>::value, Complex<float>> &other)
|
||||
: real_(other.real()), imag_(other.imag()) {}
|
||||
|
||||
HOST_DEVICE inline explicit operator bool() const { return static_cast<bool>(real_) || static_cast<bool>(imag_); }
|
||||
HOST_DEVICE inline explicit operator signed char() const { return static_cast<signed char>(real_); }
|
||||
HOST_DEVICE inline explicit operator unsigned char() const { return static_cast<unsigned char>(real_); }
|
||||
HOST_DEVICE inline explicit operator double() const { return static_cast<double>(real_); }
|
||||
HOST_DEVICE inline explicit operator float() const { return static_cast<float>(real_); }
|
||||
HOST_DEVICE inline explicit operator int16_t() const { return static_cast<int16_t>(real_); }
|
||||
HOST_DEVICE inline explicit operator uint16_t() const { return static_cast<uint16_t>(real_); }
|
||||
HOST_DEVICE inline explicit operator int32_t() const { return static_cast<int32_t>(real_); }
|
||||
HOST_DEVICE inline explicit operator uint32_t() const { return static_cast<uint32_t>(real_); }
|
||||
HOST_DEVICE inline explicit operator int64_t() const { return static_cast<int64_t>(real_); }
|
||||
HOST_DEVICE inline explicit operator uint64_t() const { return static_cast<uint64_t>(real_); }
|
||||
HOST_DEVICE inline explicit operator float16() const { return static_cast<float16>(real_); }
|
||||
|
||||
HOST_DEVICE inline constexpr Complex<T> &operator=(const T &real) {
|
||||
real_ = real;
|
||||
imag_ = T();
|
||||
return *this;
|
||||
}
|
||||
|
||||
HOST_DEVICE inline Complex<T> &operator+=(const T &real) {
|
||||
real_ += real;
|
||||
return *this;
|
||||
}
|
||||
|
||||
HOST_DEVICE inline Complex<T> &operator-=(const T &real) {
|
||||
real_ -= real;
|
||||
return *this;
|
||||
}
|
||||
|
||||
HOST_DEVICE inline Complex<T> &operator*=(const T &real) {
|
||||
real_ *= real;
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Note: check division by zero before use it.
|
||||
HOST_DEVICE inline Complex<T> &operator/=(const T &real) {
|
||||
real_ /= real;
|
||||
return *this;
|
||||
}
|
||||
|
||||
template <typename U>
|
||||
HOST_DEVICE inline Complex<T> &operator=(const Complex<U> &z) {
|
||||
real_ = z.real();
|
||||
imag_ = z.imag();
|
||||
return *this;
|
||||
}
|
||||
template <typename U>
|
||||
HOST_DEVICE inline Complex<T> &operator+=(const Complex<U> &z) {
|
||||
real_ += z.real();
|
||||
imag_ += z.imag();
|
||||
return *this;
|
||||
}
|
||||
template <typename U>
|
||||
HOST_DEVICE inline Complex<T> &operator-=(const Complex<U> &z) {
|
||||
real_ -= z.real();
|
||||
imag_ -= z.imag();
|
||||
return *this;
|
||||
}
|
||||
template <typename U>
|
||||
HOST_DEVICE inline Complex<T> &operator*=(const Complex<U> &z);
|
||||
|
||||
// Note: check division by zero before use it.
|
||||
template <typename U>
|
||||
HOST_DEVICE inline Complex<T> &operator/=(const Complex<U> &z);
|
||||
|
||||
HOST_DEVICE inline constexpr T real() const { return real_; }
|
||||
HOST_DEVICE inline constexpr T imag() const { return imag_; }
|
||||
HOST_DEVICE inline void real(T val) { real_ = val; }
|
||||
HOST_DEVICE inline void imag(T val) { imag_ = val; }
|
||||
|
||||
private:
|
||||
T real_;
|
||||
T imag_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
template <typename U>
|
||||
HOST_DEVICE inline Complex<T> &Complex<T>::operator*=(const Complex<U> &z) {
|
||||
const T real = real_ * z.real() - imag_ * z.imag();
|
||||
imag_ = real_ * z.imag() + imag_ * z.real();
|
||||
real_ = real;
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Note: check division by zero before use it.
|
||||
template <typename T>
|
||||
template <typename U>
|
||||
HOST_DEVICE inline Complex<T> &Complex<T>::operator/=(const Complex<U> &z) {
|
||||
T a = real_;
|
||||
T b = imag_;
|
||||
U c = z.real();
|
||||
U d = z.imag();
|
||||
auto denominator = c * c + d * d;
|
||||
real_ = (a * c + b * d) / denominator;
|
||||
imag_ = (b * c - a * d) / denominator;
|
||||
return *this;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
HOST_DEVICE inline Complex<T> operator+(const Complex<T> &lhs, const Complex<T> &rhs) {
|
||||
Complex<T> result = lhs;
|
||||
result += rhs;
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
HOST_DEVICE inline Complex<T> operator+(const Complex<T> &lhs, const T &rhs) {
|
||||
Complex<T> result = lhs;
|
||||
result += rhs;
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
HOST_DEVICE inline Complex<T> operator+(const T &lhs, const Complex<T> &rhs) {
|
||||
Complex<T> result = rhs;
|
||||
result += lhs;
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
HOST_DEVICE inline Complex<T> operator-(const Complex<T> &lhs, const Complex<T> &rhs) {
|
||||
Complex<T> result = lhs;
|
||||
result -= rhs;
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
HOST_DEVICE inline Complex<T> operator-(const Complex<T> &lhs, const T &rhs) {
|
||||
Complex<T> result = lhs;
|
||||
result -= rhs;
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
HOST_DEVICE inline Complex<T> operator-(const T &lhs, const Complex<T> &rhs) {
|
||||
Complex<T> result(lhs, -rhs.imag());
|
||||
result -= rhs.real();
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
HOST_DEVICE inline Complex<T> operator*(const Complex<T> &lhs, const Complex<T> &rhs) {
|
||||
Complex<T> result = lhs;
|
||||
result *= rhs;
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
HOST_DEVICE inline Complex<T> operator*(const Complex<T> &lhs, const T &rhs) {
|
||||
Complex<T> result = lhs;
|
||||
result *= rhs;
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
HOST_DEVICE inline Complex<T> operator*(const T &lhs, const Complex<T> &rhs) {
|
||||
Complex<T> result = rhs;
|
||||
result *= lhs;
|
||||
return result;
|
||||
}
|
||||
|
||||
// Note: check division by zero before use it.
|
||||
template <typename T>
|
||||
HOST_DEVICE inline Complex<T> operator/(const Complex<T> &lhs, const Complex<T> &rhs) {
|
||||
Complex<T> result = lhs;
|
||||
result /= rhs;
|
||||
return result;
|
||||
}
|
||||
|
||||
// Note: check division by zero before use it.
|
||||
template <typename T>
|
||||
HOST_DEVICE inline Complex<T> operator/(const Complex<T> &lhs, const T &rhs) {
|
||||
Complex<T> result = lhs;
|
||||
result /= rhs;
|
||||
return result;
|
||||
}
|
||||
|
||||
// Note: check division by zero before use it.
|
||||
template <typename T>
|
||||
HOST_DEVICE inline Complex<T> operator/(const T &lhs, const Complex<T> &rhs) {
|
||||
Complex<T> result = lhs;
|
||||
result /= rhs;
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
HOST_DEVICE inline Complex<T> operator+(const Complex<T> &z) {
|
||||
return z;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
HOST_DEVICE inline Complex<T> operator-(const Complex<T> &z) {
|
||||
return Complex<T>(-z.real(), -z.imag());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
HOST_DEVICE inline bool operator==(const Complex<T> &lhs, const Complex<T> &rhs) {
|
||||
return lhs.real() == rhs.real() && lhs.imag() == rhs.imag();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
HOST_DEVICE inline bool operator==(const T &lhs, const Complex<T> &rhs) {
|
||||
return lhs == rhs.real() && rhs.imag() == 0;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
HOST_DEVICE inline bool operator==(const Complex<T> &lhs, const T &rhs) {
|
||||
return lhs.real() == rhs && lhs.imag() == 0;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
HOST_DEVICE inline bool operator!=(const Complex<T> &lhs, const Complex<T> &rhs) {
|
||||
return !(lhs == rhs);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
HOST_DEVICE inline bool operator!=(const T &lhs, const Complex<T> &rhs) {
|
||||
return !(lhs == rhs);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
HOST_DEVICE inline bool operator!=(const Complex<T> &lhs, const T &rhs) {
|
||||
return !(lhs == rhs);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline std::ostream &operator<<(std::ostream &os, const Complex<T> &v) {
|
||||
return (os << std::noshowpos << v.real() << std::showpos << v.imag() << 'j');
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
HOST_DEVICE inline T abs(const Complex<T> &z) {
|
||||
#if defined(__CUDACC__)
|
||||
return thrust::abs(thrust::complex<T>(z));
|
||||
#else
|
||||
return std::abs(std::complex<T>(z));
|
||||
#endif
|
||||
}
|
||||
} // namespace utils
|
||||
} // namespace mindspore
|
||||
|
||||
template <typename T>
|
||||
using Complex = mindspore::utils::Complex<T>;
|
||||
|
||||
namespace std {
|
||||
|
||||
template <typename T>
|
||||
class numeric_limits<mindspore::utils::Complex<T>> : public numeric_limits<T> {};
|
||||
|
||||
} // namespace std
|
||||
|
||||
#endif // MINDSPORE_CCSRC_UTILS_COPLEX_H_
|
|
@ -38,7 +38,8 @@ __dtype__ = [
|
|||
"number", "tensor",
|
||||
"string", "type_none",
|
||||
"tensor_type",
|
||||
"Type", "Int"
|
||||
"Type", "Int",
|
||||
"complex64", "complex128"
|
||||
]
|
||||
|
||||
__method__ = [
|
||||
|
@ -77,6 +78,8 @@ float32 = typing.Float(32)
|
|||
single = float32
|
||||
float64 = typing.Float(64)
|
||||
double = float64
|
||||
complex64 = typing.Complex(64)
|
||||
complex128 = typing.Complex(128)
|
||||
|
||||
number = typing.Number()
|
||||
int_ = typing.Int()
|
||||
|
@ -124,14 +127,16 @@ number_type = (int8,
|
|||
uint64,
|
||||
float16,
|
||||
float32,
|
||||
float64,)
|
||||
float64,
|
||||
complex64,
|
||||
complex128,)
|
||||
|
||||
int_type = (int8, int16, int32, int64,)
|
||||
uint_type = (uint8, uint16, uint32, uint64,)
|
||||
float_type = (float16, float32, float64,)
|
||||
|
||||
implicit_conversion_seq = {t: idx for idx, t in enumerate((
|
||||
bool_, int8, uint8, int16, int32, int64, float16, float32, float64))}
|
||||
bool_, int8, uint8, int16, int32, int64, float16, float32, float64, complex64, complex128))}
|
||||
|
||||
_simple_types = {
|
||||
list: list_,
|
||||
|
@ -140,6 +145,7 @@ _simple_types = {
|
|||
bool: bool_,
|
||||
int: int64,
|
||||
float: float64,
|
||||
complex: complex128,
|
||||
str: string,
|
||||
np.bool_: bool_,
|
||||
np.str: string,
|
||||
|
@ -228,6 +234,8 @@ def dtype_to_nptype(type_):
|
|||
float16: np.float16,
|
||||
float32: np.float32,
|
||||
float64: np.float64,
|
||||
complex64: np.complex64,
|
||||
complex128: np.complex128,
|
||||
}[type_]
|
||||
|
||||
|
||||
|
@ -260,6 +268,8 @@ def dtype_to_pytype(type_):
|
|||
list_: list,
|
||||
tuple_: tuple,
|
||||
string: str,
|
||||
complex64: complex,
|
||||
complex128: complex,
|
||||
type_none: type(None)
|
||||
}[type_]
|
||||
|
||||
|
|
|
@ -26,7 +26,7 @@ from .._checkparam import Validator as validator
|
|||
__all__ = ['Tensor', 'RowTensor', 'SparseTensor']
|
||||
np_types = (np.int8, np.int16, np.int32, np.int64,
|
||||
np.uint8, np.uint16, np.uint32, np.uint64, np.float16,
|
||||
np.float32, np.float64, np.bool_)
|
||||
np.float32, np.float64, np.bool_, np.complex64, np.complex128)
|
||||
|
||||
|
||||
class Tensor(Tensor_):
|
||||
|
@ -91,7 +91,7 @@ class Tensor(Tensor_):
|
|||
validator.check_value_type('input_data', input_data, (Tensor_, np.ndarray, list, tuple, float, int, bool),
|
||||
'Tensor')
|
||||
valid_dtypes = (np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64,
|
||||
np.float16, np.float32, np.float64, np.bool_, np.str_)
|
||||
np.float16, np.float32, np.float64, np.bool_, np.str_, np.complex64, np.complex128)
|
||||
if isinstance(input_data, np.ndarray) and input_data.dtype not in valid_dtypes and \
|
||||
input_data.dtype.kind != 'U': # Support dtype np.str_
|
||||
raise TypeError(f"For Tensor, the input_data is a numpy array, "
|
||||
|
|
|
@ -27,11 +27,12 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace abstract {
|
||||
const std::map<TypeId, size_t> type_map = {{kNumberTypeBool, 1}, {kNumberTypeInt, 4}, {kNumberTypeInt8, 1},
|
||||
{kNumberTypeInt16, 2}, {kNumberTypeInt32, 4}, {kNumberTypeInt64, 8},
|
||||
{kNumberTypeUInt, 4}, {kNumberTypeUInt8, 1}, {kNumberTypeUInt16, 2},
|
||||
{kNumberTypeUInt32, 4}, {kNumberTypeUInt64, 8}, {kNumberTypeFloat, 4},
|
||||
{kNumberTypeFloat16, 2}, {kNumberTypeFloat32, 4}, {kNumberTypeFloat64, 8}};
|
||||
const std::map<TypeId, size_t> type_map = {
|
||||
{kNumberTypeBool, 1}, {kNumberTypeInt, 4}, {kNumberTypeInt8, 1}, {kNumberTypeInt16, 2},
|
||||
{kNumberTypeInt32, 4}, {kNumberTypeInt64, 8}, {kNumberTypeUInt, 4}, {kNumberTypeUInt8, 1},
|
||||
{kNumberTypeUInt16, 2}, {kNumberTypeUInt32, 4}, {kNumberTypeUInt64, 8}, {kNumberTypeFloat, 4},
|
||||
{kNumberTypeFloat16, 2}, {kNumberTypeFloat32, 4}, {kNumberTypeFloat64, 8}, {kNumberTypeComplex64, 8},
|
||||
{kNumberTypeComplex128, 16}};
|
||||
|
||||
ValuePtr ValueJoin(const ValuePtr &value1, const ValuePtr &value2) {
|
||||
MS_EXCEPTION_IF_NULL(value1);
|
||||
|
|
|
@ -0,0 +1,75 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CORE_BASE_COMPLEX_STORAGE_H_
|
||||
#define MINDSPORE_CORE_BASE_COMPLEX_STORAGE_H_
|
||||
|
||||
#include "base/float16.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
||||
template <typename T>
|
||||
struct alignas(sizeof(T) * 2) ComplexStorage {
|
||||
T real_;
|
||||
T imag_;
|
||||
|
||||
ComplexStorage() = default;
|
||||
~ComplexStorage() = default;
|
||||
|
||||
ComplexStorage(const ComplexStorage<T> &other) noexcept = default;
|
||||
ComplexStorage(ComplexStorage<T> &&other) noexcept = default;
|
||||
|
||||
ComplexStorage &operator=(const ComplexStorage<T> &other) noexcept = default;
|
||||
ComplexStorage &operator=(ComplexStorage<T> &&other) noexcept = default;
|
||||
|
||||
inline constexpr ComplexStorage(const T &real, const T &imag = T()) : real_(real), imag_(imag) {}
|
||||
|
||||
inline explicit constexpr ComplexStorage(const float16 &real) : real_(static_cast<T>(real)), imag_(T()) {}
|
||||
|
||||
template <typename U = T>
|
||||
explicit ComplexStorage(const std::enable_if_t<std::is_same<U, float>::value, ComplexStorage<double>> &other)
|
||||
: real_(other.real_), imag_(other.imag_) {}
|
||||
|
||||
template <typename U = T>
|
||||
explicit ComplexStorage(const std::enable_if_t<std::is_same<U, double>::value, ComplexStorage<float>> &other)
|
||||
: real_(other.real_), imag_(other.imag_) {}
|
||||
|
||||
inline explicit operator bool() const { return static_cast<bool>(real_) || static_cast<bool>(imag_); }
|
||||
inline explicit operator signed char() const { return static_cast<signed char>(real_); }
|
||||
inline explicit operator unsigned char() const { return static_cast<unsigned char>(real_); }
|
||||
inline explicit operator double() const { return static_cast<double>(real_); }
|
||||
inline explicit operator float() const { return static_cast<float>(real_); }
|
||||
inline explicit operator int16_t() const { return static_cast<int16_t>(real_); }
|
||||
inline explicit operator uint16_t() const { return static_cast<uint16_t>(real_); }
|
||||
inline explicit operator int32_t() const { return static_cast<int32_t>(real_); }
|
||||
inline explicit operator uint32_t() const { return static_cast<uint32_t>(real_); }
|
||||
inline explicit operator int64_t() const { return static_cast<int64_t>(real_); }
|
||||
inline explicit operator uint64_t() const { return static_cast<uint64_t>(real_); }
|
||||
inline explicit operator float16() const { return static_cast<float16>(real_); }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline bool operator==(const ComplexStorage<T> &lhs, const ComplexStorage<T> &rhs) {
|
||||
return lhs.real_ == rhs.real_ && lhs.imag_ == rhs.imag_;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline std::ostream &operator<<(std::ostream &os, const ComplexStorage<T> &v) {
|
||||
return (os << std::noshowpos << v.real_ << std::showpos << v.imag_ << 'j');
|
||||
}
|
||||
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_BASE_COMPLEX_STORAGE_H_
|
|
@ -46,4 +46,10 @@ Float::Float(const int nbits) : Number(FloatBitsToTypeId(nbits), nbits, false) {
|
|||
MS_LOG(EXCEPTION) << "Wrong number of bits.";
|
||||
}
|
||||
}
|
||||
|
||||
Complex::Complex(const int nbits) : Number(ComplexBitsToTypeId(nbits), nbits, false) {
|
||||
if (nbits != 64 && nbits != 128) {
|
||||
MS_LOG(EXCEPTION) << "Wrong number of bits.";
|
||||
}
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -150,20 +150,19 @@ class MS_CORE_API Float : public Number {
|
|||
}
|
||||
};
|
||||
|
||||
// Complex64
|
||||
class MS_CORE_API Complex64 : public Number {
|
||||
// Complex
|
||||
class MS_CORE_API Complex : public Number {
|
||||
public:
|
||||
Complex64() : Number(kNumberTypeComplex64, 64, false) {}
|
||||
~Complex64() override {}
|
||||
MS_DECLARE_PARENT(Complex64, Number)
|
||||
Complex() : Number(kNumberTypeComplex64, 64, false) {}
|
||||
explicit Complex(const int nbits);
|
||||
~Complex() override {}
|
||||
MS_DECLARE_PARENT(Complex, Number)
|
||||
|
||||
TypeId generic_type_id() const override { return kNumberTypeComplex64; }
|
||||
TypePtr DeepCopy() const override { return std::make_shared<Complex64>(); }
|
||||
TypePtr DeepCopy() const override { return std::make_shared<Complex>(nbits()); }
|
||||
std::string ToString() const override { return GetTypeName("Complex"); }
|
||||
std::string ToReprString() const override { return nbits() == 0 ? "complex64_" : GetTypeName("complex64"); }
|
||||
std::string DumpText() const override {
|
||||
return nbits() == 0 ? std::string("Complex64") : std::string("C") + std::to_string(nbits());
|
||||
}
|
||||
std::string ToReprString() const override { return GetTypeName("complex"); }
|
||||
std::string DumpText() const override { return std::string("C") + std::to_string(nbits()); }
|
||||
};
|
||||
|
||||
inline const TypePtr kBool = std::make_shared<Bool>();
|
||||
|
@ -182,7 +181,8 @@ inline const TypePtr kInt = std::make_shared<Int>();
|
|||
inline const TypePtr kUInt = std::make_shared<UInt>();
|
||||
inline const TypePtr kFloat = std::make_shared<Float>();
|
||||
inline const TypePtr kNumber = std::make_shared<Number>();
|
||||
inline const TypePtr kComplex64 = std::make_shared<Complex64>();
|
||||
inline const TypePtr kComplex64 = std::make_shared<Complex>(64);
|
||||
inline const TypePtr kComplex128 = std::make_shared<Complex>(128);
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_IR_DTYPE_NUMBER_H_
|
||||
|
|
|
@ -87,6 +87,7 @@ enum class BitsNum : int {
|
|||
eBits16 = 16,
|
||||
eBits32 = 32,
|
||||
eBits64 = 64,
|
||||
eBits128 = 128,
|
||||
};
|
||||
TypeId IntBitsToTypeId(const int nbits) {
|
||||
switch (nbits) {
|
||||
|
@ -131,6 +132,17 @@ TypeId FloatBitsToTypeId(const int nbits) {
|
|||
}
|
||||
}
|
||||
|
||||
TypeId ComplexBitsToTypeId(const int nbits) {
|
||||
switch (nbits) {
|
||||
case static_cast<int>(BitsNum::eBits64):
|
||||
return kNumberTypeComplex64;
|
||||
case static_cast<int>(BitsNum::eBits128):
|
||||
return kNumberTypeComplex128;
|
||||
default:
|
||||
MS_LOG(EXCEPTION) << "Wrong number of bits:" << nbits;
|
||||
}
|
||||
}
|
||||
|
||||
const std::string &TypeIdLabel(const TypeId &v) {
|
||||
static const std::string unknown("[Unknown Type Id]");
|
||||
auto iter = g_type_2_lable.find(v);
|
||||
|
|
|
@ -41,6 +41,7 @@ namespace mindspore {
|
|||
TypeId IntBitsToTypeId(const int nbits);
|
||||
TypeId UIntBitsToTypeId(const int nbits);
|
||||
TypeId FloatBitsToTypeId(const int nbits);
|
||||
TypeId ComplexBitsToTypeId(const int nbits);
|
||||
const std::string &TypeIdLabel(const TypeId &v);
|
||||
TypeId NormalizeTypeId(const TypeId type_id);
|
||||
bool IsSameObjectType(const Type &lhs, const Type &rhs);
|
||||
|
|
|
@ -79,6 +79,7 @@ enum TypeId : int {
|
|||
kNumberTypeFloat32,
|
||||
kNumberTypeFloat64,
|
||||
kNumberTypeComplex64,
|
||||
kNumberTypeComplex128,
|
||||
kNumberTypeEnd,
|
||||
//
|
||||
// Monad Types
|
||||
|
|
|
@ -61,41 +61,20 @@ bool TypeListEqual::operator()(TypePtrList const &lhs, TypePtrList const &rhs) c
|
|||
}
|
||||
|
||||
TypePtr TypeIdToType(TypeId id) {
|
||||
static std::unordered_map<TypeId, TypePtr> type_id_to_type = {{kNumberTypeFloat16, kFloat16},
|
||||
{kNumberTypeFloat, kFloat32},
|
||||
{kNumberTypeFloat32, kFloat32},
|
||||
{kNumberTypeFloat64, kFloat64},
|
||||
{kNumberTypeComplex64, kComplex64},
|
||||
{kNumberTypeInt8, kInt8},
|
||||
{kNumberTypeInt16, kInt16},
|
||||
{kNumberTypeInt32, kInt32},
|
||||
{kNumberTypeInt, kInt32},
|
||||
{kNumberTypeInt64, kInt64},
|
||||
{kNumberTypeUInt8, kUInt8},
|
||||
{kNumberTypeUInt16, kUInt16},
|
||||
{kNumberTypeUInt32, kUInt32},
|
||||
{kNumberTypeUInt64, kUInt64},
|
||||
{kNumberTypeBool, kBool},
|
||||
{kMetaTypeExternal, kTypeExternal},
|
||||
{kMetaTypeAnything, kAnyType},
|
||||
{kMetaTypeNone, kTypeNone},
|
||||
{kMetaTypeNull, kTypeNull},
|
||||
{kMetaTypeEllipsis, kTypeEllipsis},
|
||||
{kObjectTypeEnvType, kTypeEnv},
|
||||
{kObjectTypeRefKey, kRefKeyType},
|
||||
{kObjectTypeRef, kRefType},
|
||||
{kMetaTypeTypeType, kTypeType},
|
||||
{kObjectTypeString, kString},
|
||||
{kObjectTypeList, kList},
|
||||
{kObjectTypeTuple, kTuple},
|
||||
{kObjectTypeDictionary, kDict},
|
||||
{kObjectTypeSlice, kSlice},
|
||||
{kObjectTypeKeyword, kKeyword},
|
||||
{kObjectTypeTensorType, kTensorType},
|
||||
{kObjectTypeUMonad, kUMonadType},
|
||||
{kObjectTypeIOMonad, kIOMonadType},
|
||||
{kTypeUnknown, kTypeNone},
|
||||
{kMetaTypeProblem, kTypeNone}};
|
||||
static std::unordered_map<TypeId, TypePtr> type_id_to_type = {
|
||||
{kNumberTypeFloat16, kFloat16}, {kNumberTypeFloat, kFloat32}, {kNumberTypeFloat32, kFloat32},
|
||||
{kNumberTypeFloat64, kFloat64}, {kNumberTypeComplex64, kComplex64}, {kNumberTypeInt8, kInt8},
|
||||
{kNumberTypeInt16, kInt16}, {kNumberTypeInt32, kInt32}, {kNumberTypeInt, kInt32},
|
||||
{kNumberTypeInt64, kInt64}, {kNumberTypeUInt8, kUInt8}, {kNumberTypeUInt16, kUInt16},
|
||||
{kNumberTypeUInt32, kUInt32}, {kNumberTypeUInt64, kUInt64}, {kNumberTypeBool, kBool},
|
||||
{kNumberTypeComplex64, kComplex64}, {kNumberTypeComplex128, kComplex128}, {kMetaTypeExternal, kTypeExternal},
|
||||
{kMetaTypeAnything, kAnyType}, {kMetaTypeNone, kTypeNone}, {kMetaTypeNull, kTypeNull},
|
||||
{kMetaTypeEllipsis, kTypeEllipsis}, {kObjectTypeEnvType, kTypeEnv}, {kObjectTypeRefKey, kRefKeyType},
|
||||
{kObjectTypeRef, kRefType}, {kMetaTypeTypeType, kTypeType}, {kObjectTypeString, kString},
|
||||
{kObjectTypeList, kList}, {kObjectTypeTuple, kTuple}, {kObjectTypeDictionary, kDict},
|
||||
{kObjectTypeSlice, kSlice}, {kObjectTypeKeyword, kKeyword}, {kObjectTypeTensorType, kTensorType},
|
||||
{kObjectTypeUMonad, kUMonadType}, {kObjectTypeIOMonad, kIOMonadType}, {kTypeUnknown, kTypeNone},
|
||||
{kMetaTypeProblem, kTypeNone}};
|
||||
const auto &it = type_id_to_type.find(id);
|
||||
if (it == type_id_to_type.end()) {
|
||||
MS_LOG(EXCEPTION) << "Not support the type: " << id;
|
||||
|
|
|
@ -31,6 +31,7 @@
|
|||
|
||||
#include "abstract/utils.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "base/complex_storage.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace tensor {
|
||||
|
@ -73,7 +74,10 @@ std::unique_ptr<T[]> NewData(const U *input, size_t size) {
|
|||
return nullptr;
|
||||
}
|
||||
auto data = std::make_unique<T[]>(size);
|
||||
if constexpr (!std::is_same<T, U>::value && (std::is_same<T, float16>::value || std::is_same<U, float16>::value)) {
|
||||
if constexpr (!std::is_same<T, U>::value &&
|
||||
(std::is_same<T, float16>::value || std::is_same<U, float16>::value ||
|
||||
std::is_same<T, ComplexStorage<float>>::value || std::is_same<U, ComplexStorage<float>>::value ||
|
||||
std::is_same<T, ComplexStorage<double>>::value || std::is_same<U, ComplexStorage<double>>::value)) {
|
||||
// Because float16 do not support implicit cast from/to other types,
|
||||
// We can not use std::copy() on array of float16, use a loop here.
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
|
@ -146,7 +150,11 @@ std::unique_ptr<T[]> CopyData(const ShapeVector &shape, void *const data, TypeId
|
|||
return NewData<T>(buf, size);
|
||||
}
|
||||
case kNumberTypeComplex64: {
|
||||
auto buf = static_cast<double *>(data);
|
||||
auto buf = static_cast<ComplexStorage<float> *>(data);
|
||||
return NewData<T>(buf, size);
|
||||
}
|
||||
case kNumberTypeComplex128: {
|
||||
auto buf = static_cast<ComplexStorage<double> *>(data);
|
||||
return NewData<T>(buf, size);
|
||||
}
|
||||
case kObjectTypeString: {
|
||||
|
@ -233,7 +241,8 @@ class TensorDataImpl : public TensorData {
|
|||
std::is_same<T, bool>::value || std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value ||
|
||||
std::is_same<T, int16_t>::value || std::is_same<T, int32_t>::value || std::is_same<T, int64_t>::value ||
|
||||
std::is_same<T, uint16_t>::value || std::is_same<T, uint32_t>::value || std::is_same<T, uint64_t>::value ||
|
||||
std::is_same<T, float16>::value || std::is_same<T, float>::value || std::is_same<T, double>::value;
|
||||
std::is_same<T, float16>::value || std::is_same<T, float>::value || std::is_same<T, double>::value ||
|
||||
std::is_same<T, ComplexStorage<float>>::value || std::is_same<T, ComplexStorage<double>>::value;
|
||||
static_assert(valid, "Type is invalid");
|
||||
if (data_size_ == 0) {
|
||||
return "";
|
||||
|
@ -302,10 +311,14 @@ class TensorDataImpl : public TensorData {
|
|||
constexpr auto isBool = std::is_same<T, bool>::value;
|
||||
constexpr auto isFloat =
|
||||
std::is_same<T, float16>::value || std::is_same<T, float>::value || std::is_same<T, double>::value;
|
||||
constexpr auto isComplex =
|
||||
std::is_same<T, ComplexStorage<float>>::value || std::is_same<T, ComplexStorage<double>>::value;
|
||||
constexpr int linefeedThreshold = isFloat ? kThreshold1DFloat : (isBool ? kThreshold1DBool : kThreshold1DInt);
|
||||
for (ssize_t i = start; i < end && (cursor + i) < static_cast<ssize_t>(data_size_); i++) {
|
||||
const auto value = data_[cursor + i];
|
||||
if constexpr (isFloat) {
|
||||
if constexpr (isComplex) {
|
||||
ss << value;
|
||||
} else if constexpr (isFloat) {
|
||||
OutputFloatDataString(ss, isScalar, value);
|
||||
} else if (isBool) {
|
||||
OutputBoolDataString(ss, isScalar, value);
|
||||
|
@ -458,7 +471,9 @@ TensorDataPtr MakeTensorData(TypeId data_type, const ShapeVector &shape, const A
|
|||
case kNumberTypeFloat64:
|
||||
return std::make_shared<TensorDataImpl<double>>(shape, args...);
|
||||
case kNumberTypeComplex64:
|
||||
return std::make_shared<TensorDataImpl<double>>(shape, args...);
|
||||
return std::make_shared<TensorDataImpl<ComplexStorage<float>>>(shape, args...);
|
||||
case kNumberTypeComplex128:
|
||||
return std::make_shared<TensorDataImpl<ComplexStorage<double>>>(shape, args...);
|
||||
case kObjectTypeString:
|
||||
return std::make_shared<TensorDataImpl<uint8_t>>(shape, args...);
|
||||
case kObjectTypeTensorType:
|
||||
|
|
|
@ -908,3 +908,6 @@ class DataType:
|
|||
F64_HWCN = ("float64", "HWCN")
|
||||
F64_NDHWC = ("float64", "NDHWC")
|
||||
F64_ChannelLast = ("float64", "ChannelLast")
|
||||
|
||||
C64_Default = ("complex64", "DefaultFormat")
|
||||
C128_Default = ("complex128", "DefaultFormat")
|
||||
|
|
|
@ -0,0 +1,158 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <memory>
|
||||
|
||||
#include "common/common_test.h"
|
||||
#include "utils/complex.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
||||
class TestComplex : public UT::Common {
|
||||
public:
|
||||
TestComplex() {}
|
||||
};
|
||||
|
||||
TEST_F(TestComplex, test_size) {
|
||||
ASSERT_EQ(sizeof(Complex<float>), 2 * sizeof(float));
|
||||
ASSERT_EQ(sizeof(Complex<double>), 2 * sizeof(double));
|
||||
ASSERT_EQ(alignof(Complex<float>), 2 * sizeof(float));
|
||||
ASSERT_EQ(alignof(Complex<double>), 2 * sizeof(double));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void test_construct() {
|
||||
constexpr T real = T(1.11f);
|
||||
constexpr T imag = T(2.22f);
|
||||
ASSERT_EQ(Complex<T>().real(), T());
|
||||
ASSERT_EQ(Complex<T>().imag(), T());
|
||||
ASSERT_EQ(Complex<T>(real, imag).real(), real);
|
||||
ASSERT_EQ(Complex<T>(real, imag).imag(), imag);
|
||||
ASSERT_EQ(Complex<T>(real).real(), real);
|
||||
ASSERT_EQ(Complex<T>(real).imag(), T());
|
||||
}
|
||||
|
||||
template <typename T1, typename T2>
|
||||
void test_conver_construct() {
|
||||
ASSERT_EQ(Complex<T1>(Complex<T2>(T2(1.11f), T2(2.22f))).real(), T1(1.11f));
|
||||
ASSERT_EQ(Complex<T1>(Complex<T2>(T2(1.11f), T2(2.22f))).imag(), T1(2.22f));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void test_conver_std_construct() {
|
||||
ASSERT_EQ(Complex<T>(std::complex<T>(T(1.11f), T(2.22f))).real(), T(1.11f));
|
||||
ASSERT_EQ(Complex<T>(std::complex<T>(T(1.11f), T(2.22f))).imag(), T(2.22f));
|
||||
}
|
||||
|
||||
TEST_F(TestComplex, test_construct) {
|
||||
test_construct<float>();
|
||||
test_construct<double>();
|
||||
test_conver_construct<float, float>();
|
||||
test_conver_construct<double, double>();
|
||||
test_conver_construct<float, double>();
|
||||
test_conver_construct<double, float>();
|
||||
test_conver_std_construct<float>();
|
||||
test_conver_std_construct<double>();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void test_convert_operator(T &&a) {
|
||||
ASSERT_EQ(static_cast<T>(Complex<float>(a)), a);
|
||||
}
|
||||
|
||||
TEST_F(TestComplex, test_convert_operator) {
|
||||
test_convert_operator<bool>(true);
|
||||
test_convert_operator<signed char>(1);
|
||||
test_convert_operator<unsigned char>(1);
|
||||
ASSERT_NEAR(static_cast<double>(Complex<float>(1.11)), 1.11, 0.001);
|
||||
test_convert_operator<float>(1.11f);
|
||||
test_convert_operator<int16_t>(1);
|
||||
test_convert_operator<uint16_t>(1);
|
||||
test_convert_operator<int32_t>(1);
|
||||
test_convert_operator<uint32_t>(1);
|
||||
test_convert_operator<int64_t>(1);
|
||||
test_convert_operator<uint64_t>(1);
|
||||
float16 a(1.11f);
|
||||
ASSERT_EQ(static_cast<float16>(Complex<float>(a)), a);
|
||||
}
|
||||
|
||||
TEST_F(TestComplex, test_assign_operator) {
|
||||
Complex<float> a = 1.11f;
|
||||
std::cout << a << std::endl;
|
||||
ASSERT_EQ(a.real(), 1.11f);
|
||||
ASSERT_EQ(a.imag(), float());
|
||||
a = Complex<double>(2.22f, 1.11f);
|
||||
ASSERT_EQ(a.real(), 2.22f);
|
||||
ASSERT_EQ(a.imag(), 1.11f);
|
||||
}
|
||||
|
||||
template <typename T1, typename T2, typename T3>
|
||||
void test_arithmetic_add(T1 lhs, T2 rhs, T3 r) {
|
||||
ASSERT_EQ(lhs + rhs, r);
|
||||
if constexpr (!(std::is_same<T1, float>::value || std::is_same<T1, double>::value)) {
|
||||
ASSERT_EQ(lhs += rhs, r);
|
||||
}
|
||||
}
|
||||
template <typename T1, typename T2, typename T3>
|
||||
void test_arithmetic_sub(T1 lhs, T2 rhs, T3 r) {
|
||||
ASSERT_EQ(lhs - rhs, r);
|
||||
if constexpr (!(std::is_same<T1, float>::value || std::is_same<T1, double>::value)) {
|
||||
ASSERT_EQ(lhs -= rhs, r);
|
||||
}
|
||||
}
|
||||
template <typename T1, typename T2, typename T3>
|
||||
void test_arithmetic_mul(T1 lhs, T2 rhs, T3 r) {
|
||||
ASSERT_EQ(lhs * rhs, r);
|
||||
if constexpr (!(std::is_same<T1, float>::value || std::is_same<T1, double>::value)) {
|
||||
ASSERT_EQ(lhs *= rhs, r);
|
||||
}
|
||||
}
|
||||
template <typename T1, typename T2, typename T3>
|
||||
void test_arithmetic_div(T1 lhs, T2 rhs, T3 r) {
|
||||
ASSERT_EQ(lhs / rhs, r);
|
||||
if constexpr (!(std::is_same<T1, float>::value || std::is_same<T1, double>::value)) {
|
||||
ASSERT_EQ(lhs /= rhs, r);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TestComplex, test_arithmetic) {
|
||||
test_arithmetic_add<Complex<float>, Complex<float>, Complex<float>>(
|
||||
Complex<float>(1.11, 2.22), Complex<float>(1.11, 2.22), Complex<float>(2.22, 4.44));
|
||||
test_arithmetic_add<Complex<float>, float, Complex<float>>(Complex<float>(1.11, 2.22), 1.11,
|
||||
Complex<float>(2.22, 2.22));
|
||||
test_arithmetic_add<float, Complex<float>, Complex<float>>(1.11, Complex<float>(1.11, 2.22),
|
||||
Complex<float>(2.22, 2.22));
|
||||
|
||||
test_arithmetic_sub<Complex<float>, Complex<float>, Complex<float>>(Complex<float>(1.11, 2.22),
|
||||
Complex<float>(1.11, 2.22), Complex<float>(0, 0));
|
||||
test_arithmetic_sub<Complex<float>, float, Complex<float>>(Complex<float>(1.11, 2.22), 1.11, Complex<float>(0, 2.22));
|
||||
test_arithmetic_sub<float, Complex<float>, Complex<float>>(1.11, Complex<float>(1.11, 2.22),
|
||||
Complex<float>(0, -2.22));
|
||||
|
||||
test_arithmetic_mul<Complex<float>, Complex<float>, Complex<float>>(
|
||||
Complex<float>(1.11, 2.22), Complex<float>(1.11, 2.22), Complex<float>(-3.6963, 4.9284));
|
||||
test_arithmetic_mul<Complex<float>, float, Complex<float>>(Complex<float>(1.11, 2.22), 1.11,
|
||||
Complex<float>(1.2321, 2.22));
|
||||
test_arithmetic_mul<float, Complex<float>, Complex<float>>(1.11, Complex<float>(1.11, 2.22),
|
||||
Complex<float>(1.2321, 2.22));
|
||||
|
||||
test_arithmetic_div<Complex<float>, Complex<float>, Complex<float>>(Complex<float>(1.11, 2.22),
|
||||
Complex<float>(1.11, 2.22), Complex<float>(1, 0));
|
||||
test_arithmetic_div<Complex<float>, float, Complex<float>>(Complex<float>(1.11, 2.22), 1.11, Complex<float>(1, 2.22));
|
||||
test_arithmetic_div<float, Complex<float>, Complex<float>>(1.11, Complex<float>(1.11, 2.22),
|
||||
Complex<float>(0.2, -0.4));
|
||||
}
|
||||
|
||||
} // namespace mindspore
|
|
@ -35,6 +35,8 @@ def test_dtype_to_nptype():
|
|||
assert ms.dtype_to_nptype(ms.float16) == np.float16
|
||||
assert ms.dtype_to_nptype(ms.float32) == np.float32
|
||||
assert ms.dtype_to_nptype(ms.float64) == np.float64
|
||||
assert ms.dtype_to_nptype(ms.complex64) == np.complex64
|
||||
assert ms.dtype_to_nptype(ms.complex128) == np.complex128
|
||||
|
||||
|
||||
def test_dtype_to_pytype():
|
||||
|
@ -51,6 +53,8 @@ def test_dtype_to_pytype():
|
|||
assert ms.dtype_to_pytype(ms.float16) == float
|
||||
assert ms.dtype_to_pytype(ms.float32) == float
|
||||
assert ms.dtype_to_pytype(ms.float64) == float
|
||||
assert ms.dtype_to_pytype(ms.complex64) == complex
|
||||
assert ms.dtype_to_pytype(ms.complex128) == complex
|
||||
assert ms.dtype_to_pytype(ms.list_) == list
|
||||
assert ms.dtype_to_pytype(ms.tuple_) == tuple
|
||||
assert ms.dtype_to_pytype(ms.string) == str
|
||||
|
@ -94,6 +98,12 @@ def test_dtype():
|
|||
me_type = dtype.get_py_obj_dtype(x)
|
||||
assert me_type == ms.bool_
|
||||
|
||||
x = 0.1+3j
|
||||
me_type = dtype.get_py_obj_dtype(type(x))
|
||||
assert me_type == ms.complex128
|
||||
me_type = dtype.get_py_obj_dtype(x)
|
||||
assert me_type == ms.complex128
|
||||
|
||||
# support str
|
||||
# x = "string type"
|
||||
|
||||
|
|
|
@ -74,6 +74,45 @@ def test_tensor_type_float16():
|
|||
assert t_float16.shape == (2, 3)
|
||||
assert t_float16.dtype == ms.float16
|
||||
|
||||
def test_tensor_type_complex64():
|
||||
np_input = np.array(
|
||||
[[1+0.1j, 2j, 3+0.3j], [4-0.4j, 5, 6]], dtype=np.complex64)
|
||||
t_complex64 = ms.Tensor(np_input)
|
||||
assert isinstance(t_complex64, ms.Tensor)
|
||||
assert t_complex64.shape == (2, 3)
|
||||
assert t_complex64.dtype == ms.complex64
|
||||
assert np.all(t_complex64.asnumpy() == np_input)
|
||||
|
||||
|
||||
def test_tensor_type_complex64_user_define():
|
||||
np_input = np.zeros([1, 2, 3])
|
||||
t_complex64 = ms.Tensor(np_input, ms.complex64)
|
||||
assert isinstance(t_complex64, ms.Tensor)
|
||||
assert t_complex64.shape == (1, 2, 3)
|
||||
assert t_complex64.dtype == ms.complex64
|
||||
assert np.all(t_complex64.asnumpy() == np_input)
|
||||
|
||||
|
||||
def test_tensor_type_complex128():
|
||||
np_input = np.array(
|
||||
[[1+0.1j, 2j, 3+0.3j], [4-0.4j, 5, 6]], dtype=np.complex128)
|
||||
t_complex128 = ms.Tensor(np_input)
|
||||
assert isinstance(t_complex128, ms.Tensor)
|
||||
assert t_complex128.shape == (2, 3)
|
||||
assert t_complex128.dtype == ms.complex128
|
||||
assert np.all(t_complex128.asnumpy() == np_input)
|
||||
np_input = (1, 2.22222222j, 3)
|
||||
t_complex128 = ms.Tensor(np_input)
|
||||
assert np.all(t_complex128.asnumpy() == np_input)
|
||||
|
||||
|
||||
def test_tensor_type_complex128_user_define():
|
||||
np_input = np.zeros([1, 2, 3])
|
||||
t_complex128 = ms.Tensor(np_input, ms.complex128)
|
||||
assert isinstance(t_complex128, ms.Tensor)
|
||||
assert t_complex128.shape == (1, 2, 3)
|
||||
assert t_complex128.dtype == ms.complex128
|
||||
assert np.all(t_complex128.asnumpy() == np_input)
|
||||
|
||||
def test_tensor_type_float32():
|
||||
t_float32 = ms.Tensor(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32))
|
||||
|
@ -332,13 +371,6 @@ def test_tensor_input_ndarray_bool():
|
|||
inp = np.array([False, 2, 4])
|
||||
ms.Tensor(inp)
|
||||
|
||||
|
||||
def test_tensor_input_ndarray_complex():
|
||||
with pytest.raises(TypeError):
|
||||
inp = np.array([20j, 2, 4])
|
||||
ms.Tensor(inp)
|
||||
|
||||
|
||||
def test_tensor_input_ndarray_none():
|
||||
with pytest.raises(TypeError):
|
||||
inp = np.array([None, 2, 4])
|
||||
|
@ -445,6 +477,19 @@ def test_tensor_dtype_fp64_to_uint8():
|
|||
assert t.shape == (2, 3)
|
||||
assert t.dtype == ms.uint8
|
||||
|
||||
def test_tensor_dtype_complex64_to_float32():
|
||||
array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.complex64)
|
||||
t = ms.Tensor(array, ms.float32)
|
||||
assert isinstance(t, ms.Tensor)
|
||||
assert t.shape == (2, 3)
|
||||
assert t.dtype == ms.float32
|
||||
|
||||
def test_tensor_dtype_float32_to_complex64():
|
||||
array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32)
|
||||
t = ms.Tensor(array, ms.complex64)
|
||||
assert isinstance(t, ms.Tensor)
|
||||
assert t.shape == (2, 3)
|
||||
assert t.dtype == ms.complex64
|
||||
|
||||
def test_tensor_operation():
|
||||
x = Tensor(np.ones((3, 3)) * 4)
|
||||
|
|
|
@ -200,6 +200,12 @@ def test_parameter_lazy_init():
|
|||
assert isinstance(para.data, Tensor)
|
||||
assert np.array_equal(para.data.asnumpy(), np.ones((1, 2, 3)))
|
||||
|
||||
para = Parameter(initializer('ones', [1, 2, 3], mstype.complex64), 'test1')
|
||||
assert isinstance(para.data, Tensor)
|
||||
para = para.init_data()
|
||||
assert isinstance(para.data, Tensor)
|
||||
assert np.array_equal(para.data.asnumpy(), np.ones((1, 2, 3)))
|
||||
|
||||
# Call init_data() after set_data is set.
|
||||
para = Parameter(initializer('ones', [1, 2, 3], mstype.float32), 'test2')
|
||||
assert isinstance(para.data, Tensor)
|
||||
|
|
Loading…
Reference in New Issue