diff --git a/intermediate_source/reinforcement_q_learning.py b/intermediate_source/reinforcement_q_learning.py index 736829fe05f..5bf7637ed0f 100644 --- a/intermediate_source/reinforcement_q_learning.py +++ b/intermediate_source/reinforcement_q_learning.py @@ -310,12 +310,12 @@ def get_screen(): model.cuda() -class Variable(autograd.Variable): +def Variable(data, volatile=False): + if USE_CUDA: + return autograd.Variable(data.cuda(),volatile=volatile) + else: + return autograd.Variable(data, volatile=volatile) - def __init__(self, data, *args, **kwargs): - if USE_CUDA: - data = data.cuda() - super(Variable, self).__init__(data, *args, **kwargs) steps_done = 0