### Code written by Pranav Durai 
### Fine-Tuning SegFormer for Improved Lane Detection 

In [1]:


import os
import cv2
import numpy as np
import copy
from tqdm import tqdm
import requests

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms.functional import to_tensor
from torchvision import transforms as TF
import torch.nn.functional as F
from torch.optim import AdamW

from transformers import SegformerForSemanticSegmentation
from transformers import get_scheduler

from sklearn.metrics import jaccard_score

In [None]:
def download_dataset(url, local_filename):

    # Update Dropbox link to force download
    if "www.dropbox.com" in url and "?dl=0" in url:
        url = url.replace("?dl=0", "?dl=1")
    
    # Send a GET request to the URL
    response = requests.get(url)
    
    # Check if the request was successful
    if response.status_code == 200:
        # Write the content of the response to a file
        with open(local_filename, 'wb') as f:
            f.write(response.content)
        print(f"File downloaded and saved as {local_filename}")
    else:
        print(f"Failed to download file. Status code: {response.status_code}")

In [None]:
# Download 10% sample of BDD100K Dataset
download_dataset('https://www.dropbox.com/scl/fi/40onxgztkbtqxvsg2d6fk/deep_drive_10K.zip?rlkey=8h098tbe9dry81jidtte1d9j5&dl=1', 'BDD.zip')

In [2]:
class BDDDataset(Dataset):
    def __init__(self, images_dir, masks_dir, transform=None):
        self.images_dir = images_dir
        self.masks_dir = masks_dir
        self.transform = transform
        self.images = [img for img in os.listdir(images_dir) if img.endswith('.jpg')]
        self.masks = [mask.replace('.jpg', '.png') for mask in self.images]

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image_path = os.path.join(self.images_dir, self.images[idx])
        mask_path = os.path.join(self.masks_dir, self.masks[idx])
        image = Image.open(image_path).convert("RGB")
        mask = Image.open(mask_path).convert('L')  # Convert mask to grayscale
        
        # Convert mask to binary format with 0 and 1 values
        mask = np.array(mask)
        mask = (mask > 0).astype(np.uint8)  # Assuming non-zero pixels are lanes
        
        # Convert to PIL Image for consistency in transforms
        mask = Image.fromarray(mask)

        if self.transform:
            image = self.transform(image)
            # Assuming to_tensor transform is included which scales pixel values between 0-1
            # mask = to_tensor(mask)  # Convert the mask to [0, 1] range
        mask = TF.functional.resize(img=mask, size=[360, 640], interpolation=Image.NEAREST)
        mask = TF.functional.to_tensor(mask)
        mask = (mask > 0).long()  # Threshold back to binary and convert to LongTensor

        return image, mask

def mean_iou(preds, labels, num_classes):
    # Flatten predictions and labels
    preds_flat = preds.view(-1)
    labels_flat = labels.view(-1)

    # Check that the number of elements in the flattened predictions
    # and labels are equal
    if preds_flat.shape[0] != labels_flat.shape[0]:
        raise ValueError(f"Predictions and labels have mismatched shapes: "
                         f"{preds_flat.shape} vs {labels_flat.shape}")

    # Calculate the Jaccard score for each class
    iou = jaccard_score(labels_flat.cpu().numpy(), preds_flat.cpu().numpy(),
                        average=None, labels=range(num_classes))

    # Return the mean IoU
    return np.mean(iou)

In [3]:
# Define the appropriate transformations
transform = TF.Compose([
    TF.Resize((360, 640)),
    TF.ToTensor(),
    TF.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Create the dataset
train_dataset = BDDDataset(images_dir='deep_drive_10K/train/images',
                           masks_dir='deep_drive_10K/train/masks',
                           transform=transform)

valid_dataset = BDDDataset(images_dir='deep_drive_10K/valid/images',
                           masks_dir='deep_drive_10K/valid/masks',
                           transform=transform)

# Create the data loaders
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=6)
valid_loader = DataLoader(valid_dataset, batch_size=4, shuffle=False, num_workers=6)

In [4]:
# Load the pre-trained model
model = SegformerForSemanticSegmentation.from_pretrained('nvidia/segformer-b2-finetuned-ade-512-512')

# Adjust the number of classes for BDD dataset
model.config.num_labels = 2  # Replace with the actual number of classes

In [5]:
# Check for CUDA acceleration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device);

In [None]:
# Define the optimizer
optimizer = AdamW(model.parameters(), lr=5e-5)

# Define the learning rate scheduler
num_epochs = 30
num_training_steps = num_epochs * len(train_loader)
lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps
)

# Placeholder for best mean IoU and best model weights
best_iou = 0.0
best_model_wts = copy.deepcopy(model.state_dict())

for epoch in range(num_epochs):
    model.train()
    train_iterator = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{num_epochs}", unit="batch")
    for batch in train_iterator:
        images, masks = batch
        images = images.to(device)
        masks = masks.to(device).long()  # Ensure masks are LongTensors

        # Remove the channel dimension from the masks tensor
        masks = masks.squeeze(1)  # This changes the shape from [batch, 1, H, W] to [batch, H, W]
        optimizer.zero_grad()

        # Pass pixel_values and labels to the model
        outputs = model(pixel_values=images, labels=masks,return_dict=True)
        
        loss = outputs["loss"]
        loss.backward()

        optimizer.step()
        lr_scheduler.step()
        outputs = F.interpolate(outputs["logits"], size=masks.shape[-2:], mode="bilinear", align_corners=False)
        
        train_iterator.set_postfix(loss=loss.item())
    
    # Evaluation loop for each epoch
    model.eval()
    total_iou = 0
    num_batches = 0
    valid_iterator = tqdm(valid_loader, desc="Validation", unit="batch")
    for batch in valid_iterator:
        images, masks = batch
        images = images.to(device)
        masks = masks.to(device).long()
    
        with torch.no_grad():
            # Get the logits from the model and apply argmax to get the predictions
            outputs = model(pixel_values=images,return_dict=True)
            outputs = F.interpolate(outputs["logits"], size=masks.shape[-2:], mode="bilinear", align_corners=False)
            preds = torch.argmax(outputs, dim=1)
            preds = torch.unsqueeze(preds, dim=1)

        preds = preds.view(-1)
        masks = masks.view(-1)
    
        # Compute IoU
        iou = mean_iou(preds, masks, model.config.num_labels)
        total_iou += iou
        num_batches += 1
        valid_iterator.set_postfix(mean_iou=iou)
    
    epoch_iou = total_iou / num_batches
    print(f"Epoch {epoch+1}/{num_epochs} - Mean IoU: {epoch_iou:.4f}")

    # Check for improvement
    if epoch_iou > best_iou:
        print(f"Validation IoU improved from {best_iou:.4f} to {epoch_iou:.4f}")
        best_iou = epoch_iou
        best_model_wts = copy.deepcopy(model.state_dict())
        torch.save(best_model_wts, 'best_model.pth')

# After all epochs, load the best model weights - optional
model.load_state_dict(torch.load('best_model.pth'))
print("Loaded the best model weights!")