!28891 Add UInt8Imm to MindAPI
Merge pull request !28891 from hewei/fix_r1.6
This commit is contained in:
commit
5a788fde8e
|
@ -100,13 +100,13 @@ using StringImmPtr = SharedPtr<StringImm>;
|
|||
|
||||
MIND_API_IMM_TRAIT(StringImm, std::string);
|
||||
|
||||
/// \beief Scalar defines interface for scalar data.
|
||||
/// \brief Scalar defines interface for scalar data.
|
||||
class MIND_API Scalar : public Value {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(Scalar);
|
||||
};
|
||||
|
||||
/// \beief BoolImm defines interface for bool data.
|
||||
/// \brief BoolImm defines interface for bool data.
|
||||
class MIND_API BoolImm : public Scalar {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(BoolImm);
|
||||
|
@ -126,20 +126,20 @@ using BoolImmPtr = SharedPtr<BoolImm>;
|
|||
|
||||
MIND_API_IMM_TRAIT(BoolImm, bool);
|
||||
|
||||
/// \beief IntegerImm defines interface for integer data.
|
||||
/// \brief IntegerImm defines interface for integer data.
|
||||
class MIND_API IntegerImm : public Scalar {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(IntegerImm);
|
||||
};
|
||||
|
||||
/// \beief Int64Imm defines interface for int64 data.
|
||||
/// \brief Int64Imm defines interface for int64 data.
|
||||
class MIND_API Int64Imm : public IntegerImm {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(Int64Imm);
|
||||
|
||||
/// \brief Create Int64Imm with the given int64 value.
|
||||
///
|
||||
/// \param[in] value The given bool value.
|
||||
/// \param[in] value The given int64 value.
|
||||
explicit Int64Imm(int64_t value);
|
||||
|
||||
/// \brief Get the int64 value of this Int64Imm.
|
||||
|
@ -152,13 +152,33 @@ using Int64ImmPtr = SharedPtr<Int64Imm>;
|
|||
|
||||
MIND_API_IMM_TRAIT(Int64Imm, int64_t);
|
||||
|
||||
/// \beief FloatImm defines interface for float data.
|
||||
/// \brief UInt8Imm defines interface for uint8 data.
|
||||
class MIND_API UInt8Imm : public IntegerImm {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(UInt8Imm);
|
||||
|
||||
/// \brief Create UInt8Imm with the given uint8 value.
|
||||
///
|
||||
/// \param[in] value The given uint8 value.
|
||||
explicit UInt8Imm(uint8_t value);
|
||||
|
||||
/// \brief Get the uint8 value of this UInt8Imm.
|
||||
///
|
||||
/// \return The uint8 value of this UInt8Imm.
|
||||
uint8_t value() const;
|
||||
};
|
||||
|
||||
using UInt8ImmPtr = SharedPtr<UInt8Imm>;
|
||||
|
||||
MIND_API_IMM_TRAIT(UInt8Imm, uint8_t);
|
||||
|
||||
/// \brief FloatImm defines interface for float data.
|
||||
class MIND_API FloatImm : public Scalar {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(FloatImm);
|
||||
};
|
||||
|
||||
/// \beief FP32Imm defines interface for float32 data.
|
||||
/// \brief FP32Imm defines interface for float32 data.
|
||||
class MIND_API FP32Imm : public FloatImm {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(FP32Imm);
|
||||
|
|
|
@ -33,6 +33,7 @@ using ScalarImpl = mindspore::Scalar;
|
|||
using BoolImmImpl = mindspore::BoolImm;
|
||||
using IntegerImmImpl = mindspore::IntegerImm;
|
||||
using Int64ImmImpl = mindspore::Int64Imm;
|
||||
using UInt8ImmImpl = mindspore::UInt8Imm;
|
||||
using FloatImmImpl = mindspore::FloatImm;
|
||||
using FP32ImmImpl = mindspore::FP32Imm;
|
||||
|
||||
|
@ -84,6 +85,12 @@ Int64Imm::Int64Imm(int64_t value) : IntegerImm(std::make_shared<Int64ImmImpl>(va
|
|||
|
||||
int64_t Int64Imm::value() const { return ToRef<Int64ImmImpl>(impl_).value(); }
|
||||
|
||||
MIND_API_BASE_IMPL(UInt8Imm, UInt8ImmImpl, IntegerImm);
|
||||
|
||||
UInt8Imm::UInt8Imm(uint8_t value) : UInt8Imm(std::make_shared<UInt8ImmImpl>(value)) {}
|
||||
|
||||
uint8_t UInt8Imm::value() const { return ToRef<UInt8ImmImpl>(impl_).value(); }
|
||||
|
||||
MIND_API_BASE_IMPL(FloatImm, FloatImmImpl, Scalar);
|
||||
|
||||
MIND_API_BASE_IMPL(FP32Imm, FP32ImmImpl, FloatImm);
|
||||
|
|
|
@ -141,6 +141,15 @@ TEST_F(TestMindApi, test_values) {
|
|||
ASSERT_EQ(utils::cast<int64_t>(value_list[0]), 3);
|
||||
ASSERT_EQ(utils::cast<int64_t>(value_list[1]), 4);
|
||||
ASSERT_EQ(utils::cast<int64_t>(value_list[2]), 5);
|
||||
|
||||
std::vector<uint8_t> vec_uint8{5, 6, 7};
|
||||
auto uint8_seq = MakeValue<std::vector<uint8_t>>(vec_uint8);
|
||||
ASSERT_TRUE(uint8_seq->isa<ValueSequence>());
|
||||
auto uint8_values = GetValue<std::vector<uint8_t>>(uint8_seq);
|
||||
ASSERT_EQ(uint8_values.size(), 3);
|
||||
ASSERT_EQ(uint8_values[0], 5);
|
||||
ASSERT_EQ(uint8_values[1], 6);
|
||||
ASSERT_EQ(uint8_values[2], 7);
|
||||
}
|
||||
|
||||
/// Feature: MindAPI
|
||||
|
|
Loading…
Reference in New Issue