!16157 Add complex data type for Tensor

From: @liangzhibo
Reviewed-by: @zh_qh,@ginfung
Signed-off-by: @zh_qh
This commit is contained in:
mindspore-ci-bot 2021-05-15 09:50:26 +08:00 committed by Gitee
commit 14332cb6c9
3 changed files with 13 additions and 19 deletions

View File

@ -47,12 +47,6 @@ Float::Float(const int nbits) : Number(FloatBitsToTypeId(nbits), nbits, false) {
}
}
Complex::Complex(const int nbits) : Number(TypeId::kNumberTypeComplex64, nbits, false) {
if (nbits != 64) {
MS_LOG(EXCEPTION) << "Wrong number of bits.";
}
}
const TypePtr kBool = std::make_shared<Bool>();
const TypePtr kInt8 = std::make_shared<Int>(8);
const TypePtr kInt16 = std::make_shared<Int>(16);
@ -69,5 +63,5 @@ const TypePtr kInt = std::make_shared<Int>();
const TypePtr kUInt = std::make_shared<UInt>();
const TypePtr kFloat = std::make_shared<Float>();
const TypePtr kNumber = std::make_shared<Number>();
const TypePtr kComplex64 = std::make_shared<Complex>(64);
const TypePtr kComplex64 = std::make_shared<Complex64>();
} // namespace mindspore

View File

@ -150,21 +150,15 @@ class Float : public Number {
}
};
// Complex
class Complex : public Number {
// Complex64
class Complex64 : public Number {
public:
Complex() : Number(kNumberTypeComplex64, 0) {}
explicit Complex(const int nbits);
~Complex() override {}
MS_DECLARE_PARENT(Complex, Number)
Complex64() : Number(kNumberTypeComplex64, 64, false) {}
~Complex64() override {}
MS_DECLARE_PARENT(Complex64, Number)
TypeId generic_type_id() const override { return kNumberTypeComplex64; }
TypePtr DeepCopy() const override {
if (nbits() == 0) {
return std::make_shared<Complex>();
}
return std::make_shared<Complex>(nbits());
}
TypePtr DeepCopy() const override { return std::make_shared<Complex64>(); }
std::string ToString() const override { return GetTypeName("Complex64"); }
std::string ToReprString() const override { return nbits() == 0 ? "complex64_" : GetTypeName("complex64"); }
std::string DumpText() const override {

View File

@ -145,6 +145,10 @@ std::unique_ptr<T[]> CopyData(const ShapeVector &shape, void *const data, TypeId
auto buf = static_cast<double *>(data);
return NewData<T>(buf, size);
}
case kNumberTypeComplex64: {
auto buf = static_cast<double *>(data);
return NewData<T>(buf, size);
}
default:
break;
}
@ -447,6 +451,8 @@ TensorDataPtr MakeTensorData(TypeId data_type, const ShapeVector &shape, const A
return std::make_shared<TensorDataImpl<float>>(shape, args...);
case kNumberTypeFloat64:
return std::make_shared<TensorDataImpl<double>>(shape, args...);
case kNumberTypeComplex64:
return std::make_shared<TensorDataImpl<double>>(shape, args...);
case kObjectTypeString:
return std::make_shared<TensorDataImpl<uint8_t>>(shape, args...);
case kObjectTypeTensorType: