metrax.RSQUARED#

class metrax.RSQUARED(total: Array, count: Array, sum_of_squared_error: Array, sum_of_squared_label: Array)#

Bases: Metric

Computes the r-squared score of a scalar or a batch of tensors.

R-squared is a measure of how well the regression model fits the data. It measures the proportion of the variance in the dependent variable that is explained by the independent variable(s). It is defined as 1 - SSE / SST, where SSE is the sum of squared errors and SST is the total sum of squares.

\[R^2 = 1 - \frac{SSE}{SST}\]
where:
\[SSE = \sum_{i=1}^{N} (y_i - \hat{y}_i)^2\]
\[SST = \sum_{i=1}^{N} (y_i - \bar{y})^2\]

When sample weights \(w_i\) are provided:

\[R^2 = 1 - \frac{\sum_{i=1}^{N} w_i(y_i - \hat{y}_i)^2}{\sum_{i=1}^{N} w_i(y_i - \bar{y})^2}\]
where:
  • \(y_i\) are true values

  • \(\hat{y}_i\) are predictions

  • \(\bar{y}\) is the mean of true values

  • \(w_i\) are sample weights

  • \(N\) is the number of samples

The score ranges from -∞ to 1, where 1 indicates perfect prediction and 0 indicates that the model performs no better than a horizontal line.

__init__(total: Array, count: Array, sum_of_squared_error: Array, sum_of_squared_label: Array) None#

Methods

__init__(total, count, sum_of_squared_error, ...)

compute()

Computes the r-squared score.

compute_value()

Wraps compute() and returns a values.Value.

empty()

Returns an empty instance (i.e. .merge(Metric.empty()) is a no-op).

from_fun(fun)

Calls cls.from_model_output with the return value from fun.

from_model_output(predictions, labels[, ...])

Updates the metric.

from_output(name)

Calls cls.from_model_output with model output named name.

merge(other)

Returns Metric that is the accumulation of self and other.

reduce()

Reduces the metric along it first axis by calling _reduce_merge().

replace(**updates)

Returns a new object replacing the specified fields with new values.

Attributes

total: Array#
count: Array#
sum_of_squared_error: Array#
sum_of_squared_label: Array#
classmethod empty() RSQUARED#

Returns an empty instance (i.e. .merge(Metric.empty()) is a no-op).

classmethod from_model_output(predictions: Array, labels: Array, sample_weights: Array | None = None) RSQUARED#

Updates the metric.

Parameters:
  • predictions – A floating point 1D vector representing the prediction generated from the model. The shape should be (batch_size,).

  • labels – True value. The shape should be (batch_size,).

  • sample_weights – An optional floating point 1D vector representing the weight of each sample. The shape should be (batch_size,).

Returns:

Updated RSQUARED metric. The shape should be a single scalar.

Raises:
  • ValueError – If type of labels is wrong or the shapes of predictions

  • and labels are incompatible.

merge(other: RSQUARED) RSQUARED#

Returns Metric that is the accumulation of self and other.

Parameters:

other – A Metric whose intermediate values should be accumulated onto the values of self. Note that in a distributed setting, other will typically be the output of a jax.lax parallel operator and thus have a dimension added to the dataclass returned by .from_model_output().

Returns:

A new Metric that accumulates the value from both self and other.

compute() Array#

Computes the r-squared score.

Since we don’t know the mean of the labels before we aggregate all of the data, we will manipulate the formula to be: sst = sum_i (x_i - mean)^2

= sum_i (x_i^2 - 2 x_i mean + mean^2) = sum_i x_i^2 - 2 mean sum_i x_i + N * mean^2 = sum_i x_i^2 - 2 mean * N * mean + N * mean^2 = sum_i x_i^2 - N * mean^2

Returns:

The r-squared score.

__init__(total: Array, count: Array, sum_of_squared_error: Array, sum_of_squared_label: Array) None#
replace(**updates)#

Returns a new object replacing the specified fields with new values.