forked from mindspore-Ecosystem/mindspore
!47997 [JIT Fallback] Don't release GIL lock during running in Python fallback routine.
Merge pull request !47997 from 张清华/opt_jit_fallback2
This commit is contained in:
commit
9e0108a31a
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
* Copyright 2022-2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -31,6 +31,7 @@
|
|||
#include "kernel/common_utils.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "mindspore/ccsrc/plugin/device/cpu/kernel/pyexecute/py_execute_cpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt::dynamic_shape {
|
||||
|
@ -197,6 +198,12 @@ abstract::AbstractBasePtr MakeNewAbstract(const AnfNodePtr &input, const tensor:
|
|||
if (abs->isa<abstract::AbstractTensor>()) {
|
||||
new_abs = abs->Clone();
|
||||
new_abs->set_value(depended_value);
|
||||
|
||||
// Set user data for PyExecute infer.
|
||||
if (input->has_user_data<kernel::PyExecuteOutputData>()) {
|
||||
const auto &output_data = input->user_data<kernel::PyExecuteOutputData>();
|
||||
new_abs->set_user_data<kernel::PyExecuteOutputData>(output_data);
|
||||
}
|
||||
} else if (abs->isa<abstract::AbstractScalar>()) {
|
||||
auto type = depended_value->Dtype()->type_id();
|
||||
if (type == kNumberTypeInt32) {
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_INCLUDE_COMMON_UTILS_PYTHON_FALLBACK_RUNNING_H_
|
||||
#define MINDSPORE_CCSRC_INCLUDE_COMMON_UTILS_PYTHON_FALLBACK_RUNNING_H_
|
||||
|
||||
#include "include/common/visible.h"
|
||||
|
||||
namespace mindspore {
|
||||
class COMMON_EXPORT ScopedFallbackRunning final {
|
||||
public:
|
||||
ScopedFallbackRunning();
|
||||
~ScopedFallbackRunning();
|
||||
|
||||
inline static bool on() { return on_; }
|
||||
|
||||
private:
|
||||
inline static bool on_{false};
|
||||
};
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_INCLUDE_COMMON_UTILS_PYTHON_FALLBACK_RUNNING_H_
|
|
@ -42,6 +42,7 @@ AnfNodePtr ConvertInterpretedObjectToPyExecute(const FuncGraphPtr &fg, const Val
|
|||
// Set the value node into dict firstly.
|
||||
py::module mod = python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE);
|
||||
constexpr auto set_local_variable = "set_local_variable";
|
||||
MS_LOG(DEBUG) << set_local_variable << "(" << value_node_key << ", " << value_node_value << ")";
|
||||
(void)python_adapter::CallPyModFn(mod, set_local_variable, value_node_key, value_node_value);
|
||||
|
||||
// Get the value node from the dict in IR.
|
||||
|
|
|
@ -632,7 +632,6 @@ EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr
|
|||
return std::make_shared<MixedPrecisionCastEvaluator>(prim);
|
||||
}
|
||||
if (IsPrimitiveEquals(prim, prim::kPrimPyExecute)) {
|
||||
prim::kPrimPyExecute->AddAttr("primitive_target", MakeValue("CPU"));
|
||||
return std::make_shared<PyExecuteEvaluator>();
|
||||
}
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
* Copyright 2022-2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -20,6 +20,7 @@
|
|||
#include <vector>
|
||||
#include "pipeline/pynative/pynative_utils.h"
|
||||
#include "include/common/utils/scoped_long_running.h"
|
||||
#include "include/common/utils/python_fallback_running.h"
|
||||
#include "backend/graph_compiler/transform.h"
|
||||
#include "utils/ms_context.h"
|
||||
|
||||
|
@ -415,9 +416,16 @@ void ForwardExecutor::Sync() {
|
|||
}
|
||||
|
||||
ValuePtr ForwardExecutor::RunOpInMs(const FrontendOpRunInfoPtr &op_run_info) {
|
||||
MS_EXCEPTION_IF_NULL(op_run_info);
|
||||
if (!ScopedFallbackRunning::on()) {
|
||||
mindspore::ScopedLongRunning long_running;
|
||||
return RunOpInMsInner(op_run_info);
|
||||
}
|
||||
return RunOpInMsInner(op_run_info);
|
||||
}
|
||||
|
||||
ValuePtr ForwardExecutor::RunOpInMsInner(const FrontendOpRunInfoPtr &op_run_info) {
|
||||
MS_LOG(DEBUG) << "RunOpInMs start";
|
||||
mindspore::ScopedLongRunning long_running;
|
||||
MS_EXCEPTION_IF_NULL(op_run_info);
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
device_id_ = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
* Copyright 2022-2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -87,6 +87,7 @@ class ForwardExecutor {
|
|||
}
|
||||
ValuePtr RunOpInVM(const FrontendOpRunInfoPtr &op_run_info) const;
|
||||
ValuePtr RunOpInMs(const FrontendOpRunInfoPtr &op_run_info);
|
||||
ValuePtr RunOpInMsInner(const FrontendOpRunInfoPtr &op_run_info);
|
||||
ValuePtr RunOpWithBackendPolicy(const FrontendOpRunInfoPtr &op_run_info);
|
||||
void GetOutput(const FrontendOpRunInfoPtr &op_run_info);
|
||||
// Mix precision and Implicit transform
|
||||
|
|
|
@ -24,6 +24,7 @@
|
|||
#include "abstract/utils.h"
|
||||
#include "plugin/device/cpu/hal/device/cpu_common.h"
|
||||
#include "include/common/utils/python_adapter.h"
|
||||
#include "include/common/utils/python_fallback_running.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
#include "mindspore/ccsrc/pipeline/jit/parse/resolve.h"
|
||||
|
||||
|
@ -282,6 +283,15 @@ py::object PyExecuteCpuKernelMod::BuildLocalParameters(const std::vector<Address
|
|||
return local_dict;
|
||||
}
|
||||
|
||||
void TensorToRawMemory(const tensor::TensorPtr &tensor, const AddressPtr &address) {
|
||||
MS_EXCEPTION_IF_NULL(tensor);
|
||||
MS_EXCEPTION_IF_NULL(address);
|
||||
const auto &res = memcpy_s(address->addr, address->size, tensor->data_c(), tensor->Size());
|
||||
if (res != EOK) {
|
||||
MS_LOG(EXCEPTION) << "memcpy failed. res: " << res;
|
||||
}
|
||||
}
|
||||
|
||||
bool PyExecuteCpuKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
MS_LOG(DEBUG) << "Launch PyExecute(), inputs.size: " << inputs.size() << ", outputs: " << outputs.size();
|
||||
|
@ -317,8 +327,13 @@ bool PyExecuteCpuKernelMod::Launch(const std::vector<AddressPtr> &inputs, const
|
|||
params[0] = global_dict;
|
||||
params[1] = local_dict;
|
||||
MS_LOG(DEBUG) << "Python script: " << py_script << ", params: " << params;
|
||||
mindspore::ScopedFallbackRunning fallback_running;
|
||||
const auto &output = CallPythonScript(py_script, params);
|
||||
MS_LOG(DEBUG) << "Python output type: " << py::str(output.get_type()) << ", output: " << output;
|
||||
const auto &output_type = py::str(output.get_type());
|
||||
MS_LOG(DEBUG) << "Python output type: " << output_type << ", output: " << output;
|
||||
if (output_type.cast<std::string>() == "<class 'mindspore.common.tensor.Tensor'>") { // It's Python Tensor type.
|
||||
TensorToRawMemory(output.cast<tensor::TensorPtr>(), outputs[0]);
|
||||
}
|
||||
AttachPyOutputData(output);
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
* Copyright 2022-2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -18,9 +18,12 @@
|
|||
#include "pybind_api/pybind_patch.h"
|
||||
|
||||
#include "mindspore/core/ops/py_execute.h"
|
||||
#include "mindspore/ccsrc/include/common/utils/convert_utils_py.h"
|
||||
#include "mindspore/ccsrc/include/common/utils/python_adapter.h"
|
||||
#include "mindspore/ccsrc/include/common/utils/python_fallback_running.h"
|
||||
#include "mindspore/ccsrc/pipeline/jit/parse/data_converter.h"
|
||||
#include "mindspore/ccsrc/pybind_api/ir/tensor_py.h"
|
||||
#include "mindspore/ccsrc/plugin/device/cpu/kernel/pyexecute/py_execute_cpu_kernel.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
namespace mindspore {
|
||||
|
@ -40,8 +43,7 @@ class PyExecuteInitializer {
|
|||
~PyExecuteInitializer() = default;
|
||||
|
||||
private:
|
||||
// TODO(zh_qh): Will check the abstract shape and type later.
|
||||
static void InferPy(const std::vector<AbstractBasePtr> &input_args) {
|
||||
static abstract::ShapePtr InferPy(const std::vector<AbstractBasePtr> &input_args) {
|
||||
const auto &script_abs = input_args[0];
|
||||
const auto &script = script_abs->BuildValue();
|
||||
const auto &script_str = dyn_cast<StringImm>(script);
|
||||
|
@ -49,12 +51,20 @@ class PyExecuteInitializer {
|
|||
const auto &keys_tuple_abs = input_args[1];
|
||||
const auto &keys_tuple = keys_tuple_abs->BuildValue();
|
||||
const auto &keys = dyn_cast<ValueSequence>(keys_tuple);
|
||||
if (keys == nullptr) {
|
||||
MS_LOG(DEBUG) << "The keys is not tuple value, but got " << keys_tuple->ToString();
|
||||
return std::make_shared<abstract::Shape>(ShapeVector({1}));
|
||||
}
|
||||
const auto &values_tuple_abs = input_args[2];
|
||||
const auto &values_tuple = values_tuple_abs->BuildValue();
|
||||
if (values_tuple == kAnyValue) {
|
||||
MS_LOG(EXCEPTION) << "Value tuple should not be anyvalue.";
|
||||
}
|
||||
const auto &values = dyn_cast<ValueSequence>(values_tuple);
|
||||
if (values == nullptr) {
|
||||
MS_LOG(DEBUG) << "The values is not tuple value, but got " << keys_tuple->ToString();
|
||||
return std::make_shared<abstract::Shape>(ShapeVector({1}));
|
||||
}
|
||||
MS_LOG(DEBUG) << "script: " << script->ToString() << ", keys_tuple: " << keys_tuple->ToString()
|
||||
<< ", values_tuple: " << values_tuple->ToString();
|
||||
|
||||
|
@ -68,9 +78,18 @@ class PyExecuteInitializer {
|
|||
const auto &tuple_abs = values_tuple_abs->cast<abstract::AbstractSequencePtr>();
|
||||
const auto &value_abs = (*tuple_abs)[i];
|
||||
if (value->isa<tensor::Tensor>()) {
|
||||
const auto &tensor = value->cast<tensor::TensorPtr>();
|
||||
const auto &py_array_value = python_adapter::PyAdapterCallback::TensorToNumpy(*tensor);
|
||||
local_dict[py::str(key_str->value())] = py_array_value;
|
||||
if (value_abs->has_user_data<kernel::PyExecuteOutputData>()) {
|
||||
const auto &output_data = value_abs->user_data<kernel::PyExecuteOutputData>();
|
||||
auto obj = output_data->obj;
|
||||
local_dict[py::str(key_str->value())] = obj;
|
||||
} else {
|
||||
const auto &py_tensor = ValueToPyData(value);
|
||||
local_dict[py::str(key_str->value())] = py_tensor;
|
||||
}
|
||||
continue;
|
||||
} else if (value->isa<StringImm>()) {
|
||||
const auto &str_imm = value->cast<StringImmPtr>();
|
||||
local_dict[py::str(key_str->value())] = py::str(str_imm->value());
|
||||
continue;
|
||||
}
|
||||
local_dict[py::str(key_str->value())] = value;
|
||||
|
@ -81,8 +100,15 @@ class PyExecuteInitializer {
|
|||
params[0] = global_dict;
|
||||
params[1] = local_dict;
|
||||
MS_LOG(DEBUG) << "Python script: " << py_script << ", params: " << params;
|
||||
mindspore::ScopedFallbackRunning fallback_running;
|
||||
const auto &output = parse::data_converter::CallPythonScript(py_script, params);
|
||||
MS_LOG(DEBUG) << "Python output type: " << py::str(output.get_type()) << ", output: " << output;
|
||||
if (py::isinstance<tensor::Tensor>(output)) {
|
||||
const auto &tensor = output.cast<tensor::TensorPtr>();
|
||||
return std::make_shared<abstract::Shape>(tensor->shape());
|
||||
}
|
||||
|
||||
return std::make_shared<abstract::Shape>(ShapeVector({1}));
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -0,0 +1,22 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "include/common/utils/python_fallback_running.h"
|
||||
|
||||
namespace mindspore {
|
||||
ScopedFallbackRunning::ScopedFallbackRunning() { on_ = true; }
|
||||
|
||||
ScopedFallbackRunning::~ScopedFallbackRunning() { on_ = false; }
|
||||
} // namespace mindspore
|
|
@ -1,7 +1,7 @@
|
|||
/**
|
||||
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
||||
*
|
||||
* Copyright 2019-2022 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -30,6 +30,7 @@
|
|||
#include "utils/any.h"
|
||||
#include "utils/hash_map.h"
|
||||
#include "base/base.h"
|
||||
#include "base/user_data.h"
|
||||
#include "ir/dtype.h"
|
||||
#include "ir/value.h"
|
||||
#include "ir/tensor.h"
|
||||
|
@ -199,6 +200,59 @@ class MS_CORE_API AbstractBase : public Base {
|
|||
/// \return A pointer to the broadened abstract.
|
||||
virtual AbstractBasePtr PartialBroaden() const;
|
||||
|
||||
/// \brief Set user data.
|
||||
///
|
||||
/// \param[in] key The key of user data.
|
||||
/// \param[in] value The value of user data.
|
||||
template <typename T>
|
||||
void set_user_data(const std::string &key, const std::shared_ptr<T> &value) {
|
||||
user_data_.set<T>(key, value);
|
||||
}
|
||||
|
||||
/// \brief Set user data.
|
||||
///
|
||||
/// \param[in] value The value of user data.
|
||||
template <typename T>
|
||||
void set_user_data(const std::shared_ptr<T> &value) {
|
||||
user_data_.set<T>(T::key, value);
|
||||
}
|
||||
|
||||
/// \brief Get user data.
|
||||
///
|
||||
/// \param[in] key The key of user data.
|
||||
/// \return Pointer to user data.
|
||||
template <typename T>
|
||||
std::shared_ptr<T> user_data(const std::string &key) const {
|
||||
return user_data_.get<T>(key);
|
||||
}
|
||||
|
||||
/// \brief Set user data.
|
||||
///
|
||||
/// \return Pointer to user data.
|
||||
template <typename T>
|
||||
std::shared_ptr<T> user_data() const {
|
||||
return user_data_.get<T>(T::key);
|
||||
}
|
||||
|
||||
/// \brief Check whether there is corresponding user data by the given key.
|
||||
///
|
||||
/// \param[in] key The key of user data.
|
||||
/// \return True if it exists, otherwise false.
|
||||
bool has_user_data(const std::string &key) const { return user_data_.has(key); }
|
||||
|
||||
/// \brief Check if there is user data.
|
||||
///
|
||||
/// \return True if it exists, otherwise false.
|
||||
template <typename T>
|
||||
bool has_user_data() const {
|
||||
return user_data_.has(T::key);
|
||||
}
|
||||
|
||||
/// \brief Clone user data.
|
||||
///
|
||||
/// \param[in] abstract Abstract used to copy user data.
|
||||
void CloneUserData(const AbstractBasePtr &abstract) { user_data_ = abstract->user_data_; }
|
||||
|
||||
/// \brief Process the abstract with InterpretedObject.
|
||||
using InterpretBoolChecker = std::pair<bool, bool> (*)(const AbstractBasePtr &cond);
|
||||
static inline InterpretBoolChecker interpret_bool_checker_ = nullptr;
|
||||
|
@ -221,6 +275,7 @@ class MS_CORE_API AbstractBase : public Base {
|
|||
TypePtr type_;
|
||||
BaseShapePtr shape_;
|
||||
std::string value_desc_; // store initial value description for error report
|
||||
UserData user_data_;
|
||||
};
|
||||
|
||||
/// \brief Class AbstractScalar describes a scalar's type and value.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019-2022 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -1604,13 +1604,12 @@ GVAR_DEF(PrimitivePtr, kPrimSetSize, std::make_shared<Primitive>(kSetSize));
|
|||
// JIT Fallback ops
|
||||
// We add IO side-effect for them in advance.
|
||||
GVAR_DEF(PrimitivePtr, kPrimPyInterpret,
|
||||
std::make_shared<Primitive>("PyInterpret",
|
||||
mindspore::HashMap<std::string, ValuePtr>({{std::string(GRAPH_FLAG_SIDE_EFFECT_IO),
|
||||
std::make_shared<BoolImm>(true)}})));
|
||||
std::make_shared<Primitive>("PyInterpret", mindspore::HashMap<std::string, ValuePtr>(
|
||||
{{std::string(GRAPH_FLAG_SIDE_EFFECT_IO), MakeValue(true)}})));
|
||||
GVAR_DEF(PrimitivePtr, kPrimPyExecute,
|
||||
std::make_shared<Primitive>("PyExecute",
|
||||
mindspore::HashMap<std::string, ValuePtr>({{std::string(GRAPH_FLAG_SIDE_EFFECT_IO),
|
||||
std::make_shared<BoolImm>(true)}})));
|
||||
std::make_shared<Primitive>("PyExecute", mindspore::HashMap<std::string, ValuePtr>(
|
||||
{{std::string(GRAPH_FLAG_SIDE_EFFECT_IO), MakeValue(true)},
|
||||
{std::string("primitive_target"), MakeValue("CPU")}})));
|
||||
|
||||
// Other primitive not used by backend but used in core;
|
||||
GVAR_DEF(PrimitivePtr, kPrimStateSetItem, std::make_shared<Primitive>("state_setitem"));
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
* Copyright 2022-2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -28,17 +28,6 @@ MIND_API_OPERATOR_IMPL(PyExecute, BaseOperator);
|
|||
|
||||
BaseShapePtr PyExecuteInfer::InferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) const {
|
||||
ShapeVector out_shape = {1};
|
||||
return std::make_shared<abstract::Shape>(out_shape);
|
||||
}
|
||||
|
||||
TypePtr PyExecuteInfer::InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const {
|
||||
return kFloat64;
|
||||
}
|
||||
|
||||
AbstractBasePtr PyExecuteInfer::InferShapeAndType(const abstract::AnalysisEnginePtr &engine,
|
||||
const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) const {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
|
@ -48,12 +37,17 @@ AbstractBasePtr PyExecuteInfer::InferShapeAndType(const abstract::AnalysisEngine
|
|||
if (infer_handler_ == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "infer_handler_ should not be null.";
|
||||
}
|
||||
infer_handler_(input_args);
|
||||
return infer_handler_(input_args);
|
||||
}
|
||||
|
||||
const auto &type = InferType(primitive, input_args);
|
||||
const auto &shape = InferShape(primitive, input_args);
|
||||
const auto &abstract = MakeAbstract(shape, type);
|
||||
return abstract;
|
||||
TypePtr PyExecuteInfer::InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const {
|
||||
return kFloat64;
|
||||
}
|
||||
|
||||
AbstractBasePtr PyExecuteInfer::InferShapeAndType(const abstract::AnalysisEnginePtr &engine,
|
||||
const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) const {
|
||||
MS_LOG(EXCEPTION) << "Should not invoke InferShapeAndType.";
|
||||
}
|
||||
|
||||
std::set<int64_t> PyExecuteInfer::GetValueDependArgIndices() const { return {-1}; }
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
* Copyright 2022-2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -49,7 +49,7 @@ class MIND_API PyExecuteInfer : public abstract::OpInferBase {
|
|||
|
||||
std::set<int64_t> GetValueDependArgIndices() const override;
|
||||
|
||||
using InferHandler = void (*)(const std::vector<AbstractBasePtr> &);
|
||||
using InferHandler = abstract::ShapePtr (*)(const std::vector<AbstractBasePtr> &);
|
||||
static void set_infer_handler(const InferHandler &infer_handler) { infer_handler_ = infer_handler; }
|
||||
|
||||
private:
|
||||
|
|
|
@ -45,7 +45,7 @@ class Net(ms.nn.Cell):
|
|||
return self.np_function(a, b)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
|
@ -59,9 +59,7 @@ def test_fallback_np():
|
|||
a = ms.Tensor(np.array(4), ms.int32)
|
||||
b = ms.Tensor(np.array(5), ms.int32)
|
||||
output = Net()(a, b)
|
||||
print(f'output: {output}')
|
||||
const_output = ConstNet()()
|
||||
print(f'const_output: {const_output}')
|
||||
np.testing.assert_almost_equal(output, const_output, 3)
|
||||
|
||||
|
||||
|
@ -75,7 +73,7 @@ class Net1(ms.nn.Cell):
|
|||
return self.np_function(a, b)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
|
@ -89,9 +87,7 @@ def test_fallback_np_asnumpy():
|
|||
a = ms.Tensor(np.array(4), ms.int32)
|
||||
b = ms.Tensor(np.array(5), ms.int32)
|
||||
output = Net1()(a, b)
|
||||
print(f'output: {output}')
|
||||
const_output = ConstNet()()
|
||||
print(f'const_output: {const_output}')
|
||||
np.testing.assert_almost_equal(output, const_output, 3)
|
||||
|
||||
|
||||
|
@ -102,7 +98,7 @@ def tensor_asnumpy():
|
|||
return res
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
|
@ -117,7 +113,7 @@ def test_jit_tensor_asnumpy():
|
|||
print(res)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
|
@ -136,11 +132,10 @@ def test_dict_return_1():
|
|||
return z
|
||||
|
||||
out = dict_net_1()
|
||||
print(f'out: {out}')
|
||||
assert out == {'y': 'a'}
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
|
@ -160,11 +155,11 @@ def test_dict_return_2():
|
|||
return z
|
||||
|
||||
out = dict_net_2()
|
||||
print(f'out: {out}')
|
||||
assert out == {'a': ms.Tensor(np.array(1), ms.int64)}
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="No support None and Scalar in dict.")
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
|
@ -183,11 +178,10 @@ def test_dict_return_3():
|
|||
return z
|
||||
|
||||
out = dict_net_3()
|
||||
print(f'out: {out}')
|
||||
assert out == {'y': 'a', 'u': 9, 'v': False, 'w': None}
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
|
@ -207,10 +201,10 @@ def test_dict_get_2():
|
|||
return z
|
||||
|
||||
out = dict_net_2()
|
||||
print(f'out: {out}')
|
||||
assert out == {'a': ms.Tensor(np.array(1), ms.int64), 'b': 'hello', 'c': 'world'}
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
|
@ -230,7 +224,7 @@ def test_dict_get_3():
|
|||
return z
|
||||
|
||||
out = dict_net_3()
|
||||
print(f'out: {out}')
|
||||
assert out == {'y': ms.Tensor(np.array(1), ms.int64), 'a': 'a', 'b': 'c'}
|
||||
|
||||
|
||||
def weight_variable():
|
||||
|
@ -253,7 +247,7 @@ def fc_with_initialize(input_channels, out_channels):
|
|||
return ms.nn.Dense(input_channels, out_channels, weight, bias)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
|
@ -298,13 +292,12 @@ def test_net_dict_1():
|
|||
net = DictLeNetNet()
|
||||
x = ms.Tensor(np.random.rand(64, 1, 32, 32).astype(np.float32))
|
||||
outputs = net(x)
|
||||
print(f'outputs: {outputs}')
|
||||
assert outputs['conv1'].shape == (64, 6, 28, 28)
|
||||
assert outputs['conv2'].shape == (64, 16, 10, 10)
|
||||
assert outputs['fc'].shape == (64, 10)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
|
@ -349,13 +342,12 @@ def test_net_dict_2():
|
|||
net = DictLeNetNet()
|
||||
x = ms.Tensor(np.random.rand(64, 1, 32, 32).astype(np.float32))
|
||||
outputs = net(x)
|
||||
print(f'outputs: {outputs}')
|
||||
assert outputs['conv1'].shape == (64, 6, 28, 28)
|
||||
assert outputs['conv2'].shape == (64, 16, 10, 10)
|
||||
assert outputs['fc'].shape == (64, 10)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
|
@ -384,7 +376,6 @@ def test_getattr_cust_class():
|
|||
|
||||
net = GetattrClassNet()
|
||||
out = net()
|
||||
print(f'out: {out}')
|
||||
assert out == 100
|
||||
|
||||
|
||||
|
@ -411,8 +402,8 @@ class SelfObjectGetattrNet(ms.nn.Cell):
|
|||
|
||||
def __init__(self, v1, v2):
|
||||
super(SelfObjectGetattrNet, self).__init__()
|
||||
self.relu = nn.ReLU()
|
||||
self.softmax = nn.Softmax(0)
|
||||
self.relu = ms.nn.ReLU()
|
||||
self.softmax = ms.nn.Softmax(0)
|
||||
self.axis = 0
|
||||
self.test_class = ClassTest("test_class", v1)
|
||||
self.value = v2
|
||||
|
@ -424,7 +415,7 @@ class SelfObjectGetattrNet(ms.nn.Cell):
|
|||
|
||||
|
||||
@pytest.mark.skip(reason="Stuck by ScopedLongRunning() invocation in forward.cc during JIT Fallback Python running.")
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
|
@ -443,7 +434,6 @@ def test_call_other_object_method_runtime():
|
|||
net = SelfObjectGetattrNet(y, y1)
|
||||
output = net.construct(x)
|
||||
result = output.asnumpy()
|
||||
print(result)
|
||||
assert np.all(result == z)
|
||||
|
||||
|
||||
|
@ -472,7 +462,7 @@ class GlobalObjectGetattrNet(ms.nn.Cell):
|
|||
|
||||
|
||||
@pytest.mark.skip(reason="Stuck by ScopedLongRunning() invocation in forward.cc during JIT Fallback Python running.")
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
|
@ -490,12 +480,11 @@ def test_call_no_self_other_object_method_runtime():
|
|||
net = GlobalObjectGetattrNet(y)
|
||||
output = net.construct(x)
|
||||
result = output.asnumpy()
|
||||
print(result)
|
||||
assert np.all(result == z)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Not supported by now")
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
|
@ -518,7 +507,7 @@ def test_getattr_tensor_with_wrong_attr():
|
|||
|
||||
|
||||
@pytest.mark.skip(reason="Not supported by now")
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
|
@ -541,7 +530,7 @@ def test_getattr_list_with_wrong_attr():
|
|||
|
||||
|
||||
@pytest.mark.skip(reason="Not supported by now")
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
|
@ -564,7 +553,7 @@ def test_getattr_tuple_with_wrong_attr():
|
|||
|
||||
|
||||
@pytest.mark.skip(reason="Not supported by now")
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
|
|
Loading…
Reference in New Issue