Pytorch - Argmax

Pytorch - Argmax
“Neste tutorial de Pytorch, veremos como retornar posições de índice de valores máximos de um tensor usando argmax ().

Pytorch é uma estrutura de código aberto disponível com uma linguagem de programação Python. Podemos processar os dados em pytorch na forma de um tensor.

Um tensor é uma matriz multidimensional que é usada para armazenar os dados. Então, para usar um tensor, temos que importar o módulo da tocha.

Para criar um tensor, o método usado é tensor () ”

Sintaxe:

tocha.Tensor (dados)

Onde os dados são uma matriz multidimensional.

argmax ()

argmax () em pytorch é usado para retornar o índice do valor máximo de todos os elementos no tensor de entrada.

Sintaxe:

tocha.argmax (tensor, dim, keepdim)

Onde

  1. O tensor é o tensor de entrada
  2. Dim é reduzir a dimensão. Dim = 0 Especifica a comparação de colunas, que obterá o índice para obter o valor máximo ao longo de uma coluna, e Dim = 1 Especifica a comparação de linha, que receberá o índice para o valor máximo ao longo da linha.
  3. KeepDim verifica se o tensor de saída tem dimensão (dim) retida ou não

Exemplo 1

Neste exemplo, criaremos um tensor com 2 dimensões que possuem 3 linhas e 5 colunas e apliquem argmax () em linhas e colunas.

#import módulo tocha
importação de tocha
#Crie um tensor com 2 dimensões (3 * 5)
#com elementos aleatórios usando a função Randn ()
Dados = Torch.Randn (3,5)
#mostrar
Impressão (dados)
#get Índice máximo ao longo de colunas com argmax
Imprimir (tocha.argmax (dados, dim = 0))
#get Índice máximo ao longo de linhas com argmax
Imprimir (tocha.argmax (dados, dim = 1))

Saída:

tensor ([[0.6699, 1.3390, -1.0658, -1.8200, 0.6544],
[-0.3117, 0.2488, 0.2677, 0.2568, 0.5337],
[-1.0966, 1.8024, -0.7538, -0.2553, -1.0591]])
tensor ([0, 2, 1, 1, 0])
tensor ([1, 4, 1])

Podemos ver que os valores máximos presentes no índice ao longo das colunas são:

  1. Valor máximo - 0.6699. Seu índice é 0.
  2. Valor máximo - 1.8024. Seu índice é 2.
  3. Valor máximo - 0.2677. Seu índice é 1.
  4. Valor máximo - 0.2568. Seu índice é 1.
  5. Valor máximo - 0.6544. Seu índice é 0.

Da mesma forma, os valores máximos presentes no índice ao longo das linhas são:

  1. Valor máximo - 1.3390. Seu índice é 1.
  2. Valor máximo - 0.5337. Seu índice é 4.
  3. Valor máximo - 1.8024. Seu índice é 1.

Exemplo 2

Crie tensor com 5 * 5 Matrix e aplique Argmax ()

#import módulo tocha
importação de tocha
#Crie um tensor com 2 dimensões (5 * 5)
#com elementos aleatórios usando a função Randn ()
Dados = Torch.Randn (5,5)
#mostrar
Impressão (dados)
#get Índice máximo ao longo de colunas com argmax
Imprimir (tocha.argmax (dados, dim = 0))
#get Índice máximo ao longo de linhas com argmax
Imprimir (tocha.argmax (dados, dim = 1))

Saída:

tensor ([[-0.9553, -0.2611, -2.1233, -0.5208, -0.3458],
[-0.5466, -1.6395, 0.2576, -0.3123, 0.6785],
[-0.4574, 1.5301, 0.4812, 0.3434, 0.1388],
[0.8364, 0.3821, 0.1529, 1.4529, 0.3747],
[-1.4991, -1.8821, -0.2861, -0.4067, 1.1323]])
tensor ([3, 2, 2, 3, 4])
tensor ([1, 4, 1, 3, 4])

Podemos ver que os valores máximos presentes no índice ao longo das colunas são:

  1. Valor máximo - 0.8364. Seu índice é 3.
  2. Valor máximo - 1.5301. Seu índice é 2.
  3. Valor máximo - 0.4812. Seu índice é 2.
  4. Valor máximo - 1.4529. Seu índice é 3.
  5. Valor máximo - 1.1323. Seu índice é 4.

