metrax.Average#

class metrax.Average(total: Array, count: Array)#

Bases: Average

Average Metric inherits clu.metrics.Average and performs safe division.

__init__(total: Array, count: Array) None#

Methods

__init__(total, count)

compute()

Computes final metrics from intermediate values.

compute_value()

Wraps compute() and returns a values.Value.

empty()

Returns an empty instance (i.e. .merge(Metric.empty()) is a no-op).

from_fun(fun)

Calls cls.from_model_output with the return value from fun.

from_model_output(values[, sample_weights])

Updates the metric.

from_output(name)

Calls cls.from_model_output with model output named name.

merge(other)

Returns Metric that is the accumulation of self and other.

reduce()

Reduces the metric along it first axis by calling _reduce_merge().

replace(**updates)

Returns a new object replacing the specified fields with new values.

Attributes

total

count

classmethod from_model_output(values: Array, sample_weights: Array | None = None) Average#

Updates the metric.

Parameters:
  • values – A floating point 1D vector representing the values. The shape

  • be (should)

  • sample_weights – An optional floating point 1D vector representing the weight of each sample. The shape should be (batch_size,).

Returns:

Updated Average metric.

compute() Array#

Computes final metrics from intermediate values.

__init__(total: Array, count: Array) None#
replace(**updates)#

Returns a new object replacing the specified fields with new values.