Testing Patterns for Parsnip Engine Source Development
Testing strategies for contributing engines to the parsnip package (source development).
Overview
When contributing engines to parsnip source, testing focuses on engine-specific behavior rather than model constructor patterns (since the model type already exists).
Key testing areas:
Engine registration and setup
Fit interface compatibility (formula vs matrix vs xy)
All prediction types the engine supports
Engine-specific argument handling
Edge cases and error conditions
Test File Organization
File Naming
Engine tests go in existing model test files:
tests/testthat/test-linear_reg.R # Add glmnet engine tests here
tests/testthat/test-boost_tree.R # Add lightgbm engine tests here
Or create engine-specific files for complex engines:
tests/testthat/test-linear_reg-glmnet.R
tests/testthat/test-boost_tree-lightgbm.R
Required Test Categories
1. Engine Setup Tests
Verify engine can be selected and configured:
test_that("linear_reg can use glmnet engine", {
skip_if_not_installed("glmnet")
spec <- linear_reg() |> set_engine("glmnet")
expect_equal(spec$engine, "glmnet")
expect_s3_class(spec, "linear_reg")
})
test_that("glmnet engine accepts engine-specific arguments", {
skip_if_not_installed("glmnet")
spec <- linear_reg() |>
set_engine("glmnet", nlambda = 50, thresh = 1e-10)
expect_equal(spec$eng_args$nlambda, 50)
expect_equal(spec$eng_args$thresh, 1e-10)
})2. Fit Interface Tests
Test all interfaces the engine supports:
test_that("glmnet engine fits with formula interface", {
skip_if_not_installed("glmnet")
spec <- linear_reg(penalty = 0.1) |> set_engine("glmnet")
fit <- fit(spec, mpg ~ ., data = mtcars)
expect_s3_class(fit, "model_fit")
expect_s3_class(fit$fit, "glmnet")
})
test_that("glmnet engine fits with xy interface", {
skip_if_not_installed("glmnet")
spec <- linear_reg(penalty = 0.1) |> set_engine("glmnet")
fit <- fit_xy(spec, x = mtcars[, -1], y = mtcars$mpg)
expect_s3_class(fit, "model_fit")
expect_s3_class(fit$fit, "glmnet")
})
test_that("formula and xy interfaces give equivalent results", {
skip_if_not_installed("glmnet")
spec <- linear_reg(penalty = 0.1) |> set_engine("glmnet")
fit_formula <- fit(spec, mpg ~ hp + wt, data = mtcars)
fit_xy <- fit_xy(spec, x = mtcars[, c("hp", "wt")], y = mtcars$mpg)
pred_formula <- predict(fit_formula, mtcars[1:5, ])
pred_xy <- predict(fit_xy, mtcars[1:5, ])
expect_equal(pred_formula, pred_xy, tolerance = 1e-5)
})3. Prediction Type Tests
Test each prediction type the engine supports:
Regression predictions:
test_that("glmnet numeric predictions", {
skip_if_not_installed("glmnet")
spec <- linear_reg(penalty = 0.1) |> set_engine("glmnet")
fit <- fit(spec, mpg ~ ., data = mtcars)
preds <- predict(fit, mtcars[1:5, ])
expect_s3_class(preds, "tbl_df")
expect_named(preds, ".pred")
expect_equal(nrow(preds), 5)
expect_type(preds$.pred, "double")
})
test_that("glmnet raw predictions", {
skip_if_not_installed("glmnet")
spec <- linear_reg(penalty = 0.1) |> set_engine("glmnet")
fit <- fit(spec, mpg ~ ., data = mtcars)
preds <- predict(fit, mtcars[1:5, ], type = "raw")
expect_type(preds, "double")
expect_true(is.matrix(preds))
})Classification predictions:
test_that("glmnet class predictions", {
skip_if_not_installed("glmnet")
spec <- logistic_reg(penalty = 0.1) |> set_engine("glmnet")
data <- data.frame(
y = factor(rep(c("A", "B"), each = 10)),
x1 = rnorm(20),
x2 = rnorm(20)
)
fit <- fit(spec, y ~ ., data = data)
preds <- predict(fit, data[1:5, ])
expect_s3_class(preds, "tbl_df")
expect_named(preds, ".pred_class")
expect_s3_class(preds$.pred_class, "factor")
})
test_that("glmnet probability predictions", {
skip_if_not_installed("glmnet")
spec <- logistic_reg(penalty = 0.1) |> set_engine("glmnet")
data <- data.frame(
y = factor(rep(c("A", "B"), each = 10)),
x1 = rnorm(20),
x2 = rnorm(20)
)
fit <- fit(spec, y ~ ., data = data)
preds <- predict(fit, data[1:5, ], type = "prob")
expect_s3_class(preds, "tbl_df")
expect_true(all(grepl("^\\.pred_", names(preds))))
expect_equal(ncol(preds), 2)
# Check probabilities sum to 1
row_sums <- rowSums(preds)
expect_equal(row_sums, rep(1, 5), tolerance = 1e-10)
})4. Argument Translation Tests
Test that main arguments are correctly translated:
test_that("glmnet penalty argument translates correctly", {
skip_if_not_installed("glmnet")
spec <- linear_reg(penalty = 0.1) |> set_engine("glmnet")
fit <- fit(spec, mpg ~ ., data = mtcars)
# Check lambda was set
expect_true("lambda" %in% names(fit$fit$call))
expect_equal(fit$spec$args$penalty, 0.1)
})
test_that("glmnet mixture argument translates correctly", {
skip_if_not_installed("glmnet")
spec <- linear_reg(penalty = 0.1, mixture = 0.5) |>
set_engine("glmnet")
fit <- fit(spec, mpg ~ ., data = mtcars)
# Check alpha was set
expect_true("alpha" %in% names(fit$fit$call))
expect_equal(fit$spec$args$mixture, 0.5)
})5. Edge Case Tests
Test boundary conditions:
test_that("glmnet handles single row prediction", {
skip_if_not_installed("glmnet")
spec <- linear_reg(penalty = 0.1) |> set_engine("glmnet")
fit <- fit(spec, mpg ~ ., data = mtcars)
preds <- predict(fit, mtcars[1, ])
expect_equal(nrow(preds), 1)
expect_named(preds, ".pred")
})
test_that("glmnet handles factor predictors", {
skip_if_not_installed("glmnet")
data <- mtcars
data$cyl <- factor(data$cyl)
spec <- linear_reg(penalty = 0.1) |> set_engine("glmnet")
fit <- fit(spec, mpg ~ cyl + hp, data = data)
expect_s3_class(fit, "model_fit")
preds <- predict(fit, data[1:5, ])
expect_equal(nrow(preds), 5)
})
test_that("glmnet handles missing values appropriately", {
skip_if_not_installed("glmnet")
data <- mtcars
data$hp[1:3] <- NA
spec <- linear_reg(penalty = 0.1) |> set_engine("glmnet")
# glmnet should error on missing values
expect_error(
fit(spec, mpg ~ ., data = data)
)
})6. Multi-Mode Testing
If engine supports multiple modes, test each:
# ------------------------------------------------------------------------------
# Regression mode
test_that("xgboost regression mode", {
skip_if_not_installed("xgboost")
spec <- boost_tree(trees = 20) |>
set_engine("xgboost") |>
set_mode("regression")
fit <- fit(spec, mpg ~ ., data = mtcars)
expect_s3_class(fit, "model_fit")
preds <- predict(fit, mtcars[1:5, ])
expect_named(preds, ".pred")
})
# ------------------------------------------------------------------------------
# Classification mode
test_that("xgboost classification mode", {
skip_if_not_installed("xgboost")
spec <- boost_tree(trees = 20) |>
set_engine("xgboost") |>
set_mode("classification")
data <- data.frame(
y = factor(rep(c("A", "B"), each = 10)),
x1 = rnorm(20),
x2 = rnorm(20)
)
fit <- fit(spec, y ~ ., data = data)
expect_s3_class(fit, "model_fit")
preds_class <- predict(fit, data[1:5, ])
expect_named(preds_class, ".pred_class")
preds_prob <- predict(fit, data[1:5, ], type = "prob")
expect_true(all(grepl("^\\.pred_", names(preds_prob))))
})Snapshot Testing
Use for engine-specific errors:
test_that("glmnet errors informatively on invalid penalty", {
skip_if_not_installed("glmnet")
spec <- linear_reg(penalty = -1) |> set_engine("glmnet")
expect_snapshot(
fit(spec, mpg ~ ., data = mtcars),
error = TRUE
)
})
test_that("glmnet errors on incompatible mode-type", {
skip_if_not_installed("glmnet")
spec <- linear_reg(penalty = 0.1) |> set_engine("glmnet")
fit <- fit(spec, mpg ~ ., data = mtcars)
expect_snapshot(
predict(fit, mtcars, type = "prob"),
error = TRUE
)
})Integration Tests
Test engine works with tidymodels ecosystem:
test_that("glmnet works with workflows", {
skip_if_not_installed("workflows")
skip_if_not_installed("glmnet")
spec <- linear_reg(penalty = 0.1) |> set_engine("glmnet")
wf <- workflows::workflow() |>
workflows::add_formula(mpg ~ .) |>
workflows::add_model(spec)
fit <- fit(wf, data = mtcars)
preds <- predict(fit, mtcars[1:5, ])
expect_s3_class(preds, "tbl_df")
expect_equal(nrow(preds), 5)
})
test_that("glmnet works with recipes", {
skip_if_not_installed("workflows")
skip_if_not_installed("recipes")
skip_if_not_installed("glmnet")
spec <- linear_reg(penalty = 0.1) |> set_engine("glmnet")
rec <- recipes::recipe(mpg ~ ., data = mtcars) |>
recipes::step_normalize(recipes::all_numeric_predictors())
wf <- workflows::workflow() |>
workflows::add_recipe(rec) |>
workflows::add_model(spec)
fit <- fit(wf, data = mtcars)
preds <- predict(fit, mtcars[1:5, ])
expect_equal(nrow(preds), 5)
})Test Organization Pattern
# tests/testthat/test-linear_reg-glmnet.R
# ------------------------------------------------------------------------------
# Setup and configuration
test_that("can set glmnet engine", { ... })
test_that("glmnet accepts engine args", { ... })
# ------------------------------------------------------------------------------
# Fitting - regression mode
test_that("glmnet fits with formula", { ... })
test_that("glmnet fits with xy", { ... })
test_that("formula and xy equivalent", { ... })
# ------------------------------------------------------------------------------
# Predictions - regression mode
test_that("numeric predictions", { ... })
test_that("raw predictions", { ... })
# ------------------------------------------------------------------------------
# Argument handling
test_that("penalty translates", { ... })
test_that("mixture translates", { ... })
# ------------------------------------------------------------------------------
# Edge cases
test_that("single row", { ... })
test_that("factor predictors", { ... })
test_that("missing values", { ... })
# ------------------------------------------------------------------------------
# Integration
test_that("works with workflows", { ... })
test_that("works with recipes", { ... })Testing Checklist
Before submitting engine PR:
Common Patterns
Pattern 1: Test Both Interfaces
test_that("both interfaces work", {
skip_if_not_installed("pkg")
spec <- my_model() |> set_engine("new_engine")
# Formula
fit1 <- fit(spec, y ~ ., data = data)
pred1 <- predict(fit1, data[1:3, ])
# XY
fit2 <- fit_xy(spec, x = data[, -1], y = data$y)
pred2 <- predict(fit2, data[1:3, ])
expect_equal(pred1, pred2, tolerance = 1e-5)
})Pattern 2: Skip If Package Missing
test_that("engine test", {
skip_if_not_installed("enginepkg")
skip_if_not_installed("helperpkg")
# Test code
})Pattern 3: Check Prediction Format
# Numeric
expect_named(preds, ".pred")
expect_type(preds$.pred, "double")
# Class
expect_named(preds, ".pred_class")
expect_s3_class(preds$.pred_class, "factor")
# Probability
expect_true(all(grepl("^\\.pred_", names(preds))))
expect_equal(rowSums(preds), rep(1, nrow(preds)), tolerance = 1e-10)Debugging Engine Tests
Run Specific Tests
# Run engine-specific file
devtools::test_file("tests/testthat/test-linear_reg-glmnet.R")
# Run specific test
devtools::test_file(
"tests/testthat/test-linear_reg-glmnet.R",
filter = "numeric predictions"
)Interactive Debugging
# Load package
devtools::load_all()
# Run test interactively
test_that("debug test", {
spec <- linear_reg(penalty = 0.1) |> set_engine("glmnet")
browser() # Stops here
fit <- fit(spec, mpg ~ ., data = mtcars)
})Additional Resources
Example test files in parsnip:
tests/testthat/test-linear_reg.R- Multiple enginestests/testthat/test-boost_tree-xgboost.R- Complex enginetests/testthat/test-mlp.R- Multi-mode testing
Related guides:
Engine Implementation - Implementation guide
Best Practices (Source) - Code conventions
Troubleshooting (Source) - Common issues