@@ -95,6 +95,8 @@ function StatsAPI.fit(
95
95
M1<: Union{Missing,<:Real} ,
96
96
M2<: Union{Missing,<:Real} ,
97
97
}
98
+ extra_args = filter_model_extra_arguments (M, kwargs)
99
+
98
100
X_ismissing = eltype (X) >: Missing
99
101
y_ismissing = eltype (y) >: Missing
100
102
if any ([y_ismissing, X_ismissing])
@@ -105,34 +107,53 @@ function StatsAPI.fit(
105
107
)
106
108
throw (ArgumentError (msg))
107
109
end
108
- X, y, _ = missing_omit (X, y)
110
+ X, y, nonmissings = missing_omit (X, y)
111
+
112
+ else
113
+ nonmissings = trues (size (y))
114
+ end
115
+ # drop (X,y) missing rows in extra_args
116
+ if all (nonmissings)
117
+ extra_args = NamedTuple (var => _missing_omit (val) for (var, val) in pairs (extra_args))
118
+ else
119
+ rows = findall (nonmissings)
120
+ extra_args = NamedTuple (var => _missing_omit (view (val, rows)) for (var, val) in pairs (extra_args))
109
121
end
110
122
111
123
# Make sure X and y have the same float eltype
112
124
pX, py = promote_to_same_float (X, y)
125
+ # Make sure extra values in keyword argument have the same float eltype
126
+ T = eltype (py)
127
+ extra_args = NamedTuple (var => convert_vec_to_float (T, val) for (var, val) in pairs (extra_args))
128
+
129
+ kwargs = (; kwargs... , extra_args... )
113
130
return fit (M, pX, py, args... ; kwargs... )
114
131
end
115
132
116
133
# # Convert from formula-data to modelmatrix-response calling form
117
134
# # the `fit` method must allow the `wts`, `contrasts` and `__formula` keyword arguments
118
- # # Specialize to allow other keyword arguments (offset, precision...) to be taken from
119
- # # a column of the dataframe.
135
+ # # Specialize the `model_extra_arguments` method to allow other keyword arguments
136
+ # # (offset, precision...) to be taken from a column of the dataframe.
120
137
function StatsAPI. fit (
121
138
:: Type{M} ,
122
139
f:: FormulaTerm ,
123
140
data,
124
141
args... ;
125
142
dropmissing:: Bool = false ,
126
- wts:: Union{Nothing,Symbol,FPVector} = nothing ,
127
143
contrasts:: AbstractDict{Symbol,Any} = Dict {Symbol,Any} (),
128
144
kwargs... ,
129
145
) where {M<: AbstractRobustModel }
130
146
# Extract arrays from data using formula
131
- f, y, X, extra = modelframe (f, data, contrasts, dropmissing, M; wts = wts )
147
+ f, y, X, extra = modelframe (M, f, data, contrasts, dropmissing; kwargs ... )
132
148
# Call the `fit` method with arrays
133
149
pX, py = promote_to_same_float (X, y)
150
+ # Make sure extra values in keyword argument have the same float eltype
151
+ T = eltype (py)
152
+ extra = NamedTuple (var => convert_vec_to_float (T, val) for (var, val) in pairs (extra))
153
+
154
+ kwargs = (; kwargs... , extra... )
134
155
return fit (
135
- M, pX, py, args... ; wts = extra . wts, contrasts= contrasts, __formula= f, kwargs...
156
+ M, pX, py, args... ; contrasts= contrasts, __formula= f, kwargs...
136
157
)
137
158
end
138
159
@@ -442,6 +463,19 @@ The arguments `X` and `y` can be a `Matrix` and a `Vector` or a `Formula` and a
442
463
rlm (X, y, args... ; kwargs... ) = fit (RobustLinearModel, X, y, args... ; kwargs... )
443
464
444
465
466
+ """
467
+ model_extra_arguments(::Type{M}) where {M<:RobustLinearModel}
468
+
469
+ Get the names of extra array arguments that are used by the model.
470
+ For RobustLinearModel, [:wts, :offset].
471
+
472
+ Returns an array of extra arguments used by the model.
473
+ """
474
+ function model_extra_arguments (:: Type{M} ; kwargs... ) where {M<: RobustLinearModel }
475
+ return [:wts , :offset ]
476
+ end
477
+
478
+
445
479
"""
446
480
fit(::Type{M},
447
481
X::Union{AbstractMatrix{T},SparseMatrixCSC{T}},
@@ -617,35 +651,6 @@ function StatsAPI.fit(
617
651
return dofit ? fit! (m; fitargs... ) : m
618
652
end
619
653
620
- # # Convert from formula-data to modelmatrix-response calling form
621
- # # the `fit` method must allow the `wts`, `offset`, `contrasts` and `__formula` keyword arguments
622
- function StatsAPI. fit (
623
- :: Type{M} ,
624
- f:: FormulaTerm ,
625
- data,
626
- args... ;
627
- dropmissing:: Bool = false ,
628
- wts:: Union{Nothing,Symbol,FPVector} = nothing ,
629
- offset:: Union{Nothing,Symbol,FPVector} = nothing ,
630
- contrasts:: AbstractDict{Symbol,Any} = Dict {Symbol,Any} (),
631
- kwargs... ,
632
- ) where {M<: RobustLinearModel }
633
- # Extract arrays from data using formula
634
- f, y, X, extra = modelframe (f, data, contrasts, dropmissing, M; wts= wts, offset= offset)
635
- # Call the `fit` method with arrays
636
- return fit (
637
- M,
638
- X,
639
- y,
640
- args... ;
641
- wts= extra. wts,
642
- offset= extra. offset,
643
- contrasts= contrasts,
644
- __formula= f,
645
- kwargs... ,
646
- )
647
- end
648
-
649
654
650
655
"""
651
656
fit(::Type{M}, X, y; kwarg...) where {M<:RobustLinearModel}
0 commit comments