torch-openreml
A PyTorch-based library for AI-REML estimation of linear mixed models.
Overview¶
torch-openreml fits linear mixed-effects models using the Average Information REML (AI-REML) algorithm on a PyTorch backend. It supports flexible specification of covariance structures through a modular system of matrices and operators, along with automatic or manual gradients and optional parameter transformations for constrained estimation.
Unlike traditional mixed-model software, it does not provide a formula interface. Instead, users define the fixed- and random-effects design matrices and covariance structures directly in code. The library is focused purely on the computational and optimization backend rather than model specification syntax.
Features¶
Torch-based backend:
Built on PyTorch, supporting execution on CPU, GPU, and other available accelerators.
AI-REML estimation engine
Variance component estimation using the Average Information REML (AI-REML) quasi-Newton optimization framework.
Extensible covariance structure
Composable covariance structures and operators from built-in and user-defined components.
Hybrid differentiation support
Support automatic differentiation and manually specified gradients.
Composable parameter transformations
Configurable, chainable transformation pipelines for flexible parameterization.
Installation¶
# TODO: replace with actual install command when packaging is set up
pip install torch-openreml
Dependencies: torch, pandas, tqdm (Python 3.12).
Getting Started¶
Dataset¶
To illustrate a quick start with the library, we begin by fitting a mixed-effects model using
the john_alpha dataset. This dataset contains field trial data from a resolvable alpha lattice design
conducted at Craibstone near Aberdeen.
It consists of 72 observations and 7 variables. In this example, we use yield (dry matter yield) as the
response variable, and rep (replicate identifier), block (incomplete block within replicate),
and gen (genotype or variety identifier) as covariates.
plot rep block gen yield row col
0 1 R1 B1 G11 4.1172 1 1
1 2 R1 B1 G04 4.4461 2 1
2 3 R1 B1 G05 5.8757 3 1
3 4 R1 B1 G22 4.5784 4 1
4 5 R1 B2 G21 4.6540 5 1
.. ... .. ... ... ... ... ...
67 68 R3 B5 G24 3.5655 68 1
68 69 R3 B6 G03 2.8873 69 1
69 70 R3 B6 G05 4.1972 70 1
70 71 R3 B6 G20 3.7349 71 1
71 72 R3 B6 G07 3.6096 72 1
[72 rows x 7 columns]
Model Specification¶
The model includes an intercept, a single categorical fixed effect (rep), a random intercept for gen, and a random interaction effect between rep and block.
The model is specified as:
with marginal covariance structure:
and distributional assumptions:
For the present model, which includes two random intercept components and their interaction, the covariance contribution from the random effects is expressed as:
where \(\mathbf{G}_{rep} = \mathbf{I}\) is fixed as the identity matrix for identifiability.
Import modules¶
We begin by importing the required modules.
import torch
from torch_openreml import REML
from torch_openreml.utils import augment, n_distinct
from torch_openreml.covariance import DummyMatrix, IdentityMatrix, ScalarMatrix, Sum, CovariancePropagation, KroneckerProduct
from torch_openreml.example_data import john_alpha
Covariance Builder¶
Next, we construct \(\mathbf{y}\), \(\mathbf{X}\), and the components required to define \(\mathbf{V}\). Both \(\mathbf{y}\) and \(\mathbf{X}\) are represented as torch tensors. The DummyMatrix class serves as a matrix builder: it constructs the dummy matrix upon evaluation and accepts either pandas.Series or lists of strings as input. The argument drop_first=True removes the first column of the dummy matrix to avoid redundancy, as an intercept term is already included.
The classes ScalarMatrix and IdentityMatrix are also matrix builders, parameterized by the required matrix dimension.
We then assemble the covariance structure using composable operators. The CovariancePropagation operator represents the transformation \(\mathbf{Z}\mathbf{G}\mathbf{Z}^\top\). The KroneckerProduct operator computes the direct (Kronecker) product of two matrices, and Sum aggregates multiple matrix components.
Altogether, the covariance structure can be written as:
where \(\mathbf{G}_{rep} = \mathbf{I}\), \(\mathbf{G}_{gen} = \sigma^2_{gen}\mathbf{I}\), and \(\mathbf{G}_{block} = \sigma^2_{block}\mathbf{I}\).
The model parameters are defined as:
The logarithmic parameterization ensures that the variance components remain positive during optimization.
y = torch.tensor(john_alpha["yield"].values)
X = augment(torch.ones(len(john_alpha), 1),
DummyMatrix(john_alpha["rep"], drop_first=True)())
Z_gen = DummyMatrix(john_alpha["gen"])
Z_rep_block = DummyMatrix(john_alpha["rep"], john_alpha["block"])
G_gen = ScalarMatrix(n_distinct(john_alpha["gen"]))
G_rep = IdentityMatrix(n_distinct(john_alpha["rep"]))
G_block = ScalarMatrix(n_distinct(john_alpha["block"]))
R = ScalarMatrix(len(john_alpha))
V = Sum(
CovariancePropagation(Z_gen, G_gen),
CovariancePropagation(
Z_rep_block,
KroneckerProduct(G_rep, G_block)
),
R
)
print(V)
Sum(
operands={
'op_0': CovariancePropagation(
operands={
'op_0': DummyMatrix(shape=(72, 24)),
'op_1': ScalarMatrix(shape=(24, 24), param_specs={'sigma^2': {'fixed': False, 'default': tensor([0.]), 'trans': TransformExpPow2()}})
}
),
'op_1': CovariancePropagation(
operands={
'op_0': DummyMatrix(shape=(72, 18)),
'op_1': KroneckerProduct(
operands={
'op_0': IdentityMatrix(shape=(3, 3)),
'op_1': ScalarMatrix(shape=(6, 6), param_specs={'sigma^2': {'fixed': False, 'default': tensor([0.]), 'trans': TransformExpPow2()}})
}
)
}
),
'op_2': ScalarMatrix(shape=(72, 72), param_specs={'sigma^2': {'fixed': False, 'default': tensor([0.]), 'trans': TransformExpPow2()}})
}
)
REML Optimizer¶
Once the covariance structure has been defined, it is passed to REML to initialize the estimation procedure. The optimize method is then called with \(\mathbf{y}\), \(\mathbf{X}\), and an initial value for \(\boldsymbol{\theta}\) (set to zeros in this example). The verbose argument controls the level of diagnostic output.
Because the optimization is performed on the transformed parameter scale, the estimated parameters can be mapped back to variance components using V.build_params. The resulting values correspond to the variance components associated with the parameter names stored in V.param_names.
reml = REML(V)
theta_hat, beta_hat, n_iter = reml.optimize(y, X, torch.zeros(3), verbose=2)
print(theta_hat, V.build_params(theta_hat))
print(V.free_param_names)
print(beta_hat)
∥∇∥: 41.8197, ∥Δ∥: 8.9438, η: 1.00, ∥Δᶜ∥: 8.9438, log 𝓛: -34.3129
∥∇∥: 315227.4375, ∥Δ∥: 0.8838, η: 1.00, ∥Δᶜ∥: 0.8838, log 𝓛: -201185.3906 (-201151.0777)
∥∇∥: 119229.9453, ∥Δ∥: 0.8520, η: 1.00, ∥Δᶜ∥: 0.8520, log 𝓛: -71844.4531 (+129340.9375)
∥∇∥: 44273.4570, ∥Δ∥: 0.8143, η: 1.00, ∥Δᶜ∥: 0.8143, log 𝓛: -26065.4551 (+45778.9980)
∥∇∥: 16330.5117, ∥Δ∥: 0.7587, η: 1.00, ∥Δᶜ∥: 0.7587, log 𝓛: -9457.1396 (+16608.3154)
∥∇∥: 6001.4102, ∥Δ∥: 0.7142, η: 1.00, ∥Δᶜ∥: 0.7142, log 𝓛: -3392.8696 (+6064.2700)
∥∇∥: 2194.8989, ∥Δ∥: 0.7031, η: 1.00, ∥Δᶜ∥: 0.7031, log 𝓛: -1181.2805 (+2211.5891)
∥∇∥: 793.7851, ∥Δ∥: 0.6952, η: 1.00, ∥Δᶜ∥: 0.6952, log 𝓛: -382.4908 (+798.7897)
∥∇∥: 278.6875, ∥Δ∥: 0.6715, η: 1.00, ∥Δᶜ∥: 0.6715, log 𝓛: -102.2992 (+280.1916)
∥∇∥: 90.4116, ∥Δ∥: 0.5957, η: 1.00, ∥Δᶜ∥: 0.5957, log 𝓛: -11.4274 (+90.8718)
∥∇∥: 23.9216, ∥Δ∥: 0.4006, η: 1.00, ∥Δᶜ∥: 0.4006, log 𝓛: 12.6792 (+24.1066)
∥∇∥: 3.8674, ∥Δ∥: 0.1392, η: 1.00, ∥Δᶜ∥: 0.1392, log 𝓛: 16.5908 (+3.9116)
∥∇∥: 0.2448, ∥Δ∥: 0.0157, η: 1.00, ∥Δᶜ∥: 0.0157, log 𝓛: 16.8078 (+0.2170)
∥∇∥: 0.0104, ∥Δ∥: 0.0007, η: 1.00, ∥Δᶜ∥: 0.0007, log 𝓛: 16.8099 (+0.0021)
∥∇∥: 0.0004, ∥Δ∥: 0.0000, η: 1.00, ∥Δᶜ∥: 0.0000, log 𝓛: 16.8100 (+0.0001)
∥∇∥: 0.0000, ∥Δ∥: 0.0000, η: 1.00, ∥Δᶜ∥: 0.0000, log 𝓛: 16.8099 (-0.0001)
[∇: score, Δ: 𝐉⁻¹∇, η: learning rate, Δᶜ: clip(𝛉 + ηΔ, lb, ub) - 𝛉, 𝓛: restricted likelihood]
✓ Converged at iteration 16
tensor([-0.9728, -1.3281, -1.2529]) tensor([0.1429, 0.0702, 0.0816])
['op_0/op_1/sigma^2', 'op_1/op_1/op_1/sigma^2', 'op_2/sigma^2']
tensor([ 4.5183, 0.2978, -0.4140])
Documentation¶
Section |
Description |
|---|---|
Vignettes |
|
Model formulation, REML and ML theory, score and AI matrix derivations. |
|
Full documentation for |
Citing¶
@software{torch_openreml,
author = {Weihao Li},
title = {torch-openreml},
year = {2026},
url = {https://github.com/anu-aagi/torch-openreml/}
}