跳转至

360°旋转文字区域检测实战3:算法调优与工业级加速秘籍

本文写于2026年4月6号

一、算法评估

在拿到需求之后,一般我们会做方案选型与评审,可在这个过程中我认为比较重要的算法的评估。

一个合理的算法评估要求与业务指标对齐,甚至指标可以直接采用业务指标或者模型指标和业务指标直接给出。 此外评估集的构建要比训练集要求更加严格,不仅要照顾到数据的准确性,还要可以反应真实的业务数据分布。

再回到我们的需求上面,我们做的是旋转目标检测任务,且是360度的有向旋转目标检测。旋转目标检测常用的是mAP(基于旋转 IoU)和IoU-based precision/recall,但是在360度的有向旋转目标检测0° 和 180° 不再等价,模型必须区分“正向 vs 反向”。 IoU 完全无法区分:一个框旋转 180°,旋转IoU 可能仍然很高。比如飞机头朝左 vs 朝右,文本正着 vs 倒着,IoU:可能接近 1。但任务上完全错误。因此当评估该类任务时,对于角度也需要做限制,projects/RR360 里有专门解决方案,它在匹配时增加了角度门控,TP 需要同时满足:IoU >= iou_thr,angle_distance < angle_thr。

二、模型选型

首先需要明确的是,我们当前这个任务相对简单,因为满足如下特点: 1. 每个图像至多需要检测出一个框 2. 单个框一般面积较大,面积占比在20%以上,至多可以到90%

属于大物体检测,且边界明显,所以我们优先选择one stage方法就已经足够。

目标检测模型一般分为backbone、neck、head、loss四个方面,当然还有数据和样本采样方法。
mmrotate框架可以使得我们通过配置文件的形式非常方便的进行组合

  • Backbone mmdet.ResNet/ResNet, mmdet.CSPNeXt, mmcls.ConvNeXt,RTMDet,RotatedRetinaNet

  • Neck mmdet.FPN/FPN, mmdet.CSPNeXtPAFPN

  • Head RotatedRTMDetSepBNHead, RotatedRetinaHead, mmdet.RetinaHead, RotatedFCOSHead, RotatedATSSHead, R3Head, R3RefineHead, S2AHead, S2ARefineHead, CFAHead, RotatedRepPointsHead, OrientedRepPointsHead, SAMRepPointsHead, H2RBoxHead, AngleBranchRetinaHead

  • Loss mmdet.QualityFocalLoss, mmdet.FocalLoss/FocalLoss, mmdet.CrossEntropyLoss, mmdet.IoULoss, RotatedIoULoss, mmdet.L1Loss/L1Loss, mmdet.SmoothL1Loss/SmoothL1Loss, ConvexGIoULoss, BCConvexGIoULoss, GDLoss, GDLoss_v1, KFLoss, SmoothFocalLoss, SpatialBorderLoss, H2RBoxConsistencyLoss

为复用预训练权重,经过试验,我们这里采用RTMDet的旋转目标检测版本,Loss更改为GDLoss(type=GWD)也就是GWDLoss。

主要考虑是我们的部署硬件为昇腾,旋转目标检测支持的算子并不多,NPU加速训练效果并不明显。

三、bad case分析

在一般的模型训练完之后,还需要对算法进行bad case和corner case分析,通过需要写脚本将识别错误的找出来,看一下到底是模型识别错误还是标注错误。

标注错误可以分为两种:

  1. 本身有更合适的定义规则,需要对数据集整体翻修
  2. 在现有标注规则下,确实存在标注错误,这可能跟标注人员对于标注文档理解的不一致有关系

对于数据的修复和标注规则的重新定义往往是提升精度最直接和最快的方式。

关于模型识别错误,我们需要看下到底是模型识别能力不足还是有更好的标注规则,分别在数据和算法两个维度去分析,想办法提升精度。

在一般的工作流程中,我们可以分析bad case的类型,然后想好对应解决方案,列好优先级(P0、P1、P2)并按照优先级去执行。

