Skip to content
This repository was archived by the owner on Apr 26, 2021. It is now read-only.

Commit 926da7e

Browse files
committed
updated to MLJ 0.14
1 parent 69e98e5 commit 926da7e

File tree

3 files changed

+23
-23
lines changed

3 files changed

+23
-23
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "JLBoostMLJ"
22
uuid = "8b86df2c-1bc3-481d-95df-1c4d5a20ed95"
33
authors = ["Dai ZJ <zhuojia.dai@gmail.com>"]
4-
version = "0.1.8"
4+
version = "0.1.9"
55

66
[deps]
77
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
@@ -15,7 +15,7 @@ ScientificTypes = "321657f4-b219-11e9-178b-2701a2544e81"
1515
DataFrames = "0.21"
1616
JLBoost = "^0.1.8"
1717
LossFunctions = "0.5, 0.6"
18-
MLJ = "0.10, 0.11, 0.12, 0.13"
18+
MLJ = "0.10, 0.11, 0.12, 0.13, 0.14"
1919
MLJBase = "0.12, 0.13, 0.14, 0.15"
2020
ScientificTypes = "0.7, 0.8, 1.0"
2121
julia = "1"

README.md

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,15 @@ model = JLBoostClassifier()
2121

2222
````
2323
JLBoostClassifier(
24-
loss = LogitLogLoss(),
24+
loss = JLBoost.LogitLogLoss(),
2525
nrounds = 1,
2626
subsample = 1.0,
2727
eta = 1.0,
2828
max_depth = 6,
2929
min_child_weight = 1.0,
3030
lambda = 0.0,
3131
gamma = 0.0,
32-
colsample_bytree = 1) @810
32+
colsample_bytree = 1) @087
3333
````
3434

3535

@@ -47,11 +47,11 @@ mljmachine = machine(model, X, y)
4747

4848

