Model Constructor Design
This guide covers how to design and implement model constructor functions for parsnip (like linear_reg(), boost_tree()).
Overview
A model constructor is the user-facing function that creates a model specification object. It defines the API users will interact with.
Key responsibilities:
Accept main arguments
Set default engine and mode
Create model_spec object
Validate inputs
Constructor Function Signature
Basic Template
my_model <- function(mode = "regression",
penalty = NULL,
mixture = NULL,
engine = "default_engine") {
args <- list(
penalty = rlang::enquo(penalty),
mixture = rlang::enquo(mixture)
)
new_model_spec(
"my_model",
args = args,
eng_args = NULL,
mode = mode,
user_specified_mode = !missing(mode),
method = NULL,
engine = engine,
user_specified_engine = !missing(engine)
)
}Required Elements
Function name:
Descriptive of algorithm family:
linear_reg(),rand_forest(),boost_tree()Use snake_case
Avoid package-specific names (not
glmnet_model()orxgboost_model())
Parameters:
mode- Prediction mode (regression, classification, etc.)Main arguments - Tuning parameters
engine- Computational backend (with sensible default)
Mode Parameter
Single-Mode Models
For models that only support one mode, set it automatically:
linear_reg <- function(penalty = NULL,
mixture = NULL,
engine = "lm") {
args <- list(
penalty = rlang::enquo(penalty),
mixture = rlang::enquo(mixture)
)
new_model_spec(
"linear_reg",
args = args,
eng_args = NULL,
mode = "regression", # Fixed mode
user_specified_mode = FALSE,
method = NULL,
engine = engine,
user_specified_engine = !missing(engine)
)
}No mode parameter - Users can’t change it.
Multi-Mode Models
For models supporting multiple modes, include mode parameter:
boost_tree <- function(mode = "unknown",
trees = NULL,
tree_depth = NULL,
learn_rate = NULL,
engine = "xgboost") {
args <- list(
trees = rlang::enquo(trees),
tree_depth = rlang::enquo(tree_depth),
learn_rate = rlang::enquo(learn_rate)
)
new_model_spec(
"boost_tree",
args = args,
eng_args = NULL,
mode = mode, # User must set
user_specified_mode = !missing(mode),
method = NULL,
engine = engine,
user_specified_engine = !missing(engine)
)
}Use "unknown" as default - Forces users to be explicit.
Main Arguments
What Makes a Main Argument?
Include as main argument when:
Common across multiple engines
Important tuning parameter
Worth standardizing for tune package
Users will frequently adjust
Keep as engine argument when:
Engine-specific behavior
Rarely adjusted
No equivalent in other engines
Naming Main Arguments
Follow tidymodels conventions:
Good names (standardized):
penalty- Regularization amount (notlambdaorreg_param)mixture- L1/L2 mix (notalphaorl1_ratio)trees- Number of trees (notn_estimatorsornum_boost_round)tree_depth- Max tree depth (notmax_depth)learn_rate- Learning rate (notetaorshrinkage)mtry- Predictors per split (notmax_features)
Avoid engine-specific names:
Don’t use
lambda- UsepenaltyinsteadDon’t use
alpha- UsemixtureinsteadDon’t use
nrounds- Usetreesinstead
Using rlang::enquo()
Capture arguments with rlang::enquo() to support tidy evaluation:
args <- list(
penalty = rlang::enquo(penalty),
trees = rlang::enquo(trees),
mtry = rlang::enquo(mtry)
)Why?
Enables
tune()placeholdersSupports delayed evaluation
Required by parsnip infrastructure
Setting Defaults
NULL defaults are recommended:
linear_reg <- function(penalty = NULL, mixture = NULL, ...)Why NULL?
Lets engine use its own defaults
Clear when user specified vs default
More flexible across engines
When to use specific defaults:
Only if all engines need the same value
When NULL would be ambiguous
Engine Parameter
Choosing Default Engine
Pick a default engine based on:
Availability (base R > CRAN > GitHub)
Stability (mature packages preferred)
Performance (for typical use cases)
Simplicity (fewer dependencies better)
Examples:
linear_reg(engine = "lm") # base R, always available
boost_tree(engine = "xgboost") # Popular, well-maintained
rand_forest(engine = "ranger") # Fast, reliableTrack User Specification
Always track if user specified engine:
new_model_spec(
...,
engine = engine,
user_specified_engine = !missing(engine)
)Used for:
Error messages
Default behavior
Package loading
Using new_model_spec()
Core Constructor Helper
new_model_spec() creates the S3 object:
new_model_spec(
cls = "my_model", # Model class name
args = args, # Main arguments (enquo'd)
eng_args = NULL, # Engine arguments (initially NULL)
mode = mode, # Prediction mode
user_specified_mode = !missing(mode),
method = NULL, # Fitting method (initially NULL)
engine = engine, # Computational engine
user_specified_engine = !missing(engine)
)Fields Explained
cls - Character string for model type:
Must match function name
Used for class hierarchy
Example:
"linear_reg"forlinear_reg()function
args - List of main arguments:
Each element is output of
rlang::enquo()Names match parameter names
Can be NULL or tune() placeholders
eng_args - Engine-specific arguments:
Initially NULL
Populated by
set_engine()User provides via
set_engine(..., arg = value)
mode - Prediction mode:
“regression”, “classification”, “censored regression”, “quantile regression”
Or “unknown” for multi-mode models without default
user_specified_mode - Boolean:
TRUE if user provided mode parameter
Use
!missing(mode)to detect
method - Fitting method:
Initially NULL
Set internally by parsnip
Don’t modify in constructor
engine - Computational backend:
String naming the engine (e.g., “lm”, “glmnet”, “xgboost”)
Should have sensible default
user_specified_engine - Boolean:
TRUE if user provided engine parameter
Use
!missing(engine)to detect
Return Value
What Gets Returned
The constructor returns a model_spec object:
spec <- linear_reg()
class(spec)
#> [1] "linear_reg" "model_spec"Class Hierarchy
Classes are automatically created:
Primary class: Model name (
"linear_reg")Parent class:
"model_spec"
Created by make_classes(cls) internally.
Object Structure
spec <- linear_reg(penalty = 0.1)
str(spec)
#> List of 7
#> $ args : List of 2
#> ..$ penalty: quosure
#> ..$ mixture: quosure
#> $ eng_args : NULL
#> $ mode : chr "regression"
#> $ user_specified_mode: logi FALSE
#> $ method : NULL
#> $ engine : chr "lm"
#> $ user_specified_engine: logi FALSE
#> - attr(*, "class")= chr [1:2] "linear_reg" "model_spec"Input Validation
What to Validate
In constructor, validate:
Mode is valid (if provided)
Numeric arguments are numeric
Logical arguments are logical
Factors/characters when expected
Don’t validate in constructor:
Engine availability (checked at fit time)
Argument compatibility (checked at fit time)
Data compatibility (checked at fit time)
Validation Examples
Check mode validity:
my_model <- function(mode = "unknown", ...) {
if (!mode %in% c("unknown", "regression", "classification")) {
rlang::abort(
"mode should be 'regression' or 'classification'"
)
}
# Continue with new_model_spec()...
}Check argument types:
linear_reg <- function(penalty = NULL, ...) {
if (!is.null(penalty) && !is.numeric(penalty)) {
rlang::abort("penalty must be numeric")
}
# Continue with new_model_spec()...
}For tidymodels style, check at registration time instead: Most validation happens during registration and fitting, not in constructor.
Documentation
Roxygen2 Template
#' Linear Regression
#'
#' `linear_reg()` defines a model that uses linear regression to predict
#' a numeric outcome from one or more predictors.
#'
#' @param mode A single character string for the type of model. The only
#' possible value for this model is "regression".
#' @param penalty A non-negative number for the amount of regularization.
#' Used by glmnet and keras engines.
#' @param mixture A number between 0 and 1 for the proportion of L1
#' regularization. Used by glmnet and keras engines.
#' @param engine A character string specifying the computational engine.
#' Possible values are "lm" (default), "glmnet", and "keras".
#'
#' @details
#' ## Engines
#'
#' This model can be fit using different computational engines. The
#' available engines are:
#'
#' - **lm** (default): Uses [stats::lm()]. No regularization.
#' - **glmnet**: Uses [glmnet::glmnet()]. Supports regularization via
#' penalty and mixture.
#' - **keras**: Uses keras neural network. Supports regularization.
#'
#' ## Main arguments
#'
#' The main arguments for this model are:
#'
#' - `penalty`: Amount of regularization (lambda in glmnet)
#' - `mixture`: Mix of L1 (lasso) and L2 (ridge) regularization
#'
#' These arguments are only used by engines that support regularization.
#'
#' @return A `linear_reg` model specification.
#'
#' @seealso [set_engine()], [fit.model_spec()]
#'
#' @examples
#' # Basic linear regression
#' linear_reg() |>
#' fit(mpg ~ ., data = mtcars)
#'
#' # With regularization
#' linear_reg(penalty = 0.1, mixture = 0.5) |>
#' set_engine("glmnet") |>
#' fit(mpg ~ ., data = mtcars)
#'
#' @export
linear_reg <- function(mode = "regression",
penalty = NULL,
mixture = NULL,
engine = "lm") {
# Implementation
}Key Documentation Sections
Description: What the model does
Parameters: Each argument with type and purpose
Details: Engine-specific information
Examples: Show both simple and advanced usage
Seealso: Link to related functions
Extension vs Source Patterns
Extension Package Constructor
In your own package, export the constructor:
#' @export
my_model <- function(mode = "regression", penalty = NULL, ...) {
args <- list(penalty = rlang::enquo(penalty))
parsnip::new_model_spec( # Use namespace
"my_model",
args = args,
eng_args = NULL,
mode = mode,
user_specified_mode = !missing(mode),
method = NULL,
engine = "default_engine",
user_specified_engine = !missing(engine)
)
}Register in .onLoad():
.onLoad <- function(libname, pkgname) {
parsnip::set_new_model("my_model")
parsnip::set_model_mode("my_model", "regression")
# ... more registration
}Source Package Constructor
Contributing to parsnip directly:
# R/my_model.R
#' @export
my_model <- function(mode = "regression", penalty = NULL, ...) {
args <- list(penalty = rlang::enquo(penalty))
new_model_spec( # No namespace needed
"my_model",
args = args,
eng_args = NULL,
mode = mode,
user_specified_mode = !missing(mode),
method = NULL,
engine = "default_engine",
user_specified_engine = !missing(engine)
)
}
# R/my_model_data.R contains registrationCommon Patterns
Pattern 1: Single-Mode, Simple Arguments
linear_reg <- function(penalty = NULL,
mixture = NULL,
engine = "lm") {
args <- list(
penalty = rlang::enquo(penalty),
mixture = rlang::enquo(mixture)
)
new_model_spec(
"linear_reg",
args = args,
eng_args = NULL,
mode = "regression", # Fixed
user_specified_mode = FALSE,
method = NULL,
engine = engine,
user_specified_engine = !missing(engine)
)
}Pattern 2: Multi-Mode, Multiple Arguments
boost_tree <- function(mode = "unknown",
trees = NULL,
tree_depth = NULL,
learn_rate = NULL,
mtry = NULL,
min_n = NULL,
loss_reduction = NULL,
sample_size = NULL,
stop_iter = NULL,
engine = "xgboost") {
args <- list(
trees = rlang::enquo(trees),
tree_depth = rlang::enquo(tree_depth),
learn_rate = rlang::enquo(learn_rate),
mtry = rlang::enquo(mtry),
min_n = rlang::enquo(min_n),
loss_reduction = rlang::enquo(loss_reduction),
sample_size = rlang::enquo(sample_size),
stop_iter = rlang::enquo(stop_iter)
)
new_model_spec(
"boost_tree",
args = args,
eng_args = NULL,
mode = mode, # User specifies
user_specified_mode = !missing(mode),
method = NULL,
engine = engine,
user_specified_engine = !missing(engine)
)
}Pattern 3: With Input Validation
my_model <- function(mode = "unknown",
penalty = NULL,
engine = "default") {
# Validate mode
if (!mode %in% c("unknown", "regression", "classification")) {
rlang::abort("`mode` must be 'regression' or 'classification'")
}
# Validate penalty
if (!is.null(penalty) && (!is.numeric(penalty) || penalty < 0)) {
rlang::abort("`penalty` must be a non-negative number")
}
args <- list(penalty = rlang::enquo(penalty))
new_model_spec(
"my_model",
args = args,
eng_args = NULL,
mode = mode,
user_specified_mode = !missing(mode),
method = NULL,
engine = engine,
user_specified_engine = !missing(engine)
)
}Testing Constructors
Essential tests for model constructors:
test_that("constructor creates correct object", {
spec <- my_model()
expect_s3_class(spec, "my_model")
expect_s3_class(spec, "model_spec")
expect_equal(spec$mode, "regression")
expect_equal(spec$engine, "default_engine")
})
test_that("constructor accepts arguments", {
spec <- my_model(penalty = 0.1)
# Args are quosures
expect_true(rlang::is_quosure(spec$args$penalty))
# Can extract value
penalty_val <- rlang::eval_tidy(spec$args$penalty)
expect_equal(penalty_val, 0.1)
})
test_that("constructor tracks user specification", {
spec1 <- my_model()
expect_false(spec1$user_specified_engine)
spec2 <- my_model(engine = "other")
expect_true(spec2$user_specified_engine)
})
test_that("mode validation works", {
expect_error(
my_model(mode = "invalid"),
"mode"
)
})Summary
Key points for constructors:
- Function name - Descriptive, not engine-specific
- Mode parameter - Required for multi-mode, fixed for single-mode
- Main arguments - Standardized names, use
rlang::enquo() - Engine parameter - Sensible default, track if user-specified
- Use
new_model_spec()- Creates proper structure - Minimal validation - Most checks happen at fit time
- Good documentation - Explain engines, arguments, examples
- Return model_spec - With correct class hierarchy
The constructor is the user’s first interaction with your model - make it intuitive and well-documented.