forked from mindspore-Ecosystem/mindspore
!17297 dynamic_shape_pipeline
Merge pull request !17297 from wanyiming/dynamic_shape_pipeline
This commit is contained in:
commit
311e59c190
|
@ -411,6 +411,11 @@ void AscendBackendUBFusionOptimization(const std::shared_ptr<session::KernelGrap
|
||||||
MS_LOG(INFO) << "UBFusion is not enable, skip";
|
MS_LOG(INFO) << "UBFusion is not enable, skip";
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (kernel_graph->is_dynamic_shape()) {
|
||||||
|
MS_LOG(WARNING) << "Dynamic shape skip fusion";
|
||||||
|
return;
|
||||||
|
}
|
||||||
bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
|
bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
|
||||||
if (save_graphs) {
|
if (save_graphs) {
|
||||||
std::string file_name = "hwopt_d_ub_fusion_before_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir";
|
std::string file_name = "hwopt_d_ub_fusion_before_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir";
|
||||||
|
|
|
@ -54,7 +54,11 @@ AnfNodePtr InsertCastForMultipleOutput(const FuncGraphPtr &func_graph, const CNo
|
||||||
auto imm = std::make_shared<Int64Imm>(output_idx);
|
auto imm = std::make_shared<Int64Imm>(output_idx);
|
||||||
idx->set_abstract(std::make_shared<abstract::AbstractScalar>(imm));
|
idx->set_abstract(std::make_shared<abstract::AbstractScalar>(imm));
|
||||||
auto getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cnode, idx});
|
auto getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cnode, idx});
|
||||||
AnfAlgo::SetOutputTypeAndDetailShape({origin_type}, {origin_shape}, getitem.get());
|
auto abs = cnode->abstract()->cast<abstract::AbstractTuplePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(abs);
|
||||||
|
auto abs_i = abs->elements()[output_idx];
|
||||||
|
MS_EXCEPTION_IF_NULL(abs_i);
|
||||||
|
getitem->set_abstract(abs_i);
|
||||||
const auto dev_fmt = AnfAlgo::GetOutputFormat(cnode, output_idx);
|
const auto dev_fmt = AnfAlgo::GetOutputFormat(cnode, output_idx);
|
||||||
const auto device_type = AnfAlgo::GetOutputDeviceDataType(cnode, output_idx);
|
const auto device_type = AnfAlgo::GetOutputDeviceDataType(cnode, output_idx);
|
||||||
if (origin_type != device_type) {
|
if (origin_type != device_type) {
|
||||||
|
|
|
@ -492,9 +492,11 @@ CNodePtr CreatTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePtr
|
||||||
CNodePtr tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx});
|
CNodePtr tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx});
|
||||||
MS_EXCEPTION_IF_NULL(tuple_getitem);
|
MS_EXCEPTION_IF_NULL(tuple_getitem);
|
||||||
tuple_getitem->set_scope(node->scope());
|
tuple_getitem->set_scope(node->scope());
|
||||||
auto origin_shape = AnfAlgo::GetOutputDetailShape(node, output_idx);
|
auto abs = node->abstract()->cast<abstract::AbstractTuplePtr>();
|
||||||
TypeId origin_type = AnfAlgo::GetOutputInferDataType(node, output_idx);
|
MS_EXCEPTION_IF_NULL(abs);
|
||||||
AnfAlgo::SetOutputTypeAndDetailShape({origin_type}, {origin_shape}, tuple_getitem.get());
|
auto abs_i = abs->elements()[output_idx];
|
||||||
|
MS_EXCEPTION_IF_NULL(abs_i);
|
||||||
|
tuple_getitem->set_abstract(abs_i);
|
||||||
return tuple_getitem;
|
return tuple_getitem;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -165,10 +165,6 @@ bool AiCpuDynamicKernel::UpdateExtInfo() {
|
||||||
}
|
}
|
||||||
|
|
||||||
bool AiCpuDynamicKernel::UpdateOutputShapeFromExtInfo() {
|
bool AiCpuDynamicKernel::UpdateOutputShapeFromExtInfo() {
|
||||||
if (input_num_ == 0) {
|
|
||||||
MS_LOG(WARNING) << "input num is 0";
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
MS_LOG(INFO) << "UpdateOutputShapeFromExtInfo start";
|
MS_LOG(INFO) << "UpdateOutputShapeFromExtInfo start";
|
||||||
auto ret = rtMemcpy(ext_info_handler_->GetExtInfo(), ext_info_handler_->GetExtInfoLen(), ext_info_addr_dev_,
|
auto ret = rtMemcpy(ext_info_handler_->GetExtInfo(), ext_info_handler_->GetExtInfoLen(), ext_info_addr_dev_,
|
||||||
ext_info_size_, RT_MEMCPY_DEVICE_TO_HOST);
|
ext_info_size_, RT_MEMCPY_DEVICE_TO_HOST);
|
||||||
|
|
|
@ -73,9 +73,6 @@ void DynamicKernel::RebuildDependTensor() {
|
||||||
}
|
}
|
||||||
|
|
||||||
void DynamicKernel::InferShape() {
|
void DynamicKernel::InferShape() {
|
||||||
if (!is_input_dynamic_shape_ && is_output_dynamic_shape_ && !have_depends()) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
auto cnode = cnode_ptr_.lock();
|
auto cnode = cnode_ptr_.lock();
|
||||||
MS_EXCEPTION_IF_NULL(cnode);
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
MS_LOG(INFO) << "InferShape start, node:" << cnode->fullname_with_scope();
|
MS_LOG(INFO) << "InferShape start, node:" << cnode->fullname_with_scope();
|
||||||
|
|
|
@ -602,7 +602,7 @@ const std::set<TypeId> kFloatDataTypeSet = {kNumberTypeFloat16, kNumberTypeFloat
|
||||||
|
|
||||||
const std::set<std::string> kComputeDepend = {kUniqueOpName, kComputeAccidentalHitsOpName, kSubAndFilterOpName,
|
const std::set<std::string> kComputeDepend = {kUniqueOpName, kComputeAccidentalHitsOpName, kSubAndFilterOpName,
|
||||||
kPadAndShiftOpName, kCTCGreedyDecoderOpName, kDropoutGenMaskOpName,
|
kPadAndShiftOpName, kCTCGreedyDecoderOpName, kDropoutGenMaskOpName,
|
||||||
kMaskedSelectOpName, kDynamicStitchOpName};
|
kMaskedSelectOpName, kDynamicStitchOpName, kGetNextOpName};
|
||||||
|
|
||||||
const std::set<std::string> k3DFormatSet = {kOpFormat_NCDHW, kOpFormat_NDC1HWC0, kOpFormat_FRACTAL_Z_3D,
|
const std::set<std::string> k3DFormatSet = {kOpFormat_NCDHW, kOpFormat_NDC1HWC0, kOpFormat_FRACTAL_Z_3D,
|
||||||
kOpFormat_NDHWC, kOpFormat_DHWCN, kOpFormat_DHWNC};
|
kOpFormat_NDHWC, kOpFormat_DHWCN, kOpFormat_DHWNC};
|
||||||
|
|
|
@ -361,14 +361,12 @@ AbstractBasePtr MakeAbstractTensor(const ShapePtr &shape, const TypePtr &type) {
|
||||||
ShapeVector min_shape_vec;
|
ShapeVector min_shape_vec;
|
||||||
ShapeVector max_shape_vec;
|
ShapeVector max_shape_vec;
|
||||||
|
|
||||||
if (shape->IsDynamic()) {
|
|
||||||
if (!shape->min_shape().empty()) {
|
if (!shape->min_shape().empty()) {
|
||||||
min_shape_vec = shape->min_shape();
|
min_shape_vec = shape->min_shape();
|
||||||
}
|
}
|
||||||
if (!shape->max_shape().empty()) {
|
if (!shape->max_shape().empty()) {
|
||||||
max_shape_vec = shape->max_shape();
|
max_shape_vec = shape->max_shape();
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
auto ret_shape = std::make_shared<abstract::Shape>(ret_vec, min_shape_vec, max_shape_vec);
|
auto ret_shape = std::make_shared<abstract::Shape>(ret_vec, min_shape_vec, max_shape_vec);
|
||||||
if (type->isa<TensorType>()) {
|
if (type->isa<TensorType>()) {
|
||||||
|
|
|
@ -32,6 +32,7 @@ inline const std::unordered_map<std::string, ValuePtr> kSideEffectPropagate = {
|
||||||
{mindspore::GRAPH_FLAG_SIDE_EFFECT_PROPAGATE, kValueOne},
|
{mindspore::GRAPH_FLAG_SIDE_EFFECT_PROPAGATE, kValueOne},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
constexpr auto kGetNext = "GetNext";
|
||||||
constexpr auto kGather = "Gather";
|
constexpr auto kGather = "Gather";
|
||||||
// Arithmetic
|
// Arithmetic
|
||||||
constexpr auto kScalarAdd = "ScalarAdd";
|
constexpr auto kScalarAdd = "ScalarAdd";
|
||||||
|
@ -92,6 +93,9 @@ constexpr auto kDropoutGrad = "DropoutGrad";
|
||||||
constexpr auto kConv2DTranspose = "Conv2DTranspose";
|
constexpr auto kConv2DTranspose = "Conv2DTranspose";
|
||||||
|
|
||||||
// Here list all primitives used in backend or some special primitives used by core.
|
// Here list all primitives used in backend or some special primitives used by core.
|
||||||
|
// GetNext
|
||||||
|
inline const PrimitivePtr kPrimGetNext = std::make_shared<Primitive>(kGetNext);
|
||||||
|
|
||||||
// Arithmetic
|
// Arithmetic
|
||||||
inline const PrimitivePtr kPrimScalarAdd = std::make_shared<Primitive>(kScalarAdd);
|
inline const PrimitivePtr kPrimScalarAdd = std::make_shared<Primitive>(kScalarAdd);
|
||||||
inline const PrimitivePtr kPrimScalarSub = std::make_shared<Primitive>(kScalarSub);
|
inline const PrimitivePtr kPrimScalarSub = std::make_shared<Primitive>(kScalarSub);
|
||||||
|
|
|
@ -0,0 +1,96 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <set>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <memory>
|
||||||
|
#include <algorithm>
|
||||||
|
#include "ops/getnext.h"
|
||||||
|
#include "ops/op_utils.h"
|
||||||
|
#include "utils/check_convert_utils.h"
|
||||||
|
#include "utils/tensor_construct_utils.h"
|
||||||
|
#include "abstract/primitive_infer_map.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ops {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
void GetShapeVector(const ValuePtr &shape_attr, std::vector<std::vector<int64_t>> *shape_vec) {
|
||||||
|
if (shape_attr == nullptr) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
std::vector<ValuePtr> shape = shape_attr->isa<ValueTuple>() ? shape_attr->cast<ValueTuplePtr>()->value()
|
||||||
|
: shape_attr->cast<ValueListPtr>()->value();
|
||||||
|
for (ValuePtr shape_elements : shape) {
|
||||||
|
std::vector<ValuePtr> shape_elements_list = shape_elements->isa<ValueTuple>()
|
||||||
|
? shape_elements->cast<ValueTuplePtr>()->value()
|
||||||
|
: shape_elements->cast<ValueListPtr>()->value();
|
||||||
|
std::vector<int64_t> shape_vec_item;
|
||||||
|
(void)std::transform(std::begin(shape_elements_list), std::end(shape_elements_list),
|
||||||
|
std::back_inserter(shape_vec_item),
|
||||||
|
[](const ValuePtr &e) -> int64_t { return GetValue<int64_t>(e); });
|
||||||
|
shape_vec->push_back(shape_vec_item);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool IsDynamic(const std::vector<ShapeVector> &shape) {
|
||||||
|
for (auto shape_vec : shape) {
|
||||||
|
if (std::find(shape_vec.begin(), shape_vec.end(), -1) != shape_vec.end()) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
abstract::AbstractBasePtr GetnextInferShape(const PrimitivePtr &primitive,
|
||||||
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
auto types = GetValue<std::vector<TypePtr>>(primitive->GetAttr("types"));
|
||||||
|
ValuePtr shape_attr = primitive->GetAttr("shapes");
|
||||||
|
ValuePtr min_shape_attr = primitive->GetAttr("min_shapes");
|
||||||
|
ValuePtr max_shape_attr = primitive->GetAttr("max_shapes");
|
||||||
|
|
||||||
|
std::vector<ShapeVector> shape;
|
||||||
|
std::vector<ShapeVector> min_shape;
|
||||||
|
std::vector<ShapeVector> max_shape;
|
||||||
|
|
||||||
|
GetShapeVector(shape_attr, &shape);
|
||||||
|
GetShapeVector(min_shape_attr, &min_shape);
|
||||||
|
GetShapeVector(max_shape_attr, &max_shape);
|
||||||
|
|
||||||
|
bool is_dynamic = IsDynamic(shape);
|
||||||
|
|
||||||
|
AbstractBasePtrList output;
|
||||||
|
for (size_t i = 0; i < shape.size(); ++i) {
|
||||||
|
auto ret_shape = !min_shape.empty() && !max_shape.empty() && is_dynamic
|
||||||
|
? std::make_shared<abstract::Shape>(shape[i], min_shape[i], max_shape[i])
|
||||||
|
: std::make_shared<abstract::Shape>(shape[i]);
|
||||||
|
auto element = std::make_shared<abstract::AbstractScalar>(kAnyValue, types[i]);
|
||||||
|
auto tensor = std::make_shared<abstract::AbstractTensor>(element, ret_shape);
|
||||||
|
output.push_back(tensor);
|
||||||
|
}
|
||||||
|
return std::make_shared<abstract::AbstractTuple>(output);
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
AbstractBasePtr GetNextInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
return GetnextInferShape(primitive, input_args);
|
||||||
|
}
|
||||||
|
REGISTER_PRIMITIVE_EVAL_IMPL(GetNext, prim::kPrimGetNext, GetNextInfer, nullptr, true);
|
||||||
|
} // namespace ops
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,43 @@
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CORE_OPS_GETNEXT_H_
|
||||||
|
#define MINDSPORE_CORE_OPS_GETNEXT_H_
|
||||||
|
#include <map>
|
||||||
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
#include <memory>
|
||||||
|
#include "ops/primitive_c.h"
|
||||||
|
#include "abstract/abstract_value.h"
|
||||||
|
#include "utils/check_convert_utils.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ops {
|
||||||
|
constexpr auto kNameGetNext = prim::kGetNext;
|
||||||
|
class GetNext : public PrimitiveC {
|
||||||
|
public:
|
||||||
|
GetNext() : PrimitiveC(prim::kPrimGetNext->name()) {}
|
||||||
|
~GetNext() = default;
|
||||||
|
MS_DECLARE_PARENT(GetNext, PrimitiveC);
|
||||||
|
void Init() {}
|
||||||
|
};
|
||||||
|
AbstractBasePtr GetNextInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const std::vector<AbstractBasePtr> &input_args);
|
||||||
|
using PrimMulPtr = std::shared_ptr<GetNext>;
|
||||||
|
} // namespace ops
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CORE_OPS_GETNEXT_H_
|
|
@ -3498,7 +3498,7 @@ class FastGeLU(PrimitiveWithInfer):
|
||||||
return input_x
|
return input_x
|
||||||
|
|
||||||
|
|
||||||
class GetNext(PrimitiveWithInfer):
|
class GetNext(Primitive):
|
||||||
"""
|
"""
|
||||||
Returns the next element in the dataset queue.
|
Returns the next element in the dataset queue.
|
||||||
|
|
||||||
|
@ -3545,12 +3545,6 @@ class GetNext(PrimitiveWithInfer):
|
||||||
validator.check("types length", len(types), "shapes length", len(shapes), Rel.EQ, self.name)
|
validator.check("types length", len(types), "shapes length", len(shapes), Rel.EQ, self.name)
|
||||||
validator.check_value_type("output_num", output_num, [int], self.name)
|
validator.check_value_type("output_num", output_num, [int], self.name)
|
||||||
|
|
||||||
def infer_shape(self):
|
|
||||||
return tuple(self.shapes)
|
|
||||||
|
|
||||||
def infer_dtype(self):
|
|
||||||
return tuple(self.types)
|
|
||||||
|
|
||||||
|
|
||||||
class PReLU(PrimitiveWithInfer):
|
class PReLU(PrimitiveWithInfer):
|
||||||
r"""
|
r"""
|
||||||
|
|
|
@ -60,13 +60,18 @@ class _DataWrapper(nn.Cell):
|
||||||
dataset channel 'queue_name' and performs the forward computation.
|
dataset channel 'queue_name' and performs the forward computation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, network, dataset_types, dataset_shapes, queue_name):
|
def __init__(self, network, dataset_types, dataset_shapes, queue_name, min_shapes=None, max_shapes=None):
|
||||||
super(_DataWrapper, self).__init__(auto_prefix=False, flags=network.get_flags())
|
super(_DataWrapper, self).__init__(auto_prefix=False, flags=network.get_flags())
|
||||||
# Also copy the flag in `network` construct
|
# Also copy the flag in `network` construct
|
||||||
flags = getattr(network.__class__.construct, "_mindspore_flags", {})
|
flags = getattr(network.__class__.construct, "_mindspore_flags", {})
|
||||||
self.info = (dataset_types, dataset_shapes)
|
self.info = (dataset_types, dataset_shapes)
|
||||||
self.add_flags(**flags)
|
self.add_flags(**flags)
|
||||||
self.get_next = P.GetNext(dataset_types, dataset_shapes, len(dataset_types), queue_name)
|
self.get_next = P.GetNext(dataset_types, dataset_shapes, len(dataset_types), queue_name)
|
||||||
|
if min_shapes is not None and max_shapes is not None:
|
||||||
|
Validator.check_value_type("min_shapes", min_shapes, [list, tuple])
|
||||||
|
Validator.check_value_type("max_shapes", max_shapes, [list, tuple])
|
||||||
|
self.get_next.add_prim_attr("min_shapes", min_shapes)
|
||||||
|
self.get_next.add_prim_attr("max_shapes", max_shapes)
|
||||||
self.network = network
|
self.network = network
|
||||||
|
|
||||||
def construct(self):
|
def construct(self):
|
||||||
|
@ -74,9 +79,24 @@ class _DataWrapper(nn.Cell):
|
||||||
return self.network(*outputs)
|
return self.network(*outputs)
|
||||||
|
|
||||||
|
|
||||||
def _generate_dataset_sink_mode_net(network, dataset_shapes, dataset_types, queue_name):
|
def _generate_dataset_sink_mode_net(network, dataset_shapes, dataset_types, queue_name,
|
||||||
|
min_shapes=None, max_shapes=None):
|
||||||
if not isinstance(network, _DataWrapper):
|
if not isinstance(network, _DataWrapper):
|
||||||
network = _DataWrapper(network, dataset_types, dataset_shapes, queue_name)
|
network = _DataWrapper(network, dataset_types, dataset_shapes, queue_name, min_shapes, max_shapes)
|
||||||
|
return network
|
||||||
|
|
||||||
|
def has_dynamic_shape(dataset_shapes):
|
||||||
|
for shape in dataset_shapes:
|
||||||
|
if -1 in shape:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _generate_network_with_dataset(network, dataset_helper, queue_name):
|
||||||
|
dataset_types, dataset_shapes = dataset_helper.types_shapes()
|
||||||
|
(min_shapes, max_shapes) = (None, None) if not has_dynamic_shape(dataset_shapes) \
|
||||||
|
else dataset_helper.dynamic_min_max_shapes()
|
||||||
|
network = _generate_dataset_sink_mode_net(network, dataset_shapes, dataset_types,
|
||||||
|
queue_name, min_shapes, max_shapes)
|
||||||
return network
|
return network
|
||||||
|
|
||||||
|
|
||||||
|
@ -151,8 +171,7 @@ def connect_network_with_dataset(network, dataset_helper):
|
||||||
not context.get_context("enable_ge") and \
|
not context.get_context("enable_ge") and \
|
||||||
context.get_context("device_target") in ("Ascend", "GPU"):
|
context.get_context("device_target") in ("Ascend", "GPU"):
|
||||||
dataset.__me_inited__ = True
|
dataset.__me_inited__ = True
|
||||||
dataset_types, dataset_shapes = dataset_helper.types_shapes()
|
network = _generate_network_with_dataset(network, dataset_helper, queue_name)
|
||||||
network = _generate_dataset_sink_mode_net(network, dataset_shapes, dataset_types, queue_name)
|
|
||||||
|
|
||||||
if hasattr(dataset_iter, "sink_size") and \
|
if hasattr(dataset_iter, "sink_size") and \
|
||||||
dataset_iter.sink_size == 1 and \
|
dataset_iter.sink_size == 1 and \
|
||||||
|
@ -286,7 +305,7 @@ class _DatasetIter:
|
||||||
self.release = dataset.__transfer_dataset__.release
|
self.release = dataset.__transfer_dataset__.release
|
||||||
self.continue_send = dataset.__transfer_dataset__.continue_send
|
self.continue_send = dataset.__transfer_dataset__.continue_send
|
||||||
self.get_data_info = dataset.__transfer_dataset__.get_data_info
|
self.get_data_info = dataset.__transfer_dataset__.get_data_info
|
||||||
self.dynamic_min_max_shapes = dataset.__transfer_dataset__.dynamic_min_max_shapes
|
self.dynamic_min_max_shapes = dataset.dynamic_min_max_shapes
|
||||||
self.dataset_types, self.dataset_shapes = _get_types_and_shapes(dataset)
|
self.dataset_types, self.dataset_shapes = _get_types_and_shapes(dataset)
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
|
|
|
@ -0,0 +1,81 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
from mindspore import nn, context
|
||||||
|
from mindspore import ops as P
|
||||||
|
from mindspore.train import DatasetHelper, connect_network_with_dataset
|
||||||
|
import mindspore.dataset as ds
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||||
|
|
||||||
|
def _exec_preprocess(network, is_train, dataset, dataset_sink_mode, sink_size=-1, epoch_num=1, dataset_helper=None):
|
||||||
|
if dataset_sink_mode and not is_train:
|
||||||
|
dataset.__loop_size__ = 1
|
||||||
|
|
||||||
|
if dataset_helper is None:
|
||||||
|
dataset_helper = DatasetHelper(dataset, dataset_sink_mode, sink_size, epoch_num)
|
||||||
|
|
||||||
|
if dataset_sink_mode:
|
||||||
|
network = connect_network_with_dataset(network, dataset_helper)
|
||||||
|
|
||||||
|
network.set_train(is_train)
|
||||||
|
|
||||||
|
return dataset_helper, network
|
||||||
|
|
||||||
|
|
||||||
|
def _eval_dataset_sink_process(network, valid_dataset):
|
||||||
|
dataset_helper, eval_network = _exec_preprocess(network, is_train=False, dataset=valid_dataset,
|
||||||
|
dataset_sink_mode=True)
|
||||||
|
for inputs1, inputs2 in zip(dataset_helper, valid_dataset.create_dict_iterator()):
|
||||||
|
outputs = eval_network(*inputs1)
|
||||||
|
for elem1, (_, elem2) in zip(outputs, inputs2.items()):
|
||||||
|
assert elem1.shape == elem2.shape
|
||||||
|
|
||||||
|
def dataset_generator():
|
||||||
|
for i in range(1, 10):
|
||||||
|
yield (
|
||||||
|
np.ones((32, i), dtype=np.float32), np.zeros((32, i, i, 3), dtype=np.int32),
|
||||||
|
np.ones((32,), dtype=np.float32),
|
||||||
|
np.ones((32, i, 8), dtype=np.float32), np.ones((32, 8, 8), dtype=np.float32))
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.relu = P.ReLU()
|
||||||
|
|
||||||
|
def construct(self, x1, x2, x3, x4, x5):
|
||||||
|
x1 = self.relu(x1)
|
||||||
|
x1 = self.relu(x1)
|
||||||
|
|
||||||
|
x2 = self.relu(x2)
|
||||||
|
|
||||||
|
x3 = self.relu(x3)
|
||||||
|
x3 = self.relu(x3)
|
||||||
|
|
||||||
|
x4 = self.relu(x4)
|
||||||
|
|
||||||
|
x5 = self.relu(x5)
|
||||||
|
return x1, x2, x3, x4, x5
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_getnext_dynamic_pipeline():
|
||||||
|
network = Net()
|
||||||
|
dataset = ds.GeneratorDataset(dataset_generator, ["data1", "data2", "data3", "data4", "data5"])
|
||||||
|
dataset.set_dynamic_columns(columns={"data1": [32, None], "data2": [32, None, None, 3],
|
||||||
|
"data3": [32], "data4": [32, None, 8], "data5": [32, 8, 8]})
|
||||||
|
_eval_dataset_sink_process(network, dataset)
|
Loading…
Reference in New Issue