Skip to content

Commit 18b8593

Browse files
Merge pull request #35 from tylerjthomas9/dev
fix some MLJ warnings, add MMI.reformat for eventual update support
2 parents c9461f7 + 3c9847f commit 18b8593

File tree

7 files changed

+26
-42
lines changed

7 files changed

+26
-42
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
1212
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
1313

1414
[compat]
15-
CUDA = "3"
15+
CUDA = "3, 4"
1616
CondaPkg = "0.2"
1717
MLJBase = "0.20, 0.21"
1818
MLJModelInterface = "1"

src/CuML/CuML.jl

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,21 +24,29 @@ const CUML_MODELS = Union{
2424
CUML_TIME_SERIES,
2525
}
2626

27-
function MMI.reformat(::CUML_MODELS, X)
28-
return (to_numpy(X),)
27+
function MMI.reformat(::CUML_MODELS, X, y)
28+
return (to_numpy(X), to_numpy(y))
2929
end
3030

31-
function MMI.reformat(::CUML_MODELS, X, y)
32-
return to_numpy(X), to_numpy(y)
31+
function MMI.reformat(::CUML_MODELS, X)
32+
return (to_numpy(X), )
3333
end
3434

3535
function MMI.selectrows(::CUML_MODELS, I, X)
36-
py_I = numpy.array(I .- 1)
37-
return (X[py_I,],)
36+
py_I = numpy.array(numpy.array(I .- 1))
37+
return (X[py_I,], )
3838
end
3939

4040
function MMI.selectrows(::CUML_MODELS, I::Colon, X)
41-
return (X,)
41+
return (X, )
42+
end
43+
44+
function MMI.selectrows(::CUML_MODELS, I, X, y)
45+
py_I = numpy.array(numpy.array(I .- 1))
46+
return (X[py_I,], y[py_I])
47+
end
48+
function MMI.selectrows(::CUML_MODELS, I::Colon, X, y)
49+
return (X, y)
4250
end
4351

4452
MMI.clean!(model::CUML_MODELS) = ""

src/CuML/classification.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ function MMI.input_scitype(::Type{<:CUML_CLASSIFICATION})
127127
AbstractMatrix{MMI.Continuous},
128128
}
129129
end
130-
MMI.target_scitype(::Type{<:CUML_CLASSIFICATION}) = AbstractVector{<:Finite}
130+
MMI.target_scitype(::Type{<:CUML_CLASSIFICATION}) = Union{AbstractVector{<:Finite}, AbstractVector{MMI.Continuous}}
131131

132132
function MMI.docstring(::Type{<:LogisticRegression})
133133
return "cuML's LogisticRegression: https://docs.rapids.ai/api/cuml/stable/api.html#logistic-regression"

src/CuML/dimensionality_reduction.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ MMI.load_path(::Type{<:TSNE}) = "$PKG.CuML.TSNE"
122122

123123
function MMI.input_scitype(::Type{<:CUML_DIMENSIONALITY_REDUCTION})
124124
return Union{
125-
MMI.Table(MMI.Continuous, MMI.Count, MMI.OrderedFactor, MMI.Multiclass),
125+
Table{<:Union{AbstractVector{<:Continuous}, AbstractVector{<:Count}, AbstractVector{<:OrderedFactor}, AbstractVector{<:Multiclass}}},
126126
AbstractMatrix{MMI.Continuous},
127127
}
128128
end

src/RAPIDS.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ if Base.VERSION <= v"1.8.3"
1818
@warn warning_msg
1919
end
2020

21-
if !CUDA.functional()
21+
if !CUDA.has_cuda_gpu()
2222
@warn "No CUDA GPU Detected. Unable to load RAPIDS."
2323
const cucim = nothing
2424
const cudf = nothing

test/cuml.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,14 +142,14 @@ end
142142
end
143143

144144
@testset "SVC" begin
145-
model = SVC()
145+
model = SVC(probability=true)
146146
mach = machine(model, X, y)
147147
fit!(mach)
148148
preds = predict(mach, X)
149149
end
150150

151151
@testset "LinearSVC" begin
152-
model = LinearSVC()
152+
model = LinearSVC(probability=true)
153153
mach = machine(model, X, y)
154154
fit!(mach)
155155
preds = predict(mach, X)

test/cuml_integration.jl

Lines changed: 5 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
],
1717
MLJTestInterface.make_regression()...;
1818
mod=@__MODULE__,
19-
verbosity=0, # bump to debug
19+
verbosity=1, # bump to debug
2020
throw=false,
2121
)
2222
@test isempty(failures)
@@ -30,41 +30,17 @@
3030
y[y_string .== "O"] .= 1.0
3131
failures, summary = MLJTestInterface.test(
3232
[
33-
LogisticRegression,
33+
#LogisticRegression,
3434
MBSGDClassifier,
3535
RandomForestClassifier,
36-
SVC,
37-
LinearSVC,
36+
#SVC,
37+
#LinearSVC,
3838
KNeighborsClassifier,
3939
],
4040
X,
4141
y;
4242
mod=@__MODULE__,
43-
verbosity=0, # bump to debug
44-
throw=false,
45-
)
46-
@test isempty(failures)
47-
end
48-
49-
@testset "Binary Classification" begin
50-
X, y_string = MLJTestInterface.make_binary()
51-
# RAPIDS can only handle numeric values
52-
# TODO: add support for non-numeric labels
53-
y = zeros(200)
54-
y[y_string .== "O"] .= 1.0
55-
failures, summary = MLJTestInterface.test(
56-
[
57-
LogisticRegression,
58-
MBSGDClassifier,
59-
RandomForestClassifier,
60-
SVC,
61-
LinearSVC,
62-
KNeighborsClassifier,
63-
],
64-
X,
65-
y;
66-
mod=@__MODULE__,
67-
verbosity=0, # bump to debug
43+
verbosity=1, # bump to debug
6844
throw=false,
6945
)
7046
@test isempty(failures)

0 commit comments

Comments
 (0)