Skip to content

Commit d1c039e

Browse files
authored
fix accelerator prepare during eval only mode (#24014)
* fix mixed precision prep during eval only mode * update to address comments * update to reflect the changes in accelerate
1 parent 2c887cf commit d1c039e

File tree

1 file changed

+38
-6
lines changed

1 file changed

+38
-6
lines changed

src/transformers/trainer.py

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3141,14 +3141,30 @@ def evaluation_loop(
31413141

31423142
prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only
31433143

3144-
# if eval is called w/o train init deepspeed here
3144+
# if eval is called w/o train, handle model prep here
31453145
if self.is_deepspeed_enabled and self.model_wrapped is self.model:
31463146
_, _ = deepspeed_init(self, num_training_steps=0, inference=True)
3147-
model = self.accelerator.prepare(self.model)
3148-
self.model_wrapped = self.deepspeed = model
31493147

31503148
model = self._wrap_model(self.model, training=False, dataloader=dataloader)
31513149

3150+
if len(self.accelerator._models) == 0 and model is self.model:
3151+
model = (
3152+
self.accelerator.prepare(model)
3153+
if self.is_deepspeed_enabled
3154+
else self.accelerator.prepare_model(model, evaluation_mode=True)
3155+
)
3156+
3157+
if self.is_fsdp_enabled:
3158+
self.model = model
3159+
3160+
# for the rest of this function `model` is the outside model, whether it was wrapped or not
3161+
if model is not self.model:
3162+
self.model_wrapped = model
3163+
3164+
# backward compatibility
3165+
if self.is_deepspeed_enabled:
3166+
self.deepspeed = self.model_wrapped
3167+
31523168
# if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
31533169
# while ``train`` is running, cast it to the right dtype first and then put on device
31543170
if not self.is_in_train:
@@ -3736,14 +3752,30 @@ def prediction_loop(
37363752

37373753
prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only
37383754

3739-
# if eval is called w/o train init deepspeed here
3755+
# if eval is called w/o train, handle model prep here
37403756
if self.is_deepspeed_enabled and self.model_wrapped is self.model:
37413757
_, _ = deepspeed_init(self, num_training_steps=0, inference=True)
3742-
model = self.accelerator.prepare(self.model)
3743-
self.model_wrapped = self.deepspeed = model
37443758

37453759
model = self._wrap_model(self.model, training=False, dataloader=dataloader)
37463760

3761+
if len(self.accelerator._models) == 0 and model is self.model:
3762+
model = (
3763+
self.accelerator.prepare(model)
3764+
if self.is_deepspeed_enabled
3765+
else self.accelerator.prepare_model(model, evaluation_mode=True)
3766+
)
3767+
3768+
if self.is_fsdp_enabled:
3769+
self.model = model
3770+
3771+
# for the rest of this function `model` is the outside model, whether it was wrapped or not
3772+
if model is not self.model:
3773+
self.model_wrapped = model
3774+
3775+
# backward compatibility
3776+
if self.is_deepspeed_enabled:
3777+
self.deepspeed = self.model_wrapped
3778+
37473779
# if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
37483780
# while ``train`` is running, cast it to the right dtype first and then put on device
37493781
if not self.is_in_train:

0 commit comments

Comments
 (0)