# Stereo Vision Transformer - Inference on KITTI 2015 Stereo Vision Dataset

#### Code written by Pranav Durai

In [1]:
from PIL import Image
import torch
import numpy as np
import cv2
import glob
import os


import argparse
import matplotlib.pyplot as plt
import sys
sys.path.append('../') # add relative path

from module.sttr import STTR
from dataset.preprocess import normalization, compute_left_occ_region
from utilities.misc import NestedTensor

In [2]:
# Function to load images
def load_images(image_dir, pattern):
    filenames = sorted(glob.glob(os.path.join(image_dir, pattern)))
    return [np.array(Image.open(filename)) for filename in filenames[:500]]

In [3]:
# Default parameters
args = type('', (), {})() # create empty args
args.channel_dim = 128
args.position_encoding = 'sine1d_rel'
args.num_attn_layers = 6
args.nheads = 8
args.regression_head = 'ot'
args.context_adjustment_layer = 'cal'
args.cal_num_blocks = 8
args.cal_feat_dim = 16
args.cal_expansion_ratio = 4

In [4]:
model = STTR(args).cuda().eval()



In [5]:
# Load the pretrained model
model_file_name = "../kitti_finetuned_model.pth.tar"
checkpoint = torch.load(model_file_name)
pretrained_dict = checkpoint['state_dict']
model.load_state_dict(pretrained_dict, strict=False) # prevent BN parameters from breaking the model loading
print("Pre-trained model successfully loaded.")

Pre-trained model successfully loaded.


In [6]:
# Load images
left_images = load_images('../sample_data/KITTI_2015/2015/training/image_2', '*.png')
right_images = load_images('../sample_data/KITTI_2015/2015/training/image_3', '*.png')

In [7]:
# Initialize video writer
height, width, _ = left_images[0].shape
output_dir = '../inference_output/'
os.makedirs(output_dir, exist_ok=True)  # Create output directory if it doesn't exist

In [8]:
for i, (left, right) in enumerate(zip(left_images, right_images)):
    # Normalize and create NestedTensor for each set of images
    input_data = normalization(left=left, right=right)
    h, w, _ = left.shape
    bs = 1
    downsample = 3
    col_offset = int(downsample / 2)
    row_offset = int(downsample / 2)
    sampled_cols = torch.arange(col_offset, w, downsample)[None,].expand(bs, -1).cuda()
    sampled_rows = torch.arange(row_offset, h, downsample)[None,].expand(bs, -1).cuda()
    input_data = NestedTensor(input_data['left'].cuda()[None,], input_data['right'].cuda()[None,], sampled_cols=sampled_cols, sampled_rows=sampled_rows)

    # Perform inference
    output = model(input_data)
    disp_pred = output['disp_pred'].data.cpu().numpy()[0]
    occ_pred = output['occ_pred'].data.cpu().numpy()[0] > 0.5
    disp_pred[occ_pred] = 0.0

     # Ensure disp_pred and occ_pred are normalized and converted to uint8
    disp_pred_norm = cv2.normalize(disp_pred, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
    occ_pred_uint8 = np.uint8(occ_pred * 255)

    # Combine predicted disparity and occlusion map
    combined_output = np.hstack((disp_pred_norm, occ_pred_uint8))

    # Save the combined output as a PNG file
    output_filename = os.path.join(output_dir, f'inference_{i:03d}.png')
    cv2.imwrite(output_filename, combined_output)
    print(f"Saved: {output_filename}")

print("All inferences saved as PNG files.")



Saved: ../inference_output/inference_000.png
Saved: ../inference_output/inference_001.png
Saved: ../inference_output/inference_002.png
Saved: ../inference_output/inference_003.png
Saved: ../inference_output/inference_004.png
Saved: ../inference_output/inference_005.png
Saved: ../inference_output/inference_006.png
Saved: ../inference_output/inference_007.png
Saved: ../inference_output/inference_008.png
Saved: ../inference_output/inference_009.png
Saved: ../inference_output/inference_010.png
Saved: ../inference_output/inference_011.png
Saved: ../inference_output/inference_012.png
Saved: ../inference_output/inference_013.png
Saved: ../inference_output/inference_014.png
Saved: ../inference_output/inference_015.png
Saved: ../inference_output/inference_016.png
Saved: ../inference_output/inference_017.png
Saved: ../inference_output/inference_018.png
Saved: ../inference_output/inference_019.png
Saved: ../inference_output/inference_020.png
Saved: ../inference_output/inference_021.png
Saved: ../