!45147 [Expander] Add interfaces to get shape and dtype of Node.
Merge pull request !45147 from DeshiChen/1104_shapetype
This commit is contained in:
commit
ed0266ee25
|
@ -21,19 +21,13 @@
|
||||||
#include <limits>
|
#include <limits>
|
||||||
#include "include/common/utils/utils.h"
|
#include "include/common/utils/utils.h"
|
||||||
#include "utils/ms_context.h"
|
#include "utils/ms_context.h"
|
||||||
|
#include "common/graph_kernel/bprop/expander/common_utils.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace expander {
|
namespace expander {
|
||||||
namespace bprop {
|
namespace bprop {
|
||||||
namespace {
|
namespace {
|
||||||
constexpr size_t kMaxDims = 8;
|
constexpr size_t kMaxDims = 8;
|
||||||
|
|
||||||
int64_t CheckRange(int64_t idx, int64_t dim_size) {
|
|
||||||
if (idx < -dim_size || idx >= dim_size) {
|
|
||||||
MS_EXCEPTION(IndexError) << "index {" << idx << "} is out of bounds for dimension with size {" << dim_size << "}";
|
|
||||||
}
|
|
||||||
return idx < 0 ? (idx + dim_size) : idx;
|
|
||||||
}
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
bool BpropIRBuilder::Run(const NodePtrList &inputs, const DAttr &attrs, CNodePtrList *outputs) {
|
bool BpropIRBuilder::Run(const NodePtrList &inputs, const DAttr &attrs, CNodePtrList *outputs) {
|
||||||
|
@ -72,64 +66,6 @@ NodePtr BpropIRBuilder::GetInput(size_t i) const {
|
||||||
return (*inputs_ptr_)[i];
|
return (*inputs_ptr_)[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
ShapeVector BpropIRBuilder::GetShape(const NodePtr &node) const {
|
|
||||||
auto abs = node->get()->abstract();
|
|
||||||
MS_EXCEPTION_IF_NULL(abs);
|
|
||||||
auto shape = abs->BuildShape();
|
|
||||||
MS_EXCEPTION_IF_NULL(shape);
|
|
||||||
if (shape->isa<abstract::Shape>()) {
|
|
||||||
return shape->cast<abstract::ShapePtr>()->shape();
|
|
||||||
} else if (shape->isa<abstract::SequenceShape>()) {
|
|
||||||
MS_LOG(EXCEPTION) << "The output of node " << node->get()->ToString() << " is a tuple.";
|
|
||||||
}
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<ShapeVector> BpropIRBuilder::GetShapes(const NodePtr &node) const {
|
|
||||||
auto abs = node->get()->abstract();
|
|
||||||
MS_EXCEPTION_IF_NULL(abs);
|
|
||||||
auto shape = abs->BuildShape();
|
|
||||||
MS_EXCEPTION_IF_NULL(shape);
|
|
||||||
if (shape->isa<abstract::SequenceShape>()) {
|
|
||||||
auto seq_shape_ptr = shape->cast<abstract::SequenceShapePtr>();
|
|
||||||
MS_EXCEPTION_IF_NULL(seq_shape_ptr);
|
|
||||||
const auto &shape_list = seq_shape_ptr->shape();
|
|
||||||
if (shape_list.empty()) {
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
std::vector<ShapeVector> res;
|
|
||||||
res.reserve(shape_list.size());
|
|
||||||
for (const auto &item : shape_list) {
|
|
||||||
MS_EXCEPTION_IF_NULL(item);
|
|
||||||
if (item->isa<abstract::NoShape>()) {
|
|
||||||
res.push_back({});
|
|
||||||
} else if (!item->isa<abstract::Shape>()) {
|
|
||||||
MS_LOG(EXCEPTION) << "Invalid Shape Type(" << item->ToString() << ") In Shape List";
|
|
||||||
}
|
|
||||||
auto shape_ptr = item->cast<abstract::ShapePtr>();
|
|
||||||
MS_EXCEPTION_IF_NULL(shape_ptr);
|
|
||||||
res.push_back(shape_ptr->shape());
|
|
||||||
}
|
|
||||||
return res;
|
|
||||||
} else {
|
|
||||||
MS_LOG(EXCEPTION) << "The output of node " << node->get()->ToString() << " is not a tuple.";
|
|
||||||
}
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
|
|
||||||
TypePtr BpropIRBuilder::GetDtype(const NodePtr &node) const {
|
|
||||||
auto abs = node->get()->abstract();
|
|
||||||
MS_EXCEPTION_IF_NULL(abs);
|
|
||||||
auto dtype = abs->BuildType();
|
|
||||||
MS_EXCEPTION_IF_NULL(dtype);
|
|
||||||
if (dtype->isa<TensorType>()) {
|
|
||||||
return dtype->cast<TensorTypePtr>()->element();
|
|
||||||
} else if (dtype->isa<Tuple>()) {
|
|
||||||
MS_LOG(EXCEPTION) << "The output of node " << node->get()->ToString() << " is a tuple.";
|
|
||||||
}
|
|
||||||
return dtype;
|
|
||||||
}
|
|
||||||
|
|
||||||
ValuePtr BpropIRBuilder::GetAttr(const NodePtr &node, const std::string &attr) const {
|
ValuePtr BpropIRBuilder::GetAttr(const NodePtr &node, const std::string &attr) const {
|
||||||
auto p = GetCNodePrimitive(node->get());
|
auto p = GetCNodePrimitive(node->get());
|
||||||
MS_EXCEPTION_IF_NULL(p);
|
MS_EXCEPTION_IF_NULL(p);
|
||||||
|
|
|
@ -47,10 +47,10 @@ class BpropIRBuilder : public Emitter {
|
||||||
const NodePtrList &GetInputs() const { return *inputs_ptr_; }
|
const NodePtrList &GetInputs() const { return *inputs_ptr_; }
|
||||||
|
|
||||||
// For node that has single output
|
// For node that has single output
|
||||||
ShapeVector GetShape(const NodePtr &node) const;
|
ShapeVector GetShape(const NodePtr &node) const { return node->shape(); }
|
||||||
// For node that has multiple outputs
|
// For node that has multiple outputs
|
||||||
std::vector<ShapeVector> GetShapes(const NodePtr &node) const;
|
std::vector<ShapeVector> GetShapes(const NodePtr &node) const { return node->shapes(); }
|
||||||
TypePtr GetDtype(const NodePtr &node) const;
|
TypePtr GetDtype(const NodePtr &node) const { return node->dtype(); }
|
||||||
TypeId GetDtypeId(const NodePtr &node) const { return GetDtype(node)->type_id(); }
|
TypeId GetDtypeId(const NodePtr &node) const { return GetDtype(node)->type_id(); }
|
||||||
ValuePtr GetAttr(const NodePtr &node, const std::string &attr) const;
|
ValuePtr GetAttr(const NodePtr &node, const std::string &attr) const;
|
||||||
int64_t GetSize(const NodePtr &node) const;
|
int64_t GetSize(const NodePtr &node) const;
|
||||||
|
|
|
@ -31,6 +31,10 @@ namespace mindspore::expander::bprop {
|
||||||
namespace {
|
namespace {
|
||||||
NodePtr ReduceSumWithCast(const BpropIRBuilder *ib, const NodePtr &dx, const std::vector<int64_t> &axis) {
|
NodePtr ReduceSumWithCast(const BpropIRBuilder *ib, const NodePtr &dx, const std::vector<int64_t> &axis) {
|
||||||
auto dx_origin_dtypeptr = ib->GetDtype(dx);
|
auto dx_origin_dtypeptr = ib->GetDtype(dx);
|
||||||
|
auto need_reduce = ib->NeedReduce(ib->GetShape(dx), axis, false);
|
||||||
|
if (!need_reduce.first) {
|
||||||
|
return ib->Reshape(dx, need_reduce.second);
|
||||||
|
}
|
||||||
auto dx_origin_dtype = dx_origin_dtypeptr->type_id();
|
auto dx_origin_dtype = dx_origin_dtypeptr->type_id();
|
||||||
if (dx_origin_dtype == TypeId::kNumberTypeInt16 || dx_origin_dtype == TypeId::kNumberTypeInt32 ||
|
if (dx_origin_dtype == TypeId::kNumberTypeInt16 || dx_origin_dtype == TypeId::kNumberTypeInt32 ||
|
||||||
dx_origin_dtype == TypeId::kNumberTypeInt64) {
|
dx_origin_dtype == TypeId::kNumberTypeInt64) {
|
||||||
|
@ -285,12 +289,12 @@ NodePtrList BinopGradCommonWithShift(const BpropIRBuilder *ib, const NodePtr &x,
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<int64_t> Range(int64_t start, int64_t stop, int64_t step) {
|
std::vector<int64_t> Range(int64_t start, int64_t stop, int64_t step) {
|
||||||
auto size = (stop - start) / step;
|
int64_t size = (step != 0) ? ((stop - start) / step) : 0;
|
||||||
if (size <= 0) {
|
if (size <= 0) {
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
size = ((stop - start) % step == 0) ? size : size + 1;
|
size = ((stop - start) % step == 0) ? size : size + 1;
|
||||||
std::vector<int64_t> range(size);
|
std::vector<int64_t> range(LongToSize(size));
|
||||||
std::generate(range.begin(), range.end(), [n = start - step, step]() mutable {
|
std::generate(range.begin(), range.end(), [n = start - step, step]() mutable {
|
||||||
n = n + step;
|
n = n + step;
|
||||||
return n;
|
return n;
|
||||||
|
@ -314,6 +318,13 @@ std::vector<int64_t> GetTransposeAxis(const std::vector<int64_t> &x_shape, int64
|
||||||
return reverse_axis;
|
return reverse_axis;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int64_t CheckRange(int64_t idx, int64_t dim_size) {
|
||||||
|
if (idx < -dim_size || idx >= dim_size) {
|
||||||
|
MS_EXCEPTION(IndexError) << "index {" << idx << "} is out of bounds for dimension with size {" << dim_size << "}";
|
||||||
|
}
|
||||||
|
return idx < 0 ? (idx + dim_size) : idx;
|
||||||
|
}
|
||||||
|
|
||||||
NodePtr GetEps(const BpropIRBuilder *ib, const TypePtr &type) {
|
NodePtr GetEps(const BpropIRBuilder *ib, const TypePtr &type) {
|
||||||
switch (type->type_id()) {
|
switch (type->type_id()) {
|
||||||
case kNumberTypeFloat16:
|
case kNumberTypeFloat16:
|
||||||
|
@ -336,7 +347,7 @@ NodePtrList BinopGatherCommon(const BpropIRBuilder *ib) {
|
||||||
auto x_shp = ib->GetShape(x);
|
auto x_shp = ib->GetShape(x);
|
||||||
auto out_shp = ib->GetShape(dout);
|
auto out_shp = ib->GetShape(dout);
|
||||||
auto ind_shp = ib->GetShape(indices);
|
auto ind_shp = ib->GetShape(indices);
|
||||||
auto axis_v = GetIntFromValueNode(axis);
|
auto axis_v = CheckRange(GetIntFromValueNode(axis), SizeToLong(x_shp.size()));
|
||||||
if (out_shp.empty()) {
|
if (out_shp.empty()) {
|
||||||
dout = ib->Emit("ExpandDims", {dout, ib->Tensor(-1)});
|
dout = ib->Emit("ExpandDims", {dout, ib->Tensor(-1)});
|
||||||
}
|
}
|
||||||
|
@ -515,10 +526,7 @@ NodePtr MinOrMaxGrad(const BpropIRBuilder *ib, const NodePtr &x, const std::vect
|
||||||
auto indicators = ib->Cast(ib->Emit("Equal", {y, x}), ib->GetDtype(grad));
|
auto indicators = ib->Cast(ib->Emit("Equal", {y, x}), ib->GetDtype(grad));
|
||||||
auto minn = 1e-24;
|
auto minn = 1e-24;
|
||||||
auto min_num = ib->Tensor(minn, ib->GetDtype(grad));
|
auto min_num = ib->Tensor(minn, ib->GetDtype(grad));
|
||||||
auto num_selected =
|
auto num_selected = ib->Reshape(ib->ReduceSum(indicators, axis, false), output_shape_kept_dims) + min_num;
|
||||||
ib->Reshape(ib->Emit("ReduceSum", {indicators, ib->Value<ShapeVector>(axis)}, {{"keep_dims", MakeValue(false)}}),
|
|
||||||
output_shape_kept_dims) +
|
|
||||||
min_num;
|
|
||||||
return indicators / num_selected * grad;
|
return indicators / num_selected * grad;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -38,6 +38,8 @@ std::vector<int64_t> ReduceShape(const std::vector<int64_t> &x, const std::vecto
|
||||||
|
|
||||||
std::vector<int64_t> GetAxisList(const ValuePtr &value);
|
std::vector<int64_t> GetAxisList(const ValuePtr &value);
|
||||||
|
|
||||||
|
int64_t CheckRange(int64_t idx, int64_t dim_size);
|
||||||
|
|
||||||
NodePtrList BinopGradCommon(const BpropIRBuilder *ib, const NodePtr &x, const NodePtr &y, const NodePtr &dx,
|
NodePtrList BinopGradCommon(const BpropIRBuilder *ib, const NodePtr &x, const NodePtr &y, const NodePtr &dx,
|
||||||
const NodePtr &dy);
|
const NodePtr &dy);
|
||||||
|
|
||||||
|
|
|
@ -17,6 +17,9 @@
|
||||||
#include "common/graph_kernel/bprop/expander/emitter.h"
|
#include "common/graph_kernel/bprop/expander/emitter.h"
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
#include <functional>
|
||||||
|
#include <unordered_set>
|
||||||
|
#include <utility>
|
||||||
#include "ops/primitive_c.h"
|
#include "ops/primitive_c.h"
|
||||||
#include "utils/anf_utils.h"
|
#include "utils/anf_utils.h"
|
||||||
|
|
||||||
|
@ -69,6 +72,28 @@ NodePtr Emitter::Log(const NodePtr &x) const {
|
||||||
{"cust_aicpu", MakeValue(kLogOpName)}});
|
{"cust_aicpu", MakeValue(kLogOpName)}});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
NodePtr Emitter::Cast(const NodePtr &node, const TypePtr &type) const {
|
||||||
|
// do not emit a node when the dst type is the same as src type
|
||||||
|
if (node->dtype()->type_id() == type->type_id()) {
|
||||||
|
return node;
|
||||||
|
}
|
||||||
|
return Emit("Cast", {node, EmitValue(type)});
|
||||||
|
}
|
||||||
|
|
||||||
|
NodePtr Emitter::Reshape(const NodePtr &node, const ShapeVector &shape) const {
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
auto node_shape = node->shape();
|
||||||
|
if (shape.size() != node_shape.size()) {
|
||||||
|
return Emit(prim::kReshape, {node, Value(shape)});
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < shape.size(); ++i) {
|
||||||
|
if (shape[i] != node_shape[i] && shape[i] != -1) {
|
||||||
|
return Emit(prim::kReshape, {node, Value(shape)});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return node;
|
||||||
|
}
|
||||||
|
|
||||||
NodePtr Emitter::MatMul(const NodePtr &a, const NodePtr &b, bool transpose_a, bool transpose_b) const {
|
NodePtr Emitter::MatMul(const NodePtr &a, const NodePtr &b, bool transpose_a, bool transpose_b) const {
|
||||||
return Emit(prim::kPrimMatMul->name(), {a, b},
|
return Emit(prim::kPrimMatMul->name(), {a, b},
|
||||||
{{"transpose_x1", MakeValue(transpose_a)},
|
{{"transpose_x1", MakeValue(transpose_a)},
|
||||||
|
@ -114,7 +139,54 @@ NodePtr Emitter::ZerosLike(const NodePtr &node) const {
|
||||||
return Emit(prim::kZerosLike, {node});
|
return Emit(prim::kZerosLike, {node});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::pair<bool, ShapeVector> Emitter::NeedReduce(const ShapeVector &shape, const std::vector<int64_t> &axis,
|
||||||
|
bool keep_dim) const {
|
||||||
|
if (shape.empty()) {
|
||||||
|
return std::make_pair(false, shape);
|
||||||
|
}
|
||||||
|
auto rank = SizeToLong(shape.size());
|
||||||
|
auto real_axis = axis;
|
||||||
|
if (real_axis.empty()) {
|
||||||
|
// all reduce
|
||||||
|
for (int64_t i = 0; i < rank; ++i) {
|
||||||
|
real_axis.push_back(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::unordered_set<size_t> uniq_axis;
|
||||||
|
for (size_t i = 0; i < real_axis.size(); ++i) {
|
||||||
|
if (real_axis[i] < -rank || real_axis[i] >= rank) {
|
||||||
|
MS_EXCEPTION(ValueError) << "Reduce axis[" << i << "] is " << real_axis[i] << ", which is out of range [-" << rank
|
||||||
|
<< ", " << rank << ") for shape: " << shape;
|
||||||
|
}
|
||||||
|
auto axis_i = real_axis[i] < 0 ? real_axis[i] + rank : real_axis[i];
|
||||||
|
(void)uniq_axis.insert(LongToSize(axis_i));
|
||||||
|
}
|
||||||
|
// Calc reduce output shape
|
||||||
|
ShapeVector out_shape;
|
||||||
|
bool need_reduce = false;
|
||||||
|
for (size_t i = 0; i < shape.size(); ++i) {
|
||||||
|
if (uniq_axis.find(i) == uniq_axis.end()) {
|
||||||
|
// not reduce axis
|
||||||
|
out_shape.push_back(shape[i]);
|
||||||
|
} else {
|
||||||
|
// reduce axis
|
||||||
|
if (shape[i] != 1) {
|
||||||
|
need_reduce = true;
|
||||||
|
}
|
||||||
|
if (keep_dim) {
|
||||||
|
out_shape.push_back(1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return std::make_pair(need_reduce, out_shape);
|
||||||
|
}
|
||||||
|
|
||||||
NodePtr Emitter::ReduceSum(const NodePtr &x, const ShapeVector &axis, bool keep_dims) const {
|
NodePtr Emitter::ReduceSum(const NodePtr &x, const ShapeVector &axis, bool keep_dims) const {
|
||||||
|
MS_EXCEPTION_IF_NULL(x);
|
||||||
|
auto need_reduce = NeedReduce(x->shape(), axis, keep_dims);
|
||||||
|
if (!need_reduce.first) {
|
||||||
|
return Reshape(x, need_reduce.second);
|
||||||
|
}
|
||||||
return Emit(prim::kPrimReduceSum->name(), {x, Value<ShapeVector>(axis)}, {{"keep_dims", MakeValue(keep_dims)}});
|
return Emit(prim::kPrimReduceSum->name(), {x, Value<ShapeVector>(axis)}, {{"keep_dims", MakeValue(keep_dims)}});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -19,6 +19,7 @@
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
#include "ir/func_graph.h"
|
#include "ir/func_graph.h"
|
||||||
#include "ops/core_ops.h"
|
#include "ops/core_ops.h"
|
||||||
#include "include/common/utils/utils.h"
|
#include "include/common/utils/utils.h"
|
||||||
|
@ -44,12 +45,10 @@ class Emitter {
|
||||||
return Emit(prim::kTupleGetItem, {input, Value(static_cast<int64_t>(i))});
|
return Emit(prim::kTupleGetItem, {input, Value(static_cast<int64_t>(i))});
|
||||||
}
|
}
|
||||||
|
|
||||||
NodePtr Cast(const NodePtr &node, const TypePtr &type) const { return Emit("Cast", {node, EmitValue(type)}); }
|
NodePtr Cast(const NodePtr &node, const TypePtr &type) const;
|
||||||
NodePtr Cast(const NodePtr &node, TypeId type_id) const { return Cast(node, TypeIdToType(type_id)); }
|
NodePtr Cast(const NodePtr &node, TypeId type_id) const { return Cast(node, TypeIdToType(type_id)); }
|
||||||
|
|
||||||
NodePtr Reshape(const NodePtr &node, const ShapeVector &shape) const {
|
NodePtr Reshape(const NodePtr &node, const ShapeVector &shape) const;
|
||||||
return Emit(prim::kReshape, {node, Tensor(shape)});
|
|
||||||
}
|
|
||||||
NodePtr ExpandDims(const NodePtr &node, int64_t axis) const { return Emit(kExpandDimsOpName, {node, Value(axis)}); }
|
NodePtr ExpandDims(const NodePtr &node, int64_t axis) const { return Emit(kExpandDimsOpName, {node, Value(axis)}); }
|
||||||
NodePtr Abs(const NodePtr &node) const { return Emit(prim::kAbs, {node}); }
|
NodePtr Abs(const NodePtr &node) const { return Emit(prim::kAbs, {node}); }
|
||||||
NodePtr Neg(const NodePtr &node) const { return Emit(prim::kNeg, {node}); }
|
NodePtr Neg(const NodePtr &node) const { return Emit(prim::kNeg, {node}); }
|
||||||
|
@ -97,6 +96,8 @@ class Emitter {
|
||||||
}
|
}
|
||||||
NodePtr LogicalAnd(const NodePtr &lhs, const NodePtr &rhs) const { return Emit("LogicalAnd", {lhs, rhs}); }
|
NodePtr LogicalAnd(const NodePtr &lhs, const NodePtr &rhs) const { return Emit("LogicalAnd", {lhs, rhs}); }
|
||||||
NodePtr LogicalOr(const NodePtr &lhs, const NodePtr &rhs) const { return Emit("LogicalOr", {lhs, rhs}); }
|
NodePtr LogicalOr(const NodePtr &lhs, const NodePtr &rhs) const { return Emit("LogicalOr", {lhs, rhs}); }
|
||||||
|
std::pair<bool, ShapeVector> NeedReduce(const ShapeVector &shape, const std::vector<int64_t> &axis,
|
||||||
|
bool keep_dim) const;
|
||||||
NodePtr ReduceSum(const NodePtr &x, const ShapeVector &axis = {}, bool keep_dims = false) const;
|
NodePtr ReduceSum(const NodePtr &x, const ShapeVector &axis = {}, bool keep_dims = false) const;
|
||||||
|
|
||||||
NodePtr ZerosLike(const NodePtr &node) const;
|
NodePtr ZerosLike(const NodePtr &node) const;
|
||||||
|
@ -120,6 +121,8 @@ class Emitter {
|
||||||
return EmitValue(tensor_ptr);
|
return EmitValue(tensor_ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ExpanderInferPtr infer() const { return infer_; }
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
NodePtr NewNode(const AnfNodePtr &anfnode) const { return std::make_shared<Node>(anfnode, this); }
|
NodePtr NewNode(const AnfNodePtr &anfnode) const { return std::make_shared<Node>(anfnode, this); }
|
||||||
NodePtr CmpOpWithCast(const std::string &op, const NodePtr &lhs, const NodePtr &rhs, const TypePtr &dst_type) const {
|
NodePtr CmpOpWithCast(const std::string &op, const NodePtr &lhs, const NodePtr &rhs, const TypePtr &dst_type) const {
|
||||||
|
|
|
@ -63,5 +63,17 @@ void CppInfer::Infer(const NodePtr &node) {
|
||||||
}
|
}
|
||||||
cnode->set_abstract(result);
|
cnode->set_abstract(result);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
BaseShapePtr CppInfer::GetShape(const NodePtr &node) {
|
||||||
|
auto abs = node->get()->abstract();
|
||||||
|
MS_EXCEPTION_IF_NULL(abs);
|
||||||
|
return abs->BuildShape();
|
||||||
|
}
|
||||||
|
|
||||||
|
TypePtr CppInfer::GetDtype(const NodePtr &node) {
|
||||||
|
auto abs = node->get()->abstract();
|
||||||
|
MS_EXCEPTION_IF_NULL(abs);
|
||||||
|
return abs->BuildType();
|
||||||
|
}
|
||||||
} // namespace expander
|
} // namespace expander
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -24,7 +24,11 @@ namespace expander {
|
||||||
/// \brief ExpanderInfer is the adapter for inferring functions that is called in emitter.
|
/// \brief ExpanderInfer is the adapter for inferring functions that is called in emitter.
|
||||||
class ExpanderInfer {
|
class ExpanderInfer {
|
||||||
public:
|
public:
|
||||||
|
/// \brief Infer shape and dtype for node
|
||||||
virtual void Infer(const NodePtr &node) = 0;
|
virtual void Infer(const NodePtr &node) = 0;
|
||||||
|
|
||||||
|
virtual BaseShapePtr GetShape(const NodePtr &node) = 0;
|
||||||
|
virtual TypePtr GetDtype(const NodePtr &node) = 0;
|
||||||
};
|
};
|
||||||
using ExpanderInferPtr = std::shared_ptr<ExpanderInfer>;
|
using ExpanderInferPtr = std::shared_ptr<ExpanderInfer>;
|
||||||
|
|
||||||
|
@ -32,6 +36,8 @@ using ExpanderInferPtr = std::shared_ptr<ExpanderInfer>;
|
||||||
class CppInfer : public ExpanderInfer {
|
class CppInfer : public ExpanderInfer {
|
||||||
public:
|
public:
|
||||||
void Infer(const NodePtr &node) override;
|
void Infer(const NodePtr &node) override;
|
||||||
|
BaseShapePtr GetShape(const NodePtr &node) override;
|
||||||
|
TypePtr GetDtype(const NodePtr &node) override;
|
||||||
};
|
};
|
||||||
} // namespace expander
|
} // namespace expander
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -15,6 +15,9 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include "common/graph_kernel/bprop/expander/node.h"
|
#include "common/graph_kernel/bprop/expander/node.h"
|
||||||
|
#include <algorithm>
|
||||||
|
#include "common/graph_kernel/bprop/expander/emitter.h"
|
||||||
|
#include "common/graph_kernel/bprop/expander/infer.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace expander {
|
namespace expander {
|
||||||
|
@ -22,5 +25,64 @@ Node::Node(const AnfNodePtr &node, const Emitter *emitter) : anf_node_(node), em
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
MS_EXCEPTION_IF_NULL(emitter);
|
MS_EXCEPTION_IF_NULL(emitter);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<int64_t> Node::shape() {
|
||||||
|
if (shape_ == nullptr) {
|
||||||
|
shape_ = emitter()->infer()->GetShape(shared_from_this());
|
||||||
|
MS_EXCEPTION_IF_NULL(shape_);
|
||||||
|
}
|
||||||
|
if (shape_->isa<abstract::NoShape>()) {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
auto shape = shape_->cast<abstract::ShapePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(shape);
|
||||||
|
return shape->shape();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::vector<int64_t>> Node::shapes() {
|
||||||
|
if (shape_ == nullptr) {
|
||||||
|
shape_ = emitter()->infer()->GetShape(shared_from_this());
|
||||||
|
MS_EXCEPTION_IF_NULL(shape_);
|
||||||
|
}
|
||||||
|
auto tuple_shape = shape_->cast<abstract::SequenceShapePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(tuple_shape);
|
||||||
|
auto &shape_list = tuple_shape->shape();
|
||||||
|
std::vector<ShapeVector> shapes(shape_list.size());
|
||||||
|
(void)std::transform(shape_list.cbegin(), shape_list.cend(), shapes.begin(), [](const BaseShapePtr &bs) {
|
||||||
|
if (bs->isa<abstract::NoShape>()) {
|
||||||
|
return ShapeVector();
|
||||||
|
}
|
||||||
|
auto shape = bs->cast<abstract::ShapePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(shape);
|
||||||
|
return shape->shape();
|
||||||
|
});
|
||||||
|
return shapes;
|
||||||
|
}
|
||||||
|
|
||||||
|
TypePtr Node::dtype() {
|
||||||
|
if (type_ == nullptr) {
|
||||||
|
type_ = emitter()->infer()->GetDtype(shared_from_this());
|
||||||
|
MS_EXCEPTION_IF_NULL(type_);
|
||||||
|
if (type_->isa<TensorType>()) {
|
||||||
|
type_ = type_->cast<TensorTypePtr>()->element();
|
||||||
|
MS_EXCEPTION_IF_NULL(type_);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return type_;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<TypePtr> Node::dtypes() {
|
||||||
|
if (type_ == nullptr) {
|
||||||
|
type_ = emitter()->infer()->GetDtype(shared_from_this());
|
||||||
|
MS_EXCEPTION_IF_NULL(type_);
|
||||||
|
}
|
||||||
|
auto tuple = type_->cast<TuplePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(tuple);
|
||||||
|
std::vector<TypePtr> result(tuple->size());
|
||||||
|
auto elements = tuple->elements();
|
||||||
|
(void)std::transform(elements.cbegin(), elements.cend(), result.begin(),
|
||||||
|
[](const TypePtr &t) { return t->isa<TensorType>() ? t->cast<TensorTypePtr>()->element() : t; });
|
||||||
|
return result;
|
||||||
|
}
|
||||||
} // namespace expander
|
} // namespace expander
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include "ir/anf.h"
|
#include "ir/anf.h"
|
||||||
|
#include "ir/dtype.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace expander {
|
namespace expander {
|
||||||
|
@ -42,12 +43,24 @@ class Node : public std::enable_shared_from_this<Node> {
|
||||||
return anf_node_->cast<T>();
|
return anf_node_->cast<T>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<int64_t> shape();
|
||||||
|
std::vector<std::vector<int64_t>> shapes();
|
||||||
|
|
||||||
|
TypePtr dtype();
|
||||||
|
std::vector<TypePtr> dtypes();
|
||||||
|
|
||||||
const Emitter *emitter() const { return emitter_; }
|
const Emitter *emitter() const { return emitter_; }
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
// the wrapped anfnode.
|
||||||
AnfNodePtr anf_node_{nullptr};
|
AnfNodePtr anf_node_{nullptr};
|
||||||
// hold an emitter for operator overloading.
|
// hold the emitter who created this node.
|
||||||
const Emitter *emitter_{nullptr};
|
const Emitter *emitter_{nullptr};
|
||||||
|
|
||||||
|
// cache the output shape after first query
|
||||||
|
BaseShapePtr shape_{nullptr};
|
||||||
|
// cache the output dtype after first query
|
||||||
|
TypePtr type_{nullptr};
|
||||||
};
|
};
|
||||||
using NodePtr = std::shared_ptr<Node>;
|
using NodePtr = std::shared_ptr<Node>;
|
||||||
using NodePtrList = std::vector<NodePtr>;
|
using NodePtrList = std::vector<NodePtr>;
|
||||||
|
|
|
@ -37,7 +37,7 @@ NodePtrList GatherDropNegatives(const BpropIRBuilder *ib, const NodePtr ¶ms,
|
||||||
for (size_t i = 0; i < back_size; ++i) {
|
for (size_t i = 0; i < back_size; ++i) {
|
||||||
broadcastable_shape.push_back(1);
|
broadcastable_shape.push_back(1);
|
||||||
}
|
}
|
||||||
is_positive = ib->Emit("Reshape", {is_positive, ib->Value<ShapeVector>(broadcastable_shape)});
|
is_positive = ib->Reshape(is_positive, broadcastable_shape);
|
||||||
auto gathered_shape = ib->GetShape(gathered);
|
auto gathered_shape = ib->GetShape(gathered);
|
||||||
is_positive = ib->Emit("LogicalAnd",
|
is_positive = ib->Emit("LogicalAnd",
|
||||||
{is_positive, ib->Emit("Fill", {ib->EmitValue(kBool), ib->Value<ShapeVector>(gathered_shape),
|
{is_positive, ib->Emit("Fill", {ib->EmitValue(kBool), ib->Value<ShapeVector>(gathered_shape),
|
||||||
|
@ -58,8 +58,8 @@ NodePtrList UnsortedSegmentMinOrMaxGrad(const BpropIRBuilder *ib, const NodePtr
|
||||||
|
|
||||||
auto tmp = ib->Emit("Equal", {x, gathered_outputs});
|
auto tmp = ib->Emit("Equal", {x, gathered_outputs});
|
||||||
auto is_selected = ib->Emit("LogicalAnd", {tmp, is_positive});
|
auto is_selected = ib->Emit("LogicalAnd", {tmp, is_positive});
|
||||||
auto num_selected = ib->Emit(
|
auto num_selected =
|
||||||
"UnsortedSegmentSum", {ib->Emit("Cast", {is_selected, ib->Value(ib->GetDtype(dout))}), segment_ids, num_segments});
|
ib->Emit("UnsortedSegmentSum", {ib->Cast(is_selected, ib->GetDtype(dout)), segment_ids, num_segments});
|
||||||
auto weighted_grads = ib->Emit("RealDiv", {dout, num_selected});
|
auto weighted_grads = ib->Emit("RealDiv", {dout, num_selected});
|
||||||
auto temp_outs_2 = GatherDropNegatives(ib, weighted_grads, nullptr, zero_clipped_indices, is_positive);
|
auto temp_outs_2 = GatherDropNegatives(ib, weighted_grads, nullptr, zero_clipped_indices, is_positive);
|
||||||
MS_EXCEPTION_IF_CHECK_FAIL(temp_outs.size() > 0, "Outputs should not be empty.");
|
MS_EXCEPTION_IF_CHECK_FAIL(temp_outs.size() > 0, "Outputs should not be empty.");
|
||||||
|
@ -158,7 +158,7 @@ REG_BPROP_BUILDER("SparseGatherV2").SetBody([](const BpropIRBuilder *ib) -> Node
|
||||||
auto axis = ib->GetInput(kIndex2);
|
auto axis = ib->GetInput(kIndex2);
|
||||||
auto dout = ib->GetInput(kIndex4);
|
auto dout = ib->GetInput(kIndex4);
|
||||||
auto x_shp = ib->GetShape(x);
|
auto x_shp = ib->GetShape(x);
|
||||||
auto axis_int = GetValue<int64_t>(axis->get<ValueNodePtr>()->value());
|
auto axis_int = CheckRange(GetValue<int64_t>(axis->get<ValueNodePtr>()->value()), SizeToLong(x_shp.size()));
|
||||||
if (axis_int == 0) {
|
if (axis_int == 0) {
|
||||||
ShapeVector values_shape{ib->GetSize(indices)};
|
ShapeVector values_shape{ib->GetSize(indices)};
|
||||||
if (x_shp.size() > 1) {
|
if (x_shp.size() > 1) {
|
||||||
|
@ -730,7 +730,7 @@ REG_BPROP_BUILDER("BroadcastTo").SetBody([](const BpropIRBuilder *ib) -> NodePtr
|
||||||
auto tuple_out = BroadcastGradientArgs(broadcast_shape, x_shape);
|
auto tuple_out = BroadcastGradientArgs(broadcast_shape, x_shape);
|
||||||
MS_EXCEPTION_IF_CHECK_FAIL(!tuple_out.empty(), "BroadcastGradientArgs out should not be empty!");
|
MS_EXCEPTION_IF_CHECK_FAIL(!tuple_out.empty(), "BroadcastGradientArgs out should not be empty!");
|
||||||
auto reduction_axes = tuple_out[kIndex1];
|
auto reduction_axes = tuple_out[kIndex1];
|
||||||
auto reduced_grad = ib->Emit("ReduceSum", {dout, ib->Value(reduction_axes)}, {{"keep_dims", MakeValue(true)}});
|
auto reduced_grad = ib->ReduceSum(dout, reduction_axes, true);
|
||||||
auto dx = ib->Reshape(reduced_grad, x_shape);
|
auto dx = ib->Reshape(reduced_grad, x_shape);
|
||||||
return {dx};
|
return {dx};
|
||||||
});
|
});
|
||||||
|
@ -846,12 +846,17 @@ REG_BPROP_BUILDER("Tile").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||||
auto dout_reshaped = ib->Reshape(dout, r_shape);
|
auto dout_reshaped = ib->Reshape(dout, r_shape);
|
||||||
auto dout_dtype = ib->GetDtype(dout_reshaped)->type_id();
|
auto dout_dtype = ib->GetDtype(dout_reshaped)->type_id();
|
||||||
NodePtr dx;
|
NodePtr dx;
|
||||||
if (dout_dtype == kNumberTypeInt16 || dout_dtype == kNumberTypeInt32 || dout_dtype == kNumberTypeInt64) {
|
auto need_reduce = ib->NeedReduce(r_shape, axis, false);
|
||||||
dout_reshaped = ib->Cast(dout_reshaped, kFloat32);
|
if (need_reduce.first) {
|
||||||
dx = ib->Emit("ReduceSum", {dout_reshaped, ib->Value<ShapeVector>(axis)}, {{"keep_dims", MakeValue(false)}});
|
if (dout_dtype == kNumberTypeInt16 || dout_dtype == kNumberTypeInt32 || dout_dtype == kNumberTypeInt64) {
|
||||||
dx = ib->Cast(dx, dout_dtype);
|
dout_reshaped = ib->Cast(dout_reshaped, kFloat32);
|
||||||
|
dx = ib->Emit("ReduceSum", {dout_reshaped, ib->Value<ShapeVector>(axis)}, {{"keep_dims", MakeValue(false)}});
|
||||||
|
dx = ib->Cast(dx, dout_dtype);
|
||||||
|
} else {
|
||||||
|
dx = ib->Emit("ReduceSum", {dout_reshaped, ib->Value<ShapeVector>(axis)}, {{"keep_dims", MakeValue(false)}});
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
dx = ib->Emit("ReduceSum", {dout_reshaped, ib->Value<ShapeVector>(axis)}, {{"keep_dims", MakeValue(false)}});
|
dx = ib->Reshape(dout_reshaped, need_reduce.second);
|
||||||
}
|
}
|
||||||
dx = ib->Reshape(dx, shapex);
|
dx = ib->Reshape(dx, shapex);
|
||||||
return {dx, ib->ZerosLike(input_multiples)};
|
return {dx, ib->ZerosLike(input_multiples)};
|
||||||
|
|
|
@ -914,7 +914,7 @@ REG_BPROP_BUILDER("ReduceProd").SetBody([](const BpropIRBuilder *ib) -> NodePtrL
|
||||||
auto dout = ib->GetInput(kIndex3);
|
auto dout = ib->GetInput(kIndex3);
|
||||||
auto input_shape = ib->GetShape(x);
|
auto input_shape = ib->GetShape(x);
|
||||||
auto output_shape_kept_dims = ReduceShape(input_shape, GetAxisValue(axis));
|
auto output_shape_kept_dims = ReduceShape(input_shape, GetAxisValue(axis));
|
||||||
dout = ib->Emit("Reshape", {dout, ib->Value<ShapeVector>(output_shape_kept_dims)});
|
dout = ib->Reshape(dout, output_shape_kept_dims);
|
||||||
auto tile_scaling = TupleDiv(input_shape, output_shape_kept_dims);
|
auto tile_scaling = TupleDiv(input_shape, output_shape_kept_dims);
|
||||||
auto grad = ib->Emit("Tile", {dout, ib->Value<ShapeVector>(tile_scaling)});
|
auto grad = ib->Emit("Tile", {dout, ib->Value<ShapeVector>(tile_scaling)});
|
||||||
auto [pack_shape, perm] = SplitShapeIndex(input_shape, GetAxisValue(axis));
|
auto [pack_shape, perm] = SplitShapeIndex(input_shape, GetAxisValue(axis));
|
||||||
|
@ -964,7 +964,11 @@ REG_BPROP_BUILDER("ReduceMean").SetBody([](const BpropIRBuilder *ib) -> NodePtrL
|
||||||
}
|
}
|
||||||
return size;
|
return size;
|
||||||
};
|
};
|
||||||
auto div_shape = getSize(shape_x) / getSize(shape_out);
|
auto shape_out_sz = getSize(shape_out);
|
||||||
|
if (shape_out_sz == 0) {
|
||||||
|
MS_EXCEPTION(ValueError) << "out shape size can not be 0";
|
||||||
|
}
|
||||||
|
auto div_shape = getSize(shape_x) / shape_out_sz;
|
||||||
auto dx = ib->RealDiv(grad, ib->Tensor(div_shape, ib->GetDtype(grad)));
|
auto dx = ib->RealDiv(grad, ib->Tensor(div_shape, ib->GetDtype(grad)));
|
||||||
return {dx, ib->ZerosLike(axis)};
|
return {dx, ib->ZerosLike(axis)};
|
||||||
});
|
});
|
||||||
|
|
|
@ -899,7 +899,7 @@ REG_BPROP_BUILDER("Tanh").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||||
auto dout = ib->GetInput(kIndex2);
|
auto dout = ib->GetInput(kIndex2);
|
||||||
auto x_dtype_id = ib->GetDtypeId(x);
|
auto x_dtype_id = ib->GetDtypeId(x);
|
||||||
NodePtr dx;
|
NodePtr dx;
|
||||||
if (x_dtype_id == 46 || x_dtype_id == 47) {
|
if (x_dtype_id == kNumberTypeComplex64 || x_dtype_id == kNumberTypeComplex128) {
|
||||||
dout = ib->Emit("Conj", {dout});
|
dout = ib->Emit("Conj", {dout});
|
||||||
dx = ib->Emit("TanhGrad", {out, dout});
|
dx = ib->Emit("TanhGrad", {out, dout});
|
||||||
dx = ib->Emit("Conj", {dx});
|
dx = ib->Emit("Conj", {dx});
|
||||||
|
|
Loading…
Reference in New Issue