Skip to content

Commit 8d59183

Browse files
committed
Update IMDB dataloader mapping
1 parent 8b69b97 commit 8d59183

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

beginner_source/t5_tutorial.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -168,11 +168,11 @@ def apply_prefix(task, x):
168168
imdb_batch_size = 3
169169
imdb_datapipe = IMDB(split="test")
170170
task = "sst2 sentence"
171-
labels = {"neg": "negative", "pos": "positive"}
171+
labels = {"1": "negative", "2": "positive"}
172172

173173

174174
def process_labels(labels, x):
175-
return x[1], labels[x[0]]
175+
return x[1], labels[str(x[0])]
176176

177177

178178
imdb_datapipe = imdb_datapipe.map(partial(process_labels, labels))
@@ -361,7 +361,7 @@ def process_labels(labels, x):
361361
# really annoying was the constant cuts to VDs daughter during the last fight scene.<br /><br />
362362
# Not bad. Not good. Passable 4.
363363
#
364-
# prediction: negative
364+
# prediction: positive
365365
#
366366
# target: negative
367367
#
@@ -388,13 +388,12 @@ def process_labels(labels, x):
388388
# ---------------------
389389
#
390390
# Finally, we can also use the model to generate English to German translations on the first batch of examples from the Multi30k
391-
# test set using a beam size of 4.
391+
# test set.
392392
#
393393

394394
batch = next(iter(multi_dataloader))
395395
input_text = batch["english"]
396396
target = batch["german"]
397-
beam_size = 4
398397

399398
model_input = transform(input_text)
400399
model_output = sequence_generator.generate(model_input, eos_idx=eos_idx, num_beams=beam_size)

0 commit comments

Comments
 (0)