跳转至

360°旋转文字区域检测实战4:损失函数与旋转框代码解密

本文写于2026年4月21号晚11点

一、旋转框的NMS与IoU计算

旋转框的NMS主要就是解决旋转矩形的IoU计算问题,旋转矩形属于凸四边形,因此可以将这个问题转换为任意凸四边形的IoU计算问题,这个我之前专门写了一篇文章做介绍,请看任意凸四边形iou的计算

二、旋转框的回归Loss

关于旋转框的回归Loss,我们可以选择IoU Loss和GWD Loss,下面分别进行代码解读。

2.1 IoU Loss

如果你认真读了《任意凸四边形iou的计算》,你会发现任意凸四边形的IoU好像没办法自动微分。因为涉及到三个问题:

  1. 判断两个线段交点是否在两个矩形框内部,判断操作不可以梯度回传
  2. 计算inner area,凸包计算那里需要判断哪个点是顶点,判断操作不可以梯度回传
  3. 在计算inner area时,需要针对顶点进行排序,排序操作不可以梯度回传

下面我们一一解答:

  1. 判断两个线段交点是否在两个矩形框内部,针对每个交点会得到一个mask,那个mask是通过比较操作得到的,不会回传梯度。换言之:判断操作不回传梯度,但是mask会与部分交点相乘,mask为1的交点照样可以回传梯度。
  2. 严格说“任意点集”通常要做凸包,只处理“两个旋转矩形”的交集,交集一定是凸多边形(最多 8 个顶点),所以这里没有凸包运算,也就没有判断操作。
  3. 在计算inner area时,需要针对顶点进行排序,这里排序同理与前面第一条判断的处理一样,不回传梯度,只负责把交点顺序排序正确。

关键代码在mmcv/ops/diff_iou_rotated.py, 实现路径如下:

