class ReID:
"""Unified ReID runtime that also exposes overrideable public stage hooks."""
def __init__(
self,
path: str | Path | list[str | Path] | tuple[str | Path, ...] | None = None,
*,
weights: str | Path | list[str | Path] | tuple[str | Path, ...] | None = None,
device: str | torch.device = "cpu",
half: bool = False,
) -> None:
model_ref = path if path is not None else weights
if model_ref is None:
model_ref = WEIGHTS / "osnet_x0_25_msmt17.pt"
primary_weight = model_ref[0] if isinstance(model_ref, (list, tuple)) else model_ref
self.path = Path(primary_weight)
self.weights = model_ref
self.device = device if isinstance(device, torch.device) else select_device(device)
self.half = bool(half)
(
self.pt,
self.jit,
self.onnx,
self.xml,
self.engine,
self.tflite,
) = self.model_type(self.path)
self.backend = self
self.model = self.get_backend()
def get_backend(self):
if hasattr(self, "_backend_model"):
return self._backend_model
backend_map = (
(self.pt, PyTorchBackend),
(self.jit, TorchscriptBackend),
(self.onnx, ONNXBackend),
(self.engine, TensorRTBackend),
(self.xml, OpenVinoBackend),
(self.tflite, TFLiteBackend),
)
for enabled, backend_class in backend_map:
if enabled:
self._backend_model = backend_class(self.weights, self.device, self.half)
return self._backend_model
LOGGER.error("This model framework is not supported yet!")
raise SystemExit(1)
def check_suffix(
self,
file: Path | str = "osnet_x0_25_msmt17.pt",
suffix: str | Tuple[str, ...] = (".pt",),
msg: str = "",
) -> None:
suffixes = [suffix] if isinstance(suffix, str) else list(suffix)
files = [file] if isinstance(file, (str, Path)) else list(file)
for candidate in files:
file_suffix = Path(candidate).suffix.lower()
if file_suffix and file_suffix not in suffixes:
LOGGER.error(
f"File {candidate} does not have an acceptable suffix. Expected: {suffixes}{msg}"
)
def model_type(self, path: Path) -> Tuple[bool, ...]:
suffixes = list(export_formats().Suffix)
self.check_suffix(path, suffixes)
types = [suffix in Path(path).name for suffix in suffixes]
if Path(path).suffix in {".xml", ".bin"}:
try:
openvino_index = suffixes.index("_openvino_model")
types[openvino_index] = True
except ValueError:
pass
return tuple(types)
@staticmethod
def _coerce_boxes(boxes: Any) -> np.ndarray:
arr = np.asarray(boxes, dtype=np.float32)
if arr.size == 0:
cols = arr.shape[1] if arr.ndim == 2 else 4
return np.empty((0, cols), dtype=np.float32)
if arr.ndim == 1:
arr = arr.reshape(1, -1)
return arr.astype(np.float32, copy=False)
@staticmethod
def _coerce_crops(crops: Any) -> list[np.ndarray]:
if isinstance(crops, (str, Path)):
return [resolve_image(crops)]
if isinstance(crops, np.ndarray):
if crops.ndim == 4:
return [np.asarray(crop) for crop in crops]
if crops.ndim == 3:
return [crops]
raise ValueError(f"Unsupported crop tensor shape: {crops.shape}")
if isinstance(crops, (list, tuple)):
return [
resolve_image(crop) if isinstance(crop, (str, Path)) else np.asarray(crop)
for crop in crops
]
raise ValueError(f"Unsupported ReID input type: {type(crops)}")
def _prepare_crop_batch(self, crops: list[np.ndarray]) -> torch.Tensor:
if not crops:
return torch.empty(
(0, 3, *self.model.input_shape),
dtype=torch.float32,
device=self.model.device,
)
batch = torch.empty(
(len(crops), 3, *self.model.input_shape),
dtype=torch.float16 if self.model.half else torch.float32,
device=self.model.device,
)
for index, crop in enumerate(crops):
if crop.size == 0:
crop = np.zeros((*self.model.input_shape, 3), dtype=np.uint8)
resized = cv2.resize(
crop,
(self.model.input_shape[1], self.model.input_shape[0]),
interpolation=cv2.INTER_LINEAR,
)
resized = cv2.cvtColor(resized, cv2.COLOR_BGR2RGB)
tensor = torch.from_numpy(resized).to(batch.device, dtype=batch.dtype)
batch[index] = tensor.permute(2, 0, 1)
batch = batch / 255.0
batch = (batch - self.model.mean_array) / self.model.std_array
return batch
def preprocess(self, inputs, boxes=None, **kwargs):
if boxes is not None:
return {
"mode": "image_boxes",
"image": resolve_image(inputs),
"boxes": self._coerce_boxes(boxes),
}
return {
"mode": "crops",
"crops": self._coerce_crops(inputs),
}
def process(self, payload, **kwargs) -> np.ndarray:
if payload["mode"] == "image_boxes":
boxes = payload["boxes"]
if boxes.size == 0:
return np.empty((0, 0), dtype=np.float32)
return self.model.get_features(boxes, payload["image"])
crops = payload["crops"]
if not crops:
return np.empty((0, 0), dtype=np.float32)
batch = self._prepare_crop_batch(crops)
batch = self.model.inference_preprocess(batch)
features = self.model.forward(batch)
features = np.asarray(self.model.inference_postprocess(features), dtype=np.float32)
if features.size == 0:
return np.empty((0, 0), dtype=np.float32)
norms = np.linalg.norm(features, axis=-1, keepdims=True)
norms[norms == 0] = 1.0
return features / norms
def postprocess(self, features: np.ndarray, **kwargs) -> np.ndarray:
return features
def __call__(self, inputs, boxes=None, **kwargs) -> np.ndarray:
payload = self.preprocess(inputs, boxes=boxes, **kwargs)
features = self.process(payload, boxes=boxes, **kwargs)
return self.postprocess(features, boxes=boxes, **kwargs)