Skip to content

Commit ee6256c

Browse files
ptrblckJessica Lin
and
Jessica Lin
authored
update out of place ops, update act fn, optimzers, schedulers, samplers (#989)
Co-authored-by: pbialecki <pbialecki@nvidia.com> Co-authored-by: Jessica Lin <jplin@fb.com>
1 parent af754cb commit ee6256c

File tree

1 file changed

+58
-50
lines changed

1 file changed

+58
-50
lines changed

beginner_source/ptcheat.rst

Lines changed: 58 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,8 @@ Distributed Training
8080

8181
.. code-block:: python
8282
83-
import torch.distributed as dist # distributed communication
84-
from multiprocessing import Process # memory sharing processes
83+
import torch.distributed as dist # distributed communication
84+
from torch.multiprocessing import Process # memory sharing processes
8585
8686
See `distributed <https://pytorch.org/docs/stable/distributed.html>`__
8787
and
@@ -95,13 +95,13 @@ Creation
9595

9696
.. code-block:: python
9797
98-
torch.randn(*size) # tensor with independent N(0,1) entries
99-
torch.[ones|zeros](*size) # tensor with all 1's [or 0's]
100-
torch.Tensor(L) # create tensor from [nested] list or ndarray L
101-
x.clone() # clone of x
102-
with torch.no_grad(): # code wrap that stops autograd from tracking tensor history
103-
requires_grad=True # arg, when set to True, tracks computation
104-
# history for future derivative calculations
98+
x = torch.randn(*size) # tensor with independent N(0,1) entries
99+
x = torch.[ones|zeros](*size) # tensor with all 1's [or 0's]
100+
x = torch.tensor(L) # create tensor from [nested] list or ndarray L
101+
y = x.clone() # clone of x
102+
with torch.no_grad(): # code wrap that stops autograd from tracking tensor history
103+
requires_grad=True # arg, when set to True, tracks computation
104+
# history for future derivative calculations
105105
106106
See `tensor <https://pytorch.org/docs/stable/tensors.html>`__
107107

@@ -110,14 +110,16 @@ Dimensionality
110110

111111
.. code-block:: python
112112
113-
x.size() # return tuple-like object of dimensions
114-
torch.cat(tensor_seq, dim=0) # concatenates tensors along dim
115-
x.view(a,b,...) # reshapes x into size (a,b,...)
116-
x.view(-1,a) # reshapes x into size (b,a) for some b
117-
x.transpose(a,b) # swaps dimensions a and b
118-
x.permute(*dims) # permutes dimensions
119-
x.unsqueeze(dim) # tensor with added axis
120-
x.unsqueeze(dim=2) # (a,b,c) tensor -> (a,b,1,c) tensor
113+
x.size() # return tuple-like object of dimensions
114+
x = torch.cat(tensor_seq, dim=0) # concatenates tensors along dim
115+
y = x.view(a,b,...) # reshapes x into size (a,b,...)
116+
y = x.view(-1,a) # reshapes x into size (b,a) for some b
117+
y = x.transpose(a,b) # swaps dimensions a and b
118+
y = x.permute(*dims) # permutes dimensions
119+
y = x.unsqueeze(dim) # tensor with added axis
120+
y = x.unsqueeze(dim=2) # (a,b,c) tensor -> (a,b,1,c) tensor
121+
y = x.squeeze() # removes all dimensions of size 1 (a,1,b,1) -> (a,b)
122+
y = x.squeeze(dim=1) # removes specified dimension of size 1 (a,1,b,1) -> (a,b,1)
121123
122124
See `tensor <https://pytorch.org/docs/stable/tensors.html>`__
123125

@@ -127,9 +129,9 @@ Algebra
127129

128130
.. code-block:: python
129131
130-
A.mm(B) # matrix multiplication
131-
A.mv(x) # matrix-vector multiplication
132-
x.t() # matrix transpose
132+
ret = A.mm(B) # matrix multiplication
133+
ret = A.mv(x) # matrix-vector multiplication
134+
x = x.t() # matrix transpose
133135
134136
See `math
135137
operations <https://pytorch.org/docs/stable/torch.html?highlight=mm#math-operations>`__
@@ -139,24 +141,24 @@ GPU Usage
139141

140142
.. code-block:: python
141143
142-
torch.cuda.is_available # check for cuda
143-
x.cuda() # move x's data from
144-
# CPU to GPU and return new object
144+
torch.cuda.is_available # check for cuda
145+
x = x.cuda() # move x's data from
146+
# CPU to GPU and return new object
145147
146-
x.cpu() # move x's data from GPU to CPU
147-
# and return new object
148+
x = x.cpu() # move x's data from GPU to CPU
149+
# and return new object
148150
149-
if not args.disable_cuda and torch.cuda.is_available(): # device agnostic code
150-
args.device = torch.device('cuda') # and modularity
151-
else: #
152-
args.device = torch.device('cpu') #
151+
if not args.disable_cuda and torch.cuda.is_available(): # device agnostic code
152+
args.device = torch.device('cuda') # and modularity
153+
else: #
154+
args.device = torch.device('cpu') #
153155
154-
net.to(device) # recursively convert their
155-
# parameters and buffers to
156-
# device specific tensors
156+
net.to(device) # recursively convert their
157+
# parameters and buffers to
158+
# device specific tensors
157159
158-
mytensor.to(device) # copy your tensors to a device
159-
# (gpu, cpu)
160+
x = x.to(device) # copy your tensors to a device
161+
# (gpu, cpu)
160162
161163
See `cuda <https://pytorch.org/docs/stable/cuda.html>`__
162164

@@ -175,7 +177,7 @@ Deep Learning
175177
nn.MaxPoolXd(s) # X dimension pooling layer
176178
# (notation as above)
177179
178-
nn.BatchNorm # batch norm layer
180+
nn.BatchNormXd # batch norm layer
179181
nn.RNN/LSTM/GRU # recurrent layers
180182
nn.Dropout(p=0.5, inplace=False) # dropout layer for any dimensional input
181183
nn.Dropout2d(p=0.5, inplace=False) # 2-dimensional channel-wise dropout
@@ -189,11 +191,15 @@ Loss Functions
189191

190192
.. code-block:: python
191193
192-
nn.X # where X is BCELoss, CrossEntropyLoss,
193-
# L1Loss, MSELoss, NLLLoss, SoftMarginLoss,
194-
# MultiLabelSoftMarginLoss, CosineEmbeddingLoss,
195-
# KLDivLoss, MarginRankingLoss, HingeEmbeddingLoss
196-
# or CosineEmbeddingLoss
194+
nn.X # where X is L1Loss, MSELoss, CrossEntropyLoss
195+
# CTCLoss, NLLLoss, PoissonNLLLoss,
196+
# KLDivLoss, BCELoss, BCEWithLogitsLoss,
197+
# MarginRankingLoss, HingeEmbeddingLoss,
198+
# MultiLabelMarginLoss, SmoothL1Loss,
199+
# SoftMarginLoss, MultiLabelSoftMarginLoss,
200+
# CosineEmbeddingLoss, MultiMarginLoss,
201+
# or TripletMarginLoss
202+
197203
198204
See `loss
199205
functions <https://pytorch.org/docs/stable/nn.html#loss-functions>`__
@@ -204,10 +210,10 @@ Activation Functions
204210
.. code-block:: python
205211
206212
nn.X # where X is ReLU, ReLU6, ELU, SELU, PReLU, LeakyReLU,
207-
# Threshold, HardTanh, Sigmoid, Tanh,
208-
# LogSigmoid, Softplus, SoftShrink,
209-
# Softsign, TanhShrink, Softmin, Softmax,
210-
# Softmax2d or LogSoftmax
213+
# RReLu, CELU, GELU, Threshold, Hardshrink, HardTanh,
214+
# Sigmoid, LogSigmoid, Softplus, SoftShrink,
215+
# Softsign, Tanh, TanhShrink, Softmin, Softmax,
216+
# Softmax2d, LogSoftmax or AdaptiveSoftmaxWithLoss
211217
212218
See `activation
213219
functions <https://pytorch.org/docs/stable/nn.html#non-linear-activations-weighted-sum-nonlinearity>`__
@@ -220,8 +226,8 @@ Optimizers
220226
opt = optim.x(model.parameters(), ...) # create optimizer
221227
opt.step() # update weights
222228
optim.X # where X is SGD, Adadelta, Adagrad, Adam,
223-
# SparseAdam, Adamax, ASGD,
224-
# LBFGS, RMSProp or Rprop
229+
# AdamW, SparseAdam, Adamax, ASGD,
230+
# LBFGS, RMSprop or Rprop
225231
226232
See `optimizers <https://pytorch.org/docs/stable/optim.html>`__
227233

@@ -232,8 +238,10 @@ Learning rate scheduling
232238
233239
scheduler = optim.X(optimizer,...) # create lr scheduler
234240
scheduler.step() # update lr at start of epoch
235-
optim.lr_scheduler.X # where X is LambdaLR, StepLR, MultiStepLR,
236-
# ExponentialLR or ReduceLROnPLateau
241+
optim.lr_scheduler.X # where X is LambdaLR, MultiplicativeLR,
242+
# StepLR, MultiStepLR, ExponentialLR,
243+
# CosineAnnealingLR, ReduceLROnPlateau, CyclicLR,
244+
# OneCycleLR, CosineAnnealingWarmRestarts,
237245
238246
See `learning rate
239247
scheduler <https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate>`__
@@ -264,8 +272,8 @@ Dataloaders and DataSamplers
264272
sampler.Sampler(dataset,...) # abstract class dealing with
265273
# ways to sample from dataset
266274
267-
sampler.XSampler where ... # Sequential, Random, Subset,
268-
# WeightedRandom or Distributed
275+
sampler.XSampler where ... # Sequential, Random, SubsetRandom,
276+
# WeightedRandom, Batch, Distributed
269277
270278
See
271279
`dataloader <https://pytorch.org/docs/stable/data.html?highlight=dataloader#torch.utils.data.DataLoader>`__

0 commit comments

Comments
 (0)