PyTorch’s Scatter_() Function + One-Hot Encoding (A Visual Explanation)

import torchindex_tensor = torch.tensor([[1,3,0,2,0],[0,2,2,1,3],[3,0,0,1,1]])
input_tensor = torch.zeros(4,5).scatter_(0,index_tensor,1)
print(input_tensor)
tensor([[1., 1., 1., 0., 1.],
[1., 0., 0., 1., 1.],
[0., 1., 1., 1., 0.],
[1., 1., 0., 0., 1.]])
import torchindex_tensor = torch.tensor([[1,3,0,2,0],[0,2,2,1,3],[3,0,0,1,1]])
input_tensor = torch.zeros(3,5).scatter_(1,index_tensor,1)
print(input_tensor)
tensor([[1., 1., 1., 1., 0.],
[1., 1., 1., 1., 0.],
[1., 1., 0., 1., 0.]])

One-hot encoding

import torchclasses = 10 # representing numbers 0-9
labels = torch.tensor([0,1,2,3,4,5,6,7,8,9]).view(10,1)
one_hot = torch.zeros(classes,classes).scatter_(1,labels,1)
print(one_hot)tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]])

--

--

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store
Dr. Abder-Rahman Ali

Dr. Abder-Rahman Ali

Leveraging machine/deep learning and image processing in medical image analysis | https://abder.me