!42469 update grad for dynamic_shape

Merge pull request !42469 from chengbin/ds_r1.9
This commit is contained in:
i-robot 2022-09-21 12:51:23 +00:00 committed by Gitee
commit 520dfd5c6e
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
94 changed files with 2696 additions and 552 deletions

View File

@ -17,38 +17,12 @@
#include "backend/common/pass/reduce_sum_optimizer.h"
#include <vector>
#include "include/common/utils/anfalgo.h"
#include "utils/ms_context.h"
namespace mindspore {
namespace opt {
namespace {
const int axis_input_index = 2;
bool IsNeedComputeRank(const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode);
auto axis_input = cnode->input(axis_input_index);
MS_EXCEPTION_IF_NULL(axis_input);
if (!IsValueNode<ValueTuple>(axis_input)) {
return false;
}
auto value_node = axis_input->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
auto value = value_node->value();
MS_EXCEPTION_IF_NULL(value);
if (value->isa<ValueTuple>()) {
auto value_tuple = value->cast<ValueTuplePtr>();
MS_EXCEPTION_IF_NULL(value_tuple);
if (value_tuple->value().empty()) {
return true;
}
for (auto &iter : value_tuple->value()) {
auto item = GetValue<int64_t>(iter->cast<ScalarPtr>());
if (item < 0) {
return true;
}
}
}
return false;
}
} // namespace
AnfNodePtr ReduceSumOptimizer::NewRankOp(const AnfNodePtr &cnode, const KernelGraphPtr &kernel_graph) const {
@ -168,9 +142,6 @@ const AnfNodePtr ReduceSumOptimizer::Process(const FuncGraphPtr &func_graph, con
common::AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node);
auto kernel_graph = func_graph->cast<std::shared_ptr<session::KernelGraph>>();
MS_EXCEPTION_IF_NULL(kernel_graph);
if (AnfUtils::IsDimUnknown(cnode) && IsNeedComputeRank(cnode)) {
return InsertAssistNode(cnode, kernel_graph);
}
return NewAssistValueNode(cnode, kernel_graph);
}

View File

@ -28,7 +28,12 @@ constexpr size_t kBiasAddGradOutputsNum = 1;
void BiasAddGradCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
input_shape_ = Convert2SizeTClipNeg(AnfAlgo::GetInputDeviceShape(kernel_node, 0));
auto shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
if (IsDynamic(shape)) {
return;
}
input_shape_ = Convert2SizeT(shape);
if (input_shape_.size() < 2) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', input tensor's dimension must be at least 2, but got "
<< input_shape_.size();

View File

@ -97,6 +97,7 @@ void TopKCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_
<< "', the dimension of input must be greater than 0, but got empty input.";
}
outer_size_ = 1;
for (size_t i = 0; i < x_shape_.size() - 1; ++i) {
outer_size_ *= x_shape_[i];
}

View File

@ -50,9 +50,9 @@ int LinSpaceGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std
if (auto ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) {
return ret;
}
auto input1_shape = Convert2SizeTClipNeg(inputs[kIndex0]->GetShapeVector());
auto input2_shape = Convert2SizeTClipNeg(inputs[kIndex1]->GetShapeVector());
auto output_shape = Convert2SizeTClipNeg(outputs[kIndex0]->GetShapeVector());
auto input1_shape = inputs[kIndex0]->GetShapeVector();
auto input2_shape = inputs[kIndex1]->GetShapeVector();
auto output_shape = outputs[kIndex0]->GetShapeVector();
is_null_input_ = CHECK_SHAPE_NULL(input1_shape, kernel_name_, "start") ||
CHECK_SHAPE_NULL(input2_shape, kernel_name_, "stop") ||
CHECK_SHAPE_NULL(output_shape, kernel_name_, "output");

View File

@ -43,18 +43,22 @@ bool SvdGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vect
int SvdGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &) {
if (int ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) {
return ret;
}
auto input_shape = inputs[kIndex0]->GetShapeVector();
if (IsDynamicRank(input_shape)) {
return KRET_OK;
}
DestroyResource();
ResetResource();
input_shape_ = std::vector<size_t>(inputs.at(kIndex0)->GetDeviceShapeAdaptively().begin(),
inputs.at(kIndex0)->GetDeviceShapeAdaptively().end());
input_shape_ = Convert2SizeTClipNeg(input_shape);
total_size_ = std::accumulate(input_shape_.begin(), input_shape_.end(), size_t(1), std::multiplies<size_t>());
is_null_input_ = (total_size_ == 0);
if (is_null_input_) {
init_size_lists_func_(this);
return 0;
return KRET_OK;
}
dims_ = input_shape_.size();
if (dims_ < kDim2) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', dimensions must >= 2, but got [" << dims_;

View File

@ -26,5 +26,19 @@ MS_REG_GPU_KERNEL_ONE(
Conv3DBackpropFilter,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat32),
Conv3dGradFilterGpuKernelMod, half)
MS_REG_GPU_KERNEL_ONE(Conv3DBackpropFilter,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat32),
Conv3dGradFilterGpuKernelMod, float)
MS_REG_GPU_KERNEL_ONE(Conv3DBackpropFilter,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat32),
Conv3dGradFilterGpuKernelMod, half)
} // namespace kernel
} // namespace mindspore

View File

@ -31,6 +31,9 @@
namespace mindspore {
namespace kernel {
constexpr int kStaticInputNum = 2;
constexpr int kDynamicInputNum = 3;
constexpr size_t kInputDimSize = 5;
constexpr size_t kInDimIdxForN = 0;
constexpr size_t kInDimIdxForC = 1;
@ -127,8 +130,8 @@ class Conv3dGradFilterGpuKernelMod : public NativeGpuKernelMod {
InitResource();
size_t input_num = inputs.size();
if (input_num != 2) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of inputs must be 2, but got " << input_num;
if (input_num != kStaticInputNum && input_num != kDynamicInputNum) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of inputs must be 2 or 3, but got " << input_num;
}
size_t output_num = outputs.size();
if (output_num != 1) {

View File

@ -26,5 +26,19 @@ MS_REG_GPU_KERNEL_ONE(
Conv3DBackpropInput,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
Conv3dGradInputGpuKernelMod, half)
MS_REG_GPU_KERNEL_ONE(Conv3DBackpropInput,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat32),
Conv3dGradInputGpuKernelMod, float)
MS_REG_GPU_KERNEL_ONE(Conv3DBackpropInput,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat16),
Conv3dGradInputGpuKernelMod, half)
} // namespace kernel
} // namespace mindspore

View File

@ -32,7 +32,8 @@ namespace mindspore {
namespace kernel {
constexpr int kNumDims = 5;
constexpr int kConvDims = 3;
constexpr int kInputNum = 2;
constexpr int kStaticInputNum = 2;
constexpr int kDynamicInputNum = 3;
constexpr size_t kInDimIdxForN = 0;
constexpr size_t kInDimIdxForC = 1;
constexpr size_t kInDimIdxForD = 2;
@ -115,8 +116,8 @@ class Conv3dGradInputGpuKernelMod : public NativeGpuKernelMod {
InitResource();
size_t input_num = inputs.size();
if (input_num != kInputNum) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of inputs must be 2, but got " << input_num;
if (input_num != kStaticInputNum && input_num != kDynamicInputNum) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of inputs must be 2 or 3, but got " << input_num;
}
size_t output_num = outputs.size();
if (output_num != 1) {

View File

@ -37,6 +37,7 @@ void InsertCast(const FuncGraphPtr &graph, const AnfNodePtr &node, size_t i, con
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), common::AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), i)};
auto cast = graph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(cast);
common::AnfAlgo::SetNodeAttr(kAttrDstType, TypeIdToType(cast_type), cast);
auto cast_shape = {common::AnfAlgo::GetPrevNodeOutputDetailShape(node, i)};
common::AnfAlgo::SetOutputTypeAndDetailShape({cast_type}, cast_shape, cast.get());
FuncGraphManagerPtr manager = graph->manager();

View File

