!66230 summary适配动态shape
Merge pull request !66230 from DavidFFFan/feature-2.3-summary-iter4
This commit is contained in:
commit
e8a1441483
|
@ -81,8 +81,9 @@ void Summary::SummaryTensor(KernelGraph *graph) {
|
||||||
auto node = output_item.second.first;
|
auto node = output_item.second.first;
|
||||||
size_t index = IntToSize(output_item.second.second);
|
size_t index = IntToSize(output_item.second.second);
|
||||||
auto address = AnfAlgo::GetOutputAddr(node, index, false);
|
auto address = AnfAlgo::GetOutputAddr(node, index, false);
|
||||||
auto shape = common::AnfAlgo::GetOutputInferShape(node, index);
|
auto kt = AnfAlgo::GetOutputKernelTensor(node, index);
|
||||||
TypeId type_id = common::AnfAlgo::GetOutputInferDataType(node, index);
|
auto shape = kt->GetShapeVector();
|
||||||
|
TypeId type_id = kt->dtype_id();
|
||||||
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, shape);
|
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, shape);
|
||||||
MS_EXCEPTION_IF_NULL(address);
|
MS_EXCEPTION_IF_NULL(address);
|
||||||
if (!address->GetPtr()) {
|
if (!address->GetPtr()) {
|
||||||
|
|
|
@ -0,0 +1,84 @@
|
||||||
|
/*
|
||||||
|
* Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "ops/histogram_summary.h"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "abstract/abstract_value.h"
|
||||||
|
#include "abstract/dshape.h"
|
||||||
|
#include "abstract/ops/op_infer.h"
|
||||||
|
#include "abstract/ops/primitive_infer_map.h"
|
||||||
|
#include "base/base.h"
|
||||||
|
#include "ir/anf.h"
|
||||||
|
#include "ir/dtype/number.h"
|
||||||
|
#include "ir/primitive.h"
|
||||||
|
#include "mindapi/base/shape_vector.h"
|
||||||
|
#include "mindapi/base/shared_ptr.h"
|
||||||
|
#include "mindapi/ir/value.h"
|
||||||
|
#include "mindapi/src/helper.h"
|
||||||
|
#include "mindspore/core/ops/structure_ops.h"
|
||||||
|
#include "ops/op_name.h"
|
||||||
|
#include "ops/primitive_c.h"
|
||||||
|
#include "utils/check_convert_utils.h"
|
||||||
|
#include "utils/log_adapter.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ops {
|
||||||
|
namespace {
|
||||||
|
constexpr int BASE_SIZE = 1;
|
||||||
|
abstract::ShapePtr HistogramSummaryInferShape(const PrimitivePtr &primitive,
|
||||||
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
auto prim_name = primitive->name();
|
||||||
|
// check
|
||||||
|
MS_EXCEPTION_IF_NULL(input_args[1]);
|
||||||
|
auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->GetShape())[kShape];
|
||||||
|
(void)CheckAndConvertUtils::CheckInteger("v rank", int64_t(v_shape.size()), kGreaterEqual, BASE_SIZE, prim_name);
|
||||||
|
return std::make_shared<abstract::Shape>(ShapeVector(1));
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
MIND_API_OPERATOR_IMPL(HistogramSummary, BaseOperator);
|
||||||
|
void HistogramSummary::set_side_effect_io() { (void)this->AddAttr(kSideEffectIO, api::MakeValue(true)); }
|
||||||
|
|
||||||
|
bool HistogramSummary::get_side_effect_io() const {
|
||||||
|
auto value_ptr = GetAttr(kSideEffectIO);
|
||||||
|
return GetValue<bool>(value_ptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
void HistogramSummary::Init() { this->set_side_effect_io(); }
|
||||||
|
|
||||||
|
class MIND_API HistogramSummaryInfer : public abstract::OpInferBase {
|
||||||
|
public:
|
||||||
|
BaseShapePtr InferShape(const PrimitivePtr &primitive,
|
||||||
|
const std::vector<AbstractBasePtr> &input_args) const override {
|
||||||
|
primitive->AddAttr("dyn_input_sizes", MakeValue(std::vector<int64_t>{-1, 1}));
|
||||||
|
return HistogramSummaryInferShape(primitive, input_args);
|
||||||
|
}
|
||||||
|
|
||||||
|
TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override {
|
||||||
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
// check
|
||||||
|
CheckAndConvertUtils::CheckSummaryParam(input_args[0], input_args[1], primitive->name());
|
||||||
|
return kInt32;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_PRIMITIVE_OP_INFER_IMPL(HistogramSummary, prim::kPrimHistogramSummary, HistogramSummaryInfer, false);
|
||||||
|
} // namespace ops
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,49 @@
|
||||||
|
/*
|
||||||
|
* Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CORE_OPS_HISTOGRAM_SUMMARY_H_
|
||||||
|
#define MINDSPORE_CORE_OPS_HISTOGRAM_SUMMARY_H_
|
||||||
|
#include <map>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "mindapi/base/types.h"
|
||||||
|
#include "ops/base_operator.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ops {
|
||||||
|
constexpr auto kNameHistogramSummary = "HistogramSummary";
|
||||||
|
|
||||||
|
/// \brief Outputs a tensor to a protocol buffer through a tensor summary operator.
|
||||||
|
/// Refer to Python API @ref mindspore.ops.HistogramSummary for more details.
|
||||||
|
class MIND_API HistogramSummary : public BaseOperator {
|
||||||
|
public:
|
||||||
|
MIND_API_BASE_MEMBER(HistogramSummary);
|
||||||
|
/// \brief Constructor.
|
||||||
|
HistogramSummary() : BaseOperator(kNameHistogramSummary) {}
|
||||||
|
/// \brief Init.
|
||||||
|
void Init();
|
||||||
|
/// \brief Set side_effect_io.
|
||||||
|
void set_side_effect_io();
|
||||||
|
/// \brief Get side_effect_io.
|
||||||
|
///
|
||||||
|
/// \return side_effect_io.
|
||||||
|
bool get_side_effect_io() const;
|
||||||
|
};
|
||||||
|
} // namespace ops
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CORE_OPS_HISTOGRAM_SUMMARY_H_
|
|
@ -0,0 +1,84 @@
|
||||||
|
/*
|
||||||
|
* Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "ops/image_summary.h"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "abstract/abstract_value.h"
|
||||||
|
#include "abstract/dshape.h"
|
||||||
|
#include "abstract/ops/op_infer.h"
|
||||||
|
#include "abstract/ops/primitive_infer_map.h"
|
||||||
|
#include "base/base.h"
|
||||||
|
#include "ir/anf.h"
|
||||||
|
#include "ir/dtype/number.h"
|
||||||
|
#include "ir/primitive.h"
|
||||||
|
#include "mindapi/base/shape_vector.h"
|
||||||
|
#include "mindapi/base/shared_ptr.h"
|
||||||
|
#include "mindapi/ir/value.h"
|
||||||
|
#include "mindapi/src/helper.h"
|
||||||
|
#include "mindspore/core/ops/structure_ops.h"
|
||||||
|
#include "ops/op_name.h"
|
||||||
|
#include "ops/primitive_c.h"
|
||||||
|
#include "utils/check_convert_utils.h"
|
||||||
|
#include "utils/log_adapter.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ops {
|
||||||
|
namespace {
|
||||||
|
constexpr int IMAGE_RANK = 4;
|
||||||
|
abstract::ShapePtr ImageSummaryInferShape(const PrimitivePtr &primitive,
|
||||||
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
auto prim_name = primitive->name();
|
||||||
|
// check
|
||||||
|
MS_EXCEPTION_IF_NULL(input_args[1]);
|
||||||
|
auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->GetShape())[kShape];
|
||||||
|
(void)CheckAndConvertUtils::CheckInteger("v rank", int64_t(v_shape.size()), kEqual, IMAGE_RANK, prim_name);
|
||||||
|
return std::make_shared<abstract::Shape>(ShapeVector(1));
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
MIND_API_OPERATOR_IMPL(ImageSummary, BaseOperator);
|
||||||
|
void ImageSummary::set_side_effect_io() { (void)this->AddAttr(kSideEffectIO, api::MakeValue(true)); }
|
||||||
|
|
||||||
|
bool ImageSummary::get_side_effect_io() const {
|
||||||
|
auto value_ptr = GetAttr(kSideEffectIO);
|
||||||
|
return GetValue<bool>(value_ptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ImageSummary::Init() { this->set_side_effect_io(); }
|
||||||
|
|
||||||
|
class MIND_API ImageSummaryInfer : public abstract::OpInferBase {
|
||||||
|
public:
|
||||||
|
BaseShapePtr InferShape(const PrimitivePtr &primitive,
|
||||||
|
const std::vector<AbstractBasePtr> &input_args) const override {
|
||||||
|
primitive->AddAttr("dyn_input_sizes", MakeValue(std::vector<int64_t>{-1, 1}));
|
||||||
|
return ImageSummaryInferShape(primitive, input_args);
|
||||||
|
}
|
||||||
|
|
||||||
|
TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override {
|
||||||
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
// check
|
||||||
|
CheckAndConvertUtils::CheckSummaryParam(input_args[0], input_args[1], primitive->name());
|
||||||
|
return kInt32;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_PRIMITIVE_OP_INFER_IMPL(ImageSummary, prim::kPrimImageSummary, ImageSummaryInfer, false);
|
||||||
|
} // namespace ops
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,49 @@
|
||||||
|
/*
|
||||||
|
* Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CORE_OPS_IMAGE_SUMMARY_H_
|
||||||
|
#define MINDSPORE_CORE_OPS_IMAGE_SUMMARY_H_
|
||||||
|
#include <map>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "mindapi/base/types.h"
|
||||||
|
#include "ops/base_operator.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ops {
|
||||||
|
constexpr auto kNameImageSummary = "ImageSummary";
|
||||||
|
|
||||||
|
/// \brief Outputs a tensor to a protocol buffer through a tensor summary operator.
|
||||||
|
/// Refer to Python API @ref mindspore.ops.ImageSummary for more details.
|
||||||
|
class MIND_API ImageSummary : public BaseOperator {
|
||||||
|
public:
|
||||||
|
MIND_API_BASE_MEMBER(ImageSummary);
|
||||||
|
/// \brief Constructor.
|
||||||
|
ImageSummary() : BaseOperator(kNameImageSummary) {}
|
||||||
|
/// \brief Init.
|
||||||
|
void Init();
|
||||||
|
/// \brief Set side_effect_io.
|
||||||
|
void set_side_effect_io();
|
||||||
|
/// \brief Get side_effect_io.
|
||||||
|
///
|
||||||
|
/// \return side_effect_io.
|
||||||
|
bool get_side_effect_io() const;
|
||||||
|
};
|
||||||
|
} // namespace ops
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CORE_OPS_IMAGE_SUMMARY_H_
|
|
@ -116,7 +116,7 @@ class ScalarSummary(Primitive):
|
||||||
_cache_summary_data(self.name, args[0], args[1])
|
_cache_summary_data(self.name, args[0], args[1])
|
||||||
|
|
||||||
|
|
||||||
class ImageSummary(PrimitiveWithInfer):
|
class ImageSummary(Primitive):
|
||||||
"""
|
"""
|
||||||
This operator will put an image tensor to a summary file with protocol buffer format. It must be used with
|
This operator will put an image tensor to a summary file with protocol buffer format. It must be used with
|
||||||
SummaryRecord or SummaryCollector, which specify the directory of the summary file. The summary file can
|
SummaryRecord or SummaryCollector, which specify the directory of the summary file. The summary file can
|
||||||
|
@ -163,18 +163,6 @@ class ImageSummary(PrimitiveWithInfer):
|
||||||
self.add_prim_attr("channel_name", "ms_image_summary")
|
self.add_prim_attr("channel_name", "ms_image_summary")
|
||||||
self.add_prim_attr("dyn_input_sizes", [-1, 1])
|
self.add_prim_attr("dyn_input_sizes", [-1, 1])
|
||||||
|
|
||||||
def __infer__(self, name, value):
|
|
||||||
_check_summary_param(name, value, self.__class__.__name__)
|
|
||||||
|
|
||||||
# The shape dim of image should be 4.
|
|
||||||
v_shape = value['shape']
|
|
||||||
image_dim = 4
|
|
||||||
if len(v_shape) != image_dim:
|
|
||||||
raise ValueError(f"For '{self.name}', the dimension of 'value' must be {image_dim},"
|
|
||||||
f" but got {len(v_shape)}.")
|
|
||||||
|
|
||||||
return SUMMARY_RETURN_VALUE
|
|
||||||
|
|
||||||
def __call__(self, *args):
|
def __call__(self, *args):
|
||||||
_cache_summary_data(self.name, args[0], args[1])
|
_cache_summary_data(self.name, args[0], args[1])
|
||||||
|
|
||||||
|
@ -320,7 +308,7 @@ class TensorDump(Primitive):
|
||||||
TENSORDUMP_ID += 1
|
TENSORDUMP_ID += 1
|
||||||
|
|
||||||
|
|
||||||
class HistogramSummary(PrimitiveWithInfer):
|
class HistogramSummary(Primitive):
|
||||||
"""
|
"""
|
||||||
This operator will calculate the histogram of a tensor and put it to a summary file with protocol buffer format.
|
This operator will calculate the histogram of a tensor and put it to a summary file with protocol buffer format.
|
||||||
It must be used with SummaryRecord or SummaryCollector, which specify the directory of the summary file.
|
It must be used with SummaryRecord or SummaryCollector, which specify the directory of the summary file.
|
||||||
|
@ -373,17 +361,6 @@ class HistogramSummary(PrimitiveWithInfer):
|
||||||
self.add_prim_attr("channel_name", "ms_histogram_summary")
|
self.add_prim_attr("channel_name", "ms_histogram_summary")
|
||||||
self.add_prim_attr("dyn_input_sizes", [-1, 1])
|
self.add_prim_attr("dyn_input_sizes", [-1, 1])
|
||||||
|
|
||||||
def __infer__(self, name, value):
|
|
||||||
_check_summary_param(name, value, self.__class__.__name__)
|
|
||||||
|
|
||||||
v_shape = value['shape']
|
|
||||||
# In the summary, the histogram value should be a tensor whose shape is not [].
|
|
||||||
if not v_shape:
|
|
||||||
raise ValueError(f"For '{self.name}', the type of 'value' must be tensor, "
|
|
||||||
f"its shape should not be [], but got {v_shape}.")
|
|
||||||
|
|
||||||
return SUMMARY_RETURN_VALUE
|
|
||||||
|
|
||||||
def __call__(self, *args):
|
def __call__(self, *args):
|
||||||
_cache_summary_data(self.name, args[0], args[1])
|
_cache_summary_data(self.name, args[0], args[1])
|
||||||
|
|
||||||
|
|
|
@ -21,6 +21,7 @@ import time
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
import mindspore as ms
|
||||||
from mindspore import nn, Tensor, context
|
from mindspore import nn, Tensor, context
|
||||||
from mindspore.common.initializer import Normal
|
from mindspore.common.initializer import Normal
|
||||||
from mindspore.train import Loss
|
from mindspore.train import Loss
|
||||||
|
@ -50,6 +51,7 @@ class LeNet5(nn.Cell):
|
||||||
|
|
||||||
self.scalar_summary = P.ScalarSummary()
|
self.scalar_summary = P.ScalarSummary()
|
||||||
self.image_summary = P.ImageSummary()
|
self.image_summary = P.ImageSummary()
|
||||||
|
self.histogram_summary = P.HistogramSummary()
|
||||||
self.tensor_summary = P.TensorSummary()
|
self.tensor_summary = P.TensorSummary()
|
||||||
self.channel = Tensor(num_channel)
|
self.channel = Tensor(num_channel)
|
||||||
|
|
||||||
|
@ -57,6 +59,7 @@ class LeNet5(nn.Cell):
|
||||||
"""construct"""
|
"""construct"""
|
||||||
self.image_summary('x', x)
|
self.image_summary('x', x)
|
||||||
self.tensor_summary('x', x)
|
self.tensor_summary('x', x)
|
||||||
|
self.histogram_summary('x', x)
|
||||||
x = self.conv1(x)
|
x = self.conv1(x)
|
||||||
x = self.relu(x)
|
x = self.relu(x)
|
||||||
x = self.max_pool2d(x)
|
x = self.max_pool2d(x)
|
||||||
|
@ -112,10 +115,12 @@ class TestSummaryOps:
|
||||||
summary_data = _get_summary_tensor_data()
|
summary_data = _get_summary_tensor_data()
|
||||||
image_data = summary_data.get('x[:Image]').asnumpy()
|
image_data = summary_data.get('x[:Image]').asnumpy()
|
||||||
tensor_data = summary_data.get('x[:Tensor]').asnumpy()
|
tensor_data = summary_data.get('x[:Tensor]').asnumpy()
|
||||||
|
histogram_data = summary_data.get('x[:Histogram]').asnumpy()
|
||||||
x_fc3 = summary_data.get('x_fc3[:Scalar]').asnumpy()
|
x_fc3 = summary_data.get('x_fc3[:Scalar]').asnumpy()
|
||||||
|
|
||||||
assert np.allclose(expected_data, image_data)
|
assert np.allclose(expected_data, image_data)
|
||||||
assert np.allclose(expected_data, tensor_data)
|
assert np.allclose(expected_data, tensor_data)
|
||||||
|
assert np.allclose(expected_data, histogram_data)
|
||||||
assert not np.allclose(0, x_fc3)
|
assert not np.allclose(0, x_fc3)
|
||||||
|
|
||||||
@pytest.mark.level0
|
@pytest.mark.level0
|
||||||
|
@ -140,10 +145,12 @@ class TestSummaryOps:
|
||||||
summary_data = _get_summary_tensor_data()
|
summary_data = _get_summary_tensor_data()
|
||||||
image_data = summary_data.get('x[:Image]').asnumpy()
|
image_data = summary_data.get('x[:Image]').asnumpy()
|
||||||
tensor_data = summary_data.get('x[:Tensor]').asnumpy()
|
tensor_data = summary_data.get('x[:Tensor]').asnumpy()
|
||||||
|
histogram_data = summary_data.get('x[:Histogram]').asnumpy()
|
||||||
x_fc3 = summary_data.get('x_fc3[:Scalar]').asnumpy()
|
x_fc3 = summary_data.get('x_fc3[:Scalar]').asnumpy()
|
||||||
|
|
||||||
assert np.allclose(expected_data, image_data)
|
assert np.allclose(expected_data, image_data)
|
||||||
assert np.allclose(expected_data, tensor_data)
|
assert np.allclose(expected_data, tensor_data)
|
||||||
|
assert np.allclose(expected_data, histogram_data)
|
||||||
assert not np.allclose(0, x_fc3)
|
assert not np.allclose(0, x_fc3)
|
||||||
|
|
||||||
@pytest.mark.level0
|
@pytest.mark.level0
|
||||||
|
@ -169,10 +176,41 @@ class TestSummaryOps:
|
||||||
summary_data = _get_summary_tensor_data()
|
summary_data = _get_summary_tensor_data()
|
||||||
image_data = summary_data.get('x[:Image]').asnumpy()
|
image_data = summary_data.get('x[:Image]').asnumpy()
|
||||||
tensor_data = summary_data.get('x[:Tensor]').asnumpy()
|
tensor_data = summary_data.get('x[:Tensor]').asnumpy()
|
||||||
|
histogram_data = summary_data.get('x[:Histogram]').asnumpy()
|
||||||
x_fc3 = summary_data.get('x_fc3[:Scalar]').asnumpy()
|
x_fc3 = summary_data.get('x_fc3[:Scalar]').asnumpy()
|
||||||
|
|
||||||
assert np.allclose(expected_data, image_data)
|
assert np.allclose(expected_data, image_data)
|
||||||
assert np.allclose(expected_data, tensor_data)
|
assert np.allclose(expected_data, tensor_data)
|
||||||
|
assert np.allclose(expected_data, histogram_data)
|
||||||
assert not np.allclose(0, x_fc3)
|
assert not np.allclose(0, x_fc3)
|
||||||
|
|
||||||
del os.environ['GRAPH_OP_RUN']
|
del os.environ['GRAPH_OP_RUN']
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
@security_off_wrap
|
||||||
|
def test_dynamic_shape_summary_ops(self):
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_id=self.device_id)
|
||||||
|
ds_train = create_mnist_dataset('train', num_samples=1, batch_size=1)
|
||||||
|
ds_train_iter = ds_train.create_dict_iterator()
|
||||||
|
expected_data = next(ds_train_iter)['image'].asnumpy()
|
||||||
|
|
||||||
|
net = LeNet5()
|
||||||
|
dynamic_shape = Tensor(shape=[None, None, None, None], dtype=ms.float32)
|
||||||
|
net.set_inputs(dynamic_shape)
|
||||||
|
net(Tensor(expected_data))
|
||||||
|
|
||||||
|
time.sleep(0.5)
|
||||||
|
summary_data = _get_summary_tensor_data()
|
||||||
|
image_data = summary_data.get('x[:Image]').asnumpy()
|
||||||
|
tensor_data = summary_data.get('x[:Tensor]').asnumpy()
|
||||||
|
histogram_data = summary_data.get('x[:Histogram]').asnumpy()
|
||||||
|
x_fc3 = summary_data.get('x_fc3[:Scalar]').asnumpy()
|
||||||
|
|
||||||
|
assert np.allclose(expected_data, image_data)
|
||||||
|
assert np.allclose(expected_data, tensor_data)
|
||||||
|
assert np.allclose(expected_data, histogram_data)
|
||||||
|
assert not np.allclose(0, x_fc3)
|
||||||
|
|
Loading…
Reference in New Issue