PyTorch 的input[range(target.shape[0]), target] 表达式
![PyTorch 的input[range(target.shape[0]), target] 表达式](http://pic.ttrar.cn/nice/PyTorch%e7%9a%84inp.jpg)
在 PyTorch 中,类似 input[range(target.shape[0]), target] 这样的表达式通常用于获取输入张量(input)中特定位置的值,其中 位置由 target 张量指定的。 首先,range(target.shape[0]) 它创建了一个从 0 到 target 张量中第一个维度的大小计算得出的整数范围。例如,如果 target 张量的形状为 (5, 3),则 range(target.shape[0]) 返回一个大小为 5 的整数范围。然后,这个整数范围用于指定要获取的输入中的特定行, 即 input[range(target.shape[0])] 会返回一个张量(input)与 target 张量第一维大小一样, 且包含输入张量(input)中所指定行所有的列。 接下来,使用 target 张量指定列,从而选取每行中的所需位置。input[range(target.shape[0]), target] 最终会返回一个大小为 target.shape 的张量,包含了输入张量中所有位置需要的值。
举例来说, 如果 input 张量的形状为 (5, 4), target 张量为 (5), 并且具有以下值:
input = Tensor([[ 764, 4, 67, 785],[ 311, 101, 911, 199],[ 759, 362, 215, 651],[ 471, 821, 738, 875],[ 109, 828, 994, 675]])
target = Tensor([1, 0, 2, 3, 2])
那么 input[range(target.shape[0]), target] 将返回具有以下值得张量:
Tensor([ 4, 311, 215, 875, 994])
其中每个值都是得到所需位置的结果。


