Best Practices for Parsnip Source Development
Guidelines and best practices for contributing to the parsnip package source code.
Overview
When contributing to parsnip itself (not creating extensions), follow these practices to maintain code quality and consistency with the existing codebase.
Code Organization
File Structure
Model constructors: R/[model_type].R
R/linear_reg.R # linear_reg() constructor
R/boost_tree.R # boost_tree() constructor
R/rand_forest.R # rand_forest() constructorEngine registrations: R/[model]_data.R
R/linear_reg_data.R # All linear_reg engines
R/boost_tree_data.R # All boost_tree enginesInfrastructure: Core system files
R/aaa_models.R # Model environment setup
R/misc.R # Helper functions
R/fit.R # Fit methods
R/predict.R # Predict methodsRegistration Patterns
Complete Registration Sequence
For each engine, register in this order:
# 1. Declare engine exists
set_model_engine(
model = "linear_reg",
mode = "regression",
eng = "glmnet"
)
# 2. Declare package dependencies
set_dependency(
model = "linear_reg",
eng = "glmnet",
pkg = "glmnet",
mode = "regression"
)
# 3. Translate main arguments
set_model_arg(
model = "linear_reg",
eng = "glmnet",
parsnip = "penalty",
original = "lambda",
func = list(pkg = "dials", fun = "penalty"),
has_submodel = TRUE
)
# 4. Register fit method
set_fit(
model = "linear_reg",
eng = "glmnet",
mode = "regression",
value = list(...)
)
# 5. Register encoding (if needed)
set_encoding(
model = "linear_reg",
eng = "glmnet",
mode = "regression",
options = list(...)
)
# 6. Register each prediction type
set_pred(
model = "linear_reg",
eng = "glmnet",
mode = "regression",
type = "numeric",
value = list(...)
)
set_pred(
model = "linear_reg",
eng = "glmnet",
mode = "regression",
type = "conf_int",
value = list(...)
)Use Consistent Naming
Main arguments: Follow established parsnip conventions
# ✓ Good - consistent with other models
penalty # not lambda, not reg_param
mixture # not alpha, not l1_ratio
trees # not n_estimators, not num_boost_roundEngine names: Use package or algorithm name
# ✓ Good
"lm" # From stats package
"glmnet" # From glmnet package
"xgboost" # From xgboost package
# ✗ Avoid
"linear_model"
"elastic_net"
"boosted_trees"Using Internal Functions
When to Use Internal Functions
Source development can use internal parsnip functions:
# ✓ Allowed in parsnip source
func = c(pkg = "parsnip", fun = "xgb_train") # Internal helper
# Helper functions for complex conversions
pre = function(new_data, object) {
parsnip:::prepare_data_for_engine(new_data, object)
}Common Internal Helpers
Data conversion:
parsnip:::convert_data_to_matrix()
parsnip:::prepare_survival_data()Prediction post-processing:
parsnip:::format_class_predictions()
parsnip:::format_prob_matrix()Validation:
parsnip:::check_outcome_type()
parsnip:::validate_prediction_type()Document Internal Function Usage
When using internal functions, add comments explaining why:
set_pred(
...,
value = list(
# Using internal helper for complex survival curve extraction
post = function(results, object) {
parsnip:::extract_surv_curves(results, object)
},
...
)
)Error Handling
Use Informative Error Messages
# ✓ Good - explains the problem and solution
post = function(results, object) {
if (!inherits(results, "expected_class")) {
rlang::abort(
"Expected output from engine to be class 'expected_class'",
"i" = "Check that the engine is returning the correct format",
"i" = "Consider updating the engine package"
)
}
format_results(results)
}
# ✗ Bad - generic error
post = function(results, object) {
if (!inherits(results, "expected_class")) {
stop("Wrong type")
}
format_results(results)
}Validate at Registration Time
Check for common issues early:
set_fit(
...,
value = list(
interface = "matrix",
protect = c("x", "y"),
func = c(pkg = "glmnet", fun = "glmnet"),
defaults = list(family = "gaussian")
)
)
# Validate that function exists
if (!requireNamespace("glmnet", quietly = TRUE)) {
rlang::warn("glmnet package not available for testing")
}Testing
Test Files Organization
Model-specific tests: tests/testthat/test-[model].R
tests/testthat/test-boost_tree.R
tests/testthat/test-linear_reg.REngine-specific tests: Within model test file
# In test-linear_reg.R
test_that("lm engine works", { ... })
test_that("glmnet engine works", { ... })
test_that("keras engine works", { ... })Essential Tests for Each Engine
test_that("glmnet engine fits and predicts", {
skip_if_not_installed("glmnet")
# Fit
spec <- linear_reg(penalty = 0.1) |>
set_engine("glmnet") |>
set_mode("regression")
fit <- fit(spec, mpg ~ ., data = mtcars)
expect_s3_class(fit, "model_fit")
# Predict
preds <- predict(fit, mtcars[1:5, ])
expect_s3_class(preds, "tbl_df")
expect_named(preds, ".pred")
expect_equal(nrow(preds), 5)
})
test_that("glmnet engine handles factors", {
skip_if_not_installed("glmnet")
data <- data.frame(
y = 1:10,
x1 = 1:10,
x2 = factor(rep(c("A", "B"), 5))
)
spec <- linear_reg() |> set_engine("glmnet")
fit <- fit(spec, y ~ x1 + x2, data = data)
# Predictions should work
preds <- predict(fit, data[1:3, ])
expect_equal(nrow(preds), 3)
})
test_that("glmnet engine supports multiple prediction types", {
skip_if_not_installed("glmnet")
spec <- linear_reg() |> set_engine("glmnet")
fit <- fit(spec, mpg ~ ., data = mtcars)
# Numeric
pred_num <- predict(fit, mtcars[1:5, ], type = "numeric")
expect_named(pred_num, ".pred")
# Raw
pred_raw <- predict(fit, mtcars[1:5, ], type = "raw")
expect_true(!is.null(pred_raw))
})Test Both Formula and XY Interfaces
test_that("formula and xy interfaces produce same results", {
skip_if_not_installed("glmnet")
spec <- linear_reg() |> set_engine("glmnet")
# Formula
fit_formula <- fit(spec, mpg ~ hp + wt, data = mtcars)
pred_formula <- predict(fit_formula, mtcars[1:5, ])
# XY
fit_xy <- fit_xy(spec, x = mtcars[, c("hp", "wt")], y = mtcars$mpg)
pred_xy <- predict(fit_xy, mtcars[1:5, ])
# Should match
expect_equal(pred_formula, pred_xy, tolerance = 1e-5)
})Test Error Conditions
test_that("glmnet engine errors appropriately", {
skip_if_not_installed("glmnet")
spec <- linear_reg() |> set_engine("glmnet")
# Wrong mode
expect_error(
fit(spec, Species ~ ., data = iris),
"factor"
)
# Missing data (if engine doesn't handle it)
data_na <- mtcars
data_na$mpg[1] <- NA
expect_error(
fit(spec, mpg ~ ., data = data_na),
"missing"
)
})Documentation
Model Constructor Documentation
Follow roxygen2 conventions:
#' Linear Regression
#'
#' `linear_reg()` defines a model that can predict a numeric outcome from
#' one or more predictors.
#'
#' @param mode A single character string for the model type. The only possible
#' value for this model is "regression".
#' @param penalty A non-negative number for the amount of regularization
#' (glmnet, keras engines only). Used by glmnet as `lambda` and by keras
#' as the L2 penalty.
#' @param mixture A number between 0 and 1 for the proportion of L1
#' regularization. Used by glmnet and keras engines.
#' @param engine A character string for the software to fit the model.
#' Default is "lm".
#'
#' @details
#' The available engines are:
#' - `"lm"` (default) - Uses [stats::lm()]
#' - `"glmnet"` - Uses [glmnet::glmnet()]
#' - `"keras"` - Uses keras neural network
#'
#' @seealso [fit.model_spec()], [set_engine()]
#'
#' @examples
#' # Basic linear regression
#' linear_reg() |>
#' set_engine("lm") |>
#' fit(mpg ~ ., data = mtcars)
#'
#' # Regularized regression
#' linear_reg(penalty = 0.1, mixture = 0.5) |>
#' set_engine("glmnet") |>
#' fit(mpg ~ ., data = mtcars)
#'
#' @export
linear_reg <- function(mode = "regression",
penalty = NULL,
mixture = NULL,
engine = "lm") {
# Implementation
}Document Engine Requirements
Explain what each engine needs:
#' @details
#' ## Engine: glmnet
#'
#' Requires the glmnet package. This engine uses elastic net regularization.
#'
#' **Main arguments:**
#' - `penalty` → `lambda` - Amount of regularization
#' - `mixture` → `alpha` - Mix of L1 and L2 (0 = ridge, 1 = lasso)
#'
#' **Engine-specific arguments:**
#' - `nlambda` - Number of lambda values (default: 100)
#' - `standardize` - Standardize predictors (default: TRUE)
#' - Pass to `set_engine("glmnet", nlambda = 50)`
#'
#' **Prediction types:**
#' - `numeric` - Point predictions
#' - `raw` - Raw glmnet objectAdd Examples for Each Engine
#' @examples
#' # lm engine (default)
#' linear_reg() |>
#' fit(mpg ~ ., data = mtcars)
#'
#' # glmnet engine with regularization
#' linear_reg(penalty = 0.1) |>
#' set_engine("glmnet") |>
#' fit(mpg ~ ., data = mtcars)
#'
#' # keras engine with custom architecture
#' linear_reg() |>
#' set_engine("keras", epochs = 100) |>
#' fit(mpg ~ ., data = mtcars)Argument Translation
Follow Tidymodels Naming
When translating main arguments to engine arguments:
# ✓ Good - clear translation
set_model_arg(
model = "boost_tree",
eng = "xgboost",
parsnip = "trees", # Tidymodels standard
original = "nrounds", # xgboost name
func = list(pkg = "dials", fun = "trees"),
has_submodel = TRUE
)
# ✗ Avoid engine-specific names in main arguments
set_model_arg(
model = "boost_tree",
eng = "xgboost",
parsnip = "nrounds", # Too xgboost-specific
original = "nrounds",
...
)Document Argument Mappings
# In R/boost_tree_data.R
# xgboost engine
# Argument translations:
# - trees → nrounds
# - tree_depth → max_depth
# - learn_rate → eta
# - loss_reduction → gamma
set_model_arg(...)Compatibility Considerations
Package Version Requirements
Document minimum versions when needed:
set_dependency(
model = "linear_reg",
eng = "glmnet",
pkg = "glmnet",
mode = "regression"
)
# If specific version needed, add to DESCRIPTION
# Imports: glmnet (>= 4.0)Handle Package Changes
Add version checks for breaking changes:
set_fit(
...,
value = list(
pre = function(data, object) {
# Handle glmnet version differences
if (packageVersion("glmnet") >= "4.0") {
# New behavior
} else {
# Old behavior
}
data
},
...
)
)Multi-Mode Implementation
Register Each Mode Separately
# Register both modes
set_model_mode(model = "boost_tree", mode = "regression")
set_model_mode(model = "boost_tree", mode = "classification")
# Fit for regression
set_fit(
model = "boost_tree",
eng = "xgboost",
mode = "regression",
value = list(
defaults = list(objective = "reg:squarederror")
)
)
# Fit for classification
set_fit(
model = "boost_tree",
eng = "xgboost",
mode = "classification",
value = list(
defaults = list(objective = "multi:softprob")
)
)Performance Considerations
Lazy Evaluation
Use rlang::expr() to delay evaluation:
# ✓ Good - delays evaluation
args = list(
object = rlang::expr(object$fit),
newdata = rlang::expr(new_data)
)
# ✗ Bad - evaluates immediately
args = list(
object = object$fit, # object doesn't exist yet!
newdata = new_data # new_data doesn't exist yet!
)Avoid Unnecessary Conversions
# ✓ Good - only convert if needed
post = function(results, object) {
if (is.matrix(results)) {
tibble::as_tibble(results)
} else {
tibble::tibble(.pred = results)
}
}
# ✗ Bad - always converts (unnecessary for vectors)
post = function(results, object) {
results <- as.matrix(results) # Wasteful if already correct type
tibble::as_tibble(results)
}Summary
Key practices:
- Follow file organization - Constructors in
R/[model].R, registrations inR/[model]_data.R - Complete registration sequence - Engine, dependency, args, fit, encoding, predictions
- Use consistent naming - Follow tidymodels conventions for main arguments
- Can use internal functions - Source development has access to
:::functions - Write comprehensive tests - Test each engine, both interfaces, error conditions
- Document thoroughly - Model constructor, engine details, argument translations
- Handle multi-mode carefully - Register each mode separately, share common code
- Consider performance - Use lazy evaluation, avoid unnecessary conversions
Before submitting:
Run
devtools::check()to verify package integrityEnsure all tests pass
Update NEWS.md with changes
Follow tidymodels code style
Add examples to documentation