22
22
from torchrl .collectors import SyncDataCollector
23
23
from torchrl .data import NonTensor
24
24
from torchrl .data .replay_buffers .samplers import SliceSamplerWithoutReplacement
25
- from torchrl .data .tensor_specs import Composite
25
+ from torchrl .data .tensor_specs import Box , Composite , TensorSpec
26
26
from torchrl .envs import ChessEnv
27
27
from torchrl .envs .transforms import (
28
28
ConditionalPolicySwitch ,
@@ -75,6 +75,41 @@ def _reset(self, tensordict, tensordict_reset):
75
75
return tensordict_reset
76
76
77
77
78
+ class Score (Transform ):
79
+ def __init__ (self , input_queue , output_queue ):
80
+ super ().__init__ ()
81
+ self .input_queue = input_queue
82
+ self .output_queue = output_queue
83
+
84
+ def _step (self , tensordict , next_tensordict ):
85
+ fen = next_tensordict ["fen" ]
86
+ self .input_queue .put (fen )
87
+ _ , score = self .output_queue .get ()
88
+ next_tensordict ["score" ] = torch .tensor (
89
+ score , device = "cuda:7" , dtype = torch .bfloat16
90
+ )
91
+ return next_tensordict
92
+
93
+ def _reset (self , tensordict , tensordict_reset ):
94
+ fen = tensordict_reset ["fen" ]
95
+ self .input_queue .put (fen )
96
+ _ , score = self .output_queue .get ()
97
+ tensordict_reset ["score" ] = torch .tensor (
98
+ score , device = "cuda:7" , dtype = torch .bfloat16
99
+ )
100
+ return tensordict_reset
101
+
102
+ def transform_observation_spec (self , observation_spec : Composite ):
103
+ if not isinstance (observation_spec , Composite ):
104
+ raise ValueError (
105
+ f"observation_spec was expected to be of type Composite. Got { type (observation_spec )} instead."
106
+ )
107
+ observation_spec ["observation" ] = TensorSpec (
108
+ (), Box (), dtype = torch .bfloat16 , device = "cuda:7"
109
+ )
110
+ return observation_spec
111
+
112
+
78
113
class LLMInputTransform (Transform ):
79
114
def __init__ (self , san_moves ):
80
115
super ().__init__ ()
@@ -159,9 +194,14 @@ def run_player(input_queue, output_queue):
159
194
160
195
output = process .stdout .readline ()
161
196
if output :
162
- # print(f"Output: {output.strip()}")
197
+ print (f"Output: { output .strip ()} " )
163
198
move = re .search (r"bestmove (.*)" , output .strip ()).group (1 )
164
- output_queue .put (move )
199
+
200
+ output = process .stdout .readline ()
201
+ print (f"Output scores: { output .strip ()} " )
202
+ score = re .search (r"score (.*)" , output .strip ()).group (1 )
203
+
204
+ output_queue .put ((move , int (score )))
165
205
166
206
except queue .Empty :
167
207
continue
@@ -179,7 +219,7 @@ def run_player(input_queue, output_queue):
179
219
def setup_env (input_queue , output_queue , tokenizer ):
180
220
def policy_sunfish (td ):
181
221
input_queue .put (td ["fen" ])
182
- move = output_queue .get ()
222
+ move , _ = output_queue .get ()
183
223
san = env .board .san (chess .Move .from_uci (move ))
184
224
san_idx = env .san_moves .index (san )
185
225
td ["action" ] = torch .tensor (san_idx )
@@ -205,6 +245,7 @@ def policy_sunfish(td):
205
245
tokenizer = tokenizer ,
206
246
)
207
247
)
248
+ env .append_transform (Score (input_queue , output_queue ))
208
249
env .reset ()
209
250
return env
210
251
@@ -405,16 +446,28 @@ def remove_logits(td):
405
446
return_composite = True ,
406
447
)
407
448
408
- class CriticHead (torch .nn .Module ):
449
+ # class CriticHead(torch.nn.Module):
450
+ # def __init__(self):
451
+ # super().__init__()
452
+ # self.m = torch.nn.Linear(3584, 1, dtype=torch.bfloat16)
453
+
454
+ # def forward(self, hidden):
455
+ # return self.m(hidden).squeeze(-1).sum(-1, keepdim=True)
456
+
457
+ # critic_llm_policy = Seq(
458
+ # Mod(CriticHead(), in_keys=["hidden"], out_keys=["state_value"]),
459
+ # )
460
+
461
+ class CriticLLMPolicy (torch .nn .Module ):
409
462
def __init__ (self ):
410
463
super ().__init__ ()
411
- self .m = torch .nn .Linear (3584 , 1 , dtype = torch .bfloat16 )
412
464
413
- def forward (self , hidden ):
414
- return self .m (hidden ).squeeze (- 1 ).sum (- 1 , keepdim = True )
465
+ def forward (self , score ):
466
+ # breakpoint()
467
+ return score .unsqueeze (- 1 )
415
468
416
469
critic_llm_policy = Seq (
417
- Mod (CriticHead (), in_keys = ["hidden " ], out_keys = ["state_value" ]),
470
+ Mod (CriticLLMPolicy (), in_keys = ["score " ], out_keys = ["state_value" ]),
418
471
)
419
472
420
473
return actor_llm_policy , data_llm_policy , critic_llm_policy , tokenizer
@@ -425,12 +478,21 @@ def play(env, data_llm_policy, actor_llm_policy, tokenizer):
425
478
426
479
rb = ReplayBuffer (
427
480
storage = LazyStackStorage (100 ),
428
- batch_size = 8 ,
481
+ batch_size = 48 ,
429
482
sampler = SliceSamplerWithoutReplacement (slice_len = 8 , end_key = ("next" , "done" )),
430
483
)
431
484
485
+ # def breakpointy(td):
486
+ # breakpoint()
487
+ # return td
488
+
489
+ # rb.append_transform(breakpointy)
490
+
491
+ # Temporarily patched fbcode/pytorch/tensordict/tensordict/_lazy.py?lines=1502
432
492
rb .append_transform (lambda td : td .densify (layout = torch .jagged ))
433
493
494
+ # rb.append_transform(breakpointy)
495
+
434
496
# obs_tokens in layout=torch.jagged errors with Qwen
435
497
# File "/home/mg1998/.conda/envs/rl/lib/python3.10/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 859, in forward
436
498
# cache_position = torch.arange(
@@ -486,6 +548,7 @@ def obs_token_transform(td):
486
548
487
549
data = gae (data )
488
550
loss = loss_module (data )
551
+ breakpoint ()
489
552
loss .sum (reduce = True ).backward ()
490
553
torch .nn .utils .clip_grad_norm_ (loss_module .parameters (), 0.5 )
491
554
optim .step ()
0 commit comments