Skip to content

Commit 0887cfe

Browse files
tlienartablaomadienes
authored
Minor release with Gramian training for OLS (#151)
* fix a doc typo * fixes following discussion around #147 * small adjustments + typo fixes * first pass at Gramian training for OLS (#146) * proof of concept * AbstractMatrix -> AVR * cleaner impl * endline * fix error type * construct kernels if not passed in * add test case for implicit gram construction * last endline * check for isempty instead of iszero * Prepare minor release with gramian training --------- Co-authored-by: Anthony D. Blaom <anthony.blaom@gmail.com> Co-authored-by: adienes <51664769+adienes@users.noreply.github.com>
1 parent a872d7c commit 0887cfe

File tree

12 files changed

+86
-16
lines changed

12 files changed

+86
-16
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MLJLinearModels"
22
uuid = "6ee0df7b-362f-4a72-a706-9e79364fb692"
33
authors = ["Thibaut Lienart <tlienart@me.com>"]
4-
version = "0.9.2"
4+
version = "0.10.0"
55

66
[deps]
77
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"

src/fit/default.jl

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,22 @@ $SIGNATURES
3333
Fit a generalised linear regression model using an appropriate solver based on
3434
the loss and penalty of the model. A method can, in some cases, be specified.
3535
"""
36-
function fit(glr::GLR, X::AbstractMatrix{<:Real}, y::AVR;
36+
function fit(glr::GLR, X::AbstractMatrix{<:Real}, y::AVR; data=nothing,
3737
solver::Solver=_solver(glr, size(X)))
38-
check_nrows(X, y)
39-
n, p = size(X)
40-
c = getc(glr, y)
41-
return _fit(glr, solver, X, y, scratch(n, p, c, i=glr.fit_intercept))
38+
if hasproperty(solver, :gram) && solver.gram
39+
# interpret X,y as X'X, X'y
40+
data = verify_or_construct_gramian(glr, X, y, data)
41+
p = size(data.XX, 2)
42+
return _fit(glr, solver, data.XX, data.Xy, (; dims=(data.n, p, 0)))
43+
else
44+
check_nrows(X, y)
45+
n, p = size(X)
46+
c = getc(glr, y)
47+
return _fit(glr, solver, X, y, scratch(n, p, c, i=glr.fit_intercept))
48+
end
4249
end
50+
fit(glr::GLR; kwargs...) = fit(glr, zeros((0,0)), zeros((0,)); kwargs...)
51+
4352

4453
function scratch(n, p, c=0; i=false)
4554
p_ = p + Int(i)

src/fit/proxgrad.jl

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# Assumption: loss has gradient; penalty has prox e.g.: Lasso
44
# J(θ) = f(θ) + r(θ) where f is smooth
55
function _fit(glr::GLR, solver::ProxGrad, X, y, scratch)
6-
_,p,c = npc(scratch)
6+
n,p,c = npc(scratch)
77
c > 0 && (p *= c)
88
# vector caches + eval cache
99
θ = zeros(p) # θ_k
@@ -19,9 +19,18 @@ function _fit(glr::GLR, solver::ProxGrad, X, y, scratch)
1919
η = 1.0 # stepsize (1/L)
2020
acc = ifelse(solver.accel, 1.0, 0.0) # if 0, no extrapolation (ISTA)
2121
# functions
22-
_f = smooth_objective(glr, X, y; c=c)
23-
_fg! = smooth_fg!(glr, X, y, scratch)
24-
_prox! = prox!(glr, size(X, 1))
22+
_f = if solver.gram
23+
smooth_gram_objective(glr, X, y, n)
24+
else
25+
smooth_objective(glr, X, y; c=c)
26+
end
27+
28+
_fg! = if solver.gram
29+
smooth_gram_fg!(glr, X, y, n)
30+
else
31+
smooth_fg!(glr, X, y, scratch)
32+
end
33+
_prox! = prox!(glr, n)
2534
bt_cond = θ̂ ->
2635
_f(θ̂) > fθ̄ + dot(θ̂ .- θ̄, ∇fθ̄) + sum(abs2.(θ̂ .- θ̄)) / (2η)
2736
# loop-related

src/fit/solvers.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ Proximal Gradient solver for non-smooth objective functions.
133133
tol::Float64 = 1e-4 # tol relative change of θ i.e. norm(θ-θ_)/norm(θ)
134134
max_inner::Int = 100 # β^max_inner should be > 1e-10
135135
beta::Float64 = 0.8 # in (0, 1); shrinkage in the backtracking step
136+
gram::Bool = false # use precomputed Gramian for lsq where possible
136137
end
137138

138139
FISTA(; kwa...) = ProxGrad(;accel = true, kwa...)

src/glr/d_l2loss.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,3 +72,12 @@ function smooth_fg!(glr::GLR{L2Loss,<:ENR}, X, y, scratch)
7272
return glr.loss(r) + get_l2(glr.penalty)(view_θ(glr, θ))
7373
end
7474
end
75+
76+
function smooth_gram_fg!(glr::GLR{L2Loss,<:ENR}, XX, Xy, n)
77+
λ = get_penalty_scale_l2(glr, n)
78+
(g, θ) -> begin
79+
_g = XX * θ .- Xy
80+
g .= _g .+ λ .* θ
81+
return θ'*_g + get_l2(glr.penalty)(view_θ(glr, θ))
82+
end
83+
end

src/glr/utils.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
export objective, smooth_objective
1+
export objective, smooth_objective, smooth_gram_objective
22

33
# NOTE: RobustLoss are not always everywhere smooth but "smooth-enough".
44
const SmoothLoss = Union{L2Loss, LogisticLoss, MultinomialLoss, RobustLoss}
@@ -37,6 +37,9 @@ Return the smooth part of the objective function of a GLR.
3737
"""
3838
smooth_objective(glr::GLR{<:SmoothLoss,<:ENR}, n) = glr.loss + get_l2(glr.penalty) * ifelse(glr.scale_penalty_with_samples, n, 1.)
3939

40+
smooth_gram_objective(glr::GLR{<:SmoothLoss,<:ENR}, XX, Xy, n) =
41+
θ ->'*XX*θ)/2 -'*Xy) + (get_l2(glr.penalty) * ifelse(glr.scale_penalty_with_samples, n, 1.))(θ)
42+
4043
smooth_objective(::GLR) = @error "Case not implemented yet."
4144

