metrax.SSIM#

class metrax.SSIM(total: Array, count: Array)#

Bases: Average

SSIM (Structural Similarity Index Measure) Metric.

This class calculates the structural similarity between predicted and target images and averages it over a dataset. SSIM is a perception-based model that considers changes in structural information, luminance, and contrast.

The general SSIM formula considers three components: luminance (l),

contrast (c), and structure (s):

\[SSIM(x, y) = [l(x, y)]^\alpha \cdot [c(x, y)]^\beta \cdot [s(x, y)]^\gamma\]
Where:
  • Luminance comparison: \(l(x, y) = \frac{2\mu_x\mu_y + c_1}{\mu_x^2 + \mu_y^2 + c_1}\)

  • Contrast comparison: \(c(x, y) = \frac{2\sigma_x\sigma_y + c_2}{\sigma_x^2 + \sigma_y^2 + c_2}\)

  • Structure comparison: \(s(x, y) = \frac{\sigma_{xy} + c_3}{\sigma_x\sigma_y + c_3}\)

This implementation uses a common simplified form where \(\alpha = \beta = \gamma = 1\) and \(c_3 = c_2 / 2\).

This leads to the combined formula:

\[SSIM(x, y) = \frac{(2\mu_x\mu_y + c_1)(2\sigma_{xy} + c_2)}{(\mu_x^2 + \mu_y^2 + c_1)(\sigma_x^2 + \sigma_y^2 + c_2)}\]
In these formulas:
  • \(\mu_x\) and \(\mu_y\) are the local means of \(x\) and

\(y\). - \(\sigma_x^2\) and \(\sigma_y^2\) are the local variances of \(x\) and \(y\). - \(\sigma_{xy}\) is the local covariance of \(x\) and \(y\). - \(c_1 = (K_1 L)^2\) and \(c_2 = (K_2 L)^2\) are stabilization constants,

where \(L\) is the dynamic range of pixel values, and \(K_1, K_2\) are small constants (e.g., 0.01 and 0.03).

__init__(total: Array, count: Array) None#

Methods

__init__(total, count)

compute()

Computes final metrics from intermediate values.

compute_value()

Wraps compute() and returns a values.Value.

empty()

Returns an empty instance (i.e. .merge(Metric.empty()) is a no-op).

from_fun(fun)

Calls cls.from_model_output with the return value from fun.

from_model_output(predictions, targets, max_val)

Computes SSIM for a batch of images and creates an SSIM metric instance.

from_output(name)

Calls cls.from_model_output with model output named name.

merge(other)

Returns Metric that is the accumulation of self and other.

reduce()

Reduces the metric along it first axis by calling _reduce_merge().

replace(**updates)

Returns a new object replacing the specified fields with new values.

Attributes

classmethod from_model_output(predictions: Array, targets: Array, max_val: float, filter_size: int = 11, filter_sigma: float = 1.5, k1: float = 0.01, k2: float = 0.03) SSIM#

Computes SSIM for a batch of images and creates an SSIM metric instance.

This method takes batches of predicted and target images, calculates their SSIM values, and then initializes an SSIM metric object suitable for aggregation across multiple batches.

Parameters:
  • predictions – A JAX array of predicted images, with shape (batch, height, width, channels).

  • targets – A JAX array of ground truth images, with shape (batch, height, width, channels).

  • max_val – The maximum possible pixel value (dynamic range) of the images (e.g., 1.0 for float images in [0,1], 255 for uint8 images).

  • filter_size – The size of the Gaussian filter window used in SSIM calculation (default is 11).

  • filter_sigma – The standard deviation of the Gaussian filter (default is 1.5).

  • k1 – SSIM stability constant for the luminance term (default is 0.01).

  • k2 – SSIM stability constant for the contrast/structure term (default is 0.03).

Returns:

An SSIM instance containing the SSIM values for the current batch, ready for averaging.

__init__(total: Array, count: Array) None#
replace(**updates)#

Returns a new object replacing the specified fields with new values.