mmrotate 的 loss 不用普通 box_iou_rotated,而是用可微版本 rotated_iou_loss.py,可微 IoU 的核心在 mmcv.ops.diff_iou_rotated_2d

  1. 把 (x,y,w,h,θ) 先连续映射到4个角点(sin/cos,可导)

    def diff_iou_rotated_2d(box1: Tensor, box2: Tensor) -> Tensor:
        """Calculate differentiable iou of rotated 2d boxes.
    
        Args:
            box1 (Tensor): (B, N, 5) First box.
            box2 (Tensor): (B, N, 5) Second box.
    
        Returns:
            Tensor: (B, N) IoU.
        """
        corners1 = box2corners(box1)
        corners2 = box2corners(box2)
        intersection, _ = oriented_box_intersection_2d(corners1,
                                                    corners2)  # (B, N)
        area1 = box1[:, :, 2] * box1[:, :, 3]
        area2 = box2[:, :, 2] * box2[:, :, 3]
        union = area1 + area2 - intersection
        iou = intersection / union
        return iou
    

    def box2corners(box: Tensor) -> Tensor:
    """Convert rotated 2d box coordinate to corners.
    
    Args:
        box (Tensor): (B, N, 5) with x, y, w, h, alpha.
    
    Returns:
        Tensor: (B, N, 4, 2) Corners.
    """
    B = box.size()[0]
    x, y, w, h, alpha = box.split([1, 1, 1, 1, 1], dim=-1)
    x4 = box.new_tensor([0.5, -0.5, -0.5, 0.5]).to(box.device)
    x4 = x4 * w  # (B, N, 4)
    y4 = box.new_tensor([0.5, 0.5, -0.5, -0.5]).to(box.device)
    y4 = y4 * h  # (B, N, 4)
    corners = torch.stack([x4, y4], dim=-1)  # (B, N, 4, 2)
    sin = torch.sin(alpha)
    cos = torch.cos(alpha)
    row1 = torch.cat([cos, sin], dim=-1)
    row2 = torch.cat([-sin, cos], dim=-1)  # (B, N, 2)
    rot_T = torch.stack([row1, row2], dim=-2)  # (B, N, 2, 2)
    rotated = torch.bmm(corners.view([-1, 4, 2]), rot_T.view([-1, 2, 2]))
    rotated = rotated.view([B, -1, 4, 2])  # (B * N, 4, 2) -> (B, N, 4, 2)
    rotated[..., 0] += x
    rotated[..., 1] += y
    return rotated
    
  2. 用线段交点公式算交点(代数运算可导),并用 mask 保留有效交点

    def box_intersection(corners1: Tensor,
                     corners2: Tensor) -> Tuple[Tensor, Tensor]:
    """Find intersection points of rectangles.
    Convention: if two edges are collinear, there is no intersection point.
    
    Args:
        corners1 (Tensor): (B, N, 4, 2) First batch of boxes.
        corners2 (Tensor): (B, N, 4, 2) Second batch of boxes.
    
    Returns:
        Tuple:
         - Tensor: (B, N, 4, 4, 2) Intersections.
         - Tensor: (B, N, 4, 4) Valid intersections mask.
    """
    # build edges from corners
    # B, N, 4, 4: Batch, Box, edge, point
    line1 = torch.cat([corners1, corners1[:, :, [1, 2, 3, 0], :]], dim=3)
    line2 = torch.cat([corners2, corners2[:, :, [1, 2, 3, 0], :]], dim=3)
    # duplicate data to pair each edges from the boxes
    # (B, N, 4, 4) -> (B, N, 4, 4, 4) : Batch, Box, edge1, edge2, point
    line1_ext = line1.unsqueeze(3)
    line2_ext = line2.unsqueeze(2)
    x1, y1, x2, y2 = line1_ext.split([1, 1, 1, 1], dim=-1)
    x3, y3, x4, y4 = line2_ext.split([1, 1, 1, 1], dim=-1)
    # math: https://en.wikipedia.org/wiki/Line%E2%80%93line_intersection
    numerator = (x1 - x2) * (y3 - y4) - (y1 - y2) * (x3 - x4)
    denumerator_t = (x1 - x3) * (y3 - y4) - (y1 - y3) * (x3 - x4)
    t = denumerator_t / numerator
    t[numerator == .0] = -1.
    mask_t = (t > 0) & (t < 1)  # intersection on line segment 1
    denumerator_u = (x1 - x2) * (y1 - y3) - (y1 - y2) * (x1 - x3)
    u = -denumerator_u / numerator
    u[numerator == .0] = -1.
    mask_u = (u > 0) & (u < 1)  # intersection on line segment 2
    mask = mask_t * mask_u
    # overwrite with EPSILON. otherwise numerically unstable
    t = denumerator_t / (numerator + EPSILON)
    intersections = torch.stack([x1 + t * (x2 - x1), y1 + t * (y2 - y1)],
                                dim=-1)
    intersections = intersections * mask.float().unsqueeze(-1)
    return intersections, mask
    
  3. 组合“交点 + 落在对方框内的角点”得到交多边形顶点集合

    def box_in_box(corners1: Tensor, corners2: Tensor) -> Tuple[Tensor, Tensor]:
    """Check if corners of two boxes lie in each other.
    
    Args:
        corners1 (Tensor): (B, N, 4, 2) First batch of boxes.
        corners2 (Tensor): (B, N, 4, 2) Second batch of boxes.
    
    Returns:
        Tuple:
         - Tensor: (B, N, 4) True if i-th corner of box1 is in box2.
         - Tensor: (B, N, 4) True if i-th corner of box2 is in box1.
    """
    c1_in_2 = box1_in_box2(corners1, corners2)
    c2_in_1 = box1_in_box2(corners2, corners1)
    return c1_in_2, c2_in_1
    

    def build_vertices(corners1: Tensor, corners2: Tensor, c1_in_2: Tensor,
                   c2_in_1: Tensor, intersections: Tensor,
                   valid_mask: Tensor) -> Tuple[Tensor, Tensor]:
    """Find vertices of intersection area.
    
    Args:
        corners1 (Tensor): (B, N, 4, 2) First batch of boxes.
        corners2 (Tensor): (B, N, 4, 2) Second batch of boxes.
        c1_in_2 (Tensor): (B, N, 4) True if i-th corner of box1 is in box2.
        c2_in_1 (Tensor): (B, N, 4) True if i-th corner of box2 is in box1.
        intersections (Tensor): (B, N, 4, 4, 2) Intersections.
        valid_mask (Tensor): (B, N, 4, 4) Valid intersections mask.
    
    Returns:
        Tuple:
         - Tensor: (B, N, 24, 2) Vertices of intersection area;
               only some elements are valid.
         - Tensor: (B, N, 24) Mask of valid elements in vertices.
    """
    # NOTE: inter has elements equals zero and has zeros gradient
    # (masked by multiplying with 0); can be used as trick
    B = corners1.size()[0]
    N = corners1.size()[1]
    # (B, N, 4 + 4 + 16, 2)
    vertices = torch.cat(
        [corners1, corners2,
         intersections.view([B, N, -1, 2])], dim=2)
    # Bool (B, N, 4 + 4 + 16)
    mask = torch.cat([c1_in_2, c2_in_1, valid_mask.view([B, N, -1])], dim=2)
    return vertices, mask
    
  4. 顶点排序这一步是离散操作,索引被显式标记为不可导

    class SortVertices(Function):
    
    @staticmethod
    def forward(ctx, vertices, mask, num_valid):
        idx = ext_module.diff_iou_rotated_sort_vertices_forward(
            vertices, mask, num_valid)
        if torch.__version__ != 'parrots':
            ctx.mark_non_differentiable(idx)
        return idx
    
    @staticmethod
    def backward(ctx, gradout):
        return ()
    

  5. 在“排序结果固定”的局部区域内,用鞋带公式算面积,可导

    def calculate_area(idx_sorted: Tensor,
                   vertices: Tensor) -> Tuple[Tensor, Tensor]:
    """Calculate area of intersection.
    
    Args:
        idx_sorted (Tensor): (B, N, 9) Sorted vertex ids.
        vertices (Tensor): (B, N, 24, 2) Vertices.
    
    Returns:
        Tuple:
         - Tensor (B, N): Area of intersection.
         - Tensor: (B, N, 9, 2) Vertices of polygon with zero padding.
    """
    idx_ext = idx_sorted.unsqueeze(-1).repeat([1, 1, 1, 2])
    selected = torch.gather(vertices, 2, idx_ext)
    total = selected[:, :, 0:-1, 0] * selected[:, :, 1:, 1] \
        - selected[:, :, 0:-1, 1] * selected[:, :, 1:, 0]
    total = torch.sum(total, dim=2)
    area = torch.abs(total) / 2
    return area, selected
    

  6. 最后 IoU = inter / union,再转成 loss (1-IoU / 1-IoU^2 / -log(IoU))

