Skip to content

Commit fc02408

Browse files
committed
refactorize input validation for fit method
1 parent 47eadbd commit fc02408

File tree

2 files changed

+125
-46
lines changed

2 files changed

+125
-46
lines changed

src/robustlinearmodel.jl

Lines changed: 40 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,8 @@ function StatsAPI.fit(
9595
M1<:Union{Missing,<:Real},
9696
M2<:Union{Missing,<:Real},
9797
}
98+
extra_args = filter_model_extra_arguments(M, kwargs)
99+
98100
X_ismissing = eltype(X) >: Missing
99101
y_ismissing = eltype(y) >: Missing
100102
if any([y_ismissing, X_ismissing])
@@ -105,34 +107,53 @@ function StatsAPI.fit(
105107
)
106108
throw(ArgumentError(msg))
107109
end
108-
X, y, _ = missing_omit(X, y)
110+
X, y, nonmissings = missing_omit(X, y)
111+
112+
else
113+
nonmissings = trues(size(y))
114+
end
115+
# drop (X,y) missing rows in extra_args
116+
if all(nonmissings)
117+
extra_args = NamedTuple(var => _missing_omit(val) for (var, val) in pairs(extra_args))
118+
else
119+
rows = findall(nonmissings)
120+
extra_args = NamedTuple(var => _missing_omit(view(val, rows)) for (var, val) in pairs(extra_args))
109121
end
110122

111123
# Make sure X and y have the same float eltype
112124
pX, py = promote_to_same_float(X, y)
125+
# Make sure extra values in keyword argument have the same float eltype
126+
T = eltype(py)
127+
extra_args = NamedTuple(var => convert_vec_to_float(T, val) for (var, val) in pairs(extra_args))
128+
129+
kwargs = (; kwargs..., extra_args...)
113130
return fit(M, pX, py, args...; kwargs...)
114131
end
115132

116133
## Convert from formula-data to modelmatrix-response calling form
117134
## the `fit` method must allow the `wts`, `contrasts` and `__formula` keyword arguments
118-
## Specialize to allow other keyword arguments (offset, precision...) to be taken from
119-
## a column of the dataframe.
135+
## Specialize the `model_extra_arguments` method to allow other keyword arguments
136+
## (offset, precision...) to be taken from a column of the dataframe.
120137
function StatsAPI.fit(
121138
::Type{M},
122139
f::FormulaTerm,
123140
data,
124141
args...;
125142
dropmissing::Bool=false,
126-
wts::Union{Nothing,Symbol,FPVector}=nothing,
127143
contrasts::AbstractDict{Symbol,Any}=Dict{Symbol,Any}(),
128144
kwargs...,
129145
) where {M<:AbstractRobustModel}
130146
# Extract arrays from data using formula
131-
f, y, X, extra = modelframe(f, data, contrasts, dropmissing, M; wts=wts)
147+
f, y, X, extra = modelframe(M, f, data, contrasts, dropmissing; kwargs...)
132148
# Call the `fit` method with arrays
133149
pX, py = promote_to_same_float(X, y)
150+
# Make sure extra values in keyword argument have the same float eltype
151+
T = eltype(py)
152+
extra = NamedTuple(var => convert_vec_to_float(T, val) for (var, val) in pairs(extra))
153+
154+
kwargs = (; kwargs..., extra...)
134155
return fit(
135-
M, pX, py, args...; wts=extra.wts, contrasts=contrasts, __formula=f, kwargs...
156+
M, pX, py, args...; contrasts=contrasts, __formula=f, kwargs...
136157
)
137158
end
138159

@@ -442,6 +463,19 @@ The arguments `X` and `y` can be a `Matrix` and a `Vector` or a `Formula` and a
442463
rlm(X, y, args...; kwargs...) = fit(RobustLinearModel, X, y, args...; kwargs...)
443464

444465

466+
"""
467+
model_extra_arguments(::Type{M}) where {M<:RobustLinearModel}
468+
469+
Get the names of extra array arguments that are used by the model.
470+
For RobustLinearModel, [:wts, :offset].
471+
472+
Returns an array of extra arguments used by the model.
473+
"""
474+
function model_extra_arguments(::Type{M}; kwargs...) where {M<:RobustLinearModel}
475+
return [:wts, :offset]
476+
end
477+
478+
445479
"""
446480
fit(::Type{M},
447481
X::Union{AbstractMatrix{T},SparseMatrixCSC{T}},
@@ -617,35 +651,6 @@ function StatsAPI.fit(
617651
return dofit ? fit!(m; fitargs...) : m
618652
end
619653

620-
## Convert from formula-data to modelmatrix-response calling form
621-
## the `fit` method must allow the `wts`, `offset`, `contrasts` and `__formula` keyword arguments
622-
function StatsAPI.fit(
623-
::Type{M},
624-
f::FormulaTerm,
625-
data,
626-
args...;
627-
dropmissing::Bool=false,
628-
wts::Union{Nothing,Symbol,FPVector}=nothing,
629-
offset::Union{Nothing,Symbol,FPVector}=nothing,
630-
contrasts::AbstractDict{Symbol,Any}=Dict{Symbol,Any}(),
631-
kwargs...,
632-
) where {M<:RobustLinearModel}
633-
# Extract arrays from data using formula
634-
f, y, X, extra = modelframe(f, data, contrasts, dropmissing, M; wts=wts, offset=offset)
635-
# Call the `fit` method with arrays
636-
return fit(
637-
M,
638-
X,
639-
y,
640-
args...;
641-
wts=extra.wts,
642-
offset=extra.offset,
643-
contrasts=contrasts,
644-
__formula=f,
645-
kwargs...,
646-
)
647-
end
648-
649654

650655
"""
651656
fit(::Type{M}, X, y; kwarg...) where {M<:RobustLinearModel}

