1
+ #include < torch/torch.h>
2
+ #include < math.h>
3
+ #include < iostream>
4
+ #include < cmath>
5
+ #include < limits>
6
+ #include < chrono>
7
+ #include < ctime>
8
+ #include < random>
9
+ #include " scheduler.h"
10
+
11
+ using namespace torch ::indexing;
12
+
13
+ struct PositionalEncodingImpl : torch::nn::Module{
14
+ PositionalEncodingImpl (){
15
+
16
+ }
17
+ PositionalEncodingImpl (int64_t d_model, int64_t max_len=5000 ){
18
+ pe = torch::zeros ({max_len, d_model});
19
+ position = torch::arange (0 , max_len,
20
+ torch::TensorOptions (torch::kFloat32 ).requires_grad (false ));
21
+ position = position.unsqueeze (1 );
22
+ torch::Tensor temp = torch::arange (0 , d_model, 2 , torch::TensorOptions (torch::kFloat32 ).requires_grad (false ));
23
+ div_term = torch::exp (temp * (std::log (10000.0 ) / d_model));
24
+
25
+
26
+ pe.index_put_ ({Slice (), Slice (0 , None, 2 )}, torch::sin (position * div_term));
27
+ pe.index_put_ ({Slice (), Slice (1 , None, 2 )}, torch::cos (position * div_term));
28
+
29
+
30
+
31
+ pe = pe.unsqueeze (0 ).transpose (0 , 1 );
32
+ register_parameter (" pe" , pe);
33
+ register_parameter (" position" , position);
34
+ register_parameter (" div_term" , div_term);
35
+ register_buffer (" pe" , pe);
36
+ }
37
+
38
+ torch::Tensor forward (torch::Tensor x){
39
+ x = x + pe.index ({Slice (0 , x.size (0 )), Slice ()});
40
+ return x;
41
+ }
42
+
43
+ torch::Tensor pe;
44
+ torch::Tensor position;
45
+ torch::Tensor div_term;
46
+ };
47
+
48
+ TORCH_MODULE (PositionalEncoding);
49
+
50
+ struct TransformerModel : torch::nn::Module{
51
+ TransformerModel (int64_t feature_size = 250 , int64_t nlayers = 1 , float dropout_p=0.1 ){
52
+ pos_encoder = PositionalEncoding (feature_size);
53
+ torch::nn::TransformerEncoderLayerOptions elOptions =
54
+ torch::nn::TransformerEncoderLayerOptions (feature_size, 10 );
55
+ torch::nn::TransformerEncoderLayer encoder_layers = torch::nn::TransformerEncoderLayer (
56
+ elOptions.dropout (dropout_p));
57
+ torch::nn::TransformerEncoderOptions enOptions = torch::nn::TransformerEncoderOptions (encoder_layers, nlayers);
58
+ transformer_encoder = torch::nn::TransformerEncoder (enOptions);
59
+ decoder = torch::nn::Linear (feature_size, 1 );
60
+ register_module (" pos_encoder" , pos_encoder);
61
+ register_module (" transformer_encoder" , transformer_encoder);
62
+ register_module (" decoder" , decoder);
63
+ }
64
+
65
+ void init_weights (){
66
+ float initrange = 0.1 ;
67
+ decoder->bias .data ().zero_ ();
68
+ decoder->weight .data ().uniform_ (-initrange, initrange);
69
+ }
70
+
71
+ torch::Tensor _generate_square_subsequent_mask (int sz){
72
+ torch::Tensor mask = (torch::triu (torch::ones ({sz, sz})) == 1 ).transpose (0 , 1 ).to (torch::kFloat32 );
73
+ mask = mask.masked_fill (mask == 0 , -std::numeric_limits<float >::infinity ()).masked_fill (mask == 1 , 0 .f );
74
+ return mask;
75
+ }
76
+
77
+ torch::Tensor forward (torch::Tensor src){
78
+ if (false == is_mask_generated){
79
+ torch::Tensor mask = _generate_square_subsequent_mask (src.size (0 )).to (src.device ());
80
+ src_mask = mask;
81
+ is_mask_generated = true ;
82
+ }
83
+
84
+ src = pos_encoder (src);
85
+ torch::Tensor output = transformer_encoder (src, src_mask);
86
+ output = decoder (output);
87
+ return output;
88
+ }
89
+
90
+ torch::Tensor src_mask;
91
+ bool is_mask_generated = false ;
92
+ PositionalEncoding pos_encoder;
93
+ torch::nn::TransformerEncoder transformer_encoder = nullptr ;
94
+ torch::nn::Linear decoder = nullptr ;
95
+ int64_t ninp;
96
+ };
97
+
98
+ torch::Tensor create_inout_sequences (torch::Tensor input_data, int64_t tw, int64_t output_window = 1 ){
99
+ torch::Tensor temp = torch::empty ({input_data.size (0 ) - tw, 2 , tw}, torch::TensorOptions (torch::kFloat32 ));
100
+ auto len = input_data.numel ();
101
+ auto max_counter = len - tw;
102
+ int64_t k = 0 ;
103
+ for (auto i = 0 ; i < max_counter; i++){
104
+ torch::Tensor train_seq = input_data.index ({Slice (i, i + tw)});
105
+ temp[i][0 ] = input_data.index ({Slice (i, i + tw)});
106
+ temp[i][1 ] = input_data.index ({Slice (i + output_window, i + tw + output_window)});
107
+
108
+ }
109
+
110
+ return temp;
111
+ }
112
+
113
+ std::tuple<torch::Tensor, torch::Tensor> get_data (int64_t output_window = 1 ){
114
+ // construct a little toy dataset
115
+ auto time = torch::arange (0 , 400 , 0.1 );
116
+ auto amplitude = torch::sin (time) + torch::sin (time * 0.05 ) + torch::sin (time * 0.12 );// + dist(mt);
117
+
118
+
119
+ // from sklearn.preprocessing import MinMaxScaler
120
+
121
+
122
+ // looks like normalizing input values curtial for the model
123
+ // scaler = MinMaxScaler(feature_range=(-1, 1))
124
+ // amplitude = scaler.fit_transform(series.to_numpy().reshape(-1, 1)).reshape(-1)
125
+ // amplitude = scaler.fit_transform(amplitude.reshape(-1, 1)).reshape(-1)
126
+
127
+
128
+ auto samples = 2600 ;
129
+
130
+ auto train_data = amplitude.index ({Slice (None, samples)});
131
+ auto test_data = amplitude.index ({Slice (samples, None)});
132
+
133
+ // convert our train data into a pytorch train tensor
134
+ auto input_window = 100 ;
135
+
136
+ auto train_sequence = create_inout_sequences (train_data,input_window);
137
+ train_sequence = train_sequence.index ({Slice (None,-output_window)});
138
+
139
+ auto test_sequence = create_inout_sequences (test_data,input_window);
140
+ test_sequence = test_sequence.index ({Slice (None,-output_window)});
141
+
142
+ auto cuda_available = torch::cuda::is_available ();
143
+ torch::Device device (cuda_available ? torch::kCUDA : torch::kCPU );
144
+
145
+ return std::make_tuple (train_sequence.to (device),test_sequence.to (device));
146
+ }
147
+
148
+ std::tuple<torch::Tensor, torch::Tensor> get_batch (torch::Tensor source, int64_t i, int64_t batch_size, int64_t input_window = 100 ){
149
+ auto seq_len = std::min (batch_size, source.size (0 ) - i);
150
+
151
+ auto data = source.index ({Slice (i, i + seq_len)});
152
+ auto input = data.index ({Slice (), 0 , Slice ()});
153
+ auto target = data.index ({Slice (), 1 , Slice ()});
154
+ auto temp = input.numel ()/100 ;
155
+ if (temp > 10 )
156
+ temp = 10 ;
157
+ input = torch::reshape (input, {100 , temp, 1 });
158
+ target = torch::reshape (target, {100 , temp, 1 });
159
+ return std::make_tuple (input, target);
160
+ }
161
+
162
+
163
+ void train (TransformerModel model, torch::Tensor train_data, int64_t num_epochs = 100 ){
164
+ model.train ();
165
+ auto total_loss = 0.0 ;
166
+ auto start_time = std::chrono::system_clock::now ();
167
+ auto batch_size = 10 ;
168
+ auto batch = 0 ;
169
+
170
+ torch::nn::MSELoss criterion;
171
+
172
+
173
+ auto learning_rate = 0.005 ;
174
+ torch::optim::SGD optimizer (model.parameters (), torch::optim::SGDOptions (learning_rate));
175
+ scheduler::StepLR<decltype (optimizer)> scheduler (optimizer, 1.0 , 0.95 );
176
+
177
+ for (int64_t i = 0 ; i <= num_epochs; i++){
178
+ auto start_time = std::chrono::system_clock::now ();
179
+ std::cout<<" Epoch " <<i<<std::endl;
180
+ batch = 0 ;
181
+ for (int64_t j = 0 ; j < train_data.size (0 ); j = j + batch_size, batch++){
182
+ auto data = get_batch (train_data, j, batch_size);
183
+ optimizer.zero_grad ();
184
+ auto output = model.forward (std::get<0 >(data));
185
+
186
+ auto loss = criterion (output, std::get<1 >(data));
187
+ loss.backward ();
188
+ torch::nn::utils::clip_grad_norm_ (model.parameters (), 0.7 );
189
+ optimizer.step ();
190
+ total_loss += loss.item <double >();
191
+ auto log_interval = int (train_data.size (0 )) / (batch_size * 5 );
192
+ if (batch != 0 && 0 == batch % log_interval){
193
+ auto curr_loss = total_loss / log_interval;
194
+ auto elapsed = std::chrono::system_clock::now () - start_time;
195
+ std::cout<<" |epoch " <<i<<" | " <<batch<<" /" <<train_data.size (0 )/batch_size;
196
+ std::cout<<" batches | " <<(elapsed.count () * 10 )<<" ms | loss" <<curr_loss<<std::endl;;
197
+ total_loss = 0 ;
198
+ start_time = std::chrono::system_clock::now ();
199
+ }
200
+ }
201
+
202
+ scheduler.step ();
203
+ }
204
+
205
+ return ;
206
+ }
207
+
208
+ void evaluate (TransformerModel model, torch::Tensor eval_data){
209
+ model.eval ();
210
+ auto batch_size = 10 ;
211
+ auto total_loss = 0.0 ;
212
+ torch::nn::MSELoss criterion;
213
+
214
+ std::cout<<" Evaluating:" ;
215
+ for (int64_t j = 0 ; j < eval_data.size (0 ); j = j + batch_size){
216
+ auto data = get_batch (eval_data, j, batch_size);
217
+ auto output = model.forward (std::get<0 >(data));
218
+ auto loss = criterion (output, std::get<1 >(data));
219
+ total_loss += loss.item <double >();
220
+ }
221
+
222
+ std::cout<<" Evaluation Loss: " <<total_loss<<std::endl;
223
+ return ;
224
+ }
225
+
226
+ int main (){
227
+ auto cuda_available = torch::cuda::is_available ();
228
+ torch::Device device (cuda_available ? torch::kCUDA : torch::kCPU );
229
+
230
+ auto model = TransformerModel ();
231
+ model.to (device);
232
+
233
+ auto data = get_data ();
234
+ train (model, std::get<0 >(data));
235
+ evaluate (model, std::get<1 >(data));
236
+
237
+ return 0 ;
238
+ }
0 commit comments