본문 바로가기
AI Basic/PyTorch

[PyTorch] 02 Basic Operation

by iamzieun 2023. 3. 13.

포스팅 개요

본 포스팅은 PyTorch의 기본적인 연산 중 view, reshape, transpose와 dot, mm, matmul의 차이점을 중심으로 정리한 글입니다. 

view / reshape / transpose

: contiguity의 측면에서 차이를 가짐

  • contiguous: matrix의 눈에 보이는 순차적인 shape information과 실제로 matrix의 각 데이터가 저장된 위치가 같은지 여부
  • view: tensor가 contiguous할 때 shape를 재구성함 → view에 의해 재구성된 tensor는 항상 contiguity가 보장됨
tensor_ex = torch.rand(size=(2, 3, 2))
tensor_ex

# view
tensor_ex.view([-1, 6])
#tensor([[0.2905, 0.4599, 0.0279, 0.6814, 0.3465, 0.4365],
#        [0.5382, 0.2599, 0.3857, 0.5505, 0.8370, 0.0509]])

tensor_ex.view([-1, 6]).is_contiguous()
#True
  • reshape: tensor가 contiguous하지 않아도 shape를 재구성한 후 강제로 contiguous하게 만듦 → reshape에 의해 재구성된 tensor는 항상 contiguity가 보장됨
# reshape
tensor_ex.reshape([-1,6])
#tensor([[0.2905, 0.4599, 0.0279, 0.6814, 0.3465, 0.4365],
#        [0.5382, 0.2599, 0.3857, 0.5505, 0.8370, 0.0509]])

tensor_ex.reshape([-1, 6]).is_contiguous()
#True
  • transpose: tensor의 contiguous 여부와 상관 없이 수학적 의미의 transpose 실행 → transpose에 의해 재구성된 tensor의 contiguity는 보장되지 않음
# transpose
tensor_ex.transpose(1, 2)
#tensor([[[0.2905, 0.0279, 0.3465],
#         [0.4599, 0.6814, 0.4365]],
#
#        [[0.5382, 0.3857, 0.8370],
#         [0.2599, 0.5505, 0.0509]]])

tensor_ex.transpose(1, 2).is_contiguous()
#False

dot / mm / matmul

내적 vs 행렬곱

dot vs mm vs matmul

  • dot: dot production. 내적
    • 두 벡터의 내적을 계산한 결과이므로 스칼라값을 return
import torch
a = torch.rand(10)
b = torch.rand(10)
a.dot(b)
# tensor(1.9150)
  • mm: matrix multiplication(행렬곱). broadcasting 지원 x
a = torch.rand(5, 2, 3)
b = torch.rand(5)
a.mm(b)
# RuntimeError: self must be a matrix
    • b를 5 x 1로 broadcasting 하지 못했기 때문에 RuntimeError 발생
  • matmul: matrix multiplication(행렬곱). broadcasting 지원 o
a = torch.rand(5, 2, 3)
b = torch.rand(3)
a.matmul(b)
# tensor([[0.3146, 0.4498],
#         [0.2307, 0.2622],
#         [0.2187, 0.1017],
#         [0.4165, 0.3774],
#         [0.3401, 0.3856]])
    • mm과 달리 broadcasting이 가능하기 때문에 5 * 2 * 3의 행렬과 3 * 1의 행렬을 행렬곱한 5 * 2의 행렬을 return

 

 

 

'AI Basic > PyTorch' 카테고리의 다른 글

[PyTorch] 01 TensorFlow vs PyTorch  (0) 2023.03.13

댓글