Skip to content

Commit 8f791df

Browse files
authored
Merge pull request #98 from yuehhua/datasets
Add datasets
2 parents e433462 + a7cf730 commit 8f791df

File tree

17 files changed

+306
-47
lines changed

17 files changed

+306
-47
lines changed

Manifest.toml

Lines changed: 82 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@ version = "0.0.4"
2626

2727
[[ArrayLayouts]]
2828
deps = ["FillArrays", "LinearAlgebra", "SparseArrays"]
29-
git-tree-sha1 = "e3e0a1e7dcbfdb1fc1061bfd889581a1d942cfcb"
29+
git-tree-sha1 = "bd09f450716f55c5a47b24de277a8825e2450729"
3030
uuid = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
31-
version = "0.4.5"
31+
version = "0.4.7"
3232

3333
[[Base64]]
3434
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
@@ -39,6 +39,24 @@ git-tree-sha1 = "ecdec412a9abc8db54c0efc5548c64dfce072058"
3939
uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
4040
version = "0.5.10"
4141

42+
[[Blosc]]
43+
deps = ["Blosc_jll"]
44+
git-tree-sha1 = "84cf7d0f8fd46ca6f1b3e0305b4b4a37afe50fd6"
45+
uuid = "a74b3585-a348-5f62-a45c-50e91977d574"
46+
version = "0.7.0"
47+
48+
[[Blosc_jll]]
49+
deps = ["Libdl", "Lz4_jll", "Pkg", "Zlib_jll", "Zstd_jll"]
50+
git-tree-sha1 = "aa9ef39b54a168c3df1b2911e7797e4feee50fbe"
51+
uuid = "0b7ba130-8d10-5ba8-a3d6-c5182647fed9"
52+
version = "1.14.3+1"
53+
54+
[[BufferedStreams]]
55+
deps = ["Compat", "Test"]
56+
git-tree-sha1 = "5d55b9486590fdda5905c275bb21ce1f0754020f"
57+
uuid = "e1450e63-4bb3-523b-b2a4-4ffa8c0fd77d"
58+
version = "1.0.0"
59+
4260
[[CEnum]]
4361
git-tree-sha1 = "215a9aa4a1f23fbd05b92769fdd62559488d70e9"
4462
uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82"
@@ -86,6 +104,12 @@ git-tree-sha1 = "7b8a93dba8af7e3b42fecabf646260105ac373f7"
86104
uuid = "bbf7d656-a473-5ed7-a52c-81e309532950"
87105
version = "0.3.0"
88106

107+
[[Compat]]
108+
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
109+
git-tree-sha1 = "215f1c81cfd1c5416cd78740bff8ef59b24cd7c0"
110+
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
111+
version = "3.15.0"
112+
89113
[[CompilerSupportLibraries_jll]]
90114
deps = ["Libdl", "Pkg"]
91115
git-tree-sha1 = "7c4f882c41faa72118841185afc58a2eb00ef612"
@@ -152,15 +176,21 @@ uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
152176
version = "0.8.3"
153177

154178
[[ExprTools]]
155-
git-tree-sha1 = "6f0517056812fd6aa3af23d4b70d5325a2ae4e95"
179+
git-tree-sha1 = "7fce513fcda766962ff67c5596cb16c463dfd371"
156180
uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
157-
version = "0.1.1"
181+
version = "0.1.2"
182+
183+
[[FileIO]]
184+
deps = ["Pkg"]
185+
git-tree-sha1 = "992b4aeb62f99b69fcf0cb2085094494cc05dfb3"
186+
uuid = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549"
187+
version = "1.4.3"
158188

159189
[[FillArrays]]
160190
deps = ["LinearAlgebra", "Random", "SparseArrays"]
161-
git-tree-sha1 = "56dc5338eb79a05e5ea120f510ed77efe7e9784d"
191+
git-tree-sha1 = "b955c227b0d1413a1a97e2ca0635a5de019d7337"
162192
uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
163-
version = "0.9.5"
193+
version = "0.9.6"
164194

165195
[[FixedPointNumbers]]
166196
deps = ["Statistics"]
@@ -214,6 +244,18 @@ git-tree-sha1 = "6e62e16c779458412951a71f4d535f05a1e0bb89"
214244
uuid = "3ebe565e-a4b5-49c6-aed2-300248c3a9c1"
215245
version = "0.1.1"
216246

