metrax.MSE#

class metrax.MSE(total: Array, count: Array)#

Bases: Average

Computes the mean squared error for regression problems given predictions and labels.

The mean squared error without sample weights is defined as:

\[MSE = \frac{1}{N} \sum_{i=1}^{N} (y_i - \hat{y}_i)^2\]

When sample weights \(w_i\) are provided, the weighted mean squared error is:

\[MSE = \frac{\sum_{i=1}^{N} w_i(y_i - \hat{y}_i)^2}{\sum_{i=1}^{N} w_i}\]
where:
  • \(y_i\) are true values

  • \(\hat{y}_i\) are predictions

  • \(w_i\) are sample weights

  • \(N\) is the number of samples

__init__(total: Array, count: Array) None#

Methods

__init__(total, count)

compute()

Computes final metrics from intermediate values.

compute_value()

Wraps compute() and returns a values.Value.

empty()

Returns an empty instance (i.e. .merge(Metric.empty()) is a no-op).

from_fun(fun)

Calls cls.from_model_output with the return value from fun.

from_model_output(predictions, labels[, ...])

Updates the metric.

from_output(name)

Calls cls.from_model_output with model output named name.

merge(other)

Returns Metric that is the accumulation of self and other.

reduce()

Reduces the metric along it first axis by calling _reduce_merge().

replace(**updates)

Returns a new object replacing the specified fields with new values.

Attributes

total

count

classmethod from_model_output(predictions: Array, labels: Array, sample_weights: Array | None = None) MSE#

Updates the metric.

Parameters:
  • predictions – A floating point 1D vector representing the prediction generated from the model. The shape should be (batch_size,).

  • labels – True value. The shape should be (batch_size,).

  • sample_weights – An optional floating point 1D vector representing the weight of each sample. The shape should be (batch_size,).

Returns:

Updated MSE metric. The shape should be a single scalar.

Raises:
  • ValueError – If type of labels is wrong or the shapes of predictions

  • and labels are incompatible.

__init__(total: Array, count: Array) None#
replace(**updates)#

Returns a new object replacing the specified fields with new values.