Skip to content

Commit fe8118a

Browse files
committed
refactor: for discussion
1 parent 9941fa0 commit fe8118a

37 files changed

+230
-802
lines changed
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
from functools import partial
2+
from pathlib import Path
3+
import numpy as np
4+
import random
5+
import argparse
6+
import torch
7+
import torch.nn.functional as F
8+
import ignite
9+
import logging
10+
import workflow
11+
from workflow.functional import starcompose
12+
from workflow.torch import set_seeds
13+
from workflow.ignite import worker_init
14+
from workflow.ignite.handlers.learning_rate import (
15+
LearningRateScheduler, warmup, cyclical
16+
)
17+
from datastream import Datastream
18+
19+
from {{cookiecutter.package_name}} import (
20+
datastream, architecture, metrics, log_examples
21+
)
22+
23+
24+
def train(config):
25+
set_seeds(config['seed'])
26+
device = torch.device('cuda' if config['use_cuda'] else 'cpu')
27+
28+
model = architecture.Model().to(device)
29+
optimizer = torch.optim.Adam(
30+
model.parameters(), lr=config['learning_rate']
31+
)
32+
33+
train_state = dict(model=model, optimizer=optimizer)
34+
35+
if Path('model').exists():
36+
print('Loading model checkpoint')
37+
workflow.ignite.handlers.ModelCheckpoint.load(
38+
train_state, 'model/checkpoints', device
39+
)
40+
workflow.torch.set_learning_rate(optimizer, config['learning_rate'])
41+
42+
43+
evaluate_data_loaders = {
44+
f'evaluate_{name}': datastream.data_loader(
45+
batch_size=config['eval_batch_size'],
46+
num_workers=config['n_workers'],
47+
collate_fn=tuple,
48+
)
49+
for name, datastream in datastream.evaluate_datastreams().items()
50+
}
51+
52+
gradient_data_loader = (
53+
datastream.GradientDatastream()
54+
.data_loader(
55+
batch_size=config['batch_size'],
56+
num_workers=config['n_workers'],
57+
n_batches_per_epoch=config['n_batches_per_epoch'],
58+
worker_init_fn=partial(worker_init, config['seed'], trainer),
59+
collate_fn=tuple,
60+
)
61+
)
62+
63+
tensorboard_logger = torch.utils.tensorboard.SummaryWriter()
64+
early_stopping = workflow.EarlyStopping(...)
65+
66+
for epoch in tqdm(range(config['max_epochs'])):
67+
for examples in tqdm(gradient_data_loader):
68+
with workflow.train(model, optimizer):
69+
predictions = model.predictions(
70+
architecture.FeatureBatch.from_examples(examples)
71+
)
72+
loss = predictions.loss(examples)
73+
loss.backward()
74+
75+
metrics.gradient_metrics(
76+
examples, predictions, loss, tensorboard_logger
77+
)
78+
# optional: schedule learning rate
79+
80+
for name, data_loader in evaluate_data_loaders:
81+
for examples in tqdm(data_loader):
82+
with workflow.eval(model):
83+
predictions = model.predictions(
84+
architecture.FeatureBatch.from_examples(examples)
85+
)
86+
loss = predictions.loss(examples)
87+
88+
# TODO: metrics need state?
89+
# metrics.evaluate_metrics(
90+
# examples, predictions, loss, tensorboard_logger
91+
# )
92+
93+
improved, out_of_patience = early_stopping.score(output)
94+
if improved:
95+
torch.save(train_state, 'model_checkpoint.pt')
96+
elif out_of_patience(output):
97+
break
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
from functools import partial
2+
from pathlib import Path
3+
import numpy as np
4+
import random
5+
import argparse
6+
import torch
7+
import torch.nn.functional as F
8+
import ignite
9+
import logging
10+
import workflow
11+
from workflow.functional import starcompose
12+
from workflow.torch import set_seeds
13+
from workflow.ignite import worker_init
14+
from workflow.ignite.handlers.learning_rate import (
15+
LearningRateScheduler, warmup, cyclical
16+
)
17+
from datastream import Datastream
18+
19+
from {{cookiecutter.package_name}} import (
20+
datastream, architecture, metrics, log_examples
21+
)
22+
23+
24+
def train(config):
25+
set_seeds(config['seed'])
26+
device = torch.device('cuda' if config['use_cuda'] else 'cpu')
27+
28+
model = architecture.Model().to(device)
29+
optimizer = torch.optim.Adam(
30+
model.parameters(), lr=config['learning_rate']
31+
)
32+
33+
train_state = dict(model=model, optimizer=optimizer)
34+
35+
if Path('model').exists():
36+
print('Loading model checkpoint')
37+
workflow.ignite.handlers.ModelCheckpoint.load(
38+
train_state, 'model/checkpoints', device
39+
)
40+
workflow.torch.set_learning_rate(optimizer, config['learning_rate'])
41+
42+
43+
evaluate_data_loaders = {
44+
f'evaluate_{name}': datastream.data_loader(
45+
batch_size=config['eval_batch_size'],
46+
num_workers=config['n_workers'],
47+
collate_fn=tuple,
48+
)
49+
for name, datastream in datastream.evaluate_datastreams().items()
50+
}
51+
52+
gradient_data_loader = (
53+
datastream.GradientDatastream()
54+
.data_loader(
55+
batch_size=config['batch_size'],
56+
num_workers=config['n_workers'],
57+
n_batches_per_epoch=config['n_batches_per_epoch'],
58+
worker_init_fn=partial(worker_init, config['seed'], trainer),
59+
collate_fn=tuple,
60+
)
61+
)
62+
63+
tensorboard_logger = torch.utils.tensorboard.SummaryWriter()
64+
early_stopping = workflow.EarlyStopping(...)
65+
66+
for epoch in tqdm(range(config['max_epochs'])):
67+
68+
with workflow.module_train(model):
69+
for examples in tqdm(gradient_data_loader):
70+
predictions = model.predictions(
71+
architecture.FeatureBatch.from_examples(examples)
72+
)
73+
loss = predictions.loss(examples)
74+
loss.backward()
75+
optimizer.step()
76+
optimizer.zero_grad()
77+
78+
metrics.gradient_metrics(
79+
examples, predictions, loss, tensorboard_logger
80+
)
81+
# optional: schedule learning rate
82+
83+
with torch.no_grad, workflow.module_eval(model):
84+
for name, data_loader in evaluate_data_loaders:
85+
for examples in tqdm(data_loader):
86+
predictions = model.predictions(
87+
architecture.FeatureBatch.from_examples(examples)
88+
)
89+
loss = predictions.loss(examples)
90+
# TODO: metrics need state?
91+
# metrics.evaluate_metrics(
92+
# examples, predictions, loss, tensorboard_logger
93+
# )
94+
95+
improved, out_of_patience = early_stopping.score(output)
96+
if improved:
97+
torch.save(train_state, 'model_checkpoint.pt')
98+
elif out_of_patience(output):
99+
break

