forked from mindspore-Ecosystem/mindspore
!18257 ops clean code master
Merge pull request !18257 from wanyiming/clean_ops_train
This commit is contained in:
commit
7c1ea7e3fa
|
@ -33,8 +33,8 @@ abstract::ShapePtr AddNInferShape(const PrimitivePtr &primitive, const std::vect
|
|||
auto elements = input_args[0]->isa<abstract::AbstractTuple>()
|
||||
? input_args[0]->cast<abstract::AbstractTuplePtr>()->elements()
|
||||
: input_args[0]->cast<abstract::AbstractListPtr>()->elements();
|
||||
CheckAndConvertUtils::CheckInteger("concat element num", SizeToLong(elements.size()), kGreaterEqual, 1,
|
||||
primitive->name());
|
||||
(void)CheckAndConvertUtils::CheckInteger("concat element num", SizeToLong(elements.size()), kGreaterEqual, 1,
|
||||
primitive->name());
|
||||
primitive->AddAttr("n", MakeValue(SizeToLong(elements.size())));
|
||||
auto shape_0 = elements[0]->BuildShape();
|
||||
auto element0_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(shape_0);
|
||||
|
|
|
@ -33,14 +33,14 @@ AbstractBasePtr ConcatInfer(const abstract::AnalysisEnginePtr &, const Primitive
|
|||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, 1, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, 1, prim_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
auto input_tuple = input_args[0]->cast<abstract::AbstractTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(input_tuple);
|
||||
auto elements = input_tuple->elements();
|
||||
CheckAndConvertUtils::CheckInteger("concat element num", elements.size(), kGreaterEqual, 1, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("concat element num", elements.size(), kGreaterEqual, 1, prim_name);
|
||||
auto element0 = elements[0]->cast<abstract::AbstractTensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(element0);
|
||||
auto element0_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(element0->BuildShape())[kShape];
|
||||
|
@ -56,8 +56,8 @@ AbstractBasePtr ConcatInfer(const abstract::AnalysisEnginePtr &, const Primitive
|
|||
for (size_t i = 1; i < elements.size(); ++i) {
|
||||
std::string elementi = "element" + std::to_string(i);
|
||||
auto elementi_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(elements[i]->BuildShape())[kShape];
|
||||
CheckAndConvertUtils::CheckInteger(elementi + " shape rank", SizeToLong(elementi_shape.size()), kEqual,
|
||||
element0_shape.size(), prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger(elementi + " shape rank", SizeToLong(elementi_shape.size()), kEqual,
|
||||
SizeToLong(element0_shape.size()), prim_name);
|
||||
for (size_t j = 0; j < element0_rank; ++j) {
|
||||
if (j != axis && elementi_shape[j] != element0_shape[j]) {
|
||||
MS_LOG(EXCEPTION) << "element " << i << " shape in input can not concat with first element.";
|
||||
|
|
|
@ -32,9 +32,9 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
auto min_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
|
||||
auto max_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
|
||||
CheckAndConvertUtils::CheckInteger("x_rank", SizeToLong(in_shape.size()), kGreaterEqual, 1, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("x_rank", SizeToLong(in_shape.size()), kGreaterEqual, 1, prim_name);
|
||||
CheckAndConvertUtils::Check("min_shape", min_shape, kEqual, "max_shape", max_shape, prim_name);
|
||||
CheckAndConvertUtils::CheckInteger("min_shape", SizeToLong(min_shape.size()), kEqual, 1, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("min_shape", SizeToLong(min_shape.size()), kEqual, 1, prim_name);
|
||||
int64_t shape_val = 1;
|
||||
for (size_t i = 0; i < in_shape.size(); i++) {
|
||||
shape_val = shape_val * in_shape[i];
|
||||
|
|
|
@ -30,8 +30,8 @@ AbstractBasePtr SigmoidCrossEntropyWithLogitsInfer(const abstract::AnalysisEngin
|
|||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInteger("sigmoid_cross_extropy_with_logits_infer", SizeToLong(input_args.size()), kEqual,
|
||||
2, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("sigmoid_cross_extropy_with_logits_infer", SizeToLong(input_args.size()),
|
||||
kEqual, 2, prim_name);
|
||||
|
||||
// Infer shape
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
|
|
|
@ -37,7 +37,7 @@ AbstractBasePtr SmoothL1LossInfer(const abstract::AnalysisEnginePtr &, const Pri
|
|||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInteger("smooth_l1_loss_infer", input_args.size(), kEqual, 2, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("smooth_l1_loss_infer", input_args.size(), kEqual, 2, prim_name);
|
||||
|
||||
// Infer shape
|
||||
auto prediction = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
|
|
|
@ -32,7 +32,7 @@ AbstractBasePtr TopKInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt
|
|||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInteger("top_k_infer", input_args.size(), kEqual, 2, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("top_k_infer", input_args.size(), kEqual, 2, prim_name);
|
||||
|
||||
// Infer dtype
|
||||
auto output1_type = kInt32;
|
||||
|
|
Loading…
Reference in New Issue