Skip to content

Commit 0333bd2

Browse files
committed
update load_data
1 parent 38c77d0 commit 0333bd2

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

src/load_data.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
function load_data(name)
2+
datasets = ["iris", "adult", "digits"]
23
data_path = joinpath(dirname(pathof(DecisionTree)), "..", "test/data/")
34

45
if name == "digits"
@@ -10,20 +11,24 @@ function load_data(name)
1011
data = hcat(data...)
1112
Y = Int.(data[1, 1:end]) .+ 1
1213
X = convert(Matrix, transpose(data[2:end, 1:end]))
13-
return X, Y
14+
return X, Y
1415
end
1516

1617
if name == "iris"
1718
iris = DelimitedFiles.readdlm(joinpath(data_path, "iris.csv"), ',')
1819
X = iris[:, 1:4]
1920
Y = iris[:, 5]
20-
return X, Y
21+
return X, Y
2122
end
2223

2324
if name == "adult"
2425
adult = DelimitedFiles.readdlm(joinpath(data_path, "adult.csv"), ',');
2526
X = adult[:, 1:14];
2627
Y = adult[:, 15];
27-
return X, Y
28+
return X, Y
29+
end
30+
31+
if !(name in datasets)
32+
throw("Available datasets are $(join(datasets,", "))")
2833
end
2934
end

0 commit comments

Comments
 (0)