本项目基于mmrotate框架实现的bad case工具脚本如下(仅供参考):

import argparse
import csv
import os
import os.path as osp
from dataclasses import dataclass
from typing import Dict, List, Tuple

import cv2
import numpy as np
import torch
from mmcv.ops import box_iou_rotated
from mmdet.apis import inference_detector, init_detector
from mmdet.utils import register_all_modules as register_all_modules_mmdet
from projects.RR360.structures.bbox import RotatedBoxes

import mmrotate.structures
from mmrotate.structures import qbox2rbox, rbox2qbox
from mmrotate.utils import register_all_modules

# TODO : Refactoring with registry build
mmrotate.structures.bbox.RotatedBoxes = RotatedBoxes


@dataclass
class PredInfo:
    rbox: np.ndarray
    score: float
    label: int
    matched: bool
    match_iou: float
    match_angle_deg: float


@dataclass
class ImageStat:
    image_name: str
    image_path: str
    ann_path: str
    tp: int
    fp: int
    fn: int
    precision: float
    recall: float
    f1: float
    metric_value: float
    mean_tp_iou: float
    mean_tp_angle_err_deg: float
    pred_infos: List[PredInfo]
    gt_qboxes: np.ndarray
    gt_labels: np.ndarray


def parse_args():
    parser = argparse.ArgumentParser(
        description='Mine bad cases for rotated text detection model')
    parser.add_argument('config', help='Config file path')
    parser.add_argument('checkpoint', help='Checkpoint file path')
    parser.add_argument(
        '--img-dir',
        default='data/TRR360D/img_test_obbox',
        help='Directory of evaluation images')
    parser.add_argument(
        '--ann-dir',
        default='data/TRR360D/ann_test_obbox',
        help='Directory of DOTA-style annotations (.txt)')
    parser.add_argument(
        '--score-thr',
        type=float,
        default=0.3,
        help='Prediction score threshold')
    parser.add_argument(
        '--iou-thr',
        type=float,
        default=0.5,
        help='IoU threshold for TP matching')
    parser.add_argument(
        '--angle-thr-deg',
        type=float,
        default=-1.0,
        help='Optional angle threshold in degree for TP matching; <0 disables')
    parser.add_argument(
        '--angle-period-deg',
        type=float,
        default=180.0,
        choices=[180.0, 360.0],
        help='Angle periodicity for angle error calculation')
    parser.add_argument(
        '--metric',
        default='f1',
        choices=['f1', 'recall', 'precision'],
        help='Metric for ranking bad cases')
    parser.add_argument(
        '--metric-thr',
        type=float,
        default=0.7,
        help='Select images with metric value < metric-thr')
    parser.add_argument(
        '--max-badcases',
        type=int,
        default=50,
        help='Maximum number of bad cases to export')
    parser.add_argument(
        '--device',
        default='cpu',
        help='Device for model inference, e.g. cpu / cuda:0')
    parser.add_argument(
        '--out-dir',
        default='work_dirs/RR360/badcases_f1',
        help='Output directory')
    return parser.parse_args()


def _collect_images(img_dir: str) -> List[str]:
    if not osp.isdir(img_dir):
        raise FileNotFoundError(f'Image directory not found: {img_dir}')
    exts = {'.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff'}
    paths = []
    for name in sorted(os.listdir(img_dir)):
        path = osp.join(img_dir, name)
        if osp.isfile(path) and osp.splitext(name.lower())[1] in exts:
            paths.append(path)
    if not paths:
        raise RuntimeError(f'No images found in {img_dir}')
    return paths


