metrax

class metrax.MSE(total: Array, count: Array)

Bases: Average

Computes the mean squared error for regression problems given predictions and labels.

classmethod from_model_output(predictions: Array, labels: Array, sample_weights: Array | None = None) MSE

Updates the metric.

Parameters:
  • predictions – A floating point 1D vector representing the prediction generated from the model. The shape should be (batch_size,).

  • labels – True value. The shape should be (batch_size,).

  • sample_weights – An optional floating point 1D vector representing the weight of each sample. The shape should be (batch_size,).

Returns:

Updated MSE metric. The shape should be a single scalar.

Raises:
  • ValueError – If type of labels is wrong or the shapes of predictions

  • and labels are incompatible.

__init__(total: Array, count: Array) None
replace(**updates)

Returns a new object replacing the specified fields with new values.

class metrax.RMSE(total: Array, count: Array)

Bases: MSE

Computes the root mean squared error for regression problems given predictions and labels.

compute() Array

Computes final metrics from intermediate values.

__init__(total: Array, count: Array) None
replace(**updates)

Returns a new object replacing the specified fields with new values.

class metrax.RSQUARED(total: Array, count: Array, sum_of_squared_error: Array, sum_of_squared_label: Array)

Bases: Metric

Computes the r-squared score of a scalar or a batch of tensors.

R-squared is a measure of how well the regression model fits the data. It measures the proportion of the variance in the dependent variable that is explained by the independent variable(s). It is defined as 1 - SSE / SST, where SSE is the sum of squared errors and SST is the total sum of squares.

total: Array
count: Array
sum_of_squared_error: Array
sum_of_squared_label: Array
classmethod from_model_output(predictions: Array, labels: Array, sample_weights: Array | None = None) RSQUARED

Updates the metric.

Parameters:
  • predictions – A floating point 1D vector representing the prediction generated from the model. The shape should be (batch_size,).

  • labels – True value. The shape should be (batch_size,).

  • sample_weights – An optional floating point 1D vector representing the weight of each sample. The shape should be (batch_size,).

Returns:

Updated RSQUARED metric. The shape should be a single scalar.

Raises:
  • ValueError – If type of labels is wrong or the shapes of predictions

  • and labels are incompatible.

merge(other: RSQUARED) RSQUARED

Returns Metric that is the accumulation of self and other.

Parameters:

other – A Metric whose intermediate values should be accumulated onto the values of self. Note that in a distributed setting, other will typically be the output of a jax.lax parallel operator and thus have a dimension added to the dataclass returned by .from_model_output().

Returns:

A new Metric that accumulates the value from both self and other.

compute() Array

Computes the r-squared score.

Since we don’t know the mean of the labels before we aggregate all of the data, we will manipulate the formula to be: sst = sum_i (x_i - mean)^2

= sum_i (x_i^2 - 2 x_i mean + mean^2) = sum_i x_i^2 - 2 mean sum_i x_i + N * mean^2 = sum_i x_i^2 - 2 mean * N * mean + N * mean^2 = sum_i x_i^2 - N * mean^2

Returns:

The r-squared score.

__init__(total: Array, count: Array, sum_of_squared_error: Array, sum_of_squared_label: Array) None
replace(**updates)

Returns a new object replacing the specified fields with new values.

class metrax.Precision(true_positives: Array, false_positives: Array)

Bases: Metric

Computes precision for binary classification given predictions and labels.

true_positives

The count of true positive instances from the given data, label, and threshold.

Type:

jax.Array

false_positives

The count of false positive instances from the given data, label, and threshold.

Type:

jax.Array

true_positives: Array
false_positives: Array
classmethod from_model_output(predictions: Array, labels: Array, threshold: float = 0.5) Precision

Updates the metric.

Parameters:
  • predictions – A floating point 1D vector whose values are in the range [0, 1]. The shape should be (batch_size,).

  • labels – True value. The value is expected to be 0 or 1. The shape should be (batch_size,).

  • threshold – The threshold to use for the binary classification.

Returns:

