metrax.Dice#
- class metrax.Dice(intersection: Array, sum_pred: Array, sum_true: Array)#
Bases:
MetricComputes the Dice coefficient between y_true and y_pred.
Dice is a similarity measure used to measure overlap between two samples. A Dice score of 1 indicates perfect overlap; 0 indicates no overlap.
The formula is:
\[\text{Dice} = \frac{2 \cdot \sum (y_{\text{true}} \cdot y_{\text{pred}})} {\sum y_{\text{true}} + \sum y_{\text{pred}} + \epsilon}\]- intersection#
Sum of element-wise product between y_true and y_pred.
- Type:
jax.Array
- sum_true#
Sum of y_true across all examples.
- Type:
jax.Array
- sum_pred#
Sum of y_pred across all examples.
- Type:
jax.Array
Methods
__init__(intersection, sum_pred, sum_true)compute()Computes final metrics from intermediate values.
compute_value()Wraps compute() and returns a values.Value.
empty()Returns an empty instance (i.e. .merge(Metric.empty()) is a no-op).
from_fun(fun)Calls cls.from_model_output with the return value from fun.
from_model_output(predictions, labels)Updates the metric.
from_output(name)Calls cls.from_model_output with model output named name.
merge(other)Returns Metric that is the accumulation of self and other.
reduce()Reduces the metric along it first axis by calling _reduce_merge().
replace(**updates)Returns a new object replacing the specified fields with new values.
Attributes
- intersection: Array#
- sum_pred: Array#
- sum_true: Array#
- classmethod from_model_output(predictions: Array, labels: Array) Dice#
Updates the metric.
- Parameters:
predictions – A floating point vector whose values are in the range [0, 1]. The shape should be (batch_size,).
labels – True value. The value is expected to be 0 or 1. The shape should be (batch_size,).
- Returns:
Updated Dice metric.
- merge(other: Dice) Dice#
Returns Metric that is the accumulation of self and other.
- Parameters:
other – A Metric whose intermediate values should be accumulated onto the values of self. Note that in a distributed setting, other will typically be the output of a jax.lax parallel operator and thus have a dimension added to the dataclass returned by .from_model_output().
- Returns:
A new Metric that accumulates the value from both self and other.
- compute() Array#
Computes final metrics from intermediate values.
- replace(**updates)#
Returns a new object replacing the specified fields with new values.