forked from mindspore-Ecosystem/mindspore
!14267 add unet++ 310 mindir infer and update unet 310 infer code
From: @lihongkang1 Reviewed-by: @c_34,@wuxuejian Signed-off-by: @c_34
This commit is contained in:
commit
155ac84906
|
@ -42,11 +42,17 @@ using mindspore::MSTensor;
|
|||
using mindspore::ModelType;
|
||||
using mindspore::GraphCell;
|
||||
using mindspore::kSuccess;
|
||||
using mindspore::dataset::vision::Decode;
|
||||
using mindspore::dataset::vision::SwapRedBlue;
|
||||
using mindspore::dataset::vision::Normalize;
|
||||
using mindspore::dataset::vision::Resize;
|
||||
using mindspore::dataset::vision::HWC2CHW;
|
||||
|
||||
|
||||
DEFINE_string(mindir_path, "", "mindir path");
|
||||
DEFINE_string(dataset_path, ".", "dataset path");
|
||||
DEFINE_int32(device_id, 0, "device id");
|
||||
DEFINE_string(need_preprocess, "n", "need preprocess or not");
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
gflags::ParseCommandLineFlags(&argc, &argv, true);
|
||||
|
@ -78,6 +84,14 @@ int main(int argc, char **argv) {
|
|||
|
||||
std::map<double, double> costTime_map;
|
||||
size_t size = all_files.size();
|
||||
|
||||
auto decode(new Decode());
|
||||
auto swapredblue(new SwapRedBlue());
|
||||
auto resize(new Resize({96, 96}));
|
||||
auto normalize(new Normalize({127.5, 127.5, 127.5}, {127.5, 127.5, 127.5}));
|
||||
auto hwc2chw(new HWC2CHW());
|
||||
Execute preprocess({decode, swapredblue, resize, normalize, hwc2chw});
|
||||
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
struct timeval start = {0};
|
||||
struct timeval end = {0};
|
||||
|
@ -86,7 +100,17 @@ int main(int argc, char **argv) {
|
|||
std::vector<MSTensor> inputs;
|
||||
std::vector<MSTensor> outputs;
|
||||
std::cout << "Start predict input files:" << all_files[i] << std::endl;
|
||||
auto img = ReadFileToTensor(all_files[i]);
|
||||
|
||||
auto img = MSTensor();
|
||||
if (FLAGS_need_preprocess == "y") {
|
||||
ret = preprocess(ReadFileToTensor(all_files[i]), &img);
|
||||
if (ret != kSuccess) {
|
||||
std::cout << "preprocess " << all_files[i] << " failed." << std::endl;
|
||||
return 1;
|
||||
}
|
||||
} else {
|
||||
img = ReadFileToTensor(all_files[i]);
|
||||
}
|
||||
|
||||
inputs.emplace_back(model_inputs[0].Name(), model_inputs[0].DataType(), model_inputs[0].Shape(),
|
||||
img.Data().get(), img.DataSize());
|
||||
|
|
|
@ -21,6 +21,7 @@ from mindspore import Tensor, export, load_checkpoint, load_param_into_net, cont
|
|||
from src.unet_medical.unet_model import UNetMedical
|
||||
from src.unet_nested import NestedUNet, UNet
|
||||
from src.config import cfg_unet as cfg
|
||||
from src.utils import UnetEval
|
||||
|
||||
parser = argparse.ArgumentParser(description='unet export')
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id")
|
||||
|
@ -52,5 +53,6 @@ if __name__ == "__main__":
|
|||
param_dict = load_checkpoint(args.ckpt_file)
|
||||
# load the parameter into net
|
||||
load_param_into_net(net, param_dict)
|
||||
net = UnetEval(net)
|
||||
input_data = Tensor(np.ones([args.batch_size, cfg["num_channels"], args.height, args.width]).astype(np.float32))
|
||||
export(net, input_data, file_name=args.file_name, file_format=args.file_format)
|
||||
|
|
|
@ -15,53 +15,77 @@
|
|||
"""unet 310 infer."""
|
||||
import os
|
||||
import argparse
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from src.data_loader import create_dataset
|
||||
from src.data_loader import create_dataset, create_cell_nuclei_dataset
|
||||
from src.config import cfg_unet
|
||||
from scipy.special import softmax
|
||||
|
||||
|
||||
class dice_coeff():
|
||||
def __init__(self):
|
||||
self.clear()
|
||||
|
||||
def clear(self):
|
||||
self._dice_coeff_sum = 0
|
||||
self._iou_sum = 0
|
||||
self._samples_num = 0
|
||||
|
||||
def update(self, *inputs):
|
||||
if len(inputs) != 2:
|
||||
raise ValueError('Mean dice coefficient need 2 inputs (y_pred, y), but got {}'.format(len(inputs)))
|
||||
|
||||
y_pred = inputs[0]
|
||||
raise ValueError('Need 2 inputs ((y_softmax, y_argmax), y), but got {}'.format(len(inputs)))
|
||||
y = np.array(inputs[1])
|
||||
|
||||
self._samples_num += y.shape[0]
|
||||
y_pred = y_pred.transpose(0, 2, 3, 1)
|
||||
y = y.transpose(0, 2, 3, 1)
|
||||
y_pred = softmax(y_pred, axis=3)
|
||||
|
||||
b, h, w, c = y.shape
|
||||
if b != 1:
|
||||
raise ValueError('Batch size should be 1 when in evaluation.')
|
||||
y = y.reshape((h, w, c))
|
||||
if cfg_unet["eval_activate"].lower() == "softmax":
|
||||
y_softmax = np.squeeze(inputs[0][0], axis=0)
|
||||
if cfg_unet["eval_resize"]:
|
||||
y_pred = []
|
||||
for m in range(cfg_unet["num_classes"]):
|
||||
y_pred.append(cv2.resize(np.uint8(y_softmax[:, :, m] * 255), (w, h)) / 255)
|
||||
y_pred = np.stack(y_pred, axis=-1)
|
||||
else:
|
||||
y_pred = y_softmax
|
||||
elif cfg_unet["eval_activate"].lower() == "argmax":
|
||||
y_argmax = np.squeeze(inputs[0][1], axis=0)
|
||||
y_pred = []
|
||||
for n in range(cfg_unet["num_classes"]):
|
||||
if cfg_unet["eval_resize"]:
|
||||
y_pred.append(cv2.resize(np.uint8(y_argmax == n), (w, h), interpolation=cv2.INTER_NEAREST))
|
||||
else:
|
||||
y_pred.append(np.float32(y_argmax == n))
|
||||
y_pred = np.stack(y_pred, axis=-1)
|
||||
else:
|
||||
raise ValueError('config eval_activate should be softmax or argmax.')
|
||||
y_pred = y_pred.astype(np.float32)
|
||||
inter = np.dot(y_pred.flatten(), y.flatten())
|
||||
union = np.dot(y_pred.flatten(), y_pred.flatten()) + np.dot(y.flatten(), y.flatten())
|
||||
|
||||
single_dice_coeff = 2*float(inter)/float(union+1e-6)
|
||||
print("single dice coeff is:", single_dice_coeff)
|
||||
single_iou = single_dice_coeff / (2 - single_dice_coeff)
|
||||
print("single dice coeff is: {}, IOU is: {}".format(single_dice_coeff, single_iou))
|
||||
self._dice_coeff_sum += single_dice_coeff
|
||||
self._iou_sum += single_iou
|
||||
|
||||
def eval(self):
|
||||
if self._samples_num == 0:
|
||||
raise RuntimeError('Total samples num must not be 0.')
|
||||
|
||||
return self._dice_coeff_sum / float(self._samples_num)
|
||||
return (self._dice_coeff_sum / float(self._samples_num), self._iou_sum / float(self._samples_num))
|
||||
|
||||
|
||||
def test_net(data_dir,
|
||||
cross_valid_ind=1,
|
||||
cfg=None):
|
||||
|
||||
_, valid_dataset = create_dataset(data_dir, 1, 1, False, cross_valid_ind, False, do_crop=cfg['crop'],
|
||||
img_size=cfg['img_size'])
|
||||
if 'dataset' in cfg and cfg['dataset'] == "Cell_nuclei":
|
||||
valid_dataset = create_cell_nuclei_dataset(data_dir, cfg['img_size'], 1, 1, is_train=False,
|
||||
eval_resize=cfg["eval_resize"], split=0.8)
|
||||
else:
|
||||
_, valid_dataset = create_dataset(data_dir, 1, 1, False, cross_valid_ind, False, do_crop=cfg['crop'],
|
||||
img_size=cfg['img_size'])
|
||||
labels_list = []
|
||||
|
||||
for data in valid_dataset:
|
||||
|
@ -89,10 +113,25 @@ if __name__ == '__main__':
|
|||
rst_path = args.rst_path
|
||||
metrics = dice_coeff()
|
||||
|
||||
for j in range(len(os.listdir(rst_path))):
|
||||
file_name = rst_path + "ISBI_test_bs_1_" + str(j) + "_0" + ".bin"
|
||||
output = np.fromfile(file_name, np.float32).reshape(1, 2, 576, 576)
|
||||
label = label_list[j]
|
||||
metrics.update(output, label)
|
||||
if 'dataset' in cfg_unet and cfg_unet['dataset'] == "Cell_nuclei":
|
||||
for i, bin_name in enumerate(os.listdir('./preprocess_Result/')):
|
||||
bin_name_softmax = bin_name.replace(".png", "") + "_0.bin"
|
||||
bin_name_argmax = bin_name.replace(".png", "") + "_1.bin"
|
||||
file_name_sof = rst_path + bin_name_softmax
|
||||
file_name_arg = rst_path + bin_name_argmax
|
||||
softmax_out = np.fromfile(file_name_sof, np.float32).reshape(1, 96, 96, 2)
|
||||
argmax_out = np.fromfile(file_name_arg, np.float32).reshape(1, 96, 96)
|
||||
label = label_list[i]
|
||||
metrics.update((softmax_out, argmax_out), label)
|
||||
else:
|
||||
for j in range(len(os.listdir('./preprocess_Result/'))):
|
||||
file_name_sof = rst_path + "ISBI_test_bs_1_" + str(j) + "_0" + ".bin"
|
||||
file_name_arg = rst_path + "ISBI_test_bs_1_" + str(j) + "_1" + ".bin"
|
||||
softmax_out = np.fromfile(file_name_sof, np.float32).reshape(1, 576, 576, 2)
|
||||
argmax_out = np.fromfile(file_name_arg, np.float32).reshape(1, 576, 576)
|
||||
label = label_list[j]
|
||||
metrics.update((softmax_out, argmax_out), label)
|
||||
|
||||
print("Cross valid dice coeff is: ", metrics.eval())
|
||||
eval_score = metrics.eval()
|
||||
print("============== Cross valid dice coeff is:", eval_score[0])
|
||||
print("============== Cross valid IOU is:", eval_score[1])
|
||||
|
|
|
@ -14,6 +14,10 @@
|
|||
# ============================================================================
|
||||
"""unet 310 infer preprocess dataset"""
|
||||
import argparse
|
||||
import os
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
from src.data_loader import create_dataset
|
||||
from src.config import cfg_unet
|
||||
|
||||
|
@ -29,6 +33,56 @@ def preprocess_dataset(data_dir, result_path, cross_valid_ind=1, cfg=None):
|
|||
data[0].asnumpy().tofile(file_path)
|
||||
|
||||
|
||||
class CellNucleiDataset:
|
||||
"""
|
||||
Cell nuclei dataset preprocess class.
|
||||
"""
|
||||
def __init__(self, data_dir, repeat, result_path, is_train=False, split=0.8):
|
||||
self.data_dir = data_dir
|
||||
self.img_ids = sorted(next(os.walk(self.data_dir))[1])
|
||||
self.train_ids = self.img_ids[:int(len(self.img_ids) * split)] * repeat
|
||||
np.random.shuffle(self.train_ids)
|
||||
self.val_ids = self.img_ids[int(len(self.img_ids) * split):]
|
||||
self.is_train = is_train
|
||||
self.result_path = result_path
|
||||
self._preprocess_dataset()
|
||||
|
||||
def _preprocess_dataset(self):
|
||||
for img_id in self.val_ids:
|
||||
path = os.path.join(self.data_dir, img_id)
|
||||
img = cv2.imread(os.path.join(path, "images", img_id + ".png"))
|
||||
if len(img.shape) == 2:
|
||||
img = np.expand_dims(img, axis=-1)
|
||||
img = np.concatenate([img, img, img], axis=-1)
|
||||
mask = []
|
||||
for mask_file in next(os.walk(os.path.join(path, "masks")))[2]:
|
||||
mask_ = cv2.imread(os.path.join(path, "masks", mask_file), cv2.IMREAD_GRAYSCALE)
|
||||
mask.append(mask_)
|
||||
mask = np.max(mask, axis=0)
|
||||
cv2.imwrite(os.path.join(self.result_path, img_id + ".png"), img)
|
||||
|
||||
def _read_img_mask(self, img_id):
|
||||
path = os.path.join(self.data_dir, img_id)
|
||||
img = cv2.imread(os.path.join(path, "image.png"))
|
||||
mask = cv2.imread(os.path.join(path, "mask.png"), cv2.IMREAD_GRAYSCALE)
|
||||
return img, mask
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.is_train:
|
||||
return self._read_img_mask(self.train_ids[index])
|
||||
return self._read_img_mask(self.val_ids[index])
|
||||
|
||||
@property
|
||||
def column_names(self):
|
||||
column_names = ['image', 'mask']
|
||||
return column_names
|
||||
|
||||
def __len__(self):
|
||||
if self.is_train:
|
||||
return len(self.train_ids)
|
||||
return len(self.val_ids)
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(description='Preprocess the UNet dataset ',
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
|
@ -42,5 +96,8 @@ def get_args():
|
|||
if __name__ == '__main__':
|
||||
args = get_args()
|
||||
|
||||
preprocess_dataset(data_dir=args.data_url, cross_valid_ind=cfg_unet['cross_valid_ind'], cfg=cfg_unet, result_path=
|
||||
args.result_path)
|
||||
if 'dataset' in cfg_unet and cfg_unet['dataset'] == "Cell_nuclei":
|
||||
cell_dataset = CellNucleiDataset(args.data_url, 1, args.result_path, False, 0.8)
|
||||
else:
|
||||
preprocess_dataset(data_dir=args.data_url, cross_valid_ind=cfg_unet['cross_valid_ind'], cfg=cfg_unet,
|
||||
result_path=args.result_path)
|
||||
|
|
|
@ -14,9 +14,10 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [[ $# -lt 2 || $# -gt 3 ]]; then
|
||||
echo "Usage: bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DEVICE_ID]
|
||||
DEVICE_ID is optional, it can be set by environment variable device_id, otherwise the value is zero"
|
||||
if [[ $# -lt 3 || $# -gt 4 ]]; then
|
||||
echo "Usage: bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DEVICE_ID] [NEED_PREPROCESS]
|
||||
DEVICE_ID is optional, it can be set by environment variable device_id, otherwise the value is zero.
|
||||
NEED_PREPROCESS means weather need preprocess or not, it's value is 'y' or 'n'."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
@ -29,7 +30,7 @@ get_real_path(){
|
|||
}
|
||||
model=$(get_real_path $1)
|
||||
data_path=$(get_real_path $2)
|
||||
if [ $# == 3 ]; then
|
||||
if [ $# == 4 ]; then
|
||||
device_id=$3
|
||||
if [ -z $device_id ]; then
|
||||
device_id=0
|
||||
|
@ -37,10 +38,12 @@ if [ $# == 3 ]; then
|
|||
device_id=$device_id
|
||||
fi
|
||||
fi
|
||||
need_preprocess=$4
|
||||
|
||||
echo "mindir name: "$model
|
||||
echo "dataset path: "$data_path
|
||||
echo "device id: "$device_id
|
||||
echo "need preprocess or not: "$need_preprocess
|
||||
|
||||
export ASCEND_HOME=/usr/local/Ascend/
|
||||
if [ -d ${ASCEND_HOME}/ascend-toolkit ]; then
|
||||
|
@ -85,7 +88,7 @@ function infer()
|
|||
fi
|
||||
mkdir result_Files
|
||||
mkdir time_Result
|
||||
../ascend310_infer/src/main --mindir_path=$model --dataset_path=./preprocess_Result/ --device_id=$device_id &> infer.log
|
||||
../ascend310_infer/src/main --mindir_path=$model --dataset_path=./preprocess_Result/ --device_id=$device_id --need_preprocess=$need_preprocess &> infer.log
|
||||
}
|
||||
|
||||
function cal_acc()
|
||||
|
|
Loading…
Reference in New Issue