Ordered Probability Metrics
Ordered probability metrics evaluate predicted probabilities for ordered factor outcomes. These metrics are specifically designed for ordinal classification where the class levels have a natural ordering.
Overview
Use when:
Truth is an ordered factor (e.g.,
ordered(c("low", "medium", "high")))Predictions are probabilities for each ordered class
The ordering of classes matters (e.g., severity ratings, performance levels)
Key differences from regular probability metrics:
Uses cumulative probabilities to respect ordering
Penalizes predictions that are “further away” in the ordering
No averaging types (works the same for any number of classes)
Examples: Ranked Probability Score (RPS)
Reference implementation: R/orderedprob-ranked_prob_score.R in yardstick repository
Pattern: Three-Function Approach
1. Implementation Function
# Internal calculation logic
my_metric_impl <- function(truth, estimate, case_weights = NULL) {
# truth: ordered factor
# estimate: matrix with columns for each class level
# case_weights: numeric vector or NULL
# Example: calculate cumulative probabilities
num_class <- nlevels(truth)
inds <- hardhat::fct_encode_one_hot(truth)
cum_ind <- cumulative_rows(inds) # Helper for cumulative sums
cum_estimate <- cumulative_rows(estimate)
case_weights <- vctrs::vec_cast(case_weights, to = double())
# Calculate metric using cumulative probabilities
# ... implementation details ...
}
cumulative_rows <- function(x) {
t(apply(x, 1, cumsum))
}2. Vector Interface
#' @export
my_metric_vec <- function(truth, estimate, na_rm = TRUE, case_weights = NULL, ...) {
# Validate na_rm parameter
check_bool(na_rm)
abort_if_class_pred(truth)
# Determine estimator (typically "standard" for ordered prob metrics)
estimator <- finalize_estimator(truth, metric_class = "my_metric")
# Validate inputs
check_ordered_prob_metric(truth, estimate, case_weights, estimator)
# Handle missing values
if (na_rm) {
result <- yardstick_remove_missing(truth, estimate, case_weights)
truth <- result$truth
estimate <- result$estimate
case_weights <- result$case_weights
} else if (yardstick_any_missing(truth, estimate, case_weights)) {
return(NA_real_)
}
# Call implementation
my_metric_impl(truth, estimate, case_weights)
}3. Data Frame Method
#' My Ordered Probability Metric
#'
#' Description of what this metric measures.
#'
#' @family class probability metrics
#' @templateVar fn my_metric
#' @template return
#'
#' @param data A data frame
#' @param truth Unquoted column with true ordered classes (ordered factor)
#' @param ... Unquoted column(s) with predicted probabilities
#' @param na_rm Remove missing values (default TRUE)
#' @param case_weights Optional case weights column
#'
#' @export
my_metric <- function(data, ...) {
UseMethod("my_metric")
}
my_metric <- new_ordered_prob_metric(
my_metric,
direction = "minimize", # or "maximize"
range = c(0, 1) # or c(0, Inf), etc.
)
#' @export
#' @rdname my_metric
my_metric.data.frame <- function(data, truth, ..., na_rm = TRUE,
case_weights = NULL) {
ordered_prob_metric_summarizer(
name = "my_metric",
fn = my_metric_vec,
data = data,
truth = !!rlang::enquo(truth),
...,
na_rm = na_rm,
case_weights = !!rlang::enquo(case_weights)
)
}Complete Example: Ranked Probability Score
The ranked probability score (RPS) is a Brier score for ordinal data that uses cumulative probabilities.
# R/ranked_prob_score.R
# 1. Implementation function
ranked_prob_score_impl <- function(truth, estimate, case_weights) {
num_class <- nlevels(truth)
inds <- hardhat::fct_encode_one_hot(truth)
cum_ind <- cumulative_rows(inds)
cum_estimate <- cumulative_rows(estimate)
case_weights <- vctrs::vec_cast(case_weights, to = double())
# RPS divides by number of classes minus one
brier_ind(cum_ind, cum_estimate, case_weights) / (num_class - 1) * 2
}
cumulative_rows <- function(x) {
t(apply(x, 1, cumsum))
}
# 2. Vector interface
#' @export
ranked_prob_score_vec <- function(truth, estimate, na_rm = TRUE,
case_weights = NULL, ...) {
check_bool(na_rm)
abort_if_class_pred(truth)
estimator <- finalize_estimator(truth, metric_class = "ranked_prob_score")
check_ordered_prob_metric(truth, estimate, case_weights, estimator)
if (na_rm) {
result <- yardstick_remove_missing(truth, estimate, case_weights)
truth <- result$truth
estimate <- result$estimate
case_weights <- result$case_weights
} else if (yardstick_any_missing(truth, estimate, case_weights)) {
return(NA_real_)
}
ranked_prob_score_impl(truth, estimate, case_weights)
}
# 3. Data frame method
#' @export
ranked_prob_score <- function(data, ...) {
UseMethod("ranked_prob_score")
}
ranked_prob_score <- new_ordered_prob_metric(
ranked_prob_score,
direction = "minimize",
range = c(0, 1)
)
#' @export
#' @rdname ranked_prob_score
ranked_prob_score.data.frame <- function(data, truth, ..., na_rm = TRUE,
case_weights = NULL) {
ordered_prob_metric_summarizer(
name = "ranked_prob_score",
fn = ranked_prob_score_vec,
data = data,
truth = !!enquo(truth),
...,
na_rm = na_rm,
case_weights = !!enquo(case_weights)
)
}Key Validation Function
check_ordered_prob_metric(truth, estimate, case_weights, estimator)This validates:
truthis an ordered factorestimateis a matrix with correct dimensionscase_weightsare valid (if provided)Estimator type is appropriate
Input Format
Truth
# Must be an ordered factor
truth <- ordered(c("low", "medium", "high", "low"))Estimate
# Matrix with probabilities for each level (columns match levels)
# Rows correspond to observations, columns to ordered levels
estimate <- matrix(
c(0.7, 0.2, 0.1, # First observation
0.1, 0.6, 0.3), # Second observation
nrow = 2,
byrow = TRUE
)Cumulative Probabilities
Ordered probability metrics typically use cumulative probabilities:
# Original probabilities
probs <- c(0.2, 0.5, 0.3) # P(class=1), P(class=2), P(class=3)
# Cumulative probabilities
cum_probs <- cumsum(probs) # 0.2, 0.7, 1.0
# P(class <= 1), P(class <= 2), P(class <= 3)This respects the ordering: being one class away is better than being two classes away.
Testing
# tests/testthat/test-my_metric.R
test_that("my_metric works with ordered factors", {
df <- data.frame(
truth = ordered(c("low", "medium", "high")),
low = c(0.8, 0.1, 0.05),
medium = c(0.15, 0.8, 0.15),
high = c(0.05, 0.1, 0.8)
)
result <- my_metric(df, truth, low:high)
expect_equal(result$.metric, "my_metric")
expect_equal(result$.estimator, "standard")
expect_true(is.numeric(result$.estimate))
})
test_that("my_metric validates ordered factor", {
df <- data.frame(
truth = factor(c("a", "b", "c")), # NOT ordered
a = c(0.8, 0.1, 0.1),
b = c(0.1, 0.8, 0.1),
c = c(0.1, 0.1, 0.8)
)
expect_error(my_metric(df, truth, a:c))
})
test_that("my_metric handles case weights", {
df <- data.frame(
truth = ordered(c("low", "medium", "high")),
low = c(0.8, 0.1, 0.05),
medium = c(0.15, 0.8, 0.15),
high = c(0.05, 0.1, 0.8),
weights = c(1, 2, 1)
)
result_weighted <- my_metric(df, truth, low:high, case_weights = weights)
result_unweighted <- my_metric(df, truth, low:high)
# Weighted should differ from unweighted
expect_false(result_weighted$.estimate == result_unweighted$.estimate)
})Common Patterns
Helper for Cumulative Rows
cumulative_rows <- function(x) {
t(apply(x, 1, cumsum))
}Converting Ordered Factor to Indicators
# One-hot encoding
inds <- hardhat::fct_encode_one_hot(truth)
# Then cumulative sum
cum_ind <- cumulative_rows(inds)Key Differences from Regular Probability Metrics
| Aspect | Regular Probability | Ordered Probability |
|---|---|---|
| Truth type | Factor | Ordered factor |
| Calculation | Uses raw probabilities | Uses cumulative probabilities |
| Class ordering | Not considered | Explicitly used |
| Averaging | Macro/micro/weighted | No averaging (standard only) |
| Use case | Nominal classification | Ordinal classification |
Best Practices
- Always validate ordered factor: Use
check_ordered_prob_metric()to ensure truth is ordered - Use cumulative probabilities: This respects the ordering of classes
- Handle case weights consistently: Convert to numeric with
vctrs::vec_cast() - No averaging types: Ordered metrics work the same regardless of class count
- Document ordering assumptions: Make clear that class order matters
See Also
Probability Metrics - Regular (nominal) probability metrics
Metric System - Understanding metric architecture
Testing Patterns - Comprehensive test guide