metrax.Recall#

class metrax.Recall(true_positives: Array, false_negatives: Array)#

Bases: Metric

Computes recall for binary classification given predictions and labels.

It is calculated as:

\[Recall = \frac{TP}{TP + FN}\]
where:
  • TP (True Positives): Number of correctly predicted positive cases

  • FN (False Negatives): Number of incorrectly predicted negative cases

A threshold parameter (default 0.5) is used to convert probability predictions to binary predictions.

true_positives#

The count of true positive instances from the given data, label, and threshold.

Type:

jax.Array

false_negatives#

The count of false negative instances from the given data, label, and threshold.

Type:

jax.Array

__init__(true_positives: Array, false_negatives: Array) None#

Methods

__init__(true_positives, false_negatives)

compute()

Computes final metrics from intermediate values.

compute_value()

Wraps compute() and returns a values.Value.

empty()

Returns an empty instance (i.e. .merge(Metric.empty()) is a no-op).

from_fun(fun)

Calls cls.from_model_output with the return value from fun.

from_model_output(predictions, labels[, ...])

Updates the metric.

from_output(name)

Calls cls.from_model_output with model output named name.

merge(other)

Returns Metric that is the accumulation of self and other.

reduce()

Reduces the metric along it first axis by calling _reduce_merge().

replace(**updates)

Returns a new object replacing the specified fields with new values.

Attributes

true_positives: Array#
false_negatives: Array#
classmethod empty() Recall#

Returns an empty instance (i.e. .merge(Metric.empty()) is a no-op).

classmethod from_model_output(predictions: Array, labels: Array, threshold: float = 0.5) Recall#

Updates the metric.

Parameters:
  • predictions – A floating point 1D vector whose values are in the range [0, 1]. The shape should be (batch_size,).

  • labels – True value. The value is expected to be 0 or 1. The shape should be (batch_size,).

  • threshold – The threshold to use for the binary classification.

Returns:

Updated Recall metric. The shape should be a single scalar.

Raises:
  • ValueError – If type of labels is wrong or the shapes of predictions

  • and labels are incompatible.

merge(other: Recall) Recall#

Returns Metric that is the accumulation of self and other.

Parameters:

other – A Metric whose intermediate values should be accumulated onto the values of self. Note that in a distributed setting, other will typically be the output of a jax.lax parallel operator and thus have a dimension added to the dataclass returned by .from_model_output().

Returns:

A new Metric that accumulates the value from both self and other.

compute() Array#

Computes final metrics from intermediate values.

__init__(true_positives: Array, false_negatives: Array) None#
replace(**updates)#

Returns a new object replacing the specified fields with new values.