src/tools.jl

Lines changed: 85 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,13 @@ function promote_to_same_float(X::AbstractMatrix, y::AbstractVector)
1515
return convert.(T, X)::MT, convert.(T, y)::VT
1616
end
1717

18+
19+
function convert_vec_to_float(::Type{T}, v::AbstractVector{<:Real}) where {T<:AbstractFloat}
20+
VT = AbstractVector{T}
21+
return convert.(T, v)::VT
22+
end
23+
24+
1825
_missing_omit(x::AbstractArray{T}) where {T} = copyto!(similar(x, nonmissingtype(T)), x)
1926

2027
function StatsModels.missing_omit(X::AbstractMatrix, y::AbstractVector)
@@ -73,6 +80,62 @@ end
7380
################################################
7481

7582
const ModelFrameType = Tuple{FormulaTerm,<:AbstractVector,<:AbstractMatrix,NamedTuple}
83+
const AllowedExtraArgType = Union{Nothing,Symbol,Union{AbstractVector{<:Real},AbstractVector{Union{Missing,<:Real}}}}
84+
85+
86+
"""
87+
model_extra_arguments(::Type{M}) where {M<:AbstractRobustModel}
88+
89+
Get the names of extra array arguments (of the same size than the response ``y``)
90+
that are used by the model.
91+
For general AbstractRobustModel, `wts` is a possible keyword argument.
92+
93+
This method should specialize to the different models to include other arguments (offset, invvar...)
94+
95+
Returns an array of extra arguments used by the model.
96+
"""
97+
function model_extra_arguments(::Type{M}) where {M<:AbstractRobustModel}
98+
return [:wts]
99+
end
100+
101+
102+
"""
103+
filter_model_extra_arguments(::Type{M}, kwargs) where {M<:AbstractRobustModel}
104+
105+
Filter the kwargs to keep only the extra arguments used by the model.
106+
The values are checked to be of type Union{Nothing,Symbol,AbstractVector{<:Real}}.
107+
108+
For general AbstractRobustModel, `wts` keyword argument is fetched from the data table.
109+
110+
Returns a Dict of (key, value) with key, extra arguments used by the model, and value,
111+
given by `kwargs`.
112+
"""
113+
function filter_model_extra_arguments(
114+
::Type{M},
115+
kwargs::Union{Dict{Symbol,Any}, Base.Pairs, NamedTuple},
116+
) where {M<:AbstractRobustModel}
117+
allowed = model_extra_arguments(M)
118+
119+
extra = Dict{Symbol, AllowedExtraArgType}()
120+
for (k, val) in pairs(kwargs)
121+
s = Symbol(k)
122+
if !(s in allowed)
123+
continue
124+
end
125+
if !isa(val, AllowedExtraArgType)
126+
msg = (
127+
"extra argument does not have a compatible type, should be Nothing, " *
128+
"a Symbol or a real array: $(typeof(val))"
129+
)
130+
@warn(msg)
131+
continue
132+
end
133+
# Add to dict
134+
extra[s] = val
135+
end
136+
return extra
137+
end
138+
76139

77140
"""
78141
modelframe(f::FormulaTerm, data, contrasts::AbstractDict, ::Type{M}; kwargs...) where M
@@ -84,15 +147,17 @@ are extracted from the `data` Table using the formula `f`.
84147
Adapted from GLM.jl
85148
"""
86149
function modelframe(
87-
f::FormulaTerm, data, contrasts::AbstractDict, dropmissing::Bool, ::Type{M}; kwargs...
150+
::Type{M}, f::FormulaTerm, data, contrasts::AbstractDict, dropmissing::Bool; kwargs...
88151
)::ModelFrameType where {M<:AbstractRobustModel}
89152
# Check is a Table
90153
Tables.istable(data) ||
91154
throw(ArgumentError("expected data in a Table, got $(typeof(data))"))
92155
t = Tables.columntable(data)
93156

94-
# Check columns exist
157+
# Get columns
95158
cols = collect(termvars(f))
159+
160+
# Check columns exist
96161
msg = ""
97162
for col in cols
98163
msg *= checkcol(t, col)
@@ -101,7 +166,12 @@ function modelframe(
101166
end
102167
end
103168
msg != "" && throw(ArgumentError("Error with formula term names.\n" * msg))
104-
for val in Base.values(kwargs)
169+
170+
# Get extra columns
171+
extra_args = filter_model_extra_arguments(M, kwargs)
172+
173+
# Check extra columns exist
174+
for val in values(extra_args)
105175
if isa(val, Symbol)
106176
msg = checkcol(t, val)
107177
msg != "" && throw(ArgumentError("Error with extra column name.\n" * msg))
@@ -132,15 +202,19 @@ function modelframe(
132202
# response and model matrix
133203
## Do not copy the arrays!
134204
y, X = modelcols(f, t)
135-
extra_vec = NamedTuple(var => (
136-
if isa(val, Symbol)
137-
t[val]
138-
elseif isnothing(val)
139-
similar(y, 0)
140-
else
141-
val
205+
206+
extra_vec = NamedTuple(
207+
var => begin
208+
if isa(val, Symbol)
209+
t[val]
210+
elseif isnothing(val)
211+
similar(y, 0)
212+
else
213+
val
214+
end
142215
end
143-
) for (var, val) in pairs(kwargs))
216+
for (var, val) in pairs(extra_args)
217+
)
144218

145219
return f, y, X, extra_vec
146220
end

0 commit comments

Comments
 (0)