문제 해석

\(n\)종류의 사탕이 \(a_1, a_2, \cdots, a_n\)개씩 있다. 이 사탕들을 봉투(들)에 나누어 담는데, 모든 봉투마다 정확히 \(k\)개씩 넣어야 하고 같은 봉투에 든 사탕은 모두 다른 종류여야 한다. 사탕은 어떤 종류든 어떤 개수든 추가로 주문할 수 있지만 사탕이 남아서는 안 된다. 추가로 주문해야 하는 사탕 개수의 최솟값을 구하라.

접근

\(k\)가 \(n\) 이하이고, 사탕을 종류, 개수에 상관없이 자유롭게 추가할 수 있기 때문에 봉투를 \(b\)개 만들 수 있다면 \(b+1\)개는 당연히 만들 수 있다. 즉, 만들 수 있는 봉투 개수의 최솟값이 존재할 것이고, 이를 \(B\)라고 하면 \(Bk - \sum_{i} a_i\)가 답이 되므로, \(B\)를 구하는 법만 알면 된다.

\(b\)봉투를 만들 수 있을 때 성립하는 사실이 무엇일까? 사탕이 모자랄 때는 주문해서 채우면 되지만 남으면 안 되므로 \(a_i\) 중에 \(b\)보다 큰 것이 없어야 하고, \(a_i\)들의 합이 \(bk\) 이하여야 한다.

\[\text{Can make } b \text{ packs} \Rightarrow \begin{cases} \max_{i} a_i \le b \cdots (1)\\ \quad \text{and}\\ \sum_{i} a_i \le bk \cdots (2) \end{cases}\]

이것만으로는 필요조건이라서 별 도움이 안 됐겠지만, 저 역도 성립한다는 것을 지금부터 보일 것이다.

봉투를 만드는 가장 “안전한” 방법이 뭘까? 직관적으로 생각해서, 개수가 적은 종류보다는 많은 종류를 빨리 써서 없애는 게 좋을 것이다. 문제에 있는 예시 \(n=4, k=3, a=[1,3,4,1]\)을 생각해보자. \(b=4\)라고 하면 조건 \((1)\)과 \((2)\)를 만족한다. 하지만 첫 번째 봉투를 \([1,2,4]\)로 만들면 \(3\)번 종류 때문에 봉투를 \(4\)개 더 만들어야 해서 실패한다.

따라서 아래와 같은 greedy 전략을 떠올릴 수 있다.

\(b\)번 반복:
  남아 있는 종류가 \(k\)개 이상이면:
    가장 많이 남은 순서대로 \(k\)종류를 골라 \(1\)개씩 꺼낸다
  그렇지 않으면:
    남은 종류마다 \(1\)개씩 꺼내고, \(k\)개가 되도록 안 겹치게 마음대로 주문한다
    (\(k\)가 \(n\) 이하라서 반드시 가능)
  이 \(k\)개로 봉투 하나를 만든다

\((1) \wedge (2)\)를 만족하는 \(b\)에 대하여 위 전략은 반드시 성공한다. 증명은 아래 참조. 따라서 \((1) \wedge (2)\)는 \(b\)봉투를 만들 수 있는 것과 동치가 되고, 아래 식으로 답을 \(\mathcal{O}(n)\) 시간만에 구할 수 있다.

\[Ans = Bk - \sum_{i} a_i \quad \text{where} \quad B = \max \left( \max_{i} a_i, \left\lceil \frac{\sum_{i} a_i}{k} \right\rceil \right)\]

정확성 증명

봉투의 수 \(b\)에 대하여 수학적 귀납법을 한다. \(b=1\)이면 자명하다.

\(\max_{i} a_i < b\)라면 봉투를 하나 만든 후에도 그럴 것이다(좌변은 증가할 수 없으므로). \(\max_{i} a_i = b\)라면 \((2)\)에 의해 \(a_i = b\)인 \(i\)는 \(k\)개 이하이고, 봉투를 하나 만든 후 이 \(a_i\)들은 모두 \(1\)씩 감소하므로 \(\max_{i} a_i\)도 \(1\) 감소한다.

한편, \(a_i = 0\)인 \(i\)가 존재한다면 \((1)\)에 의해 \(\sum_{i} a_i \le (b-1)k\)이고 봉투를 하나 만든 후에도 그럴 것이다(좌변은 증가할 수 없으므로). 그렇지 않다면, 봉투를 하나 만든 후 \(\sum_{i} a_i\)는 정확히 \(k\)만큼 감소한다.

따라서 어떤 \(b\)에서 \((1) \wedge (2)\)를 만족할 때 위 전략대로 봉투를 하나 만들면 \(b\)가 \(1\) 감소하고 다시 \((1) \wedge (2)\)를 만족한다. 끝.

코드

def main():
    n, k = map(int, input().split())
    a = [int(input()) for _ in range(n)]
    sum_a = sum(a)
    B = max(max(a), (sum_a + k - 1)//k)
    print(B*k - sum_a)

if __name__ == "__main__":
    main()

태그:

카테고리:

업데이트: