support TensorShape ops to instead of DynamicShape ops

This commit is contained in:
lianliguang 2022-02-17 14:35:00 +08:00
parent 33800c7f31
commit 5c8af66e53
21 changed files with 324 additions and 109 deletions

View File

@ -2270,7 +2270,8 @@ void AnfRuntimeAlgorithm::GetAllFatherRealNode(const AnfNodePtr &anf_node, std::
bool AnfRuntimeAlgorithm::IsHostKernel(const CNodePtr &kernel_node) {
const std::set<std::string> host_kernel = {prim::kPrimDynamicShape->name(), prim::kPrimReshape->name(),
prim::kPrimDynamicBroadcastGradientArgs->name()};
prim::kPrimDynamicBroadcastGradientArgs->name(),
prim::kPrimTensorShape->name()};
auto op_name = AnfAlgo::GetCNodeName(kernel_node);
if (host_kernel.find(op_name) == host_kernel.end()) {
return false;

View File

@ -1138,12 +1138,11 @@ std::unique_ptr<GraphCompilerInfo> MindRTBackend::ConstructGraphCompilerInfo(con
parser->Parse(control_nodes_, graphs, device_contexts, root_graph, func_graph_to_kernel_graphs);
runtime::KernelMapPosition outputs_order;
size_t outputs_num;
const auto &root_output =
AnfAlgo::VisitKernelWithReturnType(root_graph->output(), 0, false, {prim::kPrimTupleGetItem}).first;
size_t position = 0;
auto outputs = AnfAlgo::GetAllOutputWithIndex(root_output);
outputs_num = outputs.size();
size_t outputs_num = outputs.size();
for (const auto &output : outputs) {
if (outputs_order.count(output) == 0) {
outputs_order[output] = {position++};

View File

@ -224,8 +224,8 @@ BuiltInTypeMap &GetAttrMap() {
static BuiltInTypeMap attr_map = {
{kObjectTypeTensorType,
{
{"shape", std::string("shape_")}, // C.shape_
{"dtype", std::string("dtype_")}, // C.dtype_
{"shape", prim::kPrimShape}, // C.shape_
{"dtype", prim::kPrimDType}, // C.dtype_
{"size", std::string("size_")}, // C.size_
{"ndim", std::string("ndim_")}, // C.ndim_
{"T", std::string("T_")}, // C.T_

View File

@ -22,8 +22,8 @@
namespace mindspore {
namespace kernel {
void DynamicShapeKernel::Execute() {
MS_LOG(INFO) << "Execute DynamicShapeKernel Start";
void TensorShapeKernel::Execute() {
MS_LOG(INFO) << "Execute TensorShapeKernel Start";
auto cnode = cnode_ptr_.lock();
MS_EXCEPTION_IF_NULL(cnode);
auto input_num = AnfAlgo::GetInputTensorNum(cnode);
@ -52,7 +52,7 @@ void DynamicShapeKernel::Execute() {
auto ret = memcpy_s(const_cast<void *>(output_addr->GetPtr()), output_addr->GetSize(),
output_tensor_for_sync->data_c(), LongToSize(output_tensor_for_sync->data().nbytes()));
if (ret != EOK) {
MS_LOG(EXCEPTION) << "Execute DynamicShapeKernel memcpy_s failed!";
MS_LOG(EXCEPTION) << "Execute TensorShapeKernel memcpy_s failed!";
}
} else {
auto runtime_instance = device::KernelRuntimeManager::Instance().GetCurrentKernelRuntime();
@ -68,11 +68,11 @@ void DynamicShapeKernel::Execute() {
output_tensor_for_sync->device_info().host_format_);
}
MS_LOG(INFO) << "Execute DynamicShapeKernel End";
MS_LOG(INFO) << "Execute TensorShapeKernel End";
}
void DynamicShapeKernel::Execute(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs) {
MS_LOG(INFO) << "Execute DynamicShapeKernel Start";
void TensorShapeKernel::Execute(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs) {
MS_LOG(INFO) << "Execute TensorShapeKernel Start";
auto cnode = cnode_ptr_.lock();
MS_EXCEPTION_IF_NULL(cnode);
auto input_num = AnfAlgo::GetInputTensorNum(cnode);
@ -100,27 +100,27 @@ void DynamicShapeKernel::Execute(const std::vector<AddressPtr> &inputs, const st
auto status = rtMemcpyAsync(outputs[0]->addr, outputs[0]->size, output_tensor_for_sync->data_c(),
LongToSize(output_tensor_for_sync->data().nbytes()), RT_MEMCPY_HOST_TO_DEVICE, stream_);
if (status != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "Execute DynamicShapeKernel rtMemcpyAsync failed!";
MS_LOG(EXCEPTION) << "Execute TensorShapeKernel rtMemcpyAsync failed!";
}
MS_LOG(INFO) << "Execute DynamicShapeKernel End";
MS_LOG(INFO) << "Execute TensorShapeKernel End";
}
device::DynamicKernelPtr DynamicShapeKernelMod::GenDynamicKernel(const CNodePtr &cnode_ptr, void *stream_ptr) {
return std::make_shared<DynamicShapeKernel>(stream_ptr, cnode_ptr);
device::DynamicKernelPtr TensorShapeKernelMod::GenDynamicKernel(const CNodePtr &cnode_ptr, void *stream_ptr) {
return std::make_shared<TensorShapeKernel>(stream_ptr, cnode_ptr);
}
bool DynamicShapeKernelMod::Launch(const std::vector<AddressPtr> &, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &, void *stream_ptr) {
bool TensorShapeKernelMod::Launch(const std::vector<AddressPtr> &, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &, void *stream_ptr) {
auto node = anf_node_.lock();
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
stream_ = stream_ptr;
auto shape_kernel = std::make_shared<DynamicShapeKernel>(stream_ptr, cnode);
auto shape_kernel = std::make_shared<TensorShapeKernel>(stream_ptr, cnode);
try {
shape_kernel->Execute();
} catch (const std::exception &e) {
MS_LOG(ERROR) << "DynamicShapeKernelMod Launch failed. node: " << cnode->fullname_with_scope()
MS_LOG(ERROR) << "TensorShapeKernelMod Launch failed. node: " << cnode->fullname_with_scope()
<< ", Error message is " << e.what();
return false;
}

View File

@ -24,23 +24,24 @@
using HostDynamicKernel = mindspore::device::ascend::HostDynamicKernel;
namespace mindspore {
namespace kernel {
class DynamicShapeKernel : public HostDynamicKernel {
class TensorShapeKernel : public HostDynamicKernel {
public:
DynamicShapeKernel(void *stream, const CNodePtr &cnode_ptr) : HostDynamicKernel(stream, cnode_ptr) {}
~DynamicShapeKernel() override = default;
TensorShapeKernel(void *stream, const CNodePtr &cnode_ptr) : HostDynamicKernel(stream, cnode_ptr) {}
~TensorShapeKernel() override = default;
void Execute() override;
void Execute(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
};
class DynamicShapeKernelMod : public HostKernelMod {
class TensorShapeKernelMod : public HostKernelMod {
public:
DynamicShapeKernelMod() = default;
~DynamicShapeKernelMod() override = default;
TensorShapeKernelMod() = default;
~TensorShapeKernelMod() override = default;
device::DynamicKernelPtr GenDynamicKernel(const CNodePtr &cnode_ptr, void *stream_ptr) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override;
};
MS_HOST_REG_KERNEL(DynamicShape, DynamicShapeKernelMod);
MS_HOST_REG_KERNEL(DynamicShape, TensorShapeKernelMod);
MS_HOST_REG_KERNEL(TensorShape, TensorShapeKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -24,7 +24,7 @@ constexpr size_t kDynamicShapeOutputNum = 1;
} // namespace
template <typename T>
void DynamicShapeCpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
void TensorShapeCpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_name_ = AnfAlgo::GetCNodeName(kernel_node);
cnode_ptr_ = kernel_node;
@ -35,9 +35,9 @@ void DynamicShapeCpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
}
template <typename T>
bool DynamicShapeCpuKernelMod<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
bool TensorShapeCpuKernelMod<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kDynamicShapeOutputNum, kernel_name_);
auto node_ = cnode_ptr_.lock();
if (node_ == nullptr) {

View File

@ -25,10 +25,10 @@
namespace mindspore {
namespace kernel {
template <typename T>
class DynamicShapeCpuKernelMod : public NativeCpuKernelMod {
class TensorShapeCpuKernelMod : public NativeCpuKernelMod {
public:
DynamicShapeCpuKernelMod() = default;
~DynamicShapeCpuKernelMod() override = default;
TensorShapeCpuKernelMod() = default;
~TensorShapeCpuKernelMod() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
@ -36,16 +36,27 @@ class DynamicShapeCpuKernelMod : public NativeCpuKernelMod {
const std::vector<AddressPtr> &outputs) override;
};
MS_REG_CPU_KERNEL_T(DynamicShape, KernelAttr(), DynamicShapeCpuKernelMod, float);
MS_REG_CPU_KERNEL_T(DynamicShape, KernelAttr(), DynamicShapeCpuKernelMod, int8_t)
MS_REG_CPU_KERNEL_T(DynamicShape, KernelAttr(), DynamicShapeCpuKernelMod, int16_t)
MS_REG_CPU_KERNEL_T(DynamicShape, KernelAttr(), DynamicShapeCpuKernelMod, int32_t)
MS_REG_CPU_KERNEL_T(DynamicShape, KernelAttr(), DynamicShapeCpuKernelMod, int64_t)
MS_REG_CPU_KERNEL_T(DynamicShape, KernelAttr(), DynamicShapeCpuKernelMod, uint8_t)
MS_REG_CPU_KERNEL_T(DynamicShape, KernelAttr(), DynamicShapeCpuKernelMod, uint16_t)
MS_REG_CPU_KERNEL_T(DynamicShape, KernelAttr(), DynamicShapeCpuKernelMod, uint32_t)
MS_REG_CPU_KERNEL_T(DynamicShape, KernelAttr(), DynamicShapeCpuKernelMod, uint64_t)
MS_REG_CPU_KERNEL_T(DynamicShape, KernelAttr(), DynamicShapeCpuKernelMod, bool)
MS_REG_CPU_KERNEL_T(DynamicShape, KernelAttr(), TensorShapeCpuKernelMod, float);
MS_REG_CPU_KERNEL_T(DynamicShape, KernelAttr(), TensorShapeCpuKernelMod, int8_t)
MS_REG_CPU_KERNEL_T(DynamicShape, KernelAttr(), TensorShapeCpuKernelMod, int16_t)
MS_REG_CPU_KERNEL_T(DynamicShape, KernelAttr(), TensorShapeCpuKernelMod, int32_t)
MS_REG_CPU_KERNEL_T(DynamicShape, KernelAttr(), TensorShapeCpuKernelMod, int64_t)
MS_REG_CPU_KERNEL_T(DynamicShape, KernelAttr(), TensorShapeCpuKernelMod, uint8_t)
MS_REG_CPU_KERNEL_T(DynamicShape, KernelAttr(), TensorShapeCpuKernelMod, uint16_t)
MS_REG_CPU_KERNEL_T(DynamicShape, KernelAttr(), TensorShapeCpuKernelMod, uint32_t)
MS_REG_CPU_KERNEL_T(DynamicShape, KernelAttr(), TensorShapeCpuKernelMod, uint64_t)
MS_REG_CPU_KERNEL_T(DynamicShape, KernelAttr(), TensorShapeCpuKernelMod, bool)
MS_REG_CPU_KERNEL_T(TensorShape, KernelAttr(), TensorShapeCpuKernelMod, float);
MS_REG_CPU_KERNEL_T(TensorShape, KernelAttr(), TensorShapeCpuKernelMod, int8_t)
MS_REG_CPU_KERNEL_T(TensorShape, KernelAttr(), TensorShapeCpuKernelMod, int16_t)
MS_REG_CPU_KERNEL_T(TensorShape, KernelAttr(), TensorShapeCpuKernelMod, int32_t)
MS_REG_CPU_KERNEL_T(TensorShape, KernelAttr(), TensorShapeCpuKernelMod, int64_t)
MS_REG_CPU_KERNEL_T(TensorShape, KernelAttr(), TensorShapeCpuKernelMod, uint8_t)
MS_REG_CPU_KERNEL_T(TensorShape, KernelAttr(), TensorShapeCpuKernelMod, uint16_t)
MS_REG_CPU_KERNEL_T(TensorShape, KernelAttr(), TensorShapeCpuKernelMod, uint32_t)
MS_REG_CPU_KERNEL_T(TensorShape, KernelAttr(), TensorShapeCpuKernelMod, uint64_t)
MS_REG_CPU_KERNEL_T(TensorShape, KernelAttr(), TensorShapeCpuKernelMod, bool)
} // namespace kernel
} // namespace mindspore

View File

@ -19,28 +19,52 @@
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_TWO(TensorShape, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
TensorShapeGpuKernelMod, int32_t, int32_t)
MS_REG_GPU_KERNEL_TWO(TensorShape, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt32),
TensorShapeGpuKernelMod, half, int32_t)
MS_REG_GPU_KERNEL_TWO(TensorShape, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32),
TensorShapeGpuKernelMod, float, int32_t)
MS_REG_GPU_KERNEL_TWO(TensorShape, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt32),
TensorShapeGpuKernelMod, bool, int32_t)
MS_REG_GPU_KERNEL_TWO(TensorShape, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
TensorShapeGpuKernelMod, int32_t, int64_t)
MS_REG_GPU_KERNEL_TWO(TensorShape, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt64),
TensorShapeGpuKernelMod, half, int64_t)
MS_REG_GPU_KERNEL_TWO(TensorShape, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt64),
TensorShapeGpuKernelMod, float, int64_t)
MS_REG_GPU_KERNEL_TWO(TensorShape, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt64),
TensorShapeGpuKernelMod, bool, int64_t)
MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
DynamicShapeGpuKernelMod, int32_t, int32_t)
TensorShapeGpuKernelMod, int32_t, int32_t)
MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt32),
DynamicShapeGpuKernelMod, half, int32_t)
TensorShapeGpuKernelMod, half, int32_t)
MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32),
DynamicShapeGpuKernelMod, float, int32_t)
TensorShapeGpuKernelMod, float, int32_t)
MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt32),
DynamicShapeGpuKernelMod, bool, int32_t)
TensorShapeGpuKernelMod, bool, int32_t)
MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
DynamicShapeGpuKernelMod, int32_t, int64_t)
TensorShapeGpuKernelMod, int32_t, int64_t)
MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt64),
DynamicShapeGpuKernelMod, half, int64_t)
TensorShapeGpuKernelMod, half, int64_t)
MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt64),
DynamicShapeGpuKernelMod, float, int64_t)
TensorShapeGpuKernelMod, float, int64_t)
MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt64),
DynamicShapeGpuKernelMod, bool, int64_t)
TensorShapeGpuKernelMod, bool, int64_t)
} // namespace kernel
} // namespace mindspore

