metrax.Recall#
- class metrax.Recall(true_positives: Array, false_negatives: Array)#
Bases:
MetricComputes recall for binary classification given predictions and labels.
It is calculated as:
\[Recall = \frac{TP}{TP + FN}\]- where:
TP (True Positives): Number of correctly predicted positive cases
FN (False Negatives): Number of incorrectly predicted negative cases
A threshold parameter (default 0.5) is used to convert probability predictions to binary predictions.
- true_positives#
The count of true positive instances from the given data, label, and threshold.
- Type:
jax.Array
- false_negatives#
The count of false negative instances from the given data, label, and threshold.
- Type:
jax.Array
Methods
__init__(true_positives, false_negatives)compute()Computes final metrics from intermediate values.
compute_value()Wraps compute() and returns a values.Value.
empty()Returns an empty instance (i.e. .merge(Metric.empty()) is a no-op).
from_fun(fun)Calls cls.from_model_output with the return value from fun.
from_model_output(predictions, labels[, ...])Updates the metric.
from_output(name)Calls cls.from_model_output with model output named name.
merge(other)Returns Metric that is the accumulation of self and other.
reduce()Reduces the metric along it first axis by calling _reduce_merge().
replace(**updates)Returns a new object replacing the specified fields with new values.
Attributes
- true_positives: Array#
- false_negatives: Array#
- classmethod from_model_output(predictions: Array, labels: Array, threshold: float = 0.5) Recall#
Updates the metric.
- Parameters:
predictions – A floating point 1D vector whose values are in the range [0, 1]. The shape should be (batch_size,).
labels – True value. The value is expected to be 0 or 1. The shape should be (batch_size,).
threshold – The threshold to use for the binary classification.
- Returns:
Updated Recall metric. The shape should be a single scalar.
- Raises:
ValueError – If type of labels is wrong or the shapes of predictions
and labels are incompatible. –
- merge(other: Recall) Recall#
Returns Metric that is the accumulation of self and other.
- Parameters:
other – A Metric whose intermediate values should be accumulated onto the values of self. Note that in a distributed setting, other will typically be the output of a jax.lax parallel operator and thus have a dimension added to the dataclass returned by .from_model_output().
- Returns:
A new Metric that accumulates the value from both self and other.
- compute() Array#
Computes final metrics from intermediate values.
- replace(**updates)#
Returns a new object replacing the specified fields with new values.