# PyTorch’s Scatter_() Function + One-Hot Encoding (A Visual Explanation)

--

The way PyTorch’s *scatter_(dim, index, src)* function works can be a bit confusing. So, I will take a visual approach in explaining the function as I believe it will be more clearer in grasping the concept further.

The *scatter_()* function takes three arguments: The *dimension* across which we will be filling out the data. If dim=0, this means that we will be filling the data across the rows, and if dim=1, we will be filling the data across the columns.

The second argument, *index*, is where in the respective dimension (row or column) the value to fill out will be placed. The third argument, *src*, represents the values we would like to fill out in the tensor upon which the scatter_() function will be affected. This can be a scalar or a tensor of values.

Let’s take an example to make things more clearer.

Let’s break this step-by-step. What is done here is scattering the value “1” (src) across the rows (dim=0) of the affected tensor (input_tensor), which is a zeros tensor in our case.

I also show here a way to determine the number of rows and columns in input_tensor, provided that the dimension we fill out the values across (i.e. rows) needs to be at least the same value as the dimension in index_tensor, unless there is an index value in the index_tensor which is higher. For instance, index “3” means that we need to have 4 rows, which is greater than the number of rows in index_tensor. Since the affection will be across the rows (dim=0), the number of columns will be the same as those in index_tensor.

The figure below depicts how we select the indices in index_tensor across the rows (dim=0), and map them to input_tensor.

Let’s write what we’ve done above in PyTorch, and it would also be an opportunity for us to check if our input_tensor is correct.

import torchindex_tensor = torch.tensor([[1,3,0,2,0],[0,2,2,1,3],[3,0,0,1,1]])

input_tensor = torch.zeros(4,5).scatter_(0,index_tensor,1)

print(input_tensor)tensor([[1., 1., 1., 0., 1.],

[1., 0., 0., 1., 1.],

[0., 1., 1., 1., 0.],

[1., 1., 0., 0., 1.]])

So, yep, we got that right!

What happens when dim=1? That is, filling the values in the respective index across the *columns* of input_tensor; the figures below depict the changes that will occur.

This is what the above would look like in PyTorch.

import torchindex_tensor = torch.tensor([[1,3,0,2,0],[0,2,2,1,3],[3,0,0,1,1]])

input_tensor = torch.zeros(3,5).scatter_(1,index_tensor,1)

print(input_tensor)tensor([[1., 1., 1., 1., 0.],

[1., 1., 1., 1., 0.],

[1., 1., 0., 1., 0.]])

# One-hot encoding

Using the above, we can now easily do one-hot encoding using PyTorch. The following code shows us how we can do just that.

import torchclasses = 10 # representing numbers 0-9

labels = torch.tensor([0,1,2,3,4,5,6,7,8,9]).view(10,1)

one_hot = torch.zeros(classes,classes).scatter_(1,labels,1)print(one_hot)tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],

[0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],

[0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],

[0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],

[0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],

[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],

[0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],

[0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],

[0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],

[0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]])

So yes, that was it for this tutorial. Hope you got a more sense of PyTorch’s scatter_() function, and don’t hesitate to ask any questions or comment below.