template/{{cookiecutter.repository_name}}/{{cookiecutter.package_name}}/train.py

Lines changed: 34 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,7 @@
2222

2323

2424
def train(config):
25-
2625
set_seeds(config['seed'])
27-
2826
device = torch.device('cuda' if config['use_cuda'] else 'cpu')
2927

3028
model = architecture.Model().to(device)
@@ -39,23 +37,18 @@ def train(config):
3937
workflow.ignite.handlers.ModelCheckpoint.load(
4038
train_state, 'model/checkpoints', device
4139
)
42-
4340
workflow.torch.set_learning_rate(optimizer, config['learning_rate'])
4441

45-
n_parameters = sum([
46-
p.shape.numel() for p in model.parameters() if p.requires_grad
47-
])
48-
print(f'n_parameters: {n_parameters:,}')
49-
42+
#
5043
def process_batch(examples):
5144
predictions = model.predictions(
5245
architecture.FeatureBatch.from_examples(examples)
5346
)
5447
loss = predictions.loss(examples)
5548
return predictions, loss
5649

57-
@workflow.ignite.decorators.train(model, optimizer)
58-
def train_batch(engine, examples):
50+
@workflow.torch.decorators.train(model, optimizer)
51+
def train_batch(examples):
5952
predictions, loss = process_batch(examples)
6053
loss.backward()
6154
return dict(
@@ -64,8 +57,8 @@ def train_batch(engine, examples):
6457
loss=loss,
6558
)
6659