4245
"""

src/mlj/classifiers.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ See also [`MultinomialClassifier`](@ref).
6565
"""some instance of `MLJLinearModels.S` where `S` is one of: `LBFGS`, `Newton`,
6666
`NewtonCG`, `ProxGrad`; but subject to the following restrictions:
6767
68-
- If `penalty = :l2`, `ProxGrad` is disallowed. Otherwise, `ProxyGrad` is the only
68+
- If `penalty = :l2`, `ProxGrad` is disallowed. Otherwise, `ProxGrad` is the only
6969
option.
7070
7171
- Unless `scitype(y) <: Finite{2}` (binary target) `Newton` is disallowed.
@@ -142,7 +142,7 @@ See also [`LogisticClassifier`](@ref).
142142
"""some instance of `MLJLinearModels.S` where `S` is one of: `LBFGS`,
143143
`NewtonCG`, `ProxGrad`; but subject to the following restrictions:
144144
145-
- If `penalty = :l2`, `ProxGrad` is disallowed. Otherwise, `ProxyGrad` is the only
145+
- If `penalty = :l2`, `ProxGrad` is disallowed. Otherwise, `ProxGrad` is the only
146146
option.
147147
148148
- Unless `scitype(y) <: Finite{2}` (binary target) `Newton` is disallowed.

src/utils.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,29 @@ function check_nrows(X::AbstractMatrix, y::AbstractVecOrMat)::Nothing
99
throw(DimensionMismatch("`X` and `y` must have the same number of rows."))
1010
end
1111

12+
function verify_or_construct_gramian(glr, X, y, data)
13+
check_nrows(X, y)
14+
isnothing(data) && return (; XX = X'X, Xy = X'y, n = length(y))
15+
16+
!all(hasproperty.(Ref(data), (:XX, :Xy, :n))) && throw(ArgumentError("data must contain XX, Xy, n"))
17+
size(data.XX, 1) != size(data.Xy, 1) && throw(DimensionMismatch("`XX` and Xy` must have the same number of rows."))
18+
!issymmetric(data.XX) && throw(ArgumentError("Input `XX` must be symmetric"))
19+
20+
c = getc(glr, data.Xy)
21+
!iszero(c) && throw(ArgumentError("Categorical loss not supported with Gramian kernel"))
22+
glr.fit_intercept && throw(ArgumentError("Intercept not supported with Gramian kernel"))
23+
24+
if any(!isempty, (X, y))
25+
all((
26+
isapprox(X'X, data.XX; rtol=1e-5),
27+
isapprox(X'y, data.Xy; rtol=1e-5),
28+
length(y) == data.n
29+
)) || throw(ArgumentError("Inputs `X` and `y` do not match inputs `XX` and `Xy`."))
30+
end
31+
32+
return data
33+
end
34+
1235
"""
1336
$SIGNATURES
1437

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
33
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
44
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
5+
Downloads = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
56
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
67
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
78
MLJ = "add582a8-e3ab-11e8-2d5e-e98b27df1bc7"

test/benchmarks/robust.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22

33
if DO_COMPARISONS
44
@testset "Comp-QR-147" begin
5-
using CSV, DataFrames
5+
using CSV, DataFrames, Downloads
66

7-
dataset = CSV.read(download("http://freakonometrics.free.fr/rent98_00.txt"), DataFrame)
7+
dataset = CSV.read(Downloads.download("http://freakonometrics.free.fr/rent98_00.txt"), DataFrame)
88
tau = 0.3
99

1010
y = Vector(dataset[!,:rent_euro])

0 commit comments

Comments
 (0)