카테고리 없음

7/2 코테일지 (메모리 초과와 counting sort, array.index)

코테챌린져 2024. 7. 2. 15:54

메모리 초과 Counting Sort 

10989 수 정렬하기 3

 

다음 과 같은 문제인데, 들어오는 숫자의 수가 10,000,000개나 된다는 것을 알 수 있으며

메모리 제한은 8MB이다. 

정수 하나당 4B이므로 8MB/4B = 2MB= 2 *2^20= 약 2*10^6개의 정수를 받을 수 있다. 문제의 10^7개의 문자를 받을 수 없다.

해당 사실을 모르고 초기에는 다음과 같이 작성을 하였고

import sys
n = int(sys.stdin.readline().rstrip())

array = []
for _ in range (n):
  array.append(int(sys.stdin.readline().rstrip()))

array.sort()

for element in array:
  print(element)

메모리 초과가 나서, array.sort()는 mergesort를 사용하기 때문에 메모리 초과가 난다 생각하여 Quick Sort로 시도해보았지만 똑같이 실패했다.

def quickSort(arr: list[int], s: int, e: int) -> list[int]:
    if e - s + 1 <= 1:
        return arr

    pivot = arr[e]
    left = s # pointer for left side

    # Partition: elements smaller than pivot on left side
    for i in range(s, e):
        if arr[i] < pivot:
            tmp = arr[left]
            arr[left] = arr[i]
            arr[i] = tmp
            left += 1

    # Move pivot in-between left & right sides
    arr[e] = arr[left]
    arr[left] = pivot
    
    # Quick sort left side
    quickSort(arr, s, left - 1)

    # Quick sort right side
    quickSort(arr, left + 1, e)

    return arr

import sys
n = int(sys.stdin.readline().rstrip())

array = []
for _ in range (n):
  array.append(int(sys.stdin.readline().rstrip()))

quickSort(array,0,len(array)-1)

for element in array:
  print(element)

 

정답은 바로 Counting Sort를 사용해야 하는데, 그 이유는 숫자의 범위인 10000이 배열의 최대 길이인 10^7보다 현저히 작기 때문에, 메모리 적으로 훨씬 효율적이다.

import sys
n = int(sys.stdin.readline().rstrip())

array = [0 for _ in range(10000)]

for _ in range(n):
  array[int(sys.stdin.readline().rstrip())-1] += 1

for i,element in enumerate(array):
  for _ in range(element):
     print(i+1)

 

array.index

1021 회전하는 큐

array.index(element)로 배열의 요소의 index를 구할 수 있다.

if array.index(element) <= len(array)/2:

 

그리고 이 문제는 꼭 다시한번 풀어보는 걸 추천한다.

내가 적은 코드도 맞는 코드이고 오히려 시간도 더 효율적인데, 너무 어렵게 생각해서 풀어서 좀 더 쉽게 접근하는 답안도 참고하자.

내가 적은 코드 

import sys 
from collections import deque 

n, m = map(int,sys.stdin.readline().rstrip().split(' '))
find_array = list(map(int,sys.stdin.readline().rstrip().split(' ')))

for i,element in enumerate(find_array):
  find_array[i] -= 1

count = 0 
for k,element in enumerate(find_array):
  if element == 0:
    pass
  elif element <= n/2 :
    for i,j in enumerate(find_array):
      find_array[i] -= element
      if (find_array[i] > n):
        find_array[i] -= n
      if (find_array[i] < 0 ):
        find_array[i] += n
    count += element
  elif element > n/2 :
    for i,j in enumerate(find_array):
      find_array[i] += (n-element)
      if (find_array[i] > n):
        find_array[i] -= n
      if (find_array[i] < 0 ):
        find_array[i] += n
    count += (n-element)  
  for i,element in enumerate(find_array):
    find_array[i] -= 1
  n = n - 1
print(count)

 

답안

import sys 
from collections import deque 

n, m = map(int,sys.stdin.readline().rstrip().split(' '))
array= deque()
for i in range(n):
  array.append(i+1)
find_array = list(map(int,sys.stdin.readline().rstrip().split(' ')))

count = 0 
for k,element in enumerate(find_array):
  if array.index(element) <= len(array)/2:
    while array[0] != element:
      array.append(array.popleft())
      count = count+1
  elif array.index(element) > len(array)/2:
    while array[0] != element:
      array.appendleft(array.pop())  
      count=count+1
  array.popleft()
print(count)