PyTorch is a popular tool for working with machine learning, and it uses something called "tensors" to handle data. Imagine tensors as multi-dimensional tables of numbers, like spreadsheets but much more powerful. To work with tensors, we often need to access specific data inside them, which is where indices come into play.

Let's break down what indices are and how to use them in PyTorch, with simple examples.

What Are Tensors?

Tensors are like flexible boxes that can hold data in different shapes. They can be one-dimensional (like a list), two-dimensional (like a table), or even more complex.

Example Tensor

Here's a simple tensor, think of it as a table with 2 rows and 2 columns:

import torch
tensor = torch.tensor([[1, 2], [3, 4]])

This tensor looks like this:

[[1, 2],
 [3, 4]]

Basic Indexing

1. Getting a Single Value

You can pick a single value from the tensor by using its row and column numbers. This is called indexing.

value = tensor[0, 1]
print(value)  # Output: 2

Here, tensor[0, 1] means "give me the value in the first row and the second column."

2. Getting a Whole Row or Column

You can also get a whole row or column from the tensor.

row = tensor[1, :]
print(row)  # Output: tensor([3, 4])

Here, tensor[1, :] means "give me all columns in the second row."

column = tensor[:, 1]
print(column)  # Output: tensor([2, 4])

And tensor[:, 1] means "give me all rows in the second column."

More Advanced Indexing

1. Selecting Multiple Specific Values

You can pick multiple values at once using lists of indices.

selected = tensor[[0, 1], [0, 1]]
print(selected)  # Output: tensor([1, 4])

Here, [0, 1] for rows and [0, 1] for columns means "give me the values at (0,0) and (1,1)."

2. Using Conditions

You can select values based on conditions, like picking all numbers greater than 2.

mask = tensor > 2
filtered = tensor[mask]
print(filtered)  # Output: tensor([3, 4])

This creates a mask (true/false) for values greater than 2 and uses it to pick those values.

Special Indexing Functions

1. torch.index_select

Use this to pick entire rows or columns.

indices = torch.tensor([0, 2])
selected_rows = torch.index_select(tensor, 0, indices)
print(selected_rows)

Here, indices tells PyTorch which rows to pick.

2. torch.gather

Use this to collect specific values based on another tensor of indices.

tensor = torch.tensor([[1, 2], [3, 4]])
indices = torch.tensor([[0, 0], [1, 0]])
gathered = torch.gather(tensor, 1, indices)
print(gathered)  # Output: tensor([[1, 1], [4, 3]])

This picks values from tensor based on the positions given in indices.

3. torch.scatter

Use this to place values in specific positions in a tensor.

tensor = torch.zeros(2, 2)
indices = torch.tensor([[0, 1], [1, 0]])
values = torch.tensor([[5, 7], [9, 2]])
tensor.scatter_(1, indices, values)
print(tensor)  # Output: tensor([[5, 7], [2, 9]])

Here, scatter_ puts values into tensor at the positions given by indices.

4. torch.nonzero

This helps find where the non-zero (not zero) values are in the tensor.

tensor = torch.tensor([[0, 1], [2, 0]])
nonzero_positions = torch.nonzero(tensor)
print(nonzero_positions)  # Output: tensor([[0, 1], [1, 0]])

It tells you where non-zero numbers are located in the tensor.

Real-Life Example: Managing a Spreadsheet

Let's say you're managing a small spreadsheet of scores:

scores = torch.tensor([[80, 90], [70, 85]])

This tensor has scores for two students, where the first row is for Student A and the second row for Student B. Here's how you can use indexing:

Get a Specific Score

To get the score of Student A in the second test:

score = scores[0, 1]
print(score)  # Output: 90

Get All Scores for a Student

To get all scores for Student B:

student_b_scores = scores[1, :]
print(student_b_scores)  # Output: tensor([70, 85])

Find Scores Above a Threshold

To find scores greater than 80:

high_scores = scores[scores > 80]
print(high_scores)  # Output: tensor([90, 85])

Conclusion

Indices in PyTorch help you efficiently work with data inside tensors, just like pointing to cells in a spreadsheet.

By mastering these simple techniques, you can easily access, modify, and analyze data in your machine learning projects.