!12728 fix precision error after cache modification

From: @simson_wu
Reviewed-by: @chujinjin,@zhoufeng54
Signed-off-by: @chujinjin
This commit is contained in:
mindspore-ci-bot 2021-03-02 13:54:30 +08:00 committed by Gitee
commit 00f25c8409
3 changed files with 6 additions and 3 deletions

View File

@ -66,6 +66,7 @@
#include "utils/ms_utils.h"
#include "utils/config_manager.h"
#include "utils/ms_context.h"
#include "utils/utils.h"
#if ENABLE_CPU && ENABLE_GPU
#include "ps/util.h"
#include "ps/ps_cache/ps_cache_manager.h"
@ -448,7 +449,8 @@ void GPUSession::BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &grap
const std::vector<tensor::TensorPtr> &input_tensors,
const std::vector<int64_t> &tensors_mask) {
// Check if the graph cache exists.
if (run_op_graphs_.find(graph_info) != run_op_graphs_.end()) {
if (run_op_graphs_.find(graph_info) != run_op_graphs_.end() &&
kOpCacheAllowList.find(op_run_info.op_name) == kOpCacheAllowList.end()) {
return;
}
// Prepare the graph

View File

@ -387,10 +387,8 @@ void ConvertPyObjectToTensor(const py::object &input_object, const PrimitivePtr
} else if (py::isinstance<py::float_>(input_object)) {
double input_value = py::cast<py::float_>(input_object);
tensor_ptr = std::make_shared<tensor::Tensor>(input_value, kFloat32);
*tensor_mask = kValueNodeTensorMask;
} else if (py::isinstance<py::int_>(input_object)) {
tensor_ptr = std::make_shared<tensor::Tensor>(py::cast<int64_t>(input_object), kInt64);
*tensor_mask = kValueNodeTensorMask;
} else if (py::isinstance<py::array>(input_object)) {
tensor_ptr = TensorPy::MakeTensor(py::cast<py::array>(input_object), nullptr);
} else if (py::isinstance<py::list>(input_object)) {

View File

@ -271,6 +271,7 @@ constexpr auto kPadAndShiftOpName = "PadAndShift";
constexpr auto kSparseSoftmaxCrossEntropyWithLogitsOpName = "SparseSoftmaxCrossEntropyWithLogits";
constexpr auto kOneHotOpName = "OneHot";
constexpr auto kSoftmaxCrossEntropyWithLogitsOpName = "SoftmaxCrossEntropyWithLogits";
constexpr auto kUniformCandidateSamplerOpName = "UniformCandidateSampler";
// Hcom Op Type
constexpr auto kHcomOpTypeAllReduce = "HcomAllReduce";
@ -492,6 +493,8 @@ const std::set<std::string> kOptOperatorSet = {kMomentumOpName,
const std::set<std::string> kPosteriorOperatorSet = {kPullOpName};
const std::set<std::string> kOpCacheAllowList = {kUniformCandidateSamplerOpName};
const std::set<std::string> kHWSpecialFormatSet = {
kOpFormat_FRACTAL_Z_3D, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0, kOpFormat_FRAC_NZ, kOpFormat_C1HWNCoC0,
kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04, kOpFormat_FRACTAL_ZN_LSTM, kOpFormat_NDC1HWC0, kOpFormat_FRAC_Z};