forked from mindspore-Ecosystem/mindspore
!46303 MSAdapter
Merge pull request !46303 from huangbingjian/adapter_tensor
This commit is contained in:
commit
92bf98fa4e
|
@ -12,6 +12,7 @@
|
|||
"mindspore/mindspore/python/mindspore/_check_version.py" "unused-import"
|
||||
"mindspore/mindspore/python/mindspore/_check_version.py" "broad-except"
|
||||
"mindspore/mindspore/python/mindspore/common/api.py" "protected-access"
|
||||
"mindspore/mindspore/python/mindspore/common/api.py" "not-callable"
|
||||
"mindspore/mindspore/python/mindspore/common/parameter.py" "protected-access"
|
||||
"mindspore/mindspore/python/mindspore/common/parameter.py" "no-value-for-parameter"
|
||||
"mindspore/mindspore/python/mindspore/common/hook_handle.py" "protected-access"
|
||||
|
@ -72,6 +73,7 @@
|
|||
"mindspore/mindspore/python/mindspore/ops/function/array_func.py" "redefined-builtin"
|
||||
"mindspore/mindspore/python/mindspore/ops/operations/array_ops.py" "redefined-builtin"
|
||||
"mindspore/mindspore/python/mindspore/ops/_grad_experimental/grad_sparse_ops.py" "unused-variable"
|
||||
"mindspore/mindspore/python/mindspore/ops/operations/_inner_ops.py" "not-callable"
|
||||
|
||||
# MindData
|
||||
"mindspore/mindspore/python/mindspore/dataset/__init__.py" "redefined-builtin"
|
||||
|
@ -175,6 +177,11 @@
|
|||
"mindspore/tests/" "c-extension-no-member"
|
||||
"mindspore/tests/st/parameter/test_parameter_celllist.py" "protected-access"
|
||||
"mindspore/tests/ut/python/rewrite/test_cellcontainer.py" "protected-access"
|
||||
"mindspore/tests/st/ms_adapter/_register/utils.py" "protected-access"
|
||||
"mindspore/tests/st/ms_adapter/_register/utils.py" "unused-variable"
|
||||
"mindspore/tests/st/ms_adapter/test_ms_adapter_base.py" "unidiomatic-typecheck"
|
||||
"mindspore/tests/st/ms_adapter/test_operator.py" "unidiomatic-typecheck"
|
||||
"mindspore/tests/st/ms_adapter/test_python_builtins.py" "unidiomatic-typecheck"
|
||||
|
||||
#MindSpore Lite
|
||||
"mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/experimental/HPC-generator/generator.py" "redefined-builtin"
|
||||
|
|
|
@ -214,6 +214,22 @@ AbstractBasePtr InferImplHasType(const AnalysisEnginePtr &, const PrimitivePtr &
|
|||
return std::make_shared<AbstractScalar>(std::make_shared<BoolImm>(v), kBool);
|
||||
}
|
||||
|
||||
bool IsAdapterTensor(const AbstractBasePtr &x) {
|
||||
if (!x->isa<abstract::AbstractTensor>()) {
|
||||
return false;
|
||||
}
|
||||
return x->cast<abstract::AbstractTensorPtr>()->is_adapter();
|
||||
}
|
||||
|
||||
bool IsAdapterTensorClassType(const AbstractBasePtr &cmp) {
|
||||
auto cmp_value = cmp->BuildValue();
|
||||
if (!cmp_value->isa<parse::ClassType>()) {
|
||||
return false;
|
||||
}
|
||||
auto class_obj = cmp_value->cast<parse::ClassTypePtr>()->obj();
|
||||
return py::hasattr(class_obj, PYTHON_ADAPTER_TENSOR);
|
||||
}
|
||||
|
||||
bool CheckPythonIsInstance(const py::object &x, const AbstractBasePtr &cmp, const py::module &mod, bool is_const) {
|
||||
if (cmp->isa<abstract::AbstractTuple>()) {
|
||||
const auto &cmp_tuple_elements = cmp->cast<abstract::AbstractTuplePtr>()->elements();
|
||||
|
@ -401,6 +417,11 @@ AbstractBasePtr InferImplIsInstance(const AnalysisEnginePtr &, const PrimitivePt
|
|||
return std::make_shared<AbstractScalar>(std::make_shared<BoolImm>(result), kBool);
|
||||
}
|
||||
|
||||
// x is adapter tensor.
|
||||
if (IsAdapterTensor(x) && IsAdapterTensorClassType(cmp)) {
|
||||
return std::make_shared<AbstractScalar>(std::make_shared<BoolImm>(true), kBool);
|
||||
}
|
||||
|
||||
auto x_value = x->BuildValue();
|
||||
// x is variable built-in type.
|
||||
if (x_value == kAnyValue) {
|
||||
|
@ -1112,6 +1133,42 @@ std::optional<StandardPrimitiveImplReg> GetFrontendPrimitiveInferImpl(const Prim
|
|||
return std::optional<StandardPrimitiveImplReg>();
|
||||
}
|
||||
|
||||
AbstractBasePtr SetAdapterFlag(const std::string &op_name, const AbstractBasePtr &abs_input, bool adapter_flag) {
|
||||
MS_EXCEPTION_IF_NULL(abs_input);
|
||||
// Clone is needed here.
|
||||
if (abs_input->isa<AbstractRefTensor>()) {
|
||||
auto abs_ref = abs_input->Clone()->cast<AbstractRefPtr>();
|
||||
abs_ref->set_is_adapter(adapter_flag);
|
||||
return abs_ref;
|
||||
}
|
||||
if (abs_input->isa<AbstractTensor>()) {
|
||||
auto abs_tensor = abs_input->Clone()->cast<AbstractTensorPtr>();
|
||||
abs_tensor->set_is_adapter(adapter_flag);
|
||||
return abs_tensor;
|
||||
}
|
||||
MS_LOG(EXCEPTION) << op_name << " requires a tensor as the first argument, but got " << abs_input->ToString();
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplConvertToAdapterTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: a tensor.
|
||||
constexpr size_t args_num = 1;
|
||||
constexpr size_t input_index = 0;
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, args_num);
|
||||
return SetAdapterFlag(op_name, args_spec_list[input_index], true);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplConvertToMsTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: a tensor.
|
||||
constexpr size_t args_num = 1;
|
||||
constexpr size_t input_index = 0;
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, args_num);
|
||||
return SetAdapterFlag(op_name, args_spec_list[input_index], false);
|
||||
}
|
||||
|
||||
#ifndef _MSC_VER
|
||||
// String
|
||||
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(StringMul, prim::kPrimStringMul, InferImplStringMul, nullptr);
|
||||
|
@ -1149,6 +1206,10 @@ REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(Taylor, prim::kPrimTaylor, InferImplTaylor, n
|
|||
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(Shard, prim::kPrimShard, InferImplShard, nullptr);
|
||||
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(Vmap, prim::kPrimVmap, InferImplVmap, nullptr);
|
||||
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(Lower, prim::kPrimLower, InferImplLower, nullptr);
|
||||
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(ConvertToAdapterTensor, prim::kPrimConvertToAdapterTensor,
|
||||
InferImplConvertToAdapterTensor, nullptr);
|
||||
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(ConvertToMsTensor, prim::kPrimConvertToMsTensor, InferImplConvertToMsTensor,
|
||||
nullptr);
|
||||
#else
|
||||
void RegPrimitiveFrontEval() {
|
||||
// String
|
||||
|
@ -1213,6 +1274,11 @@ void RegPrimitiveFrontEval() {
|
|||
InferImplVmap, nullptr);
|
||||
abstract::RegisterStandardPrimitiveEvalHelper(abstract::GetFrontendPrimitiveInferMapPtr(), prim::kPrimLower,
|
||||
InferImplLower, nullptr);
|
||||
abstract::RegisterStandardPrimitiveEvalHelper(abstract::GetFrontendPrimitiveInferMapPtr(),
|
||||
prim::kPrimConvertToAdapterTensor, InferImplConvertToAdapterTensor,
|
||||
nullptr);
|
||||
abstract::RegisterStandardPrimitiveEvalHelper(abstract::GetFrontendPrimitiveInferMapPtr(),
|
||||
prim::kPrimConvertToMsTensor, InferImplConvertToMsTensor, nullptr);
|
||||
} // namespace abstract
|
||||
#endif
|
||||
} // namespace abstract
|
||||
|
|
|
@ -81,6 +81,10 @@ AbstractBasePtr InferImplShard(const AnalysisEnginePtr &, const PrimitivePtr &pr
|
|||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplVmap(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplConvertToAdapterTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplConvertToMsTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
|
||||
const PrimitiveEvalImplMap &GetFrontendPrimitiveInferMap();
|
||||
PrimitiveEvalImplMap *GetFrontendPrimitiveInferMapPtr();
|
||||
|
|
|
@ -51,6 +51,7 @@
|
|||
#include "frontend/optimizer/irpass/call_graph_tuple_transform.h"
|
||||
#include "frontend/optimizer/irpass/recompute_prepare.h"
|
||||
#include "frontend/optimizer/irpass/real_op_eliminate.h"
|
||||
#include "frontend/optimizer/irpass/convert_tensor_eliminate.h"
|
||||
#include "mindspore/ccsrc/frontend/optimizer/irpass/bprop_mindir/get_constexpr_ops.h"
|
||||
#include "mindspore/ccsrc/frontend/optimizer/irpass/bprop_mindir/get_class_type.h"
|
||||
#include "mindspore/ccsrc/frontend/optimizer/irpass/bprop_mindir/get_meta_fg.h"
|
||||
|
@ -132,6 +133,11 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
|
|||
all_reduce_const_elim_ =
|
||||
MakeSubstitution(std::make_shared<AllReduceConstElim>(), "reduce_all_const_elim", prim::kPrimAllReduce);
|
||||
real_op_eliminate_ = MakeSubstitution(std::make_shared<RealOpEliminate>(), "real_op_eliminate", prim::kPrimRealInner);
|
||||
convert_tensor_eliminate_ = MakeSubstitution(std::make_shared<ConvertTensorEliminate>(), "convert_tensor_eliminate",
|
||||
{prim::kPrimConvertToAdapterTensor, prim::kPrimConvertToMsTensor});
|
||||
convert_tensor_all_eliminate_ =
|
||||
MakeSubstitution(std::make_shared<ConvertTensorAllEliminate>(), "convert_tensor_all_eliminate",
|
||||
{prim::kPrimConvertToAdapterTensor, prim::kPrimConvertToMsTensor});
|
||||
|
||||
// Environ Item Eliminate
|
||||
environ_get_eliminate_ =
|
||||
|
|
|
@ -64,6 +64,8 @@ class OptimizeIRPassLib {
|
|||
SubstitutionPtr mini_step_allgather_replace_;
|
||||
SubstitutionPtr micro_step_allgather_replace_;
|
||||
SubstitutionPtr real_op_eliminate_;
|
||||
SubstitutionPtr convert_tensor_eliminate_;
|
||||
SubstitutionPtr convert_tensor_all_eliminate_;
|
||||
|
||||
// Env Item Eliminate
|
||||
SubstitutionPtr environ_get_eliminate_;
|
||||
|
|
|
@ -0,0 +1,101 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_CONVERT_TENSOR_ELIMINATE_H_
|
||||
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_CONVERT_TENSOR_ELIMINATE_H_
|
||||
|
||||
#include "frontend/optimizer/anf_visitor.h"
|
||||
#include "frontend/optimizer/irpass.h"
|
||||
#include "frontend/optimizer/optimizer.h"
|
||||
#include "pipeline/jit/static_analysis/prim.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace irpass {
|
||||
class ConvertTensorEliminate : public AnfVisitor {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
auto fg = node->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
constexpr size_t tensor_index = 1;
|
||||
auto x = cnode->input(tensor_index);
|
||||
bool is_adapter = IsAdapterTensor(x);
|
||||
if (IsPrimitiveCNode(node, prim::kPrimConvertToAdapterTensor)) {
|
||||
// {prim::kPrimConvertToAdapterTensor, x} -> x
|
||||
if (is_adapter) {
|
||||
return x;
|
||||
}
|
||||
// {prim::kPrimConvertToAdapterTensor, {prim::kPrimConvertToMsTensor, inp}} ->
|
||||
// {prim::kPrimConvertToAdapterTensor, inp}
|
||||
if (IsPrimitiveCNode(x, prim::kPrimConvertToMsTensor)) {
|
||||
auto x_cnode = x->cast<CNodePtr>();
|
||||
auto inp = x_cnode->input(tensor_index);
|
||||
auto new_node = fg->NewCNode({NewValueNode(prim::kPrimConvertToAdapterTensor), inp});
|
||||
new_node->set_abstract(node->abstract());
|
||||
return new_node;
|
||||
}
|
||||
}
|
||||
if (IsPrimitiveCNode(x, prim::kPrimConvertToMsTensor)) {
|
||||
// {prim::kPrimConvertToMsTensor, x} -> x
|
||||
if (!is_adapter) {
|
||||
return x;
|
||||
}
|
||||
// {prim::kPrimConvertToMsTensor, {prim::kPrimConvertToAdapterTensor, inp}} ->
|
||||
// {prim::kPrimConvertToMsTensor, inp}
|
||||
if (IsPrimitiveCNode(x, prim::kPrimConvertToAdapterTensor)) {
|
||||
auto x_cnode = x->cast<CNodePtr>();
|
||||
auto inp = x_cnode->input(tensor_index);
|
||||
auto new_node = fg->NewCNode({NewValueNode(prim::kPrimConvertToMsTensor), inp});
|
||||
new_node->set_abstract(node->abstract());
|
||||
return new_node;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
private:
|
||||
bool IsAdapterTensor(const AnfNodePtr &node) {
|
||||
auto abs = node->abstract();
|
||||
MS_EXCEPTION_IF_NULL(abs);
|
||||
auto abs_tensor = dyn_cast<abstract::AbstractTensor>(abs);
|
||||
MS_EXCEPTION_IF_NULL(abs_tensor);
|
||||
return abs_tensor->is_adapter();
|
||||
}
|
||||
};
|
||||
|
||||
class ConvertTensorAllEliminate : public AnfVisitor {
|
||||
public:
|
||||
// {prim::kPrimConvertToAdapterTensor, x} -> x
|
||||
// {prim::kPrimConvertToMsTensor, x} -> x
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
if (!IsPrimitiveCNode(node, prim::kPrimConvertToAdapterTensor) &&
|
||||
!IsPrimitiveCNode(node, prim::kPrimConvertToMsTensor)) {
|
||||
return nullptr;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
constexpr size_t tensor_index = 1;
|
||||
auto tensor = cnode->input(tensor_index);
|
||||
tensor->set_abstract(node->abstract());
|
||||
return tensor;
|
||||
}
|
||||
};
|
||||
} // namespace irpass
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_CONVERT_TENSOR_ELIMINATE_H_
|
|
@ -31,9 +31,8 @@ namespace py = pybind11;
|
|||
|
||||
namespace mindspore {
|
||||
py::object AnyToPyData(const Any &value);
|
||||
COMMON_EXPORT py::object BaseRefToPyData(const BaseRef &value);
|
||||
COMMON_EXPORT py::object BaseRefToPyData(const BaseRef &value, const AbstractBasePtr &output);
|
||||
COMMON_EXPORT py::object ValueToPyData(const ValuePtr &value);
|
||||
COMMON_EXPORT py::object BaseRefToPyData(const BaseRef &value, const AbstractBasePtr &abs = nullptr);
|
||||
COMMON_EXPORT py::object ValueToPyData(const ValuePtr &value, const AbstractBasePtr &abs = nullptr);
|
||||
|
||||
COMMON_EXPORT bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &output, const py::tuple &args,
|
||||
const std::shared_ptr<py::object> &ret_val);
|
||||
|
|
|
@ -24,6 +24,7 @@
|
|||
#include "frontend/operator/composite/composite.h"
|
||||
#include "ir/func_graph_cloner.h"
|
||||
#include "ir/cell.h"
|
||||
#include "ir/adapter_tensor.h"
|
||||
#include "utils/symbolic.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
|
@ -47,6 +48,8 @@ using COOTensor = mindspore::tensor::COOTensor;
|
|||
using COOTensorPtr = mindspore::tensor::COOTensorPtr;
|
||||
using MapTensor = mindspore::tensor::MapTensor;
|
||||
using MapTensorPtr = mindspore::tensor::MapTensorPtr;
|
||||
using AdapterTensor = mindspore::tensor::AdapterTensor;
|
||||
using AdapterTensorPtr = mindspore::tensor::AdapterTensorPtr;
|
||||
|
||||
using InstanceCheckFunc = std::function<bool(const py::object &)>;
|
||||
using InstanceConvertFunc = std::function<ValuePtr(const py::object &, bool, const TypePtr &)>;
|
||||
|
@ -80,7 +83,7 @@ using ArgsObjConvertFunc = std::function<ValuePtr(const py::object &)>;
|
|||
using ArgsObjSigConvertFunc = std::function<ValuePtr(const py::object &, bool)>;
|
||||
using ArgsOjbTypeConvertFunc = std::function<ValuePtr(const py::object &, const TypePtr &)>;
|
||||
|
||||
// Convert the data according instance type
|
||||
// Convert the data according to instance type
|
||||
template <typename T>
|
||||
class ByTypeDataConverter : public DataConverter {
|
||||
public:
|
||||
|
@ -117,7 +120,7 @@ class ByTypeDataConverter : public DataConverter {
|
|||
InstanceCheckFunc check_func_ = nullptr;
|
||||
};
|
||||
|
||||
// Convert the data according object attribute.
|
||||
// Convert the data according to object attribute.
|
||||
class ByAttrDataConverter : public DataConverter {
|
||||
public:
|
||||
ByAttrDataConverter(const std::string &attr_name, const ArgsObjConvertFunc &convert_func)
|
||||
|
@ -139,6 +142,28 @@ class ByAttrDataConverter : public DataConverter {
|
|||
std::string attr_name_;
|
||||
};
|
||||
|
||||
// Convert the data according to match function.
|
||||
class ByFuncDataConverter : public DataConverter {
|
||||
public:
|
||||
ByFuncDataConverter(const InstanceCheckFunc &match_func, const ArgsObjConvertFunc &convert_func)
|
||||
: DataConverter(
|
||||
[convert_func](const py::object &obj, bool, const TypePtr &) -> ValuePtr { return convert_func(obj); }),
|
||||
match_func_(match_func) {}
|
||||
|
||||
ByFuncDataConverter(const InstanceCheckFunc &match_func, const ArgsObjSigConvertFunc &convert_func)
|
||||
: DataConverter([convert_func](const py::object &obj, bool use_sig, const TypePtr &) -> ValuePtr {
|
||||
return convert_func(obj, use_sig);
|
||||
}),
|
||||
match_func_(match_func) {}
|
||||
|
||||
~ByFuncDataConverter() override = default;
|
||||
|
||||
bool Matched(const py::object &obj) override { return match_func_ != nullptr ? match_func_(obj) : false; }
|
||||
|
||||
private:
|
||||
InstanceCheckFunc match_func_ = nullptr;
|
||||
};
|
||||
|
||||
FuncGraphPtr ConvertToBpropCut(const py::object &obj) {
|
||||
std::vector<std::string> results = data_converter::GetObjKey(obj);
|
||||
std::string obj_key = results[0];
|
||||
|
@ -171,6 +196,25 @@ FuncGraphPtr ConvertToBpropCut(const py::object &obj) {
|
|||
}
|
||||
|
||||
namespace {
|
||||
bool IsAdapterTensor(const py::object &obj) {
|
||||
if (!py::hasattr(obj, PYTHON_ADAPTER_TENSOR)) {
|
||||
return false;
|
||||
}
|
||||
// Only class instances are considered.
|
||||
if (data_converter::IsClassType(obj)) {
|
||||
return false;
|
||||
}
|
||||
// Check if the attribute is true.
|
||||
return py::getattr(obj, PYTHON_ADAPTER_TENSOR).cast<bool>();
|
||||
}
|
||||
|
||||
ValuePtr ConvertAdapterTensor(const py::object &obj) {
|
||||
// Use class AdapterTensor instead of Tensor to avoid circular dependencies.
|
||||
MS_LOG(DEBUG) << "Converting adapter tensor";
|
||||
auto tensor = obj.cast<TensorPtr>();
|
||||
return std::make_shared<AdapterTensor>(tensor);
|
||||
}
|
||||
|
||||
ValuePtr ConvertTuple(const py::object &obj, bool use_signature) {
|
||||
MS_LOG(DEBUG) << "Converting python tuple";
|
||||
auto tuple = obj.cast<py::tuple>();
|
||||
|
@ -512,8 +556,10 @@ ValuePtr ObjCast(const py::object &obj) {
|
|||
}
|
||||
|
||||
static const std::vector<DataConverterPtr> &GetDataConverters() {
|
||||
// Convert data by python object type.
|
||||
static const std::vector<DataConverterPtr> data_converters{
|
||||
// Convert data by python object type.
|
||||
// AdapterTensor needs to be processed before Tensor because it inherits from Tensor.
|
||||
std::make_shared<ByFuncDataConverter>(IsAdapterTensor, ConvertAdapterTensor),
|
||||
std::make_shared<ByTypeDataConverter<Tensor>>(ObjCast<TensorPtr>),
|
||||
std::make_shared<ByTypeDataConverter<MetaTensor>>(ObjCast<MetaTensorPtr>),
|
||||
std::make_shared<ByTypeDataConverter<CSRTensor>>(ObjCast<CSRTensorPtr>),
|
||||
|
|
|
@ -67,6 +67,7 @@ const char PYTHON_MOD_GET_CLASS_INSTANCE_TYPE[] = "get_class_instance_type";
|
|||
const char PYTHON_MOD_CREATE_INSTANCE[] = "create_instance";
|
||||
const char PYTHON_MOD_IS_SUPPORTED_CREATE_INSTANCE_TYPE[] = "is_supported_create_instance_type";
|
||||
const char PYTHON_MOD_IS_CLASS_TYPE[] = "is_class_type";
|
||||
const char PYTHON_MOD_GET_ADAPTER_TENSOR_ATTR[] = "get_adapter_tensor_attr";
|
||||
const char PYTHON_MOD_GET_MS_CLASS_NAME[] = "get_ms_class_name";
|
||||
const char PYTHON_MOD_GET_MODULE_NAMESPACE[] = "get_module_namespace";
|
||||
const char PYTHON_MOD_GET_ATTR_NAMESPACE_SYMBOL[] = "get_class_attr_namespace_symbol";
|
||||
|
|
|
@ -307,6 +307,7 @@ opt::OptPassConfig GetOptPassA1(const opt::irpass::OptimizeIRPassLib &irpass) {
|
|||
irpass.stopgrad_eliminater_,
|
||||
irpass.partial_eliminate_,
|
||||
irpass.replace_applicator_,
|
||||
irpass.convert_tensor_eliminate_,
|
||||
|
||||
// Miscellaneous
|
||||
irpass.tuple_list_get_item_eliminator_,
|
||||
|
@ -761,13 +762,16 @@ bool EliminateSpecialOpOptPass(const ResourcePtr &resource) {
|
|||
opt::OptPassConfig mutable_op_eliminate = opt::OptPassConfig({
|
||||
irpass.mutable_op_eliminate_,
|
||||
});
|
||||
opt::OptPassConfig convert_tensor_op_eliminate = opt::OptPassConfig({
|
||||
irpass.convert_tensor_all_eliminate_,
|
||||
});
|
||||
OptPassGroupMap map({
|
||||
{"ad_related_special_op_eliminate", ad_related_special_op_eliminate},
|
||||
{"mutable_op_eliminate", mutable_op_eliminate},
|
||||
{"convert_tensor_op_eliminate", convert_tensor_op_eliminate},
|
||||
});
|
||||
auto ad_related_special_op_eliminate_opt =
|
||||
opt::Optimizer::MakeOptimizer("ad_related_special_op_eliminate", resource, map);
|
||||
(void)ad_related_special_op_eliminate_opt->step(func_graph, false);
|
||||
auto special_op_eliminate_opt = opt::Optimizer::MakeOptimizer("special_op_eliminate", resource, map);
|
||||
(void)special_op_eliminate_opt->step(func_graph, false);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -1511,6 +1511,53 @@ EvalResultPtr GetEvaluatedValueForCellAttrOrMethod(const AbstractBasePtrList &ar
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
EvalResultPtr GetEvaluatedValueForAdapterTensorAttrOrMethod(const AnalysisEnginePtr &engine,
|
||||
const AbstractBasePtr &data_args,
|
||||
const AbstractBasePtr &item_args,
|
||||
const ConfigPtr &data_conf,
|
||||
const AnfNodeConfigPtr &out_conf) {
|
||||
MS_EXCEPTION_IF_NULL(data_args);
|
||||
MS_EXCEPTION_IF_NULL(item_args);
|
||||
// Check whether it is AdapterTensor or AdapterParameter.
|
||||
auto abs = data_args->cast_ptr<abstract::AbstractTensor>();
|
||||
if (abs == nullptr || !abs->is_adapter()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Get the name of attr/method.
|
||||
ValuePtr item_value = item_args->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(item_value);
|
||||
if (!item_value->isa<StringImm>()) {
|
||||
MS_LOG(EXCEPTION) << "Expect a string, but got: " << item_value->ToString();
|
||||
}
|
||||
std::string item_name = item_value->cast_ptr<StringImm>()->value();
|
||||
|
||||
constexpr size_t attr_index = 0;
|
||||
constexpr size_t flag_index = 1;
|
||||
constexpr size_t info_required_size = 2;
|
||||
py::module mod = python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE);
|
||||
py::tuple attr_info = python_adapter::CallPyModFn(mod, parse::PYTHON_MOD_GET_ADAPTER_TENSOR_ATTR, py::str(item_name));
|
||||
if (attr_info.size() != info_required_size) {
|
||||
MS_EXCEPTION(NameError) << "attr info size should be 2, but got " << attr_info.size();
|
||||
}
|
||||
// If func is none, it means there is no such attr or method.
|
||||
py::object func = attr_info[attr_index];
|
||||
if (py::isinstance<py::none>(func)) {
|
||||
return nullptr;
|
||||
}
|
||||
ValuePtr converted_value = nullptr;
|
||||
bool success = parse::ConvertData(func, &converted_value);
|
||||
if (!success || converted_value == nullptr || !converted_value->isa<FuncGraph>()) {
|
||||
return nullptr;
|
||||
}
|
||||
AddToManager(engine, converted_value->cast<FuncGraphPtr>());
|
||||
|
||||
// Check whether it is an attribute or a method.
|
||||
bool is_attr = attr_info[flag_index].cast<bool>();
|
||||
REQUIRE_TYPE require_type = is_attr ? REQUIRE_TYPE::ATTR : REQUIRE_TYPE::METHOD;
|
||||
return StaticGetterInferred(converted_value, data_conf, out_conf, require_type);
|
||||
}
|
||||
|
||||
EvalResultPtr GetEvaluatedValueForBuiltinTypeAttrOrMethod(const AnalysisEnginePtr &engine,
|
||||
const AbstractBasePtrList &args_abs_list,
|
||||
const ConfigPtr &data_conf,
|
||||
|
@ -1676,6 +1723,11 @@ EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePt
|
|||
return res;
|
||||
}
|
||||
}
|
||||
// Get attribute or method of AdapterTensor object.
|
||||
auto res = GetEvaluatedValueForAdapterTensorAttrOrMethod(engine, data_args, item_args, data_conf, out_conf);
|
||||
if (res != nullptr) {
|
||||
return res;
|
||||
}
|
||||
// Try to search method map, if not found, the data_type should be External type.
|
||||
TypePtr data_type = data_args->BuildType();
|
||||
if (pipeline::Resource::IsTypeInBuiltInMap(data_type->type_id())) {
|
||||
|
|
|
@ -254,6 +254,10 @@ void PyParser::ParseOpInputByPythonObj(const FrontendOpRunInfoPtr &op_run_info,
|
|||
py::object DataConvert::ValueToPyObj(const ValuePtr &v) { return ValueToPyData(v); }
|
||||
|
||||
ValuePtr DataConvert::PyObjToValue(const py::object &obj) {
|
||||
// In PyNative mode, AdapterTensor is treated as ms.Tensor.
|
||||
if (py::hasattr(obj, PYTHON_ADAPTER_TENSOR) && py::getattr(obj, PYTHON_ADAPTER_TENSOR).cast<bool>()) {
|
||||
py::setattr(obj, PYTHON_ADAPTER_TENSOR, py::bool_(false));
|
||||
}
|
||||
ValuePtr converted_ret = parse::data_converter::PyDataToValue(obj);
|
||||
if (converted_ret == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Attribute convert error with type: " << std::string(py::str(obj));
|
||||
|
|
|
@ -19,6 +19,7 @@ namespace mindspore {
|
|||
const char PYTHON_PRIMITIVE_FLAG[] = "__primitive_flag__";
|
||||
const char PYTHON_CELL_AS_LIST[] = "__cell_as_list__";
|
||||
const char PYTHON_MS_CLASS[] = "__ms_class__";
|
||||
const char PYTHON_ADAPTER_TENSOR[] = "__adapter_tensor__";
|
||||
const char PYTHON_CLASS_MEMBER_NAMESPACE[] = "__class_member_namespace__";
|
||||
const char PYTHON_FUNCTION_FORBID_REUSE[] = "__function_forbid_reuse__";
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -21,6 +21,7 @@ namespace mindspore {
|
|||
extern const char PYTHON_PRIMITIVE_FLAG[];
|
||||
extern const char PYTHON_CELL_AS_LIST[];
|
||||
extern const char PYTHON_MS_CLASS[];
|
||||
extern const char PYTHON_ADAPTER_TENSOR[];
|
||||
extern const char PYTHON_CLASS_MEMBER_NAMESPACE[];
|
||||
extern const char PYTHON_FUNCTION_FORBID_REUSE[];
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -562,6 +562,7 @@ void RegMetaTensor(py::module *m) {
|
|||
}),
|
||||
py::arg("input"), py::arg("dtype") = nullptr)
|
||||
.def_property("init_flag", &Tensor::is_init, &Tensor::set_init_flag)
|
||||
.def_property("adapter_flag", &Tensor::is_adapter, &Tensor::set_adapter_flag)
|
||||
.def_property_readonly("_dtype", &Tensor::Dtype, R"mydelimiter(
|
||||
Get the tensor's data type.
|
||||
|
||||
|
|
|
@ -41,20 +41,45 @@ namespace mindspore {
|
|||
py::object BuiltinsToPyData(const Any &value);
|
||||
py::object BuiltinsToPyData(const BaseRef &value);
|
||||
py::object VectorToPyData(const Any &value);
|
||||
py::object VectorRefToPyData(const VectorRef &value_list);
|
||||
py::object VectorRefToPyData(const VectorRef &value_list, const AbstractBasePtr &output);
|
||||
py::object VectorRefToPyData(const VectorRef &value_list, const AbstractBasePtr &abs = nullptr);
|
||||
py::object MakeCSRTensor(const VectorRef &value_list);
|
||||
py::object MakeCSRTensor(const ValuePtr &value);
|
||||
py::object MakeCOOTensor(const VectorRef &value_list);
|
||||
py::object MakeCOOTensor(const ValuePtr &value);
|
||||
ShapeVector ConvertShapeTupleToShapeVector(const ValueTuplePtr &shape_tuple);
|
||||
ShapeVector ConvertToShapeVector(const VectorRef &value_list, size_t shape_idx);
|
||||
|
||||
// For AbstractSequence and AbstractDictionary.
|
||||
template <typename T>
|
||||
T CheckAbstractElementsSize(const AbstractBasePtr &abs_value, size_t value_size) {
|
||||
if (abs_value == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
auto abs = abs_value->cast<T>();
|
||||
if (abs != nullptr && value_size != abs->size()) {
|
||||
MS_LOG(EXCEPTION) << "The size of elements should be equal to " << value_size << ", but got " << abs->size();
|
||||
}
|
||||
return abs;
|
||||
}
|
||||
|
||||
py::object SetAdaptedAttrToTensor(const py::object &tensor, const AbstractBasePtr &abs) {
|
||||
if (abs == nullptr || !abs->isa<abstract::AbstractTensor>()) {
|
||||
return tensor;
|
||||
}
|
||||
auto tensor_abs = abs->cast<abstract::AbstractTensorPtr>();
|
||||
if (tensor_abs->is_adapter()) {
|
||||
py::setattr(tensor, "adapter_flag", py::bool_(true));
|
||||
}
|
||||
return tensor;
|
||||
}
|
||||
|
||||
py::object CSRTensorToPyData(const tensor::CSRTensorPtr &csr_tensor) {
|
||||
auto ref = py::tuple(1);
|
||||
ref[0] = csr_tensor;
|
||||
return ref[0];
|
||||
}
|
||||
py::object TensorToPyData(const tensor::TensorPtr &tensor) {
|
||||
|
||||
py::object TensorToPyData(const tensor::TensorPtr &tensor, const AbstractBasePtr &abs) {
|
||||
MS_EXCEPTION_IF_NULL(tensor);
|
||||
if (tensor->NeedWait()) {
|
||||
py::gil_scoped_release release;
|
||||
|
@ -62,6 +87,7 @@ py::object TensorToPyData(const tensor::TensorPtr &tensor) {
|
|||
}
|
||||
py::tuple v(1);
|
||||
v[0] = tensor;
|
||||
v[0] = SetAdaptedAttrToTensor(v[0], abs);
|
||||
return v[0];
|
||||
}
|
||||
|
||||
|
@ -120,79 +146,110 @@ py::object ScalarPtrToPyData(const ScalarPtr &value) {
|
|||
}
|
||||
}
|
||||
|
||||
using ConverterFunction = std::function<py::object(const ValuePtr &value)>;
|
||||
py::object ValueSequenceToPyData(const ValueSequencePtr &value, const AbstractBasePtr &abs) {
|
||||
auto value_sequeue = value->value();
|
||||
auto value_size = value_sequeue.size();
|
||||
py::tuple res_sequeue(value_size);
|
||||
auto abs_sequeue = CheckAbstractElementsSize<abstract::AbstractSequencePtr>(abs, value_size);
|
||||
if (abs_sequeue == nullptr) {
|
||||
for (size_t i = 0; i < value_size; i++) {
|
||||
res_sequeue[i] = ValueToPyData(value_sequeue[i]);
|
||||
}
|
||||
} else {
|
||||
for (size_t i = 0; i < value_size; i++) {
|
||||
res_sequeue[i] = ValueToPyData(value_sequeue[i], abs_sequeue->elements()[i]);
|
||||
}
|
||||
}
|
||||
if (value->isa<ValueTuple>()) {
|
||||
return res_sequeue;
|
||||
}
|
||||
return res_sequeue.cast<py::list>();
|
||||
}
|
||||
|
||||
py::object ValueDictionaryToPyData(const ValueDictionaryPtr &value, const AbstractBasePtr &abs) {
|
||||
auto value_dict = value->value();
|
||||
auto value_size = value_dict.size();
|
||||
py::dict res_dict;
|
||||
auto abs_dict = CheckAbstractElementsSize<abstract::AbstractDictionaryPtr>(abs, value_size);
|
||||
if (abs_dict == nullptr) {
|
||||
for (const auto &v : value_dict) {
|
||||
res_dict[ValueToPyData(v.first)] = ValueToPyData(v.second);
|
||||
}
|
||||
} else {
|
||||
for (size_t i = 0; i < value_size; i++) {
|
||||
auto v = value_dict[i];
|
||||
auto abs_elem = abs_dict->elements()[i];
|
||||
res_dict[ValueToPyData(v.first, abs_elem.first)] = ValueToPyData(v.second, abs_elem.second);
|
||||
}
|
||||
}
|
||||
return res_dict;
|
||||
}
|
||||
|
||||
using ConverterFunction = std::function<py::object(const ValuePtr &value, const AbstractBasePtr &abs)>;
|
||||
using ValueNameToConverterVector = std::vector<std::pair<uint32_t, ConverterFunction>>;
|
||||
|
||||
// (Value Type Name) -> (Converter Function)
|
||||
// The converter function is used to convert Value object to Python data object.
|
||||
static ValueNameToConverterVector value_name_to_converter = {
|
||||
// Scalar
|
||||
{Scalar::kTypeId, [](const ValuePtr &value) -> py::object { return ScalarPtrToPyData(value->cast<ScalarPtr>()); }},
|
||||
{Scalar::kTypeId,
|
||||
[](const ValuePtr &value, const AbstractBasePtr &) -> py::object {
|
||||
return ScalarPtrToPyData(value->cast<ScalarPtr>());
|
||||
}},
|
||||
// Tensor
|
||||
{tensor::Tensor::kTypeId,
|
||||
[](const ValuePtr &value) -> py::object {
|
||||
[](const ValuePtr &value, const AbstractBasePtr &abs) -> py::object {
|
||||
auto tensor_ptr = value->cast<tensor::TensorPtr>();
|
||||
return TensorToPyData(tensor_ptr);
|
||||
return TensorToPyData(tensor_ptr, abs);
|
||||
}},
|
||||
// MetaTenser
|
||||
{tensor::MetaTensor::kTypeId,
|
||||
[](const ValuePtr &value) -> py::object {
|
||||
[](const ValuePtr &value, const AbstractBasePtr &) -> py::object {
|
||||
py::tuple tuple_container(1);
|
||||
tuple_container[0] = value->cast<tensor::MetaTensorPtr>();
|
||||
return tuple_container[0];
|
||||
}},
|
||||
// CSRTensor
|
||||
{tensor::CSRTensor::kTypeId,
|
||||
[](const ValuePtr &value) -> py::object {
|
||||
[](const ValuePtr &value, const AbstractBasePtr &) -> py::object {
|
||||
auto csr_tensor_ptr = value->cast<tensor::CSRTensorPtr>();
|
||||
return CSRTensorToPyData(csr_tensor_ptr);
|
||||
}},
|
||||
// RefKey
|
||||
{RefKey::kTypeId,
|
||||
[](const ValuePtr &value) -> py::object {
|
||||
[](const ValuePtr &value, const AbstractBasePtr &) -> py::object {
|
||||
py::tuple tuple_container(1);
|
||||
tuple_container[0] = value->cast<RefKeyPtr>();
|
||||
return tuple_container[0];
|
||||
}},
|
||||
// Type
|
||||
{Type::kTypeId,
|
||||
[](const ValuePtr &value) -> py::object {
|
||||
[](const ValuePtr &value, const AbstractBasePtr &) -> py::object {
|
||||
py::tuple tuple_container(1);
|
||||
tuple_container[0] = value->cast<TypePtr>();
|
||||
return tuple_container[0];
|
||||
}},
|
||||
// StringImm
|
||||
{StringImm::kTypeId,
|
||||
[](const ValuePtr &value) -> py::object {
|
||||
[](const ValuePtr &value, const AbstractBasePtr &) -> py::object {
|
||||
py::str res = value->cast<StringImmPtr>()->value();
|
||||
return res;
|
||||
}},
|
||||
// ValueSequence
|
||||
{ValueSequence::kTypeId,
|
||||
[](const ValuePtr &value) -> py::object {
|
||||
auto value_sequeue = value->cast<ValueSequencePtr>()->value();
|
||||
py::tuple res_sequeue(value_sequeue.size());
|
||||
for (size_t i = 0; i < value_sequeue.size(); i++) {
|
||||
res_sequeue[i] = ValueToPyData(value_sequeue[i]);
|
||||
}
|
||||
if (value->isa<ValueTuple>()) {
|
||||
return res_sequeue;
|
||||
}
|
||||
return res_sequeue.cast<py::list>();
|
||||
[](const ValuePtr &value, const AbstractBasePtr &abs) -> py::object {
|
||||
auto value_sequeue = value->cast<ValueSequencePtr>();
|
||||
return ValueSequenceToPyData(value_sequeue, abs);
|
||||
}},
|
||||
// ValueDictionary
|
||||
{ValueDictionary::kTypeId,
|
||||
[](const ValuePtr &value) -> py::object {
|
||||
auto value_list = value->cast<ValueDictionaryPtr>()->value();
|
||||
py::dict res_dict;
|
||||
for (const auto &v : value_list) {
|
||||
res_dict[ValueToPyData(v.first)] = ValueToPyData(v.second);
|
||||
}
|
||||
return res_dict;
|
||||
[](const ValuePtr &value, const AbstractBasePtr &abs) -> py::object {
|
||||
auto value_dict = value->cast<ValueDictionaryPtr>();
|
||||
return ValueDictionaryToPyData(value_dict, abs);
|
||||
}},
|
||||
// ValueSlice
|
||||
{ValueSlice::kTypeId,
|
||||
[](const ValuePtr &value) -> py::object {
|
||||
[](const ValuePtr &value, const AbstractBasePtr &) -> py::object {
|
||||
auto slice = value->cast<ValueSlicePtr>();
|
||||
auto start = ValueToPyData(slice->start());
|
||||
auto end = ValueToPyData(slice->stop());
|
||||
|
@ -201,7 +258,7 @@ static ValueNameToConverterVector value_name_to_converter = {
|
|||
}},
|
||||
// KeywordArg
|
||||
{KeywordArg::kTypeId,
|
||||
[](const ValuePtr &value) -> py::object {
|
||||
[](const ValuePtr &value, const AbstractBasePtr &) -> py::object {
|
||||
auto abs_keyword_arg = value->ToAbstract()->cast<abstract::AbstractKeywordArgPtr>();
|
||||
auto key = abs_keyword_arg->get_key();
|
||||
auto val = abs_keyword_arg->get_arg()->BuildValue();
|
||||
|
@ -212,53 +269,53 @@ static ValueNameToConverterVector value_name_to_converter = {
|
|||
}},
|
||||
// parse::NameSpace
|
||||
{parse::NameSpace::kTypeId,
|
||||
[](const ValuePtr &value) -> py::object {
|
||||
[](const ValuePtr &value, const AbstractBasePtr &) -> py::object {
|
||||
auto ns = value->cast<parse::NameSpacePtr>();
|
||||
return ns->module_obj();
|
||||
}},
|
||||
// parse::ClassType
|
||||
{parse::ClassType::kTypeId,
|
||||
[](const ValuePtr &value) -> py::object {
|
||||
[](const ValuePtr &value, const AbstractBasePtr &) -> py::object {
|
||||
auto class_type = value->cast<parse::ClassTypePtr>();
|
||||
return class_type->obj();
|
||||
}},
|
||||
// parse::MsClassObject
|
||||
{parse::MsClassObject::kTypeId,
|
||||
[](const ValuePtr &value) -> py::object {
|
||||
[](const ValuePtr &value, const AbstractBasePtr &) -> py::object {
|
||||
auto ms_class_object = value->cast<parse::MsClassObjectPtr>();
|
||||
return ms_class_object->obj();
|
||||
}},
|
||||
// parse::InterpretedObject
|
||||
{parse::InterpretedObject::kTypeId,
|
||||
[](const ValuePtr &value) -> py::object {
|
||||
[](const ValuePtr &value, const AbstractBasePtr &) -> py::object {
|
||||
auto interpreted_object = value->cast<parse::InterpretedObjectPtr>();
|
||||
return interpreted_object->obj();
|
||||
}},
|
||||
// None
|
||||
{None::kTypeId, [](const ValuePtr &) -> py::object { return py::none(); }},
|
||||
{None::kTypeId, [](const ValuePtr &, const AbstractBasePtr &) -> py::object { return py::none(); }},
|
||||
// AnyValue
|
||||
{AnyValue::kTypeId, [](const ValuePtr &) -> py::object { return py::none(); }},
|
||||
{AnyValue::kTypeId, [](const ValuePtr &, const AbstractBasePtr &) -> py::object { return py::none(); }},
|
||||
// ErrorValue
|
||||
{ErrorValue::kTypeId, [](const ValuePtr &) -> py::object { return py::none(); }},
|
||||
{ErrorValue::kTypeId, [](const ValuePtr &, const AbstractBasePtr &) -> py::object { return py::none(); }},
|
||||
// FuncGraph
|
||||
{FuncGraph::kTypeId, [](const ValuePtr &) -> py::object { return py::none(); }},
|
||||
{FuncGraph::kTypeId, [](const ValuePtr &, const AbstractBasePtr &) -> py::object { return py::none(); }},
|
||||
// Primitive
|
||||
{Primitive::kTypeId, [](const ValuePtr &) -> py::object { return py::none(); }},
|
||||
{Primitive::kTypeId, [](const ValuePtr &, const AbstractBasePtr &) -> py::object { return py::none(); }},
|
||||
// Monad
|
||||
{Monad::kTypeId, [](const ValuePtr &) -> py::object { return py::none(); }},
|
||||
{Monad::kTypeId, [](const ValuePtr &, const AbstractBasePtr &) -> py::object { return py::none(); }},
|
||||
// Ellipsis
|
||||
{Ellipsis::kTypeId, [](const ValuePtr &) -> py::object { return py::ellipsis(); }}};
|
||||
{Ellipsis::kTypeId, [](const ValuePtr &, const AbstractBasePtr &) -> py::object { return py::ellipsis(); }}};
|
||||
|
||||
// When converting data to tensor, ValueToPyData will only return _c_expression Tensor,
|
||||
// but not python tensor. If python tensor is needed, call _convert_python_data to the output.
|
||||
py::object ValueToPyData(const ValuePtr &value) {
|
||||
py::object ValueToPyData(const ValuePtr &value, const AbstractBasePtr &abs) {
|
||||
if (value == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "The `value` should not be null";
|
||||
}
|
||||
py::gil_scoped_acquire gil;
|
||||
for (auto &iter : value_name_to_converter) {
|
||||
if (value->IsFromTypeId(iter.first)) {
|
||||
return iter.second(value);
|
||||
return iter.second(value, abs);
|
||||
}
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "Unsupported to convert " << value->ToString() << "[" << value->type_name() << "] to a PyData";
|
||||
|
@ -273,10 +330,6 @@ py::object AnyToPyData(const Any &value) {
|
|||
MS_LOG(DEBUG) << "ValuePtr";
|
||||
ValuePtr v = value.cast<ValuePtr>();
|
||||
ret = ValueToPyData(v);
|
||||
} else if (value.is<tensor::TensorPtr>()) {
|
||||
MS_LOG(DEBUG) << "tensor";
|
||||
auto tensor_ptr = value.cast<tensor::TensorPtr>();
|
||||
ret = TensorToPyData(tensor_ptr);
|
||||
} else if (value.is<py::object>()) {
|
||||
MS_LOG(DEBUG) << "py obj";
|
||||
ret = value.cast<py::object>();
|
||||
|
@ -307,24 +360,7 @@ py::object AnyToPyData(const Any &value) {
|
|||
return ret;
|
||||
}
|
||||
|
||||
py::object BaseRefToPyData(const BaseRef &value, const AbstractBasePtr &output) {
|
||||
py::object ret;
|
||||
// If output value is a tuple, check if abstract is a COOTensor in funcgraph output
|
||||
if (utils::isa<VectorRef>(value)) {
|
||||
MS_LOG(DEBUG) << "BaseRefToPyData, value is tuple: " << value.ToString();
|
||||
auto vec_ref = utils::cast<VectorRef>(value);
|
||||
if (output != nullptr) {
|
||||
ret = VectorRefToPyData(vec_ref, output);
|
||||
} else {
|
||||
ret = VectorRefToPyData(vec_ref);
|
||||
}
|
||||
} else {
|
||||
ret = BaseRefToPyData(value);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
py::object BaseRefToPyData(const BaseRef &value) {
|
||||
py::object BaseRefToPyData(const BaseRef &value, const AbstractBasePtr &abs) {
|
||||
py::object ret;
|
||||
MS_LOG(DEBUG) << "BaseRefToPyData " << value.ToString();
|
||||
if (utils::isa<int>(value) || utils::isa<float>(value) || utils::isa<double>(value) || utils::isa<bool>(value)) {
|
||||
|
@ -332,18 +368,14 @@ py::object BaseRefToPyData(const BaseRef &value) {
|
|||
} else if (utils::isa<ValuePtr>(value)) {
|
||||
MS_LOG(DEBUG) << "ValuePtr";
|
||||
ValuePtr v = utils::cast<ValuePtr>(value);
|
||||
ret = ValueToPyData(v);
|
||||
} else if (utils::isa<tensor::TensorPtr>(value)) {
|
||||
MS_LOG(DEBUG) << "tensor";
|
||||
auto tensor_ptr = utils::cast<tensor::TensorPtr>(value);
|
||||
ret = TensorToPyData(tensor_ptr);
|
||||
ret = ValueToPyData(v, abs);
|
||||
} else if (utils::isa<PyObjectRef>(value)) {
|
||||
MS_LOG(DEBUG) << "py obj";
|
||||
PyObjectRef py_ref = utils::cast<PyObjectRef>(value);
|
||||
ret = py_ref.object_;
|
||||
} else if (utils::isa<VectorRef>(value)) {
|
||||
auto vec_ref = utils::cast<VectorRef>(value);
|
||||
ret = VectorRefToPyData(vec_ref);
|
||||
ret = VectorRefToPyData(vec_ref, abs);
|
||||
} else if (utils::isa<TypePtr>(value)) {
|
||||
py::tuple v(1);
|
||||
v[0] = utils::cast<TypePtr>(value);
|
||||
|
@ -419,42 +451,29 @@ py::object VectorToPyData(const Any &value) {
|
|||
return ret;
|
||||
}
|
||||
|
||||
py::object VectorRefToPyData(const VectorRef &value_list) {
|
||||
py::object ret;
|
||||
MS_LOG(DEBUG) << "vector_ref";
|
||||
size_t value_size = value_list.size();
|
||||
auto ref_tuple = py::tuple(value_size);
|
||||
for (size_t i = 0; i < value_size; i++) {
|
||||
ref_tuple[i] = BaseRefToPyData(value_list[i]);
|
||||
}
|
||||
ret = ref_tuple;
|
||||
return ret;
|
||||
}
|
||||
|
||||
py::object VectorRefToPyData(const VectorRef &value_list, const AbstractBasePtr &output) {
|
||||
MS_LOG(DEBUG) << "vector_ref";
|
||||
py::object VectorRefToPyData(const VectorRef &value_list, const AbstractBasePtr &abs) {
|
||||
// Current VectorRef reflects a COOTensor type
|
||||
if (output->isa<abstract::AbstractCSRTensor>()) {
|
||||
if (abs != nullptr && abs->isa<abstract::AbstractCSRTensor>()) {
|
||||
return MakeCSRTensor(value_list);
|
||||
}
|
||||
if (output->isa<abstract::AbstractCOOTensor>()) {
|
||||
if (abs != nullptr && abs->isa<abstract::AbstractCOOTensor>()) {
|
||||
return MakeCOOTensor(value_list);
|
||||
}
|
||||
|
||||
py::object ret;
|
||||
size_t value_size = value_list.size();
|
||||
auto ref_tuple = py::tuple(value_size);
|
||||
abstract::AbstractTuplePtr tuple_output = output->cast<abstract::AbstractTuplePtr>();
|
||||
bool is_abstract_tuple = tuple_output != nullptr;
|
||||
for (size_t i = 0; i < value_size; i++) {
|
||||
if (!is_abstract_tuple || i >= tuple_output->size()) {
|
||||
// Fall back to original process
|
||||
auto seq_abs = CheckAbstractElementsSize<abstract::AbstractSequencePtr>(abs, value_size);
|
||||
if (seq_abs == nullptr) {
|
||||
for (size_t i = 0; i < value_size; i++) {
|
||||
ref_tuple[i] = BaseRefToPyData(value_list[i]);
|
||||
} else {
|
||||
ref_tuple[i] = BaseRefToPyData(value_list[i], (*tuple_output)[i]);
|
||||
}
|
||||
} else {
|
||||
for (size_t i = 0; i < value_size; i++) {
|
||||
ref_tuple[i] = BaseRefToPyData(value_list[i], seq_abs->elements()[i]);
|
||||
}
|
||||
}
|
||||
ret = ref_tuple;
|
||||
return ret;
|
||||
return ref_tuple;
|
||||
}
|
||||
|
||||
bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &output, const py::tuple &args,
|
||||
|
@ -462,12 +481,14 @@ bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &output, const py::tuple
|
|||
if (output->isa<ValueNode>()) {
|
||||
MS_LOG(INFO) << "Graph's output is a constant. No need to execute.";
|
||||
ValuePtr value = GetValueNode(output);
|
||||
if (output->abstract()->isa<abstract::AbstractCSRTensor>()) {
|
||||
auto abs = output->abstract();
|
||||
MS_EXCEPTION_IF_NULL(abs);
|
||||
if (abs->isa<abstract::AbstractCSRTensor>()) {
|
||||
*ret_val = MakeCSRTensor(value);
|
||||
} else if (output->abstract()->isa<abstract::AbstractCOOTensor>()) {
|
||||
} else if (abs->isa<abstract::AbstractCOOTensor>()) {
|
||||
*ret_val = MakeCOOTensor(value);
|
||||
} else {
|
||||
*ret_val = ValueToPyData(value);
|
||||
*ret_val = ValueToPyData(value, abs);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
@ -505,6 +526,7 @@ bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &output, const py::tuple
|
|||
auto tensor = param->default_param();
|
||||
*ret_val = py::cast(tensor);
|
||||
}
|
||||
*ret_val = SetAdaptedAttrToTensor(*ret_val, output->abstract());
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
|
|
|
@ -1213,13 +1213,19 @@ AbstractBasePtr AbstractTensor::Join(const AbstractBasePtr &other) {
|
|||
// Check element
|
||||
auto element = element_->Join(other_tensor->element_);
|
||||
MS_EXCEPTION_IF_NULL(element);
|
||||
return std::make_shared<AbstractTensor>(element, res_shape);
|
||||
auto ret = std::make_shared<AbstractTensor>(element, res_shape);
|
||||
ret->set_is_adapter(is_adapter_);
|
||||
return ret;
|
||||
}
|
||||
|
||||
bool AbstractTensor::equal_to(const AbstractTensor &other) const {
|
||||
if (this == &other) {
|
||||
return true;
|
||||
}
|
||||
// Check if both Tensor or both AdapterTensor.
|
||||
if (is_adapter() != other.is_adapter()) {
|
||||
return false;
|
||||
}
|
||||
const auto &v1 = GetValueTrack();
|
||||
const auto &v2 = other.GetValueTrack();
|
||||
MS_EXCEPTION_IF_NULL(v1);
|
||||
|
@ -1268,6 +1274,7 @@ AbstractBasePtr AbstractTensor::Clone() const {
|
|||
clone->set_value(GetValueTrack());
|
||||
clone->set_value_range(get_min_value(), get_max_value());
|
||||
clone->set_shape_value(get_shape_value());
|
||||
clone->set_is_adapter(is_adapter());
|
||||
return clone;
|
||||
}
|
||||
|
||||
|
@ -1278,6 +1285,7 @@ AbstractBasePtr AbstractTensor::Broaden() const {
|
|||
MS_EXCEPTION_IF_NULL(shp);
|
||||
broaden->set_shape(shp->Clone());
|
||||
broaden->set_value(kAnyValue);
|
||||
broaden->set_is_adapter(is_adapter());
|
||||
return broaden;
|
||||
}
|
||||
|
||||
|
@ -1289,6 +1297,7 @@ AbstractBasePtr AbstractTensor::BroadenWithShape() const {
|
|||
shp->Broaden();
|
||||
broaden->set_shape(shp);
|
||||
broaden->set_value(kAnyValue);
|
||||
broaden->set_is_adapter(is_adapter());
|
||||
return broaden;
|
||||
}
|
||||
|
||||
|
@ -1441,6 +1450,7 @@ std::string AbstractJTagged::ToString() const {
|
|||
AbstractRefTensor::AbstractRefTensor(const AbstractTensorPtr &ref_value, const ValuePtr &ref_key_value)
|
||||
: AbstractTensor(*ref_value), ref_key_value_(ref_key_value) {
|
||||
set_type(std::make_shared<RefType>());
|
||||
set_is_adapter(ref_value->is_adapter());
|
||||
MS_EXCEPTION_IF_NULL(ref_key_value);
|
||||
if (ref_key_value != kAnyValue && !ref_key_value->isa<RefKey>()) {
|
||||
MS_LOG(EXCEPTION) << "ref_key_value must be kAnyValue or RefKey, but got:" << ref_key_value->ToString();
|
||||
|
|
|
@ -730,11 +730,15 @@ class MS_CORE_API AbstractTensor : public AbstractUndetermined {
|
|||
|
||||
AbstractBasePtr PartialBroaden() const override;
|
||||
|
||||
bool is_adapter() const { return is_adapter_; }
|
||||
void set_is_adapter(bool is_adapter) { is_adapter_ = is_adapter; }
|
||||
|
||||
protected:
|
||||
bool equal_to(const AbstractTensor &other) const;
|
||||
ValuePtr min_value_ = nullptr;
|
||||
ValuePtr max_value_ = nullptr;
|
||||
ValuePtr shape_value_ = nullptr;
|
||||
bool is_adapter_ = false;
|
||||
};
|
||||
using AbstractTensorPtr = std::shared_ptr<AbstractTensor>;
|
||||
using AbstractTensorPtrList = std::vector<AbstractTensorPtr>;
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ir/adapter_tensor.h"
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <algorithm>
|
||||
#include "abstract/utils.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace tensor {
|
||||
bool AdapterTensor::operator==(const AdapterTensor &other) const { return this == &other; }
|
||||
|
||||
abstract::AbstractBasePtr AdapterTensor::ToAbstract() {
|
||||
auto abs = origin_tensor_->ToAbstract();
|
||||
MS_EXCEPTION_IF_NULL(abs);
|
||||
auto tensor_abs = abs->cast<abstract::AbstractTensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(tensor_abs);
|
||||
tensor_abs->set_is_adapter(true);
|
||||
return tensor_abs;
|
||||
}
|
||||
} // namespace tensor
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,57 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_IR_ADAPTER_TENSOR_H_
|
||||
#define MINDSPORE_CORE_IR_ADAPTER_TENSOR_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include "ir/anf.h"
|
||||
#include "ir/dtype.h"
|
||||
#include "ir/tensor.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace tensor {
|
||||
class AdapterTensor;
|
||||
// Smart pointer for AdapterTensor.
|
||||
using AdapterTensorPtr = std::shared_ptr<AdapterTensor>;
|
||||
///
|
||||
/// \brief AdapterTensor is used to map the Tensor of other frameworks.
|
||||
///
|
||||
class MS_CORE_API AdapterTensor final : public Tensor {
|
||||
public:
|
||||
/// \brief Create AdapterTensor from tensor.
|
||||
///
|
||||
/// \param[in] tensor The input tensor.
|
||||
explicit AdapterTensor(const TensorPtr &tensor) : Tensor(*tensor), origin_tensor_(tensor) {}
|
||||
|
||||
AdapterTensor() = default;
|
||||
|
||||
/// Destructor of AdapterTensor.
|
||||
~AdapterTensor() override = default;
|
||||
|
||||
MS_DECLARE_PARENT(AdapterTensor, Tensor);
|
||||
|
||||
bool operator==(const AdapterTensor &other) const;
|
||||
|
||||
abstract::AbstractBasePtr ToAbstract() override;
|
||||
|
||||
private:
|
||||
TensorPtr origin_tensor_{nullptr};
|
||||
};
|
||||
} // namespace tensor
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CORE_IR_ADAPTER_TENSOR_H_
|
|
@ -466,6 +466,16 @@ class MS_CORE_API Tensor : public MetaTensor {
|
|||
/// \param[in] flag Whether this Tensor is initialized.
|
||||
void set_init_flag(bool flag) { init_flag_ = flag; }
|
||||
|
||||
/// \brief Check whether this Tensor needs to be converted.
|
||||
///
|
||||
/// \return Whether this Tensor needs to be converted.
|
||||
bool is_adapter() const { return adapter_flag_; }
|
||||
|
||||
/// \brief Set the adapter flag of this Tensor.
|
||||
///
|
||||
/// \param[in] flag Whether this Tensor needs to be converted.
|
||||
void set_adapter_flag(bool flag) { adapter_flag_ = flag; }
|
||||
|
||||
/// \brief Check if this Tensor is forward output.
|
||||
///
|
||||
/// \return Whether this Tensor is forward output.
|
||||
|
@ -772,6 +782,7 @@ class MS_CORE_API Tensor : public MetaTensor {
|
|||
void ExecuteLazyTask() const;
|
||||
|
||||
bool init_flag_{false};
|
||||
bool adapter_flag_{false};
|
||||
bool is_forward_output_{false};
|
||||
TensorDataPtr data_{nullptr};
|
||||
std::string id_{""};
|
||||
|
|
|
@ -1693,6 +1693,8 @@ GVAR_DEF(PrimitivePtr, kPrimDictLen, std::make_shared<Primitive>("dict_len"));
|
|||
GVAR_DEF(PrimitivePtr, kPrimFakeBprop, std::make_shared<Primitive>("fake_bprop"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimBroadcastGradientArgs, std::make_shared<Primitive>("BroadcastGradientArgs"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimDynamicBroadcastGradientArgs, std::make_shared<Primitive>(kDynamicBroadcastGradientArgs));
|
||||
GVAR_DEF(PrimitivePtr, kPrimConvertToAdapterTensor, std::make_shared<Primitive>("ConvertToAdapterTensor"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimConvertToMsTensor, std::make_shared<Primitive>("ConvertToMsTensor"));
|
||||
|
||||
// Random
|
||||
GVAR_DEF(PrimitivePtr, kPrimStandardLaplace, std::make_shared<Primitive>("StandardLaplace"));
|
||||
|
|
|
@ -26,7 +26,8 @@ from .parser import (Parser, create_instance, is_supported_create_instance_type,
|
|||
convert_to_ms_tensor, get_object_description, get_class_attr_namespace_symbol, get_ms_class_name,
|
||||
is_class_type, check_obj_bool, python_isinstance, ms_isinstance, convert_to_ms_csrtensor,
|
||||
convert_to_ms_cootensor, convert_class_to_function, convert_cell_list_to_sequence, is_cell_list,
|
||||
get_obj_from_sequence, get_type, is_class_member_recursive, merge_global_params, get_global_params)
|
||||
get_obj_from_sequence, get_type, is_class_member_recursive, merge_global_params, get_global_params,
|
||||
get_adapter_tensor_attr)
|
||||
|
||||
__all__ = ['Parser', 'create_instance', 'is_supported_create_instance_type', 'generate_scope',
|
||||
'get_bprop_method_of_class', 'get_class_instance_type', 'get_class_member_namespace_symbol',
|
||||
|
@ -37,4 +38,4 @@ __all__ = ['Parser', 'create_instance', 'is_supported_create_instance_type', 'ge
|
|||
'convert_to_ms_tensor', 'get_object_description', 'get_class_attr_namespace_symbol', 'get_ms_class_name',
|
||||
'is_class_type', 'check_obj_bool', 'python_isinstance', 'ms_isinstance', 'convert_to_ms_csrtensor',
|
||||
'convert_to_ms_cootensor', 'convert_class_to_function', 'convert_cell_list_to_sequence', 'is_cell_list',
|
||||
'get_obj_from_sequence', 'get_type', 'is_class_member_recursive']
|
||||
'get_obj_from_sequence', 'get_type', 'is_class_member_recursive', 'get_adapter_tensor_attr']
|
||||
|
|
|
@ -40,6 +40,7 @@ from mindspore.common.api import _MindsporeFunctionExecutor
|
|||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common import mutable
|
||||
from mindspore.common._register_for_adapter import ms_adapter_registry
|
||||
from .namespace import Namespace, CellNamespace, ClosureNamespace, ClassMemberNamespace, ClassAttrNamespace
|
||||
from .resources import parse_object_map, ops_symbol_map, convert_object_map, convert_class_to_function_map, trope_ns
|
||||
from .resources import SYMBOL_UNDEFINE, NO_IMPLEMENT
|
||||
|
@ -127,6 +128,12 @@ _unsupported_convert_data_type = (
|
|||
_global_params = {}
|
||||
|
||||
|
||||
def _convert_map():
|
||||
"""Get convert object map"""
|
||||
adapter_convert_map = ms_adapter_registry.convert_map
|
||||
return adapter_convert_map if adapter_convert_map else convert_object_map
|
||||
|
||||
|
||||
def create_slice_obj(start, end, step):
|
||||
"""Create slice object"""
|
||||
return slice(start, end, step)
|
||||
|
@ -245,8 +252,9 @@ def resolve_symbol(namespace, symbol):
|
|||
"within the construct() or @jit decorated function in graph mode.")
|
||||
|
||||
# If need trope the obj
|
||||
if resolve_ in convert_object_map:
|
||||
resolve_ = convert_object_map.get(resolve_)
|
||||
convert_map = _convert_map()
|
||||
if resolve_ in convert_map:
|
||||
resolve_ = convert_map.get(resolve_)
|
||||
logger.debug("Convert resolve: %r", resolve_)
|
||||
if resolve_ == NO_IMPLEMENT:
|
||||
raise NotImplementedError(f"Not support for '{symbol}'.")
|
||||
|
@ -549,6 +557,18 @@ def is_class_type(cls):
|
|||
return isinstance(cls, type)
|
||||
|
||||
|
||||
def get_adapter_tensor_attr(name):
|
||||
"""Get the method or @property modified function of the class, excluding those inherited from parent class."""
|
||||
cls = ms_adapter_registry.tensor
|
||||
properties = [key for key, value in vars(cls).items() if isinstance(value, property)]
|
||||
if name in properties:
|
||||
return getattr(cls, name).fget, True
|
||||
methods = [key for key, value in vars(cls).items() if inspect.isfunction(value)]
|
||||
if name in methods:
|
||||
return getattr(cls, name), False
|
||||
return None, False
|
||||
|
||||
|
||||
def get_ms_class_name(cls):
|
||||
"""Get the name of the class instance decorated with jit_class."""
|
||||
if isinstance(cls, type):
|
||||
|
@ -827,7 +847,7 @@ class Parser:
|
|||
@staticmethod
|
||||
def is_unsupported_namespace(value):
|
||||
"""To check if not supported for namespace"""
|
||||
unsupported = isinstance(value, _builtin_function_or_method_type) and value not in convert_object_map
|
||||
unsupported = isinstance(value, _builtin_function_or_method_type) and value not in _convert_map()
|
||||
logger.debug(f"'{value}' unsupported: {unsupported}.")
|
||||
if unsupported and value in _fallback_unsupported_python_builtin_type:
|
||||
raise TypeError(f"'{value}' is not supported both in JIT Fallback and graph mode.")
|
||||
|
@ -868,7 +888,7 @@ class Parser:
|
|||
"""Get the convert object for value which don't support to be converted in C++."""
|
||||
if not self.is_unsupported_convert_data_type(value):
|
||||
return value
|
||||
return convert_object_map.get(value)
|
||||
return _convert_map().get(value)
|
||||
|
||||
def parse(self):
|
||||
"""Parse the function or method."""
|
||||
|
|
|
@ -0,0 +1,51 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Registry MSAdapter config."""
|
||||
|
||||
from mindspore.common.tensor import Tensor
|
||||
|
||||
|
||||
class Registry:
|
||||
"""Registry class for ms adapter."""
|
||||
|
||||
def __init__(self):
|
||||
self._tensor = None
|
||||
self._convert_map = {}
|
||||
|
||||
@property
|
||||
def tensor(self):
|
||||
if self._tensor is None:
|
||||
raise ValueError("Before using Tensor in MSAdapter, please call 'set_adapter_config'.")
|
||||
return self._tensor
|
||||
|
||||
@property
|
||||
def convert_map(self):
|
||||
return self._convert_map
|
||||
|
||||
def register_tensor(self, value):
|
||||
if self._tensor is not None:
|
||||
raise ValueError("Repeated registration of tensor in ms adapter config.")
|
||||
if not issubclass(value, Tensor):
|
||||
raise ValueError(f"The tensor definition here should be a subclass of ms.Tensor, but got {value}.")
|
||||
self._tensor = value
|
||||
|
||||
def register_convert_map(self, value):
|
||||
if not isinstance(value, dict):
|
||||
raise ValueError(f"Expect a dict type, but got {type(value)}.")
|
||||
self._convert_map = value
|
||||
|
||||
|
||||
ms_adapter_registry = Registry()
|
|
@ -45,6 +45,7 @@ from mindspore.parallel._utils import _check_full_batch, _get_parameter_broadcas
|
|||
from mindspore._checkparam import Validator
|
||||
from mindspore.common._utils import is_shape_unknown
|
||||
from mindspore.common.mutable import mutable
|
||||
from mindspore.common._register_for_adapter import ms_adapter_registry
|
||||
|
||||
# store ms_function class compiled pipeline cache
|
||||
ms_compile_cache = set()
|
||||
|
@ -65,6 +66,8 @@ def _convert_python_data(data):
|
|||
Returns:
|
||||
data, a data convert C++ to python
|
||||
"""
|
||||
if isinstance(data, Tensor) and data.adapter_flag:
|
||||
return ms_adapter_registry.tensor(data)
|
||||
if isinstance(data, Tensor) and not isinstance(data, PythonTensor):
|
||||
return PythonTensor(data, internal=True)
|
||||
if isinstance(data, CSRTensor) and not isinstance(data, PythonCSRTensor):
|
||||
|
@ -878,6 +881,25 @@ def jit_class(cls):
|
|||
return cls
|
||||
|
||||
|
||||
def set_adapter_config(config):
|
||||
"""
|
||||
Register configuration information for MSAdapter.
|
||||
|
||||
Args:
|
||||
config (dict): Configuration information.
|
||||
"""
|
||||
if not isinstance(config, dict):
|
||||
raise TypeError(f"The input argument of 'set_adapter_config' should be a dict, but got {config}.")
|
||||
for key, value in config.items():
|
||||
if key == "Tensor":
|
||||
setattr(value, "__adapter_tensor__", True)
|
||||
ms_adapter_registry.register_tensor(value)
|
||||
elif key == "convert_object_map":
|
||||
ms_adapter_registry.register_convert_map(value)
|
||||
else:
|
||||
raise ValueError(f"Unsupported key in adapter config: {key}")
|
||||
|
||||
|
||||
def _function_forbid_reuse(func):
|
||||
if not inspect.isfunction(func):
|
||||
raise TypeError(f'Decorator _function_forbid_reuse can only be used for function type, but got {func}.')
|
||||
|
|
|
@ -197,3 +197,23 @@ def get_bprop_resize_bilinear(self):
|
|||
return dx, zeros_like(size)
|
||||
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(inner.ConvertToAdapterTensor)
|
||||
def get_bprop_convert_to_adapter_tensor(self):
|
||||
"""Generate bprop for ConvertToAdapterTensor"""
|
||||
|
||||
def bprop(x, out, dout):
|
||||
return (dout,)
|
||||
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(inner.ConvertToMsTensor)
|
||||
def get_bprop_convert_to_ms_tensor(self):
|
||||
"""Generate bprop for ConvertToMsTensor"""
|
||||
|
||||
def bprop(x, out, dout):
|
||||
return (dout,)
|
||||
|
||||
return bprop
|
||||
|
|
|
@ -31,6 +31,7 @@ from mindspore.common import dtype as mstype
|
|||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.communication.management import GlobalComm
|
||||
from mindspore.common.api import _pynative_executor
|
||||
from mindspore.common._register_for_adapter import ms_adapter_registry
|
||||
|
||||
|
||||
# Bit operation
|
||||
|
@ -2321,3 +2322,65 @@ class IsInstance(PrimitiveWithInfer):
|
|||
'dtype': mstype.type_type,
|
||||
'value': value}
|
||||
return out
|
||||
|
||||
|
||||
class ConvertToAdapterTensor(Primitive):
|
||||
"""
|
||||
Convert a tensor from MindSpore's Tensor type to MSAdapter's Tensor type,
|
||||
where MSAdapter's Tensor is a subclass of MindSpore's Tensor.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - The input tensor.
|
||||
|
||||
Outputs:
|
||||
A tensor, whose type is MSAdapter's Tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor([1, 2 ,3])
|
||||
>>> x = ops.ConvertToAdapterTensor()(x)
|
||||
>>> print(x)
|
||||
[1 2 3]
|
||||
"""
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize"""
|
||||
|
||||
def __call__(self, x):
|
||||
"""run in PyNative mode"""
|
||||
return ms_adapter_registry.tensor(x, inner=True)
|
||||
|
||||
convert_to_adapter_tensor = ConvertToAdapterTensor()
|
||||
|
||||
|
||||
class ConvertToMsTensor(Primitive):
|
||||
"""
|
||||
Convert a tensor from MSAdapter's Tensor type to MindSpore's Tensor type,
|
||||
where MSAdapter's Tensor is a subclass of MindSpore's Tensor.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - The input tensor.
|
||||
|
||||
Outputs:
|
||||
A tensor, whose type is MindSpore's Tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor([1, 2 ,3])
|
||||
>>> x = ops.ConvertToMsTensor()(x)
|
||||
>>> print(x)
|
||||
[1 2 3]
|
||||
"""
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize"""
|
||||
|
||||
def __call__(self, x):
|
||||
"""run in PyNative mode"""
|
||||
return Tensor(x)
|
||||
|
||||
convert_to_ms_tensor = ConvertToMsTensor()
|
||||
|
|
|
@ -0,0 +1,44 @@
|
|||
from mindspore.common.api import set_adapter_config
|
||||
from mindspore._extends.parse import trope as T
|
||||
from mindspore._extends.parse.resources import convert_object_map
|
||||
from ._register.ms_adapter_api import Tensor, Parameter
|
||||
from ._register import multitype_ops
|
||||
from ._register import standard_method as S
|
||||
|
||||
# Update convert_object_map
|
||||
convert_object_map[T.add] = multitype_ops.add
|
||||
convert_object_map[T.sub] = multitype_ops.sub
|
||||
convert_object_map[T.mul] = multitype_ops.mul
|
||||
convert_object_map[T.truediv] = multitype_ops.div
|
||||
convert_object_map[T.getitem] = multitype_ops.getitem
|
||||
convert_object_map[T.setitem] = multitype_ops.setitem
|
||||
convert_object_map[T.floordiv] = multitype_ops.floordiv
|
||||
convert_object_map[T.mod] = multitype_ops.mod
|
||||
convert_object_map[T.pow] = multitype_ops.pow_
|
||||
convert_object_map[T.and_] = multitype_ops.bitwise_and
|
||||
convert_object_map[T.or_] = multitype_ops.bitwise_or
|
||||
convert_object_map[T.xor] = multitype_ops.bitwise_xor
|
||||
convert_object_map[T.neg] = multitype_ops.negative
|
||||
convert_object_map[T.not_] = multitype_ops.logical_not
|
||||
convert_object_map[T.eq] = multitype_ops.equal
|
||||
convert_object_map[T.ne] = multitype_ops.not_equal
|
||||
convert_object_map[T.lt] = multitype_ops.less
|
||||
convert_object_map[T.gt] = multitype_ops.greater
|
||||
convert_object_map[T.le] = multitype_ops.less_equal
|
||||
convert_object_map[T.ge] = multitype_ops.greater_equal
|
||||
convert_object_map[T.contains] = multitype_ops.in_
|
||||
convert_object_map[T.not_contains] = multitype_ops.not_in_
|
||||
convert_object_map[T.matmul] = S.adapter_matmul
|
||||
convert_object_map[T.invert] = S.adapter_invert
|
||||
convert_object_map[T.abs] = S.adapter_abs
|
||||
convert_object_map[T.round] = S.adapter_round
|
||||
convert_object_map[T.max] = S.adapter_max
|
||||
convert_object_map[T.min] = S.adapter_min
|
||||
convert_object_map[T.sum] = S.adapter_sum
|
||||
|
||||
|
||||
adapter_config = {"Tensor": Tensor, "convert_object_map": convert_object_map}
|
||||
set_adapter_config(adapter_config)
|
||||
|
||||
|
||||
__all__ = ["Tensor", "Parameter"]
|
|
@ -0,0 +1,60 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" MSAdapter api. """
|
||||
|
||||
import sys
|
||||
import mindspore as ms
|
||||
|
||||
|
||||
class Tensor(ms.Tensor):
|
||||
def __init__(self, input_data=None, dtype=None, shape=None, init=None, inner=False):
|
||||
super(Tensor, self).__init__(input_data=input_data, dtype=dtype, shape=shape, init=init)
|
||||
|
||||
@property
|
||||
def attr(self):
|
||||
return 10
|
||||
|
||||
def method(self, x):
|
||||
return x + self.attr
|
||||
|
||||
def size(self, dim=None):
|
||||
if dim is None:
|
||||
return self.shape
|
||||
return self.shape[dim]
|
||||
|
||||
|
||||
class Parameter(ms.Parameter):
|
||||
def __new__(cls, default_input, *args, **kwargs):
|
||||
init_data_flag = bool(isinstance(default_input, ms.Tensor) and default_input.has_init)
|
||||
rc = sys.getrefcount(default_input)
|
||||
_, *class_init_args = Parameter._get_parameter_new_args(default_input, rc)
|
||||
new_type = Parameter._get_base_class(Tensor)
|
||||
obj = Tensor.__new__(new_type)
|
||||
Tensor.__init__(obj, *class_init_args)
|
||||
obj.init_mode = None
|
||||
obj.is_default_input_init = init_data_flag
|
||||
if obj.has_init:
|
||||
obj.init_mode = default_input
|
||||
return obj
|
||||
|
||||
@staticmethod
|
||||
def _get_base_class(input_class):
|
||||
input_class_name = Parameter.__name__
|
||||
if input_class_name in Parameter._base_type:
|
||||
new_type = Parameter._base_type.get(input_class_name)
|
||||
else:
|
||||
new_type = type(input_class_name, (Parameter, input_class), {})
|
||||
Parameter._base_type[input_class_name] = new_type
|
||||
return new_type
|
|
@ -0,0 +1,172 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
from mindspore.ops.composite.multitype_ops.add_impl import add
|
||||
from mindspore.ops.composite.multitype_ops.sub_impl import sub
|
||||
from mindspore.ops.composite.multitype_ops.mul_impl import mul
|
||||
from mindspore.ops.composite.multitype_ops.div_impl import div
|
||||
from mindspore.ops.composite.multitype_ops.floordiv_impl import floordiv
|
||||
from mindspore.ops.composite.multitype_ops.mod_impl import mod
|
||||
from mindspore.ops.composite.multitype_ops.pow_impl import pow_
|
||||
from mindspore.ops.composite.multitype_ops.bitwise_and_impl import bitwise_and
|
||||
from mindspore.ops.composite.multitype_ops.bitwise_or_impl import bitwise_or
|
||||
from mindspore.ops.composite.multitype_ops.bitwise_xor_impl import bitwise_xor
|
||||
from mindspore.ops.composite.multitype_ops.negative_impl import negative
|
||||
from mindspore.ops.composite.multitype_ops.logic_not_impl import logical_not
|
||||
from mindspore.ops.composite.multitype_ops.equal_impl import equal
|
||||
from mindspore.ops.composite.multitype_ops.not_equal_impl import not_equal
|
||||
from mindspore.ops.composite.multitype_ops.less_impl import less
|
||||
from mindspore.ops.composite.multitype_ops.greater_impl import greater
|
||||
from mindspore.ops.composite.multitype_ops.less_equal_impl import less_equal
|
||||
from mindspore.ops.composite.multitype_ops.greater_equal_impl import greater_equal
|
||||
from mindspore.ops.composite.multitype_ops.in_impl import in_
|
||||
from mindspore.ops.composite.multitype_ops.not_in_impl import not_in_
|
||||
from mindspore.ops.composite.multitype_ops.getitem_impl import getitem
|
||||
from mindspore.ops.composite.multitype_ops.setitem_impl import setitem
|
||||
from tests.st.ms_adapter._register import utils
|
||||
|
||||
|
||||
# multitype_ops.add
|
||||
utils.update_multitype_ops_tensor_tensor(add)
|
||||
utils.update_multitype_ops_number_tensor(add)
|
||||
utils.update_multitype_ops_tensor_number(add)
|
||||
utils.update_multitype_ops_tuple_tensor(add)
|
||||
utils.update_multitype_ops_tensor_tuple(add)
|
||||
utils.update_multitype_ops_list_tensor(add)
|
||||
utils.update_multitype_ops_tensor_list(add)
|
||||
|
||||
# multitype_ops.sub
|
||||
utils.update_multitype_ops_tensor_tensor(sub)
|
||||
utils.update_multitype_ops_number_tensor(sub)
|
||||
utils.update_multitype_ops_tensor_number(sub)
|
||||
utils.update_multitype_ops_tuple_tensor(sub)
|
||||
utils.update_multitype_ops_tensor_tuple(sub)
|
||||
utils.update_multitype_ops_list_tensor(sub)
|
||||
utils.update_multitype_ops_tensor_list(sub)
|
||||
|
||||
# multitype_ops.mul
|
||||
utils.update_multitype_ops_tensor_tensor(mul)
|
||||
utils.update_multitype_ops_number_tensor(mul)
|
||||
utils.update_multitype_ops_tensor_number(mul)
|
||||
utils.update_multitype_ops_tuple_tensor(mul)
|
||||
utils.update_multitype_ops_tensor_tuple(mul)
|
||||
utils.update_multitype_ops_list_tensor(mul)
|
||||
utils.update_multitype_ops_tensor_list(mul)
|
||||
|
||||
# multitype_ops.div
|
||||
utils.update_multitype_ops_tensor_tensor(div)
|
||||
utils.update_multitype_ops_number_tensor(div)
|
||||
utils.update_multitype_ops_tensor_number(div)
|
||||
utils.update_multitype_ops_tuple_tensor(div)
|
||||
utils.update_multitype_ops_tensor_tuple(div)
|
||||
utils.update_multitype_ops_list_tensor(div)
|
||||
utils.update_multitype_ops_tensor_list(div)
|
||||
|
||||
# multitype_ops.floordiv
|
||||
utils.update_multitype_ops_tensor_tensor(floordiv)
|
||||
utils.update_multitype_ops_number_tensor(floordiv)
|
||||
utils.update_multitype_ops_tensor_number(floordiv)
|
||||
utils.update_multitype_ops_tuple_tensor(floordiv)
|
||||
utils.update_multitype_ops_tensor_tuple(floordiv)
|
||||
utils.update_multitype_ops_list_tensor(floordiv)
|
||||
utils.update_multitype_ops_tensor_list(floordiv)
|
||||
|
||||
# multitype_ops.mod
|
||||
utils.update_multitype_ops_tensor_tensor(mod)
|
||||
utils.update_multitype_ops_number_tensor(mod)
|
||||
utils.update_multitype_ops_tensor_number(mod)
|
||||
utils.update_multitype_ops_tuple_tensor(mod)
|
||||
utils.update_multitype_ops_tensor_tuple(mod)
|
||||
utils.update_multitype_ops_list_tensor(mod)
|
||||
utils.update_multitype_ops_tensor_list(mod)
|
||||
|
||||
# multitype_ops.pow_
|
||||
utils.update_multitype_ops_tensor_tensor(pow_)
|
||||
utils.update_multitype_ops_number_tensor(pow_)
|
||||
utils.update_multitype_ops_tensor_number(pow_)
|
||||
utils.update_multitype_ops_tuple_tensor(pow_)
|
||||
utils.update_multitype_ops_tensor_tuple(pow_)
|
||||
utils.update_multitype_ops_list_tensor(pow_)
|
||||
utils.update_multitype_ops_tensor_list(pow_)
|
||||
|
||||
# multitype_ops.bitwise_and
|
||||
utils.update_multitype_ops_tensor_tensor(bitwise_and)
|
||||
utils.update_multitype_ops_number_tensor(bitwise_and)
|
||||
utils.update_multitype_ops_tensor_number(bitwise_and)
|
||||
|
||||
# multitype_ops.bitwise_or
|
||||
utils.update_multitype_ops_tensor_tensor(bitwise_or)
|
||||
utils.update_multitype_ops_number_tensor(bitwise_or)
|
||||
utils.update_multitype_ops_tensor_number(bitwise_or)
|
||||
|
||||
# multitype_ops.bitwise_xor
|
||||
utils.update_multitype_ops_tensor_tensor(bitwise_xor)
|
||||
utils.update_multitype_ops_number_tensor(bitwise_xor)
|
||||
utils.update_multitype_ops_tensor_number(bitwise_xor)
|
||||
|
||||
# multitype_ops.negative
|
||||
utils.update_multitype_ops_tensor(negative)
|
||||
|
||||
# multitype_ops.logical_not
|
||||
utils.update_multitype_ops_tensor(logical_not)
|
||||
|
||||
# multitype_ops.equal
|
||||
utils.update_multitype_ops_tensor_tensor(equal)
|
||||
utils.update_multitype_ops_number_tensor(equal)
|
||||
utils.update_multitype_ops_tensor_number(equal)
|
||||
|
||||
# multitype_ops.not_equal
|
||||
utils.update_multitype_ops_tensor_tensor(not_equal)
|
||||
utils.update_multitype_ops_number_tensor(not_equal)
|
||||
utils.update_multitype_ops_tensor_number(not_equal)
|
||||
|
||||
# multitype_ops.less
|
||||
utils.update_multitype_ops_tensor_tensor(less)
|
||||
utils.update_multitype_ops_number_tensor(less)
|
||||
utils.update_multitype_ops_tensor_number(less)
|
||||
|
||||
# multitype_ops.greater
|
||||
utils.update_multitype_ops_tensor_tensor(greater)
|
||||
utils.update_multitype_ops_number_tensor(greater)
|
||||
utils.update_multitype_ops_tensor_number(greater)
|
||||
|
||||
# multitype_ops.less_equal
|
||||
utils.update_multitype_ops_tensor_tensor(less_equal)
|
||||
utils.update_multitype_ops_number_tensor(less_equal)
|
||||
utils.update_multitype_ops_tensor_number(less_equal)
|
||||
|
||||
# multitype_ops.greater_equal
|
||||
utils.update_multitype_ops_tensor_tensor(greater_equal)
|
||||
utils.update_multitype_ops_number_tensor(greater_equal)
|
||||
utils.update_multitype_ops_tensor_number(greater_equal)
|
||||
|
||||
# multitype_ops.in_
|
||||
utils.update_multitype_ops_tensor_tuple(in_)
|
||||
utils.update_multitype_ops_tensor_list(in_)
|
||||
|
||||
# multitype_ops.not_in_
|
||||
utils.update_multitype_ops_tensor_tuple(not_in_)
|
||||
utils.update_multitype_ops_tensor_list(not_in_)
|
||||
|
||||
# multitype_ops.getitem
|
||||
utils.update_multitype_ops_tensor_tensor(getitem)
|
||||
utils.update_multitype_ops_tensor_number(getitem)
|
||||
utils.update_multitype_ops_tensor_tuple(getitem)
|
||||
utils.update_multitype_ops_tensor_list(getitem)
|
||||
utils.update_multitype_ops_tensor_none(getitem)
|
||||
utils.update_multitype_ops_tensor_slice(getitem)
|
||||
|
||||
# multitype_ops.setitem
|
||||
utils.update_multitype_ops_setitem_tensor(setitem)
|
|
@ -0,0 +1,107 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
from mindspore._extends.parse import trope as T
|
||||
from mindspore._extends.parse.resources import convert_object_map
|
||||
from tests.st.ms_adapter._register.ms_adapter_api import Tensor as adapter_Tensor
|
||||
from tests.st.ms_adapter._register.utils import convert_to_ms_tensor, convert_to_adapter_tensor
|
||||
|
||||
|
||||
matmul_fn = convert_object_map.get(T.matmul)
|
||||
invert_fn = convert_object_map.get(T.invert)
|
||||
abs_fn = convert_object_map.get(T.abs)
|
||||
round_fn = convert_object_map.get(T.round)
|
||||
max_fn = convert_object_map.get(T.max)
|
||||
min_fn = convert_object_map.get(T.min)
|
||||
sum_fn = convert_object_map.get(T.sum)
|
||||
|
||||
|
||||
def adapter_matmul(x, y):
|
||||
if isinstance(x, adapter_Tensor) and isinstance(y, adapter_Tensor):
|
||||
x = convert_to_ms_tensor(x)
|
||||
y = convert_to_ms_tensor(y)
|
||||
out = matmul_fn(x, y)
|
||||
out = convert_to_adapter_tensor(out)
|
||||
else:
|
||||
out = matmul_fn(x, y)
|
||||
return out
|
||||
|
||||
|
||||
def adapter_invert(x):
|
||||
if isinstance(x, adapter_Tensor):
|
||||
x = convert_to_ms_tensor(x)
|
||||
out = invert_fn(x)
|
||||
out = convert_to_adapter_tensor(out)
|
||||
else:
|
||||
out = invert_fn(x)
|
||||
return out
|
||||
|
||||
|
||||
def adapter_abs(x):
|
||||
if isinstance(x, adapter_Tensor):
|
||||
x = convert_to_ms_tensor(x)
|
||||
out = abs_fn(x)
|
||||
out = convert_to_adapter_tensor(out)
|
||||
else:
|
||||
out = abs_fn(x)
|
||||
return out
|
||||
|
||||
|
||||
def adapter_round(*data):
|
||||
if (len(data) == 1 and isinstance(data[0], adapter_Tensor)) or \
|
||||
(len(data) == 2 and isinstance(data[0], adapter_Tensor) and isinstance(data[1], None)):
|
||||
x = data[0]
|
||||
x = convert_to_ms_tensor(x)
|
||||
out = round_fn(x)
|
||||
out = convert_to_adapter_tensor(out)
|
||||
else:
|
||||
out = round_fn(*data)
|
||||
return out
|
||||
|
||||
|
||||
def _has_adapter_tensor(*data):
|
||||
if len(data) == 1 and isinstance(data[0], adapter_Tensor):
|
||||
return True
|
||||
for elem in data:
|
||||
if isinstance(elem, adapter_Tensor):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def adapter_max(*data):
|
||||
if _has_adapter_tensor(*data):
|
||||
out = max_fn(*data)
|
||||
out = convert_to_adapter_tensor(out)
|
||||
else:
|
||||
out = max_fn(*data)
|
||||
return out
|
||||
|
||||
|
||||
def adapter_min(*data):
|
||||
if _has_adapter_tensor(*data):
|
||||
out = min_fn(*data)
|
||||
out = convert_to_adapter_tensor(out)
|
||||
else:
|
||||
out = min_fn(*data)
|
||||
return out
|
||||
|
||||
|
||||
def adapter_sum(*data):
|
||||
if _has_adapter_tensor(*data):
|
||||
out = sum_fn(*data)
|
||||
out = convert_to_adapter_tensor(out)
|
||||
else:
|
||||
out = sum_fn(*data)
|
||||
return out
|
|
@ -0,0 +1,204 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import mindspore as ms
|
||||
from mindspore import dtype as mstype
|
||||
from mindspore._c_expression import typing
|
||||
from mindspore.ops.operations import _inner_ops as inner
|
||||
from tests.st.ms_adapter._register.ms_adapter_api import Tensor as adapter_Tensor
|
||||
|
||||
|
||||
def convert_to_ms_tensor(x):
|
||||
return inner.convert_to_ms_tensor(x)
|
||||
|
||||
|
||||
def convert_to_adapter_tensor(x):
|
||||
return inner.convert_to_adapter_tensor(x)
|
||||
|
||||
|
||||
def get_registed_fn(ops, *type_names):
|
||||
types = tuple(map(mstype.typing.str_to_type, type_names))
|
||||
for sigs, fn in ops.entries:
|
||||
if len(sigs) != len(types):
|
||||
continue
|
||||
if any(not typing.is_subclass(type_, sig) for sig, type_ in zip(sigs, types)):
|
||||
continue
|
||||
return fn
|
||||
raise ValueError(f"For 'MultitypeFuncGraph', cannot find fn match given types: {types}.")
|
||||
|
||||
|
||||
def convert_output(out):
|
||||
if isinstance(out, ms.Tensor):
|
||||
out = convert_to_adapter_tensor(out)
|
||||
return out
|
||||
|
||||
|
||||
def update_multitype_ops_tensor(ops):
|
||||
func = get_registed_fn(ops, "Tensor")
|
||||
|
||||
@ops.register("Tensor")
|
||||
def _tensor(x):
|
||||
if isinstance(x, adapter_Tensor):
|
||||
x = convert_to_ms_tensor(x)
|
||||
out = func(x)
|
||||
out = convert_output(out)
|
||||
else:
|
||||
out = func(x)
|
||||
return out
|
||||
|
||||
|
||||
def update_multitype_ops_tensor_tensor(ops):
|
||||
func = get_registed_fn(ops, "Tensor", "Tensor")
|
||||
|
||||
@ops.register("Tensor", "Tensor")
|
||||
def _tensor_and_tensor(x, y):
|
||||
if isinstance(x, adapter_Tensor) and isinstance(y, adapter_Tensor):
|
||||
x = convert_to_ms_tensor(x)
|
||||
y = convert_to_ms_tensor(y)
|
||||
out = func(x, y)
|
||||
out = convert_output(out)
|
||||
else:
|
||||
out = func(x, y)
|
||||
return out
|
||||
|
||||
|
||||
def update_multitype_ops_number_tensor(ops):
|
||||
func = get_registed_fn(ops, "Number", "Tensor")
|
||||
|
||||
@ops.register("Number", "Tensor")
|
||||
def _number_and_tensor(x, y):
|
||||
if isinstance(y, adapter_Tensor):
|
||||
y = convert_to_ms_tensor(y)
|
||||
out = func(x, y)
|
||||
out = convert_output(out)
|
||||
else:
|
||||
out = func(x, y)
|
||||
return out
|
||||
|
||||
|
||||
def update_multitype_ops_tensor_number(ops):
|
||||
func = get_registed_fn(ops, "Tensor", "Number")
|
||||
|
||||
@ops.register("Tensor", "Number")
|
||||
def _tensor_and_number(x, y):
|
||||
if isinstance(x, adapter_Tensor):
|
||||
x = convert_to_ms_tensor(x)
|
||||
out = func(x, y)
|
||||
out = convert_output(out)
|
||||
else:
|
||||
out = func(x, y)
|
||||
return out
|
||||
|
||||
|
||||
def update_multitype_ops_tuple_tensor(ops):
|
||||
func = get_registed_fn(ops, "Tuple", "Tensor")
|
||||
|
||||
@ops.register("Tuple", "Tensor")
|
||||
def _tuple_and_tensor(x, y):
|
||||
if isinstance(y, adapter_Tensor):
|
||||
y = convert_to_ms_tensor(y)
|
||||
out = func(x, y)
|
||||
out = convert_output(out)
|
||||
else:
|
||||
out = func(x, y)
|
||||
return out
|
||||
|
||||
|
||||
def update_multitype_ops_tensor_tuple(ops):
|
||||
func = get_registed_fn(ops, "Tensor", "Tuple")
|
||||
|
||||
@ops.register("Tensor", "Tuple")
|
||||
def _tensor_and_tuple(x, y):
|
||||
if isinstance(x, adapter_Tensor):
|
||||
x = convert_to_ms_tensor(x)
|
||||
out = func(x, y)
|
||||
out = convert_output(out)
|
||||
else:
|
||||
out = func(x, y)
|
||||
return out
|
||||
|
||||
|
||||
def update_multitype_ops_list_tensor(ops):
|
||||
func = get_registed_fn(ops, "List", "Tensor")
|
||||
|
||||
@ops.register("List", "Tensor")
|
||||
def _list_and_tensor(x, y):
|
||||
if isinstance(y, adapter_Tensor):
|
||||
y = convert_to_ms_tensor(y)
|
||||
out = func(x, y)
|
||||
out = convert_output(out)
|
||||
else:
|
||||
out = func(x, y)
|
||||
return out
|
||||
|
||||
|
||||
def update_multitype_ops_tensor_list(ops):
|
||||
func = get_registed_fn(ops, "Tensor", "List")
|
||||
|
||||
@ops.register("Tensor", "List")
|
||||
def _tensor_and_list(x, y):
|
||||
if isinstance(x, adapter_Tensor):
|
||||
x = convert_to_ms_tensor(x)
|
||||
out = func(x, y)
|
||||
out = convert_output(out)
|
||||
else:
|
||||
out = func(x, y)
|
||||
return out
|
||||
|
||||
|
||||
def update_multitype_ops_tensor_none(ops):
|
||||
func = get_registed_fn(ops, "Tensor", "None")
|
||||
|
||||
@ops.register("Tensor", "None")
|
||||
def _tensor_and_none(x, y):
|
||||
if isinstance(x, adapter_Tensor):
|
||||
x = convert_to_ms_tensor(x)
|
||||
out = func(x, y)
|
||||
out = convert_output(out)
|
||||
else:
|
||||
out = func(x, y)
|
||||
return out
|
||||
|
||||
|
||||
def update_multitype_ops_tensor_slice(ops):
|
||||
func = get_registed_fn(ops, "Tensor", "Slice")
|
||||
|
||||
@ops.register("Tensor", "Slice")
|
||||
def _tensor_and_slice(x, y):
|
||||
if isinstance(x, adapter_Tensor):
|
||||
x = convert_to_ms_tensor(x)
|
||||
out = func(x, y)
|
||||
out = convert_output(out)
|
||||
else:
|
||||
out = func(x, y)
|
||||
return out
|
||||
|
||||
|
||||
def update_multitype_ops_setitem_tensor(ops):
|
||||
def register_for_setitem(sigs, fn):
|
||||
@ops.register(*sigs)
|
||||
def _tensor_setitem(data, index, value):
|
||||
if isinstance(data, adapter_Tensor):
|
||||
data = convert_to_ms_tensor(data)
|
||||
out = fn(data, index, value)
|
||||
out = convert_to_adapter_tensor(out)
|
||||
else:
|
||||
out = fn(data, index, value)
|
||||
return out
|
||||
|
||||
entries = ops.entries.copy()
|
||||
for sigs, fn in entries:
|
||||
if typing.is_subclass(sigs[0], mstype.tensor):
|
||||
register_for_setitem(sigs, fn)
|
|
@ -0,0 +1,141 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test MSAdapter. """
|
||||
|
||||
import pytest
|
||||
import mindspore as ms
|
||||
from tests.st.ms_adapter import Tensor, Parameter
|
||||
from tests.st.ms_adapter._register.utils import convert_to_ms_tensor, convert_to_adapter_tensor
|
||||
|
||||
|
||||
ms.set_context(mode=ms.GRAPH_MODE)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_tensor_attr():
|
||||
"""
|
||||
Feature: MSAdapter
|
||||
Description: Get the properties of MSAdapter.Tensor
|
||||
Expectation: No exception
|
||||
"""
|
||||
@ms.jit
|
||||
def func(x):
|
||||
return x.attr
|
||||
|
||||
x = Tensor(1)
|
||||
assert func(x) == 10
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_tensor_method():
|
||||
"""
|
||||
Feature: MSAdapter
|
||||
Description: Get the methods of MSAdapter.Tensor
|
||||
Expectation: No exception
|
||||
"""
|
||||
@ms.jit
|
||||
def func(x):
|
||||
return x.method(10)
|
||||
|
||||
x = Tensor(1)
|
||||
assert func(x) == 20
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_parameter_attr():
|
||||
"""
|
||||
Feature: MSAdapter
|
||||
Description: Get the properties of MSAdapter.Parameter
|
||||
Expectation: No exception
|
||||
"""
|
||||
@ms.jit
|
||||
def func(x):
|
||||
return x.attr
|
||||
|
||||
x = Parameter(Tensor(1))
|
||||
assert func(x) == 10
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_parameter_method():
|
||||
"""
|
||||
Feature: MSAdapter
|
||||
Description: Get the methods of MSAdapter.Parameter
|
||||
Expectation: No exception
|
||||
"""
|
||||
@ms.jit
|
||||
def func(x):
|
||||
return x.method(10)
|
||||
|
||||
x = Parameter(Tensor(1))
|
||||
assert func(x) == 20
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_tensor_convert_type():
|
||||
"""
|
||||
Feature: MSAdapter
|
||||
Description: Test type conversion
|
||||
Expectation: No exception
|
||||
"""
|
||||
@ms.jit
|
||||
def func(x, y):
|
||||
a = x.size(0)
|
||||
b = y.size
|
||||
x = convert_to_ms_tensor(x)
|
||||
y = convert_to_adapter_tensor(y)
|
||||
c = x.size
|
||||
d = y.size(0)
|
||||
return x, y, (a, b, c, d)
|
||||
|
||||
x = Tensor([1, 2, 3])
|
||||
y = ms.Tensor([1, 2, 3])
|
||||
out = func(x, y)
|
||||
assert type(out[0]) is ms.Tensor
|
||||
assert type(out[1]) is Tensor
|
||||
assert out[2] == (3, 3, 3, 3)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_tensor_isinstance():
|
||||
"""
|
||||
Feature: MSAdapter
|
||||
Description: Test isinstance syntax
|
||||
Expectation: No exception
|
||||
"""
|
||||
@ms.jit
|
||||
def func(x):
|
||||
a = isinstance(x, Tensor)
|
||||
x = convert_to_ms_tensor(x)
|
||||
b = isinstance(x, Tensor)
|
||||
x = convert_to_adapter_tensor(x)
|
||||
c = isinstance(x, Tensor)
|
||||
return a, b, c
|
||||
|
||||
x = Tensor(1)
|
||||
out = func(x)
|
||||
assert out[0] and not out[1] and out[2]
|
|
@ -0,0 +1,51 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test grad in MSAdapter. """
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.ops import grad
|
||||
from tests.st.ms_adapter import Tensor
|
||||
|
||||
|
||||
ms.set_context(mode=ms.GRAPH_MODE)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_ms_adapter_grad():
|
||||
"""
|
||||
Feature: MSAdapter
|
||||
Description: Test grad scenario of MSAdapter
|
||||
Expectation: No exception
|
||||
"""
|
||||
class Net(nn.Cell):
|
||||
def construct(self, x, y, z):
|
||||
return x * y * z
|
||||
|
||||
x = Tensor([1, 2], mstype.int32)
|
||||
y = Tensor([-2, 3], mstype.int32)
|
||||
z = Tensor([0, 3], mstype.int32)
|
||||
net = Net()
|
||||
output = grad(net, grad_position=(1, 2))(x, y, z)
|
||||
|
||||
grad_y = Tensor([0, 6], mstype.int32)
|
||||
grad_z = Tensor([-2, 6], mstype.int32)
|
||||
assert np.all(output[0].asnumpy() == grad_y.asnumpy())
|
||||
assert np.all(output[1].asnumpy() == grad_z.asnumpy())
|
|
@ -0,0 +1,307 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import pytest
|
||||
import mindspore as ms
|
||||
import tests.st.ms_adapter as adapter
|
||||
|
||||
ms.set_context(mode=ms.GRAPH_MODE)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_arithmetic_operator():
|
||||
"""
|
||||
Feature: MSAdapter
|
||||
Description: Test arithmetic operators
|
||||
Expectation: No exception
|
||||
"""
|
||||
@ms.jit
|
||||
def add_fn(x, y):
|
||||
return x + y
|
||||
|
||||
@ms.jit
|
||||
def sub_fn(x, y):
|
||||
return x - y
|
||||
|
||||
@ms.jit
|
||||
def mul_fn(x, y):
|
||||
return x * y
|
||||
|
||||
@ms.jit
|
||||
def div_fn(x, y):
|
||||
return x / y
|
||||
|
||||
@ms.jit
|
||||
def floordiv_fn(x, y):
|
||||
return x // y
|
||||
|
||||
@ms.jit
|
||||
def mod_fn(x, y):
|
||||
return x % y
|
||||
|
||||
@ms.jit
|
||||
def pow_fn(x, y):
|
||||
return x ** y
|
||||
|
||||
def check_output_type(func):
|
||||
ms_x = ms.Tensor(1)
|
||||
adapter_x = adapter.Tensor(1)
|
||||
assert type(func(ms_x, ms_x)) is ms.Tensor
|
||||
assert type(func(adapter_x, adapter_x)) is adapter.Tensor # "Tensor", "Tensor"
|
||||
assert type(func(adapter_x, 1)) is adapter.Tensor # "Tensor", "Number"
|
||||
assert type(func(1, adapter_x)) is adapter.Tensor # "Number", "Tensor"
|
||||
assert type(func(adapter_x, (adapter_x,))) is adapter.Tensor # "Tensor", "Tuple"
|
||||
assert type(func((adapter_x,), adapter_x)) is adapter.Tensor # "Tuple", "Tensor"
|
||||
assert type(func(adapter_x, [adapter_x,])) is adapter.Tensor # "Tensor", "List"
|
||||
assert type(func([adapter_x,], adapter_x)) is adapter.Tensor # "List", "Tensor"
|
||||
|
||||
check_output_type(add_fn)
|
||||
check_output_type(sub_fn)
|
||||
check_output_type(mul_fn)
|
||||
check_output_type(div_fn)
|
||||
check_output_type(floordiv_fn)
|
||||
check_output_type(mod_fn)
|
||||
check_output_type(pow_fn)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_binary_operator():
|
||||
"""
|
||||
Feature: MSAdapter
|
||||
Description: Test binary operators
|
||||
Expectation: No exception
|
||||
"""
|
||||
@ms.jit
|
||||
def equal_fn(x, y):
|
||||
return x == y
|
||||
|
||||
@ms.jit
|
||||
def not_equal_fn(x, y):
|
||||
return x != y
|
||||
|
||||
@ms.jit
|
||||
def less_fn(x, y):
|
||||
return x < y
|
||||
|
||||
@ms.jit
|
||||
def greater_fn(x, y):
|
||||
return x > y
|
||||
|
||||
@ms.jit
|
||||
def less_equal_fn(x, y):
|
||||
return x <= y
|
||||
|
||||
@ms.jit
|
||||
def greater_equal_fn(x, y):
|
||||
return x >= y
|
||||
|
||||
@ms.jit
|
||||
def bitwise_and_fn(x, y):
|
||||
return x & y
|
||||
|
||||
@ms.jit
|
||||
def bitwise_or_fn(x, y):
|
||||
return x | y
|
||||
|
||||
@ms.jit
|
||||
def bitwise_xor_fn(x, y):
|
||||
return x ^ y
|
||||
|
||||
def check_output_type(func):
|
||||
ms_x = ms.Tensor([1, 2, 3])
|
||||
ms_y = ms.Tensor([3, 2, 1])
|
||||
adapter_x = adapter.Tensor([1, 2, 3])
|
||||
adapter_y = adapter.Tensor([3, 2, 1])
|
||||
assert type(func(ms_x, ms_y)) is ms.Tensor
|
||||
assert type(func(adapter_x, adapter_y)) is adapter.Tensor # "Tensor", "Tensor"
|
||||
assert type(func(adapter_x, 1)) is adapter.Tensor # "Tensor", "Number"
|
||||
assert type(func(1, adapter_x)) is adapter.Tensor # "Number", "Tensor"
|
||||
|
||||
check_output_type(equal_fn)
|
||||
check_output_type(not_equal_fn)
|
||||
check_output_type(less_fn)
|
||||
check_output_type(greater_fn)
|
||||
check_output_type(less_equal_fn)
|
||||
check_output_type(greater_equal_fn)
|
||||
check_output_type(bitwise_and_fn)
|
||||
check_output_type(bitwise_or_fn)
|
||||
check_output_type(bitwise_xor_fn)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_unary_operator():
|
||||
"""
|
||||
Feature: MSAdapter
|
||||
Description: Test unary operators
|
||||
Expectation: No exception
|
||||
"""
|
||||
@ms.jit
|
||||
def positive_fn(x):
|
||||
return +x
|
||||
|
||||
@ms.jit
|
||||
def negative_fn(x):
|
||||
return -x
|
||||
|
||||
ms_x = ms.Tensor([1, -2, 3])
|
||||
adapter_x = adapter.Tensor([1, -2, 3])
|
||||
assert type(positive_fn(ms_x)) is ms.Tensor
|
||||
assert type(negative_fn(ms_x)) is ms.Tensor
|
||||
assert type(positive_fn(adapter_x)) is adapter.Tensor
|
||||
assert type(negative_fn(adapter_x)) is adapter.Tensor
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_logical_operator():
|
||||
"""
|
||||
Feature: MSAdapter
|
||||
Description: Test logical operators
|
||||
Expectation: No exception
|
||||
"""
|
||||
@ms.jit
|
||||
def is_fn(x):
|
||||
return x is None
|
||||
|
||||
@ms.jit
|
||||
def is_not_fn(x):
|
||||
return x is not None
|
||||
|
||||
@ms.jit
|
||||
def invert_fn(x):
|
||||
return ~x
|
||||
|
||||
@ms.jit
|
||||
def logical_not_fn(x):
|
||||
return not x
|
||||
|
||||
ms_x = ms.Tensor(True)
|
||||
adapter_x = adapter.Tensor(True)
|
||||
assert not is_fn(adapter_x)
|
||||
assert is_not_fn(adapter_x)
|
||||
assert type(invert_fn(ms_x)) is ms.Tensor
|
||||
assert type(logical_not_fn(ms_x)) is ms.Tensor
|
||||
assert type(invert_fn(adapter_x)) is adapter.Tensor
|
||||
assert type(logical_not_fn(adapter_x)) is adapter.Tensor
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_contain_operator():
|
||||
"""
|
||||
Feature: MSAdapter
|
||||
Description: Test in / not in
|
||||
Expectation: No exception
|
||||
"""
|
||||
@ms.jit
|
||||
def in_fn(x, y, z):
|
||||
return x in (x, y, z)
|
||||
|
||||
@ms.jit
|
||||
def not_in_fn(x, y, z):
|
||||
return x not in (x, y, z)
|
||||
|
||||
ms_x = ms.Tensor(2)
|
||||
ms_y = ms.Tensor(2)
|
||||
ms_z = ms.Tensor(3)
|
||||
adapter_x = adapter.Tensor(1)
|
||||
adapter_y = adapter.Tensor(2)
|
||||
adapter_z = adapter.Tensor(3)
|
||||
assert type(in_fn(ms_x, ms_y, ms_z)) is ms.Tensor
|
||||
assert type(not_in_fn(ms_x, ms_y, ms_z)) is ms.Tensor
|
||||
assert type(in_fn(adapter_x, adapter_y, adapter_z)) is adapter.Tensor
|
||||
assert type(not_in_fn(adapter_x, adapter_y, adapter_z)) is adapter.Tensor
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_matmul():
|
||||
"""
|
||||
Feature: MSAdapter
|
||||
Description: Test matmul operator
|
||||
Expectation: No exception
|
||||
"""
|
||||
@ms.jit
|
||||
def func(x, y):
|
||||
return x @ y
|
||||
|
||||
ms_x = ms.Tensor([1, 2], ms.float32)
|
||||
ms_y = ms.Tensor([3, 4], ms.float32)
|
||||
adapter_x = adapter.Tensor([1, 2], ms.float32)
|
||||
adapter_y = adapter.Tensor([3, 4], ms.float32)
|
||||
assert type(func(ms_x, ms_y)) is ms.Tensor
|
||||
assert type(func(adapter_x, adapter_y)) is adapter.Tensor
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_getitem():
|
||||
"""
|
||||
Feature: MSAdapter
|
||||
Description: Test getietm operator
|
||||
Expectation: No exception
|
||||
"""
|
||||
@ms.jit
|
||||
def getitem_fn(x, index):
|
||||
return x[index]
|
||||
|
||||
@ms.jit
|
||||
def getitem_slice_fn(x):
|
||||
return x[1:]
|
||||
|
||||
ms_x = ms.Tensor([[1, 2, 3], [4, 5, 6]])
|
||||
adapter_x = adapter.Tensor([[1, 2, 3], [4, 5, 6]])
|
||||
assert type(getitem_fn(ms_x, 0)) is ms.Tensor
|
||||
assert type(getitem_fn(ms_x, None)) is ms.Tensor
|
||||
assert type(getitem_fn(ms_x, [0, 1])) is ms.Tensor
|
||||
assert type(getitem_fn(ms_x, (0, 1))) is ms.Tensor
|
||||
assert type(getitem_slice_fn(ms_x)) is ms.Tensor
|
||||
assert type(getitem_fn(adapter_x, 0)) is adapter.Tensor
|
||||
assert type(getitem_fn(adapter_x, None)) is adapter.Tensor
|
||||
assert type(getitem_fn(adapter_x, [0, 1])) is adapter.Tensor
|
||||
assert type(getitem_fn(adapter_x, (0, 1))) is adapter.Tensor
|
||||
assert type(getitem_slice_fn(adapter_x)) is adapter.Tensor
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_setitem():
|
||||
"""
|
||||
Feature: MSAdapter
|
||||
Description: Test setitem operator
|
||||
Expectation: No exception
|
||||
"""
|
||||
@ms.jit
|
||||
def setitem_fn(x, index, value):
|
||||
x[index] = value
|
||||
return x
|
||||
|
||||
ms_x = ms.Tensor([[1, 2, 3], [4, 5, 6]])
|
||||
adapter_x = adapter.Tensor([[1, 2, 3], [4, 5, 6]])
|
||||
adapter_index = adapter.Tensor(0)
|
||||
adapter_value = adapter.Tensor([7, 8, 9])
|
||||
assert type(setitem_fn(adapter_x, adapter_index, adapter_value)) is adapter.Tensor
|
||||
assert type(setitem_fn(ms_x, adapter_index, adapter_value)) is ms.Tensor
|
|
@ -0,0 +1,391 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
from functools import partial
|
||||
import pytest
|
||||
import mindspore as ms
|
||||
import tests.st.ms_adapter as adapter
|
||||
|
||||
ms.set_context(mode=ms.GRAPH_MODE)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_abs():
|
||||
"""
|
||||
Feature: MSAdapter
|
||||
Description: Test python built-in function abs()
|
||||
Expectation: No exception
|
||||
"""
|
||||
@ms.jit
|
||||
def func(x):
|
||||
return abs(x)
|
||||
|
||||
assert type(func(ms.Tensor(-5))) is ms.Tensor
|
||||
assert type(func(adapter.Tensor(-5))) is adapter.Tensor
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_round():
|
||||
"""
|
||||
Feature: MSAdapter
|
||||
Description: Test python built-in function round()
|
||||
Expectation: No exception
|
||||
"""
|
||||
@ms.jit
|
||||
def func(x):
|
||||
return round(x)
|
||||
|
||||
assert type(func(ms.Tensor(1.55))) is ms.Tensor
|
||||
assert type(func(adapter.Tensor(1.55))) is adapter.Tensor
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_map():
|
||||
"""
|
||||
Feature: MSAdapter
|
||||
Description: Test python built-in function map()
|
||||
Expectation: No exception
|
||||
"""
|
||||
def add(x, y):
|
||||
return x + y
|
||||
|
||||
@ms.jit
|
||||
def func(x, y):
|
||||
return map(add, x, y)
|
||||
|
||||
x = (adapter.Tensor(1), 2)
|
||||
y = (adapter.Tensor(2), 4)
|
||||
out = func(x, y)
|
||||
assert type(out[0]) is adapter.Tensor
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_filter():
|
||||
"""
|
||||
Feature: MSAdapter
|
||||
Description: Test python built-in function filter()
|
||||
Expectation: No exception
|
||||
"""
|
||||
def select_fn(x):
|
||||
return True
|
||||
|
||||
@ms.jit
|
||||
def func(x):
|
||||
return filter(select_fn, x)
|
||||
|
||||
x = (adapter.Tensor(2), 1, 2, 3)
|
||||
out = func(x)
|
||||
assert type(out[0]) is adapter.Tensor
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_partial():
|
||||
"""
|
||||
Feature: MSAdapter
|
||||
Description: Test python built-in function partial()
|
||||
Expectation: No exception
|
||||
"""
|
||||
def add(x, y):
|
||||
return x + y
|
||||
|
||||
@ms.jit
|
||||
def func(data):
|
||||
add_ = partial(add, x=2)
|
||||
return add_(y=data)
|
||||
|
||||
out = func(adapter.Tensor(1))
|
||||
assert type(out) is adapter.Tensor
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_zip():
|
||||
"""
|
||||
Feature: MSAdapter
|
||||
Description: Test python built-in function zip()
|
||||
Expectation: No exception
|
||||
"""
|
||||
@ms.jit
|
||||
def func(x, y):
|
||||
return zip(x, y)
|
||||
|
||||
x = (adapter.Tensor(1), 2)
|
||||
y = (adapter.Tensor(2), 4)
|
||||
out = func(x, y)
|
||||
assert type(out[0][0]) is adapter.Tensor
|
||||
assert type(out[0][1]) is adapter.Tensor
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_enumerate():
|
||||
"""
|
||||
Feature: MSAdapter
|
||||
Description: Test python built-in function enumerate()
|
||||
Expectation: No exception
|
||||
"""
|
||||
@ms.jit
|
||||
def func(x):
|
||||
return enumerate(x)
|
||||
|
||||
x = adapter.Tensor([[1, 2], [3, 4], [5, 6]])
|
||||
out = func(x)
|
||||
assert out[0][0] == 0
|
||||
assert type(out[0][1]) is adapter.Tensor
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_instance():
|
||||
"""
|
||||
Feature: MSAdapter
|
||||
Description: Test python built-in function isinstance()
|
||||
Expectation: No exception
|
||||
"""
|
||||
@ms.jit
|
||||
def func(x, y):
|
||||
a = isinstance(x, ms.Tensor) and not isinstance(x, adapter.Tensor)
|
||||
b = isinstance(y, ms.Tensor) and isinstance(y, adapter.Tensor)
|
||||
return a, b
|
||||
|
||||
x = ms.Tensor(1)
|
||||
y = adapter.Tensor(1)
|
||||
a, b = func(x, y)
|
||||
assert a and b
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_max():
|
||||
"""
|
||||
Feature: MSAdapter
|
||||
Description: Test python built-in function max()
|
||||
Expectation: No exception
|
||||
"""
|
||||
@ms.jit
|
||||
def func(x, y, z):
|
||||
return max(x), max(y, z)
|
||||
|
||||
x = adapter.Tensor([1, 2], ms.float32)
|
||||
y = adapter.Tensor([1], ms.float32)
|
||||
z = adapter.Tensor([2], ms.float32)
|
||||
out = func(x, y, z)
|
||||
assert type(out[0]) is adapter.Tensor
|
||||
assert type(out[1]) is adapter.Tensor
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_min():
|
||||
"""
|
||||
Feature: MSAdapter
|
||||
Description: Test python built-in function min()
|
||||
Expectation: No exception
|
||||
"""
|
||||
@ms.jit
|
||||
def func(x, y, z):
|
||||
return min(x), min(y, z)
|
||||
|
||||
x = adapter.Tensor([1, 2], ms.float32)
|
||||
y = adapter.Tensor([1], ms.float32)
|
||||
z = adapter.Tensor([2], ms.float32)
|
||||
out = func(x, y, z)
|
||||
assert type(out[0]) is adapter.Tensor
|
||||
assert type(out[1]) is adapter.Tensor
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_sum():
|
||||
"""
|
||||
Feature: MSAdapter
|
||||
Description: Test python built-in function sum()
|
||||
Expectation: No exception
|
||||
"""
|
||||
@ms.jit
|
||||
def func(x, y, z):
|
||||
return sum(x), sum(y, z)
|
||||
|
||||
x = adapter.Tensor([[1, 2], [3, 4]], ms.float32)
|
||||
y = adapter.Tensor([1, 2, 3], ms.float32)
|
||||
z = adapter.Tensor([4, 5, 6], ms.float32)
|
||||
out = func(x, y, z)
|
||||
assert type(out[0]) is adapter.Tensor
|
||||
assert type(out[1]) is adapter.Tensor
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_getattr():
|
||||
"""
|
||||
Feature: MSAdapter
|
||||
Description: Test python built-in function getattr()
|
||||
Expectation: No exception
|
||||
"""
|
||||
@ms.jit
|
||||
def func(x):
|
||||
return getattr(x, "attr")
|
||||
|
||||
x = adapter.Tensor([1, 2, 3])
|
||||
assert func(x) == 10
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_hasattr():
|
||||
"""
|
||||
Feature: MSAdapter
|
||||
Description: Test python built-in function hasattr()
|
||||
Expectation: No exception
|
||||
"""
|
||||
@ms.jit
|
||||
def func(x, y):
|
||||
return hasattr(x, "method"), hasattr(y, "method")
|
||||
|
||||
x = adapter.Tensor([1, 2, 3])
|
||||
y = ms.Tensor([1, 2, 3])
|
||||
out = func(x, y)
|
||||
assert out[0] and not out[1]
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_iter():
|
||||
"""
|
||||
Feature: MSAdapter
|
||||
Description: Test python built-in function iter()
|
||||
Expectation: No exception
|
||||
"""
|
||||
@ms.jit
|
||||
def func(x):
|
||||
return iter(x)[0]
|
||||
|
||||
x = adapter.Tensor([1, 2, 3])
|
||||
out = func(x)
|
||||
assert type(out) is adapter.Tensor
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_next():
|
||||
"""
|
||||
Feature: MSAdapter
|
||||
Description: Test python built-in function next()
|
||||
Expectation: No exception
|
||||
"""
|
||||
@ms.jit
|
||||
def func(x):
|
||||
it = iter(x)
|
||||
return next(it)
|
||||
|
||||
x = adapter.Tensor([1, 2, 3])
|
||||
out = func(x)
|
||||
assert type(out[0]) is adapter.Tensor
|
||||
assert type(out[1]) is adapter.Tensor
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_print():
|
||||
"""
|
||||
Feature: MSAdapter
|
||||
Description: Test python built-in function print()
|
||||
Expectation: No exception
|
||||
"""
|
||||
@ms.jit
|
||||
def func(x):
|
||||
print(x)
|
||||
return x
|
||||
|
||||
func(adapter.Tensor([1, 2, 3]))
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_tuple():
|
||||
"""
|
||||
Feature: MSAdapter
|
||||
Description: Test python built-in function tuple()
|
||||
Expectation: No exception
|
||||
"""
|
||||
@ms.jit
|
||||
def func(x):
|
||||
return tuple(x)
|
||||
|
||||
x = adapter.Tensor([1, 2, 3])
|
||||
out = func(x)
|
||||
assert type(out[0]) is adapter.Tensor
|
||||
assert type(out[1]) is adapter.Tensor
|
||||
assert type(out[2]) is adapter.Tensor
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_list():
|
||||
"""
|
||||
Feature: MSAdapter
|
||||
Description: Test python built-in function list()
|
||||
Expectation: No exception
|
||||
"""
|
||||
@ms.jit
|
||||
def func(x):
|
||||
return list(x)
|
||||
|
||||
x = adapter.Tensor([1, 2, 3])
|
||||
out = func(x)
|
||||
assert type(out[0]) is adapter.Tensor
|
||||
assert type(out[1]) is adapter.Tensor
|
||||
assert type(out[2]) is adapter.Tensor
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_bool():
|
||||
"""
|
||||
Feature: MSAdapter
|
||||
Description: Test python built-in function bool()
|
||||
Expectation: No exception
|
||||
"""
|
||||
@ms.jit
|
||||
def func(x):
|
||||
return bool(x)
|
||||
|
||||
x = adapter.Tensor([10])
|
||||
out = func(x)
|
||||
assert type(out) is adapter.Tensor
|
Loading…
Reference in New Issue