Updated Precision metric. The shape should be a single scalar.

Raises:
  • ValueError – If type of labels is wrong or the shapes of predictions

  • and labels are incompatible.

merge(other: Precision) Precision

Returns Metric that is the accumulation of self and other.

Parameters:

other – A Metric whose intermediate values should be accumulated onto the values of self. Note that in a distributed setting, other will typically be the output of a jax.lax parallel operator and thus have a dimension added to the dataclass returned by .from_model_output().

Returns:

A new Metric that accumulates the value from both self and other.

compute() Array

Computes final metrics from intermediate values.

__init__(true_positives: Array, false_positives: Array) None
replace(**updates)

Returns a new object replacing the specified fields with new values.

class metrax.Recall(true_positives: Array, false_negatives: Array)

Bases: Metric

Computes recall for binary classification given predictions and labels.

true_positives

The count of true positive instances from the given data, label, and threshold.

Type:

jax.Array

false_negatives

The count of false negative instances from the given data, label, and threshold.

Type:

jax.Array

true_positives: Array
false_negatives: Array
classmethod from_model_output(predictions: Array, labels: Array, threshold: float = 0.5) Recall

Updates the metric.

Parameters:
  • predictions – A floating point 1D vector whose values are in the range [0, 1]. The shape should be (batch_size,).

  • labels – True value. The value is expected to be 0 or 1. The shape should be (batch_size,).

  • threshold – The threshold to use for the binary classification.

Returns:

Updated Recall metric. The shape should be a single scalar.

Raises:
  • ValueError – If type of labels is wrong or the shapes of predictions

  • and labels are incompatible.

merge(other: Recall) Recall

Returns Metric that is the accumulation of self and other.

Parameters:

other – A Metric whose intermediate values should be accumulated onto the values of self. Note that in a distributed setting, other will typically be the output of a jax.lax parallel operator and thus have a dimension added to the dataclass returned by .from_model_output().

Returns:

A new Metric that accumulates the value from both self and other.

compute() Array

Computes final metrics from intermediate values.

__init__(true_positives: Array, false_negatives: Array) None
replace(**updates)

Returns a new object replacing the specified fields with new values.

class metrax.AUCPR(true_positives: Array, false_positives: Array, false_negatives: Array, num_thresholds: int)

Bases: Metric

Computes area under the precision-recall curve for binary classification given predictions and labels.

AUC-PR Curve metric have a number of known issues so use it with caution. - PR curves are highly class balance sensitive. - PR is a non-monotonic function and thus its “area” is not directly

proportional to performance.

  • PR-AUC has no standard implementation and different libraries will give different results. Some libraries will interpolate between points, others will assume a step function (or trapezoidal as sklearn does). Some libraries will compute the convex hull of the PR curve, others will not. Because PR is non monotonic, its value is sensitive to the number of samples along the curve (more so than ROC-AUC).

true_positives

The count of true positive instances from the given data and label at each threshold.

Type:

jax.Array

false_positives

The count of false positive instances from the given data and label at each threshold.

Type:

jax.Array

false_negatives

The count of false negative instances from the given data and label at each threshold.

Type:

jax.Array

true_positives: Array
false_positives: Array
false_negatives: Array
num_thresholds: int
classmethod from_model_output(predictions: Array, labels: Array, sample_weights: Array | None = None, num_thresholds: int = 200) AUCPR

Updates the metric.

Parameters:
  • predictions – A floating point 1D vector whose values are in the range [0, 1]. The shape should be (batch_size,).

  • labels – True value. The value is expected to be 0 or 1. The shape should be (batch_size,).

  • sample_weights – An optional floating point 1D vector representing the weight of each sample. The shape should be (batch_size,).

  • num_thresholds – The number of thresholds to use. Default is 200.

Returns:

The area under the precision-recall curve. The shape should be a single scalar.

Raises:
  • ValueError – If type of labels is wrong or the shapes of predictions

  • and labels are incompatible.

merge(other: AUCPR) AUCPR

Returns Metric that is the accumulation of self and other.

Parameters:

