forked from mindspore-Ecosystem/mindspore
!1918 sparse grad for gatherv2
Merge pull request !1918 from riemann_penn/sparse_grad_for_gatherv2
This commit is contained in:
commit
3b8edd5a5b
|
@ -30,6 +30,7 @@
|
|||
#include "pipeline/parse/python_adapter.h"
|
||||
#include "pipeline/parse/resolve.h"
|
||||
#include "operator/composite/composite.h"
|
||||
#include "operator/composite/map.h"
|
||||
#include "utils/ordered_map.h"
|
||||
#include "utils/ordered_set.h"
|
||||
#include "utils/utils.h"
|
||||
|
@ -190,6 +191,8 @@ std::string AnfExporter::GetMultitypeFuncGraphText(const prim::MultitypeFuncGrap
|
|||
* ├── MultitypeGraph
|
||||
* ├── HyperMap
|
||||
* │ └── HyperMapPy
|
||||
* ├── Map
|
||||
* │ └── MapPy
|
||||
* ├── Tail
|
||||
* ├── MakeTupleGradient
|
||||
* ├── GradOperation
|
||||
|
@ -208,17 +211,25 @@ std::string AnfExporter::GetMetaFuncGraphText(const MetaFuncGraphPtr &meta_func_
|
|||
oss << GetMultitypeFuncGraphText(mt_func_graph);
|
||||
} else if (meta_func_graph
|
||||
->isa<prim::HyperMapPy>()) { // this statement must before 'meta_graph->isa<prim::HyperMap>()'
|
||||
prim::HyperMapPyPtr hyper_map = meta_func_graph->cast<prim::HyperMapPyPtr>();
|
||||
MS_EXCEPTION_IF_NULL(hyper_map);
|
||||
auto hyper_map = meta_func_graph->cast<prim::HyperMapPyPtr>();
|
||||
if (hyper_map->GetFnLeaf() != nullptr) {
|
||||
oss << "{fn_leaf=" << GetMetaFuncGraphText(hyper_map->GetFnLeaf()) << "}";
|
||||
}
|
||||
} else if (meta_func_graph->isa<prim::HyperMap>()) {
|
||||
prim::HyperMapPtr hyper_map = meta_func_graph->cast<prim::HyperMapPtr>();
|
||||
MS_EXCEPTION_IF_NULL(hyper_map);
|
||||
auto hyper_map = meta_func_graph->cast<prim::HyperMapPtr>();
|
||||
if (hyper_map->GetFnLeaf() != nullptr) {
|
||||
oss << "{fn_leaf=" << GetMetaFuncGraphText(hyper_map->GetFnLeaf()) << "}";
|
||||
}
|
||||
} else if (meta_func_graph->isa<prim::MapPy>()) { // this statement must before 'meta_graph->isa<prim::Map>()'
|
||||
auto map = meta_func_graph->cast<prim::MapPyPtr>();
|
||||
if (map->GetFnLeaf() != nullptr) {
|
||||
oss << "{fn_leaf=" << GetMetaFuncGraphText(map->GetFnLeaf()) << "}";
|
||||
}
|
||||
} else if (meta_func_graph->isa<prim::Map>()) {
|
||||
auto map = meta_func_graph->cast<prim::MapPtr>();
|
||||
if (map->GetFnLeaf() != nullptr) {
|
||||
oss << "{fn_leaf=" << GetMetaFuncGraphText(map->GetFnLeaf()) << "}";
|
||||
}
|
||||
} else if (meta_func_graph->isa<prim::GradOperation>()) {
|
||||
prim::GradOperationPtr grad_op = meta_func_graph->cast<prim::GradOperationPtr>();
|
||||
oss << "{get_all=" << grad_op->get_all_ << ", get_by_list=" << grad_op->get_by_list_
|
||||
|
|
|
@ -0,0 +1,289 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "operator/composite/map.h"
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "ir/anf.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "pipeline/static_analysis/abstract_value.h"
|
||||
#include "pipeline/static_analysis/abstract_function.h"
|
||||
#include "pipeline/static_analysis/dshape.h"
|
||||
#include "pybind_api/api_register.h"
|
||||
#include "debug/trace.h"
|
||||
#include "operator/ops.h"
|
||||
#include "./common.h"
|
||||
|
||||
namespace mindspore {
|
||||
// namespace to support composite operators definition
|
||||
namespace prim {
|
||||
using FuncGraphAbstractClosure = mindspore::abstract::FuncGraphAbstractClosure;
|
||||
|
||||
AnfNodePtr Map::FullMakeLeaf(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const AnfNodePtrList &args) {
|
||||
MS_LOG(DEBUG) << "Map FullMakeLeaf non recursive.\n";
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
std::vector<AnfNodePtr> inputs;
|
||||
if (fn_arg != nullptr) {
|
||||
inputs.emplace_back(fn_arg);
|
||||
} else {
|
||||
inputs.emplace_back(NewValueNode(fn_leaf_));
|
||||
}
|
||||
inputs.insert(inputs.end(), args.begin(), args.end());
|
||||
return func_graph->NewCNode(inputs);
|
||||
}
|
||||
|
||||
FuncGraphPtr Map::GenerateLeafFunc(const size_t &args_size) {
|
||||
// Generate func for leaf nodes
|
||||
FuncGraphPtr ptrGraph = std::make_shared<FuncGraph>();
|
||||
ptrGraph->set_flags(FUNC_GRAPH_FLAG_CORE, true);
|
||||
ptrGraph->set_flags(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true);
|
||||
ptrGraph->debug_info()->set_name("map");
|
||||
AnfNodePtr ptrFnArg = nullptr;
|
||||
if (fn_leaf_ == nullptr) {
|
||||
ptrFnArg = ptrGraph->add_parameter();
|
||||
}
|
||||
AnfNodePtrList args;
|
||||
for (size_t i = 0; i < args_size; ++i) {
|
||||
args.emplace_back(ptrGraph->add_parameter());
|
||||
}
|
||||
ptrGraph->set_output(FullMakeLeaf(ptrGraph, ptrFnArg, args));
|
||||
return ptrGraph;
|
||||
}
|
||||
|
||||
AnfNodePtr Map::FullMakeList(const std::shared_ptr<List> &type, const FuncGraphPtr &func_graph,
|
||||
const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(type);
|
||||
|
||||
std::size_t size = type->elements().size();
|
||||
bool is_not_same =
|
||||
std::any_of(arg_pairs.begin(), arg_pairs.end(), [size](const std::pair<AnfNodePtr, TypePtr> &item) {
|
||||
auto lhs = std::dynamic_pointer_cast<List>(item.second);
|
||||
MS_EXCEPTION_IF_NULL(lhs);
|
||||
return lhs->elements().size() != size;
|
||||
});
|
||||
if (is_not_same) {
|
||||
MS_LOG(EXCEPTION) << "List in Map should have same length";
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> inputs;
|
||||
inputs.push_back(NewValueNode(prim::kPrimMakeList));
|
||||
|
||||
for (int i = 0; i < SizeToInt(size); ++i) {
|
||||
MS_LOG(DEBUG) << "GenerateLeafFunc for the " << i << "th arg of the target";
|
||||
auto ptrGraph = GenerateLeafFunc(arg_pairs.size());
|
||||
auto fn = NewValueNode(ptrGraph);
|
||||
|
||||
std::vector<AnfNodePtr> inputs2;
|
||||
inputs2.push_back(fn);
|
||||
if (fn_arg != nullptr) {
|
||||
inputs2.push_back(fn_arg);
|
||||
}
|
||||
|
||||
(void)std::transform(
|
||||
arg_pairs.begin(), arg_pairs.end(), std::back_inserter(inputs2),
|
||||
[&func_graph, i](const std::pair<AnfNodePtr, Any> &item) {
|
||||
return func_graph->NewCNode({NewValueNode(prim::kPrimListGetItem), item.first, NewValueNode(i)});
|
||||
});
|
||||
|
||||
inputs.push_back(func_graph->NewCNode(inputs2));
|
||||
}
|
||||
return func_graph->NewCNode(inputs);
|
||||
}
|
||||
|
||||
AnfNodePtr Map::FullMakeTuple(const std::shared_ptr<Tuple> &type, const FuncGraphPtr &func_graph,
|
||||
const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(type);
|
||||
|
||||
std::size_t size = type->elements().size();
|
||||
bool is_not_same =
|
||||
std::any_of(arg_pairs.begin(), arg_pairs.end(), [size](const std::pair<AnfNodePtr, TypePtr> &item) {
|
||||
auto lhs = std::dynamic_pointer_cast<Tuple>(item.second);
|
||||
MS_EXCEPTION_IF_NULL(lhs);
|
||||
return lhs->elements().size() != size;
|
||||
});
|
||||
if (is_not_same) {
|
||||
MS_LOG(EXCEPTION) << "tuple in Map should have same length";
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> inputs;
|
||||
inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
|
||||
|
||||
for (int i = 0; i < SizeToInt(size); ++i) {
|
||||
MS_LOG(DEBUG) << "GenerateLeafFunc for the " << i << "th arg of the tuple inputs";
|
||||
auto ptrGraph = GenerateLeafFunc(arg_pairs.size());
|
||||
auto fn = NewValueNode(ptrGraph);
|
||||
|
||||
std::vector<AnfNodePtr> inputs2;
|
||||
inputs2.push_back(fn);
|
||||
if (fn_arg != nullptr) {
|
||||
inputs2.push_back(fn_arg);
|
||||
}
|
||||
|
||||
(void)std::transform(
|
||||
arg_pairs.begin(), arg_pairs.end(), std::back_inserter(inputs2),
|
||||
[&func_graph, &i](std::pair<AnfNodePtr, Any> item) {
|
||||
return func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), item.first, NewValueNode(i)});
|
||||
});
|
||||
|
||||
inputs.push_back(func_graph->NewCNode(inputs2));
|
||||
}
|
||||
return func_graph->NewCNode(inputs);
|
||||
}
|
||||
|
||||
AnfNodePtr Map::FullMakeClass(const std::shared_ptr<Class> &type, const FuncGraphPtr &func_graph,
|
||||
const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs) {
|
||||
MS_EXCEPTION_IF_NULL(type);
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
|
||||
std::vector<AnfNodePtr> inputs;
|
||||
inputs.push_back(NewValueNode(prim::kPrimMakeRecord));
|
||||
inputs.push_back(NewValueNode(type));
|
||||
|
||||
std::size_t attrSize = type->GetAttributes().size();
|
||||
for (std::size_t i = 0; i < attrSize; ++i) {
|
||||
MS_LOG(DEBUG) << "GenerateLeafFunc for the " << i << "th element of the inputs";
|
||||
auto ptrGraph = GenerateLeafFunc(arg_pairs.size());
|
||||
auto fn = NewValueNode(ptrGraph);
|
||||
|
||||
std::vector<AnfNodePtr> inputs2;
|
||||
inputs2.push_back(fn);
|
||||
if (fn_arg != nullptr) {
|
||||
inputs2.push_back(fn_arg);
|
||||
}
|
||||
|
||||
int j = 0;
|
||||
for (auto item : arg_pairs) {
|
||||
inputs2.push_back(func_graph->NewCNode({NewValueNode(prim::kPrimGetAttr), item.first, NewValueNode(j)}));
|
||||
j++;
|
||||
}
|
||||
|
||||
inputs.push_back(func_graph->NewCNode(inputs2));
|
||||
}
|
||||
return func_graph->NewCNode(inputs);
|
||||
}
|
||||
|
||||
AnfNodePtr Map::Make(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs) {
|
||||
bool found = false;
|
||||
TypeId id = kObjectTypeEnd;
|
||||
std::pair<AnfNodePtr, TypePtr> pair;
|
||||
for (auto &item : arg_pairs) {
|
||||
pair = item;
|
||||
MS_LOG(DEBUG) << "Map " << pair.second->ToString();
|
||||
id = item.second->type_id();
|
||||
if (nonleaf_.count(id)) {
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (found) {
|
||||
// In a nonleaf situation, all arguments must have the same generic.
|
||||
bool is_not_same =
|
||||
std::any_of(arg_pairs.begin(), arg_pairs.end(), [pair](const std::pair<AnfNodePtr, TypePtr> &item) {
|
||||
if (item.first != pair.first) {
|
||||
return item.second->type_id() != pair.second->type_id();
|
||||
}
|
||||
return false;
|
||||
});
|
||||
if (is_not_same) {
|
||||
std::ostringstream oss;
|
||||
oss << "There are " << arg_pairs.size() << " inputs of `" << name_ << "`, corresponding type info:\n"
|
||||
<< trace::GetDebugInfo(func_graph->debug_info()) << "\n";
|
||||
int idx = 0;
|
||||
for (auto &item : arg_pairs) {
|
||||
oss << ++idx << ": " << item.second->ToString() << "\n";
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "Map cannot match up all input types of arguments.\n"
|
||||
<< oss.str() << pair.second->ToString() << "\n";
|
||||
}
|
||||
}
|
||||
|
||||
switch (id) {
|
||||
case kObjectTypeList: {
|
||||
auto type = std::static_pointer_cast<List>(pair.second);
|
||||
return FullMakeList(type, func_graph, fn_arg, arg_pairs);
|
||||
}
|
||||
case kObjectTypeTuple: {
|
||||
auto type = std::static_pointer_cast<Tuple>(pair.second);
|
||||
return FullMakeTuple(type, func_graph, fn_arg, arg_pairs);
|
||||
}
|
||||
case kObjectTypeClass: {
|
||||
auto type = std::static_pointer_cast<Class>(pair.second);
|
||||
return FullMakeClass(type, func_graph, fn_arg, arg_pairs);
|
||||
}
|
||||
default:
|
||||
MS_LOG(EXCEPTION) << "Map can only be applied to list, tuple and class "
|
||||
<< ", but got " << pair.second->ToString();
|
||||
}
|
||||
}
|
||||
|
||||
FuncGraphPtr Map::GenerateFromTypes(const TypePtrList &args_spec_list) {
|
||||
FuncGraphPtr ptrGraph = std::make_shared<FuncGraph>();
|
||||
ptrGraph->set_flags(FUNC_GRAPH_FLAG_CORE, true);
|
||||
ptrGraph->set_flags(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true);
|
||||
ptrGraph->debug_info()->set_name("map");
|
||||
|
||||
AnfNodePtr ptrFnArg = nullptr;
|
||||
std::size_t i = 0;
|
||||
if (fn_leaf_ == nullptr) {
|
||||
ptrFnArg = ptrGraph->add_parameter();
|
||||
i = 1;
|
||||
}
|
||||
ArgsPairList arg_pairs;
|
||||
std::size_t size = args_spec_list.size();
|
||||
for (; i < size; ++i) {
|
||||
MS_LOG(DEBUG) << "GenerateFromTypes for elements from " << args_spec_list[i]->ToString();
|
||||
arg_pairs.push_back(std::make_pair(ptrGraph->add_parameter(), args_spec_list[i]));
|
||||
}
|
||||
|
||||
ptrGraph->set_output(Make(ptrGraph, ptrFnArg, arg_pairs));
|
||||
return ptrGraph;
|
||||
}
|
||||
|
||||
abstract::AbstractBasePtrList Map::NormalizeArgs(const AbstractBasePtrList &args_spec_list) const {
|
||||
if (fn_leaf_ == nullptr) {
|
||||
MS_EXCEPTION_IF_NULL(args_spec_list[0]);
|
||||
// Assert that map's function param does not contain free variables
|
||||
if (args_spec_list[0]->isa<FuncGraphAbstractClosure>()) {
|
||||
auto graph_func = dyn_cast<FuncGraphAbstractClosure>(args_spec_list[0]);
|
||||
auto func_graph = graph_func->func_graph();
|
||||
if (func_graph->parent() != nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Map don't support Closure with free variable yet.";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
AbstractBasePtrList broadened;
|
||||
(void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(broadened),
|
||||
[](const AbstractBasePtr &arg) -> AbstractBasePtr {
|
||||
MS_EXCEPTION_IF_NULL(arg);
|
||||
return arg->Broaden();
|
||||
});
|
||||
return broadened;
|
||||
}
|
||||
|
||||
REGISTER_PYBIND_DEFINE(Map_, ([](const py::module *m) {
|
||||
(void)py::class_<MapPy, MetaFuncGraph, std::shared_ptr<MapPy>>(*m, "Map_")
|
||||
.def(py::init<std::shared_ptr<MultitypeFuncGraph>>(), py::arg("leaf"))
|
||||
.def(py::init<>());
|
||||
}));
|
||||
} // namespace prim
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,98 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_OPERATOR_COMPOSITE_MAP_H_
|
||||
#define MINDSPORE_CCSRC_OPERATOR_COMPOSITE_MAP_H_
|
||||
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "ir/dtype.h"
|
||||
#include "ir/meta_func_graph.h"
|
||||
#include "operator/composite/multitype_funcgraph.h"
|
||||
|
||||
namespace mindspore {
|
||||
// namespace to support composite operators definition
|
||||
namespace prim {
|
||||
using ArgsPairList = std::vector<std::pair<AnfNodePtr, TypePtr>>;
|
||||
|
||||
class Map : public MetaFuncGraph {
|
||||
public:
|
||||
explicit Map(const std::shared_ptr<MultitypeFuncGraph> &fn_leaf = nullptr)
|
||||
: MetaFuncGraph("map"),
|
||||
fn_leaf_(fn_leaf),
|
||||
broadcast_(false),
|
||||
nonleaf_({kObjectTypeList, kObjectTypeTuple, kObjectTypeClass}) {
|
||||
Init();
|
||||
}
|
||||
Map(const Map &h) : MetaFuncGraph("map"), fn_leaf_(h.fn_leaf_), broadcast_(h.broadcast_), nonleaf_(h.nonleaf_) {
|
||||
Init();
|
||||
}
|
||||
Map &operator=(const Map &h) {
|
||||
if (this != &h) {
|
||||
fn_leaf_ = h.fn_leaf_;
|
||||
broadcast_ = h.broadcast_;
|
||||
nonleaf_ = h.nonleaf_;
|
||||
if (fn_leaf_) {
|
||||
name_ = "map[" + fn_leaf_->name() + "]";
|
||||
}
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
~Map() override = default;
|
||||
MS_DECLARE_PARENT(Map, MetaFuncGraph)
|
||||
abstract::AbstractBasePtrList NormalizeArgs(const abstract::AbstractBasePtrList &args_spec_list) const override;
|
||||
FuncGraphPtr GenerateFromTypes(const TypePtrList &args_spec_list) override;
|
||||
MetaFuncGraphPtr GetFnLeaf() { return fn_leaf_; }
|
||||
|
||||
private:
|
||||
FuncGraphPtr GenerateLeafFunc(const size_t &args_size);
|
||||
AnfNodePtr FullMakeLeaf(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const AnfNodePtrList &args);
|
||||
AnfNodePtr FullMakeList(const std::shared_ptr<List> &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg,
|
||||
const ArgsPairList &arg_pairs);
|
||||
AnfNodePtr FullMakeTuple(const std::shared_ptr<Tuple> &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg,
|
||||
const ArgsPairList &arg_pairs);
|
||||
AnfNodePtr FullMakeClass(const std::shared_ptr<Class> &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg,
|
||||
const ArgsPairList &arg_pairs);
|
||||
AnfNodePtr Make(const FuncGraphPtr &graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs);
|
||||
void Init() {
|
||||
if (fn_leaf_ != nullptr) {
|
||||
name_ = "map[" + fn_leaf_->name() + "]";
|
||||
}
|
||||
signatures_ =
|
||||
// def map(func:read, *args:ref):
|
||||
std::vector<Signature>({{"func", SignatureEnumRW::kRWRead, SignatureEnumKind::kKindDefault},
|
||||
{"args", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindVarPositional}});
|
||||
}
|
||||
|
||||
MultitypeFuncGraphPtr fn_leaf_;
|
||||
bool broadcast_;
|
||||
std::set<TypeId> nonleaf_;
|
||||
};
|
||||
using MapPtr = std::shared_ptr<Map>;
|
||||
class MapPy : public Map {
|
||||
public:
|
||||
explicit MapPy(const std::shared_ptr<MultitypeFuncGraph> &fn_leaf = nullptr) : Map(fn_leaf) {}
|
||||
~MapPy() override = default;
|
||||
MS_DECLARE_PARENT(MapPy, Map)
|
||||
};
|
||||
using MapPyPtr = std::shared_ptr<MapPy>;
|
||||
} // namespace prim
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_OPERATOR_COMPOSITE_MAP_H_
|
|
@ -14,9 +14,14 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
|
||||
#include "ir/dtype.h"
|
||||
#include "common/utils.h"
|
||||
#include "operator/ops.h"
|
||||
#include "pipeline/static_analysis/param_validator.h"
|
||||
#include "pipeline/static_analysis/prim.h"
|
||||
#include "operator/ops.h"
|
||||
#include "pipeline/static_analysis/utils.h"
|
||||
#include "utils/symbolic.h"
|
||||
|
||||
|
@ -50,6 +55,65 @@ AbstractBasePtr InferImplJ(const AnalysisEnginePtr &, const PrimitivePtr &primit
|
|||
return AbstractFunction::MakeAbstractFunction(jv);
|
||||
}
|
||||
|
||||
class UndeterminedShapeType {
|
||||
public:
|
||||
explicit UndeterminedShapeType(const std::string &env_str) {
|
||||
// param_name indices_shape indices_type values_shape values_type dense_shape
|
||||
// export UNDETERMINED_SPARSE_SHAPE_TYPES="w1:2:Int32:2 1 2:Float32:3 1 2"
|
||||
std::vector<string> fields;
|
||||
string tmp;
|
||||
std::stringstream input(env_str);
|
||||
while (std::getline(input, tmp, ':')) {
|
||||
fields.push_back(tmp);
|
||||
}
|
||||
if (fields.size() != fields_num) {
|
||||
MS_LOG(EXCEPTION) << "Expect " << fields_num << " fields, but got " << fields.size();
|
||||
}
|
||||
|
||||
param_name_ = fields[0];
|
||||
|
||||
indices_shape_ = GetShape(fields[1]);
|
||||
indices_type_ = StringToType(fields[2]);
|
||||
|
||||
values_shape_ = GetShape(fields[3]);
|
||||
values_type_ = StringToType(fields[4]);
|
||||
|
||||
auto dense_shape_vec = GetShape(fields[5]);
|
||||
AbstractBasePtrList dense_shape_list;
|
||||
(void)std::transform(dense_shape_vec.begin(), dense_shape_vec.end(), std::back_inserter(dense_shape_list),
|
||||
[](const auto &elem) { return FromValue(elem, false); });
|
||||
dense_shape_ = dense_shape_list;
|
||||
}
|
||||
const std::string ¶m_name() { return param_name_; }
|
||||
const std::vector<int> &indices_shape() { return indices_shape_; }
|
||||
const TypePtr &indices_type() { return indices_type_; }
|
||||
const std::vector<int> &values_shape() { return values_shape_; }
|
||||
const TypePtr &values_type() { return values_type_; }
|
||||
const AbstractBasePtrList &dense_shape() { return dense_shape_; }
|
||||
|
||||
private:
|
||||
std::string param_name_;
|
||||
std::vector<int> indices_shape_;
|
||||
TypePtr indices_type_;
|
||||
std::vector<int> values_shape_;
|
||||
TypePtr values_type_;
|
||||
AbstractBasePtrList dense_shape_;
|
||||
static const size_t fields_num;
|
||||
|
||||
std::vector<int> GetShape(const std::string &shape_str);
|
||||
};
|
||||
std::vector<int> UndeterminedShapeType::GetShape(const std::string &shape_str) {
|
||||
std::vector<int> ret;
|
||||
std::istringstream iss(shape_str);
|
||||
int elem;
|
||||
while (iss.good()) {
|
||||
iss >> elem;
|
||||
ret.emplace_back(elem);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
const size_t UndeterminedShapeType::fields_num = 6;
|
||||
|
||||
AbstractBasePtr InferImplEnvGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
|
@ -62,6 +126,31 @@ AbstractBasePtr InferImplEnvGetItem(const AnalysisEnginePtr &, const PrimitivePt
|
|||
if (type->type_id() != kObjectTypeSymbolicKeyType) {
|
||||
MS_LOG(EXCEPTION) << "EnvGetItem evaluator args[1] should be a SymbolicKeyInstance but: " << key->ToString();
|
||||
}
|
||||
|
||||
if (key->sparse_grad()) {
|
||||
// Will be fixed once undetermined type ready
|
||||
auto sparse_shape_types = common::GetEnv("UNDETERMINED_SPARSE_SHAPE_TYPES");
|
||||
if (sparse_shape_types.empty()) {
|
||||
sparse_shape_types = "w1:2:Int32:2 1 2:Float32:3 1 2";
|
||||
}
|
||||
MS_LOG(DEBUG) << "EnvGetItem is sparse_grad " << key->ToString() << ", Undetermined shape is "
|
||||
<< sparse_shape_types;
|
||||
|
||||
auto shape_types = UndeterminedShapeType(sparse_shape_types);
|
||||
AbstractBasePtrList sparse_list;
|
||||
// indices
|
||||
auto indices_ele = std::make_shared<AbstractScalar>(kAnyValue, shape_types.indices_type());
|
||||
auto indices = std::make_shared<AbstractTensor>(indices_ele, std::make_shared<Shape>(shape_types.indices_shape()));
|
||||
sparse_list.emplace_back(indices);
|
||||
// values
|
||||
auto dout_ele = std::make_shared<AbstractScalar>(kAnyValue, shape_types.values_type());
|
||||
auto dout = std::make_shared<AbstractTensor>(dout_ele, std::make_shared<Shape>(shape_types.values_shape()));
|
||||
sparse_list.emplace_back(dout);
|
||||
// dense_shape
|
||||
sparse_list.emplace_back(std::make_shared<AbstractTuple>(shape_types.dense_shape()));
|
||||
return std::make_shared<AbstractTuple>(sparse_list);
|
||||
}
|
||||
|
||||
if (!key->GetValueTrack()->isa<SymbolicKeyInstance>()) {
|
||||
return dflt;
|
||||
}
|
||||
|
@ -80,8 +169,6 @@ AbstractBasePtr InferImplEnvSetItem(const AnalysisEnginePtr &, const PrimitivePt
|
|||
CheckArgsSize(primitive->name(), args_spec_list, 3);
|
||||
|
||||
auto key = args_spec_list[1];
|
||||
auto value = args_spec_list[2];
|
||||
|
||||
ValuePtr key_value_ptr = key->GetValueTrack();
|
||||
MS_EXCEPTION_IF_NULL(key_value_ptr);
|
||||
auto key_value_track = key_value_ptr->cast<SymbolicKeyInstancePtr>();
|
||||
|
@ -91,7 +178,6 @@ AbstractBasePtr InferImplEnvSetItem(const AnalysisEnginePtr &, const PrimitivePt
|
|||
}
|
||||
auto expected = key_value_track->abstract();
|
||||
MS_EXCEPTION_IF_NULL(expected);
|
||||
(void)expected->Join(value);
|
||||
return std::make_shared<AbstractScalar>(kAnyValue, std::make_shared<EnvType>());
|
||||
}
|
||||
|
||||
|
@ -126,7 +212,9 @@ AbstractBasePtr InferImplMakeRef(const AnalysisEnginePtr &, const PrimitivePtr &
|
|||
if (type->type_id() != kObjectTypeRefKey) {
|
||||
MS_LOG(EXCEPTION) << "First input of make_ref should be a RefKey but a " << type->ToString();
|
||||
}
|
||||
return std::make_shared<AbstractRef>(args_spec_list[0], args_spec_list[1], args_spec_list[2]);
|
||||
auto ret = std::make_shared<AbstractRef>(args_spec_list[0], args_spec_list[1], args_spec_list[2]);
|
||||
ret->set_sparse_grad(args_spec_list[2]->sparse_grad());
|
||||
return ret;
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplGetRefKey(const AnalysisEnginePtr &, const PrimitivePtr &,
|
||||
|
|
|
@ -38,6 +38,7 @@
|
|||
#include "pipeline/remove_value_node_dup.h"
|
||||
#include "optimizer/optimizer.h"
|
||||
#include "vm/transform.h"
|
||||
#include "parse/python_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace pipeline {
|
||||
|
@ -228,6 +229,8 @@ bool AbstractSpecializeAction(const ResourcePtr &res) {
|
|||
if (param_node->has_default()) {
|
||||
auto param_value = std::dynamic_pointer_cast<ParamValuePy>(param_node->default_param());
|
||||
AbstractBasePtr ptr = abstract::FromValue(parse::data_converter::PyDataToValue(param_value->value()), true);
|
||||
auto sparse_grad = py::cast<bool>(parse::python_adapter::GetPyObjAttr(param_value->value(), "sparse_grad"));
|
||||
ptr->set_sparse_grad(sparse_grad);
|
||||
|
||||
parallel::ParallelParameterContextRestoreInNoTraining(func_graph, param_node, ptr);
|
||||
args_spec.push_back(ptr);
|
||||
|
|
|
@ -51,6 +51,7 @@ ValuePtr AbstractBase::BuildValue() const {
|
|||
AbstractBasePtr AbstractBase::Broaden() const {
|
||||
AbstractBasePtr clone = Clone();
|
||||
clone->set_value(kAnyValue);
|
||||
clone->set_sparse_grad(sparse_grad_);
|
||||
return clone;
|
||||
}
|
||||
|
||||
|
@ -63,7 +64,8 @@ std::string AbstractBase::ToString() const {
|
|||
MS_EXCEPTION_IF_NULL(type_);
|
||||
MS_EXCEPTION_IF_NULL(shape_);
|
||||
buffer << type_name() << "("
|
||||
<< "Type: " << type_->ToString() << " Value: " << value << " Shape: " << shape_->ToString() << ")";
|
||||
<< "Type: " << type_->ToString() << " Value: " << value << " Shape: " << shape_->ToString()
|
||||
<< " sparse_grad: " << sparse_grad_ << ")";
|
||||
return buffer.str();
|
||||
}
|
||||
|
||||
|
@ -72,16 +74,22 @@ AbstractBasePtr AbstractScalar::Broaden() const { return AbstractBase::Broaden()
|
|||
AbstractBasePtr AbstractScalar::Join(const AbstractBasePtr &other) {
|
||||
MS_EXCEPTION_IF_NULL(other);
|
||||
if (*this == *other) {
|
||||
return shared_from_base<AbstractBase>();
|
||||
auto ret = shared_from_base<AbstractBase>();
|
||||
ret->set_sparse_grad(sparse_grad());
|
||||
return ret;
|
||||
}
|
||||
auto value_self = GetValueTrack();
|
||||
MS_EXCEPTION_IF_NULL(value_self);
|
||||
ValuePtr res_value = ValueJoin(value_self, other->GetValueTrack());
|
||||
TypePtr res_type = TypeJoin(GetTypeTrack(), other->GetTypeTrack());
|
||||
if (res_value == value_self) {
|
||||
return shared_from_base<AbstractBase>();
|
||||
auto ret = shared_from_base<AbstractBase>();
|
||||
ret->set_sparse_grad(sparse_grad());
|
||||
return ret;
|
||||
}
|
||||
return std::make_shared<AbstractScalar>(res_value, res_type);
|
||||
auto ret = std::make_shared<AbstractScalar>(res_value, res_type);
|
||||
ret->set_sparse_grad(sparse_grad());
|
||||
return ret;
|
||||
}
|
||||
|
||||
AbstractBasePtr AbstractType::Clone() const {
|
||||
|
@ -423,7 +431,9 @@ AbstractBasePtr AbstractTensor::Join(const AbstractBasePtr &other) {
|
|||
}
|
||||
auto element = element_->Join(other_tensor->element_);
|
||||
auto shape = ShapeJoin(this->shape(), other_tensor->shape());
|
||||
return std::make_shared<AbstractTensor>(element, shape);
|
||||
auto ret = std::make_shared<AbstractTensor>(element, shape);
|
||||
ret->set_sparse_grad(sparse_grad());
|
||||
return ret;
|
||||
}
|
||||
|
||||
bool AbstractTensor::operator==(const AbstractTensor &other) const {
|
||||
|
@ -463,6 +473,7 @@ AbstractBasePtr AbstractTensor::Clone() const {
|
|||
ShapePtr shp = shape();
|
||||
clone->set_shape(shp->Clone());
|
||||
clone->set_value(GetValueTrack());
|
||||
clone->set_sparse_grad(sparse_grad());
|
||||
return clone;
|
||||
}
|
||||
|
||||
|
@ -472,6 +483,7 @@ AbstractBasePtr AbstractTensor::Broaden() const {
|
|||
auto shp = shape();
|
||||
broaden->set_shape(shp->Clone());
|
||||
broaden->set_value(kAnyValue);
|
||||
broaden->set_sparse_grad(sparse_grad());
|
||||
return broaden;
|
||||
}
|
||||
|
||||
|
@ -482,6 +494,7 @@ AbstractBasePtr AbstractTensor::BroadenWithShape() const {
|
|||
shp->Broaden();
|
||||
broaden->set_shape(shp);
|
||||
broaden->set_value(kAnyValue);
|
||||
broaden->set_sparse_grad(sparse_grad());
|
||||
return broaden;
|
||||
}
|
||||
|
||||
|
@ -502,7 +515,8 @@ std::string AbstractTensor::ToString() const {
|
|||
MS_EXCEPTION_IF_NULL(value_track);
|
||||
buffer << type_name() << "("
|
||||
<< "shape: " << shape_track->ToString() << ", element: " << element_->ToString()
|
||||
<< ", value_ptr: " << value_track << ", value: " << value_track->ToString() << ")";
|
||||
<< ", value_ptr: " << value_track << ", value: " << value_track->ToString() << " sparse_grad " << sparse_grad()
|
||||
<< ")";
|
||||
return buffer.str();
|
||||
}
|
||||
|
||||
|
|
|
@ -44,7 +44,7 @@ class AbstractBase : public Base {
|
|||
public:
|
||||
explicit AbstractBase(const ValuePtr &value = nullptr, const TypePtr &type = kAnyType,
|
||||
const BaseShapePtr &shape = kNoShape)
|
||||
: value_(value), type_(type), shape_(shape) {}
|
||||
: value_(value), type_(type), shape_(shape), sparse_grad_(false) {}
|
||||
~AbstractBase() override = default;
|
||||
MS_DECLARE_PARENT(AbstractBase, Base)
|
||||
|
||||
|
@ -53,11 +53,13 @@ class AbstractBase : public Base {
|
|||
|
||||
virtual bool operator==(const AbstractBase &other) const;
|
||||
void set_value(const ValuePtr &value) { value_ = value; }
|
||||
void set_sparse_grad(const bool &sparse_grad) { sparse_grad_ = sparse_grad; }
|
||||
void set_type(const TypePtr &type) { type_ = type; }
|
||||
void set_shape(const BaseShapePtr &shape) { shape_ = shape; }
|
||||
void set_value_desc(const std::string &desc) { value_desc_ = desc; }
|
||||
const std::string &value_desc() const { return value_desc_; }
|
||||
ValuePtr GetValueTrack() const { return value_; }
|
||||
bool sparse_grad() const { return sparse_grad_; }
|
||||
TypePtr GetTypeTrack() const { return type_; }
|
||||
BaseShapePtr GetShapeTrack() const { return shape_; }
|
||||
|
||||
|
@ -85,6 +87,7 @@ class AbstractBase : public Base {
|
|||
TypePtr type_;
|
||||
BaseShapePtr shape_;
|
||||
std::string value_desc_; // store initial value description for error report
|
||||
bool sparse_grad_;
|
||||
};
|
||||
|
||||
class AbstractScalar : public AbstractBase {
|
||||
|
|
|
@ -851,7 +851,11 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
|
|||
}
|
||||
auto refkey = key_value->cast<RefKeyPtr>();
|
||||
if (refkey == nullptr) {
|
||||
return std::make_shared<EvalResult>(std::make_shared<AbstractScalar>(type), std::make_shared<AttrValueMap>());
|
||||
auto ret = std::make_shared<AbstractScalar>(type);
|
||||
auto ref_value = ref_abs->ref();
|
||||
MS_EXCEPTION_IF_NULL(ref_value);
|
||||
ret->set_sparse_grad(ref_value->sparse_grad());
|
||||
return std::make_shared<EvalResult>(ret, std::make_shared<AttrValueMap>());
|
||||
}
|
||||
|
||||
std::string name = refkey->tag();
|
||||
|
@ -865,6 +869,7 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
|
|||
x = SensitivityTransform(x);
|
||||
std::shared_ptr<SymbolicKeyInstance> key = std::make_shared<SymbolicKeyInstance>(node, x);
|
||||
std::shared_ptr<AbstractScalar> abs_scalar = std::make_shared<AbstractScalar>(key, type);
|
||||
abs_scalar->set_sparse_grad(x->sparse_grad());
|
||||
return std::make_shared<EvalResult>(abs_scalar, std::make_shared<AttrValueMap>());
|
||||
}
|
||||
};
|
||||
|
|
|
@ -50,12 +50,14 @@ class Parameter:
|
|||
requires_grad (bool): True if the parameter requires gradient. Default: True.
|
||||
layerwise_parallel (bool): A kind of model parallel mode. When layerwise_parallel is true in paralle mode,
|
||||
broadcast and gradients communication would not be applied on parameters. Default: False.
|
||||
sparse_grad (bool): True if the parameter's gradient is sparse. Default: False.
|
||||
"""
|
||||
def __init__(self, default_input, name, requires_grad=True, layerwise_parallel=False):
|
||||
def __init__(self, default_input, name, requires_grad=True, layerwise_parallel=False, sparse_grad=False):
|
||||
self.set_parameter_data(default_input)
|
||||
self.name = name
|
||||
self.requires_grad = requires_grad
|
||||
self.layerwise_parallel = layerwise_parallel
|
||||
self.sparse_grad = sparse_grad
|
||||
self._is_init = False
|
||||
self._sliced = False
|
||||
self.clone_info = _CloneInfo()
|
||||
|
@ -168,6 +170,17 @@ class Parameter:
|
|||
raise TypeError("`requires_grad` parameter must be bool type")
|
||||
self._requires_grad = value
|
||||
|
||||
@property
|
||||
def sparse_grad(self):
|
||||
"""Return whether the parameter's gradient is sparse."""
|
||||
return self._sparse_grad
|
||||
|
||||
@sparse_grad.setter
|
||||
def sparse_grad(self, value=True):
|
||||
if not isinstance(value, bool):
|
||||
raise TypeError("`sparse_grad` parameter must be bool type")
|
||||
self._sparse_grad = value
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
return self.default_input
|
||||
|
|
|
@ -30,6 +30,7 @@ unsorted_segment_sum = P.UnsortedSegmentSum()
|
|||
transpose = P.Transpose()
|
||||
shape_op = P.Shape()
|
||||
reshape = P.Reshape()
|
||||
size_op = P.Size()
|
||||
invert_permutation = P.InvertPermutation()
|
||||
logical_and = P.LogicalAnd()
|
||||
|
||||
|
@ -284,6 +285,37 @@ def get_bprop_gather_v2(self):
|
|||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.SparseGatherV2)
|
||||
def get_bprop_sparse_gather_v2(self):
|
||||
"""Generate bprop for SparseGatherV2"""
|
||||
|
||||
def bprop(x, indices, axis, out, dout):
|
||||
x_shp = shape_op(x)
|
||||
if axis == 0:
|
||||
indices_size = (size_op(indices),)
|
||||
x_tail_shp = x_shp[1:]
|
||||
values_shape = indices_size + x_tail_shp
|
||||
values = reshape(dout, values_shape)
|
||||
indices = reshape(indices, indices_size)
|
||||
return (indices, values, x_shp), zeros_like(indices), zeros_like(axis)
|
||||
if F.rank(dout) == 0:
|
||||
dout = P.ExpandDims()(dout, -1)
|
||||
if F.rank(indices) == 0:
|
||||
indices = P.ExpandDims()(indices, -1)
|
||||
out_shp = shape_op(dout)
|
||||
ind_shp = shape_op(indices)
|
||||
# Example: out_shape:(3,2,3) axis 1 -> (1,0,2)
|
||||
perm_1 = _generate_shape_index(out_shp, ind_shp, axis)
|
||||
values_transpose = transpose(dout, perm_1)
|
||||
params_grad = unsorted_segment_sum(values_transpose, indices, shape_op(x)[axis])
|
||||
# Example: out_shape:(3,2,3) axis 2 -> (1,2,0)
|
||||
perm_2 = _generate_inverse_index(x_shp, axis)
|
||||
params_grad = transpose(params_grad, perm_2)
|
||||
return params_grad, zeros_like(indices), zeros_like(axis)
|
||||
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.Range)
|
||||
def get_bprop_range(self):
|
||||
"""Generate bprop for Range"""
|
||||
|
|
|
@ -20,7 +20,7 @@ Pre-defined combination of operators.
|
|||
"""
|
||||
|
||||
|
||||
from .base import GradOperation, HyperMap, MultitypeFuncGraph, add_flags, \
|
||||
from .base import GradOperation, HyperMap, Map, MultitypeFuncGraph, add_flags, \
|
||||
grad, grad_all, grad_all_with_sens, grad_by_list, grad_by_list_with_sens, grad_with_sens, \
|
||||
core, env_get, tail, zip_operation
|
||||
from .clip_ops import clip_by_value
|
||||
|
|
|
@ -19,7 +19,7 @@
|
|||
from functools import partial
|
||||
|
||||
from mindspore import context
|
||||
from ..._c_expression import EnvInstance_, GradOperation_, HyperMap_, MultitypeFuncGraph_, Tail_, TensorSlice_, \
|
||||
from ..._c_expression import EnvInstance_, GradOperation_, HyperMap_, Map_, MultitypeFuncGraph_, Tail_, TensorSlice_, \
|
||||
TupleAdd_, TupleSlice_, UnpackCall_, ZipOperation_, ListAppend_, TupleGetItemTensor_
|
||||
from ...common import dtype as mstype
|
||||
from ...common.api import ms_function, _pynative_exec, _wrap_func
|
||||
|
@ -241,6 +241,69 @@ class HyperMap(HyperMap_):
|
|||
return func(*args_list)
|
||||
return tuple(map(hypermap, *args_list))
|
||||
|
||||
class Map(Map_):
|
||||
"""
|
||||
Map will apply the set operation on input sequences.
|
||||
|
||||
Which will apply the operations of every elements of the sequence.
|
||||
|
||||
Args:
|
||||
ops (Union[MultitypeFuncGraph, None]): `ops` is the operation to apply. If `ops` is `None`,
|
||||
the operations should be putted in the first input of the instance.
|
||||
|
||||
Inputs:
|
||||
- **args** (Tuple[sequence]) - If `ops` is not `None`, all the inputs should be the same length sequences,
|
||||
and each row of the sequences. e.g. If args length is 2, and for `i` in length of each sequence
|
||||
`(args[0][i], args[1][i])` will be the input of the operation.
|
||||
|
||||
If `ops` is not `None`, the first input is the operation, and the other is inputs.
|
||||
|
||||
Outputs:
|
||||
sequence, the output will be same type and same length of sequence from input and the value of each element
|
||||
is the result of operation apply each row of element. e.g. `operation(args[0][i], args[1][i])`.
|
||||
"""
|
||||
|
||||
def __init__(self, ops=None):
|
||||
self.ops = ops
|
||||
if ops:
|
||||
Map_.__init__(self, ops)
|
||||
else:
|
||||
Map_.__init__(self)
|
||||
|
||||
def __call__(self, *args):
|
||||
func = args[0]
|
||||
count = 0
|
||||
count_max = 1
|
||||
args_list = args[1:]
|
||||
if self.ops is not None:
|
||||
func = self.ops
|
||||
args_list = args
|
||||
for item in args_list:
|
||||
if isinstance(item, (tuple, list)):
|
||||
count_max = len(item)
|
||||
break
|
||||
|
||||
def get_item(x):
|
||||
nonlocal count
|
||||
if isinstance(x, (tuple, list)):
|
||||
return x[count]
|
||||
return x
|
||||
|
||||
for i in range(count_max):
|
||||
true_args = tuple(map(get_item, args_list))
|
||||
func(*true_args)
|
||||
count = i + 1
|
||||
return True
|
||||
|
||||
def register(self, *type_names):
|
||||
"""Register a function for the given type string."""
|
||||
|
||||
def deco(fn):
|
||||
self.register_fn(type_names, fn)
|
||||
return fn
|
||||
return deco
|
||||
|
||||
|
||||
class _ListAppend(ListAppend_):
|
||||
"""
|
||||
A metafuncgraph class that append one element to list.
|
||||
|
|
|
@ -21,7 +21,7 @@ A collection of operators to build nerual networks or computing functions.
|
|||
|
||||
from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
|
||||
Diag, DiagPart, DType, ExpandDims, Eye,
|
||||
Fill, GatherNd, GatherV2, InvertPermutation,
|
||||
Fill, GatherNd, GatherV2, SparseGatherV2, InvertPermutation,
|
||||
IsInstance, IsSubClass, ArgMaxWithValue, OnesLike, ZerosLike,
|
||||
Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue, Range,
|
||||
SameTypeShape, ScatterAdd, ScatterMax, ScatterUpdate,
|
||||
|
@ -122,6 +122,7 @@ __all__ = [
|
|||
'Transpose',
|
||||
'OneHot',
|
||||
'GatherV2',
|
||||
'SparseGatherV2',
|
||||
'Concat',
|
||||
'Pack',
|
||||
'Unpack',
|
||||
|
|
|
@ -526,6 +526,29 @@ class GatherV2(PrimitiveWithInfer):
|
|||
return out
|
||||
|
||||
|
||||
class SparseGatherV2(GatherV2):
|
||||
"""
|
||||
Returns a slice of input tensor based on the specified indices and axis.
|
||||
|
||||
Inputs:
|
||||
- **input_params** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
|
||||
The original Tensor.
|
||||
- **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`.
|
||||
Specifies the indices of elements of the original Tensor. Must be in the range
|
||||
`[0, input_param.shape()[axis])`.
|
||||
- **axis** (int) - Specifies the dimension index to gather indices.
|
||||
|
||||
Outputs:
|
||||
Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`.
|
||||
|
||||
Examples:
|
||||
>>> input_params = Tensor(np.array([[1, 2, 7, 42], [3, 4, 54, 22], [2, 2, 55, 3]]), mindspore.float32)
|
||||
>>> input_indices = Tensor(np.array([1, 2]), mindspore.int32)
|
||||
>>> axis = 1
|
||||
>>> out = P.GatherV2()(input_params, input_indices, axis)
|
||||
"""
|
||||
|
||||
|
||||
class Range(PrimitiveWithInfer):
|
||||
r"""
|
||||
Creates a sequence of numbers.
|
||||
|
|
|
@ -373,6 +373,8 @@ class CheckBprop(PrimitiveWithInfer):
|
|||
|
||||
def infer_shape(self, xshapes, yshapes):
|
||||
tips = f'Bprop of {self.prim_to_check}'
|
||||
validator.check_value_type('grads', xshapes, (tuple,), tips)
|
||||
validator.check_value_type('params', yshapes, (tuple,), tips)
|
||||
if len(xshapes) < len(yshapes):
|
||||
raise TypeError(f"{tips}, the size of output should be {len(yshapes)},"
|
||||
f" but got {len(xshapes)}.")
|
||||
|
@ -389,6 +391,8 @@ class CheckBprop(PrimitiveWithInfer):
|
|||
|
||||
def infer_dtype(self, xdtypes, ydtypes):
|
||||
tips = f'Bprop of {self.prim_to_check}'
|
||||
validator.check_value_type('grads', xdtypes, (tuple,), tips)
|
||||
validator.check_value_type('params', ydtypes, (tuple,), tips)
|
||||
if len(xdtypes) < len(ydtypes):
|
||||
raise TypeError(f"{tips}, the size of output should be {len(ydtypes)},"
|
||||
f" but got {len(xdtypes)}.")
|
||||
|
|
|
@ -0,0 +1,173 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test adam """
|
||||
import numpy as np
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor, Parameter, context
|
||||
from mindspore.common.api import _executor
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.nn import TrainOneStepCell, WithLossCell
|
||||
from mindspore.nn.optim import Optimizer
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore._checkparam import Rel
|
||||
|
||||
|
||||
adam_opt_for_map = C.MultitypeFuncGraph("adam_opt_for_map")
|
||||
@adam_opt_for_map.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
|
||||
"Tensor", "Tensor", "Tensor", "Bool")
|
||||
def _update_run_op_for_map(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, gradient, decay_flag):
|
||||
op_mul = P.Mul()
|
||||
op_square = P.Square()
|
||||
op_sqrt = P.Sqrt()
|
||||
op_cast = P.Cast()
|
||||
op_reshape = P.Reshape()
|
||||
op_shape = P.Shape()
|
||||
|
||||
param_fp32 = op_cast(param, mstype.float32)
|
||||
m_fp32 = op_cast(m, mstype.float32)
|
||||
v_fp32 = op_cast(v, mstype.float32)
|
||||
gradient_fp32 = op_cast(gradient, mstype.float32)
|
||||
|
||||
next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta1, gradient_fp32)
|
||||
|
||||
next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32)
|
||||
- beta2, op_square(gradient_fp32))
|
||||
|
||||
update = next_m / (op_sqrt(next_v) + eps)
|
||||
if decay_flag:
|
||||
update = update + op_mul(weight_decay_tensor, param_fp32)
|
||||
|
||||
update_with_lr = op_mul(lr, update)
|
||||
next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32))
|
||||
|
||||
next_v = F.depend(next_v, F.assign(param, next_param))
|
||||
next_v = F.depend(next_v, F.assign(m, next_m))
|
||||
next_v = F.depend(next_v, F.assign(v, next_v))
|
||||
return next_v
|
||||
|
||||
|
||||
@adam_opt_for_map.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
|
||||
"Tensor", "Tensor", "Tuple", "Bool")
|
||||
def _update_run_op_sparse_for_map(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, gradient, decay_flag):
|
||||
return gradient[2][2]
|
||||
|
||||
def _check_param_value(beta1, beta2, eps, weight_decay, prim_name):
|
||||
"""Check the type of inputs."""
|
||||
validator.check_value_type("beta1", beta1, [float], prim_name)
|
||||
validator.check_value_type("beta2", beta2, [float], prim_name)
|
||||
validator.check_value_type("eps", eps, [float], prim_name)
|
||||
validator.check_value_type("weight_dacay", weight_decay, [float], prim_name)
|
||||
validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER, prim_name)
|
||||
validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER, prim_name)
|
||||
validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER, prim_name)
|
||||
validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, prim_name)
|
||||
|
||||
|
||||
class AdamWeightDecaySparse(Optimizer):
|
||||
"""
|
||||
Implements Adam algorithm weight decay fix.
|
||||
|
||||
Args:
|
||||
params (list[Parameter]): A list of parameter, which will be updated. The element in `params`
|
||||
should be class mindspore.Parameter.
|
||||
learning_rate (Union[float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is
|
||||
Iterable or a Tensor and the dims of the Tensor is 1,
|
||||
use dynamic learning rate, then the i-th step will
|
||||
take the i-th value as the learning rate.
|
||||
When the learning_rate is float or learning_rate is a Tensor
|
||||
but the dims of the Tensor is 0, use fixed learning rate.
|
||||
Other cases are not supported. Default: 1e-3.
|
||||
beta1 (float): The exponential decay rate for the 1st moment estimates. Default: 0.9.
|
||||
Should be in range (0.0, 1.0).
|
||||
beta2 (float): The exponential decay rate for the 2nd moment estimates. Default: 0.999.
|
||||
Should be in range (0.0, 1.0).
|
||||
eps (float): Term added to the denominator to improve numerical stability. Default: 1e-6.
|
||||
Should be greater than 0.
|
||||
weight_decay (float): Weight decay (L2 penalty). Default: 0.0.
|
||||
decay_filter (Function): A function to determine whether to apply weight decay on parameters. Default:
|
||||
lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name.
|
||||
|
||||
Inputs:
|
||||
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`,
|
||||
and might be in sparse format.
|
||||
|
||||
Outputs:
|
||||
tuple[Parameter], the updated velocity value, the shape is the same as `params`.
|
||||
|
||||
Examples:
|
||||
>>> net = Net()
|
||||
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
|
||||
>>> optim = nn.AdamWeightDecay(params=net.trainable_params())
|
||||
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None)
|
||||
"""
|
||||
def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0,
|
||||
decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name):
|
||||
super(AdamWeightDecaySparse, self).__init__(learning_rate, params)
|
||||
if self.is_group:
|
||||
raise RuntimeError(f"The {self.cls_name} optimizer cannot support group setting.")
|
||||
_check_param_value(beta1, beta2, eps, weight_decay, self.cls_name)
|
||||
self.beta1 = Tensor(np.array([beta1]).astype(np.float32))
|
||||
self.beta2 = Tensor(np.array([beta2]).astype(np.float32))
|
||||
self.eps = Tensor(np.array([eps]).astype(np.float32))
|
||||
self.weight_decay_tensor = Tensor(np.array([weight_decay]).astype(np.float32))
|
||||
|
||||
self.params = self.parameters
|
||||
self.moments1 = self.params.clone(prefix="adam_m", init='zeros')
|
||||
self.moments2 = self.params.clone(prefix="adam_v", init='zeros')
|
||||
self.decay_flag = tuple(decay_filter(x) for x in self.params)
|
||||
|
||||
self.map = C.Map()
|
||||
|
||||
def construct(self, gradients):
|
||||
lr = self.get_lr()
|
||||
updated_velocity = self.map(F.partial(adam_opt_for_map, self.beta1, self.beta2, self.eps, lr,
|
||||
self.weight_decay_tensor),
|
||||
self.params, self.moments1, self.moments2, gradients, self.decay_flag)
|
||||
|
||||
return updated_velocity
|
||||
|
||||
|
||||
def test_AdamWeightDecaySparse():
|
||||
""" test_AdamWeightDecaySparse """
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
class Loss(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Loss, self).__init__()
|
||||
def construct(self, base, target):
|
||||
return base
|
||||
class NetWithSparseGatherV2(nn.Cell):
|
||||
def __init__(self):
|
||||
super(NetWithSparseGatherV2, self).__init__()
|
||||
self.w1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="w1", sparse_grad=True)
|
||||
self.w2 = Parameter(Tensor(np.ones([2, 1, 2]).astype(np.float32)), name="w2")
|
||||
self.gatherv2 = P.SparseGatherV2()
|
||||
self.axis = 0
|
||||
def construct(self, indices):
|
||||
return self.gatherv2(self.w1, indices, self.axis) * self.w2
|
||||
|
||||
inputs = Tensor(np.array([0, 1]).astype(np.int32))
|
||||
label = Tensor(np.zeros([2, 1, 2]).astype(np.float32))
|
||||
net = NetWithSparseGatherV2()
|
||||
net.set_train()
|
||||
loss = Loss()
|
||||
optimizer = AdamWeightDecaySparse(net.trainable_params())
|
||||
|
||||
net_with_loss = WithLossCell(net, loss)
|
||||
train_network = TrainOneStepCell(net_with_loss, optimizer)
|
||||
_executor.compile(train_network, inputs, label)
|
Loading…
Reference in New Issue