forked from mindspore-Ecosystem/mindspore
Adapt GatherV2 for dynamic shape
This commit is contained in:
parent
8f106d685a
commit
144a35b17e
|
@ -49,22 +49,6 @@ using mindspore::parse::PyObjectWrapper;
|
||||||
std::unordered_set<std::string> prims_to_skip_undetermined_infer{"make_tuple", "make_list", "switch", "env_setitem",
|
std::unordered_set<std::string> prims_to_skip_undetermined_infer{"make_tuple", "make_list", "switch", "env_setitem",
|
||||||
"env_getitem"};
|
"env_getitem"};
|
||||||
|
|
||||||
EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) {
|
|
||||||
if (prims_to_skip_undetermined_infer.find(prim_->name()) == prims_to_skip_undetermined_infer.end()) {
|
|
||||||
auto ret_abstract = AbstractEval(args);
|
|
||||||
if (ret_abstract != nullptr) {
|
|
||||||
MS_LOG(DEBUG) << "StandardPrimEvaluator eval Undetermined";
|
|
||||||
return ret_abstract;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
prim_->BeginRecordAddAttr();
|
|
||||||
AbstractBasePtr abs_base = eval_impl_(engine, prim_, args);
|
|
||||||
prim_->EndRecordAddAttr();
|
|
||||||
auto added_attrs = prim_->evaluate_added_attrs();
|
|
||||||
auto infer_result = std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>(added_attrs));
|
|
||||||
return infer_result;
|
|
||||||
}
|
|
||||||
|
|
||||||
EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||||
AnfNodeConfigPtr out_conf) {
|
AnfNodeConfigPtr out_conf) {
|
||||||
AbstractBasePtrList args_spec_list;
|
AbstractBasePtrList args_spec_list;
|
||||||
|
@ -289,45 +273,45 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) {
|
||||||
py::dict dic;
|
py::dict dic;
|
||||||
if (abs_base->isa<AbstractTensor>()) {
|
if (abs_base->isa<AbstractTensor>()) {
|
||||||
auto arg_tensor = dyn_cast<AbstractTensor>(abs_base);
|
auto arg_tensor = dyn_cast<AbstractTensor>(abs_base);
|
||||||
dic["shape"] = arg_tensor->shape()->shape();
|
dic[ATTR_SHAPE] = arg_tensor->shape()->shape();
|
||||||
if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
|
if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
|
||||||
const auto &min_shape = arg_tensor->shape()->min_shape();
|
const auto &min_shape = arg_tensor->shape()->min_shape();
|
||||||
const auto &max_shape = arg_tensor->shape()->max_shape();
|
const auto &max_shape = arg_tensor->shape()->max_shape();
|
||||||
if (!min_shape.empty() && !max_shape.empty()) {
|
if (!min_shape.empty() && !max_shape.empty()) {
|
||||||
dic["min_shape"] = min_shape;
|
dic[ATTR_MIN_SHAPE] = min_shape;
|
||||||
dic["max_shape"] = max_shape;
|
dic[ATTR_MAX_SHAPE] = max_shape;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
dic["dtype"] = arg_tensor->BuildType();
|
dic[ATTR_DTYPE] = arg_tensor->BuildType();
|
||||||
dic["value"] = BuildValue(arg_tensor->BuildValue());
|
dic[ATTR_VALUE] = BuildValue(arg_tensor->BuildValue());
|
||||||
} else if (abs_base->isa<AbstractRowTensor>()) {
|
} else if (abs_base->isa<AbstractRowTensor>()) {
|
||||||
auto arg = dyn_cast<AbstractRowTensor>(abs_base);
|
auto arg = dyn_cast<AbstractRowTensor>(abs_base);
|
||||||
dic["shape"] = arg->shape()->shape();
|
dic[ATTR_SHAPE] = arg->shape()->shape();
|
||||||
dic["dtype"] = arg->BuildType();
|
dic[ATTR_DTYPE] = arg->BuildType();
|
||||||
dic["value"] = BuildValue(arg->BuildValue());
|
dic[ATTR_VALUE] = BuildValue(arg->BuildValue());
|
||||||
} else if (abs_base->isa<AbstractSparseTensor>()) {
|
} else if (abs_base->isa<AbstractSparseTensor>()) {
|
||||||
auto arg = dyn_cast<AbstractSparseTensor>(abs_base);
|
auto arg = dyn_cast<AbstractSparseTensor>(abs_base);
|
||||||
dic["shape"] = arg->shape()->shape();
|
dic[ATTR_SHAPE] = arg->shape()->shape();
|
||||||
dic["dtype"] = arg->BuildType();
|
dic[ATTR_DTYPE] = arg->BuildType();
|
||||||
dic["value"] = BuildValue(arg->BuildValue());
|
dic[ATTR_VALUE] = BuildValue(arg->BuildValue());
|
||||||
} else if (abs_base->isa<AbstractScalar>() || abs_base->isa<AbstractType>() || abs_base->isa<AbstractRefKey>()) {
|
} else if (abs_base->isa<AbstractScalar>() || abs_base->isa<AbstractType>() || abs_base->isa<AbstractRefKey>()) {
|
||||||
ShapeVector shape;
|
ShapeVector shape;
|
||||||
dic["shape"] = shape;
|
dic[ATTR_SHAPE] = shape;
|
||||||
dic["dtype"] = abs_base->BuildType();
|
dic[ATTR_DTYPE] = abs_base->BuildType();
|
||||||
dic["value"] = BuildValue(abs_base->BuildValue());
|
dic[ATTR_VALUE] = BuildValue(abs_base->BuildValue());
|
||||||
} else if (abs_base->isa<AbstractSlice>()) {
|
} else if (abs_base->isa<AbstractSlice>()) {
|
||||||
auto arg_slice = dyn_cast<AbstractSlice>(abs_base);
|
auto arg_slice = dyn_cast<AbstractSlice>(abs_base);
|
||||||
ShapeVector shape;
|
ShapeVector shape;
|
||||||
dic["shape"] = shape;
|
dic[ATTR_SHAPE] = shape;
|
||||||
dic["dtype"] = arg_slice->BuildType();
|
dic[ATTR_DTYPE] = arg_slice->BuildType();
|
||||||
dic["value"] = BuildValue(arg_slice->BuildValue());
|
dic[ATTR_VALUE] = BuildValue(arg_slice->BuildValue());
|
||||||
} else if (abs_base->isa<AbstractRef>()) {
|
} else if (abs_base->isa<AbstractRef>()) {
|
||||||
auto value = abs_base->cast<AbstractRefPtr>()->ref();
|
auto value = abs_base->cast<AbstractRefPtr>()->ref();
|
||||||
dic = ConvertAbstractToPython(value);
|
dic = ConvertAbstractToPython(value);
|
||||||
} else if (abs_base->isa<AbstractEllipsis>()) {
|
} else if (abs_base->isa<AbstractEllipsis>()) {
|
||||||
dic["shape"] = py::none();
|
dic[ATTR_SHAPE] = py::none();
|
||||||
dic["dtype"] = py::ellipsis();
|
dic[ATTR_DTYPE] = py::ellipsis();
|
||||||
dic["value"] = py::ellipsis();
|
dic[ATTR_VALUE] = py::ellipsis();
|
||||||
} else if (abs_base->isa<AbstractTuple>()) {
|
} else if (abs_base->isa<AbstractTuple>()) {
|
||||||
auto arg_tuple = dyn_cast<AbstractTuple>(abs_base);
|
auto arg_tuple = dyn_cast<AbstractTuple>(abs_base);
|
||||||
size_t len = arg_tuple->size();
|
size_t len = arg_tuple->size();
|
||||||
|
@ -336,12 +320,12 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) {
|
||||||
|
|
||||||
for (size_t i = 0; i < len; i++) {
|
for (size_t i = 0; i < len; i++) {
|
||||||
py::dict out = ConvertAbstractToPython(arg_tuple->elements()[i]);
|
py::dict out = ConvertAbstractToPython(arg_tuple->elements()[i]);
|
||||||
shape_tuple[i] = out["shape"];
|
shape_tuple[i] = out[ATTR_SHAPE];
|
||||||
dtype_tuple[i] = out["dtype"];
|
dtype_tuple[i] = out[ATTR_DTYPE];
|
||||||
}
|
}
|
||||||
dic["shape"] = shape_tuple;
|
dic[ATTR_SHAPE] = shape_tuple;
|
||||||
dic["dtype"] = dtype_tuple;
|
dic[ATTR_DTYPE] = dtype_tuple;
|
||||||
dic["value"] = BuildValue(arg_tuple->BuildValue());
|
dic[ATTR_VALUE] = BuildValue(arg_tuple->BuildValue());
|
||||||
} else if (abs_base->isa<AbstractList>()) {
|
} else if (abs_base->isa<AbstractList>()) {
|
||||||
auto arg_list = dyn_cast<AbstractList>(abs_base);
|
auto arg_list = dyn_cast<AbstractList>(abs_base);
|
||||||
size_t len = arg_list->size();
|
size_t len = arg_list->size();
|
||||||
|
@ -350,25 +334,25 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) {
|
||||||
|
|
||||||
for (size_t i = 0; i < len; i++) {
|
for (size_t i = 0; i < len; i++) {
|
||||||
py::dict out = ConvertAbstractToPython(arg_list->elements()[i]);
|
py::dict out = ConvertAbstractToPython(arg_list->elements()[i]);
|
||||||
shape_list[i] = out["shape"];
|
shape_list[i] = out[ATTR_SHAPE];
|
||||||
dtype_list[i] = out["dtype"];
|
dtype_list[i] = out[ATTR_DTYPE];
|
||||||
}
|
}
|
||||||
dic["shape"] = shape_list;
|
dic[ATTR_SHAPE] = shape_list;
|
||||||
dic["dtype"] = dtype_list;
|
dic[ATTR_DTYPE] = dtype_list;
|
||||||
dic["value"] = BuildValue(arg_list->BuildValue());
|
dic[ATTR_VALUE] = BuildValue(arg_list->BuildValue());
|
||||||
} else if (abs_base->isa<AbstractNone>()) {
|
} else if (abs_base->isa<AbstractNone>()) {
|
||||||
dic["shape"] = py::none();
|
dic[ATTR_SHAPE] = py::none();
|
||||||
dic["dtype"] = py::none();
|
dic[ATTR_DTYPE] = py::none();
|
||||||
dic["value"] = py::none();
|
dic[ATTR_VALUE] = py::none();
|
||||||
} else if (abs_base->isa<AbstractFunction>()) {
|
} else if (abs_base->isa<AbstractFunction>()) {
|
||||||
dic["shape"] = py::none();
|
dic[ATTR_SHAPE] = py::none();
|
||||||
dic["dtype"] = abs_base->BuildType();
|
dic[ATTR_DTYPE] = abs_base->BuildType();
|
||||||
dic["value"] = py::none();
|
dic[ATTR_VALUE] = py::none();
|
||||||
} else if (abs_base->isa<AbstractUndetermined>()) {
|
} else if (abs_base->isa<AbstractUndetermined>()) {
|
||||||
auto arg = dyn_cast<AbstractUndetermined>(abs_base);
|
auto arg = dyn_cast<AbstractUndetermined>(abs_base);
|
||||||
dic["shape"] = py::none();
|
dic[ATTR_SHAPE] = py::none();
|
||||||
dic["dtype"] = arg->BuildType();
|
dic[ATTR_DTYPE] = arg->BuildType();
|
||||||
dic["value"] = py::none();
|
dic[ATTR_VALUE] = py::none();
|
||||||
} else {
|
} else {
|
||||||
auto value = abs_base->BuildValue();
|
auto value = abs_base->BuildValue();
|
||||||
if ((*value == *kAnyValue)) {
|
if ((*value == *kAnyValue)) {
|
||||||
|
@ -409,18 +393,20 @@ py::tuple PreparePyInputs(const PrimitivePyPtr &prim_py, const AbstractBasePtrLi
|
||||||
|
|
||||||
AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dict &output) {
|
AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dict &output) {
|
||||||
// Convert to AbstractValue based on type and shape
|
// Convert to AbstractValue based on type and shape
|
||||||
auto out_dtype = output["dtype"];
|
auto out_dtype = output[ATTR_DTYPE];
|
||||||
if (output["value"].is_none()) {
|
if (output[ATTR_VALUE].is_none()) {
|
||||||
auto out_shape = output["shape"];
|
auto out_shape = output[ATTR_SHAPE];
|
||||||
py::object min_shape = output.contains("min_shape") ? (py::object)output["min_shape"] : (py::object)py::none();
|
py::object min_shape =
|
||||||
py::object max_shape = output.contains("max_shape") ? (py::object)output["max_shape"] : (py::object)py::none();
|
output.contains(py::str(ATTR_MIN_SHAPE)) ? (py::object)output[ATTR_MIN_SHAPE] : (py::object)py::none();
|
||||||
|
py::object max_shape =
|
||||||
|
output.contains(py::str(ATTR_MAX_SHAPE)) ? (py::object)output[ATTR_MAX_SHAPE] : (py::object)py::none();
|
||||||
|
|
||||||
return PyListDtype2AbstractTensor(out_shape, out_dtype, min_shape, max_shape);
|
return PyListDtype2AbstractTensor(out_shape, out_dtype, min_shape, max_shape);
|
||||||
}
|
}
|
||||||
// Convert pyobject to Value, then to AbstractValue
|
// Convert pyobject to Value, then to AbstractValue
|
||||||
ValuePtr converted_ret = nullptr;
|
ValuePtr converted_ret = nullptr;
|
||||||
TypePtr dtype = py::isinstance<Type>(out_dtype) ? out_dtype.cast<TypePtr>() : nullptr;
|
TypePtr dtype = py::isinstance<Type>(out_dtype) ? out_dtype.cast<TypePtr>() : nullptr;
|
||||||
bool converted = parse::ConvertData(output["value"], &converted_ret, false, dtype);
|
bool converted = parse::ConvertData(output[ATTR_VALUE], &converted_ret, false, dtype);
|
||||||
if (!converted) {
|
if (!converted) {
|
||||||
MS_LOG(EXCEPTION) << "Convert data failed";
|
MS_LOG(EXCEPTION) << "Convert data failed";
|
||||||
}
|
}
|
||||||
|
@ -447,6 +433,73 @@ AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dic
|
||||||
}
|
}
|
||||||
} // end anonymous namespace
|
} // end anonymous namespace
|
||||||
|
|
||||||
|
EvalResultPtr StandardPrimEvaluator::EvalPyCheckPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) {
|
||||||
|
auto prim_py = dyn_cast<PrimitivePy>(prim_);
|
||||||
|
if (prim_py == nullptr) {
|
||||||
|
MS_LOG(EXCEPTION) << "The primitive with type 'kPrimTypePyInferCheck' should be a python primitive.";
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call checking method '__check__' for subclass of 'PrimitiveWithCheck'
|
||||||
|
MS_LOG(DEBUG) << "Begin input args checking for: " << prim_py->ToString();
|
||||||
|
auto py_args = PreparePyInputs(prim_py, args);
|
||||||
|
prim_py->RunCheck(py_args);
|
||||||
|
|
||||||
|
prim_->BeginRecordAddAttr();
|
||||||
|
AbstractBasePtr abs_base = eval_impl_(engine, prim_, args);
|
||||||
|
prim_->EndRecordAddAttr();
|
||||||
|
auto added_attrs = prim_->evaluate_added_attrs();
|
||||||
|
|
||||||
|
if (!py::hasattr(prim_py->GetPyObj(), PY_PRIM_METHOD_INFER_VALUE)) {
|
||||||
|
return std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>(added_attrs));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call method 'infer_value' for primitive with this method for constant propagation
|
||||||
|
py::tuple py_vals(py_args.size());
|
||||||
|
for (size_t i = 0; i < py_args.size(); ++i) {
|
||||||
|
py_vals[i] = py_args[i][ATTR_VALUE];
|
||||||
|
}
|
||||||
|
py::object py_ret = prim_py->RunInferValue(py_vals);
|
||||||
|
if (py::isinstance<py::none>(py_ret)) {
|
||||||
|
return std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>(added_attrs));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert pyobject to Value, then to AbstractValue
|
||||||
|
ValuePtr converted_ret = nullptr;
|
||||||
|
TypePtr dtype = abs_base->BuildType();
|
||||||
|
bool converted = parse::ConvertData(py_ret, &converted_ret, false, dtype);
|
||||||
|
if (!converted) {
|
||||||
|
MS_LOG(EXCEPTION) << "Convert data failed";
|
||||||
|
}
|
||||||
|
auto res_spec = FromValue(converted_ret);
|
||||||
|
MS_EXCEPTION_IF_NULL(res_spec);
|
||||||
|
if (res_spec->isa<AbstractTensor>()) {
|
||||||
|
// Replace to tensor constant node in specialize
|
||||||
|
auto res_tensor = res_spec->cast<AbstractTensorPtr>();
|
||||||
|
res_tensor->set_value(converted_ret);
|
||||||
|
}
|
||||||
|
return std::make_shared<EvalResult>(res_spec, std::make_shared<AttrValueMap>(added_attrs));
|
||||||
|
}
|
||||||
|
|
||||||
|
EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) {
|
||||||
|
if (prims_to_skip_undetermined_infer.find(prim_->name()) == prims_to_skip_undetermined_infer.end()) {
|
||||||
|
auto ret_abstract = AbstractEval(args);
|
||||||
|
if (ret_abstract != nullptr) {
|
||||||
|
MS_LOG(DEBUG) << "StandardPrimEvaluator eval Undetermined";
|
||||||
|
return ret_abstract;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (prim_->prim_type() == PrimType::kPrimTypePyInferCheck) {
|
||||||
|
return EvalPyCheckPrim(engine, args);
|
||||||
|
}
|
||||||
|
|
||||||
|
prim_->BeginRecordAddAttr();
|
||||||
|
AbstractBasePtr abs_base = eval_impl_(engine, prim_, args);
|
||||||
|
prim_->EndRecordAddAttr();
|
||||||
|
auto added_attrs = prim_->evaluate_added_attrs();
|
||||||
|
return std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>(added_attrs));
|
||||||
|
}
|
||||||
|
|
||||||
EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) {
|
EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) {
|
||||||
auto ret_abstract = AbstractEval(args);
|
auto ret_abstract = AbstractEval(args);
|
||||||
if (ret_abstract != nullptr) {
|
if (ret_abstract != nullptr) {
|
||||||
|
|
|
@ -42,6 +42,8 @@ class StandardPrimEvaluator : public TrivialPrimEvaluator {
|
||||||
std::string ToString() const override { return identifier_ + prim_->name(); }
|
std::string ToString() const override { return identifier_ + prim_->name(); }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
EvalResultPtr EvalPyCheckPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args);
|
||||||
|
|
||||||
PrimitivePtr prim_;
|
PrimitivePtr prim_;
|
||||||
const StandardPrimitiveEvalImpl eval_impl_;
|
const StandardPrimitiveEvalImpl eval_impl_;
|
||||||
};
|
};
|
||||||
|
|
|
@ -308,20 +308,18 @@ void AnalysisEngine::Clear() {
|
||||||
namespace {
|
namespace {
|
||||||
EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr &engine) {
|
EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr &engine) {
|
||||||
// Custom Primitive with python infer_shape, infer_type
|
// Custom Primitive with python infer_shape, infer_type
|
||||||
EvaluatorPtr evaluator = nullptr;
|
|
||||||
MS_EXCEPTION_IF_NULL(prim);
|
MS_EXCEPTION_IF_NULL(prim);
|
||||||
if (prim->isa<prim::DoSignaturePrimitive>()) {
|
if (prim->isa<prim::DoSignaturePrimitive>()) {
|
||||||
evaluator = std::make_shared<DoSignatureEvaluator>(prim);
|
return std::make_shared<DoSignatureEvaluator>(prim);
|
||||||
return evaluator;
|
|
||||||
}
|
}
|
||||||
if (prim->isa<prim::UnpackGraphPrimitive>()) {
|
if (prim->isa<prim::UnpackGraphPrimitive>()) {
|
||||||
evaluator = std::make_shared<UnpackGraphEvaluator>(prim);
|
return std::make_shared<UnpackGraphEvaluator>(prim);
|
||||||
return evaluator;
|
|
||||||
}
|
}
|
||||||
if (prim->Hash() == prim::kPrimMixedPrecisionCast->Hash() && prim->name() == prim::kPrimMixedPrecisionCast->name()) {
|
if (prim->Hash() == prim::kPrimMixedPrecisionCast->Hash() && prim->name() == prim::kPrimMixedPrecisionCast->name()) {
|
||||||
evaluator = std::make_shared<MixedPrecisionCastEvaluator>(prim);
|
return std::make_shared<MixedPrecisionCastEvaluator>(prim);
|
||||||
return evaluator;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
EvaluatorPtr evaluator = nullptr;
|
||||||
if (prim->HasPyEvaluator()) {
|
if (prim->HasPyEvaluator()) {
|
||||||
auto prim_py = dyn_cast<PrimitivePy>(prim);
|
auto prim_py = dyn_cast<PrimitivePy>(prim);
|
||||||
if (prim_py != nullptr) {
|
if (prim_py != nullptr) {
|
||||||
|
|
|
@ -55,6 +55,10 @@ void ValidateOperation(const AnfNodePtr &node) {
|
||||||
MS_LOG(DEBUG) << "Primitive " << prim->name() << " has python evaluator.";
|
MS_LOG(DEBUG) << "Primitive " << prim->name() << " has python evaluator.";
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
if (prim->prim_type() == PrimType::kPrimTypePyInferCheck) {
|
||||||
|
MS_LOG(DEBUG) << "Primitive " << prim->name() << " has python inference checking method.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
if (prim->name() == "fake_bprop") {
|
if (prim->name() == "fake_bprop") {
|
||||||
MS_LOG(EXCEPTION) << "Illegal primitive: " << GetValue<std::string>(prim->GetAttr("info"));
|
MS_LOG(EXCEPTION) << "Illegal primitive: " << GetValue<std::string>(prim->GetAttr("info"));
|
||||||
}
|
}
|
||||||
|
|
|
@ -254,16 +254,33 @@ py::dict PrimitivePy::RunInfer(const py::tuple &args) {
|
||||||
if (!HasPyObj()) {
|
if (!HasPyObj()) {
|
||||||
MS_LOG(EXCEPTION) << "[" << this->ToString() << "]: pyobj is empty";
|
MS_LOG(EXCEPTION) << "[" << this->ToString() << "]: pyobj is empty";
|
||||||
}
|
}
|
||||||
auto infer_fuc = python_obj_.attr("__infer__");
|
auto infer_fuc = python_obj_.attr(PY_PRIM_METHOD_INFER);
|
||||||
return infer_fuc(*args);
|
return infer_fuc(*args);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void PrimitivePy::RunCheck(const py::tuple &args) {
|
||||||
|
if (!HasPyObj()) {
|
||||||
|
MS_LOG(EXCEPTION) << "[" << this->ToString() << "]: pyobj is empty";
|
||||||
|
}
|
||||||
|
auto check_func = python_obj_.attr(PY_PRIM_METHOD_CHECK);
|
||||||
|
(void)check_func(*args);
|
||||||
|
}
|
||||||
|
|
||||||
|
py::object PrimitivePy::RunInferValue(const py::tuple &args) {
|
||||||
|
if (!HasPyObj()) {
|
||||||
|
MS_LOG(EXCEPTION) << "[" << this->ToString() << "]: pyobj is empty";
|
||||||
|
}
|
||||||
|
auto infer_value = python_obj_.attr(PY_PRIM_METHOD_INFER_VALUE);
|
||||||
|
return infer_value(*args);
|
||||||
|
}
|
||||||
|
|
||||||
REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module *m) {
|
REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module *m) {
|
||||||
(void)py::enum_<PrimType>(*m, "prim_type", py::arithmetic())
|
(void)py::enum_<PrimType>(*m, "prim_type", py::arithmetic())
|
||||||
.value("unknown", PrimType::kPrimTypeUnknown)
|
.value("unknown", PrimType::kPrimTypeUnknown)
|
||||||
.value("builtin", PrimType::kPrimTypeBuiltIn)
|
.value("builtin", PrimType::kPrimTypeBuiltIn)
|
||||||
.value("py_infer_shape", PrimType::kPrimTypePyInferShape)
|
.value("py_infer_shape", PrimType::kPrimTypePyInferShape)
|
||||||
.value("user_custom", PrimType::kPrimTypeUserCustom);
|
.value("user_custom", PrimType::kPrimTypeUserCustom)
|
||||||
|
.value("py_infer_check", PrimType::kPrimTypePyInferCheck);
|
||||||
(void)py::class_<PrimitivePy, std::shared_ptr<PrimitivePy>>(*m, "Primitive_")
|
(void)py::class_<PrimitivePy, std::shared_ptr<PrimitivePy>>(*m, "Primitive_")
|
||||||
.def_readonly(PYTHON_PRIMITIVE_FLAG, &PrimitivePy::parse_info_)
|
.def_readonly(PYTHON_PRIMITIVE_FLAG, &PrimitivePy::parse_info_)
|
||||||
.def(py::init<py::str &, py::object>())
|
.def(py::init<py::str &, py::object>())
|
||||||
|
|
|
@ -62,6 +62,8 @@ class PrimitivePy : public Primitive {
|
||||||
const bool parse_info_ = true;
|
const bool parse_info_ = true;
|
||||||
const py::object &GetPyObj() const { return python_obj_; }
|
const py::object &GetPyObj() const { return python_obj_; }
|
||||||
py::dict RunInfer(const py::tuple &args);
|
py::dict RunInfer(const py::tuple &args);
|
||||||
|
void RunCheck(const py::tuple &args);
|
||||||
|
py::object RunInferValue(const py::tuple &args);
|
||||||
bool ObjHasAttr(const char *attr_name) { return py::hasattr(python_obj_, attr_name); }
|
bool ObjHasAttr(const char *attr_name) { return py::hasattr(python_obj_, attr_name); }
|
||||||
bool HasPyObj() { return python_obj_.operator bool(); }
|
bool HasPyObj() { return python_obj_.operator bool(); }
|
||||||
PrimitivePtr Clone() override;
|
PrimitivePtr Clone() override;
|
||||||
|
|
|
@ -81,6 +81,9 @@ AbstractBasePtr InferImplDropoutGenMask(const AnalysisEnginePtr &, const Primiti
|
||||||
AbstractBasePtr InferImplMinOrMaxGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr InferImplMinOrMaxGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const AbstractBasePtrList &args_spec_list);
|
const AbstractBasePtrList &args_spec_list);
|
||||||
|
|
||||||
|
AbstractBasePtr InferImplSqrt(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const AbstractBasePtrList &args_spec_list);
|
||||||
|
|
||||||
AbstractBasePtr InferImplScalarToArray(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr InferImplScalarToArray(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const AbstractBasePtrList &args_spec_list);
|
const AbstractBasePtrList &args_spec_list);
|
||||||
AbstractBasePtr InferImplArrayToScalar(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr InferImplArrayToScalar(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
@ -176,6 +179,14 @@ AbstractBasePtr InferImplUniqueGrad(const AnalysisEnginePtr &, const PrimitivePt
|
||||||
const AbstractBasePtrList &args_spec_list);
|
const AbstractBasePtrList &args_spec_list);
|
||||||
AbstractBasePtr InferImplUnique(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr InferImplUnique(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const AbstractBasePtrList &args_spec_list);
|
const AbstractBasePtrList &args_spec_list);
|
||||||
|
AbstractBasePtr InferImplGatherV2(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const AbstractBasePtrList &args_spec_list);
|
||||||
|
AbstractBasePtr InferImplDynamicShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const AbstractBasePtrList &args_spec_list);
|
||||||
|
AbstractBasePtr InferImplSparseApplyFtrl(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const AbstractBasePtrList &args_spec_list);
|
||||||
|
AbstractBasePtr InferImplSparseApplyProximalAdagrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const AbstractBasePtrList &args_spec_list);
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) {
|
AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) {
|
||||||
|
|
|
@ -14,6 +14,8 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <iterator>
|
||||||
#include "abstract/infer_functions.h"
|
#include "abstract/infer_functions.h"
|
||||||
#include "abstract/utils.h"
|
#include "abstract/utils.h"
|
||||||
#include "abstract/param_validator.h"
|
#include "abstract/param_validator.h"
|
||||||
|
@ -226,5 +228,60 @@ AbstractBasePtr InferImplUniqueGrad(const AnalysisEnginePtr &, const PrimitivePt
|
||||||
// outputs: dx
|
// outputs: dx
|
||||||
return std::make_shared<AbstractTensor>(ids->element(), ids_idx->shape());
|
return std::make_shared<AbstractTensor>(ids->element(), ids_idx->shape());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
AbstractBasePtr InferImplGatherV2(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const AbstractBasePtrList &args_spec_list) {
|
||||||
|
const std::string &op_name = primitive->name();
|
||||||
|
CheckArgsSize(op_name, args_spec_list, 3);
|
||||||
|
AbstractTensorPtr params = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
|
||||||
|
AbstractTensorPtr indices = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
|
||||||
|
AbstractScalarPtr axis = CheckArg<AbstractScalar>(op_name, args_spec_list, 2);
|
||||||
|
|
||||||
|
auto params_shp = params->shape()->shape();
|
||||||
|
auto indices_shp = indices->shape()->shape();
|
||||||
|
auto axis_val = GetValue<int>(axis->BuildValue());
|
||||||
|
|
||||||
|
auto params_rank = static_cast<int>(params_shp.size());
|
||||||
|
if (axis_val < 0) {
|
||||||
|
axis_val += params_rank;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto calc_shape = [axis_val, ¶ms_shp](const ShapeVector &inp_vec) -> ShapeVector {
|
||||||
|
ShapeVector out_vec;
|
||||||
|
std::copy(params_shp.begin(), params_shp.begin() + axis_val, std::back_inserter(out_vec));
|
||||||
|
copy(inp_vec.begin(), inp_vec.end(), std::back_inserter(out_vec));
|
||||||
|
copy(params_shp.begin() + axis_val + 1, params_shp.end(), std::back_inserter(out_vec));
|
||||||
|
return out_vec;
|
||||||
|
};
|
||||||
|
|
||||||
|
ShapeVector out_shape = calc_shape(indices_shp);
|
||||||
|
if (!indices->shape()->min_shape().empty() && !indices->shape()->max_shape().empty()) {
|
||||||
|
ShapeVector min_shape = calc_shape(indices->shape()->min_shape());
|
||||||
|
ShapeVector max_shape = calc_shape(indices->shape()->max_shape());
|
||||||
|
return std::make_shared<AbstractTensor>(params->element(),
|
||||||
|
std::make_shared<Shape>(out_shape, min_shape, max_shape));
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::make_shared<AbstractTensor>(params->element(), std::make_shared<Shape>(out_shape));
|
||||||
|
}
|
||||||
|
|
||||||
|
AbstractBasePtr InferImplDynamicShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const AbstractBasePtrList &args_spec_list) {
|
||||||
|
const std::string &op_name = primitive->name();
|
||||||
|
CheckArgsSize(op_name, args_spec_list, 1);
|
||||||
|
AbstractTensorPtr input = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
|
||||||
|
auto shape = input->shape()->shape();
|
||||||
|
|
||||||
|
bool has_dyn_shape = std::any_of(shape.begin(), shape.end(), [](int dim) { return dim == Shape::SHP_ANY; });
|
||||||
|
std::vector<int> tensor_shp({static_cast<int>(shape.size())});
|
||||||
|
if (has_dyn_shape) {
|
||||||
|
auto elem = std::make_shared<AbstractScalar>(std::make_shared<AnyValue>(), std::make_shared<Int>(32));
|
||||||
|
return std::make_shared<AbstractTensor>(elem, std::make_shared<Shape>(tensor_shp));
|
||||||
|
}
|
||||||
|
auto shp_buf_size = sizeof(int) * shape.size();
|
||||||
|
auto tensor = std::make_shared<tensor::Tensor>(kNumberTypeInt32, tensor_shp, shape.data(), shp_buf_size);
|
||||||
|
|
||||||
|
return tensor->ToAbstract();
|
||||||
|
}
|
||||||
} // namespace abstract
|
} // namespace abstract
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -37,5 +37,14 @@ AbstractBasePtr InferImplMinOrMaxGrad(const AnalysisEnginePtr &, const Primitive
|
||||||
|
|
||||||
return std::make_shared<AbstractTuple>(AbstractBasePtrList({dx, dy}));
|
return std::make_shared<AbstractTuple>(AbstractBasePtrList({dx, dy}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
AbstractBasePtr InferImplSqrt(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const AbstractBasePtrList &args_spec_list) {
|
||||||
|
// Inputs: three tensors.
|
||||||
|
const std::string op_name = primitive->name();
|
||||||
|
CheckArgsSize(op_name, args_spec_list, 1);
|
||||||
|
auto inp = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
|
||||||
|
return inp->Clone()->Broaden();
|
||||||
|
}
|
||||||
} // namespace abstract
|
} // namespace abstract
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -445,5 +445,25 @@ AbstractBasePtr InferImplDropoutGenMask(const AnalysisEnginePtr &, const Primiti
|
||||||
return std::make_shared<AbstractTensor>(std::make_shared<AbstractScalar>(kAnyValue, kUInt8),
|
return std::make_shared<AbstractTensor>(std::make_shared<AbstractScalar>(kAnyValue, kUInt8),
|
||||||
std::make_shared<Shape>(std::vector<int64_t>{shape_y}));
|
std::make_shared<Shape>(std::vector<int64_t>{shape_y}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
AbstractBasePtr InferImplSparseApplyFtrl(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const AbstractBasePtrList &args_spec_list) {
|
||||||
|
CheckArgsSize(primitive->name(), args_spec_list, 5);
|
||||||
|
AbstractBasePtrList elements;
|
||||||
|
for (size_t i = 0; i < 3; ++i) {
|
||||||
|
elements.push_back(args_spec_list[i]->Clone()->Broaden());
|
||||||
|
}
|
||||||
|
return std::make_shared<AbstractTuple>(elements);
|
||||||
|
}
|
||||||
|
|
||||||
|
AbstractBasePtr InferImplSparseApplyProximalAdagrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const AbstractBasePtrList &args_spec_list) {
|
||||||
|
CheckArgsSize(primitive->name(), args_spec_list, 7);
|
||||||
|
AbstractBasePtrList elements;
|
||||||
|
for (size_t i = 0; i < 2; ++i) {
|
||||||
|
elements.push_back(args_spec_list[i]->Clone()->Broaden());
|
||||||
|
}
|
||||||
|
return std::make_shared<AbstractTuple>(elements);
|
||||||
|
}
|
||||||
} // namespace abstract
|
} // namespace abstract
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -37,6 +37,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
|
||||||
// Maths
|
// Maths
|
||||||
{prim::kPrimMaximumGrad, {InferImplMinOrMaxGrad, true}},
|
{prim::kPrimMaximumGrad, {InferImplMinOrMaxGrad, true}},
|
||||||
{prim::kPrimMinimumGrad, {InferImplMinOrMaxGrad, true}},
|
{prim::kPrimMinimumGrad, {InferImplMinOrMaxGrad, true}},
|
||||||
|
{prim::kPrimSqrt, {InferImplSqrt, true}},
|
||||||
// Array
|
// Array
|
||||||
{prim::kPrimScalarToArray, {InferImplScalarToArray, true}},
|
{prim::kPrimScalarToArray, {InferImplScalarToArray, true}},
|
||||||
{prim::kPrimArrayToScalar, {InferImplArrayToScalar, true}},
|
{prim::kPrimArrayToScalar, {InferImplArrayToScalar, true}},
|
||||||
|
@ -44,6 +45,9 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
|
||||||
{prim::kPrimPack, {InferImplPack, true}},
|
{prim::kPrimPack, {InferImplPack, true}},
|
||||||
{prim::kPrimUnique, {InferImplUnique, true}},
|
{prim::kPrimUnique, {InferImplUnique, true}},
|
||||||
{prim::kPrimUniqueGrad, {InferImplUniqueGrad, true}},
|
{prim::kPrimUniqueGrad, {InferImplUniqueGrad, true}},
|
||||||
|
{prim::kPrimGatherV2, {InferImplGatherV2, true}},
|
||||||
|
{prim::kPrimSparseGatherV2, {InferImplGatherV2, true}},
|
||||||
|
{prim::kPrimDynamicShape, {InferImplDynamicShape, true}},
|
||||||
// Structure
|
// Structure
|
||||||
{prim::kPrimMakeTuple, {InferImplMakeTuple, true}},
|
{prim::kPrimMakeTuple, {InferImplMakeTuple, true}},
|
||||||
{prim::kPrimMakeList, {InferImplMakeList, true}},
|
{prim::kPrimMakeList, {InferImplMakeList, true}},
|
||||||
|
@ -77,6 +81,8 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
|
||||||
{prim::kPrimLayerNorm, {InferImplLayerNorm, true}},
|
{prim::kPrimLayerNorm, {InferImplLayerNorm, true}},
|
||||||
{prim::kPrimLayerNormGrad, {InferImplLayerNormGrad, true}},
|
{prim::kPrimLayerNormGrad, {InferImplLayerNormGrad, true}},
|
||||||
{prim::kPrimDropoutGenMask, {InferImplDropoutGenMask, true}},
|
{prim::kPrimDropoutGenMask, {InferImplDropoutGenMask, true}},
|
||||||
|
{prim::kPrimSparseApplyFtrl, {InferImplSparseApplyFtrl, true}},
|
||||||
|
{prim::kPrimSparseApplyProximalAdagrad, {InferImplSparseApplyProximalAdagrad, true}},
|
||||||
// Others
|
// Others
|
||||||
{prim::kPrimIdentity, {InferImplIdentity, true}},
|
{prim::kPrimIdentity, {InferImplIdentity, true}},
|
||||||
// Set impl to null as it will use PartialEvaluator;
|
// Set impl to null as it will use PartialEvaluator;
|
||||||
|
|
|
@ -84,6 +84,9 @@ inline const PrimitivePtr kPrimConcat = std::make_shared<Primitive>("Concat");
|
||||||
inline const PrimitivePtr kPrimSqueeze = std::make_shared<Primitive>("Squeeze");
|
inline const PrimitivePtr kPrimSqueeze = std::make_shared<Primitive>("Squeeze");
|
||||||
inline const PrimitivePtr kPrimTranspose = std::make_shared<Primitive>("Transpose");
|
inline const PrimitivePtr kPrimTranspose = std::make_shared<Primitive>("Transpose");
|
||||||
inline const PrimitivePtr kPrimGatherV2 = std::make_shared<Primitive>("GatherV2");
|
inline const PrimitivePtr kPrimGatherV2 = std::make_shared<Primitive>("GatherV2");
|
||||||
|
inline const PrimitivePtr kPrimSparseGatherV2 = std::make_shared<Primitive>("SparseGatherV2");
|
||||||
|
inline const PrimitivePtr kPrimShape = std::make_shared<Primitive>("Shape");
|
||||||
|
inline const PrimitivePtr kPrimDynamicShape = std::make_shared<Primitive>("DynamicShape");
|
||||||
inline const PrimitivePtr kPrimEmbeddingLookup = std::make_shared<Primitive>("EmbeddingLookup");
|
inline const PrimitivePtr kPrimEmbeddingLookup = std::make_shared<Primitive>("EmbeddingLookup");
|
||||||
inline const PrimitivePtr kPrimEmbeddingLookupCommGrad = std::make_shared<Primitive>("EmbeddingLookupCommGrad");
|
inline const PrimitivePtr kPrimEmbeddingLookupCommGrad = std::make_shared<Primitive>("EmbeddingLookupCommGrad");
|
||||||
inline const PrimitivePtr kPrimSize = std::make_shared<Primitive>("Size");
|
inline const PrimitivePtr kPrimSize = std::make_shared<Primitive>("Size");
|
||||||
|
@ -154,6 +157,8 @@ inline const PrimitivePtr kPrimBpropCut = std::make_shared<Primitive>("bprop_cut
|
||||||
inline const PrimitivePtr kPrimFakeQuantPerLayer = std::make_shared<Primitive>("FakeQuantPerLayer");
|
inline const PrimitivePtr kPrimFakeQuantPerLayer = std::make_shared<Primitive>("FakeQuantPerLayer");
|
||||||
inline const PrimitivePtr kPrimFakeQuantPerChannel = std::make_shared<Primitive>("FakeQuantPerChannel");
|
inline const PrimitivePtr kPrimFakeQuantPerChannel = std::make_shared<Primitive>("FakeQuantPerChannel");
|
||||||
inline const PrimitivePtr kPrimApplyRMSProp = std::make_shared<Primitive>("ApplyRMSProp");
|
inline const PrimitivePtr kPrimApplyRMSProp = std::make_shared<Primitive>("ApplyRMSProp");
|
||||||
|
inline const PrimitivePtr kPrimSparseApplyFtrl = std::make_shared<Primitive>("SparseApplyFtrl");
|
||||||
|
inline const PrimitivePtr kPrimSparseApplyProximalAdagrad = std::make_shared<Primitive>("SparseApplyProximalAdagrad");
|
||||||
|
|
||||||
// Comm ops
|
// Comm ops
|
||||||
inline const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator");
|
inline const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator");
|
||||||
|
|
|
@ -35,7 +35,8 @@ enum PrimType {
|
||||||
kPrimTypeBuiltIn, // Built-in primitive operator
|
kPrimTypeBuiltIn, // Built-in primitive operator
|
||||||
kPrimTypePyInferShape, // Primitive operator defined by custom
|
kPrimTypePyInferShape, // Primitive operator defined by custom
|
||||||
kPrimTypePyInferTensor, // Primitive operator defined by custom
|
kPrimTypePyInferTensor, // Primitive operator defined by custom
|
||||||
kPrimTypeUserCustom
|
kPrimTypeUserCustom,
|
||||||
|
kPrimTypePyInferCheck // Primitive operator with input args checking method
|
||||||
};
|
};
|
||||||
|
|
||||||
class Primitive : public Named {
|
class Primitive : public Named {
|
||||||
|
|
|
@ -23,4 +23,19 @@ const char GRAPH_FLAG_HAS_EFFECT[] = "has_effect";
|
||||||
const char GRAPH_FLAG_EFFECT_PATIAL_ORDER[] = "_effect_patial_order";
|
const char GRAPH_FLAG_EFFECT_PATIAL_ORDER[] = "_effect_patial_order";
|
||||||
const char GRAPH_FLAG_RANDOM_EFFECT[] = "_random_effect";
|
const char GRAPH_FLAG_RANDOM_EFFECT[] = "_random_effect";
|
||||||
const char GRAPH_FLAG_SIDE_EFFECT[] = "_side_effect";
|
const char GRAPH_FLAG_SIDE_EFFECT[] = "_side_effect";
|
||||||
|
|
||||||
|
// method names of python primitive called from c++ source code
|
||||||
|
// 1. infer method name of class 'PrimitiveWithInfer'
|
||||||
|
const char PY_PRIM_METHOD_INFER[] = "__infer__";
|
||||||
|
// 2. check method name of class 'PrimitiveWithCheck'
|
||||||
|
const char PY_PRIM_METHOD_CHECK[] = "__check__";
|
||||||
|
// 3. method name of class 'PrimitivePy' for constant propagation
|
||||||
|
const char PY_PRIM_METHOD_INFER_VALUE[] = "infer_value";
|
||||||
|
|
||||||
|
// type inference related attributes
|
||||||
|
const char ATTR_VALUE[] = "value";
|
||||||
|
const char ATTR_DTYPE[] = "dtype";
|
||||||
|
const char ATTR_SHAPE[] = "shape";
|
||||||
|
const char ATTR_MIN_SHAPE[] = "min_shape";
|
||||||
|
const char ATTR_MAX_SHAPE[] = "max_shape";
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -23,6 +23,16 @@ extern const char GRAPH_FLAG_HAS_EFFECT[];
|
||||||
extern const char GRAPH_FLAG_EFFECT_PATIAL_ORDER[];
|
extern const char GRAPH_FLAG_EFFECT_PATIAL_ORDER[];
|
||||||
extern const char GRAPH_FLAG_RANDOM_EFFECT[];
|
extern const char GRAPH_FLAG_RANDOM_EFFECT[];
|
||||||
extern const char GRAPH_FLAG_SIDE_EFFECT[];
|
extern const char GRAPH_FLAG_SIDE_EFFECT[];
|
||||||
|
|
||||||
|
extern const char PY_PRIM_METHOD_INFER[];
|
||||||
|
extern const char PY_PRIM_METHOD_CHECK[];
|
||||||
|
extern const char PY_PRIM_METHOD_INFER_VALUE[];
|
||||||
|
|
||||||
|
extern const char ATTR_VALUE[];
|
||||||
|
extern const char ATTR_DTYPE[];
|
||||||
|
extern const char ATTR_SHAPE[];
|
||||||
|
extern const char ATTR_MIN_SHAPE[];
|
||||||
|
extern const char ATTR_MAX_SHAPE[];
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
#endif // MINDSPORE_CORE_UTILS_FLAGS_H
|
#endif // MINDSPORE_CORE_UTILS_FLAGS_H
|
||||||
|
|
|
@ -27,7 +27,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
|
||||||
Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue,
|
Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue,
|
||||||
SameTypeShape, ScatterAdd, ScatterSub, ScatterMul, ScatterDiv, ScatterMax, ScatterMin,
|
SameTypeShape, ScatterAdd, ScatterSub, ScatterMul, ScatterDiv, ScatterMax, ScatterMin,
|
||||||
ScatterUpdate, ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select,
|
ScatterUpdate, ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select,
|
||||||
Shape, Size, Slice, Split, TransShape, ParallelConcat, Padding,
|
Shape, DynamicShape, Size, Slice, Split, TransShape, ParallelConcat, Padding,
|
||||||
ScatterNdAdd, ScatterNdSub, ScatterNonAliasingAdd, ReverseV2, Rint,
|
ScatterNdAdd, ScatterNdSub, ScatterNonAliasingAdd, ReverseV2, Rint,
|
||||||
Squeeze, StridedSlice, Tile, TensorScatterUpdate, EditDistance,
|
Squeeze, StridedSlice, Tile, TensorScatterUpdate, EditDistance,
|
||||||
Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, UnsortedSegmentProd,
|
Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, UnsortedSegmentProd,
|
||||||
|
@ -206,6 +206,7 @@ __all__ = [
|
||||||
'HookBackward',
|
'HookBackward',
|
||||||
'InvertPermutation',
|
'InvertPermutation',
|
||||||
'Shape',
|
'Shape',
|
||||||
|
'DynamicShape',
|
||||||
'DropoutDoMask',
|
'DropoutDoMask',
|
||||||
'DropoutGenMask',
|
'DropoutGenMask',
|
||||||
'DropoutGrad',
|
'DropoutGrad',
|
||||||
|
|
|
@ -27,7 +27,7 @@ import numpy as np
|
||||||
|
|
||||||
from .._utils import get_concat_offset
|
from .._utils import get_concat_offset
|
||||||
from ..operations.math_ops import _infer_shape_reduce
|
from ..operations.math_ops import _infer_shape_reduce
|
||||||
from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register, _run_op
|
from ..primitive import Primitive, PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register, _run_op
|
||||||
from ..._c_expression import signature_dtype as sig_dtype
|
from ..._c_expression import signature_dtype as sig_dtype
|
||||||
from ..._c_expression import signature_kind as sig_kind
|
from ..._c_expression import signature_kind as sig_kind
|
||||||
from ..._c_expression import signature_rw as sig_rw
|
from ..._c_expression import signature_rw as sig_rw
|
||||||
|
@ -142,6 +142,11 @@ class ExpandDims(PrimitiveWithInfer):
|
||||||
out = {'shape': x_shape,
|
out = {'shape': x_shape,
|
||||||
'dtype': x['dtype'],
|
'dtype': x['dtype'],
|
||||||
'value': value}
|
'value': value}
|
||||||
|
if 'min_shape' in x and 'max_shape' in x:
|
||||||
|
out['min_shape'] = x['min_shape']
|
||||||
|
out['min_shape'].insert(axis_v, 1)
|
||||||
|
out['max_shape'] = x['max_shape']
|
||||||
|
out['max_shape'].insert(axis_v, 1)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
@ -277,6 +282,9 @@ class Cast(PrimitiveWithInfer):
|
||||||
out = {'shape': x['shape'],
|
out = {'shape': x['shape'],
|
||||||
'dtype': mstype.tensor_type(t['value']),
|
'dtype': mstype.tensor_type(t['value']),
|
||||||
'value': value}
|
'value': value}
|
||||||
|
if 'min_shape' in x and 'max_shape' in x:
|
||||||
|
out['min_shape'] = x['min_shape']
|
||||||
|
out['max_shape'] = x['max_shape']
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
@ -445,6 +453,27 @@ class Shape(PrimitiveWithInfer):
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class DynamicShape(Primitive):
|
||||||
|
"""
|
||||||
|
Returns the shape of input tensor.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Tensor[int], 1-dim Tensor of type int32
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> input_tensor = Tensor(np.ones(shape=[3, 2, 1]), mindspore.float32)
|
||||||
|
>>> shape = P.DynamicShape()
|
||||||
|
>>> output = shape(input_tensor)
|
||||||
|
"""
|
||||||
|
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self):
|
||||||
|
"""init Shape"""
|
||||||
|
|
||||||
|
|
||||||
class Squeeze(PrimitiveWithInfer):
|
class Squeeze(PrimitiveWithInfer):
|
||||||
"""
|
"""
|
||||||
Returns a tensor with the same type but dimensions of 1 being removed based on axis.
|
Returns a tensor with the same type but dimensions of 1 being removed based on axis.
|
||||||
|
@ -578,7 +607,7 @@ class Unique(Primitive):
|
||||||
self.init_prim_io_names(inputs=['x'], outputs=['output'])
|
self.init_prim_io_names(inputs=['x'], outputs=['output'])
|
||||||
|
|
||||||
|
|
||||||
class GatherV2(PrimitiveWithInfer):
|
class GatherV2(PrimitiveWithCheck):
|
||||||
"""
|
"""
|
||||||
Returns a slice of input tensor based on the specified indices and axis.
|
Returns a slice of input tensor based on the specified indices and axis.
|
||||||
|
|
||||||
|
@ -605,7 +634,7 @@ class GatherV2(PrimitiveWithInfer):
|
||||||
"""init index_select"""
|
"""init index_select"""
|
||||||
self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output'])
|
self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output'])
|
||||||
|
|
||||||
def __infer__(self, params, indices, axis):
|
def __check__(self, params, indices, axis):
|
||||||
validator.check_subclass("params", params['dtype'], mstype.tensor, self.name)
|
validator.check_subclass("params", params['dtype'], mstype.tensor, self.name)
|
||||||
validator.check_tensor_type_same({"indices": indices['dtype']}, mstype.int_type, self.name)
|
validator.check_tensor_type_same({"indices": indices['dtype']}, mstype.int_type, self.name)
|
||||||
validator.check_subclass("axis", axis['dtype'], mstype.int_, self.name)
|
validator.check_subclass("axis", axis['dtype'], mstype.int_, self.name)
|
||||||
|
@ -613,13 +642,6 @@ class GatherV2(PrimitiveWithInfer):
|
||||||
params_shp = params['shape']
|
params_shp = params['shape']
|
||||||
rank = len(params_shp)
|
rank = len(params_shp)
|
||||||
validator.check_int_range("axis", axis_v, -rank, rank, Rel.INC_LEFT, self.name)
|
validator.check_int_range("axis", axis_v, -rank, rank, Rel.INC_LEFT, self.name)
|
||||||
if axis_v < 0:
|
|
||||||
axis_v += rank
|
|
||||||
out_shape = params_shp[:axis_v] + indices['shape'] + params_shp[axis_v + 1:]
|
|
||||||
out = {'shape': out_shape,
|
|
||||||
'dtype': params['dtype'],
|
|
||||||
'value': None}
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class SparseGatherV2(GatherV2):
|
class SparseGatherV2(GatherV2):
|
||||||
|
|
|
@ -26,7 +26,7 @@ from ..._checkparam import Rel
|
||||||
from ...common import dtype as mstype
|
from ...common import dtype as mstype
|
||||||
from ...common.tensor import Tensor
|
from ...common.tensor import Tensor
|
||||||
from .._utils import get_broadcast_shape
|
from .._utils import get_broadcast_shape
|
||||||
from ..primitive import PrimitiveWithInfer, prim_attr_register, _run_op
|
from ..primitive import PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register, _run_op
|
||||||
|
|
||||||
|
|
||||||
def _infer_shape_reduce(x, axis, keep_dims, prim_name):
|
def _infer_shape_reduce(x, axis, keep_dims, prim_name):
|
||||||
|
@ -1257,7 +1257,7 @@ class Rsqrt(PrimitiveWithInfer):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
class Sqrt(PrimitiveWithInfer):
|
class Sqrt(PrimitiveWithCheck):
|
||||||
"""
|
"""
|
||||||
Returns square root of a tensor element-wise.
|
Returns square root of a tensor element-wise.
|
||||||
|
|
||||||
|
@ -1279,12 +1279,8 @@ class Sqrt(PrimitiveWithInfer):
|
||||||
"""init Sqrt"""
|
"""init Sqrt"""
|
||||||
self.init_prim_io_names(inputs=['x'], outputs=['output'])
|
self.init_prim_io_names(inputs=['x'], outputs=['output'])
|
||||||
|
|
||||||
def infer_shape(self, x_shape):
|
def check_dtype(self, x_type):
|
||||||
return x_shape
|
|
||||||
|
|
||||||
def infer_dtype(self, x_type):
|
|
||||||
validator.check_tensor_type_same({"x": x_type}, mstype.number_type, self.name)
|
validator.check_tensor_type_same({"x": x_type}, mstype.number_type, self.name)
|
||||||
return x_type
|
|
||||||
|
|
||||||
def infer_value(self, x):
|
def infer_value(self, x):
|
||||||
if x is not None:
|
if x is not None:
|
||||||
|
|
|
@ -28,7 +28,7 @@ from ..._c_expression import signature_dtype as sig_dtype
|
||||||
from ..._checkparam import Validator as validator
|
from ..._checkparam import Validator as validator
|
||||||
from ..._checkparam import Rel
|
from ..._checkparam import Rel
|
||||||
from ...common import dtype as mstype
|
from ...common import dtype as mstype
|
||||||
from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register
|
from ..primitive import Primitive, PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register
|
||||||
from ..operations.math_ops import _infer_shape_reduce
|
from ..operations.math_ops import _infer_shape_reduce
|
||||||
|
|
||||||
|
|
||||||
|
@ -4354,7 +4354,7 @@ class ApplyProximalAdagrad(PrimitiveWithInfer):
|
||||||
return var_dtype, accum_dtype
|
return var_dtype, accum_dtype
|
||||||
|
|
||||||
|
|
||||||
class SparseApplyProximalAdagrad(PrimitiveWithInfer):
|
class SparseApplyProximalAdagrad(PrimitiveWithCheck):
|
||||||
r"""
|
r"""
|
||||||
Update relevant entries according to the proximal adagrad algorithm. Compared with ApplyProximalAdagrad,
|
Update relevant entries according to the proximal adagrad algorithm. Compared with ApplyProximalAdagrad,
|
||||||
an additional index tensor is input.
|
an additional index tensor is input.
|
||||||
|
@ -4433,11 +4433,10 @@ class SparseApplyProximalAdagrad(PrimitiveWithInfer):
|
||||||
outputs=['var', 'accum'])
|
outputs=['var', 'accum'])
|
||||||
self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name)
|
self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name)
|
||||||
|
|
||||||
def infer_shape(self, var_shape, accum_shape, lr_shape, l1_shape, l2_shape, grad_shape, indices_shape):
|
def check_shape(self, var_shape, accum_shape, lr_shape, l1_shape, l2_shape, grad_shape, indices_shape):
|
||||||
validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name)
|
validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name)
|
||||||
return var_shape, accum_shape
|
|
||||||
|
|
||||||
def infer_dtype(self, var_dtype, accum_dtype, lr_dtype, l1_dtype, l2_dtype, grad_dtype, indices_dtype):
|
def check_dtype(self, var_dtype, accum_dtype, lr_dtype, l1_dtype, l2_dtype, grad_dtype, indices_dtype):
|
||||||
args = {'var': var_dtype, 'accum': accum_dtype, 'grad': grad_dtype}
|
args = {'var': var_dtype, 'accum': accum_dtype, 'grad': grad_dtype}
|
||||||
validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name)
|
validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name)
|
||||||
validator.check_scalar_or_tensor_type_same({"lr": lr_dtype}, [mstype.float16, mstype.float32], self.name)
|
validator.check_scalar_or_tensor_type_same({"lr": lr_dtype}, [mstype.float16, mstype.float32], self.name)
|
||||||
|
@ -4446,7 +4445,6 @@ class SparseApplyProximalAdagrad(PrimitiveWithInfer):
|
||||||
valid_types = [mstype.int16, mstype.int32, mstype.int64,
|
valid_types = [mstype.int16, mstype.int32, mstype.int64,
|
||||||
mstype.uint16, mstype.uint32, mstype.uint64]
|
mstype.uint16, mstype.uint32, mstype.uint64]
|
||||||
validator.check_tensor_type_same({'indices': indices_dtype}, valid_types, self.name)
|
validator.check_tensor_type_same({'indices': indices_dtype}, valid_types, self.name)
|
||||||
return var_dtype, accum_dtype
|
|
||||||
|
|
||||||
|
|
||||||
class ApplyAddSign(PrimitiveWithInfer):
|
class ApplyAddSign(PrimitiveWithInfer):
|
||||||
|
@ -4978,7 +4976,7 @@ class ApplyFtrl(PrimitiveWithInfer):
|
||||||
return var_type
|
return var_type
|
||||||
|
|
||||||
|
|
||||||
class SparseApplyFtrl(PrimitiveWithInfer):
|
class SparseApplyFtrl(PrimitiveWithCheck):
|
||||||
"""
|
"""
|
||||||
Update relevant entries according to the FTRL-proximal scheme.
|
Update relevant entries according to the FTRL-proximal scheme.
|
||||||
|
|
||||||
|
@ -5053,21 +5051,19 @@ class SparseApplyFtrl(PrimitiveWithInfer):
|
||||||
self.lr_power = validator.check_number("lr_power", lr_power, 0, Rel.LE, self.name)
|
self.lr_power = validator.check_number("lr_power", lr_power, 0, Rel.LE, self.name)
|
||||||
self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name)
|
self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name)
|
||||||
|
|
||||||
def infer_shape(self, var_shape, accum_shape, linear_shape, grad_shape, indices_shape):
|
def check_shape(self, var_shape, accum_shape, linear_shape, grad_shape, indices_shape):
|
||||||
validator.check('var shape', var_shape, 'accum shape', accum_shape, Rel.EQ, self.name)
|
validator.check('var shape', var_shape, 'accum shape', accum_shape, Rel.EQ, self.name)
|
||||||
validator.check('var shape', var_shape, 'linear shape', linear_shape, Rel.EQ, self.name)
|
validator.check('var shape', var_shape, 'linear shape', linear_shape, Rel.EQ, self.name)
|
||||||
if len(var_shape) > 1:
|
if len(var_shape) > 1:
|
||||||
validator.check('var_shape[1:]', var_shape[1:], 'grad_shape[1:]', grad_shape[1:], Rel.EQ, self.name)
|
validator.check('var_shape[1:]', var_shape[1:], 'grad_shape[1:]', grad_shape[1:], Rel.EQ, self.name)
|
||||||
validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name)
|
validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name)
|
||||||
validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name)
|
validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name)
|
||||||
return var_shape, accum_shape, linear_shape
|
|
||||||
|
|
||||||
def infer_dtype(self, var_dtype, accum_dtype, linear_dtype, grad_dtype, indices_dtype):
|
def check_dtype(self, var_dtype, accum_dtype, linear_dtype, grad_dtype, indices_dtype):
|
||||||
args = {"var_dtype": var_dtype, "accum_dtype": accum_dtype,
|
args = {"var_dtype": var_dtype, "accum_dtype": accum_dtype,
|
||||||
"linear_dtype": linear_dtype, "grad_dtype": grad_dtype}
|
"linear_dtype": linear_dtype, "grad_dtype": grad_dtype}
|
||||||
validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name)
|
validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name)
|
||||||
validator.check_tensor_type_same({"indices_dtype": indices_dtype}, [mstype.int32], self.name)
|
validator.check_tensor_type_same({"indices_dtype": indices_dtype}, [mstype.int32], self.name)
|
||||||
return var_dtype, accum_dtype, linear_dtype
|
|
||||||
|
|
||||||
|
|
||||||
class SparseApplyFtrlV2(PrimitiveWithInfer):
|
class SparseApplyFtrlV2(PrimitiveWithInfer):
|
||||||
|
|
|
@ -200,6 +200,84 @@ class Primitive(Primitive_):
|
||||||
return self._update_parameter
|
return self._update_parameter
|
||||||
|
|
||||||
|
|
||||||
|
class PrimitiveWithCheck(Primitive):
|
||||||
|
"""
|
||||||
|
PrimitiveWithCheck is the base class of primitives in python defines functions for checking operator input arguments
|
||||||
|
but used the infer method registed in c++ source codes.
|
||||||
|
|
||||||
|
There are three methods can be overide to define the check logic of the primitive: __check__(), check_shape(),
|
||||||
|
check_dtype(). If __check__() is defined in primitive, the __check__() has highest priority to be called.
|
||||||
|
If __check__() is not defined, infer_shape() and infer_dtype() can be defined to describe the check logic of
|
||||||
|
the shape and type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): Name of the current Primitive.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> # init a Primitive class with check
|
||||||
|
>>> class Flatten(PrimitiveWithCheck):
|
||||||
|
>>> @prim_attr_register
|
||||||
|
>>> def __init__(self):
|
||||||
|
>>> pass
|
||||||
|
>>> def check_shape(self, input_x):
|
||||||
|
>>> validator.check_integer('input_x rank', len(input_x), 1, Rel.GE, self.name)
|
||||||
|
>>>
|
||||||
|
>>> def check_dtype(self, input_x):
|
||||||
|
>>> validator.check_subclass("input_x", input_x, mstype.tensor, self.name)
|
||||||
|
>>>
|
||||||
|
>>> # init a Primitive obj
|
||||||
|
>>> add = Flatten()
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, name):
|
||||||
|
Primitive.__init__(self, name)
|
||||||
|
self.set_prim_type(prim_type.py_infer_check)
|
||||||
|
|
||||||
|
def _clone(self):
|
||||||
|
"""
|
||||||
|
Deeply clones the primitive object.
|
||||||
|
|
||||||
|
Calls the __init__() method with the same arguments. This method is called in parser if the
|
||||||
|
flag self.__setattr_flag__ is True.
|
||||||
|
"""
|
||||||
|
cloned_prim = Primitive._clone(self)
|
||||||
|
return cloned_prim
|
||||||
|
|
||||||
|
def check_shape(self, *args):
|
||||||
|
"""
|
||||||
|
Check shapes of input args.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
The shape of scalar is an empty tuple.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args (tuple(int)): shapes of input tensors.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
None.
|
||||||
|
"""
|
||||||
|
return None
|
||||||
|
|
||||||
|
def check_dtype(self, *args):
|
||||||
|
"""
|
||||||
|
Check data types of input args.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args (:class:`mindspore.dtype`): data type of inputs.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
None.
|
||||||
|
"""
|
||||||
|
return None
|
||||||
|
|
||||||
|
def __check__(self, *args):
|
||||||
|
"""Check shape, type, and value at the same time by using dictionary as arguments."""
|
||||||
|
tracks = ['dtype', 'shape']
|
||||||
|
for track in tracks:
|
||||||
|
fn = getattr(self, 'check_' + track)
|
||||||
|
fn(*(x[track] for x in args))
|
||||||
|
|
||||||
|
|
||||||
class PrimitiveWithInfer(Primitive):
|
class PrimitiveWithInfer(Primitive):
|
||||||
"""
|
"""
|
||||||
PrimitiveWithInfer is the base class of primitives in python defines functions for tracking inference in python.
|
PrimitiveWithInfer is the base class of primitives in python defines functions for tracking inference in python.
|
||||||
|
@ -306,6 +384,18 @@ class PrimitiveWithInfer(Primitive):
|
||||||
if not is_graph_mode:
|
if not is_graph_mode:
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
# output does not contain dynamic shape, no need to calculate min/max shape
|
||||||
|
def has_dynamic_shape(shp):
|
||||||
|
if isinstance(shp, int):
|
||||||
|
return shp < 0
|
||||||
|
if isinstance(shp, (list, tuple)):
|
||||||
|
return any(has_dynamic_shape(e) for e in shp)
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not has_dynamic_shape(out['shape']):
|
||||||
|
return out
|
||||||
|
|
||||||
|
# calculate min/max shape for output
|
||||||
def get_specified_shape(elems, attr):
|
def get_specified_shape(elems, attr):
|
||||||
has_specified_shape = False
|
has_specified_shape = False
|
||||||
ret_vals = []
|
ret_vals = []
|
||||||
|
@ -345,6 +435,8 @@ def prim_attr_register(fn):
|
||||||
def deco(self, *args, **kwargs):
|
def deco(self, *args, **kwargs):
|
||||||
if isinstance(self, PrimitiveWithInfer):
|
if isinstance(self, PrimitiveWithInfer):
|
||||||
PrimitiveWithInfer.__init__(self, self.__class__.__name__)
|
PrimitiveWithInfer.__init__(self, self.__class__.__name__)
|
||||||
|
elif isinstance(self, PrimitiveWithCheck):
|
||||||
|
PrimitiveWithCheck.__init__(self, self.__class__.__name__)
|
||||||
else:
|
else:
|
||||||
Primitive.__init__(self, self.__class__.__name__)
|
Primitive.__init__(self, self.__class__.__name__)
|
||||||
bound_args = inspect.signature(fn).bind(self, *args, **kwargs)
|
bound_args = inspect.signature(fn).bind(self, *args, **kwargs)
|
||||||
|
|
|
@ -27,7 +27,7 @@ from mindspore.ops import composite as C
|
||||||
from mindspore.ops import functional as F
|
from mindspore.ops import functional as F
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like
|
from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like
|
||||||
from mindspore.ops.primitive import constexpr
|
from mindspore.ops.primitive import constexpr, PrimitiveWithInfer, prim_attr_register
|
||||||
from mindspore.ops._grad.grad_base import bprop_getters
|
from mindspore.ops._grad.grad_base import bprop_getters
|
||||||
from mindspore import Tensor, RowTensor, context
|
from mindspore import Tensor, RowTensor, context
|
||||||
from mindspore.common.parameter import Parameter, ParameterTuple
|
from mindspore.common.parameter import Parameter, ParameterTuple
|
||||||
|
@ -105,10 +105,31 @@ def _generate_inverse_index(x_shape, axis):
|
||||||
perm = index[1:1 + axis] + (0,) + index[1 + axis:]
|
perm = index[1:1 + axis] + (0,) + index[1 + axis:]
|
||||||
return perm
|
return perm
|
||||||
|
|
||||||
class MySparseGatherV2(P.GatherV2):
|
# pylint: disable=W0231
|
||||||
|
class MySparseGatherV2(PrimitiveWithInfer):
|
||||||
"""
|
"""
|
||||||
For test
|
For test
|
||||||
"""
|
"""
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self):
|
||||||
|
"""init index_select"""
|
||||||
|
self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output'])
|
||||||
|
|
||||||
|
def __infer__(self, params, indices, axis):
|
||||||
|
validator.check_subclass("params", params['dtype'], mstype.tensor, self.name)
|
||||||
|
validator.check_tensor_type_same({"indices": indices['dtype']}, mstype.int_type, self.name)
|
||||||
|
validator.check_subclass("axis", axis['dtype'], mstype.int_, self.name)
|
||||||
|
axis_v = axis['value']
|
||||||
|
params_shp = params['shape']
|
||||||
|
rank = len(params_shp)
|
||||||
|
validator.check_int_range("axis", axis_v, -rank, rank, Rel.INC_LEFT, self.name)
|
||||||
|
if axis_v < 0:
|
||||||
|
axis_v += rank
|
||||||
|
out_shape = params_shp[:axis_v] + indices['shape'] + params_shp[axis_v + 1:]
|
||||||
|
out = {'shape': out_shape,
|
||||||
|
'dtype': params['dtype'],
|
||||||
|
'value': None}
|
||||||
|
return out
|
||||||
|
|
||||||
@bprop_getters.register(MySparseGatherV2)
|
@bprop_getters.register(MySparseGatherV2)
|
||||||
def get_bprop_sparse_gather_v2(self):
|
def get_bprop_sparse_gather_v2(self):
|
||||||
|
|
|
@ -0,0 +1,109 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
""" test dynamic shape """
|
||||||
|
from mindspore import Tensor, context, nn, Parameter
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore import dtype as mstype
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, save_graphs=False)
|
||||||
|
|
||||||
|
|
||||||
|
def test_sparse_apply_proximal_ada_grad():
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.sparse_apply_proximal_adagrad = P.SparseApplyProximalAdagrad()
|
||||||
|
self.var = Parameter(Tensor(np.random.rand(7800, 80).astype(np.float32)), name="var")
|
||||||
|
self.accum = Parameter(Tensor(np.random.rand(7800, 80).astype(np.float32)), name="accum")
|
||||||
|
self.lr = 0.01
|
||||||
|
self.l1 = 0.0
|
||||||
|
self.l2 = 0.0
|
||||||
|
def construct(self, grad, indices):
|
||||||
|
out = self.sparse_apply_proximal_adagrad(self.var, self.accum, self.lr, self.l1, self.l2, grad, indices)
|
||||||
|
return out[0]
|
||||||
|
|
||||||
|
class NetWrapper(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(NetWrapper, self).__init__()
|
||||||
|
self.unq = P.Unique()
|
||||||
|
self.add = P.TensorAdd()
|
||||||
|
self.expand_dims = P.ExpandDims()
|
||||||
|
self.cast = P.Cast()
|
||||||
|
self.net = Net()
|
||||||
|
|
||||||
|
def construct(self, grad, inp):
|
||||||
|
ids, _ = self.unq(inp)
|
||||||
|
new_grad = self.expand_dims(ids, 1)
|
||||||
|
new_grad = self.cast(new_grad, mstype.float32) + grad
|
||||||
|
return self.net(new_grad, ids)
|
||||||
|
|
||||||
|
net = NetWrapper()
|
||||||
|
grad = Tensor(np.random.rand(1, 80).astype(np.float32))
|
||||||
|
indices = Tensor(np.ones([7800]), mstype.int32)
|
||||||
|
net(grad, indices)
|
||||||
|
|
||||||
|
|
||||||
|
def test_sparse_apply_ftrl():
|
||||||
|
class SparseApplyFtrlNet(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(SparseApplyFtrlNet, self).__init__()
|
||||||
|
self.sparse_apply_ftrl = P.SparseApplyFtrl(lr=0.01, l1=0.0, l2=0.0, lr_power=-0.5)
|
||||||
|
self.var = Parameter(Tensor(np.random.rand(7800, 80).astype(np.float32)), name="var")
|
||||||
|
self.accum = Parameter(Tensor(np.random.rand(7800, 80).astype(np.float32)), name="accum")
|
||||||
|
self.linear = Parameter(Tensor(np.random.rand(7800, 80).astype(np.float32)), name="linear")
|
||||||
|
|
||||||
|
def construct(self, grad, indices):
|
||||||
|
out = self.sparse_apply_ftrl(self.var, self.accum, self.linear, grad, indices)
|
||||||
|
return out[0]
|
||||||
|
|
||||||
|
class NetWrapper(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(NetWrapper, self).__init__()
|
||||||
|
self.unq = P.Unique()
|
||||||
|
self.add = P.TensorAdd()
|
||||||
|
self.expand_dims = P.ExpandDims()
|
||||||
|
self.cast = P.Cast()
|
||||||
|
self.net = SparseApplyFtrlNet()
|
||||||
|
|
||||||
|
def construct(self, grad, inp):
|
||||||
|
ids, _ = self.unq(inp)
|
||||||
|
new_grad = self.expand_dims(ids, 1)
|
||||||
|
new_grad = self.cast(new_grad, mstype.float32) + grad
|
||||||
|
return self.net(new_grad, ids)
|
||||||
|
|
||||||
|
net = NetWrapper()
|
||||||
|
grad = Tensor(np.random.rand(1, 80).astype(np.float32))
|
||||||
|
indices = Tensor(np.ones([7800]), mstype.int32)
|
||||||
|
net(grad, indices)
|
||||||
|
|
||||||
|
|
||||||
|
def test_gatherv2():
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.unq = P.Unique()
|
||||||
|
self.gather = P.GatherV2()
|
||||||
|
|
||||||
|
def construct(self, x, y):
|
||||||
|
u, _ = self.unq(y)
|
||||||
|
z = self.gather(x, u, 0)
|
||||||
|
return z
|
||||||
|
|
||||||
|
x = Tensor(np.ones([20, 12], dtype=np.float32))
|
||||||
|
y = Tensor(np.ones([8], dtype=np.int32))
|
||||||
|
net = Net()
|
||||||
|
net(x, y)
|
Loading…
Reference in New Issue