# Varifocal Loss 

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional


# Source: https://github.com/hyz-xmaster/VarifocalNet
def varifocal_loss(
    logits: torch.Tensor,
    labels: torch.Tensor,
    weight: Optional[torch.Tensor]=None,
    alpha: float=0.75,
    gamma: float=2.0,
    iou_weighted: bool=True,
):
    
    """`Varifocal Loss <https://arxiv.org/abs/2008.13367>`

    Args:
        logits (torch.Tensor): The model predicted logits with shape (N, C), 
        C is the number of classes
        labels (torch.Tensor): The learning target of the iou-aware
            classification score with shape (N, C), C is the number of classes.
        weight (torch.Tensor, optional): The weight of loss for each
            prediction. Defaults to None.
        alpha (float, optional): A balance factor for the negative part of
            Varifocal Loss, which is different from the alpha of Focal Loss.
            Defaults to 0.75.
        gamma (float, optional): The gamma for calculating the modulating
            factor. Defaults to 2.0.
        iou_weighted (bool, optional): Whether to weight the loss of the
            positive example with the iou target. Defaults to True.
    """
    assert logits.size() == labels.size()
    logits_prob = logits.sigmoid()
    labels = labels.type_as(logits)
    if iou_weighted:
        focal_weight = labels * (labels > 0.0).float() + \
            alpha * (logits_prob - labels).abs().pow(gamma) * \
            (labels <= 0.0).float()

    else:
        focal_weight = (labels > 0.0).float() + \
            alpha * (logits_prob - labels).abs().pow(gamma) * \
            (labels <= 0.0).float()

    loss = F.binary_cross_entropy_with_logits(
        logits, labels, reduction='none') * focal_weight
    loss = loss * weight if weight is not None else loss
    return loss



class VariFocalLoss(nn.Module):
    def __init__(
        self,
        alpha: float=0.75,
        gamma: float=2.0,
        iou_weighted: bool=True,
        reduction: str='mean',
    ):
        # VariFocal Implementation: https://github.com/hyz-xmaster/VarifocalNet/blob/master/mmdet/models/losses/varifocal_loss.py
        super(VariFocalLoss, self).__init__()
        assert reduction in ('mean', 'sum', 'none')
        assert alpha >= 0.0
        self.alpha = alpha
        self.gamma = gamma
        self.iou_weighted = iou_weighted
        self.reduction = reduction

    def forward(self, logits, labels):
        loss = varifocal_loss(logits, labels, self.alpha, self.gamma, self.iou_weighted)

        if self.reduction == 'sum':
            return loss.sum()
        elif self.reduction == 'mean':
            return loss.mean()
        else:
            return loss


N, C = 5, 4  # Number of samples N and number of classes C
logits = torch.randn(N, C, requires_grad=True)  # Example logits
labels = torch.rand(N, C)  # Example labels, assuming continuous values for demonstration


# Recompute the loss with everything correctly defined
vari_focal_loss_instance = VariFocalLoss()
loss_output_corrected = vari_focal_loss_instance(logits, labels)

# Print the corrected loss output
print(loss_output_corrected)

tensor(0.2350, grad_fn=<MeanBackward0>)
