Skip to content

ReID

ReID is the unified appearance-model runtime used by the tracker backends and public API.

Unified ReID runtime that also exposes overrideable public stage hooks.

Source code in boxmot/reid/core/reid.py
class ReID:
    """Unified ReID runtime that also exposes overrideable public stage hooks."""

    def __init__(
        self,
        path: str | Path | list[str | Path] | tuple[str | Path, ...] | None = None,
        *,
        weights: str | Path | list[str | Path] | tuple[str | Path, ...] | None = None,
        device: str | torch.device = "cpu",
        half: bool = False,
    ) -> None:
        model_ref = path if path is not None else weights
        if model_ref is None:
            model_ref = WEIGHTS / "osnet_x0_25_msmt17.pt"

        primary_weight = model_ref[0] if isinstance(model_ref, (list, tuple)) else model_ref
        self.path = Path(primary_weight)
        self.weights = model_ref
        self.device = device if isinstance(device, torch.device) else select_device(device)
        self.half = bool(half)
        (
            self.pt,
            self.jit,
            self.onnx,
            self.xml,
            self.engine,
            self.tflite,
        ) = self.model_type(self.path)
        self.backend = self
        self.model = self.get_backend()

    def get_backend(self):
        if hasattr(self, "_backend_model"):
            return self._backend_model

        backend_map = (
            (self.pt, PyTorchBackend),
            (self.jit, TorchscriptBackend),
            (self.onnx, ONNXBackend),
            (self.engine, TensorRTBackend),
            (self.xml, OpenVinoBackend),
            (self.tflite, TFLiteBackend),
        )

        for enabled, backend_class in backend_map:
            if enabled:
                self._backend_model = backend_class(self.weights, self.device, self.half)
                return self._backend_model

        LOGGER.error("This model framework is not supported yet!")
        raise SystemExit(1)

    def check_suffix(
        self,
        file: Path | str = "osnet_x0_25_msmt17.pt",
        suffix: str | Tuple[str, ...] = (".pt",),
        msg: str = "",
    ) -> None:
        suffixes = [suffix] if isinstance(suffix, str) else list(suffix)
        files = [file] if isinstance(file, (str, Path)) else list(file)

        for candidate in files:
            file_suffix = Path(candidate).suffix.lower()
            if file_suffix and file_suffix not in suffixes:
                LOGGER.error(
                    f"File {candidate} does not have an acceptable suffix. Expected: {suffixes}{msg}"
                )

    def model_type(self, path: Path) -> Tuple[bool, ...]:
        suffixes = list(export_formats().Suffix)
        self.check_suffix(path, suffixes)
        types = [suffix in Path(path).name for suffix in suffixes]

        if Path(path).suffix in {".xml", ".bin"}:
            try:
                openvino_index = suffixes.index("_openvino_model")
                types[openvino_index] = True
            except ValueError:
                pass

        return tuple(types)

    @staticmethod
    def _coerce_boxes(boxes: Any) -> np.ndarray:
        arr = np.asarray(boxes, dtype=np.float32)
        if arr.size == 0:
            cols = arr.shape[1] if arr.ndim == 2 else 4
            return np.empty((0, cols), dtype=np.float32)
        if arr.ndim == 1:
            arr = arr.reshape(1, -1)
        return arr.astype(np.float32, copy=False)

    @staticmethod
    def _coerce_crops(crops: Any) -> list[np.ndarray]:
        if isinstance(crops, (str, Path)):
            return [resolve_image(crops)]

        if isinstance(crops, np.ndarray):
            if crops.ndim == 4:
                return [np.asarray(crop) for crop in crops]
            if crops.ndim == 3:
                return [crops]
            raise ValueError(f"Unsupported crop tensor shape: {crops.shape}")

        if isinstance(crops, (list, tuple)):
            return [
                resolve_image(crop) if isinstance(crop, (str, Path)) else np.asarray(crop)
                for crop in crops
            ]

        raise ValueError(f"Unsupported ReID input type: {type(crops)}")

    def _prepare_crop_batch(self, crops: list[np.ndarray]) -> torch.Tensor:
        if not crops:
            return torch.empty(
                (0, 3, *self.model.input_shape),
                dtype=torch.float32,
                device=self.model.device,
            )

        batch = torch.empty(
            (len(crops), 3, *self.model.input_shape),
            dtype=torch.float16 if self.model.half else torch.float32,
            device=self.model.device,
        )

        for index, crop in enumerate(crops):
            if crop.size == 0:
                crop = np.zeros((*self.model.input_shape, 3), dtype=np.uint8)

            resized = cv2.resize(
                crop,
                (self.model.input_shape[1], self.model.input_shape[0]),
                interpolation=cv2.INTER_LINEAR,
            )
            resized = cv2.cvtColor(resized, cv2.COLOR_BGR2RGB)
            tensor = torch.from_numpy(resized).to(batch.device, dtype=batch.dtype)
            batch[index] = tensor.permute(2, 0, 1)

        batch = batch / 255.0
        batch = (batch - self.model.mean_array) / self.model.std_array
        return batch

    def preprocess(self, inputs, boxes=None, **kwargs):
        if boxes is not None:
            return {
                "mode": "image_boxes",
                "image": resolve_image(inputs),
                "boxes": self._coerce_boxes(boxes),
            }

        return {
            "mode": "crops",
            "crops": self._coerce_crops(inputs),
        }

    def process(self, payload, **kwargs) -> np.ndarray:
        if payload["mode"] == "image_boxes":
            boxes = payload["boxes"]
            if boxes.size == 0:
                return np.empty((0, 0), dtype=np.float32)
            return self.model.get_features(boxes, payload["image"])

        crops = payload["crops"]
        if not crops:
            return np.empty((0, 0), dtype=np.float32)

        batch = self._prepare_crop_batch(crops)
        batch = self.model.inference_preprocess(batch)
        features = self.model.forward(batch)
        features = np.asarray(self.model.inference_postprocess(features), dtype=np.float32)
        if features.size == 0:
            return np.empty((0, 0), dtype=np.float32)

        norms = np.linalg.norm(features, axis=-1, keepdims=True)
        norms[norms == 0] = 1.0
        return features / norms

    def postprocess(self, features: np.ndarray, **kwargs) -> np.ndarray:
        return features

    def __call__(self, inputs, boxes=None, **kwargs) -> np.ndarray:
        payload = self.preprocess(inputs, boxes=boxes, **kwargs)
        features = self.process(payload, boxes=boxes, **kwargs)
        return self.postprocess(features, boxes=boxes, **kwargs)