!28891 Add UInt8Imm to MindAPI

Merge pull request !28891 from hewei/fix_r1.6
This commit is contained in:
i-robot 2022-01-12 07:15:55 +00:00 committed by Gitee
commit 5a788fde8e
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 43 additions and 7 deletions

View File

@ -100,13 +100,13 @@ using StringImmPtr = SharedPtr<StringImm>;
MIND_API_IMM_TRAIT(StringImm, std::string);
/// \beief Scalar defines interface for scalar data.
/// \brief Scalar defines interface for scalar data.
class MIND_API Scalar : public Value {
public:
MIND_API_BASE_MEMBER(Scalar);
};
/// \beief BoolImm defines interface for bool data.
/// \brief BoolImm defines interface for bool data.
class MIND_API BoolImm : public Scalar {
public:
MIND_API_BASE_MEMBER(BoolImm);
@ -126,20 +126,20 @@ using BoolImmPtr = SharedPtr<BoolImm>;
MIND_API_IMM_TRAIT(BoolImm, bool);
/// \beief IntegerImm defines interface for integer data.
/// \brief IntegerImm defines interface for integer data.
class MIND_API IntegerImm : public Scalar {
public:
MIND_API_BASE_MEMBER(IntegerImm);
};
/// \beief Int64Imm defines interface for int64 data.
/// \brief Int64Imm defines interface for int64 data.
class MIND_API Int64Imm : public IntegerImm {
public:
MIND_API_BASE_MEMBER(Int64Imm);
/// \brief Create Int64Imm with the given int64 value.
///
/// \param[in] value The given bool value.
/// \param[in] value The given int64 value.
explicit Int64Imm(int64_t value);
/// \brief Get the int64 value of this Int64Imm.
@ -152,13 +152,33 @@ using Int64ImmPtr = SharedPtr<Int64Imm>;
MIND_API_IMM_TRAIT(Int64Imm, int64_t);
/// \beief FloatImm defines interface for float data.
/// \brief UInt8Imm defines interface for uint8 data.
class MIND_API UInt8Imm : public IntegerImm {
public:
MIND_API_BASE_MEMBER(UInt8Imm);
/// \brief Create UInt8Imm with the given uint8 value.
///
/// \param[in] value The given uint8 value.
explicit UInt8Imm(uint8_t value);
/// \brief Get the uint8 value of this UInt8Imm.
///
/// \return The uint8 value of this UInt8Imm.
uint8_t value() const;
};
using UInt8ImmPtr = SharedPtr<UInt8Imm>;
MIND_API_IMM_TRAIT(UInt8Imm, uint8_t);
/// \brief FloatImm defines interface for float data.
class MIND_API FloatImm : public Scalar {
public:
MIND_API_BASE_MEMBER(FloatImm);
};
/// \beief FP32Imm defines interface for float32 data.
/// \brief FP32Imm defines interface for float32 data.
class MIND_API FP32Imm : public FloatImm {
public:
MIND_API_BASE_MEMBER(FP32Imm);

View File

@ -33,6 +33,7 @@ using ScalarImpl = mindspore::Scalar;
using BoolImmImpl = mindspore::BoolImm;
using IntegerImmImpl = mindspore::IntegerImm;
using Int64ImmImpl = mindspore::Int64Imm;
using UInt8ImmImpl = mindspore::UInt8Imm;
using FloatImmImpl = mindspore::FloatImm;
using FP32ImmImpl = mindspore::FP32Imm;
@ -84,6 +85,12 @@ Int64Imm::Int64Imm(int64_t value) : IntegerImm(std::make_shared<Int64ImmImpl>(va
int64_t Int64Imm::value() const { return ToRef<Int64ImmImpl>(impl_).value(); }
MIND_API_BASE_IMPL(UInt8Imm, UInt8ImmImpl, IntegerImm);
UInt8Imm::UInt8Imm(uint8_t value) : UInt8Imm(std::make_shared<UInt8ImmImpl>(value)) {}
uint8_t UInt8Imm::value() const { return ToRef<UInt8ImmImpl>(impl_).value(); }
MIND_API_BASE_IMPL(FloatImm, FloatImmImpl, Scalar);
MIND_API_BASE_IMPL(FP32Imm, FP32ImmImpl, FloatImm);

View File

@ -141,6 +141,15 @@ TEST_F(TestMindApi, test_values) {
ASSERT_EQ(utils::cast<int64_t>(value_list[0]), 3);
ASSERT_EQ(utils::cast<int64_t>(value_list[1]), 4);
ASSERT_EQ(utils::cast<int64_t>(value_list[2]), 5);
std::vector<uint8_t> vec_uint8{5, 6, 7};
auto uint8_seq = MakeValue<std::vector<uint8_t>>(vec_uint8);
ASSERT_TRUE(uint8_seq->isa<ValueSequence>());
auto uint8_values = GetValue<std::vector<uint8_t>>(uint8_seq);
ASSERT_EQ(uint8_values.size(), 3);
ASSERT_EQ(uint8_values[0], 5);
ASSERT_EQ(uint8_values[1], 6);
ASSERT_EQ(uint8_values[2], 7);
}
/// Feature: MindAPI