247+
[[HDF5]]
248+
deps = ["Blosc", "HDF5_jll", "Libdl", "Mmap", "Random"]
249+
git-tree-sha1 = "0713cbabdf855852dfab3ce6447c87145f3d9ea8"
250+
uuid = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f"
251+
version = "0.13.6"
252+
253+
[[HDF5_jll]]
254+
deps = ["Libdl", "Pkg", "Zlib_jll"]
255+
git-tree-sha1 = "85bd2e586a10ae0eab856125bf5245e0d36384a7"
256+
uuid = "0234f1f7-429e-5d53-9886-15a909be8d59"
257+
version = "1.10.5+5"
258+
217259
[[HTTP]]
218260
deps = ["Base64", "Dates", "IniFile", "MbedTLS", "Sockets"]
219261
git-tree-sha1 = "2ac03263ce44be4222342bca1c51c36ce7566161"
@@ -242,16 +284,16 @@ deps = ["Markdown"]
242284
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
243285

244286
[[JLD2]]
245-
deps = ["CodecZlib", "DataStructures", "MacroTools", "Mmap", "Pkg", "Printf", "Requires", "UUIDs"]
246-
git-tree-sha1 = "d2c0db66530ff444846d6e84bcf948a74ce31635"
287+
deps = ["CodecZlib", "DataStructures", "FileIO", "Mmap", "Pkg", "Printf", "UUIDs"]
288+
git-tree-sha1 = "9353b717ee4e27beab4e902c92a06bb5f160d2cf"
247289
uuid = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
248-
version = "0.2.0"
290+
version = "0.1.14"
249291

250292
[[JSON]]
251293
deps = ["Dates", "Mmap", "Parsers", "Unicode"]
252-
git-tree-sha1 = "b34d7cef7b337321e97d22242c3c2b91f476748e"
294+
git-tree-sha1 = "81690084b6198a2e1da36fcfda16eeca9f9f24e4"
253295
uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
254-
version = "0.21.0"
296+
version = "0.21.1"
255297

256298
[[Juno]]
257299
deps = ["Base64", "Logging", "Media", "Profile"]
@@ -287,9 +329,21 @@ uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
287329

288330
[[LoopVectorization]]
289331
deps = ["DocStringExtensions", "LinearAlgebra", "OffsetArrays", "SIMDPirates", "SLEEFPirates", "UnPack", "VectorizationBase"]
290-
git-tree-sha1 = "224c9768765c2a3b588fec71cff48b8eb1c80c48"
332+
git-tree-sha1 = "3242a8f411e19eda9adc49d0b877681975c11375"
291333
uuid = "bdcacae8-1622-11e9-2a5c-532679323890"
292-
version = "0.8.25"
334+
version = "0.8.26"
335+
336+
[[Lz4_jll]]
337+
deps = ["Libdl", "Pkg"]
338+
git-tree-sha1 = "51b1db0732bbdcfabb60e36095cc3ed9c0016932"
339+
uuid = "5ced341a-0733-55b8-9ab6-a4889d929147"
340+
version = "1.9.2+2"
341+
342+
[[MAT]]
343+
deps = ["BufferedStreams", "CodecZlib", "HDF5", "SparseArrays"]
344+
git-tree-sha1 = "7e36f6a52274ddb8515ec1f559306be3f412d6a6"
345+
uuid = "23992714-dd62-5051-b70f-ba57cb901cac"
346+
version = "0.8.1"
293347

294348
[[MacroTools]]
295349
deps = ["Markdown", "Random"]
@@ -320,10 +374,10 @@ uuid = "e89f7d12-3494-54d1-8411-f7d8b9ae1f27"
320374
version = "0.5.0"
321375

322376
[[MetaGraphs]]
323-
deps = ["JLD2", "LightGraphs", "Random"]
324-
git-tree-sha1 = "43ebbe06b22d213e4a8750424f9c7d1311bee2a6"
377+
deps = ["JLD2", "LightGraphs"]
378+
git-tree-sha1 = "c6a4c88304e1ecef6fc372f12d3b8e427e128c1a"
325379
uuid = "626554b9-1ddb-594c-aa3c-2596fe9399a5"
326-
version = "0.6.4"
380+
version = "0.6.3"
327381

328382
[[Missings]]
329383
deps = ["DataAPI"]
@@ -415,9 +469,9 @@ uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
415469

