metrax.PSNR#
- class metrax.PSNR(total: Array, count: Array)#
Bases:
AveragePSNR (Peak Signal-to-Noise Ratio) Metric.
This class calculates the Peak Signal-to-Noise Ratio (PSNR) between two images to measure the quality of a reconstructed image compared to a reference.
\[\text{PSNR}(I, J) = 10 \cdot \log_{10} \left( \frac{\max(I)^2}{\text{MSE}(I, J)} \right)\]- Where:
\(\max(I)\) is the maximum possible pixel value of the input image.
\(\text{MSE}(I, J)\) is the mean squared error between images
\(I\) and \(J\).
Methods
__init__(total, count)compute()Computes final metrics from intermediate values.
compute_value()Wraps compute() and returns a values.Value.
empty()Returns an empty instance (i.e. .merge(Metric.empty()) is a no-op).
from_fun(fun)Calls cls.from_model_output with the return value from fun.
from_model_output(predictions, targets, max_val)Computes PSNR for a batch of images and creates an PSNR metric instance.
from_output(name)Calls cls.from_model_output with model output named name.
merge(other)Returns Metric that is the accumulation of self and other.
reduce()Reduces the metric along it first axis by calling _reduce_merge().
replace(**updates)Returns a new object replacing the specified fields with new values.
Attributes
totalcount- classmethod from_model_output(predictions: Array, targets: Array, max_val: float) PSNR#
Computes PSNR for a batch of images and creates an PSNR metric instance.
- Parameters:
predictions – A JAX array of predicted images, with shape
(batch, H, W, C).targets – A JAX array of ground truth images, with shape
(batch, H, W, C).max_val – The maximum possible pixel value (dynamic range) of the images (e.g., 1.0 for float images in [0,1], 255 for uint8 images).
- Returns:
A
PSNRinstance containing per‑image PSNR values.
- replace(**updates)#
Returns a new object replacing the specified fields with new values.