forked from mindspore-Ecosystem/mindspore
!42469 update grad for dynamic_shape
Merge pull request !42469 from chengbin/ds_r1.9
This commit is contained in:
commit
520dfd5c6e
|
@ -17,38 +17,12 @@
|
||||||
#include "backend/common/pass/reduce_sum_optimizer.h"
|
#include "backend/common/pass/reduce_sum_optimizer.h"
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "include/common/utils/anfalgo.h"
|
#include "include/common/utils/anfalgo.h"
|
||||||
|
#include "utils/ms_context.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
namespace {
|
namespace {
|
||||||
const int axis_input_index = 2;
|
const int axis_input_index = 2;
|
||||||
|
|
||||||
bool IsNeedComputeRank(const CNodePtr &cnode) {
|
|
||||||
MS_EXCEPTION_IF_NULL(cnode);
|
|
||||||
auto axis_input = cnode->input(axis_input_index);
|
|
||||||
MS_EXCEPTION_IF_NULL(axis_input);
|
|
||||||
if (!IsValueNode<ValueTuple>(axis_input)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
auto value_node = axis_input->cast<ValueNodePtr>();
|
|
||||||
MS_EXCEPTION_IF_NULL(value_node);
|
|
||||||
auto value = value_node->value();
|
|
||||||
MS_EXCEPTION_IF_NULL(value);
|
|
||||||
if (value->isa<ValueTuple>()) {
|
|
||||||
auto value_tuple = value->cast<ValueTuplePtr>();
|
|
||||||
MS_EXCEPTION_IF_NULL(value_tuple);
|
|
||||||
if (value_tuple->value().empty()) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
for (auto &iter : value_tuple->value()) {
|
|
||||||
auto item = GetValue<int64_t>(iter->cast<ScalarPtr>());
|
|
||||||
if (item < 0) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
AnfNodePtr ReduceSumOptimizer::NewRankOp(const AnfNodePtr &cnode, const KernelGraphPtr &kernel_graph) const {
|
AnfNodePtr ReduceSumOptimizer::NewRankOp(const AnfNodePtr &cnode, const KernelGraphPtr &kernel_graph) const {
|
||||||
|
@ -168,9 +142,6 @@ const AnfNodePtr ReduceSumOptimizer::Process(const FuncGraphPtr &func_graph, con
|
||||||
common::AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node);
|
common::AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node);
|
||||||
auto kernel_graph = func_graph->cast<std::shared_ptr<session::KernelGraph>>();
|
auto kernel_graph = func_graph->cast<std::shared_ptr<session::KernelGraph>>();
|
||||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||||
if (AnfUtils::IsDimUnknown(cnode) && IsNeedComputeRank(cnode)) {
|
|
||||||
return InsertAssistNode(cnode, kernel_graph);
|
|
||||||
}
|
|
||||||
return NewAssistValueNode(cnode, kernel_graph);
|
return NewAssistValueNode(cnode, kernel_graph);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -28,7 +28,12 @@ constexpr size_t kBiasAddGradOutputsNum = 1;
|
||||||
void BiasAddGradCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
void BiasAddGradCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||||
input_shape_ = Convert2SizeTClipNeg(AnfAlgo::GetInputDeviceShape(kernel_node, 0));
|
auto shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||||
|
if (IsDynamic(shape)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
input_shape_ = Convert2SizeT(shape);
|
||||||
if (input_shape_.size() < 2) {
|
if (input_shape_.size() < 2) {
|
||||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', input tensor's dimension must be at least 2, but got "
|
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', input tensor's dimension must be at least 2, but got "
|
||||||
<< input_shape_.size();
|
<< input_shape_.size();
|
||||||
|
|
|
@ -97,6 +97,7 @@ void TopKCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_
|
MS_LOG(EXCEPTION) << "For '" << kernel_name_
|
||||||
<< "', the dimension of input must be greater than 0, but got empty input.";
|
<< "', the dimension of input must be greater than 0, but got empty input.";
|
||||||
}
|
}
|
||||||
|
outer_size_ = 1;
|
||||||
for (size_t i = 0; i < x_shape_.size() - 1; ++i) {
|
for (size_t i = 0; i < x_shape_.size() - 1; ++i) {
|
||||||
outer_size_ *= x_shape_[i];
|
outer_size_ *= x_shape_[i];
|
||||||
}
|
}
|
||||||
|
|
|
@ -50,9 +50,9 @@ int LinSpaceGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std
|
||||||
if (auto ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) {
|
if (auto ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) {
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
auto input1_shape = Convert2SizeTClipNeg(inputs[kIndex0]->GetShapeVector());
|
auto input1_shape = inputs[kIndex0]->GetShapeVector();
|
||||||
auto input2_shape = Convert2SizeTClipNeg(inputs[kIndex1]->GetShapeVector());
|
auto input2_shape = inputs[kIndex1]->GetShapeVector();
|
||||||
auto output_shape = Convert2SizeTClipNeg(outputs[kIndex0]->GetShapeVector());
|
auto output_shape = outputs[kIndex0]->GetShapeVector();
|
||||||
is_null_input_ = CHECK_SHAPE_NULL(input1_shape, kernel_name_, "start") ||
|
is_null_input_ = CHECK_SHAPE_NULL(input1_shape, kernel_name_, "start") ||
|
||||||
CHECK_SHAPE_NULL(input2_shape, kernel_name_, "stop") ||
|
CHECK_SHAPE_NULL(input2_shape, kernel_name_, "stop") ||
|
||||||
CHECK_SHAPE_NULL(output_shape, kernel_name_, "output");
|
CHECK_SHAPE_NULL(output_shape, kernel_name_, "output");
|
||||||
|
|
|
@ -43,18 +43,22 @@ bool SvdGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vect
|
||||||
int SvdGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
int SvdGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||||
const std::vector<KernelTensorPtr> &outputs,
|
const std::vector<KernelTensorPtr> &outputs,
|
||||||
const std::map<uint32_t, tensor::TensorPtr> &) {
|
const std::map<uint32_t, tensor::TensorPtr> &) {
|
||||||
|
if (int ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) {
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
auto input_shape = inputs[kIndex0]->GetShapeVector();
|
||||||
|
if (IsDynamicRank(input_shape)) {
|
||||||
|
return KRET_OK;
|
||||||
|
}
|
||||||
DestroyResource();
|
DestroyResource();
|
||||||
ResetResource();
|
ResetResource();
|
||||||
|
input_shape_ = Convert2SizeTClipNeg(input_shape);
|
||||||
input_shape_ = std::vector<size_t>(inputs.at(kIndex0)->GetDeviceShapeAdaptively().begin(),
|
|
||||||
inputs.at(kIndex0)->GetDeviceShapeAdaptively().end());
|
|
||||||
total_size_ = std::accumulate(input_shape_.begin(), input_shape_.end(), size_t(1), std::multiplies<size_t>());
|
total_size_ = std::accumulate(input_shape_.begin(), input_shape_.end(), size_t(1), std::multiplies<size_t>());
|
||||||
is_null_input_ = (total_size_ == 0);
|
is_null_input_ = (total_size_ == 0);
|
||||||
if (is_null_input_) {
|
if (is_null_input_) {
|
||||||
init_size_lists_func_(this);
|
init_size_lists_func_(this);
|
||||||
return 0;
|
return KRET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
dims_ = input_shape_.size();
|
dims_ = input_shape_.size();
|
||||||
if (dims_ < kDim2) {
|
if (dims_ < kDim2) {
|
||||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', dimensions must >= 2, but got [" << dims_;
|
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', dimensions must >= 2, but got [" << dims_;
|
||||||
|
|
|
@ -26,5 +26,19 @@ MS_REG_GPU_KERNEL_ONE(
|
||||||
Conv3DBackpropFilter,
|
Conv3DBackpropFilter,
|
||||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat32),
|
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat32),
|
||||||
Conv3dGradFilterGpuKernelMod, half)
|
Conv3dGradFilterGpuKernelMod, half)
|
||||||
|
MS_REG_GPU_KERNEL_ONE(Conv3DBackpropFilter,
|
||||||
|
KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeFloat32)
|
||||||
|
.AddInputAttr(kNumberTypeFloat32)
|
||||||
|
.AddInputAttr(kNumberTypeInt64)
|
||||||
|
.AddOutputAttr(kNumberTypeFloat32),
|
||||||
|
Conv3dGradFilterGpuKernelMod, float)
|
||||||
|
MS_REG_GPU_KERNEL_ONE(Conv3DBackpropFilter,
|
||||||
|
KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeFloat16)
|
||||||
|
.AddInputAttr(kNumberTypeFloat16)
|
||||||
|
.AddInputAttr(kNumberTypeInt64)
|
||||||
|
.AddOutputAttr(kNumberTypeFloat32),
|
||||||
|
Conv3dGradFilterGpuKernelMod, half)
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -31,6 +31,9 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
|
constexpr int kStaticInputNum = 2;
|
||||||
|
constexpr int kDynamicInputNum = 3;
|
||||||
|
|
||||||
constexpr size_t kInputDimSize = 5;
|
constexpr size_t kInputDimSize = 5;
|
||||||
constexpr size_t kInDimIdxForN = 0;
|
constexpr size_t kInDimIdxForN = 0;
|
||||||
constexpr size_t kInDimIdxForC = 1;
|
constexpr size_t kInDimIdxForC = 1;
|
||||||
|
@ -127,8 +130,8 @@ class Conv3dGradFilterGpuKernelMod : public NativeGpuKernelMod {
|
||||||
InitResource();
|
InitResource();
|
||||||
|
|
||||||
size_t input_num = inputs.size();
|
size_t input_num = inputs.size();
|
||||||
if (input_num != 2) {
|
if (input_num != kStaticInputNum && input_num != kDynamicInputNum) {
|
||||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of inputs must be 2, but got " << input_num;
|
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of inputs must be 2 or 3, but got " << input_num;
|
||||||
}
|
}
|
||||||
size_t output_num = outputs.size();
|
size_t output_num = outputs.size();
|
||||||
if (output_num != 1) {
|
if (output_num != 1) {
|
||||||
|
|
|
@ -26,5 +26,19 @@ MS_REG_GPU_KERNEL_ONE(
|
||||||
Conv3DBackpropInput,
|
Conv3DBackpropInput,
|
||||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||||
Conv3dGradInputGpuKernelMod, half)
|
Conv3dGradInputGpuKernelMod, half)
|
||||||
|
MS_REG_GPU_KERNEL_ONE(Conv3DBackpropInput,
|
||||||
|
KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeFloat32)
|
||||||
|
.AddInputAttr(kNumberTypeFloat32)
|
||||||
|
.AddInputAttr(kNumberTypeInt64)
|
||||||
|
.AddOutputAttr(kNumberTypeFloat32),
|
||||||
|
Conv3dGradInputGpuKernelMod, float)
|
||||||
|
MS_REG_GPU_KERNEL_ONE(Conv3DBackpropInput,
|
||||||
|
KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeFloat16)
|
||||||
|
.AddInputAttr(kNumberTypeFloat16)
|
||||||
|
.AddInputAttr(kNumberTypeInt64)
|
||||||
|
.AddOutputAttr(kNumberTypeFloat16),
|
||||||
|
Conv3dGradInputGpuKernelMod, half)
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -32,7 +32,8 @@ namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
constexpr int kNumDims = 5;
|
constexpr int kNumDims = 5;
|
||||||
constexpr int kConvDims = 3;
|
constexpr int kConvDims = 3;
|
||||||
constexpr int kInputNum = 2;
|
constexpr int kStaticInputNum = 2;
|
||||||
|
constexpr int kDynamicInputNum = 3;
|
||||||
constexpr size_t kInDimIdxForN = 0;
|
constexpr size_t kInDimIdxForN = 0;
|
||||||
constexpr size_t kInDimIdxForC = 1;
|
constexpr size_t kInDimIdxForC = 1;
|
||||||
constexpr size_t kInDimIdxForD = 2;
|
constexpr size_t kInDimIdxForD = 2;
|
||||||
|
@ -115,8 +116,8 @@ class Conv3dGradInputGpuKernelMod : public NativeGpuKernelMod {
|
||||||
InitResource();
|
InitResource();
|
||||||
|
|
||||||
size_t input_num = inputs.size();
|
size_t input_num = inputs.size();
|
||||||
if (input_num != kInputNum) {
|
if (input_num != kStaticInputNum && input_num != kDynamicInputNum) {
|
||||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of inputs must be 2, but got " << input_num;
|
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of inputs must be 2 or 3, but got " << input_num;
|
||||||
}
|
}
|
||||||
size_t output_num = outputs.size();
|
size_t output_num = outputs.size();
|
||||||
if (output_num != 1) {
|
if (output_num != 1) {
|
||||||
|
|
|
@ -37,6 +37,7 @@ void InsertCast(const FuncGraphPtr &graph, const AnfNodePtr &node, size_t i, con
|
||||||
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), common::AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), i)};
|
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), common::AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), i)};
|
||||||
auto cast = graph->NewCNode(inputs);
|
auto cast = graph->NewCNode(inputs);
|
||||||
MS_EXCEPTION_IF_NULL(cast);
|
MS_EXCEPTION_IF_NULL(cast);
|
||||||
|
common::AnfAlgo::SetNodeAttr(kAttrDstType, TypeIdToType(cast_type), cast);
|
||||||
auto cast_shape = {common::AnfAlgo::GetPrevNodeOutputDetailShape(node, i)};
|
auto cast_shape = {common::AnfAlgo::GetPrevNodeOutputDetailShape(node, i)};
|
||||||
common::AnfAlgo::SetOutputTypeAndDetailShape({cast_type}, cast_shape, cast.get());
|
common::AnfAlgo::SetOutputTypeAndDetailShape({cast_type}, cast_shape, cast.get());
|
||||||
FuncGraphManagerPtr manager = graph->manager();
|
FuncGraphManagerPtr manager = graph->manager();
|
||||||
|
|
|
@ -221,10 +221,20 @@ AbstractBasePtr InferImplMatMul(const AnalysisEnginePtr &, const PrimitivePtr &p
|
||||||
MS_EXCEPTION_IF_NULL(y->shape());
|
MS_EXCEPTION_IF_NULL(y->shape());
|
||||||
auto x_shp = x->shape()->shape();
|
auto x_shp = x->shape()->shape();
|
||||||
auto y_shp = y->shape()->shape();
|
auto y_shp = y->shape()->shape();
|
||||||
const size_t SHAPE_SIZE = 2;
|
|
||||||
if (x_shp.size() != SHAPE_SIZE || y_shp.size() != SHAPE_SIZE) {
|
TypePtr x_type = x->element()->GetTypeTrack();
|
||||||
MS_LOG(EXCEPTION) << "MatMul inputs should have the same dimension size and equal to 2.";
|
if (x_type->type_id() == TypeId::kNumberTypeInt8) {
|
||||||
|
x_type = kInt32;
|
||||||
}
|
}
|
||||||
|
if (primitive->HasAttr("cast_type")) {
|
||||||
|
auto out_type = primitive->GetAttr("cast_type");
|
||||||
|
MS_EXCEPTION_IF_NULL(out_type);
|
||||||
|
if (!out_type->isa<Type>()) {
|
||||||
|
MS_EXCEPTION(ValueError) << "MatMul cast_type must be a `Type`";
|
||||||
|
}
|
||||||
|
x_type = out_type->cast<TypePtr>();
|
||||||
|
}
|
||||||
|
|
||||||
ValuePtr transpose_a_ptr = primitive->GetAttr("transpose_a");
|
ValuePtr transpose_a_ptr = primitive->GetAttr("transpose_a");
|
||||||
ValuePtr transpose_b_ptr = primitive->GetAttr("transpose_b");
|
ValuePtr transpose_b_ptr = primitive->GetAttr("transpose_b");
|
||||||
bool transpose_a = GetValue<bool>(transpose_a_ptr);
|
bool transpose_a = GetValue<bool>(transpose_a_ptr);
|
||||||
|
@ -233,18 +243,23 @@ AbstractBasePtr InferImplMatMul(const AnalysisEnginePtr &, const PrimitivePtr &p
|
||||||
ShapeVector x_max_shape = x->shape()->max_shape();
|
ShapeVector x_max_shape = x->shape()->max_shape();
|
||||||
ShapeVector y_min_shape = y->shape()->min_shape();
|
ShapeVector y_min_shape = y->shape()->min_shape();
|
||||||
ShapeVector y_max_shape = y->shape()->max_shape();
|
ShapeVector y_max_shape = y->shape()->max_shape();
|
||||||
// Additional check for dynamic shape
|
|
||||||
// Last infer will be real shape values
|
if (IsDynamicRank(x_shp) || IsDynamicRank(y_shp)) {
|
||||||
bool x_not_dyn = std::all_of(x_shp.begin(), x_shp.end(), [](int64_t value) { return value != Shape::SHP_ANY; });
|
ShapeVector ret_shape{UNKNOWN_RANK};
|
||||||
bool y_not_dyn = std::all_of(y_shp.begin(), y_shp.end(), [](int64_t value) { return value != Shape::SHP_ANY; });
|
return std::make_shared<AbstractTensor>(x_type, std::make_shared<Shape>(ret_shape));
|
||||||
if (x_not_dyn && y_not_dyn) {
|
|
||||||
auto x_col = x_shp[(transpose_a ? 0 : 1)];
|
|
||||||
auto y_row = y_shp[(transpose_b ? 1 : 0)];
|
|
||||||
if (x_col != y_row) {
|
|
||||||
MS_LOG(EXCEPTION) << "MatMul shape error, got x_col: " << x_col << ", y_row: " << y_row
|
|
||||||
<< ". In MatMul x_col and y_row should be equal.";
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const size_t SHAPE_SIZE = 2;
|
||||||
|
if (x_shp.size() != SHAPE_SIZE || y_shp.size() != SHAPE_SIZE) {
|
||||||
|
MS_LOG(EXCEPTION) << "MatMul inputs should have the same dimension size and equal to 2.";
|
||||||
|
}
|
||||||
|
auto x_col = x_shp[(transpose_a ? 0 : 1)];
|
||||||
|
auto y_row = y_shp[(transpose_b ? 1 : 0)];
|
||||||
|
if (x_col != y_row && x_col >= 0 && y_row >= 0) {
|
||||||
|
MS_LOG(EXCEPTION) << "MatMul shape error, got x_col: " << x_col << ", y_row: " << y_row
|
||||||
|
<< ". In MatMul x_col and y_row should be equal.";
|
||||||
|
}
|
||||||
|
|
||||||
ShapeVector ret_shape;
|
ShapeVector ret_shape;
|
||||||
ShapeVector ret_min_shape;
|
ShapeVector ret_min_shape;
|
||||||
ShapeVector ret_max_shape;
|
ShapeVector ret_max_shape;
|
||||||
|
@ -259,18 +274,6 @@ AbstractBasePtr InferImplMatMul(const AnalysisEnginePtr &, const PrimitivePtr &p
|
||||||
make_shape(ret_shape, x_shp, y_shp);
|
make_shape(ret_shape, x_shp, y_shp);
|
||||||
make_shape(ret_min_shape, x_min_shape, y_min_shape);
|
make_shape(ret_min_shape, x_min_shape, y_min_shape);
|
||||||
make_shape(ret_max_shape, x_max_shape, y_max_shape);
|
make_shape(ret_max_shape, x_max_shape, y_max_shape);
|
||||||
TypePtr x_type = x->element()->GetTypeTrack();
|
|
||||||
if (x_type->type_id() == TypeId::kNumberTypeInt8) {
|
|
||||||
x_type = kInt32;
|
|
||||||
}
|
|
||||||
if (primitive->HasAttr("cast_type")) {
|
|
||||||
auto out_type = primitive->GetAttr("cast_type");
|
|
||||||
MS_EXCEPTION_IF_NULL(out_type);
|
|
||||||
if (!out_type->isa<Type>()) {
|
|
||||||
MS_EXCEPTION(ValueError) << "MatMul cast_type must be a `Type`";
|
|
||||||
}
|
|
||||||
x_type = out_type->cast<TypePtr>();
|
|
||||||
}
|
|
||||||
return std::make_shared<AbstractTensor>(x_type, std::make_shared<Shape>(ret_shape, ret_min_shape, ret_max_shape));
|
return std::make_shared<AbstractTensor>(x_type, std::make_shared<Shape>(ret_shape, ret_min_shape, ret_max_shape));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -83,16 +83,15 @@ abstract::ShapePtr BatchMatmulInferShape(const PrimitivePtr &primitive,
|
||||||
if (IsDynamicRank(x_shp) || IsDynamicRank(y_shp)) {
|
if (IsDynamicRank(x_shp) || IsDynamicRank(y_shp)) {
|
||||||
return std::make_shared<abstract::Shape>(ShapeVector({UNKNOWN_RANK}));
|
return std::make_shared<abstract::Shape>(ShapeVector({UNKNOWN_RANK}));
|
||||||
}
|
}
|
||||||
auto context = MsContext::GetInstance();
|
|
||||||
constexpr size_t x_dim_limit = 3;
|
constexpr size_t x_dim_limit = 3;
|
||||||
constexpr size_t y_dim_limit = 2;
|
constexpr size_t y_dim_limit = 2;
|
||||||
if ((!IsDynamic(x_shp)) && (!IsDynamic(y_shp))) {
|
|
||||||
if (x_shp.size() < x_dim_limit || y_shp.size() < y_dim_limit) {
|
bool not_dynamic_shape = (!IsDynamic(x_shp)) && !(IsDynamic(y_shp));
|
||||||
MS_EXCEPTION(ValueError)
|
if (not_dynamic_shape && (x_shp.size() < x_dim_limit || y_shp.size() < y_dim_limit)) {
|
||||||
<< "For '" << prim_name
|
MS_EXCEPTION(ValueError)
|
||||||
<< "', input 'x' must be greater or equal to 3, input 'y' must be greater or equal to 2. But got 'x': "
|
<< "For '" << prim_name
|
||||||
<< x_shp.size() << ", 'y': " << y_shp.size() << ".";
|
<< "', input 'x' must be greater or equal to 3, input 'y' must be greater or equal to 2. But got 'x': "
|
||||||
}
|
<< x_shp.size() << ", 'y': " << y_shp.size() << ".";
|
||||||
}
|
}
|
||||||
|
|
||||||
constexpr size_t offset = 2;
|
constexpr size_t offset = 2;
|
||||||
|
@ -106,7 +105,7 @@ abstract::ShapePtr BatchMatmulInferShape(const PrimitivePtr &primitive,
|
||||||
int64_t y_row = y_last[static_cast<size_t>(transpose_b)];
|
int64_t y_row = y_last[static_cast<size_t>(transpose_b)];
|
||||||
if (std::find(x_shp.begin(), x_shp.end(), -1) == x_shp.end() &&
|
if (std::find(x_shp.begin(), x_shp.end(), -1) == x_shp.end() &&
|
||||||
std::find(y_shp.begin(), y_shp.end(), -1) == y_shp.end()) {
|
std::find(y_shp.begin(), y_shp.end(), -1) == y_shp.end()) {
|
||||||
if (x_col != y_row) {
|
if (not_dynamic_shape && x_col != y_row) {
|
||||||
MS_EXCEPTION(ValueError) << "For " << prim_name << " evaluator shapes of inputs can not do this operator, "
|
MS_EXCEPTION(ValueError) << "For " << prim_name << " evaluator shapes of inputs can not do this operator, "
|
||||||
<< "got " << x_col << " and " << y_row << " , with x1 shape " << x_shp
|
<< "got " << x_col << " and " << y_row << " , with x1 shape " << x_shp
|
||||||
<< "(transpose_a=" << transpose_a << "})"
|
<< "(transpose_a=" << transpose_a << "})"
|
||||||
|
@ -117,11 +116,7 @@ abstract::ShapePtr BatchMatmulInferShape(const PrimitivePtr &primitive,
|
||||||
(void)primitive->AddAttr("transpose_x2", transpose_b_ptr);
|
(void)primitive->AddAttr("transpose_x2", transpose_b_ptr);
|
||||||
// Additional check for dynamic shape
|
// Additional check for dynamic shape
|
||||||
// Last infer will be real shape values
|
// Last infer will be real shape values
|
||||||
bool x_not_dyn =
|
if (not_dynamic_shape) {
|
||||||
std::all_of(x_shp.begin(), x_shp.end(), [](int64_t value) { return value != abstract::Shape::SHP_ANY; });
|
|
||||||
bool y_not_dyn =
|
|
||||||
std::all_of(y_shp.begin(), y_shp.end(), [](int64_t value) { return value != abstract::Shape::SHP_ANY; });
|
|
||||||
if (x_not_dyn && y_not_dyn) {
|
|
||||||
size_t x_offset = x_shp.size() - offset;
|
size_t x_offset = x_shp.size() - offset;
|
||||||
size_t y_offset = y_shp.size() - offset;
|
size_t y_offset = y_shp.size() - offset;
|
||||||
auto x_c = x_shp[x_offset + (transpose_a ? 0 : 1)];
|
auto x_c = x_shp[x_offset + (transpose_a ? 0 : 1)];
|
||||||
|
|
|
@ -40,11 +40,14 @@ abstract::ShapePtr BiasAddInferShape(const PrimitivePtr &primitive, const std::v
|
||||||
(void)CheckAndConvertUtils::CheckInteger("arg size", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
|
(void)CheckAndConvertUtils::CheckInteger("arg size", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
|
||||||
auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape());
|
auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape());
|
||||||
auto input_shape = shape_map[kShape];
|
auto input_shape = shape_map[kShape];
|
||||||
|
auto bias_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
|
||||||
|
if (IsDynamicRank(input_shape) || IsDynamicRank(bias_shape)) {
|
||||||
|
return std::make_shared<abstract::Shape>(ShapeVector{UNKNOWN_RANK});
|
||||||
|
}
|
||||||
const int64_t x_min_rank = 2;
|
const int64_t x_min_rank = 2;
|
||||||
const int64_t x_max_rank = 5;
|
const int64_t x_max_rank = 5;
|
||||||
CheckAndConvertUtils::CheckInRange("dims of input_x", input_shape.size(), kIncludeBoth, {x_min_rank, x_max_rank},
|
CheckAndConvertUtils::CheckInRange("dims of input_x", input_shape.size(), kIncludeBoth, {x_min_rank, x_max_rank},
|
||||||
prim_name);
|
prim_name);
|
||||||
auto bias_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
|
|
||||||
(void)CheckAndConvertUtils::CheckInteger("bias rank", SizeToLong(bias_shape.size()), kEqual, 1, prim_name);
|
(void)CheckAndConvertUtils::CheckInteger("bias rank", SizeToLong(bias_shape.size()), kEqual, 1, prim_name);
|
||||||
const int64_t x_size = 2;
|
const int64_t x_size = 2;
|
||||||
(void)CheckAndConvertUtils::CheckInteger("x rank", SizeToLong(input_shape.size()), kGreaterEqual, x_size, prim_name);
|
(void)CheckAndConvertUtils::CheckInteger("x rank", SizeToLong(input_shape.size()), kGreaterEqual, x_size, prim_name);
|
||||||
|
|
|
@ -52,7 +52,7 @@ void SetConv3DBackpropFilterPadList(const PrimitivePtr &primitive, const std::ve
|
||||||
CheckAndConvertUtils::CheckIntOrTupleInt("attribute[dilation]", primitive->GetAttr(kDilation), prim_name);
|
CheckAndConvertUtils::CheckIntOrTupleInt("attribute[dilation]", primitive->GetAttr(kDilation), prim_name);
|
||||||
// default pad mode is valid
|
// default pad mode is valid
|
||||||
int64_t pad_mode;
|
int64_t pad_mode;
|
||||||
CheckAndConvertUtils::GetPadModEnumValue(primitive->GetAttr(kPadMode), &pad_mode, true);
|
CheckAndConvertUtils::GetPadModEnumValue(primitive->GetAttr(kPadMode), &pad_mode, false);
|
||||||
ShapeVector pad_list = {0, 0, 0, 0, 0, 0};
|
ShapeVector pad_list = {0, 0, 0, 0, 0, 0};
|
||||||
auto attr_pad_list_prt = primitive->GetAttr(kPadList);
|
auto attr_pad_list_prt = primitive->GetAttr(kPadList);
|
||||||
if ((attr_pad_list_prt != nullptr) && !attr_pad_list_prt->isa<None>()) {
|
if ((attr_pad_list_prt != nullptr) && !attr_pad_list_prt->isa<None>()) {
|
||||||
|
@ -261,6 +261,7 @@ class Conv3DBackpropFilterInfer : public abstract::OpInferBase {
|
||||||
std::set<TypePtr> valid_x_type = {kInt8, kInt32, kFloat16, kFloat32};
|
std::set<TypePtr> valid_x_type = {kInt8, kInt32, kFloat16, kFloat32};
|
||||||
return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_x_type, prim_name);
|
return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_x_type, prim_name);
|
||||||
}
|
}
|
||||||
|
std::set<int64_t> GetValueDependArgIndices() const override { return {kConv3DBackpropFilterFilterSizeIndex}; }
|
||||||
};
|
};
|
||||||
|
|
||||||
REGISTER_PRIMITIVE_OP_INFER_IMPL(Conv3DBackpropFilter, prim::kPrimConv3DBackpropFilter, Conv3DBackpropFilterInfer,
|
REGISTER_PRIMITIVE_OP_INFER_IMPL(Conv3DBackpropFilter, prim::kPrimConv3DBackpropFilter, Conv3DBackpropFilterInfer,
|
||||||
|
|
|
@ -45,7 +45,7 @@ void SetConv3DBackpropInputPadList(const PrimitivePtr &primitive, const std::vec
|
||||||
CheckAndConvertUtils::CheckIntOrTupleInt("attribute[dilation]", primitive->GetAttr(kDilation), prim_name);
|
CheckAndConvertUtils::CheckIntOrTupleInt("attribute[dilation]", primitive->GetAttr(kDilation), prim_name);
|
||||||
// default pad mode is valid
|
// default pad mode is valid
|
||||||
int64_t pad_mode;
|
int64_t pad_mode;
|
||||||
CheckAndConvertUtils::GetPadModEnumValue(primitive->GetAttr(kPadMode), &pad_mode, true);
|
CheckAndConvertUtils::GetPadModEnumValue(primitive->GetAttr(kPadMode), &pad_mode, false);
|
||||||
ShapeVector pad_list = {0, 0, 0, 0, 0, 0};
|
ShapeVector pad_list = {0, 0, 0, 0, 0, 0};
|
||||||
auto attr_pad_list_prt = primitive->GetAttr(kPadList);
|
auto attr_pad_list_prt = primitive->GetAttr(kPadList);
|
||||||
if ((attr_pad_list_prt != nullptr) && (!attr_pad_list_prt->isa<None>())) {
|
if ((attr_pad_list_prt != nullptr) && (!attr_pad_list_prt->isa<None>())) {
|
||||||
|
@ -258,6 +258,7 @@ class Conv3DBackpropInputInfer : public abstract::OpInferBase {
|
||||||
std::set<TypePtr> valid_x_type = {kInt8, kInt32, kFloat16, kFloat32};
|
std::set<TypePtr> valid_x_type = {kInt8, kInt32, kFloat16, kFloat32};
|
||||||
return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_x_type, prim_name);
|
return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_x_type, prim_name);
|
||||||
}
|
}
|
||||||
|
std::set<int64_t> GetValueDependArgIndices() const override { return {kConv3DBackpropInputSizeIndex}; }
|
||||||
};
|
};
|
||||||
|
|
||||||
REGISTER_PRIMITIVE_OP_INFER_IMPL(Conv3DBackpropInput, prim::kPrimConv3DBackpropInput, Conv3DBackpropInputInfer, false);
|
REGISTER_PRIMITIVE_OP_INFER_IMPL(Conv3DBackpropInput, prim::kPrimConv3DBackpropInput, Conv3DBackpropInputInfer, false);
|
||||||
|
|
|
@ -121,6 +121,10 @@ abstract::ShapePtr MatrixDiagV3InferShape(const PrimitivePtr &primitive,
|
||||||
auto padding_value_rank = SizeToLong(padding_shape.size());
|
auto padding_value_rank = SizeToLong(padding_shape.size());
|
||||||
constexpr int64_t number_one = 1;
|
constexpr int64_t number_one = 1;
|
||||||
constexpr int64_t number_two = 2;
|
constexpr int64_t number_two = 2;
|
||||||
|
if (IsDynamicRank(x_shape)) {
|
||||||
|
ShapeVector out_shape = {UNKNOWN_RANK};
|
||||||
|
return std::make_shared<abstract::Shape>(out_shape);
|
||||||
|
}
|
||||||
CheckAndConvertUtils::CheckInRange<int64_t>("rank of 'k'", k_rank, kIncludeBoth, {0, number_one}, prim_name);
|
CheckAndConvertUtils::CheckInRange<int64_t>("rank of 'k'", k_rank, kIncludeBoth, {0, number_one}, prim_name);
|
||||||
(void)CheckAndConvertUtils::CheckInteger("rank of 'num_rows'", row_rank, kEqual, 0, prim_name);
|
(void)CheckAndConvertUtils::CheckInteger("rank of 'num_rows'", row_rank, kEqual, 0, prim_name);
|
||||||
(void)CheckAndConvertUtils::CheckInteger("rank of 'num_cols'", col_rank, kEqual, 0, prim_name);
|
(void)CheckAndConvertUtils::CheckInteger("rank of 'num_cols'", col_rank, kEqual, 0, prim_name);
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -34,6 +34,14 @@ abstract::BaseShapePtr SvdInferShape(const PrimitivePtr &prim, const std::vector
|
||||||
auto full_matrices = GetValue<bool>(prim->GetAttr(kAttrFullMatrices));
|
auto full_matrices = GetValue<bool>(prim->GetAttr(kAttrFullMatrices));
|
||||||
|
|
||||||
auto a_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
|
auto a_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
|
||||||
|
if (IsDynamicRank(a_shape) || IsDynamic(a_shape)) {
|
||||||
|
ShapeVector dyn_shape{UNKNOWN_RANK};
|
||||||
|
std::vector<abstract::BaseShapePtr> shape_tuple;
|
||||||
|
(void)shape_tuple.emplace_back(std::make_shared<abstract::Shape>(dyn_shape));
|
||||||
|
(void)shape_tuple.emplace_back(std::make_shared<abstract::Shape>(dyn_shape));
|
||||||
|
(void)shape_tuple.emplace_back(std::make_shared<abstract::Shape>(dyn_shape));
|
||||||
|
return std::make_shared<abstract::TupleShape>(shape_tuple);
|
||||||
|
}
|
||||||
auto ndim = a_shape.size();
|
auto ndim = a_shape.size();
|
||||||
(void)CheckAndConvertUtils::CheckInteger("ndim", SizeToLong(ndim), kGreaterEqual, kSizeTwo, prim->name());
|
(void)CheckAndConvertUtils::CheckInteger("ndim", SizeToLong(ndim), kGreaterEqual, kSizeTwo, prim->name());
|
||||||
auto m = a_shape[ndim - kIndexTwo];
|
auto m = a_shape[ndim - kIndexTwo];
|
||||||
|
|
|
@ -36,16 +36,16 @@ abstract::TupleShapePtr TopKInferShape(const PrimitivePtr &primitive, const std:
|
||||||
auto prim_name = primitive->name();
|
auto prim_name = primitive->name();
|
||||||
auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape());
|
auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape());
|
||||||
auto x_shape = shape_map[kShape];
|
auto x_shape = shape_map[kShape];
|
||||||
|
if (IsDynamicRank(x_shape)) {
|
||||||
|
abstract::BaseShapePtr out_shape_ptr = std::make_shared<abstract::Shape>(ShapeVector{UNKNOWN_RANK});
|
||||||
|
return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{out_shape_ptr, out_shape_ptr});
|
||||||
|
}
|
||||||
int64_t k_v = 0;
|
int64_t k_v = 0;
|
||||||
// 2rd input is a Tensor when TopK is a dynamic shape operator
|
// 2rd input is a Tensor when TopK is a dynamic shape operator
|
||||||
if (input_args[kInputIndex1]->isa<abstract::AbstractTensor>()) {
|
if (input_args[kInputIndex1]->isa<abstract::AbstractTensor>()) {
|
||||||
auto k_ptr = input_args[kInputIndex1]->BuildValue();
|
auto k_ptr = input_args[kInputIndex1]->BuildValue();
|
||||||
MS_EXCEPTION_IF_NULL(k_ptr);
|
MS_EXCEPTION_IF_NULL(k_ptr);
|
||||||
if (k_ptr->isa<tensor::Tensor>()) {
|
k_v = CheckAndConvertUtils::CheckTensorIntValue("k", k_ptr, prim_name)[0];
|
||||||
auto k_tensor_ptr = k_ptr->cast<tensor::TensorPtr>();
|
|
||||||
MS_EXCEPTION_IF_NULL(k_tensor_ptr);
|
|
||||||
k_v = *static_cast<int64_t *>(k_tensor_ptr->data_c());
|
|
||||||
}
|
|
||||||
} else if (input_args[kInputIndex1]->isa<abstract::AbstractScalar>()) {
|
} else if (input_args[kInputIndex1]->isa<abstract::AbstractScalar>()) {
|
||||||
k_v = GetValue<int64_t>(input_args[kInputIndex1]->BuildValue());
|
k_v = GetValue<int64_t>(input_args[kInputIndex1]->BuildValue());
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -29,6 +29,7 @@ from mindspore.ops.primitive import constexpr
|
||||||
from mindspore.common import dtype as mstype
|
from mindspore.common import dtype as mstype
|
||||||
from mindspore.common.tensor import RowTensor
|
from mindspore.common.tensor import RowTensor
|
||||||
from mindspore.ops._utils.utils import range_op, get_1d_shape, generate_shape_index, is_shape_unknown
|
from mindspore.ops._utils.utils import range_op, get_1d_shape, generate_shape_index, is_shape_unknown
|
||||||
|
from .._grad.grad_base import dyn_rank, convert_to_tensor, dyn_invert_permutation, dyn_size, dyn_ones, dyn_fill
|
||||||
|
|
||||||
reduce_sum = P.ReduceSum()
|
reduce_sum = P.ReduceSum()
|
||||||
unsorted_segment_sum = P.UnsortedSegmentSum()
|
unsorted_segment_sum = P.UnsortedSegmentSum()
|
||||||
|
@ -370,11 +371,20 @@ def _transpose_perm_positive(perm):
|
||||||
return tuple(res)
|
return tuple(res)
|
||||||
|
|
||||||
|
|
||||||
|
def _dyn_transpose_perm_positive(perm):
|
||||||
|
return (perm + dyn_size(perm)) % (dyn_size(perm))
|
||||||
|
|
||||||
|
|
||||||
@bprop_getters.register(P.Transpose)
|
@bprop_getters.register(P.Transpose)
|
||||||
def get_bprop_transpose(self):
|
def get_bprop_transpose(self):
|
||||||
"""Generate bprop for Transpose"""
|
"""Generate bprop for Transpose"""
|
||||||
|
|
||||||
def bprop(x, perm, out, dout):
|
def bprop(x, perm, out, dout):
|
||||||
|
is_mutable, perm = convert_to_tensor(perm)
|
||||||
|
if is_mutable:
|
||||||
|
perm = _dyn_transpose_perm_positive(perm)
|
||||||
|
return transpose(dout, dyn_invert_permutation(perm)), zeros_like(perm)
|
||||||
|
|
||||||
perm = _transpose_perm_positive(perm)
|
perm = _transpose_perm_positive(perm)
|
||||||
return transpose(dout, invert_permutation(perm)), zeros_like(perm)
|
return transpose(dout, invert_permutation(perm)), zeros_like(perm)
|
||||||
|
|
||||||
|
@ -1019,18 +1029,36 @@ def _gather_drop_negatives(params,
|
||||||
select = P.Select()
|
select = P.Select()
|
||||||
|
|
||||||
if zero_clipped_indices is None:
|
if zero_clipped_indices is None:
|
||||||
zero_clipped_indices = maximum(ids, zeros_like(ids))
|
if is_shape_unknown(shape_op(ids)):
|
||||||
|
zero_ids = dyn_fill(ids.dtype, dyn_shape_op(ids), 0)
|
||||||
|
else:
|
||||||
|
zero_ids = zeros_like(ids)
|
||||||
|
zero_clipped_indices = maximum(ids, zero_ids)
|
||||||
gathered = gather(params, zero_clipped_indices, 0)
|
gathered = gather(params, zero_clipped_indices, 0)
|
||||||
|
zero_slice = zeros_like(gathered)
|
||||||
if is_positive is None:
|
if is_positive is None:
|
||||||
is_positive = greater_equal(ids, 0)
|
is_positive = greater_equal(ids, 0)
|
||||||
is_positive_shape = shape_op(is_positive)
|
is_positive_shape = shape_op(is_positive)
|
||||||
broadcastable_shape = is_positive_shape
|
|
||||||
for _ in range(rank(gathered) - rank(is_positive)):
|
|
||||||
broadcastable_shape += (1,)
|
|
||||||
is_positive = reshape(is_positive, broadcastable_shape)
|
|
||||||
gathered_shape = shape_op(gathered)
|
gathered_shape = shape_op(gathered)
|
||||||
is_positive = logical_and(is_positive, fill(mstype.bool_, gathered_shape, 1))
|
if is_shape_unknown(gathered_shape) or is_shape_unknown(is_positive_shape):
|
||||||
zero_slice = zeros_like(gathered)
|
gathered_shape = dyn_shape_op(gathered)
|
||||||
|
rank_gathered = dyn_rank(gathered)
|
||||||
|
fill_gathered = dyn_fill(mstype.int64, gathered_shape, 1)
|
||||||
|
is_positive_shape = dyn_shape_op(is_positive)
|
||||||
|
rank_positive = dyn_rank(is_positive)
|
||||||
|
if rank_gathered - rank_positive > 0:
|
||||||
|
padded_size = F.expand_dims(rank_gathered - rank_positive, 0)
|
||||||
|
padded_shape = dyn_ones(padded_size, is_positive_shape.dtype)
|
||||||
|
is_positive_shape = P.Concat(-1)((is_positive_shape, padded_shape))
|
||||||
|
is_positive = reshape(is_positive, is_positive_shape)
|
||||||
|
is_positive = logical_and(is_positive, F.cast(fill_gathered, mstype.bool_))
|
||||||
|
zero_slice = dyn_fill(gathered.dtype, gathered_shape, 0)
|
||||||
|
else:
|
||||||
|
broadcastable_shape = is_positive_shape
|
||||||
|
for _ in range(rank(gathered) - rank(is_positive)):
|
||||||
|
broadcastable_shape += (1,)
|
||||||
|
is_positive = reshape(is_positive, broadcastable_shape)
|
||||||
|
is_positive = logical_and(is_positive, fill(mstype.bool_, gathered_shape, 1))
|
||||||
return (select(is_positive, gathered, zero_slice), zero_clipped_indices, is_positive)
|
return (select(is_positive, gathered, zero_slice), zero_clipped_indices, is_positive)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1059,7 +1087,13 @@ def get_bprop_unsorted_segment_sum(self):
|
||||||
"""Generate bprop for UnsortedSegmentSum"""
|
"""Generate bprop for UnsortedSegmentSum"""
|
||||||
|
|
||||||
def bprop(x, segment_ids, num_segments, out, dout):
|
def bprop(x, segment_ids, num_segments, out, dout):
|
||||||
return _gather_drop_negatives(dout, segment_ids, None, None)[0], zeros_like(segment_ids), \
|
segment_shape = shape_op(segment_ids)
|
||||||
|
if is_shape_unknown(segment_shape):
|
||||||
|
segment_shape = dyn_shape_op(segment_ids)
|
||||||
|
zeros_segment = dyn_fill(segment_ids.dtype, segment_shape, 0)
|
||||||
|
else:
|
||||||
|
zeros_segment = zeros_like(segment_ids)
|
||||||
|
return _gather_drop_negatives(dout, segment_ids, None, None)[0], zeros_segment, \
|
||||||
zeros_like(num_segments)
|
zeros_like(num_segments)
|
||||||
|
|
||||||
return bprop
|
return bprop
|
||||||
|
|
|
@ -97,12 +97,12 @@ def dyn_rank(tensor):
|
||||||
return dyn_shape(dyn_shape(tensor))[0]
|
return dyn_shape(dyn_shape(tensor))[0]
|
||||||
|
|
||||||
|
|
||||||
def dyn_size(tensor):
|
def dyn_size(tensor, dtype=mstype.int64):
|
||||||
"""get the size of tensor"""
|
"""get the size of tensor"""
|
||||||
shape = dyn_shape(tensor)
|
shape = dyn_shape(tensor)
|
||||||
shape = cast(shape, mstype.float32)
|
shape = cast(shape, mstype.float32)
|
||||||
size = P.ReduceProd()(shape)
|
size = P.ReduceProd()(shape)
|
||||||
size = cast(size, mstype.int32)
|
size = cast(size, dtype)
|
||||||
return size
|
return size
|
||||||
|
|
||||||
|
|
||||||
|
@ -117,7 +117,7 @@ def create_tensor_by_element(ori_tuple, data_type=mstype.int64):
|
||||||
return ori_tuple
|
return ori_tuple
|
||||||
|
|
||||||
|
|
||||||
def dyn_invert_premutation(prem):
|
def dyn_invert_permutation(prem):
|
||||||
"""get the invert premutation of tensor"""
|
"""get the invert premutation of tensor"""
|
||||||
indices = P.ExpandDims()(prem, -1)
|
indices = P.ExpandDims()(prem, -1)
|
||||||
end = dyn_size(prem)
|
end = dyn_size(prem)
|
||||||
|
|
|
@ -53,8 +53,20 @@ def dyn_binop_grad_common(x, y, dx, dy):
|
||||||
shape_of_x = dyn_shape_op(x)
|
shape_of_x = dyn_shape_op(x)
|
||||||
shape_of_y = dyn_shape_op(y)
|
shape_of_y = dyn_shape_op(y)
|
||||||
rx, ry = DynamicBroadcastGradientArgs()(shape_of_x, shape_of_y)
|
rx, ry = DynamicBroadcastGradientArgs()(shape_of_x, shape_of_y)
|
||||||
dx = reduce_sum(dx, rx)
|
dx_origin_dtype = dx.dtype
|
||||||
dy = reduce_sum(dy, ry)
|
if dx_origin_dtype in (mstype.int16, mstype.int32, mstype.int64):
|
||||||
|
dx = F.cast(dx, mstype.float32)
|
||||||
|
dx = reduce_sum(dx, rx)
|
||||||
|
dx = F.cast(dx, dx_origin_dtype)
|
||||||
|
else:
|
||||||
|
dx = reduce_sum(dx, rx)
|
||||||
|
dy_origin_dtype = dy.dtype
|
||||||
|
if dy_origin_dtype in (mstype.int16, mstype.int32, mstype.int64):
|
||||||
|
dy = F.cast(dy, mstype.float32)
|
||||||
|
dy = reduce_sum(dy, ry)
|
||||||
|
dy = F.cast(dy, dy_origin_dtype)
|
||||||
|
else:
|
||||||
|
dy = reduce_sum(dy, ry)
|
||||||
reduce_dx = reshape(dx, shape_of_x)
|
reduce_dx = reshape(dx, shape_of_x)
|
||||||
reduce_dy = reshape(dy, shape_of_y)
|
reduce_dy = reshape(dy, shape_of_y)
|
||||||
return reduce_dx, reduce_dy
|
return reduce_dx, reduce_dy
|
||||||
|
|
|
@ -18,7 +18,7 @@ import numpy as np
|
||||||
from mindspore.ops.primitive import constexpr
|
from mindspore.ops.primitive import constexpr
|
||||||
from mindspore.ops.operations import nn_ops as nps
|
from mindspore.ops.operations import nn_ops as nps
|
||||||
from mindspore.common import dtype as mstype
|
from mindspore.common import dtype as mstype
|
||||||
from .grad_base import bprop_getters
|
from .grad_base import bprop_getters, dyn_size, create_tensor_by_element
|
||||||
from .. import functional as F
|
from .. import functional as F
|
||||||
from .. import operations as P
|
from .. import operations as P
|
||||||
from ..composite.multitype_ops.zeros_like_impl import zeros_like
|
from ..composite.multitype_ops.zeros_like_impl import zeros_like
|
||||||
|
@ -61,6 +61,17 @@ def bias_add_gradgrad_helper(shape, bias_shape, data_format):
|
||||||
return tuple(expanded_shape), tuple(tile_mults)
|
return tuple(expanded_shape), tuple(tile_mults)
|
||||||
|
|
||||||
|
|
||||||
|
def bias_add_gradgrad_helper_dynamic(shape, bias_shape, data_format):
|
||||||
|
"""Helper function of BiasGradGrad to calculate expanded shape(dynamic version)."""
|
||||||
|
if data_format == "NCHW":
|
||||||
|
expanded_shape = P.Concat(0)((P.OnesLike()(shape[:1]), bias_shape, P.OnesLike()(shape[2:])))
|
||||||
|
tile_mults = P.Concat(0)((shape[:1], [1], shape[2:]))
|
||||||
|
else:
|
||||||
|
expanded_shape = P.Concat(0)((P.OnesLike()(shape[:-1]), bias_shape))
|
||||||
|
tile_mults = P.Concat(0)((shape[:-1], [1]))
|
||||||
|
return tuple(expanded_shape), tuple(tile_mults)
|
||||||
|
|
||||||
|
|
||||||
@bprop_getters.register(G.BiasAddGrad)
|
@bprop_getters.register(G.BiasAddGrad)
|
||||||
def get_bprop_bias_add_grad(self):
|
def get_bprop_bias_add_grad(self):
|
||||||
"""Grad definition for `BiasAddGrad` operation."""
|
"""Grad definition for `BiasAddGrad` operation."""
|
||||||
|
@ -70,10 +81,19 @@ def get_bprop_bias_add_grad(self):
|
||||||
def bprop(dy, out, dout):
|
def bprop(dy, out, dout):
|
||||||
reshape = P.Reshape()
|
reshape = P.Reshape()
|
||||||
tile = P.Tile()
|
tile = P.Tile()
|
||||||
|
dyn_shape = P.TensorShape()
|
||||||
expanded_shape, tile_mults = bias_add_gradgrad_helper(dy.shape, dout.shape, data_format)
|
if is_shape_unknown(dy) or is_shape_unknown(dout):
|
||||||
expanded_grad = reshape(dout, expanded_shape)
|
dy_shape = dyn_shape(dy)
|
||||||
tiled_grad = tile(expanded_grad, tile_mults)
|
dout_shape = dyn_shape(dout)
|
||||||
|
expanded_shape, tile_mults = bias_add_gradgrad_helper_dynamic(dy_shape, dout_shape, data_format)
|
||||||
|
expanded_grad = reshape(dout, create_tensor_by_element(expanded_shape))
|
||||||
|
tiled_grad = tile(expanded_grad, create_tensor_by_element(tile_mults))
|
||||||
|
else:
|
||||||
|
dy_shape = dy.shape
|
||||||
|
dout_shape = dout.shape
|
||||||
|
expanded_shape, tile_mults = bias_add_gradgrad_helper(dy_shape, dout_shape, data_format)
|
||||||
|
expanded_grad = reshape(dout, expanded_shape)
|
||||||
|
tiled_grad = tile(expanded_grad, tile_mults)
|
||||||
return (tiled_grad,)
|
return (tiled_grad,)
|
||||||
|
|
||||||
return bprop
|
return bprop
|
||||||
|
@ -122,8 +142,14 @@ def get_bprop_conv3d(self):
|
||||||
pad=self.pad, stride=self.stride, dilation=self.dilation, group=self.group, data_format=self.data_format
|
pad=self.pad, stride=self.stride, dilation=self.dilation, group=self.group, data_format=self.data_format
|
||||||
)
|
)
|
||||||
get_shape = P.Shape()
|
get_shape = P.Shape()
|
||||||
|
get_dyn_shape = P.TensorShape()
|
||||||
|
|
||||||
def bprop(x, w, out, dout):
|
def bprop(x, w, out, dout):
|
||||||
|
if is_shape_unknown(get_shape(x)) or is_shape_unknown(get_shape(w)):
|
||||||
|
dx = input_grad(w, dout, get_dyn_shape(x))
|
||||||
|
dw = filter_grad(x, dout, get_dyn_shape(w))
|
||||||
|
return dx, dw
|
||||||
|
|
||||||
dx = input_grad(w, dout, get_shape(x))
|
dx = input_grad(w, dout, get_shape(x))
|
||||||
dw = filter_grad(x, dout, get_shape(w))
|
dw = filter_grad(x, dout, get_shape(w))
|
||||||
return dx, dw
|
return dx, dw
|
||||||
|
@ -145,8 +171,14 @@ def get_bprop_conv3d_transpose(self):
|
||||||
out_channel=self.in_channel, kernel_size=self.kernel_size, mode=self.mode, pad_mode="pad",
|
out_channel=self.in_channel, kernel_size=self.kernel_size, mode=self.mode, pad_mode="pad",
|
||||||
pad=pad_list, stride=self.stride, dilation=self.dilation, group=self.group, data_format=self.data_format
|
pad=pad_list, stride=self.stride, dilation=self.dilation, group=self.group, data_format=self.data_format
|
||||||
)
|
)
|
||||||
|
get_dyn_shape = P.TensorShape()
|
||||||
|
|
||||||
def bprop(x, w, out, dout):
|
def bprop(x, w, out, dout):
|
||||||
|
if is_shape_unknown(F.shape(w)):
|
||||||
|
dx = input_grad(dout, w)
|
||||||
|
dw = filter_grad(dout, x, get_dyn_shape(w))
|
||||||
|
return dx, dw
|
||||||
|
|
||||||
dx = input_grad(dout, w)
|
dx = input_grad(dout, w)
|
||||||
dw = filter_grad(dout, x, F.shape(w))
|
dw = filter_grad(dout, x, F.shape(w))
|
||||||
return dx, dw
|
return dx, dw
|
||||||
|
@ -928,10 +960,12 @@ def get_bprop_top_kv2(self):
|
||||||
scatter = P.ScatterNd()
|
scatter = P.ScatterNd()
|
||||||
expand_dims = P.ExpandDims()
|
expand_dims = P.ExpandDims()
|
||||||
shape_op = P.Shape()
|
shape_op = P.Shape()
|
||||||
|
dyn_shape = P.TensorShape()
|
||||||
reshape_op = P.Reshape()
|
reshape_op = P.Reshape()
|
||||||
dtype = P.DType()
|
dtype = P.DType()
|
||||||
|
cast = P.Cast()
|
||||||
|
|
||||||
def bprop(input_x, k, out, dout):
|
def _bprop_static(input_x, k, out, dout):
|
||||||
in_shape = shape_op(input_x)
|
in_shape = shape_op(input_x)
|
||||||
in_lastdim = in_shape[-1]
|
in_lastdim = in_shape[-1]
|
||||||
|
|
||||||
|
@ -958,6 +992,37 @@ def get_bprop_top_kv2(self):
|
||||||
in_shape)
|
in_shape)
|
||||||
return out_grad, zeros_like(k)
|
return out_grad, zeros_like(k)
|
||||||
|
|
||||||
|
def _bprop_dynshape(input_x, k, out, dout):
|
||||||
|
in_shape = dyn_shape(input_x)
|
||||||
|
in_lastdim = in_shape[-1]
|
||||||
|
|
||||||
|
indices = out[1]
|
||||||
|
ind_shape = dyn_shape(indices)
|
||||||
|
ind_lastdim = ind_shape[-1]
|
||||||
|
|
||||||
|
ind_2d = reshape_op(indices, create_tensor_by_element((-1, ind_lastdim)))
|
||||||
|
outerdim = dyn_shape(ind_2d)[0]
|
||||||
|
|
||||||
|
# [0, outterdim, 2*outerdim, ..., (k-1)*outerdim]
|
||||||
|
range_flatten_index = P.Range()(cast(0, mstype.int64), outerdim * in_lastdim, in_lastdim)
|
||||||
|
|
||||||
|
# expand_dims to (k, 1), then broadcast
|
||||||
|
ind = reshape_op(ind_2d + expand_dims(range_flatten_index, -1), create_tensor_by_element((-1,)))
|
||||||
|
in_shape_1d = expand_dims(dyn_size(input_x, mstype.int64), -1)
|
||||||
|
|
||||||
|
out_grad = reshape_op(
|
||||||
|
scatter(
|
||||||
|
expand_dims(ind, -1),
|
||||||
|
reshape_op(dout[0], create_tensor_by_element((-1,))),
|
||||||
|
in_shape_1d),
|
||||||
|
in_shape)
|
||||||
|
return out_grad, zeros_like(k)
|
||||||
|
|
||||||
|
def bprop(input_x, k, out, dout):
|
||||||
|
if is_shape_unknown(shape_op(input_x)):
|
||||||
|
return _bprop_dynshape(input_x, k, out, dout)
|
||||||
|
return _bprop_static(input_x, k, out, dout)
|
||||||
|
|
||||||
return bprop
|
return bprop
|
||||||
|
|
||||||
|
|
||||||
|
@ -1239,7 +1304,7 @@ def get_bprop_binary_cross_entropy(self):
|
||||||
|
|
||||||
|
|
||||||
@bprop_getters.register(P.BCEWithLogitsLoss)
|
@bprop_getters.register(P.BCEWithLogitsLoss)
|
||||||
def get_bprop_ce_with_logits_loss(self):
|
def get_bprop_bce_with_logits_loss(self):
|
||||||
"""Grad definition for `BCEWithLogitsLoss` operation."""
|
"""Grad definition for `BCEWithLogitsLoss` operation."""
|
||||||
reduction = self.reduction
|
reduction = self.reduction
|
||||||
mul = P.Mul()
|
mul = P.Mul()
|
||||||
|
@ -1249,6 +1314,7 @@ def get_bprop_ce_with_logits_loss(self):
|
||||||
size = P.Size()
|
size = P.Size()
|
||||||
neg = P.Neg()
|
neg = P.Neg()
|
||||||
log = P.Log()
|
log = P.Log()
|
||||||
|
shape = P.Shape()
|
||||||
|
|
||||||
def bprop(predict, target, weight, pos_weight, out, dout):
|
def bprop(predict, target, weight, pos_weight, out, dout):
|
||||||
sigmoid_input = sigmoid(predict)
|
sigmoid_input = sigmoid(predict)
|
||||||
|
@ -1263,8 +1329,10 @@ def get_bprop_ce_with_logits_loss(self):
|
||||||
dx = mul(dx, weight)
|
dx = mul(dx, weight)
|
||||||
grad_target = mul(grad_target, weight)
|
grad_target = mul(grad_target, weight)
|
||||||
if reduction == 'mean':
|
if reduction == 'mean':
|
||||||
dx = dx / size(dx)
|
dx_size = dyn_size(dx) if is_shape_unknown(shape(dx)) else size(dx)
|
||||||
grad_target = grad_target / size(target)
|
target_size = dyn_size(target) if is_shape_unknown(shape(target)) else size(target)
|
||||||
|
dx = dx / dx_size
|
||||||
|
grad_target = grad_target / target_size
|
||||||
return dx, grad_target, zeros_like(weight), zeros_like(pos_weight)
|
return dx, grad_target, zeros_like(weight), zeros_like(pos_weight)
|
||||||
|
|
||||||
return bprop
|
return bprop
|
||||||
|
@ -1414,7 +1482,7 @@ def get_bprop_conv2d_backprop_filter(self):
|
||||||
|
|
||||||
def bprop(dy, x, filter_size, out, dout):
|
def bprop(dy, x, filter_size, out, dout):
|
||||||
x_shape = get_shape(x)
|
x_shape = get_shape(x)
|
||||||
if -1 in x_shape:
|
if is_shape_unknown(x_shape):
|
||||||
x_shape = get_dyn_shape(x)
|
x_shape = get_dyn_shape(x)
|
||||||
dw_dx = input_grad(dy, dout, x_shape)
|
dw_dx = input_grad(dy, dout, x_shape)
|
||||||
dw_dy = filter_grad(x, dout)
|
dw_dy = filter_grad(x, dout)
|
||||||
|
|
|
@ -19,14 +19,19 @@ import mindspore
|
||||||
|
|
||||||
from .. import Tensor
|
from .. import Tensor
|
||||||
from .. import functional as F
|
from .. import functional as F
|
||||||
|
from .. import operations as P
|
||||||
from ..composite.multitype_ops.zeros_like_impl import zeros_like
|
from ..composite.multitype_ops.zeros_like_impl import zeros_like
|
||||||
from ..operations import math_ops as math
|
from ..operations import math_ops as math
|
||||||
from ..operations import linalg_ops as linalg
|
from ..operations import linalg_ops as linalg
|
||||||
from ..operations import array_ops as arrays
|
from ..operations import array_ops as arrays
|
||||||
from ..primitive import constexpr
|
from ..primitive import constexpr
|
||||||
from .._grad.grad_base import bprop_getters
|
from .._grad.grad_base import bprop_getters
|
||||||
|
from .._grad.grad_base import dyn_rank
|
||||||
|
from .._utils.utils import is_shape_unknown
|
||||||
|
|
||||||
_shape = arrays.Shape()
|
_shape = arrays.Shape()
|
||||||
|
_dyn_shape = arrays.TensorShape()
|
||||||
|
|
||||||
_dtype = arrays.DType()
|
_dtype = arrays.DType()
|
||||||
_cast = arrays.Cast()
|
_cast = arrays.Cast()
|
||||||
_transpose = arrays.Transpose()
|
_transpose = arrays.Transpose()
|
||||||
|
@ -47,12 +52,19 @@ def _raise_value_error(*info):
|
||||||
|
|
||||||
|
|
||||||
def _matrix_transpose(a):
|
def _matrix_transpose(a):
|
||||||
dims = a.ndim
|
"""Transpose last two axes"""
|
||||||
if dims < 2:
|
if is_shape_unknown(_shape(a)):
|
||||||
_raise_value_error(
|
dims = dyn_rank(a)
|
||||||
"To do _matrix_transpose for input a's ndim is not greater or equal to 2, which is invalid.")
|
axes = P.Range()(P.Cast()(0, mindspore.int64), dims, P.Cast()(1, mindspore.int64))
|
||||||
axes = F.make_range(0, dims)
|
axes = P.Concat(axis=-1)((axes[:-2], axes[-1:], axes[-2:-1]))
|
||||||
axes = axes[:-2] + (axes[-1],) + (axes[-2],)
|
else:
|
||||||
|
dims = a.ndim
|
||||||
|
if dims < 2:
|
||||||
|
_raise_value_error(
|
||||||
|
"To do _matrix_transpose for input a's ndim is not greater or equal to 2, which is invalid: {}."
|
||||||
|
.format(dims))
|
||||||
|
axes = F.make_range(0, dims)
|
||||||
|
axes = axes[:-2] + (axes[-1],) + (axes[-2],)
|
||||||
return _transpose(a, axes)
|
return _transpose(a, axes)
|
||||||
|
|
||||||
|
|
||||||
|
@ -76,14 +88,26 @@ def _make_zero_matrix(shape, dtype):
|
||||||
|
|
||||||
|
|
||||||
def _matrix_diag(diagonal):
|
def _matrix_diag(diagonal):
|
||||||
|
"""Do matrix diagnoal"""
|
||||||
diagonal_shape = _shape(diagonal)
|
diagonal_shape = _shape(diagonal)
|
||||||
|
if is_shape_unknown(diagonal_shape):
|
||||||
|
diagonal_shape = _dyn_shape(diagonal)
|
||||||
|
row = P.Cast()(diagonal_shape[-1], mindspore.int32)
|
||||||
|
return arrays.MatrixDiagV3()(diagonal, _k_0, row, row, P.Cast()(0, _dtype(diagonal)))
|
||||||
|
|
||||||
row = _make_tensor(diagonal_shape[-1], mindspore.int32)
|
row = _make_tensor(diagonal_shape[-1], mindspore.int32)
|
||||||
return arrays.MatrixDiagV3()(diagonal, _k_0, row, row, _make_tensor(0, _dtype(diagonal)))
|
return arrays.MatrixDiagV3()(diagonal, _k_0, row, row, _make_tensor(0, _dtype(diagonal)))
|
||||||
|
|
||||||
|
|
||||||
def _mat_mul(x, y):
|
def _mat_mul(x, y):
|
||||||
|
"""Do matmul"""
|
||||||
shape = _shape(x)
|
shape = _shape(x)
|
||||||
if len(shape) > 2:
|
if is_shape_unknown(shape):
|
||||||
|
shape = _dyn_shape(x)
|
||||||
|
tensor_rank = dyn_rank(x)
|
||||||
|
else:
|
||||||
|
tensor_rank = len(shape)
|
||||||
|
if tensor_rank > 2:
|
||||||
return math.BatchMatMul()(x, y)
|
return math.BatchMatMul()(x, y)
|
||||||
return math.MatMul()(x, y)
|
return math.MatMul()(x, y)
|
||||||
|
|
||||||
|
@ -106,12 +130,16 @@ def get_bprop_svd(self):
|
||||||
return (da,)
|
return (da,)
|
||||||
|
|
||||||
a_shape = _shape(a)
|
a_shape = _shape(a)
|
||||||
if len(a_shape) < 2:
|
if is_shape_unknown(a_shape):
|
||||||
|
a_shape = _dyn_shape(a)
|
||||||
|
tensor_rank = dyn_rank(a)
|
||||||
|
else:
|
||||||
|
tensor_rank = len(a_shape)
|
||||||
|
if tensor_rank < 2:
|
||||||
_raise_value_error(
|
_raise_value_error(
|
||||||
"For input a's ndim is not greater or equal to 2, which is invalid.")
|
"For input a's ndim is not greater or equal to 2, which is invalid.")
|
||||||
m = a_shape[-2]
|
m = a_shape[-2]
|
||||||
n = a_shape[-1]
|
n = a_shape[-1]
|
||||||
|
|
||||||
s, u, v = out
|
s, u, v = out
|
||||||
ds, du, dv = dout
|
ds, du, dv = dout
|
||||||
use_adjoint = False
|
use_adjoint = False
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
# Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -17,6 +17,7 @@
|
||||||
|
|
||||||
from mindspore.common import dtype as mstype
|
from mindspore.common import dtype as mstype
|
||||||
from mindspore.ops.operations.math_ops import Trace, Bernoulli, Renorm
|
from mindspore.ops.operations.math_ops import Trace, Bernoulli, Renorm
|
||||||
|
from mindspore.ops._utils.utils import is_shape_unknown
|
||||||
from mindspore import nn
|
from mindspore import nn
|
||||||
import mindspore.numpy as mnp
|
import mindspore.numpy as mnp
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -27,7 +28,7 @@ from ..operations.math_ops import Real, Imag, Complex, Angle
|
||||||
from ..operations.math_ops import ComplexAbs
|
from ..operations.math_ops import ComplexAbs
|
||||||
from ..operations.math_ops import Sinc
|
from ..operations.math_ops import Sinc
|
||||||
from ..functional import broadcast_gradient_args
|
from ..functional import broadcast_gradient_args
|
||||||
from .._grad.grad_base import bprop_getters
|
from .._grad.grad_base import bprop_getters, create_tensor_by_element, dyn_rank
|
||||||
from .._grad.grad_math_ops import binop_grad_common
|
from .._grad.grad_math_ops import binop_grad_common
|
||||||
from ..composite.multitype_ops.zeros_like_impl import zeros_like
|
from ..composite.multitype_ops.zeros_like_impl import zeros_like
|
||||||
from ..operations import _grad_ops as G
|
from ..operations import _grad_ops as G
|
||||||
|
@ -55,9 +56,11 @@ from ..operations.math_ops import CholeskySolve
|
||||||
from ..operations.math_ops import AddV2
|
from ..operations.math_ops import AddV2
|
||||||
from ..operations.math_ops import TridiagonalMatMul
|
from ..operations.math_ops import TridiagonalMatMul
|
||||||
from ..operations.math_ops import Logit
|
from ..operations.math_ops import Logit
|
||||||
|
from .._utils.utils import is_shape_unknown
|
||||||
|
|
||||||
|
|
||||||
transpose = P.Transpose()
|
transpose = P.Transpose()
|
||||||
|
dyn_shape_op = P.TensorShape()
|
||||||
_conj = P.Conj()
|
_conj = P.Conj()
|
||||||
|
|
||||||
|
|
||||||
|
@ -67,6 +70,11 @@ def _generate_perm(x_dim):
|
||||||
return perm
|
return perm
|
||||||
|
|
||||||
|
|
||||||
|
def _dyn_generate_perm(x_dim):
|
||||||
|
perm = P.Range()(P.Cast()(0, x_dim.dtype), x_dim - 2, P.Cast()(1, x_dim.dtype))
|
||||||
|
return perm
|
||||||
|
|
||||||
|
|
||||||
def _adjoint(a):
|
def _adjoint(a):
|
||||||
return cholesky_transpose(_conj(a))
|
return cholesky_transpose(_conj(a))
|
||||||
|
|
||||||
|
@ -110,10 +118,20 @@ def get_bprop_cdist(self):
|
||||||
|
|
||||||
def bprop(input_x, input_y, out, dout):
|
def bprop(input_x, input_y, out, dout):
|
||||||
dout_shape = F.shape(dout)
|
dout_shape = F.shape(dout)
|
||||||
dout_dim = len(dout_shape)
|
if is_shape_unknown(dout_shape):
|
||||||
dout_perm_part1 = _generate_perm(dout_dim)
|
dout_dim = dyn_rank(dout)
|
||||||
dout_perm_part2 = (dout_dim - 1, dout_dim - 2)
|
dout_perm_part2 = create_tensor_by_element(
|
||||||
dout_perm = dout_perm_part1 + dout_perm_part2
|
(dout_dim - 1, dout_dim - 2))
|
||||||
|
if dout_dim <= 2:
|
||||||
|
dout_perm = dout_perm_part2
|
||||||
|
else:
|
||||||
|
dout_perm_part1 = _dyn_generate_perm(dout_dim)
|
||||||
|
dout_perm = P.Concat(0)((dout_perm_part1, dout_perm_part2))
|
||||||
|
else:
|
||||||
|
dout_dim = len(dout_shape)
|
||||||
|
dout_perm_part1 = _generate_perm(dout_dim)
|
||||||
|
dout_perm_part2 = (dout_dim - 1, dout_dim - 2)
|
||||||
|
dout_perm = dout_perm_part1 + dout_perm_part2
|
||||||
out_perm = dout_perm
|
out_perm = dout_perm
|
||||||
dout_transpose = transpose(dout, dout_perm)
|
dout_transpose = transpose(dout, dout_perm)
|
||||||
out_transpose = transpose(out, out_perm)
|
out_transpose = transpose(out, out_perm)
|
||||||
|
@ -484,8 +502,16 @@ def get_bprop_matrix_determinant(self):
|
||||||
inverse_op = P.MatrixInverse(adjoint=True)
|
inverse_op = P.MatrixInverse(adjoint=True)
|
||||||
shape_op = P.Shape()
|
shape_op = P.Shape()
|
||||||
reshape = P.Reshape()
|
reshape = P.Reshape()
|
||||||
|
concat = P.Concat(0)
|
||||||
|
|
||||||
def bprop(x, out, dout):
|
def bprop(x, out, dout):
|
||||||
|
if is_shape_unknown(shape_op(x)):
|
||||||
|
x_adj_inv = inverse_op(x)
|
||||||
|
out_shape = dyn_shape_op(out)
|
||||||
|
ones = create_tensor_by_element((1, 1))
|
||||||
|
multipliers = reshape(dout * out, concat((out_shape, ones)))
|
||||||
|
dx = multipliers * x_adj_inv
|
||||||
|
return (dx,)
|
||||||
x_adj_inv = inverse_op(x)
|
x_adj_inv = inverse_op(x)
|
||||||
multipliers = reshape(dout * out, shape_op(out) + (1, 1))
|
multipliers = reshape(dout * out, shape_op(out) + (1, 1))
|
||||||
dx = multipliers * x_adj_inv
|
dx = multipliers * x_adj_inv
|
||||||
|
@ -902,7 +928,11 @@ def get_bprop_trace(self):
|
||||||
|
|
||||||
def bprop(x, out, dout):
|
def bprop(x, out, dout):
|
||||||
shape = shape_op(x)
|
shape = shape_op(x)
|
||||||
dx = input_grad(dout, cast(to_array(shape), mstype.int64))
|
if is_shape_unknown(shape):
|
||||||
|
shape = dyn_shape_op(x)
|
||||||
|
dx = input_grad(dout, shape)
|
||||||
|
else:
|
||||||
|
dx = input_grad(dout, cast(to_array(shape), mstype.int64))
|
||||||
return (dx,)
|
return (dx,)
|
||||||
|
|
||||||
return bprop
|
return bprop
|
||||||
|
@ -1041,7 +1071,7 @@ def get_bprop_tridiagonal_matmul(self):
|
||||||
maindiag_grad = reduce_sum(rhs_conj * grad, -1)
|
maindiag_grad = reduce_sum(rhs_conj * grad, -1)
|
||||||
subdiag_grad = reduce_sum(_rightshift(rhs_conj) * grad, -1)
|
subdiag_grad = reduce_sum(_rightshift(rhs_conj) * grad, -1)
|
||||||
rhs_grad = _rightshift(superdiag_conj * grad) + maindiag_conj * grad + \
|
rhs_grad = _rightshift(superdiag_conj * grad) + maindiag_conj * grad + \
|
||||||
_leftshift(subdiag_conj * grad)
|
_leftshift(subdiag_conj * grad)
|
||||||
superdiag_grad = expand_dims(superdiag_grad, -2)
|
superdiag_grad = expand_dims(superdiag_grad, -2)
|
||||||
maindiag_grad = expand_dims(maindiag_grad, -2)
|
maindiag_grad = expand_dims(maindiag_grad, -2)
|
||||||
subdiag_grad = expand_dims(subdiag_grad, -2)
|
subdiag_grad = expand_dims(subdiag_grad, -2)
|
||||||
|
|
|
@ -25,6 +25,7 @@ from mindspore.ops.operations.sparse_ops import SparseSegmentSumWithNumSegments
|
||||||
from mindspore.ops.operations.sparse_ops import SparseSegmentSqrtN
|
from mindspore.ops.operations.sparse_ops import SparseSegmentSqrtN
|
||||||
from mindspore.ops.operations.sparse_ops import SparseSegmentSqrtNWithNumSegments
|
from mindspore.ops.operations.sparse_ops import SparseSegmentSqrtNWithNumSegments
|
||||||
from mindspore.ops.operations.sparse_ops import SparseSegmentMeanWithNumSegments
|
from mindspore.ops.operations.sparse_ops import SparseSegmentMeanWithNumSegments
|
||||||
|
from mindspore.ops._utils.utils import is_shape_unknown
|
||||||
from mindspore.common import dtype as mstype
|
from mindspore.common import dtype as mstype
|
||||||
from mindspore import Tensor
|
from mindspore import Tensor
|
||||||
from mindspore.ops.primitive import constexpr
|
from mindspore.ops.primitive import constexpr
|
||||||
|
@ -36,6 +37,7 @@ from .._grad.grad_base import bprop_getters
|
||||||
from .._utils.utils import is_shape_unknown
|
from .._utils.utils import is_shape_unknown
|
||||||
|
|
||||||
# Unused parameters are placeholders.
|
# Unused parameters are placeholders.
|
||||||
|
dyn_shape_op = P.TensorShape()
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
|
@ -50,6 +52,7 @@ def get_bprop_sparse_softmax(self):
|
||||||
sparse_dense_cwise_add = SparseDenseCwiseAdd()
|
sparse_dense_cwise_add = SparseDenseCwiseAdd()
|
||||||
reduce_sum = P.ReduceSum(keep_dims=True)
|
reduce_sum = P.ReduceSum(keep_dims=True)
|
||||||
mul = P.Mul()
|
mul = P.Mul()
|
||||||
|
|
||||||
def bprop(indices, values, shape, out, dout):
|
def bprop(indices, values, shape, out, dout):
|
||||||
default_values = _create_tensor(0, values.dtype)
|
default_values = _create_tensor(0, values.dtype)
|
||||||
out_dout = mul(out, dout)
|
out_dout = mul(out, dout)
|
||||||
|
@ -106,7 +109,6 @@ def get_bprop_sparse_segment_sqrt_n(self):
|
||||||
"""Grad definition for `SparseSegmentSqrtN` operation."""
|
"""Grad definition for `SparseSegmentSqrtN` operation."""
|
||||||
input_grad = G.SparseSegmentSqrtNGrad()
|
input_grad = G.SparseSegmentSqrtNGrad()
|
||||||
shape = P.Shape()
|
shape = P.Shape()
|
||||||
dyn_shape_op = P.TensorShape()
|
|
||||||
|
|
||||||
def bprop(x, indices, segment_ids, out, dout):
|
def bprop(x, indices, segment_ids, out, dout):
|
||||||
shape_x = shape(x)
|
shape_x = shape(x)
|
||||||
|
@ -127,7 +129,6 @@ def get_bprop_sparse_segment_sqrt_n_with_num_segments(self):
|
||||||
"""Grad definition for `SparseSegmentSqrtNWithNumSegments` operation."""
|
"""Grad definition for `SparseSegmentSqrtNWithNumSegments` operation."""
|
||||||
input_grad = G.SparseSegmentSqrtNGrad()
|
input_grad = G.SparseSegmentSqrtNGrad()
|
||||||
shape = P.Shape()
|
shape = P.Shape()
|
||||||
dyn_shape_op = P.TensorShape()
|
|
||||||
|
|
||||||
def bprop(x, indices, segment_ids, num_segments, out, dout):
|
def bprop(x, indices, segment_ids, num_segments, out, dout):
|
||||||
shape_x = shape(x)
|
shape_x = shape(x)
|
||||||
|
@ -148,7 +149,6 @@ def get_bprop_sparse_segment_sum(self):
|
||||||
"""Grad definition for `SparseSegmentSum` operation."""
|
"""Grad definition for `SparseSegmentSum` operation."""
|
||||||
input_grad = G.SparseSegmentSumGrad()
|
input_grad = G.SparseSegmentSumGrad()
|
||||||
shape = P.Shape()
|
shape = P.Shape()
|
||||||
dyn_shape_op = P.TensorShape()
|
|
||||||
|
|
||||||
def bprop(x, indices, segment_ids, out, dout):
|
def bprop(x, indices, segment_ids, out, dout):
|
||||||
shape_x = shape(x)
|
shape_x = shape(x)
|
||||||
|
@ -169,7 +169,6 @@ def get_bprop_sparse_segment_sum_with_num_segments(self):
|
||||||
"""Grad definition for `SparseSegmentSumWithNumSegments` operation."""
|
"""Grad definition for `SparseSegmentSumWithNumSegments` operation."""
|
||||||
input_grad = G.SparseSegmentSumGrad()
|
input_grad = G.SparseSegmentSumGrad()
|
||||||
shape = P.Shape()
|
shape = P.Shape()
|
||||||
dyn_shape_op = P.TensorShape()
|
|
||||||
|
|
||||||
def bprop(x, indices, segment_ids, num_segments, out, dout):
|
def bprop(x, indices, segment_ids, num_segments, out, dout):
|
||||||
shape_x = shape(x)
|
shape_x = shape(x)
|
||||||
|
@ -192,7 +191,12 @@ def get_bprop_sparse_segment_mean_with_num_segments(self):
|
||||||
shape = P.Shape()
|
shape = P.Shape()
|
||||||
|
|
||||||
def bprop(x, indices, segment_ids, num_segments, out, dout):
|
def bprop(x, indices, segment_ids, num_segments, out, dout):
|
||||||
output_dim0 = F.scalar_to_tensor(shape(x)[0], mstype.int32)
|
x_shp = shape(x)
|
||||||
|
if is_shape_unknown(x_shp):
|
||||||
|
x_shp = dyn_shape_op(x)
|
||||||
|
output_dim0 = F.cast(x_shp[0], mstype.int32)
|
||||||
|
else:
|
||||||
|
output_dim0 = F.scalar_to_tensor(x_shp[0], mstype.int32)
|
||||||
indices = F.cast(indices, mstype.int32)
|
indices = F.cast(indices, mstype.int32)
|
||||||
segment_ids = F.cast(segment_ids, mstype.int32)
|
segment_ids = F.cast(segment_ids, mstype.int32)
|
||||||
dx = input_grad(dout, indices, segment_ids, output_dim0)
|
dx = input_grad(dout, indices, segment_ids, output_dim0)
|
||||||
|
@ -208,6 +212,7 @@ def get_bprop_sparse_reorder(self):
|
||||||
sparse_reorder_op = SparseReorder()
|
sparse_reorder_op = SparseReorder()
|
||||||
range_op = P.Range()
|
range_op = P.Range()
|
||||||
gather_op = P.Gather()
|
gather_op = P.Gather()
|
||||||
|
|
||||||
def bprop(indices, values, shape, out, dout):
|
def bprop(indices, values, shape, out, dout):
|
||||||
num_entries = F.shape(indices)[0]
|
num_entries = F.shape(indices)[0]
|
||||||
start = Tensor(0, dtype=mstype.int32)
|
start = Tensor(0, dtype=mstype.int32)
|
||||||
|
|
|
@ -1,20 +1,20 @@
|
||||||
|
|
||||||
0.1.1 MindSpore*1.9.0:Ü
|
0.1.1 MindSpore*1.9.0:Þ
|
||||||
‘
|
’
|
||||||
|
|
||||||
bprop.12:xbprop.12:[CNode]13:1bprop.12:[CNode]13:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:-Default/S-Prim-hyper_map[zeros_like_leaf]-op8
|
bprop.33:xbprop.33:[CNode]34:1bprop.33:[CNode]34:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op23
|
||||||
‘
|
’
|
||||||
|
|
||||||
bprop.12:ybprop.12:[CNode]14:3bprop.12:[CNode]14:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:-Default/S-Prim-hyper_map[zeros_like_leaf]-op9
|
bprop.33:ybprop.33:[CNode]35:3bprop.33:[CNode]35:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op24
|
||||||
<EFBFBD>
|
<EFBFBD>
|
||||||
bprop.12:[CNode]13:1
|
bprop.33:[CNode]34:1
|
||||||
bprop.12:[CNode]14:3bprop.12:[CNode]15:4bprop.12:[CNode]15:4"REF::S-Prim-MakeTuple:5:Default/S-Prim-MakeTuple-op10bprop.12*
|
bprop.33:[CNode]35:3bprop.33:[CNode]36:4bprop.33:[CNode]36:4"REF::S-Prim-MakeTuple:5:Default/S-Prim-MakeTuple-op25bprop.33*
|
||||||
|
|
||||||
bprop.12:x*
|
bprop.33:x*
|
||||||
|
|
||||||
bprop.12:y*
|
bprop.33:y*
|
||||||
bprop.12:out*
|
bprop.33:out*
|
||||||
bprop.12:dout2
|
bprop.33:dout2
|
||||||
bprop.12:[CNode]15:4:@51014516712342426367430e7abdbb8b5454385a02c91339e89ad18f19803ee3Pb&
|
bprop.33:[CNode]36:4:@157abcda70ad669686a320be427dedf2a17498d59a38b65d04c272c1a40296e0Pb&
|
||||||
S-Prim-MakeTuple:5S-Prim-MakeTuplebH
|
S-Prim-MakeTuple:5S-Prim-MakeTuplebH
|
||||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h
|
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h
|
|
@ -1,14 +1,14 @@
|
||||||
|
|
||||||
0.1.1 MindSpore*1.9.0:£
|
0.1.1 MindSpore*1.9.0:¤
|
||||||
‘
|
’
|
||||||
|
|
||||||
bprop.14:xbprop.14:[CNode]15:1bprop.14:[CNode]15:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:-Default/S-Prim-hyper_map[zeros_like_leaf]-op9
|
bprop.16:xbprop.16:[CNode]17:1bprop.16:[CNode]17:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op11
|
||||||
z
|
z
|
||||||
bprop.14:[CNode]15:1bprop.14:[CNode]16:3bprop.14:[CNode]16:3"REF::S-Prim-MakeTuple:4:Default/S-Prim-MakeTuple-op10bprop.14*
|
bprop.16:[CNode]17:1bprop.16:[CNode]18:3bprop.16:[CNode]18:3"REF::S-Prim-MakeTuple:4:Default/S-Prim-MakeTuple-op12bprop.16*
|
||||||
|
|
||||||
bprop.14:x*
|
bprop.16:x*
|
||||||
bprop.14:out*
|
bprop.16:out*
|
||||||
bprop.14:dout2
|
bprop.16:dout2
|
||||||
bprop.14:[CNode]16:3:@da20d6cb5b4c7c7cced90d19d6551099a88fa59be3bc1e215683bed4c1b029abPb&
|
bprop.16:[CNode]18:3:@030ffb7e684c8e2604e69bfc02a6e32af82543a97f4aa0a6bc7cab6275ee9b9dPbH
|
||||||
S-Prim-MakeTuple:4S-Prim-MakeTuplebH
|
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
|
||||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h
|
S-Prim-MakeTuple:4S-Prim-MakeTupleh
|
|
@ -2,13 +2,13 @@
|
||||||
0.1.1 MindSpore*1.9.0:¤
|
0.1.1 MindSpore*1.9.0:¤
|
||||||
’
|
’
|
||||||
|
|
||||||
bprop.17:xbprop.17:[CNode]18:1bprop.17:[CNode]18:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op11
|
bprop.19:xbprop.19:[CNode]20:1bprop.19:[CNode]20:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op13
|
||||||
z
|
z
|
||||||
bprop.17:[CNode]18:1bprop.17:[CNode]19:3bprop.17:[CNode]19:3"REF::S-Prim-MakeTuple:4:Default/S-Prim-MakeTuple-op12bprop.17*
|
bprop.19:[CNode]20:1bprop.19:[CNode]21:3bprop.19:[CNode]21:3"REF::S-Prim-MakeTuple:4:Default/S-Prim-MakeTuple-op14bprop.19*
|
||||||
|
|
||||||
bprop.17:x*
|
bprop.19:x*
|
||||||
bprop.17:out*
|
bprop.19:out*
|
||||||
bprop.17:dout2
|
bprop.19:dout2
|
||||||
bprop.17:[CNode]19:3:@da20d6cb5b4c7c7cced90d19d6551099a88fa59be3bc1e215683bed4c1b029abPb&
|
bprop.19:[CNode]21:3:@030ffb7e684c8e2604e69bfc02a6e32af82543a97f4aa0a6bc7cab6275ee9b9dPb&
|
||||||
S-Prim-MakeTuple:4S-Prim-MakeTuplebH
|
S-Prim-MakeTuple:4S-Prim-MakeTuplebH
|
||||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h
|
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h
|
|
@ -1,16 +1,20 @@
|
||||||
|
|
||||||
0.1.1 MindSpore*1.9.0:Â
|
0.1.1 MindSpore*1.9.0:Þ
|
||||||
Œ
|
’
|
||||||
bprop.1:xbprop.1:[CNode]2:1bprop.1:[CNode]2:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:-Default/S-Prim-hyper_map[zeros_like_leaf]-op0
|
|
||||||
Œ
|
bprop.22:xbprop.22:[CNode]23:1bprop.22:[CNode]23:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op15
|
||||||
bprop.1:ybprop.1:[CNode]3:3bprop.1:[CNode]3:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:-Default/S-Prim-hyper_map[zeros_like_leaf]-op1
|
’
|
||||||
‡
|
|
||||||
bprop.1:[CNode]2:1
|
bprop.22:ybprop.22:[CNode]24:3bprop.22:[CNode]24:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op16
|
||||||
bprop.1:[CNode]3:3bprop.1:[CNode]4:4bprop.1:[CNode]4:4"REF::S-Prim-MakeTuple:5:Default/S-Prim-MakeTuple-op2bprop.1*
|
<EFBFBD>
|
||||||
bprop.1:x*
|
bprop.22:[CNode]23:1
|
||||||
bprop.1:y*
|
bprop.22:[CNode]24:3bprop.22:[CNode]25:4bprop.22:[CNode]25:4"REF::S-Prim-MakeTuple:5:Default/S-Prim-MakeTuple-op17bprop.22*
|
||||||
bprop.1:out*
|
|
||||||
bprop.1:dout2
|
bprop.22:x*
|
||||||
bprop.1:[CNode]4:4:@51014516712342426367430e7abdbb8b5454385a02c91339e89ad18f19803ee3PbH
|
|
||||||
|
bprop.22:y*
|
||||||
|
bprop.22:out*
|
||||||
|
bprop.22:dout2
|
||||||
|
bprop.22:[CNode]25:4:@157abcda70ad669686a320be427dedf2a17498d59a38b65d04c272c1a40296e0PbH
|
||||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
|
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
|
||||||
S-Prim-MakeTuple:5S-Prim-MakeTupleh
|
S-Prim-MakeTuple:5S-Prim-MakeTupleh
|
|
@ -1,16 +1,20 @@
|
||||||
|
|
||||||
0.1.1 MindSpore*1.9.0:Â
|
0.1.1 MindSpore*1.9.0:Þ
|
||||||
Œ
|
’
|
||||||
bprop.5:xbprop.5:[CNode]6:1bprop.5:[CNode]6:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:-Default/S-Prim-hyper_map[zeros_like_leaf]-op3
|
|
||||||
Œ
|
bprop.26:xbprop.26:[CNode]27:1bprop.26:[CNode]27:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op18
|
||||||
bprop.5:ybprop.5:[CNode]7:3bprop.5:[CNode]7:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:-Default/S-Prim-hyper_map[zeros_like_leaf]-op4
|
’
|
||||||
‡
|
|
||||||
bprop.5:[CNode]6:1
|
bprop.26:ybprop.26:[CNode]28:3bprop.26:[CNode]28:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op19
|
||||||
bprop.5:[CNode]7:3bprop.5:[CNode]8:4bprop.5:[CNode]8:4"REF::S-Prim-MakeTuple:5:Default/S-Prim-MakeTuple-op5bprop.5*
|
<EFBFBD>
|
||||||
bprop.5:x*
|
bprop.26:[CNode]27:1
|
||||||
bprop.5:y*
|
bprop.26:[CNode]28:3bprop.26:[CNode]29:4bprop.26:[CNode]29:4"REF::S-Prim-MakeTuple:5:Default/S-Prim-MakeTuple-op20bprop.26*
|
||||||
bprop.5:out*
|
|
||||||
bprop.5:dout2
|
bprop.26:x*
|
||||||
bprop.5:[CNode]8:4:@51014516712342426367430e7abdbb8b5454385a02c91339e89ad18f19803ee3Pb&
|
|
||||||
S-Prim-MakeTuple:5S-Prim-MakeTuplebH
|
bprop.26:y*
|
||||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h
|
bprop.26:out*
|
||||||
|
bprop.26:dout2
|
||||||
|
bprop.26:[CNode]29:4:@157abcda70ad669686a320be427dedf2a17498d59a38b65d04c272c1a40296e0PbH
|
||||||
|
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
|
||||||
|
S-Prim-MakeTuple:5S-Prim-MakeTupleh
|
|
@ -2,13 +2,13 @@
|
||||||
0.1.1 MindSpore*1.9.0:¤
|
0.1.1 MindSpore*1.9.0:¤
|
||||||
’
|
’
|
||||||
|
|
||||||
bprop.20:xbprop.20:[CNode]21:1bprop.20:[CNode]21:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op13
|
bprop.37:xbprop.37:[CNode]38:1bprop.37:[CNode]38:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op27
|
||||||
z
|
z
|
||||||
bprop.20:[CNode]21:1bprop.20:[CNode]22:3bprop.20:[CNode]22:3"REF::S-Prim-MakeTuple:4:Default/S-Prim-MakeTuple-op14bprop.20*
|
bprop.37:[CNode]38:1bprop.37:[CNode]39:3bprop.37:[CNode]39:3"REF::S-Prim-MakeTuple:4:Default/S-Prim-MakeTuple-op28bprop.37*
|
||||||
|
|
||||||
bprop.20:x*
|
bprop.37:x*
|
||||||
bprop.20:out*
|
bprop.37:out*
|
||||||
bprop.20:dout2
|
bprop.37:dout2
|
||||||
bprop.20:[CNode]22:3:@da20d6cb5b4c7c7cced90d19d6551099a88fa59be3bc1e215683bed4c1b029abPb&
|
bprop.37:[CNode]39:3:@030ffb7e684c8e2604e69bfc02a6e32af82543a97f4aa0a6bc7cab6275ee9b9dPbH
|
||||||
S-Prim-MakeTuple:4S-Prim-MakeTuplebH
|
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
|
||||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h
|
S-Prim-MakeTuple:4S-Prim-MakeTupleh
|
|
@ -1,27 +1,27 @@
|
||||||
|
|
||||||
0.1.1 MindSpore*1.9.0:°
|
0.1.1 MindSpore*1.9.0:°
|
||||||
›
|
›
|
||||||
bprop.13:dout
|
bprop.32:dout
|
||||||
|
|
||||||
bprop.13:y
|
bprop.32:y
|
||||||
bprop.13:keep_probbprop.13:[CNode]14:1bprop.13:[CNode]14:1"REF::S-Prim-DropoutDoMask:2:!Default/S-Prim-DropoutDoMask-op10
|
bprop.32:keep_probbprop.32:[CNode]33:1bprop.32:[CNode]33:1"REF::S-Prim-DropoutDoMask:2:!Default/S-Prim-DropoutDoMask-op23
|
||||||
’
|
’
|
||||||
|
|
||||||
bprop.13:ybprop.13:[CNode]15:3bprop.13:[CNode]15:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:4:.Default/S-Prim-hyper_map[zeros_like_leaf]-op11
|
bprop.32:ybprop.32:[CNode]34:3bprop.32:[CNode]34:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:4:.Default/S-Prim-hyper_map[zeros_like_leaf]-op24
|
||||||
š
|
š
|
||||||
bprop.13:keep_probbprop.13:[CNode]16:5bprop.13:[CNode]16:5"(REF::S-Prim-hyper_map[zeros_like_leaf]:4:.Default/S-Prim-hyper_map[zeros_like_leaf]-op12
|
bprop.32:keep_probbprop.32:[CNode]35:5bprop.32:[CNode]35:5"(REF::S-Prim-hyper_map[zeros_like_leaf]:4:.Default/S-Prim-hyper_map[zeros_like_leaf]-op25
|
||||||
¦
|
¦
|
||||||
bprop.13:[CNode]14:1
|
bprop.32:[CNode]33:1
|
||||||
bprop.13:[CNode]15:3
|
bprop.32:[CNode]34:3
|
||||||
bprop.13:[CNode]16:5bprop.13:[CNode]17:6bprop.13:[CNode]17:6"REF::S-Prim-MakeTuple:7:Default/S-Prim-MakeTuple-op13bprop.13*
|
bprop.32:[CNode]35:5bprop.32:[CNode]36:6bprop.32:[CNode]36:6"REF::S-Prim-MakeTuple:7:Default/S-Prim-MakeTuple-op26bprop.32*
|
||||||
|
|
||||||
bprop.13:x*
|
bprop.32:x*
|
||||||
|
|
||||||
bprop.13:y*
|
bprop.32:y*
|
||||||
bprop.13:keep_prob*
|
bprop.32:keep_prob*
|
||||||
bprop.13:out*
|
bprop.32:out*
|
||||||
bprop.13:dout2
|
bprop.32:dout2
|
||||||
bprop.13:[CNode]17:6:@de4408af9f648bc43794b496de642f2ba8fef52099d138465267bb08b5fd7e8fPb&
|
bprop.32:[CNode]36:6:@320a603a26f1cc5c0174957a4eea3494bfd5f1c2f20ad64ef6e0ee49a91b62baPb&
|
||||||
S-Prim-MakeTuple:7S-Prim-MakeTupleb.
|
S-Prim-MakeTuple:7S-Prim-MakeTuplebH
|
||||||
S-Prim-DropoutDoMask:2S-Prim-DropoutDoMaskbH
|
#S-Prim-hyper_map[zeros_like_leaf]:4!S-Prim-hyper_map[zeros_like_leaf]b.
|
||||||
#S-Prim-hyper_map[zeros_like_leaf]:4!S-Prim-hyper_map[zeros_like_leaf]h
|
S-Prim-DropoutDoMask:2S-Prim-DropoutDoMaskh
|
|
@ -1,16 +1,16 @@
|
||||||
|
|
||||||
0.1.1 MindSpore*1.9.0:Ú
|
0.1.1 MindSpore*1.9.0:ö
|
||||||
|
–
|
||||||
|
bprop.22:shapebprop.22:[CNode]23:1bprop.22:[CNode]23:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op15
|
||||||
|
š
|
||||||
|
bprop.22:keep_probbprop.22:[CNode]24:3bprop.22:[CNode]24:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op16
|
||||||
<EFBFBD>
|
<EFBFBD>
|
||||||
bprop.3:shapebprop.3:[CNode]4:1bprop.3:[CNode]4:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:-Default/S-Prim-hyper_map[zeros_like_leaf]-op2
|
bprop.22:[CNode]23:1
|
||||||
”
|
bprop.22:[CNode]24:3bprop.22:[CNode]25:4bprop.22:[CNode]25:4"REF::S-Prim-MakeTuple:5:Default/S-Prim-MakeTuple-op17bprop.22*
|
||||||
bprop.3:keep_probbprop.3:[CNode]5:3bprop.3:[CNode]5:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:-Default/S-Prim-hyper_map[zeros_like_leaf]-op3
|
bprop.22:shape*
|
||||||
‡
|
bprop.22:keep_prob*
|
||||||
bprop.3:[CNode]4:1
|
bprop.22:out*
|
||||||
bprop.3:[CNode]5:3bprop.3:[CNode]6:4bprop.3:[CNode]6:4"REF::S-Prim-MakeTuple:5:Default/S-Prim-MakeTuple-op4bprop.3*
|
bprop.22:dout2
|
||||||
bprop.3:shape*
|
bprop.22:[CNode]25:4:@320a603a26f1cc5c0174957a4eea3494bfd5f1c2f20ad64ef6e0ee49a91b62baPb&
|
||||||
bprop.3:keep_prob*
|
|
||||||
bprop.3:out*
|
|
||||||
bprop.3:dout2
|
|
||||||
bprop.3:[CNode]6:4:@de4408af9f648bc43794b496de642f2ba8fef52099d138465267bb08b5fd7e8fPb&
|
|
||||||
S-Prim-MakeTuple:5S-Prim-MakeTuplebH
|
S-Prim-MakeTuple:5S-Prim-MakeTuplebH
|
||||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h
|
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h
|
|
@ -2,13 +2,13 @@
|
||||||
0.1.1 MindSpore*1.9.0:¤
|
0.1.1 MindSpore*1.9.0:¤
|
||||||
’
|
’
|
||||||
|
|
||||||
bprop.26:xbprop.26:[CNode]27:1bprop.26:[CNode]27:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op17
|
bprop.43:xbprop.43:[CNode]44:1bprop.43:[CNode]44:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op31
|
||||||
z
|
z
|
||||||
bprop.26:[CNode]27:1bprop.26:[CNode]28:3bprop.26:[CNode]28:3"REF::S-Prim-MakeTuple:4:Default/S-Prim-MakeTuple-op18bprop.26*
|
bprop.43:[CNode]44:1bprop.43:[CNode]45:3bprop.43:[CNode]45:3"REF::S-Prim-MakeTuple:4:Default/S-Prim-MakeTuple-op32bprop.43*
|
||||||
|
|
||||||
bprop.26:x*
|
bprop.43:x*
|
||||||
bprop.26:out*
|
bprop.43:out*
|
||||||
bprop.26:dout2
|
bprop.43:dout2
|
||||||
bprop.26:[CNode]28:3:@da20d6cb5b4c7c7cced90d19d6551099a88fa59be3bc1e215683bed4c1b029abPbH
|
bprop.43:[CNode]45:3:@030ffb7e684c8e2604e69bfc02a6e32af82543a97f4aa0a6bc7cab6275ee9b9dPb&
|
||||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
|
S-Prim-MakeTuple:4S-Prim-MakeTuplebH
|
||||||
S-Prim-MakeTuple:4S-Prim-MakeTupleh
|
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h
|
|
@ -2,19 +2,19 @@
|
||||||
0.1.1 MindSpore*1.9.0:Ţ
|
0.1.1 MindSpore*1.9.0:Ţ
|
||||||
’
|
’
|
||||||
|
|
||||||
bprop.30:xbprop.30:[CNode]31:1bprop.30:[CNode]31:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op21
|
bprop.61:xbprop.61:[CNode]62:1bprop.61:[CNode]62:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op44
|
||||||
’
|
’
|
||||||
|
|
||||||
bprop.30:ybprop.30:[CNode]32:3bprop.30:[CNode]32:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op22
|
bprop.61:ybprop.61:[CNode]63:3bprop.61:[CNode]63:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op45
|
||||||
<EFBFBD>
|
<EFBFBD>
|
||||||
bprop.30:[CNode]31:1
|
bprop.61:[CNode]62:1
|
||||||
bprop.30:[CNode]32:3bprop.30:[CNode]33:4bprop.30:[CNode]33:4"REF::S-Prim-MakeTuple:5:Default/S-Prim-MakeTuple-op23bprop.30*
|
bprop.61:[CNode]63:3bprop.61:[CNode]64:4bprop.61:[CNode]64:4"REF::S-Prim-MakeTuple:5:Default/S-Prim-MakeTuple-op46bprop.61*
|
||||||
|
|
||||||
bprop.30:x*
|
bprop.61:x*
|
||||||
|
|
||||||
bprop.30:y*
|
bprop.61:y*
|
||||||
bprop.30:out*
|
bprop.61:out*
|
||||||
bprop.30:dout2
|
bprop.61:dout2
|
||||||
bprop.30:[CNode]33:4:@51014516712342426367430e7abdbb8b5454385a02c91339e89ad18f19803ee3PbH
|
bprop.61:[CNode]64:4:@157abcda70ad669686a320be427dedf2a17498d59a38b65d04c272c1a40296e0Pb&
|
||||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
|
S-Prim-MakeTuple:5S-Prim-MakeTuplebH
|
||||||
S-Prim-MakeTuple:5S-Prim-MakeTupleh
|
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h
|
|
@ -1,20 +1,16 @@
|
||||||
|
|
||||||
0.1.1 MindSpore*1.9.0:Þ
|
0.1.1 MindSpore*1.9.0:ú
|
||||||
’
|
˜
|
||||||
|
bprop.143:xbprop.143:[CNode]144:1bprop.143:[CNode]144:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:/Default/S-Prim-hyper_map[zeros_like_leaf]-op104
|
||||||
bprop.70:xbprop.70:[CNode]71:1bprop.70:[CNode]71:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op51
|
˜
|
||||||
’
|
bprop.143:ybprop.143:[CNode]145:3bprop.143:[CNode]145:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:/Default/S-Prim-hyper_map[zeros_like_leaf]-op105
|
||||||
|
™
|
||||||
bprop.70:ybprop.70:[CNode]72:3bprop.70:[CNode]72:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op52
|
bprop.143:[CNode]144:1
|
||||||
<EFBFBD>
|
bprop.143:[CNode]145:3bprop.143:[CNode]146:4bprop.143:[CNode]146:4"REF::S-Prim-MakeTuple:5:Default/S-Prim-MakeTuple-op106 bprop.143*
|
||||||
bprop.70:[CNode]71:1
|
bprop.143:x*
|
||||||
bprop.70:[CNode]72:3bprop.70:[CNode]73:4bprop.70:[CNode]73:4"REF::S-Prim-MakeTuple:5:Default/S-Prim-MakeTuple-op53bprop.70*
|
bprop.143:y*
|
||||||
|
bprop.143:out*
|
||||||
bprop.70:x*
|
bprop.143:dout2
|
||||||
|
bprop.143:[CNode]146:4:@157abcda70ad669686a320be427dedf2a17498d59a38b65d04c272c1a40296e0Pb&
|
||||||
bprop.70:y*
|
|
||||||
bprop.70:out*
|
|
||||||
bprop.70:dout2
|
|
||||||
bprop.70:[CNode]73:4:@51014516712342426367430e7abdbb8b5454385a02c91339e89ad18f19803ee3Pb&
|
|
||||||
S-Prim-MakeTuple:5S-Prim-MakeTuplebH
|
S-Prim-MakeTuple:5S-Prim-MakeTuplebH
|
||||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h
|
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h
|
|
@ -2,19 +2,19 @@
|
||||||
0.1.1 MindSpore*1.9.0:Ţ
|
0.1.1 MindSpore*1.9.0:Ţ
|
||||||
’
|
’
|
||||||
|
|
||||||
bprop.42:xbprop.42:[CNode]43:1bprop.42:[CNode]43:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op30
|
bprop.73:xbprop.73:[CNode]74:1bprop.73:[CNode]74:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op53
|
||||||
’
|
’
|
||||||
|
|
||||||
bprop.42:ybprop.42:[CNode]44:3bprop.42:[CNode]44:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op31
|
bprop.73:ybprop.73:[CNode]75:3bprop.73:[CNode]75:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op54
|
||||||
<EFBFBD>
|
<EFBFBD>
|
||||||
bprop.42:[CNode]43:1
|
bprop.73:[CNode]74:1
|
||||||
bprop.42:[CNode]44:3bprop.42:[CNode]45:4bprop.42:[CNode]45:4"REF::S-Prim-MakeTuple:5:Default/S-Prim-MakeTuple-op32bprop.42*
|
bprop.73:[CNode]75:3bprop.73:[CNode]76:4bprop.73:[CNode]76:4"REF::S-Prim-MakeTuple:5:Default/S-Prim-MakeTuple-op55bprop.73*
|
||||||
|
|
||||||
bprop.42:x*
|
bprop.73:x*
|
||||||
|
|
||||||
bprop.42:y*
|
bprop.73:y*
|
||||||
bprop.42:out*
|
bprop.73:out*
|
||||||
bprop.42:dout2
|
bprop.73:dout2
|
||||||
bprop.42:[CNode]45:4:@51014516712342426367430e7abdbb8b5454385a02c91339e89ad18f19803ee3Pb&
|
bprop.73:[CNode]76:4:@157abcda70ad669686a320be427dedf2a17498d59a38b65d04c272c1a40296e0PbH
|
||||||
S-Prim-MakeTuple:5S-Prim-MakeTuplebH
|
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
|
||||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h
|
S-Prim-MakeTuple:5S-Prim-MakeTupleh
|
|
@ -2,19 +2,19 @@
|
||||||
0.1.1 MindSpore*1.9.0:Ţ
|
0.1.1 MindSpore*1.9.0:Ţ
|
||||||
’
|
’
|
||||||
|
|
||||||
bprop.38:xbprop.38:[CNode]39:1bprop.38:[CNode]39:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op27
|
bprop.69:xbprop.69:[CNode]70:1bprop.69:[CNode]70:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op50
|
||||||
’
|
’
|
||||||
|
|
||||||
bprop.38:ybprop.38:[CNode]40:3bprop.38:[CNode]40:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op28
|
bprop.69:ybprop.69:[CNode]71:3bprop.69:[CNode]71:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op51
|
||||||
<EFBFBD>
|
<EFBFBD>
|
||||||
bprop.38:[CNode]39:1
|
bprop.69:[CNode]70:1
|
||||||
bprop.38:[CNode]40:3bprop.38:[CNode]41:4bprop.38:[CNode]41:4"REF::S-Prim-MakeTuple:5:Default/S-Prim-MakeTuple-op29bprop.38*
|
bprop.69:[CNode]71:3bprop.69:[CNode]72:4bprop.69:[CNode]72:4"REF::S-Prim-MakeTuple:5:Default/S-Prim-MakeTuple-op52bprop.69*
|
||||||
|
|
||||||
bprop.38:x*
|
bprop.69:x*
|
||||||
|
|
||||||
bprop.38:y*
|
bprop.69:y*
|
||||||
bprop.38:out*
|
bprop.69:out*
|
||||||
bprop.38:dout2
|
bprop.69:dout2
|
||||||
bprop.38:[CNode]41:4:@51014516712342426367430e7abdbb8b5454385a02c91339e89ad18f19803ee3PbH
|
bprop.69:[CNode]72:4:@157abcda70ad669686a320be427dedf2a17498d59a38b65d04c272c1a40296e0Pb&
|
||||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
|
S-Prim-MakeTuple:5S-Prim-MakeTuplebH
|
||||||
S-Prim-MakeTuple:5S-Prim-MakeTupleh
|
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h
|
|
@ -1,9 +1,9 @@
|
||||||
|
|
||||||
0.1.1 MindSpore*1.9.0:ü
|
0.1.1 MindSpore*1.9.0:ü
|
||||||
m
|
m
|
||||||
bprop.1:doutbprop.1:[CNode]2:1bprop.1:[CNode]2:1"REF::S-Prim-MakeTuple:2:Default/S-Prim-MakeTuple-op0bprop.1*
|
bprop.3:doutbprop.3:[CNode]4:1bprop.3:[CNode]4:1"REF::S-Prim-MakeTuple:2:Default/S-Prim-MakeTuple-op2bprop.3*
|
||||||
bprop.1:x*
|
bprop.3:x*
|
||||||
bprop.1:out*
|
bprop.3:out*
|
||||||
bprop.1:dout2
|
bprop.3:dout2
|
||||||
bprop.1:[CNode]2:1:@da20d6cb5b4c7c7cced90d19d6551099a88fa59be3bc1e215683bed4c1b029abPb&
|
bprop.3:[CNode]4:1:@030ffb7e684c8e2604e69bfc02a6e32af82543a97f4aa0a6bc7cab6275ee9b9dPb&
|
||||||
S-Prim-MakeTuple:2S-Prim-MakeTupleh
|
S-Prim-MakeTuple:2S-Prim-MakeTupleh
|
|
@ -1,12 +1,14 @@
|
||||||
|
|
||||||
0.1.1 MindSpore*1.9.0:—
|
0.1.1 MindSpore*1.9.0:¤
|
||||||
Ž
|
’
|
||||||
bprop.9:xbprop.9:[CNode]10:1bprop.9:[CNode]10:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:-Default/S-Prim-hyper_map[zeros_like_leaf]-op6
|
|
||||||
v
|
bprop.30:xbprop.30:[CNode]31:1bprop.30:[CNode]31:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op21
|
||||||
bprop.9:[CNode]10:1bprop.9:[CNode]11:3bprop.9:[CNode]11:3"REF::S-Prim-MakeTuple:4:Default/S-Prim-MakeTuple-op7bprop.9*
|
z
|
||||||
bprop.9:x*
|
bprop.30:[CNode]31:1bprop.30:[CNode]32:3bprop.30:[CNode]32:3"REF::S-Prim-MakeTuple:4:Default/S-Prim-MakeTuple-op22bprop.30*
|
||||||
bprop.9:out*
|
|
||||||
bprop.9:dout2
|
bprop.30:x*
|
||||||
bprop.9:[CNode]11:3:@51014516712342426367430e7abdbb8b5454385a02c91339e89ad18f19803ee3PbH
|
bprop.30:out*
|
||||||
|
bprop.30:dout2
|
||||||
|
bprop.30:[CNode]32:3:@157abcda70ad669686a320be427dedf2a17498d59a38b65d04c272c1a40296e0PbH
|
||||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
|
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
|
||||||
S-Prim-MakeTuple:4S-Prim-MakeTupleh
|
S-Prim-MakeTuple:4S-Prim-MakeTupleh
|
|
@ -1,14 +1,12 @@
|
||||||
|
|
||||||
0.1.1 MindSpore*1.9.0:¤
|
0.1.1 MindSpore*1.9.0:¸
|
||||||
’
|
˜
|
||||||
|
bprop.164:xbprop.164:[CNode]165:1bprop.164:[CNode]165:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:/Default/S-Prim-hyper_map[zeros_like_leaf]-op120
|
||||||
bprop.91:xbprop.91:[CNode]92:1bprop.91:[CNode]92:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op67
|
<EFBFBD>
|
||||||
z
|
bprop.164:[CNode]165:1bprop.164:[CNode]166:3bprop.164:[CNode]166:3"REF::S-Prim-MakeTuple:4:Default/S-Prim-MakeTuple-op121 bprop.164*
|
||||||
bprop.91:[CNode]92:1bprop.91:[CNode]93:3bprop.91:[CNode]93:3"REF::S-Prim-MakeTuple:4:Default/S-Prim-MakeTuple-op68bprop.91*
|
bprop.164:x*
|
||||||
|
bprop.164:out*
|
||||||
bprop.91:x*
|
bprop.164:dout2
|
||||||
bprop.91:out*
|
bprop.164:[CNode]166:3:@157abcda70ad669686a320be427dedf2a17498d59a38b65d04c272c1a40296e0PbH
|
||||||
bprop.91:dout2
|
|
||||||
bprop.91:[CNode]93:3:@51014516712342426367430e7abdbb8b5454385a02c91339e89ad18f19803ee3PbH
|
|
||||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
|
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
|
||||||
S-Prim-MakeTuple:4S-Prim-MakeTupleh
|
S-Prim-MakeTuple:4S-Prim-MakeTupleh
|
|
@ -1,14 +1,12 @@
|
||||||
|
|
||||||
0.1.1 MindSpore*1.9.0:¤
|
0.1.1 MindSpore*1.9.0:¸
|
||||||
’
|
˜
|
||||||
|
bprop.161:xbprop.161:[CNode]162:1bprop.161:[CNode]162:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:/Default/S-Prim-hyper_map[zeros_like_leaf]-op118
|
||||||
bprop.88:xbprop.88:[CNode]89:1bprop.88:[CNode]89:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op65
|
<EFBFBD>
|
||||||
z
|
bprop.161:[CNode]162:1bprop.161:[CNode]163:3bprop.161:[CNode]163:3"REF::S-Prim-MakeTuple:4:Default/S-Prim-MakeTuple-op119 bprop.161*
|
||||||
bprop.88:[CNode]89:1bprop.88:[CNode]90:3bprop.88:[CNode]90:3"REF::S-Prim-MakeTuple:4:Default/S-Prim-MakeTuple-op66bprop.88*
|
bprop.161:x*
|
||||||
|
bprop.161:out*
|
||||||
bprop.88:x*
|
bprop.161:dout2
|
||||||
bprop.88:out*
|
bprop.161:[CNode]163:3:@157abcda70ad669686a320be427dedf2a17498d59a38b65d04c272c1a40296e0Pb&
|
||||||
bprop.88:dout2
|
|
||||||
bprop.88:[CNode]90:3:@51014516712342426367430e7abdbb8b5454385a02c91339e89ad18f19803ee3Pb&
|
|
||||||
S-Prim-MakeTuple:4S-Prim-MakeTuplebH
|
S-Prim-MakeTuple:4S-Prim-MakeTuplebH
|
||||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h
|
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h
|
|
@ -2,19 +2,19 @@
|
||||||
0.1.1 MindSpore*1.9.0:Ţ
|
0.1.1 MindSpore*1.9.0:Ţ
|
||||||
’
|
’
|
||||||
|
|
||||||
bprop.50:xbprop.50:[CNode]51:1bprop.50:[CNode]51:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op36
|
bprop.81:xbprop.81:[CNode]82:1bprop.81:[CNode]82:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op59
|
||||||
’
|
’
|
||||||
|
|
||||||
bprop.50:ybprop.50:[CNode]52:3bprop.50:[CNode]52:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op37
|
bprop.81:ybprop.81:[CNode]83:3bprop.81:[CNode]83:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op60
|
||||||
<EFBFBD>
|
<EFBFBD>
|
||||||
bprop.50:[CNode]51:1
|
bprop.81:[CNode]82:1
|
||||||
bprop.50:[CNode]52:3bprop.50:[CNode]53:4bprop.50:[CNode]53:4"REF::S-Prim-MakeTuple:5:Default/S-Prim-MakeTuple-op38bprop.50*
|
bprop.81:[CNode]83:3bprop.81:[CNode]84:4bprop.81:[CNode]84:4"REF::S-Prim-MakeTuple:5:Default/S-Prim-MakeTuple-op61bprop.81*
|
||||||
|
|
||||||
bprop.50:x*
|
bprop.81:x*
|
||||||
|
|
||||||
bprop.50:y*
|
bprop.81:y*
|
||||||
bprop.50:out*
|
bprop.81:out*
|
||||||
bprop.50:dout2
|
bprop.81:dout2
|
||||||
bprop.50:[CNode]53:4:@51014516712342426367430e7abdbb8b5454385a02c91339e89ad18f19803ee3PbH
|
bprop.81:[CNode]84:4:@157abcda70ad669686a320be427dedf2a17498d59a38b65d04c272c1a40296e0Pb&
|
||||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
|
S-Prim-MakeTuple:5S-Prim-MakeTuplebH
|
||||||
S-Prim-MakeTuple:5S-Prim-MakeTupleh
|
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h
|
|
@ -2,19 +2,19 @@
|
||||||
0.1.1 MindSpore*1.9.0:Ţ
|
0.1.1 MindSpore*1.9.0:Ţ
|
||||||
’
|
’
|
||||||
|
|
||||||
bprop.46:xbprop.46:[CNode]47:1bprop.46:[CNode]47:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op33
|
bprop.77:xbprop.77:[CNode]78:1bprop.77:[CNode]78:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op56
|
||||||
’
|
’
|
||||||
|
|
||||||
bprop.46:ybprop.46:[CNode]48:3bprop.46:[CNode]48:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op34
|
bprop.77:ybprop.77:[CNode]79:3bprop.77:[CNode]79:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op57
|
||||||
<EFBFBD>
|
<EFBFBD>
|
||||||
bprop.46:[CNode]47:1
|
bprop.77:[CNode]78:1
|
||||||
bprop.46:[CNode]48:3bprop.46:[CNode]49:4bprop.46:[CNode]49:4"REF::S-Prim-MakeTuple:5:Default/S-Prim-MakeTuple-op35bprop.46*
|
bprop.77:[CNode]79:3bprop.77:[CNode]80:4bprop.77:[CNode]80:4"REF::S-Prim-MakeTuple:5:Default/S-Prim-MakeTuple-op58bprop.77*
|
||||||
|
|
||||||
bprop.46:x*
|
bprop.77:x*
|
||||||
|
|
||||||
bprop.46:y*
|
bprop.77:y*
|
||||||
bprop.46:out*
|
bprop.77:out*
|
||||||
bprop.46:dout2
|
bprop.77:dout2
|
||||||
bprop.46:[CNode]49:4:@51014516712342426367430e7abdbb8b5454385a02c91339e89ad18f19803ee3Pb&
|
bprop.77:[CNode]80:4:@157abcda70ad669686a320be427dedf2a17498d59a38b65d04c272c1a40296e0PbH
|
||||||
S-Prim-MakeTuple:5S-Prim-MakeTuplebH
|
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
|
||||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h
|
S-Prim-MakeTuple:5S-Prim-MakeTupleh
|
|
@ -1,20 +1,20 @@
|
||||||
|
|
||||||
0.1.1 MindSpore*1.9.0:©
|
0.1.1 MindSpore*1.9.0:©
|
||||||
–
|
–
|
||||||
bprop.25:startbprop.25:[CNode]26:1bprop.25:[CNode]26:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op17
|
bprop.46:startbprop.46:[CNode]47:1bprop.46:[CNode]47:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op32
|
||||||
•
|
•
|
||||||
bprop.25:stopbprop.25:[CNode]27:3bprop.25:[CNode]27:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op18
|
bprop.46:stopbprop.46:[CNode]48:3bprop.46:[CNode]48:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op33
|
||||||
”
|
”
|
||||||
bprop.25:numbprop.25:[CNode]28:4bprop.25:[CNode]28:4"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op19
|
bprop.46:numbprop.46:[CNode]49:4bprop.46:[CNode]49:4"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op34
|
||||||
¦
|
¦
|
||||||
bprop.25:[CNode]26:1
|
bprop.46:[CNode]47:1
|
||||||
bprop.25:[CNode]27:3
|
bprop.46:[CNode]48:3
|
||||||
bprop.25:[CNode]28:4bprop.25:[CNode]29:5bprop.25:[CNode]29:5"REF::S-Prim-MakeTuple:6:Default/S-Prim-MakeTuple-op20bprop.25*
|
bprop.46:[CNode]49:4bprop.46:[CNode]50:5bprop.46:[CNode]50:5"REF::S-Prim-MakeTuple:6:Default/S-Prim-MakeTuple-op35bprop.46*
|
||||||
bprop.25:start*
|
bprop.46:start*
|
||||||
bprop.25:stop*
|
bprop.46:stop*
|
||||||
bprop.25:num*
|
bprop.46:num*
|
||||||
bprop.25:out*
|
bprop.46:out*
|
||||||
bprop.25:dout2
|
bprop.46:dout2
|
||||||
bprop.25:[CNode]29:5:@51014516712342426367430e7abdbb8b5454385a02c91339e89ad18f19803ee3PbH
|
bprop.46:[CNode]50:5:@157abcda70ad669686a320be427dedf2a17498d59a38b65d04c272c1a40296e0Pb&
|
||||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
|
S-Prim-MakeTuple:6S-Prim-MakeTuplebH
|
||||||
S-Prim-MakeTuple:6S-Prim-MakeTupleh
|
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h
|
|
@ -2,19 +2,19 @@
|
||||||
0.1.1 MindSpore*1.9.0:Þ
|
0.1.1 MindSpore*1.9.0:Þ
|
||||||
’
|
’
|
||||||
|
|
||||||
bprop.54:xbprop.54:[CNode]55:1bprop.54:[CNode]55:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op39
|
bprop.85:xbprop.85:[CNode]86:1bprop.85:[CNode]86:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op62
|
||||||
’
|
’
|
||||||
|
|
||||||
bprop.54:ybprop.54:[CNode]56:3bprop.54:[CNode]56:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op40
|
bprop.85:ybprop.85:[CNode]87:3bprop.85:[CNode]87:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op63
|
||||||
<EFBFBD>
|
<EFBFBD>
|
||||||
bprop.54:[CNode]55:1
|
bprop.85:[CNode]86:1
|
||||||
bprop.54:[CNode]56:3bprop.54:[CNode]57:4bprop.54:[CNode]57:4"REF::S-Prim-MakeTuple:5:Default/S-Prim-MakeTuple-op41bprop.54*
|
bprop.85:[CNode]87:3bprop.85:[CNode]88:4bprop.85:[CNode]88:4"REF::S-Prim-MakeTuple:5:Default/S-Prim-MakeTuple-op64bprop.85*
|
||||||
|
|
||||||
bprop.54:x*
|
bprop.85:x*
|
||||||
|
|
||||||
bprop.54:y*
|
bprop.85:y*
|
||||||
bprop.54:out*
|
bprop.85:out*
|
||||||
bprop.54:dout2
|
bprop.85:dout2
|
||||||
bprop.54:[CNode]57:4:@51014516712342426367430e7abdbb8b5454385a02c91339e89ad18f19803ee3PbH
|
bprop.85:[CNode]88:4:@157abcda70ad669686a320be427dedf2a17498d59a38b65d04c272c1a40296e0PbH
|
||||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
|
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
|
||||||
S-Prim-MakeTuple:5S-Prim-MakeTupleh
|
S-Prim-MakeTuple:5S-Prim-MakeTupleh
|
|
@ -2,13 +2,13 @@
|
||||||
0.1.1 MindSpore*1.9.0:¤
|
0.1.1 MindSpore*1.9.0:¤
|
||||||
’
|
’
|
||||||
|
|
||||||
bprop.19:xbprop.19:[CNode]20:1bprop.19:[CNode]20:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op13
|
bprop.40:xbprop.40:[CNode]41:1bprop.40:[CNode]41:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op28
|
||||||
z
|
z
|
||||||
bprop.19:[CNode]20:1bprop.19:[CNode]21:3bprop.19:[CNode]21:3"REF::S-Prim-MakeTuple:4:Default/S-Prim-MakeTuple-op14bprop.19*
|
bprop.40:[CNode]41:1bprop.40:[CNode]42:3bprop.40:[CNode]42:3"REF::S-Prim-MakeTuple:4:Default/S-Prim-MakeTuple-op29bprop.40*
|
||||||
|
|
||||||
bprop.19:x*
|
bprop.40:x*
|
||||||
bprop.19:out*
|
bprop.40:out*
|
||||||
bprop.19:dout2
|
bprop.40:dout2
|
||||||
bprop.19:[CNode]21:3:@51014516712342426367430e7abdbb8b5454385a02c91339e89ad18f19803ee3Pb&
|
bprop.40:[CNode]42:3:@157abcda70ad669686a320be427dedf2a17498d59a38b65d04c272c1a40296e0Pb&
|
||||||
S-Prim-MakeTuple:4S-Prim-MakeTuplebH
|
S-Prim-MakeTuple:4S-Prim-MakeTuplebH
|
||||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h
|
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h
|
|
@ -2,19 +2,19 @@
|
||||||
0.1.1 MindSpore*1.9.0:Ţ
|
0.1.1 MindSpore*1.9.0:Ţ
|
||||||
’
|
’
|
||||||
|
|
||||||
bprop.58:xbprop.58:[CNode]59:1bprop.58:[CNode]59:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op42
|
bprop.89:xbprop.89:[CNode]90:1bprop.89:[CNode]90:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op65
|
||||||
’
|
’
|
||||||
|
|
||||||
bprop.58:ybprop.58:[CNode]60:3bprop.58:[CNode]60:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op43
|
bprop.89:ybprop.89:[CNode]91:3bprop.89:[CNode]91:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op66
|
||||||
<EFBFBD>
|
<EFBFBD>
|
||||||
bprop.58:[CNode]59:1
|
bprop.89:[CNode]90:1
|
||||||
bprop.58:[CNode]60:3bprop.58:[CNode]61:4bprop.58:[CNode]61:4"REF::S-Prim-MakeTuple:5:Default/S-Prim-MakeTuple-op44bprop.58*
|
bprop.89:[CNode]91:3bprop.89:[CNode]92:4bprop.89:[CNode]92:4"REF::S-Prim-MakeTuple:5:Default/S-Prim-MakeTuple-op67bprop.89*
|
||||||
|
|
||||||
bprop.58:x*
|
bprop.89:x*
|
||||||
|
|
||||||
bprop.58:y*
|
bprop.89:y*
|
||||||
bprop.58:out*
|
bprop.89:out*
|
||||||
bprop.58:dout2
|
bprop.89:dout2
|
||||||
bprop.58:[CNode]61:4:@51014516712342426367430e7abdbb8b5454385a02c91339e89ad18f19803ee3PbH
|
bprop.89:[CNode]92:4:@157abcda70ad669686a320be427dedf2a17498d59a38b65d04c272c1a40296e0Pb&
|
||||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
|
S-Prim-MakeTuple:5S-Prim-MakeTuplebH
|
||||||
S-Prim-MakeTuple:5S-Prim-MakeTupleh
|
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h
|
Binary file not shown.
Binary file not shown.
|
@ -2,19 +2,19 @@
|
||||||
0.1.1 MindSpore*1.9.0:Ţ
|
0.1.1 MindSpore*1.9.0:Ţ
|
||||||
’
|
’
|
||||||
|
|
||||||
bprop.34:xbprop.34:[CNode]35:1bprop.34:[CNode]35:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op24
|
bprop.65:xbprop.65:[CNode]66:1bprop.65:[CNode]66:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op47
|
||||||
’
|
’
|
||||||
|
|
||||||
bprop.34:ybprop.34:[CNode]36:3bprop.34:[CNode]36:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op25
|
bprop.65:ybprop.65:[CNode]67:3bprop.65:[CNode]67:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op48
|
||||||
<EFBFBD>
|
<EFBFBD>
|
||||||
bprop.34:[CNode]35:1
|
bprop.65:[CNode]66:1
|
||||||
bprop.34:[CNode]36:3bprop.34:[CNode]37:4bprop.34:[CNode]37:4"REF::S-Prim-MakeTuple:5:Default/S-Prim-MakeTuple-op26bprop.34*
|
bprop.65:[CNode]67:3bprop.65:[CNode]68:4bprop.65:[CNode]68:4"REF::S-Prim-MakeTuple:5:Default/S-Prim-MakeTuple-op49bprop.65*
|
||||||
|
|
||||||
bprop.34:x*
|
bprop.65:x*
|
||||||
|
|
||||||
bprop.34:y*
|
bprop.65:y*
|
||||||
bprop.34:out*
|
bprop.65:out*
|
||||||
bprop.34:dout2
|
bprop.65:dout2
|
||||||
bprop.34:[CNode]37:4:@51014516712342426367430e7abdbb8b5454385a02c91339e89ad18f19803ee3PbH
|
bprop.65:[CNode]68:4:@157abcda70ad669686a320be427dedf2a17498d59a38b65d04c272c1a40296e0Pb&
|
||||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
|
S-Prim-MakeTuple:5S-Prim-MakeTuplebH
|
||||||
S-Prim-MakeTuple:5S-Prim-MakeTupleh
|
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h
|
|
@ -1,24 +1,24 @@
|
||||||
|
|
||||||
0.1.1 MindSpore*1.9.0:Ý
|
0.1.1 MindSpore*1.9.0:‚
|
||||||
’
|
˜
|
||||||
bprop.7:indicesbprop.7:[CNode]8:1bprop.7:[CNode]8:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:-Default/S-Prim-hyper_map[zeros_like_leaf]-op5
|
bprop.26:indicesbprop.26:[CNode]27:1bprop.26:[CNode]27:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op18
|
||||||
<EFBFBD>
|
|
||||||
bprop.7:depthbprop.7:[CNode]9:3bprop.7:[CNode]9:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:-Default/S-Prim-hyper_map[zeros_like_leaf]-op6
|
|
||||||
•
|
|
||||||
bprop.7:on_valuebprop.7:[CNode]10:4bprop.7:[CNode]10:4"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:-Default/S-Prim-hyper_map[zeros_like_leaf]-op7
|
|
||||||
–
|
–
|
||||||
bprop.7:off_valuebprop.7:[CNode]11:5bprop.7:[CNode]11:5"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:-Default/S-Prim-hyper_map[zeros_like_leaf]-op8
|
bprop.26:depthbprop.26:[CNode]28:3bprop.26:[CNode]28:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op19
|
||||||
³
|
™
|
||||||
bprop.7:[CNode]8:1
|
bprop.26:on_valuebprop.26:[CNode]29:4bprop.26:[CNode]29:4"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op20
|
||||||
bprop.7:[CNode]9:3
|
š
|
||||||
bprop.7:[CNode]10:4
|
bprop.26:off_valuebprop.26:[CNode]30:5bprop.26:[CNode]30:5"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op21
|
||||||
bprop.7:[CNode]11:5bprop.7:[CNode]12:6bprop.7:[CNode]12:6"REF::S-Prim-MakeTuple:7:Default/S-Prim-MakeTuple-op9bprop.7*
|
¼
|
||||||
bprop.7:indices*
|
bprop.26:[CNode]27:1
|
||||||
bprop.7:depth*
|
bprop.26:[CNode]28:3
|
||||||
bprop.7:on_value*
|
bprop.26:[CNode]29:4
|
||||||
bprop.7:off_value*
|
bprop.26:[CNode]30:5bprop.26:[CNode]31:6bprop.26:[CNode]31:6"REF::S-Prim-MakeTuple:7:Default/S-Prim-MakeTuple-op22bprop.26*
|
||||||
bprop.7:out*
|
bprop.26:indices*
|
||||||
bprop.7:dout2
|
bprop.26:depth*
|
||||||
bprop.7:[CNode]12:6:@de4408af9f648bc43794b496de642f2ba8fef52099d138465267bb08b5fd7e8fPb&
|
bprop.26:on_value*
|
||||||
|
bprop.26:off_value*
|
||||||
|
bprop.26:out*
|
||||||
|
bprop.26:dout2
|
||||||
|
bprop.26:[CNode]31:6:@320a603a26f1cc5c0174957a4eea3494bfd5f1c2f20ad64ef6e0ee49a91b62baPb&
|
||||||
S-Prim-MakeTuple:7S-Prim-MakeTuplebH
|
S-Prim-MakeTuple:7S-Prim-MakeTuplebH
|
||||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h
|
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h
|
|
@ -1,12 +1,14 @@
|
||||||
|
|
||||||
0.1.1 MindSpore*1.9.0:”
|
0.1.1 MindSpore*1.9.0:¢
|
||||||
Œ
|
‘
|
||||||
bprop.8:xbprop.8:[CNode]9:1bprop.8:[CNode]9:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:-Default/S-Prim-hyper_map[zeros_like_leaf]-op5
|
|
||||||
u
|
bprop.10:xbprop.10:[CNode]11:1bprop.10:[CNode]11:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:-Default/S-Prim-hyper_map[zeros_like_leaf]-op7
|
||||||
bprop.8:[CNode]9:1bprop.8:[CNode]10:3bprop.8:[CNode]10:3"REF::S-Prim-MakeTuple:4:Default/S-Prim-MakeTuple-op6bprop.8*
|
y
|
||||||
bprop.8:x*
|
bprop.10:[CNode]11:1bprop.10:[CNode]12:3bprop.10:[CNode]12:3"REF::S-Prim-MakeTuple:4:Default/S-Prim-MakeTuple-op8bprop.10*
|
||||||
bprop.8:out*
|
|
||||||
bprop.8:dout2
|
bprop.10:x*
|
||||||
bprop.8:[CNode]10:3:@da20d6cb5b4c7c7cced90d19d6551099a88fa59be3bc1e215683bed4c1b029abPb&
|
bprop.10:out*
|
||||||
|
bprop.10:dout2
|
||||||
|
bprop.10:[CNode]12:3:@030ffb7e684c8e2604e69bfc02a6e32af82543a97f4aa0a6bc7cab6275ee9b9dPb&
|
||||||
S-Prim-MakeTuple:4S-Prim-MakeTuplebH
|
S-Prim-MakeTuple:4S-Prim-MakeTuplebH
|
||||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h
|
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h
|
|
@ -1,20 +1,20 @@
|
||||||
|
|
||||||
0.1.1 MindSpore*1.9.0:Š
|
0.1.1 MindSpore*1.9.0:Š
|
||||||
<EFBFBD>
|
<EFBFBD>
|
||||||
bprop.3:startbprop.3:[CNode]4:1bprop.3:[CNode]4:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:-Default/S-Prim-hyper_map[zeros_like_leaf]-op1
|
bprop.5:startbprop.5:[CNode]6:1bprop.5:[CNode]6:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:-Default/S-Prim-hyper_map[zeros_like_leaf]-op3
|
||||||
<EFBFBD>
|
<EFBFBD>
|
||||||
bprop.3:limitbprop.3:[CNode]5:3bprop.3:[CNode]5:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:-Default/S-Prim-hyper_map[zeros_like_leaf]-op2
|
bprop.5:limitbprop.5:[CNode]7:3bprop.5:[CNode]7:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:-Default/S-Prim-hyper_map[zeros_like_leaf]-op4
|
||||||
<EFBFBD>
|
<EFBFBD>
|
||||||
bprop.3:deltabprop.3:[CNode]6:4bprop.3:[CNode]6:4"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:-Default/S-Prim-hyper_map[zeros_like_leaf]-op3
|
bprop.5:deltabprop.5:[CNode]8:4bprop.5:[CNode]8:4"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:-Default/S-Prim-hyper_map[zeros_like_leaf]-op5
|
||||||
›
|
›
|
||||||
bprop.3:[CNode]4:1
|
bprop.5:[CNode]6:1
|
||||||
bprop.3:[CNode]5:3
|
bprop.5:[CNode]7:3
|
||||||
bprop.3:[CNode]6:4bprop.3:[CNode]7:5bprop.3:[CNode]7:5"REF::S-Prim-MakeTuple:6:Default/S-Prim-MakeTuple-op4bprop.3*
|
bprop.5:[CNode]8:4bprop.5:[CNode]9:5bprop.5:[CNode]9:5"REF::S-Prim-MakeTuple:6:Default/S-Prim-MakeTuple-op6bprop.5*
|
||||||
bprop.3:start*
|
bprop.5:start*
|
||||||
bprop.3:limit*
|
bprop.5:limit*
|
||||||
bprop.3:delta*
|
bprop.5:delta*
|
||||||
bprop.3:out*
|
bprop.5:out*
|
||||||
bprop.3:dout2
|
bprop.5:dout2
|
||||||
bprop.3:[CNode]7:5:@da20d6cb5b4c7c7cced90d19d6551099a88fa59be3bc1e215683bed4c1b029abPb&
|
bprop.5:[CNode]9:5:@030ffb7e684c8e2604e69bfc02a6e32af82543a97f4aa0a6bc7cab6275ee9b9dPb&
|
||||||
S-Prim-MakeTuple:6S-Prim-MakeTuplebH
|
S-Prim-MakeTuple:6S-Prim-MakeTuplebH
|
||||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h
|
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h
|
|
@ -2,13 +2,13 @@
|
||||||
0.1.1 MindSpore*1.9.0:¤
|
0.1.1 MindSpore*1.9.0:¤
|
||||||
’
|
’
|
||||||
|
|
||||||
bprop.29:xbprop.29:[CNode]30:1bprop.29:[CNode]30:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op19
|
bprop.46:xbprop.46:[CNode]47:1bprop.46:[CNode]47:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op33
|
||||||
z
|
z
|
||||||
bprop.29:[CNode]30:1bprop.29:[CNode]31:3bprop.29:[CNode]31:3"REF::S-Prim-MakeTuple:4:Default/S-Prim-MakeTuple-op20bprop.29*
|
bprop.46:[CNode]47:1bprop.46:[CNode]48:3bprop.46:[CNode]48:3"REF::S-Prim-MakeTuple:4:Default/S-Prim-MakeTuple-op34bprop.46*
|
||||||
|
|
||||||
bprop.29:x*
|
bprop.46:x*
|
||||||
bprop.29:out*
|
bprop.46:out*
|
||||||
bprop.29:dout2
|
bprop.46:dout2
|
||||||
bprop.29:[CNode]31:3:@da20d6cb5b4c7c7cced90d19d6551099a88fa59be3bc1e215683bed4c1b029abPbH
|
bprop.46:[CNode]48:3:@030ffb7e684c8e2604e69bfc02a6e32af82543a97f4aa0a6bc7cab6275ee9b9dPbH
|
||||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
|
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
|
||||||
S-Prim-MakeTuple:4S-Prim-MakeTupleh
|
S-Prim-MakeTuple:4S-Prim-MakeTupleh
|
Binary file not shown.
|
@ -8,9 +8,9 @@ m
|
||||||
bprop.1:x*
|
bprop.1:x*
|
||||||
bprop.1:out*
|
bprop.1:out*
|
||||||
bprop.1:dout2
|
bprop.1:dout2
|
||||||
bprop.1:[CNode]2:3:@de4408af9f648bc43794b496de642f2ba8fef52099d138465267bb08b5fd7e8fPb&
|
bprop.1:[CNode]2:3:@320a603a26f1cc5c0174957a4eea3494bfd5f1c2f20ad64ef6e0ee49a91b62baPbr
|
||||||
S-Prim-MakeTuple:4S-Prim-MakeTuplebr
|
|
||||||
S-Prim-ReluGrad:2S-Prim-ReluGrad
|
S-Prim-ReluGrad:2S-Prim-ReluGrad
|
||||||
output_names€ŠZoutput€+
|
output_names€ŠZoutput€+
|
||||||
input_names€ŠZ
|
input_names€ŠZ
|
||||||
y_backprop€ŠZx€h
|
y_backprop€ŠZx€b&
|
||||||
|
S-Prim-MakeTuple:4S-Prim-MakeTupleh
|
|
@ -2,17 +2,17 @@
|
||||||
0.1.1 MindSpore*1.9.0:ä
|
0.1.1 MindSpore*1.9.0:ä
|
||||||
’
|
’
|
||||||
|
|
||||||
bprop.62:xbprop.62:[CNode]63:1bprop.62:[CNode]63:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op45
|
bprop.93:xbprop.93:[CNode]94:1bprop.93:[CNode]94:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op68
|
||||||
•
|
•
|
||||||
bprop.62:axisbprop.62:[CNode]64:3bprop.62:[CNode]64:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op46
|
bprop.93:axisbprop.93:[CNode]95:3bprop.93:[CNode]95:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op69
|
||||||
<EFBFBD>
|
<EFBFBD>
|
||||||
bprop.62:[CNode]63:1
|
bprop.93:[CNode]94:1
|
||||||
bprop.62:[CNode]64:3bprop.62:[CNode]65:4bprop.62:[CNode]65:4"REF::S-Prim-MakeTuple:5:Default/S-Prim-MakeTuple-op47bprop.62*
|
bprop.93:[CNode]95:3bprop.93:[CNode]96:4bprop.93:[CNode]96:4"REF::S-Prim-MakeTuple:5:Default/S-Prim-MakeTuple-op70bprop.93*
|
||||||
|
|
||||||
bprop.62:x*
|
bprop.93:x*
|
||||||
bprop.62:axis*
|
bprop.93:axis*
|
||||||
bprop.62:out*
|
bprop.93:out*
|
||||||
bprop.62:dout2
|
bprop.93:dout2
|
||||||
bprop.62:[CNode]65:4:@51014516712342426367430e7abdbb8b5454385a02c91339e89ad18f19803ee3Pb&
|
bprop.93:[CNode]96:4:@157abcda70ad669686a320be427dedf2a17498d59a38b65d04c272c1a40296e0Pb&
|
||||||
S-Prim-MakeTuple:5S-Prim-MakeTuplebH
|
S-Prim-MakeTuple:5S-Prim-MakeTuplebH
|
||||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h
|
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h
|
|
@ -1,18 +1,18 @@
|
||||||
|
|
||||||
0.1.1 MindSpore*1.9.0:ä
|
0.1.1 MindSpore*1.9.0:ç
|
||||||
<EFBFBD>
|
<EFBFBD>
|
||||||
|
|
||||||
bprop.66:xbprop.66:[CNode]67:1bprop.66:[CNode]67:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op48
|
bprop.97:xbprop.97:[CNode]98:1bprop.97:[CNode]98:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op71
|
||||||
<EFBFBD>
|
<EFBFBD>
|
||||||
bprop.66:axisbprop.66:[CNode]68:3bprop.66:[CNode]68:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op49
|
bprop.97:axisbprop.97:[CNode]99:3bprop.97:[CNode]99:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op72
|
||||||
<EFBFBD>
|
’
|
||||||
bprop.66:[CNode]67:1
|
bprop.97:[CNode]98:1
|
||||||
bprop.66:[CNode]68:3bprop.66:[CNode]69:4bprop.66:[CNode]69:4"REF::S-Prim-MakeTuple:5:Default/S-Prim-MakeTuple-op50bprop.66*
|
bprop.97:[CNode]99:3bprop.97:[CNode]100:4bprop.97:[CNode]100:4"REF::S-Prim-MakeTuple:5:Default/S-Prim-MakeTuple-op73bprop.97*
|
||||||
|
|
||||||
bprop.66:x*
|
bprop.97:x*
|
||||||
bprop.66:axis*
|
bprop.97:axis*
|
||||||
bprop.66:out*
|
bprop.97:out*
|
||||||
bprop.66:dout2
|
bprop.97:dout2
|
||||||
bprop.66:[CNode]69:4:@51014516712342426367430e7abdbb8b5454385a02c91339e89ad18f19803ee3PbH
|
bprop.97:[CNode]100:4:@157abcda70ad669686a320be427dedf2a17498d59a38b65d04c272c1a40296e0Pb&
|
||||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
|
S-Prim-MakeTuple:5S-Prim-MakeTuplebH
|
||||||
S-Prim-MakeTuple:5S-Prim-MakeTupleh
|
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h
|
|
@ -1,24 +1,24 @@
|
||||||
|
|
||||||
0.1.1 MindSpore*1.9.0:¿
|
0.1.1 MindSpore*1.9.0:¿
|
||||||
u
|
u
|
||||||
bprop.18:dout
|
bprop.71:dout
|
||||||
|
|
||||||
bprop.18:ybprop.18:dgrad:1bprop.18:dgrad:1"REF::S-Prim-ReluGrad:2:Default/S-Prim-ReluGrad-op14
|
bprop.71:ybprop.71:dgrad:1bprop.71:dgrad:1"REF::S-Prim-ReluGrad:2:Default/S-Prim-ReluGrad-op50
|
||||||
’
|
’
|
||||||
|
|
||||||
bprop.18:ybprop.18:[CNode]19:3bprop.18:[CNode]19:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:4:.Default/S-Prim-hyper_map[zeros_like_leaf]-op15
|
bprop.71:ybprop.71:[CNode]72:3bprop.71:[CNode]72:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:4:.Default/S-Prim-hyper_map[zeros_like_leaf]-op51
|
||||||
Œ
|
Œ
|
||||||
bprop.18:dgrad:1
|
bprop.71:dgrad:1
|
||||||
bprop.18:[CNode]19:3bprop.18:[CNode]20:5bprop.18:[CNode]20:5"REF::S-Prim-MakeTuple:6:Default/S-Prim-MakeTuple-op16bprop.18*
|
bprop.71:[CNode]72:3bprop.71:[CNode]73:5bprop.71:[CNode]73:5"REF::S-Prim-MakeTuple:6:Default/S-Prim-MakeTuple-op52bprop.71*
|
||||||
bprop.18:grad*
|
bprop.71:grad*
|
||||||
|
|
||||||
bprop.18:y*
|
bprop.71:y*
|
||||||
bprop.18:out*
|
bprop.71:out*
|
||||||
bprop.18:dout2
|
bprop.71:dout2
|
||||||
bprop.18:[CNode]20:5:@de4408af9f648bc43794b496de642f2ba8fef52099d138465267bb08b5fd7e8fPbr
|
bprop.71:[CNode]73:5:@320a603a26f1cc5c0174957a4eea3494bfd5f1c2f20ad64ef6e0ee49a91b62baPb&
|
||||||
|
S-Prim-MakeTuple:6S-Prim-MakeTuplebr
|
||||||
S-Prim-ReluGrad:2S-Prim-ReluGrad
|
S-Prim-ReluGrad:2S-Prim-ReluGrad
|
||||||
output_names€ŠZoutput€+
|
output_names€ŠZoutput€+
|
||||||
input_names€ŠZ
|
input_names€ŠZ
|
||||||
y_backprop€ŠZx€b&
|
y_backprop€ŠZx€bH
|
||||||
S-Prim-MakeTuple:6S-Prim-MakeTuplebH
|
|
||||||
#S-Prim-hyper_map[zeros_like_leaf]:4!S-Prim-hyper_map[zeros_like_leaf]h
|
#S-Prim-hyper_map[zeros_like_leaf]:4!S-Prim-hyper_map[zeros_like_leaf]h
|
|
@ -2,13 +2,13 @@
|
||||||
0.1.1 MindSpore*1.9.0:¤
|
0.1.1 MindSpore*1.9.0:¤
|
||||||
’
|
’
|
||||||
|
|
||||||
bprop.22:xbprop.22:[CNode]23:1bprop.22:[CNode]23:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op15
|
bprop.43:xbprop.43:[CNode]44:1bprop.43:[CNode]44:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op30
|
||||||
z
|
z
|
||||||
bprop.22:[CNode]23:1bprop.22:[CNode]24:3bprop.22:[CNode]24:3"REF::S-Prim-MakeTuple:4:Default/S-Prim-MakeTuple-op16bprop.22*
|
bprop.43:[CNode]44:1bprop.43:[CNode]45:3bprop.43:[CNode]45:3"REF::S-Prim-MakeTuple:4:Default/S-Prim-MakeTuple-op31bprop.43*
|
||||||
|
|
||||||
bprop.22:x*
|
bprop.43:x*
|
||||||
bprop.22:out*
|
bprop.43:out*
|
||||||
bprop.22:dout2
|
bprop.43:dout2
|
||||||
bprop.22:[CNode]24:3:@51014516712342426367430e7abdbb8b5454385a02c91339e89ad18f19803ee3Pb&
|
bprop.43:[CNode]45:3:@157abcda70ad669686a320be427dedf2a17498d59a38b65d04c272c1a40296e0Pb&
|
||||||
S-Prim-MakeTuple:4S-Prim-MakeTuplebH
|
S-Prim-MakeTuple:4S-Prim-MakeTuplebH
|
||||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h
|
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -1,35 +1,35 @@
|
||||||
|
|
||||||
0.1.1 MindSpore*1.9.0:Ç
|
0.1.1 MindSpore*1.9.0:Ç
|
||||||
•
|
•
|
||||||
bprop.32:condbprop.32:[CNode]33:1bprop.32:[CNode]33:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op21
|
bprop.49:condbprop.49:[CNode]50:1bprop.49:[CNode]50:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op35
|
||||||
’
|
’
|
||||||
|
|
||||||
bprop.32:xbprop.32:[CNode]34:3bprop.32:[CNode]34:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op22
|
bprop.49:xbprop.49:[CNode]51:3bprop.49:[CNode]51:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op36
|
||||||
’
|
’
|
||||||
bprop.32:cond
|
bprop.49:cond
|
||||||
bprop.32:dout
|
bprop.49:dout
|
||||||
bprop.32:[CNode]34:3bprop.32:[CNode]35:4bprop.32:[CNode]35:4"REF::S-Prim-Select:5:Default/S-Prim-Select-op23
|
bprop.49:[CNode]51:3bprop.49:[CNode]52:4bprop.49:[CNode]52:4"REF::S-Prim-Select:5:Default/S-Prim-Select-op37
|
||||||
’
|
’
|
||||||
|
|
||||||
bprop.32:ybprop.32:[CNode]36:6bprop.32:[CNode]36:6"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op24
|
bprop.49:ybprop.49:[CNode]53:6bprop.49:[CNode]53:6"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op38
|
||||||
’
|
’
|
||||||
bprop.32:cond
|
bprop.49:cond
|
||||||
bprop.32:[CNode]36:6
|
bprop.49:[CNode]53:6
|
||||||
bprop.32:doutbprop.32:[CNode]37:7bprop.32:[CNode]37:7"REF::S-Prim-Select:5:Default/S-Prim-Select-op25
|
bprop.49:doutbprop.49:[CNode]54:7bprop.49:[CNode]54:7"REF::S-Prim-Select:5:Default/S-Prim-Select-op39
|
||||||
¦
|
¦
|
||||||
bprop.32:[CNode]33:1
|
bprop.49:[CNode]50:1
|
||||||
bprop.32:[CNode]35:4
|
bprop.49:[CNode]52:4
|
||||||
bprop.32:[CNode]37:7bprop.32:[CNode]38:8bprop.32:[CNode]38:8"REF::S-Prim-MakeTuple:9:Default/S-Prim-MakeTuple-op26bprop.32*
|
bprop.49:[CNode]54:7bprop.49:[CNode]55:8bprop.49:[CNode]55:8"REF::S-Prim-MakeTuple:9:Default/S-Prim-MakeTuple-op40bprop.49*
|
||||||
bprop.32:cond*
|
bprop.49:cond*
|
||||||
|
|
||||||
bprop.32:x*
|
bprop.49:x*
|
||||||
|
|
||||||
bprop.32:y*
|
bprop.49:y*
|
||||||
bprop.32:out*
|
bprop.49:out*
|
||||||
bprop.32:dout2
|
bprop.49:dout2
|
||||||
bprop.32:[CNode]38:8:@da20d6cb5b4c7c7cced90d19d6551099a88fa59be3bc1e215683bed4c1b029abPbv
|
bprop.49:[CNode]55:8:@030ffb7e684c8e2604e69bfc02a6e32af82543a97f4aa0a6bc7cab6275ee9b9dPb&
|
||||||
|
S-Prim-MakeTuple:9S-Prim-MakeTuplebH
|
||||||
|
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]bv
|
||||||
S-Prim-Select:5
S-Prim-Select
|
S-Prim-Select:5
S-Prim-Select
|
||||||
output_names€ŠZoutput€3
|
output_names€ŠZoutput€3
|
||||||
input_names€ŠZ condition€ŠZx€ŠZy€bH
|
input_names€ŠZ condition€ŠZx€ŠZy€h
|
||||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
|
|
||||||
S-Prim-MakeTuple:9S-Prim-MakeTupleh
|
|
|
@ -2,13 +2,13 @@
|
||||||
0.1.1 MindSpore*1.9.0:¤
|
0.1.1 MindSpore*1.9.0:¤
|
||||||
’
|
’
|
||||||
|
|
||||||
bprop.23:xbprop.23:[CNode]24:1bprop.23:[CNode]24:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op15
|
bprop.40:xbprop.40:[CNode]41:1bprop.40:[CNode]41:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op29
|
||||||
z
|
z
|
||||||
bprop.23:[CNode]24:1bprop.23:[CNode]25:3bprop.23:[CNode]25:3"REF::S-Prim-MakeTuple:4:Default/S-Prim-MakeTuple-op16bprop.23*
|
bprop.40:[CNode]41:1bprop.40:[CNode]42:3bprop.40:[CNode]42:3"REF::S-Prim-MakeTuple:4:Default/S-Prim-MakeTuple-op30bprop.40*
|
||||||
|
|
||||||
bprop.23:x*
|
bprop.40:x*
|
||||||
bprop.23:out*
|
bprop.40:out*
|
||||||
bprop.23:dout2
|
bprop.40:dout2
|
||||||
bprop.23:[CNode]25:3:@da20d6cb5b4c7c7cced90d19d6551099a88fa59be3bc1e215683bed4c1b029abPbH
|
bprop.40:[CNode]42:3:@030ffb7e684c8e2604e69bfc02a6e32af82543a97f4aa0a6bc7cab6275ee9b9dPbH
|
||||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
|
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
|
||||||
S-Prim-MakeTuple:4S-Prim-MakeTupleh
|
S-Prim-MakeTuple:4S-Prim-MakeTupleh
|
|
@ -2,13 +2,13 @@
|
||||||
0.1.1 MindSpore*1.9.0:¤
|
0.1.1 MindSpore*1.9.0:¤
|
||||||
’
|
’
|
||||||
|
|
||||||
bprop.16:xbprop.16:[CNode]17:1bprop.16:[CNode]17:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op11
|
bprop.37:xbprop.37:[CNode]38:1bprop.37:[CNode]38:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op26
|
||||||
z
|
z
|
||||||
bprop.16:[CNode]17:1bprop.16:[CNode]18:3bprop.16:[CNode]18:3"REF::S-Prim-MakeTuple:4:Default/S-Prim-MakeTuple-op12bprop.16*
|
bprop.37:[CNode]38:1bprop.37:[CNode]39:3bprop.37:[CNode]39:3"REF::S-Prim-MakeTuple:4:Default/S-Prim-MakeTuple-op27bprop.37*
|
||||||
|
|
||||||
bprop.16:x*
|
bprop.37:x*
|
||||||
bprop.16:out*
|
bprop.37:out*
|
||||||
bprop.16:dout2
|
bprop.37:dout2
|
||||||
bprop.16:[CNode]18:3:@51014516712342426367430e7abdbb8b5454385a02c91339e89ad18f19803ee3PbH
|
bprop.37:[CNode]39:3:@157abcda70ad669686a320be427dedf2a17498d59a38b65d04c272c1a40296e0PbH
|
||||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
|
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
|
||||||
S-Prim-MakeTuple:4S-Prim-MakeTupleh
|
S-Prim-MakeTuple:4S-Prim-MakeTupleh
|
|
@ -1,20 +1,16 @@
|
||||||
|
|
||||||
0.1.1 MindSpore*1.9.0:Þ
|
0.1.1 MindSpore*1.9.0:ú
|
||||||
’
|
˜
|
||||||
|
bprop.147:xbprop.147:[CNode]148:1bprop.147:[CNode]148:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:/Default/S-Prim-hyper_map[zeros_like_leaf]-op107
|
||||||
bprop.74:xbprop.74:[CNode]75:1bprop.74:[CNode]75:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op54
|
˜
|
||||||
’
|
bprop.147:ybprop.147:[CNode]149:3bprop.147:[CNode]149:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:/Default/S-Prim-hyper_map[zeros_like_leaf]-op108
|
||||||
|
™
|
||||||
bprop.74:ybprop.74:[CNode]76:3bprop.74:[CNode]76:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op55
|
bprop.147:[CNode]148:1
|
||||||
<EFBFBD>
|
bprop.147:[CNode]149:3bprop.147:[CNode]150:4bprop.147:[CNode]150:4"REF::S-Prim-MakeTuple:5:Default/S-Prim-MakeTuple-op109 bprop.147*
|
||||||
bprop.74:[CNode]75:1
|
bprop.147:x*
|
||||||
bprop.74:[CNode]76:3bprop.74:[CNode]77:4bprop.74:[CNode]77:4"REF::S-Prim-MakeTuple:5:Default/S-Prim-MakeTuple-op56bprop.74*
|
bprop.147:y*
|
||||||
|
bprop.147:out*
|
||||||
bprop.74:x*
|
bprop.147:dout2
|
||||||
|
bprop.147:[CNode]150:4:@157abcda70ad669686a320be427dedf2a17498d59a38b65d04c272c1a40296e0Pb&
|
||||||
bprop.74:y*
|
S-Prim-MakeTuple:5S-Prim-MakeTuplebH
|
||||||
bprop.74:out*
|
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h
|
||||||
bprop.74:dout2
|
|
||||||
bprop.74:[CNode]77:4:@51014516712342426367430e7abdbb8b5454385a02c91339e89ad18f19803ee3PbH
|
|
||||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
|
|
||||||
S-Prim-MakeTuple:5S-Prim-MakeTupleh
|
|
|
@ -1,14 +1,14 @@
|
||||||
|
|
||||||
0.1.1 MindSpore*1.9.0:¢
|
0.1.1 MindSpore*1.9.0:£
|
||||||
<EFBFBD>
|
<EFBFBD>
|
||||||
|
|
||||||
bprop.11:xbprop.11:[CNode]12:1bprop.11:[CNode]12:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:-Default/S-Prim-hyper_map[zeros_like_leaf]-op7
|
bprop.13:xbprop.13:[CNode]14:1bprop.13:[CNode]14:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:-Default/S-Prim-hyper_map[zeros_like_leaf]-op9
|
||||||
y
|
z
|
||||||
bprop.11:[CNode]12:1bprop.11:[CNode]13:3bprop.11:[CNode]13:3"REF::S-Prim-MakeTuple:4:Default/S-Prim-MakeTuple-op8bprop.11*
|
bprop.13:[CNode]14:1bprop.13:[CNode]15:3bprop.13:[CNode]15:3"REF::S-Prim-MakeTuple:4:Default/S-Prim-MakeTuple-op10bprop.13*
|
||||||
|
|
||||||
bprop.11:x*
|
bprop.13:x*
|
||||||
bprop.11:out*
|
bprop.13:out*
|
||||||
bprop.11:dout2
|
bprop.13:dout2
|
||||||
bprop.11:[CNode]13:3:@da20d6cb5b4c7c7cced90d19d6551099a88fa59be3bc1e215683bed4c1b029abPb&
|
bprop.13:[CNode]15:3:@030ffb7e684c8e2604e69bfc02a6e32af82543a97f4aa0a6bc7cab6275ee9b9dPbH
|
||||||
S-Prim-MakeTuple:4S-Prim-MakeTuplebH
|
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
|
||||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h
|
S-Prim-MakeTuple:4S-Prim-MakeTupleh
|
|
@ -2503,7 +2503,6 @@ class HistogramFixedWidth(PrimitiveWithInfer):
|
||||||
self.add_prim_attr('dtype', 3)
|
self.add_prim_attr('dtype', 3)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Log(Primitive):
|
class Log(Primitive):
|
||||||
"""
|
"""
|
||||||
Returns the natural logarithm of a tensor element-wise.
|
Returns the natural logarithm of a tensor element-wise.
|
||||||
|
|
|
@ -0,0 +1,64 @@
|
||||||
|
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import mindspore as ms
|
||||||
|
from mindspore import ops, nn, context, Tensor
|
||||||
|
from .test_grad_of_dynamic import TestDynamicGrad
|
||||||
|
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE)
|
||||||
|
|
||||||
|
|
||||||
|
class NetAddcmul(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(NetAddcmul, self).__init__()
|
||||||
|
self.add = ops.Addcmul()
|
||||||
|
|
||||||
|
def construct(self, a, x1, x2, v):
|
||||||
|
return self.add(a, x1, x2, v)
|
||||||
|
|
||||||
|
|
||||||
|
def addcmul_test(is_dyn_rank):
|
||||||
|
a = Tensor(np.array([1, 1, 1]).astype(np.float32))
|
||||||
|
x1 = Tensor(np.array([[1], [2], [3]]).astype(np.float32))
|
||||||
|
x2 = Tensor(np.array([[1, 2, 3]]).astype(np.float32))
|
||||||
|
v = Tensor([1], ms.float32)
|
||||||
|
tester = TestDynamicGrad(NetAddcmul())
|
||||||
|
tester.test_dynamic_grad_net([a, x1, x2, v], is_dyn_rank)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
def test_addcmul_dyn_shape():
|
||||||
|
"""
|
||||||
|
Feature: Addcmul Grad DynamicShape.
|
||||||
|
Description: Test case of dynamic shape for Addcmul grad operator.
|
||||||
|
Expectation: success.
|
||||||
|
"""
|
||||||
|
addcmul_test(False)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
def test_addcmul_dyn_rank():
|
||||||
|
"""
|
||||||
|
Feature: Addcmul Grad DynamicShape.
|
||||||
|
Description: Test case of dynamic rank for Addcmul grad operator.
|
||||||
|
Expectation: success.
|
||||||
|
"""
|
||||||
|
addcmul_test(True)
|
|
@ -0,0 +1,97 @@
|
||||||
|
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore import Tensor, context
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from .test_grad_of_dynamic import TestDynamicGrad
|
||||||
|
|
||||||
|
|
||||||
|
class Atan2Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Atan2Net, self).__init__()
|
||||||
|
self.atan2 = P.Atan2()
|
||||||
|
|
||||||
|
def construct(self, x, y):
|
||||||
|
return self.atan2(x, y)
|
||||||
|
|
||||||
|
|
||||||
|
def dynamic_shape():
|
||||||
|
test_dynamic = TestDynamicGrad(Atan2Net())
|
||||||
|
x = Tensor(np.array([0, 1]).astype(np.float32))
|
||||||
|
y = Tensor(np.array([1, 1]).astype(np.float32))
|
||||||
|
test_dynamic.test_dynamic_grad_net((x, y))
|
||||||
|
|
||||||
|
|
||||||
|
def dynamic_rank():
|
||||||
|
test_dynamic = TestDynamicGrad(Atan2Net())
|
||||||
|
x = Tensor(np.array([0, 1]).astype(np.float32))
|
||||||
|
y = Tensor(np.array([1, 1]).astype(np.float32))
|
||||||
|
test_dynamic.test_dynamic_grad_net((x, y), True)
|
||||||
|
|
||||||
|
|
||||||
|
def test_dynamic_atan2_cpu():
|
||||||
|
"""
|
||||||
|
Feature: Atan2 Grad DynamicShape.
|
||||||
|
Description: Test case of dynamic shape for Atan2 grad operator on CPU.
|
||||||
|
Expectation: success.
|
||||||
|
"""
|
||||||
|
# Graph mode
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||||
|
dynamic_shape()
|
||||||
|
dynamic_rank()
|
||||||
|
# PyNative mode
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
|
||||||
|
dynamic_shape()
|
||||||
|
dynamic_rank()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
def test_dynamic_atan2_gpu():
|
||||||
|
"""
|
||||||
|
Feature: Atan2 Grad DynamicShape.
|
||||||
|
Description: Test case of dynamic shape for Atan2 grad operator on GPU.
|
||||||
|
Expectation: success.
|
||||||
|
"""
|
||||||
|
# Graph mode
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||||
|
dynamic_shape()
|
||||||
|
dynamic_rank()
|
||||||
|
# PyNative mode
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||||
|
dynamic_shape()
|
||||||
|
dynamic_rank()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
def test_dynamic_atan2_ascend():
|
||||||
|
"""
|
||||||
|
Feature: Atan2 Grad DynamicShape.
|
||||||
|
Description: Test case of dynamic shape for Atan2 grad operator on Ascend.
|
||||||
|
Expectation: success.
|
||||||
|
"""
|
||||||
|
# Graph mode
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||||
|
dynamic_shape()
|
||||||
|
dynamic_rank()
|
||||||
|
# PyNative mode
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||||
|
dynamic_shape()
|
||||||
|
dynamic_rank()
|
|
@ -0,0 +1,63 @@
|
||||||
|
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import mindspore
|
||||||
|
import mindspore.ops.operations.math_ops as M
|
||||||
|
from mindspore import nn, context, Tensor
|
||||||
|
from .test_grad_of_dynamic import TestDynamicGrad
|
||||||
|
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE)
|
||||||
|
|
||||||
|
|
||||||
|
class NetBatchMatMul(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(NetBatchMatMul, self).__init__()
|
||||||
|
self.batchmatmul = M.BatchMatMul()
|
||||||
|
|
||||||
|
def construct(self, x, y):
|
||||||
|
return self.batchmatmul(x, y)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
def test_batch_matmul_dynamic_shape():
|
||||||
|
"""
|
||||||
|
Feature: BatchMatMul Grad DynamicShape.
|
||||||
|
Description: Test case of dynamic shape for BatchMatMul grad operator on GPU.
|
||||||
|
Expectation: success.
|
||||||
|
"""
|
||||||
|
test_dynamic = TestDynamicGrad(NetBatchMatMul(), skip_convert_out_ids=[0])
|
||||||
|
x = Tensor(np.ones(shape=[2, 4, 1, 3]), mindspore.float32)
|
||||||
|
y = Tensor(np.ones(shape=[2, 4, 3, 4]), mindspore.float32)
|
||||||
|
inputs = [x, y]
|
||||||
|
test_dynamic.test_dynamic_grad_net(inputs, False)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
def test_batch_matmul_dynamic_rank():
|
||||||
|
"""
|
||||||
|
Feature: BatchMatMul Grad DynamicShape.
|
||||||
|
Description: Test case of dynamic rank for BatchMatMul grad operator on GPU.
|
||||||
|
Expectation: success.
|
||||||
|
"""
|
||||||
|
test_dynamic = TestDynamicGrad(NetBatchMatMul(), skip_convert_out_ids=[0])
|
||||||
|
x = Tensor(np.ones(shape=[2, 4, 1, 3]), mindspore.float32)
|
||||||
|
y = Tensor(np.ones(shape=[2, 4, 3, 4]), mindspore.float32)
|
||||||
|
inputs = [x, y]
|
||||||
|
test_dynamic.test_dynamic_grad_net(inputs, True)
|
|
@ -0,0 +1,99 @@
|
||||||
|
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import mindspore as ms
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore import Tensor, context
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from .test_grad_of_dynamic import TestDynamicGrad
|
||||||
|
|
||||||
|
|
||||||
|
class BiasAddNet(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(BiasAddNet, self).__init__()
|
||||||
|
self.bias_add = P.BiasAdd()
|
||||||
|
|
||||||
|
def construct(self, x, y):
|
||||||
|
return self.bias_add(x, y)
|
||||||
|
|
||||||
|
|
||||||
|
def run_dynamic_shape():
|
||||||
|
test_dynamic = TestDynamicGrad(BiasAddNet())
|
||||||
|
input_x = Tensor(np.arange(6).reshape((2, 3)), ms.float32)
|
||||||
|
bias = Tensor(np.random.random(3).reshape((3,)), ms.float32)
|
||||||
|
test_dynamic.test_dynamic_grad_net([input_x, bias])
|
||||||
|
|
||||||
|
|
||||||
|
def run_dynamic_rank():
|
||||||
|
test_dynamic = TestDynamicGrad(BiasAddNet())
|
||||||
|
input_x = Tensor(np.arange(6).reshape((2, 3)), ms.float32)
|
||||||
|
bias = Tensor(np.random.random(3).reshape((3,)), ms.float32)
|
||||||
|
test_dynamic.test_dynamic_grad_net([input_x, bias], True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
def test_dynamic_bias_add_cpu():
|
||||||
|
"""
|
||||||
|
Feature: BiasAdd Grad DynamicShape.
|
||||||
|
Description: Test case of dynamic shape for BiasAdd grad operator on CPU.
|
||||||
|
Expectation: success.
|
||||||
|
"""
|
||||||
|
# Graph mode
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||||
|
run_dynamic_shape()
|
||||||
|
run_dynamic_rank()
|
||||||
|
# PyNative mode
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
|
||||||
|
run_dynamic_shape()
|
||||||
|
run_dynamic_rank()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
def test_dynamic_bias_add_gpu():
|
||||||
|
"""
|
||||||
|
Feature: BiasAdd Grad DynamicShape.
|
||||||
|
Description: Test case of dynamic shape for BiasAdd grad operator on GPU.
|
||||||
|
Expectation: success.
|
||||||
|
"""
|
||||||
|
# Graph mode
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||||
|
run_dynamic_shape()
|
||||||
|
run_dynamic_rank()
|
||||||
|
# PyNative mode
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||||
|
run_dynamic_shape()
|
||||||
|
run_dynamic_rank()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
def test_dynamic_bias_add_ascend():
|
||||||
|
"""
|
||||||
|
Feature: BiasAdd Grad DynamicShape.
|
||||||
|
Description: Test case of dynamic shape for BiasAdd grad operator on Ascend.
|
||||||
|
Expectation: success.
|
||||||
|
"""
|
||||||
|
# Graph mode
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||||
|
run_dynamic_shape()
|
||||||
|
run_dynamic_rank()
|
||||||
|
# PyNative mode
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||||
|
run_dynamic_shape()
|
||||||
|
run_dynamic_rank()
|
|
@ -0,0 +1,63 @@
|
||||||
|
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
from mindspore import nn
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore import context
|
||||||
|
from .test_grad_of_dynamic import TestDynamicGrad
|
||||||
|
|
||||||
|
|
||||||
|
class NetCdist(nn.Cell):
|
||||||
|
def __init__(self, p):
|
||||||
|
super(NetCdist, self).__init__()
|
||||||
|
self.cdist = P.Cdist(p)
|
||||||
|
|
||||||
|
def construct(self, x1, x2):
|
||||||
|
return self.cdist(x1, x2)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
def test_dynamic_shape_cdist():
|
||||||
|
"""
|
||||||
|
Feature: Cdist Grad DynamicShape.
|
||||||
|
Description: Test case of dynamic shape for Cdist grad operator on CPU.
|
||||||
|
Expectation: success.
|
||||||
|
"""
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE)
|
||||||
|
test_dynamic = TestDynamicGrad(NetCdist(2.))
|
||||||
|
x1 = Tensor(np.array([[[1.0, 1.0], [2.0, 2.0]]]).astype(np.float32))
|
||||||
|
x2 = Tensor(np.array([[[3.0, 3.0], [3.0, 3.0]]]).astype(np.float32))
|
||||||
|
test_dynamic.test_dynamic_grad_net([x1, x2], False)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
def test_dynamic_rank_cdist():
|
||||||
|
"""
|
||||||
|
Feature: Cdist Grad DynamicShape.
|
||||||
|
Description: Test case of dynamic rank for Cdist grad operator on CPU.
|
||||||
|
Expectation: success.
|
||||||
|
"""
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE)
|
||||||
|
test_dynamic = TestDynamicGrad(NetCdist(2.))
|
||||||
|
x1 = Tensor(np.array([[[1.0, 1.0], [2.0, 2.0]]]).astype(np.float32))
|
||||||
|
x2 = Tensor(np.array([[[3.0, 3.0], [3.0, 3.0]]]).astype(np.float32))
|
||||||
|
test_dynamic.test_dynamic_grad_net([x1, x2], True)
|
|
@ -0,0 +1,111 @@
|
||||||
|
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
from mindspore import nn, context, Tensor
|
||||||
|
from mindspore.ops.operations import _grad_ops as G
|
||||||
|
from .test_grad_of_dynamic import TestDynamicGrad
|
||||||
|
|
||||||
|
|
||||||
|
class NetConv2DBackpropFilter(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(NetConv2DBackpropFilter, self).__init__()
|
||||||
|
self.conv_filter = G.Conv2DBackpropFilter(1, 3, pad_mode="valid", pad=0, mode=1, stride=(1, 1),
|
||||||
|
dilation=(1, 1, 1, 1), group=1)
|
||||||
|
self.w_shape = (1, 1, 3, 3)
|
||||||
|
|
||||||
|
def construct(self, out, x):
|
||||||
|
return self.conv_filter(out, x, self.w_shape)
|
||||||
|
|
||||||
|
|
||||||
|
def grad_dyn_case(is_dynamic_rank):
|
||||||
|
test_dynamic = TestDynamicGrad(NetConv2DBackpropFilter())
|
||||||
|
np.random.seed(1)
|
||||||
|
out = Tensor(np.random.normal(0, 1, (1, 1, 4, 4)).astype(np.float32))
|
||||||
|
x = Tensor(np.random.normal(0, 2, (1, 1, 6, 6)).astype(np.float32))
|
||||||
|
test_dynamic.test_dynamic_grad_net([out, x], is_dynamic_rank)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_gpu_grad_dynamic_shape():
|
||||||
|
"""
|
||||||
|
Feature: test Conv2DBackpropFilter dynamic shape on GPU.
|
||||||
|
Description: input is dynamic shape.
|
||||||
|
Expectation: the result match with static shape
|
||||||
|
"""
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||||
|
grad_dyn_case(False)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_gpu_grad_dynamic_rank():
|
||||||
|
"""
|
||||||
|
Feature: test Conv2DBackpropFilter dynamic rank on GPU.
|
||||||
|
Description: input is dynamic rank.
|
||||||
|
Expectation: the result match with static shape
|
||||||
|
"""
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||||
|
grad_dyn_case(True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_cpu_grad_dynamic_shape():
|
||||||
|
"""
|
||||||
|
Feature: test Conv2DBackpropFilter dynamic shape on CPU.
|
||||||
|
Description: input is dynamic shape.
|
||||||
|
Expectation: the result match with static shape
|
||||||
|
"""
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
|
||||||
|
grad_dyn_case(False)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_cpu_grad_dynamic_rank():
|
||||||
|
"""
|
||||||
|
Feature: test Conv2DBackpropFilter dynamic rank on CPU.
|
||||||
|
Description: input is dynamic rank.
|
||||||
|
Expectation: the result match with static shape
|
||||||
|
"""
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
|
||||||
|
grad_dyn_case(True)
|
||||||
|
|
||||||
|
|
||||||
|
def test_ascend_grad_dynamic_shape():
|
||||||
|
"""
|
||||||
|
Feature: test Conv2DBackpropFilter dynamic shape on Ascend.
|
||||||
|
Description: input is dynamic shape.
|
||||||
|
Expectation: the result match with static shape
|
||||||
|
"""
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||||
|
grad_dyn_case(False)
|
||||||
|
|
||||||
|
|
||||||
|
def test_ascend_grad_dynamic_rank():
|
||||||
|
"""
|
||||||
|
Feature: test Conv2DBackpropFilter dynamic rank on Ascend.
|
||||||
|
Description: input is dynamic rank.
|
||||||
|
Expectation: the result match with static shape
|
||||||
|
"""
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||||
|
grad_dyn_case(True)
|
|
@ -0,0 +1,125 @@
|
||||||
|
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
from mindspore import nn, context, Tensor
|
||||||
|
from mindspore.ops.operations import Conv3D
|
||||||
|
from .test_grad_of_dynamic import TestDynamicGrad
|
||||||
|
|
||||||
|
|
||||||
|
class NetConv3d(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(NetConv3d, self).__init__()
|
||||||
|
out_channel = 4
|
||||||
|
kernel_size = 2
|
||||||
|
self.conv = Conv3D(out_channel,
|
||||||
|
kernel_size,
|
||||||
|
mode=1,
|
||||||
|
pad_mode="valid",
|
||||||
|
pad=0,
|
||||||
|
stride=1,
|
||||||
|
dilation=1,
|
||||||
|
group=1)
|
||||||
|
|
||||||
|
def construct(self, x, w):
|
||||||
|
return self.conv(x, w)
|
||||||
|
|
||||||
|
|
||||||
|
def grad_dyn_case(is_dynamic_rank):
|
||||||
|
test_dynamic = TestDynamicGrad(NetConv3d())
|
||||||
|
input_np = np.arange(1 * 3 * 3 * 3 * 3).reshape(1, 3, 3, 3, 3).astype(np.float32)
|
||||||
|
weight_np = np.arange(4 * 3 * 2 * 2 * 2).reshape(4, 3, 2, 2, 2).astype(np.float32)
|
||||||
|
test_dynamic.test_dynamic_grad_net([Tensor(input_np), Tensor(weight_np)], is_dynamic_rank)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_gpu_grad_dynamic_shape_1():
|
||||||
|
"""
|
||||||
|
Feature: test Conv3D grad dynamic shape on GPU.
|
||||||
|
Description: input is dynamic shape.
|
||||||
|
Expectation: the result match with static shape
|
||||||
|
"""
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||||
|
grad_dyn_case(False)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_gpu_grad_dynamic_rank_1():
|
||||||
|
"""
|
||||||
|
Feature: test Conv3D grad dynamic rank on GPU.
|
||||||
|
Description: input is dynamic rank.
|
||||||
|
Expectation: the result match with static shape
|
||||||
|
"""
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||||
|
grad_dyn_case(True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_gpu_grad_dynamic_shape_2():
|
||||||
|
"""
|
||||||
|
Feature: test Conv3D grad dynamic shape on GPU.
|
||||||
|
Description: input is dynamic shape.
|
||||||
|
Expectation: the result match with static shape
|
||||||
|
"""
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||||
|
grad_dyn_case(False)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_gpu_grad_dynamic_rank_2():
|
||||||
|
"""
|
||||||
|
Feature: test Conv3D grad dynamic shape on GPU.
|
||||||
|
Description: input is dynamic shape.
|
||||||
|
Expectation: the result match with static shape
|
||||||
|
"""
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||||
|
grad_dyn_case(True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="CPU无Conv3DBackpropFilter, Conv3DBackpropInput, kernel实现")
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_cpu_grad_dynamic_shape():
|
||||||
|
"""
|
||||||
|
Feature: test Conv3D grad dynamic shape on CPU.
|
||||||
|
Description: input is dynamic shape.
|
||||||
|
Expectation: the result match with static shape
|
||||||
|
"""
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
|
||||||
|
grad_dyn_case(False)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="CPU无Conv3DBackpropFilter, Conv3DBackpropInput, kernel实现")
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_cpu_grad_dynamic_rank():
|
||||||
|
"""
|
||||||
|
Feature: test Conv3D grad dynamic rank on CPU.
|
||||||
|
Description: input is dynamic rank.
|
||||||
|
Expectation: the result match with static shape
|
||||||
|
"""
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
|
||||||
|
grad_dyn_case(True)
|
|
@ -0,0 +1,125 @@
|
||||||
|
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
from mindspore import nn, context, Tensor
|
||||||
|
from mindspore.ops.operations import Conv3DTranspose
|
||||||
|
from .test_grad_of_dynamic import TestDynamicGrad
|
||||||
|
|
||||||
|
|
||||||
|
class NetConv3dTranspose(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(NetConv3dTranspose, self).__init__()
|
||||||
|
in_channel = 2
|
||||||
|
out_channel = 2
|
||||||
|
kernel_size = 2
|
||||||
|
self.conv_trans = Conv3DTranspose(in_channel, out_channel,
|
||||||
|
kernel_size,
|
||||||
|
pad_mode="pad",
|
||||||
|
pad=1,
|
||||||
|
stride=1,
|
||||||
|
dilation=1,
|
||||||
|
group=1)
|
||||||
|
|
||||||
|
def construct(self, x, w):
|
||||||
|
return self.conv_trans(x, w)
|
||||||
|
|
||||||
|
|
||||||
|
def grad_dyn_case(is_dynamic_rank):
|
||||||
|
x = Tensor(np.arange(1 * 2 * 3 * 3 * 3).reshape(1, 2, 3, 3, 3).astype(np.float32))
|
||||||
|
w = Tensor(np.ones((2, 2, 2, 2, 2)).astype(np.float32))
|
||||||
|
test_dynamic = TestDynamicGrad(NetConv3dTranspose())
|
||||||
|
test_dynamic.test_dynamic_grad_net([x, w], is_dynamic_rank)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_gpu_grad_dynamic_shape_1():
|
||||||
|
"""
|
||||||
|
Feature: test Conv3DTranspose grad dynamic shape on GPU.
|
||||||
|
Description: input is dynamic shape.
|
||||||
|
Expectation: the result match with static shape
|
||||||
|
"""
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||||
|
grad_dyn_case(False)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_gpu_grad_dynamic_rank_1():
|
||||||
|
"""
|
||||||
|
Feature: test Conv3DTranspose grad dynamic rank on GPU.
|
||||||
|
Description: input is dynamic rank.
|
||||||
|
Expectation: the result match with static shape
|
||||||
|
"""
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||||
|
grad_dyn_case(True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_gpu_grad_dynamic_shape_2():
|
||||||
|
"""
|
||||||
|
Feature: test Conv3DTranspose grad dynamic shape on GPU.
|
||||||
|
Description: input is dynamic shape.
|
||||||
|
Expectation: the result match with static shape
|
||||||
|
"""
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||||
|
grad_dyn_case(False)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_gpu_grad_dynamic_rank_2():
|
||||||
|
"""
|
||||||
|
Feature: test Conv3DTranspose grad dynamic shape on GPU.
|
||||||
|
Description: input is dynamic shape.
|
||||||
|
Expectation: the result match with static shape
|
||||||
|
"""
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||||
|
grad_dyn_case(True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="CPU无Conv3DBackpropFilter, Conv3DBackpropInput, kernel实现")
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_cpu_grad_dynamic_shape():
|
||||||
|
"""
|
||||||
|
Feature: test Conv3DTranspose grad dynamic shape on CPU.
|
||||||
|
Description: input is dynamic shape.
|
||||||
|
Expectation: the result match with static shape
|
||||||
|
"""
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
|
||||||
|
grad_dyn_case(False)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="CPU无Conv3DBackpropFilter, Conv3DBackpropInput, kernel实现")
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_cpu_grad_dynamic_rank():
|
||||||
|
"""
|
||||||
|
Feature: test Conv3DTranspose grad dynamic rank on CPU.
|
||||||
|
Description: input is dynamic rank.
|
||||||
|
Expectation: the result match with static shape
|
||||||
|
"""
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
|
||||||
|
grad_dyn_case(True)
|
|
@ -0,0 +1,63 @@
|
||||||
|
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
from mindspore import ops, nn, context, Tensor
|
||||||
|
from .test_grad_of_dynamic import TestDynamicGrad
|
||||||
|
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE)
|
||||||
|
|
||||||
|
|
||||||
|
class NetDivNoNan(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(NetDivNoNan, self).__init__()
|
||||||
|
self.div = ops.DivNoNan()
|
||||||
|
|
||||||
|
def construct(self, x, y):
|
||||||
|
return self.div(x, y)
|
||||||
|
|
||||||
|
|
||||||
|
def divnonan_test(is_dyn_rank):
|
||||||
|
x = Tensor(np.array([[1, 2, 3], [4, 5, 6]]).astype(np.float32))
|
||||||
|
y = Tensor(np.array([[7, 8, 9]]).astype(np.float32))
|
||||||
|
tester = TestDynamicGrad(NetDivNoNan())
|
||||||
|
tester.test_dynamic_grad_net([x, y], is_dyn_rank)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
def test_divnonan_dyn_shape():
|
||||||
|
"""
|
||||||
|
Feature: DivNoNan Grad DynamicShape.
|
||||||
|
Description: Test case of dynamic shape for DivNoNan grad operator.
|
||||||
|
Expectation: success.
|
||||||
|
"""
|
||||||
|
divnonan_test(False)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
def test_divnonan_dyn_rank():
|
||||||
|
"""
|
||||||
|
Feature: DivNoNan Grad DynamicShape.
|
||||||
|
Description: Test case of dynamic rank for DivNoNan grad operator.
|
||||||
|
Expectation: success.
|
||||||
|
"""
|
||||||
|
divnonan_test(True)
|
|
@ -0,0 +1,83 @@
|
||||||
|
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import mindspore as ms
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore import Tensor, context
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from .test_grad_of_dynamic import TestDynamicGrad
|
||||||
|
|
||||||
|
|
||||||
|
class MaskedFillNet(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(MaskedFillNet, self).__init__()
|
||||||
|
self.maskedfill = P.MaskedFill()
|
||||||
|
|
||||||
|
def construct(self, input_0, mask, value):
|
||||||
|
return self.maskedfill(input_0, mask, value)
|
||||||
|
|
||||||
|
|
||||||
|
def run_dynamic_shape():
|
||||||
|
test_dynamic = TestDynamicGrad(MaskedFillNet())
|
||||||
|
input_0 = Tensor(np.array([1., 2., 3., 4.]), ms.float32)
|
||||||
|
mask = Tensor(np.array([True, True, False, True]), ms.bool_)
|
||||||
|
value = Tensor(0.5, ms.float32)
|
||||||
|
test_dynamic.test_dynamic_grad_net([input_0, mask, value])
|
||||||
|
|
||||||
|
|
||||||
|
def run_dynamic_rank():
|
||||||
|
test_dynamic = TestDynamicGrad(MaskedFillNet())
|
||||||
|
input_0 = Tensor(np.array([1., 2., 3., 4.]), ms.float32)
|
||||||
|
mask = Tensor(np.array([True, True, False, True]), ms.bool_)
|
||||||
|
value = Tensor(0.5, ms.float32)
|
||||||
|
test_dynamic.test_dynamic_grad_net([input_0, mask, value], True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
def test_dynamic_maskedfill_gpu():
|
||||||
|
"""
|
||||||
|
Feature: MaskedFill Grad DynamicShape.
|
||||||
|
Description: Test case of dynamic shape for MaskedFill grad operator on GPU.
|
||||||
|
Expectation: success.
|
||||||
|
"""
|
||||||
|
# Graph mode
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||||
|
run_dynamic_shape()
|
||||||
|
run_dynamic_rank()
|
||||||
|
# PyNative mode
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||||
|
run_dynamic_shape()
|
||||||
|
run_dynamic_rank()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
def test_dynamic_maskedfill_ascend():
|
||||||
|
"""
|
||||||
|
Feature: MaskedFill Grad DynamicShape.
|
||||||
|
Description: Test case of dynamic shape for MaskedFill grad operator on Ascend.
|
||||||
|
Expectation: success.
|
||||||
|
"""
|
||||||
|
# Graph mode
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||||
|
run_dynamic_shape()
|
||||||
|
run_dynamic_rank()
|
||||||
|
# PyNative mode
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||||
|
run_dynamic_shape()
|
||||||
|
run_dynamic_rank()
|
|
@ -0,0 +1,78 @@
|
||||||
|
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore import Tensor, context
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from .test_grad_of_dynamic import TestDynamicGrad
|
||||||
|
|
||||||
|
|
||||||
|
class MatrixDeterminantNet(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(MatrixDeterminantNet, self).__init__()
|
||||||
|
self.matrix_determinant = P.MatrixDeterminant()
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
return self.matrix_determinant(x)
|
||||||
|
|
||||||
|
|
||||||
|
def dynamic_shape():
|
||||||
|
test_dynamic = TestDynamicGrad(MatrixDeterminantNet())
|
||||||
|
x = Tensor(np.array([[[-4.5, -1.5], [7.0, 6.0]], [[2.5, 0.5], [3.0, 9.0]]]).astype(np.float32))
|
||||||
|
test_dynamic.test_dynamic_grad_net(x)
|
||||||
|
|
||||||
|
|
||||||
|
def dynamic_rank():
|
||||||
|
test_dynamic = TestDynamicGrad(MatrixDeterminantNet())
|
||||||
|
x = Tensor(np.array([[[-4.5, -1.5], [7.0, 6.0]], [[2.5, 0.5], [3.0, 9.0]]]).astype(np.float32))
|
||||||
|
test_dynamic.test_dynamic_grad_net(x, True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
def test_dynamic_matrix_determinant_cpu():
|
||||||
|
"""
|
||||||
|
Feature: MatrixDeterminant Grad DynamicShape.
|
||||||
|
Description: Test case of dynamic shape for MatrixDeterminant grad operator on CPU.
|
||||||
|
Expectation: success.
|
||||||
|
"""
|
||||||
|
# Graph mode
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||||
|
dynamic_shape()
|
||||||
|
dynamic_rank()
|
||||||
|
# PyNative mode
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
|
||||||
|
dynamic_shape()
|
||||||
|
dynamic_rank()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
def test_dynamic_matrix_determinant_gpu():
|
||||||
|
"""
|
||||||
|
Feature: MatrixDeterminant Grad DynamicShape.
|
||||||
|
Description: Test case of dynamic shape for MatrixDeterminant grad operator on GPU.
|
||||||
|
Expectation: success.
|
||||||
|
"""
|
||||||
|
# Graph mode
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||||
|
dynamic_shape()
|
||||||
|
dynamic_rank()
|
||||||
|
# PyNative mode
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||||
|
dynamic_shape()
|
||||||
|
dynamic_rank()
|
|
@ -0,0 +1,80 @@
|
||||||
|
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
from mindspore import nn
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore import context
|
||||||
|
from mindspore import dtype as mstype
|
||||||
|
from .test_grad_of_dynamic import TestDynamicGrad
|
||||||
|
|
||||||
|
|
||||||
|
class MatrixDiagPartV3Net(nn.Cell):
|
||||||
|
def __init__(self, align='LEFT_RIGHT'):
|
||||||
|
super(MatrixDiagPartV3Net, self).__init__()
|
||||||
|
self.matrix_diag_dart_v3 = P.array_ops.MatrixDiagPartV3(align=align)
|
||||||
|
|
||||||
|
def construct(self, x, k, padding_value):
|
||||||
|
return self.matrix_diag_dart_v3(x, k, padding_value)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
def test_dynamic_shape_matrix_diag_partv3():
|
||||||
|
"""
|
||||||
|
Feature: MatrixDiagPartV3 Grad DynamicShape.
|
||||||
|
Description: Test case of dynamic shape for MatrixDiagPartV3 grad operator on CPU and GPU.
|
||||||
|
Expectation: success.
|
||||||
|
"""
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE)
|
||||||
|
align = 'RIGHT_LEFT'
|
||||||
|
test_dynamic = TestDynamicGrad(MatrixDiagPartV3Net(align))
|
||||||
|
input_x = Tensor(np.array([[[1, 2, 3, 4],
|
||||||
|
[5, 6, 7, 8],
|
||||||
|
[9, 8, 7, 6]],
|
||||||
|
[[5, 4, 3, 2],
|
||||||
|
[1, 2, 3, 4],
|
||||||
|
[5, 6, 7, 8]]]), mstype.float32)
|
||||||
|
k = Tensor(1, mstype.int32)
|
||||||
|
padding_value = Tensor(0, mstype.float32)
|
||||||
|
test_dynamic.test_dynamic_grad_net([input_x, k, padding_value], False)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
def test_dynamic_rank_matrix_diag_partv3():
|
||||||
|
"""
|
||||||
|
Feature: MatrixDiagPartV3 Grad DynamicShape.
|
||||||
|
Description: Test case of dynamic rank for MatrixDiagPartV3 grad operator on CPU and GPU.
|
||||||
|
Expectation: success.
|
||||||
|
"""
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE)
|
||||||
|
align = 'RIGHT_LEFT'
|
||||||
|
test_dynamic = TestDynamicGrad(MatrixDiagPartV3Net(align))
|
||||||
|
input_x = Tensor(np.array([[[1, 2, 3, 4],
|
||||||
|
[5, 6, 7, 8],
|
||||||
|
[9, 8, 7, 6]],
|
||||||
|
[[5, 4, 3, 2],
|
||||||
|
[1, 2, 3, 4],
|
||||||
|
[5, 6, 7, 8]]]), mstype.float32)
|
||||||
|
k = Tensor(1, mstype.int32)
|
||||||
|
padding_value = Tensor(0, mstype.float32)
|
||||||
|
test_dynamic.test_dynamic_grad_net([input_x, k, padding_value], True)
|
|
@ -0,0 +1,83 @@
|
||||||
|
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import mindspore as ms
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore import Tensor, context
|
||||||
|
from mindspore.ops.operations.array_ops import MatrixDiagV3
|
||||||
|
from .test_grad_of_dynamic import TestDynamicGrad
|
||||||
|
|
||||||
|
|
||||||
|
class MatrixDiagV3Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(MatrixDiagV3Net, self).__init__()
|
||||||
|
self.matrix_diag_v3 = MatrixDiagV3(align='LEFT_RIGHT')
|
||||||
|
|
||||||
|
def construct(self, x, k, num_rows, num_cols, padding_value):
|
||||||
|
return self.matrix_diag_v3(x, k, num_rows, num_cols, padding_value)
|
||||||
|
|
||||||
|
|
||||||
|
def run_dynamic_shape():
|
||||||
|
test_dynamic = TestDynamicGrad(MatrixDiagV3Net())
|
||||||
|
x = Tensor(np.array([[8, 9, 0], [1, 2, 3], [0, 4, 5]]), ms.float32)
|
||||||
|
k = Tensor(np.array([-1, 1]), ms.int32)
|
||||||
|
num_rows = Tensor(np.array(3), ms.int32)
|
||||||
|
num_cols = Tensor(np.array(3), ms.int32)
|
||||||
|
padding_value = Tensor(np.array(11), ms.float32)
|
||||||
|
test_dynamic.test_dynamic_grad_net(
|
||||||
|
[x, k, num_rows, num_cols, padding_value])
|
||||||
|
|
||||||
|
|
||||||
|
def run_dynamic_rank():
|
||||||
|
test_dynamic = TestDynamicGrad(MatrixDiagV3Net())
|
||||||
|
x = Tensor(np.array([[8, 9, 0],
|
||||||
|
[1, 2, 3],
|
||||||
|
[0, 4, 5]]), ms.float32)
|
||||||
|
k = Tensor(np.array([-1, 1]), ms.int32)
|
||||||
|
num_rows = Tensor(np.array(3), ms.int32)
|
||||||
|
num_cols = Tensor(np.array(3), ms.int32)
|
||||||
|
padding_value = Tensor(np.array(11), ms.float32)
|
||||||
|
test_dynamic.test_dynamic_grad_net(
|
||||||
|
[x, k, num_rows, num_cols, padding_value], True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
def test_dynamic_matrix_diag_v3_cpu():
|
||||||
|
"""
|
||||||
|
Feature: MatrixDiagV3 Grad DynamicShape.
|
||||||
|
Description: Test case of dynamic shape for MatrixDiagV3 grad operator on CPU.
|
||||||
|
Expectation: success.
|
||||||
|
"""
|
||||||
|
# PyNative mode
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
|
||||||
|
run_dynamic_shape()
|
||||||
|
run_dynamic_rank()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
def test_dynamic_matrix_diag_v3_gpu():
|
||||||
|
"""
|
||||||
|
Feature: MatrixDiagV3 Grad DynamicShape.
|
||||||
|
Description: Test case of dynamic shape for MatrixDiagV3 grad operator on GPU.
|
||||||
|
Expectation: success.
|
||||||
|
"""
|
||||||
|
# PyNative mode
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||||
|
run_dynamic_shape()
|
||||||
|
run_dynamic_rank()
|
|
@ -0,0 +1,97 @@
|
||||||
|
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import mindspore as ms
|
||||||
|
from mindspore import nn, context, Tensor
|
||||||
|
from mindspore.ops.operations.sparse_ops import SparseAdd
|
||||||
|
from .test_grad_of_dynamic import TestDynamicGrad
|
||||||
|
|
||||||
|
|
||||||
|
class NetSparseAdd(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(NetSparseAdd, self).__init__()
|
||||||
|
self.sparse_add = SparseAdd()
|
||||||
|
|
||||||
|
def construct(self, a_indices, a_values, a_shape, b_indices, b_values, b_shape, thresh):
|
||||||
|
return self.sparse_add(a_indices, a_values, a_shape, b_indices, b_values, b_shape, thresh)
|
||||||
|
|
||||||
|
|
||||||
|
def grad_dyn_case(is_dynamic_rank):
|
||||||
|
test_dynamic = TestDynamicGrad(NetSparseAdd())
|
||||||
|
value_type = ms.float32
|
||||||
|
thresh_type = ms.float32
|
||||||
|
thresh_value = 0
|
||||||
|
a_indices = Tensor([[0, 1], [1, 2]], ms.int64)
|
||||||
|
a_values = Tensor([1, 2], value_type)
|
||||||
|
a_shape = Tensor([3, 4], ms.int64)
|
||||||
|
b_indices = Tensor([[0, 1], [1, 2]], ms.int64)
|
||||||
|
b_values = Tensor([1, 2], value_type)
|
||||||
|
b_shape = Tensor([3, 4], ms.int64)
|
||||||
|
thresh = Tensor(thresh_value, thresh_type)
|
||||||
|
test_dynamic.test_dynamic_grad_net([a_indices, a_values, a_shape, b_indices, b_values, b_shape, thresh],
|
||||||
|
is_dynamic_rank)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_gpu_grad_dynamic_shape():
|
||||||
|
"""
|
||||||
|
Feature: test SparseAdd dynamic shape on GPU.
|
||||||
|
Description: input is dynamic shape.
|
||||||
|
Expectation: the result match with static shape
|
||||||
|
"""
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||||
|
grad_dyn_case(False)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_gpu_grad_dynamic_rank():
|
||||||
|
"""
|
||||||
|
Feature: test SparseAdd dynamic rank on GPU.
|
||||||
|
Description: input is dynamic rank.
|
||||||
|
Expectation: the result match with static shape
|
||||||
|
"""
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||||
|
grad_dyn_case(True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_cpu_grad_dynamic_shape():
|
||||||
|
"""
|
||||||
|
Feature: test SparseAdd dynamic shape on CPU.
|
||||||
|
Description: input is dynamic shape.
|
||||||
|
Expectation: the result match with static shape
|
||||||
|
"""
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
|
||||||
|
grad_dyn_case(False)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_cpu_grad_dynamic_rank():
|
||||||
|
"""
|
||||||
|
Feature: test SparseAdd dynamic rank on CPU.
|
||||||
|
Description: input is dynamic rank.
|
||||||
|
Expectation: the result match with static shape
|
||||||
|
"""
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
|
||||||
|
grad_dyn_case(True)
|
|
@ -0,0 +1,63 @@
|
||||||
|
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import numpy as np
|
||||||
|
from mindspore import nn, context, Tensor
|
||||||
|
from mindspore.ops.operations.sparse_ops import SparseSegmentSqrtN
|
||||||
|
from .test_grad_of_dynamic import TestDynamicGrad
|
||||||
|
|
||||||
|
|
||||||
|
class NetSparseSegmentSqrtN(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(NetSparseSegmentSqrtN, self).__init__()
|
||||||
|
self.sparse_seg_sqrt_n = SparseSegmentSqrtN()
|
||||||
|
|
||||||
|
def construct(self, x, indices, seg):
|
||||||
|
return self.sparse_seg_sqrt_n(x, indices, seg)
|
||||||
|
|
||||||
|
|
||||||
|
def grad_dyn_case(is_dynamic_rank):
|
||||||
|
test_dynamic = TestDynamicGrad(NetSparseSegmentSqrtN())
|
||||||
|
x = Tensor(np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]).astype(np.float32))
|
||||||
|
indices = Tensor(np.array([0, 1, 2]).astype(np.int32))
|
||||||
|
segment_ids = Tensor(np.array([0, 1, 2]).astype(np.int32))
|
||||||
|
test_dynamic.test_dynamic_grad_net((x, indices, segment_ids), is_dynamic_rank)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_cpu_grad_dynamic_shape():
|
||||||
|
"""
|
||||||
|
Feature: test SparseSegmentSqrtN dynamic shape on CPU.
|
||||||
|
Description: input is dynamic shape.
|
||||||
|
Expectation: the result match with static shape
|
||||||
|
"""
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
|
||||||
|
grad_dyn_case(False)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_cpu_grad_dynamic_rank():
|
||||||
|
"""
|
||||||
|
Feature: test SparseSegmentSqrtN dynamic rank on CPU.
|
||||||
|
Description: input is dynamic rank.
|
||||||
|
Expectation: the result match with static shape
|
||||||
|
"""
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
|
||||||
|
grad_dyn_case(True)
|
|
@ -0,0 +1,64 @@
|
||||||
|
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import mindspore as ms
|
||||||
|
from mindspore import nn, context, Tensor
|
||||||
|
from mindspore.ops.operations.sparse_ops import SparseSegmentSqrtNWithNumSegments
|
||||||
|
from .test_grad_of_dynamic import TestDynamicGrad
|
||||||
|
|
||||||
|
|
||||||
|
class NetSparseSegmentSqrtNWithNumSegments(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(NetSparseSegmentSqrtNWithNumSegments, self).__init__()
|
||||||
|
self.sparse_seg_sqrt_n_with_n_seg = SparseSegmentSqrtNWithNumSegments()
|
||||||
|
|
||||||
|
def construct(self, x, indices, seg, num):
|
||||||
|
return self.sparse_seg_sqrt_n_with_n_seg(x, indices, seg, num)
|
||||||
|
|
||||||
|
|
||||||
|
def grad_dyn_case(is_dynamic_rank):
|
||||||
|
test_dynamic = TestDynamicGrad(NetSparseSegmentSqrtNWithNumSegments(), skip_convert_out_ids=[0])
|
||||||
|
x = Tensor([[0, 1, 0, 0], [0, 1, 1, 0], [1, 0, 1, 0]], dtype=ms.float16)
|
||||||
|
indices = Tensor([0, 2, 1], dtype=ms.int32)
|
||||||
|
segment_ids = Tensor([0, 1, 2], dtype=ms.int32)
|
||||||
|
num_segments = Tensor([4], dtype=ms.int32)
|
||||||
|
test_dynamic.test_dynamic_grad_net((x, indices, segment_ids, num_segments), is_dynamic_rank)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_cpu_grad_dynamic_shape():
|
||||||
|
"""
|
||||||
|
Feature: test SparseSegmentSqrtNWithNumSegments dynamic shape on CPU.
|
||||||
|
Description: input is dynamic shape.
|
||||||
|
Expectation: the result match with static shape
|
||||||
|
"""
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
|
||||||
|
grad_dyn_case(False)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_cpu_grad_dynamic_rank():
|
||||||
|
"""
|
||||||
|
Feature: test SparseSegmentSqrtNWithNumSegments dynamic rank on CPU.
|
||||||
|
Description: input is dynamic rank.
|
||||||
|
Expectation: the result match with static shape
|
||||||
|
"""
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
|
||||||
|
grad_dyn_case(True)
|
|
@ -0,0 +1,67 @@
|
||||||
|
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import mindspore
|
||||||
|
import mindspore.ops.operations.sparse_ops as S
|
||||||
|
from mindspore import nn, context, Tensor
|
||||||
|
from .test_grad_of_dynamic import TestDynamicGrad
|
||||||
|
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
|
||||||
|
|
||||||
|
|
||||||
|
class NetSparseSegmentMeanWithNumSegments(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(NetSparseSegmentMeanWithNumSegments, self).__init__()
|
||||||
|
self.sparse_segment_mean_with_num_segments = S.SparseSegmentMeanWithNumSegments()
|
||||||
|
|
||||||
|
def construct(self, x, indices, seg_ids, num_segments):
|
||||||
|
return self.sparse_segment_mean_with_num_segments(x, indices, seg_ids, num_segments)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
def test_sparse_segment_mean_with_num_segments_dynamic_shape():
|
||||||
|
"""
|
||||||
|
Feature: SparseSegmentMeanWithNumSegments Grad DynamicShape.
|
||||||
|
Description: Test case of dynamic shape for SparseSegmentMeanWithNumSegments grad operator on CPU.
|
||||||
|
Expectation: success.
|
||||||
|
"""
|
||||||
|
test_dynamic = TestDynamicGrad(NetSparseSegmentMeanWithNumSegments(), skip_convert_out_ids=[0])
|
||||||
|
x = Tensor([[0, 2, 0, 0], [0, 1, 1, 0], [2, 0, 2, 0]], dtype=mindspore.float32)
|
||||||
|
indices = Tensor([0, 2, 1], dtype=mindspore.int32)
|
||||||
|
segment_ids = Tensor([0, 0, 2], dtype=mindspore.int32)
|
||||||
|
num_segments = Tensor([4], dtype=mindspore.int32)
|
||||||
|
inputs = [x, indices, segment_ids, num_segments]
|
||||||
|
test_dynamic.test_dynamic_grad_net(inputs, False)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
def test_sparse_segment_mean_with_num_segments_dynamic_rank():
|
||||||
|
"""
|
||||||
|
Feature: SparseSegmentMeanWithNumSegments Grad DynamicShape.
|
||||||
|
Description: Test case of dynamic rank for SparseSegmentMeanWithNumSegments grad operator on CPU.
|
||||||
|
Expectation: success.
|
||||||
|
"""
|
||||||
|
x = Tensor([[0, 2, 0, 0], [0, 1, 1, 0], [2, 0, 2, 0]], dtype=mindspore.float32)
|
||||||
|
indices = Tensor([0, 2, 1], dtype=mindspore.int32)
|
||||||
|
segment_ids = Tensor([0, 0, 2], dtype=mindspore.int32)
|
||||||
|
num_segments = Tensor([4], dtype=mindspore.int32)
|
||||||
|
test_dynamic = TestDynamicGrad(NetSparseSegmentMeanWithNumSegments(), skip_convert_out_ids=[0])
|
||||||
|
inputs = [x, indices, segment_ids, num_segments]
|
||||||
|
test_dynamic.test_dynamic_grad_net(inputs, True)
|
|
@ -0,0 +1,110 @@
|
||||||
|
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
from mindspore import nn, ops, context, Tensor
|
||||||
|
from .test_grad_of_dynamic import TestDynamicGrad
|
||||||
|
|
||||||
|
|
||||||
|
class NetSquaredDifference(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(NetSquaredDifference, self).__init__()
|
||||||
|
self.squared_diff = ops.SquaredDifference()
|
||||||
|
|
||||||
|
def construct(self, x, y):
|
||||||
|
return self.squared_diff(x, y)
|
||||||
|
|
||||||
|
|
||||||
|
def grad_dyn_case(is_dynamic_rank):
|
||||||
|
test_dynamic = TestDynamicGrad(NetSquaredDifference())
|
||||||
|
np.random.seed(1)
|
||||||
|
x = Tensor(np.random.uniform(10, 20, (3, 4, 5, 2)).astype(np.float16))
|
||||||
|
y = Tensor(np.random.uniform(40, 50, (3, 4, 5, 2)).astype(np.float16))
|
||||||
|
test_dynamic.test_dynamic_grad_net([x, y], is_dynamic_rank)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_gpu_grad_dynamic_shape():
|
||||||
|
"""
|
||||||
|
Feature: test SquaredDifference dynamic shape on GPU.
|
||||||
|
Description: input is dynamic shape.
|
||||||
|
Expectation: the result match with static shape
|
||||||
|
"""
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||||
|
grad_dyn_case(False)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_gpu_grad_dynamic_rank():
|
||||||
|
"""
|
||||||
|
Feature: test SquaredDifference dynamic rank on GPU.
|
||||||
|
Description: input is dynamic rank.
|
||||||
|
Expectation: the result match with static shape
|
||||||
|
"""
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||||
|
grad_dyn_case(True)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cpu_grad_dynamic_shape():
|
||||||
|
"""
|
||||||
|
Feature: test SquaredDifference dynamic shape on CPU.
|
||||||
|
Description: input is dynamic shape.
|
||||||
|
Expectation: the result match with static shape
|
||||||
|
"""
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
|
||||||
|
grad_dyn_case(False)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cpu_grad_dynamic_rank():
|
||||||
|
"""
|
||||||
|
Feature: test SquaredDifference dynamic rank on CPU.
|
||||||
|
Description: input is dynamic rank.
|
||||||
|
Expectation: the result match with static shape
|
||||||
|
"""
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
|
||||||
|
grad_dyn_case(True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_ascend_grad_dynamic_shape():
|
||||||
|
"""
|
||||||
|
Feature: test SquaredDifference dynamic shape on Ascend.
|
||||||
|
Description: input is dynamic shape.
|
||||||
|
Expectation: the result match with static shape
|
||||||
|
"""
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||||
|
grad_dyn_case(False)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_ascend_grad_dynamic_rank():
|
||||||
|
"""
|
||||||
|
Feature: test SquaredDifference dynamic rank on Ascend.
|
||||||
|
Description: input is dynamic rank.
|
||||||
|
Expectation: the result match with static shape
|
||||||
|
"""
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||||
|
grad_dyn_case(True)
|
|
@ -0,0 +1,63 @@
|
||||||
|
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
from mindspore import ops, nn, context, Tensor
|
||||||
|
from .test_grad_of_dynamic import TestDynamicGrad
|
||||||
|
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE)
|
||||||
|
|
||||||
|
|
||||||
|
class NetSub(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(NetSub, self).__init__()
|
||||||
|
self.sub = ops.Sub()
|
||||||
|
|
||||||
|
def construct(self, x, y):
|
||||||
|
return self.sub(x, y)
|
||||||
|
|
||||||
|
|
||||||
|
def sub_test(is_dyn_rank):
|
||||||
|
x = Tensor(np.array([[1, 2, 3], [4, 5, 6]]).astype(np.float32))
|
||||||
|
y = Tensor(np.array([[7, 8, 9]]).astype(np.float32))
|
||||||
|
tester = TestDynamicGrad(NetSub())
|
||||||
|
tester.test_dynamic_grad_net([x, y], is_dyn_rank)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
def test_sub_dyn_shape():
|
||||||
|
"""
|
||||||
|
Feature: Sub Grad DynamicShape.
|
||||||
|
Description: Test case of dynamic shape for Sub grad operator.
|
||||||
|
Expectation: success.
|
||||||
|
"""
|
||||||
|
sub_test(False)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
def test_sub_dyn_rank():
|
||||||
|
"""
|
||||||
|
Feature: Sub Grad DynamicShape.
|
||||||
|
Description: Test case of dynamic rank for Sub grad operator.
|
||||||
|
Expectation: success.
|
||||||
|
"""
|
||||||
|
sub_test(True)
|
|
@ -0,0 +1,62 @@
|
||||||
|
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
from mindspore import ops, nn, context, Tensor
|
||||||
|
from .test_grad_of_dynamic import TestDynamicGrad
|
||||||
|
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE)
|
||||||
|
|
||||||
|
|
||||||
|
class NetTopK(nn.Cell):
|
||||||
|
def __init__(self, k):
|
||||||
|
super(NetTopK, self).__init__()
|
||||||
|
self.topk = ops.TopK()
|
||||||
|
self.k = k
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
return self.topk(x, self.k)
|
||||||
|
|
||||||
|
|
||||||
|
def topk_test(is_dyn_rank):
|
||||||
|
x = Tensor(np.array([[1, 2, 3, 4, 5]]).astype(np.float32))
|
||||||
|
k = 3
|
||||||
|
tester = TestDynamicGrad(NetTopK(k))
|
||||||
|
tester.test_dynamic_grad_net([x], is_dyn_rank)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
def test_topk_dyn_shape():
|
||||||
|
"""
|
||||||
|
Feature: TopK Grad DynamicShape.
|
||||||
|
Description: Test case of dynamic shape for TopK grad operator.
|
||||||
|
Expectation: success.
|
||||||
|
"""
|
||||||
|
topk_test(False)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
def test_topk_dyn_rank():
|
||||||
|
"""
|
||||||
|
Feature: TopK Grad DynamicShape.
|
||||||
|
Description: Test case of dynamic rank for TopK grad operator.
|
||||||
|
Expectation: success.
|
||||||
|
"""
|
||||||
|
topk_test(True)
|
|
@ -0,0 +1,61 @@
|
||||||
|
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import mindspore
|
||||||
|
import mindspore.ops.operations.math_ops as M
|
||||||
|
from mindspore import nn, context, Tensor
|
||||||
|
from .test_grad_of_dynamic import TestDynamicGrad
|
||||||
|
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE)
|
||||||
|
|
||||||
|
|
||||||
|
class NetTrace(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(NetTrace, self).__init__()
|
||||||
|
self.trace = M.Trace()
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
return self.trace(x)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
def test_trace_dynamic_shape():
|
||||||
|
"""
|
||||||
|
Feature: Trace Grad DynamicShape.
|
||||||
|
Description: Test case of dynamic shape for Trace grad operator on CPU and GPU.
|
||||||
|
Expectation: success.
|
||||||
|
"""
|
||||||
|
test_dynamic = TestDynamicGrad(NetTrace())
|
||||||
|
x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), mindspore.float32)
|
||||||
|
test_dynamic.test_dynamic_grad_net(x, False)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
def test_trace_dynamic_shape_rank():
|
||||||
|
"""
|
||||||
|
Feature: Trace Grad DynamicShape.
|
||||||
|
Description: Test case of dynamic rank for Trace grad operator on CPU and GPU.
|
||||||
|
Expectation: success.
|
||||||
|
"""
|
||||||
|
test_dynamic = TestDynamicGrad(NetTrace())
|
||||||
|
x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), mindspore.float32)
|
||||||
|
test_dynamic.test_dynamic_grad_net(x, True)
|
|
@ -0,0 +1,67 @@
|
||||||
|
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
from mindspore import nn
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore import context
|
||||||
|
from .test_grad_of_dynamic import TestDynamicGrad
|
||||||
|
|
||||||
|
|
||||||
|
class NetTranspose(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(NetTranspose, self).__init__()
|
||||||
|
self.transpose = P.Transpose()
|
||||||
|
|
||||||
|
def construct(self, x, perm):
|
||||||
|
return self.transpose(x, perm)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
def test_dynamic_shape_transpose():
|
||||||
|
"""
|
||||||
|
Feature: Transpose Grad DynamicShape.
|
||||||
|
Description: Test case of dynamic shape for Transpose grad operator on CPU, GPU and Ascend.
|
||||||
|
Expectation: success.
|
||||||
|
"""
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE)
|
||||||
|
test_dynamic = TestDynamicGrad(NetTranspose())
|
||||||
|
x = Tensor(np.array([[1, 2, 3], [3, 2, 1]]))
|
||||||
|
perm = (0, 1)
|
||||||
|
test_dynamic.test_dynamic_grad_net([x, perm], False)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
def test_dynamic_rank_transpose():
|
||||||
|
"""
|
||||||
|
Feature: Transpose Grad DynamicShape.
|
||||||
|
Description: Test case of dynamic rank for Transpose grad operator on CPU, GPU and Ascend.
|
||||||
|
Expectation: success.
|
||||||
|
"""
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE)
|
||||||
|
test_dynamic = TestDynamicGrad(NetTranspose())
|
||||||
|
x = Tensor(np.array([[1, 2, 3], [3, 2, 1]]))
|
||||||
|
perm = (0, 1)
|
||||||
|
test_dynamic.test_dynamic_grad_net([x, perm], True)
|
|
@ -0,0 +1,81 @@
|
||||||
|
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import mindspore as ms
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore import Tensor, context
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from .test_grad_of_dynamic import TestDynamicGrad
|
||||||
|
|
||||||
|
|
||||||
|
class TruncateModNet(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(TruncateModNet, self).__init__()
|
||||||
|
self.truncate_mode = P.TruncateMod()
|
||||||
|
|
||||||
|
def construct(self, x, y):
|
||||||
|
return self.truncate_mode(x, y)
|
||||||
|
|
||||||
|
|
||||||
|
def run_dynamic_shape():
|
||||||
|
test_dynamic = TestDynamicGrad(TruncateModNet())
|
||||||
|
x = Tensor(np.array([2, 4, -1]), ms.int32)
|
||||||
|
y = Tensor(np.array([3, 3, 3]), ms.int32)
|
||||||
|
test_dynamic.test_dynamic_grad_net([x, y])
|
||||||
|
|
||||||
|
|
||||||
|
def run_dynamic_rank():
|
||||||
|
test_dynamic = TestDynamicGrad(TruncateModNet())
|
||||||
|
x = Tensor(np.array([2, 4, -1]), ms.int32)
|
||||||
|
y = Tensor(np.array([3, 3, 3]), ms.int32)
|
||||||
|
test_dynamic.test_dynamic_grad_net([x, y], True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
def test_dynamic_truncate_mode_gpu():
|
||||||
|
"""
|
||||||
|
Feature: TruncateMod Grad DynamicShape.
|
||||||
|
Description: Test case of dynamic shape for TruncateMod grad operator on GPU.
|
||||||
|
Expectation: success.
|
||||||
|
"""
|
||||||
|
# Graph mode
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||||
|
run_dynamic_shape()
|
||||||
|
run_dynamic_rank()
|
||||||
|
# PyNative mode
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||||
|
run_dynamic_shape()
|
||||||
|
run_dynamic_rank()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
def test_dynamic_truncate_mode_ascend():
|
||||||
|
"""
|
||||||
|
Feature: TruncateMod Grad DynamicShape.
|
||||||
|
Description: Test case of dynamic shape for TruncateMod grad operator on Ascend.
|
||||||
|
Expectation: success.
|
||||||
|
"""
|
||||||
|
# Graph mode
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||||
|
run_dynamic_shape()
|
||||||
|
run_dynamic_rank()
|
||||||
|
# PyNative mode
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||||
|
run_dynamic_shape()
|
||||||
|
run_dynamic_rank()
|
|
@ -0,0 +1,70 @@
|
||||||
|
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
from mindspore import nn
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore import context
|
||||||
|
from .test_grad_of_dynamic import TestDynamicGrad
|
||||||
|
|
||||||
|
|
||||||
|
class UnsortedSegmentMaxNet(nn.Cell):
|
||||||
|
def __init__(self, num_segments):
|
||||||
|
super(UnsortedSegmentMaxNet, self).__init__()
|
||||||
|
self.unsorted_segment_max = P.UnsortedSegmentMax()
|
||||||
|
self.num_segments = num_segments
|
||||||
|
|
||||||
|
def construct(self, data, ids):
|
||||||
|
return self.unsorted_segment_max(data, ids, self.num_segments)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
def test_dynamic_shape_unsorted_segment_max():
|
||||||
|
"""
|
||||||
|
Feature: UnsortedSegmentMax Grad DynamicShape.
|
||||||
|
Description: Test case of dynamic shape for UnsortedSegmentMax grad operator on CPU and GPU.
|
||||||
|
Expectation: success.
|
||||||
|
"""
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE)
|
||||||
|
num_segments = 2
|
||||||
|
test_dynamic = TestDynamicGrad(UnsortedSegmentMaxNet(num_segments))
|
||||||
|
input_x = Tensor(
|
||||||
|
np.array([[1, 2, 3], [4, 5, 6], [4, 2, 1]]).astype(np.float32))
|
||||||
|
segment_ids = Tensor(np.array([0, 1, 1]).astype(np.int32))
|
||||||
|
test_dynamic.test_dynamic_grad_net([input_x, segment_ids], False)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
def test_dynamic_rank_unsorted_segment_max():
|
||||||
|
"""
|
||||||
|
Feature: UnsortedSegmentMax Grad DynamicShape.
|
||||||
|
Description: Test case of dynamic rank for UnsortedSegmentMax grad operator on CPU and GPU.
|
||||||
|
Expectation: success.
|
||||||
|
"""
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE)
|
||||||
|
num_segments = 2
|
||||||
|
test_dynamic = TestDynamicGrad(UnsortedSegmentMaxNet(num_segments))
|
||||||
|
input_x = Tensor(
|
||||||
|
np.array([[1, 2, 3], [4, 5, 6], [4, 2, 1]]).astype(np.float32))
|
||||||
|
segment_ids = Tensor(np.array([0, 1, 1]).astype(np.int32))
|
||||||
|
test_dynamic.test_dynamic_grad_net([input_x, segment_ids], True)
|
Loading…
Reference in New Issue