Skip to content
This repository was archived by the owner on Apr 26, 2021. It is now read-only.

Commit 69e98e5

Browse files
committed
update
1 parent 5f63a9d commit 69e98e5

File tree

2 files changed

+80
-61
lines changed

2 files changed

+80
-61
lines changed

Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "JLBoostMLJ"
22
uuid = "8b86df2c-1bc3-481d-95df-1c4d5a20ed95"
33
authors = ["Dai ZJ <zhuojia.dai@gmail.com>"]
4-
version = "0.1.7"
4+
version = "0.1.8"
55

66
[deps]
77
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
@@ -15,8 +15,8 @@ ScientificTypes = "321657f4-b219-11e9-178b-2701a2544e81"
1515
DataFrames = "0.21"
1616
JLBoost = "^0.1.8"
1717
LossFunctions = "0.5, 0.6"
18-
MLJ = "0.10, 0.11, 0.12"
19-
MLJBase = "0.12, 0.13, 0.14"
18+
MLJ = "0.10, 0.11, 0.12, 0.13"
19+
MLJBase = "0.12, 0.13, 0.14, 0.15"
2020
ScientificTypes = "0.7, 0.8, 1.0"
2121
julia = "1"
2222

README.md

Lines changed: 77 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ The [MLJ.jl](https://github.com/alan-turing-institute/MLJ.jl) interface to [JLBo
66
## Usage Example
77

88
````julia
9+
910
using RDatasets;
1011
iris = dataset("datasets", "iris");
1112
iris[!, :is_setosa] = iris.Species .== "setosa";
@@ -20,15 +21,15 @@ model = JLBoostClassifier()
2021

2122
````
2223
JLBoostClassifier(
23-
loss = JLBoost.LogitLogLoss(),
24+
loss = LogitLogLoss(),
2425
nrounds = 1,
2526
subsample = 1.0,
2627
eta = 1.0,
2728
max_depth = 6,
2829
min_child_weight = 1.0,
2930
lambda = 0.0,
3031
gamma = 0.0,
31-
colsample_bytree = 1) @813
32+
colsample_bytree = 1) @810
3233
````
3334

3435

