forked from mindspore-Ecosystem/mindspore
dynamic data sink on Ascend
This commit is contained in:
parent
7de71630d9
commit
57cb72e2b7
|
@ -818,6 +818,13 @@ void SetRunMode(const FuncGraphPtr &func_graph, compile::Backend *backend_ptr) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GRAPH | Dynamic Shape : KernelByKernel path in MindRT.
|
||||||
|
if (IsDynamicShapeGraph(func_graph)) {
|
||||||
|
MS_LOG(INFO) << "Run Graph mode with kernelbykernel(Dynamic Shape).";
|
||||||
|
set_ctx(false, false, false);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
// GRAPH | Closure\ENV\While scenario : KernelByKernel path in MindRT.
|
// GRAPH | Closure\ENV\While scenario : KernelByKernel path in MindRT.
|
||||||
auto graphs = func_graph->func_graphs_used_total();
|
auto graphs = func_graph->func_graphs_used_total();
|
||||||
(void)graphs.insert(func_graph);
|
(void)graphs.insert(func_graph);
|
||||||
|
|
|
@ -24,6 +24,7 @@
|
||||||
#include "acl/acl_rt.h"
|
#include "acl/acl_rt.h"
|
||||||
#include "utils/convert_utils.h"
|
#include "utils/convert_utils.h"
|
||||||
#include "plugin/device/ascend/kernel/aicpu/aicpu_util.h"
|
#include "plugin/device/ascend/kernel/aicpu/aicpu_util.h"
|
||||||
|
#include "plugin/device/ascend/hal/device/ascend_memory_manager.h"
|
||||||
#include "utils/ms_context.h"
|
#include "utils/ms_context.h"
|
||||||
#include "runtime/device/kernel_runtime.h"
|
#include "runtime/device/kernel_runtime.h"
|
||||||
#include "runtime/kernel.h"
|
#include "runtime/kernel.h"
|
||||||
|
@ -42,6 +43,13 @@ DynamicAicpuOpKernelMod::DynamicAicpuOpKernelMod(const AnfNodePtr &anf_node_ptr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
DynamicAicpuOpKernelMod::~DynamicAicpuOpKernelMod() {
|
||||||
|
// free dev ptr
|
||||||
|
if (ext_info_addr_dev_ != nullptr) {
|
||||||
|
auto mem_manager = std::make_shared<device::ascend::AscendMemoryManager>();
|
||||||
|
mem_manager->FreeMemFromMemPool(ext_info_addr_dev_);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void DynamicAicpuOpKernelMod::InferOp() {
|
void DynamicAicpuOpKernelMod::InferOp() {
|
||||||
auto node = anf_node_.lock();
|
auto node = anf_node_.lock();
|
||||||
|
@ -104,9 +112,11 @@ void DynamicAicpuOpKernelMod::AllocateExtInfoDeviceAddr(const CNodePtr &cnode) {
|
||||||
}
|
}
|
||||||
// Allocate ext info addr in device
|
// Allocate ext info addr in device
|
||||||
if (!ext_info_.empty()) {
|
if (!ext_info_.empty()) {
|
||||||
auto ret = rtMalloc(&ext_info_addr_dev_, ext_info_.size(), RT_MEMORY_HBM);
|
auto mem_manager = std::make_shared<device::ascend::AscendMemoryManager>();
|
||||||
if (ret != RT_ERROR_NONE) {
|
ext_info_addr_dev_ = mem_manager->MallocMemFromMemPool(ext_info_.size(), false);
|
||||||
MS_LOG(EXCEPTION) << "Call rtMalloc ext_info_addr_dev_ failed. Op name: " << cnode->fullname_with_scope();
|
if (ext_info_addr_dev_ == nullptr) {
|
||||||
|
MS_LOG(EXCEPTION) << "Call MemoryPool to allocate ext_info_addr_dev_ failed. Op name: "
|
||||||
|
<< cnode->fullname_with_scope();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ext_info_size_ = ext_info_.size();
|
ext_info_size_ = ext_info_.size();
|
||||||
|
|
|
@ -27,6 +27,7 @@ class DynamicAicpuOpKernelMod : public AicpuOpKernelMod {
|
||||||
public:
|
public:
|
||||||
DynamicAicpuOpKernelMod() : unknow_type_(device::ascend::UnknowShapeOpType::DEPEND_IN_SHAPE) {}
|
DynamicAicpuOpKernelMod() : unknow_type_(device::ascend::UnknowShapeOpType::DEPEND_IN_SHAPE) {}
|
||||||
explicit DynamicAicpuOpKernelMod(const AnfNodePtr &anf_node_ptr);
|
explicit DynamicAicpuOpKernelMod(const AnfNodePtr &anf_node_ptr);
|
||||||
|
~DynamicAicpuOpKernelMod() override;
|
||||||
|
|
||||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override;
|
const std::vector<AddressPtr> &outputs, void *stream_ptr) override;
|
||||||
|
|
|
@ -31,6 +31,7 @@
|
||||||
#include "runtime/mem.h"
|
#include "runtime/mem.h"
|
||||||
#include "pipeline/jit/static_analysis/static_analysis.h"
|
#include "pipeline/jit/static_analysis/static_analysis.h"
|
||||||
#include "plugin/device/ascend/hal/device/executor/tiling/op_tiling_adapter.h"
|
#include "plugin/device/ascend/hal/device/executor/tiling/op_tiling_adapter.h"
|
||||||
|
#include "plugin/device/ascend/hal/device/ascend_memory_manager.h"
|
||||||
#include "utils/ms_device_shape_transfer.h"
|
#include "utils/ms_device_shape_transfer.h"
|
||||||
#include "utils/utils.h"
|
#include "utils/utils.h"
|
||||||
#include "register/op_tiling.h"
|
#include "register/op_tiling.h"
|
||||||
|
@ -51,6 +52,13 @@ DynamicTbeKernelMod::DynamicTbeKernelMod(KernelPackPtr kernel_pack, const AnfNod
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
DynamicTbeKernelMod::~DynamicTbeKernelMod() {
|
||||||
|
if (tiling_data_ptr_ != nullptr) {
|
||||||
|
auto mem_manager = std::make_shared<device::ascend::AscendMemoryManager>();
|
||||||
|
mem_manager->FreeMemFromMemPool(tiling_data_ptr_);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void DynamicTbeKernelMod::InferOp() {
|
void DynamicTbeKernelMod::InferOp() {
|
||||||
if (AnfAlgo::IsDynamicShape(anf_node_.lock())) {
|
if (AnfAlgo::IsDynamicShape(anf_node_.lock())) {
|
||||||
auto node = anf_node_.lock();
|
auto node = anf_node_.lock();
|
||||||
|
@ -142,8 +150,9 @@ void DynamicTbeKernelMod::InitTilingDataPtr() {
|
||||||
auto kernel_json_info = kernel_pack_->kernel_json_info();
|
auto kernel_json_info = kernel_pack_->kernel_json_info();
|
||||||
auto op_para_size = kernel_json_info.op_para_size;
|
auto op_para_size = kernel_json_info.op_para_size;
|
||||||
if (op_para_size > 0) {
|
if (op_para_size > 0) {
|
||||||
auto ret = rtMalloc(&tiling_data_ptr_, op_para_size, RT_MEMORY_HBM);
|
auto mem_manager = std::make_shared<device::ascend::AscendMemoryManager>();
|
||||||
if (ret != RT_ERROR_NONE) {
|
tiling_data_ptr_ = mem_manager->MallocMemFromMemPool(op_para_size, false);
|
||||||
|
if (tiling_data_ptr_ == nullptr) {
|
||||||
MS_LOG(EXCEPTION) << "RtMalloc tiling data failed.";
|
MS_LOG(EXCEPTION) << "RtMalloc tiling data failed.";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -33,6 +33,7 @@ class DynamicTbeKernelMod : public TbeKernelMod {
|
||||||
public:
|
public:
|
||||||
explicit DynamicTbeKernelMod(KernelPackPtr kernel_pack) : TbeKernelMod(kernel_pack) {} // maybe delete later
|
explicit DynamicTbeKernelMod(KernelPackPtr kernel_pack) : TbeKernelMod(kernel_pack) {} // maybe delete later
|
||||||
DynamicTbeKernelMod(KernelPackPtr kernel_pack, const AnfNodePtr &anf_node_ptr);
|
DynamicTbeKernelMod(KernelPackPtr kernel_pack, const AnfNodePtr &anf_node_ptr);
|
||||||
|
~DynamicTbeKernelMod() override;
|
||||||
|
|
||||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override;
|
const std::vector<AddressPtr> &outputs, void *stream_ptr) override;
|
||||||
|
|
|
@ -27,36 +27,50 @@ namespace opt::dynamic_shape {
|
||||||
namespace {
|
namespace {
|
||||||
constexpr size_t kTupleFirstItemIndex = 0;
|
constexpr size_t kTupleFirstItemIndex = 0;
|
||||||
constexpr size_t kFirstDataInputIndex = 1;
|
constexpr size_t kFirstDataInputIndex = 1;
|
||||||
|
using DependPair = std::pair<AnfNodePtr, AnfNodePtr>;
|
||||||
AnfNodePtr InsertDepend(const FuncGraphPtr &g, const AnfNodePtr &prev, const AnfNodePtr &next) {
|
struct DependPairCmp {
|
||||||
|
bool operator()(const DependPair &lhs, const DependPair &rhs) const {
|
||||||
|
if (lhs.first != rhs.first) {
|
||||||
|
return lhs.first > rhs.first;
|
||||||
|
}
|
||||||
|
return lhs.second > rhs.second;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
void InsertDepend(const FuncGraphPtr &g, const AnfNodePtr &prev, const AnfNodePtr &next, AnfNodePtrList *depend_nodes) {
|
||||||
MS_EXCEPTION_IF_NULL(g);
|
MS_EXCEPTION_IF_NULL(g);
|
||||||
MS_EXCEPTION_IF_NULL(prev);
|
MS_EXCEPTION_IF_NULL(prev);
|
||||||
MS_EXCEPTION_IF_NULL(next);
|
MS_EXCEPTION_IF_NULL(next);
|
||||||
|
MS_EXCEPTION_IF_NULL(depend_nodes);
|
||||||
|
static std::set<DependPair, DependPairCmp> added_set;
|
||||||
|
|
||||||
|
DependPair cur_pair = std::make_pair(prev, next);
|
||||||
|
if (added_set.count(cur_pair) > 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
// add depend from prev to next
|
// add depend from prev to next
|
||||||
auto depend_node = g->NewCNode(
|
auto depend_node = g->NewCNode(
|
||||||
std::vector<AnfNodePtr>{NewValueNode(std::make_shared<Primitive>(prim::kPrimDepend->name())), next, prev});
|
std::vector<AnfNodePtr>{NewValueNode(std::make_shared<Primitive>(prim::kPrimDepend->name())), next, prev});
|
||||||
MS_EXCEPTION_IF_NULL(depend_node);
|
MS_EXCEPTION_IF_NULL(depend_node);
|
||||||
return depend_node;
|
depend_nodes->push_back(depend_node);
|
||||||
|
added_set.insert(cur_pair);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool LinkInternalOp(const FuncGraphPtr &g, const AnfNodePtr &node, AnfNodePtrList *depend_nodes) {
|
bool LinkInternalOp(const FuncGraphPtr &g, const AnfNodePtr &node, AnfNodePtrList *depend_nodes) {
|
||||||
MS_EXCEPTION_IF_NULL(g);
|
MS_EXCEPTION_IF_NULL(g);
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
MS_EXCEPTION_IF_NULL(depend_nodes);
|
MS_EXCEPTION_IF_NULL(depend_nodes);
|
||||||
|
|
||||||
bool changed = false;
|
bool changed = false;
|
||||||
auto custom_nodes = CustomActorNodeManager::Instance().GetCustomActorNodes(node);
|
auto custom_nodes = CustomActorNodeManager::Instance().GetCustomActorNodes(node);
|
||||||
if (custom_nodes.infer_node != nullptr) {
|
if (custom_nodes.infer_node != nullptr && custom_nodes.init_node != nullptr) {
|
||||||
if (custom_nodes.init_node == nullptr) {
|
InsertDepend(g, custom_nodes.infer_node, custom_nodes.init_node, depend_nodes); // link infer => init
|
||||||
MS_LOG(WARNING) << "Node " << node->DebugString() << " has infer node but init node is null.";
|
InsertDepend(g, custom_nodes.init_node, node, depend_nodes); // link init => launch
|
||||||
} else {
|
changed = true;
|
||||||
depend_nodes->push_back(InsertDepend(g, custom_nodes.infer_node, custom_nodes.init_node)); // link infer => init
|
|
||||||
depend_nodes->push_back(InsertDepend(g, custom_nodes.init_node, node)); // link init => launch
|
|
||||||
changed = true;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (IsDynUpdate(custom_nodes.update_node)) {
|
if (IsDynUpdate(custom_nodes.update_node)) {
|
||||||
depend_nodes->push_back(InsertDepend(g, node, custom_nodes.update_node)); // link launch => update
|
InsertDepend(g, node, custom_nodes.update_node, depend_nodes); // link launch => update
|
||||||
changed = true;
|
changed = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -81,19 +95,13 @@ bool LinkInputOp(const FuncGraphPtr &g, const CNodePtr &cnode, AnfNodePtrList *d
|
||||||
}
|
}
|
||||||
auto prev_custom_nodes = CustomActorNodeManager::Instance().GetCustomActorNodes(prev_node);
|
auto prev_custom_nodes = CustomActorNodeManager::Instance().GetCustomActorNodes(prev_node);
|
||||||
if (prev_custom_nodes.infer_node != nullptr) {
|
if (prev_custom_nodes.infer_node != nullptr) {
|
||||||
depend_nodes->push_back(
|
InsertDepend(g, prev_custom_nodes.infer_node, custom_nodes.infer_node,
|
||||||
InsertDepend(g, prev_custom_nodes.infer_node, custom_nodes.infer_node)); // link prev.infer => curr.infer
|
depend_nodes); // link prev.infer => curr.infer
|
||||||
MS_LOG(DEBUG) << "Link from " << prev_node->fullname_with_scope() << " infer "
|
|
||||||
<< prev_custom_nodes.infer_node->fullname_with_scope() << " to " << cnode->fullname_with_scope()
|
|
||||||
<< " infer " << custom_nodes.infer_node->fullname_with_scope();
|
|
||||||
changed = true;
|
changed = true;
|
||||||
}
|
}
|
||||||
if (IsDynUpdate(prev_custom_nodes.update_node)) {
|
if (IsDynUpdate(prev_custom_nodes.update_node)) {
|
||||||
depend_nodes->push_back(
|
InsertDepend(g, prev_custom_nodes.update_node, custom_nodes.infer_node,
|
||||||
InsertDepend(g, prev_custom_nodes.update_node, custom_nodes.infer_node)); // link prev.update => curr.infer
|
depend_nodes); // link prev.update => curr.infer
|
||||||
MS_LOG(DEBUG) << "Link from " << prev_node->fullname_with_scope() << " update "
|
|
||||||
<< prev_custom_nodes.update_node->fullname_with_scope() << " to " << cnode->fullname_with_scope()
|
|
||||||
<< " infer " << custom_nodes.infer_node->fullname_with_scope();
|
|
||||||
changed = true;
|
changed = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -127,11 +135,11 @@ bool LinkDependSync(const FuncGraphPtr &g, const CNodePtr &cnode, AnfNodePtrList
|
||||||
if (IsDynUpdate(prev_custom_nodes.update_node)) {
|
if (IsDynUpdate(prev_custom_nodes.update_node)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
// 1. Link prev_node => prev_node.update if its update is just sync.
|
// 1. Link prev_node => prev_node.update if its update is just sync.
|
||||||
depend_nodes->push_back(InsertDepend(g, prev_node, prev_custom_nodes.update_node));
|
InsertDepend(g, prev_node, prev_custom_nodes.update_node, depend_nodes);
|
||||||
|
// 1. Link prev_node => prev_node.update if its update is just sync.
|
||||||
// 2. Link prev_node.update => cur_node.infer.
|
// 2. Link prev_node.update => cur_node.infer.
|
||||||
depend_nodes->push_back(InsertDepend(g, prev_custom_nodes.update_node, custom_nodes.infer_node));
|
InsertDepend(g, prev_custom_nodes.update_node, custom_nodes.infer_node, depend_nodes);
|
||||||
changed = true;
|
changed = true;
|
||||||
}
|
}
|
||||||
return changed;
|
return changed;
|
||||||
|
|
|
@ -36,6 +36,7 @@ void CustomActor::Run(OpContext<DeviceTensor> *const ctx) {
|
||||||
std::string error_info = "Launch custom kernel exception: " + node->fullname_with_scope();
|
std::string error_info = "Launch custom kernel exception: " + node->fullname_with_scope();
|
||||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy_, (*ctx), error_info);
|
SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy_, (*ctx), error_info);
|
||||||
}
|
}
|
||||||
|
EraseInput(ctx);
|
||||||
SendOutput(ctx);
|
SendOutput(ctx);
|
||||||
}
|
}
|
||||||
} // namespace runtime
|
} // namespace runtime
|
||||||
|
|
|
@ -20,12 +20,12 @@ from mindspore.train import DatasetHelper, connect_network_with_dataset
|
||||||
import mindspore.dataset as ds
|
import mindspore.dataset as ds
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||||
|
|
||||||
def _exec_preprocess(network, is_train, dataset, dataset_sink_mode, sink_size=-1, epoch_num=1, dataset_helper=None):
|
|
||||||
if dataset_sink_mode and not is_train:
|
def _exec_preprocess(network, is_train, dataset, dataset_sink_mode, sink_size=1, epoch_num=1, dataset_helper=None):
|
||||||
dataset.__loop_size__ = 1
|
|
||||||
|
|
||||||
if dataset_helper is None:
|
if dataset_helper is None:
|
||||||
dataset_helper = DatasetHelper(dataset, dataset_sink_mode, sink_size, epoch_num)
|
dataset_helper = DatasetHelper(
|
||||||
|
dataset, dataset_sink_mode, sink_size, epoch_num)
|
||||||
|
|
||||||
if dataset_sink_mode:
|
if dataset_sink_mode:
|
||||||
network = connect_network_with_dataset(network, dataset_helper)
|
network = connect_network_with_dataset(network, dataset_helper)
|
||||||
|
@ -43,13 +43,16 @@ def _eval_dataset_sink_process(network, valid_dataset):
|
||||||
for elem1, (_, elem2) in zip(outputs, inputs2.items()):
|
for elem1, (_, elem2) in zip(outputs, inputs2.items()):
|
||||||
assert elem1.shape == elem2.shape
|
assert elem1.shape == elem2.shape
|
||||||
|
|
||||||
|
|
||||||
def dataset_generator():
|
def dataset_generator():
|
||||||
for i in range(1, 10):
|
for i in range(1, 10):
|
||||||
yield (
|
yield (
|
||||||
np.ones((32, i), dtype=np.float32), np.zeros((32, i, i, 3), dtype=np.int32),
|
np.ones((32, i), dtype=np.float32), np.zeros(
|
||||||
|
(32, i, i, 3), dtype=np.int32),
|
||||||
np.ones((32,), dtype=np.float32),
|
np.ones((32,), dtype=np.float32),
|
||||||
np.ones((32, i, 8), dtype=np.float32), np.ones((32, 8, 8), dtype=np.float32))
|
np.ones((32, i, 8), dtype=np.float32), np.ones((32, 8, 8), dtype=np.float32))
|
||||||
|
|
||||||
|
|
||||||
class Net(nn.Cell):
|
class Net(nn.Cell):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(Net, self).__init__()
|
super(Net, self).__init__()
|
||||||
|
@ -69,13 +72,41 @@ class Net(nn.Cell):
|
||||||
x5 = self.relu(x5)
|
x5 = self.relu(x5)
|
||||||
return x1, x2, x3, x4, x5
|
return x1, x2, x3, x4, x5
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.level0
|
@pytest.mark.level0
|
||||||
@pytest.mark.platform_arm_ascend_training
|
@pytest.mark.platform_arm_ascend_training
|
||||||
@pytest.mark.platform_x86_ascend_training
|
@pytest.mark.platform_x86_ascend_training
|
||||||
@pytest.mark.env_onecard
|
@pytest.mark.env_onecard
|
||||||
def test_getnext_dynamic_pipeline():
|
def test_getnext_dynamic_pipeline():
|
||||||
network = Net()
|
network = Net()
|
||||||
dataset = ds.GeneratorDataset(dataset_generator, ["data1", "data2", "data3", "data4", "data5"])
|
dataset = ds.GeneratorDataset(
|
||||||
|
dataset_generator, ["data1", "data2", "data3", "data4", "data5"])
|
||||||
dataset.set_dynamic_columns(columns={"data1": [32, None], "data2": [32, None, None, 3],
|
dataset.set_dynamic_columns(columns={"data1": [32, None], "data2": [32, None, None, 3],
|
||||||
"data3": [32], "data4": [32, None, 8], "data5": [32, 8, 8]})
|
"data3": [32], "data4": [32, None, 8], "data5": [32, 8, 8]})
|
||||||
_eval_dataset_sink_process(network, dataset)
|
_eval_dataset_sink_process(network, dataset)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_getnext_sink_size_dynamic_pipeline():
|
||||||
|
"""
|
||||||
|
Feature: arbitrary sink size of dynamic data sink.
|
||||||
|
Description: datasets with dynamic shape as input.
|
||||||
|
Expectation: success without assert exception.
|
||||||
|
"""
|
||||||
|
network = Net()
|
||||||
|
dataset = ds.GeneratorDataset(
|
||||||
|
dataset_generator, ["data1", "data2", "data3", "data4", "data5"])
|
||||||
|
dataset.set_dynamic_columns(columns={"data1": [32, None], "data2": [32, None, None, 3],
|
||||||
|
"data3": [32], "data4": [32, None, 8], "data5": [32, 8, 8]})
|
||||||
|
|
||||||
|
dataset_helper, eval_network = _exec_preprocess(
|
||||||
|
network, is_train=False, dataset=dataset, dataset_sink_mode=True, sink_size=-1)
|
||||||
|
for inputs in dataset_helper:
|
||||||
|
outputs = eval_network(*inputs)
|
||||||
|
for data_item in dataset.create_dict_iterator():
|
||||||
|
last_inputs = data_item.items()
|
||||||
|
for output, (_, last_input) in zip(outputs, last_inputs):
|
||||||
|
assert output.shape == last_input.shape
|
||||||
|
|
Loading…
Reference in New Issue