metrax.Average#
- class metrax.Average(total: Array, count: Array)#
Bases:
AverageAverage Metric inherits clu.metrics.Average and performs safe division.
Methods
__init__(total, count)compute()Computes final metrics from intermediate values.
compute_value()Wraps compute() and returns a values.Value.
empty()Returns an empty instance (i.e. .merge(Metric.empty()) is a no-op).
from_fun(fun)Calls cls.from_model_output with the return value from fun.
from_model_output(values[, sample_weights])Updates the metric.
from_output(name)Calls cls.from_model_output with model output named name.
merge(other)Returns Metric that is the accumulation of self and other.
reduce()Reduces the metric along it first axis by calling _reduce_merge().
replace(**updates)Returns a new object replacing the specified fields with new values.
Attributes
totalcount- classmethod from_model_output(values: Array, sample_weights: Array | None = None) Average#
Updates the metric.
- Parameters:
values – A floating point 1D vector representing the values. The shape
be (should)
sample_weights – An optional floating point 1D vector representing the weight of each sample. The shape should be (batch_size,).
- Returns:
Updated Average metric.
- compute() Array#
Computes final metrics from intermediate values.
- replace(**updates)#
Returns a new object replacing the specified fields with new values.