forked from mindspore-Ecosystem/mindspore
Sync enterprise code to gitee.
This commit is contained in:
parent
c88480a804
commit
66b5416f88
|
@ -37,6 +37,7 @@ enum DeviceType {
|
||||||
kAscend,
|
kAscend,
|
||||||
kAscend910,
|
kAscend910,
|
||||||
kAscend310,
|
kAscend310,
|
||||||
|
kCustomDevice,
|
||||||
// add new type here
|
// add new type here
|
||||||
kInvalidDeviceType = 100,
|
kInvalidDeviceType = 100,
|
||||||
};
|
};
|
||||||
|
@ -146,7 +147,7 @@ class MS_API Context {
|
||||||
/// heterogeneous scenarios with multiple members in the vector.
|
/// heterogeneous scenarios with multiple members in the vector.
|
||||||
///
|
///
|
||||||
/// \return Mutable reference of DeviceInfoContext vector in this context.
|
/// \return Mutable reference of DeviceInfoContext vector in this context.
|
||||||
std::vector<std::shared_ptr<DeviceInfoContext>> &MutableDeviceInfo();
|
std::vector<std::shared_ptr<DeviceInfoContext>> &MutableDeviceInfo() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::shared_ptr<Data> data_;
|
std::shared_ptr<Data> data_;
|
||||||
|
@ -182,16 +183,17 @@ class MS_API DeviceInfoContext : public std::enable_shared_from_this<DeviceInfoC
|
||||||
///
|
///
|
||||||
/// \return provider's name.
|
/// \return provider's name.
|
||||||
inline std::string GetProvider() const;
|
inline std::string GetProvider() const;
|
||||||
|
|
||||||
/// \brief set provider's name.
|
/// \brief set provider's name.
|
||||||
///
|
///
|
||||||
/// \param[in] provider define the provider's name.
|
/// \param[in] provider define the provider's name.
|
||||||
|
|
||||||
inline void SetProvider(const std::string &provider);
|
inline void SetProvider(const std::string &provider);
|
||||||
|
|
||||||
/// \brief obtain provider's device type.
|
/// \brief obtain provider's device type.
|
||||||
///
|
///
|
||||||
/// \return provider's device type.
|
/// \return provider's device type.
|
||||||
|
|
||||||
inline std::string GetProviderDevice() const;
|
inline std::string GetProviderDevice() const;
|
||||||
|
|
||||||
/// \brief set provider's device type.
|
/// \brief set provider's device type.
|
||||||
///
|
///
|
||||||
/// \param[in] device define the provider's device type.EG: CPU.
|
/// \param[in] device define the provider's device type.EG: CPU.
|
||||||
|
|
|
@ -95,7 +95,7 @@ std::vector<int32_t> Context::GetThreadAffinityCoreList() const {
|
||||||
return data_->affinity_core_list_;
|
return data_->affinity_core_list_;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::shared_ptr<DeviceInfoContext>> &Context::MutableDeviceInfo() {
|
std::vector<std::shared_ptr<DeviceInfoContext>> &Context::MutableDeviceInfo() const {
|
||||||
MS_EXCEPTION_IF_NULL(data_);
|
MS_EXCEPTION_IF_NULL(data_);
|
||||||
return data_->device_info_list;
|
return data_->device_info_list;
|
||||||
}
|
}
|
||||||
|
|
|
@ -30,11 +30,13 @@ namespace ops {
|
||||||
namespace {
|
namespace {
|
||||||
abstract::ShapePtr RintInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
abstract::ShapePtr RintInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||||
MS_EXCEPTION_IF_NULL(primitive);
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
MS_EXCEPTION_IF_ZERO("Rint input number", input_args.size());
|
||||||
|
MS_EXCEPTION_IF_NULL(input_args[0]);
|
||||||
|
|
||||||
auto prim_name = primitive->name();
|
auto prim_name = primitive->name();
|
||||||
const int64_t input_num = 1;
|
const int64_t input_num = 1;
|
||||||
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kGreaterEqual, input_num,
|
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kGreaterEqual, input_num,
|
||||||
prim_name);
|
prim_name);
|
||||||
MS_EXCEPTION_IF_NULL(input_args[0]);
|
|
||||||
(void)CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, 0);
|
(void)CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, 0);
|
||||||
auto x = input_args[0]->BuildShape();
|
auto x = input_args[0]->BuildShape();
|
||||||
MS_EXCEPTION_IF_NULL(x);
|
MS_EXCEPTION_IF_NULL(x);
|
||||||
|
@ -44,6 +46,9 @@ abstract::ShapePtr RintInferShape(const PrimitivePtr &primitive, const std::vect
|
||||||
}
|
}
|
||||||
|
|
||||||
TypePtr RintInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
TypePtr RintInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
MS_EXCEPTION_IF_ZERO("Rint input number", input_args.size());
|
||||||
|
MS_EXCEPTION_IF_NULL(input_args[0]);
|
||||||
|
|
||||||
for (const auto &item : input_args) {
|
for (const auto &item : input_args) {
|
||||||
MS_EXCEPTION_IF_NULL(item);
|
MS_EXCEPTION_IF_NULL(item);
|
||||||
}
|
}
|
||||||
|
|
|
@ -41,6 +41,7 @@ typedef enum {
|
||||||
DT_GPU, /**< GPU device type */
|
DT_GPU, /**< GPU device type */
|
||||||
DT_NPU, /**< NPU device type */
|
DT_NPU, /**< NPU device type */
|
||||||
DT_ASCEND, /**< ASCEND device type */
|
DT_ASCEND, /**< ASCEND device type */
|
||||||
|
DT_CUSTOM, /**< EXTEND device type */
|
||||||
DT_END /**< NO device type */
|
DT_END /**< NO device type */
|
||||||
} DeviceType;
|
} DeviceType;
|
||||||
|
|
||||||
|
|
|
@ -107,6 +107,17 @@ std::shared_ptr<mindspore::AscendDeviceInfo> AscendDeviceInfoFromAscendDeviceCon
|
||||||
ascend_info->SetDynamicImageSize(ascend_context.device_info_.ascend_device_info_.image_size_);
|
ascend_info->SetDynamicImageSize(ascend_context.device_info_.ascend_device_info_.image_size_);
|
||||||
return ascend_info;
|
return ascend_info;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<mindspore::DeviceInfoContext> CustomDeviceInfoFromCustomDeviceContext(
|
||||||
|
const lite::DeviceContext &inner_context) {
|
||||||
|
if (inner_context.device_type_ != DT_CUSTOM) {
|
||||||
|
MS_LOG(ERROR) << "Function input parameter is not extended context.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
auto device_info = inner_context.device_info_.custom_device_info_.user_defined_device_info_;
|
||||||
|
MS_CHECK_TRUE_RET(device_info != nullptr, nullptr);
|
||||||
|
return device_info;
|
||||||
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
mindspore::Context *MSContextFromContext(const std::shared_ptr<InnerContext> &context) {
|
mindspore::Context *MSContextFromContext(const std::shared_ptr<InnerContext> &context) {
|
||||||
|
@ -130,7 +141,8 @@ mindspore::Context *MSContextFromContext(const std::shared_ptr<InnerContext> &co
|
||||||
transfer_funcs = {{DT_CPU, CPUDeviceInfoFromCPUDeviceContext},
|
transfer_funcs = {{DT_CPU, CPUDeviceInfoFromCPUDeviceContext},
|
||||||
{DT_GPU, GPUDeviceInfoFromGPUDeviceContext},
|
{DT_GPU, GPUDeviceInfoFromGPUDeviceContext},
|
||||||
{DT_NPU, NPUDeviceInfoFromNPUDeviceContext},
|
{DT_NPU, NPUDeviceInfoFromNPUDeviceContext},
|
||||||
{DT_ASCEND, AscendDeviceInfoFromAscendDeviceContext}};
|
{DT_ASCEND, AscendDeviceInfoFromAscendDeviceContext},
|
||||||
|
{DT_CUSTOM, CustomDeviceInfoFromCustomDeviceContext}};
|
||||||
for (auto &device_context : context->device_list_) {
|
for (auto &device_context : context->device_list_) {
|
||||||
auto device_type = device_context.device_type_;
|
auto device_type = device_context.device_type_;
|
||||||
if (transfer_funcs.find(device_type) == transfer_funcs.end()) {
|
if (transfer_funcs.find(device_type) == transfer_funcs.end()) {
|
||||||
|
|
|
@ -211,7 +211,7 @@ bool Context::GetMultiModalHW() const {
|
||||||
return data_->float_mode;
|
return data_->float_mode;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::shared_ptr<DeviceInfoContext>> &Context::MutableDeviceInfo() {
|
std::vector<std::shared_ptr<DeviceInfoContext>> &Context::MutableDeviceInfo() const {
|
||||||
static std::vector<std::shared_ptr<DeviceInfoContext>> empty{};
|
static std::vector<std::shared_ptr<DeviceInfoContext>> empty{};
|
||||||
if (data_ == nullptr) {
|
if (data_ == nullptr) {
|
||||||
MS_LOG(ERROR) << "Invalid context.";
|
MS_LOG(ERROR) << "Invalid context.";
|
||||||
|
|
|
@ -231,7 +231,7 @@ bool Context::GetMultiModalHW() const {
|
||||||
return data_->float_mode;
|
return data_->float_mode;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::shared_ptr<DeviceInfoContext>> &Context::MutableDeviceInfo() {
|
std::vector<std::shared_ptr<DeviceInfoContext>> &Context::MutableDeviceInfo() const {
|
||||||
static std::vector<std::shared_ptr<DeviceInfoContext>> empty{};
|
static std::vector<std::shared_ptr<DeviceInfoContext>> empty{};
|
||||||
if (data_ == nullptr) {
|
if (data_ == nullptr) {
|
||||||
MS_LOG(ERROR) << "Invalid context.";
|
MS_LOG(ERROR) << "Invalid context.";
|
||||||
|
|
|
@ -78,6 +78,14 @@ Status ContextUtils::AddAscendDevice(lite::InnerContext *inner_context, DeviceIn
|
||||||
return kSuccess;
|
return kSuccess;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status ContextUtils::AddCustomDevice(lite::InnerContext *inner_context,
|
||||||
|
const std::shared_ptr<DeviceInfoContext> &device) {
|
||||||
|
lite::DeviceInfo device_info;
|
||||||
|
device_info.custom_device_info_ = {device};
|
||||||
|
inner_context->device_list_.push_back({lite::DeviceType::DT_CUSTOM, device_info});
|
||||||
|
return kSuccess;
|
||||||
|
}
|
||||||
|
|
||||||
void ContextUtils::ResetContextDefaultParam(Context *context) {
|
void ContextUtils::ResetContextDefaultParam(Context *context) {
|
||||||
if (context->GetInterOpParallelNum() == 0) {
|
if (context->GetInterOpParallelNum() == 0) {
|
||||||
context->SetInterOpParallelNum(kDefaultInterOpParallelNum);
|
context->SetInterOpParallelNum(kDefaultInterOpParallelNum);
|
||||||
|
@ -147,6 +155,8 @@ std::shared_ptr<lite::InnerContext> ContextUtils::Convert(Context *context) {
|
||||||
ret = AddNpuDevice(npu_context->GetEnableFP16(), npu_context->GetFrequency(), inner_context.get());
|
ret = AddNpuDevice(npu_context->GetEnableFP16(), npu_context->GetFrequency(), inner_context.get());
|
||||||
} else if (device->GetDeviceType() == kAscend) {
|
} else if (device->GetDeviceType() == kAscend) {
|
||||||
ret = AddAscendDevice(inner_context.get(), device.get());
|
ret = AddAscendDevice(inner_context.get(), device.get());
|
||||||
|
} else if (device->GetDeviceType() == kCustomDevice) {
|
||||||
|
ret = AddCustomDevice(inner_context.get(), device);
|
||||||
}
|
}
|
||||||
if (ret != kSuccess) {
|
if (ret != kSuccess) {
|
||||||
MS_LOG(ERROR) << "Add device failed!";
|
MS_LOG(ERROR) << "Add device failed!";
|
||||||
|
|
|
@ -47,6 +47,7 @@ class ContextUtils {
|
||||||
lite::InnerContext *inner_context);
|
lite::InnerContext *inner_context);
|
||||||
static Status AddNpuDevice(bool enable_fp16, int frequency, lite::InnerContext *inner_context);
|
static Status AddNpuDevice(bool enable_fp16, int frequency, lite::InnerContext *inner_context);
|
||||||
static Status AddAscendDevice(lite::InnerContext *inner_context, DeviceInfoContext *device);
|
static Status AddAscendDevice(lite::InnerContext *inner_context, DeviceInfoContext *device);
|
||||||
|
static Status AddCustomDevice(lite::InnerContext *inner_context, const std::shared_ptr<DeviceInfoContext> &device);
|
||||||
static bool IsAffinityModeValid(int affinity_mode) {
|
static bool IsAffinityModeValid(int affinity_mode) {
|
||||||
return affinity_mode >= lite::NO_BIND && affinity_mode <= lite::MID_CPU;
|
return affinity_mode >= lite::NO_BIND && affinity_mode <= lite::MID_CPU;
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,6 +19,7 @@
|
||||||
#include <set>
|
#include <set>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include <memory>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#ifdef BFC_MEMORY
|
#ifdef BFC_MEMORY
|
||||||
#include "src/extendrt/dynamic_mem_allocator.h"
|
#include "src/extendrt/dynamic_mem_allocator.h"
|
||||||
|
@ -32,6 +33,10 @@
|
||||||
#endif
|
#endif
|
||||||
#include "include/lite_types.h"
|
#include "include/lite_types.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
class DeviceInfoContext;
|
||||||
|
}
|
||||||
|
|
||||||
namespace mindspore::lite {
|
namespace mindspore::lite {
|
||||||
typedef struct CpuDeviceInfo {
|
typedef struct CpuDeviceInfo {
|
||||||
bool enable_float16_ = false; /**< prior enable float16 inference */
|
bool enable_float16_ = false; /**< prior enable float16 inference */
|
||||||
|
@ -59,11 +64,17 @@ typedef struct AscendDeviceInfo {
|
||||||
std::string image_size_;
|
std::string image_size_;
|
||||||
} AscendDeviceInfo;
|
} AscendDeviceInfo;
|
||||||
|
|
||||||
|
/// \brief CustomDeviceInfo defined for user defined device configuration information.
|
||||||
|
typedef struct CustomDeviceInfo {
|
||||||
|
std::shared_ptr<DeviceInfoContext> user_defined_device_info_;
|
||||||
|
} CustomDeviceInfo;
|
||||||
|
|
||||||
struct DeviceInfo {
|
struct DeviceInfo {
|
||||||
CpuDeviceInfo cpu_device_info_;
|
CpuDeviceInfo cpu_device_info_;
|
||||||
GpuDeviceInfo gpu_device_info_;
|
GpuDeviceInfo gpu_device_info_;
|
||||||
NpuDeviceInfo npu_device_info_;
|
NpuDeviceInfo npu_device_info_;
|
||||||
AscendDeviceInfo ascend_device_info_;
|
AscendDeviceInfo ascend_device_info_;
|
||||||
|
CustomDeviceInfo custom_device_info_;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct DeviceContext {
|
struct DeviceContext {
|
||||||
|
|
|
@ -1,7 +1,10 @@
|
||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
set -e
|
set -e
|
||||||
CUR_DIR=$(cd "$(dirname $0)"; pwd)
|
CUR_DIR=$(
|
||||||
|
cd "$(dirname $0)"
|
||||||
|
pwd
|
||||||
|
)
|
||||||
BUILD_DIR=${CUR_DIR}/../build
|
BUILD_DIR=${CUR_DIR}/../build
|
||||||
|
|
||||||
export GLOG_v=2
|
export GLOG_v=2
|
||||||
|
@ -58,10 +61,13 @@ echo 'run common ut tests'
|
||||||
./lite-test --gtest_filter=TestDeConvolutionFp32*
|
./lite-test --gtest_filter=TestDeConvolutionFp32*
|
||||||
./lite-test --gtest_filter=TestLogicalOrFp32*
|
./lite-test --gtest_filter=TestLogicalOrFp32*
|
||||||
|
|
||||||
|
# test cases of generic api
|
||||||
|
./lite-test --gtest_filter="GenericApiTest*"
|
||||||
|
|
||||||
# test cases of INT8 OP
|
# test cases of INT8 OP
|
||||||
## ./lite-test --gtest_filter=TestPadInt8.*
|
## ./lite-test --gtest_filter=TestPadInt8.*
|
||||||
./lite-test --gtest_filter=TestDeconvInt8.*
|
./lite-test --gtest_filter=TestDeconvInt8.*
|
||||||
if [ "$ENABLE_CONVERTER_TEST" = true ];then
|
if [ "$ENABLE_CONVERTER_TEST" = true ]; then
|
||||||
./lite-test-converter --gtest_filter="ModelParserRegistryTest.TestRegistry"
|
./lite-test-converter --gtest_filter="ModelParserRegistryTest.TestRegistry"
|
||||||
./lite-test-converter --gtest_filter="NodeParserRegistryTest.TestRegistry"
|
./lite-test-converter --gtest_filter="NodeParserRegistryTest.TestRegistry"
|
||||||
./lite-test-converter --gtest_filter="PassRegistryTest.TestRegistry"
|
./lite-test-converter --gtest_filter="PassRegistryTest.TestRegistry"
|
||||||
|
@ -87,7 +93,7 @@ echo 'run inference ut tests'
|
||||||
./lite-test --gtest_filter="ControlFlowTest.TestMergeWhileModel"
|
./lite-test --gtest_filter="ControlFlowTest.TestMergeWhileModel"
|
||||||
|
|
||||||
echo 'run mindrt parallel ut test'
|
echo 'run mindrt parallel ut test'
|
||||||
if [ "$ENABLE_CONVERTER_TEST" = true ];then
|
if [ "$ENABLE_CONVERTER_TEST" = true ]; then
|
||||||
./lite-test-converter --gtest_filter="MindrtParallelTest.*"
|
./lite-test-converter --gtest_filter="MindrtParallelTest.*"
|
||||||
echo 'user set output tensors st test'
|
echo 'user set output tensors st test'
|
||||||
./lite-test --gtest_filter="GraphTest.UserSetGraphOutput*"
|
./lite-test --gtest_filter="GraphTest.UserSetGraphOutput*"
|
||||||
|
@ -117,8 +123,8 @@ echo 'run c api ut test'
|
||||||
echo 'run bfc memory ut test'
|
echo 'run bfc memory ut test'
|
||||||
./lite-test --gtest_filter="DynamicMemManagerTest.*"
|
./lite-test --gtest_filter="DynamicMemManagerTest.*"
|
||||||
|
|
||||||
mindspore_lite_whl=`ls ${CUR_DIR}/../../../output/*.whl`
|
mindspore_lite_whl=$(ls ${CUR_DIR}/../../../output/*.whl)
|
||||||
if [[ -f "${mindspore_lite_whl}" || "$MSLITE_ENABLE_SERVER_INFERENCE" = on ]]; then
|
if [[ -f "${mindspore_lite_whl}" || "$MSLITE_ENABLE_SERVER_INFERENCE" == on ]]; then
|
||||||
# prepare model and inputdata for Python-API ut test
|
# prepare model and inputdata for Python-API ut test
|
||||||
if [ ! -e mobilenetv2.ms ]; then
|
if [ ! -e mobilenetv2.ms ]; then
|
||||||
MODEL_DOWNLOAD_URL="https://download.mindspore.cn/model_zoo/official/lite/quick_start/mobilenetv2.ms"
|
MODEL_DOWNLOAD_URL="https://download.mindspore.cn/model_zoo/official/lite/quick_start/mobilenetv2.ms"
|
||||||
|
@ -148,7 +154,7 @@ else
|
||||||
pytest ${CUR_DIR}/ut/python/test_converter_api.py -s
|
pytest ${CUR_DIR}/ut/python/test_converter_api.py -s
|
||||||
RET=$?
|
RET=$?
|
||||||
if [ ${RET} -ne 0 ]; then
|
if [ ${RET} -ne 0 ]; then
|
||||||
exit ${RET}
|
exit ${RET}
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
@ -157,7 +163,7 @@ else
|
||||||
pytest ${CUR_DIR}/ut/python/test_inference_api.py -s
|
pytest ${CUR_DIR}/ut/python/test_inference_api.py -s
|
||||||
RET=$?
|
RET=$?
|
||||||
if [ ${RET} -ne 0 ]; then
|
if [ ${RET} -ne 0 ]; then
|
||||||
exit ${RET}
|
exit ${RET}
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# run inference CPU Python-API st test
|
# run inference CPU Python-API st test
|
||||||
|
@ -165,16 +171,16 @@ else
|
||||||
pytest ${CUR_DIR}/st/python/test_inference.py::test_cpu_inference_01 -s
|
pytest ${CUR_DIR}/st/python/test_inference.py::test_cpu_inference_01 -s
|
||||||
RET=$?
|
RET=$?
|
||||||
if [ ${RET} -ne 0 ]; then
|
if [ ${RET} -ne 0 ]; then
|
||||||
exit ${RET}
|
exit ${RET}
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ "$MSLITE_ENABLE_SERVER_INFERENCE" = on ];then
|
if [ "$MSLITE_ENABLE_SERVER_INFERENCE" = on ]; then
|
||||||
echo 'run ModelParallelRunner api ut test'
|
echo 'run ModelParallelRunner api ut test'
|
||||||
./lite-test --gtest_filter="ModelParallelRunnerTest.*"
|
./lite-test --gtest_filter="ModelParallelRunnerTest.*"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ "$MSLITE_ENABLE_KERNEL_EXECUTOR" = on ];then
|
if [ "$MSLITE_ENABLE_KERNEL_EXECUTOR" = on ]; then
|
||||||
echo 'run kernel executor api ut test'
|
echo 'run kernel executor api ut test'
|
||||||
./lite-test --gtest_filter="KernelExecutorTest.*"
|
./lite-test --gtest_filter="KernelExecutorTest.*"
|
||||||
fi
|
fi
|
||||||
|
|
|
@ -0,0 +1,65 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include <memory>
|
||||||
|
#include "include/api/types.h"
|
||||||
|
#include "include/api/context.h"
|
||||||
|
#include "common/common_test.h"
|
||||||
|
#include "src/runtime/cxx_api/converters.h"
|
||||||
|
#include "src/common/context_util.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
class GenericApiTest : public mindspore::CommonTest {
|
||||||
|
public:
|
||||||
|
GenericApiTest() {}
|
||||||
|
};
|
||||||
|
|
||||||
|
class Tda4DeviceInfo : public mindspore::DeviceInfoContext {
|
||||||
|
public:
|
||||||
|
mindspore::DeviceType GetDeviceType() const override { return mindspore::DeviceType::kCustomDevice; };
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(GenericApiTest, TestConvertContextToInnerContext) {
|
||||||
|
mindspore::Context *context = new (std::nothrow) mindspore::Context();
|
||||||
|
auto &device_list = context->MutableDeviceInfo();
|
||||||
|
auto device_info = std::make_shared<mindspore::CPUDeviceInfo>();
|
||||||
|
auto tda4_device_info = std::make_shared<mindspore::Tda4DeviceInfo>();
|
||||||
|
device_list.push_back(device_info);
|
||||||
|
device_list.push_back(tda4_device_info);
|
||||||
|
|
||||||
|
lite::InnerContext *inner_ctx = ContextUtils::Convert(context);
|
||||||
|
|
||||||
|
ASSERT_EQ(inner_ctx->device_list_.size(), device_list.size());
|
||||||
|
ASSERT_EQ(inner_ctx->device_list_[0].device_type_, mindspore::lite::DT_CPU);
|
||||||
|
ASSERT_EQ(inner_ctx->device_list_[1].device_type_, mindspore::lite::DT_CUSTOM);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(GenericApiTest, TestConvertInnerContextToContext) {
|
||||||
|
mindspore::Context *context = new (std::nothrow) mindspore::Context();
|
||||||
|
auto &device_list = context->MutableDeviceInfo();
|
||||||
|
auto device_info = std::make_shared<mindspore::CPUDeviceInfo>();
|
||||||
|
auto tda4_device_info = std::make_shared<mindspore::Tda4DeviceInfo>();
|
||||||
|
device_list.push_back(device_info);
|
||||||
|
device_list.push_back(tda4_device_info);
|
||||||
|
|
||||||
|
lite::InnerContext *inner_ctx = ContextUtils::Convert(context);
|
||||||
|
mindspore::Context *ctx = MSContextFromContext(inner_ctx);
|
||||||
|
auto &new_device_list = ctx->MutableDeviceInfo();
|
||||||
|
|
||||||
|
ASSERT_EQ(new_device_list.size(), device_list.size());
|
||||||
|
ASSERT_EQ(new_device_list[0]->GetDeviceType(), mindspore::DeviceType::kCPU);
|
||||||
|
ASSERT_EQ(new_device_list[1]->GetDeviceType(), mindspore::DeviceType::kCustomDevice);
|
||||||
|
}
|
||||||
|
} // namespace mindspore
|
|
@ -22,7 +22,8 @@ int EliminateRedundantCastPass::RemoveCastOp(const AnfNodePtr &anf_node, const F
|
||||||
const int expected_cast_input_count = 3;
|
const int expected_cast_input_count = 3;
|
||||||
auto cast_cnode = anf_node->cast<CNodePtr>();
|
auto cast_cnode = anf_node->cast<CNodePtr>();
|
||||||
MS_CHECK_TRUE_RET(cast_cnode->inputs().size() == expected_cast_input_count, lite::RET_NO_CHANGE);
|
MS_CHECK_TRUE_RET(cast_cnode->inputs().size() == expected_cast_input_count, lite::RET_NO_CHANGE);
|
||||||
TypeId first_type, second_type;
|
TypeId first_type;
|
||||||
|
TypeId second_type;
|
||||||
if (opt::GetDataTypeFromAnfNode(cast_cnode->input(1), &first_type) != RET_OK) {
|
if (opt::GetDataTypeFromAnfNode(cast_cnode->input(1), &first_type) != RET_OK) {
|
||||||
MS_LOG(ERROR) << "Failed to get " << anf_node->fullname_with_scope() << " output tensor data type.";
|
MS_LOG(ERROR) << "Failed to get " << anf_node->fullname_with_scope() << " output tensor data type.";
|
||||||
return lite::RET_NO_CHANGE;
|
return lite::RET_NO_CHANGE;
|
||||||
|
|
Loading…
Reference in New Issue