Skip to content

Commit f5d2b15

Browse files
Update on "[DRAFT] ppo chess with llm and ConditionalPolicySwitch to sunfish bot"
[ghstack-poisoned]
1 parent 7ee0011 commit f5d2b15

File tree

1 file changed

+73
-10
lines changed

1 file changed

+73
-10
lines changed

examples/agents/ppo-chess-llm.py

Lines changed: 73 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from torchrl.collectors import SyncDataCollector
2323
from torchrl.data import NonTensor
2424
from torchrl.data.replay_buffers.samplers import SliceSamplerWithoutReplacement
25-
from torchrl.data.tensor_specs import Composite
25+
from torchrl.data.tensor_specs import Box, Composite, TensorSpec
2626
from torchrl.envs import ChessEnv
2727
from torchrl.envs.transforms import (
2828
ConditionalPolicySwitch,
@@ -75,6 +75,41 @@ def _reset(self, tensordict, tensordict_reset):
7575
return tensordict_reset
7676

7777

78+
class Score(Transform):
79+
def __init__(self, input_queue, output_queue):
80+
super().__init__()
81+
self.input_queue = input_queue
82+
self.output_queue = output_queue
83+
84+
def _step(self, tensordict, next_tensordict):
85+
fen = next_tensordict["fen"]
86+
self.input_queue.put(fen)
87+
_, score = self.output_queue.get()
88+
next_tensordict["score"] = torch.tensor(
89+
score, device="cuda:7", dtype=torch.bfloat16
90+
)
91+
return next_tensordict
92+
93+
def _reset(self, tensordict, tensordict_reset):
94+
fen = tensordict_reset["fen"]
95+
self.input_queue.put(fen)
96+
_, score = self.output_queue.get()
97+
tensordict_reset["score"] = torch.tensor(
98+
score, device="cuda:7", dtype=torch.bfloat16
99+
)
100+
return tensordict_reset
101+
102+
def transform_observation_spec(self, observation_spec: Composite):
103+
if not isinstance(observation_spec, Composite):
104+
raise ValueError(
105+
f"observation_spec was expected to be of type Composite. Got {type(observation_spec)} instead."
106+
)
107+
observation_spec["observation"] = TensorSpec(
108+
(), Box(), dtype=torch.bfloat16, device="cuda:7"
109+
)
110+
return observation_spec
111+
112+
78113
class LLMInputTransform(Transform):
79114
def __init__(self, san_moves):
80115
super().__init__()
@@ -159,9 +194,14 @@ def run_player(input_queue, output_queue):
159194

160195
output = process.stdout.readline()
161196
if output:
162-
# print(f"Output: {output.strip()}")
197+
print(f"Output: {output.strip()}")
163198
move = re.search(r"bestmove (.*)", output.strip()).group(1)
164-
output_queue.put(move)
199+
200+
output = process.stdout.readline()
201+
print(f"Output scores: {output.strip()}")
202+
score = re.search(r"score (.*)", output.strip()).group(1)
203+
204+
output_queue.put((move, int(score)))
165205

166206
except queue.Empty:
167207
continue
@@ -179,7 +219,7 @@ def run_player(input_queue, output_queue):
179219
def setup_env(input_queue, output_queue, tokenizer):
180220
def policy_sunfish(td):
181221
input_queue.put(td["fen"])
182-
move = output_queue.get()
222+
move, _ = output_queue.get()
183223
san = env.board.san(chess.Move.from_uci(move))
184224
san_idx = env.san_moves.index(san)
185225
td["action"] = torch.tensor(san_idx)
@@ -205,6 +245,7 @@ def policy_sunfish(td):
205245
tokenizer=tokenizer,
206246
)
207247
)
248+
env.append_transform(Score(input_queue, output_queue))
208249
env.reset()
209250
return env
210251

@@ -405,16 +446,28 @@ def remove_logits(td):
405446
return_composite=True,
406447
)
407448

408-
class CriticHead(torch.nn.Module):
449+
# class CriticHead(torch.nn.Module):
450+
# def __init__(self):
451+
# super().__init__()
452+
# self.m = torch.nn.Linear(3584, 1, dtype=torch.bfloat16)
453+
454+
# def forward(self, hidden):
455+
# return self.m(hidden).squeeze(-1).sum(-1, keepdim=True)
456+
457+
# critic_llm_policy = Seq(
458+
# Mod(CriticHead(), in_keys=["hidden"], out_keys=["state_value"]),
459+
# )
460+
461+
class CriticLLMPolicy(torch.nn.Module):
409462
def __init__(self):
410463
super().__init__()
411-
self.m = torch.nn.Linear(3584, 1, dtype=torch.bfloat16)
412464

413-
def forward(self, hidden):
414-
return self.m(hidden).squeeze(-1).sum(-1, keepdim=True)
465+
def forward(self, score):
466+
# breakpoint()
467+
return score.unsqueeze(-1)
415468

416469
critic_llm_policy = Seq(
417-
Mod(CriticHead(), in_keys=["hidden"], out_keys=["state_value"]),
470+
Mod(CriticLLMPolicy(), in_keys=["score"], out_keys=["state_value"]),
418471
)
419472

420473
return actor_llm_policy, data_llm_policy, critic_llm_policy, tokenizer
@@ -425,12 +478,21 @@ def play(env, data_llm_policy, actor_llm_policy, tokenizer):
425478

426479
rb = ReplayBuffer(
427480
storage=LazyStackStorage(100),
428-
batch_size=8,
481+
batch_size=48,
429482
sampler=SliceSamplerWithoutReplacement(slice_len=8, end_key=("next", "done")),
430483
)
431484

485+
# def breakpointy(td):
486+
# breakpoint()
487+
# return td
488+
489+
# rb.append_transform(breakpointy)
490+
491+
# Temporarily patched fbcode/pytorch/tensordict/tensordict/_lazy.py?lines=1502
432492
rb.append_transform(lambda td: td.densify(layout=torch.jagged))
433493

494+
# rb.append_transform(breakpointy)
495+
434496
# obs_tokens in layout=torch.jagged errors with Qwen
435497
# File "/home/mg1998/.conda/envs/rl/lib/python3.10/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 859, in forward
436498
# cache_position = torch.arange(
@@ -486,6 +548,7 @@ def obs_token_transform(td):
486548

487549
data = gae(data)
488550
loss = loss_module(data)
551+
breakpoint()
489552
loss.sum(reduce=True).backward()
490553
torch.nn.utils.clip_grad_norm_(loss_module.parameters(), 0.5)
491554
optim.step()

0 commit comments

Comments
 (0)