other – A Metric whose intermediate values should be accumulated onto the values of self. Note that in a distributed setting, other will typically be the output of a jax.lax parallel operator and thus have a dimension added to the dataclass returned by .from_model_output().

Returns:

A new Metric that accumulates the value from both self and other.

interpolate_pr_auc() Array

Interpolation formula inspired by section 4 of Davis & Goadrich 2006.

https://minds.wisconsin.edu/handle/1793/60482

Note here we derive & use a closed formula not present in the paper as follows:

Precision = TP / (TP + FP) = TP / P

Modeling all of TP (true positive), FP (false positive) and their sum P = TP + FP (predicted positive) as varying linearly within each interval [A, B] between successive thresholds, we get

Precision slope = dTP / dP

= (TP_B - TP_A) / (P_B - P_A) = (TP - TP_A) / (P - P_A)

Precision = (TP_A + slope * (P - P_A)) / P

The area within the interval is (slope / total_pos_weight) times

int_A^B{Precision.dP} = int_A^B{(TP_A + slope * (P - P_A)) * dP / P} int_A^B{Precision.dP} = int_A^B{slope * dP + intercept * dP / P}

where intercept = TP_A - slope * P_A = TP_B - slope * P_B, resulting in

int_A^B{Precision.dP} = TP_B - TP_A + intercept * log(P_B / P_A)

Bringing back the factor (slope / total_pos_weight) we’d put aside, we get

slope * [dTP + intercept * log(P_B / P_A)] / total_pos_weight

where dTP == TP_B - TP_A.

Note that when P_A == 0 the above calculation simplifies into

int_A^B{Precision.dTP} = int_A^B{slope * dTP} = slope * (TP_B - TP_A)

which is really equivalent to imputing constant precision throughout the first bucket having >0 true positives.

Returns:

A float scalar jax.Array that is an approximation of the area under the P-R curve.

Return type:

pr_auc

compute() Array

Computes final metrics from intermediate values.

__init__(true_positives: Array, false_positives: Array, false_negatives: Array, num_thresholds: int) None
replace(**updates)

Returns a new object replacing the specified fields with new values.

class metrax.AUCROC(true_positives: Array, true_negatives: Array, false_positives: Array, false_negatives: Array, num_thresholds: int)

Bases: Metric

Computes area under the receiver operation characteristic curve for binary classification given predictions and labels.

true_positives

The count of true positive instances from the given data and label at each threshold.

Type:

jax.Array

false_positives

The count of false positive instances from the given data and label at each threshold.

Type:

jax.Array

total_count

The count of every data point.

true_positives: Array
true_negatives: Array
false_positives: Array
false_negatives: Array
num_thresholds: int
classmethod from_model_output(predictions: Array, labels: Array, sample_weights: Array | None = None, num_thresholds: int = 200) AUCROC

Updates the metric.

Parameters:
  • predictions – A floating point 1D vector whose values are in the range [0, 1]. The shape should be (batch_size,).

  • labels – True value. The value is expected to be 0 or 1. The shape should be (batch_size,).

  • sample_weights – An optional floating point 1D vector representing the weight of each sample. The shape should be (batch_size,).

  • num_thresholds – The number of thresholds to use. Default is 200.

Returns:

The area under the receiver operation characteristic curve. The shape should be a single scalar.

Raises:
  • ValueError – If type of labels is wrong or the shapes of predictions

  • and labels are incompatible.

merge(other: AUCROC) AUCROC

Returns Metric that is the accumulation of self and other.

Parameters:

other – A Metric whose intermediate values should be accumulated onto the values of self. Note that in a distributed setting, other will typically be the output of a jax.lax parallel operator and thus have a dimension added to the dataclass returned by .from_model_output().

Returns:

A new Metric that accumulates the value from both self and other.

compute() Array

Computes final metrics from intermediate values.

__init__(true_positives: Array, true_negatives: Array, false_positives: Array, false_negatives: Array, num_thresholds: int) None
replace(**updates)

Returns a new object replacing the specified fields with new values.