metrax
- class metrax.MSE(total: Array, count: Array)
Bases:
AverageComputes the mean squared error for regression problems given predictions and labels.
- classmethod from_model_output(predictions: Array, labels: Array, sample_weights: Array | None = None) MSE
Updates the metric.
- Parameters:
predictions – A floating point 1D vector representing the prediction generated from the model. The shape should be (batch_size,).
labels – True value. The shape should be (batch_size,).
sample_weights – An optional floating point 1D vector representing the weight of each sample. The shape should be (batch_size,).
- Returns:
Updated MSE metric. The shape should be a single scalar.
- Raises:
ValueError – If type of labels is wrong or the shapes of predictions
and labels are incompatible. –
- replace(**updates)
Returns a new object replacing the specified fields with new values.
- class metrax.RMSE(total: Array, count: Array)
Bases:
MSEComputes the root mean squared error for regression problems given predictions and labels.
- compute() Array
Computes final metrics from intermediate values.
- replace(**updates)
Returns a new object replacing the specified fields with new values.
- class metrax.RSQUARED(total: Array, count: Array, sum_of_squared_error: Array, sum_of_squared_label: Array)
Bases:
MetricComputes the r-squared score of a scalar or a batch of tensors.
R-squared is a measure of how well the regression model fits the data. It measures the proportion of the variance in the dependent variable that is explained by the independent variable(s). It is defined as 1 - SSE / SST, where SSE is the sum of squared errors and SST is the total sum of squares.
- total: Array
- count: Array
- sum_of_squared_error: Array
- sum_of_squared_label: Array
- classmethod from_model_output(predictions: Array, labels: Array, sample_weights: Array | None = None) RSQUARED
Updates the metric.
- Parameters:
predictions – A floating point 1D vector representing the prediction generated from the model. The shape should be (batch_size,).
labels – True value. The shape should be (batch_size,).
sample_weights – An optional floating point 1D vector representing the weight of each sample. The shape should be (batch_size,).
- Returns:
Updated RSQUARED metric. The shape should be a single scalar.
- Raises:
ValueError – If type of labels is wrong or the shapes of predictions
and labels are incompatible. –
- merge(other: RSQUARED) RSQUARED
Returns Metric that is the accumulation of self and other.
- Parameters:
other – A Metric whose intermediate values should be accumulated onto the values of self. Note that in a distributed setting, other will typically be the output of a jax.lax parallel operator and thus have a dimension added to the dataclass returned by .from_model_output().
- Returns:
A new Metric that accumulates the value from both self and other.
- compute() Array
Computes the r-squared score.
Since we don’t know the mean of the labels before we aggregate all of the data, we will manipulate the formula to be: sst = sum_i (x_i - mean)^2
= sum_i (x_i^2 - 2 x_i mean + mean^2) = sum_i x_i^2 - 2 mean sum_i x_i + N * mean^2 = sum_i x_i^2 - 2 mean * N * mean + N * mean^2 = sum_i x_i^2 - N * mean^2
- Returns:
The r-squared score.
- __init__(total: Array, count: Array, sum_of_squared_error: Array, sum_of_squared_label: Array) None
- replace(**updates)
Returns a new object replacing the specified fields with new values.
- class metrax.Precision(true_positives: Array, false_positives: Array)
Bases:
MetricComputes precision for binary classification given predictions and labels.
- true_positives
The count of true positive instances from the given data, label, and threshold.
- Type:
jax.Array
- false_positives
The count of false positive instances from the given data, label, and threshold.
- Type:
jax.Array
- true_positives: Array
- false_positives: Array
- classmethod from_model_output(predictions: Array, labels: Array, threshold: float = 0.5) Precision
Updates the metric.
- Parameters:
predictions – A floating point 1D vector whose values are in the range [0, 1]. The shape should be (batch_size,).
labels – True value. The value is expected to be 0 or 1. The shape should be (batch_size,).
threshold – The threshold to use for the binary classification.
- Returns:
Updated Precision metric. The shape should be a single scalar.
- Raises:
ValueError – If type of labels is wrong or the shapes of predictions
and labels are incompatible. –
- merge(other: Precision) Precision
Returns Metric that is the accumulation of self and other.
- Parameters:
other – A Metric whose intermediate values should be accumulated onto the values of self. Note that in a distributed setting, other will typically be the output of a jax.lax parallel operator and thus have a dimension added to the dataclass returned by .from_model_output().
- Returns:
A new Metric that accumulates the value from both self and other.
- compute() Array
Computes final metrics from intermediate values.
- replace(**updates)
Returns a new object replacing the specified fields with new values.
- class metrax.Recall(true_positives: Array, false_negatives: Array)
Bases:
MetricComputes recall for binary classification given predictions and labels.
- true_positives
The count of true positive instances from the given data, label, and threshold.
- Type:
jax.Array
- false_negatives
The count of false negative instances from the given data, label, and threshold.
- Type:
jax.Array
- true_positives: Array
- false_negatives: Array
- classmethod from_model_output(predictions: Array, labels: Array, threshold: float = 0.5) Recall
Updates the metric.
- Parameters:
predictions – A floating point 1D vector whose values are in the range [0, 1]. The shape should be (batch_size,).
labels – True value. The value is expected to be 0 or 1. The shape should be (batch_size,).
threshold – The threshold to use for the binary classification.
- Returns:
Updated Recall metric. The shape should be a single scalar.
- Raises:
ValueError – If type of labels is wrong or the shapes of predictions
and labels are incompatible. –
- merge(other: Recall) Recall
Returns Metric that is the accumulation of self and other.
- Parameters:
other – A Metric whose intermediate values should be accumulated onto the values of self. Note that in a distributed setting, other will typically be the output of a jax.lax parallel operator and thus have a dimension added to the dataclass returned by .from_model_output().
- Returns:
A new Metric that accumulates the value from both self and other.
- compute() Array
Computes final metrics from intermediate values.
- replace(**updates)
Returns a new object replacing the specified fields with new values.
- class metrax.AUCPR(true_positives: Array, false_positives: Array, false_negatives: Array, num_thresholds: int)
Bases:
MetricComputes area under the precision-recall curve for binary classification given predictions and labels.
AUC-PR Curve metric have a number of known issues so use it with caution. - PR curves are highly class balance sensitive. - PR is a non-monotonic function and thus its “area” is not directly
proportional to performance.
PR-AUC has no standard implementation and different libraries will give different results. Some libraries will interpolate between points, others will assume a step function (or trapezoidal as sklearn does). Some libraries will compute the convex hull of the PR curve, others will not. Because PR is non monotonic, its value is sensitive to the number of samples along the curve (more so than ROC-AUC).
- true_positives
The count of true positive instances from the given data and label at each threshold.
- Type:
jax.Array
- false_positives
The count of false positive instances from the given data and label at each threshold.
- Type:
jax.Array
- false_negatives
The count of false negative instances from the given data and label at each threshold.
- Type:
jax.Array
- true_positives: Array
- false_positives: Array
- false_negatives: Array
- classmethod from_model_output(predictions: Array, labels: Array, sample_weights: Array | None = None, num_thresholds: int = 200) AUCPR
Updates the metric.
- Parameters:
predictions – A floating point 1D vector whose values are in the range [0, 1]. The shape should be (batch_size,).
labels – True value. The value is expected to be 0 or 1. The shape should be (batch_size,).
sample_weights – An optional floating point 1D vector representing the weight of each sample. The shape should be (batch_size,).
num_thresholds – The number of thresholds to use. Default is 200.
- Returns:
The area under the precision-recall curve. The shape should be a single scalar.
- Raises:
ValueError – If type of labels is wrong or the shapes of predictions
and labels are incompatible. –
- merge(other: AUCPR) AUCPR
Returns Metric that is the accumulation of self and other.
- Parameters:
other – A Metric whose intermediate values should be accumulated onto the values of self. Note that in a distributed setting, other will typically be the output of a jax.lax parallel operator and thus have a dimension added to the dataclass returned by .from_model_output().
- Returns:
A new Metric that accumulates the value from both self and other.
- interpolate_pr_auc() Array
Interpolation formula inspired by section 4 of Davis & Goadrich 2006.
https://minds.wisconsin.edu/handle/1793/60482
Note here we derive & use a closed formula not present in the paper as follows:
Precision = TP / (TP + FP) = TP / P
Modeling all of TP (true positive), FP (false positive) and their sum P = TP + FP (predicted positive) as varying linearly within each interval [A, B] between successive thresholds, we get
- Precision slope = dTP / dP
= (TP_B - TP_A) / (P_B - P_A) = (TP - TP_A) / (P - P_A)
Precision = (TP_A + slope * (P - P_A)) / P
The area within the interval is (slope / total_pos_weight) times
int_A^B{Precision.dP} = int_A^B{(TP_A + slope * (P - P_A)) * dP / P} int_A^B{Precision.dP} = int_A^B{slope * dP + intercept * dP / P}
where intercept = TP_A - slope * P_A = TP_B - slope * P_B, resulting in
int_A^B{Precision.dP} = TP_B - TP_A + intercept * log(P_B / P_A)
Bringing back the factor (slope / total_pos_weight) we’d put aside, we get
slope * [dTP + intercept * log(P_B / P_A)] / total_pos_weight
where dTP == TP_B - TP_A.
Note that when P_A == 0 the above calculation simplifies into
int_A^B{Precision.dTP} = int_A^B{slope * dTP} = slope * (TP_B - TP_A)
which is really equivalent to imputing constant precision throughout the first bucket having >0 true positives.
- Returns:
A float scalar jax.Array that is an approximation of the area under the P-R curve.
- Return type:
pr_auc
- compute() Array
Computes final metrics from intermediate values.
- __init__(true_positives: Array, false_positives: Array, false_negatives: Array, num_thresholds: int) None
- replace(**updates)
Returns a new object replacing the specified fields with new values.
- class metrax.AUCROC(true_positives: Array, true_negatives: Array, false_positives: Array, false_negatives: Array, num_thresholds: int)
Bases:
MetricComputes area under the receiver operation characteristic curve for binary classification given predictions and labels.
- true_positives
The count of true positive instances from the given data and label at each threshold.
- Type:
jax.Array
- false_positives
The count of false positive instances from the given data and label at each threshold.
- Type:
jax.Array
- total_count
The count of every data point.
- true_positives: Array
- true_negatives: Array
- false_positives: Array
- false_negatives: Array
- classmethod from_model_output(predictions: Array, labels: Array, sample_weights: Array | None = None, num_thresholds: int = 200) AUCROC
Updates the metric.
- Parameters:
predictions – A floating point 1D vector whose values are in the range [0, 1]. The shape should be (batch_size,).
labels – True value. The value is expected to be 0 or 1. The shape should be (batch_size,).
sample_weights – An optional floating point 1D vector representing the weight of each sample. The shape should be (batch_size,).
num_thresholds – The number of thresholds to use. Default is 200.
- Returns:
The area under the receiver operation characteristic curve. The shape should be a single scalar.
- Raises:
ValueError – If type of labels is wrong or the shapes of predictions
and labels are incompatible. –
- merge(other: AUCROC) AUCROC
Returns Metric that is the accumulation of self and other.
- Parameters:
other – A Metric whose intermediate values should be accumulated onto the values of self. Note that in a distributed setting, other will typically be the output of a jax.lax parallel operator and thus have a dimension added to the dataclass returned by .from_model_output().
- Returns:
A new Metric that accumulates the value from both self and other.
- compute() Array
Computes final metrics from intermediate values.
- __init__(true_positives: Array, true_negatives: Array, false_positives: Array, false_negatives: Array, num_thresholds: int) None
- replace(**updates)
Returns a new object replacing the specified fields with new values.