Da mesma forma, os valores máximos presentes no índice ao longo das linhas são:

  1. Valor máximo - -0.2611. Seu índice é 1.
  2. Valor máximo - 0.6785. Seu índice é 4.
  3. Valor máximo - 1.5301. Seu índice é 1.
  4. Valor máximo - 1.4529. Seu índice é 3.
  5. Valor máximo - 1.1323. Seu índice é 4.

Trabalhe com a CPU

Se você deseja executar uma função argmax () na CPU, temos que criar um tensor com uma função CPU (). Isso será executado em uma máquina de CPU.

Quando estamos criando um tensor, neste momento, podemos usar a função CPU ().

Sintaxe:

tocha.Tensor (dados).CPU()

Exemplo 1

Crie tensor com 5 * 5 matriz com CPU () e aplique argmax ()
#import módulo tocha
importação de tocha
#Crie um tensor com 2 dimensões (5 * 5)
#com elementos aleatórios usando a função Randn () com CPU ()
Dados = Torch.Randn (5,5).CPU()
#mostrar
Impressão (dados)
#get Índice máximo ao longo de colunas com argmax
Imprimir (tocha.argmax (dados, dim = 0))
#get Índice máximo ao longo de linhas com argmax
Imprimir (tocha.argmax (dados, dim = 1))

Saída:

tensor ([[-0.2213, 1.6140, -0.0774, 0.4135, 0.1379],
[-0.4415, -2.5789, 0.8294, -0.9309, 1.3535],
[-1.3256, -0.7233, -0.9713, 1.0742, 1.9350],
[-0.7126, -1.3336, 0.7371, -0.2253, 0.1675],
[-0.1174, -0.5773, 0.8887, -0.2563, -1.0667]])
tensor ([4, 0, 4, 2, 2])
tensor ([1, 4, 4, 2, 2])

Podemos ver que os valores máximos presentes no índice ao longo das colunas são:

  1. Valor máximo - -0.1174. Seu índice é 4.
  2. Valor máximo - 1.6140. Seu índice é 0.
  3. Valor máximo - 0.8887. Seu índice é 4.
  4. Valor máximo - 1.0742. Seu índice é 2.
  5. Valor máximo - 1.9350. Seu índice é 2.

Da mesma forma, os valores máximos presentes no índice ao longo das linhas são:

  1. Valor máximo - 1.6140. Seu índice é 1.
  2. Valor máximo - 1.3535. Seu índice é 4.
  3. Valor máximo - 1.9350. Seu índice é 4.
  4. Valor máximo - 0.7371. Seu índice é 2.
  5. Valor máximo - 0.8887. Seu índice é 2.

Exemplo 2

Neste exemplo, criaremos um tensor com 2 dimensões que possuem 3 linhas e 5 colunas usando a função CPU () e apliquem argmax () em linhas e colunas.

#import módulo tocha
importação de tocha
#Crie um tensor com 2 dimensões (3 * 5)
#com elementos aleatórios usando randn () com função cpu ()
Dados = Torch.Randn (3,5).CPU()
#mostrar
Impressão (dados)
#get Índice máximo ao longo de colunas com argmax
Imprimir (tocha.argmax (dados, dim = 0))
#get Índice máximo ao longo de linhas com argmax
Imprimir (tocha.argmax (dados, dim = 1))

Saída:

tensor ([[0.6699, 1.3390, -1.0658, -1.8200, 0.6544],
[-0.3117, 0.2488, 0.2677, 0.2568, 0.5337],
[-1.0966, 1.8024, -0.7538, -0.2553, -1.0591]])
tensor ([0, 2, 1, 1, 0])
tensor ([1, 4, 1])

Podemos ver que os valores máximos presentes no índice ao longo das colunas são:

  1. Valor máximo - 0.6699. Seu índice é 0.
  2. Valor máximo - 1.8024. Seu índice é 2.
  3. Valor máximo - 0.2677. Seu índice é 1.
  4. Valor máximo - 0.2568. Seu índice é 1.
  5. Valor máximo - 0.6544. Seu índice é 0.

Da mesma forma, os valores máximos presentes no índice ao longo das linhas são:

  1. Valor máximo - 1.3390. Seu índice é 1.
  2. Valor máximo - 0.5337. Seu índice é 4.
  3. Valor máximo - 1.8024. Seu índice é 1.

Conclusão

Nesta lição de Pytorch, vimos o que argmax () e como aplicar argmax () em um tensor para retornar índices de valores máximos entre colunas e linhas.

Também criamos um tensor com função CPU () e retornamos índices de valores máximos. Dim é o parâmetro usado para retornar índices de valores máximos entre as colunas quando está definido como 0 e retornar índices de valores máximos entre linhas quando está definido como 1.