|
| 1 | +#include <torch/torch.h> |
| 2 | +#include <math.h> |
| 3 | +#include <iostream> |
| 4 | +#include <cmath> |
| 5 | +#include <limits> |
| 6 | +#include <chrono> |
| 7 | +#include <ctime> |
| 8 | +#include <random> |
| 9 | +#include "scheduler.h" |
| 10 | + |
| 11 | +using namespace torch::indexing; |
| 12 | + |
| 13 | +struct PositionalEncodingImpl : torch::nn::Module{ |
| 14 | + PositionalEncodingImpl(){ |
| 15 | + |
| 16 | + } |
| 17 | + PositionalEncodingImpl(int64_t d_model, int64_t max_len=5000){ |
| 18 | + pe = torch::zeros({max_len, d_model}); |
| 19 | + position = torch::arange(0, max_len, |
| 20 | + torch::TensorOptions(torch::kFloat32).requires_grad(false)); |
| 21 | + position = position.unsqueeze(1); |
| 22 | + torch::Tensor temp = torch::arange(0, d_model, 2, torch::TensorOptions(torch::kFloat32).requires_grad(false)); |
| 23 | + div_term = torch::exp(temp * (std::log(10000.0) / d_model)); |
| 24 | + |
| 25 | + |
| 26 | + pe.index_put_({Slice(), Slice(0, None, 2)}, torch::sin(position * div_term)); |
| 27 | + pe.index_put_({Slice(), Slice(1, None, 2)}, torch::cos(position * div_term)); |
| 28 | + |
| 29 | + |
| 30 | + |
| 31 | + pe = pe.unsqueeze(0).transpose(0, 1); |
| 32 | + register_parameter("pe", pe); |
| 33 | + register_parameter("position", position); |
| 34 | + register_parameter("div_term", div_term); |
| 35 | + register_buffer("pe", pe); |
| 36 | + } |
| 37 | + |
| 38 | + torch::Tensor forward(torch::Tensor x){ |
| 39 | + x = x + pe.index({Slice(0, x.size(0)), Slice()}); |
| 40 | + return x; |
| 41 | + } |
| 42 | + |
| 43 | + torch::Tensor pe; |
| 44 | + torch::Tensor position; |
| 45 | + torch::Tensor div_term; |
| 46 | +}; |
| 47 | + |
| 48 | +TORCH_MODULE(PositionalEncoding); |
| 49 | + |
| 50 | +struct TransformerModel : torch::nn::Module{ |
| 51 | + TransformerModel(int64_t feature_size = 250, int64_t nlayers = 1, float dropout_p=0.1){ |
| 52 | + pos_encoder = PositionalEncoding(feature_size); |
| 53 | + torch::nn::TransformerEncoderLayerOptions elOptions = |
| 54 | + torch::nn::TransformerEncoderLayerOptions(feature_size, 10); |
| 55 | + torch::nn::TransformerEncoderLayer encoder_layers = torch::nn::TransformerEncoderLayer( |
| 56 | + elOptions.dropout(dropout_p)); |
| 57 | + torch::nn::TransformerEncoderOptions enOptions = torch::nn::TransformerEncoderOptions(encoder_layers, nlayers); |
| 58 | + transformer_encoder = torch::nn::TransformerEncoder(enOptions); |
| 59 | + decoder = torch::nn::Linear(feature_size, 1); |
| 60 | + register_module("pos_encoder", pos_encoder); |
| 61 | + register_module("transformer_encoder", transformer_encoder); |
| 62 | + register_module("decoder", decoder); |
| 63 | + } |
| 64 | + |
| 65 | + void init_weights(){ |
| 66 | + float initrange = 0.1; |
| 67 | + decoder->bias.data().zero_(); |
| 68 | + decoder->weight.data().uniform_(-initrange, initrange); |
| 69 | + } |
| 70 | + |
| 71 | + torch::Tensor _generate_square_subsequent_mask(int sz){ |
| 72 | + torch::Tensor mask = (torch::triu(torch::ones({sz, sz})) == 1).transpose(0, 1).to(torch::kFloat32); |
| 73 | + mask = mask.masked_fill(mask == 0, -std::numeric_limits<float>::infinity()).masked_fill(mask == 1, 0.f); |
| 74 | + return mask; |
| 75 | + } |
| 76 | + |
| 77 | + torch::Tensor forward(torch::Tensor src){ |
| 78 | + if (false == is_mask_generated){ |
| 79 | + torch::Tensor mask = _generate_square_subsequent_mask(src.size(0)).to(src.device()); |
| 80 | + src_mask = mask; |
| 81 | + is_mask_generated = true; |
| 82 | + } |
| 83 | + |
| 84 | + src = pos_encoder(src); |
| 85 | + torch::Tensor output = transformer_encoder(src, src_mask); |
| 86 | + output = decoder(output); |
| 87 | + return output; |
| 88 | + } |
| 89 | + |
| 90 | + torch::Tensor src_mask; |
| 91 | + bool is_mask_generated = false; |
| 92 | + PositionalEncoding pos_encoder; |
| 93 | + torch::nn::TransformerEncoder transformer_encoder = nullptr; |
| 94 | + torch::nn::Linear decoder = nullptr; |
| 95 | + int64_t ninp; |
| 96 | +}; |
| 97 | + |
| 98 | +torch::Tensor create_inout_sequences(torch::Tensor input_data, int64_t tw, int64_t output_window = 1){ |
| 99 | + torch::Tensor temp = torch::empty({input_data.size(0) - tw, 2, tw}, torch::TensorOptions(torch::kFloat32)); |
| 100 | + auto len = input_data.numel(); |
| 101 | + auto max_counter = len - tw; |
| 102 | + int64_t k = 0; |
| 103 | + for (auto i = 0; i < max_counter; i++){ |
| 104 | + torch::Tensor train_seq = input_data.index({Slice(i, i + tw)}); |
| 105 | + temp[i][0] = input_data.index({Slice(i, i + tw)}); |
| 106 | + temp[i][1] = input_data.index({Slice(i + output_window, i + tw + output_window)}); |
| 107 | + |
| 108 | + } |
| 109 | + |
| 110 | + return temp; |
| 111 | +} |
| 112 | + |
| 113 | +std::tuple<torch::Tensor, torch::Tensor> get_data(int64_t output_window = 1){ |
| 114 | + //construct a little toy dataset |
| 115 | + auto time = torch::arange(0, 400, 0.1); |
| 116 | + auto amplitude = torch::sin(time) + torch::sin(time * 0.05) + torch::sin(time * 0.12);// + dist(mt); |
| 117 | + |
| 118 | + |
| 119 | + //from sklearn.preprocessing import MinMaxScaler |
| 120 | + |
| 121 | + |
| 122 | + //looks like normalizing input values curtial for the model |
| 123 | + //scaler = MinMaxScaler(feature_range=(-1, 1)) |
| 124 | + //amplitude = scaler.fit_transform(series.to_numpy().reshape(-1, 1)).reshape(-1) |
| 125 | + //amplitude = scaler.fit_transform(amplitude.reshape(-1, 1)).reshape(-1) |
| 126 | + |
| 127 | + |
| 128 | + auto samples = 2600; |
| 129 | + |
| 130 | + auto train_data = amplitude.index({Slice(None, samples)}); |
| 131 | + auto test_data = amplitude.index({Slice(samples, None)}); |
| 132 | + |
| 133 | + //convert our train data into a pytorch train tensor |
| 134 | + auto input_window = 100; |
| 135 | + |
| 136 | + auto train_sequence = create_inout_sequences(train_data,input_window); |
| 137 | + train_sequence = train_sequence.index({Slice(None,-output_window)}); |
| 138 | + |
| 139 | + auto test_sequence = create_inout_sequences(test_data,input_window); |
| 140 | + test_sequence = test_sequence.index({Slice(None,-output_window)}); |
| 141 | + |
| 142 | + auto cuda_available = torch::cuda::is_available(); |
| 143 | + torch::Device device(cuda_available ? torch::kCUDA : torch::kCPU); |
| 144 | + |
| 145 | + return std::make_tuple(train_sequence.to(device),test_sequence.to(device)); |
| 146 | +} |
| 147 | + |
| 148 | +std::tuple<torch::Tensor, torch::Tensor> get_batch(torch::Tensor source, int64_t i, int64_t batch_size, int64_t input_window = 100){ |
| 149 | + auto seq_len = std::min(batch_size, source.size(0) - i); |
| 150 | + |
| 151 | + auto data = source.index({Slice(i, i + seq_len)}); |
| 152 | + auto input = data.index({Slice(), 0, Slice()}); |
| 153 | + auto target = data.index({Slice(), 1, Slice()}); |
| 154 | + auto temp = input.numel()/100; |
| 155 | + if (temp > 10) |
| 156 | + temp = 10; |
| 157 | + input = torch::reshape(input, {100, temp, 1}); |
| 158 | + target = torch::reshape(target, {100, temp, 1}); |
| 159 | + return std::make_tuple(input, target); |
| 160 | +} |
| 161 | + |
| 162 | + |
| 163 | +void train(TransformerModel model, torch::Tensor train_data, int64_t num_epochs = 100){ |
| 164 | + model.train(); |
| 165 | + auto total_loss = 0.0; |
| 166 | + auto start_time = std::chrono::system_clock::now(); |
| 167 | + auto batch_size = 10; |
| 168 | + auto batch = 0; |
| 169 | + |
| 170 | + torch::nn::MSELoss criterion; |
| 171 | + |
| 172 | + |
| 173 | + auto learning_rate = 0.005; |
| 174 | + torch::optim::SGD optimizer(model.parameters(), torch::optim::SGDOptions(learning_rate)); |
| 175 | + scheduler::StepLR<decltype(optimizer)> scheduler(optimizer, 1.0, 0.95); |
| 176 | + |
| 177 | + for(int64_t i = 0; i <= num_epochs; i++){ |
| 178 | + auto start_time = std::chrono::system_clock::now(); |
| 179 | + std::cout<<"Epoch "<<i<<std::endl; |
| 180 | + batch = 0; |
| 181 | + for (int64_t j = 0; j < train_data.size(0); j = j + batch_size, batch++){ |
| 182 | + auto data = get_batch(train_data, j, batch_size); |
| 183 | + optimizer.zero_grad(); |
| 184 | + auto output = model.forward(std::get<0>(data)); |
| 185 | + |
| 186 | + auto loss = criterion(output, std::get<1>(data)); |
| 187 | + loss.backward(); |
| 188 | + torch::nn::utils::clip_grad_norm_(model.parameters(), 0.7); |
| 189 | + optimizer.step(); |
| 190 | + total_loss += loss.item<double>(); |
| 191 | + auto log_interval = int(train_data.size(0)) / (batch_size * 5); |
| 192 | + if (batch != 0 && 0 == batch % log_interval){ |
| 193 | + auto curr_loss = total_loss / log_interval; |
| 194 | + auto elapsed = std::chrono::system_clock::now() - start_time; |
| 195 | + std::cout<<"|epoch "<<i<<" | "<<batch<<"/"<<train_data.size(0)/batch_size; |
| 196 | + std::cout<<" batches | "<<(elapsed.count() * 10)<<" ms | loss"<<curr_loss<<std::endl;; |
| 197 | + total_loss = 0; |
| 198 | + start_time = std::chrono::system_clock::now(); |
| 199 | + } |
| 200 | + } |
| 201 | + |
| 202 | + scheduler.step(); |
| 203 | + } |
| 204 | + |
| 205 | + return; |
| 206 | +} |
| 207 | + |
| 208 | +void evaluate(TransformerModel model, torch::Tensor eval_data){ |
| 209 | + model.eval(); |
| 210 | + auto batch_size = 10; |
| 211 | + auto total_loss = 0.0; |
| 212 | + torch::nn::MSELoss criterion; |
| 213 | + |
| 214 | + std::cout<<"Evaluating:"; |
| 215 | + for (int64_t j = 0; j < eval_data.size(0); j = j + batch_size){ |
| 216 | + auto data = get_batch(eval_data, j, batch_size); |
| 217 | + auto output = model.forward(std::get<0>(data)); |
| 218 | + auto loss = criterion(output, std::get<1>(data)); |
| 219 | + total_loss += loss.item<double>(); |
| 220 | + } |
| 221 | + |
| 222 | + std::cout<<"Evaluation Loss: "<<total_loss<<std::endl; |
| 223 | + return; |
| 224 | +} |
| 225 | + |
| 226 | +int main(){ |
| 227 | + auto cuda_available = torch::cuda::is_available(); |
| 228 | + torch::Device device(cuda_available ? torch::kCUDA : torch::kCPU); |
| 229 | + |
| 230 | + auto model = TransformerModel(); |
| 231 | + model.to(device); |
| 232 | + |
| 233 | + auto data = get_data(); |
| 234 | + train(model, std::get<0>(data)); |
| 235 | + evaluate(model, std::get<1>(data)); |
| 236 | + |
| 237 | + return 0; |
| 238 | + |
| 239 | +} |
| 240 | + |
0 commit comments