Skip to content

Commit e97f410

Browse files
MaximilianSchreffphaniarnab
authored andcommitted
[SYSTEMDS-3851] Builtin for embedding Layer
This patch adds the embedding layer as a built-in operator in our nn/layers library. The functionality is similar to pytorch.nn.Embedding (https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html) The layer receives indices as input which refer to indices of an embedding dictionary and returns an embedding matrix where row i refers to embedding vector indices[i] of the embedding dictionary. This layer is used in every transformer architecture. Here the indices usually come from a tokenizer and the embedding matrix is the input to the actual transformer model. Closes #2237
1 parent 9fb1967 commit e97f410

File tree

3 files changed

+242
-0
lines changed

3 files changed

+242
-0
lines changed

scripts/nn/layers/embedding.dml

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
#-------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
#-------------------------------------------------------------
21+
22+
forward = function(matrix[double] indices, matrix[double] embedding_dict)
23+
return (matrix[double] embeddings) {
24+
/*
25+
* Forward pass of an embedding layer. An embedding matrix is constructed
26+
* from indices and corresponding embedding vectors from the embedding
27+
* dictionary.
28+
*
29+
* Inputs:
30+
* - indices: Indices referring to embedding vectors of embedding dictionary
31+
* of shape n x 1 with each value in {1, ..., v}.
32+
* - embedding_dict: Dictionary of embedding vectors of shape v x d.
33+
*
34+
* Outputs:
35+
* - embeddings: Embedding matrix where row i is equal to
36+
* embedding_dict[indices[i]].
37+
*/
38+
n = nrow(indices)
39+
v = nrow(embedding_dict)
40+
41+
# Construct permutation-like matrix (one '1' per row, rest '0')
42+
permutation = matrix(0, rows=n, cols=v)
43+
for (i in 1:n) {
44+
permutation[i, as.integer(as.scalar(indices[i]))] = 1
45+
}
46+
47+
embeddings = permutation %*% embedding_dict
48+
}
49+
50+
backward = function(matrix[double] dout, matrix[double] indices, int v,
51+
int padding_idx = -1)
52+
return (matrix[double] dembedding_dict) {
53+
/*
54+
* Backward pass of embedding layer computes the gradients of the embedding
55+
* dictionary.
56+
*
57+
* Inputs:
58+
* - dout: Gradient of the output.
59+
* - indices: Indices referring to embedding vectors of embedding dictionary
60+
* of shape n x 1 with each value in {1, ..., v}.
61+
* - v: Embedding dictionary size.
62+
* - padding_idx: Index of embedding vector of embedding dictionary which
63+
* should not be updated (i.e. gradients are 0). Use -1 if
64+
* there is no padding vector.
65+
*
66+
* Outputs:
67+
* - dembedding_dict: Gradients of the dictionary of embedding vectors of
68+
* shape v x d.
69+
*/
70+
n = nrow(indices)
71+
72+
# Construct permutation-like matrix (one '1' per row, rest '0')
73+
permutation = matrix(0, rows=n, cols=v)
74+
for (i in 1:n) {
75+
permutation[i, as.integer(as.scalar(indices[i]))] = 1
76+
}
77+
78+
dembedding_dict = t(permutation) %*% dout
79+
if (padding_idx != -1) {
80+
dembedding_dict[padding_idx] = matrix(0, rows=1, cols=ncol(dout))
81+
}
82+
}
83+
84+
init = function(int v, int d, int seed = -1)
85+
return (matrix[double] embedding_dict) {
86+
/*
87+
* Initializes embedding dictionary matrix via N(0, 1).
88+
*
89+
* Inputs:
90+
* - v: Embedding dictionary size.
91+
* - d: Embedding vector dimension.
92+
* - seed: Random generation seed.
93+
*
94+
* Output:
95+
* - embedding_dict: Embedding dictionary matrix of shape v x d.
96+
*/
97+
embedding_dict = rand(rows=v, cols=d, pdf="normal", seed=seed)
98+
}
99+

src/test/java/org/apache/sysds/test/applications/nn/NNComponentTest.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,11 @@ public void gelu() {
129129
run("gelu.dml");
130130
}
131131

