PyTorch’s Scatter_() Function + One-Hot Encoding (A Visual Explanation)

Dr. Abder-Rahman Ali
4 min readNov 8, 2020

The way PyTorch’s scatter_(dim, index, src) function works can be a bit confusing. So, I will take a visual approach in explaining the function as I believe it will be more clearer in grasping the concept further.

The scatter_() function takes three arguments: The dimension across which we will be filling out the data. If dim=0, this means that we will be filling the data across the rows, and if dim=1, we will be filling the data across the columns.

The second argument, index, is where in the respective dimension (row or column) the value to fill out will be placed. The third argument, src, represents the values we would like to fill out in the tensor upon which the scatter_() function will be affected. This can be a scalar or a tensor of values.

Let’s take an example to make things more clearer.

Let’s break this step-by-step. What is done here is scattering the value “1” (src) across the rows (dim=0) of the affected tensor (input_tensor), which is a zeros tensor in our case.

I also show here a way to determine the number of rows and columns in input_tensor, provided that the dimension we fill out the values across (i.e. rows) needs to be at least the same value as the dimension in index_tensor, unless there is an index value in the index_tensor which is higher. For instance, index “3” means that we need to have 4 rows, which is greater than the number of rows in index_tensor. Since the affection will be across the rows (dim=0), the number of columns will be the same as those in index_tensor.

The figure below depicts how we select the indices in index_tensor across the rows (dim=0), and map them to input_tensor.

Let’s write what we’ve done above in PyTorch, and it would also be an opportunity for us to check if our input_tensor is correct.

import torchindex_tensor = torch.tensor([[1,3,0,2,0],[0,2,2,1,3],[3,0,0,1,1]])
input_tensor = torch.zeros(4,5).scatter_(0,index_tensor,1)
print(input_tensor)
tensor([[1., 1., 1., 0., 1.],
[1., 0., 0., 1., 1.],
[0., 1., 1., 1., 0.],
[1., 1., 0., 0., 1.]])

So, yep, we got that right!

What happens when dim=1? That is, filling the values in the respective index across the columns of input_tensor; the figures below depict the changes that will occur.

This is what the above would look like in PyTorch.

import torchindex_tensor = torch.tensor([[1,3,0,2,0],[0,2,2,1,3],[3,0,0,1,1]])
input_tensor = torch.zeros(3,5).scatter_(1,index_tensor,1)
print(input_tensor)
tensor([[1., 1., 1., 1., 0.],
[1., 1., 1., 1., 0.],
[1., 1., 0., 1., 0.]])

One-hot encoding

Using the above, we can now easily do one-hot encoding using PyTorch. The following code shows us how we can do just that.

import torchclasses = 10 # representing numbers 0-9
labels = torch.tensor([0,1,2,3,4,5,6,7,8,9]).view(10,1)
one_hot = torch.zeros(classes,classes).scatter_(1,labels,1)
print(one_hot)tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]])

So yes, that was it for this tutorial. Hope you got a more sense of PyTorch’s scatter_() function, and don’t hesitate to ask any questions or comment below.

--

--

Dr. Abder-Rahman Ali

Research Fellow @ Massachusetts General Hospital/Harvard Medical School | https://abder.mgh.harvard.edu