|
| 1 | +# ****************************************************************************** |
| 2 | +# Copyright 2018 Intel Corporation |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | +# ****************************************************************************** |
| 16 | + |
| 17 | +from collections import defaultdict |
| 18 | +import glob |
| 19 | +import os |
| 20 | +import shutil |
| 21 | +import tarfile |
| 22 | +import tempfile |
| 23 | + |
| 24 | +from six.moves.urllib.request import urlretrieve, urlopen |
| 25 | + |
| 26 | +import onnx.backend.test |
| 27 | +from onnx.backend.test.case.test_case import TestCase as OnnxTestCase |
| 28 | + |
| 29 | + |
| 30 | +class ModelZooTestRunner(onnx.backend.test.BackendTest): |
| 31 | + |
| 32 | + def __init__(self, backend, models_dict, parent_module=None): |
| 33 | + # type: (Type[Backend], Dict[str,str], Optional[str]) -> None |
| 34 | + self.backend = backend |
| 35 | + self._parent_module = parent_module |
| 36 | + self._include_patterns = set() # type: Set[Pattern[Text]] |
| 37 | + self._exclude_patterns = set() # type: Set[Pattern[Text]] |
| 38 | + self._test_items = defaultdict(dict) # type: Dict[Text, Dict[Text, TestItem]] |
| 39 | + |
| 40 | + for model_name, url in models_dict.items(): |
| 41 | + test_name = 'test_{}'.format(model_name) |
| 42 | + |
| 43 | + test_case = OnnxTestCase( |
| 44 | + name=test_name, |
| 45 | + url=url, |
| 46 | + model_name=model_name, |
| 47 | + model_dir=None, |
| 48 | + model=None, |
| 49 | + data_sets=None, |
| 50 | + kind='OnnxBackendRealModelTest', |
| 51 | + ) |
| 52 | + self._add_model_test(test_case, 'Zoo') |
| 53 | + |
| 54 | + @staticmethod |
| 55 | + def _get_etag_for_url(url): # type: (str) -> str |
| 56 | + request = urlopen(url) |
| 57 | + return request.info().get('ETag') |
| 58 | + |
| 59 | + @staticmethod |
| 60 | + def _read_etag_file(model_dir): # type: (str) -> str |
| 61 | + etag_file_path = os.path.join(model_dir, 'source_tar_etag') |
| 62 | + if os.path.exists(etag_file_path): |
| 63 | + return open(etag_file_path).read() |
| 64 | + |
| 65 | + @staticmethod |
| 66 | + def _write_etag_file(model_dir, etag_value): # type: (str, str) -> None |
| 67 | + etag_file_path = os.path.join(model_dir, 'source_tar_etag') |
| 68 | + open(etag_file_path, 'w').write(etag_value) |
| 69 | + |
| 70 | + @staticmethod |
| 71 | + def _backup_old_version(model_dir): # type: (str) -> None |
| 72 | + if os.path.exists(model_dir): |
| 73 | + backup_index = 0 |
| 74 | + while True: |
| 75 | + dest = '{}.old.{}'.format(model_dir, backup_index) |
| 76 | + if os.path.exists(dest): |
| 77 | + backup_index += 1 |
| 78 | + continue |
| 79 | + shutil.move(model_dir, dest) |
| 80 | + break |
| 81 | + |
| 82 | + def _prepare_model_data(self, model_test): # type: (TestCase) -> Text |
| 83 | + onnx_home = os.path.expanduser(os.getenv('ONNX_HOME', os.path.join('~', '.onnx'))) |
| 84 | + models_dir = os.getenv('ONNX_MODELS', os.path.join(onnx_home, 'models')) |
| 85 | + model_dir = os.path.join(models_dir, model_test.model_name) # type: Text |
| 86 | + current_version_etag = self._get_etag_for_url(model_test.url) |
| 87 | + |
| 88 | + # If model already exists, check if it's the latest version by verifying cached Etag value |
| 89 | + if os.path.exists(os.path.join(model_dir, 'model.onnx')): |
| 90 | + if not current_version_etag or current_version_etag == self._read_etag_file(model_dir): |
| 91 | + return model_dir |
| 92 | + |
| 93 | + # If model does exist, but is not current, backup directory |
| 94 | + self._backup_old_version(model_dir) |
| 95 | + |
| 96 | + # Download and extract model and data |
| 97 | + download_file = tempfile.NamedTemporaryFile(delete=False) |
| 98 | + temp_clean_dir = tempfile.mkdtemp() |
| 99 | + |
| 100 | + try: |
| 101 | + download_file.close() |
| 102 | + print('\nStart downloading model {} from {}'.format( |
| 103 | + model_test.model_name, model_test.url)) |
| 104 | + urlretrieve(model_test.url, download_file.name) |
| 105 | + print('Done') |
| 106 | + |
| 107 | + with tempfile.TemporaryDirectory() as temp_extract_dir: |
| 108 | + with tarfile.open(download_file.name) as tar_file: |
| 109 | + tar_file.extractall(temp_extract_dir) |
| 110 | + |
| 111 | + # Move model `.onnx` file from temp_extract_dir to temp_clean_dir |
| 112 | + model_files = glob.glob(temp_extract_dir + '/**/*.onnx', recursive=True) |
| 113 | + assert len(model_files) > 0, 'Model file not found for {}'.format(model_test.name) |
| 114 | + model_file = model_files[0] |
| 115 | + shutil.move(model_file, temp_clean_dir + '/model.onnx') |
| 116 | + |
| 117 | + # Move extracted test data sets to temp_clean_dir |
| 118 | + test_data_sets = glob.glob(temp_extract_dir + '/**/test_data_set_*', recursive=True) |
| 119 | + test_data_sets.extend( |
| 120 | + glob.glob(temp_extract_dir + '/**/test_data_*.npz', recursive=True)) |
| 121 | + for test_data_set in test_data_sets: |
| 122 | + shutil.move(test_data_set, temp_clean_dir) |
| 123 | + |
| 124 | + # Save Etag value to Etag file |
| 125 | + self._write_etag_file(temp_clean_dir, current_version_etag) |
| 126 | + |
| 127 | + # Move temp_clean_dir to ultimate destination |
| 128 | + shutil.move(temp_clean_dir, model_dir) |
| 129 | + |
| 130 | + except Exception as e: |
| 131 | + print('Failed to prepare data for model {}: {}'.format(model_test.model_name, e)) |
| 132 | + os.remove(temp_clean_dir) |
| 133 | + raise |
| 134 | + finally: |
| 135 | + os.remove(download_file.name) |
| 136 | + return model_dir |
0 commit comments