67-
@workflow.ignite.decorators.evaluate(model)
68-
def evaluate_batch(engine, examples):
60+
@workflow.torch.decorators.evaluate(model)
61+
def evaluate_batch(examples):
6962
predictions, loss = process_batch(examples)
7063
return dict(
7164
examples=examples,
@@ -82,78 +75,34 @@ def evaluate_batch(engine, examples):
8275
for name, datastream in datastream.evaluate_datastreams().items()
8376
}
8477

85-
trainer, evaluators, tensorboard_logger = workflow.ignite.trainer(
86-
train_batch,
87-
evaluate_batch,
88-
evaluate_data_loaders,
89-
metrics=dict(
90-
progress=metrics.progress_metrics(),
91-
train=metrics.train_metrics(),
92-
**{
93-
name: metrics.evaluate_metrics()
94-
for name in evaluate_data_loaders.keys()
95-
}
96-
),
97-
optimizers=optimizer,
98-
)
99-
100-
workflow.ignite.handlers.ModelScore(
101-
lambda: -evaluators['evaluate_early_stopping'].state.metrics['loss'],
102-
train_state,
103-
{
104-
name: metrics.evaluate_metrics()
105-
for name in evaluate_data_loaders.keys()
106-
},
107-
tensorboard_logger,
108-
config,
109-
).attach(trainer, evaluators)
110-
111-
tensorboard_logger.attach(
112-
trainer,
113-
log_examples('train', trainer),
114-
ignite.engine.Events.EPOCH_COMPLETED,
115-
)
116-
tensorboard_logger.attach(
117-
evaluators['evaluate_compare'],
118-
log_examples('evaluate_compare', trainer),
119-
ignite.engine.Events.EPOCH_COMPLETED,
78+
gradient_data_loader = (
79+
datastream.GradientDatastream()
80+
.data_loader(
81+
batch_size=config['batch_size'],
82+
num_workers=config['n_workers'],
83+
n_batches_per_epoch=config['n_batches_per_epoch'],
84+
worker_init_fn=partial(worker_init, config['seed'], trainer),
85+
collate_fn=tuple,
86+
)
12087
)
12188

122-
if config.get('search_learning_rate', False):
123-
124-
def search(config):
125-
def search_(step, multiplier):
126-
return (
127-
step,
128-
(1 / config['minimum_learning_rate'])
129-
** (step / config['n_batches'])
130-
)
131-
return search_
132-
133-
LearningRateScheduler(
134-
optimizer,
135-
search(config),
136-
).attach(trainer)
137-
138-
else:
139-
LearningRateScheduler(
140-
optimizer,
141-
starcompose(
142-
warmup(150),
143-
cyclical(length=500),
144-
),
145-
).attach(trainer)
146-
147-
trainer.run(
148-
data=(
149-
datastream.GradientDatastream()
150-
.data_loader(
151-
batch_size=config['batch_size'],
152-
num_workers=config['n_workers'],
153-
n_batches_per_epoch=config['n_batches_per_epoch'],
154-
worker_init_fn=partial(worker_init, config['seed'], trainer),
155-
collate_fn=tuple,
156-
)
157-
),
158-
max_epochs=config['max_epochs'],
159-
)
89+
tensorboard_logger = torch.utils.tensorboard.SummaryWriter()
90+
early_stopping = workflow.EarlyStopping(...)
91+
92+
for epoch in tqdm(range(config['max_epochs'])):
93+
for examples in tqdm(gradient_data_loader):
94+
output = train_batch(examples)
95+
metrics.gradient_metrics(output, tensorboard_logger)
96+
# optional: schedule learning rate
97+
98+
for name, data_loader in evaluate_data_loaders:
99+
for examples in tqdm(data_loader):
100+
output = evaluate_batch(examples)
101+
# TODO: metrics need state?
102+
metrics.evaluate_metrics(output, tensorboard_logger)
103+
104+
improved, out_of_patience = early_stopping.score(output)
105+
if improved:
106+
torch.save(train_state, 'model_checkpoint.pt')
107+
elif out_of_patience(output):
108+
break

workflow/functional/interleaved.py

Lines changed: 0 additions & 29 deletions
This file was deleted.

0 commit comments

Comments
 (0)