BOJ 2473 - 세 용액

문제 링크


CODE
#include <bits/stdc++.h>
#define all(x) (x).begin(), (x).en()
#define rep(i, a, b) for (int i = (a); i < (b); ++i)
typedef long long ll;
const int INF = 0x3f3f3f3f;
using namespace std;
ll N;
ll arr[5050];
int main()
{
ios_base::sync_with_stdio(false);
cin.tie(NULL); std::cout.tie(NULL);
cin >> N;
for (int i = 1; i <= N; i++)
cin >> arr[i];
sort(arr+1, arr+N+1);
ll first, second, third;
ll mn = LONG_MAX;
for (int i = 1; i <= N - 2; i++)
{
int st = i + 1;
int en = N;
while (st < en)
{
ll temp = mn;
ll result = arr[i] + arr[st] + arr[en];
mn = min(mn, abs(result));
if (mn != temp) {
temp = mn;
first = arr[i];
second = arr[st];
third = arr[en];
}
if (result < 0)
st++;
else
en--;
}
}
vector<int>ans;
ans.push_back(first);
ans.push_back(second);
ans.push_back(third);
sort(ans.begin(), ans.end());
for (int i = 0; i < 3; i++)
cout << ans[i] << ' ';
return 0;
}
DESCRIPTION
앞전의 두 용액과 마찬가지로, 투 포인터와 이분 탐색을 둘다 이용하여 문제를 풀이할 수 있다. 두 용액에서 용액 2개를 일일이 고른다면
O(N^2)
로 시간초과가 발생하는 것을 투 포인터를 활용해O(N)
으로 줄여서 통과시켰던 것처럼, 이번 문제도O(N^3)
으로 발생하는 시간 초과를(O(5000 * 5000 * 5000)
) 투 포인터를 활용해O(N^2)
로 줄이는 것으로 풀이가 가능하다.두 용액을 풀고 바로 세 용액을 풀려고 시도했는데, 용액 3개를 고르는 지점에서 내 논리에 오류가 발생해서 결국 구글에 쳐봤었다.

- 이것은 두 용액을 투 포인터로 고르는 대략적인 컨셉의 그림인데, 세 용액은 여기서 하나를 더 골라야한다. 처음엔 두 용액에서 했던 것처럼 양 끝점을 잡고, 나머지 하나를 고르려고 했는데, 그렇게 되면 양 끝점을 어느 순서로 당겨야하는지에 대한 기준을 잡을 수가 없어서, 사실상 N 세제곱의 풀이와 다른 점이 없었다. 애초에 양 끝점을 잡고 나머지 하나를 고르는게 아니라, 배열을 순회하면서(맨 첫번째 용액부터 N-2번째 용액까지) i번째 용액을 첫 번째 용액으로 삼고, 나머지 두 개의 용액을 투 포인터로 찾아내는 방식을 사용해야 했다. 생각해낼 수도 있었을 법한데 아쉬운 부분이다…

- 마찬가지로 이 문제도 이분 탐색을 활용할 수가 있다. 두 용액때 O(N^2)을 O(NlogN)으로 줄였던 것처럼 여기서도 O(N^3)을 O(N^2 * logN)으로 줄일 수가 있다. 이 문제는 N이 5000이고, log 5000은 거의 상수에 가깝기 때문에 사실상 O(N^2)와 다르지 않다고 볼 수 있다. 물론 시간차이는 좀 난다. 아래는 이분 탐색을 사용한 풀이, 골자는 위의 투 포인터와 크게 다르지 않다. 용액 두 개를 먼저 골라두고, 나머지 하나를 이분 탐색으로 찾아나간다.
#include <bits/stdc++.h>
#define all(x) (x).begin(), (x).end()
#define rep(i, a, b) for (int i = (a); i < (b); ++i)
typedef long long ll;
const int INF = 0x3f3f3f3f;
using namespace std;
int N;
ll arr[5050];
ll mn = 3000000000;
ll first, second, third;
int main()
{
ios_base::sync_with_stdio(false);
cin.tie(NULL); std::cout.tie(NULL);
cin >> N;
for (int i = 1; i <= N; i++) {
cin >> arr[i];
}
sort(arr+1, arr+N+1);
for (int i = 1; i <= N - 2; i++) {
for (int j = i + 1; j <= N - 1; j++) {
int st = j + 1;
int en = N;
while (st <= en) // 이분 탐색 적용
{
int mid = (st + en) / 2;
ll sum = arr[i] + arr[j] + arr[mid];
ll result = abs(sum);
if (result < mn) {
mn = result;
first = arr[i];
second = arr[j];
third = arr[mid];
}
if (sum < 0) {
st = mid + 1;
} else {
en = mid - 1;
}
}
}
}
vector<int>v;
v.push_back(first);v.push_back(second);v.push_back(third);
sort(v.begin(), v.end());
for (int i=0;i<3;i++)cout<<v[i]<<' ';
return 0;
}