Implementing Fit and Predict Methods
This guide covers how to implement fit and predict methods for parsnip models. This applies to both new model specifications and new engines.
Overview
The fit/predict lifecycle in parsnip:
- User calls
fit()with model spec, formula/data - parsnip translates arguments and prepares data
- Engine function is called to fit the model
- Result wrapped in
model_fitobject - User calls
predict()with new data and type - Engine prediction function is called
- Result standardized to tibble format with
.predcolumns
Fit Method Structure
Standard Signature
fit(object, formula, data, control = control_parsnip(), ...)Arguments:
object- Model specificationformula- Model formuladata- Training datacontrol- Control object for verbosity, error handling...- Additional arguments passed to engine
Fit Implementation via set_fit()
Instead of writing fit.model_spec() methods directly, you use set_fit() to register how to fit:
set_fit(
model = "linear_reg",
eng = "lm",
mode = "regression",
value = list(
interface = "formula",
protect = c("formula", "data"),
func = c(pkg = "stats", fun = "lm"),
defaults = list()
)
)Components:
interface - How data is passed to the engine:
"formula"- Formula interface:lm(formula, data)"matrix"- Matrix interface:glmnet(x, y)"xy"- Separate x and y:knn(train = x, cl = y)
protect - Arguments that shouldn’t be overridden by user
func - The function to call (package and function name)
defaults - Default arguments to pass to the engine
Data Preparation
parsnip handles data conversion based on interface:
Formula interface ("formula"):
# User provides:
fit(spec, mpg ~ hp + wt, data = mtcars)
# Engine receives:
lm(formula = mpg ~ hp + wt, data = mtcars)Matrix interface ("matrix"):
# User provides:
fit(spec, mpg ~ hp + wt, data = mtcars)
# Engine receives:
glmnet(x = as.matrix(mtcars[, c("hp", "wt")]), y = mtcars$mpg)parsnip automatically converts formula to matrices.
XY interface ("xy"):
# User can provide:
fit_xy(spec, x = predictors, y = outcome)
# Or from formula:
fit(spec, mpg ~ hp + wt, data = mtcars)
# Converted to x/y internallyCalling the Engine Function
The func specification tells parsnip what to call:
func = c(pkg = "stats", fun = "lm")
# Becomes: stats::lm(...)
func = c(pkg = "xgboost", fun = "xgb.train")
# Becomes: xgboost::xgb.train(...)Wrapping in model_fit
parsnip automatically wraps the result:
# Engine returns lm object
fit_result <- lm(mpg ~ hp, data = mtcars)
# parsnip wraps it:
model_fit <- structure(
list(
spec = original_spec,
fit = fit_result,
preproc = preprocessing_info
),
class = "model_fit"
)Predict Method Structure
Standard Signature
predict(object, new_data, type = "numeric", ...)Arguments:
object- Fitted model (model_fit)new_data- Data frame with new observationstype- Type of prediction (depends on mode)...- Additional arguments
Prediction Implementation via set_pred()
Register each prediction type separately:
set_pred(
model = "linear_reg",
eng = "lm",
mode = "regression",
type = "numeric",
value = list(
pre = NULL,
post = NULL,
func = c(fun = "predict"),
args = list(
object = rlang::expr(object$fit),
newdata = rlang::expr(new_data),
type = "response"
)
)
)Components:
pre - Function to run before prediction (data preparation)
post - Function to run after prediction (format conversion)
func - The prediction function to call
args - Arguments to pass (using rlang::expr() for delayed evaluation)
Pre-Processing (pre)
Use pre when you need to prepare data before prediction:
pre = function(new_data, object) {
# Convert factors to integers for this engine
new_data$category <- as.integer(new_data$category)
new_data
}Signature: function(new_data, object)
Returns: Modified new_data
Post-Processing (post)
Use post to convert engine output to standard format:
post = function(results, object) {
# Engine returns matrix with columns "lower", "upper"
# Convert to tibble with standard names
tibble::tibble(
.pred_lower = results[, "lower"],
.pred_upper = results[, "upper"]
)
}Signature: function(results, object)
Returns: Tibble with standardized column names
Prediction Function Arguments
Use rlang::expr() to delay evaluation:
args = list(
object = rlang::expr(object$fit), # The fitted model
newdata = rlang::expr(new_data), # New data
type = "response" # Static argument
)Why rlang::expr()?
Prevents evaluation until prediction time
Allows access to the actual fitted object
Enables proper scoping
Multiple Prediction Types
Register each type separately:
# Numeric predictions
set_pred(..., type = "numeric", ...)
# Confidence intervals
set_pred(..., type = "conf_int", ...)
# Raw predictions
set_pred(..., type = "raw", ...)Each can have different pre, post, and args.
Common Implementation Patterns
Pattern 1: Formula-Based Fit
Example: lm engine for linear_reg
set_fit(
model = "linear_reg",
eng = "lm",
mode = "regression",
value = list(
interface = "formula",
protect = c("formula", "data"),
func = c(pkg = "stats", fun = "lm"),
defaults = list()
)
)Simple and direct - parsnip handles everything.
Pattern 2: Matrix-Based Fit
Example: glmnet engine
set_fit(
model = "linear_reg",
eng = "glmnet",
mode = "regression",
value = list(
interface = "matrix",
protect = c("x", "y", "weights"),
func = c(pkg = "glmnet", fun = "glmnet"),
defaults = list(family = "gaussian")
)
)parsnip converts formula to matrices automatically.
Pattern 3: Simple Prediction
Example: Numeric prediction with lm
set_pred(
model = "linear_reg",
eng = "lm",
mode = "regression",
type = "numeric",
value = list(
pre = NULL,
post = NULL,
func = c(fun = "predict"),
args = list(
object = rlang::expr(object$fit),
newdata = rlang::expr(new_data),
type = "response"
)
)
)Direct call to predict() method, no transformation needed.
Pattern 4: Prediction with Post-Processing
Example: Confidence intervals
set_pred(
model = "linear_reg",
eng = "lm",
mode = "regression",
type = "conf_int",
value = list(
pre = NULL,
post = function(results, object) {
tibble::tibble(
.pred_lower = results[, "lwr"],
.pred_upper = results[, "upr"]
)
},
func = c(fun = "predict"),
args = list(
object = rlang::expr(object$fit),
newdata = rlang::expr(new_data),
interval = "confidence",
level = 0.95
)
)
)Engine returns matrix, post converts to standard format.
Pattern 5: Multi-Mode Model
Example: Model supporting both regression and classification
# Regression fit
set_fit(
model = "my_model",
eng = "custom",
mode = "regression",
value = list(...)
)
# Classification fit
set_fit(
model = "my_model",
eng = "custom",
mode = "classification",
value = list(...)
)
# Regression prediction
set_pred(
model = "my_model",
eng = "custom",
mode = "regression",
type = "numeric",
value = list(...)
)
# Classification predictions
set_pred(
model = "my_model",
eng = "custom",
mode = "classification",
type = "class",
value = list(...)
)
set_pred(
model = "my_model",
eng = "custom",
mode = "classification",
type = "prob",
value = list(...)
)Register fit and predictions separately for each mode.
Column Naming Conventions
parsnip has strict conventions for prediction column names:
Numeric predictions:
tibble::tibble(.pred = c(1.2, 3.4, 5.6))Class predictions:
tibble::tibble(.pred_class = factor(c("A", "B", "A")))Probability predictions:
tibble::tibble(
.pred_A = c(0.8, 0.2, 0.7),
.pred_B = c(0.2, 0.8, 0.3)
)Confidence intervals:
tibble::tibble(
.pred_lower = c(1.0, 3.0, 5.0),
.pred_upper = c(1.5, 4.0, 6.0)
)Always use these exact names in post functions.
Error Handling
Missing Required Packages
Check for packages before fitting:
set_dependency(
model = "boost_tree",
eng = "xgboost",
pkg = "xgboost",
mode = "regression"
)parsnip will error with helpful message if package not installed.
Incompatible Mode/Type Combinations
Only register prediction types that make sense:
# DON'T register "prob" for regression models
# DON'T register "numeric" for classification modelsparsnip will error if user requests unavailable type.
Invalid New Data
Engine functions handle this, but you can add validation in pre:
pre = function(new_data, object) {
if (nrow(new_data) == 0) {
rlang::abort("new_data must have at least one row")
}
new_data
}Testing Fit and Predict
Essential tests for any fit/predict implementation:
Fit tests:
test_that("model fits with formula", {
spec <- my_model() |> set_engine("custom")
fit <- fit(spec, mpg ~ ., data = mtcars)
expect_s3_class(fit, "model_fit")
})
test_that("model fits with xy", {
spec <- my_model() |> set_engine("custom")
fit <- fit_xy(spec, x = mtcars[, -1], y = mtcars$mpg)
expect_s3_class(fit, "model_fit")
})Predict tests:
test_that("numeric predictions work", {
spec <- my_model() |> set_engine("custom")
fit <- fit(spec, mpg ~ ., data = mtcars)
preds <- predict(fit, new_data = mtcars[1:5, ])
expect_s3_class(preds, "tbl_df")
expect_named(preds, ".pred")
expect_equal(nrow(preds), 5)
})
test_that("predictions match new_data rows", {
spec <- my_model() |> set_engine("custom")
fit <- fit(spec, mpg ~ ., data = mtcars)
preds <- predict(fit, new_data = mtcars[1:10, ])
expect_equal(nrow(preds), 10)
})Interface tests:
test_that("formula and xy produce same results", {
spec <- my_model() |> set_engine("custom")
fit_formula <- fit(spec, mpg ~ hp + wt, data = mtcars)
fit_xy <- fit_xy(spec, x = mtcars[, c("hp", "wt")], y = mtcars$mpg)
pred_formula <- predict(fit_formula, mtcars[1:5, ])
pred_xy <- predict(fit_xy, mtcars[1:5, ])
expect_equal(pred_formula, pred_xy, tolerance = 1e-5)
})Extension vs Source Development
Extension Development
When creating engines in your own package:
Always use parsnip:: prefix:
parsnip::set_fit(...)
parsnip::set_pred(...)You can only use exported functions:
func = c(pkg = "stats", fun = "lm") # ✅ Exported
func = c(fun = "lm") # ✅ Also worksSource Development
When contributing to parsnip itself:
No prefix needed:
set_fit(...)
set_pred(...)Can reference internal functions:
func = c(pkg = "parsnip", fun = "xgb_train") # Internal helperFollow parsnip file organization:
Fit/predict code in
R/[model]_data.RGroup all engines for a model together
Summary
Implementing fit and predict:
- Use
set_fit()to register fitting function and interface - Use
set_pred()for each prediction type - Use
prefor data preparation - Use
postfor result formatting - Use
rlang::expr()for argument evaluation - Follow column naming conventions strictly
- Register each mode separately if multi-mode
- Test thoroughly with different data types and interfaces
The registration system handles most complexity - you just specify what to call and how to format results.