metrax.IoU#
- class metrax.IoU(total: Array, count: Array)#
Bases:
AverageMeasures Intersection over Union (IoU) for semantic segmentation.
The general formula for IoU for a single class is: $IoU_{class} = frac{TP}{TP + FP + FN}$ where TP, FP, FN are True Positives, False Positives, and False Negatives.
Per-Batch Processing: For each input batch, a mean IoU is calculated. This involves: 1. Aggregating TP, FP, and FN pixel counts for each specified target class
(from the required target_class_ids list) across all samples within the batch.
Computing IoU for each of these classes using the batch-aggregated counts: $IoU_{class} = frac{TP}{TP + FP + FN + epsilon}$.
Averaging these per-class IoU scores to get a single value for the batch. - If target_class_ids is empty, an array of zeros of shape (B,)
(where B is batch size) is produced by _calculate_iou.
Otherwise, a scalar jnp.ndarray (shape ()) representing the mean IoU is produced.
Accumulation & Final Metric: This class inherits from base.Average. It accumulates the results from per-batch processing and compute() returns the final mean IoU as a scalar jnp.ndarray (shape ()).
Methods
__init__(total, count)compute()Computes final metrics from intermediate values.
compute_value()Wraps compute() and returns a values.Value.
empty()Returns an empty instance (i.e. .merge(Metric.empty()) is a no-op).
from_fun(fun)Calls cls.from_model_output with the return value from fun.
from_model_output(predictions, targets, ...)Creates an IoU instance from a batch of model outputs.
from_output(name)Calls cls.from_model_output with model output named name.
merge(other)Returns Metric that is the accumulation of self and other.
reduce()Reduces the metric along it first axis by calling _reduce_merge().
replace(**updates)Returns a new object replacing the specified fields with new values.
Attributes
totalcount- classmethod from_model_output(predictions: Array, targets: Array, num_classes: int, target_class_ids: Array, from_logits: bool = False, epsilon: float = 1e-07) IoU#
Creates an IoU instance from a batch of model outputs.
Per-batch processing: 1. Preprocesses predictions and targets into integer label masks of
shape (B, H, W). (B: batch size, H: height, W: width).
Calls _calculate_iou using the provided target_class_ids to compute the batch’s mean IoU.
- Parameters:
predictions – jax.Array. Model predictions. - If from_logits is True: shape (B, H, W, C) (C: num_classes). - If from_logits is False: shape (B, H, W) or (B, H, W, 1).
targets – jax.Array. Ground truth segmentation masks. Shape (B, H, W) or (B, H, W, 1), integer class labels.
num_classes – Total number of distinct classes (C). Integer.
target_class_ids – An array of integer class IDs for which to compute IoU.
from_logits – bool. If True, predictions are logits and argmax is applied. Default is False.
epsilon – float. Small value for stable IoU calculation. Default is 1e-7.
- Returns:
An IoU metric instance updated with the IoU score from this batch.
- replace(**updates)#
Returns a new object replacing the specified fields with new values.