PyTorch’s Scatter_() Function + One-Hot Encoding (A Visual Explanation)

The way PyTorch’s scatter_(dim, index, src) function works can be a bit confusing. So, I will take a visual approach in explaining the function as I believe it will be more clearer in grasping the concept further.

The scatter_() function takes three arguments: The dimension across which we will be filling out the data. If dim=0, this means…