metrax.SNR#
- class metrax.SNR(total: Array, count: Array)#
Bases:
AverageSNR (Signal-to-Noise Ratio) Metric for audio.
This class calculates the Signal-to-Noise Ratio (SNR) in decibels (dB) between a predicted audio signal and a ground truth audio signal, and averages it over a dataset.
The SNR is defined as:
\[SNR_{dB} = 10 \cdot \log_{10} \left( \frac{P_{signal}}{P_{noise}} \right)\]- Where:
\(P_{signal}\) is the power of the ground truth signal (target). By default (zero_mean=False), this is the mean of the squared target values. If zero_mean=True, it’s the variance of the target values.
\(P_{noise}\) is the power of the noise component, which is defined as the difference between the target and preds (target - preds). By default (zero_mean=False), this is the mean of the squared noise values. If zero_mean=True, it’s the variance of the noise values.
Methods
__init__(total, count)compute()Computes final metrics from intermediate values.
compute_value()Wraps compute() and returns a values.Value.
empty()Returns an empty instance (i.e. .merge(Metric.empty()) is a no-op).
from_fun(fun)Calls cls.from_model_output with the return value from fun.
from_model_output(predictions, targets[, ...])Computes SNR for a batch of audio signals and creates an SNR metric instance.
from_output(name)Calls cls.from_model_output with model output named name.
merge(other)Returns Metric that is the accumulation of self and other.
reduce()Reduces the metric along it first axis by calling _reduce_merge().
replace(**updates)Returns a new object replacing the specified fields with new values.
Attributes
totalcount- classmethod from_model_output(predictions: Array, targets: Array, zero_mean: bool = False) SNR#
Computes SNR for a batch of audio signals and creates an SNR metric instance.
- Parameters:
predictions – A JAX array of predicted audio signals.
targets – A JAX array of ground truth audio signals.
zero_mean – If True, subtracts the mean from the signal and noise before calculating their respective powers.
- Returns:
An SNR instance containing the SNR value for the current batch, ready for averaging.
- replace(**updates)#
Returns a new object replacing the specified fields with new values.