Skip to content

Commit c814fe1

Browse files
authored
Merge pull request #1175 from lingbai-kong/ndarrayload
optimize: temporal complexity of Imdb dataset loader
2 parents eb49be0 + 628b2ce commit c814fe1

File tree

2 files changed

+27
-35
lines changed

2 files changed

+27
-35
lines changed

src/TensorFlowNET.Keras/Datasets/Imdb.cs

Lines changed: 21 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -116,23 +116,13 @@ public DatasetPass load_data(
116116
for (var i = 0; i < x_train_array.GetLength(0); i++)
117117
{
118118
new_x_train_array[i, 0] = (int)start_char;
119-
for (var j = 0; j < x_train_array.GetLength(1); j++)
120-
{
121-
if (x_train_array[i, j] == 0)
122-
break;
123-
new_x_train_array[i, j + 1] = x_train_array[i, j];
124-
}
119+
Array.Copy(x_train_array, i * x_train_array.GetLength(1), new_x_train_array, i * new_x_train_array.GetLength(1) + 1, x_train_array.GetLength(1));
125120
}
126121
int[,] new_x_test_array = new int[x_test_array.GetLength(0), x_test_array.GetLength(1) + 1];
127122
for (var i = 0; i < x_test_array.GetLength(0); i++)
128123
{
129124
new_x_test_array[i, 0] = (int)start_char;
130-
for (var j = 0; j < x_test_array.GetLength(1); j++)
131-
{
132-
if (x_test_array[i, j] == 0)
133-
break;
134-
new_x_test_array[i, j + 1] = x_test_array[i, j];
135-
}
125+
Array.Copy(x_test_array, i * x_test_array.GetLength(1), new_x_test_array, i * new_x_test_array.GetLength(1) + 1, x_test_array.GetLength(1));
136126
}
137127
x_train_array = new_x_train_array;
138128
x_test_array = new_x_test_array;
@@ -163,15 +153,19 @@ public DatasetPass load_data(
163153
{
164154
maxlen = max(x_train_array.GetLength(1), x_test_array.GetLength(1));
165155
}
166-
(x_train, labels_train) = data_utils._remove_long_seq((int)maxlen, x_train_array, labels_train_array);
167-
(x_test, labels_test) = data_utils._remove_long_seq((int)maxlen, x_test_array, labels_test_array);
168-
if (x_train.size == 0 || x_test.size == 0)
156+
(x_train_array, labels_train_array) = data_utils._remove_long_seq((int)maxlen, x_train_array, labels_train_array);
157+
(x_test_array, labels_test_array) = data_utils._remove_long_seq((int)maxlen, x_test_array, labels_test_array);
158+
if (x_train_array.Length == 0 || x_test_array.Length == 0)
169159
throw new ValueError("After filtering for sequences shorter than maxlen=" +
170160
$"{maxlen}, no sequence was kept. Increase maxlen.");
171161

172-
var xs = np.concatenate(new[] { x_train, x_test });
173-
var labels = np.concatenate(new[] { labels_train, labels_test });
174-
var xs_array = (int[,])xs.ToMultiDimArray<int>();
162+
int[,] xs_array = new int[x_train_array.GetLength(0) + x_test_array.GetLength(0), (int)maxlen];
163+
Array.Copy(x_train_array, xs_array, x_train_array.Length);
164+
Array.Copy(x_test_array, 0, xs_array, x_train_array.Length, x_train_array.Length);
165+
166+
long[] labels_array = new long[labels_train_array.Length + labels_test_array.Length];
167+
Array.Copy(labels_train_array, labels_array, labels_train_array.Length);
168+
Array.Copy(labels_test_array, 0, labels_array, labels_train_array.Length, labels_test_array.Length);
175169

176170
if (num_words == null)
177171
{
@@ -197,7 +191,7 @@ public DatasetPass load_data(
197191
new_xs_array[i, j] = (int)oov_char;
198192
}
199193
}
200-
xs = new NDArray(new_xs_array);
194+
xs_array = new_xs_array;
201195
}
202196
else
203197
{
@@ -211,19 +205,19 @@ public DatasetPass load_data(
211205
new_xs_array[i, k++] = xs_array[i, j];
212206
}
213207
}
214-
xs = new NDArray(new_xs_array);
208+
xs_array = new_xs_array;
215209
}
216210

217-
var idx = len(x_train);
218-
x_train = xs[$"0:{idx}"];
219-
x_test = xs[$"{idx}:"];
220-
var y_train = labels[$"0:{idx}"];
221-
var y_test = labels[$"{idx}:"];
211+
Array.Copy(xs_array, x_train_array, x_train_array.Length);
212+
Array.Copy(xs_array, x_train_array.Length, x_test_array, 0, x_train_array.Length);
213+
214+
Array.Copy(labels_array, labels_train_array, labels_train_array.Length);
215+
Array.Copy(labels_array, labels_train_array.Length, labels_test_array, 0, labels_test_array.Length);
222216

223217
return new DatasetPass
224218
{
225-
Train = (x_train, y_train),
226-
Test = (x_test, y_test)
219+
Train = (x_train_array, labels_train_array),
220+
Test = (x_test_array, labels_test_array)
227221
};
228222
}
229223

src/TensorFlowNET.Keras/Utils/data_utils.cs

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ public static string get_file(string fname, string origin,
4040
return datadir;
4141
}
4242

43-
public static (NDArray, NDArray) _remove_long_seq(int maxlen, NDArray seq, NDArray label)
43+
public static (int[,], long[]) _remove_long_seq(int maxlen, int[,] seq, long[] label)
4444
{
4545
/*Removes sequences that exceed the maximum length.
4646
@@ -56,19 +56,17 @@ public static (NDArray, NDArray) _remove_long_seq(int maxlen, NDArray seq, NDArr
5656
List<int[]> new_seq = new List<int[]>();
5757
List<long> new_label = new List<long>();
5858

59-
var seq_array = (int[,])seq.ToMultiDimArray<int>();
60-
var label_array = (long[])label.ToArray<long>();
61-
for (var i = 0; i < seq_array.GetLength(0); i++)
59+
for (var i = 0; i < seq.GetLength(0); i++)
6260
{
63-
if (maxlen < seq_array.GetLength(1) && seq_array[i,maxlen] != 0)
61+
if (maxlen < seq.GetLength(1) && seq[i, maxlen] != 0)
6462
continue;
6563
int[] sentence = new int[maxlen];
66-
for (var j = 0; j < maxlen && j < seq_array.GetLength(1); j++)
64+
for (var j = 0; j < maxlen && j < seq.GetLength(1); j++)
6765
{
68-
sentence[j] = seq_array[i, j];
66+
sentence[j] = seq[i, j];
6967
}
7068
new_seq.Add(sentence);
71-
new_label.Add(label_array[i]);
69+
new_label.Add(label[i]);
7270
}
7371

7472
int[,] new_seq_array = new int[new_seq.Count, maxlen];

0 commit comments

Comments
 (0)