416470
[[SIMDPirates]]
417471
deps = ["VectorizationBase"]
418-
git-tree-sha1 = "26ccdd1466f3071e27e81b43216ea238b62c0c42"
472+
git-tree-sha1 = "450d163d3279a1d35e3aad3352a5167ef21b84a4"
419473
uuid = "21efa798-c60a-11e8-04d3-e1a92915a26a"
420-
version = "0.8.24"
474+
version = "0.8.25"
421475

422476
[[SLEEFPirates]]
423477
deps = ["Libdl", "SIMDPirates", "VectorizationBase"]
@@ -515,9 +569,9 @@ uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
515569

516570
[[VectorizationBase]]
517571
deps = ["CpuId", "Libdl", "LinearAlgebra"]
518-
git-tree-sha1 = "c2a34c8065076a867fc36522c1a3441156a63445"
572+
git-tree-sha1 = "03e2fbb479a1ea350398195b6fbf439bae0f8260"
519573
uuid = "3d5dd08c-fd9d-11e8-17fa-ed2836048c2f"
520-
version = "0.12.32"
574+
version = "0.12.33"
521575

522576
[[VersionParsing]]
523577
git-tree-sha1 = "80229be1f670524750d905f8fc8148e5a8c4537f"
@@ -536,11 +590,17 @@ git-tree-sha1 = "fdd89e5ab270ea0f2a0174bd9093e557d06d4bfa"
536590
uuid = "83775a58-1f1d-513f-b197-d71354ab007a"
537591
version = "1.2.11+16"
538592

593+
[[Zstd_jll]]
594+
deps = ["Libdl", "Pkg"]
595+
git-tree-sha1 = "4de91f4313d9e88162d461e282fe3066ab3a3c09"
596+
uuid = "3161d3a3-bdf6-5164-811a-617609db77b4"
597+
version = "1.4.5+1"
598+
539599
[[Zygote]]
540600
deps = ["AbstractFFTs", "ArrayLayouts", "ChainRules", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "Future", "IRTools", "InteractiveUtils", "LinearAlgebra", "LoopVectorization", "MacroTools", "NNlib", "NaNMath", "Random", "Requires", "SpecialFunctions", "Statistics", "ZygoteRules"]
541-
git-tree-sha1 = "fa45127afd117fca6a8540f92d5e3799daf8339a"
601+
git-tree-sha1 = "b0a948a0a78e3e41515714fa1ef4f40a284ffa06"
542602
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
543-
version = "0.5.5"
603+
version = "0.5.6"
544604

545605
[[ZygoteRules]]
546606
deps = ["MacroTools"]

Project.toml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@ GraphSignals = "3ebe565e-a4b5-49c6-aed2-300248c3a9c1"
1313
HTTP = "cd3eb016-35fb-5094-929b-558a96fad6f3"
1414
IRTools = "7869d1d1-7146-5819-86e3-90919afe41df"
1515
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
16+
JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
1617
LightGraphs = "093fc24a-ae57-5d10-9952-331d41423f4d"
1718
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
19+
MAT = "23992714-dd62-5051-b70f-ba57cb901cac"
1820
MetaGraphs = "626554b9-1ddb-594c-aa3c-2596fe9399a5"
1921
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
2022
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -36,9 +38,11 @@ Flux = "0.10 - 0.11"
3638
GraphSignals = "0.1"
3739
HTTP = "0.8"
3840
IRTools = "0.4"
39-
JLD2 = "0.2"
41+
JLD2 = "0.1 - 0.2"
42+
JSON = "0.21"
4043
LightGraphs = "1.3"
41-
MetaGraphs = "0.6"
44+
MAT = "0.8"
45+
MetaGraphs = "< 0.6.4"
4246
PyCall = "1.91"
4347
Requires = "1.0.0"
4448
ScatterNNlib = "0.1"

src/GeometricFlux.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,9 @@ using ZygoteRules
1919
import Flux: maxpool, meanpool
2020

2121
export
22-
# datasets
22+
datasets,
2323
traindata,
24+
validdata,
2425
testdata,
2526

2627
# layers/gn

src/datasets/Datasets.jl

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,29 +2,40 @@ module Datasets
22
using DataDeps: DataDep, register, @datadep_str
33
using HTTP
44
using JLD2
5+
using JSON
6+
using LightGraphs: SimpleDiGraph, add_edge!
7+
using MAT
58
using PyCall
6-
using SparseArrays: SparseMatrixCSC
9+
using SparseArrays: SparseMatrixCSC, sparse
710

