forked from mindspore-Ecosystem/mindspore
!16157 Add complex data type for Tensor
From: @liangzhibo Reviewed-by: @zh_qh,@ginfung Signed-off-by: @zh_qh
This commit is contained in:
commit
14332cb6c9
|
@ -47,12 +47,6 @@ Float::Float(const int nbits) : Number(FloatBitsToTypeId(nbits), nbits, false) {
|
|||
}
|
||||
}
|
||||
|
||||
Complex::Complex(const int nbits) : Number(TypeId::kNumberTypeComplex64, nbits, false) {
|
||||
if (nbits != 64) {
|
||||
MS_LOG(EXCEPTION) << "Wrong number of bits.";
|
||||
}
|
||||
}
|
||||
|
||||
const TypePtr kBool = std::make_shared<Bool>();
|
||||
const TypePtr kInt8 = std::make_shared<Int>(8);
|
||||
const TypePtr kInt16 = std::make_shared<Int>(16);
|
||||
|
@ -69,5 +63,5 @@ const TypePtr kInt = std::make_shared<Int>();
|
|||
const TypePtr kUInt = std::make_shared<UInt>();
|
||||
const TypePtr kFloat = std::make_shared<Float>();
|
||||
const TypePtr kNumber = std::make_shared<Number>();
|
||||
const TypePtr kComplex64 = std::make_shared<Complex>(64);
|
||||
const TypePtr kComplex64 = std::make_shared<Complex64>();
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -150,21 +150,15 @@ class Float : public Number {
|
|||
}
|
||||
};
|
||||
|
||||
// Complex
|
||||
class Complex : public Number {
|
||||
// Complex64
|
||||
class Complex64 : public Number {
|
||||
public:
|
||||
Complex() : Number(kNumberTypeComplex64, 0) {}
|
||||
explicit Complex(const int nbits);
|
||||
~Complex() override {}
|
||||
MS_DECLARE_PARENT(Complex, Number)
|
||||
Complex64() : Number(kNumberTypeComplex64, 64, false) {}
|
||||
~Complex64() override {}
|
||||
MS_DECLARE_PARENT(Complex64, Number)
|
||||
|
||||
TypeId generic_type_id() const override { return kNumberTypeComplex64; }
|
||||
TypePtr DeepCopy() const override {
|
||||
if (nbits() == 0) {
|
||||
return std::make_shared<Complex>();
|
||||
}
|
||||
return std::make_shared<Complex>(nbits());
|
||||
}
|
||||
TypePtr DeepCopy() const override { return std::make_shared<Complex64>(); }
|
||||
std::string ToString() const override { return GetTypeName("Complex64"); }
|
||||
std::string ToReprString() const override { return nbits() == 0 ? "complex64_" : GetTypeName("complex64"); }
|
||||
std::string DumpText() const override {
|
||||
|
|
|
@ -145,6 +145,10 @@ std::unique_ptr<T[]> CopyData(const ShapeVector &shape, void *const data, TypeId
|
|||
auto buf = static_cast<double *>(data);
|
||||
return NewData<T>(buf, size);
|
||||
}
|
||||
case kNumberTypeComplex64: {
|
||||
auto buf = static_cast<double *>(data);
|
||||
return NewData<T>(buf, size);
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
@ -447,6 +451,8 @@ TensorDataPtr MakeTensorData(TypeId data_type, const ShapeVector &shape, const A
|
|||
return std::make_shared<TensorDataImpl<float>>(shape, args...);
|
||||
case kNumberTypeFloat64:
|
||||
return std::make_shared<TensorDataImpl<double>>(shape, args...);
|
||||
case kNumberTypeComplex64:
|
||||
return std::make_shared<TensorDataImpl<double>>(shape, args...);
|
||||
case kObjectTypeString:
|
||||
return std::make_shared<TensorDataImpl<uint8_t>>(shape, args...);
|
||||
case kObjectTypeTensorType:
|
||||
|
|
Loading…
Reference in New Issue