def _parse_dota_ann(ann_path: str, class_to_id: Dict[str, int]) -> Tuple[np.ndarray, np.ndarray]:
    if not osp.exists(ann_path):
        return np.empty((0, 8), dtype=np.float32), np.empty((0,), dtype=np.int64)

    qboxes = []
    labels = []
    with open(ann_path, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            parts = line.split()
            if len(parts) < 9:
                continue
            coords = list(map(float, parts[:8]))
            cls_name = parts[8]
            if cls_name not in class_to_id:
                continue
            qboxes.append(coords)
            labels.append(class_to_id[cls_name])

    if not qboxes:
        return np.empty((0, 8), dtype=np.float32), np.empty((0,), dtype=np.int64)
    return np.asarray(qboxes, dtype=np.float32), np.asarray(labels, dtype=np.int64)


def _to_numpy_pred(pred_instances, score_thr: float):
    if hasattr(pred_instances.bboxes, 'tensor'):
        bboxes = pred_instances.bboxes.tensor
    else:
        bboxes = pred_instances.bboxes
    scores = pred_instances.scores
    labels = pred_instances.labels

    keep = scores >= score_thr
    bboxes = bboxes[keep]
    scores = scores[keep]
    labels = labels[keep]

    if bboxes.numel() == 0:
        return (
            np.empty((0, 5), dtype=np.float32),
            np.empty((0,), dtype=np.float32),
            np.empty((0,), dtype=np.int64),
        )

    return (
        bboxes.detach().cpu().numpy().astype(np.float32),
        scores.detach().cpu().numpy().astype(np.float32),
        labels.detach().cpu().numpy().astype(np.int64),
    )


def _long_side_angle_rad(rboxes: np.ndarray) -> np.ndarray:
    angles = rboxes[:, 4].copy()
    w = rboxes[:, 2]
    h = rboxes[:, 3]
    # Canonicalize to long-side orientation.
    angles[w < h] += np.pi / 2.0
    return angles


def _angular_diff_deg(pred_rboxes: np.ndarray, gt_rboxes: np.ndarray,
                      period_deg: float) -> np.ndarray:
    period_rad = np.deg2rad(period_deg)
    pred_angles_rad = _long_side_angle_rad(pred_rboxes)[:, None]
    gt_angles_rad = _long_side_angle_rad(gt_rboxes)[None, :]
    diff = np.abs(pred_angles_rad - gt_angles_rad)
    diff = np.mod(diff, period_rad)
    diff = np.minimum(diff, period_rad - diff)
    return np.rad2deg(diff)


def _match_single_class(pred_rboxes: np.ndarray, pred_scores: np.ndarray,
                        gt_rboxes: np.ndarray, iou_thr: float,
                        angle_thr_deg: float,
                        angle_period_deg: float) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    matched = np.zeros((pred_rboxes.shape[0],), dtype=np.bool_)
    match_iou = np.zeros((pred_rboxes.shape[0],), dtype=np.float32)
    match_angle_deg = np.zeros((pred_rboxes.shape[0],), dtype=np.float32)
    if pred_rboxes.shape[0] == 0 or gt_rboxes.shape[0] == 0:
        return matched, match_iou, match_angle_deg

    iou_mat = box_iou_rotated(
        torch.from_numpy(pred_rboxes), torch.from_numpy(gt_rboxes)).cpu().numpy()
    ang_mat = _angular_diff_deg(pred_rboxes, gt_rboxes, period_deg=angle_period_deg)
    gt_taken = np.zeros((gt_rboxes.shape[0],), dtype=np.bool_)

    order = np.argsort(-pred_scores)
    for pidx in order:
        best_gidx = -1
        best_iou = 0.0
        best_ang = 0.0
        for gidx in range(gt_rboxes.shape[0]):
            if gt_taken[gidx]:
                continue
            cur_iou = float(iou_mat[pidx, gidx])
            cur_ang = float(ang_mat[pidx, gidx])
            if angle_thr_deg >= 0.0 and cur_ang > angle_thr_deg:
                continue
            if cur_iou > best_iou:
                best_iou = cur_iou
                best_ang = cur_ang
                best_gidx = gidx
        if best_gidx >= 0 and best_iou >= iou_thr:
            matched[pidx] = True
            match_iou[pidx] = best_iou
            match_angle_deg[pidx] = best_ang
            gt_taken[best_gidx] = True
    return matched, match_iou, match_angle_deg


def _compute_stat(img_path: str, ann_dir: str, class_to_id: Dict[str, int], model,
                  score_thr: float, iou_thr: float, angle_thr_deg: float,
                  angle_period_deg: float, metric: str) -> ImageStat:
    img_name = osp.basename(img_path)
    stem = osp.splitext(img_name)[0]
    ann_path = osp.join(ann_dir, f'{stem}.txt')

    gt_qboxes, gt_labels = _parse_dota_ann(ann_path, class_to_id)
    if gt_qboxes.shape[0] > 0:
        gt_rboxes = qbox2rbox(torch.from_numpy(gt_qboxes)).cpu().numpy().astype(np.float32)
    else:
        gt_rboxes = np.empty((0, 5), dtype=np.float32)

    det = inference_detector(model, img_path)
    pred = det.pred_instances
    pred_rboxes, pred_scores, pred_labels = _to_numpy_pred(pred, score_thr=score_thr)

    matched = np.zeros((pred_rboxes.shape[0],), dtype=np.bool_)
    match_iou = np.zeros((pred_rboxes.shape[0],), dtype=np.float32)
    match_angle_deg = np.zeros((pred_rboxes.shape[0],), dtype=np.float32)

    for cls_id in np.unique(np.concatenate([pred_labels, gt_labels], axis=0)
                            if pred_labels.size + gt_labels.size > 0 else np.array([], dtype=np.int64)):
        p_mask = pred_labels == cls_id
        g_mask = gt_labels == cls_id
        cls_pred = pred_rboxes[p_mask]
        cls_score = pred_scores[p_mask]
        cls_gt = gt_rboxes[g_mask]

        cls_match, cls_iou, cls_ang = _match_single_class(
            cls_pred,
            cls_score,
            cls_gt,
            iou_thr=iou_thr,
            angle_thr_deg=angle_thr_deg,
            angle_period_deg=angle_period_deg)
        matched[p_mask] = cls_match
        match_iou[p_mask] = cls_iou
        match_angle_deg[p_mask] = cls_ang

    tp = int(matched.sum())
    fp = int(pred_rboxes.shape[0] - tp)
    fn = int(gt_rboxes.shape[0] - tp)

    precision = tp / (tp + fp) if (tp + fp) > 0 else (1.0 if fn == 0 else 0.0)
    recall = tp / (tp + fn) if (tp + fn) > 0 else 1.0
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
    metric_value = dict(f1=f1, recall=recall, precision=precision)[metric]

    tp_ious = match_iou[matched]
    mean_tp_iou = float(tp_ious.mean()) if tp_ious.size > 0 else 0.0
    tp_ang = match_angle_deg[matched]
    mean_tp_angle_err_deg = float(tp_ang.mean()) if tp_ang.size > 0 else 0.0

    pred_infos = [
        PredInfo(
            rbox=pred_rboxes[i],
            score=float(pred_scores[i]),
            label=int(pred_labels[i]),
            matched=bool(matched[i]),
            match_iou=float(match_iou[i]),
            match_angle_deg=float(match_angle_deg[i])) for i in range(pred_rboxes.shape[0])
    ]

    return ImageStat(
        image_name=img_name,
        image_path=img_path,
        ann_path=ann_path,
        tp=tp,
        fp=fp,
        fn=fn,
        precision=float(precision),
        recall=float(recall),
        f1=float(f1),
        metric_value=float(metric_value),
        mean_tp_iou=mean_tp_iou,
        mean_tp_angle_err_deg=mean_tp_angle_err_deg,
        pred_infos=pred_infos,
        gt_qboxes=gt_qboxes,
        gt_labels=gt_labels)


def _draw_badcase(stat: ImageStat, id_to_class: Dict[int, str], out_path: str):
    img = cv2.imread(stat.image_path)
    if img is None:
        raise RuntimeError(f'Failed to read image: {stat.image_path}')

    # GT: green
    for i in range(stat.gt_qboxes.shape[0]):
        poly = stat.gt_qboxes[i].reshape(4, 2)
        pts = np.round(poly).astype(np.int32)
        cv2.polylines(img, [pts], isClosed=True, color=(0, 255, 0), thickness=2)
        cls_text = id_to_class.get(int(stat.gt_labels[i]), str(int(stat.gt_labels[i])))
        cv2.putText(
            img,
            f'GT:{cls_text}',
            (int(pts[0, 0]), max(10, int(pts[0, 1]) - 6)),
            cv2.FONT_HERSHEY_SIMPLEX,
            0.45,
            (0, 255, 0),
            1,
            cv2.LINE_AA)

    # Pred: TP cyan, FP red
    if stat.pred_infos:
        pred_rboxes = np.stack([x.rbox for x in stat.pred_infos], axis=0).astype(np.float32)
        pred_qboxes = rbox2qbox(torch.from_numpy(pred_rboxes)).cpu().numpy().reshape(-1, 4, 2)
        for i, info in enumerate(stat.pred_infos):
            pts = np.round(pred_qboxes[i]).astype(np.int32)
            color = (255, 255, 0) if info.matched else (0, 0, 255)
            cv2.polylines(img, [pts], isClosed=True, color=color, thickness=2)
            cls_text = id_to_class.get(info.label, str(info.label))
            tag = 'TP' if info.matched else 'FP'
            cv2.putText(
                img,
                f'{tag}:{cls_text} {info.score:.2f} iou:{info.match_iou:.2f} ang:{info.match_angle_deg:.1f}',
                (int(pts[1, 0]), max(10, int(pts[1, 1]) - 6)),
                cv2.FONT_HERSHEY_SIMPLEX,
                0.45,
                color,
                1,
                cv2.LINE_AA)

    header = (
        f'TP={stat.tp} FP={stat.fp} FN={stat.fn}  '
        f'P={stat.precision:.3f} R={stat.recall:.3f} F1={stat.f1:.3f} '
        f'mTPIoU={stat.mean_tp_iou:.3f} mAngErr={stat.mean_tp_angle_err_deg:.1f}')
    cv2.rectangle(img, (0, 0), (img.shape[1], 28), (0, 0, 0), -1)
    cv2.putText(img, header, (8, 19), cv2.FONT_HERSHEY_SIMPLEX, 0.55, (255, 255, 255), 1, cv2.LINE_AA)

    os.makedirs(osp.dirname(out_path), exist_ok=True)
    cv2.imwrite(out_path, img)


def main():
    args = parse_args()
    os.makedirs(args.out_dir, exist_ok=True)
    vis_dir = osp.join(args.out_dir, 'vis')
    os.makedirs(vis_dir, exist_ok=True)

    register_all_modules_mmdet(init_default_scope=False)
    register_all_modules(init_default_scope=True)

    model = init_detector(args.config, args.checkpoint, device=args.device)
    model.eval()

    classes = model.dataset_meta.get('classes', None)
    if classes is None:
        raise RuntimeError('Cannot find class names from model.dataset_meta["classes"]')
    class_to_id = {name: idx for idx, name in enumerate(classes)}
    id_to_class = {idx: name for idx, name in enumerate(classes)}

    image_paths = _collect_images(args.img_dir)
    all_stats: List[ImageStat] = []
    for i, img_path in enumerate(image_paths, 1):
        stat = _compute_stat(
            img_path=img_path,
            ann_dir=args.ann_dir,
            class_to_id=class_to_id,
            model=model,
            score_thr=args.score_thr,
            iou_thr=args.iou_thr,
            angle_thr_deg=args.angle_thr_deg,
            angle_period_deg=args.angle_period_deg,
            metric=args.metric)
        all_stats.append(stat)
        if i % 20 == 0 or i == len(image_paths):
            print(f'Processed {i}/{len(image_paths)}')

    all_stats.sort(key=lambda x: x.metric_value)
    bad = [x for x in all_stats if x.metric_value < args.metric_thr]
    bad = bad[:args.max_badcases]
    if len(bad) == 0:
        bad = all_stats[:min(args.max_badcases, len(all_stats))]

    csv_path = osp.join(args.out_dir, 'all_metrics.csv')
    with open(csv_path, 'w', encoding='utf-8', newline='') as f:
        writer = csv.writer(f)
        writer.writerow([
            'image_name', 'tp', 'fp', 'fn', 'precision', 'recall', 'f1',
            'mean_tp_iou', 'mean_tp_angle_err_deg', 'metric', 'metric_value',
            'selected_badcase'
        ])
        selected = set(x.image_name for x in bad)
        for s in all_stats:
            writer.writerow([
                s.image_name, s.tp, s.fp, s.fn, f'{s.precision:.6f}',
                f'{s.recall:.6f}', f'{s.f1:.6f}', f'{s.mean_tp_iou:.6f}',
                f'{s.mean_tp_angle_err_deg:.6f}',
                args.metric, f'{s.metric_value:.6f}', int(s.image_name in selected)
            ])

    for rank, stat in enumerate(bad, 1):
        out_name = f'{rank:03d}_{stat.image_name}'
        _draw_badcase(stat, id_to_class=id_to_class, out_path=osp.join(vis_dir, out_name))

    summary_path = osp.join(args.out_dir, 'summary.txt')
    with open(summary_path, 'w', encoding='utf-8') as f:
        f.write(f'num_images={len(all_stats)}\n')
        f.write(f'metric={args.metric}\n')
        f.write(f'metric_thr={args.metric_thr}\n')
        f.write(f'score_thr={args.score_thr}\n')
        f.write(f'iou_thr={args.iou_thr}\n')
        f.write(f'angle_thr_deg={args.angle_thr_deg}\n')
        f.write(f'angle_period_deg={args.angle_period_deg}\n')
        f.write(f'num_badcases={len(bad)}\n')
        if bad:
            metric_vals = [x.metric_value for x in bad]
            f.write(f'badcase_metric_min={min(metric_vals):.6f}\n')
            f.write(f'badcase_metric_max={max(metric_vals):.6f}\n')
            f.write(f'badcase_metric_mean={float(np.mean(metric_vals)):.6f}\n')

    print('\n=== Bad Case Analysis Done ===')
    print(f'Output dir: {args.out_dir}')
    print(f'CSV: {csv_path}')
    print(f'Visualization dir: {vis_dir}')
    print(f'Bad cases exported: {len(bad)}')


if __name__ == '__main__':
    main()

四、推理优化

在日常算法工作中,对于模型推理的优化往往是专注于训练的炼丹工程师最容易忽略的一点,推理优化来源如下两地:

  1. 对于业务的深入理解来精简算法流程
  2. 采用模型加速手段(上推理框架做算子融合、剪枝、蒸馏、量化等)

我们这里举一个例子,当前任务每个图像至多需要检测出一个框,那么不要好并行化且需要计算旋转IoU的nms算子是否还需要呢?

答案很显然,我们不再需要nms,而是从score中取max并返回max_index,然后将score和det_boxes[max_index]返回到我们的系统即可。

总结

本文主要讲的是360°旋转文字检测在实际工程中的一整套落地思路。从算法评估开始,就强调评估指标一定要和业务对齐,尤其是旋转检测里不能只看IoU,还要额外考虑角度约束,因为0°和180°在语义上是完全不同的。模型选型上,结合“单图单目标、目标面积大、边界清晰”的特点,优先选择one-stage方案,并基于mmrotate框架使用RTMDet旋转版本,同时搭配合适的损失函数,在兼顾效果的同时也考虑了昇腾硬件的部署限制。在模型训练完成后,还需要通过bad case分析去定位问题来源,把标注错误和模型能力不足区分开来,其中修正数据往往是最快提升效果的方式。最后在推理阶段,重点不在于继续堆复杂模型,而是结合业务特点做简化优化,比如单目标场景下可以直接用score取最大值替代NMS,从而减少计算开销、提升整体推理效率。

评论