!38970 Distinguish the data types supported by different hardware platforms for Softsign.

Merge pull request !38970 from liqiliang/padding-gpu
This commit is contained in:
i-robot 2022-07-28 02:25:42 +00:00 committed by Gitee
commit 229ed3b7a6
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 11 additions and 1 deletions

View File

@ -24,6 +24,7 @@
#include "utils/check_convert_utils.h"
#include "abstract/ops/primitive_infer_map.h"
#include "mindapi/src/helper.h"
#include "utils/ms_context.h"
namespace mindspore {
namespace ops {
@ -45,7 +46,16 @@ TypePtr SoftsignInferType(const PrimitivePtr &prim, const std::vector<AbstractBa
(void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, 1, prim_name);
MS_EXCEPTION_IF_NULL(input_args[0]);
auto x_type = input_args[0]->BuildType();
const std::set valid_types = {kFloat16, kFloat32, kFloat64};
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
bool is_gpu = (context->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kGPUDevice);
bool is_cpu = (context->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kCPUDevice);
std::set<TypePtr> valid_types{};
if (is_gpu || is_cpu) {
valid_types = {kFloat16, kFloat32, kFloat64};
} else {
valid_types = {kFloat16, kFloat32};
}
(void)CheckAndConvertUtils::CheckTensorTypeValid("input_x", x_type, valid_types, prim_name);
return x_type;
}