-
Notifications
You must be signed in to change notification settings - Fork 204
Open
Description
Hi, I met the following error when finetune llama7b model with FSDP+HQQ:
Traceback (most recent call last):
File "/usr/local/lib/python3.10/dist-packages/torch/multiprocessing/spawn.py", line 74, in _wrap
fn(i, *args)
File "/workspace/fsdp_qlora/train.py", line 723, in fsdp_main
model = FSDP(
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 481, in __init__
_auto_wrap(
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_wrap_utils.py", line 101, in _auto_wrap
_recursive_wrap(**recursive_wrap_kwargs, **root_kwargs) # type: ignore[arg-type]
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/wrap.py", line 543, in _recursive_wrap
wrapped_child, num_wrapped_params = _recursive_wrap(
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/wrap.py", line 543, in _recursive_wrap
wrapped_child, num_wrapped_params = _recursive_wrap(
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/wrap.py", line 543, in _recursive_wrap
wrapped_child, num_wrapped_params = _recursive_wrap(
[Previous line repeated 1 more time]
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/wrap.py", line 561, in _recursive_wrap
return _wrap(module, wrapper_cls, **kwargs), nonwrapped_numel
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/wrap.py", line 490, in _wrap
return wrapper_cls(module, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 481, in __init__
_auto_wrap(
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_wrap_utils.py", line 45, in _auto_wrap
_check_nested_wrapping(root_module)
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_wrap_utils.py", line 107, in _check_nested_wrapping
raise ValueError(
ValueError: FSDP auto wrapping requires modules to not already have FSDP applied but found q_proj.lora_AB in
LlamaSdpaAttention(
(q_proj): LORA(
(base_layer): HQQLinear()
(lora_AB): FullyShardedDataParallel(
(_fsdp_wrapped_module): Sequential(
(0): Linear(in_features=4096, out_features=64, bias=False)
(1): Linear(in_features=64, out_features=4096, bias=False)
)
)
(lora_dropout): Dropout(p=0.1, inplace=False)
)
(k_proj): LORA(
(base_layer): HQQLinear()
(lora_AB): FullyShardedDataParallel(
(_fsdp_wrapped_module): Sequential(
(0): Linear(in_features=4096, out_features=64, bias=False)
(1): Linear(in_features=64, out_features=4096, bias=False)
)
)
(lora_dropout): Dropout(p=0.1, inplace=False)
)
(v_proj): LORA(
(base_layer): HQQLinear()
(lora_AB): FullyShardedDataParallel(
(_fsdp_wrapped_module): Sequential(
(0): Linear(in_features=4096, out_features=64, bias=False)
(1): Linear(in_features=64, out_features=4096, bias=False)
)
)
(lora_dropout): Dropout(p=0.1, inplace=False)
)
(o_proj): HQQLinear()
(rotary_emb): LlamaRotaryEmbedding()
)
the command is:
export CUDA_VISIBLE_DEVICES=3,4
python train.py \
--world_size 2 \
--model_name /workspace/model/Llama-2-7b-hf \
--gradient_accumulation_steps 2 \
--batch_size 1 \
--context_length 4096 \
--num_epochs 1 \
--sharding_strategy full_shard \
--precision bf16 \
--train_type hqq_lora \
--use_gradient_checkpointing true \
--use_cpu_offload true \
--dataset dummy \
--verbose true
How to solve this problem?
Looking forward to your reply.
Metadata
Metadata
Assignees
Labels
No labels