-
[PyTorch] nn.Linear의 Weight 행렬은 왜 전치되어 있을까?카테고리 없음 2023. 3. 26. 19:50
PyTorch 공식문서를 보면 nn.Linear의 계산식이 $y=xA^T+b$라고 명시되어있다. $x$는 입력값, $A$는 weight 행렬, b는 bias 벡터, y는 출력값이다. 여기서 $x$의 크기가 (1, input_features)이고 $A$의 크기가 (output_features, input_features)이기 때문에 행렬곱을 하려면 $A$가 전치가 되어야 한다. 하지만 처음부터 $A$의 크기를 (input_features, output_features)로 잡으면 전치할 필요도 없는데 왜 불필요한 연산을 하는 것일까? 그 이유는 캐시의 지역성에 따라 계산이 더 효율적으로 수행되어 속도가 빨라지기 때문이다.
캐시의 지역성
배열에 있는 값에 차례대로 접근하는 것은 자주 있는 일이다. 그래서 CPU는 특정 위치의 메모리에 접근할 때 그 주변에 있는 메모리도 가까운 미래에 쓰일 것으로 가정하고 미리 한꺼번에 가져와서 필요할 때 빠르게 값을 찾을 수 있도록 보관한다. 캐시를 통해 빠르게 값을 찾는 경우를 가리켜 캐시 히트라 부른다. 아래 코드를 보자.
import timeit def print_time(name, setup, stmt): result = timeit.repeat(stmt, setup, number=10) result_str = f'{name}: ' result_str += ' '.join(f'{i:.2f}' for i in result) result_str += f' | average: {sum(result) / len(result):.2f}' print(result_str) setup_dim2_list = ''' x = 3000 y = 3000 a = [[0 for _ in range(x)] for _ in range(y)] ''' stmt1 = ''' for i in range(x): for j in range(y): a[i][j] += 10 ''' stmt2 = ''' for i in range(x): for j in range(y): a[j][i] += 10 ''' print_time('case1', setup_dim2_list, stmt1) # case1: 5.91 5.66 5.79 5.67 5.97 | average: 5.80 print_time('case2', setup_dim2_list, stmt2) # case2: 8.42 8.52 8.76 7.93 8.33 | average: 8.39
위 코드는 2차원 list를 두 가지 방법으로 순회하면서 실행시간을 측정한다. stmt1은 한 행을 다 보고 다음 행을 보는 방식이고 stmt2는 한 열을 다 보고 다음 열을 보는 방식이다. stmt1이 더 빠른 이유는 한 행 안의 항목의 메모리상 위치가 서로 붙어있어서 캐시의 지역성에 의해 캐시 히트가 많이 발생하기 때문이다. 반면 stmt2는 띄엄띄엄 떨어져있는 값을 참조하기 때문에 느리다.
Tensor.t()의 원리
t()는 전치행렬을 반환하는 함수이다. 하지만 t()를 호출했을 때 전체 데이터를 전치하여 새로 저장하는 일은 하지 않는다. 그저 앞으로 x행 y열을 읽으라는 명령이 들어오면 y행 x열을 반환하기로 약속만 해놓는다. 그래서 전치행렬을 구하는 연산은 비용이 들지 않는다. 아래 코드를 보자.
a = torch.empty([3, 2]) b = a.t() print(a[2][0].data_ptr() == b[0][2].data_ptr()) # True
data_ptr은 데이터 주소를 반환하는 함수이다. a[2][0]과 b[0][2]은 같은 주소를 가리킨다. 원본 데이터에는 변화가 없고 인덱싱에서만 차이가 난다.
행렬곱을 할 때 캐시 히트율 높이기
행렬 $A$, $B$가 있을 때 $C = AB$가 되도록 행렬곱을 한다고 하면 수식 $c_{ij}=\sum^k a_{ik} * b_{kj}$을 통해 계산할 수 있다. 모든 $i$, $j$, $k$에 대하여 순회를 해야 하므로 3중 for문으로 짜야한다. 아래 코드를 보자.
setup_mm = ''' import random x = 300 h = 300 y = 300 a = [[random.randint(-4, 4) for _ in range(x)] for _ in range(h)] b = [[random.randint(-4, 4) for _ in range(h)] for _ in range(y)] c = [[0 for _ in range(x)] for _ in range(y)] ''' stmt3 = ''' for i in range(x): for j in range(y): c[i][j] = 0 for k in range(h): c[i][j] += a[i][k] + b[k][j] ''' print_time('case3', setup_mm, stmt3) # case3: 38.06 40.37 41.82 40.11 39.53 | average: 39.98
위 코드는 (x,h)크기의 행렬 a와 (h,y)크기의 행렬 b를 곱하여 (x,y)크기의 행렬 c를 계산하고 시간을 측정한다. 여기서 $b_{kj}$를 보면 stmt2에서처럼 한 열씩 순회하고 있는 것을 볼 수 있다. 캐시 히트율을 높이기 위해 for j문과 for k문을 바꾼다면 c[i][j] = 0줄의 처리가 애매해진다. 어떻게 해결할 수 있을까? 아래 코드를 보자.
setup_mm_T = ''' import random x = 300 h = 300 y = 300 a = [[random.randint(-4, 4) for _ in range(x)] for _ in range(h)] b = [[random.randint(-4, 4) for _ in range(y)] for _ in range(h)] c = [[0 for _ in range(x)] for _ in range(y)] ''' stmt4 = ''' for i in range(x): for j in range(y): c[i][j] = 0 for k in range(h): c[i][j] += a[i][k] + b[j][k] ''' print_time('case4', setup_mm_T, stmt4) # case4: 35.92 36.13 37.83 36.26 35.45 | average: 36.32
행렬 b의 크기를 (y, h)로 선언하고 전치한 형태의 b[j][k]로 접근하고 있다. 캐시 히트율을 높이면서, 비용이 들지 않는 전치행렬을 이용하여 최적화한 것이다. 약간의 시간 단축이 이뤄진 것을 볼 수 있다.
정리
nn.Linear에서 A를 전치행렬로 저장하고 행렬곱을 수행할 때 전치하면 캐시 히트율이 높아져 행렬곱을 빠르게 할 수 있다.