The foundations we'll assume throughout this course are:
from pathlib import Path
import pickle, gzip, math, os, time, shutil, matplotlib as mpl, matplotlib.pyplot as plt
MNIST_URL='https://github.com/mnielsen/neural-networks-and-deep-learning/blob/master/data/mnist.pkl.gz?raw=true'
path_data = Path('data')
path_data.mkdir(exist_ok=True)
path_gz = path_data/'mnist.pkl.gz'
urlretrieve - (read the docs!)
from urllib.request import urlretrieve
if not path_gz.exists(): urlretrieve(MNIST_URL, path_gz)
!ls -l data
total 16656 -rw-rw-r-- 1 jhoward jhoward 17051982 Sep 30 04:37 mnist.pkl.gz
with gzip.open(path_gz, 'rb') as f: ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding='latin-1')
lst1 = list(x_train[0])
vals = lst1[200:210]
vals
[0.0, 0.0, 0.0, 0.19140625, 0.9296875, 0.98828125, 0.98828125, 0.98828125, 0.98828125, 0.98828125]
def chunks(x, sz):
for i in range(0, len(x), sz): yield x[i:i+sz]
list(chunks(vals, 5))
[[0.0, 0.0, 0.0, 0.19140625, 0.9296875], [0.98828125, 0.98828125, 0.98828125, 0.98828125, 0.98828125]]
mpl.rcParams['image.cmap'] = 'gray'
plt.imshow(list(chunks(lst1, 28)));
from itertools import islice
it = iter(vals)
islice(it, 5)
<itertools.islice at 0x7f678a013b30>
list(islice(it, 5))
[0.0, 0.0, 0.0, 0.19140625, 0.9296875]
list(islice(it, 5))
[0.98828125, 0.98828125, 0.98828125, 0.98828125, 0.98828125]
list(islice(it, 5))
[]
it = iter(lst1)
img = list(iter(lambda: list(islice(it, 28)), []))
plt.imshow(img);
img[20][15]
0.98828125
class Matrix:
def __init__(self, xs): self.xs = xs
def __getitem__(self, idxs): return self.xs[idxs[0]][idxs[1]]
m = Matrix(img)
m[20,15]
0.98828125
import torch
from torch import tensor
/home/jhoward/mambaforge/lib/python3.9/site-packages/torch/cuda/__init__.py:83: UserWarning: CUDA initialization: Unexpected error from cudaGetDeviceCount(). Did you run some cuda functions before calling NumCudaDevices() that might have already set an error? Error 803: system has unsupported display driver / cuda driver combination (Triggered internally at /opt/conda/conda-bld/pytorch_1656352657443/work/c10/cuda/CUDAFunctions.cpp:109.) return torch._C._cuda_getDeviceCount() > 0
tensor([1,2,3])
tensor([1, 2, 3])
x_train,y_train,x_valid,y_valid = map(tensor, (x_train,y_train,x_valid,y_valid))
x_train.shape
torch.Size([50000, 784])
x_train.type()
'torch.FloatTensor'
imgs = x_train.reshape((-1,28,28))
imgs.shape
torch.Size([50000, 28, 28])
plt.imshow(imgs[0]);
imgs[0,20,15]
tensor(0.9883)
n,c = x_train.shape
y_train, y_train.shape
(tensor([5, 0, 4, ..., 8, 4, 8]), torch.Size([50000]))
min(y_train),max(y_train)
(tensor(0), tensor(9))
y_train.min(), y_train.max()
(tensor(0), tensor(9))
Based on the Wichmann Hill algorithm used before Python 2.3.
rnd_state = None
def seed(a):
global rnd_state
a, x = divmod(a, 30268)
a, y = divmod(a, 30306)
a, z = divmod(a, 30322)
rnd_state = int(x)+1, int(y)+1, int(z)+1
seed(457428938475)
rnd_state
(4976, 20238, 499)
def rand():
global rnd_state
x, y, z = rnd_state
x = (171 * x) % 30269
y = (172 * y) % 30307
z = (170 * z) % 30323
rnd_state = x,y,z
return (x/30269 + y/30307 + z/30323) % 1.0
rand(),rand(),rand()
(0.7645251082582081, 0.7920889799553945, 0.06912886811267205)
if os.fork(): print(f'In parent: {rand()}')
else:
print(f'In child: {rand()}')
os._exit(os.EX_OK)
In parent: 0.9559050644103264 In child: 0.9559050644103264
if os.fork(): print(f'In parent: {torch.rand(1)}')
else:
print(f'In child: {torch.rand(1)}')
os._exit(os.EX_OK)
In parent: tensor([0.1241]) In child: tensor([0.1241])
plt.plot([rand() for _ in range(50)]);
plt.hist([rand() for _ in range(10000)]);
%timeit -n 10 list(chunks([rand() for _ in range(7840)], 10))
2.64 ms ± 13.8 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit -n 10 torch.randn(784,10)
The slowest run took 14.01 times longer than the fastest. This could mean that an intermediate result is being cached. 91.3 µs ± 145 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
weights = torch.randn(784,10)
bias = torch.zeros(10)
m1 = x_valid[:5]
m2 = weights
m1.shape,m2.shape
(torch.Size([5, 784]), torch.Size([784, 10]))
ar,ac = m1.shape # n_rows * n_cols
br,bc = m2.shape
(ar,ac),(br,bc)
((5, 784), (784, 10))
t1 = torch.zeros(ar, bc)
t1.shape
torch.Size([5, 10])
for i in range(ar): # 5
for j in range(bc): # 10
for k in range(ac): # 784
t1[i,j] += m1[i,k] * m2[k,j]
t1
tensor([[ 1.6348, 1.5296, 9.7368, 0.7734, 4.8742, -7.7081, -4.1087,
-1.6114, 18.5321, -1.8670],
[ 6.0661, 4.1725, 3.9054, 9.3144, 8.8816, 5.6508, -4.1909,
3.5372, 13.6349, -3.2493],
[ -0.4399, 7.3606, -12.2532, 4.0383, 6.1932, 4.1315, 4.4890,
-2.2943, 6.3479, -1.4134],
[ 4.8634, -5.8807, -13.8575, -16.3413, 2.8746, 11.1938, 16.8231,
-0.0271, 0.0583, -4.9399],
[ -0.1856, -9.0826, 2.4181, 4.3356, 0.6344, -0.7210, -2.1056,
12.1806, 4.3477, -10.4125]])
t1.shape
torch.Size([5, 10])
import numpy as np
torch.set_printoptions(precision=2, linewidth=140, sci_mode=False)
np.set_printoptions(precision=2, linewidth=140)
t1
tensor([[ 1.63, 1.53, 9.74, 0.77, 4.87, -7.71, -4.11, -1.61, 18.53, -1.87],
[ 6.07, 4.17, 3.91, 9.31, 8.88, 5.65, -4.19, 3.54, 13.63, -3.25],
[ -0.44, 7.36, -12.25, 4.04, 6.19, 4.13, 4.49, -2.29, 6.35, -1.41],
[ 4.86, -5.88, -13.86, -16.34, 2.87, 11.19, 16.82, -0.03, 0.06, -4.94],
[ -0.19, -9.08, 2.42, 4.34, 0.63, -0.72, -2.11, 12.18, 4.35, -10.41]])
def matmul(a,b):
(ar,ac),(br,bc) = a.shape,b.shape
c = torch.zeros(ar, bc)
for i in range(ar):
for j in range(bc):
for k in range(ac): c[i,j] += a[i,k] * b[k,j]
return c
%time _=matmul(m1, m2)
CPU times: user 416 ms, sys: 0 ns, total: 416 ms Wall time: 416 ms
from numba import njit
@njit
def dot(a,b):
res = 0.
for i in range(len(a)): res+=a[i]*b[i]
return res
from numpy import array
%time dot(array([1.,2,3]),array([2.,3,4]))
CPU times: user 185 ms, sys: 12.3 ms, total: 197 ms Wall time: 243 ms
20.0
%time dot(array([1.,2,3]),array([2.,3,4]))
CPU times: user 14 µs, sys: 1 µs, total: 15 µs Wall time: 16.9 µs
20.0
Now only two of our loops are running in Python, not three:
def matmul(a,b):
(ar,ac),(br,bc) = a.shape,b.shape
c = torch.zeros(ar, bc)
for i in range(ar):
for j in range(bc): c[i,j] = dot(a[i,:], b[:,j])
return c
m1a,m2a = m1.numpy(),m2.numpy()
from fastcore.test import *
test_close(t1,matmul(m1a, m2a))
%timeit -n 10 matmul(m1a,m2a)
245 µs ± 36.7 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
a = tensor([10., 6, -4])
b = tensor([2., 8, 7])
a,b
(tensor([10., 6., -4.]), tensor([2., 8., 7.]))
a + b
tensor([12., 14., 3.])
(a < b).float().mean()
tensor(0.67)
m = tensor([[1., 2, 3], [4,5,6], [7,8,9]]); m
tensor([[1., 2., 3.],
[4., 5., 6.],
[7., 8., 9.]])
Frobenius norm:
$$\| A \|_F = \left( \sum_{i,j=1}^n | a_{ij} |^2 \right)^{1/2}$$Hint: you don't normally need to write equations in LaTeX yourself, instead, you can click 'edit' in Wikipedia and copy the LaTeX from there (which is what I did for the above equation). Or on arxiv.org, click "Download: Other formats" in the top right, then "Download source"; rename the downloaded file to end in .tgz if it doesn't already, and you should find the source there, including the equations to copy and paste. This is the source LaTeX that I pasted to render the equation above:
$$\| A \|_F = \left( \sum_{i,j=1}^n | a_{ij} |^2 \right)^{1/2}$$
(m*m).sum().sqrt()
tensor(16.88)
def matmul(a,b):
(ar,ac),(br,bc) = a.shape,b.shape
c = torch.zeros(ar, bc)
for i in range(ar):
for j in range(bc): c[i,j] = (a[i,:] * b[:,j]).sum()
return c
test_close(t1,matmul(m1, m2))
%timeit -n 10 _=matmul(m1, m2)
600 µs ± 13.5 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
def matmul(a,b):
(ar,ac),(br,bc) = a.shape,b.shape
c = torch.zeros(ar, bc)
for i in range(ar):
for j in range(bc): c[i,j] = torch.dot(a[i,:], b[:,j])
return c
test_close(t1,matmul(m1, m2))
%timeit -n 10 _=matmul(m1, m2)
487 µs ± 21.5 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
The term broadcasting describes how arrays with different shapes are treated during arithmetic operations. The term broadcasting was first used by Numpy.
From the Numpy Documentation:
The term broadcasting describes how numpy treats arrays with
different shapes during arithmetic operations. Subject to certain
constraints, the smaller array is “broadcast” across the larger
array so that they have compatible shapes. Broadcasting provides a
means of vectorizing array operations so that looping occurs in C
instead of Python. It does this without making needless copies of
data and usually leads to efficient algorithm implementations.
In addition to the efficiency of broadcasting, it allows developers to write less code, which typically leads to fewer errors.
This section was adapted from Chapter 4 of the fast.ai Computational Linear Algebra course.
a
tensor([10., 6., -4.])
a > 0
tensor([ True, True, False])
How are we able to do a > 0? 0 is being broadcast to have the same dimensions as a.
For instance you can normalize our dataset by subtracting the mean (a scalar) from the entire data set (a matrix) and dividing by the standard deviation (another scalar), using broadcasting.
Other examples of broadcasting with a scalar:
a + 1
tensor([11., 7., -3.])
m
tensor([[1., 2., 3.],
[4., 5., 6.],
[7., 8., 9.]])
2*m
tensor([[ 2., 4., 6.],
[ 8., 10., 12.],
[14., 16., 18.]])
Although broadcasting a scalar is an idea that dates back to APL, the more powerful idea of broadcasting across higher rank tensors comes from a little known language called Yorick.
We can also broadcast a vector to a matrix:
c = tensor([10.,20,30]); c
tensor([10., 20., 30.])
m
tensor([[1., 2., 3.],
[4., 5., 6.],
[7., 8., 9.]])
m.shape,c.shape
(torch.Size([3, 3]), torch.Size([3]))
m + c
tensor([[11., 22., 33.],
[14., 25., 36.],
[17., 28., 39.]])
c + m
tensor([[11., 22., 33.],
[14., 25., 36.],
[17., 28., 39.]])
We don't really copy the rows, but it looks as if we did. In fact, the rows are given a stride of 0.
t = c.expand_as(m)
t
tensor([[10., 20., 30.],
[10., 20., 30.],
[10., 20., 30.]])
m + t
tensor([[11., 22., 33.],
[14., 25., 36.],
[17., 28., 39.]])
t.storage()
10.0 20.0 30.0 [torch.storage._TypedStorage(dtype=torch.float32, device=cpu) of size 3]
t.stride(), t.shape
((0, 1), torch.Size([3, 3]))
You can index with the special value [None] or use unsqueeze() to convert a 1-dimensional array into a 2-dimensional array (although one of those dimensions has value 1).
c.unsqueeze(0), c[None, :]
(tensor([[10., 20., 30.]]), tensor([[10., 20., 30.]]))
c.shape, c.unsqueeze(0).shape
(torch.Size([3]), torch.Size([1, 3]))
c.unsqueeze(1), c[:, None]
(tensor([[10.],
[20.],
[30.]]),
tensor([[10.],
[20.],
[30.]]))
c.shape, c.unsqueeze(1).shape
(torch.Size([3]), torch.Size([3, 1]))
You can always skip trailling ':'s. And '...' means 'all preceding dimensions'
c[None].shape,c[...,None].shape
(torch.Size([1, 3]), torch.Size([3, 1]))
c[:,None].expand_as(m)
tensor([[10., 10., 10.],
[20., 20., 20.],
[30., 30., 30.]])
m + c[:,None]
tensor([[11., 12., 13.],
[24., 25., 26.],
[37., 38., 39.]])
m + c[None,:]
tensor([[11., 22., 33.],
[14., 25., 36.],
[17., 28., 39.]])
c[None,:]
tensor([[10., 20., 30.]])
c[None,:].shape
torch.Size([1, 3])
c[:,None]
tensor([[10.],
[20.],
[30.]])
c[:,None].shape
torch.Size([3, 1])
c[None,:] * c[:,None]
tensor([[100., 200., 300.],
[200., 400., 600.],
[300., 600., 900.]])
c[None] > c[:,None]
tensor([[False, True, True],
[False, False, True],
[False, False, False]])
When operating on two arrays/tensors, Numpy/PyTorch compares their shapes element-wise. It starts with the trailing dimensions, and works its way forward. Two dimensions are compatible when
Arrays do not need to have the same number of dimensions. For example, if you have a 256*256*3 array of RGB values, and you want to scale each color in the image by a different value, you can multiply the image by a one-dimensional array with 3 values. Lining up the sizes of the trailing axes of these arrays according to the broadcast rules, shows that they are compatible:
Image (3d array): 256 x 256 x 3
Scale (1d array): 3
Result (3d array): 256 x 256 x 3
The numpy documentation includes several examples of what dimensions can and can not be broadcast together.
digit = m1[0]
digit.shape,m2.shape
(torch.Size([784]), torch.Size([784, 10]))
digit[:,None].shape
torch.Size([784, 1])
digit[:,None].expand_as(m2).shape
torch.Size([784, 10])
(digit[:,None]*m2).shape
torch.Size([784, 10])
def matmul(a,b):
(ar,ac),(br,bc) = a.shape,b.shape
c = torch.zeros(ar, bc)
for i in range(ar):
# c[i,j] = (a[i,:] * b[:,j]).sum() # previous version
c[i] = (a[i,:,None] * b).sum(dim=0) # broadcast version
return c
test_close(t1,matmul(m1, m2))
%timeit -n 10 _=matmul(m1, m2)
72.4 µs ± 2.58 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
Our time has gone from ~500ms to <0.1ms, an over 5000x improvement! We can run on the whole dataset now.
tr = matmul(x_train, weights)
tr
tensor([[ 6.77, 8.24, 1.10, ..., -11.12, 12.56, -2.75],
[ 5.97, 1.00, 6.21, ..., 13.31, 7.89, 2.27],
[ 11.77, 2.42, 10.78, ..., 3.62, 9.77, -2.33],
...,
[ -7.22, 0.33, 6.40, ..., 3.90, 8.37, -0.94],
[ -0.12, 13.33, -14.11, ..., -8.20, 4.03, 5.79],
[ 2.85, 10.83, 0.60, ..., -3.50, 13.04, 4.38]])
tr.shape
torch.Size([50000, 10])
%time _=matmul(x_train, weights)
CPU times: user 6.81 s, sys: 203 ms, total: 7.02 s Wall time: 673 ms
Einstein summation (einsum) is a compact representation for combining products and sums in a general way. The key rules are:
# c[i,j] += a[i,k] * b[k,j]
# c[i,j] = (a[i,:] * b[:,j]).sum()
def matmul(a,b): return torch.einsum('ik,kj->ij', a, b)
test_close(tr, matmul(x_train, weights), eps=1e-3)
%timeit -n 1 _=matmul(x_train, weights)
15.6 ms ± 199 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
We can use pytorch's function or operator directly for matrix multiplication.
test_close(tr, x_train@weights, eps=1e-3)
%timeit -n 1 _=torch.matmul(x_train, weights)
15.5 ms ± 68.1 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
def matmul(grid, a,b,c):
i,j = grid
if i < c.shape[0] and j < c.shape[1]:
tmp = 0.
for k in range(a.shape[1]): tmp += a[i, k] * b[k, j]
c[i,j] = tmp
res = torch.zeros(ar, bc)
matmul((0,0), m1, m2, res)
res
tensor([[1.63, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],
[0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],
[0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],
[0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],
[0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00]])
def launch_kernel(kernel, grid_x, grid_y, *args, **kwargs):
for i in range(grid_x):
for j in range(grid_y): kernel((i,j), *args, **kwargs)
res = torch.zeros(ar, bc)
launch_kernel(matmul, ar, bc, m1, m2, res)
res
tensor([[ 1.63, 1.53, 9.74, 0.77, 4.87, -7.71, -4.11, -1.61, 18.53, -1.87],
[ 6.07, 4.17, 3.91, 9.31, 8.88, 5.65, -4.19, 3.54, 13.63, -3.25],
[ -0.44, 7.36, -12.25, 4.04, 6.19, 4.13, 4.49, -2.29, 6.35, -1.41],
[ 4.86, -5.88, -13.86, -16.34, 2.87, 11.19, 16.82, -0.03, 0.06, -4.94],
[ -0.19, -9.08, 2.42, 4.34, 0.63, -0.72, -2.11, 12.18, 4.35, -10.41]])
from numba import cuda
@cuda.jit
def matmul(a,b,c):
i, j = cuda.grid(2)
if i < c.shape[0] and j < c.shape[1]:
tmp = 0.
for k in range(a.shape[1]): tmp += a[i, k] * b[k, j]
c[i,j] = tmp
cuda.syncthreads()
r = np.zeros(tr.shape)
m1g,m2g,rg = cuda.to_device(x_train),cuda.to_device(weights),cuda.to_device(r)
r.shape
(50000, 10)
TPB = 16
rr,rc = r.shape
blockspergrid = (math.ceil(rr / TPB), math.ceil(rc / TPB))
blockspergrid
(3125, 1)
matmul[blockspergrid, (TPB,TPB)](m1g,m2g,rg)
r = rg.copy_to_host()
test_close(tr, r, eps=1.03)
%%timeit -n 1
matmul[blockspergrid, (TPB,TPB)](m1g,m2g,rg)
r = rg.copy_to_host()
3.69 ms ± 81.8 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
m1c,m2c = x_train.cuda(),weights.cuda()
%timeit -n 1 r==(m1c@m2c).cpu()
429 µs ± 57.6 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
Our broadcasting version was >500ms, and our CUDA version is around 0.5ms, which is another 1000x improvement compared to broadcasting. So our total speedup is around 5 million times!