forked from mindspore-Ecosystem/mindspore
!18709 Add EPP-MVSNet to model zoo
Merge pull request !18709 from NewMesc/master
This commit is contained in:
commit
8a417aa456
|
@ -0,0 +1,148 @@
|
||||||
|
# Contents
|
||||||
|
|
||||||
|
- [EPP-MVSNet](#thinking-path-re-ranker)
|
||||||
|
- [Model Architecture](#model-architecture)
|
||||||
|
- [Dataset](#dataset)
|
||||||
|
- [Features](#features)
|
||||||
|
- [Mixed Precision](#mixed-precision)
|
||||||
|
- [Environment Requirements](#environment-requirements)
|
||||||
|
- [Quick Start](#quick-start)
|
||||||
|
- [Script Description](#script-description)
|
||||||
|
- [Script and Sample Code](#script-and-sample-code)
|
||||||
|
- [Script Parameters](#script-parameters)
|
||||||
|
- [Training Process](#training-process)
|
||||||
|
- [Training](#training)
|
||||||
|
- [Evaluation Process](#evaluation-process)
|
||||||
|
- [Evaluation](#evaluation)
|
||||||
|
- [Model Description](#model-description)
|
||||||
|
- [Performance](#performance)
|
||||||
|
- [Description of random situation](#description-of-random-situation)
|
||||||
|
- [ModelZoo Homepage](#modelzoo-homepage)
|
||||||
|
|
||||||
|
# [EPP-MVSNet](#contents)
|
||||||
|
|
||||||
|
EPP-MVSNet was proposed in 2021 by Parallel Distributed Computing Lab & Huawei Riemann Lab. By aggregating features at high resolution to a
|
||||||
|
limited cost volume with an optimal depth range, thus, EPP-MVSNet leads to effective and efficient 3D construction. Moreover, EPP-MVSNet achieved
|
||||||
|
highest F-Score on the online TNT intermediate benchmark. This is a example of evaluation of EPP-MVSNet with BlendedMVS dataset in MindSpore. More
|
||||||
|
importantly, this is the first open source version for EPP-MVSNet.
|
||||||
|
|
||||||
|
# [Model Architecture](#contents)
|
||||||
|
|
||||||
|
Specially, EPP-MVSNet contains two main modules. The first part is feature extraction network, which extracts 2D features of a group of pictures(one reference
|
||||||
|
view and some source views) iteratively. The second part which contains three stages, iteratively regularizes the 3D cost volume composed of 2D features
|
||||||
|
using homography transformation, and finally predicts depth map.
|
||||||
|
|
||||||
|
# [Dataset](#contents)
|
||||||
|
|
||||||
|
The dataset used in this example is BlendedMVS, which is a large-scale MVS dataset for generalized multi-view stereo networks. The dataset contains
|
||||||
|
17k MVS training samples covering a variety of 113 scenes, including architectures, sculptures and small objects.
|
||||||
|
|
||||||
|
# [Features](#contents)
|
||||||
|
|
||||||
|
## [3D Feature](#contents)
|
||||||
|
|
||||||
|
This implementation version of EPP-MVSNet utilizes the newest 3D features of MindSpore.
|
||||||
|
|
||||||
|
# [Environment Requirements](#contents)
|
||||||
|
|
||||||
|
- Hardware (GPU)
|
||||||
|
- Framework
|
||||||
|
- [MindSpore](https://www.mindspore.cn/install/en)
|
||||||
|
- For more information, please check the resources below:
|
||||||
|
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
|
||||||
|
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
|
||||||
|
|
||||||
|
# [Quick Start](#contents)
|
||||||
|
|
||||||
|
After installing MindSpore via the official website and Dataset is correctly generated, you can start training and evaluation as follows.
|
||||||
|
|
||||||
|
- running on GPU
|
||||||
|
|
||||||
|
```python
|
||||||
|
# run evaluation example with BlendedMVS dataset
|
||||||
|
sh eval.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
# [Script Description](#contents)
|
||||||
|
|
||||||
|
## [Script and Sample Code](#contents)
|
||||||
|
|
||||||
|
```shell
|
||||||
|
.
|
||||||
|
└─eppmvsnet
|
||||||
|
├─README.md
|
||||||
|
├─scripts
|
||||||
|
| └─run_eval.sh # Launch evaluation in gpu
|
||||||
|
|
|
||||||
|
├─src
|
||||||
|
| ├─blendedmvs.py # build blendedmvs data
|
||||||
|
| ├─eppmvsnet.py # main architecture of EPP-MVSNet
|
||||||
|
| ├─modules.py # math operations used in EPP-MVSNet
|
||||||
|
| ├─networks.py # sub-networks of EPP-MVSNet
|
||||||
|
| └─utils.py # other operations used for evaluation
|
||||||
|
|
|
||||||
|
├─validate.py # Evaluation process on blendedmvs
|
||||||
|
```
|
||||||
|
|
||||||
|
## [Script Parameters](#contents)
|
||||||
|
|
||||||
|
Parameters for EPP-MVSNet evaluation can be set in validate.py.
|
||||||
|
|
||||||
|
- config for EPP-MVSNet
|
||||||
|
|
||||||
|
```python
|
||||||
|
"n_views": 5, # Num of views used in a depth prediction
|
||||||
|
"depth_interval": 128, # Init depth numbers
|
||||||
|
"n_depths": [32, 16, 8], # Depth numbers of three stages
|
||||||
|
"interval_ratios": [4.0, 2.0, 1.0], # Depth interval's expanding ratios of three stages
|
||||||
|
"img_wh": [768, 512], # Image resolution of evaluation
|
||||||
|
```
|
||||||
|
|
||||||
|
validate.py for more configuration.
|
||||||
|
|
||||||
|
## [Evaluation Process](#contents)
|
||||||
|
|
||||||
|
### Evaluation
|
||||||
|
|
||||||
|
- EPP-MVSNet evaluation on GPU
|
||||||
|
|
||||||
|
```python
|
||||||
|
sh eval.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
Evaluation result will be stored in "./results/blendedmvs/val/metrics.txt". You can find the result like the
|
||||||
|
followings in log.
|
||||||
|
|
||||||
|
```python
|
||||||
|
stage3_l1_loss:1.1738
|
||||||
|
stage3_less1_acc:0.8734
|
||||||
|
stage3_less3_acc:0.938
|
||||||
|
mean forward time(s/pic):0.2697
|
||||||
|
```
|
||||||
|
|
||||||
|
# [Model Description](#contents)
|
||||||
|
|
||||||
|
## [Performance](#contents)
|
||||||
|
|
||||||
|
### Inference Performance
|
||||||
|
|
||||||
|
| Parameter | EPP-MVSNet GPU |
|
||||||
|
| ------------------------------ | ---------------------------- |
|
||||||
|
| Model Version | Inception V1 |
|
||||||
|
| Resource | Tesla V100 16GB; Ubuntu16.04 |
|
||||||
|
| uploaded Date | 06/22/2021(month/day/year) |
|
||||||
|
| MindSpore Version | 1.3.0 |
|
||||||
|
| Dataset | BlendedMVS |
|
||||||
|
| Batch_size | 1 |
|
||||||
|
| Output | ./results/blendedmvs/val |
|
||||||
|
| Acc_less_1mm | 0.8734 |
|
||||||
|
| Acc_less_3mm | 0.938 |
|
||||||
|
| mean_time(s/pic) | 0.2697 |
|
||||||
|
|
||||||
|
# [Description of random situation](#contents)
|
||||||
|
|
||||||
|
No random situation for evaluation.
|
||||||
|
|
||||||
|
# [ModelZoo Homepage](#contents)
|
||||||
|
|
||||||
|
Please check the official [homepage](http://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
|
@ -0,0 +1,24 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
# eval script
|
||||||
|
SRC_NUM=4
|
||||||
|
VIEW_NUM=$[${SRC_NUM}+1]
|
||||||
|
DATAPATH="./data/blendedmvs/dataset_low_res"
|
||||||
|
|
||||||
|
python -u validate.py --root_dir ${DATAPATH} --dataset_name blendedmvs --save_visual --img_wh 768 576 --n_views ${VIEW_NUM} --n_depths 32 16 8 --interval_ratios 4.0 2.0 1.0 --levels 3 --split val --gpu_id 0 > log.txt 2>&1 &
|
||||||
|
|
||||||
|
cd ..
|
|
@ -0,0 +1,276 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""blendedmvs dataset"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from collections import defaultdict
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
import mindspore.dataset.vision.py_transforms as py_vision
|
||||||
|
|
||||||
|
from src.utils import read_pfm
|
||||||
|
|
||||||
|
|
||||||
|
class Compose:
|
||||||
|
"""Composes several transforms together.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
transforms (list of ``Transform`` objects): list of transforms to compose.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> transforms.Compose([
|
||||||
|
>>> transforms.CenterCrop(10),
|
||||||
|
>>> transforms.ToTensor(),
|
||||||
|
>>> ])
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, transforms):
|
||||||
|
self.transforms = transforms
|
||||||
|
|
||||||
|
def __call__(self, img):
|
||||||
|
for t in self.transforms:
|
||||||
|
img = t(img)
|
||||||
|
return img
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
format_string = self.__class__.__name__ + '('
|
||||||
|
for t in self.transforms:
|
||||||
|
format_string += '\n'
|
||||||
|
format_string += ' {0}'.format(t)
|
||||||
|
format_string += '\n)'
|
||||||
|
return format_string
|
||||||
|
|
||||||
|
|
||||||
|
class BlendedMVSDataset:
|
||||||
|
"""blendedmvs dataset"""
|
||||||
|
|
||||||
|
def __init__(self, root_dir, split, n_views=3, levels=3, depth_interval=128.0, img_wh=(768, 576),
|
||||||
|
crop_wh=(640, 512), scale=False, scan=None, training_tag=False):
|
||||||
|
"""
|
||||||
|
img_wh should be set to a tuple ex: (1152, 864) to enable test mode!
|
||||||
|
"""
|
||||||
|
self.root_dir = root_dir
|
||||||
|
self.split = split
|
||||||
|
self.scale = scale
|
||||||
|
self.training_tag = training_tag
|
||||||
|
assert self.split in ['train', 'val', 'all'], \
|
||||||
|
'split must be either "train", "val" or "all"!'
|
||||||
|
self.img_wh = img_wh
|
||||||
|
if img_wh is not None:
|
||||||
|
assert img_wh[0] % 32 == 0 and img_wh[1] % 32 == 0, \
|
||||||
|
'img_wh must both be multiples of 32!'
|
||||||
|
self.single_scan = scan
|
||||||
|
self.crop_wh = crop_wh
|
||||||
|
if crop_wh is not None:
|
||||||
|
assert crop_wh[0] % 32 == 0 and crop_wh[1] % 32 == 0, \
|
||||||
|
'img_wh must both be multiples of 32!'
|
||||||
|
self.n_views = n_views
|
||||||
|
self.levels = levels # FPN levels
|
||||||
|
self.n_depths = depth_interval
|
||||||
|
|
||||||
|
self.build_metas()
|
||||||
|
self.cal_crop_factors()
|
||||||
|
self.build_proj_mats()
|
||||||
|
self.define_transforms()
|
||||||
|
|
||||||
|
def cal_crop_factors(self):
|
||||||
|
""""calculate crop factors"""
|
||||||
|
self.start_w = (self.img_wh[0] - self.crop_wh[0]) // 2
|
||||||
|
self.start_h = (self.img_wh[1] - self.crop_wh[1]) // 2
|
||||||
|
self.finish_w = self.start_w + self.crop_wh[0]
|
||||||
|
self.finish_h = self.start_h + self.crop_wh[1]
|
||||||
|
|
||||||
|
def build_metas(self):
|
||||||
|
""""build meta information"""
|
||||||
|
self.metas = []
|
||||||
|
self.ref_views_per_scan = defaultdict(list)
|
||||||
|
if self.split == 'train':
|
||||||
|
list_txt = os.path.join(self.root_dir, 'training_list.txt')
|
||||||
|
elif self.split == 'val':
|
||||||
|
list_txt = os.path.join(self.root_dir, 'validation_list.txt')
|
||||||
|
else:
|
||||||
|
list_txt = os.path.join(self.root_dir, 'all_list.txt')
|
||||||
|
|
||||||
|
if self.single_scan is not None:
|
||||||
|
self.scans = self.single_scan if isinstance(self.single_scan, list) else [self.single_scan]
|
||||||
|
else:
|
||||||
|
with open(list_txt) as f:
|
||||||
|
self.scans = [line.rstrip() for line in f.readlines()]
|
||||||
|
|
||||||
|
for scan in self.scans:
|
||||||
|
with open(os.path.join(self.root_dir, scan, "cams/pair.txt")) as f:
|
||||||
|
num_viewpoint = int(f.readline())
|
||||||
|
for _ in range(num_viewpoint):
|
||||||
|
ref_view = int(f.readline().rstrip())
|
||||||
|
self.ref_views_per_scan[scan] += [ref_view]
|
||||||
|
line = f.readline().rstrip().split()
|
||||||
|
n_views_valid = int(line[0]) # valid views
|
||||||
|
if n_views_valid < self.n_views: # skip no enough valid views
|
||||||
|
continue
|
||||||
|
src_views = [int(x) for x in line[1::2]]
|
||||||
|
self.metas += [(scan, -1, ref_view, src_views)]
|
||||||
|
|
||||||
|
def build_proj_mats(self):
|
||||||
|
""""build projection matrix"""
|
||||||
|
self.proj_mats = {} # proj mats for each scan
|
||||||
|
if self.root_dir.endswith('dataset_low_res') \
|
||||||
|
or self.root_dir.endswith('dataset_low_res/'):
|
||||||
|
img_w, img_h = 768, 576
|
||||||
|
else:
|
||||||
|
img_w, img_h = 2048, 1536
|
||||||
|
for scan in self.scans:
|
||||||
|
self.proj_mats[scan] = {}
|
||||||
|
for vid in self.ref_views_per_scan[scan]:
|
||||||
|
proj_mat_filename = os.path.join(self.root_dir, scan,
|
||||||
|
f'cams/{vid:08d}_cam.txt')
|
||||||
|
intrinsics, extrinsics, depth_min, depth_max = \
|
||||||
|
self.read_cam_file(scan, proj_mat_filename)
|
||||||
|
intrinsics[0] *= self.img_wh[0] / img_w / 8
|
||||||
|
intrinsics[1] *= self.img_wh[1] / img_h / 8
|
||||||
|
# center crop
|
||||||
|
if self.training_tag:
|
||||||
|
intrinsics[0, 2] = intrinsics[0, 2] - self.start_w / 8
|
||||||
|
intrinsics[1, 2] = intrinsics[1, 2] - self.start_h / 8
|
||||||
|
|
||||||
|
# multiply intrinsics and extrinsics to get projection matrix
|
||||||
|
proj_mat_ls = []
|
||||||
|
for _ in reversed(range(self.levels)):
|
||||||
|
proj_mat_l = np.eye(4)
|
||||||
|
proj_mat_l[:3, :4] = intrinsics @ extrinsics[:3, :4]
|
||||||
|
intrinsics[:2] *= 2 # 1/8->1/4->1/2
|
||||||
|
proj_mat_ls += [proj_mat_l]
|
||||||
|
proj_mat_ls = np.stack(proj_mat_ls[::-1]).astype(dtype=np.float32)
|
||||||
|
self.proj_mats[scan][vid] = (proj_mat_ls, depth_min, depth_max)
|
||||||
|
|
||||||
|
def read_cam_file(self, scan, filename):
|
||||||
|
""""read camera file"""
|
||||||
|
with open(filename) as f:
|
||||||
|
lines = [line.rstrip() for line in f.readlines()]
|
||||||
|
# extrinsics: line [1,5), 4x4 matrix
|
||||||
|
extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ')
|
||||||
|
extrinsics = extrinsics.reshape((4, 4))
|
||||||
|
# intrinsics: line [7-10), 3x3 matrix
|
||||||
|
intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ')
|
||||||
|
intrinsics = intrinsics.reshape((3, 3))
|
||||||
|
# depth_min & depth_interval: line 11
|
||||||
|
depth_min = float(lines[11].split()[0])
|
||||||
|
depth_max = float(lines[11].split()[3])
|
||||||
|
return intrinsics, extrinsics, depth_min, depth_max
|
||||||
|
|
||||||
|
def read_depth_and_mask(self, scan, filename, depth_min):
|
||||||
|
""""read depth and mask"""
|
||||||
|
depth = np.array(read_pfm(filename)[0], dtype=np.float32)
|
||||||
|
h, w = depth.shape
|
||||||
|
if (h, w) != self.img_wh:
|
||||||
|
depth_0 = cv2.resize(depth, self.img_wh,
|
||||||
|
interpolation=cv2.INTER_LINEAR)
|
||||||
|
if self.training_tag:
|
||||||
|
depth_0 = depth_0[self.start_h:self.finish_h, self.start_w:self.finish_w]
|
||||||
|
depth_0 = cv2.resize(depth_0, None, fx=0.5, fy=0.5,
|
||||||
|
interpolation=cv2.INTER_LINEAR)
|
||||||
|
depth_1 = cv2.resize(depth_0, None, fx=0.5, fy=0.5,
|
||||||
|
interpolation=cv2.INTER_LINEAR)
|
||||||
|
depth_2 = cv2.resize(depth_1, None, fx=0.5, fy=0.5,
|
||||||
|
interpolation=cv2.INTER_LINEAR)
|
||||||
|
|
||||||
|
depths = {"level_0": depth_0,
|
||||||
|
"level_1": depth_1,
|
||||||
|
"level_2": depth_2}
|
||||||
|
|
||||||
|
masks = {"level_0": depth_0 > depth_min,
|
||||||
|
"level_1": depth_1 > depth_min,
|
||||||
|
"level_2": depth_2 > depth_min}
|
||||||
|
depth_max = depth_0.max()
|
||||||
|
return depths, masks, depth_max
|
||||||
|
|
||||||
|
def define_transforms(self):
|
||||||
|
if self.training_tag and self.split == 'train': # you can add augmentation here
|
||||||
|
self.transform = Compose([
|
||||||
|
py_vision.ToTensor(),
|
||||||
|
py_vision.Normalize(mean=[0.485, 0.456, 0.406],
|
||||||
|
std=[0.229, 0.224, 0.225]),
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
self.transform = Compose([
|
||||||
|
py_vision.ToTensor(),
|
||||||
|
py_vision.Normalize(mean=[0.485, 0.456, 0.406],
|
||||||
|
std=[0.229, 0.224, 0.225]),
|
||||||
|
])
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.metas)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
sample = {}
|
||||||
|
scan, _, ref_view, src_views = self.metas[idx]
|
||||||
|
# use only the reference view and first nviews-1 source views
|
||||||
|
view_ids = [ref_view] + src_views[:self.n_views - 1]
|
||||||
|
|
||||||
|
imgs = []
|
||||||
|
proj_mats = [] # record proj mats between views
|
||||||
|
for i, vid in enumerate(view_ids):
|
||||||
|
img_filename = os.path.join(self.root_dir, f'{scan}/blended_images/{vid:08d}.jpg')
|
||||||
|
depth_filename = os.path.join(self.root_dir, f'{scan}/rendered_depth_maps/{vid:08d}.pfm')
|
||||||
|
|
||||||
|
img = Image.open(img_filename)
|
||||||
|
w, h = img.size
|
||||||
|
if (h, w) != self.img_wh:
|
||||||
|
img = img.resize(self.img_wh, Image.BILINEAR)
|
||||||
|
if self.training_tag:
|
||||||
|
img = img.crop((self.start_w, self.start_h, self.finish_w, self.finish_h))
|
||||||
|
img = self.transform(img)
|
||||||
|
imgs += [img]
|
||||||
|
|
||||||
|
proj_mat_ls, depth_min, depth_max = deepcopy(self.proj_mats[scan][vid])
|
||||||
|
|
||||||
|
if i == 0: # reference view
|
||||||
|
if self.split == 'train':
|
||||||
|
depths, masks, depth_max = self.read_depth_and_mask(scan, depth_filename, depth_min)
|
||||||
|
elif self.split == 'val':
|
||||||
|
if self.training_tag:
|
||||||
|
depths, masks, depth_max = self.read_depth_and_mask(scan, depth_filename, depth_min)
|
||||||
|
else:
|
||||||
|
depths, masks, _ = self.read_depth_and_mask(scan, depth_filename, depth_min)
|
||||||
|
else:
|
||||||
|
raise ValueError
|
||||||
|
fix_depth_interval = (depth_max - depth_min) / self.n_depths
|
||||||
|
depth_interval = fix_depth_interval
|
||||||
|
sample['init_depth_min'] = [depth_min]
|
||||||
|
sample['depth_interval'] = [depth_interval]
|
||||||
|
sample['fix_depth_interval'] = [fix_depth_interval]
|
||||||
|
ref_proj_inv = np.asarray(proj_mat_ls)
|
||||||
|
for j in range(proj_mat_ls.shape[0]):
|
||||||
|
ref_proj_inv[j] = np.mat(proj_mat_ls[j]).I
|
||||||
|
else:
|
||||||
|
proj_mats += [proj_mat_ls @ ref_proj_inv]
|
||||||
|
|
||||||
|
imgs = np.stack(imgs)
|
||||||
|
proj_mats = np.stack(proj_mats)[:, :, :3] # (V-1, self.levels, 3, 4) from fine to coarse
|
||||||
|
depth_0 = depths['level_0']
|
||||||
|
mask_0 = masks["level_0"]
|
||||||
|
|
||||||
|
sample['imgs'] = imgs
|
||||||
|
sample['proj_mats'] = proj_mats
|
||||||
|
sample['depths'] = depths
|
||||||
|
sample['masks'] = masks
|
||||||
|
sample['scan_vid'] = (scan, ref_view)
|
||||||
|
|
||||||
|
return imgs, proj_mats, np.array(sample['init_depth_min'], dtype=np.float32), \
|
||||||
|
np.array(sample['depth_interval'], dtype=np.float32), np.fromstring(scan, dtype=np.uint8), \
|
||||||
|
np.array(ref_view), depth_0, mask_0, np.array(sample['fix_depth_interval'], dtype=np.float32)
|
|
@ -0,0 +1,501 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""main architecture of EPP-MVSNet"""
|
||||||
|
|
||||||
|
import mindspore.nn as nn
|
||||||
|
import mindspore.common as mstype
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from src.modules import get_depth_values, determine_center_pixel_interval, groupwise_correlation, entropy_num_based, \
|
||||||
|
HomoWarp
|
||||||
|
from src.networks import UNet2D, CostCompression, CoarseStageRegPair, CoarseStageRegFuse, StageRegFuse
|
||||||
|
|
||||||
|
|
||||||
|
class SingleStage(nn.Cell):
|
||||||
|
"""single stage"""
|
||||||
|
|
||||||
|
def __init__(self, fuse_reg, height, width, entropy_range=False):
|
||||||
|
super(SingleStage, self).__init__()
|
||||||
|
self.fuse_reg = fuse_reg
|
||||||
|
self.entropy_range = entropy_range
|
||||||
|
|
||||||
|
self.transpose = P.Transpose()
|
||||||
|
self.zeros = P.Zeros()
|
||||||
|
self.linspace = P.LinSpace()
|
||||||
|
self.exp = P.Exp()
|
||||||
|
self.expand_dims = P.ExpandDims()
|
||||||
|
self.squeeze_1 = P.Squeeze(1)
|
||||||
|
self.pow = P.Pow()
|
||||||
|
self.tile = P.Tile()
|
||||||
|
self.homo_warp = HomoWarp(height, width)
|
||||||
|
|
||||||
|
def construct(self, sample, depth_num, depth_start_override=None, depth_interval_override=None,
|
||||||
|
uncertainty_maps=None):
|
||||||
|
"""construct function of single stage"""
|
||||||
|
ref_feat, src_feats, proj_mats = sample
|
||||||
|
depth_start = depth_start_override # n111 or n1hw
|
||||||
|
depth_interval = depth_interval_override # n111
|
||||||
|
|
||||||
|
D = depth_num
|
||||||
|
B, C, H, W = ref_feat.shape
|
||||||
|
|
||||||
|
depth_interval = depth_interval.view(B, 1, 1, 1)
|
||||||
|
interim_scale = 1
|
||||||
|
|
||||||
|
ref_ncdhw = self.expand_dims(ref_feat, 2).view(B, C, 1, -1)
|
||||||
|
ref_ncdhw = self.tile(ref_ncdhw, (1, 1, D, 1)).view(B, C, D, H, W)
|
||||||
|
|
||||||
|
pair_results = [] # MVS
|
||||||
|
|
||||||
|
weight_sum = self.zeros((ref_ncdhw.shape[0], 1, 1, ref_ncdhw.shape[3] // interim_scale,
|
||||||
|
ref_ncdhw.shape[4] // interim_scale), mstype.float32)
|
||||||
|
fused_interim = self.zeros((ref_ncdhw.shape[0], 8, ref_ncdhw.shape[2] // interim_scale, ref_ncdhw.shape[3] //
|
||||||
|
interim_scale, ref_ncdhw.shape[4] // interim_scale), mstype.float32)
|
||||||
|
|
||||||
|
depth_values = get_depth_values(depth_start, D, depth_interval, False)
|
||||||
|
|
||||||
|
for i in range(src_feats.shape[1]):
|
||||||
|
src_feat = src_feats[:, i]
|
||||||
|
proj_mat = proj_mats[:, i]
|
||||||
|
uncertainty_map = uncertainty_maps[i]
|
||||||
|
warped_src = self.homo_warp(src_feat, proj_mat, depth_values)
|
||||||
|
cost_volume = groupwise_correlation(ref_ncdhw, warped_src, 8, 1)
|
||||||
|
|
||||||
|
interim = cost_volume
|
||||||
|
heads = [uncertainty_map]
|
||||||
|
|
||||||
|
weight = self.expand_dims(self.exp(-heads[0]), 2)
|
||||||
|
weight_sum = weight_sum + weight
|
||||||
|
fused_interim = fused_interim + interim * weight
|
||||||
|
|
||||||
|
fused_interim /= weight_sum
|
||||||
|
est_depth, prob_map, prob_volume = self.fuse_reg(fused_interim, depth_values)
|
||||||
|
|
||||||
|
if self.entropy_range:
|
||||||
|
# mean entropy
|
||||||
|
entropy_num = entropy_num_based(prob_volume, dim=1, depth_num=D, keepdim=True)
|
||||||
|
conf_range = self.pow(D, entropy_num).view(B, -1).mean()
|
||||||
|
return est_depth, prob_map, pair_results, conf_range # MVS
|
||||||
|
return est_depth, prob_map, pair_results # MVS
|
||||||
|
|
||||||
|
|
||||||
|
class SingleStageP1(nn.Cell):
|
||||||
|
"""part1 of single stage 1"""
|
||||||
|
|
||||||
|
def __init__(self, depth_number=32):
|
||||||
|
super(SingleStageP1, self).__init__()
|
||||||
|
self.transpose = P.Transpose()
|
||||||
|
self.zeros = P.Zeros()
|
||||||
|
self.linspace = P.LinSpace()
|
||||||
|
self.expand_dims = P.ExpandDims()
|
||||||
|
self.squeeze_1 = P.Squeeze(1)
|
||||||
|
self.squeeze_last = P.Squeeze(-1)
|
||||||
|
self.tile = P.Tile()
|
||||||
|
|
||||||
|
self.zero = Tensor(0, mstype.float32)
|
||||||
|
self.compression_ratio = 1
|
||||||
|
self.depth_number = depth_number - self.compression_ratio + 1
|
||||||
|
self.D = Tensor(self.depth_number - 1, mstype.float32)
|
||||||
|
|
||||||
|
def construct(self, sample, depth_num, depth_start_override=None, depth_interval_override=None, timing=True):
|
||||||
|
"""construct function of part1 of single stage 1"""
|
||||||
|
ref_feat, _, _ = sample
|
||||||
|
depth_start = depth_start_override # n111 or n1hw
|
||||||
|
depth_interval = depth_interval_override # n111
|
||||||
|
|
||||||
|
depth_num *= self.compression_ratio
|
||||||
|
depth_interval /= self.compression_ratio
|
||||||
|
|
||||||
|
D = depth_num
|
||||||
|
B = ref_feat.shape[0]
|
||||||
|
|
||||||
|
depth_interval = depth_interval.view(B, 1, 1, 1)
|
||||||
|
depth_start = depth_start.reshape(B, 1, 1, 1)
|
||||||
|
|
||||||
|
h, w = ref_feat.shape[-2:]
|
||||||
|
depth_end = depth_start + (D - 1) * depth_interval
|
||||||
|
inverse_depth_interval = (1 / depth_start - 1 / depth_end) / (D - self.compression_ratio)
|
||||||
|
depth_values = 1 / depth_end + inverse_depth_interval * \
|
||||||
|
self.linspace(self.zero, self.D, self.depth_number) # (D)
|
||||||
|
depth_values = 1.0 / depth_values
|
||||||
|
single_depth_value = depth_values.reshape(B, self.depth_number, 1, 1)
|
||||||
|
depth_values = self.tile(single_depth_value, (1, 1, h, w))
|
||||||
|
single_depth_value = single_depth_value[:, ::self.compression_ratio, :, :]
|
||||||
|
|
||||||
|
return depth_values, single_depth_value
|
||||||
|
|
||||||
|
|
||||||
|
class SingleStageP3(nn.Cell):
|
||||||
|
"""part3 of single stage 1"""
|
||||||
|
|
||||||
|
def __init__(self, entropy_range=False, compression_ratio=5):
|
||||||
|
super(SingleStageP3, self).__init__()
|
||||||
|
self.pair_reg = CoarseStageRegPair()
|
||||||
|
self.fuse_reg = CoarseStageRegFuse()
|
||||||
|
self.entropy_range = entropy_range
|
||||||
|
self.compression_ratio = compression_ratio
|
||||||
|
|
||||||
|
self.transpose = P.Transpose()
|
||||||
|
self.zeros = P.Zeros()
|
||||||
|
self.linspace = P.LinSpace()
|
||||||
|
self.expand_dims = P.ExpandDims()
|
||||||
|
self.squeeze_1 = P.Squeeze(1)
|
||||||
|
self.squeeze_last = P.Squeeze(-1)
|
||||||
|
self.tile = P.Tile()
|
||||||
|
self.exp = P.Exp()
|
||||||
|
self.pow = P.Pow()
|
||||||
|
|
||||||
|
def construct(self, cost_volume_list, depth_values, sample, depth_num):
|
||||||
|
"""construct function"""
|
||||||
|
ref_feat, _, _ = sample
|
||||||
|
|
||||||
|
B, _, _, _ = ref_feat.shape
|
||||||
|
|
||||||
|
d_scale = 1
|
||||||
|
interim_scale = 1
|
||||||
|
|
||||||
|
ref_ncdhw = self.expand_dims(ref_feat, 2)
|
||||||
|
pair_results = [] # MVS
|
||||||
|
|
||||||
|
weight_sum = self.zeros((ref_ncdhw.shape[0], 1, 1, ref_ncdhw.shape[3] // interim_scale,
|
||||||
|
ref_ncdhw.shape[4] // interim_scale), mstype.float32)
|
||||||
|
fused_interim = self.zeros((ref_ncdhw.shape[0], 8, depth_num // d_scale,
|
||||||
|
ref_ncdhw.shape[3] // interim_scale, ref_ncdhw.shape[4] // interim_scale),
|
||||||
|
mstype.float32)
|
||||||
|
|
||||||
|
for i in range(cost_volume_list.shape[1]):
|
||||||
|
cost_volume = cost_volume_list[:, i]
|
||||||
|
|
||||||
|
interim, est_depth, uncertainty_map, occ = self.pair_reg(cost_volume, depth_values)
|
||||||
|
pair_results.append([est_depth, [uncertainty_map, occ]])
|
||||||
|
|
||||||
|
weight = self.expand_dims(self.exp(-uncertainty_map), 2)
|
||||||
|
weight_sum = weight_sum + weight
|
||||||
|
fused_interim = fused_interim + interim * weight
|
||||||
|
fused_interim /= weight_sum
|
||||||
|
est_depth, prob_map, prob_volume = self.fuse_reg(fused_interim, depth_values)
|
||||||
|
if self.entropy_range:
|
||||||
|
# mean entropy
|
||||||
|
entropy_num = entropy_num_based(prob_volume, dim=1, depth_num=depth_num, keepdim=True)
|
||||||
|
conf_range = self.pow(depth_num, entropy_num).view(B, -1).mean()
|
||||||
|
return est_depth, prob_map, pair_results, conf_range # MVS
|
||||||
|
return est_depth, prob_map, pair_results # MVS
|
||||||
|
|
||||||
|
|
||||||
|
class SingleStageP2_S1(nn.Cell):
|
||||||
|
"""0 interpolation, part2 of single stage 1"""
|
||||||
|
|
||||||
|
def __init__(self, cost_compression, height=32, width=40):
|
||||||
|
super(SingleStageP2_S1, self).__init__()
|
||||||
|
self.cost_compression = cost_compression
|
||||||
|
|
||||||
|
self.transpose = P.Transpose()
|
||||||
|
self.zeros = P.Zeros()
|
||||||
|
self.linspace = P.LinSpace()
|
||||||
|
self.expand_dims = P.ExpandDims()
|
||||||
|
self.squeeze_1 = P.Squeeze(1)
|
||||||
|
self.squeeze_last = P.Squeeze(-1)
|
||||||
|
self.tile = P.Tile()
|
||||||
|
self.homo_warp = HomoWarp(height, width)
|
||||||
|
self.stack = P.Stack(1)
|
||||||
|
|
||||||
|
def construct(self, sample, depth_num, depth_start_override=None, depth_interval_override=None, depth_values=None,
|
||||||
|
idx=None):
|
||||||
|
"""construct function"""
|
||||||
|
ref_feat, src_feats, proj_mats = sample
|
||||||
|
|
||||||
|
compression_ratio = 1
|
||||||
|
depth_num *= compression_ratio
|
||||||
|
|
||||||
|
B, C, H, W = ref_feat.shape
|
||||||
|
|
||||||
|
ref_ncdhw = self.expand_dims(ref_feat, 2).view(B, C, 1, -1)
|
||||||
|
ref_ncdhw = self.tile(ref_ncdhw, (1, 1, 32, 1)).view(B, C, 32, H, W)
|
||||||
|
|
||||||
|
src_feat = src_feats[:, idx]
|
||||||
|
proj_mat = proj_mats[:, idx]
|
||||||
|
src_depth_values = depth_values
|
||||||
|
|
||||||
|
warped_src = self.homo_warp(src_feat, proj_mat, src_depth_values)
|
||||||
|
cost_volume = groupwise_correlation(ref_ncdhw, warped_src, 8, 1)
|
||||||
|
# dynamic max pool
|
||||||
|
cost_volume = self.cost_compression(cost_volume)
|
||||||
|
return cost_volume
|
||||||
|
|
||||||
|
|
||||||
|
class SingleStageP2_S3(nn.Cell):
|
||||||
|
"""2 interpolation, part2 of single stage 1"""
|
||||||
|
|
||||||
|
def __init__(self, cost_compression, height=64, width=80, depth_number=96, depth_ratio=3, compression_ratio=5):
|
||||||
|
super(SingleStageP2_S3, self).__init__()
|
||||||
|
self.cost_compression = cost_compression
|
||||||
|
|
||||||
|
self.transpose = P.Transpose()
|
||||||
|
self.zeros = P.Zeros()
|
||||||
|
self.linspace = P.LinSpace()
|
||||||
|
self.expand_dims = P.ExpandDims()
|
||||||
|
self.squeeze_1 = P.Squeeze(1)
|
||||||
|
self.squeeze_last = P.Squeeze(-1)
|
||||||
|
self.tile = P.Tile()
|
||||||
|
self.linspace = P.LinSpace()
|
||||||
|
self.cat1 = P.Concat(axis=1)
|
||||||
|
self.stack = P.Stack(1)
|
||||||
|
|
||||||
|
self.homo_warp = HomoWarp(height, width)
|
||||||
|
self.max_pool = P.MaxPool3D(kernel_size=(3, 1, 1), strides=(3, 1, 1))
|
||||||
|
|
||||||
|
self.zero = Tensor(0, mstype.float32)
|
||||||
|
self.src_compression_ratio = depth_ratio
|
||||||
|
self.depth_number = depth_number - depth_ratio + 1
|
||||||
|
self.src_D = Tensor(self.depth_number, mstype.float32)
|
||||||
|
self.compression_ratio = compression_ratio
|
||||||
|
|
||||||
|
def construct(self, sample, depth_num, depth_start_override=None, depth_interval_override=None, depth_values=None,
|
||||||
|
idx=None):
|
||||||
|
"""construct function"""
|
||||||
|
ref_feat, src_feats, proj_mats = sample
|
||||||
|
depth_start = depth_start_override # n111 or n1hw
|
||||||
|
depth_interval = depth_interval_override # n111
|
||||||
|
|
||||||
|
depth_num *= self.compression_ratio
|
||||||
|
depth_interval /= self.compression_ratio
|
||||||
|
|
||||||
|
D = depth_num
|
||||||
|
B, C, H, W = ref_feat.shape
|
||||||
|
|
||||||
|
depth_interval = depth_interval.view(B, 1, 1, 1)
|
||||||
|
depth_start = depth_start.reshape(B, 1, 1, 1)
|
||||||
|
depth_end = depth_start + (D - 1) * depth_interval
|
||||||
|
|
||||||
|
ref_ncdhw = self.expand_dims(ref_feat, 2).view(B, C, 1, -1)
|
||||||
|
ref_ncdhw = self.tile(ref_ncdhw, (1, 1, 96, 1)).view(B, C, 96, H, W)
|
||||||
|
|
||||||
|
src_feat = src_feats[:, idx]
|
||||||
|
proj_mat = proj_mats[:, idx]
|
||||||
|
src_D = D // self.compression_ratio * self.src_compression_ratio
|
||||||
|
src_inverse_depth_interval = (1 / depth_start - 1 / depth_end) / (src_D - self.src_compression_ratio)
|
||||||
|
src_depth_values = 1 / depth_end + src_inverse_depth_interval * \
|
||||||
|
self.linspace(self.zero, self.src_D, self.depth_number) # (D)
|
||||||
|
src_depth_values = 1.0 / src_depth_values
|
||||||
|
|
||||||
|
src_depth_values = src_depth_values.view(B, -1, 1, 1)
|
||||||
|
|
||||||
|
end_interval = self.expand_dims(src_depth_values[:, 1, :, :], 1) - self.expand_dims(
|
||||||
|
src_depth_values[:, 0, :, :], 1)
|
||||||
|
end_interpolation = self.expand_dims(src_depth_values[:, 0, :, :], 1) - end_interval
|
||||||
|
start_interval = self.expand_dims(src_depth_values[:, -2, :, :], 1) - self.expand_dims(
|
||||||
|
src_depth_values[:, -1, :, :], 1)
|
||||||
|
start_interpolation = self.expand_dims(src_depth_values[:, -1, :, :], 1) - start_interval
|
||||||
|
|
||||||
|
src_depth_values = self.cat1((end_interpolation, src_depth_values, start_interpolation))
|
||||||
|
src_depth_values = self.tile(src_depth_values, (1, 1, H, W))
|
||||||
|
warped_src = self.homo_warp(src_feat, proj_mat, src_depth_values)
|
||||||
|
cost_volume = groupwise_correlation(ref_ncdhw, warped_src, 8, 1)
|
||||||
|
|
||||||
|
# dynamic max pool
|
||||||
|
cost_volume = self.cost_compression(cost_volume)
|
||||||
|
cost_volume = self.max_pool(cost_volume)
|
||||||
|
return cost_volume
|
||||||
|
|
||||||
|
|
||||||
|
class EPPMVSNet(nn.Cell):
|
||||||
|
"""EPP-MVSNet"""
|
||||||
|
|
||||||
|
def __init__(self, n_depths, interval_ratios, entropy_range=False, shrink_ratio=1,
|
||||||
|
height=None, width=None, distance=0.5):
|
||||||
|
super(EPPMVSNet, self).__init__()
|
||||||
|
|
||||||
|
self.feat_ext = UNet2D()
|
||||||
|
self.n_depths = n_depths
|
||||||
|
self.interval_ratios = interval_ratios
|
||||||
|
self.entropy_range = entropy_range
|
||||||
|
# hyper parameter used in entropy-based adjustment
|
||||||
|
self.shrink_ratio = shrink_ratio
|
||||||
|
self.distance = distance
|
||||||
|
|
||||||
|
self.stage1_p1 = SingleStageP1()
|
||||||
|
self.stage1_p3 = SingleStageP3(entropy_range=self.entropy_range)
|
||||||
|
|
||||||
|
self.cost_compression = CostCompression()
|
||||||
|
|
||||||
|
self.stage1_p2_s1 = SingleStageP2_S1(self.cost_compression, height=height // 8, width=width // 8)
|
||||||
|
self.stage1_p2_s3 = SingleStageP2_S3(self.cost_compression, height=height // 8, width=width // 8)
|
||||||
|
|
||||||
|
self.fuse_reg_2 = StageRegFuse("./ckpts/stage2_reg_fuse.ckpt")
|
||||||
|
self.fuse_reg_3 = StageRegFuse("./ckpts/stage3_reg_fuse.ckpt")
|
||||||
|
|
||||||
|
self.stage2 = SingleStage(self.fuse_reg_2, height // 4, width // 4, entropy_range=self.entropy_range)
|
||||||
|
self.stage3 = SingleStage(self.fuse_reg_3, height // 2, width // 2, entropy_range=self.entropy_range)
|
||||||
|
|
||||||
|
self.eppmvsnet_p1 = EPPMVSNetP1(self.feat_ext, self.stage1_p1, n_depths=n_depths,
|
||||||
|
interval_ratios=interval_ratios)
|
||||||
|
self.eppmvsnet_p3 = EPPMVSNetP3(self.stage1_p3, self.stage2, self.stage3, n_depths=n_depths,
|
||||||
|
interval_ratios=interval_ratios, height=height, width=width,
|
||||||
|
entropy_range=self.entropy_range, shrink_ratio=self.shrink_ratio)
|
||||||
|
|
||||||
|
def construct(self, imgs, proj_mats=None, depth_start=None, depth_interval=None):
|
||||||
|
"""construct function"""
|
||||||
|
feat_pack_1, feat_pack_2, feat_pack_3, depth_values_stage1, pixel_distances \
|
||||||
|
= self.eppmvsnet_p1(imgs, proj_mats, depth_start, depth_interval)
|
||||||
|
|
||||||
|
cost_volume_list = []
|
||||||
|
ref_feat_1, srcs_feat_1 = feat_pack_1[:, 0], feat_pack_1[:, 1:]
|
||||||
|
|
||||||
|
for i in range(feat_pack_1.shape[1] - 1):
|
||||||
|
pixel_distance = P.Squeeze()(pixel_distances[i])
|
||||||
|
|
||||||
|
if pixel_distance < self.distance * 3:
|
||||||
|
cost_volume_1 = self.stage1_p2_s1([ref_feat_1, srcs_feat_1, proj_mats[:, :, 2]],
|
||||||
|
self.n_depths[0], depth_start,
|
||||||
|
depth_interval * self.interval_ratios[0],
|
||||||
|
depth_values_stage1, i)
|
||||||
|
|
||||||
|
cost_volume_list.append(cost_volume_1)
|
||||||
|
else:
|
||||||
|
cost_volume_3 = self.stage1_p2_s3([ref_feat_1, srcs_feat_1, proj_mats[:, :, 2]],
|
||||||
|
self.n_depths[0], depth_start,
|
||||||
|
depth_interval * self.interval_ratios[0],
|
||||||
|
depth_values_stage1, i)
|
||||||
|
cost_volume_list.append(cost_volume_3)
|
||||||
|
cost_volume_list = P.Stack(1)(cost_volume_list)
|
||||||
|
results = self.eppmvsnet_p3(feat_pack_1, feat_pack_2, feat_pack_3, depth_values_stage1, cost_volume_list,
|
||||||
|
proj_mats, depth_start, depth_interval)
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
class EPPMVSNetP1(nn.Cell):
|
||||||
|
"""EPPMVSNet part1"""
|
||||||
|
|
||||||
|
def __init__(self, feat_ext, stage1_p1, n_depths, interval_ratios, entropy_range=False):
|
||||||
|
super(EPPMVSNetP1, self).__init__()
|
||||||
|
self.n_depths = n_depths
|
||||||
|
self.interval_ratios = interval_ratios
|
||||||
|
self.entropy_range = entropy_range
|
||||||
|
# hyper parameter used in entropy-based adjustment
|
||||||
|
self.shrink_ratio = 1
|
||||||
|
|
||||||
|
self.feat_ext = feat_ext
|
||||||
|
self.stage1_p1 = stage1_p1
|
||||||
|
|
||||||
|
self.tile = P.Tile()
|
||||||
|
self.expand_dims = P.ExpandDims()
|
||||||
|
|
||||||
|
def construct(self, imgs, proj_mats=None, depth_start=None, depth_interval=None):
|
||||||
|
"""construct function"""
|
||||||
|
B, V, _, H, W = imgs.shape
|
||||||
|
imgs = imgs.reshape(B * V, 3, H, W)
|
||||||
|
feat_pack_1, feat_pack_2, feat_pack_3 = self.feat_ext(imgs)
|
||||||
|
feat_pack_1 = feat_pack_1.view(B, V, *feat_pack_1.shape[1:]) # (B, V, C, h, w)
|
||||||
|
feat_pack_2 = feat_pack_2.view(B, V, *feat_pack_2.shape[1:]) # (B, V, C, h, w)
|
||||||
|
feat_pack_3 = feat_pack_3.view(B, V, *feat_pack_3.shape[1:]) # (B, V, C, h, w)
|
||||||
|
|
||||||
|
ref_feat_1, srcs_feat_1 = feat_pack_1[:, 0], feat_pack_1[:, 1:]
|
||||||
|
|
||||||
|
depth_values_stage1, single_depth_values_stage1 = self.stage1_p1([ref_feat_1, srcs_feat_1, proj_mats[:, :, 2]],
|
||||||
|
depth_num=self.n_depths[0],
|
||||||
|
depth_start_override=depth_start,
|
||||||
|
depth_interval_override=depth_interval *
|
||||||
|
self.interval_ratios[0])
|
||||||
|
|
||||||
|
_, src_feats, proj_mats = [ref_feat_1, srcs_feat_1, proj_mats[:, :, 2]]
|
||||||
|
pixel_distances = []
|
||||||
|
for i in range(src_feats.shape[1]):
|
||||||
|
src_feat = src_feats[:, i]
|
||||||
|
proj_mat = proj_mats[:, i]
|
||||||
|
pixel_distance = determine_center_pixel_interval(src_feat, proj_mat, single_depth_values_stage1)
|
||||||
|
pixel_distances.append(pixel_distance)
|
||||||
|
return feat_pack_1, feat_pack_2, feat_pack_3, depth_values_stage1, pixel_distances
|
||||||
|
|
||||||
|
|
||||||
|
class EPPMVSNetP3(nn.Cell):
|
||||||
|
"""EPPMVSNet part3"""
|
||||||
|
|
||||||
|
def __init__(self, stage1_p3, stage2, stage3, n_depths, interval_ratios, entropy_range=False,
|
||||||
|
shrink_ratio=1, height=None, width=None):
|
||||||
|
super(EPPMVSNetP3, self).__init__()
|
||||||
|
self.n_depths = n_depths
|
||||||
|
self.interval_ratios = interval_ratios
|
||||||
|
self.entropy_range = entropy_range
|
||||||
|
self.shrink_ratio = shrink_ratio
|
||||||
|
|
||||||
|
self.stage1_p3 = stage1_p3
|
||||||
|
self.stage2 = stage2
|
||||||
|
self.stage3 = stage3
|
||||||
|
|
||||||
|
self.height = height
|
||||||
|
self.width = width
|
||||||
|
|
||||||
|
def construct(self, feat_pack_1, feat_pack_2, feat_pack_3, depth_values_stage1, cost_volume_list_stage1,
|
||||||
|
proj_mats=None, depth_start=None, depth_interval=None):
|
||||||
|
"""construct function"""
|
||||||
|
H = self.height
|
||||||
|
W = self.width
|
||||||
|
|
||||||
|
ref_feat_1, srcs_feat_1 = feat_pack_1[:, 0], feat_pack_1[:, 1:]
|
||||||
|
if self.entropy_range:
|
||||||
|
est_depth_1, _, pair_results_1, conf_range_1 = self.stage1_p3(cost_volume_list_stage1, depth_values_stage1,
|
||||||
|
[ref_feat_1, srcs_feat_1, proj_mats[:, :, 2]],
|
||||||
|
self.n_depths[0])
|
||||||
|
stage2_conf_interval = self.shrink_ratio * conf_range_1 / self.n_depths[0] * (
|
||||||
|
depth_interval * self.interval_ratios[0] * self.n_depths[0]) / self.n_depths[1]
|
||||||
|
else:
|
||||||
|
est_depth_1, _, pair_results_1 = self.stage1_p3(cost_volume_list_stage1, depth_values_stage1,
|
||||||
|
[ref_feat_1, srcs_feat_1, proj_mats[:, :, 2]],
|
||||||
|
self.n_depths[0])
|
||||||
|
stage2_conf_interval = None
|
||||||
|
uncertainty_maps_1, uncertainty_maps_2 = [], []
|
||||||
|
for pair_result in pair_results_1:
|
||||||
|
uncertainty_maps_1.append(pair_result[1][0])
|
||||||
|
for uncertainty_map in uncertainty_maps_1:
|
||||||
|
uncertainty_maps_2.append(P.ResizeBilinear((H // 4, W // 4), False)(uncertainty_map))
|
||||||
|
|
||||||
|
ref_feat_2, srcs_feat_2 = feat_pack_2[:, 0], feat_pack_2[:, 1:]
|
||||||
|
depth_start_2 = P.ResizeBilinear((H // 4, W // 4), False)(est_depth_1)
|
||||||
|
|
||||||
|
if self.entropy_range:
|
||||||
|
est_depth_2, _, _, conf_range_2 = self.stage2([ref_feat_2, srcs_feat_2, proj_mats[:, :, 1]],
|
||||||
|
depth_num=self.n_depths[1],
|
||||||
|
depth_start_override=depth_start_2,
|
||||||
|
depth_interval_override=stage2_conf_interval,
|
||||||
|
uncertainty_maps=uncertainty_maps_2)
|
||||||
|
stage3_conf_interval = self.shrink_ratio * conf_range_2 / self.n_depths[1] * (
|
||||||
|
stage2_conf_interval * self.n_depths[1]) / self.n_depths[2]
|
||||||
|
else:
|
||||||
|
est_depth_2, _, _ = self.stage2([ref_feat_2, srcs_feat_2, proj_mats[:, :, 1]],
|
||||||
|
depth_num=self.n_depths[1], depth_start_override=depth_start_2,
|
||||||
|
depth_interval_override=depth_interval * self.interval_ratios[1],
|
||||||
|
uncertainty_maps=uncertainty_maps_2)
|
||||||
|
stage3_conf_interval = None
|
||||||
|
uncertainty_maps_3 = []
|
||||||
|
for uncertainty_map in uncertainty_maps_2:
|
||||||
|
uncertainty_maps_3.append(P.ResizeBilinear((H // 2, W // 2), False)(uncertainty_map))
|
||||||
|
|
||||||
|
ref_feat_3, srcs_feat_3 = feat_pack_3[:, 0], feat_pack_3[:, 1:]
|
||||||
|
depth_start_3 = P.ResizeBilinear((H // 2, W // 2), False)(est_depth_2)
|
||||||
|
|
||||||
|
if self.entropy_range:
|
||||||
|
est_depth_3, prob_map_3, _, _ = self.stage3([ref_feat_3, srcs_feat_3, proj_mats[:, :, 0]],
|
||||||
|
depth_num=self.n_depths[2],
|
||||||
|
depth_start_override=depth_start_3,
|
||||||
|
depth_interval_override=stage3_conf_interval,
|
||||||
|
uncertainty_maps=uncertainty_maps_3)
|
||||||
|
else:
|
||||||
|
est_depth_3, prob_map_3, _ = self.stage3([ref_feat_3, srcs_feat_3, proj_mats[:, :, 0]],
|
||||||
|
depth_num=self.n_depths[2],
|
||||||
|
depth_start_override=depth_start_3,
|
||||||
|
depth_interval_override=depth_interval * self.interval_ratios[2],
|
||||||
|
uncertainty_maps=uncertainty_maps_3)
|
||||||
|
refined_depth = est_depth_3
|
||||||
|
return refined_depth, prob_map_3
|
|
@ -0,0 +1,437 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""math operations of EPP-MVSNet"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import mindspore
|
||||||
|
import mindspore.nn as nn
|
||||||
|
import mindspore.common as mstype
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore.ops import constexpr
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.ops import composite as C
|
||||||
|
|
||||||
|
|
||||||
|
@constexpr
|
||||||
|
def generate_FloatTensor(x):
|
||||||
|
return mindspore.Tensor(x, dtype=mstype.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def get_depth_values(current_depth, n_depths, depth_interval, inverse_depth=False):
|
||||||
|
"""
|
||||||
|
get the depth values of each pixel : [depth_min, depth_max) step is depth_interval
|
||||||
|
current_depth: (B, 1, H, W), current depth map
|
||||||
|
n_depth: int, number of channels of depth
|
||||||
|
depth_interval: (B) or float, interval between each depth channel
|
||||||
|
return: (B, D, H, W)
|
||||||
|
"""
|
||||||
|
linspace = P.LinSpace()
|
||||||
|
if not isinstance(depth_interval, float) and depth_interval.shape != current_depth.shape:
|
||||||
|
depth_interval = depth_interval.reshape(-1, 1, 1, 1)
|
||||||
|
depth_min = C.clip_by_value(current_depth - n_depths / 2 * depth_interval, generate_FloatTensor(1e-7),
|
||||||
|
generate_FloatTensor(58682))
|
||||||
|
if inverse_depth:
|
||||||
|
depth_end = depth_min + (n_depths - 1) * depth_interval
|
||||||
|
inverse_depth_interval = (1 / depth_min - 1 / depth_end) / (n_depths - 1)
|
||||||
|
depth_values = 1 / depth_end + inverse_depth_interval * \
|
||||||
|
linspace(generate_FloatTensor(0), generate_FloatTensor(n_depths - 1), n_depths).reshape(1, -1, 1,
|
||||||
|
1)
|
||||||
|
depth_values = 1.0 / depth_values
|
||||||
|
else:
|
||||||
|
depth_values = depth_min + depth_interval * \
|
||||||
|
linspace(generate_FloatTensor(0), generate_FloatTensor(n_depths - 1), n_depths).reshape(1, -1, 1,
|
||||||
|
1)
|
||||||
|
return depth_values
|
||||||
|
|
||||||
|
|
||||||
|
class HomoWarp(nn.Cell):
|
||||||
|
'''STN'''
|
||||||
|
|
||||||
|
def __init__(self, H, W):
|
||||||
|
super(HomoWarp, self).__init__()
|
||||||
|
# batch_size = 1
|
||||||
|
x = np.linspace(0, W - 1, W)
|
||||||
|
y = np.linspace(0, H - 1, H)
|
||||||
|
x_t, y_t = np.meshgrid(x, y)
|
||||||
|
x_t = Tensor(x_t, mstype.float32)
|
||||||
|
y_t = Tensor(y_t, mstype.float32)
|
||||||
|
expand_dims = P.ExpandDims()
|
||||||
|
x_t = expand_dims(x_t, 0)
|
||||||
|
y_t = expand_dims(y_t, 0)
|
||||||
|
flatten = P.Flatten()
|
||||||
|
x_t_flat = flatten(x_t)
|
||||||
|
y_t_flat = flatten(y_t)
|
||||||
|
oneslike = P.OnesLike()
|
||||||
|
ones = oneslike(x_t_flat)
|
||||||
|
concat = P.Concat()
|
||||||
|
sampling_grid = concat((x_t_flat, y_t_flat, ones))
|
||||||
|
self.sampling_grid = expand_dims(sampling_grid, 0) # (1, 3, D*H*W)
|
||||||
|
c = np.linspace(0, 31, 32)
|
||||||
|
self.channel = Tensor(c, mstype.float32).view(1, 1, 1, 1, -1)
|
||||||
|
|
||||||
|
batch_size = 128
|
||||||
|
batch_idx = np.arange(batch_size)
|
||||||
|
batch_idx = batch_idx.reshape((batch_size, 1, 1, 1))
|
||||||
|
self.batch_idx = Tensor(batch_idx, mstype.float32)
|
||||||
|
self.zero = Tensor(np.zeros([]), mstype.float32)
|
||||||
|
|
||||||
|
def get_pixel_value(self, img, x, y):
|
||||||
|
"""
|
||||||
|
Utility function to get pixel value for coordinate
|
||||||
|
vectors x and y from a 4D tensor image.
|
||||||
|
|
||||||
|
Input
|
||||||
|
-----
|
||||||
|
- img: tensor of shape (B, H, W, C)
|
||||||
|
- x: flattened tensor of shape (B*D*H*W,)
|
||||||
|
- y: flattened tensor of shape (B*D*H*W,)
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
- output: tensor of shape (B, D, H, W, C)
|
||||||
|
"""
|
||||||
|
shape = P.Shape()
|
||||||
|
img_shape = shape(x)
|
||||||
|
batch_size = img_shape[0]
|
||||||
|
D = img_shape[1]
|
||||||
|
H = img_shape[2]
|
||||||
|
W = img_shape[3]
|
||||||
|
img[:, 0, :, :] = self.zero
|
||||||
|
img[:, H - 1, :, :] = self.zero
|
||||||
|
img[:, :, 0, :] = self.zero
|
||||||
|
img[:, :, W - 1, :] = self.zero
|
||||||
|
|
||||||
|
tile = P.Tile()
|
||||||
|
batch_idx = P.Slice()(self.batch_idx, (0, 0, 0, 0), (batch_size, 1, 1, 1))
|
||||||
|
b = tile(batch_idx, (1, D, H, W))
|
||||||
|
|
||||||
|
expand_dims = P.ExpandDims()
|
||||||
|
b = expand_dims(b, 4)
|
||||||
|
x = expand_dims(x, 4)
|
||||||
|
y = expand_dims(y, 4)
|
||||||
|
|
||||||
|
concat = P.Concat(4)
|
||||||
|
indices = concat((b, y, x))
|
||||||
|
|
||||||
|
cast = P.Cast()
|
||||||
|
indices = cast(indices, mstype.int32)
|
||||||
|
gather_nd = P.GatherNd()
|
||||||
|
|
||||||
|
return cast(gather_nd(img, indices), mstype.float32)
|
||||||
|
|
||||||
|
def homo_warp(self, height, width, proj_mat, depth_values):
|
||||||
|
"""`
|
||||||
|
This function returns a sampling grid, which when
|
||||||
|
used with the bilinear sampler on the input feature
|
||||||
|
map, will create an output feature map that is an
|
||||||
|
affine transformation [1] of the input feature map.
|
||||||
|
|
||||||
|
zero = Tensor(np.zeros([]), mstype.float32)
|
||||||
|
Input
|
||||||
|
-----
|
||||||
|
- height: desired height of grid/output. Used
|
||||||
|
to downsample or upsample.
|
||||||
|
|
||||||
|
- width: desired width of grid/output. Used
|
||||||
|
to downsample or upsample.
|
||||||
|
|
||||||
|
- proj_mat: (B, 3, 4) equal to "src_proj @ ref_proj_inv"
|
||||||
|
|
||||||
|
- depth_values: (B, D, H, W)
|
||||||
|
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
- normalized grid (-1, 1) of shape (num_batch, 2, H, W).
|
||||||
|
The 2nd dimension has 2 components: (x, y) which are the
|
||||||
|
sampling points of the original image for each point in the
|
||||||
|
target image.
|
||||||
|
|
||||||
|
Note
|
||||||
|
----
|
||||||
|
[1]: the affine transformation allows cropping, translation,
|
||||||
|
and isotropic scaling.
|
||||||
|
"""
|
||||||
|
shape = P.Shape()
|
||||||
|
B = shape(depth_values)[0]
|
||||||
|
D = shape(depth_values)[1]
|
||||||
|
H = height
|
||||||
|
W = width
|
||||||
|
|
||||||
|
R = proj_mat[:, :, :3] # (B, 3, 3)
|
||||||
|
T = proj_mat[:, :, 3:] # (B, 3, 1)
|
||||||
|
|
||||||
|
cast = P.Cast()
|
||||||
|
depth_values = cast(depth_values, mstype.float32)
|
||||||
|
|
||||||
|
# transform the sampling grid - batch multiply
|
||||||
|
matmul = P.BatchMatMul()
|
||||||
|
tile = P.Tile()
|
||||||
|
ref_grid_d = tile(self.sampling_grid, (B, 1, 1)) # (B, 3, H*W)
|
||||||
|
cast = P.Cast()
|
||||||
|
ref_grid_d = cast(ref_grid_d, mstype.float32)
|
||||||
|
|
||||||
|
# repeat_elements has problem, can not be used
|
||||||
|
ref_grid_d = P.Tile()(ref_grid_d, (1, 1, D))
|
||||||
|
src_grid_d = matmul(R, ref_grid_d) + T / depth_values.view(B, 1, D * H * W)
|
||||||
|
|
||||||
|
# project negative depth pixels to somewhere outside the image
|
||||||
|
negative_depth_mask = src_grid_d[:, 2:] <= 1e-7
|
||||||
|
src_grid_d[:, 0:1][negative_depth_mask] = W
|
||||||
|
src_grid_d[:, 1:2][negative_depth_mask] = H
|
||||||
|
src_grid_d[:, 2:3][negative_depth_mask] = 1
|
||||||
|
|
||||||
|
src_grid = src_grid_d[:, :2] / src_grid_d[:, 2:] # divide by depth (B, 2, D*H*W)
|
||||||
|
|
||||||
|
reshape = P.Reshape()
|
||||||
|
src_grid = reshape(src_grid, (B, 2, D, H, W))
|
||||||
|
return src_grid
|
||||||
|
|
||||||
|
def bilinear_sampler(self, img, x, y):
|
||||||
|
"""
|
||||||
|
Performs bilinear sampling of the input images according to the
|
||||||
|
normalized coordinates provided by the sampling grid. Note that
|
||||||
|
the sampling is done identically for each channel of the input.
|
||||||
|
|
||||||
|
To test if the function works properly, output image should be
|
||||||
|
identical to input image when theta is initialized to identity
|
||||||
|
transform.
|
||||||
|
|
||||||
|
Input
|
||||||
|
-----
|
||||||
|
- img: batch of images in (B, H, W, C) layout.
|
||||||
|
- grid: x, y which is the output of affine_grid_generator.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
- out: interpolated images according to grids. Same size as grid.
|
||||||
|
"""
|
||||||
|
shape = P.Shape()
|
||||||
|
H = shape(img)[1]
|
||||||
|
W = shape(img)[2]
|
||||||
|
cast = P.Cast()
|
||||||
|
max_y = cast(H - 1, mstype.float32)
|
||||||
|
max_x = cast(W - 1, mstype.float32)
|
||||||
|
zero = self.zero
|
||||||
|
|
||||||
|
# grab 4 nearest corner points for each (x_i, y_i)
|
||||||
|
floor = P.Floor()
|
||||||
|
x0 = floor(x)
|
||||||
|
x1 = x0 + 1
|
||||||
|
y0 = floor(y)
|
||||||
|
y1 = y0 + 1
|
||||||
|
|
||||||
|
# clip to range [0, H-1/W-1] to not violate img boundaries
|
||||||
|
x0 = C.clip_by_value(x0, zero, max_x)
|
||||||
|
x1 = C.clip_by_value(x1, zero, max_x)
|
||||||
|
y0 = C.clip_by_value(y0, zero, max_y)
|
||||||
|
y1 = C.clip_by_value(y1, zero, max_y)
|
||||||
|
|
||||||
|
# get pixel value at corner coords
|
||||||
|
Ia = self.get_pixel_value(img, x0, y0)
|
||||||
|
Ib = self.get_pixel_value(img, x0, y1)
|
||||||
|
Ic = self.get_pixel_value(img, x1, y0)
|
||||||
|
Id = self.get_pixel_value(img, x1, y1)
|
||||||
|
|
||||||
|
# recast as float for delta calculation
|
||||||
|
x0 = cast(x0, mstype.float32)
|
||||||
|
x1 = cast(x1, mstype.float32)
|
||||||
|
y0 = cast(y0, mstype.float32)
|
||||||
|
y1 = cast(y1, mstype.float32)
|
||||||
|
|
||||||
|
# calculate deltas
|
||||||
|
wa = (x1 - x) * (y1 - y)
|
||||||
|
wb = (x1 - x) * (y - y0)
|
||||||
|
wc = (x - x0) * (y1 - y)
|
||||||
|
wd = (x - x0) * (y - y0)
|
||||||
|
|
||||||
|
# add dimension for addition
|
||||||
|
expand_dims = P.ExpandDims()
|
||||||
|
wa = expand_dims(wa, 4)
|
||||||
|
wb = expand_dims(wb, 4)
|
||||||
|
wc = expand_dims(wc, 4)
|
||||||
|
wd = expand_dims(wd, 4)
|
||||||
|
|
||||||
|
# compute output
|
||||||
|
add_n = P.AddN()
|
||||||
|
out = add_n([wa * Ia, wb * Ib, wc * Ic, wd * Id])
|
||||||
|
return out
|
||||||
|
|
||||||
|
def construct(self, input_fmap, proj_mat, depth_values, out_dims=None, **kwargs):
|
||||||
|
"""
|
||||||
|
Spatial Transformer Network layer implementation as described in [1].
|
||||||
|
|
||||||
|
The layer is composed of 3 elements:
|
||||||
|
|
||||||
|
- localization_net: takes the original image as input and outputs
|
||||||
|
the parameters of the affine transformation that should be applied
|
||||||
|
to the input image.
|
||||||
|
|
||||||
|
- affine_grid_generator: generates a grid of (x,y) coordinates that
|
||||||
|
correspond to a set of points where the input should be sampled
|
||||||
|
to produce the transformed output.
|
||||||
|
|
||||||
|
- bilinear_sampler: takes as input the original image and the grid
|
||||||
|
and produces the output image using bilinear interpolation.
|
||||||
|
|
||||||
|
Input
|
||||||
|
-----
|
||||||
|
- input_fmap: output of the previous layer. Can be input if spatial
|
||||||
|
transformer layer is at the beginning of architecture. Should be
|
||||||
|
a tensor of shape (B, H, W, C)->(B, C, H, W).
|
||||||
|
|
||||||
|
- theta: affine transform tensor of shape (B, 6). Permits cropping,
|
||||||
|
translation and isotropic scaling. Initialize to identity matrix.
|
||||||
|
It is the output of the localization network.
|
||||||
|
|
||||||
|
- proj_mat: (B, 3, 4) equal to "src_proj @ ref_proj_inv"
|
||||||
|
|
||||||
|
- depth_values: (B, D, H, W)
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
- out_fmap: transformed input feature map. Tensor of size (B, C, H, W)-->(B, H, W, C).
|
||||||
|
- out: (B, C, D, H, W)
|
||||||
|
|
||||||
|
|
||||||
|
Notes
|
||||||
|
-----
|
||||||
|
[1]: 'Spatial Transformer Networks', Jaderberg et. al,
|
||||||
|
(https://arxiv.org/abs/1506.02025)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# grab input dimensions
|
||||||
|
trans = P.Transpose()
|
||||||
|
input_fmap = trans(input_fmap, (0, 2, 3, 1))
|
||||||
|
shape = P.Shape()
|
||||||
|
input_size = shape(input_fmap)
|
||||||
|
H = input_size[1]
|
||||||
|
W = input_size[2]
|
||||||
|
|
||||||
|
# generate grids of same size or upsample/downsample if specified
|
||||||
|
if out_dims:
|
||||||
|
out_H = out_dims[0]
|
||||||
|
out_W = out_dims[1]
|
||||||
|
batch_grids = self.homo_warp(out_H, out_W, proj_mat, depth_values)
|
||||||
|
else:
|
||||||
|
batch_grids = self.homo_warp(H, W, proj_mat, depth_values)
|
||||||
|
|
||||||
|
x_s, y_s = P.Split(1, 2)(batch_grids)
|
||||||
|
squeeze = P.Squeeze(1)
|
||||||
|
x_s = squeeze(x_s)
|
||||||
|
y_s = squeeze(y_s)
|
||||||
|
out_fmap = self.bilinear_sampler(input_fmap, x_s, y_s)
|
||||||
|
out_fmap = trans(out_fmap, (0, 4, 1, 2, 3))
|
||||||
|
|
||||||
|
return out_fmap
|
||||||
|
|
||||||
|
|
||||||
|
def determine_center_pixel_interval(src_feat, proj_mat, depth_values):
|
||||||
|
"""
|
||||||
|
src_feat: (B, C, H, W)
|
||||||
|
proj_mat: (B, 3, 4) equal to "src_proj @ ref_proj_inv"
|
||||||
|
depth_values: (B, D, H, W)
|
||||||
|
out: (B, C, D, H, W)
|
||||||
|
"""
|
||||||
|
B, _, H, W = src_feat.shape
|
||||||
|
D = depth_values.shape[1]
|
||||||
|
|
||||||
|
R = proj_mat[:, :, :3] # (B, 3, 3)
|
||||||
|
T = proj_mat[:, :, 3:] # (B, 3, 1)
|
||||||
|
|
||||||
|
concat = P.Concat(axis=1)
|
||||||
|
ref_center = generate_FloatTensor([H / 2, W / 2]).view(1, 2, 1, 1)
|
||||||
|
ref_center = ref_center.reshape(1, 2, -1)
|
||||||
|
ref_center = P.Tile()(ref_center, (B, 1, 1))
|
||||||
|
ref_center = concat((ref_center, C.ones_like(ref_center[:, :1]))) # (B, 3, H*W)
|
||||||
|
ref_center_d = C.repeat_elements(ref_center, rep=D, axis=2) # (B, 3, D*H*W)
|
||||||
|
src_center_d = C.matmul(R, ref_center_d) + T / depth_values.view(B, 1, D * 1 * 1)
|
||||||
|
|
||||||
|
negative_depth_mask = src_center_d[:, 2:] <= 1e-7
|
||||||
|
src_center_d[:, 0:1][negative_depth_mask] = W
|
||||||
|
src_center_d[:, 1:2][negative_depth_mask] = H
|
||||||
|
src_center_d[:, 2:3][negative_depth_mask] = 1
|
||||||
|
|
||||||
|
transpose = P.Transpose()
|
||||||
|
sqrt = P.Sqrt()
|
||||||
|
pow_ms = P.Pow()
|
||||||
|
src_center = src_center_d[:, :2] / src_center_d[:, 2:] # divide by depth (B, 2, D*H*W)
|
||||||
|
src_grid_valid = transpose(src_center, (0, 2, 1)).view(B, D, 1, 2) # (B, D*H*W, 2)
|
||||||
|
delta_p = src_grid_valid[:, 1:, :, :] - src_grid_valid[:, :-1, :, :]
|
||||||
|
epipolar_pixel = sqrt(pow_ms(delta_p[:, :, :, 0], 2) + pow_ms(delta_p[:, :, :, 1], 2))
|
||||||
|
epipolar_pixel = epipolar_pixel.mean(1)
|
||||||
|
|
||||||
|
return epipolar_pixel
|
||||||
|
|
||||||
|
|
||||||
|
def depth_regression(p, depth_values, keep_dim=False):
|
||||||
|
"""
|
||||||
|
p: probability volume (B, D, H, W)
|
||||||
|
depth_values: discrete depth values (B, D, H, W) or (D)
|
||||||
|
inverse: depth_values is inverse depth or not
|
||||||
|
"""
|
||||||
|
if depth_values.ndim <= 2:
|
||||||
|
depth_values = depth_values.view(*depth_values.shape, 1, 1)
|
||||||
|
cumsum = P.ReduceSum(keep_dim)
|
||||||
|
depth = cumsum(p * depth_values, 1)
|
||||||
|
return depth
|
||||||
|
|
||||||
|
|
||||||
|
def soft_argmin(volume, dim, keepdim=False, window=None):
|
||||||
|
"""soft argmin"""
|
||||||
|
softmax = nn.Softmax(1)
|
||||||
|
prob_vol = softmax(volume)
|
||||||
|
length = volume.shape[dim]
|
||||||
|
index = nn.Range(0, length)()
|
||||||
|
index_shape = []
|
||||||
|
for i in range(len(volume.shape)):
|
||||||
|
if i == dim:
|
||||||
|
index_shape.append(length)
|
||||||
|
else:
|
||||||
|
index_shape.append(1)
|
||||||
|
index = index.reshape(index_shape)
|
||||||
|
out = P.ReduceSum(True)(index * prob_vol, dim)
|
||||||
|
squeeze = P.Squeeze(axis=dim)
|
||||||
|
out_sq = squeeze(out) if not keepdim else out
|
||||||
|
if window is None:
|
||||||
|
return prob_vol, out_sq
|
||||||
|
# |depth hypothesis - predicted depth|, assemble to UCSNet
|
||||||
|
# 1d11 n1hw
|
||||||
|
mask = ((index - out).abs() <= window)
|
||||||
|
mask = mask.astype(mstype.float32)
|
||||||
|
prob_map = P.ReduceSum(keepdim)(prob_vol * mask, dim)
|
||||||
|
return prob_vol, out_sq, prob_map
|
||||||
|
|
||||||
|
|
||||||
|
def entropy(volume, dim, keepdim=False):
|
||||||
|
return P.ReduceSum(keepdim)(
|
||||||
|
-volume * P.Log()(C.clip_by_value(volume, generate_FloatTensor(1e-9), generate_FloatTensor(1.))), dim)
|
||||||
|
|
||||||
|
|
||||||
|
def entropy_num_based(volume, dim, depth_num, keepdim=False):
|
||||||
|
return P.ReduceSum(keepdim)(
|
||||||
|
-volume * P.Log()(C.clip_by_value(volume, generate_FloatTensor(1e-9), generate_FloatTensor(1.))) / math.log(
|
||||||
|
math.e, depth_num), dim)
|
||||||
|
|
||||||
|
|
||||||
|
def groupwise_correlation(v1, v2, groups, dim):
|
||||||
|
n, c, d, h, w = v1.shape
|
||||||
|
reshaped_size = (n, groups, c // groups, d, h, w)
|
||||||
|
v1_reshaped = v1.view(*reshaped_size)
|
||||||
|
v2_reshaped = v2.view(*reshaped_size)
|
||||||
|
vc = P.ReduceSum()(v1_reshaped * v2_reshaped, dim + 1)
|
||||||
|
return vc
|
|
@ -0,0 +1,402 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""sub-networks of EPP-MVSNet"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import mindspore
|
||||||
|
import mindspore.ops as P
|
||||||
|
from mindspore import nn
|
||||||
|
from mindspore import Tensor, Parameter
|
||||||
|
from src.modules import depth_regression, soft_argmin, entropy
|
||||||
|
|
||||||
|
|
||||||
|
class BasicBlockA(nn.Cell):
|
||||||
|
"""BasicBlockA"""
|
||||||
|
|
||||||
|
def __init__(self, in_channels, out_channels, stride):
|
||||||
|
super(BasicBlockA, self).__init__()
|
||||||
|
self.conv2d_0 = nn.Conv2d(in_channels, out_channels, 3, stride=stride, padding=1, pad_mode="pad")
|
||||||
|
self.conv2d_1 = nn.Conv2d(in_channels, out_channels, 1, stride=stride, padding=0, pad_mode="valid")
|
||||||
|
self.batchnorm2d_2 = nn.BatchNorm2d(out_channels, eps=9.999999747378752e-06, momentum=0.8999999761581421)
|
||||||
|
self.batchnorm2d_3 = nn.BatchNorm2d(out_channels, eps=9.999999747378752e-06, momentum=0.8999999761581421)
|
||||||
|
self.relu_4 = nn.ReLU()
|
||||||
|
self.conv2d_5 = nn.Conv2d(out_channels, out_channels, 3, stride=1, padding=(1, 1, 1, 1), pad_mode="pad")
|
||||||
|
self.batchnorm2d_6 = nn.BatchNorm2d(out_channels, eps=9.999999747378752e-06, momentum=0.8999999761581421)
|
||||||
|
self.relu_8 = nn.ReLU()
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
"""construct"""
|
||||||
|
x1 = self.conv2d_0(x)
|
||||||
|
x1 = self.batchnorm2d_2(x1)
|
||||||
|
x1 = self.relu_4(x1)
|
||||||
|
x1 = self.conv2d_5(x1)
|
||||||
|
x1 = self.batchnorm2d_6(x1)
|
||||||
|
|
||||||
|
res = self.conv2d_1(x)
|
||||||
|
res = self.batchnorm2d_3(res)
|
||||||
|
|
||||||
|
out = P.Add()(x1, res)
|
||||||
|
out = self.relu_8(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class BasicBlockB(nn.Cell):
|
||||||
|
"""BasicBlockB"""
|
||||||
|
|
||||||
|
def __init__(self, in_channels, out_channels):
|
||||||
|
super(BasicBlockB, self).__init__()
|
||||||
|
self.conv2d_0 = nn.Conv2d(in_channels, out_channels, 3, stride=1, padding=1, pad_mode="pad")
|
||||||
|
self.batchnorm2d_1 = nn.BatchNorm2d(out_channels, eps=9.999999747378752e-06, momentum=0.8999999761581421)
|
||||||
|
self.relu_2 = nn.ReLU()
|
||||||
|
self.conv2d_3 = nn.Conv2d(in_channels, out_channels, 3, stride=1, padding=1, pad_mode="pad")
|
||||||
|
self.batchnorm2d_4 = nn.BatchNorm2d(out_channels, eps=9.999999747378752e-06, momentum=0.8999999761581421)
|
||||||
|
self.relu_6 = nn.ReLU()
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
"""construct"""
|
||||||
|
x1 = self.conv2d_0(x)
|
||||||
|
x1 = self.batchnorm2d_1(x1)
|
||||||
|
x1 = self.relu_2(x1)
|
||||||
|
x1 = self.conv2d_3(x1)
|
||||||
|
x1 = self.batchnorm2d_4(x1)
|
||||||
|
|
||||||
|
res = x
|
||||||
|
|
||||||
|
out = P.Add()(x1, res)
|
||||||
|
out = self.relu_6(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class UNet2D(nn.Cell):
|
||||||
|
"""UNet2D"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(UNet2D, self).__init__()
|
||||||
|
|
||||||
|
self.conv2d_0 = nn.Conv2d(3, 16, 5, stride=2, padding=2, pad_mode="pad")
|
||||||
|
self.batchnorm2d_1 = nn.BatchNorm2d(16, eps=9.999999747378752e-06, momentum=0.8999999761581421)
|
||||||
|
self.leakyrelu_2 = nn.LeakyReLU(alpha=0.009999999776482582)
|
||||||
|
|
||||||
|
self.convblocka_0 = BasicBlockA(16, 32, 1)
|
||||||
|
self.convblockb_0 = BasicBlockB(32, 32)
|
||||||
|
|
||||||
|
self.convblocka_1 = BasicBlockA(32, 64, 2)
|
||||||
|
self.convblockb_1 = BasicBlockB(64, 64)
|
||||||
|
|
||||||
|
self.convblocka_2 = BasicBlockA(64, 128, 2)
|
||||||
|
self.convblockb_2 = BasicBlockB(128, 128)
|
||||||
|
|
||||||
|
self.conv2dbackpropinput_51 = P.Conv2DBackpropInput(64, 3, stride=2, pad=1, pad_mode="pad")
|
||||||
|
self.conv2dbackpropinput_51_weight = Parameter(Tensor(
|
||||||
|
np.random.uniform(0, 1, (128, 64, 3, 3)).astype(np.float32)))
|
||||||
|
self.conv2d_54 = nn.Conv2d(128, 64, 3, stride=1, padding=1, pad_mode="pad")
|
||||||
|
self.convblockb_3 = BasicBlockB(64, 64)
|
||||||
|
|
||||||
|
self.conv2dbackpropinput_62 = P.Conv2DBackpropInput(32, 3, stride=2, pad=1, pad_mode="pad")
|
||||||
|
self.conv2dbackpropinput_62_weight = Parameter(Tensor(
|
||||||
|
np.random.uniform(0, 1, (64, 32, 3, 3)).astype(np.float32)))
|
||||||
|
self.conv2d_65 = nn.Conv2d(64, 32, 3, stride=1, padding=1, pad_mode="pad")
|
||||||
|
self.convblockb_4 = BasicBlockB(32, 32)
|
||||||
|
|
||||||
|
self.conv2d_52 = nn.Conv2d(128, 32, 3, stride=1, padding=1, pad_mode="pad")
|
||||||
|
self.conv2d_63 = nn.Conv2d(64, 32, 3, stride=1, padding=1, pad_mode="pad")
|
||||||
|
self.conv2d_73 = nn.Conv2d(32, 32, 3, stride=1, padding=1, pad_mode="pad")
|
||||||
|
|
||||||
|
self.concat = P.Concat(axis=1)
|
||||||
|
|
||||||
|
param_dict = mindspore.load_checkpoint("./ckpts/feat_ext.ckpt")
|
||||||
|
params_not_loaded = mindspore.load_param_into_net(self, param_dict, strict_load=True)
|
||||||
|
print(params_not_loaded)
|
||||||
|
|
||||||
|
def construct(self, imgs):
|
||||||
|
"""construct"""
|
||||||
|
_, _, h, w = imgs.shape
|
||||||
|
|
||||||
|
x = self.conv2d_0(imgs)
|
||||||
|
x = self.batchnorm2d_1(x)
|
||||||
|
x = self.leakyrelu_2(x)
|
||||||
|
|
||||||
|
x1 = self.convblocka_0(x)
|
||||||
|
x1 = self.convblockb_0(x1)
|
||||||
|
x2 = self.convblocka_1(x1)
|
||||||
|
x2 = self.convblockb_1(x2)
|
||||||
|
x3 = self.convblocka_2(x2)
|
||||||
|
x3 = self.convblockb_2(x3)
|
||||||
|
|
||||||
|
x2_upsample = self.conv2dbackpropinput_51(x3, self.conv2dbackpropinput_51_weight,
|
||||||
|
(x2.shape[0], x2.shape[1], h // 4, w // 4))
|
||||||
|
x2_upsample = self.concat((x2_upsample, x2,))
|
||||||
|
x2_upsample = self.conv2d_54(x2_upsample)
|
||||||
|
x2_upsample = self.convblockb_3(x2_upsample)
|
||||||
|
|
||||||
|
x1_upsample = self.conv2dbackpropinput_62(x2_upsample, self.conv2dbackpropinput_62_weight,
|
||||||
|
(x1.shape[0], x1.shape[1], h // 2, w // 2))
|
||||||
|
x1_upsample = self.concat((x1_upsample, x1,))
|
||||||
|
x1_upsample = self.conv2d_65(x1_upsample)
|
||||||
|
x1_upsample = self.convblockb_4(x1_upsample)
|
||||||
|
|
||||||
|
x3_final = self.conv2d_52(x3)
|
||||||
|
x2_final = self.conv2d_63(x2_upsample)
|
||||||
|
x1_final = self.conv2d_73(x1_upsample)
|
||||||
|
return x3_final, x2_final, x1_final
|
||||||
|
|
||||||
|
|
||||||
|
class ConvBnReLu(nn.Cell):
|
||||||
|
"""ConvBnReLu"""
|
||||||
|
|
||||||
|
def __init__(self, in_channels, out_channels):
|
||||||
|
super(ConvBnReLu, self).__init__()
|
||||||
|
self.conv3d_0 = nn.Conv3d(in_channels, out_channels, (3, 1, 1), stride=1, padding=(1, 1, 0, 0, 0, 0),
|
||||||
|
pad_mode="pad")
|
||||||
|
self.batchnorm3d_1 = nn.BatchNorm3d(out_channels, eps=9.999999747378752e-06, momentum=0.8999999761581421)
|
||||||
|
self.leakyrelu_2 = nn.LeakyReLU(alpha=0.009999999776482582)
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
"""construct"""
|
||||||
|
x = self.conv3d_0(x)
|
||||||
|
x = self.batchnorm3d_1(x)
|
||||||
|
x = self.leakyrelu_2(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class CostCompression(nn.Cell):
|
||||||
|
"""CostCompression"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(CostCompression, self).__init__()
|
||||||
|
self.basicblock_0 = ConvBnReLu(8, 64)
|
||||||
|
self.basicblock_1 = ConvBnReLu(64, 64)
|
||||||
|
self.basicblock_2 = ConvBnReLu(64, 8)
|
||||||
|
|
||||||
|
param_dict = mindspore.load_checkpoint("./ckpts/stage1_cost_compression.ckpt")
|
||||||
|
params_not_loaded = mindspore.load_param_into_net(self, param_dict, strict_load=True)
|
||||||
|
print(params_not_loaded)
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
"""construct"""
|
||||||
|
x = self.basicblock_0(x)
|
||||||
|
x = self.basicblock_1(x)
|
||||||
|
x = self.basicblock_2(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Pseudo3DBlock_A(nn.Cell):
|
||||||
|
"""Pseudo3DBlock_A"""
|
||||||
|
|
||||||
|
def __init__(self, in_channels, out_channels):
|
||||||
|
super(Pseudo3DBlock_A, self).__init__()
|
||||||
|
self.conv3d_0 = nn.Conv3d(in_channels, out_channels, (1, 3, 3), stride=1, padding=(0, 0, 1, 1, 1, 1),
|
||||||
|
pad_mode="pad")
|
||||||
|
self.conv3d_1 = nn.Conv3d(out_channels, out_channels, (3, 1, 1), stride=1, padding=(1, 1, 0, 0, 0, 0),
|
||||||
|
pad_mode="pad")
|
||||||
|
self.batchnorm3d_2 = nn.BatchNorm3d(out_channels, eps=9.999999747378752e-06, momentum=0.8999999761581421)
|
||||||
|
self.relu_3 = nn.ReLU()
|
||||||
|
self.conv3d_4 = nn.Conv3d(out_channels, out_channels, (1, 3, 3), stride=1, padding=(0, 0, 1, 1, 1, 1),
|
||||||
|
pad_mode="pad")
|
||||||
|
self.conv3d_5 = nn.Conv3d(out_channels, out_channels, (3, 1, 1), stride=1, padding=(1, 1, 0, 0, 0, 0),
|
||||||
|
pad_mode="pad")
|
||||||
|
self.batchnorm3d_6 = nn.BatchNorm3d(out_channels, eps=9.999999747378752e-06, momentum=0.8999999761581421)
|
||||||
|
self.relu_8 = nn.ReLU()
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
"""construct"""
|
||||||
|
x1 = self.conv3d_0(x)
|
||||||
|
x1 = self.conv3d_1(x1)
|
||||||
|
x1 = self.batchnorm3d_2(x1)
|
||||||
|
x1 = self.relu_3(x1)
|
||||||
|
x1 = self.conv3d_4(x1)
|
||||||
|
x1 = self.conv3d_5(x1)
|
||||||
|
x1 = self.batchnorm3d_6(x1)
|
||||||
|
|
||||||
|
res = x
|
||||||
|
|
||||||
|
out = P.Add()(x1, res)
|
||||||
|
out = self.relu_8(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class Pseudo3DBlock_B(nn.Cell):
|
||||||
|
"""Pseudo3DBlock_B"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(Pseudo3DBlock_B, self).__init__()
|
||||||
|
self.conv3d_0 = nn.Conv3d(8, 8, (1, 3, 3), stride=(1, 2, 2), padding=(0, 0, 1, 1, 1, 1), pad_mode="pad")
|
||||||
|
self.conv3d_1 = nn.Conv3d(8, 16, (1, 1, 1), stride=2, padding=0, pad_mode="valid")
|
||||||
|
self.conv3d_2 = nn.Conv3d(8, 16, (3, 1, 1), stride=(2, 1, 1), padding=(1, 1, 0, 0, 0, 0), pad_mode="pad")
|
||||||
|
self.batchnorm3d_3 = nn.BatchNorm3d(16, eps=9.999999747378752e-06, momentum=0.8999999761581421)
|
||||||
|
self.batchnorm3d_4 = nn.BatchNorm3d(16, eps=9.999999747378752e-06, momentum=0.8999999761581421)
|
||||||
|
self.relu_5 = nn.ReLU()
|
||||||
|
self.conv3d_6 = nn.Conv3d(16, 16, (1, 3, 3), stride=1, padding=(0, 0, 1, 1, 1, 1), pad_mode="pad")
|
||||||
|
self.conv3d_7 = nn.Conv3d(16, 16, (3, 1, 1), stride=1, padding=(1, 1, 0, 0, 0, 0), pad_mode="pad")
|
||||||
|
self.batchnorm3d_8 = nn.BatchNorm3d(16, eps=9.999999747378752e-06, momentum=0.8999999761581421)
|
||||||
|
self.relu_10 = nn.ReLU()
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
"""construct"""
|
||||||
|
x1 = self.conv3d_0(x)
|
||||||
|
x1 = self.conv3d_2(x1)
|
||||||
|
x1 = self.batchnorm3d_4(x1)
|
||||||
|
x1 = self.relu_5(x1)
|
||||||
|
x1 = self.conv3d_6(x1)
|
||||||
|
x1 = self.conv3d_7(x1)
|
||||||
|
x1 = self.batchnorm3d_8(x1)
|
||||||
|
|
||||||
|
res = self.conv3d_1(x)
|
||||||
|
res = self.batchnorm3d_3(res)
|
||||||
|
|
||||||
|
out = P.Add()(x1, res)
|
||||||
|
out = self.relu_10(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class CoarseStageRegFuse(nn.Cell):
|
||||||
|
"""CoarseStageRegFuse"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(CoarseStageRegFuse, self).__init__()
|
||||||
|
self.basicblocka_0 = Pseudo3DBlock_A(8, 8)
|
||||||
|
self.basicblockb_0 = Pseudo3DBlock_B()
|
||||||
|
self.conv3dtranspose_21 = nn.Conv3dTranspose(16, 8, 3, stride=2, padding=1, pad_mode="pad", output_padding=1)
|
||||||
|
|
||||||
|
self.conv3d_23 = nn.Conv3d(16, 8, (1, 3, 3), stride=1, padding=(0, 0, 1, 1, 1, 1), pad_mode="pad")
|
||||||
|
self.conv3d_24 = nn.Conv3d(8, 8, (3, 1, 1), stride=1, padding=(1, 1, 0, 0, 0, 0), pad_mode="pad")
|
||||||
|
self.conv3d_25 = nn.Conv3d(8, 1, 3, stride=1, padding=1, pad_mode="pad")
|
||||||
|
|
||||||
|
self.concat_1 = P.Concat(axis=1)
|
||||||
|
self.squeeze_1 = P.Squeeze(axis=1)
|
||||||
|
|
||||||
|
param_dict = mindspore.load_checkpoint("./ckpts/stage1_reg_fuse.ckpt")
|
||||||
|
params_not_loaded = mindspore.load_param_into_net(self, param_dict, strict_load=True)
|
||||||
|
print(params_not_loaded)
|
||||||
|
|
||||||
|
def construct(self, fused_interim, depth_values):
|
||||||
|
"""construct"""
|
||||||
|
x1 = self.basicblocka_0(fused_interim)
|
||||||
|
x2 = self.basicblockb_0(x1)
|
||||||
|
x1_upsample = self.conv3dtranspose_21(x2)
|
||||||
|
|
||||||
|
cost_volume = self.concat_1((x1_upsample, x1))
|
||||||
|
cost_volume = self.conv3d_23(cost_volume)
|
||||||
|
cost_volume = self.conv3d_24(cost_volume)
|
||||||
|
score_volume = self.conv3d_25(cost_volume)
|
||||||
|
|
||||||
|
score_volume = self.squeeze_1(score_volume)
|
||||||
|
|
||||||
|
prob_volume, _, prob_map = soft_argmin(score_volume, dim=1, keepdim=True, window=2)
|
||||||
|
est_depth = depth_regression(prob_volume, depth_values, keep_dim=True)
|
||||||
|
return est_depth, prob_map, prob_volume
|
||||||
|
|
||||||
|
|
||||||
|
class CoarseStageRegPair(nn.Cell):
|
||||||
|
"""CoarseStageRegPair"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(CoarseStageRegPair, self).__init__()
|
||||||
|
self.basicblocka_0 = Pseudo3DBlock_A(8, 8)
|
||||||
|
self.basicblockb_0 = Pseudo3DBlock_B()
|
||||||
|
self.conv3dtranspose_21 = nn.Conv3dTranspose(16, 8, 3, stride=2, padding=1, pad_mode="pad", output_padding=1)
|
||||||
|
|
||||||
|
self.concat_22 = P.Concat(axis=1)
|
||||||
|
self.conv3d_23 = nn.Conv3d(16, 8, (1, 3, 3), stride=1, padding=(0, 0, 1, 1, 1, 1), pad_mode="pad")
|
||||||
|
self.conv3d_24 = nn.Conv3d(8, 8, (3, 1, 1), stride=1, padding=(1, 1, 0, 0, 0, 0), pad_mode="pad")
|
||||||
|
self.conv3d_25 = nn.Conv3d(8, 1, 3, stride=1, padding=1, pad_mode="pad")
|
||||||
|
|
||||||
|
self.conv2d_38 = nn.Conv2d(1, 8, 3, stride=1, padding=1, pad_mode="pad")
|
||||||
|
self.batchnorm2d_39 = nn.BatchNorm2d(num_features=8, eps=9.999999747378752e-06, momentum=0.8999999761581421)
|
||||||
|
self.leakyrelu_40 = nn.LeakyReLU(alpha=0.009999999776482582)
|
||||||
|
self.conv2d_41 = nn.Conv2d(8, 8, 3, stride=1, padding=1, pad_mode="pad")
|
||||||
|
self.batchnorm2d_42 = nn.BatchNorm2d(num_features=8, eps=9.999999747378752e-06, momentum=0.8999999761581421)
|
||||||
|
self.leakyrelu_43 = nn.LeakyReLU(alpha=0.009999999776482582)
|
||||||
|
self.conv2d_45 = nn.Conv2d(8, 1, 3, stride=1, padding=1, pad_mode="pad")
|
||||||
|
self.conv2d_46 = nn.Conv2d(8, 1, 3, stride=1, padding=1, pad_mode="pad")
|
||||||
|
|
||||||
|
self.concat_1 = P.Concat(axis=1)
|
||||||
|
self.squeeze_1 = P.Squeeze(axis=1)
|
||||||
|
|
||||||
|
param_dict = mindspore.load_checkpoint("./ckpts/stage1_reg_pair.ckpt")
|
||||||
|
params_not_loaded = mindspore.load_param_into_net(self, param_dict, strict_load=True)
|
||||||
|
print(params_not_loaded)
|
||||||
|
|
||||||
|
def construct(self, cost_volume, depth_values):
|
||||||
|
"""construct"""
|
||||||
|
x1 = self.basicblocka_0(cost_volume)
|
||||||
|
x2 = self.basicblockb_0(x1)
|
||||||
|
x1_upsample = self.conv3dtranspose_21(x2)
|
||||||
|
|
||||||
|
interim = self.concat_1((x1_upsample, x1))
|
||||||
|
interim = self.conv3d_23(interim)
|
||||||
|
interim = self.conv3d_24(interim)
|
||||||
|
score_volume = self.conv3d_25(interim)
|
||||||
|
|
||||||
|
score_volume = self.squeeze_1(score_volume)
|
||||||
|
prob_volume, _ = soft_argmin(score_volume, dim=1, keepdim=True)
|
||||||
|
est_depth = depth_regression(prob_volume, depth_values, keep_dim=True)
|
||||||
|
entropy_ = entropy(prob_volume, dim=1, keepdim=True)
|
||||||
|
|
||||||
|
x = self.conv2d_38(entropy_)
|
||||||
|
x = self.batchnorm2d_39(x)
|
||||||
|
x = self.leakyrelu_40(x)
|
||||||
|
x = self.conv2d_41(x)
|
||||||
|
x = self.batchnorm2d_42(x)
|
||||||
|
x = self.leakyrelu_43(x)
|
||||||
|
|
||||||
|
out = P.Add()(x, entropy_)
|
||||||
|
uncertainty_map = self.conv2d_45(out)
|
||||||
|
occ = self.conv2d_46(out)
|
||||||
|
|
||||||
|
return interim, est_depth, uncertainty_map, occ
|
||||||
|
|
||||||
|
|
||||||
|
class StageRegFuse(nn.Cell):
|
||||||
|
"""StageRegFuse"""
|
||||||
|
|
||||||
|
def __init__(self, ckpt_path):
|
||||||
|
super(StageRegFuse, self).__init__()
|
||||||
|
self.basicblocka_0 = Pseudo3DBlock_A(8, 8)
|
||||||
|
self.basicblocka_1 = Pseudo3DBlock_A(8, 8)
|
||||||
|
self.basicblockb_0 = Pseudo3DBlock_B()
|
||||||
|
self.basicblocka_2 = Pseudo3DBlock_A(16, 16)
|
||||||
|
self.conv3dtranspose_38 = nn.Conv3dTranspose(16, 8, 3, stride=2, padding=1, pad_mode="pad", output_padding=1)
|
||||||
|
|
||||||
|
self.concat_39 = P.Concat(axis=1)
|
||||||
|
self.conv3d_40 = nn.Conv3d(16, 8, (1, 3, 3), stride=1, padding=(0, 0, 1, 1, 1, 1), pad_mode="pad")
|
||||||
|
self.conv3d_41 = nn.Conv3d(8, 8, (3, 1, 1), stride=1, padding=(1, 1, 0, 0, 0, 0), pad_mode="pad")
|
||||||
|
self.conv3d_42 = nn.Conv3d(8, 1, 3, stride=1, padding=1, pad_mode="pad")
|
||||||
|
|
||||||
|
self.concat_1 = P.Concat(axis=1)
|
||||||
|
self.squeeze_1 = P.Squeeze(axis=1)
|
||||||
|
|
||||||
|
param_dict = mindspore.load_checkpoint(ckpt_path)
|
||||||
|
params_not_loaded = mindspore.load_param_into_net(self, param_dict, strict_load=True)
|
||||||
|
print(params_not_loaded)
|
||||||
|
|
||||||
|
def construct(self, fused_interim, depth_values):
|
||||||
|
"""construct"""
|
||||||
|
x1 = self.basicblocka_0(fused_interim)
|
||||||
|
x1 = self.basicblocka_1(x1)
|
||||||
|
x2 = self.basicblockb_0(x1)
|
||||||
|
x2 = self.basicblocka_2(x2)
|
||||||
|
x1_upsample = self.conv3dtranspose_38(x2)
|
||||||
|
|
||||||
|
cost_volume = self.concat_1((x1_upsample, x1))
|
||||||
|
cost_volume = self.conv3d_40(cost_volume)
|
||||||
|
cost_volume = self.conv3d_41(cost_volume)
|
||||||
|
score_volume = self.conv3d_42(cost_volume)
|
||||||
|
|
||||||
|
score_volume = self.squeeze_1(score_volume)
|
||||||
|
|
||||||
|
prob_volume, _, prob_map = soft_argmin(score_volume, dim=1, keepdim=True, window=2)
|
||||||
|
est_depth = depth_regression(prob_volume, depth_values, keep_dim=True)
|
||||||
|
return est_depth, prob_map, prob_volume
|
|
@ -0,0 +1,116 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""other functions of EPP-MVSNet"""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def read_pfm(filename):
|
||||||
|
"""read pfm files"""
|
||||||
|
file = open(filename, 'rb')
|
||||||
|
color = None
|
||||||
|
width = None
|
||||||
|
height = None
|
||||||
|
scale = None
|
||||||
|
endian = None
|
||||||
|
|
||||||
|
header = file.readline().decode('utf-8').rstrip()
|
||||||
|
if header == 'PF':
|
||||||
|
color = True
|
||||||
|
elif header == 'Pf':
|
||||||
|
color = False
|
||||||
|
else:
|
||||||
|
raise Exception('Not a PFM file.')
|
||||||
|
|
||||||
|
dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode('utf-8'))
|
||||||
|
if dim_match:
|
||||||
|
width, height = map(int, dim_match.groups())
|
||||||
|
else:
|
||||||
|
raise Exception('Malformed PFM header.')
|
||||||
|
|
||||||
|
scale = float(file.readline().rstrip())
|
||||||
|
if scale < 0: # little-endian
|
||||||
|
endian = '<'
|
||||||
|
scale = -scale
|
||||||
|
else:
|
||||||
|
endian = '>' # big-endian
|
||||||
|
|
||||||
|
data = np.fromfile(file, endian + 'f')
|
||||||
|
shape = (height, width, 3) if color else (height, width)
|
||||||
|
|
||||||
|
data = np.reshape(data, shape)
|
||||||
|
data = np.flipud(data)
|
||||||
|
file.close()
|
||||||
|
return data, scale
|
||||||
|
|
||||||
|
|
||||||
|
def save_pfm(filename, image, scale=1):
|
||||||
|
"""save pfm files"""
|
||||||
|
file = open(filename, "wb")
|
||||||
|
color = None
|
||||||
|
|
||||||
|
image = np.flipud(image)
|
||||||
|
|
||||||
|
if image.dtype.name != 'float32':
|
||||||
|
raise Exception('Image dtype must be float32.')
|
||||||
|
|
||||||
|
if len(image.shape) == 3 and image.shape[2] == 3: # color image
|
||||||
|
color = True
|
||||||
|
elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1: # greyscale
|
||||||
|
color = False
|
||||||
|
else:
|
||||||
|
raise Exception('Image must have H x W x 3, H x W x 1 or H x W dimensions.')
|
||||||
|
|
||||||
|
file.write('PF\n'.encode('utf-8') if color else 'Pf\n'.encode('utf-8'))
|
||||||
|
file.write('{} {}\n'.format(image.shape[1], image.shape[0]).encode('utf-8'))
|
||||||
|
|
||||||
|
endian = image.dtype.byteorder
|
||||||
|
|
||||||
|
if endian == '<' or endian == '=' and sys.byteorder == 'little':
|
||||||
|
scale = -scale
|
||||||
|
|
||||||
|
file.write(('%f\n' % scale).encode('utf-8'))
|
||||||
|
|
||||||
|
image.tofile(file)
|
||||||
|
file.close()
|
||||||
|
|
||||||
|
|
||||||
|
class AverageMeter:
|
||||||
|
"""Computes and stores the average and current value"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""reset averagemeter"""
|
||||||
|
self.val = 0
|
||||||
|
self.avg = 0
|
||||||
|
self.sum = 0
|
||||||
|
self.count = 0
|
||||||
|
|
||||||
|
def update(self, val, n=1):
|
||||||
|
"""update averagemeter"""
|
||||||
|
self.val = val
|
||||||
|
self.sum += val * n
|
||||||
|
self.count += n
|
||||||
|
if self.count > 0:
|
||||||
|
self.avg = self.sum / self.count
|
|
@ -0,0 +1,180 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""EPP-MVSNet's validation process on BlendedMVS dataset"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
import mindspore.dataset as ds
|
||||||
|
from mindspore import context
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
|
||||||
|
from src.eppmvsnet import EPPMVSNet
|
||||||
|
from src.blendedmvs import BlendedMVSDataset
|
||||||
|
from src.utils import save_pfm, AverageMeter
|
||||||
|
|
||||||
|
|
||||||
|
def get_opts():
|
||||||
|
"""set options"""
|
||||||
|
parser = ArgumentParser()
|
||||||
|
parser.add_argument('--gpu_id', type=int, default=0, choices=[0, 1, 2, 3, 4, 5, 6, 7],
|
||||||
|
help='which gpu used to inference')
|
||||||
|
## vis
|
||||||
|
parser.add_argument('--root_dir', type=str,
|
||||||
|
default='/home/ubuntu/data/DTU/mvs_training/dtu/',
|
||||||
|
help='root directory of dtu dataset')
|
||||||
|
parser.add_argument('--dataset_name', type=str, default='blendedmvs',
|
||||||
|
choices=['blendedmvs'],
|
||||||
|
help='which dataset to train/val')
|
||||||
|
parser.add_argument('--split', type=str, default=None,
|
||||||
|
help='which split to evaluate')
|
||||||
|
parser.add_argument('--scan', type=str, default=None, nargs='+',
|
||||||
|
help='specify scan to evaluate (must be in the split)')
|
||||||
|
# for depth prediction
|
||||||
|
parser.add_argument('--n_views', type=int, default=5,
|
||||||
|
help='number of views (including ref) to be used in testing')
|
||||||
|
parser.add_argument('--depth_interval', type=float, default=128,
|
||||||
|
help='depth interval unit in mm')
|
||||||
|
parser.add_argument('--n_depths', nargs='+', type=int, default=[32, 16, 8],
|
||||||
|
help='number of depths in each level')
|
||||||
|
parser.add_argument('--interval_ratios', nargs='+', type=float, default=[4.0, 2.0, 1.0],
|
||||||
|
help='depth interval ratio to multiply with --depth_interval in each level')
|
||||||
|
parser.add_argument('--img_wh', nargs="+", type=int, default=[1152, 864],
|
||||||
|
help='resolution (img_w, img_h) of the image, must be multiples of 32')
|
||||||
|
parser.add_argument('--ckpt_path', type=str, default='ckpts/exp2/_ckpt_epoch_10.ckpt',
|
||||||
|
help='pretrained checkpoint path to load')
|
||||||
|
parser.add_argument('--save_visual', default=False, action='store_true',
|
||||||
|
help='save depth and proba visualization or not')
|
||||||
|
parser.add_argument('--entropy_range', action='store_true', default=False,
|
||||||
|
help='whether to use entropy range method')
|
||||||
|
parser.add_argument('--conf', type=float, default=0.9,
|
||||||
|
help='min confidence for pixel to be valid')
|
||||||
|
parser.add_argument('--levels', type=int, default=3, choices=[3, 4, 5],
|
||||||
|
help='number of FPN levels (fixed to be 3!)')
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
args = get_opts()
|
||||||
|
context.set_context(mode=0, device_target='GPU', device_id=args.gpu_id, save_graphs=False)
|
||||||
|
|
||||||
|
dataset = BlendedMVSDataset(args.root_dir, args.split, n_views=args.n_views, depth_interval=args.depth_interval,
|
||||||
|
img_wh=tuple(args.img_wh), levels=args.levels, scan=args.scan)
|
||||||
|
img_wh = args.img_wh
|
||||||
|
scans = dataset.scans
|
||||||
|
|
||||||
|
print(args.n_depths)
|
||||||
|
print(args.interval_ratios)
|
||||||
|
# Step 1. Create depth estimation and probability for each scan
|
||||||
|
EPPMVSNet_eval = EPPMVSNet(n_depths=args.n_depths, interval_ratios=args.interval_ratios,
|
||||||
|
entropy_range=args.entropy_range, height=args.img_wh[1], width=args.img_wh[0])
|
||||||
|
EPPMVSNet_eval.set_train(False)
|
||||||
|
|
||||||
|
depth_dir = f'results/{args.dataset_name}/{args.split}/depth'
|
||||||
|
print('Creating depth and confidence predictions...')
|
||||||
|
if args.scan:
|
||||||
|
data_range = [i for i, x in enumerate(dataset.metas) if x[0] == args.scan]
|
||||||
|
else:
|
||||||
|
data_range = range(len(dataset))
|
||||||
|
test_loader = ds.GeneratorDataset(dataset, column_names=["imgs", "proj_mats", "init_depth_min", "depth_interval",
|
||||||
|
"scan", "vid", "depth_0", "mask_0", "fix_depth_interval"],
|
||||||
|
num_parallel_workers=1, shuffle=False)
|
||||||
|
test_loader = test_loader.batch(batch_size=1)
|
||||||
|
test_data_size = test_loader.get_dataset_size()
|
||||||
|
print("train dataset length is:", test_data_size)
|
||||||
|
|
||||||
|
pbar = tqdm(enumerate(test_loader.create_tuple_iterator()), dynamic_ncols=True, total=test_data_size)
|
||||||
|
|
||||||
|
metrics = ['stage3_l1_loss', 'stage3_less1_acc', 'stage3_less3_acc']
|
||||||
|
avg_metrics = {t: AverageMeter() for t in metrics}
|
||||||
|
|
||||||
|
forward_time_avg = AverageMeter()
|
||||||
|
|
||||||
|
scan_list, vid_list = [], []
|
||||||
|
|
||||||
|
depth_folder = f'{img_wh[0]}_{img_wh[1]}_{args.n_views - 1}'
|
||||||
|
|
||||||
|
for i, sample in pbar:
|
||||||
|
imgs, proj_mats, init_depth_min, depth_interval, scan, vid, depth_0, mask_0, fix_depth_interval = sample
|
||||||
|
scan = scan[0].asnumpy()
|
||||||
|
scan_str = ""
|
||||||
|
for num in scan:
|
||||||
|
scan_str += chr(num)
|
||||||
|
scan = scan_str
|
||||||
|
vid = vid[0].asnumpy()
|
||||||
|
|
||||||
|
depth_file_dir = os.path.join(depth_dir, scan, depth_folder)
|
||||||
|
if not os.path.exists(depth_file_dir):
|
||||||
|
os.makedirs(depth_file_dir, exist_ok=True)
|
||||||
|
|
||||||
|
begin = time.time()
|
||||||
|
|
||||||
|
results = EPPMVSNet_eval(imgs, proj_mats, init_depth_min, depth_interval)
|
||||||
|
|
||||||
|
forward_time = time.time() - begin
|
||||||
|
if i != 0:
|
||||||
|
forward_time_avg.update(forward_time)
|
||||||
|
|
||||||
|
depth, proba = results
|
||||||
|
depth = P.Squeeze()(depth).asnumpy()
|
||||||
|
depth = np.nan_to_num(depth) # change nan to 0
|
||||||
|
proba = P.Squeeze()(proba).asnumpy()
|
||||||
|
proba = np.nan_to_num(proba) # change nan to 0
|
||||||
|
|
||||||
|
save_pfm(os.path.join(depth_dir, f'{scan}/{depth_folder}/depth_{vid:04d}.pfm'), depth)
|
||||||
|
save_pfm(os.path.join(depth_dir, f'{scan}/{depth_folder}/proba_{vid:04d}.pfm'), proba)
|
||||||
|
|
||||||
|
# record l1 loss of each image
|
||||||
|
scan_list.append(scan)
|
||||||
|
vid_list.append(vid)
|
||||||
|
|
||||||
|
pred_depth = depth
|
||||||
|
gt = P.Squeeze()(depth_0).asnumpy()
|
||||||
|
mask = P.Squeeze()(mask_0).asnumpy()
|
||||||
|
|
||||||
|
abs_err = np.abs(pred_depth - gt)
|
||||||
|
abs_err_scaled = abs_err / fix_depth_interval.asnumpy()
|
||||||
|
|
||||||
|
l1 = abs_err_scaled[mask].mean()
|
||||||
|
less1 = (abs_err_scaled[mask] < 1.).astype(np.float32).mean()
|
||||||
|
less3 = (abs_err_scaled[mask] < 3.).astype(np.float32).mean()
|
||||||
|
|
||||||
|
avg_metrics[f'stage3_l1_loss'].update(l1)
|
||||||
|
avg_metrics[f'stage3_less1_acc'].update(less1)
|
||||||
|
avg_metrics[f'stage3_less3_acc'].update(less3)
|
||||||
|
|
||||||
|
if args.save_visual:
|
||||||
|
mi = np.min(depth[depth > 0])
|
||||||
|
ma = np.max(depth)
|
||||||
|
depth = (depth - mi) / (ma - mi + 1e-8)
|
||||||
|
depth = (255 * depth).astype(np.uint8)
|
||||||
|
depth_img = cv2.applyColorMap(depth, cv2.COLORMAP_JET)
|
||||||
|
cv2.imwrite(os.path.join(depth_dir, f'{scan}/{depth_folder}/depth_visual_{vid:04d}.jpg'), depth_img)
|
||||||
|
cv2.imwrite(os.path.join(depth_dir, f'{scan}/{depth_folder}/proba_visual_{vid:04d}.jpg'),
|
||||||
|
(255 * (proba > args.conf)).astype(np.uint8))
|
||||||
|
print(f'step {i} time: {forward_time}s')
|
||||||
|
print(f'mean forward time: {forward_time_avg.avg}')
|
||||||
|
|
||||||
|
with open(f'results/{args.dataset_name}/{args.split}/metrics.txt', 'w') as f:
|
||||||
|
for i in avg_metrics.items():
|
||||||
|
f.writelines((i[0]) + ':' + str(np.round(i[1].avg, 4)) + '\n')
|
||||||
|
f.writelines('mean forward time(s/pic):' + str(np.round(forward_time_avg.avg, 4)) + '\n')
|
||||||
|
f.close()
|
||||||
|
print('Done!')
|
Loading…
Reference in New Issue