!45147 [Expander] Add interfaces to get shape and dtype of Node.

Merge pull request !45147 from DeshiChen/1104_shapetype
This commit is contained in:
i-robot 2022-11-09 13:34:23 +00:00 committed by Gitee
commit ed0266ee25
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
13 changed files with 216 additions and 93 deletions

View File

@ -21,19 +21,13 @@
#include <limits> #include <limits>
#include "include/common/utils/utils.h" #include "include/common/utils/utils.h"
#include "utils/ms_context.h" #include "utils/ms_context.h"
#include "common/graph_kernel/bprop/expander/common_utils.h"
namespace mindspore { namespace mindspore {
namespace expander { namespace expander {
namespace bprop { namespace bprop {
namespace { namespace {
constexpr size_t kMaxDims = 8; constexpr size_t kMaxDims = 8;
int64_t CheckRange(int64_t idx, int64_t dim_size) {
if (idx < -dim_size || idx >= dim_size) {
MS_EXCEPTION(IndexError) << "index {" << idx << "} is out of bounds for dimension with size {" << dim_size << "}";
}
return idx < 0 ? (idx + dim_size) : idx;
}
} // namespace } // namespace
bool BpropIRBuilder::Run(const NodePtrList &inputs, const DAttr &attrs, CNodePtrList *outputs) { bool BpropIRBuilder::Run(const NodePtrList &inputs, const DAttr &attrs, CNodePtrList *outputs) {
@ -72,64 +66,6 @@ NodePtr BpropIRBuilder::GetInput(size_t i) const {
return (*inputs_ptr_)[i]; return (*inputs_ptr_)[i];
} }
ShapeVector BpropIRBuilder::GetShape(const NodePtr &node) const {
auto abs = node->get()->abstract();
MS_EXCEPTION_IF_NULL(abs);
auto shape = abs->BuildShape();
MS_EXCEPTION_IF_NULL(shape);
if (shape->isa<abstract::Shape>()) {
return shape->cast<abstract::ShapePtr>()->shape();
} else if (shape->isa<abstract::SequenceShape>()) {
MS_LOG(EXCEPTION) << "The output of node " << node->get()->ToString() << " is a tuple.";
}
return {};
}
std::vector<ShapeVector> BpropIRBuilder::GetShapes(const NodePtr &node) const {
auto abs = node->get()->abstract();
MS_EXCEPTION_IF_NULL(abs);
auto shape = abs->BuildShape();
MS_EXCEPTION_IF_NULL(shape);
if (shape->isa<abstract::SequenceShape>()) {
auto seq_shape_ptr = shape->cast<abstract::SequenceShapePtr>();
MS_EXCEPTION_IF_NULL(seq_shape_ptr);
const auto &shape_list = seq_shape_ptr->shape();
if (shape_list.empty()) {
return {};
}
std::vector<ShapeVector> res;
res.reserve(shape_list.size());
for (const auto &item : shape_list) {
MS_EXCEPTION_IF_NULL(item);
if (item->isa<abstract::NoShape>()) {
res.push_back({});
} else if (!item->isa<abstract::Shape>()) {
MS_LOG(EXCEPTION) << "Invalid Shape Type(" << item->ToString() << ") In Shape List";
}
auto shape_ptr = item->cast<abstract::ShapePtr>();
MS_EXCEPTION_IF_NULL(shape_ptr);
res.push_back(shape_ptr->shape());
}
return res;
} else {
MS_LOG(EXCEPTION) << "The output of node " << node->get()->ToString() << " is not a tuple.";
}
return {};
}
TypePtr BpropIRBuilder::GetDtype(const NodePtr &node) const {
auto abs = node->get()->abstract();
MS_EXCEPTION_IF_NULL(abs);
auto dtype = abs->BuildType();
MS_EXCEPTION_IF_NULL(dtype);
if (dtype->isa<TensorType>()) {
return dtype->cast<TensorTypePtr>()->element();
} else if (dtype->isa<Tuple>()) {
MS_LOG(EXCEPTION) << "The output of node " << node->get()->ToString() << " is a tuple.";
}
return dtype;
}
ValuePtr BpropIRBuilder::GetAttr(const NodePtr &node, const std::string &attr) const { ValuePtr BpropIRBuilder::GetAttr(const NodePtr &node, const std::string &attr) const {
auto p = GetCNodePrimitive(node->get()); auto p = GetCNodePrimitive(node->get());
MS_EXCEPTION_IF_NULL(p); MS_EXCEPTION_IF_NULL(p);

View File

@ -47,10 +47,10 @@ class BpropIRBuilder : public Emitter {
const NodePtrList &GetInputs() const { return *inputs_ptr_; } const NodePtrList &GetInputs() const { return *inputs_ptr_; }
// For node that has single output // For node that has single output
ShapeVector GetShape(const NodePtr &node) const; ShapeVector GetShape(const NodePtr &node) const { return node->shape(); }
// For node that has multiple outputs // For node that has multiple outputs
std::vector<ShapeVector> GetShapes(const NodePtr &node) const; std::vector<ShapeVector> GetShapes(const NodePtr &node) const { return node->shapes(); }
TypePtr GetDtype(const NodePtr &node) const; TypePtr GetDtype(const NodePtr &node) const { return node->dtype(); }
TypeId GetDtypeId(const NodePtr &node) const { return GetDtype(node)->type_id(); } TypeId GetDtypeId(const NodePtr &node) const { return GetDtype(node)->type_id(); }
ValuePtr GetAttr(const NodePtr &node, const std::string &attr) const; ValuePtr GetAttr(const NodePtr &node, const std::string &attr) const;
int64_t GetSize(const NodePtr &node) const; int64_t GetSize(const NodePtr &node) const;

View File

@ -31,6 +31,10 @@ namespace mindspore::expander::bprop {
namespace { namespace {
NodePtr ReduceSumWithCast(const BpropIRBuilder *ib, const NodePtr &dx, const std::vector<int64_t> &axis) { NodePtr ReduceSumWithCast(const BpropIRBuilder *ib, const NodePtr &dx, const std::vector<int64_t> &axis) {
auto dx_origin_dtypeptr = ib->GetDtype(dx); auto dx_origin_dtypeptr = ib->GetDtype(dx);
auto need_reduce = ib->NeedReduce(ib->GetShape(dx), axis, false);
if (!need_reduce.first) {
return ib->Reshape(dx, need_reduce.second);
}
auto dx_origin_dtype = dx_origin_dtypeptr->type_id(); auto dx_origin_dtype = dx_origin_dtypeptr->type_id();
if (dx_origin_dtype == TypeId::kNumberTypeInt16 || dx_origin_dtype == TypeId::kNumberTypeInt32 || if (dx_origin_dtype == TypeId::kNumberTypeInt16 || dx_origin_dtype == TypeId::kNumberTypeInt32 ||
dx_origin_dtype == TypeId::kNumberTypeInt64) { dx_origin_dtype == TypeId::kNumberTypeInt64) {
@ -285,12 +289,12 @@ NodePtrList BinopGradCommonWithShift(const BpropIRBuilder *ib, const NodePtr &x,
} }
std::vector<int64_t> Range(int64_t start, int64_t stop, int64_t step) { std::vector<int64_t> Range(int64_t start, int64_t stop, int64_t step) {
auto size = (stop - start) / step; int64_t size = (step != 0) ? ((stop - start) / step) : 0;
if (size <= 0) { if (size <= 0) {
return {}; return {};
} }
size = ((stop - start) % step == 0) ? size : size + 1; size = ((stop - start) % step == 0) ? size : size + 1;
std::vector<int64_t> range(size); std::vector<int64_t> range(LongToSize(size));
std::generate(range.begin(), range.end(), [n = start - step, step]() mutable { std::generate(range.begin(), range.end(), [n = start - step, step]() mutable {
n = n + step; n = n + step;
return n; return n;
@ -314,6 +318,13 @@ std::vector<int64_t> GetTransposeAxis(const std::vector<int64_t> &x_shape, int64
return reverse_axis; return reverse_axis;
} }
int64_t CheckRange(int64_t idx, int64_t dim_size) {
if (idx < -dim_size || idx >= dim_size) {
MS_EXCEPTION(IndexError) << "index {" << idx << "} is out of bounds for dimension with size {" << dim_size << "}";
}
return idx < 0 ? (idx + dim_size) : idx;
}
NodePtr GetEps(const BpropIRBuilder *ib, const TypePtr &type) { NodePtr GetEps(const BpropIRBuilder *ib, const TypePtr &type) {
switch (type->type_id()) { switch (type->type_id()) {
case kNumberTypeFloat16: case kNumberTypeFloat16:
@ -336,7 +347,7 @@ NodePtrList BinopGatherCommon(const BpropIRBuilder *ib) {
auto x_shp = ib->GetShape(x); auto x_shp = ib->GetShape(x);
auto out_shp = ib->GetShape(dout); auto out_shp = ib->GetShape(dout);
auto ind_shp = ib->GetShape(indices); auto ind_shp = ib->GetShape(indices);
auto axis_v = GetIntFromValueNode(axis); auto axis_v = CheckRange(GetIntFromValueNode(axis), SizeToLong(x_shp.size()));
if (out_shp.empty()) { if (out_shp.empty()) {
dout = ib->Emit("ExpandDims", {dout, ib->Tensor(-1)}); dout = ib->Emit("ExpandDims", {dout, ib->Tensor(-1)});
} }
@ -515,10 +526,7 @@ NodePtr MinOrMaxGrad(const BpropIRBuilder *ib, const NodePtr &x, const std::vect
auto indicators = ib->Cast(ib->Emit("Equal", {y, x}), ib->GetDtype(grad)); auto indicators = ib->Cast(ib->Emit("Equal", {y, x}), ib->GetDtype(grad));
auto minn = 1e-24; auto minn = 1e-24;
auto min_num = ib->Tensor(minn, ib->GetDtype(grad)); auto min_num = ib->Tensor(minn, ib->GetDtype(grad));
auto num_selected = auto num_selected = ib->Reshape(ib->ReduceSum(indicators, axis, false), output_shape_kept_dims) + min_num;
ib->Reshape(ib->Emit("ReduceSum", {indicators, ib->Value<ShapeVector>(axis)}, {{"keep_dims", MakeValue(false)}}),
output_shape_kept_dims) +
min_num;
return indicators / num_selected * grad; return indicators / num_selected * grad;
} }

View File

@ -38,6 +38,8 @@ std::vector<int64_t> ReduceShape(const std::vector<int64_t> &x, const std::vecto
std::vector<int64_t> GetAxisList(const ValuePtr &value); std::vector<int64_t> GetAxisList(const ValuePtr &value);
int64_t CheckRange(int64_t idx, int64_t dim_size);
NodePtrList BinopGradCommon(const BpropIRBuilder *ib, const NodePtr &x, const NodePtr &y, const NodePtr &dx, NodePtrList BinopGradCommon(const BpropIRBuilder *ib, const NodePtr &x, const NodePtr &y, const NodePtr &dx,
const NodePtr &dy); const NodePtr &dy);

View File

@ -17,6 +17,9 @@
#include "common/graph_kernel/bprop/expander/emitter.h" #include "common/graph_kernel/bprop/expander/emitter.h"
#include <algorithm> #include <algorithm>
#include <functional>
#include <unordered_set>
#include <utility>
#include "ops/primitive_c.h" #include "ops/primitive_c.h"
#include "utils/anf_utils.h" #include "utils/anf_utils.h"
@ -69,6 +72,28 @@ NodePtr Emitter::Log(const NodePtr &x) const {
{"cust_aicpu", MakeValue(kLogOpName)}}); {"cust_aicpu", MakeValue(kLogOpName)}});
} }
NodePtr Emitter::Cast(const NodePtr &node, const TypePtr &type) const {
// do not emit a node when the dst type is the same as src type
if (node->dtype()->type_id() == type->type_id()) {
return node;
}
return Emit("Cast", {node, EmitValue(type)});
}
NodePtr Emitter::Reshape(const NodePtr &node, const ShapeVector &shape) const {
MS_EXCEPTION_IF_NULL(node);
auto node_shape = node->shape();
if (shape.size() != node_shape.size()) {
return Emit(prim::kReshape, {node, Value(shape)});
}
for (size_t i = 0; i < shape.size(); ++i) {
if (shape[i] != node_shape[i] && shape[i] != -1) {
return Emit(prim::kReshape, {node, Value(shape)});
}
}
return node;
}
NodePtr Emitter::MatMul(const NodePtr &a, const NodePtr &b, bool transpose_a, bool transpose_b) const { NodePtr Emitter::MatMul(const NodePtr &a, const NodePtr &b, bool transpose_a, bool transpose_b) const {
return Emit(prim::kPrimMatMul->name(), {a, b}, return Emit(prim::kPrimMatMul->name(), {a, b},
{{"transpose_x1", MakeValue(transpose_a)}, {{"transpose_x1", MakeValue(transpose_a)},
@ -114,7 +139,54 @@ NodePtr Emitter::ZerosLike(const NodePtr &node) const {
return Emit(prim::kZerosLike, {node}); return Emit(prim::kZerosLike, {node});
} }
std::pair<bool, ShapeVector> Emitter::NeedReduce(const ShapeVector &shape, const std::vector<int64_t> &axis,
bool keep_dim) const {
if (shape.empty()) {
return std::make_pair(false, shape);
}
auto rank = SizeToLong(shape.size());
auto real_axis = axis;
if (real_axis.empty()) {
// all reduce
for (int64_t i = 0; i < rank; ++i) {
real_axis.push_back(i);
}
}
std::unordered_set<size_t> uniq_axis;
for (size_t i = 0; i < real_axis.size(); ++i) {
if (real_axis[i] < -rank || real_axis[i] >= rank) {
MS_EXCEPTION(ValueError) << "Reduce axis[" << i << "] is " << real_axis[i] << ", which is out of range [-" << rank
<< ", " << rank << ") for shape: " << shape;
}
auto axis_i = real_axis[i] < 0 ? real_axis[i] + rank : real_axis[i];
(void)uniq_axis.insert(LongToSize(axis_i));
}
// Calc reduce output shape
ShapeVector out_shape;
bool need_reduce = false;
for (size_t i = 0; i < shape.size(); ++i) {
if (uniq_axis.find(i) == uniq_axis.end()) {
// not reduce axis
out_shape.push_back(shape[i]);
} else {
// reduce axis
if (shape[i] != 1) {
need_reduce = true;
}
if (keep_dim) {
out_shape.push_back(1);
}
}
}
return std::make_pair(need_reduce, out_shape);
}
NodePtr Emitter::ReduceSum(const NodePtr &x, const ShapeVector &axis, bool keep_dims) const { NodePtr Emitter::ReduceSum(const NodePtr &x, const ShapeVector &axis, bool keep_dims) const {
MS_EXCEPTION_IF_NULL(x);
auto need_reduce = NeedReduce(x->shape(), axis, keep_dims);
if (!need_reduce.first) {
return Reshape(x, need_reduce.second);
}
return Emit(prim::kPrimReduceSum->name(), {x, Value<ShapeVector>(axis)}, {{"keep_dims", MakeValue(keep_dims)}}); return Emit(prim::kPrimReduceSum->name(), {x, Value<ShapeVector>(axis)}, {{"keep_dims", MakeValue(keep_dims)}});
} }

View File

@ -19,6 +19,7 @@
#include <memory> #include <memory>
#include <vector> #include <vector>
#include <string> #include <string>
#include <utility>
#include "ir/func_graph.h" #include "ir/func_graph.h"
#include "ops/core_ops.h" #include "ops/core_ops.h"
#include "include/common/utils/utils.h" #include "include/common/utils/utils.h"
@ -44,12 +45,10 @@ class Emitter {
return Emit(prim::kTupleGetItem, {input, Value(static_cast<int64_t>(i))}); return Emit(prim::kTupleGetItem, {input, Value(static_cast<int64_t>(i))});
} }
NodePtr Cast(const NodePtr &node, const TypePtr &type) const { return Emit("Cast", {node, EmitValue(type)}); } NodePtr Cast(const NodePtr &node, const TypePtr &type) const;
NodePtr Cast(const NodePtr &node, TypeId type_id) const { return Cast(node, TypeIdToType(type_id)); } NodePtr Cast(const NodePtr &node, TypeId type_id) const { return Cast(node, TypeIdToType(type_id)); }
NodePtr Reshape(const NodePtr &node, const ShapeVector &shape) const { NodePtr Reshape(const NodePtr &node, const ShapeVector &shape) const;
return Emit(prim::kReshape, {node, Tensor(shape)});
}
NodePtr ExpandDims(const NodePtr &node, int64_t axis) const { return Emit(kExpandDimsOpName, {node, Value(axis)}); } NodePtr ExpandDims(const NodePtr &node, int64_t axis) const { return Emit(kExpandDimsOpName, {node, Value(axis)}); }
NodePtr Abs(const NodePtr &node) const { return Emit(prim::kAbs, {node}); } NodePtr Abs(const NodePtr &node) const { return Emit(prim::kAbs, {node}); }
NodePtr Neg(const NodePtr &node) const { return Emit(prim::kNeg, {node}); } NodePtr Neg(const NodePtr &node) const { return Emit(prim::kNeg, {node}); }
@ -97,6 +96,8 @@ class Emitter {
} }
NodePtr LogicalAnd(const NodePtr &lhs, const NodePtr &rhs) const { return Emit("LogicalAnd", {lhs, rhs}); } NodePtr LogicalAnd(const NodePtr &lhs, const NodePtr &rhs) const { return Emit("LogicalAnd", {lhs, rhs}); }
NodePtr LogicalOr(const NodePtr &lhs, const NodePtr &rhs) const { return Emit("LogicalOr", {lhs, rhs}); } NodePtr LogicalOr(const NodePtr &lhs, const NodePtr &rhs) const { return Emit("LogicalOr", {lhs, rhs}); }
std::pair<bool, ShapeVector> NeedReduce(const ShapeVector &shape, const std::vector<int64_t> &axis,
bool keep_dim) const;
NodePtr ReduceSum(const NodePtr &x, const ShapeVector &axis = {}, bool keep_dims = false) const; NodePtr ReduceSum(const NodePtr &x, const ShapeVector &axis = {}, bool keep_dims = false) const;
NodePtr ZerosLike(const NodePtr &node) const; NodePtr ZerosLike(const NodePtr &node) const;
@ -120,6 +121,8 @@ class Emitter {
return EmitValue(tensor_ptr); return EmitValue(tensor_ptr);
} }
ExpanderInferPtr infer() const { return infer_; }
protected: protected:
NodePtr NewNode(const AnfNodePtr &anfnode) const { return std::make_shared<Node>(anfnode, this); } NodePtr NewNode(const AnfNodePtr &anfnode) const { return std::make_shared<Node>(anfnode, this); }
NodePtr CmpOpWithCast(const std::string &op, const NodePtr &lhs, const NodePtr &rhs, const TypePtr &dst_type) const { NodePtr CmpOpWithCast(const std::string &op, const NodePtr &lhs, const NodePtr &rhs, const TypePtr &dst_type) const {

View File

@ -63,5 +63,17 @@ void CppInfer::Infer(const NodePtr &node) {
} }
cnode->set_abstract(result); cnode->set_abstract(result);
} }
BaseShapePtr CppInfer::GetShape(const NodePtr &node) {
auto abs = node->get()->abstract();
MS_EXCEPTION_IF_NULL(abs);
return abs->BuildShape();
}
TypePtr CppInfer::GetDtype(const NodePtr &node) {
auto abs = node->get()->abstract();
MS_EXCEPTION_IF_NULL(abs);
return abs->BuildType();
}
} // namespace expander } // namespace expander
} // namespace mindspore } // namespace mindspore

View File

@ -24,7 +24,11 @@ namespace expander {
/// \brief ExpanderInfer is the adapter for inferring functions that is called in emitter. /// \brief ExpanderInfer is the adapter for inferring functions that is called in emitter.
class ExpanderInfer { class ExpanderInfer {
public: public:
/// \brief Infer shape and dtype for node
virtual void Infer(const NodePtr &node) = 0; virtual void Infer(const NodePtr &node) = 0;
virtual BaseShapePtr GetShape(const NodePtr &node) = 0;
virtual TypePtr GetDtype(const NodePtr &node) = 0;
}; };
using ExpanderInferPtr = std::shared_ptr<ExpanderInfer>; using ExpanderInferPtr = std::shared_ptr<ExpanderInfer>;
@ -32,6 +36,8 @@ using ExpanderInferPtr = std::shared_ptr<ExpanderInfer>;
class CppInfer : public ExpanderInfer { class CppInfer : public ExpanderInfer {
public: public:
void Infer(const NodePtr &node) override; void Infer(const NodePtr &node) override;
BaseShapePtr GetShape(const NodePtr &node) override;
TypePtr GetDtype(const NodePtr &node) override;
}; };
} // namespace expander } // namespace expander
} // namespace mindspore } // namespace mindspore

View File

@ -15,6 +15,9 @@
*/ */
#include "common/graph_kernel/bprop/expander/node.h" #include "common/graph_kernel/bprop/expander/node.h"
#include <algorithm>
#include "common/graph_kernel/bprop/expander/emitter.h"
#include "common/graph_kernel/bprop/expander/infer.h"
namespace mindspore { namespace mindspore {
namespace expander { namespace expander {
@ -22,5 +25,64 @@ Node::Node(const AnfNodePtr &node, const Emitter *emitter) : anf_node_(node), em
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(emitter); MS_EXCEPTION_IF_NULL(emitter);
} }
std::vector<int64_t> Node::shape() {
if (shape_ == nullptr) {
shape_ = emitter()->infer()->GetShape(shared_from_this());
MS_EXCEPTION_IF_NULL(shape_);
}
if (shape_->isa<abstract::NoShape>()) {
return {};
}
auto shape = shape_->cast<abstract::ShapePtr>();
MS_EXCEPTION_IF_NULL(shape);
return shape->shape();
}
std::vector<std::vector<int64_t>> Node::shapes() {
if (shape_ == nullptr) {
shape_ = emitter()->infer()->GetShape(shared_from_this());
MS_EXCEPTION_IF_NULL(shape_);
}
auto tuple_shape = shape_->cast<abstract::SequenceShapePtr>();
MS_EXCEPTION_IF_NULL(tuple_shape);
auto &shape_list = tuple_shape->shape();
std::vector<ShapeVector> shapes(shape_list.size());
(void)std::transform(shape_list.cbegin(), shape_list.cend(), shapes.begin(), [](const BaseShapePtr &bs) {
if (bs->isa<abstract::NoShape>()) {
return ShapeVector();
}
auto shape = bs->cast<abstract::ShapePtr>();
MS_EXCEPTION_IF_NULL(shape);
return shape->shape();
});
return shapes;
}
TypePtr Node::dtype() {
if (type_ == nullptr) {
type_ = emitter()->infer()->GetDtype(shared_from_this());
MS_EXCEPTION_IF_NULL(type_);
if (type_->isa<TensorType>()) {
type_ = type_->cast<TensorTypePtr>()->element();
MS_EXCEPTION_IF_NULL(type_);
}
}
return type_;
}
std::vector<TypePtr> Node::dtypes() {
if (type_ == nullptr) {
type_ = emitter()->infer()->GetDtype(shared_from_this());
MS_EXCEPTION_IF_NULL(type_);
}
auto tuple = type_->cast<TuplePtr>();
MS_EXCEPTION_IF_NULL(tuple);
std::vector<TypePtr> result(tuple->size());
auto elements = tuple->elements();
(void)std::transform(elements.cbegin(), elements.cend(), result.begin(),
[](const TypePtr &t) { return t->isa<TensorType>() ? t->cast<TensorTypePtr>()->element() : t; });
return result;
}
} // namespace expander } // namespace expander
} // namespace mindspore } // namespace mindspore

View File

@ -20,6 +20,7 @@
#include <string> #include <string>
#include <memory> #include <memory>
#include "ir/anf.h" #include "ir/anf.h"
#include "ir/dtype.h"
namespace mindspore { namespace mindspore {
namespace expander { namespace expander {
@ -42,12 +43,24 @@ class Node : public std::enable_shared_from_this<Node> {
return anf_node_->cast<T>(); return anf_node_->cast<T>();
} }
std::vector<int64_t> shape();
std::vector<std::vector<int64_t>> shapes();
TypePtr dtype();
std::vector<TypePtr> dtypes();
const Emitter *emitter() const { return emitter_; } const Emitter *emitter() const { return emitter_; }
protected: protected:
// the wrapped anfnode.
AnfNodePtr anf_node_{nullptr}; AnfNodePtr anf_node_{nullptr};
// hold an emitter for operator overloading. // hold the emitter who created this node.
const Emitter *emitter_{nullptr}; const Emitter *emitter_{nullptr};
// cache the output shape after first query
BaseShapePtr shape_{nullptr};
// cache the output dtype after first query
TypePtr type_{nullptr};
}; };
using NodePtr = std::shared_ptr<Node>; using NodePtr = std::shared_ptr<Node>;
using NodePtrList = std::vector<NodePtr>; using NodePtrList = std::vector<NodePtr>;

View File

@ -37,7 +37,7 @@ NodePtrList GatherDropNegatives(const BpropIRBuilder *ib, const NodePtr &params,
for (size_t i = 0; i < back_size; ++i) { for (size_t i = 0; i < back_size; ++i) {
broadcastable_shape.push_back(1); broadcastable_shape.push_back(1);
} }
is_positive = ib->Emit("Reshape", {is_positive, ib->Value<ShapeVector>(broadcastable_shape)}); is_positive = ib->Reshape(is_positive, broadcastable_shape);
auto gathered_shape = ib->GetShape(gathered); auto gathered_shape = ib->GetShape(gathered);
is_positive = ib->Emit("LogicalAnd", is_positive = ib->Emit("LogicalAnd",
{is_positive, ib->Emit("Fill", {ib->EmitValue(kBool), ib->Value<ShapeVector>(gathered_shape), {is_positive, ib->Emit("Fill", {ib->EmitValue(kBool), ib->Value<ShapeVector>(gathered_shape),
@ -58,8 +58,8 @@ NodePtrList UnsortedSegmentMinOrMaxGrad(const BpropIRBuilder *ib, const NodePtr
auto tmp = ib->Emit("Equal", {x, gathered_outputs}); auto tmp = ib->Emit("Equal", {x, gathered_outputs});
auto is_selected = ib->Emit("LogicalAnd", {tmp, is_positive}); auto is_selected = ib->Emit("LogicalAnd", {tmp, is_positive});
auto num_selected = ib->Emit( auto num_selected =
"UnsortedSegmentSum", {ib->Emit("Cast", {is_selected, ib->Value(ib->GetDtype(dout))}), segment_ids, num_segments}); ib->Emit("UnsortedSegmentSum", {ib->Cast(is_selected, ib->GetDtype(dout)), segment_ids, num_segments});
auto weighted_grads = ib->Emit("RealDiv", {dout, num_selected}); auto weighted_grads = ib->Emit("RealDiv", {dout, num_selected});
auto temp_outs_2 = GatherDropNegatives(ib, weighted_grads, nullptr, zero_clipped_indices, is_positive); auto temp_outs_2 = GatherDropNegatives(ib, weighted_grads, nullptr, zero_clipped_indices, is_positive);
MS_EXCEPTION_IF_CHECK_FAIL(temp_outs.size() > 0, "Outputs should not be empty."); MS_EXCEPTION_IF_CHECK_FAIL(temp_outs.size() > 0, "Outputs should not be empty.");
@ -158,7 +158,7 @@ REG_BPROP_BUILDER("SparseGatherV2").SetBody([](const BpropIRBuilder *ib) -> Node
auto axis = ib->GetInput(kIndex2); auto axis = ib->GetInput(kIndex2);
auto dout = ib->GetInput(kIndex4); auto dout = ib->GetInput(kIndex4);
auto x_shp = ib->GetShape(x); auto x_shp = ib->GetShape(x);
auto axis_int = GetValue<int64_t>(axis->get<ValueNodePtr>()->value()); auto axis_int = CheckRange(GetValue<int64_t>(axis->get<ValueNodePtr>()->value()), SizeToLong(x_shp.size()));
if (axis_int == 0) { if (axis_int == 0) {
ShapeVector values_shape{ib->GetSize(indices)}; ShapeVector values_shape{ib->GetSize(indices)};
if (x_shp.size() > 1) { if (x_shp.size() > 1) {
@ -730,7 +730,7 @@ REG_BPROP_BUILDER("BroadcastTo").SetBody([](const BpropIRBuilder *ib) -> NodePtr
auto tuple_out = BroadcastGradientArgs(broadcast_shape, x_shape); auto tuple_out = BroadcastGradientArgs(broadcast_shape, x_shape);
MS_EXCEPTION_IF_CHECK_FAIL(!tuple_out.empty(), "BroadcastGradientArgs out should not be empty!"); MS_EXCEPTION_IF_CHECK_FAIL(!tuple_out.empty(), "BroadcastGradientArgs out should not be empty!");
auto reduction_axes = tuple_out[kIndex1]; auto reduction_axes = tuple_out[kIndex1];
auto reduced_grad = ib->Emit("ReduceSum", {dout, ib->Value(reduction_axes)}, {{"keep_dims", MakeValue(true)}}); auto reduced_grad = ib->ReduceSum(dout, reduction_axes, true);
auto dx = ib->Reshape(reduced_grad, x_shape); auto dx = ib->Reshape(reduced_grad, x_shape);
return {dx}; return {dx};
}); });
@ -846,12 +846,17 @@ REG_BPROP_BUILDER("Tile").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
auto dout_reshaped = ib->Reshape(dout, r_shape); auto dout_reshaped = ib->Reshape(dout, r_shape);
auto dout_dtype = ib->GetDtype(dout_reshaped)->type_id(); auto dout_dtype = ib->GetDtype(dout_reshaped)->type_id();
NodePtr dx; NodePtr dx;
if (dout_dtype == kNumberTypeInt16 || dout_dtype == kNumberTypeInt32 || dout_dtype == kNumberTypeInt64) { auto need_reduce = ib->NeedReduce(r_shape, axis, false);
dout_reshaped = ib->Cast(dout_reshaped, kFloat32); if (need_reduce.first) {
dx = ib->Emit("ReduceSum", {dout_reshaped, ib->Value<ShapeVector>(axis)}, {{"keep_dims", MakeValue(false)}}); if (dout_dtype == kNumberTypeInt16 || dout_dtype == kNumberTypeInt32 || dout_dtype == kNumberTypeInt64) {
dx = ib->Cast(dx, dout_dtype); dout_reshaped = ib->Cast(dout_reshaped, kFloat32);
dx = ib->Emit("ReduceSum", {dout_reshaped, ib->Value<ShapeVector>(axis)}, {{"keep_dims", MakeValue(false)}});
dx = ib->Cast(dx, dout_dtype);
} else {
dx = ib->Emit("ReduceSum", {dout_reshaped, ib->Value<ShapeVector>(axis)}, {{"keep_dims", MakeValue(false)}});
}
} else { } else {
dx = ib->Emit("ReduceSum", {dout_reshaped, ib->Value<ShapeVector>(axis)}, {{"keep_dims", MakeValue(false)}}); dx = ib->Reshape(dout_reshaped, need_reduce.second);
} }
dx = ib->Reshape(dx, shapex); dx = ib->Reshape(dx, shapex);
return {dx, ib->ZerosLike(input_multiples)}; return {dx, ib->ZerosLike(input_multiples)};

View File

@ -914,7 +914,7 @@ REG_BPROP_BUILDER("ReduceProd").SetBody([](const BpropIRBuilder *ib) -> NodePtrL
auto dout = ib->GetInput(kIndex3); auto dout = ib->GetInput(kIndex3);
auto input_shape = ib->GetShape(x); auto input_shape = ib->GetShape(x);
auto output_shape_kept_dims = ReduceShape(input_shape, GetAxisValue(axis)); auto output_shape_kept_dims = ReduceShape(input_shape, GetAxisValue(axis));
dout = ib->Emit("Reshape", {dout, ib->Value<ShapeVector>(output_shape_kept_dims)}); dout = ib->Reshape(dout, output_shape_kept_dims);
auto tile_scaling = TupleDiv(input_shape, output_shape_kept_dims); auto tile_scaling = TupleDiv(input_shape, output_shape_kept_dims);
auto grad = ib->Emit("Tile", {dout, ib->Value<ShapeVector>(tile_scaling)}); auto grad = ib->Emit("Tile", {dout, ib->Value<ShapeVector>(tile_scaling)});
auto [pack_shape, perm] = SplitShapeIndex(input_shape, GetAxisValue(axis)); auto [pack_shape, perm] = SplitShapeIndex(input_shape, GetAxisValue(axis));
@ -964,7 +964,11 @@ REG_BPROP_BUILDER("ReduceMean").SetBody([](const BpropIRBuilder *ib) -> NodePtrL
} }
return size; return size;
}; };
auto div_shape = getSize(shape_x) / getSize(shape_out); auto shape_out_sz = getSize(shape_out);
if (shape_out_sz == 0) {
MS_EXCEPTION(ValueError) << "out shape size can not be 0";
}
auto div_shape = getSize(shape_x) / shape_out_sz;
auto dx = ib->RealDiv(grad, ib->Tensor(div_shape, ib->GetDtype(grad))); auto dx = ib->RealDiv(grad, ib->Tensor(div_shape, ib->GetDtype(grad)));
return {dx, ib->ZerosLike(axis)}; return {dx, ib->ZerosLike(axis)};
}); });

View File

@ -899,7 +899,7 @@ REG_BPROP_BUILDER("Tanh").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
auto dout = ib->GetInput(kIndex2); auto dout = ib->GetInput(kIndex2);
auto x_dtype_id = ib->GetDtypeId(x); auto x_dtype_id = ib->GetDtypeId(x);
NodePtr dx; NodePtr dx;
if (x_dtype_id == 46 || x_dtype_id == 47) { if (x_dtype_id == kNumberTypeComplex64 || x_dtype_id == kNumberTypeComplex128) {
dout = ib->Emit("Conj", {dout}); dout = ib->Emit("Conj", {dout});
dx = ib->Emit("TanhGrad", {out, dout}); dx = ib->Emit("TanhGrad", {out, dout});
dx = ib->Emit("Conj", {dx}); dx = ib->Emit("Conj", {dx});