|
| 1 | +#------------------------------------------------------------- |
| 2 | +# |
| 3 | +# Licensed to the Apache Software Foundation (ASF) under one |
| 4 | +# or more contributor license agreements. See the NOTICE file |
| 5 | +# distributed with this work for additional information |
| 6 | +# regarding copyright ownership. The ASF licenses this file |
| 7 | +# to you under the Apache License, Version 2.0 (the |
| 8 | +# "License"); you may not use this file except in compliance |
| 9 | +# with the License. You may obtain a copy of the License at |
| 10 | +# |
| 11 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 12 | +# |
| 13 | +# Unless required by applicable law or agreed to in writing, |
| 14 | +# software distributed under the License is distributed on an |
| 15 | +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 16 | +# KIND, either express or implied. See the License for the |
| 17 | +# specific language governing permissions and limitations |
| 18 | +# under the License. |
| 19 | +# |
| 20 | +#------------------------------------------------------------- |
| 21 | + |
| 22 | +source("nn/layers/embedding.dml") as embedding |
| 23 | +source("src/test/scripts/applications/nn/util.dml") as test_util |
| 24 | + |
| 25 | +embedding_test_forward = function() { |
| 26 | + print("Testing Embedding - Forward Test") |
| 27 | + n = 4 |
| 28 | + v = 7 |
| 29 | + d = 3 |
| 30 | + |
| 31 | + embedding_dict = matrix("-0.78327566 -0.87246466 -0.80580276 |
| 32 | + -0.17845497 2.1740944 -1.2514428 |
| 33 | + -0.27202556 -1.3681601 -1.5384313 |
| 34 | + 1.4215976 -0.463162 1.2592019 |
| 35 | + -1.7417 -0.46109396 -0.06011621 |
| 36 | + -0.7803316 1.0802858 0.7465289 |
| 37 | + 0. 0. 0.", rows=v, cols=d) |
| 38 | + indices = matrix("1 6 7 6", rows=n, cols=1) |
| 39 | + |
| 40 | + embeddings = embedding::forward(indices, embedding_dict) |
| 41 | + |
| 42 | + expected_embeddings = matrix("-0.78327566 -0.87246466 -0.80580276 |
| 43 | + -0.7803316 1.0802858 0.7465289 |
| 44 | + 0. 0. 0. |
| 45 | + -0.7803316 1.0802858 0.7465289", rows=n, cols=d) |
| 46 | + |
| 47 | + test_util::check_all_close(embeddings, expected_embeddings, 1e-05) |
| 48 | +} |
| 49 | + |
| 50 | +embedding_test_forward_backward_no_pad = function() { |
| 51 | + print("Testing Embedding - Forward & Backward Test w/out Padding") |
| 52 | + n = 2 |
| 53 | + v = 4 |
| 54 | + d = 3 |
| 55 | + |
| 56 | + embedding_dict = matrix("-0.15039968 0.56168836 -0.577436 |
| 57 | + 0.47334725 1.5215642 -0.1924941 |
| 58 | + 1.600819 -1.1331359 -2.58817 |
| 59 | + 0.9779929 -0.82212716 -1.5917081", rows=v, cols=d) |
| 60 | + indices = matrix("2 3", rows=n, cols=1) |
| 61 | + |
| 62 | + embeddings = embedding::forward(indices, embedding_dict) |
| 63 | + |
| 64 | + expected_embeddings = matrix("0.47334725 1.5215642 -0.1924941 |
| 65 | + 1.600819 -1.1331359 -2.58817", rows=n, cols=d) |
| 66 | + |
| 67 | + test_util::check_all_close(embeddings, expected_embeddings, 1e-05) |
| 68 | + |
| 69 | + dout = matrix(seq(1, n*d, 1), rows=n, cols=d) |
| 70 | + padding_idx = -1 |
| 71 | + |
| 72 | + dembedding_dict = embedding::backward(dout, indices, v, padding_idx) |
| 73 | + expected_dembedding_dict = matrix("0. 0. 0. |
| 74 | + 1. 2. 3. |
| 75 | + 4. 5. 6. |
| 76 | + 0. 0. 0.", rows=v, cols=d) |
| 77 | + test_util::check_all_close(dembedding_dict, expected_dembedding_dict, 1e-05) |
| 78 | +} |
| 79 | + |
| 80 | +embedding_test_forward_backward_pad = function() { |
| 81 | + print("Testing Embedding - Forward & Backward Test w/ Padding") |
| 82 | + n = 5 |
| 83 | + v = 10 |
| 84 | + d = 6 |
| 85 | + |
| 86 | + embedding_dict = matrix("-1.24377859e+00 -1.10724878e+00 2.35533118e-01 6.65530920e-01 |
| 87 | + 9.80555452e-03 6.31030917e-01 |
| 88 | + 8.16493928e-01 -6.21011078e-01 -5.75569510e-01 -3.93419750e-02 |
| 89 | + -6.20878041e-01 1.37852756e-02 |
| 90 | + 7.43950903e-01 1.60437262e+00 -2.31788456e-01 1.15943216e-01 |
| 91 | + -8.83608997e-01 1.11547875e+00 |
| 92 | + 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 |
| 93 | + 0.00000000e+00 0.00000000e+00 |
| 94 | + 1.70598769e+00 1.82770026e+00 1.30581510e+00 1.05738208e-01 |
| 95 | + 4.50116873e-01 3.48498315e-01 |
| 96 | + 1.40551448e+00 3.43091488e-02 1.84714049e-03 -5.52828193e-01 |
| 97 | + 3.65064174e-01 -9.31223869e-01 |
| 98 | + 1.33713937e+00 -3.43729639e+00 -1.22915792e+00 -1.12923630e-01 |
| 99 | + -1.16292477e+00 -2.16708351e-02 |
| 100 | + 6.63879395e-01 -2.76697308e-01 -9.02738094e-01 -6.85515344e-01 |
| 101 | + -6.43863618e-01 -2.30419707e+00 |
| 102 | + 1.44121364e-01 5.20578504e-01 -6.53087497e-01 6.62900746e-01 |
| 103 | + 3.82369667e-01 -2.25386508e-02 |
| 104 | + 2.20637798e+00 -6.86733365e-01 -1.27398467e+00 6.28316283e-01 |
| 105 | + 2.70236313e-01 2.20882833e-01", rows=v, cols=d) |
| 106 | + indices = matrix("1 1 1 4 6", rows=n, cols=1) |
| 107 | + |
| 108 | + embeddings = embedding::forward(indices, embedding_dict) |
| 109 | + |
| 110 | + expected_embeddings = matrix("-1.2437786 -1.1072488 0.23553312 0.6655309 0.00980555 0.6310309 |
| 111 | + -1.2437786 -1.1072488 0.23553312 0.6655309 0.00980555 0.6310309 |
| 112 | + -1.2437786 -1.1072488 0.23553312 0.6655309 0.00980555 0.6310309 |
| 113 | + 0. 0. 0. 0. 0. 0. |
| 114 | + 1.4055145 0.03430915 0.00184714 -0.5528282 0.36506417 -0.93122387", rows=n, cols=d) |
| 115 | + |
| 116 | + test_util::check_all_close(embeddings, expected_embeddings, 1e-05) |
| 117 | + |
| 118 | + dout = matrix(seq(1, n*d, 1), rows=n, cols=d) |
| 119 | + padding_idx = 4 |
| 120 | + |
| 121 | + dembedding_dict = embedding::backward(dout, indices, v, padding_idx) |
| 122 | + expected_dembedding_dict = matrix("21. 24. 27. 30. 33. 36. |
| 123 | + 0. 0. 0. 0. 0. 0. |
| 124 | + 0. 0. 0. 0. 0. 0. |
| 125 | + 0. 0. 0. 0. 0. 0. |
| 126 | + 0. 0. 0. 0. 0. 0. |
| 127 | + 25. 26. 27. 28. 29. 30. |
| 128 | + 0. 0. 0. 0. 0. 0. |
| 129 | + 0. 0. 0. 0. 0. 0. |
| 130 | + 0. 0. 0. 0. 0. 0. |
| 131 | + 0. 0. 0. 0. 0. 0.", rows=v, cols=d) |
| 132 | + test_util::check_all_close(dembedding_dict, expected_dembedding_dict, 1e-05) |
| 133 | +} |
| 134 | + |
| 135 | +embedding_test_forward() |
| 136 | +embedding_test_forward_backward_no_pad() |
| 137 | +embedding_test_forward_backward_pad() |
| 138 | + |
0 commit comments