forked from mindspore-Ecosystem/mindspore
!29764 Add more required apis in MindAPI
Merge pull request !29764 from hewei/core_api
This commit is contained in:
commit
a42bc503f0
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -46,6 +46,11 @@ class MIND_API AbstractBase : public Base {
|
|||
/// \return A pointer to the Value.
|
||||
ValuePtr value() const;
|
||||
|
||||
/// \brief Get the shape of the abstract.
|
||||
///
|
||||
/// \return A pointer to the shape.
|
||||
ShapePtr shape() const;
|
||||
|
||||
/// \brief Set the type for this abstract.
|
||||
///
|
||||
/// \param[in] type The type to be set.
|
||||
|
@ -55,6 +60,11 @@ class MIND_API AbstractBase : public Base {
|
|||
///
|
||||
/// \param[in] value The value to be set.
|
||||
void set_value(const ValuePtr &value);
|
||||
|
||||
/// \brief Set the shape for this abstract.
|
||||
///
|
||||
/// \param[in] shape The shape to be set.
|
||||
void set_shape(const ShapePtr &shape);
|
||||
};
|
||||
|
||||
/// \brief AbstractScalar describes a scalar's type and value.
|
||||
|
@ -116,11 +126,6 @@ class MIND_API AbstractTensor : public AbstractBase {
|
|||
///
|
||||
/// \return A pointer to the element abstract.
|
||||
AbstractBasePtr element() const;
|
||||
|
||||
/// \brief Get the shape of the abstract.
|
||||
///
|
||||
/// \return A pointer to the shape.
|
||||
ShapePtr shape() const;
|
||||
};
|
||||
|
||||
using AbstractTensorPtr = SharedPtr<AbstractTensor>;
|
||||
|
@ -142,6 +147,11 @@ using AbstractSequencePtr = SharedPtr<AbstractSequence>;
|
|||
class MIND_API AbstractTuple : public AbstractSequence {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(AbstractTuple);
|
||||
|
||||
/// \brief Create AbstractTuple from a list of element abstracts.
|
||||
///
|
||||
/// \param[in] elements A list of abstracts.
|
||||
explicit AbstractTuple(const AbstractBasePtrList &elements);
|
||||
};
|
||||
} // namespace mindspore::api
|
||||
#endif // MINDSPORE_CORE_MINDAPI_IR_ABSTRACT_H_
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -27,6 +27,11 @@ class MIND_API Shape : public Base {
|
|||
public:
|
||||
MIND_API_BASE_MEMBER(Shape);
|
||||
|
||||
/// \brief Create Shape with the given shape dimensions.
|
||||
///
|
||||
/// \param[in] shape The shape dimensions.
|
||||
explicit Shape(const ShapeVector &shape);
|
||||
|
||||
/// \brief Get the shape dimensions.
|
||||
///
|
||||
/// \return The shape dimensions.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -17,11 +17,13 @@
|
|||
#include "mindapi/ir/abstract.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "abstract/dshape.h"
|
||||
#include "ir/dtype.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
using TypeImpl = mindspore::Type;
|
||||
using ValueImpl = mindspore::Value;
|
||||
using ShapeImpl = mindspore::abstract::Shape;
|
||||
using AbstractBaseImpl = mindspore::abstract::AbstractBase;
|
||||
|
||||
MIND_API_BASE_IMPL(AbstractBase, AbstractBaseImpl, Base);
|
||||
|
@ -41,6 +43,11 @@ ValuePtr AbstractBase::value() const {
|
|||
return ToWrapper<Value>(v);
|
||||
}
|
||||
|
||||
ShapePtr AbstractBase::shape() const {
|
||||
auto s = ToRef<AbstractBaseImpl>(impl_).GetShapeTrack();
|
||||
return ToWrapper<Shape>(s);
|
||||
}
|
||||
|
||||
void AbstractBase::set_type(const TypePtr &type) {
|
||||
auto type_impl = ToImpl<TypeImpl>(type);
|
||||
ToRef<AbstractBaseImpl>(impl_).set_type(type_impl);
|
||||
|
@ -51,6 +58,11 @@ void AbstractBase::set_value(const ValuePtr &value) {
|
|||
ToRef<AbstractBaseImpl>(impl_).set_value(value_impl);
|
||||
}
|
||||
|
||||
void AbstractBase::set_shape(const ShapePtr &shape) {
|
||||
auto shape_impl = ToImpl<ShapeImpl>(shape);
|
||||
ToRef<AbstractBaseImpl>(impl_).set_shape(shape_impl);
|
||||
}
|
||||
|
||||
using AbstractScalarImpl = mindspore::abstract::AbstractScalar;
|
||||
|
||||
MIND_API_BASE_IMPL(AbstractScalar, AbstractScalarImpl, AbstractBase);
|
||||
|
@ -84,11 +96,6 @@ AbstractBasePtr AbstractTensor::element() const {
|
|||
return ToWrapper<AbstractBase>(abs);
|
||||
}
|
||||
|
||||
ShapePtr AbstractTensor::shape() const {
|
||||
auto s = ToRef<AbstractTensorImpl>(impl_).shape();
|
||||
return ToWrapper<Shape>(s);
|
||||
}
|
||||
|
||||
using AbstractSequenceImpl = mindspore::abstract::AbstractSequence;
|
||||
|
||||
MIND_API_BASE_IMPL(AbstractSequence, AbstractSequenceImpl, AbstractBase);
|
||||
|
@ -101,4 +108,7 @@ AbstractBasePtrList AbstractSequence::elements() const {
|
|||
using AbstractTupleImpl = mindspore::abstract::AbstractTuple;
|
||||
|
||||
MIND_API_BASE_IMPL(AbstractTuple, AbstractTupleImpl, AbstractSequence);
|
||||
|
||||
AbstractTuple::AbstractTuple(const AbstractBasePtrList &elements)
|
||||
: AbstractSequence(std::make_shared<AbstractTupleImpl>(ToImplVector<AbstractBaseImpl>(elements))) {}
|
||||
} // namespace mindspore::api
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -23,5 +23,7 @@ using ShapeImpl = mindspore::abstract::Shape;
|
|||
|
||||
MIND_API_BASE_IMPL(Shape, ShapeImpl, Base);
|
||||
|
||||
Shape::Shape(const ShapeVector &shape) : Base(std::make_shared<ShapeImpl>(shape)) {}
|
||||
|
||||
const ShapeVector &Shape::shape() const { return ToRef<ShapeImpl>(impl_).shape(); }
|
||||
} // namespace mindspore::api
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -418,4 +418,31 @@ TEST_F(TestMindApi, test_api_logging) {
|
|||
}
|
||||
ASSERT_TRUE(true);
|
||||
}
|
||||
|
||||
/// Feature: MindAPI
|
||||
/// Description: test AbstractSequence API.
|
||||
/// Expectation: AbstractSequence work as expected.
|
||||
TEST_F(TestMindApi, test_abstract_sequence) {
|
||||
AbstractBasePtrList abs_list;
|
||||
abs_list.emplace_back(MakeShared<AbstractScalar>(int64_t(1)));
|
||||
abs_list.emplace_back(MakeShared<AbstractScalar>(float(1.2f)));
|
||||
abs_list.emplace_back(MakeShared<AbstractScalar>(true));
|
||||
abs_list.emplace_back(MakeShared<AbstractScalar>(std::string("hello")));
|
||||
ShapeVector shape{1, 2, 3};
|
||||
abs_list.emplace_back(MakeShared<AbstractTensor>(TypeId::kNumberTypeFloat32, shape));
|
||||
auto abs_tuple = MakeShared<AbstractTuple>(abs_list);
|
||||
ASSERT_EQ(abs_tuple->elements().size(), abs_list.size());
|
||||
ASSERT_EQ(GetValue<int64_t>(abs_tuple->elements()[0]->value()), 1);
|
||||
ASSERT_TRUE(abs_tuple->elements()[1]->value()->isa<FP32Imm>());
|
||||
ASSERT_TRUE(GetValue<bool>(abs_tuple->elements()[2]->value()));
|
||||
ASSERT_EQ(GetValue<std::string>(abs_tuple->elements()[3]->value()), "hello");
|
||||
ASSERT_TRUE(abs_tuple->elements()[4]->isa<AbstractTensor>());
|
||||
ASSERT_EQ(abs_tuple->elements()[4]->type()->type_id(), TypeId::kObjectTypeTensorType);
|
||||
ASSERT_EQ(abs_tuple->elements()[4]->shape()->shape(), shape);
|
||||
ASSERT_EQ(abs_tuple->elements()[4]->cast<AbstractTensorPtr>()->element()->type()->type_id(),
|
||||
TypeId::kNumberTypeFloat32);
|
||||
ShapeVector shape2{2, 3, 4};
|
||||
abs_tuple->elements()[4]->set_shape(MakeShared<Shape>(shape2));
|
||||
ASSERT_EQ(abs_tuple->elements()[4]->shape()->shape(), shape2);
|
||||
}
|
||||
} // namespace mindspore::api
|
||||
|
|
Loading…
Reference in New Issue