@ -221,10 +221,20 @@ AbstractBasePtr InferImplMatMul(const AnalysisEnginePtr &, const PrimitivePtr &p
MS_EXCEPTION_IF_NULL(y->shape());
auto x_shp = x->shape()->shape();
auto y_shp = y->shape()->shape();
const size_t SHAPE_SIZE = 2;
if (x_shp.size() != SHAPE_SIZE || y_shp.size() != SHAPE_SIZE) {
MS_LOG(EXCEPTION) << "MatMul inputs should have the same dimension size and equal to 2.";
TypePtr x_type = x->element()->GetTypeTrack();
if (x_type->type_id() == TypeId::kNumberTypeInt8) {
x_type = kInt32;
}
if (primitive->HasAttr("cast_type")) {
auto out_type = primitive->GetAttr("cast_type");
MS_EXCEPTION_IF_NULL(out_type);
if (!out_type->isa<Type>()) {
MS_EXCEPTION(ValueError) << "MatMul cast_type must be a `Type`";
}
x_type = out_type->cast<TypePtr>();
}
ValuePtr transpose_a_ptr = primitive->GetAttr("transpose_a");
ValuePtr transpose_b_ptr = primitive->GetAttr("transpose_b");
bool transpose_a = GetValue<bool>(transpose_a_ptr);
@ -233,18 +243,23 @@ AbstractBasePtr InferImplMatMul(const AnalysisEnginePtr &, const PrimitivePtr &p
ShapeVector x_max_shape = x->shape()->max_shape();
ShapeVector y_min_shape = y->shape()->min_shape();
ShapeVector y_max_shape = y->shape()->max_shape();
// Additional check for dynamic shape
// Last infer will be real shape values
bool x_not_dyn = std::all_of(x_shp.begin(), x_shp.end(), [](int64_t value) { return value != Shape::SHP_ANY; });
bool y_not_dyn = std::all_of(y_shp.begin(), y_shp.end(), [](int64_t value) { return value != Shape::SHP_ANY; });
if (x_not_dyn && y_not_dyn) {
auto x_col = x_shp[(transpose_a ? 0 : 1)];
auto y_row = y_shp[(transpose_b ? 1 : 0)];
if (x_col != y_row) {
MS_LOG(EXCEPTION) << "MatMul shape error, got x_col: " << x_col << ", y_row: " << y_row
<< ". In MatMul x_col and y_row should be equal.";
}
if (IsDynamicRank(x_shp) || IsDynamicRank(y_shp)) {
ShapeVector ret_shape{UNKNOWN_RANK};
return std::make_shared<AbstractTensor>(x_type, std::make_shared<Shape>(ret_shape));
}
const size_t SHAPE_SIZE = 2;
if (x_shp.size() != SHAPE_SIZE || y_shp.size() != SHAPE_SIZE) {
MS_LOG(EXCEPTION) << "MatMul inputs should have the same dimension size and equal to 2.";
}
auto x_col = x_shp[(transpose_a ? 0 : 1)];
auto y_row = y_shp[(transpose_b ? 1 : 0)];
if (x_col != y_row && x_col >= 0 && y_row >= 0) {
MS_LOG(EXCEPTION) << "MatMul shape error, got x_col: " << x_col << ", y_row: " << y_row
<< ". In MatMul x_col and y_row should be equal.";
}
ShapeVector ret_shape;
ShapeVector ret_min_shape;
ShapeVector ret_max_shape;
@ -259,18 +274,6 @@ AbstractBasePtr InferImplMatMul(const AnalysisEnginePtr &, const PrimitivePtr &p
make_shape(ret_shape, x_shp, y_shp);
make_shape(ret_min_shape, x_min_shape, y_min_shape);
make_shape(ret_max_shape, x_max_shape, y_max_shape);
TypePtr x_type = x->element()->GetTypeTrack();
if (x_type->type_id() == TypeId::kNumberTypeInt8) {
x_type = kInt32;
}
if (primitive->HasAttr("cast_type")) {
auto out_type = primitive->GetAttr("cast_type");
MS_EXCEPTION_IF_NULL(out_type);
if (!out_type->isa<Type>()) {
MS_EXCEPTION(ValueError) << "MatMul cast_type must be a `Type`";
}
x_type = out_type->cast<TypePtr>();
}
return std::make_shared<AbstractTensor>(x_type, std::make_shared<Shape>(ret_shape, ret_min_shape, ret_max_shape));
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2021-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -83,16 +83,15 @@ abstract::ShapePtr BatchMatmulInferShape(const PrimitivePtr &primitive,
if (IsDynamicRank(x_shp) || IsDynamicRank(y_shp)) {
return std::make_shared<abstract::Shape>(ShapeVector({UNKNOWN_RANK}));
}
auto context = MsContext::GetInstance();
constexpr size_t x_dim_limit = 3;
constexpr size_t y_dim_limit = 2;
if ((!IsDynamic(x_shp)) && (!IsDynamic(y_shp))) {
if (x_shp.size() < x_dim_limit || y_shp.size() < y_dim_limit) {
MS_EXCEPTION(ValueError)
<< "For '" << prim_name
<< "', input 'x' must be greater or equal to 3, input 'y' must be greater or equal to 2. But got 'x': "
<< x_shp.size() << ", 'y': " << y_shp.size() << ".";
}
bool not_dynamic_shape = (!IsDynamic(x_shp)) && !(IsDynamic(y_shp));
if (not_dynamic_shape && (x_shp.size() < x_dim_limit || y_shp.size() < y_dim_limit)) {
MS_EXCEPTION(ValueError)
<< "For '" << prim_name
<< "', input 'x' must be greater or equal to 3, input 'y' must be greater or equal to 2. But got 'x': "
<< x_shp.size() << ", 'y': " << y_shp.size() << ".";
}
constexpr size_t offset = 2;
@ -106,7 +105,7 @@ abstract::ShapePtr BatchMatmulInferShape(const PrimitivePtr &primitive,
int64_t y_row = y_last[static_cast<size_t>(transpose_b)];
if (std::find(x_shp.begin(), x_shp.end(), -1) == x_shp.end() &&
std::find(y_shp.begin(), y_shp.end(), -1) == y_shp.end()) {
if (x_col != y_row) {
if (not_dynamic_shape && x_col != y_row) {
MS_EXCEPTION(ValueError) << "For " << prim_name << " evaluator shapes of inputs can not do this operator, "
<< "got " << x_col << " and " << y_row << " , with x1 shape " << x_shp
<< "(transpose_a=" << transpose_a << "})"
@ -117,11 +116,7 @@ abstract::ShapePtr BatchMatmulInferShape(const PrimitivePtr &primitive,
(void)primitive->AddAttr("transpose_x2", transpose_b_ptr);
// Additional check for dynamic shape
// Last infer will be real shape values
bool x_not_dyn =
std::all_of(x_shp.begin(), x_shp.end(), [](int64_t value) { return value != abstract::Shape::SHP_ANY; });
bool y_not_dyn =
std::all_of(y_shp.begin(), y_shp.end(), [](int64_t value) { return value != abstract::Shape::SHP_ANY; });
if (x_not_dyn && y_not_dyn) {
if (not_dynamic_shape) {
size_t x_offset = x_shp.size() - offset;
size_t y_offset = y_shp.size() - offset;
auto x_c = x_shp[x_offset + (transpose_a ? 0 : 1)];

View File

@ -40,11 +40,14 @@ abstract::ShapePtr BiasAddInferShape(const PrimitivePtr &primitive, const std::v
(void)CheckAndConvertUtils::CheckInteger("arg size", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape());
auto input_shape = shape_map[kShape];
auto bias_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
if (IsDynamicRank(input_shape) || IsDynamicRank(bias_shape)) {
return std::make_shared<abstract::Shape>(ShapeVector{UNKNOWN_RANK});
}
const int64_t x_min_rank = 2;
const int64_t x_max_rank = 5;
CheckAndConvertUtils::CheckInRange("dims of input_x", input_shape.size(), kIncludeBoth, {x_min_rank, x_max_rank},
prim_name);
auto bias_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
(void)CheckAndConvertUtils::CheckInteger("bias rank", SizeToLong(bias_shape.size()), kEqual, 1, prim_name);
const int64_t x_size = 2;
(void)CheckAndConvertUtils::CheckInteger("x rank", SizeToLong(input_shape.size()), kGreaterEqual, x_size, prim_name);

View File

@ -52,7 +52,7 @@ void SetConv3DBackpropFilterPadList(const PrimitivePtr &primitive, const std::ve
CheckAndConvertUtils::CheckIntOrTupleInt("attribute[dilation]", primitive->GetAttr(kDilation), prim_name);
// default pad mode is valid
int64_t pad_mode;
CheckAndConvertUtils::GetPadModEnumValue(primitive->GetAttr(kPadMode), &pad_mode, true);
CheckAndConvertUtils::GetPadModEnumValue(primitive->GetAttr(kPadMode), &pad_mode, false);
ShapeVector pad_list = {0, 0, 0, 0, 0, 0};
auto attr_pad_list_prt = primitive->GetAttr(kPadList);
if ((attr_pad_list_prt != nullptr) && !attr_pad_list_prt->isa<None>()) {
@ -261,6 +261,7 @@ class Conv3DBackpropFilterInfer : public abstract::OpInferBase {
std::set<TypePtr> valid_x_type = {kInt8, kInt32, kFloat16, kFloat32};
return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_x_type, prim_name);
}
std::set<int64_t> GetValueDependArgIndices() const override { return {kConv3DBackpropFilterFilterSizeIndex}; }
};
REGISTER_PRIMITIVE_OP_INFER_IMPL(Conv3DBackpropFilter, prim::kPrimConv3DBackpropFilter, Conv3DBackpropFilterInfer,

View File

@ -45,7 +45,7 @@ void SetConv3DBackpropInputPadList(const PrimitivePtr &primitive, const std::vec
CheckAndConvertUtils::CheckIntOrTupleInt("attribute[dilation]", primitive->GetAttr(kDilation), prim_name);
// default pad mode is valid
int64_t pad_mode;
CheckAndConvertUtils::GetPadModEnumValue(primitive->GetAttr(kPadMode), &pad_mode, true);
CheckAndConvertUtils::GetPadModEnumValue(primitive->GetAttr(kPadMode), &pad_mode, false);
ShapeVector pad_list = {0, 0, 0, 0, 0, 0};
auto attr_pad_list_prt = primitive->GetAttr(kPadList);
if ((attr_pad_list_prt != nullptr) && (!attr_pad_list_prt->isa<None>())) {
@ -258,6 +258,7 @@ class Conv3DBackpropInputInfer : public abstract::OpInferBase {
std::set<TypePtr> valid_x_type = {kInt8, kInt32, kFloat16, kFloat32};
return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_x_type, prim_name);
}
std::set<int64_t> GetValueDependArgIndices() const override { return {kConv3DBackpropInputSizeIndex}; }
};
REGISTER_PRIMITIVE_OP_INFER_IMPL(Conv3DBackpropInput, prim::kPrimConv3DBackpropInput, Conv3DBackpropInputInfer, false);

View File

@ -121,6 +121,10 @@ abstract::ShapePtr MatrixDiagV3InferShape(const PrimitivePtr &primitive,
auto padding_value_rank = SizeToLong(padding_shape.size());
constexpr int64_t number_one = 1;
constexpr int64_t number_two = 2;
if (IsDynamicRank(x_shape)) {
ShapeVector out_shape = {UNKNOWN_RANK};
return std::make_shared<abstract::Shape>(out_shape);
}
CheckAndConvertUtils::CheckInRange<int64_t>("rank of 'k'", k_rank, kIncludeBoth, {0, number_one}, prim_name);
(void)CheckAndConvertUtils::CheckInteger("rank of 'num_rows'", row_rank, kEqual, 0, prim_name);
(void)CheckAndConvertUtils::CheckInteger("rank of 'num_cols'", col_rank, kEqual, 0, prim_name);

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -34,6 +34,14 @@ abstract::BaseShapePtr SvdInferShape(const PrimitivePtr &prim, const std::vector
auto full_matrices = GetValue<bool>(prim->GetAttr(kAttrFullMatrices));
auto a_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
if (IsDynamicRank(a_shape) || IsDynamic(a_shape)) {
ShapeVector dyn_shape{UNKNOWN_RANK};
std::vector<abstract::BaseShapePtr> shape_tuple;
(void)shape_tuple.emplace_back(std::make_shared<abstract::Shape>(dyn_shape));
(void)shape_tuple.emplace_back(std::make_shared<abstract::Shape>(dyn_shape));
(void)shape_tuple.emplace_back(std::make_shared<abstract::Shape>(dyn_shape));
return std::make_shared<abstract::TupleShape>(shape_tuple);
}
auto ndim = a_shape.size();
(void)CheckAndConvertUtils::CheckInteger("ndim", SizeToLong(ndim), kGreaterEqual, kSizeTwo, prim->name());
auto m = a_shape[ndim - kIndexTwo];

View File

@ -36,16 +36,16 @@ abstract::TupleShapePtr TopKInferShape(const PrimitivePtr &primitive, const std:
auto prim_name = primitive->name();
auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape());
auto x_shape = shape_map[kShape];
if (IsDynamicRank(x_shape)) {
abstract::BaseShapePtr out_shape_ptr = std::make_shared<abstract::Shape>(ShapeVector{UNKNOWN_RANK});
return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{out_shape_ptr, out_shape_ptr});
}
int64_t k_v = 0;
// 2rd input is a Tensor when TopK is a dynamic shape operator
if (input_args[kInputIndex1]->isa<abstract::AbstractTensor>()) {
auto k_ptr = input_args[kInputIndex1]->BuildValue();
MS_EXCEPTION_IF_NULL(k_ptr);
if (k_ptr->isa<tensor::Tensor>()) {
auto k_tensor_ptr = k_ptr->cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(k_tensor_ptr);
k_v = *static_cast<int64_t *>(k_tensor_ptr->data_c());
}
k_v = CheckAndConvertUtils::CheckTensorIntValue("k", k_ptr, prim_name)[0];
} else if (input_args[kInputIndex1]->isa<abstract::AbstractScalar>()) {
k_v = GetValue<int64_t>(input_args[kInputIndex1]->BuildValue());
} else {

View File

@ -29,6 +29,7 @@ from mindspore.ops.primitive import constexpr
from mindspore.common import dtype as mstype
from mindspore.common.tensor import RowTensor
from mindspore.ops._utils.utils import range_op, get_1d_shape, generate_shape_index, is_shape_unknown
from .._grad.grad_base import dyn_rank, convert_to_tensor, dyn_invert_permutation, dyn_size, dyn_ones, dyn_fill
reduce_sum = P.ReduceSum()
unsorted_segment_sum = P.UnsortedSegmentSum()
@ -370,11 +371,20 @@ def _transpose_perm_positive(perm):
return tuple(res)
def _dyn_transpose_perm_positive(perm):
return (perm + dyn_size(perm)) % (dyn_size(perm))
@bprop_getters.register(P.Transpose)
def get_bprop_transpose(self):
"""Generate bprop for Transpose"""
def bprop(x, perm, out, dout):
is_mutable, perm = convert_to_tensor(perm)
if is_mutable:
perm = _dyn_transpose_perm_positive(perm)
return transpose(dout, dyn_invert_permutation(perm)), zeros_like(perm)
perm = _transpose_perm_positive(perm)
return transpose(dout, invert_permutation(perm)), zeros_like(perm)
@ -1019,18 +1029,36 @@ def _gather_drop_negatives(params,
select = P.Select()
if zero_clipped_indices is None:
zero_clipped_indices = maximum(ids, zeros_like(ids))
if is_shape_unknown(shape_op(ids)):
zero_ids = dyn_fill(ids.dtype, dyn_shape_op(ids), 0)
else:
zero_ids = zeros_like(ids)
zero_clipped_indices = maximum(ids, zero_ids)
gathered = gather(params, zero_clipped_indices, 0)
zero_slice = zeros_like(gathered)
if is_positive is None:
is_positive = greater_equal(ids, 0)
is_positive_shape = shape_op(is_positive)
broadcastable_shape = is_positive_shape
for _ in range(rank(gathered) - rank(is_positive)):
broadcastable_shape += (1,)
is_positive = reshape(is_positive, broadcastable_shape)
gathered_shape = shape_op(gathered)
is_positive = logical_and(is_positive, fill(mstype.bool_, gathered_shape, 1))
zero_slice = zeros_like(gathered)
if is_shape_unknown(gathered_shape) or is_shape_unknown(is_positive_shape):
gathered_shape = dyn_shape_op(gathered)
rank_gathered = dyn_rank(gathered)
fill_gathered = dyn_fill(mstype.int64, gathered_shape, 1)
is_positive_shape = dyn_shape_op(is_positive)
rank_positive = dyn_rank(is_positive)
if rank_gathered - rank_positive > 0:
padded_size = F.expand_dims(rank_gathered - rank_positive, 0)
padded_shape = dyn_ones(padded_size, is_positive_shape.dtype)
is_positive_shape = P.Concat(-1)((is_positive_shape, padded_shape))
is_positive = reshape(is_positive, is_positive_shape)
is_positive = logical_and(is_positive, F.cast(fill_gathered, mstype.bool_))
zero_slice = dyn_fill(gathered.dtype, gathered_shape, 0)
else:
broadcastable_shape = is_positive_shape
for _ in range(rank(gathered) - rank(is_positive)):
broadcastable_shape += (1,)
is_positive = reshape(is_positive, broadcastable_shape)
is_positive = logical_and(is_positive, fill(mstype.bool_, gathered_shape, 1))
return (select(is_positive, gathered, zero_slice), zero_clipped_indices, is_positive)
@ -1059,7 +1087,13 @@ def get_bprop_unsorted_segment_sum(self):
"""Generate bprop for UnsortedSegmentSum"""
def bprop(x, segment_ids, num_segments, out, dout):
return _gather_drop_negatives(dout, segment_ids, None, None)[0], zeros_like(segment_ids), \
segment_shape = shape_op(segment_ids)
if is_shape_unknown(segment_shape):
segment_shape = dyn_shape_op(segment_ids)
zeros_segment = dyn_fill(segment_ids.dtype, segment_shape, 0)
else:
zeros_segment = zeros_like(segment_ids)
return _gather_drop_negatives(dout, segment_ids, None, None)[0], zeros_segment, \
zeros_like(num_segments)
return bprop

View File

@ -97,12 +97,12 @@ def dyn_rank(tensor):
return dyn_shape(dyn_shape(tensor))[0]
def dyn_size(tensor):
def dyn_size(tensor, dtype=mstype.int64):
"""get the size of tensor"""
shape = dyn_shape(tensor)
shape = cast(shape, mstype.float32)
size = P.ReduceProd()(shape)
size = cast(size, mstype.int32)
size = cast(size, dtype)
return size
@ -117,7 +117,7 @@ def create_tensor_by_element(ori_tuple, data_type=mstype.int64):
return ori_tuple
def dyn_invert_premutation(prem):
def dyn_invert_permutation(prem):
"""get the invert premutation of tensor"""
indices = P.ExpandDims()(prem, -1)
end = dyn_size(prem)

View File

@ -53,8 +53,20 @@ def dyn_binop_grad_common(x, y, dx, dy):
shape_of_x = dyn_shape_op(x)
shape_of_y = dyn_shape_op(y)
rx, ry = DynamicBroadcastGradientArgs()(shape_of_x, shape_of_y)
dx = reduce_sum(dx, rx)
dy = reduce_sum(dy, ry)
dx_origin_dtype = dx.dtype
if dx_origin_dtype in (mstype.int16, mstype.int32, mstype.int64):
dx = F.cast(dx, mstype.float32)
dx = reduce_sum(dx, rx)
dx = F.cast(dx, dx_origin_dtype)
else:
dx = reduce_sum(dx, rx)
dy_origin_dtype = dy.dtype
if dy_origin_dtype in (mstype.int16, mstype.int32, mstype.int64):
dy = F.cast(dy, mstype.float32)
dy = reduce_sum(dy, ry)
dy = F.cast(dy, dy_origin_dtype)
else:
dy = reduce_sum(dy, ry)
reduce_dx = reshape(dx, shape_of_x)
reduce_dy = reshape(dy, shape_of_y)
return reduce_dx, reduce_dy

View File

@ -18,7 +18,7 @@ import numpy as np
from mindspore.ops.primitive import constexpr
from mindspore.ops.operations import nn_ops as nps
from mindspore.common import dtype as mstype
from .grad_base import bprop_getters
from .grad_base import bprop_getters, dyn_size, create_tensor_by_element
from .. import functional as F
from .. import operations as P
from ..composite.multitype_ops.zeros_like_impl import zeros_like
@ -61,6 +61,17 @@ def bias_add_gradgrad_helper(shape, bias_shape, data_format):
return tuple(expanded_shape), tuple(tile_mults)
def bias_add_gradgrad_helper_dynamic(shape, bias_shape, data_format):
"""Helper function of BiasGradGrad to calculate expanded shape(dynamic version)."""
if data_format == "NCHW":
expanded_shape = P.Concat(0)((P.OnesLike()(shape[:1]), bias_shape, P.OnesLike()(shape[2:])))
tile_mults = P.Concat(0)((shape[:1], [1], shape[2:]))
else:
expanded_shape = P.Concat(0)((P.OnesLike()(shape[:-1]), bias_shape))
tile_mults = P.Concat(0)((shape[:-1], [1]))
return tuple(expanded_shape), tuple(tile_mults)
@bprop_getters.register(G.BiasAddGrad)
def get_bprop_bias_add_grad(self):
"""Grad definition for `BiasAddGrad` operation."""
@ -70,10 +81,19 @@ def get_bprop_bias_add_grad(self):
def bprop(dy, out, dout):
reshape = P.Reshape()
tile = P.Tile()
expanded_shape, tile_mults = bias_add_gradgrad_helper(dy.shape, dout.shape, data_format)
expanded_grad = reshape(dout, expanded_shape)
tiled_grad = tile(expanded_grad, tile_mults)
dyn_shape = P.TensorShape()
if is_shape_unknown(dy) or is_shape_unknown(dout):
dy_shape = dyn_shape(dy)
dout_shape = dyn_shape(dout)
expanded_shape, tile_mults = bias_add_gradgrad_helper_dynamic(dy_shape, dout_shape, data_format)
expanded_grad = reshape(dout, create_tensor_by_element(expanded_shape))
tiled_grad = tile(expanded_grad, create_tensor_by_element(tile_mults))
else:
dy_shape = dy.shape
dout_shape = dout.shape
expanded_shape, tile_mults = bias_add_gradgrad_helper(dy_shape, dout_shape, data_format)
expanded_grad = reshape(dout, expanded_shape)
tiled_grad = tile(expanded_grad, tile_mults)
return (tiled_grad,)
return bprop
@ -122,8 +142,14 @@ def get_bprop_conv3d(self):
pad=self.pad, stride=self.stride, dilation=self.dilation, group=self.group, data_format=self.data_format
)
get_shape = P.Shape()
get_dyn_shape = P.TensorShape()
def bprop(x, w, out, dout):
if is_shape_unknown(get_shape(x)) or is_shape_unknown(get_shape(w)):
dx = input_grad(w, dout, get_dyn_shape(x))
dw = filter_grad(x, dout, get_dyn_shape(w))
return dx, dw
dx = input_grad(w, dout, get_shape(x))
dw = filter_grad(x, dout, get_shape(w))
return dx, dw
@ -145,8 +171,14 @@ def get_bprop_conv3d_transpose(self):
out_channel=self.in_channel, kernel_size=self.kernel_size, mode=self.mode, pad_mode="pad",
pad=pad_list, stride=self.stride, dilation=self.dilation, group=self.group, data_format=self.data_format
)
get_dyn_shape = P.TensorShape()
def bprop(x, w, out, dout):
if is_shape_unknown(F.shape(w)):
dx = input_grad(dout, w)
dw = filter_grad(dout, x, get_dyn_shape(w))
return dx, dw
dx = input_grad(dout, w)
dw = filter_grad(dout, x, F.shape(w))
return dx, dw
@ -928,10 +960,12 @@ def get_bprop_top_kv2(self):
scatter = P.ScatterNd()
expand_dims = P.ExpandDims()
shape_op = P.Shape()
dyn_shape = P.TensorShape()
reshape_op = P.Reshape()
dtype = P.DType()
cast = P.Cast()
def bprop(input_x, k, out, dout):
def _bprop_static(input_x, k, out, dout):
in_shape = shape_op(input_x)
in_lastdim = in_shape[-1]
@ -958,6 +992,37 @@ def get_bprop_top_kv2(self):
in_shape)
return out_grad, zeros_like(k)
def _bprop_dynshape(input_x, k, out, dout):
in_shape = dyn_shape(input_x)
in_lastdim = in_shape[-1]
indices = out[1]
ind_shape = dyn_shape(indices)
ind_lastdim = ind_shape[-1]
ind_2d = reshape_op(indices, create_tensor_by_element((-1, ind_lastdim)))
outerdim = dyn_shape(ind_2d)[0]
# [0, outterdim, 2*outerdim, ..., (k-1)*outerdim]
range_flatten_index = P.Range()(cast(0, mstype.int64), outerdim * in_lastdim, in_lastdim)
# expand_dims to (k, 1), then broadcast
ind = reshape_op(ind_2d + expand_dims(range_flatten_index, -1), create_tensor_by_element((-1,)))
in_shape_1d = expand_dims(dyn_size(input_x, mstype.int64), -1)
out_grad = reshape_op(
scatter(
expand_dims(ind, -1),
reshape_op(dout[0], create_tensor_by_element((-1,))),
in_shape_1d),
in_shape)
return out_grad, zeros_like(k)
def bprop(input_x, k, out, dout):
if is_shape_unknown(shape_op(input_x)):
return _bprop_dynshape(input_x, k, out, dout)
return _bprop_static(input_x, k, out, dout)
return bprop
@ -1239,7 +1304,7 @@ def get_bprop_binary_cross_entropy(self):
@bprop_getters.register(P.BCEWithLogitsLoss)
def get_bprop_ce_with_logits_loss(self):
def get_bprop_bce_with_logits_loss(self):
"""Grad definition for `BCEWithLogitsLoss` operation."""
reduction = self.reduction
mul = P.Mul()
@ -1249,6 +1314,7 @@ def get_bprop_ce_with_logits_loss(self):
size = P.Size()
neg = P.Neg()
log = P.Log()
shape = P.Shape()
def bprop(predict, target, weight, pos_weight, out, dout):
sigmoid_input = sigmoid(predict)
@ -1263,8 +1329,10 @@ def get_bprop_ce_with_logits_loss(self):
dx = mul(dx, weight)
grad_target = mul(grad_target, weight)
if reduction == 'mean':
dx = dx / size(dx)
grad_target = grad_target / size(target)
dx_size = dyn_size(dx) if is_shape_unknown(shape(dx)) else size(dx)
target_size = dyn_size(target) if is_shape_unknown(shape(target)) else size(target)
dx = dx / dx_size
grad_target = grad_target / target_size
return dx, grad_target, zeros_like(weight), zeros_like(pos_weight)
return bprop
@ -1414,7 +1482,7 @@ def get_bprop_conv2d_backprop_filter(self):
def bprop(dy, x, filter_size, out, dout):
x_shape = get_shape(x)
if -1 in x_shape:
if is_shape_unknown(x_shape):
x_shape = get_dyn_shape(x)
dw_dx = input_grad(dy, dout, x_shape)
dw_dy = filter_grad(x, dout)

View File

@ -19,14 +19,19 @@ import mindspore
from .. import Tensor
from .. import functional as F
from .. import operations as P
from ..composite.multitype_ops.zeros_like_impl import zeros_like
from ..operations import math_ops as math
from ..operations import linalg_ops as linalg
from ..operations import array_ops as arrays
from ..primitive import constexpr
from .._grad.grad_base import bprop_getters
from .._grad.grad_base import dyn_rank
from .._utils.utils import is_shape_unknown
_shape = arrays.Shape()
_dyn_shape = arrays.TensorShape()
_dtype = arrays.DType()
_cast = arrays.Cast()
_transpose = arrays.Transpose()
@ -47,12 +52,19 @@ def _raise_value_error(*info):
def _matrix_transpose(a):
dims = a.ndim
if dims < 2:
_raise_value_error(
"To do _matrix_transpose for input a's ndim is not greater or equal to 2, which is invalid.")
axes = F.make_range(0, dims)
axes = axes[:-2] + (axes[-1],) + (axes[-2],)
"""Transpose last two axes"""
if is_shape_unknown(_shape(a)):
dims = dyn_rank(a)
axes = P.Range()(P.Cast()(0, mindspore.int64), dims, P.Cast()(1, mindspore.int64))
axes = P.Concat(axis=-1)((axes[:-2], axes[-1:], axes[-2:-1]))
else:
dims = a.ndim
if dims < 2:
_raise_value_error(
"To do _matrix_transpose for input a's ndim is not greater or equal to 2, which is invalid: {}."
.format(dims))
axes = F.make_range(0, dims)
axes = axes[:-2] + (axes[-1],) + (axes[-2],)
return _transpose(a, axes)
@ -76,14 +88,26 @@ def _make_zero_matrix(shape, dtype):
def _matrix_diag(diagonal):
"""Do matrix diagnoal"""
diagonal_shape = _shape(diagonal)
if is_shape_unknown(diagonal_shape):
diagonal_shape = _dyn_shape(diagonal)
row = P.Cast()(diagonal_shape[-1], mindspore.int32)
return arrays.MatrixDiagV3()(diagonal, _k_0, row, row, P.Cast()(0, _dtype(diagonal)))
row = _make_tensor(diagonal_shape[-1], mindspore.int32)
return arrays.MatrixDiagV3()(diagonal, _k_0, row, row, _make_tensor(0, _dtype(diagonal)))
def _mat_mul(x, y):
"""Do matmul"""
shape = _shape(x)
if len(shape) > 2:
if is_shape_unknown(shape):
shape = _dyn_shape(x)
tensor_rank = dyn_rank(x)
else:
tensor_rank = len(shape)
if tensor_rank > 2:
return math.BatchMatMul()(x, y)
return math.MatMul()(x, y)
@ -106,12 +130,16 @@ def get_bprop_svd(self):
return (da,)
a_shape = _shape(a)
if len(a_shape) < 2:
if is_shape_unknown(a_shape):
a_shape = _dyn_shape(a)
tensor_rank = dyn_rank(a)
else:
tensor_rank = len(a_shape)
if tensor_rank < 2:
_raise_value_error(
"For input a's ndim is not greater or equal to 2, which is invalid.")
m = a_shape[-2]
n = a_shape[-1]
s, u, v = out
ds, du, dv = dout
use_adjoint = False

View File

@ -1,4 +1,4 @@
# Copyright 2021 Huawei Technologies Co., Ltd
# Copyright 2021-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -17,6 +17,7 @@
from mindspore.common import dtype as mstype
from mindspore.ops.operations.math_ops import Trace, Bernoulli, Renorm
from mindspore.ops._utils.utils import is_shape_unknown
from mindspore import nn
import mindspore.numpy as mnp
import numpy as np
@ -27,7 +28,7 @@ from ..operations.math_ops import Real, Imag, Complex, Angle
from ..operations.math_ops import ComplexAbs
from ..operations.math_ops import Sinc
from ..functional import broadcast_gradient_args
from .._grad.grad_base import bprop_getters
from .._grad.grad_base import bprop_getters, create_tensor_by_element, dyn_rank
from .._grad.grad_math_ops import binop_grad_common
from ..composite.multitype_ops.zeros_like_impl import zeros_like
from ..operations import _grad_ops as G
@ -55,9 +56,11 @@ from ..operations.math_ops import CholeskySolve
from ..operations.math_ops import AddV2
from ..operations.math_ops import TridiagonalMatMul
from ..operations.math_ops import Logit
from .._utils.utils import is_shape_unknown
transpose = P.Transpose()
dyn_shape_op = P.TensorShape()
_conj = P.Conj()
@ -67,6 +70,11 @@ def _generate_perm(x_dim):
return perm
def _dyn_generate_perm(x_dim):
perm = P.Range()(P.Cast()(0, x_dim.dtype), x_dim - 2, P.Cast()(1, x_dim.dtype))
return perm
def _adjoint(a):
return cholesky_transpose(_conj(a))
@ -110,10 +118,20 @@ def get_bprop_cdist(self):
def bprop(input_x, input_y, out, dout):
dout_shape = F.shape(dout)
dout_dim = len(dout_shape)
dout_perm_part1 = _generate_perm(dout_dim)
dout_perm_part2 = (dout_dim - 1, dout_dim - 2)
dout_perm = dout_perm_part1 + dout_perm_part2
if is_shape_unknown(dout_shape):
dout_dim = dyn_rank(dout)
dout_perm_part2 = create_tensor_by_element(
(dout_dim - 1, dout_dim - 2))
if dout_dim <= 2:
dout_perm = dout_perm_part2
else:
dout_perm_part1 = _dyn_generate_perm(dout_dim)
dout_perm = P.Concat(0)((dout_perm_part1, dout_perm_part2))
else:
dout_dim = len(dout_shape)
dout_perm_part1 = _generate_perm(dout_dim)
dout_perm_part2 = (dout_dim - 1, dout_dim - 2)
dout_perm = dout_perm_part1 + dout_perm_part2
out_perm = dout_perm
dout_transpose = transpose(dout, dout_perm)
out_transpose = transpose(out, out_perm)
@ -484,8 +502,16 @@ def get_bprop_matrix_determinant(self):
inverse_op = P.MatrixInverse(adjoint=True)
shape_op = P.Shape()
reshape = P.Reshape()
concat = P.Concat(0)
def bprop(x, out, dout):
if is_shape_unknown(shape_op(x)):
x_adj_inv = inverse_op(x)
out_shape = dyn_shape_op(out)
ones = create_tensor_by_element((1, 1))
multipliers = reshape(dout * out, concat((out_shape, ones)))
dx = multipliers * x_adj_inv
return (dx,)
x_adj_inv = inverse_op(x)
multipliers = reshape(dout * out, shape_op(out) + (1, 1))
dx = multipliers * x_adj_inv
@ -902,7 +928,11 @@ def get_bprop_trace(self):
def bprop(x, out, dout):
shape = shape_op(x)
dx = input_grad(dout, cast(to_array(shape), mstype.int64))
if is_shape_unknown(shape):
shape = dyn_shape_op(x)
dx = input_grad(dout, shape)
else:
dx = input_grad(dout, cast(to_array(shape), mstype.int64))
return (dx,)
return bprop
@ -1041,7 +1071,7 @@ def get_bprop_tridiagonal_matmul(self):
maindiag_grad = reduce_sum(rhs_conj * grad, -1)
subdiag_grad = reduce_sum(_rightshift(rhs_conj) * grad, -1)
rhs_grad = _rightshift(superdiag_conj * grad) + maindiag_conj * grad + \
_leftshift(subdiag_conj * grad)
_leftshift(subdiag_conj * grad)
superdiag_grad = expand_dims(superdiag_grad, -2)
maindiag_grad = expand_dims(maindiag_grad, -2)
subdiag_grad = expand_dims(subdiag_grad, -2)

View File

@ -25,6 +25,7 @@ from mindspore.ops.operations.sparse_ops import SparseSegmentSumWithNumSegments
from mindspore.ops.operations.sparse_ops import SparseSegmentSqrtN
from mindspore.ops.operations.sparse_ops import SparseSegmentSqrtNWithNumSegments
from mindspore.ops.operations.sparse_ops import SparseSegmentMeanWithNumSegments
from mindspore.ops._utils.utils import is_shape_unknown
from mindspore.common import dtype as mstype
from mindspore import Tensor
from mindspore.ops.primitive import constexpr
@ -36,6 +37,7 @@ from .._grad.grad_base import bprop_getters
from .._utils.utils import is_shape_unknown
# Unused parameters are placeholders.
dyn_shape_op = P.TensorShape()
@constexpr
@ -50,6 +52,7 @@ def get_bprop_sparse_softmax(self):
sparse_dense_cwise_add = SparseDenseCwiseAdd()
reduce_sum = P.ReduceSum(keep_dims=True)
mul = P.Mul()
def bprop(indices, values, shape, out, dout):
default_values = _create_tensor(0, values.dtype)
out_dout = mul(out, dout)
@ -106,7 +109,6 @@ def get_bprop_sparse_segment_sqrt_n(self):
"""Grad definition for `SparseSegmentSqrtN` operation."""
input_grad = G.SparseSegmentSqrtNGrad()
shape = P.Shape()
dyn_shape_op = P.TensorShape()
def bprop(x, indices, segment_ids, out, dout):
shape_x = shape(x)
@ -127,7 +129,6 @@ def get_bprop_sparse_segment_sqrt_n_with_num_segments(self):
"""Grad definition for `SparseSegmentSqrtNWithNumSegments` operation."""
input_grad = G.SparseSegmentSqrtNGrad()
shape = P.Shape()
dyn_shape_op = P.TensorShape()
def bprop(x, indices, segment_ids, num_segments, out, dout):
shape_x = shape(x)
@ -148,7 +149,6 @@ def get_bprop_sparse_segment_sum(self):
"""Grad definition for `SparseSegmentSum` operation."""
input_grad = G.SparseSegmentSumGrad()
shape = P.Shape()
dyn_shape_op = P.TensorShape()
def bprop(x, indices, segment_ids, out, dout):
shape_x = shape(x)
@ -169,7 +169,6 @@ def get_bprop_sparse_segment_sum_with_num_segments(self):
"""Grad definition for `SparseSegmentSumWithNumSegments` operation."""
input_grad = G.SparseSegmentSumGrad()
shape = P.Shape()
dyn_shape_op = P.TensorShape()
def bprop(x, indices, segment_ids, num_segments, out, dout):
shape_x = shape(x)
@ -192,7 +191,12 @@ def get_bprop_sparse_segment_mean_with_num_segments(self):
shape = P.Shape()
def bprop(x, indices, segment_ids, num_segments, out, dout):
output_dim0 = F.scalar_to_tensor(shape(x)[0], mstype.int32)
x_shp = shape(x)
if is_shape_unknown(x_shp):
x_shp = dyn_shape_op(x)
output_dim0 = F.cast(x_shp[0], mstype.int32)
else:
output_dim0 = F.scalar_to_tensor(x_shp[0], mstype.int32)
indices = F.cast(indices, mstype.int32)
segment_ids = F.cast(segment_ids, mstype.int32)
dx = input_grad(dout, indices, segment_ids, output_dim0)
@ -208,6 +212,7 @@ def get_bprop_sparse_reorder(self):
sparse_reorder_op = SparseReorder()
range_op = P.Range()
gather_op = P.Gather()
def bprop(indices, values, shape, out, dout):
num_entries = F.shape(indices)[0]
start = Tensor(0, dtype=mstype.int32)

View File

@ -1,20 +1,20 @@
0.1.1 MindSpore*1.9.0:Ü

0.1.1 MindSpore*1.9.0:Þ

bprop.12:xbprop.12:[CNode]13:1bprop.12:[CNode]13:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:-Default/S-Prim-hyper_map[zeros_like_leaf]-op8

bprop.33:xbprop.33:[CNode]34:1bprop.33:[CNode]34:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op23

bprop.12:ybprop.12:[CNode]14:3bprop.12:[CNode]14:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:-Default/S-Prim-hyper_map[zeros_like_leaf]-op9
bprop.33:ybprop.33:[CNode]35:3bprop.33:[CNode]35:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op24
<EFBFBD>
bprop.12:[CNode]13:1
bprop.12:[CNode]14:3bprop.12:[CNode]15:4bprop.12:[CNode]15:4"REF::S-Prim-MakeTuple:5:Default/S-Prim-MakeTuple-op10bprop.12*
bprop.33:[CNode]34:1
bprop.33:[CNode]35:3bprop.33:[CNode]36:4bprop.33:[CNode]36:4"REF::S-Prim-MakeTuple:5:Default/S-Prim-MakeTuple-op25bprop.33*
bprop.12:x*
bprop.33:x*
bprop.12:y*
bprop.12:out*
bprop.12:dout2
bprop.12:[CNode]15:4:@51014516712342426367430e7abdbb8b5454385a02c91339e89ad18f19803ee3Pb&
bprop.33:y*
bprop.33:out*
bprop.33:dout2
bprop.33:[CNode]36:4:@157abcda70ad669686a320be427dedf2a17498d59a38b65d04c272c1a40296e0Pb&
S-Prim-MakeTuple:5S-Prim-MakeTuplebH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h

View File

@ -1,14 +1,14 @@
0.1.1 MindSpore*1.9.0:£

0.1.1 MindSpore*1.9.0:¤

bprop.14:xbprop.14:[CNode]15:1bprop.14:[CNode]15:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:-Default/S-Prim-hyper_map[zeros_like_leaf]-op9
bprop.16:xbprop.16:[CNode]17:1bprop.16:[CNode]17:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op11
z
bprop.14:[CNode]15:1bprop.14:[CNode]16:3bprop.14:[CNode]16:3"REF::S-Prim-MakeTuple:4:Default/S-Prim-MakeTuple-op10bprop.14*
bprop.16:[CNode]17:1bprop.16:[CNode]18:3bprop.16:[CNode]18:3"REF::S-Prim-MakeTuple:4:Default/S-Prim-MakeTuple-op12bprop.16*
bprop.14:x*
bprop.14:out*
bprop.14:dout2
bprop.14:[CNode]16:3:@da20d6cb5b4c7c7cced90d19d6551099a88fa59be3bc1e215683bed4c1b029abPb&
S-Prim-MakeTuple:4S-Prim-MakeTuplebH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h
bprop.16:x*
bprop.16:out*
bprop.16:dout2
bprop.16:[CNode]18:3:@030ffb7e684c8e2604e69bfc02a6e32af82543a97f4aa0a6bc7cab6275ee9b9dPbH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
S-Prim-MakeTuple:4S-Prim-MakeTupleh

View File

@ -2,13 +2,13 @@
0.1.1 MindSpore*1.9.0:¤

bprop.17:xbprop.17:[CNode]18:1bprop.17:[CNode]18:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op11
bprop.19:xbprop.19:[CNode]20:1bprop.19:[CNode]20:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op13
z
bprop.17:[CNode]18:1bprop.17:[CNode]19:3bprop.17:[CNode]19:3"REF::S-Prim-MakeTuple:4:Default/S-Prim-MakeTuple-op12bprop.17*
bprop.19:[CNode]20:1bprop.19:[CNode]21:3bprop.19:[CNode]21:3"REF::S-Prim-MakeTuple:4:Default/S-Prim-MakeTuple-op14bprop.19*
bprop.17:x*
bprop.17:out*
bprop.17:dout2
bprop.17:[CNode]19:3:@da20d6cb5b4c7c7cced90d19d6551099a88fa59be3bc1e215683bed4c1b029abPb&
bprop.19:x*
bprop.19:out*
bprop.19:dout2
bprop.19:[CNode]21:3:@030ffb7e684c8e2604e69bfc02a6e32af82543a97f4aa0a6bc7cab6275ee9b9dPb&
S-Prim-MakeTuple:4S-Prim-MakeTuplebH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h

View File

@ -1,16 +1,20 @@
0.1.1 MindSpore*1.9.0:Â
Œ
bprop.1:xbprop.1:[CNode]2:1bprop.1:[CNode]2:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:-Default/S-Prim-hyper_map[zeros_like_leaf]-op0
Œ
bprop.1:ybprop.1:[CNode]3:3bprop.1:[CNode]3:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:-Default/S-Prim-hyper_map[zeros_like_leaf]-op1

bprop.1:[CNode]2:1
bprop.1:[CNode]3:3bprop.1:[CNode]4:4bprop.1:[CNode]4:4"REF::S-Prim-MakeTuple:5:Default/S-Prim-MakeTuple-op2bprop.1*
bprop.1:x*
bprop.1:y*
bprop.1:out*
bprop.1:dout2
bprop.1:[CNode]4:4:@51014516712342426367430e7abdbb8b5454385a02c91339e89ad18f19803ee3PbH
0.1.1 MindSpore*1.9.0:Þ

bprop.22:xbprop.22:[CNode]23:1bprop.22:[CNode]23:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op15

bprop.22:ybprop.22:[CNode]24:3bprop.22:[CNode]24:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op16
<EFBFBD>
bprop.22:[CNode]23:1
bprop.22:[CNode]24:3bprop.22:[CNode]25:4bprop.22:[CNode]25:4"REF::S-Prim-MakeTuple:5:Default/S-Prim-MakeTuple-op17bprop.22*
bprop.22:x*
bprop.22:y*
bprop.22:out*
bprop.22:dout2
bprop.22:[CNode]25:4:@157abcda70ad669686a320be427dedf2a17498d59a38b65d04c272c1a40296e0PbH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
S-Prim-MakeTuple:5S-Prim-MakeTupleh

View File

@ -1,16 +1,20 @@
0.1.1 MindSpore*1.9.0:Â
Œ
bprop.5:xbprop.5:[CNode]6:1bprop.5:[CNode]6:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:-Default/S-Prim-hyper_map[zeros_like_leaf]-op3
Œ
bprop.5:ybprop.5:[CNode]7:3bprop.5:[CNode]7:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:-Default/S-Prim-hyper_map[zeros_like_leaf]-op4

bprop.5:[CNode]6:1
bprop.5:[CNode]7:3bprop.5:[CNode]8:4bprop.5:[CNode]8:4"REF::S-Prim-MakeTuple:5:Default/S-Prim-MakeTuple-op5bprop.5*
bprop.5:x*
bprop.5:y*
bprop.5:out*
bprop.5:dout2
bprop.5:[CNode]8:4:@51014516712342426367430e7abdbb8b5454385a02c91339e89ad18f19803ee3Pb&
S-Prim-MakeTuple:5S-Prim-MakeTuplebH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h
0.1.1 MindSpore*1.9.0:Þ

bprop.26:xbprop.26:[CNode]27:1bprop.26:[CNode]27:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op18

bprop.26:ybprop.26:[CNode]28:3bprop.26:[CNode]28:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op19
<EFBFBD>
bprop.26:[CNode]27:1
bprop.26:[CNode]28:3bprop.26:[CNode]29:4bprop.26:[CNode]29:4"REF::S-Prim-MakeTuple:5:Default/S-Prim-MakeTuple-op20bprop.26*
bprop.26:x*
bprop.26:y*
bprop.26:out*
bprop.26:dout2
bprop.26:[CNode]29:4:@157abcda70ad669686a320be427dedf2a17498d59a38b65d04c272c1a40296e0PbH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
S-Prim-MakeTuple:5S-Prim-MakeTupleh

View File

@ -2,13 +2,13 @@
0.1.1 MindSpore*1.9.0:¤

bprop.20:xbprop.20:[CNode]21:1bprop.20:[CNode]21:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op13
bprop.37:xbprop.37:[CNode]38:1bprop.37:[CNode]38:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op27
z
bprop.20:[CNode]21:1bprop.20:[CNode]22:3bprop.20:[CNode]22:3"REF::S-Prim-MakeTuple:4:Default/S-Prim-MakeTuple-op14bprop.20*
bprop.37:[CNode]38:1bprop.37:[CNode]39:3bprop.37:[CNode]39:3"REF::S-Prim-MakeTuple:4:Default/S-Prim-MakeTuple-op28bprop.37*
bprop.20:x*
bprop.20:out*
bprop.20:dout2
bprop.20:[CNode]22:3:@da20d6cb5b4c7c7cced90d19d6551099a88fa59be3bc1e215683bed4c1b029abPb&
S-Prim-MakeTuple:4S-Prim-MakeTuplebH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h
bprop.37:x*
bprop.37:out*
bprop.37:dout2
bprop.37:[CNode]39:3:@030ffb7e684c8e2604e69bfc02a6e32af82543a97f4aa0a6bc7cab6275ee9b9dPbH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
S-Prim-MakeTuple:4S-Prim-MakeTupleh

View File

@ -1,27 +1,27 @@
0.1.1 MindSpore*1.9.0:°

bprop.13:dout
bprop.32:dout
bprop.13:y
bprop.13:keep_probbprop.13:[CNode]14:1bprop.13:[CNode]14:1"REF::S-Prim-DropoutDoMask:2:!Default/S-Prim-DropoutDoMask-op10
bprop.32:y
bprop.32:keep_probbprop.32:[CNode]33:1bprop.32:[CNode]33:1"REF::S-Prim-DropoutDoMask:2:!Default/S-Prim-DropoutDoMask-op23

bprop.13:ybprop.13:[CNode]15:3bprop.13:[CNode]15:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:4:.Default/S-Prim-hyper_map[zeros_like_leaf]-op11
bprop.32:ybprop.32:[CNode]34:3bprop.32:[CNode]34:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:4:.Default/S-Prim-hyper_map[zeros_like_leaf]-op24
š
bprop.13:keep_probbprop.13:[CNode]16:5bprop.13:[CNode]16:5"(REF::S-Prim-hyper_map[zeros_like_leaf]:4:.Default/S-Prim-hyper_map[zeros_like_leaf]-op12
bprop.32:keep_probbprop.32:[CNode]35:5bprop.32:[CNode]35:5"(REF::S-Prim-hyper_map[zeros_like_leaf]:4:.Default/S-Prim-hyper_map[zeros_like_leaf]-op25
¦
bprop.13:[CNode]14:1
bprop.13:[CNode]15:3
bprop.13:[CNode]16:5bprop.13:[CNode]17:6bprop.13:[CNode]17:6"REF::S-Prim-MakeTuple:7:Default/S-Prim-MakeTuple-op13bprop.13*
bprop.32:[CNode]33:1
bprop.32:[CNode]34:3
bprop.32:[CNode]35:5bprop.32:[CNode]36:6bprop.32:[CNode]36:6"REF::S-Prim-MakeTuple:7:Default/S-Prim-MakeTuple-op26bprop.32*
bprop.13:x*
bprop.32:x*
bprop.13:y*
bprop.13:keep_prob*
bprop.13:out*
bprop.13:dout2
bprop.13:[CNode]17:6:@de4408af9f648bc43794b496de642f2ba8fef52099d138465267bb08b5fd7e8fPb&
S-Prim-MakeTuple:7S-Prim-MakeTupleb.
S-Prim-DropoutDoMask:2S-Prim-DropoutDoMaskbH
#S-Prim-hyper_map[zeros_like_leaf]:4!S-Prim-hyper_map[zeros_like_leaf]h
bprop.32:y*
bprop.32:keep_prob*
bprop.32:out*
bprop.32:dout2
bprop.32:[CNode]36:6:@320a603a26f1cc5c0174957a4eea3494bfd5f1c2f20ad64ef6e0ee49a91b62baPb&
S-Prim-MakeTuple:7S-Prim-MakeTuplebH
#S-Prim-hyper_map[zeros_like_leaf]:4!S-Prim-hyper_map[zeros_like_leaf]b.
S-Prim-DropoutDoMask:2S-Prim-DropoutDoMaskh

View File

@ -1,16 +1,16 @@
0.1.1 MindSpore*1.9.0:Ú
0.1.1 MindSpore*1.9.0:ö

bprop.22:shapebprop.22:[CNode]23:1bprop.22:[CNode]23:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op15
š
bprop.22:keep_probbprop.22:[CNode]24:3bprop.22:[CNode]24:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op16
<EFBFBD>
bprop.3:shapebprop.3:[CNode]4:1bprop.3:[CNode]4:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:-Default/S-Prim-hyper_map[zeros_like_leaf]-op2

bprop.3:keep_probbprop.3:[CNode]5:3bprop.3:[CNode]5:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:-Default/S-Prim-hyper_map[zeros_like_leaf]-op3

bprop.3:[CNode]4:1
bprop.3:[CNode]5:3bprop.3:[CNode]6:4bprop.3:[CNode]6:4"REF::S-Prim-MakeTuple:5:Default/S-Prim-MakeTuple-op4bprop.3*
bprop.3:shape*
bprop.3:keep_prob*
bprop.3:out*
bprop.3:dout2
bprop.3:[CNode]6:4:@de4408af9f648bc43794b496de642f2ba8fef52099d138465267bb08b5fd7e8fPb&
bprop.22:[CNode]23:1
bprop.22:[CNode]24:3bprop.22:[CNode]25:4bprop.22:[CNode]25:4"REF::S-Prim-MakeTuple:5:Default/S-Prim-MakeTuple-op17bprop.22*
bprop.22:shape*
bprop.22:keep_prob*
bprop.22:out*
bprop.22:dout2
bprop.22:[CNode]25:4:@320a603a26f1cc5c0174957a4eea3494bfd5f1c2f20ad64ef6e0ee49a91b62baPb&
S-Prim-MakeTuple:5S-Prim-MakeTuplebH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h

View File

@ -2,13 +2,13 @@
0.1.1 MindSpore*1.9.0:¤

bprop.26:xbprop.26:[CNode]27:1bprop.26:[CNode]27:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op17
bprop.43:xbprop.43:[CNode]44:1bprop.43:[CNode]44:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op31
z
bprop.26:[CNode]27:1bprop.26:[CNode]28:3bprop.26:[CNode]28:3"REF::S-Prim-MakeTuple:4:Default/S-Prim-MakeTuple-op18bprop.26*
bprop.43:[CNode]44:1bprop.43:[CNode]45:3bprop.43:[CNode]45:3"REF::S-Prim-MakeTuple:4:Default/S-Prim-MakeTuple-op32bprop.43*
bprop.26:x*
bprop.26:out*
bprop.26:dout2
bprop.26:[CNode]28:3:@da20d6cb5b4c7c7cced90d19d6551099a88fa59be3bc1e215683bed4c1b029abPbH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
S-Prim-MakeTuple:4S-Prim-MakeTupleh
bprop.43:x*
bprop.43:out*
bprop.43:dout2
bprop.43:[CNode]45:3:@030ffb7e684c8e2604e69bfc02a6e32af82543a97f4aa0a6bc7cab6275ee9b9dPb&
S-Prim-MakeTuple:4S-Prim-MakeTuplebH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h

View File

@ -2,19 +2,19 @@
0.1.1 MindSpore*1.9.0:Ţ

bprop.30:xbprop.30:[CNode]31:1bprop.30:[CNode]31:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op21
bprop.61:xbprop.61:[CNode]62:1bprop.61:[CNode]62:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op44

bprop.30:ybprop.30:[CNode]32:3bprop.30:[CNode]32:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op22
bprop.61:ybprop.61:[CNode]63:3bprop.61:[CNode]63:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op45
<EFBFBD>
bprop.30:[CNode]31:1
bprop.30:[CNode]32:3bprop.30:[CNode]33:4bprop.30:[CNode]33:4"REF::S-Prim-MakeTuple:5:Default/S-Prim-MakeTuple-op23bprop.30*
bprop.61:[CNode]62:1
bprop.61:[CNode]63:3bprop.61:[CNode]64:4bprop.61:[CNode]64:4"REF::S-Prim-MakeTuple:5:Default/S-Prim-MakeTuple-op46bprop.61*
bprop.30:x*
bprop.61:x*
bprop.30:y*
bprop.30:out*
bprop.30:dout2
bprop.30:[CNode]33:4:@51014516712342426367430e7abdbb8b5454385a02c91339e89ad18f19803ee3PbH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
S-Prim-MakeTuple:5S-Prim-MakeTupleh
bprop.61:y*
bprop.61:out*
bprop.61:dout2
bprop.61:[CNode]64:4:@157abcda70ad669686a320be427dedf2a17498d59a38b65d04c272c1a40296e0Pb&
S-Prim-MakeTuple:5S-Prim-MakeTuplebH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h

View File

@ -1,20 +1,16 @@
0.1.1 MindSpore*1.9.0:Þ

bprop.70:xbprop.70:[CNode]71:1bprop.70:[CNode]71:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op51

bprop.70:ybprop.70:[CNode]72:3bprop.70:[CNode]72:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op52
<EFBFBD>
bprop.70:[CNode]71:1
bprop.70:[CNode]72:3bprop.70:[CNode]73:4bprop.70:[CNode]73:4"REF::S-Prim-MakeTuple:5:Default/S-Prim-MakeTuple-op53bprop.70*
bprop.70:x*
bprop.70:y*
bprop.70:out*
bprop.70:dout2
bprop.70:[CNode]73:4:@51014516712342426367430e7abdbb8b5454385a02c91339e89ad18f19803ee3Pb&
0.1.1 MindSpore*1.9.0:ú
˜
bprop.143:xbprop.143:[CNode]144:1bprop.143:[CNode]144:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:/Default/S-Prim-hyper_map[zeros_like_leaf]-op104
˜
bprop.143:ybprop.143:[CNode]145:3bprop.143:[CNode]145:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:/Default/S-Prim-hyper_map[zeros_like_leaf]-op105

bprop.143:[CNode]144:1
bprop.143:[CNode]145:3bprop.143:[CNode]146:4bprop.143:[CNode]146:4"REF::S-Prim-MakeTuple:5:Default/S-Prim-MakeTuple-op106 bprop.143*
bprop.143:x*
bprop.143:y*
bprop.143:out*
bprop.143:dout2
bprop.143:[CNode]146:4:@157abcda70ad669686a320be427dedf2a17498d59a38b65d04c272c1a40296e0Pb&
S-Prim-MakeTuple:5S-Prim-MakeTuplebH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h

View File

@ -2,19 +2,19 @@
0.1.1 MindSpore*1.9.0:Ţ

bprop.42:xbprop.42:[CNode]43:1bprop.42:[CNode]43:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op30
bprop.73:xbprop.73:[CNode]74:1bprop.73:[CNode]74:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op53

bprop.42:ybprop.42:[CNode]44:3bprop.42:[CNode]44:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op31
bprop.73:ybprop.73:[CNode]75:3bprop.73:[CNode]75:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op54
<EFBFBD>
bprop.42:[CNode]43:1
bprop.42:[CNode]44:3bprop.42:[CNode]45:4bprop.42:[CNode]45:4"REF::S-Prim-MakeTuple:5:Default/S-Prim-MakeTuple-op32bprop.42*
bprop.73:[CNode]74:1
bprop.73:[CNode]75:3bprop.73:[CNode]76:4bprop.73:[CNode]76:4"REF::S-Prim-MakeTuple:5:Default/S-Prim-MakeTuple-op55bprop.73*
bprop.42:x*
bprop.73:x*
bprop.42:y*
bprop.42:out*
bprop.42:dout2
bprop.42:[CNode]45:4:@51014516712342426367430e7abdbb8b5454385a02c91339e89ad18f19803ee3Pb&
S-Prim-MakeTuple:5S-Prim-MakeTuplebH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h
bprop.73:y*
bprop.73:out*
bprop.73:dout2
bprop.73:[CNode]76:4:@157abcda70ad669686a320be427dedf2a17498d59a38b65d04c272c1a40296e0PbH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
S-Prim-MakeTuple:5S-Prim-MakeTupleh

View File

@ -2,19 +2,19 @@
0.1.1 MindSpore*1.9.0:Ţ

bprop.38:xbprop.38:[CNode]39:1bprop.38:[CNode]39:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op27
bprop.69:xbprop.69:[CNode]70:1bprop.69:[CNode]70:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op50

bprop.38:ybprop.38:[CNode]40:3bprop.38:[CNode]40:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op28
bprop.69:ybprop.69:[CNode]71:3bprop.69:[CNode]71:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op51
<EFBFBD>
bprop.38:[CNode]39:1
bprop.38:[CNode]40:3bprop.38:[CNode]41:4bprop.38:[CNode]41:4"REF::S-Prim-MakeTuple:5:Default/S-Prim-MakeTuple-op29bprop.38*
bprop.69:[CNode]70:1
bprop.69:[CNode]71:3bprop.69:[CNode]72:4bprop.69:[CNode]72:4"REF::S-Prim-MakeTuple:5:Default/S-Prim-MakeTuple-op52bprop.69*
bprop.38:x*
bprop.69:x*
bprop.38:y*
bprop.38:out*
bprop.38:dout2
bprop.38:[CNode]41:4:@51014516712342426367430e7abdbb8b5454385a02c91339e89ad18f19803ee3PbH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
S-Prim-MakeTuple:5S-Prim-MakeTupleh
bprop.69:y*
bprop.69:out*
bprop.69:dout2
bprop.69:[CNode]72:4:@157abcda70ad669686a320be427dedf2a17498d59a38b65d04c272c1a40296e0Pb&
S-Prim-MakeTuple:5S-Prim-MakeTuplebH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h

View File

@ -1,9 +1,9 @@
0.1.1 MindSpore*1.9.0:ü
m
bprop.1:doutbprop.1:[CNode]2:1bprop.1:[CNode]2:1"REF::S-Prim-MakeTuple:2:Default/S-Prim-MakeTuple-op0bprop.1*
bprop.1:x*
bprop.1:out*
bprop.1:dout2
bprop.1:[CNode]2:1:@da20d6cb5b4c7c7cced90d19d6551099a88fa59be3bc1e215683bed4c1b029abPb&
bprop.3:doutbprop.3:[CNode]4:1bprop.3:[CNode]4:1"REF::S-Prim-MakeTuple:2:Default/S-Prim-MakeTuple-op2bprop.3*
bprop.3:x*
bprop.3:out*
bprop.3:dout2
bprop.3:[CNode]4:1:@030ffb7e684c8e2604e69bfc02a6e32af82543a97f4aa0a6bc7cab6275ee9b9dPb&
S-Prim-MakeTuple:2S-Prim-MakeTupleh

View File

@ -1,12 +1,14 @@
0.1.1 MindSpore*1.9.0:—
Ž
bprop.9:xbprop.9:[CNode]10:1bprop.9:[CNode]10:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:-Default/S-Prim-hyper_map[zeros_like_leaf]-op6
v
bprop.9:[CNode]10:1bprop.9:[CNode]11:3bprop.9:[CNode]11:3"REF::S-Prim-MakeTuple:4:Default/S-Prim-MakeTuple-op7bprop.9*
bprop.9:x*
bprop.9:out*
bprop.9:dout2
bprop.9:[CNode]11:3:@51014516712342426367430e7abdbb8b5454385a02c91339e89ad18f19803ee3PbH
0.1.1 MindSpore*1.9.0:¤

bprop.30:xbprop.30:[CNode]31:1bprop.30:[CNode]31:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op21
z
bprop.30:[CNode]31:1bprop.30:[CNode]32:3bprop.30:[CNode]32:3"REF::S-Prim-MakeTuple:4:Default/S-Prim-MakeTuple-op22bprop.30*
bprop.30:x*
bprop.30:out*
bprop.30:dout2
bprop.30:[CNode]32:3:@157abcda70ad669686a320be427dedf2a17498d59a38b65d04c272c1a40296e0PbH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
S-Prim-MakeTuple:4S-Prim-MakeTupleh

View File

@ -1,14 +1,12 @@
0.1.1 MindSpore*1.9.0:¤

bprop.91:xbprop.91:[CNode]92:1bprop.91:[CNode]92:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op67
z
bprop.91:[CNode]92:1bprop.91:[CNode]93:3bprop.91:[CNode]93:3"REF::S-Prim-MakeTuple:4:Default/S-Prim-MakeTuple-op68bprop.91*
bprop.91:x*
bprop.91:out*
bprop.91:dout2
bprop.91:[CNode]93:3:@51014516712342426367430e7abdbb8b5454385a02c91339e89ad18f19803ee3PbH
0.1.1 MindSpore*1.9.0:¸
˜
bprop.164:xbprop.164:[CNode]165:1bprop.164:[CNode]165:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:/Default/S-Prim-hyper_map[zeros_like_leaf]-op120
<EFBFBD>
bprop.164:[CNode]165:1bprop.164:[CNode]166:3bprop.164:[CNode]166:3"REF::S-Prim-MakeTuple:4:Default/S-Prim-MakeTuple-op121 bprop.164*
bprop.164:x*
bprop.164:out*
bprop.164:dout2
bprop.164:[CNode]166:3:@157abcda70ad669686a320be427dedf2a17498d59a38b65d04c272c1a40296e0PbH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
S-Prim-MakeTuple:4S-Prim-MakeTupleh

View File

@ -1,14 +1,12 @@
0.1.1 MindSpore*1.9.0:¤

bprop.88:xbprop.88:[CNode]89:1bprop.88:[CNode]89:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op65
z
bprop.88:[CNode]89:1bprop.88:[CNode]90:3bprop.88:[CNode]90:3"REF::S-Prim-MakeTuple:4:Default/S-Prim-MakeTuple-op66bprop.88*
bprop.88:x*
bprop.88:out*
bprop.88:dout2
bprop.88:[CNode]90:3:@51014516712342426367430e7abdbb8b5454385a02c91339e89ad18f19803ee3Pb&
0.1.1 MindSpore*1.9.0:¸
˜
bprop.161:xbprop.161:[CNode]162:1bprop.161:[CNode]162:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:/Default/S-Prim-hyper_map[zeros_like_leaf]-op118
<EFBFBD>
bprop.161:[CNode]162:1bprop.161:[CNode]163:3bprop.161:[CNode]163:3"REF::S-Prim-MakeTuple:4:Default/S-Prim-MakeTuple-op119 bprop.161*
bprop.161:x*
bprop.161:out*
bprop.161:dout2
bprop.161:[CNode]163:3:@157abcda70ad669686a320be427dedf2a17498d59a38b65d04c272c1a40296e0Pb&
S-Prim-MakeTuple:4S-Prim-MakeTuplebH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h

View File

@ -2,19 +2,19 @@
0.1.1 MindSpore*1.9.0:Ţ

bprop.50:xbprop.50:[CNode]51:1bprop.50:[CNode]51:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op36
bprop.81:xbprop.81:[CNode]82:1bprop.81:[CNode]82:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op59

bprop.50:ybprop.50:[CNode]52:3bprop.50:[CNode]52:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op37
bprop.81:ybprop.81:[CNode]83:3bprop.81:[CNode]83:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op60
<EFBFBD>
bprop.50:[CNode]51:1
bprop.50:[CNode]52:3bprop.50:[CNode]53:4bprop.50:[CNode]53:4"REF::S-Prim-MakeTuple:5:Default/S-Prim-MakeTuple-op38bprop.50*
bprop.81:[CNode]82:1
bprop.81:[CNode]83:3bprop.81:[CNode]84:4bprop.81:[CNode]84:4"REF::S-Prim-MakeTuple:5:Default/S-Prim-MakeTuple-op61bprop.81*
bprop.50:x*
bprop.81:x*
bprop.50:y*
bprop.50:out*
bprop.50:dout2
bprop.50:[CNode]53:4:@51014516712342426367430e7abdbb8b5454385a02c91339e89ad18f19803ee3PbH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
S-Prim-MakeTuple:5S-Prim-MakeTupleh
bprop.81:y*
bprop.81:out*
bprop.81:dout2
bprop.81:[CNode]84:4:@157abcda70ad669686a320be427dedf2a17498d59a38b65d04c272c1a40296e0Pb&
S-Prim-MakeTuple:5S-Prim-MakeTuplebH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h

View File

@ -2,19 +2,19 @@
0.1.1 MindSpore*1.9.0:Ţ

bprop.46:xbprop.46:[CNode]47:1bprop.46:[CNode]47:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op33
bprop.77:xbprop.77:[CNode]78:1bprop.77:[CNode]78:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op56

bprop.46:ybprop.46:[CNode]48:3bprop.46:[CNode]48:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op34
bprop.77:ybprop.77:[CNode]79:3bprop.77:[CNode]79:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op57
<EFBFBD>
bprop.46:[CNode]47:1
bprop.46:[CNode]48:3bprop.46:[CNode]49:4bprop.46:[CNode]49:4"REF::S-Prim-MakeTuple:5:Default/S-Prim-MakeTuple-op35bprop.46*
bprop.77:[CNode]78:1
bprop.77:[CNode]79:3bprop.77:[CNode]80:4bprop.77:[CNode]80:4"REF::S-Prim-MakeTuple:5:Default/S-Prim-MakeTuple-op58bprop.77*
bprop.46:x*
bprop.77:x*
bprop.46:y*
bprop.46:out*
bprop.46:dout2
bprop.46:[CNode]49:4:@51014516712342426367430e7abdbb8b5454385a02c91339e89ad18f19803ee3Pb&
S-Prim-MakeTuple:5S-Prim-MakeTuplebH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h
bprop.77:y*
bprop.77:out*
bprop.77:dout2
bprop.77:[CNode]80:4:@157abcda70ad669686a320be427dedf2a17498d59a38b65d04c272c1a40296e0PbH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
S-Prim-MakeTuple:5S-Prim-MakeTupleh

View File

@ -1,20 +1,20 @@
0.1.1 MindSpore*1.9.0:©

bprop.25:startbprop.25:[CNode]26:1bprop.25:[CNode]26:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op17
bprop.46:startbprop.46:[CNode]47:1bprop.46:[CNode]47:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op32

bprop.25:stopbprop.25:[CNode]27:3bprop.25:[CNode]27:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op18
bprop.46:stopbprop.46:[CNode]48:3bprop.46:[CNode]48:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op33

bprop.25:numbprop.25:[CNode]28:4bprop.25:[CNode]28:4"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op19
bprop.46:numbprop.46:[CNode]49:4bprop.46:[CNode]49:4"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op34
¦
bprop.25:[CNode]26:1
bprop.25:[CNode]27:3
bprop.25:[CNode]28:4bprop.25:[CNode]29:5bprop.25:[CNode]29:5"REF::S-Prim-MakeTuple:6:Default/S-Prim-MakeTuple-op20bprop.25*
bprop.25:start*
bprop.25:stop*
bprop.25:num*
bprop.25:out*
bprop.25:dout2
bprop.25:[CNode]29:5:@51014516712342426367430e7abdbb8b5454385a02c91339e89ad18f19803ee3PbH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
S-Prim-MakeTuple:6S-Prim-MakeTupleh
bprop.46:[CNode]47:1
bprop.46:[CNode]48:3
bprop.46:[CNode]49:4bprop.46:[CNode]50:5bprop.46:[CNode]50:5"REF::S-Prim-MakeTuple:6:Default/S-Prim-MakeTuple-op35bprop.46*
bprop.46:start*
bprop.46:stop*
bprop.46:num*
bprop.46:out*
bprop.46:dout2
bprop.46:[CNode]50:5:@157abcda70ad669686a320be427dedf2a17498d59a38b65d04c272c1a40296e0Pb&
S-Prim-MakeTuple:6S-Prim-MakeTuplebH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h

View File

@ -2,19 +2,19 @@
0.1.1 MindSpore*1.9.0:Þ

bprop.54:xbprop.54:[CNode]55:1bprop.54:[CNode]55:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op39
bprop.85:xbprop.85:[CNode]86:1bprop.85:[CNode]86:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op62

bprop.54:ybprop.54:[CNode]56:3bprop.54:[CNode]56:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op40
bprop.85:ybprop.85:[CNode]87:3bprop.85:[CNode]87:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op63
<EFBFBD>
bprop.54:[CNode]55:1
bprop.54:[CNode]56:3bprop.54:[CNode]57:4bprop.54:[CNode]57:4"REF::S-Prim-MakeTuple:5:Default/S-Prim-MakeTuple-op41bprop.54*
bprop.85:[CNode]86:1
bprop.85:[CNode]87:3bprop.85:[CNode]88:4bprop.85:[CNode]88:4"REF::S-Prim-MakeTuple:5:Default/S-Prim-MakeTuple-op64bprop.85*
bprop.54:x*
bprop.85:x*
bprop.54:y*
bprop.54:out*
bprop.54:dout2
bprop.54:[CNode]57:4:@51014516712342426367430e7abdbb8b5454385a02c91339e89ad18f19803ee3PbH
bprop.85:y*
bprop.85:out*
bprop.85:dout2
bprop.85:[CNode]88:4:@157abcda70ad669686a320be427dedf2a17498d59a38b65d04c272c1a40296e0PbH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
S-Prim-MakeTuple:5S-Prim-MakeTupleh

View File

@ -2,13 +2,13 @@
0.1.1 MindSpore*1.9.0:¤

bprop.19:xbprop.19:[CNode]20:1bprop.19:[CNode]20:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op13
bprop.40:xbprop.40:[CNode]41:1bprop.40:[CNode]41:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op28
z
bprop.19:[CNode]20:1bprop.19:[CNode]21:3bprop.19:[CNode]21:3"REF::S-Prim-MakeTuple:4:Default/S-Prim-MakeTuple-op14bprop.19*
bprop.40:[CNode]41:1bprop.40:[CNode]42:3bprop.40:[CNode]42:3"REF::S-Prim-MakeTuple:4:Default/S-Prim-MakeTuple-op29bprop.40*
bprop.19:x*
bprop.19:out*
bprop.19:dout2
bprop.19:[CNode]21:3:@51014516712342426367430e7abdbb8b5454385a02c91339e89ad18f19803ee3Pb&
bprop.40:x*
bprop.40:out*
bprop.40:dout2
bprop.40:[CNode]42:3:@157abcda70ad669686a320be427dedf2a17498d59a38b65d04c272c1a40296e0Pb&
S-Prim-MakeTuple:4S-Prim-MakeTuplebH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h

View File

@ -2,19 +2,19 @@
0.1.1 MindSpore*1.9.0:Ţ

bprop.58:xbprop.58:[CNode]59:1bprop.58:[CNode]59:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op42
bprop.89:xbprop.89:[CNode]90:1bprop.89:[CNode]90:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op65

bprop.58:ybprop.58:[CNode]60:3bprop.58:[CNode]60:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op43
bprop.89:ybprop.89:[CNode]91:3bprop.89:[CNode]91:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op66
<EFBFBD>
bprop.58:[CNode]59:1
bprop.58:[CNode]60:3bprop.58:[CNode]61:4bprop.58:[CNode]61:4"REF::S-Prim-MakeTuple:5:Default/S-Prim-MakeTuple-op44bprop.58*
bprop.89:[CNode]90:1
bprop.89:[CNode]91:3bprop.89:[CNode]92:4bprop.89:[CNode]92:4"REF::S-Prim-MakeTuple:5:Default/S-Prim-MakeTuple-op67bprop.89*
bprop.58:x*
bprop.89:x*
bprop.58:y*
bprop.58:out*
bprop.58:dout2
bprop.58:[CNode]61:4:@51014516712342426367430e7abdbb8b5454385a02c91339e89ad18f19803ee3PbH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
S-Prim-MakeTuple:5S-Prim-MakeTupleh
bprop.89:y*
bprop.89:out*
bprop.89:dout2
bprop.89:[CNode]92:4:@157abcda70ad669686a320be427dedf2a17498d59a38b65d04c272c1a40296e0Pb&
S-Prim-MakeTuple:5S-Prim-MakeTuplebH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h

View File

@ -2,19 +2,19 @@
0.1.1 MindSpore*1.9.0:Ţ

bprop.34:xbprop.34:[CNode]35:1bprop.34:[CNode]35:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op24
bprop.65:xbprop.65:[CNode]66:1bprop.65:[CNode]66:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op47

bprop.34:ybprop.34:[CNode]36:3bprop.34:[CNode]36:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op25
bprop.65:ybprop.65:[CNode]67:3bprop.65:[CNode]67:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op48
<EFBFBD>
bprop.34:[CNode]35:1
bprop.34:[CNode]36:3bprop.34:[CNode]37:4bprop.34:[CNode]37:4"REF::S-Prim-MakeTuple:5:Default/S-Prim-MakeTuple-op26bprop.34*
bprop.65:[CNode]66:1
bprop.65:[CNode]67:3bprop.65:[CNode]68:4bprop.65:[CNode]68:4"REF::S-Prim-MakeTuple:5:Default/S-Prim-MakeTuple-op49bprop.65*
bprop.34:x*
bprop.65:x*
bprop.34:y*
bprop.34:out*
bprop.34:dout2
bprop.34:[CNode]37:4:@51014516712342426367430e7abdbb8b5454385a02c91339e89ad18f19803ee3PbH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
S-Prim-MakeTuple:5S-Prim-MakeTupleh
bprop.65:y*
bprop.65:out*
bprop.65:dout2
bprop.65:[CNode]68:4:@157abcda70ad669686a320be427dedf2a17498d59a38b65d04c272c1a40296e0Pb&
S-Prim-MakeTuple:5S-Prim-MakeTuplebH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h

View File

@ -1,24 +1,24 @@
0.1.1 MindSpore*1.9.0:Ý

bprop.7:indicesbprop.7:[CNode]8:1bprop.7:[CNode]8:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:-Default/S-Prim-hyper_map[zeros_like_leaf]-op5
<EFBFBD>
bprop.7:depthbprop.7:[CNode]9:3bprop.7:[CNode]9:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:-Default/S-Prim-hyper_map[zeros_like_leaf]-op6

bprop.7:on_valuebprop.7:[CNode]10:4bprop.7:[CNode]10:4"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:-Default/S-Prim-hyper_map[zeros_like_leaf]-op7
0.1.1 MindSpore*1.9.0:
˜
bprop.26:indicesbprop.26:[CNode]27:1bprop.26:[CNode]27:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op18

bprop.7:off_valuebprop.7:[CNode]11:5bprop.7:[CNode]11:5"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:-Default/S-Prim-hyper_map[zeros_like_leaf]-op8
³
bprop.7:[CNode]8:1
bprop.7:[CNode]9:3
bprop.7:[CNode]10:4
bprop.7:[CNode]11:5bprop.7:[CNode]12:6bprop.7:[CNode]12:6"REF::S-Prim-MakeTuple:7:Default/S-Prim-MakeTuple-op9bprop.7*
bprop.7:indices*
bprop.7:depth*
bprop.7:on_value*
bprop.7:off_value*
bprop.7:out*
bprop.7:dout2
bprop.7:[CNode]12:6:@de4408af9f648bc43794b496de642f2ba8fef52099d138465267bb08b5fd7e8fPb&
bprop.26:depthbprop.26:[CNode]28:3bprop.26:[CNode]28:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op19

bprop.26:on_valuebprop.26:[CNode]29:4bprop.26:[CNode]29:4"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op20
š
bprop.26:off_valuebprop.26:[CNode]30:5bprop.26:[CNode]30:5"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op21
¼
bprop.26:[CNode]27:1
bprop.26:[CNode]28:3
bprop.26:[CNode]29:4
bprop.26:[CNode]30:5bprop.26:[CNode]31:6bprop.26:[CNode]31:6"REF::S-Prim-MakeTuple:7:Default/S-Prim-MakeTuple-op22bprop.26*
bprop.26:indices*
bprop.26:depth*
bprop.26:on_value*
bprop.26:off_value*
bprop.26:out*
bprop.26:dout2
bprop.26:[CNode]31:6:@320a603a26f1cc5c0174957a4eea3494bfd5f1c2f20ad64ef6e0ee49a91b62baPb&
S-Prim-MakeTuple:7S-Prim-MakeTuplebH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h

View File

@ -1,12 +1,14 @@
0.1.1 MindSpore*1.9.0:”
Œ
bprop.8:xbprop.8:[CNode]9:1bprop.8:[CNode]9:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:-Default/S-Prim-hyper_map[zeros_like_leaf]-op5
u
bprop.8:[CNode]9:1bprop.8:[CNode]10:3bprop.8:[CNode]10:3"REF::S-Prim-MakeTuple:4:Default/S-Prim-MakeTuple-op6bprop.8*
bprop.8:x*
bprop.8:out*
bprop.8:dout2
bprop.8:[CNode]10:3:@da20d6cb5b4c7c7cced90d19d6551099a88fa59be3bc1e215683bed4c1b029abPb&
0.1.1 MindSpore*1.9.0:¢

bprop.10:xbprop.10:[CNode]11:1bprop.10:[CNode]11:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:-Default/S-Prim-hyper_map[zeros_like_leaf]-op7
y
bprop.10:[CNode]11:1bprop.10:[CNode]12:3bprop.10:[CNode]12:3"REF::S-Prim-MakeTuple:4:Default/S-Prim-MakeTuple-op8bprop.10*
bprop.10:x*
bprop.10:out*
bprop.10:dout2
bprop.10:[CNode]12:3:@030ffb7e684c8e2604e69bfc02a6e32af82543a97f4aa0a6bc7cab6275ee9b9dPb&
S-Prim-MakeTuple:4S-Prim-MakeTuplebH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h

View File

@ -1,20 +1,20 @@
0.1.1 MindSpore*1.9.0:Š
<EFBFBD>
bprop.3:startbprop.3:[CNode]4:1bprop.3:[CNode]4:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:-Default/S-Prim-hyper_map[zeros_like_leaf]-op1
bprop.5:startbprop.5:[CNode]6:1bprop.5:[CNode]6:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:-Default/S-Prim-hyper_map[zeros_like_leaf]-op3
<EFBFBD>
bprop.3:limitbprop.3:[CNode]5:3bprop.3:[CNode]5:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:-Default/S-Prim-hyper_map[zeros_like_leaf]-op2
bprop.5:limitbprop.5:[CNode]7:3bprop.5:[CNode]7:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:-Default/S-Prim-hyper_map[zeros_like_leaf]-op4
<EFBFBD>
bprop.3:deltabprop.3:[CNode]6:4bprop.3:[CNode]6:4"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:-Default/S-Prim-hyper_map[zeros_like_leaf]-op3
bprop.5:deltabprop.5:[CNode]8:4bprop.5:[CNode]8:4"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:-Default/S-Prim-hyper_map[zeros_like_leaf]-op5

bprop.3:[CNode]4:1
bprop.3:[CNode]5:3
bprop.3:[CNode]6:4bprop.3:[CNode]7:5bprop.3:[CNode]7:5"REF::S-Prim-MakeTuple:6:Default/S-Prim-MakeTuple-op4bprop.3*
bprop.3:start*
bprop.3:limit*
bprop.3:delta*
bprop.3:out*
bprop.3:dout2
bprop.3:[CNode]7:5:@da20d6cb5b4c7c7cced90d19d6551099a88fa59be3bc1e215683bed4c1b029abPb&
bprop.5:[CNode]6:1
bprop.5:[CNode]7:3
bprop.5:[CNode]8:4bprop.5:[CNode]9:5bprop.5:[CNode]9:5"REF::S-Prim-MakeTuple:6:Default/S-Prim-MakeTuple-op6bprop.5*
bprop.5:start*
bprop.5:limit*
bprop.5:delta*
bprop.5:out*
bprop.5:dout2
bprop.5:[CNode]9:5:@030ffb7e684c8e2604e69bfc02a6e32af82543a97f4aa0a6bc7cab6275ee9b9dPb&
S-Prim-MakeTuple:6S-Prim-MakeTuplebH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h

View File

@ -2,13 +2,13 @@
0.1.1 MindSpore*1.9.0:¤

bprop.29:xbprop.29:[CNode]30:1bprop.29:[CNode]30:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op19
bprop.46:xbprop.46:[CNode]47:1bprop.46:[CNode]47:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op33
z
bprop.29:[CNode]30:1bprop.29:[CNode]31:3bprop.29:[CNode]31:3"REF::S-Prim-MakeTuple:4:Default/S-Prim-MakeTuple-op20bprop.29*
bprop.46:[CNode]47:1bprop.46:[CNode]48:3bprop.46:[CNode]48:3"REF::S-Prim-MakeTuple:4:Default/S-Prim-MakeTuple-op34bprop.46*
bprop.29:x*
bprop.29:out*
bprop.29:dout2
bprop.29:[CNode]31:3:@da20d6cb5b4c7c7cced90d19d6551099a88fa59be3bc1e215683bed4c1b029abPbH
bprop.46:x*
bprop.46:out*
bprop.46:dout2
bprop.46:[CNode]48:3:@030ffb7e684c8e2604e69bfc02a6e32af82543a97f4aa0a6bc7cab6275ee9b9dPbH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
S-Prim-MakeTuple:4S-Prim-MakeTupleh

View File

@ -8,9 +8,9 @@ m
bprop.1:x*
bprop.1:out*
bprop.1:dout2
bprop.1:[CNode]2:3:@de4408af9f648bc43794b496de642f2ba8fef52099d138465267bb08b5fd7e8fPb&
S-Prim-MakeTuple:4S-Prim-MakeTuplebr
bprop.1:[CNode]2:3:@320a603a26f1cc5c0174957a4eea3494bfd5f1c2f20ad64ef6e0ee49a91b62baPbr
S-Prim-ReluGrad:2S-Prim-ReluGrad
output_names€Š Zoutput€+
input_names€ŠZ
y_backprop€ŠZx€h
y_backprop€ŠZx€b&
S-Prim-MakeTuple:4S-Prim-MakeTupleh

View File

@ -2,17 +2,17 @@
0.1.1 MindSpore*1.9.0:ä

bprop.62:xbprop.62:[CNode]63:1bprop.62:[CNode]63:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op45
bprop.93:xbprop.93:[CNode]94:1bprop.93:[CNode]94:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op68

bprop.62:axisbprop.62:[CNode]64:3bprop.62:[CNode]64:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op46
bprop.93:axisbprop.93:[CNode]95:3bprop.93:[CNode]95:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op69
<EFBFBD>
bprop.62:[CNode]63:1
bprop.62:[CNode]64:3bprop.62:[CNode]65:4bprop.62:[CNode]65:4"REF::S-Prim-MakeTuple:5:Default/S-Prim-MakeTuple-op47bprop.62*
bprop.93:[CNode]94:1
bprop.93:[CNode]95:3bprop.93:[CNode]96:4bprop.93:[CNode]96:4"REF::S-Prim-MakeTuple:5:Default/S-Prim-MakeTuple-op70bprop.93*
bprop.62:x*
bprop.62:axis*
bprop.62:out*
bprop.62:dout2
bprop.62:[CNode]65:4:@51014516712342426367430e7abdbb8b5454385a02c91339e89ad18f19803ee3Pb&
bprop.93:x*
bprop.93:axis*
bprop.93:out*
bprop.93:dout2
bprop.93:[CNode]96:4:@157abcda70ad669686a320be427dedf2a17498d59a38b65d04c272c1a40296e0Pb&
S-Prim-MakeTuple:5S-Prim-MakeTuplebH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h

View File

@ -1,18 +1,18 @@
0.1.1 MindSpore*1.9.0:ä
0.1.1 MindSpore*1.9.0:ç
<EFBFBD>
bprop.66:xbprop.66:[CNode]67:1bprop.66:[CNode]67:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op48
bprop.97:xbprop.97:[CNode]98:1bprop.97:[CNode]98:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op71
<EFBFBD>
bprop.66:axisbprop.66:[CNode]68:3bprop.66:[CNode]68:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op49
<EFBFBD>
bprop.66:[CNode]67:1
bprop.66:[CNode]68:3bprop.66:[CNode]69:4bprop.66:[CNode]69:4"REF::S-Prim-MakeTuple:5:Default/S-Prim-MakeTuple-op50bprop.66*
bprop.97:axisbprop.97:[CNode]99:3bprop.97:[CNode]99:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op72

bprop.97:[CNode]98:1
bprop.97:[CNode]99:3bprop.97:[CNode]100:4bprop.97:[CNode]100:4"REF::S-Prim-MakeTuple:5:Default/S-Prim-MakeTuple-op73bprop.97*
bprop.66:x*
bprop.66:axis*
bprop.66:out*
bprop.66:dout2
bprop.66:[CNode]69:4:@51014516712342426367430e7abdbb8b5454385a02c91339e89ad18f19803ee3PbH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
S-Prim-MakeTuple:5S-Prim-MakeTupleh
bprop.97:x*
bprop.97:axis*
bprop.97:out*
bprop.97:dout2
bprop.97:[CNode]100:4:@157abcda70ad669686a320be427dedf2a17498d59a38b65d04c272c1a40296e0Pb&
S-Prim-MakeTuple:5S-Prim-MakeTuplebH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h

View File

@ -1,24 +1,24 @@
0.1.1 MindSpore*1.9.0:¿
u
bprop.18:dout
bprop.71:dout
bprop.18:ybprop.18:dgrad:1bprop.18:dgrad:1"REF::S-Prim-ReluGrad:2:Default/S-Prim-ReluGrad-op14
bprop.71:ybprop.71:dgrad:1bprop.71:dgrad:1"REF::S-Prim-ReluGrad:2:Default/S-Prim-ReluGrad-op50

bprop.18:ybprop.18:[CNode]19:3bprop.18:[CNode]19:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:4:.Default/S-Prim-hyper_map[zeros_like_leaf]-op15
bprop.71:ybprop.71:[CNode]72:3bprop.71:[CNode]72:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:4:.Default/S-Prim-hyper_map[zeros_like_leaf]-op51
Œ
bprop.18:dgrad:1
bprop.18:[CNode]19:3bprop.18:[CNode]20:5bprop.18:[CNode]20:5"REF::S-Prim-MakeTuple:6:Default/S-Prim-MakeTuple-op16bprop.18*
bprop.18:grad*
bprop.71:dgrad:1
bprop.71:[CNode]72:3bprop.71:[CNode]73:5bprop.71:[CNode]73:5"REF::S-Prim-MakeTuple:6:Default/S-Prim-MakeTuple-op52bprop.71*
bprop.71:grad*
bprop.18:y*
bprop.18:out*
bprop.18:dout2
bprop.18:[CNode]20:5:@de4408af9f648bc43794b496de642f2ba8fef52099d138465267bb08b5fd7e8fPbr
bprop.71:y*
bprop.71:out*
bprop.71:dout2
bprop.71:[CNode]73:5:@320a603a26f1cc5c0174957a4eea3494bfd5f1c2f20ad64ef6e0ee49a91b62baPb&
S-Prim-MakeTuple:6S-Prim-MakeTuplebr
S-Prim-ReluGrad:2S-Prim-ReluGrad
output_names€Š Zoutput€+
input_names€ŠZ
y_backprop€ŠZx€b&
S-Prim-MakeTuple:6S-Prim-MakeTuplebH
y_backprop€ŠZx€bH
#S-Prim-hyper_map[zeros_like_leaf]:4!S-Prim-hyper_map[zeros_like_leaf]h

View File

@ -2,13 +2,13 @@
0.1.1 MindSpore*1.9.0:¤

bprop.22:xbprop.22:[CNode]23:1bprop.22:[CNode]23:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op15
bprop.43:xbprop.43:[CNode]44:1bprop.43:[CNode]44:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op30
z
bprop.22:[CNode]23:1bprop.22:[CNode]24:3bprop.22:[CNode]24:3"REF::S-Prim-MakeTuple:4:Default/S-Prim-MakeTuple-op16bprop.22*
bprop.43:[CNode]44:1bprop.43:[CNode]45:3bprop.43:[CNode]45:3"REF::S-Prim-MakeTuple:4:Default/S-Prim-MakeTuple-op31bprop.43*
bprop.22:x*
bprop.22:out*
bprop.22:dout2
bprop.22:[CNode]24:3:@51014516712342426367430e7abdbb8b5454385a02c91339e89ad18f19803ee3Pb&
bprop.43:x*
bprop.43:out*
bprop.43:dout2
bprop.43:[CNode]45:3:@157abcda70ad669686a320be427dedf2a17498d59a38b65d04c272c1a40296e0Pb&
S-Prim-MakeTuple:4S-Prim-MakeTuplebH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h

View File

@ -1,35 +1,35 @@
0.1.1 MindSpore*1.9.0:Ç

bprop.32:condbprop.32:[CNode]33:1bprop.32:[CNode]33:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op21
bprop.49:condbprop.49:[CNode]50:1bprop.49:[CNode]50:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op35

bprop.32:xbprop.32:[CNode]34:3bprop.32:[CNode]34:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op22
bprop.49:xbprop.49:[CNode]51:3bprop.49:[CNode]51:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op36

bprop.32:cond
bprop.32:dout
bprop.32:[CNode]34:3bprop.32:[CNode]35:4bprop.32:[CNode]35:4"REF::S-Prim-Select:5:Default/S-Prim-Select-op23
bprop.49:cond
bprop.49:dout
bprop.49:[CNode]51:3bprop.49:[CNode]52:4bprop.49:[CNode]52:4"REF::S-Prim-Select:5:Default/S-Prim-Select-op37

bprop.32:ybprop.32:[CNode]36:6bprop.32:[CNode]36:6"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op24
bprop.49:ybprop.49:[CNode]53:6bprop.49:[CNode]53:6"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op38

bprop.32:cond
bprop.32:[CNode]36:6
bprop.32:doutbprop.32:[CNode]37:7bprop.32:[CNode]37:7"REF::S-Prim-Select:5:Default/S-Prim-Select-op25
bprop.49:cond
bprop.49:[CNode]53:6
bprop.49:doutbprop.49:[CNode]54:7bprop.49:[CNode]54:7"REF::S-Prim-Select:5:Default/S-Prim-Select-op39
¦
bprop.32:[CNode]33:1
bprop.32:[CNode]35:4
bprop.32:[CNode]37:7bprop.32:[CNode]38:8bprop.32:[CNode]38:8"REF::S-Prim-MakeTuple:9:Default/S-Prim-MakeTuple-op26bprop.32*
bprop.32:cond*
bprop.49:[CNode]50:1
bprop.49:[CNode]52:4
bprop.49:[CNode]54:7bprop.49:[CNode]55:8bprop.49:[CNode]55:8"REF::S-Prim-MakeTuple:9:Default/S-Prim-MakeTuple-op40bprop.49*
bprop.49:cond*
bprop.32:x*
bprop.49:x*
bprop.32:y*
bprop.32:out*
bprop.32:dout2
bprop.32:[CNode]38:8:@da20d6cb5b4c7c7cced90d19d6551099a88fa59be3bc1e215683bed4c1b029abPbv
bprop.49:y*
bprop.49:out*
bprop.49:dout2
bprop.49:[CNode]55:8:@030ffb7e684c8e2604e69bfc02a6e32af82543a97f4aa0a6bc7cab6275ee9b9dPb&
S-Prim-MakeTuple:9S-Prim-MakeTuplebH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]bv
S-Prim-Select:5 S-Prim-Select
output_names€Š Zoutput€3
input_names€ŠZ condition€ŠZx€ŠZy€bH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
S-Prim-MakeTuple:9S-Prim-MakeTupleh
input_names€ŠZ condition€ŠZx€ŠZy€h

View File

@ -2,13 +2,13 @@
0.1.1 MindSpore*1.9.0:¤

bprop.23:xbprop.23:[CNode]24:1bprop.23:[CNode]24:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op15
bprop.40:xbprop.40:[CNode]41:1bprop.40:[CNode]41:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op29
z
bprop.23:[CNode]24:1bprop.23:[CNode]25:3bprop.23:[CNode]25:3"REF::S-Prim-MakeTuple:4:Default/S-Prim-MakeTuple-op16bprop.23*
bprop.40:[CNode]41:1bprop.40:[CNode]42:3bprop.40:[CNode]42:3"REF::S-Prim-MakeTuple:4:Default/S-Prim-MakeTuple-op30bprop.40*
bprop.23:x*
bprop.23:out*
bprop.23:dout2
bprop.23:[CNode]25:3:@da20d6cb5b4c7c7cced90d19d6551099a88fa59be3bc1e215683bed4c1b029abPbH
bprop.40:x*
bprop.40:out*
bprop.40:dout2
bprop.40:[CNode]42:3:@030ffb7e684c8e2604e69bfc02a6e32af82543a97f4aa0a6bc7cab6275ee9b9dPbH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
S-Prim-MakeTuple:4S-Prim-MakeTupleh

View File

@ -2,13 +2,13 @@
0.1.1 MindSpore*1.9.0:¤

bprop.16:xbprop.16:[CNode]17:1bprop.16:[CNode]17:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op11
bprop.37:xbprop.37:[CNode]38:1bprop.37:[CNode]38:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op26
z
bprop.16:[CNode]17:1bprop.16:[CNode]18:3bprop.16:[CNode]18:3"REF::S-Prim-MakeTuple:4:Default/S-Prim-MakeTuple-op12bprop.16*
bprop.37:[CNode]38:1bprop.37:[CNode]39:3bprop.37:[CNode]39:3"REF::S-Prim-MakeTuple:4:Default/S-Prim-MakeTuple-op27bprop.37*
bprop.16:x*
bprop.16:out*
bprop.16:dout2
bprop.16:[CNode]18:3:@51014516712342426367430e7abdbb8b5454385a02c91339e89ad18f19803ee3PbH
bprop.37:x*
bprop.37:out*
bprop.37:dout2
bprop.37:[CNode]39:3:@157abcda70ad669686a320be427dedf2a17498d59a38b65d04c272c1a40296e0PbH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
S-Prim-MakeTuple:4S-Prim-MakeTupleh

View File

@ -1,20 +1,16 @@
0.1.1 MindSpore*1.9.0:Þ

bprop.74:xbprop.74:[CNode]75:1bprop.74:[CNode]75:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op54

bprop.74:ybprop.74:[CNode]76:3bprop.74:[CNode]76:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op55
<EFBFBD>
bprop.74:[CNode]75:1
bprop.74:[CNode]76:3bprop.74:[CNode]77:4bprop.74:[CNode]77:4"REF::S-Prim-MakeTuple:5:Default/S-Prim-MakeTuple-op56bprop.74*
bprop.74:x*
bprop.74:y*
bprop.74:out*
bprop.74:dout2
bprop.74:[CNode]77:4:@51014516712342426367430e7abdbb8b5454385a02c91339e89ad18f19803ee3PbH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
S-Prim-MakeTuple:5S-Prim-MakeTupleh
0.1.1 MindSpore*1.9.0:ú
˜
bprop.147:xbprop.147:[CNode]148:1bprop.147:[CNode]148:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:/Default/S-Prim-hyper_map[zeros_like_leaf]-op107
˜
bprop.147:ybprop.147:[CNode]149:3bprop.147:[CNode]149:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:/Default/S-Prim-hyper_map[zeros_like_leaf]-op108

bprop.147:[CNode]148:1
bprop.147:[CNode]149:3bprop.147:[CNode]150:4bprop.147:[CNode]150:4"REF::S-Prim-MakeTuple:5:Default/S-Prim-MakeTuple-op109 bprop.147*
bprop.147:x*
bprop.147:y*
bprop.147:out*
bprop.147:dout2
bprop.147:[CNode]150:4:@157abcda70ad669686a320be427dedf2a17498d59a38b65d04c272c1a40296e0Pb&
S-Prim-MakeTuple:5S-Prim-MakeTuplebH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h

View File

@ -1,14 +1,14 @@
0.1.1 MindSpore*1.9.0:¢
0.1.1 MindSpore*1.9.0:£
<EFBFBD>
bprop.11:xbprop.11:[CNode]12:1bprop.11:[CNode]12:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:-Default/S-Prim-hyper_map[zeros_like_leaf]-op7
y
bprop.11:[CNode]12:1bprop.11:[CNode]13:3bprop.11:[CNode]13:3"REF::S-Prim-MakeTuple:4:Default/S-Prim-MakeTuple-op8bprop.11*
bprop.13:xbprop.13:[CNode]14:1bprop.13:[CNode]14:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:-Default/S-Prim-hyper_map[zeros_like_leaf]-op9
z
bprop.13:[CNode]14:1bprop.13:[CNode]15:3bprop.13:[CNode]15:3"REF::S-Prim-MakeTuple:4:Default/S-Prim-MakeTuple-op10bprop.13*
bprop.11:x*
bprop.11:out*
bprop.11:dout2
bprop.11:[CNode]13:3:@da20d6cb5b4c7c7cced90d19d6551099a88fa59be3bc1e215683bed4c1b029abPb&
S-Prim-MakeTuple:4S-Prim-MakeTuplebH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h
bprop.13:x*
bprop.13:out*
bprop.13:dout2
bprop.13:[CNode]15:3:@030ffb7e684c8e2604e69bfc02a6e32af82543a97f4aa0a6bc7cab6275ee9b9dPbH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
S-Prim-MakeTuple:4S-Prim-MakeTupleh

View File

@ -2503,7 +2503,6 @@ class HistogramFixedWidth(PrimitiveWithInfer):
self.add_prim_attr('dtype', 3)
class Log(Primitive):
"""
Returns the natural logarithm of a tensor element-wise.

View File

@ -0,0 +1,64 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
from mindspore import ops, nn, context, Tensor
from .test_grad_of_dynamic import TestDynamicGrad
context.set_context(mode=context.PYNATIVE_MODE)
class NetAddcmul(nn.Cell):
def __init__(self):
super(NetAddcmul, self).__init__()
self.add = ops.Addcmul()
def construct(self, a, x1, x2, v):
return self.add(a, x1, x2, v)
def addcmul_test(is_dyn_rank):
a = Tensor(np.array([1, 1, 1]).astype(np.float32))
x1 = Tensor(np.array([[1], [2], [3]]).astype(np.float32))
x2 = Tensor(np.array([[1, 2, 3]]).astype(np.float32))
v = Tensor([1], ms.float32)
tester = TestDynamicGrad(NetAddcmul())
tester.test_dynamic_grad_net([a, x1, x2, v], is_dyn_rank)
@pytest.mark.level1
@pytest.mark.env_onecard
@pytest.mark.platform_x86_gpu_training
def test_addcmul_dyn_shape():
"""
Feature: Addcmul Grad DynamicShape.
Description: Test case of dynamic shape for Addcmul grad operator.
Expectation: success.
"""
addcmul_test(False)
@pytest.mark.level1
@pytest.mark.env_onecard
@pytest.mark.platform_x86_gpu_training
def test_addcmul_dyn_rank():
"""
Feature: Addcmul Grad DynamicShape.
Description: Test case of dynamic rank for Addcmul grad operator.
Expectation: success.
"""
addcmul_test(True)

View File

@ -0,0 +1,97 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.nn as nn
from mindspore import Tensor, context
from mindspore.ops import operations as P
from .test_grad_of_dynamic import TestDynamicGrad
class Atan2Net(nn.Cell):
def __init__(self):
super(Atan2Net, self).__init__()
self.atan2 = P.Atan2()
def construct(self, x, y):
return self.atan2(x, y)
def dynamic_shape():
test_dynamic = TestDynamicGrad(Atan2Net())
x = Tensor(np.array([0, 1]).astype(np.float32))
y = Tensor(np.array([1, 1]).astype(np.float32))
test_dynamic.test_dynamic_grad_net((x, y))
def dynamic_rank():
test_dynamic = TestDynamicGrad(Atan2Net())
x = Tensor(np.array([0, 1]).astype(np.float32))
y = Tensor(np.array([1, 1]).astype(np.float32))
test_dynamic.test_dynamic_grad_net((x, y), True)
def test_dynamic_atan2_cpu():
"""
Feature: Atan2 Grad DynamicShape.
Description: Test case of dynamic shape for Atan2 grad operator on CPU.
Expectation: success.
"""
# Graph mode
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
dynamic_shape()
dynamic_rank()
# PyNative mode
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
dynamic_shape()
dynamic_rank()
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
def test_dynamic_atan2_gpu():
"""
Feature: Atan2 Grad DynamicShape.
Description: Test case of dynamic shape for Atan2 grad operator on GPU.
Expectation: success.
"""
# Graph mode
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
dynamic_shape()
dynamic_rank()
# PyNative mode
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
dynamic_shape()
dynamic_rank()
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
def test_dynamic_atan2_ascend():
"""
Feature: Atan2 Grad DynamicShape.
Description: Test case of dynamic shape for Atan2 grad operator on Ascend.
Expectation: success.
"""
# Graph mode
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
dynamic_shape()
dynamic_rank()
# PyNative mode
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
dynamic_shape()
dynamic_rank()

View File

@ -0,0 +1,63 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore
import mindspore.ops.operations.math_ops as M
from mindspore import nn, context, Tensor
from .test_grad_of_dynamic import TestDynamicGrad
context.set_context(mode=context.PYNATIVE_MODE)
class NetBatchMatMul(nn.Cell):
def __init__(self):
super(NetBatchMatMul, self).__init__()
self.batchmatmul = M.BatchMatMul()
def construct(self, x, y):
return self.batchmatmul(x, y)
@pytest.mark.level1
@pytest.mark.env_onecard
@pytest.mark.platform_x86_gpu_training
def test_batch_matmul_dynamic_shape():
"""
Feature: BatchMatMul Grad DynamicShape.
Description: Test case of dynamic shape for BatchMatMul grad operator on GPU.
Expectation: success.
"""
test_dynamic = TestDynamicGrad(NetBatchMatMul(), skip_convert_out_ids=[0])
x = Tensor(np.ones(shape=[2, 4, 1, 3]), mindspore.float32)
y = Tensor(np.ones(shape=[2, 4, 3, 4]), mindspore.float32)
inputs = [x, y]
test_dynamic.test_dynamic_grad_net(inputs, False)
@pytest.mark.level1
@pytest.mark.env_onecard
@pytest.mark.platform_x86_gpu_training
def test_batch_matmul_dynamic_rank():
"""
Feature: BatchMatMul Grad DynamicShape.
Description: Test case of dynamic rank for BatchMatMul grad operator on GPU.
Expectation: success.
"""
test_dynamic = TestDynamicGrad(NetBatchMatMul(), skip_convert_out_ids=[0])
x = Tensor(np.ones(shape=[2, 4, 1, 3]), mindspore.float32)
y = Tensor(np.ones(shape=[2, 4, 3, 4]), mindspore.float32)
inputs = [x, y]
test_dynamic.test_dynamic_grad_net(inputs, True)

View File

@ -0,0 +1,99 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
from mindspore import Tensor, context
from mindspore.ops import operations as P
from .test_grad_of_dynamic import TestDynamicGrad
class BiasAddNet(nn.Cell):
def __init__(self):
super(BiasAddNet, self).__init__()
self.bias_add = P.BiasAdd()
def construct(self, x, y):
return self.bias_add(x, y)
def run_dynamic_shape():
test_dynamic = TestDynamicGrad(BiasAddNet())
input_x = Tensor(np.arange(6).reshape((2, 3)), ms.float32)
bias = Tensor(np.random.random(3).reshape((3,)), ms.float32)
test_dynamic.test_dynamic_grad_net([input_x, bias])
def run_dynamic_rank():
test_dynamic = TestDynamicGrad(BiasAddNet())
input_x = Tensor(np.arange(6).reshape((2, 3)), ms.float32)
bias = Tensor(np.random.random(3).reshape((3,)), ms.float32)
test_dynamic.test_dynamic_grad_net([input_x, bias], True)
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
def test_dynamic_bias_add_cpu():
"""
Feature: BiasAdd Grad DynamicShape.
Description: Test case of dynamic shape for BiasAdd grad operator on CPU.
Expectation: success.
"""
# Graph mode
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
run_dynamic_shape()
run_dynamic_rank()
# PyNative mode
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
run_dynamic_shape()
run_dynamic_rank()
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
def test_dynamic_bias_add_gpu():
"""
Feature: BiasAdd Grad DynamicShape.
Description: Test case of dynamic shape for BiasAdd grad operator on GPU.
Expectation: success.
"""
# Graph mode
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
run_dynamic_shape()
run_dynamic_rank()
# PyNative mode
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
run_dynamic_shape()
run_dynamic_rank()
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
def test_dynamic_bias_add_ascend():
"""
Feature: BiasAdd Grad DynamicShape.
Description: Test case of dynamic shape for BiasAdd grad operator on Ascend.
Expectation: success.
"""
# Graph mode
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
run_dynamic_shape()
run_dynamic_rank()
# PyNative mode
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
run_dynamic_shape()
run_dynamic_rank()

View File

@ -0,0 +1,63 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
from mindspore import nn
from mindspore.ops import operations as P
from mindspore import Tensor
from mindspore import context
from .test_grad_of_dynamic import TestDynamicGrad
class NetCdist(nn.Cell):
def __init__(self, p):
super(NetCdist, self).__init__()
self.cdist = P.Cdist(p)
def construct(self, x1, x2):
return self.cdist(x1, x2)
@pytest.mark.level1
@pytest.mark.env_onecard
@pytest.mark.platform_x86_cpu
def test_dynamic_shape_cdist():
"""
Feature: Cdist Grad DynamicShape.
Description: Test case of dynamic shape for Cdist grad operator on CPU.
Expectation: success.
"""
context.set_context(mode=context.PYNATIVE_MODE)
test_dynamic = TestDynamicGrad(NetCdist(2.))
x1 = Tensor(np.array([[[1.0, 1.0], [2.0, 2.0]]]).astype(np.float32))
x2 = Tensor(np.array([[[3.0, 3.0], [3.0, 3.0]]]).astype(np.float32))
test_dynamic.test_dynamic_grad_net([x1, x2], False)
@pytest.mark.level1
@pytest.mark.env_onecard
@pytest.mark.platform_x86_cpu
def test_dynamic_rank_cdist():
"""
Feature: Cdist Grad DynamicShape.
Description: Test case of dynamic rank for Cdist grad operator on CPU.
Expectation: success.
"""
context.set_context(mode=context.PYNATIVE_MODE)
test_dynamic = TestDynamicGrad(NetCdist(2.))
x1 = Tensor(np.array([[[1.0, 1.0], [2.0, 2.0]]]).astype(np.float32))
x2 = Tensor(np.array([[[3.0, 3.0], [3.0, 3.0]]]).astype(np.float32))
test_dynamic.test_dynamic_grad_net([x1, x2], True)

View File

@ -0,0 +1,111 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
from mindspore import nn, context, Tensor
from mindspore.ops.operations import _grad_ops as G
from .test_grad_of_dynamic import TestDynamicGrad
class NetConv2DBackpropFilter(nn.Cell):
def __init__(self):
super(NetConv2DBackpropFilter, self).__init__()
self.conv_filter = G.Conv2DBackpropFilter(1, 3, pad_mode="valid", pad=0, mode=1, stride=(1, 1),
dilation=(1, 1, 1, 1), group=1)
self.w_shape = (1, 1, 3, 3)
def construct(self, out, x):
return self.conv_filter(out, x, self.w_shape)
def grad_dyn_case(is_dynamic_rank):
test_dynamic = TestDynamicGrad(NetConv2DBackpropFilter())
np.random.seed(1)
out = Tensor(np.random.normal(0, 1, (1, 1, 4, 4)).astype(np.float32))
x = Tensor(np.random.normal(0, 2, (1, 1, 6, 6)).astype(np.float32))
test_dynamic.test_dynamic_grad_net([out, x], is_dynamic_rank)
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_gpu_grad_dynamic_shape():
"""
Feature: test Conv2DBackpropFilter dynamic shape on GPU.
Description: input is dynamic shape.
Expectation: the result match with static shape
"""
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
grad_dyn_case(False)
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_gpu_grad_dynamic_rank():
"""
Feature: test Conv2DBackpropFilter dynamic rank on GPU.
Description: input is dynamic rank.
Expectation: the result match with static shape
"""
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
grad_dyn_case(True)
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_cpu_grad_dynamic_shape():
"""
Feature: test Conv2DBackpropFilter dynamic shape on CPU.
Description: input is dynamic shape.
Expectation: the result match with static shape
"""
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
grad_dyn_case(False)
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_cpu_grad_dynamic_rank():
"""
Feature: test Conv2DBackpropFilter dynamic rank on CPU.
Description: input is dynamic rank.
Expectation: the result match with static shape
"""
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
grad_dyn_case(True)
def test_ascend_grad_dynamic_shape():
"""
Feature: test Conv2DBackpropFilter dynamic shape on Ascend.
Description: input is dynamic shape.
Expectation: the result match with static shape
"""
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
grad_dyn_case(False)
def test_ascend_grad_dynamic_rank():
"""
Feature: test Conv2DBackpropFilter dynamic rank on Ascend.
Description: input is dynamic rank.
Expectation: the result match with static shape
"""
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
grad_dyn_case(True)

View File

@ -0,0 +1,125 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
from mindspore import nn, context, Tensor
from mindspore.ops.operations import Conv3D
from .test_grad_of_dynamic import TestDynamicGrad
class NetConv3d(nn.Cell):
def __init__(self):
super(NetConv3d, self).__init__()
out_channel = 4
kernel_size = 2
self.conv = Conv3D(out_channel,
kernel_size,
mode=1,
pad_mode="valid",
pad=0,
stride=1,
dilation=1,
group=1)
def construct(self, x, w):
return self.conv(x, w)
def grad_dyn_case(is_dynamic_rank):
test_dynamic = TestDynamicGrad(NetConv3d())
input_np = np.arange(1 * 3 * 3 * 3 * 3).reshape(1, 3, 3, 3, 3).astype(np.float32)
weight_np = np.arange(4 * 3 * 2 * 2 * 2).reshape(4, 3, 2, 2, 2).astype(np.float32)
test_dynamic.test_dynamic_grad_net([Tensor(input_np), Tensor(weight_np)], is_dynamic_rank)
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_gpu_grad_dynamic_shape_1():
"""
Feature: test Conv3D grad dynamic shape on GPU.
Description: input is dynamic shape.
Expectation: the result match with static shape
"""
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
grad_dyn_case(False)
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_gpu_grad_dynamic_rank_1():
"""
Feature: test Conv3D grad dynamic rank on GPU.
Description: input is dynamic rank.
Expectation: the result match with static shape
"""
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
grad_dyn_case(True)
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_gpu_grad_dynamic_shape_2():
"""
Feature: test Conv3D grad dynamic shape on GPU.
Description: input is dynamic shape.
Expectation: the result match with static shape
"""
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
grad_dyn_case(False)
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_gpu_grad_dynamic_rank_2():
"""
Feature: test Conv3D grad dynamic shape on GPU.
Description: input is dynamic shape.
Expectation: the result match with static shape
"""
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
grad_dyn_case(True)
@pytest.mark.skip(reason="CPU无Conv3DBackpropFilter, Conv3DBackpropInput, kernel实现")
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_cpu_grad_dynamic_shape():
"""
Feature: test Conv3D grad dynamic shape on CPU.
Description: input is dynamic shape.
Expectation: the result match with static shape
"""
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
grad_dyn_case(False)
@pytest.mark.skip(reason="CPU无Conv3DBackpropFilter, Conv3DBackpropInput, kernel实现")
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_cpu_grad_dynamic_rank():
"""
Feature: test Conv3D grad dynamic rank on CPU.
Description: input is dynamic rank.
Expectation: the result match with static shape
"""
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
grad_dyn_case(True)

View File

@ -0,0 +1,125 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
from mindspore import nn, context, Tensor
from mindspore.ops.operations import Conv3DTranspose
from .test_grad_of_dynamic import TestDynamicGrad
class NetConv3dTranspose(nn.Cell):
def __init__(self):
super(NetConv3dTranspose, self).__init__()
in_channel = 2
out_channel = 2
kernel_size = 2
self.conv_trans = Conv3DTranspose(in_channel, out_channel,
kernel_size,
pad_mode="pad",
pad=1,
stride=1,
dilation=1,
group=1)
def construct(self, x, w):
return self.conv_trans(x, w)
def grad_dyn_case(is_dynamic_rank):
x = Tensor(np.arange(1 * 2 * 3 * 3 * 3).reshape(1, 2, 3, 3, 3).astype(np.float32))
w = Tensor(np.ones((2, 2, 2, 2, 2)).astype(np.float32))
test_dynamic = TestDynamicGrad(NetConv3dTranspose())
test_dynamic.test_dynamic_grad_net([x, w], is_dynamic_rank)
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_gpu_grad_dynamic_shape_1():
"""
Feature: test Conv3DTranspose grad dynamic shape on GPU.
Description: input is dynamic shape.
Expectation: the result match with static shape
"""
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
grad_dyn_case(False)
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_gpu_grad_dynamic_rank_1():
"""
Feature: test Conv3DTranspose grad dynamic rank on GPU.
Description: input is dynamic rank.
Expectation: the result match with static shape
"""
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
grad_dyn_case(True)
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_gpu_grad_dynamic_shape_2():
"""
Feature: test Conv3DTranspose grad dynamic shape on GPU.
Description: input is dynamic shape.
Expectation: the result match with static shape
"""
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
grad_dyn_case(False)
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_gpu_grad_dynamic_rank_2():
"""
Feature: test Conv3DTranspose grad dynamic shape on GPU.
Description: input is dynamic shape.
Expectation: the result match with static shape
"""
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
grad_dyn_case(True)
@pytest.mark.skip(reason="CPU无Conv3DBackpropFilter, Conv3DBackpropInput, kernel实现")
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_cpu_grad_dynamic_shape():
"""
Feature: test Conv3DTranspose grad dynamic shape on CPU.
Description: input is dynamic shape.
Expectation: the result match with static shape
"""
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
grad_dyn_case(False)
@pytest.mark.skip(reason="CPU无Conv3DBackpropFilter, Conv3DBackpropInput, kernel实现")
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_cpu_grad_dynamic_rank():
"""
Feature: test Conv3DTranspose grad dynamic rank on CPU.
Description: input is dynamic rank.
Expectation: the result match with static shape
"""
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
grad_dyn_case(True)

View File

@ -0,0 +1,63 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
from mindspore import ops, nn, context, Tensor
from .test_grad_of_dynamic import TestDynamicGrad
context.set_context(mode=context.PYNATIVE_MODE)
class NetDivNoNan(nn.Cell):
def __init__(self):
super(NetDivNoNan, self).__init__()
self.div = ops.DivNoNan()
def construct(self, x, y):
return self.div(x, y)
def divnonan_test(is_dyn_rank):
x = Tensor(np.array([[1, 2, 3], [4, 5, 6]]).astype(np.float32))
y = Tensor(np.array([[7, 8, 9]]).astype(np.float32))
tester = TestDynamicGrad(NetDivNoNan())
tester.test_dynamic_grad_net([x, y], is_dyn_rank)
@pytest.mark.level1
@pytest.mark.env_onecard
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
def test_divnonan_dyn_shape():
"""
Feature: DivNoNan Grad DynamicShape.
Description: Test case of dynamic shape for DivNoNan grad operator.
Expectation: success.
"""
divnonan_test(False)
@pytest.mark.level1
@pytest.mark.env_onecard
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
def test_divnonan_dyn_rank():
"""
Feature: DivNoNan Grad DynamicShape.
Description: Test case of dynamic rank for DivNoNan grad operator.
Expectation: success.
"""
divnonan_test(True)

View File

@ -0,0 +1,83 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
from mindspore import Tensor, context
from mindspore.ops import operations as P
from .test_grad_of_dynamic import TestDynamicGrad
class MaskedFillNet(nn.Cell):
def __init__(self):
super(MaskedFillNet, self).__init__()
self.maskedfill = P.MaskedFill()
def construct(self, input_0, mask, value):
return self.maskedfill(input_0, mask, value)
def run_dynamic_shape():
test_dynamic = TestDynamicGrad(MaskedFillNet())
input_0 = Tensor(np.array([1., 2., 3., 4.]), ms.float32)
mask = Tensor(np.array([True, True, False, True]), ms.bool_)
value = Tensor(0.5, ms.float32)
test_dynamic.test_dynamic_grad_net([input_0, mask, value])
def run_dynamic_rank():
test_dynamic = TestDynamicGrad(MaskedFillNet())
input_0 = Tensor(np.array([1., 2., 3., 4.]), ms.float32)
mask = Tensor(np.array([True, True, False, True]), ms.bool_)
value = Tensor(0.5, ms.float32)
test_dynamic.test_dynamic_grad_net([input_0, mask, value], True)
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
def test_dynamic_maskedfill_gpu():
"""
Feature: MaskedFill Grad DynamicShape.
Description: Test case of dynamic shape for MaskedFill grad operator on GPU.
Expectation: success.
"""
# Graph mode
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
run_dynamic_shape()
run_dynamic_rank()
# PyNative mode
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
run_dynamic_shape()
run_dynamic_rank()
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
def test_dynamic_maskedfill_ascend():
"""
Feature: MaskedFill Grad DynamicShape.
Description: Test case of dynamic shape for MaskedFill grad operator on Ascend.
Expectation: success.
"""
# Graph mode
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
run_dynamic_shape()
run_dynamic_rank()
# PyNative mode
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
run_dynamic_shape()
run_dynamic_rank()

View File

@ -0,0 +1,78 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.nn as nn
from mindspore import Tensor, context
from mindspore.ops import operations as P
from .test_grad_of_dynamic import TestDynamicGrad
class MatrixDeterminantNet(nn.Cell):
def __init__(self):
super(MatrixDeterminantNet, self).__init__()
self.matrix_determinant = P.MatrixDeterminant()
def construct(self, x):
return self.matrix_determinant(x)
def dynamic_shape():
test_dynamic = TestDynamicGrad(MatrixDeterminantNet())
x = Tensor(np.array([[[-4.5, -1.5], [7.0, 6.0]], [[2.5, 0.5], [3.0, 9.0]]]).astype(np.float32))
test_dynamic.test_dynamic_grad_net(x)
def dynamic_rank():
test_dynamic = TestDynamicGrad(MatrixDeterminantNet())
x = Tensor(np.array([[[-4.5, -1.5], [7.0, 6.0]], [[2.5, 0.5], [3.0, 9.0]]]).astype(np.float32))
test_dynamic.test_dynamic_grad_net(x, True)
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
def test_dynamic_matrix_determinant_cpu():
"""
Feature: MatrixDeterminant Grad DynamicShape.
Description: Test case of dynamic shape for MatrixDeterminant grad operator on CPU.
Expectation: success.
"""
# Graph mode
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
dynamic_shape()
dynamic_rank()
# PyNative mode
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
dynamic_shape()
dynamic_rank()
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
def test_dynamic_matrix_determinant_gpu():
"""
Feature: MatrixDeterminant Grad DynamicShape.
Description: Test case of dynamic shape for MatrixDeterminant grad operator on GPU.
Expectation: success.
"""
# Graph mode
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
dynamic_shape()
dynamic_rank()
# PyNative mode
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
dynamic_shape()
dynamic_rank()

View File

@ -0,0 +1,80 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
from mindspore import nn
from mindspore.ops import operations as P
from mindspore import Tensor
from mindspore import context
from mindspore import dtype as mstype
from .test_grad_of_dynamic import TestDynamicGrad
class MatrixDiagPartV3Net(nn.Cell):
def __init__(self, align='LEFT_RIGHT'):
super(MatrixDiagPartV3Net, self).__init__()
self.matrix_diag_dart_v3 = P.array_ops.MatrixDiagPartV3(align=align)
def construct(self, x, k, padding_value):
return self.matrix_diag_dart_v3(x, k, padding_value)
@pytest.mark.level1
@pytest.mark.env_onecard
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
def test_dynamic_shape_matrix_diag_partv3():
"""
Feature: MatrixDiagPartV3 Grad DynamicShape.
Description: Test case of dynamic shape for MatrixDiagPartV3 grad operator on CPU and GPU.
Expectation: success.
"""
context.set_context(mode=context.PYNATIVE_MODE)
align = 'RIGHT_LEFT'
test_dynamic = TestDynamicGrad(MatrixDiagPartV3Net(align))
input_x = Tensor(np.array([[[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 8, 7, 6]],
[[5, 4, 3, 2],
[1, 2, 3, 4],
[5, 6, 7, 8]]]), mstype.float32)
k = Tensor(1, mstype.int32)
padding_value = Tensor(0, mstype.float32)
test_dynamic.test_dynamic_grad_net([input_x, k, padding_value], False)
@pytest.mark.level1
@pytest.mark.env_onecard
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
def test_dynamic_rank_matrix_diag_partv3():
"""
Feature: MatrixDiagPartV3 Grad DynamicShape.
Description: Test case of dynamic rank for MatrixDiagPartV3 grad operator on CPU and GPU.
Expectation: success.
"""
context.set_context(mode=context.PYNATIVE_MODE)
align = 'RIGHT_LEFT'
test_dynamic = TestDynamicGrad(MatrixDiagPartV3Net(align))
input_x = Tensor(np.array([[[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 8, 7, 6]],
[[5, 4, 3, 2],
[1, 2, 3, 4],
[5, 6, 7, 8]]]), mstype.float32)
k = Tensor(1, mstype.int32)
padding_value = Tensor(0, mstype.float32)
test_dynamic.test_dynamic_grad_net([input_x, k, padding_value], True)

View File

@ -0,0 +1,83 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
from mindspore import Tensor, context
from mindspore.ops.operations.array_ops import MatrixDiagV3
from .test_grad_of_dynamic import TestDynamicGrad
class MatrixDiagV3Net(nn.Cell):
def __init__(self):
super(MatrixDiagV3Net, self).__init__()
self.matrix_diag_v3 = MatrixDiagV3(align='LEFT_RIGHT')
def construct(self, x, k, num_rows, num_cols, padding_value):
return self.matrix_diag_v3(x, k, num_rows, num_cols, padding_value)
def run_dynamic_shape():
test_dynamic = TestDynamicGrad(MatrixDiagV3Net())
x = Tensor(np.array([[8, 9, 0], [1, 2, 3], [0, 4, 5]]), ms.float32)
k = Tensor(np.array([-1, 1]), ms.int32)
num_rows = Tensor(np.array(3), ms.int32)
num_cols = Tensor(np.array(3), ms.int32)
padding_value = Tensor(np.array(11), ms.float32)
test_dynamic.test_dynamic_grad_net(
[x, k, num_rows, num_cols, padding_value])
def run_dynamic_rank():
test_dynamic = TestDynamicGrad(MatrixDiagV3Net())
x = Tensor(np.array([[8, 9, 0],
[1, 2, 3],
[0, 4, 5]]), ms.float32)
k = Tensor(np.array([-1, 1]), ms.int32)
num_rows = Tensor(np.array(3), ms.int32)
num_cols = Tensor(np.array(3), ms.int32)
padding_value = Tensor(np.array(11), ms.float32)
test_dynamic.test_dynamic_grad_net(
[x, k, num_rows, num_cols, padding_value], True)
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
def test_dynamic_matrix_diag_v3_cpu():
"""
Feature: MatrixDiagV3 Grad DynamicShape.
Description: Test case of dynamic shape for MatrixDiagV3 grad operator on CPU.
Expectation: success.
"""
# PyNative mode
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
run_dynamic_shape()
run_dynamic_rank()
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
def test_dynamic_matrix_diag_v3_gpu():
"""
Feature: MatrixDiagV3 Grad DynamicShape.
Description: Test case of dynamic shape for MatrixDiagV3 grad operator on GPU.
Expectation: success.
"""
# PyNative mode
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
run_dynamic_shape()
run_dynamic_rank()

View File

@ -0,0 +1,97 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import pytest
import mindspore as ms
from mindspore import nn, context, Tensor
from mindspore.ops.operations.sparse_ops import SparseAdd
from .test_grad_of_dynamic import TestDynamicGrad
class NetSparseAdd(nn.Cell):
def __init__(self):
super(NetSparseAdd, self).__init__()
self.sparse_add = SparseAdd()
def construct(self, a_indices, a_values, a_shape, b_indices, b_values, b_shape, thresh):
return self.sparse_add(a_indices, a_values, a_shape, b_indices, b_values, b_shape, thresh)
def grad_dyn_case(is_dynamic_rank):
test_dynamic = TestDynamicGrad(NetSparseAdd())
value_type = ms.float32
thresh_type = ms.float32
thresh_value = 0
a_indices = Tensor([[0, 1], [1, 2]], ms.int64)
a_values = Tensor([1, 2], value_type)
a_shape = Tensor([3, 4], ms.int64)
b_indices = Tensor([[0, 1], [1, 2]], ms.int64)
b_values = Tensor([1, 2], value_type)
b_shape = Tensor([3, 4], ms.int64)
thresh = Tensor(thresh_value, thresh_type)
test_dynamic.test_dynamic_grad_net([a_indices, a_values, a_shape, b_indices, b_values, b_shape, thresh],
is_dynamic_rank)
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_gpu_grad_dynamic_shape():
"""
Feature: test SparseAdd dynamic shape on GPU.
Description: input is dynamic shape.
Expectation: the result match with static shape
"""
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
grad_dyn_case(False)
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_gpu_grad_dynamic_rank():
"""
Feature: test SparseAdd dynamic rank on GPU.
Description: input is dynamic rank.
Expectation: the result match with static shape
"""
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
grad_dyn_case(True)
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_cpu_grad_dynamic_shape():
"""
Feature: test SparseAdd dynamic shape on CPU.
Description: input is dynamic shape.
Expectation: the result match with static shape
"""
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
grad_dyn_case(False)
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_cpu_grad_dynamic_rank():
"""
Feature: test SparseAdd dynamic rank on CPU.
Description: input is dynamic rank.
Expectation: the result match with static shape
"""
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
grad_dyn_case(True)

View File

@ -0,0 +1,63 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import pytest
import numpy as np
from mindspore import nn, context, Tensor
from mindspore.ops.operations.sparse_ops import SparseSegmentSqrtN
from .test_grad_of_dynamic import TestDynamicGrad
class NetSparseSegmentSqrtN(nn.Cell):
def __init__(self):
super(NetSparseSegmentSqrtN, self).__init__()
self.sparse_seg_sqrt_n = SparseSegmentSqrtN()
def construct(self, x, indices, seg):
return self.sparse_seg_sqrt_n(x, indices, seg)
def grad_dyn_case(is_dynamic_rank):
test_dynamic = TestDynamicGrad(NetSparseSegmentSqrtN())
x = Tensor(np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]).astype(np.float32))
indices = Tensor(np.array([0, 1, 2]).astype(np.int32))
segment_ids = Tensor(np.array([0, 1, 2]).astype(np.int32))
test_dynamic.test_dynamic_grad_net((x, indices, segment_ids), is_dynamic_rank)
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_cpu_grad_dynamic_shape():
"""
Feature: test SparseSegmentSqrtN dynamic shape on CPU.
Description: input is dynamic shape.
Expectation: the result match with static shape
"""
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
grad_dyn_case(False)
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_cpu_grad_dynamic_rank():
"""
Feature: test SparseSegmentSqrtN dynamic rank on CPU.
Description: input is dynamic rank.
Expectation: the result match with static shape
"""
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
grad_dyn_case(True)

View File

@ -0,0 +1,64 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import pytest
import mindspore as ms
from mindspore import nn, context, Tensor
from mindspore.ops.operations.sparse_ops import SparseSegmentSqrtNWithNumSegments
from .test_grad_of_dynamic import TestDynamicGrad
class NetSparseSegmentSqrtNWithNumSegments(nn.Cell):
def __init__(self):
super(NetSparseSegmentSqrtNWithNumSegments, self).__init__()
self.sparse_seg_sqrt_n_with_n_seg = SparseSegmentSqrtNWithNumSegments()
def construct(self, x, indices, seg, num):
return self.sparse_seg_sqrt_n_with_n_seg(x, indices, seg, num)
def grad_dyn_case(is_dynamic_rank):
test_dynamic = TestDynamicGrad(NetSparseSegmentSqrtNWithNumSegments(), skip_convert_out_ids=[0])
x = Tensor([[0, 1, 0, 0], [0, 1, 1, 0], [1, 0, 1, 0]], dtype=ms.float16)
indices = Tensor([0, 2, 1], dtype=ms.int32)
segment_ids = Tensor([0, 1, 2], dtype=ms.int32)
num_segments = Tensor([4], dtype=ms.int32)
test_dynamic.test_dynamic_grad_net((x, indices, segment_ids, num_segments), is_dynamic_rank)
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_cpu_grad_dynamic_shape():
"""
Feature: test SparseSegmentSqrtNWithNumSegments dynamic shape on CPU.
Description: input is dynamic shape.
Expectation: the result match with static shape
"""
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
grad_dyn_case(False)
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_cpu_grad_dynamic_rank():
"""
Feature: test SparseSegmentSqrtNWithNumSegments dynamic rank on CPU.
Description: input is dynamic rank.
Expectation: the result match with static shape
"""
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
grad_dyn_case(True)

View File

@ -0,0 +1,67 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import pytest
import mindspore
import mindspore.ops.operations.sparse_ops as S
from mindspore import nn, context, Tensor
from .test_grad_of_dynamic import TestDynamicGrad
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
class NetSparseSegmentMeanWithNumSegments(nn.Cell):
def __init__(self):
super(NetSparseSegmentMeanWithNumSegments, self).__init__()
self.sparse_segment_mean_with_num_segments = S.SparseSegmentMeanWithNumSegments()
def construct(self, x, indices, seg_ids, num_segments):
return self.sparse_segment_mean_with_num_segments(x, indices, seg_ids, num_segments)
@pytest.mark.level1
@pytest.mark.env_onecard
@pytest.mark.platform_x86_cpu
def test_sparse_segment_mean_with_num_segments_dynamic_shape():
"""
Feature: SparseSegmentMeanWithNumSegments Grad DynamicShape.
Description: Test case of dynamic shape for SparseSegmentMeanWithNumSegments grad operator on CPU.
Expectation: success.
"""
test_dynamic = TestDynamicGrad(NetSparseSegmentMeanWithNumSegments(), skip_convert_out_ids=[0])
x = Tensor([[0, 2, 0, 0], [0, 1, 1, 0], [2, 0, 2, 0]], dtype=mindspore.float32)
indices = Tensor([0, 2, 1], dtype=mindspore.int32)
segment_ids = Tensor([0, 0, 2], dtype=mindspore.int32)
num_segments = Tensor([4], dtype=mindspore.int32)
inputs = [x, indices, segment_ids, num_segments]
test_dynamic.test_dynamic_grad_net(inputs, False)
@pytest.mark.level1
@pytest.mark.env_onecard
@pytest.mark.platform_x86_cpu
def test_sparse_segment_mean_with_num_segments_dynamic_rank():
"""
Feature: SparseSegmentMeanWithNumSegments Grad DynamicShape.
Description: Test case of dynamic rank for SparseSegmentMeanWithNumSegments grad operator on CPU.
Expectation: success.
"""
x = Tensor([[0, 2, 0, 0], [0, 1, 1, 0], [2, 0, 2, 0]], dtype=mindspore.float32)
indices = Tensor([0, 2, 1], dtype=mindspore.int32)
segment_ids = Tensor([0, 0, 2], dtype=mindspore.int32)
num_segments = Tensor([4], dtype=mindspore.int32)
test_dynamic = TestDynamicGrad(NetSparseSegmentMeanWithNumSegments(), skip_convert_out_ids=[0])
inputs = [x, indices, segment_ids, num_segments]
test_dynamic.test_dynamic_grad_net(inputs, True)

View File

@ -0,0 +1,110 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
from mindspore import nn, ops, context, Tensor
from .test_grad_of_dynamic import TestDynamicGrad
class NetSquaredDifference(nn.Cell):
def __init__(self):
super(NetSquaredDifference, self).__init__()
self.squared_diff = ops.SquaredDifference()
def construct(self, x, y):
return self.squared_diff(x, y)
def grad_dyn_case(is_dynamic_rank):
test_dynamic = TestDynamicGrad(NetSquaredDifference())
np.random.seed(1)
x = Tensor(np.random.uniform(10, 20, (3, 4, 5, 2)).astype(np.float16))
y = Tensor(np.random.uniform(40, 50, (3, 4, 5, 2)).astype(np.float16))
test_dynamic.test_dynamic_grad_net([x, y], is_dynamic_rank)
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_gpu_grad_dynamic_shape():
"""
Feature: test SquaredDifference dynamic shape on GPU.
Description: input is dynamic shape.
Expectation: the result match with static shape
"""
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
grad_dyn_case(False)
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_gpu_grad_dynamic_rank():
"""
Feature: test SquaredDifference dynamic rank on GPU.
Description: input is dynamic rank.
Expectation: the result match with static shape
"""
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
grad_dyn_case(True)
def test_cpu_grad_dynamic_shape():
"""
Feature: test SquaredDifference dynamic shape on CPU.
Description: input is dynamic shape.
Expectation: the result match with static shape
"""
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
grad_dyn_case(False)
def test_cpu_grad_dynamic_rank():
"""
Feature: test SquaredDifference dynamic rank on CPU.
Description: input is dynamic rank.
Expectation: the result match with static shape
"""
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
grad_dyn_case(True)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_ascend_grad_dynamic_shape():
"""
Feature: test SquaredDifference dynamic shape on Ascend.
Description: input is dynamic shape.
Expectation: the result match with static shape
"""
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
grad_dyn_case(False)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_ascend_grad_dynamic_rank():
"""
Feature: test SquaredDifference dynamic rank on Ascend.
Description: input is dynamic rank.
Expectation: the result match with static shape
"""
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
grad_dyn_case(True)

View File

@ -0,0 +1,63 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
from mindspore import ops, nn, context, Tensor
from .test_grad_of_dynamic import TestDynamicGrad
context.set_context(mode=context.PYNATIVE_MODE)
class NetSub(nn.Cell):
def __init__(self):
super(NetSub, self).__init__()
self.sub = ops.Sub()
def construct(self, x, y):
return self.sub(x, y)
def sub_test(is_dyn_rank):
x = Tensor(np.array([[1, 2, 3], [4, 5, 6]]).astype(np.float32))
y = Tensor(np.array([[7, 8, 9]]).astype(np.float32))
tester = TestDynamicGrad(NetSub())
tester.test_dynamic_grad_net([x, y], is_dyn_rank)
@pytest.mark.level1
@pytest.mark.env_onecard
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
def test_sub_dyn_shape():
"""
Feature: Sub Grad DynamicShape.
Description: Test case of dynamic shape for Sub grad operator.
Expectation: success.
"""
sub_test(False)
@pytest.mark.level1
@pytest.mark.env_onecard
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
def test_sub_dyn_rank():
"""
Feature: Sub Grad DynamicShape.
Description: Test case of dynamic rank for Sub grad operator.
Expectation: success.
"""
sub_test(True)

View File

@ -0,0 +1,62 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
from mindspore import ops, nn, context, Tensor
from .test_grad_of_dynamic import TestDynamicGrad
context.set_context(mode=context.PYNATIVE_MODE)
class NetTopK(nn.Cell):
def __init__(self, k):
super(NetTopK, self).__init__()
self.topk = ops.TopK()
self.k = k
def construct(self, x):
return self.topk(x, self.k)
def topk_test(is_dyn_rank):
x = Tensor(np.array([[1, 2, 3, 4, 5]]).astype(np.float32))
k = 3
tester = TestDynamicGrad(NetTopK(k))
tester.test_dynamic_grad_net([x], is_dyn_rank)
@pytest.mark.level1
@pytest.mark.env_onecard
@pytest.mark.platform_x86_cpu
def test_topk_dyn_shape():
"""
Feature: TopK Grad DynamicShape.
Description: Test case of dynamic shape for TopK grad operator.
Expectation: success.
"""
topk_test(False)
@pytest.mark.level1
@pytest.mark.env_onecard
@pytest.mark.platform_x86_cpu
def test_topk_dyn_rank():
"""
Feature: TopK Grad DynamicShape.
Description: Test case of dynamic rank for TopK grad operator.
Expectation: success.
"""
topk_test(True)

View File

@ -0,0 +1,61 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore
import mindspore.ops.operations.math_ops as M
from mindspore import nn, context, Tensor
from .test_grad_of_dynamic import TestDynamicGrad
context.set_context(mode=context.PYNATIVE_MODE)
class NetTrace(nn.Cell):
def __init__(self):
super(NetTrace, self).__init__()
self.trace = M.Trace()
def construct(self, x):
return self.trace(x)
@pytest.mark.level1
@pytest.mark.env_onecard
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
def test_trace_dynamic_shape():
"""
Feature: Trace Grad DynamicShape.
Description: Test case of dynamic shape for Trace grad operator on CPU and GPU.
Expectation: success.
"""
test_dynamic = TestDynamicGrad(NetTrace())
x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), mindspore.float32)
test_dynamic.test_dynamic_grad_net(x, False)
@pytest.mark.level1
@pytest.mark.env_onecard
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
def test_trace_dynamic_shape_rank():
"""
Feature: Trace Grad DynamicShape.
Description: Test case of dynamic rank for Trace grad operator on CPU and GPU.
Expectation: success.
"""
test_dynamic = TestDynamicGrad(NetTrace())
x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), mindspore.float32)
test_dynamic.test_dynamic_grad_net(x, True)

View File

@ -0,0 +1,67 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
from mindspore import nn
from mindspore.ops import operations as P
from mindspore import Tensor
from mindspore import context
from .test_grad_of_dynamic import TestDynamicGrad
class NetTranspose(nn.Cell):
def __init__(self):
super(NetTranspose, self).__init__()
self.transpose = P.Transpose()
def construct(self, x, perm):
return self.transpose(x, perm)
@pytest.mark.level1
@pytest.mark.env_onecard
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
def test_dynamic_shape_transpose():
"""
Feature: Transpose Grad DynamicShape.
Description: Test case of dynamic shape for Transpose grad operator on CPU, GPU and Ascend.
Expectation: success.
"""
context.set_context(mode=context.PYNATIVE_MODE)
test_dynamic = TestDynamicGrad(NetTranspose())
x = Tensor(np.array([[1, 2, 3], [3, 2, 1]]))
perm = (0, 1)
test_dynamic.test_dynamic_grad_net([x, perm], False)
@pytest.mark.level1
@pytest.mark.env_onecard
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
def test_dynamic_rank_transpose():
"""
Feature: Transpose Grad DynamicShape.
Description: Test case of dynamic rank for Transpose grad operator on CPU, GPU and Ascend.
Expectation: success.
"""
context.set_context(mode=context.PYNATIVE_MODE)
test_dynamic = TestDynamicGrad(NetTranspose())
x = Tensor(np.array([[1, 2, 3], [3, 2, 1]]))
perm = (0, 1)
test_dynamic.test_dynamic_grad_net([x, perm], True)

View File

@ -0,0 +1,81 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
from mindspore import Tensor, context
from mindspore.ops import operations as P
from .test_grad_of_dynamic import TestDynamicGrad
class TruncateModNet(nn.Cell):
def __init__(self):
super(TruncateModNet, self).__init__()
self.truncate_mode = P.TruncateMod()
def construct(self, x, y):
return self.truncate_mode(x, y)
def run_dynamic_shape():
test_dynamic = TestDynamicGrad(TruncateModNet())
x = Tensor(np.array([2, 4, -1]), ms.int32)
y = Tensor(np.array([3, 3, 3]), ms.int32)
test_dynamic.test_dynamic_grad_net([x, y])
def run_dynamic_rank():
test_dynamic = TestDynamicGrad(TruncateModNet())
x = Tensor(np.array([2, 4, -1]), ms.int32)
y = Tensor(np.array([3, 3, 3]), ms.int32)
test_dynamic.test_dynamic_grad_net([x, y], True)
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
def test_dynamic_truncate_mode_gpu():
"""
Feature: TruncateMod Grad DynamicShape.
Description: Test case of dynamic shape for TruncateMod grad operator on GPU.
Expectation: success.
"""
# Graph mode
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
run_dynamic_shape()
run_dynamic_rank()
# PyNative mode
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
run_dynamic_shape()
run_dynamic_rank()
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
def test_dynamic_truncate_mode_ascend():
"""
Feature: TruncateMod Grad DynamicShape.
Description: Test case of dynamic shape for TruncateMod grad operator on Ascend.
Expectation: success.
"""
# Graph mode
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
run_dynamic_shape()
run_dynamic_rank()
# PyNative mode
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
run_dynamic_shape()
run_dynamic_rank()

View File

@ -0,0 +1,70 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
from mindspore import nn
from mindspore.ops import operations as P
from mindspore import Tensor
from mindspore import context
from .test_grad_of_dynamic import TestDynamicGrad
class UnsortedSegmentMaxNet(nn.Cell):
def __init__(self, num_segments):
super(UnsortedSegmentMaxNet, self).__init__()
self.unsorted_segment_max = P.UnsortedSegmentMax()
self.num_segments = num_segments
def construct(self, data, ids):
return self.unsorted_segment_max(data, ids, self.num_segments)
@pytest.mark.level1
@pytest.mark.env_onecard
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
def test_dynamic_shape_unsorted_segment_max():
"""
Feature: UnsortedSegmentMax Grad DynamicShape.
Description: Test case of dynamic shape for UnsortedSegmentMax grad operator on CPU and GPU.
Expectation: success.
"""
context.set_context(mode=context.PYNATIVE_MODE)
num_segments = 2
test_dynamic = TestDynamicGrad(UnsortedSegmentMaxNet(num_segments))
input_x = Tensor(
np.array([[1, 2, 3], [4, 5, 6], [4, 2, 1]]).astype(np.float32))
segment_ids = Tensor(np.array([0, 1, 1]).astype(np.int32))
test_dynamic.test_dynamic_grad_net([input_x, segment_ids], False)
@pytest.mark.level1
@pytest.mark.env_onecard
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
def test_dynamic_rank_unsorted_segment_max():
"""
Feature: UnsortedSegmentMax Grad DynamicShape.
Description: Test case of dynamic rank for UnsortedSegmentMax grad operator on CPU and GPU.
Expectation: success.
"""
context.set_context(mode=context.PYNATIVE_MODE)
num_segments = 2
test_dynamic = TestDynamicGrad(UnsortedSegmentMaxNet(num_segments))
input_x = Tensor(
np.array([[1, 2, 3], [4, 5, 6], [4, 2, 1]]).astype(np.float32))
segment_ids = Tensor(np.array([0, 1, 1]).astype(np.int32))
test_dynamic.test_dynamic_grad_net([input_x, segment_ids], True)