811
export
912
Dataset,
1013
Planetoid,
1114
Cora,
15+
PPI,
16+
Reddit,
17+
QM7b,
18+
# Entities,
1219
dataset,
1320
traindata,
21+
validdata,
1422
testdata
1523

16-
abstract type Dataset end
17-
24+
include("./dataset.jl")
1825
include("./planetoid.jl")
1926
include("./cora.jl")
2027
include("./ppi.jl")
2128
include("./reddit.jl")
22-
# include("./qm7b.jl")
29+
include("./qm7b.jl")
2330
# include("./entities.jl")
2431
include("./datautils.jl")
2532

2633
function __init__()
2734
planetoid_init()
2835
cora_init()
36+
ppi_init()
37+
reddit_init()
38+
qm7b_init()
39+
# entities_init()
2940
end
3041
end

src/datasets/dataset.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
abstract type Dataset end
2+
3+
# function dataset(::Dataset)
4+
# throw()
5+
# end
6+
7+
# function traindata(::Dataset)
8+
9+
# end
10+
11+
# function testdata(::Dataset)
12+
13+
# end

src/datasets/datautils.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,8 @@ function download_file(url, path)
55
end
66
end
77
end
8+
9+
function unzip(zipfile::String)
10+
f = replace(zipfile, ".zip"=>"")
11+
run(`unzip $f`)
12+
end

src/datasets/ppi.jl

Lines changed: 60 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
ppi_init() = register(DataDep(
2-
"Cora full datasets",
2+
"PPI",
33
"""
44
The protein-protein interaction networks from the `"Predicting
55
Multicellular Function through Multi-layer Tissue Networks"
@@ -8,7 +8,63 @@ ppi_init() = register(DataDep(
88
total) and gene ontology sets as labels (121 in total).
99
""",
1010
"https://s3.us-east-2.amazonaws.com/dgl.ai/dataset/ppi.zip",
11-
"";
12-
fetch_method=http_download,
13-
post_fetch_method=DataDeps.unpack,
11+
"1f5b2b09ac0f897fa6aa1338c64ab75a5473674cbba89380120bede8cddb2a6a";
12+
post_fetch_method=preprocess_ppi,
1413
))
14+
15+
function preprocess_ppi(local_path)
16+
unzip(local_path)
17+
18+
for phase in ["train", "test", "valid"]
19+
graph_file = @datadep_str "PPI/$(phase)_graph.json"
20+
id_file = @datadep_str "PPI/$(phase)_graph_id.npy"
21+
X_file = @datadep_str "PPI/$(phase)_feats.npy"
22+
y_file = @datadep_str "PPI/$(phase)_labels.npy"
23+
24+
py"""
25+
import numpy as np
26+
ids = np.load($id_file)
27+
X = np.load($X_file)
28+
y = np.load($y_file)
29+
"""
30+
31+
X = Matrix{Float32}(py"X")
32+
y = SparseMatrixCSC{Int32,Int64}(Array(py"y"))
33+
ids = Array(py"ids")
34+
graph = read_ppi_graph(graph_file)
35+
36+
jld2file = replace(local_path, "ppi.zip"=>"ppi.$(phase).jld2")
37+
@save jld2file graph X y ids
38+
end
39+
end
40+
41+
function read_ppi_graph(filename::String)
42+
d = JSON.Parser.parsefile(filename)
43+
g = SimpleDiGraph{Int32}(length(d["nodes"]))
44+
45+
for pair in d["links"]
46+
add_edge!(g, pair["source"], pair["target"])
47+
end
48+
g
49+
end
50+
51+
struct PPI <: Dataset
52+
end
53+
54+
function traindata(::PPI)
55+
file = datadep"PPI/ppi.train.jld2"
56+
@load file graph X y ids
57+
graph, X, y, ids
58+
end
59+
60+
function validdata(::PPI)
61+
file = datadep"PPI/ppi.valid.jld2"
62+
@load file graph X y ids
63+
graph, X, y, ids
64+
end
65+
66+
function testdata(::PPI)
67+
file = datadep"PPI/ppi.test.jld2"
68+
@load file graph X y ids
69+
graph, X, y, ids
70+
end

0 commit comments

Comments
 (0)