Merge pull request !46303 from huangbingjian/adapter_tensor
This commit is contained in:
i-robot 2022-12-21 08:04:50 +00:00 committed by Gitee
commit 92bf98fa4e
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
37 changed files with 2209 additions and 116 deletions

View File

@ -12,6 +12,7 @@
"mindspore/mindspore/python/mindspore/_check_version.py" "unused-import"
"mindspore/mindspore/python/mindspore/_check_version.py" "broad-except"
"mindspore/mindspore/python/mindspore/common/api.py" "protected-access"
"mindspore/mindspore/python/mindspore/common/api.py" "not-callable"
"mindspore/mindspore/python/mindspore/common/parameter.py" "protected-access"
"mindspore/mindspore/python/mindspore/common/parameter.py" "no-value-for-parameter"
"mindspore/mindspore/python/mindspore/common/hook_handle.py" "protected-access"
@ -72,6 +73,7 @@
"mindspore/mindspore/python/mindspore/ops/function/array_func.py" "redefined-builtin"
"mindspore/mindspore/python/mindspore/ops/operations/array_ops.py" "redefined-builtin"
"mindspore/mindspore/python/mindspore/ops/_grad_experimental/grad_sparse_ops.py" "unused-variable"
"mindspore/mindspore/python/mindspore/ops/operations/_inner_ops.py" "not-callable"
# MindData
"mindspore/mindspore/python/mindspore/dataset/__init__.py" "redefined-builtin"
@ -175,6 +177,11 @@
"mindspore/tests/" "c-extension-no-member"
"mindspore/tests/st/parameter/test_parameter_celllist.py" "protected-access"
"mindspore/tests/ut/python/rewrite/test_cellcontainer.py" "protected-access"
"mindspore/tests/st/ms_adapter/_register/utils.py" "protected-access"
"mindspore/tests/st/ms_adapter/_register/utils.py" "unused-variable"
"mindspore/tests/st/ms_adapter/test_ms_adapter_base.py" "unidiomatic-typecheck"
"mindspore/tests/st/ms_adapter/test_operator.py" "unidiomatic-typecheck"
"mindspore/tests/st/ms_adapter/test_python_builtins.py" "unidiomatic-typecheck"
#MindSpore Lite
"mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/experimental/HPC-generator/generator.py" "redefined-builtin"

View File

@ -214,6 +214,22 @@ AbstractBasePtr InferImplHasType(const AnalysisEnginePtr &, const PrimitivePtr &
return std::make_shared<AbstractScalar>(std::make_shared<BoolImm>(v), kBool);
}
bool IsAdapterTensor(const AbstractBasePtr &x) {
if (!x->isa<abstract::AbstractTensor>()) {
return false;
}
return x->cast<abstract::AbstractTensorPtr>()->is_adapter();
}
bool IsAdapterTensorClassType(const AbstractBasePtr &cmp) {
auto cmp_value = cmp->BuildValue();
if (!cmp_value->isa<parse::ClassType>()) {
return false;
}
auto class_obj = cmp_value->cast<parse::ClassTypePtr>()->obj();
return py::hasattr(class_obj, PYTHON_ADAPTER_TENSOR);
}
bool CheckPythonIsInstance(const py::object &x, const AbstractBasePtr &cmp, const py::module &mod, bool is_const) {
if (cmp->isa<abstract::AbstractTuple>()) {
const auto &cmp_tuple_elements = cmp->cast<abstract::AbstractTuplePtr>()->elements();
@ -401,6 +417,11 @@ AbstractBasePtr InferImplIsInstance(const AnalysisEnginePtr &, const PrimitivePt
return std::make_shared<AbstractScalar>(std::make_shared<BoolImm>(result), kBool);
}
// x is adapter tensor.
if (IsAdapterTensor(x) && IsAdapterTensorClassType(cmp)) {
return std::make_shared<AbstractScalar>(std::make_shared<BoolImm>(true), kBool);
}
auto x_value = x->BuildValue();
// x is variable built-in type.
if (x_value == kAnyValue) {
@ -1112,6 +1133,42 @@ std::optional<StandardPrimitiveImplReg> GetFrontendPrimitiveInferImpl(const Prim
return std::optional<StandardPrimitiveImplReg>();
}
AbstractBasePtr SetAdapterFlag(const std::string &op_name, const AbstractBasePtr &abs_input, bool adapter_flag) {
MS_EXCEPTION_IF_NULL(abs_input);
// Clone is needed here.
if (abs_input->isa<AbstractRefTensor>()) {
auto abs_ref = abs_input->Clone()->cast<AbstractRefPtr>();
abs_ref->set_is_adapter(adapter_flag);
return abs_ref;
}
if (abs_input->isa<AbstractTensor>()) {
auto abs_tensor = abs_input->Clone()->cast<AbstractTensorPtr>();
abs_tensor->set_is_adapter(adapter_flag);
return abs_tensor;
}
MS_LOG(EXCEPTION) << op_name << " requires a tensor as the first argument, but got " << abs_input->ToString();
}
AbstractBasePtr InferImplConvertToAdapterTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: a tensor.
constexpr size_t args_num = 1;
constexpr size_t input_index = 0;
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, args_num);
return SetAdapterFlag(op_name, args_spec_list[input_index], true);
}
AbstractBasePtr InferImplConvertToMsTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: a tensor.
constexpr size_t args_num = 1;
constexpr size_t input_index = 0;
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, args_num);
return SetAdapterFlag(op_name, args_spec_list[input_index], false);
}
#ifndef _MSC_VER
// String
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(StringMul, prim::kPrimStringMul, InferImplStringMul, nullptr);
@ -1149,6 +1206,10 @@ REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(Taylor, prim::kPrimTaylor, InferImplTaylor, n
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(Shard, prim::kPrimShard, InferImplShard, nullptr);
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(Vmap, prim::kPrimVmap, InferImplVmap, nullptr);
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(Lower, prim::kPrimLower, InferImplLower, nullptr);
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(ConvertToAdapterTensor, prim::kPrimConvertToAdapterTensor,
InferImplConvertToAdapterTensor, nullptr);
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(ConvertToMsTensor, prim::kPrimConvertToMsTensor, InferImplConvertToMsTensor,
nullptr);
#else
void RegPrimitiveFrontEval() {
// String
@ -1213,6 +1274,11 @@ void RegPrimitiveFrontEval() {
InferImplVmap, nullptr);
abstract::RegisterStandardPrimitiveEvalHelper(abstract::GetFrontendPrimitiveInferMapPtr(), prim::kPrimLower,
InferImplLower, nullptr);
abstract::RegisterStandardPrimitiveEvalHelper(abstract::GetFrontendPrimitiveInferMapPtr(),
prim::kPrimConvertToAdapterTensor, InferImplConvertToAdapterTensor,
nullptr);
abstract::RegisterStandardPrimitiveEvalHelper(abstract::GetFrontendPrimitiveInferMapPtr(),
prim::kPrimConvertToMsTensor, InferImplConvertToMsTensor, nullptr);
} // namespace abstract
#endif
} // namespace abstract

View File

