forked from mindspore-Ecosystem/mindspore
!38970 Distinguish the data types supported by different hardware platforms for Softsign.
Merge pull request !38970 from liqiliang/padding-gpu
This commit is contained in:
commit
229ed3b7a6
|
@ -24,6 +24,7 @@
|
|||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
#include "utils/ms_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
|
@ -45,7 +46,16 @@ TypePtr SoftsignInferType(const PrimitivePtr &prim, const std::vector<AbstractBa
|
|||
(void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, 1, prim_name);
|
||||
MS_EXCEPTION_IF_NULL(input_args[0]);
|
||||
auto x_type = input_args[0]->BuildType();
|
||||
const std::set valid_types = {kFloat16, kFloat32, kFloat64};
|
||||
auto context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
bool is_gpu = (context->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kGPUDevice);
|
||||
bool is_cpu = (context->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kCPUDevice);
|
||||
std::set<TypePtr> valid_types{};
|
||||
if (is_gpu || is_cpu) {
|
||||
valid_types = {kFloat16, kFloat32, kFloat64};
|
||||
} else {
|
||||
valid_types = {kFloat16, kFloat32};
|
||||
}
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("input_x", x_type, valid_types, prim_name);
|
||||
return x_type;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue