forked from mindspore-Ecosystem/mindspore
support csr cpu
This commit is contained in:
parent
b4f00fba4b
commit
893587ae55
|
@ -0,0 +1,48 @@
|
|||
# akg-third-party
|
||||
# file directory: mindspore/akg/third_party/incubator-tvm/
|
||||
|
||||
https://docs.tvm.ai/vta/tutorials/index.html
|
||||
http://raw.githubusercontent.com/uwsaml/web-data/master/vta/blogpost/vta_stack.png:align
|
||||
https://raw.githubusercontent.com/tvmai/tvmai.github.io/master/images/docs/inferbound/passupdomain_problem.png:align
|
||||
https://raw.githubusercontent.com/tvmai/tvmai.github.io/master/images/docs/inferbound/stage_graph.png:align
|
||||
https://docs.tvm.ai/tutorials/get_started.html#sphx-glr-tutorials-get-started-py
|
||||
http://docs.tvm.ai/tutorials/deployment/cross_compilation_and_rpc.html#sphx-glr-tutorials-deployment-cross-compilation-and-rpc-py
|
||||
https://docs.tvm.ai/api/python/module.html#tvm.module.Module.time_evaluatornum
|
||||
https://arxiv.org/abs/1509.09308https://github.com/andravin/wincnn
|
||||
https://homes.cs.washington.edu/~cyulin/media/gnn_model/gcn_%s.torch
|
||||
https://docs.tvm.ai/api/python/schedule.html#tvm.schedule.Stage.storage_align
|
||||
https://homes.cs.washington.edu/~haichen/
|
||||
https://docs.tvm.ai/dev/relay_pass_infra.html
|
||||
https://en.cppreference.com/w/cpp/numeric/math/roundhttps://en.cppreference.com/w/cpp/numeric/math/nearbyint
|
||||
https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.csr_matrix.htmlandhttps://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.sparse.bsr_matrix.htmlfor
|
||||
https://arxiv.org/abs/1706.02515alpha
|
||||
https://tinyurl.com/y5k6fz5w.We
|
||||
https://tinyurl.com/y4d7hrzf.We
|
||||
https://raw.githubusercontent.com/uwsampl/tvm-distro/master/tophub
|
||||
https://arxiv.org/abs/1409.3215-
|
||||
https://arxiv.org/abs/1706.02515alpha
|
||||
https://discuss.tvm.ai/t/pool2d-gives-bad-output-for-integer-inputs/3377low
|
||||
http://www.apache.org/licenses/LICENSE-2.0////
|
||||
https://raw.githubusercontent.com/tvmai/tvmai.github.io/master/images/docs/inferbound/gatherbound.png:align
|
||||
https://raw.githubusercontent.com/tvmai/tvmai.github.io/master/images/docs/inferbound/inferbound_traversal.png:align
|
||||
https://raw.githubusercontent.com/tvmai/tvmai.github.io/master/images/relay/let_scope.png:align
|
||||
http://raw.githubusercontent.com/uwsaml/web-data/master/vta/blogpost/vta_overview.png:align
|
||||
http://raw.githubusercontent.com/uwsaml/web-data/master/vta/developer/dataflow.png:align
|
||||
https://raw.githubusercontent.com/tvmai/tvmai.github.io/master/images/docs/inferbound/inferbound_phases.png:align
|
||||
https://raw.githubusercontent.com/tvmai/tvmai.github.io/master/images/docs/inferbound/passupdomain_div.png:align
|
||||
https://raw.githubusercontent.com/tvmai/tvmai.github.io/master/images/docs/inferbound/passupdomain_min.png:align
|
||||
https://raw.githubusercontent.com/tvmai/tvmai.github.io/master/images/relay/dataflow.png:align
|
||||
https://raw.githubusercontent.com/tvmai/tvmai.github.io/master/images/relay/dataflow_vs_func.png:align
|
||||
https://raw.githubusercontent.com/tvmai/tvmai.github.io/master/images/docs/inferbound/gatherbound_problem.png:align
|
||||
https://raw.githubusercontent.com/tvmai/tvmai.github.io/master/images/docs/inferbound/relations.png:align
|
||||
http://raw.githubusercontent.com/uwsaml/web-data/master/vta/developer/vta_instructions.png:align
|
||||
http://raw.githubusercontent.com/uwsaml/web-data/master/vta/developer/gemm_core.png:align
|
||||
https://raw.githubusercontent.com/tvmai/tvmai.github.io/master/images/docs/inferbound/union.png:align
|
||||
https://raw.githubusercontent.com/tvmai/tvmai.github.io/master/images/docs/inferbound/passupdomain_nodiv.png:align
|
||||
http://raw.githubusercontent.com/uwsaml/web-data/master/vta/developer/alu_core.png:align
|
||||
http://raw.githubusercontent.com/uwsaml/web-data/master/vta/developer/2d_dma.png:align
|
||||
https://raw.githubusercontent.com/dmlc/web-data/master/tensorflow/models/object_detection/ssd_mobilenet_v1_coco_2018_01_28_nopp.tgz
|
||||
https://github.com/google/googletestcd
|
||||
https://github.com/dmlc/dgl/blob/master/examples/pytorch/gcn/train.pyfrom
|
||||
https://github.com/FrozenGene/tflite/releases/download/v1.13.1/tflite-1.13.1-py3-none-any.whlpip3
|
||||
https://github.com/siju-samuel/darknet/blob/master/
|
|
@ -3,7 +3,10 @@ mindspore.COOTensor
|
|||
|
||||
.. py:class:: mindspore.COOTensor(indices=None, values=None, shape=None)
|
||||
|
||||
用来表示某一张量在给定索引上非零元素的集合。
|
||||
用来表示某一张量在给定索引上非零元素的集合,其中索引(indices)指示了每一个非零元素的位置。
|
||||
|
||||
.. note::
|
||||
- 这是一个实验特性,在未来可能会发生API的变化。
|
||||
|
||||
**参数:**
|
||||
|
||||
|
@ -29,7 +32,7 @@ mindspore.COOTensor
|
|||
.. py:method:: shape
|
||||
:property:
|
||||
|
||||
稀疏矩阵的稠密形状。
|
||||
返回稀疏矩阵的稠密形状。
|
||||
|
||||
.. py:method:: dtype
|
||||
:property:
|
||||
|
@ -49,7 +52,7 @@ mindspore.COOTensor
|
|||
.. py:method:: ndim
|
||||
:property:
|
||||
|
||||
稀疏矩阵的稠密维度。
|
||||
返回稀疏矩阵的稠密维度。
|
||||
|
||||
.. py:method:: to_csr()
|
||||
|
||||
|
|
|
@ -3,7 +3,11 @@ mindspore.CSRTensor
|
|||
|
||||
.. py:class:: mindspore.CSRTensor(indptr=None, indices=None, values=None, shape=None)
|
||||
|
||||
用来表示某一张量在给定索引上非零元素的集合。
|
||||
用来表示某一张量在给定索引上非零元素的集合,其中行索引由`indptr`表示,列索引由`indices`
|
||||
表示,非零值由`values`表示。
|
||||
|
||||
.. note::
|
||||
- 这是一个实验特性,在未来可能会发生API的变化。
|
||||
|
||||
**参数:**
|
||||
|
||||
|
@ -35,7 +39,7 @@ mindspore.CSRTensor
|
|||
.. py:method:: shape
|
||||
:property:
|
||||
|
||||
稀疏矩阵的稠密形状。
|
||||
返回稀疏矩阵的稠密形状。
|
||||
|
||||
.. py:method:: dtype
|
||||
:property:
|
||||
|
|
|
@ -173,7 +173,7 @@ const AnfNodePtr SparseProcess::Process(const FuncGraphPtr &func_graph, const An
|
|||
auto new_node = NewCNode({NewValueNode(prim::kPrimTupleGetItem), inputs[sparse_index], cons_node}, func_graph);
|
||||
new_node->set_abstract(node->abstract());
|
||||
return new_node;
|
||||
// ComputeSparse node: SparseTensorDenseMatmul, CSRDenseMul, CSRReduceSum
|
||||
// ComputeSparse node: SparseTensorDenseMatmul, CSRMul, CSRReduceSum
|
||||
} else if (sparse_op_set.find(prim_name) != sparse_op_set.end()) {
|
||||
const auto &inputs = cnode->inputs();
|
||||
std::vector<AnfNodePtr> new_inputs;
|
||||
|
|
|
@ -96,7 +96,6 @@ const mindspore::HashSet<std::string> make_sparse_set = {{prim::kMakeCSRTensor},
|
|||
// and (possibly) dense tensors, used in backend common optimization pass:
|
||||
// sparse_process.cc
|
||||
const mindspore::HashSet<std::string> sparse_op_set = {{prim::kSparseTensorDenseMatmul},
|
||||
{prim::kCSRDenseMul},
|
||||
{prim::kCSRReduceSum},
|
||||
{prim::kCSRMV},
|
||||
{prim::kCSRMul},
|
||||
|
@ -104,7 +103,7 @@ const mindspore::HashSet<std::string> sparse_op_set = {{prim::kSparseTensorDense
|
|||
{prim::kCSR2COO},
|
||||
{prim::kCSRDiv}};
|
||||
|
||||
COMMON_EXPORT bool IsCustomCSROP(const AnfNodePtr &cnode);
|
||||
COMMON_EXPORT bool IsAKGSparseOP(const AnfNodePtr &cnode);
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_INCLUDE_COMMON_UTILS_CONVERT_UTILS_H_
|
||||
|
|
|
@ -586,7 +586,7 @@ bool AkgKernelBuilder::AkgKernelParallelBuild(const std::vector<AnfNodePtr> &anf
|
|||
AkgKernelJsonGenerator akg_kernel_json_generator(option);
|
||||
auto cnode = anf_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
bool is_custom_node = IsPrimitiveCNode(cnode, prim::kPrimCustom) || IsCustomCSROP(cnode);
|
||||
bool is_custom_node = IsPrimitiveCNode(cnode, prim::kPrimCustom) || IsAKGSparseOP(cnode);
|
||||
// Graph kernel node and Custom node need to generate composite json
|
||||
if (common::AnfAlgo::IsGraphKernel(cnode) || is_custom_node) {
|
||||
FuncGraphPtr func_graph = is_custom_node ? cnode->func_graph() : common::AnfAlgo::GetCNodeFuncGraphPtr(cnode);
|
||||
|
|
|
@ -1059,14 +1059,14 @@ void ComputeCapability::GetComputeCapability() {
|
|||
#ifdef ENABLE_GPU
|
||||
int a, b;
|
||||
auto ret = cuDeviceGetAttribute(&a, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, 0);
|
||||
if (ret != CUDA_SUCCESS) {
|
||||
if (ret != CUDA_SUCCESS && Callback::Instance()->GetTargetFromContext() == kGPUDevice) {
|
||||
const char *msg = nullptr;
|
||||
cuGetErrorName(ret, &msg);
|
||||
MS_LOG(WARNING) << "Get CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR fail, error message: " << msg;
|
||||
return;
|
||||
}
|
||||
ret = cuDeviceGetAttribute(&b, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, 0);
|
||||
if (ret != CUDA_SUCCESS) {
|
||||
if (ret != CUDA_SUCCESS && Callback::Instance()->GetTargetFromContext() == kGPUDevice) {
|
||||
const char *msg = nullptr;
|
||||
cuGetErrorName(ret, &msg);
|
||||
MS_LOG(WARNING) << "Get CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR fail, error message: " << msg;
|
||||
|
|
|
@ -423,7 +423,8 @@ py::list FilterTensorArgs(const py::args &args, bool has_sens = false) {
|
|||
py::list only_tensors;
|
||||
size_t forward_args_size = has_sens ? size - 1 : size;
|
||||
for (size_t i = 0; i < forward_args_size; ++i) {
|
||||
if (py::isinstance<tensor::Tensor>(args[i])) {
|
||||
if (py::isinstance<tensor::Tensor>(args[i]) || py::isinstance<tensor::CSRTensor>(args[i]) ||
|
||||
py::isinstance<tensor::COOTensor>(args[i])) {
|
||||
only_tensors.append(args[i]);
|
||||
}
|
||||
}
|
||||
|
@ -3054,8 +3055,11 @@ void GradExecutor::RunGradGraph(py::object *ret, const py::object &cell, const p
|
|||
top_cell()->set_k_pynative_cell_ptr(nullptr);
|
||||
BaseRef value = (*run)(arg_list);
|
||||
grad_is_running_ = false;
|
||||
FuncGraphPtr fg = resource->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
auto output_abs = fg->output()->abstract();
|
||||
MS_LOG(DEBUG) << "Eval run end " << value.ToString();
|
||||
*ret = BaseRefToPyData(value);
|
||||
*ret = BaseRefToPyData(value, output_abs);
|
||||
// Clear device memory resource of top cell when it has been ran.
|
||||
auto has_higher_order = std::any_of(top_cell_list_.begin(), top_cell_list_.end(),
|
||||
[](const TopCellInfoPtr &value) { return !value->is_topest(); });
|
||||
|
|
|
@ -27,6 +27,7 @@
|
|||
#include "plugin/device/cpu/kernel/custom/custom_aot_cpu_kernel.h"
|
||||
#include "plugin/device/cpu/kernel/custom/custom_julia_cpu_kernel.h"
|
||||
#include "utils/trace_base.h"
|
||||
#include "include/common/utils/convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
|
@ -259,11 +260,14 @@ void AddKernelAttr(const CNodePtr &kernel_node, const KernelAttr &kernel_attr) {
|
|||
kernel_attrs);
|
||||
}
|
||||
|
||||
void UpdateCustomKernelBuildInfoAndAttrs(const CNodePtr &kernel_node) {
|
||||
void UpdateCustomKernelBuildInfoAndAttrs(const CNodePtr &kernel_node, bool is_akg_op) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
|
||||
// Custom op's kernel type can only be CPU_KERNEL on CPU
|
||||
if (is_akg_op) {
|
||||
builder->SetKernelType(KernelType::AKG_KERNEL);
|
||||
} else {
|
||||
builder->SetKernelType(KernelType::CPU_KERNEL);
|
||||
}
|
||||
builder->SetProcessor(kernel::Processor::CPU);
|
||||
// Set inputs info
|
||||
std::vector<TypeId> input_types;
|
||||
|
@ -281,6 +285,7 @@ void UpdateCustomKernelBuildInfoAndAttrs(const CNodePtr &kernel_node) {
|
|||
builder->SetOutputsDeviceType(output_types);
|
||||
builder->SetOutputsFormat(output_formats);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), kernel_node.get());
|
||||
if (!is_akg_op) {
|
||||
// Update kernel attrs
|
||||
KernelAttr attr;
|
||||
if (input_types.size() != input_formats.size()) {
|
||||
|
@ -298,6 +303,7 @@ void UpdateCustomKernelBuildInfoAndAttrs(const CNodePtr &kernel_node) {
|
|||
attr.AddOutputAttr(output_types[i], output_formats[i]);
|
||||
}
|
||||
AddKernelAttr(kernel_node, attr);
|
||||
}
|
||||
}
|
||||
|
||||
KernelAttr FillNoneInKernelAttr(const CNodePtr &kernel_node, const std::vector<TypeId> &input_types,
|
||||
|
@ -428,11 +434,13 @@ void SetKernelInfo(const CNodePtr &kernel_node) {
|
|||
MS_LOG(WARNING) << "Not find operator information for Custom operator[" << op_name << "]. "
|
||||
<< "Infer operator information from inputs. For more details, "
|
||||
<< "please refer to 'mindspore.ops.Custom' at https://www.mindspore.cn.";
|
||||
return UpdateCustomKernelBuildInfoAndAttrs(kernel_node);
|
||||
return UpdateCustomKernelBuildInfoAndAttrs(kernel_node, false);
|
||||
}
|
||||
} else if (IsDynamicParamKernel(op_name)) {
|
||||
// Select for dynamic kernel(both the number and data type are undetermined).
|
||||
return UpdateDynamicKernelBuildInfoAndAttrs(kernel_node);
|
||||
} else if (IsAKGSparseOP(kernel_node)) {
|
||||
return UpdateCustomKernelBuildInfoAndAttrs(kernel_node, true);
|
||||
}
|
||||
|
||||
std::vector<std::string> input_formats;
|
||||
|
|
|
@ -318,7 +318,7 @@ size_t CountValueNum(const ValueTuplePtr &value_tuple) {
|
|||
return cnt;
|
||||
}
|
||||
|
||||
bool IsCustomCSROP(const AnfNodePtr &cnode) {
|
||||
bool IsAKGSparseOP(const AnfNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
const PrimitiveSet prims{prim::kPrimCSRReduceSum, prim::kPrimCSRMul, prim::kPrimCSRMV, prim::kPrimCSRGather,
|
||||
prim::kPrimCSR2COO, prim::kPrimCOO2CSR, prim::kPrimCSRDiv};
|
||||
|
|
|
@ -159,9 +159,7 @@ AbstractBasePtr InferImplCOOTensorGetIndices(const AnalysisEnginePtr &, const Pr
|
|||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplCOOTensorGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplCSRMul(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplCSRDiv(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
AbstractBasePtr InferImplCSRElementWise(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplCSRMV(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
|
|
|
@ -29,7 +29,43 @@
|
|||
|
||||
namespace {
|
||||
constexpr auto kRankSize = "rank_size";
|
||||
inline void CheckSparseShape(ShapeVector sparse_shp, ShapeVector dense_shp) {
|
||||
constexpr auto kCSRMulBatchPos = 2;
|
||||
int dlen = mindspore::SizeToInt(sparse_shp.size()) - mindspore::SizeToInt(dense_shp.size());
|
||||
if (dlen < 0) {
|
||||
MS_EXCEPTION(mindspore::ValueError) << "Currently, only support dense tensor broadcast to sparse tensor, "
|
||||
<< "but sparse tensor has " << sparse_shp.size() << " dimensions, "
|
||||
<< "and dense tensor has " << dense_shp.size() << " dimensions, ";
|
||||
}
|
||||
for (int i = 0; i < dlen; i++) {
|
||||
(void)dense_shp.insert(dense_shp.begin(), 1);
|
||||
}
|
||||
if (sparse_shp.size() != dense_shp.size()) {
|
||||
MS_LOG(EXCEPTION) << "Failure: sparse_shp.size() != dense_shp.size().";
|
||||
}
|
||||
if (sparse_shp.size() < 1) {
|
||||
MS_LOG(EXCEPTION) << "Failure: dense tensor and sparse tensor shapes cannot be zero.";
|
||||
}
|
||||
if (dense_shp[0] != sparse_shp[0]) {
|
||||
MS_EXCEPTION(mindspore::ValueError)
|
||||
<< "Currently, dense tensor and sparse tensor shapes must equal in first dimension.";
|
||||
}
|
||||
for (size_t i = 0; i < sparse_shp.size(); i++) {
|
||||
auto s = sparse_shp[i];
|
||||
auto d = dense_shp[i];
|
||||
if (i < kCSRMulBatchPos) {
|
||||
if (d != s && d != 1) {
|
||||
MS_EXCEPTION(mindspore::ValueError) << "Dense shape cannot broadcast to sparse shape.";
|
||||
}
|
||||
} else {
|
||||
if (d != s) {
|
||||
MS_EXCEPTION(mindspore::ValueError)
|
||||
<< "Currently, sparse shape and dense shape must equal in feature dimensions.";
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
namespace mindspore {
|
||||
namespace abstract {
|
||||
|
@ -37,6 +73,7 @@ constexpr auto kCSRDenseShape = "dense_shape";
|
|||
constexpr auto kCSRAxis = "axis";
|
||||
constexpr auto kCSRAvgRows = "csr_avg_rows";
|
||||
constexpr auto kIsCSR = "is_csr";
|
||||
|
||||
AbstractBasePtr InferImplIdentity(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// An object of a subclass of AbstractBase
|
||||
|
@ -344,12 +381,12 @@ AbstractBasePtr InferImplMakeCOOTensor(const AnalysisEnginePtr &, const Primitiv
|
|||
}
|
||||
auto indices_shp = indices->shape()->shape();
|
||||
if (indices_shp.size() != 2) {
|
||||
MS_EXCEPTION(TypeError) << "Indices must be a 2 dimension tensor, but got a " << indices_shp.size()
|
||||
MS_EXCEPTION(TypeError) << "Indices must be a 2 dimensional tensor, but got a " << indices_shp.size()
|
||||
<< " dimension tensor";
|
||||
}
|
||||
auto values_shp = values->shape()->shape();
|
||||
if (values_shp.size() != 1) {
|
||||
MS_EXCEPTION(TypeError) << "Values must be a 1 dimension tensor, but got a " << values_shp.size()
|
||||
MS_EXCEPTION(TypeError) << "Values must be a 1 dimensional tensor, but got a " << values_shp.size()
|
||||
<< " dimension tensor";
|
||||
}
|
||||
if (indices_shp[0] != values_shp[0]) {
|
||||
|
@ -418,13 +455,12 @@ AbstractBasePtr InferImplCOOTensorGetDenseShape(const AnalysisEnginePtr &, const
|
|||
return sparse_tensor->dense_shape();
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplCSRMul(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
AbstractBasePtr InferImplCSRElementWise(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: a sparse tensor and a dense tensor.
|
||||
constexpr auto kCSRMulInputsNum = 2;
|
||||
constexpr auto kCSRMulShapeSize = 2;
|
||||
constexpr auto kCSRElementwiseInputsNum = 2;
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, kCSRMulInputsNum);
|
||||
CheckArgsSize(op_name, args_spec_list, kCSRElementwiseInputsNum);
|
||||
auto sparse = CheckArg<AbstractCSRTensor>(op_name, args_spec_list, 0);
|
||||
auto dense = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
|
||||
MS_EXCEPTION_IF_NULL(sparse);
|
||||
|
@ -435,14 +471,7 @@ AbstractBasePtr InferImplCSRMul(const AnalysisEnginePtr &, const PrimitivePtr &p
|
|||
|
||||
auto sparse_shape = sparse->shape()->shape();
|
||||
auto dense_shape = dense->shape()->shape();
|
||||
if (sparse_shape.size() != kCSRMulShapeSize || dense_shape.size() != kCSRMulShapeSize) {
|
||||
MS_EXCEPTION(ValueError) << "Currently, only support " << kCSRMulShapeSize << "-D inputs!"
|
||||
<< "but sparse tensor has " << sparse_shape.size() << " dimensions, "
|
||||
<< "and dense tensor has " << dense_shape.size() << " dimensions, ";
|
||||
}
|
||||
if (dense_shape[0] != sparse_shape[0]) {
|
||||
MS_EXCEPTION(ValueError) << "Currently, only support dense tensor broadcast with last dim!";
|
||||
}
|
||||
CheckSparseShape(sparse_shape, dense_shape);
|
||||
auto ret = sparse->values()->Broaden();
|
||||
|
||||
MS_EXCEPTION_IF_NULL(sparse->indices()->shape());
|
||||
|
@ -454,40 +483,6 @@ AbstractBasePtr InferImplCSRMul(const AnalysisEnginePtr &, const PrimitivePtr &p
|
|||
return ret;
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplCSRDiv(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: a sparse tensor and a dense tensor.
|
||||
constexpr auto kCSRDivInputsNum = 2;
|
||||
constexpr auto kCSRDivShapeSize = 2;
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, kCSRDivInputsNum);
|
||||
auto sparse = CheckArg<AbstractCSRTensor>(op_name, args_spec_list, 0);
|
||||
auto dense = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
|
||||
MS_EXCEPTION_IF_NULL(sparse);
|
||||
MS_EXCEPTION_IF_NULL(sparse->shape());
|
||||
MS_EXCEPTION_IF_NULL(sparse->values());
|
||||
MS_EXCEPTION_IF_NULL(sparse->indices());
|
||||
MS_EXCEPTION_IF_NULL(dense);
|
||||
|
||||
auto sparse_shape = sparse->shape()->shape();
|
||||
auto dense_shape = dense->shape()->shape();
|
||||
if (sparse_shape.size() != kCSRDivShapeSize || dense_shape.size() != kCSRDivShapeSize) {
|
||||
MS_EXCEPTION(ValueError) << "Currently, only support " << kCSRDivShapeSize << "-D inputs!"
|
||||
<< "but sparse tensor has " << sparse_shape.size() << " dimensions, "
|
||||
<< "and dense tensor has " << dense_shape.size() << " dimensions, ";
|
||||
}
|
||||
auto ret = sparse->values()->Broaden();
|
||||
|
||||
MS_EXCEPTION_IF_NULL(sparse->indices()->shape());
|
||||
auto nnz_vec = sparse->indices()->shape()->shape();
|
||||
int csr_avg_rows = nnz_vec[0] / dense_shape[0];
|
||||
primitive->set_attr(kCSRAvgRows, MakeValue(csr_avg_rows));
|
||||
primitive->set_attr(kCSRDenseShape, MakeValue(sparse_shape));
|
||||
primitive->set_attr(kIsCSR, MakeValue(true));
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplCSRMV(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: a sparse tensor and a dense tensor.
|
||||
|
@ -544,7 +539,6 @@ AbstractBasePtr InferImplCSRReduceSum(const AnalysisEnginePtr &, const Primitive
|
|||
MS_EXCEPTION(ValueError) << "Currently, only support " << kCSRReduceSumShapeSize << "-D inputs!"
|
||||
<< "but sparse tensor has " << sparse_shape.size() << " dimensions.";
|
||||
}
|
||||
|
||||
ShapeVector out_shape = sparse_shape;
|
||||
MS_EXCEPTION_IF_NULL(axis->BuildValue());
|
||||
if (axis->BuildValue()->isa<Int32Imm>() || axis->BuildValue()->isa<Int64Imm>()) {
|
||||
|
@ -682,21 +676,21 @@ AbstractBasePtr InferImplMakeCSRTensor(const AnalysisEnginePtr &, const Primitiv
|
|||
}
|
||||
auto indptr_shp = indptr->shape()->shape();
|
||||
if (indptr_shp.size() != 1) {
|
||||
MS_EXCEPTION(ValueError) << "Indptr must be a 1 dimension tensor, but got a " << indptr_shp.size()
|
||||
<< " dimension tensor";
|
||||
MS_EXCEPTION(ValueError) << "Indptr must be a 1-dimensional tensor, but got a " << indptr_shp.size()
|
||||
<< "-dimensional tensor";
|
||||
}
|
||||
auto indices_shp = indices->shape()->shape();
|
||||
if (indices_shp.size() != 1) {
|
||||
MS_EXCEPTION(ValueError) << "Indices must be a 1 dimension tensor, but got a " << indices_shp.size()
|
||||
<< " dimension tensor";
|
||||
MS_EXCEPTION(ValueError) << "Indices must be a 1-dimensional tensor, but got a " << indices_shp.size()
|
||||
<< "-dimensional tensor";
|
||||
}
|
||||
auto values_shp = values->shape()->shape();
|
||||
if (values_shp.size() != 1) {
|
||||
MS_EXCEPTION(ValueError) << "Values must be a 1 dimension tensor, but got a " << values_shp.size()
|
||||
<< " dimension tensor";
|
||||
MS_EXCEPTION(ValueError) << "Values must be a 1-dimensional tensor, but got a " << values_shp.size()
|
||||
<< "-dimensional tensor";
|
||||
}
|
||||
if (indices_shp[0] != values_shp[0]) {
|
||||
MS_EXCEPTION(ValueError) << "indices and values must have same size, but got: values length: " << values_shp[0]
|
||||
MS_EXCEPTION(ValueError) << "Indices and values must have same size, but got: values length: " << values_shp[0]
|
||||
<< ", indices length " << indices_shp[0];
|
||||
}
|
||||
for (const auto &elem_type : shape->ElementsType()) {
|
||||
|
@ -712,14 +706,19 @@ AbstractBasePtr InferImplMakeCSRTensor(const AnalysisEnginePtr &, const Primitiv
|
|||
auto elem = GetValue<int64_t>(e);
|
||||
return elem;
|
||||
});
|
||||
|
||||
for (auto shape_elem : shape_vec) {
|
||||
if (shape_elem < 0) {
|
||||
MS_EXCEPTION(TypeError) << "The element of shape must be positive, but got " << shape_value->ToString();
|
||||
}
|
||||
if (values_shp.size() + 1 != shape_vec.size()) {
|
||||
MS_EXCEPTION(ValueError) << "Values' dimension should equal to csr_tensor's dimension - 1.";
|
||||
}
|
||||
if (shape_vec[0] + 1 != indptr_shp[0]) {
|
||||
MS_EXCEPTION(ValueError) << "indptr must have length (1 + shape[0]), but got: " << indptr_shp[0];
|
||||
MS_EXCEPTION(ValueError) << "Indptr must have length (1 + shape[0]), but got: " << indptr_shp[0];
|
||||
}
|
||||
for (size_t i = 0; i < shape_vec.size(); ++i) {
|
||||
if (shape_vec[i] < 0) {
|
||||
MS_EXCEPTION(TypeError) << "The element of shape must be positive, but got " << shape_value->ToString();
|
||||
}
|
||||
if ((i > 1) && (shape_vec[i] != values_shp[i - 1])) {
|
||||
MS_EXCEPTION(ValueError) << "csr_tensor's shape should match with values' shape.";
|
||||
}
|
||||
}
|
||||
auto ret = std::make_shared<AbstractCSRTensor>(values->element()->BuildType(), shape_vec);
|
||||
ret->set_indptr(indptr);
|
||||
|
|
|
@ -234,8 +234,8 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
|
|||
{prim::kPrimCSRTensorGetIndptr, R{InferImplCSRTensorGetIndptr, nullptr, true}},
|
||||
{prim::kPrimCSRTensorGetIndices, R{InferImplCSRTensorGetIndices, nullptr, true}},
|
||||
{prim::kPrimCSRTensorGetDenseShape, R{InferImplCSRTensorGetDenseShape, nullptr, true}},
|
||||
{prim::kPrimCSRMul, R{InferImplCSRMul, nullptr, true}},
|
||||
{prim::kPrimCSRDiv, R{InferImplCSRDiv, nullptr, true}},
|
||||
{prim::kPrimCSRMul, R{InferImplCSRElementWise, nullptr, true}},
|
||||
{prim::kPrimCSRDiv, R{InferImplCSRElementWise, nullptr, true}},
|
||||
{prim::kPrimCSRMV, R{InferImplCSRMV, nullptr, true}},
|
||||
{prim::kPrimCSRReduceSum, R{InferImplCSRReduceSum, nullptr, true}},
|
||||
{prim::kPrimCSRGather, R{InferImplCSRGather, nullptr, true}},
|
||||
|
|
|
@ -46,7 +46,6 @@ MS_CORE_API AbstractBasePtrList AbstractJoin(const AbstractBasePtrList &spec1, c
|
|||
MS_CORE_API AbstractBasePtr SensitivityTransform(const AbstractBasePtr &spec);
|
||||
|
||||
ShapeVector BroadcastShape(ShapeVector shpx, ShapeVector shpy);
|
||||
|
||||
MS_CORE_API size_t TypeIdSize(const TypeId data_type);
|
||||
template <typename T>
|
||||
T ShapeSize(const std::vector<T> &shape) {
|
||||
|
|
|
@ -160,7 +160,6 @@ constexpr auto kCOOTensorDenseMatmul = "COOTensorDenseMatmul";
|
|||
|
||||
// Sparse ops
|
||||
constexpr auto kSparseTensorDenseMatmul = "SparseTensorDenseMatmul";
|
||||
constexpr auto kCSRDenseMul = "CSRDenseMul";
|
||||
constexpr auto kCSRReduceSum = "CSRReduceSum";
|
||||
constexpr auto kCSRMV = "CSRMV";
|
||||
constexpr auto kCSRMul = "CSRMul";
|
||||
|
@ -607,7 +606,6 @@ GVAR_DEF(PrimitivePtr, kPrimIsCSRFunc, std::make_shared<Primitive>(kIsCSRFunc));
|
|||
// Sparse ops
|
||||
GVAR_DEF(PrimitivePtr, kPrimSparseTensorDenseMatmul, std::make_shared<Primitive>(kSparseTensorDenseMatmul));
|
||||
GVAR_DEF(PrimitivePtr, kPrimCOOTensorDenseMatmul, std::make_shared<Primitive>(kCOOTensorDenseMatmul));
|
||||
GVAR_DEF(PrimitivePtr, kPrimCSRDenseMul, std::make_shared<Primitive>(kCSRDenseMul));
|
||||
GVAR_DEF(PrimitivePtr, kPrimCSRReduceSum, std::make_shared<Primitive>(kCSRReduceSum));
|
||||
GVAR_DEF(PrimitivePtr, kPrimCSRMV, std::make_shared<Primitive>(kCSRMV));
|
||||
GVAR_DEF(PrimitivePtr, kPrimCSRMul, std::make_shared<Primitive>(kCSRMul));
|
||||
|
|
|
@ -859,6 +859,46 @@ class Validator:
|
|||
def check_type_support(dtype, device, supported_dtypes):
|
||||
return dtype in supported_dtypes or not context.get_context('device_target') == device
|
||||
|
||||
@staticmethod
|
||||
def check_csr_tensor_shape(indptr_shp, indices_shp, values_shp, csr_shp):
|
||||
"""Checks input tensors' shapes for CSRTensor."""
|
||||
if len(csr_shp) != 2:
|
||||
raise ValueError("Currently only supports 2-dimensional csr tensor.")
|
||||
if len(values_shp) != 1:
|
||||
raise ValueError(f"Values must be a 1-dimensional tensor, but got a {len(values_shp)} dimension tensor.")
|
||||
if len(indices_shp) != 1:
|
||||
raise ValueError(f"Indices must be a 1-dimensional tensor, but got a {len(indices_shp)} dimension tensor.")
|
||||
if len(indptr_shp) != 1:
|
||||
raise ValueError(f"Indptr must be a 1-dimensional tensor, but got a {len(indptr_shp)} dimension tensor.")
|
||||
if csr_shp[0] + 1 != indptr_shp[0]:
|
||||
raise ValueError(f"Indptr must have length (1 + shape[0]), but got: {indptr_shp[0]}")
|
||||
|
||||
@staticmethod
|
||||
def check_csr_tensor_dtype(indptr_dtype, indices_dtype):
|
||||
"""Checks input tensors' data types for CSRTensor."""
|
||||
if indptr_dtype not in (mstype.int16, mstype.int32, mstype.int64):
|
||||
raise TypeError("Indptr must have integer data type.")
|
||||
if indices_dtype not in (mstype.int16, mstype.int32, mstype.int64):
|
||||
raise TypeError("Indices must have integer data type.")
|
||||
|
||||
@staticmethod
|
||||
def check_coo_tensor_shape(indices_shp, values_shp, coo_shp):
|
||||
"""Checks input tensors' shapes for COOTensor."""
|
||||
if len(coo_shp) != 2:
|
||||
raise ValueError("Currently only supports 2-dimensional coo tensor.")
|
||||
if len(indices_shp) != 2:
|
||||
raise ValueError(f"Indices must be a 2-dimensional tensor, but got a {len(indices_shp)} dimension tensor.")
|
||||
if len(values_shp) != 1:
|
||||
raise ValueError(f"Values must be a 1-dimensional tensor, but got a {len(values_shp)} dimension tensor.")
|
||||
if indices_shp[0] != values_shp[0]:
|
||||
raise ValueError(f"Indices.shape must be (N, 2), where N equals to number of nonzero values in coo tensor.")
|
||||
|
||||
@staticmethod
|
||||
def check_coo_tensor_dtype(indices_dtype):
|
||||
"""Checks input tensors' data types for COOTensor."""
|
||||
if indices_dtype not in (mstype.int16, mstype.int32, mstype.int64):
|
||||
raise TypeError("Indices must have integer data type.")
|
||||
|
||||
|
||||
def check_input_format(input_param):
|
||||
"""Judge input format."""
|
||||
|
|
|
@ -2416,6 +2416,9 @@ class COOTensor(COOTensor_):
|
|||
[0, 0, 2, 0],
|
||||
[0, 0, 0, 0]]
|
||||
|
||||
Note:
|
||||
This is an experimental feature and is subjected to change.
|
||||
|
||||
Args:
|
||||
indices (Tensor): A 2-D integer Tensor of shape `[N, ndims]`,
|
||||
where N and ndims are the number of `values` and number of dimensions in
|
||||
|
@ -2433,7 +2436,7 @@ class COOTensor(COOTensor_):
|
|||
>>> import mindspore as ms
|
||||
>>> import mindspore.nn as nn
|
||||
>>> from mindspore import Tensor, COOTensor
|
||||
>>> indices = Tensor([[0, 1], [1, 2]])
|
||||
>>> indices = Tensor([[0, 1], [1, 2]], dtype=ms.int32)
|
||||
>>> values = Tensor([1, 2], dtype=ms.float32)
|
||||
>>> shape = (3, 4)
|
||||
>>> x = COOTensor(indices, values, shape)
|
||||
|
@ -2448,29 +2451,52 @@ class COOTensor(COOTensor_):
|
|||
|
||||
def __init__(self, indices=None, values=None, shape=None, coo_tensor=None):
|
||||
"Init COOTensor"
|
||||
self.init_finished = False
|
||||
# Directly init a COOTensor from another COOTensor
|
||||
if indices is None and values is None and shape is None and coo_tensor is not None:
|
||||
if not isinstance(coo_tensor, (COOTensor, COOTensor_)):
|
||||
raise TypeError("If only one input provided, it must be a COOTensor.")
|
||||
COOTensor_.__init__(self, coo_tensor)
|
||||
# Init a COOTensor from indices, values and shape
|
||||
else:
|
||||
if not (isinstance(indices, Tensor) and isinstance(values, Tensor) and isinstance(shape, tuple)):
|
||||
raise TypeError("Inputs must follow: COOTensor(indices, values, shape).")
|
||||
validator.check_coo_tensor_shape(indices.shape, values.shape, shape)
|
||||
validator.check_coo_tensor_dtype(indices.dtype)
|
||||
COOTensor_.__init__(self, indices, values, shape)
|
||||
self.init_finished = True
|
||||
|
||||
def __repr__(self):
|
||||
"""Avoid PyTest Segfault when COOTensor is not initialized."""
|
||||
if self.init_finished:
|
||||
return COOTensor_.__repr__(self)
|
||||
return ''
|
||||
|
||||
@property
|
||||
def indices(self):
|
||||
"""Return COOTensor's indices."""
|
||||
return Tensor(self._indices)
|
||||
|
||||
@property
|
||||
def values(self):
|
||||
"""Return COOTensor's non-zero values."""
|
||||
return Tensor(self._values)
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
"""Return COOTensor's shape."""
|
||||
return self._shape
|
||||
|
||||
def to_csr(self):
|
||||
"Converts COOTensor to CSRTensor."
|
||||
"""
|
||||
Converts COOTensor to CSRTensor.
|
||||
|
||||
Returns:
|
||||
CSRTensor.
|
||||
|
||||
Supported Platforms:
|
||||
``GPU`` ``CPU``
|
||||
"""
|
||||
row_indices = self.indices[:, 0]
|
||||
col_indices = self.indices[:, 1]
|
||||
idx_dtype = self.indices.dtype
|
||||
|
@ -2483,8 +2509,17 @@ class COOTensor(COOTensor_):
|
|||
return CSRTensor(indptr, col_indices, values, self.shape)
|
||||
|
||||
def to_dense(self):
|
||||
"""
|
||||
Converts COOTensor to Dense Tensor.
|
||||
|
||||
Returns:
|
||||
Tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``GPU`` ``CPU``
|
||||
"""
|
||||
zeros_tensor = tensor_operator_registry.get("zeros")(self.shape, self.values.dtype)
|
||||
return tensor_operator_registry.get("tensor_scatter_update")(
|
||||
return tensor_operator_registry.get("tensor_scatter_add")(
|
||||
zeros_tensor, self.indices, self.values)
|
||||
|
||||
@property
|
||||
|
@ -2526,8 +2561,8 @@ class COOTensor(COOTensor_):
|
|||
>>> indices = Tensor([[0, 1], [1, 2]])
|
||||
>>> values = Tensor([1, 2], dtype=ms.float32)
|
||||
>>> shape = (3, 4)
|
||||
>>> x = COOTensor(indices, values, shape)
|
||||
>>> print(x.astype(ms.float64).dtype)
|
||||
>>> coo_tensor = COOTensor(indices, values, shape)
|
||||
>>> print(coo_tensor.astype(ms.float64).dtype)
|
||||
Float64
|
||||
"""
|
||||
data = self.values.astype(dtype)
|
||||
|
@ -2549,9 +2584,6 @@ class CSRTensor(CSRTensor_):
|
|||
values indicated by `values` and row and column positions indicated by `indptr`
|
||||
and `indices`.
|
||||
|
||||
Alternatively, CSRTensor can be initialized by passing another CSRTensor as input.
|
||||
Currently this constructor can only be supported in PyNative Mode.
|
||||
|
||||
Note:
|
||||
This is an experimental feature and is subjected to change.
|
||||
|
||||
|
@ -2576,13 +2608,11 @@ class CSRTensor(CSRTensor_):
|
|||
>>> import mindspore as ms
|
||||
>>> from mindspore import Tensor, CSRTensor
|
||||
>>> # initialize a csr_tensor with indptr, indices, values and shape
|
||||
>>> indptr = Tensor([0, 1, 2])
|
||||
>>> indices = Tensor([0, 1])
|
||||
>>> indptr = Tensor([0, 1, 2], dtype=ms.int32)
|
||||
>>> indices = Tensor([0, 1], dtype=ms.int32)
|
||||
>>> values = Tensor([1, 2], dtype=ms.float32)
|
||||
>>> shape = (2, 4)
|
||||
>>> csr_tensor = CSRTensor(indptr, indices, values, shape)
|
||||
>>> # initialize a csr_tensor from another csr_tensor
|
||||
>>> csr_tensor_2 = CSRTensor(csr_tensor=csr_tensor)
|
||||
>>> # access a data member of CSRTensor
|
||||
>>> print(indptr == csr_tensor.indptr)
|
||||
[ True True True]
|
||||
|
@ -2590,24 +2620,18 @@ class CSRTensor(CSRTensor_):
|
|||
|
||||
def __init__(self, indptr=None, indices=None, values=None, shape=None, csr_tensor=None):
|
||||
self.init_finished = False
|
||||
# Case 1: directly init a CSRTensor from another CSRTensor
|
||||
# Directly init a CSRTensor from another CSRTensor
|
||||
if indptr is None and indices is None and values is None and shape is None:
|
||||
if not isinstance(csr_tensor, (CSRTensor, CSRTensor_)):
|
||||
raise TypeError("If only one input provided, it must be a CSRTensor.")
|
||||
CSRTensor_.__init__(self, csr_tensor)
|
||||
# Case 2: init a CSRTensor from indptr, indices, values and shape
|
||||
# Init a CSRTensor from indptr, indices, values and shape
|
||||
else:
|
||||
if (indptr is None or indices is None or values is None or shape is None):
|
||||
raise TypeError("Inputs must follow: CSRTensor(indptr, indices, values, shape).")
|
||||
if not (isinstance(indptr, Tensor) and isinstance(indices, Tensor) \
|
||||
and isinstance(values, Tensor) and isinstance(shape, tuple)):
|
||||
raise TypeError("Inputs must follow: CSRTensor(tensor, tensor, tensor, tuple).")
|
||||
if len(shape) != 2 or shape[0] + 1 != indptr.shape[0] or shape[1] <= 0:
|
||||
raise ValueError("Shape length should be 2, shape[0] should equal to indptr.shape[0] - 1")
|
||||
if indptr.dtype not in (mstype.int16, mstype.int32, mstype.int64):
|
||||
raise TypeError("indptr must have integer data type.")
|
||||
if indices.dtype not in (mstype.int16, mstype.int32, mstype.int64):
|
||||
raise TypeError("indices must have integer data type.")
|
||||
validator.check_csr_tensor_shape(indptr.shape, indices.shape, values.shape, shape)
|
||||
validator.check_csr_tensor_dtype(indptr.dtype, indices.dtype)
|
||||
CSRTensor_.__init__(self, indptr, indices, values, shape)
|
||||
self.init_finished = True
|
||||
|
||||
|
@ -2630,18 +2654,22 @@ class CSRTensor(CSRTensor_):
|
|||
|
||||
@property
|
||||
def indptr(self):
|
||||
"""Return CSRTensor's row indices pointers."""
|
||||
return Tensor(self._indptr)
|
||||
|
||||
@property
|
||||
def indices(self):
|
||||
"""Return CSRTensor's column indices."""
|
||||
return Tensor(self._indices)
|
||||
|
||||
@property
|
||||
def values(self):
|
||||
"""Return CSRTensor's non-zero values."""
|
||||
return Tensor(self._values)
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
"""Return CSRTensor's shape."""
|
||||
return self._shape
|
||||
|
||||
@property
|
||||
|
@ -2700,7 +2728,7 @@ class CSRTensor(CSRTensor_):
|
|||
>>> values = Tensor([1, 2], dtype=ms.float32)
|
||||
>>> shape = (2, 4)
|
||||
>>> csr_tensor = CSRTensor(indptr, indices, values, shape)
|
||||
>>> print(x.astype(ms.float64).dtype)
|
||||
>>> print(csr_tensor.astype(ms.float64).dtype)
|
||||
Float64
|
||||
"""
|
||||
data = self.values.astype(dtype)
|
||||
|
@ -2717,7 +2745,7 @@ class CSRTensor(CSRTensor_):
|
|||
Tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``GPU``
|
||||
``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> from mindspore import Tensor, CSRTensor
|
||||
|
@ -2745,7 +2773,7 @@ class CSRTensor(CSRTensor_):
|
|||
Tensor, the dtype is the same as `sparse_tensor.values`.
|
||||
|
||||
Supported Platforms:
|
||||
``GPU``
|
||||
``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> from mindspore import Tensor, CSRTensor
|
||||
|
@ -2762,7 +2790,15 @@ class CSRTensor(CSRTensor_):
|
|||
return tensor_operator_registry.get("csr_reduce_sum")(self, axis)
|
||||
|
||||
def abs(self):
|
||||
"""Return absolute value element-wisely."""
|
||||
"""
|
||||
Return absolute value element-wisely.
|
||||
|
||||
Returns:
|
||||
CSRTensor, with all values being non-negative.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
"""
|
||||
data = self.values.abs()
|
||||
return CSRTensor(self.indptr, self.indices, data, self.shape)
|
||||
|
||||
|
|
|
@ -25,6 +25,13 @@ from .grad_base import bprops, bprop_getters
|
|||
# Unused parameters are placeholders.
|
||||
|
||||
|
||||
@bprops.register("MakeCSRTensor")
|
||||
def bprop_make_csr_tensor(indptr, indices, values, dense_shape, out, dout):
|
||||
"""Backpropagator for primitive `MakeCSRTensor`."""
|
||||
res = (zeros_like(indptr), zeros_like(indices), F.csr_tensor_get_values(dout), ())
|
||||
return res
|
||||
|
||||
|
||||
@bprops.register("MakeCOOTensor")
|
||||
def bprop_make_coo_tensor(indices, values, dense_shape, out, dout):
|
||||
"""Backpropagator for primitive `MakeCOOTensor`."""
|
||||
|
|
|
@ -0,0 +1,22 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""__init__"""
|
||||
from .coo2csr import _coo2csr_akg
|
||||
from .csr2coo import _csr2coo_akg
|
||||
from .csr_gather import _csr_gather_akg
|
||||
from .csr_mul import _csr_mul_akg
|
||||
from .csr_mv import _csr_mv_akg
|
||||
from .csr_reduce_sum import _csr_reduce_sum_akg
|
||||
# Please insert op register in lexicographical order of the filename.
|
|
@ -0,0 +1,29 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""COO2CSR op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType
|
||||
|
||||
coo2csr_op_info = AkgGpuRegOp("COO2CSR") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.input(0, "row_indices") \
|
||||
.output(0, "output") \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
@op_info_register(coo2csr_op_info)
|
||||
def _coo2csr_akg():
|
||||
"""COO2CSR AutoDiff register"""
|
||||
return
|
|
@ -0,0 +1,29 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""CSR2COO op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType
|
||||
|
||||
csr2coo_op_info = AkgGpuRegOp("CSR2COO") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.input(0, "indptr") \
|
||||
.output(0, "output") \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
@op_info_register(csr2coo_op_info)
|
||||
def _csr2coo_akg():
|
||||
"""CSR2COO AutoDiff register"""
|
||||
return
|
|
@ -0,0 +1,33 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""CSRGatherop"""
|
||||
from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType
|
||||
|
||||
csr_gather_op_info = AkgGpuRegOp("CSRGather") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.input(0, "indptr") \
|
||||
.input(1, "indices") \
|
||||
.input(2, "dense") \
|
||||
.output(0, "output") \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.F32_Default, \
|
||||
DataType.F32_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.F32_Default, \
|
||||
DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
@op_info_register(csr_gather_op_info)
|
||||
def _csr_gather_akg():
|
||||
"""CSRGather AutoDiff register"""
|
||||
return
|
|
@ -0,0 +1,33 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""CSRMul op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, AkgCpuRegOp, DataType
|
||||
|
||||
csr_mul_op_info = AkgCpuRegOp("CSRMul") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.input(0, "indptr") \
|
||||
.input(1, "indices") \
|
||||
.input(2, "values") \
|
||||
.input(4, "dense_tensor") \
|
||||
.output(0, "output0") \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.F32_Default, \
|
||||
DataType.F32_Default, \
|
||||
DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
@op_info_register(csr_mul_op_info)
|
||||
def _csr_mul_akg():
|
||||
"""CSRMul AutoDiff register"""
|
||||
return
|
|
@ -0,0 +1,33 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""CSRMV op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, AkgCpuRegOp, DataType
|
||||
|
||||
csr_mv_op_info = AkgCpuRegOp("CSRMV") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.input(0, "indptr") \
|
||||
.input(1, "indices") \
|
||||
.input(2, "values") \
|
||||
.input(4, "dense_tensor") \
|
||||
.output(0, "output") \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.F32_Default, \
|
||||
DataType.F32_Default, \
|
||||
DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
@op_info_register(csr_mv_op_info)
|
||||
def _csr_mv_akg():
|
||||
"""CSRMV AutoDiff register"""
|
||||
return
|
|
@ -0,0 +1,31 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""CSRReduceSum op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, AkgCpuRegOp, DataType
|
||||
|
||||
csr_reduce_sum_op_info = AkgCpuRegOp("CSRReduceSum") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.input(0, "indptr") \
|
||||
.input(1, "indices") \
|
||||
.input(2, "values") \
|
||||
.output(0, "output") \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.F32_Default, \
|
||||
DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
@op_info_register(csr_reduce_sum_op_info)
|
||||
def _csr_reduce_sum_akg():
|
||||
"""CSRReduceSum AutoDiff register"""
|
||||
return
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2019 Huawei Technologies Co., Ltd
|
||||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -13,6 +13,13 @@
|
|||
# limitations under the License.
|
||||
|
||||
"""__init__"""
|
||||
from .coo2csr import _coo2csr_akg
|
||||
from .csr2coo import _csr2coo_akg
|
||||
from .csr_gather import _csr_gather_akg
|
||||
from .csr_div import _csr_div_akg
|
||||
from .csr_mul import _csr_mul_akg
|
||||
from .csr_mv import _csr_mv_akg
|
||||
from .csr_reduce_sum import _csr_reduce_sum_akg
|
||||
from .equal import _equal_akg
|
||||
from .greater_equal import _greater_equal_akg
|
||||
from .lessequal import _lessequal_akg
|
||||
|
@ -22,11 +29,4 @@ from .logical_or import _logical_or_akg
|
|||
from .mean import _simple_mean_akg
|
||||
from .mean_grad import _simple_mean_grad_akg
|
||||
from .notequal import _notequal_akg
|
||||
from .csr_reduce_sum import _csr_reduce_sum_akg
|
||||
from .csr_mv import _csr_mv_akg
|
||||
from .csr_mul import _csr_mul_akg
|
||||
from .csr_gather import _csr_gather_akg
|
||||
from .csr2coo import _csr2coo_akg
|
||||
from .coo2csr import _coo2csr_akg
|
||||
from .csr_div import _csr_div_akg
|
||||
# Please insert op register in lexicographical order of the filename.
|
||||
|
|
|
@ -27,30 +27,6 @@ using ".register" decorator.
|
|||
"""
|
||||
|
||||
|
||||
@mul.register("CSRTensor", "Tensor")
|
||||
def _mul_csrtensor_tensor(x, y):
|
||||
"""
|
||||
Returns x * y where x is CSRTensor and y is Tensor.
|
||||
|
||||
Outputs:
|
||||
CSRTensor, equal to x * y.
|
||||
"""
|
||||
data = F.csr_mul(x, y)
|
||||
return CSRTensor(x.indptr, x.indices, data, x.shape)
|
||||
|
||||
|
||||
@mul.register("Tensor", "CSRTensor")
|
||||
def _mul_tensor_csrtensor(x, y):
|
||||
"""
|
||||
Returns x * y where x is Tensor and y is CSRTensor.
|
||||
|
||||
Outputs:
|
||||
CSRTensor, equal to x * y.
|
||||
"""
|
||||
data = F.csr_mul(y, x)
|
||||
return CSRTensor(y.indptr, y.indices, data, y.shape)
|
||||
|
||||
|
||||
@mul.register("Number", "Number")
|
||||
def _mul_scalar(x, y):
|
||||
"""
|
||||
|
@ -221,3 +197,27 @@ def _list_mul_tensor(x, y):
|
|||
"""
|
||||
x = utils.sequence_to_tensor(x, y.dtype)
|
||||
return F.tensor_mul(x, y)
|
||||
|
||||
|
||||
@mul.register("CSRTensor", "Tensor")
|
||||
def _csrtensor_mul_tensor(x, y):
|
||||
"""
|
||||
Returns x * y where x is CSRTensor and y is Tensor.
|
||||
|
||||
Outputs:
|
||||
CSRTensor, equal to x * y.
|
||||
"""
|
||||
data = F.csr_mul(x, y)
|
||||
return CSRTensor(x.indptr, x.indices, data, x.shape)
|
||||
|
||||
|
||||
@mul.register("Tensor", "CSRTensor")
|
||||
def _tensor_mul_csrtensor(x, y):
|
||||
"""
|
||||
Returns x * y where x is Tensor and y is CSRTensor.
|
||||
|
||||
Outputs:
|
||||
CSRTensor, equal to x * y.
|
||||
"""
|
||||
data = F.csr_mul(y, x)
|
||||
return CSRTensor(y.indptr, y.indices, data, y.shape)
|
||||
|
|
|
@ -149,8 +149,54 @@ tensor_scatter_update = P.TensorScatterUpdate()
|
|||
scatter_nd_update = P.ScatterNdUpdate()
|
||||
stack = P.Stack()
|
||||
|
||||
csr_mul = _csr_ops.CSRMul()
|
||||
csr_div = _csr_ops.CSRDiv()
|
||||
def csr_mul(x, y):
|
||||
"""
|
||||
Returns x * y where x is CSRTensor and y is Tensor.
|
||||
|
||||
Note:
|
||||
This function returns the results of dense Tensor, represents the non-zero
|
||||
values of the CSRTensor. If user expects a CSRTensor as output, please directly
|
||||
use `*` operator instead. Only support dense tensor broadcast to sparse tensor
|
||||
at the moment.
|
||||
|
||||
Args:
|
||||
x (CSRTensor): Sparse CSR Tensor.
|
||||
y (Tensor): Dense Tensor, its shape must be able to broadcast to x.
|
||||
|
||||
Returns:
|
||||
Dense Tensor, represents the non-zero values of the result.
|
||||
|
||||
Supported Platforms:
|
||||
``GPU`` ``CPU``
|
||||
"""
|
||||
if x.shape[0] != 1 and y.shape[0] == 1:
|
||||
y = y.expand_as(x)
|
||||
return _csr_ops.CSRMul()(x, y)
|
||||
|
||||
def csr_div(x, y):
|
||||
"""
|
||||
Returns x / y where x is CSRTensor and y is Tensor.
|
||||
|
||||
Note:
|
||||
This function returns the results of dense Tensor, represents the non-zero
|
||||
values of the CSRTensor. If user expects a CSRTensor as output, please directly
|
||||
use `/` operator instead. Only support dense tensor broadcast to sparse tensor
|
||||
at the moment.
|
||||
|
||||
Args:
|
||||
x (CSRTensor): Sparse CSR Tensor.
|
||||
y (Tensor): Dense Tensor, its shape must be able to broadcast to x.
|
||||
|
||||
Returns:
|
||||
Dense Tensor, represents the non-zero values of the result.
|
||||
|
||||
Supported Platforms:
|
||||
``GPU`` ``CPU``
|
||||
"""
|
||||
if x.shape[0] != 1 and y.shape[0] == 1:
|
||||
y = y.expand_as(x)
|
||||
return _csr_ops.CSRDiv()(x, y)
|
||||
|
||||
csr_mv = _csr_ops.CSRMV()
|
||||
csr_reduce_sum = _csr_ops.CSRReduceSum()
|
||||
csr_gather = _csr_ops.CSRGather()
|
||||
|
@ -749,5 +795,6 @@ tensor_operator_registry.register('narrow', narrow)
|
|||
tensor_operator_registry.register('sort', sort)
|
||||
tensor_operator_registry.register('zeros', zeros)
|
||||
tensor_operator_registry.register('tensor_scatter_update', tensor_scatter_update)
|
||||
tensor_operator_registry.register('tensor_scatter_add', tensor_scatter_add)
|
||||
__all__ = [name for name in dir() if name[0] != "_"]
|
||||
__all__.remove('Primitive')
|
||||
|
|
|
@ -410,6 +410,13 @@ class AkgAscendRegOp(AkgRegOp):
|
|||
super(AkgAscendRegOp, self).__init__(op_name, "AiCore")
|
||||
|
||||
|
||||
class AkgCpuRegOp(AkgRegOp):
|
||||
"""Class for AkgCpu op info register"""
|
||||
|
||||
def __init__(self, op_name):
|
||||
super(AkgCpuRegOp, self).__init__(op_name, "LLVM")
|
||||
|
||||
|
||||
class AiCPURegOp(CpuRegOp):
|
||||
r"""
|
||||
Class for AiCPU operator information register.
|
||||
|
|
|
@ -31,7 +31,7 @@ class CSRReduceSum(PrimitiveWithInfer):
|
|||
Tensor, the dtype is the same as `sparse_tensor.values`.
|
||||
|
||||
Supported Platforms:
|
||||
``GPU``
|
||||
``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import mindspore
|
||||
|
@ -79,7 +79,7 @@ class CSRMV(PrimitiveWithInfer):
|
|||
Tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``GPU``
|
||||
``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import mindspore
|
||||
|
@ -115,7 +115,7 @@ class CSRMV(PrimitiveWithInfer):
|
|||
|
||||
class CSRMul(PrimitiveWithInfer):
|
||||
"""
|
||||
Elemwise multiplication on a CSRTensor and a dense tensor.
|
||||
Elemwise multiplication of a CSRTensor and a dense tensor.
|
||||
|
||||
.. warning::
|
||||
This is an experimental prototype that is subject to change and/or deletion.
|
||||
|
@ -132,7 +132,7 @@ class CSRMul(PrimitiveWithInfer):
|
|||
Tensor, the dtype and shape is the same as `sparse_tensor.values`.
|
||||
|
||||
Supported Platforms:
|
||||
``GPU``
|
||||
``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import mindspore
|
||||
|
@ -183,7 +183,7 @@ class CSRGather(PrimitiveWithInfer):
|
|||
dimensions are the same as ``dense[2:]``.
|
||||
|
||||
Supported Platforms:
|
||||
``GPU``
|
||||
``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.nn as nn
|
||||
|
@ -228,7 +228,7 @@ class CSR2COO(PrimitiveWithInfer):
|
|||
Tensor, the dtype is the same as `indptr` and has shape (`nnz`,).
|
||||
|
||||
Supported Platforms:
|
||||
``GPU``
|
||||
``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.nn as nn
|
||||
|
@ -268,7 +268,7 @@ class COO2CSR(PrimitiveWithInfer):
|
|||
Tensor, the dtype is the same as `row_indices` and has shape ('height' + 1,).
|
||||
|
||||
Supported Platforms:
|
||||
``GPU``
|
||||
``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.nn as nn
|
||||
|
@ -310,7 +310,7 @@ class CSRDiv(PrimitiveWithInfer):
|
|||
Tensor, the dtype and shape is the same as `sparse_tensor.values`.
|
||||
|
||||
Supported Platforms:
|
||||
``GPU``
|
||||
``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import mindspore
|
||||
|
|
|
@ -37,6 +37,7 @@ def compare_coo(coo1, coo2):
|
|||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_make_coo():
|
||||
"""
|
||||
|
@ -96,6 +97,7 @@ def test_coo_tensor_in_while():
|
|||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_coo_method():
|
||||
"""
|
||||
|
|
|
@ -253,6 +253,7 @@ def test_csr_tensor_in_while_cpu():
|
|||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_csr_ops():
|
||||
"""
|
||||
|
@ -448,6 +449,7 @@ def test_isinstance_csr_tensor():
|
|||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_dtype_csr_tensor():
|
||||
"""
|
||||
|
@ -475,6 +477,7 @@ def test_dtype_csr_tensor():
|
|||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_csr_bprop():
|
||||
"""
|
||||
|
@ -482,38 +485,25 @@ def test_csr_bprop():
|
|||
Description: Test CSRReduceSum, CSRMul, CSRMV, CSRTensor.to_coo(), CSRTensor.to_dense().
|
||||
Expectation: Success.
|
||||
"""
|
||||
class CSRMulNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(CSRMulNet, self).__init__()
|
||||
self.op = _csr_ops.CSRMul()
|
||||
csr_reduce_sum = _csr_ops.CSRReduceSum()
|
||||
csrmv = _csr_ops.CSRMV()
|
||||
grad_op = ops.GradOperation(get_all=True)
|
||||
|
||||
def construct(self, csr_tensor, dense):
|
||||
return self.op(csr_tensor, dense)
|
||||
def test_csr_mul(csr_tensor, dense):
|
||||
return csr_tensor * dense
|
||||
|
||||
class CSRReduceSumNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(CSRReduceSumNet, self).__init__()
|
||||
self.op = _csr_ops.CSRReduceSum()
|
||||
def test_csr_reduce_sum(csr_tensor, axis):
|
||||
return csr_reduce_sum(csr_tensor, axis)
|
||||
|
||||
def construct(self, csr_tensor, axis):
|
||||
return self.op(csr_tensor, axis)
|
||||
def test_csrmv(csr_tensor, dense):
|
||||
return csrmv(csr_tensor, dense)
|
||||
|
||||
class CSRMVNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(CSRMVNet, self).__init__()
|
||||
self.op = _csr_ops.CSRMV()
|
||||
|
||||
def construct(self, csr_tensor, dense):
|
||||
return self.op(csr_tensor, dense)
|
||||
|
||||
class BpropNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(BpropNet, self).__init__()
|
||||
self.net = net
|
||||
self.grad_op = ops.GradOperation(get_all=True)
|
||||
|
||||
def construct(self, *inputs):
|
||||
return self.grad_op(self.net)(*inputs)
|
||||
test_csr_mul_grad_pynative = grad_op(test_csr_mul)
|
||||
test_csr_mul_grad_graph = ms_function(test_csr_mul_grad_pynative)
|
||||
test_csr_reduce_sum_grad_pynative = grad_op(test_csr_reduce_sum)
|
||||
test_csr_reduce_sum_grad_graph = ms_function(test_csr_reduce_sum_grad_pynative)
|
||||
test_csrmv_grad_pynative = grad_op(test_csrmv)
|
||||
test_csrmv_grad_graph = ms_function(test_csrmv_grad_pynative)
|
||||
|
||||
indptr = Tensor([0, 1, 4, 6], dtype=mstype.int32)
|
||||
indices = Tensor([3, 0, 1, 2, 1, 3], dtype=mstype.int32)
|
||||
|
@ -522,33 +512,45 @@ def test_csr_bprop():
|
|||
csr_tensor = CSRTensor(indptr, indices, values, dense_shape)
|
||||
|
||||
csr_mv_arg = Tensor([[1], [2], [3], [4]], dtype=mstype.float32)
|
||||
csr_mv_output_1, csr_mv_output_2 = BpropNet(CSRMVNet())(csr_tensor, csr_mv_arg)
|
||||
csr_mv_expect_1 = np.array([4, 1, 2, 3, 2, 4], dtype=np.float32)
|
||||
csr_mv_expect_2 = np.array([[1], [6], [3], [5]], dtype=np.float32)
|
||||
csr_mv_output_1, csr_mv_output_2 = test_csrmv_grad_pynative(csr_tensor, csr_mv_arg)
|
||||
assert np.allclose(csr_mv_output_1.values.asnumpy(), csr_mv_expect_1)
|
||||
assert np.allclose(csr_mv_output_2.asnumpy(), csr_mv_expect_2)
|
||||
csr_mv_output_1, csr_mv_output_2 = test_csrmv_grad_graph(csr_tensor, csr_mv_arg)
|
||||
assert np.allclose(csr_mv_output_1.values.asnumpy(), csr_mv_expect_1)
|
||||
assert np.allclose(csr_mv_output_2.asnumpy(), csr_mv_expect_2)
|
||||
|
||||
csr_reduce_sum_output = BpropNet(CSRReduceSumNet())(csr_tensor, 1)
|
||||
csr_reduce_sum_expect = np.ones(6, dtype=np.float32)
|
||||
csr_reduce_sum_output = test_csr_reduce_sum_grad_pynative(csr_tensor, 1)
|
||||
assert np.allclose(csr_reduce_sum_output[0].values.asnumpy(), csr_reduce_sum_expect)
|
||||
csr_reduce_sum_output = test_csr_reduce_sum_grad_graph(csr_tensor, 1)
|
||||
assert np.allclose(csr_reduce_sum_output[0].values.asnumpy(), csr_reduce_sum_expect)
|
||||
|
||||
csr_mul_arg_1 = Tensor([[1], [2], [3]], dtype=mstype.float32)
|
||||
csr_mul_output_1_1, csr_mul_output_1_2 = BpropNet(CSRMulNet())(csr_tensor, csr_mul_arg_1)
|
||||
csr_mul_expect_1_1 = np.array([1, 2, 2, 2, 3, 3], dtype=np.float32)
|
||||
csr_mul_expect_1_2 = np.array([[0], [6], [9]], dtype=np.float32)
|
||||
csr_mul_output_1_1, csr_mul_output_1_2 = test_csr_mul_grad_pynative(csr_tensor, csr_mul_arg_1)
|
||||
assert np.allclose(csr_mul_output_1_1.values.asnumpy(), csr_mul_expect_1_1)
|
||||
assert np.allclose(csr_mul_output_1_2.asnumpy(), csr_mul_expect_1_2)
|
||||
csr_mul_output_1_1, csr_mul_output_1_2 = test_csr_mul_grad_graph(csr_tensor, csr_mul_arg_1)
|
||||
assert np.allclose(csr_mul_output_1_1.values.asnumpy(), csr_mul_expect_1_1)
|
||||
assert np.allclose(csr_mul_output_1_2.asnumpy(), csr_mul_expect_1_2)
|
||||
|
||||
csr_mul_arg_2 = Tensor(np.arange(12).reshape(3, 4), dtype=mstype.float32)
|
||||
csr_mul_output_2_1, csr_mul_output_2_2 = BpropNet(CSRMulNet())(csr_tensor, csr_mul_arg_2)
|
||||
csr_mul_expect_2_1 = np.array([3, 4, 5, 6, 9, 11], dtype=np.float32)
|
||||
csr_mul_expect_2_2 = np.array([[0, 0, 0, 0], [1, 2, 3, 0], [0, 4, 0, 5]], np.float32)
|
||||
csr_mul_output_2_1, csr_mul_output_2_2 = test_csr_mul_grad_pynative(csr_tensor, csr_mul_arg_2)
|
||||
assert np.allclose(csr_mul_output_2_1.values.asnumpy(), csr_mul_expect_2_1)
|
||||
assert np.allclose(csr_mul_output_2_2.asnumpy(), csr_mul_expect_2_2)
|
||||
csr_mul_output_2_1, csr_mul_output_2_2 = test_csr_mul_grad_graph(csr_tensor, csr_mul_arg_2)
|
||||
assert np.allclose(csr_mul_output_2_1.values.asnumpy(), csr_mul_expect_2_1)
|
||||
assert np.allclose(csr_mul_output_2_2.asnumpy(), csr_mul_expect_2_2)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_csr_method():
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue