> 文章列表 > PyTorch 的input[range(target.shape[0]), target] 表达式

PyTorch 的input[range(target.shape[0]), target] 表达式

PyTorch 的input[range(target.shape[0]), target] 表达式

在 PyTorch 中,类似 input[range(target.shape[0]), target] 这样的表达式通常用于获取输入张量(input)中特定位置的值,其中 位置由 target 张量指定的。 首先,range(target.shape[0]) 它创建了一个从 0 到 target 张量中第一个维度的大小计算得出的整数范围。例如,如果 target 张量的形状为 (5, 3),则 range(target.shape[0]) 返回一个大小为 5 的整数范围。然后,这个整数范围用于指定要获取的输入中的特定行, 即 input[range(target.shape[0])] 会返回一个张量(input)与 target 张量第一维大小一样, 且包含输入张量(input)中所指定行所有的列。 接下来,使用 target 张量指定列,从而选取每行中的所需位置。input[range(target.shape[0]), target] 最终会返回一个大小为 target.shape 的张量,包含了输入张量中所有位置需要的值。

举例来说, 如果 input 张量的形状为 (5, 4), target 张量为 (5), 并且具有以下值:

input = Tensor([[ 764,    4,   67,  785],[ 311,  101,  911,  199],[ 759,  362,  215,  651],[ 471,  821,  738,  875],[ 109,  828,  994,  675]])
target = Tensor([1, 0, 2, 3, 2])

那么 input[range(target.shape[0]), target] 将返回具有以下值得张量:

Tensor([  4, 311, 215, 875, 994])

其中每个值都是得到所需位置的结果。