132+
@Test
133+
public void embedding() {
134+
run("embedding.dml");
135+
}
136+
132137
@Override
133138
protected void run(String name) {
134139
super.run("component/" + name);
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
#-------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
#-------------------------------------------------------------
21+
22+
source("nn/layers/embedding.dml") as embedding
23+
source("src/test/scripts/applications/nn/util.dml") as test_util
24+
25+
embedding_test_forward = function() {
26+
print("Testing Embedding - Forward Test")
27+
n = 4
28+
v = 7
29+
d = 3
30+
31+
embedding_dict = matrix("-0.78327566 -0.87246466 -0.80580276
32+
-0.17845497 2.1740944 -1.2514428
33+
-0.27202556 -1.3681601 -1.5384313
34+
1.4215976 -0.463162 1.2592019
35+
-1.7417 -0.46109396 -0.06011621
36+
-0.7803316 1.0802858 0.7465289
37+
0. 0. 0.", rows=v, cols=d)
38+
indices = matrix("1 6 7 6", rows=n, cols=1)
39+
40+
embeddings = embedding::forward(indices, embedding_dict)
41+
42+
expected_embeddings = matrix("-0.78327566 -0.87246466 -0.80580276
43+
-0.7803316 1.0802858 0.7465289
44+
0. 0. 0.
45+
-0.7803316 1.0802858 0.7465289", rows=n, cols=d)
46+
47+
test_util::check_all_close(embeddings, expected_embeddings, 1e-05)
48+
}
49+
50+
embedding_test_forward_backward_no_pad = function() {
51+
print("Testing Embedding - Forward & Backward Test w/out Padding")
52+
n = 2
53+
v = 4
54+
d = 3
55+
56+
embedding_dict = matrix("-0.15039968 0.56168836 -0.577436
57+
0.47334725 1.5215642 -0.1924941
58+
1.600819 -1.1331359 -2.58817
59+
0.9779929 -0.82212716 -1.5917081", rows=v, cols=d)
60+
indices = matrix("2 3", rows=n, cols=1)
61+
62+
embeddings = embedding::forward(indices, embedding_dict)
63+
64+
expected_embeddings = matrix("0.47334725 1.5215642 -0.1924941
65+
1.600819 -1.1331359 -2.58817", rows=n, cols=d)
66+
67+
test_util::check_all_close(embeddings, expected_embeddings, 1e-05)
68+
69+
dout = matrix(seq(1, n*d, 1), rows=n, cols=d)
70+
padding_idx = -1
71+
72+
dembedding_dict = embedding::backward(dout, indices, v, padding_idx)
73+
expected_dembedding_dict = matrix("0. 0. 0.
74+
1. 2. 3.
75+
4. 5. 6.
76+
0. 0. 0.", rows=v, cols=d)
77+
test_util::check_all_close(dembedding_dict, expected_dembedding_dict, 1e-05)
78+
}
79+
80+
embedding_test_forward_backward_pad = function() {
81+
print("Testing Embedding - Forward & Backward Test w/ Padding")
82+
n = 5
83+
v = 10
84+
d = 6
85+
86+
embedding_dict = matrix("-1.24377859e+00 -1.10724878e+00 2.35533118e-01 6.65530920e-01
87+
9.80555452e-03 6.31030917e-01
88+
8.16493928e-01 -6.21011078e-01 -5.75569510e-01 -3.93419750e-02
89+
-6.20878041e-01 1.37852756e-02
90+
7.43950903e-01 1.60437262e+00 -2.31788456e-01 1.15943216e-01
91+
-8.83608997e-01 1.11547875e+00
92+
0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
93+
0.00000000e+00 0.00000000e+00
94+
1.70598769e+00 1.82770026e+00 1.30581510e+00 1.05738208e-01
95+
4.50116873e-01 3.48498315e-01
96+
1.40551448e+00 3.43091488e-02 1.84714049e-03 -5.52828193e-01
97+
3.65064174e-01 -9.31223869e-01
98+
1.33713937e+00 -3.43729639e+00 -1.22915792e+00 -1.12923630e-01
99+
-1.16292477e+00 -2.16708351e-02
100+
6.63879395e-01 -2.76697308e-01 -9.02738094e-01 -6.85515344e-01
101+
-6.43863618e-01 -2.30419707e+00
102+
1.44121364e-01 5.20578504e-01 -6.53087497e-01 6.62900746e-01
103+
3.82369667e-01 -2.25386508e-02
104+
2.20637798e+00 -6.86733365e-01 -1.27398467e+00 6.28316283e-01
105+
2.70236313e-01 2.20882833e-01", rows=v, cols=d)
106+
indices = matrix("1 1 1 4 6", rows=n, cols=1)
107+
108+
embeddings = embedding::forward(indices, embedding_dict)
109+
110+
expected_embeddings = matrix("-1.2437786 -1.1072488 0.23553312 0.6655309 0.00980555 0.6310309
111+
-1.2437786 -1.1072488 0.23553312 0.6655309 0.00980555 0.6310309
112+
-1.2437786 -1.1072488 0.23553312 0.6655309 0.00980555 0.6310309
113+
0. 0. 0. 0. 0. 0.
114+
1.4055145 0.03430915 0.00184714 -0.5528282 0.36506417 -0.93122387", rows=n, cols=d)
115+
116+
test_util::check_all_close(embeddings, expected_embeddings, 1e-05)
117+
118+
dout = matrix(seq(1, n*d, 1), rows=n, cols=d)
119+
padding_idx = 4
120+
121+
dembedding_dict = embedding::backward(dout, indices, v, padding_idx)
122+
expected_dembedding_dict = matrix("21. 24. 27. 30. 33. 36.
123+
0. 0. 0. 0. 0. 0.
124+
0. 0. 0. 0. 0. 0.
125+
0. 0. 0. 0. 0. 0.
126+
0. 0. 0. 0. 0. 0.
127+
25. 26. 27. 28. 29. 30.
128+
0. 0. 0. 0. 0. 0.
129+
0. 0. 0. 0. 0. 0.
130+
0. 0. 0. 0. 0. 0.
131+
0. 0. 0. 0. 0. 0.", rows=v, cols=d)
132+
test_util::check_all_close(dembedding_dict, expected_dembedding_dict, 1e-05)
133+
}
134+
135+
embedding_test_forward()
136+
embedding_test_forward_backward_no_pad()
137+
embedding_test_forward_backward_pad()
138+

0 commit comments

Comments
 (0)