-
Notifications
You must be signed in to change notification settings - Fork 39
Description
我发现在源码的以下函数当中制定了yolo的网络模型为MobileNet
yolo = create_yolo(
architecture = "MobileNet",
labels = self.labels,
input_size = self.input_shape[:2],
anchors = self.anchors,
coord_scale=1.0,
class_scale=1.0,
object_scale=5.0,
no_object_scale=1.0,
weights = weights,
strip_size = strip_size
)
另外在以下函数中列出了所有支持的yolo模型:
def create_feature_extractor(architecture, input_size, weights=None, strip_size = 32):
"""
# Args
architecture : str
input_size : int
# Returns
feature_extractor : BaseFeatureExtractor instance
"""
if architecture == 'Inception3':
feature_extractor = Inception3Feature(input_size, weights)
elif architecture == 'SqueezeNet':
feature_extractor = SqueezeNetFeature(input_size, weights)
elif architecture == 'MobileNet':
feature_extractor = MobileNetFeature(input_size, weights, strip_size = strip_size)
elif architecture == 'Full Yolo':
feature_extractor = FullYoloFeature(input_size, weights)
elif architecture == 'Tiny Yolo':
feature_extractor = TinyYoloFeature(input_size, weights)
elif architecture == 'VGG16':
feature_extractor = VGG16Feature(input_size, weights)
elif architecture == 'ResNet50':
feature_extractor = ResNet50Feature(input_size, weights)
else:
raise Exception('Architecture not supported! Only support Full Yolo, Tiny Yolo, MobileNet, SqueezeNet, VGG16, ResNet50, and Inception3 at the moment!')
return feature_extractor
但是当我想尝试修改参数architecture切换其他模型时,报错提示:
ValueError: You are trying to load a weight file containing 54 layers into a model with 25 layers.
我发现在源码train/detector/init.py中,指定了一个weitht权重文件:mobilenet_7_5_224_tf_no_top.h5
def train(self, epochs= 100,
progress_cb=None,
weights=os.path.join(curr_file_dir, "weights", "mobilenet_7_5_224_tf_no_top.h5"),
batch_size = 5,
train_times = 5,
valid_times = 2,
learning_rate=1e-4,
jitter = False,
is_only_detect = False,
save_best_weights_path = "out/best_weights.h5",
save_final_weights_path = "out/final_weights.h5",
):
我想请问下,这个文件mobilenet_7_5_224_tf_no_top.h5,用途是什么,是在网络初始化时指定网络的默认参数吗?如果切换成其他模型,比如tiny yolo,如何指定weith参数才可正常训练。