metrax.IoU#

class metrax.IoU(total: Array, count: Array)#

Bases: Average

Measures Intersection over Union (IoU) for semantic segmentation.

The general formula for IoU for a single class is: $IoU_{class} = frac{TP}{TP + FP + FN}$ where TP, FP, FN are True Positives, False Positives, and False Negatives.

Per-Batch Processing: For each input batch, a mean IoU is calculated. This involves: 1. Aggregating TP, FP, and FN pixel counts for each specified target class

(from the required target_class_ids list) across all samples within the batch.

  1. Computing IoU for each of these classes using the batch-aggregated counts: $IoU_{class} = frac{TP}{TP + FP + FN + epsilon}$.

  2. Averaging these per-class IoU scores to get a single value for the batch. - If target_class_ids is empty, an array of zeros of shape (B,)

    (where B is batch size) is produced by _calculate_iou.

    • Otherwise, a scalar jnp.ndarray (shape ()) representing the mean IoU is produced.

Accumulation & Final Metric: This class inherits from base.Average. It accumulates the results from per-batch processing and compute() returns the final mean IoU as a scalar jnp.ndarray (shape ()).

__init__(total: Array, count: Array) None#

Methods

__init__(total, count)

compute()

Computes final metrics from intermediate values.

compute_value()

Wraps compute() and returns a values.Value.

empty()

Returns an empty instance (i.e. .merge(Metric.empty()) is a no-op).

from_fun(fun)

Calls cls.from_model_output with the return value from fun.

from_model_output(predictions, targets, ...)

Creates an IoU instance from a batch of model outputs.

from_output(name)

Calls cls.from_model_output with model output named name.

merge(other)

Returns Metric that is the accumulation of self and other.

reduce()

Reduces the metric along it first axis by calling _reduce_merge().

replace(**updates)

Returns a new object replacing the specified fields with new values.

Attributes

total

count

classmethod from_model_output(predictions: Array, targets: Array, num_classes: int, target_class_ids: Array, from_logits: bool = False, epsilon: float = 1e-07) IoU#

Creates an IoU instance from a batch of model outputs.

Per-batch processing: 1. Preprocesses predictions and targets into integer label masks of

shape (B, H, W). (B: batch size, H: height, W: width).

  1. Calls _calculate_iou using the provided target_class_ids to compute the batch’s mean IoU.

Parameters:
  • predictionsjax.Array. Model predictions. - If from_logits is True: shape (B, H, W, C) (C: num_classes). - If from_logits is False: shape (B, H, W) or (B, H, W, 1).

  • targetsjax.Array. Ground truth segmentation masks. Shape (B, H, W) or (B, H, W, 1), integer class labels.

  • num_classes – Total number of distinct classes (C). Integer.

  • target_class_ids – An array of integer class IDs for which to compute IoU.

  • from_logitsbool. If True, predictions are logits and argmax is applied. Default is False.

  • epsilonfloat. Small value for stable IoU calculation. Default is 1e-7.

Returns:

An IoU metric instance updated with the IoU score from this batch.

__init__(total: Array, count: Array) None#
replace(**updates)#

Returns a new object replacing the specified fields with new values.