metrax.Dice#

class metrax.Dice(intersection: Array, sum_pred: Array, sum_true: Array)#

Bases: Metric

Computes the Dice coefficient between y_true and y_pred.

Dice is a similarity measure used to measure overlap between two samples. A Dice score of 1 indicates perfect overlap; 0 indicates no overlap.

The formula is:

\[\text{Dice} = \frac{2 \cdot \sum (y_{\text{true}} \cdot y_{\text{pred}})} {\sum y_{\text{true}} + \sum y_{\text{pred}} + \epsilon}\]
intersection#

Sum of element-wise product between y_true and y_pred.

Type:

jax.Array

sum_true#

Sum of y_true across all examples.

Type:

jax.Array

sum_pred#

Sum of y_pred across all examples.

Type:

jax.Array

__init__(intersection: Array, sum_pred: Array, sum_true: Array) None#

Methods

__init__(intersection, sum_pred, sum_true)

compute()

Computes final metrics from intermediate values.

compute_value()

Wraps compute() and returns a values.Value.

empty()

Returns an empty instance (i.e. .merge(Metric.empty()) is a no-op).

from_fun(fun)

Calls cls.from_model_output with the return value from fun.

from_model_output(predictions, labels)

Updates the metric.

from_output(name)

Calls cls.from_model_output with model output named name.

merge(other)

Returns Metric that is the accumulation of self and other.

reduce()

Reduces the metric along it first axis by calling _reduce_merge().

replace(**updates)

Returns a new object replacing the specified fields with new values.

Attributes

intersection: Array#
sum_pred: Array#
sum_true: Array#
classmethod empty() Dice#

Returns an empty instance (i.e. .merge(Metric.empty()) is a no-op).

classmethod from_model_output(predictions: Array, labels: Array) Dice#

Updates the metric.

Parameters:
  • predictions – A floating point vector whose values are in the range [0, 1]. The shape should be (batch_size,).

  • labels – True value. The value is expected to be 0 or 1. The shape should be (batch_size,).

Returns:

Updated Dice metric.

merge(other: Dice) Dice#

Returns Metric that is the accumulation of self and other.

Parameters:

other – A Metric whose intermediate values should be accumulated onto the values of self. Note that in a distributed setting, other will typically be the output of a jax.lax parallel operator and thus have a dimension added to the dataclass returned by .from_model_output().

Returns:

A new Metric that accumulates the value from both self and other.

compute() Array#

Computes final metrics from intermediate values.

__init__(intersection: Array, sum_pred: Array, sum_true: Array) None#
replace(**updates)#

Returns a new object replacing the specified fields with new values.