@@ -40,15 +41,17 @@ JLBoostClassifier(
4041
Put the model and data in a machine
4142

4243
````julia
44+
4345
mljmachine = machine(model, X, y)
4446
````
4547

4648

4749
````
48-
Machine{JLBoostClassifier} @138 trained 0 times.
50+
Machine{JLBoostClassifier} @772 trained 0 times.
4951
args:
50-
1: Source @021 ⏎ `Table{AbstractArray{Continuous,1}}`
51-
2: Source @672 ⏎ `AbstractArray{Count,1}`
52+
1: Source @097 ⏎ `ScientificTypes.Table{AbstractArray{ScientificTypes.C
53+
ontinuous,1}}`
54+
2: Source @077 ⏎ `AbstractArray{ScientificTypes.Count,1}`
5255
````
5356

5457

@@ -58,6 +61,7 @@ Machine{JLBoostClassifier} @138 trained 0 times.
5861
Fit model using machine
5962

6063
````julia
64+
6165
fit!(mljmachine)
6266
````
6367

@@ -77,10 +81,11 @@ Choosing a split on SepalLength
7781
Choosing a split on SepalWidth
7882
Choosing a split on PetalLength
7983
Choosing a split on PetalWidth
80-
Machine{JLBoostClassifier} @138 trained 1 time.
84+
Machine{JLBoostClassifier} @772 trained 1 time.
8185
args:
82-
1: Source @021 ⏎ `Table{AbstractArray{Continuous,1}}`
83-
2: Source @672 ⏎ `AbstractArray{Count,1}`
86+
1: Source @097 ⏎ `ScientificTypes.Table{AbstractArray{ScientificTypes.C
87+
ontinuous,1}}`
88+
2: Source @077 ⏎ `AbstractArray{ScientificTypes.Count,1}`
8489
````
8590

8691

@@ -90,32 +95,34 @@ Machine{JLBoostClassifier} @138 trained 1 time.
9095
Predict using machine
9196

9297
````julia
98+
9399
predict(mljmachine, X)
94100
````
95101

96102

97103
````
98-
150-element Array{UnivariateFinite{Multiclass{2},Bool,UInt32,Float64},1}:
99-
UnivariateFinite{Multiclass{2}}(false=>0.881, true=>0.119)
100-
UnivariateFinite{Multiclass{2}}(false=>0.881, true=>0.119)
101-
UnivariateFinite{Multiclass{2}}(false=>0.881, true=>0.119)
102-
UnivariateFinite{Multiclass{2}}(false=>0.881, true=>0.119)
103-
UnivariateFinite{Multiclass{2}}(false=>0.881, true=>0.119)
104-
UnivariateFinite{Multiclass{2}}(false=>0.881, true=>0.119)
105-
UnivariateFinite{Multiclass{2}}(false=>0.881, true=>0.119)
106-
UnivariateFinite{Multiclass{2}}(false=>0.881, true=>0.119)
107-
UnivariateFinite{Multiclass{2}}(false=>0.881, true=>0.119)
108-
UnivariateFinite{Multiclass{2}}(false=>0.881, true=>0.119)
104+
150-element Array{MLJBase.UnivariateFinite{ScientificTypes.Multiclass{2},Bo
105+
ol,UInt32,Float64},1}:
106+
UnivariateFinite{ScientificTypes.Multiclass{2}}(false=>0.881, true=>0.119)
107+
UnivariateFinite{ScientificTypes.Multiclass{2}}(false=>0.881, true=>0.119)
108+
UnivariateFinite{ScientificTypes.Multiclass{2}}(false=>0.881, true=>0.119)
109+
UnivariateFinite{ScientificTypes.Multiclass{2}}(false=>0.881, true=>0.119)
110+
UnivariateFinite{ScientificTypes.Multiclass{2}}(false=>0.881, true=>0.119)
111+
UnivariateFinite{ScientificTypes.Multiclass{2}}(false=>0.881, true=>0.119)
112+
UnivariateFinite{ScientificTypes.Multiclass{2}}(false=>0.881, true=>0.119)
113+
UnivariateFinite{ScientificTypes.Multiclass{2}}(false=>0.881, true=>0.119)
114+
UnivariateFinite{ScientificTypes.Multiclass{2}}(false=>0.881, true=>0.119)
115+
UnivariateFinite{ScientificTypes.Multiclass{2}}(false=>0.881, true=>0.119)
109116
110-
UnivariateFinite{Multiclass{2}}(false=>0.119, true=>0.881)
111-
UnivariateFinite{Multiclass{2}}(false=>0.119, true=>0.881)
112-
UnivariateFinite{Multiclass{2}}(false=>0.119, true=>0.881)
113-
UnivariateFinite{Multiclass{2}}(false=>0.119, true=>0.881)
114-
UnivariateFinite{Multiclass{2}}(false=>0.119, true=>0.881)
115-
UnivariateFinite{Multiclass{2}}(false=>0.119, true=>0.881)
116-
UnivariateFinite{Multiclass{2}}(false=>0.119, true=>0.881)
117-
UnivariateFinite{Multiclass{2}}(false=>0.119, true=>0.881)
118-
UnivariateFinite{Multiclass{2}}(false=>0.119, true=>0.881)
117+
UnivariateFinite{ScientificTypes.Multiclass{2}}(false=>0.119, true=>0.881)
118+
UnivariateFinite{ScientificTypes.Multiclass{2}}(false=>0.119, true=>0.881)
119+
UnivariateFinite{ScientificTypes.Multiclass{2}}(false=>0.119, true=>0.881)
120+
UnivariateFinite{ScientificTypes.Multiclass{2}}(false=>0.119, true=>0.881)
121+
UnivariateFinite{ScientificTypes.Multiclass{2}}(false=>0.119, true=>0.881)
122+
UnivariateFinite{ScientificTypes.Multiclass{2}}(false=>0.119, true=>0.881)
123+
UnivariateFinite{ScientificTypes.Multiclass{2}}(false=>0.119, true=>0.881)
124+
UnivariateFinite{ScientificTypes.Multiclass{2}}(false=>0.119, true=>0.881)
125+
UnivariateFinite{ScientificTypes.Multiclass{2}}(false=>0.119, true=>0.881)
119126
````
120127

121128

@@ -125,6 +132,7 @@ predict(mljmachine, X)
125132
Feature importance using machine
126133

127134
````julia
135+
128136
feature_importance(fitted_params(mljmachine).fitresult, X, y)
129137
````
130138

@@ -146,12 +154,13 @@ feature_importance(fitted_params(mljmachine).fitresult, X, y)
146154
Data preparation: need to convert `y` to categorical
147155

148156
````julia
157+
149158
y_cate = categorical(y)
150159
````
151160

152161

153162
````
154-
150-element CategoricalArray{Bool,1,UInt32}:
163+
150-element CategoricalArrays.CategoricalArray{Bool,1,UInt32}:
155164
true
156165
true
157166
true
@@ -181,6 +190,7 @@ y_cate = categorical(y)
181190
Set up some hyperparameter ranges
182191

183192
````julia
193+
184194
using JLBoost, JLBoostMLJ, MLJ
185195
jlb = JLBoostClassifier()
186196
r1 = range(jlb, :nrounds, lower=1, upper = 6)
@@ -199,16 +209,18 @@ MLJBase.NumericRange(Float64, :eta, ... )
199209

200210
Set up the machine
201211
````julia
212+
202213
tm = TunedModel(model = jlb, ranges = [r1, r2, r3], measure = cross_entropy)
203214
m = machine(tm, X, y_cate)
204215
````
205216

206217

207218
````
208-
Machine{ProbabilisticTunedModel{Grid,…}} @666 trained 0 times.
219+
Machine{ProbabilisticTunedModel{Grid,…}} @388 trained 0 times.
209220
args:
210-
1: Source @209 ⏎ `Table{AbstractArray{Continuous,1}}`
211-
2: Source @449 ⏎ `AbstractArray{Multiclass{2},1}`
221+
1: Source @578 ⏎ `ScientificTypes.Table{AbstractArray{ScientificTypes.C
222+
ontinuous,1}}`
223+
2: Source @226 ⏎ `AbstractArray{ScientificTypes.Multiclass{2},1}`
212224
````
213225

214226

@@ -217,15 +229,17 @@ Machine{ProbabilisticTunedModel{Grid,…}} @666 trained 0 times.
217229

218230
Fit it!
219231
````julia
232+
220233
fit!(m)
221234
````
222235

223236

224237
````
225-
Machine{ProbabilisticTunedModel{Grid,…}} @666 trained 1 time.
238+
Machine{ProbabilisticTunedModel{Grid,…}} @388 trained 1 time.
226239
args:
227-
1: Source @209 ⏎ `Table{AbstractArray{Continuous,1}}`
228-
2: Source @449 ⏎ `AbstractArray{Multiclass{2},1}`
240+
1: Source @578 ⏎ `ScientificTypes.Table{AbstractArray{ScientificTypes.C
241+
ontinuous,1}}`
242+
2: Source @226 ⏎ `AbstractArray{ScientificTypes.Multiclass{2},1}`
229243
````
230244

231245

@@ -234,6 +248,7 @@ Machine{ProbabilisticTunedModel{Grid,…}} @666 trained 1 time.
234248

235249
Inspected the tuned parameters
236250
````julia
251+
237252
fitted_params(m).best_model.max_depth
238253
fitted_params(m).best_model.nrounds
239254
fitted_params(m).best_model.eta
@@ -252,6 +267,7 @@ fitted_params(m).best_model.eta
252267

253268
Fit the model with `verbosity = 1`
254269
````julia
270+
255271
mljmodel = fit(model, 1, X, y)
256272
````
257273

@@ -271,15 +287,15 @@ Choosing a split on SepalLength
271287
Choosing a split on SepalWidth
272288
Choosing a split on PetalLength
273289
Choosing a split on PetalWidth
274-
(fitresult = (treemodel = JLBoost.JLBoostTrees.JLBoostTreeModel(JLBoost.JLB
275-
oostTrees.AbstractJLBoostTree[eta = 1.0 (tree weight)
290+
(fitresult = (treemodel = JLBoostTreeModel(AbstractJLBoostTree[eta = 1.0 (t
291+
ree weight)
276292
277293
-- PetalLength <= 1.9
278294
---- weight = 2.0
279295
280296
-- PetalLength > 1.9
281297
---- weight = -2.0
282-
], JLBoost.LogitLogLoss(), :__y__),
298+
], LogitLogLoss(), :__y__),
283299
target_levels = Bool[0, 1],),
284300
cache = nothing,
285301
report = (AUC = 0.16666666666666669,
@@ -296,32 +312,34 @@ oostTrees.AbstractJLBoostTree[eta = 1.0 (tree weight)
296312
Predicting using the model
297313

298314
````julia
315+
299316
predict(model, mljmodel.fitresult, X)
300317
````
301318

302319

303320
````
304-
150-element Array{UnivariateFinite{Multiclass{2},Bool,UInt32,Float64},1}:
305-
UnivariateFinite{Multiclass{2}}(false=>0.881, true=>0.119)
306-
UnivariateFinite{Multiclass{2}}(false=>0.881, true=>0.119)
307-
UnivariateFinite{Multiclass{2}}(false=>0.881, true=>0.119)
308-
UnivariateFinite{Multiclass{2}}(false=>0.881, true=>0.119)
309-
UnivariateFinite{Multiclass{2}}(false=>0.881, true=>0.119)
310-
UnivariateFinite{Multiclass{2}}(false=>0.881, true=>0.119)
311-
UnivariateFinite{Multiclass{2}}(false=>0.881, true=>0.119)
312-
UnivariateFinite{Multiclass{2}}(false=>0.881, true=>0.119)
313-
UnivariateFinite{Multiclass{2}}(false=>0.881, true=>0.119)
314-
UnivariateFinite{Multiclass{2}}(false=>0.881, true=>0.119)
321+
150-element Array{MLJBase.UnivariateFinite{ScientificTypes.Multiclass{2},Bo
322+
ol,UInt32,Float64},1}:
323+
UnivariateFinite{ScientificTypes.Multiclass{2}}(false=>0.881, true=>0.119)
324+
UnivariateFinite{ScientificTypes.Multiclass{2}}(false=>0.881, true=>0.119)
325+
UnivariateFinite{ScientificTypes.Multiclass{2}}(false=>0.881, true=>0.119)
326+
UnivariateFinite{ScientificTypes.Multiclass{2}}(false=>0.881, true=>0.119)
327+
UnivariateFinite{ScientificTypes.Multiclass{2}}(false=>0.881, true=>0.119)
328+
UnivariateFinite{ScientificTypes.Multiclass{2}}(false=>0.881, true=>0.119)
329+
UnivariateFinite{ScientificTypes.Multiclass{2}}(false=>0.881, true=>0.119)
330+
UnivariateFinite{ScientificTypes.Multiclass{2}}(false=>0.881, true=>0.119)
331+
UnivariateFinite{ScientificTypes.Multiclass{2}}(false=>0.881, true=>0.119)
332+
UnivariateFinite{ScientificTypes.Multiclass{2}}(false=>0.881, true=>0.119)
315333
316-
UnivariateFinite{Multiclass{2}}(false=>0.119, true=>0.881)
317-
UnivariateFinite{Multiclass{2}}(false=>0.119, true=>0.881)
318-
UnivariateFinite{Multiclass{2}}(false=>0.119, true=>0.881)
319-
UnivariateFinite{Multiclass{2}}(false=>0.119, true=>0.881)
320-
UnivariateFinite{Multiclass{2}}(false=>0.119, true=>0.881)
321-
UnivariateFinite{Multiclass{2}}(false=>0.119, true=>0.881)
322-
UnivariateFinite{Multiclass{2}}(false=>0.119, true=>0.881)
323-
UnivariateFinite{Multiclass{2}}(false=>0.119, true=>0.881)
324-
UnivariateFinite{Multiclass{2}}(false=>0.119, true=>0.881)
334+
UnivariateFinite{ScientificTypes.Multiclass{2}}(false=>0.119, true=>0.881)
335+
UnivariateFinite{ScientificTypes.Multiclass{2}}(false=>0.119, true=>0.881)
336+
UnivariateFinite{ScientificTypes.Multiclass{2}}(false=>0.119, true=>0.881)
337+
UnivariateFinite{ScientificTypes.Multiclass{2}}(false=>0.119, true=>0.881)
338+
UnivariateFinite{ScientificTypes.Multiclass{2}}(false=>0.119, true=>0.881)
339+
UnivariateFinite{ScientificTypes.Multiclass{2}}(false=>0.119, true=>0.881)
340+
UnivariateFinite{ScientificTypes.Multiclass{2}}(false=>0.119, true=>0.881)
341+
UnivariateFinite{ScientificTypes.Multiclass{2}}(false=>0.119, true=>0.881)
342+
UnivariateFinite{ScientificTypes.Multiclass{2}}(false=>0.119, true=>0.881)
325343
````
326344

327345

@@ -332,6 +350,7 @@ Feature Importance for simple fitting
332350
One can obtain the feature importance using the `feature_importance` function
333351

334352
````julia
353+
335354
feature_importance(mljmodel.fitresult.treemodel, X, y)
336355
````
337356

0 commit comments

Comments
 (0)