所以本质是:用“连续几何公式 + 分段逻辑 + 非可导排序索引”构造一个可训练的 IoU。

2.2 GWD Loss

GWD Loss 可以理解成“把旋转框看成二维高斯后,用 Wasserstein-2 距离做回归”。

具体实现在mmrotate/models/losses/gaussian_dist_loss.py

  1. 旋转框先转成高斯 映射关系:

    \[ \quad\mu=(x,y) \]
    \[ \Sigma=R\cdot\mathrm{diag}((w/2)^2,(h/2)^2)\cdot R^\top \]
    def xy_wh_r_2_xy_sigma(xywhr):
    """Convert oriented bounding box to 2-D Gaussian distribution.
    
    Args:
        xywhr (torch.Tensor): rbboxes with shape (N, 5).
    
    Returns:
        xy (torch.Tensor): center point of 2-D Gaussian distribution
            with shape (N, 2).
        sigma (torch.Tensor): covariance matrix of 2-D Gaussian distribution
            with shape (N, 2, 2).
    """
    _shape = xywhr.shape
    assert _shape[-1] == 5
    xy = xywhr[..., :2]
    wh = xywhr[..., 2:4].clamp(min=1e-7, max=1e7).reshape(-1, 2)
    r = xywhr[..., 4]
    cos_r = torch.cos(r)
    sin_r = torch.sin(r)
    R = torch.stack((cos_r, -sin_r, sin_r, cos_r), dim=-1).reshape(-1, 2, 2)
    S = 0.5 * torch.diag_embed(wh)
    
    sigma = R.bmm(S.square()).bmm(R.permute(0, 2,
                                            1)).reshape(_shape[:-1] + (2, 2))
    
    return xy, sigma
    
  2. GWD 主体参见下面公式

    距离本质是:

    \[ d=\sqrt{\|\mu_p-\mu_t\|^2+\alpha^2\cdot d_{\mathrm{Bures}}(\Sigma_p,\Sigma_t)} \]

    其中协方差项\(d_\mathrm{Bures}\)需要 \(\operatorname{Tr}((\Sigma_p^{1/2}\Sigma_t\Sigma_p^{1/2})^{1/2})\)

    @weighted_loss
    def gwd_loss(pred, target, fun='log1p', tau=1.0, alpha=1.0, normalize=True):
        """Gaussian Wasserstein distance loss.
        Derivation and simplification:
            Given any positive-definite symmetrical 2*2 matrix Z:
                :math:`Tr(Z^{1/2}) = λ_1^{1/2} + λ_2^{1/2}`
            where :math:`λ_1` and :math:`λ_2` are the eigen values of Z
            Meanwhile we have:
                :math:`Tr(Z) = λ_1 + λ_2`
    
                :math:`det(Z) = λ_1 * λ_2`
            Combination with following formula:
                :math:`(λ_1^{1/2}+λ_2^{1/2})^2 = λ_1+λ_2+2 *(λ_1 * λ_2)^{1/2}`
            Yield:
                :math:`Tr(Z^{1/2}) = (Tr(Z) + 2 * (det(Z))^{1/2})^{1/2}`
            For gwd loss the frustrating coupling part is:
                :math:`Tr((Σ_p^{1/2} * Σ_t * Σp^{1/2})^{1/2})`
            Assuming :math:`Z = Σ_p^{1/2} * Σ_t * Σ_p^{1/2}` then:
                :math:`Tr(Z) = Tr(Σ_p^{1/2} * Σ_t * Σ_p^{1/2})
                = Tr(Σ_p^{1/2} * Σ_p^{1/2} * Σ_t)
                = Tr(Σ_p * Σ_t)`
                :math:`det(Z) = det(Σ_p^{1/2} * Σ_t * Σ_p^{1/2})
                = det(Σ_p^{1/2}) * det(Σ_t) * det(Σ_p^{1/2})
                = det(Σ_p * Σ_t)`
            and thus we can rewrite the coupling part as:
                :math:`Tr(Z^{1/2}) = (Tr(Z) + 2 * (det(Z))^{1/2})^{1/2}`
                :math:`Tr((Σ_p^{1/2} * Σ_t * Σ_p^{1/2})^{1/2})
                = (Tr(Σ_p * Σ_t) + 2 * (det(Σ_p * Σ_t))^{1/2})^{1/2}`
    
        Args:
            pred (torch.Tensor): Predicted bboxes.
            target (torch.Tensor): Corresponding gt bboxes.
            fun (str): The function applied to distance. Defaults to 'log1p'.
            tau (float): Defaults to 1.0.
            alpha (float): Defaults to 1.0.
            normalize (bool): Whether to normalize the distance. Defaults to True.
    
        Returns:
            loss (torch.Tensor)
    
        """
        xy_p, Sigma_p = pred
        xy_t, Sigma_t = target
    
        xy_distance = (xy_p - xy_t).square().sum(dim=-1)
    
        whr_distance = Sigma_p.diagonal(dim1=-2, dim2=-1).sum(dim=-1)
        whr_distance = whr_distance + Sigma_t.diagonal(
            dim1=-2, dim2=-1).sum(dim=-1)
    
        _t_tr = (Sigma_p.bmm(Sigma_t)).diagonal(dim1=-2, dim2=-1).sum(dim=-1)
        _t_det_sqrt = (Sigma_p.det() * Sigma_t.det()).clamp(1e-7).sqrt()
        whr_distance = whr_distance + (-2) * (
            (_t_tr + 2 * _t_det_sqrt).clamp(1e-7).sqrt())
    
        distance = (xy_distance + alpha * alpha * whr_distance).clamp(1e-7).sqrt()
    
        if normalize:
            scale = 2 * (
                _t_det_sqrt.clamp(1e-7).sqrt().clamp(1e-7).sqrt()).clamp(1e-7)
            distance = distance / scale
    
        return postprocess(distance, fun=fun, tau=tau)
    

GWD 对无重叠样本也有平滑梯度,比 IoU 类 loss 更稳定,比如当IoU等于0时,GWD对于离得比较远的两个框相比于离得近的惩罚力度更大。

三、360°旋转角度引入的新问题

上述IoU Loss尽管可以解决0-180的旋转框回归问题,但是0和180度的角度歧义问题缺没有解决,也就是一个旋转0度和旋转180度的框和另外一个框计算iou的时候计算出来的iou是一样的,这明显是一个反复提到的角度周期性问题:因为我们做的是360度旋转目标检测!

那么如何解决这个问题呢?
我们的baseline给出了最直接的方法:额外针对角度增加一个L1 Loss。

那么是否有更加优雅的Loss将角度回归与框的回归统一起来呢?
我相信是有的,这有待研究。

四、总结

写到这里,360°旋转文字区域检测实战系列文章应该是结束了,本来想说一些道理 到这里的时候又不想说了,还是把本文当做一个单纯的技术分享吧!

早安/午安/晚安,感谢你读到这里,谢谢,我的读者朋友。

评论