4949
````
50-
Machine{JLBoostClassifier} @772 trained 0 times.
50+
Machine{JLBoostClassifier} @730 trained 0 times.
5151
args:
52-
1: Source @097 ⏎ `ScientificTypes.Table{AbstractArray{ScientificTypes.C
52+
1: Source @910 ⏎ `ScientificTypes.Table{AbstractArray{ScientificTypes.C
5353
ontinuous,1}}`
54-
2: Source @077 ⏎ `AbstractArray{ScientificTypes.Count,1}`
54+
2: Source @954 ⏎ `AbstractArray{ScientificTypes.Count,1}`
5555
````
5656

5757

@@ -81,11 +81,11 @@ Choosing a split on SepalLength
8181
Choosing a split on SepalWidth
8282
Choosing a split on PetalLength
8383
Choosing a split on PetalWidth
84-
Machine{JLBoostClassifier} @772 trained 1 time.
84+
Machine{JLBoostClassifier} @730 trained 1 time.
8585
args:
86-
1: Source @097 ⏎ `ScientificTypes.Table{AbstractArray{ScientificTypes.C
86+
1: Source @910 ⏎ `ScientificTypes.Table{AbstractArray{ScientificTypes.C
8787
ontinuous,1}}`
88-
2: Source @077 ⏎ `AbstractArray{ScientificTypes.Count,1}`
88+
2: Source @954 ⏎ `AbstractArray{ScientificTypes.Count,1}`
8989
````
9090

9191

@@ -216,11 +216,11 @@ m = machine(tm, X, y_cate)
216216

217217

218218
````
219-
Machine{ProbabilisticTunedModel{Grid,…}} @388 trained 0 times.
219+
Machine{ProbabilisticTunedModel{Grid,…}} @109 trained 0 times.
220220
args:
221-
1: Source @578 ⏎ `ScientificTypes.Table{AbstractArray{ScientificTypes.C
221+
1: Source @664 ⏎ `ScientificTypes.Table{AbstractArray{ScientificTypes.C
222222
ontinuous,1}}`
223-
2: Source @226 ⏎ `AbstractArray{ScientificTypes.Multiclass{2},1}`
223+
2: Source @788 ⏎ `AbstractArray{ScientificTypes.Multiclass{2},1}`
224224
````
225225

226226

@@ -235,11 +235,11 @@ fit!(m)
235235

236236

237237
````
238-
Machine{ProbabilisticTunedModel{Grid,…}} @388 trained 1 time.
238+
Machine{ProbabilisticTunedModel{Grid,…}} @109 trained 1 time.
239239
args:
240-
1: Source @578 ⏎ `ScientificTypes.Table{AbstractArray{ScientificTypes.C
240+
1: Source @664 ⏎ `ScientificTypes.Table{AbstractArray{ScientificTypes.C
241241
ontinuous,1}}`
242-
2: Source @226 ⏎ `AbstractArray{ScientificTypes.Multiclass{2},1}`
242+
2: Source @788 ⏎ `AbstractArray{ScientificTypes.Multiclass{2},1}`
243243
````
244244

245245

@@ -287,15 +287,15 @@ Choosing a split on SepalLength
287287
Choosing a split on SepalWidth
288288
Choosing a split on PetalLength
289289
Choosing a split on PetalWidth
290-
(fitresult = (treemodel = JLBoostTreeModel(AbstractJLBoostTree[eta = 1.0 (t
291-
ree weight)
290+
(fitresult = (treemodel = JLBoost.JLBoostTrees.JLBoostTreeModel(JLBoost.JLB
291+
oostTrees.AbstractJLBoostTree[eta = 1.0 (tree weight)
292292
293293
-- PetalLength <= 1.9
294294
---- weight = 2.0
295295
296296
-- PetalLength > 1.9
297297
---- weight = -2.0
298-
], LogitLogLoss(), :__y__),
298+
], JLBoost.LogitLogLoss(), :__y__),
299299
target_levels = Bool[0, 1],),
300300
cache = nothing,
301301
report = (AUC = 0.16666666666666669,

src/mlj.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
export fit, predict, fitted_params, JLBoostMLJModel, JLBoostClassifier, JLBoostRegressor, JLBoostCount
22

3-
#using MLJBase
3+
import MLJBase
44
import MLJBase: Probabilistic, Deterministic, clean!, fit, predict, fitted_params, load_path, Table
55
import MLJBase: package_name, package_uuid, package_url, is_pure_julia, package_license
66
import MLJBase: input_scitype, target_scitype, docstring, UnivariateFinite
77

88
using ScientificTypes: Continuous, OrderedFactor, Count, Multiclass, Finite
99

1010
using LossFunctions: PoissonLoss, L2DistLoss
11-
using JLBoost: LogitLogLoss, jlboost, AUC, gini, feature_importance, predict
11+
using JLBoost: LogitLogLoss, jlboost, AUC, gini, feature_importance
1212

1313
using DataFrames: DataFrame, nrow, levels, categorical
1414

@@ -194,15 +194,15 @@ fitted_params(model::JLBoostMLJModel, fitresult) = (fitresult = fitresult.treemo
194194

195195

196196
# seehttps://alan-turing-institute.github.io/MLJ.jl/stable/adding_models_for_general_use/#The-predict-method-1
197-
predict(model::JLBoostClassifier, fitresult, Xnew) = begin
197+
function MLJBase.predict(model::JLBoostClassifier, fitresult, Xnew)
198198
res = JLBoost.predict(fitresult.treemodel, Xnew)
199199
p = 1 ./ (1 .+ exp.(-res))
200200
levels_cate = categorical(fitresult.target_levels)
201201
[UnivariateFinite(levels_cate, [p, 1-p]) for p in p]
202202
end
203203

204204

205-
predict(model::JLBoostMLJModel, fitresult, Xnew) = begin
205+
function MLJBase.predict(model::JLBoostMLJModel, fitresult, Xnew)
206206
JLBoost.predict(fitresult, Xnew)
207207
end
208208

0 commit comments

Comments
 (0)