ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • [PyTorch] nn.Linear의 Weight 행렬은 왜 전치되어 있을까?
    카테고리 없음 2023. 3. 26. 19:50
     

    Why does the Linear module seems to do unnecessary transposing?

    I was looking at the code for torch.nn.Linear(in_features, out_features, bias=True) and it seems that it store the matrix one way but then decides that to compute stuff its necessary to transpose (though the transposing seems it could have been avoided). W

    discuss.pytorch.org

     

    PyTorch 공식문서를 보면 nn.Linear의 계산식이 $y=xA^T+b$라고 명시되어있다. $x$는 입력값, $A$는 weight 행렬, b는 bias 벡터, y는 출력값이다. 여기서 $x$의 크기가 (1, input_features)이고 $A$의 크기가 (output_features, input_features)이기 때문에 행렬곱을 하려면 $A$가 전치가 되어야 한다. 하지만 처음부터 $A$의 크기를 (input_features, output_features)로 잡으면 전치할 필요도 없는데 왜 불필요한 연산을 하는 것일까? 그 이유는 캐시의 지역성에 따라 계산이 더 효율적으로 수행되어 속도가 빨라지기 때문이다.


    캐시의 지역성

    배열에 있는 값에 차례대로 접근하는 것은 자주 있는 일이다. 그래서 CPU는 특정 위치의 메모리에 접근할 때 그 주변에 있는 메모리도 가까운 미래에 쓰일 것으로 가정하고 미리 한꺼번에 가져와서 필요할 때 빠르게 값을 찾을 수 있도록 보관한다. 캐시를 통해 빠르게 값을 찾는 경우를 가리켜 캐시 히트라 부른다. 아래 코드를 보자.

    import timeit
    def print_time(name, setup, stmt):
        result = timeit.repeat(stmt, setup, number=10)
        result_str = f'{name}:  '
        result_str += '  '.join(f'{i:.2f}' for i in result)
        result_str += f'  | average:  {sum(result) / len(result):.2f}'
        print(result_str)
    setup_dim2_list = '''
    x = 3000
    y = 3000
    a = [[0 for _ in range(x)] for _ in range(y)]
    '''
    stmt1 = '''
    for i in range(x):
        for j in range(y):
            a[i][j] += 10
    '''
    stmt2 = '''
    for i in range(x):
        for j in range(y):
            a[j][i] += 10
    '''
    print_time('case1', setup_dim2_list, stmt1)
    # case1:  5.91  5.66  5.79  5.67  5.97  | average:  5.80
    
    print_time('case2', setup_dim2_list, stmt2)
    # case2:  8.42  8.52  8.76  7.93  8.33  | average:  8.39

    위 코드는 2차원 list를 두 가지 방법으로 순회하면서 실행시간을 측정한다. stmt1은 한 행을 다 보고 다음 행을 보는 방식이고 stmt2는 한 열을 다 보고 다음 열을 보는 방식이다. stmt1이 더 빠른 이유는 한 행 안의 항목의 메모리상 위치가 서로 붙어있어서 캐시의 지역성에 의해 캐시 히트가 많이 발생하기 때문이다. 반면 stmt2는 띄엄띄엄 떨어져있는 값을 참조하기 때문에 느리다.


    Tensor.t()의 원리

    t()는 전치행렬을 반환하는 함수이다. 하지만 t()를 호출했을 때 전체 데이터를 전치하여 새로 저장하는 일은 하지 않는다. 그저 앞으로 x행 y열을 읽으라는 명령이 들어오면 y행 x열을 반환하기로 약속만 해놓는다. 그래서 전치행렬을 구하는 연산은 비용이 들지 않는다. 아래 코드를 보자.

    a = torch.empty([3, 2])
    b = a.t()
    print(a[2][0].data_ptr() == b[0][2].data_ptr()) # True

     

    data_ptr은 데이터 주소를 반환하는 함수이다. a[2][0]과 b[0][2]은 같은 주소를 가리킨다. 원본 데이터에는 변화가 없고 인덱싱에서만 차이가 난다.


    행렬곱을 할 때 캐시 히트율 높이기

    행렬 $A$, $B$가 있을 때 $C = AB$가 되도록 행렬곱을 한다고 하면 수식 $c_{ij}=\sum^k a_{ik} * b_{kj}$을 통해 계산할 수 있다. 모든 $i$, $j$, $k$에 대하여 순회를 해야 하므로 3중 for문으로 짜야한다. 아래 코드를 보자.

    setup_mm = '''
    import random
    x = 300
    h = 300
    y = 300
    a = [[random.randint(-4, 4) for _ in range(x)] for _ in range(h)]
    b = [[random.randint(-4, 4) for _ in range(h)] for _ in range(y)]
    c = [[0 for _ in range(x)] for _ in range(y)]
    '''
    stmt3 = '''
    for i in range(x):
        for j in range(y):
            c[i][j] = 0
            for k in range(h):
                c[i][j] += a[i][k] + b[k][j]
    '''
    print_time('case3', setup_mm, stmt3)
    # case3:  38.06  40.37  41.82  40.11  39.53  | average:  39.98

     

    위 코드는 (x,h)크기의 행렬 a와 (h,y)크기의 행렬 b를 곱하여 (x,y)크기의 행렬 c를 계산하고 시간을 측정한다. 여기서 $b_{kj}$를 보면 stmt2에서처럼 한 열씩 순회하고 있는 것을 볼 수 있다. 캐시 히트율을 높이기 위해 for j문과 for k문을 바꾼다면 c[i][j] = 0줄의 처리가 애매해진다. 어떻게 해결할 수 있을까? 아래 코드를 보자.

    setup_mm_T = '''
    import random
    x = 300
    h = 300
    y = 300
    a = [[random.randint(-4, 4) for _ in range(x)] for _ in range(h)]
    b = [[random.randint(-4, 4) for _ in range(y)] for _ in range(h)]
    c = [[0 for _ in range(x)] for _ in range(y)]
    '''
    stmt4 = '''
    for i in range(x):
        for j in range(y):
            c[i][j] = 0
            for k in range(h):
                c[i][j] += a[i][k] + b[j][k]
    '''
    print_time('case4', setup_mm_T, stmt4)
    # case4:  35.92  36.13  37.83  36.26  35.45  | average:  36.32

    행렬 b의 크기를 (y, h)로 선언하고 전치한 형태의 b[j][k]로 접근하고 있다. 캐시 히트율을 높이면서, 비용이 들지 않는 전치행렬을 이용하여 최적화한 것이다. 약간의 시간 단축이 이뤄진 것을 볼 수 있다.


    정리

    nn.Linear에서 A를 전치행렬로 저장하고 행렬곱을 수행할 때 전치하면 캐시 히트율이 높아져 행렬곱을 빠르게 할 수 있다.

    댓글

Designed by Tistory.