View File

@ -27,10 +27,10 @@
namespace mindspore {
namespace kernel {
template <typename T, typename S>
class DynamicShapeGpuKernelMod : public NativeGpuKernelMod {
class TensorShapeGpuKernelMod : public NativeGpuKernelMod {
public:
DynamicShapeGpuKernelMod() { ResetResource(); }
~DynamicShapeGpuKernelMod() = default;
TensorShapeGpuKernelMod() { ResetResource(); }
~TensorShapeGpuKernelMod() = default;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {

View File

@ -226,8 +226,6 @@ AbstractBasePtr InferImplDynamicAssign(const AnalysisEnginePtr &, const Primitiv
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplGatherV2(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplDynamicShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplSparseApplyFtrl(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplSparseApplyProximalAdagrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,

View File

@ -777,29 +777,6 @@ AbstractBasePtr InferImplEmbeddingLookup(const AnalysisEnginePtr &, const Primit
return ret;
}
AbstractBasePtr InferImplDynamicShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
const std::string &op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 1);
AbstractTensorPtr input = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
MS_EXCEPTION_IF_NULL(input->shape());
auto shape = input->shape()->shape();
bool has_dyn_shape = std::any_of(shape.begin(), shape.end(), [](int64_t dim) { return dim == Shape::SHP_ANY; });
ShapeVector tensor_shp({static_cast<int64_t>(shape.size())});
if (has_dyn_shape) {
auto elem = std::make_shared<AbstractScalar>(std::make_shared<AnyValue>(), std::make_shared<Int>(64));
auto min_value = MakeValue(input->shape()->min_shape());
auto max_value = MakeValue(input->shape()->max_shape());
auto tensor = std::make_shared<AbstractTensor>(elem, std::make_shared<Shape>(tensor_shp));
tensor->set_value_range(min_value, max_value);
return tensor;
}
auto shp_buf_size = sizeof(int64_t) * shape.size();
auto tensor = std::make_shared<tensor::Tensor>(kNumberTypeInt64, tensor_shp, shape.data(), shp_buf_size);
return tensor->ToAbstract();
}
AbstractBasePtr InferImplTranspose(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
const std::string &op_name = primitive->name();

View File

@ -157,7 +157,6 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimComputeAccidentalHits, R{InferImplComputeAccidentalHits, nullptr, true}},
{prim::kPrimDynamicStitch, R{InferImplDynamicStitch, nullptr, true}},
{prim::kPrimPadAndShift, R{InferImplPadAndShift, nullptr, true}},
{prim::kPrimDynamicShape, R{InferImplDynamicShape, nullptr, true}},
{prim::kPrimMapUniform, R{InferImplMapUniform, nullptr, true}},
{prim::kPrimSplit, R{InferImplSplit, nullptr, true}},
{prim::kPrimSequenceMask, R{InferImplSequenceMask, nullptr, true}},

View File

@ -89,6 +89,7 @@ constexpr auto kCross = "Cross";
// Arrays
constexpr auto kDynamicShape = "DynamicShape";
constexpr auto kTensorShape = "TensorShape";
constexpr auto kStack = "Stack";
constexpr auto kUnstack = "Unstack";
constexpr auto kTupleGetItem = "TupleGetItem";
@ -263,6 +264,7 @@ MS_CORE_API inline const PrimitivePtr kPrimSparseToDense = std::make_shared<Prim
MS_CORE_API inline const PrimitivePtr kPrimShape = std::make_shared<Primitive>("Shape");
MS_CORE_API inline const PrimitivePtr kPrimStridedSlice = std::make_shared<Primitive>(kStridedSlice);
MS_CORE_API inline const PrimitivePtr kPrimStridedSliceGrad = std::make_shared<Primitive>(kStridedSliceGrad);
MS_CORE_API inline const PrimitivePtr kPrimTensorShape = std::make_shared<Primitive>(kTensorShape);
MS_CORE_API inline const PrimitivePtr kPrimDynamicShape = std::make_shared<Primitive>(kDynamicShape);
MS_CORE_API inline const PrimitivePtr kPrimEmbeddingLookup = std::make_shared<Primitive>("EmbeddingLookup");
MS_CORE_API inline const PrimitivePtr kPrimEmbeddingLookupCommGrad =

View File

@ -0,0 +1,31 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_DYNAMIC_SHAPE_H_
#define MINDSPORE_CORE_OPS_DYNAMIC_SHAPE_H_
#include "ops/primitive_c.h"
#include "base/core_ops.h"
namespace mindspore {
namespace ops {
class DynamicShape : public PrimitiveC {
public:
DynamicShape() : PrimitiveC(prim::kPrimDynamicShape->name()) {}
~DynamicShape() = default;
MS_DECLARE_PARENT(DynamicShape, PrimitiveC);
};
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_DYNAMIC_SHAPE_H_

View File

@ -0,0 +1,52 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/tensor_shape.h"
#include <vector>
#include <memory>
#include "ops/dynamic_shape.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
abstract::AbstractBasePtr TensorShapeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args) {
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, 1, primitive->name());
auto input = input_args[0]->cast<abstract::AbstractTensorPtr>();
if (input == nullptr) {
MS_EXCEPTION(TypeError) << "For the input['input_x'] of primitive[" << primitive->name() << "] should be a Tensor"
<< ", but got " << input_args[0]->BuildType()->ToString() << ".";
}
MS_EXCEPTION_IF_NULL(input->shape());
auto shape = input->shape()->shape();
bool has_dyn_shape =
std::any_of(shape.begin(), shape.end(), [](int64_t dim) { return dim == abstract::Shape::SHP_ANY; });
ShapeVector tensor_shp({static_cast<int64_t>(shape.size())});
if (has_dyn_shape) {
auto elem = std::make_shared<abstract::AbstractScalar>(std::make_shared<AnyValue>(), std::make_shared<Int>(64));
auto min_value = MakeValue(input->shape()->min_shape());
auto max_value = MakeValue(input->shape()->max_shape());
auto abs_tensor = std::make_shared<abstract::AbstractTensor>(elem, std::make_shared<abstract::Shape>(tensor_shp));
abs_tensor->set_value_range(min_value, max_value);
return abs_tensor;
}
auto shp_buf_size = sizeof(int64_t) * shape.size();
auto tensor = std::make_shared<tensor::Tensor>(kNumberTypeInt64, tensor_shp, shape.data(), shp_buf_size);
return tensor->ToAbstract();
}
REGISTER_PRIMITIVE_EVAL_IMPL(TensorShape, prim::kPrimTensorShape, TensorShapeInfer, nullptr, true);
REGISTER_PRIMITIVE_EVAL_IMPL(DynamicShape, prim::kPrimDynamicShape, TensorShapeInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,31 @@
/**
* Copyright 2019-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_TENSOR_SHAPE_H_
#define MINDSPORE_CORE_OPS_TENSOR_SHAPE_H_
#include "ops/primitive_c.h"
#include "base/core_ops.h"
namespace mindspore {
namespace ops {
class TensorShape : public PrimitiveC {
public:
TensorShape() : PrimitiveC(prim::kPrimTensorShape->name()) {}
~TensorShape() = default;
MS_DECLARE_PARENT(TensorShape, PrimitiveC);
};
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_TENSOR_SHAPE_H_

View File

@ -34,6 +34,7 @@ from .bias_add_grad import _bias_add_grad_cpu
from .dropout import _dropout_cpu
from .dropout_grad import _dropout_grad_cpu
from .gather_d import _gather_cpu
from .tensor_shape import _tensor_shape_cpu
from .gather_d_grad import _gather_d_grad_cpu
from .gather_v2 import _gather_v2_cpu
from .gather_nd import _gather_nd_cpu

View File

@ -0,0 +1,38 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""DynamicShape op"""
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
tensor_shape_op_info = CpuRegOp("TensorShape") \
.input(0, "x", "required") \
.output(0, "y", "required") \
.dtype_format(DataType.F32_Default, DataType.I64_Default) \
.dtype_format(DataType.I8_Default, DataType.I64_Default) \
.dtype_format(DataType.I16_Default, DataType.I64_Default) \
.dtype_format(DataType.I32_Default, DataType.I64_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.U8_Default, DataType.I64_Default) \
.dtype_format(DataType.U16_Default, DataType.I64_Default) \
.dtype_format(DataType.U32_Default, DataType.I64_Default) \
.dtype_format(DataType.U64_Default, DataType.I64_Default) \
.dtype_format(DataType.BOOL_Default, DataType.I64_Default) \
.get_op_info()
@op_info_register(tensor_shape_op_info)
def _tensor_shape_cpu():
"""DynamicShape cpu register"""
return

View File

@ -19,7 +19,16 @@ Primitive operator classes.
A collection of operators to build neural networks or to compute functions.
"""
from .image_ops import (CropAndResize, NonMaxSuppressionV3, HSVToRGB)
from . import _quant_ops
from ._embedding_cache_ops import (CacheSwapTable, UpdateCache, MapCacheIdx, SubAndFilter,
MapUniform, DynamicAssign, PadAndShift)
from ._inner_ops import (MatmulDDS, DSDMatmul, NonZero)
from ._quant_ops import *
from ._thor_ops import (CusBatchMatMul, CusCholeskyTrsm, CusFusedAbsMax1, CusImg2Col, CusMatMulCubeDenseLeft,
CusMatMulCubeFraczRightMul, CusMatMulCube, CusMatrixCombine, CusTranspose02314,
CusMatMulCubeDenseRight, CusMatMulCubeFraczLeftCast, Im2Col, NewIm2Col,
LoadIm2Col, UpdateThorGradient, Cholesky, CholeskyTrsm,
DetTriangle, ProdForceSeA)
from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Stack, Unpack, Unstack,
Diag, DiagPart, DType, ExpandDims, Eye,
Fill, Ones, Zeros, GatherNd, GatherV2, Gather, SparseGatherV2, InvertPermutation,
@ -27,7 +36,8 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Stack, Unpack, Unsta
Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue, Meshgrid, Lstsq,
SameTypeShape, ScatterAdd, ScatterSub, ScatterMul, ScatterDiv, ScatterMax, ScatterMin,
ScatterUpdate, ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select,
Shape, DynamicShape, Size, Slice, Split, SplitV, TransShape, ParallelConcat, Padding,
Shape, DynamicShape, TensorShape, Size, Slice, Split, SplitV, TransShape, ParallelConcat,
Padding,
UniqueWithPad, ScatterNdAdd, ScatterNdSub, ScatterNonAliasingAdd, ReverseV2, Rint,
Squeeze, StridedSlice, Tile, TensorScatterUpdate, TensorScatterAdd, EditDistance, Sort,
Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, UnsortedSegmentMax,
@ -36,18 +46,20 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Stack, Unpack, Unsta
EmbeddingLookup, Unique, GatherD, Identity, Range, MaskedFill, MaskedSelect, SearchSorted,
TensorScatterMax, TensorScatterMin, TensorScatterSub, ScatterElements, ExtractVolumePatches,
LowerBound, UpperBound, Cummax)
from .comm_ops import (AllGather, AllReduce, NeighborExchange, NeighborExchangeV2, AlltoAll, _AllSwap, ReduceScatter, Broadcast,
from .comm_ops import (AllGather, AllReduce, NeighborExchange, NeighborExchangeV2, AlltoAll, _AllSwap, ReduceScatter,
Broadcast,
_MirrorOperator, _MirrorMiniStepOperator, _MiniStepAllGather, ReduceOp, _VirtualDataset,
_VirtualOutput, _VirtualDiv, _GetTensorSlice, _VirtualAdd, _VirtualAssignAdd, _VirtualAccuGrad,
_HostAllGather, _HostReduceScatter, _MirrorMicroStepOperator, _MicroStepAllGather)
from .control_ops import GeSwitch, Merge
from .custom_ops import (Custom)
from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary,
TensorSummary, HistogramSummary, Print, Assert)
from .control_ops import GeSwitch, Merge
from .image_ops import (CropAndResize, NonMaxSuppressionV3, HSVToRGB)
from .inner_ops import (ScalarCast, Randperm, NoRepeatNGram, LambApplyOptimizerAssign, LambApplyWeightAssign,
MakeRefKey,
FusedWeightScaleApplyMomentum, FusedCastAdamWeightDecay, FusedAdaFactor,
FusedAdaFactorWithGlobalNorm)
from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, AssignSub, Atan2, BatchMatMul,
BitwiseAnd, BitwiseOr, Ger,
BitwiseXor, Inv, Invert, ApproximateEqual, InplaceAdd, InplaceSub,
@ -63,11 +75,6 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A
Addcmul, Square, Sub, TensorAdd, Add, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps,
Tan, MatrixInverse, IndexAdd, Erfinv, Conj, Real, Imag, Complex, Trunc, IsClose, LuSolve,
CholeskyInverse)
from .random_ops import (RandomChoiceWithMask, StandardNormal, Gamma, Poisson, UniformInt, UniformReal,
RandomCategorical, StandardLaplace, Multinomial, UniformCandidateSampler,
LogUniformCandidateSampler)
from .nn_ops import (LSTM, SGD, Adam, AdamWeightDecay, FusedSparseAdam, FusedSparseLazyAdam, AdamNoUpdateParam,
ApplyMomentum, BatchNorm, BiasAdd, Conv2D, Conv3D, Conv2DTranspose, Conv3DTranspose,
DepthwiseConv2dNative,
@ -91,20 +98,15 @@ from .nn_ops import (LSTM, SGD, Adam, AdamWeightDecay, FusedSparseAdam, FusedSpa
ApplyAddSign, ApplyPowerSign, ApplyGradientDescent, ApplyProximalGradientDescent,
ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell, InTopK, AdaptiveAvgPool2D, SoftShrink,
ApplyAdamWithAmsgrad, GridSampler3D)
from . import _quant_ops
from ._quant_ops import *
from .other_ops import (Assign, IOU, BoundingBoxDecode, BoundingBoxEncode,
ConfusionMatrix, PopulationCount, UpdateState, Load,
CheckValid, Partial, Depend, identity, CheckBprop, Push, Pull, PullWeight, PushWeight,
PushMetrics, StartFLJob, UpdateModel, GetModel, PyFunc, ExchangeKeys, GetKeys)
from ._thor_ops import (CusBatchMatMul, CusCholeskyTrsm, CusFusedAbsMax1, CusImg2Col, CusMatMulCubeDenseLeft,
CusMatMulCubeFraczRightMul, CusMatMulCube, CusMatrixCombine, CusTranspose02314,
CusMatMulCubeDenseRight, CusMatMulCubeFraczLeftCast, Im2Col, NewIm2Col,
LoadIm2Col, UpdateThorGradient, Cholesky, CholeskyTrsm,
DetTriangle, ProdForceSeA)
from .random_ops import (RandomChoiceWithMask, StandardNormal, Gamma, Poisson, UniformInt, UniformReal,
RandomCategorical, StandardLaplace, Multinomial, UniformCandidateSampler,
LogUniformCandidateSampler)
from .rl_ops import (BufferAppend, BufferGetItem, BufferSample)
from .sparse_ops import (SparseToDense, SparseTensorDenseMatmul)
from ._embedding_cache_ops import (CacheSwapTable, UpdateCache, MapCacheIdx, SubAndFilter,
MapUniform, DynamicAssign, PadAndShift)
from .sponge_ops import (BondForce, BondEnergy, BondAtomEnergy, BondForceWithAtomEnergy, BondForceWithAtomVirial,
DihedralForce, DihedralEnergy, DihedralAtomEnergy, DihedralForceWithAtomEnergy, AngleForce,
AngleEnergy, AngleAtomEnergy, AngleForceWithAtomEnergy, PMEReciprocalForce,
@ -122,9 +124,6 @@ from .sponge_update_ops import (ConstrainForceCycleWithVirial, RefreshUintCrd, L
PMEExcludedForceUpdate, LJForceWithVirialEnergyUpdate,
Dihedral14ForceWithAtomEnergyVirial, PMEEnergyUpdate,
ConstrainForceVirial, ConstrainForce, Constrain)
from .rl_ops import (BufferAppend, BufferGetItem, BufferSample)
from ._inner_ops import (MatmulDDS, DSDMatmul, NonZero)
from .custom_ops import (Custom)
__all__ = [
'HSVToRGB',
@ -146,6 +145,7 @@ __all__ = [
'ArgMinWithValue',
'AddN',
'AccumulateNV2',
'TensorShape',
'Sub',
'Coalesce',
'CumSum',

View File

@ -1,5 +1,3 @@
# coding: utf-8
# Copyright 2020-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -637,6 +635,39 @@ class Shape(Primitive):
def __init__(self):
"""Initialize Shape"""
class TensorShape(Primitive):
"""
Returns the shape of the input tensor. And it used to be dynamic shape.
Note:
Dynamic shape: After the graph is running, as the tensor flows in the graph, the specific shape of the tensor
on each node on the graph can be inferred according to the structure of the graph.
This shape is called a dynamic shape. As the input shape of the graph is different,
the dynamic shape of the tensor in the graph will change.
Inputs:
- **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
Outputs:
Tensor[int], 1-dim Tensor of type int32
Raises:
TypeError: If `input_x` is not a Tensor.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> input_x = Tensor(np.ones(shape=[3, 2, 1]), mindspore.float32)
>>> shape = ops.TensorShape()
>>> output = shape(input_x)
>>> print(output)
[3 2 1]
"""
@prim_attr_register
def __init__(self):
"""init Shape"""
self.init_prim_io_names(inputs=['input_x'], outputs=['output'])
class DynamicShape(Primitive):
"""
@ -668,6 +699,7 @@ class DynamicShape(Primitive):
[3 2 1]
"""
@deprecated("1.7", "TensorShape", True)
@prim_attr_register
def __init__(self):
"""init Shape"""

View File

@ -16,17 +16,18 @@
import functools
import numpy as np
import pytest
from mindspore.ops.signature import sig_rw, sig_dtype, make_sig
import mindspore as ms
import mindspore.context as context
from mindspore import Tensor
from mindspore.common import dtype as mstype
from mindspore.nn import Cell
from mindspore.ops import operations as P
from mindspore.ops.operations import _inner_ops as inner
from mindspore.ops import prim_attr_register
from mindspore.ops.operations import _inner_ops as inner
from mindspore.ops.primitive import PrimitiveWithInfer
import mindspore.context as context
from mindspore.ops.signature import sig_rw, sig_dtype, make_sig
from ..ut_filter import non_graph_engine
from ....mindspore_test_framework.mindspore_test import mindspore_test
from ....mindspore_test_framework.pipeline.forward.compile_forward \
@ -273,6 +274,8 @@ class UnpackNet(Cell):
def construct(self, x):
return self.unstack(x)
class SpaceToDepthNet(Cell):
def __init__(self):
super(SpaceToDepthNet, self).__init__()
@ -315,6 +318,17 @@ class SpaceToBatchNDNet(Cell):
return self.space_to_batch_nd(x)
class TensorShapeNet(Cell):
def __init__(self):
super(TensorShapeNet, self).__init__()
self.shape = P.TensorShape()
self.unique = P.Unique()
def construct(self, x):
x, _ = self.unique(x)
return self.shape(x)
class RangeNet(Cell):
def __init__(self):
super(RangeNet, self).__init__()
@ -367,15 +381,17 @@ test_case_array_ops = [
('RangeNet', {
'block': RangeNet(),
'desc_inputs': [Tensor(np.array([1, 2, 3, 2]), ms.int32)]}),
('TensorShapeNet', {'block': TensorShapeNet(), 'desc_inputs': [Tensor(np.array([1, 2, 3, 2]), ms.int32)]})
]
test_case_lists = [test_case_array_ops]
test_exec_case = functools.reduce(lambda x, y: x + y, test_case_lists)
# use -k to select certain testcast
# pytest tests/python/ops/test_ops.py::test_backward -k LayerNorm
@non_graph_engine
@mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config)
def test_exec():
@ -393,6 +409,8 @@ raise_set = [
('ReduceSum_Error', {
'block': (lambda x: P.ReduceSum(keep_dims=1), {'exception': TypeError}),
'desc_inputs': [Tensor(np.ones(shape=[3, 1, 5]))]}),
('TensorShapeNet_Error', {'block': (lambda x: P.TensorSHape(), {'exception': TypeError}),
'desc_inputs': [(Tensor(np.ones(shape=[3, 1, 5])), Tensor(np.ones(shape=[3, 1, 5])))]})
]