metrax.PSNR#

class metrax.PSNR(total: Array, count: Array)#

Bases: Average

PSNR (Peak Signal-to-Noise Ratio) Metric.

This class calculates the Peak Signal-to-Noise Ratio (PSNR) between two images to measure the quality of a reconstructed image compared to a reference.

\[\text{PSNR}(I, J) = 10 \cdot \log_{10} \left( \frac{\max(I)^2}{\text{MSE}(I, J)} \right)\]
Where:
  • \(\max(I)\) is the maximum possible pixel value of the input image.

  • \(\text{MSE}(I, J)\) is the mean squared error between images

\(I\) and \(J\).

__init__(total: Array, count: Array) None#

Methods

__init__(total, count)

compute()

Computes final metrics from intermediate values.

compute_value()

Wraps compute() and returns a values.Value.

empty()

Returns an empty instance (i.e. .merge(Metric.empty()) is a no-op).

from_fun(fun)

Calls cls.from_model_output with the return value from fun.

from_model_output(predictions, targets, max_val)

Computes PSNR for a batch of images and creates an PSNR metric instance.

from_output(name)

Calls cls.from_model_output with model output named name.

merge(other)

Returns Metric that is the accumulation of self and other.

reduce()

Reduces the metric along it first axis by calling _reduce_merge().

replace(**updates)

Returns a new object replacing the specified fields with new values.

Attributes

total

count

classmethod from_model_output(predictions: Array, targets: Array, max_val: float) PSNR#

Computes PSNR for a batch of images and creates an PSNR metric instance.

Parameters:
  • predictions – A JAX array of predicted images, with shape (batch, H, W, C).

  • targets – A JAX array of ground truth images, with shape (batch, H, W, C).

  • max_val – The maximum possible pixel value (dynamic range) of the images (e.g., 1.0 for float images in [0,1], 255 for uint8 images).

Returns:

A PSNR instance containing per‑image PSNR values.

__init__(total: Array, count: Array) None#
replace(**updates)#

Returns a new object replacing the specified fields with new values.