@ -81,6 +81,10 @@ AbstractBasePtr InferImplShard(const AnalysisEnginePtr &, const PrimitivePtr &pr
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplVmap(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplConvertToAdapterTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplConvertToMsTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
const PrimitiveEvalImplMap &GetFrontendPrimitiveInferMap();
PrimitiveEvalImplMap *GetFrontendPrimitiveInferMapPtr();

View File

@ -51,6 +51,7 @@
#include "frontend/optimizer/irpass/call_graph_tuple_transform.h"
#include "frontend/optimizer/irpass/recompute_prepare.h"
#include "frontend/optimizer/irpass/real_op_eliminate.h"
#include "frontend/optimizer/irpass/convert_tensor_eliminate.h"
#include "mindspore/ccsrc/frontend/optimizer/irpass/bprop_mindir/get_constexpr_ops.h"
#include "mindspore/ccsrc/frontend/optimizer/irpass/bprop_mindir/get_class_type.h"
#include "mindspore/ccsrc/frontend/optimizer/irpass/bprop_mindir/get_meta_fg.h"
@ -132,6 +133,11 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
all_reduce_const_elim_ =
MakeSubstitution(std::make_shared<AllReduceConstElim>(), "reduce_all_const_elim", prim::kPrimAllReduce);
real_op_eliminate_ = MakeSubstitution(std::make_shared<RealOpEliminate>(), "real_op_eliminate", prim::kPrimRealInner);
convert_tensor_eliminate_ = MakeSubstitution(std::make_shared<ConvertTensorEliminate>(), "convert_tensor_eliminate",
{prim::kPrimConvertToAdapterTensor, prim::kPrimConvertToMsTensor});
convert_tensor_all_eliminate_ =
MakeSubstitution(std::make_shared<ConvertTensorAllEliminate>(), "convert_tensor_all_eliminate",
{prim::kPrimConvertToAdapterTensor, prim::kPrimConvertToMsTensor});
// Environ Item Eliminate
environ_get_eliminate_ =

View File

@ -64,6 +64,8 @@ class OptimizeIRPassLib {
SubstitutionPtr mini_step_allgather_replace_;
SubstitutionPtr micro_step_allgather_replace_;
SubstitutionPtr real_op_eliminate_;
SubstitutionPtr convert_tensor_eliminate_;
SubstitutionPtr convert_tensor_all_eliminate_;
// Env Item Eliminate
SubstitutionPtr environ_get_eliminate_;

View File

@ -0,0 +1,101 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_CONVERT_TENSOR_ELIMINATE_H_
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_CONVERT_TENSOR_ELIMINATE_H_
#include "frontend/optimizer/anf_visitor.h"
#include "frontend/optimizer/irpass.h"
#include "frontend/optimizer/optimizer.h"
#include "pipeline/jit/static_analysis/prim.h"
namespace mindspore {
namespace opt {
namespace irpass {
class ConvertTensorEliminate : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
auto fg = node->func_graph();
MS_EXCEPTION_IF_NULL(fg);
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
constexpr size_t tensor_index = 1;
auto x = cnode->input(tensor_index);
bool is_adapter = IsAdapterTensor(x);
if (IsPrimitiveCNode(node, prim::kPrimConvertToAdapterTensor)) {
// {prim::kPrimConvertToAdapterTensor, x} -> x
if (is_adapter) {
return x;
}
// {prim::kPrimConvertToAdapterTensor, {prim::kPrimConvertToMsTensor, inp}} ->
// {prim::kPrimConvertToAdapterTensor, inp}
if (IsPrimitiveCNode(x, prim::kPrimConvertToMsTensor)) {
auto x_cnode = x->cast<CNodePtr>();
auto inp = x_cnode->input(tensor_index);
auto new_node = fg->NewCNode({NewValueNode(prim::kPrimConvertToAdapterTensor), inp});
new_node->set_abstract(node->abstract());
return new_node;
}
}
if (IsPrimitiveCNode(x, prim::kPrimConvertToMsTensor)) {
// {prim::kPrimConvertToMsTensor, x} -> x
if (!is_adapter) {
return x;
}
// {prim::kPrimConvertToMsTensor, {prim::kPrimConvertToAdapterTensor, inp}} ->
// {prim::kPrimConvertToMsTensor, inp}
if (IsPrimitiveCNode(x, prim::kPrimConvertToAdapterTensor)) {
auto x_cnode = x->cast<CNodePtr>();
auto inp = x_cnode->input(tensor_index);
auto new_node = fg->NewCNode({NewValueNode(prim::kPrimConvertToMsTensor), inp});
new_node->set_abstract(node->abstract());
return new_node;
}
}
return nullptr;
}
private:
bool IsAdapterTensor(const AnfNodePtr &node) {
auto abs = node->abstract();
MS_EXCEPTION_IF_NULL(abs);
auto abs_tensor = dyn_cast<abstract::AbstractTensor>(abs);
MS_EXCEPTION_IF_NULL(abs_tensor);
return abs_tensor->is_adapter();
}
};
class ConvertTensorAllEliminate : public AnfVisitor {
public:
// {prim::kPrimConvertToAdapterTensor, x} -> x
// {prim::kPrimConvertToMsTensor, x} -> x
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
if (!IsPrimitiveCNode(node, prim::kPrimConvertToAdapterTensor) &&
!IsPrimitiveCNode(node, prim::kPrimConvertToMsTensor)) {
return nullptr;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
constexpr size_t tensor_index = 1;
auto tensor = cnode->input(tensor_index);
tensor->set_abstract(node->abstract());
return tensor;
}
};
} // namespace irpass
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_CONVERT_TENSOR_ELIMINATE_H_

View File

@ -31,9 +31,8 @@ namespace py = pybind11;
namespace mindspore {
py::object AnyToPyData(const Any &value);
COMMON_EXPORT py::object BaseRefToPyData(const BaseRef &value);
COMMON_EXPORT py::object BaseRefToPyData(const BaseRef &value, const AbstractBasePtr &output);
COMMON_EXPORT py::object ValueToPyData(const ValuePtr &value);
COMMON_EXPORT py::object BaseRefToPyData(const BaseRef &value, const AbstractBasePtr &abs = nullptr);
COMMON_EXPORT py::object ValueToPyData(const ValuePtr &value, const AbstractBasePtr &abs = nullptr);
COMMON_EXPORT bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &output, const py::tuple &args,
const std::shared_ptr<py::object> &ret_val);

View File

@ -24,6 +24,7 @@
#include "frontend/operator/composite/composite.h"
#include "ir/func_graph_cloner.h"
#include "ir/cell.h"
#include "ir/adapter_tensor.h"
#include "utils/symbolic.h"
#include "utils/ms_context.h"
#include "include/common/utils/utils.h"
@ -47,6 +48,8 @@ using COOTensor = mindspore::tensor::COOTensor;
using COOTensorPtr = mindspore::tensor::COOTensorPtr;
using MapTensor = mindspore::tensor::MapTensor;
using MapTensorPtr = mindspore::tensor::MapTensorPtr;
using AdapterTensor = mindspore::tensor::AdapterTensor;
using AdapterTensorPtr = mindspore::tensor::AdapterTensorPtr;
using InstanceCheckFunc = std::function<bool(const py::object &)>;
using InstanceConvertFunc = std::function<ValuePtr(const py::object &, bool, const TypePtr &)>;
@ -80,7 +83,7 @@ using ArgsObjConvertFunc = std::function<ValuePtr(const py::object &)>;
using ArgsObjSigConvertFunc = std::function<ValuePtr(const py::object &, bool)>;
using ArgsOjbTypeConvertFunc = std::function<ValuePtr(const py::object &, const TypePtr &)>;
// Convert the data according instance type
// Convert the data according to instance type
template <typename T>
class ByTypeDataConverter : public DataConverter {
public:
@ -117,7 +120,7 @@ class ByTypeDataConverter : public DataConverter {
InstanceCheckFunc check_func_ = nullptr;
};
// Convert the data according object attribute.
// Convert the data according to object attribute.
class ByAttrDataConverter : public DataConverter {
public:
ByAttrDataConverter(const std::string &attr_name, const ArgsObjConvertFunc &convert_func)
@ -139,6 +142,28 @@ class ByAttrDataConverter : public DataConverter {
std::string attr_name_;
};
// Convert the data according to match function.
class ByFuncDataConverter : public DataConverter {
public:
ByFuncDataConverter(const InstanceCheckFunc &match_func, const ArgsObjConvertFunc &convert_func)
: DataConverter(
[convert_func](const py::object &obj, bool, const TypePtr &) -> ValuePtr { return convert_func(obj); }),
match_func_(match_func) {}
ByFuncDataConverter(const InstanceCheckFunc &match_func, const ArgsObjSigConvertFunc &convert_func)
: DataConverter([convert_func](const py::object &obj, bool use_sig, const TypePtr &) -> ValuePtr {
return convert_func(obj, use_sig);
}),
match_func_(match_func) {}
~ByFuncDataConverter() override = default;
bool Matched(const py::object &obj) override { return match_func_ != nullptr ? match_func_(obj) : false; }
private:
InstanceCheckFunc match_func_ = nullptr;
};
FuncGraphPtr ConvertToBpropCut(const py::object &obj) {
std::vector<std::string> results = data_converter::GetObjKey(obj);
std::string obj_key = results[0];
@ -171,6 +196,25 @@ FuncGraphPtr ConvertToBpropCut(const py::object &obj) {
}
namespace {
bool IsAdapterTensor(const py::object &obj) {
if (!py::hasattr(obj, PYTHON_ADAPTER_TENSOR)) {
return false;
}
// Only class instances are considered.
if (data_converter::IsClassType(obj)) {
return false;
}
// Check if the attribute is true.
return py::getattr(obj, PYTHON_ADAPTER_TENSOR).cast<bool>();
}
ValuePtr ConvertAdapterTensor(const py::object &obj) {
// Use class AdapterTensor instead of Tensor to avoid circular dependencies.
MS_LOG(DEBUG) << "Converting adapter tensor";
auto tensor = obj.cast<TensorPtr>();
return std::make_shared<AdapterTensor>(tensor);
}
ValuePtr ConvertTuple(const py::object &obj, bool use_signature) {
MS_LOG(DEBUG) << "Converting python tuple";
auto tuple = obj.cast<py::tuple>();
@ -512,8 +556,10 @@ ValuePtr ObjCast(const py::object &obj) {
}
static const std::vector<DataConverterPtr> &GetDataConverters() {
// Convert data by python object type.
static const std::vector<DataConverterPtr> data_converters{
// Convert data by python object type.
// AdapterTensor needs to be processed before Tensor because it inherits from Tensor.
std::make_shared<ByFuncDataConverter>(IsAdapterTensor, ConvertAdapterTensor),
std::make_shared<ByTypeDataConverter<Tensor>>(ObjCast<TensorPtr>),
std::make_shared<ByTypeDataConverter<MetaTensor>>(ObjCast<MetaTensorPtr>),
std::make_shared<ByTypeDataConverter<CSRTensor>>(ObjCast<CSRTensorPtr>),

View File

@ -67,6 +67,7 @@ const char PYTHON_MOD_GET_CLASS_INSTANCE_TYPE[] = "get_class_instance_type";
const char PYTHON_MOD_CREATE_INSTANCE[] = "create_instance";
const char PYTHON_MOD_IS_SUPPORTED_CREATE_INSTANCE_TYPE[] = "is_supported_create_instance_type";
const char PYTHON_MOD_IS_CLASS_TYPE[] = "is_class_type";
const char PYTHON_MOD_GET_ADAPTER_TENSOR_ATTR[] = "get_adapter_tensor_attr";
const char PYTHON_MOD_GET_MS_CLASS_NAME[] = "get_ms_class_name";
const char PYTHON_MOD_GET_MODULE_NAMESPACE[] = "get_module_namespace";
const char PYTHON_MOD_GET_ATTR_NAMESPACE_SYMBOL[] = "get_class_attr_namespace_symbol";

View File

@ -307,6 +307,7 @@ opt::OptPassConfig GetOptPassA1(const opt::irpass::OptimizeIRPassLib &irpass) {
irpass.stopgrad_eliminater_,
irpass.partial_eliminate_,
irpass.replace_applicator_,
irpass.convert_tensor_eliminate_,
// Miscellaneous
irpass.tuple_list_get_item_eliminator_,
@ -761,13 +762,16 @@ bool EliminateSpecialOpOptPass(const ResourcePtr &resource) {
opt::OptPassConfig mutable_op_eliminate = opt::OptPassConfig({
irpass.mutable_op_eliminate_,
});
opt::OptPassConfig convert_tensor_op_eliminate = opt::OptPassConfig({
irpass.convert_tensor_all_eliminate_,
});
OptPassGroupMap map({
{"ad_related_special_op_eliminate", ad_related_special_op_eliminate},
{"mutable_op_eliminate", mutable_op_eliminate},
{"convert_tensor_op_eliminate", convert_tensor_op_eliminate},
});
auto ad_related_special_op_eliminate_opt =
opt::Optimizer::MakeOptimizer("ad_related_special_op_eliminate", resource, map);
(void)ad_related_special_op_eliminate_opt->step(func_graph, false);
auto special_op_eliminate_opt = opt::Optimizer::MakeOptimizer("special_op_eliminate", resource, map);
(void)special_op_eliminate_opt->step(func_graph, false);
return true;
}

View File

@ -1511,6 +1511,53 @@ EvalResultPtr GetEvaluatedValueForCellAttrOrMethod(const AbstractBasePtrList &ar
return nullptr;
}
EvalResultPtr GetEvaluatedValueForAdapterTensorAttrOrMethod(const AnalysisEnginePtr &engine,
const AbstractBasePtr &data_args,
const AbstractBasePtr &item_args,
const ConfigPtr &data_conf,
const AnfNodeConfigPtr &out_conf) {
MS_EXCEPTION_IF_NULL(data_args);
MS_EXCEPTION_IF_NULL(item_args);
// Check whether it is AdapterTensor or AdapterParameter.
auto abs = data_args->cast_ptr<abstract::AbstractTensor>();
if (abs == nullptr || !abs->is_adapter()) {
return nullptr;
}
// Get the name of attr/method.
ValuePtr item_value = item_args->BuildValue();
MS_EXCEPTION_IF_NULL(item_value);
if (!item_value->isa<StringImm>()) {
MS_LOG(EXCEPTION) << "Expect a string, but got: " << item_value->ToString();
}
std::string item_name = item_value->cast_ptr<StringImm>()->value();
constexpr size_t attr_index = 0;
constexpr size_t flag_index = 1;
constexpr size_t info_required_size = 2;
py::module mod = python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE);
py::tuple attr_info = python_adapter::CallPyModFn(mod, parse::PYTHON_MOD_GET_ADAPTER_TENSOR_ATTR, py::str(item_name));
if (attr_info.size() != info_required_size) {
MS_EXCEPTION(NameError) << "attr info size should be 2, but got " << attr_info.size();
}
// If func is none, it means there is no such attr or method.
py::object func = attr_info[attr_index];
if (py::isinstance<py::none>(func)) {
return nullptr;
}
ValuePtr converted_value = nullptr;
bool success = parse::ConvertData(func, &converted_value);
if (!success || converted_value == nullptr || !converted_value->isa<FuncGraph>()) {
return nullptr;
}
AddToManager(engine, converted_value->cast<FuncGraphPtr>());
// Check whether it is an attribute or a method.
bool is_attr = attr_info[flag_index].cast<bool>();
REQUIRE_TYPE require_type = is_attr ? REQUIRE_TYPE::ATTR : REQUIRE_TYPE::METHOD;
return StaticGetterInferred(converted_value, data_conf, out_conf, require_type);
}
EvalResultPtr GetEvaluatedValueForBuiltinTypeAttrOrMethod(const AnalysisEnginePtr &engine,
const AbstractBasePtrList &args_abs_list,
const ConfigPtr &data_conf,
@ -1676,6 +1723,11 @@ EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePt
return res;
}
}
// Get attribute or method of AdapterTensor object.
auto res = GetEvaluatedValueForAdapterTensorAttrOrMethod(engine, data_args, item_args, data_conf, out_conf);
if (res != nullptr) {
return res;
}
// Try to search method map, if not found, the data_type should be External type.
TypePtr data_type = data_args->BuildType();
if (pipeline::Resource::IsTypeInBuiltInMap(data_type->type_id())) {

View File

@ -254,6 +254,10 @@ void PyParser::ParseOpInputByPythonObj(const FrontendOpRunInfoPtr &op_run_info,
py::object DataConvert::ValueToPyObj(const ValuePtr &v) { return ValueToPyData(v); }
ValuePtr DataConvert::PyObjToValue(const py::object &obj) {
// In PyNative mode, AdapterTensor is treated as ms.Tensor.
if (py::hasattr(obj, PYTHON_ADAPTER_TENSOR) && py::getattr(obj, PYTHON_ADAPTER_TENSOR).cast<bool>()) {
py::setattr(obj, PYTHON_ADAPTER_TENSOR, py::bool_(false));
}
ValuePtr converted_ret = parse::data_converter::PyDataToValue(obj);
if (converted_ret == nullptr) {
MS_LOG(EXCEPTION) << "Attribute convert error with type: " << std::string(py::str(obj));

View File

@ -19,6 +19,7 @@ namespace mindspore {
const char PYTHON_PRIMITIVE_FLAG[] = "__primitive_flag__";
const char PYTHON_CELL_AS_LIST[] = "__cell_as_list__";
const char PYTHON_MS_CLASS[] = "__ms_class__";
const char PYTHON_ADAPTER_TENSOR[] = "__adapter_tensor__";
const char PYTHON_CLASS_MEMBER_NAMESPACE[] = "__class_member_namespace__";
const char PYTHON_FUNCTION_FORBID_REUSE[] = "__function_forbid_reuse__";
} // namespace mindspore

View File

@ -21,6 +21,7 @@ namespace mindspore {
extern const char PYTHON_PRIMITIVE_FLAG[];
extern const char PYTHON_CELL_AS_LIST[];
extern const char PYTHON_MS_CLASS[];
extern const char PYTHON_ADAPTER_TENSOR[];
extern const char PYTHON_CLASS_MEMBER_NAMESPACE[];
extern const char PYTHON_FUNCTION_FORBID_REUSE[];
} // namespace mindspore

View File

@ -562,6 +562,7 @@ void RegMetaTensor(py::module *m) {
}),
py::arg("input"), py::arg("dtype") = nullptr)
.def_property("init_flag", &Tensor::is_init, &Tensor::set_init_flag)
.def_property("adapter_flag", &Tensor::is_adapter, &Tensor::set_adapter_flag)
.def_property_readonly("_dtype", &Tensor::Dtype, R"mydelimiter(
Get the tensor's data type.

View File

@ -41,20 +41,45 @@ namespace mindspore {
py::object BuiltinsToPyData(const Any &value);
py::object BuiltinsToPyData(const BaseRef &value);
py::object VectorToPyData(const Any &value);
py::object VectorRefToPyData(const VectorRef &value_list);
py::object VectorRefToPyData(const VectorRef &value_list, const AbstractBasePtr &output);
py::object VectorRefToPyData(const VectorRef &value_list, const AbstractBasePtr &abs = nullptr);
py::object MakeCSRTensor(const VectorRef &value_list);
py::object MakeCSRTensor(const ValuePtr &value);
py::object MakeCOOTensor(const VectorRef &value_list);
py::object MakeCOOTensor(const ValuePtr &value);
ShapeVector ConvertShapeTupleToShapeVector(const ValueTuplePtr &shape_tuple);
ShapeVector ConvertToShapeVector(const VectorRef &value_list, size_t shape_idx);
// For AbstractSequence and AbstractDictionary.
template <typename T>
T CheckAbstractElementsSize(const AbstractBasePtr &abs_value, size_t value_size) {
if (abs_value == nullptr) {
return nullptr;
}
auto abs = abs_value->cast<T>();
if (abs != nullptr && value_size != abs->size()) {
MS_LOG(EXCEPTION) << "The size of elements should be equal to " << value_size << ", but got " << abs->size();
}
return abs;
}
py::object SetAdaptedAttrToTensor(const py::object &tensor, const AbstractBasePtr &abs) {
if (abs == nullptr || !abs->isa<abstract::AbstractTensor>()) {
return tensor;
}
auto tensor_abs = abs->cast<abstract::AbstractTensorPtr>();
if (tensor_abs->is_adapter()) {
py::setattr(tensor, "adapter_flag", py::bool_(true));
}
return tensor;
}
py::object CSRTensorToPyData(const tensor::CSRTensorPtr &csr_tensor) {
auto ref = py::tuple(1);
ref[0] = csr_tensor;
return ref[0];
}
py::object TensorToPyData(const tensor::TensorPtr &tensor) {
py::object TensorToPyData(const tensor::TensorPtr &tensor, const AbstractBasePtr &abs) {
MS_EXCEPTION_IF_NULL(tensor);
if (tensor->NeedWait()) {
py::gil_scoped_release release;
@ -62,6 +87,7 @@ py::object TensorToPyData(const tensor::TensorPtr &tensor) {
}
py::tuple v(1);
v[0] = tensor;
v[0] = SetAdaptedAttrToTensor(v[0], abs);
return v[0];
}
@ -120,79 +146,110 @@ py::object ScalarPtrToPyData(const ScalarPtr &value) {
}
}
using ConverterFunction = std::function<py::object(const ValuePtr &value)>;
py::object ValueSequenceToPyData(const ValueSequencePtr &value, const AbstractBasePtr &abs) {
auto value_sequeue = value->value();
auto value_size = value_sequeue.size();
py::tuple res_sequeue(value_size);
auto abs_sequeue = CheckAbstractElementsSize<abstract::AbstractSequencePtr>(abs, value_size);
if (abs_sequeue == nullptr) {
for (size_t i = 0; i < value_size; i++) {
res_sequeue[i] = ValueToPyData(value_sequeue[i]);
}
} else {
for (size_t i = 0; i < value_size; i++) {
res_sequeue[i] = ValueToPyData(value_sequeue[i], abs_sequeue->elements()[i]);
}
}
if (value->isa<ValueTuple>()) {
return res_sequeue;
}
return res_sequeue.cast<py::list>();
}
py::object ValueDictionaryToPyData(const ValueDictionaryPtr &value, const AbstractBasePtr &abs) {
auto value_dict = value->value();
auto value_size = value_dict.size();
py::dict res_dict;
auto abs_dict = CheckAbstractElementsSize<abstract::AbstractDictionaryPtr>(abs, value_size);
if (abs_dict == nullptr) {
for (const auto &v : value_dict) {
res_dict[ValueToPyData(v.first)] = ValueToPyData(v.second);
}
} else {
for (size_t i = 0; i < value_size; i++) {
auto v = value_dict[i];
auto abs_elem = abs_dict->elements()[i];
res_dict[ValueToPyData(v.first, abs_elem.first)] = ValueToPyData(v.second, abs_elem.second);
}
}
return res_dict;
}
using ConverterFunction = std::function<py::object(const ValuePtr &value, const AbstractBasePtr &abs)>;
using ValueNameToConverterVector = std::vector<std::pair<uint32_t, ConverterFunction>>;
// (Value Type Name) -> (Converter Function)
// The converter function is used to convert Value object to Python data object.
static ValueNameToConverterVector value_name_to_converter = {
// Scalar
{Scalar::kTypeId, [](const ValuePtr &value) -> py::object { return ScalarPtrToPyData(value->cast<ScalarPtr>()); }},
{Scalar::kTypeId,
[](const ValuePtr &value, const AbstractBasePtr &) -> py::object {
return ScalarPtrToPyData(value->cast<ScalarPtr>());
}},
// Tensor
{tensor::Tensor::kTypeId,
[](const ValuePtr &value) -> py::object {
[](const ValuePtr &value, const AbstractBasePtr &abs) -> py::object {
auto tensor_ptr = value->cast<tensor::TensorPtr>();
return TensorToPyData(tensor_ptr);
return TensorToPyData(tensor_ptr, abs);
}},
// MetaTenser
{tensor::MetaTensor::kTypeId,
[](const ValuePtr &value) -> py::object {
[](const ValuePtr &value, const AbstractBasePtr &) -> py::object {
py::tuple tuple_container(1);
tuple_container[0] = value->cast<tensor::MetaTensorPtr>();
return tuple_container[0];
}},
// CSRTensor
{tensor::CSRTensor::kTypeId,
[](const ValuePtr &value) -> py::object {
[](const ValuePtr &value, const AbstractBasePtr &) -> py::object {
auto csr_tensor_ptr = value->cast<tensor::CSRTensorPtr>();
return CSRTensorToPyData(csr_tensor_ptr);
}},
// RefKey
{RefKey::kTypeId,
[](const ValuePtr &value) -> py::object {
[](const ValuePtr &value, const AbstractBasePtr &) -> py::object {
py::tuple tuple_container(1);
tuple_container[0] = value->cast<RefKeyPtr>();
return tuple_container[0];
}},
// Type
{Type::kTypeId,
[](const ValuePtr &value) -> py::object {
[](const ValuePtr &value, const AbstractBasePtr &) -> py::object {
py::tuple tuple_container(1);
tuple_container[0] = value->cast<TypePtr>();
return tuple_container[0];
}},
// StringImm
{StringImm::kTypeId,
[](const ValuePtr &value) -> py::object {
[](const ValuePtr &value, const AbstractBasePtr &) -> py::object {
py::str res = value->cast<StringImmPtr>()->value();
return res;
}},
// ValueSequence
{ValueSequence::kTypeId,
[](const ValuePtr &value) -> py::object {
auto value_sequeue = value->cast<ValueSequencePtr>()->value();
py::tuple res_sequeue(value_sequeue.size());
for (size_t i = 0; i < value_sequeue.size(); i++) {
res_sequeue[i] = ValueToPyData(value_sequeue[i]);
}
if (value->isa<ValueTuple>()) {
return res_sequeue;
}
return res_sequeue.cast<py::list>();
[](const ValuePtr &value, const AbstractBasePtr &abs) -> py::object {
auto value_sequeue = value->cast<ValueSequencePtr>();
return ValueSequenceToPyData(value_sequeue, abs);
}},
// ValueDictionary
{ValueDictionary::kTypeId,
[](const ValuePtr &value) -> py::object {
auto value_list = value->cast<ValueDictionaryPtr>()->value();
py::dict res_dict;
for (const auto &v : value_list) {
res_dict[ValueToPyData(v.first)] = ValueToPyData(v.second);
}
return res_dict;
[](const ValuePtr &value, const AbstractBasePtr &abs) -> py::object {
auto value_dict = value->cast<ValueDictionaryPtr>();
return ValueDictionaryToPyData(value_dict, abs);
}},
// ValueSlice
{ValueSlice::kTypeId,
[](const ValuePtr &value) -> py::object {
[](const ValuePtr &value, const AbstractBasePtr &) -> py::object {
auto slice = value->cast<ValueSlicePtr>();
auto start = ValueToPyData(slice->start());
auto end = ValueToPyData(slice->stop());
@ -201,7 +258,7 @@ static ValueNameToConverterVector value_name_to_converter = {
}},
// KeywordArg
{KeywordArg::kTypeId,
[](const ValuePtr &value) -> py::object {
[](const ValuePtr &value, const AbstractBasePtr &) -> py::object {
auto abs_keyword_arg = value->ToAbstract()->cast<abstract::AbstractKeywordArgPtr>();
auto key = abs_keyword_arg->get_key();
auto val = abs_keyword_arg->get_arg()->BuildValue();
@ -212,53 +269,53 @@ static ValueNameToConverterVector value_name_to_converter = {
}},
// parse::NameSpace
{parse::NameSpace::kTypeId,
[](const ValuePtr &value) -> py::object {
[](const ValuePtr &value, const AbstractBasePtr &) -> py::object {
auto ns = value->cast<parse::NameSpacePtr>();
return ns->module_obj();
}},
// parse::ClassType
{parse::ClassType::kTypeId,
[](const ValuePtr &value) -> py::object {
[](const ValuePtr &value, const AbstractBasePtr &) -> py::object {
auto class_type = value->cast<parse::ClassTypePtr>();
return class_type->obj();
}},
// parse::MsClassObject
{parse::MsClassObject::kTypeId,
[](const ValuePtr &value) -> py::object {
[](const ValuePtr &value, const AbstractBasePtr &) -> py::object {
auto ms_class_object = value->cast<parse::MsClassObjectPtr>();
return ms_class_object->obj();
}},
// parse::InterpretedObject
{parse::InterpretedObject::kTypeId,
[](const ValuePtr &value) -> py::object {
[](const ValuePtr &value, const AbstractBasePtr &) -> py::object {
auto interpreted_object = value->cast<parse::InterpretedObjectPtr>();
return interpreted_object->obj();
}},
// None
{None::kTypeId, [](const ValuePtr &) -> py::object { return py::none(); }},
{None::kTypeId, [](const ValuePtr &, const AbstractBasePtr &) -> py::object { return py::none(); }},
// AnyValue
{AnyValue::kTypeId, [](const ValuePtr &) -> py::object { return py::none(); }},
{AnyValue::kTypeId, [](const ValuePtr &, const AbstractBasePtr &) -> py::object { return py::none(); }},
// ErrorValue
{ErrorValue::kTypeId, [](const ValuePtr &) -> py::object { return py::none(); }},
{ErrorValue::kTypeId, [](const ValuePtr &, const AbstractBasePtr &) -> py::object { return py::none(); }},
// FuncGraph
{FuncGraph::kTypeId, [](const ValuePtr &) -> py::object { return py::none(); }},
{FuncGraph::kTypeId, [](const ValuePtr &, const AbstractBasePtr &) -> py::object { return py::none(); }},
// Primitive
{Primitive::kTypeId, [](const ValuePtr &) -> py::object { return py::none(); }},
{Primitive::kTypeId, [](const ValuePtr &, const AbstractBasePtr &) -> py::object { return py::none(); }},
// Monad
{Monad::kTypeId, [](const ValuePtr &) -> py::object { return py::none(); }},
{Monad::kTypeId, [](const ValuePtr &, const AbstractBasePtr &) -> py::object { return py::none(); }},
// Ellipsis
{Ellipsis::kTypeId, [](const ValuePtr &) -> py::object { return py::ellipsis(); }}};
{Ellipsis::kTypeId, [](const ValuePtr &, const AbstractBasePtr &) -> py::object { return py::ellipsis(); }}};
// When converting data to tensor, ValueToPyData will only return _c_expression Tensor,
// but not python tensor. If python tensor is needed, call _convert_python_data to the output.
py::object ValueToPyData(const ValuePtr &value) {
py::object ValueToPyData(const ValuePtr &value, const AbstractBasePtr &abs) {
if (value == nullptr) {
MS_LOG(EXCEPTION) << "The `value` should not be null";
}
py::gil_scoped_acquire gil;
for (auto &iter : value_name_to_converter) {
if (value->IsFromTypeId(iter.first)) {
return iter.second(value);
return iter.second(value, abs);
}
}
MS_LOG(EXCEPTION) << "Unsupported to convert " << value->ToString() << "[" << value->type_name() << "] to a PyData";
@ -273,10 +330,6 @@ py::object AnyToPyData(const Any &value) {
MS_LOG(DEBUG) << "ValuePtr";
ValuePtr v = value.cast<ValuePtr>();
ret = ValueToPyData(v);
} else if (value.is<tensor::TensorPtr>()) {
MS_LOG(DEBUG) << "tensor";
auto tensor_ptr = value.cast<tensor::TensorPtr>();
ret = TensorToPyData(tensor_ptr);
} else if (value.is<py::object>()) {
MS_LOG(DEBUG) << "py obj";
ret = value.cast<py::object>();
@ -307,24 +360,7 @@ py::object AnyToPyData(const Any &value) {
return ret;
}
py::object BaseRefToPyData(const BaseRef &value, const AbstractBasePtr &output) {
py::object ret;
// If output value is a tuple, check if abstract is a COOTensor in funcgraph output
if (utils::isa<VectorRef>(value)) {
MS_LOG(DEBUG) << "BaseRefToPyData, value is tuple: " << value.ToString();
auto vec_ref = utils::cast<VectorRef>(value);
if (output != nullptr) {
ret = VectorRefToPyData(vec_ref, output);
} else {
ret = VectorRefToPyData(vec_ref);
}
} else {
ret = BaseRefToPyData(value);
}
return ret;
}
py::object BaseRefToPyData(const BaseRef &value) {
py::object BaseRefToPyData(const BaseRef &value, const AbstractBasePtr &abs) {
py::object ret;
MS_LOG(DEBUG) << "BaseRefToPyData " << value.ToString();
if (utils::isa<int>(value) || utils::isa<float>(value) || utils::isa<double>(value) || utils::isa<bool>(value)) {
@ -332,18 +368,14 @@ py::object BaseRefToPyData(const BaseRef &value) {
} else if (utils::isa<ValuePtr>(value)) {
MS_LOG(DEBUG) << "ValuePtr";
ValuePtr v = utils::cast<ValuePtr>(value);
ret = ValueToPyData(v);
} else if (utils::isa<tensor::TensorPtr>(value)) {
MS_LOG(DEBUG) << "tensor";
auto tensor_ptr = utils::cast<tensor::TensorPtr>(value);
ret = TensorToPyData(tensor_ptr);
ret = ValueToPyData(v, abs);
} else if (utils::isa<PyObjectRef>(value)) {
MS_LOG(DEBUG) << "py obj";
PyObjectRef py_ref = utils::cast<PyObjectRef>(value);
ret = py_ref.object_;
} else if (utils::isa<VectorRef>(value)) {
auto vec_ref = utils::cast<VectorRef>(value);
ret = VectorRefToPyData(vec_ref);
ret = VectorRefToPyData(vec_ref, abs);
} else if (utils::isa<TypePtr>(value)) {
py::tuple v(1);
v[0] = utils::cast<TypePtr>(value);
@ -419,42 +451,29 @@ py::object VectorToPyData(const Any &value) {
return ret;
}
py::object VectorRefToPyData(const VectorRef &value_list) {
py::object ret;
MS_LOG(DEBUG) << "vector_ref";
size_t value_size = value_list.size();
auto ref_tuple = py::tuple(value_size);
for (size_t i = 0; i < value_size; i++) {
ref_tuple[i] = BaseRefToPyData(value_list[i]);
}
ret = ref_tuple;
return ret;
}
py::object VectorRefToPyData(const VectorRef &value_list, const AbstractBasePtr &output) {
MS_LOG(DEBUG) << "vector_ref";
py::object VectorRefToPyData(const VectorRef &value_list, const AbstractBasePtr &abs) {
// Current VectorRef reflects a COOTensor type
if (output->isa<abstract::AbstractCSRTensor>()) {
if (abs != nullptr && abs->isa<abstract::AbstractCSRTensor>()) {
return MakeCSRTensor(value_list);
}
if (output->isa<abstract::AbstractCOOTensor>()) {
if (abs != nullptr && abs->isa<abstract::AbstractCOOTensor>()) {
return MakeCOOTensor(value_list);
}
py::object ret;
size_t value_size = value_list.size();
auto ref_tuple = py::tuple(value_size);
abstract::AbstractTuplePtr tuple_output = output->cast<abstract::AbstractTuplePtr>();
bool is_abstract_tuple = tuple_output != nullptr;
for (size_t i = 0; i < value_size; i++) {
if (!is_abstract_tuple || i >= tuple_output->size()) {
// Fall back to original process
auto seq_abs = CheckAbstractElementsSize<abstract::AbstractSequencePtr>(abs, value_size);
if (seq_abs == nullptr) {
for (size_t i = 0; i < value_size; i++) {
ref_tuple[i] = BaseRefToPyData(value_list[i]);
} else {
ref_tuple[i] = BaseRefToPyData(value_list[i], (*tuple_output)[i]);
}
} else {
for (size_t i = 0; i < value_size; i++) {
ref_tuple[i] = BaseRefToPyData(value_list[i], seq_abs->elements()[i]);
}
}
ret = ref_tuple;
return ret;
return ref_tuple;
}
bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &output, const py::tuple &args,
@ -462,12 +481,14 @@ bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &output, const py::tuple
if (output->isa<ValueNode>()) {
MS_LOG(INFO) << "Graph's output is a constant. No need to execute.";
ValuePtr value = GetValueNode(output);
if (output->abstract()->isa<abstract::AbstractCSRTensor>()) {
auto abs = output->abstract();
MS_EXCEPTION_IF_NULL(abs);
if (abs->isa<abstract::AbstractCSRTensor>()) {
*ret_val = MakeCSRTensor(value);
} else if (output->abstract()->isa<abstract::AbstractCOOTensor>()) {
} else if (abs->isa<abstract::AbstractCOOTensor>()) {
*ret_val = MakeCOOTensor(value);
} else {
*ret_val = ValueToPyData(value);
*ret_val = ValueToPyData(value, abs);
}
return true;
}
@ -505,6 +526,7 @@ bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &output, const py::tuple
auto tensor = param->default_param();
*ret_val = py::cast(tensor);
}
*ret_val = SetAdaptedAttrToTensor(*ret_val, output->abstract());
return true;
}
return false;

View File

@ -1213,13 +1213,19 @@ AbstractBasePtr AbstractTensor::Join(const AbstractBasePtr &other) {
// Check element
auto element = element_->Join(other_tensor->element_);
MS_EXCEPTION_IF_NULL(element);
return std::make_shared<AbstractTensor>(element, res_shape);
auto ret = std::make_shared<AbstractTensor>(element, res_shape);
ret->set_is_adapter(is_adapter_);
return ret;
}
bool AbstractTensor::equal_to(const AbstractTensor &other) const {
if (this == &other) {
return true;
}
// Check if both Tensor or both AdapterTensor.
if (is_adapter() != other.is_adapter()) {
return false;
}
const auto &v1 = GetValueTrack();
const auto &v2 = other.GetValueTrack();
MS_EXCEPTION_IF_NULL(v1);
@ -1268,6 +1274,7 @@ AbstractBasePtr AbstractTensor::Clone() const {
clone->set_value(GetValueTrack());
clone->set_value_range(get_min_value(), get_max_value());
clone->set_shape_value(get_shape_value());
clone->set_is_adapter(is_adapter());
return clone;
}
@ -1278,6 +1285,7 @@ AbstractBasePtr AbstractTensor::Broaden() const {
MS_EXCEPTION_IF_NULL(shp);
broaden->set_shape(shp->Clone());
broaden->set_value(kAnyValue);
broaden->set_is_adapter(is_adapter());
return broaden;
}
@ -1289,6 +1297,7 @@ AbstractBasePtr AbstractTensor::BroadenWithShape() const {
shp->Broaden();
broaden->set_shape(shp);
broaden->set_value(kAnyValue);
broaden->set_is_adapter(is_adapter());
return broaden;
}
@ -1441,6 +1450,7 @@ std::string AbstractJTagged::ToString() const {
AbstractRefTensor::AbstractRefTensor(const AbstractTensorPtr &ref_value, const ValuePtr &ref_key_value)
: AbstractTensor(*ref_value), ref_key_value_(ref_key_value) {
set_type(std::make_shared<RefType>());
set_is_adapter(ref_value->is_adapter());
MS_EXCEPTION_IF_NULL(ref_key_value);
if (ref_key_value != kAnyValue && !ref_key_value->isa<RefKey>()) {
MS_LOG(EXCEPTION) << "ref_key_value must be kAnyValue or RefKey, but got:" << ref_key_value->ToString();

View File

@ -730,11 +730,15 @@ class MS_CORE_API AbstractTensor : public AbstractUndetermined {
AbstractBasePtr PartialBroaden() const override;
bool is_adapter() const { return is_adapter_; }
void set_is_adapter(bool is_adapter) { is_adapter_ = is_adapter; }
protected:
bool equal_to(const AbstractTensor &other) const;
ValuePtr min_value_ = nullptr;
ValuePtr max_value_ = nullptr;
ValuePtr shape_value_ = nullptr;
bool is_adapter_ = false;
};
using AbstractTensorPtr = std::shared_ptr<AbstractTensor>;
using AbstractTensorPtrList = std::vector<AbstractTensorPtr>;

View File

@ -0,0 +1,38 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ir/adapter_tensor.h"
#include <memory>
#include <utility>
#include <algorithm>
#include "abstract/utils.h"
#include "abstract/abstract_value.h"
namespace mindspore {
namespace tensor {
bool AdapterTensor::operator==(const AdapterTensor &other) const { return this == &other; }
abstract::AbstractBasePtr AdapterTensor::ToAbstract() {
auto abs = origin_tensor_->ToAbstract();
MS_EXCEPTION_IF_NULL(abs);
auto tensor_abs = abs->cast<abstract::AbstractTensorPtr>();
MS_EXCEPTION_IF_NULL(tensor_abs);
tensor_abs->set_is_adapter(true);
return tensor_abs;
}
} // namespace tensor
} // namespace mindspore

View File

@ -0,0 +1,57 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_IR_ADAPTER_TENSOR_H_
#define MINDSPORE_CORE_IR_ADAPTER_TENSOR_H_
#include <memory>
#include <string>
#include "ir/anf.h"
#include "ir/dtype.h"
#include "ir/tensor.h"
namespace mindspore {
namespace tensor {
class AdapterTensor;
// Smart pointer for AdapterTensor.
using AdapterTensorPtr = std::shared_ptr<AdapterTensor>;
///
/// \brief AdapterTensor is used to map the Tensor of other frameworks.
///
class MS_CORE_API AdapterTensor final : public Tensor {
public:
/// \brief Create AdapterTensor from tensor.
///
/// \param[in] tensor The input tensor.
explicit AdapterTensor(const TensorPtr &tensor) : Tensor(*tensor), origin_tensor_(tensor) {}
AdapterTensor() = default;
/// Destructor of AdapterTensor.
~AdapterTensor() override = default;
MS_DECLARE_PARENT(AdapterTensor, Tensor);
bool operator==(const AdapterTensor &other) const;
abstract::AbstractBasePtr ToAbstract() override;
private:
TensorPtr origin_tensor_{nullptr};
};
} // namespace tensor
} // namespace mindspore
#endif // MINDSPORE_CORE_IR_ADAPTER_TENSOR_H_

View File

@ -466,6 +466,16 @@ class MS_CORE_API Tensor : public MetaTensor {
/// \param[in] flag Whether this Tensor is initialized.
void set_init_flag(bool flag) { init_flag_ = flag; }
/// \brief Check whether this Tensor needs to be converted.
///
/// \return Whether this Tensor needs to be converted.
bool is_adapter() const { return adapter_flag_; }
/// \brief Set the adapter flag of this Tensor.
///
/// \param[in] flag Whether this Tensor needs to be converted.
void set_adapter_flag(bool flag) { adapter_flag_ = flag; }
/// \brief Check if this Tensor is forward output.
///
/// \return Whether this Tensor is forward output.
@ -772,6 +782,7 @@ class MS_CORE_API Tensor : public MetaTensor {
void ExecuteLazyTask() const;
bool init_flag_{false};
bool adapter_flag_{false};
bool is_forward_output_{false};
TensorDataPtr data_{nullptr};
std::string id_{""};

View File

@ -1693,6 +1693,8 @@ GVAR_DEF(PrimitivePtr, kPrimDictLen, std::make_shared<Primitive>("dict_len"));
GVAR_DEF(PrimitivePtr, kPrimFakeBprop, std::make_shared<Primitive>("fake_bprop"));
GVAR_DEF(PrimitivePtr, kPrimBroadcastGradientArgs, std::make_shared<Primitive>("BroadcastGradientArgs"));
GVAR_DEF(PrimitivePtr, kPrimDynamicBroadcastGradientArgs, std::make_shared<Primitive>(kDynamicBroadcastGradientArgs));
GVAR_DEF(PrimitivePtr, kPrimConvertToAdapterTensor, std::make_shared<Primitive>("ConvertToAdapterTensor"));
GVAR_DEF(PrimitivePtr, kPrimConvertToMsTensor, std::make_shared<Primitive>("ConvertToMsTensor"));
// Random
GVAR_DEF(PrimitivePtr, kPrimStandardLaplace, std::make_shared<Primitive>("StandardLaplace"));

View File

@ -26,7 +26,8 @@ from .parser import (Parser, create_instance, is_supported_create_instance_type,
convert_to_ms_tensor, get_object_description, get_class_attr_namespace_symbol, get_ms_class_name,
is_class_type, check_obj_bool, python_isinstance, ms_isinstance, convert_to_ms_csrtensor,
convert_to_ms_cootensor, convert_class_to_function, convert_cell_list_to_sequence, is_cell_list,
get_obj_from_sequence, get_type, is_class_member_recursive, merge_global_params, get_global_params)
get_obj_from_sequence, get_type, is_class_member_recursive, merge_global_params, get_global_params,
get_adapter_tensor_attr)
__all__ = ['Parser', 'create_instance', 'is_supported_create_instance_type', 'generate_scope',
'get_bprop_method_of_class', 'get_class_instance_type', 'get_class_member_namespace_symbol',
@ -37,4 +38,4 @@ __all__ = ['Parser', 'create_instance', 'is_supported_create_instance_type', 'ge
'convert_to_ms_tensor', 'get_object_description', 'get_class_attr_namespace_symbol', 'get_ms_class_name',
'is_class_type', 'check_obj_bool', 'python_isinstance', 'ms_isinstance', 'convert_to_ms_csrtensor',
'convert_to_ms_cootensor', 'convert_class_to_function', 'convert_cell_list_to_sequence', 'is_cell_list',
'get_obj_from_sequence', 'get_type', 'is_class_member_recursive']
'get_obj_from_sequence', 'get_type', 'is_class_member_recursive', 'get_adapter_tensor_attr']

View File

@ -40,6 +40,7 @@ from mindspore.common.api import _MindsporeFunctionExecutor
from mindspore.common import dtype as mstype
from mindspore.common.parameter import Parameter
from mindspore.common import mutable
from mindspore.common._register_for_adapter import ms_adapter_registry
from .namespace import Namespace, CellNamespace, ClosureNamespace, ClassMemberNamespace, ClassAttrNamespace
from .resources import parse_object_map, ops_symbol_map, convert_object_map, convert_class_to_function_map, trope_ns
from .resources import SYMBOL_UNDEFINE, NO_IMPLEMENT
@ -127,6 +128,12 @@ _unsupported_convert_data_type = (
_global_params = {}
def _convert_map():
"""Get convert object map"""
adapter_convert_map = ms_adapter_registry.convert_map
return adapter_convert_map if adapter_convert_map else convert_object_map
def create_slice_obj(start, end, step):
"""Create slice object"""
return slice(start, end, step)
@ -245,8 +252,9 @@ def resolve_symbol(namespace, symbol):
"within the construct() or @jit decorated function in graph mode.")
# If need trope the obj
if resolve_ in convert_object_map:
resolve_ = convert_object_map.get(resolve_)
convert_map = _convert_map()
if resolve_ in convert_map:
resolve_ = convert_map.get(resolve_)
logger.debug("Convert resolve: %r", resolve_)
if resolve_ == NO_IMPLEMENT:
raise NotImplementedError(f"Not support for '{symbol}'.")
@ -549,6 +557,18 @@ def is_class_type(cls):
return isinstance(cls, type)
def get_adapter_tensor_attr(name):
"""Get the method or @property modified function of the class, excluding those inherited from parent class."""
cls = ms_adapter_registry.tensor
properties = [key for key, value in vars(cls).items() if isinstance(value, property)]
if name in properties:
return getattr(cls, name).fget, True
methods = [key for key, value in vars(cls).items() if inspect.isfunction(value)]
if name in methods:
return getattr(cls, name), False
return None, False
def get_ms_class_name(cls):
"""Get the name of the class instance decorated with jit_class."""
if isinstance(cls, type):
@ -827,7 +847,7 @@ class Parser:
@staticmethod
def is_unsupported_namespace(value):
"""To check if not supported for namespace"""
unsupported = isinstance(value, _builtin_function_or_method_type) and value not in convert_object_map
unsupported = isinstance(value, _builtin_function_or_method_type) and value not in _convert_map()
logger.debug(f"'{value}' unsupported: {unsupported}.")
if unsupported and value in _fallback_unsupported_python_builtin_type:
raise TypeError(f"'{value}' is not supported both in JIT Fallback and graph mode.")
@ -868,7 +888,7 @@ class Parser:
"""Get the convert object for value which don't support to be converted in C++."""
if not self.is_unsupported_convert_data_type(value):
return value
return convert_object_map.get(value)
return _convert_map().get(value)
def parse(self):
"""Parse the function or method."""

View File

@ -0,0 +1,51 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Registry MSAdapter config."""
from mindspore.common.tensor import Tensor
class Registry:
"""Registry class for ms adapter."""
def __init__(self):
self._tensor = None
self._convert_map = {}
@property
def tensor(self):
if self._tensor is None:
raise ValueError("Before using Tensor in MSAdapter, please call 'set_adapter_config'.")
return self._tensor
@property
def convert_map(self):
return self._convert_map
def register_tensor(self, value):
if self._tensor is not None:
raise ValueError("Repeated registration of tensor in ms adapter config.")
if not issubclass(value, Tensor):
raise ValueError(f"The tensor definition here should be a subclass of ms.Tensor, but got {value}.")
self._tensor = value
def register_convert_map(self, value):
if not isinstance(value, dict):
raise ValueError(f"Expect a dict type, but got {type(value)}.")
self._convert_map = value
ms_adapter_registry = Registry()

View File

@ -45,6 +45,7 @@ from mindspore.parallel._utils import _check_full_batch, _get_parameter_broadcas
from mindspore._checkparam import Validator
from mindspore.common._utils import is_shape_unknown
from mindspore.common.mutable import mutable
from mindspore.common._register_for_adapter import ms_adapter_registry
# store ms_function class compiled pipeline cache
ms_compile_cache = set()
@ -65,6 +66,8 @@ def _convert_python_data(data):
Returns:
data, a data convert C++ to python
"""
if isinstance(data, Tensor) and data.adapter_flag:
return ms_adapter_registry.tensor(data)
if isinstance(data, Tensor) and not isinstance(data, PythonTensor):
return PythonTensor(data, internal=True)
if isinstance(data, CSRTensor) and not isinstance(data, PythonCSRTensor):
@ -878,6 +881,25 @@ def jit_class(cls):
return cls
def set_adapter_config(config):
"""
Register configuration information for MSAdapter.
Args:
config (dict): Configuration information.
"""
if not isinstance(config, dict):
raise TypeError(f"The input argument of 'set_adapter_config' should be a dict, but got {config}.")
for key, value in config.items():
if key == "Tensor":
setattr(value, "__adapter_tensor__", True)
ms_adapter_registry.register_tensor(value)
elif key == "convert_object_map":
ms_adapter_registry.register_convert_map(value)
else:
raise ValueError(f"Unsupported key in adapter config: {key}")
def _function_forbid_reuse(func):
if not inspect.isfunction(func):
raise TypeError(f'Decorator _function_forbid_reuse can only be used for function type, but got {func}.')

View File

@ -197,3 +197,23 @@ def get_bprop_resize_bilinear(self):
return dx, zeros_like(size)
return bprop
@bprop_getters.register(inner.ConvertToAdapterTensor)
def get_bprop_convert_to_adapter_tensor(self):
"""Generate bprop for ConvertToAdapterTensor"""
def bprop(x, out, dout):
return (dout,)
return bprop
@bprop_getters.register(inner.ConvertToMsTensor)
def get_bprop_convert_to_ms_tensor(self):
"""Generate bprop for ConvertToMsTensor"""
def bprop(x, out, dout):
return (dout,)
return bprop

View File

@ -31,6 +31,7 @@ from mindspore.common import dtype as mstype
from mindspore.common.parameter import Parameter
from mindspore.communication.management import GlobalComm
from mindspore.common.api import _pynative_executor
from mindspore.common._register_for_adapter import ms_adapter_registry
# Bit operation
@ -2321,3 +2322,65 @@ class IsInstance(PrimitiveWithInfer):
'dtype': mstype.type_type,
'value': value}
return out
class ConvertToAdapterTensor(Primitive):
"""
Convert a tensor from MindSpore's Tensor type to MSAdapter's Tensor type,
where MSAdapter's Tensor is a subclass of MindSpore's Tensor.
Inputs:
- **x** (Tensor) - The input tensor.
Outputs:
A tensor, whose type is MSAdapter's Tensor.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> x = Tensor([1, 2 ,3])
>>> x = ops.ConvertToAdapterTensor()(x)
>>> print(x)
[1 2 3]
"""
@prim_attr_register
def __init__(self):
"""Initialize"""
def __call__(self, x):
"""run in PyNative mode"""
return ms_adapter_registry.tensor(x, inner=True)
convert_to_adapter_tensor = ConvertToAdapterTensor()
class ConvertToMsTensor(Primitive):
"""
Convert a tensor from MSAdapter's Tensor type to MindSpore's Tensor type,
where MSAdapter's Tensor is a subclass of MindSpore's Tensor.
Inputs:
- **x** (Tensor) - The input tensor.
Outputs:
A tensor, whose type is MindSpore's Tensor.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> x = Tensor([1, 2 ,3])
>>> x = ops.ConvertToMsTensor()(x)
>>> print(x)
[1 2 3]
"""
@prim_attr_register
def __init__(self):
"""Initialize"""
def __call__(self, x):
"""run in PyNative mode"""
return Tensor(x)
convert_to_ms_tensor = ConvertToMsTensor()

View File

@ -0,0 +1,44 @@
from mindspore.common.api import set_adapter_config
from mindspore._extends.parse import trope as T
from mindspore._extends.parse.resources import convert_object_map
from ._register.ms_adapter_api import Tensor, Parameter
from ._register import multitype_ops
from ._register import standard_method as S
# Update convert_object_map
convert_object_map[T.add] = multitype_ops.add
convert_object_map[T.sub] = multitype_ops.sub
convert_object_map[T.mul] = multitype_ops.mul
convert_object_map[T.truediv] = multitype_ops.div
convert_object_map[T.getitem] = multitype_ops.getitem
convert_object_map[T.setitem] = multitype_ops.setitem
convert_object_map[T.floordiv] = multitype_ops.floordiv
convert_object_map[T.mod] = multitype_ops.mod
convert_object_map[T.pow] = multitype_ops.pow_
convert_object_map[T.and_] = multitype_ops.bitwise_and
convert_object_map[T.or_] = multitype_ops.bitwise_or
convert_object_map[T.xor] = multitype_ops.bitwise_xor
convert_object_map[T.neg] = multitype_ops.negative
convert_object_map[T.not_] = multitype_ops.logical_not
convert_object_map[T.eq] = multitype_ops.equal
convert_object_map[T.ne] = multitype_ops.not_equal
convert_object_map[T.lt] = multitype_ops.less
convert_object_map[T.gt] = multitype_ops.greater
convert_object_map[T.le] = multitype_ops.less_equal
convert_object_map[T.ge] = multitype_ops.greater_equal
convert_object_map[T.contains] = multitype_ops.in_
convert_object_map[T.not_contains] = multitype_ops.not_in_
convert_object_map[T.matmul] = S.adapter_matmul
convert_object_map[T.invert] = S.adapter_invert
convert_object_map[T.abs] = S.adapter_abs
convert_object_map[T.round] = S.adapter_round
convert_object_map[T.max] = S.adapter_max
convert_object_map[T.min] = S.adapter_min
convert_object_map[T.sum] = S.adapter_sum
adapter_config = {"Tensor": Tensor, "convert_object_map": convert_object_map}
set_adapter_config(adapter_config)
__all__ = ["Tensor", "Parameter"]

View File

@ -0,0 +1,60 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" MSAdapter api. """
import sys
import mindspore as ms
class Tensor(ms.Tensor):
def __init__(self, input_data=None, dtype=None, shape=None, init=None, inner=False):
super(Tensor, self).__init__(input_data=input_data, dtype=dtype, shape=shape, init=init)
@property
def attr(self):
return 10
def method(self, x):
return x + self.attr
def size(self, dim=None):
if dim is None:
return self.shape
return self.shape[dim]
class Parameter(ms.Parameter):
def __new__(cls, default_input, *args, **kwargs):
init_data_flag = bool(isinstance(default_input, ms.Tensor) and default_input.has_init)
rc = sys.getrefcount(default_input)
_, *class_init_args = Parameter._get_parameter_new_args(default_input, rc)
new_type = Parameter._get_base_class(Tensor)
obj = Tensor.__new__(new_type)
Tensor.__init__(obj, *class_init_args)
obj.init_mode = None
obj.is_default_input_init = init_data_flag
if obj.has_init:
obj.init_mode = default_input
return obj
@staticmethod
def _get_base_class(input_class):
input_class_name = Parameter.__name__
if input_class_name in Parameter._base_type:
new_type = Parameter._base_type.get(input_class_name)
else:
new_type = type(input_class_name, (Parameter, input_class), {})
Parameter._base_type[input_class_name] = new_type
return new_type

View File

@ -0,0 +1,172 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from mindspore.ops.composite.multitype_ops.add_impl import add
from mindspore.ops.composite.multitype_ops.sub_impl import sub
from mindspore.ops.composite.multitype_ops.mul_impl import mul
from mindspore.ops.composite.multitype_ops.div_impl import div
from mindspore.ops.composite.multitype_ops.floordiv_impl import floordiv
from mindspore.ops.composite.multitype_ops.mod_impl import mod
from mindspore.ops.composite.multitype_ops.pow_impl import pow_
from mindspore.ops.composite.multitype_ops.bitwise_and_impl import bitwise_and
from mindspore.ops.composite.multitype_ops.bitwise_or_impl import bitwise_or
from mindspore.ops.composite.multitype_ops.bitwise_xor_impl import bitwise_xor
from mindspore.ops.composite.multitype_ops.negative_impl import negative
from mindspore.ops.composite.multitype_ops.logic_not_impl import logical_not
from mindspore.ops.composite.multitype_ops.equal_impl import equal
from mindspore.ops.composite.multitype_ops.not_equal_impl import not_equal
from mindspore.ops.composite.multitype_ops.less_impl import less
from mindspore.ops.composite.multitype_ops.greater_impl import greater
from mindspore.ops.composite.multitype_ops.less_equal_impl import less_equal
from mindspore.ops.composite.multitype_ops.greater_equal_impl import greater_equal
from mindspore.ops.composite.multitype_ops.in_impl import in_
from mindspore.ops.composite.multitype_ops.not_in_impl import not_in_
from mindspore.ops.composite.multitype_ops.getitem_impl import getitem
from mindspore.ops.composite.multitype_ops.setitem_impl import setitem
from tests.st.ms_adapter._register import utils
# multitype_ops.add
utils.update_multitype_ops_tensor_tensor(add)
utils.update_multitype_ops_number_tensor(add)
utils.update_multitype_ops_tensor_number(add)
utils.update_multitype_ops_tuple_tensor(add)
utils.update_multitype_ops_tensor_tuple(add)
utils.update_multitype_ops_list_tensor(add)
utils.update_multitype_ops_tensor_list(add)
# multitype_ops.sub
utils.update_multitype_ops_tensor_tensor(sub)
utils.update_multitype_ops_number_tensor(sub)
utils.update_multitype_ops_tensor_number(sub)
utils.update_multitype_ops_tuple_tensor(sub)
utils.update_multitype_ops_tensor_tuple(sub)
utils.update_multitype_ops_list_tensor(sub)
utils.update_multitype_ops_tensor_list(sub)
# multitype_ops.mul
utils.update_multitype_ops_tensor_tensor(mul)
utils.update_multitype_ops_number_tensor(mul)
utils.update_multitype_ops_tensor_number(mul)
utils.update_multitype_ops_tuple_tensor(mul)
utils.update_multitype_ops_tensor_tuple(mul)
utils.update_multitype_ops_list_tensor(mul)
utils.update_multitype_ops_tensor_list(mul)
# multitype_ops.div
utils.update_multitype_ops_tensor_tensor(div)
utils.update_multitype_ops_number_tensor(div)
utils.update_multitype_ops_tensor_number(div)
utils.update_multitype_ops_tuple_tensor(div)
utils.update_multitype_ops_tensor_tuple(div)
utils.update_multitype_ops_list_tensor(div)
utils.update_multitype_ops_tensor_list(div)
# multitype_ops.floordiv
utils.update_multitype_ops_tensor_tensor(floordiv)
utils.update_multitype_ops_number_tensor(floordiv)
utils.update_multitype_ops_tensor_number(floordiv)
utils.update_multitype_ops_tuple_tensor(floordiv)
utils.update_multitype_ops_tensor_tuple(floordiv)
utils.update_multitype_ops_list_tensor(floordiv)
utils.update_multitype_ops_tensor_list(floordiv)
# multitype_ops.mod
utils.update_multitype_ops_tensor_tensor(mod)
utils.update_multitype_ops_number_tensor(mod)
utils.update_multitype_ops_tensor_number(mod)
utils.update_multitype_ops_tuple_tensor(mod)
utils.update_multitype_ops_tensor_tuple(mod)
utils.update_multitype_ops_list_tensor(mod)
utils.update_multitype_ops_tensor_list(mod)
# multitype_ops.pow_
utils.update_multitype_ops_tensor_tensor(pow_)
utils.update_multitype_ops_number_tensor(pow_)
utils.update_multitype_ops_tensor_number(pow_)
utils.update_multitype_ops_tuple_tensor(pow_)
utils.update_multitype_ops_tensor_tuple(pow_)
utils.update_multitype_ops_list_tensor(pow_)
utils.update_multitype_ops_tensor_list(pow_)
# multitype_ops.bitwise_and
utils.update_multitype_ops_tensor_tensor(bitwise_and)
utils.update_multitype_ops_number_tensor(bitwise_and)
utils.update_multitype_ops_tensor_number(bitwise_and)
# multitype_ops.bitwise_or
utils.update_multitype_ops_tensor_tensor(bitwise_or)
utils.update_multitype_ops_number_tensor(bitwise_or)
utils.update_multitype_ops_tensor_number(bitwise_or)
# multitype_ops.bitwise_xor
utils.update_multitype_ops_tensor_tensor(bitwise_xor)
utils.update_multitype_ops_number_tensor(bitwise_xor)
utils.update_multitype_ops_tensor_number(bitwise_xor)
# multitype_ops.negative
utils.update_multitype_ops_tensor(negative)
# multitype_ops.logical_not
utils.update_multitype_ops_tensor(logical_not)
# multitype_ops.equal
utils.update_multitype_ops_tensor_tensor(equal)
utils.update_multitype_ops_number_tensor(equal)
utils.update_multitype_ops_tensor_number(equal)
# multitype_ops.not_equal
utils.update_multitype_ops_tensor_tensor(not_equal)
utils.update_multitype_ops_number_tensor(not_equal)
utils.update_multitype_ops_tensor_number(not_equal)
# multitype_ops.less
utils.update_multitype_ops_tensor_tensor(less)
utils.update_multitype_ops_number_tensor(less)
utils.update_multitype_ops_tensor_number(less)
# multitype_ops.greater
utils.update_multitype_ops_tensor_tensor(greater)
utils.update_multitype_ops_number_tensor(greater)
utils.update_multitype_ops_tensor_number(greater)
# multitype_ops.less_equal
utils.update_multitype_ops_tensor_tensor(less_equal)
utils.update_multitype_ops_number_tensor(less_equal)
utils.update_multitype_ops_tensor_number(less_equal)
# multitype_ops.greater_equal
utils.update_multitype_ops_tensor_tensor(greater_equal)
utils.update_multitype_ops_number_tensor(greater_equal)
utils.update_multitype_ops_tensor_number(greater_equal)
# multitype_ops.in_
utils.update_multitype_ops_tensor_tuple(in_)
utils.update_multitype_ops_tensor_list(in_)
# multitype_ops.not_in_
utils.update_multitype_ops_tensor_tuple(not_in_)
utils.update_multitype_ops_tensor_list(not_in_)
# multitype_ops.getitem
utils.update_multitype_ops_tensor_tensor(getitem)
utils.update_multitype_ops_tensor_number(getitem)
utils.update_multitype_ops_tensor_tuple(getitem)
utils.update_multitype_ops_tensor_list(getitem)
utils.update_multitype_ops_tensor_none(getitem)
utils.update_multitype_ops_tensor_slice(getitem)
# multitype_ops.setitem
utils.update_multitype_ops_setitem_tensor(setitem)

View File

@ -0,0 +1,107 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from mindspore._extends.parse import trope as T
from mindspore._extends.parse.resources import convert_object_map
from tests.st.ms_adapter._register.ms_adapter_api import Tensor as adapter_Tensor
from tests.st.ms_adapter._register.utils import convert_to_ms_tensor, convert_to_adapter_tensor
matmul_fn = convert_object_map.get(T.matmul)
invert_fn = convert_object_map.get(T.invert)
abs_fn = convert_object_map.get(T.abs)
round_fn = convert_object_map.get(T.round)
max_fn = convert_object_map.get(T.max)
min_fn = convert_object_map.get(T.min)
sum_fn = convert_object_map.get(T.sum)
def adapter_matmul(x, y):
if isinstance(x, adapter_Tensor) and isinstance(y, adapter_Tensor):
x = convert_to_ms_tensor(x)
y = convert_to_ms_tensor(y)
out = matmul_fn(x, y)
out = convert_to_adapter_tensor(out)
else:
out = matmul_fn(x, y)
return out
def adapter_invert(x):
if isinstance(x, adapter_Tensor):
x = convert_to_ms_tensor(x)
out = invert_fn(x)
out = convert_to_adapter_tensor(out)
else:
out = invert_fn(x)
return out
def adapter_abs(x):
if isinstance(x, adapter_Tensor):
x = convert_to_ms_tensor(x)
out = abs_fn(x)
out = convert_to_adapter_tensor(out)
else:
out = abs_fn(x)
return out
def adapter_round(*data):
if (len(data) == 1 and isinstance(data[0], adapter_Tensor)) or \
(len(data) == 2 and isinstance(data[0], adapter_Tensor) and isinstance(data[1], None)):
x = data[0]
x = convert_to_ms_tensor(x)
out = round_fn(x)
out = convert_to_adapter_tensor(out)
else:
out = round_fn(*data)
return out
def _has_adapter_tensor(*data):
if len(data) == 1 and isinstance(data[0], adapter_Tensor):
return True
for elem in data:
if isinstance(elem, adapter_Tensor):
return True
return False
def adapter_max(*data):
if _has_adapter_tensor(*data):
out = max_fn(*data)
out = convert_to_adapter_tensor(out)
else:
out = max_fn(*data)
return out
def adapter_min(*data):
if _has_adapter_tensor(*data):
out = min_fn(*data)
out = convert_to_adapter_tensor(out)
else:
out = min_fn(*data)
return out
def adapter_sum(*data):
if _has_adapter_tensor(*data):
out = sum_fn(*data)
out = convert_to_adapter_tensor(out)
else:
out = sum_fn(*data)
return out

View File

@ -0,0 +1,204 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import mindspore as ms
from mindspore import dtype as mstype
from mindspore._c_expression import typing
from mindspore.ops.operations import _inner_ops as inner
from tests.st.ms_adapter._register.ms_adapter_api import Tensor as adapter_Tensor
def convert_to_ms_tensor(x):
return inner.convert_to_ms_tensor(x)
def convert_to_adapter_tensor(x):
return inner.convert_to_adapter_tensor(x)
def get_registed_fn(ops, *type_names):
types = tuple(map(mstype.typing.str_to_type, type_names))
for sigs, fn in ops.entries:
if len(sigs) != len(types):
continue
if any(not typing.is_subclass(type_, sig) for sig, type_ in zip(sigs, types)):
continue
return fn
raise ValueError(f"For 'MultitypeFuncGraph', cannot find fn match given types: {types}.")
def convert_output(out):
if isinstance(out, ms.Tensor):
out = convert_to_adapter_tensor(out)
return out
def update_multitype_ops_tensor(ops):
func = get_registed_fn(ops, "Tensor")
@ops.register("Tensor")
def _tensor(x):
if isinstance(x, adapter_Tensor):
x = convert_to_ms_tensor(x)
out = func(x)
out = convert_output(out)
else:
out = func(x)
return out
def update_multitype_ops_tensor_tensor(ops):
func = get_registed_fn(ops, "Tensor", "Tensor")
@ops.register("Tensor", "Tensor")
def _tensor_and_tensor(x, y):
if isinstance(x, adapter_Tensor) and isinstance(y, adapter_Tensor):
x = convert_to_ms_tensor(x)
y = convert_to_ms_tensor(y)
out = func(x, y)
out = convert_output(out)
else:
out = func(x, y)
return out
def update_multitype_ops_number_tensor(ops):
func = get_registed_fn(ops, "Number", "Tensor")
@ops.register("Number", "Tensor")
def _number_and_tensor(x, y):
if isinstance(y, adapter_Tensor):
y = convert_to_ms_tensor(y)
out = func(x, y)
out = convert_output(out)
else:
out = func(x, y)
return out
def update_multitype_ops_tensor_number(ops):
func = get_registed_fn(ops, "Tensor", "Number")
@ops.register("Tensor", "Number")
def _tensor_and_number(x, y):
if isinstance(x, adapter_Tensor):
x = convert_to_ms_tensor(x)
out = func(x, y)
out = convert_output(out)
else:
out = func(x, y)
return out
def update_multitype_ops_tuple_tensor(ops):
func = get_registed_fn(ops, "Tuple", "Tensor")
@ops.register("Tuple", "Tensor")
def _tuple_and_tensor(x, y):
if isinstance(y, adapter_Tensor):
y = convert_to_ms_tensor(y)
out = func(x, y)
out = convert_output(out)
else:
out = func(x, y)
return out
def update_multitype_ops_tensor_tuple(ops):
func = get_registed_fn(ops, "Tensor", "Tuple")
@ops.register("Tensor", "Tuple")
def _tensor_and_tuple(x, y):
if isinstance(x, adapter_Tensor):
x = convert_to_ms_tensor(x)
out = func(x, y)
out = convert_output(out)
else:
out = func(x, y)
return out
def update_multitype_ops_list_tensor(ops):
func = get_registed_fn(ops, "List", "Tensor")
@ops.register("List", "Tensor")
def _list_and_tensor(x, y):
if isinstance(y, adapter_Tensor):
y = convert_to_ms_tensor(y)
out = func(x, y)
out = convert_output(out)
else:
out = func(x, y)
return out
def update_multitype_ops_tensor_list(ops):
func = get_registed_fn(ops, "Tensor", "List")
@ops.register("Tensor", "List")
def _tensor_and_list(x, y):
if isinstance(x, adapter_Tensor):
x = convert_to_ms_tensor(x)
out = func(x, y)
out = convert_output(out)
else:
out = func(x, y)
return out
def update_multitype_ops_tensor_none(ops):
func = get_registed_fn(ops, "Tensor", "None")
@ops.register("Tensor", "None")
def _tensor_and_none(x, y):
if isinstance(x, adapter_Tensor):
x = convert_to_ms_tensor(x)
out = func(x, y)
out = convert_output(out)
else:
out = func(x, y)
return out
def update_multitype_ops_tensor_slice(ops):
func = get_registed_fn(ops, "Tensor", "Slice")
@ops.register("Tensor", "Slice")
def _tensor_and_slice(x, y):
if isinstance(x, adapter_Tensor):
x = convert_to_ms_tensor(x)
out = func(x, y)
out = convert_output(out)
else:
out = func(x, y)
return out
def update_multitype_ops_setitem_tensor(ops):
def register_for_setitem(sigs, fn):
@ops.register(*sigs)
def _tensor_setitem(data, index, value):
if isinstance(data, adapter_Tensor):
data = convert_to_ms_tensor(data)
out = fn(data, index, value)
out = convert_to_adapter_tensor(out)
else:
out = fn(data, index, value)
return out
entries = ops.entries.copy()
for sigs, fn in entries:
if typing.is_subclass(sigs[0], mstype.tensor):
register_for_setitem(sigs, fn)

View File

@ -0,0 +1,141 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test MSAdapter. """
import pytest
import mindspore as ms
from tests.st.ms_adapter import Tensor, Parameter
from tests.st.ms_adapter._register.utils import convert_to_ms_tensor, convert_to_adapter_tensor
ms.set_context(mode=ms.GRAPH_MODE)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_tensor_attr():
"""
Feature: MSAdapter
Description: Get the properties of MSAdapter.Tensor
Expectation: No exception
"""
@ms.jit
def func(x):
return x.attr
x = Tensor(1)
assert func(x) == 10
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_tensor_method():
"""
Feature: MSAdapter
Description: Get the methods of MSAdapter.Tensor
Expectation: No exception
"""
@ms.jit
def func(x):
return x.method(10)
x = Tensor(1)
assert func(x) == 20
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_parameter_attr():
"""
Feature: MSAdapter
Description: Get the properties of MSAdapter.Parameter
Expectation: No exception
"""
@ms.jit
def func(x):
return x.attr
x = Parameter(Tensor(1))
assert func(x) == 10
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_parameter_method():
"""
Feature: MSAdapter
Description: Get the methods of MSAdapter.Parameter
Expectation: No exception
"""
@ms.jit
def func(x):
return x.method(10)
x = Parameter(Tensor(1))
assert func(x) == 20
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_tensor_convert_type():
"""
Feature: MSAdapter
Description: Test type conversion
Expectation: No exception
"""
@ms.jit
def func(x, y):
a = x.size(0)
b = y.size
x = convert_to_ms_tensor(x)
y = convert_to_adapter_tensor(y)
c = x.size
d = y.size(0)
return x, y, (a, b, c, d)
x = Tensor([1, 2, 3])
y = ms.Tensor([1, 2, 3])
out = func(x, y)
assert type(out[0]) is ms.Tensor
assert type(out[1]) is Tensor
assert out[2] == (3, 3, 3, 3)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_tensor_isinstance():
"""
Feature: MSAdapter
Description: Test isinstance syntax
Expectation: No exception
"""
@ms.jit
def func(x):
a = isinstance(x, Tensor)
x = convert_to_ms_tensor(x)
b = isinstance(x, Tensor)
x = convert_to_adapter_tensor(x)
c = isinstance(x, Tensor)
return a, b, c
x = Tensor(1)
out = func(x)
assert out[0] and not out[1] and out[2]

View File

@ -0,0 +1,51 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test grad in MSAdapter. """
import pytest
import numpy as np
import mindspore as ms
import mindspore.nn as nn
from mindspore.common import dtype as mstype
from mindspore.ops import grad
from tests.st.ms_adapter import Tensor
ms.set_context(mode=ms.GRAPH_MODE)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_ms_adapter_grad():
"""
Feature: MSAdapter
Description: Test grad scenario of MSAdapter
Expectation: No exception
"""
class Net(nn.Cell):
def construct(self, x, y, z):
return x * y * z
x = Tensor([1, 2], mstype.int32)
y = Tensor([-2, 3], mstype.int32)
z = Tensor([0, 3], mstype.int32)
net = Net()
output = grad(net, grad_position=(1, 2))(x, y, z)
grad_y = Tensor([0, 6], mstype.int32)
grad_z = Tensor([-2, 6], mstype.int32)
assert np.all(output[0].asnumpy() == grad_y.asnumpy())
assert np.all(output[1].asnumpy() == grad_z.asnumpy())

View File

@ -0,0 +1,307 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import pytest
import mindspore as ms
import tests.st.ms_adapter as adapter
ms.set_context(mode=ms.GRAPH_MODE)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_arithmetic_operator():
"""
Feature: MSAdapter
Description: Test arithmetic operators
Expectation: No exception
"""
@ms.jit
def add_fn(x, y):
return x + y
@ms.jit
def sub_fn(x, y):
return x - y
@ms.jit
def mul_fn(x, y):
return x * y
@ms.jit
def div_fn(x, y):
return x / y
@ms.jit
def floordiv_fn(x, y):
return x // y
@ms.jit
def mod_fn(x, y):
return x % y
@ms.jit
def pow_fn(x, y):
return x ** y
def check_output_type(func):
ms_x = ms.Tensor(1)
adapter_x = adapter.Tensor(1)
assert type(func(ms_x, ms_x)) is ms.Tensor
assert type(func(adapter_x, adapter_x)) is adapter.Tensor # "Tensor", "Tensor"
assert type(func(adapter_x, 1)) is adapter.Tensor # "Tensor", "Number"
assert type(func(1, adapter_x)) is adapter.Tensor # "Number", "Tensor"
assert type(func(adapter_x, (adapter_x,))) is adapter.Tensor # "Tensor", "Tuple"
assert type(func((adapter_x,), adapter_x)) is adapter.Tensor # "Tuple", "Tensor"
assert type(func(adapter_x, [adapter_x,])) is adapter.Tensor # "Tensor", "List"
assert type(func([adapter_x,], adapter_x)) is adapter.Tensor # "List", "Tensor"
check_output_type(add_fn)
check_output_type(sub_fn)
check_output_type(mul_fn)
check_output_type(div_fn)
check_output_type(floordiv_fn)
check_output_type(mod_fn)
check_output_type(pow_fn)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_binary_operator():
"""
Feature: MSAdapter
Description: Test binary operators
Expectation: No exception
"""
@ms.jit
def equal_fn(x, y):
return x == y
@ms.jit
def not_equal_fn(x, y):
return x != y
@ms.jit
def less_fn(x, y):
return x < y
@ms.jit
def greater_fn(x, y):
return x > y
@ms.jit
def less_equal_fn(x, y):
return x <= y
@ms.jit
def greater_equal_fn(x, y):
return x >= y
@ms.jit
def bitwise_and_fn(x, y):
return x & y
@ms.jit
def bitwise_or_fn(x, y):
return x | y
@ms.jit
def bitwise_xor_fn(x, y):
return x ^ y
def check_output_type(func):
ms_x = ms.Tensor([1, 2, 3])
ms_y = ms.Tensor([3, 2, 1])
adapter_x = adapter.Tensor([1, 2, 3])
adapter_y = adapter.Tensor([3, 2, 1])
assert type(func(ms_x, ms_y)) is ms.Tensor
assert type(func(adapter_x, adapter_y)) is adapter.Tensor # "Tensor", "Tensor"
assert type(func(adapter_x, 1)) is adapter.Tensor # "Tensor", "Number"
assert type(func(1, adapter_x)) is adapter.Tensor # "Number", "Tensor"
check_output_type(equal_fn)
check_output_type(not_equal_fn)
check_output_type(less_fn)
check_output_type(greater_fn)
check_output_type(less_equal_fn)
check_output_type(greater_equal_fn)
check_output_type(bitwise_and_fn)
check_output_type(bitwise_or_fn)
check_output_type(bitwise_xor_fn)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_unary_operator():
"""
Feature: MSAdapter
Description: Test unary operators
Expectation: No exception
"""
@ms.jit
def positive_fn(x):
return +x
@ms.jit
def negative_fn(x):
return -x
ms_x = ms.Tensor([1, -2, 3])
adapter_x = adapter.Tensor([1, -2, 3])
assert type(positive_fn(ms_x)) is ms.Tensor
assert type(negative_fn(ms_x)) is ms.Tensor
assert type(positive_fn(adapter_x)) is adapter.Tensor
assert type(negative_fn(adapter_x)) is adapter.Tensor
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_logical_operator():
"""
Feature: MSAdapter
Description: Test logical operators
Expectation: No exception
"""
@ms.jit
def is_fn(x):
return x is None
@ms.jit
def is_not_fn(x):
return x is not None
@ms.jit
def invert_fn(x):
return ~x
@ms.jit
def logical_not_fn(x):
return not x
ms_x = ms.Tensor(True)
adapter_x = adapter.Tensor(True)
assert not is_fn(adapter_x)
assert is_not_fn(adapter_x)
assert type(invert_fn(ms_x)) is ms.Tensor
assert type(logical_not_fn(ms_x)) is ms.Tensor
assert type(invert_fn(adapter_x)) is adapter.Tensor
assert type(logical_not_fn(adapter_x)) is adapter.Tensor
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_contain_operator():
"""
Feature: MSAdapter
Description: Test in / not in
Expectation: No exception
"""
@ms.jit
def in_fn(x, y, z):
return x in (x, y, z)
@ms.jit
def not_in_fn(x, y, z):
return x not in (x, y, z)
ms_x = ms.Tensor(2)
ms_y = ms.Tensor(2)
ms_z = ms.Tensor(3)
adapter_x = adapter.Tensor(1)
adapter_y = adapter.Tensor(2)
adapter_z = adapter.Tensor(3)
assert type(in_fn(ms_x, ms_y, ms_z)) is ms.Tensor
assert type(not_in_fn(ms_x, ms_y, ms_z)) is ms.Tensor
assert type(in_fn(adapter_x, adapter_y, adapter_z)) is adapter.Tensor
assert type(not_in_fn(adapter_x, adapter_y, adapter_z)) is adapter.Tensor
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_matmul():
"""
Feature: MSAdapter
Description: Test matmul operator
Expectation: No exception
"""
@ms.jit
def func(x, y):
return x @ y
ms_x = ms.Tensor([1, 2], ms.float32)
ms_y = ms.Tensor([3, 4], ms.float32)
adapter_x = adapter.Tensor([1, 2], ms.float32)
adapter_y = adapter.Tensor([3, 4], ms.float32)
assert type(func(ms_x, ms_y)) is ms.Tensor
assert type(func(adapter_x, adapter_y)) is adapter.Tensor
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_getitem():
"""
Feature: MSAdapter
Description: Test getietm operator
Expectation: No exception
"""
@ms.jit
def getitem_fn(x, index):
return x[index]
@ms.jit
def getitem_slice_fn(x):
return x[1:]
ms_x = ms.Tensor([[1, 2, 3], [4, 5, 6]])
adapter_x = adapter.Tensor([[1, 2, 3], [4, 5, 6]])
assert type(getitem_fn(ms_x, 0)) is ms.Tensor
assert type(getitem_fn(ms_x, None)) is ms.Tensor
assert type(getitem_fn(ms_x, [0, 1])) is ms.Tensor
assert type(getitem_fn(ms_x, (0, 1))) is ms.Tensor
assert type(getitem_slice_fn(ms_x)) is ms.Tensor
assert type(getitem_fn(adapter_x, 0)) is adapter.Tensor
assert type(getitem_fn(adapter_x, None)) is adapter.Tensor
assert type(getitem_fn(adapter_x, [0, 1])) is adapter.Tensor
assert type(getitem_fn(adapter_x, (0, 1))) is adapter.Tensor
assert type(getitem_slice_fn(adapter_x)) is adapter.Tensor
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_setitem():
"""
Feature: MSAdapter
Description: Test setitem operator
Expectation: No exception
"""
@ms.jit
def setitem_fn(x, index, value):
x[index] = value
return x
ms_x = ms.Tensor([[1, 2, 3], [4, 5, 6]])
adapter_x = adapter.Tensor([[1, 2, 3], [4, 5, 6]])
adapter_index = adapter.Tensor(0)
adapter_value = adapter.Tensor([7, 8, 9])
assert type(setitem_fn(adapter_x, adapter_index, adapter_value)) is adapter.Tensor
assert type(setitem_fn(ms_x, adapter_index, adapter_value)) is ms.Tensor

View File

@ -0,0 +1,391 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from functools import partial
import pytest
import mindspore as ms
import tests.st.ms_adapter as adapter
ms.set_context(mode=ms.GRAPH_MODE)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_abs():
"""
Feature: MSAdapter
Description: Test python built-in function abs()
Expectation: No exception
"""
@ms.jit
def func(x):
return abs(x)
assert type(func(ms.Tensor(-5))) is ms.Tensor
assert type(func(adapter.Tensor(-5))) is adapter.Tensor
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_round():
"""
Feature: MSAdapter
Description: Test python built-in function round()
Expectation: No exception
"""
@ms.jit
def func(x):
return round(x)
assert type(func(ms.Tensor(1.55))) is ms.Tensor
assert type(func(adapter.Tensor(1.55))) is adapter.Tensor
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_map():
"""
Feature: MSAdapter
Description: Test python built-in function map()
Expectation: No exception
"""
def add(x, y):
return x + y
@ms.jit
def func(x, y):
return map(add, x, y)
x = (adapter.Tensor(1), 2)
y = (adapter.Tensor(2), 4)
out = func(x, y)
assert type(out[0]) is adapter.Tensor
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_filter():
"""
Feature: MSAdapter
Description: Test python built-in function filter()
Expectation: No exception
"""
def select_fn(x):
return True
@ms.jit
def func(x):
return filter(select_fn, x)
x = (adapter.Tensor(2), 1, 2, 3)
out = func(x)
assert type(out[0]) is adapter.Tensor
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_partial():
"""
Feature: MSAdapter
Description: Test python built-in function partial()
Expectation: No exception
"""
def add(x, y):
return x + y
@ms.jit
def func(data):
add_ = partial(add, x=2)
return add_(y=data)
out = func(adapter.Tensor(1))
assert type(out) is adapter.Tensor
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_zip():
"""
Feature: MSAdapter
Description: Test python built-in function zip()
Expectation: No exception
"""
@ms.jit
def func(x, y):
return zip(x, y)
x = (adapter.Tensor(1), 2)
y = (adapter.Tensor(2), 4)
out = func(x, y)
assert type(out[0][0]) is adapter.Tensor
assert type(out[0][1]) is adapter.Tensor
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_enumerate():
"""
Feature: MSAdapter
Description: Test python built-in function enumerate()
Expectation: No exception
"""
@ms.jit
def func(x):
return enumerate(x)
x = adapter.Tensor([[1, 2], [3, 4], [5, 6]])
out = func(x)
assert out[0][0] == 0
assert type(out[0][1]) is adapter.Tensor
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_instance():
"""
Feature: MSAdapter
Description: Test python built-in function isinstance()
Expectation: No exception
"""
@ms.jit
def func(x, y):
a = isinstance(x, ms.Tensor) and not isinstance(x, adapter.Tensor)
b = isinstance(y, ms.Tensor) and isinstance(y, adapter.Tensor)
return a, b
x = ms.Tensor(1)
y = adapter.Tensor(1)
a, b = func(x, y)
assert a and b
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_max():
"""
Feature: MSAdapter
Description: Test python built-in function max()
Expectation: No exception
"""
@ms.jit
def func(x, y, z):
return max(x), max(y, z)
x = adapter.Tensor([1, 2], ms.float32)
y = adapter.Tensor([1], ms.float32)
z = adapter.Tensor([2], ms.float32)
out = func(x, y, z)
assert type(out[0]) is adapter.Tensor
assert type(out[1]) is adapter.Tensor
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_min():
"""
Feature: MSAdapter
Description: Test python built-in function min()
Expectation: No exception
"""
@ms.jit
def func(x, y, z):
return min(x), min(y, z)
x = adapter.Tensor([1, 2], ms.float32)
y = adapter.Tensor([1], ms.float32)
z = adapter.Tensor([2], ms.float32)
out = func(x, y, z)
assert type(out[0]) is adapter.Tensor
assert type(out[1]) is adapter.Tensor
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_sum():
"""
Feature: MSAdapter
Description: Test python built-in function sum()
Expectation: No exception
"""
@ms.jit
def func(x, y, z):
return sum(x), sum(y, z)
x = adapter.Tensor([[1, 2], [3, 4]], ms.float32)
y = adapter.Tensor([1, 2, 3], ms.float32)
z = adapter.Tensor([4, 5, 6], ms.float32)
out = func(x, y, z)
assert type(out[0]) is adapter.Tensor
assert type(out[1]) is adapter.Tensor
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_getattr():
"""
Feature: MSAdapter
Description: Test python built-in function getattr()
Expectation: No exception
"""
@ms.jit
def func(x):
return getattr(x, "attr")
x = adapter.Tensor([1, 2, 3])
assert func(x) == 10
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_hasattr():
"""
Feature: MSAdapter
Description: Test python built-in function hasattr()
Expectation: No exception
"""
@ms.jit
def func(x, y):
return hasattr(x, "method"), hasattr(y, "method")
x = adapter.Tensor([1, 2, 3])
y = ms.Tensor([1, 2, 3])
out = func(x, y)
assert out[0] and not out[1]
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_iter():
"""
Feature: MSAdapter
Description: Test python built-in function iter()
Expectation: No exception
"""
@ms.jit
def func(x):
return iter(x)[0]
x = adapter.Tensor([1, 2, 3])
out = func(x)
assert type(out) is adapter.Tensor
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_next():
"""
Feature: MSAdapter
Description: Test python built-in function next()
Expectation: No exception
"""
@ms.jit
def func(x):
it = iter(x)
return next(it)
x = adapter.Tensor([1, 2, 3])
out = func(x)
assert type(out[0]) is adapter.Tensor
assert type(out[1]) is adapter.Tensor
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_print():
"""
Feature: MSAdapter
Description: Test python built-in function print()
Expectation: No exception
"""
@ms.jit
def func(x):
print(x)
return x
func(adapter.Tensor([1, 2, 3]))
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_tuple():
"""
Feature: MSAdapter
Description: Test python built-in function tuple()
Expectation: No exception
"""
@ms.jit
def func(x):
return tuple(x)
x = adapter.Tensor([1, 2, 3])
out = func(x)
assert type(out[0]) is adapter.Tensor
assert type(out[1]) is adapter.Tensor
assert type(out[2]) is adapter.Tensor
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_list():
"""
Feature: MSAdapter
Description: Test python built-in function list()
Expectation: No exception
"""
@ms.jit
def func(x):
return list(x)
x = adapter.Tensor([1, 2, 3])
out = func(x)
assert type(out[0]) is adapter.Tensor
assert type(out[1]) is adapter.Tensor
assert type(out[2]) is adapter.Tensor
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_bool():
"""
Feature: MSAdapter
Description: Test python built-in function bool()
Expectation: No exception
"""
@ms.jit
def func(x):
return bool(x)
x = adapter.Tensor([10])
out = func(x)
assert type(out) is adapter.Tensor